Source code for geometric_algebra_attention.pytorch.VectorAttention
import torch as pt
from .. import base
from . import geometric_algebra
from .internal import AttentionBase
[docs]
class VectorAttention(AttentionBase, base.VectorAttention, pt.nn.Module):
__doc__ = base.VectorAttention.__doc__
def __init__(self, n_dim, *args, **kwargs):
pt.nn.Module.__init__(self)
base.VectorAttention.__init__(self, *args, **kwargs)
self.n_dim = n_dim
if type(self) == VectorAttention:
self.init()