Source code for geometric_algebra_attention.pytorch.MomentumLayerNormalization

import torch as pt

from .geometric_algebra import custom_norm

[docs] class MomentumLayerNormalization(pt.nn.Module): """Exponential decay normalization. Calculates a running average of the L2 norm and scales inputs to have length (over the last axis) 1, on average. :param momentum: Momentum of moving average, from 0 to 1 :param epsilon: Minimum norm for normalization scaling factor """ def __init__(self, momentum=0.99, epsilon=1e-7): super().__init__() self.register_buffer('momentum', pt.as_tensor(momentum)) self.register_buffer('epsilon', pt.as_tensor(epsilon)) self.register_buffer('norm', pt.ones(1)) def forward(self, x): if norm = custom_norm(x) norm = pt.mean(norm) mixture = self.momentum*self.norm + (1.0 - self.momentum)*norm self.norm[:] = mixture.detach() return x/pt.maximum(self.norm, self.epsilon).detach()