AtenciΓ³n dispersaΒΆ

El proceso de enmascaramiento en la atenciΓ³n permite ignorar las relaciones de las tΓ³kens actuales con los tΓ³kens futuros. Podemos ver que estΓ‘ el enmascaramiento estΓ‘ ligado a la concepciΓ³n de la atenciΓ³n como un mecanismo grΓ‘fico. En la atenciΓ³n enmascarada podemos suponer que tenemos una grΓ‘fica tal que un elemento $x_i$ estΓ‘ conectado con un elemento $x_j$ si y sΓ³lo si $j \leq i$, siendo $i$ y $j$ la posiciΓ³n absoluta de los tΓ³kens en la oraciΓ³n de entrada.

Child et al. (2019) proponen extender esta concepciΓ³n todavΓ­a mΓ‘s para limitar las conexiones de los tΓ³kens actuales (y por tanto la informaciΓ³n que pueden observar) a elementos que se encuentran en una posiciΓ³n determinada con respecto al elemento actual. Los autores proponer varias formas de conectar estos elementos con respecto a las relaciones de vecindad en la grΓ‘fica, de tal forma que se pueden determinar de manera especΓ­fica las vecindades $\mathcal{N}_i$ de los elementos $x_i$. De tal forma, que la matriz de pesos de atenciΓ³n sΓ³lo tendrΓ‘ valores en las entradas correspondientes a los elementos que estΓ‘n conectados entre sΓ­ y tendrΓ‘ 0 en las otras entradas, funcionando como una matriz de adyacencia de una grΓ‘fica pesada. De la presencia de esto 0's es por lo que los autores llaman atenciΓ³n dispersa a este mecanismo.

En general, podemos pensar a la atenciΓ³n dispersa como un mΓ©todo de auto-atenciΓ³n en grΓ‘ficas en donde se toma en cuenta sΓ³lo los elementos relacionados en una grΓ‘fica no necesariamente conectada por completo. Esto es:

$$h_i = \sum_{j \in\mathcal{N}_i} \alpha_(x_i, x_j) \psi_v(x_j)$$

Para simplificar el cΓ‘lculo de la atenciΓ³n dispersa, empero, Child et al. (2019) proponen el uso de conexiones simples, que no requieran de la especificaciΓ³n explΓ­cita de una estructura de grΓ‘fica. Por ejemplo, los autores proponen algunos casos particulares que se muestran a continuaciΓ³n:

No description has been provided for this image

El caso mΓ‘s sencillo es el tomar en cuenta sΓ³lo los $k$ elementos anteriores al tΓ³ken actual. Es decir, en lugar de relacionar todos los anteriores, se tomarΓ‘ una ventana de tamaΓ±o $k$ para tomar los elementos. De esta forma, la representaciΓ³n de un elemento $x_i$ dependerΓ‘ de la atenciΓ³n puesta sΓ³lo a los $k$ elementos anteriores, a esto le llaman stride attention. En la figura anterior (c) tambiΓ©n se puede ver otra versiΓ³n de atenciΓ³n fija. En este sentido, podemos definir la atenciΓ³n dispersa como sigue:

AtenciΓ³n dispersa: Es un mecanismo de auto-atenciΓ³n que genera matrices de atenciΓ³n dispersas, al considerar que las relaciones entre los datos de entrada $x_1, x_2,..., x_n$ no definen una grΓ‘fica completamente conectada. En particular, podemos definir la atenciΓ³n (Stride) dispersa, cuando las vecindades de los elementos estΓ‘n dadas como: $$\mathcal{N}_i = \{j : max(0, i-k) \leq j \leq i \}$$ Donde $k$ es un hiperparΓ‘metro que determina el nΓΊmero de elementos previos que se consideran.

Para la implementaciΓ³n de esta atenciΓ³n dispersa utilizaremos una mΓ‘scara, pero tomando en cuenta el considerar sΓ³lo los $k$ elementos anteriores. El nΓΊmero $k$ de elementos previos que debemos considerar queda determinado como un hiperparΓ‘metro que llamamos stride. La implementaciΓ³n es similar a la auto-atenciΓ³n enmascarada, aunque tiende a ser mΓ‘s costosa la creaciΓ³n de la mΓ‘scara.

InΒ [1]:
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
from seaborn import heatmap as hm
import numpy as np

