Svane20 commited on
Commit
863ed87
Β·
1 Parent(s): c2fd7a9

Updated example images

Browse files
app.py CHANGED
@@ -1,14 +1,15 @@
1
  import gradio as gr
2
  import torch
3
  from torchvision.transforms import Compose, Resize, ToTensor, Normalize
 
 
 
4
  import numpy as np
5
  import os
6
- from timeit import default_timer as timer
7
  from PIL import Image
8
- import onnxruntime as ort
9
-
10
- from replacements.foreground_estimation import get_foreground_estimation
11
- from replacements.replacements import sky_replacement
12
 
13
 
14
  def _load_model(checkpoint):
@@ -37,29 +38,157 @@ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
37
  session, input_name, output_name = _load_model(checkpoint_path)
38
 
39
 
40
- def inference(image):
41
- output = session.run([output_name], {input_name: image.cpu().numpy()})[0]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
 
43
  # Ensure the output is in valid range [0, 1]
44
- output = np.clip(output, 0, 1)
45
 
46
  return np.squeeze(output, axis=0).squeeze()
47
 
48
 
49
- def predict(image):
50
- image_tensor = transforms(image).unsqueeze(0).to(device)
51
-
52
- # Perform inference
53
- predicted_alpha = inference(image_tensor)
54
 
55
- # Perform sky replacement
56
- h, w = predicted_alpha.shape
57
- downscaled_image = image.resize(size=(w, h), resample=Image.Resampling.LANCZOS)
58
- foreground = get_foreground_estimation(downscaled_image, predicted_alpha)
59
- replaced_sky = sky_replacement(foreground, predicted_alpha)
60
 
61
- return predicted_alpha, replaced_sky
 
 
 
 
62
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63
 
64
  with gr.Blocks(theme=gr.themes.Default()) as demo:
