|
|
|
|
|
import math |
|
|
|
import torch |
|
from torch import nn |
|
|
|
from .spherical_armonics import SH as SH_analytic |
|
|
|
|
|
class SphericalHarmonics(nn.Module): |
|
""" |
|
Spherical Harmonics locaiton encoder |
|
""" |
|
|
|
def __init__(self, legendre_polys: int = 10, harmonics_calculation="analytic"): |
|
""" |
|
legendre_polys: determines the number of legendre polynomials. |
|
more polynomials lead more fine-grained resolutions |
|
calculation of spherical harmonics: |
|
analytic uses pre-computed equations. This is exact, but works only up to degree 50, |
|
closed-form uses one equation but is computationally slower (especially for high degrees) |
|
""" |
|
super(SphericalHarmonics, self).__init__() |
|
self.L, self.M = int(legendre_polys), int(legendre_polys) |
|
self.embedding_dim = self.L * self.M |
|
|
|
if harmonics_calculation == "closed-form": |
|
self.SH = SH_closed_form |
|
elif harmonics_calculation == "analytic": |
|
self.SH = SH_analytic |
|
|
|
def forward(self, lonlat): |
|
lon, lat = lonlat[:, 0], lonlat[:, 1] |
|
|
|
|
|
phi = torch.deg2rad(lon + 180) |
|
theta = torch.deg2rad(lat + 90) |
|
""" |
|
greater_than_50 = (lon > 50).any() or (lat > 50).any() |
|
if greater_than_50: |
|
SH = SH_closed_form |
|
else: |
|
SH = SH_analytic |
|
""" |
|
SH = self.SH |
|
|
|
Y = [] |
|
for l in range(self.L): |
|
for m in range(-l, l + 1): |
|
y = SH(m, l, phi, theta) |
|
if isinstance(y, float): |
|
y = y * torch.ones_like(phi) |
|
if y.isnan().any(): |
|
print(m, l, y) |
|
Y.append(y) |
|
|
|
return torch.stack(Y, dim=-1) |
|
|
|
|
|
|
|
|
|
|
|
|
|
def associated_legendre_polynomial(l, m, x): |
|
pmm = torch.ones_like(x) |
|
if m > 0: |
|
somx2 = torch.sqrt((1 - x) * (1 + x)) |
|
fact = 1.0 |
|
for i in range(1, m + 1): |
|
pmm = pmm * (-fact) * somx2 |
|
fact += 2.0 |
|
if l == m: |
|
return pmm |
|
pmmp1 = x * (2.0 * m + 1.0) * pmm |
|
if l == m + 1: |
|
return pmmp1 |
|
pll = torch.zeros_like(x) |
|
for ll in range(m + 2, l + 1): |
|
pll = ((2.0 * ll - 1.0) * x * pmmp1 - (ll + m - 1.0) * pmm) / (ll - m) |
|
pmm = pmmp1 |
|
pmmp1 = pll |
|
return pll |
|
|
|
|
|
def SH_renormalization(l, m): |
|
return math.sqrt( |
|
(2.0 * l + 1.0) * math.factorial(l - m) / (4 * math.pi * math.factorial(l + m)) |
|
) |
|
|
|
|
|
def SH_closed_form(m, l, phi, theta): |
|
if m == 0: |
|
return SH_renormalization(l, m) * associated_legendre_polynomial( |
|
l, m, torch.cos(theta) |
|
) |
|
elif m > 0: |
|
return ( |
|
math.sqrt(2.0) |
|
* SH_renormalization(l, m) |
|
* torch.cos(m * phi) |
|
* associated_legendre_polynomial(l, m, torch.cos(theta)) |
|
) |
|
else: |
|
return ( |
|
math.sqrt(2.0) |
|
* SH_renormalization(l, -m) |
|
* torch.sin(-m * phi) |
|
* associated_legendre_polynomial(l, -m, torch.cos(theta)) |
|
) |
|
|