divisor.flux1.prompt

  1# SPDX-License-Identifier: MPL-2.0 AND LicenseRef-Commons-Clause-License-Condition-1.0
  2# <!-- // /*  d a r k s h a p e s */ -->
  3# adapted BFL Flux code from https://github.com/black-forest-labs/flux
  4
  5from dataclasses import replace
  6
  7from fire import Fire
  8from nnll.console import nfo
  9from divisor.registry import gfx_device, empty_cache
 10import torch
 11
 12from divisor.controller import rng
 13from divisor.flux1.loading import load_ae, load_clip, load_flow_model, load_t5
 14from divisor.flux1.sampling import denoise, get_schedule, prepare
 15from divisor.noise import prepare_4d_noise_for_3d_model
 16from divisor.spec import InitialParamsFlux, ModelSpec, flux_configs, get_model_spec
 17from divisor.state import MenuState
 18
 19
 20def parse_prompt(state: MenuState) -> MenuState | None:
 21    """Parse user input and update input display.
 22
 23    :param state: Current state to update
 24    :returns: Updated input display or None to quit"""
 25    user_question = "Next prompt (write /h for help, /q to quit and leave empty to repeat):\n"
 26    usage = (
 27        "Usage: Either write your prompt directly, leave this field empty "
 28        "to repeat the prompt or write a command starting with a slash:\n"
 29        "- '/w <width>' will set the width of the generated image\n"
 30        "- '/h <height>' will set the height of the generated image\n"
 31        "- '/s <seed>' sets the next seed\n"
 32        "- '/g <guidance>' sets the guidance (flux-dev only)\n"
 33        "- '/n <steps>' sets the number of steps\n"
 34        "- '/q' to quit"
 35    )
 36
 37    while (prompt := input(user_question)).startswith("/"):
 38        if prompt.startswith("/w"):
 39            if prompt.count(" ") != 1:
 40                nfo(f"Got invalid command '{prompt}'\n{usage}")
 41                continue
 42            _, width = prompt.split()
 43            width = 16 * (int(width) // 16)
 44            state = replace(state, width=width)
 45            nfo(f"Setting resolution to {state.width} x {state.height} ({state.height * state.width / 1e6:.2f}MP)")
 46        elif prompt.startswith("/h"):
 47            if prompt.count(" ") != 1:
 48                nfo(f"Got invalid command '{prompt}'\n{usage}")
 49                continue
 50            _, height = prompt.split()
 51            height = 16 * (int(height) // 16)
 52            state = replace(state, height=height)
 53            nfo(f"Setting resolution to {state.width} x {state.height} ({state.height * state.width / 1e6:.2f}MP)")
 54        elif prompt.startswith("/g"):
 55            if prompt.count(" ") != 1:
 56                nfo(f"Got invalid command '{prompt}'\n{usage}")
 57                continue
 58            _, guidance = prompt.split()
 59            state = replace(state, guidance=float(guidance))
 60            nfo(f"Setting guidance to {state.guidance}")
 61        elif prompt.startswith("/s"):
 62            if prompt.count(" ") != 1:
 63                nfo(f"Got invalid command '{prompt}'\n{usage}")
 64                continue
 65            _, seed = prompt.split()
 66            state = replace(state, seed=int(seed))
 67            nfo(f"Setting seed to {state.seed}")
 68        elif prompt.startswith("/n"):
 69            if prompt.count(" ") != 1:
 70                nfo(f"Got invalid command '{prompt}'\n{usage}")
 71                continue
 72            _, steps = prompt.split()
 73            state = replace(state, num_steps=int(steps))
 74            nfo(f"Setting number of steps to {state.num_steps}")
 75        elif prompt.startswith("/q"):
 76            nfo("Quitting")
 77            return None
 78        else:
 79            if not prompt.startswith("/h"):
 80                nfo(f"Got invalid command '{prompt}'\n{usage}")
 81            nfo(usage)
 82    if prompt != "":
 83        state = replace(state, prompt=prompt)
 84    return state
 85
 86
 87@torch.inference_mode()
 88def main(
 89    mir_id: str = "model.dit.flux1-dev",
 90    ae_id: str = "model.vae.flux1-dev",
 91    width: int = 1360,
 92    height: int = 768,
 93    guidance: float = 4.0,
 94    seed: int = rng.next_seed(),
 95    prompt: str = "",
 96    quantization: bool = False,
 97    device: torch.device = gfx_device,
 98    num_steps: int = 50,
 99    loop: bool = False,
100    offload: bool = False,
101    compile: bool = False,
102    verbose: bool = False,
103    input_images: list[str] | None = None,
104):
105    """Sample the flux model. Either interactively (set `--loop`) or run for a single image.\n
106    :param name: Name of the model to load
107    :param height: height of the sample in pixels (should be a multiple of 16)
108    :param width: width of the sample in pixels (should be a multiple of 16)
109    :param seed: Set a seed for sampling
110    :param output_name: where to save the output image, `{idx}` will be replaced by the index of the sample
111    :param prompt: Prompt used for sampling
112    :param device: Pytorch device
113    :param num_steps: number of sampling steps (default 4 for schnell, 28 for guidance distilled)
114    :param loop: start an interactive session and sample multiple times
115    :param guidance: guidance value used for guidance distillation"""
116
117    if quantization:
118        mir_id += ":*@fp8-sai"
119    model_spec: ModelSpec = get_model_spec(mir_id, flux_configs)
120    ae_spec: ModelSpec = get_model_spec(ae_id, flux_configs)
121    init: InitialParamsFlux = model_spec.init
122
123    prompt_parts = prompt.split("|")
124    if len(prompt_parts) == 1:
125        prompt = prompt_parts[0]
126        additional_prompts = None
127    else:
128        additional_prompts = prompt_parts[1:]
129        prompt = prompt_parts[0]
130
131    assert not ((additional_prompts is not None) and loop), "Do not provide additional prompts and set loop to True"
132
133    height = 16 * (height // 16)
134    width = 16 * (width // 16)
135
136    t5 = load_t5(device, init.max_length or 512)
137    clip = load_clip(device)
138    model = load_flow_model(
139        model_spec,
140        device=torch.device("cpu") if offload else device,
141        verbose=verbose,
142    )
143
144    is_compiled = False
145    if compile and not offload:  # (compiled models should be on target device)
146        nfo("Compilation enabled.")
147        model = torch.compile(model)  # type: ignore[assignment]
148        is_compiled = True
149
150    ae = load_ae(ae_spec, device=torch.device("cpu") if offload else device)
151
152    state = MenuState.from_cli_args(
153        prompt=prompt,
154        width=width,
155        height=height,
156        num_steps=num_steps or init.num_steps,
157        guidance=guidance,
158        seed=seed,
159    )
160
161    if loop:
162        state = parse_prompt(state)
163
164    while state is not None:
165        if state.seed is None:
166            seed = rng.next_seed()
167            state = replace(state, seed=seed)
168        else:
169            rng.next_seed(state.seed)
170        assert state.seed is not None, "Seed must be set"
171        assert state.width is not None and state.height is not None, "Width and height must be set"
172        assert state.num_steps is not None, "num_steps must be set"
173        assert state.prompt is not None, "Prompt must be set"
174        nfo(f"Generating with seed {rng.seed}: {state.prompt}")
175
176        from divisor.noise import get_noise
177
178        x_4d = get_noise(
179            1,
180            state.height,
181            state.width,
182            device=device,
183            dtype=torch.bfloat16,
184            seed=rng.seed,  # type: ignore
185        )
186
187        if offload:
188            ae = ae.cpu()
189            empty_cache
190
191            t5, clip = t5.to(device), clip.to(device)
192        x_3d = prepare_4d_noise_for_3d_model(
193            height=state.height,  # type: ignore
194            width=state.width,  # type: ignore
195            seed=rng.seed,  # type: ignore
196            t5=t5,
197            clip=clip,
198            prompt=state.prompt,
199            device=device,
200            dtype=torch.bfloat16,
201        )
202        inp = prepare(t5, clip, x_4d, prompt=state.prompt)
203        timesteps = get_schedule(state.num_steps, inp["img"].shape[1], shift=init.shift)
204        # Update state with runtime information (use 3D format for current_sample)
205        state = state.with_runtime_state(
206            current_timestep=0.0,
207            current_sample=x_3d,
208            timestep_index=0,
209            total_timesteps=len(timesteps),
210        )
211        # offload TEs to CPU, load model to gpu
212        if offload:
213            t5, clip = t5.cpu(), clip.cpu()
214            empty_cache
215            if is_compiled:
216                raise RuntimeError("Can't move compiled models. Compile on target device, or disable compilation when offloading.")
217            model = model.to(device)  # type: ignore[attr-defined]
218            if compile:
219                nfo("Compiling model on device.")
220                model = torch.compile(model, mode="max-autotune")  # type: ignore[assignment]
221                is_compiled = True
222
223        from divisor.state import InferenceState
224
225        settings = InferenceState(
226            img=inp["img"],
227            img_ids=inp["img_ids"],
228            txt=inp["txt"],
229            txt_ids=inp["txt_ids"],
230            vec=inp["vec"],
231            state=state,
232            ae=ae,
233            timesteps=timesteps,
234            device=device,
235            t5=t5,
236            clip=clip,
237        )
238        x_4d = denoise(
239            model,  # type: ignore[arg-type]
240            settings=settings,
241        )
242
243        if loop:
244            nfo("-" * 80)
245            state = parse_prompt(state)
246        elif additional_prompts:
247            next_prompt = additional_prompts.pop(0)
248            state = replace(state, prompt=next_prompt)
249        else:
250            state = None
251
252
253if __name__ == "__main__":
254    Fire(main)
def parse_prompt(state: divisor.state.MenuState) -> divisor.state.MenuState | None:
21def parse_prompt(state: MenuState) -> MenuState | None:
22    """Parse user input and update input display.
23
24    :param state: Current state to update
25    :returns: Updated input display or None to quit"""
26    user_question = "Next prompt (write /h for help, /q to quit and leave empty to repeat):\n"
27    usage = (
28        "Usage: Either write your prompt directly, leave this field empty "
29        "to repeat the prompt or write a command starting with a slash:\n"
30        "- '/w <width>' will set the width of the generated image\n"
31        "- '/h <height>' will set the height of the generated image\n"
32        "- '/s <seed>' sets the next seed\n"
33        "- '/g <guidance>' sets the guidance (flux-dev only)\n"
34        "- '/n <steps>' sets the number of steps\n"
35        "- '/q' to quit"
36    )
37
38    while (prompt := input(user_question)).startswith("/"):
39        if prompt.startswith("/w"):
40            if prompt.count(" ") != 1:
41                nfo(f"Got invalid command '{prompt}'\n{usage}")
42                continue
43            _, width = prompt.split()
44            width = 16 * (int(width) // 16)
45            state = replace(state, width=width)
46            nfo(f"Setting resolution to {state.width} x {state.height} ({state.height * state.width / 1e6:.2f}MP)")
47        elif prompt.startswith("/h"):
48            if prompt.count(" ") != 1:
49                nfo(f"Got invalid command '{prompt}'\n{usage}")
50                continue
51            _, height = prompt.split()
52            height = 16 * (int(height) // 16)
53            state = replace(state, height=height)
54            nfo(f"Setting resolution to {state.width} x {state.height} ({state.height * state.width / 1e6:.2f}MP)")
55        elif prompt.startswith("/g"):
56            if prompt.count(" ") != 1:
57                nfo(f"Got invalid command '{prompt}'\n{usage}")
58                continue
59            _, guidance = prompt.split()
60            state = replace(state, guidance=float(guidance))
61            nfo(f"Setting guidance to {state.guidance}")
62        elif prompt.startswith("/s"):
63            if prompt.count(" ") != 1:
64                nfo(f"Got invalid command '{prompt}'\n{usage}")
65                continue
66            _, seed = prompt.split()
67            state = replace(state, seed=int(seed))
68            nfo(f"Setting seed to {state.seed}")
69        elif prompt.startswith("/n"):
70            if prompt.count(" ") != 1:
71                nfo(f"Got invalid command '{prompt}'\n{usage}")
72                continue
73            _, steps = prompt.split()
74            state = replace(state, num_steps=int(steps))
75            nfo(f"Setting number of steps to {state.num_steps}")
76        elif prompt.startswith("/q"):
77            nfo("Quitting")
78            return None
79        else:
80            if not prompt.startswith("/h"):
81                nfo(f"Got invalid command '{prompt}'\n{usage}")
82            nfo(usage)
83    if prompt != "":
84        state = replace(state, prompt=prompt)
85    return state

Parse user input and update input display.

Parameters
  • state: Current state to update :returns: Updated input display or None to quit
@torch.inference_mode()
def main( mir_id: str = 'model.dit.flux1-dev', ae_id: str = 'model.vae.flux1-dev', width: int = 1360, height: int = 768, guidance: float = 4.0, seed: int = 10283458677934723471, prompt: str = '', quantization: bool = False, device: torch.device = device(type='mps'), num_steps: int = 50, loop: bool = False, offload: bool = False, compile: bool = False, verbose: bool = False, input_images: list[str] | None = None):
 88@torch.inference_mode()
 89def main(
 90    mir_id: str = "model.dit.flux1-dev",
 91    ae_id: str = "model.vae.flux1-dev",
 92    width: int = 1360,
 93    height: int = 768,
 94    guidance: float = 4.0,
 95    seed: int = rng.next_seed(),
 96    prompt: str = "",
 97    quantization: bool = False,
 98    device: torch.device = gfx_device,
 99    num_steps: int = 50,
100    loop: bool = False,
101    offload: bool = False,
102    compile: bool = False,
103    verbose: bool = False,
104    input_images: list[str] | None = None,
105):
106    """Sample the flux model. Either interactively (set `--loop`) or run for a single image.\n
107    :param name: Name of the model to load
108    :param height: height of the sample in pixels (should be a multiple of 16)
109    :param width: width of the sample in pixels (should be a multiple of 16)
110    :param seed: Set a seed for sampling
111    :param output_name: where to save the output image, `{idx}` will be replaced by the index of the sample
112    :param prompt: Prompt used for sampling
113    :param device: Pytorch device
114    :param num_steps: number of sampling steps (default 4 for schnell, 28 for guidance distilled)
115    :param loop: start an interactive session and sample multiple times
116    :param guidance: guidance value used for guidance distillation"""
117
118    if quantization:
119        mir_id += ":*@fp8-sai"
120    model_spec: ModelSpec = get_model_spec(mir_id, flux_configs)
121    ae_spec: ModelSpec = get_model_spec(ae_id, flux_configs)
122    init: InitialParamsFlux = model_spec.init
123
124    prompt_parts = prompt.split("|")
125    if len(prompt_parts) == 1:
126        prompt = prompt_parts[0]
127        additional_prompts = None
128    else:
129        additional_prompts = prompt_parts[1:]
130        prompt = prompt_parts[0]
131
132    assert not ((additional_prompts is not None) and loop), "Do not provide additional prompts and set loop to True"
133
134    height = 16 * (height // 16)
135    width = 16 * (width // 16)
136
137    t5 = load_t5(device, init.max_length or 512)
138    clip = load_clip(device)
139    model = load_flow_model(
140        model_spec,
141        device=torch.device("cpu") if offload else device,
142        verbose=verbose,
143    )
144
145    is_compiled = False
146    if compile and not offload:  # (compiled models should be on target device)
147        nfo("Compilation enabled.")
148        model = torch.compile(model)  # type: ignore[assignment]
149        is_compiled = True
150
151    ae = load_ae(ae_spec, device=torch.device("cpu") if offload else device)
152
153    state = MenuState.from_cli_args(
154        prompt=prompt,
155        width=width,
156        height=height,
157        num_steps=num_steps or init.num_steps,
158        guidance=guidance,
159        seed=seed,
160    )
161
162    if loop:
163        state = parse_prompt(state)
164
165    while state is not None:
166        if state.seed is None:
167            seed = rng.next_seed()
168            state = replace(state, seed=seed)
169        else:
170            rng.next_seed(state.seed)
171        assert state.seed is not None, "Seed must be set"
172        assert state.width is not None and state.height is not None, "Width and height must be set"
173        assert state.num_steps is not None, "num_steps must be set"
174        assert state.prompt is not None, "Prompt must be set"
175        nfo(f"Generating with seed {rng.seed}: {state.prompt}")
176
177        from divisor.noise import get_noise
178
179        x_4d = get_noise(
180            1,
181            state.height,
182            state.width,
183            device=device,
184            dtype=torch.bfloat16,
185            seed=rng.seed,  # type: ignore
186        )
187
188        if offload:
189            ae = ae.cpu()
190            empty_cache
191
192            t5, clip = t5.to(device), clip.to(device)
193        x_3d = prepare_4d_noise_for_3d_model(
194            height=state.height,  # type: ignore
195            width=state.width,  # type: ignore
196            seed=rng.seed,  # type: ignore
197            t5=t5,
198            clip=clip,
199            prompt=state.prompt,
200            device=device,
201            dtype=torch.bfloat16,
202        )
203        inp = prepare(t5, clip, x_4d, prompt=state.prompt)
204        timesteps = get_schedule(state.num_steps, inp["img"].shape[1], shift=init.shift)
205        # Update state with runtime information (use 3D format for current_sample)
206        state = state.with_runtime_state(
207            current_timestep=0.0,
208            current_sample=x_3d,
209            timestep_index=0,
210            total_timesteps=len(timesteps),
211        )
212        # offload TEs to CPU, load model to gpu
213        if offload:
214            t5, clip = t5.cpu(), clip.cpu()
215            empty_cache
216            if is_compiled:
217                raise RuntimeError("Can't move compiled models. Compile on target device, or disable compilation when offloading.")
218            model = model.to(device)  # type: ignore[attr-defined]
219            if compile:
220                nfo("Compiling model on device.")
221                model = torch.compile(model, mode="max-autotune")  # type: ignore[assignment]
222                is_compiled = True
223
224        from divisor.state import InferenceState
225
226        settings = InferenceState(
227            img=inp["img"],
228            img_ids=inp["img_ids"],
229            txt=inp["txt"],
230            txt_ids=inp["txt_ids"],
231            vec=inp["vec"],
232            state=state,
233            ae=ae,
234            timesteps=timesteps,
235            device=device,
236            t5=t5,
237            clip=clip,
238        )
239        x_4d = denoise(
240            model,  # type: ignore[arg-type]
241            settings=settings,
242        )
243
244        if loop:
245            nfo("-" * 80)
246            state = parse_prompt(state)
247        elif additional_prompts:
248            next_prompt = additional_prompts.pop(0)
249            state = replace(state, prompt=next_prompt)
250        else:
251            state = None

Sample the flux model. Either interactively (set --loop) or run for a single image.

Parameters
  • name: Name of the model to load
  • height: height of the sample in pixels (should be a multiple of 16)
  • width: width of the sample in pixels (should be a multiple of 16)
  • seed: Set a seed for sampling
  • output_name: where to save the output image, {idx} will be replaced by the index of the sample
  • prompt: Prompt used for sampling
  • device: Pytorch device
  • num_steps: number of sampling steps (default 4 for schnell, 28 for guidance distilled)
  • loop: start an interactive session and sample multiple times
  • guidance: guidance value used for guidance distillation