93 lines
3.0 KiB
Python
93 lines
3.0 KiB
Python
import numpy as np
|
|
from sawaw import SAWAWEntry, SentimentResult
|
|
from pathlib import Path
|
|
import torch
|
|
from loguru import logger
|
|
from tqdm import tqdm
|
|
|
|
from methods.tokenizer import to_vec
|
|
from methods.model import SentimentAspectCNN
|
|
# Load the data from semeval dataset
|
|
path = Path("./data/restaurant_train.raw")
|
|
content = path.read_text()
|
|
|
|
def parse_content(content: str):
|
|
'''I 'm partial to the $T$ .
|
|
Gnocchi
|
|
1'''
|
|
lines = content.split("\n")
|
|
entries = []
|
|
for i in range(0, len(lines), 3):
|
|
if i + 2 >= len(lines):
|
|
break
|
|
sentence, aspect_word, sentiment = lines[i], lines[i+1], lines[i+2]
|
|
sentence_replaced = sentence.replace("$T$", aspect_word)
|
|
entries.append(SAWAWEntry(sentence_replaced, [aspect_word], [SentimentResult(int(sentiment)+1)]))
|
|
return entries
|
|
|
|
entries = parse_content(content)
|
|
logger.info("Loaded {} entries from {}", len(entries), path)
|
|
|
|
# Load the tokenizer
|
|
max_len = 80
|
|
data_vectors, sentiment_gts = [], []
|
|
for entry in tqdm(entries):
|
|
data_vector, sentiment_gt = to_vec(entry, max_len=max_len, should_return_sentiment=True) # shape: (num_of_aspect_words, 80, 26); (num_of_aspect_words, )
|
|
data_vectors.append(data_vector)
|
|
sentiment_gts.append(sentiment_gt)
|
|
|
|
data_vectors = torch.cat(data_vectors, dim=0)
|
|
sentiment_gts = torch.Tensor(sentiment_gts).unsqueeze(1) # shape: (num_of_aspect_words, 1)
|
|
|
|
# Train the model
|
|
embedding_dim = 26 # 25 for word embeddings + 1 for aspect indicator
|
|
num_filters = 88
|
|
filter_sizes = [3, 4, 3]
|
|
output_dim = 1
|
|
dropout = 0.2
|
|
|
|
model = SentimentAspectCNN(embedding_dim, num_filters, filter_sizes, output_dim, dropout)
|
|
model.train()
|
|
optimizer = torch.optim.Adam(model.parameters())
|
|
criterion = torch.nn.BCELoss()
|
|
|
|
batch_size = 16
|
|
from torch.utils.data import TensorDataset, DataLoader
|
|
dataset = TensorDataset(data_vectors, sentiment_gts)
|
|
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
|
|
|
|
try:
|
|
epochs = 100
|
|
for epoch in range(epochs):
|
|
epoch_loss = 0
|
|
for batch in tqdm(dataloader):
|
|
data_vectors, sentiment_gts = batch
|
|
optimizer.zero_grad()
|
|
outputs = model(data_vectors)
|
|
loss = criterion(outputs, sentiment_gts)
|
|
loss.backward()
|
|
optimizer.step()
|
|
epoch_loss += loss.item()
|
|
logger.info("Epoch {}: loss={}", epoch, epoch_loss)
|
|
|
|
except KeyboardInterrupt:
|
|
logger.info("Training stopped by user")
|
|
# Save the model
|
|
torch.save(model.state_dict(), "./data/model.pt")
|
|
logger.info("Model saved to {}", "./data/model.pt")
|
|
|
|
# Test the model to find the best threshold
|
|
model.eval()
|
|
|
|
for threshold in np.arange(0.1, 1, 0.1):
|
|
logger.info("Testing with threshold={}", threshold)
|
|
num_correct = 0
|
|
num_total = 0
|
|
for batch in tqdm(dataloader):
|
|
data_vectors, sentiment_gts = batch
|
|
outputs = model(data_vectors)
|
|
outputs = outputs > threshold
|
|
num_correct += torch.sum(outputs == sentiment_gts).item()
|
|
num_total += len(sentiment_gts)
|
|
logger.info("Accuracy: {}", num_correct / num_total)
|