divisor.flux2.prompt
1# SPDX-License-Identifier: MPL-2.0 AND LicenseRef-Commons-Clause-License-Condition-1.0 2# <!-- // /* d a r k s h a p e s */ --> 3# adapted BFL Flux code from https://github.com/black-forest-labs/flux2 4 5from dataclasses import replace 6 7import torch 8from fire import Fire 9from nnll.console import nfo 10from PIL import Image 11 12from divisor.controller import rng 13from divisor.flux1.loading import load_ae, load_flow_model, load_mistral_small_embedder 14from divisor.flux1.prompt import parse_prompt 15from divisor.flux2.sampling import ( 16 batched_prc_img, 17 batched_prc_txt, 18 denoise_interactive, 19 encode_image_refs, 20 get_schedule, 21) 22from divisor.noise import get_noise 23from divisor.registry import gfx_device, gfx_dtype, empty_cache 24from divisor.spec import ModelSpec, flux_configs, get_model_spec 25from divisor.state import InferenceState, MenuState 26 27 28def main( 29 mir_id: str = "model.dit.flux2-dev", 30 ae_id: str = "model.vae.flux2-dev", 31 width: int = 1360, 32 height: int = 768, 33 guidance: float = 4, 34 seed: int = rng.next_seed(), 35 prompt: str = "", 36 quantization: bool = False, 37 device: torch.device = gfx_device, 38 num_steps: int = 50, 39 upsample_prompt: bool = False, 40 loop: bool = False, 41 offload: bool = False, 42 compile: bool = False, 43 verbose: bool = False, 44 input_images: list[str] | None = None, 45) -> None: 46 """Sample the flux model. Either interactively (set `--loop`) or run for a single image.\n 47 :param name: Name of the model to load 48 :param height: height of the sample in pixels (should be a multiple of 16) 49 :param width: width of the sample in pixels (should be a multiple of 16) 50 :param seed: Set a seed for sampling 51 :param output_name: where to save the output image, `{idx}` will be replaced by the index of the sample 52 :param prompt: Prompt used for sampling 53 :param device: Pytorch device 54 :param num_steps: number of sampling steps (default 4 for schnell, 28 for guidance distilled) 55 :param loop: start an interactive session and sample multiple times 56 :param guidance: guidance value used for guidance distillation 57 """ 58 59 precision = gfx_dtype 60 prompt_parts = prompt.split("|") 61 if len(prompt_parts) == 1: 62 prompt = prompt_parts[0] 63 additional_prompts = None 64 else: 65 additional_prompts = prompt_parts[1:] 66 prompt = prompt_parts[0] 67 68 mistral = load_mistral_small_embedder() 69 if quantization: 70 mir_id += ":*@fp8-sai" 71 model_spec: ModelSpec = get_model_spec(mir_id, flux_configs) 72 ae_spec = get_model_spec(ae_id, flux_configs) 73 model = load_flow_model( 74 model_spec, 75 device=torch.device("cpu") if offload else device, 76 verbose=verbose, 77 ) 78 79 is_compiled = False 80 if compile and not offload: # Compiled models can't be easily moved between devices 81 nfo("Compilation enabled.") 82 model = torch.compile(model) # type: ignore[assignment] 83 is_compiled = True 84 85 ae = load_ae(ae_spec, device=torch.device("cpu") if offload else device) 86 ae.eval() 87 mistral.eval() 88 89 state = MenuState.from_cli_args( 90 prompt=prompt, 91 width=width, 92 height=height, 93 num_steps=num_steps, 94 guidance=guidance, 95 seed=seed, 96 ) 97 98 if loop: 99 state = parse_prompt(state) 100 101 while state is not None: 102 if state.seed is None: 103 seed = rng.next_seed() 104 state = replace(state, seed=seed) 105 else: 106 rng.next_seed(state.seed) 107 # At this point, state.seed is guaranteed to be an int 108 assert state.seed is not None, "Seed must be set" 109 assert state.width is not None and state.height is not None, "Width and height must be set" 110 assert state.num_steps is not None, "num_steps must be set" 111 assert state.prompt is not None, "Prompt must be set" 112 nfo(f"Generating with seed {rng.seed}: {state.prompt}") 113 114 x = get_noise( 115 1, 116 state.height, 117 state.width, 118 dtype=precision, 119 seed=rng.seed, # type: ignore 120 device=device, 121 version_2=True, 122 ) 123 124 with torch.no_grad(): 125 if input_images: 126 img_ctx = [Image.open(input_image) for input_image in input_images] 127 ref_tokens, ref_ids = encode_image_refs(ae, img_ctx) # type: ignore 128 else: 129 ref_tokens = None 130 ref_ids = None 131 132 if upsample_prompt: 133 # Use local model for upsampling 134 upsampled_prompts = mistral.upsample_prompt([state.prompt], img=[img_ctx] if img_ctx else None) # type: ignore 135 prompt = upsampled_prompts[0] if upsampled_prompts else state.prompt 136 state = replace(state, prompt=prompt) 137 else: 138 prompt = state.prompt 139 140 ctx = mistral([prompt]).to(precision) # type: ignore 141 ctx, ctx_ids = batched_prc_txt(ctx) 142 143 randn = get_noise( 144 1, 145 state.height, # type: ignore 146 state.width, # type: ignore 147 dtype=torch.bfloat16, 148 seed=state.seed, # type: ignore 149 device=device, 150 version_2=True, 151 ) 152 153 x, x_ids = batched_prc_img(randn) 154 155 timesteps = get_schedule(state.num_steps, x.shape[1]) # type: ignore 156 # Update state with runtime information 157 state = state.with_runtime_state( 158 current_timestep=0.0, 159 current_sample=x, 160 timestep_index=0, 161 total_timesteps=len(timesteps), 162 ) 163 164 if offload: 165 if device.type == "cuda": 166 mistral = mistral.cpu() # type: ignore 167 else: 168 mistral = None 169 del mistral 170 empty_cache 171 if is_compiled: 172 raise RuntimeError("Cannot use offload=True with compile=True. Compile after model is on device, or disable compilation when offloading.") 173 model = model.to(device) # type: ignore[attr-defined] 174 if compile: 175 nfo("Compiling model on device.") 176 model = torch.compile(model, mode="max-autotune") # type: ignore[assignment] 177 is_compiled = True 178 179 x = denoise_interactive( 180 model, # type: ignore 181 InferenceState( 182 img=x, 183 img_ids=x_ids, 184 txt=ctx, 185 txt_ids=ctx_ids, 186 timesteps=timesteps, 187 img_cond_seq=ref_tokens, 188 img_cond_seq_ids=ref_ids, 189 state=state, 190 ae=ae, 191 ), 192 ) 193 if loop: 194 nfo("-" * 80) 195 state = parse_prompt(state) 196 elif additional_prompts: 197 next_prompt = additional_prompts.pop(0) 198 state = replace(state, prompt=next_prompt) 199 else: 200 state = None 201 202 203if __name__ == "__main__": 204 Fire(main)
def
main( mir_id: str = 'model.dit.flux2-dev', ae_id: str = 'model.vae.flux2-dev', width: int = 1360, height: int = 768, guidance: float = 4, seed: int = 9245605556668510957, prompt: str = '', quantization: bool = False, device: torch.device = device(type='mps'), num_steps: int = 50, upsample_prompt: bool = False, loop: bool = False, offload: bool = False, compile: bool = False, verbose: bool = False, input_images: list[str] | None = None) -> None:
29def main( 30 mir_id: str = "model.dit.flux2-dev", 31 ae_id: str = "model.vae.flux2-dev", 32 width: int = 1360, 33 height: int = 768, 34 guidance: float = 4, 35 seed: int = rng.next_seed(), 36 prompt: str = "", 37 quantization: bool = False, 38 device: torch.device = gfx_device, 39 num_steps: int = 50, 40 upsample_prompt: bool = False, 41 loop: bool = False, 42 offload: bool = False, 43 compile: bool = False, 44 verbose: bool = False, 45 input_images: list[str] | None = None, 46) -> None: 47 """Sample the flux model. Either interactively (set `--loop`) or run for a single image.\n 48 :param name: Name of the model to load 49 :param height: height of the sample in pixels (should be a multiple of 16) 50 :param width: width of the sample in pixels (should be a multiple of 16) 51 :param seed: Set a seed for sampling 52 :param output_name: where to save the output image, `{idx}` will be replaced by the index of the sample 53 :param prompt: Prompt used for sampling 54 :param device: Pytorch device 55 :param num_steps: number of sampling steps (default 4 for schnell, 28 for guidance distilled) 56 :param loop: start an interactive session and sample multiple times 57 :param guidance: guidance value used for guidance distillation 58 """ 59 60 precision = gfx_dtype 61 prompt_parts = prompt.split("|") 62 if len(prompt_parts) == 1: 63 prompt = prompt_parts[0] 64 additional_prompts = None 65 else: 66 additional_prompts = prompt_parts[1:] 67 prompt = prompt_parts[0] 68 69 mistral = load_mistral_small_embedder() 70 if quantization: 71 mir_id += ":*@fp8-sai" 72 model_spec: ModelSpec = get_model_spec(mir_id, flux_configs) 73 ae_spec = get_model_spec(ae_id, flux_configs) 74 model = load_flow_model( 75 model_spec, 76 device=torch.device("cpu") if offload else device, 77 verbose=verbose, 78 ) 79 80 is_compiled = False 81 if compile and not offload: # Compiled models can't be easily moved between devices 82 nfo("Compilation enabled.") 83 model = torch.compile(model) # type: ignore[assignment] 84 is_compiled = True 85 86 ae = load_ae(ae_spec, device=torch.device("cpu") if offload else device) 87 ae.eval() 88 mistral.eval() 89 90 state = MenuState.from_cli_args( 91 prompt=prompt, 92 width=width, 93 height=height, 94 num_steps=num_steps, 95 guidance=guidance, 96 seed=seed, 97 ) 98 99 if loop: 100 state = parse_prompt(state) 101 102 while state is not None: 103 if state.seed is None: 104 seed = rng.next_seed() 105 state = replace(state, seed=seed) 106 else: 107 rng.next_seed(state.seed) 108 # At this point, state.seed is guaranteed to be an int 109 assert state.seed is not None, "Seed must be set" 110 assert state.width is not None and state.height is not None, "Width and height must be set" 111 assert state.num_steps is not None, "num_steps must be set" 112 assert state.prompt is not None, "Prompt must be set" 113 nfo(f"Generating with seed {rng.seed}: {state.prompt}") 114 115 x = get_noise( 116 1, 117 state.height, 118 state.width, 119 dtype=precision, 120 seed=rng.seed, # type: ignore 121 device=device, 122 version_2=True, 123 ) 124 125 with torch.no_grad(): 126 if input_images: 127 img_ctx = [Image.open(input_image) for input_image in input_images] 128 ref_tokens, ref_ids = encode_image_refs(ae, img_ctx) # type: ignore 129 else: 130 ref_tokens = None 131 ref_ids = None 132 133 if upsample_prompt: 134 # Use local model for upsampling 135 upsampled_prompts = mistral.upsample_prompt([state.prompt], img=[img_ctx] if img_ctx else None) # type: ignore 136 prompt = upsampled_prompts[0] if upsampled_prompts else state.prompt 137 state = replace(state, prompt=prompt) 138 else: 139 prompt = state.prompt 140 141 ctx = mistral([prompt]).to(precision) # type: ignore 142 ctx, ctx_ids = batched_prc_txt(ctx) 143 144 randn = get_noise( 145 1, 146 state.height, # type: ignore 147 state.width, # type: ignore 148 dtype=torch.bfloat16, 149 seed=state.seed, # type: ignore 150 device=device, 151 version_2=True, 152 ) 153 154 x, x_ids = batched_prc_img(randn) 155 156 timesteps = get_schedule(state.num_steps, x.shape[1]) # type: ignore 157 # Update state with runtime information 158 state = state.with_runtime_state( 159 current_timestep=0.0, 160 current_sample=x, 161 timestep_index=0, 162 total_timesteps=len(timesteps), 163 ) 164 165 if offload: 166 if device.type == "cuda": 167 mistral = mistral.cpu() # type: ignore 168 else: 169 mistral = None 170 del mistral 171 empty_cache 172 if is_compiled: 173 raise RuntimeError("Cannot use offload=True with compile=True. Compile after model is on device, or disable compilation when offloading.") 174 model = model.to(device) # type: ignore[attr-defined] 175 if compile: 176 nfo("Compiling model on device.") 177 model = torch.compile(model, mode="max-autotune") # type: ignore[assignment] 178 is_compiled = True 179 180 x = denoise_interactive( 181 model, # type: ignore 182 InferenceState( 183 img=x, 184 img_ids=x_ids, 185 txt=ctx, 186 txt_ids=ctx_ids, 187 timesteps=timesteps, 188 img_cond_seq=ref_tokens, 189 img_cond_seq_ids=ref_ids, 190 state=state, 191 ae=ae, 192 ), 193 ) 194 if loop: 195 nfo("-" * 80) 196 state = parse_prompt(state) 197 elif additional_prompts: 198 next_prompt = additional_prompts.pop(0) 199 state = replace(state, prompt=next_prompt) 200 else: 201 state = None
Sample the flux model. Either interactively (set --loop) or run for a single image.
Parameters
- name: Name of the model to load
- height: height of the sample in pixels (should be a multiple of 16)
- width: width of the sample in pixels (should be a multiple of 16)
- seed: Set a seed for sampling
- output_name: where to save the output image,
{idx}will be replaced by the index of the sample - prompt: Prompt used for sampling
- device: Pytorch device
- num_steps: number of sampling steps (default 4 for schnell, 28 for guidance distilled)
- loop: start an interactive session and sample multiple times
- guidance: guidance value used for guidance distillation