Implementación de Word2Vec (Softmax)¶

En primer lugar, definimos las funciones necesarias: para stemming y para indexar el vocabulario.

In [31]:
#-*- encoding:utf-8 -*-
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
from collections import defaultdict, Counter
from nltk.stem.snowball import SnowballStemmer
from itertools import chain
from re import sub
from sklearn.decomposition import PCA
from operator import itemgetter
from nltk.corpus import cess_esp
from sklearn.model_selection import train_test_split
from nltk.tokenize import sent_tokenize, word_tokenize
from re import compile, sub

#Declaramos el stemeer
stemizer = SnowballStemmer('spanish')

#Funcion para stemming
def stems(string):
  stem_string = []
  for w in string:
    stem_string.append(stemizer.stem(w))
  return stem_string

#Funcion que crea un vocabulario de palabras con un indice numerico
def vocab():
    vocab = defaultdict()
    vocab.default_factory = lambda: len(vocab)
    return vocab    

#Funcion que pasa la cadena de simbolos a una secuencia con indices numericos
def text2numba(corpus, vocab):
    for doc in corpus:
        yield [vocab[w] for w in doc] #.split()]
  

Abrimos el corpus que bamos a utilizar y obtenemos los contextos.

In [32]:
#corpus = ['el perro come un hueso', 'un muchacho jugaba', 'el muchacho saltaba la cuerda',
#          'un perro come croquetas', 'un muchacho juega']

corpus = sent_tokenize(open('borges_aleph.txt','r', encoding='utf8').read().strip())
print(corpus[0])
La candente mañana de febrero en que Beatriz Viterbo murió, después de una imperiosa agonía que no se rebajó un solo instante ni al sentimentalismo ni al miedo, noté que las carteleras de fierro de la Plaza Constitución habían renovado no sé qué aviso de cigarrillos rubios; el hecho me dolió, pues comprendí que el incesante y vasto universo ya se apartaba de ella y que ese cambio era el primero de una serie infinita.

Posteriormente, limpiamos y stemizamos el corpus. Asimismo, creamos el vocabulario con índices numéricos y transformamos las cadenas de palabras encadenas de índices numéricos. Se crean venanas de $1\times 1$

In [33]:
#Abrimos el documento, lo limpiamos y separamos las cadenas
#corpus = sub(r'[^\w\s]','',uploaded['corpus.es'].decode('utf8').strip().lower()).split('\n')
regex = compile('[^a-zA-Z| |ñáéíóú]')
#Stemizamos el documento
#corpus = [sent.split() for sent in corpus] #[stems(sent) for sent in corpus]
corpus = [stems(word_tokenize(regex.sub('', sent).lower())) for sent in corpus]

#Llamamos la funcion para crear el vocabulario
idx = vocab()
#Creamos el vocabulario y le asignamos un indice a cada simbolo segun su aparicion
cads_idx = list(text2numba(corpus,idx))

print(corpus[0])
print(cads_idx[0])
['la', 'candent', 'mañan', 'de', 'febrer', 'en', 'que', 'beatriz', 'viterb', 'mur', 'despues', 'de', 'una', 'imperi', 'agon', 'que', 'no', 'se', 'rebaj', 'un', 'sol', 'instant', 'ni', 'al', 'sentimental', 'ni', 'al', 'mied', 'not', 'que', 'las', 'carteler', 'de', 'fierr', 'de', 'la', 'plaz', 'constitu', 'hab', 'renov', 'no', 'se', 'que', 'avis', 'de', 'cigarrill', 'rubi', 'el', 'hech', 'me', 'dol', 'pues', 'comprend', 'que', 'el', 'inces', 'y', 'vast', 'univers', 'ya', 'se', 'apart', 'de', 'ella', 'y', 'que', 'ese', 'cambi', 'era', 'el', 'primer', 'de', 'una', 'seri', 'infinit']
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 3, 11, 12, 13, 6, 14, 15, 16, 17, 18, 19, 20, 21, 22, 20, 21, 23, 24, 6, 25, 26, 3, 27, 3, 0, 28, 29, 30, 31, 14, 15, 6, 32, 3, 33, 34, 35, 36, 37, 38, 39, 40, 6, 35, 41, 42, 43, 44, 45, 15, 46, 3, 47, 42, 6, 48, 49, 50, 35, 51, 3, 11, 52, 53]
In [4]:
cadenas = cads_idx 

#Se obtiene la longitud del alfabeto
N = len(idx)

