Source code for geometric_algebra_attention.jax.Multivector2MultivectorAttention

import functools
import operator

import jax

from .. import base
from .MultivectorAttention import MultivectorAttention

[docs] class Multivector2MultivectorAttention(base.Multivector2MultivectorAttention, MultivectorAttention): __doc__ = base.Multivector2MultivectorAttention.__doc__ def __init__(self, score_net, value_net, scale_net, reduce=True, merge_fun='mean', join_fun='mean', rank=2, invariant_mode='single', covariant_mode='partial', include_normalized_products=False, linear_mode='partial', linear_terms=0, convex_covariants=False, **kwargs): self.scale_net_params = self.scale_net_fn = None base.Multivector2MultivectorAttention.__init__( self, scale_net=scale_net, convex_covariants=convex_covariants) MultivectorAttention.__init__( self, score_net=score_net, value_net=value_net, reduce=reduce, merge_fun=merge_fun, join_fun=join_fun, rank=rank, invariant_mode=invariant_mode, covariant_mode=covariant_mode, include_normalized_products=include_normalized_products, linear_mode=linear_mode, linear_terms=linear_terms, **kwargs)
[docs] def stax_init(self, rng, input_shape): super().stax_init(rng, input_shape) rng, next_rng = jax.random.split(rng) _, self.scale_net_params = self.scale_net_init(next_rng, (self.n_dim,)) return input_shape, self.params
@property def params(self): result = super().params result['scale_net_params'] = self.scale_net_params return result @params.setter def params(self, values): for (name, val) in values.items(): setattr(self, name, val) @property def scale_net(self): """Get and set the parameters for the scale-generating function. See the main jax module documentation for more details about these functions. """ kwargs = {} if getattr(self, '_last_rng', None) is not None: mixin = functools.reduce(operator.mul, map(int, 'scale'.encode())) kwargs['rng'] = jax.random.fold_in(self._last_rng, mixin) return functools.partial(self.scale_net_fn, self.scale_net_params, **kwargs) @scale_net.setter def scale_net(self, value): self.scale_net_init, self.scale_net_fn = value