class SparseAttention(nn.Module):
    #AtenciΓ³n enmascarando subsecuentes
    def __init__(self, d_model, stride=3):
        super(SparseAttention, self).__init__()
        self.d_model = d_model
        self.stride = stride
        self.Q = nn.Linear(d_model, d_model, bias=False)
        self.K = nn.Linear(d_model, d_model, bias=False)
        self.V = nn.Linear(d_model, d_model, bias=False)

    def forward(self, x):
        query, key, value = self.Q(x), self.K(x), self.V(x)
        scores = torch.matmul(query, key.T)/np.sqrt(self.d_model)
        #Enmascaramiento de los scores
        mask  = self.masking(x)
        scores = scores.masked_fill(mask == 0, -1e9)
        att = nn.functional.softmax(scores, dim=-1)
        h = torch.matmul(att, value)

        return h, att

    def masking(self, x):
        #CreaciΓ³n de la mΓ‘scara
        n = x.size(0)
        mask = np.ones((n,n))
        for i in range(0,n):
            for j in range(0,self.stride):
                m = max(0,i-j)
                mask[i,m] = 0
        
        return torch.from_numpy(mask) == 0

Como podemos observar, el modelo de atenciΓ³n dispersa considerarΓ‘ sΓ³lo los $k$ elementos previos, ignorando los elementos subsiguientes y aquellos elementos previos que estΓ©n mΓ‘s allΓ‘ de distancia $k$ al tΓ³ken actual. Si visualizamos la matriz de atenciΓ³n, podemos ver que la matriz serΓ‘ dispersa, conteniendo sΓ³lo valores de probabilidades en la diagonal y las $k$ entradas previas a Γ©sta. Por ejemplo, si consideramos un stride igual a 3, estaremos asumiendo una estructura grΓ‘fica donde cada tΓ³ken conencta sΓ³lo con los 2 anteriores (y el mismo):

No description has been provided for this image

La matriz de atenciΓ³n serΓ­a como se muestra a continuaciΓ³n (donde todavΓ­a no se ha aprendido los pesos adecuados).

InΒ [2]:
model = SparseAttention(128, stride=3)

x = torch.rand(5,128)
labels = ['$w_1$','$w_2$','$w_3$','$w_4$','$w_5$']
h, att = model(x)
hm(att.detach().numpy(), annot=True, xticklabels=labels, yticklabels=labels)
plt.show()
No description has been provided for this image

La atenciΓ³n dispersa, como lo proponen los Child et al. (2019), busca factorizar la atenciΓ³n para enfocarse en los llamados modelos autoregresivos que generen secuencias tanto de texto, como de imΓ‘genes o audio. La atenciΓ³n dispersa tambiΓ©n busca disminuir la complejidad espacial de la atenciΓ³n comΓΊn, pues es claro que una matriz de atenciΓ³n requiere de una complejidad $O(n^2)$ en memoria, mientras que con la atenciΓ³n dispersa, el uso de memoria (usando representaciones dispersas de las matrices de atenciΓ³n) se redice a orden $O(kn)$; Los autres proponen que $k \approx \sqrt{n}$. Sin embargo, la creaciΓ³n de la mΓ‘scara puede tomar mΓ‘s tiempo, aumentando la complejidad temporal. Por ejemplo, en este caso la mΓ‘scara se crea en tiempo $O(kn)$.

AplicaciΓ³n de atenciΓ³n dispersaΒΆ

Para aplicar la atenciΓ³n dispersa podemos utilizar el mismo ejemplo que hemos venido trabajando sobre un modelo del lenguaje. Utilizamos multi-cabezas, tomando en cada cabeza un tamaΓ±o de stride incremental, de tal forma que la primera cabeza sΓ³lo considera al elemento para la representaciΓ³n, la segunda cabeza considera al elemento anterior, la tercera a los dos elementos anteriores, etc. Este es el ΓΊnico cambo que introducimos en la implementaciΓ³n:

InΒ [3]:
import copy