print(cadenas[0])
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 3, 11, 12, 13, 6, 14, 15, 16, 17, 18, 19, 20, 21, 22, 20, 21, 23, 24, 6, 25, 26, 3, 27, 3, 0, 28, 29, 30, 31, 14, 15, 6, 32, 3, 33, 34, 35, 36, 37, 38, 39, 40, 6, 35, 41, 42, 43, 44, 45, 15, 46, 3, 47, 42, 6, 48, 49, 50, 35, 51, 3, 11, 52, 53]

Hecho esto, extraemos los bigramas del texto.

In [5]:
#Se crean los bigramas
contexts = list(chain(*[zip(cad,cad[1:]) for cad in cadenas])) + list(chain(*[zip(cad[1:],cad) for cad in cadenas]))

#Se obtiene la frecuencia de cada bigrama
frecContexts = Counter(contexts)

print(contexts[:100])
[(0, 1), (1, 2), (2, 3), (3, 4), (4, 5), (5, 6), (6, 7), (7, 8), (8, 9), (9, 10), (10, 3), (3, 11), (11, 12), (12, 13), (13, 6), (6, 14), (14, 15), (15, 16), (16, 17), (17, 18), (18, 19), (19, 20), (20, 21), (21, 22), (22, 20), (20, 21), (21, 23), (23, 24), (24, 6), (6, 25), (25, 26), (26, 3), (3, 27), (27, 3), (3, 0), (0, 28), (28, 29), (29, 30), (30, 31), (31, 14), (14, 15), (15, 6), (6, 32), (32, 3), (3, 33), (33, 34), (34, 35), (35, 36), (36, 37), (37, 38), (38, 39), (39, 40), (40, 6), (6, 35), (35, 41), (41, 42), (42, 43), (43, 44), (44, 45), (45, 15), (15, 46), (46, 3), (3, 47), (47, 42), (42, 6), (6, 48), (48, 49), (49, 50), (50, 35), (35, 51), (51, 3), (3, 11), (11, 52), (52, 53), (49, 35), (35, 44), (44, 54), (54, 55), (55, 14), (14, 56), (56, 57), (57, 58), (58, 59), (59, 60), (60, 61), (61, 62), (62, 15), (15, 63), (63, 64), (64, 65), (65, 0), (0, 30), (30, 66), (66, 67), (67, 55), (55, 68), (68, 69), (69, 70), (70, 71), (71, 72)]

Ahora, paamos a la rd neuronla. Primero, inicializamos los parámetros de la red. Requerimos de dos matrices U (la matriz de embedding) y W (la matriz de la capa de salida).

In [6]:
np.random.seed(0)
#El número de rasgos que representan cada vector
nn_input_dim = N
#El total de clases que arrojará
output_dim = N
#El número de ejmplos
num_examples = len(contexts)

#Dimensiones de los vectores-palabra
dim = 2

#Embedding
C = np.random.randn(dim, N) / np.sqrt(N)

#Capa de salida
W = np.random.randn(N,dim) / np.sqrt(dim)

print(C.shape)
print(W.shape)
(2, 1524)
(1524, 2)

Ahora entrenamos la red con el algoritmo de backpropagation y de gradiente descendiente.

In [7]:
%%time
#Guarda el riesgo por época
R = []

#Hiperparámetros
its = 100
eta = 0.1
for i in range(0,its):
    #Acumula el riesgo de los 
    #ejemplos en la época actual
    R_x = 0
    for ex in contexts:
        #Forward
        #Embedimiento
        u_w = C.T[ex[0]]
        #salida
        a = np.dot(W,u_w)
        out = np.exp(a) # - np.max(a))
        #Softmax
        f = out/out.sum(0)
        #Suma riesgo por ejemplo
        R_x += -np.log(f)[ex[1]]

        #Backprop
        #Variable de salida
        d_out = f
        d_out[ex[1]] -= 1
        
        #Variable de embedding
        d_emb = np.dot(d_out,W)
        
        #Actualizacion de salida
        W -= eta*np.outer(d_out,u_w)

        #Actualizacion de embedding
        C.T[ex[0]] -= eta*d_emb
    
    #Guarda el riesgo en la época
    R.append(R_x)
    #Imprime información de época
    print('Fin de la iteración {}. Riesgo: {}'.format(i, R_x))
