divisor.dimoo.text_understanding_generator
Text understanding generator
1# SPDX-License-Identifier: Apache-2.0 2# Adapted from https://github.com/Alpha-VLLM/Lumina-DiMOO 3 4""" 5Text understanding generator 6""" 7 8from typing import Optional 9 10import numpy as np 11import torch 12import torch.nn.functional as F 13 14from divisor.registry import gfx_dtype, gfx_device 15from divisor.mmada.live_token import get_num_transfer_tokens 16from divisor.noise import add_gumbel_noise 17 18 19@torch.no_grad() 20def generate_text_understanding( 21 model, 22 prompt, 23 steps=32, 24 gen_length=32, 25 block_length=32, 26 temperature=1.0, 27 cfg_scale=0.0, 28 remasking="low_confidence", 29 mask_id=126336, 30 code_start: Optional[int] = None, 31): 32 """ 33 Text understanding generation function 34 35 Args: 36 model: Mask predictor 37 prompt: Input prompt tensor (1, L) 38 steps: Sampling steps, less than or equal to gen_length 39 gen_length: Generated answer length 40 block_length: Block length, less than or equal to gen_length 41 temperature: Categorical distribution sampling temperature 42 cfg_scale: Unsupervised classifier-free guidance scale 43 remasking: Remasking strategy 'low_confidence' or 'random' 44 mask_id: The token id of [MASK] is 126336 45 code_start: Prediction text token satrt index 46 """ 47 device = next(model.parameters()).device or device 48 precision = gfx_dtype 49 x = prompt 50 51 prompt_index = x != mask_id 52 53 assert gen_length % block_length == 0 54 num_blocks = gen_length // block_length 55 56 assert steps % num_blocks == 0 57 steps = steps // num_blocks 58 59 for num_block in range(num_blocks): 60 block_mask_index = x[:, code_start + num_block * block_length : code_start + (num_block + 1) * block_length :] == mask_id 61 num_transfer_tokens = get_num_transfer_tokens(block_mask_index, steps) 62 63 for i in range(steps): 64 mask_index = x == mask_id 65 if cfg_scale > 0.0: 66 un_x = x.clone() 67 un_x[prompt_index] = mask_id 68 x_ = torch.cat([x, un_x], dim=0) 69 logits = model(x_).logits 70 logits, un_logits = torch.chunk(logits, 2, dim=0) 71 logits = un_logits + (cfg_scale + 1) * (logits - un_logits) 72 else: 73 logits = model(x).logits 74 75 logits_with_noise = add_gumbel_noise(logits, temperature=temperature) 76 x0 = torch.argmax(logits_with_noise, dim=-1) # b, l 77 78 if remasking == "low_confidence": 79 p = F.softmax(logits.to(precision), dim=-1) 80 x0_p = torch.squeeze(torch.gather(p, dim=-1, index=torch.unsqueeze(x0, -1)), -1) # b, l 81 elif remasking == "random": 82 x0_p = torch.rand((x0.shape[0], x0.shape[1]), device=x0.device) 83 else: 84 raise NotImplementedError(remasking) 85 86 x0_p[:, code_start + (num_block + 1) * block_length :] = -np.inf 87 88 x0 = torch.where(mask_index, x0, x) 89 confidence = torch.where(mask_index, x0_p, -np.inf) 90 91 transfer_index = torch.zeros_like(x0, dtype=torch.bool, device=x0.device) 92 for j in range(confidence.shape[0]): 93 _, select_index = torch.topk(confidence[j], k=num_transfer_tokens[j, i]) 94 transfer_index[j, select_index] = True 95 x[transfer_index] = x0[transfer_index] 96 97 # early stop 98 if (x == 126081).sum().item() > 0 and num_blocks > 0: 99 return x[:, : input_p_len + prompt.shape[0] + 3 + rows + (num_block + 1) * block_length] 100 101 return x
@torch.no_grad()
def
generate_text_understanding( model, prompt, steps=32, gen_length=32, block_length=32, temperature=1.0, cfg_scale=0.0, remasking='low_confidence', mask_id=126336, code_start: Optional[int] = None):
20@torch.no_grad() 21def generate_text_understanding( 22 model, 23 prompt, 24 steps=32, 25 gen_length=32, 26 block_length=32, 27 temperature=1.0, 28 cfg_scale=0.0, 29 remasking="low_confidence", 30 mask_id=126336, 31 code_start: Optional[int] = None, 32): 33 """ 34 Text understanding generation function 35 36 Args: 37 model: Mask predictor 38 prompt: Input prompt tensor (1, L) 39 steps: Sampling steps, less than or equal to gen_length 40 gen_length: Generated answer length 41 block_length: Block length, less than or equal to gen_length 42 temperature: Categorical distribution sampling temperature 43 cfg_scale: Unsupervised classifier-free guidance scale 44 remasking: Remasking strategy 'low_confidence' or 'random' 45 mask_id: The token id of [MASK] is 126336 46 code_start: Prediction text token satrt index 47 """ 48 device = next(model.parameters()).device or device 49 precision = gfx_dtype 50 x = prompt 51 52 prompt_index = x != mask_id 53 54 assert gen_length % block_length == 0 55 num_blocks = gen_length // block_length 56 57 assert steps % num_blocks == 0 58 steps = steps // num_blocks 59 60 for num_block in range(num_blocks): 61 block_mask_index = x[:, code_start + num_block * block_length : code_start + (num_block + 1) * block_length :] == mask_id 62 num_transfer_tokens = get_num_transfer_tokens(block_mask_index, steps) 63 64 for i in range(steps): 65 mask_index = x == mask_id 66 if cfg_scale > 0.0: 67 un_x = x.clone() 68 un_x[prompt_index] = mask_id 69 x_ = torch.cat([x, un_x], dim=0) 70 logits = model(x_).logits 71 logits, un_logits = torch.chunk(logits, 2, dim=0) 72 logits = un_logits + (cfg_scale + 1) * (logits - un_logits) 73 else: 74 logits = model(x).logits 75 76 logits_with_noise = add_gumbel_noise(logits, temperature=temperature) 77 x0 = torch.argmax(logits_with_noise, dim=-1) # b, l 78 79 if remasking == "low_confidence": 80 p = F.softmax(logits.to(precision), dim=-1) 81 x0_p = torch.squeeze(torch.gather(p, dim=-1, index=torch.unsqueeze(x0, -1)), -1) # b, l 82 elif remasking == "random": 83 x0_p = torch.rand((x0.shape[0], x0.shape[1]), device=x0.device) 84 else: 85 raise NotImplementedError(remasking) 86 87 x0_p[:, code_start + (num_block + 1) * block_length :] = -np.inf 88 89 x0 = torch.where(mask_index, x0, x) 90 confidence = torch.where(mask_index, x0_p, -np.inf) 91 92 transfer_index = torch.zeros_like(x0, dtype=torch.bool, device=x0.device) 93 for j in range(confidence.shape[0]): 94 _, select_index = torch.topk(confidence[j], k=num_transfer_tokens[j, i]) 95 transfer_index[j, select_index] = True 96 x[transfer_index] = x0[transfer_index] 97 98 # early stop 99 if (x == 126081).sum().item() > 0 and num_blocks > 0: 100 return x[:, : input_p_len + prompt.shape[0] + 3 + rows + (num_block + 1) * block_length] 101 102 return x
Text understanding generation function
Args: model: Mask predictor prompt: Input prompt tensor (1, L) steps: Sampling steps, less than or equal to gen_length gen_length: Generated answer length block_length: Block length, less than or equal to gen_length temperature: Categorical distribution sampling temperature cfg_scale: Unsupervised classifier-free guidance scale remasking: Remasking strategy 'low_confidence' or 'random' mask_id: The token id of [MASK] is 126336 code_start: Prediction text token satrt index