This tutorial guides you through:
- Creating a Cloud TPU VM to deploy the Llama 2 family of large language models (LLMs), available in different sizes (7B, 13B, or 70B)
- Preparing checkpoints for the models and deploying them on SAX
- Interacting with the model through an HTTP endpoint
Serving for AGI Experiments (SAX) is an experimental system that serves Paxml, JAX, and PyTorch models for inference. Code and documentation for SAX are in the Saxml Git repository. The current stable version with TPU v5e support is v1.1.0.
About SAX cells
A SAX cell (or cluster) is the core unit for serving your models. It consists of two main parts:
- Admin server: This server keeps track of your model servers, assigns models to those model servers, and helps clients find the right model server to interact with.
- Model servers: These servers run your model. They're responsible for processing incoming requests and generating responses.
The following diagram shows a diagram of a SAX cell:
Figure 1. SAX cell with admin server and model server.
You can interact with a SAX cell using clients written in Python, C++, or Go, or directly through an HTTP server. The following diagram shows how an external client can interact with a SAX cell:
Figure 2. Runtime architecture of an external client interacting with a SAX cell.
Objectives
- Set up TPU resources for serving
- Create a SAX cluster
- Publish the Llama 2 model
- Interact with the model
Costs
In this document, you use the following billable components of Google Cloud:
- Cloud TPU
- Compute Engine
- Cloud Storage
To generate a cost estimate based on your projected usage,
use the pricing calculator.
Before you begin
Set up your Google Cloud project, activate the Cloud TPU API, and create a service account by following the instructions in Set up the Cloud TPU environment.
Create a TPU
The following steps show how to create a TPU VM, which will serve your model.
Create environment variables:
export PROJECT_ID=PROJECT_ID export ACCELERATOR_TYPE=ACCELERATOR_TYPE export ZONE=ZONE export RUNTIME_VERSION=v2-alpha-tpuv5-lite export SERVICE_ACCOUNT=SERVICE_ACCOUNT export TPU_NAME=TPU_NAME export QUEUED_RESOURCE_ID=QUEUED_RESOURCE_ID
Environment variable descriptions
PROJECT_ID
- The ID of your Google Cloud project.
ACCELERATOR_TYPE
- The accelerator type specifies the version and size of the
Cloud TPU you want to create. Different Llama 2 model sizes have
different TPU size requirements:
- 7B:
v5litepod-4
or larger - 13B:
v5litepod-8
or larger - 70B:
v5litepod-16
or larger
- 7B:
ZONE
- The zone where you want to create your Cloud TPU.
SERVICE_ACCOUNT
- The service account you want to attach to your Cloud TPU.
TPU_NAME
- The name for your Cloud TPU.
QUEUED_RESOURCE_ID
- An identifier for your queued resource request.
Set the project ID and zone in your active Google Cloud CLI configuration:
gcloud config set project $PROJECT_ID && gcloud config set compute/zone $ZONE
Create the TPU VM:
gcloud compute tpus queued-resources create ${QUEUED_RESOURCE_ID} \ --node-id ${TPU_NAME} \ --project ${PROJECT_ID} \ --zone ${ZONE} \ --accelerator-type ${ACCELERATOR_TYPE} \ --runtime-version ${RUNTIME_VERSION} \ --service-account ${SERVICE_ACCOUNT}
Check that the TPU is active:
gcloud compute tpus queued-resources list --project $PROJECT_ID --zone $ZONE
Set up checkpoint conversion node
To run the LLama models on a SAX cluster, you need to convert the original Llama checkpoints to a SAX-compatible format.
The conversion requires significant memory resources, depending on the model size:
Model | Machine type |
---|---|
7B | 50-60 GB memory |
13B | 120 GB memory |
70B | 500-600 GB memory (N2 or M1 machine type) |
For the 7B and the 13B model, you can run the conversion on the TPU VM. For the 70B model, you need to create a Compute Engine instance with approximately 1TB of disk space:
gcloud compute instances create INSTANCE_NAME --project=$PROJECT_ID --zone=$ZONE \ --machine-type=n2-highmem-128 \ --network-interface=network-tier=PREMIUM,stack-type=IPV4_ONLY,subnet=default \ --maintenance-policy=MIGRATE --provisioning-model=STANDARD \ --service-account=$SERVICE_ACCOUNT \ --scopes=https://www.googleapis.com/auth/cloud-platform \ --tags=http-server,https-server \ --create-disk=auto-delete=yes,boot=yes,device-name=bk-workday-dlvm,image=projects/ml-images/global/images/c0-deeplearning-common-cpu-v20230925-debian-10,mode=rw,size=500,type=projects/$PROJECT_ID/zones/$ZONE/diskTypes/pd-balanced \ --no-shielded-secure-boot \ --shielded-vtpm \ --shielded-integrity-monitoring \ --labels=goog-ec-src=vm_add-gcloud \ --reservation-affinity=any
Whether you use a TPU or Compute Engine instance as your conversion server, set up your server to convert the Llama 2 checkpoints:
For the 7B and the 13B model, set the server name environment variable to the name of your TPU:
export CONV_SERVER_NAME=$TPU_NAME
For the 70B model, set the server name environment variable to the name of your Compute Engine instance:
export CONV_SERVER_NAME=INSTANCE_NAME
Connect to the conversion node using SSH.
If your conversion node is a TPU, connect to the TPU:
gcloud compute tpus tpu-vm ssh $CONV_SERVER_NAME --project=$PROJECT_ID --zone=$ZONE
If your conversion node is a Compute Engine instance, connect to the Compute Engine VM:
gcloud compute ssh $CONV_SERVER_NAME --project=$PROJECT_ID --zone=$ZONE
Install required packages on the conversion node:
sudo apt update sudo apt-get install python3-pip sudo apt-get install git-all pip3 install paxml==1.1.0 pip3 install torch pip3 install jaxlib==0.4.14
Download the Llama checkpoint conversion script:
gcloud storage cp gs://cloud-tpu-inference-public/sax-tokenizers/llama/convert_llama_ckpt.py .
Download Llama 2 weights
Before converting the model, you need to download the Llama 2 weights. For this tutorial, you must use the original Llama 2 weights (for example, meta-llama/Llama-2-7b) and not the weights that have been converted for the Hugging Face Transformers format (for example, meta-llama/Llama-2-7b-hf).
If you already have the Llama 2 weights, skip ahead to Convert the weights.
To download the weights from the Hugging Face hub, you need to set up a user access token and request access to the Llama 2 models. To request access, follow the instructions on the Hugging Face page for the model you want to use, for example, meta-llama/Llama-2-7b.
Create a directory for the weights:
sudo mkdir WEIGHTS_DIRECTORY
Get the Llama2 weights from the Hugging Face hub:
Install the Hugging Face hub CLI:
pip install -U "huggingface_hub[cli]"
Change to the weights directory:
cd WEIGHTS_DIRECTORY
Download the Llama 2 files:
python3 from huggingface_hub import login login() from huggingface_hub import hf_hub_download, snapshot_download import os PATH=os.getcwd() snapshot_download(repo_id="meta-llama/LLAMA2_REPO", local_dir_use_symlinks=False, local_dir=PATH)
Replace LLAMA2_REPO with the name of the Hugging Face repository you want to download from:
Llama-2-7b
,Llama-2-13b
, orLlama-2-70b
.
Convert the weights
Edit the conversion script, then run the conversion script to convert the model weights.
Create a directory to hold the converted weights:
sudo mkdir CONVERTED_WEIGHTS
Clone the Saxml GitHub repository in a directory where you have read, write, and execute permissions:
git clone https://github.com/google/saxml.git -b r1.1.0
Change to the
saxml
directory:cd saxml
Open the
saxml/tools/convert_llama_ckpt.py
file.In the
saxml/tools/convert_llama_ckpt.py
file, change line 169 from:'scale': pytorch_vars[0]['layers.%d.attention_norm.weight' % (layer_idx)].type(torch.float16).numpy()
To:
'scale': pytorch_vars[0]['norm.weight'].type(torch.float16).numpy()
Run the
saxml/tools/init_cloud_vm.sh
script:saxml/tools/init_cloud_vm.sh
For 70B only: Turn test mode off:
- Open the
saxml/server/pax/lm/params/lm_cloud.py
file. In the
saxml/server/pax/lm/params/lm_cloud.py
file, change line 344 from:return True
To:
return False
- Open the
Convert the weights:
python3 saxml/tools/convert_llama_ckpt.py --base-model-path WEIGHTS_DIRECTORY \ --pax-model-path CONVERTED_WEIGHTS \ --model-size MODEL_SIZE
Replace the following:
- WEIGHTS_DIRECTORY: Directory for the original weights.
- CONVERTED_WEIGHTS: Target path for the converted weights.
- MODEL_SIZE:
7b
,13b
, or70b
.
Prepare the checkpoint directory
After you convert the checkpoints, the checkpoint directory should have the following structure:
checkpoint_00000000
metadata/
metadata
state/
mdl_vars.params.lm*/
...
...
step/
Create an empty file named commit_success.txt
and put a copy of it in the
checkpoint_00000000
, metadata
, and state
directories. This lets SAX know
that this checkpoint is fully converted and ready to load:
Change to the checkpoint directory:
cd CONVERTED_WEIGHTS/checkpoint_00000000
Create an empty file named
commit_success.txt
:touch commit_success.txt
Change to the metadata directory and create an empty file named
commit_success.txt
:cd metadata && touch commit_success.txt
Change to the state directory and create an empty file named
commit_success.txt
:cd .. && cd state && touch commit_success.txt
The checkpoint directory should now have the following structure:
checkpoint_00000000
commit_success.txt
metadata/
commit_success.txt
metadata
state/
commit_success.txt
mdl_vars.params.lm*/
...
...
step/
Create a Cloud Storage bucket
You need to store the converted checkpoints in a Cloud Storage bucket so that they're available when publishing the model.
Set an environment variable for the name of your Cloud Storage bucket:
export GSBUCKET=BUCKET_NAME
Create a bucket:
gcloud storage buckets create gs://${GSBUCKET}
Copy your converted checkpoint files to your bucket:
gcloud storage cp -r CONVERTED_WEIGHTS/checkpoint_00000000 gs://$GSBUCKET/sax_models/llama2/SAX_LLAMA2_DIR/
Replace SAX_LLAMA2_DIR with the appropriate value:
- 7B:
saxml_llama27b
- 13B:
saxml_llama213b
- 70B:
saxml_llama270b
- 7B:
Create SAX cluster
To create a SAX cluster, you need to:
In a typical deployment, you would run the admin server on a Compute Engine instance, and the model server on a TPU or a GPU. For the purpose of this tutorial, you will deploy the admin server and the model server on the same TPU v5e instance.
Create admin server
Create the admin server Docker container:
On the conversion server, install Docker:
sudo apt-get update sudo apt-get install docker.io
Launch the admin server Docker container:
sudo docker run --name sax-admin-server \ -it \ -d \ --rm \ --network host \ --env GSBUCKET=${GSBUCKET} us-docker.pkg.dev/cloud-tpu-images/inference/sax-admin-server:v1.1.0
You can run the docker run
command without the -d
option to view the logs and
ensure that the admin server starts correctly.
Create model server
The following sections show how to create a model server.
7b model
Launch the model server Docker container:
sudo docker run --privileged \
-it \
-d \
--rm \
--network host \
--name "sax-model-server" \
--env SAX_ROOT=gs://${GSBUCKET}/sax-root us-docker.pkg.dev/cloud-tpu-images/inference/sax-model-server:v1.1.0 \
--sax_cell="/sax/test" \
--port=10001 \
--platform_chip=tpuv5e \
--platform_topology='4'
13b model
The configuration for LLaMA13BFP16TPUv5e
is missing from lm_cloud.py
. The
following steps show how to update lm_cloud.py
and commit a new Docker image.
Start the model server:
sudo docker run --privileged \ -it \ -d \ --rm \ --network host \ --name "sax-model-server" \ --env SAX_ROOT=gs://${GSBUCKET}/sax-root \ us-docker.pkg.dev/cloud-tpu-images/inference/sax-model-server:v1.1.0 \ --sax_cell="/sax/test" \ --port=10001 \ --platform_chip=tpuv5e \ --platform_topology='8'
Connect to the Docker container using SSH:
sudo docker exec -it sax-model-server bash
Install Vim in the Docker image:
$ apt update $ apt install vim
Open the
saxml/server/pax/lm/params/lm_cloud.py
file. Search forLLaMA13B
. You should see the following code:@servable_model_registry.register @quantization.for_transformer(quantize_on_the_fly=False) class LLaMA13B(BaseLLaMA): """13B model on a A100-40GB. April 12, 2023 Latency = 5.06s with 128 decoded tokens. 38ms per output token. """ NUM_LAYERS = 40 VOCAB_SIZE = 32000 DIMS_PER_HEAD = 128 NUM_HEADS = 40 MODEL_DIMS = 5120 HIDDEN_DIMS = 13824 ICI_MESH_SHAPE = [1, 1, 1] @property def test_mode(self) -> bool: return True
Comment or delete the line that begins with
@quantization
. After this change, the file should look like the following:@servable_model_registry.register class LLaMA13B(BaseLLaMA): """13B model on a A100-40GB. April 12, 2023 Latency = 5.06s with 128 decoded tokens. 38ms per output token. """ NUM_LAYERS = 40 VOCAB_SIZE = 32000 DIMS_PER_HEAD = 128 NUM_HEADS = 40 MODEL_DIMS = 5120 HIDDEN_DIMS = 13824 ICI_MESH_SHAPE = [1, 1, 1] @property def test_mode(self) -> bool: return True
Add the following code to support the TPU configuration.
@servable_model_registry.register class LLaMA13BFP16TPUv5e(LLaMA13B): """13B model on TPU v5e-8. """ BATCH_SIZE = [1] BUCKET_KEYS = [128] MAX_DECODE_STEPS = [32] ENABLE_GENERATE_STREAM = False ICI_MESH_SHAPE = [1, 1, 8] @property def test_mode(self) -> bool: return False
Exit the Docker container SSH session:
exit
Commit the changes to a new Docker image:
sudo docker commit sax-model-server sax-model-server:v1.1.0-mod
Check that the new Docker image is created:
sudo docker images
You can publish the Docker image to your project's Artifact Registry, but this tutorial will proceed with the local image.
Stop the model server. The rest of the tutorial will use the updated model server.
sudo docker stop sax-model-server
Start the model server using the updated Docker image. Be sure to specify the updated image name,
sax-model-server:v1.1.0-mod
:sudo docker run --privileged \ -it \ -d \ --rm \ --network host \ --name "sax-model-server" \ --env SAX_ROOT=gs://${GSBUCKET}/sax-root \ sax-model-server:v1.1.0-mod \ --sax_cell="/sax/test" \ --port=10001 \ --platform_chip=tpuv5e \ --platform_topology='8'
70B model
Connect to your TPU using SSH and start the model server:
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
--project ${PROJECT_ID} \
--zone ${ZONE} \
--worker=all \
--command="
gcloud auth configure-docker \
us-docker.pkg.dev
# Pull SAX model server image
sudo docker pull us-docker.pkg.dev/cloud-tpu-images/inference/sax-model-server:v1.1.0
# Run model server
sudo docker run \
--privileged \
-it \
-d \
--rm \
--network host \
--name "sax-model-server" \
--env SAX_ROOT=gs://${GSBUCKET}/sax-root \
us-docker.pkg.dev/cloud-tpu-images/inference/sax-model-server:v1.1.0 \
--sax_cell="/sax/test" \
--port=10001 \
--platform_chip=tpuv5e \
--platform_topology='16'
"
Check logs
Check the model server logs to make sure that the model server has started properly:
docker logs -f sax-model-server
If the model server didn't start, see the Troubleshoot section for more information.
For the 70B model, repeat these steps for each TPU VM:
Connect to the TPU using SSH:
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \ --project ${PROJECT_ID} \ --zone ${ZONE} \ --worker=WORKER_NUMBER
WORKER_NUMBER is a 0-based index, indicating which TPU VM you want to connect to.
Check the logs:
sudo docker logs -f sax-model-server
Three TPU VMs should show that they have connected to the other instances:
I1117 00:16:07.196594 140613973207936 multi_host_sync.py:152] Received SPMD peer address 10.182.0.3:10001 I1117 00:16:07.197484 140613973207936 multi_host_sync.py:152] Received SPMD peer address 10.182.0.87:10001 I1117 00:16:07.199437 140613973207936 multi_host_sync.py:152] Received SPMD peer address 10.182.0.13:10001
One of the TPU VMs should have logs that show the model server starting:
I1115 04:01:29.479170 139974275995200 model_service_base.py:867] Started joining SAX cell /sax/test ERROR: logging before flag.Parse: I1115 04:01:31.479794 1 location.go:141] Calling Join due to address update ERROR: logging before flag.Parse: I1115 04:01:31.814721 1 location.go:155] Joined 10.182.0.44:10000
Publish the model
SAX comes with a command-line tool called saxutil
, which simplifies
interacting with SAX model servers. In this tutorial, you use
saxutil
to publish the model. For the full list of saxutil
commands, see
the Saxml README
file.
Change to the directory where you cloned the Saxml GitHub repository:
cd saxml
For the 70B model, connect to your conversion server:
gcloud compute ssh ${CONV_SERVER_NAME} \ --project ${PROJECT_ID} \ --zone ${ZONE}
Install Bazel:
sudo apt-get install bazel
Set an alias for running
saxutil
with your Cloud Storage bucket:alias saxutil='bazel run saxml/bin:saxutil -- --sax_root=gs://${GSBUCKET}/sax-root'
Publish the model using
saxutil
. This takes about 10 minutes on a TPU v5litepod-8.saxutil --sax_root=gs://${GSBUCKET}/sax-root publish '/sax/test/MODEL' \ saxml.server.pax.lm.params.lm_cloud.PARAMETERS \ gs://${GSBUCKET}/sax_models/llama2/SAX_LLAMA2_DIR/checkpoint_00000000/ \ 1
Replace the following variables:
Model size Values 7B MODEL
: llama27b
PARAMETERS
: saxml.server.pax.lm.params.lm_cloud.LLaMA7BFP16TPUv5e
SAX_LLAMA2_DIR
: saxml_llama27b
13B MODEL
: llama213b
PARAMETERS
: saxml.server.pax.lm.params.lm_cloud.LLaMA13BFP16TPUv5e
SAX_LLAMA2_DIR
: saxml_llama213b
70B MODEL
: llama270b
PARAMETERS
: saxml.server.pax.lm.params.lm_cloud.LLaMA70BFP16TPUv5e
SAX_LLAMA2_DIR
: saxml_llama270b
Test deployment
To check if deployment has succeeded, use the saxutil ls
command:
saxutil ls /sax/test/MODEL
A successful deployment should have a number of replicas greater than zero and look similar to the following:
INFO: Running command line: bazel-bin/saxml/bin/saxutil_/saxutil '--sax_rootmgs://sax-admin2/sax-root is /sax/test/1lama27b
+----------+-------------------------------------------------------+-----------------------------------------------------------------------+---------------+---------------------------+
| MODEL | MODEL PATH | CHECKPOINT PATH | # OF REPLICAS | (SELECTED) REPLICAADDRESS |
+----------+-------------------------------------------------------+-----------------------------------------------------------------------+---------------+---------------------------+
| llama27b | saxml.server.pax.lm.params.lm_cloud.LLaMA7BFP16TPUv5e | gs://${MODEL_BUCKET}/sax_models/llama2/7b/pax_7B/checkpoint_00000000/ | 1 | 10.182.0.28:10001 |
+----------+-------------------------------------------------------+-----------------------------------------------------------------------+---------------+---------------------------+
The Docker logs for the model server will be similar to the following:
I1114 17:31:03.586631 140003787142720 model_service_base.py:532] Successfully loaded model for key: /sax/test/llama27b
INFO: Running command line: bazel-bin/saxml/bin/saxutil_/saxutil '--sax_rootmgs://sax-admin2/sax-root is /sax/test/1lama27b
Troubleshoot
If the deployment fails, check the model server logs:
sudo docker logs -f sax-model-server
For a successful deployment, you should see the following output:
Successfully loaded model for key: /sax/test/llama27b
If the logs don't show that the model was deployed, check the model configuration and the path to your model checkpoint.
Generate responses
You can use the saxutil
tool to generate responses to prompts.
Generate responses to a question:
saxutil lm.generate -extra="temperature:0.2" /sax/test/MODEL "Q: Who is Harry Potter's mother? A:"
The output should be similar to the following:
INFO: Running command line: bazel-bin/saxml/bin/saxutil_/saxutil '--sax_rootmgs://sax-admin2/sax-root' lm.generate /sax/test/llama27b 'Q: Who is Harry Potter's mother? A: `
+-------------------------------+------------+
| GENERATE | SCORE |
+-------------------------------+------------+
| 1. Harry Potter's mother is | -20.214787 |
| Lily Evans. 2. Harry Potter's | |
| mother is Petunia Evans | |
| (Dursley). | |
+-------------------------------+------------+
Interact with the model from a client
The SAX repository includes clients that you can use to interact with a SAX cell. Clients are available in C++, Python and Go. The following example shows how to build a Python client.
Build the Python client:
bazel build saxml/client/python:sax.cc --compile_one_dependency
Add the client to
PYTHONPATH
. This example assumes that you havesaxml
under your home directory:export PYTHONPATH=${PYTHONPATH}:$HOME/saxml/bazel-bin/saxml/client/python/
Interact with SAX from the Python shell:
$ python3 Python 3.10.12 (main, Jun 11 2023, 05:26:28) [GCC 11.4.0] on linux Type "help", "copyright", "credits" or "license" for more information. >>> import sax >>>
Interact with the model from an HTTP endpoint
To interact with the model from an HTTP endpoint, create an HTTP client:
Create a Compute Engine VM:
export PROJECT_ID=PROJECT_ID export ZONE=ZONE export HTTP_SERVER_NAME=HTTP_SERVER_NAME export SERVICE_ACCOUNT=SERVICE_ACCOUNT export MACHINE_TYPE=e2-standard-8 gcloud compute instances create $HTTP_SERVER_NAME --project=$PROJECT_ID --zone=$ZONE \ --machine-type=$MACHINE_TYPE \ --network-interface=network-tier=PREMIUM,stack-type=IPV4_ONLY,subnet=default \ --maintenance-policy=MIGRATE --provisioning-model=STANDARD \ --service-account=$SERVICE_ACCOUNT \ --scopes=https://www.googleapis.com/auth/cloud-platform \ --tags=http-server,https-server \ --create-disk=auto-delete=yes,boot=yes,device-name=$HTTP_SERVER_NAME,image=projects/ml-images/global/images/c0-deeplearning-common-cpu-v20230925-debian-10,mode=rw,size=500,type=projects/$PROJECT_ID/zones/$ZONE/diskTypes/pd-balanced \ --no-shielded-secure-boot \ --shielded-vtpm \ --shielded-integrity-monitoring \ --labels=goog-ec-src=vm_add-gcloud \ --reservation-affinity=any
Connect to the Compute Engine VM using SSH:
gcloud compute ssh $HTTP_SERVER_NAME --project=$PROJECT_ID --zone=$ZONE
Clone the AI on GKE GitHub repository:
git clone https://github.com/GoogleCloudPlatform/ai-on-gke.git
Change to the HTTP server directory:
cd ai-on-gke/tools/saxml-on-gke/httpserver
Build the Docker file:
docker build -f Dockerfile -t sax-http .
Run the HTTP server:
docker run -e SAX_ROOT=gs://${GSBUCKET}/sax-root -p 8888:8888 -it sax-http
Test your endpoint from your local machine or another server with access to port 8888 using the following commands:
Export environment variables for your server's IP address and port:
export LB_IP=HTTP_SERVER_EXTERNAL_IP export PORT=8888
Set the JSON payload, containing the model and query:
json_payload=$(cat << EOF { "model": "/sax/test/MODEL", "query": "Example query" } EOF )
Send the request:
curl --request POST --header "Content-type: application/json" -s $LB_IP:$PORT/generate --data "$json_payload"
Clean up
To avoid incurring charges to your Google Cloud account for the resources used in this tutorial, either delete the project that contains the resources, or keep the project and delete the individual resources.
When you are done with this tutorial, follow these steps to clean up your resources.
Delete your Cloud TPU.
$ gcloud compute tpus tpu-vm delete $TPU_NAME --zone $ZONE
Delete your Compute Engine instance, if you created one.
gcloud compute instances delete INSTANCE_NAME
Delete your Cloud Storage bucket and its contents.
gcloud storage rm --recursive gs://BUCKET_NAME
What's next
- All TPU tutorials
- Supported reference models
- Inference using v5e
- Convert a model for inference using v5e