Displaying 1 to 20 from 28 results

numpyro - Probabilistic programming with NumPy powered by JAX for autograd and JIT compilation to GPU/TPU/CPU

  •    Python

Probabilistic programming with NumPy powered by JAX for autograd and JIT compilation to GPU/TPU/CPU. NumPyro is a small probabilistic programming library that provides a NumPy backend for Pyro. We rely on JAX for automatic differentiation and JIT compilation to GPU / CPU. This is an alpha release under active development, so beware of brittleness, bugs, and changes to the API as the design evolves.

big_transfer - Official repository for the "Big Transfer (BiT): General Visual Representation Learning" paper

  •    Python

Update 18/06/2021: We release new high performing BiT-R50x1 models, which were distilled from BiT-M-R152x2, see this section. More details in our paper "Knowledge distillation: A good teacher is patient and consistent". Update 08/02/2021: We also release ALL BiT-M models fine-tuned on ALL 19 VTAB-1k datasets, see below.

dm-haiku - JAX-based neural network library

  •    Python

Haiku is a simple neural network library for JAX developed by some of the authors of Sonnet, a neural network library for TensorFlow. Documentation on Haiku can be found at https://dm-haiku.readthedocs.io/.

transformers - 🤗Transformers: State-of-the-art Natural Language Processing for Pytorch, TensorFlow, and JAX

  •    Python

🤗 Transformers provides thousands of pretrained models to perform tasks on texts such as classification, information extraction, question answering, summarization, translation, text generation and more in over 100 languages. Its aim is to make cutting-edge NLP easier to use for everyone. 🤗 Transformers provides APIs to quickly download and use those pretrained models on a given text, fine-tune them on your own datasets and then share them with the community on our model hub. At the same time, each python module defining an architecture is fully standalone and can be modified to enable quick research experiments.




tensorly - TensorLy: Tensor Learning in Python.

  •    Python

TensorLy is a Python library that aims at making tensor learning simple and accessible. It allows to easily perform tensor decomposition, tensor learning and tensor algebra. Its backend system allows to seamlessly perform computation with NumPy, PyTorch, JAX, MXNet, TensorFlow or CuPy, and run methods at scale on CPU or GPU. The only pre-requisite is to have Python 3 installed. The easiest way is via the Anaconda distribution.

datasets - TFDS is a collection of datasets ready to use with TensorFlow, Jax, ...

  •    Python

TensorFlow Datasets provides many public datasets as tf.data.Datasets. To install and use TFDS, we strongly encourage to start with our getting started guide. Try it interactively in a Colab notebook.

jraph - A Graph Neural Network Library in Jax

  •    Python

Jraph (pronounced "giraffe") is a lightweight library for working with graph neural networks in jax. It provides a data structure for graphs, a set of utilities for working with graphs, and a 'zoo' of forkable graph neural network models. Jraph is designed to provide utilities for working with graphs in jax, but doesn't prescribe a way to write or develop graph neural networks.

awesome-jax - JAX - A curated list of resources https://github.com/google/jax

  •    

JAX brings automatic differentiation and the XLA compiler together through a NumPy-like API for high performance machine learning research on accelerators like GPUs and TPUs. This section contains libraries that are well-made and useful, but have not necessarily been battle-tested by a large userbase yet.


fast-soft-sort - Fast Differentiable Sorting and Ranking

  •    Python

Differentiable sorting and ranking operations in O(n log n). Run python setup.py install or copy the fast_soft_sort/ folder to your project.

long-range-arena - Long Range Arena for Benchmarking Efficient Transformers

  •    Python

Long-range arena is an effort toward systematic evaluation of efficient transformer models. The project aims at establishing benchmark tasks/dtasets using which we can evaluate transformer-based models in a systematic way, by assessing their generalization power, computational efficiency, memory foot-print, etc. Long-range arena also implements different variants of Transformer models in JAX, using Flax.

Newt - Newt - a Gaussian process library in JAX.

  •    Python

Newt is a Gaussian process (GP) library built in JAX (with objax), built and actively maintained by Will Wilkinson. Newt provides a unifying view of approximate Bayesian inference for GPs, and allows for the combination of many models (e.g. GPs, sparse GPs, Markov GPs, sparse Markov GPs) with the inference method of your choice (VI, EP, Laplace, Linearisation). For a full list of the methods implemented scroll down to the bottom of this page.

