renhang commited on
Commit
65e9daa
·
1 Parent(s): 4931afc

fix model.py

Browse files
Files changed (4) hide show
  1. app.py +1 -1
  2. gt0.json +0 -1
  3. jam_infer.yaml +66 -0
  4. model.py +316 -142
app.py CHANGED
@@ -47,7 +47,7 @@ def generate_song(reference_audio, lyrics_text, style_prompt, duration):
47
  reference_audio_path=reference_audio,
48
  lyrics_json_path=lyrics_file,
49
  style_prompt=style_prompt,
50
- duration_sec=duration
51
  )
52
  return output_path
53
  finally:
 
47
  reference_audio_path=reference_audio,
48
  lyrics_json_path=lyrics_file,
49
  style_prompt=style_prompt,
50
+ duration=duration
51
  )
52
  return output_path
53
  finally:
gt0.json DELETED
@@ -1 +0,0 @@
1
- [{"word": "Every", "start_offset": 259, "end_offset": 267, "start": 20.72, "end": 21.36, "phoneme": "\u025bv\u025di|_"}, {"word": "night", "start_offset": 267, "end_offset": 275, "start": 21.36, "end": 22.0, "phoneme": "na\u026at|_"}, {"word": "in", "start_offset": 279, "end_offset": 283, "start": 22.32, "end": 22.64, "phoneme": "\u026an|_"}, {"word": "my", "start_offset": 283, "end_offset": 287, "start": 22.64, "end": 22.96, "phoneme": "ma\u026a|_"}, {"word": "dreams,", "start_offset": 287, "end_offset": 301, "start": 22.96, "end": 24.080000000000002, "phoneme": "dri\u02d0mz,"}, {"word": "I", "start_offset": 309, "end_offset": 313, "start": 24.72, "end": 25.04, "phoneme": "a\u026a|_"}, {"word": "see", "start_offset": 317, "end_offset": 321, "start": 25.36, "end": 25.68, "phoneme": "si\u02d0|_"}, {"word": "you,", "start_offset": 321, "end_offset": 325, "start": 25.68, "end": 26.0, "phoneme": "ju\u02d0,"}, {"word": "I", "start_offset": 340, "end_offset": 344, "start": 27.2, "end": 27.52, "phoneme": "a\u026a|_"}, {"word": "feel", "start_offset": 348, "end_offset": 352, "start": 27.84, "end": 28.16, "phoneme": "fi\u02d0l|_"}, {"word": "you.", "start_offset": 358, "end_offset": 362, "start": 28.64, "end": 28.96, "phoneme": "ju\u02d0."}, {"word": "That", "start_offset": 377, "end_offset": 381, "start": 30.16, "end": 30.48, "phoneme": "\u00f0\u00e6t|_"}, {"word": "is", "start_offset": 385, "end_offset": 389, "start": 30.8, "end": 31.12, "phoneme": "\u026az"}, {"word": "how", "start_offset": 393, "end_offset": 397, "start": 31.44, "end": 31.76, "phoneme": "ha\u028a|_"}, {"word": "I", "start_offset": 401, "end_offset": 405, "start": 32.08, "end": 32.4, "phoneme": "a\u026a|_"}, {"word": "know", "start_offset": 405, "end_offset": 409, "start": 32.4, "end": 32.72, "phoneme": "no\u028a|_"}, {"word": "you", "start_offset": 413, "end_offset": 417, "start": 33.04, "end": 33.36, "phoneme": "ju\u02d0|_"}, {"word": "go", "start_offset": 428, "end_offset": 431, "start": 34.24, "end": 34.480000000000004, "phoneme": "go\u028a|_"}, {"word": "far", "start_offset": 495, "end_offset": 503, "start": 39.6, "end": 40.24, "phoneme": "f\u0251\u02d0r"}, {"word": "across", "start_offset": 507, "end_offset": 517, "start": 40.56, "end": 41.36, "phoneme": "\u0259kr\u0254s|_"}, {"word": "the", "start_offset": 519, "end_offset": 523, "start": 41.52, "end": 41.84, "phoneme": "\u00f0\u0259|_"}, {"word": "distance", "start_offset": 527, "end_offset": 538, "start": 42.160000000000004, "end": 43.04, "phoneme": "d\u026ast\u0259ns|_"}, {"word": "and", "start_offset": 552, "end_offset": 556, "start": 44.160000000000004, "end": 44.480000000000004, "phoneme": "\u0259nd"}, {"word": "spaces", "start_offset": 556, "end_offset": 572, "start": 44.480000000000004, "end": 45.76, "phoneme": "spe\u026as\u0259z"}, {"word": "between", "start_offset": 583, "end_offset": 587, "start": 46.64, "end": 46.96, "phoneme": "b\u026atwi\u02d0n|_"}, {"word": "us.", "start_offset": 602, "end_offset": 606, "start": 48.160000000000004, "end": 48.480000000000004, "phoneme": "\u028cs."}, {"word": "You", "start_offset": 621, "end_offset": 625, "start": 49.68, "end": 50.0, "phoneme": "ju\u02d0|_"}, {"word": "have", "start_offset": 629, "end_offset": 633, "start": 50.32, "end": 50.64, "phoneme": "h\u00e6v"}, {"word": "come", "start_offset": 633, "end_offset": 637, "start": 50.64, "end": 50.96, "phoneme": "k\u028cm|_"}, {"word": "to", "start_offset": 641, "end_offset": 645, "start": 51.28, "end": 51.6, "phoneme": "tu\u02d0|_"}, {"word": "show", "start_offset": 649, "end_offset": 653, "start": 51.92, "end": 52.24, "phoneme": "\u0283o\u028a|_"}, {"word": "you", "start_offset": 655, "end_offset": 659, "start": 52.4, "end": 52.72, "phoneme": "ju\u02d0|_"}, {"word": "go", "start_offset": 673, "end_offset": 676, "start": 53.84, "end": 54.08, "phoneme": "go\u028a|_"}, {"word": "near,", "start_offset": 738, "end_offset": 745, "start": 59.04, "end": 59.6, "phoneme": "n\u026ar,"}, {"word": "far,", "start_offset": 768, "end_offset": 776, "start": 61.44, "end": 62.08, "phoneme": "f\u0251\u02d0r,"}, {"word": "wherever", "start_offset": 794, "end_offset": 806, "start": 63.52, "end": 64.48, "phoneme": "w\u025br\u025bv\u025d"}, {"word": "you", "start_offset": 822, "end_offset": 826, "start": 65.76, "end": 66.08, "phoneme": "ju\u02d0|_"}, {"word": "are.", "start_offset": 826, "end_offset": 830, "start": 66.08, "end": 66.4, "phoneme": "\u0251\u02d0r."}, {"word": "I", "start_offset": 849, "end_offset": 852, "start": 67.92, "end": 68.16, "phoneme": "a\u026a|_"}, {"word": "believe", "start_offset": 856, "end_offset": 868, "start": 68.48, "end": 69.44, "phoneme": "b\u026ali\u02d0v"}, {"word": "that", "start_offset": 875, "end_offset": 878, "start": 70.0, "end": 70.24, "phoneme": "\u00f0\u00e6t|_"}, {"word": "the", "start_offset": 886, "end_offset": 890, "start": 70.88, "end": 71.2, "phoneme": "\u00f0\u0259|_"}, {"word": "heart", "start_offset": 890, "end_offset": 898, "start": 71.2, "end": 71.84, "phoneme": "h\u0251\u02d0rt|_"}, {"word": "does", "start_offset": 898, "end_offset": 901, "start": 71.84, "end": 72.08, "phoneme": "d\u028cz"}, {"word": "go", "start_offset": 916, "end_offset": 920, "start": 73.28, "end": 73.60000000000001, "phoneme": "go\u028a|_"}, {"word": "on", "start_offset": 982, "end_offset": 985, "start": 78.56, "end": 78.8, "phoneme": "\u0251\u02d0n|_"}, {"word": "small.", "start_offset": 1009, "end_offset": 1017, "start": 80.72, "end": 81.36, "phoneme": "sm\u0254l."}, {"word": "You", "start_offset": 1037, "end_offset": 1041, "start": 82.96000000000001, "end": 83.28, "phoneme": "ju\u02d0|_"}, {"word": "open", "start_offset": 1045, "end_offset": 1049, "start": 83.60000000000001, "end": 83.92, "phoneme": "o\u028ap\u0259n|_"}, {"word": "the", "start_offset": 1065, "end_offset": 1069, "start": 85.2, "end": 85.52, "phoneme": "\u00f0\u0259|_"}, {"word": "door,", "start_offset": 1069, "end_offset": 1076, "start": 85.52, "end": 86.08, "phoneme": "d\u0254r,"}, {"word": "and", "start_offset": 1090, "end_offset": 1094, "start": 87.2, "end": 87.52, "phoneme": "\u0259nd"}, {"word": "you'll", "start_offset": 1094, "end_offset": 1100, "start": 87.52, "end": 88.0, "phoneme": "j\u028c\u028al|_"}, {"word": "hear", "start_offset": 1103, "end_offset": 1108, "start": 88.24, "end": 88.64, "phoneme": "hi\u02d0r"}, {"word": "in", "start_offset": 1119, "end_offset": 1122, "start": 89.52, "end": 89.76, "phoneme": "\u026an|_"}, {"word": "my", "start_offset": 1126, "end_offset": 1130, "start": 90.08, "end": 90.4, "phoneme": "ma\u026a|_"}, {"word": "heart,", "start_offset": 1130, "end_offset": 1138, "start": 90.4, "end": 91.04, "phoneme": "h\u0251\u02d0rt,"}, {"word": "and", "start_offset": 1141, "end_offset": 1145, "start": 91.28, "end": 91.60000000000001, "phoneme": "\u0259nd"}, {"word": "my", "start_offset": 1157, "end_offset": 1161, "start": 92.56, "end": 92.88, "phoneme": "ma\u026a|_"}, {"word": "heart", "start_offset": 1165, "end_offset": 1173, "start": 93.2, "end": 93.84, "phoneme": "h\u0251\u02d0rt|_"}, {"word": "will", "start_offset": 1173, "end_offset": 1177, "start": 93.84, "end": 94.16, "phoneme": "w\u026al|_"}, {"word": "go", "start_offset": 1185, "end_offset": 1189, "start": 94.8, "end": 95.12, "phoneme": "go\u028a|_"}, {"word": "and", "start_offset": 1211, "end_offset": 1215, "start": 96.88, "end": 97.2, "phoneme": "\u0259nd"}, {"word": "dawn.", "start_offset": 1223, "end_offset": 1233, "start": 97.84, "end": 98.64, "phoneme": "d\u0254n."}, {"word": "Love", "start_offset": 1345, "end_offset": 1353, "start": 107.60000000000001, "end": 108.24000000000001, "phoneme": "l\u028cv"}, {"word": "can", "start_offset": 1356, "end_offset": 1360, "start": 108.48, "end": 108.8, "phoneme": "k\u00e6n|_"}, {"word": "touch", "start_offset": 1360, "end_offset": 1366, "start": 108.8, "end": 109.28, "phoneme": "t\u028ct\u0283|_"}, {"word": "us", "start_offset": 1369, "end_offset": 1373, "start": 109.52, "end": 109.84, "phoneme": "\u028cs|_"}, {"word": "one", "start_offset": 1376, "end_offset": 1380, "start": 110.08, "end": 110.4, "phoneme": "w\u028cn|_"}, {"word": "time", "start_offset": 1384, "end_offset": 1388, "start": 110.72, "end": 111.04, "phoneme": "ta\u026am|_"}, {"word": "and", "start_offset": 1399, "end_offset": 1402, "start": 111.92, "end": 112.16, "phoneme": "\u0259nd"}, {"word": "last", "start_offset": 1406, "end_offset": 1410, "start": 112.48, "end": 112.8, "phoneme": "l\u00e6st|_"}, {"word": "for", "start_offset": 1416, "end_offset": 1420, "start": 113.28, "end": 113.60000000000001, "phoneme": "f\u0254r"}, {"word": "a", "start_offset": 1431, "end_offset": 1435, "start": 114.48, "end": 114.8, "phoneme": "\u0259|_"}, {"word": "lifetime", "start_offset": 1435, "end_offset": 1458, "start": 114.8, "end": 116.64, "phoneme": "la\u026afta\u026am|_"}, {"word": "and", "start_offset": 1471, "end_offset": 1475, "start": 117.68, "end": 118.0, "phoneme": "\u0259nd"}, {"word": "never", "start_offset": 1479, "end_offset": 1483, "start": 118.32000000000001, "end": 118.64, "phoneme": "n\u025bv\u025d"}, {"word": "let", "start_offset": 1487, "end_offset": 1491, "start": 118.96000000000001, "end": 119.28, "phoneme": "l\u025bt|_"}, {"word": "go", "start_offset": 1495, "end_offset": 1499, "start": 119.60000000000001, "end": 119.92, "phoneme": "go\u028a|_"}, {"word": "till", "start_offset": 1503, "end_offset": 1511, "start": 120.24000000000001, "end": 120.88, "phoneme": "t\u026al|_"}, {"word": "we're", "start_offset": 1521, "end_offset": 1528, "start": 121.68, "end": 122.24000000000001, "phoneme": "w\u025d\u02d0|_"}, {"word": "gone.", "start_offset": 1528, "end_offset": 1536, "start": 122.24000000000001, "end": 122.88, "phoneme": "g\u0254n."}, {"word": "Love", "start_offset": 1587, "end_offset": 1596, "start": 126.96000000000001, "end": 127.68, "phoneme": "l\u028cv"}, {"word": "was", "start_offset": 1599, "end_offset": 1603, "start": 127.92, "end": 128.24, "phoneme": "w\u0251\u02d0z"}, {"word": "when", "start_offset": 1607, "end_offset": 1611, "start": 128.56, "end": 128.88, "phoneme": "w\u025bn|_"}, {"word": "I", "start_offset": 1611, "end_offset": 1615, "start": 128.88, "end": 129.2, "phoneme": "a\u026a|_"}, {"word": "loved", "start_offset": 1615, "end_offset": 1626, "start": 129.2, "end": 130.08, "phoneme": "l\u028cvd"}, {"word": "you", "start_offset": 1626, "end_offset": 1630, "start": 130.08, "end": 130.4, "phoneme": "ju\u02d0|_"}, {"word": "one", "start_offset": 1641, "end_offset": 1644, "start": 131.28, "end": 131.52, "phoneme": "w\u028cn|_"}, {"word": "true", "start_offset": 1648, "end_offset": 1656, "start": 131.84, "end": 132.48, "phoneme": "tru\u02d0|_"}, {"word": "time.", "start_offset": 1656, "end_offset": 1660, "start": 132.48, "end": 132.8, "phoneme": "ta\u026am."}, {"word": "I", "start_offset": 1672, "end_offset": 1675, "start": 133.76, "end": 134.0, "phoneme": "a\u026a|_"}, {"word": "hold", "start_offset": 1679, "end_offset": 1687, "start": 134.32, "end": 134.96, "phoneme": "ho\u028ald"}, {"word": "to", "start_offset": 1691, "end_offset": 1693, "start": 135.28, "end": 135.44, "phoneme": "tu\u02d0|_"}, {"word": "in", "start_offset": 1712, "end_offset": 1716, "start": 136.96, "end": 137.28, "phoneme": "\u026an|_"}, {"word": "my", "start_offset": 1720, "end_offset": 1724, "start": 137.6, "end": 137.92000000000002, "phoneme": "ma\u026a|_"}, {"word": "life", "start_offset": 1724, "end_offset": 1728, "start": 137.92000000000002, "end": 138.24, "phoneme": "la\u026af|_"}, {"word": "will", "start_offset": 1731, "end_offset": 1733, "start": 138.48, "end": 138.64000000000001, "phoneme": "w\u026al|_"}, {"word": "always", "start_offset": 1743, "end_offset": 1747, "start": 139.44, "end": 139.76, "phoneme": "\u0254lwe\u026az"}, {"word": "go", "start_offset": 1763, "end_offset": 1767, "start": 141.04, "end": 141.36, "phoneme": "go\u028a|_"}, {"word": "near", "start_offset": 1830, "end_offset": 1836, "start": 146.4, "end": 146.88, "phoneme": "n\u026ar"}, {"word": "far", "start_offset": 1859, "end_offset": 1867, "start": 148.72, "end": 149.36, "phoneme": "f\u0251\u02d0r"}, {"word": "wherever", "start_offset": 1884, "end_offset": 1896, "start": 150.72, "end": 151.68, "phoneme": "w\u025br\u025bv\u025d"}, {"word": "you", "start_offset": 1914, "end_offset": 1918, "start": 153.12, "end": 153.44, "phoneme": "ju\u02d0|_"}, {"word": "are.", "start_offset": 1918, "end_offset": 1922, "start": 153.44, "end": 153.76, "phoneme": "\u0251\u02d0r."}, {"word": "I", "start_offset": 1940, "end_offset": 1943, "start": 155.20000000000002, "end": 155.44, "phoneme": "a\u026a|_"}, {"word": "believe", "start_offset": 1947, "end_offset": 1959, "start": 155.76, "end": 156.72, "phoneme": "b\u026ali\u02d0v"}, {"word": "that", "start_offset": 1966, "end_offset": 1970, "start": 157.28, "end": 157.6, "phoneme": "\u00f0\u00e6t|_"}, {"word": "the", "start_offset": 1974, "end_offset": 1977, "start": 157.92000000000002, "end": 158.16, "phoneme": "\u00f0\u0259|_"}, {"word": "heart", "start_offset": 1981, "end_offset": 1986, "start": 158.48, "end": 158.88, "phoneme": "h\u0251\u02d0rt|_"}, {"word": "does", "start_offset": 1990, "end_offset": 1993, "start": 159.20000000000002, "end": 159.44, "phoneme": "d\u028cz"}, {"word": "go", "start_offset": 2008, "end_offset": 2011, "start": 160.64000000000001, "end": 160.88, "phoneme": "go\u028a|_"}, {"word": "small.", "start_offset": 2099, "end_offset": 2111, "start": 167.92000000000002, "end": 168.88, "phoneme": "sm\u0254l."}, {"word": "You", "start_offset": 2127, "end_offset": 2131, "start": 170.16, "end": 170.48, "phoneme": "ju\u02d0|_"}, {"word": "open", "start_offset": 2136, "end_offset": 2140, "start": 170.88, "end": 171.20000000000002, "phoneme": "o\u028ap\u0259n|_"}, {"word": "the", "start_offset": 2156, "end_offset": 2160, "start": 172.48, "end": 172.8, "phoneme": "\u00f0\u0259|_"}, {"word": "door", "start_offset": 2160, "end_offset": 2167, "start": 172.8, "end": 173.36, "phoneme": "d\u0254r"}, {"word": "and", "start_offset": 2181, "end_offset": 2185, "start": 174.48, "end": 174.8, "phoneme": "\u0259nd"}, {"word": "you", "start_offset": 2185, "end_offset": 2187, "start": 174.8, "end": 174.96, "phoneme": "ju\u02d0|_"}, {"word": "hear", "start_offset": 2195, "end_offset": 2203, "start": 175.6, "end": 176.24, "phoneme": "hi\u02d0r"}, {"word": "in", "start_offset": 2209, "end_offset": 2213, "start": 176.72, "end": 177.04, "phoneme": "\u026an|_"}, {"word": "my", "start_offset": 2217, "end_offset": 2221, "start": 177.36, "end": 177.68, "phoneme": "ma\u026a|_"}, {"word": "heart,", "start_offset": 2221, "end_offset": 2230, "start": 177.68, "end": 178.4, "phoneme": "h\u0251\u02d0rt,"}, {"word": "and", "start_offset": 2232, "end_offset": 2236, "start": 178.56, "end": 178.88, "phoneme": "\u0259nd"}, {"word": "my", "start_offset": 2248, "end_offset": 2251, "start": 179.84, "end": 180.08, "phoneme": "ma\u026a|_"}, {"word": "heart", "start_offset": 2255, "end_offset": 2263, "start": 180.4, "end": 181.04, "phoneme": "h\u0251\u02d0rt|_"}, {"word": "will", "start_offset": 2263, "end_offset": 2266, "start": 181.04, "end": 181.28, "phoneme": "w\u026al|_"}, {"word": "go", "start_offset": 2278, "end_offset": 2282, "start": 182.24, "end": 182.56, "phoneme": "go\u028a|_"}, {"word": "on.", "start_offset": 2286, "end_offset": 2289, "start": 182.88, "end": 183.12, "phoneme": "\u0251\u02d0n."}, {"word": "You", "start_offset": 2557, "end_offset": 2559, "start": 204.56, "end": 204.72, "phoneme": "ju\u02d0|_"}, {"word": "hear", "start_offset": 2587, "end_offset": 2594, "start": 206.96, "end": 207.52, "phoneme": "hi\u02d0r"}, {"word": "there's", "start_offset": 2610, "end_offset": 2620, "start": 208.8, "end": 209.6, "phoneme": "\u00f0\u025brz"}, {"word": "nothing", "start_offset": 2620, "end_offset": 2632, "start": 209.6, "end": 210.56, "phoneme": "n\u028c\u03b8\u026a\u014b|_"}, {"word": "I", "start_offset": 2640, "end_offset": 2644, "start": 211.20000000000002, "end": 211.52, "phoneme": "a\u026a|_"}, {"word": "fear,", "start_offset": 2644, "end_offset": 2651, "start": 211.52, "end": 212.08, "phoneme": "f\u026ar,"}, {"word": "and", "start_offset": 2666, "end_offset": 2669, "start": 213.28, "end": 213.52, "phoneme": "\u0259nd"}, {"word": "I", "start_offset": 2673, "end_offset": 2677, "start": 213.84, "end": 214.16, "phoneme": "a\u026a|_"}, {"word": "know", "start_offset": 2677, "end_offset": 2681, "start": 214.16, "end": 214.48000000000002, "phoneme": "no\u028a|_"}, {"word": "that", "start_offset": 2693, "end_offset": 2697, "start": 215.44, "end": 215.76, "phoneme": "\u00f0\u00e6t|_"}, {"word": "my", "start_offset": 2701, "end_offset": 2705, "start": 216.08, "end": 216.4, "phoneme": "ma\u026a|_"}, {"word": "heart", "start_offset": 2705, "end_offset": 2713, "start": 216.4, "end": 217.04, "phoneme": "h\u0251\u02d0rt|_"}, {"word": "will", "start_offset": 2717, "end_offset": 2721, "start": 217.36, "end": 217.68, "phoneme": "w\u026al|_"}, {"word": "go", "start_offset": 2733, "end_offset": 2736, "start": 218.64000000000001, "end": 218.88, "phoneme": "go\u028a|_"}, {"word": "forever", "start_offset": 2852, "end_offset": 2863, "start": 228.16, "end": 229.04, "phoneme": "f\u025d\u025bv\u025d"}, {"word": "this", "start_offset": 2881, "end_offset": 2883, "start": 230.48000000000002, "end": 230.64000000000001, "phoneme": "\u00f0\u026as|_"}, {"word": "way.", "start_offset": 2888, "end_offset": 2892, "start": 231.04, "end": 231.36, "phoneme": "we\u026a."}, {"word": "You", "start_offset": 2908, "end_offset": 2911, "start": 232.64000000000001, "end": 232.88, "phoneme": "ju\u02d0|_"}, {"word": "are", "start_offset": 2914, "end_offset": 2918, "start": 233.12, "end": 233.44, "phoneme": "\u0251\u02d0r"}, {"word": "safe", "start_offset": 2928, "end_offset": 2935, "start": 234.24, "end": 234.8, "phoneme": "se\u026af|_"}, {"word": "in", "start_offset": 2938, "end_offset": 2942, "start": 235.04, "end": 235.36, "phoneme": "\u026an|_"}, {"word": "my", "start_offset": 2942, "end_offset": 2946, "start": 235.36, "end": 235.68, "phoneme": "ma\u026a|_"}, {"word": "heart,", "start_offset": 2950, "end_offset": 2957, "start": 236.0, "end": 236.56, "phoneme": "h\u0251\u02d0rt,"}, {"word": "and", "start_offset": 2959, "end_offset": 2963, "start": 236.72, "end": 237.04, "phoneme": "\u0259nd"}, {"word": "my", "start_offset": 2975, "end_offset": 2978, "start": 238.0, "end": 238.24, "phoneme": "ma\u026a|_"}, {"word": "heart", "start_offset": 2982, "end_offset": 2990, "start": 238.56, "end": 239.20000000000002, "phoneme": "h\u0251\u02d0rt|_"}, {"word": "will", "start_offset": 2990, "end_offset": 2994, "start": 239.20000000000002, "end": 239.52, "phoneme": "w\u026al|_"}, {"word": "go", "start_offset": 3002, "end_offset": 3005, "start": 240.16, "end": 240.4, "phoneme": "go\u028a|_"}, {"word": "on", "start_offset": 3009, "end_offset": 3012, "start": 240.72, "end": 240.96, "phoneme": "\u0251\u02d0n|_"}, {"word": "there.", "start_offset": 3028, "end_offset": 3032, "start": 242.24, "end": 242.56, "phoneme": "\u00f0\u025br."}]
 
 
jam_infer.yaml ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ project_root: "."
2
+ evaluation:
3
+ checkpoint_path: ""
4
+ test_set_path: ""
5
+ negative_style_prompt: ${project_root}/public/vocal.npy
6
+ num_samples: null
7
+ batch_size: 1
8
+ random_crop_style: false
9
+ vae_type: 'diffrhythm'
10
+ num_style_secs: 30
11
+ ignore_style: false
12
+ use_prompt_style: false
13
+
14
+ dataset:
15
+ pattern: "placeholder"
16
+ shuffle: false
17
+ resample_by_duration_threshold: null
18
+ always_crop_from_beginning: true
19
+ always_use_style_index: 0
20
+
21
+ sample_kwargs:
22
+ batch_infer_num: 1
23
+ cfg_range:
24
+ - 0.05
25
+ - 1
26
+ dual_cfg:
27
+ - 4.7
28
+ - 2.5
29
+ steps: 50
30
+
31
+ model:
32
+ num_channels: 64
33
+ cfm:
34
+ max_frames: ${max_frames}
35
+ num_channels: ${model.num_channels}
36
+ dual_drop_prob: [0.1, 0.5]
37
+ no_edit: true
38
+
39
+ dit:
40
+ max_frames: ${max_frames}
41
+ mel_dim: ${model.num_channels}
42
+ dim: 1408
43
+ depth: 16
44
+ heads: 32
45
+ ff_mult: 4
46
+ text_dim: 512
47
+ conv_layers: 4
48
+ grad_ckpt: true
49
+ use_implicit_duration: true
50
+
51
+ data:
52
+ train_dataset:
53
+ max_frames: ${max_frames}
54
+ multiple_styles: true
55
+ sampling_rate: 44100
56
+ shuffle: true
57
+ silence_latent_path: ${project_root}/public/silience_latent.pt
58
+ tokenizer_path: ${project_root}/public/en_us_cmudict_ipa_forward.pt
59
+ lrc_upsample_factor: ${lrc_upsample_factor}
60
+ filler: average_sparse
61
+ phonemizer_checkpoint: ${project_root}/public/en_us_cmudict_ipa_forward.pt
62
+
63
+ # General settings
64
+ max_frames: 5000
65
+ lrc_upsample_factor: 4
66
+ seed: 42
model.py CHANGED
@@ -1,194 +1,368 @@
 
 
 
 
 
