Parallel DNN Training with JackknifeGPI
The data split section explains how to partition a dataset. After partitioning, the dataset is divided into two data bags: train_bags
and test_bags
.
This section demonstrates how to perform distributed DNN training using Tensorflow, with Jackknife resampling as the cross-validation method. In this approach, each year is iteratively left out as the cross-validation data. The best model is selected based on the lowest Root Mean Square Error (RMSE).
A more comprehensive example can be found in this Example Notebook.
Training a Single Grid Point
To train a DNN for a single grid cell, you can use the JackknifeGPI
object:
from motrainer.jackknife import JackknifeGPI
# Intiate a Jackknife GPI from one gridcell
df = train_bags.take(1)
gpi_data = df.compute()
gpi = JackknifeGPI(gpi_data, outpath='./results')
# Perform training and export
results = gpi.train()
gpi.export_best()
The training results will be exported to the ./results
path.
Training Multiple Grid Points
To train multiple grid points in parallel, you can define a training function as follows:
def training_func(gpi_num, df):
gpi_data = df.compute()
gpi = JackknifeGPI(gpi_data,
outpath=f"results/gpi{gpi_num}")
gpi.train()
gpi.export_best()
Then, map the training function to each grid cell:
from dask.distributed import Client, wait
# Use client to parallelize the loop across workers
client = Client()
futures = [
client.submit(training_func, gpi_num, df) for gpi_num, df in enumerate(train_bags)
]
# Wait for all computations to finish
wait(futures)
# Get the results
results = client.gather(futures)
The above examples uses a local threaded Dask scheduler to parallelize the tasks. When executing training on an HPC system, we recommend using Dask SLURM cluster for the distributed training. For more information on different Dask clusters, please check the Dask Documentation.
You can also directly submit training jobs as SLURM jobs, instead of using Dask SLURM cluster. You can find the example of using SLURM here.