CNN Module
- class cnn.CNN(train_data, validation_data, test_data, batch_size)[source]
Bases:
object
CNN Trainer class.
This class is responsible for training a CNN model.
- Parameters:
train_data (torchvision.datasets.ImageFolder) – Training data.
validation_data (torchvision.datasets.ImageFolder) – Validation data.
test_data (torchvision.datasets.ImageFolder) – Test data.
batch_size (int) – Batch size.
- create_and_train_cnn(model_name, num_epochs, learning_rate, weight_decay, replications)[source]
Create and train a CNN model.
- Parameters:
model_name (str) – Model name to be trained.
num_epochs (int) – Number of epochs to be trained.
learning_rate (float) – Learning rate to be used at train.
weight_decay (float) – Weight decay to be used at train.
replications (int) – Number of replications used at each trained model.
- Returns:
- A dict mapping keys to the:
’result_name’: (str) Result name.
’acc_avg’: (float) Average accuracy.
’iter_acc_max’: (int) Iteration of maximum accuracy.
’duration’: (float) Duration of training.
- Return type:
(dict)
- create_criterion()[source]
Create a loss criterion.
- Parameters:
None
- Returns:
Cross entropy loss object.
- Return type:
(object)
- create_model(model_name)[source]
Create a function to a CNN model to be trained.
Note
At moment, the models available are: [VGG11, Alexnet, MobilenetV3Large].
- Parameters:
model_name (str) – CNN model name.
- Returns:
Function to CNN model selected.
- Return type:
(function)
- create_optimizer(model, learning_rate, weight_decay)[source]
Create an optimizer.
- Parameters:
model (function) – CNN function.
learning_rate (float) – Learning rate
weight_decay (float) – Weight decay
- Returns:
Optimizer object.
- Return type:
(object)
- evaluate_model(model, loader)[source]
Evaluate a model.
- Parameters:
model (function) – Model function.
loader (DataLoader) – Data loader
- Returns:
Model (trained) accuracy.
- Return type:
(float)
- train_epoch(model, trainLoader, optimizer, criterion)[source]
Train an epoch.
- Parameters:
model (function) – Model function.
trainLoader (DataLoader) – Training data loader.
optimizer (object) – Optimizer object.
criterion (object) – CEL object.
- Returns:
Mean of losses.
- Return type:
(float)
- train_model(model, train_loader, optimizer, criterion, model_name, num_epochs, learning_rate, weight_decay, replication)[source]
Train a CNN model.
Train a CNN model and save it (PTH file) at ‘models’ directory.
- Parameters:
model (function) – Model function.
train_loader (DataLoader) – Training data loader
optimizer (object) – Optimizer object.
criterion (object) – CEL object.
model_name (str) – Model name.
num_epochs (int) – Number of epochs.
learning_rate (float) – Learning rate.
weight_decay (float) – Weight decay.
replication (int) – Replication.
- Returns:
None