Load a Dataset and Finetune a Model

Load a dataset and train your model

Once you have labeled your data on the Memri platform, you can use it to train your model in this Google Colab notebook*.

In this guide you will:

  1. Load a labeled dataset from the POD
  2. Train a distilRoBERTa text classifier model on a labelled dataset
  3. Upload a trained model to use in a plugin for a data app
  • If you are unfamiliar with Google Colab notebooks, have a look at this quick intro.
  • Make sure to run the below cells, one by one, in the correct order to avoid errors!
  • In this guide we are helping you connect your own personal data from your Memri POD, alternitively you can use the Tweet eval emoji datasets, which is available from 🤗 Hugging Face.
  • If you don’t wish to use your personal data, or you don’t want to spend time training a model, you can simply use our sentiment-plugin, which uses a pre-trained model from 🤗 Hugging Face. Just paste the plugin repo address at project set-up step on the Memri platform, and skip this process.

Setup

from IPython.display import clear_output
!pip install pandas transformers torch git+https://gitlab.memri.io/memri/pymemri.git@v0.0.23
clear_output()
print("Installed")
  1. Import the libraries needed to train your model
  • Make sure to run the installation step above first to avoid errors!
import os
import random
import textwrap

import pandas as pd
import torch
import transformers
from transformers import AutoModelForSequenceClassification, AutoTokenizer
from transformers.utils import logging

from pymemri.data.itembase import Edge, Item
from pymemri.data.schema import Dataset, Message, CategoricalLabel
from pymemri.data.oauth import OauthFlow
from pymemri.data.loader import write_model_to_package_registry
from pymemri.pod.client import PodClient
from getpass import getpass
transformers.utils.logging.set_verbosity_error()
os.environ["WANDB_DISABLED"] = "true"

1. Load your dataset from the POD

  1. Run the cell
  2. Copy your Dataset Name, Login Key and Password Key from your app.memri.io screen, and paste them below as prompted to load your connected dataset from you POD.
### *Define your pod url here*, this is the one for dev.app.memri.io ####
pod_url = "https://uat.pod.memri.io"
### *Define your dataset here* ####
dataset_name = input("dataset_name:") if "dataset_name" not in locals() else dataset_name
### *Define your login key here* ####
owner_key = getpass("owner key:") if "owner_key" not in locals() else owner_key
### *Define your password key here* ####
database_key = getpass("database_key:") if "database_key" not in locals() else database_key
  1. Connect your POD to load your data
# Connect to pod
client = PodClient(
    url=pod_url,
    owner_key=owner_key,
    database_key=database_key,
)
client.add_to_schema(CategoricalLabel, Message, Dataset, OauthFlow);
  1. Download and inspect the dataset
  • All entries in the dataset can be found via the Dataset.entry edge
dataset = client.get_dataset(dataset_name)

num_entries = len(dataset.entry)
print(f"number of items in the dataset: {num_entries}")
  1. Export the dataset to a format compatible with Python and inspect in a table
  • In pymemri, the Dataset class can format your dataset to different datatypes using the Dataset.to method; here we will use Pandas.
  • The columns of the dataset are inferred automatically. If you want to use custom columns, you can use the columns argument. See the dataset documentation for more info.
data = dataset.to("pandas")
data.head()

2. Fine-tune a model

  1. Configure the distilRoBERTa model on your dataset

The transformers library contains all code to do the training, you only need to define a torch Dataset that contains our data and handles tokenization.

# Hyperparameters
model_name = "distilroberta-base"
batch_size = 32
learning_rate = 1e-3

class TransformerDataset(torch.utils.data.Dataset):
    def __init__(self, data: pd.DataFrame, tokenizer: transformers.PreTrainedTokenizerBase):
        self.data = data
        self.label2idx, self.idx2label = self.get_label_map()
        self.num_labels = len(self.label2idx)
        self.tokenizer = tokenizer
        
    def tokenize(self, message, label=None):
        tokenized = self.tokenizer(message, padding="max_length", truncation=True)
        if label:
            tokenized["label"] = self.label2idx[label]
        return tokenized

    def get_label_map(self):
        unique_labels = data["annotation.labelValue"].unique()
        return {l: i for i, l in enumerate(unique_labels)}, {i: l for i, l in enumerate(unique_labels)}
        
    def __len__(self):
        return len(self.data)
        
    def __getitem__(self, idx):
        # Get the row from self.data, and skip the first column (id).
        return self.tokenize(*self.data.iloc[idx][1:])

tokenizer = AutoTokenizer.from_pretrained(model_name)
dataset = TransformerDataset(data, tokenizer)
  1. Train and finetune the model
  • We use Trainer class, as it handles all training, monitoring and integration with Weights & Biases
# Load model
model = AutoModelForSequenceClassification.from_pretrained(
    model_name,
    num_labels=dataset.num_labels,
    id2label=dataset.idx2label
)

# To increase training speed, we will freeze all layers except the classifier head.
for param in model.base_model.parameters():
    param.requires_grad = False
training_args = transformers.TrainingArguments(
    "twitter-emoji-trainer",
    learning_rate=learning_rate,
    per_device_train_batch_size=batch_size,
    logging_steps=1,
    optim="adamw_torch"
)

trainer = transformers.Trainer(
    model=model,
    args=training_args,
    train_dataset=dataset
)
logging.set_verbosity(40)
trainer.train()

3. Upload your model to a data app plugin

Now that your model is trained, it will be uploaded to your new GitLab project.

  1. Run the cell
  2. Copy and paste the GitLab project name from the your screen on app.memri.io
  • To avoid errors, make sure your GitLab project does not have any full stops in the name/URL
project_name = input("project name:") if "project_name" not in locals() else project_name
write_model_to_package_registry(model, project_name=project_name, client=client)

That’s it! 🎉

You have trained a ML model and made it accesible via the package registry, ready to be used in your data app.

Check out the next step to see how to build a plugin and deploy a data app.