Source code for geometric_algebra_attention.keras.TiedMultivectorAttention
from tensorflow import keras
from .. import base
from .Multivector2MultivectorAttention import Multivector2MultivectorAttention
[docs]
class TiedMultivectorAttention(
base.TiedMultivectorAttention, Multivector2MultivectorAttention
):
__doc__ = base.TiedMultivectorAttention.__doc__
[docs]
def compute_mask(self, *args, **kwargs):
result = super().compute_mask(*args, **kwargs)
return result, result
keras.utils.get_custom_objects()["TiedMultivectorAttention"] = TiedMultivectorAttention