divisor.dimoo.inference_mmu
Text understanding inference script
1# SPDX-License-Identifier: Apache-2.0 2# Adapted from https://github.com/Alpha-VLLM/Lumina-DiMOO 3 4""" 5Text understanding inference script 6""" 7 8import argparse 9import os 10import sys 11import time 12 13import torch 14from huggingface_hub import snapshot_download 15from transformers import AutoTokenizer 16 17from divisor.dimoo.config import SPECIAL_TOKENS 18from divisor.dimoo.prompt_utils import generate_text_prompt 19from divisor.dimoo.text_understanding_generator import generate_text_understanding 20from divisor.mmada.modeling_llada import LLaDAModelLM 21from divisor.registry import gfx_device 22 23sys.path.append(os.path.dirname(os.path.dirname(__file__))) 24 25 26def main(): 27 parser = argparse.ArgumentParser(description="Text understanding inference") 28 parser.add_argument("--model_id", type=str, required=False, default=snapshot_download("Alpha-VLLM/Lumina-DiMOO"), help="Model ID") 29 parser.add_argument("--prompt", type=str, required=True, help="Text prompt") 30 parser.add_argument("--steps", type=int, default=128, help="Generation steps") 31 parser.add_argument("--gen_length", type=int, default=1024, help="Generation length") 32 parser.add_argument("--block_length", type=int, default=256, help="Block length") 33 parser.add_argument("--temperature", type=float, default=0.0, help="Temperature") 34 parser.add_argument("--cfg_scale", type=float, default=0.0, help="CFG scale") 35 parser.add_argument("--vae_ckpt", type=str, default="./vae_ckpt", help="VAE checkpoint path") 36 parser.add_argument("--output_dir", type=str, default="outputs_text_understanding", help="Output directory") 37 38 args = parser.parse_args() 39 40 # Special tokens 41 MASK = SPECIAL_TOKENS["mask_token"] 42 BOA = SPECIAL_TOKENS["answer_start"] # Begin of Answer 43 EOA = SPECIAL_TOKENS["answer_end"] # End of Answer 44 45 # Create output directory 46 os.makedirs(args.output_dir, exist_ok=True) 47 48 # Load model and tokenizer 49 tokenizer = AutoTokenizer.from_pretrained(args.model_id) 50 51 input_prompt = generate_text_prompt(args.prompt) 52 input_token = tokenizer(input_prompt)["input_ids"] 53 54 # Prediction text token start index 55 code_start = len(input_token) + 1 56 57 # Build text mask predition sequence 58 input_token = input_token + [BOA] + args.gen_length * [MASK] + [EOA] 59 input_ids = torch.tensor(input_token, device=device).unsqueeze(0) 60 61 # Generate text 62 start_time = time.time() 63 out_new = generate_text_understanding( 64 model, 65 input_ids, 66 steps=args.steps, 67 gen_length=args.gen_length, 68 block_length=args.block_length, 69 temperature=args.temperature, 70 cfg_scale=args.cfg_scale, 71 remasking="low_confidence", 72 code_start=code_start, 73 ) 74 75 text_new = tokenizer.batch_decode(out_new[:, code_start:-1], skip_special_tokens=True)[0] 76 77 end_time = time.time() 78 elapsed_time = end_time - start_time 79 print(f"[✓] (Time {elapsed_time:.2f}s)") 80 81 print(f"Generated text: {text_new}") 82 83 84if __name__ == "__main__": 85 main()
def
main():
27def main(): 28 parser = argparse.ArgumentParser(description="Text understanding inference") 29 parser.add_argument("--model_id", type=str, required=False, default=snapshot_download("Alpha-VLLM/Lumina-DiMOO"), help="Model ID") 30 parser.add_argument("--prompt", type=str, required=True, help="Text prompt") 31 parser.add_argument("--steps", type=int, default=128, help="Generation steps") 32 parser.add_argument("--gen_length", type=int, default=1024, help="Generation length") 33 parser.add_argument("--block_length", type=int, default=256, help="Block length") 34 parser.add_argument("--temperature", type=float, default=0.0, help="Temperature") 35 parser.add_argument("--cfg_scale", type=float, default=0.0, help="CFG scale") 36 parser.add_argument("--vae_ckpt", type=str, default="./vae_ckpt", help="VAE checkpoint path") 37 parser.add_argument("--output_dir", type=str, default="outputs_text_understanding", help="Output directory") 38 39 args = parser.parse_args() 40 41 # Special tokens 42 MASK = SPECIAL_TOKENS["mask_token"] 43 BOA = SPECIAL_TOKENS["answer_start"] # Begin of Answer 44 EOA = SPECIAL_TOKENS["answer_end"] # End of Answer 45 46 # Create output directory 47 os.makedirs(args.output_dir, exist_ok=True) 48 49 # Load model and tokenizer 50 tokenizer = AutoTokenizer.from_pretrained(args.model_id) 51 52 input_prompt = generate_text_prompt(args.prompt) 53 input_token = tokenizer(input_prompt)["input_ids"] 54 55 # Prediction text token start index 56 code_start = len(input_token) + 1 57 58 # Build text mask predition sequence 59 input_token = input_token + [BOA] + args.gen_length * [MASK] + [EOA] 60 input_ids = torch.tensor(input_token, device=device).unsqueeze(0) 61 62 # Generate text 63 start_time = time.time() 64 out_new = generate_text_understanding( 65 model, 66 input_ids, 67 steps=args.steps, 68 gen_length=args.gen_length, 69 block_length=args.block_length, 70 temperature=args.temperature, 71 cfg_scale=args.cfg_scale, 72 remasking="low_confidence", 73 code_start=code_start, 74 ) 75 76 text_new = tokenizer.batch_decode(out_new[:, code_start:-1], skip_special_tokens=True)[0] 77 78 end_time = time.time() 79 elapsed_time = end_time - start_time 80 print(f"[✓] (Time {elapsed_time:.2f}s)") 81 82 print(f"Generated text: {text_new}")