divisor.flux2.text_encoder

  1# SPDX-License-Identifier:Apache-2.0
  2# original BFL Flux code from https://github.com/black-forest-labs/flux2
  3
  4from pathlib import Path
  5
  6from PIL import Image
  7from einops import rearrange
  8from nnll.console import nfo
  9from divisor.registry import gfx_dtype
 10import torch
 11import torch.nn as nn
 12from transformers import AutoProcessor, Mistral3ForConditionalGeneration
 13
 14
 15from divisor.flux2.sampling import cap_pixels, concatenate_images
 16from divisor.flux2.system_messages import (
 17    PROMPT_IMAGE_INTEGRITY,
 18    PROMPT_IMAGE_INTEGRITY_FOLLOW_UP,
 19    PROMPT_TEXT_INTEGRITY,
 20    SYSTEM_MESSAGE,
 21    SYSTEM_MESSAGE_UPSAMPLING_I2I,
 22    SYSTEM_MESSAGE_UPSAMPLING_T2I,
 23    SYSTEM_PROMPT_CONTENT_FILTER,
 24)
 25
 26OUTPUT_LAYERS = [10, 20, 30]
 27MAX_LENGTH = 512
 28UPSAMPLING_MAX_IMAGE_SIZE = 768**2
 29precision: torch.dtype = gfx_dtype
 30
 31
 32class Mistral3SmallEmbedder(nn.Module):
 33    def __init__(
 34        self,
 35        model_spec: str = "mistralai/Mistral-Small-3.2-24B-Instruct-2506",
 36        model_spec_processor: str = "mistralai/Mistral-Small-3.1-24B-Instruct-2503",
 37        torch_dtype: torch.dtype = precision,
 38    ):
 39        super().__init__()
 40
 41        self.model: Mistral3ForConditionalGeneration = Mistral3ForConditionalGeneration.from_pretrained(
 42            model_spec,
 43            dtype=torch_dtype,
 44        )
 45        self.processor = AutoProcessor.from_pretrained(model_spec_processor, use_fast=False)
 46        self.yes_token, self.no_token = self.processor.tokenizer.encode(["yes", "no"], add_special_tokens=False)
 47
 48        self.max_length = MAX_LENGTH
 49        self.upsampling_max_image_size = UPSAMPLING_MAX_IMAGE_SIZE
 50
 51    def _validate_and_process_images(self, img: list[list[Image.Image]] | list[Image.Image]) -> list[list[Image.Image]]:
 52        # Simple validation: ensure it's a list of PIL images or list of lists of PIL images
 53        if not img:
 54            return []
 55
 56        # Check if it's a list of lists or a list of images
 57        if isinstance(img[0], Image.Image):
 58            # It's a list of images, convert to list of lists
 59            img = [[im] for im in img]  # type: ignore
 60
 61        # potentially concatenate multiple images to reduce the size
 62        img = [[concatenate_images(img_i)] if len(img_i) > 1 else img_i for img_i in img]  # type: ignore
 63
 64        # cap the pixels
 65        img = [[cap_pixels(img_i, self.upsampling_max_image_size) for img_i in img_i] for img_i in img]  # type: ignore
 66        return img  # type: ignore
 67
 68    def format_input(
 69        self,
 70        txt: list[str],
 71        system_message: str = SYSTEM_MESSAGE,
 72        img: list[Image.Image] | list[list[Image.Image]] | None = None,
 73    ) -> list[list[dict]]:
 74        """
 75        Format a batch of text prompts into the conversation format expected by apply_chat_template.
 76        Optionally, add images to the input.
 77
 78        Args:
 79            txt: List of text prompts
 80            system_message: System message to use (default: CREATIVE_SYSTEM_MESSAGE)
 81            img: List of images to add to the input.
 82
 83        Returns:
 84            List of conversations, where each conversation is a list of message dicts
 85        """
 86        # Remove [IMG] tokens from prompts to avoid Pixtral validation issues
 87        # when truncation is enabled. The processor counts [IMG] tokens and fails
 88        # if the count changes after truncation.
 89        cleaned_txt = [prompt.replace("[IMG]", "") for prompt in txt]
 90
 91        if img is None or len(img) == 0:
 92            return [
 93                [
 94                    {
 95                        "role": "system",
 96                        "content": [{"type": "text", "text": system_message}],
 97                    },
 98                    {"role": "user", "content": [{"type": "text", "text": prompt}]},
 99                ]
100                for prompt in cleaned_txt
101            ]
102        else:
103            assert len(img) == len(txt), "Number of images must match number of prompts"
104            img = self._validate_and_process_images(img)
105
106            messages = [
107                [
108                    {
109                        "role": "system",
110                        "content": [{"type": "text", "text": system_message}],
111                    },
112                ]
113                for _ in cleaned_txt
114            ]
115
116            for i, (el, images) in enumerate(zip(messages, img)):
117                # optionally add the images per batch element.
118                if images is not None:
119                    el.append(
120                        {
121                            "role": "user",
122                            "content": [{"type": "image", "image": image_obj} for image_obj in images],
123                        }
124                    )
125                # add the text.
126                el.append(
127                    {
128                        "role": "user",
129                        "content": [{"type": "text", "text": cleaned_txt[i]}],
130                    }
131                )
132
133            return messages
134
135    @torch.no_grad()
136    def upsample_prompt(
137        self,
138        txt: list[str],
139        img: list[Image.Image] | list[list[Image.Image]] | None = None,
140        temperature: float = 0.15,
141    ) -> list[str]:
142        """
143        Upsample prompts using the model's generate method.
144
145        Args:
146            txt: List of input prompts to upsample
147            img: Optional list of images or list of lists of images. If None or all None, uses t2i mode, otherwise i2i mode.
148
149        Returns:
150            List of upsampled prompts
151        """
152        # Set system message based on whether images are provided
153        if img is None or len(img) == 0 or img[0] is None:
154            system_message = SYSTEM_MESSAGE_UPSAMPLING_T2I
155        else:
156            system_message = SYSTEM_MESSAGE_UPSAMPLING_I2I
157
158        # Format input messages
159        messages_batch = self.format_input(txt=txt, system_message=system_message, img=img)
160
161        # Process all messages at once
162        # with image processing a too short max length can throw an error in here.
163        try:
164            inputs = self.processor.apply_chat_template(
165                messages_batch,
166                add_generation_prompt=True,
167                tokenize=True,
168                return_dict=True,
169                return_tensors="pt",
170                padding="max_length",
171                truncation=True,
172                max_length=2048,
173            )
174        except ValueError as e:
175            nfo(f"Error processing input: {e}, your max length is probably too short, when you have images in the input.")
176            raise e
177
178        # Move to device
179        inputs["input_ids"] = inputs["input_ids"].to(self.model.device)
180        inputs["attention_mask"] = inputs["attention_mask"].to(self.model.device)
181
182        if "pixel_values" in inputs:
183            inputs["pixel_values"] = inputs["pixel_values"].to(self.model.device, self.model.dtype)
184
185        # Generate text using the model's generate method
186        try:
187            generated_ids = self.model.generate(
188                **inputs,
189                max_new_tokens=512,
190                do_sample=True,
191                temperature=temperature,
192                use_cache=True,
193            )
194
195            # Decode only the newly generated tokens (skip input tokens)
196            # Extract only the generated portion
197            input_length = inputs["input_ids"].shape[1]
198            generated_tokens = generated_ids[:, input_length:]
199
200            raw_txt = self.processor.tokenizer.batch_decode(generated_tokens, skip_special_tokens=True, clean_up_tokenization_spaces=True)
201            return raw_txt
202        except Exception as e:
203            nfo(f"Error generating upsampled prompt: {e}, returning original prompt")
204            return txt
205
206    @torch.no_grad()
207    def forward(self, txt: list[str]):
208        # Format input messages
209        messages_batch = self.format_input(txt=txt)
210
211        # Process all messages at once
212        # with image processing a too short max length can throw an error in here.
213        inputs = self.processor.apply_chat_template(
214            messages_batch,
215            add_generation_prompt=False,
216            tokenize=True,
217            return_dict=True,
218            return_tensors="pt",
219            padding="max_length",
220            truncation=True,
221            max_length=self.max_length,
222        )
223
224        # Move to device
225        input_ids = inputs["input_ids"].to(self.model.device)
226        attention_mask = inputs["attention_mask"].to(self.model.device)
227
228        # Forward pass through the model
229        output = self.model(
230            input_ids=input_ids,
231            attention_mask=attention_mask,
232            output_hidden_states=True,
233            use_cache=False,
234        )
235
236        out = torch.stack([output.hidden_states[k] for k in OUTPUT_LAYERS], dim=1)
237        return rearrange(out, "b c l d -> b l (c d)")
238
239    def yes_no_logit_processor(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
240        """
241        Sets all tokens but yes/no to the minimum.
242        """
243        scores_yes_token = scores[:, self.yes_token].clone()
244        scores_no_token = scores[:, self.no_token].clone()
245        scores_min = scores.min()
246        scores[:, :] = scores_min - 1
247        scores[:, self.yes_token] = scores_yes_token
248        scores[:, self.no_token] = scores_no_token
249        return scores
250
251    def test_image(self, image: Image.Image | str | Path | torch.Tensor) -> bool:
252        if isinstance(image, torch.Tensor):
253            image = rearrange(image[0].clamp(-1.0, 1.0), "c h w -> h w c")
254            image = Image.fromarray((127.5 * (image + 1.0)).cpu().byte().numpy())
255        elif isinstance(image, (str, Path)):
256            image = Image.open(image)
257
258        # 512^2 pixels are enough for checking
259        w, h = image.size
260        f = (512**2 / (w * h)) ** 0.5
261        image = image.resize((int(f * w), int(f * h)))
262
263        chat = [
264            {
265                "role": "system",
266                "content": [
267                    {
268                        "type": "text",
269                        "text": SYSTEM_PROMPT_CONTENT_FILTER,
270                    },
271                ],
272            },
273            {
274                "role": "user",
275                "content": [
276                    {
277                        "type": "text",
278                        "text": PROMPT_IMAGE_INTEGRITY,
279                    },
280                    {
281                        "type": "image",
282                        "image": image,
283                    },
284                    {
285                        "type": "text",
286                        "text": PROMPT_IMAGE_INTEGRITY_FOLLOW_UP,
287                    },
288                ],
289            },
290        ]
291
292        inputs = self.processor.apply_chat_template(
293            chat,
294            add_generation_prompt=True,
295            tokenize=True,
296            return_dict=True,
297            return_tensors="pt",
298        ).to(self.model.device)
299        inputs["pixel_values"] = inputs["pixel_values"].to(dtype=self.model.dtype)
300
301        generate_ids = self.model.generate(
302            **inputs,
303            max_new_tokens=1,
304            logits_processor=[self.yes_no_logit_processor],  # type: ignore
305            do_sample=False,
306        )
307
308        return generate_ids[0, -1].item() == self.yes_token
309
310    def test_txt(self, txt: str) -> bool:
311        chat = [
312            {
313                "role": "system",
314                "content": [
315                    {
316                        "type": "text",
317                        "text": SYSTEM_PROMPT_CONTENT_FILTER,
318                    },
319                ],
320            },
321            {
322                "role": "user",
323                "content": [
324                    {
325                        "type": "text",
326                        "text": PROMPT_TEXT_INTEGRITY.format(prompt=txt),
327                    },
328                ],
329            },
330        ]
331
332        inputs = self.processor.apply_chat_template(
333            chat,
334            add_generation_prompt=True,
335            tokenize=True,
336            return_dict=True,
337            return_tensors="pt",
338        ).to(self.model.device)
339
340        generate_ids = self.model.generate(
341            **inputs,
342            max_new_tokens=1,
343            logits_processor=[self.yes_no_logit_processor],  # type: ignore
344            do_sample=False,
345        )
346        return generate_ids[0, -1].item() == self.yes_token
OUTPUT_LAYERS = [10, 20, 30]
MAX_LENGTH = 512
UPSAMPLING_MAX_IMAGE_SIZE = 589824
precision: torch.dtype = torch.bfloat16
class Mistral3SmallEmbedder(torch.nn.modules.module.Module):
 33class Mistral3SmallEmbedder(nn.Module):
 34    def __init__(
 35        self,
 36        model_spec: str = "mistralai/Mistral-Small-3.2-24B-Instruct-2506",
 37        model_spec_processor: str = "mistralai/Mistral-Small-3.1-24B-Instruct-2503",
 38        torch_dtype: torch.dtype = precision,
 39    ):
 40        super().__init__()
 41
 42        self.model: Mistral3ForConditionalGeneration = Mistral3ForConditionalGeneration.from_pretrained(
 43            model_spec,
 44            dtype=torch_dtype,
 45        )
 46        self.processor = AutoProcessor.from_pretrained(model_spec_processor, use_fast=False)
 47        self.yes_token, self.no_token = self.processor.tokenizer.encode(["yes", "no"], add_special_tokens=False)
 48
 49        self.max_length = MAX_LENGTH
 50        self.upsampling_max_image_size = UPSAMPLING_MAX_IMAGE_SIZE
 51
 52    def _validate_and_process_images(self, img: list[list[Image.Image]] | list[Image.Image]) -> list[list[Image.Image]]:
 53        # Simple validation: ensure it's a list of PIL images or list of lists of PIL images
 54        if not img:
 55            return []
 56
 57        # Check if it's a list of lists or a list of images
 58        if isinstance(img[0], Image.Image):
 59            # It's a list of images, convert to list of lists
 60            img = [[im] for im in img]  # type: ignore
 61
 62        # potentially concatenate multiple images to reduce the size
 63        img = [[concatenate_images(img_i)] if len(img_i) > 1 else img_i for img_i in img]  # type: ignore
 64
 65        # cap the pixels
 66        img = [[cap_pixels(img_i, self.upsampling_max_image_size) for img_i in img_i] for img_i in img]  # type: ignore
 67        return img  # type: ignore
 68
 69    def format_input(
 70        self,
 71        txt: list[str],
 72        system_message: str = SYSTEM_MESSAGE,
 73        img: list[Image.Image] | list[list[Image.Image]] | None = None,
 74    ) -> list[list[dict]]:
 75        """
 76        Format a batch of text prompts into the conversation format expected by apply_chat_template.
 77        Optionally, add images to the input.
 78
 79        Args:
 80            txt: List of text prompts
 81            system_message: System message to use (default: CREATIVE_SYSTEM_MESSAGE)
 82            img: List of images to add to the input.
 83
 84        Returns:
 85            List of conversations, where each conversation is a list of message dicts
 86        """
 87        # Remove [IMG] tokens from prompts to avoid Pixtral validation issues
 88        # when truncation is enabled. The processor counts [IMG] tokens and fails
 89        # if the count changes after truncation.
 90        cleaned_txt = [prompt.replace("[IMG]", "") for prompt in txt]
 91
 92        if img is None or len(img) == 0:
 93            return [
 94                [
 95                    {
 96                        "role": "system",
 97                        "content": [{"type": "text", "text": system_message}],
 98                    },
 99                    {"role": "user", "content": [{"type": "text", "text": prompt}]},
100                ]
101                for prompt in cleaned_txt
102            ]
103        else:
104            assert len(img) == len(txt), "Number of images must match number of prompts"
105            img = self._validate_and_process_images(img)
106
107            messages = [
108                [
109                    {
110                        "role": "system",
111                        "content": [{"type": "text", "text": system_message}],
112                    },
113                ]
114                for _ in cleaned_txt
115            ]
116
117            for i, (el, images) in enumerate(zip(messages, img)):
118                # optionally add the images per batch element.
119                if images is not None:
120                    el.append(
121                        {
122                            "role": "user",
123                            "content": [{"type": "image", "image": image_obj} for image_obj in images],
124                        }
125                    )
126                # add the text.
127                el.append(
128                    {
129                        "role": "user",
130                        "content": [{"type": "text", "text": cleaned_txt[i]}],
131                    }
132                )
133
134            return messages
135
136    @torch.no_grad()
137    def upsample_prompt(
138        self,
139        txt: list[str],
140        img: list[Image.Image] | list[list[Image.Image]] | None = None,
141        temperature: float = 0.15,
142    ) -> list[str]:
143        """
144        Upsample prompts using the model's generate method.
145
146        Args:
147            txt: List of input prompts to upsample
148            img: Optional list of images or list of lists of images. If None or all None, uses t2i mode, otherwise i2i mode.
149
150        Returns:
151            List of upsampled prompts
152        """
153        # Set system message based on whether images are provided
154        if img is None or len(img) == 0 or img[0] is None:
155            system_message = SYSTEM_MESSAGE_UPSAMPLING_T2I
156        else:
157            system_message = SYSTEM_MESSAGE_UPSAMPLING_I2I
158
159        # Format input messages
160        messages_batch = self.format_input(txt=txt, system_message=system_message, img=img)
161
162        # Process all messages at once
163        # with image processing a too short max length can throw an error in here.
164        try:
165            inputs = self.processor.apply_chat_template(
166                messages_batch,
167                add_generation_prompt=True,
168                tokenize=True,
169                return_dict=True,
170                return_tensors="pt",
171                padding="max_length",
172                truncation=True,
173                max_length=2048,
174            )
175        except ValueError as e:
176            nfo(f"Error processing input: {e}, your max length is probably too short, when you have images in the input.")
177            raise e
178
179        # Move to device
180        inputs["input_ids"] = inputs["input_ids"].to(self.model.device)
181        inputs["attention_mask"] = inputs["attention_mask"].to(self.model.device)
182
183        if "pixel_values" in inputs:
184            inputs["pixel_values"] = inputs["pixel_values"].to(self.model.device, self.model.dtype)
185
186        # Generate text using the model's generate method
187        try:
188            generated_ids = self.model.generate(
189                **inputs,
190                max_new_tokens=512,
191                do_sample=True,
192                temperature=temperature,
193                use_cache=True,
194            )
195
196            # Decode only the newly generated tokens (skip input tokens)
197            # Extract only the generated portion
198            input_length = inputs["input_ids"].shape[1]
199            generated_tokens = generated_ids[:, input_length:]
200
201            raw_txt = self.processor.tokenizer.batch_decode(generated_tokens, skip_special_tokens=True, clean_up_tokenization_spaces=True)
202            return raw_txt
203        except Exception as e:
204            nfo(f"Error generating upsampled prompt: {e}, returning original prompt")
205            return txt
206
207    @torch.no_grad()
208    def forward(self, txt: list[str]):
209        # Format input messages
210        messages_batch = self.format_input(txt=txt)
211
212        # Process all messages at once
213        # with image processing a too short max length can throw an error in here.
214        inputs = self.processor.apply_chat_template(
215            messages_batch,
216            add_generation_prompt=False,
217            tokenize=True,
218            return_dict=True,
219            return_tensors="pt",
220            padding="max_length",
221            truncation=True,
222            max_length=self.max_length,
223        )
224
225        # Move to device
226        input_ids = inputs["input_ids"].to(self.model.device)
227        attention_mask = inputs["attention_mask"].to(self.model.device)
228
229        # Forward pass through the model
230        output = self.model(
231            input_ids=input_ids,
232            attention_mask=attention_mask,
233            output_hidden_states=True,
234            use_cache=False,
235        )
236
237        out = torch.stack([output.hidden_states[k] for k in OUTPUT_LAYERS], dim=1)
238        return rearrange(out, "b c l d -> b l (c d)")
239
240    def yes_no_logit_processor(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
241        """
242        Sets all tokens but yes/no to the minimum.
243        """
244        scores_yes_token = scores[:, self.yes_token].clone()
245        scores_no_token = scores[:, self.no_token].clone()
246        scores_min = scores.min()
247        scores[:, :] = scores_min - 1
248        scores[:, self.yes_token] = scores_yes_token
249        scores[:, self.no_token] = scores_no_token
250        return scores
251
252    def test_image(self, image: Image.Image | str | Path | torch.Tensor) -> bool:
253        if isinstance(image, torch.Tensor):
254            image = rearrange(image[0].clamp(-1.0, 1.0), "c h w -> h w c")
255            image = Image.fromarray((127.5 * (image + 1.0)).cpu().byte().numpy())
256        elif isinstance(image, (str, Path)):
257            image = Image.open(image)
258
259        # 512^2 pixels are enough for checking
260        w, h = image.size
261        f = (512**2 / (w * h)) ** 0.5
262        image = image.resize((int(f * w), int(f * h)))
263
264        chat = [
265            {
266                "role": "system",
267                "content": [
268                    {
269                        "type": "text",
270                        "text": SYSTEM_PROMPT_CONTENT_FILTER,
271                    },
272                ],
273            },
274            {
275                "role": "user",
276                "content": [
277                    {
278                        "type": "text",
279                        "text": PROMPT_IMAGE_INTEGRITY,
280                    },
281                    {
282                        "type": "image",
283                        "image": image,
284                    },
285                    {
286                        "type": "text",
287                        "text": PROMPT_IMAGE_INTEGRITY_FOLLOW_UP,
288                    },
289                ],
290            },
291        ]
292
293        inputs = self.processor.apply_chat_template(
294            chat,
295            add_generation_prompt=True,
296            tokenize=True,
297            return_dict=True,
298            return_tensors="pt",
299        ).to(self.model.device)
300        inputs["pixel_values"] = inputs["pixel_values"].to(dtype=self.model.dtype)
301
302        generate_ids = self.model.generate(
303            **inputs,
304            max_new_tokens=1,
305            logits_processor=[self.yes_no_logit_processor],  # type: ignore
306            do_sample=False,
307        )
308
309        return generate_ids[0, -1].item() == self.yes_token
310
311    def test_txt(self, txt: str) -> bool:
312        chat = [
313            {
314                "role": "system",
315                "content": [
316                    {
317                        "type": "text",
318                        "text": SYSTEM_PROMPT_CONTENT_FILTER,
319                    },
320                ],
321            },
322            {
323                "role": "user",
324                "content": [
325                    {
326                        "type": "text",
327                        "text": PROMPT_TEXT_INTEGRITY.format(prompt=txt),
328                    },
329                ],
330            },
331        ]
332
333        inputs = self.processor.apply_chat_template(
334            chat,
335            add_generation_prompt=True,
336            tokenize=True,
337            return_dict=True,
338            return_tensors="pt",
339        ).to(self.model.device)
340
341        generate_ids = self.model.generate(
342            **inputs,
343            max_new_tokens=1,
344            logits_processor=[self.yes_no_logit_processor],  # type: ignore
345            do_sample=False,
346        )
347        return generate_ids[0, -1].item() == self.yes_token

Base class for all neural network modules.

Your models should also subclass this class.

Modules can also contain other Modules, allowing them to be nested in a tree structure. You can assign the submodules as regular attributes::

import torch.nn as nn
import torch.nn.functional as F


class Model(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.conv1 = nn.Conv2d(1, 20, 5)
        self.conv2 = nn.Conv2d(20, 20, 5)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        return F.relu(self.conv2(x))

Submodules assigned in this way will be registered, and will also have their parameters converted when you call to(), etc.

As per the example above, an __init__() call to the parent class must be made before assignment on the child.

:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool

Mistral3SmallEmbedder( model_spec: str = 'mistralai/Mistral-Small-3.2-24B-Instruct-2506', model_spec_processor: str = 'mistralai/Mistral-Small-3.1-24B-Instruct-2503', torch_dtype: torch.dtype = torch.bfloat16)
34    def __init__(
35        self,
36        model_spec: str = "mistralai/Mistral-Small-3.2-24B-Instruct-2506",
37        model_spec_processor: str = "mistralai/Mistral-Small-3.1-24B-Instruct-2503",
38        torch_dtype: torch.dtype = precision,
39    ):
40        super().__init__()
41
42        self.model: Mistral3ForConditionalGeneration = Mistral3ForConditionalGeneration.from_pretrained(
43            model_spec,
44            dtype=torch_dtype,
45        )
46        self.processor = AutoProcessor.from_pretrained(model_spec_processor, use_fast=False)
47        self.yes_token, self.no_token = self.processor.tokenizer.encode(["yes", "no"], add_special_tokens=False)
48
49        self.max_length = MAX_LENGTH
50        self.upsampling_max_image_size = UPSAMPLING_MAX_IMAGE_SIZE

Initialize internal Module state, shared by both nn.Module and ScriptModule.

model: transformers.models.mistral3.modeling_mistral3.Mistral3ForConditionalGeneration
processor
max_length
upsampling_max_image_size
def format_input( self, txt: list[str], system_message: str = 'You are an AI that reasons about image descriptions. You give structured responses focusing on object relationships, object\nattribution and actions without speculation.', img: list[PIL.Image.Image] | list[list[PIL.Image.Image]] | None = None) -> list[list[dict]]:
 69    def format_input(
 70        self,
 71        txt: list[str],
 72        system_message: str = SYSTEM_MESSAGE,
 73        img: list[Image.Image] | list[list[Image.Image]] | None = None,
 74    ) -> list[list[dict]]:
 75        """
 76        Format a batch of text prompts into the conversation format expected by apply_chat_template.
 77        Optionally, add images to the input.
 78
 79        Args:
 80            txt: List of text prompts
 81            system_message: System message to use (default: CREATIVE_SYSTEM_MESSAGE)
 82            img: List of images to add to the input.
 83
 84        Returns:
 85            List of conversations, where each conversation is a list of message dicts
 86        """
 87        # Remove [IMG] tokens from prompts to avoid Pixtral validation issues
 88        # when truncation is enabled. The processor counts [IMG] tokens and fails
 89        # if the count changes after truncation.
 90        cleaned_txt = [prompt.replace("[IMG]", "") for prompt in txt]
 91
 92        if img is None or len(img) == 0:
 93            return [
 94                [
 95                    {
 96                        "role": "system",
 97                        "content": [{"type": "text", "text": system_message}],
 98                    },
 99                    {"role": "user", "content": [{"type": "text", "text": prompt}]},
100                ]
101                for prompt in cleaned_txt
102            ]
103        else:
104            assert len(img) == len(txt), "Number of images must match number of prompts"
105            img = self._validate_and_process_images(img)
106
107            messages = [
108                [
109                    {
110                        "role": "system",
111                        "content": [{"type": "text", "text": system_message}],
112                    },
113                ]
114                for _ in cleaned_txt
115            ]
116
117            for i, (el, images) in enumerate(zip(messages, img)):
118                # optionally add the images per batch element.
119                if images is not None:
120                    el.append(
121                        {
122                            "role": "user",
123                            "content": [{"type": "image", "image": image_obj} for image_obj in images],
124                        }
125                    )
126                # add the text.
127                el.append(
128                    {
129                        "role": "user",
130                        "content": [{"type": "text", "text": cleaned_txt[i]}],
131                    }
132                )
133
134            return messages

Format a batch of text prompts into the conversation format expected by apply_chat_template. Optionally, add images to the input.

Args: txt: List of text prompts system_message: System message to use (default: CREATIVE_SYSTEM_MESSAGE) img: List of images to add to the input.

Returns: List of conversations, where each conversation is a list of message dicts

@torch.no_grad()
def upsample_prompt( self, txt: list[str], img: list[PIL.Image.Image] | list[list[PIL.Image.Image]] | None = None, temperature: float = 0.15) -> list[str]:
136    @torch.no_grad()
137    def upsample_prompt(
138        self,
139        txt: list[str],
140        img: list[Image.Image] | list[list[Image.Image]] | None = None,
141        temperature: float = 0.15,
142    ) -> list[str]:
143        """
144        Upsample prompts using the model's generate method.
145
146        Args:
147            txt: List of input prompts to upsample
148            img: Optional list of images or list of lists of images. If None or all None, uses t2i mode, otherwise i2i mode.
149
150        Returns:
151            List of upsampled prompts
152        """
153        # Set system message based on whether images are provided
154        if img is None or len(img) == 0 or img[0] is None:
155            system_message = SYSTEM_MESSAGE_UPSAMPLING_T2I
156        else:
157            system_message = SYSTEM_MESSAGE_UPSAMPLING_I2I
158
159        # Format input messages
160        messages_batch = self.format_input(txt=txt, system_message=system_message, img=img)
161
162        # Process all messages at once
163        # with image processing a too short max length can throw an error in here.
164        try:
165            inputs = self.processor.apply_chat_template(
166                messages_batch,
167                add_generation_prompt=True,
168                tokenize=True,
169                return_dict=True,
170                return_tensors="pt",
171                padding="max_length",
172                truncation=True,
173                max_length=2048,
174            )
175        except ValueError as e:
176            nfo(f"Error processing input: {e}, your max length is probably too short, when you have images in the input.")
177            raise e
178
179        # Move to device
180        inputs["input_ids"] = inputs["input_ids"].to(self.model.device)
181        inputs["attention_mask"] = inputs["attention_mask"].to(self.model.device)
182
183        if "pixel_values" in inputs:
184            inputs["pixel_values"] = inputs["pixel_values"].to(self.model.device, self.model.dtype)
185
186        # Generate text using the model's generate method
187        try:
188            generated_ids = self.model.generate(
189                **inputs,
190                max_new_tokens=512,
191                do_sample=True,
192                temperature=temperature,
193                use_cache=True,
194            )
195
196            # Decode only the newly generated tokens (skip input tokens)
197            # Extract only the generated portion
198            input_length = inputs["input_ids"].shape[1]
199            generated_tokens = generated_ids[:, input_length:]
200
201            raw_txt = self.processor.tokenizer.batch_decode(generated_tokens, skip_special_tokens=True, clean_up_tokenization_spaces=True)
202            return raw_txt
203        except Exception as e:
204            nfo(f"Error generating upsampled prompt: {e}, returning original prompt")
205            return txt

Upsample prompts using the model's generate method.

Args: txt: List of input prompts to upsample img: Optional list of images or list of lists of images. If None or all None, uses t2i mode, otherwise i2i mode.

Returns: List of upsampled prompts

@torch.no_grad()
def forward(self, txt: list[str]):
207    @torch.no_grad()
208    def forward(self, txt: list[str]):
209        # Format input messages
210        messages_batch = self.format_input(txt=txt)
211
212        # Process all messages at once
213        # with image processing a too short max length can throw an error in here.
214        inputs = self.processor.apply_chat_template(
215            messages_batch,
216            add_generation_prompt=False,
217            tokenize=True,
218            return_dict=True,
219            return_tensors="pt",
220            padding="max_length",
221            truncation=True,
222            max_length=self.max_length,
223        )
224
225        # Move to device
226        input_ids = inputs["input_ids"].to(self.model.device)
227        attention_mask = inputs["attention_mask"].to(self.model.device)
228
229        # Forward pass through the model
230        output = self.model(
231            input_ids=input_ids,
232            attention_mask=attention_mask,
233            output_hidden_states=True,
234            use_cache=False,
235        )
236
237        out = torch.stack([output.hidden_states[k] for k in OUTPUT_LAYERS], dim=1)
238        return rearrange(out, "b c l d -> b l (c d)")

Define the computation performed at every call.

Should be overridden by all subclasses.

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

def yes_no_logit_processor( self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
240    def yes_no_logit_processor(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
241        """
242        Sets all tokens but yes/no to the minimum.
243        """
244        scores_yes_token = scores[:, self.yes_token].clone()
245        scores_no_token = scores[:, self.no_token].clone()
246        scores_min = scores.min()
247        scores[:, :] = scores_min - 1
248        scores[:, self.yes_token] = scores_yes_token
249        scores[:, self.no_token] = scores_no_token
250        return scores

Sets all tokens but yes/no to the minimum.

def test_image( self, image: PIL.Image.Image | str | pathlib._local.Path | torch.Tensor) -> bool:
252    def test_image(self, image: Image.Image | str | Path | torch.Tensor) -> bool:
253        if isinstance(image, torch.Tensor):
254            image = rearrange(image[0].clamp(-1.0, 1.0), "c h w -> h w c")
255            image = Image.fromarray((127.5 * (image + 1.0)).cpu().byte().numpy())
256        elif isinstance(image, (str, Path)):
257            image = Image.open(image)
258
259        # 512^2 pixels are enough for checking
260        w, h = image.size
261        f = (512**2 / (w * h)) ** 0.5
262        image = image.resize((int(f * w), int(f * h)))
263
264        chat = [
265            {
266                "role": "system",
267                "content": [
268                    {
269                        "type": "text",
270                        "text": SYSTEM_PROMPT_CONTENT_FILTER,
271                    },
272                ],
273            },
274            {
275                "role": "user",
276                "content": [
277                    {
278                        "type": "text",
279                        "text": PROMPT_IMAGE_INTEGRITY,
280                    },
281                    {
282                        "type": "image",
283                        "image": image,
284                    },
285                    {
286                        "type": "text",
287                        "text": PROMPT_IMAGE_INTEGRITY_FOLLOW_UP,
288                    },
289                ],
290            },
291        ]
292
293        inputs = self.processor.apply_chat_template(
294            chat,
295            add_generation_prompt=True,
296            tokenize=True,
297            return_dict=True,
298            return_tensors="pt",
299        ).to(self.model.device)
300        inputs["pixel_values"] = inputs["pixel_values"].to(dtype=self.model.dtype)
301
302        generate_ids = self.model.generate(
303            **inputs,
304            max_new_tokens=1,
305            logits_processor=[self.yes_no_logit_processor],  # type: ignore
306            do_sample=False,
307        )
308
309        return generate_ids[0, -1].item() == self.yes_token
def test_txt(self, txt: str) -> bool:
311    def test_txt(self, txt: str) -> bool:
312        chat = [
313            {
314                "role": "system",
315                "content": [
316                    {
317                        "type": "text",
318                        "text": SYSTEM_PROMPT_CONTENT_FILTER,
319                    },
320                ],
321            },
322            {
323                "role": "user",
324                "content": [
325                    {
326                        "type": "text",
327                        "text": PROMPT_TEXT_INTEGRITY.format(prompt=txt),
328                    },
329                ],
330            },
331        ]
332
333        inputs = self.processor.apply_chat_template(
334            chat,
335            add_generation_prompt=True,
336            tokenize=True,
337            return_dict=True,
338            return_tensors="pt",
339        ).to(self.model.device)
340
341        generate_ids = self.model.generate(
342            **inputs,
343            max_new_tokens=1,
344            logits_processor=[self.yes_no_logit_processor],  # type: ignore
345            do_sample=False,
346        )
347        return generate_ids[0, -1].item() == self.yes_token