提交 81383916 编写于 作者: X Xin Pan

add OpBase and unify with VarBase

test=develop
上级 f6f06924
cc_library(layer SRCS layer.cc) cc_library(layer SRCS layer.cc DEPS proto_desc)
cc_library(tracer SRCS tracer.cc DEPS proto_desc) cc_library(tracer SRCS tracer.cc DEPS proto_desc)
cc_library(engine SRCS engine.cc) cc_library(engine SRCS engine.cc)
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
#pragma once #pragma once
#include <vector> #include <vector>
#include "paddle/fluid/framework/op_desc.h"
#include "paddle/fluid/framework/scope.h" #include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/var_desc.h" #include "paddle/fluid/framework/var_desc.h"
#include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/enforce.h"
...@@ -22,21 +23,28 @@ ...@@ -22,21 +23,28 @@
namespace paddle { namespace paddle {
namespace imperative { namespace imperative {
class VariableBase { class VarBase {
public: public:
VariableBase() {} VarBase() {}
virtual ~VariableBase() {} virtual ~VarBase() {}
framework::VarDesc* var_desc_; framework::VarDesc* var_desc_;
}; };
class OpBase {
public:
OpBase() {}
virtual ~OpBase() {}
framework::OpDesc* op_desc_;
};
class Layer { class Layer {
public: public:
virtual ~Layer() {} virtual ~Layer() {}
virtual std::vector<VariableBase> Forward( virtual std::vector<VarBase> Forward(const std::vector<VarBase>& inputs) {
const std::vector<VariableBase>& inputs) { std::vector<VarBase> vars;
std::vector<VariableBase> vars;
return vars; return vars;
} }
......
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
#pragma once #pragma once
#include <map>
#include <string> #include <string>
#include <vector> #include <vector>
...@@ -21,6 +22,7 @@ ...@@ -21,6 +22,7 @@
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/scope.h" #include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/imperative/engine.h" #include "paddle/fluid/imperative/engine.h"
#include "paddle/fluid/imperative/layer.h"
namespace paddle { namespace paddle {
namespace imperative { namespace imperative {
...@@ -29,11 +31,13 @@ class Tracer { ...@@ -29,11 +31,13 @@ class Tracer {
public: public:
Tracer() {} Tracer() {}
void Trace(framework::OpDesc* op_desc) { void Trace(OpBase* op, const std::map<std::string, VarBase*>& inputs,
const std::map<std::string, VarBase*>& outputs) {
framework::OpDesc* op_desc = op->op_desc_;
LOG(ERROR) << "tracer tracing " << op_desc->Type(); LOG(ERROR) << "tracer tracing " << op_desc->Type();
op_desc->InferShape(*block_); op_desc->InferShape(*block_);
op_desc->InferVarType(block_); op_desc->InferVarType(block_);
std::unique_ptr<framework::OperatorBase> op = std::unique_ptr<framework::OperatorBase> op_base =
framework::OpRegistry::CreateOp(*op_desc); framework::OpRegistry::CreateOp(*op_desc);
for (const std::string& vname : op_desc->InputArgumentNames()) { for (const std::string& vname : op_desc->InputArgumentNames()) {
framework::Variable* var = scope_->Var(vname); framework::Variable* var = scope_->Var(vname);
...@@ -57,7 +61,7 @@ class Tracer { ...@@ -57,7 +61,7 @@ class Tracer {
} }
} }
} }
op->Run(*scope_, platform::CPUPlace()); op_base->Run(*scope_, platform::CPUPlace());
} }
void SetScope(framework::Scope* scope) { scope_ = scope; } void SetScope(framework::Scope* scope) { scope_ = scope; }
......
...@@ -26,9 +26,9 @@ class PyLayer : public imperative::Layer { ...@@ -26,9 +26,9 @@ class PyLayer : public imperative::Layer {
public: public:
using imperative::Layer::Layer; // Inherit constructors using imperative::Layer::Layer; // Inherit constructors
std::vector<imperative::VariableBase> Forward( std::vector<imperative::VarBase> Forward(
const std::vector<imperative::VariableBase>& inputs) override { const std::vector<imperative::VarBase>& inputs) override {
PYBIND11_OVERLOAD(std::vector<imperative::VariableBase>, Layer, Forward, PYBIND11_OVERLOAD(std::vector<imperative::VarBase>, Layer, Forward,
inputs); // NOLINT inputs); // NOLINT
} }
......
...@@ -101,21 +101,30 @@ PYBIND11_MODULE(core, m) { ...@@ -101,21 +101,30 @@ PYBIND11_MODULE(core, m) {
BindException(&m); BindException(&m);
py::class_<imperative::VariableBase>(m, "VariableBase", py::class_<imperative::VarBase>(m, "VarBase",
R"DOC()DOC") R"DOC()DOC")
.def_property( .def_property(
"desc", "desc",
[](const imperative::VariableBase &self) { return self.var_desc_; }, [](const imperative::VarBase &self) { return self.var_desc_; },
[](imperative::VariableBase &self, framework::VarDesc *var_desc) { [](imperative::VarBase &self, framework::VarDesc *var_desc) {
self.var_desc_ = var_desc; self.var_desc_ = var_desc;
}, },
py::return_value_policy::reference); py::return_value_policy::reference);
py::class_<imperative::OpBase>(m, "OpBase",
R"DOC()DOC")
.def_property(
"desc", [](const imperative::OpBase &self) { return self.op_desc_; },
[](imperative::OpBase &self, framework::OpDesc *op_desc) {
self.op_desc_ = op_desc;
},
py::return_value_policy::reference);
py::class_<imperative::Layer, PyLayer /* <--- trampoline*/> layer(m, "Layer"); py::class_<imperative::Layer, PyLayer /* <--- trampoline*/> layer(m, "Layer");
layer.def(py::init<>()) layer.def(py::init<>())
.def("forward", .def("forward",
[](imperative::Layer &self, [](imperative::Layer &self,
const std::vector<imperative::VariableBase> &inputs) { const std::vector<imperative::VarBase> &inputs) {
return self.Forward(inputs); return self.Forward(inputs);
}) })
.def("backward", &imperative::Layer::Backward); .def("backward", &imperative::Layer::Backward);
......
...@@ -212,7 +212,7 @@ def _debug_string_(proto, throw_on_error=True): ...@@ -212,7 +212,7 @@ def _debug_string_(proto, throw_on_error=True):
return proto.__str__() return proto.__str__()
class Variable(core.VariableBase): class Variable(core.VarBase):
""" """
In Fluid, every input and output of an operator is a variable. In most In Fluid, every input and output of an operator is a variable. In most
cases, variables are used for holding different kinds of data or training cases, variables are used for holding different kinds of data or training
...@@ -507,7 +507,7 @@ class OpProtoHolder(object): ...@@ -507,7 +507,7 @@ class OpProtoHolder(object):
} }
class Operator(object): class Operator(core.OpBase):
""" """
In Fluid, all the operation are represented by Operator, and Operator In Fluid, all the operation are represented by Operator, and Operator
is regarded as a build in an instruction of a Block. Users can use the is regarded as a build in an instruction of a Block. Users can use the
...@@ -602,20 +602,20 @@ class Operator(object): ...@@ -602,20 +602,20 @@ class Operator(object):
return True return True
return False return False
if inputs is not None: self.inputs = [] if not inputs else inputs
for in_proto in proto.inputs: for in_proto in proto.inputs:
found = find_name(inputs, in_proto.name) found = find_name(self.inputs, in_proto.name)
assert found or in_proto.dispensable, "Input {} not found".format( assert found or in_proto.dispensable, "Input {} not found".format(
in_proto.name) in_proto.name)
if found: if found:
in_args = inputs[in_proto.name] in_args = self.inputs[in_proto.name]
if not isinstance(in_args, list): if not isinstance(in_args, list):
in_args = [in_args] in_args = [in_args]
if not in_proto.duplicable and len(in_args) > 1: if not in_proto.duplicable and len(in_args) > 1:
raise ValueError( raise ValueError(
"Input %s expects only one input, but %d are given." "Input %s expects only one input, but %d are given." %
% (in_proto.name, len(in_args))) (in_proto.name, len(in_args)))
in_arg_names = [] in_arg_names = []
for arg in in_args: for arg in in_args:
if isinstance(arg, six.string_types): if isinstance(arg, six.string_types):
...@@ -628,6 +628,7 @@ class Operator(object): ...@@ -628,6 +628,7 @@ class Operator(object):
else: else:
self.desc.set_input(in_proto.name, []) self.desc.set_input(in_proto.name, [])
self.outputs = [] if not outputs else outputs
if outputs is not None: if outputs is not None:
given = set() given = set()
need = set() need = set()
...@@ -1222,7 +1223,8 @@ class Block(object): ...@@ -1222,7 +1223,8 @@ class Block(object):
if _in_imperative_mode(): if _in_imperative_mode():
op_desc = core.OpDesc() op_desc = core.OpDesc()
op = Operator(block=self, desc=op_desc, *args, **kwargs) op = Operator(block=self, desc=op_desc, *args, **kwargs)
_imperative_tracer().trace(op.desc) sys.stderr.write('%s %s!!!\n' % (type(op.inputs), type(op.outputs)))
_imperative_tracer().trace(op, op.inputs, op.outputs)
return return
op_desc = self.desc.append_op() op_desc = self.desc.append_op()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册