1
 
 
 
 
 
 
 
 
2
  import torch
3
  import torchaudio
4
  from omegaconf import OmegaConf
5
- from huggingface_hub import snapshot_download
6
- import numpy as np
7
- import json
8
- import os
9
  from safetensors.torch import load_file
10
-
11
- # Imports from the jamify library
12
- from jam.model.cfm import CFM
13
- from jam.model.dit import DiT
14
- from jam.model.vae import StableAudioOpenVAE
15
- from jam.dataset import DiffusionWebDataset, enhance_webdataset_config
16
  from muq import MuQMuLan
 
 
 
 
 
 
 
 
17
 
18
- # Helper functions adapted from jamify/src/jam/infer.py
19
  def get_negative_style_prompt(device, file_path):
20
- vocal_style = np.load(file_path)
21
- vocal_style = torch.from_numpy(vocal_style).to(device)
22
- return vocal_style.half()
 
 
 
23
 
24
- def normalize_audio(audio):
25
  audio = audio - audio.mean(-1, keepdim=True)
26
  audio = audio / (audio.abs().max(-1, keepdim=True).values + 1e-8)
27
- return audio
 
 
 
 
 
 
 
 
28
 
29
- class Jamify:
30
- def __init__(self, device="cuda" if torch.cuda.is_available() else "cpu"):
31
- self.device = torch.device(device)
 
 
32
 
