Sonnet 是 DeepMind 为其旗舰深度学习框架 JAX 设计的高级神经网络库,如果你了解 TensorFlow,可以把 Sonnet 类比为 TensorFlow 的 Keras API,但它与 JAX 的函数式编程范式结合得更加紧密和优雅。

教程目录
- 第一部分:Sonnet 是什么?为什么用它?
- Sonnet 的定位与特点
- Sonnet vs. Keras/Flax
- 第二部分:前置知识 - JAX 基础
jax.numpy(JAX 的 NumPy)jax.jit(即时编译)jax.grad(自动微分)
- 第三部分:Sonnet 核心组件详解
snt.Module: 所有模块的基类snt.Linear: 全连接层snt.Conv2D: 2D 卷积层snt.BatchNorm: 批归一化snt.RNN/LSTM/GRU: 循环神经网络snt.Sequential: 顺序容器
- 第四部分:实战演练 - 构建一个简单的 MLP
- 完整代码示例
- 代码逐行解析
- 第五部分:实战演练 - 构建一个 CNN 图像分类器
- 完整代码示例
- 代码逐行解析
- 第六部分:进阶主题
- 参数初始化
- 自定义 Module
- 训练循环与
optax集成
- 第七部分:总结与资源
第一部分:Sonnet 是什么?为什么用它?
Sonnet 的定位与特点
- 高级 API: Sonnet 提供了简洁、易于使用的 API 来构建复杂的神经网络模型,让你可以专注于模型结构,而不是底层的 JAX 操作。
- JAX 原生: Sonnet 完全基于 JAX 构建,这意味着你可以无缝使用 JAX 的所有强大功能,如
jit、grad、pmap、vmap等,而无需任何额外配置。 - 函数式与面向对象结合: Sonnet 采用面向对象的方式组织模型(通过继承
snt.Module),但在前向传播时,它鼓励使用函数式调用(module(inputs)),这使得模型状态(参数)和数据流分离,非常清晰。 - 为研究而生: DeepMind 开发 Sonnet 的初衷就是为了支持其前沿的 AI 研究,它设计得非常灵活,易于扩展和实验新想法。
Sonnet vs. Keras/Flax
| 特性 | Sonnet | Keras (TensorFlow) | Flax |
|---|---|---|---|
| 底层框架 | JAX | TensorFlow | JAX |
| 编程范式 | OO + Functional | OO + Functional | Functional (一等公民) |
| 参数处理 | 明确的 variables (params, state, batch_stats) |
隐式集成在模型中 | 明确的 params 和 state |
| 易用性 | 对 JAX 用户友好 | 对初学者最友好 | 对函数式编程爱好者友好 |
| 核心优势 | 与 JAX 生态无缝集成,灵活 | 生态最庞大,文档最全 | 纯函数式设计,状态管理清晰 |
简单来说:如果你想在 JAX 的世界里快速搭建模型,Sonnet 是一个非常优雅且强大的选择。
第二部分:前置知识 - JAX 基础
在深入 Sonnet 之前,你必须理解 JAX 的三个核心概念。
jax.numpy (JAX 的 NumPy)
它和 NumPy 的 API 几乎完全一样,但它的操作是在“函数式”的转换器上执行的,这意味着它不会改变输入,而是返回一个新结果。
import jax.numpy as jnp # JAX 的数组是不可变的 x = jnp.array([1, 2, 3]) y = x + 1 # x 不会改变 print(x) # 输出: [1 2 3] print(y) # 输出: [2 3 4]
jax.jit (即时编译 - Just-In-Time Compilation)
这是 JAX 性能的关键。jit 会将你的 Python 函数编译成高效的 XLA 执行代码,并在首次运行时缓存,后续调用速度极快。

from jax import jit
# 一个普通的 Python 函数
def slow_fn(x):
return x + x.T @ x
# 编译它!
fast_fn = jit(slow_fn)
# 首次运行会稍慢,因为需要编译
# 后续运行会非常快
x = jnp.ones((1000, 1000))
fast_fn(x)
jax.grad (自动微分)
grad 可以自动计算一个标量函数对其中一个输入的梯度,这是所有神经网络训练的基础。
from jax import grad
# 一个简单的二次函数
def f(x):
return x ** 2
# 计算 f 对 x 的导数
df_dx = grad(f)
# 在 x = 3.0 处求导
print(df_dx(3.0)) # 输出: 6.0 (因为 2 * 3.0 = 6.0)
第三部分:Sonnet 核心组件详解
snt.Module: 所有模块的基类
这是 Sonnet 的核心,你创建的任何自定义网络层或完整模型都应该继承 snt.Module。
- *`call(self, args, kwargs)`: 定义模型的前向传播逻辑,当你像
model(inputs)一样调用模块时,这个方法会被执行。 variables: 一个重要的属性,包含了模块的所有可训练参数、状态(如 RNN 的隐藏状态)和批量统计信息(如 BatchNorm 的均值和方差)。
snt.Linear: 全连接层
最简单的层之一。
import sonnet as snt # 创建一个输入维度 10,输出维度 5 的全连接层 linear = snt.Linear(output_size=5) # 假设输入数据是 (batch_size, input_dim) inputs = jnp.ones((32, 10)) # batch_size=32 # 前向传播 outputs = linear(inputs) print(outputs.shape) # 输出: (32, 5) # 查看参数 print(linear.variables) # 包含 'kernel' (权重) 和 'bias' (偏置)
snt.Conv2D: 2D 卷积层
用于处理图像数据。

