diff --git a/source/advanced/parameter_more_setting.rst b/source/advanced/parameter_more_setting.rst index 645d8cd3e7f8bc771beda4a9757c3d014b0019dd..8a1cbc4d1344586b690e025651ac0853a267e3c5 100644 --- a/source/advanced/parameter_more_setting.rst +++ b/source/advanced/parameter_more_setting.rst @@ -3,7 +3,7 @@ 更细粒度的参数优化设置 ============================== -在 :ref:`train_and_evaluation` 当中网络使用如下优化器进行训练: +在 :ref:`train_and_evaluation` 中网络使用如下优化器进行训练: .. testcode:: @@ -13,7 +13,7 @@ lr=0.05, # 学习速率 ) -这个优化器对所有参数都使用同一学习速率进行优化,我们将在本章中介绍如何做到对不同的参数采用不同的学习速率。 +这个优化器对所有参数都使用同一学习速率进行优化,而在本章中我们将介绍如何做到对不同的参数采用不同的学习速率。 本章我们沿用 :ref:`network_build` 中创建的 ``LeNet`` ,下述的优化器相关代码可以用于取代 :ref:`train_and_evaluation` 中对应的代码。 diff --git a/source/advanced/sublinear.rst b/source/advanced/sublinear.rst index 2e5f3ccd85175b14e2e3ceb8c0109d0c15f6089e..859dd14e20e34203eb3569b6f010cac699358863 100644 --- a/source/advanced/sublinear.rst +++ b/source/advanced/sublinear.rst @@ -22,7 +22,7 @@ 亚线性内存技术仅适用于 MegEngine 静态图模式。这种内存优化方式在编译计算图和训练模型时会有少量的额外时间开销。下面我们以 `ResNet50 `_ 为例,说明使用亚线性内存优化能够大幅节约网络训练显存使用。 .. testcode:: - + import os import megengine as mge @@ -43,9 +43,6 @@ lr=0.1, ) - data = mge.tensor() - label = mge.tensor(dtype="int32") - # symbolic参数说明请参见 静态图的两种模式 @trace(symbolic=True) def train_func(data, label, *, net, optimizer): @@ -59,10 +56,8 @@ # 使用假数据 batch_data = np.random.randn(batch_size, 3, 224, 224).astype(np.float32) batch_label = np.random.randint(1000, size=(batch_size,)).astype(np.float32) - data.set_value(batch_data) - label.set_value(batch_label) optimizer.zero_grad() - train_func(data, label, net=resnet, optimizer=optimizer) + train_func(batch_data, batch_label, net=resnet, optimizer=optimizer) optimizer.step() # 设置使用单卡 GPU ,显存容量为 11 GB diff --git a/source/advanced/two_static_mode.rst b/source/advanced/two_static_mode.rst index 0ced9315ef2fb28500bdb33aa04d26e52d7d6f85..821d7954baadce98b879503f3c95cfbebcb93989 100644 --- a/source/advanced/two_static_mode.rst +++ b/source/advanced/two_static_mode.rst @@ -26,13 +26,15 @@ # @trace(symbolic=False) # “动态构造” @trace(symbolic=True) # “静态构造” - def train_func(data, label, *, opt, net): + def train_func(data, label, *, opt, net): logits = net(data) print(logits[0]) # 因网络输出太多,此处仅打印部分 loss = F.cross_entropy_with_softmax(logits, label) opt.backward(loss) return logits, loss +输出为: + .. testoutput:: Tensor(None) diff --git a/source/basic/dynamic_and_static_graph.rst b/source/basic/dynamic_and_static_graph.rst index 61059000155ceca3522ac427e1e36ec0436f232a..1e4fd0cc7db1239fa5302ad901d0f34e2919e0ea 100644 --- a/source/basic/dynamic_and_static_graph.rst +++ b/source/basic/dynamic_and_static_graph.rst @@ -20,9 +20,9 @@ MegEngine支持 **静态计算图** 模式。该模式将计算图的构建和 在上图左侧的计算图中,为了存储 ``x`` 、 ``w`` 、 ``p`` 、 ``b``, ``y`` 五个变量,动态图需要 ``40`` 个字节(假设每个变量占用 8 字节的内存)。在静态图中,由于我们只需要知道结果 ``y`` ,可以让 ``y`` 复用中间变量 ``p`` 的内存,实现“原地”(inplace)修改。这样,静态图所占用的内存就减少为 ``32`` 个字节。 -MegEngine 还采用 **算子融合** (Operator Fuse)的方式减少计算开销。以上图为例,我们可以将乘法和加法融合为一个三元操作(假设硬件支持) **乘加** ,降低计算量。 +MegEngine 还采用 **算子融合** (Operator Fuse)的方式减少计算开销。以上图为例,我们可以将乘法和加法融合为一个三元操作(假设底层支持) **乘加** ,降低计算量。 -注意,只有了解了完整的计算流程后才能进行上述优化。 +注意,框架只有获取了完整的计算流程后才能进行上述优化。 动态图转静态图 ------------------------------ @@ -38,14 +38,14 @@ MegEngine 提供了很方便的动静态图转换的方法,几乎无需代码 total_loss = 0 for step, (batch_data, batch_label) in enumerate(dataloader): optimizer.zero_grad() # 将参数的梯度置零 - + # 以下五行代码为网络的计算和优化,后续转静态图时将进行处理 data.set_value(batch_data) label.set_value(batch_label) logits = le_net(data) loss = F.cross_entropy_with_softmax(logits, label) optimizer.backward(loss) # 反传计算梯度 - + optimizer.step() # 根据梯度更新参数值 total_loss += loss.numpy().item() print("epoch: {}, loss {}".format(epoch, total_loss/len(dataloader))) @@ -133,11 +133,11 @@ MegEngine 提供了很方便的动静态图转换的方法,几乎无需代码 total_loss = 0 for step, (data, label) in enumerate(dataloader): optimizer.zero_grad() # 将参数的梯度置零 - - label = label.astype('int32') # 交叉熵损失的label需要int32类型 + + label = label.astype('int32') # 交叉熵损失的label需要int32类型 # 调用被 trace 装饰后的函数 logits, loss = train_func(data, label, opt=optimizer, net=le_net) - + optimizer.step() # 根据梯度更新参数值 total_loss += loss.numpy().item() print("epoch: {}, loss {}".format(epoch, total_loss/len(dataloader)))