提交 5c776877 编写于 作者: M minqiyang

Fix batch_norm's stop_gradient bug

test=develop
上级 79d62c54
......@@ -156,6 +156,8 @@ class Autograd {
for (auto it : candidate->pre_ops_) {
for (OpBase* pre_op : it.second) {
if (!pre_op) continue;
VLOG(5) << "op dep " << candidate->op_desc_->Type() << " <---- "
<< it.first << " <---- " << pre_op->op_desc_->Type();
if (visited.find(pre_op) == visited.end()) {
visited.insert(pre_op);
queue.push_back(pre_op);
......
......@@ -28,6 +28,7 @@
#include "paddle/fluid/framework/var_desc.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/imperative/type_defs.h"
......@@ -148,8 +149,12 @@ class VarBase {
}
void ClearGradient() {
delete grads_;
grads_ = new VarBase(true);
VLOG(1) << "clear gradient of " << var_desc_->Name();
auto grads_t = grads_->var_->GetMutable<framework::LoDTensor>();
operators::math::set_constant(
*(platform::DeviceContextPool::Instance().Get(
grads_->var_->Get<framework::LoDTensor>().place())),
grads_t, 0.0);
}
framework::LoDTensor& GradValue();
......
......@@ -83,11 +83,12 @@ void Tracer::Trace(OpBase* op, const VarBasePtrMap& inputs,
op->input_vars_ = inputs;
for (auto it : op->input_vars_) {
auto& invars = invars_map[it.first];
invars.reserve(it.second.size());
for (VarBase* inp : it.second) {
PADDLE_ENFORCE_NOT_NULL(inp->var_, "op %s input %s nullptr",
op->op_desc_->Type(), inp->var_desc_->Name());
invars.push_back(inp->var_);
invars.emplace_back(inp->var_);
vars[inp->var_desc_->Name()] = inp;
if (inp->PreOp()) {
op->pre_ops_[it.first].push_back(inp->PreOp());
......@@ -104,9 +105,10 @@ void Tracer::Trace(OpBase* op, const VarBasePtrMap& inputs,
for (auto it : op->output_vars_) {
auto& outvars = outvars_map[it.first];
const std::vector<VarBase*>& outputs = it.second;
outvars.reserve(outputs.size());
for (size_t i = 0; i < outputs.size(); ++i) {
VarBase* out = outputs[i];
outvars.push_back(out->var_);
outvars.emplace_back(out->var_);
vars[out->var_desc_->Name()] = out;
framework::VarDesc* var_desc = block->FindVar(out->var_desc_->Name());
......
......@@ -334,6 +334,7 @@ class BatchNorm(layers.Layer):
default_initializer=Constant(1.0))
if use_global_stats and self._helper.param_attr.learning_rate == 0.:
self._scale.stop_gradient = True
self._scale._stop_gradient = True
self._bias = self._helper.create_parameter(
attr=self._helper.bias_attr,
......@@ -342,6 +343,7 @@ class BatchNorm(layers.Layer):
is_bias=True)
if use_global_stats and self._helper.bias_attr.learning_rate == 0.:
self._bias.stop_gradient = True
self._bias._stop_gradient = True
self._mean = self._helper.create_parameter(
attr=ParamAttr(
......@@ -352,6 +354,7 @@ class BatchNorm(layers.Layer):
shape=param_shape,
dtype=self._dtype)
self._mean.stop_gradient = True
self._mean._stop_gradient = True
self._variance = self._helper.create_parameter(
attr=ParamAttr(
......@@ -362,6 +365,7 @@ class BatchNorm(layers.Layer):
shape=param_shape,
dtype=self._dtype)
self._variance.stop_gradient = True
self._variance._stop_gradient = True
self._in_place = in_place
self._momentum = momentum
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册