%load ../utils/djl-imports
%load ../utils/plot-utils
%load ../utils/Functions.java
%load ../utils/PlotUtils.java
%load ../utils/StopWatch.java
%load ../utils/Accumulator.java
%load ../utils/Animator.java
%load ../utils/Training.java
%load ../utils/timemachine/Vocab.java
%load ../utils/timemachine/RNNModel.java
%load ../utils/timemachine/RNNModelScratch.java
%load ../utils/timemachine/TimeMachine.java
%load ../utils/timemachine/TimeMachineDataset.java
NDManager manager = NDManager.newBaseManager();
int batchSize = 32;
int numSteps = 35;
TimeMachineDataset dataset =
new TimeMachineDataset.Builder()
.setManager(manager)
.setMaxTokens(10000)
.setSampling(batchSize, false)
.setSteps(numSteps)
.build();
dataset.prepare();
Vocab vocab = dataset.getVocab();
9.1.2.1. 初始化模型参数
下一步是初始化模型参数。
我们从标准差为\(0.01\)的高斯分布中提取权重,
并将偏置项设为\(0\),超参数numHiddens定义隐藏单元的数量,
实例化与更新门、重置门、候选隐状态和输出层相关的所有权重和偏置。
public static NDArray normal(Shape shape, Device device) {
return manager.randomNormal(0, 0.01f, shape, DataType.FLOAT32, device);
public static NDList three(int numInputs
, int numHiddens, Device device) {
return new NDList(
normal(new Shape(numInputs, numHiddens), device),
normal(new Shape(numHiddens, numHiddens), device),
manager.zeros(new Shape(numHiddens), DataType.FLOAT32, device));
public static NDList getParams(int vocabSize, int numHiddens, Device device) {
int numInputs = vocabSize;
int numOutputs = vocabSize;
// Update gate parameters
NDList temp = three(numInputs, numHiddens, device);
NDArray W_xz = temp.get(0);
NDArray W_hz = temp.get(1);
NDArray b_z = temp.get(2);
// Reset gate parameters
temp = three(numInputs, numHiddens, device);
NDArray W_xr = temp.get(0);
NDArray W_hr = temp.get(1);
NDArray b_r = temp.get(2);
// Candidate hidden state parameters
temp = three(numInputs, numHiddens, device);
NDArray W_xh = temp.get(0);
NDArray W_hh = temp.get(1);
NDArray b_h = temp.get(2);
// Output layer parameters
NDArray W_hq = normal(new Shape(numHiddens, numOutputs), device);
NDArray b_q = manager.zeros(new Shape(numOutputs), DataType.FLOAT32, device);
// Attach gradients
NDList params = new NDList(W_xz, W_hz, b_z, W_xr, W_hr, b_r, W_xh, W_hh, b_h, W_hq, b_q);
for (NDArray param : params) {
param.setRequiresGradient(true);
return params;
现在我们将定义隐状态的初始化函数 initGruState()。 与
Section 8.5中定义的initRnnState()函数一样,
此函数返回一个形状为(批量大小,隐藏单元个数)的 NDArray,NDArray
的值全部为零。
public static NDList initGruState(int batchSize, int numHiddens, Device device) {
return new NDList(manager.zeros(new Shape(batchSize, numHiddens), DataType.FLOAT32, device));
现在我们准备定义门控循环单元模型,
模型的架构与基本的循环神经网络单元是相同的, 只是权重更新公式更为复杂。
public static Pair<NDArray, NDList> gru(NDArray inputs, NDList state, NDList params) {
NDArray W_xz = params.get(0);
NDArray W_hz = params.get(1);
NDArray b_z = params.get(2);
NDArray W_xr = params.get(3);
NDArray W_hr = params.get(4);
NDArray b_r = params.get(5);
NDArray W_xh = params.get(6);
NDArray W_hh = params.get(7);
NDArray b_h = params.get(8);
NDArray W_hq = params.get(9);
NDArray b_q = params.get(10);
NDArray H = state.get(0);
NDList outputs = new NDList();
NDArray X, Y, Z, R, H_tilda;
for (int i = 0; i < inputs.size(0); i++) {
X = inputs.get(i);
Z = Activation.sigmoid(X.dot(W_xz).add(H.dot(W_hz).add(b_z)));
R = Activation.sigmoid(X.dot(W_xr).add(H.dot(W_hr).add(b_r)));
H_tilda = Activation.tanh(X.dot(W_xh).add(R.mul(H).dot(W_hh).add(b_h)));
H = Z.mul(H).add(Z.mul(-1).add(1).mul(H_tilda));
Y = H.dot(W_hq).add(b_q);
outputs.add(Y);
return new Pair(outputs.size() > 1 ? NDArrays.concat(outputs) : outputs.get(0), new NDList(H));
训练和预测的工作方式与 Section 8.5完全相同。
训练结束后,我们分别打印输出训练集的困惑度, 以及前缀“time
traveler”和“traveler”的预测序列上的困惑度。
int vocabSize = vocab.length();
int numHiddens = 256;
Device device = manager.getDevice();
int numEpochs = Integer.getInteger("MAX_EPOCH", 500);
int lr = 1;
Functions.TriFunction<Integer, Integer, Device, NDList> getParamsFn = (a, b, c) -> getParams(a, b, c);
Functions.TriFunction<Integer, Integer, Device, NDList> initGruStateFn =
(a, b, c) -> initGruState(a, b, c);
Functions.TriFunction<NDArray, NDList, NDList, Pair<NDArray, NDList>> gruFn = (a, b, c) -> gru(a, b, c);
RNNModelScratch model =
new RNNModelScratch(vocabSize, numHiddens, device,
getParamsFn, initGruStateFn, gruFn);
TimeMachine.trainCh8(model, dataset, vocab, lr, numEpochs, device, false, manager);
perplexity: 1.0, 11663.9 tokens/sec on gpu(0)
time traveller for so it will be conveniettingstingstin timupauc
travellersais a fourth timethere is however a sended and an
因为它使用的是编译好的运算符而不是单个的 NDArray
运算来处理之前阐述的许多细节。
GRU gruLayer = GRU.builder().setNumLayers(1)
.setStateSize(numHiddens).optReturnState(true).optBatchFirst(false).build();
RNNModel modelConcise = new RNNModel(gruLayer,vocab.length());
TimeMachine.trainCh8(modelConcise, dataset, vocab, lr, numEpochs, device, false, manager);
INFO Training on: 1 GPUs.
INFO Load MXNet Engine Version 1.9.0 in 0.083 ms.
perplexity: 1.0, 82348.6 tokens/sec on gpu(0)
time traveller for so it will be convenient to speak of himwas e
traveller of ht right ang hasd and wlank and we sage but yo
假设我们只想使用时间步\(t'\)的输入来预测时间步\(t > t'\)的输出。对于每个时间步,重置门和更新门的最佳值是什么?
调整和分析超参数对运行时间、困惑度和输出顺序的影响。
比较rnn.RNN和rnn.GRU的不同实现对运行时间、困惑度和输出字符串的影响。
如果仅仅实现门控循环单元的一部分,例如,只有一个重置门或一个更新门会怎样?