Scale ML workloads using Ray

Introduction

The Cloud TPU Ray tool combines Cloud TPU API and Ray Jobs with the aim of improving users' development experience on Cloud TPU. This user guide provides a minimal example of how you can use Ray with Cloud TPUs. These examples are not meant to be used in production services and are for illustrative purposes only.

What's included in this tool?

For your convenience, the tool provides:

  • Generic abstractions that hide away boilerplate for common TPU actions
  • Toy examples that you can fork for your own basic workflows

Specifically:

  • tpu_api.py: Python wrapper for basic TPU operations using the Cloud TPU API.
  • tpu_controller.py: Class representation of a TPU. This is essentially a wrapper for tpu_api.py.
  • ray_tpu_controller.py: TPU controller with Ray functionality. This abstracts away boilerplate for Ray cluster and Ray Jobs.
  • run_basic_jax.py: Basic example that shows how to use RayTpuController for print(jax.device_count()).
  • run_hp_search.py: Basic example that shows how Ray Tune can be used with JAX/Flax on MNIST.
  • run_pax_autoresume.py: Example that showcases how you can use RayTpuController for fault tolerant training using PAX as an example workload.

Setting up Ray cluster head node

One of the basic ways you can use Ray with a TPU Pod is to set up the TPU Pod as a Ray cluster. Creating a separate CPU VM as coordinator VM is the natural way to do this. The following graphic shows an example of a Ray cluster configuration:

An example Ray cluster configuration

The following commands show how you can set up a Ray cluster using the Google Cloud CLI:

$ gcloud compute instances create my_tpu_admin --machine-type=n1-standard-4 ...
$ gcloud compute ssh my_tpu_admin

$ (vm) pip3 install ray[default]
$ (vm) ray start --head --port=6379 --num-cpus=0
...
# (Ray returns the IP address of the HEAD node, for example, RAY_HEAD_IP)
$ (vm) gcloud compute tpus tpu-vm create $TPU_NAME ... --metadata startup-script="pip3 install ray && ray start --address=$RAY_HEAD_IP --resources='{\"tpu_host\": 1}'"

For your convenience, we also provide basic scripts for creating a coordinator VM and deploying the contents of this folder to your coordinator VM. For source code, see create_cpu.sh and deploy.sh.

These scripts set some default values:

  • create_cpu.sh will create a VM named $USER-admin and will utilize whatever project and zone are set to your gcloud config defaults. Run gcloud config list to see those defaults.
  • create_cpu.sh by default allocates a boot disk size of 200GB.
  • deploy.sh assumes that your VM name is $USER-admin. If you change that value in create_cpu.sh be sure to change it in deploy.sh.

To use the convenience scripts:

  1. Clone the GitHub repository to your local machine and enter the ray_tpu folder:

    $ git clone https://github.com/tensorflow/tpu.git
    $ cd tpu/tools/ray_tpu/
  2. If you do not have a dedicated service account for TPU administration (highly recommended), set one up:

    $ ./create_tpu_service_account.sh
  3. Create a coordinator VM:

    $ ./create_cpu.sh

    This script installs dependencies on the VM by using a startup script and automatically blocks until the startup script is complete.

  4. Deploy local code to the coordinator VM:

    $ ./deploy.sh
  5. SSH to the VM:

    $ gcloud compute ssh $USER-admin -- -L8265:localhost:8265

    Port forwarding is enabled here as Ray will automatically start a dashboard at port 8265. From the machine that you SSH to your coordinator VM, you will be able to access this dashboard at http://127.0.0.1:8265/.

  6. If you skipped step 0, set up your gcloud credentials within the CPU VM:

    $ (vm) gcloud auth login --update-adc

    This step sets project ID info and allows Cloud TPU API to run on the coordinator VM.

  7. Install requirements:

    $ (vm) pip3 install -r src/requirements.txt
  8. Start Ray on the coordinator VM, and the coordinator VM becomes the head node of the Ray cluster:

    $ (vm) ray start --head --port=6379 --num-cpus=0

Usage examples

Basic JAX example

