divisor.flux1.prompt
1# SPDX-License-Identifier: MPL-2.0 AND LicenseRef-Commons-Clause-License-Condition-1.0 2# <!-- // /* d a r k s h a p e s */ --> 3# adapted BFL Flux code from https://github.com/black-forest-labs/flux 4 5from dataclasses import replace 6 7from fire import Fire 8from nnll.console import nfo 9from divisor.registry import gfx_device, empty_cache 10import torch 11 12from divisor.controller import rng 13from divisor.flux1.loading import load_ae, load_clip, load_flow_model, load_t5 14from divisor.flux1.sampling import denoise, get_schedule, prepare 15from divisor.noise import prepare_4d_noise_for_3d_model 16from divisor.spec import InitialParamsFlux, ModelSpec, flux_configs, get_model_spec 17from divisor.state import MenuState 18 19 20def parse_prompt(state: MenuState) -> MenuState | None: 21 """Parse user input and update input display. 22 23 :param state: Current state to update 24 :returns: Updated input display or None to quit""" 25 user_question = "Next prompt (write /h for help, /q to quit and leave empty to repeat):\n" 26 usage = ( 27 "Usage: Either write your prompt directly, leave this field empty " 28 "to repeat the prompt or write a command starting with a slash:\n" 29 "- '/w <width>' will set the width of the generated image\n" 30 "- '/h <height>' will set the height of the generated image\n" 31 "- '/s <seed>' sets the next seed\n" 32 "- '/g <guidance>' sets the guidance (flux-dev only)\n" 33 "- '/n <steps>' sets the number of steps\n" 34 "- '/q' to quit" 35 ) 36 37 while (prompt := input(user_question)).startswith("/"): 38 if prompt.startswith("/w"): 39 if prompt.count(" ") != 1: 40 nfo(f"Got invalid command '{prompt}'\n{usage}") 41 continue 42 _, width = prompt.split() 43 width = 16 * (int(width) // 16) 44 state = replace(state, width=width) 45 nfo(f"Setting resolution to {state.width} x {state.height} ({state.height * state.width / 1e6:.2f}MP)") 46 elif prompt.startswith("/h"): 47 if prompt.count(" ") != 1: 48 nfo(f"Got invalid command '{prompt}'\n{usage}") 49 continue 50 _, height = prompt.split() 51 height = 16 * (int(height) // 16) 52 state = replace(state, height=height) 53 nfo(f"Setting resolution to {state.width} x {state.height} ({state.height * state.width / 1e6:.2f}MP)") 54 elif prompt.startswith("/g"): 55 if prompt.count(" ") != 1: 56 nfo(f"Got invalid command '{prompt}'\n{usage}") 57 continue 58 _, guidance = prompt.split() 59 state = replace(state, guidance=float(guidance)) 60 nfo(f"Setting guidance to {state.guidance}") 61 elif prompt.startswith("/s"): 62 if prompt.count(" ") != 1: 63 nfo(f"Got invalid command '{prompt}'\n{usage}") 64 continue 65 _, seed = prompt.split() 66 state = replace(state, seed=int(seed)) 67 nfo(f"Setting seed to {state.seed}") 68 elif prompt.startswith("/n"): 69 if prompt.count(" ") != 1: 70 nfo(f"Got invalid command '{prompt}'\n{usage}") 71 continue 72 _, steps = prompt.split() 73 state = replace(state, num_steps=int(steps)) 74 nfo(f"Setting number of steps to {state.num_steps}") 75 elif prompt.startswith("/q"): 76 nfo("Quitting") 77 return None 78 else: 79 if not prompt.startswith("/h"): 80 nfo(f"Got invalid command '{prompt}'\n{usage}") 81 nfo(usage) 82 if prompt != "": 83 state = replace(state, prompt=prompt) 84 return state 85 86 87@torch.inference_mode() 88def main( 89 mir_id: str = "model.dit.flux1-dev", 90 ae_id: str = "model.vae.flux1-dev", 91 width: int = 1360, 92 height: int = 768, 93 guidance: float = 4.0, 94 seed: int = rng.next_seed(), 95 prompt: str = "", 96 quantization: bool = False, 97 device: torch.device = gfx_device, 98 num_steps: int = 50, 99 loop: bool = False, 100 offload: bool = False, 101 compile: bool = False, 102 verbose: bool = False, 103 input_images: list[str] | None = None, 104): 105 """Sample the flux model. Either interactively (set `--loop`) or run for a single image.\n 106 :param name: Name of the model to load 107 :param height: height of the sample in pixels (should be a multiple of 16) 108 :param width: width of the sample in pixels (should be a multiple of 16) 109 :param seed: Set a seed for sampling 110 :param output_name: where to save the output image, `{idx}` will be replaced by the index of the sample 111 :param prompt: Prompt used for sampling 112 :param device: Pytorch device 113 :param num_steps: number of sampling steps (default 4 for schnell, 28 for guidance distilled) 114 :param loop: start an interactive session and sample multiple times 115 :param guidance: guidance value used for guidance distillation""" 116 117 if quantization: 118 mir_id += ":*@fp8-sai" 119 model_spec: ModelSpec = get_model_spec(mir_id, flux_configs) 120 ae_spec: ModelSpec = get_model_spec(ae_id, flux_configs) 121 init: InitialParamsFlux = model_spec.init 122 123 prompt_parts = prompt.split("|") 124 if len(prompt_parts) == 1: 125 prompt = prompt_parts[0] 126 additional_prompts = None 127 else: 128 additional_prompts = prompt_parts[1:] 129 prompt = prompt_parts[0] 130 131 assert not ((additional_prompts is not None) and loop), "Do not provide additional prompts and set loop to True" 132 133 height = 16 * (height // 16) 134 width = 16 * (width // 16) 135 136 t5 = load_t5(device, init.max_length or 512) 137 clip = load_clip(device) 138 model = load_flow_model( 139 model_spec, 140 device=torch.device("cpu") if offload else device, 141 verbose=verbose, 142 ) 143 144 is_compiled = False 145 if compile and not offload: # (compiled models should be on target device) 146 nfo("Compilation enabled.") 147 model = torch.compile(model) # type: ignore[assignment] 148 is_compiled = True 149 150 ae = load_ae(ae_spec, device=torch.device("cpu") if offload else device) 151 152 state = MenuState.from_cli_args( 153 prompt=prompt, 154 width=width, 155 height=height, 156 num_steps=num_steps or init.num_steps, 157 guidance=guidance, 158 seed=seed, 159 ) 160 161 if loop: 162 state = parse_prompt(state) 163 164 while state is not None: 165 if state.seed is None: 166 seed = rng.next_seed() 167 state = replace(state, seed=seed) 168 else: 169 rng.next_seed(state.seed) 170 assert state.seed is not None, "Seed must be set" 171 assert state.width is not None and state.height is not None, "Width and height must be set" 172 assert state.num_steps is not None, "num_steps must be set" 173 assert state.prompt is not None, "Prompt must be set" 174 nfo(f"Generating with seed {rng.seed}: {state.prompt}") 175 176 from divisor.noise import get_noise 177 178 x_4d = get_noise( 179 1, 180 state.height, 181 state.width, 182 device=device, 183 dtype=torch.bfloat16, 184 seed=rng.seed, # type: ignore 185 ) 186 187 if offload: 188 ae = ae.cpu() 189 empty_cache 190 191 t5, clip = t5.to(device), clip.to(device) 192 x_3d = prepare_4d_noise_for_3d_model( 193 height=state.height, # type: ignore 194 width=state.width, # type: ignore 195 seed=rng.seed, # type: ignore 196 t5=t5, 197 clip=clip, 198 prompt=state.prompt, 199 device=device, 200 dtype=torch.bfloat16, 201 ) 202 inp = prepare(t5, clip, x_4d, prompt=state.prompt) 203 timesteps = get_schedule(state.num_steps, inp["img"].shape[1], shift=init.shift) 204 # Update state with runtime information (use 3D format for current_sample) 205 state = state.with_runtime_state( 206 current_timestep=0.0, 207 current_sample=x_3d, 208 timestep_index=0, 209 total_timesteps=len(timesteps), 210 ) 211 # offload TEs to CPU, load model to gpu 212 if offload: 213 t5, clip = t5.cpu(), clip.cpu() 214 empty_cache 215 if is_compiled: 216 raise RuntimeError("Can't move compiled models. Compile on target device, or disable compilation when offloading.") 217 model = model.to(device) # type: ignore[attr-defined] 218 if compile: 219 nfo("Compiling model on device.") 220 model = torch.compile(model, mode="max-autotune") # type: ignore[assignment] 221 is_compiled = True 222 223 from divisor.state import InferenceState 224 225 settings = InferenceState( 226 img=inp["img"], 227 img_ids=inp["img_ids"], 228 txt=inp["txt"], 229 txt_ids=inp["txt_ids"], 230 vec=inp["vec"], 231 state=state, 232 ae=ae, 233 timesteps=timesteps, 234 device=device, 235 t5=t5, 236 clip=clip, 237 ) 238 x_4d = denoise( 239 model, # type: ignore[arg-type] 240 settings=settings, 241 ) 242 243 if loop: 244 nfo("-" * 80) 245 state = parse_prompt(state) 246 elif additional_prompts: 247 next_prompt = additional_prompts.pop(0) 248 state = replace(state, prompt=next_prompt) 249 else: 250 state = None 251 252 253if __name__ == "__main__": 254 Fire(main)
21def parse_prompt(state: MenuState) -> MenuState | None: 22 """Parse user input and update input display. 23 24 :param state: Current state to update 25 :returns: Updated input display or None to quit""" 26 user_question = "Next prompt (write /h for help, /q to quit and leave empty to repeat):\n" 27 usage = ( 28 "Usage: Either write your prompt directly, leave this field empty " 29 "to repeat the prompt or write a command starting with a slash:\n" 30 "- '/w <width>' will set the width of the generated image\n" 31 "- '/h <height>' will set the height of the generated image\n" 32 "- '/s <seed>' sets the next seed\n" 33 "- '/g <guidance>' sets the guidance (flux-dev only)\n" 34 "- '/n <steps>' sets the number of steps\n" 35 "- '/q' to quit" 36 ) 37 38 while (prompt := input(user_question)).startswith("/"): 39 if prompt.startswith("/w"): 40 if prompt.count(" ") != 1: 41 nfo(f"Got invalid command '{prompt}'\n{usage}") 42 continue 43 _, width = prompt.split() 44 width = 16 * (int(width) // 16) 45 state = replace(state, width=width) 46 nfo(f"Setting resolution to {state.width} x {state.height} ({state.height * state.width / 1e6:.2f}MP)") 47 elif prompt.startswith("/h"): 48 if prompt.count(" ") != 1: 49 nfo(f"Got invalid command '{prompt}'\n{usage}") 50 continue 51 _, height = prompt.split() 52 height = 16 * (int(height) // 16) 53 state = replace(state, height=height) 54 nfo(f"Setting resolution to {state.width} x {state.height} ({state.height * state.width / 1e6:.2f}MP)") 55 elif prompt.startswith("/g"): 56 if prompt.count(" ") != 1: 57 nfo(f"Got invalid command '{prompt}'\n{usage}") 58 continue 59 _, guidance = prompt.split() 60 state = replace(state, guidance=float(guidance)) 61 nfo(f"Setting guidance to {state.guidance}") 62 elif prompt.startswith("/s"): 63 if prompt.count(" ") != 1: 64 nfo(f"Got invalid command '{prompt}'\n{usage}") 65 continue 66 _, seed = prompt.split() 67 state = replace(state, seed=int(seed)) 68 nfo(f"Setting seed to {state.seed}") 69 elif prompt.startswith("/n"): 70 if prompt.count(" ") != 1: 71 nfo(f"Got invalid command '{prompt}'\n{usage}") 72 continue 73 _, steps = prompt.split() 74 state = replace(state, num_steps=int(steps)) 75 nfo(f"Setting number of steps to {state.num_steps}") 76 elif prompt.startswith("/q"): 77 nfo("Quitting") 78 return None 79 else: 80 if not prompt.startswith("/h"): 81 nfo(f"Got invalid command '{prompt}'\n{usage}") 82 nfo(usage) 83 if prompt != "": 84 state = replace(state, prompt=prompt) 85 return state
Parse user input and update input display.
Parameters
- state: Current state to update :returns: Updated input display or None to quit
@torch.inference_mode()
def
main( mir_id: str = 'model.dit.flux1-dev', ae_id: str = 'model.vae.flux1-dev', width: int = 1360, height: int = 768, guidance: float = 4.0, seed: int = 10283458677934723471, prompt: str = '', quantization: bool = False, device: torch.device = device(type='mps'), num_steps: int = 50, loop: bool = False, offload: bool = False, compile: bool = False, verbose: bool = False, input_images: list[str] | None = None):
88@torch.inference_mode() 89def main( 90 mir_id: str = "model.dit.flux1-dev", 91 ae_id: str = "model.vae.flux1-dev", 92 width: int = 1360, 93 height: int = 768, 94 guidance: float = 4.0, 95 seed: int = rng.next_seed(), 96 prompt: str = "", 97 quantization: bool = False, 98 device: torch.device = gfx_device, 99 num_steps: int = 50, 100 loop: bool = False, 101 offload: bool = False, 102 compile: bool = False, 103 verbose: bool = False, 104 input_images: list[str] | None = None, 105): 106 """Sample the flux model. Either interactively (set `--loop`) or run for a single image.\n 107 :param name: Name of the model to load 108 :param height: height of the sample in pixels (should be a multiple of 16) 109 :param width: width of the sample in pixels (should be a multiple of 16) 110 :param seed: Set a seed for sampling 111 :param output_name: where to save the output image, `{idx}` will be replaced by the index of the sample 112 :param prompt: Prompt used for sampling 113 :param device: Pytorch device 114 :param num_steps: number of sampling steps (default 4 for schnell, 28 for guidance distilled) 115 :param loop: start an interactive session and sample multiple times 116 :param guidance: guidance value used for guidance distillation""" 117 118 if quantization: 119 mir_id += ":*@fp8-sai" 120 model_spec: ModelSpec = get_model_spec(mir_id, flux_configs) 121 ae_spec: ModelSpec = get_model_spec(ae_id, flux_configs) 122 init: InitialParamsFlux = model_spec.init 123 124 prompt_parts = prompt.split("|") 125 if len(prompt_parts) == 1: 126 prompt = prompt_parts[0] 127 additional_prompts = None 128 else: 129 additional_prompts = prompt_parts[1:] 130 prompt = prompt_parts[0] 131 132 assert not ((additional_prompts is not None) and loop), "Do not provide additional prompts and set loop to True" 133 134 height = 16 * (height // 16) 135 width = 16 * (width // 16) 136 137 t5 = load_t5(device, init.max_length or 512) 138 clip = load_clip(device) 139 model = load_flow_model( 140 model_spec, 141 device=torch.device("cpu") if offload else device, 142 verbose=verbose, 143 ) 144 145 is_compiled = False 146 if compile and not offload: # (compiled models should be on target device) 147 nfo("Compilation enabled.") 148 model = torch.compile(model) # type: ignore[assignment] 149 is_compiled = True 150 151 ae = load_ae(ae_spec, device=torch.device("cpu") if offload else device) 152 153 state = MenuState.from_cli_args( 154 prompt=prompt, 155 width=width, 156 height=height, 157 num_steps=num_steps or init.num_steps, 158 guidance=guidance, 159 seed=seed, 160 ) 161 162 if loop: 163 state = parse_prompt(state) 164 165 while state is not None: 166 if state.seed is None: 167 seed = rng.next_seed() 168 state = replace(state, seed=seed) 169 else: 170 rng.next_seed(state.seed) 171 assert state.seed is not None, "Seed must be set" 172 assert state.width is not None and state.height is not None, "Width and height must be set" 173 assert state.num_steps is not None, "num_steps must be set" 174 assert state.prompt is not None, "Prompt must be set" 175 nfo(f"Generating with seed {rng.seed}: {state.prompt}") 176 177 from divisor.noise import get_noise 178 179 x_4d = get_noise( 180 1, 181 state.height, 182 state.width, 183 device=device, 184 dtype=torch.bfloat16, 185 seed=rng.seed, # type: ignore 186 ) 187 188 if offload: 189 ae = ae.cpu() 190 empty_cache 191 192 t5, clip = t5.to(device), clip.to(device) 193 x_3d = prepare_4d_noise_for_3d_model( 194 height=state.height, # type: ignore 195 width=state.width, # type: ignore 196 seed=rng.seed, # type: ignore 197 t5=t5, 198 clip=clip, 199 prompt=state.prompt, 200 device=device, 201 dtype=torch.bfloat16, 202 ) 203 inp = prepare(t5, clip, x_4d, prompt=state.prompt) 204 timesteps = get_schedule(state.num_steps, inp["img"].shape[1], shift=init.shift) 205 # Update state with runtime information (use 3D format for current_sample) 206 state = state.with_runtime_state( 207 current_timestep=0.0, 208 current_sample=x_3d, 209 timestep_index=0, 210 total_timesteps=len(timesteps), 211 ) 212 # offload TEs to CPU, load model to gpu 213 if offload: 214 t5, clip = t5.cpu(), clip.cpu() 215 empty_cache 216 if is_compiled: 217 raise RuntimeError("Can't move compiled models. Compile on target device, or disable compilation when offloading.") 218 model = model.to(device) # type: ignore[attr-defined] 219 if compile: 220 nfo("Compiling model on device.") 221 model = torch.compile(model, mode="max-autotune") # type: ignore[assignment] 222 is_compiled = True 223 224 from divisor.state import InferenceState 225 226 settings = InferenceState( 227 img=inp["img"], 228 img_ids=inp["img_ids"], 229 txt=inp["txt"], 230 txt_ids=inp["txt_ids"], 231 vec=inp["vec"], 232 state=state, 233 ae=ae, 234 timesteps=timesteps, 235 device=device, 236 t5=t5, 237 clip=clip, 238 ) 239 x_4d = denoise( 240 model, # type: ignore[arg-type] 241 settings=settings, 242 ) 243 244 if loop: 245 nfo("-" * 80) 246 state = parse_prompt(state) 247 elif additional_prompts: 248 next_prompt = additional_prompts.pop(0) 249 state = replace(state, prompt=next_prompt) 250 else: 251 state = None
Sample the flux model. Either interactively (set --loop) or run for a single image.
Parameters
- name: Name of the model to load
- height: height of the sample in pixels (should be a multiple of 16)
- width: width of the sample in pixels (should be a multiple of 16)
- seed: Set a seed for sampling
- output_name: where to save the output image,
{idx}will be replaced by the index of the sample - prompt: Prompt used for sampling
- device: Pytorch device
- num_steps: number of sampling steps (default 4 for schnell, 28 for guidance distilled)
- loop: start an interactive session and sample multiple times
- guidance: guidance value used for guidance distillation