divisor.acestep.ui.components

ACE-Step: A Step Towards Music Generation Foundation Model

https://github.com/ace-step/ACE-Step

Apache 2.0 License

  1"""
  2ACE-Step: A Step Towards Music Generation Foundation Model
  3
  4https://github.com/ace-step/ACE-Step
  5
  6Apache 2.0 License
  7"""
  8
  9import gradio as gr
 10import librosa
 11import os
 12
 13
 14TAG_DEFAULT = "funk, pop, soul, rock, melodic, guitar, drums, bass, keyboard, percussion, 105 BPM, energetic, upbeat, groovy, vibrant, dynamic"
 15LYRIC_DEFAULT = """[verse]
 16Neon lights they flicker bright
 17City hums in dead of night
 18Rhythms pulse through concrete veins
 19Lost in echoes of refrains
 20
 21[verse]
 22Bassline groovin' in my chest
 23Heartbeats match the city's zest
 24Electric whispers fill the air
 25Synthesized dreams everywhere
 26
 27[chorus]
 28Turn it up and let it flow
 29Feel the fire let it grow
 30In this rhythm we belong
 31Hear the night sing out our song
 32
 33[verse]
 34Guitar strings they start to weep
 35Wake the soul from silent sleep
 36Every note a story told
 37In this night we’re bold and gold
 38
 39[bridge]
 40Voices blend in harmony
 41Lost in pure cacophony
 42Timeless echoes timeless cries
 43Soulful shouts beneath the skies
 44
 45[verse]
 46Keyboard dances on the keys
 47Melodies on evening breeze
 48Catch the tune and hold it tight
 49In this moment we take flight
 50"""
 51
 52# First, let's define the presets at the top of the file, after the imports
 53GENRE_PRESETS = {
 54    "Modern Pop": "pop, synth, drums, guitar, 120 bpm, upbeat, catchy, vibrant, female vocals, polished vocals",
 55    "Rock": "rock, electric guitar, drums, bass, 130 bpm, energetic, rebellious, gritty, male vocals, raw vocals",
 56    "Hip Hop": "hip hop, 808 bass, hi-hats, synth, 90 bpm, bold, urban, intense, male vocals, rhythmic vocals",
 57    "Country": "country, acoustic guitar, steel guitar, fiddle, 100 bpm, heartfelt, rustic, warm, male vocals, twangy vocals",
 58    "EDM": "edm, synth, bass, kick drum, 128 bpm, euphoric, pulsating, energetic, instrumental",
 59    "Reggae": "reggae, guitar, bass, drums, 80 bpm, chill, soulful, positive, male vocals, smooth vocals",
 60    "Classical": "classical, orchestral, strings, piano, 60 bpm, elegant, emotive, timeless, instrumental",
 61    "Jazz": "jazz, saxophone, piano, double bass, 110 bpm, smooth, improvisational, soulful, male vocals, crooning vocals",
 62    "Metal": "metal, electric guitar, double kick drum, bass, 160 bpm, aggressive, intense, heavy, male vocals, screamed vocals",
 63    "R&B": "r&b, synth, bass, drums, 85 bpm, sultry, groovy, romantic, female vocals, silky vocals",
 64}
 65
 66
 67# Add this function to handle preset selection
 68def update_tags_from_preset(preset_name):
 69    if preset_name == "Custom":
 70        return ""
 71    return GENRE_PRESETS.get(preset_name, "")
 72
 73
 74def create_output_ui(task_name="Text2Music"):
 75    # For many consumer-grade GPU devices, only one batch can be run
 76    output_audio1 = gr.Audio(type="filepath", label=f"{task_name} Generated Audio 1")
 77    # output_audio2 = gr.Audio(type="filepath", label="Generated Audio 2")
 78    with gr.Accordion(f"{task_name} Parameters", open=False):
 79        input_params_json = gr.JSON(label=f"{task_name} Parameters")
 80    # outputs = [output_audio1, output_audio2]
 81    outputs = [output_audio1]
 82    return outputs, input_params_json
 83
 84
 85def dump_func(*args):
 86    print(args)
 87    return []
 88
 89
 90def create_text2music_ui(
 91    gr,
 92    text2music_process_func,
 93):
 94    with gr.Row(equal_height=True):
 95        # Get base output directory from environment variable, defaulting to CWD-relative 'outputs'.
 96        # This default (./outputs) is suitable for non-Docker local development.
 97        # For Docker, the ACE_OUTPUT_DIR environment variable should be set (e.g., to /app/outputs).
 98        output_file_dir = os.environ.get("ACE_OUTPUT_DIR", "./outputs")
 99        if not os.path.isdir(output_file_dir):