65
  gr.Markdown(
@@ -71,55 +200,60 @@ with gr.Blocks(theme=gr.themes.Default()) as demo:
71
  """
72
  )
73
 
 
 
74
  with gr.Row():
75
  # Left Column: Input Image and Run/Clear Buttons
76
  with gr.Column(scale=1):
77
  input_image = gr.Image(type="pil", label="Input Image")
78
-
79
  with gr.Row():
80
  clear_button = gr.Button("Clear")
81
  run_button = gr.Button("Submit", variant="primary")
82
 
83
  # Right Column: Output Images
84
  with gr.Column(scale=1):
85
- output_mask = gr.Image(type="numpy", label="Predicted Mask")
86
- output_sky = gr.Image(type="numpy", label="Sky Replacement")
87
 
88
  metadata_display = gr.Markdown(None)
89
 
90
-
91
- def load_example(example_image, example_type, example_desc):
92
- info = f"**Type:** {example_type}\n\n**Description:** {example_desc}"
93
- return example_image, info
94
-
95
-
96
- example_list = [
97
- ["examples/real_0054.jpg", "Real", "Good"],
98
- ["examples/real_0116.jpg", "Real", "Good"],
99
- ["examples/real_0585.jpg", "Real", "Good"],
100
- ["examples/synthetic_10635.jpg", "Synthetic", "Good"],
101
- ["examples/synthetic_10512.jpg", "Synthetic", "Good"],
102
- ["examples/real_0765.jpg", "Real", "Decent"],
103
- ["examples/real_0822.jpg", "Real", "Decent"],
104
- ["examples/synthetic_10795.jpg", "Synthetic", "Decent"],
105
- ["examples/synthetic_10560.jpg", "Synthetic", "Decent"],
106
- ["examples/synthetic_10679.jpg", "Synthetic", "Decent"],
107
- ["examples/real_0823.jpg", "Real", "Decent, lacks some details in the trees"],
108
- ["examples/real_bad_0007.jpg", "Real", "Bad, lacks details in the trees"],
109
- ["examples/real_bad_0934.jpg", "Real", "Bad, lacks details in the trees"],
110
- ]
111
-
112
- examples_component = gr.Examples(
113
- examples=example_list,
114
- inputs=[
115
- input_image,
116
- gr.Textbox(label="Real or Synthetic", value="", interactive=False, visible=False),
117
- gr.Textbox(label="Description", value="", interactive=False, visible=False),
118
- ],
119
- outputs=[input_image, metadata_display],
120
- fn=load_example,
121
- cache_examples=True,
122
- label="Examples (click an image to load it and see details)"
 
 
 
 
123
  )
124
 
125
 
 
1
  import gradio as gr
2
  import torch
3
  from torchvision.transforms import Compose, Resize, ToTensor, Normalize
4
+ import onnxruntime as ort
5
+
6
+ import pymatting
7
  import numpy as np
8
  import os
 
9
  from PIL import Image
10
+ from typing import Tuple
11
+ import random
12
+ from pathlib import Path
 
13
 
14
 
15
  def _load_model(checkpoint):
 
38
  session, input_name, output_name = _load_model(checkpoint_path)
39
 
40
 
41
+ def _get_foreground_estimation(image: np.ndarray, alpha: np.ndarray) -> np.ndarray:
42
+ """
43
+ Estimate the foreground using the image and the predicted alpha mask.
44
+
45
+ Args:
46
+ image (np.ndarray): The input image.
47
+ alpha (np.ndarray): The predicted alpha mask.
48
+
49
+ Returns:
50
+ np.ndarray: The estimated foreground.
51
+ """
52
+ # Normalize the image to [0, 1] range
53
+ normalized_image = np.array(image) / 255.0
54
+
55
+ # Invert the alpha mask since the pymatting library expects the sky to be the background
56
+ inverted_alpha = 1 - alpha
57
+
58
+ return pymatting.estimate_foreground_ml(image=normalized_image, alpha=inverted_alpha)
59
+
60
+
61
+ def _sky_replacement(foreground: np.ndarray, alpha_mask: np.ndarray) -> np.ndarray:
62
+ """
63
+ Perform sky replacement using the estimated foreground and predicted alpha mask.
64
+
65
+ Args:
66
+ foreground (np.ndarray): The estimated foreground.
67
+ alpha_mask (np.ndarray): The predicted alpha mask.
68
+
69
+ Returns:
70
+ np.ndarray: The sky-replaced image.
71
+ """
72
+ new_sky_path = Path(__file__).parent / "assets/skies/francesco-ungaro-i75WTJn-RBY-unsplash.jpg"
73
+ new_sky_img = Image.open(new_sky_path).convert("RGB")
74
+
75
+ # Get the target size from the foreground image
76
+ h, w = foreground.shape[:2]
77
+
78
+ # Check the size of the sky image
79
+ sky_width, sky_height = new_sky_img.size
80
+
81
+ # If the sky image is smaller than the target size
82
+ if sky_width < w or sky_height < h:
83
+ scale = max(w / sky_width, h / sky_height)
84
+ new_size = (int(sky_width * scale), int(sky_height * scale))
85
+ new_sky_img = new_sky_img.resize(new_size, resample=Image.Resampling.LANCZOS)
86
+ sky_width, sky_height = new_sky_img.size
87
+
88
+ # Determine the maximum possible top-left coordinates for the crop
89
+ max_left = sky_width - w
90
+ max_top = sky_height - h
91
+
92
+ # Choose random offsets for left and top within the valid range
93
+ left = random.randint(a=0, b=max_left) if max_left > 0 else 0
94
+ top = random.randint(a=0, b=max_top) if max_top > 0 else 0
95
+
96
+ # Crop the sky image to the target size using the random offsets
97
+ new_sky_img = new_sky_img.crop((left, top, left + w, top + h))
98
+
99
+ new_sky = np.asarray(new_sky_img).astype(np.float32) / 255.0
100
+ if foreground.dtype != np.float32:
101
+ foreground = foreground.astype(np.float32) / 255.0
102
+ if foreground.shape[2] == 4:
103
+ foreground = foreground[:, :, :3]
104
+
105
+ # Ensure that the alpha mask values are within the range [0, 1]
106
+ alpha_mask = np.clip(alpha_mask, a_min=0, a_max=1)
107
+
108
+ # Blend the foreground with the new sky using the alpha mask
109
+ return (1 - alpha_mask[:, :, None]) * foreground + alpha_mask[:, :, None] * new_sky
110
+
111
+
112
+ def _inference(image: Image) -> np.ndarray:
113
+ """
114
+ Perform inference on the input image using the ONNX model.
115
+
116
+ Args:
117
+ image (Image): The input image.
118
+
119
+ Returns:
120
+ np.ndarray: The predicted alpha mask.
121
+ """
122
+ output = session.run(output_names=[output_name], input_feed={input_name: image.cpu().numpy()})[0]
123
 
124
  # Ensure the output is in valid range [0, 1]
125
+ output = np.clip(output, a_min=0, a_max=1)
126
 
127
  return np.squeeze(output, axis=0).squeeze()
128
 
129
 
130
+ def predict(image: Image) -> Tuple[Image, Image]:
131
+ """
132
+ Perform sky replacement on the input image.
 
 
133
 
134
+ Args:
135
+ image (Image): The input image.
 
 
 
136
 
137
+ Returns:
138
+ Tuple[Image, Image]: The predicted alpha mask and the sky-replaced image.
139
+ """
140
+ image_tensor = transforms(image).unsqueeze(0).to(device)
141
+ predicted_alpha = _inference(image_tensor)
142
 
143
+ # Downscale the input image to match predicted_alpha
144
+ h, w = predicted_alpha.shape
145
+ downscaled_image = image.resize((w, h), Image.Resampling.LANCZOS)
146
+
147
+ # Estimate foreground and run sky_replacement
148
+ foreground = _get_foreground_estimation(downscaled_image, predicted_alpha)
149
+ replaced_sky = _sky_replacement(foreground, predicted_alpha)
150
+
151
+ # Resize the predicted alpha and replaced sky to original dimensions
152
+ predicted_alpha_pil = Image.fromarray((predicted_alpha * 255).astype(np.uint8), mode='L')
153
+ predicted_alpha_pil = predicted_alpha_pil.resize((h, w), Image.Resampling.LANCZOS)
154
+ replaced_sky_pil = Image.fromarray((replaced_sky * 255).astype(np.uint8)) # mode='RGB' typically
155
+ replaced_sky_pil = replaced_sky_pil.resize((h, w), Image.Resampling.LANCZOS)
156
+
157
+ return predicted_alpha_pil, replaced_sky_pil
158
+
159
+
160
+ real_example_list = [
161
+ ["examples/real/1901.jpg", "Real", "Good"],
162
+ ["examples/real/2022.jpg", "Real", "Good"],
163
+ ["examples/real/2041.jpg", "Real", "Good"],
164
+ ["examples/real/2196.jpg", "Real", "Good"],
165
+ ["examples/real/2188.jpg", "Real", "Good"],
166
+ ["examples/real/0001.jpg", "Real", "Acceptable, missing minor detail around the lamppost"],
167
+ ["examples/real/0054.jpg", "Real", "Acceptable, missing sky details between the houses"],
168
+ ["examples/real/2043.jpg", "Real", "Acceptable, missing minor detail in the window in the background"],
169
+ ["examples/real/0211.jpg", "Real", "Okay, misclassified a cloud in the left corner as the sky"],
170
+ ["examples/real/0894.jpg", "Real", "Okay, missing details in the trees"],
171
+ ["examples/real/2184.jpg", "Real", "Okay, lacks tree details in the background"],
172
+ ["examples/real/2026.jpg", "Real", "Okay, lacks tree details in the left background"],
173
+ ["examples/real/1975.jpg", "Real", "Okay, lacks tree branch details"],
174
+ ["examples/real/0069.jpg", "Real", "Bad, didn't replace the sky between the houses"],
175
+ ["examples/real/2079.jpg", "Real", "Bad, couldn't get the complete details of the tree"],
176
+ ["examples/real/2038.jpg", "Real", "Bad, lacks overall details in both trees and tree branches"],
177
+ ]
178
+
179
+ synthetic_example_list = [
180
+ ["examples/synthetic/0055.jpg", "Synthetic", "Good"],
181
+ ["examples/synthetic/0059.jpg", "Synthetic", "Good"],
182
+ ["examples/synthetic/0086.jpg", "Synthetic", "Good"],
183
+ ["examples/synthetic/10406.jpg", "Synthetic", "Good"],
184
+ ["examples/synthetic/10515.jpg", "Synthetic", "Good"],
185
+ ["examples/synthetic/10416.jpg", "Synthetic", "Acceptable, missing minor detail in the tree leaves"],
186
+ ["examples/synthetic/0150.jpg", "Synthetic", "Acceptable, missing minor detail in the tree"],
187
+ ["examples/synthetic/0096.jpg", "Synthetic", "Okay, missing minor detail in the trees"],
188
+ ["examples/synthetic/0124.jpg", "Synthetic", "Okay, missing minor detail in the trees"],
189
+ ["examples/synthetic/0127.jpg", "Synthetic", "Bad, missing many details in the trees"],
190
+ ["examples/synthetic/10467.jpg", "Synthetic", "Bad, misclassified the windows as sky"],
191
+ ]
192
 
193
  with gr.Blocks(theme=gr.themes.Default()) as demo:
194
  gr.Markdown(
 
200
  """
201
  )
