From 5c7768776c2a0b0a3b7c39e618897d17bb5bf882 Mon Sep 17 00:00:00 2001 From: minqiyang Date: Mon, 28 Jan 2019 17:00:04 +0800 Subject: [PATCH] Fix batch_norm's stop_gradient bug test=develop --- paddle/fluid/imperative/layer.cc | 2 ++ paddle/fluid/imperative/layer.h | 9 +++++++-- paddle/fluid/imperative/tracer.cc | 6 ++++-- python/paddle/fluid/imperative/nn.py | 4 ++++ 4 files changed, 17 insertions(+), 4 deletions(-) diff --git a/paddle/fluid/imperative/layer.cc b/paddle/fluid/imperative/layer.cc index 8029129b9a..64d4d999d1 100644 --- a/paddle/fluid/imperative/layer.cc +++ b/paddle/fluid/imperative/layer.cc @@ -156,6 +156,8 @@ class Autograd { for (auto it : candidate->pre_ops_) { for (OpBase* pre_op : it.second) { if (!pre_op) continue; + VLOG(5) << "op dep " << candidate->op_desc_->Type() << " <---- " + << it.first << " <---- " << pre_op->op_desc_->Type(); if (visited.find(pre_op) == visited.end()) { visited.insert(pre_op); queue.push_back(pre_op); diff --git a/paddle/fluid/imperative/layer.h b/paddle/fluid/imperative/layer.h index 633924aa41..0151a80816 100644 --- a/paddle/fluid/imperative/layer.h +++ b/paddle/fluid/imperative/layer.h @@ -28,6 +28,7 @@ #include "paddle/fluid/framework/var_desc.h" #include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/device_context.h" +#include "paddle/fluid/operators/math/math_function.h" #include "paddle/fluid/imperative/type_defs.h" @@ -148,8 +149,12 @@ class VarBase { } void ClearGradient() { - delete grads_; - grads_ = new VarBase(true); + VLOG(1) << "clear gradient of " << var_desc_->Name(); + auto grads_t = grads_->var_->GetMutable(); + operators::math::set_constant( + *(platform::DeviceContextPool::Instance().Get( + grads_->var_->Get().place())), + grads_t, 0.0); } framework::LoDTensor& GradValue(); diff --git a/paddle/fluid/imperative/tracer.cc b/paddle/fluid/imperative/tracer.cc index 5b87839f45..c8af936c33 100644 --- a/paddle/fluid/imperative/tracer.cc +++ b/paddle/fluid/imperative/tracer.cc @@ -83,11 +83,12 @@ void Tracer::Trace(OpBase* op, const VarBasePtrMap& inputs, op->input_vars_ = inputs; for (auto it : op->input_vars_) { auto& invars = invars_map[it.first]; + invars.reserve(it.second.size()); for (VarBase* inp : it.second) { PADDLE_ENFORCE_NOT_NULL(inp->var_, "op %s input %s nullptr", op->op_desc_->Type(), inp->var_desc_->Name()); - invars.push_back(inp->var_); + invars.emplace_back(inp->var_); vars[inp->var_desc_->Name()] = inp; if (inp->PreOp()) { op->pre_ops_[it.first].push_back(inp->PreOp()); @@ -104,9 +105,10 @@ void Tracer::Trace(OpBase* op, const VarBasePtrMap& inputs, for (auto it : op->output_vars_) { auto& outvars = outvars_map[it.first]; const std::vector& outputs = it.second; + outvars.reserve(outputs.size()); for (size_t i = 0; i < outputs.size(); ++i) { VarBase* out = outputs[i]; - outvars.push_back(out->var_); + outvars.emplace_back(out->var_); vars[out->var_desc_->Name()] = out; framework::VarDesc* var_desc = block->FindVar(out->var_desc_->Name()); diff --git a/python/paddle/fluid/imperative/nn.py b/python/paddle/fluid/imperative/nn.py index fe5014f5e6..543f573890 100644 --- a/python/paddle/fluid/imperative/nn.py +++ b/python/paddle/fluid/imperative/nn.py @@ -334,6 +334,7 @@ class BatchNorm(layers.Layer): default_initializer=Constant(1.0)) if use_global_stats and self._helper.param_attr.learning_rate == 0.: self._scale.stop_gradient = True + self._scale._stop_gradient = True self._bias = self._helper.create_parameter( attr=self._helper.bias_attr, @@ -342,6 +343,7 @@ class BatchNorm(layers.Layer): is_bias=True) if use_global_stats and self._helper.bias_attr.learning_rate == 0.: self._bias.stop_gradient = True + self._bias._stop_gradient = True self._mean = self._helper.create_parameter( attr=ParamAttr( @@ -352,6 +354,7 @@ class BatchNorm(layers.Layer): shape=param_shape, dtype=self._dtype) self._mean.stop_gradient = True + self._mean._stop_gradient = True self._variance = self._helper.create_parameter( attr=ParamAttr( @@ -362,6 +365,7 @@ class BatchNorm(layers.Layer): shape=param_shape, dtype=self._dtype) self._variance.stop_gradient = True + self._variance._stop_gradient = True self._in_place = in_place self._momentum = momentum -- GitLab