# 创建一个 3x3 卷积核,输入通道 3,输出通道 64 conv = snt.Conv2D(output_channels=64, kernel_shape=3) # 输入数据通常是 (batch_size, height, width, channels) images = jnp.ones((16, 224, 224, 3)) # batch_size=16 features = conv(images) print(features.shape) # 输出: (16, 222, 222, 64)
snt.BatchNorm: 批量归一化
# 创建一个批归一化层,适用于数据格式 NCHW # Sonnet 也支持 NHWC 格式 batch_norm = snt.BatchNorm(axis=(1, 2, 3)) # 对 H, W, C 通道做归一化 # 第一次调用时,它会计算全局统计量 normalized_features = batch_norm(features) print(normalized_features.shape) # 输出: (16, 222, 222, 64) # 查看批量统计信息 print(batch_norm.variables['batch_stats'])
snt.RNN/LSTM/GRU: 循环神经网络
# 创建一个 LSTM 单元 lstm = snt.LSTM(hidden_size=128) # 初始隐藏状态 batch_size = 32 initial_state = lstm.initial_state(batch_size) # 输入数据 (sequence_length, batch_size, input_size) inputs = jnp.ones((10, 32, 64)) # 10个时间步 # 前向传播 # lstm 会返回 (output_sequence, final_state) output_sequence, final_state = lstm(inputs, initial_state) print(output_sequence.shape) # 输出: (10, 32, 128) print(final_state.h.shape) # 输出: (32, 128)
snt.Sequential: 顺序容器
当你想将多个层按顺序堆叠时,Sequential 非常有用。
mlp = snt.Sequential([
snt.Linear(256),
jax.nn.relu,
snt.Linear(128),
jax.nn.relu,
snt.Linear(10) # 10个类别的输出
])
inputs = jnp.ones((32, 784)) # MNIST 数据
outputs = mlp(inputs)
print(outputs.shape) # 输出: (32, 10)
第四部分:实战演练 - 构建一个简单的 MLP
我们将构建一个多层感知机来分类手写数字(类似 MNIST 任务)。
import jax
import jax.numpy as jnp
import sonnet as snt
import optax # JAX 的优化器库
# 1. 定义模型
class MLP(snt.Module):
"""一个简单的多层感知机"""
def __init__(self, hidden_sizes, output_size):
super().__init__()
self.layers = snt.Sequential([
snt.Linear(hidden_sizes[0]),
jax.nn.relu,
snt.Linear(hidden_sizes[1]),
jax.nn.relu,
snt.Linear(output_size)
])
def __call__(self, x):
# 展平输入图像 (batch_size, 28, 28) -> (batch_size, 784)
x = x.reshape((x.shape[0], -1))
return self.layers(x)
# 2. 初始化模型和优化器
model = MLP(hidden_sizes=[128, 64], output_size=10)
optimizer = optax.adam(learning_rate=1e-3)
# 初始化模型参数
dummy_input = jnp.ones((1, 28, 28, 1)) # 模拟一个 batch 的输入
params = model.init(jax.random.PRNGKey(0), dummy_input)['params']
# 3. 定义损失函数和训练步骤
@jax.jit
def loss_fn(params, batch):
images, labels = batch
logits = model.apply({'params': params}, images)
# 使用 optax 的交叉熵损失
loss = optax.softmax_cross_entropy_with_integer_logits(
logits=logits, labels=labels
)
return jnp.mean(loss)
# 使用 `jax.value_and_grad` 一次性计算损失和梯度
@jax.jit
def train_step(params, opt_state, batch):
loss, grads = jax.value_and_grad(loss_fn)(params, batch)
# 应用梯度更新参数
updates, new_opt_state = optimizer.update(grads, opt_state)
new_params = optax.apply_updates(params, updates)
return new_params, new_opt_state, loss
# 4. 模拟训练循环
opt_state = optimizer.init(params)
# 假设我们有一些数据
train_data = (jnp.ones((64, 28, 28, 1)), jnp.ones((64,), dtype=jnp.int32)) # batch_size=64
# 训练一个步骤
new_params, new_opt_state, loss = train_step(params, opt_state, train_data)
print(f"Initial Loss: {loss}")
第五部分:实战演练 - 构建一个 CNN 图像分类器
下面是一个更完整的 CNN 示例,用于 CIFAR-10 数据集分类。
import jax
import jax.numpy as jnp
import sonnet as snt
import optax
# 1. 定义 CNN 模型
class CNN(snt.Module):
def __init__(self, num_classes=10):
super().__init__()
self.conv_block1 = snt.Sequential([
snt.Conv2D(32, kernel_shape=3, padding='SAME'),
jax.nn.relu,
snt.MaxPool2D(pool_size=2, strides=2, padding='SAME')
])
self.conv_block2 = snt.Sequential([
snt.Conv2D(64, kernel_shape=3, padding='SAME'),
jax.nn.relu,
snt.MaxPool2D(pool_size=2, strides=2, padding='SAME')
])
self.classifier = snt.Sequential([
snt.Flatten(),
snt.Linear(256),
jax.nn.relu,
snt.Linear(num_classes)
])
def __call__(self, x):
x = self.conv_block1(x)
x = self.conv_block2(x)
x = self.classifier(x)
return x
# 2. 初始化
model = CNN(num_classes=10)
dummy_input = jnp.ones((1, 32, 32, 3)) # CIFAR-10 图像尺寸
params = model.init(jax.random.PRNGKey(42), dummy_input)['params']
# 3. 损失和训练步骤 (与 MLP 类似)
@jax.jit
def loss_fn(params, batch):
images, labels = batch
logits = model.apply({'params': params}, images)
loss = optax.softmax_cross_entropy_with_integer_logits(
logits=logits, labels=labels
)
accuracy = jnp.mean(jnp.argmax(logits, axis=-1) == labels)
return loss, accuracy
@jax.jit
def train_step(params, opt_state, batch):
(loss, accuracy), grads = jax.value_and_grad(loss_fn, has_aux=True)(params, batch)
updates, new_opt_state = optimizer.update(grads, opt_state)
new_params = optax.apply_updates(params, updates)
return new_params, new_opt_state, loss, accuracy
# 4. 优化器
optimizer = optax.adamw(learning_rate=1e-3, weight_decay=1e-4)
opt_state = optimizer.init(params)
# 5. 模拟训练
train_batch = (jnp.ones((32, 32, 32, 3)), jnp.ones((32,), dtype=jnp.int32))
for epoch in range(5):
new_params, new_opt_state, loss, accuracy = train_step(params, opt_state, train_batch)
params = new_params
opt_state = new_opt_state
print(f"Epoch {epoch+1}, Loss: {loss:.4f}, Accuracy: {accuracy:.4f}")
第六部分:进阶主题
参数初始化
Sonnet 允许你为每个层指定自定义的初始化器。
import jax.nn.initializers as init
# 创建一个线性层,并指定权重和偏置的初始化方法
linear = snt.Linear(
output_size=5,
w_init=init.Glorot(), # Glorot/Xavier 初始化
b_init=init.Zeros() # 零初始化
)
自定义 Module
创建你自己的复杂模块非常简单。
class MyAttentionBlock(snt.Module):
def __init__(self, embed_dim, num_heads):
super().__init__()
self.attention = snt.MultiHeadAttention(num_heads=num_heads, key_size=embed_dim)
self.layer_norm1 = snt.LayerNorm(axis=-1)
self.mlp = snt.Sequential([
snt.Linear(embed_dim * 4),
jax.nn.relu,
snt.Linear(embed_dim)
])
self.layer_norm2 = snt.LayerNorm(axis=-1)
def __call__(self, x):
# Self-attention
attn_output = self.attention(q=x, k=x, v=x)
x = self.layer_norm1(x + attn_output) # 残差连接
# MLP
mlp_output = self.mlp(x)
x = self.layer_norm2(x + mlp_output) # 残差连接
return x
训练循环与 optax 集成
现代 JAX/Sonnet 训练流程通常使用 optax 来管理优化器状态和梯度更新,如上面的示例所示,这种模式非常清晰和高效。
第七部分:总结与资源
- Sonnet 是 JAX 的 Keras: 它为 JAX 提供了一个高级、易用的神经网络 API。
- 核心是
snt.Module: 通过继承它来构建你的模型,并实现__call__方法。 - 拥抱 JAX 生态: 记得总是用
@jax.jit来加速你的计算,用jax.grad来计算梯度。 - 与
optax携手: Sonnet 和optax是 JAX 生态中的“黄金搭档”,一起用于构建和训练模型。
官方资源
- Sonnet GitHub 仓库: https://github.com/deepmind/sonnet (包含所有源代码和示例)
- Sonnet API 文档: https://sonnet.readthedocs.io/ (查找所有模块和函数的详细说明)
- JAX 官方教程: https://jax.readthedocs.io/en/latest/jax-101/ (在学 Sonnet 之前,强烈建议先通读 JAX 101)
希望这份详细的教程能帮助你快速上手 DeepMind 的 Sonnet!
