提交 81383916 编写于 作者: X Xin Pan

add OpBase and unify with VarBase

test=develop
上级 f6f06924
cc_library(layer SRCS layer.cc)
cc_library(layer SRCS layer.cc DEPS proto_desc)
cc_library(tracer SRCS tracer.cc DEPS proto_desc)
cc_library(engine SRCS engine.cc)
......@@ -15,6 +15,7 @@
#pragma once
#include <vector>
#include "paddle/fluid/framework/op_desc.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/var_desc.h"
#include "paddle/fluid/platform/enforce.h"
......@@ -22,21 +23,28 @@
namespace paddle {
namespace imperative {
class VariableBase {
class VarBase {
public:
VariableBase() {}
virtual ~VariableBase() {}
VarBase() {}
virtual ~VarBase() {}
framework::VarDesc* var_desc_;
};
class OpBase {
public:
OpBase() {}
virtual ~OpBase() {}
framework::OpDesc* op_desc_;
};
class Layer {
public:
virtual ~Layer() {}
virtual std::vector<VariableBase> Forward(
const std::vector<VariableBase>& inputs) {
std::vector<VariableBase> vars;
virtual std::vector<VarBase> Forward(const std::vector<VarBase>& inputs) {
std::vector<VarBase> vars;
return vars;
}
......
......@@ -14,6 +14,7 @@
#pragma once
#include <map>
#include <string>
#include <vector>
......@@ -21,6 +22,7 @@
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/imperative/engine.h"
#include "paddle/fluid/imperative/layer.h"
namespace paddle {
namespace imperative {
......@@ -29,11 +31,13 @@ class Tracer {
public:
Tracer() {}
void Trace(framework::OpDesc* op_desc) {
void Trace(OpBase* op, const std::map<std::string, VarBase*>& inputs,
const std::map<std::string, VarBase*>& outputs) {
framework::OpDesc* op_desc = op->op_desc_;
LOG(ERROR) << "tracer tracing " << op_desc->Type();
op_desc->InferShape(*block_);
op_desc->InferVarType(block_);
std::unique_ptr<framework::OperatorBase> op =
std::unique_ptr<framework::OperatorBase> op_base =
framework::OpRegistry::CreateOp(*op_desc);
for (const std::string& vname : op_desc->InputArgumentNames()) {
framework::Variable* var = scope_->Var(vname);
......@@ -57,7 +61,7 @@ class Tracer {
}
}
}
op->Run(*scope_, platform::CPUPlace());
op_base->Run(*scope_, platform::CPUPlace());
}
void SetScope(framework::Scope* scope) { scope_ = scope; }
......
......@@ -26,9 +26,9 @@ class PyLayer : public imperative::Layer {
public:
using imperative::Layer::Layer; // Inherit constructors
std::vector<imperative::VariableBase> Forward(
const std::vector<imperative::VariableBase>& inputs) override {
PYBIND11_OVERLOAD(std::vector<imperative::VariableBase>, Layer, Forward,
std::vector<imperative::VarBase> Forward(
const std::vector<imperative::VarBase>& inputs) override {
PYBIND11_OVERLOAD(std::vector<imperative::VarBase>, Layer, Forward,
inputs); // NOLINT
}
......
......@@ -101,21 +101,30 @@ PYBIND11_MODULE(core, m) {
BindException(&m);
py::class_<imperative::VariableBase>(m, "VariableBase",
R"DOC()DOC")
py::class_<imperative::VarBase>(m, "VarBase",
R"DOC()DOC")
.def_property(
"desc",
[](const imperative::VariableBase &self) { return self.var_desc_; },
[](imperative::VariableBase &self, framework::VarDesc *var_desc) {
[](const imperative::VarBase &self) { return self.var_desc_; },
[](imperative::VarBase &self, framework::VarDesc *var_desc) {
self.var_desc_ = var_desc;
},
py::return_value_policy::reference);
py::class_<imperative::OpBase>(m, "OpBase",
R"DOC()DOC")
.def_property(
"desc", [](const imperative::OpBase &self) { return self.op_desc_; },
[](imperative::OpBase &self, framework::OpDesc *op_desc) {
self.op_desc_ = op_desc;
},
py::return_value_policy::reference);
py::class_<imperative::Layer, PyLayer /* <--- trampoline*/> layer(m, "Layer");
layer.def(py::init<>())
.def("forward",
[](imperative::Layer &self,
const std::vector<imperative::VariableBase> &inputs) {
const std::vector<imperative::VarBase> &inputs) {
return self.Forward(inputs);
})
.def("backward", &imperative::Layer::Backward);
......
......@@ -212,7 +212,7 @@ def _debug_string_(proto, throw_on_error=True):
return proto.__str__()
class Variable(core.VariableBase):
class Variable(core.VarBase):
"""
In Fluid, every input and output of an operator is a variable. In most
cases, variables are used for holding different kinds of data or training
......@@ -507,7 +507,7 @@ class OpProtoHolder(object):
}
class Operator(object):
class Operator(core.OpBase):
"""
In Fluid, all the operation are represented by Operator, and Operator
is regarded as a build in an instruction of a Block. Users can use the
......@@ -602,32 +602,33 @@ class Operator(object):
return True
return False
if inputs is not None:
for in_proto in proto.inputs:
found = find_name(inputs, in_proto.name)
assert found or in_proto.dispensable, "Input {} not found".format(
in_proto.name)
if found:
in_args = inputs[in_proto.name]
if not isinstance(in_args, list):
in_args = [in_args]
if not in_proto.duplicable and len(in_args) > 1:
raise ValueError(
"Input %s expects only one input, but %d are given."
% (in_proto.name, len(in_args)))
in_arg_names = []
for arg in in_args:
if isinstance(arg, six.string_types):
in_arg_names.append(arg)
elif isinstance(arg, six.binary_type):
in_arg_names.append(arg.decode())
else:
in_arg_names.append(cpt.to_text(arg.name))
self.desc.set_input(in_proto.name, in_arg_names)
else:
self.desc.set_input(in_proto.name, [])
self.inputs = [] if not inputs else inputs
for in_proto in proto.inputs:
found = find_name(self.inputs, in_proto.name)
assert found or in_proto.dispensable, "Input {} not found".format(
in_proto.name)
if found:
in_args = self.inputs[in_proto.name]
if not isinstance(in_args, list):
in_args = [in_args]
if not in_proto.duplicable and len(in_args) > 1:
raise ValueError(
"Input %s expects only one input, but %d are given." %
(in_proto.name, len(in_args)))
in_arg_names = []
for arg in in_args:
if isinstance(arg, six.string_types):
in_arg_names.append(arg)
elif isinstance(arg, six.binary_type):
in_arg_names.append(arg.decode())
else:
in_arg_names.append(cpt.to_text(arg.name))
self.desc.set_input(in_proto.name, in_arg_names)
else:
self.desc.set_input(in_proto.name, [])
self.outputs = [] if not outputs else outputs
if outputs is not None:
given = set()
need = set()
......@@ -1222,7 +1223,8 @@ class Block(object):
if _in_imperative_mode():
op_desc = core.OpDesc()
op = Operator(block=self, desc=op_desc, *args, **kwargs)
_imperative_tracer().trace(op.desc)
sys.stderr.write('%s %s!!!\n' % (type(op.inputs), type(op.outputs)))
_imperative_tracer().trace(op, op.inputs, op.outputs)
return
op_desc = self.desc.append_op()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册