Probabilistic programming with NumPy powered by JAX for autograd and JIT compilation to GPU/TPU/CPU. NumPyro is a small probabilistic programming library that provides a NumPy backend for Pyro. We rely on JAX for automatic differentiation and JIT compilation to GPU / CPU. This is an alpha release under active development, so beware of brittleness, bugs, and changes to the API as the design evolves.
numpy probabilistic-programming bayesian-inference hmc pyro jax inference-algorithmsUpdate 18/06/2021: We release new high performing BiT-R50x1 models, which were distilled from BiT-M-R152x2, see this section. More details in our paper "Knowledge distillation: A good teacher is patient and consistent". Update 08/02/2021: We also release ALL BiT-M models fine-tuned on ALL 19 VTAB-1k datasets, see below.
deep-learning pytorch imagenet convolutional-neural-networks transfer-learning jax tensorflow2Haiku is a simple neural network library for JAX developed by some of the authors of Sonnet, a neural network library for TensorFlow. Documentation on Haiku can be found at https://dm-haiku.readthedocs.io/.
machine-learning deep-neural-networks deep-learning neural-networks jax🤗 Transformers provides thousands of pretrained models to perform tasks on texts such as classification, information extraction, question answering, summarization, translation, text generation and more in over 100 languages. Its aim is to make cutting-edge NLP easier to use for everyone. 🤗 Transformers provides APIs to quickly download and use those pretrained models on a given text, fine-tune them on your own datasets and then share them with the community on our model hub. At the same time, each python module defining an architecture is fully standalone and can be modified to enable quick research experiments.
nlp natural-language-processing tensorflow pytorch transformer speech-recognition seq2seq flax gpt pretrained-models language-models natural-language-generation nlp-library language-model bert natural-language-understanding jax xlnet pytorch-transformers model-hubTensorLy is a Python library that aims at making tensor learning simple and accessible. It allows to easily perform tensor decomposition, tensor learning and tensor algebra. Its backend system allows to seamlessly perform computation with NumPy, PyTorch, JAX, MXNet, TensorFlow or CuPy, and run methods at scale on CPU or GPU. The only pre-requisite is to have Python 3 installed. The easiest way is via the Anaconda distribution.
machine-learning mxnet tensorflow numpy pytorch decomposition tensor-factorization tensor tensor-algebra tensorly tensor-learning tensor-decomposition cupy tensor-regressions tensor-methods jaxTensorFlow Datasets provides many public datasets as tf.data.Datasets. To install and use TFDS, we strongly encourage to start with our getting started guide. Try it interactively in a Colab notebook.
data machine-learning tensorflow numpy dataset datasets jaxJraph (pronounced "giraffe") is a lightweight library for working with graph neural networks in jax. It provides a data structure for graphs, a set of utilities for working with graphs, and a 'zoo' of forkable graph neural network models. Jraph is designed to provide utilities for working with graphs in jax, but doesn't prescribe a way to write or develop graph neural networks.
machine-learning deep-learning jax graph-neural-networksJAX brings automatic differentiation and the XLA compiler together through a NumPy-like API for high performance machine learning research on accelerators like GPUs and TPUs. This section contains libraries that are well-made and useful, but have not necessarily been battle-tested by a large userbase yet.
machine-learning awesome deep-learning neural-network numpy autograd awesome-list jax xlaDifferentiable sorting and ranking operations in O(n log n). Run python setup.py install or copy the fast_soft_sort/ folder to your project.
sorting tensorflow pytorch ranking differentiable jaxLong-range arena is an effort toward systematic evaluation of efficient transformer models. The project aims at establishing benchmark tasks/dtasets using which we can evaluate transformer-based models in a systematic way, by assessing their generalization power, computational efficiency, memory foot-print, etc. Long-range arena also implements different variants of Transformer models in JAX, using Flax.
nlp deep-learning transformers attention flax jaxNewt is a Gaussian process (GP) library built in JAX (with objax), built and actively maintained by Will Wilkinson. Newt provides a unifying view of approximate Bayesian inference for GPs, and allows for the combination of many models (e.g. GPs, sparse GPs, Markov GPs, sparse Markov GPs) with the inference method of your choice (VI, EP, Laplace, Linearisation). For a full list of the methods implemented scroll down to the bottom of this page.
machine-learning signal-processing gaussian-processes state-space-models jax approximate-bayesian-inference sparse-gps markov-gpsGPJax aims to provide a low-level interface to Gaussian process models. Code is written entirely in Jax to enhance readability, and structured to allow researchers to easily extend the code to suit their own needs. When defining GP prior in GPJax, the user need only specify a mean and kernel function. A GP posterior can then be realised by computing the product of our prior with a likelihood function. The idea behind this is that the code should be as close as possible to the maths that we would write on paper when working with GP models. After importing the necessary dependencies, we'll first simulate some data.
probabilistic-programming gaussian-processes jaxThe motivation for this is that in my work I want to use libraries like JAX to fit models to data in astrophysics. In these models, there is often at least one part of the model specification that is physically motivated and while there are generally existing implementations of these model elements, it is often inefficient or impractical to re-implement these as a high-level JAX function. Instead, I want to expose a well-tested and optimized implementation in C directly to JAX. In my work, this often includes things like iterative algorithms or special functions that are not well suited to implementation using JAX directly. So, as part of updating my exoplanet library to interface with JAX, I had to learn what infrastructure was required to support this use case, and since I couldn't find a tutorial that covered all the pieces that I needed in one place, I wanted to put this together. Pretty much everything that I'll talk about is covered in more detail somewhere else (even if that somewhere is just a comment in some source code), but hopefully this summary can point you in the right direction if you have a use case like this.
cuda jax xlaFunsor is a tensor-like library for functions and distributions. See Functional tensors for probabilistic programming for a system description.
machine-learning numpy pytorch symbolic probabilistic-programming pyro jaxand most importantly, a SYMBOLIC/DECLARATIVE programming environment allowing CONCISE/EXPLICIT/OPTIMIZED computations. For a deep network oriented imperative library built on JAX and with a JAX syntax check out FLAX.
deep-neural-networks lasagne theano deep-learning tensorflow numpy dataset jaxScenic is a codebase with a focus on research around attention-based models for computer vision. Scenic has been successfully used to develop classification, segmentation, and detection models for multiple modalities including images, video, audio, and multimodal combinations of them. More precisely, Scenic is a (i) set of shared light-weight libraries solving tasks commonly encountered tasks when training large-scale (i.e. multi-device, multi-host) vision models; and (ii) a number of projects containing fully fleshed out problem-specific training and evaluation loops using these libraries.
research computer-vision deep-learning transformers attention jax vision-transformerPIX is an image processing library in JAX, for JAX. JAX is a library resulting from the union of Autograd and XLA for high-performance machine learning research. It provides NumPy, SciPy, automatic differentiation and first-class GPU/TPU support.
machine-learning image computer-vision image-processing jaxBrax is a differentiable physics engine that simulates environments made up of rigid bodies, joints, and actuators. It's also a suite of learning algorithms to train agents to operate in these environments (PPO, SAC, evolutionary strategy, and direct trajectory optimization are implemented). Brax is written in JAX and is designed for use on acceleration hardware. It is both efficient for single-core training, and scalable to massively parallel simulation, without the need for pesky datacenters.
reinforcement-learning robotics physics-simulation jaxThis repository contains the code for the paper Vision Transformers are Robust Learners by Sayak Paul* and Pin-Yu Chen*. *Equal contribution.
computer-vision tensorflow transformers pytorch robustness self-attention jaxVeros is the versatile ocean simulator -- it aims to be a powerful tool that makes high-performance ocean modeling approachable and fun. Because Veros is a pure Python module, the days of struggling with complicated model setup workflows, ancient programming environments, and obscure legacy code are finally over. Veros supports a NumPy backend for small-scale problems and a high-performance JAX backend with CPU and GPU support. It is fully parallelized via MPI and supports distributed execution.
gpu climate parallel distributed geophysics oceanography multi-core jax
We have large collection of open source products. Follow the tags from
Tag Cloud >>
Open source products are scattered around the web. Please provide information
about the open source projects you own / you use.
Add Projects.