class MultiHeadMaskAttention(nn.Module):
    def __init__(self, in_size, d_model, hidden=128, heads=3, dropout=0.3):
        super(MultiHeadMaskAttention, self).__init__()
        self.d_model = d_model
        self.enc = Encoding(in_size, d_model)
        #Uso de atenciΓ³n dispersa con stride incremental
        self.att = nn.ModuleList([copy.deepcopy(SparseAttention(d_model, stride=i+1)) for i, _ in enumerate(range(heads))])
        self.lin = nn.Linear(heads*d_model, d_model, bias=True)
        self.norm = LayerNorm(d_model)
        self.ffw = nn.Sequential(nn.Linear(d_model, hidden), nn.ReLU(),
                                nn.Linear(hidden, d_model))
        self.drop1 = nn.Dropout(p=dropout)
        self.drop2 = nn.Dropout(p=dropout)
        self.drop3 = nn.Dropout(p=dropout)
    
    def forward(self, x):
        x_e = self.enc(x)
        x_e = self.drop1(x_e)
        head_att = [head(x_e) for head in self.att]
        self.att_weights = [head[1] for head in head_att]
        heads = [head[0] for head in head_att]
        multi_heads = torch.cat(heads, dim=-1)
        h = self.lin(multi_heads)
        h_norm = x_e + self.norm(h)
        h_norm = self.drop2(h_norm)
        out = self.ffw(h)
        
        return self.drop3(h_norm + self.norm(out))

Datos para el entrenamientoΒΆ

Los datos de entrenamiento tienen las mismas caracterΓ­sticas que anteriormente: se busca predecir la siguiente palabra, pero ahora el contexto de predicciΓ³n serΓ‘ limitado por el tamaΓ±o del estride; sin embargo, debe notarse que esto no estima una probabilidad de ngramas, pues la informaciΓ³n se transmite entre los diferentes elementos. Por ejemplo, si el stride es 3, entonces la representaciΓ³n del cuarto tΓ³ken $w_4$ dependerΓ‘ de $w_3$ y $w_2$, pero no de $w_1$. Pero tanto la representaciΓ³n de $w_3$ y $w_2$ sΓ­ dependen de $w_1$ por lo que la informaciΓ³n de este tΓ³ken se pasa de manera indirecta.

InΒ [4]:
import pandas as pd
from tqdm import tqdm
from transformers import *

#Corpus a utilizar
corpus = ['el perro come un hueso', 'un muchacho jugaba', 'el muchacho saltaba la cuerda',
          'un perro come croquetas', 'el perro come', 'el gato come croquetas', 
          'un gato come', 'un muchacho jugaba con la cuerda', 'el muchacho jugaba con la cuerda']
corpus = [w.split() for w in corpus]
#CreaciΓ³n del vocabulario
voc = vocab()
voc['[bos]'] = 0
voc['[eos]'] = 1
#IndexaciΓ³n de cadenas
sents = list(index(corpus, voc))

#Pares de entrenamiento
x = [torch.cat((torch.tensor([voc['[bos]']]),s), axis=0) for s in sents]
y = [torch.cat((s, torch.tensor([voc['[eos]']])), axis=0) for s in sents]
print(x[0], y[0])
tensor([0, 2, 3, 4, 5, 6]) tensor([2, 3, 4, 5, 6, 1])

Entrenamiento del modeloΒΆ

Ya con los datos de entrenamiento creados, podemos definir el modelo. Ya que en la arquitectura del modelo hemos determinado que el stride sea incremental en base al nΓΊmero de la cabeza, elejimos 4 cabezales para que el stride tome a lo mΓ‘s 4 elementos previos en la ΓΊltima cabeza.

InΒ [5]:
len_voc = len(voc)
model = nn.Sequential(MultiHeadMaskAttention(len_voc, 128, heads=4), 
                      nn.Linear(128,len_voc), nn.Softmax(1)) 

#Carga del modelo
model.load_state_dict(torch.load('sparse.model'))
model.eval()
Out[5]:
Sequential(
  (0): MultiHeadMaskAttention(
    (enc): Encoding(
      (emb): Embedding(15, 128)
      (pe): PositionalEncoding()
    )
    (att): ModuleList(
      (0-3): 4 x SparseAttention(
        (Q): Linear(in_features=128, out_features=128, bias=False)
        (K): Linear(in_features=128, out_features=128, bias=False)
        (V): Linear(in_features=128, out_features=128, bias=False)
      )
    )
    (lin): Linear(in_features=512, out_features=128, bias=True)
    (norm): LayerNorm()
    (ffw): Sequential(
      (0): Linear(in_features=128, out_features=128, bias=True)
      (1): ReLU()
      (2): Linear(in_features=128, out_features=128, bias=True)
    )
    (drop1): Dropout(p=0.3, inplace=False)
    (drop2): Dropout(p=0.3, inplace=False)
    (drop3): Dropout(p=0.3, inplace=False)
  )
  (1): Linear(in_features=128, out_features=15, bias=True)
  (2): Softmax(dim=1)
)

Ahora entrenamos el modelo con los datos de entrenamiento. Utilizamos el optimizador Noam que se revisarΓ‘ mΓ‘s adelante. Podemos guardar el modelo para cargarlo en posteriores pruebas.

