Source code for geometric_algebra_attention.pytorch.MomentumNormalization
import torch as pt
[docs]
class MomentumNormalization(pt.nn.Module):
"""Exponential decay normalization.
Computes the mean and standard deviation all axes but the last and
normalizes values to have mean 0 and variance 1; suitable for
normalizing a vector of real-valued quantities with differing
units.
:param n_dim: Last dimension of the layer input
:param momentum: Momentum of moving average, from 0 to 1
"""
def __init__(self, n_dim, momentum=.99):
super().__init__()
self.n_dim = n_dim
self.register_buffer('momentum', pt.as_tensor(momentum))
self.register_buffer('mu', pt.zeros(n_dim))
self.register_buffer('sigma', pt.ones(n_dim))
def forward(self, x):
if self.training:
axes = tuple(range(x.ndim - 1))
mu_calc = pt.mean(x, axes, keepdim=False)
sigma_calc = pt.std(x, axes, keepdim=False, unbiased=False)
new_mu = self.momentum*self.mu + (1 - self.momentum)*mu_calc
new_sigma = self.momentum*self.sigma + (1 - self.momentum)*sigma_calc
self.mu[:] = new_mu.detach()
self.sigma[:] = new_sigma.detach()
sigma = pt.maximum(self.sigma, pt.as_tensor(1e-7, dtype=x.dtype, device=x.device))
return (x - self.mu.detach())/sigma.detach()