TensorFlow, Pytorch, Caffe, Keras, Theano, and many more. There’s already an abundance of deep learning frameworks, so why should you care about Trax? Well, most deep learning libraries have two major drawbacks:
- They require you to write long syntaxes, even for simple tasks.
- Their language/API can be quite complex and hard to understand, especially for complicated architectures.
PyTorch Lightning and Keras solve this issue to a great extent, but they are just high-level wrapper APIs to complicated packages. On the other hand, Trax is built from the ground up for speed and clear, concise code, even when dealing with large, complex models. As the developers put it, Trax is “Your path to advanced deep learning“. Also, it’s actively used and maintained by the Google Brain team.
The codebase is organized by SOLID architecture and design principles, and it provides well-formatted logging. Trax uses the JAX library. JAX provides high-performance code acceleration by using Autograd and XLA. Autograd assists JAX to distinguish native Python and Numpy, and XLA is used to just-in-time compile and execute programs on GPU and Cloud TPU accelerators. It can be used as a library in python scripts and notebooks or binary from the shell. This makes training larger models more convenient. One thing to note is that Trax oriented more towards natural language models than computer vision.
A brief introduction to Trax‘s high level syntax
- Install Trax from PyPI
!pip install
- To work with layers in Trax you’ll need to import layers. A basic Sigmoid layer can be instantiated using activation_fns.Sigmoid(), you can find the details of all layers here.
# Make a sigmoid activation layer from trax import layers as ly sigmoid = ly.activation_fns.Sigmoid() # Some attributes print("name :", sigmoid.name) print("weights :", sigmoid.weights) print("# of inputs :", sigmoid.n_in) print("# of outputs :", sigmoid.n_out)
Trax provides a Python decorator that can be used to create classes for neural network layers dynamically
# define a custom layer def Custom_layer(): # Set a name layer_name = "custom_layer" # Custom function def func(x): return x + x^2 return ly.base.Fn(layer_name, func) # Create the layer object custom_layer = Custom_layer() # Check properties print("name :", custom_layer.name) print("expected inputs :", custom_layer.n_in) print("promised outputs :", custom_layer.n_out) # Inputs x = np.array([0, -1, 1]) # Outputs print("outputs :", custom_layer(x))
- Models are built from layers using combinators like
trax.layers.combinators.Serial
,trax.layers.combinators.Parallel
, andtrax.layers.combinators.Branch
. Here’s a transformer implemented in Trax:
model = ly.Serial( ly.Embedding(vocab_size=8192, d_feature=256), ly.Mean(axis=1), # Average on axis 1 (length of sentence). ly.Dense(2), # Classify 2 classes. ) # Print model structure. print(model)
- It has access to a large number of datasets including Tesnor2Tesnor and Tensorflow datasets. The data streams in Trax are represented as Python iterators, here’s the code to import the TFDS IMDb reviews dataset using
trax.data
:
train_stream = trax.data.TFDS('imdb_reviews', keys=('text', 'label'), train=True)() eval_stream = trax.data.TFDS('imdb_reviews', keys=('text', 'label'), train=False)()
- You can train supervised and reinforcement learning models in Trax using
trax.supervised.training
andtrax.rl
respectively. Here’s an example of training a supervised learning model:
from trax.supervised import training # Training task train_task = training.TrainTask( labeled_data=train_batches_stream, loss_layer=tl.WeightedCategoryCrossEntropy(), optimizer=trax.optimizers.Adam(0.01), n_steps_per_checkpoint=500, ) # Evaluaton task eval_task = training.EvalTask( labeled_data=eval_batches_stream, metrics=[tl.WeightedCategoryCrossEntropy(), tl.WeightedCategoryAccuracy()], n_eval_batches=20 # For less variance in eval numbers. ) # Training loop saves checkpoints to output_dir. output_dir = os.path.expanduser('~/output_dir/') !rm -rf {output_dir} training_loop = training.Loop(model, train_task, eval_tasks=[eval_task], output_dir=output_dir) # Run 2000 steps (batches). training_loop.run(2000)
After training, the models can be run like any function: