from setuptools import setup | |
from torch.utils.cpp_extension import BuildExtension, CUDAExtension | |
import os | |
import torch | |
print("Building gscuda") | |
# 假设源文件在 gs_cuda 目录下 | |
file_path = "utils/gs_cuda_dmax" | |
setup( | |
name="gscuda", # 模块名 | |
ext_modules=[ | |
CUDAExtension( | |
name="gscuda", # 可以直接作为模块导入 | |
sources=[ | |
os.path.join(file_path, "gswrapper.cpp"), | |
os.path.join(file_path, "gs.cu") | |
], | |
# 设置运行时库路径(可选) | |
library_dirs=[os.path.join(os.path.dirname(torch.__file__), 'lib')], | |
) | |
], | |
cmdclass={ | |
"build_ext": BuildExtension | |
}, | |
) |