diff --git a/paddle/fluid/imperative/layer.cc b/paddle/fluid/imperative/layer.cc index 83fc6ee2e299f5fa18d5cc6f220c0be6a66e709d..47488d4dea79f285769f29c93f7888a7f783f070 100644 --- a/paddle/fluid/imperative/layer.cc +++ b/paddle/fluid/imperative/layer.cc @@ -156,6 +156,8 @@ class Autograd { for (auto it : candidate->pre_ops_) { for (OpBase* pre_op : it.second) { if (!pre_op) continue; + VLOG(5) << "op dep " << candidate->op_desc_->Type() << " <---- " + << it.first << " <---- " << pre_op->op_desc_->Type(); if (visited.find(pre_op) == visited.end()) { visited.insert(pre_op); queue.push_back(pre_op); diff --git a/paddle/fluid/imperative/layer.h b/paddle/fluid/imperative/layer.h index dc97433a5102b39d03ea5cac3157c027f9d67c98..78205486c5534ac0c61cc6d545bdafa4dfc95695 100644 --- a/paddle/fluid/imperative/layer.h +++ b/paddle/fluid/imperative/layer.h @@ -28,6 +28,7 @@ #include "paddle/fluid/framework/var_desc.h" #include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/device_context.h" +#include "paddle/fluid/operators/math/math_function.h" #include "paddle/fluid/imperative/type_defs.h" @@ -140,16 +141,24 @@ class VarBase { void RunBackward(); void TrackPreOp(OpBase* pre_op, const std::string& pre_op_out_name, - int pre_op_out_idx, bool stop_gradient) { + int pre_op_out_idx, bool pre_op_stop_gradient) { pre_op_ = pre_op; pre_op_out_name_ = pre_op_out_name; pre_op_out_idx_ = pre_op_out_idx; - stop_gradient_ = stop_gradient; + if (pre_op_stop_gradient) { + stop_gradient_ = pre_op_stop_gradient; + } } void ClearGradient() { - delete grads_; - grads_ = new VarBase(true); + VLOG(1) << "clear gradient of " << var_desc_->Name(); + if (grads_ && grads_->var_ && grads_->var_->IsInitialized()) { + auto grads_t = grads_->var_->GetMutable(); + operators::math::set_constant( + *(platform::DeviceContextPool::Instance().Get( + grads_->var_->Get().place())), + grads_t, 0.0); + } } framework::LoDTensor& GradValue(); diff --git a/paddle/fluid/imperative/tracer.cc b/paddle/fluid/imperative/tracer.cc index cd62807a5532e6b2309cb5a8f679c3097b51c9e9..51bbac6d2a1cf2bd64f3e1f2d420e104569273c8 100644 --- a/paddle/fluid/imperative/tracer.cc +++ b/paddle/fluid/imperative/tracer.cc @@ -84,11 +84,12 @@ void Tracer::Trace(OpBase* op, const VarBasePtrMap& inputs, op->input_vars_ = inputs; for (auto it : op->input_vars_) { auto& invars = invars_map[it.first]; + invars.reserve(it.second.size()); for (VarBase* inp : it.second) { PADDLE_ENFORCE_NOT_NULL(inp->var_, "op %s input %s nullptr", op->op_desc_->Type(), inp->var_desc_->Name()); - invars.push_back(inp->var_); + invars.emplace_back(inp->var_); vars[inp->var_desc_->Name()] = inp; if (inp->PreOp()) { op->pre_ops_[it.first].push_back(inp->PreOp()); @@ -105,9 +106,10 @@ void Tracer::Trace(OpBase* op, const VarBasePtrMap& inputs, for (auto it : op->output_vars_) { auto& outvars = outvars_map[it.first]; const std::vector& outputs = it.second; + outvars.reserve(outputs.size()); for (size_t i = 0; i < outputs.size(); ++i) { VarBase* out = outputs[i]; - outvars.push_back(out->var_); + outvars.emplace_back(out->var_); vars[out->var_desc_->Name()] = out; framework::VarDesc* var_desc = block->FindVar(out->var_desc_->Name()); diff --git a/python/paddle/fluid/framework.py b/python/paddle/fluid/framework.py index 2bdae60db347b3d42fded138a20a505486e48dbc..96587b6e904f681a71182ffdb03608b5edde5e46 100644 --- a/python/paddle/fluid/framework.py +++ b/python/paddle/fluid/framework.py @@ -445,11 +445,16 @@ class Variable(object): @property def _stop_gradient(self): - return self._ivar.stop_gradient + if _in_imperative_mode(): + return self._ivar.stop_gradient + else: + return self.stop_gradient @_stop_gradient.setter def _stop_gradient(self, s): - self._ivar.stop_gradient = s + if _in_imperative_mode(): + self._ivar.stop_gradient = s + self.stop_gradient = s @property def persistable(self): @@ -1310,6 +1315,9 @@ class Block(object): outputs=kwargs.get("outputs", None), attrs=kwargs.get("attrs", None)) self.ops.append(op) + + # TODO(minqiyang): add stop_gradient support in static mode too. + # currently, we only support stop_gradient in imperative mode. self._trace_op(op, kwargs.get("stop_gradient", False)) return op diff --git a/python/paddle/fluid/imperative/layers.py b/python/paddle/fluid/imperative/layers.py index f457f56203eb2c1da62f4d8ad8915c322c822e0a..71ff95bdea36967c1fa6b5c94cc7ca305e7a544a 100644 --- a/python/paddle/fluid/imperative/layers.py +++ b/python/paddle/fluid/imperative/layers.py @@ -15,6 +15,7 @@ import contextlib import sys import numpy as np +import collections from paddle.fluid import core from paddle.fluid import framework @@ -31,7 +32,23 @@ class Layer(core.Layer): self._dtype = dtype def parameters(self): - return [] + params = [] + for key in self.__dict__.keys(): + value = self.__dict__[key] + if isinstance(value, framework.Parameter): + params.append(value) + elif isinstance(value, core.Layer): + params.extend(value.parameters()) + elif isinstance(value, collections.Container): + if len(value) == 0: + continue + if isinstance(value[0], framework.Parameter): + params.extend(value) + elif isinstance(value[0], core.Layer): + for v in value: + params.extend(v.parameters()) + + return params def clear_gradients(self): for p in self.parameters(): diff --git a/python/paddle/fluid/imperative/nn.py b/python/paddle/fluid/imperative/nn.py index 140c0ff037d453641cc119301269121025e17cbd..dc90603c3716bef347a5212a3c6fb66522e49c14 100644 --- a/python/paddle/fluid/imperative/nn.py +++ b/python/paddle/fluid/imperative/nn.py @@ -332,21 +332,16 @@ class BatchNorm(layers.Layer): shape=param_shape, dtype=self._dtype, default_initializer=Constant(1.0)) - - # TODO(minqiyang): change stop_gradient sign to trainable to align with static graph - # # setting stop_gradient=True to reduce computation - # if use_global_stats and self._helper.param_attr.learning_rate == 0.: - # self._scale.stop_gradient = True + if use_global_stats and self._helper.param_attr.learning_rate == 0.: + self._scale._stop_gradient = True self._bias = self._helper.create_parameter( attr=self._helper.bias_attr, shape=param_shape, dtype=self._dtype, is_bias=True) - # TODO(minqiyang): change stop_gradient sign to trainable to align with static graph - # # setting stop_gradient=True to reduce computation - # if use_global_stats and self._helper.bias_attr.learning_rate == 0.: - # self._bias.stop_gradient = True + if use_global_stats and self._helper.bias_attr.learning_rate == 0.: + self._bias._stop_gradient = True self._mean = self._helper.create_parameter( attr=ParamAttr( @@ -356,7 +351,7 @@ class BatchNorm(layers.Layer): do_model_average=do_model_average_for_mean_and_var), shape=param_shape, dtype=self._dtype) - self._mean.stop_gradient = True + self._mean._stop_gradient = True self._variance = self._helper.create_parameter( attr=ParamAttr( @@ -366,7 +361,7 @@ class BatchNorm(layers.Layer): do_model_average=do_model_average_for_mean_and_var), shape=param_shape, dtype=self._dtype) - self._variance.stop_gradient = True + self._variance._stop_gradient = True self._in_place = in_place self._momentum = momentum diff --git a/python/paddle/fluid/optimizer.py b/python/paddle/fluid/optimizer.py index 14f4276e2f4fc4a24d701ef05c94b88c4f0336da..e0e781a322b3eb68e3f54a66252a8d8b11a9a56f 100644 --- a/python/paddle/fluid/optimizer.py +++ b/python/paddle/fluid/optimizer.py @@ -387,7 +387,7 @@ class Optimizer(object): params_grads = [] for param in parameters: - if param.stop_gradient: + if param.stop_gradient or not param.trainable: continue # create gradient variable grad_var = Variable( diff --git a/python/paddle/fluid/tests/unittests/CMakeLists.txt b/python/paddle/fluid/tests/unittests/CMakeLists.txt index c23dfa01e76c21d0d162f2fed986e2eaf3a70a6d..7e693c6a41f71f11fd702e2cfc26aa4a21cd2de7 100644 --- a/python/paddle/fluid/tests/unittests/CMakeLists.txt +++ b/python/paddle/fluid/tests/unittests/CMakeLists.txt @@ -85,6 +85,7 @@ list(REMOVE_ITEM TEST_OPS test_image_classification_resnet) list(REMOVE_ITEM TEST_OPS test_bilinear_interp_op) list(REMOVE_ITEM TEST_OPS test_nearest_interp_op) list(REMOVE_ITEM TEST_OPS test_imperative_resnet) +list(REMOVE_ITEM TEST_OPS test_imperative_optimizer) foreach(TEST_OP ${TEST_OPS}) py_test_modules(${TEST_OP} MODULES ${TEST_OP}) endforeach(TEST_OP) @@ -94,6 +95,8 @@ py_test_modules(test_bilinear_interp_op MODULES test_bilinear_interp_op SERIAL) py_test_modules(test_nearest_interp_op MODULES test_nearest_interp_op SERIAL) py_test_modules(test_imperative_resnet MODULES test_imperative_resnet ENVS FLAGS_cudnn_deterministic=1) +py_test_modules(test_imperative_optimizer MODULES test_imperative_optimizer ENVS + FLAGS_cudnn_deterministic=1) if(WITH_DISTRIBUTE) py_test_modules(test_dist_train MODULES test_dist_train SERIAL) set_tests_properties(test_listen_and_serv_op PROPERTIES TIMEOUT 20) diff --git a/python/paddle/fluid/tests/unittests/test_imperative_optimizer.py b/python/paddle/fluid/tests/unittests/test_imperative_optimizer.py index d0a5a883174cb33a035b344f9489b2ba02ba99f1..08b155acc657c3a4a73f5b1d72ac356fc7e83a58 100644 --- a/python/paddle/fluid/tests/unittests/test_imperative_optimizer.py +++ b/python/paddle/fluid/tests/unittests/test_imperative_optimizer.py @@ -82,13 +82,14 @@ class MNIST(fluid.imperative.Layer): self._simple_img_conv_pool_2 = SimpleImgConvPool( 20, 50, 5, 2, 2, act="relu") - pool_2_shape = 50 * 8 * 8 + pool_2_shape = 50 * 4 * 4 SIZE = 10 scale = (2.0 / (pool_2_shape**2 * SIZE))**0.5 self._fc = FC(10, param_attr=fluid.param_attr.ParamAttr( initializer=fluid.initializer.NormalInitializer( - loc=0.0, scale=scale))) + loc=0.0, scale=scale)), + act="softmax") def forward(self, inputs): x = self._simple_img_conv_pool_1(inputs) @@ -98,9 +99,9 @@ class MNIST(fluid.imperative.Layer): class TestImperativeMnist(unittest.TestCase): - def test_mnist_cpu_float32(self): + def test_mnist_float32(self): seed = 90 - + batch_num = 2 with fluid.imperative.guard(): fluid.default_startup_program().random_seed = seed fluid.default_main_program().random_seed = seed @@ -112,15 +113,15 @@ class TestImperativeMnist(unittest.TestCase): dy_param_init_value = {} for batch_id, data in enumerate(train_reader()): - if batch_id >= 2: + if batch_id >= batch_num: break - x_data = np.array( + dy_x_data = np.array( [x[0].reshape(1, 28, 28) for x in data]).astype('float32') y_data = np.array([x[1] for x in data]).astype('int64').reshape( 128, 1) - img = to_variable(x_data) + img = to_variable(dy_x_data) label = to_variable(y_data) label._stop_gradient = True @@ -136,6 +137,7 @@ class TestImperativeMnist(unittest.TestCase): avg_loss._backward() sgd.minimize(avg_loss) + mnist.clear_gradients() dy_param_value = {} for param in fluid.default_main_program().global_block( ).all_parameters(): @@ -175,10 +177,10 @@ class TestImperativeMnist(unittest.TestCase): static_param_init_value[static_param_name_list[i]] = out[i] for batch_id, data in enumerate(train_reader()): - if batch_id >= 2: + if batch_id >= batch_num: break - x_data = np.array( + static_x_data = np.array( [x[0].reshape(1, 28, 28) for x in data]).astype('float32') y_data = np.array([x[1] for x in data]).astype('int64').reshape( [128, 1]) @@ -186,7 +188,7 @@ class TestImperativeMnist(unittest.TestCase): fetch_list = [avg_loss.name] fetch_list.extend(static_param_name_list) out = exe.run(fluid.default_main_program(), - feed={"pixel": x_data, + feed={"pixel": static_x_data, "label": y_data}, fetch_list=fetch_list) @@ -196,11 +198,12 @@ class TestImperativeMnist(unittest.TestCase): static_param_value[static_param_name_list[i - 1]] = out[i] for key, value in six.iteritems(static_param_init_value): - self.assertTrue( - np.allclose(value.all(), dy_param_init_value[key].all())) - self.assertTrue(np.allclose(static_out.all(), dy_out.all())) + self.assertTrue(np.allclose(value, dy_param_init_value[key])) + + self.assertTrue(np.allclose(static_out, dy_out)) + for key, value in six.iteritems(static_param_value): - self.assertTrue(np.allclose(value.all(), dy_param_value[key].all())) + self.assertTrue(np.allclose(value, dy_param_value[key])) if __name__ == '__main__': diff --git a/python/paddle/fluid/tests/unittests/test_imperative_resnet.py b/python/paddle/fluid/tests/unittests/test_imperative_resnet.py index 87a72dd04e376cf9225e275d862b0cbbb9774e2c..c27fd0b8024a8fa3310a62de34299fb621e2902f 100644 --- a/python/paddle/fluid/tests/unittests/test_imperative_resnet.py +++ b/python/paddle/fluid/tests/unittests/test_imperative_resnet.py @@ -264,6 +264,7 @@ class TestImperativeResnet(unittest.TestCase): )] = np_array optimizer.minimize(avg_loss) + resnet.clear_gradients() dy_param_value = {} for param in fluid.default_main_program().global_block(