Source code for geometric_algebra_attention.jax.Vector2Multivector
from .. import base
from .internal import AttentionBase
[docs]
class Vector2Multivector(base.Vector2Multivector):
__doc__ = base.Vector2Multivector.__doc__
math = AttentionBase.math
@classmethod
def stax_init(cls, rng, input_shape):
return input_shape, []
@classmethod
def stax_apply(cls, params, inputs, rng=None):
return cls._evaluate(inputs)
@property
def stax_functions(self):
return self.stax_init, stax_apply
def __call__(self, inputs):
return self.stax_apply(None, inputs)