Spaces:
Paused
Paused
Upload 80 files
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- deepsvg/__init__.py +0 -0
- deepsvg/__pycache__/__init__.cpython-310.pyc +0 -0
- deepsvg/config.py +101 -0
- deepsvg/difflib/__pycache__/tensor.cpython-310.pyc +0 -0
- deepsvg/difflib/loss.py +51 -0
- deepsvg/difflib/tensor.py +249 -0
- deepsvg/difflib/utils.py +81 -0
- deepsvg/gui/README.md +2 -0
- deepsvg/gui/__init__.py +0 -0
- deepsvg/gui/config.py +5 -0
- deepsvg/gui/deepsvg.kv +380 -0
- deepsvg/gui/interpolate.py +126 -0
- deepsvg/gui/layout/__init__.py +0 -0
- deepsvg/gui/layout/aligned_textinput.py +52 -0
- deepsvg/gui/main.py +794 -0
- deepsvg/gui/res/down.png +0 -0
- deepsvg/gui/res/hand.png +0 -0
- deepsvg/gui/res/hand.svg +1 -0
- deepsvg/gui/res/pause.png +0 -0
- deepsvg/gui/res/pen.png +0 -0
- deepsvg/gui/res/pen.svg +1 -0
- deepsvg/gui/res/pencil.png +0 -0
- deepsvg/gui/res/pencil.svg +1 -0
- deepsvg/gui/res/play.png +0 -0
- deepsvg/gui/res/play.svg +3 -0
- deepsvg/gui/res/switch.png +0 -0
- deepsvg/gui/res/up.png +0 -0
- deepsvg/gui/state/__init__.py +0 -0
- deepsvg/gui/state/project.py +115 -0
- deepsvg/gui/state/state.py +78 -0
- deepsvg/gui/utils.py +66 -0
- deepsvg/model/basic_blocks.py +101 -0
- deepsvg/model/config.py +107 -0
- deepsvg/model/layers/__init__.py +0 -0
- deepsvg/model/layers/attention.py +161 -0
- deepsvg/model/layers/functional.py +256 -0
- deepsvg/model/layers/improved_transformer.py +141 -0
- deepsvg/model/layers/positional_encoding.py +43 -0
- deepsvg/model/layers/transformer.py +393 -0
- deepsvg/model/layers/utils.py +36 -0
- deepsvg/model/loss.py +104 -0
- deepsvg/model/model.py +690 -0
- deepsvg/model/utils.py +84 -0
- deepsvg/model/vector_quantize_pytorch.py +605 -0
- deepsvg/schedulers/warmup.py +67 -0
- deepsvg/svg_dataset.py +269 -0
- deepsvg/svglib/__init__.py +0 -0
- deepsvg/svglib/__pycache__/__init__.cpython-310.pyc +0 -0
- deepsvg/svglib/__pycache__/geom.cpython-310.pyc +0 -0
- deepsvg/svglib/__pycache__/svg.cpython-310.pyc +0 -0
deepsvg/__init__.py
ADDED
File without changes
|
deepsvg/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (139 Bytes). View file
|
|
deepsvg/config.py
ADDED
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch.optim as optim
|
2 |
+
from deepsvg.schedulers.warmup import GradualWarmupScheduler
|
3 |
+
|
4 |
+
|
5 |
+
class _Config:
|
6 |
+
"""
|
7 |
+
Training config.
|
8 |
+
"""
|
9 |
+
def __init__(self, num_gpus=1):
|
10 |
+
|
11 |
+
self.num_gpus = num_gpus #
|
12 |
+
|
13 |
+
self.dataloader_module = "deepsvg.svgtensor_dataset" #
|
14 |
+
self.collate_fn = None #
|
15 |
+
self.data_dir = "./dataset/icons_tensor/" #
|
16 |
+
self.meta_filepath = "./dataset/icons_meta.csv" #
|
17 |
+
self.loader_num_workers = 0 #
|
18 |
+
|
19 |
+
self.pretrained_path = None #
|
20 |
+
|
21 |
+
self.model_cfg = None #
|
22 |
+
|
23 |
+
self.num_epochs = None #
|
24 |
+
self.num_steps = None #
|
25 |
+
self.learning_rate = 1e-3 #
|
26 |
+
self.batch_size = 100 #
|
27 |
+
self.warmup_steps = 500 #
|
28 |
+
|
29 |
+
|
30 |
+
# Dataset
|
31 |
+
self.train_ratio = 1.0 #
|
32 |
+
self.nb_augmentations = 1 #
|
33 |
+
|
34 |
+
self.max_num_groups = 15 #
|
35 |
+
self.max_seq_len = 30 #
|
36 |
+
self.max_total_len = None #
|
37 |
+
|
38 |
+
self.filter_uni = None #
|
39 |
+
self.filter_category = None #
|
40 |
+
self.filter_platform = None #
|
41 |
+
|
42 |
+
self.filter_labels = None #
|
43 |
+
|
44 |
+
self.grad_clip = None #
|
45 |
+
|
46 |
+
self.log_every = 20 #
|
47 |
+
self.val_every = 1000 #
|
48 |
+
self.ckpt_every = 1000 #
|
49 |
+
|
50 |
+
self.stats_to_print = {
|
51 |
+
"train": ["lr", "time"]
|
52 |
+
}
|
53 |
+
|
54 |
+
self.model_args = [] #
|
55 |
+
self.optimizer_starts = [0] #
|
56 |
+
|
57 |
+
# Overridable methods
|
58 |
+
def make_model(self):
|
59 |
+
raise NotImplementedError
|
60 |
+
|
61 |
+
def make_losses(self):
|
62 |
+
raise NotImplementedError
|
63 |
+
|
64 |
+
def make_optimizers(self, model):
|
65 |
+
return [optim.AdamW(model.parameters(), self.learning_rate)]
|
66 |
+
|
67 |
+
def make_schedulers(self, optimizers, epoch_size):
|
68 |
+
return [None] * len(optimizers)
|
69 |
+
|
70 |
+
def make_warmup_schedulers(self, optimizers, scheduler_lrs):
|
71 |
+
return [GradualWarmupScheduler(optimizer, multiplier=1.0, total_epoch=self.warmup_steps, after_scheduler=scheduler_lr)
|
72 |
+
for optimizer, scheduler_lr in zip(optimizers, scheduler_lrs)]
|
73 |
+
|
74 |
+
def get_params(self, step, epoch):
|
75 |
+
return {}
|
76 |
+
|
77 |
+
def get_weights(self, step, epoch):
|
78 |
+
return {}
|
79 |
+
|
80 |
+
def set_train_vars(self, train_vars, dataloader):
|
81 |
+
pass
|
82 |
+
|
83 |
+
def visualize(self, model, output, train_vars, step, epoch, summary_writer, visualization_dir):
|
84 |
+
pass
|
85 |
+
|
86 |
+
# Utility methods
|
87 |
+
def values(self):
|
88 |
+
for key in dir(self):
|
89 |
+
if not key.startswith("__") and not callable(getattr(self, key)):
|
90 |
+
yield key, getattr(self, key)
|
91 |
+
|
92 |
+
def to_dict(self):
|
93 |
+
return {key: val for key, val in self.values()}
|
94 |
+
|
95 |
+
def load_dict(self, dict):
|
96 |
+
for key, val in dict.items():
|
97 |
+
setattr(self, key, val)
|
98 |
+
|
99 |
+
def print_params(self):
|
100 |
+
for key, val in self.values():
|
101 |
+
print(f" {key} = {val}")
|
deepsvg/difflib/__pycache__/tensor.cpython-310.pyc
ADDED
Binary file (8.57 kB). View file
|
|
deepsvg/difflib/loss.py
ADDED
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
from .utils import *
|
3 |
+
|
4 |
+
|
5 |
+
def chamfer_loss(x, y):
|
6 |
+
d = torch.cdist(x, y)
|
7 |
+
return d.min(dim=0).values.mean() + d.min(dim=1).values.mean()
|
8 |
+
|
9 |
+
|
10 |
+
def continuity_loss(x):
|
11 |
+
d = (x[1:] - x[:-1]).norm(dim=-1, p=2)
|
12 |
+
return d.mean()
|
13 |
+
|
14 |
+
|
15 |
+
def svg_length_loss(p_pred, p_target):
|
16 |
+
pred_length, target_length = get_length(p_pred), get_length(p_target)
|
17 |
+
|
18 |
+
return (target_length - pred_length).abs() / target_length
|
19 |
+
|
20 |
+
|
21 |
+
def svg_emd_loss(p_pred, p_target,
|
22 |
+
first_point_weight=False, return_matched_indices=False):
|
23 |
+
n, m = len(p_pred), len(p_target)
|
24 |
+
|
25 |
+
if n == 0:
|
26 |
+
return 0.
|
27 |
+
|
28 |
+
# Make target point lists clockwise
|
29 |
+
p_target = make_clockwise(p_target)
|
30 |
+
|
31 |
+
# Compute length distribution
|
32 |
+
distr_pred = torch.linspace(0., 1., n).to(p_pred.device)
|
33 |
+
distr_target = get_length_distribution(p_target, normalize=True)
|
34 |
+
d = torch.cdist(distr_pred.unsqueeze(-1), distr_target.unsqueeze(-1))
|
35 |
+
matching = d.argmin(dim=-1)
|
36 |
+
p_target_sub = p_target[matching]
|
37 |
+
|
38 |
+
# EMD
|
39 |
+
i = np.argmin([torch.norm(p_pred - reorder(p_target_sub, i), dim=-1).mean() for i in range(n)])
|
40 |
+
|
41 |
+
losses = torch.norm(p_pred - reorder(p_target_sub, i), dim=-1)
|
42 |
+
|
43 |
+
if first_point_weight:
|
44 |
+
weights = torch.ones_like(losses)
|
45 |
+
weights[0] = 10.
|
46 |
+
losses = losses * weights
|
47 |
+
|
48 |
+
if return_matched_indices:
|
49 |
+
return losses.mean(), (p_pred, p_target, reorder(matching, i))
|
50 |
+
|
51 |
+
return losses.mean()
|
deepsvg/difflib/tensor.py
ADDED
@@ -0,0 +1,249 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import annotations
|
2 |
+
import torch
|
3 |
+
import torch.utils.data
|
4 |
+
from typing import Union
|
5 |
+
Num = Union[int, float]
|
6 |
+
|
7 |
+
|
8 |
+
class SVGTensor:
|
9 |
+
# 0 1 2 3 4 5 6
|
10 |
+
COMMANDS_SIMPLIFIED = ["m", "l", "c", "a", "EOS", "SOS", "z"]
|
11 |
+
|
12 |
+
# rad x lrg sw ctrl ctrl end
|
13 |
+
# ius axs arc eep 1 2 pos
|
14 |
+
# rot fg fg
|
15 |
+
CMD_ARGS_MASK = torch.tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1], # m
|
16 |
+
[0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1], # l
|
17 |
+
[0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1], # c
|
18 |
+
[1, 1, 1, 1, 1, 0, 0, 0, 0, 1, 1], # a
|
19 |
+
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], # EOS
|
20 |
+
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], # SOS
|
21 |
+
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]) # z
|
22 |
+
|
23 |
+
class Index:
|
24 |
+
COMMAND = 0
|
25 |
+
RADIUS = slice(1, 3)
|
26 |
+
X_AXIS_ROT = 3
|
27 |
+
LARGE_ARC_FLG = 4
|
28 |
+
SWEEP_FLG = 5
|
29 |
+
START_POS = slice(6, 8)
|
30 |
+
CONTROL1 = slice(8, 10)
|
31 |
+
CONTROL2 = slice(10, 12)
|
32 |
+
END_POS = slice(12, 14)
|
33 |
+
|
34 |
+
class IndexArgs:
|
35 |
+
RADIUS = slice(0, 2)
|
36 |
+
X_AXIS_ROT = 2
|
37 |
+
LARGE_ARC_FLG = 3
|
38 |
+
SWEEP_FLG = 4
|
39 |
+
CONTROL1 = slice(5, 7)
|
40 |
+
CONTROL2 = slice(7, 9)
|
41 |
+
END_POS = slice(9, 11)
|
42 |
+
|
43 |
+
position_keys = ["control1", "control2", "end_pos"]
|
44 |
+
all_position_keys = ["start_pos", *position_keys]
|
45 |
+
arg_keys = ["radius", "x_axis_rot", "large_arc_flg", "sweep_flg", *position_keys]
|
46 |
+
all_arg_keys = [*arg_keys[:4], "start_pos", *arg_keys[4:]]
|
47 |
+
cmd_arg_keys = ["commands", *arg_keys]
|
48 |
+
all_keys = ["commands", *all_arg_keys]
|
49 |
+
|
50 |
+
def __init__(self, commands, radius, x_axis_rot, large_arc_flg, sweep_flg, control1, control2, end_pos,
|
51 |
+
seq_len=None, label=None, PAD_VAL=-1, ARGS_DIM=256, filling=0):
|
52 |
+
|
53 |
+
self.commands = commands.reshape(-1, 1).float()
|
54 |
+
|
55 |
+
self.radius = radius.float()
|
56 |
+
self.x_axis_rot = x_axis_rot.reshape(-1, 1).float()
|
57 |
+
self.large_arc_flg = large_arc_flg.reshape(-1, 1).float()
|
58 |
+
self.sweep_flg = sweep_flg.reshape(-1, 1).float()
|
59 |
+
|
60 |
+
self.control1 = control1.float()
|
61 |
+
self.control2 = control2.float()
|
62 |
+
self.end_pos = end_pos.float()
|
63 |
+
|
64 |
+
self.seq_len = torch.tensor(len(commands)) if seq_len is None else seq_len
|
65 |
+
self.label = label
|
66 |
+
|
67 |
+
self.PAD_VAL = PAD_VAL
|
68 |
+
self.ARGS_DIM = ARGS_DIM
|
69 |
+
|
70 |
+
self.sos_token = torch.Tensor([self.COMMANDS_SIMPLIFIED.index("SOS")]).unsqueeze(-1)
|
71 |
+
self.eos_token = self.pad_token = torch.Tensor([self.COMMANDS_SIMPLIFIED.index("EOS")]).unsqueeze(-1)
|
72 |
+
|
73 |
+
self.filling = filling
|
74 |
+
|
75 |
+
@property
|
76 |
+
def start_pos(self):
|
77 |
+
start_pos = self.end_pos[:-1]
|
78 |
+
|
79 |
+
return torch.cat([
|
80 |
+
start_pos.new_zeros(1, 2),
|
81 |
+
start_pos
|
82 |
+
])
|
83 |
+
|
84 |
+
@staticmethod
|
85 |
+
def from_data(data, *args, **kwargs):
|
86 |
+
return SVGTensor(data[:, SVGTensor.Index.COMMAND], data[:, SVGTensor.Index.RADIUS], data[:, SVGTensor.Index.X_AXIS_ROT],
|
87 |
+
data[:, SVGTensor.Index.LARGE_ARC_FLG], data[:, SVGTensor.Index.SWEEP_FLG], data[:, SVGTensor.Index.CONTROL1],
|
88 |
+
data[:, SVGTensor.Index.CONTROL2], data[:, SVGTensor.Index.END_POS], *args, **kwargs)
|
89 |
+
|
90 |
+
@staticmethod
|
91 |
+
def from_cmd_args(commands, args, *nargs, **kwargs):
|
92 |
+
return SVGTensor(commands, args[:, SVGTensor.IndexArgs.RADIUS], args[:, SVGTensor.IndexArgs.X_AXIS_ROT],
|
93 |
+
args[:, SVGTensor.IndexArgs.LARGE_ARC_FLG], args[:, SVGTensor.IndexArgs.SWEEP_FLG], args[:, SVGTensor.IndexArgs.CONTROL1],
|
94 |
+
args[:, SVGTensor.IndexArgs.CONTROL2], args[:, SVGTensor.IndexArgs.END_POS], *nargs, **kwargs)
|
95 |
+
|
96 |
+
def get_data(self, keys):
|
97 |
+
return torch.cat([self.__getattribute__(key) for key in keys], dim=-1)
|
98 |
+
|
99 |
+
@property
|
100 |
+
def data(self):
|
101 |
+
return self.get_data(self.all_keys)
|
102 |
+
|
103 |
+
def copy(self):
|
104 |
+
return SVGTensor(*[self.__getattribute__(key).clone() for key in self.cmd_arg_keys],
|
105 |
+
seq_len=self.seq_len.clone(), label=self.label, PAD_VAL=self.PAD_VAL, ARGS_DIM=self.ARGS_DIM,
|
106 |
+
filling=self.filling)
|
107 |
+
|
108 |
+
def add_sos(self):
|
109 |
+
self.commands = torch.cat([self.sos_token, self.commands])
|
110 |
+
|
111 |
+
for key in self.arg_keys:
|
112 |
+
v = self.__getattribute__(key)
|
113 |
+
self.__setattr__(key, torch.cat([v.new_full((1, v.size(-1)), self.PAD_VAL), v]))
|
114 |
+
|
115 |
+
self.seq_len += 1
|
116 |
+
return self
|
117 |
+
|
118 |
+
def drop_sos(self):
|
119 |
+
for key in self.cmd_arg_keys:
|
120 |
+
self.__setattr__(key, self.__getattribute__(key)[1:])
|
121 |
+
|
122 |
+
self.seq_len -= 1
|
123 |
+
return self
|
124 |
+
|
125 |
+
def add_eos(self):
|
126 |
+
self.commands = torch.cat([self.commands, self.eos_token])
|
127 |
+
|
128 |
+
for key in self.arg_keys:
|
129 |
+
v = self.__getattribute__(key)
|
130 |
+
self.__setattr__(key, torch.cat([v, v.new_full((1, v.size(-1)), self.PAD_VAL)]))
|
131 |
+
|
132 |
+
return self
|
133 |
+
|
134 |
+
def pad(self, seq_len=51):
|
135 |
+
pad_len = max(seq_len - len(self.commands), 0)
|
136 |
+
|
137 |
+
self.commands = torch.cat([self.commands, self.pad_token.repeat(pad_len, 1)])
|
138 |
+
|
139 |
+
for key in self.arg_keys:
|
140 |
+
v = self.__getattribute__(key)
|
141 |
+
self.__setattr__(key, torch.cat([v, v.new_full((pad_len, v.size(-1)), self.PAD_VAL)]))
|
142 |
+
|
143 |
+
return self
|
144 |
+
|
145 |
+
def unpad(self):
|
146 |
+
# Remove EOS + padding
|
147 |
+
for key in self.cmd_arg_keys:
|
148 |
+
self.__setattr__(key, self.__getattribute__(key)[:self.seq_len])
|
149 |
+
return self
|
150 |
+
|
151 |
+
def draw(self, *args, **kwags):
|
152 |
+
from deepsvg.svglib.svg import SVGPath
|
153 |
+
return SVGPath.from_tensor(self.data).draw(*args, **kwags)
|
154 |
+
|
155 |
+
def cmds(self):
|
156 |
+
return self.commands.reshape(-1)
|
157 |
+
|
158 |
+
def args(self, with_start_pos=False):
|
159 |
+
if with_start_pos:
|
160 |
+
return self.get_data(self.all_arg_keys)
|
161 |
+
|
162 |
+
return self.get_data(self.arg_keys)
|
163 |
+
|
164 |
+
def _get_real_commands_mask(self):
|
165 |
+
mask = self.cmds() < self.COMMANDS_SIMPLIFIED.index("EOS")
|
166 |
+
return mask
|
167 |
+
|
168 |
+
def _get_args_mask(self):
|
169 |
+
mask = SVGTensor.CMD_ARGS_MASK[self.cmds().long()].bool()
|
170 |
+
return mask
|
171 |
+
|
172 |
+
def get_relative_args(self):
|
173 |
+
data = self.args().clone()
|
174 |
+
|
175 |
+
real_commands = self._get_real_commands_mask()
|
176 |
+
data_real_commands = data[real_commands]
|
177 |
+
|
178 |
+
start_pos = data_real_commands[:-1, SVGTensor.IndexArgs.END_POS].clone()
|
179 |
+
|
180 |
+
data_real_commands[1:, SVGTensor.IndexArgs.CONTROL1] -= start_pos
|
181 |
+
data_real_commands[1:, SVGTensor.IndexArgs.CONTROL2] -= start_pos
|
182 |
+
data_real_commands[1:, SVGTensor.IndexArgs.END_POS] -= start_pos
|
183 |
+
data[real_commands] = data_real_commands
|
184 |
+
|
185 |
+
mask = self._get_args_mask()
|
186 |
+
data[mask] += self.ARGS_DIM - 1
|
187 |
+
data[~mask] = self.PAD_VAL
|
188 |
+
|
189 |
+
return data
|
190 |
+
|
191 |
+
def sample_points(self, n=10):
|
192 |
+
device = self.commands.device
|
193 |
+
|
194 |
+
z = torch.linspace(0, 1, n, device=device)
|
195 |
+
Z = torch.stack([torch.ones_like(z), z, z.pow(2), z.pow(3)], dim=1)
|
196 |
+
|
197 |
+
Q = torch.tensor([
|
198 |
+
[[0., 0., 0., 0.], # "m"
|
199 |
+
[0., 0., 0., 0.],
|
200 |
+
[0., 0., 0., 0.],
|
201 |
+
[0., 0., 0., 0.]],
|
202 |
+
|
203 |
+
[[1., 0., 0., 0.], # "l"
|
204 |
+
[-1, 0., 0., 1.],
|
205 |
+
[0., 0., 0., 0.],
|
206 |
+
[0., 0., 0., 0.]],
|
207 |
+
|
208 |
+
[[1., 0., 0., 0.], # "c"
|
209 |
+
[-3, 3., 0., 0.],
|
210 |
+
[3., -6, 3., 0.],
|
211 |
+
[-1, 3., -3, 1.]],
|
212 |
+
|
213 |
+
torch.zeros(4, 4), # "a", no support yet
|
214 |
+
|
215 |
+
torch.zeros(4, 4), # "EOS"
|
216 |
+
torch.zeros(4, 4), # "SOS"
|
217 |
+
torch.zeros(4, 4), # "z"
|
218 |
+
], device=device)
|
219 |
+
|
220 |
+
commands, pos = self.commands.reshape(-1).long(), self.get_data(self.all_position_keys).reshape(-1, 4, 2)
|
221 |
+
inds = (commands == self.COMMANDS_SIMPLIFIED.index("l")) | (commands == self.COMMANDS_SIMPLIFIED.index("c"))
|
222 |
+
commands, pos = commands[inds], pos[inds]
|
223 |
+
|
224 |
+
Z_coeffs = torch.matmul(Q[commands], pos)
|
225 |
+
|
226 |
+
# Last point being first point of next command, we drop last point except the one from the last command
|
227 |
+
sample_points = torch.matmul(Z, Z_coeffs)
|
228 |
+
sample_points = torch.cat([sample_points[:, :-1].reshape(-1, 2), sample_points[-1, -1].unsqueeze(0)])
|
229 |
+
|
230 |
+
return sample_points
|
231 |
+
|
232 |
+
@staticmethod
|
233 |
+
def get_length_distribution(p, normalize=True):
|
234 |
+
start, end = p[:-1], p[1:]
|
235 |
+
length_distr = torch.norm(end - start, dim=-1).cumsum(dim=0)
|
236 |
+
length_distr = torch.cat([length_distr.new_zeros(1), length_distr])
|
237 |
+
if normalize:
|
238 |
+
length_distr = length_distr / length_distr[-1]
|
239 |
+
return length_distr
|
240 |
+
|
241 |
+
def sample_uniform_points(self, n=100):
|
242 |
+
p = self.sample_points(n=n)
|
243 |
+
|
244 |
+
distr_unif = torch.linspace(0., 1., n).to(p.device)
|
245 |
+
distr = self.get_length_distribution(p, normalize=True)
|
246 |
+
d = torch.cdist(distr_unif.unsqueeze(-1), distr.unsqueeze(-1))
|
247 |
+
matching = d.argmin(dim=-1)
|
248 |
+
|
249 |
+
return p[matching]
|
deepsvg/difflib/utils.py
ADDED
@@ -0,0 +1,81 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import matplotlib.pyplot as plt
|
3 |
+
import PIL.Image
|
4 |
+
import io
|
5 |
+
|
6 |
+
|
7 |
+
def set_viewbox(viewbox):
|
8 |
+
plt.xlim(0, viewbox[0])
|
9 |
+
plt.ylim(viewbox[1], 0)
|
10 |
+
|
11 |
+
|
12 |
+
def plot_points(p, viewbox=None, show_color=False, show_colorbar=False, image_file=None, return_img=False):
|
13 |
+
cm = plt.cm.get_cmap('RdYlBu')
|
14 |
+
plt.gca().set_aspect('equal')
|
15 |
+
plt.gca().invert_yaxis()
|
16 |
+
plt.gca().axis('off')
|
17 |
+
|
18 |
+
if viewbox is not None:
|
19 |
+
set_viewbox(viewbox)
|
20 |
+
|
21 |
+
kwargs = {"c": range(len(p)), "cmap": cm} if show_color else {}
|
22 |
+
plt.scatter(p[:, 0], p[:, 1], **kwargs)
|
23 |
+
|
24 |
+
if show_color and show_colorbar:
|
25 |
+
plt.colorbar()
|
26 |
+
|
27 |
+
if image_file is not None:
|
28 |
+
plt.savefig(image_file, bbox_inches='tight')
|
29 |
+
|
30 |
+
if return_img:
|
31 |
+
buf = io.BytesIO()
|
32 |
+
plt.gcf().savefig(buf)
|
33 |
+
buf.seek(0)
|
34 |
+
return PIL.Image.open(buf)
|
35 |
+
|
36 |
+
|
37 |
+
def plot_matching(p1, p2, matching, viewbox=None):
|
38 |
+
plt.gca().set_aspect('equal')
|
39 |
+
plt.gca().invert_yaxis()
|
40 |
+
plt.axis("off")
|
41 |
+
|
42 |
+
if viewbox is not None:
|
43 |
+
set_viewbox(viewbox)
|
44 |
+
|
45 |
+
plt.scatter(p1[:, 0], p1[:, 1], color="C0")
|
46 |
+
plt.scatter(p2[:, 0], p2[:, 1], color="C1")
|
47 |
+
|
48 |
+
for start, end in zip(p1[::10], p2[matching][::10]):
|
49 |
+
plt.plot([start[0], end[0]], [start[1], end[1]], color="C2")
|
50 |
+
|
51 |
+
|
52 |
+
def is_clockwise(p):
|
53 |
+
start, end = p[:-1], p[1:]
|
54 |
+
return torch.stack([start, end], dim=-1).det().sum() > 0
|
55 |
+
|
56 |
+
|
57 |
+
def make_clockwise(p):
|
58 |
+
if not is_clockwise(p):
|
59 |
+
return p.flip(dims=[0])
|
60 |
+
return p
|
61 |
+
|
62 |
+
|
63 |
+
def reorder(p, i):
|
64 |
+
return torch.cat([p[i:], p[:i]])
|
65 |
+
|
66 |
+
|
67 |
+
def get_length(p):
|
68 |
+
start, end = p[:-1], p[1:]
|
69 |
+
return torch.norm(end - start, dim=-1).sum()
|
70 |
+
|
71 |
+
|
72 |
+
def get_length_distribution(p, normalize=True):
|
73 |
+
start, end = p[:-1], p[1:]
|
74 |
+
length_distr = torch.norm(end - start, dim=-1).cumsum(dim=0)
|
75 |
+
length_distr = torch.cat([length_distr.new_zeros(1),
|
76 |
+
length_distr])
|
77 |
+
|
78 |
+
if normalize:
|
79 |
+
length_distr = length_distr / length_distr[-1]
|
80 |
+
|
81 |
+
return length_distr
|
deepsvg/gui/README.md
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
# DeepSVG Editor: a GUI for easy SVG animation
|
2 |
+
|
deepsvg/gui/__init__.py
ADDED
File without changes
|
deepsvg/gui/config.py
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
ROOT_DIR = "./gui_data"
|
4 |
+
STATE_PATH = os.path.join(ROOT_DIR, "state.pkl")
|
5 |
+
TMP_PATH = os.path.join(ROOT_DIR, "tmp")
|
deepsvg/gui/deepsvg.kv
ADDED
@@ -0,0 +1,380 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<DeepSVGWidget>:
|
2 |
+
orientation: "vertical"
|
3 |
+
|
4 |
+
Header:
|
5 |
+
id: header
|
6 |
+
|
7 |
+
BoxLayout:
|
8 |
+
orientation: "horizontal"
|
9 |
+
|
10 |
+
Sidebar:
|
11 |
+
id: sidebar_scroll
|
12 |
+
|
13 |
+
StencilView:
|
14 |
+
size_hint: 1, 1
|
15 |
+
|
16 |
+
canvas.before:
|
17 |
+
Color:
|
18 |
+
rgb: 0.89, 0.89, 0.89
|
19 |
+
Rectangle:
|
20 |
+
pos: self.pos
|
21 |
+
size: self.size
|
22 |
+
|
23 |
+
EditorView:
|
24 |
+
id: editor
|
25 |
+
|
26 |
+
TimeLine:
|
27 |
+
id: timeline_scroll
|
28 |
+
|
29 |
+
|
30 |
+
<Header>:
|
31 |
+
orientation: "horizontal"
|
32 |
+
size_hint_y: None
|
33 |
+
height: 50
|
34 |
+
|
35 |
+
canvas.before:
|
36 |
+
Color:
|
37 |
+
rgb: 0, 0, 0
|
38 |
+
Rectangle:
|
39 |
+
pos: self.pos
|
40 |
+
size: self.size
|
41 |
+
HeaderIcon:
|
42 |
+
source: "deepsvg/gui/res/hand.png"
|
43 |
+
index: 0
|
44 |
+
|
45 |
+
HeaderIcon:
|
46 |
+
source: "deepsvg/gui/res/pen.png"
|
47 |
+
index: 1
|
48 |
+
|
49 |
+
HeaderIcon:
|
50 |
+
source: "deepsvg/gui/res/pencil.png"
|
51 |
+
index: 2
|
52 |
+
|
53 |
+
Padding
|
54 |
+
|
55 |
+
HeaderButton:
|
56 |
+
text: "Clear all"
|
57 |
+
on_press: root.on_erase()
|
58 |
+
|
59 |
+
Padding
|
60 |
+
|
61 |
+
HeaderButton:
|
62 |
+
text: "Done"
|
63 |
+
on_press: root.on_done()
|
64 |
+
|
65 |
+
Label
|
66 |
+
|
67 |
+
Padding
|
68 |
+
|
69 |
+
TitleWidget:
|
70 |
+
text: root.title
|
71 |
+
on_text: root.on_title(self.text)
|
72 |
+
|
73 |
+
Padding
|
74 |
+
|
75 |
+
Label
|
76 |
+
|
77 |
+
HeaderButton:
|
78 |
+
text: "Add frame"
|
79 |
+
on_press: root.add_frame()
|
80 |
+
|
81 |
+
Padding
|
82 |
+
|
83 |
+
HeaderButton:
|
84 |
+
text: "Interpolate"
|
85 |
+
on_press: root.interpolate()
|
86 |
+
|
87 |
+
Padding
|
88 |
+
|
89 |
+
HeaderIcon:
|
90 |
+
index: 3
|
91 |
+
source: "deepsvg/gui/res/pause.png" if root.is_playing else "deepsvg/gui/res/play.png"
|
92 |
+
on_press: root.pause_animation() if root.is_playing else root.play_animation()
|
93 |
+
|
94 |
+
|
95 |
+
<TitleWidget>:
|
96 |
+
size_hint_x: None
|
97 |
+
width: 150
|
98 |
+
|
99 |
+
multiline: False
|
100 |
+
background_color: 0, 0, 0, 1
|
101 |
+
background_active: ""
|
102 |
+
background_normal: ""
|
103 |
+
halign: "center"
|
104 |
+
valign: "middle"
|
105 |
+
foreground_color: 1, 1, 1, 1
|
106 |
+
hint_text_color: 1, 1, 1, 1
|
107 |
+
cursor_color: 1, 1, 1, 1
|
108 |
+
|
109 |
+
<Sidebar>:
|
110 |
+
do_scroll_x: False
|
111 |
+
size_hint_x: None
|
112 |
+
width: 225
|
113 |
+
|
114 |
+
canvas.before:
|
115 |
+
Color:
|
116 |
+
rgb: 1, 1, 1
|
117 |
+
Rectangle:
|
118 |
+
pos: self.pos
|
119 |
+
size: self.size
|
120 |
+
|
121 |
+
Color:
|
122 |
+
rgb: 0.8, 0.8, 0.8
|
123 |
+
Line:
|
124 |
+
width: 1
|
125 |
+
rectangle: self.x, self.y, self.width, self.height
|
126 |
+
|
127 |
+
BoxLayout:
|
128 |
+
id: sidebar
|
129 |
+
|
130 |
+
orientation: "vertical"
|
131 |
+
size_hint_y: None
|
132 |
+
height: self.children[0].height * len(self.children) if self.children else 0
|
133 |
+
|
134 |
+
|
135 |
+
<PathLayerView>
|
136 |
+
orientation: "horizontal"
|
137 |
+
size_hint_y: None
|
138 |
+
height: 40
|
139 |
+
|
140 |
+
canvas:
|
141 |
+
Color:
|
142 |
+
rgb: (0.08, 0.58, 0.97) if self.parent is not None and self.index == self.parent.parent.selected_path_idx else (1, 1, 1)
|
143 |
+
Rectangle:
|
144 |
+
pos: self.pos
|
145 |
+
size: self.size
|
146 |
+
Color:
|
147 |
+
rgb: 0.8, 0.8, 0.8
|
148 |
+
Line:
|
149 |
+
width: 1
|
150 |
+
rectangle: self.x, self.y, self.width, self.height
|
151 |
+
|
152 |
+
Label:
|
153 |
+
color: 0, 0, 0, 1
|
154 |
+
size_hint_x: None
|
155 |
+
text: str(root.index)
|
156 |
+
width: self.texture_size[0]
|
157 |
+
padding_x: 10
|
158 |
+
|
159 |
+
Label
|
160 |
+
|
161 |
+
Image:
|
162 |
+
size_hint_x: None
|
163 |
+
source: root.source
|
164 |
+
nocache: True
|
165 |
+
|
166 |
+
Label
|
167 |
+
|
168 |
+
UpButton
|
169 |
+
|
170 |
+
DownButton
|
171 |
+
|
172 |
+
Padding
|
173 |
+
|
174 |
+
ReverseButton
|
175 |
+
|
176 |
+
Label
|
177 |
+
|
178 |
+
|
179 |
+
<UpButton>:
|
180 |
+
size_hint: None, None
|
181 |
+
height: 0.6 * self.parent.height
|
182 |
+
pos_hint: {'top': 0.8}
|
183 |
+
width: self.height
|
184 |
+
|
185 |
+
background_normal: ""
|
186 |
+
background_down: ""
|
187 |
+
background_color: 0.3, 0.3, 0.3, 1
|
188 |
+
|
189 |
+
Image:
|
190 |
+
source: "deepsvg/gui/res/up.png"
|
191 |
+
center: self.parent.center
|
192 |
+
|
193 |
+
<DownButton>:
|
194 |
+
size_hint: None, None
|
195 |
+
height: 0.6 * self.parent.height
|
196 |
+
pos_hint: {'top': 0.8}
|
197 |
+
width: self.height
|
198 |
+
|
199 |
+
background_normal: ""
|
200 |
+
background_down: ""
|
201 |
+
background_color: 0.3, 0.3, 0.3, 1
|
202 |
+
|
203 |
+
Image:
|
204 |
+
source: "deepsvg/gui/res/down.png"
|
205 |
+
center: self.parent.center
|
206 |
+
|
207 |
+
<ReverseButton>:
|
208 |
+
size_hint: None, None
|
209 |
+
height: 0.6 * self.parent.height
|
210 |
+
pos_hint: {'top': 0.8}
|
211 |
+
width: self.height
|
212 |
+
|
213 |
+
background_normal: ""
|
214 |
+
background_down: ""
|
215 |
+
background_color: 0.3, 0.3, 0.3, 1
|
216 |
+
|
217 |
+
Image:
|
218 |
+
source: "deepsvg/gui/res/switch.png"
|
219 |
+
center: self.parent.center
|
220 |
+
|
221 |
+
|
222 |
+
<BezierSegment>:
|
223 |
+
canvas:
|
224 |
+
Color:
|
225 |
+
rgb: .769, .769, .769
|
226 |
+
Line:
|
227 |
+
points: [*self.p1, *self.q1] if root.parent and root.parent.selected and self.is_curved else []
|
228 |
+
dash_length: 5
|
229 |
+
dash_offset: 5
|
230 |
+
Line:
|
231 |
+
points: [*self.q2, *self.p2] if root.parent and root.parent.selected and self.is_curved else []
|
232 |
+
dash_length: 5
|
233 |
+
dash_offset: 5
|
234 |
+
Color:
|
235 |
+
rgb: tuple(root.parent.color) if root.parent is not None else (.043, .769, 1)
|
236 |
+
Line:
|
237 |
+
bezier: ([*self.p1, *self.q1, *self.q2, *self.p2] if self.is_curved else [*self.p1, *self.p2]) if self.is_finished else [-10000, -10000]
|
238 |
+
width: 1.1
|
239 |
+
Color:
|
240 |
+
rgb: 1, .616, .043
|
241 |
+
Point:
|
242 |
+
points: [*self.p1, *self.p2] if root.parent and root.parent.selected else []
|
243 |
+
pointsize: 1.5
|
244 |
+
Color:
|
245 |
+
rgb: .769, .769, .769
|
246 |
+
Point:
|
247 |
+
points: [*self.q1, *self.q2] if self.is_curved and root.parent and root.parent.selected else []
|
248 |
+
pointsize: 1.5
|
249 |
+
|
250 |
+
|
251 |
+
<BezierPath>
|
252 |
+
|
253 |
+
|
254 |
+
<Sketch>:
|
255 |
+
canvas:
|
256 |
+
Color:
|
257 |
+
rgb: root.color
|
258 |
+
Line:
|
259 |
+
points: root.points
|
260 |
+
width: 1.2
|
261 |
+
|
262 |
+
|
263 |
+
<EditorView>:
|
264 |
+
size_hint: None, None
|
265 |
+
size: draw_viewbox.size
|
266 |
+
center: self.parent.center
|
267 |
+
scale: 1.5
|
268 |
+
|
269 |
+
DrawViewbox:
|
270 |
+
id: draw_viewbox
|
271 |
+
|
272 |
+
|
273 |
+
<DrawViewbox>
|
274 |
+
size: 256, 256
|
275 |
+
|
276 |
+
canvas.before:
|
277 |
+
Color:
|
278 |
+
rgb: 1, 1, 1
|
279 |
+
Rectangle:
|
280 |
+
pos: self.pos
|
281 |
+
size: self.size
|
282 |
+
|
283 |
+
<TimeLine>:
|
284 |
+
do_scroll_y: False
|
285 |
+
size_hint_y: None
|
286 |
+
height: 50
|
287 |
+
|
288 |
+
canvas.before:
|
289 |
+
Color:
|
290 |
+
rgb: 1, 1, 1
|
291 |
+
Rectangle:
|
292 |
+
pos: self.pos
|
293 |
+
size: self.size
|
294 |
+
|
295 |
+
Color:
|
296 |
+
rgb: 0.8, 0.8, 0.8
|
297 |
+
Line:
|
298 |
+
width: 1
|
299 |
+
rectangle: self.x, self.y, self.width, self.height
|
300 |
+
|
301 |
+
BoxLayout:
|
302 |
+
id: timeline
|
303 |
+
|
304 |
+
orientation: "horizontal"
|
305 |
+
size_hint_x: None
|
306 |
+
width: 50 * len(self.children) if self.children else 0
|
307 |
+
|
308 |
+
|
309 |
+
<FrameView>
|
310 |
+
size_hint_x: None
|
311 |
+
width: self.height
|
312 |
+
|
313 |
+
color: 0, 0, 0, 1
|
314 |
+
text: str(self.index)
|
315 |
+
|
316 |
+
background_normal: ""
|
317 |
+
background_down: ""
|
318 |
+
background_color: (0.08, 0.58, 0.97, 1) if self.parent and self.index == self.parent.parent.selected_frame and self.keyframe else (0.48, 0.78, 1, 1) if self.parent and self.index == self.parent.parent.selected_frame and not self.keyframe else (1, 0.67, 0.19, 1) if self.keyframe else (1, 1, 1, 1)
|
319 |
+
|
320 |
+
canvas:
|
321 |
+
Color:
|
322 |
+
rgb: 0.8, 0.8, 0.8
|
323 |
+
Line:
|
324 |
+
width: 1
|
325 |
+
rectangle: self.x, self.y, self.width, self.height
|
326 |
+
|
327 |
+
|
328 |
+
<HeaderIcon>:
|
329 |
+
size_hint_x: None
|
330 |
+
width: self.height
|
331 |
+
|
332 |
+
canvas:
|
333 |
+
Color:
|
334 |
+
rgb: (0.08, 0.58, 0.97) if self.index == self.parent.selected_tool else (0, 0, 0)
|
335 |
+
Rectangle:
|
336 |
+
pos: self.pos
|
337 |
+
size: self.size
|
338 |
+
|
339 |
+
Image:
|
340 |
+
source: self.parent.source
|
341 |
+
center: self.parent.center
|
342 |
+
|
343 |
+
<HeaderButton>:
|
344 |
+
size_hint: None, None
|
345 |
+
height: 0.8 * self.parent.height
|
346 |
+
pos_hint: {'top': 0.9}
|
347 |
+
width: self.texture_size[0] + 40
|
348 |
+
|
349 |
+
<Padding>:
|
350 |
+
size_hint_x: None
|
351 |
+
width: 10
|
352 |
+
|
353 |
+
|
354 |
+
|
355 |
+
<FileChoosePopup>:
|
356 |
+
title: "Import SVG file"
|
357 |
+
size_hint: .9, .9
|
358 |
+
auto_dismiss: True
|
359 |
+
|
360 |
+
BoxLayout:
|
361 |
+
orientation: "vertical"
|
362 |
+
FileChooserIconView:
|
363 |
+
id: filechooser
|
364 |
+
path: root.path
|
365 |
+
|
366 |
+
BoxLayout:
|
367 |
+
size_hint_x: 1
|
368 |
+
size_hint_y: None
|
369 |
+
height: 50
|
370 |
+
pos_hint: {'center_x': .5, 'center_y': .5}
|
371 |
+
spacing: 20
|
372 |
+
Button:
|
373 |
+
text: "Cancel"
|
374 |
+
on_release: root.dismiss()
|
375 |
+
Button:
|
376 |
+
text: "Load"
|
377 |
+
on_release: root.load(filechooser.selection)
|
378 |
+
id: ldbtn
|
379 |
+
disabled: True if filechooser.selection==[] else False
|
380 |
+
|
deepsvg/gui/interpolate.py
ADDED
@@ -0,0 +1,126 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch.utils.data import DataLoader
|
3 |
+
import torch.nn as nn
|
4 |
+
|
5 |
+
from configs.deepsvg.hierarchical_ordered import Config
|
6 |
+
|
7 |
+
from deepsvg import utils
|
8 |
+
from deepsvg.svglib.svg import SVG
|
9 |
+
from deepsvg.difflib.tensor import SVGTensor
|
10 |
+
from deepsvg.svglib.geom import Bbox
|
11 |
+
from deepsvg.svgtensor_dataset import load_dataset, SVGFinetuneDataset
|
12 |
+
from deepsvg.utils.utils import batchify
|
13 |
+
|
14 |
+
from .state.project import DeepSVGProject, Frame
|
15 |
+
from .utils import easein_easeout
|
16 |
+
|
17 |
+
|
18 |
+
device = torch.device("cuda:0"if torch.cuda.is_available() else "cpu")
|
19 |
+
pretrained_path = "./pretrained/hierarchical_ordered.pth.tar"
|
20 |
+
|
21 |
+
cfg = Config()
|
22 |
+
cfg.model_cfg.dropout = 0. # for faster convergence
|
23 |
+
model = cfg.make_model().to(device)
|
24 |
+
model.eval()
|
25 |
+
|
26 |
+
|
27 |
+
dataset = load_dataset(cfg)
|
28 |
+
|
29 |
+
|
30 |
+
def decode(z):
|
31 |
+
commands_y, args_y, _ = model.greedy_sample(z=z)
|
32 |
+
tensor_pred = SVGTensor.from_cmd_args(commands_y[0].cpu(), args_y[0].cpu())
|
33 |
+
svg_path_sample = SVG.from_tensor(tensor_pred.data, viewbox=Bbox(256))
|
34 |
+
|
35 |
+
return svg_path_sample
|
36 |
+
|
37 |
+
|
38 |
+
def encode_svg(svg):
|
39 |
+
data = dataset.get(model_args=[*cfg.model_args, "tensor_grouped"], svg=svg)
|
40 |
+
model_args = batchify((data[key] for key in cfg.model_args), device)
|
41 |
+
z = model(*model_args, encode_mode=True)
|
42 |
+
return z
|
43 |
+
|
44 |
+
|
45 |
+
def interpolate_svg(svg1, svg2, n=10, ease=True):
|
46 |
+
z1, z2 = encode_svg(svg1), encode_svg(svg2)
|
47 |
+
|
48 |
+
alphas = torch.linspace(0., 1., n+2)[1:-1]
|
49 |
+
if ease:
|
50 |
+
alphas = easein_easeout(alphas)
|
51 |
+
|
52 |
+
z_list = [(1 - a) * z1 + a * z2 for a in alphas]
|
53 |
+
svgs = [decode(z) for z in z_list]
|
54 |
+
|
55 |
+
return svgs
|
56 |
+
|
57 |
+
|
58 |
+
def finetune_model(project: DeepSVGProject, nb_augmentations=3500):
|
59 |
+
keyframe_ids = [i for i, frame in enumerate(project.frames) if frame.keyframe]
|
60 |
+
|
61 |
+
if len(keyframe_ids) < 2:
|
62 |
+
return
|
63 |
+
|
64 |
+
svgs = [project.frames[i].svg for i in keyframe_ids]
|
65 |
+
|
66 |
+
utils.load_model(pretrained_path, model)
|
67 |
+
print("Finetuning...")
|
68 |
+
finetune_dataset = SVGFinetuneDataset(dataset, svgs, frac=1.0, nb_augmentations=nb_augmentations)
|
69 |
+
dataloader = DataLoader(finetune_dataset, batch_size=cfg.batch_size, shuffle=True, drop_last=False,
|
70 |
+
num_workers=cfg.loader_num_workers, collate_fn=cfg.collate_fn)
|
71 |
+
|
72 |
+
# Optimizer, lr & warmup schedulers
|
73 |
+
optimizers = cfg.make_optimizers(model)
|
74 |
+
scheduler_lrs = cfg.make_schedulers(optimizers, epoch_size=len(dataloader))
|
75 |
+
scheduler_warmups = cfg.make_warmup_schedulers(optimizers, scheduler_lrs)
|
76 |
+
|
77 |
+
loss_fns = [l.to(device) for l in cfg.make_losses()]
|
78 |
+
|
79 |
+
epoch = 0
|
80 |
+
for step, data in enumerate(dataloader):
|
81 |
+
model.train()
|
82 |
+
model_args = [data[arg].to(device) for arg in cfg.model_args]
|
83 |
+
labels = data["label"].to(device) if "label" in data else None
|
84 |
+
params_dict, weights_dict = cfg.get_params(step, epoch), cfg.get_weights(step, epoch)
|
85 |
+
|
86 |
+
for i, (loss_fn, optimizer, scheduler_lr, scheduler_warmup, optimizer_start) in enumerate(
|
87 |
+
zip(loss_fns, optimizers, scheduler_lrs, scheduler_warmups, cfg.optimizer_starts), 1):
|
88 |
+
optimizer.zero_grad()
|
89 |
+
|
90 |
+
output = model(*model_args, params=params_dict)
|
91 |
+
loss_dict = loss_fn(output, labels, weights=weights_dict)
|
92 |
+
|
93 |
+
loss_dict["loss"].backward()
|
94 |
+
if cfg.grad_clip is not None:
|
95 |
+
nn.utils.clip_grad_norm_(model.parameters(), cfg.grad_clip)
|
96 |
+
|
97 |
+
optimizer.step()
|
98 |
+
if scheduler_lr is not None:
|
99 |
+
scheduler_lr.step()
|
100 |
+
if scheduler_warmup is not None:
|
101 |
+
scheduler_warmup.step()
|
102 |
+
|
103 |
+
if step % 20 == 0:
|
104 |
+
print(f"Step {step}: loss: {loss_dict['loss']}")
|
105 |
+
|
106 |
+
print("Finetuning done.")
|
107 |
+
|
108 |
+
|
109 |
+
def compute_interpolation(project: DeepSVGProject):
|
110 |
+
finetune_model(project)
|
111 |
+
|
112 |
+
keyframe_ids = [i for i, frame in enumerate(project.frames) if frame.keyframe]
|
113 |
+
|
114 |
+
if len(keyframe_ids) < 2:
|
115 |
+
return
|
116 |
+
|
117 |
+
model.eval()
|
118 |
+
|
119 |
+
for i1, i2 in zip(keyframe_ids[:-1], keyframe_ids[1:]):
|
120 |
+
frames_inbetween = i2 - i1 - 1
|
121 |
+
if frames_inbetween == 0:
|
122 |
+
continue
|
123 |
+
|
124 |
+
svgs = interpolate_svg(project.frames[i1].svg, project.frames[i2].svg, n=frames_inbetween, ease=False)
|
125 |
+
for di, svg in enumerate(svgs, 1):
|
126 |
+
project.frames[i1 + di] = Frame(i1 + di, keyframe=False, svg=svg)
|
deepsvg/gui/layout/__init__.py
ADDED
File without changes
|
deepsvg/gui/layout/aligned_textinput.py
ADDED
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from kivy.uix.textinput import TextInput
|
2 |
+
from kivy.properties import StringProperty
|
3 |
+
|
4 |
+
DEFAULT_PADDING = 6
|
5 |
+
|
6 |
+
|
7 |
+
class AlignedTextInput(TextInput):
|
8 |
+
halign = StringProperty('left')
|
9 |
+
valign = StringProperty('top')
|
10 |
+
|
11 |
+
def __init__(self, **kwargs):
|
12 |
+
self.halign = kwargs.get("halign", "left")
|
13 |
+
self.valign = kwargs.get("valign", "top")
|
14 |
+
|
15 |
+
self.bind(on_text=self.on_text)
|
16 |
+
|
17 |
+
super().__init__(**kwargs)
|
18 |
+
|
19 |
+
def on_text(self, instance, value):
|
20 |
+
self.redraw()
|
21 |
+
|
22 |
+
def on_size(self, instance, value):
|
23 |
+
self.redraw()
|
24 |
+
|
25 |
+
def redraw(self):
|
26 |
+
"""
|
27 |
+
Note: This methods depends on internal variables of its TextInput
|
28 |
+
base class (_lines_rects and _refresh_text())
|
29 |
+
"""
|
30 |
+
|
31 |
+
self._refresh_text(self.text)
|
32 |
+
|
33 |
+
max_size = max(self._lines_rects, key=lambda r: r.size[0]).size
|
34 |
+
num_lines = len(self._lines_rects)
|
35 |
+
|
36 |
+
px = [DEFAULT_PADDING, DEFAULT_PADDING]
|
37 |
+
py = [DEFAULT_PADDING, DEFAULT_PADDING]
|
38 |
+
|
39 |
+
if self.halign == 'center':
|
40 |
+
d = (self.width - max_size[0]) / 2.0 - DEFAULT_PADDING
|
41 |
+
px = [d, d]
|
42 |
+
elif self.halign == 'right':
|
43 |
+
px[0] = self.width - max_size[0] - DEFAULT_PADDING
|
44 |
+
|
45 |
+
if self.valign == 'middle':
|
46 |
+
d = (self.height - max_size[1] * num_lines) / 2.0 - DEFAULT_PADDING
|
47 |
+
py = [d + 5, d - 5]
|
48 |
+
elif self.valign == 'bottom':
|
49 |
+
py[0] = self.height - max_size[1] * num_lines - DEFAULT_PADDING
|
50 |
+
|
51 |
+
self.padding_x = px
|
52 |
+
self.padding_y = py
|
deepsvg/gui/main.py
ADDED
@@ -0,0 +1,794 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from kivy.app import App
|
2 |
+
from kivy.uix.widget import Widget
|
3 |
+
from kivy.uix.boxlayout import BoxLayout
|
4 |
+
from kivy.uix.button import Button
|
5 |
+
from kivy.uix.scatter import Scatter
|
6 |
+
from kivy.uix.label import Label
|
7 |
+
from kivy.uix.scrollview import ScrollView
|
8 |
+
from kivy.properties import BooleanProperty, StringProperty, NumericProperty, ListProperty, ObjectProperty
|
9 |
+
from kivy.uix.behaviors import ButtonBehavior
|
10 |
+
from kivy.vector import Vector
|
11 |
+
from kivy.metrics import dp
|
12 |
+
from kivy.clock import Clock
|
13 |
+
from kivy.uix.popup import Popup
|
14 |
+
|
15 |
+
from kivy.config import Config
|
16 |
+
Config.set('graphics', 'width', '1400')
|
17 |
+
Config.set('graphics', 'height', '800')
|
18 |
+
from kivy.core.window import Window
|
19 |
+
|
20 |
+
import os
|
21 |
+
from typing import List
|
22 |
+
|
23 |
+
from deepsvg.svglib.geom import Point
|
24 |
+
from deepsvg.svglib.svg_command import SVGCommandMove, SVGCommandLine, SVGCommandBezier
|
25 |
+
from deepsvg.svgtensor_dataset import SVGTensorDataset
|
26 |
+
|
27 |
+
from .layout.aligned_textinput import AlignedTextInput
|
28 |
+
from .state.state import State, ToolMode, DrawMode, LoopMode, PlaybackMode
|
29 |
+
from .state.project import Frame
|
30 |
+
from .config import ROOT_DIR
|
31 |
+
from .interpolate import compute_interpolation
|
32 |
+
from .utils import *
|
33 |
+
|
34 |
+
|
35 |
+
if not os.path.exists(ROOT_DIR):
|
36 |
+
os.makedirs(ROOT_DIR)
|
37 |
+
|
38 |
+
state = State()
|
39 |
+
state.load_state()
|
40 |
+
state.load_project()
|
41 |
+
|
42 |
+
|
43 |
+
class HeaderIcon(Button):
|
44 |
+
index = NumericProperty(0)
|
45 |
+
source = StringProperty("")
|
46 |
+
|
47 |
+
def on_press(self):
|
48 |
+
state.header.selected_tool = self.index
|
49 |
+
|
50 |
+
|
51 |
+
class Header(BoxLayout):
|
52 |
+
selected_tool = NumericProperty(0)
|
53 |
+
title = StringProperty(state.project.name)
|
54 |
+
is_playing = BooleanProperty(False)
|
55 |
+
delay = NumericProperty(state.delay)
|
56 |
+
|
57 |
+
def on_selected_tool(self, *args):
|
58 |
+
if self.selected_tool in [ToolMode.MOVE, ToolMode.PEN, ToolMode.PENCIL] and state.header.is_playing:
|
59 |
+
state.header.pause_animation()
|
60 |
+
|
61 |
+
def on_done(self, *args):
|
62 |
+
if self.selected_tool == ToolMode.PEN and state.draw_mode == DrawMode.DRAW:
|
63 |
+
path = state.current_path
|
64 |
+
|
65 |
+
last_segment = path.children[-1]
|
66 |
+
path.remove_widget(last_segment)
|
67 |
+
|
68 |
+
state.draw_viewbox.on_path_done(state.current_path)
|
69 |
+
|
70 |
+
state.draw_mode = DrawMode.STILL
|
71 |
+
state.current_path = None
|
72 |
+
self.selected_tool = ToolMode.MOVE
|
73 |
+
|
74 |
+
def on_erase(self):
|
75 |
+
state.modified = True
|
76 |
+
state.draw_viewbox.clear()
|
77 |
+
|
78 |
+
state.timeline.make_keyframe(False)
|
79 |
+
|
80 |
+
def add_frame(self, keyframe=False):
|
81 |
+
frame_idx = state.timeline._add_frame(keyframe=keyframe)
|
82 |
+
|
83 |
+
state.project.frames.append(Frame(frame_idx, keyframe))
|
84 |
+
|
85 |
+
self.load_next_frame(frame_idx=frame_idx)
|
86 |
+
|
87 |
+
def play_animation(self):
|
88 |
+
self.is_playing = True
|
89 |
+
state.sidebar.selected_path_idx = -1
|
90 |
+
self.clock = Clock.schedule_once(self.load_next_frame)
|
91 |
+
|
92 |
+
def load_next_frame(self, dt=0, frame_idx=None, *args):
|
93 |
+
if state.timeline.nb_frames > 0:
|
94 |
+
if frame_idx is None:
|
95 |
+
frame_idx_tmp = state.timeline.selected_frame + state.loop_orientation
|
96 |
+
|
97 |
+
if frame_idx_tmp < 0 or frame_idx_tmp >= state.timeline.nb_frames:
|
98 |
+
if state.loop_mode in [LoopMode.NORMAL, LoopMode.REVERSE]:
|
99 |
+
frame_idx = frame_idx_tmp % state.timeline.nb_frames
|
100 |
+
else: # LoopMode.PINGPONG
|
101 |
+
state.loop_orientation *= -1
|
102 |
+
frame_idx = (state.timeline.selected_frame + state.loop_orientation) % state.timeline.nb_frames
|
103 |
+
else:
|
104 |
+
frame_idx = frame_idx_tmp
|
105 |
+
|
106 |
+
state.timeline.selected_frame = frame_idx
|
107 |
+
|
108 |
+
if self.is_playing:
|
109 |
+
if state.playback_mode == PlaybackMode.EASE:
|
110 |
+
t = frame_idx / state.timeline.nb_frames
|
111 |
+
delay = 2 * state.delay / (1 + d_easein_easeout(t))
|
112 |
+
else:
|
113 |
+
delay = state.delay
|
114 |
+
self.clock = Clock.schedule_once(self.load_next_frame, delay)
|
115 |
+
|
116 |
+
def pause_animation(self):
|
117 |
+
self.clock.cancel()
|
118 |
+
state.sidebar.selected_path_idx = -1
|
119 |
+
self.is_playing = False
|
120 |
+
|
121 |
+
state.timeline.on_selected_frame() # re-render frame to display sidebar layers
|
122 |
+
|
123 |
+
def on_title(self, title):
|
124 |
+
state.project.name = title
|
125 |
+
|
126 |
+
def interpolate(self):
|
127 |
+
state.draw_viewbox.save_frame()
|
128 |
+
|
129 |
+
compute_interpolation(state.project)
|
130 |
+
|
131 |
+
|
132 |
+
class PathLayerView(ButtonBehavior, BoxLayout):
|
133 |
+
index = NumericProperty(0)
|
134 |
+
source = StringProperty("")
|
135 |
+
|
136 |
+
def __init__(self, index, **kwargs):
|
137 |
+
super().__init__(**kwargs)
|
138 |
+
|
139 |
+
self.index = index
|
140 |
+
self.source = os.path.join(state.project.cache_dir, f"{state.timeline.selected_frame}_{index}.png")
|
141 |
+
|
142 |
+
def on_press(self):
|
143 |
+
state.sidebar.selected_path_idx = self.index
|
144 |
+
|
145 |
+
def move_up(self):
|
146 |
+
if self.index > 0:
|
147 |
+
state.sidebar.swap_paths(self.index, self.index - 1)
|
148 |
+
|
149 |
+
def move_down(self):
|
150 |
+
if self.index < state.sidebar.nb_paths - 1:
|
151 |
+
state.sidebar.swap_paths(self.index, self.index + 1)
|
152 |
+
|
153 |
+
def reverse(self):
|
154 |
+
state.sidebar.reverse_path(self.index)
|
155 |
+
|
156 |
+
|
157 |
+
class Sidebar(ScrollView):
|
158 |
+
selected_path_idx = NumericProperty(-1)
|
159 |
+
|
160 |
+
@property
|
161 |
+
def sidebar(self):
|
162 |
+
return self.ids.sidebar
|
163 |
+
|
164 |
+
@property
|
165 |
+
def nb_paths(self):
|
166 |
+
return len(self.sidebar.children)
|
167 |
+
|
168 |
+
def on_selected_path_idx(self, *args):
|
169 |
+
state.draw_viewbox.unselect_all()
|
170 |
+
|
171 |
+
if self.selected_path_idx >= 0:
|
172 |
+
state.draw_viewbox.get_path(self.selected_path_idx).selected = True
|
173 |
+
|
174 |
+
def _add_path(self, idx=None):
|
175 |
+
if idx is None:
|
176 |
+
idx = self.nb_paths
|
177 |
+
new_pathlayer = PathLayerView(idx)
|
178 |
+
self.sidebar.add_widget(new_pathlayer)
|
179 |
+
return idx
|
180 |
+
|
181 |
+
def get_path(self, path_idx):
|
182 |
+
index = self.nb_paths - 1 - path_idx
|
183 |
+
return self.sidebar.children[index]
|
184 |
+
|
185 |
+
def erase(self):
|
186 |
+
self.sidebar.clear_widgets()
|
187 |
+
self.selected_path_idx = -1
|
188 |
+
|
189 |
+
def swap_paths(self, idx1, idx2):
|
190 |
+
path_layer1, path_layer2 = self.get_path(idx1), self.get_path(idx2)
|
191 |
+
path1, path2 = state.draw_viewbox.get_path(idx1), state.draw_viewbox.get_path(idx2)
|
192 |
+
|
193 |
+
path_layer1.index, path_layer2.index = idx2, idx1
|
194 |
+
path1.color, path2.color = path2.color, path1.color
|
195 |
+
path1.index, path2.index = path2.index, path1.index
|
196 |
+
|
197 |
+
id1, id2 = self.nb_paths - 1 - idx1, self.nb_paths - 1 - idx2
|
198 |
+
self.sidebar.children[id1], self.sidebar.children[id2] = path_layer2, path_layer1
|
199 |
+
state.draw_viewbox.children[id1], state.draw_viewbox.children[id2] = path2, path1
|
200 |
+
|
201 |
+
self.selected_path_idx = idx2
|
202 |
+
state.modified = True
|
203 |
+
|
204 |
+
def reverse_path(self, idx):
|
205 |
+
path = state.draw_viewbox.get_path(idx)
|
206 |
+
svg_path = path.to_svg_path().reverse()
|
207 |
+
new_path = BezierPath.from_svg_path(svg_path, color=path.color, index=path.index, selected=path.selected)
|
208 |
+
|
209 |
+
id = self.nb_paths - 1 - idx
|
210 |
+
state.draw_viewbox.remove_widget(path)
|
211 |
+
state.draw_viewbox.add_widget(new_path, index=id)
|
212 |
+
|
213 |
+
self.selected_path_idx = idx
|
214 |
+
state.modified = True
|
215 |
+
|
216 |
+
def select(self, path_idx):
|
217 |
+
if self.selected_path_idx >= 0:
|
218 |
+
state.draw_viewbox.get_path(state.sidebar.selected_path_idx).selected = False
|
219 |
+
self.selected_path_idx = path_idx
|
220 |
+
|
221 |
+
|
222 |
+
class BezierSegment(Widget):
|
223 |
+
is_curved = BooleanProperty(True)
|
224 |
+
|
225 |
+
is_finished = BooleanProperty(True)
|
226 |
+
select_dist = NumericProperty(3)
|
227 |
+
|
228 |
+
p1 = ListProperty([0, 0])
|
229 |
+
q1 = ListProperty([0, 0])
|
230 |
+
q2 = ListProperty([0, 0])
|
231 |
+
p2 = ListProperty([0, 0])
|
232 |
+
|
233 |
+
def clone(self):
|
234 |
+
segment = BezierSegment()
|
235 |
+
segment.is_curved = self.is_curved
|
236 |
+
segment.p1 = self.p1 # shallow copy
|
237 |
+
segment.q1 = self.q1
|
238 |
+
segment.q2 = self.q2
|
239 |
+
segment.p2 = self.p2
|
240 |
+
return segment
|
241 |
+
|
242 |
+
@staticmethod
|
243 |
+
def line(p1, p2):
|
244 |
+
segment = BezierSegment()
|
245 |
+
segment.is_curved = False
|
246 |
+
segment.p1 = segment.q1 = p1
|
247 |
+
segment.p2 = segment.q2 = p2
|
248 |
+
return segment
|
249 |
+
|
250 |
+
@staticmethod
|
251 |
+
def bezier(p1, q1, q2, p2):
|
252 |
+
segment = BezierSegment()
|
253 |
+
segment.is_curved = True
|
254 |
+
segment.q1, segment.q2 = q1, q2
|
255 |
+
segment.p1, segment.p2 = p1, p2
|
256 |
+
return segment
|
257 |
+
|
258 |
+
def get_point(self, key):
|
259 |
+
return getattr(self, key)
|
260 |
+
|
261 |
+
def on_touch_down(self, touch):
|
262 |
+
max_dist = dp(self.select_dist)
|
263 |
+
|
264 |
+
if not self.parent.selected:
|
265 |
+
return super().on_touch_down(touch)
|
266 |
+
|
267 |
+
keys_to_test = ["p1", "q1", "q2", "p2"] if self.is_curved else ["p1", "p2"]
|
268 |
+
for key in keys_to_test:
|
269 |
+
if dist(touch.pos, getattr(self, key)) < max_dist:
|
270 |
+
touch.ud['selected'] = key
|
271 |
+
touch.grab(self)
|
272 |
+
|
273 |
+
state.modified = True
|
274 |
+
|
275 |
+
return True
|
276 |
+
|
277 |
+
def on_touch_move(self, touch):
|
278 |
+
if touch.grab_current is not self:
|
279 |
+
return super().on_touch_move(touch)
|
280 |
+
|
281 |
+
key = touch.ud['selected']
|
282 |
+
setattr(self, key, touch.pos)
|
283 |
+
|
284 |
+
if state.header.selected_tool == ToolMode.PEN:
|
285 |
+
self.is_curved = True
|
286 |
+
self.is_finished = False
|
287 |
+
state.draw_mode = DrawMode.HOLDING_DOWN
|
288 |
+
|
289 |
+
setattr(self, "p2", touch.pos)
|
290 |
+
|
291 |
+
if key in ["p1", "p2"]:
|
292 |
+
self.parent.move(self, key, touch.pos)
|
293 |
+
|
294 |
+
def on_touch_up(self, touch):
|
295 |
+
if touch.grab_current is not self:
|
296 |
+
return super().on_touch_up(touch)
|
297 |
+
|
298 |
+
touch.ungrab(self)
|
299 |
+
|
300 |
+
if state.header.selected_tool == ToolMode.PEN:
|
301 |
+
self.is_finished = True
|
302 |
+
state.draw_mode = DrawMode.DRAW
|
303 |
+
|
304 |
+
|
305 |
+
class BezierPath(Widget):
|
306 |
+
color = ListProperty([1, 1, 1])
|
307 |
+
index = NumericProperty(0)
|
308 |
+
selected = BooleanProperty(False)
|
309 |
+
|
310 |
+
def __init__(self, segments: List[BezierSegment], color=None, index=None, selected=False, **kwargs):
|
311 |
+
super().__init__(**kwargs)
|
312 |
+
|
313 |
+
if color is not None:
|
314 |
+
self.color = color
|
315 |
+
|
316 |
+
if index is not None:
|
317 |
+
self.index = index
|
318 |
+
|
319 |
+
self.selected = selected
|
320 |
+
|
321 |
+
for segment in segments:
|
322 |
+
self.add_segment(segment)
|
323 |
+
|
324 |
+
def clone(self):
|
325 |
+
segments = [segment.clone() for segment in self.children]
|
326 |
+
return BezierPath(segments, self.color, self.index, self.selected)
|
327 |
+
|
328 |
+
def add_segment(self, segment: BezierSegment):
|
329 |
+
self.add_widget(segment, index=len(self.children))
|
330 |
+
|
331 |
+
def move(self, segment, key, pos):
|
332 |
+
idx = self.children.index(segment)
|
333 |
+
|
334 |
+
if not (idx == 0 and key == "p1") and not (idx == len(self.children) - 1 and key == "p2"):
|
335 |
+
idx2, key2 = (idx-1, "p2") if key == "p1" else (idx+1, "p1")
|
336 |
+
setattr(self.children[idx2], key2, pos)
|
337 |
+
|
338 |
+
def add_widget(self, widget, index=0, canvas=None):
|
339 |
+
super().add_widget(widget, index=index, canvas=canvas)
|
340 |
+
|
341 |
+
def remove_widget(self, widget):
|
342 |
+
super().remove_widget(widget)
|
343 |
+
|
344 |
+
@staticmethod
|
345 |
+
def from_svg_path(svg_path: SVGPath, *args, **kwargs):
|
346 |
+
segments = []
|
347 |
+
for command in svg_path.path_commands:
|
348 |
+
if isinstance(command, SVGCommandBezier):
|
349 |
+
segment = BezierSegment.bezier(flip_vertical(command.p1.tolist()), flip_vertical(command.q1.tolist()),
|
350 |
+
flip_vertical(command.q2.tolist()), flip_vertical(command.p2.tolist()))
|
351 |
+
segments.append(segment)
|
352 |
+
elif isinstance(command, SVGCommandLine):
|
353 |
+
segment = BezierSegment.line(flip_vertical(command.start_pos.tolist()),
|
354 |
+
flip_vertical(command.end_pos.tolist()))
|
355 |
+
segments.append(segment)
|
356 |
+
|
357 |
+
path = BezierPath(segments, *args, **kwargs)
|
358 |
+
return path
|
359 |
+
|
360 |
+
def to_svg_path(self):
|
361 |
+
path_commands = []
|
362 |
+
for segment in self.children:
|
363 |
+
if segment.is_curved:
|
364 |
+
command = SVGCommandBezier(Point(*flip_vertical(segment.p1)), Point(*flip_vertical(segment.q1)),
|
365 |
+
Point(*flip_vertical(segment.q2)), Point(*flip_vertical(segment.p2)))
|
366 |
+
else:
|
367 |
+
command = SVGCommandLine(Point(*flip_vertical(segment.p1)), Point(*flip_vertical(segment.p2)))
|
368 |
+
path_commands.append(command)
|
369 |
+
svg_path = SVGPath(path_commands)
|
370 |
+
return svg_path
|
371 |
+
|
372 |
+
|
373 |
+
class Sketch(Widget):
|
374 |
+
color = ListProperty([1, 1, 1])
|
375 |
+
points = ListProperty([])
|
376 |
+
|
377 |
+
def __init__(self, points, color=None, **kwargs):
|
378 |
+
super().__init__(**kwargs)
|
379 |
+
|
380 |
+
if color is not None:
|
381 |
+
self.color = color
|
382 |
+
|
383 |
+
self.points = points
|
384 |
+
|
385 |
+
def on_touch_move(self, touch):
|
386 |
+
if touch.grab_current is not self:
|
387 |
+
return super().on_touch_move(touch)
|
388 |
+
|
389 |
+
self.points.extend(touch.pos)
|
390 |
+
|
391 |
+
def on_touch_up(self, touch):
|
392 |
+
if touch.grab_current is not self:
|
393 |
+
return super().on_touch_up(touch)
|
394 |
+
|
395 |
+
touch.ungrab(self)
|
396 |
+
|
397 |
+
self.parent.on_sketch_done(self)
|
398 |
+
|
399 |
+
def to_svg_path(self):
|
400 |
+
points = [Point(x, 255 - y) for x, y in zip(self.points[::2], self.points[1::2])]
|
401 |
+
commands = [SVGCommandMove(points[0])] + [SVGCommandLine(p1, p2) for p1, p2 in zip(points[:-1], points[1:])]
|
402 |
+
svg_path = SVGPath.from_commands(commands).path
|
403 |
+
return svg_path
|
404 |
+
|
405 |
+
|
406 |
+
class EditorView(Scatter):
|
407 |
+
def on_touch_down(self, touch):
|
408 |
+
if self.collide_point(*touch.pos) and touch.is_mouse_scrolling:
|
409 |
+
if touch.button == 'scrolldown':
|
410 |
+
if self.scale < 10:
|
411 |
+
self.scale = self.scale * 1.1
|
412 |
+
elif touch.button == 'scrollup':
|
413 |
+
if self.scale > 1:
|
414 |
+
self.scale = self.scale * 0.8
|
415 |
+
return True
|
416 |
+
|
417 |
+
return super().on_touch_down(touch)
|
418 |
+
|
419 |
+
|
420 |
+
class DrawViewbox(Widget):
|
421 |
+
def __init__(self, **kwargs):
|
422 |
+
super().__init__(**kwargs)
|
423 |
+
Window.bind(mouse_pos=self.on_mouse_pos)
|
424 |
+
|
425 |
+
@property
|
426 |
+
def nb_paths(self):
|
427 |
+
return len(self.children)
|
428 |
+
|
429 |
+
def _get_color(self, idx):
|
430 |
+
color = color_dict[colors[idx % len(colors)]]
|
431 |
+
return color
|
432 |
+
|
433 |
+
def on_mouse_pos(self, _, abs_pos):
|
434 |
+
pos = (Vector(abs_pos) - Vector(self.parent.pos)) / self.parent.scale
|
435 |
+
|
436 |
+
if state.header.selected_tool == ToolMode.PEN and state.draw_mode == DrawMode.DRAW:
|
437 |
+
segment = state.current_path.children[-1]
|
438 |
+
segment.p2 = segment.q2 = pos
|
439 |
+
|
440 |
+
def on_sketch_done(self, sketch: Sketch):
|
441 |
+
# Digitalize points to Bézier path
|
442 |
+
svg_path = preprocess_svg_path(sketch.to_svg_path(), force_smooth=True)
|
443 |
+
|
444 |
+
path_idx = state.sidebar.nb_paths
|
445 |
+
path = BezierPath.from_svg_path(svg_path, color=sketch.color, index=path_idx, selected=True)
|
446 |
+
self.remove_widget(sketch)
|
447 |
+
|
448 |
+
self.add_new_path(path, svg_path)
|
449 |
+
|
450 |
+
def on_path_done(self, path: BezierPath):
|
451 |
+
svg_path = preprocess_svg_path(path.to_svg_path())
|
452 |
+
|
453 |
+
path_idx = state.sidebar.nb_paths
|
454 |
+
new_path = BezierPath.from_svg_path(svg_path, color=path.color, index=path_idx, selected=True)
|
455 |
+
self.remove_widget(path)
|
456 |
+
|
457 |
+
self.add_new_path(new_path, svg_path)
|
458 |
+
|
459 |
+
def paste(self, path: BezierPath):
|
460 |
+
path = path.clone()
|
461 |
+
|
462 |
+
path_idx = state.sidebar.nb_paths
|
463 |
+
path.color = self._get_color(path_idx)
|
464 |
+
path.selected = True
|
465 |
+
|
466 |
+
svg_path = path.to_svg_path()
|
467 |
+
|
468 |
+
self.add_new_path(path, svg_path)
|
469 |
+
|
470 |
+
def unselect_all(self):
|
471 |
+
for path in self.children:
|
472 |
+
path.selected = False
|
473 |
+
|
474 |
+
def get_path(self, path_idx):
|
475 |
+
index = self.nb_paths - 1 - path_idx
|
476 |
+
return self.children[index]
|
477 |
+
|
478 |
+
def add_new_path(self, path: BezierSegment, svg_path: SVGPath):
|
479 |
+
self.add_path(path, svg_path, force_rerender_miniature=True)
|
480 |
+
|
481 |
+
state.modified = True
|
482 |
+
state.timeline.make_keyframe(True)
|
483 |
+
state.sidebar.select(path.index)
|
484 |
+
|
485 |
+
def add_path(self, path: BezierPath, svg_path: SVGPath, force_rerender_miniature=False):
|
486 |
+
path_idx = state.sidebar.nb_paths
|
487 |
+
self.add_widget(path)
|
488 |
+
|
489 |
+
miniature_path = os.path.join(state.project.cache_dir, f"{state.timeline.selected_frame}_{path_idx}.png")
|
490 |
+
if not os.path.exists(miniature_path) or force_rerender_miniature:
|
491 |
+
svg_path = normalized_path(svg_path)
|
492 |
+
svg_path.draw(viewbox=svg_path.bbox().make_square(min_size=12),
|
493 |
+
file_path=os.path.join(state.project.cache_dir, f"{state.timeline.selected_frame}_{path_idx}.png"),
|
494 |
+
do_display=False)
|
495 |
+
|
496 |
+
if not state.header.is_playing:
|
497 |
+
state.sidebar._add_path()
|
498 |
+
|
499 |
+
def on_touch_down(self, touch):
|
500 |
+
if state.header.selected_tool == ToolMode.PLAY:
|
501 |
+
return False
|
502 |
+
|
503 |
+
if state.header.selected_tool == ToolMode.PEN and self.collide_point(*touch.pos):
|
504 |
+
state.draw_mode = DrawMode.DRAW
|
505 |
+
|
506 |
+
if state.current_path is None:
|
507 |
+
path = BezierPath([], color=self._get_color(len(self.children)), selected=True)
|
508 |
+
self.add_widget(path)
|
509 |
+
state.current_path = path
|
510 |
+
|
511 |
+
l = BezierSegment.line(touch.pos, touch.pos)
|
512 |
+
|
513 |
+
touch.ud["selected"] = "q1"
|
514 |
+
touch.grab(l)
|
515 |
+
|
516 |
+
state.current_path.add_segment(l)
|
517 |
+
|
518 |
+
state.modified = True
|
519 |
+
|
520 |
+
return True
|
521 |
+
|
522 |
+
if state.header.selected_tool == ToolMode.PENCIL and self.collide_point(*touch.pos):
|
523 |
+
l = Sketch([*touch.pos], color=self._get_color(len(self.children)))
|
524 |
+
self.add_widget(l)
|
525 |
+
touch.grab(l)
|
526 |
+
|
527 |
+
state.modified = True
|
528 |
+
|
529 |
+
return True
|
530 |
+
|
531 |
+
if super().on_touch_down(touch):
|
532 |
+
return True
|
533 |
+
|
534 |
+
def clear(self):
|
535 |
+
state.draw_viewbox.clear_widgets()
|
536 |
+
state.sidebar.erase()
|
537 |
+
|
538 |
+
def add_widget(self, widget, index=0, canvas=None):
|
539 |
+
super().add_widget(widget, index=index, canvas=canvas)
|
540 |
+
|
541 |
+
def remove_widget(self, widget):
|
542 |
+
super().remove_widget(widget)
|
543 |
+
|
544 |
+
def to_svg(self):
|
545 |
+
svg_path_groups = []
|
546 |
+
for path in reversed(self.children):
|
547 |
+
svg_path_groups.append(path.to_svg_path().to_group())
|
548 |
+
|
549 |
+
svg = SVG(svg_path_groups, viewbox=Bbox(256))
|
550 |
+
return svg
|
551 |
+
|
552 |
+
def load_svg(self, svg: SVG, frame_idx):
|
553 |
+
kivy_bezierpaths = []
|
554 |
+
for idx, svg_path in enumerate(svg.paths):
|
555 |
+
path = BezierPath.from_svg_path(svg_path, color=self._get_color(idx), index=idx, selected=False)
|
556 |
+
kivy_bezierpaths.append(path)
|
557 |
+
self.add_path(path, svg_path, force_rerender_miniature=True)
|
558 |
+
|
559 |
+
state.project.frames[frame_idx].svg = svg
|
560 |
+
state.project.frames[frame_idx].kivy_bezierpaths = kivy_bezierpaths
|
561 |
+
|
562 |
+
def load_cached(self, svg: SVG, kivy_bezierpaths: List[BezierPath]):
|
563 |
+
for path, svg_path in zip(kivy_bezierpaths, svg.paths):
|
564 |
+
self.add_path(path, svg_path)
|
565 |
+
|
566 |
+
def load_frame(self, frame_idx):
|
567 |
+
svg = state.project.frames[frame_idx].svg
|
568 |
+
kivy_bezierpaths = state.project.frames[frame_idx].kivy_bezierpaths
|
569 |
+
|
570 |
+
if kivy_bezierpaths is None:
|
571 |
+
self.load_svg(svg, frame_idx)
|
572 |
+
else:
|
573 |
+
self.load_cached(svg, kivy_bezierpaths)
|
574 |
+
|
575 |
+
self.unselect_all()
|
576 |
+
|
577 |
+
def save_frame(self):
|
578 |
+
svg = self.to_svg()
|
579 |
+
state.project.frames[state.current_frame].svg = svg
|
580 |
+
state.project.frames[state.current_frame].kivy_bezierpaths = [child for child in reversed(self.children) if isinstance(child, BezierPath)]
|
581 |
+
|
582 |
+
|
583 |
+
class HeaderButton(Button):
|
584 |
+
pass
|
585 |
+
|
586 |
+
|
587 |
+
class UpButton(Button):
|
588 |
+
def on_press(self):
|
589 |
+
self.parent.move_up()
|
590 |
+
|
591 |
+
|
592 |
+
class DownButton(Button):
|
593 |
+
def on_press(self):
|
594 |
+
self.parent.move_down()
|
595 |
+
|
596 |
+
|
597 |
+
class ReverseButton(Button):
|
598 |
+
def on_press(self):
|
599 |
+
self.parent.reverse()
|
600 |
+
|
601 |
+
|
602 |
+
class FrameView(Button):
|
603 |
+
index = NumericProperty(0)
|
604 |
+
keyframe = BooleanProperty(False)
|
605 |
+
|
606 |
+
def __init__(self, index, keyframe=False, **kwargs):
|
607 |
+
super().__init__(**kwargs)
|
608 |
+
|
609 |
+
self.index = index
|
610 |
+
self.keyframe = keyframe
|
611 |
+
|
612 |
+
def on_press(self):
|
613 |
+
state.timeline.selected_frame = self.index
|
614 |
+
|
615 |
+
|
616 |
+
class TimeLine(ScrollView):
|
617 |
+
selected_frame = NumericProperty(-1)
|
618 |
+
|
619 |
+
@property
|
620 |
+
def timeline(self):
|
621 |
+
return self.ids.timeline
|
622 |
+
|
623 |
+
@property
|
624 |
+
def nb_frames(self):
|
625 |
+
return len(self.timeline.children)
|
626 |
+
|
627 |
+
def on_selected_frame(self, *args):
|
628 |
+
self._update_frame(self.selected_frame)
|
629 |
+
|
630 |
+
def _update_frame(self, new_frame_idx):
|
631 |
+
if state.current_frame >= 0 and state.modified:
|
632 |
+
state.draw_viewbox.save_frame()
|
633 |
+
|
634 |
+
state.current_frame = new_frame_idx
|
635 |
+
state.draw_viewbox.clear()
|
636 |
+
state.modified = False
|
637 |
+
|
638 |
+
state.draw_viewbox.load_frame(new_frame_idx)
|
639 |
+
|
640 |
+
def _add_frame(self, keyframe=False):
|
641 |
+
idx = self.nb_frames
|
642 |
+
new_frame = FrameView(idx, keyframe=keyframe)
|
643 |
+
|
644 |
+
self.timeline.add_widget(new_frame)
|
645 |
+
return idx
|
646 |
+
|
647 |
+
def get_frame(self, frame_idx):
|
648 |
+
index = self.nb_frames - 1 - frame_idx
|
649 |
+
return self.timeline.children[index]
|
650 |
+
|
651 |
+
def make_keyframe(self, is_keyframe=None):
|
652 |
+
if is_keyframe is None:
|
653 |
+
is_keyframe = not self.get_frame(state.timeline.selected_frame).keyframe
|
654 |
+
|
655 |
+
self.get_frame(state.timeline.selected_frame).keyframe = is_keyframe
|
656 |
+
state.project.frames[state.timeline.selected_frame].keyframe = is_keyframe
|
657 |
+
|
658 |
+
|
659 |
+
class TitleWidget(AlignedTextInput):
|
660 |
+
pass
|
661 |
+
|
662 |
+
|
663 |
+
class Padding(Label):
|
664 |
+
pass
|
665 |
+
|
666 |
+
|
667 |
+
class FileChoosePopup(Popup):
|
668 |
+
load = ObjectProperty()
|
669 |
+
path = StringProperty(".")
|
670 |
+
|
671 |
+
|
672 |
+
class DeepSVGWidget(BoxLayout):
|
673 |
+
def __init__(self, **kwargs):
|
674 |
+
super().__init__(**kwargs)
|
675 |
+
|
676 |
+
state.main_widget = self
|
677 |
+
state.header = self.ids.header
|
678 |
+
state.sidebar = self.ids.sidebar_scroll
|
679 |
+
state.draw_viewbox = self.ids.editor.ids.draw_viewbox
|
680 |
+
state.timeline = self.ids.timeline_scroll
|
681 |
+
|
682 |
+
self._load_project()
|
683 |
+
|
684 |
+
def _load_project(self):
|
685 |
+
for frame in state.project.frames:
|
686 |
+
state.timeline._add_frame(keyframe=frame.keyframe)
|
687 |
+
|
688 |
+
state.timeline.selected_frame = 0
|
689 |
+
|
690 |
+
|
691 |
+
class DeepSVGApp(App):
|
692 |
+
def build(self):
|
693 |
+
self.title = 'DeepSVG Editor'
|
694 |
+
|
695 |
+
Window.bind(on_request_close=self.on_request_close)
|
696 |
+
Window.bind(on_keyboard=self.on_keyboard)
|
697 |
+
|
698 |
+
return DeepSVGWidget()
|
699 |
+
|
700 |
+
def save(self):
|
701 |
+
state.draw_viewbox.save_frame()
|
702 |
+
|
703 |
+
state.save_state()
|
704 |
+
state.project.save_project()
|
705 |
+
|
706 |
+
def on_request_close(self, *args, **kwargs):
|
707 |
+
self.save()
|
708 |
+
|
709 |
+
self.stop()
|
710 |
+
|
711 |
+
def on_keyboard(self, window, key, scancode, codepoint, modifier):
|
712 |
+
CTRL_PRESSED = (modifier == ['ctrl'] or modifier == ['meta'])
|
713 |
+
|
714 |
+
if codepoint == "h" and not CTRL_PRESSED:
|
715 |
+
# Hand tool
|
716 |
+
state.header.selected_tool = ToolMode.MOVE
|
717 |
+
|
718 |
+
elif codepoint == "p" and not CTRL_PRESSED:
|
719 |
+
# Pen tool
|
720 |
+
state.header.selected_tool = ToolMode.PEN
|
721 |
+
|
722 |
+
elif CTRL_PRESSED and codepoint == "p":
|
723 |
+
# Pencil tool
|
724 |
+
state.header.selected_tool = ToolMode.PENCIL
|
725 |
+
|
726 |
+
elif codepoint == "k" and not CTRL_PRESSED:
|
727 |
+
# Make keypoint
|
728 |
+
state.timeline.make_keyframe()
|
729 |
+
|
730 |
+
elif CTRL_PRESSED and codepoint == 'q':
|
731 |
+
# Quit
|
732 |
+
self.on_request_close()
|
733 |
+
|
734 |
+
elif CTRL_PRESSED and codepoint == 'i':
|
735 |
+
# Import
|
736 |
+
self.file_chooser = FileChoosePopup(load=self.on_file_chosen)
|
737 |
+
self.file_chooser.open()
|
738 |
+
|
739 |
+
elif CTRL_PRESSED and codepoint == "e":
|
740 |
+
# Export
|
741 |
+
state.project.export_to_gif(loop_mode=state.loop_mode)
|
742 |
+
|
743 |
+
elif CTRL_PRESSED and codepoint == 'c':
|
744 |
+
# Copy
|
745 |
+
if state.sidebar.selected_path_idx >= 0:
|
746 |
+
state.clipboard = state.draw_viewbox.get_path(state.sidebar.selected_path_idx).clone()
|
747 |
+
|
748 |
+
elif CTRL_PRESSED and codepoint == 'v':
|
749 |
+
# Paste
|
750 |
+
if isinstance(state.clipboard, BezierPath):
|
751 |
+
state.draw_viewbox.paste(state.clipboard)
|
752 |
+
|
753 |
+
elif CTRL_PRESSED and codepoint == 's':
|
754 |
+
# Save
|
755 |
+
self.save()
|
756 |
+
|
757 |
+
elif key == Keys.SPACEBAR:
|
758 |
+
# Play/Pause
|
759 |
+
state.header.selected_tool = ToolMode.PLAY
|
760 |
+
|
761 |
+
if state.header.is_playing:
|
762 |
+
state.header.pause_animation()
|
763 |
+
else:
|
764 |
+
state.header.play_animation()
|
765 |
+
|
766 |
+
elif key == Keys.LEFT:
|
767 |
+
# Previous frame
|
768 |
+
if state.current_frame > 0:
|
769 |
+
state.timeline.selected_frame = state.current_frame - 1
|
770 |
+
|
771 |
+
elif key == Keys.RIGHT:
|
772 |
+
# Next frame
|
773 |
+
if state.current_frame < state.timeline.nb_frames - 1:
|
774 |
+
state.timeline.selected_frame = state.current_frame + 1
|
775 |
+
|
776 |
+
def on_file_chosen(self, selection):
|
777 |
+
file_path = str(selection[0])
|
778 |
+
self.file_chooser.dismiss()
|
779 |
+
|
780 |
+
if file_path:
|
781 |
+
if not file_path.endswith(".svg"):
|
782 |
+
return
|
783 |
+
|
784 |
+
svg = SVG.load_svg(file_path)
|
785 |
+
svg = SVGTensorDataset.simplify(svg)
|
786 |
+
svg = SVGTensorDataset.preprocess(svg, mean=True)
|
787 |
+
|
788 |
+
state.draw_viewbox.load_svg(svg, frame_idx=state.timeline.selected_frame)
|
789 |
+
state.modified = True
|
790 |
+
state.timeline.make_keyframe(True)
|
791 |
+
|
792 |
+
|
793 |
+
if __name__ == "__main__":
|
794 |
+
DeepSVGApp().run()
|
deepsvg/gui/res/down.png
ADDED
![]() |
deepsvg/gui/res/hand.png
ADDED
![]() |
deepsvg/gui/res/hand.svg
ADDED
|
deepsvg/gui/res/pause.png
ADDED
![]() |
deepsvg/gui/res/pen.png
ADDED
![]() |
deepsvg/gui/res/pen.svg
ADDED
|
deepsvg/gui/res/pencil.png
ADDED
![]() |
deepsvg/gui/res/pencil.svg
ADDED
|
deepsvg/gui/res/play.png
ADDED
![]() |
deepsvg/gui/res/play.svg
ADDED
|
deepsvg/gui/res/switch.png
ADDED
![]() |
deepsvg/gui/res/up.png
ADDED
![]() |
deepsvg/gui/state/__init__.py
ADDED
File without changes
|
deepsvg/gui/state/project.py
ADDED
@@ -0,0 +1,115 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import uuid
|
3 |
+
import json
|
4 |
+
import numpy as np
|
5 |
+
from moviepy.editor import ImageClip, concatenate_videoclips
|
6 |
+
import shutil
|
7 |
+
|
8 |
+
from deepsvg.svglib.svg import SVG
|
9 |
+
from deepsvg.svglib.geom import Bbox
|
10 |
+
|
11 |
+
from ..config import ROOT_DIR
|
12 |
+
|
13 |
+
|
14 |
+
class Frame:
|
15 |
+
def __init__(self, index, keyframe=False, svg=None):
|
16 |
+
self.index = index
|
17 |
+
self.keyframe = keyframe
|
18 |
+
|
19 |
+
if svg is None:
|
20 |
+
svg = SVG([], viewbox=Bbox(256))
|
21 |
+
self.svg = svg
|
22 |
+
|
23 |
+
self.kivy_bezierpaths = None
|
24 |
+
|
25 |
+
def to_dict(self):
|
26 |
+
return {
|
27 |
+
"index": self.index,
|
28 |
+
"keyframe": self.keyframe
|
29 |
+
}
|
30 |
+
|
31 |
+
@staticmethod
|
32 |
+
def load_dict(frame):
|
33 |
+
f = Frame(frame["index"], frame["keyframe"])
|
34 |
+
return f
|
35 |
+
|
36 |
+
|
37 |
+
class DeepSVGProject:
|
38 |
+
def __init__(self, name="Title"):
|
39 |
+
self.name = name
|
40 |
+
self.uid = str(uuid.uuid4())
|
41 |
+
|
42 |
+
self.frames = [Frame(index=0)]
|
43 |
+
|
44 |
+
@property
|
45 |
+
def filename(self):
|
46 |
+
return os.path.join(ROOT_DIR, f"{self.uid}.json")
|
47 |
+
|
48 |
+
@property
|
49 |
+
def base_dir(self):
|
50 |
+
base_dir = os.path.join(ROOT_DIR, self.uid)
|
51 |
+
|
52 |
+
if not os.path.exists(base_dir):
|
53 |
+
os.makedirs(base_dir)
|
54 |
+
|
55 |
+
return base_dir
|
56 |
+
|
57 |
+
@property
|
58 |
+
def cache_dir(self):
|
59 |
+
cache_dir = os.path.join(self.base_dir, "cache")
|
60 |
+
|
61 |
+
if not os.path.exists(cache_dir):
|
62 |
+
os.makedirs(cache_dir)
|
63 |
+
|
64 |
+
return cache_dir
|
65 |
+
|
66 |
+
def load_project(self, file_path):
|
67 |
+
with open(file_path, "r") as f:
|
68 |
+
data = json.load(f)
|
69 |
+
|
70 |
+
self.name = data["name"]
|
71 |
+
self.uid = data["uid"]
|
72 |
+
|
73 |
+
self.load_frames(data["frames"])
|
74 |
+
|
75 |
+
shutil.rmtree(self.cache_dir)
|
76 |
+
|
77 |
+
def load_frames(self, frames):
|
78 |
+
self.frames = [Frame.load_dict(frame) for frame in frames]
|
79 |
+
|
80 |
+
for frame in self.frames:
|
81 |
+
frame.svg = SVG.load_svg(os.path.join(self.base_dir, f"{frame.index}.svg"))
|
82 |
+
|
83 |
+
def save_project(self):
|
84 |
+
with open(self.filename, "w") as f:
|
85 |
+
data = {
|
86 |
+
"name": self.name,
|
87 |
+
"uid": self.uid,
|
88 |
+
|
89 |
+
"frames": [frame.to_dict() for frame in self.frames]
|
90 |
+
}
|
91 |
+
|
92 |
+
json.dump(data, f)
|
93 |
+
|
94 |
+
self.save_frames()
|
95 |
+
|
96 |
+
def save_frames(self):
|
97 |
+
for frame in self.frames:
|
98 |
+
frame.svg.save_svg(os.path.join(self.base_dir, f"{frame.index}.svg"))
|
99 |
+
|
100 |
+
def export_to_gif(self, frame_duration=0.1, loop_mode=0):
|
101 |
+
from .state import LoopMode
|
102 |
+
|
103 |
+
imgs = [frame.svg.copy().normalize().draw(do_display=False, return_png=True) for frame in self.frames]
|
104 |
+
|
105 |
+
if loop_mode == LoopMode.REVERSE:
|
106 |
+
imgs = imgs[::-1]
|
107 |
+
elif loop_mode == LoopMode.PINGPONG:
|
108 |
+
imgs = imgs + imgs[::-1]
|
109 |
+
|
110 |
+
clips = [ImageClip(np.array(img)).set_duration(frame_duration) for img in imgs]
|
111 |
+
|
112 |
+
clip = concatenate_videoclips(clips, method="compose", bg_color=(255, 255, 255))
|
113 |
+
|
114 |
+
file_path = os.path.join(ROOT_DIR, f"{self.uid}.gif")
|
115 |
+
clip.write_gif(file_path, fps=24, verbose=False, logger=None)
|
deepsvg/gui/state/state.py
ADDED
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .project import DeepSVGProject
|
2 |
+
from ..config import STATE_PATH
|
3 |
+
import pickle
|
4 |
+
import os
|
5 |
+
|
6 |
+
|
7 |
+
class ToolMode:
|
8 |
+
MOVE = 0
|
9 |
+
PEN = 1
|
10 |
+
PENCIL = 2
|
11 |
+
PLAY = 3
|
12 |
+
|
13 |
+
|
14 |
+
class DrawMode:
|
15 |
+
STILL = 0
|
16 |
+
DRAW = 1
|
17 |
+
HOLDING_DOWN = 2
|
18 |
+
|
19 |
+
|
20 |
+
class LoopMode:
|
21 |
+
NORMAL = 0
|
22 |
+
REVERSE = 1
|
23 |
+
PINGPONG = 2
|
24 |
+
|
25 |
+
|
26 |
+
class PlaybackMode:
|
27 |
+
NORMAL = 0
|
28 |
+
EASE = 1
|
29 |
+
|
30 |
+
|
31 |
+
class LoopOrientation:
|
32 |
+
FORWARD = 1
|
33 |
+
BACKWARD = -1
|
34 |
+
|
35 |
+
|
36 |
+
class State:
|
37 |
+
def __init__(self):
|
38 |
+
self.project_file = None
|
39 |
+
self.project = DeepSVGProject()
|
40 |
+
|
41 |
+
self.loop_mode = LoopMode.PINGPONG
|
42 |
+
self.loop_orientation = LoopOrientation.FORWARD
|
43 |
+
self.playback_mode = PlaybackMode.EASE
|
44 |
+
|
45 |
+
self.delay = 1 / 10.
|
46 |
+
|
47 |
+
self.modified = False
|
48 |
+
|
49 |
+
# Keep track of previously selected current_frame, separately from timeline's selected_frame attribute
|
50 |
+
self.current_frame = -1
|
51 |
+
|
52 |
+
self.current_path = None
|
53 |
+
self.draw_mode = DrawMode.STILL
|
54 |
+
|
55 |
+
self.clipboard = None
|
56 |
+
|
57 |
+
# UI references
|
58 |
+
self.main_widget = None
|
59 |
+
self.header = None
|
60 |
+
self.sidebar = None
|
61 |
+
self.draw_viewbox = None
|
62 |
+
self.timeline = None
|
63 |
+
|
64 |
+
def save_state(self):
|
65 |
+
with open(STATE_PATH, "wb") as f:
|
66 |
+
state_dict = {k: v for k, v in self.__dict__.items() if k in ["project_file"]}
|
67 |
+
pickle.dump(state_dict, f)
|
68 |
+
|
69 |
+
def load_state(self):
|
70 |
+
if os.path.exists(STATE_PATH):
|
71 |
+
with open(STATE_PATH, "rb") as f:
|
72 |
+
self.__dict__.update(pickle.load(f))
|
73 |
+
|
74 |
+
def load_project(self):
|
75 |
+
if self.project_file is not None:
|
76 |
+
self.project.load_project(self.project_file)
|
77 |
+
else:
|
78 |
+
self.project_file = self.project.filename
|
deepsvg/gui/utils.py
ADDED
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from deepsvg.svglib.svg import SVG
|
2 |
+
from deepsvg.svglib.svg_path import SVGPath
|
3 |
+
from deepsvg.svglib.geom import Bbox
|
4 |
+
|
5 |
+
|
6 |
+
color_dict = {
|
7 |
+
"deepskyblue": [0., 0.69, 0.97],
|
8 |
+
"lime": [0.02, 1., 0.01],
|
9 |
+
"deeppink": [1., 0.07, 0.53],
|
10 |
+
"gold": [1., 0.81, 0.01],
|
11 |
+
"coral": [1., 0.45, 0.27],
|
12 |
+
"darkviolet": [0.53, 0.01, 0.8],
|
13 |
+
"royalblue": [0.21, 0.36, 0.86],
|
14 |
+
"darkmagenta": [0.5, 0., 0.5],
|
15 |
+
"teal": [0., 0.45, 0.45],
|
16 |
+
"green": [0., 0.45, 0.],
|
17 |
+
"maroon": [0.45, 0., 0.],
|
18 |
+
"aqua": [0., 1., 1.],
|
19 |
+
"grey": [0.45, 0.45, 0.45],
|
20 |
+
"steelblue": [0.24, 0.46, 0.67],
|
21 |
+
"orange": [1., 0.6, 0.01]
|
22 |
+
}
|
23 |
+
|
24 |
+
colors = ["deepskyblue", "lime", "deeppink", "gold", "coral", "darkviolet", "royalblue", "darkmagenta", "teal",
|
25 |
+
"gold", "green", "maroon", "aqua", "grey", "steelblue", "lime", "orange"]
|
26 |
+
|
27 |
+
|
28 |
+
class Keys:
|
29 |
+
LEFT = 276
|
30 |
+
UP = 273
|
31 |
+
RIGHT = 275
|
32 |
+
DOWN = 274
|
33 |
+
|
34 |
+
SPACEBAR = 32
|
35 |
+
|
36 |
+
|
37 |
+
def dist(a, b):
|
38 |
+
return ((a[0] - b[0]) ** 2 + (a[1] - b[1]) ** 2) ** .5
|
39 |
+
|
40 |
+
|
41 |
+
def preprocess_svg_path(svg_path: SVGPath, force_smooth=False):
|
42 |
+
svg = SVG([svg_path.to_group()], viewbox=Bbox(256)).normalize()
|
43 |
+
svg.canonicalize()
|
44 |
+
svg.filter_duplicates()
|
45 |
+
svg = svg.simplify_heuristic(force_smooth=force_smooth)
|
46 |
+
svg.normalize()
|
47 |
+
svg.numericalize(256)
|
48 |
+
|
49 |
+
return svg[0].path
|
50 |
+
|
51 |
+
|
52 |
+
def normalized_path(svg_path):
|
53 |
+
svg = SVG([svg_path.copy().to_group()], viewbox=Bbox(256)).normalize()
|
54 |
+
return svg[0].path
|
55 |
+
|
56 |
+
|
57 |
+
def flip_vertical(p):
|
58 |
+
return [p[0], 255 - p[1]]
|
59 |
+
|
60 |
+
|
61 |
+
def easein_easeout(t):
|
62 |
+
return t * t / (2. * (t * t - t) + 1.)
|
63 |
+
|
64 |
+
|
65 |
+
def d_easein_easeout(t):
|
66 |
+
return 3 * (1 - t) * t / (2 * t * t - 2 * t + 1) ** 2
|
deepsvg/model/basic_blocks.py
ADDED
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
|
4 |
+
|
5 |
+
class FCN(nn.Module):
|
6 |
+
def __init__(self, d_model, n_commands, n_args, args_dim=256, abs_targets=False):
|
7 |
+
super().__init__()
|
8 |
+
|
9 |
+
self.n_args = n_args
|
10 |
+
self.args_dim = args_dim
|
11 |
+
self.abs_targets = abs_targets
|
12 |
+
|
13 |
+
self.command_fcn = nn.Linear(d_model, n_commands)
|
14 |
+
|
15 |
+
if abs_targets:
|
16 |
+
self.args_fcn = nn.Linear(d_model, n_args)
|
17 |
+
else:
|
18 |
+
self.args_fcn = nn.Linear(d_model, n_args * args_dim)
|
19 |
+
|
20 |
+
def forward(self, out):
|
21 |
+
S, N, _ = out.shape
|
22 |
+
|
23 |
+
command_logits = self.command_fcn(out) # Shape [S, N, n_commands]
|
24 |
+
args_logits = self.args_fcn(out) # Shape [S, N, n_args * args_dim]
|
25 |
+
|
26 |
+
if not self.abs_targets:
|
27 |
+
args_logits = args_logits.reshape(S, N, self.n_args, self.args_dim) # Shape [S, N, n_args, args_dim]
|
28 |
+
|
29 |
+
return command_logits, args_logits
|
30 |
+
|
31 |
+
|
32 |
+
class ArgumentFCN(nn.Module):
|
33 |
+
def __init__(self, d_model, n_args, args_dim=256, abs_targets=False):
|
34 |
+
super().__init__()
|
35 |
+
|
36 |
+
self.n_args = n_args
|
37 |
+
self.args_dim = args_dim
|
38 |
+
self.abs_targets = abs_targets
|
39 |
+
|
40 |
+
# classification -> regression
|
41 |
+
if abs_targets:
|
42 |
+
self.args_fcn = nn.Sequential(
|
43 |
+
nn.Linear(d_model, n_args * args_dim),
|
44 |
+
nn.Linear(n_args * args_dim, n_args)
|
45 |
+
)
|
46 |
+
else:
|
47 |
+
self.args_fcn = nn.Linear(d_model, n_args * args_dim)
|
48 |
+
|
49 |
+
def forward(self, out):
|
50 |
+
S, N, _ = out.shape
|
51 |
+
|
52 |
+
args_logits = self.args_fcn(out) # Shape [S, N, n_args * args_dim]
|
53 |
+
|
54 |
+
if not self.abs_targets:
|
55 |
+
args_logits = args_logits.reshape(S, N, self.n_args, self.args_dim) # Shape [S, N, n_args, args_dim]
|
56 |
+
|
57 |
+
return args_logits
|
58 |
+
|
59 |
+
|
60 |
+
class HierarchFCN(nn.Module):
|
61 |
+
def __init__(self, d_model, dim_z):
|
62 |
+
super().__init__()
|
63 |
+
|
64 |
+
# self.visibility_fcn = nn.Linear(d_model, 2)
|
65 |
+
# self.z_fcn = nn.Linear(d_model, dim_z)
|
66 |
+
self.visibility_fcn = nn.Linear(dim_z, 2)
|
67 |
+
self.z_fcn = nn.Linear(dim_z, dim_z)
|
68 |
+
|
69 |
+
def forward(self, out):
|
70 |
+
G, N, _ = out.shape
|
71 |
+
|
72 |
+
visibility_logits = self.visibility_fcn(out) # Shape [G, N, 2]
|
73 |
+
z = self.z_fcn(out) # Shape [G, N, dim_z]
|
74 |
+
|
75 |
+
return visibility_logits.unsqueeze(0), z.unsqueeze(0)
|
76 |
+
|
77 |
+
|
78 |
+
class ResNet(nn.Module):
|
79 |
+
def __init__(self, d_model):
|
80 |
+
super().__init__()
|
81 |
+
|
82 |
+
self.linear1 = nn.Sequential(
|
83 |
+
nn.Linear(d_model, d_model), nn.ReLU()
|
84 |
+
)
|
85 |
+
self.linear2 = nn.Sequential(
|
86 |
+
nn.Linear(d_model, d_model), nn.ReLU()
|
87 |
+
)
|
88 |
+
self.linear3 = nn.Sequential(
|
89 |
+
nn.Linear(d_model, d_model), nn.ReLU()
|
90 |
+
)
|
91 |
+
self.linear4 = nn.Sequential(
|
92 |
+
nn.Linear(d_model, d_model), nn.ReLU()
|
93 |
+
)
|
94 |
+
|
95 |
+
def forward(self, z):
|
96 |
+
z = z + self.linear1(z)
|
97 |
+
z = z + self.linear2(z)
|
98 |
+
z = z + self.linear3(z)
|
99 |
+
z = z + self.linear4(z)
|
100 |
+
|
101 |
+
return z
|
deepsvg/model/config.py
ADDED
@@ -0,0 +1,107 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from deepsvg.difflib.tensor import SVGTensor
|
2 |
+
|
3 |
+
|
4 |
+
class _DefaultConfig:
|
5 |
+
"""
|
6 |
+
Model config.
|
7 |
+
"""
|
8 |
+
def __init__(self):
|
9 |
+
self.args_dim = 256 # Coordinate numericalization, default: 256 (8-bit)
|
10 |
+
self.n_args = 11 # Tensor nb of arguments, default: 11 (rx,ry,phi,fA,fS,qx1,qy1,qx2,qy2,x1,x2)
|
11 |
+
self.n_commands = len(SVGTensor.COMMANDS_SIMPLIFIED) # m, l, c, a, EOS, SOS, z
|
12 |
+
|
13 |
+
self.dropout = 0.1 # Dropout rate used in basic layers and Transformers
|
14 |
+
|
15 |
+
self.model_type = "transformer" # "transformer" ("lstm" implementation is work in progress)
|
16 |
+
|
17 |
+
self.encode_stages = 1 # One-stage or two-stage: 1 | 2
|
18 |
+
self.decode_stages = 1 # One-stage or two-stage: 1 | 2
|
19 |
+
|
20 |
+
self.use_resnet = True # Use extra fully-connected residual blocks after Encoder
|
21 |
+
|
22 |
+
self.use_vae = True # Sample latent vector (with reparametrization trick) or use encodings directly
|
23 |
+
|
24 |
+
self.pred_mode = "one_shot" # Feed-forward (one-shot) or autogressive: "one_shot" | "autoregressive"
|
25 |
+
self.rel_targets = False # Predict coordinates in relative or absolute format
|
26 |
+
|
27 |
+
self.label_condition = False # Make all blocks conditional on the label
|
28 |
+
self.n_labels = 100 # Number of labels (when used)
|
29 |
+
self.dim_label = 64 # Label embedding dimensionality
|
30 |
+
|
31 |
+
self.self_match = False # Use Hungarian (self-match) or Ordered assignment
|
32 |
+
|
33 |
+
self.n_layers = 4 # Number of Encoder blocks
|
34 |
+
self.n_layers_decode = 4 # Number of Decoder blocks
|
35 |
+
self.n_heads = 8 # Transformer config: number of heads
|
36 |
+
self.dim_feedforward = 512 # Transformer config: FF dimensionality
|
37 |
+
self.d_model = 256 # Transformer config: model dimensionality
|
38 |
+
|
39 |
+
self.dim_z = 256 # Latent vector dimensionality
|
40 |
+
|
41 |
+
self.max_num_groups = 8 # Number of paths (N_P)
|
42 |
+
self.max_seq_len = 30 # Number of commands (N_C)
|
43 |
+
self.max_total_len = self.max_num_groups * self.max_seq_len # Concatenated sequence length for baselines
|
44 |
+
|
45 |
+
self.num_groups_proposal = self.max_num_groups # Number of predicted paths, default: N_P
|
46 |
+
|
47 |
+
def get_model_args(self):
|
48 |
+
model_args = []
|
49 |
+
|
50 |
+
model_args += ["commands_grouped", "args_grouped"] if self.encode_stages <= 1 else ["commands", "args"]
|
51 |
+
|
52 |
+
if self.rel_targets:
|
53 |
+
model_args += ["commands_grouped", "args_rel_grouped"] if self.decode_stages == 1 else ["commands", "args_rel"]
|
54 |
+
else:
|
55 |
+
model_args += ["commands_grouped", "args_grouped"] if self.decode_stages == 1 else ["commands", "args"]
|
56 |
+
|
57 |
+
if self.label_condition:
|
58 |
+
model_args.append("label")
|
59 |
+
|
60 |
+
return model_args
|
61 |
+
|
62 |
+
|
63 |
+
class SketchRNN(_DefaultConfig):
|
64 |
+
# LSTM - Autoregressive - One-stage
|
65 |
+
def __init__(self):
|
66 |
+
super().__init__()
|
67 |
+
|
68 |
+
self.model_type = "lstm"
|
69 |
+
|
70 |
+
self.pred_mode = "autoregressive"
|
71 |
+
self.rel_targets = True
|
72 |
+
|
73 |
+
|
74 |
+
class Sketchformer(_DefaultConfig):
|
75 |
+
# Transformer - Autoregressive - One-stage
|
76 |
+
def __init__(self):
|
77 |
+
super().__init__()
|
78 |
+
|
79 |
+
self.pred_mode = "autoregressive"
|
80 |
+
self.rel_targets = True
|
81 |
+
|
82 |
+
|
83 |
+
class OneStageOneShot(_DefaultConfig):
|
84 |
+
# Transformer - One-shot - One-stage
|
85 |
+
def __init__(self):
|
86 |
+
super().__init__()
|
87 |
+
|
88 |
+
self.encode_stages = 1
|
89 |
+
self.decode_stages = 1
|
90 |
+
|
91 |
+
|
92 |
+
class Hierarchical(_DefaultConfig):
|
93 |
+
# Transformer - One-shot - Two-stage - Ordered
|
94 |
+
def __init__(self):
|
95 |
+
super().__init__()
|
96 |
+
|
97 |
+
self.encode_stages = 2
|
98 |
+
self.decode_stages = 2
|
99 |
+
|
100 |
+
|
101 |
+
class HierarchicalSelfMatching(_DefaultConfig):
|
102 |
+
# Transformer - One-shot - Two-stage - Hungarian
|
103 |
+
def __init__(self):
|
104 |
+
super().__init__()
|
105 |
+
self.encode_stages = 2
|
106 |
+
self.decode_stages = 2
|
107 |
+
self.self_match = True
|
deepsvg/model/layers/__init__.py
ADDED
File without changes
|
deepsvg/model/layers/attention.py
ADDED
@@ -0,0 +1,161 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch.nn import Linear
|
3 |
+
from torch.nn.init import xavier_uniform_
|
4 |
+
from torch.nn.init import constant_
|
5 |
+
from torch.nn.init import xavier_normal_
|
6 |
+
from torch.nn.parameter import Parameter
|
7 |
+
from torch.nn.modules.module import Module
|
8 |
+
|
9 |
+
from .functional import multi_head_attention_forward
|
10 |
+
|
11 |
+
|
12 |
+
class MultiheadAttention(Module):
|
13 |
+
r"""Allows the model to jointly attend to information
|
14 |
+
from different representation subspaces.
|
15 |
+
See reference: Attention Is All You Need
|
16 |
+
|
17 |
+
.. math::
|
18 |
+
\text{MultiHead}(Q, K, V) = \text{Concat}(head_1,\dots,head_h)W^O
|
19 |
+
\text{where} head_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V)
|
20 |
+
|
21 |
+
Args:
|
22 |
+
embed_dim: total dimension of the model.
|
23 |
+
num_heads: parallel attention heads.
|
24 |
+
dropout: a Dropout layer on attn_output_weights. Default: 0.0.
|
25 |
+
bias: add bias as module parameter. Default: True.
|
26 |
+
add_bias_kv: add bias to the key and value sequences at dim=0.
|
27 |
+
add_zero_attn: add a new batch of zeros to the key and
|
28 |
+
value sequences at dim=1.
|
29 |
+
kdim: total number of features in key. Default: None.
|
30 |
+
vdim: total number of features in key. Default: None.
|
31 |
+
|
32 |
+
Note: if kdim and vdim are None, they will be set to embed_dim such that
|
33 |
+
query, key, and value have the same number of features.
|
34 |
+
|
35 |
+
Examples::
|
36 |
+
|
37 |
+
>>> multihead_attn = nn.MultiheadAttention(embed_dim, num_heads)
|
38 |
+
>>> attn_output, attn_output_weights = multihead_attn(query, key, value)
|
39 |
+
"""
|
40 |
+
__annotations__ = {
|
41 |
+
'bias_k': torch._jit_internal.Optional[torch.Tensor],
|
42 |
+
'bias_v': torch._jit_internal.Optional[torch.Tensor],
|
43 |
+
}
|
44 |
+
__constants__ = ['q_proj_weight', 'k_proj_weight', 'v_proj_weight', 'in_proj_weight']
|
45 |
+
|
46 |
+
def __init__(self, embed_dim, num_heads, dropout=0., bias=True, add_bias_kv=False, add_zero_attn=False, kdim=None, vdim=None):
|
47 |
+
super(MultiheadAttention, self).__init__()
|
48 |
+
self.embed_dim = embed_dim
|
49 |
+
self.kdim = kdim if kdim is not None else embed_dim
|
50 |
+
self.vdim = vdim if vdim is not None else embed_dim
|
51 |
+
self._qkv_same_embed_dim = self.kdim == embed_dim and self.vdim == embed_dim
|
52 |
+
|
53 |
+
self.num_heads = num_heads
|
54 |
+
self.dropout = dropout
|
55 |
+
self.head_dim = embed_dim // num_heads
|
56 |
+
assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads"
|
57 |
+
|
58 |
+
if self._qkv_same_embed_dim is False:
|
59 |
+
self.q_proj_weight = Parameter(torch.Tensor(embed_dim, embed_dim))
|
60 |
+
self.k_proj_weight = Parameter(torch.Tensor(embed_dim, self.kdim))
|
61 |
+
self.v_proj_weight = Parameter(torch.Tensor(embed_dim, self.vdim))
|
62 |
+
self.register_parameter('in_proj_weight', None)
|
63 |
+
else:
|
64 |
+
self.in_proj_weight = Parameter(torch.empty(3 * embed_dim, embed_dim))
|
65 |
+
self.register_parameter('q_proj_weight', None)
|
66 |
+
self.register_parameter('k_proj_weight', None)
|
67 |
+
self.register_parameter('v_proj_weight', None)
|
68 |
+
|
69 |
+
if bias:
|
70 |
+
self.in_proj_bias = Parameter(torch.empty(3 * embed_dim))
|
71 |
+
else:
|
72 |
+
self.register_parameter('in_proj_bias', None)
|
73 |
+
self.out_proj = Linear(embed_dim, embed_dim, bias=bias)
|
74 |
+
|
75 |
+
if add_bias_kv:
|
76 |
+
self.bias_k = Parameter(torch.empty(1, 1, embed_dim))
|
77 |
+
self.bias_v = Parameter(torch.empty(1, 1, embed_dim))
|
78 |
+
else:
|
79 |
+
self.bias_k = self.bias_v = None
|
80 |
+
|
81 |
+
self.add_zero_attn = add_zero_attn
|
82 |
+
|
83 |
+
self._reset_parameters()
|
84 |
+
|
85 |
+
def _reset_parameters(self):
|
86 |
+
if self._qkv_same_embed_dim:
|
87 |
+
xavier_uniform_(self.in_proj_weight)
|
88 |
+
else:
|
89 |
+
xavier_uniform_(self.q_proj_weight)
|
90 |
+
xavier_uniform_(self.k_proj_weight)
|
91 |
+
xavier_uniform_(self.v_proj_weight)
|
92 |
+
|
93 |
+
if self.in_proj_bias is not None:
|
94 |
+
constant_(self.in_proj_bias, 0.)
|
95 |
+
constant_(self.out_proj.bias, 0.)
|
96 |
+
if self.bias_k is not None:
|
97 |
+
xavier_normal_(self.bias_k)
|
98 |
+
if self.bias_v is not None:
|
99 |
+
xavier_normal_(self.bias_v)
|
100 |
+
|
101 |
+
def __setstate__(self, state):
|
102 |
+
# Support loading old MultiheadAttention checkpoints generated by v1.1.0
|
103 |
+
if '_qkv_same_embed_dim' not in state:
|
104 |
+
state['_qkv_same_embed_dim'] = True
|
105 |
+
|
106 |
+
super(MultiheadAttention, self).__setstate__(state)
|
107 |
+
|
108 |
+
def forward(self, query, key, value, key_padding_mask=None,
|
109 |
+
need_weights=True, attn_mask=None):
|
110 |
+
# type: (Tensor, Tensor, Tensor, Optional[Tensor], bool, Optional[Tensor]) -> Tuple[Tensor, Optional[Tensor]]
|
111 |
+
r"""
|
112 |
+
Args:
|
113 |
+
query, key, value: map a query and a set of key-value pairs to an output.
|
114 |
+
See "Attention Is All You Need" for more details.
|
115 |
+
key_padding_mask: if provided, specified padding elements in the key will
|
116 |
+
be ignored by the attention. This is an binary mask. When the value is True,
|
117 |
+
the corresponding value on the attention layer will be filled with -inf.
|
118 |
+
need_weights: output attn_output_weights.
|
119 |
+
attn_mask: 2D or 3D mask that prevents attention to certain positions. This is an additive mask
|
120 |
+
(i.e. the values will be added to the attention layer). A 2D mask will be broadcasted for all
|
121 |
+
the batches while a 3D mask allows to specify a different mask for the entries of each batch.
|
122 |
+
|
123 |
+
Shape:
|
124 |
+
- Inputs:
|
125 |
+
- query: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is
|
126 |
+
the embedding dimension.
|
127 |
+
- key: :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is
|
128 |
+
the embedding dimension.
|
129 |
+
- value: :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is
|
130 |
+
the embedding dimension.
|
131 |
+
- key_padding_mask: :math:`(N, S)`, ByteTensor, where N is the batch size, S is the source sequence length.
|
132 |
+
- attn_mask: 2D mask :math:`(L, S)` where L is the target sequence length, S is the source sequence length.
|
133 |
+
3D mask :math:`(N*num_heads, L, S)` where N is the batch size, L is the target sequence length,
|
134 |
+
S is the source sequence length.
|
135 |
+
|
136 |
+
- Outputs:
|
137 |
+
- attn_output: :math:`(L, N, E)` where L is the target sequence length, N is the batch size,
|
138 |
+
E is the embedding dimension.
|
139 |
+
- attn_output_weights: :math:`(N, L, S)` where N is the batch size,
|
140 |
+
L is the target sequence length, S is the source sequence length.
|
141 |
+
"""
|
142 |
+
if not self._qkv_same_embed_dim:
|
143 |
+
return multi_head_attention_forward(
|
144 |
+
query, key, value, self.embed_dim, self.num_heads,
|
145 |
+
self.in_proj_weight, self.in_proj_bias,
|
146 |
+
self.bias_k, self.bias_v, self.add_zero_attn,
|
147 |
+
self.dropout, self.out_proj.weight, self.out_proj.bias,
|
148 |
+
training=self.training,
|
149 |
+
key_padding_mask=key_padding_mask, need_weights=need_weights,
|
150 |
+
attn_mask=attn_mask, use_separate_proj_weight=True,
|
151 |
+
q_proj_weight=self.q_proj_weight, k_proj_weight=self.k_proj_weight,
|
152 |
+
v_proj_weight=self.v_proj_weight)
|
153 |
+
else:
|
154 |
+
return multi_head_attention_forward(
|
155 |
+
query, key, value, self.embed_dim, self.num_heads,
|
156 |
+
self.in_proj_weight, self.in_proj_bias,
|
157 |
+
self.bias_k, self.bias_v, self.add_zero_attn,
|
158 |
+
self.dropout, self.out_proj.weight, self.out_proj.bias,
|
159 |
+
training=self.training,
|
160 |
+
key_padding_mask=key_padding_mask, need_weights=need_weights,
|
161 |
+
attn_mask=attn_mask)
|
deepsvg/model/layers/functional.py
ADDED
@@ -0,0 +1,256 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import division
|
2 |
+
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import torch.nn.functional as F
|
6 |
+
|
7 |
+
|
8 |
+
def multi_head_attention_forward(query, # type: Tensor
|
9 |
+
key, # type: Tensor
|
10 |
+
value, # type: Tensor
|
11 |
+
embed_dim_to_check, # type: int
|
12 |
+
num_heads, # type: int
|
13 |
+
in_proj_weight, # type: Tensor
|
14 |
+
in_proj_bias, # type: Tensor
|
15 |
+
bias_k, # type: Optional[Tensor]
|
16 |
+
bias_v, # type: Optional[Tensor]
|
17 |
+
add_zero_attn, # type: bool
|
18 |
+
dropout_p, # type: float
|
19 |
+
out_proj_weight, # type: Tensor
|
20 |
+
out_proj_bias, # type: Tensor
|
21 |
+
training=True, # type: bool
|
22 |
+
key_padding_mask=None, # type: Optional[Tensor]
|
23 |
+
need_weights=True, # type: bool
|
24 |
+
attn_mask=None, # type: Optional[Tensor]
|
25 |
+
use_separate_proj_weight=False, # type: bool
|
26 |
+
q_proj_weight=None, # type: Optional[Tensor]
|
27 |
+
k_proj_weight=None, # type: Optional[Tensor]
|
28 |
+
v_proj_weight=None, # type: Optional[Tensor]
|
29 |
+
static_k=None, # type: Optional[Tensor]
|
30 |
+
static_v=None # type: Optional[Tensor]
|
31 |
+
):
|
32 |
+
# type: (...) -> Tuple[Tensor, Optional[Tensor]]
|
33 |
+
r"""
|
34 |
+
Args:
|
35 |
+
query, key, value: map a query and a set of key-value pairs to an output.
|
36 |
+
See "Attention Is All You Need" for more details.
|
37 |
+
embed_dim_to_check: total dimension of the model.
|
38 |
+
num_heads: parallel attention heads.
|
39 |
+
in_proj_weight, in_proj_bias: input projection weight and bias.
|
40 |
+
bias_k, bias_v: bias of the key and value sequences to be added at dim=0.
|
41 |
+
add_zero_attn: add a new batch of zeros to the key and
|
42 |
+
value sequences at dim=1.
|
43 |
+
dropout_p: probability of an element to be zeroed.
|
44 |
+
out_proj_weight, out_proj_bias: the output projection weight and bias.
|
45 |
+
training: apply dropout if is ``True``.
|
46 |
+
key_padding_mask: if provided, specified padding elements in the key will
|
47 |
+
be ignored by the attention. This is an binary mask. When the value is True,
|
48 |
+
the corresponding value on the attention layer will be filled with -inf.
|
49 |
+
need_weights: output attn_output_weights.
|
50 |
+
attn_mask: 2D or 3D mask that prevents attention to certain positions. This is an additive mask
|
51 |
+
(i.e. the values will be added to the attention layer). A 2D mask will be broadcasted for all
|
52 |
+
the batches while a 3D mask allows to specify a different mask for the entries of each batch.
|
53 |
+
use_separate_proj_weight: the function accept the proj. weights for query, key,
|
54 |
+
and value in different forms. If false, in_proj_weight will be used, which is
|
55 |
+
a combination of q_proj_weight, k_proj_weight, v_proj_weight.
|
56 |
+
q_proj_weight, k_proj_weight, v_proj_weight, in_proj_bias: input projection weight and bias.
|
57 |
+
static_k, static_v: static key and value used for attention operators.
|
58 |
+
Shape:
|
59 |
+
Inputs:
|
60 |
+
- query: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is
|
61 |
+
the embedding dimension.
|
62 |
+
- key: :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is
|
63 |
+
the embedding dimension.
|
64 |
+
- value: :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is
|
65 |
+
the embedding dimension.
|
66 |
+
- key_padding_mask: :math:`(N, S)`, ByteTensor, where N is the batch size, S is the source sequence length.
|
67 |
+
- attn_mask: 2D mask :math:`(L, S)` where L is the target sequence length, S is the source sequence length.
|
68 |
+
3D mask :math:`(N*num_heads, L, S)` where N is the batch size, L is the target sequence length,
|
69 |
+
S is the source sequence length.
|
70 |
+
- static_k: :math:`(N*num_heads, S, E/num_heads)`, where S is the source sequence length,
|
71 |
+
N is the batch size, E is the embedding dimension. E/num_heads is the head dimension.
|
72 |
+
- static_v: :math:`(N*num_heads, S, E/num_heads)`, where S is the source sequence length,
|
73 |
+
N is the batch size, E is the embedding dimension. E/num_heads is the head dimension.
|
74 |
+
Outputs:
|
75 |
+
- attn_output: :math:`(L, N, E)` where L is the target sequence length, N is the batch size,
|
76 |
+
E is the embedding dimension.
|
77 |
+
- attn_output_weights: :math:`(N, L, S)` where N is the batch size,
|
78 |
+
L is the target sequence length, S is the source sequence length.
|
79 |
+
"""
|
80 |
+
|
81 |
+
tgt_len, bsz, embed_dim = query.size()
|
82 |
+
assert embed_dim == embed_dim_to_check
|
83 |
+
assert key.size() == value.size()
|
84 |
+
|
85 |
+
head_dim = embed_dim // num_heads
|
86 |
+
assert head_dim * num_heads == embed_dim, "embed_dim must be divisible by num_heads"
|
87 |
+
scaling = float(head_dim) ** -0.5
|
88 |
+
|
89 |
+
if not use_separate_proj_weight:
|
90 |
+
if torch.equal(query, key) and torch.equal(key, value):
|
91 |
+
# self-attention
|
92 |
+
q, k, v = F.linear(query, in_proj_weight, in_proj_bias).chunk(3, dim=-1)
|
93 |
+
|
94 |
+
elif torch.equal(key, value):
|
95 |
+
# encoder-decoder attention
|
96 |
+
# This is inline in_proj function with in_proj_weight and in_proj_bias
|
97 |
+
_b = in_proj_bias
|
98 |
+
_start = 0
|
99 |
+
_end = embed_dim
|
100 |
+
_w = in_proj_weight[_start:_end, :]
|
101 |
+
if _b is not None:
|
102 |
+
_b = _b[_start:_end]
|
103 |
+
q = F.linear(query, _w, _b)
|
104 |
+
|
105 |
+
if key is None:
|
106 |
+
assert value is None
|
107 |
+
k = None
|
108 |
+
v = None
|
109 |
+
else:
|
110 |
+
|
111 |
+
# This is inline in_proj function with in_proj_weight and in_proj_bias
|
112 |
+
_b = in_proj_bias
|
113 |
+
_start = embed_dim
|
114 |
+
_end = None
|
115 |
+
_w = in_proj_weight[_start:, :]
|
116 |
+
if _b is not None:
|
117 |
+
_b = _b[_start:]
|
118 |
+
k, v = F.linear(key, _w, _b).chunk(2, dim=-1)
|
119 |
+
|
120 |
+
else:
|
121 |
+
# This is inline in_proj function with in_proj_weight and in_proj_bias
|
122 |
+
_b = in_proj_bias
|
123 |
+
_start = 0
|
124 |
+
_end = embed_dim
|
125 |
+
_w = in_proj_weight[_start:_end, :]
|
126 |
+
if _b is not None:
|
127 |
+
_b = _b[_start:_end]
|
128 |
+
q = F.linear(query, _w, _b)
|
129 |
+
|
130 |
+
# This is inline in_proj function with in_proj_weight and in_proj_bias
|
131 |
+
_b = in_proj_bias
|
132 |
+
_start = embed_dim
|
133 |
+
_end = embed_dim * 2
|
134 |
+
_w = in_proj_weight[_start:_end, :]
|
135 |
+
if _b is not None:
|
136 |
+
_b = _b[_start:_end]
|
137 |
+
k = F.linear(key, _w, _b)
|
138 |
+
|
139 |
+
# This is inline in_proj function with in_proj_weight and in_proj_bias
|
140 |
+
_b = in_proj_bias
|
141 |
+
_start = embed_dim * 2
|
142 |
+
_end = None
|
143 |
+
_w = in_proj_weight[_start:, :]
|
144 |
+
if _b is not None:
|
145 |
+
_b = _b[_start:]
|
146 |
+
v = F.linear(value, _w, _b)
|
147 |
+
else:
|
148 |
+
q_proj_weight_non_opt = torch.jit._unwrap_optional(q_proj_weight)
|
149 |
+
len1, len2 = q_proj_weight_non_opt.size()
|
150 |
+
assert len1 == embed_dim and len2 == query.size(-1)
|
151 |
+
|
152 |
+
k_proj_weight_non_opt = torch.jit._unwrap_optional(k_proj_weight)
|
153 |
+
len1, len2 = k_proj_weight_non_opt.size()
|
154 |
+
assert len1 == embed_dim and len2 == key.size(-1)
|
155 |
+
|
156 |
+
v_proj_weight_non_opt = torch.jit._unwrap_optional(v_proj_weight)
|
157 |
+
len1, len2 = v_proj_weight_non_opt.size()
|
158 |
+
assert len1 == embed_dim and len2 == value.size(-1)
|
159 |
+
|
160 |
+
if in_proj_bias is not None:
|
161 |
+
q = F.linear(query, q_proj_weight_non_opt, in_proj_bias[0:embed_dim])
|
162 |
+
k = F.linear(key, k_proj_weight_non_opt, in_proj_bias[embed_dim:(embed_dim * 2)])
|
163 |
+
v = F.linear(value, v_proj_weight_non_opt, in_proj_bias[(embed_dim * 2):])
|
164 |
+
else:
|
165 |
+
q = F.linear(query, q_proj_weight_non_opt, in_proj_bias)
|
166 |
+
k = F.linear(key, k_proj_weight_non_opt, in_proj_bias)
|
167 |
+
v = F.linear(value, v_proj_weight_non_opt, in_proj_bias)
|
168 |
+
q = q * scaling
|
169 |
+
|
170 |
+
if attn_mask is not None:
|
171 |
+
if attn_mask.dim() == 2:
|
172 |
+
attn_mask = attn_mask.unsqueeze(0)
|
173 |
+
if list(attn_mask.size()) != [1, query.size(0), key.size(0)]:
|
174 |
+
raise RuntimeError('The size of the 2D attn_mask is not correct.')
|
175 |
+
elif attn_mask.dim() == 3:
|
176 |
+
if list(attn_mask.size()) != [bsz * num_heads, query.size(0), key.size(0)]:
|
177 |
+
raise RuntimeError('The size of the 3D attn_mask is not correct.')
|
178 |
+
else:
|
179 |
+
raise RuntimeError("attn_mask's dimension {} is not supported".format(attn_mask.dim()))
|
180 |
+
# attn_mask's dim is 3 now.
|
181 |
+
|
182 |
+
if bias_k is not None and bias_v is not None:
|
183 |
+
if static_k is None and static_v is None:
|
184 |
+
k = torch.cat([k, bias_k.repeat(1, bsz, 1)])
|
185 |
+
v = torch.cat([v, bias_v.repeat(1, bsz, 1)])
|
186 |
+
if attn_mask is not None:
|
187 |
+
attn_mask = F.pad(attn_mask, (0, 1))
|
188 |
+
if key_padding_mask is not None:
|
189 |
+
key_padding_mask = F.pad(key_padding_mask, (0, 1))
|
190 |
+
else:
|
191 |
+
assert static_k is None, "bias cannot be added to static key."
|
192 |
+
assert static_v is None, "bias cannot be added to static value."
|
193 |
+
else:
|
194 |
+
assert bias_k is None
|
195 |
+
assert bias_v is None
|
196 |
+
|
197 |
+
q = q.contiguous().view(tgt_len, bsz * num_heads, head_dim).transpose(0, 1)
|
198 |
+
if k is not None:
|
199 |
+
k = k.contiguous().view(-1, bsz * num_heads, head_dim).transpose(0, 1)
|
200 |
+
if v is not None:
|
201 |
+
v = v.contiguous().view(-1, bsz * num_heads, head_dim).transpose(0, 1)
|
202 |
+
|
203 |
+
if static_k is not None:
|
204 |
+
assert static_k.size(0) == bsz * num_heads
|
205 |
+
assert static_k.size(2) == head_dim
|
206 |
+
k = static_k
|
207 |
+
|
208 |
+
if static_v is not None:
|
209 |
+
assert static_v.size(0) == bsz * num_heads
|
210 |
+
assert static_v.size(2) == head_dim
|
211 |
+
v = static_v
|
212 |
+
|
213 |
+
src_len = k.size(1)
|
214 |
+
|
215 |
+
if key_padding_mask is not None:
|
216 |
+
assert key_padding_mask.size(0) == bsz
|
217 |
+
assert key_padding_mask.size(1) == src_len
|
218 |
+
|
219 |
+
if add_zero_attn:
|
220 |
+
src_len += 1
|
221 |
+
k = torch.cat([k, torch.zeros((k.size(0), 1) + k.size()[2:], dtype=k.dtype, device=k.device)], dim=1)
|
222 |
+
v = torch.cat([v, torch.zeros((v.size(0), 1) + v.size()[2:], dtype=v.dtype, device=v.device)], dim=1)
|
223 |
+
if attn_mask is not None:
|
224 |
+
attn_mask = F.pad(attn_mask, (0, 1))
|
225 |
+
if key_padding_mask is not None:
|
226 |
+
key_padding_mask = F.pad(key_padding_mask, (0, 1))
|
227 |
+
|
228 |
+
attn_output_weights = torch.bmm(q, k.transpose(1, 2))
|
229 |
+
assert list(attn_output_weights.size()) == [bsz * num_heads, tgt_len, src_len]
|
230 |
+
|
231 |
+
if attn_mask is not None:
|
232 |
+
attn_output_weights += attn_mask
|
233 |
+
|
234 |
+
if key_padding_mask is not None:
|
235 |
+
attn_output_weights = attn_output_weights.view(bsz, num_heads, tgt_len, src_len)
|
236 |
+
attn_output_weights = attn_output_weights.masked_fill(
|
237 |
+
key_padding_mask.unsqueeze(1).unsqueeze(2),
|
238 |
+
float('-inf'),
|
239 |
+
)
|
240 |
+
attn_output_weights = attn_output_weights.view(bsz * num_heads, tgt_len, src_len)
|
241 |
+
|
242 |
+
attn_output_weights = F.softmax(
|
243 |
+
attn_output_weights, dim=-1)
|
244 |
+
attn_output_weights = F.dropout(attn_output_weights, p=dropout_p, training=training)
|
245 |
+
|
246 |
+
attn_output = torch.bmm(attn_output_weights, v)
|
247 |
+
assert list(attn_output.size()) == [bsz * num_heads, tgt_len, head_dim]
|
248 |
+
attn_output = attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
|
249 |
+
attn_output = F.linear(attn_output, out_proj_weight, out_proj_bias)
|
250 |
+
|
251 |
+
if need_weights:
|
252 |
+
# average attention weights over heads
|
253 |
+
attn_output_weights = attn_output_weights.view(bsz, num_heads, tgt_len, src_len)
|
254 |
+
return attn_output, attn_output_weights.sum(dim=1) / num_heads
|
255 |
+
else:
|
256 |
+
return attn_output, None
|
deepsvg/model/layers/improved_transformer.py
ADDED
@@ -0,0 +1,141 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import copy
|
3 |
+
|
4 |
+
from torch.nn import functional as F
|
5 |
+
from torch.nn.modules.module import Module
|
6 |
+
from torch.nn.modules.container import ModuleList
|
7 |
+
from torch.nn.init import xavier_uniform_
|
8 |
+
from torch.nn.modules.dropout import Dropout
|
9 |
+
from torch.nn.modules.linear import Linear
|
10 |
+
from torch.nn.modules.normalization import LayerNorm
|
11 |
+
|
12 |
+
from .attention import MultiheadAttention
|
13 |
+
from .transformer import _get_activation_fn
|
14 |
+
|
15 |
+
|
16 |
+
class TransformerEncoderLayerImproved(Module):
|
17 |
+
def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, activation="relu", d_global2=None):
|
18 |
+
super(TransformerEncoderLayerImproved, self).__init__()
|
19 |
+
self.self_attn = MultiheadAttention(d_model, nhead, dropout=dropout)
|
20 |
+
|
21 |
+
if d_global2 is not None:
|
22 |
+
self.linear_global2 = Linear(d_global2, d_model)
|
23 |
+
|
24 |
+
# Implementation of Feedforward model
|
25 |
+
self.linear1 = Linear(d_model, dim_feedforward)
|
26 |
+
self.dropout = Dropout(dropout)
|
27 |
+
self.linear2 = Linear(dim_feedforward, d_model)
|
28 |
+
|
29 |
+
self.norm1 = LayerNorm(d_model)
|
30 |
+
self.norm2 = LayerNorm(d_model)
|
31 |
+
self.dropout1 = Dropout(dropout)
|
32 |
+
self.dropout2_2 = Dropout(dropout)
|
33 |
+
self.dropout2 = Dropout(dropout)
|
34 |
+
|
35 |
+
self.activation = _get_activation_fn(activation)
|
36 |
+
|
37 |
+
def __setstate__(self, state):
|
38 |
+
if 'activation' not in state:
|
39 |
+
state['activation'] = F.relu
|
40 |
+
super(TransformerEncoderLayerImproved, self).__setstate__(state)
|
41 |
+
|
42 |
+
def forward(self, src, memory2=None, src_mask=None, src_key_padding_mask=None):
|
43 |
+
src1 = self.norm1(src)
|
44 |
+
src2 = self.self_attn(src1, src1, src1, attn_mask=src_mask, key_padding_mask=src_key_padding_mask)[0]
|
45 |
+
src = src + self.dropout1(src2)
|
46 |
+
|
47 |
+
if memory2 is not None:
|
48 |
+
src2_2 = self.linear_global2(memory2)
|
49 |
+
src = src + self.dropout2_2(src2_2)
|
50 |
+
|
51 |
+
src1 = self.norm2(src)
|
52 |
+
src2 = self.linear2(self.dropout(self.activation(self.linear1(src1))))
|
53 |
+
src = src + self.dropout2(src2)
|
54 |
+
return src
|
55 |
+
|
56 |
+
|
57 |
+
class TransformerDecoderLayerImproved(Module):
|
58 |
+
def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, activation="relu"):
|
59 |
+
super(TransformerDecoderLayerImproved, self).__init__()
|
60 |
+
self.self_attn = MultiheadAttention(d_model, nhead, dropout=dropout)
|
61 |
+
self.multihead_attn = MultiheadAttention(d_model, nhead, dropout=dropout)
|
62 |
+
# Implementation of Feedforward model
|
63 |
+
self.linear1 = Linear(d_model, dim_feedforward)
|
64 |
+
self.dropout = Dropout(dropout)
|
65 |
+
self.linear2 = Linear(dim_feedforward, d_model)
|
66 |
+
|
67 |
+
self.norm1 = LayerNorm(d_model)
|
68 |
+
self.norm2 = LayerNorm(d_model)
|
69 |
+
self.norm3 = LayerNorm(d_model)
|
70 |
+
self.dropout1 = Dropout(dropout)
|
71 |
+
self.dropout2 = Dropout(dropout)
|
72 |
+
self.dropout3 = Dropout(dropout)
|
73 |
+
|
74 |
+
self.activation = _get_activation_fn(activation)
|
75 |
+
|
76 |
+
def __setstate__(self, state):
|
77 |
+
if 'activation' not in state:
|
78 |
+
state['activation'] = F.relu
|
79 |
+
super(TransformerDecoderLayerImproved, self).__setstate__(state)
|
80 |
+
|
81 |
+
def forward(self, tgt, memory, tgt_mask=None, memory_mask=None,
|
82 |
+
tgt_key_padding_mask=None, memory_key_padding_mask=None):
|
83 |
+
tgt1 = self.norm1(tgt)
|
84 |
+
tgt2 = self.self_attn(tgt1, tgt1, tgt1, attn_mask=tgt_mask, key_padding_mask=tgt_key_padding_mask)[0]
|
85 |
+
tgt = tgt + self.dropout1(tgt2)
|
86 |
+
|
87 |
+
tgt1 = self.norm2(tgt)
|
88 |
+
tgt2 = self.multihead_attn(tgt1, memory, memory, attn_mask=memory_mask, key_padding_mask=memory_key_padding_mask)[0]
|
89 |
+
tgt = tgt + self.dropout2(tgt2)
|
90 |
+
|
91 |
+
tgt1 = self.norm3(tgt)
|
92 |
+
tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt1))))
|
93 |
+
tgt = tgt + self.dropout3(tgt2)
|
94 |
+
return tgt
|
95 |
+
|
96 |
+
|
97 |
+
class TransformerDecoderLayerGlobalImproved(Module):
|
98 |
+
def __init__(self, d_model, d_global, nhead, dim_feedforward=2048, dropout=0.1, activation="relu", d_global2=None):
|
99 |
+
super(TransformerDecoderLayerGlobalImproved, self).__init__()
|
100 |
+
self.self_attn = MultiheadAttention(d_model, nhead, dropout=dropout)
|
101 |
+
|
102 |
+
self.linear_global = Linear(d_global, d_model)
|
103 |
+
|
104 |
+
if d_global2 is not None:
|
105 |
+
self.linear_global2 = Linear(d_global2, d_model)
|
106 |
+
|
107 |
+
# Implementation of Feedforward model
|
108 |
+
self.linear1 = Linear(d_model, dim_feedforward)
|
109 |
+
self.dropout = Dropout(dropout)
|
110 |
+
self.linear2 = Linear(dim_feedforward, d_model)
|
111 |
+
|
112 |
+
self.norm1 = LayerNorm(d_model)
|
113 |
+
self.norm2 = LayerNorm(d_model)
|
114 |
+
self.dropout1 = Dropout(dropout)
|
115 |
+
self.dropout2 = Dropout(dropout)
|
116 |
+
self.dropout2_2 = Dropout(dropout)
|
117 |
+
self.dropout3 = Dropout(dropout)
|
118 |
+
|
119 |
+
self.activation = _get_activation_fn(activation)
|
120 |
+
|
121 |
+
def __setstate__(self, state):
|
122 |
+
if 'activation' not in state:
|
123 |
+
state['activation'] = F.relu
|
124 |
+
super(TransformerDecoderLayerGlobalImproved, self).__setstate__(state)
|
125 |
+
|
126 |
+
def forward(self, tgt, memory, memory2=None, tgt_mask=None, tgt_key_padding_mask=None, *args, **kwargs):
|
127 |
+
tgt1 = self.norm1(tgt)
|
128 |
+
tgt2 = self.self_attn(tgt1, tgt1, tgt1, attn_mask=tgt_mask, key_padding_mask=tgt_key_padding_mask)[0]
|
129 |
+
tgt = tgt + self.dropout1(tgt2)
|
130 |
+
|
131 |
+
tgt2 = self.linear_global(memory)
|
132 |
+
tgt = tgt + self.dropout2(tgt2) # implicit broadcast
|
133 |
+
|
134 |
+
if memory2 is not None:
|
135 |
+
tgt2_2 = self.linear_global2(memory2)
|
136 |
+
tgt = tgt + self.dropout2_2(tgt2_2)
|
137 |
+
|
138 |
+
tgt1 = self.norm2(tgt)
|
139 |
+
tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt1))))
|
140 |
+
tgt = tgt + self.dropout3(tgt2)
|
141 |
+
return tgt
|
deepsvg/model/layers/positional_encoding.py
ADDED
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
|
5 |
+
|
6 |
+
class PositionalEncodingSinCos(nn.Module):
|
7 |
+
def __init__(self, d_model, dropout=0.1, max_len=250):
|
8 |
+
super(PositionalEncodingSinCos, self).__init__()
|
9 |
+
self.dropout = nn.Dropout(p=dropout)
|
10 |
+
|
11 |
+
pe = torch.zeros(max_len, d_model)
|
12 |
+
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
|
13 |
+
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
|
14 |
+
pe[:, 0::2] = torch.sin(position * div_term)
|
15 |
+
pe[:, 1::2] = torch.cos(position * div_term)
|
16 |
+
pe = pe.unsqueeze(0).transpose(0, 1)
|
17 |
+
self.register_buffer('pe', pe)
|
18 |
+
|
19 |
+
def forward(self, x):
|
20 |
+
x = x + self.pe[:x.size(0), :]
|
21 |
+
return self.dropout(x)
|
22 |
+
|
23 |
+
|
24 |
+
class PositionalEncodingLUT(nn.Module):
|
25 |
+
|
26 |
+
def __init__(self, d_model, dropout=0.1, max_len=250):
|
27 |
+
super(PositionalEncodingLUT, self).__init__()
|
28 |
+
self.dropout = nn.Dropout(p=dropout)
|
29 |
+
|
30 |
+
position = torch.arange(0, max_len, dtype=torch.long).unsqueeze(1)
|
31 |
+
self.register_buffer('position', position)
|
32 |
+
|
33 |
+
self.pos_embed = nn.Embedding(max_len, d_model)
|
34 |
+
|
35 |
+
self._init_embeddings()
|
36 |
+
|
37 |
+
def _init_embeddings(self):
|
38 |
+
nn.init.kaiming_normal_(self.pos_embed.weight, mode="fan_in")
|
39 |
+
|
40 |
+
def forward(self, x):
|
41 |
+
pos = self.position[:x.size(0)]
|
42 |
+
x = x + self.pos_embed(pos)
|
43 |
+
return self.dropout(x)
|
deepsvg/model/layers/transformer.py
ADDED
@@ -0,0 +1,393 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import copy
|
3 |
+
|
4 |
+
from torch.nn import functional as F
|
5 |
+
from torch.nn.modules.module import Module
|
6 |
+
from torch.nn.modules.container import ModuleList
|
7 |
+
from torch.nn.init import xavier_uniform_
|
8 |
+
from torch.nn.modules.dropout import Dropout
|
9 |
+
from torch.nn.modules.linear import Linear
|
10 |
+
from torch.nn.modules.normalization import LayerNorm
|
11 |
+
|
12 |
+
from .attention import MultiheadAttention
|
13 |
+
|
14 |
+
|
15 |
+
class Transformer(Module):
|
16 |
+
r"""A transformer model. User is able to modify the attributes as needed. The architecture
|
17 |
+
is based on the paper "Attention Is All You Need". Ashish Vaswani, Noam Shazeer,
|
18 |
+
Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez, Lukasz Kaiser, and
|
19 |
+
Illia Polosukhin. 2017. Attention is all you need. In Advances in Neural Information
|
20 |
+
Processing Systems, pages 6000-6010. Users can build the BERT(https://arxiv.org/abs/1810.04805)
|
21 |
+
model with corresponding parameters.
|
22 |
+
|
23 |
+
Args:
|
24 |
+
d_model: the number of expected features in the encoder/decoder inputs (default=512).
|
25 |
+
nhead: the number of heads in the multiheadattention models (default=8).
|
26 |
+
num_encoder_layers: the number of sub-encoder-layers in the encoder (default=6).
|
27 |
+
num_decoder_layers: the number of sub-decoder-layers in the decoder (default=6).
|
28 |
+
dim_feedforward: the dimension of the feedforward network model (default=2048).
|
29 |
+
dropout: the dropout value (default=0.1).
|
30 |
+
activation: the activation function of encoder/decoder intermediate layer, relu or gelu (default=relu).
|
31 |
+
custom_encoder: custom encoder (default=None).
|
32 |
+
custom_decoder: custom decoder (default=None).
|
33 |
+
|
34 |
+
Examples::
|
35 |
+
>>> transformer_model = nn.Transformer(nhead=16, num_encoder_layers=12)
|
36 |
+
>>> src = torch.rand((10, 32, 512))
|
37 |
+
>>> tgt = torch.rand((20, 32, 512))
|
38 |
+
>>> out = transformer_model(src, tgt)
|
39 |
+
|
40 |
+
Note: A full example to apply nn.Transformer module for the word language model is available in
|
41 |
+
https://github.com/pytorch/examples/tree/master/word_language_model
|
42 |
+
"""
|
43 |
+
|
44 |
+
def __init__(self, d_model=512, nhead=8, num_encoder_layers=6,
|
45 |
+
num_decoder_layers=6, dim_feedforward=2048, dropout=0.1,
|
46 |
+
activation="relu", custom_encoder=None, custom_decoder=None):
|
47 |
+
super(Transformer, self).__init__()
|
48 |
+
|
49 |
+
if custom_encoder is not None:
|
50 |
+
self.encoder = custom_encoder
|
51 |
+
else:
|
52 |
+
encoder_layer = TransformerEncoderLayer(d_model, nhead, dim_feedforward, dropout, activation)
|
53 |
+
encoder_norm = LayerNorm(d_model)
|
54 |
+
self.encoder = TransformerEncoder(encoder_layer, num_encoder_layers, encoder_norm)
|
55 |
+
|
56 |
+
if custom_decoder is not None:
|
57 |
+
self.decoder = custom_decoder
|
58 |
+
else:
|
59 |
+
decoder_layer = TransformerDecoderLayer(d_model, nhead, dim_feedforward, dropout, activation)
|
60 |
+
decoder_norm = LayerNorm(d_model)
|
61 |
+
self.decoder = TransformerDecoder(decoder_layer, num_decoder_layers, decoder_norm)
|
62 |
+
|
63 |
+
self._reset_parameters()
|
64 |
+
|
65 |
+
self.d_model = d_model
|
66 |
+
self.nhead = nhead
|
67 |
+
|
68 |
+
def forward(self, src, tgt, src_mask=None, tgt_mask=None,
|
69 |
+
memory_mask=None, src_key_padding_mask=None,
|
70 |
+
tgt_key_padding_mask=None, memory_key_padding_mask=None):
|
71 |
+
# type: (Tensor, Tensor, Optional[Tensor], Optional[Tensor], Optional[Tensor], Optional[Tensor], Optional[Tensor], Optional[Tensor]) -> Tensor # noqa
|
72 |
+
r"""Take in and process masked source/target sequences.
|
73 |
+
|
74 |
+
Args:
|
75 |
+
src: the sequence to the encoder (required).
|
76 |
+
tgt: the sequence to the decoder (required).
|
77 |
+
src_mask: the additive mask for the src sequence (optional).
|
78 |
+
tgt_mask: the additive mask for the tgt sequence (optional).
|
79 |
+
memory_mask: the additive mask for the encoder output (optional).
|
80 |
+
src_key_padding_mask: the ByteTensor mask for src keys per batch (optional).
|
81 |
+
tgt_key_padding_mask: the ByteTensor mask for tgt keys per batch (optional).
|
82 |
+
memory_key_padding_mask: the ByteTensor mask for memory keys per batch (optional).
|
83 |
+
|
84 |
+
Shape:
|
85 |
+
- src: :math:`(S, N, E)`.
|
86 |
+
- tgt: :math:`(T, N, E)`.
|
87 |
+
- src_mask: :math:`(S, S)`.
|
88 |
+
- tgt_mask: :math:`(T, T)`.
|
89 |
+
- memory_mask: :math:`(T, S)`.
|
90 |
+
- src_key_padding_mask: :math:`(N, S)`.
|
91 |
+
- tgt_key_padding_mask: :math:`(N, T)`.
|
92 |
+
- memory_key_padding_mask: :math:`(N, S)`.
|
93 |
+
|
94 |
+
Note: [src/tgt/memory]_mask should be filled with
|
95 |
+
float('-inf') for the masked positions and float(0.0) else. These masks
|
96 |
+
ensure that predictions for position i depend only on the unmasked positions
|
97 |
+
j and are applied identically for each sequence in a batch.
|
98 |
+
[src/tgt/memory]_key_padding_mask should be a ByteTensor where True values are positions
|
99 |
+
that should be masked with float('-inf') and False values will be unchanged.
|
100 |
+
This mask ensures that no information will be taken from position i if
|
101 |
+
it is masked, and has a separate mask for each sequence in a batch.
|
102 |
+
|
103 |
+
- output: :math:`(T, N, E)`.
|
104 |
+
|
105 |
+
Note: Due to the multi-head attention architecture in the transformer model,
|
106 |
+
the output sequence length of a transformer is same as the input sequence
|
107 |
+
(i.e. target) length of the decode.
|
108 |
+
|
109 |
+
where S is the source sequence length, T is the target sequence length, N is the
|
110 |
+
batch size, E is the feature number
|
111 |
+
|
112 |
+
Examples:
|
113 |
+
>>> output = transformer_model(src, tgt, src_mask=src_mask, tgt_mask=tgt_mask)
|
114 |
+
"""
|
115 |
+
|
116 |
+
if src.size(1) != tgt.size(1):
|
117 |
+
raise RuntimeError("the batch number of src and tgt must be equal")
|
118 |
+
|
119 |
+
if src.size(2) != self.d_model or tgt.size(2) != self.d_model:
|
120 |
+
raise RuntimeError("the feature number of src and tgt must be equal to d_model")
|
121 |
+
|
122 |
+
memory = self.encoder(src, mask=src_mask, src_key_padding_mask=src_key_padding_mask)
|
123 |
+
output = self.decoder(tgt, memory, tgt_mask=tgt_mask, memory_mask=memory_mask,
|
124 |
+
tgt_key_padding_mask=tgt_key_padding_mask,
|
125 |
+
memory_key_padding_mask=memory_key_padding_mask)
|
126 |
+
return output
|
127 |
+
|
128 |
+
|
129 |
+
def generate_square_subsequent_mask(self, sz):
|
130 |
+
r"""Generate a square mask for the sequence. The masked positions are filled with float('-inf').
|
131 |
+
Unmasked positions are filled with float(0.0).
|
132 |
+
"""
|
133 |
+
mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
|
134 |
+
mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
|
135 |
+
return mask
|
136 |
+
|
137 |
+
|
138 |
+
def _reset_parameters(self):
|
139 |
+
r"""Initiate parameters in the transformer model."""
|
140 |
+
|
141 |
+
for p in self.parameters():
|
142 |
+
if p.dim() > 1:
|
143 |
+
xavier_uniform_(p)
|
144 |
+
|
145 |
+
|
146 |
+
class TransformerEncoder(Module):
|
147 |
+
r"""TransformerEncoder is a stack of N encoder layers
|
148 |
+
|
149 |
+
Args:
|
150 |
+
encoder_layer: an instance of the TransformerEncoderLayer() class (required).
|
151 |
+
num_layers: the number of sub-encoder-layers in the encoder (required).
|
152 |
+
norm: the layer normalization component (optional).
|
153 |
+
|
154 |
+
Examples::
|
155 |
+
>>> encoder_layer = nn.TransformerEncoderLayer(d_model=512, nhead=8)
|
156 |
+
>>> transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=6)
|
157 |
+
>>> src = torch.rand(10, 32, 512)
|
158 |
+
>>> out = transformer_encoder(src)
|
159 |
+
"""
|
160 |
+
__constants__ = ['norm']
|
161 |
+
|
162 |
+
def __init__(self, encoder_layer, num_layers, norm=None):
|
163 |
+
super(TransformerEncoder, self).__init__()
|
164 |
+
self.layers = _get_clones(encoder_layer, num_layers)
|
165 |
+
self.num_layers = num_layers
|
166 |
+
self.norm = norm
|
167 |
+
|
168 |
+
def forward(self, src, memory2=None, mask=None, src_key_padding_mask=None):
|
169 |
+
# type: (Tensor, Optional[Tensor], Optional[Tensor], Optional[Tensor]) -> Tensor
|
170 |
+
r"""Pass the input through the encoder layers in turn.
|
171 |
+
|
172 |
+
Args:
|
173 |
+
src: the sequence to the encoder (required).
|
174 |
+
mask: the mask for the src sequence (optional).
|
175 |
+
src_key_padding_mask: the mask for the src keys per batch (optional).
|
176 |
+
|
177 |
+
Shape:
|
178 |
+
see the docs in Transformer class.
|
179 |
+
"""
|
180 |
+
output = src
|
181 |
+
|
182 |
+
for mod in self.layers:
|
183 |
+
output = mod(output, memory2=memory2, src_mask=mask, src_key_padding_mask=src_key_padding_mask)
|
184 |
+
|
185 |
+
if self.norm is not None:
|
186 |
+
output = self.norm(output)
|
187 |
+
|
188 |
+
return output
|
189 |
+
|
190 |
+
|
191 |
+
class TransformerDecoder(Module):
|
192 |
+
r"""TransformerDecoder is a stack of N decoder layers
|
193 |
+
|
194 |
+
Args:
|
195 |
+
decoder_layer: an instance of the TransformerDecoderLayer() class (required).
|
196 |
+
num_layers: the number of sub-decoder-layers in the decoder (required).
|
197 |
+
norm: the layer normalization component (optional).
|
198 |
+
|
199 |
+
Examples::
|
200 |
+
>>> decoder_layer = nn.TransformerDecoderLayer(d_model=512, nhead=8)
|
201 |
+
>>> transformer_decoder = nn.TransformerDecoder(decoder_layer, num_layers=6)
|
202 |
+
>>> memory = torch.rand(10, 32, 512)
|
203 |
+
>>> tgt = torch.rand(20, 32, 512)
|
204 |
+
>>> out = transformer_decoder(tgt, memory)
|
205 |
+
"""
|
206 |
+
__constants__ = ['norm']
|
207 |
+
|
208 |
+
def __init__(self, decoder_layer, num_layers, norm=None):
|
209 |
+
super(TransformerDecoder, self).__init__()
|
210 |
+
self.layers = _get_clones(decoder_layer, num_layers)
|
211 |
+
self.num_layers = num_layers
|
212 |
+
self.norm = norm
|
213 |
+
|
214 |
+
def forward(self, tgt, memory, memory2=None, tgt_mask=None,
|
215 |
+
memory_mask=None, tgt_key_padding_mask=None,
|
216 |
+
memory_key_padding_mask=None):
|
217 |
+
# type: (Tensor, Tensor, Optional[Tensor], Optional[Tensor], Optional[Tensor], Optional[Tensor], Optional[Tensor]) -> Tensor
|
218 |
+
r"""Pass the inputs (and mask) through the decoder layer in turn.
|
219 |
+
|
220 |
+
Args:
|
221 |
+
tgt: the sequence to the decoder (required).
|
222 |
+
memory: the sequence from the last layer of the encoder (required).
|
223 |
+
tgt_mask: the mask for the tgt sequence (optional).
|
224 |
+
memory_mask: the mask for the memory sequence (optional).
|
225 |
+
tgt_key_padding_mask: the mask for the tgt keys per batch (optional).
|
226 |
+
memory_key_padding_mask: the mask for the memory keys per batch (optional).
|
227 |
+
|
228 |
+
Shape:
|
229 |
+
see the docs in Transformer class.
|
230 |
+
"""
|
231 |
+
output = tgt
|
232 |
+
|
233 |
+
for mod in self.layers:
|
234 |
+
output = mod(output, memory, memory2=memory2, tgt_mask=tgt_mask,
|
235 |
+
memory_mask=memory_mask,
|
236 |
+
tgt_key_padding_mask=tgt_key_padding_mask,
|
237 |
+
memory_key_padding_mask=memory_key_padding_mask)
|
238 |
+
|
239 |
+
if self.norm is not None:
|
240 |
+
output = self.norm(output)
|
241 |
+
|
242 |
+
return output
|
243 |
+
|
244 |
+
|
245 |
+
class TransformerEncoderLayer(Module):
|
246 |
+
r"""TransformerEncoderLayer is made up of self-attn and feedforward network.
|
247 |
+
This standard encoder layer is based on the paper "Attention Is All You Need".
|
248 |
+
Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez,
|
249 |
+
Lukasz Kaiser, and Illia Polosukhin. 2017. Attention is all you need. In Advances in
|
250 |
+
Neural Information Processing Systems, pages 6000-6010. Users may modify or implement
|
251 |
+
in a different way during application.
|
252 |
+
|
253 |
+
Args:
|
254 |
+
d_model: the number of expected features in the input (required).
|
255 |
+
nhead: the number of heads in the multiheadattention models (required).
|
256 |
+
dim_feedforward: the dimension of the feedforward network model (default=2048).
|
257 |
+
dropout: the dropout value (default=0.1).
|
258 |
+
activation: the activation function of intermediate layer, relu or gelu (default=relu).
|
259 |
+
|
260 |
+
Examples::
|
261 |
+
>>> encoder_layer = nn.TransformerEncoderLayer(d_model=512, nhead=8)
|
262 |
+
>>> src = torch.rand(10, 32, 512)
|
263 |
+
>>> out = encoder_layer(src)
|
264 |
+
"""
|
265 |
+
|
266 |
+
def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, activation="relu"):
|
267 |
+
super(TransformerEncoderLayer, self).__init__()
|
268 |
+
self.self_attn = MultiheadAttention(d_model, nhead, dropout=dropout)
|
269 |
+
# Implementation of Feedforward model
|
270 |
+
self.linear1 = Linear(d_model, dim_feedforward)
|
271 |
+
self.dropout = Dropout(dropout)
|
272 |
+
self.linear2 = Linear(dim_feedforward, d_model)
|
273 |
+
|
274 |
+
self.norm1 = LayerNorm(d_model)
|
275 |
+
self.norm2 = LayerNorm(d_model)
|
276 |
+
self.dropout1 = Dropout(dropout)
|
277 |
+
self.dropout2 = Dropout(dropout)
|
278 |
+
|
279 |
+
self.activation = _get_activation_fn(activation)
|
280 |
+
|
281 |
+
def __setstate__(self, state):
|
282 |
+
if 'activation' not in state:
|
283 |
+
state['activation'] = F.relu
|
284 |
+
super(TransformerEncoderLayer, self).__setstate__(state)
|
285 |
+
|
286 |
+
def forward(self, src, src_mask=None, src_key_padding_mask=None):
|
287 |
+
# type: (Tensor, Optional[Tensor], Optional[Tensor]) -> Tensor
|
288 |
+
r"""Pass the input through the encoder layer.
|
289 |
+
|
290 |
+
Args:
|
291 |
+
src: the sequence to the encoder layer (required).
|
292 |
+
src_mask: the mask for the src sequence (optional).
|
293 |
+
src_key_padding_mask: the mask for the src keys per batch (optional).
|
294 |
+
|
295 |
+
Shape:
|
296 |
+
see the docs in Transformer class.
|
297 |
+
"""
|
298 |
+
src2 = self.self_attn(src, src, src, attn_mask=src_mask,
|
299 |
+
key_padding_mask=src_key_padding_mask)[0]
|
300 |
+
src = src + self.dropout1(src2)
|
301 |
+
src = self.norm1(src)
|
302 |
+
src2 = self.linear2(self.dropout(self.activation(self.linear1(src))))
|
303 |
+
src = src + self.dropout2(src2)
|
304 |
+
src = self.norm2(src)
|
305 |
+
return src
|
306 |
+
|
307 |
+
|
308 |
+
class TransformerDecoderLayer(Module):
|
309 |
+
r"""TransformerDecoderLayer is made up of self-attn, multi-head-attn and feedforward network.
|
310 |
+
This standard decoder layer is based on the paper "Attention Is All You Need".
|
311 |
+
Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez,
|
312 |
+
Lukasz Kaiser, and Illia Polosukhin. 2017. Attention is all you need. In Advances in
|
313 |
+
Neural Information Processing Systems, pages 6000-6010. Users may modify or implement
|
314 |
+
in a different way during application.
|
315 |
+
|
316 |
+
Args:
|
317 |
+
d_model: the number of expected features in the input (required).
|
318 |
+
nhead: the number of heads in the multiheadattention models (required).
|
319 |
+
dim_feedforward: the dimension of the feedforward network model (default=2048).
|
320 |
+
dropout: the dropout value (default=0.1).
|
321 |
+
activation: the activation function of intermediate layer, relu or gelu (default=relu).
|
322 |
+
|
323 |
+
Examples::
|
324 |
+
>>> decoder_layer = nn.TransformerDecoderLayer(d_model=512, nhead=8)
|
325 |
+
>>> memory = torch.rand(10, 32, 512)
|
326 |
+
>>> tgt = torch.rand(20, 32, 512)
|
327 |
+
>>> out = decoder_layer(tgt, memory)
|
328 |
+
"""
|
329 |
+
|
330 |
+
def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, activation="relu"):
|
331 |
+
super(TransformerDecoderLayer, self).__init__()
|
332 |
+
self.self_attn = MultiheadAttention(d_model, nhead, dropout=dropout)
|
333 |
+
self.multihead_attn = MultiheadAttention(d_model, nhead, dropout=dropout)
|
334 |
+
# Implementation of Feedforward model
|
335 |
+
self.linear1 = Linear(d_model, dim_feedforward)
|
336 |
+
self.dropout = Dropout(dropout)
|
337 |
+
self.linear2 = Linear(dim_feedforward, d_model)
|
338 |
+
|
339 |
+
self.norm1 = LayerNorm(d_model)
|
340 |
+
self.norm2 = LayerNorm(d_model)
|
341 |
+
self.norm3 = LayerNorm(d_model)
|
342 |
+
self.dropout1 = Dropout(dropout)
|
343 |
+
self.dropout2 = Dropout(dropout)
|
344 |
+
self.dropout3 = Dropout(dropout)
|
345 |
+
|
346 |
+
self.activation = _get_activation_fn(activation)
|
347 |
+
|
348 |
+
def __setstate__(self, state):
|
349 |
+
if 'activation' not in state:
|
350 |
+
state['activation'] = F.relu
|
351 |
+
super(TransformerDecoderLayer, self).__setstate__(state)
|
352 |
+
|
353 |
+
def forward(self, tgt, memory, tgt_mask=None, memory_mask=None,
|
354 |
+
tgt_key_padding_mask=None, memory_key_padding_mask=None):
|
355 |
+
# type: (Tensor, Tensor, Optional[Tensor], Optional[Tensor], Optional[Tensor], Optional[Tensor]) -> Tensor
|
356 |
+
r"""Pass the inputs (and mask) through the decoder layer.
|
357 |
+
|
358 |
+
Args:
|
359 |
+
tgt: the sequence to the decoder layer (required).
|
360 |
+
memory: the sequence from the last layer of the encoder (required).
|
361 |
+
tgt_mask: the mask for the tgt sequence (optional).
|
362 |
+
memory_mask: the mask for the memory sequence (optional).
|
363 |
+
tgt_key_padding_mask: the mask for the tgt keys per batch (optional).
|
364 |
+
memory_key_padding_mask: the mask for the memory keys per batch (optional).
|
365 |
+
|
366 |
+
Shape:
|
367 |
+
see the docs in Transformer class.
|
368 |
+
"""
|
369 |
+
tgt2 = self.self_attn(tgt, tgt, tgt, attn_mask=tgt_mask,
|
370 |
+
key_padding_mask=tgt_key_padding_mask)[0]
|
371 |
+
tgt = tgt + self.dropout1(tgt2)
|
372 |
+
tgt = self.norm1(tgt)
|
373 |
+
tgt2 = self.multihead_attn(tgt, memory, memory, attn_mask=memory_mask,
|
374 |
+
key_padding_mask=memory_key_padding_mask)[0]
|
375 |
+
tgt = tgt + self.dropout2(tgt2)
|
376 |
+
tgt = self.norm2(tgt)
|
377 |
+
tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt))))
|
378 |
+
tgt = tgt + self.dropout3(tgt2)
|
379 |
+
tgt = self.norm3(tgt)
|
380 |
+
return tgt
|
381 |
+
|
382 |
+
|
383 |
+
def _get_clones(module, N):
|
384 |
+
return ModuleList([copy.deepcopy(module) for i in range(N)])
|
385 |
+
|
386 |
+
|
387 |
+
def _get_activation_fn(activation):
|
388 |
+
if activation == "relu":
|
389 |
+
return F.relu
|
390 |
+
elif activation == "gelu":
|
391 |
+
return F.gelu
|
392 |
+
|
393 |
+
raise RuntimeError("activation should be relu/gelu, not {}".format(activation))
|
deepsvg/model/layers/utils.py
ADDED
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
|
4 |
+
def to_negative_mask(mask):
|
5 |
+
if mask is None:
|
6 |
+
return
|
7 |
+
|
8 |
+
mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
|
9 |
+
return mask
|
10 |
+
|
11 |
+
|
12 |
+
def generate_square_subsequent_mask(sz):
|
13 |
+
mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
|
14 |
+
mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
|
15 |
+
return mask
|
16 |
+
|
17 |
+
|
18 |
+
def generate_adj_subsequent_mask(sz):
|
19 |
+
mask = torch.diag(torch.ones(sz), diagonal=0) + torch.diag(torch.ones(sz-1), diagonal=-1)
|
20 |
+
|
21 |
+
if sz >= 2:
|
22 |
+
mask = mask + torch.diag(torch.ones(sz-2), diagonal=-2)
|
23 |
+
|
24 |
+
return to_negative_mask(mask)
|
25 |
+
|
26 |
+
|
27 |
+
def generate_adj_mask(sz):
|
28 |
+
mask = torch.diag(torch.ones(sz), diagonal=0) +\
|
29 |
+
torch.diag(torch.ones(sz - 1), diagonal=+1) +\
|
30 |
+
torch.diag(torch.ones(sz - 1), diagonal=-1)
|
31 |
+
|
32 |
+
if sz >= 2:
|
33 |
+
mask = mask + torch.diag(torch.ones(sz - 2), diagonal=-2) +\
|
34 |
+
torch.diag(torch.ones(sz - 2), diagonal=+2)
|
35 |
+
|
36 |
+
return to_negative_mask(mask)
|
deepsvg/model/loss.py
ADDED
@@ -0,0 +1,104 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
from deepsvg.difflib.tensor import SVGTensor
|
5 |
+
from .utils import _get_padding_mask, _get_visibility_mask
|
6 |
+
from .config import _DefaultConfig
|
7 |
+
|
8 |
+
|
9 |
+
class SVGLoss(nn.Module):
|
10 |
+
def __init__(self, cfg: _DefaultConfig):
|
11 |
+
super().__init__()
|
12 |
+
|
13 |
+
self.cfg = cfg
|
14 |
+
|
15 |
+
self.args_dim = 2 * cfg.args_dim if cfg.rel_targets else cfg.args_dim + 1
|
16 |
+
|
17 |
+
self.register_buffer("cmd_args_mask", SVGTensor.CMD_ARGS_MASK)
|
18 |
+
|
19 |
+
def forward(self, output, labels, weights):
|
20 |
+
loss = 0.
|
21 |
+
res = {}
|
22 |
+
|
23 |
+
# VAE
|
24 |
+
if self.cfg.use_vae:
|
25 |
+
mu, logsigma = output["mu"], output["logsigma"]
|
26 |
+
loss_kl = -0.5 * torch.mean(1 + logsigma - mu.pow(2) - torch.exp(logsigma))
|
27 |
+
loss_kl = loss_kl.clamp(min=weights["kl_tolerance"])
|
28 |
+
|
29 |
+
loss += weights["loss_kl_weight"] * loss_kl
|
30 |
+
res["loss_kl"] = loss_kl
|
31 |
+
|
32 |
+
# remove commitment loss
|
33 |
+
# if self.cfg.use_vqvae:
|
34 |
+
# vqvae_loss = output["vqvae_loss"].mean()
|
35 |
+
# loss += vqvae_loss
|
36 |
+
# res["vqvae_loss"] = vqvae_loss
|
37 |
+
|
38 |
+
# Target & predictions
|
39 |
+
# tgt_commands.shape [batch_size, max_num_groups, max_seq_len + 2]
|
40 |
+
# tgt_args.shape [batch_size, max_num_groups, max_seq_len + 2, n_args]
|
41 |
+
tgt_commands, tgt_args = output["tgt_commands"], output["tgt_args"]
|
42 |
+
|
43 |
+
visibility_mask = _get_visibility_mask(tgt_commands, seq_dim=-1)
|
44 |
+
padding_mask = _get_padding_mask(tgt_commands, seq_dim=-1, extended=True) * visibility_mask.unsqueeze(-1)
|
45 |
+
|
46 |
+
command_logits, args_logits = output["command_logits"], output["args_logits"]
|
47 |
+
|
48 |
+
# 2-stage visibility
|
49 |
+
if self.cfg.decode_stages == 2:
|
50 |
+
visibility_logits = output["visibility_logits"]
|
51 |
+
loss_visibility = F.cross_entropy(visibility_logits.reshape(-1, 2), visibility_mask.reshape(-1).long())
|
52 |
+
|
53 |
+
loss += weights["loss_visibility_weight"] * loss_visibility
|
54 |
+
res["loss_visibility"] = loss_visibility
|
55 |
+
|
56 |
+
# Commands & args
|
57 |
+
if self.cfg.bin_targets: # 当使用 bin_targets 时,每个坐标是由 8 bit 代表的,所以会多一维
|
58 |
+
tgt_args = tgt_args[..., 1:, :, :]
|
59 |
+
else:
|
60 |
+
tgt_args = tgt_args[..., 1:, :]
|
61 |
+
tgt_commands, padding_mask = tgt_commands[..., 1:], padding_mask[..., 1:]
|
62 |
+
|
63 |
+
# mask.shape [batch_size, 8, 31, 11]
|
64 |
+
# 对于预测正确的 command, mask 会乘上 True, cmd_args_mask 向量不会发生改变
|
65 |
+
# 对于预测错误的 command, mask 会乘上 False, 相当于把 cmd_args_mask 置为 0, 即不统计对应的 args
|
66 |
+
# pred_cmd = torch.argmax(command_logits, dim = -1)
|
67 |
+
# mask = self.cmd_args_mask[tgt_commands.long()] * (pred_cmd == tgt_commands).unsqueeze(-1)
|
68 |
+
|
69 |
+
mask = self.cmd_args_mask[tgt_commands.long()]
|
70 |
+
|
71 |
+
|
72 |
+
# padding_mask.shape [batch_size, num_path, num_commands + 1]
|
73 |
+
# command_logits.shape [batch_size, num_path, num_commands + 1, n_commands]
|
74 |
+
# command_logits[padding_mask.bool()].shape [-1, n_commands]
|
75 |
+
# 目的是把 PAD 的位置筛掉
|
76 |
+
loss_cmd = F.cross_entropy(command_logits[padding_mask.bool()].reshape(-1, self.cfg.n_commands), tgt_commands[padding_mask.bool()].reshape(-1).long())
|
77 |
+
|
78 |
+
if self.cfg.abs_targets:
|
79 |
+
# l2 loss performs better than l1 loss
|
80 |
+
loss_args = nn.MSELoss()(
|
81 |
+
args_logits[mask.bool()].reshape(-1),
|
82 |
+
tgt_args[mask.bool()].reshape(-1).float()
|
83 |
+
)
|
84 |
+
elif self.cfg.bin_targets:
|
85 |
+
loss_args = nn.MSELoss()(
|
86 |
+
args_logits[mask.bool()].reshape(-1),
|
87 |
+
tgt_args[mask.bool()].reshape(-1).float()
|
88 |
+
)
|
89 |
+
else:
|
90 |
+
loss_args = F.cross_entropy(
|
91 |
+
args_logits[mask.bool()].reshape(-1, self.args_dim),
|
92 |
+
tgt_args[mask.bool()].reshape(-1).long() + 1
|
93 |
+
) # shift due to -1 PAD_VAL
|
94 |
+
|
95 |
+
loss += weights["loss_cmd_weight"] * loss_cmd \
|
96 |
+
+ weights["loss_args_weight"] * loss_args
|
97 |
+
|
98 |
+
res.update({
|
99 |
+
"loss": loss,
|
100 |
+
"loss_cmd": loss_cmd,
|
101 |
+
"loss_args": loss_args
|
102 |
+
})
|
103 |
+
|
104 |
+
return res
|
deepsvg/model/model.py
ADDED
@@ -0,0 +1,690 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from deepsvg.difflib.tensor import SVGTensor
|
2 |
+
from deepsvg.utils.utils import _pack_group_batch, _unpack_group_batch, _make_seq_first, _make_batch_first, eval_decorator
|
3 |
+
from deepsvg.utils import bit2int
|
4 |
+
|
5 |
+
from .layers.transformer import *
|
6 |
+
from .layers.improved_transformer import *
|
7 |
+
from .layers.positional_encoding import *
|
8 |
+
from .vector_quantize_pytorch import VectorQuantize
|
9 |
+
from .basic_blocks import FCN, HierarchFCN, ResNet, ArgumentFCN
|
10 |
+
from .config import _DefaultConfig
|
11 |
+
from .utils import (_get_padding_mask, _get_key_padding_mask, _get_group_mask, _get_visibility_mask,
|
12 |
+
_get_key_visibility_mask, _generate_square_subsequent_mask, _sample_categorical, _threshold_sample)
|
13 |
+
|
14 |
+
from torch.nn.utils.rnn import pad_packed_sequence, pack_padded_sequence
|
15 |
+
from scipy.optimize import linear_sum_assignment
|
16 |
+
from einops import rearrange
|
17 |
+
|
18 |
+
from random import randint
|
19 |
+
|
20 |
+
|
21 |
+
class SVGEmbedding(nn.Module):
|
22 |
+
def __init__(self, cfg: _DefaultConfig, seq_len, use_group=True, group_len=None):
|
23 |
+
super().__init__()
|
24 |
+
|
25 |
+
self.cfg = cfg
|
26 |
+
|
27 |
+
# command embedding
|
28 |
+
self.command_embed = nn.Embedding(cfg.n_commands, cfg.d_model) # (7, 256)
|
29 |
+
self.embed_fcn = nn.Linear(cfg.n_args, cfg.d_model)
|
30 |
+
|
31 |
+
self.use_group = use_group
|
32 |
+
if use_group:
|
33 |
+
if group_len is None:
|
34 |
+
group_len = cfg.max_num_groups
|
35 |
+
self.group_embed = nn.Embedding(group_len+2, cfg.d_model)
|
36 |
+
|
37 |
+
self.pos_encoding = PositionalEncodingLUT(cfg.d_model, max_len=seq_len+2, dropout=cfg.dropout)
|
38 |
+
|
39 |
+
self.register_buffer("cmd_args_mask", SVGTensor.CMD_ARGS_MASK)
|
40 |
+
|
41 |
+
self._init_embeddings()
|
42 |
+
|
43 |
+
def _init_embeddings(self):
|
44 |
+
nn.init.kaiming_normal_(self.command_embed.weight, mode="fan_in")
|
45 |
+
nn.init.kaiming_normal_(self.embed_fcn.weight, mode="fan_in")
|
46 |
+
|
47 |
+
# if not self.cfg.bin_targets:
|
48 |
+
# nn.init.kaiming_normal_(self.arg_embed.weight, mode="fan_in")
|
49 |
+
|
50 |
+
if self.use_group:
|
51 |
+
nn.init.kaiming_normal_(self.group_embed.weight, mode="fan_in")
|
52 |
+
|
53 |
+
def forward(self, commands, args, groups=None):
|
54 |
+
# commands.shape (32, 960) = (max_seq_len + 2, max_num_groups * batch_size)
|
55 |
+
S, GN = commands.shape
|
56 |
+
|
57 |
+
src = self.command_embed(commands.long()) + self.embed_fcn(args)
|
58 |
+
|
59 |
+
if self.use_group:
|
60 |
+
src = src + self.group_embed(groups.long())
|
61 |
+
|
62 |
+
src = self.pos_encoding(src)
|
63 |
+
return src
|
64 |
+
|
65 |
+
|
66 |
+
class ConstEmbedding(nn.Module):
|
67 |
+
def __init__(self, cfg: _DefaultConfig, seq_len):
|
68 |
+
super().__init__()
|
69 |
+
|
70 |
+
self.cfg = cfg
|
71 |
+
|
72 |
+
self.seq_len = seq_len
|
73 |
+
|
74 |
+
self.PE = PositionalEncodingLUT(cfg.d_model, max_len=seq_len, dropout=cfg.dropout)
|
75 |
+
|
76 |
+
def forward(self, z):
|
77 |
+
N = z.size(1)
|
78 |
+
src = self.PE(z.new_zeros(self.seq_len, N, self.cfg.d_model))
|
79 |
+
return src
|
80 |
+
|
81 |
+
|
82 |
+
class LabelEmbedding(nn.Module):
|
83 |
+
def __init__(self, cfg: _DefaultConfig):
|
84 |
+
super().__init__()
|
85 |
+
|
86 |
+
self.label_embedding = nn.Embedding(cfg.n_labels, cfg.dim_label)
|
87 |
+
|
88 |
+
self._init_embeddings()
|
89 |
+
|
90 |
+
def _init_embeddings(self):
|
91 |
+
nn.init.kaiming_normal_(self.label_embedding.weight, mode="fan_in")
|
92 |
+
|
93 |
+
def forward(self, label):
|
94 |
+
src = self.label_embedding(label)
|
95 |
+
return src
|
96 |
+
|
97 |
+
|
98 |
+
class Encoder(nn.Module):
|
99 |
+
def __init__(self, cfg: _DefaultConfig):
|
100 |
+
super().__init__()
|
101 |
+
|
102 |
+
self.cfg = cfg
|
103 |
+
|
104 |
+
seq_len = cfg.max_seq_len if cfg.encode_stages == 2 else cfg.max_total_len
|
105 |
+
self.use_group = cfg.encode_stages == 1
|
106 |
+
self.embedding = SVGEmbedding(cfg, seq_len, use_group=self.use_group)
|
107 |
+
|
108 |
+
if cfg.label_condition:
|
109 |
+
self.label_embedding = LabelEmbedding(cfg)
|
110 |
+
dim_label = cfg.dim_label if cfg.label_condition else None
|
111 |
+
|
112 |
+
if cfg.model_type == "transformer":
|
113 |
+
encoder_layer = TransformerEncoderLayerImproved(cfg.d_model, cfg.n_heads, cfg.dim_feedforward, cfg.dropout, d_global2=dim_label)
|
114 |
+
encoder_norm = LayerNorm(cfg.d_model)
|
115 |
+
self.encoder = TransformerEncoder(encoder_layer, cfg.n_layers, encoder_norm)
|
116 |
+
|
117 |
+
else: # "lstm"
|
118 |
+
self.encoder = nn.LSTM(cfg.d_model, cfg.d_model // 2, dropout=cfg.dropout, bidirectional=True)
|
119 |
+
|
120 |
+
if cfg.encode_stages == 2:
|
121 |
+
if not cfg.self_match:
|
122 |
+
self.hierarchical_PE = PositionalEncodingLUT(cfg.d_model, max_len=cfg.max_num_groups)
|
123 |
+
|
124 |
+
# hierarchical_encoder_layer = TransformerEncoderLayerImproved(cfg.d_model, cfg.n_heads, cfg.dim_feedforward, cfg.dropout, d_global2=dim_label)
|
125 |
+
# hierarchical_encoder_norm = LayerNorm(cfg.d_model)
|
126 |
+
# self.hierarchical_encoder = TransformerEncoder(hierarchical_encoder_layer, cfg.n_layers, hierarchical_encoder_norm)
|
127 |
+
|
128 |
+
def forward(self, commands, args, label=None):
|
129 |
+
# commands.shape: [batch_size, max_num_groups, max_seq_len + 2]
|
130 |
+
# args.shape: [batch_size, max_num_groups, max_seq_len + 2, n_args]
|
131 |
+
S, G, N = commands.shape
|
132 |
+
l = self.label_embedding(label).unsqueeze(0).unsqueeze(0).repeat(1, commands.size(1), 1, 1) if self.cfg.label_condition else None
|
133 |
+
|
134 |
+
# if self.cfg.encode_stages == 2:
|
135 |
+
# visibility_mask, key_visibility_mask = _get_visibility_mask(commands, seq_dim=0), _get_key_visibility_mask(commands, seq_dim=0)
|
136 |
+
|
137 |
+
commands, args, l = _pack_group_batch(commands, args, l)
|
138 |
+
# commands.shape: [batch_size, max_num_groups * (max_seq_len + 2)]
|
139 |
+
# key_padding_mask 使得在做 attention 的时候可以遮住 <PAD>
|
140 |
+
padding_mask, key_padding_mask = _get_padding_mask(commands, seq_dim=0), _get_key_padding_mask(commands, seq_dim=0)
|
141 |
+
group_mask = _get_group_mask(commands, seq_dim=0) if self.use_group else None
|
142 |
+
|
143 |
+
# cmd_src, args_src = self.embedding(commands, args, group_mask)
|
144 |
+
src = self.embedding(commands, args, group_mask)
|
145 |
+
|
146 |
+
if self.cfg.model_type == "transformer":
|
147 |
+
memory = self.encoder(src, mask=None, src_key_padding_mask=key_padding_mask, memory2=l)
|
148 |
+
z = memory * padding_mask # 不对 command 做 avg
|
149 |
+
else: # "lstm"
|
150 |
+
hidden_cell = (src.new_zeros(2, N, self.cfg.d_model // 2),
|
151 |
+
src.new_zeros(2, N, self.cfg.d_model // 2))
|
152 |
+
sequence_lengths = padding_mask.sum(dim=0).squeeze(-1)
|
153 |
+
x = pack_padded_sequence(src, sequence_lengths, enforce_sorted=False)
|
154 |
+
|
155 |
+
packed_output, _ = self.encoder(x, hidden_cell)
|
156 |
+
|
157 |
+
memory, _ = pad_packed_sequence(packed_output)
|
158 |
+
idx = (sequence_lengths - 1).long().view(1, -1, 1).repeat(1, 1, self.cfg.d_model)
|
159 |
+
z = memory.gather(dim=0, index=idx)
|
160 |
+
|
161 |
+
# cmd_z, args_z = _unpack_group_batch(N, cmd_z, args_z)
|
162 |
+
z = _unpack_group_batch(N, z)
|
163 |
+
|
164 |
+
# 为什么不用 encode_stages == 1 这个 flag 来实现单个 encoder?
|
165 |
+
# 当 encode_stages = 1 时, 获取 data 会有一个 group 操作. 现在尽量不修改原来的代码逻辑
|
166 |
+
if self.cfg.one_encoder:
|
167 |
+
return z.transpose(0, 1)
|
168 |
+
|
169 |
+
if self.cfg.encode_stages == 2:
|
170 |
+
assert False, 'not use E2'
|
171 |
+
# src = z.transpose(0, 1)
|
172 |
+
# src = _pack_group_batch(src)
|
173 |
+
# l = self.label_embedding(label).unsqueeze(0) if self.cfg.label_condition else None
|
174 |
+
|
175 |
+
# if not self.cfg.self_match:
|
176 |
+
# src = self.hierarchical_PE(src)
|
177 |
+
|
178 |
+
# memory = self.hierarchical_encoder(src, mask=None, src_key_padding_mask=key_visibility_mask, memory2=l)
|
179 |
+
|
180 |
+
# if self.cfg.quantize_path:
|
181 |
+
# z = (memory * visibility_mask)
|
182 |
+
# else:
|
183 |
+
# z = (memory * visibility_mask).sum(dim=0, keepdim=True) / visibility_mask.sum(dim=0, keepdim=True)
|
184 |
+
# z = _unpack_group_batch(N, z)
|
185 |
+
|
186 |
+
return z
|
187 |
+
|
188 |
+
|
189 |
+
class VAE(nn.Module):
|
190 |
+
def __init__(self, cfg: _DefaultConfig):
|
191 |
+
super(VAE, self).__init__()
|
192 |
+
|
193 |
+
self.enc_mu_fcn = nn.Linear(cfg.d_model, cfg.dim_z)
|
194 |
+
self.enc_sigma_fcn = nn.Linear(cfg.d_model, cfg.dim_z)
|
195 |
+
|
196 |
+
self._init_embeddings()
|
197 |
+
|
198 |
+
def _init_embeddings(self):
|
199 |
+
nn.init.normal_(self.enc_mu_fcn.weight, std=0.001)
|
200 |
+
nn.init.constant_(self.enc_mu_fcn.bias, 0)
|
201 |
+
nn.init.normal_(self.enc_sigma_fcn.weight, std=0.001)
|
202 |
+
nn.init.constant_(self.enc_sigma_fcn.bias, 0)
|
203 |
+
|
204 |
+
def forward(self, z):
|
205 |
+
mu, logsigma = self.enc_mu_fcn(z), self.enc_sigma_fcn(z)
|
206 |
+
sigma = torch.exp(logsigma / 2.)
|
207 |
+
z = mu + sigma * torch.randn_like(sigma)
|
208 |
+
|
209 |
+
return z, mu, logsigma
|
210 |
+
|
211 |
+
|
212 |
+
class Bottleneck(nn.Module):
|
213 |
+
def __init__(self, cfg: _DefaultConfig):
|
214 |
+
super(Bottleneck, self).__init__()
|
215 |
+
|
216 |
+
self.bottleneck = nn.Linear(cfg.d_model, cfg.dim_z)
|
217 |
+
|
218 |
+
def forward(self, z):
|
219 |
+
return self.bottleneck(z)
|
220 |
+
|
221 |
+
|
222 |
+
class Decoder(nn.Module):
|
223 |
+
def __init__(self, cfg: _DefaultConfig):
|
224 |
+
super(Decoder, self).__init__()
|
225 |
+
|
226 |
+
self.cfg = cfg
|
227 |
+
|
228 |
+
if cfg.label_condition:
|
229 |
+
self.label_embedding = LabelEmbedding(cfg)
|
230 |
+
dim_label = cfg.dim_label if cfg.label_condition else None
|
231 |
+
|
232 |
+
if cfg.decode_stages == 2:
|
233 |
+
# self.hierarchical_embedding = ConstEmbedding(cfg, cfg.num_groups_proposal)
|
234 |
+
|
235 |
+
# hierarchical_decoder_layer = TransformerDecoderLayerGlobalImproved(cfg.d_model, cfg.dim_z, cfg.n_heads, cfg.dim_feedforward, cfg.dropout, d_global2=dim_label)
|
236 |
+
# hierarchical_decoder_norm = LayerNorm(cfg.d_model)
|
237 |
+
# self.hierarchical_decoder = TransformerDecoder(hierarchical_decoder_layer, cfg.n_layers_decode, hierarchical_decoder_norm)
|
238 |
+
self.hierarchical_fcn = HierarchFCN(cfg.d_model, cfg.dim_z)
|
239 |
+
|
240 |
+
if cfg.pred_mode == "autoregressive":
|
241 |
+
self.embedding = SVGEmbedding(cfg, cfg.max_total_len, rel_args=cfg.rel_targets, use_group=True, group_len=cfg.max_total_len)
|
242 |
+
|
243 |
+
square_subsequent_mask = _generate_square_subsequent_mask(self.cfg.max_total_len+1)
|
244 |
+
self.register_buffer("square_subsequent_mask", square_subsequent_mask)
|
245 |
+
else: # "one_shot"
|
246 |
+
seq_len = cfg.max_seq_len+1 if cfg.decode_stages == 2 else cfg.max_total_len+1
|
247 |
+
self.embedding = ConstEmbedding(cfg, seq_len)
|
248 |
+
if cfg.args_decoder:
|
249 |
+
self.argument_embedding = ConstEmbedding(cfg, seq_len)
|
250 |
+
|
251 |
+
if cfg.model_type == "transformer":
|
252 |
+
decoder_layer = TransformerDecoderLayerGlobalImproved(cfg.d_model, cfg.dim_z, cfg.n_heads, cfg.dim_feedforward, cfg.dropout, d_global2=dim_label)
|
253 |
+
decoder_norm = LayerNorm(cfg.d_model)
|
254 |
+
self.decoder = TransformerDecoder(decoder_layer, cfg.n_layers_decode, decoder_norm)
|
255 |
+
|
256 |
+
else: # "lstm"
|
257 |
+
self.fc_hc = nn.Linear(cfg.dim_z, 2 * cfg.d_model)
|
258 |
+
self.decoder = nn.LSTM(cfg.d_model, cfg.d_model, dropout=cfg.dropout)
|
259 |
+
|
260 |
+
if cfg.rel_targets:
|
261 |
+
args_dim = 2 * cfg.args_dim
|
262 |
+
if cfg.bin_targets:
|
263 |
+
args_dim = 8
|
264 |
+
else:
|
265 |
+
args_dim = cfg.args_dim + 1
|
266 |
+
|
267 |
+
self.fcn = FCN(cfg.d_model, cfg.n_commands, cfg.n_args, args_dim, cfg.abs_targets)
|
268 |
+
|
269 |
+
def _get_initial_state(self, z):
|
270 |
+
hidden, cell = torch.split(torch.tanh(self.fc_hc(z)), self.cfg.d_model, dim=2)
|
271 |
+
hidden_cell = hidden.contiguous(), cell.contiguous()
|
272 |
+
return hidden_cell
|
273 |
+
|
274 |
+
def forward(self, z, commands, args, label=None, hierarch_logits=None, return_hierarch=False):
|
275 |
+
N = z.size(2)
|
276 |
+
l = self.label_embedding(label).unsqueeze(0) if self.cfg.label_condition else None
|
277 |
+
if hierarch_logits is None:
|
278 |
+
# z = _pack_group_batch(z)
|
279 |
+
visibility_z = _pack_group_batch(torch.mean(z[:, 1:, ...], dim=1, keepdim=True)) # 负责预测 visibility, 并且把 SOS 移除
|
280 |
+
|
281 |
+
if self.cfg.decode_stages == 2:
|
282 |
+
if hierarch_logits is None:
|
283 |
+
# src = self.hierarchical_embedding(z)
|
284 |
+
# # print('D2 PE src', src.shape)
|
285 |
+
# # print('D2 con z', z.shape)
|
286 |
+
# out = self.hierarchical_decoder(src, z, tgt_mask=None, tgt_key_padding_mask=None, memory2=l)
|
287 |
+
# # print('D2 out', out.shape)
|
288 |
+
# hierarch_logits, _z = self.hierarchical_fcn(out)
|
289 |
+
# # print('hierarch_logits origin', hierarch_logits.shape)
|
290 |
+
|
291 |
+
# only linear layer for visibility prediction
|
292 |
+
hierarch_logits, _z = self.hierarchical_fcn(visibility_z)
|
293 |
+
|
294 |
+
if self.cfg.label_condition: l = l.unsqueeze(0).repeat(1, z.size(1), 1, 1)
|
295 |
+
|
296 |
+
hierarch_logits, l = _pack_group_batch(hierarch_logits, l)
|
297 |
+
if not self.cfg.connect_through:
|
298 |
+
z = _pack_group_batch(_z)
|
299 |
+
|
300 |
+
if return_hierarch:
|
301 |
+
return _unpack_group_batch(N, hierarch_logits, z)
|
302 |
+
|
303 |
+
if self.cfg.pred_mode == "autoregressive":
|
304 |
+
S = commands.size(0)
|
305 |
+
commands, args = _pack_group_batch(commands, args)
|
306 |
+
|
307 |
+
group_mask = _get_group_mask(commands, seq_dim=0)
|
308 |
+
|
309 |
+
src = self.embedding(commands, args, group_mask)
|
310 |
+
|
311 |
+
if self.cfg.model_type == "transformer":
|
312 |
+
key_padding_mask = _get_key_padding_mask(commands, seq_dim=0)
|
313 |
+
out = self.decoder(src, z, tgt_mask=self.square_subsequent_mask[:S, :S], tgt_key_padding_mask=key_padding_mask, memory2=l)
|
314 |
+
else: # "lstm"
|
315 |
+
hidden_cell = self._get_initial_state(z) # TODO: reinject intermediate state
|
316 |
+
out, _ = self.decoder(src, hidden_cell)
|
317 |
+
|
318 |
+
else: # "one_shot"
|
319 |
+
if self.cfg.connect_through:
|
320 |
+
z = rearrange(z, 'p c b d -> c (p b) d')
|
321 |
+
z = z[1:, ...]
|
322 |
+
|
323 |
+
src = self.embedding(z)
|
324 |
+
out = self.decoder(src, z, tgt_mask=None, tgt_key_padding_mask=None, memory2=l)
|
325 |
+
# print('D1 out', out.shape)
|
326 |
+
|
327 |
+
if self.cfg.args_decoder:
|
328 |
+
command_logits = self.command_fcn(out)
|
329 |
+
z = torch.argmax(command_logits, dim=-1).unsqueeze(-1).float()
|
330 |
+
src = self.argument_embedding(z)
|
331 |
+
# print('D0 PE src', src.shape)
|
332 |
+
# print('D0 con z', z.shape)
|
333 |
+
out = self.argument_decoder(src, z, tgt_mask=None, tgt_key_padding_mask=None, memory2=l)
|
334 |
+
# print('D0 out', out.shape)
|
335 |
+
args_logits = self.argument_fcn(out)
|
336 |
+
else:
|
337 |
+
# command_logits, args_logits = self.fcn(cmd_out, args_out)
|
338 |
+
command_logits, args_logits = self.fcn(out)
|
339 |
+
|
340 |
+
out_logits = (command_logits, args_logits) + ((hierarch_logits,) if self.cfg.decode_stages == 2 else ())
|
341 |
+
|
342 |
+
return _unpack_group_batch(N, *out_logits)
|
343 |
+
|
344 |
+
|
345 |
+
class SVGTransformer(nn.Module):
|
346 |
+
def __init__(self, cfg: _DefaultConfig):
|
347 |
+
super(SVGTransformer, self).__init__()
|
348 |
+
|
349 |
+
self.cfg = cfg
|
350 |
+
# self.args_dim = 2 * cfg.args_dim if cfg.rel_targets else cfg.args_dim + 1 # 257
|
351 |
+
if cfg.rel_targets:
|
352 |
+
args_dim = 2 * cfg.args_dim
|
353 |
+
if cfg.bin_targets:
|
354 |
+
args_dim = 8
|
355 |
+
else:
|
356 |
+
args_dim = cfg.args_dim + 1
|
357 |
+
|
358 |
+
if self.cfg.encode_stages > 0:
|
359 |
+
|
360 |
+
self.encoder = Encoder(cfg)
|
361 |
+
|
362 |
+
if cfg.use_resnet:
|
363 |
+
self.resnet = ResNet(cfg.d_model)
|
364 |
+
|
365 |
+
if cfg.use_vae:
|
366 |
+
self.vae = VAE(cfg)
|
367 |
+
else:
|
368 |
+
self.bottleneck = Bottleneck(cfg)
|
369 |
+
# self.bottleneck2 = Bottleneck(cfg)
|
370 |
+
self.encoder_norm = LayerNorm(cfg.dim_z, elementwise_affine=False)
|
371 |
+
|
372 |
+
if cfg.use_vqvae:
|
373 |
+
self.vqvae = VectorQuantize(
|
374 |
+
dim = cfg.dim_z,
|
375 |
+
codebook_size = cfg.codebook_size,
|
376 |
+
decay = 0.8,
|
377 |
+
commitment_weight = 0.,
|
378 |
+
use_cosine_sim = cfg.use_cosine_sim,
|
379 |
+
)
|
380 |
+
|
381 |
+
self.decoder = Decoder(cfg)
|
382 |
+
|
383 |
+
# 定义 self.cmd_args_mask, 但是分配一块持久性缓冲区
|
384 |
+
self.register_buffer("cmd_args_mask", SVGTensor.CMD_ARGS_MASK)
|
385 |
+
|
386 |
+
def perfect_matching(self, command_logits, args_logits, hierarch_logits, tgt_commands, tgt_args):
|
387 |
+
with torch.no_grad():
|
388 |
+
N, G, S, n_args = tgt_args.shape
|
389 |
+
visibility_mask = _get_visibility_mask(tgt_commands, seq_dim=-1)
|
390 |
+
padding_mask = _get_padding_mask(tgt_commands, seq_dim=-1, extended=True) * visibility_mask.unsqueeze(-1)
|
391 |
+
|
392 |
+
# Unsqueeze
|
393 |
+
tgt_commands, tgt_args, tgt_hierarch = tgt_commands.unsqueeze(2), tgt_args.unsqueeze(2), visibility_mask.unsqueeze(2)
|
394 |
+
command_logits, args_logits, hierarch_logits = command_logits.unsqueeze(1), args_logits.unsqueeze(1), hierarch_logits.unsqueeze(1).squeeze(-2)
|
395 |
+
|
396 |
+
# Loss
|
397 |
+
tgt_hierarch, hierarch_logits = tgt_hierarch.repeat(1, 1, self.cfg.num_groups_proposal), hierarch_logits.repeat(1, G, 1, 1)
|
398 |
+
tgt_commands, command_logits = tgt_commands.repeat(1, 1, self.cfg.num_groups_proposal, 1), command_logits.repeat(1, G, 1, 1, 1)
|
399 |
+
tgt_args, args_logits = tgt_args.repeat(1, 1, self.cfg.num_groups_proposal, 1, 1), args_logits.repeat(1, G, 1, 1, 1, 1)
|
400 |
+
|
401 |
+
padding_mask, mask = padding_mask.unsqueeze(2).repeat(1, 1, self.cfg.num_groups_proposal, 1), self.cmd_args_mask[tgt_commands.long()]
|
402 |
+
|
403 |
+
loss_args = F.cross_entropy(args_logits.reshape(-1, self.args_dim), tgt_args.reshape(-1).long() + 1, reduction="none").reshape(N, G, self.cfg.num_groups_proposal, S, n_args) # shift due to -1 PAD_VAL
|
404 |
+
loss_cmd = F.cross_entropy(command_logits.reshape(-1, self.cfg.n_commands), tgt_commands.reshape(-1).long(), reduction="none").reshape(N, G, self.cfg.num_groups_proposal, S)
|
405 |
+
loss_hierarch = F.cross_entropy(hierarch_logits.reshape(-1, 2), tgt_hierarch.reshape(-1).long(), reduction="none").reshape(N, G, self.cfg.num_groups_proposal)
|
406 |
+
|
407 |
+
loss_args = (loss_args * mask).sum(dim=[-1, -2]) / mask.sum(dim=[-1, -2])
|
408 |
+
loss_cmd = (loss_cmd * padding_mask).sum(dim=-1) / padding_mask.sum(dim=-1)
|
409 |
+
|
410 |
+
loss = 2.0 * loss_args + 1.0 * loss_cmd + 1.0 * loss_hierarch
|
411 |
+
|
412 |
+
# Iterate over the batch-dimension
|
413 |
+
assignment_list = []
|
414 |
+
|
415 |
+
full_set = set(range(self.cfg.num_groups_proposal))
|
416 |
+
for i in range(N):
|
417 |
+
costs = loss[i]
|
418 |
+
mask = visibility_mask[i]
|
419 |
+
_, assign = linear_sum_assignment(costs[mask].cpu())
|
420 |
+
assign = assign.tolist()
|
421 |
+
assignment_list.append(assign + list(full_set - set(assign)))
|
422 |
+
|
423 |
+
assignment = torch.tensor(assignment_list, device=command_logits.device)
|
424 |
+
|
425 |
+
return assignment.unsqueeze(-1).unsqueeze(-1)
|
426 |
+
|
427 |
+
@property
|
428 |
+
def origin_empty_path(self):
|
429 |
+
return torch.tensor([
|
430 |
+
11, 16, 7, 23, 24, 10, 13, 5, 1, 8, 3, 3, 7, 15, 7, 18, 15, 31,
|
431 |
+
21, 31, 16, 10, 2, 14, 26, 14, 6, 13, 7, 28, 11, 19, 9, 6, 7, 1,
|
432 |
+
22, 31, 21, 4, 21, 6, 1, 4, 15, 13, 10, 19, 9, 13, 21, 29, 12, 13,
|
433 |
+
10, 23, 15, 11, 1, 18, 19, 5, 23, 20, 7, 29, 13, 15, 22, 31, 17, 10,
|
434 |
+
21, 28, 13, 20, 24, 30, 21, 28, 5, 22, 14, 15, 3, 7, 14, 1, 19, 23,
|
435 |
+
30, 25, 26, 27, 11, 23, 8, 6, 3, 31, 28, 29, 11, 1, 3, 6, 4, 12,
|
436 |
+
12, 25, 0, 18, 5, 26, 5, 12, 23, 14, 19, 25, 12, 20, 2, 3, 18, 11,
|
437 |
+
1, 12
|
438 |
+
])
|
439 |
+
|
440 |
+
# for dalle usage
|
441 |
+
# indices = model.get_codebook_indices(*model_args)
|
442 |
+
# commands_y, args_y = model.decode(indices)
|
443 |
+
@torch.no_grad()
|
444 |
+
@eval_decorator
|
445 |
+
def get_codebook_indices(self, commands_enc, args_enc, commands_dec, args_dec):
|
446 |
+
indices = self(commands_enc, args_enc, commands_dec, args_dec, return_indices=True)
|
447 |
+
return indices
|
448 |
+
|
449 |
+
@torch.no_grad()
|
450 |
+
@eval_decorator
|
451 |
+
def decode(self, codebook_indices):
|
452 |
+
torch.set_printoptions(profile='full')
|
453 |
+
print(codebook_indices.reshape(self.cfg.max_num_groups, self.cfg.max_seq_len + 2))
|
454 |
+
z = self.vqvae.codebook[codebook_indices] # shape [batch_size, num_of_indices, codebook_dim]
|
455 |
+
# args_z = self.args_vqvae.codebook[codebook_indices]
|
456 |
+
|
457 |
+
batch_size = z.shape[0]
|
458 |
+
z = z.reshape(self.cfg.max_num_groups, -1, batch_size, self.cfg.dim_z)
|
459 |
+
|
460 |
+
out_logits = self.decoder(z, None, None)
|
461 |
+
out_logits = _make_batch_first(*out_logits)
|
462 |
+
|
463 |
+
res = {
|
464 |
+
"command_logits": out_logits[0], # shape [batch_size, path_num, command_num + 1, 5]
|
465 |
+
"args_logits": out_logits[1], # shape [batch_size, path_num, command_num + 1, 6]
|
466 |
+
"visibility_logits": out_logits[2]
|
467 |
+
}
|
468 |
+
|
469 |
+
# hack
|
470 |
+
# commands_y, args_y, _ = self.greedy_sample(res=res, commands_dec=cmd_indices)
|
471 |
+
commands_y, args_y, _ = self.greedy_sample(res=res)
|
472 |
+
|
473 |
+
# visualization, but it is not responsible for decode()
|
474 |
+
# tensor_pred = SVGTensor.from_cmd_args(commands_y[0].cpu(), args_y[0].cpu())
|
475 |
+
# svg_path_sample = SVG.from_tensor(tensor_pred.data, viewbox=Bbox(256), allow_empty=True).normalize().zoom(1.5)
|
476 |
+
# svg_path_sample.fill_(True)
|
477 |
+
# svg_path_sample.save_svg('test.svg')
|
478 |
+
|
479 |
+
return commands_y, args_y
|
480 |
+
|
481 |
+
def forward(self, commands_enc, args_enc, commands_dec, args_dec, label=None,
|
482 |
+
z=None, hierarch_logits=None,
|
483 |
+
return_tgt=True, params=None, encode_mode=False, return_hierarch=False, return_indices=False):
|
484 |
+
# commands_enc 中包含 commands 的类型
|
485 |
+
# commands_enc.shape: [batch_size, max_num_groups, max_seq_len + 2]
|
486 |
+
# args_enc.shape: [batch_size, max_num_groups, max_seq_len + 2, n_args]
|
487 |
+
# commands_dec.shape: [batch_size, max_num_groups, max_seq_len + 2]
|
488 |
+
# args_dec.shape: [batch_size, max_num_groups, max_seq_len + 2, n_args]
|
489 |
+
# assert args_enc.equal(args_dec)
|
490 |
+
commands_enc, args_enc = _make_seq_first(commands_enc, args_enc) # Possibly None, None
|
491 |
+
commands_dec_, args_dec_ = _make_seq_first(commands_dec, args_dec)
|
492 |
+
# commands_enc.shape: [max_seq_len + 2, max_num_groups, batch_size]
|
493 |
+
# args_enc.shape: [max_seq_len + 2, max_num_groups, batch_size, 11]
|
494 |
+
|
495 |
+
if z is None:
|
496 |
+
z = self.encoder(commands_enc, args_enc, label)
|
497 |
+
# cmd_z, args_z = self.encoder(commands_enc, args_enc, label)
|
498 |
+
# print('encoded z', z.shape)
|
499 |
+
|
500 |
+
if self.cfg.use_resnet:
|
501 |
+
z = self.resnet(z)
|
502 |
+
|
503 |
+
if self.cfg.use_vae:
|
504 |
+
z, mu, logsigma = self.vae(z)
|
505 |
+
else:
|
506 |
+
# z = self.bottleneck(z)
|
507 |
+
z = self.encoder_norm(self.bottleneck(z))
|
508 |
+
# cmd_z = self.encoder_norm(self.bottleneck(cmd_z))
|
509 |
+
# args_z = self.encoder_norm(self.bottleneck2(args_z))
|
510 |
+
# print('bottleneck z', z)
|
511 |
+
# print('normed z', z, z.shape)
|
512 |
+
|
513 |
+
if self.cfg.use_vqvae or self.cfg.use_rqvae:
|
514 |
+
# initial z.shape [num_path, 1, batch_size, dim_z]
|
515 |
+
# batch_size, max_num_groups = cmd_z.shape[2], cmd_z.shape[0]
|
516 |
+
batch_size, max_num_groups = z.shape[2], z.shape[0]
|
517 |
+
|
518 |
+
# print(z.shape)
|
519 |
+
# z = z.reshape(batch_size, -1, self.cfg.dim_z)
|
520 |
+
# z = z.reshape(max_num_groups, -1, self.cfg.dim_z)
|
521 |
+
z = rearrange(z, 'p c b z -> b (p c) z')
|
522 |
+
# cmd_z = cmd_z.reshape(batch_size, -1, self.cfg.dim_z)
|
523 |
+
# args_z = args_z.reshape(batch_size, -1, self.cfg.dim_z)
|
524 |
+
# print(z.shape)
|
525 |
+
|
526 |
+
# z = rearrange(z, 'p 1 b d -> b 1 p d') # p: num_of_path
|
527 |
+
# # b: batch_size
|
528 |
+
# # d: dim_z
|
529 |
+
# z = self.conv_enc_layer(z)
|
530 |
+
# z = rearrange(z, 'b c p d -> b (p d) c') # b d c: batch_size, dim_z, num_channel
|
531 |
+
|
532 |
+
if self.cfg.use_vqvae:
|
533 |
+
quantized, indices, commit_loss = self.vqvae(z) # tokenization
|
534 |
+
else:
|
535 |
+
quantized, indices, commit_loss = self.rqvae(z)
|
536 |
+
|
537 |
+
if return_indices:
|
538 |
+
return indices
|
539 |
+
|
540 |
+
# z = rearrange(quantized, 'b (p d) c -> b c p d', p = max_num_groups if self.cfg.quantize_path else 1)
|
541 |
+
# z = self.conv_dec_layer(z)
|
542 |
+
# z = rearrange(z, 'b 1 p d -> p 1 b d')
|
543 |
+
# z = quantized.reshape(max_num_groups, -1, batch_size, self.cfg.dim_z)
|
544 |
+
z = rearrange(quantized, 'b (p c) z -> p c b z', p = max_num_groups)
|
545 |
+
|
546 |
+
# cmd_z = cmd_quantized.reshape(max_num_groups, -1, batch_size, self.cfg.dim_z)
|
547 |
+
# args_z = args_quantized.reshape(max_num_groups, -1, batch_size, self.cfg.dim_z)
|
548 |
+
# print(indices)
|
549 |
+
# print('quantized z', z.shape)
|
550 |
+
else:
|
551 |
+
z = _make_seq_first(z)
|
552 |
+
|
553 |
+
if encode_mode: return z
|
554 |
+
|
555 |
+
if return_tgt: # Train mode
|
556 |
+
# remove EOS command
|
557 |
+
# [max_seq_len + 1, max_num_groups, batch_size]
|
558 |
+
commands_dec_, args_dec_ = commands_dec_[:-1], args_dec_[:-1]
|
559 |
+
|
560 |
+
out_logits = self.decoder(z, commands_dec_, args_dec_, label, hierarch_logits=hierarch_logits,
|
561 |
+
return_hierarch=return_hierarch)
|
562 |
+
|
563 |
+
if return_hierarch:
|
564 |
+
return out_logits
|
565 |
+
|
566 |
+
out_logits = _make_batch_first(*out_logits)
|
567 |
+
|
568 |
+
if return_tgt and self.cfg.self_match: # Assignment
|
569 |
+
assert self.cfg.decode_stages == 2 # Self-matching expects two-stage decoder
|
570 |
+
command_logits, args_logits, hierarch_logits = out_logits
|
571 |
+
|
572 |
+
assignment = self.perfect_matching(command_logits, args_logits, hierarch_logits, commands_dec[..., 1:], args_dec[..., 1:, :])
|
573 |
+
|
574 |
+
command_logits = torch.gather(command_logits, dim=1, index=assignment.expand_as(command_logits))
|
575 |
+
args_logits = torch.gather(args_logits, dim=1, index=assignment.unsqueeze(-1).expand_as(args_logits))
|
576 |
+
hierarch_logits = torch.gather(hierarch_logits, dim=1, index=assignment.expand_as(hierarch_logits))
|
577 |
+
|
578 |
+
out_logits = (command_logits, args_logits, hierarch_logits)
|
579 |
+
|
580 |
+
res = {
|
581 |
+
"command_logits": out_logits[0],
|
582 |
+
"args_logits": out_logits[1]
|
583 |
+
}
|
584 |
+
|
585 |
+
if self.cfg.decode_stages == 2:
|
586 |
+
res["visibility_logits"] = out_logits[2]
|
587 |
+
|
588 |
+
if return_tgt:
|
589 |
+
res["tgt_commands"] = commands_dec
|
590 |
+
res["tgt_args"] = args_dec
|
591 |
+
|
592 |
+
if self.cfg.use_vae:
|
593 |
+
res["mu"] = _make_batch_first(mu)
|
594 |
+
res["logsigma"] = _make_batch_first(logsigma)
|
595 |
+
|
596 |
+
if self.cfg.use_vqvae:
|
597 |
+
res["vqvae_loss"] = commit_loss
|
598 |
+
return res
|
599 |
+
|
600 |
+
def greedy_sample(self, commands_enc=None, args_enc=None, commands_dec=None, args_dec=None, label=None,
|
601 |
+
z=None, hierarch_logits=None,
|
602 |
+
concat_groups=True, temperature=0.0001, res=None):
|
603 |
+
if self.cfg.pred_mode == "one_shot":
|
604 |
+
if res is None:
|
605 |
+
res = self.forward(commands_enc, args_enc, commands_dec, args_dec, label=label, z=z, hierarch_logits=hierarch_logits, return_tgt=True)
|
606 |
+
|
607 |
+
commands_y = _sample_categorical(temperature, res["command_logits"])
|
608 |
+
# hack
|
609 |
+
# commands_y = commands_dec.reshape(1, 8, 32)[..., 1:]
|
610 |
+
if self.cfg.abs_targets:
|
611 |
+
# 此时 args 不需要采样
|
612 |
+
# 模型可能直接输出 -1, 所以我们不需要 args_y -= 1
|
613 |
+
# 但是 SVG 坐标的范围是 0-255, 我们仍然需要 clamp, 并手动将其转换为整数
|
614 |
+
# 那些应该填 "-1" 的位置会在 _make_valid 中被 mask 过滤掉
|
615 |
+
# args_y = torch.clamp(res['args_logits'], min=0, max=255).int()
|
616 |
+
# args_y = torch.clamp(res['args_logits'], min=0, max=256)
|
617 |
+
# args_y = (res['args_logits'] + 1) * 128 - 1
|
618 |
+
args_y = (res['args_logits'] + 1) * 12
|
619 |
+
elif self.cfg.bin_targets:
|
620 |
+
# 此时 args 也不需要采样
|
621 |
+
# 我们需要一个 threshold, logits < threshold is 0, logits >= threshold is 1
|
622 |
+
threshold = 0.0
|
623 |
+
args_logits = res['args_logits']
|
624 |
+
args_y = torch.where(args_logits > threshold, torch.ones_like(args_logits), torch.zeros_like(args_logits))
|
625 |
+
args_y = bit2int(args_y)
|
626 |
+
else:
|
627 |
+
args_y = _sample_categorical(temperature, res["args_logits"])
|
628 |
+
args_y -= 1 # shift due to -1 PAD_VAL
|
629 |
+
|
630 |
+
visibility_y = _threshold_sample(res["visibility_logits"], threshold=0.7).bool().squeeze(-1) if self.cfg.decode_stages == 2 else None
|
631 |
+
commands_y, args_y = self._make_valid(commands_y, args_y, visibility_y)
|
632 |
+
else:
|
633 |
+
if z is None:
|
634 |
+
z = self.forward(commands_enc, args_enc, None, None, label=label, encode_mode=True)
|
635 |
+
|
636 |
+
PAD_VAL = 0
|
637 |
+
commands_y, args_y = z.new_zeros(1, 1, 1).fill_(SVGTensor.COMMANDS_SIMPLIFIED.index("SOS")).long(), z.new_ones(1, 1, 1, self.cfg.n_args).fill_(PAD_VAL).long()
|
638 |
+
|
639 |
+
for i in range(self.cfg.max_total_len):
|
640 |
+
res = self.forward(None, None, commands_y, args_y, label=label, z=z, hierarch_logits=hierarch_logits, return_tgt=False)
|
641 |
+
commands_new_y, args_new_y = _sample_categorical(temperature, res["command_logits"], res["args_logits"])
|
642 |
+
args_new_y -= 1 # shift due to -1 PAD_VAL
|
643 |
+
_, args_new_y = self._make_valid(commands_new_y, args_new_y)
|
644 |
+
|
645 |
+
commands_y, args_y = torch.cat([commands_y, commands_new_y[..., -1:]], dim=-1), torch.cat([args_y, args_new_y[..., -1:, :]], dim=-2)
|
646 |
+
|
647 |
+
commands_y, args_y = commands_y[..., 1:], args_y[..., 1:, :] # Discard SOS token
|
648 |
+
|
649 |
+
if self.cfg.rel_targets:
|
650 |
+
args_y = self._make_absolute(commands_y, args_y)
|
651 |
+
|
652 |
+
if concat_groups:
|
653 |
+
N = commands_y.size(0)
|
654 |
+
# 必须使用 commands_y, 而不能用 tgt_commands
|
655 |
+
# 因为 commands_y 可能会有多余的 EOS, EOS 是无法可视化的
|
656 |
+
padding_mask_y = _get_padding_mask(commands_y, seq_dim=-1).bool()
|
657 |
+
commands_y, args_y = commands_y[padding_mask_y].reshape(N, -1), args_y[padding_mask_y].reshape(N, -1, self.cfg.n_args)
|
658 |
+
|
659 |
+
return commands_y, args_y, res
|
660 |
+
|
661 |
+
def _make_valid(self, commands_y, args_y, visibility_y=None, PAD_VAL=0):
|
662 |
+
if visibility_y is not None:
|
663 |
+
S = commands_y.size(-1)
|
664 |
+
commands_y[~visibility_y] = commands_y.new_tensor([SVGTensor.COMMANDS_SIMPLIFIED.index("m"), *[SVGTensor.COMMANDS_SIMPLIFIED.index("EOS")] * (S - 1)])
|
665 |
+
args_y[~visibility_y] = PAD_VAL
|
666 |
+
|
667 |
+
mask = self.cmd_args_mask[commands_y.long()].bool()
|
668 |
+
args_y[~mask] = PAD_VAL
|
669 |
+
|
670 |
+
return commands_y, args_y
|
671 |
+
|
672 |
+
def _make_absolute(self, commands_y, args_y):
|
673 |
+
|
674 |
+
mask = self.cmd_args_mask[commands_y.long()].bool()
|
675 |
+
args_y[mask] -= self.cfg.args_dim - 1
|
676 |
+
|
677 |
+
real_commands = commands_y < SVGTensor.COMMANDS_SIMPLIFIED.index("EOS")
|
678 |
+
|
679 |
+
args_real_commands = args_y[real_commands]
|
680 |
+
end_pos = args_real_commands[:-1, SVGTensor.IndexArgs.END_POS].cumsum(dim=0)
|
681 |
+
|
682 |
+
args_real_commands[1:, SVGTensor.IndexArgs.CONTROL1] += end_pos
|
683 |
+
args_real_commands[1:, SVGTensor.IndexArgs.CONTROL2] += end_pos
|
684 |
+
args_real_commands[1:, SVGTensor.IndexArgs.END_POS] += end_pos
|
685 |
+
|
686 |
+
args_y[real_commands] = args_real_commands
|
687 |
+
|
688 |
+
_, args_y = self._make_valid(commands_y, args_y)
|
689 |
+
|
690 |
+
return args_y
|
deepsvg/model/utils.py
ADDED
@@ -0,0 +1,84 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from deepsvg.difflib.tensor import SVGTensor
|
3 |
+
from torch.distributions.categorical import Categorical
|
4 |
+
import torch.nn.functional as F
|
5 |
+
|
6 |
+
|
7 |
+
def _get_key_padding_mask(commands, seq_dim=0):
|
8 |
+
"""
|
9 |
+
Args:
|
10 |
+
commands: Shape [S, ...]
|
11 |
+
"""
|
12 |
+
with torch.no_grad():
|
13 |
+
key_padding_mask = (commands == SVGTensor.COMMANDS_SIMPLIFIED.index("EOS")).cumsum(dim=seq_dim) > 0
|
14 |
+
|
15 |
+
if seq_dim == 0:
|
16 |
+
return key_padding_mask.transpose(0, 1)
|
17 |
+
return key_padding_mask
|
18 |
+
|
19 |
+
|
20 |
+
def _get_padding_mask(commands, seq_dim=0, extended=False):
|
21 |
+
with torch.no_grad():
|
22 |
+
padding_mask = (commands == SVGTensor.COMMANDS_SIMPLIFIED.index("EOS")).cumsum(dim=seq_dim) == 0
|
23 |
+
padding_mask = padding_mask.float()
|
24 |
+
|
25 |
+
if extended:
|
26 |
+
# padding_mask doesn't include the final EOS, extend by 1 position to include it in the loss
|
27 |
+
S = commands.size(seq_dim)
|
28 |
+
torch.narrow(padding_mask, seq_dim, 3, S-3).add_(torch.narrow(padding_mask, seq_dim, 0, S-3)).clamp_(max=1)
|
29 |
+
|
30 |
+
if seq_dim == 0:
|
31 |
+
return padding_mask.unsqueeze(-1)
|
32 |
+
return padding_mask
|
33 |
+
|
34 |
+
|
35 |
+
def _get_group_mask(commands, seq_dim=0):
|
36 |
+
"""
|
37 |
+
Args:
|
38 |
+
commands: Shape [S, ...]
|
39 |
+
"""
|
40 |
+
with torch.no_grad():
|
41 |
+
group_mask = (commands == SVGTensor.COMMANDS_SIMPLIFIED.index("m")).cumsum(dim=seq_dim)
|
42 |
+
return group_mask
|
43 |
+
|
44 |
+
|
45 |
+
def _get_visibility_mask(commands, seq_dim=0):
|
46 |
+
"""
|
47 |
+
Args:
|
48 |
+
commands: Shape [S, ...]
|
49 |
+
"""
|
50 |
+
S = commands.size(seq_dim)
|
51 |
+
with torch.no_grad():
|
52 |
+
visibility_mask = (commands == SVGTensor.COMMANDS_SIMPLIFIED.index("EOS")).sum(dim=seq_dim) < S - 1
|
53 |
+
|
54 |
+
if seq_dim == 0:
|
55 |
+
return visibility_mask.unsqueeze(-1)
|
56 |
+
return visibility_mask
|
57 |
+
|
58 |
+
|
59 |
+
def _get_key_visibility_mask(commands, seq_dim=0):
|
60 |
+
S = commands.size(seq_dim)
|
61 |
+
with torch.no_grad():
|
62 |
+
key_visibility_mask = (commands == SVGTensor.COMMANDS_SIMPLIFIED.index("EOS")).sum(dim=seq_dim) >= S - 1
|
63 |
+
|
64 |
+
if seq_dim == 0:
|
65 |
+
return key_visibility_mask.transpose(0, 1)
|
66 |
+
return key_visibility_mask
|
67 |
+
|
68 |
+
|
69 |
+
def _generate_square_subsequent_mask(sz):
|
70 |
+
mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
|
71 |
+
mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
|
72 |
+
return mask
|
73 |
+
|
74 |
+
|
75 |
+
def _sample_categorical(temperature=0.0001, *args_logits):
|
76 |
+
if len(args_logits) == 1:
|
77 |
+
arg_logits, = args_logits
|
78 |
+
return Categorical(logits=arg_logits / temperature).sample()
|
79 |
+
return (*(Categorical(logits=arg_logits / temperature).sample() for arg_logits in args_logits),)
|
80 |
+
|
81 |
+
|
82 |
+
def _threshold_sample(arg_logits, threshold=0.5, temperature=1.0):
|
83 |
+
scores = F.softmax(arg_logits / temperature, dim=-1)[..., 1]
|
84 |
+
return scores > threshold
|
deepsvg/model/vector_quantize_pytorch.py
ADDED
@@ -0,0 +1,605 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch import nn, einsum
|
3 |
+
import torch.nn.functional as F
|
4 |
+
import torch.distributed as distributed
|
5 |
+
from torch.cuda.amp import autocast
|
6 |
+
|
7 |
+
from einops import rearrange, repeat
|
8 |
+
from contextlib import contextmanager
|
9 |
+
|
10 |
+
def exists(val):
|
11 |
+
return val is not None
|
12 |
+
|
13 |
+
def default(val, d):
|
14 |
+
return val if exists(val) else d
|
15 |
+
|
16 |
+
def noop(*args, **kwargs):
|
17 |
+
pass
|
18 |
+
|
19 |
+
def l2norm(t):
|
20 |
+
return F.normalize(t, p = 2, dim = -1)
|
21 |
+
|
22 |
+
def log(t, eps = 1e-20):
|
23 |
+
return torch.log(t.clamp(min = eps))
|
24 |
+
|
25 |
+
def uniform_init(*shape):
|
26 |
+
t = torch.empty(shape)
|
27 |
+
nn.init.kaiming_uniform_(t)
|
28 |
+
return t
|
29 |
+
|
30 |
+
def gumbel_noise(t):
|
31 |
+
noise = torch.zeros_like(t).uniform_(0, 1)
|
32 |
+
return -log(-log(noise))
|
33 |
+
|
34 |
+
def gumbel_sample(t, temperature = 1., dim = -1):
|
35 |
+
if temperature == 0:
|
36 |
+
return t.argmax(dim = dim)
|
37 |
+
|
38 |
+
return ((t / temperature) + gumbel_noise(t)).argmax(dim = dim)
|
39 |
+
|
40 |
+
def ema_inplace(moving_avg, new, decay):
|
41 |
+
moving_avg.data.mul_(decay).add_(new, alpha = (1 - decay))
|
42 |
+
|
43 |
+
def laplace_smoothing(x, n_categories, eps = 1e-5):
|
44 |
+
return (x + eps) / (x.sum() + n_categories * eps)
|
45 |
+
|
46 |
+
def sample_vectors(samples, num):
|
47 |
+
num_samples, device = samples.shape[0], samples.device
|
48 |
+
if num_samples >= num:
|
49 |
+
indices = torch.randperm(num_samples, device = device)[:num]
|
50 |
+
else:
|
51 |
+
indices = torch.randint(0, num_samples, (num,), device = device)
|
52 |
+
|
53 |
+
return samples[indices]
|
54 |
+
|
55 |
+
def batched_sample_vectors(samples, num):
|
56 |
+
return torch.stack([sample_vectors(sample, num) for sample in samples.unbind(dim = 0)], dim = 0)
|
57 |
+
|
58 |
+
def pad_shape(shape, size, dim = 0):
|
59 |
+
return [size if i == dim else s for i, s in enumerate(shape)]
|
60 |
+
|
61 |
+
def sample_multinomial(total_count, probs):
|
62 |
+
device = probs.device
|
63 |
+
probs = probs.cpu()
|
64 |
+
|
65 |
+
total_count = probs.new_full((), total_count)
|
66 |
+
remainder = probs.new_ones(())
|
67 |
+
sample = torch.empty_like(probs, dtype = torch.long)
|
68 |
+
|
69 |
+
for i, p in enumerate(probs):
|
70 |
+
s = torch.binomial(total_count, p / remainder)
|
71 |
+
sample[i] = s
|
72 |
+
total_count -= s
|
73 |
+
remainder -= p
|
74 |
+
|
75 |
+
return sample.to(device)
|
76 |
+
|
77 |
+
def all_gather_sizes(x, dim):
|
78 |
+
size = torch.tensor(x.shape[dim], dtype = torch.long, device = x.device)
|
79 |
+
all_sizes = [torch.empty_like(size) for _ in range(distributed.get_world_size())]
|
80 |
+
distributed.all_gather(all_sizes, size)
|
81 |
+
return torch.stack(all_sizes)
|
82 |
+
|
83 |
+
def all_gather_variably_sized(x, sizes, dim = 0):
|
84 |
+
rank = distributed.get_rank()
|
85 |
+
all_x = []
|
86 |
+
|
87 |
+
for i, size in enumerate(sizes):
|
88 |
+
t = x if i == rank else x.new_empty(pad_shape(x.shape, size, dim))
|
89 |
+
distributed.broadcast(t, src = i, async_op = True)
|
90 |
+
all_x.append(t)
|
91 |
+
|
92 |
+
distributed.barrier()
|
93 |
+
return all_x
|
94 |
+
|
95 |
+
def sample_vectors_distributed(local_samples, num):
|
96 |
+
local_samples = rearrange(local_samples, '1 ... -> ...')
|
97 |
+
|
98 |
+
rank = distributed.get_rank()
|
99 |
+
all_num_samples = all_gather_sizes(local_samples, dim = 0)
|
100 |
+
|
101 |
+
if rank == 0:
|
102 |
+
samples_per_rank = sample_multinomial(num, all_num_samples / all_num_samples.sum())
|
103 |
+
else:
|
104 |
+
samples_per_rank = torch.empty_like(all_num_samples)
|
105 |
+
|
106 |
+
distributed.broadcast(samples_per_rank, src = 0)
|
107 |
+
samples_per_rank = samples_per_rank.tolist()
|
108 |
+
|
109 |
+
local_samples = sample_vectors(local_samples, samples_per_rank[rank])
|
110 |
+
all_samples = all_gather_variably_sized(local_samples, samples_per_rank, dim = 0)
|
111 |
+
out = torch.cat(all_samples, dim = 0)
|
112 |
+
|
113 |
+
return rearrange(out, '... -> 1 ...')
|
114 |
+
|
115 |
+
def batched_bincount(x, *, minlength):
|
116 |
+
batch, dtype, device = x.shape[0], x.dtype, x.device
|
117 |
+
target = torch.zeros(batch, minlength, dtype = dtype, device = device)
|
118 |
+
values = torch.ones_like(x)
|
119 |
+
target.scatter_add_(-1, x, values)
|
120 |
+
return target
|
121 |
+
|
122 |
+
def kmeans(
|
123 |
+
samples,
|
124 |
+
num_clusters,
|
125 |
+
num_iters = 10,
|
126 |
+
use_cosine_sim = False,
|
127 |
+
sample_fn = batched_sample_vectors,
|
128 |
+
all_reduce_fn = noop
|
129 |
+
):
|
130 |
+
num_codebooks, dim, dtype, device = samples.shape[0], samples.shape[-1], samples.dtype, samples.device
|
131 |
+
|
132 |
+
means = sample_fn(samples, num_clusters)
|
133 |
+
|
134 |
+
for _ in range(num_iters):
|
135 |
+
if use_cosine_sim:
|
136 |
+
dists = samples @ rearrange(means, 'h n d -> h d n')
|
137 |
+
else:
|
138 |
+
dists = -torch.cdist(samples, means, p = 2)
|
139 |
+
|
140 |
+
buckets = torch.argmax(dists, dim = -1)
|
141 |
+
bins = batched_bincount(buckets, minlength = num_clusters)
|
142 |
+
all_reduce_fn(bins)
|
143 |
+
|
144 |
+
zero_mask = bins == 0
|
145 |
+
bins_min_clamped = bins.masked_fill(zero_mask, 1)
|
146 |
+
|
147 |
+
new_means = buckets.new_zeros(num_codebooks, num_clusters, dim, dtype = dtype)
|
148 |
+
|
149 |
+
new_means.scatter_add_(1, repeat(buckets, 'h n -> h n d', d = dim), samples)
|
150 |
+
new_means = new_means / rearrange(bins_min_clamped, '... -> ... 1')
|
151 |
+
all_reduce_fn(new_means)
|
152 |
+
|
153 |
+
if use_cosine_sim:
|
154 |
+
new_means = l2norm(new_means)
|
155 |
+
|
156 |
+
means = torch.where(
|
157 |
+
rearrange(zero_mask, '... -> ... 1'),
|
158 |
+
means,
|
159 |
+
new_means
|
160 |
+
)
|
161 |
+
|
162 |
+
return means, bins
|
163 |
+
|
164 |
+
def batched_embedding(indices, embeds):
|
165 |
+
batch, dim = indices.shape[1], embeds.shape[-1]
|
166 |
+
indices = repeat(indices, 'h b n -> h b n d', d = dim)
|
167 |
+
embeds = repeat(embeds, 'h c d -> h b c d', b = batch)
|
168 |
+
return embeds.gather(2, indices)
|
169 |
+
|
170 |
+
# regularization losses
|
171 |
+
|
172 |
+
def orthogonal_loss_fn(t):
|
173 |
+
# eq (2) from https://arxiv.org/abs/2112.00384
|
174 |
+
h, n = t.shape[:2]
|
175 |
+
normed_codes = l2norm(t)
|
176 |
+
cosine_sim = einsum('h i d, h j d -> h i j', normed_codes, normed_codes)
|
177 |
+
return (cosine_sim ** 2).sum() / (h * n ** 2) - (1 / n)
|
178 |
+
|
179 |
+
# distance types
|
180 |
+
|
181 |
+
class EuclideanCodebook(nn.Module):
|
182 |
+
def __init__(
|
183 |
+
self,
|
184 |
+
dim,
|
185 |
+
codebook_size,
|
186 |
+
num_codebooks = 1,
|
187 |
+
kmeans_init = False,
|
188 |
+
kmeans_iters = 10,
|
189 |
+
sync_kmeans = True,
|
190 |
+
decay = 0.8,
|
191 |
+
eps = 1e-5,
|
192 |
+
threshold_ema_dead_code = 2,
|
193 |
+
use_ddp = False,
|
194 |
+
learnable_codebook = False,
|
195 |
+
sample_codebook_temp = 0
|
196 |
+
):
|
197 |
+
super().__init__()
|
198 |
+
self.decay = decay
|
199 |
+
init_fn = uniform_init if not kmeans_init else torch.zeros
|
200 |
+
embed = init_fn(num_codebooks, codebook_size, dim)
|
201 |
+
|
202 |
+
self.codebook_size = codebook_size
|
203 |
+
self.num_codebooks = num_codebooks
|
204 |
+
|
205 |
+
self.kmeans_iters = kmeans_iters
|
206 |
+
self.eps = eps
|
207 |
+
self.threshold_ema_dead_code = threshold_ema_dead_code
|
208 |
+
self.sample_codebook_temp = sample_codebook_temp
|
209 |
+
|
210 |
+
assert not (use_ddp and num_codebooks > 1 and kmeans_init), 'kmeans init is not compatible with multiple codebooks in distributed environment for now'
|
211 |
+
|
212 |
+
self.sample_fn = sample_vectors_distributed if use_ddp and sync_kmeans else batched_sample_vectors
|
213 |
+
self.kmeans_all_reduce_fn = distributed.all_reduce if use_ddp and sync_kmeans else noop
|
214 |
+
self.all_reduce_fn = distributed.all_reduce if use_ddp else noop
|
215 |
+
|
216 |
+
self.register_buffer('initted', torch.Tensor([not kmeans_init]))
|
217 |
+
self.register_buffer('cluster_size', torch.zeros(num_codebooks, codebook_size))
|
218 |
+
self.register_buffer('embed_avg', embed.clone())
|
219 |
+
|
220 |
+
self.learnable_codebook = learnable_codebook
|
221 |
+
if learnable_codebook:
|
222 |
+
self.embed = nn.Parameter(embed)
|
223 |
+
else:
|
224 |
+
self.register_buffer('embed', embed)
|
225 |
+
|
226 |
+
@torch.jit.ignore
|
227 |
+
def init_embed_(self, data):
|
228 |
+
if self.initted:
|
229 |
+
return
|
230 |
+
|
231 |
+
embed, cluster_size = kmeans(
|
232 |
+
data,
|
233 |
+
self.codebook_size,
|
234 |
+
self.kmeans_iters,
|
235 |
+
sample_fn = self.sample_fn,
|
236 |
+
all_reduce_fn = self.kmeans_all_reduce_fn
|
237 |
+
)
|
238 |
+
|
239 |
+
self.embed.data.copy_(embed)
|
240 |
+
self.embed_avg.data.copy_(embed.clone())
|
241 |
+
self.cluster_size.data.copy_(cluster_size)
|
242 |
+
self.initted.data.copy_(torch.Tensor([True]))
|
243 |
+
|
244 |
+
def replace(self, batch_samples, batch_mask):
|
245 |
+
batch_samples = l2norm(batch_samples)
|
246 |
+
|
247 |
+
for ind, (samples, mask) in enumerate(zip(batch_samples.unbind(dim = 0), batch_mask.unbind(dim = 0))):
|
248 |
+
if not torch.any(mask):
|
249 |
+
continue
|
250 |
+
|
251 |
+
sampled = self.sample_fn(rearrange(samples, '... -> 1 ...'), mask.sum().item())
|
252 |
+
self.embed.data[ind][mask] = rearrange(sampled, '1 ... -> ...')
|
253 |
+
|
254 |
+
def expire_codes_(self, batch_samples):
|
255 |
+
if self.threshold_ema_dead_code == 0:
|
256 |
+
return
|
257 |
+
|
258 |
+
expired_codes = self.cluster_size < self.threshold_ema_dead_code
|
259 |
+
|
260 |
+
if not torch.any(expired_codes):
|
261 |
+
return
|
262 |
+
|
263 |
+
batch_samples = rearrange(batch_samples, 'h ... d -> h (...) d')
|
264 |
+
self.replace(batch_samples, batch_mask = expired_codes)
|
265 |
+
|
266 |
+
@autocast(enabled = False)
|
267 |
+
def forward(self, x):
|
268 |
+
needs_codebook_dim = x.ndim < 4
|
269 |
+
|
270 |
+
x = x.float()
|
271 |
+
|
272 |
+
if needs_codebook_dim:
|
273 |
+
x = rearrange(x, '... -> 1 ...')
|
274 |
+
|
275 |
+
shape, dtype = x.shape, x.dtype
|
276 |
+
flatten = rearrange(x, 'h ... d -> h (...) d')
|
277 |
+
|
278 |
+
self.init_embed_(flatten)
|
279 |
+
|
280 |
+
embed = self.embed if not self.learnable_codebook else self.embed.detach()
|
281 |
+
|
282 |
+
dist = -torch.cdist(flatten, embed, p = 2)
|
283 |
+
|
284 |
+
embed_ind = gumbel_sample(dist, dim = -1, temperature = self.sample_codebook_temp)
|
285 |
+
embed_onehot = F.one_hot(embed_ind, self.codebook_size).type(dtype)
|
286 |
+
embed_ind = embed_ind.view(*shape[:-1])
|
287 |
+
|
288 |
+
quantize = batched_embedding(embed_ind, self.embed)
|
289 |
+
|
290 |
+
if self.training:
|
291 |
+
cluster_size = embed_onehot.sum(dim = 1)
|
292 |
+
|
293 |
+
self.all_reduce_fn(cluster_size)
|
294 |
+
ema_inplace(self.cluster_size, cluster_size, self.decay)
|
295 |
+
|
296 |
+
embed_sum = einsum('h n d, h n c -> h c d', flatten, embed_onehot)
|
297 |
+
self.all_reduce_fn(embed_sum.contiguous())
|
298 |
+
ema_inplace(self.embed_avg, embed_sum, self.decay)
|
299 |
+
|
300 |
+
cluster_size = laplace_smoothing(self.cluster_size, self.codebook_size, self.eps) * self.cluster_size.sum()
|
301 |
+
|
302 |
+
embed_normalized = self.embed_avg / rearrange(cluster_size, '... -> ... 1')
|
303 |
+
self.embed.data.copy_(embed_normalized)
|
304 |
+
self.expire_codes_(x)
|
305 |
+
|
306 |
+
if needs_codebook_dim:
|
307 |
+
quantize, embed_ind = map(lambda t: rearrange(t, '1 ... -> ...'), (quantize, embed_ind))
|
308 |
+
|
309 |
+
return quantize, embed_ind
|
310 |
+
|
311 |
+
class CosineSimCodebook(nn.Module):
|
312 |
+
def __init__(
|
313 |
+
self,
|
314 |
+
dim,
|
315 |
+
codebook_size,
|
316 |
+
num_codebooks = 1,
|
317 |
+
kmeans_init = False,
|
318 |
+
kmeans_iters = 10,
|
319 |
+
sync_kmeans = True,
|
320 |
+
decay = 0.8,
|
321 |
+
eps = 1e-5,
|
322 |
+
threshold_ema_dead_code = 2,
|
323 |
+
use_ddp = False,
|
324 |
+
learnable_codebook = False,
|
325 |
+
sample_codebook_temp = 0.
|
326 |
+
):
|
327 |
+
super().__init__()
|
328 |
+
self.decay = decay
|
329 |
+
|
330 |
+
if not kmeans_init:
|
331 |
+
embed = l2norm(uniform_init(num_codebooks, codebook_size, dim))
|
332 |
+
else:
|
333 |
+
embed = torch.zeros(num_codebooks, codebook_size, dim)
|
334 |
+
|
335 |
+
self.codebook_size = codebook_size
|
336 |
+
self.num_codebooks = num_codebooks
|
337 |
+
|
338 |
+
self.kmeans_iters = kmeans_iters
|
339 |
+
self.eps = eps
|
340 |
+
self.threshold_ema_dead_code = threshold_ema_dead_code
|
341 |
+
self.sample_codebook_temp = sample_codebook_temp
|
342 |
+
|
343 |
+
self.sample_fn = sample_vectors_distributed if use_ddp and sync_kmeans else batched_sample_vectors
|
344 |
+
self.kmeans_all_reduce_fn = distributed.all_reduce if use_ddp and sync_kmeans else noop
|
345 |
+
self.all_reduce_fn = distributed.all_reduce if use_ddp else noop
|
346 |
+
|
347 |
+
self.register_buffer('initted', torch.Tensor([not kmeans_init]))
|
348 |
+
self.register_buffer('cluster_size', torch.zeros(num_codebooks, codebook_size))
|
349 |
+
|
350 |
+
self.learnable_codebook = learnable_codebook
|
351 |
+
if learnable_codebook:
|
352 |
+
self.embed = nn.Parameter(embed)
|
353 |
+
else:
|
354 |
+
self.register_buffer('embed', embed)
|
355 |
+
|
356 |
+
@torch.jit.ignore
|
357 |
+
def init_embed_(self, data):
|
358 |
+
if self.initted:
|
359 |
+
return
|
360 |
+
|
361 |
+
embed, cluster_size = kmeans(
|
362 |
+
data,
|
363 |
+
self.codebook_size,
|
364 |
+
self.kmeans_iters,
|
365 |
+
use_cosine_sim = True,
|
366 |
+
sample_fn = self.sample_fn,
|
367 |
+
all_reduce_fn = self.kmeans_all_reduce_fn
|
368 |
+
)
|
369 |
+
|
370 |
+
self.embed.data.copy_(embed)
|
371 |
+
self.cluster_size.data.copy_(cluster_size)
|
372 |
+
self.initted.data.copy_(torch.Tensor([True]))
|
373 |
+
|
374 |
+
def replace(self, batch_samples, batch_mask):
|
375 |
+
batch_samples = l2norm(batch_samples)
|
376 |
+
|
377 |
+
for ind, (samples, mask) in enumerate(zip(batch_samples.unbind(dim = 0), batch_mask.unbind(dim = 0))):
|
378 |
+
if not torch.any(mask):
|
379 |
+
continue
|
380 |
+
|
381 |
+
sampled = self.sample_fn(rearrange(samples, '... -> 1 ...'), mask.sum().item())
|
382 |
+
self.embed.data[ind][mask] = rearrange(sampled, '1 ... -> ...')
|
383 |
+
|
384 |
+
def expire_codes_(self, batch_samples):
|
385 |
+
if self.threshold_ema_dead_code == 0:
|
386 |
+
return
|
387 |
+
|
388 |
+
expired_codes = self.cluster_size < self.threshold_ema_dead_code
|
389 |
+
|
390 |
+
if not torch.any(expired_codes):
|
391 |
+
return
|
392 |
+
|
393 |
+
batch_samples = rearrange(batch_samples, 'h ... d -> h (...) d')
|
394 |
+
self.replace(batch_samples, batch_mask = expired_codes)
|
395 |
+
|
396 |
+
@autocast(enabled = False)
|
397 |
+
def forward(self, x):
|
398 |
+
needs_codebook_dim = x.ndim < 4
|
399 |
+
|
400 |
+
x = x.float()
|
401 |
+
|
402 |
+
if needs_codebook_dim:
|
403 |
+
x = rearrange(x, '... -> 1 ...')
|
404 |
+
|
405 |
+
shape, dtype = x.shape, x.dtype
|
406 |
+
|
407 |
+
flatten = rearrange(x, 'h ... d -> h (...) d')
|
408 |
+
flatten = l2norm(flatten)
|
409 |
+
|
410 |
+
self.init_embed_(flatten)
|
411 |
+
|
412 |
+
embed = self.embed if not self.learnable_codebook else self.embed.detach()
|
413 |
+
embed = l2norm(embed)
|
414 |
+
|
415 |
+
dist = einsum('h n d, h c d -> h n c', flatten, embed)
|
416 |
+
embed_ind = gumbel_sample(dist, dim = -1, temperature = self.sample_codebook_temp)
|
417 |
+
embed_onehot = F.one_hot(embed_ind, self.codebook_size).type(dtype)
|
418 |
+
embed_ind = embed_ind.view(*shape[:-1])
|
419 |
+
|
420 |
+
quantize = batched_embedding(embed_ind, self.embed)
|
421 |
+
|
422 |
+
if self.training:
|
423 |
+
bins = embed_onehot.sum(dim = 1)
|
424 |
+
self.all_reduce_fn(bins)
|
425 |
+
|
426 |
+
ema_inplace(self.cluster_size, bins, self.decay)
|
427 |
+
|
428 |
+
zero_mask = (bins == 0)
|
429 |
+
bins = bins.masked_fill(zero_mask, 1.)
|
430 |
+
|
431 |
+
embed_sum = einsum('h n d, h n c -> h c d', flatten, embed_onehot)
|
432 |
+
self.all_reduce_fn(embed_sum)
|
433 |
+
|
434 |
+
embed_normalized = embed_sum / rearrange(bins, '... -> ... 1')
|
435 |
+
embed_normalized = l2norm(embed_normalized)
|
436 |
+
|
437 |
+
embed_normalized = torch.where(
|
438 |
+
rearrange(zero_mask, '... -> ... 1'),
|
439 |
+
embed,
|
440 |
+
embed_normalized
|
441 |
+
)
|
442 |
+
|
443 |
+
ema_inplace(self.embed, embed_normalized, self.decay)
|
444 |
+
self.expire_codes_(x)
|
445 |
+
|
446 |
+
if needs_codebook_dim:
|
447 |
+
quantize, embed_ind = map(lambda t: rearrange(t, '1 ... -> ...'), (quantize, embed_ind))
|
448 |
+
|
449 |
+
return quantize, embed_ind
|
450 |
+
|
451 |
+
# main class
|
452 |
+
|
453 |
+
class VectorQuantize(nn.Module):
|
454 |
+
def __init__(
|
455 |
+
self,
|
456 |
+
dim,
|
457 |
+
codebook_size,
|
458 |
+
codebook_dim = None,
|
459 |
+
heads = 1,
|
460 |
+
separate_codebook_per_head = False,
|
461 |
+
decay = 0.8,
|
462 |
+
eps = 1e-5,
|
463 |
+
kmeans_init = False,
|
464 |
+
kmeans_iters = 10,
|
465 |
+
sync_kmeans = True,
|
466 |
+
use_cosine_sim = False,
|
467 |
+
threshold_ema_dead_code = 0,
|
468 |
+
channel_last = True,
|
469 |
+
accept_image_fmap = False,
|
470 |
+
commitment_weight = 1.,
|
471 |
+
orthogonal_reg_weight = 0.,
|
472 |
+
orthogonal_reg_active_codes_only = False,
|
473 |
+
orthogonal_reg_max_codes = None,
|
474 |
+
sample_codebook_temp = 0.,
|
475 |
+
sync_codebook = False
|
476 |
+
):
|
477 |
+
super().__init__()
|
478 |
+
self.heads = heads
|
479 |
+
self.separate_codebook_per_head = separate_codebook_per_head
|
480 |
+
|
481 |
+
codebook_dim = default(codebook_dim, dim)
|
482 |
+
codebook_input_dim = codebook_dim * heads
|
483 |
+
|
484 |
+
requires_projection = codebook_input_dim != dim
|
485 |
+
self.project_in = nn.Linear(dim, codebook_input_dim) if requires_projection else nn.Identity()
|
486 |
+
self.project_out = nn.Linear(codebook_input_dim, dim) if requires_projection else nn.Identity()
|
487 |
+
|
488 |
+
self.eps = eps
|
489 |
+
self.commitment_weight = commitment_weight
|
490 |
+
|
491 |
+
has_codebook_orthogonal_loss = orthogonal_reg_weight > 0
|
492 |
+
self.orthogonal_reg_weight = orthogonal_reg_weight
|
493 |
+
self.orthogonal_reg_active_codes_only = orthogonal_reg_active_codes_only
|
494 |
+
self.orthogonal_reg_max_codes = orthogonal_reg_max_codes
|
495 |
+
|
496 |
+
codebook_class = EuclideanCodebook if not use_cosine_sim else CosineSimCodebook
|
497 |
+
|
498 |
+
self._codebook = codebook_class(
|
499 |
+
dim = codebook_dim,
|
500 |
+
num_codebooks = heads if separate_codebook_per_head else 1,
|
501 |
+
codebook_size = codebook_size,
|
502 |
+
kmeans_init = kmeans_init,
|
503 |
+
kmeans_iters = kmeans_iters,
|
504 |
+
sync_kmeans = sync_kmeans,
|
505 |
+
decay = decay,
|
506 |
+
eps = eps,
|
507 |
+
threshold_ema_dead_code = threshold_ema_dead_code,
|
508 |
+
use_ddp = sync_codebook,
|
509 |
+
learnable_codebook = has_codebook_orthogonal_loss,
|
510 |
+
sample_codebook_temp = sample_codebook_temp
|
511 |
+
)
|
512 |
+
|
513 |
+
self.codebook_size = codebook_size
|
514 |
+
|
515 |
+
self.accept_image_fmap = accept_image_fmap
|
516 |
+
self.channel_last = channel_last
|
517 |
+
|
518 |
+
@property
|
519 |
+
def codebook(self):
|
520 |
+
codebook = self._codebook.embed
|
521 |
+
if self.separate_codebook_per_head:
|
522 |
+
return codebook
|
523 |
+
|
524 |
+
return rearrange(codebook, '1 ... -> ...')
|
525 |
+
|
526 |
+
def forward(
|
527 |
+
self,
|
528 |
+
x,
|
529 |
+
mask = None
|
530 |
+
):
|
531 |
+
shape, device, heads, is_multiheaded, codebook_size = x.shape, x.device, self.heads, self.heads > 1, self.codebook_size
|
532 |
+
|
533 |
+
need_transpose = not self.channel_last and not self.accept_image_fmap
|
534 |
+
|
535 |
+
if self.accept_image_fmap:
|
536 |
+
height, width = x.shape[-2:]
|
537 |
+
x = rearrange(x, 'b c h w -> b (h w) c')
|
538 |
+
|
539 |
+
if need_transpose:
|
540 |
+
x = rearrange(x, 'b d n -> b n d')
|
541 |
+
|
542 |
+
x = self.project_in(x)
|
543 |
+
|
544 |
+
if is_multiheaded:
|
545 |
+
ein_rhs_eq = 'h b n d' if self.separate_codebook_per_head else '1 (b h) n d'
|
546 |
+
x = rearrange(x, f'b n (h d) -> {ein_rhs_eq}', h = heads)
|
547 |
+
|
548 |
+
quantize, embed_ind = self._codebook(x)
|
549 |
+
|
550 |
+
if self.training:
|
551 |
+
quantize = x + (quantize - x).detach()
|
552 |
+
|
553 |
+
loss = torch.tensor([0.], device = device, requires_grad = self.training)
|
554 |
+
|
555 |
+
if self.training:
|
556 |
+
if self.commitment_weight > 0:
|
557 |
+
detached_quantize = quantize.detach()
|
558 |
+
|
559 |
+
if exists(mask):
|
560 |
+
# with variable lengthed sequences
|
561 |
+
commit_loss = F.mse_loss(detached_quantize, x, reduction = 'none')
|
562 |
+
|
563 |
+
if is_multiheaded:
|
564 |
+
mask = repeat(mask, 'b n -> c (b h) n', c = commit_loss.shape[0], h = commit_loss.shape[1] // mask.shape[0])
|
565 |
+
|
566 |
+
commit_loss = commit_loss[mask].mean()
|
567 |
+
else:
|
568 |
+
commit_loss = F.mse_loss(detached_quantize, x)
|
569 |
+
|
570 |
+
loss = loss + commit_loss * self.commitment_weight
|
571 |
+
|
572 |
+
if self.orthogonal_reg_weight > 0:
|
573 |
+
codebook = self._codebook.embed
|
574 |
+
|
575 |
+
if self.orthogonal_reg_active_codes_only:
|
576 |
+
# only calculate orthogonal loss for the activated codes for this batch
|
577 |
+
unique_code_ids = torch.unique(embed_ind)
|
578 |
+
codebook = codebook[unique_code_ids]
|
579 |
+
|
580 |
+
num_codes = codebook.shape[0]
|
581 |
+
if exists(self.orthogonal_reg_max_codes) and num_codes > self.orthogonal_reg_max_codes:
|
582 |
+
rand_ids = torch.randperm(num_codes, device = device)[:self.orthogonal_reg_max_codes]
|
583 |
+
codebook = codebook[rand_ids]
|
584 |
+
|
585 |
+
orthogonal_reg_loss = orthogonal_loss_fn(codebook)
|
586 |
+
loss = loss + orthogonal_reg_loss * self.orthogonal_reg_weight
|
587 |
+
|
588 |
+
if is_multiheaded:
|
589 |
+
if self.separate_codebook_per_head:
|
590 |
+
quantize = rearrange(quantize, 'h b n d -> b n (h d)', h = heads)
|
591 |
+
embed_ind = rearrange(embed_ind, 'h b n -> b n h', h = heads)
|
592 |
+
else:
|
593 |
+
quantize = rearrange(quantize, '1 (b h) n d -> b n (h d)', h = heads)
|
594 |
+
embed_ind = rearrange(embed_ind, '1 (b h) n -> b n h', h = heads)
|
595 |
+
|
596 |
+
quantize = self.project_out(quantize)
|
597 |
+
|
598 |
+
if need_transpose:
|
599 |
+
quantize = rearrange(quantize, 'b n d -> b d n')
|
600 |
+
|
601 |
+
if self.accept_image_fmap:
|
602 |
+
quantize = rearrange(quantize, 'b (h w) c -> b c h w', h = height, w = width)
|
603 |
+
embed_ind = rearrange(embed_ind, 'b (h w) ... -> b h w ...', h = height, w = width)
|
604 |
+
|
605 |
+
return quantize, embed_ind, loss
|
deepsvg/schedulers/warmup.py
ADDED
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from torch.optim.lr_scheduler import _LRScheduler
|
2 |
+
from torch.optim.lr_scheduler import ReduceLROnPlateau
|
3 |
+
|
4 |
+
|
5 |
+
class GradualWarmupScheduler(_LRScheduler):
|
6 |
+
""" Gradually warm-up(increasing) learning rate in optimizer.
|
7 |
+
Proposed in 'Accurate, Large Minibatch SGD: Training ImageNet in 1 Hour'.
|
8 |
+
Args:
|
9 |
+
optimizer (Optimizer): Wrapped optimizer.
|
10 |
+
multiplier: target learning rate = base lr * multiplier if multiplier > 1.0. if multiplier = 1.0, lr starts from 0 and ends up with the base_lr.
|
11 |
+
total_epoch: target learning rate is reached at total_epoch, gradually
|
12 |
+
after_scheduler: after target_epoch, use this scheduler(eg. ReduceLROnPlateau)
|
13 |
+
"""
|
14 |
+
|
15 |
+
def __init__(self, optimizer, multiplier, total_epoch, after_scheduler=None):
|
16 |
+
self.multiplier = multiplier
|
17 |
+
if self.multiplier < 1.:
|
18 |
+
raise ValueError('multiplier should be greater thant or equal to 1.')
|
19 |
+
self.total_epoch = total_epoch
|
20 |
+
self.after_scheduler = after_scheduler
|
21 |
+
self.finished = False
|
22 |
+
super(GradualWarmupScheduler, self).__init__(optimizer)
|
23 |
+
|
24 |
+
def get_lr(self):
|
25 |
+
if self.last_epoch > self.total_epoch:
|
26 |
+
if self.after_scheduler:
|
27 |
+
if not self.finished:
|
28 |
+
self.after_scheduler.base_lrs = [base_lr * self.multiplier for base_lr in self.base_lrs]
|
29 |
+
self.finished = True
|
30 |
+
return self.after_scheduler.get_last_lr()
|
31 |
+
return [base_lr * self.multiplier for base_lr in self.base_lrs]
|
32 |
+
|
33 |
+
if self.multiplier == 1.0:
|
34 |
+
return [base_lr * (float(self.last_epoch) / self.total_epoch) for base_lr in self.base_lrs]
|
35 |
+
else:
|
36 |
+
return [base_lr * ((self.multiplier - 1.) * self.last_epoch / self.total_epoch + 1.) for base_lr in self.base_lrs]
|
37 |
+
|
38 |
+
def step_ReduceLROnPlateau(self, metrics, epoch=None):
|
39 |
+
if epoch is None:
|
40 |
+
epoch = self.last_epoch + 1
|
41 |
+
self.last_epoch = epoch if epoch != 0 else 1 # ReduceLROnPlateau is called at the end of epoch, whereas others are called at beginning
|
42 |
+
if self.last_epoch <= self.total_epoch:
|
43 |
+
# warmup_lr = [base_lr * ((self.multiplier - 1.) * self.last_epoch / self.total_epoch + 1.) for base_lr in self.base_lrs]
|
44 |
+
if self.multiplier == 1.0:
|
45 |
+
warmup_lr = [base_lr * (float(self.last_epoch) / self.total_epoch) for base_lr in self.base_lrs]
|
46 |
+
else:
|
47 |
+
warmup_lr = [base_lr * ((self.multiplier - 1.) * self.last_epoch / self.total_epoch + 1.) for base_lr in self.base_lrs]
|
48 |
+
for param_group, lr in zip(self.optimizer.param_groups, warmup_lr):
|
49 |
+
param_group['lr'] = lr
|
50 |
+
else:
|
51 |
+
if epoch is None:
|
52 |
+
self.after_scheduler.step(metrics, None)
|
53 |
+
else:
|
54 |
+
self.after_scheduler.step(metrics, epoch - self.total_epoch)
|
55 |
+
|
56 |
+
def step(self, epoch=None, metrics=None):
|
57 |
+
if type(self.after_scheduler) != ReduceLROnPlateau:
|
58 |
+
if self.finished and self.after_scheduler:
|
59 |
+
if epoch is None:
|
60 |
+
self.after_scheduler.step(None)
|
61 |
+
else:
|
62 |
+
self.after_scheduler.step(epoch - self.total_epoch)
|
63 |
+
self._last_lr = self.after_scheduler.get_last_lr()
|
64 |
+
else:
|
65 |
+
return super(GradualWarmupScheduler, self).step(epoch)
|
66 |
+
else:
|
67 |
+
self.step_ReduceLROnPlateau(metrics, epoch)
|
deepsvg/svg_dataset.py
ADDED
@@ -0,0 +1,269 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from deepsvg.config import _Config
|
2 |
+
from deepsvg.difflib.tensor import SVGTensor
|
3 |
+
from deepsvg.svglib.svg import SVG
|
4 |
+
from deepsvg.svglib.geom import Point, Angle
|
5 |
+
# from deepsvg import utils
|
6 |
+
|
7 |
+
import math
|
8 |
+
import torch
|
9 |
+
import torch.utils.data
|
10 |
+
import random
|
11 |
+
from typing import List, Union
|
12 |
+
import pandas as pd
|
13 |
+
import os
|
14 |
+
import pickle
|
15 |
+
from sklearn.model_selection import train_test_split
|
16 |
+
Num = Union[int, float]
|
17 |
+
|
18 |
+
|
19 |
+
class SVGDataset(torch.utils.data.Dataset):
|
20 |
+
def __init__(self, df, data_dir, model_args, max_num_groups, max_seq_len, max_total_len=None, PAD_VAL=0,
|
21 |
+
nb_augmentations=1, already_preprocessed=True):
|
22 |
+
self.data_dir = data_dir
|
23 |
+
|
24 |
+
self.already_preprocessed = already_preprocessed
|
25 |
+
|
26 |
+
self.MAX_NUM_GROUPS = max_num_groups
|
27 |
+
self.MAX_SEQ_LEN = max_seq_len
|
28 |
+
self.MAX_TOTAL_LEN = max_total_len
|
29 |
+
|
30 |
+
if max_total_len is None:
|
31 |
+
self.MAX_TOTAL_LEN = max_num_groups * max_seq_len
|
32 |
+
|
33 |
+
# if df is None:
|
34 |
+
# df = pd.read_csv(meta_filepath)
|
35 |
+
|
36 |
+
# if len(df) > 0:
|
37 |
+
# if filter_uni is not None:
|
38 |
+
# df = df[df.uni.isin(filter_uni)]
|
39 |
+
|
40 |
+
# if filter_platform is not None:
|
41 |
+
# df = df[df.platform.isin(filter_platform)]
|
42 |
+
|
43 |
+
# if filter_category is not None:
|
44 |
+
# df = df[df.category.isin(filter_category)]
|
45 |
+
|
46 |
+
# df = df[(df.nb_groups <= max_num_groups) & (df.max_len_group <= max_seq_len)]
|
47 |
+
# if max_total_len is not None:
|
48 |
+
# df = df[df.total_len <= max_total_len]
|
49 |
+
|
50 |
+
# self.df = df.sample(frac=train_ratio) if train_ratio < 1.0 else df
|
51 |
+
self.df = df
|
52 |
+
|
53 |
+
self.model_args = model_args
|
54 |
+
|
55 |
+
self.PAD_VAL = PAD_VAL
|
56 |
+
|
57 |
+
self.nb_augmentations = nb_augmentations
|
58 |
+
|
59 |
+
def search_name(self, name):
|
60 |
+
return self.df[self.df.commonName.str.contains(name)]
|
61 |
+
|
62 |
+
def _filter_categories(self, filter_category):
|
63 |
+
self.df = self.df[self.df.category.isin(filter_category)]
|
64 |
+
|
65 |
+
@staticmethod
|
66 |
+
def _uni_to_label(uni):
|
67 |
+
if 48 <= uni <= 57:
|
68 |
+
return uni - 48
|
69 |
+
elif 65 <= uni <= 90:
|
70 |
+
return uni - 65 + 10
|
71 |
+
return uni - 97 + 36
|
72 |
+
|
73 |
+
@staticmethod
|
74 |
+
def _label_to_uni(label_id):
|
75 |
+
if 0 <= label_id <= 9:
|
76 |
+
return label_id + 48
|
77 |
+
elif 10 <= label_id <= 35:
|
78 |
+
return label_id + 65 - 10
|
79 |
+
return label_id + 97 - 36
|
80 |
+
|
81 |
+
@staticmethod
|
82 |
+
def _category_to_label(category):
|
83 |
+
categories = ['characters', 'free-icons', 'logos', 'alphabet', 'animals', 'arrows', 'astrology', 'baby', 'beauty',
|
84 |
+
'business', 'cinema', 'city', 'clothing', 'computer-hardware', 'crime', 'cultures', 'data', 'diy',
|
85 |
+
'drinks', 'ecommerce', 'editing', 'files', 'finance', 'folders', 'food', 'gaming', 'hands', 'healthcare',
|
86 |
+
'holidays', 'household', 'industry', 'maps', 'media-controls', 'messaging', 'military', 'mobile',
|
87 |
+
'music', 'nature', 'network', 'photo-video', 'plants', 'printing', 'profile', 'programming', 'science',
|
88 |
+
'security', 'shopping', 'social-networks', 'sports', 'time-and-date', 'transport', 'travel', 'user-interface',
|
89 |
+
'users', 'weather', 'flags', 'emoji', 'men', 'women']
|
90 |
+
return categories.index(category)
|
91 |
+
|
92 |
+
def get_label(self, idx=0, entry=None):
|
93 |
+
# if entry is None:
|
94 |
+
# entry = self.df.iloc[idx]
|
95 |
+
|
96 |
+
# if "uni" in self.df.columns: # Font dataset
|
97 |
+
# label = self._uni_to_label(entry.uni)
|
98 |
+
# return torch.tensor(label)
|
99 |
+
# elif "category" in self.df.columns: # Icons dataset
|
100 |
+
# label = self._category_to_label(entry.category)
|
101 |
+
# return torch.tensor(label)
|
102 |
+
|
103 |
+
if "label" in self.df.columns:
|
104 |
+
return self.df.iloc[idx]['label']
|
105 |
+
|
106 |
+
def idx_to_id(self, idx):
|
107 |
+
return self.df.iloc[idx].id
|
108 |
+
|
109 |
+
def entry_from_id(self, id):
|
110 |
+
return self.df[self.df.id == str(id)].iloc[0]
|
111 |
+
|
112 |
+
def _load_svg(self, icon_id):
|
113 |
+
svg = SVG.load_svg(os.path.join(self.data_dir, f"{icon_id}.svg"))
|
114 |
+
|
115 |
+
if not self.already_preprocessed:
|
116 |
+
svg.fill_(False)
|
117 |
+
svg.normalize().zoom(0.9)
|
118 |
+
svg.canonicalize()
|
119 |
+
svg = svg.simplify_heuristic()
|
120 |
+
|
121 |
+
return svg
|
122 |
+
|
123 |
+
def __len__(self):
|
124 |
+
return len(self.df) * self.nb_augmentations
|
125 |
+
|
126 |
+
def random_icon(self):
|
127 |
+
return self[random.randrange(0, len(self))]
|
128 |
+
|
129 |
+
def random_id(self):
|
130 |
+
idx = random.randrange(0, len(self)) % len(self.df)
|
131 |
+
return self.idx_to_id(idx)
|
132 |
+
|
133 |
+
def random_id_by_uni(self, uni):
|
134 |
+
df = self.df[self.df.uni == uni]
|
135 |
+
return df.id.sample().iloc[0]
|
136 |
+
|
137 |
+
def __getitem__(self, idx):
|
138 |
+
return self.get(idx, self.model_args)
|
139 |
+
|
140 |
+
@staticmethod
|
141 |
+
def _augment(svg, mean=False):
|
142 |
+
# aug 2
|
143 |
+
# dx = random.randint(0, 10)
|
144 |
+
# dy = random.randint(0, 10)
|
145 |
+
# factor = 0.02 * dx + 0.8
|
146 |
+
|
147 |
+
# return svg.zoom(factor).translate(Point(dx / 6, dy / 6)).rotate(Angle((dx - 5) / 2))
|
148 |
+
|
149 |
+
# aug 1
|
150 |
+
n = random.random() % 10 # [0, 9]
|
151 |
+
dx, dy = (0, 0) if mean else (n / 9, n / 9)
|
152 |
+
factor = 0.7 if mean else 0.02 * n + 0.8
|
153 |
+
|
154 |
+
return svg.zoom(factor).translate(Point(dx, dy))
|
155 |
+
# return svg.zoom(factor)
|
156 |
+
|
157 |
+
@staticmethod
|
158 |
+
def simplify(svg, normalize=True):
|
159 |
+
svg.canonicalize(normalize=normalize)
|
160 |
+
svg = svg.simplify_heuristic()
|
161 |
+
return svg.normalize()
|
162 |
+
|
163 |
+
@staticmethod
|
164 |
+
def preprocess(svg, augment=True, numericalize=True, mean=False):
|
165 |
+
if augment:
|
166 |
+
svg = SVGDataset._augment(svg, mean=mean)
|
167 |
+
if numericalize:
|
168 |
+
return svg.numericalize(256)
|
169 |
+
return svg
|
170 |
+
|
171 |
+
def get(self, idx=0, model_args=None, random_aug=True, id=None, svg: SVG=None):
|
172 |
+
if id is None:
|
173 |
+
idx = idx % len(self.df)
|
174 |
+
id = self.idx_to_id(idx)
|
175 |
+
# utils.set_value('id', id)
|
176 |
+
|
177 |
+
if svg is None:
|
178 |
+
svg = self._load_svg(id)
|
179 |
+
|
180 |
+
svg = SVGDataset.preprocess(svg, augment=random_aug, numericalize=False)
|
181 |
+
|
182 |
+
t_sep, fillings = svg.to_tensor(concat_groups=False, PAD_VAL=self.PAD_VAL), svg.to_fillings()
|
183 |
+
|
184 |
+
label = self.get_label(idx)
|
185 |
+
|
186 |
+
return self.get_data(t_sep, fillings, model_args=model_args, label=label)
|
187 |
+
|
188 |
+
def get_data(self, t_sep, fillings, model_args=None, label=None):
|
189 |
+
res = {}
|
190 |
+
|
191 |
+
if model_args is None:
|
192 |
+
model_args = self.model_args
|
193 |
+
|
194 |
+
pad_len = max(self.MAX_NUM_GROUPS - len(t_sep), 0)
|
195 |
+
|
196 |
+
t_sep.extend([torch.empty(0, 9)] * pad_len)
|
197 |
+
# t_sep.extend([torch.empty(0, 14)] * pad_len)
|
198 |
+
fillings.extend([0] * pad_len)
|
199 |
+
|
200 |
+
t_grouped = [SVGTensor.from_data(torch.cat(t_sep, dim=0), PAD_VAL=self.PAD_VAL).add_eos().add_sos().pad(
|
201 |
+
seq_len=self.MAX_TOTAL_LEN + 2)]
|
202 |
+
t_sep = [SVGTensor.from_data(t, PAD_VAL=self.PAD_VAL, filling=f).add_eos().add_sos().pad(seq_len=self.MAX_SEQ_LEN + 2) for
|
203 |
+
t, f in zip(t_sep, fillings)]
|
204 |
+
|
205 |
+
for arg in set(model_args):
|
206 |
+
if "_grouped" in arg:
|
207 |
+
arg_ = arg.split("_grouped")[0]
|
208 |
+
t_list = t_grouped
|
209 |
+
else:
|
210 |
+
arg_ = arg
|
211 |
+
t_list = t_sep
|
212 |
+
|
213 |
+
if arg_ == "tensor":
|
214 |
+
res[arg] = t_list
|
215 |
+
|
216 |
+
if arg_ == "commands":
|
217 |
+
res[arg] = torch.stack([t.cmds() for t in t_list])
|
218 |
+
|
219 |
+
if arg_ == "args_rel":
|
220 |
+
res[arg] = torch.stack([t.get_relative_args() for t in t_list])
|
221 |
+
if arg_ == "args_bin":
|
222 |
+
res[arg] = torch.stack([t.get_binary_args() for t in t_list])
|
223 |
+
if arg_ == "args":
|
224 |
+
res[arg] = torch.stack([t.args() for t in t_list])
|
225 |
+
|
226 |
+
if "filling" in model_args:
|
227 |
+
res["filling"] = torch.stack([torch.tensor(t.filling) for t in t_sep]).unsqueeze(-1)
|
228 |
+
|
229 |
+
if "label" in model_args:
|
230 |
+
res["label"] = label
|
231 |
+
|
232 |
+
return res
|
233 |
+
|
234 |
+
|
235 |
+
def load_dataset(cfg: _Config, already_preprocessed=True, train_split=False):
|
236 |
+
|
237 |
+
df = pd.read_csv(cfg.meta_filepath)
|
238 |
+
|
239 |
+
if len(df) > 0:
|
240 |
+
if cfg.filter_uni is not None:
|
241 |
+
df = df[df.uni.isin(cfg.filter_uni)]
|
242 |
+
|
243 |
+
if cfg.filter_platform is not None:
|
244 |
+
df = df[df.platform.isin(cfg.filter_platform)]
|
245 |
+
|
246 |
+
if cfg.filter_category is not None:
|
247 |
+
df = df[df.category.isin(cfg.filter_category)]
|
248 |
+
|
249 |
+
df = df[(df.nb_groups <= cfg.max_num_groups) & (df.max_len_group <= cfg.max_seq_len)]
|
250 |
+
if cfg.max_total_len is not None:
|
251 |
+
df = df[df.total_len <= cfg.max_total_len]
|
252 |
+
|
253 |
+
df = df.sample(frac=cfg.dataset_ratio) if cfg.dataset_ratio < 1.0 else df
|
254 |
+
|
255 |
+
train_df, valid_df = train_test_split(df, train_size=cfg.train_ratio)
|
256 |
+
if train_split:
|
257 |
+
train_df, valid_df = train_test_split(train_df, train_size=cfg.train_ratio)
|
258 |
+
|
259 |
+
train_dataset = SVGDataset(train_df, cfg.data_dir, cfg.model_args, cfg.max_num_groups, cfg.max_seq_len, cfg.max_total_len, nb_augmentations=cfg.nb_augmentations, already_preprocessed=already_preprocessed)
|
260 |
+
valid_dataset = SVGDataset(valid_df, cfg.data_dir, cfg.model_args, cfg.max_num_groups, cfg.max_seq_len, cfg.max_total_len, nb_augmentations=cfg.nb_augmentations, already_preprocessed=already_preprocessed)
|
261 |
+
|
262 |
+
print(f"Number of train SVGs: {len(train_df)}")
|
263 |
+
# print(f"First SVG in train: {train_df.iloc[0]['id']} - {train_df.iloc[0]['category']} - {train_df.iloc[0]['subcategory']}")
|
264 |
+
print(f"First SVG in train: {train_df.iloc[0]['id']}")
|
265 |
+
print(f"Number of valid SVGs: {len(valid_df)}")
|
266 |
+
# print(f"First SVG in train: {valid_df.iloc[0]['id']} - {valid_df.iloc[0]['category']} - {valid_df.iloc[0]['subcategory']}")
|
267 |
+
print(f"First SVG in train: {valid_df.iloc[0]['id']}")
|
268 |
+
|
269 |
+
return train_dataset, valid_dataset
|
deepsvg/svglib/__init__.py
ADDED
File without changes
|
deepsvg/svglib/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (146 Bytes). View file
|
|
deepsvg/svglib/__pycache__/geom.cpython-310.pyc
ADDED
Binary file (18.2 kB). View file
|
|
deepsvg/svglib/__pycache__/svg.cpython-310.pyc
ADDED
Binary file (20.3 kB). View file
|
|