{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "This notebooks demonstrate how to execute parallel machine learning training using [`dask-ml`](https://ml.dask.org/) and motrainer.\n", "\n", "The example dataset `./example1_data.zarr/` can be generated using the following Jupyter Notebook:\n", "- [Covert a nested DataFrame to a Dataset](https://vegewaterdynamics.github.io/motrainer/notebooks/example_read_from_one_df/)" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import motrainer\n", "import numpy as np\n", "import xarray as xr" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Load data" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "
<xarray.Dataset>\n",
       "Dimensions:    (space: 5, time: 8506)\n",
       "Coordinates:\n",
       "    latitude   (space) float64 dask.array<chunksize=(5,), meta=np.ndarray>\n",
       "    longitude  (space) float64 dask.array<chunksize=(5,), meta=np.ndarray>\n",
       "  * time       (time) datetime64[ns] 2007-01-02 ... 2020-01-01T01:00:00\n",
       "Dimensions without coordinates: space\n",
       "Data variables:\n",
       "    BIOMA1     (space, time) float64 dask.array<chunksize=(3, 8506), meta=np.ndarray>\n",
       "    BIOMA2     (space, time) float64 dask.array<chunksize=(3, 8506), meta=np.ndarray>\n",
       "    TG1        (space, time) float64 dask.array<chunksize=(3, 8506), meta=np.ndarray>\n",
       "    TG2        (space, time) float64 dask.array<chunksize=(3, 8506), meta=np.ndarray>\n",
       "    TG3        (space, time) float64 dask.array<chunksize=(3, 8506), meta=np.ndarray>\n",
       "    WG1        (space, time) float64 dask.array<chunksize=(3, 8506), meta=np.ndarray>\n",
       "    WG2        (space, time) float64 dask.array<chunksize=(3, 8506), meta=np.ndarray>\n",
       "    WG3        (space, time) float64 dask.array<chunksize=(3, 8506), meta=np.ndarray>\n",
       "    curv       (space, time) float64 dask.array<chunksize=(3, 8506), meta=np.ndarray>\n",
       "    sig        (space, time) float64 dask.array<chunksize=(3, 8506), meta=np.ndarray>\n",
       "    slop       (space, time) float64 dask.array<chunksize=(3, 8506), meta=np.ndarray>\n",
       "Attributes:\n",
       "    license:  data license\n",
       "    source:   data source
" ], "text/plain": [ "\n", "Dimensions: (space: 5, time: 8506)\n", "Coordinates:\n", " latitude (space) float64 dask.array\n", " longitude (space) float64 dask.array\n", " * time (time) datetime64[ns] 2007-01-02 ... 2020-01-01T01:00:00\n", "Dimensions without coordinates: space\n", "Data variables:\n", " BIOMA1 (space, time) float64 dask.array\n", " BIOMA2 (space, time) float64 dask.array\n", " TG1 (space, time) float64 dask.array\n", " TG2 (space, time) float64 dask.array\n", " TG3 (space, time) float64 dask.array\n", " WG1 (space, time) float64 dask.array\n", " WG2 (space, time) float64 dask.array\n", " WG3 (space, time) float64 dask.array\n", " curv (space, time) float64 dask.array\n", " sig (space, time) float64 dask.array\n", " slop (space, time) float64 dask.array\n", "Attributes:\n", " license: data license\n", " source: data source" ] }, "execution_count": 2, "metadata": {}, "output_type": "execute_result" } ], "source": [ "data_path = \"./example1_data.zarr\"\n", "ds = xr.open_zarr(data_path)\n", "ds" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Split per gridcell" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "True" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Check if the dataset is splitable\n", "motrainer.is_splitable(ds)" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "dask.bag" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# split the dataset per grid cell\n", "bags = motrainer.dataset_split(ds, \"space\")\n", "bags" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Train Test Split" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "def to_dataframe(ds):\n", " return ds.to_dask_dataframe()\n", "\n", "def chunk(ds, chunks):\n", " return ds.chunk(chunks)" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "# Train test split, mapped to each element of the bag\n", "train_test_bags = bags.map(\n", " motrainer.train_test_split, split={\"time\": np.datetime64(\"2016-01-01\")}\n", ")\n", "\n", "# # Or split by mask\n", "# mask = ds[\"time\"]