From 8138391631aed37b2832b66ede3f70a9aff1ea0e Mon Sep 17 00:00:00 2001 From: Xin Pan Date: Wed, 28 Nov 2018 16:13:13 +0800 Subject: [PATCH] add OpBase and unify with VarBase test=develop --- paddle/fluid/imperative/CMakeLists.txt | 2 +- paddle/fluid/imperative/layer.h | 20 ++++++--- paddle/fluid/imperative/tracer.h | 10 +++-- paddle/fluid/pybind/imperative.h | 6 +-- paddle/fluid/pybind/pybind.cc | 19 ++++++--- python/paddle/fluid/framework.py | 58 +++++++++++++------------- 6 files changed, 69 insertions(+), 46 deletions(-) diff --git a/paddle/fluid/imperative/CMakeLists.txt b/paddle/fluid/imperative/CMakeLists.txt index dff80dff0..fb57eca65 100644 --- a/paddle/fluid/imperative/CMakeLists.txt +++ b/paddle/fluid/imperative/CMakeLists.txt @@ -1,3 +1,3 @@ -cc_library(layer SRCS layer.cc) +cc_library(layer SRCS layer.cc DEPS proto_desc) cc_library(tracer SRCS tracer.cc DEPS proto_desc) cc_library(engine SRCS engine.cc) diff --git a/paddle/fluid/imperative/layer.h b/paddle/fluid/imperative/layer.h index a83535af9..cf352ebe5 100644 --- a/paddle/fluid/imperative/layer.h +++ b/paddle/fluid/imperative/layer.h @@ -15,6 +15,7 @@ #pragma once #include +#include "paddle/fluid/framework/op_desc.h" #include "paddle/fluid/framework/scope.h" #include "paddle/fluid/framework/var_desc.h" #include "paddle/fluid/platform/enforce.h" @@ -22,21 +23,28 @@ namespace paddle { namespace imperative { -class VariableBase { +class VarBase { public: - VariableBase() {} - virtual ~VariableBase() {} + VarBase() {} + virtual ~VarBase() {} framework::VarDesc* var_desc_; }; +class OpBase { + public: + OpBase() {} + virtual ~OpBase() {} + + framework::OpDesc* op_desc_; +}; + class Layer { public: virtual ~Layer() {} - virtual std::vector Forward( - const std::vector& inputs) { - std::vector vars; + virtual std::vector Forward(const std::vector& inputs) { + std::vector vars; return vars; } diff --git a/paddle/fluid/imperative/tracer.h b/paddle/fluid/imperative/tracer.h index 8a7a2d700..c5dea9686 100644 --- a/paddle/fluid/imperative/tracer.h +++ b/paddle/fluid/imperative/tracer.h @@ -14,6 +14,7 @@ #pragma once +#include #include #include @@ -21,6 +22,7 @@ #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/scope.h" #include "paddle/fluid/imperative/engine.h" +#include "paddle/fluid/imperative/layer.h" namespace paddle { namespace imperative { @@ -29,11 +31,13 @@ class Tracer { public: Tracer() {} - void Trace(framework::OpDesc* op_desc) { + void Trace(OpBase* op, const std::map& inputs, + const std::map& outputs) { + framework::OpDesc* op_desc = op->op_desc_; LOG(ERROR) << "tracer tracing " << op_desc->Type(); op_desc->InferShape(*block_); op_desc->InferVarType(block_); - std::unique_ptr op = + std::unique_ptr op_base = framework::OpRegistry::CreateOp(*op_desc); for (const std::string& vname : op_desc->InputArgumentNames()) { framework::Variable* var = scope_->Var(vname); @@ -57,7 +61,7 @@ class Tracer { } } } - op->Run(*scope_, platform::CPUPlace()); + op_base->Run(*scope_, platform::CPUPlace()); } void SetScope(framework::Scope* scope) { scope_ = scope; } diff --git a/paddle/fluid/pybind/imperative.h b/paddle/fluid/pybind/imperative.h index 9a558fbdb..bf01ed3a2 100644 --- a/paddle/fluid/pybind/imperative.h +++ b/paddle/fluid/pybind/imperative.h @@ -26,9 +26,9 @@ class PyLayer : public imperative::Layer { public: using imperative::Layer::Layer; // Inherit constructors - std::vector Forward( - const std::vector& inputs) override { - PYBIND11_OVERLOAD(std::vector, Layer, Forward, + std::vector Forward( + const std::vector& inputs) override { + PYBIND11_OVERLOAD(std::vector, Layer, Forward, inputs); // NOLINT } diff --git a/paddle/fluid/pybind/pybind.cc b/paddle/fluid/pybind/pybind.cc index 3cf1ec34a..cd0e550b9 100644 --- a/paddle/fluid/pybind/pybind.cc +++ b/paddle/fluid/pybind/pybind.cc @@ -101,21 +101,30 @@ PYBIND11_MODULE(core, m) { BindException(&m); - py::class_(m, "VariableBase", - R"DOC()DOC") + py::class_(m, "VarBase", + R"DOC()DOC") .def_property( "desc", - [](const imperative::VariableBase &self) { return self.var_desc_; }, - [](imperative::VariableBase &self, framework::VarDesc *var_desc) { + [](const imperative::VarBase &self) { return self.var_desc_; }, + [](imperative::VarBase &self, framework::VarDesc *var_desc) { self.var_desc_ = var_desc; }, py::return_value_policy::reference); + py::class_(m, "OpBase", + R"DOC()DOC") + .def_property( + "desc", [](const imperative::OpBase &self) { return self.op_desc_; }, + [](imperative::OpBase &self, framework::OpDesc *op_desc) { + self.op_desc_ = op_desc; + }, + py::return_value_policy::reference); + py::class_ layer(m, "Layer"); layer.def(py::init<>()) .def("forward", [](imperative::Layer &self, - const std::vector &inputs) { + const std::vector &inputs) { return self.Forward(inputs); }) .def("backward", &imperative::Layer::Backward); diff --git a/python/paddle/fluid/framework.py b/python/paddle/fluid/framework.py index 3d3263e7c..51da8260e 100644 --- a/python/paddle/fluid/framework.py +++ b/python/paddle/fluid/framework.py @@ -212,7 +212,7 @@ def _debug_string_(proto, throw_on_error=True): return proto.__str__() -class Variable(core.VariableBase): +class Variable(core.VarBase): """ In Fluid, every input and output of an operator is a variable. In most cases, variables are used for holding different kinds of data or training @@ -507,7 +507,7 @@ class OpProtoHolder(object): } -class Operator(object): +class Operator(core.OpBase): """ In Fluid, all the operation are represented by Operator, and Operator is regarded as a build in an instruction of a Block. Users can use the @@ -602,32 +602,33 @@ class Operator(object): return True return False - if inputs is not None: - for in_proto in proto.inputs: - found = find_name(inputs, in_proto.name) - assert found or in_proto.dispensable, "Input {} not found".format( - in_proto.name) - - if found: - in_args = inputs[in_proto.name] - if not isinstance(in_args, list): - in_args = [in_args] - if not in_proto.duplicable and len(in_args) > 1: - raise ValueError( - "Input %s expects only one input, but %d are given." - % (in_proto.name, len(in_args))) - in_arg_names = [] - for arg in in_args: - if isinstance(arg, six.string_types): - in_arg_names.append(arg) - elif isinstance(arg, six.binary_type): - in_arg_names.append(arg.decode()) - else: - in_arg_names.append(cpt.to_text(arg.name)) - self.desc.set_input(in_proto.name, in_arg_names) - else: - self.desc.set_input(in_proto.name, []) + self.inputs = [] if not inputs else inputs + for in_proto in proto.inputs: + found = find_name(self.inputs, in_proto.name) + assert found or in_proto.dispensable, "Input {} not found".format( + in_proto.name) + + if found: + in_args = self.inputs[in_proto.name] + if not isinstance(in_args, list): + in_args = [in_args] + if not in_proto.duplicable and len(in_args) > 1: + raise ValueError( + "Input %s expects only one input, but %d are given." % + (in_proto.name, len(in_args))) + in_arg_names = [] + for arg in in_args: + if isinstance(arg, six.string_types): + in_arg_names.append(arg) + elif isinstance(arg, six.binary_type): + in_arg_names.append(arg.decode()) + else: + in_arg_names.append(cpt.to_text(arg.name)) + self.desc.set_input(in_proto.name, in_arg_names) + else: + self.desc.set_input(in_proto.name, []) + self.outputs = [] if not outputs else outputs if outputs is not None: given = set() need = set() @@ -1222,7 +1223,8 @@ class Block(object): if _in_imperative_mode(): op_desc = core.OpDesc() op = Operator(block=self, desc=op_desc, *args, **kwargs) - _imperative_tracer().trace(op.desc) + sys.stderr.write('%s %s!!!\n' % (type(op.inputs), type(op.outputs))) + _imperative_tracer().trace(op, op.inputs, op.outputs) return op_desc = self.desc.append_op() -- GitLab