divisor.flux2.prompt

  1# SPDX-License-Identifier: MPL-2.0 AND LicenseRef-Commons-Clause-License-Condition-1.0
  2# <!-- // /*  d a r k s h a p e s */ -->
  3# adapted BFL Flux code from https://github.com/black-forest-labs/flux2
  4
  5from dataclasses import replace
  6
  7import torch
  8from fire import Fire
  9from nnll.console import nfo
 10from PIL import Image
 11
 12from divisor.controller import rng
 13from divisor.flux1.loading import load_ae, load_flow_model, load_mistral_small_embedder
 14from divisor.flux1.prompt import parse_prompt
 15from divisor.flux2.sampling import (
 16    batched_prc_img,
 17    batched_prc_txt,
 18    denoise_interactive,
 19    encode_image_refs,
 20    get_schedule,
 21)
 22from divisor.noise import get_noise
 23from divisor.registry import gfx_device, gfx_dtype, empty_cache
 24from divisor.spec import ModelSpec, flux_configs, get_model_spec
 25from divisor.state import InferenceState, MenuState
 26
 27
 28def main(
 29    mir_id: str = "model.dit.flux2-dev",
 30    ae_id: str = "model.vae.flux2-dev",
 31    width: int = 1360,
 32    height: int = 768,
 33    guidance: float = 4,
 34    seed: int = rng.next_seed(),
 35    prompt: str = "",
 36    quantization: bool = False,
 37    device: torch.device = gfx_device,
 38    num_steps: int = 50,
 39    upsample_prompt: bool = False,
 40    loop: bool = False,
 41    offload: bool = False,
 42    compile: bool = False,
 43    verbose: bool = False,
 44    input_images: list[str] | None = None,
 45) -> None:
 46    """Sample the flux model. Either interactively (set `--loop`) or run for a single image.\n
 47    :param name: Name of the model to load
 48    :param height: height of the sample in pixels (should be a multiple of 16)
 49    :param width: width of the sample in pixels (should be a multiple of 16)
 50    :param seed: Set a seed for sampling
 51    :param output_name: where to save the output image, `{idx}` will be replaced by the index of the sample
 52    :param prompt: Prompt used for sampling
 53    :param device: Pytorch device
 54    :param num_steps: number of sampling steps (default 4 for schnell, 28 for guidance distilled)
 55    :param loop: start an interactive session and sample multiple times
 56    :param guidance: guidance value used for guidance distillation
 57    """
 58
 59    precision = gfx_dtype
 60    prompt_parts = prompt.split("|")
 61    if len(prompt_parts) == 1:
 62        prompt = prompt_parts[0]
 63        additional_prompts = None
 64    else:
 65        additional_prompts = prompt_parts[1:]
 66        prompt = prompt_parts[0]
 67
 68    mistral = load_mistral_small_embedder()
 69    if quantization:
 70        mir_id += ":*@fp8-sai"
 71    model_spec: ModelSpec = get_model_spec(mir_id, flux_configs)
 72    ae_spec = get_model_spec(ae_id, flux_configs)
 73    model = load_flow_model(
 74        model_spec,
 75        device=torch.device("cpu") if offload else device,
 76        verbose=verbose,
 77    )
 78
 79    is_compiled = False
 80    if compile and not offload:  # Compiled models can't be easily moved between devices
 81        nfo("Compilation enabled.")
 82        model = torch.compile(model)  # type: ignore[assignment]
 83        is_compiled = True
 84
 85    ae = load_ae(ae_spec, device=torch.device("cpu") if offload else device)
 86    ae.eval()
 87    mistral.eval()
 88
 89    state = MenuState.from_cli_args(
 90        prompt=prompt,
 91        width=width,
 92        height=height,
 93        num_steps=num_steps,
 94        guidance=guidance,
 95        seed=seed,
 96    )
 97
 98    if loop:
 99        state = parse_prompt(state)
