divisor.dimoo.modeling_xllmx_dimoo

 1# SPDX-License-Identifier: Apache-2.0
 2# adapted from https://github.com/Alpha-VLLM/Lumina-DiMOO
 3
 4import torch.nn.functional as F
 5import torch
 6from divisor.mmada.modeling_llada import LLaDAModelLM
 7from divisor.mmada.configuration_llada import LLaDAConfig
 8from typing import List
 9
10
11def create_attention_mask(original_lengths, max_tokens, device):
12    batch_size = len(original_lengths)
13    attention_mask = torch.zeros(batch_size, max_tokens, dtype=torch.bool, device=device)
14    for i, length in enumerate(original_lengths):
15        attention_mask[i, :length] = 1  # 有效位置设为1
16    return attention_mask
17
18
19class XLLMXDimooModelLM(LLaDAModelLM):
20    config_class = LLaDAConfig
21    base_model_prefix = "model"
22
23    def __init__(self, config: LLaDAConfig, *args, **kwargs):
24        print(f"Initializing MMadaModelLM with config: {config}")
25        super().__init__(config, *args, **kwargs)
26
27    def forward(self, input_ids=None, labels=None, infer=False, use_cache=False, to_compute_mask=None, cat="", **kwargs):
28        if infer:
29            input_ids = input_ids.tolist()
30        # ========================================================
31        # padding input batch len & attention bias for attention mask
32        # ========================================================
33        max_tokens = max([len(_) for _ in input_ids])
34        original_lengths = [len(example) for example in input_ids]  # every sample len --> record for attention mask
35        input_ids = [example + [0] * (max_tokens - len(example)) for example in input_ids]  # padding 0 to right --> max length
36        input_ids = torch.tensor(input_ids, dtype=torch.int64, device=self.device)
37        # attn mask
38        attention_mask = create_attention_mask(original_lengths, max_tokens, self.device)
39        attention_bias = (attention_mask[:, :, None] & attention_mask[:, None, :]).bool().unsqueeze(1)
40        # ========================================================
41        # model output
42        # ========================================================
43        output = LLaDAModelLM.forward(self, input_ids=input_ids, attention_bias=attention_bias, use_cache=use_cache, to_compute_mask=to_compute_mask, cat=cat)
44        if infer:
45            return output
46
47        # ========================================================
48        # padding label batch len & loss
49        # ========================================================
50        labels = [label + [-100] * (max_tokens - len(label)) for label in labels]  # padding -100 to right --> max length
51        labels = torch.tensor(labels, dtype=torch.int64, device=self.device)
52        logits = output.logits
53        loss = F.cross_entropy(
54            logits.contiguous().view(-1, logits.shape[-1]),
55            labels.contiguous().view(-1),
56            ignore_index=-100,
57        )
58        return loss
59
60    def get_fsdp_wrap_module_list(self) -> List:
61        modules = [*list(self.model.transformer.blocks), self.model.transformer.ff_out]
62        return modules
def create_attention_mask(original_lengths, max_tokens, device):
12def create_attention_mask(original_lengths, max_tokens, device):
13    batch_size = len(original_lengths)
14    attention_mask = torch.zeros(batch_size, max_tokens, dtype=torch.bool, device=device)
15    for i, length in enumerate(original_lengths):
16        attention_mask[i, :length] = 1  # 有效位置设为1
17    return attention_mask
class XLLMXDimooModelLM(divisor.mmada.modeling_llada.LLaDAModelLM):
20class XLLMXDimooModelLM(LLaDAModelLM):
21    config_class = LLaDAConfig
22    base_model_prefix = "model"
23
24    def __init__(self, config: LLaDAConfig, *args, **kwargs):
25        print(f"Initializing MMadaModelLM with config: {config}")
26        super().__init__(config, *args, **kwargs)
27
28    def forward(self, input_ids=None, labels=None, infer=False, use_cache=False, to_compute_mask=None, cat="", **kwargs):
29        if infer:
30            input_ids = input_ids.tolist()
31        # ========================================================
32        # padding input batch len & attention bias for attention mask
33        # ========================================================
34        max_tokens = max([len(_) for _ in input_ids])
35        original_lengths = [len(example) for example in input_ids]  # every sample len --> record for attention mask
36        input_ids = [example + [0] * (max_tokens - len(example)) for example in input_ids]  # padding 0 to right --> max length
37        input_ids = torch.tensor(input_ids, dtype=torch.int64, device=self.device)
38        # attn mask
39        attention_mask = create_attention_mask(original_lengths, max_tokens, self.device)
40        attention_bias = (attention_mask[:, :, None] & attention_mask[:, None, :]).bool().unsqueeze(1)
41        # ========================================================
42        # model output
43        # ========================================================
44        output = LLaDAModelLM.forward(self, input_ids=input_ids, attention_bias=attention_bias, use_cache=use_cache, to_compute_mask=to_compute_mask, cat=cat)
45        if infer:
46            return output
47
48        # ========================================================
49        # padding label batch len & loss
50        # ========================================================
51        labels = [label + [-100] * (max_tokens - len(label)) for label in labels]  # padding -100 to right --> max length
52        labels = torch.tensor(labels, dtype=torch.int64, device=self.device)
53        logits = output.logits
54        loss = F.cross_entropy(
55            logits.contiguous().view(-1, logits.shape[-1]),
56            labels.contiguous().view(-1),
57            ignore_index=-100,
58        )
59        return loss
60
61    def get_fsdp_wrap_module_list(self) -> List:
62        modules = [*list(self.model.transformer.blocks), self.model.transformer.ff_out]
63        return modules