33
- # --- FIX: Point to the local jamify repository for config and public files ---
34
- #jamify_repo_path = "/Users/cy/Desktop/JAM/jamify"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
 
36
- print("Downloading main model checkpoint...")
37
- model_repo_path = snapshot_download(repo_id="declare-lab/jam-0.5")
38
- self.checkpoint_path = os.path.join(model_repo_path, "jam-0_5.safetensors")
39
-
40
- # Use local config and data files
41
- config_path = os.path.join(model_repo_path, "jam_infer.yaml")
42
- self.negative_style_prompt_path = os.path.join(model_repo_path, "vocal.npy")
43
- tokenizer_path = os.path.join(model_repo_path, "en_us_cmudict_ipa_forward.pt")
44
- silence_latent_path = os.path.join(model_repo_path, "silience_latent.pt")
45
- print("Loading configuration...")
46
- self.config = OmegaConf.load(config_path)
47
- self.config.data.train_dataset.silence_latent_path = silence_latent_path
48
 
49
- # --- FIX: Override the relative paths in the config with absolute paths ---
50
- self.config.data.train_dataset.tokenizer_path = tokenizer_path
51
- self.config.evaluation.dataset.tokenizer_path = tokenizer_path
52
- self.config.data.train_dataset.phonemizer_checkpoint = tokenizer_path
 
 
 
 
 
 
 
 
53
 
