Main Module

cnn_trainer_local.define_transforms(height, width)[source]

Define transforms for the images.

Parameters:
  • height (int) – Images height.

  • width (int) – Images width.

Returns:

Data transforms.

Return type:

(dict)

cnn_trainer_local.main()[source]

Main function.

cnn_trainer_local.read_images(data_transforms, train_path, val_path, test_path)[source]

Read images (train, validation and test) from their respective directories.

Parameters:
  • data_transforms (dict) – Tranforms to be applied to the images.

  • train_path (str) – Path to train images directory.

  • val_path (str) – Path to validation images directory.

  • test_path (str) – Path to test images directory.

Returns:

A dict mapping keys to the:
  • ’train_data’: (datasets.ImageFolder) Train data.

  • ’validation_data’: (datasets.ImageFolder) Validation data.

  • ’test_data’: (datasets.ImageFolder) Test data.

Return type:

(dict)