{ "cells": [ { "cell_type": "markdown", "id": "47cee8d5-1d5a-4e60-8bcd-fdf6496faadf", "metadata": {}, "source": [ "## Xarray: using CuPy" ] }, { "cell_type": "markdown", "id": "ec94158a-eb91-4cb0-a316-748b91f9696c", "metadata": {}, "source": [ "This notebook demonstrates how to use Xarray on a GPU with CuPy. Since CuPy is not a dependency for earthkit-data it has to be installed separately. Also a CUDA-based GPU environment has to be up and running for the notebook to work." ] }, { "cell_type": "code", "execution_count": 1, "id": "ebb1cb85-5e14-45dd-b956-432d970eed1c", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ " " ] } ], "source": [ "# Get GRIB data on pressure levels\n", "import earthkit.data as ekd\n", "ds = ekd.from_source(\"sample\", \"pl.grib\")" ] }, { "cell_type": "code", "execution_count": 2, "id": "8bd9660a-9d96-4bed-9c6e-4b3f1fdaeaea", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "
<xarray.Dataset> Size: 176kB\n",
       "Dimensions:                  (forecast_reference_time: 4, step: 2, level: 2,\n",
       "                              latitude: 19, longitude: 36)\n",
       "Coordinates:\n",
       "  * forecast_reference_time  (forecast_reference_time) datetime64[ns] 32B 202...\n",
       "  * step                     (step) timedelta64[ns] 16B 00:00:00 06:00:00\n",
       "  * level                    (level) int64 16B 500 700\n",
       "  * latitude                 (latitude) float64 152B 90.0 80.0 ... -80.0 -90.0\n",
       "  * longitude                (longitude) float64 288B 0.0 10.0 ... 340.0 350.0\n",
       "Data variables:\n",
       "    r                        (forecast_reference_time, step, level, latitude, longitude) float64 88kB ...\n",
       "    t                        (forecast_reference_time, step, level, latitude, longitude) float64 88kB ...\n",
       "Attributes:\n",
       "    class:        od\n",
       "    stream:       oper\n",
       "    levtype:      pl\n",
       "    type:         fc\n",
       "    expver:       0001\n",
       "    date:         20240603\n",
       "    time:         0\n",
       "    domain:       g\n",
       "    number:       0\n",
       "    Conventions:  CF-1.8\n",
       "    institution:  ECMWF
" ], "text/plain": [ " Size: 176kB\n", "Dimensions: (forecast_reference_time: 4, step: 2, level: 2,\n", " latitude: 19, longitude: 36)\n", "Coordinates:\n", " * forecast_reference_time (forecast_reference_time) datetime64[ns] 32B 202...\n", " * step (step) timedelta64[ns] 16B 00:00:00 06:00:00\n", " * level (level) int64 16B 500 700\n", " * latitude (latitude) float64 152B 90.0 80.0 ... -80.0 -90.0\n", " * longitude (longitude) float64 288B 0.0 10.0 ... 340.0 350.0\n", "Data variables:\n", " r (forecast_reference_time, step, level, latitude, longitude) float64 88kB ...\n", " t (forecast_reference_time, step, level, latitude, longitude) float64 88kB ...\n", "Attributes:\n", " class: od\n", " stream: oper\n", " levtype: pl\n", " type: fc\n", " expver: 0001\n", " date: 20240603\n", " time: 0\n", " domain: g\n", " number: 0\n", " Conventions: CF-1.8\n", " institution: ECMWF" ] }, "execution_count": 2, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Create a lazy loaded Xarray with Numpy arrays\n", "r = ds.to_xarray()\n", "r" ] }, { "cell_type": "code", "execution_count": 3, "id": "debcf314-41e8-4dcc-916b-ac602e936e16", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "numpy.ndarray" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "type(r.t.data)" ] }, { "cell_type": "markdown", "id": "28901396-3ded-48c8-9ed1-8ef70ab8df26", "metadata": {}, "source": [ "#### Move to the GPU as CuPy" ] }, { "cell_type": "markdown", "id": "b6b9b07b-44bc-4fd4-b312-d3c5bee787a7", "metadata": {}, "source": [ "We use the ``to_device()`` method, which is available on the ``earthkit`` Xarray accessor. The first argument specifies the device. When the device is not \"cpu\" and the ``array_backend`` keyword argument is not specified it is automatically set to \"cupy\"." ] }, { "cell_type": "code", "execution_count": 4, "id": "e6e251f2-75ac-4de5-8b9a-98cd12780063", "metadata": {}, "outputs": [], "source": [ "r_cp = r.earthkit.to_device(\"cuda:0\") \n", "# equivalent code:\n", "# r_cp = r.earthkit.to_device(\"cuda:0\", array_backend=\"cupy\") " ] }, { "cell_type": "code", "execution_count": 5, "id": "2716e8e3-4e53-4ba9-ae9d-57fa1f983fa3", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "cupy.ndarray" ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ "type(r_cp.t.data)" ] }, { "cell_type": "code", "execution_count": 6, "id": "6257255f-59e3-4e05-a63c-7e3e381f1258", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "
<xarray.DataArray 't' ()> Size: 8B\n",
       "array(261.56490497)
" ], "text/plain": [ " Size: 8B\n", "array(261.56490497)" ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Xarray computations work\n", "r_cp.t.mean()" ] }, { "cell_type": "code", "execution_count": 7, "id": "1106e56a-d9db-4484-88c6-43a7458892f7", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "cupy.ndarray" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Alter the values\n", "r_cp += 1\n", "type(r_cp.t.data)" ] }, { "cell_type": "markdown", "id": "126a7326-8c24-45bf-b5f2-a195a341a621", "metadata": {}, "source": [ "#### Move back to the CPU as Numpy" ] }, { "cell_type": "markdown", "id": "6b08cf88-3d20-4bc3-9589-52f098f502b2", "metadata": {}, "source": [ "We use ``to_device()`` again to move back the dataset to the cpu. When the device is \"cpu\" and the ``array_backend`` keyword argument is not specified it is automatically set to \"numpy\"." ] }, { "cell_type": "code", "execution_count": 8, "id": "4e35953f-c075-4118-85a3-2ac108af1bcf", "metadata": {}, "outputs": [], "source": [ "r_np = r_cp.earthkit.to_device(\"cpu\")\n", "# equivalent code:\n", "# r_np = r.earthkit.to_device(\"cpu\", array_backend=\"numpy\") " ] }, { "cell_type": "code", "execution_count": 9, "id": "701838e0-0be8-4066-99db-adc07b79e197", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "numpy.ndarray" ] }, "execution_count": 9, "metadata": {}, "output_type": "execute_result" } ], "source": [ "type(r_np.t.data)" ] }, { "cell_type": "code", "execution_count": 10, "id": "5e565c0b-f847-40d6-8869-22091780bc98", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "
<xarray.DataArray 't' ()> Size: 8B\n",
       "array(262.56490497)
" ], "text/plain": [ " Size: 8B\n", "array(262.56490497)" ] }, "execution_count": 10, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# The dataset contains the values altered on the GPU\n", "r_np.t.mean()" ] }, { "cell_type": "code", "execution_count": null, "id": "82ab0386-e111-47e0-9aac-064305f84503", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python (conda-ek-gpu)", "language": "python", "name": "earthkit-gpu" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.13.5" } }, "nbformat": 4, "nbformat_minor": 5 }