GPJax - A didactic Gaussian process package for researchers in Jax.

  •    Python

GPJax aims to provide a low-level interface to Gaussian process models. Code is written entirely in Jax to enhance readability, and structured to allow researchers to easily extend the code to suit their own needs. When defining GP prior in GPJax, the user need only specify a mean and kernel function. A GP posterior can then be realised by computing the product of our prior with a likelihood function. The idea behind this is that the code should be as close as possible to the maths that we would write on paper when working with GP models. After importing the necessary dependencies, we'll first simulate some data.

extending-jax - Extending JAX with custom C++ and CUDA code

  •    Python

The motivation for this is that in my work I want to use libraries like JAX to fit models to data in astrophysics. In these models, there is often at least one part of the model specification that is physically motivated and while there are generally existing implementations of these model elements, it is often inefficient or impractical to re-implement these as a high-level JAX function. Instead, I want to expose a well-tested and optimized implementation in C directly to JAX. In my work, this often includes things like iterative algorithms or special functions that are not well suited to implementation using JAX directly. So, as part of updating my exoplanet library to interface with JAX, I had to learn what infrastructure was required to support this use case, and since I couldn't find a tutorial that covered all the pieces that I needed in one place, I wanted to put this together. Pretty much everything that I'll talk about is covered in more detail somewhere else (even if that somewhere is just a comment in some source code), but hopefully this summary can point you in the right direction if you have a use case like this.

funsor - Functional tensors for probabilistic programming

  •    Python

Funsor is a tensor-like library for functions and distributions. See Functional tensors for probabilistic programming for a system description.

SymJAX - Documentation:

  •    Python

and most importantly, a SYMBOLIC/DECLARATIVE programming environment allowing CONCISE/EXPLICIT/OPTIMIZED computations. For a deep network oriented imperative library built on JAX and with a JAX syntax check out FLAX.

scenic - Scenic: A Jax Library for Computer Vision and Beyond

  •    Python

Scenic is a codebase with a focus on research around attention-based models for computer vision. Scenic has been successfully used to develop classification, segmentation, and detection models for multiple modalities including images, video, audio, and multimodal combinations of them. More precisely, Scenic is a (i) set of shared light-weight libraries solving tasks commonly encountered tasks when training large-scale (i.e. multi-device, multi-host) vision models; and (ii) a number of projects containing fully fleshed out problem-specific training and evaluation loops using these libraries.

dm_pix - PIX is an image processing library in JAX, for JAX.

  •    Python

PIX is an image processing library in JAX, for JAX. JAX is a library resulting from the union of Autograd and XLA for high-performance machine learning research. It provides NumPy, SciPy, automatic differentiation and first-class GPU/TPU support.

brax - Massively parallel rigidbody physics simulation on accelerator hardware.

  •    Jupyter

Brax is a differentiable physics engine that simulates environments made up of rigid bodies, joints, and actuators. It's also a suite of learning algorithms to train agents to operate in these environments (PPO, SAC, evolutionary strategy, and direct trajectory optimization are implemented). Brax is written in JAX and is designed for use on acceleration hardware. It is both efficient for single-core training, and scalable to massively parallel simulation, without the need for pesky datacenters.

robustness-vit - Contains code for the paper "Vision Transformers are Robust Learners".

  •    Jupyter

This repository contains the code for the paper Vision Transformers are Robust Learners by Sayak Paul* and Pin-Yu Chen*. *Equal contribution.

veros - The versatile ocean simulator, in pure Python, powered by JAX.

  •    Python

Veros is the versatile ocean simulator -- it aims to be a powerful tool that makes high-performance ocean modeling approachable and fun. Because Veros is a pure Python module, the days of struggling with complicated model setup workflows, ancient programming environments, and obscure legacy code are finally over. Veros supports a NumPy backend for small-scale problems and a high-performance JAX backend with CPU and GPU support. It is fully parallelized via MPI and supports distributed execution.






We have large collection of open source products. Follow the tags from Tag Cloud >>


Open source products are scattered around the web. Please provide information about the open source projects you own / you use. Add Projects.