divisor.xflux1.prompt
1# SPDX-License-Identifier: MPL-2.0 AND LicenseRef-Commons-Clause-License-Condition-1.0 2# <!-- // /* d a r k s h a p e s */ --> 3# adapted XFlux code from https://github.com/TencentARC/FluxKits 4 5from dataclasses import replace 6 7import torch 8from fire import Fire 9from nnll.console import nfo 10 11from divisor.controller import rng 12from divisor.flux1.loading import load_ae, load_clip, load_flow_model, load_t5 13from divisor.flux1.prompt import parse_prompt 14from divisor.flux1.sampling import get_schedule, prepare 15from divisor.noise import get_noise 16from divisor.registry import empty_cache, gfx_device 17from divisor.spec import InitialParamsFlux, ModelSpec, flux_configs, get_model_spec 18from divisor.state import MenuState 19from divisor.xflux1.sampling import denoise 20 21 22@torch.inference_mode() 23def main( 24 mir_id: str = "model.dit.flux1-dev:mini", 25 ae_id: str = "model.vae.flux1-dev", 26 width: int = 1360, 27 height: int = 768, 28 guidance: float = 4, 29 seed: int = rng.next_seed(), 30 prompt: str = "", 31 quantization: bool = False, 32 device: torch.device = gfx_device, 33 num_steps: int = 50, 34 loop: bool = False, 35 offload: bool = False, 36 compile: bool = False, 37 verbose: bool = False, 38 input_images: list[str] | None = None, 39): 40 """Sample the flux model. Either interactively (set `--loop`) or run for a single image.\n 41 :param name: Name of the model to load 42 :param height: height of the sample in pixels (should be a multiple of 16) 43 :param width: width of the sample in pixels (should be a multiple of 16) 44 :param seed: Set a seed for sampling 45 :param output_name: where to save the output image, `{idx}` will be replaced by the index of the sample 46 :param prompt: Prompt used for sampling 47 :param device: Pytorch device 48 :param num_steps: number of sampling steps (default 4 for schnell, 28 for guidance distilled) 49 :param loop: start an interactive session and sample multiple times 50 :param guidance: guidance value used for guidance distillation 51 """ 52 53 if quantization: 54 mir_id += ":@fp8-sai" 55 56 model_spec: ModelSpec = get_model_spec(mir_id, flux_configs) 57 ae_spec = get_model_spec(ae_id, flux_configs) 58 59 prompt_parts = prompt.split("|") 60 if len(prompt_parts) == 1: 61 prompt = prompt_parts[0] 62 additional_prompts = None 63 else: 64 additional_prompts = prompt_parts[1:] 65 prompt = prompt_parts[0] 66 67 init: InitialParamsFlux = model_spec.init 68 69 height = 16 * (height // 16) 70 width = 16 * (width // 16) 71 72 t5 = load_t5(device, init.max_length) 73 clip = load_clip(device) 74 # Load model to final device if not offloading (compile requires model to be on target device) 75 model = load_flow_model( 76 model_spec, 77 device=torch.device("cpu") if offload else device, 78 verbose=verbose, 79 ) 80 81 is_compiled = False 82 if compile and not offload: 83 # Compile only if not offloading (compiled models can't be easily moved between devices) 84 nfo("Compilation enabled.") 85 model = torch.compile(model) # type: ignore[assignment] 86 is_compiled = True 87 88 ae = load_ae(ae_spec, device=torch.device("cpu") if offload else device) 89 90 # Create initial state from CLI args 91 state = MenuState.from_cli_args( 92 prompt=prompt, 93 width=width, 94 height=height, 95 num_steps=num_steps or init.num_steps, 96 guidance=guidance, 97 seed=seed, 98 ) 99 100 if loop: 101 state = parse_prompt(state) 102 103 while state is not None: 104 if state.seed is None: 105 seed = rng.next_seed() 106 state = replace(state, seed=seed) 107 else: 108 rng.next_seed(state.seed) 109 # At this point, state.seed is guaranteed to be an int 110 assert state.seed is not None, "Seed must be set" 111 assert state.width is not None and state.height is not None, "Width and height must be set" 112 assert state.num_steps is not None, "num_steps must be set" 113 assert state.prompt is not None, "Prompt must be set" 114 nfo(f"Generating with seed {rng.seed}: {state.prompt}") 115 116 x = get_noise( 117 1, 118 state.height, 119 state.width, 120 device=device, 121 dtype=torch.bfloat16, 122 seed=rng.seed, # type: ignore 123 ) 124 125 # prepare input 126 if offload: 127 ae = ae.cpu() 128 empty_cache 129 130 t5, clip = t5.to(device), clip.to(device) 131 inp = prepare(t5, clip, x, prompt=state.prompt) 132 timesteps = get_schedule(state.num_steps, inp["img"].shape[1], shift=init.shift) 133 # Update state with runtime information 134 state = state.with_runtime_state( 135 current_timestep=0.0, 136 current_sample=x, 137 timestep_index=0, 138 total_timesteps=len(timesteps), 139 ) 140 # offload TEs to CPU, load model to gpu 141 if offload: 142 t5, clip = t5.cpu(), clip.cpu() 143 empty_cache 144 # Move model to device 145 if is_compiled: 146 # Can't move compiled models, so recompile after moving 147 # This requires getting the underlying model, which is tricky 148 # For now, just move and recompile 149 # Note: This is a limitation - ideally compile after moving 150 raise RuntimeError("Cannot use offload=True with compile=True. Compile after model is on device, or disable compilation when offloading.") 151 # At this point, model is not compiled, so .to() is safe 152 model = model.to(device) # type: ignore[attr-defined] 153 # Compile after moving to device if requested 154 if compile: 155 nfo("Compiling model on device.") 156 157 model = torch.compile(model, mode="max-autotune") # type: ignore[assignment] 158 is_compiled = True 159 160 # denoise initial noise 161 from divisor.state import InferenceState 162 163 settings = InferenceState( 164 img=inp["img"], 165 img_ids=inp["img_ids"], 166 txt=inp["txt"], 167 txt_ids=inp["txt_ids"], 168 vec=inp["vec"], 169 state=state, 170 ae=ae, 171 timesteps=timesteps, 172 device=device, 173 t5=t5, 174 clip=clip, 175 ) 176 x = denoise( 177 model, # type: ignore[arg-type] 178 settings=settings, 179 ) 180 181 if loop: 182 nfo("-" * 80) 183 state = parse_prompt(state) 184 elif additional_prompts: 185 next_prompt = additional_prompts.pop(0) 186 state = replace(state, prompt=next_prompt) 187 else: 188 state = None 189 190 191if __name__ == "__main__": 192 Fire(main)
@torch.inference_mode()
def
main( mir_id: str = 'model.dit.flux1-dev:mini', ae_id: str = 'model.vae.flux1-dev', width: int = 1360, height: int = 768, guidance: float = 4, seed: int = 15126361353060768187, prompt: str = '', quantization: bool = False, device: torch.device = device(type='mps'), num_steps: int = 50, loop: bool = False, offload: bool = False, compile: bool = False, verbose: bool = False, input_images: list[str] | None = None):
23@torch.inference_mode() 24def main( 25 mir_id: str = "model.dit.flux1-dev:mini", 26 ae_id: str = "model.vae.flux1-dev", 27 width: int = 1360, 28 height: int = 768, 29 guidance: float = 4, 30 seed: int = rng.next_seed(), 31 prompt: str = "", 32 quantization: bool = False, 33 device: torch.device = gfx_device, 34 num_steps: int = 50, 35 loop: bool = False, 36 offload: bool = False, 37 compile: bool = False, 38 verbose: bool = False, 39 input_images: list[str] | None = None, 40): 41 """Sample the flux model. Either interactively (set `--loop`) or run for a single image.\n 42 :param name: Name of the model to load 43 :param height: height of the sample in pixels (should be a multiple of 16) 44 :param width: width of the sample in pixels (should be a multiple of 16) 45 :param seed: Set a seed for sampling 46 :param output_name: where to save the output image, `{idx}` will be replaced by the index of the sample 47 :param prompt: Prompt used for sampling 48 :param device: Pytorch device 49 :param num_steps: number of sampling steps (default 4 for schnell, 28 for guidance distilled) 50 :param loop: start an interactive session and sample multiple times 51 :param guidance: guidance value used for guidance distillation 52 """ 53 54 if quantization: 55 mir_id += ":@fp8-sai" 56 57 model_spec: ModelSpec = get_model_spec(mir_id, flux_configs) 58 ae_spec = get_model_spec(ae_id, flux_configs) 59 60 prompt_parts = prompt.split("|") 61 if len(prompt_parts) == 1: 62 prompt = prompt_parts[0] 63 additional_prompts = None 64 else: 65 additional_prompts = prompt_parts[1:] 66 prompt = prompt_parts[0] 67 68 init: InitialParamsFlux = model_spec.init 69 70 height = 16 * (height // 16) 71 width = 16 * (width // 16) 72 73 t5 = load_t5(device, init.max_length) 74 clip = load_clip(device) 75 # Load model to final device if not offloading (compile requires model to be on target device) 76 model = load_flow_model( 77 model_spec, 78 device=torch.device("cpu") if offload else device, 79 verbose=verbose, 80 ) 81 82 is_compiled = False 83 if compile and not offload: 84 # Compile only if not offloading (compiled models can't be easily moved between devices) 85 nfo("Compilation enabled.") 86 model = torch.compile(model) # type: ignore[assignment] 87 is_compiled = True 88 89 ae = load_ae(ae_spec, device=torch.device("cpu") if offload else device) 90 91 # Create initial state from CLI args 92 state = MenuState.from_cli_args( 93 prompt=prompt, 94 width=width, 95 height=height, 96 num_steps=num_steps or init.num_steps, 97 guidance=guidance, 98 seed=seed, 99 ) 100 101 if loop: 102 state = parse_prompt(state) 103 104 while state is not None: 105 if state.seed is None: 106 seed = rng.next_seed() 107 state = replace(state, seed=seed) 108 else: 109 rng.next_seed(state.seed) 110 # At this point, state.seed is guaranteed to be an int 111 assert state.seed is not None, "Seed must be set" 112 assert state.width is not None and state.height is not None, "Width and height must be set" 113 assert state.num_steps is not None, "num_steps must be set" 114 assert state.prompt is not None, "Prompt must be set" 115 nfo(f"Generating with seed {rng.seed}: {state.prompt}") 116 117 x = get_noise( 118 1, 119 state.height, 120 state.width, 121 device=device, 122 dtype=torch.bfloat16, 123 seed=rng.seed, # type: ignore 124 ) 125 126 # prepare input 127 if offload: 128 ae = ae.cpu() 129 empty_cache 130 131 t5, clip = t5.to(device), clip.to(device) 132 inp = prepare(t5, clip, x, prompt=state.prompt) 133 timesteps = get_schedule(state.num_steps, inp["img"].shape[1], shift=init.shift) 134 # Update state with runtime information 135 state = state.with_runtime_state( 136 current_timestep=0.0, 137 current_sample=x, 138 timestep_index=0, 139 total_timesteps=len(timesteps), 140 ) 141 # offload TEs to CPU, load model to gpu 142 if offload: 143 t5, clip = t5.cpu(), clip.cpu() 144 empty_cache 145 # Move model to device 146 if is_compiled: 147 # Can't move compiled models, so recompile after moving 148 # This requires getting the underlying model, which is tricky 149 # For now, just move and recompile 150 # Note: This is a limitation - ideally compile after moving 151 raise RuntimeError("Cannot use offload=True with compile=True. Compile after model is on device, or disable compilation when offloading.") 152 # At this point, model is not compiled, so .to() is safe 153 model = model.to(device) # type: ignore[attr-defined] 154 # Compile after moving to device if requested 155 if compile: 156 nfo("Compiling model on device.") 157 158 model = torch.compile(model, mode="max-autotune") # type: ignore[assignment] 159 is_compiled = True 160 161 # denoise initial noise 162 from divisor.state import InferenceState 163 164 settings = InferenceState( 165 img=inp["img"], 166 img_ids=inp["img_ids"], 167 txt=inp["txt"], 168 txt_ids=inp["txt_ids"], 169 vec=inp["vec"], 170 state=state, 171 ae=ae, 172 timesteps=timesteps, 173 device=device, 174 t5=t5, 175 clip=clip, 176 ) 177 x = denoise( 178 model, # type: ignore[arg-type] 179 settings=settings, 180 ) 181 182 if loop: 183 nfo("-" * 80) 184 state = parse_prompt(state) 185 elif additional_prompts: 186 next_prompt = additional_prompts.pop(0) 187 state = replace(state, prompt=next_prompt) 188 else: 189 state = None
Sample the flux model. Either interactively (set --loop) or run for a single image.
Parameters
- name: Name of the model to load
- height: height of the sample in pixels (should be a multiple of 16)
- width: width of the sample in pixels (should be a multiple of 16)
- seed: Set a seed for sampling
- output_name: where to save the output image,
{idx}will be replaced by the index of the sample - prompt: Prompt used for sampling
- device: Pytorch device
- num_steps: number of sampling steps (default 4 for schnell, 28 for guidance distilled)
- loop: start an interactive session and sample multiple times
- guidance: guidance value used for guidance distillation