Example 1: Training a classifier based on image frames ====================================================== In this example we train a ResNet model using images loaded from the dataset. The model accepts single images, and is trained in batches of 16. The dataset is split into 80% train and 20% test, split by sequence so that all images in one sequence remain in the same train/test set as each other. The Code ----------- :: import torch from torch import nn from torch import optim import torch.nn.functional as F import numpy as np import pandas as pd from tqdm import tqdm from torchvision import datasets, transforms, models import argparse from pathlib import Path from torch.utils.data import DataLoader from DigitalTyphoonDataloader.DigitalTyphoonDataset import DigitalTyphoonDataset def main(args): ## Prepare the data # Specify the paths to the data data_path = args.dataroot images_path = data_path + '/image/' # to the image folder metadata_path = data_path + '/metadata/' # to the metadata folder json_path = data_path + '/metadata.json' # to the metadata json # Define a filter to pass to the loader. # Any image that the function returns true will be included def image_filter(image): return image.grade() < 7 # Define a function to transform each image, to pass to the loader. # Crucially, this transform function is applied to each *image*, prior to any Pytorch processing. # So, image-by-image transforms (i.e. clipping, downsampling, etc. can/should be done here) def transform_func(image_ray): # Clip the pixel values between 150 and 350 image_ray = np.clip(image_ray, standardize_range[0], standardize_range[1]) # Standardize the pixel values between 0 and 1 image_ray = (image_ray - standardize_range[0]) / (standardize_range[1] - standardize_range[0]) # Downsample the images to 224, 224 if downsample_size != (512, 512): image_ray = torch.Tensor(image_ray) image_ray = torch.reshape(image_ray, [1, 1, image_ray.size()[0], image_ray.size()[1]]) image_ray = nn.functional.interpolate(image_ray, size=downsample_size, mode='bilinear', align_corners=False) image_ray = torch.reshape(image_ray, [image_ray.size()[2], image_ray.size()[3]]) image_ray = image_ray.numpy() return image_ray # Load Dataset dataset = DigitalTyphoonDataset(str(images_path), str(metadata_path), str(json_path), 'grade', # the labels we'd like to retrieve from the dataset filter_func=image_filter, # the filter function defined above transform_func=transform_func, # the transform function defined above verbose=False) # Split the dataset into a training and test split (80% and 20% respectively) # split by sequence so all images in one sequence will belong to the same bucket train_set, test_set = dataset.random_split([0.8, 0.2], split_by='sequence') # Make Pytorch DataLoaders out of the returned sets. From here, it retains all Pytorch functionality. trainloader = DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=num_workers) testloader = DataLoader(test_set, batch_size=batch_size, shuffle=False, num_workers=num_workers) ## Prepare the model # Hyperparameters num_epochs = args.max_epochs batch_size = 16 learning_rate = 0.001 standardize_range = (150, 350) downsample_size = (224, 224) # Load a ResNet model model = torch.hub.load('pytorch/vision:v0.10.0', 'resnet18', weights=weights) # Modify the model to take single channel images model.conv1 = nn.Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False) # Modify the model to classify between 7 classes model.fc = nn.Linear(in_features=512, out_features=7, bias=True) # Loss and optimizer criterion = nn.CrossEntropyLoss() optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate, momentum=0.9) ## Train the model for epoch in np.arange(max_epochs): batches_per_epoch = len(trainloader) model.train() for batch_num, data in enumerate(tqdm(trainloader)): # One batch of the data (16 images and 16 labels) are held in the data variable # Data is a tuple, with images in data[0] and labels in data[1] images, labels = data # cast pixels to float and grade (label) to long images, labels = torch.Tensor(images).float(), torch.Tensor(labels).long() # Reshape the image tensor to add a channel dimension (only one channel) images = torch.reshape(images, [images.size()[0], 1, images.size()[1], images.size()[2]]) optimizer.zero_grad() # Forward pass predictions = model(images) # Calculate the loss loss = criterion(predictions, labels) # backward pass loss.backward() # update weights optimizer.step() if __name__ == '__main__': parser = argparse.ArgumentParser(description='Train a resnet model') parser.add_argument('--dataroot', required=True, type=str, help='path to the root data directory') parser.add_argument('--split_by', default='frame', type=str, help='How to split the dataset') parser.add_argument('--maxepochs', default=100, type=int) args = parser.parse_args() main(args)