Source code for geometric_algebra_attention.keras.LabeledMultivectorAttention
from tensorflow import keras
from .. import base
from .Multivector2MultivectorAttention import Multivector2MultivectorAttention
[docs]
class LabeledMultivectorAttention(base.LabeledMultivectorAttention, Multivector2MultivectorAttention):
__doc__ = base.LabeledMultivectorAttention.__doc__
[docs]
def build(self, input_shape):
modified_shape = input_shape[1]
return super().build(modified_shape)
[docs]
def compute_mask(self, inputs, mask=None):
"""Calculate the output mask of this layer given input shapes and masks."""
if mask is None:
return mask
(child_mask, other_mask) = mask
return child_mask
keras.utils.get_custom_objects()['LabeledMultivectorAttention'] = LabeledMultivectorAttention