100
101    while state is not None:
102        if state.seed is None:
103            seed = rng.next_seed()
104            state = replace(state, seed=seed)
105        else:
106            rng.next_seed(state.seed)
107        # At this point, state.seed is guaranteed to be an int
108        assert state.seed is not None, "Seed must be set"
109        assert state.width is not None and state.height is not None, "Width and height must be set"
110        assert state.num_steps is not None, "num_steps must be set"
111        assert state.prompt is not None, "Prompt must be set"
112        nfo(f"Generating with seed {rng.seed}: {state.prompt}")
113
114        x = get_noise(
115            1,
116            state.height,
117            state.width,
118            dtype=precision,
119            seed=rng.seed,  # type: ignore
120            device=device,
121            version_2=True,
122        )
123
124        with torch.no_grad():
125            if input_images:
126                img_ctx = [Image.open(input_image) for input_image in input_images]
127                ref_tokens, ref_ids = encode_image_refs(ae, img_ctx)  # type: ignore
128            else:
129                ref_tokens = None
130                ref_ids = None
131
132            if upsample_prompt:
133                # Use local model for upsampling
134                upsampled_prompts = mistral.upsample_prompt([state.prompt], img=[img_ctx] if img_ctx else None)  # type: ignore
135                prompt = upsampled_prompts[0] if upsampled_prompts else state.prompt
136                state = replace(state, prompt=prompt)
137            else:
138                prompt = state.prompt
139
140            ctx = mistral([prompt]).to(precision)  # type: ignore
141            ctx, ctx_ids = batched_prc_txt(ctx)
142
143            randn = get_noise(
144                1,
145                state.height,  # type: ignore
146                state.width,  # type: ignore
147                dtype=torch.bfloat16,
148                seed=state.seed,  # type: ignore
149                device=device,
150                version_2=True,
151            )
152
153            x, x_ids = batched_prc_img(randn)
154
155            timesteps = get_schedule(state.num_steps, x.shape[1])  # type: ignore
156            # Update state with runtime information
157            state = state.with_runtime_state(
158                current_timestep=0.0,
159                current_sample=x,
160                timestep_index=0,
161                total_timesteps=len(timesteps),
162            )
163
164            if offload:
165                if device.type == "cuda":
166                    mistral = mistral.cpu()  # type: ignore
167                else:
168                    mistral = None
169                    del mistral
170                empty_cache
171                if is_compiled:
172                    raise RuntimeError("Cannot use offload=True with compile=True. Compile after model is on device, or disable compilation when offloading.")
173                model = model.to(device)  # type: ignore[attr-defined]
174                if compile:
175                    nfo("Compiling model on device.")
176                    model = torch.compile(model, mode="max-autotune")  # type: ignore[assignment]
177                    is_compiled = True
178
179            x = denoise_interactive(
180                model,  # type: ignore
181                InferenceState(
182                    img=x,
183                    img_ids=x_ids,
184                    txt=ctx,
185                    txt_ids=ctx_ids,
186                    timesteps=timesteps,
187                    img_cond_seq=ref_tokens,
188                    img_cond_seq_ids=ref_ids,
189                    state=state,
190                    ae=ae,
191                ),
192            )
193            if loop:
194                nfo("-" * 80)
195                state = parse_prompt(state)
196            elif additional_prompts:
197                next_prompt = additional_prompts.pop(0)
198                state = replace(state, prompt=next_prompt)
199            else:
200                state = None
201
202
203if __name__ == "__main__":
204    Fire(main)
def main( mir_id: str = 'model.dit.flux2-dev', ae_id: str = 'model.vae.flux2-dev', width: int = 1360, height: int = 768, guidance: float = 4, seed: int = 9245605556668510957, prompt: str = '', quantization: bool = False, device: torch.device = device(type='mps'), num_steps: int = 50, upsample_prompt: bool = False, loop: bool = False, offload: bool = False, compile: bool = False, verbose: bool = False, input_images: list[str] | None = None) -> None:
 29def main(
 30    mir_id: str = "model.dit.flux2-dev",
 31    ae_id: str = "model.vae.flux2-dev",
 32    width: int = 1360,
 33    height: int = 768,
 34    guidance: float = 4,
 35    seed: int = rng.next_seed(),
 36    prompt: str = "",
 37    quantization: bool = False,
 38    device: torch.device = gfx_device,
 39    num_steps: int = 50,
 40    upsample_prompt: bool = False,
 41    loop: bool = False,
 42    offload: bool = False,
 43    compile: bool = False,
 44    verbose: bool = False,
 45    input_images: list[str] | None = None,
 46) -> None:
 47    """Sample the flux model. Either interactively (set `--loop`) or run for a single image.\n
 48    :param name: Name of the model to load
 49    :param height: height of the sample in pixels (should be a multiple of 16)
 50    :param width: width of the sample in pixels (should be a multiple of 16)
 51    :param seed: Set a seed for sampling
 52    :param output_name: where to save the output image, `{idx}` will be replaced by the index of the sample
 53    :param prompt: Prompt used for sampling
 54    :param device: Pytorch device
 55    :param num_steps: number of sampling steps (default 4 for schnell, 28 for guidance distilled)
 56    :param loop: start an interactive session and sample multiple times
 57    :param guidance: guidance value used for guidance distillation
 58    """
 59
 60    precision = gfx_dtype
 61    prompt_parts = prompt.split("|")
 62    if len(prompt_parts) == 1:
 63        prompt = prompt_parts[0]
 64        additional_prompts = None
 65    else:
 66        additional_prompts = prompt_parts[1:]
 67        prompt = prompt_parts[0]
 68
 69    mistral = load_mistral_small_embedder()
 70    if quantization:
 71        mir_id += ":*@fp8-sai"
 72    model_spec: ModelSpec = get_model_spec(mir_id, flux_configs)
 73    ae_spec = get_model_spec(ae_id, flux_configs)
 74    model = load_flow_model(
 75        model_spec,
 76        device=torch.device("cpu") if offload else device,
 77        verbose=verbose,
 78    )
 79
 80    is_compiled = False
 81    if compile and not offload:  # Compiled models can't be easily moved between devices
 82        nfo("Compilation enabled.")
 83        model = torch.compile(model)  # type: ignore[assignment]
 84        is_compiled = True
 85
 86    ae = load_ae(ae_spec, device=torch.device("cpu") if offload else device)
 87    ae.eval()
 88    mistral.eval()
 89
 90    state = MenuState.from_cli_args(
 91        prompt=prompt,
 92        width=width,
 93        height=height,
 94        num_steps=num_steps,
 95        guidance=guidance,
 96        seed=seed,
 97    )
 98
 99    if loop:
100        state = parse_prompt(state)
101
102    while state is not None:
103        if state.seed is None:
104            seed = rng.next_seed()
105            state = replace(state, seed=seed)
106        else:
107            rng.next_seed(state.seed)
108        # At this point, state.seed is guaranteed to be an int
109        assert state.seed is not None, "Seed must be set"
110        assert state.width is not None and state.height is not None, "Width and height must be set"
111        assert state.num_steps is not None, "num_steps must be set"
112        assert state.prompt is not None, "Prompt must be set"
113        nfo(f"Generating with seed {rng.seed}: {state.prompt}")
114
115        x = get_noise(
116            1,
117            state.height,
118            state.width,
119            dtype=precision,
120            seed=rng.seed,  # type: ignore
121            device=device,
122            version_2=True,
123        )
124
125        with torch.no_grad():
126            if input_images:
127                img_ctx = [Image.open(input_image) for input_image in input_images]
128                ref_tokens, ref_ids = encode_image_refs(ae, img_ctx)  # type: ignore
129            else:
130                ref_tokens = None
131                ref_ids = None
132
133            if upsample_prompt:
134                # Use local model for upsampling
135                upsampled_prompts = mistral.upsample_prompt([state.prompt], img=[img_ctx] if img_ctx else None)  # type: ignore
136                prompt = upsampled_prompts[0] if upsampled_prompts else state.prompt
137                state = replace(state, prompt=prompt)
138            else:
139                prompt = state.prompt
140
141            ctx = mistral([prompt]).to(precision)  # type: ignore
142            ctx, ctx_ids = batched_prc_txt(ctx)
143
144            randn = get_noise(
145                1,
146                state.height,  # type: ignore
147                state.width,  # type: ignore
148                dtype=torch.bfloat16,
149                seed=state.seed,  # type: ignore
150                device=device,
151                version_2=True,
152            )
153
154            x, x_ids = batched_prc_img(randn)
155
156            timesteps = get_schedule(state.num_steps, x.shape[1])  # type: ignore
157            # Update state with runtime information
158            state = state.with_runtime_state(
159                current_timestep=0.0,
160                current_sample=x,
161                timestep_index=0,
162                total_timesteps=len(timesteps),
163            )
164
165            if offload:
166                if device.type == "cuda":
167                    mistral = mistral.cpu()  # type: ignore
168                else:
169                    mistral = None
170                    del mistral
171                empty_cache
172                if is_compiled:
173                    raise RuntimeError("Cannot use offload=True with compile=True. Compile after model is on device, or disable compilation when offloading.")
174                model = model.to(device)  # type: ignore[attr-defined]
175                if compile:
176                    nfo("Compiling model on device.")
177                    model = torch.compile(model, mode="max-autotune")  # type: ignore[assignment]
178                    is_compiled = True
179
180            x = denoise_interactive(
181                model,  # type: ignore
182                InferenceState(
183                    img=x,
184                    img_ids=x_ids,
185                    txt=ctx,
186                    txt_ids=ctx_ids,
187                    timesteps=timesteps,
188                    img_cond_seq=ref_tokens,
189                    img_cond_seq_ids=ref_ids,
190                    state=state,
191                    ae=ae,
192                ),
193            )
194            if loop:
195                nfo("-" * 80)
196                state = parse_prompt(state)
197            elif additional_prompts:
198                next_prompt = additional_prompts.pop(0)
199                state = replace(state, prompt=next_prompt)
200            else:
201                state = None

Sample the flux model. Either interactively (set --loop) or run for a single image.

Parameters
  • name: Name of the model to load
  • height: height of the sample in pixels (should be a multiple of 16)
  • width: width of the sample in pixels (should be a multiple of 16)
  • seed: Set a seed for sampling
  • output_name: where to save the output image, {idx} will be replaced by the index of the sample
  • prompt: Prompt used for sampling
  • device: Pytorch device
  • num_steps: number of sampling steps (default 4 for schnell, 28 for guidance distilled)
  • loop: start an interactive session and sample multiple times
  • guidance: guidance value used for guidance distillation