1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
|
diff --git a/openunmix/__init__.py b/openunmix/__init__.py
index 3a5324b..84b94fe 100644
--- a/openunmix/__init__.py
+++ b/openunmix/__init__.py
@@ -7,9 +7,15 @@ Please checkout [the open-unmix website](https://sigsep.github.io/open-unmix) fo
"""
from openunmix import utils
+from pathlib import Path
import torch.hub
+def packaged_model_dir(model_name):
+ path = Path('/usr/share') / f'open-unmix-{model_name}-weights'
+ return str(path) if path.exists() else None
+
+
def umxse_spec(targets=None, device="cpu", pretrained=True):
target_urls = {
"speech": "https://zenodo.org/records/3786908/files/speech_f5e0d9f9.pth",
@@ -31,7 +37,8 @@ def umxse_spec(targets=None, device="cpu", pretrained=True):
# enable centering of stft to minimize reconstruction error
if pretrained:
- state_dict = torch.hub.load_state_dict_from_url(target_urls[target], map_location=device)
+ state_dict = torch.hub.load_state_dict_from_url(
+ target_urls[target], model_dir=packaged_model_dir('umxse'), map_location=device)
target_unmix.load_state_dict(state_dict, strict=False)
target_unmix.eval()
@@ -116,7 +123,8 @@ def umxhq_spec(targets=None, device="cpu", pretrained=True):
# enable centering of stft to minimize reconstruction error
if pretrained:
- state_dict = torch.hub.load_state_dict_from_url(target_urls[target], map_location=device)
+ state_dict = torch.hub.load_state_dict_from_url(
+ target_urls[target], model_dir=packaged_model_dir('umxhq'), map_location=device)
target_unmix.load_state_dict(state_dict, strict=False)
target_unmix.eval()
@@ -203,7 +211,8 @@ def umx_spec(targets=None, device="cpu", pretrained=True):
# enable centering of stft to minimize reconstruction error
if pretrained:
- state_dict = torch.hub.load_state_dict_from_url(target_urls[target], map_location=device)
+ state_dict = torch.hub.load_state_dict_from_url(
+ target_urls[target], model_dir=packaged_model_dir('umx'), map_location=device)
target_unmix.load_state_dict(state_dict, strict=False)
target_unmix.eval()
@@ -290,7 +299,8 @@ def umxl_spec(targets=None, device="cpu", pretrained=True):
# enable centering of stft to minimize reconstruction error
if pretrained:
- state_dict = torch.hub.load_state_dict_from_url(target_urls[target], map_location=device)
+ state_dict = torch.hub.load_state_dict_from_url(
+ target_urls[target], model_dir=packaged_model_dir('umxl'), map_location=device)
target_unmix.load_state_dict(state_dict, strict=False)
target_unmix.eval()
|