CLOSP-VS / location_encoder.py
DarthReca's picture
Upload 4 files
e9aa1c4 verified
# Copyright (c) Microsoft Corporation.
import math
import torch
from einops import rearrange
from torch import nn
from torch.nn import functional as F
from .positional_encoding import SphericalHarmonics
class LocationEncoder(nn.Module):
def __init__(
self,
dim_hidden: int,
num_layers: int,
dim_out: int,
legendre_polys: int = 10,
):
super().__init__()
self.posenc = SphericalHarmonics(legendre_polys=legendre_polys)
self.nnet = SirenNet(
dim_in=self.posenc.embedding_dim,
dim_hidden=dim_hidden,
num_layers=num_layers,
dim_out=dim_out,
)
def forward(self, x):
x = self.posenc(x)
return self.nnet(x)
class SirenNet(nn.Module):
"""Sinusoidal Representation Network (SIREN)"""
def __init__(
self,
dim_in,
dim_hidden,
dim_out,
num_layers,
w0=1.0,
w0_initial=30.0,
use_bias=True,
final_activation=None,
degreeinput=False,
dropout=True,
):
super().__init__()
self.num_layers = num_layers
self.dim_hidden = dim_hidden
self.degreeinput = degreeinput
self.layers = nn.ModuleList([])
for ind in range(num_layers):
is_first = ind == 0
layer_w0 = w0_initial if is_first else w0
layer_dim_in = dim_in if is_first else dim_hidden
self.layers.append(
Siren(
dim_in=layer_dim_in,
dim_out=dim_hidden,
w0=layer_w0,
use_bias=use_bias,
is_first=is_first,
dropout=dropout,
)
)
final_activation = (
nn.Identity() if not exists(final_activation) else final_activation
)
self.last_layer = Siren(
dim_in=dim_hidden,
dim_out=dim_out,
w0=w0,
use_bias=use_bias,
activation=final_activation,
dropout=False,
)
def forward(self, x, mods=None):
# do some normalization to bring degrees in a -pi to pi range
if self.degreeinput:
x = torch.deg2rad(x) - torch.pi
mods = cast_tuple(mods, self.num_layers)
for layer, mod in zip(self.layers, mods):
x = layer(x)
if exists(mod):
x *= rearrange(mod, "d -> () d")
return self.last_layer(x)
class Sine(nn.Module):
def __init__(self, w0=1.0):
super().__init__()
self.w0 = w0
def forward(self, x):
return torch.sin(self.w0 * x)
class Siren(nn.Module):
def __init__(
self,
dim_in,
dim_out,
w0=1.0,
c=6.0,
is_first=False,
use_bias=True,
activation=None,
dropout=False,
):
super().__init__()
self.dim_in = dim_in
self.is_first = is_first
self.dim_out = dim_out
self.dropout = dropout
weight = torch.zeros(dim_out, dim_in)
bias = torch.zeros(dim_out) if use_bias else None
self.init_(weight, bias, c=c, w0=w0)
self.weight = nn.Parameter(weight)
self.bias = nn.Parameter(bias) if use_bias else None
self.activation = Sine(w0) if activation is None else activation
def init_(self, weight, bias, c, w0):
dim = self.dim_in
w_std = (1 / dim) if self.is_first else (math.sqrt(c / dim) / w0)
weight.uniform_(-w_std, w_std)
if exists(bias):
bias.uniform_(-w_std, w_std)
def forward(self, x):
out = F.linear(x, self.weight, self.bias)
if self.dropout:
out = F.dropout(out, training=self.training)
out = self.activation(out)
return out
def exists(val):
return val is not None
def cast_tuple(val, repeat=1):
return val if isinstance(val, tuple) else ((val,) * repeat)