54
- print("Loading VAE model...")
55
- self.vae = StableAudioOpenVAE().to(self.device).eval()
56
 
57
- print("Loading CFM model...")
58
- self.cfm_model = self._load_cfm_model(self.config.model, self.checkpoint_path)
 
 
 
59
 
60
- print("Loading MuQ style model...")
61
- self.muq_model = MuQMuLan.from_pretrained("OpenMuQ/MuQ-MuLan-large").to(self.device).eval()
62
-
63
- print("Setting up dataset processor...")
64
- dataset_cfg = OmegaConf.merge(self.config.data.train_dataset, self.config.evaluation.dataset)
65
- enhance_webdataset_config(dataset_cfg)
66
- dataset_cfg.multiple_styles = False
67
- self.dataset_processor = DiffusionWebDataset(**dataset_cfg)
68
-
69
- print("Jamify model loaded successfully.")
70
-
71
- def _load_cfm_model(self, model_config, checkpoint_path):
72
- dit_config = model_config["dit"].copy()
73
- if "text_num_embeds" not in dit_config:
74
- dit_config["text_num_embeds"] = 256
75
 
76
- model = CFM(
77
- transformer=DiT(**dit_config),
78
- **model_config["cfm"]
79
- ).to(self.device)
80
 
81
- state_dict = load_file(checkpoint_path)
82
- model.load_state_dict(state_dict, strict=False)
83
- return model.eval()
84
-
85
- def _generate_style_embedding_from_audio(self, audio_path):
 
 
 
 
 
 
 
 
 
 
86
  waveform, sample_rate = torchaudio.load(audio_path)
 
 
