Source code for geometric_algebra_attention.keras.LabeledVectorAttention


from tensorflow import keras

from .. import base
from .Vector2VectorAttention import Vector2VectorAttention

[docs] class LabeledVectorAttention(base.LabeledVectorAttention, Vector2VectorAttention): __doc__ = base.LabeledVectorAttention.__doc__
[docs] def build(self, input_shape): modified_shape = input_shape[1] return super().build(modified_shape)
[docs] def compute_mask(self, inputs, mask=None): """Calculate the output mask of this layer given input shapes and masks.""" if mask is None: return mask (child_mask, other_mask) = mask return child_mask
keras.utils.get_custom_objects()['LabeledVectorAttention'] = LabeledVectorAttention