run_basic_jax.py is a minimal example that demonstrates how you can use the Ray Jobs and Ray runtime environment on a Ray cluster with TPU VMs to run a JAX workload.

For ML frameworks compatible with Cloud TPUs that use a multi-controller programming model, such as JAX and PyTorch/XLA PJRT, you must run at least one process per host. For more information, see Multi-process programming model. In practice, this might look like:

$ gcloud compute tpus tpu-vm scp my_bug_free_python_code my_tpu:~/ --worker=all
$ gcloud compute tpus tpu-vm ssh my_tpu --worker=all --command="python3 ~/my_bug_free_python_code/main.py"

If you have more than ~16 hosts, such as a v4-128, you will run into SSH scalability issues and your command might have to change to:

$ gcloud compute tpus tpu-vm scp my_bug_free_python_code my_tpu:~/ --worker=all --batch-size=8
$ gcloud compute tpus tpu-vm ssh my_tpu --worker=all --command="python3 ~/my_bug_free_python_code/main.py &" --batch-size=8

This can become a hindrance on developer velocity if my_bug_free_python_code contains bugs. One of the ways you can solve this problem is by using an orchestrator like Kubernetes or Ray. Ray includes the concept of a Runtime environment that, when applied, deploys code and dependencies when the Ray application is run.

Combining the Ray runtime environment with Ray cluster and Ray Jobs allows you to bypass the SCP/SSH cycle. Assuming you followed the above examples, you can run this with:

$ python3 legacy/run_basic_jax.py

The output is similar to the following:

2023-03-01 22:12:10,065   INFO worker.py:1364 -- Connecting to existing Ray cluster at address: 10.130.0.19:6379...
2023-03-01 22:12:10,072   INFO worker.py:1544 -- Connected to Ray cluster. View the dashboard at http://127.0.0.1:8265
W0301 22:12:11.148555 140341931026240 ray_tpu_controller.py:143] TPU is not found, create tpu...
Creating TPU:  $USER-ray-test
Request:  {'accelerator_config': {'topology': '2x2x2', 'type': 'V4'}, 'runtimeVersion': 'tpu-ubuntu2204-base', 'networkConfig': {'enableExternalIps': True}, 'metadata': {'startup-script': '#! /bin/bash\necho "hello world"\nmkdir -p /dev/shm\nsudo mount -t tmpfs -o size=100g tmpfs /dev/shm\n pip3 install ray[default]\nray start --resources=\'{"tpu_host": 1}\' --address=10.130.0.19:6379'}}
Create TPU operation still running...
...
Create TPU operation complete.
I0301 22:13:17.795493 140341931026240 ray_tpu_controller.py:121] Detected 0 TPU hosts in cluster, expecting 2 hosts in total
I0301 22:13:17.795823 140341931026240 ray_tpu_controller.py:160] Waiting for 30s for TPU hosts to join cluster...

I0301 22:15:17.986352 140341931026240 ray_tpu_controller.py:121] Detected 2 TPU hosts in cluster, expecting 2 hosts in total
I0301 22:15:17.986503 140341931026240 ray_tpu_controller.py:90] Ray already started on each host.
2023-03-01 22:15:18,010   INFO dashboard_sdk.py:315 -- Uploading package gcs://_ray_pkg_3599972ae38ce933.zip.
2023-03-01 22:15:18,010   INFO packaging.py:503 -- Creating a file package for local directory '/home/$USER/src'.
2023-03-01 22:15:18,080   INFO dashboard_sdk.py:362 -- Package gcs://_ray_pkg_3599972ae38ce933.zip already exists, skipping upload.
I0301 22:15:18.455581 140341931026240 ray_tpu_controller.py:169] Queued 2 jobs.
...
I0301 22:15:48.523541 140341931026240 ray_tpu_controller.py:254] [ADMIN]: raysubmit_WRUtVB7nMaRTgK39: Status is SUCCEEDED
I0301 22:15:48.561111 140341931026240 ray_tpu_controller.py:256] [raysubmit_WRUtVB7nMaRTgK39]: E0301 22:15:36.294834089   21286 credentials_generic.cc:35]            Could not get HOME environment variable.
8

