Source code for geometric_algebra_attention.keras.TiedVectorAttention

from tensorflow import keras

from .. import base
from .Vector2VectorAttention import Vector2VectorAttention


[docs] class TiedVectorAttention(base.TiedVectorAttention, Vector2VectorAttention): __doc__ = base.TiedVectorAttention.__doc__
[docs] def compute_mask(self, *args, **kwargs): result = super().compute_mask(*args, **kwargs) return result, result
keras.utils.get_custom_objects()["TiedVectorAttention"] = TiedVectorAttention