提交 74551758 编写于 作者: M minqiyang

Polish code

test=develop
上级 1f0ef42e
...@@ -175,7 +175,7 @@ std::unique_ptr<VarBase> VarBase::NewVarBase(const platform::Place& dst_place, ...@@ -175,7 +175,7 @@ std::unique_ptr<VarBase> VarBase::NewVarBase(const platform::Place& dst_place,
PADDLE_ENFORCE(var_->IsInitialized(), PADDLE_ENFORCE(var_->IsInitialized(),
"Variable must be initialized when getting numpy tensor"); "Variable must be initialized when getting numpy tensor");
std::unique_ptr<VarBase> new_var(new VarBase("NewVarBase")); std::unique_ptr<VarBase> new_var(new VarBase());
framework::LoDTensor* tensor = framework::LoDTensor* tensor =
new_var->var_->GetMutable<framework::LoDTensor>(); new_var->var_->GetMutable<framework::LoDTensor>();
tensor->Resize(var_->Get<framework::LoDTensor>().dims()); tensor->Resize(var_->Get<framework::LoDTensor>().dims());
...@@ -303,7 +303,7 @@ std::vector<VarBase*> PyLayer::Apply(int func_id, ...@@ -303,7 +303,7 @@ std::vector<VarBase*> PyLayer::Apply(int func_id,
std::vector<Variable*> outvars = CallPythonFunc(py_funcs_[func_id], invars); std::vector<Variable*> outvars = CallPythonFunc(py_funcs_[func_id], invars);
std::vector<VarBase*> ret; std::vector<VarBase*> ret;
for (Variable* v : outvars) { for (Variable* v : outvars) {
ret.push_back(new VarBase(v, new VarBase("PYLAYER_XGRAD", true), "")); ret.push_back(new VarBase(v, new VarBase(true)));
} }
return ret; return ret;
} }
......
...@@ -103,28 +103,24 @@ class OpBase; ...@@ -103,28 +103,24 @@ class OpBase;
*/ */
class VarBase { class VarBase {
public: public:
explicit VarBase(std::string name) VarBase() : VarBase(new framework::Variable(), new VarBase(true)) {}
: VarBase(new framework::Variable(), new VarBase(name + "XGRAD", true),
name) {}
// Owns `var` and `grad` // Owns `var` and `grad`
VarBase(framework::Variable* var, VarBase* grad, std::string name) VarBase(framework::Variable* var, VarBase* grad)
: var_desc_(nullptr), : var_desc_(nullptr),
var_(var), var_(var),
grads_(grad), grads_(grad),
stop_gradient_(false), stop_gradient_(false),
pre_op_(nullptr), pre_op_(nullptr),
pre_op_out_idx_(-1), pre_op_out_idx_(-1) {}
name_(name) {}
explicit VarBase(std::string name, bool stop_gradient) explicit VarBase(bool stop_gradient)
: var_desc_(nullptr), : var_desc_(nullptr),
var_(new framework::Variable()), var_(new framework::Variable()),
grads_(stop_gradient ? nullptr : new VarBase(name + "XGRAD", true)), grads_(stop_gradient ? nullptr : new VarBase(true)),
stop_gradient_(stop_gradient), stop_gradient_(stop_gradient),
pre_op_(nullptr), pre_op_(nullptr),
pre_op_out_idx_(-1), pre_op_out_idx_(-1) {}
name_(name) {}
virtual ~VarBase() { virtual ~VarBase() {
if (var_) { if (var_) {
...@@ -187,7 +183,6 @@ class VarBase { ...@@ -187,7 +183,6 @@ class VarBase {
OpBase* pre_op_; OpBase* pre_op_;
std::string pre_op_out_name_; std::string pre_op_out_name_;
int pre_op_out_idx_; int pre_op_out_idx_;
std::string name_;
}; };
/* The wrapper for OpDesc which holds a OpDesc and a OpDesc of its /* The wrapper for OpDesc which holds a OpDesc and a OpDesc of its
......
...@@ -66,33 +66,12 @@ platform::Place GetExpectedPlace(platform::Place place, VarBasePtrMap inputs) { ...@@ -66,33 +66,12 @@ platform::Place GetExpectedPlace(platform::Place place, VarBasePtrMap inputs) {
return result; return result;
} }
// framework::BlockDesc* InferShapeAndVarType(OpBase* op, const VarBasePtrMap&
// inputs, const VarBasePtrMap& outputs) {
// std::unique_ptr<BlockDesc> block(new BlockDesc());
// // construct op desc
// op->op_desc_ = block.AppendOp();
// // construct op inputs and outputs
// // for
// //
// for (auto it = )
// op->op_desc_->SetInput()
// op->op_desc_->InferShape(*block);
// op->op_desc_->InferVarType(block.get());
// return block.release();
// }
void Tracer::Trace(OpBase* op, const VarBasePtrMap& inputs, void Tracer::Trace(OpBase* op, const VarBasePtrMap& inputs,
const VarBasePtrMap& outputs, framework::BlockDesc* block, const VarBasePtrMap& outputs, framework::BlockDesc* block,
const platform::Place expected_place, const platform::Place expected_place,
const bool stop_gradient) { const bool stop_gradient) {
std::map<std::string, VarBase*> vars; std::map<std::string, VarBase*> vars;
// framework::BlockDesc* block = InferShapeAndVarType(op, inputs, outputs);
framework::OpDesc* op_desc = op->op_desc_; framework::OpDesc* op_desc = op->op_desc_;
VLOG(3) << "tracer tracing " << op_desc->Type(); VLOG(3) << "tracer tracing " << op_desc->Type();
op_desc->InferShape(*block); op_desc->InferShape(*block);
......
...@@ -137,7 +137,7 @@ PYBIND11_MODULE(core, m) { ...@@ -137,7 +137,7 @@ PYBIND11_MODULE(core, m) {
py::class_<imperative::VarBase>(m, "VarBase", R"DOC()DOC") py::class_<imperative::VarBase>(m, "VarBase", R"DOC()DOC")
// .def(py::init<>()) // .def(py::init<>())
.def(py::init<std::string, bool>(), py::arg("stop_gradient") = false, py::arg("name") = "") .def(py::init<bool>(), py::arg("stop_gradient") = false)
.def("_run_backward", .def("_run_backward",
[](imperative::VarBase &self) { self.RunBackward(); }) [](imperative::VarBase &self) { self.RunBackward(); })
.def("_grad_name", &imperative::VarBase::GradName) .def("_grad_name", &imperative::VarBase::GradName)
......
...@@ -306,10 +306,6 @@ class Variable(object): ...@@ -306,10 +306,6 @@ class Variable(object):
if name is None: if name is None:
name = unique_name.generate('_generated_var') name = unique_name.generate('_generated_var')
# print("create var", name)
# import sys
# sys.stdout.flush()
is_new_var = False is_new_var = False
name = cpt.to_text(name) name = cpt.to_text(name)
self.desc = self.block.desc.find_var(cpt.to_bytes(name)) self.desc = self.block.desc.find_var(cpt.to_bytes(name))
...@@ -387,9 +383,8 @@ class Variable(object): ...@@ -387,9 +383,8 @@ class Variable(object):
if _in_imperative_mode(): if _in_imperative_mode():
self._ivar = kwargs.get("ivar", None) self._ivar = kwargs.get("ivar", None)
if not self._ivar: if not self._ivar:
self._ivar = core.VarBase(name, stop_gradient) self._ivar = core.VarBase(stop_gradient)
self._ivar.desc = self.desc self._ivar.desc = self.desc
self._ivar.stop_gradient = stop_gradient
def _numpy(self): def _numpy(self):
new_ivar = self._ivar._copy_to(core.CPUPlace(), True) new_ivar = self._ivar._copy_to(core.CPUPlace(), True)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册