100            os.makedirs(output_file_dir, exist_ok=True)
101        json_files = [f for f in os.listdir(output_file_dir) if f.endswith(".json")]
102        json_files.sort(reverse=True, key=lambda x: int(x.split("_")[1]))
103        output_files = gr.Dropdown(choices=json_files, label="Select previous generated input params", scale=9, interactive=True)
104        load_bnt = gr.Button("Load", variant="primary", scale=1)
105
106    with gr.Row():
107        with gr.Column():
108            with gr.Row(equal_height=True):
109                # add markdown, tags and lyrics examples are from ai music generation community
110                audio_duration = gr.Slider(
111                    -1,
112                    240.0,
113                    step=0.00001,
114                    value=-1,
115                    label="Audio Duration",
116                    interactive=True,
117                    info="-1 means random duration (30 ~ 240).",
118                    scale=9,
119                )
120                format = gr.Dropdown(choices=["mp3", "ogg", "flac", "wav"], value="wav", label="Format")
121                sample_bnt = gr.Button("Sample", variant="secondary", scale=1)
122
123            # audio2audio
124            with gr.Row(equal_height=True):
125                audio2audio_enable = gr.Checkbox(
126                    label="Enable Audio2Audio", value=False, info="Check to enable Audio-to-Audio generation using a reference audio.", elem_id="audio2audio_checkbox"
127                )
128                lora_name_or_path = gr.Dropdown(
129                    label="Lora Name or Path", choices=["ACE-Step/ACE-Step-v1-chinese-rap-LoRA", "none"], value="none", allow_custom_value=True, min_width=300
130                )
131                lora_weight = gr.Number(value=1.0, label="Lora weight", step=0.1, maximum=3, minimum=-3)
132
133            ref_audio_input = gr.Audio(
134                type="filepath",
135                label="Reference Audio (for Audio2Audio)",
136                visible=False,
137                elem_id="ref_audio_input",
138                # show_download_button=True,
139            )
140            ref_audio_strength = gr.Slider(
141                label="Refer audio strength",
142                minimum=0.0,
143                maximum=1.0,
144                step=0.01,
145                value=0.5,
146                elem_id="ref_audio_strength",
147                visible=False,
148                interactive=True,
149            )
150
151            def toggle_ref_audio_visibility(is_checked):
152                return (
153                    gr.update(visible=is_checked, elem_id="ref_audio_input"),
154                    gr.update(visible=is_checked, elem_id="ref_audio_strength"),
155                )
156
157            audio2audio_enable.change(
158                fn=toggle_ref_audio_visibility,
159                inputs=[audio2audio_enable],
160                outputs=[ref_audio_input, ref_audio_strength],
161            )
162
163            with gr.Column(scale=2):
164                with gr.Group():
165                    gr.Markdown(
166                        """<center>Support tags, descriptions, and scene. Use commas to separate different tags.<br>Tags and lyrics examples are from AI music generation community.</center>"""
167                    )
168                    with gr.Row():
169                        genre_preset = gr.Dropdown(
170                            choices=["Custom"] + list(GENRE_PRESETS.keys()),
171                            value="Custom",
172                            label="Preset",
173                            scale=1,
174                        )
175                        prompt = gr.Textbox(
176                            lines=1,
177                            label="Tags",
178                            max_lines=4,
179                            value=TAG_DEFAULT,
180                            scale=9,
181                        )
182
183            # Add the change event for the preset dropdown
184            genre_preset.change(fn=update_tags_from_preset, inputs=[genre_preset], outputs=[prompt])
185            with gr.Group():
186                gr.Markdown(
187                    """<center>Support lyric structure tags like [verse], [chorus], and [bridge] to separate different parts of the lyrics.<br>Use [instrumental] or [inst] to generate instrumental music. Not support genre structure tag in lyrics</center>"""
188                )
189                lyrics = gr.Textbox(
190                    lines=9,
191                    label="Lyrics",
192                    max_lines=13,
193                    value=LYRIC_DEFAULT,
194                )
195
196            with gr.Accordion("Basic Settings", open=False):
197                infer_step = gr.Slider(
198                    minimum=1,
199                    maximum=200,
200                    step=1,
201                    value=60,
202                    label="Infer Steps",
203                    interactive=True,
204                )
205                guidance_scale = gr.Slider(
206                    minimum=0.0,
207                    maximum=30.0,
208                    step=0.1,
209                    value=15.0,
210                    label="Guidance Scale",
211                    interactive=True,
212                    info="When guidance_scale_lyric > 1 and guidance_scale_text > 1, the guidance scale will not be applied.",
213                )
214                guidance_scale_text = gr.Slider(
215                    minimum=0.0,
216                    maximum=10.0,
217                    step=0.1,
218                    value=0.0,
219                    label="Guidance Scale Text",
220                    interactive=True,
221                    info="Guidance scale for text condition. It can only apply to cfg. set guidance_scale_text=5.0, guidance_scale_lyric=1.5 for start",
222                )
223                guidance_scale_lyric = gr.Slider(
224                    minimum=0.0,
225                    maximum=10.0,
226                    step=0.1,
227                    value=0.0,
228                    label="Guidance Scale Lyric",
229                    interactive=True,
230                )
231
232                manual_seeds = gr.Textbox(
233                    label="manual seeds (default None)",
234                    placeholder="1,2,3,4",
235                    value=None,
236                    info="Seed for the generation",
237                )
238
239            with gr.Accordion("Advanced Settings", open=False):
240                scheduler_type = gr.Radio(
241                    ["euler", "heun", "pingpong"],
242                    value="euler",
243                    label="Scheduler Type",
244                    elem_id="scheduler_type",
245                    info="Scheduler type for the generation. euler is recommended. heun will take more time. pingpong use SDE",
246                )
247                cfg_type = gr.Radio(
248                    ["cfg", "apg", "cfg_star"],
249                    value="apg",
250                    label="CFG Type",
251                    elem_id="cfg_type",
252                    info="CFG type for the generation. apg is recommended. cfg and cfg_star are almost the same.",
253                )
254                use_erg_tag = gr.Checkbox(
255                    label="use ERG for tag",
256                    value=True,
257                    info="Use Entropy Rectifying Guidance for tag. It will multiple a temperature to the attention to make a weaker tag condition and make better diversity.",
258                )
259                use_erg_lyric = gr.Checkbox(
260                    label="use ERG for lyric",
261                    value=False,
262                    info="The same but apply to lyric encoder's attention.",
263                )
264                use_erg_diffusion = gr.Checkbox(
265                    label="use ERG for diffusion",
266                    value=True,
267                    info="The same but apply to diffusion model's attention.",
268                )
269
270                omega_scale = gr.Slider(
271                    minimum=-100.0,
272                    maximum=100.0,
273                    step=0.1,
274                    value=10.0,
275                    label="Granularity Scale",
276                    interactive=True,
277                    info="Granularity scale for the generation. Higher values can reduce artifacts",
278                )
279
280                guidance_interval = gr.Slider(
281                    minimum=0.0,
282                    maximum=1.0,
283                    step=0.01,
284                    value=0.5,
285                    label="Guidance Interval",
286                    interactive=True,
287                    info="Guidance interval for the generation. 0.5 means only apply guidance in the middle steps (0.25 * infer_steps to 0.75 * infer_steps)",
288                )
289                guidance_interval_decay = gr.Slider(
290                    minimum=0.0,
291                    maximum=1.0,
292                    step=0.01,
293                    value=0.0,
294                    label="Guidance Interval Decay",
295                    interactive=True,
296                    info="Guidance interval decay for the generation. Guidance scale will decay from guidance_scale to min_guidance_scale in the interval. 0.0 means no decay.",
297                )
298                min_guidance_scale = gr.Slider(
299                    minimum=0.0,
300                    maximum=200.0,
301                    step=0.1,
302                    value=3.0,
303                    label="Min Guidance Scale",
304                    interactive=True,
305                    info="Min guidance scale for guidance interval decay's end scale",
306                )
307                oss_steps = gr.Textbox(
308                    label="OSS Steps",
309                    placeholder="16, 29, 52, 96, 129, 158, 172, 183, 189, 200",
310                    value=None,
311                    info="Optimal Steps for the generation. But not test well",
312                )
313
314            text2music_bnt = gr.Button("Generate", variant="primary")
315
316        with gr.Column():
317            outputs, input_params_json = create_output_ui()
318            with gr.Tab("retake"):
319                retake_variance = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, value=0.2, label="variance")
320                retake_seeds = gr.Textbox(label="retake seeds (default None)", placeholder="", value=None)
321                retake_bnt = gr.Button("Retake", variant="primary")
322                retake_outputs, retake_input_params_json = create_output_ui("Retake")
323
324                def retake_process_func(json_data, retake_variance, retake_seeds):
325                    return text2music_process_func(
326                        json_data["format"],
327                        json_data["audio_duration"],
328                        json_data["prompt"],
329                        json_data["lyrics"],
330                        json_data["infer_step"],
331                        json_data["guidance_scale"],
332                        json_data["scheduler_type"],
333                        json_data["cfg_type"],
334                        json_data["omega_scale"],
335                        ", ".join(map(str, json_data["actual_seeds"])),
336                        json_data["guidance_interval"],
337                        json_data["guidance_interval_decay"],
338                        json_data["min_guidance_scale"],
339                        json_data["use_erg_tag"],
340                        json_data["use_erg_lyric"],
341                        json_data["use_erg_diffusion"],
342                        ", ".join(map(str, json_data["oss_steps"])),
343                        (json_data["guidance_scale_text"] if "guidance_scale_text" in json_data else 0.0),
344                        (json_data["guidance_scale_lyric"] if "guidance_scale_lyric" in json_data else 0.0),
345                        retake_seeds=retake_seeds,
346                        retake_variance=retake_variance,
347                        task="retake",
348                        lora_name_or_path="none" if "lora_name_or_path" not in json_data else json_data["lora_name_or_path"],
349                        lora_weight=1 if "lora_weight" not in json_data else json_data["lora_weight"],
350                    )
351
352                retake_bnt.click(
353                    fn=retake_process_func,
354                    inputs=[
355                        input_params_json,
356                        retake_variance,
357                        retake_seeds,
358                    ],
359                    outputs=retake_outputs + [retake_input_params_json],
360                )
361            with gr.Tab("repainting"):
362                retake_variance = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, value=0.2, label="variance")
363                retake_seeds = gr.Textbox(label="repaint seeds (default None)", placeholder="", value=None)
364                repaint_start = gr.Slider(
365                    minimum=0.0,
366                    maximum=240.0,
367                    step=0.01,
368                    value=0.0,
369                    label="Repaint Start Time",
370                    interactive=True,
371                )
372                repaint_end = gr.Slider(
373                    minimum=0.0,
374                    maximum=240.0,
375                    step=0.01,
376                    value=30.0,
377                    label="Repaint End Time",
378                    interactive=True,
379                )
380                repaint_source = gr.Radio(
381                    ["text2music", "last_repaint", "upload"],
382                    value="text2music",
383                    label="Repaint Source",
384                    elem_id="repaint_source",
385                )
386
387                repaint_source_audio_upload = gr.Audio(
388                    label="Upload Audio",
389                    type="filepath",
390                    visible=False,
391                    elem_id="repaint_source_audio_upload",
392                    # show_download_button=True,
393                )
394                repaint_source.change(
395                    fn=lambda x: gr.update(visible=x == "upload", elem_id="repaint_source_audio_upload"),
396                    inputs=[repaint_source],
397                    outputs=[repaint_source_audio_upload],
398                )
399
400                repaint_bnt = gr.Button("Repaint", variant="primary")
401                repaint_outputs, repaint_input_params_json = create_output_ui("Repaint")
402
403                def repaint_process_func(
404                    text2music_json_data,
405                    repaint_json_data,
406                    retake_variance,
407                    retake_seeds,
408                    repaint_start,
409                    repaint_end,
410                    repaint_source,
411                    repaint_source_audio_upload,
412                    prompt,
413                    lyrics,
414                    infer_step,
415                    guidance_scale,
416                    scheduler_type,
417                    cfg_type,
418                    omega_scale,
419                    manual_seeds,
420                    guidance_interval,
421                    guidance_interval_decay,
422                    min_guidance_scale,
423                    use_erg_tag,
424                    use_erg_lyric,
425                    use_erg_diffusion,
426                    oss_steps,
427                    guidance_scale_text,
428                    guidance_scale_lyric,
429                ):
430                    if repaint_source == "upload":
431                        src_audio_path = repaint_source_audio_upload
432                        audio_duration = librosa.get_duration(filename=src_audio_path)
433                        json_data = {"audio_duration": audio_duration}
434                    elif repaint_source == "text2music":
435                        json_data = text2music_json_data
436                        src_audio_path = json_data["audio_path"]
437                    elif repaint_source == "last_repaint":
438                        json_data = repaint_json_data
439                        src_audio_path = json_data["audio_path"]
440
441                    return text2music_process_func(
442                        format.value,
443                        json_data["audio_duration"],
444                        prompt,
445                        lyrics,
446                        infer_step,
447                        guidance_scale,
448                        scheduler_type,
449                        cfg_type,
450                        omega_scale,
451                        manual_seeds,
452                        guidance_interval,
453                        guidance_interval_decay,
454                        min_guidance_scale,
455                        use_erg_tag,
456                        use_erg_lyric,
457                        use_erg_diffusion,
458                        oss_steps,
459                        guidance_scale_text,
460                        guidance_scale_lyric,
461                        retake_seeds=retake_seeds,
462                        retake_variance=retake_variance,
463                        task="repaint",
464                        repaint_start=repaint_start,
465                        repaint_end=repaint_end,
466                        src_audio_path=src_audio_path,
467                        lora_name_or_path="none" if "lora_name_or_path" not in json_data else json_data["lora_name_or_path"],
468                        lora_weight=1 if "lora_weight" not in json_data else json_data["lora_weight"],
469                    )
470
471                repaint_bnt.click(
472                    fn=repaint_process_func,
473                    inputs=[
474                        input_params_json,
475                        repaint_input_params_json,
476                        retake_variance,
477                        retake_seeds,
478                        repaint_start,
479                        repaint_end,
480                        repaint_source,
481                        repaint_source_audio_upload,
482                        prompt,
483                        lyrics,
484                        infer_step,
485                        guidance_scale,
486                        scheduler_type,
487                        cfg_type,
488                        omega_scale,
489                        manual_seeds,
490                        guidance_interval,
491                        guidance_interval_decay,
492                        min_guidance_scale,
493                        use_erg_tag,
494                        use_erg_lyric,
495                        use_erg_diffusion,
496                        oss_steps,
497                        guidance_scale_text,
498                        guidance_scale_lyric,
499                    ],
500                    outputs=repaint_outputs + [repaint_input_params_json],
501                )
502            with gr.Tab("edit"):
503                edit_prompt = gr.Textbox(lines=2, label="Edit Tags", max_lines=4)
504                edit_lyrics = gr.Textbox(lines=9, label="Edit Lyrics", max_lines=13)
505                retake_seeds = gr.Textbox(label="edit seeds (default None)", placeholder="", value=None)
506
507                edit_type = gr.Radio(
508                    ["only_lyrics", "remix"],
509                    value="only_lyrics",
510                    label="Edit Type",
511                    elem_id="edit_type",
512                    info="`only_lyrics` will keep the whole song the same except lyrics difference. Make your diffrence smaller, e.g. one lyrc line change.\nremix can change the song melody and genre",
513                )
514                edit_n_min = gr.Slider(
515                    minimum=0.0,
516                    maximum=1.0,
517                    step=0.01,
518                    value=0.6,
519                    label="edit_n_min",
520                    interactive=True,
521                )
522                edit_n_max = gr.Slider(
523                    minimum=0.0,
524                    maximum=1.0,
525                    step=0.01,
526                    value=1.0,
527                    label="edit_n_max",
528                    interactive=True,
529                )
530
531                def edit_type_change_func(edit_type):
532                    if edit_type == "only_lyrics":
533                        n_min = 0.6
534                        n_max = 1.0
535                    elif edit_type == "remix":
536                        n_min = 0.2
537                        n_max = 0.4
538                    return n_min, n_max
539
540                edit_type.change(
541                    edit_type_change_func,
542                    inputs=[edit_type],
543                    outputs=[edit_n_min, edit_n_max],
544                )
545
546                edit_source = gr.Radio(
547                    ["text2music", "last_edit", "upload"],
548                    value="text2music",
549                    label="Edit Source",
550                    elem_id="edit_source",
551                )
552                edit_source_audio_upload = gr.Audio(
553                    label="Upload Audio",
554                    type="filepath",
555                    visible=False,
556                    elem_id="edit_source_audio_upload",
557                    # show_download_button=True,
558                )
559                edit_source.change(
560                    fn=lambda x: gr.update(visible=x == "upload", elem_id="edit_source_audio_upload"),
561                    inputs=[edit_source],
562                    outputs=[edit_source_audio_upload],
563                )
564
565                edit_bnt = gr.Button("Edit", variant="primary")
566                edit_outputs, edit_input_params_json = create_output_ui("Edit")
567
568                def edit_process_func(
569                    text2music_json_data,
570                    edit_input_params_json,
571                    edit_source,
572                    edit_source_audio_upload,
573                    prompt,
574                    lyrics,
575                    edit_prompt,
576                    edit_lyrics,
577                    edit_n_min,
578                    edit_n_max,
579                    infer_step,
580                    guidance_scale,
581                    scheduler_type,
582                    cfg_type,
583                    omega_scale,
584                    manual_seeds,
585                    guidance_interval,
586                    guidance_interval_decay,
587                    min_guidance_scale,
588                    use_erg_tag,
589                    use_erg_lyric,
590                    use_erg_diffusion,
591                    oss_steps,
592                    guidance_scale_text,
593                    guidance_scale_lyric,
594                    retake_seeds,
595                ):
596                    if edit_source == "upload":
597                        src_audio_path = edit_source_audio_upload
598                        audio_duration = librosa.get_duration(filename=src_audio_path)
599                        json_data = {"audio_duration": audio_duration}
600                    elif edit_source == "text2music":
601                        json_data = text2music_json_data
602                        src_audio_path = json_data["audio_path"]
603                    elif edit_source == "last_edit":
604                        json_data = edit_input_params_json
605                        src_audio_path = json_data["audio_path"]
606
607                    if not edit_prompt:
608                        edit_prompt = prompt
609                    if not edit_lyrics:
610                        edit_lyrics = lyrics
611
612                    return text2music_process_func(
613                        format.value,
614                        json_data["audio_duration"],
615                        prompt,
616                        lyrics,
617                        infer_step,
618                        guidance_scale,
619                        scheduler_type,
620                        cfg_type,
621                        omega_scale,
622                        manual_seeds,
623                        guidance_interval,
624                        guidance_interval_decay,
625                        min_guidance_scale,
626                        use_erg_tag,
627                        use_erg_lyric,
628                        use_erg_diffusion,
629                        oss_steps,
630                        guidance_scale_text,
631                        guidance_scale_lyric,
632                        task="edit",
633                        src_audio_path=src_audio_path,
634                        edit_target_prompt=edit_prompt,
635                        edit_target_lyrics=edit_lyrics,
636                        edit_n_min=edit_n_min,
637                        edit_n_max=edit_n_max,
638                        retake_seeds=retake_seeds,
639                        lora_name_or_path="none" if "lora_name_or_path" not in json_data else json_data["lora_name_or_path"],
640                        lora_weight=1 if "lora_weight" not in json_data else json_data["lora_weight"],
641                    )
642
643                edit_bnt.click(
644                    fn=edit_process_func,
645                    inputs=[
646                        input_params_json,
647                        edit_input_params_json,
648                        edit_source,
649                        edit_source_audio_upload,
650                        prompt,
651                        lyrics,
652                        edit_prompt,
653                        edit_lyrics,
654                        edit_n_min,
655                        edit_n_max,
656                        infer_step,
657                        guidance_scale,
658                        scheduler_type,
659                        cfg_type,
660                        omega_scale,
661                        manual_seeds,
662                        guidance_interval,
663                        guidance_interval_decay,
664                        min_guidance_scale,
665                        use_erg_tag,
666                        use_erg_lyric,
667                        use_erg_diffusion,
668                        oss_steps,
669                        guidance_scale_text,
670                        guidance_scale_lyric,
671                        retake_seeds,
672                    ],
673                    outputs=edit_outputs + [edit_input_params_json],
674                )
675            with gr.Tab("extend"):
676                extend_seeds = gr.Textbox(label="extend seeds (default None)", placeholder="", value=None)
677                left_extend_length = gr.Slider(
678                    minimum=0.0,
679                    maximum=240.0,
680                    step=0.01,
681                    value=0.0,
682                    label="Left Extend Length",
683                    interactive=True,
684                )
685                right_extend_length = gr.Slider(
686                    minimum=0.0,
687                    maximum=240.0,
688                    step=0.01,
689                    value=30.0,
690                    label="Right Extend Length",
691                    interactive=True,
692                )
693                extend_source = gr.Radio(
694                    ["text2music", "last_extend", "upload"],
695                    value="text2music",
696                    label="Extend Source",
697                    elem_id="extend_source",
698                )
699
700                extend_source_audio_upload = gr.Audio(
701                    label="Upload Audio",
702                    type="filepath",
703                    visible=False,
704                    elem_id="extend_source_audio_upload",
705                    # show_download_button=True,
706                )
707                extend_source.change(
708                    fn=lambda x: gr.update(visible=x == "upload", elem_id="extend_source_audio_upload"),
709                    inputs=[extend_source],
710                    outputs=[extend_source_audio_upload],
711                )
712
713                extend_bnt = gr.Button("Extend", variant="primary")
714                extend_outputs, extend_input_params_json = create_output_ui("Extend")
715
716                def extend_process_func(
717                    text2music_json_data,
718                    extend_input_params_json,
719                    extend_seeds,
720                    left_extend_length,
721                    right_extend_length,
722                    extend_source,
723                    extend_source_audio_upload,
724                    prompt,
725                    lyrics,
726                    infer_step,
727                    guidance_scale,
728                    scheduler_type,
729                    cfg_type,
730                    omega_scale,
731                    manual_seeds,
732                    guidance_interval,
733                    guidance_interval_decay,
734                    min_guidance_scale,
735                    use_erg_tag,
736                    use_erg_lyric,
737                    use_erg_diffusion,
738                    oss_steps,
739                    guidance_scale_text,
740                    guidance_scale_lyric,
741                ):
742                    if extend_source == "upload":
743                        src_audio_path = extend_source_audio_upload
744                        # get audio duration
745                        audio_duration = librosa.get_duration(filename=src_audio_path)
746                        json_data = {"audio_duration": audio_duration}
747                    elif extend_source == "text2music":
748                        json_data = text2music_json_data
749                        src_audio_path = json_data["audio_path"]
750                    elif extend_source == "last_extend":
751                        json_data = extend_input_params_json
752                        src_audio_path = json_data["audio_path"]
753
754                    repaint_start = -left_extend_length
755                    repaint_end = json_data["audio_duration"] + right_extend_length
756                    return text2music_process_func(
757                        format.value,
758                        json_data["audio_duration"],
759                        prompt,
760                        lyrics,
761                        infer_step,
762                        guidance_scale,
763                        scheduler_type,
764                        cfg_type,
765                        omega_scale,
766                        manual_seeds,
767                        guidance_interval,
768                        guidance_interval_decay,
769                        min_guidance_scale,
770                        use_erg_tag,
771                        use_erg_lyric,
772                        use_erg_diffusion,
773                        oss_steps,
774                        guidance_scale_text,
775                        guidance_scale_lyric,
776                        retake_seeds=extend_seeds,
777                        retake_variance=1.0,
778                        task="extend",
779                        repaint_start=repaint_start,
780                        repaint_end=repaint_end,
781                        src_audio_path=src_audio_path,
782                        lora_name_or_path=("none" if "lora_name_or_path" not in json_data else json_data["lora_name_or_path"]),
783                        lora_weight=(1 if "lora_weight" not in json_data else json_data["lora_weight"]),
784                    )
785
786                extend_bnt.click(
787                    fn=extend_process_func,
788                    inputs=[
789                        input_params_json,
790                        extend_input_params_json,
791                        extend_seeds,
792                        left_extend_length,
793                        right_extend_length,
794                        extend_source,
795                        extend_source_audio_upload,
796                        prompt,
797                        lyrics,
798                        infer_step,
799                        guidance_scale,
800                        scheduler_type,
801                        cfg_type,
802                        omega_scale,
803                        manual_seeds,
804                        guidance_interval,
805                        guidance_interval_decay,
806                        min_guidance_scale,
807                        use_erg_tag,
808                        use_erg_lyric,
809                        use_erg_diffusion,
810                        oss_steps,
811                        guidance_scale_text,
812                        guidance_scale_lyric,
813                    ],
814                    outputs=extend_outputs + [extend_input_params_json],
815                )
816
817        def json2output(json_data):
818            return (
819                json_data["audio_duration"],
820                json_data["prompt"],
821                json_data["lyrics"],
822                json_data["infer_step"],
823                json_data["guidance_scale"],
824                json_data["scheduler_type"],
825                json_data["cfg_type"],
826                json_data["omega_scale"],
827                ", ".join(map(str, json_data["actual_seeds"])),
828                json_data["guidance_interval"],
829                json_data["guidance_interval_decay"],
830                json_data["min_guidance_scale"],
831                json_data["use_erg_tag"],
832                json_data["use_erg_lyric"],
833                json_data["use_erg_diffusion"],
834                ", ".join(map(str, json_data["oss_steps"])),
835                (json_data["guidance_scale_text"] if "guidance_scale_text" in json_data else 0.0),
836                (json_data["guidance_scale_lyric"] if "guidance_scale_lyric" in json_data else 0.0),
837                (json_data["audio2audio_enable"] if "audio2audio_enable" in json_data else False),
838                (json_data["ref_audio_strength"] if "ref_audio_strength" in json_data else 0.5),
839                (json_data["ref_audio_input"] if "ref_audio_input" in json_data else None),
840            )
841
842        # def sample_data(lora_name_or_path_):
843        #     json_data = sample_data_func(lora_name_or_path_)
844        #     return json2output(json_data)
845
846        # sample_bnt.click(
847        #     sample_data,
848        #     inputs=[lora_name_or_path],
849        #     outputs=[
850        #         audio_duration,
851        #         prompt,
852        #         lyrics,
853        #         infer_step,
854        #         guidance_scale,
855        #         scheduler_type,
856        #         cfg_type,
857        #         omega_scale,
858        #         manual_seeds,
859        #         guidance_interval,
860        #         guidance_interval_decay,
861        #         min_guidance_scale,
862        #         use_erg_tag,
863        #         use_erg_lyric,
864        #         use_erg_diffusion,
865        #         oss_steps,
866        #         guidance_scale_text,
867        #         guidance_scale_lyric,
868        #         audio2audio_enable,
869        #         ref_audio_strength,
870        #         ref_audio_input,
871        #     ],
872        # )
873
874        # def load_data(json_file):
875        #     if isinstance(output_file_dir, str):
876        #         json_file = os.path.join(output_file_dir, json_file)
877        #     # json_data = load_data_func(json_file)
878        #     # return json2output(json_data)
879
880        # load_bnt.click(
881        #     fn=load_data,
882        #     inputs=[output_files],
883        #     outputs=[
884        #         audio_duration,
885        #         prompt,
886        #         lyrics,
887        #         infer_step,
888        #         guidance_scale,
889        #         scheduler_type,
890        #         cfg_type,
891        #         omega_scale,
892        #         manual_seeds,
893        #         guidance_interval,
894        #         guidance_interval_decay,
895        #         min_guidance_scale,
896        #         use_erg_tag,
897        #         use_erg_lyric,
898        #         use_erg_diffusion,
899        #         oss_steps,
900        #         guidance_scale_text,
901        #         guidance_scale_lyric,
902        #         audio2audio_enable,
903        #         ref_audio_strength,
904        #         ref_audio_input,
905        #     ],
906        # )
907
908    text2music_bnt.click(
909        fn=text2music_process_func,
910        inputs=[
911            format,
912            audio_duration,
913            prompt,
914            lyrics,
915            infer_step,
916            guidance_scale,
917            scheduler_type,
918            cfg_type,
919            omega_scale,
920            manual_seeds,
921            guidance_interval,
922            guidance_interval_decay,
923            min_guidance_scale,
924            use_erg_tag,
925            use_erg_lyric,
926            use_erg_diffusion,
927            oss_steps,
928            guidance_scale_text,
929            guidance_scale_lyric,
930            audio2audio_enable,
931            ref_audio_strength,
932            ref_audio_input,
933            lora_name_or_path,
934            lora_weight,
935        ],
936        outputs=outputs + [input_params_json],
937    )
938
939
940def create_main_demo_ui(
941    text2music_process_func=dump_func,
942):
943    with gr.Blocks(
944        title="ACE-Step Model 1.0 DEMO",
945    ) as demo:
946        gr.Markdown(
947            """
948            <h1 style="text-align: center;">ACE-Step: A Step Towards Music Generation Foundation Model</h1>
949        """
950        )
951        with gr.Tab("text2music"):
952            create_text2music_ui(
953                gr=gr,
954                text2music_process_func=text2music_process_func,
955            )
956    return demo
957
958
959if __name__ == "__main__":
960    demo = create_main_demo_ui()
961    demo.launch(
962        server_name="0.0.0.0",
963        server_port=7860,
964    )
TAG_DEFAULT = 'funk, pop, soul, rock, melodic, guitar, drums, bass, keyboard, percussion, 105 BPM, energetic, upbeat, groovy, vibrant, dynamic'
LYRIC_DEFAULT = "[verse]\nNeon lights they flicker bright\nCity hums in dead of night\nRhythms pulse through concrete veins\nLost in echoes of refrains\n\n[verse]\nBassline groovin' in my chest\nHeartbeats match the city's zest\nElectric whispers fill the air\nSynthesized dreams everywhere\n\n[chorus]\nTurn it up and let it flow\nFeel the fire let it grow\nIn this rhythm we belong\nHear the night sing out our song\n\n[verse]\nGuitar strings they start to weep\nWake the soul from silent sleep\nEvery note a story told\nIn this night we’re bold and gold\n\n[bridge]\nVoices blend in harmony\nLost in pure cacophony\nTimeless echoes timeless cries\nSoulful shouts beneath the skies\n\n[verse]\nKeyboard dances on the keys\nMelodies on evening breeze\nCatch the tune and hold it tight\nIn this moment we take flight\n"
GENRE_PRESETS = {'Modern Pop': 'pop, synth, drums, guitar, 120 bpm, upbeat, catchy, vibrant, female vocals, polished vocals', 'Rock': 'rock, electric guitar, drums, bass, 130 bpm, energetic, rebellious, gritty, male vocals, raw vocals', 'Hip Hop': 'hip hop, 808 bass, hi-hats, synth, 90 bpm, bold, urban, intense, male vocals, rhythmic vocals', 'Country': 'country, acoustic guitar, steel guitar, fiddle, 100 bpm, heartfelt, rustic, warm, male vocals, twangy vocals', 'EDM': 'edm, synth, bass, kick drum, 128 bpm, euphoric, pulsating, energetic, instrumental', 'Reggae': 'reggae, guitar, bass, drums, 80 bpm, chill, soulful, positive, male vocals, smooth vocals', 'Classical': 'classical, orchestral, strings, piano, 60 bpm, elegant, emotive, timeless, instrumental', 'Jazz': 'jazz, saxophone, piano, double bass, 110 bpm, smooth, improvisational, soulful, male vocals, crooning vocals', 'Metal': 'metal, electric guitar, double kick drum, bass, 160 bpm, aggressive, intense, heavy, male vocals, screamed vocals', 'R&B': 'r&b, synth, bass, drums, 85 bpm, sultry, groovy, romantic, female vocals, silky vocals'}
def update_tags_from_preset(preset_name):
69def update_tags_from_preset(preset_name):
70    if preset_name == "Custom":
71        return ""
72    return GENRE_PRESETS.get(preset_name, "")
def create_output_ui(task_name='Text2Music'):
75def create_output_ui(task_name="Text2Music"):
76    # For many consumer-grade GPU devices, only one batch can be run
77    output_audio1 = gr.Audio(type="filepath", label=f"{task_name} Generated Audio 1")
78    # output_audio2 = gr.Audio(type="filepath", label="Generated Audio 2")
79    with gr.Accordion(f"{task_name} Parameters", open=False):
80        input_params_json = gr.JSON(label=f"{task_name} Parameters")
81    # outputs = [output_audio1, output_audio2]
82    outputs = [output_audio1]
83    return outputs, input_params_json
def dump_func(*args):
86def dump_func(*args):
87    print(args)
88    return []
def create_text2music_ui(gr, text2music_process_func):
 91def create_text2music_ui(
 92    gr,
 93    text2music_process_func,
 94):
 95    with gr.Row(equal_height=True):
 96        # Get base output directory from environment variable, defaulting to CWD-relative 'outputs'.
 97        # This default (./outputs) is suitable for non-Docker local development.
 98        # For Docker, the ACE_OUTPUT_DIR environment variable should be set (e.g., to /app/outputs).
 99        output_file_dir = os.environ.get("ACE_OUTPUT_DIR", "./outputs")
