Skip to content

  • 体验新版
    • 正在加载...
  • 登录
  • PaddlePaddle
  • Paddle
  • 合并请求
  • !25024

P
Paddle
  • 项目概览

PaddlePaddle / Paddle
接近 2 年 前同步成功

通知 2323
Star 20933
Fork 5424
  • 代码
    • 文件
    • 提交
    • 分支
    • Tags
    • 贡献者
    • 分支图
    • Diff
  • Issue 1423
    • 列表
    • 看板
    • 标记
    • 里程碑
  • 合并请求 543
  • Wiki 0
    • Wiki
  • 分析
    • 仓库
    • DevOps
  • 项目成员
  • Pages
P
Paddle
  • 项目概览
    • 项目概览
    • 详情
    • 发布
  • 仓库
    • 仓库
    • 文件
    • 提交
    • 分支
    • 标签
    • 贡献者
    • 分支图
    • 比较
  • Issue 1,423
    • Issue 1,423
    • 列表
    • 看板
    • 标记
    • 里程碑
  • 合并请求 543
    • 合并请求 543
  • Pages
  • 分析
    • 分析
    • 仓库分析
    • DevOps
  • Wiki 0
    • Wiki
  • 成员
    • 成员
  • 收起侧边栏
  • 动态
  • 分支图
  • 创建新Issue
  • 提交
  • Issue看板

Fix model eval function !25024

  • Report abuse
!25024 开放中 6月 11, 2020 由 saxon_zh@saxon_zh 创建
#<User:0x000055935d1b8e88>
  • 概览 1
  • 提交 8
  • 变更 7

Created by: phlrain

PR types

Function optimization

PR changes

APIs

Describe

本pr一共有两个改动点(这两个修改点强耦合):

  1. eval mode和 no_grad解绑

  2. 解决多个model之前train/eval干扰的问题

  3. 在动态图中,eval mode和no_grad(不记录反向)功能是绑定的,切换到eval mode会自动打开no_grad,这个跟之前动态图下不调用 loss.backward,显存无法释放的问题有关系,目前不调用loss.backward显存能够释放,目前把eval mode和 no_grad解绑

  4. 动态图如果定义了多个model,train/eval model会相互干扰 model_1 = AModel() model_2 = TModel()

model_1.train() model_2.eval()

在这种情况下,model_1其实也被设置成了 eval mode,根本原因是目前mode的设置存储在了一个全局的tracer当中,每次修改都会全局修改; 目前.train/eval切换影响的op只有batch norm和 dropout两种, Dygraph.BatchNorm 和 class形势的 dygraph.Dropout由于都是获取自己内部的training状态,不会受到全局状态的影响。

唯一会有影响的是 fluid.layers.dropout() 在这种写法下,op的is_test属性会获取全局的信息, 为解决避免layers.dropout 受到全局的影响,有两种解决方案

  1. 将fluid.layers.dropout 替换为 dygraph.Dropout,需要先在 __init__定义 self.dropout = dygraph.Dropout 然后forward中 x = self.dropout(x)
  2. 采用如下写法 fluid.layers.dropout(x, is_test=not self.training) ,将is_test与Layer的 training状态绑定
指派人
分配到
审核者
Request review from
无
里程碑
无
分配里程碑
工时统计
标识: paddlepaddle/Paddle!25024
Source branch: github/fork/phlrain/fix_model_eval_function
渝ICP备2023009037号

京公网安备11010502055752号

网络110报警服务 Powered by GitLab CE v13.7
开源知识
Git 入门 Pro Git 电子书 在线学 Git
Markdown 基础入门 IT 技术知识开源图谱
帮助
使用手册 反馈建议 博客
《GitCode 隐私声明》 《GitCode 服务条款》 关于GitCode
Powered by GitLab CE v13.7