提交 5c776877 编写于 作者: M minqiyang

Fix batch_norm's stop_gradient bug

test=develop
上级 79d62c54
...@@ -156,6 +156,8 @@ class Autograd { ...@@ -156,6 +156,8 @@ class Autograd {
for (auto it : candidate->pre_ops_) { for (auto it : candidate->pre_ops_) {
for (OpBase* pre_op : it.second) { for (OpBase* pre_op : it.second) {
if (!pre_op) continue; if (!pre_op) continue;
VLOG(5) << "op dep " << candidate->op_desc_->Type() << " <---- "
<< it.first << " <---- " << pre_op->op_desc_->Type();
if (visited.find(pre_op) == visited.end()) { if (visited.find(pre_op) == visited.end()) {
visited.insert(pre_op); visited.insert(pre_op);
queue.push_back(pre_op); queue.push_back(pre_op);
......
...@@ -28,6 +28,7 @@ ...@@ -28,6 +28,7 @@
#include "paddle/fluid/framework/var_desc.h" #include "paddle/fluid/framework/var_desc.h"
#include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/device_context.h" #include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/imperative/type_defs.h" #include "paddle/fluid/imperative/type_defs.h"
...@@ -148,8 +149,12 @@ class VarBase { ...@@ -148,8 +149,12 @@ class VarBase {
} }
void ClearGradient() { void ClearGradient() {
delete grads_; VLOG(1) << "clear gradient of " << var_desc_->Name();
grads_ = new VarBase(true); auto grads_t = grads_->var_->GetMutable<framework::LoDTensor>();
operators::math::set_constant(
*(platform::DeviceContextPool::Instance().Get(
grads_->var_->Get<framework::LoDTensor>().place())),
grads_t, 0.0);
} }
framework::LoDTensor& GradValue(); framework::LoDTensor& GradValue();
......
...@@ -83,11 +83,12 @@ void Tracer::Trace(OpBase* op, const VarBasePtrMap& inputs, ...@@ -83,11 +83,12 @@ void Tracer::Trace(OpBase* op, const VarBasePtrMap& inputs,
op->input_vars_ = inputs; op->input_vars_ = inputs;
for (auto it : op->input_vars_) { for (auto it : op->input_vars_) {
auto& invars = invars_map[it.first]; auto& invars = invars_map[it.first];
invars.reserve(it.second.size());
for (VarBase* inp : it.second) { for (VarBase* inp : it.second) {
PADDLE_ENFORCE_NOT_NULL(inp->var_, "op %s input %s nullptr", PADDLE_ENFORCE_NOT_NULL(inp->var_, "op %s input %s nullptr",
op->op_desc_->Type(), inp->var_desc_->Name()); op->op_desc_->Type(), inp->var_desc_->Name());
invars.push_back(inp->var_); invars.emplace_back(inp->var_);
vars[inp->var_desc_->Name()] = inp; vars[inp->var_desc_->Name()] = inp;
if (inp->PreOp()) { if (inp->PreOp()) {
op->pre_ops_[it.first].push_back(inp->PreOp()); op->pre_ops_[it.first].push_back(inp->PreOp());
...@@ -104,9 +105,10 @@ void Tracer::Trace(OpBase* op, const VarBasePtrMap& inputs, ...@@ -104,9 +105,10 @@ void Tracer::Trace(OpBase* op, const VarBasePtrMap& inputs,
for (auto it : op->output_vars_) { for (auto it : op->output_vars_) {
auto& outvars = outvars_map[it.first]; auto& outvars = outvars_map[it.first];
const std::vector<VarBase*>& outputs = it.second; const std::vector<VarBase*>& outputs = it.second;
outvars.reserve(outputs.size());
for (size_t i = 0; i < outputs.size(); ++i) { for (size_t i = 0; i < outputs.size(); ++i) {
VarBase* out = outputs[i]; VarBase* out = outputs[i];
outvars.push_back(out->var_); outvars.emplace_back(out->var_);
vars[out->var_desc_->Name()] = out; vars[out->var_desc_->Name()] = out;
framework::VarDesc* var_desc = block->FindVar(out->var_desc_->Name()); framework::VarDesc* var_desc = block->FindVar(out->var_desc_->Name());
......
...@@ -334,6 +334,7 @@ class BatchNorm(layers.Layer): ...@@ -334,6 +334,7 @@ class BatchNorm(layers.Layer):
default_initializer=Constant(1.0)) default_initializer=Constant(1.0))
if use_global_stats and self._helper.param_attr.learning_rate == 0.: if use_global_stats and self._helper.param_attr.learning_rate == 0.:
self._scale.stop_gradient = True self._scale.stop_gradient = True
self._scale._stop_gradient = True
self._bias = self._helper.create_parameter( self._bias = self._helper.create_parameter(
attr=self._helper.bias_attr, attr=self._helper.bias_attr,
...@@ -342,6 +343,7 @@ class BatchNorm(layers.Layer): ...@@ -342,6 +343,7 @@ class BatchNorm(layers.Layer):
is_bias=True) is_bias=True)
if use_global_stats and self._helper.bias_attr.learning_rate == 0.: if use_global_stats and self._helper.bias_attr.learning_rate == 0.:
self._bias.stop_gradient = True self._bias.stop_gradient = True
self._bias._stop_gradient = True
self._mean = self._helper.create_parameter( self._mean = self._helper.create_parameter(
attr=ParamAttr( attr=ParamAttr(
...@@ -352,6 +354,7 @@ class BatchNorm(layers.Layer): ...@@ -352,6 +354,7 @@ class BatchNorm(layers.Layer):
shape=param_shape, shape=param_shape,
dtype=self._dtype) dtype=self._dtype)
self._mean.stop_gradient = True self._mean.stop_gradient = True
self._mean._stop_gradient = True
self._variance = self._helper.create_parameter( self._variance = self._helper.create_parameter(
attr=ParamAttr( attr=ParamAttr(
...@@ -362,6 +365,7 @@ class BatchNorm(layers.Layer): ...@@ -362,6 +365,7 @@ class BatchNorm(layers.Layer):
shape=param_shape, shape=param_shape,
dtype=self._dtype) dtype=self._dtype)
self._variance.stop_gradient = True self._variance.stop_gradient = True
self._variance._stop_gradient = True
self._in_place = in_place self._in_place = in_place
self._momentum = momentum self._momentum = momentum
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册