87
  if sample_rate != 24000:
88
  resampler = torchaudio.transforms.Resample(sample_rate, 24000)
89
  waveform = resampler(waveform)
 
 
90
  if waveform.shape[0] > 1:
91
  waveform = waveform.mean(dim=0, keepdim=True)
92
 
93
- waveform = waveform.squeeze(0).to(self.device)
 
 
 
 
94
 
 
95
  with torch.inference_mode():
96
- style_embedding = self.muq_model(wavs=waveform.unsqueeze(0)[..., :24000 * 30])
 
 
 
 
 
 
 
 
 
 
 
97
  return style_embedding[0]
98
 
99
- def _generate_style_embedding_from_prompt(self, prompt):
100
- with torch.inference_mode():
101
- style_embedding = self.muq_model(texts=[prompt]).squeeze(0)
102
- return style_embedding
103
 
104
- def predict(self, reference_audio_path, lyrics_json_path, style_prompt, duration_sec=30, steps=50):
105
- print("Starting prediction...")
106
-
107
- if reference_audio_path:
108
- print(f"Generating style from audio: {reference_audio_path}")
109
- style_embedding = self._generate_style_embedding_from_audio(reference_audio_path)
110
- elif style_prompt:
111
- print(f"Generating style from prompt: '{style_prompt}'")
112
- style_embedding = self._generate_style_embedding_from_prompt(style_prompt)
113
- else:
114
- print("No style provided, using zero embedding.")
115
- style_embedding = torch.zeros(512, device=self.device)
 
 
 
 
 
 
116
 
