Source code for geometric_algebra_attention.keras.MomentumNormalization


import tensorflow as tf
from tensorflow import keras

[docs] class MomentumNormalization(keras.layers.Layer): """Exponential decay normalization. Computes the mean and standard deviation all axes but the last and normalizes values to have mean 0 and variance 1; suitable for normalizing a vector of real-valued quantities with differing units. :param momentum: Momentum of moving average, from 0 to 1 :param epsilon: Minimum std for normalization scaling factor :param use_mean: If True (default), calculate and apply a mean shift :param use_std: If True (default), calculate and apply a standard deviation scaling factor """ def __init__(self, momentum=.99, epsilon=1e-7, use_mean=True, use_std=True, *args, **kwargs): self.momentum = momentum self.epsilon = epsilon self.use_mean = use_mean self.use_std = use_std self.supports_masking = True super().__init__(*args, **kwargs) def build(self, input_shape): shape = [input_shape[-1]] self.mu = self.add_weight( name='mu', shape=shape, initializer='zeros', trainable=False) self.sigma = self.add_weight( name='sigma', shape=shape, initializer='ones', trainable=False) self._summary_axes = tuple(range(len(input_shape) - 1)) def call(self, inputs, training=False, mask=None): if training and self.trainable: if mask is not None: values = tf.ragged.boolean_mask(inputs, mask=mask) else: values = inputs mean = tf.math.reduce_mean(values, axis=self._summary_axes, keepdims=False) std = tf.math.reduce_std(values, axis=self._summary_axes, keepdims=False) self.mu.assign(self.momentum*self.mu + (1 - self.momentum)*mean) self.sigma.assign(self.momentum*self.sigma + (1 - self.momentum)*std) mu = self.mu*tf.cast(self.use_mean, tf.float32) use_std = tf.cast(self.use_std, tf.float32) denominator = use_std*(self.sigma + self.epsilon) + (1 - use_std)*1. result = (inputs - mu)/denominator if mask is not None: return tf.where(mask[..., None], result, inputs) return result def compute_mask(self, inputs, mask=None): return mask def get_config(self): result = super().get_config() result['momentum'] = self.momentum result['epsilon'] = self.epsilon result['use_mean'] = self.use_mean result['use_std'] = self.use_std return result
keras.utils.get_custom_objects()['MomentumNormalization'] = MomentumNormalization