Source code for geometric_algebra_attention.keras.MultivectorAttention


from tensorflow import keras

from .. import base
from .internal import AttentionBase

[docs] class MultivectorAttention( AttentionBase, base.MultivectorAttention, keras.layers.Layer): __doc__ = base.MultivectorAttention.__doc__ def __init__(self, score_net, value_net, reduce=True, merge_fun='mean', join_fun='mean', rank=2, invariant_mode='single', covariant_mode='single', include_normalized_products=False, linear_mode='partial', linear_terms=0, **kwargs): keras.layers.Layer.__init__(self, **kwargs) base.MultivectorAttention.__init__( self, score_net, value_net, reduce, merge_fun, join_fun, rank, invariant_mode, covariant_mode, include_normalized_products, linear_mode, linear_terms, )
keras.utils.get_custom_objects()['MultivectorAttention'] = MultivectorAttention