Source code for geometric_algebra_attention.pytorch.MultivectorAttention

import torch as pt

from .. import base
from . import geometric_algebra
from .internal import AttentionBase

[docs] class MultivectorAttention(AttentionBase, base.MultivectorAttention, pt.nn.Module): __doc__ = base.MultivectorAttention.__doc__ def __init__(self, n_dim, *args, **kwargs): pt.nn.Module.__init__(self) base.MultivectorAttention.__init__(self, *args, **kwargs) self.n_dim = n_dim if type(self) == MultivectorAttention: self.init()