Jax框架的性能分析——性能分析可视化

时间:2024-01-22 16:33:41

官方文档:

https://jax.readthedocs.io/en/latest/profiling.html



  1. 将jax代码的性能文件写入到文件夹中,并给出上传第三方网站的链接生成(https://ui.perfetto.dev/):


import jax

with jax.profiler.trace("/tmp/jax-trace", create_perfetto_link=True):
  # Run the operations to be profiled
  key = jax.random.PRNGKey(0)
  x = jax.random.normal(key, (5000, 5000))
  y = x @ x
  y.block_until_ready()


可以通过TensorFlow的tensorboard来查看指定生成的文件夹,来在本地进行显示;

也可以通过点击生成的链接将性能文件自动上传到第三方网站并查看(设置生成链接后会中断进程运行,直至手动打开该生成链接,然后进程才会继续执行):

Jax框架的性能分析——性能分析可视化_性能分析


使用TensorFlow的tensorboard进行性能文件的显示需要安装如下library:

pip install tensorflow tensorboard-plugin-profile


  1. 对部分jax的代码进行性能分析,并通过端口将性能文件转发给另一进程,然后再另一进程中生成上传链接,点击链接后上传第三方网站后查看,该种方式不会中断原进程的运行,也不需要等待原进程运行结束。(该种方式最大的不同就是生成第三方链接不是在原进程中,不会影响原进程的运行)

原进程运行代码(待性能分析的代码,需要设置服务端口号,这里是8877):

import jax

jax.profiler.start_server(8877)
key = jax.random.PRNGKey(0)
x = jax.random.normal(key, (1000, 1000))
for _ in range(100000):
    y = x @ x
y.block_until_ready()
jax.profiler.stop_server()


另起一个进程,检测端口号8877,接受到性能分析信息后生成上传链接:

python -m jax.collect_profile 8877 1000

这里的8877是之前的服务端口号,这里需要对此进行监听,1000是指该进程的运行时间,这个时间可以设置的大一些,这里设置为1000秒。

这里需要注意,原进程启动后才可以启动链接生成进程,否则连接不到端口会报错,也就是说在生成链接进程生成成功之前原进程不能结束,因此我们可以在原进程的最终位置加入sleep函数。


原进程:

Jax框架的性能分析——性能分析可视化_性能分析_02


链接生成进程:

Jax框架的性能分析——性能分析可视化_性能分析_03


Jax框架的性能分析——性能分析可视化_性能分析_04


需要注意:

进行对jax的性能信息收集的时候,需要对显卡进行独占(只能运行一个CUDA进程),否则会报错,不过可以通过修改默认设置取消该特性,不过为保证性能分析的准确性(防止同时运行其他进程,对性能分析造成影响)不建议更改默认设置:

修改默认设置,允许其他进程运行的情况下启动性能分析进程,设置环境变量:

TF_GPU_CUPTI_FORCE_CONCURRENT_KERNEL=1


Jax框架的性能分析——性能分析可视化_端口号_05