sawaw/scripts/train_cnn.py

93 lines
3.0 KiB
Python

import numpy as np
from sawaw import SAWAWEntry, SentimentResult
from pathlib import Path
import torch
from loguru import logger
from tqdm import tqdm
from methods.tokenizer import to_vec
from methods.model import SentimentAspectCNN
# Load the data from semeval dataset
path = Path("./data/restaurant_train.raw")
content = path.read_text()
def parse_content(content: str):
'''I 'm partial to the $T$ .
Gnocchi
1'''
lines = content.split("\n")
entries = []
for i in range(0, len(lines), 3):
if i + 2 >= len(lines):
break
sentence, aspect_word, sentiment = lines[i], lines[i+1], lines[i+2]
sentence_replaced = sentence.replace("$T$", aspect_word)
entries.append(SAWAWEntry(sentence_replaced, [aspect_word], [SentimentResult(int(sentiment)+1)]))
return entries
entries = parse_content(content)
logger.info("Loaded {} entries from {}", len(entries), path)
# Load the tokenizer
max_len = 80
data_vectors, sentiment_gts = [], []
for entry in tqdm(entries):
data_vector, sentiment_gt = to_vec(entry, max_len=max_len, should_return_sentiment=True) # shape: (num_of_aspect_words, 80, 26); (num_of_aspect_words, )
data_vectors.append(data_vector)
sentiment_gts.append(sentiment_gt)
data_vectors = torch.cat(data_vectors, dim=0)
sentiment_gts = torch.Tensor(sentiment_gts).unsqueeze(1) # shape: (num_of_aspect_words, 1)
# Train the model
embedding_dim = 26 # 25 for word embeddings + 1 for aspect indicator
num_filters = 88
filter_sizes = [3, 4, 3]
output_dim = 1
dropout = 0.2
model = SentimentAspectCNN(embedding_dim, num_filters, filter_sizes, output_dim, dropout)
model.train()
optimizer = torch.optim.Adam(model.parameters())
criterion = torch.nn.BCELoss()
batch_size = 16
from torch.utils.data import TensorDataset, DataLoader
dataset = TensorDataset(data_vectors, sentiment_gts)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
try:
epochs = 100
for epoch in range(epochs):
epoch_loss = 0
for batch in tqdm(dataloader):
data_vectors, sentiment_gts = batch
optimizer.zero_grad()
outputs = model(data_vectors)
loss = criterion(outputs, sentiment_gts)
loss.backward()
optimizer.step()
epoch_loss += loss.item()
logger.info("Epoch {}: loss={}", epoch, epoch_loss)
except KeyboardInterrupt:
logger.info("Training stopped by user")
# Save the model
torch.save(model.state_dict(), "./data/model.pt")
logger.info("Model saved to {}", "./data/model.pt")
# Test the model to find the best threshold
model.eval()
for threshold in np.arange(0.1, 1, 0.1):
logger.info("Testing with threshold={}", threshold)
num_correct = 0
num_total = 0
for batch in tqdm(dataloader):
data_vectors, sentiment_gts = batch
outputs = model(data_vectors)
outputs = outputs > threshold
num_correct += torch.sum(outputs == sentiment_gts).item()
num_total += len(sentiment_gts)
logger.info("Accuracy: {}", num_correct / num_total)