divisor.xflux1.prompt

  1# SPDX-License-Identifier: MPL-2.0 AND LicenseRef-Commons-Clause-License-Condition-1.0
  2# <!-- // /*  d a r k s h a p e s */ -->
  3# adapted XFlux code from https://github.com/TencentARC/FluxKits
  4
  5from dataclasses import replace
  6
  7import torch
  8from fire import Fire
  9from nnll.console import nfo
 10
 11from divisor.controller import rng
 12from divisor.flux1.loading import load_ae, load_clip, load_flow_model, load_t5
 13from divisor.flux1.prompt import parse_prompt
 14from divisor.flux1.sampling import get_schedule, prepare
 15from divisor.noise import get_noise
 16from divisor.registry import empty_cache, gfx_device
 17from divisor.spec import InitialParamsFlux, ModelSpec, flux_configs, get_model_spec
 18from divisor.state import MenuState
 19from divisor.xflux1.sampling import denoise
 20
 21
 22@torch.inference_mode()
 23def main(
 24    mir_id: str = "model.dit.flux1-dev:mini",
 25    ae_id: str = "model.vae.flux1-dev",
 26    width: int = 1360,
 27    height: int = 768,
 28    guidance: float = 4,
 29    seed: int = rng.next_seed(),
 30    prompt: str = "",
 31    quantization: bool = False,
 32    device: torch.device = gfx_device,
 33    num_steps: int = 50,
 34    loop: bool = False,
 35    offload: bool = False,
 36    compile: bool = False,
 37    verbose: bool = False,
 38    input_images: list[str] | None = None,
 39):
 40    """Sample the flux model. Either interactively (set `--loop`) or run for a single image.\n
 41    :param name: Name of the model to load
 42    :param height: height of the sample in pixels (should be a multiple of 16)
 43    :param width: width of the sample in pixels (should be a multiple of 16)
 44    :param seed: Set a seed for sampling
 45    :param output_name: where to save the output image, `{idx}` will be replaced by the index of the sample
 46    :param prompt: Prompt used for sampling
 47    :param device: Pytorch device
 48    :param num_steps: number of sampling steps (default 4 for schnell, 28 for guidance distilled)
 49    :param loop: start an interactive session and sample multiple times
 50    :param guidance: guidance value used for guidance distillation
 51    """
 52
 53    if quantization:
 54        mir_id += ":@fp8-sai"
 55
 56    model_spec: ModelSpec = get_model_spec(mir_id, flux_configs)
 57    ae_spec = get_model_spec(ae_id, flux_configs)
 58
 59    prompt_parts = prompt.split("|")
 60    if len(prompt_parts) == 1:
 61        prompt = prompt_parts[0]
 62        additional_prompts = None
 63    else:
 64        additional_prompts = prompt_parts[1:]
 65        prompt = prompt_parts[0]
 66
 67    init: InitialParamsFlux = model_spec.init
 68
 69    height = 16 * (height // 16)
 70    width = 16 * (width // 16)
 71
 72    t5 = load_t5(device, init.max_length)
 73    clip = load_clip(device)
 74    # Load model to final device if not offloading (compile requires model to be on target device)
 75    model = load_flow_model(
 76        model_spec,
 77        device=torch.device("cpu") if offload else device,
 78        verbose=verbose,
 79    )
 80
 81    is_compiled = False
 82    if compile and not offload:
 83        # Compile only if not offloading (compiled models can't be easily moved between devices)
 84        nfo("Compilation enabled.")
 85        model = torch.compile(model)  # type: ignore[assignment]
 86        is_compiled = True
 87
 88    ae = load_ae(ae_spec, device=torch.device("cpu") if offload else device)
 89
 90    # Create initial state from CLI args
 91    state = MenuState.from_cli_args(
 92        prompt=prompt,
 93        width=width,
 94        height=height,
 95        num_steps=num_steps or init.num_steps,
 96        guidance=guidance,
 97        seed=seed,
 98    )
 99
100    if loop:
101        state = parse_prompt(state)
102
103    while state is not None:
104        if state.seed is None:
105            seed = rng.next_seed()
106            state = replace(state, seed=seed)
107        else:
108            rng.next_seed(state.seed)
109        # At this point, state.seed is guaranteed to be an int
110        assert state.seed is not None, "Seed must be set"
111        assert state.width is not None and state.height is not None, "Width and height must be set"
112        assert state.num_steps is not None, "num_steps must be set"
113        assert state.prompt is not None, "Prompt must be set"
114        nfo(f"Generating with seed {rng.seed}: {state.prompt}")
115
116        x = get_noise(
117            1,
118            state.height,
119            state.width,
120            device=device,
121            dtype=torch.bfloat16,
122            seed=rng.seed,  # type: ignore
123        )
124
125        # prepare input
126        if offload:
127            ae = ae.cpu()
128            empty_cache
129
130            t5, clip = t5.to(device), clip.to(device)
131        inp = prepare(t5, clip, x, prompt=state.prompt)
132        timesteps = get_schedule(state.num_steps, inp["img"].shape[1], shift=init.shift)
133        # Update state with runtime information
134        state = state.with_runtime_state(
135            current_timestep=0.0,
136            current_sample=x,
137            timestep_index=0,
138            total_timesteps=len(timesteps),
139        )
140        # offload TEs to CPU, load model to gpu
141        if offload:
142            t5, clip = t5.cpu(), clip.cpu()
143            empty_cache
144            # Move model to device
145            if is_compiled:
146                # Can't move compiled models, so recompile after moving
147                # This requires getting the underlying model, which is tricky
148                # For now, just move and recompile
149                # Note: This is a limitation - ideally compile after moving
150                raise RuntimeError("Cannot use offload=True with compile=True. Compile after model is on device, or disable compilation when offloading.")
151            # At this point, model is not compiled, so .to() is safe
152            model = model.to(device)  # type: ignore[attr-defined]
153            # Compile after moving to device if requested
154            if compile:
155                nfo("Compiling model on device.")
156
157                model = torch.compile(model, mode="max-autotune")  # type: ignore[assignment]
158                is_compiled = True
159
160        # denoise initial noise
161        from divisor.state import InferenceState
162
163        settings = InferenceState(
164            img=inp["img"],
165            img_ids=inp["img_ids"],
166            txt=inp["txt"],
167            txt_ids=inp["txt_ids"],
168            vec=inp["vec"],
169            state=state,
170            ae=ae,
171            timesteps=timesteps,
172            device=device,
173            t5=t5,
174            clip=clip,
175        )
176        x = denoise(
177            model,  # type: ignore[arg-type]
178            settings=settings,
179        )
180
181        if loop:
182            nfo("-" * 80)
183            state = parse_prompt(state)
184        elif additional_prompts:
185            next_prompt = additional_prompts.pop(0)
186            state = replace(state, prompt=next_prompt)
187        else:
188            state = None
189
190
191if __name__ == "__main__":
192    Fire(main)
@torch.inference_mode()
def main( mir_id: str = 'model.dit.flux1-dev:mini', ae_id: str = 'model.vae.flux1-dev', width: int = 1360, height: int = 768, guidance: float = 4, seed: int = 15126361353060768187, prompt: str = '', quantization: bool = False, device: torch.device = device(type='mps'), num_steps: int = 50, loop: bool = False, offload: bool = False, compile: bool = False, verbose: bool = False, input_images: list[str] | None = None):
 23@torch.inference_mode()
 24def main(
 25    mir_id: str = "model.dit.flux1-dev:mini",
 26    ae_id: str = "model.vae.flux1-dev",
 27    width: int = 1360,
 28    height: int = 768,
 29    guidance: float = 4,
 30    seed: int = rng.next_seed(),
 31    prompt: str = "",
 32    quantization: bool = False,
 33    device: torch.device = gfx_device,
 34    num_steps: int = 50,
 35    loop: bool = False,
 36    offload: bool = False,
 37    compile: bool = False,
 38    verbose: bool = False,
 39    input_images: list[str] | None = None,
 40):
 41    """Sample the flux model. Either interactively (set `--loop`) or run for a single image.\n
 42    :param name: Name of the model to load
 43    :param height: height of the sample in pixels (should be a multiple of 16)
 44    :param width: width of the sample in pixels (should be a multiple of 16)
 45    :param seed: Set a seed for sampling
 46    :param output_name: where to save the output image, `{idx}` will be replaced by the index of the sample
 47    :param prompt: Prompt used for sampling
 48    :param device: Pytorch device
 49    :param num_steps: number of sampling steps (default 4 for schnell, 28 for guidance distilled)
 50    :param loop: start an interactive session and sample multiple times
 51    :param guidance: guidance value used for guidance distillation
 52    """
 53
 54    if quantization:
 55        mir_id += ":@fp8-sai"
 56
 57    model_spec: ModelSpec = get_model_spec(mir_id, flux_configs)
 58    ae_spec = get_model_spec(ae_id, flux_configs)
 59
 60    prompt_parts = prompt.split("|")
 61    if len(prompt_parts) == 1:
 62        prompt = prompt_parts[0]
 63        additional_prompts = None
 64    else:
 65        additional_prompts = prompt_parts[1:]
 66        prompt = prompt_parts[0]
 67
 68    init: InitialParamsFlux = model_spec.init
 69
 70    height = 16 * (height // 16)
 71    width = 16 * (width // 16)
 72
 73    t5 = load_t5(device, init.max_length)
 74    clip = load_clip(device)
 75    # Load model to final device if not offloading (compile requires model to be on target device)
 76    model = load_flow_model(
 77        model_spec,
 78        device=torch.device("cpu") if offload else device,
 79        verbose=verbose,
 80    )
 81
 82    is_compiled = False
 83    if compile and not offload:
 84        # Compile only if not offloading (compiled models can't be easily moved between devices)
 85        nfo("Compilation enabled.")
 86        model = torch.compile(model)  # type: ignore[assignment]
 87        is_compiled = True
 88
 89    ae = load_ae(ae_spec, device=torch.device("cpu") if offload else device)
 90
 91    # Create initial state from CLI args
 92    state = MenuState.from_cli_args(
 93        prompt=prompt,
 94        width=width,
 95        height=height,
 96        num_steps=num_steps or init.num_steps,
 97        guidance=guidance,
 98        seed=seed,
 99    )
100
101    if loop:
102        state = parse_prompt(state)
103
104    while state is not None:
105        if state.seed is None:
106            seed = rng.next_seed()
107            state = replace(state, seed=seed)
108        else:
109            rng.next_seed(state.seed)
110        # At this point, state.seed is guaranteed to be an int
111        assert state.seed is not None, "Seed must be set"
112        assert state.width is not None and state.height is not None, "Width and height must be set"
113        assert state.num_steps is not None, "num_steps must be set"
114        assert state.prompt is not None, "Prompt must be set"
115        nfo(f"Generating with seed {rng.seed}: {state.prompt}")
116
117        x = get_noise(
118            1,
119            state.height,
120            state.width,
121            device=device,
122            dtype=torch.bfloat16,
123            seed=rng.seed,  # type: ignore
124        )
125
126        # prepare input
127        if offload:
128            ae = ae.cpu()
129            empty_cache
130
131            t5, clip = t5.to(device), clip.to(device)
132        inp = prepare(t5, clip, x, prompt=state.prompt)
133        timesteps = get_schedule(state.num_steps, inp["img"].shape[1], shift=init.shift)
134        # Update state with runtime information
135        state = state.with_runtime_state(
136            current_timestep=0.0,
137            current_sample=x,
138            timestep_index=0,
139            total_timesteps=len(timesteps),
140        )
141        # offload TEs to CPU, load model to gpu
142        if offload:
143            t5, clip = t5.cpu(), clip.cpu()
144            empty_cache
145            # Move model to device
146            if is_compiled:
147                # Can't move compiled models, so recompile after moving
148                # This requires getting the underlying model, which is tricky
149                # For now, just move and recompile
150                # Note: This is a limitation - ideally compile after moving
151                raise RuntimeError("Cannot use offload=True with compile=True. Compile after model is on device, or disable compilation when offloading.")
152            # At this point, model is not compiled, so .to() is safe
153            model = model.to(device)  # type: ignore[attr-defined]
154            # Compile after moving to device if requested
155            if compile:
156                nfo("Compiling model on device.")
157
158                model = torch.compile(model, mode="max-autotune")  # type: ignore[assignment]
159                is_compiled = True
160
161        # denoise initial noise
162        from divisor.state import InferenceState
163
164        settings = InferenceState(
165            img=inp["img"],
166            img_ids=inp["img_ids"],
167            txt=inp["txt"],
168            txt_ids=inp["txt_ids"],
169            vec=inp["vec"],
170            state=state,
171            ae=ae,
172            timesteps=timesteps,
173            device=device,
174            t5=t5,
175            clip=clip,
176        )
177        x = denoise(
178            model,  # type: ignore[arg-type]
179            settings=settings,
180        )
181
182        if loop:
183            nfo("-" * 80)
184            state = parse_prompt(state)
185        elif additional_prompts:
186            next_prompt = additional_prompts.pop(0)
187            state = replace(state, prompt=next_prompt)
188        else:
189            state = None

Sample the flux model. Either interactively (set --loop) or run for a single image.

Parameters
  • name: Name of the model to load
  • height: height of the sample in pixels (should be a multiple of 16)
  • width: width of the sample in pixels (should be a multiple of 16)
  • seed: Set a seed for sampling
  • output_name: where to save the output image, {idx} will be replaced by the index of the sample
  • prompt: Prompt used for sampling
  • device: Pytorch device
  • num_steps: number of sampling steps (default 4 for schnell, 28 for guidance distilled)
  • loop: start an interactive session and sample multiple times
  • guidance: guidance value used for guidance distillation