I0301 22:15:58.575289 140341931026240 ray_tpu_controller.py:254] [ADMIN]: raysubmit_yPCPXHiFgaCK2rBY: Status is SUCCEEDED
I0301 22:15:58.584667 140341931026240 ray_tpu_controller.py:256] [raysubmit_yPCPXHiFgaCK2rBY]: E0301 22:15:35.720800499    8561 credentials_generic.cc:35]            Could not get HOME environment variable.
8

Fault tolerant training

This example showcases how you can use RayTpuController to implement fault tolerant training. For this example, we pretrain a simple LLM on PAX on a v4-16, but note that you can replace this PAX workload with any other long running workload. For source code, see run_pax_autoresume.py.

To run this example:

  1. Clone paxml to your coordinator VM:

    $ git clone https://github.com/google/paxml.git

    To demonstrate the ease-of-use that the Ray Runtime Environment provides for making and deploying JAX changes, this example requires you to modify PAX.

  2. Add a new experiment config:

    $ cat <<EOT >> paxml/paxml/tasks/lm/params/lm_cloud.py
    
    @experiment_registry.register
    class TestModel(LmCloudSpmd2BLimitSteps):
    ICI_MESH_SHAPE = [1, 4, 2]
    CHECKPOINT_POLICY = layers.AutodiffCheckpointType.SAVE_CONTEXT_AND_OUT_PROJ
    
    def task(self) -> tasks_lib.SingleTask.HParams:
      task_p = super().task()
      task_p.train.num_train_steps = 1000
      task_p.train.save_interval_steps = 100
      return task_p
    EOT
    
  3. Run run_pax_autoresume.py:

    $ python3 legacy/run_pax_autoresume.py --model_dir=gs://your/gcs/bucket
  4. As the workload runs, experiment with what happens when you delete your TPU, by default, named $USER-tpu-ray:

    $ gcloud compute tpus tpu-vm delete -q $USER-tpu-ray --zone=us-central2-b

    Ray will detect the TPU is down with following message:

    I0303 05:12:47.384248 140280737294144 checkpointer.py:64] Saving item to gs://$USER-us-central2/pax/v4-16-autoresume-test/checkpoints/checkpoint_00000200/metadata.
    W0303 05:15:17.707648 140051311609600 ray_tpu_controller.py:127] TPU is not found, create tpu...
    2023-03-03 05:15:30,774 WARNING worker.py:1866 -- The node with node id: 9426f44574cce4866be798cfed308f2d3e21ba69487d422872cdd6e3 and address: 10.130.0.113 and node name: 10.130.0.113 has been marked dead because the detector has missed too many heartbeats from it. This can happen when a       (1) raylet crashes unexpectedly (OOM, preempted node, etc.)
          (2) raylet has lagging heartbeats due to slow network or busy workload.
    2023-03-03 05:15:33,243 WARNING worker.py:1866 -- The node with node id: 214f5e4656d1ef48f99148ddde46448253fe18672534467ee94b02ba and address: 10.130.0.114 and node name: 10.130.0.114 has been marked dead because the detector has missed too many heartbeats from it. This can happen when a       (1) raylet crashes unexpectedly (OOM, preempted node, etc.)
          (2) raylet has lagging heartbeats due to slow network or busy workload.

    And the job will automatically recreate the TPU VM and restart the training job so that it can resume the training from the latest checkpoint (200 step in this example):

    I0303 05:22:43.141277 140226398705472 train.py:1149] Training loop starting...
    I0303 05:22:43.141381 140226398705472 summary_utils.py:267] Opening SummaryWriter `gs://$USER-us-central2/pax/v4-16-autoresume-test/summaries/train`...
    I0303 05:22:43.353654 140226398705472 summary_utils.py:267] Opening SummaryWriter `gs://$USER-us-central2/pax/v4-16-autoresume-test/summaries/eval_train`...
    I0303 05:22:44.008952 140226398705472 py_utils.py:350] Starting sync_global_devices Start training loop from step: 200 across 8 devices globally

This example showcases using Ray Tune from the Ray AIR to hyperparameter tune MNIST from JAX/FLAX. For source code, see run_hp_search.py.