Fin de la iteración 0. Riesgo: 62381.43560463933
Fin de la iteración 1. Riesgo: 57249.781347191805
Fin de la iteración 2. Riesgo: 53047.141059796355
Fin de la iteración 3. Riesgo: 51374.9493898501
Fin de la iteración 4. Riesgo: 50607.57749381352
Fin de la iteración 5. Riesgo: 50132.26891223573
Fin de la iteración 6. Riesgo: 49791.11830223971
Fin de la iteración 7. Riesgo: 49526.6929912567
Fin de la iteración 8. Riesgo: 49313.773906421826
Fin de la iteración 9. Riesgo: 49138.70523447623
Fin de la iteración 10. Riesgo: 48992.95779203332
Fin de la iteración 11. Riesgo: 48870.602307142675
Fin de la iteración 12. Riesgo: 48767.20061005804
Fin de la iteración 13. Riesgo: 48679.342081658906
Fin de la iteración 14. Riesgo: 48604.330886072734
Fin de la iteración 15. Riesgo: 48539.96536081687
Fin de la iteración 16. Riesgo: 48484.455329277014
Fin de la iteración 17. Riesgo: 48436.38487175208
Fin de la iteración 18. Riesgo: 48394.63049358595
Fin de la iteración 19. Riesgo: 48358.25661567158
Fin de la iteración 20. Riesgo: 48326.44047140278
Fin de la iteración 21. Riesgo: 48298.44808681481
Fin de la iteración 22. Riesgo: 48273.64519033824
Fin de la iteración 23. Riesgo: 48251.50904325735
Fin de la iteración 24. Riesgo: 48231.62120671019
Fin de la iteración 25. Riesgo: 48213.646837867396
Fin de la iteración 26. Riesgo: 48197.313259680705
Fin de la iteración 27. Riesgo: 48182.39393444583
Fin de la iteración 28. Riesgo: 48168.698235765376
Fin de la iteración 29. Riesgo: 48156.06568571495
Fin de la iteración 30. Riesgo: 48144.36360087332
Fin de la iteración 31. Riesgo: 48133.48689044062
Fin de la iteración 32. Riesgo: 48123.357237646516
Fin de la iteración 33. Riesgo: 48113.91785562135
Fin de la iteración 34. Riesgo: 48105.12282918956
Fin de la iteración 35. Riesgo: 48096.92645282405
Fin de la iteración 36. Riesgo: 48089.2796727821
Fin de la iteración 37. Riesgo: 48082.133345631366
Fin de la iteración 38. Riesgo: 48075.44213597053
Fin de la iteración 39. Riesgo: 48069.16567485178
Fin de la iteración 40. Riesgo: 48063.26793337794
Fin de la iteración 41. Riesgo: 48057.7163912017
Fin de la iteración 42. Riesgo: 48052.481533124876
Fin de la iteración 43. Riesgo: 48047.53659158593
Fin de la iteración 44. Riesgo: 48042.85737382608
Fin de la iteración 45. Riesgo: 48038.4220916769
Fin de la iteración 46. Riesgo: 48034.21117507606
Fin de la iteración 47. Riesgo: 48030.20707572185
Fin de la iteración 48. Riesgo: 48026.3940716228
Fin de la iteración 49. Riesgo: 48022.75808086226
Fin de la iteración 50. Riesgo: 48019.28648970376
Fin de la iteración 51. Riesgo: 48015.967997692125
Fin de la iteración 52. Riesgo: 48012.79248039494
Fin de la iteración 53. Riesgo: 48009.75086865333
Fin de la iteración 54. Riesgo: 48006.835041687584
Fin de la iteración 55. Riesgo: 48004.03773041518
Fin de la iteración 56. Riesgo: 48001.35242709017
Fin de la iteración 57. Riesgo: 47998.77329790933
Fin de la iteración 58. Riesgo: 47996.29509643762
Fin de la iteración 59. Riesgo: 47993.91307728436
Fin de la iteración 60. Riesgo: 47991.622911058286
Fin de la iteración 61. Riesgo: 47989.420602919214
Fin de la iteración 62. Riesgo: 47987.30241777354
Fin de la iteración 63. Riesgo: 47985.26481518958
Fin de la iteración 64. Riesgo: 47983.304396499196
Fin de la iteración 65. Riesgo: 47981.41786543814
Fin de la iteración 66. Riesgo: 47979.60200238597
Fin de la iteración 67. Riesgo: 47977.85365100051
Fin de la iteración 68. Riesgo: 47976.169715105
Fin de la iteración 69. Riesgo: 47974.547163151656
Fin de la iteración 70. Riesgo: 47972.98303750817
Fin de la iteración 71. Riesgo: 47971.47446604892
Fin de la iteración 72. Riesgo: 47970.018674049184
Fin de la iteración 73. Riesgo: 47968.61299493079
Fin de la iteración 74. Riesgo: 47967.254878989916
Fin de la iteración 75. Riesgo: 47965.94189969099
Fin de la iteración 76. Riesgo: 47964.67175747292
Fin de la iteración 77. Riesgo: 47963.44228124322
Fin de la iteración 78. Riesgo: 47962.25142786518
Fin de la iteración 79. Riesgo: 47961.097279993504
Fin de la iteración 80. Riesgo: 47959.97804261523
Fin de la iteración 81. Riesgo: 47958.892038613
Fin de la iteración 82. Riesgo: 47957.837703628385
Fin de la iteración 83. Riesgo: 47956.813580436574
Fin de la iteración 84. Riesgo: 47955.8183130103
Fin de la iteración 85. Riesgo: 47954.85064038987
Fin de la iteración 86. Riesgo: 47953.90939045614
Fin de la iteración 87. Riesgo: 47952.99347367351
Fin de la iteración 88. Riesgo: 47952.101876845525
Fin de la iteración 89. Riesgo: 47951.233656932905
Fin de la iteración 90. Riesgo: 47950.38793496217
Fin de la iteración 91. Riesgo: 47949.56389006493
Fin de la iteración 92. Riesgo: 47948.76075369099
Fin de la iteración 93. Riesgo: 47947.97780403872
Fin de la iteración 94. Riesgo: 47947.21436075266
Fin de la iteración 95. Riesgo: 47946.46977994515
Fin de la iteración 96. Riesgo: 47945.74344959325
Fin de la iteración 97. Riesgo: 47945.03478536264
Fin de la iteración 98. Riesgo: 47944.34322690278
Fin de la iteración 99. Riesgo: 47943.66823464552
CPU times: user 1min 17s, sys: 43.7 ms, total: 1min 17s
Wall time: 1min 17s

