Parallelly training DNN with Tensorflow
This notebooks demonstrate how to split data to train-test execute parallel DNN trainings.
The example dataset ./example1_data.zarr/ can be generated using the following Jupyter Notebook:
Import libraries¶
import xarray as xr
import motrainer
import dask_ml.model_selection as dcv
from motrainer.jackknife import JackknifeGPI
import numpy as np
2024-05-10 15:01:18.501959: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`. 2024-05-10 15:01:18.511230: I external/local_tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used. 2024-05-10 15:01:18.616501: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered 2024-05-10 15:01:18.616572: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered 2024-05-10 15:01:18.624625: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered 2024-05-10 15:01:18.644226: I external/local_tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used. 2024-05-10 15:01:18.646772: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations. To enable the following instructions: AVX2 AVX_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags. 2024-05-10 15:01:20.706954: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
Read data and split to train and test datasets¶
# Read the data
zarr_file_path = "./example1_data.zarr"
ds = xr.open_zarr(zarr_file_path)
# Use less data to reduce training time
ds = ds.isel(time=ds.time>=np.datetime64('2015-01-01'))
def to_dataframe(ds):
    return ds.to_dask_dataframe()
def chunk(ds, chunks):
    return ds.chunk(chunks)
    
bags = motrainer.dataset_split(ds, "space")
bags = bags.map(chunk, {"space": 100}).map(to_dataframe)
test_size = 0.33
f_shuffle = True
train_test_bags = bags.map(
    dcv.train_test_split, test_size=test_size, shuffle=f_shuffle, random_state=1
)  
train_bags = train_test_bags.pluck(0)
test_bags = train_test_bags.pluck(1)
Define training parameters¶
# JackKnife parameters
JackKnife = {
    'val_split_year': 2017,
    'output_list': ['sig', 'slop', 'curv'],
    'input_list': ['TG1', 'TG2', 'TG3', 'WG1', 'WG2', 'WG3', 'BIOMA1', 'BIOMA2'],
    'out_path': './dnn_examples/results'
}
# Training parameters
searching_space = {
    'num_dense_layers': [1, 2],
    'num_input_nodes': [2, 3],
    'num_dense_nodes': [16, 32],
    'learning_rate': [1e-3, 1e-2],
    'activation': ['relu']
}
# Define the optimization
optimize_space = {
    'best_loss': 2,
    'n_calls': 11,
    'epochs': 2,
    'noise': 0.1, 
    'kappa': 5,
    'validation_split': 0.2,
    'x0': [1e-3, 1, 2, 16, 'relu', 32]
}
Run the training¶
In this example, we will demonstrate how to run the training parralel per grid (partition) with a dask cluster.
# a function for training
def training_func(gpi_num, df, JackKnife, searching_space, optimize_space):
    
    # remove NA data
    gpi_data = df.compute()
    gpi_data.dropna(inplace=True)
    # add time to index
    gpi_data.set_index("time", inplace=True, drop=True)
    gpi = JackknifeGPI(gpi_data,
                       JackKnife['val_split_year'],
                       JackKnife['input_list'],
                       JackKnife['output_list'],
                       outpath=f"{JackKnife['out_path']}/gpi{gpi_num+1}")
    gpi.train(searching_space=searching_space,
              optimize_space=optimize_space,
              normalize_method='standard',
              training_method='dnn',
              performance_method='rmse',
              verbose=2)
    gpi.export_best()
    return gpi.apr_perf, gpi.post_perf
By default, Dask uses a local threaded scheduler to parallelize the tasks. Alternatively, other types of clusters can be set up if the training job is running on other infrastructures. The usage of different clusters will not influence the syntax of data split and training jobs. For more information on different Dask clusters, please check the Dask Documentation.
from dask.distributed import Client
client = Client()
from dask.distributed import wait
# Use client to parallelize the loop across workers
futures = [
    client.submit(training_func, gpi_num, df, JackKnife, searching_space, optimize_space) for  gpi_num, df in enumerate(train_bags)
]
# Wait for all computations to finish
wait(futures)
# Get the results
results = client.gather(futures)
# Close the Dask client
client.close()
# print the results
from pathlib import Path
Path('./results').mkdir(exist_ok=True)
for gpi_num, performance in enumerate(results):
    print(f"GPI {(gpi_num + 1)}")
    print(" aprior performance(RMSE):")
    print(performance[0])
    print("post performance(RMSE):")
    print(performance[1])
    print("=========================================")
GPI 1 aprior performance(RMSE): [[0.03349] [0.07387] [0.22965]] post performance(RMSE): [[0.32417] [0.08714] [0.81276]] ========================================= GPI 2 aprior performance(RMSE): [[0.10507] [0.03492] [0.09597]] post performance(RMSE): [[0.0185 ] [0.77383] [0.20172]] ========================================= GPI 3 aprior performance(RMSE): [[0.32753] [0.36519] [0.26186]] post performance(RMSE): [[0.17438] [0.26897] [0.16316]] ========================================= GPI 4 aprior performance(RMSE): [[0.22702] [0.50275] [0.12853]] post performance(RMSE): [[0.48915] [0.08741] [0.4903 ]] ========================================= GPI 5 aprior performance(RMSE): [[0.34393] [0.1475 ] [0.25872]] post performance(RMSE): [[2.02418] [0.47361] [0.67175]] =========================================
Shutdown the client to free up the resources click on SHUTDOWN in the Dask JupyterLab extension.
Inspect best model file¶
import h5py
import tensorflow as tf
best_model = "./dnn_examples/results/gpi1/best_optimized_model_2015.h5"
model = tf.keras.models.load_model(best_model)
model.summary()
Model: "sequential"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 dense (Dense)               (None, 2)                 18        
                                                                 
 layer_dense_1 (Dense)       (None, 19)                57        
                                                                 
 layer_dense_2 (Dense)       (None, 19)                380       
                                                                 
 dense_1 (Dense)             (None, 3)                 60        
                                                                 
=================================================================
Total params: 515 (2.01 KB)
Trainable params: 515 (2.01 KB)
Non-trainable params: 0 (0.00 Byte)
_________________________________________________________________
# Add more info to the model file e.g. the path to the data
with h5py.File(best_model, 'a') as f:
    f.attrs['input_file_path'] = "./example1_data.zarr"
# Inspect the hyperparameters and input_list 
with h5py.File(best_model, 'r') as f:
    hyperparameters = f.attrs['hyperparameters']
    input_list = f.attrs['input_list']
    input_file_path = f.attrs['input_file_path']
print(eval(hyperparameters))
[(0.5740020871162415, [0.009248990393121144, 2, 2, 19, 'relu', 219]), (0.6381517052650452, [0.0034642993935407942, 2, 3, 19, 'relu', 89]), (0.7027523517608643, [0.00943332323418086, 2, 2, 25, 'relu', 306]), (0.7084829211235046, [0.006084925609967853, 1, 2, 27, 'relu', 123]), (0.7407179474830627, [0.002975229104493618, 1, 2, 31, 'relu', 122]), (0.7598923444747925, [0.0018202265892128732, 1, 2, 22, 'relu', 59]), (0.7676792144775391, [0.0012450228647412386, 1, 2, 18, 'relu', 330]), (0.7816160917282104, [0.0023138826233784302, 2, 2, 25, 'relu', 38]), (0.8949440121650696, [0.001439410716322745, 1, 3, 17, 'relu', 103]), (0.9745810627937317, [0.001, 1, 2, 16, 'relu', 32]), (1.0652109384536743, [0.001534802942029055, 2, 3, 16, 'relu', 295])]
print(input_list)
['TG1' 'TG2' 'TG3' 'WG1' 'WG2' 'WG3' 'BIOMA1' 'BIOMA2']
print(input_file_path)
./example1_data.zarr