To run this example:

  1. Install the requirements:

    $ pip3 install -r src/tune/requirements.txt
  2. Run run_hp_search.py:

    $ python3 src/tune/run_hp_search.py

    The output is similar to the following:

    Number of trials: 3/3 (3 TERMINATED)
    +-----------------------------+------------+-------------------+-----------------+------------+--------+--------+------------------+
    | Trial name                  | status     | loc               |   learning_rate |   momentum |    acc |   iter |   total time (s) |
    |-----------------------------+------------+-------------------+-----------------+------------+--------+--------+------------------|
    | hp_search_mnist_8cbbb_00000 | TERMINATED | 10.130.0.84:21340 |     1.15258e-09 |   0.897988 | 0.0982 |      3 |          82.4525 |
    | hp_search_mnist_8cbbb_00001 | TERMINATED | 10.130.0.84:21340 |     0.000219523 |   0.825463 | 0.1009 |      3 |          73.1168 |
    | hp_search_mnist_8cbbb_00002 | TERMINATED | 10.130.0.84:21340 |     1.08035e-08 |   0.660416 | 0.098  |      3 |          71.6813 |
    +-----------------------------+------------+-------------------+-----------------+------------+--------+--------+------------------+
    
    2023-03-02 21:50:47,378   INFO tune.py:798 -- Total run time: 318.07 seconds (318.01 seconds for the tuning loop).
    ...

Troubleshooting

Ray head node is unable to connect

If you run a workload that creates/deletes the TPU lifecycle, sometimes this doesn't disconnect the TPU hosts from the Ray cluster. This may show up as gRPC errors that signal that the Ray head node is unable to connect to a set of IP addresses.

As a result, you may need to terminate your ray session (ray stop) and restart it (ray start --head --port=6379 --num-cpus=0).

Ray Job fails directly without any log output

PAX is experimental and this example may break due to pip dependencies. If that happens, you may see something like this:

I0303 20:50:36.084963 140306486654720 ray_tpu_controller.py:174] Queued 2 jobs.
I0303 20:50:36.136786 140306486654720 ray_tpu_controller.py:238] Requested to clean up 1 stale jobs from previous failures.
I0303 20:50:36.148653 140306486654720 ray_tpu_controller.py:253] Job status: Counter({<JobStatus.FAILED: 'FAILED'>: 2})
I0303 20:51:38.582798 140306486654720 ray_tpu_controller.py:126] Detected 2 TPU hosts in cluster, expecting 2 hosts in total
W0303 20:51:38.589029 140306486654720 ray_tpu_controller.py:196] Detected job raysubmit_8j85YLdHH9pPrmuz FAILED.
2023-03-03 20:51:38,641   INFO dashboard_sdk.py:362 -- Package gcs://_ray_pkg_ae3cacd575e24531.zip already exists, skipping upload.
2023-03-03 20:51:38,706   INFO dashboard_sdk.py:362 -- Package gcs://_ray_pkg_ae3cacd575e24531.zip already exists, skipping upload.

To see the root cause of the error, you can go to http://127.0.0.1:8265/ and view the dashboard for the running/failed jobs which will provide more information. runtime_env_agent.log shows all the error information related to runtime_env setup, for example:

60    INFO: pip is looking at multiple versions of  to determine which version is compatible with other requirements. This could take a while.
61    INFO: pip is looking at multiple versions of orbax to determine which version is compatible with other requirements. This could take a while.
62    ERROR: Cannot install paxml because these package versions have conflicting dependencies.
63
64    The conflict is caused by:
65        praxis 0.3.0 depends on t5x
66        praxis 0.2.1 depends on t5x
67        praxis 0.2.0 depends on t5x
68        praxis 0.1 depends on t5x
69
70    To fix this you could try to:
71    1. loosen the range of package versions you've specified
72    2. remove package versions to allow pip attempt to solve the dependency conflict
73
74    ERROR: ResolutionImpossible: for help visit https://pip.pypa.io/en/latest/topics/dependency-resolution/#dealing-with-dependency-conflicts