100        if not os.path.isdir(output_file_dir):
101            os.makedirs(output_file_dir, exist_ok=True)
102        json_files = [f for f in os.listdir(output_file_dir) if f.endswith(".json")]
103        json_files.sort(reverse=True, key=lambda x: int(x.split("_")[1]))
104        output_files = gr.Dropdown(choices=json_files, label="Select previous generated input params", scale=9, interactive=True)
105        load_bnt = gr.Button("Load", variant="primary", scale=1)
106
107    with gr.Row():
108        with gr.Column():
109            with gr.Row(equal_height=True):
110                # add markdown, tags and lyrics examples are from ai music generation community
111                audio_duration = gr.Slider(
112                    -1,
113                    240.0,
114                    step=0.00001,
115                    value=-1,
116                    label="Audio Duration",
117                    interactive=True,
118                    info="-1 means random duration (30 ~ 240).",
119                    scale=9,
120                )
121                format = gr.Dropdown(choices=["mp3", "ogg", "flac", "wav"], value="wav", label="Format")
122                sample_bnt = gr.Button("Sample", variant="secondary", scale=1)
123
124            # audio2audio
125            with gr.Row(equal_height=True):
126                audio2audio_enable = gr.Checkbox(
127                    label="Enable Audio2Audio", value=False, info="Check to enable Audio-to-Audio generation using a reference audio.", elem_id="audio2audio_checkbox"
128                )
129                lora_name_or_path = gr.Dropdown(
130                    label="Lora Name or Path", choices=["ACE-Step/ACE-Step-v1-chinese-rap-LoRA", "none"], value="none", allow_custom_value=True, min_width=300
131                )
132                lora_weight = gr.Number(value=1.0, label="Lora weight", step=0.1, maximum=3, minimum=-3)
133
134            ref_audio_input = gr.Audio(
135                type="filepath",
136                label="Reference Audio (for Audio2Audio)",
137                visible=False,
138                elem_id="ref_audio_input",
139                # show_download_button=True,
140            )
141            ref_audio_strength = gr.Slider(
142                label="Refer audio strength",
143                minimum=0.0,
144                maximum=1.0,
145                step=0.01,
146                value=0.5,
147                elem_id="ref_audio_strength",
148                visible=False,
149                interactive=True,
150            )
151
152            def toggle_ref_audio_visibility(is_checked):
153                return (
154                    gr.update(visible=is_checked, elem_id="ref_audio_input"),
155                    gr.update(visible=is_checked, elem_id="ref_audio_strength"),
156                )
157
158            audio2audio_enable.change(
159                fn=toggle_ref_audio_visibility,
160                inputs=[audio2audio_enable],
161                outputs=[ref_audio_input, ref_audio_strength],
162            )
163
164            with gr.Column(scale=2):
165                with gr.Group():
166                    gr.Markdown(
167                        """<center>Support tags, descriptions, and scene. Use commas to separate different tags.<br>Tags and lyrics examples are from AI music generation community.</center>"""
168                    )
169                    with gr.Row():
170                        genre_preset = gr.Dropdown(
171                            choices=["Custom"] + list(GENRE_PRESETS.keys()),
172                            value="Custom",
173                            label="Preset",
174                            scale=1,
175                        )
176                        prompt = gr.Textbox(
177                            lines=1,
178                            label="Tags",
179                            max_lines=4,
180                            value=TAG_DEFAULT,
181                            scale=9,
182                        )
183
184            # Add the change event for the preset dropdown
185            genre_preset.change(fn=update_tags_from_preset, inputs=[genre_preset], outputs=[prompt])
186            with gr.Group():
187                gr.Markdown(
188                    """<center>Support lyric structure tags like [verse], [chorus], and [bridge] to separate different parts of the lyrics.<br>Use [instrumental] or [inst] to generate instrumental music. Not support genre structure tag in lyrics</center>"""
189                )
190                lyrics = gr.Textbox(
191                    lines=9,
192                    label="Lyrics",
193                    max_lines=13,
194                    value=LYRIC_DEFAULT,
195                )
196
197            with gr.Accordion("Basic Settings", open=False):
198                infer_step = gr.Slider(
199                    minimum=1,
200                    maximum=200,
201                    step=1,
202                    value=60,
203                    label="Infer Steps",
204                    interactive=True,
205                )
206                guidance_scale = gr.Slider(
207                    minimum=0.0,
208                    maximum=30.0,
209                    step=0.1,
210                    value=15.0,
211                    label="Guidance Scale",
212                    interactive=True,
213                    info="When guidance_scale_lyric > 1 and guidance_scale_text > 1, the guidance scale will not be applied.",
214                )
215                guidance_scale_text = gr.Slider(
216                    minimum=0.0,
217                    maximum=10.0,
218                    step=0.1,
219                    value=0.0,
220                    label="Guidance Scale Text",
221                    interactive=True,
222                    info="Guidance scale for text condition. It can only apply to cfg. set guidance_scale_text=5.0, guidance_scale_lyric=1.5 for start",
223                )
224                guidance_scale_lyric = gr.Slider(
225                    minimum=0.0,
226                    maximum=10.0,
227                    step=0.1,
228                    value=0.0,
229                    label="Guidance Scale Lyric",
230                    interactive=True,
231                )
232
233                manual_seeds = gr.Textbox(
234                    label="manual seeds (default None)",
235                    placeholder="1,2,3,4",
236                    value=None,
237                    info="Seed for the generation",
238                )
239
240            with gr.Accordion("Advanced Settings", open=False):
241                scheduler_type = gr.Radio(
242                    ["euler", "heun", "pingpong"],
243                    value="euler",
244                    label="Scheduler Type",
245                    elem_id="scheduler_type",
246                    info="Scheduler type for the generation. euler is recommended. heun will take more time. pingpong use SDE",
247                )
248                cfg_type = gr.Radio(
249                    ["cfg", "apg", "cfg_star"],
250                    value="apg",
251                    label="CFG Type",
252                    elem_id="cfg_type",
253                    info="CFG type for the generation. apg is recommended. cfg and cfg_star are almost the same.",
254                )
255                use_erg_tag = gr.Checkbox(
256                    label="use ERG for tag",
257                    value=True,
258                    info="Use Entropy Rectifying Guidance for tag. It will multiple a temperature to the attention to make a weaker tag condition and make better diversity.",
259                )
260                use_erg_lyric = gr.Checkbox(
261                    label="use ERG for lyric",
262                    value=False,
263                    info="The same but apply to lyric encoder's attention.",
264                )
265                use_erg_diffusion = gr.Checkbox(
266                    label="use ERG for diffusion",
267                    value=True,
268                    info="The same but apply to diffusion model's attention.",
269                )
270
271                omega_scale = gr.Slider(
272                    minimum=-100.0,
273                    maximum=100.0,
274                    step=0.1,
275                    value=10.0,
276                    label="Granularity Scale",
277                    interactive=True,
278                    info="Granularity scale for the generation. Higher values can reduce artifacts",
279                )
280
281                guidance_interval = gr.Slider(
282                    minimum=0.0,
283                    maximum=1.0,
284                    step=0.01,
285                    value=0.5,
286                    label="Guidance Interval",
287                    interactive=True,
288                    info="Guidance interval for the generation. 0.5 means only apply guidance in the middle steps (0.25 * infer_steps to 0.75 * infer_steps)",
289                )
290                guidance_interval_decay = gr.Slider(
291                    minimum=0.0,
292                    maximum=1.0,
293                    step=0.01,
294                    value=0.0,
295                    label="Guidance Interval Decay",
296                    interactive=True,
297                    info="Guidance interval decay for the generation. Guidance scale will decay from guidance_scale to min_guidance_scale in the interval. 0.0 means no decay.",
298                )
299                min_guidance_scale = gr.Slider(
300                    minimum=0.0,
301                    maximum=200.0,
302                    step=0.1,
303                    value=3.0,
304                    label="Min Guidance Scale",
305                    interactive=True,
306                    info="Min guidance scale for guidance interval decay's end scale",
307                )
308                oss_steps = gr.Textbox(
309                    label="OSS Steps",
310                    placeholder="16, 29, 52, 96, 129, 158, 172, 183, 189, 200",
311                    value=None,
312                    info="Optimal Steps for the generation. But not test well",
313                )
314
315            text2music_bnt = gr.Button("Generate", variant="primary")
316
317        with gr.Column():
318            outputs, input_params_json = create_output_ui()
319            with gr.Tab("retake"):
320                retake_variance = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, value=0.2, label="variance")
321                retake_seeds = gr.Textbox(label="retake seeds (default None)", placeholder="", value=None)
322                retake_bnt = gr.Button("Retake", variant="primary")
323                retake_outputs, retake_input_params_json = create_output_ui("Retake")
324
325                def retake_process_func(json_data, retake_variance, retake_seeds):
326                    return text2music_process_func(
327                        json_data["format"],
328                        json_data["audio_duration"],
329                        json_data["prompt"],
330                        json_data["lyrics"],
331                        json_data["infer_step"],
332                        json_data["guidance_scale"],
333                        json_data["scheduler_type"],
334                        json_data["cfg_type"],
335                        json_data["omega_scale"],
336                        ", ".join(map(str, json_data["actual_seeds"])),
337                        json_data["guidance_interval"],
338                        json_data["guidance_interval_decay"],
339                        json_data["min_guidance_scale"],
340                        json_data["use_erg_tag"],
341                        json_data["use_erg_lyric"],
342                        json_data["use_erg_diffusion"],
343                        ", ".join(map(str, json_data["oss_steps"])),
344                        (json_data["guidance_scale_text"] if "guidance_scale_text" in json_data else 0.0),
345                        (json_data["guidance_scale_lyric"] if "guidance_scale_lyric" in json_data else 0.0),
346                        retake_seeds=retake_seeds,
347                        retake_variance=retake_variance,
348                        task="retake",
349                        lora_name_or_path="none" if "lora_name_or_path" not in json_data else json_data["lora_name_or_path"],
350                        lora_weight=1 if "lora_weight" not in json_data else json_data["lora_weight"],
351                    )
352
353                retake_bnt.click(
354                    fn=retake_process_func,
355                    inputs=[
356                        input_params_json,
357                        retake_variance,
358                        retake_seeds,
359                    ],
360                    outputs=retake_outputs + [retake_input_params_json],
361                )
362            with gr.Tab("repainting"):
363                retake_variance = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, value=0.2, label="variance")
364                retake_seeds = gr.Textbox(label="repaint seeds (default None)", placeholder="", value=None)
365                repaint_start = gr.Slider(
366                    minimum=0.0,
367                    maximum=240.0,
368                    step=0.01,
369                    value=0.0,
370                    label="Repaint Start Time",
371                    interactive=True,
372                )
373                repaint_end = gr.Slider(
374                    minimum=0.0,
375                    maximum=240.0,
376                    step=0.01,
377                    value=30.0,
378                    label="Repaint End Time",
379                    interactive=True,
380                )
381                repaint_source = gr.Radio(
382                    ["text2music", "last_repaint", "upload"],
383                    value="text2music",
384                    label="Repaint Source",
385                    elem_id="repaint_source",
386                )
387
388                repaint_source_audio_upload = gr.Audio(
389                    label="Upload Audio",
390                    type="filepath",
391                    visible=False,
392                    elem_id="repaint_source_audio_upload",
393                    # show_download_button=True,
394                )
395                repaint_source.change(
396                    fn=lambda x: gr.update(visible=x == "upload", elem_id="repaint_source_audio_upload"),
397                    inputs=[repaint_source],
398                    outputs=[repaint_source_audio_upload],
399                )
400
401                repaint_bnt = gr.Button("Repaint", variant="primary")
402                repaint_outputs, repaint_input_params_json = create_output_ui("Repaint")
403
404                def repaint_process_func(
405                    text2music_json_data,
406                    repaint_json_data,
407                    retake_variance,
408                    retake_seeds,
409                    repaint_start,
410                    repaint_end,
411                    repaint_source,
412                    repaint_source_audio_upload,
413                    prompt,
414                    lyrics,
415                    infer_step,
416                    guidance_scale,
417                    scheduler_type,
418                    cfg_type,
419                    omega_scale,
420                    manual_seeds,
421                    guidance_interval,
422                    guidance_interval_decay,
423                    min_guidance_scale,
424                    use_erg_tag,
425                    use_erg_lyric,
426                    use_erg_diffusion,
427                    oss_steps,
428                    guidance_scale_text,
429                    guidance_scale_lyric,
430                ):
431                    if repaint_source == "upload":
432                        src_audio_path = repaint_source_audio_upload
433                        audio_duration = librosa.get_duration(filename=src_audio_path)
434                        json_data = {"audio_duration": audio_duration}
435                    elif repaint_source == "text2music":
436                        json_data = text2music_json_data
437                        src_audio_path = json_data["audio_path"]
438                    elif repaint_source == "last_repaint":
439                        json_data = repaint_json_data
440                        src_audio_path = json_data["audio_path"]
441
442                    return text2music_process_func(
443                        format.value,
444                        json_data["audio_duration"],
445                        prompt,
446                        lyrics,
447                        infer_step,
448                        guidance_scale,
449                        scheduler_type,
450                        cfg_type,
451                        omega_scale,
452                        manual_seeds,
453                        guidance_interval,
454                        guidance_interval_decay,
455                        min_guidance_scale,
456                        use_erg_tag,
457                        use_erg_lyric,
458                        use_erg_diffusion,
459                        oss_steps,
460                        guidance_scale_text,
461                        guidance_scale_lyric,
462                        retake_seeds=retake_seeds,
463                        retake_variance=retake_variance,
464                        task="repaint",
465                        repaint_start=repaint_start,
466                        repaint_end=repaint_end,
467                        src_audio_path=src_audio_path,
468                        lora_name_or_path="none" if "lora_name_or_path" not in json_data else json_data["lora_name_or_path"],
469                        lora_weight=1 if "lora_weight" not in json_data else json_data["lora_weight"],
470                    )
471
472                repaint_bnt.click(
473                    fn=repaint_process_func,
474                    inputs=[
475                        input_params_json,
476                        repaint_input_params_json,
477                        retake_variance,
478                        retake_seeds,
479                        repaint_start,
480                        repaint_end,
481                        repaint_source,
482                        repaint_source_audio_upload,
483                        prompt,
484                        lyrics,
485                        infer_step,
486                        guidance_scale,
487                        scheduler_type,
488                        cfg_type,
489                        omega_scale,
490                        manual_seeds,
491                        guidance_interval,
492                        guidance_interval_decay,
493                        min_guidance_scale,
494                        use_erg_tag,
495                        use_erg_lyric,
496                        use_erg_diffusion,
497                        oss_steps,
498                        guidance_scale_text,
499                        guidance_scale_lyric,
500                    ],
501                    outputs=repaint_outputs + [repaint_input_params_json],
502                )
503            with gr.Tab("edit"):
504                edit_prompt = gr.Textbox(lines=2, label="Edit Tags", max_lines=4)
505                edit_lyrics = gr.Textbox(lines=9, label="Edit Lyrics", max_lines=13)
506                retake_seeds = gr.Textbox(label="edit seeds (default None)", placeholder="", value=None)
507
508                edit_type = gr.Radio(
509                    ["only_lyrics", "remix"],
510                    value="only_lyrics",
511                    label="Edit Type",
512                    elem_id="edit_type",
513                    info="`only_lyrics` will keep the whole song the same except lyrics difference. Make your diffrence smaller, e.g. one lyrc line change.\nremix can change the song melody and genre",
514                )
515                edit_n_min = gr.Slider(
516                    minimum=0.0,
517                    maximum=1.0,
518                    step=0.01,
519                    value=0.6,
520                    label="edit_n_min",
521                    interactive=True,
522                )
523                edit_n_max = gr.Slider(
524                    minimum=0.0,
525                    maximum=1.0,
526                    step=0.01,
527                    value=1.0,
528                    label="edit_n_max",
529                    interactive=True,
530                )
531
532                def edit_type_change_func(edit_type):
533                    if edit_type == "only_lyrics":
534                        n_min = 0.6
535                        n_max = 1.0
536                    elif edit_type == "remix":
537                        n_min = 0.2
538                        n_max = 0.4
539                    return n_min, n_max
540
541                edit_type.change(
542                    edit_type_change_func,
543                    inputs=[edit_type],
544                    outputs=[edit_n_min, edit_n_max],
545                )
546
547                edit_source = gr.Radio(
548                    ["text2music", "last_edit", "upload"],
549                    value="text2music",
550                    label="Edit Source",
551                    elem_id="edit_source",
552                )
553                edit_source_audio_upload = gr.Audio(
554                    label="Upload Audio",
555                    type="filepath",
556                    visible=False,
557                    elem_id="edit_source_audio_upload",
558                    # show_download_button=True,
559                )
560                edit_source.change(
561                    fn=lambda x: gr.update(visible=x == "upload", elem_id="edit_source_audio_upload"),
562                    inputs=[edit_source],
563                    outputs=[edit_source_audio_upload],
564                )
565
566                edit_bnt = gr.Button("Edit", variant="primary")
567                edit_outputs, edit_input_params_json = create_output_ui("Edit")
568
569                def edit_process_func(
570                    text2music_json_data,
571                    edit_input_params_json,
572                    edit_source,
573                    edit_source_audio_upload,
574                    prompt,
575                    lyrics,
576                    edit_prompt,
577                    edit_lyrics,
578                    edit_n_min,
579                    edit_n_max,
580                    infer_step,
581                    guidance_scale,
582                    scheduler_type,
583                    cfg_type,
584                    omega_scale,
585                    manual_seeds,
586                    guidance_interval,
587                    guidance_interval_decay,
588                    min_guidance_scale,
589                    use_erg_tag,
590                    use_erg_lyric,
591                    use_erg_diffusion,
592                    oss_steps,
593                    guidance_scale_text,
594                    guidance_scale_lyric,
595                    retake_seeds,
596                ):
597                    if edit_source == "upload":
598                        src_audio_path = edit_source_audio_upload
599                        audio_duration = librosa.get_duration(filename=src_audio_path)
600                        json_data = {"audio_duration": audio_duration}
601                    elif edit_source == "text2music":
602                        json_data = text2music_json_data
603                        src_audio_path = json_data["audio_path"]
604                    elif edit_source == "last_edit":
605                        json_data = edit_input_params_json
606                        src_audio_path = json_data["audio_path"]
607
608                    if not edit_prompt:
609                        edit_prompt = prompt
610                    if not edit_lyrics:
611                        edit_lyrics = lyrics
612
613                    return text2music_process_func(
614                        format.value,
615                        json_data["audio_duration"],
616                        prompt,
617                        lyrics,
618                        infer_step,
619                        guidance_scale,
620                        scheduler_type,
621                        cfg_type,
622                        omega_scale,
623                        manual_seeds,
624                        guidance_interval,
625                        guidance_interval_decay,
626                        min_guidance_scale,
627                        use_erg_tag,
628                        use_erg_lyric,
629                        use_erg_diffusion,
630                        oss_steps,
631                        guidance_scale_text,
632                        guidance_scale_lyric,
633                        task="edit",
634                        src_audio_path=src_audio_path,
635                        edit_target_prompt=edit_prompt,
636                        edit_target_lyrics=edit_lyrics,
637                        edit_n_min=edit_n_min,
638                        edit_n_max=edit_n_max,
639                        retake_seeds=retake_seeds,
640                        lora_name_or_path="none" if "lora_name_or_path" not in json_data else json_data["lora_name_or_path"],
641                        lora_weight=1 if "lora_weight" not in json_data else json_data["lora_weight"],
642                    )
643
644                edit_bnt.click(
645                    fn=edit_process_func,
646                    inputs=[
647                        input_params_json,
648                        edit_input_params_json,
649                        edit_source,
650                        edit_source_audio_upload,
651                        prompt,
652                        lyrics,
653                        edit_prompt,
654                        edit_lyrics,
655                        edit_n_min,
656                        edit_n_max,
657                        infer_step,
658                        guidance_scale,
659                        scheduler_type,
660                        cfg_type,
661                        omega_scale,
662                        manual_seeds,
663                        guidance_interval,
664                        guidance_interval_decay,
665                        min_guidance_scale,
666                        use_erg_tag,
667                        use_erg_lyric,
668                        use_erg_diffusion,
669                        oss_steps,
670                        guidance_scale_text,
671                        guidance_scale_lyric,
672                        retake_seeds,
673                    ],
674                    outputs=edit_outputs + [edit_input_params_json],
675                )
676            with gr.Tab("extend"):
677                extend_seeds = gr.Textbox(label="extend seeds (default None)", placeholder="", value=None)
678                left_extend_length = gr.Slider(
679                    minimum=0.0,
680                    maximum=240.0,
681                    step=0.01,
682                    value=0.0,
683                    label="Left Extend Length",
684                    interactive=True,
685                )
686                right_extend_length = gr.Slider(
687                    minimum=0.0,
688                    maximum=240.0,
689                    step=0.01,
690                    value=30.0,
691                    label="Right Extend Length",
692                    interactive=True,
693                )
694                extend_source = gr.Radio(
695                    ["text2music", "last_extend", "upload"],
696                    value="text2music",
697                    label="Extend Source",
698                    elem_id="extend_source",
699                )
700
701                extend_source_audio_upload = gr.Audio(
702                    label="Upload Audio",
703                    type="filepath",
704                    visible=False,
705                    elem_id="extend_source_audio_upload",
706                    # show_download_button=True,
707                )
708                extend_source.change(
709                    fn=lambda x: gr.update(visible=x == "upload", elem_id="extend_source_audio_upload"),
710                    inputs=[extend_source],
711                    outputs=[extend_source_audio_upload],
712                )
713
714                extend_bnt = gr.Button("Extend", variant="primary")
715                extend_outputs, extend_input_params_json = create_output_ui("Extend")
716
717                def extend_process_func(
718                    text2music_json_data,
719                    extend_input_params_json,
720                    extend_seeds,
721                    left_extend_length,
722                    right_extend_length,
723                    extend_source,
724                    extend_source_audio_upload,
725                    prompt,
726                    lyrics,
727                    infer_step,
728                    guidance_scale,
729                    scheduler_type,
730                    cfg_type,
731                    omega_scale,
732                    manual_seeds,
733                    guidance_interval,
734                    guidance_interval_decay,
735                    min_guidance_scale,
736                    use_erg_tag,
737                    use_erg_lyric,
738                    use_erg_diffusion,
739                    oss_steps,
740                    guidance_scale_text,
741                    guidance_scale_lyric,
742                ):
743                    if extend_source == "upload":
744                        src_audio_path = extend_source_audio_upload
745                        # get audio duration
746                        audio_duration = librosa.get_duration(filename=src_audio_path)
747                        json_data = {"audio_duration": audio_duration}
748                    elif extend_source == "text2music":
749                        json_data = text2music_json_data
750                        src_audio_path = json_data["audio_path"]
751                    elif extend_source == "last_extend":
752                        json_data = extend_input_params_json
753                        src_audio_path = json_data["audio_path"]
754
755                    repaint_start = -left_extend_length
756                    repaint_end = json_data["audio_duration"] + right_extend_length
757                    return text2music_process_func(
758                        format.value,
759                        json_data["audio_duration"],
760                        prompt,
761                        lyrics,
762                        infer_step,
763                        guidance_scale,
764                        scheduler_type,
765                        cfg_type,
766                        omega_scale,
767                        manual_seeds,
768                        guidance_interval,
769                        guidance_interval_decay,
770                        min_guidance_scale,
771                        use_erg_tag,
772                        use_erg_lyric,
773                        use_erg_diffusion,
774                        oss_steps,
775                        guidance_scale_text,
776                        guidance_scale_lyric,
777                        retake_seeds=extend_seeds,
778                        retake_variance=1.0,
779                        task="extend",
780                        repaint_start=repaint_start,
781                        repaint_end=repaint_end,
782                        src_audio_path=src_audio_path,
783                        lora_name_or_path=("none" if "lora_name_or_path" not in json_data else json_data["lora_name_or_path"]),
784                        lora_weight=(1 if "lora_weight" not in json_data else json_data["lora_weight"]),
785                    )
786
787                extend_bnt.click(
788                    fn=extend_process_func,
789                    inputs=[
790                        input_params_json,
791                        extend_input_params_json,
792                        extend_seeds,
793                        left_extend_length,
794                        right_extend_length,
795                        extend_source,
796                        extend_source_audio_upload,
797                        prompt,
798                        lyrics,
799                        infer_step,
800                        guidance_scale,
801                        scheduler_type,
802                        cfg_type,
803                        omega_scale,
804                        manual_seeds,
805                        guidance_interval,
806                        guidance_interval_decay,
807                        min_guidance_scale,
808                        use_erg_tag,
809                        use_erg_lyric,
810                        use_erg_diffusion,
811                        oss_steps,
812                        guidance_scale_text,
813                        guidance_scale_lyric,
814                    ],
815                    outputs=extend_outputs + [extend_input_params_json],
816                )
817
818        def json2output(json_data):
819            return (
820                json_data["audio_duration"],
821                json_data["prompt"],
822                json_data["lyrics"],
823                json_data["infer_step"],
824                json_data["guidance_scale"],
825                json_data["scheduler_type"],
826                json_data["cfg_type"],
827                json_data["omega_scale"],
828                ", ".join(map(str, json_data["actual_seeds"])),
829                json_data["guidance_interval"],
830                json_data["guidance_interval_decay"],
831                json_data["min_guidance_scale"],
832                json_data["use_erg_tag"],
833                json_data["use_erg_lyric"],
834                json_data["use_erg_diffusion"],
835                ", ".join(map(str, json_data["oss_steps"])),
836                (json_data["guidance_scale_text"] if "guidance_scale_text" in json_data else 0.0),
837                (json_data["guidance_scale_lyric"] if "guidance_scale_lyric" in json_data else 0.0),
838                (json_data["audio2audio_enable"] if "audio2audio_enable" in json_data else False),
839                (json_data["ref_audio_strength"] if "ref_audio_strength" in json_data else 0.5),
840                (json_data["ref_audio_input"] if "ref_audio_input" in json_data else None),
841            )
842
843        # def sample_data(lora_name_or_path_):
844        #     json_data = sample_data_func(lora_name_or_path_)
845        #     return json2output(json_data)
846
847        # sample_bnt.click(
848        #     sample_data,
849        #     inputs=[lora_name_or_path],
850        #     outputs=[
851        #         audio_duration,
852        #         prompt,
853        #         lyrics,
854        #         infer_step,
855        #         guidance_scale,
856        #         scheduler_type,
857        #         cfg_type,
858        #         omega_scale,
859        #         manual_seeds,
860        #         guidance_interval,
861        #         guidance_interval_decay,
862        #         min_guidance_scale,
863        #         use_erg_tag,
864        #         use_erg_lyric,
865        #         use_erg_diffusion,
866        #         oss_steps,
867        #         guidance_scale_text,
868        #         guidance_scale_lyric,
869        #         audio2audio_enable,
870        #         ref_audio_strength,
871        #         ref_audio_input,
872        #     ],
873        # )
874
875        # def load_data(json_file):
876        #     if isinstance(output_file_dir, str):
877        #         json_file = os.path.join(output_file_dir, json_file)
878        #     # json_data = load_data_func(json_file)
879        #     # return json2output(json_data)
880
881        # load_bnt.click(
882        #     fn=load_data,
883        #     inputs=[output_files],
884        #     outputs=[
885        #         audio_duration,
886        #         prompt,
887        #         lyrics,
888        #         infer_step,
889        #         guidance_scale,
890        #         scheduler_type,
891        #         cfg_type,
892        #         omega_scale,
893        #         manual_seeds,
894        #         guidance_interval,
895        #         guidance_interval_decay,
896        #         min_guidance_scale,
897        #         use_erg_tag,
898        #         use_erg_lyric,
899        #         use_erg_diffusion,
900        #         oss_steps,
901        #         guidance_scale_text,
902        #         guidance_scale_lyric,
903        #         audio2audio_enable,
904        #         ref_audio_strength,
905        #         ref_audio_input,
906        #     ],
907        # )
908
909    text2music_bnt.click(
910        fn=text2music_process_func,
911        inputs=[
912            format,
913            audio_duration,
914            prompt,
915            lyrics,
916            infer_step,
917            guidance_scale,
918            scheduler_type,
919            cfg_type,
920            omega_scale,
921            manual_seeds,
922            guidance_interval,
923            guidance_interval_decay,
924            min_guidance_scale,
925            use_erg_tag,
926            use_erg_lyric,
927            use_erg_diffusion,
928            oss_steps,
929            guidance_scale_text,
930            guidance_scale_lyric,
931            audio2audio_enable,
932            ref_audio_strength,
933            ref_audio_input,
934            lora_name_or_path,
935            lora_weight,
936        ],
937        outputs=outputs + [input_params_json],
938    )
def create_main_demo_ui(text2music_process_func=<function dump_func>):
941def create_main_demo_ui(
942    text2music_process_func=dump_func,
943):
944    with gr.Blocks(
945        title="ACE-Step Model 1.0 DEMO",
946    ) as demo:
947        gr.Markdown(
948            """
949            <h1 style="text-align: center;">ACE-Step: A Step Towards Music Generation Foundation Model</h1>
950        """
951        )
952        with gr.Tab("text2music"):
953            create_text2music_ui(
954                gr=gr,
955                text2music_process_func=text2music_process_func,
956            )
957    return demo