202
 
203
+ data_type = gr.Radio(choices=["Real", "Synthetic"], value="Real", label="Select Data Type for Examples")
204
+
205
  with gr.Row():
206
  # Left Column: Input Image and Run/Clear Buttons
207
  with gr.Column(scale=1):
208
  input_image = gr.Image(type="pil", label="Input Image")
 
209
  with gr.Row():
210
  clear_button = gr.Button("Clear")
211
  run_button = gr.Button("Submit", variant="primary")
212
 
213
  # Right Column: Output Images
214
  with gr.Column(scale=1):
215
+ output_mask = gr.Image(type="pil", label="Predicted Mask")
216
+ output_sky = gr.Image(type="pil", label="Sky Replacement")
217
 
218
  metadata_display = gr.Markdown(None)
219
 
220
+ with gr.Column(visible=True) as real_examples_container:
221
+ real_examples_component = gr.Examples(
222
+ examples=real_example_list,
223
+ inputs=[input_image,
224
+ gr.Textbox(label="Data Type", value="", interactive=False, visible=False),
225
+ gr.Textbox(label="Result", value="", interactive=False, visible=False)],
226
+ outputs=[input_image, metadata_display],
227
+ fn=lambda example, dtype, desc: (example, f"**Type:** {dtype}\n\n**Result:** {desc}"),
228
+ cache_examples=False,
229
+ label="Real Data Examples"
230
+ )
231
+
232
+ with gr.Column(visible=False) as synthetic_examples_container:
233
+ synthetic_examples_component = gr.Examples(
234
+ examples=synthetic_example_list,
235
+ inputs=[input_image,
236
+ gr.Textbox(label="Data Type", value="", interactive=False, visible=False),
237
+ gr.Textbox(label="Result", value="", interactive=False, visible=False)],
238
+ outputs=[input_image, metadata_display],
239
+ fn=lambda example, dtype, desc: (example, f"**Type:** {dtype}\n\n**Result:** {desc}"),
240
+ cache_examples=False,
241
+ label="Synthetic Data Examples"
242
+ )
243
+
244
+
245
+ # Callback to toggle the container visibility based on selection.
246
+ def switch_examples(selected):
247
+ if selected == "Real":
248
+ return gr.update(visible=True), gr.update(visible=False)
249
+ else:
250
+ return gr.update(visible=False), gr.update(visible=True)
251
+
252
+
253
+ data_type.change(
254
+ fn=switch_examples,
255
+ inputs=data_type,
256
+ outputs=[real_examples_container, synthetic_examples_container]
257
  )
