贝博恩创新科技网

sonnet教程 deepmind

Sonnet 是 DeepMind 为其旗舰深度学习框架 JAX 设计的高级神经网络库,如果你了解 TensorFlow,可以把 Sonnet 类比为 TensorFlow 的 Keras API,但它与 JAX 的函数式编程范式结合得更加紧密和优雅。

sonnet教程 deepmind-图1
(图片来源网络,侵删)

教程目录

  1. 第一部分:Sonnet 是什么?为什么用它?
    • Sonnet 的定位与特点
    • Sonnet vs. Keras/Flax
  2. 第二部分:前置知识 - JAX 基础
    • jax.numpy (JAX 的 NumPy)
    • jax.jit (即时编译)
    • jax.grad (自动微分)
  3. 第三部分:Sonnet 核心组件详解
    • snt.Module: 所有模块的基类
    • snt.Linear: 全连接层
    • snt.Conv2D: 2D 卷积层
    • snt.BatchNorm: 批归一化
    • snt.RNN/LSTM/GRU: 循环神经网络
    • snt.Sequential: 顺序容器
  4. 第四部分:实战演练 - 构建一个简单的 MLP
    • 完整代码示例
    • 代码逐行解析
  5. 第五部分:实战演练 - 构建一个 CNN 图像分类器
    • 完整代码示例
    • 代码逐行解析
  6. 第六部分:进阶主题
    • 参数初始化
    • 自定义 Module
    • 训练循环与 optax 集成
  7. 第七部分:总结与资源

第一部分:Sonnet 是什么?为什么用它?

Sonnet 的定位与特点

  • 高级 API: Sonnet 提供了简洁、易于使用的 API 来构建复杂的神经网络模型,让你可以专注于模型结构,而不是底层的 JAX 操作。
  • JAX 原生: Sonnet 完全基于 JAX 构建,这意味着你可以无缝使用 JAX 的所有强大功能,如 jitgradpmapvmap 等,而无需任何额外配置。
  • 函数式与面向对象结合: Sonnet 采用面向对象的方式组织模型(通过继承 snt.Module),但在前向传播时,它鼓励使用函数式调用(module(inputs)),这使得模型状态(参数)和数据流分离,非常清晰。
  • 为研究而生: DeepMind 开发 Sonnet 的初衷就是为了支持其前沿的 AI 研究,它设计得非常灵活,易于扩展和实验新想法。

Sonnet vs. Keras/Flax

特性 Sonnet Keras (TensorFlow) Flax
底层框架 JAX TensorFlow JAX
编程范式 OO + Functional OO + Functional Functional (一等公民)
参数处理 明确的 variables (params, state, batch_stats) 隐式集成在模型中 明确的 paramsstate
易用性 对 JAX 用户友好 对初学者最友好 对函数式编程爱好者友好
核心优势 与 JAX 生态无缝集成,灵活 生态最庞大,文档最全 纯函数式设计,状态管理清晰

简单来说:如果你想在 JAX 的世界里快速搭建模型,Sonnet 是一个非常优雅且强大的选择。


第二部分:前置知识 - JAX 基础

在深入 Sonnet 之前,你必须理解 JAX 的三个核心概念。

jax.numpy (JAX 的 NumPy)

它和 NumPy 的 API 几乎完全一样,但它的操作是在“函数式”的转换器上执行的,这意味着它不会改变输入,而是返回一个新结果。

import jax.numpy as jnp
# JAX 的数组是不可变的
x = jnp.array([1, 2, 3])
y = x + 1 # x 不会改变
print(x) # 输出: [1 2 3]
print(y) # 输出: [2 3 4]

jax.jit (即时编译 - Just-In-Time Compilation)

这是 JAX 性能的关键。jit 会将你的 Python 函数编译成高效的 XLA 执行代码,并在首次运行时缓存,后续调用速度极快。

sonnet教程 deepmind-图2
(图片来源网络,侵删)
from jax import jit
# 一个普通的 Python 函数
def slow_fn(x):
    return x + x.T @ x
# 编译它!
fast_fn = jit(slow_fn)
# 首次运行会稍慢,因为需要编译
# 后续运行会非常快
x = jnp.ones((1000, 1000))
fast_fn(x)

jax.grad (自动微分)

grad 可以自动计算一个标量函数对其中一个输入的梯度,这是所有神经网络训练的基础。

from jax import grad
# 一个简单的二次函数
def f(x):
    return x ** 2
# 计算 f 对 x 的导数
df_dx = grad(f)
# 在 x = 3.0 处求导
print(df_dx(3.0)) # 输出: 6.0 (因为 2 * 3.0 = 6.0)

第三部分:Sonnet 核心组件详解

snt.Module: 所有模块的基类

这是 Sonnet 的核心,你创建的任何自定义网络层或完整模型都应该继承 snt.Module

  • *`call(self, args, kwargs)`: 定义模型的前向传播逻辑,当你像 model(inputs) 一样调用模块时,这个方法会被执行。
  • variables: 一个重要的属性,包含了模块的所有可训练参数、状态(如 RNN 的隐藏状态)和批量统计信息(如 BatchNorm 的均值和方差)。