Podemos visualizar como se minimiza la función de riesgo a través de las iteraciones.

In [8]:
#Ploteo del riesgo
plt.plot(R, 'v-')
plt.title('Riesgo a través de las iteraciones')
plt.xlabel('Iteración')
plt.ylabel('Riesgo')
plt.show()

Aplicación de la red¶

Entrenada la red, definimos una función forward para obtener las probabilidades a partir de la red ya entrenada.

In [9]:
#Forward
def forward(x):    
    #Embedimiento
    u_w = C.T[x]
    #Capa de salida
    out = np.exp(np.dot(W,u_w))
    p = out/out.sum(0)
    return p

Podemos probar cómo son las probabilidades de la red. En este caso, lo hacemos para el símbolo BOS.

In [28]:
#for word in idx.keys():
probs = sorted(list(zip(idx.keys(),forward(idx['infinit']))), key=itemgetter(1), reverse=True)
probs[:20]
Out[28]:
[('de', 0.038176918142471886),
 ('en', 0.017909391041895956),
 ('a', 0.014557162648726535),
 ('y', 0.014334084561012865),
 ('que', 0.011315246430016352),
 ('vi', 0.007330155355420444),
 ('tod', 0.006928234128033684),
 ('aleph', 0.006317570202325492),
 ('par', 0.006088867848752325),
 ('es', 0.005744297763202286),
 ('por', 0.0057261053981883915),
 ('era', 0.004747870099240432),
 ('del', 0.004549816290149445),
 ('con', 0.004353379088178406),
 ('argentin', 0.004123669943301754),
 ('espej', 0.004077811849812892),
 ('per', 0.003991795942585374),
 ('sol', 0.0038395116084809127),
 ('o', 0.0035962397547391212),
 ('se', 0.0035133773986019854)]

Vectores distribuidos¶

Los vectores de word embeddings se almacenan en la matriz de la capa de embedding (capa oculta). De esta forma, cada columna de la matriz corresponde a un vector que representa una palabra.

In [11]:
pd.DataFrame(data=C.T, index=list(idx.keys()))
Out[11]:
0 1
la 0.071946 1.571248
candent 2.120305 -0.557875
mañan -2.338513 2.880450
de 0.013222 0.371907
febrer 1.694601 3.931954
... ... ...
ment 0.957419 1.031564
poros 1.093793 1.628693
tragic 3.008104 -1.405820
erosion -0.072291 1.547840
estel -1.130263 1.987682

