Source code for geometric_algebra_attention.keras.MomentumLayerNormalization
import tensorflow as tf
from tensorflow import keras
from ..tensorflow.geometric_algebra import custom_norm
[docs]
class MomentumLayerNormalization(keras.layers.Layer):
"""Exponential decay normalization.
Calculates a running average of the L2 norm and scales inputs to
have length (over the last axis) 1, on average.
:param momentum: Momentum of moving average, from 0 to 1
:param epsilon: Minimum norm for normalization scaling factor
"""
def __init__(self, momentum=.99, epsilon=1e-7, *args, **kwargs):
self.momentum = momentum
self.epsilon = epsilon
self.supports_masking = True
super().__init__(*args, **kwargs)
def build(self, input_shape):
shape = [1]
self.norm = self.add_weight(
name = 'norm', shape=shape, initializer='ones', trainable=False)
def call(self, inputs, training=False, mask=None):
if training and self.trainable:
norm = custom_norm(inputs)
norm = tf.math.reduce_mean(norm, keepdims=False)
self.norm.assign(self.momentum*self.norm + (1 - self.momentum)*norm)
result = inputs/tf.maximum(self.norm, self.epsilon)
if mask is not None:
return tf.where(mask[..., None], result, inputs)
return result
def compute_mask(self, inputs, mask=None):
return mask
def get_config(self):
result = super().get_config()
result['momentum'] = self.momentum
result['epsilon'] = self.epsilon
return result
keras.utils.get_custom_objects()['MomentumLayerNormalization'] = MomentumLayerNormalization