snt.Linear: 全连接层

最简单的层之一。

import sonnet as snt
# 创建一个输入维度 10,输出维度 5 的全连接层
linear = snt.Linear(output_size=5)
# 假设输入数据是 (batch_size, input_dim)
inputs = jnp.ones((32, 10)) # batch_size=32
# 前向传播
outputs = linear(inputs)
print(outputs.shape) # 输出: (32, 5)
# 查看参数
print(linear.variables) # 包含 'kernel' (权重) 和 'bias' (偏置)

snt.Conv2D: 2D 卷积层

用于处理图像数据。

sonnet教程 deepmind-图3
(图片来源网络,侵删)
# 创建一个 3x3 卷积核,输入通道 3,输出通道 64
conv = snt.Conv2D(output_channels=64, kernel_shape=3)
# 输入数据通常是 (batch_size, height, width, channels)
images = jnp.ones((16, 224, 224, 3)) # batch_size=16
features = conv(images)
print(features.shape) # 输出: (16, 222, 222, 64)

snt.BatchNorm: 批量归一化

# 创建一个批归一化层,适用于数据格式 NCHW
# Sonnet 也支持 NHWC 格式
batch_norm = snt.BatchNorm(axis=(1, 2, 3)) # 对 H, W, C 通道做归一化
# 第一次调用时,它会计算全局统计量
normalized_features = batch_norm(features)
print(normalized_features.shape) # 输出: (16, 222, 222, 64)
# 查看批量统计信息
print(batch_norm.variables['batch_stats'])

snt.RNN/LSTM/GRU: 循环神经网络

# 创建一个 LSTM 单元
lstm = snt.LSTM(hidden_size=128)
# 初始隐藏状态
batch_size = 32
initial_state = lstm.initial_state(batch_size)
# 输入数据 (sequence_length, batch_size, input_size)
inputs = jnp.ones((10, 32, 64)) # 10个时间步
# 前向传播
# lstm 会返回 (output_sequence, final_state)
output_sequence, final_state = lstm(inputs, initial_state)
print(output_sequence.shape) # 输出: (10, 32, 128)
print(final_state.h.shape)    # 输出: (32, 128)

snt.Sequential: 顺序容器

当你想将多个层按顺序堆叠时,Sequential 非常有用。

mlp = snt.Sequential([
    snt.Linear(256),
    jax.nn.relu,
    snt.Linear(128),
    jax.nn.relu,
    snt.Linear(10) # 10个类别的输出
])
inputs = jnp.ones((32, 784)) # MNIST 数据
outputs = mlp(inputs)
print(outputs.shape) # 输出: (32, 10)

第四部分:实战演练 - 构建一个简单的 MLP

我们将构建一个多层感知机来分类手写数字(类似 MNIST 任务)。

import jax
import jax.numpy as jnp
import sonnet as snt
import optax  # JAX 的优化器库
# 1. 定义模型
class MLP(snt.Module):
    """一个简单的多层感知机"""
    def __init__(self, hidden_sizes, output_size):
        super().__init__()
        self.layers = snt.Sequential([
            snt.Linear(hidden_sizes[0]),
            jax.nn.relu,
            snt.Linear(hidden_sizes[1]),
            jax.nn.relu,
            snt.Linear(output_size)
        ])
    def __call__(self, x):
        # 展平输入图像 (batch_size, 28, 28) -> (batch_size, 784)
        x = x.reshape((x.shape[0], -1))
        return self.layers(x)
# 2. 初始化模型和优化器
model = MLP(hidden_sizes=[128, 64], output_size=10)
optimizer = optax.adam(learning_rate=1e-3)
# 初始化模型参数
dummy_input = jnp.ones((1, 28, 28, 1)) # 模拟一个 batch 的输入
params = model.init(jax.random.PRNGKey(0), dummy_input)['params']
# 3. 定义损失函数和训练步骤
@jax.jit
def loss_fn(params, batch):
    images, labels = batch
    logits = model.apply({'params': params}, images)
    # 使用 optax 的交叉熵损失
    loss = optax.softmax_cross_entropy_with_integer_logits(
        logits=logits, labels=labels
    )
    return jnp.mean(loss)
# 使用 `jax.value_and_grad` 一次性计算损失和梯度
@jax.jit
def train_step(params, opt_state, batch):
    loss, grads = jax.value_and_grad(loss_fn)(params, batch)
    # 应用梯度更新参数
    updates, new_opt_state = optimizer.update(grads, opt_state)
    new_params = optax.apply_updates(params, updates)
    return new_params, new_opt_state, loss
# 4. 模拟训练循环
opt_state = optimizer.init(params)
# 假设我们有一些数据
train_data = (jnp.ones((64, 28, 28, 1)), jnp.ones((64,), dtype=jnp.int32)) # batch_size=64
# 训练一个步骤
new_params, new_opt_state, loss = train_step(params, opt_state, train_data)
print(f"Initial Loss: {loss}")