258
 
259
 
examples/real_0822.jpg β†’ assets/skies/francesco-ungaro-i75WTJn-RBY-unsplash.jpg RENAMED
File without changes
assets/skies/new_sky.webp DELETED
Binary file (12.5 kB)
 
examples/{real_0585.jpg β†’ real/0001.jpg} RENAMED
File without changes
examples/{real_0054.jpg β†’ real/0054.jpg} RENAMED
File without changes
examples/{real_0765.jpg β†’ real/0069.jpg} RENAMED
File without changes
examples/{real_0116.jpg β†’ real/0211.jpg} RENAMED
File without changes
examples/real/0894.jpg ADDED

Git LFS Details

  • SHA256: a801b4791e1453954fe420e59ed1e05422a4217e9b55eaf1df7b88518296e4dc
  • Pointer size: 132 Bytes
  • Size of remote file: 1.31 MB
examples/{real_bad_0007.jpg β†’ real/1901.jpg} RENAMED
File without changes
examples/real/1975.jpg ADDED

Git LFS Details

  • SHA256: b5a5056ac3ade74eca8b8ee33d442ba7a3e4ef057f83c9e8b428a3129f770e06
  • Pointer size: 132 Bytes
  • Size of remote file: 1.35 MB
examples/real/2022.jpg ADDED

