tumuyan2 commited on
Commit
a1b158c
·
1 Parent(s): 62053b6
Files changed (2) hide show
  1. app.py +37 -52
  2. pth2onnx.py +1 -1
app.py CHANGED
@@ -10,19 +10,8 @@ from urllib.parse import urlparse, unquote
10
  import gdown
11
  import sys
12
  from typing import Optional
 
13
 
14
- # def format_bytes(size: int) -> str:
15
- # """将字节数格式化为更易读的单位 (KB, MB, GB)"""
16
- # if size is None:
17
- # return "未知大小"
18
- # power = 1024
19
- # n = 0
20
- # power_labels = {0: '', 1: 'K', 2: 'M', 3: 'G', 4: 'T'}
21
- # while size > power and n < len(power_labels) -1 :
22
- # size /= power
23
- # n += 1
24
- # return f"{size:.2f} {power_labels[n]}B"
25
- # 日志开关
26
  log_to_terminal = True
27
  task_counter = 0
28
  download_cache = {} # 格式: {url: 文件路径}
@@ -44,8 +33,6 @@ def download_file(url: str, save_path: str):
44
  with open(save_path, 'wb') as f:
45
  f.write(response.content)
46
 
47
-
48
-
49
  def download_gdrive_file(
50
  url: str,
51
  folder: str,
@@ -65,57 +52,53 @@ def download_gdrive_file(
65
  Optional[str]: 如果下载成功,返回文件的完整路径;否则返回 None。
66
  """
67
  print(f"--- 开始处理链接: {url} ---")
68
-
69
- # 1. 获取文件信息
70
- try:
71
- info = gdown.get_file_info(url)
72
- if not info:
73
- print(f"错误:无法从链接获取文件信息。请检查链接是否有效且文件是否已公开分享。", file=sys.stderr)
74
- return None
75
-
76
- filename = info.get('name')
77
- filesize = info.get('size')
78
 
79
- if not filename or filesize is None:
80
- print(f"错误:无法获取完整的文件名或文件大小。元数据: {info}", file=sys.stderr)
81
- return None
82
-
83
- except Exception as e:
84
- print(f"错误:获取文件信息时发生异常: {e}", file=sys.stderr)
85
- return None
86
-
87
- # 2. 验证文件大小
88
- if filesize < filesize_min or filesize > filesize_max:
89
- return None
90
-
91
 
92
  # 3. 准备下载路径并创建文件夹
93
  try:
94
  os.makedirs(folder, exist_ok=True)
95
- output_path = os.path.join(folder, filename)
96
  except OSError as e:
97
  print(f"错误:创建文件夹 '{folder}' 失败: {e}", file=sys.stderr)
98
  return None
99
 
100
  # 4. 下载文件
101
  try:
102
- print(f"开始下载google drive文件到: {output_path}")
103
- downloaded_path = gdown.download(url, output_path, quiet=True)
104
- if downloaded_path and os.path.exists(downloaded_path):
105
- print(f"下载成功!文件已保存至: {downloaded_path}")
106
- return downloaded_path
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
107
  else:
108
- print("错误:gdown 下载过程未返回有效路径或文件下载后不存在。", file=sys.stderr)
109
  return None
110
-
111
  except Exception as e:
112
- print(f"错误:下载过程中发生异常: {e}", file=sys.stderr)
113
- # 如果下载中断,清理可能不完整的文件
114
- if os.path.exists(output_path):
115
- os.remove(output_path)
116
  return None
117
-
118
-
 
 
 
119
 
120
  def download_file2folder(url: str, folder: str, filesize_max: int, filesize_min: int) -> Optional[str]:
121
  global download_cache
@@ -287,7 +270,9 @@ with gr.Blocks() as demo:
287
  async def process_model(input_type, url_input, file_input, tilesize, fp16, try_run):
288
  global task_counter
289
  task_counter += 1
290
- output_dir = os.path.join(os.getcwd(), f"output_{task_counter}")
 
 
291
  os.makedirs(output_dir, exist_ok=True)
292
  log=""
293
 
@@ -296,7 +281,7 @@ with gr.Blocks() as demo:
296
  print_log(task_counter, f'正在下载模型文件: {url_input}', '开始')
297
  yield None, None, log
298
 
299
- if url_input.startswith("https://drive.google.com/file/"):
300
  model_input = download_gdrive_file(
301
  url=url_input,
302
  folder=output_dir,
 
10
  import gdown
11
  import sys
12
  from typing import Optional
13
+ import datetime
14
 
 
 
 
 
 
 
 
 
 
 
 
 
15
  log_to_terminal = True
16
  task_counter = 0
17
  download_cache = {} # 格式: {url: 文件路径}
 
33
  with open(save_path, 'wb') as f:
34
  f.write(response.content)
35
 
 
 
36
  def download_gdrive_file(
37
  url: str,
38
  folder: str,
 
52
  Optional[str]: 如果下载成功,返回文件的完整路径;否则返回 None。
53
  """
54
  print(f"--- 开始处理链接: {url} ---")
 
 
 
 
 
 
 
 
 
 
55
 
 
 
 
 
 
 
 
 
 
 
 
 
56
 
57
  # 3. 准备下载路径并创建文件夹
58
  try:
59
  os.makedirs(folder, exist_ok=True)
 
60
  except OSError as e:
61
  print(f"错误:创建文件夹 '{folder}' 失败: {e}", file=sys.stderr)
62
  return None
63
 
64
  # 4. 下载文件
65
  try:
66
+ print(f"开始下载google drive文件到: {folder}")
67
+ # 保存当前工作目录
68
+ original_dir = os.getcwd()
69
+ # 切换到目标文件夹
70
+ os.chdir(folder)
71
+ # 下载文件,不指定output_path参数
72
+ downloaded_filename = gdown.download(url, quiet=True, fuzzy=True)
73
+ # 恢复原始工作目录
74
+ os.chdir(original_dir)
75
+
76
+ # 检查下载结果和文件信息
77
+ if downloaded_filename:
78
+ downloaded_path = os.path.join(folder, downloaded_filename)
79
+ if os.path.exists(downloaded_path):
80
+ # 获取文件大小
81
+ file_size = os.path.getsize(downloaded_path)
82
+ if file_size< filesize_max and file_size> filesize_min:
83
+ print(f"下载成功!文件已保存至: {downloaded_path}")
84
+ return downloaded_path
85
+ else:
86
+ print(f"文件大小超出范围: {file_size} bytes")
87
+ return None
88
+ else:
89
+ print(f"错误:下载的文件 '{downloaded_path}' 不存在。", file=sys.stderr)
90
+ return None
91
  else:
92
+ print("错误:gdown 下载过程未返回有效文件名。", file=sys.stderr)
93
  return None
 
94
  except Exception as e:
95
+ print(f"下载过程中发生错误: {e}", file=sys.stderr)
 
 
 
96
  return None
97
+ finally:
98
+ # 确保工作目录恢复,即使发生异常
99
+ if 'original_dir' in locals():
100
+ os.chdir(original_dir)
101
+
102
 
103
  def download_file2folder(url: str, folder: str, filesize_max: int, filesize_min: int) -> Optional[str]:
104
  global download_cache
 
270
  async def process_model(input_type, url_input, file_input, tilesize, fp16, try_run):
271
  global task_counter
272
  task_counter += 1
273
+
274
+ timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
275
+ output_dir = os.path.join(os.getcwd(), f"output_{task_counter}_{timestamp}")
276
  os.makedirs(output_dir, exist_ok=True)
277
  log=""
278
 
 
281
  print_log(task_counter, f'正在下载模型文件: {url_input}', '开始')
282
  yield None, None, log
283
 
284
+ if url_input.startswith("https://drive.google.com/"):
285
  model_input = download_gdrive_file(
286
  url=url_input,
287
  folder=output_dir,
pth2onnx.py CHANGED
@@ -70,7 +70,7 @@ def convert_pth_to_onnx(pth_path: str, onnx_path: str=None, channel:int=0, tiles
70
  filename = os.path.basename(pth_path).upper()
71
  pattern = f'(^|[_-])({scale}X|X{scale})([_-]|$)'
72
  if re.search(pattern, filename):
73
- print(f'文件名 {filename} 包含匹配模式。')
74
  else:
75
  base_path = f"{base_path}-x{scale}"
76
  # print("final use_fp16", str(use_fp16) )
 
70
  filename = os.path.basename(pth_path).upper()
71
  pattern = f'(^|[_-])({scale}X|X{scale})([_-]|$)'
72
  if re.search(pattern, filename):
73
+ print(f'文件名 {filename} 包含倍率信息。')
74
  else:
75
  base_path = f"{base_path}-x{scale}"
76
  # print("final use_fp16", str(use_fp16) )