divisor.dimoo.modeling_xllmx_dimoo
1# SPDX-License-Identifier: Apache-2.0 2# adapted from https://github.com/Alpha-VLLM/Lumina-DiMOO 3 4import torch.nn.functional as F 5import torch 6from divisor.mmada.modeling_llada import LLaDAModelLM 7from divisor.mmada.configuration_llada import LLaDAConfig 8from typing import List 9 10 11def create_attention_mask(original_lengths, max_tokens, device): 12 batch_size = len(original_lengths) 13 attention_mask = torch.zeros(batch_size, max_tokens, dtype=torch.bool, device=device) 14 for i, length in enumerate(original_lengths): 15 attention_mask[i, :length] = 1 # 有效位置设为1 16 return attention_mask 17 18 19class XLLMXDimooModelLM(LLaDAModelLM): 20 config_class = LLaDAConfig 21 base_model_prefix = "model" 22 23 def __init__(self, config: LLaDAConfig, *args, **kwargs): 24 print(f"Initializing MMadaModelLM with config: {config}") 25 super().__init__(config, *args, **kwargs) 26 27 def forward(self, input_ids=None, labels=None, infer=False, use_cache=False, to_compute_mask=None, cat="", **kwargs): 28 if infer: 29 input_ids = input_ids.tolist() 30 # ======================================================== 31 # padding input batch len & attention bias for attention mask 32 # ======================================================== 33 max_tokens = max([len(_) for _ in input_ids]) 34 original_lengths = [len(example) for example in input_ids] # every sample len --> record for attention mask 35 input_ids = [example + [0] * (max_tokens - len(example)) for example in input_ids] # padding 0 to right --> max length 36 input_ids = torch.tensor(input_ids, dtype=torch.int64, device=self.device) 37 # attn mask 38 attention_mask = create_attention_mask(original_lengths, max_tokens, self.device) 39 attention_bias = (attention_mask[:, :, None] & attention_mask[:, None, :]).bool().unsqueeze(1) 40 # ======================================================== 41 # model output 42 # ======================================================== 43 output = LLaDAModelLM.forward(self, input_ids=input_ids, attention_bias=attention_bias, use_cache=use_cache, to_compute_mask=to_compute_mask, cat=cat) 44 if infer: 45 return output 46 47 # ======================================================== 48 # padding label batch len & loss 49 # ======================================================== 50 labels = [label + [-100] * (max_tokens - len(label)) for label in labels] # padding -100 to right --> max length 51 labels = torch.tensor(labels, dtype=torch.int64, device=self.device) 52 logits = output.logits 53 loss = F.cross_entropy( 54 logits.contiguous().view(-1, logits.shape[-1]), 55 labels.contiguous().view(-1), 56 ignore_index=-100, 57 ) 58 return loss 59 60 def get_fsdp_wrap_module_list(self) -> List: 61 modules = [*list(self.model.transformer.blocks), self.model.transformer.ff_out] 62 return modules
def
create_attention_mask(original_lengths, max_tokens, device):
12def create_attention_mask(original_lengths, max_tokens, device): 13 batch_size = len(original_lengths) 14 attention_mask = torch.zeros(batch_size, max_tokens, dtype=torch.bool, device=device) 15 for i, length in enumerate(original_lengths): 16 attention_mask[i, :length] = 1 # 有效位置设为1 17 return attention_mask
class
XLLMXDimooModelLM(divisor.mmada.modeling_llada.LLaDAModelLM):
20class XLLMXDimooModelLM(LLaDAModelLM): 21 config_class = LLaDAConfig 22 base_model_prefix = "model" 23 24 def __init__(self, config: LLaDAConfig, *args, **kwargs): 25 print(f"Initializing MMadaModelLM with config: {config}") 26 super().__init__(config, *args, **kwargs) 27 28 def forward(self, input_ids=None, labels=None, infer=False, use_cache=False, to_compute_mask=None, cat="", **kwargs): 29 if infer: 30 input_ids = input_ids.tolist() 31 # ======================================================== 32 # padding input batch len & attention bias for attention mask 33 # ======================================================== 34 max_tokens = max([len(_) for _ in input_ids]) 35 original_lengths = [len(example) for example in input_ids] # every sample len --> record for attention mask 36 input_ids = [example + [0] * (max_tokens - len(example)) for example in input_ids] # padding 0 to right --> max length 37 input_ids = torch.tensor(input_ids, dtype=torch.int64, device=self.device) 38 # attn mask 39 attention_mask = create_attention_mask(original_lengths, max_tokens, self.device) 40 attention_bias = (attention_mask[:, :, None] & attention_mask[:, None, :]).bool().unsqueeze(1) 41 # ======================================================== 42 # model output 43 # ======================================================== 44 output = LLaDAModelLM.forward(self, input_ids=input_ids, attention_bias=attention_bias, use_cache=use_cache, to_compute_mask=to_compute_mask, cat=cat) 45 if infer: 46 return output 47 48 # ======================================================== 49 # padding label batch len & loss 50 # ======================================================== 51 labels = [label + [-100] * (max_tokens - len(label)) for label in labels] # padding -100 to right --> max length 52 labels = torch.tensor(labels, dtype=torch.int64, device=self.device) 53 logits = output.logits 54 loss = F.cross_entropy( 55 logits.contiguous().view(-1, logits.shape[-1]), 56 labels.contiguous().view(-1), 57 ignore_index=-100, 58 ) 59 return loss 60 61 def get_fsdp_wrap_module_list(self) -> List: 62 modules = [*list(self.model.transformer.blocks), self.model.transformer.ff_out] 63 return modules
Extremely barebones HF model wrapper.
XLLMXDimooModelLM( config: divisor.mmada.configuration_llada.LLaDAConfig, *args, **kwargs)
24 def __init__(self, config: LLaDAConfig, *args, **kwargs): 25 print(f"Initializing MMadaModelLM with config: {config}") 26 super().__init__(config, *args, **kwargs)
Args:
config ([PretrainedConfig]):
Model configuration class with all the parameters of the model. Initializing with a config file does not
load the weights associated with the model, only the configuration. Check out the
[~PreTrainedModel.from_pretrained] method to load the model weights.
def
forward( self, input_ids=None, labels=None, infer=False, use_cache=False, to_compute_mask=None, cat='', **kwargs):
28 def forward(self, input_ids=None, labels=None, infer=False, use_cache=False, to_compute_mask=None, cat="", **kwargs): 29 if infer: 30 input_ids = input_ids.tolist() 31 # ======================================================== 32 # padding input batch len & attention bias for attention mask 33 # ======================================================== 34 max_tokens = max([len(_) for _ in input_ids]) 35 original_lengths = [len(example) for example in input_ids] # every sample len --> record for attention mask 36 input_ids = [example + [0] * (max_tokens - len(example)) for example in input_ids] # padding 0 to right --> max length 37 input_ids = torch.tensor(input_ids, dtype=torch.int64, device=self.device) 38 # attn mask 39 attention_mask = create_attention_mask(original_lengths, max_tokens, self.device) 40 attention_bias = (attention_mask[:, :, None] & attention_mask[:, None, :]).bool().unsqueeze(1) 41 # ======================================================== 42 # model output 43 # ======================================================== 44 output = LLaDAModelLM.forward(self, input_ids=input_ids, attention_bias=attention_bias, use_cache=use_cache, to_compute_mask=to_compute_mask, cat=cat) 45 if infer: 46 return output 47 48 # ======================================================== 49 # padding label batch len & loss 50 # ======================================================== 51 labels = [label + [-100] * (max_tokens - len(label)) for label in labels] # padding -100 to right --> max length 52 labels = torch.tensor(labels, dtype=torch.int64, device=self.device) 53 logits = output.logits 54 loss = F.cross_entropy( 55 logits.contiguous().view(-1, logits.shape[-1]), 56 labels.contiguous().view(-1), 57 ignore_index=-100, 58 ) 59 return loss
Define the computation performed at every call.
Should be overridden by all subclasses.
Although the recipe for forward pass needs to be defined within
this function, one should call the Module instance afterwards
instead of this since the former takes care of running the
registered hooks while the latter silently ignores them.