divisor.acestep.gradio
1# SPDX-License-Identifier:Apache-2.0 2# adapted from https://github.com/ace-step/ACE-Step 3 4 5import os 6 7 8def main( 9 checkpoint_path="", 10 server_name="127.0.0.1", 11 port=7865, 12 share=False, 13 bf16=True, 14 torch_compile=False, 15 cpu_offload=False, 16 overlapped_decode=False, 17): 18 """Main function to launch the ACE Step pipeline demo.\n 19 :param checkpoint_path: Path to the checkpoint directory. Downloads automatically if empty. 20 :param server_name: The server name to use for the Gradio app. 21 :param port: The port to use for the Gradio app. 22 :param share: Whether to create a public, shareable link for the Gradio app. 23 :param bf16: Whether to use bfloat16 precision. Turn off if using MPS 24 :param torch_compile: Whether to use torch.compile. 25 :param cpu_offload: Whether to use CPU offloading (only load current stage's model to GPU. 26 :param overlapped_decode: Whether to use overlapped decoding (run dcae and vocoder using sliding windows. 27 """ 28 29 from divisor.acestep.pipeline_ace_step import ACEStepPipeline 30 from divisor.acestep.ui.components import create_main_demo_ui 31 32 model_demo = ACEStepPipeline( 33 checkpoint_dir=checkpoint_path, 34 dtype="bfloat16" if bf16 else "float32", 35 torch_compile=torch_compile, 36 cpu_offload=cpu_offload, 37 overlapped_decode=overlapped_decode, 38 ) 39 40 demo = create_main_demo_ui( 41 text2music_process_func=model_demo.__call__, 42 ) 43 44 demo.launch(server_name=server_name, server_port=port, share=share) 45 46 47if __name__ == "__main__": 48 main()
def
main( checkpoint_path='', server_name='127.0.0.1', port=7865, share=False, bf16=True, torch_compile=False, cpu_offload=False, overlapped_decode=False):
9def main( 10 checkpoint_path="", 11 server_name="127.0.0.1", 12 port=7865, 13 share=False, 14 bf16=True, 15 torch_compile=False, 16 cpu_offload=False, 17 overlapped_decode=False, 18): 19 """Main function to launch the ACE Step pipeline demo.\n 20 :param checkpoint_path: Path to the checkpoint directory. Downloads automatically if empty. 21 :param server_name: The server name to use for the Gradio app. 22 :param port: The port to use for the Gradio app. 23 :param share: Whether to create a public, shareable link for the Gradio app. 24 :param bf16: Whether to use bfloat16 precision. Turn off if using MPS 25 :param torch_compile: Whether to use torch.compile. 26 :param cpu_offload: Whether to use CPU offloading (only load current stage's model to GPU. 27 :param overlapped_decode: Whether to use overlapped decoding (run dcae and vocoder using sliding windows. 28 """ 29 30 from divisor.acestep.pipeline_ace_step import ACEStepPipeline 31 from divisor.acestep.ui.components import create_main_demo_ui 32 33 model_demo = ACEStepPipeline( 34 checkpoint_dir=checkpoint_path, 35 dtype="bfloat16" if bf16 else "float32", 36 torch_compile=torch_compile, 37 cpu_offload=cpu_offload, 38 overlapped_decode=overlapped_decode, 39 ) 40 41 demo = create_main_demo_ui( 42 text2music_process_func=model_demo.__call__, 43 ) 44 45 demo.launch(server_name=server_name, server_port=port, share=share)
Main function to launch the ACE Step pipeline demo.
Parameters
- checkpoint_path: Path to the checkpoint directory. Downloads automatically if empty.
- server_name: The server name to use for the Gradio app.
- port: The port to use for the Gradio app.
- share: Whether to create a public, shareable link for the Gradio app.
- bf16: Whether to use bfloat16 precision. Turn off if using MPS
- torch_compile: Whether to use torch.compile.
- cpu_offload: Whether to use CPU offloading (only load current stage's model to GPU.
- overlapped_decode: Whether to use overlapped decoding (run dcae and vocoder using sliding windows.