Extremely barebones HF model wrapper.

XLLMXDimooModelLM( config: divisor.mmada.configuration_llada.LLaDAConfig, *args, **kwargs)
24    def __init__(self, config: LLaDAConfig, *args, **kwargs):
25        print(f"Initializing MMadaModelLM with config: {config}")
26        super().__init__(config, *args, **kwargs)

Args: config ([PretrainedConfig]): Model configuration class with all the parameters of the model. Initializing with a config file does not load the weights associated with the model, only the configuration. Check out the [~PreTrainedModel.from_pretrained] method to load the model weights.

config_class = <class 'divisor.mmada.configuration_llada.LLaDAConfig'>
base_model_prefix = 'model'
def forward( self, input_ids=None, labels=None, infer=False, use_cache=False, to_compute_mask=None, cat='', **kwargs):
28    def forward(self, input_ids=None, labels=None, infer=False, use_cache=False, to_compute_mask=None, cat="", **kwargs):
29        if infer:
30            input_ids = input_ids.tolist()
31        # ========================================================
32        # padding input batch len & attention bias for attention mask
33        # ========================================================
34        max_tokens = max([len(_) for _ in input_ids])
35        original_lengths = [len(example) for example in input_ids]  # every sample len --> record for attention mask
36        input_ids = [example + [0] * (max_tokens - len(example)) for example in input_ids]  # padding 0 to right --> max length
37        input_ids = torch.tensor(input_ids, dtype=torch.int64, device=self.device)
38        # attn mask
39        attention_mask = create_attention_mask(original_lengths, max_tokens, self.device)
40        attention_bias = (attention_mask[:, :, None] & attention_mask[:, None, :]).bool().unsqueeze(1)
41        # ========================================================
42        # model output
43        # ========================================================
44        output = LLaDAModelLM.forward(self, input_ids=input_ids, attention_bias=attention_bias, use_cache=use_cache, to_compute_mask=to_compute_mask, cat=cat)
45        if infer:
46            return output
47
48        # ========================================================
49        # padding label batch len & loss
50        # ========================================================
51        labels = [label + [-100] * (max_tokens - len(label)) for label in labels]  # padding -100 to right --> max length
52        labels = torch.tensor(labels, dtype=torch.int64, device=self.device)
53        logits = output.logits
54        loss = F.cross_entropy(
55            logits.contiguous().view(-1, logits.shape[-1]),
56            labels.contiguous().view(-1),
57            ignore_index=-100,
58        )
59        return loss

Define the computation performed at every call.

Should be overridden by all subclasses.

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

def get_fsdp_wrap_module_list(self) -> List:
61    def get_fsdp_wrap_module_list(self) -> List:
62        modules = [*list(self.model.transformer.blocks), self.model.transformer.ff_out]
63        return modules