【文件属性】:
文件名称:基于Numpy/JAX/JIT的Pyro(深度概率编程)-python
文件大小:3.43MB
文件格式:ZIP
更新时间:2021-06-18 21:01:46
机器学习
基于Numpy/JAX/JIT的Pyro(深度概率编程)
NumPyro 使用由 JAX 提供支持的 NumPy 进行概率编程,用于自动梯度和 JIT 编译到 GPU/TPU/CPU。
文档 |
示例 |
论坛 什么是 NumPyro?
NumPyro 是一个小型概率编程库,它为 Pyro 提供了一个 NumPy 后端。
我们依靠 JAX 进行自动微分和 JIT 编译到 GPU/CPU。
这是一个正在积极开发中的 alpha 版本,因此请注意随着设计的发展,API 的脆弱性、错误和更改。
NumPyro 设计为轻量级,并专注于提供用户可以在其上构建的灵活基础: Pyro 原语:除了示例和参数等 Pyro 原语之外,NumPyro 程序还可以包含常规 Python 和 NumPy 代码。
除了 PyTorch 和 Numpy 的 API 之间的一些细微差别外,模型代码应该与 Pyro 非常相似。
请参阅下面的示例。
推理算法:NumPyro 目前支持 Hamiltonian Monte Carlo,包括 No U-Turn Sampler 的实现。
NumPyro