diff --git a/paddle/fluid/imperative/layer.cc b/paddle/fluid/imperative/layer.cc index d0aaa00c49ff1efa9d7a95b5fc63abc4fec4a4be..7594670cd2608802bdf41682ef5724a7a965d754 100644 --- a/paddle/fluid/imperative/layer.cc +++ b/paddle/fluid/imperative/layer.cc @@ -44,7 +44,7 @@ void AddTo(Variable* src, Variable* dst) { src_tensor->numel()); float* dst_data = dst_tensor->mutable_data(platform::CPUPlace()); const float* src_data = src_tensor->data(); - for (size_t i = 0; i < src_tensor->numel(); ++i) { + for (int64_t i = 0; i < src_tensor->numel(); ++i) { dst_data[i] += src_data[i]; } } @@ -117,9 +117,9 @@ class Autograd { } }; -framework::LoDTensor& VarBase::Grad() { +framework::LoDTensor& VarBase::GradValue() { VLOG(3) << "get var grad " << var_desc_->Name(); - return *grads_->GetMutable(); + return *(grads_->var_->GetMutable()); } std::map> OpBase::ApplyGrad() { @@ -183,7 +183,7 @@ void VarBase::RunBackward() { if (!pre_op_) return; VLOG(3) << "start backward"; - auto grads_t = grads_->GetMutable(); + auto grads_t = grads_->var_->GetMutable(); float* data = grads_t->mutable_data(platform::CPUPlace()); std::fill(data, data + grads_t->numel(), 1.0); @@ -209,7 +209,7 @@ std::vector PyLayer::Apply(int func_id, std::vector outvars = CallPythonFunc(py_funcs_[func_id], invars); std::vector ret; for (Variable* v : outvars) { - ret.push_back(new VarBase(v, new Variable())); + ret.push_back(new VarBase(v, new VarBase(true))); } return ret; } diff --git a/paddle/fluid/imperative/layer.h b/paddle/fluid/imperative/layer.h index 4be0614c7e9b95d03db6b3c5bb1c4e9f978cc91b..86c2dc3fa4a7d03aa8f0a89a25c17656e1cd708c 100644 --- a/paddle/fluid/imperative/layer.h +++ b/paddle/fluid/imperative/layer.h @@ -17,13 +17,14 @@ #include #include #include -#include "pybind11/pybind11.h" -#include "Python.h" #include "paddle/fluid/framework/op_desc.h" #include "paddle/fluid/framework/operator.h" #include "paddle/fluid/framework/var_desc.h" #include "paddle/fluid/platform/enforce.h" +#include "pybind11/pybind11.h" + +#include "paddle/fluid/imperative/type_defs.h" namespace paddle { namespace imperative { @@ -85,13 +86,19 @@ class PreparedOp { class OpBase; +/* The wrapper for Variable which holds a Variable and a VarBase of its + * gradient. This object should be managed totally by Python intepreter. + * + * Nearly all interface should be implemented in C++. + */ class VarBase { public: - VarBase() : VarBase(new framework::Variable(), new framework::Variable()) {} + VarBase() : VarBase(new framework::Variable(), new VarBase(true)) {} // Owns `var` and `grad` - VarBase(framework::Variable* var, framework::Variable* grad) + VarBase(framework::Variable* var, VarBase* grad) : pre_op_(nullptr), + pre_op_out_name_(), pre_op_out_idx_(-1), var_desc_(nullptr), var_(var), @@ -100,17 +107,26 @@ class VarBase { explicit VarBase(bool stop_gradient) : pre_op_(nullptr), + pre_op_out_name_(), pre_op_out_idx_(-1), var_desc_(nullptr), var_(new framework::Variable()), - grads_(new framework::Variable()), + grads_(stop_gradient ? nullptr : new VarBase(true)), stop_gradient_(stop_gradient) {} - virtual ~VarBase() {} + virtual ~VarBase() { + if (var_) { + delete var_; + } + + if (grads_) { + delete grads_; + } + } void RunBackward(); - framework::LoDTensor& Grad(); + framework::LoDTensor& GradValue(); inline std::string GradName() const { PADDLE_ENFORCE( @@ -124,12 +140,16 @@ class VarBase { int pre_op_out_idx_; framework::VarDesc* var_desc_; + framework::Variable* var_; - framework::Variable* grads_; + VarBase* grads_; bool stop_gradient_; }; +/* The wrapper for OpDesc which holds a OpDesc and a OpDesc of its + * gradient. This object should be managed totally by Python intepreter. + */ class OpBase { public: OpBase() @@ -153,13 +173,13 @@ class OpBase { framework::OpDesc* grad_op_desc_; int backward_id_; - std::map> input_vars_; - std::map> output_vars_; - std::map> pre_ops_; + VarBasePtrMap input_vars_; + VarBasePtrMap output_vars_; + OpBasePtrMap pre_ops_; std::map> pre_ops_out_idx_; - std::map> grad_input_vars_; - std::map> grad_output_vars_; + framework::VariableValueMap grad_input_vars_; + framework::VariableValueMap grad_output_vars_; framework::BlockDesc* block_; }; diff --git a/paddle/fluid/imperative/tracer.cc b/paddle/fluid/imperative/tracer.cc index f64f9e72c4a23528948183b909d65e90783a4463..a01225ccee4a82f77ec2a23df75d1cf7b719bdb7 100644 --- a/paddle/fluid/imperative/tracer.cc +++ b/paddle/fluid/imperative/tracer.cc @@ -15,5 +15,199 @@ #include "paddle/fluid/imperative/tracer.h" namespace paddle { -namespace imperative {} // namespace imperative +namespace imperative { + +void CreateGradOp(const framework::OpDesc& op_desc, + const std::unordered_set& no_grad_set, + const std::vector& grad_sub_block, + framework::OpDesc** grad_op_desc, + std::unordered_map* grad_to_var) { + std::vector> grad_op_descs = + framework::OpInfoMap::Instance() + .Get(op_desc.Type()) + .GradOpMaker()(op_desc, no_grad_set, grad_to_var, grad_sub_block); + PADDLE_ENFORCE(grad_op_descs.size() == 1, "Only support 1 grad op now."); + // TODO(panyx0718): Leak? + *grad_op_desc = grad_op_descs[0].release(); +} + +void InitVar(framework::Variable* var, framework::Variable* grad_var) { + auto& var_t = var->Get(); + float* data = + grad_var->GetMutable()->mutable_data( + var_t.dims(), platform::CPUPlace()); + std::fill(data, data + var_t.numel(), 0.0); +} + +void Tracer::Trace(OpBase* op, const VarBasePtrMap& inputs, + const VarBasePtrMap& outputs, framework::BlockDesc* block, + const bool stop_gradient) { + std::map vars; + + framework::OpDesc* op_desc = op->op_desc_; + VLOG(3) << "tracer tracing " << op_desc->Type(); + op_desc->InferShape(*block); + op_desc->InferVarType(block); + std::unique_ptr op_base = + framework::OpRegistry::CreateOp(*op_desc); + + framework::VariableValueMap invars_map; + framework::VariableValueMap outvars_map; + + op->input_vars_ = inputs; + for (auto it : op->input_vars_) { + auto& invars = invars_map[it.first]; + for (VarBase* inp : it.second) { + PADDLE_ENFORCE_NOT_NULL(inp->var_, "op %s input %s nullptr", + op->op_desc_->Type(), inp->var_desc_->Name()); + + invars.push_back(inp->var_); + vars[inp->var_desc_->Name()] = inp; + if (inp->pre_op_) { + op->pre_ops_[it.first].push_back(inp->pre_op_); + op->pre_ops_out_idx_[it.first].push_back(inp->pre_op_out_idx_); + } else { + op->pre_ops_[it.first].push_back(nullptr); + } + VLOG(3) << "input vname " << inp->var_desc_->Name() << " " + << inp->var_->IsInitialized(); + } + } + + op->output_vars_ = outputs; + for (auto it : op->output_vars_) { + auto& outvars = outvars_map[it.first]; + const std::vector& outputs = it.second; + for (size_t i = 0; i < outputs.size(); ++i) { + VarBase* out = outputs[i]; + outvars.push_back(out->var_); + vars[out->var_desc_->Name()] = out; + + framework::VarDesc* var_desc = block->FindVar(out->var_desc_->Name()); + if (var_desc->GetType() == framework::proto::VarType::LOD_TENSOR) { + out->var_->GetMutable(); + } else { + LOG(ERROR) << "tracer doesn't support yet"; + } + out->stop_gradient_ = stop_gradient; + out->pre_op_ = op; + out->pre_op_out_name_ = it.first; + out->pre_op_out_idx_ = i; + + VLOG(3) << "output vname " << out->var_desc_->Name() << " " + << out->var_->IsInitialized(); + } + } + + VLOG(3) << "tracer running " << op_desc->Type(); + framework::RuntimeContext ctx(invars_map, outvars_map); + + // TODO(panyx0718): Cache p. + framework::OperatorWithKernel* op_kernel = + dynamic_cast(op_base.get()); + PADDLE_ENFORCE_NOT_NULL(op_kernel, "only support op with kernel"); + + framework::Scope scope; + platform::CPUPlace place; + PreparedOp p = PreparedOp::Prepare(ctx, *op_kernel, place); + p.op.RuntimeInferShape(scope, place, ctx); + p.func(framework::ExecutionContext(p.op, scope, *p.dev_ctx, p.ctx)); + + if (!stop_gradient) { + framework::OpDesc* grad_op_desc; + // TODO(panyx): Is this leaked? + std::unique_ptr> grad_to_var( + new std::unordered_map()); + CreateGradOp(*op_desc, {}, {block}, &grad_op_desc, grad_to_var.get()); + op->grad_op_desc_ = grad_op_desc; + + for (auto it : grad_op_desc->Inputs()) { + auto& grad_in_vars = op->grad_input_vars_[it.first]; + for (const std::string& grad_invar : it.second) { + block->FindRecursiveOrCreateVar(grad_invar); + auto var_it = grad_to_var->find(grad_invar); + if (var_it == grad_to_var->end()) { + auto fwd_var_it = vars.find(grad_invar); + PADDLE_ENFORCE(fwd_var_it != vars.end()); + // Forward inputs or outputs. + grad_in_vars.push_back(fwd_var_it->second->var_); + } else { + VarBase* var = vars[var_it->second]; + if (!var->grads_->var_->IsInitialized()) { + InitVar(var->var_, var->grads_->var_); + } + // Douts. + grad_in_vars.push_back(var->grads_->var_); + } + } + } + + for (auto it : grad_op_desc->Outputs()) { + auto& grad_out_vars = op->grad_output_vars_[it.first]; + for (const std::string& grad_outvar : it.second) { + block->FindRecursiveOrCreateVar(grad_outvar); + auto var_it = grad_to_var->find(grad_outvar); + PADDLE_ENFORCE(var_it != grad_to_var->end()); + VarBase* var = vars[var_it->second]; + if (!var->grads_->var_->IsInitialized()) { + InitVar(var->var_, var->grads_->var_); + } + grad_out_vars.push_back(var->grads_->var_); + } + } + } + + op->block_ = block; +} + +std::vector Tracer::PyTrace(OpBase* op, + const std::vector& inputs, + bool stop_gradient) { + VLOG(3) << "py_trace"; + op->input_vars_["X"] = inputs; + op->output_vars_["Out"] = PyLayer::Apply(op->forward_id_, inputs); + for (VarBase* inp : inputs) { + if (inp->pre_op_) { + op->pre_ops_["X"].push_back(inp->pre_op_); + op->pre_ops_out_idx_["X"].push_back(inp->pre_op_out_idx_); + } else { + op->pre_ops_["X"].push_back(nullptr); + } + } + + auto& outputs = op->output_vars_["Out"]; + for (size_t i = 0; i < outputs.size(); ++i) { + VarBase* out = outputs[i]; + out->stop_gradient_ = stop_gradient; + out->pre_op_ = op; + out->pre_op_out_name_ = "Out"; + out->pre_op_out_idx_ = i; + } + if (!stop_gradient) { + auto& grad_input_vars = op->grad_input_vars_["X@GRAD"]; + auto& grad_output_vars = op->grad_output_vars_["Out@GRAD"]; + + for (const VarBase* inp : inputs) { + grad_input_vars.push_back(inp->var_); + } + for (VarBase* out : outputs) { + grad_input_vars.push_back(out->var_); + } + for (VarBase* out : outputs) { + grad_input_vars.push_back(out->grads_->var_); + if (!grad_input_vars.back()->IsInitialized()) { + InitVar(out->var_, grad_input_vars.back()); + } + } + for (const VarBase* inp : inputs) { + grad_output_vars.push_back(inp->grads_->var_); + if (!grad_output_vars.back()->IsInitialized()) { + InitVar(inp->var_, grad_output_vars.back()); + } + } + } + return outputs; +} + +} // namespace imperative } // namespace paddle diff --git a/paddle/fluid/imperative/tracer.h b/paddle/fluid/imperative/tracer.h index f68a67e5d745edb04da441fdb0a64edfdd4230ed..f225d8abe6c0635d2bdd8dba0b12c7fc3a4110db 100644 --- a/paddle/fluid/imperative/tracer.h +++ b/paddle/fluid/imperative/tracer.h @@ -30,23 +30,9 @@ void CreateGradOp(const framework::OpDesc& op_desc, const std::unordered_set& no_grad_set, const std::vector& grad_sub_block, framework::OpDesc** grad_op_desc, - std::unordered_map* grad_to_var) { - std::vector> grad_op_descs = - framework::OpInfoMap::Instance() - .Get(op_desc.Type()) - .GradOpMaker()(op_desc, no_grad_set, grad_to_var, grad_sub_block); - PADDLE_ENFORCE(grad_op_descs.size() == 1, "Only support 1 grad op now."); - // TODO(panyx0718): Leak? - *grad_op_desc = grad_op_descs[0].release(); -} + std::unordered_map* grad_to_var); -void InitVar(framework::Variable* var, framework::Variable* grad_var) { - auto& var_t = var->Get(); - float* data = - grad_var->GetMutable()->mutable_data( - var_t.dims(), platform::CPUPlace()); - std::fill(data, data + var_t.numel(), 0.0); -} +void InitVar(framework::Variable* var, framework::Variable* grad_var); class Tracer { public: @@ -57,172 +43,10 @@ class Tracer { void Trace(OpBase* op, const std::map>& inputs, const std::map>& outputs, - framework::BlockDesc* block, const bool stop_gradient = false) { - std::map vars; - - framework::OpDesc* op_desc = op->op_desc_; - VLOG(3) << "tracer tracing " << op_desc->Type(); - op_desc->InferShape(*block); - op_desc->InferVarType(block); - std::unique_ptr op_base = - framework::OpRegistry::CreateOp(*op_desc); - - framework::VariableValueMap invars_map; - framework::VariableValueMap outvars_map; - - op->input_vars_ = inputs; - for (auto it : op->input_vars_) { - auto& invars = invars_map[it.first]; - for (VarBase* inp : it.second) { - PADDLE_ENFORCE_NOT_NULL(inp->var_, "op %s input %s nullptr", - op->op_desc_->Type(), inp->var_desc_->Name()); - - invars.push_back(inp->var_); - vars[inp->var_desc_->Name()] = inp; - if (inp->pre_op_) { - op->pre_ops_[it.first].push_back(inp->pre_op_); - op->pre_ops_out_idx_[it.first].push_back(inp->pre_op_out_idx_); - } else { - op->pre_ops_[it.first].push_back(nullptr); - } - VLOG(3) << "input vname " << inp->var_desc_->Name() << " " - << inp->var_->IsInitialized(); - } - } - - op->output_vars_ = outputs; - for (auto it : op->output_vars_) { - auto& outvars = outvars_map[it.first]; - const std::vector& outputs = it.second; - for (size_t i = 0; i < outputs.size(); ++i) { - VarBase* out = outputs[i]; - outvars.push_back(out->var_); - vars[out->var_desc_->Name()] = out; - - framework::VarDesc* var_desc = block->FindVar(out->var_desc_->Name()); - if (var_desc->GetType() == framework::proto::VarType::LOD_TENSOR) { - out->var_->GetMutable(); - } else { - LOG(ERROR) << "tracer doesn't support yet"; - } - out->stop_gradient_ = stop_gradient; - out->pre_op_ = op; - out->pre_op_out_name_ = it.first; - out->pre_op_out_idx_ = i; - - VLOG(3) << "output vname " << out->var_desc_->Name() << " " - << out->var_->IsInitialized(); - } - } - - VLOG(3) << "tracer running " << op_desc->Type(); - framework::RuntimeContext ctx(invars_map, outvars_map); - - // TODO(panyx0718): Cache p. - framework::OperatorWithKernel* op_kernel = - dynamic_cast(op_base.get()); - PADDLE_ENFORCE_NOT_NULL(op_kernel, "only support op with kernel"); - - framework::Scope scope; - platform::CPUPlace place; - PreparedOp p = PreparedOp::Prepare(ctx, *op_kernel, place); - p.op.RuntimeInferShape(scope, place, ctx); - p.func(framework::ExecutionContext(p.op, scope, *p.dev_ctx, p.ctx)); - - if (!stop_gradient) { - framework::OpDesc* grad_op_desc; - // TODO(panyx): Is this leaked? - std::unique_ptr> grad_to_var( - new std::unordered_map()); - CreateGradOp(*op_desc, {}, {block}, &grad_op_desc, grad_to_var.get()); - op->grad_op_desc_ = grad_op_desc; - - for (auto it : grad_op_desc->Inputs()) { - auto& grad_in_vars = op->grad_input_vars_[it.first]; - for (const std::string& grad_invar : it.second) { - block->FindRecursiveOrCreateVar(grad_invar); - auto var_it = grad_to_var->find(grad_invar); - if (var_it == grad_to_var->end()) { - auto fwd_var_it = vars.find(grad_invar); - PADDLE_ENFORCE(fwd_var_it != vars.end()); - // Forward inputs or outputs. - grad_in_vars.push_back(fwd_var_it->second->var_); - } else { - VarBase* var = vars[var_it->second]; - if (!var->grads_->IsInitialized()) { - InitVar(var->var_, var->grads_); - } - // Douts. - grad_in_vars.push_back(var->grads_); - } - } - } - - for (auto it : grad_op_desc->Outputs()) { - auto& grad_out_vars = op->grad_output_vars_[it.first]; - for (const std::string& grad_outvar : it.second) { - block->FindRecursiveOrCreateVar(grad_outvar); - auto var_it = grad_to_var->find(grad_outvar); - PADDLE_ENFORCE(var_it != grad_to_var->end()); - VarBase* var = vars[var_it->second]; - if (!var->grads_->IsInitialized()) { - InitVar(var->var_, var->grads_); - } - grad_out_vars.push_back(var->grads_); - } - } - } - - op->block_ = block; - } + framework::BlockDesc* block, const bool stop_gradient = false); std::vector PyTrace(OpBase* op, const std::vector& inputs, - bool stop_gradient = false) { - VLOG(3) << "py_trace"; - op->input_vars_["X"] = inputs; - op->output_vars_["Out"] = PyLayer::Apply(op->forward_id_, inputs); - for (VarBase* inp : inputs) { - if (inp->pre_op_) { - op->pre_ops_["X"].push_back(inp->pre_op_); - op->pre_ops_out_idx_["X"].push_back(inp->pre_op_out_idx_); - } else { - op->pre_ops_["X"].push_back(nullptr); - } - } - - auto& outputs = op->output_vars_["Out"]; - for (size_t i = 0; i < outputs.size(); ++i) { - VarBase* out = outputs[i]; - out->stop_gradient_ = stop_gradient; - out->pre_op_ = op; - out->pre_op_out_name_ = "Out"; - out->pre_op_out_idx_ = i; - } - if (!stop_gradient) { - auto& grad_input_vars = op->grad_input_vars_["X@GRAD"]; - auto& grad_output_vars = op->grad_output_vars_["Out@GRAD"]; - - for (const VarBase* inp : inputs) { - grad_input_vars.push_back(inp->var_); - } - for (VarBase* out : outputs) { - grad_input_vars.push_back(out->var_); - } - for (VarBase* out : outputs) { - grad_input_vars.push_back(out->grads_); - if (!grad_input_vars.back()->IsInitialized()) { - InitVar(out->var_, grad_input_vars.back()); - } - } - for (const VarBase* inp : inputs) { - grad_output_vars.push_back(inp->grads_); - if (!grad_output_vars.back()->IsInitialized()) { - InitVar(inp->var_, grad_output_vars.back()); - } - } - } - return outputs; - } + bool stop_gradient = false); private: framework::BlockDesc* root_block_; diff --git a/paddle/fluid/imperative/type_defs.h b/paddle/fluid/imperative/type_defs.h new file mode 100644 index 0000000000000000000000000000000000000000..fc9e42f8d0e9996176a5cbab7d8c7cf08ddce1af --- /dev/null +++ b/paddle/fluid/imperative/type_defs.h @@ -0,0 +1,31 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include +#include +#include + +namespace paddle { +namespace imperative { + +class VarBase; +class OpBase; + +typedef std::map> VarBasePtrMap; +typedef std::map> OpBasePtrMap; + +} // namespace imperative +} // namespace paddle diff --git a/paddle/fluid/pybind/CMakeLists.txt b/paddle/fluid/pybind/CMakeLists.txt index ca2764e64f6b85450efb257f94498169d40ccff8..9a91ea38caef50b9a7ad970a3d08ca28c497e419 100644 --- a/paddle/fluid/pybind/CMakeLists.txt +++ b/paddle/fluid/pybind/CMakeLists.txt @@ -1,5 +1,6 @@ - -set(PYBIND_DEPS pybind python proto_desc memory executor async_executor prune feed_fetch_method pass_builder parallel_executor profiler layer scope_pool) +set(PYBIND_DEPS pybind python proto_desc memory executor async_executor prune + feed_fetch_method pass_builder parallel_executor profiler layer scope_pool + tracer) if(WITH_PYTHON) list(APPEND PYBIND_DEPS py_func_op) endif() diff --git a/paddle/fluid/pybind/pybind.cc b/paddle/fluid/pybind/pybind.cc index 72689bd6068a357d7c0338490ca0c507fe7b40fd..f3f4854a9efbcf5ab325e7f6aec81135c018dcd5 100644 --- a/paddle/fluid/pybind/pybind.cc +++ b/paddle/fluid/pybind/pybind.cc @@ -126,26 +126,18 @@ PYBIND11_MODULE(core, m) { m.add_object("_cleanup", py::capsule([]() { ScopePool::Instance().Clear(); })); - py::class_>( - m, "VarBase", R"DOC()DOC") + py::class_(m, "VarBase", R"DOC()DOC") // .def(py::init<>()) .def(py::init(), py::arg("stop_gradient") = false) .def("_run_backward", [](imperative::VarBase &self) { self.RunBackward(); }) .def("_grad_name", &imperative::VarBase::GradName) - .def("_grad", &imperative::VarBase::Grad) - .def_property("grad_value", - [](const imperative::VarBase &self) { return self.grads_; }, - [](imperative::VarBase &self, framework::Variable *grad) { - self.grads_ = grad; - }, - py::return_value_policy::reference) - .def_property("value", - [](const imperative::VarBase &self) { return self.var_; }, - [](imperative::VarBase &self, framework::Variable *var) { - self.var_ = var; - }, - py::return_value_policy::reference) + .def("_grad_value", &imperative::VarBase::GradValue) + .def("_grad_ivar", + [](const imperative::VarBase &self) { return self.grads_; }, + py::return_value_policy::reference) + .def("value", [](const imperative::VarBase &self) { return self.var_; }, + py::return_value_policy::reference) .def_property( "desc", [](const imperative::VarBase &self) { return self.var_desc_; }, diff --git a/python/paddle/fluid/framework.py b/python/paddle/fluid/framework.py index 8e18dffac300b8081f880bf408b9727071798c17..8d061f41f09a88d06a6b0018d95793e8cadbcdf3 100644 --- a/python/paddle/fluid/framework.py +++ b/python/paddle/fluid/framework.py @@ -372,30 +372,21 @@ class Variable(object): self.stop_gradient = stop_gradient self.is_data = is_data if _in_imperative_mode(): - if 'ivar' in kwargs: - self._ivar = kwargs['ivar'] - else: + self._ivar = kwargs.get("ivar", None) + if not self._ivar: self._ivar = core.VarBase() self._ivar.desc = self.desc self._ivar.stop_gradient = stop_gradient def _numpy(self): - tensor = self._ivar.value.get_tensor() + tensor = self._ivar.value().get_tensor() return np.array(tensor) def _backward(self): self._ivar._run_backward() def _gradient(self): - return np.array(self._ivar._grad()) - - @property - def _value(self): - return self._ivar.value - - @_value.setter - def _value(self, v): - self._ivar.value = v + return np.array(self._ivar._grad_value()) def __str__(self): return self.to_string(True) diff --git a/python/paddle/fluid/imperative/base.py b/python/paddle/fluid/imperative/base.py index c04dcc7e39be9946b561fd725647e87d7712d8b9..5d3ebb25a935cea6ec376e6bc044281dcba37337 100644 --- a/python/paddle/fluid/imperative/base.py +++ b/python/paddle/fluid/imperative/base.py @@ -45,7 +45,7 @@ def to_variable(value, block=None): name=None, shape=value.shape, dtype=value.dtype) - var = py_var._ivar.value + var = py_var._ivar.value() tensor = var.get_tensor() tensor.set(value, core.CPUPlace()) return py_var diff --git a/python/paddle/fluid/imperative/layers.py b/python/paddle/fluid/imperative/layers.py index 8027d9ba3bcf4d37f3573bc928faf574dcde1038..6d3987c9d5437463960910834a2202be9fb32cfe 100644 --- a/python/paddle/fluid/imperative/layers.py +++ b/python/paddle/fluid/imperative/layers.py @@ -55,18 +55,18 @@ class PyLayer(core.PyLayer): super(PyLayer, self).__init__() @staticmethod - def forward(inputs): + def forward(*inputs): raise NotImplementedError @staticmethod - def backward(douts): + def backward(*douts): raise NotImplementedError @classmethod - def __call__(cls, inputs): + def __call__(cls, *inputs): tracer = framework._imperative_tracer() block = framework.default_main_program().current_block() - inputs = [x._ivar for x in inputs] + ivar_inputs = [x._ivar for x in inputs] if not hasattr(cls, 'forward_id'): cls.forward_id = core.PyLayer.num_funcs() + 1 @@ -78,11 +78,11 @@ class PyLayer(core.PyLayer): iop.forward_id = cls.forward_id iop.backward_id = cls.backward_id block.ops.append(iop) - ivars = tracer.py_trace(iop, inputs, False) + ivars = tracer.py_trace(iop, ivar_inputs, False) # ivars = core.PyLayer.apply(cls.forward, inputs) ret = [] for ivar in ivars: - tensor = ivar.value.get_tensor() + tensor = ivar.value().get_tensor() py_var = framework.Variable( block, type=core.VarDesc.VarType.LOD_TENSOR, diff --git a/python/paddle/fluid/optimizer.py b/python/paddle/fluid/optimizer.py index bf3730ce51fdd59af85184ce5f5b0cad8ef0e6d3..f01a0eda9a711abb3265fe5bb86ecb702a6ac6aa 100644 --- a/python/paddle/fluid/optimizer.py +++ b/python/paddle/fluid/optimizer.py @@ -390,8 +390,8 @@ class Optimizer(object): grad_var = Variable( block=loss.block, name=param._ivar._grad_name(), - stop_gradient=True) - grad_var._value = param._ivar.grad_value + stop_gradient=True, + ivar=param._ivar._grad_ivar()) params_grads.append((param, grad_var)) with program_guard(program, startup_program): optimize_ops = self._create_optimization_pass(params_grads) diff --git a/python/paddle/fluid/tests/unittests/test_imperative.py b/python/paddle/fluid/tests/unittests/test_imperative.py index e3e1ce7ca3127969e9c4430649a18b08e0e71889..86baff3c589d7b8a14938886b3e2104b0beb1cc9 100644 --- a/python/paddle/fluid/tests/unittests/test_imperative.py +++ b/python/paddle/fluid/tests/unittests/test_imperative.py @@ -97,35 +97,35 @@ class TestImperative(unittest.TestCase): super(PyLayer1, self).__init__() @staticmethod - def forward(inputs): - return inputs + def forward(input): + return input @staticmethod - def backward(inputs): - return inputs + def backward(input): + return input class PyLayer2(fluid.imperative.PyLayer): def __init__(self): super(PyLayer2, self).__init__() @staticmethod - def forward(inputs): - return inputs + def forward(input): + return input @staticmethod - def backward(inputs): - return inputs + def backward(input): + return input py_layer_1 = PyLayer1() py_layer_2 = PyLayer2() - py_layer_1([fluid.imperative.base.to_variable(np.ones([2, 2]))]) - py_layer_2([fluid.imperative.base.to_variable(np.ones([2, 2]))]) + py_layer_1(fluid.imperative.base.to_variable(np.ones([2, 2]))) + py_layer_2(fluid.imperative.base.to_variable(np.ones([2, 2]))) id = py_layer_1.forward_id self.assertGreater(id, 0) self.assertEqual(py_layer_1.backward_id, id + 1) self.assertEqual(py_layer_2.forward_id, id + 2) self.assertEqual(py_layer_2.backward_id, id + 3) - py_layer_1([fluid.imperative.base.to_variable(np.ones([2, 2]))]) + py_layer_1(fluid.imperative.base.to_variable(np.ones([2, 2]))) self.assertEqual(py_layer_1.forward_id, id) def test_pylayer(self): @@ -133,7 +133,7 @@ class TestImperative(unittest.TestCase): with fluid.imperative.guard(): my_py_layer = MyPyLayer() var_inp = fluid.imperative.base.to_variable(np_inp) - outs = my_py_layer([var_inp]) + outs = my_py_layer(var_inp) dy_out = np.sum(outs[0]._numpy()) outs[0]._backward() dy_grad = var_inp._gradient() diff --git a/python/paddle/fluid/tests/unittests/test_imperative_optimizer.py b/python/paddle/fluid/tests/unittests/test_imperative_optimizer.py index 0549f50fe2673e87a87cdcf1929b4ca2ebce70f9..63eeae4b712c2064309b664b91d5f0347b67817d 100644 --- a/python/paddle/fluid/tests/unittests/test_imperative_optimizer.py +++ b/python/paddle/fluid/tests/unittests/test_imperative_optimizer.py @@ -105,7 +105,6 @@ class TestImperativeMnist(unittest.TestCase): fluid.default_startup_program().random_seed = seed fluid.default_main_program().random_seed = seed - # mnist = Conv2D(1, 20, 5) mnist = MNIST() sgd = SGDOptimizer(learning_rate=1e-3) train_reader = paddle.batch( @@ -126,16 +125,17 @@ class TestImperativeMnist(unittest.TestCase): label._stop_gradient = True cost = mnist(img) - loss = fluid.layers.reduce_mean(cost) - dy_out = loss._numpy() + loss = fluid.layers.cross_entropy(cost, label) + avg_loss = fluid.layers.mean(loss) + dy_out = avg_loss._numpy() if batch_id == 0: for param in fluid.default_main_program().global_block( ).all_parameters(): dy_param_init_value[param.name] = param._numpy() - loss._backward() - sgd.minimize(loss) + avg_loss._backward() + sgd.minimize(avg_loss) dy_param_value = {} for param in fluid.default_main_program().global_block( ).all_parameters(): @@ -147,7 +147,6 @@ class TestImperativeMnist(unittest.TestCase): exe = fluid.Executor(fluid.CPUPlace()) - # mnist = Conv2D(1, 20, 5) mnist = MNIST() sgd = SGDOptimizer(learning_rate=1e-3) train_reader = paddle.batch( @@ -157,8 +156,9 @@ class TestImperativeMnist(unittest.TestCase): name='pixel', shape=[1, 28, 28], dtype='float32') label = fluid.layers.data(name='label', shape=[1], dtype='int64') cost = mnist(img) - loss = fluid.layers.reduce_mean(cost) - sgd.minimize(loss) + loss = fluid.layers.cross_entropy(cost, label) + avg_loss = fluid.layers.mean(loss) + sgd.minimize(avg_loss) # initialize params and fetch them static_param_init_value = {} @@ -182,7 +182,7 @@ class TestImperativeMnist(unittest.TestCase): y_data = np.array([x[1] for x in data]).astype('int64').reshape( [128, 1]) - fetch_list = [loss.name] + fetch_list = [avg_loss.name] fetch_list.extend(static_param_name_list) out = exe.run(fluid.default_main_program(), feed={"pixel": x_data,