117
- print(f"Loading lyrics from: {lyrics_json_path}")
118
- with open(lyrics_json_path, 'r') as f:
119
- lrc_data = json.load(f)
120
- if 'word' not in lrc_data:
121
- lrc_data = {'word': lrc_data}
122
 
123
- frame_rate = 21.5
124
- num_frames = int(duration_sec * frame_rate)
125
- fake_latent = torch.randn(128, num_frames)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
126
 
127
- sample_tuple = ("user_song", fake_latent, style_embedding, lrc_data)
 
 
 
 
 
 
 
 
 
 
 
128
 
129
- print("Processing sample...")
130
- processed_sample = self.dataset_processor.process_sample_safely(sample_tuple)
131
- if processed_sample is None:
132
- raise ValueError("Failed to process the provided lyrics and style.")
 
 
 
 
 
133
 
134
- batch = self.dataset_processor.custom_collate_fn([processed_sample])
 
 
 
 
 
 
 
 
 
 
135
 
136
- for key, value in batch.items():
137
- if isinstance(value, torch.Tensor):
138
- batch[key] = value.to(self.device)
139
 
140
- print("Generating audio latent...")
141
- with torch.inference_mode():
142
- batch_size = 1
143
- text = batch["lrc"]
144
- style_prompt_tensor = batch["prompt"]
145
- start_time = batch["start_time"]
146
- duration_abs = batch["duration_abs"]
147
- duration_rel = batch["duration_rel"]
148
-
149
- cond = torch.zeros(batch_size, self.cfm_model.max_frames, 64).to(self.device)
150
- pred_frames = [(0, self.cfm_model.max_frames)]
151
-
152
- negative_style_prompt = get_negative_style_prompt(self.device, self.negative_style_prompt_path)
153
- negative_style_prompt = negative_style_prompt.repeat(batch_size, 1)
154
-
155
- sample_kwargs = self.config.evaluation.sample_kwargs
156
- sample_kwargs.steps = steps
157
- latents, _ = self.cfm_model.sample(
158
- cond=cond, text=text, style_prompt=style_prompt_tensor,
159
- duration_abs=duration_abs, duration_rel=duration_rel,
160
- negative_style_prompt=negative_style_prompt, start_time=start_time,
161
- latent_pred_segments=pred_frames, **sample_kwargs)
162
-
163
- latent = latents[0][0]
164
 
165
- print("Decoding latent to audio...")
166
- latent_for_vae = latent.transpose(0, 1).unsqueeze(0)
167
- pred_audio = self.vae.decode(latent_for_vae).sample.squeeze(0).detach().cpu()
168
 
169
- pred_audio = normalize_audio(pred_audio)
 
 
 
 
 
 
 
 
170
 
171
- sample_rate = 44100
172
- trim_samples = int(duration_sec * sample_rate)
173
- if pred_audio.shape[1] > trim_samples:
174
- pred_audio = pred_audio[:, :trim_samples]
175
-
176
- import time
177
- import glob
178
 
 
 
 
 
 
 
 
 
 
 
 
 
179
  # Clean up old generated files (keep only last 5 files)
180
- old_files = sorted(glob.glob("generated_song_*.mp3"))
181
- if len(old_files) >= 5:
182
- for old_file in old_files[:-4]: # Keep last 4, delete older ones
183
  try:
184
  os.remove(old_file)
185
  print(f"Cleaned up old file: {old_file}")
186
  except OSError:
187
  pass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
188
 
