sawaw/scripts/train_cnn.py

93 lines
3.0 KiB
Python
Raw Permalink Normal View History

2023-11-17 08:01:58 +01:00
import numpy as np
from sawaw import SAWAWEntry, SentimentResult
from pathlib import Path
import torch
from loguru import logger
from tqdm import tqdm
from methods.tokenizer import to_vec
from methods.model import SentimentAspectCNN
# Load the data from semeval dataset
path = Path("./data/restaurant_train.raw")
content = path.read_text()
def parse_content(content: str):
'''I 'm partial to the $T$ .
Gnocchi
1'''
lines = content.split("\n")
entries = []
for i in range(0, len(lines), 3):
if i + 2 >= len(lines):
break
sentence, aspect_word, sentiment = lines[i], lines[i+1], lines[i+2]
sentence_replaced = sentence.replace("$T$", aspect_word)
entries.append(SAWAWEntry(sentence_replaced, [aspect_word], [SentimentResult(int(sentiment)+1)]))
return entries
entries = parse_content(content)
logger.info("Loaded {} entries from {}", len(entries), path)
# Load the tokenizer
max_len = 80
data_vectors, sentiment_gts = [], []
for entry in tqdm(entries):
data_vector, sentiment_gt = to_vec(entry, max_len=max_len, should_return_sentiment=True) # shape: (num_of_aspect_words, 80, 26); (num_of_aspect_words, )
data_vectors.append(data_vector)
sentiment_gts.append(sentiment_gt)
data_vectors = torch.cat(data_vectors, dim=0)
sentiment_gts = torch.Tensor(sentiment_gts).unsqueeze(1) # shape: (num_of_aspect_words, 1)
# Train the model
embedding_dim = 26 # 25 for word embeddings + 1 for aspect indicator
num_filters = 88
filter_sizes = [3, 4, 3]
output_dim = 1
dropout = 0.2
model = SentimentAspectCNN(embedding_dim, num_filters, filter_sizes, output_dim, dropout)
model.train()
optimizer = torch.optim.Adam(model.parameters())
criterion = torch.nn.BCELoss()
batch_size = 16
from torch.utils.data import TensorDataset, DataLoader
dataset = TensorDataset(data_vectors, sentiment_gts)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
try:
epochs = 100
for epoch in range(epochs):
epoch_loss = 0
for batch in tqdm(dataloader):
data_vectors, sentiment_gts = batch
optimizer.zero_grad()
outputs = model(data_vectors)
loss = criterion(outputs, sentiment_gts)
loss.backward()
optimizer.step()
epoch_loss += loss.item()
logger.info("Epoch {}: loss={}", epoch, epoch_loss)
except KeyboardInterrupt:
logger.info("Training stopped by user")
# Save the model
torch.save(model.state_dict(), "./data/model.pt")
logger.info("Model saved to {}", "./data/model.pt")
# Test the model to find the best threshold
model.eval()
for threshold in np.arange(0.1, 1, 0.1):
logger.info("Testing with threshold={}", threshold)
num_correct = 0
num_total = 0
for batch in tqdm(dataloader):
data_vectors, sentiment_gts = batch
outputs = model(data_vectors)
outputs = outputs > threshold
num_correct += torch.sum(outputs == sentiment_gts).item()
num_total += len(sentiment_gts)
logger.info("Accuracy: {}", num_correct / num_total)