soumyaprabhamaiti commited on
Commit
8aed5f0
·
0 Parent(s):

Initial commit

Browse files
.github/workflows/check_file_size.yml ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: Check file size
2
+ on: # or directly `on: [push]` to run the action on every push on any branch
3
+ pull_request:
4
+ branches: [main]
5
+
6
+ # to run this workflow manually from the Actions tab
7
+ workflow_dispatch:
8
+
9
+ jobs:
10
+ check-file-size:
11
+ runs-on: ubuntu-latest
12
+ steps:
13
+ - name: Check large files
14
+ uses: ActionsDesk/lfs-warning@v2.0
15
+ with:
16
+ filesizelimit: 10485760 # this is 10MB so we can sync to HF Spaces
.github/workflows/sync_to_HF_hub.yml ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: Sync to Hugging Face hub
2
+ on:
3
+ push:
4
+ branches: [main]
5
+
6
+ # to run this workflow manually from the Actions tab
7
+ workflow_dispatch:
8
+
9
+ jobs:
10
+ sync-to-hub:
11
+ runs-on: ubuntu-latest
12
+ steps:
13
+ - uses: actions/checkout@v3
14
+ with:
15
+ fetch-depth: 0
16
+ lfs: true
17
+ - name: Push to hub
18
+ env:
19
+ HF: ${{ secrets.HF }}
20
+ run: git push --force https://soumyaprabhamaiti:$HF@huggingface.co/spaces/soumyaprabhamaiti/pet-image-segmentation-pytorch main
.github/workflows/sync_to_core_lib_github.yml ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: Sync to Core Library GitHub Repository
2
+ on:
3
+ push:
4
+ branches: [main]
5
+
6
+ # to run this workflow manually from the Actions tab
7
+ workflow_dispatch:
8
+
9
+ jobs:
10
+ sync-to-core-lib-gh:
11
+ runs-on: ubuntu-latest
12
+ steps:
13
+ - uses: actions/checkout@v4
14
+ with:
15
+ fetch-depth: 0
16
+ lfs: true
17
+ - name: Push to core lib repo
18
+ env:
19
+ CORE_LIB_GH_PAT: ${{ secrets.CORE_LIB_GH_PAT }}
20
+ CORE_LIB_REPO_NAME: pet_seg_core
21
+ run: rm -rf .git && cd $CORE_LIB_REPO_NAME && git config --global user.email "73134224+soumya-prabha-maiti@users.noreply.github.com" && git config --global user.name "Github Actions (on behalf of Soumya Prabha Maiti)" && git init --initial-branch=main && git add . && git commit -m "Modify core library" && git remote add origin https://$CORE_LIB_GH_PAT@github.com/soumya-prabha-maiti/$CORE_LIB_REPO_NAME.git && git push -u origin main --force
.gitignore ADDED
@@ -0,0 +1,170 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ wheels/
23
+ share/python-wheels/
24
+ *.egg-info/
25
+ .installed.cfg
26
+ *.egg
27
+ MANIFEST
28
+
29
+ # PyInstaller
30
+ # Usually these files are written by a python script from a template
31
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
32
+ *.manifest
33
+ *.spec
34
+
35
+ # Installer logs
36
+ pip-log.txt
37
+ pip-delete-this-directory.txt
38
+
39
+ # Unit test / coverage reports
40
+ htmlcov/
41
+ .tox/
42
+ .nox/
43
+ .coverage
44
+ .coverage.*
45
+ .cache
46
+ nosetests.xml
47
+ coverage.xml
48
+ *.cover
49
+ *.py,cover
50
+ .hypothesis/
51
+ .pytest_cache/
52
+ cover/
53
+
54
+ # Translations
55
+ *.mo
56
+ *.pot
57
+
58
+ # Django stuff:
59
+ *.log
60
+ local_settings.py
61
+ db.sqlite3
62
+ db.sqlite3-journal
63
+
64
+ # Flask stuff:
65
+ instance/
66
+ .webassets-cache
67
+
68
+ # Scrapy stuff:
69
+ .scrapy
70
+
71
+ # Sphinx documentation
72
+ docs/_build/
73
+
74
+ # PyBuilder
75
+ .pybuilder/
76
+ target/
77
+
78
+ # Jupyter Notebook
79
+ .ipynb_checkpoints
80
+
81
+ # IPython
82
+ profile_default/
83
+ ipython_config.py
84
+
85
+ # pyenv
86
+ # For a library or package, you might want to ignore these files since the code is
87
+ # intended to run in multiple environments; otherwise, check them in:
88
+ # .python-version
89
+
90
+ # pipenv
91
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
93
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
94
+ # install all needed dependencies.
95
+ #Pipfile.lock
96
+
97
+ # poetry
98
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
100
+ # commonly ignored for libraries.
101
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102
+ #poetry.lock
103
+
104
+ # pdm
105
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106
+ #pdm.lock
107
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108
+ # in version control.
109
+ # https://pdm.fming.dev/latest/usage/project/#working-with-version-control
110
+ .pdm.toml
111
+ .pdm-python
112
+ .pdm-build/
113
+
114
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
115
+ __pypackages__/
116
+
117
+ # Celery stuff
118
+ celerybeat-schedule
119
+ celerybeat.pid
120
+
121
+ # SageMath parsed files
122
+ *.sage.py
123
+
124
+ # Environments
125
+ .env
126
+ .venv
127
+ env/
128
+ venv/
129
+ ENV/
130
+ env.bak/
131
+ venv.bak/
132
+
133
+ # Spyder project settings
134
+ .spyderproject
135
+ .spyproject
136
+
137
+ # Rope project settings
138
+ .ropeproject
139
+
140
+ # mkdocs documentation
141
+ /site
142
+
143
+ # mypy
144
+ .mypy_cache/
145
+ .dmypy.json
146
+ dmypy.json
147
+
148
+ # Pyre type checker
149
+ .pyre/
150
+
151
+ # pytype static type analyzer
152
+ .pytype/
153
+
154
+ # Cython debug symbols
155
+ cython_debug/
156
+
157
+ # PyCharm
158
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
159
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
160
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
161
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
162
+ #.idea/
163
+
164
+ data/
165
+ logs/
166
+ lightning_logs/
167
+ results/
168
+ *.ipynb
169
+ flagged/
170
+ *.ckpt
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2024 Soumya Prabha Maiti
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
README.md ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Pet Image Segmentation using PyTorch
3
+ emoji: 🌖
4
+ colorFrom: blue
5
+ colorTo: pink
6
+ sdk: gradio
7
+ sdk_version: 5.4.0
8
+ app_file: run_webapp.py
9
+ pinned: false
10
+ license: mit
11
+ ---
example_images/img1.jpg ADDED
example_images/img2.jpg ADDED
example_images/img3.jpg ADDED
pet_seg_core/.gitignore ADDED
@@ -0,0 +1,162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ wheels/
23
+ share/python-wheels/
24
+ *.egg-info/
25
+ .installed.cfg
26
+ *.egg
27
+ MANIFEST
28
+
29
+ # PyInstaller
30
+ # Usually these files are written by a python script from a template
31
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
32
+ *.manifest
33
+ *.spec
34
+
35
+ # Installer logs
36
+ pip-log.txt
37
+ pip-delete-this-directory.txt
38
+
39
+ # Unit test / coverage reports
40
+ htmlcov/
41
+ .tox/
42
+ .nox/
43
+ .coverage
44
+ .coverage.*
45
+ .cache
46
+ nosetests.xml
47
+ coverage.xml
48
+ *.cover
49
+ *.py,cover
50
+ .hypothesis/
51
+ .pytest_cache/
52
+ cover/
53
+
54
+ # Translations
55
+ *.mo
56
+ *.pot
57
+
58
+ # Django stuff:
59
+ *.log
60
+ local_settings.py
61
+ db.sqlite3
62
+ db.sqlite3-journal
63
+
64
+ # Flask stuff:
65
+ instance/
66
+ .webassets-cache
67
+
68
+ # Scrapy stuff:
69
+ .scrapy
70
+
71
+ # Sphinx documentation
72
+ docs/_build/
73
+
74
+ # PyBuilder
75
+ .pybuilder/
76
+ target/
77
+
78
+ # Jupyter Notebook
79
+ .ipynb_checkpoints
80
+
81
+ # IPython
82
+ profile_default/
83
+ ipython_config.py
84
+
85
+ # pyenv
86
+ # For a library or package, you might want to ignore these files since the code is
87
+ # intended to run in multiple environments; otherwise, check them in:
88
+ # .python-version
89
+
90
+ # pipenv
91
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
93
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
94
+ # install all needed dependencies.
95
+ #Pipfile.lock
96
+
97
+ # poetry
98
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
100
+ # commonly ignored for libraries.
101
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102
+ #poetry.lock
103
+
104
+ # pdm
105
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106
+ #pdm.lock
107
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108
+ # in version control.
109
+ # https://pdm.fming.dev/latest/usage/project/#working-with-version-control
110
+ .pdm.toml
111
+ .pdm-python
112
+ .pdm-build/
113
+
114
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
115
+ __pypackages__/
116
+
117
+ # Celery stuff
118
+ celerybeat-schedule
119
+ celerybeat.pid
120
+
121
+ # SageMath parsed files
122
+ *.sage.py
123
+
124
+ # Environments
125
+ .env
126
+ .venv
127
+ env/
128
+ venv/
129
+ ENV/
130
+ env.bak/
131
+ venv.bak/
132
+
133
+ # Spyder project settings
134
+ .spyderproject
135
+ .spyproject
136
+
137
+ # Rope project settings
138
+ .ropeproject
139
+
140
+ # mkdocs documentation
141
+ /site
142
+
143
+ # mypy
144
+ .mypy_cache/
145
+ .dmypy.json
146
+ dmypy.json
147
+
148
+ # Pyre type checker
149
+ .pyre/
150
+
151
+ # pytype static type analyzer
152
+ .pytype/
153
+
154
+ # Cython debug symbols
155
+ cython_debug/
156
+
157
+ # PyCharm
158
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
159
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
160
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
161
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
162
+ #.idea/
pet_seg_core/README.md ADDED
File without changes
pet_seg_core/config.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ import os
3
+ from dotenv import load_dotenv
4
+
5
+ load_dotenv()
6
+
7
+
8
+ @dataclass
9
+ class PetSegTrainConfig:
10
+ EPOCHS = 5
11
+ BATCH_SIZE = 8
12
+ FAST_DEV_RUN = False
13
+ TOTAL_SAMPLES = 100
14
+ LEARNING_RATE = 1e-3
15
+ TRAIN_VAL_TEST_DATA_PATH = "./data/train_val_test"
16
+ DEPTHWISE_SEP = False
17
+ CHANNELS_LIST = [16, 32, 64, 128, 256]
18
+ DESCRIPTION_TEXT = None
19
+
20
+
21
+ @dataclass
22
+ class PetSegWebappConfig:
23
+ MODEL_WEIGHTS_GDRIVE_FILE_ID = os.environ.get("MODEL_WEIGHTS_GDRIVE_FILE_ID")
24
+ MODEL_WEIGHTS_LOCAL_PATH = os.environ.get(
25
+ "MODEL_WEIGHTS_LOCAL_PATH", "pet-segmentation-pytorch_epoch=4-step=1840.ckpt"
26
+ )
27
+ DOWNLOAD_MODEL_WEIGTHS_FROM_GDRIVE = (
28
+ os.environ.get("DOWNLOAD_MODEL_WEIGTHS_FROM_GDRIVE", "True") == "True"
29
+ )
pet_seg_core/data.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torchvision
3
+ from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
4
+ from torchvision import transforms as T
5
+
6
+ from pet_seg_core.config import PetSegTrainConfig
7
+
8
+ # Define the transforms
9
+ transform = T.Compose(
10
+ [
11
+ T.ToTensor(),
12
+ T.Resize((128, 128), interpolation=T.InterpolationMode.NEAREST),
13
+ ]
14
+ )
15
+
16
+ print(f"Downloading data")
17
+
18
+ # Download the dataset
19
+ train_val_ds = torchvision.datasets.OxfordIIITPet(
20
+ root=PetSegTrainConfig.TRAIN_VAL_TEST_DATA_PATH,
21
+ split="trainval",
22
+ target_types="segmentation",
23
+ transform=transform,
24
+ target_transform=transform,
25
+ download=True,
26
+ )
27
+
28
+ print(f"Downloaded data")
29
+
30
+ # Randomly sample some samples
31
+ if PetSegTrainConfig.TOTAL_SAMPLES > 0:
32
+ train_val_ds = torch.utils.data.Subset(
33
+ train_val_ds, torch.randperm(len(train_val_ds))[:PetSegTrainConfig.TOTAL_SAMPLES]
34
+ )
35
+
36
+ # Split the dataset into train val and test
37
+ train_ds, val_ds = torch.utils.data.random_split(
38
+ train_val_ds,
39
+ [int(0.8 * len(train_val_ds)), len(train_val_ds) - int(0.8 * len(train_val_ds))],
40
+ )
41
+
42
+ test_ds, val_ds = torch.utils.data.random_split(
43
+ val_ds,
44
+ [int(0.5 * len(val_ds)), len(val_ds) - int(0.5 * len(val_ds))],
45
+ )
46
+
47
+ train_dataloader = DataLoader(
48
+ train_ds, # The training samples.
49
+ sampler=RandomSampler(train_ds), # Select batches randomly
50
+ batch_size=PetSegTrainConfig.BATCH_SIZE, # Trains with this batch size.
51
+ num_workers=3,
52
+ persistent_workers=True,
53
+ )
54
+
55
+ # For validation the order doesn't matter, so we'll just read them sequentially.
56
+ val_dataloader = DataLoader(
57
+ val_ds, # The validation samples.
58
+ sampler=SequentialSampler(val_ds), # Pull out batches sequentially.
59
+ batch_size=PetSegTrainConfig.BATCH_SIZE, # Evaluate with this batch size.
60
+ num_workers=3,
61
+ persistent_workers=True,
62
+ )
63
+
64
+ # For validation the order doesn't matter, so we'll just read them sequentially.
65
+ test_dataloader = DataLoader(
66
+ test_ds, # The validation samples.
67
+ sampler = SequentialSampler(test_ds), # Pull out batches sequentially.
68
+ batch_size = PetSegTrainConfig.BATCH_SIZE, # Evaluate with this batch size.
69
+ num_workers=3,
70
+ persistent_workers=True,
71
+ )
pet_seg_core/gdrive_utils.py ADDED
@@ -0,0 +1,124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import datetime
2
+ import io
3
+ import json
4
+ import os
5
+
6
+ from google.oauth2 import service_account
7
+ from googleapiclient.discovery import build
8
+ from googleapiclient.http import MediaIoBaseDownload
9
+
10
+
11
+ class GDriveUtils:
12
+ LOG_EVENTS = True
13
+
14
+ @staticmethod
15
+ def get_gdrive_service(creds_stringified: str | None = None):
16
+ SCOPES = ["https://www.googleapis.com/auth/drive"]
17
+ if not creds_stringified:
18
+ print(
19
+ "Attempting to use google drive creds from environment variable"
20
+ ) if GDriveUtils.LOG_EVENTS else None
21
+ creds_stringified = os.getenv("GOOGLE_SERVICE_ACC_CREDS")
22
+ creds_dict = json.loads(creds_stringified)
23
+ creds = service_account.Credentials.from_service_account_info(
24
+ creds_dict, scopes=SCOPES
25
+ )
26
+ return build("drive", "v3", credentials=creds)
27
+
28
+ @staticmethod
29
+ def upload_file_to_gdrive(
30
+ local_file_path,
31
+ drive_parent_folder_id: str,
32
+ drive_filename: str | None = None,
33
+ creds_stringified: str | None = None,
34
+ ) -> str:
35
+ service = GDriveUtils.get_gdrive_service(creds_stringified)
36
+
37
+ if not drive_filename:
38
+ drive_filename = os.path.basename(local_file_path)
39
+
40
+ file_metadata = {
41
+ "name": drive_filename,
42
+ "parents": [drive_parent_folder_id],
43
+ }
44
+ file = (
45
+ service.files()
46
+ .create(body=file_metadata, media_body=local_file_path)
47
+ .execute()
48
+ )
49
+ print(
50
+ "File uploaded, drive file id: ", file.get("id")
51
+ ) if GDriveUtils.LOG_EVENTS else None
52
+ return file.get("id")
53
+
54
+ @staticmethod
55
+ def upload_file_to_gdrive_sanity_check(
56
+ drive_parent_folder_id: str,
57
+ creds_stringified: str | None = None,
58
+ ):
59
+ try:
60
+ curr_time_utc = datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
61
+ file_name = f"gdrive_upload_test_{curr_time_utc}_UTC.txt"
62
+ print(
63
+ "Creating local file to upload: ", file_name
64
+ ) if GDriveUtils.LOG_EVENTS else None
65
+ with open(file_name, "w") as f:
66
+ f.write(f"gdrive_upload_test_{curr_time_utc}_UTC")
67
+ return GDriveUtils.upload_file_to_gdrive(
68
+ file_name, drive_parent_folder_id, creds_stringified=creds_stringified
69
+ )
70
+ except Exception as e:
71
+ raise e
72
+ finally:
73
+ if os.path.exists(file_name):
74
+ print(
75
+ "Deleting local file: ", file_name
76
+ ) if GDriveUtils.LOG_EVENTS else None
77
+ os.remove(file_name)
78
+
79
+ @staticmethod
80
+ def download_file_from_gdrive(
81
+ drive_file_id: str,
82
+ local_file_path: str | None = None,
83
+ creds_stringified: str | None = None,
84
+ ):
85
+ service = GDriveUtils.get_gdrive_service(creds_stringified)
86
+
87
+ drive_filename = service.files().get(fileId=drive_file_id, fields="name").execute().get('name')
88
+
89
+ if not local_file_path:
90
+ local_file_path = f"{drive_file_id}_{drive_filename}"
91
+
92
+ request = service.files().get_media(fileId=drive_file_id)
93
+ file = io.BytesIO()
94
+ downloader = MediaIoBaseDownload(file, request, chunksize= 25 * 1024 * 1024)
95
+ done = False
96
+ while done is False:
97
+ status, done = downloader.next_chunk()
98
+ print(f"Downloading gdrive file {drive_filename} to local file {local_file_path}: {int(status.progress() * 100)}%.") if GDriveUtils.LOG_EVENTS else None
99
+
100
+ if os.path.dirname(local_file_path):
101
+ os.makedirs(os.path.dirname(local_file_path), exist_ok=True)
102
+ with open(local_file_path, "wb") as f:
103
+ f.write(file.getvalue())
104
+ print(
105
+ "Downloaded file locally to: ", local_file_path
106
+ ) if GDriveUtils.LOG_EVENTS else None
107
+
108
+ @staticmethod
109
+ def download_file_from_gdrive_sanity_check(
110
+ drive_parent_folder_id: str,
111
+ creds_stringified: str | None = None,
112
+ ):
113
+ file_id = GDriveUtils.upload_file_to_gdrive_sanity_check(
114
+ drive_parent_folder_id, creds_stringified
115
+ )
116
+ GDriveUtils.download_file_from_gdrive(
117
+ file_id, creds_stringified=creds_stringified
118
+ )
119
+
120
+ def stringify_json_creds(json_file: str, txt_file: str) -> str:
121
+ with open(json_file, "r") as f:
122
+ creds_dict = json.load(f)
123
+ with open(txt_file, "w") as f:
124
+ f.write(json.dumps(creds_dict))
pet_seg_core/model.py ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import lightning as pl
2
+ import torch
3
+ import torchvision.transforms.functional as TF
4
+ from torch import nn
5
+ from torchmetrics.functional.segmentation import mean_iou
6
+ from torchmetrics.classification import MulticlassConfusionMatrix
7
+ from pet_seg_core.config import PetSegTrainConfig
8
+ from functools import partial
9
+
10
+
11
+ class DoubleConvOriginal(nn.Module):
12
+ def __init__(self, in_channels, out_channels):
13
+ super(DoubleConvOriginal, self).__init__()
14
+ self.double_conv = nn.Sequential(
15
+ nn.Conv2d(
16
+ in_channels,
17
+ out_channels,
18
+ kernel_size=3,
19
+ stride=1,
20
+ padding=1,
21
+ bias=False,
22
+ ),
23
+ nn.BatchNorm2d(out_channels),
24
+ nn.ReLU(inplace=True),
25
+ nn.Conv2d(
26
+ out_channels,
27
+ out_channels,
28
+ kernel_size=3,
29
+ stride=1,
30
+ padding=1,
31
+ bias=False,
32
+ ),
33
+ nn.BatchNorm2d(out_channels),
34
+ nn.ReLU(inplace=True),
35
+ )
36
+
37
+ def forward(self, x):
38
+ return self.double_conv(x)
39
+
40
+ class DoubleConvDepthwiseSep(nn.Module):
41
+ def __init__(self, in_channels, out_channels):
42
+ super(DoubleConvDepthwiseSep, self).__init__()
43
+ self.double_conv = nn.Sequential(
44
+ nn.Conv2d(
45
+ in_channels,
46
+ in_channels,
47
+ kernel_size=3,
48
+ stride=1,
49
+ padding=1,
50
+ groups=in_channels,
51
+ bias=False,
52
+ ),
53
+ nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, bias=False),
54
+ nn.BatchNorm2d(out_channels),
55
+ nn.ReLU(inplace=True),
56
+ nn.Conv2d(
57
+ out_channels,
58
+ out_channels,
59
+ kernel_size=3,
60
+ stride=1,
61
+ padding=1,
62
+ groups=out_channels,
63
+ bias=False,
64
+ ),
65
+ nn.Conv2d(out_channels, out_channels, kernel_size=1, stride=1, bias=False),
66
+ nn.BatchNorm2d(out_channels),
67
+ nn.ReLU(inplace=True),
68
+ )
69
+
70
+ def forward(self, x):
71
+ return self.double_conv(x)
72
+
73
+
74
+ class UNet(pl.LightningModule):
75
+ def __init__(
76
+ self,
77
+ in_channels,
78
+ out_channels,
79
+ channels_list=[64, 128, 256, 512],
80
+ depthwise_sep=False,
81
+ ):
82
+ super(UNet, self).__init__()
83
+ self.save_hyperparameters()
84
+
85
+ self.in_channels = in_channels
86
+ self.out_channels = out_channels
87
+
88
+ self.encoder = nn.ModuleList()
89
+ self.decoder = nn.ModuleList()
90
+ self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
91
+
92
+ if depthwise_sep:
93
+ DoubleConv = DoubleConvDepthwiseSep
94
+ else:
95
+ DoubleConv = DoubleConvOriginal
96
+
97
+ # Encoder
98
+ for channels in channels_list:
99
+ self.encoder.append(DoubleConv(in_channels, channels))
100
+ in_channels = channels
101
+
102
+ # Decoder
103
+ for channels in channels_list[::-1]:
104
+ self.decoder.append(
105
+ nn.ConvTranspose2d(channels * 2, channels, kernel_size=2, stride=2)
106
+ )
107
+ self.decoder.append(DoubleConv(channels * 2, channels))
108
+
109
+ self.bottleneck = DoubleConv(channels_list[-1], channels_list[-1] * 2)
110
+ self.out = nn.Conv2d(channels_list[0], out_channels, kernel_size=1)
111
+
112
+ self.loss_fn = nn.CrossEntropyLoss()
113
+
114
+ self.iou = partial(mean_iou, num_classes=out_channels)
115
+ self.conf_mat = MulticlassConfusionMatrix(num_classes=out_channels)
116
+
117
+ def forward(self, x):
118
+ skip_connections = []
119
+ for i, enc_block in enumerate(self.encoder):
120
+ x = enc_block(x)
121
+ skip_connections.append(x)
122
+ x = self.pool(x)
123
+
124
+ x = self.bottleneck(x)
125
+
126
+ skip_connections = skip_connections[::-1]
127
+
128
+ for i in range(0, len(self.decoder), 2):
129
+ x = self.decoder[i](x)
130
+ skip_connection = skip_connections[i // 2]
131
+
132
+ if x.shape != skip_connection.shape:
133
+ x = TF.resize(x, size=skip_connection.shape[2:])
134
+
135
+ concat_skip = torch.cat(
136
+ (skip_connection, x), dim=1
137
+ ) # Concatenate along the channel dimension
138
+ x = self.decoder[i + 1](concat_skip)
139
+
140
+ x = self.out(x)
141
+
142
+ return x
143
+
144
+ def _common_step(self, batch, batch_idx, prefix):
145
+ x, y = batch
146
+ y = (y * 255 - 1).long().squeeze(1) # move to dataloader
147
+ y_hat = self(x)
148
+
149
+ loss = self.loss_fn(y_hat, y)
150
+ self.log(f"{prefix}_loss", loss.item(), prog_bar=True)
151
+
152
+ y_hat_argmax = torch.argmax(y_hat, dim=1)
153
+
154
+ y_hat_argmax_onehot = torch.nn.functional.one_hot(y_hat_argmax, num_classes=self.out_channels).permute(0, 3, 1, 2)
155
+ y_onehot = torch.nn.functional.one_hot(y, num_classes=self.out_channels).permute(0, 3, 1, 2)
156
+ iou = self.iou(y_hat_argmax_onehot, y_onehot)
157
+ # self.log(f"{prefix}_iou", iou.mean().item(), prog_bar=True)
158
+
159
+ self.conf_mat.update(y_hat_argmax, y)
160
+
161
+ return y_hat, loss
162
+
163
+ def training_step(self, batch, batch_idx):
164
+ y_hat, loss = self._common_step(batch, batch_idx, "train")
165
+ return loss
166
+
167
+ def validation_step(self, batch, batch_idx):
168
+ y_hat, loss = self._common_step(batch, batch_idx, "val")
169
+
170
+ def test_step(self, batch, batch_idx):
171
+ y_hat, loss = self._common_step(batch, batch_idx, "test")
172
+
173
+ def _common_on_epoch_end(self, prefix):
174
+ confmat = self.conf_mat.compute()
175
+
176
+ for i in range(self.out_channels):
177
+ for j in range(self.out_channels):
178
+ self.log(f'{prefix}_confmat_true={i}_pred={j}', confmat[i][j].item(), prog_bar=True)
179
+
180
+ iou = torch.zeros(self.out_channels)
181
+ for i in range(self.out_channels):
182
+ true_positive = confmat[i, i]
183
+ false_positive = confmat.sum(dim=0)[i] - true_positive
184
+ false_negative = confmat.sum(dim=1)[i] - true_positive
185
+ union = true_positive + false_positive + false_negative
186
+ if union > 0:
187
+ iou[i] = true_positive / union
188
+ else:
189
+ iou[i] = float('nan')
190
+ self.log(f'{prefix}_iou_class={i}', iou[i].item(), prog_bar=True)
191
+
192
+ self.conf_mat.reset()
193
+
194
+ def on_train_epoch_end(self):
195
+ self._common_on_epoch_end("train")
196
+
197
+ def on_validation_epoch_end(self):
198
+ self._common_on_epoch_end("val")
199
+
200
+ def configure_optimizers(self):
201
+ return torch.optim.Adam(self.parameters(), lr=PetSegTrainConfig.LEARNING_RATE)
pet_seg_core/pytorch_device_utils.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ def get_best_device_pytorch():
4
+ device = None
5
+ if torch.cuda.is_available():
6
+ device = torch.device("cuda")
7
+ print(f'{torch.cuda.device_count()} GPU(s) available. Using the GPU: {torch.cuda.get_device_name(0)}')
8
+ elif torch.backends.mps.is_available():
9
+ device = torch.device("mps")
10
+ print("Using Mac ARM64 GPU")
11
+ else:
12
+ device = torch.device("cpu")
13
+ print('No GPU available, using CPU')
14
+ return device
pet_seg_core/requirements.txt ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ aiofiles==23.2.1
2
+ aiohappyeyeballs==2.4.3
3
+ aiohttp==3.10.10
4
+ aiosignal==1.3.1
5
+ annotated-types==0.7.0
6
+ anyio==4.6.2.post1
7
+ attrs==24.2.0
8
+ cachetools==5.5.0
9
+ certifi==2024.8.30
10
+ charset-normalizer==3.4.0
11
+ click==8.1.7
12
+ colorama==0.4.6
13
+ fastapi==0.115.4
14
+ ffmpy==0.4.0
15
+ filelock==3.16.1
16
+ frozenlist==1.5.0
17
+ fsspec==2024.10.0
18
+ google-api-core==2.22.0
19
+ google-api-python-client==2.149.0
20
+ google-auth==2.35.0
21
+ google-auth-httplib2==0.2.0
22
+ googleapis-common-protos==1.65.0
23
+ gradio==5.4.0
24
+ gradio_client==1.4.2
25
+ h11==0.14.0
26
+ httpcore==1.0.6
27
+ httplib2==0.22.0
28
+ httpx==0.27.2
29
+ huggingface-hub==0.26.2
30
+ idna==3.10
31
+ Jinja2==3.1.4
32
+ lightning==2.4.0
33
+ lightning-utilities==0.11.8
34
+ markdown-it-py==3.0.0
35
+ MarkupSafe==2.1.5
36
+ mdurl==0.1.2
37
+ mpmath==1.3.0
38
+ multidict==6.1.0
39
+ networkx==3.4.2
40
+ numpy==1.26.4
41
+ opencv-python==4.10.0.84
42
+ orjson==3.10.10
43
+ packaging==24.1
44
+ pandas==2.2.3
45
+ pillow==11.0.0
46
+ propcache==0.2.0
47
+ proto-plus==1.25.0
48
+ protobuf==5.28.3
49
+ pyasn1==0.6.1
50
+ pyasn1_modules==0.4.1
51
+ pydantic==2.9.2
52
+ pydantic_core==2.23.4
53
+ pydub==0.25.1
54
+ Pygments==2.18.0
55
+ pyparsing==3.2.0
56
+ python-dateutil==2.9.0.post0
57
+ python-dotenv==1.0.1
58
+ python-multipart==0.0.12
59
+ pytorch-lightning==2.4.0
60
+ pytz==2024.2
61
+ PyYAML==6.0.2
62
+ requests==2.32.3
63
+ rich==13.9.3
64
+ rsa==4.9
65
+ ruff==0.7.1
66
+ safehttpx==0.1.1
67
+ semantic-version==2.10.0
68
+ shellingham==1.5.4
69
+ six==1.16.0
70
+ sniffio==1.3.1
71
+ starlette==0.41.2
72
+ sympy==1.13.1
73
+ tomlkit==0.12.0
74
+ torch==2.5.1
75
+ torchmetrics==1.5.1
76
+ torchvision==0.20.1
77
+ tqdm==4.66.6
78
+ typer==0.12.5
79
+ typing_extensions==4.12.2
80
+ tzdata==2024.2
81
+ uritemplate==4.1.1
82
+ urllib3==2.2.3
83
+ uvicorn==0.32.0
84
+ websockets==12.0
85
+ yarl==1.17.0
pet_seg_core/train.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import lightning as pl
2
+ from lightning.pytorch.callbacks import ModelCheckpoint
3
+ from lightning.pytorch.loggers import CSVLogger
4
+ from datetime import datetime
5
+ import os
6
+
7
+ from pet_seg_core.config import PetSegTrainConfig
8
+ from pet_seg_core.data import train_dataloader, val_dataloader
9
+ from pet_seg_core.model import UNet
10
+
11
+ def train():
12
+ curr_time = datetime.now().strftime('%Y-%m-%d_%H:%M:%S.%f')
13
+ results_folder = f"results/{curr_time}"
14
+ os.mkdir(results_folder)
15
+ with open(f"{results_folder}/description.txt", "w") as f:
16
+ f.write(PetSegTrainConfig.DESCRIPTION_TEXT)
17
+
18
+ logger = CSVLogger(save_dir="", name=results_folder, version="")
19
+ checkpoint_callback = ModelCheckpoint(
20
+ dirpath=results_folder,
21
+ save_top_k=-1,
22
+ )
23
+ trainer = pl.Trainer(
24
+ max_epochs=PetSegTrainConfig.EPOCHS, fast_dev_run=PetSegTrainConfig.FAST_DEV_RUN, logger=logger, callbacks=[checkpoint_callback], gradient_clip_val=1.0
25
+ )
26
+ model = UNet(3, 3, channels_list=PetSegTrainConfig.CHANNELS_LIST, depthwise_sep=PetSegTrainConfig.DEPTHWISE_SEP)
27
+
28
+ trainer.fit(model, train_dataloaders=train_dataloader, val_dataloaders=val_dataloader)
pet_seg_core/webapp.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pet_seg_core.model import UNet
2
+ from pet_seg_core.config import PetSegWebappConfig
3
+ from pet_seg_core.gdrive_utils import GDriveUtils
4
+
5
+ from torchvision import transforms as T
6
+ import torch
7
+ import gradio as gr
8
+ import numpy as np
9
+ import cv2
10
+ from dotenv import load_dotenv
11
+
12
+ load_dotenv()
13
+
14
+ device = torch.device("cpu")
15
+
16
+ if PetSegWebappConfig.DOWNLOAD_MODEL_WEIGTHS_FROM_GDRIVE:
17
+ GDriveUtils.download_file_from_gdrive(
18
+ PetSegWebappConfig.MODEL_WEIGHTS_GDRIVE_FILE_ID, PetSegWebappConfig.MODEL_WEIGHTS_LOCAL_PATH
19
+ )
20
+
21
+ model = UNet.load_from_checkpoint(PetSegWebappConfig.MODEL_WEIGHTS_LOCAL_PATH)
22
+ model.eval()
23
+
24
+ def segment_image(img):
25
+ img = T.ToTensor()(img).unsqueeze(0).to(device)
26
+ mask = model(img)
27
+ mask = torch.argmax(mask, dim = 1).squeeze().detach().cpu().numpy()
28
+ return mask
29
+
30
+ def overlay_mask(img, mask, alpha=0.5):
31
+ # Define color mapping
32
+ colors = {
33
+ 0: [255, 0, 0], # Class 0 - Red
34
+ 1: [0, 255, 0], # Class 1 - Green
35
+ 2: [0, 0, 255] # Class 2 - Blue
36
+ # Add more colors for additional classes if needed
37
+ }
38
+
39
+ # Create a blank colored overlay image
40
+ overlay = np.zeros_like(img)
41
+
42
+ # Map each mask value to the corresponding color
43
+ for class_id, color in colors.items():
44
+ overlay[mask == class_id] = color
45
+
46
+ # Blend the overlay with the original image
47
+ output = cv2.addWeighted(img, 1 - alpha, overlay, alpha, 0)
48
+
49
+ return output
50
+
51
+ def transform(img):
52
+ mask=segment_image(img)
53
+ blended_img = overlay_mask(img, mask)
54
+ return blended_img
55
+
56
+ app = gr.Interface(
57
+ fn=transform,
58
+ inputs=gr.Image(label="Input Image"),
59
+ outputs=gr.Image(label="Image with Segmentation Overlay"),
60
+ title="Image Segmentation on Pet Images",
61
+ description="Segment image of a pet animal into three classes: background, pet, and boundary.",
62
+ examples=[
63
+ "example_images/img1.jpg",
64
+ "example_images/img2.jpg",
65
+ "example_images/img3.jpg"
66
+ ]
67
+ )
requirements.txt ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ aiofiles==23.2.1
2
+ aiohappyeyeballs==2.4.3
3
+ aiohttp==3.10.10
4
+ aiosignal==1.3.1
5
+ annotated-types==0.7.0
6
+ anyio==4.6.2.post1
7
+ attrs==24.2.0
8
+ cachetools==5.5.0
9
+ certifi==2024.8.30
10
+ charset-normalizer==3.4.0
11
+ click==8.1.7
12
+ colorama==0.4.6
13
+ fastapi==0.115.4
14
+ ffmpy==0.4.0
15
+ filelock==3.16.1
16
+ frozenlist==1.5.0
17
+ fsspec==2024.10.0
18
+ google-api-core==2.22.0
19
+ google-api-python-client==2.149.0
20
+ google-auth==2.35.0
21
+ google-auth-httplib2==0.2.0
22
+ googleapis-common-protos==1.65.0
23
+ gradio==5.4.0
24
+ gradio_client==1.4.2
25
+ h11==0.14.0
26
+ httpcore==1.0.6
27
+ httplib2==0.22.0
28
+ httpx==0.27.2
29
+ huggingface-hub==0.26.2
30
+ idna==3.10
31
+ Jinja2==3.1.4
32
+ lightning==2.4.0
33
+ lightning-utilities==0.11.8
34
+ markdown-it-py==3.0.0
35
+ MarkupSafe==2.1.5
36
+ mdurl==0.1.2
37
+ mpmath==1.3.0
38
+ multidict==6.1.0
39
+ networkx==3.4.2
40
+ numpy==1.26.4
41
+ opencv-python==4.10.0.84
42
+ orjson==3.10.10
43
+ packaging==24.1
44
+ pandas==2.2.3
45
+ pillow==11.0.0
46
+ propcache==0.2.0
47
+ proto-plus==1.25.0
48
+ protobuf==5.28.3
49
+ pyasn1==0.6.1
50
+ pyasn1_modules==0.4.1
51
+ pydantic==2.9.2
52
+ pydantic_core==2.23.4
53
+ pydub==0.25.1
54
+ Pygments==2.18.0
55
+ pyparsing==3.2.0
56
+ python-dateutil==2.9.0.post0
57
+ python-dotenv==1.0.1
58
+ python-multipart==0.0.12
59
+ pytorch-lightning==2.4.0
60
+ pytz==2024.2
61
+ PyYAML==6.0.2
62
+ requests==2.32.3
63
+ rich==13.9.3
64
+ rsa==4.9
65
+ ruff==0.7.1
66
+ safehttpx==0.1.1
67
+ semantic-version==2.10.0
68
+ shellingham==1.5.4
69
+ six==1.16.0
70
+ sniffio==1.3.1
71
+ starlette==0.41.2
72
+ sympy==1.13.1
73
+ tomlkit==0.12.0
74
+ torch==2.5.1
75
+ torchmetrics==1.5.1
76
+ torchvision==0.20.1
77
+ tqdm==4.66.6
78
+ typer==0.12.5
79
+ typing_extensions==4.12.2
80
+ tzdata==2024.2
81
+ uritemplate==4.1.1
82
+ urllib3==2.2.3
83
+ uvicorn==0.32.0
84
+ websockets==12.0
85
+ yarl==1.17.0
run_training.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ if __name__ == '__main__':
2
+ from pet_seg_core.config import PetSegTrainConfig
3
+ PetSegTrainConfig.EPOCHS = 5
4
+ PetSegTrainConfig.TOTAL_SAMPLES = -1
5
+ PetSegTrainConfig.DESCRIPTION_TEXT = "UNET with RGB input and 3 channel(3 class) output, trained on all samples for 5 epochs. Will be used for webapp"
6
+ from pet_seg_core.train import train
7
+ train()
run_webapp.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from pet_seg_core.webapp import app
2
+ app.launch()