第五部分:实战演练 - 构建一个 CNN 图像分类器

下面是一个更完整的 CNN 示例,用于 CIFAR-10 数据集分类。

import jax
import jax.numpy as jnp
import sonnet as snt
import optax
# 1. 定义 CNN 模型
class CNN(snt.Module):
    def __init__(self, num_classes=10):
        super().__init__()
        self.conv_block1 = snt.Sequential([
            snt.Conv2D(32, kernel_shape=3, padding='SAME'),
            jax.nn.relu,
            snt.MaxPool2D(pool_size=2, strides=2, padding='SAME')
        ])
        self.conv_block2 = snt.Sequential([
            snt.Conv2D(64, kernel_shape=3, padding='SAME'),
            jax.nn.relu,
            snt.MaxPool2D(pool_size=2, strides=2, padding='SAME')
        ])
        self.classifier = snt.Sequential([
            snt.Flatten(),
            snt.Linear(256),
            jax.nn.relu,
            snt.Linear(num_classes)
        ])
    def __call__(self, x):
        x = self.conv_block1(x)
        x = self.conv_block2(x)
        x = self.classifier(x)
        return x
# 2. 初始化
model = CNN(num_classes=10)
dummy_input = jnp.ones((1, 32, 32, 3)) # CIFAR-10 图像尺寸
params = model.init(jax.random.PRNGKey(42), dummy_input)['params']
# 3. 损失和训练步骤 (与 MLP 类似)
@jax.jit
def loss_fn(params, batch):
    images, labels = batch
    logits = model.apply({'params': params}, images)
    loss = optax.softmax_cross_entropy_with_integer_logits(
        logits=logits, labels=labels
    )
    accuracy = jnp.mean(jnp.argmax(logits, axis=-1) == labels)
    return loss, accuracy
@jax.jit
def train_step(params, opt_state, batch):
    (loss, accuracy), grads = jax.value_and_grad(loss_fn, has_aux=True)(params, batch)
    updates, new_opt_state = optimizer.update(grads, opt_state)
    new_params = optax.apply_updates(params, updates)
    return new_params, new_opt_state, loss, accuracy
# 4. 优化器
optimizer = optax.adamw(learning_rate=1e-3, weight_decay=1e-4)
opt_state = optimizer.init(params)
# 5. 模拟训练
train_batch = (jnp.ones((32, 32, 32, 3)), jnp.ones((32,), dtype=jnp.int32))
for epoch in range(5):
    new_params, new_opt_state, loss, accuracy = train_step(params, opt_state, train_batch)
    params = new_params
    opt_state = new_opt_state
    print(f"Epoch {epoch+1}, Loss: {loss:.4f}, Accuracy: {accuracy:.4f}")

第六部分:进阶主题

参数初始化

Sonnet 允许你为每个层指定自定义的初始化器。

import jax.nn.initializers as init
# 创建一个线性层,并指定权重和偏置的初始化方法
linear = snt.Linear(
    output_size=5,
    w_init=init.Glorot(),    # Glorot/Xavier 初始化
    b_init=init.Zeros()      # 零初始化
)

自定义 Module

创建你自己的复杂模块非常简单。

class MyAttentionBlock(snt.Module):
    def __init__(self, embed_dim, num_heads):
        super().__init__()
        self.attention = snt.MultiHeadAttention(num_heads=num_heads, key_size=embed_dim)
        self.layer_norm1 = snt.LayerNorm(axis=-1)
        self.mlp = snt.Sequential([
            snt.Linear(embed_dim * 4),
            jax.nn.relu,
            snt.Linear(embed_dim)
        ])
        self.layer_norm2 = snt.LayerNorm(axis=-1)
    def __call__(self, x):
        # Self-attention
        attn_output = self.attention(q=x, k=x, v=x)
        x = self.layer_norm1(x + attn_output) # 残差连接
        # MLP
        mlp_output = self.mlp(x)
        x = self.layer_norm2(x + mlp_output) # 残差连接
        return x

训练循环与 optax 集成

现代 JAX/Sonnet 训练流程通常使用 optax 来管理优化器状态和梯度更新,如上面的示例所示,这种模式非常清晰和高效。


第七部分:总结与资源

  • Sonnet 是 JAX 的 Keras: 它为 JAX 提供了一个高级、易用的神经网络 API。
  • 核心是 snt.Module: 通过继承它来构建你的模型,并实现 __call__ 方法。
  • 拥抱 JAX 生态: 记得总是用 @jax.jit 来加速你的计算,用 jax.grad 来计算梯度。
  • optax 携手: Sonnet 和 optax 是 JAX 生态中的“黄金搭档”,一起用于构建和训练模型。

官方资源

希望这份详细的教程能帮助你快速上手 DeepMind 的 Sonnet!

分享:
扫描分享到社交APP
上一篇
下一篇