章节 01
导读 / 主楼:Blox:拥抱函数式编程的JAX神经网络库
Blox是一个轻量级函数式神经网络库,完全拥抱JAX的函数式特性,通过显式状态传递实现透明的模型定义,支持JIT编译、自动微分和并行计算,无需任何魔法装饰器。
正文
Blox是一个轻量级函数式神经网络库,完全拥抱JAX的函数式特性,通过显式状态传递实现透明的模型定义,支持JIT编译、自动微分和并行计算,无需任何魔法装饰器。
章节 01
Blox是一个轻量级函数式神经网络库,完全拥抱JAX的函数式特性,通过显式状态传递实现透明的模型定义,支持JIT编译、自动微分和并行计算,无需任何魔法装饰器。
章节 02
python\noutputs, params = model(params, inputs)\n\n\n参数进去,输出和更新后的参数出来。这就是JAX中有状态计算的标准模式。由于状态显式地流经代码,所有JAX变换——jax.jit、jax.grad、jax.vmap、jax.checkpoint——都能开箱即用,无需包装器、装饰器或任何惊喜。\n\n相比之下,大多数深度学习框架依赖隐式全局状态或线程本地上下文来隐藏参数。虽然这节省了几个字符,却创造了"黑盒"——你不知道函数实际访问了什么数据,也无法完全控制执行流程。Blox通过接受稍微冗长的函数签名,换取了完全的透明度和控制力。\n\n## 为什么函数式?\n\n面向对象编程(OOP)在深度学习中的流行有其历史原因:它符合人类对"层"和"模型"的直觉认知。然而,这种直觉在分布式训练、自动微分和硬件加速的复杂场景下往往成为负担。\n\nJAX的函数式特性提供了更强大的抽象:\n- 纯函数:相同的输入总是产生相同的输出,便于推理和优化\n- 可组合变换:jit、grad、vmap等变换可以自由组合\n- 无隐藏状态:所有状态变化都是显式的,避免副作用\n- 硬件无关:代码自动在CPU/GPU/TPU上运行\n\nBlox充分利用这些特性,让开发者直接操作JAX的核心概念,而不是被框架的抽象层所隔离。\n\n## 核心组件\n\nBlox的设计极其精简,只有几个核心概念:\n\n### Graph:模型层次结构\n\nGraph定义了模型的层次命名空间。通过graph.child('name')创建子节点,为每个模块提供唯一路径用于参数命名。Graph还存储对所有创建模块的引用,提供graph.walk()用于遍历——这对于应用LoRA适配器或切换训练模式非常有用。\n\n### Module:可组合的计算单元\n\nModule在Graph中有唯一路径,提供便捷方法(get_param、set_param)来自动管理自身参数。模块只是Python对象,可以嵌套、注入或动态生成。\n\n### Params:不可变的状态容器\n\nParams是一个不可变容器,以扁平字典形式存储所有状态,键为tuple路径(如('net', 'mlp', 'hidden', 'kernel'))。使用split()方法可分离可训练参数和非训练参数(如RNG状态、批归一化统计量)。\n\n### Param:带元数据的参数包装器\n\nParam包装每个参数值,包含trainable标志和任意元数据。trainable标志决定参数是否可微,也支持分片元数据用于分布式训练。\n\n### Rng:确定性随机数生成\n\nRng模块生成确定性随机密钥。由于它用于随机初始化所有其他参数并提供运行时随机性,必须通过rng.seed(params, seed=42)首先进行种子设置。\n\n## 完整示例:构建MLP\n\n让我们通过一个多层感知机(MLP)示例来了解Blox的使用方式:\n\npython\nimport jax\nimport jax.numpy as jnp\nimport blox as bx\n\n# 定义线性层\nclass Linear(bx.Module):\n def __init__(self, graph: bx.Graph, output_size: int, rng: bx.Rng):\n super().__init__(graph)\n self.output_size = output_size\n self.rng = rng\n\n def __call__(self, params: bx.Params, x: jax.Array):\n # 惰性参数创建——无需预先指定输入形状!\n kernel, params = self.get_param(\n params,\n name='kernel',\n shape=(x.shape[-1], self.output_size),\n init=jax.nn.initializers.normal(),\n rng=self.rng,\n )\n bias, params = self.get_param(\n params,\n name='bias',\n shape=(self.output_size,),\n init=jax.nn.initializers.zeros,\n rng=self.rng,\n )\n return x @ kernel + bias, params\n\n# 定义MLP\nclass MLP(bx.Module):\n def __init__(self, graph: bx.Graph, hidden_size: int, output_size: int, rng: bx.Rng):\n super().__init__(graph)\n self.hidden = Linear(graph.child('hidden'), hidden_size, rng=rng)\n self.output = Linear(graph.child('output'), output_size, rng=rng)\n\n def __call__(self, params: bx.Params, x: jax.Array):\n x, params = self.hidden(params, x)\n x = jax.nn.relu(x)\n return self.output(params, x)\n\n\n## 初始化与JIT编译\n\nBlox将"初始化"(遍历图创建参数)与"运行时" cleanly 分离:\n\npython\n# 定义结构\ngraph = bx.Graph('net')\nrng = bx.Rng(graph.child('rng'))\nmodel = MLP(graph.child('mlp'), hidden_size=128, output_size=10, rng=rng)\n\n# 初始化参数容器和RNG状态\nparams = bx.Params()\nparams = rng.seed(params, seed=42)\n\n# 运行前向传播触发惰性参数初始化\ndummy_input = jnp.ones((1, 784))\n_, params = model(params, dummy_input)\n\n# 锁定以防止训练期间意外创建参数\nparams = params.locked()\n\n# JIT编译并调用——无需特殊装饰器\noutputs, params = jax.jit(model)(params, inputs)\n\n\n## 训练模式:参数分割与合并\n\n训练时需要区分可训练参数(权重)和非训练参数(RNG、批归一化统计量等)。Blox的split()方法实现了这一分离:\n\npython\n@jax.jit(donate_argnames='params')\ndef train_step(params, inputs, targets):\n # 分割为可训练和非训练参数\n trainable, non_trainable = params.split()\n\n def loss_fn(t, nt):\n # 合并以运行前向传播\n preds, new_params = model(t.merge(nt), inputs)\n loss = jnp.mean((preds - targets) ** 2)\n # 提取前向过程中更新的非训练参数\n _, new_nt = new_params.split()\n return loss, new_nt\n\n # 可训练参数的梯度,非训练参数的更新状态\n grads, new_non_trainable = jax.grad(loss_fn, has_aux=True)(\n trainable, non_trainable\n )\n\n # 使用SGD更新可训练参数\n new_trainable = jax.tree.map(\n lambda w, g: w - 0.01 * g, trainable, grads\n )\n\n # 合并更新后的可训练和非训练参数\n return new_trainable.merge(new_non_trainable)\n\n\n## 分布式训练支持\n\nBlox原生支持分布式训练,通过参数元数据指定分片策略:\n\npython\nfrom jax.sharding import NamedSharding, PartitionSpec as P\n\n# 定义带分片元数据的层\nlinear = bx.Linear(\n graph.child('linear'),\n output_size=4096,\n rng=rng,\n kernel_metadata={'sharding': (None, 'model')}, # 沿model轴分片权重\n bias_metadata={'sharding': ('model',)}, # 沿model轴分片偏置\n)\n\n\n## Blox vs 其他JAX库\n\n| 特性 | OOP风格包装器 | Blox |\n|------|-------------|------|\n| 函数调用 | out = layer(x) | outputs, params = layer(params, inputs) |\n| 状态管理 | 隐式全局状态 | 显式状态传递 |\n| 变量作用域 | 不透明 | 显式bx.Graph路径 |\n| JIT/并行 | 自定义包装器 | 标准jax.jit/jax.vmap |\n\n## 适用场景\n\n学习者:这里没有"框架魔法"。所见即所得,是理解JAX层面神经网络实际工作原理的最佳方式。\n\n实践者:如果你厌倦了与隐藏重要细节的框架作斗争,Blox提供完全透明。无论是构建自定义训练循环、实现新颖架构还是扩展规模,你都可以直接访问完整的执行栈。\n\n研究者:Blox的显式设计特别适合需要精细控制的研究场景,如强化学习(Actor与Learner共享权重)、元学习、神经架构搜索等。\n\n## 总结\n\nBlox代表了深度学习框架设计的一种返璞归真。它拒绝为了易用性而牺牲透明度的诱惑,选择拥抱函数式编程的强大能力。这种设计哲学可能不适合追求快速原型的初学者,但对于希望深入理解JAX、需要精细控制训练流程的研究者和工程师来说,Blox提供了一个理想的工具。\n\n在深度学习框架日益复杂、抽象层不断堆叠的今天,Blox提醒我们:有时候,简单和透明才是最高级的抽象。章节 03
Blox:拥抱函数式编程的JAX神经网络库\n\n在深度学习框架的生态中,PyTorch以其易用性占据主导地位,TensorFlow以其工业级部署能力立足,而JAX则以其函数式编程范式和强大的可组合变换能力吸引了越来越多的研究者。然而,大多数JAX神经网络库试图让JAX"看起来像PyTorch",通过面向对象的包装和隐式状态管理来降低学习曲线。这种妥协虽然降低了入门门槛,却牺牲了JAX最核心的优势。\n\nBlox选择了一条不同的道路:完全拥抱JAX的函数式本质,打造一个透明、轻量、零魔法的神经网络库。\n\n核心理念:显式优于隐式\n\nBlox的设计哲学可以用一句话概括:\n\npython\noutputs, params = model(params, inputs)\n\n\n参数进去,输出和更新后的参数出来。这就是JAX中有状态计算的标准模式。由于状态显式地流经代码,所有JAX变换——jax.jit、jax.grad、jax.vmap、jax.checkpoint——都能开箱即用,无需包装器、装饰器或任何惊喜。\n\n相比之下,大多数深度学习框架依赖隐式全局状态或线程本地上下文来隐藏参数。虽然这节省了几个字符,却创造了"黑盒"——你不知道函数实际访问了什么数据,也无法完全控制执行流程。Blox通过接受稍微冗长的函数签名,换取了完全的透明度和控制力。\n\n为什么函数式?\n\n面向对象编程(OOP)在深度学习中的流行有其历史原因:它符合人类对"层"和"模型"的直觉认知。然而,这种直觉在分布式训练、自动微分和硬件加速的复杂场景下往往成为负担。\n\nJAX的函数式特性提供了更强大的抽象:\n- 纯函数:相同的输入总是产生相同的输出,便于推理和优化\n- 可组合变换:jit、grad、vmap等变换可以自由组合\n- 无隐藏状态:所有状态变化都是显式的,避免副作用\n- 硬件无关:代码自动在CPU/GPU/TPU上运行\n\nBlox充分利用这些特性,让开发者直接操作JAX的核心概念,而不是被框架的抽象层所隔离。\n\n核心组件\n\nBlox的设计极其精简,只有几个核心概念:\n\nGraph:模型层次结构\n\nGraph定义了模型的层次命名空间。通过graph.child('name')创建子节点,为每个模块提供唯一路径用于参数命名。Graph还存储对所有创建模块的引用,提供graph.walk()用于遍历——这对于应用LoRA适配器或切换训练模式非常有用。\n\nModule:可组合的计算单元\n\nModule在Graph中有唯一路径,提供便捷方法(get_param、set_param)来自动管理自身参数。模块只是Python对象,可以嵌套、注入或动态生成。\n\nParams:不可变的状态容器\n\nParams是一个不可变容器,以扁平字典形式存储所有状态,键为tuple路径(如('net', 'mlp', 'hidden', 'kernel'))。使用split()方法可分离可训练参数和非训练参数(如RNG状态、批归一化统计量)。\n\nParam:带元数据的参数包装器\n\nParam包装每个参数值,包含trainable标志和任意元数据。trainable标志决定参数是否可微,也支持分片元数据用于分布式训练。\n\nRng:确定性随机数生成\n\nRng模块生成确定性随机密钥。由于它用于随机初始化所有其他参数并提供运行时随机性,必须通过rng.seed(params, seed=42)首先进行种子设置。\n\n完整示例:构建MLP\n\n让我们通过一个多层感知机(MLP)示例来了解Blox的使用方式:\n\npython\nimport jax\nimport jax.numpy as jnp\nimport blox as bx\n\n定义线性层\nclass Linear(bx.Module):\n def __init__(self, graph: bx.Graph, output_size: int, rng: bx.Rng):\n super().__init__(graph)\n self.output_size = output_size\n self.rng = rng\n\n def __call__(self, params: bx.Params, x: jax.Array):\n 惰性参数创建——无需预先指定输入形状!\n kernel, params = self.get_param(\n params,\n name='kernel',\n shape=(x.shape[-1], self.output_size),\n init=jax.nn.initializers.normal(),\n rng=self.rng,\n )\n bias, params = self.get_param(\n params,\n name='bias',\n shape=(self.output_size,),\n init=jax.nn.initializers.zeros,\n rng=self.rng,\n )\n return x @ kernel + bias, params\n\n定义MLP\nclass MLP(bx.Module):\n def __init__(self, graph: bx.Graph, hidden_size: int, output_size: int, rng: bx.Rng):\n super().__init__(graph)\n self.hidden = Linear(graph.child('hidden'), hidden_size, rng=rng)\n self.output = Linear(graph.child('output'), output_size, rng=rng)\n\n def __call__(self, params: bx.Params, x: jax.Array):\n x, params = self.hidden(params, x)\n x = jax.nn.relu(x)\n return self.output(params, x)\n\n\n初始化与JIT编译\n\nBlox将"初始化"(遍历图创建参数)与"运行时" cleanly 分离:\n\npython\n定义结构\ngraph = bx.Graph('net')\nrng = bx.Rng(graph.child('rng'))\nmodel = MLP(graph.child('mlp'), hidden_size=128, output_size=10, rng=rng)\n\n初始化参数容器和RNG状态\nparams = bx.Params()\nparams = rng.seed(params, seed=42)\n\n运行前向传播触发惰性参数初始化\ndummy_input = jnp.ones((1, 784))\n_, params = model(params, dummy_input)\n\n锁定以防止训练期间意外创建参数\nparams = params.locked()\n\nJIT编译并调用——无需特殊装饰器\noutputs, params = jax.jit(model)(params, inputs)\n\n\n训练模式:参数分割与合并\n\n训练时需要区分可训练参数(权重)和非训练参数(RNG、批归一化统计量等)。Blox的split()方法实现了这一分离:\n\npython\n@jax.jit(donate_argnames='params')\ndef train_step(params, inputs, targets):\n 分割为可训练和非训练参数\n trainable, non_trainable = params.split()\n\n def loss_fn(t, nt):\n 合并以运行前向传播\n preds, new_params = model(t.merge(nt), inputs)\n loss = jnp.mean((preds - targets) ** 2)\n 提取前向过程中更新的非训练参数\n _, new_nt = new_params.split()\n return loss, new_nt\n\n 可训练参数的梯度,非训练参数的更新状态\n grads, new_non_trainable = jax.grad(loss_fn, has_aux=True)(\n trainable, non_trainable\n )\n\n 使用SGD更新可训练参数\n new_trainable = jax.tree.map(\n lambda w, g: w - 0.01 * g, trainable, grads\n )\n\n 合并更新后的可训练和非训练参数\n return new_trainable.merge(new_non_trainable)\n\n\n分布式训练支持\n\nBlox原生支持分布式训练,通过参数元数据指定分片策略:\n\npython\nfrom jax.sharding import NamedSharding, PartitionSpec as P\n\n定义带分片元数据的层\nlinear = bx.Linear(\n graph.child('linear'),\n output_size=4096,\n rng=rng,\n kernel_metadata={'sharding': (None, 'model')}, 沿model轴分片权重\n bias_metadata={'sharding': ('model',)}, 沿model轴分片偏置\n)\n\n\nBlox vs 其他JAX库\n\n| 特性 | OOP风格包装器 | Blox |\n|------|-------------|------|\n| 函数调用 | out = layer(x) | outputs, params = layer(params, inputs) |\n| 状态管理 | 隐式全局状态 | 显式状态传递 |\n| 变量作用域 | 不透明 | 显式bx.Graph路径 |\n| JIT/并行 | 自定义包装器 | 标准jax.jit/jax.vmap |\n\n适用场景\n\n学习者:这里没有"框架魔法"。所见即所得,是理解JAX层面神经网络实际工作原理的最佳方式。\n\n实践者:如果你厌倦了与隐藏重要细节的框架作斗争,Blox提供完全透明。无论是构建自定义训练循环、实现新颖架构还是扩展规模,你都可以直接访问完整的执行栈。\n\n研究者:Blox的显式设计特别适合需要精细控制的研究场景,如强化学习(Actor与Learner共享权重)、元学习、神经架构搜索等。\n\n总结\n\nBlox代表了深度学习框架设计的一种返璞归真。它拒绝为了易用性而牺牲透明度的诱惑,选择拥抱函数式编程的强大能力。这种设计哲学可能不适合追求快速原型的初学者,但对于希望深入理解JAX、需要精细控制训练流程的研究者和工程师来说,Blox提供了一个理想的工具。\n\n在深度学习框架日益复杂、抽象层不断堆叠的今天,Blox提醒我们:有时候,简单和透明才是最高级的抽象。