From 74551758cca02c28e536728f1ca308cd13a7086e Mon Sep 17 00:00:00 2001 From: minqiyang Date: Thu, 21 Feb 2019 11:01:27 +0800 Subject: [PATCH] Polish code test=develop --- paddle/fluid/imperative/layer.cc | 4 ++-- paddle/fluid/imperative/layer.h | 17 ++++++----------- paddle/fluid/imperative/tracer.cc | 21 --------------------- paddle/fluid/pybind/pybind.cc | 2 +- python/paddle/fluid/framework.py | 7 +------ 5 files changed, 10 insertions(+), 41 deletions(-) diff --git a/paddle/fluid/imperative/layer.cc b/paddle/fluid/imperative/layer.cc index 827473ec82..47488d4dea 100644 --- a/paddle/fluid/imperative/layer.cc +++ b/paddle/fluid/imperative/layer.cc @@ -175,7 +175,7 @@ std::unique_ptr VarBase::NewVarBase(const platform::Place& dst_place, PADDLE_ENFORCE(var_->IsInitialized(), "Variable must be initialized when getting numpy tensor"); - std::unique_ptr new_var(new VarBase("NewVarBase")); + std::unique_ptr new_var(new VarBase()); framework::LoDTensor* tensor = new_var->var_->GetMutable(); tensor->Resize(var_->Get().dims()); @@ -303,7 +303,7 @@ std::vector PyLayer::Apply(int func_id, std::vector outvars = CallPythonFunc(py_funcs_[func_id], invars); std::vector ret; for (Variable* v : outvars) { - ret.push_back(new VarBase(v, new VarBase("PYLAYER_XGRAD", true), "")); + ret.push_back(new VarBase(v, new VarBase(true))); } return ret; } diff --git a/paddle/fluid/imperative/layer.h b/paddle/fluid/imperative/layer.h index f42ceb5027..78205486c5 100644 --- a/paddle/fluid/imperative/layer.h +++ b/paddle/fluid/imperative/layer.h @@ -103,28 +103,24 @@ class OpBase; */ class VarBase { public: - explicit VarBase(std::string name) - : VarBase(new framework::Variable(), new VarBase(name + "XGRAD", true), - name) {} + VarBase() : VarBase(new framework::Variable(), new VarBase(true)) {} // Owns `var` and `grad` - VarBase(framework::Variable* var, VarBase* grad, std::string name) + VarBase(framework::Variable* var, VarBase* grad) : var_desc_(nullptr), var_(var), grads_(grad), stop_gradient_(false), pre_op_(nullptr), - pre_op_out_idx_(-1), - name_(name) {} + pre_op_out_idx_(-1) {} - explicit VarBase(std::string name, bool stop_gradient) + explicit VarBase(bool stop_gradient) : var_desc_(nullptr), var_(new framework::Variable()), - grads_(stop_gradient ? nullptr : new VarBase(name + "XGRAD", true)), + grads_(stop_gradient ? nullptr : new VarBase(true)), stop_gradient_(stop_gradient), pre_op_(nullptr), - pre_op_out_idx_(-1), - name_(name) {} + pre_op_out_idx_(-1) {} virtual ~VarBase() { if (var_) { @@ -187,7 +183,6 @@ class VarBase { OpBase* pre_op_; std::string pre_op_out_name_; int pre_op_out_idx_; - std::string name_; }; /* The wrapper for OpDesc which holds a OpDesc and a OpDesc of its diff --git a/paddle/fluid/imperative/tracer.cc b/paddle/fluid/imperative/tracer.cc index c8244e22fd..ef275a361f 100644 --- a/paddle/fluid/imperative/tracer.cc +++ b/paddle/fluid/imperative/tracer.cc @@ -66,33 +66,12 @@ platform::Place GetExpectedPlace(platform::Place place, VarBasePtrMap inputs) { return result; } -// framework::BlockDesc* InferShapeAndVarType(OpBase* op, const VarBasePtrMap& -// inputs, const VarBasePtrMap& outputs) { -// std::unique_ptr block(new BlockDesc()); - -// // construct op desc -// op->op_desc_ = block.AppendOp(); - -// // construct op inputs and outputs -// // for -// // -// for (auto it = ) -// op->op_desc_->SetInput() - -// op->op_desc_->InferShape(*block); -// op->op_desc_->InferVarType(block.get()); - -// return block.release(); -// } - void Tracer::Trace(OpBase* op, const VarBasePtrMap& inputs, const VarBasePtrMap& outputs, framework::BlockDesc* block, const platform::Place expected_place, const bool stop_gradient) { std::map vars; - // framework::BlockDesc* block = InferShapeAndVarType(op, inputs, outputs); - framework::OpDesc* op_desc = op->op_desc_; VLOG(3) << "tracer tracing " << op_desc->Type(); op_desc->InferShape(*block); diff --git a/paddle/fluid/pybind/pybind.cc b/paddle/fluid/pybind/pybind.cc index 26ebacc13f..351513712c 100644 --- a/paddle/fluid/pybind/pybind.cc +++ b/paddle/fluid/pybind/pybind.cc @@ -137,7 +137,7 @@ PYBIND11_MODULE(core, m) { py::class_(m, "VarBase", R"DOC()DOC") // .def(py::init<>()) - .def(py::init(), py::arg("stop_gradient") = false, py::arg("name") = "") + .def(py::init(), py::arg("stop_gradient") = false) .def("_run_backward", [](imperative::VarBase &self) { self.RunBackward(); }) .def("_grad_name", &imperative::VarBase::GradName) diff --git a/python/paddle/fluid/framework.py b/python/paddle/fluid/framework.py index 4ff769dd48..708d4880a1 100644 --- a/python/paddle/fluid/framework.py +++ b/python/paddle/fluid/framework.py @@ -306,10 +306,6 @@ class Variable(object): if name is None: name = unique_name.generate('_generated_var') - # print("create var", name) - # import sys - # sys.stdout.flush() - is_new_var = False name = cpt.to_text(name) self.desc = self.block.desc.find_var(cpt.to_bytes(name)) @@ -387,9 +383,8 @@ class Variable(object): if _in_imperative_mode(): self._ivar = kwargs.get("ivar", None) if not self._ivar: - self._ivar = core.VarBase(name, stop_gradient) + self._ivar = core.VarBase(stop_gradient) self._ivar.desc = self.desc - self._ivar.stop_gradient = stop_gradient def _numpy(self): new_ivar = self._ivar._copy_to(core.CPUPlace(), True) -- GitLab