Source code for geometric_algebra_attention.pytorch.TiedMultivectorAttention
from .. import base
from .Multivector2MultivectorAttention import Multivector2MultivectorAttention
[docs]
class TiedMultivectorAttention(
base.TiedMultivectorAttention, Multivector2MultivectorAttention
):
__doc__ = base.TiedMultivectorAttention.__doc__
def __init__(
self,
n_dim,
score_net,
value_net,
scale_net,
reduce=True,
merge_fun="mean",
join_fun="mean",
rank=2,
invariant_mode="single",
covariant_mode="partial",
include_normalized_products=False,
convex_covariants=False,
linear_mode='partial',
linear_terms=0,
**kwargs
):
Multivector2MultivectorAttention.__init__(
self,
n_dim=n_dim,
score_net=score_net,
value_net=value_net,
scale_net=scale_net,
reduce=reduce,
merge_fun=merge_fun,
join_fun=join_fun,
rank=rank,
invariant_mode=invariant_mode,
covariant_mode=covariant_mode,
include_normalized_products=include_normalized_products,
convex_covariants=convex_covariants,
linear_mode=linear_mode,
linear_terms=linear_terms,
**kwargs
)
if type(self) == TiedMultivectorAttention:
self.init()