divisor.dimoo.inference_mmu

Text understanding inference script

 1# SPDX-License-Identifier: Apache-2.0
 2# Adapted from https://github.com/Alpha-VLLM/Lumina-DiMOO
 3
 4"""
 5Text understanding inference script
 6"""
 7
 8import argparse
 9import os
10import sys
11import time
12
13import torch
14from huggingface_hub import snapshot_download
15from transformers import AutoTokenizer
16
17from divisor.dimoo.config import SPECIAL_TOKENS
18from divisor.dimoo.prompt_utils import generate_text_prompt
19from divisor.dimoo.text_understanding_generator import generate_text_understanding
20from divisor.mmada.modeling_llada import LLaDAModelLM
21from divisor.registry import gfx_device
22
23sys.path.append(os.path.dirname(os.path.dirname(__file__)))
24
25
26def main():
27    parser = argparse.ArgumentParser(description="Text understanding inference")
28    parser.add_argument("--model_id", type=str, required=False, default=snapshot_download("Alpha-VLLM/Lumina-DiMOO"), help="Model ID")
29    parser.add_argument("--prompt", type=str, required=True, help="Text prompt")
30    parser.add_argument("--steps", type=int, default=128, help="Generation steps")
31    parser.add_argument("--gen_length", type=int, default=1024, help="Generation length")
32    parser.add_argument("--block_length", type=int, default=256, help="Block length")
33    parser.add_argument("--temperature", type=float, default=0.0, help="Temperature")
34    parser.add_argument("--cfg_scale", type=float, default=0.0, help="CFG scale")
35    parser.add_argument("--vae_ckpt", type=str, default="./vae_ckpt", help="VAE checkpoint path")
36    parser.add_argument("--output_dir", type=str, default="outputs_text_understanding", help="Output directory")
37
38    args = parser.parse_args()
39
40    # Special tokens
41    MASK = SPECIAL_TOKENS["mask_token"]
42    BOA = SPECIAL_TOKENS["answer_start"]  # Begin of Answer
43    EOA = SPECIAL_TOKENS["answer_end"]  # End of Answer
44
45    # Create output directory
46    os.makedirs(args.output_dir, exist_ok=True)
47
48    # Load model and tokenizer
49    tokenizer = AutoTokenizer.from_pretrained(args.model_id)
50
51    input_prompt = generate_text_prompt(args.prompt)
52    input_token = tokenizer(input_prompt)["input_ids"]
53
54    # Prediction text token start index
55    code_start = len(input_token) + 1
56
57    # Build text mask predition sequence
58    input_token = input_token + [BOA] + args.gen_length * [MASK] + [EOA]
59    input_ids = torch.tensor(input_token, device=device).unsqueeze(0)
60
61    # Generate text
62    start_time = time.time()
63    out_new = generate_text_understanding(
64        model,
65        input_ids,
66        steps=args.steps,
67        gen_length=args.gen_length,
68        block_length=args.block_length,
69        temperature=args.temperature,
70        cfg_scale=args.cfg_scale,
71        remasking="low_confidence",
72        code_start=code_start,
73    )
74
75    text_new = tokenizer.batch_decode(out_new[:, code_start:-1], skip_special_tokens=True)[0]
76
77    end_time = time.time()
78    elapsed_time = end_time - start_time
79    print(f"[✓] (Time {elapsed_time:.2f}s)")
80
81    print(f"Generated text: {text_new}")
82
83
84if __name__ == "__main__":
85    main()
def main():
27def main():
28    parser = argparse.ArgumentParser(description="Text understanding inference")
29    parser.add_argument("--model_id", type=str, required=False, default=snapshot_download("Alpha-VLLM/Lumina-DiMOO"), help="Model ID")
30    parser.add_argument("--prompt", type=str, required=True, help="Text prompt")
31    parser.add_argument("--steps", type=int, default=128, help="Generation steps")
32    parser.add_argument("--gen_length", type=int, default=1024, help="Generation length")
33    parser.add_argument("--block_length", type=int, default=256, help="Block length")
34    parser.add_argument("--temperature", type=float, default=0.0, help="Temperature")
35    parser.add_argument("--cfg_scale", type=float, default=0.0, help="CFG scale")
36    parser.add_argument("--vae_ckpt", type=str, default="./vae_ckpt", help="VAE checkpoint path")
37    parser.add_argument("--output_dir", type=str, default="outputs_text_understanding", help="Output directory")
38
39    args = parser.parse_args()
40
41    # Special tokens
42    MASK = SPECIAL_TOKENS["mask_token"]
43    BOA = SPECIAL_TOKENS["answer_start"]  # Begin of Answer
44    EOA = SPECIAL_TOKENS["answer_end"]  # End of Answer
45
46    # Create output directory
47    os.makedirs(args.output_dir, exist_ok=True)
48
49    # Load model and tokenizer
50    tokenizer = AutoTokenizer.from_pretrained(args.model_id)
51
52    input_prompt = generate_text_prompt(args.prompt)
53    input_token = tokenizer(input_prompt)["input_ids"]
54
55    # Prediction text token start index
56    code_start = len(input_token) + 1
57
58    # Build text mask predition sequence
59    input_token = input_token + [BOA] + args.gen_length * [MASK] + [EOA]
60    input_ids = torch.tensor(input_token, device=device).unsqueeze(0)
61
62    # Generate text
63    start_time = time.time()
64    out_new = generate_text_understanding(
65        model,
66        input_ids,
67        steps=args.steps,
68        gen_length=args.gen_length,
69        block_length=args.block_length,
70        temperature=args.temperature,
71        cfg_scale=args.cfg_scale,
72        remasking="low_confidence",
73        code_start=code_start,
74    )
75
76    text_new = tokenizer.batch_decode(out_new[:, code_start:-1], skip_special_tokens=True)[0]
77
78    end_time = time.time()
79    elapsed_time = end_time - start_time
80    print(f"[✓] (Time {elapsed_time:.2f}s)")
81
82    print(f"Generated text: {text_new}")