divisor.dimoo.text_understanding_generator

Text understanding generator

  1# SPDX-License-Identifier: Apache-2.0
  2# Adapted from https://github.com/Alpha-VLLM/Lumina-DiMOO
  3
  4"""
  5Text understanding generator
  6"""
  7
  8from typing import Optional
  9
 10import numpy as np
 11import torch
 12import torch.nn.functional as F
 13
 14from divisor.registry import gfx_dtype, gfx_device
 15from divisor.mmada.live_token import get_num_transfer_tokens
 16from divisor.noise import add_gumbel_noise
 17
 18
 19@torch.no_grad()
 20def generate_text_understanding(
 21    model,
 22    prompt,
 23    steps=32,
 24    gen_length=32,
 25    block_length=32,
 26    temperature=1.0,
 27    cfg_scale=0.0,
 28    remasking="low_confidence",
 29    mask_id=126336,
 30    code_start: Optional[int] = None,
 31):
 32    """
 33    Text understanding generation function
 34
 35    Args:
 36        model: Mask predictor
 37        prompt: Input prompt tensor (1, L)
 38        steps: Sampling steps, less than or equal to gen_length
 39        gen_length: Generated answer length
 40        block_length: Block length, less than or equal to gen_length
 41        temperature: Categorical distribution sampling temperature
 42        cfg_scale: Unsupervised classifier-free guidance scale
 43        remasking: Remasking strategy 'low_confidence' or 'random'
 44        mask_id: The token id of [MASK] is 126336
 45        code_start: Prediction text token satrt index
 46    """
 47    device = next(model.parameters()).device or device
 48    precision = gfx_dtype
 49    x = prompt
 50
 51    prompt_index = x != mask_id
 52
 53    assert gen_length % block_length == 0
 54    num_blocks = gen_length // block_length
 55
 56    assert steps % num_blocks == 0
 57    steps = steps // num_blocks
 58
 59    for num_block in range(num_blocks):
 60        block_mask_index = x[:, code_start + num_block * block_length : code_start + (num_block + 1) * block_length :] == mask_id
 61        num_transfer_tokens = get_num_transfer_tokens(block_mask_index, steps)
 62
 63        for i in range(steps):
 64            mask_index = x == mask_id
 65            if cfg_scale > 0.0:
 66                un_x = x.clone()
 67                un_x[prompt_index] = mask_id
 68                x_ = torch.cat([x, un_x], dim=0)
 69                logits = model(x_).logits
 70                logits, un_logits = torch.chunk(logits, 2, dim=0)
 71                logits = un_logits + (cfg_scale + 1) * (logits - un_logits)
 72            else:
 73                logits = model(x).logits
 74
 75            logits_with_noise = add_gumbel_noise(logits, temperature=temperature)
 76            x0 = torch.argmax(logits_with_noise, dim=-1)  # b, l
 77
 78            if remasking == "low_confidence":
 79                p = F.softmax(logits.to(precision), dim=-1)
 80                x0_p = torch.squeeze(torch.gather(p, dim=-1, index=torch.unsqueeze(x0, -1)), -1)  # b, l
 81            elif remasking == "random":
 82                x0_p = torch.rand((x0.shape[0], x0.shape[1]), device=x0.device)
 83            else:
 84                raise NotImplementedError(remasking)
 85
 86            x0_p[:, code_start + (num_block + 1) * block_length :] = -np.inf
 87
 88            x0 = torch.where(mask_index, x0, x)
 89            confidence = torch.where(mask_index, x0_p, -np.inf)
 90
 91            transfer_index = torch.zeros_like(x0, dtype=torch.bool, device=x0.device)
 92            for j in range(confidence.shape[0]):
 93                _, select_index = torch.topk(confidence[j], k=num_transfer_tokens[j, i])
 94                transfer_index[j, select_index] = True
 95            x[transfer_index] = x0[transfer_index]
 96
 97        # early stop
 98        if (x == 126081).sum().item() > 0 and num_blocks > 0:
 99            return x[:, : input_p_len + prompt.shape[0] + 3 + rows + (num_block + 1) * block_length]
