提交 74551758 编写于 作者: M minqiyang

Polish code

test=develop
上级 1f0ef42e
......@@ -175,7 +175,7 @@ std::unique_ptr<VarBase> VarBase::NewVarBase(const platform::Place& dst_place,
PADDLE_ENFORCE(var_->IsInitialized(),
"Variable must be initialized when getting numpy tensor");
std::unique_ptr<VarBase> new_var(new VarBase("NewVarBase"));
std::unique_ptr<VarBase> new_var(new VarBase());
framework::LoDTensor* tensor =
new_var->var_->GetMutable<framework::LoDTensor>();
tensor->Resize(var_->Get<framework::LoDTensor>().dims());
......@@ -303,7 +303,7 @@ std::vector<VarBase*> PyLayer::Apply(int func_id,
std::vector<Variable*> outvars = CallPythonFunc(py_funcs_[func_id], invars);
std::vector<VarBase*> ret;
for (Variable* v : outvars) {
ret.push_back(new VarBase(v, new VarBase("PYLAYER_XGRAD", true), ""));
ret.push_back(new VarBase(v, new VarBase(true)));
}
return ret;
}
......
......@@ -103,28 +103,24 @@ class OpBase;
*/
class VarBase {
public:
explicit VarBase(std::string name)
: VarBase(new framework::Variable(), new VarBase(name + "XGRAD", true),
name) {}
VarBase() : VarBase(new framework::Variable(), new VarBase(true)) {}
// Owns `var` and `grad`
VarBase(framework::Variable* var, VarBase* grad, std::string name)
VarBase(framework::Variable* var, VarBase* grad)
: var_desc_(nullptr),
var_(var),
grads_(grad),
stop_gradient_(false),
pre_op_(nullptr),
pre_op_out_idx_(-1),
name_(name) {}
pre_op_out_idx_(-1) {}
explicit VarBase(std::string name, bool stop_gradient)
explicit VarBase(bool stop_gradient)
: var_desc_(nullptr),
var_(new framework::Variable()),
grads_(stop_gradient ? nullptr : new VarBase(name + "XGRAD", true)),
grads_(stop_gradient ? nullptr : new VarBase(true)),
stop_gradient_(stop_gradient),
pre_op_(nullptr),
pre_op_out_idx_(-1),
name_(name) {}
pre_op_out_idx_(-1) {}
virtual ~VarBase() {
if (var_) {
......@@ -187,7 +183,6 @@ class VarBase {
OpBase* pre_op_;
std::string pre_op_out_name_;
int pre_op_out_idx_;
std::string name_;
};
/* The wrapper for OpDesc which holds a OpDesc and a OpDesc of its
......
......@@ -66,33 +66,12 @@ platform::Place GetExpectedPlace(platform::Place place, VarBasePtrMap inputs) {
return result;
}
// framework::BlockDesc* InferShapeAndVarType(OpBase* op, const VarBasePtrMap&
// inputs, const VarBasePtrMap& outputs) {
// std::unique_ptr<BlockDesc> block(new BlockDesc());
// // construct op desc
// op->op_desc_ = block.AppendOp();
// // construct op inputs and outputs
// // for
// //
// for (auto it = )
// op->op_desc_->SetInput()
// op->op_desc_->InferShape(*block);
// op->op_desc_->InferVarType(block.get());
// return block.release();
// }
void Tracer::Trace(OpBase* op, const VarBasePtrMap& inputs,
const VarBasePtrMap& outputs, framework::BlockDesc* block,
const platform::Place expected_place,
const bool stop_gradient) {
std::map<std::string, VarBase*> vars;
// framework::BlockDesc* block = InferShapeAndVarType(op, inputs, outputs);
framework::OpDesc* op_desc = op->op_desc_;
VLOG(3) << "tracer tracing " << op_desc->Type();
op_desc->InferShape(*block);
......
......@@ -137,7 +137,7 @@ PYBIND11_MODULE(core, m) {
py::class_<imperative::VarBase>(m, "VarBase", R"DOC()DOC")
// .def(py::init<>())
.def(py::init<std::string, bool>(), py::arg("stop_gradient") = false, py::arg("name") = "")
.def(py::init<bool>(), py::arg("stop_gradient") = false)
.def("_run_backward",
[](imperative::VarBase &self) { self.RunBackward(); })
.def("_grad_name", &imperative::VarBase::GradName)
......
......@@ -306,10 +306,6 @@ class Variable(object):
if name is None:
name = unique_name.generate('_generated_var')
# print("create var", name)
# import sys
# sys.stdout.flush()
is_new_var = False
name = cpt.to_text(name)
self.desc = self.block.desc.find_var(cpt.to_bytes(name))
......@@ -387,9 +383,8 @@ class Variable(object):
if _in_imperative_mode():
self._ivar = kwargs.get("ivar", None)
if not self._ivar:
self._ivar = core.VarBase(name, stop_gradient)
self._ivar = core.VarBase(stop_gradient)
self._ivar.desc = self.desc
self._ivar.stop_gradient = stop_gradient
def _numpy(self):
new_ivar = self._ivar._copy_to(core.CPUPlace(), True)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册