Updated model to use PyTorch instead of ONNX
Browse files- app.py +4 -0
- requirements.txt +2 -0
app.py
CHANGED
@@ -57,6 +57,10 @@ checkpoint_path = "swin_small_patch4_window7_224_512_v1_latest.pt"
|
|
57 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
58 |
model = _load_model(checkpoint_path, device)
|
59 |
|
|
|
|
|
|
|
|
|
60 |
|
61 |
def _get_foreground_estimation(image, alpha):
|
62 |
"""
|
|
|
57 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
58 |
model = _load_model(checkpoint_path, device)
|
59 |
|
60 |
+
print(f"Using device: {device}")
|
61 |
+
if device.type == "cuda":
|
62 |
+
print(f"CUDA device: {torch.cuda.get_device_name(torch.cuda.current_device())}")
|
63 |
+
|
64 |
|
65 |
def _get_foreground_estimation(image, alpha):
|
66 |
"""
|
requirements.txt
CHANGED
@@ -1,3 +1,5 @@
|
|
|
|
|
|
1 |
gradio
|
2 |
torch
|
3 |
torchvision
|
|
|
1 |
+
--extra-index-url https://download.pytorch.org/whl/cu113
|
2 |
+
|
3 |
gradio
|
4 |
torch
|
5 |
torchvision
|