提交 4d0df1fe 编写于 作者: X Xin Pan

add fields for autograd

test=develop
上级 81383916
......@@ -23,19 +23,29 @@
namespace paddle {
namespace imperative {
class OpBase;
class VarBase {
public:
VarBase() {}
virtual ~VarBase() {}
OpBase* pre_op_;
framework::VarDesc* var_desc_;
};
class OpBase {
public:
OpBase() {}
virtual ~OpBase() {}
OpBase()
: input_vars_(new std::vector<VarBase*>()),
output_vars_(new std::vector<VarBase*>()) {}
virtual ~OpBase() {
delete input_vars_;
delete output_vars_;
}
std::vector<VarBase*>* input_vars_;
std::vector<VarBase*>* output_vars_;
framework::OpDesc* op_desc_;
};
......
......@@ -31,15 +31,18 @@ class Tracer {
public:
Tracer() {}
void Trace(OpBase* op, const std::map<std::string, VarBase*>& inputs,
const std::map<std::string, VarBase*>& outputs) {
void Trace(OpBase* op, const std::vector<VarBase*>& inputs,
const std::vector<VarBase*>& outputs) {
framework::OpDesc* op_desc = op->op_desc_;
LOG(ERROR) << "tracer tracing " << op_desc->Type();
op_desc->InferShape(*block_);
op_desc->InferVarType(block_);
std::unique_ptr<framework::OperatorBase> op_base =
framework::OpRegistry::CreateOp(*op_desc);
for (const std::string& vname : op_desc->InputArgumentNames()) {
*op->input_vars_ = inputs;
for (VarBase* input : inputs) {
const std::string vname = input->var_desc_->Name();
framework::Variable* var = scope_->Var(vname);
if (!var->IsInitialized()) {
framework::VarDesc* var_desc = block_->FindVar(vname);
......@@ -50,7 +53,10 @@ class Tracer {
}
}
}
for (const std::string& vname : op_desc->OutputArgumentNames()) {
*op->output_vars_ = outputs;
for (auto output : outputs) {
const std::string vname = output->var_desc_->Name();
framework::Variable* var = scope_->Var(vname);
if (!var->IsInitialized()) {
framework::VarDesc* var_desc = block_->FindVar(vname);
......@@ -60,6 +66,7 @@ class Tracer {
LOG(ERROR) << "tracer doesn't support yet";
}
}
output->pre_op_ = op;
}
op_base->Run(*scope_, platform::CPUPlace());
}
......
......@@ -37,6 +37,11 @@ class PyLayer : public imperative::Layer {
}
};
class PyOpBase : public imperative::OpBase {
public:
using imperative::OpBase::OpBase; // Inherit constructors
};
void BindTracer(pybind11::module* m);
} // namespace pybind
......
......@@ -111,8 +111,9 @@ PYBIND11_MODULE(core, m) {
},
py::return_value_policy::reference);
py::class_<imperative::OpBase>(m, "OpBase",
R"DOC()DOC")
py::class_<imperative::OpBase, PyOpBase>(m, "OpBase",
R"DOC()DOC")
.def(py::init<>())
.def_property(
"desc", [](const imperative::OpBase &self) { return self.op_desc_; },
[](imperative::OpBase &self, framework::OpDesc *op_desc) {
......
......@@ -563,6 +563,7 @@ class Operator(core.OpBase):
inputs=None,
outputs=None,
attrs=None):
core.OpBase.__init__(self)
self.block = block
self.desc = desc
# note: not add self.attrs here:
......@@ -602,33 +603,32 @@ class Operator(core.OpBase):
return True
return False
self.inputs = [] if not inputs else inputs
for in_proto in proto.inputs:
found = find_name(self.inputs, in_proto.name)
assert found or in_proto.dispensable, "Input {} not found".format(
in_proto.name)
if found:
in_args = self.inputs[in_proto.name]
if not isinstance(in_args, list):
in_args = [in_args]
if not in_proto.duplicable and len(in_args) > 1:
raise ValueError(
"Input %s expects only one input, but %d are given." %
(in_proto.name, len(in_args)))
in_arg_names = []
for arg in in_args:
if isinstance(arg, six.string_types):
in_arg_names.append(arg)
elif isinstance(arg, six.binary_type):
in_arg_names.append(arg.decode())
else:
in_arg_names.append(cpt.to_text(arg.name))
self.desc.set_input(in_proto.name, in_arg_names)
else:
self.desc.set_input(in_proto.name, [])
if inputs is not None:
for in_proto in proto.inputs:
found = find_name(inputs, in_proto.name)
assert found or in_proto.dispensable, "Input {} not found".format(
in_proto.name)
if found:
in_args = inputs[in_proto.name]
if not isinstance(in_args, list):
in_args = [in_args]
if not in_proto.duplicable and len(in_args) > 1:
raise ValueError(
"Input %s expects only one input, but %d are given."
% (in_proto.name, len(in_args)))
in_arg_names = []
for arg in in_args:
if isinstance(arg, six.string_types):
in_arg_names.append(arg)
elif isinstance(arg, six.binary_type):
in_arg_names.append(arg.decode())
else:
in_arg_names.append(cpt.to_text(arg.name))
self.desc.set_input(in_proto.name, in_arg_names)
else:
self.desc.set_input(in_proto.name, [])
self.outputs = [] if not outputs else outputs
if outputs is not None:
given = set()
need = set()
......@@ -657,6 +657,21 @@ class Operator(core.OpBase):
arg.op = self
self.desc.set_output(out_proto.name, out_arg_names)
input_vars = []
for inp in inputs.values():
if isinstance(inp, Variable):
input_vars.append(inp)
elif isinstance(inp, list):
input_vars.extend(inp[:])
self.inputs = input_vars
output_vars = []
for out in outputs.values():
if isinstance(out, Variable):
output_vars.append(out)
elif isinstance(inp, list):
output_vars.extend(out[:])
self.outputs = output_vars
if op_attrs is not None:
if not isinstance(op_attrs, dict):
raise TypeError("'attrs' should be a dict.")
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册