Source code for geometric_algebra_attention.tensorflow.internal

import tensorflow as tf

from .. import base
from . import geometric_algebra

class AttentionBase:
    algebra = geometric_algebra

    math = base.Namespace(
        all=tf.reduce_all,
        any=tf.reduce_any,
        asarray=lambda x, ref=None: tf.convert_to_tensor(x),
        bool_to_int=lambda x: tf.cast(x, tf.int8),
        clip=tf.clip_by_value,
        concat=tf.concat,
        logical_and=tf.logical_and,
        named_constant=lambda name, x, ref: tf.constant(x),
        pow=tf.pow,
        product=tf.reduce_prod,
        reshape=tf.reshape,
        shape=tf.shape,
        softmax=tf.nn.softmax,
        sqrt=tf.sqrt,
        sum=tf.reduce_sum,
        tensordot=tf.tensordot,
        where=tf.where,
        zeros_like=tf.zeros_like,
    )

    def __init__(self, n_dim, *args, **kwargs):
        self.n_dim = n_dim

        super().__init__(*args, **kwargs)

        weight_sets = self._build_weight_definitions(n_dim)
        for (name, defs) in weight_sets.groups.items():
            weights = [tf.Variable(
                tf.random.normal(def_.shape, stddev=def_.stdev), name=def_.name,
                trainable=True)
                       for def_ in defs]
            setattr(self, name, weights)

        for (name, def_) in weight_sets.singles.items():
            weight = tf.Variable(
                tf.random.normal(def_.shape, stddev=def_.stdev), name=def_.name,
                trainable=True)
            setattr(self, name, weight)

    def __call__(self, inputs, return_attention=False):
        """Evaluate the attention calculation for this layer."""
        intermediates = self._evaluate(inputs)
        result = [intermediates.output]

        if return_attention:
            result.append(intermediates.attention)

        if len(result) == 1:
            return result[0]

        return tuple(result)