divisor.flux2.text_encoder
1# SPDX-License-Identifier:Apache-2.0 2# original BFL Flux code from https://github.com/black-forest-labs/flux2 3 4from pathlib import Path 5 6from PIL import Image 7from einops import rearrange 8from nnll.console import nfo 9from divisor.registry import gfx_dtype 10import torch 11import torch.nn as nn 12from transformers import AutoProcessor, Mistral3ForConditionalGeneration 13 14 15from divisor.flux2.sampling import cap_pixels, concatenate_images 16from divisor.flux2.system_messages import ( 17 PROMPT_IMAGE_INTEGRITY, 18 PROMPT_IMAGE_INTEGRITY_FOLLOW_UP, 19 PROMPT_TEXT_INTEGRITY, 20 SYSTEM_MESSAGE, 21 SYSTEM_MESSAGE_UPSAMPLING_I2I, 22 SYSTEM_MESSAGE_UPSAMPLING_T2I, 23 SYSTEM_PROMPT_CONTENT_FILTER, 24) 25 26OUTPUT_LAYERS = [10, 20, 30] 27MAX_LENGTH = 512 28UPSAMPLING_MAX_IMAGE_SIZE = 768**2 29precision: torch.dtype = gfx_dtype 30 31 32class Mistral3SmallEmbedder(nn.Module): 33 def __init__( 34 self, 35 model_spec: str = "mistralai/Mistral-Small-3.2-24B-Instruct-2506", 36 model_spec_processor: str = "mistralai/Mistral-Small-3.1-24B-Instruct-2503", 37 torch_dtype: torch.dtype = precision, 38 ): 39 super().__init__() 40 41 self.model: Mistral3ForConditionalGeneration = Mistral3ForConditionalGeneration.from_pretrained( 42 model_spec, 43 dtype=torch_dtype, 44 ) 45 self.processor = AutoProcessor.from_pretrained(model_spec_processor, use_fast=False) 46 self.yes_token, self.no_token = self.processor.tokenizer.encode(["yes", "no"], add_special_tokens=False) 47 48 self.max_length = MAX_LENGTH 49 self.upsampling_max_image_size = UPSAMPLING_MAX_IMAGE_SIZE 50 51 def _validate_and_process_images(self, img: list[list[Image.Image]] | list[Image.Image]) -> list[list[Image.Image]]: 52 # Simple validation: ensure it's a list of PIL images or list of lists of PIL images 53 if not img: 54 return [] 55 56 # Check if it's a list of lists or a list of images 57 if isinstance(img[0], Image.Image): 58 # It's a list of images, convert to list of lists 59 img = [[im] for im in img] # type: ignore 60 61 # potentially concatenate multiple images to reduce the size 62 img = [[concatenate_images(img_i)] if len(img_i) > 1 else img_i for img_i in img] # type: ignore 63 64 # cap the pixels 65 img = [[cap_pixels(img_i, self.upsampling_max_image_size) for img_i in img_i] for img_i in img] # type: ignore 66 return img # type: ignore 67 68 def format_input( 69 self, 70 txt: list[str], 71 system_message: str = SYSTEM_MESSAGE, 72 img: list[Image.Image] | list[list[Image.Image]] | None = None, 73 ) -> list[list[dict]]: 74 """ 75 Format a batch of text prompts into the conversation format expected by apply_chat_template. 76 Optionally, add images to the input. 77 78 Args: 79 txt: List of text prompts 80 system_message: System message to use (default: CREATIVE_SYSTEM_MESSAGE) 81 img: List of images to add to the input. 82 83 Returns: 84 List of conversations, where each conversation is a list of message dicts 85 """ 86 # Remove [IMG] tokens from prompts to avoid Pixtral validation issues 87 # when truncation is enabled. The processor counts [IMG] tokens and fails 88 # if the count changes after truncation. 89 cleaned_txt = [prompt.replace("[IMG]", "") for prompt in txt] 90 91 if img is None or len(img) == 0: 92 return [ 93 [ 94 { 95 "role": "system", 96 "content": [{"type": "text", "text": system_message}], 97 }, 98 {"role": "user", "content": [{"type": "text", "text": prompt}]}, 99 ] 100 for prompt in cleaned_txt 101 ] 102 else: 103 assert len(img) == len(txt), "Number of images must match number of prompts" 104 img = self._validate_and_process_images(img) 105 106 messages = [ 107 [ 108 { 109 "role": "system", 110 "content": [{"type": "text", "text": system_message}], 111 }, 112 ] 113 for _ in cleaned_txt 114 ] 115 116 for i, (el, images) in enumerate(zip(messages, img)): 117 # optionally add the images per batch element. 118 if images is not None: 119 el.append( 120 { 121 "role": "user", 122 "content": [{"type": "image", "image": image_obj} for image_obj in images], 123 } 124 ) 125 # add the text. 126 el.append( 127 { 128 "role": "user", 129 "content": [{"type": "text", "text": cleaned_txt[i]}], 130 } 131 ) 132 133 return messages 134 135 @torch.no_grad() 136 def upsample_prompt( 137 self, 138 txt: list[str], 139 img: list[Image.Image] | list[list[Image.Image]] | None = None, 140 temperature: float = 0.15, 141 ) -> list[str]: 142 """ 143 Upsample prompts using the model's generate method. 144 145 Args: 146 txt: List of input prompts to upsample 147 img: Optional list of images or list of lists of images. If None or all None, uses t2i mode, otherwise i2i mode. 148 149 Returns: 150 List of upsampled prompts 151 """ 152 # Set system message based on whether images are provided 153 if img is None or len(img) == 0 or img[0] is None: 154 system_message = SYSTEM_MESSAGE_UPSAMPLING_T2I 155 else: 156 system_message = SYSTEM_MESSAGE_UPSAMPLING_I2I 157 158 # Format input messages 159 messages_batch = self.format_input(txt=txt, system_message=system_message, img=img) 160 161 # Process all messages at once 162 # with image processing a too short max length can throw an error in here. 163 try: 164 inputs = self.processor.apply_chat_template( 165 messages_batch, 166 add_generation_prompt=True, 167 tokenize=True, 168 return_dict=True, 169 return_tensors="pt", 170 padding="max_length", 171 truncation=True, 172 max_length=2048, 173 ) 174 except ValueError as e: 175 nfo(f"Error processing input: {e}, your max length is probably too short, when you have images in the input.") 176 raise e 177 178 # Move to device 179 inputs["input_ids"] = inputs["input_ids"].to(self.model.device) 180 inputs["attention_mask"] = inputs["attention_mask"].to(self.model.device) 181 182 if "pixel_values" in inputs: 183 inputs["pixel_values"] = inputs["pixel_values"].to(self.model.device, self.model.dtype) 184 185 # Generate text using the model's generate method 186 try: 187 generated_ids = self.model.generate( 188 **inputs, 189 max_new_tokens=512, 190 do_sample=True, 191 temperature=temperature, 192 use_cache=True, 193 ) 194 195 # Decode only the newly generated tokens (skip input tokens) 196 # Extract only the generated portion 197 input_length = inputs["input_ids"].shape[1] 198 generated_tokens = generated_ids[:, input_length:] 199 200 raw_txt = self.processor.tokenizer.batch_decode(generated_tokens, skip_special_tokens=True, clean_up_tokenization_spaces=True) 201 return raw_txt 202 except Exception as e: 203 nfo(f"Error generating upsampled prompt: {e}, returning original prompt") 204 return txt 205 206 @torch.no_grad() 207 def forward(self, txt: list[str]): 208 # Format input messages 209 messages_batch = self.format_input(txt=txt) 210 211 # Process all messages at once 212 # with image processing a too short max length can throw an error in here. 213 inputs = self.processor.apply_chat_template( 214 messages_batch, 215 add_generation_prompt=False, 216 tokenize=True, 217 return_dict=True, 218 return_tensors="pt", 219 padding="max_length", 220 truncation=True, 221 max_length=self.max_length, 222 ) 223 224 # Move to device 225 input_ids = inputs["input_ids"].to(self.model.device) 226 attention_mask = inputs["attention_mask"].to(self.model.device) 227 228 # Forward pass through the model 229 output = self.model( 230 input_ids=input_ids, 231 attention_mask=attention_mask, 232 output_hidden_states=True, 233 use_cache=False, 234 ) 235 236 out = torch.stack([output.hidden_states[k] for k in OUTPUT_LAYERS], dim=1) 237 return rearrange(out, "b c l d -> b l (c d)") 238 239 def yes_no_logit_processor(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: 240 """ 241 Sets all tokens but yes/no to the minimum. 242 """ 243 scores_yes_token = scores[:, self.yes_token].clone() 244 scores_no_token = scores[:, self.no_token].clone() 245 scores_min = scores.min() 246 scores[:, :] = scores_min - 1 247 scores[:, self.yes_token] = scores_yes_token 248 scores[:, self.no_token] = scores_no_token 249 return scores 250 251 def test_image(self, image: Image.Image | str | Path | torch.Tensor) -> bool: 252 if isinstance(image, torch.Tensor): 253 image = rearrange(image[0].clamp(-1.0, 1.0), "c h w -> h w c") 254 image = Image.fromarray((127.5 * (image + 1.0)).cpu().byte().numpy()) 255 elif isinstance(image, (str, Path)): 256 image = Image.open(image) 257 258 # 512^2 pixels are enough for checking 259 w, h = image.size 260 f = (512**2 / (w * h)) ** 0.5 261 image = image.resize((int(f * w), int(f * h))) 262 263 chat = [ 264 { 265 "role": "system", 266 "content": [ 267 { 268 "type": "text", 269 "text": SYSTEM_PROMPT_CONTENT_FILTER, 270 }, 271 ], 272 }, 273 { 274 "role": "user", 275 "content": [ 276 { 277 "type": "text", 278 "text": PROMPT_IMAGE_INTEGRITY, 279 }, 280 { 281 "type": "image", 282 "image": image, 283 }, 284 { 285 "type": "text", 286 "text": PROMPT_IMAGE_INTEGRITY_FOLLOW_UP, 287 }, 288 ], 289 }, 290 ] 291 292 inputs = self.processor.apply_chat_template( 293 chat, 294 add_generation_prompt=True, 295 tokenize=True, 296 return_dict=True, 297 return_tensors="pt", 298 ).to(self.model.device) 299 inputs["pixel_values"] = inputs["pixel_values"].to(dtype=self.model.dtype) 300 301 generate_ids = self.model.generate( 302 **inputs, 303 max_new_tokens=1, 304 logits_processor=[self.yes_no_logit_processor], # type: ignore 305 do_sample=False, 306 ) 307 308 return generate_ids[0, -1].item() == self.yes_token 309 310 def test_txt(self, txt: str) -> bool: 311 chat = [ 312 { 313 "role": "system", 314 "content": [ 315 { 316 "type": "text", 317 "text": SYSTEM_PROMPT_CONTENT_FILTER, 318 }, 319 ], 320 }, 321 { 322 "role": "user", 323 "content": [ 324 { 325 "type": "text", 326 "text": PROMPT_TEXT_INTEGRITY.format(prompt=txt), 327 }, 328 ], 329 }, 330 ] 331 332 inputs = self.processor.apply_chat_template( 333 chat, 334 add_generation_prompt=True, 335 tokenize=True, 336 return_dict=True, 337 return_tensors="pt", 338 ).to(self.model.device) 339 340 generate_ids = self.model.generate( 341 **inputs, 342 max_new_tokens=1, 343 logits_processor=[self.yes_no_logit_processor], # type: ignore 344 do_sample=False, 345 ) 346 return generate_ids[0, -1].item() == self.yes_token
33class Mistral3SmallEmbedder(nn.Module): 34 def __init__( 35 self, 36 model_spec: str = "mistralai/Mistral-Small-3.2-24B-Instruct-2506", 37 model_spec_processor: str = "mistralai/Mistral-Small-3.1-24B-Instruct-2503", 38 torch_dtype: torch.dtype = precision, 39 ): 40 super().__init__() 41 42 self.model: Mistral3ForConditionalGeneration = Mistral3ForConditionalGeneration.from_pretrained( 43 model_spec, 44 dtype=torch_dtype, 45 ) 46 self.processor = AutoProcessor.from_pretrained(model_spec_processor, use_fast=False) 47 self.yes_token, self.no_token = self.processor.tokenizer.encode(["yes", "no"], add_special_tokens=False) 48 49 self.max_length = MAX_LENGTH 50 self.upsampling_max_image_size = UPSAMPLING_MAX_IMAGE_SIZE 51 52 def _validate_and_process_images(self, img: list[list[Image.Image]] | list[Image.Image]) -> list[list[Image.Image]]: 53 # Simple validation: ensure it's a list of PIL images or list of lists of PIL images 54 if not img: 55 return [] 56 57 # Check if it's a list of lists or a list of images 58 if isinstance(img[0], Image.Image): 59 # It's a list of images, convert to list of lists 60 img = [[im] for im in img] # type: ignore 61 62 # potentially concatenate multiple images to reduce the size 63 img = [[concatenate_images(img_i)] if len(img_i) > 1 else img_i for img_i in img] # type: ignore 64 65 # cap the pixels 66 img = [[cap_pixels(img_i, self.upsampling_max_image_size) for img_i in img_i] for img_i in img] # type: ignore 67 return img # type: ignore 68 69 def format_input( 70 self, 71 txt: list[str], 72 system_message: str = SYSTEM_MESSAGE, 73 img: list[Image.Image] | list[list[Image.Image]] | None = None, 74 ) -> list[list[dict]]: 75 """ 76 Format a batch of text prompts into the conversation format expected by apply_chat_template. 77 Optionally, add images to the input. 78 79 Args: 80 txt: List of text prompts 81 system_message: System message to use (default: CREATIVE_SYSTEM_MESSAGE) 82 img: List of images to add to the input. 83 84 Returns: 85 List of conversations, where each conversation is a list of message dicts 86 """ 87 # Remove [IMG] tokens from prompts to avoid Pixtral validation issues 88 # when truncation is enabled. The processor counts [IMG] tokens and fails 89 # if the count changes after truncation. 90 cleaned_txt = [prompt.replace("[IMG]", "") for prompt in txt] 91 92 if img is None or len(img) == 0: 93 return [ 94 [ 95 { 96 "role": "system", 97 "content": [{"type": "text", "text": system_message}], 98 }, 99 {"role": "user", "content": [{"type": "text", "text": prompt}]}, 100 ] 101 for prompt in cleaned_txt 102 ] 103 else: 104 assert len(img) == len(txt), "Number of images must match number of prompts" 105 img = self._validate_and_process_images(img) 106 107 messages = [ 108 [ 109 { 110 "role": "system", 111 "content": [{"type": "text", "text": system_message}], 112 }, 113 ] 114 for _ in cleaned_txt 115 ] 116 117 for i, (el, images) in enumerate(zip(messages, img)): 118 # optionally add the images per batch element. 119 if images is not None: 120 el.append( 121 { 122 "role": "user", 123 "content": [{"type": "image", "image": image_obj} for image_obj in images], 124 } 125 ) 126 # add the text. 127 el.append( 128 { 129 "role": "user", 130 "content": [{"type": "text", "text": cleaned_txt[i]}], 131 } 132 ) 133 134 return messages 135 136 @torch.no_grad() 137 def upsample_prompt( 138 self, 139 txt: list[str], 140 img: list[Image.Image] | list[list[Image.Image]] | None = None, 141 temperature: float = 0.15, 142 ) -> list[str]: 143 """ 144 Upsample prompts using the model's generate method. 145 146 Args: 147 txt: List of input prompts to upsample 148 img: Optional list of images or list of lists of images. If None or all None, uses t2i mode, otherwise i2i mode. 149 150 Returns: 151 List of upsampled prompts 152 """ 153 # Set system message based on whether images are provided 154 if img is None or len(img) == 0 or img[0] is None: 155 system_message = SYSTEM_MESSAGE_UPSAMPLING_T2I 156 else: 157 system_message = SYSTEM_MESSAGE_UPSAMPLING_I2I 158 159 # Format input messages 160 messages_batch = self.format_input(txt=txt, system_message=system_message, img=img) 161 162 # Process all messages at once 163 # with image processing a too short max length can throw an error in here. 164 try: 165 inputs = self.processor.apply_chat_template( 166 messages_batch, 167 add_generation_prompt=True, 168 tokenize=True, 169 return_dict=True, 170 return_tensors="pt", 171 padding="max_length", 172 truncation=True, 173 max_length=2048, 174 ) 175 except ValueError as e: 176 nfo(f"Error processing input: {e}, your max length is probably too short, when you have images in the input.") 177 raise e 178 179 # Move to device 180 inputs["input_ids"] = inputs["input_ids"].to(self.model.device) 181 inputs["attention_mask"] = inputs["attention_mask"].to(self.model.device) 182 183 if "pixel_values" in inputs: 184 inputs["pixel_values"] = inputs["pixel_values"].to(self.model.device, self.model.dtype) 185 186 # Generate text using the model's generate method 187 try: 188 generated_ids = self.model.generate( 189 **inputs, 190 max_new_tokens=512, 191 do_sample=True, 192 temperature=temperature, 193 use_cache=True, 194 ) 195 196 # Decode only the newly generated tokens (skip input tokens) 197 # Extract only the generated portion 198 input_length = inputs["input_ids"].shape[1] 199 generated_tokens = generated_ids[:, input_length:] 200 201 raw_txt = self.processor.tokenizer.batch_decode(generated_tokens, skip_special_tokens=True, clean_up_tokenization_spaces=True) 202 return raw_txt 203 except Exception as e: 204 nfo(f"Error generating upsampled prompt: {e}, returning original prompt") 205 return txt 206 207 @torch.no_grad() 208 def forward(self, txt: list[str]): 209 # Format input messages 210 messages_batch = self.format_input(txt=txt) 211 212 # Process all messages at once 213 # with image processing a too short max length can throw an error in here. 214 inputs = self.processor.apply_chat_template( 215 messages_batch, 216 add_generation_prompt=False, 217 tokenize=True, 218 return_dict=True, 219 return_tensors="pt", 220 padding="max_length", 221 truncation=True, 222 max_length=self.max_length, 223 ) 224 225 # Move to device 226 input_ids = inputs["input_ids"].to(self.model.device) 227 attention_mask = inputs["attention_mask"].to(self.model.device) 228 229 # Forward pass through the model 230 output = self.model( 231 input_ids=input_ids, 232 attention_mask=attention_mask, 233 output_hidden_states=True, 234 use_cache=False, 235 ) 236 237 out = torch.stack([output.hidden_states[k] for k in OUTPUT_LAYERS], dim=1) 238 return rearrange(out, "b c l d -> b l (c d)") 239 240 def yes_no_logit_processor(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: 241 """ 242 Sets all tokens but yes/no to the minimum. 243 """ 244 scores_yes_token = scores[:, self.yes_token].clone() 245 scores_no_token = scores[:, self.no_token].clone() 246 scores_min = scores.min() 247 scores[:, :] = scores_min - 1 248 scores[:, self.yes_token] = scores_yes_token 249 scores[:, self.no_token] = scores_no_token 250 return scores 251 252 def test_image(self, image: Image.Image | str | Path | torch.Tensor) -> bool: 253 if isinstance(image, torch.Tensor): 254 image = rearrange(image[0].clamp(-1.0, 1.0), "c h w -> h w c") 255 image = Image.fromarray((127.5 * (image + 1.0)).cpu().byte().numpy()) 256 elif isinstance(image, (str, Path)): 257 image = Image.open(image) 258 259 # 512^2 pixels are enough for checking 260 w, h = image.size 261 f = (512**2 / (w * h)) ** 0.5 262 image = image.resize((int(f * w), int(f * h))) 263 264 chat = [ 265 { 266 "role": "system", 267 "content": [ 268 { 269 "type": "text", 270 "text": SYSTEM_PROMPT_CONTENT_FILTER, 271 }, 272 ], 273 }, 274 { 275 "role": "user", 276 "content": [ 277 { 278 "type": "text", 279 "text": PROMPT_IMAGE_INTEGRITY, 280 }, 281 { 282 "type": "image", 283 "image": image, 284 }, 285 { 286 "type": "text", 287 "text": PROMPT_IMAGE_INTEGRITY_FOLLOW_UP, 288 }, 289 ], 290 }, 291 ] 292 293 inputs = self.processor.apply_chat_template( 294 chat, 295 add_generation_prompt=True, 296 tokenize=True, 297 return_dict=True, 298 return_tensors="pt", 299 ).to(self.model.device) 300 inputs["pixel_values"] = inputs["pixel_values"].to(dtype=self.model.dtype) 301 302 generate_ids = self.model.generate( 303 **inputs, 304 max_new_tokens=1, 305 logits_processor=[self.yes_no_logit_processor], # type: ignore 306 do_sample=False, 307 ) 308 309 return generate_ids[0, -1].item() == self.yes_token 310 311 def test_txt(self, txt: str) -> bool: 312 chat = [ 313 { 314 "role": "system", 315 "content": [ 316 { 317 "type": "text", 318 "text": SYSTEM_PROMPT_CONTENT_FILTER, 319 }, 320 ], 321 }, 322 { 323 "role": "user", 324 "content": [ 325 { 326 "type": "text", 327 "text": PROMPT_TEXT_INTEGRITY.format(prompt=txt), 328 }, 329 ], 330 }, 331 ] 332 333 inputs = self.processor.apply_chat_template( 334 chat, 335 add_generation_prompt=True, 336 tokenize=True, 337 return_dict=True, 338 return_tensors="pt", 339 ).to(self.model.device) 340 341 generate_ids = self.model.generate( 342 **inputs, 343 max_new_tokens=1, 344 logits_processor=[self.yes_no_logit_processor], # type: ignore 345 do_sample=False, 346 ) 347 return generate_ids[0, -1].item() == self.yes_token
Base class for all neural network modules.
Your models should also subclass this class.
Modules can also contain other Modules, allowing them to be nested in a tree structure. You can assign the submodules as regular attributes::
import torch.nn as nn
import torch.nn.functional as F
class Model(nn.Module):
def __init__(self) -> None:
super().__init__()
self.conv1 = nn.Conv2d(1, 20, 5)
self.conv2 = nn.Conv2d(20, 20, 5)
def forward(self, x):
x = F.relu(self.conv1(x))
return F.relu(self.conv2(x))
Submodules assigned in this way will be registered, and will also have their
parameters converted when you call to(), etc.
As per the example above, an __init__() call to the parent class
must be made before assignment on the child.
:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool
34 def __init__( 35 self, 36 model_spec: str = "mistralai/Mistral-Small-3.2-24B-Instruct-2506", 37 model_spec_processor: str = "mistralai/Mistral-Small-3.1-24B-Instruct-2503", 38 torch_dtype: torch.dtype = precision, 39 ): 40 super().__init__() 41 42 self.model: Mistral3ForConditionalGeneration = Mistral3ForConditionalGeneration.from_pretrained( 43 model_spec, 44 dtype=torch_dtype, 45 ) 46 self.processor = AutoProcessor.from_pretrained(model_spec_processor, use_fast=False) 47 self.yes_token, self.no_token = self.processor.tokenizer.encode(["yes", "no"], add_special_tokens=False) 48 49 self.max_length = MAX_LENGTH 50 self.upsampling_max_image_size = UPSAMPLING_MAX_IMAGE_SIZE
Initialize internal Module state, shared by both nn.Module and ScriptModule.
69 def format_input( 70 self, 71 txt: list[str], 72 system_message: str = SYSTEM_MESSAGE, 73 img: list[Image.Image] | list[list[Image.Image]] | None = None, 74 ) -> list[list[dict]]: 75 """ 76 Format a batch of text prompts into the conversation format expected by apply_chat_template. 77 Optionally, add images to the input. 78 79 Args: 80 txt: List of text prompts 81 system_message: System message to use (default: CREATIVE_SYSTEM_MESSAGE) 82 img: List of images to add to the input. 83 84 Returns: 85 List of conversations, where each conversation is a list of message dicts 86 """ 87 # Remove [IMG] tokens from prompts to avoid Pixtral validation issues 88 # when truncation is enabled. The processor counts [IMG] tokens and fails 89 # if the count changes after truncation. 90 cleaned_txt = [prompt.replace("[IMG]", "") for prompt in txt] 91 92 if img is None or len(img) == 0: 93 return [ 94 [ 95 { 96 "role": "system", 97 "content": [{"type": "text", "text": system_message}], 98 }, 99 {"role": "user", "content": [{"type": "text", "text": prompt}]}, 100 ] 101 for prompt in cleaned_txt 102 ] 103 else: 104 assert len(img) == len(txt), "Number of images must match number of prompts" 105 img = self._validate_and_process_images(img) 106 107 messages = [ 108 [ 109 { 110 "role": "system", 111 "content": [{"type": "text", "text": system_message}], 112 }, 113 ] 114 for _ in cleaned_txt 115 ] 116 117 for i, (el, images) in enumerate(zip(messages, img)): 118 # optionally add the images per batch element. 119 if images is not None: 120 el.append( 121 { 122 "role": "user", 123 "content": [{"type": "image", "image": image_obj} for image_obj in images], 124 } 125 ) 126 # add the text. 127 el.append( 128 { 129 "role": "user", 130 "content": [{"type": "text", "text": cleaned_txt[i]}], 131 } 132 ) 133 134 return messages
Format a batch of text prompts into the conversation format expected by apply_chat_template. Optionally, add images to the input.
Args: txt: List of text prompts system_message: System message to use (default: CREATIVE_SYSTEM_MESSAGE) img: List of images to add to the input.
Returns: List of conversations, where each conversation is a list of message dicts
136 @torch.no_grad() 137 def upsample_prompt( 138 self, 139 txt: list[str], 140 img: list[Image.Image] | list[list[Image.Image]] | None = None, 141 temperature: float = 0.15, 142 ) -> list[str]: 143 """ 144 Upsample prompts using the model's generate method. 145 146 Args: 147 txt: List of input prompts to upsample 148 img: Optional list of images or list of lists of images. If None or all None, uses t2i mode, otherwise i2i mode. 149 150 Returns: 151 List of upsampled prompts 152 """ 153 # Set system message based on whether images are provided 154 if img is None or len(img) == 0 or img[0] is None: 155 system_message = SYSTEM_MESSAGE_UPSAMPLING_T2I 156 else: 157 system_message = SYSTEM_MESSAGE_UPSAMPLING_I2I 158 159 # Format input messages 160 messages_batch = self.format_input(txt=txt, system_message=system_message, img=img) 161 162 # Process all messages at once 163 # with image processing a too short max length can throw an error in here. 164 try: 165 inputs = self.processor.apply_chat_template( 166 messages_batch, 167 add_generation_prompt=True, 168 tokenize=True, 169 return_dict=True, 170 return_tensors="pt", 171 padding="max_length", 172 truncation=True, 173 max_length=2048, 174 ) 175 except ValueError as e: 176 nfo(f"Error processing input: {e}, your max length is probably too short, when you have images in the input.") 177 raise e 178 179 # Move to device 180 inputs["input_ids"] = inputs["input_ids"].to(self.model.device) 181 inputs["attention_mask"] = inputs["attention_mask"].to(self.model.device) 182 183 if "pixel_values" in inputs: 184 inputs["pixel_values"] = inputs["pixel_values"].to(self.model.device, self.model.dtype) 185 186 # Generate text using the model's generate method 187 try: 188 generated_ids = self.model.generate( 189 **inputs, 190 max_new_tokens=512, 191 do_sample=True, 192 temperature=temperature, 193 use_cache=True, 194 ) 195 196 # Decode only the newly generated tokens (skip input tokens) 197 # Extract only the generated portion 198 input_length = inputs["input_ids"].shape[1] 199 generated_tokens = generated_ids[:, input_length:] 200 201 raw_txt = self.processor.tokenizer.batch_decode(generated_tokens, skip_special_tokens=True, clean_up_tokenization_spaces=True) 202 return raw_txt 203 except Exception as e: 204 nfo(f"Error generating upsampled prompt: {e}, returning original prompt") 205 return txt
Upsample prompts using the model's generate method.
Args: txt: List of input prompts to upsample img: Optional list of images or list of lists of images. If None or all None, uses t2i mode, otherwise i2i mode.
Returns: List of upsampled prompts
207 @torch.no_grad() 208 def forward(self, txt: list[str]): 209 # Format input messages 210 messages_batch = self.format_input(txt=txt) 211 212 # Process all messages at once 213 # with image processing a too short max length can throw an error in here. 214 inputs = self.processor.apply_chat_template( 215 messages_batch, 216 add_generation_prompt=False, 217 tokenize=True, 218 return_dict=True, 219 return_tensors="pt", 220 padding="max_length", 221 truncation=True, 222 max_length=self.max_length, 223 ) 224 225 # Move to device 226 input_ids = inputs["input_ids"].to(self.model.device) 227 attention_mask = inputs["attention_mask"].to(self.model.device) 228 229 # Forward pass through the model 230 output = self.model( 231 input_ids=input_ids, 232 attention_mask=attention_mask, 233 output_hidden_states=True, 234 use_cache=False, 235 ) 236 237 out = torch.stack([output.hidden_states[k] for k in OUTPUT_LAYERS], dim=1) 238 return rearrange(out, "b c l d -> b l (c d)")
Define the computation performed at every call.
Should be overridden by all subclasses.
Although the recipe for forward pass needs to be defined within
this function, one should call the Module instance afterwards
instead of this since the former takes care of running the
registered hooks while the latter silently ignores them.
240 def yes_no_logit_processor(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: 241 """ 242 Sets all tokens but yes/no to the minimum. 243 """ 244 scores_yes_token = scores[:, self.yes_token].clone() 245 scores_no_token = scores[:, self.no_token].clone() 246 scores_min = scores.min() 247 scores[:, :] = scores_min - 1 248 scores[:, self.yes_token] = scores_yes_token 249 scores[:, self.no_token] = scores_no_token 250 return scores
Sets all tokens but yes/no to the minimum.
252 def test_image(self, image: Image.Image | str | Path | torch.Tensor) -> bool: 253 if isinstance(image, torch.Tensor): 254 image = rearrange(image[0].clamp(-1.0, 1.0), "c h w -> h w c") 255 image = Image.fromarray((127.5 * (image + 1.0)).cpu().byte().numpy()) 256 elif isinstance(image, (str, Path)): 257 image = Image.open(image) 258 259 # 512^2 pixels are enough for checking 260 w, h = image.size 261 f = (512**2 / (w * h)) ** 0.5 262 image = image.resize((int(f * w), int(f * h))) 263 264 chat = [ 265 { 266 "role": "system", 267 "content": [ 268 { 269 "type": "text", 270 "text": SYSTEM_PROMPT_CONTENT_FILTER, 271 }, 272 ], 273 }, 274 { 275 "role": "user", 276 "content": [ 277 { 278 "type": "text", 279 "text": PROMPT_IMAGE_INTEGRITY, 280 }, 281 { 282 "type": "image", 283 "image": image, 284 }, 285 { 286 "type": "text", 287 "text": PROMPT_IMAGE_INTEGRITY_FOLLOW_UP, 288 }, 289 ], 290 }, 291 ] 292 293 inputs = self.processor.apply_chat_template( 294 chat, 295 add_generation_prompt=True, 296 tokenize=True, 297 return_dict=True, 298 return_tensors="pt", 299 ).to(self.model.device) 300 inputs["pixel_values"] = inputs["pixel_values"].to(dtype=self.model.dtype) 301 302 generate_ids = self.model.generate( 303 **inputs, 304 max_new_tokens=1, 305 logits_processor=[self.yes_no_logit_processor], # type: ignore 306 do_sample=False, 307 ) 308 309 return generate_ids[0, -1].item() == self.yes_token
311 def test_txt(self, txt: str) -> bool: 312 chat = [ 313 { 314 "role": "system", 315 "content": [ 316 { 317 "type": "text", 318 "text": SYSTEM_PROMPT_CONTENT_FILTER, 319 }, 320 ], 321 }, 322 { 323 "role": "user", 324 "content": [ 325 { 326 "type": "text", 327 "text": PROMPT_TEXT_INTEGRITY.format(prompt=txt), 328 }, 329 ], 330 }, 331 ] 332 333 inputs = self.processor.apply_chat_template( 334 chat, 335 add_generation_prompt=True, 336 tokenize=True, 337 return_dict=True, 338 return_tensors="pt", 339 ).to(self.model.device) 340 341 generate_ids = self.model.generate( 342 **inputs, 343 max_new_tokens=1, 344 logits_processor=[self.yes_no_logit_processor], # type: ignore 345 do_sample=False, 346 ) 347 return generate_ids[0, -1].item() == self.yes_token