Math and science::INF ML AI
Attention layer
Do you remember the implementation of Attention in Pytorch? dummy cloze
class Attention(nn.Module):
def __init__(self, dim, heads=4, dropout=0.0, project_out=True):
super().__init__()
self.proj_out = project_out
self.dim_head = dim / heads
if int(self.dim_head) != self.dim_head:
raise ValueError(
f"dim ({dim}) must be divisible by heads ({heads})"
)
self.dim_head = int(self.dim_head)
self.scale = self.dim_head**-0.5
self.heads = heads
hidden_dim = self.dim_head * heads
self.to_qkv = nn.Linear(dim, hidden_dim * 3, bias=False)
self.to_out = nn.Sequential(
nn.Linear(hidden_dim, dim),
nn.Dropout(dropout),
) if self.project_out else nn.Identity()
self.norm = nn.LayerNorm(dim)
self.dropout = nn.Dropout(dropout)
def forward(self, x):
b, c, n = x.shape
x = self.norm(x)
qkv = self.to_qkv(x).chunk(3, dim=1)
q, k, v = map(
lambda t: einops.rearrange(t, "b (h c) n -> b h c n", h=self.heads),
qkv,
)
q = q * self.scale
raw_attn = torch.einsum("b h d i, b h d j -> b h i j", q, k)
attn = raw_attn.softmax(dim=-1)
attn = self.dropout(attn)
out = torch.einsum("b h i j, b h d j -> b h i d", attn, v)
out = einops.rearrange(out, "b h n d -> b (h d) n")
return self.to_out(out)