1524 rows × 2 columns

Podemos, entonces, visualizar los datos en un espacio vectorial.

In [12]:
#Función para visualizar palabras
def plot_words(Z,ids):
    Z = PCA(2).fit_transform(Z)
    r=0
    plt.scatter(Z[:,0],Z[:,1], marker='o', c='blue')
    for label,x,y in zip(ids, Z[:,0], Z[:,1]):
        plt.annotate(label, xy=(x,y), xytext=(-1,1), textcoords='offset points', ha='center', va='bottom')
        r+=1
In [36]:
#Visualización de los embeddings
plot_words(C.T[:100], list(idx.keys())[:100])
plt.title('Embeddings')
plt.show()

Implementación con Pytorch¶

La paquetería de Pytorch permite implementar modelos basados en redes neuronales. Por tanto, podemos implementar una red para obtener embeddings como en el algoritmo de Word2Vec.

In [14]:
import torch
import torch.nn as nn

Definimos la red. En este caso, Pytorch ya cuenta con capas de embeddings, por lo que únicamente bastará señalar que queremos una capa de tipo Embedding.

In [15]:
network = nn.Sequential(nn.Embedding(N, 2),nn.Linear(2, N, bias=False), nn.Softmax(dim=1))

Definimos la función de riesgo y el optimizador:

In [16]:
risk = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(network.parameters(), lr=0.1)

Procedemos a entrenar la red con los datos:

In [17]:
%%time
R2 = []
for i in range(0,its):
    R_it = 0
    for ex in contexts:
        optimizer.zero_grad()
        probs = network(torch.tensor([ex[0]]))
        loss = risk(probs, torch.LongTensor([ex[1]]))
        loss.backward()
        optimizer.step()
        R_it += loss.detach()
        
    #Guarda el riesgo en la época
    R2.append(R_it)
    #Imprime información de época
    print('Fin de la iteración {}. Riesgo: {}'.format(i, R_it))
