JAX + Flower For Federated Learning Gives Machine Learning Researchers The Flexibility To Use The Deep Learning Framework For Their Projects

Google researchers created JAX to conduct NumPy computations on GPUs and TPUs. DeepMind uses it to help and expedite its research, and it is increasingly gaining popularity. Differentiation with grad(), vectorization with map(), and JIT-compilation (just-in-time) with jit are some of the composable functions required for machine learning research in JAX (). As a result, adding a JAX-based workload to the Flower code samples is a must-have. The combination of JAX and Flower allows ML and FL researchers to employ the deep learning framework that their projects demand. The updated code example now serves as a template for migrating existing JAX projects to a federated environment.

It’s pretty simple to put up a centralized machine learning architecture, and the JAX developer documentation has multiple examples. Because the ML model parameters are stored in the DeviceArray data format, setting up the federated workload requires some knowledge of JAX. To be compatible with the Flower NumPyClient, those arguments must be converted to NumPy ndarrays. The JAX meets Flower example below demonstrates how a Flower setup might work.

Let’s start by setting up a very basic JAX training environment. To construct a random regression problem, the file jax_training.py uses a linear regression dataset from scikit-learn. The data is loaded using the load data function (). Model() defines a simple linear regression model, whereas train() and evaluate() specify the training process and evaluation of the trained model, respectively. Loss fn() is an extra function for the loss calculation, and it is differentiated using the JAX-defined differentiator function grad ().

def main():

    # Load training and validation data
    X, y, X_test, y_test = load_data()
    model_shape = X.shape[1:]

    # Defining the loss function 
    grad_fn = jax.grad(loss_fn)

    # Loading the linear regression model
    params = load_model(model_shape)   

    # Start model training based on training set
    params, loss, num_examples = train(params, grad_fn, X, y)
    print("Training loss:", loss)

    # Evaluate model (loss)
    loss, num_example = evaluation(params, grad_fn, X_test, y_test)
    print("Evaluation loss:", loss)

The server sends the global model parameters to a set of randomly selected clients, the clients train the model parameters on their local data, they return the updated model parameters to the server, and the server aggregates the parameter updates it received from the clients to get the new (hopefully improved) global model. This is an example of one round of federated learning, which is repeated until the model converges.

By default, the Flower server uses the basic FedAvg technique to aggregate the model parameter changes it receives from clients. The new global model based on the aggregated model parameters is delivered to the next group of randomly selected clients to begin the next cycle of federated learning.

To do so, simply re-use the jax_training.py methods to perform local training on each client before federating it with Flower. The federated training client code is described below.

To begin, import all of the necessary packages. Flower (package flwr), NumPy, and Jax are the three:

import flwr as fl
import numpy as np
import jax
import jax.numpy as jnp

from typing import Dict, List, Tuple

import jax_training

Client.py’s main function is very similar to the centralized example. After loading the data and creating the model, the Flower client is started with the local model and data.

def main() -> None:
    """Load data, start NumPyClient."""

    # Load data
    train_x, train_y, test_x, test_y = jax_training.load_data()

    # Define the loss function
    grad_fn = jax.grad(jax_training.loss_fn)

    # Load model (from centralized training) and initialize parameters
    model_shape = train_x.shape[1:]
    params = jax_training.load_model(model_shape)

    # Start Flower client
    client = FlowerClient(params, grad_fn, train_x, train_y, test_x, test_y)
    fl.client.start_numpy_client("0.0.0.0:8080", client)


if __name__ == "__main__":
    main()

FlowerClient is the glue code that allows Flower to call the regular training and evaluation routines by connecting the local model and data to the Flower framework. When the client is started (by executing start client or start numpy client), it establishes a connection to the server, waits for messages from the server, processes those messages by invoking FlowerClient methods, and then returns the results to the server for aggregation.

Get parameters(), set parameters(), fit(), and evaluate are the four methods required for a Flower client implementation (). To collect the parameters of the locally defined model, use the function get parameters(). It’s worth noting that in order to communicate the local model parameters to the Flower server and start the server-side aggregation process, the JAX parameters from DeviceArrays must be transformed to NumPy ndarrays using np.array().

The aggregation method takes the average of the collected parameters and applies it to the global model parameters. The next set of clients receives the modified global model parameters, and set parameters() updates the local model parameters on those clients. Following a round of training, the evaluation procedure begins. A single cycle of federated learning is now complete.

class FlowerClient(fl.client.NumPyClient):
    """Flower client implementing linear regression using JAX"""

    def __init__(
        self,
        params: Dict,
        grad_fn: Callable,
        train_x: List[np.ndarray],
        train_y: List[np.ndarray],
        test_x: List[np.ndarray],
        test_y: List[np.ndarray],
    ) -> None:
        self.params = params
        self.grad_fn = grad_fn
        self.train_x = train_x
        self.train_y = train_y
        self.test_x = test_x
        self.test_y = test_y

    def get_parameters(self):
        # Return model parameters as a list of NumPy ndarrays
        parameter_value = []
        for _, val in self.params.items():
            parameter_value.append(np.array(val))
        return parameter_value
    
    def set_parameters(self, parameters: List[np.ndarray]) -> None:
        # Collect model parameters and update the parameters of the local model
        value=jnp.ndarray
        params_item = list(zip(self.params.keys(),parameters))
        for item in params_item:
            key = item[0]
            value = item[1]
            self.params[key] = value
        return self.params
    
    def fit(
        self, parameters: List[np.ndarray], config: Dict
    ) -> Tuple[List[np.ndarray], int, Dict]:
        # Set model parameters, train model, return updated model parameters
        print("Start local training")
        self.params = self.set_parameters(parameters)
        self.params, loss, num_examples = jax_training.train(self.params, self.grad_fn, self.train_x, self.train_y)
        results = {"loss": float(loss)}
        print("Training results", results)
        return self.get_parameters(), num_examples, results

    def evaluate(
        self, parameters: List[np.ndarray], config: Dict
    ) -> Tuple[float, int, Dict]:
        # Set model parameters, evaluate model on local test dataset, return result
        print("Start evaluation")
        self.params = self.set_parameters(parameters)
        loss, num_examples = jax_training.evaluation(self.params,self.grad_fn, self.test_x, self.test_y)
        print("Evaluation accuracy & loss", loss)
        return (
            float(loss),
            num_examples,
            {"loss": float(loss)},
        )

With server.py, a Flower server can now be set up.

import flwr as fl

if __name__ == "__main__":
    fl.server.start_server("0.0.0.0:8080", config={"num_rounds": 3})

open a terminal window and type:

$ python server.py

Start the first client by opening a new terminal and typing:

$ python client.py

Finally, start the second client by opening a new terminal:

$ python client.py

Flower is used to federate the previously centralized JAX example. To enable Flower to manage the complexity of federated learning, all that is required is to convert the JAX model parameters to and from NumPy ndarrays and subclass NumPyClient.

Loading different data points on each client, launching new clients, or even establishing different tactics are examples of other paradigms.

Check out the  Advanced TensorFlow Example for a deeper dive into Flower’s features.

Source: https://flower.dev/blog/2022-03-22-jax-meets-flower-federated-learning-with-jax/