divisor.acestep.gradio

 1# SPDX-License-Identifier:Apache-2.0
 2# adapted from https://github.com/ace-step/ACE-Step
 3
 4
 5import os
 6
 7
 8def main(
 9    checkpoint_path="",
10    server_name="127.0.0.1",
11    port=7865,
12    share=False,
13    bf16=True,
14    torch_compile=False,
15    cpu_offload=False,
16    overlapped_decode=False,
17):
18    """Main function to launch the ACE Step pipeline demo.\n
19    :param checkpoint_path: Path to the checkpoint directory. Downloads automatically if empty.
20    :param server_name: The server name to use for the Gradio app.
21    :param port: The port to use for the Gradio app.
22    :param share: Whether to create a public, shareable link for the Gradio app.
23    :param bf16: Whether to use bfloat16 precision. Turn off if using MPS
24    :param torch_compile: Whether to use torch.compile.
25    :param cpu_offload: Whether to use CPU offloading (only load current stage's model to GPU.
26    :param overlapped_decode: Whether to use overlapped decoding (run dcae and vocoder using sliding windows.
27    """
28
29    from divisor.acestep.pipeline_ace_step import ACEStepPipeline
30    from divisor.acestep.ui.components import create_main_demo_ui
31
32    model_demo = ACEStepPipeline(
33        checkpoint_dir=checkpoint_path,
34        dtype="bfloat16" if bf16 else "float32",
35        torch_compile=torch_compile,
36        cpu_offload=cpu_offload,
37        overlapped_decode=overlapped_decode,
38    )
39
40    demo = create_main_demo_ui(
41        text2music_process_func=model_demo.__call__,
42    )
43
44    demo.launch(server_name=server_name, server_port=port, share=share)
45
46
47if __name__ == "__main__":
48    main()
def main( checkpoint_path='', server_name='127.0.0.1', port=7865, share=False, bf16=True, torch_compile=False, cpu_offload=False, overlapped_decode=False):
 9def main(
10    checkpoint_path="",
11    server_name="127.0.0.1",
12    port=7865,
13    share=False,
14    bf16=True,
15    torch_compile=False,
16    cpu_offload=False,
17    overlapped_decode=False,
18):
19    """Main function to launch the ACE Step pipeline demo.\n
20    :param checkpoint_path: Path to the checkpoint directory. Downloads automatically if empty.
21    :param server_name: The server name to use for the Gradio app.
22    :param port: The port to use for the Gradio app.
23    :param share: Whether to create a public, shareable link for the Gradio app.
24    :param bf16: Whether to use bfloat16 precision. Turn off if using MPS
25    :param torch_compile: Whether to use torch.compile.
26    :param cpu_offload: Whether to use CPU offloading (only load current stage's model to GPU.
27    :param overlapped_decode: Whether to use overlapped decoding (run dcae and vocoder using sliding windows.
28    """
29
30    from divisor.acestep.pipeline_ace_step import ACEStepPipeline
31    from divisor.acestep.ui.components import create_main_demo_ui
32
33    model_demo = ACEStepPipeline(
34        checkpoint_dir=checkpoint_path,
35        dtype="bfloat16" if bf16 else "float32",
36        torch_compile=torch_compile,
37        cpu_offload=cpu_offload,
38        overlapped_decode=overlapped_decode,
39    )
40
41    demo = create_main_demo_ui(
42        text2music_process_func=model_demo.__call__,
43    )
44
45    demo.launch(server_name=server_name, server_port=port, share=share)

Main function to launch the ACE Step pipeline demo.

Parameters
  • checkpoint_path: Path to the checkpoint directory. Downloads automatically if empty.
  • server_name: The server name to use for the Gradio app.
  • port: The port to use for the Gradio app.
  • share: Whether to create a public, shareable link for the Gradio app.
  • bf16: Whether to use bfloat16 precision. Turn off if using MPS
  • torch_compile: Whether to use torch.compile.
  • cpu_offload: Whether to use CPU offloading (only load current stage's model to GPU.
  • overlapped_decode: Whether to use overlapped decoding (run dcae and vocoder using sliding windows.