InΒ [6]:
criterion = nn.CrossEntropyLoss()
optimizer = NoamOptimizer(model.parameters(), model[0].d_model, decay=0.01)
epochs = range(100)

#Entrenamiento
model.train()
for t in tqdm(epochs):
    for i in torch.randperm(len(x)):
        prediction = model(x[i])
        optimizer.zero_grad()
        loss_value = criterion(prediction, y[i])
        loss_value.backward()
        optimizer.step()

#torch.save(model.state_dict(), 'model.model')
/home/cienciasia/anaconda3/lib/python3.11/site-packages/torch/cuda/__init__.py:619: UserWarning: Can't initialize NVML
  warnings.warn("Can't initialize NVML")
100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 100/100 [00:09<00:00, 10.02it/s]

ExploraciΓ³n del modeloΒΆ

Al igual que en otros ejemplos, podemos comenzar a explorar quΓ© es lo que estΓ‘ aprendiendo el modelo; podemos observar las probabilidades de este modelo dado un contexto previo. Este tipo de modelos nos permitirΓ‘ predecir una palabra siguiente que se adapta al contexto pevio. En el caso de la auto-atenciΓ³n enmascarada se toma todo los elementos previos para cada caso; aquΓ­ sΓ³lo se toman aquellos elementos que estΓ©n en la ventana para la representaciΓ³n. Esto puede tener cierta influencia en los cΓ‘lculos de las probabilidades.

InΒ [7]:
devoc = {i:t for t,i in voc.items()}
def result(text, model):
    #FunciΓ³n para predecir la siguiente palabra dado el contexto
    tokens = text.split()
    x = torch.tensor([voc[t] for t in tokens])
    pred = model(x)
    max_token = pred.argmax(axis=1).detach().numpy()
    
    return pred.detach().numpy(), ' '.join([devoc[i] for i in max_token])

p, pred_text = result('[bos]', model)
print('Palabra siguiente con mayor prob: {}'.format(pred_text))

#VisualizaciΓ³n de probabilidades mΓ‘s altas
args = np.argsort(p[-1])[::-1]
probs = np.sort(p[-1])[::-1]
pd.DataFrame(data=probs, columns=['prob. tΓ³ken'], index=[devoc[j] for j in args]).plot.bar()
plt.show()
Palabra siguiente con mayor prob: el
No description has been provided for this image

Finalmente, nuestro interΓ©s radica en explorar el tipo de matrices de atenciΓ³n que se estΓ‘n obteniendo. Como lo deciamos, hemos definido un stride incremental, por lo que la primera cabeza tendrΓ‘ una matriz diagonal. En este caso, realmente no se estΓ‘ haciendo un proceso de atenciΓ³n, pues el elemento se estΓ‘ representando sΓ³lo por sΓ­ mismo o, de forma mΓ‘s especΓ­fica, por su proyecciΓ³n en el espacio de valores; es decir, tenemos que: $h_i = \psi_v(x_i)$ con un peso de atenciΓ³n de probabilidad 1. Los otros casos son mΓ‘s interesante, pues en la segunda cabeza, cada elemento se representarΓ‘ por sΓ­ mismo y el elemento anterior. En este caso particular, parece que todos las probabilidades de los pesos de atenciΓ³n son cercanas a $\frac{1}{2}$; es decir, cada elemento se representa con la informaciΓ³n tanto de sΓ­ mismo como del elemento anterior por igual. En la cabeza 3 se representarΓ‘ por los 2 elementos anteriores, y finalmente en la cabeza 4 se representarΓ‘ por los 4 elementos anteriores, que en el ejemplo particular que aquΓ­ usamos corresponde a todos los elementos previos.

InΒ [8]:
text = '[bos] un gato come'
result(text, model)

for i, att_w in enumerate(model[0].att_weights):
    hm(att_w.detach().numpy(), xticklabels=text.split(), yticklabels=text.split(), vmin=0, vmax=1)
    plt.title('AtenciΓ³n en cabeza %i' %i)
    plt.show()
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image

ReferenciasΒΆ

Child, R., Gray, S., Radford, A., & Sutskever, I. (2019). Generating long sequences with sparse transformers. arXiv preprint arXiv:1904.10509.

Vaswani, A., Shazeer, N., Parmar, N., Uszkoreit, J., Jones, L., Gomez, A. N., ... & Polosukhin, I. (2017). Attention is all you need. Advances in neural information processing systems, 30.


Principal