En primer lugar, definimos las funciones necesarias: para stemming y para indexar el vocabulario.
#-*- encoding:utf-8 -*-
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
from collections import defaultdict, Counter
from nltk.stem.snowball import SnowballStemmer
from itertools import chain
from re import sub
from sklearn.decomposition import PCA
from operator import itemgetter
from nltk.corpus import cess_esp
from sklearn.model_selection import train_test_split
from nltk.tokenize import sent_tokenize, word_tokenize
from re import compile, sub
#Declaramos el stemeer
stemizer = SnowballStemmer('spanish')
#Funcion para stemming
def stems(string):
stem_string = []
for w in string:
stem_string.append(stemizer.stem(w))
return stem_string
#Funcion que crea un vocabulario de palabras con un indice numerico
def vocab():
vocab = defaultdict()
vocab.default_factory = lambda: len(vocab)
return vocab
#Funcion que pasa la cadena de simbolos a una secuencia con indices numericos
def text2numba(corpus, vocab):
for doc in corpus:
yield [vocab[w] for w in doc] #.split()]
Abrimos el corpus que bamos a utilizar y obtenemos los contextos.
#corpus = ['el perro come un hueso', 'un muchacho jugaba', 'el muchacho saltaba la cuerda',
# 'un perro come croquetas', 'un muchacho juega']
corpus = sent_tokenize(open('borges_aleph.txt','r', encoding='utf8').read().strip())
print(corpus[0])
La candente mañana de febrero en que Beatriz Viterbo murió, después de una imperiosa agonía que no se rebajó un solo instante ni al sentimentalismo ni al miedo, noté que las carteleras de fierro de la Plaza Constitución habían renovado no sé qué aviso de cigarrillos rubios; el hecho me dolió, pues comprendí que el incesante y vasto universo ya se apartaba de ella y que ese cambio era el primero de una serie infinita.
Posteriormente, limpiamos y stemizamos el corpus. Asimismo, creamos el vocabulario con índices numéricos y transformamos las cadenas de palabras encadenas de índices numéricos. Se crean venanas de $1\times 1$
#Abrimos el documento, lo limpiamos y separamos las cadenas
#corpus = sub(r'[^\w\s]','',uploaded['corpus.es'].decode('utf8').strip().lower()).split('\n')
regex = compile('[^a-zA-Z| |ñáéíóú]')
#Stemizamos el documento
#corpus = [sent.split() for sent in corpus] #[stems(sent) for sent in corpus]
corpus = [stems(word_tokenize(regex.sub('', sent).lower())) for sent in corpus]
#Llamamos la funcion para crear el vocabulario
idx = vocab()
#Creamos el vocabulario y le asignamos un indice a cada simbolo segun su aparicion
cads_idx = list(text2numba(corpus,idx))
print(corpus[0])
print(cads_idx[0])
['la', 'candent', 'mañan', 'de', 'febrer', 'en', 'que', 'beatriz', 'viterb', 'mur', 'despues', 'de', 'una', 'imperi', 'agon', 'que', 'no', 'se', 'rebaj', 'un', 'sol', 'instant', 'ni', 'al', 'sentimental', 'ni', 'al', 'mied', 'not', 'que', 'las', 'carteler', 'de', 'fierr', 'de', 'la', 'plaz', 'constitu', 'hab', 'renov', 'no', 'se', 'que', 'avis', 'de', 'cigarrill', 'rubi', 'el', 'hech', 'me', 'dol', 'pues', 'comprend', 'que', 'el', 'inces', 'y', 'vast', 'univers', 'ya', 'se', 'apart', 'de', 'ella', 'y', 'que', 'ese', 'cambi', 'era', 'el', 'primer', 'de', 'una', 'seri', 'infinit'] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 3, 11, 12, 13, 6, 14, 15, 16, 17, 18, 19, 20, 21, 22, 20, 21, 23, 24, 6, 25, 26, 3, 27, 3, 0, 28, 29, 30, 31, 14, 15, 6, 32, 3, 33, 34, 35, 36, 37, 38, 39, 40, 6, 35, 41, 42, 43, 44, 45, 15, 46, 3, 47, 42, 6, 48, 49, 50, 35, 51, 3, 11, 52, 53]
cadenas = cads_idx
#Se obtiene la longitud del alfabeto
N = len(idx)
print(cadenas[0])
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 3, 11, 12, 13, 6, 14, 15, 16, 17, 18, 19, 20, 21, 22, 20, 21, 23, 24, 6, 25, 26, 3, 27, 3, 0, 28, 29, 30, 31, 14, 15, 6, 32, 3, 33, 34, 35, 36, 37, 38, 39, 40, 6, 35, 41, 42, 43, 44, 45, 15, 46, 3, 47, 42, 6, 48, 49, 50, 35, 51, 3, 11, 52, 53]
Hecho esto, extraemos los bigramas del texto.
#Se crean los bigramas
contexts = list(chain(*[zip(cad,cad[1:]) for cad in cadenas])) + list(chain(*[zip(cad[1:],cad) for cad in cadenas]))
#Se obtiene la frecuencia de cada bigrama
frecContexts = Counter(contexts)
print(contexts[:100])
[(0, 1), (1, 2), (2, 3), (3, 4), (4, 5), (5, 6), (6, 7), (7, 8), (8, 9), (9, 10), (10, 3), (3, 11), (11, 12), (12, 13), (13, 6), (6, 14), (14, 15), (15, 16), (16, 17), (17, 18), (18, 19), (19, 20), (20, 21), (21, 22), (22, 20), (20, 21), (21, 23), (23, 24), (24, 6), (6, 25), (25, 26), (26, 3), (3, 27), (27, 3), (3, 0), (0, 28), (28, 29), (29, 30), (30, 31), (31, 14), (14, 15), (15, 6), (6, 32), (32, 3), (3, 33), (33, 34), (34, 35), (35, 36), (36, 37), (37, 38), (38, 39), (39, 40), (40, 6), (6, 35), (35, 41), (41, 42), (42, 43), (43, 44), (44, 45), (45, 15), (15, 46), (46, 3), (3, 47), (47, 42), (42, 6), (6, 48), (48, 49), (49, 50), (50, 35), (35, 51), (51, 3), (3, 11), (11, 52), (52, 53), (49, 35), (35, 44), (44, 54), (54, 55), (55, 14), (14, 56), (56, 57), (57, 58), (58, 59), (59, 60), (60, 61), (61, 62), (62, 15), (15, 63), (63, 64), (64, 65), (65, 0), (0, 30), (30, 66), (66, 67), (67, 55), (55, 68), (68, 69), (69, 70), (70, 71), (71, 72)]
Ahora, paamos a la rd neuronla. Primero, inicializamos los parámetros de la red. Requerimos de dos matrices U (la matriz de embedding) y W (la matriz de la capa de salida).
np.random.seed(0)
#El número de rasgos que representan cada vector
nn_input_dim = N
#El total de clases que arrojará
output_dim = N
#El número de ejmplos
num_examples = len(contexts)
#Dimensiones de los vectores-palabra
dim = 2
#Embedding
C = np.random.randn(dim, N) / np.sqrt(N)
#Capa de salida
W = np.random.randn(N,dim) / np.sqrt(dim)
print(C.shape)
print(W.shape)
(2, 1524) (1524, 2)
Ahora entrenamos la red con el algoritmo de backpropagation y de gradiente descendiente.
%%time
#Guarda el riesgo por época
R = []
#Hiperparámetros
its = 100
eta = 0.1
for i in range(0,its):
#Acumula el riesgo de los
#ejemplos en la época actual
R_x = 0
for ex in contexts:
#Forward
#Embedimiento
u_w = C.T[ex[0]]
#salida
a = np.dot(W,u_w)
out = np.exp(a) # - np.max(a))
#Softmax
f = out/out.sum(0)
#Suma riesgo por ejemplo
R_x += -np.log(f)[ex[1]]
#Backprop
#Variable de salida
d_out = f
d_out[ex[1]] -= 1
#Variable de embedding
d_emb = np.dot(d_out,W)
#Actualizacion de salida
W -= eta*np.outer(d_out,u_w)
#Actualizacion de embedding
C.T[ex[0]] -= eta*d_emb
#Guarda el riesgo en la época
R.append(R_x)
#Imprime información de época
print('Fin de la iteración {}. Riesgo: {}'.format(i, R_x))
Fin de la iteración 0. Riesgo: 62381.43560463933 Fin de la iteración 1. Riesgo: 57249.781347191805 Fin de la iteración 2. Riesgo: 53047.141059796355 Fin de la iteración 3. Riesgo: 51374.9493898501 Fin de la iteración 4. Riesgo: 50607.57749381352 Fin de la iteración 5. Riesgo: 50132.26891223573 Fin de la iteración 6. Riesgo: 49791.11830223971 Fin de la iteración 7. Riesgo: 49526.6929912567 Fin de la iteración 8. Riesgo: 49313.773906421826 Fin de la iteración 9. Riesgo: 49138.70523447623 Fin de la iteración 10. Riesgo: 48992.95779203332 Fin de la iteración 11. Riesgo: 48870.602307142675 Fin de la iteración 12. Riesgo: 48767.20061005804 Fin de la iteración 13. Riesgo: 48679.342081658906 Fin de la iteración 14. Riesgo: 48604.330886072734 Fin de la iteración 15. Riesgo: 48539.96536081687 Fin de la iteración 16. Riesgo: 48484.455329277014 Fin de la iteración 17. Riesgo: 48436.38487175208 Fin de la iteración 18. Riesgo: 48394.63049358595 Fin de la iteración 19. Riesgo: 48358.25661567158 Fin de la iteración 20. Riesgo: 48326.44047140278 Fin de la iteración 21. Riesgo: 48298.44808681481 Fin de la iteración 22. Riesgo: 48273.64519033824 Fin de la iteración 23. Riesgo: 48251.50904325735 Fin de la iteración 24. Riesgo: 48231.62120671019 Fin de la iteración 25. Riesgo: 48213.646837867396 Fin de la iteración 26. Riesgo: 48197.313259680705 Fin de la iteración 27. Riesgo: 48182.39393444583 Fin de la iteración 28. Riesgo: 48168.698235765376 Fin de la iteración 29. Riesgo: 48156.06568571495 Fin de la iteración 30. Riesgo: 48144.36360087332 Fin de la iteración 31. Riesgo: 48133.48689044062 Fin de la iteración 32. Riesgo: 48123.357237646516 Fin de la iteración 33. Riesgo: 48113.91785562135 Fin de la iteración 34. Riesgo: 48105.12282918956 Fin de la iteración 35. Riesgo: 48096.92645282405 Fin de la iteración 36. Riesgo: 48089.2796727821 Fin de la iteración 37. Riesgo: 48082.133345631366 Fin de la iteración 38. Riesgo: 48075.44213597053 Fin de la iteración 39. Riesgo: 48069.16567485178 Fin de la iteración 40. Riesgo: 48063.26793337794 Fin de la iteración 41. Riesgo: 48057.7163912017 Fin de la iteración 42. Riesgo: 48052.481533124876 Fin de la iteración 43. Riesgo: 48047.53659158593 Fin de la iteración 44. Riesgo: 48042.85737382608 Fin de la iteración 45. Riesgo: 48038.4220916769 Fin de la iteración 46. Riesgo: 48034.21117507606 Fin de la iteración 47. Riesgo: 48030.20707572185 Fin de la iteración 48. Riesgo: 48026.3940716228 Fin de la iteración 49. Riesgo: 48022.75808086226 Fin de la iteración 50. Riesgo: 48019.28648970376 Fin de la iteración 51. Riesgo: 48015.967997692125 Fin de la iteración 52. Riesgo: 48012.79248039494 Fin de la iteración 53. Riesgo: 48009.75086865333 Fin de la iteración 54. Riesgo: 48006.835041687584 Fin de la iteración 55. Riesgo: 48004.03773041518 Fin de la iteración 56. Riesgo: 48001.35242709017 Fin de la iteración 57. Riesgo: 47998.77329790933 Fin de la iteración 58. Riesgo: 47996.29509643762 Fin de la iteración 59. Riesgo: 47993.91307728436 Fin de la iteración 60. Riesgo: 47991.622911058286 Fin de la iteración 61. Riesgo: 47989.420602919214 Fin de la iteración 62. Riesgo: 47987.30241777354 Fin de la iteración 63. Riesgo: 47985.26481518958 Fin de la iteración 64. Riesgo: 47983.304396499196 Fin de la iteración 65. Riesgo: 47981.41786543814 Fin de la iteración 66. Riesgo: 47979.60200238597 Fin de la iteración 67. Riesgo: 47977.85365100051 Fin de la iteración 68. Riesgo: 47976.169715105 Fin de la iteración 69. Riesgo: 47974.547163151656 Fin de la iteración 70. Riesgo: 47972.98303750817 Fin de la iteración 71. Riesgo: 47971.47446604892 Fin de la iteración 72. Riesgo: 47970.018674049184 Fin de la iteración 73. Riesgo: 47968.61299493079 Fin de la iteración 74. Riesgo: 47967.254878989916 Fin de la iteración 75. Riesgo: 47965.94189969099 Fin de la iteración 76. Riesgo: 47964.67175747292 Fin de la iteración 77. Riesgo: 47963.44228124322 Fin de la iteración 78. Riesgo: 47962.25142786518 Fin de la iteración 79. Riesgo: 47961.097279993504 Fin de la iteración 80. Riesgo: 47959.97804261523 Fin de la iteración 81. Riesgo: 47958.892038613 Fin de la iteración 82. Riesgo: 47957.837703628385 Fin de la iteración 83. Riesgo: 47956.813580436574 Fin de la iteración 84. Riesgo: 47955.8183130103 Fin de la iteración 85. Riesgo: 47954.85064038987 Fin de la iteración 86. Riesgo: 47953.90939045614 Fin de la iteración 87. Riesgo: 47952.99347367351 Fin de la iteración 88. Riesgo: 47952.101876845525 Fin de la iteración 89. Riesgo: 47951.233656932905 Fin de la iteración 90. Riesgo: 47950.38793496217 Fin de la iteración 91. Riesgo: 47949.56389006493 Fin de la iteración 92. Riesgo: 47948.76075369099 Fin de la iteración 93. Riesgo: 47947.97780403872 Fin de la iteración 94. Riesgo: 47947.21436075266 Fin de la iteración 95. Riesgo: 47946.46977994515 Fin de la iteración 96. Riesgo: 47945.74344959325 Fin de la iteración 97. Riesgo: 47945.03478536264 Fin de la iteración 98. Riesgo: 47944.34322690278 Fin de la iteración 99. Riesgo: 47943.66823464552 CPU times: user 1min 17s, sys: 43.7 ms, total: 1min 17s Wall time: 1min 17s
Podemos visualizar como se minimiza la función de riesgo a través de las iteraciones.
#Ploteo del riesgo
plt.plot(R, 'v-')
plt.title('Riesgo a través de las iteraciones')
plt.xlabel('Iteración')
plt.ylabel('Riesgo')
plt.show()
Entrenada la red, definimos una función forward para obtener las probabilidades a partir de la red ya entrenada.
#Forward
def forward(x):
#Embedimiento
u_w = C.T[x]
#Capa de salida
out = np.exp(np.dot(W,u_w))
p = out/out.sum(0)
return p
Podemos probar cómo son las probabilidades de la red. En este caso, lo hacemos para el símbolo BOS.
#for word in idx.keys():
probs = sorted(list(zip(idx.keys(),forward(idx['infinit']))), key=itemgetter(1), reverse=True)
probs[:20]
[('de', 0.038176918142471886), ('en', 0.017909391041895956), ('a', 0.014557162648726535), ('y', 0.014334084561012865), ('que', 0.011315246430016352), ('vi', 0.007330155355420444), ('tod', 0.006928234128033684), ('aleph', 0.006317570202325492), ('par', 0.006088867848752325), ('es', 0.005744297763202286), ('por', 0.0057261053981883915), ('era', 0.004747870099240432), ('del', 0.004549816290149445), ('con', 0.004353379088178406), ('argentin', 0.004123669943301754), ('espej', 0.004077811849812892), ('per', 0.003991795942585374), ('sol', 0.0038395116084809127), ('o', 0.0035962397547391212), ('se', 0.0035133773986019854)]
Los vectores de word embeddings se almacenan en la matriz de la capa de embedding (capa oculta). De esta forma, cada columna de la matriz corresponde a un vector que representa una palabra.
pd.DataFrame(data=C.T, index=list(idx.keys()))
0 | 1 | |
---|---|---|
la | 0.071946 | 1.571248 |
candent | 2.120305 | -0.557875 |
mañan | -2.338513 | 2.880450 |
de | 0.013222 | 0.371907 |
febrer | 1.694601 | 3.931954 |
... | ... | ... |
ment | 0.957419 | 1.031564 |
poros | 1.093793 | 1.628693 |
tragic | 3.008104 | -1.405820 |
erosion | -0.072291 | 1.547840 |
estel | -1.130263 | 1.987682 |
1524 rows × 2 columns
Podemos, entonces, visualizar los datos en un espacio vectorial.
#Función para visualizar palabras
def plot_words(Z,ids):
Z = PCA(2).fit_transform(Z)
r=0
plt.scatter(Z[:,0],Z[:,1], marker='o', c='blue')
for label,x,y in zip(ids, Z[:,0], Z[:,1]):
plt.annotate(label, xy=(x,y), xytext=(-1,1), textcoords='offset points', ha='center', va='bottom')
r+=1
#Visualización de los embeddings
plot_words(C.T[:100], list(idx.keys())[:100])
plt.title('Embeddings')
plt.show()
La paquetería de Pytorch permite implementar modelos basados en redes neuronales. Por tanto, podemos implementar una red para obtener embeddings como en el algoritmo de Word2Vec.
import torch
import torch.nn as nn
Definimos la red. En este caso, Pytorch ya cuenta con capas de embeddings, por lo que únicamente bastará señalar que queremos una capa de tipo Embedding.
network = nn.Sequential(nn.Embedding(N, 2),nn.Linear(2, N, bias=False), nn.Softmax(dim=1))
Definimos la función de riesgo y el optimizador:
risk = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(network.parameters(), lr=0.1)
Procedemos a entrenar la red con los datos:
%%time
R2 = []
for i in range(0,its):
R_it = 0
for ex in contexts:
optimizer.zero_grad()
probs = network(torch.tensor([ex[0]]))
loss = risk(probs, torch.LongTensor([ex[1]]))
loss.backward()
optimizer.step()
R_it += loss.detach()
#Guarda el riesgo en la época
R2.append(R_it)
#Imprime información de época
print('Fin de la iteración {}. Riesgo: {}'.format(i, R_it))
Fin de la iteración 0. Riesgo: 63715.2265625 Fin de la iteración 1. Riesgo: 63715.22265625 Fin de la iteración 2. Riesgo: 63715.22265625 Fin de la iteración 3. Riesgo: 63715.21875 Fin de la iteración 4. Riesgo: 63715.2109375 Fin de la iteración 5. Riesgo: 63715.2109375 Fin de la iteración 6. Riesgo: 63715.203125 Fin de la iteración 7. Riesgo: 63715.1875 Fin de la iteración 8. Riesgo: 63715.18359375 Fin de la iteración 9. Riesgo: 63715.18359375 Fin de la iteración 10. Riesgo: 63715.18359375 Fin de la iteración 11. Riesgo: 63715.1796875 Fin de la iteración 12. Riesgo: 63715.1640625 Fin de la iteración 13. Riesgo: 63715.15234375 Fin de la iteración 14. Riesgo: 63715.1484375 Fin de la iteración 15. Riesgo: 63715.13671875 Fin de la iteración 16. Riesgo: 63715.12109375 Fin de la iteración 17. Riesgo: 63715.11328125 Fin de la iteración 18. Riesgo: 63715.0859375 Fin de la iteración 19. Riesgo: 63715.06640625 Fin de la iteración 20. Riesgo: 63715.03125 Fin de la iteración 21. Riesgo: 63714.99609375 Fin de la iteración 22. Riesgo: 63714.96484375 Fin de la iteración 23. Riesgo: 63714.859375 Fin de la iteración 24. Riesgo: 63714.7734375 Fin de la iteración 25. Riesgo: 63714.625 Fin de la iteración 26. Riesgo: 63714.45703125 Fin de la iteración 27. Riesgo: 63714.08984375 Fin de la iteración 28. Riesgo: 63712.9609375 Fin de la iteración 29. Riesgo: 63703.78515625 Fin de la iteración 30. Riesgo: 63658.359375 Fin de la iteración 31. Riesgo: 63621.3828125 Fin de la iteración 32. Riesgo: 63594.75 Fin de la iteración 33. Riesgo: 63571.46875 Fin de la iteración 34. Riesgo: 63526.625 Fin de la iteración 35. Riesgo: 63514.59765625 Fin de la iteración 36. Riesgo: 63502.90625 Fin de la iteración 37. Riesgo: 63494.09375 Fin de la iteración 38. Riesgo: 63481.578125 Fin de la iteración 39. Riesgo: 63471.625 Fin de la iteración 40. Riesgo: 63466.54296875 Fin de la iteración 41. Riesgo: 63459.5625 Fin de la iteración 42. Riesgo: 63453.41015625 Fin de la iteración 43. Riesgo: 63449.34765625 Fin de la iteración 44. Riesgo: 63447.5546875 Fin de la iteración 45. Riesgo: 63446.2265625 Fin de la iteración 46. Riesgo: 63444.2265625 Fin de la iteración 47. Riesgo: 63442.70703125 Fin de la iteración 48. Riesgo: 63440.9296875 Fin de la iteración 49. Riesgo: 63440.46875 Fin de la iteración 50. Riesgo: 63440.23828125 Fin de la iteración 51. Riesgo: 63439.515625 Fin de la iteración 52. Riesgo: 63438.84375 Fin de la iteración 53. Riesgo: 63438.27734375 Fin de la iteración 54. Riesgo: 63437.69921875 Fin de la iteración 55. Riesgo: 63436.89453125 Fin de la iteración 56. Riesgo: 63436.3984375 Fin de la iteración 57. Riesgo: 63436.26953125 Fin de la iteración 58. Riesgo: 63435.86328125 Fin de la iteración 59. Riesgo: 63435.06640625 Fin de la iteración 60. Riesgo: 63434.03125 Fin de la iteración 61. Riesgo: 63433.32421875 Fin de la iteración 62. Riesgo: 63433.00390625 Fin de la iteración 63. Riesgo: 63432.28125 Fin de la iteración 64. Riesgo: 63431.77734375 Fin de la iteración 65. Riesgo: 63430.52734375 Fin de la iteración 66. Riesgo: 63429.86328125 Fin de la iteración 67. Riesgo: 63429.40234375 Fin de la iteración 68. Riesgo: 63429.3515625 Fin de la iteración 69. Riesgo: 63429.29296875 Fin de la iteración 70. Riesgo: 63429.14453125 Fin de la iteración 71. Riesgo: 63428.60546875 Fin de la iteración 72. Riesgo: 63428.1015625 Fin de la iteración 73. Riesgo: 63427.6640625 Fin de la iteración 74. Riesgo: 63426.8828125 Fin de la iteración 75. Riesgo: 63426.19921875 Fin de la iteración 76. Riesgo: 63425.37109375 Fin de la iteración 77. Riesgo: 63424.07421875 Fin de la iteración 78. Riesgo: 63421.55078125 Fin de la iteración 79. Riesgo: 63420.09765625 Fin de la iteración 80. Riesgo: 63418.921875 Fin de la iteración 81. Riesgo: 63417.0078125 Fin de la iteración 82. Riesgo: 63415.640625 Fin de la iteración 83. Riesgo: 63414.37109375 Fin de la iteración 84. Riesgo: 63412.89453125 Fin de la iteración 85. Riesgo: 63411.828125 Fin de la iteración 86. Riesgo: 63410.1796875 Fin de la iteración 87. Riesgo: 63408.87109375 Fin de la iteración 88. Riesgo: 63406.9765625 Fin de la iteración 89. Riesgo: 63405.359375 Fin de la iteración 90. Riesgo: 63404.3984375 Fin de la iteración 91. Riesgo: 63403.70703125 Fin de la iteración 92. Riesgo: 63400.0234375 Fin de la iteración 93. Riesgo: 63395.16015625 Fin de la iteración 94. Riesgo: 63394.74609375 Fin de la iteración 95. Riesgo: 63394.4921875 Fin de la iteración 96. Riesgo: 63394.17578125 Fin de la iteración 97. Riesgo: 63393.8046875 Fin de la iteración 98. Riesgo: 63393.421875 Fin de la iteración 99. Riesgo: 63393.03125 CPU times: user 33min 24s, sys: 14 s, total: 33min 38s Wall time: 5min 35s
Podemos visualizar el riesgo:
#Ploteo del riesgo
plt.plot(R2, 'v-')
plt.title('Riesgo a través de las iteraciones')
plt.xlabel('Iteración')
plt.ylabel('Riesgo')
plt.show()
Finalmente, observamos el comportamiento de los embeddings:
C2 = network[0].weight.detach()
plot_words(C2[:100], list(idx.keys())[:100])
plt.title('Embeddings obtenidos con pytorch')
plt.show()