Fin de la iteración 0. Riesgo: 63715.2265625
Fin de la iteración 1. Riesgo: 63715.22265625
Fin de la iteración 2. Riesgo: 63715.22265625
Fin de la iteración 3. Riesgo: 63715.21875
Fin de la iteración 4. Riesgo: 63715.2109375
Fin de la iteración 5. Riesgo: 63715.2109375
Fin de la iteración 6. Riesgo: 63715.203125
Fin de la iteración 7. Riesgo: 63715.1875
Fin de la iteración 8. Riesgo: 63715.18359375
Fin de la iteración 9. Riesgo: 63715.18359375
Fin de la iteración 10. Riesgo: 63715.18359375
Fin de la iteración 11. Riesgo: 63715.1796875
Fin de la iteración 12. Riesgo: 63715.1640625
Fin de la iteración 13. Riesgo: 63715.15234375
Fin de la iteración 14. Riesgo: 63715.1484375
Fin de la iteración 15. Riesgo: 63715.13671875
Fin de la iteración 16. Riesgo: 63715.12109375
Fin de la iteración 17. Riesgo: 63715.11328125
Fin de la iteración 18. Riesgo: 63715.0859375
Fin de la iteración 19. Riesgo: 63715.06640625
Fin de la iteración 20. Riesgo: 63715.03125
Fin de la iteración 21. Riesgo: 63714.99609375
Fin de la iteración 22. Riesgo: 63714.96484375
Fin de la iteración 23. Riesgo: 63714.859375
Fin de la iteración 24. Riesgo: 63714.7734375
Fin de la iteración 25. Riesgo: 63714.625
Fin de la iteración 26. Riesgo: 63714.45703125
Fin de la iteración 27. Riesgo: 63714.08984375
Fin de la iteración 28. Riesgo: 63712.9609375
Fin de la iteración 29. Riesgo: 63703.78515625
Fin de la iteración 30. Riesgo: 63658.359375
Fin de la iteración 31. Riesgo: 63621.3828125
Fin de la iteración 32. Riesgo: 63594.75
Fin de la iteración 33. Riesgo: 63571.46875
Fin de la iteración 34. Riesgo: 63526.625
Fin de la iteración 35. Riesgo: 63514.59765625
Fin de la iteración 36. Riesgo: 63502.90625
Fin de la iteración 37. Riesgo: 63494.09375
Fin de la iteración 38. Riesgo: 63481.578125
Fin de la iteración 39. Riesgo: 63471.625
Fin de la iteración 40. Riesgo: 63466.54296875
Fin de la iteración 41. Riesgo: 63459.5625
Fin de la iteración 42. Riesgo: 63453.41015625
Fin de la iteración 43. Riesgo: 63449.34765625
Fin de la iteración 44. Riesgo: 63447.5546875
Fin de la iteración 45. Riesgo: 63446.2265625
Fin de la iteración 46. Riesgo: 63444.2265625
Fin de la iteración 47. Riesgo: 63442.70703125
Fin de la iteración 48. Riesgo: 63440.9296875
Fin de la iteración 49. Riesgo: 63440.46875
Fin de la iteración 50. Riesgo: 63440.23828125
Fin de la iteración 51. Riesgo: 63439.515625
Fin de la iteración 52. Riesgo: 63438.84375
Fin de la iteración 53. Riesgo: 63438.27734375
Fin de la iteración 54. Riesgo: 63437.69921875
Fin de la iteración 55. Riesgo: 63436.89453125
Fin de la iteración 56. Riesgo: 63436.3984375
Fin de la iteración 57. Riesgo: 63436.26953125
Fin de la iteración 58. Riesgo: 63435.86328125
Fin de la iteración 59. Riesgo: 63435.06640625
Fin de la iteración 60. Riesgo: 63434.03125
Fin de la iteración 61. Riesgo: 63433.32421875
Fin de la iteración 62. Riesgo: 63433.00390625
Fin de la iteración 63. Riesgo: 63432.28125
Fin de la iteración 64. Riesgo: 63431.77734375
Fin de la iteración 65. Riesgo: 63430.52734375
Fin de la iteración 66. Riesgo: 63429.86328125
Fin de la iteración 67. Riesgo: 63429.40234375
Fin de la iteración 68. Riesgo: 63429.3515625
Fin de la iteración 69. Riesgo: 63429.29296875
Fin de la iteración 70. Riesgo: 63429.14453125
Fin de la iteración 71. Riesgo: 63428.60546875
Fin de la iteración 72. Riesgo: 63428.1015625
Fin de la iteración 73. Riesgo: 63427.6640625
Fin de la iteración 74. Riesgo: 63426.8828125
Fin de la iteración 75. Riesgo: 63426.19921875
Fin de la iteración 76. Riesgo: 63425.37109375
Fin de la iteración 77. Riesgo: 63424.07421875
Fin de la iteración 78. Riesgo: 63421.55078125
Fin de la iteración 79. Riesgo: 63420.09765625
Fin de la iteración 80. Riesgo: 63418.921875
Fin de la iteración 81. Riesgo: 63417.0078125
Fin de la iteración 82. Riesgo: 63415.640625
Fin de la iteración 83. Riesgo: 63414.37109375
Fin de la iteración 84. Riesgo: 63412.89453125
Fin de la iteración 85. Riesgo: 63411.828125
Fin de la iteración 86. Riesgo: 63410.1796875
Fin de la iteración 87. Riesgo: 63408.87109375
Fin de la iteración 88. Riesgo: 63406.9765625
Fin de la iteración 89. Riesgo: 63405.359375
Fin de la iteración 90. Riesgo: 63404.3984375
Fin de la iteración 91. Riesgo: 63403.70703125
Fin de la iteración 92. Riesgo: 63400.0234375
Fin de la iteración 93. Riesgo: 63395.16015625
Fin de la iteración 94. Riesgo: 63394.74609375
Fin de la iteración 95. Riesgo: 63394.4921875
Fin de la iteración 96. Riesgo: 63394.17578125
Fin de la iteración 97. Riesgo: 63393.8046875
Fin de la iteración 98. Riesgo: 63393.421875
Fin de la iteración 99. Riesgo: 63393.03125
CPU times: user 33min 24s, sys: 14 s, total: 33min 38s
Wall time: 5min 35s

Podemos visualizar el riesgo:

In [18]:
#Ploteo del riesgo
plt.plot(R2, 'v-')
plt.title('Riesgo a través de las iteraciones')
plt.xlabel('Iteración')
plt.ylabel('Riesgo')
plt.show()

Finalmente, observamos el comportamiento de los embeddings:

In [19]:
C2 = network[0].weight.detach()
plot_words(C2[:100], list(idx.keys())[:100])
plt.title('Embeddings obtenidos con pytorch')
plt.show()