diff --git a/paddle/fluid/imperative/layer.h b/paddle/fluid/imperative/layer.h index cf352ebe53b835a6b12725abea4146fc4bf773ce..c400b19700094e40a3d234c2c063e27a0596e1bd 100644 --- a/paddle/fluid/imperative/layer.h +++ b/paddle/fluid/imperative/layer.h @@ -23,19 +23,29 @@ namespace paddle { namespace imperative { +class OpBase; + class VarBase { public: VarBase() {} virtual ~VarBase() {} + OpBase* pre_op_; framework::VarDesc* var_desc_; }; class OpBase { public: - OpBase() {} - virtual ~OpBase() {} + OpBase() + : input_vars_(new std::vector()), + output_vars_(new std::vector()) {} + virtual ~OpBase() { + delete input_vars_; + delete output_vars_; + } + std::vector* input_vars_; + std::vector* output_vars_; framework::OpDesc* op_desc_; }; diff --git a/paddle/fluid/imperative/tracer.h b/paddle/fluid/imperative/tracer.h index c5dea9686314a1d321cc85c1b280c190c160b67e..9d7bdda8ccc83aedb25b4aba5868a57d49506370 100644 --- a/paddle/fluid/imperative/tracer.h +++ b/paddle/fluid/imperative/tracer.h @@ -31,15 +31,18 @@ class Tracer { public: Tracer() {} - void Trace(OpBase* op, const std::map& inputs, - const std::map& outputs) { + void Trace(OpBase* op, const std::vector& inputs, + const std::vector& outputs) { framework::OpDesc* op_desc = op->op_desc_; LOG(ERROR) << "tracer tracing " << op_desc->Type(); op_desc->InferShape(*block_); op_desc->InferVarType(block_); std::unique_ptr op_base = framework::OpRegistry::CreateOp(*op_desc); - for (const std::string& vname : op_desc->InputArgumentNames()) { + + *op->input_vars_ = inputs; + for (VarBase* input : inputs) { + const std::string vname = input->var_desc_->Name(); framework::Variable* var = scope_->Var(vname); if (!var->IsInitialized()) { framework::VarDesc* var_desc = block_->FindVar(vname); @@ -50,7 +53,10 @@ class Tracer { } } } - for (const std::string& vname : op_desc->OutputArgumentNames()) { + + *op->output_vars_ = outputs; + for (auto output : outputs) { + const std::string vname = output->var_desc_->Name(); framework::Variable* var = scope_->Var(vname); if (!var->IsInitialized()) { framework::VarDesc* var_desc = block_->FindVar(vname); @@ -60,6 +66,7 @@ class Tracer { LOG(ERROR) << "tracer doesn't support yet"; } } + output->pre_op_ = op; } op_base->Run(*scope_, platform::CPUPlace()); } diff --git a/paddle/fluid/pybind/imperative.h b/paddle/fluid/pybind/imperative.h index bf01ed3a2558de66a6c8f7cde7ed7059c1ac75d4..5834b83df904bda8389782c5dd65e6f4a23ee4dd 100644 --- a/paddle/fluid/pybind/imperative.h +++ b/paddle/fluid/pybind/imperative.h @@ -37,6 +37,11 @@ class PyLayer : public imperative::Layer { } }; +class PyOpBase : public imperative::OpBase { + public: + using imperative::OpBase::OpBase; // Inherit constructors +}; + void BindTracer(pybind11::module* m); } // namespace pybind diff --git a/paddle/fluid/pybind/pybind.cc b/paddle/fluid/pybind/pybind.cc index cd0e550b90e4db967db9577a069e44a8efafab54..656d28eb2a1b30eff3793a9fc42ebe06fff8d419 100644 --- a/paddle/fluid/pybind/pybind.cc +++ b/paddle/fluid/pybind/pybind.cc @@ -111,8 +111,9 @@ PYBIND11_MODULE(core, m) { }, py::return_value_policy::reference); - py::class_(m, "OpBase", - R"DOC()DOC") + py::class_(m, "OpBase", + R"DOC()DOC") + .def(py::init<>()) .def_property( "desc", [](const imperative::OpBase &self) { return self.op_desc_; }, [](imperative::OpBase &self, framework::OpDesc *op_desc) { diff --git a/python/paddle/fluid/framework.py b/python/paddle/fluid/framework.py index 51da8260e253e6b1b7aae67df847305cae5bf8fa..d4ca6901d5a51b5947fc7a4fcfaaddfaa94c4a40 100644 --- a/python/paddle/fluid/framework.py +++ b/python/paddle/fluid/framework.py @@ -563,6 +563,7 @@ class Operator(core.OpBase): inputs=None, outputs=None, attrs=None): + core.OpBase.__init__(self) self.block = block self.desc = desc # note: not add self.attrs here: @@ -602,33 +603,32 @@ class Operator(core.OpBase): return True return False - self.inputs = [] if not inputs else inputs - for in_proto in proto.inputs: - found = find_name(self.inputs, in_proto.name) - assert found or in_proto.dispensable, "Input {} not found".format( - in_proto.name) - - if found: - in_args = self.inputs[in_proto.name] - if not isinstance(in_args, list): - in_args = [in_args] - if not in_proto.duplicable and len(in_args) > 1: - raise ValueError( - "Input %s expects only one input, but %d are given." % - (in_proto.name, len(in_args))) - in_arg_names = [] - for arg in in_args: - if isinstance(arg, six.string_types): - in_arg_names.append(arg) - elif isinstance(arg, six.binary_type): - in_arg_names.append(arg.decode()) - else: - in_arg_names.append(cpt.to_text(arg.name)) - self.desc.set_input(in_proto.name, in_arg_names) - else: - self.desc.set_input(in_proto.name, []) + if inputs is not None: + for in_proto in proto.inputs: + found = find_name(inputs, in_proto.name) + assert found or in_proto.dispensable, "Input {} not found".format( + in_proto.name) + + if found: + in_args = inputs[in_proto.name] + if not isinstance(in_args, list): + in_args = [in_args] + if not in_proto.duplicable and len(in_args) > 1: + raise ValueError( + "Input %s expects only one input, but %d are given." + % (in_proto.name, len(in_args))) + in_arg_names = [] + for arg in in_args: + if isinstance(arg, six.string_types): + in_arg_names.append(arg) + elif isinstance(arg, six.binary_type): + in_arg_names.append(arg.decode()) + else: + in_arg_names.append(cpt.to_text(arg.name)) + self.desc.set_input(in_proto.name, in_arg_names) + else: + self.desc.set_input(in_proto.name, []) - self.outputs = [] if not outputs else outputs if outputs is not None: given = set() need = set() @@ -657,6 +657,21 @@ class Operator(core.OpBase): arg.op = self self.desc.set_output(out_proto.name, out_arg_names) + input_vars = [] + for inp in inputs.values(): + if isinstance(inp, Variable): + input_vars.append(inp) + elif isinstance(inp, list): + input_vars.extend(inp[:]) + self.inputs = input_vars + output_vars = [] + for out in outputs.values(): + if isinstance(out, Variable): + output_vars.append(out) + elif isinstance(inp, list): + output_vars.extend(out[:]) + self.outputs = output_vars + if op_attrs is not None: if not isinstance(op_attrs, dict): raise TypeError("'attrs' should be a dict.")