Git LFS Details

  • SHA256: a173ac354da1a9bd3567578f5ef2f81a323644ed15c78c4c970ffb6fda0e7dc5
  • Pointer size: 132 Bytes
  • Size of remote file: 1.05 MB
examples/real/2026.jpg ADDED

Git LFS Details

  • SHA256: 224769f2507265e7c42c9a32cff090b97de7c9039902d14b655df9b5937f2f85
  • Pointer size: 132 Bytes
  • Size of remote file: 1.39 MB
examples/real/2038.jpg ADDED

Git LFS Details

  • SHA256: b773cdf133507e9137ddc3fae0323440cc20a4952c10c4a328564fa5e738be4e
  • Pointer size: 132 Bytes
  • Size of remote file: 1.73 MB
examples/{real_bad_0934.jpg β†’ real/2041.jpg} RENAMED
File without changes
examples/real/2043.jpg ADDED

Git LFS Details

  • SHA256: 74c9abf61485dfbe737bb25fe29ef86bab06cbbc9cd53c29b36899a9de2ba9f0
  • Pointer size: 131 Bytes
  • Size of remote file: 598 kB
examples/real/2079.jpg ADDED

Git LFS Details

  • SHA256: f5ae21dff87ce5c9781af74d46dfe0b7fe88357488ce40fcd4c4593ad5fbd8a1
  • Pointer size: 132 Bytes
  • Size of remote file: 1.09 MB
examples/real/2184.jpg ADDED

Git LFS Details

  • SHA256: bd395eda6c83ab37244144e8d0db37777e9799f08097404eb1449f9e227769af
  • Pointer size: 132 Bytes
  • Size of remote file: 1.4 MB
examples/real/2188.jpg ADDED

Git LFS Details

  • SHA256: 5f5bace92bafac1e38a2fa5ce2b3d4dc557f589e4a16234a60e1508bdc7d192d
  • Pointer size: 131 Bytes
  • Size of remote file: 523 kB
examples/real/2196.jpg ADDED

Git LFS Details

  • SHA256: 3477e92cbf08cc91eb68cadf56259f24fb7498a4ae2c5fb29f4a366bf83eccb5
  • Pointer size: 131 Bytes
  • Size of remote file: 601 kB
examples/real_0823.jpg DELETED

Git LFS Details

  • SHA256: b9dc011499ebd3054694e84666d846bcef41d4c4decd041897b700765334cda3
  • Pointer size: 132 Bytes
  • Size of remote file: 1.5 MB
examples/synthetic/0055.jpg ADDED

Git LFS Details

  • SHA256: f6702e87e927499d1951d7d1692e41b4d528fca06c76b2fc09a56b71a1f914e0
  • Pointer size: 131 Bytes
  • Size of remote file: 871 kB
examples/synthetic/0059.jpg ADDED

Git LFS Details

  • SHA256: 309aba280e1ed974cc328c2ac83d1967ce967621efc991a190236527b2963900
  • Pointer size: 132 Bytes
  • Size of remote file: 1.16 MB
examples/synthetic/0086.jpg ADDED

Git LFS Details

  • SHA256: 06b225e7f1bef4f9f9d8bdc751f510d12db21ae100ecc56e9ea72601a65fe3f8
  • Pointer size: 132 Bytes
  • Size of remote file: 1.07 MB
examples/synthetic/0097.jpg ADDED

Git LFS Details

  • SHA256: e617bfb74034886233cc0886183b8371005228c5d3823b16158eb27c86753953
  • Pointer size: 132 Bytes
  • Size of remote file: 1.41 MB
examples/synthetic/0124.jpg ADDED

Git LFS Details

  • SHA256: 16b748dbfdf1a7cc6192cc2a19df811cfd7831cd9b9432995a8e61e58f9d46bb
  • Pointer size: 132 Bytes
  • Size of remote file: 1.12 MB
examples/synthetic/0127.jpg ADDED

Git LFS Details

  • SHA256: a140a78d2851358c324db07a1be958ef8171d29f1d7f4ab2c95c4edb6df54d27
  • Pointer size: 132 Bytes
  • Size of remote file: 1.36 MB
examples/synthetic/0150.jpg ADDED

Git LFS Details

  • SHA256: d74794d02156cad4cf2eff1a4530bd3a1f473d1aeeafd563f01c920ce3273324
  • Pointer size: 132 Bytes
  • Size of remote file: 1.32 MB
examples/synthetic/10406.jpg ADDED

