未验证 提交 b69996c2 编写于 作者: Q Qiyang Min 提交者: GitHub

Merge pull request #15558 from velconia/imperative_resnet

Refine Batch Norm
...@@ -156,6 +156,8 @@ class Autograd { ...@@ -156,6 +156,8 @@ class Autograd {
for (auto it : candidate->pre_ops_) { for (auto it : candidate->pre_ops_) {
for (OpBase* pre_op : it.second) { for (OpBase* pre_op : it.second) {
if (!pre_op) continue; if (!pre_op) continue;
VLOG(5) << "op dep " << candidate->op_desc_->Type() << " <---- "
<< it.first << " <---- " << pre_op->op_desc_->Type();
if (visited.find(pre_op) == visited.end()) { if (visited.find(pre_op) == visited.end()) {
visited.insert(pre_op); visited.insert(pre_op);
queue.push_back(pre_op); queue.push_back(pre_op);
......
...@@ -28,6 +28,7 @@ ...@@ -28,6 +28,7 @@
#include "paddle/fluid/framework/var_desc.h" #include "paddle/fluid/framework/var_desc.h"
#include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/device_context.h" #include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/imperative/type_defs.h" #include "paddle/fluid/imperative/type_defs.h"
...@@ -140,16 +141,24 @@ class VarBase { ...@@ -140,16 +141,24 @@ class VarBase {
void RunBackward(); void RunBackward();
void TrackPreOp(OpBase* pre_op, const std::string& pre_op_out_name, void TrackPreOp(OpBase* pre_op, const std::string& pre_op_out_name,
int pre_op_out_idx, bool stop_gradient) { int pre_op_out_idx, bool pre_op_stop_gradient) {
pre_op_ = pre_op; pre_op_ = pre_op;
pre_op_out_name_ = pre_op_out_name; pre_op_out_name_ = pre_op_out_name;
pre_op_out_idx_ = pre_op_out_idx; pre_op_out_idx_ = pre_op_out_idx;
stop_gradient_ = stop_gradient; if (pre_op_stop_gradient) {
stop_gradient_ = pre_op_stop_gradient;
}
} }
void ClearGradient() { void ClearGradient() {
delete grads_; VLOG(1) << "clear gradient of " << var_desc_->Name();
grads_ = new VarBase(true); if (grads_ && grads_->var_ && grads_->var_->IsInitialized()) {
auto grads_t = grads_->var_->GetMutable<framework::LoDTensor>();
operators::math::set_constant(
*(platform::DeviceContextPool::Instance().Get(
grads_->var_->Get<framework::LoDTensor>().place())),
grads_t, 0.0);
}
} }
framework::LoDTensor& GradValue(); framework::LoDTensor& GradValue();
......
...@@ -84,11 +84,12 @@ void Tracer::Trace(OpBase* op, const VarBasePtrMap& inputs, ...@@ -84,11 +84,12 @@ void Tracer::Trace(OpBase* op, const VarBasePtrMap& inputs,
op->input_vars_ = inputs; op->input_vars_ = inputs;
for (auto it : op->input_vars_) { for (auto it : op->input_vars_) {
auto& invars = invars_map[it.first]; auto& invars = invars_map[it.first];
invars.reserve(it.second.size());
for (VarBase* inp : it.second) { for (VarBase* inp : it.second) {
PADDLE_ENFORCE_NOT_NULL(inp->var_, "op %s input %s nullptr", PADDLE_ENFORCE_NOT_NULL(inp->var_, "op %s input %s nullptr",
op->op_desc_->Type(), inp->var_desc_->Name()); op->op_desc_->Type(), inp->var_desc_->Name());
invars.push_back(inp->var_); invars.emplace_back(inp->var_);
vars[inp->var_desc_->Name()] = inp; vars[inp->var_desc_->Name()] = inp;
if (inp->PreOp()) { if (inp->PreOp()) {
op->pre_ops_[it.first].push_back(inp->PreOp()); op->pre_ops_[it.first].push_back(inp->PreOp());
...@@ -105,9 +106,10 @@ void Tracer::Trace(OpBase* op, const VarBasePtrMap& inputs, ...@@ -105,9 +106,10 @@ void Tracer::Trace(OpBase* op, const VarBasePtrMap& inputs,
for (auto it : op->output_vars_) { for (auto it : op->output_vars_) {
auto& outvars = outvars_map[it.first]; auto& outvars = outvars_map[it.first];
const std::vector<VarBase*>& outputs = it.second; const std::vector<VarBase*>& outputs = it.second;
outvars.reserve(outputs.size());
for (size_t i = 0; i < outputs.size(); ++i) { for (size_t i = 0; i < outputs.size(); ++i) {
VarBase* out = outputs[i]; VarBase* out = outputs[i];
outvars.push_back(out->var_); outvars.emplace_back(out->var_);
vars[out->var_desc_->Name()] = out; vars[out->var_desc_->Name()] = out;
framework::VarDesc* var_desc = block->FindVar(out->var_desc_->Name()); framework::VarDesc* var_desc = block->FindVar(out->var_desc_->Name());
......
...@@ -445,11 +445,16 @@ class Variable(object): ...@@ -445,11 +445,16 @@ class Variable(object):
@property @property
def _stop_gradient(self): def _stop_gradient(self):
if _in_imperative_mode():
return self._ivar.stop_gradient return self._ivar.stop_gradient
else:
return self.stop_gradient
@_stop_gradient.setter @_stop_gradient.setter
def _stop_gradient(self, s): def _stop_gradient(self, s):
if _in_imperative_mode():
self._ivar.stop_gradient = s self._ivar.stop_gradient = s
self.stop_gradient = s
@property @property
def persistable(self): def persistable(self):
...@@ -1310,6 +1315,9 @@ class Block(object): ...@@ -1310,6 +1315,9 @@ class Block(object):
outputs=kwargs.get("outputs", None), outputs=kwargs.get("outputs", None),
attrs=kwargs.get("attrs", None)) attrs=kwargs.get("attrs", None))
self.ops.append(op) self.ops.append(op)
# TODO(minqiyang): add stop_gradient support in static mode too.
# currently, we only support stop_gradient in imperative mode.
self._trace_op(op, kwargs.get("stop_gradient", False)) self._trace_op(op, kwargs.get("stop_gradient", False))
return op return op
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
import contextlib import contextlib
import sys import sys
import numpy as np import numpy as np
import collections
from paddle.fluid import core from paddle.fluid import core
from paddle.fluid import framework from paddle.fluid import framework
...@@ -31,7 +32,23 @@ class Layer(core.Layer): ...@@ -31,7 +32,23 @@ class Layer(core.Layer):
self._dtype = dtype self._dtype = dtype
def parameters(self): def parameters(self):
return [] params = []
for key in self.__dict__.keys():
value = self.__dict__[key]
if isinstance(value, framework.Parameter):
params.append(value)
elif isinstance(value, core.Layer):
params.extend(value.parameters())
elif isinstance(value, collections.Container):
if len(value) == 0:
continue
if isinstance(value[0], framework.Parameter):
params.extend(value)
elif isinstance(value[0], core.Layer):
for v in value:
params.extend(v.parameters())
return params
def clear_gradients(self): def clear_gradients(self):
for p in self.parameters(): for p in self.parameters():
......
...@@ -332,21 +332,16 @@ class BatchNorm(layers.Layer): ...@@ -332,21 +332,16 @@ class BatchNorm(layers.Layer):
shape=param_shape, shape=param_shape,
dtype=self._dtype, dtype=self._dtype,
default_initializer=Constant(1.0)) default_initializer=Constant(1.0))
if use_global_stats and self._helper.param_attr.learning_rate == 0.:
# TODO(minqiyang): change stop_gradient sign to trainable to align with static graph self._scale._stop_gradient = True
# # setting stop_gradient=True to reduce computation
# if use_global_stats and self._helper.param_attr.learning_rate == 0.:
# self._scale.stop_gradient = True
self._bias = self._helper.create_parameter( self._bias = self._helper.create_parameter(
attr=self._helper.bias_attr, attr=self._helper.bias_attr,
shape=param_shape, shape=param_shape,
dtype=self._dtype, dtype=self._dtype,
is_bias=True) is_bias=True)
# TODO(minqiyang): change stop_gradient sign to trainable to align with static graph if use_global_stats and self._helper.bias_attr.learning_rate == 0.:
# # setting stop_gradient=True to reduce computation self._bias._stop_gradient = True
# if use_global_stats and self._helper.bias_attr.learning_rate == 0.:
# self._bias.stop_gradient = True
self._mean = self._helper.create_parameter( self._mean = self._helper.create_parameter(
attr=ParamAttr( attr=ParamAttr(
...@@ -356,7 +351,7 @@ class BatchNorm(layers.Layer): ...@@ -356,7 +351,7 @@ class BatchNorm(layers.Layer):
do_model_average=do_model_average_for_mean_and_var), do_model_average=do_model_average_for_mean_and_var),
shape=param_shape, shape=param_shape,
dtype=self._dtype) dtype=self._dtype)
self._mean.stop_gradient = True self._mean._stop_gradient = True
self._variance = self._helper.create_parameter( self._variance = self._helper.create_parameter(
attr=ParamAttr( attr=ParamAttr(
...@@ -366,7 +361,7 @@ class BatchNorm(layers.Layer): ...@@ -366,7 +361,7 @@ class BatchNorm(layers.Layer):
do_model_average=do_model_average_for_mean_and_var), do_model_average=do_model_average_for_mean_and_var),
shape=param_shape, shape=param_shape,
dtype=self._dtype) dtype=self._dtype)
self._variance.stop_gradient = True self._variance._stop_gradient = True
self._in_place = in_place self._in_place = in_place
self._momentum = momentum self._momentum = momentum
......
...@@ -387,7 +387,7 @@ class Optimizer(object): ...@@ -387,7 +387,7 @@ class Optimizer(object):
params_grads = [] params_grads = []
for param in parameters: for param in parameters:
if param.stop_gradient: if param.stop_gradient or not param.trainable:
continue continue
# create gradient variable # create gradient variable
grad_var = Variable( grad_var = Variable(
......
...@@ -85,6 +85,7 @@ list(REMOVE_ITEM TEST_OPS test_image_classification_resnet) ...@@ -85,6 +85,7 @@ list(REMOVE_ITEM TEST_OPS test_image_classification_resnet)
list(REMOVE_ITEM TEST_OPS test_bilinear_interp_op) list(REMOVE_ITEM TEST_OPS test_bilinear_interp_op)
list(REMOVE_ITEM TEST_OPS test_nearest_interp_op) list(REMOVE_ITEM TEST_OPS test_nearest_interp_op)
list(REMOVE_ITEM TEST_OPS test_imperative_resnet) list(REMOVE_ITEM TEST_OPS test_imperative_resnet)
list(REMOVE_ITEM TEST_OPS test_imperative_optimizer)
foreach(TEST_OP ${TEST_OPS}) foreach(TEST_OP ${TEST_OPS})
py_test_modules(${TEST_OP} MODULES ${TEST_OP}) py_test_modules(${TEST_OP} MODULES ${TEST_OP})
endforeach(TEST_OP) endforeach(TEST_OP)
...@@ -94,6 +95,8 @@ py_test_modules(test_bilinear_interp_op MODULES test_bilinear_interp_op SERIAL) ...@@ -94,6 +95,8 @@ py_test_modules(test_bilinear_interp_op MODULES test_bilinear_interp_op SERIAL)
py_test_modules(test_nearest_interp_op MODULES test_nearest_interp_op SERIAL) py_test_modules(test_nearest_interp_op MODULES test_nearest_interp_op SERIAL)
py_test_modules(test_imperative_resnet MODULES test_imperative_resnet ENVS py_test_modules(test_imperative_resnet MODULES test_imperative_resnet ENVS
FLAGS_cudnn_deterministic=1) FLAGS_cudnn_deterministic=1)
py_test_modules(test_imperative_optimizer MODULES test_imperative_optimizer ENVS
FLAGS_cudnn_deterministic=1)
if(WITH_DISTRIBUTE) if(WITH_DISTRIBUTE)
py_test_modules(test_dist_train MODULES test_dist_train SERIAL) py_test_modules(test_dist_train MODULES test_dist_train SERIAL)
set_tests_properties(test_listen_and_serv_op PROPERTIES TIMEOUT 20) set_tests_properties(test_listen_and_serv_op PROPERTIES TIMEOUT 20)
......
...@@ -82,13 +82,14 @@ class MNIST(fluid.imperative.Layer): ...@@ -82,13 +82,14 @@ class MNIST(fluid.imperative.Layer):
self._simple_img_conv_pool_2 = SimpleImgConvPool( self._simple_img_conv_pool_2 = SimpleImgConvPool(
20, 50, 5, 2, 2, act="relu") 20, 50, 5, 2, 2, act="relu")
pool_2_shape = 50 * 8 * 8 pool_2_shape = 50 * 4 * 4
SIZE = 10 SIZE = 10
scale = (2.0 / (pool_2_shape**2 * SIZE))**0.5 scale = (2.0 / (pool_2_shape**2 * SIZE))**0.5
self._fc = FC(10, self._fc = FC(10,
param_attr=fluid.param_attr.ParamAttr( param_attr=fluid.param_attr.ParamAttr(
initializer=fluid.initializer.NormalInitializer( initializer=fluid.initializer.NormalInitializer(
loc=0.0, scale=scale))) loc=0.0, scale=scale)),
act="softmax")
def forward(self, inputs): def forward(self, inputs):
x = self._simple_img_conv_pool_1(inputs) x = self._simple_img_conv_pool_1(inputs)
...@@ -98,9 +99,9 @@ class MNIST(fluid.imperative.Layer): ...@@ -98,9 +99,9 @@ class MNIST(fluid.imperative.Layer):
class TestImperativeMnist(unittest.TestCase): class TestImperativeMnist(unittest.TestCase):
def test_mnist_cpu_float32(self): def test_mnist_float32(self):
seed = 90 seed = 90
batch_num = 2
with fluid.imperative.guard(): with fluid.imperative.guard():
fluid.default_startup_program().random_seed = seed fluid.default_startup_program().random_seed = seed
fluid.default_main_program().random_seed = seed fluid.default_main_program().random_seed = seed
...@@ -112,15 +113,15 @@ class TestImperativeMnist(unittest.TestCase): ...@@ -112,15 +113,15 @@ class TestImperativeMnist(unittest.TestCase):
dy_param_init_value = {} dy_param_init_value = {}
for batch_id, data in enumerate(train_reader()): for batch_id, data in enumerate(train_reader()):
if batch_id >= 2: if batch_id >= batch_num:
break break
x_data = np.array( dy_x_data = np.array(
[x[0].reshape(1, 28, 28) for x in data]).astype('float32') [x[0].reshape(1, 28, 28) for x in data]).astype('float32')
y_data = np.array([x[1] for x in data]).astype('int64').reshape( y_data = np.array([x[1] for x in data]).astype('int64').reshape(
128, 1) 128, 1)
img = to_variable(x_data) img = to_variable(dy_x_data)
label = to_variable(y_data) label = to_variable(y_data)
label._stop_gradient = True label._stop_gradient = True
...@@ -136,6 +137,7 @@ class TestImperativeMnist(unittest.TestCase): ...@@ -136,6 +137,7 @@ class TestImperativeMnist(unittest.TestCase):
avg_loss._backward() avg_loss._backward()
sgd.minimize(avg_loss) sgd.minimize(avg_loss)
mnist.clear_gradients()
dy_param_value = {} dy_param_value = {}
for param in fluid.default_main_program().global_block( for param in fluid.default_main_program().global_block(
).all_parameters(): ).all_parameters():
...@@ -175,10 +177,10 @@ class TestImperativeMnist(unittest.TestCase): ...@@ -175,10 +177,10 @@ class TestImperativeMnist(unittest.TestCase):
static_param_init_value[static_param_name_list[i]] = out[i] static_param_init_value[static_param_name_list[i]] = out[i]
for batch_id, data in enumerate(train_reader()): for batch_id, data in enumerate(train_reader()):
if batch_id >= 2: if batch_id >= batch_num:
break break
x_data = np.array( static_x_data = np.array(
[x[0].reshape(1, 28, 28) for x in data]).astype('float32') [x[0].reshape(1, 28, 28) for x in data]).astype('float32')
y_data = np.array([x[1] for x in data]).astype('int64').reshape( y_data = np.array([x[1] for x in data]).astype('int64').reshape(
[128, 1]) [128, 1])
...@@ -186,7 +188,7 @@ class TestImperativeMnist(unittest.TestCase): ...@@ -186,7 +188,7 @@ class TestImperativeMnist(unittest.TestCase):
fetch_list = [avg_loss.name] fetch_list = [avg_loss.name]
fetch_list.extend(static_param_name_list) fetch_list.extend(static_param_name_list)
out = exe.run(fluid.default_main_program(), out = exe.run(fluid.default_main_program(),
feed={"pixel": x_data, feed={"pixel": static_x_data,
"label": y_data}, "label": y_data},
fetch_list=fetch_list) fetch_list=fetch_list)
...@@ -196,11 +198,12 @@ class TestImperativeMnist(unittest.TestCase): ...@@ -196,11 +198,12 @@ class TestImperativeMnist(unittest.TestCase):
static_param_value[static_param_name_list[i - 1]] = out[i] static_param_value[static_param_name_list[i - 1]] = out[i]
for key, value in six.iteritems(static_param_init_value): for key, value in six.iteritems(static_param_init_value):
self.assertTrue( self.assertTrue(np.allclose(value, dy_param_init_value[key]))
np.allclose(value.all(), dy_param_init_value[key].all()))
self.assertTrue(np.allclose(static_out.all(), dy_out.all())) self.assertTrue(np.allclose(static_out, dy_out))
for key, value in six.iteritems(static_param_value): for key, value in six.iteritems(static_param_value):
self.assertTrue(np.allclose(value.all(), dy_param_value[key].all())) self.assertTrue(np.allclose(value, dy_param_value[key]))
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -264,6 +264,7 @@ class TestImperativeResnet(unittest.TestCase): ...@@ -264,6 +264,7 @@ class TestImperativeResnet(unittest.TestCase):
)] = np_array )] = np_array
optimizer.minimize(avg_loss) optimizer.minimize(avg_loss)
resnet.clear_gradients()
dy_param_value = {} dy_param_value = {}
for param in fluid.default_main_program().global_block( for param in fluid.default_main_program().global_block(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册