转换后的 JAX return 类型

JAX return types after transformations

为什么 jax.grad 的 return 类型与以下场景中的其他 jax 转换不同?

考虑一个由 JAX 转换的函数,它将自定义容器作为参数

import jax
import jax.numpy as jnp
from collections import namedtuple

# Define NTContainer using namedtuple factory
NTContainer = namedtuple(
    'NTContainer',['name','flag','array']
)

# Instantiate container
container = NTContainer(
   name='composite-container',
   flag=0,
   array=jnp.reshape(jnp.arange(6,dtype='f4'),(3,2)),
)

# Test function
def test_func(container : NTContainer):
    if container.flag == 0:
        return container.array.sum()
    return container.array.prod()

需要通知 JAX 如何正确处理 NTContainer(默认 namedtuple pytree 不能使用)

# Register NTContainer pytree to handle static and traceable members

def unpack_NTC(c):
    active, passive = (c.array,), (c.name, c.flag)
    return active, passive

def repack_NTC(passive, active):
    (name, flag), (array,) = passive, active
    return NTContainer(name,flag,value)

jax.tree_util.register_pytree_node(
    NTContainer, unpack_NTC, repack_NTC,
)

现在执行几个 jax 转换并调用 container 结果

jax.jit(test_func)(container)
# DeviceArray(15., dtype=float32)

jax.vmap(test_func)(container)
# DeviceArray([1., 5., 9.], dtype=float32)

jax.grad(test_func)(container)
# NTContainer(name='composite-container',
#     flag=0,
#     array=DeviceArray([[1., 1.],
#         [1., 1.],
#         [1., 1.]], dtype=float32))

为什么 jax.grad 转换调用 return 一个 NTContainer 而不是 DeviceArray

简短的回答是,这是 return 值,因为这是 grad 的定义方式,老实说,我不确定它还能如何定义。

思考这些转变;

  • jit(f)(x) 计算应用于 xf 的输出:结果是一个标量。
  • vmap(f)(x) 计算应用于 x 中数组的每个元素的 f 的输出:结果是一个向量。
  • grad(f)(x)计算f相对于x的每个分量的梯度,结果是x的每个分量一个值。

所以grad(f)(x)在这种情况下必须return多个值的集合是有道理的,但是应该如何打包这个渐变集合?

JAX 可以定义 grad 以便它始终 return 将这些作为列表或元组,但这可能会让人难以理解输出梯度与输入的关系,特别是对于更复杂的结构,如嵌套列表、元组和字典。因此,JAX 将这些 per-component 梯度和 re-bundles 它们采用与输入参数相同的结构:因此,在这种情况下,结果是 NTContainer 包含那些 per-component 梯度。