Git LFS Details

  • SHA256: 542d321ed7687475c55e21bdb72fa7445d580c36d7d239aa33f254dea82d9813
  • Pointer size: 132 Bytes
  • Size of remote file: 1.08 MB
examples/synthetic/10416.jpg ADDED

Git LFS Details

  • SHA256: 9f2329142531755b1abb16e9d597b44bf01185c2789ec95e52057d697a85d4a2
  • Pointer size: 132 Bytes
  • Size of remote file: 1.26 MB
examples/synthetic/10467.jpg ADDED

Git LFS Details

  • SHA256: 774b2aeabaeec93031c9f76def44cf6f87dbe174f3bdfc857eeee7ea7394072f
  • Pointer size: 132 Bytes
  • Size of remote file: 1.26 MB
examples/synthetic/10515.jpg ADDED

Git LFS Details

  • SHA256: d008291e26dbce54bfc9061136e710592700b902049c586d4e4fee122ec08fdf
  • Pointer size: 132 Bytes
  • Size of remote file: 1.41 MB
examples/synthetic_10512.jpg DELETED

Git LFS Details

  • SHA256: ec5248d1995b62fa6240ce13255925e1a307bd8b018473fdf6ee4e775940fb21
  • Pointer size: 132 Bytes
  • Size of remote file: 1.05 MB
examples/synthetic_10560.jpg DELETED

Git LFS Details

  • SHA256: 7aca6079ca0894d2e3f7276697c16e3d2cf88669f99330392535c495c46edf62
  • Pointer size: 132 Bytes
  • Size of remote file: 1.55 MB
examples/synthetic_10635.jpg DELETED

Git LFS Details

  • SHA256: 1cb70d51124a5d6bc5f52741dfa9e98f11ad89a02923b59d9500e18d3ebedb5d
  • Pointer size: 132 Bytes
  • Size of remote file: 1.47 MB
examples/synthetic_10679.jpg DELETED

Git LFS Details

  • SHA256: 947c2ef6c9f0d0d2622ac263ad1cc7c414d61cc797625d941cc213d25b7bb67e
  • Pointer size: 132 Bytes
  • Size of remote file: 1.4 MB
examples/synthetic_10795.jpg DELETED

Git LFS Details

  • SHA256: 4b76bfb92f7b37326ad5a1c3afd9ff8c1a678ff3528e000cbf3af0e5799fcc52
  • Pointer size: 132 Bytes
  • Size of remote file: 1.02 MB
replacements/__init__.py DELETED
File without changes
replacements/foreground_estimation.py DELETED
@@ -1,9 +0,0 @@
1
- import pymatting
2
- import numpy as np
3
-
4
-
5
- def get_foreground_estimation(image, alpha):
6
- normalized_image = np.array(image) / 255.0
7
- inverted_alpha = 1 - alpha
8
-
9
- return pymatting.estimate_foreground_ml(image=normalized_image, alpha=inverted_alpha)
 
 
 
 
 
 
 
 
 
 
replacements/replacements.py DELETED
@@ -1,28 +0,0 @@
1
- from pathlib import Path
2
- from PIL import Image
3
- import numpy as np
4
-
5
-
6
- def sky_replacement(foreground, alpha_mask):
7
- # Load the new sky image
8
- current_directory = Path(__file__).parent.parent
9
- new_sky_path = current_directory / "assets/skies/new_sky.webp"
10
- new_sky_img = Image.open(new_sky_path).convert("RGB")
11
-
12
- # Resize to match foreground dimensions
13
- h, w = foreground.shape[:2]
14
- new_sky_img = new_sky_img.resize(size=(w, h), resample=Image.Resampling.LANCZOS)
15
-
16
- # Convert new_sky and foreground to float32 numpy arrays in range [0, 1]
17
- new_sky = np.asarray(new_sky_img).astype(np.float32) / 255.0
18
- if foreground.dtype != np.float32:
19
- foreground = foreground.astype(np.float32) / 255.0
20
-
21
- # If foreground has an alpha channel, drop it
22
- if foreground.shape[2] == 4:
23
- foreground = foreground[:, :, :3]
24
-
25
- # Ensure values are in [0, 1]
26
- alpha_mask = np.clip(alpha_mask, a_min=0, a_max=1)
27
-
28
- return (1 - alpha_mask[:, :, None]) * foreground + alpha_mask[:, :, None] * new_sky