Gradientes en capas de atención¶

Para comprende de mejor manera la derivación sobre la capa de auto-atención, conviene visualizar estas capas como gráficas computacionales; es decir, como aplicación de funciones que requieren del cálculo de funciones previas. Por ejemplo, en la auto-atención se realizan varios cómputos:

  • Proyecciones lineales: En primer lugar, tenemos que los vectores de query, key y value se obtienen a partir de capas lineales: capas donde se multiplica por una matriz. En general, estas capas no suelen tener un bias. Por lo que las capas de este tipo son de la forma: $$\psi(x_i) = Wx_i$$ Cada una de estas capas se debe estimar para los valores de query, key y value de manera independiente. Para simplicar, denotaremos cada una de estas capas como $q_i, k_i$ y $v_i$, respectivamente.
  • Producto escalar escalado: Una vez que se han estimado los valores para query y key, se realiza el cálculo de los valores que entrarán en el cálculo de las probabilidades Softmax por medio de un producto punto escalado, el cual es de la forma: $$\epsilon_{i,j} = \frac{q_i \cdot k_j}{\sqrt{d}}$$ El valor de $d$ corresponde a la dimensión del modelo; es decir, la dimensión de los vectores del value.
  • Probabilidad Softmax: Una vez que se han obtenido los valores del producto punto escalado, se obtienen las probabilidades con la función Softmax, el cálculo es de la siguiente forma: $$\alpha\big(q_i, k_j \big) = \frac{e^{\epsilon_{i,j}}}{\sum_l e^{\epsilon_{i,l}}}$$ A estos valores se les conoce como pesos de atención. Los pesos de atención conforman una matriz de atención.
  • Representación: Finalmente, se obtiene la representación de cada uno de las entradas a partir de los pesos de atención de la siguiente forma: $$h_i = \sum_s \alpha\big(q_i, k_j) \big) v_j$$ Es decir, la representación que se obtiene de la capa de auto-atención es una combinación lineal de los elementos de la entrada ponderados por los pesos de atención.

En la capa de auto-atención, tenemos que derivar desde los valores de salida, que corresponden a la representación $h_t$ y actualizar los pesos de la capa, los cuáles corresponden a las capas lineales en cada uno de los valores de query, key y value. Estos son los únicos pesos con los que la capa cuenta. Por tanto, debemos retropropagar la derivada hasta estos pesos.

La capa de auto-atención trabaja con varios valores de salida, y por tanto, podemos pensar su derivada de forma secuencial. Por lo que podemos pensar que tiene una función de riesgo similar a las redes recurrentes:

$$R(\theta) = \sum_t \sum_y L_t(f_y(x_t), y)$$

Recuérdese que $L_t$ es la función de pérdida, $f_y(x_t)$ es la salida en el tiempo $t$ en la neurona $y$.

En la forma en que hemos construido la capa de atención, los pesos más próximos corresponden a los valores del value. En el elemento $t$, tenemos:

$$\frac{\partial L_t(f_y(x_t), y)}{\partial W_{i,j}^v} = \frac{\partial L_t(f_y(x_t), y)}{\partial h_t} \frac{\partial h_{t}}{\partial W^v_{i,j}}$$

Como en casos previamente vistos, $\frac{\partial L_t(f_y(x_{t}), y)}{\partial h_{t}}$ denota la derivada que se retropropaga desde las capas superiores. Por tanto, debemos enfocarnos en la derivada en $\frac{\partial h_{t}}{\partial W^v_{i,j}}$, donde tenemos que:

\begin{align*} \frac{\partial h_{t}}{\partial W^v_{i,j}} &= \frac{\partial h_{t}}{\partial v_{s,i}} \frac{\partial v_{s,i}}{\partial W_{i,j}^v} \\ &= \sum_s \alpha\big(q_t, k_{s} \big) x_{s,j} \end{align*}

La derivada en este caso es simple, pues el valor de value se obtiene de forma lineal. En el caso del key y el query, la derivada tiene que pasar por la función Softmax y el producto punto escadalado. Ambas derivadas son similares. Para el caso del key tenemos:

\begin{align*} \frac{\partial h_{t}}{\partial W^k_{i,j}} &= \frac{\partial h_{t}}{\alpha(\epsilon_{t,s})} \frac{\partial \alpha(\epsilon_{t,s})}{\partial \epsilon_{t,s}} \frac{\partial \epsilon_{t,s}}{ \partial W_{i,j}^k} \\ &= \frac{1}{\sqrt{d}}\sum_s v_{s,i} \alpha(\epsilon_{t,s})\big( \delta_{t,s} - \alpha(\epsilon_{t,s}) \big) q_{t,i} x_{t,j} \end{align*}

Ya que el query multiplica al key en el producto punto escalado, se puede ver que la derivada para el query es:

\begin{align*} \frac{\partial h_{t}}{\partial W^q_{i,j}} &= \frac{\partial h_{t}}{\alpha(\epsilon_{t,s})} \frac{\partial \alpha(\epsilon_{t,s})}{\partial \epsilon_{t,s}} \frac{\partial \epsilon_{t,s}}{ \partial W_{i,j}^q} \\ &= \frac{1}{\sqrt{d}}\sum_s v_{s,i} \alpha(\epsilon_{t,s})\big( \delta_{t,s} - \alpha(\epsilon_{t,s}) \big) k_{t,i} x_{t,j} \end{align*}

Asimismo, las derivadas en la capa de auto-atención pueden retropropagar a una capa anterior tomando las derivadas anteriores (omitiendo los valores de $x_{s,j}$. De esta forma, las capa de auto-atención pueden integrarse en arquitecturas más complejas. En particular, este tipo de capas conforman la parte central de las arquitecturas de Transformers (transformadores), las cuales revisamos a continuación.

Referencias¶

Vaswani, A., Shazeer, N., Parmar, N., Uszkoreit, J., Jones, L., Gomez, A. N., ... & Polosukhin, I. (2017). Attention is all you need. Advances in neural information processing systems, 30.


Principal