From 68e9b841ab69f8484944e77c486aa226e12ed5f2 Mon Sep 17 00:00:00 2001 From: minqiyang Date: Thu, 27 Dec 2018 10:28:30 +0800 Subject: [PATCH] Add support for optimizer --- paddle/fluid/imperative/layer.cc | 2 +- paddle/fluid/imperative/layer.h | 9 +++ paddle/fluid/imperative/tracer.h | 8 ++- paddle/fluid/operators/optimizers/sgd_op.h | 5 ++ paddle/fluid/pybind/pybind.cc | 13 +++++ python/paddle/fluid/framework.py | 28 ++++++++- python/paddle/fluid/initializer.py | 1 + python/paddle/fluid/layer_helper.py | 2 +- python/paddle/fluid/layers/tensor.py | 57 ++++++++++++------- python/paddle/fluid/optimizer.py | 45 +++++++++++---- .../tests/unittests/test_imperative_mnist.py | 7 ++- 11 files changed, 139 insertions(+), 38 deletions(-) diff --git a/paddle/fluid/imperative/layer.cc b/paddle/fluid/imperative/layer.cc index fcddcc4ed4..2c615275d1 100644 --- a/paddle/fluid/imperative/layer.cc +++ b/paddle/fluid/imperative/layer.cc @@ -104,7 +104,7 @@ class Autograd { framework::Variable* CreateVariable(const std::string& name, const framework::DDim& dim, float val, framework::Scope* scope, - bool random_name = true) { + bool random_name = false) { std::string varname = name; if (random_name) { std::mt19937 rng; diff --git a/paddle/fluid/imperative/layer.h b/paddle/fluid/imperative/layer.h index 90cc3ae1a9..56112f9a90 100644 --- a/paddle/fluid/imperative/layer.h +++ b/paddle/fluid/imperative/layer.h @@ -45,6 +45,15 @@ class VarBase { framework::LoDTensor& Grad(); + inline framework::Variable* GradVar() { return grads_; } + + inline std::string GradName() const { + PADDLE_ENFORCE( + var_desc_, + "Couldn't get gradient variable's name, please call backward() first"); + return string::Sprintf("%s@IGrad", var_desc_->Name()); + } + OpBase* pre_op_; int pre_op_out_idx_; diff --git a/paddle/fluid/imperative/tracer.h b/paddle/fluid/imperative/tracer.h index f6dac762fd..c885f39ced 100644 --- a/paddle/fluid/imperative/tracer.h +++ b/paddle/fluid/imperative/tracer.h @@ -52,7 +52,7 @@ class Tracer { const std::vector& outputs, framework::BlockDesc* block, const bool stop_gradient) { framework::OpDesc* op_desc = op->op_desc_; - VLOG(3) << "tracer tracing " << op_desc->Type(); + LOG(ERROR) << "tracer tracing " << op_desc->Type(); op_desc->InferShape(*block); op_desc->InferVarType(block); std::unique_ptr op_base = @@ -61,7 +61,10 @@ class Tracer { *op->input_vars_ = inputs; for (VarBase* input : inputs) { const std::string vname = input->var_desc_->Name(); + LOG(ERROR) << "input: " << vname; + LOG(ERROR) << "input var: " << input->var_; framework::Variable* var = root_scope_->Var(vname); + LOG(ERROR) << "var_ in tracer pointer: " << var; input->var_ = var; if (!var->IsInitialized()) { framework::VarDesc* var_desc = block->FindVar(vname); @@ -84,6 +87,7 @@ class Tracer { *op->output_vars_ = outputs; for (size_t i = 0; i < outputs.size(); ++i) { const std::string vname = outputs[i]->var_desc_->Name(); + LOG(ERROR) << "output name: " << vname; framework::Variable* var = root_scope_->Var(vname); if (!var->IsInitialized()) { framework::VarDesc* var_desc = block->FindVar(vname); @@ -98,7 +102,7 @@ class Tracer { outputs[i]->pre_op_out_idx_ = i; } - VLOG(3) << "tracer running " << op_desc->Type(); + LOG(ERROR) << "tracer running " << op_desc->Type(); op_base->Run(*root_scope_, platform::CPUPlace()); if (!stop_gradient) { framework::OpDesc* grad_op_desc; diff --git a/paddle/fluid/operators/optimizers/sgd_op.h b/paddle/fluid/operators/optimizers/sgd_op.h index 98bae5e1d3..ec4218497a 100644 --- a/paddle/fluid/operators/optimizers/sgd_op.h +++ b/paddle/fluid/operators/optimizers/sgd_op.h @@ -29,6 +29,8 @@ class SGDOpKernel : public framework::OpKernel { const auto *param_var = ctx.InputVar("Param"); const auto *grad_var = ctx.InputVar("Grad"); + LOG(ERROR) << "grad_var: " << grad_var; + if (param_var->IsType()) { const auto *param = ctx.Input("Param"); auto *param_out = ctx.Output("ParamOut"); @@ -39,8 +41,11 @@ class SGDOpKernel : public framework::OpKernel { const auto *grad = ctx.Input("Grad"); auto p = framework::EigenVector::Flatten(*param); + LOG(ERROR) << "param flattened"; auto g = framework::EigenVector::Flatten(*grad); + LOG(ERROR) << "grad flattened"; auto o = framework::EigenVector::Flatten(*param_out); + LOG(ERROR) << "paramout flattened"; auto *lr = learning_rate->data(); o = p - lr[0] * g; diff --git a/paddle/fluid/pybind/pybind.cc b/paddle/fluid/pybind/pybind.cc index 9608aa9d69..c690d1b8b3 100644 --- a/paddle/fluid/pybind/pybind.cc +++ b/paddle/fluid/pybind/pybind.cc @@ -117,10 +117,23 @@ PYBIND11_MODULE(core, m) { [](imperative::VarBase &self, framework::Scope *scope) { self.RunBackward(scope); }) + .def("_grad_var", + [](const imperative::VarBase &self) { + LOG(ERROR) << "grad_var_ pointer: " << self.grads_; + return self.grads_; + }, + py::return_value_policy::reference) + .def("_grad_name", &imperative::VarBase::GradName) .def("_grad", &imperative::VarBase::Grad) + .def("_print_var_pointer", + [](const imperative::VarBase &self) { + LOG(ERROR) << self.var_desc_->Name() + << " print_var pointer: " << self.var_; + }) .def_property("value", [](const imperative::VarBase &self) { return self.var_; }, [](imperative::VarBase &self, framework::Variable *var) { + LOG(ERROR) << "set var to pointer: " << var; self.var_ = var; }, py::return_value_policy::reference) diff --git a/python/paddle/fluid/framework.py b/python/paddle/fluid/framework.py index 3dc23bd060..9073fa79b0 100644 --- a/python/paddle/fluid/framework.py +++ b/python/paddle/fluid/framework.py @@ -19,7 +19,6 @@ import contextlib import os import re import six -import sys import numpy as np @@ -369,6 +368,7 @@ class Variable(object): self._ivar.stop_gradient = stop_gradient def _numpy(self): + print("get_variable_tensor", self.desc.name()) scope = _imperative_tracer().get_scope() tensor = core.get_variable_tensor(scope, self.desc.name()) return np.array(tensor) @@ -380,6 +380,14 @@ class Variable(object): def _gradient(self): return np.array(self._ivar._grad()) + @property + def _value(self): + return self._ivar.value + + @_value.setter + def _value(self, v): + self._ivar.value = v + def __str__(self): return self.to_string(True) @@ -632,6 +640,7 @@ class Operator(object): if inputs is not None: for in_proto in proto.inputs: + print("create op: find_name", in_proto.name) found = find_name(inputs, in_proto.name) assert found or in_proto.dispensable, "Input {} not found".format( in_proto.name) @@ -695,9 +704,11 @@ class Operator(object): self._update_desc_attr(attr_name, attr_val) self.desc.check_attrs() + if self._has_kernel(type): self.desc.infer_var_type(self.block.desc) self.desc.infer_shape(self.block.desc) + if _in_imperative_mode(): self.iop = core.OpBase() self.iop.desc = self.desc @@ -1167,6 +1178,7 @@ class Block(object): def create_var(self, *args, **kwargs): var = Variable(block=self, *args, **kwargs) if 'initializer' in kwargs: + print("initializer, ", type(kwargs['initializer'])) kwargs['initializer'](var, self) return var @@ -1281,6 +1293,16 @@ class Block(object): """ op_desc = self.desc.append_op() op = Operator(block=self, desc=op_desc, *args, **kwargs) + print("op inputs: ", [v._numpy() for v in op.inputs]) + print("op inputs: ", [v for v in op.inputs]) + import sys + sys.stdout.flush() + for v in op.inputs: + v._ivar._print_var_pointer() + print("print var pointer end") + import sys + sys.stdout.flush() + if _in_imperative_mode(): _imperative_tracer().trace(op.iop, [v._ivar for v in op.inputs], [v._ivar for v in op.outputs], self.desc, @@ -1338,6 +1360,10 @@ class Block(object): _imperative_tracer().trace(op.iop, [v._ivar for v in op.inputs], [v._ivar for v in op.outputs], self.desc, kwargs.get("stop_gradient", False)) + print([v.name for v in op.outputs]) + for v in op.outputs: + v._ivar._print_var_pointer() + print("fill_constant end") self.ops.insert(0, op) return op diff --git a/python/paddle/fluid/initializer.py b/python/paddle/fluid/initializer.py index 7acaed2250..fe8357aa06 100644 --- a/python/paddle/fluid/initializer.py +++ b/python/paddle/fluid/initializer.py @@ -153,6 +153,7 @@ class ConstantInitializer(Initializer): assert isinstance(var, framework.Variable) assert isinstance(block, framework.Block) # Initialization Ops should be prepended and not appended + print("fill_constant") op = block._prepend_op( type="fill_constant", outputs={"Out": var}, diff --git a/python/paddle/fluid/layer_helper.py b/python/paddle/fluid/layer_helper.py index eba5417723..f3413d7296 100644 --- a/python/paddle/fluid/layer_helper.py +++ b/python/paddle/fluid/layer_helper.py @@ -369,7 +369,7 @@ class LayerHelper(object): def set_variable_initializer(self, var, initializer): assert isinstance(var, Variable) - self.startup_program.global_block().create_var( + return self.startup_program.global_block().create_var( name=var.name, type=var.type, dtype=var.dtype, diff --git a/python/paddle/fluid/layers/tensor.py b/python/paddle/fluid/layers/tensor.py index 49a486cf0c..a7565aa108 100644 --- a/python/paddle/fluid/layers/tensor.py +++ b/python/paddle/fluid/layers/tensor.py @@ -20,6 +20,7 @@ from ..framework import convert_np_dtype_to_dtype_ from ..framework import Variable from ..initializer import Constant, force_init_on_cpu from ..core import VarDesc +from ..imperative import base as imperative_base from .layer_function_generator import templatedoc import numpy @@ -104,15 +105,15 @@ def create_global_var(shape, Args: shape(list[int]): shape of the variable - value(float): the value of the variable. The new created + value(float): the value of the variable. The new created variable will be filled with it. dtype(string): data type of the variable - persistable(bool): if this variable is persistable. + persistable(bool): if this variable is persistable. Default: False - force_cpu(bool): force this variable to be on CPU. + force_cpu(bool): force this variable to be on CPU. Default: False - name(str|None): The name of the variable. If set to None the variable - name will be generated automatically. + name(str|None): The name of the variable. If set to None the variable + name will be generated automatically. Default: None Returns: @@ -121,21 +122,33 @@ def create_global_var(shape, Examples: .. code-block:: python - var = fluid.create_global_var(shape=[2,3], value=1.0, dtype='float32', + var = fluid.create_global_var(shape=[2,3], value=1.0, dtype='float32', persistable=True, force_cpu=True, name='new_var') """ helper = LayerHelper("global_var", **locals()) var = helper.create_global_variable( - dtype=dtype, shape=shape, persistable=persistable, name=name) - helper.set_variable_initializer( - var, initializer=Constant( - value=float(value), force_cpu=force_cpu)) + dtype=dtype, + shape=shape, + persistable=persistable, + name=name, + stop_gradient=True) + print("set_variable_initializer, ", var.name) + if imperative_base.enabled(): + var = helper.set_variable_initializer( + var, initializer=Constant( + value=float(value), force_cpu=force_cpu)) + print("get var", var) + else: + helper.set_variable_initializer( + var, initializer=Constant( + value=float(value), force_cpu=force_cpu)) + return var def cast(x, dtype): """ - This layer takes in the Variable :attr:`x` with :attr:`x.dtype` and casts + This layer takes in the Variable :attr:`x` with :attr:`x.dtype` and casts it to the output with :attr:`dtype`. Args: @@ -199,9 +212,9 @@ def tensor_array_to_tensor(input, axis=1, name=None): and returns that as the output. A simple example as below: - + .. code-block:: text - + Given: input.data = {[[0.6, 0.1, 0.3], @@ -210,9 +223,9 @@ def tensor_array_to_tensor(input, axis=1, name=None): [1.8]], [[2.3, 2.1], [2.5, 2.4]]} - + axis = 1 - + Then: output.data = [[0.6, 0.1, 0.3, 1.3, 2.3, 2.1], @@ -493,12 +506,12 @@ def argmax(x, axis=0): def argsort(input, axis=-1, name=None): """ - Performs sorting on the input Variable along the given axis, and outputs - sorted data Varibale and its corresponding index Variable with the same + Performs sorting on the input Variable along the given axis, and outputs + sorted data Varibale and its corresponding index Variable with the same shape as :attr:`input`. .. code-block:: text - + For example, the given axis is -1 and the input Variable input = [[0.15849551, 0.45865775, 0.8563702 ], @@ -511,15 +524,15 @@ def argsort(input, axis=-1, name=None): and the sorted indices along the given axis turn outs to be - indices = [[0, 1, 2], + indices = [[0, 1, 2], [0, 2, 1]] Args: input(Variable): The input Variable for sorting. - axis(int): The axis along which to sort the input Variable. When - :attr:`axis` < 0, the actual axis will be :attr:`axis` + + axis(int): The axis along which to sort the input Variable. When + :attr:`axis` < 0, the actual axis will be :attr:`axis` + rank(:attr:`input`). Default -1, the last dimension. - name(str|None): (optional) A name for this layer. If set None, the + name(str|None): (optional) A name for this layer. If set None, the layer will be named automatically. Returns: diff --git a/python/paddle/fluid/optimizer.py b/python/paddle/fluid/optimizer.py index 59c22d4e49..7e90d47870 100644 --- a/python/paddle/fluid/optimizer.py +++ b/python/paddle/fluid/optimizer.py @@ -30,6 +30,7 @@ from .initializer import Constant from .layer_helper import LayerHelper from .layers import ops from .regularizer import append_regularization_ops +from .imperative import base as imperative_base __all__ = [ 'SGD', 'Momentum', 'Adagrad', 'Adam', 'Adamax', 'DecayedAdagrad', 'Ftrl', @@ -108,6 +109,7 @@ class Optimizer(object): # create learning rate variable for every parameter param = param_and_grad[0] param_lr = param.optimize_attr['learning_rate'] + print("param_lr: ", param_lr, self._global_learning_rate()._numpy()) if type(param_lr) == Variable: return param_lr else: @@ -301,19 +303,38 @@ class Optimizer(object): This method combines interface `append_backward()` and `create_optimization_pass()` into one. """ - params_grads = append_backward(loss, parameter_list, no_grad_set, - [error_clip_callback]) + if imperative_base.enabled: + if parameter_list is not None: + params_grads = parameter_list + else: + program = loss.block.program + parameters = program.global_block().all_parameters() + params_grads = [] + for param in parameters: + grad_var = Variable( + block=loss.block, + name=param._ivar._grad_name(), + stop_gradient=True) + grad_var._value = param._ivar._grad_var() + print("create grad var: ", grad_var.name) + print("grad_var value: ", grad_var._numpy()) + import sys + sys.stdout.flush() + params_grads.append((param, grad_var)) + else: + params_grads = append_backward(loss, parameter_list, no_grad_set, + [error_clip_callback]) - params_grads = sorted(params_grads, key=lambda x: x[0].name) + params_grads = sorted(params_grads, key=lambda x: x[0].name) - params_grads, table_param_and_grad, table_optimize_op = \ - self._process_distribute_lookuptable(params_grads, loss, startup_program) + params_grads, table_param_and_grad, table_optimize_op = \ + self._process_distribute_lookuptable(params_grads, loss, startup_program) - params_grads = append_gradient_clip_ops(params_grads) + params_grads = append_gradient_clip_ops(params_grads) - # Add regularization if any - params_grads = append_regularization_ops(params_grads, - self.regularization) + # Add regularization if any + params_grads = append_regularization_ops(params_grads, + self.regularization) optimize_ops = self._create_optimization_pass(params_grads, loss, startup_program) @@ -356,6 +377,10 @@ class SGDOptimizer(Optimizer): def _append_optimize_op(self, block, param_and_grad): assert isinstance(block, framework.Block) + print("append sgd") + import sys + sys.stdout.flush() + # create the optimize op sgd_op = block.append_op( type=self.type, @@ -477,7 +502,7 @@ class LarsMomentumOptimizer(Optimizer): regularization: A Regularizer, such as fluid.regularizer.L2DecayRegularizer. name: A optional name prefix. - + Examples: .. code-block:: python diff --git a/python/paddle/fluid/tests/unittests/test_imperative_mnist.py b/python/paddle/fluid/tests/unittests/test_imperative_mnist.py index 9d1e079998..12d605316c 100644 --- a/python/paddle/fluid/tests/unittests/test_imperative_mnist.py +++ b/python/paddle/fluid/tests/unittests/test_imperative_mnist.py @@ -18,6 +18,7 @@ import numpy as np import paddle.fluid as fluid from paddle.fluid import core +from paddle.fluid.optimizer import SGDOptimizer from paddle.fluid.imperative.nn import Conv2D, Pool2D, FC from paddle.fluid.imperative.base import to_variable @@ -119,7 +120,11 @@ class TestImperativeMnist(unittest.TestCase): out._backward() filter_grad = mnist._simple_img_conv_pool_1._conv2d._filter_param._gradient( ) - print(filter_grad) + # print(filter_grad) + + sgd = SGDOptimizer(learning_rate=1e-3) + sgd.minimize(out) + # np_inp = np.array([[1.0, 2.0], [3.0, 4.0]], dtype=np.float32) # with fluid.imperative.guard(): # mlp = MLP() -- GitLab