update
Browse files- app.py +37 -52
- pth2onnx.py +1 -1
app.py
CHANGED
@@ -10,19 +10,8 @@ from urllib.parse import urlparse, unquote
|
|
10 |
import gdown
|
11 |
import sys
|
12 |
from typing import Optional
|
|
|
13 |
|
14 |
-
# def format_bytes(size: int) -> str:
|
15 |
-
# """将字节数格式化为更易读的单位 (KB, MB, GB)"""
|
16 |
-
# if size is None:
|
17 |
-
# return "未知大小"
|
18 |
-
# power = 1024
|
19 |
-
# n = 0
|
20 |
-
# power_labels = {0: '', 1: 'K', 2: 'M', 3: 'G', 4: 'T'}
|
21 |
-
# while size > power and n < len(power_labels) -1 :
|
22 |
-
# size /= power
|
23 |
-
# n += 1
|
24 |
-
# return f"{size:.2f} {power_labels[n]}B"
|
25 |
-
# 日志开关
|
26 |
log_to_terminal = True
|
27 |
task_counter = 0
|
28 |
download_cache = {} # 格式: {url: 文件路径}
|
@@ -44,8 +33,6 @@ def download_file(url: str, save_path: str):
|
|
44 |
with open(save_path, 'wb') as f:
|
45 |
f.write(response.content)
|
46 |
|
47 |
-
|
48 |
-
|
49 |
def download_gdrive_file(
|
50 |
url: str,
|
51 |
folder: str,
|
@@ -65,57 +52,53 @@ def download_gdrive_file(
|
|
65 |
Optional[str]: 如果下载成功,返回文件的完整路径;否则返回 None。
|
66 |
"""
|
67 |
print(f"--- 开始处理链接: {url} ---")
|
68 |
-
|
69 |
-
# 1. 获取文件信息
|
70 |
-
try:
|
71 |
-
info = gdown.get_file_info(url)
|
72 |
-
if not info:
|
73 |
-
print(f"错误:无法从链接获取文件信息。请检查链接是否有效且文件是否已公开分享。", file=sys.stderr)
|
74 |
-
return None
|
75 |
-
|
76 |
-
filename = info.get('name')
|
77 |
-
filesize = info.get('size')
|
78 |
|
79 |
-
if not filename or filesize is None:
|
80 |
-
print(f"错误:无法获取完整的文件名或文件大小。元数据: {info}", file=sys.stderr)
|
81 |
-
return None
|
82 |
-
|
83 |
-
except Exception as e:
|
84 |
-
print(f"错误:获取文件信息时发生异常: {e}", file=sys.stderr)
|
85 |
-
return None
|
86 |
-
|
87 |
-
# 2. 验证文件大小
|
88 |
-
if filesize < filesize_min or filesize > filesize_max:
|
89 |
-
return None
|
90 |
-
|
91 |
|
92 |
# 3. 准备下载路径并创建文件夹
|
93 |
try:
|
94 |
os.makedirs(folder, exist_ok=True)
|
95 |
-
output_path = os.path.join(folder, filename)
|
96 |
except OSError as e:
|
97 |
print(f"错误:创建文件夹 '{folder}' 失败: {e}", file=sys.stderr)
|
98 |
return None
|
99 |
|
100 |
# 4. 下载文件
|
101 |
try:
|
102 |
-
print(f"开始下载google drive文件到: {
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
107 |
else:
|
108 |
-
print("错误:gdown
|
109 |
return None
|
110 |
-
|
111 |
except Exception as e:
|
112 |
-
print(f"
|
113 |
-
# 如果下载中断,清理可能不完整的文件
|
114 |
-
if os.path.exists(output_path):
|
115 |
-
os.remove(output_path)
|
116 |
return None
|
117 |
-
|
118 |
-
|
|
|
|
|
|
|
119 |
|
120 |
def download_file2folder(url: str, folder: str, filesize_max: int, filesize_min: int) -> Optional[str]:
|
121 |
global download_cache
|
@@ -287,7 +270,9 @@ with gr.Blocks() as demo:
|
|
287 |
async def process_model(input_type, url_input, file_input, tilesize, fp16, try_run):
|
288 |
global task_counter
|
289 |
task_counter += 1
|
290 |
-
|
|
|
|
|
291 |
os.makedirs(output_dir, exist_ok=True)
|
292 |
log=""
|
293 |
|
@@ -296,7 +281,7 @@ with gr.Blocks() as demo:
|
|
296 |
print_log(task_counter, f'正在下载模型文件: {url_input}', '开始')
|
297 |
yield None, None, log
|
298 |
|
299 |
-
if url_input.startswith("https://drive.google.com/
|
300 |
model_input = download_gdrive_file(
|
301 |
url=url_input,
|
302 |
folder=output_dir,
|
|
|
10 |
import gdown
|
11 |
import sys
|
12 |
from typing import Optional
|
13 |
+
import datetime
|
14 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
15 |
log_to_terminal = True
|
16 |
task_counter = 0
|
17 |
download_cache = {} # 格式: {url: 文件路径}
|
|
|
33 |
with open(save_path, 'wb') as f:
|
34 |
f.write(response.content)
|
35 |
|
|
|
|
|
36 |
def download_gdrive_file(
|
37 |
url: str,
|
38 |
folder: str,
|
|
|
52 |
Optional[str]: 如果下载成功,返回文件的完整路径;否则返回 None。
|
53 |
"""
|
54 |
print(f"--- 开始处理链接: {url} ---")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
55 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
56 |
|
57 |
# 3. 准备下载路径并创建文件夹
|
58 |
try:
|
59 |
os.makedirs(folder, exist_ok=True)
|
|
|
60 |
except OSError as e:
|
61 |
print(f"错误:创建文件夹 '{folder}' 失败: {e}", file=sys.stderr)
|
62 |
return None
|
63 |
|
64 |
# 4. 下载文件
|
65 |
try:
|
66 |
+
print(f"开始下载google drive文件到: {folder}")
|
67 |
+
# 保存当前工作目录
|
68 |
+
original_dir = os.getcwd()
|
69 |
+
# 切换到目标文件夹
|
70 |
+
os.chdir(folder)
|
71 |
+
# 下载文件,不指定output_path参数
|
72 |
+
downloaded_filename = gdown.download(url, quiet=True, fuzzy=True)
|
73 |
+
# 恢复原始工作目录
|
74 |
+
os.chdir(original_dir)
|
75 |
+
|
76 |
+
# 检查下载结果和文件信息
|
77 |
+
if downloaded_filename:
|
78 |
+
downloaded_path = os.path.join(folder, downloaded_filename)
|
79 |
+
if os.path.exists(downloaded_path):
|
80 |
+
# 获取文件大小
|
81 |
+
file_size = os.path.getsize(downloaded_path)
|
82 |
+
if file_size< filesize_max and file_size> filesize_min:
|
83 |
+
print(f"下载成功!文件已保存至: {downloaded_path}")
|
84 |
+
return downloaded_path
|
85 |
+
else:
|
86 |
+
print(f"文件大小超出范围: {file_size} bytes")
|
87 |
+
return None
|
88 |
+
else:
|
89 |
+
print(f"错误:下载的文件 '{downloaded_path}' 不存在。", file=sys.stderr)
|
90 |
+
return None
|
91 |
else:
|
92 |
+
print("错误:gdown 下载过程未返回有效文件名。", file=sys.stderr)
|
93 |
return None
|
|
|
94 |
except Exception as e:
|
95 |
+
print(f"下载过程中发生错误: {e}", file=sys.stderr)
|
|
|
|
|
|
|
96 |
return None
|
97 |
+
finally:
|
98 |
+
# 确保工作目录恢复,即使发生异常
|
99 |
+
if 'original_dir' in locals():
|
100 |
+
os.chdir(original_dir)
|
101 |
+
|
102 |
|
103 |
def download_file2folder(url: str, folder: str, filesize_max: int, filesize_min: int) -> Optional[str]:
|
104 |
global download_cache
|
|
|
270 |
async def process_model(input_type, url_input, file_input, tilesize, fp16, try_run):
|
271 |
global task_counter
|
272 |
task_counter += 1
|
273 |
+
|
274 |
+
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
|
275 |
+
output_dir = os.path.join(os.getcwd(), f"output_{task_counter}_{timestamp}")
|
276 |
os.makedirs(output_dir, exist_ok=True)
|
277 |
log=""
|
278 |
|
|
|
281 |
print_log(task_counter, f'正在下载模型文件: {url_input}', '开始')
|
282 |
yield None, None, log
|
283 |
|
284 |
+
if url_input.startswith("https://drive.google.com/"):
|
285 |
model_input = download_gdrive_file(
|
286 |
url=url_input,
|
287 |
folder=output_dir,
|
pth2onnx.py
CHANGED
@@ -70,7 +70,7 @@ def convert_pth_to_onnx(pth_path: str, onnx_path: str=None, channel:int=0, tiles
|
|
70 |
filename = os.path.basename(pth_path).upper()
|
71 |
pattern = f'(^|[_-])({scale}X|X{scale})([_-]|$)'
|
72 |
if re.search(pattern, filename):
|
73 |
-
print(f'文件名 {filename}
|
74 |
else:
|
75 |
base_path = f"{base_path}-x{scale}"
|
76 |
# print("final use_fp16", str(use_fp16) )
|
|
|
70 |
filename = os.path.basename(pth_path).upper()
|
71 |
pattern = f'(^|[_-])({scale}X|X{scale})([_-]|$)'
|
72 |
if re.search(pattern, filename):
|
73 |
+
print(f'文件名 {filename} 包含倍率信息。')
|
74 |
else:
|
75 |
base_path = f"{base_path}-x{scale}"
|
76 |
# print("final use_fp16", str(use_fp16) )
|