100
101    return x
@torch.no_grad()
def generate_text_understanding( model, prompt, steps=32, gen_length=32, block_length=32, temperature=1.0, cfg_scale=0.0, remasking='low_confidence', mask_id=126336, code_start: Optional[int] = None):
 20@torch.no_grad()
 21def generate_text_understanding(
 22    model,
 23    prompt,
 24    steps=32,
 25    gen_length=32,
 26    block_length=32,
 27    temperature=1.0,
 28    cfg_scale=0.0,
 29    remasking="low_confidence",
 30    mask_id=126336,
 31    code_start: Optional[int] = None,
 32):
 33    """
 34    Text understanding generation function
 35
 36    Args:
 37        model: Mask predictor
 38        prompt: Input prompt tensor (1, L)
 39        steps: Sampling steps, less than or equal to gen_length
 40        gen_length: Generated answer length
 41        block_length: Block length, less than or equal to gen_length
 42        temperature: Categorical distribution sampling temperature
 43        cfg_scale: Unsupervised classifier-free guidance scale
 44        remasking: Remasking strategy 'low_confidence' or 'random'
 45        mask_id: The token id of [MASK] is 126336
 46        code_start: Prediction text token satrt index
 47    """
 48    device = next(model.parameters()).device or device
 49    precision = gfx_dtype
 50    x = prompt
 51
 52    prompt_index = x != mask_id
 53
 54    assert gen_length % block_length == 0
 55    num_blocks = gen_length // block_length
 56
 57    assert steps % num_blocks == 0
 58    steps = steps // num_blocks
 59
 60    for num_block in range(num_blocks):
 61        block_mask_index = x[:, code_start + num_block * block_length : code_start + (num_block + 1) * block_length :] == mask_id
 62        num_transfer_tokens = get_num_transfer_tokens(block_mask_index, steps)
 63
 64        for i in range(steps):
 65            mask_index = x == mask_id
 66            if cfg_scale > 0.0:
 67                un_x = x.clone()
 68                un_x[prompt_index] = mask_id
 69                x_ = torch.cat([x, un_x], dim=0)
 70                logits = model(x_).logits
 71                logits, un_logits = torch.chunk(logits, 2, dim=0)
 72                logits = un_logits + (cfg_scale + 1) * (logits - un_logits)
 73            else:
 74                logits = model(x).logits
 75
 76            logits_with_noise = add_gumbel_noise(logits, temperature=temperature)
 77            x0 = torch.argmax(logits_with_noise, dim=-1)  # b, l
 78
 79            if remasking == "low_confidence":
 80                p = F.softmax(logits.to(precision), dim=-1)
 81                x0_p = torch.squeeze(torch.gather(p, dim=-1, index=torch.unsqueeze(x0, -1)), -1)  # b, l
 82            elif remasking == "random":
 83                x0_p = torch.rand((x0.shape[0], x0.shape[1]), device=x0.device)
 84            else:
 85                raise NotImplementedError(remasking)
 86
 87            x0_p[:, code_start + (num_block + 1) * block_length :] = -np.inf
 88
 89            x0 = torch.where(mask_index, x0, x)
 90            confidence = torch.where(mask_index, x0_p, -np.inf)
 91
 92            transfer_index = torch.zeros_like(x0, dtype=torch.bool, device=x0.device)
 93            for j in range(confidence.shape[0]):
 94                _, select_index = torch.topk(confidence[j], k=num_transfer_tokens[j, i])
 95                transfer_index[j, select_index] = True
 96            x[transfer_index] = x0[transfer_index]
 97
 98        # early stop
 99        if (x == 126081).sum().item() > 0 and num_blocks > 0:
100            return x[:, : input_p_len + prompt.shape[0] + 3 + rows + (num_block + 1) * block_length]
101
102    return x

Text understanding generation function

Args: model: Mask predictor prompt: Input prompt tensor (1, L) steps: Sampling steps, less than or equal to gen_length gen_length: Generated answer length block_length: Block length, less than or equal to gen_length temperature: Categorical distribution sampling temperature cfg_scale: Unsupervised classifier-free guidance scale remasking: Remasking strategy 'low_confidence' or 'random' mask_id: The token id of [MASK] is 126336 code_start: Prediction text token satrt index