首页
Preview

JAX和可组合的程序转换

https://github.com/google/jax 的“关于”部分如下所述。

Composable transformations of Python+NumPy programs: differentiate, 
vectorize, JIT to GPU/TPU, and more

“可组合变换”是什么意思?在 NeurIPS2020: JAX 生态系统 Meetup 的视频中,DeepMind 的一位工程师进行了解释。

举个例子,考虑以下函数。

def fn(x, y):
  return x**2 + y

fn(1., 2.) # (1**2 + 2) = 3

如何使用 JAX 编写以计算此函数的梯度?

df_dx = grad(fn)
df_dx(1., 2.) # df_dx = 2*x = 2*1 = 2

在上述代码中,grad 是返回函数的函数,df_dx 也是一个函数。然后可以在普通函数调用中使用它。

接下来,我们如何编写二阶梯度?

df2_dx = grad(grad(fn))
df2_dx(1., 2.) # df2_dx = d(2*x)_dx = 2

由于 grad 是可组合的,所以只需添加另一个 grad 即可。

不仅 grad 是可组合的,其他变换也可以使用。以下代码允许我们计算编译后的二阶梯度。

df2_dx = jit(grad(grad(fn)))
df2_dx(1., 2.) # 2, compiled here.
df2_dx(1., 2.) # 2, execute pre-compiled code, so it can be executed faster than the first time

此外,批处理计算也可以作为可组合变换添加。(即批处理编译后的二阶梯度)

df2_dx = vmap(jit(grad(grad(fn))))
xs = jnp.ones((batch_size,))
df2_dx(xs, 2 * xs) # [2, 2] if batch_size=2

在运行多个加速器(例如 GPU)时,也可以将其作为可组合变换添加。(即多 GPU 批处理编译后的二阶梯度)

df2_dx = pmap(vmap(jit(grad(grad(fn)))))
xs = jnp.ones((num_gpus, batch_size,))
df2_dx(xs, 2 * xs) # [[2, 2], [2, 2]] if batch_size=2 and num_gpus=2

这就是在这个5分钟的演示中解释的内容,我很惊讶地了解了 JAX 中可组合变换的含义。视频还讨论了像 Haiku 和 Optax 这样的生态系统,以及各种 JAX 实现的其他示例,例如 GANs,这些都很有启发性。

译自:https://medium.com/@satojkovic/jax-and-composable-program-transformations-d54ecbb9c39b

版权声明:本文内容由TeHub注册用户自发贡献,版权归原作者所有,TeHub社区不拥有其著作权,亦不承担相应法律责任。 如果您发现本社区中有涉嫌抄袭的内容,填写侵权投诉表单进行举报,一经查实,本社区将立刻删除涉嫌侵权内容。

点赞(0)
收藏(0)
alivne
复杂的问题简单化

评论(0)

添加评论