189
- timestamp = int(time.time() * 1000) # Use milliseconds for uniqueness
190
- output_path = f"generated_song_{timestamp}.mp3"
191
- print(f"Saving audio to {output_path}")
192
- torchaudio.save(output_path, pred_audio, sample_rate, format="mp3")
193
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
194
  return output_path
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Generate audio using JAM model
4
+ Reads from filtered test set and generates audio using CFM+DiT model.
5
+ """
6
 
7
+ import os
8
+ import glob
9
+ import time
10
+ import json
11
+ import random
12
+ import sys
13
+ from huggingface_hub import snapshot_download
14
  import torch
15
  import torchaudio
16
  from omegaconf import OmegaConf
17
+ from torch.utils.data import DataLoader, Dataset
18
+ from tqdm.auto import tqdm
19
+ import accelerate
20
+ import pyloudnorm as pyln
21
  from safetensors.torch import load_file
 
 
 
 
 
 
22
  from muq import MuQMuLan
23
+ import numpy as np
24
+ from accelerate import Accelerator
25
+
26
+ from jam.dataset import enhance_webdataset_config, DiffusionWebDataset
27
+ from jam.model.vae import StableAudioOpenVAE, DiffRhythmVAE
28
+
29
+ # DiffRhythm imports for CFM+DiT model
30
+ from jam.model import CFM, DiT
31
 
 
32
  def get_negative_style_prompt(device, file_path):
33
+ vocal_stlye = np.load(file_path)
34
+
35
+ vocal_stlye = torch.from_numpy(vocal_stlye).to(device) # [1, 512]
36
+ vocal_stlye = vocal_stlye.half()
37
+
38
+ return vocal_stlye
39
 
40
+ def normalize_audio(audio, normalize_lufs=True):
41
  audio = audio - audio.mean(-1, keepdim=True)
42
  audio = audio / (audio.abs().max(-1, keepdim=True).values + 1e-8)
43
+ if normalize_lufs:
44
+ meter = pyln.Meter(rate=44100)
45
+ target_lufs = -14.0
46
+ loudness = meter.integrated_loudness(audio.transpose(0, 1).numpy())
47
+ normalised = pyln.normalize.loudness(audio.transpose(0, 1).numpy(), loudness, target_lufs)
48
+ normalised = torch.from_numpy(normalised).transpose(0, 1)
49
+ else:
50
+ normalised = audio
51
+ return normalised
52
 
53
+ class FilteredTestSetDataset(Dataset):
54
+ """Custom dataset for loading from filtered test set JSON"""
55
+ def __init__(self, test_set_path, diffusion_dataset, muq_model, num_samples=None, random_crop_style=False, num_style_secs=30, use_prompt_style=False):
56
+ with open(test_set_path, 'r') as f:
57
+ self.test_samples = json.load(f)
58
 
59
+ if num_samples is not None:
60
+ self.test_samples = self.test_samples[:num_samples]
61
+
62
+ self.diffusion_dataset = diffusion_dataset
63
+ self.muq_model = muq_model
64
+ self.random_crop_style = random_crop_style
65
+ self.num_style_secs = num_style_secs
66
+ self.use_prompt_style = use_prompt_style
67
+ if self.use_prompt_style:
68
+ print("Using prompt style instead of audio style.")
69
+
70
+ def __len__(self):
71
+ return len(self.test_samples)
72
+
73
+ def __getitem__(self, idx):
74
+ test_sample = self.test_samples[idx]
75
+ sample_id = test_sample["id"]
76
 
77
+ # Load LRC data
78
+ lrc_path = test_sample["lrc_path"]
79
+ with open(lrc_path, 'r') as f:
80
+ lrc_data = json.load(f)
81
+ if 'word' not in lrc_data:
82
+ data = {'word': lrc_data}
83
+ lrc_data = data
 
 
 
 
 
84
 
85
+ # Generate style embedding from original audio on-the-fly
86
+ audio_path = test_sample["audio_path"]
87
+ if self.use_prompt_style:
88
+ prompt_path = test_sample["prompt_path"]
89
+ prompt = open(prompt_path, 'r').read()
90
+ if len(prompt) > 300:
91
+ print(f"Sample {sample_id} has prompt length {len(prompt)}")
92
+ prompt = prompt[:300]
93
+ print(prompt)
94
+ style_embedding = self.muq_model(texts=[prompt]).squeeze(0)
95
+ else:
96
+ style_embedding = self.generate_style_embedding(audio_path)
97
 
98
+ duration = test_sample["duration"]
 
99
 
100
+ # Create fake latent with correct length
101
+ # Assuming frame_rate from config (typically 21.5 fps for 44.1kHz)
102
+ frame_rate = 21.5
103
+ num_frames = int(duration * frame_rate)
104
+ fake_latent = torch.randn(128, num_frames) # 128 is latent dim
105
 
106
+ # Create sample tuple matching DiffusionWebDataset format
107
+ fake_sample = (
108
+ sample_id,
109
+ fake_latent, # latent with correct duration
110
+ style_embedding, # style from actual audio
111
+ lrc_data # actual LRC data
112
+ )
 
 
 
 
 
 
 
 
113
 
114
+ # Process through DiffusionWebDataset's process_sample_safely
115
+ processed_sample = self.diffusion_dataset.process_sample_safely(fake_sample)
 
 
116
 
117
+ # Add metadata
118
+ if processed_sample is not None:
119
+ processed_sample['test_metadata'] = {
120
+ 'sample_id': sample_id,
121
+ 'audio_path': audio_path,
122
+ 'lrc_path': lrc_path,
123
+ 'duration': duration,
124
+ 'num_frames': num_frames
125
+ }
126
+
127
+ return processed_sample
128
+
129
+ def generate_style_embedding(self, audio_path):
130
+ """Generate style embedding using MuQ model on the whole music"""
131
+ # Load audio
132
  waveform, sample_rate = torchaudio.load(audio_path)
133
+
134
+ # Resample to 24kHz if needed (MuQ expects 24kHz)
135
  if sample_rate != 24000:
136
  resampler = torchaudio.transforms.Resample(sample_rate, 24000)
137
  waveform = resampler(waveform)
138
+
139
+ # Convert to mono if stereo
140
  if waveform.shape[0] > 1:
141
  waveform = waveform.mean(dim=0, keepdim=True)
142
 
143
+ # Ensure waveform is 2D (channels, time) - squeeze out channel dim for mono
144
+ waveform = waveform.squeeze(0) # Now shape is (time,)
145
+
146
+ # Move to same device as model
147
+ waveform = waveform.to(self.muq_model.device)
148
 
149
+ # Generate embedding using MuQ model
150
  with torch.inference_mode():
151
+ # MuQ expects batch dimension and 1D audio, returns (batch, embedding_dim)
152
+ if self.random_crop_style:
153
+ # Randomly crop 30 seconds from the waveform
154
+ total_samples = waveform.shape[0]
155
+ target_samples = 24000 * self.num_style_secs # 30 seconds at 24kHz
156
+
157
+ start_idx = random.randint(0, total_samples - target_samples)
158
+ style_embedding = self.muq_model(wavs=waveform.unsqueeze(0)[..., start_idx:start_idx + target_samples])
159
+ else:
160
+ style_embedding = self.muq_model(wavs=waveform.unsqueeze(0)[..., :24000 * self.num_style_secs])
161
+
162
+ # Keep shape as (embedding_dim,) not scalar
163
  return style_embedding[0]
164
 
 
 
 
 
165
 
166
+ def custom_collate_fn_with_metadata(batch, base_collate_fn):
167
+ """Custom collate function that preserves test_metadata"""
168
+ # Filter out None samples
169
+ batch = [item for item in batch if item is not None]
170
+ if not batch:
171
+ return None
172
+
173
+ # Extract test_metadata before collating
174
+ test_metadata = [item.pop('test_metadata') for item in batch]
175
+
176
+ # Use base collate function for the rest
177
+ collated = base_collate_fn(batch)
178
+
179
+ # Add test_metadata back
180
+ if collated is not None:
181
+ collated['test_metadata'] = test_metadata
182
+
183
+ return collated
184
 
 
 
 
 
 
185
 
186
+ def load_model(model_config, checkpoint_path, device):
187
+ """
188
+ Load JAM CFM model from checkpoint (follows infer.py pattern)
189
+ """
190
+ # Build CFM model from config
191
+ dit_config = model_config["dit"].copy()
192
+ # Add text_num_embeds if not specified - should be at least 64 for phoneme tokens
193
+ if "text_num_embeds" not in dit_config:
194
+ dit_config["text_num_embeds"] = 256 # Default value from DiT
195
+
196
+ cfm = CFM(
197
+ transformer=DiT(**dit_config),
198
+ **model_config["cfm"]
199
+ )
200
+ cfm = cfm.to(device)
201
+
202
+ # Load checkpoint - use the path from config
203
+ checkpoint = load_file(checkpoint_path)
204
+ cfm.load_state_dict(checkpoint, strict=False)
205
+
206
+ return cfm.eval()
207
+
208
+
209
+ def generate_latent(model, batch, sample_kwargs, negative_style_prompt_path=None, ignore_style=False, device='cuda'):
210
+ """
211
+ Generate latent from batch data (follows infer.py pattern)
212
+ """
213
+ with torch.inference_mode():
214
+ batch_size = len(batch["lrc"])
215
+ text = batch["lrc"].to(device)
216
+ style_prompt = batch["prompt"].to(device)
217
+ start_time = batch["start_time"].to(device)
218
+ duration_abs = batch["duration_abs"].to(device)
219
+ duration_rel = batch["duration_rel"].to(device)
220
 
221
+ # Create zero conditioning latent
222
+ # Handle case where model might be wrapped by accelerator
223
+ max_frames = model.max_frames
224
+ cond = torch.zeros(batch_size, max_frames, 64).to(text.device)
225
+ pred_frames = [(0, max_frames)]
226
+
227
+ default_sample_kwargs = {
228
+ "cfg_strength": 4,
229
+ "steps": 50,
230
+ "batch_infer_num": 1
231
+ }
232
+ sample_kwargs = {**default_sample_kwargs, **sample_kwargs}
233
 
234
+ if negative_style_prompt_path is None:
235
+ negative_style_prompt_path = 'public_checkpoints/vocal.npy'
236
+ negative_style_prompt = get_negative_style_prompt(text.device, negative_style_prompt_path)
237
+ elif negative_style_prompt_path == 'zeros':
238
+ negative_style_prompt = torch.zeros(1, 512).to(text.device)
239
+ else:
240
+ negative_style_prompt = get_negative_style_prompt(text.device, negative_style_prompt_path)
241
+
242
+ negative_style_prompt = negative_style_prompt.repeat(batch_size, 1)
243
 
244
+ latents, _ = model.sample(
245
+ cond=cond,
246
+ text=text,
247
+ style_prompt=negative_style_prompt if ignore_style else style_prompt,
248
+ duration_abs=duration_abs,
249
+ duration_rel=duration_rel,
250
+ negative_style_prompt=negative_style_prompt,
251
+ start_time=start_time,
252
+ latent_pred_segments=pred_frames,
253
+ **sample_kwargs
254
+ )
255
 
256
+ return latents
 
 
257
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
258
 
259
+ class Jamify:
260
+ def __init__(self):
261
+ os.makedirs('outputs', exist_ok=True)
262
 
263
+ device = 'cuda'
264
+ config_path = 'jam_infer.yaml'
265
+ self.config = OmegaConf.load(config_path)
266
+ OmegaConf.resolve(self.config)
267
+
268
+ # Override output directory for evaluation
269
+ print("Downloading main model checkpoint...")
270
+ model_repo_path = snapshot_download(repo_id="declare-lab/jam-0.5")
271
+ self.config.evaluation.checkpoint_path = os.path.join(model_repo_path, "jam-0_5.safetensors")
272
 
273
+ # Load VAE based on configuration
274
+ vae_type = self.config.evaluation.get('vae_type', 'stable_audio')
275
+ if vae_type == 'diffrhythm':
276
+ vae = DiffRhythmVAE(device=device).to(device)
277
+ else:
278
+ vae = StableAudioOpenVAE().to(device)
 
279
 
280
+ self.vae = vae
281
+ self.vae_type = vae_type
282
+ self.cfm_model = load_model(self.config.model, self.config.evaluation.checkpoint_path, device)
283
+ self.muq_model = MuQMuLan.from_pretrained("OpenMuQ/MuQ-MuLan-large").to(device).eval()
284
+
285
+ dataset_cfg = OmegaConf.merge(self.config.data.train_dataset, self.config.evaluation.dataset)
286
+ enhance_webdataset_config(dataset_cfg)
287
+ # Override multiple_styles to False since we're generating single style embeddings
288
+ dataset_cfg.multiple_styles = False
289
+ self.base_dataset = DiffusionWebDataset(**dataset_cfg)
290
+
291
+ def cleanup_old_files(self, sample_id):
292
  # Clean up old generated files (keep only last 5 files)
293
+ old_mp3_files = sorted(glob.glob("outputs/*.mp3"))
294
+ if len(old_mp3_files) >= 10:
295
+ for old_file in old_mp3_files[:-9]: # Keep last 4, delete older ones
296
  try:
297
  os.remove(old_file)
298
  print(f"Cleaned up old file: {old_file}")
299
  except OSError:
300
  pass
301
+ os.unlink(f"outputs/{sample_id}.json")
302
+
303
+ def predict(self, reference_audio_path, lyrics_json_path, style_prompt, duration):
304
+ sample_id = str(int(time.time() * 1000000)) # microsecond timestamp for uniqueness
305
+ test_set = [{
306
+ "id": sample_id,
307
+ "audio_path": reference_audio_path,
308
+ "lrc_path": lyrics_json_path,
309
+ "duration": duration,
310
+ "prompt_path": style_prompt
311
+ }]
312
+ json.dump(test_set, open(f"outputs/{sample_id}.json", "w"))
313
+
314
+ # Create filtered test set dataset
315
+ test_dataset = FilteredTestSetDataset(
316
+ test_set_path=f"outputs/{sample_id}.json",
317
+ diffusion_dataset=self.base_dataset,
318
+ muq_model=self.muq_model,
319
+ num_samples=1,
320
+ random_crop_style=self.config.evaluation.random_crop_style,
321
+ num_style_secs=self.config.evaluation.num_style_secs,
322
+ use_prompt_style=self.config.evaluation.use_prompt_style
323
+ )
324
+
325
+ # Create dataloader with custom collate function
326
+ dataloader = DataLoader(
327
+ test_dataset,
328
+ batch_size=1,
329
+ shuffle=False,
330
+ collate_fn=lambda batch: custom_collate_fn_with_metadata(batch, self.base_dataset.custom_collate_fn)
331
+ )
332
 
333
+ batch = next(iter(dataloader))
334
+ sample_kwargs = self.config.evaluation.sample_kwargs
335
+ latent = generate_latent(self.cfm_model, batch, sample_kwargs, self.config.evaluation.negative_style_prompt, self.config.evaluation.ignore_style)[0][0]
 
336
 
337
+ test_metadata = batch['test_metadata'][0]
338
+ sample_id = test_metadata['sample_id']
339
+ original_duration = test_metadata['duration']
340
+
341
+ # Decode audio
342
+ latent_for_vae = latent.transpose(0, 1).unsqueeze(0)
343
+
344
+ # Use chunked decoding if configured (only for DiffRhythm VAE)
345
+ use_chunked = self.config.evaluation.get('use_chunked_decoding', True)
346
+ if self.vae_type == 'diffrhythm' and use_chunked:
347
+ pred_audio = self.vae.decode(
348
+ latent_for_vae,
349
+ chunked=True,
350
+ overlap=self.config.evaluation.get('chunked_overlap', 32),
351
+ chunk_size=self.config.evaluation.get('chunked_size', 128)
352
+ ).sample.squeeze(0).detach().cpu()
353
+ else:
354
+ pred_audio = self.vae.decode(latent_for_vae).sample.squeeze(0).detach().cpu()
355
+
356
+ pred_audio = normalize_audio(pred_audio)
357
+ sample_rate = 44100
358
+ trim_samples = int(original_duration * sample_rate)
359
+ if pred_audio.shape[1] > trim_samples:
360
+ pred_audio_trimmed = pred_audio[:, :trim_samples]
361
+ else:
362
+ pred_audio_trimmed = pred_audio
363
+
364
+ output_path = f'outputs/{sample_id}.mp3'
365
+ torchaudio.save(output_path, pred_audio_trimmed, sample_rate, format="mp3")
366
+ self.cleanup_old_files(sample_id)
367
  return output_path
368
+