theano.scan 起什么作用
theano.scan
是 Theano 库中一个极其强大和核心的函数,它的作用是实现循环(looping)或递归(recursion)操作,并将其无缝地集成到符号计算图中。
在深度学习中,许多模型都包含重复性的结构,例如:
- RNN/LSTM/GRU:在每个时间步
t
重复执行相同的更新规则。 - 序列生成:从模型中一步步采样生成一个序列。
- 注意力机制:对输入序列的每个元素进行相似的计算。
theano.scan
就是为了优雅、高效地处理这类问题而设计的。它允许你定义一个“扫描体”(scan body),然后让Theano自动处理循环的展开、内存管理、梯度计算等复杂细节。
theano.scan
的核心作用
- 将循环编译为计算图:你可以用Python写一个函数来描述单个时间步的操作(比如RNN的一个step),然后通过
scan
告诉Theano:“请把这个函数在序列上运行N次”。Theano会将这个循环展开成一个巨大的、静态的计算图,并对其进行优化。 - 自动微分:由于整个循环被编译成了一个计算图,Theano可以像对待普通函数一样,使用反向传播(backpropagation)自动计算整个循环过程的梯度。这对于训练RNN至关重要。
- GPU加速:编译后的计算图可以被转换为高效的C或CUDA代码,在GPU上并行执行,极大地提升了性能。
- 处理变长序列:
scan
可以轻松处理不同长度的输入序列。
theano.scan
的基本语法
sequences,
outputs_info,
non_sequences,
n_steps,
...)
让我们通过一个构建简单RNN的例子来解释这些参数:
示例:用 theano.scan
实现一个简单的RNN
import theano
import theano.tensor as T
import numpy as np# 1. 定义符号变量
x = T.matrix('x') # 输入序列: (T x D) 矩阵,T是时间步数,D是维度
h0 = T.vector('h0') # 初始隐藏状态
W_hh = T.matrix('W_hh') # 隐藏层到隐藏层的权重
W_xh = T.matrix('W_xh') # 输入层到隐藏层的权重
b_h = T.vector('b_h') # 隐藏层偏置# 2. 定义 "扫描体" 函数 (Scan Body)
# 这个函数定义了在单个时间步 t 上要执行的操作
def rnn_step(x_t, h_tm1, W_hh, W_xh, b_h):"""参数:x_t: 当前时间步的输入 (D维向量)h_tm1: 上一时刻的隐藏状态 (H维向量),"tm1" 表示 t minus 1W_hh, W_xh, b_h: 非时序的共享参数返回:h_t: 当前时刻的隐藏状态 (H维向量)"""h_t = T.tanh(T.dot(h_tm1, W_hh) + T.dot(x_t, W_xh) + b_h)return h_t# 3. 调用 theano.scan
# 我们需要告诉 scan:
# - 扫描体是什么 (fn=rnn_step)
# - 循环的序列是什么 (sequences=[x]) -> x 会被拆分成 x[0], x[1], ...
# - 初始状态是什么 (outputs_info=[h0])
# - 其他不随时间变化的参数 (non_sequences=[W_hh, W_xh, b_h])
# - 循环次数 (n_steps=x.shape[0],即输入序列的长度)results, _ = theano.scan(fn=rnn_step, # 要重复执行的函数sequences=[x], # 沿时间轴变化的输入outputs_info=[h0], # 输出的初始值 (这里是隐藏状态)non_sequences=[W_hh, W_xh, b_h], # 不随时间变化的参数n_steps=x.shape[0] # 循环的次数
)# 4. 结果
# 'results' 是一个张量,包含了所有时间步的输出。
# 在这个例子中,results 的形状是 (T x H),其中每一行 results[t] 就是 rnn_step 在时间步 t 的返回值 h_t。
all_hidden_states = results # (T x H) 矩阵# 5. 创建可调用的函数
# 现在我们可以编译一个完整的函数
compute_rnn = theano.function(inputs=[x, h0, W_hh, W_xh, b_h],outputs=all_hidden_states
)# 6. 使用 (示例数据)
x_data = np.random.randn(5, 10).astype(theano.config.floatX) # 5个时间步,每步10维
h0_data = np.zeros(20, dtype=theano.config.floatX) # 20维的初始隐藏状态
W_hh_data = np.random.randn(20, 20).astype(theano.config.floatX)
W_xh_data = np.random.randn(10, 20).astype(theano.config.floatX)
b_h_data = np.zeros(20, dtype=theano.config.floatX)# 调用函数,得到所有时间步的隐藏状态
hidden_states = compute_rnn(x_data, h0_data, W_hh_data, W_xh_data, b_h_data)
print(hidden_states.shape) # 输出: (5, 20)
关键参数详解
参数 | 描述 |
---|---|
fn | 必需。这是“扫描体”函数,定义了单个时间步的逻辑。 |
sequences | 可选。一个列表,包含沿时间轴变化的输入序列。scan 会在每个时间步从这些序列中取出一个切片作为 fn 的输入。 |
outputs_info | 必需。一个列表,指定了 fn 的输出应该是什么样子的。通常用于提供初始状态(如 h0 )。如果 fn 有多个输出,这里就需要多个初始值。 |
non_sequences | 可选。一个列表,包含那些在所有时间步都保持不变的额外参数(如模型权重 W , b )。它们会被原封不动地传递给 fn 。 |
n_steps | 必需。指定循环要执行多少次。 |