提交 ad8fa77c 编写于 作者: Y Yu Yang

Merge branch 'develop' into feature/add_some_skeletons_of_ops

#pragma once #pragma once
#include <algorithm> #include <algorithm>
#include <atomic>
#include <type_traits> #include <type_traits>
#include <unordered_map> #include <unordered_map>
#include <unordered_set> #include <unordered_set>
...@@ -214,29 +215,61 @@ class OpRegistry { ...@@ -214,29 +215,61 @@ class OpRegistry {
} }
static OperatorPtr CreateOp(const OpDesc& op_desc) { static OperatorPtr CreateOp(const OpDesc& op_desc) {
//! Create a OpPtr by type.
std::string op_type = op_desc.type(); std::string op_type = op_desc.type();
OperatorPtr op(creators().at(op_type)()); OperatorPtr op(creators().at(op_type)());
//! Fill op's data member. Not use constructor because it will be noising
//! for Op developer.
const OpProto& op_proto = protos().at(op_type);
op->type_ = op_desc.type(); op->type_ = op_desc.type();
// set op's inputs_ from desc.
op->inputs_.reserve((size_t)op_desc.inputs_size()); op->inputs_.reserve((size_t)op_desc.inputs_size());
std::copy(op_desc.inputs().begin(), op_desc.inputs().end(), std::copy(op_desc.inputs().begin(), op_desc.inputs().end(),
std::back_inserter(op->inputs_)); std::back_inserter(op->inputs_));
// set op's outputs_ from desc.
op->outputs_.reserve((size_t)op_desc.outputs_size()); op->outputs_.reserve((size_t)op_desc.outputs_size());
std::copy(op_desc.outputs().begin(), op_desc.outputs().end(), std::copy(op_desc.outputs().begin(), op_desc.outputs().end(),
std::back_inserter(op->outputs_)); std::back_inserter(op->outputs_));
//! Fill attrs, and validate attrs.
for (auto& attr : op_desc.attrs()) { for (auto& attr : op_desc.attrs()) {
op->attrs_[attr.name()] = AttrTypeHelper::GetAttrValue(attr); op->attrs_[attr.name()] = AttrTypeHelper::GetAttrValue(attr);
} }
op_checkers().at(op_type).Check(op->attrs_); op_checkers().at(op_type).Check(op->attrs_);
//! Convert Temporary variable name to an unique variable name.
GenerateTempVariableName(op.get());
// set argument offsets stored in op.
CreateInOutOffsetMap(op, op_proto);
//! Other op's custom Init for a complex Op. For simple Op, the Init
//! method do nothing.
op->Init(); op->Init();
return op; return op;
} }
// init op.in_out_idxs_ to accelerate argument's offset lookup.
static void CreateInOutOffsetMap(OperatorPtr op, const OpProto& proto) {
op->CreateInOutOffsetMap(proto);
}
static std::unordered_map<std::string, OpProto>& protos() { static std::unordered_map<std::string, OpProto>& protos() {
static std::unordered_map<std::string, OpProto> protos_; static std::unordered_map<std::string, OpProto> protos_;
return protos_; return protos_;
}; };
private: private:
static void GenerateTempVariableName(OperatorBase* op) {
static std::atomic<size_t> gUniqId(0UL);
for (auto& outname : op->outputs_) {
if (outname == OperatorBase::TMP_VAR_NAME()) {
outname += op->type_;
outname += "@";
outname += std::to_string(gUniqId.fetch_add(1));
}
}
}
static std::unordered_map<std::string, OpCreator>& creators() { static std::unordered_map<std::string, OpCreator>& creators() {
static std::unordered_map<std::string, OpCreator> creators_; static std::unordered_map<std::string, OpCreator> creators_;
return creators_; return creators_;
......
...@@ -12,30 +12,86 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,30 +12,86 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include <algorithm>
#include "paddle/framework/operator.h" #include "paddle/framework/operator.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
void OperatorBase::CreateInOutOffsetMap(const OpProto& proto) {
PADDLE_ENFORCE(in_out_idxs_.empty(), "duplicate call CreateInOutOffsetMap");
for (int i = 0; i < proto.inputs_size(); i++) {
const auto& name = proto.inputs()[i].name();
in_out_idxs_[name] = i;
}
for (int i = 0; i < proto.outputs_size(); i++) {
const auto& name = proto.outputs()[i].name();
in_out_idxs_[name] = i;
}
}
const std::string& OperatorBase::Input(const std::string& name) const {
auto it = in_out_idxs_.find(name);
PADDLE_ENFORCE(it != in_out_idxs_.end(), "no key [%s] in in_out_idxs_", name);
if (attrs_.count("input_format") == 0) {
return inputs_[it->second];
} else {
const auto& input_format = GetAttr<std::vector<int>>("input_format");
int idx = input_format[it->second];
return inputs_.at(idx);
}
}
std::vector<std::string> OperatorBase::Inputs(const std::string& name) const {
auto input_format = GetAttr<std::vector<int>>("input_format");
auto offset = in_out_idxs_.at(name);
return std::vector<std::string>{
inputs_.begin() + input_format.at(offset),
inputs_.begin() + input_format.at(offset + 1)};
}
const std::string& OperatorBase::Output(const std::string& name) const {
auto it = in_out_idxs_.find(name);
PADDLE_ENFORCE(it != in_out_idxs_.end(), "no key [%s] in in_out_idxs_", name);
if (attrs_.count("output_format") == 0) {
return outputs_[it->second];
} else {
const auto& output_format = GetAttr<std::vector<int>>("output_format");
int idx = output_format[it->second];
return outputs_.at(idx);
}
}
std::vector<std::string> OperatorBase::Outputs(const std::string& name) const {
auto output_format = GetAttr<std::vector<int>>("output_format");
auto offset = in_out_idxs_.at(name);
return std::vector<std::string>{
outputs_.begin() + output_format.at(offset),
outputs_.begin() + output_format.at(offset + 1)};
}
std::string OperatorBase::DebugString() const { std::string OperatorBase::DebugString() const {
std::stringstream ss; std::stringstream ss;
ss << "=================\n"; ss << "Op(" << type_ << "), inputs:(";
ss << "type = " << type_ << "\n"; for (size_t i = 0; i < inputs_.size(); ++i) {
ss << "inputs = ["; ss << inputs_[i];
for (auto& ipt : inputs_) { if (i != inputs_.size() - 1) {
ss << ipt << ", "; ss << ", ";
} }
ss << "]\n";
ss << "outputs = [";
for (auto& opt : outputs_) {
ss << opt << ", ";
} }
ss << "]\n"; ss << "), outputs:(";
ss << "attr_keys = ["; for (size_t i = 0; i < outputs_.size(); ++i) {
for (auto& attr : attrs_) { ss << outputs_[i];
ss << attr.first << ", "; if (i != outputs_.size() - 1) {
ss << ", ";
}
} }
ss << "]\n"; ss << ").";
return ss.str(); return ss.str();
} }
......
...@@ -14,18 +14,20 @@ limitations under the License. */ ...@@ -14,18 +14,20 @@ limitations under the License. */
#pragma once #pragma once
#include <paddle/framework/attr_checker.h>
#include <paddle/framework/op_desc.pb.h>
#include <paddle/framework/scope.h>
#include <paddle/framework/tensor.h>
#include <paddle/platform/device_context.h>
#include <paddle/platform/place.h>
#include <paddle/utils/Error.h>
#include <boost/variant.hpp> #include <boost/variant.hpp>
#include <string> #include <string>
#include <unordered_map> #include <unordered_map>
#include <vector> #include <vector>
#include "paddle/framework/attr_checker.h"
#include "paddle/framework/op_desc.pb.h"
#include "paddle/framework/op_proto.pb.h"
#include "paddle/framework/scope.h"
#include "paddle/framework/tensor.h"
#include "paddle/platform/device_context.h"
#include "paddle/platform/place.h"
#include "paddle/utils/Error.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
...@@ -39,6 +41,13 @@ using OperatorPtr = std::shared_ptr<OperatorBase>; ...@@ -39,6 +41,13 @@ using OperatorPtr = std::shared_ptr<OperatorBase>;
*/ */
class OperatorBase { class OperatorBase {
public: public:
/// If a variable is a empty variable, that name will be used.
static std::string EMPTY_VAR_NAME() { return "@EMPTY@"; }
/// If a variable is a temporary variable, that name will be set in Python,
/// but it will be convert to a unique name in scope after OpCreator.
static std::string TMP_VAR_NAME() { return "@TEMP@"; }
virtual ~OperatorBase() {} virtual ~OperatorBase() {}
template <typename T> template <typename T>
...@@ -62,11 +71,72 @@ class OperatorBase { ...@@ -62,11 +71,72 @@ class OperatorBase {
virtual void Run(const ScopePtr& scope, virtual void Run(const ScopePtr& scope,
const platform::DeviceContext& dev_ctx) const = 0; const platform::DeviceContext& dev_ctx) const = 0;
// Get a input with argument's name described in `op_proto`
const std::string& Input(const std::string& name) const;
// Get a input which has multiple variables.
// TODO add a vector_view to prevent memory copy.
std::vector<std::string> Inputs(const std::string& name) const;
// Get a output with argument's name described in `op_proto`
const std::string& Output(const std::string& name) const;
// Get an output which has multiple variables.
// TODO add a vector_view to prevent memory copy.
std::vector<std::string> Outputs(const std::string& name) const;
// init in_out_idxs_ to accelerate argument's offset lookup.
void CreateInOutOffsetMap(const OpProto& proto);
public: public:
std::string type_; std::string type_;
std::vector<std::string> inputs_; std::vector<std::string> inputs_;
std::vector<std::string> outputs_; std::vector<std::string> outputs_;
AttributeMap attrs_; AttributeMap attrs_;
// store the arguments' offset described in op_desc.
std::unordered_map<std::string, int> in_out_idxs_;
};
class KernelContext {
public:
KernelContext(const OperatorBase* op, const std::shared_ptr<Scope>& scope,
const platform::DeviceContext& device_context)
: op_(*op), scope_(scope), device_context_(device_context) {}
const Variable* Input(int index) const {
return scope_->GetVariable(op_.inputs_[index]);
}
Variable* Output(int index) const {
return scope_->GetVariable(op_.outputs_[index]);
}
const Variable* Input(const std::string& name) const {
return scope_->GetVariable(op_.Input(name));
}
const Variable* Output(const std::string& name) const {
return scope_->GetVariable(op_.Output(name));
}
const std::vector<const Variable*> Inputs(const std::string& name) const {
auto names = op_.Inputs(name);
std::vector<const Variable*> res;
std::transform(
names.begin(), names.end(), res.begin(),
[this](const std::string& name) { return scope_->GetVariable(name); });
return res;
}
const std::vector<const Variable*> Outputs(const std::string& name) const {
auto names = op_.Outputs(name);
std::vector<const Variable*> res;
std::transform(
names.begin(), names.end(), res.begin(),
[this](const std::string& name) { return scope_->GetVariable(name); });
return res;
}
const OperatorBase& op_;
const std::shared_ptr<Scope>& scope_;
const platform::DeviceContext& device_context_;
}; };
class OpKernel { class OpKernel {
...@@ -77,25 +147,6 @@ class OpKernel { ...@@ -77,25 +147,6 @@ class OpKernel {
* device resource such as CUDA stream, cublas handle, etc. from * device resource such as CUDA stream, cublas handle, etc. from
* KernelContext. User should construct it before run the Operator. * KernelContext. User should construct it before run the Operator.
*/ */
class KernelContext {
public:
KernelContext(const OperatorBase* op, const ScopePtr& scope,
const platform::DeviceContext& device_context)
: op_(*op), scope_(scope), device_context_(device_context) {}
const Variable* Input(int index) const {
return scope_->GetVariable(op_.inputs_[index]);
}
Variable* Output(int index) const {
return scope_->GetVariable(op_.outputs_[index]);
}
const OperatorBase& op_;
const ScopePtr& scope_;
const platform::DeviceContext& device_context_;
};
virtual void Compute(const KernelContext& context) const = 0; virtual void Compute(const KernelContext& context) const = 0;
virtual ~OpKernel() {} virtual ~OpKernel() {}
...@@ -140,7 +191,7 @@ class OperatorWithKernel : public OperatorBase { ...@@ -140,7 +191,7 @@ class OperatorWithKernel : public OperatorBase {
void Run(const ScopePtr& scope, void Run(const ScopePtr& scope,
const platform::DeviceContext& dev_ctx) const final { const platform::DeviceContext& dev_ctx) const final {
auto& opKernel = AllOpKernels().at(type_).at(OpKernelKey(dev_ctx)); auto& opKernel = AllOpKernels().at(type_).at(OpKernelKey(dev_ctx));
opKernel->Compute(OpKernel::KernelContext(this, scope, dev_ctx)); opKernel->Compute(KernelContext(this, scope, dev_ctx));
} }
static std::unordered_map<std::string /* op_type */, OpKernelMap>& static std::unordered_map<std::string /* op_type */, OpKernelMap>&
...@@ -148,6 +199,7 @@ class OperatorWithKernel : public OperatorBase { ...@@ -148,6 +199,7 @@ class OperatorWithKernel : public OperatorBase {
static std::unordered_map<std::string, OpKernelMap> g_all_op_kernels; static std::unordered_map<std::string, OpKernelMap> g_all_op_kernels;
return g_all_op_kernels; return g_all_op_kernels;
} }
void InferShape(const std::shared_ptr<Scope>& scope) const final { void InferShape(const std::shared_ptr<Scope>& scope) const final {
std::vector<const Tensor*> ins; std::vector<const Tensor*> ins;
VarNamesToTensors(scope, inputs_, &ins); VarNamesToTensors(scope, inputs_, &ins);
......
...@@ -30,7 +30,6 @@ class OpWithoutKernelTest : public OperatorBase { ...@@ -30,7 +30,6 @@ class OpWithoutKernelTest : public OperatorBase {
op_run_num++; op_run_num++;
ASSERT_EQ((int)inputs_.size(), 1); ASSERT_EQ((int)inputs_.size(), 1);
ASSERT_EQ((int)outputs_.size(), 1); ASSERT_EQ((int)outputs_.size(), 1);
ASSERT_NEAR(GetAttr<float>("scale"), 3.14, 1e-5);
ASSERT_EQ(scope->GetVariable(inputs_[0]), nullptr); ASSERT_EQ(scope->GetVariable(inputs_[0]), nullptr);
ASSERT_EQ(x, 1); ASSERT_EQ(x, 1);
ASSERT_NE(scope->GetVariable(outputs_[0]), nullptr); ASSERT_NE(scope->GetVariable(outputs_[0]), nullptr);
...@@ -86,9 +85,11 @@ class OpKernelTestProtoAndCheckerMaker : public OpProtoAndCheckerMaker { ...@@ -86,9 +85,11 @@ class OpKernelTestProtoAndCheckerMaker : public OpProtoAndCheckerMaker {
public: public:
OpKernelTestProtoAndCheckerMaker(OpProto* proto, OpAttrChecker* op_checker) OpKernelTestProtoAndCheckerMaker(OpProto* proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) { : OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("input", "input of test op"); AddInput("x", "input of test op");
AddOutput("output", "output of test op"); AddOutput("y", "output of test op");
AddAttr<float>("scale", "scale of cosine op"); AddAttr<float>("scale", "scale of cosine op")
.SetDefault(1.0)
.LargerThan(0.0);
AddComment("This is test op"); AddComment("This is test op");
} }
}; };
...@@ -103,11 +104,65 @@ class OpWithKernelTest : public OperatorWithKernel { ...@@ -103,11 +104,65 @@ class OpWithKernelTest : public OperatorWithKernel {
class CPUKernelTest : public OpKernel { class CPUKernelTest : public OpKernel {
public: public:
void Compute(const KernelContext& context) const { void Compute(const KernelContext& ctx) const {
std::cout << "this is cpu kernel" << std::endl;
std::cout << ctx.op_.DebugString() << std::endl;
cpu_kernel_run_num++; cpu_kernel_run_num++;
ASSERT_EQ((int)context.op_.inputs_.size(), 1); ASSERT_EQ(ctx.op_.Input("x"), "IN1");
ASSERT_EQ((int)context.op_.outputs_.size(), 1); ASSERT_EQ(ctx.op_.Output("y"), "OUT1");
ASSERT_NEAR(context.op_.GetAttr<float>("scale"), 3.14, 1e-5); }
};
// multiple inputs test
class OperatorMultiInputsTest : public OperatorBase {
public:
void Init() override { x = 1; }
void InferShape(const std::shared_ptr<Scope>& scope) const override {}
void Run(const std::shared_ptr<Scope>& scope,
const platform::DeviceContext& dev_ctx) const override {
ASSERT_EQ(scope->GetVariable(inputs_[0]), nullptr);
ASSERT_EQ(x, 1);
ASSERT_NE(scope->GetVariable(outputs_[0]), nullptr);
ASSERT_EQ(Input("x"), "IN1");
ASSERT_EQ(Input("y"), "OUT1");
}
public:
float x = 0;
};
class OpKernelTestMultiInputsProtoAndCheckerMaker
: public OpProtoAndCheckerMaker {
public:
OpKernelTestMultiInputsProtoAndCheckerMaker(OpProto* proto,
OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInputs("xs", "inputs of test op");
AddInput("k", "input of test op");
AddOutputs("ys", "outputs of test op");
AddAttr<float>("scale", "scale of cosine op")
.SetDefault(1.0)
.LargerThan(0.0);
AddComment("This is test op");
}
};
class CPUKernalMultiInputsTest : public OpKernel {
public:
void Compute(const KernelContext& ctx) const {
auto xs = ctx.op_.Inputs("xs");
ASSERT_EQ(xs.size(), 3UL);
ASSERT_EQ(xs[0], "x0");
ASSERT_EQ(xs[1], "x1");
ASSERT_EQ(xs[2], "x2");
auto k = ctx.op_.Input("k");
ASSERT_EQ(k, "k0");
auto ys = ctx.op_.Outputs("ys");
ASSERT_EQ(ys.size(), 2UL);
ASSERT_EQ(ys[0], "y0");
ASSERT_EQ(ys[1], "y1");
} }
}; };
...@@ -118,6 +173,7 @@ REGISTER_OP(op_with_kernel, paddle::framework::OpWithKernelTest, ...@@ -118,6 +173,7 @@ REGISTER_OP(op_with_kernel, paddle::framework::OpWithKernelTest,
paddle::framework::OpKernelTestProtoAndCheckerMaker); paddle::framework::OpKernelTestProtoAndCheckerMaker);
REGISTER_OP_CPU_KERNEL(op_with_kernel, paddle::framework::CPUKernelTest); REGISTER_OP_CPU_KERNEL(op_with_kernel, paddle::framework::CPUKernelTest);
// test with single input
TEST(OpKernel, all) { TEST(OpKernel, all) {
paddle::framework::OpDesc op_desc; paddle::framework::OpDesc op_desc;
op_desc.set_type("op_with_kernel"); op_desc.set_type("op_with_kernel");
...@@ -137,3 +193,47 @@ TEST(OpKernel, all) { ...@@ -137,3 +193,47 @@ TEST(OpKernel, all) {
op->Run(scope, cpu_device_context); op->Run(scope, cpu_device_context);
ASSERT_EQ(paddle::framework::cpu_kernel_run_num, 1); ASSERT_EQ(paddle::framework::cpu_kernel_run_num, 1);
} }
REGISTER_OP(op_multi_inputs_with_kernel, paddle::framework::OpWithKernelTest,
paddle::framework::OpKernelTestMultiInputsProtoAndCheckerMaker);
REGISTER_OP_CPU_KERNEL(op_multi_inputs_with_kernel,
paddle::framework::CPUKernalMultiInputsTest);
// test with multi inputs
TEST(OpKernel, multi_inputs) {
using namespace paddle::framework;
OpDesc op_desc;
op_desc.set_type("op_multi_inputs_with_kernel");
*op_desc.mutable_inputs()->Add() = "x0";
*op_desc.mutable_inputs()->Add() = "x1";
*op_desc.mutable_inputs()->Add() = "x2";
*op_desc.mutable_inputs()->Add() = "k0";
*op_desc.mutable_outputs()->Add() = "y0";
*op_desc.mutable_outputs()->Add() = "y1";
auto attr = op_desc.mutable_attrs()->Add();
attr->set_name("scale");
attr->set_type(paddle::framework::AttrType::FLOAT);
attr->set_f(3.14);
auto attr0 = op_desc.mutable_attrs()->Add();
attr0->set_name("input_format");
attr0->set_type(paddle::framework::AttrType::INTS);
auto input_format = attr0->mutable_ints();
input_format->Add(0); // x0
input_format->Add(3); // k
input_format->Add(4); // end
auto attr1 = op_desc.mutable_attrs()->Add();
attr1->set_name("output_format");
attr1->set_type(paddle::framework::AttrType::INTS);
auto output_format = attr1->mutable_ints();
output_format->Add(0); // y0
output_format->Add(2); // y1
paddle::platform::CPUDeviceContext cpu_device_context;
auto scope = std::make_shared<Scope>();
OperatorPtr op(paddle::framework::OpRegistry::CreateOp(op_desc));
op->Run(scope, cpu_device_context);
}
...@@ -8,10 +8,10 @@ namespace operators { ...@@ -8,10 +8,10 @@ namespace operators {
template <typename Place> template <typename Place>
class AddKernel : public framework::OpKernel { class AddKernel : public framework::OpKernel {
public: public:
void Compute(const KernelContext &context) const override { void Compute(const framework::KernelContext &context) const override {
LOG(INFO) << "Add kernel in " << typeid(Place).name(); LOG(INFO) << "Add kernel in " << typeid(Place).name();
} }
}; };
} // namespace op } // namespace operators
} // namespace paddle } // namespace paddle
...@@ -67,6 +67,23 @@ All parameter, weight, gradient are variables in Paddle. ...@@ -67,6 +67,23 @@ All parameter, weight, gradient are variables in Paddle.
} }
return ret_values; return ret_values;
}); });
m.def_submodule(
"var_names",
"The module will return special predefined variable name in Paddle")
.def("empty", pd::OperatorBase::EMPTY_VAR_NAME)
.def("temp", pd::OperatorBase::TMP_VAR_NAME);
py::class_<pd::OperatorBase, pd::OperatorPtr>(m, "Operator")
.def("__str__", &pd::OperatorBase::DebugString)
.def_static("create", [](const std::string& protobin) {
pd::OpDesc desc;
PADDLE_ENFORCE(desc.ParsePartialFromString(protobin),
"Cannot parse user input to OpDesc");
PADDLE_ENFORCE(desc.IsInitialized(),
"User OpDesc is not initialized, reason %s",
desc.InitializationErrorString());
return pd::OpRegistry::CreateOp(desc);
});
return m.ptr(); return m.ptr();
} }
import paddle.v2.framework.core as core import paddle.v2.framework.core as core
import paddle.v2.framework.proto.op_proto_pb2 as op_proto_pb2 import paddle.v2.framework.proto.op_proto_pb2 as op_proto_pb2
import paddle.v2.framework.proto.op_desc_pb2 as op_desc_pb2
import paddle.v2.framework.proto.attr_type_pb2 as attr_type_pb2
import cStringIO
def get_all_op_protos(): def get_all_op_protos():
"""
Get all registered op proto from Paddle C++
:return: list of OpProto
"""
protostrs = core.get_all_op_protos() protostrs = core.get_all_op_protos()
ret_values = [] ret_values = []
for pbstr in protostrs: for pbstr in protostrs:
op_proto = op_proto_pb2.OpProto.FromString(str(pbstr)) op_proto = op_proto_pb2.OpProto.FromString(str(pbstr))
ret_values.append(op_proto) ret_values.append(op_proto)
return ret_values return ret_values
class OpDescCreationMethod(object):
"""
A Functor object to convert user input(use key word args) to OpDesc based on
OpProto.
:param op_proto: The OpProto object.
:type op_proto: op_proto_pb2.OpProto
"""
def __init__(self, op_proto):
if not isinstance(op_proto, op_proto_pb2.OpProto):
raise TypeError("Argument should be OpProto")
self.__op_proto__ = op_proto
def __call__(self, *args, **kwargs):
"""
Convert user input to OpDesc. Only key-word args are supported.
:return: OpDesc based on user input
:rtype: op_desc_pb2.OpDesc
"""
if len(args) != 0:
raise ValueError("Only keyword arguments is supported by Paddle")
op_desc = op_desc_pb2.OpDesc()
# Inputs
ipts, ipt_format, _ = OpDescCreationMethod.extract_input_or_output(
"input", kwargs, self.__op_proto__.inputs)
op_desc.inputs.extend(ipts)
if ipt_format is not None:
op_desc.attrs.extend([ipt_format])
# Outputs
outs, out_format, tmp_index = OpDescCreationMethod.extract_input_or_output(
"output", kwargs, self.__op_proto__.outputs)
op_desc.outputs.extend(outs)
if out_format is not None:
op_desc.attrs.extend([out_format])
if len(tmp_index) != 0:
tmp_index_attr = op_desc.attrs.add()
tmp_index_attr.type = attr_type_pb2.INTS
tmp_index_attr.name = "temporary_index"
tmp_index_attr.ints.extend(tmp_index)
# Types
op_desc.type = self.__op_proto__.type
# Attrs
for attr in self.__op_proto__.attrs:
if attr.generated:
continue
user_defined_attr = kwargs.get(attr.name, None)
if user_defined_attr is not None:
new_attr = op_desc.attrs.add()
new_attr.name = attr.name
new_attr.type = attr.type
if attr.type == attr_type_pb2.INT:
new_attr.i = user_defined_attr
elif attr.type == attr_type_pb2.FLOAT:
new_attr.f = user_defined_attr
elif attr.type == attr_type_pb2.STRING:
new_attr.s = user_defined_attr
elif attr.type == attr_type_pb2.INTS:
new_attr.ints.extend(user_defined_attr)
elif attr.type == attr_type_pb2.FLOATS:
new_attr.floats.extend(user_defined_attr)
elif attr.type == attr_type_pb2.STRINGS:
new_attr.strings.extend(user_defined_attr)
else:
raise NotImplementedError("Not support attribute type " +
attr.type)
return op_desc
@staticmethod
def extract_input_or_output(in_out, kwargs, meta):
"""
Extract input variable names or output variable names from key-word
arguments, which base on VarProtos.
:param in_out: "input" or "output"
:param kwargs: key-word arguments that user inputted.
:param meta: a list of VarProto
:return: The three object will be return. The variable names. The
input_format or output_format attribute(None if the input or output is
not multiple). The temporary variable index list.
"""
multiple = OpDescCreationMethod.any_is_true((m.multiple for m in meta))
tmp_index = []
retv = []
if multiple:
var_format = op_desc_pb2.AttrDesc()
var_format.type = attr_type_pb2.INTS
var_format.name = "%s_format" % in_out
var_format.ints.append(0)
for var in meta:
var_name = var.name
if var.temporary:
var_name = [core.var_names.temp()]
tmp_index.append(len(retv))
else:
var_name = kwargs.get(var_name, [])
if not isinstance(var_name, list):
var_name = [var_name]
retv.extend(var_name)
var_format.ints.append(len(var_name) + var_format.ints[-1])
return retv, var_format, tmp_index
else:
for var in meta:
if var.temporary:
retv.append(kwargs.get(var.name, core.var_names.temp()))
tmp_index.append(len(retv))
else:
retv.append(kwargs.get(var.name, core.var_names.empty()))
return retv, None, tmp_index
@staticmethod
def any_is_true(generator):
"""
Reduce a bool array to one. If any of them is True, then return True.
"""
for flag in generator:
if flag:
return True
return False
def get_docstring_from_op_proto(op_proto):
"""
Generate docstring from a OpProto
:param op_proto: a OpProto instance.
:type op_proto: op_proto_pb2.OpProto
:return: docstring
"""
if not isinstance(op_proto, op_proto_pb2.OpProto):
raise TypeError("Input must be OpProto")
f = cStringIO.StringIO()
f.write(op_proto.comment)
f.write("\n")
def __append_param__(name, comment, type):
# Maybe replace the following line with template engine is better.
f.write(":param ")
f.write(name)
f.write(": ")
f.write(comment)
f.write("\n")
f.write(":type ")
f.write(name)
f.write(": ")
f.write(type)
f.write("\n")
for ipt in op_proto.inputs:
__append_param__(ipt.name, ipt.comment, "list | basestr"
if ipt.multiple else "basestr")
temp_var_prefix = \
"This is a temporary variable. It does not have to set by user. "
for opt in op_proto.outputs:
__append_param__(opt.name, opt.comment if not opt.temporary else
temp_var_prefix + opt.comment, "list | basestr"
if opt.multiple else "basestr")
for attr in op_proto.attrs:
attr_type = None
if attr.type == attr_type_pb2.INT:
attr_type = "int"
elif attr.type == attr_type_pb2.FLOAT:
attr_type = "float"
elif attr.type == attr_type_pb2.STRING:
attr_type = "basestr"
elif attr.type == attr_type_pb2.INTS:
attr_type = "list of int"
elif attr.type == attr_type_pb2.FLOATS:
attr_type = "list of float"
elif attr.type == attr_type_pb2.STRINGS:
attr_type = "list of basestr"
if attr_type is None:
raise RuntimeError("Not supported attribute type " + attr.type)
__append_param__(attr.name, attr.comment, attr_type)
return f.getvalue()
def create_op_creation_method(op_proto):
"""
Generate op creation method for an OpProto
"""
method = OpDescCreationMethod(op_proto)
def __impl__(*args, **kwargs):
opdesc = method(*args, **kwargs)
return core.Operator.create(opdesc.SerializeToString())
__impl__.__doc__ = get_docstring_from_op_proto(op_proto)
return __impl__
class OpCreationsHolder(object):
"""
A object will holds all op creation methods.
Use `op_creations.xxx_op` to access them.
"""
pass
op_creations = OpCreationsHolder()
def __bootstrap__():
"""
Bootstrap function for this module. It will dynamic create all op creation
methods in runtime.
"""
for op_proto in get_all_op_protos():
func = create_op_creation_method(op_proto)
func.__name__ = str(op_proto.type)
setattr(op_creations, func.__name__, func)
__bootstrap__()
import unittest import unittest
import paddle.v2.framework.create_op_creation_methods as creation import paddle.v2.framework.create_op_creation_methods as creation
import paddle.v2.framework.core as core
import paddle.v2.framework.proto.op_proto_pb2 as op_proto_pb2
import paddle.v2.framework.proto.op_desc_pb2 as op_desc_pb2
import paddle.v2.framework.proto.attr_type_pb2 as attr_type_pb2
class TestOpCreationsMethods(unittest.TestCase): class TestGetAllProtos(unittest.TestCase):
def test_all_protos(self): def test_all(self):
all_protos = creation.get_all_op_protos() all_protos = creation.get_all_op_protos()
self.assertNotEqual(0, len(all_protos)) self.assertNotEqual(0, len(all_protos))
...@@ -11,5 +15,240 @@ class TestOpCreationsMethods(unittest.TestCase): ...@@ -11,5 +15,240 @@ class TestOpCreationsMethods(unittest.TestCase):
self.assertTrue(each.IsInitialized()) self.assertTrue(each.IsInitialized())
class TestOpDescCreationMethod(unittest.TestCase):
def test_plain_input_output(self):
op = op_proto_pb2.OpProto()
op.type = "test"
ipt = op.inputs.add()
ipt.name = "X"
ipt.comment = "not matter"
ipt = op.inputs.add()
ipt.name = "Y"
ipt.comment = "not matter"
opt = op.outputs.add()
opt.name = "Z"
opt.comment = "not matter"
op.comment = "not matter"
self.assertTrue(op.IsInitialized())
method = creation.OpDescCreationMethod(op)
output = method(X="a", Y="b", Z="c")
expected = op_desc_pb2.OpDesc()
expected.type = "test"
expected.inputs.extend(["a", "b"])
expected.outputs.append("c")
self.assertEqual(expected, output)
def test_multiple_input_plain_output(self):
op = op_proto_pb2.OpProto()
op.type = "fc"
ipt = op.inputs.add()
ipt.name = "X"
ipt.comment = ""
ipt.multiple = True
ipt = op.inputs.add()
ipt.name = "W"
ipt.comment = ""
ipt.multiple = True
ipt = op.inputs.add()
ipt.name = "b"
ipt.comment = ""
out = op.outputs.add()
out.name = "Y"
out.comment = ""
op.comment = ""
self.assertTrue(op.IsInitialized())
method = creation.OpDescCreationMethod(op)
generated1 = method(X="x", W="w", b="b", Y="y")
expected1 = op_desc_pb2.OpDesc()
expected1.inputs.extend(['x', 'w', 'b'])
expected1.outputs.extend(['y'])
expected1.type = 'fc'
attr = expected1.attrs.add()
attr.name = 'input_format'
attr.type = attr_type_pb2.INTS
attr.ints.extend([0, 1, 2, 3])
self.assertEqual(expected1, generated1)
generated2 = method(
X=['x1', 'x2', 'x3'], b='b', W=['w1', 'w2', 'w3'], Y='y')
expected2 = op_desc_pb2.OpDesc()
expected2.inputs.extend(['x1', 'x2', 'x3', 'w1', 'w2', 'w3', 'b'])
expected2.outputs.extend(['y'])
expected2.type = 'fc'
attr = expected2.attrs.add()
attr.name = 'input_format'
attr.type = attr_type_pb2.INTS
attr.ints.extend([0, 3, 6, 7])
self.assertEqual(expected2, generated2)
def test_attrs(self):
op = op_proto_pb2.OpProto()
op.type = "test"
ipt = op.inputs.add()
ipt.name = 'X'
ipt.comment = ""
def __add_attr__(name, type):
attr = op.attrs.add()
attr.name = name
attr.comment = ""
attr.type = type
__add_attr__("int_attr", attr_type_pb2.INT)
__add_attr__("float_attr", attr_type_pb2.FLOAT)
__add_attr__("string_attr", attr_type_pb2.STRING)
__add_attr__("ints_attr", attr_type_pb2.INTS)
__add_attr__("floats_attr", attr_type_pb2.FLOATS)
__add_attr__("strings_attr", attr_type_pb2.STRINGS)
op.comment = ""
self.assertTrue(op.IsInitialized())
method = creation.OpDescCreationMethod(op)
generated = method(
X="a",
int_attr=10,
float_attr=3.2,
string_attr="test_str",
ints_attr=[0, 1, 2, 3, 4],
floats_attr=[0.2, 3.2, 4.5],
strings_attr=["a", "b", "c"])
expected = op_desc_pb2.OpDesc()
expected.type = "test"
expected.inputs.extend(['a'])
attr = expected.attrs.add()
attr.name = "int_attr"
attr.type = attr_type_pb2.INT
attr.i = 10
attr = expected.attrs.add()
attr.name = "float_attr"
attr.type = attr_type_pb2.FLOAT
attr.f = 3.2
attr = expected.attrs.add()
attr.name = "string_attr"
attr.type = attr_type_pb2.STRING
attr.s = "test_str"
attr = expected.attrs.add()
attr.name = "ints_attr"
attr.type = attr_type_pb2.INTS
attr.ints.extend([0, 1, 2, 3, 4])
attr = expected.attrs.add()
attr.name = 'floats_attr'
attr.type = attr_type_pb2.FLOATS
attr.floats.extend([0.2, 3.2, 4.5])
attr = expected.attrs.add()
attr.name = 'strings_attr'
attr.type = attr_type_pb2.STRINGS
attr.strings.extend(['a', 'b', 'c'])
self.assertEqual(expected, generated)
def test_input_temporary_output(self):
op = op_proto_pb2.OpProto()
op.type = "test"
out = op.outputs.add()
out.name = "OUT"
out.comment = ""
out = op.outputs.add()
out.name = "TMP"
out.comment = ""
out.temporary = True
out = op.outputs.add()
out.name = "OUT2"
out.comment = ""
op.comment = ""
method = creation.OpDescCreationMethod(op)
generated = method(OUT="a", OUT2="b")
desc = op_desc_pb2.OpDesc()
desc.outputs.extend(["a", core.var_names.temp(), "b"])
desc.type = "test"
attr = desc.attrs.add()
attr.name = "temporary_index"
attr.type = attr_type_pb2.INTS
attr.ints.append(2)
self.assertEqual(generated, desc)
class TestOpCreationDocStr(unittest.TestCase):
def test_all(self):
op = op_proto_pb2.OpProto()
op.type = "test"
op.comment = """Test Op.
This op is used for unit test, not a real op.
"""
a = op.inputs.add()
a.name = "a"
a.comment = "Input a for test op"
a.multiple = True
b = op.inputs.add()
b.name = "b"
b.comment = "Input b for test op"
self.assertTrue(op.IsInitialized())
o1 = op.outputs.add()
o1.name = "output"
o1.comment = "The output of test op"
o2 = op.outputs.add()
o2.name = "temp output"
o2.comment = "The temporary output of test op"
o2.temporary = True
test_str = op.attrs.add()
test_str.name = "str_attr"
test_str.type = attr_type_pb2.STRING
test_str.comment = "A string attribute for test op"
actual = creation.get_docstring_from_op_proto(op)
expected_docstring = '''Test Op.
This op is used for unit test, not a real op.
:param a: Input a for test op
:type a: list | basestr
:param b: Input b for test op
:type b: basestr
:param output: The output of test op
:type output: basestr
:param temp output: This is a temporary variable. It does not have to set by user. The temporary output of test op
:type temp output: basestr
:param str_attr: A string attribute for test op
:type str_attr: basestr
'''
self.assertEqual(expected_docstring, actual)
class TestOpCreations(unittest.TestCase):
def test_all(self):
add_op = creation.op_creations.add_two(X="a", Y="b", Out="z")
self.assertIsNotNone(add_op)
# Invoke C++ DebugString()
self.assertEqual('Op(add_two), inputs:(a, b), outputs:(z).',
str(add_op))
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
import py_paddle.swig_paddle as swig_api
import paddle.trainer_config_helpers.config_parser_utils as config_parser_utils import paddle.trainer_config_helpers.config_parser_utils as config_parser_utils
import paddle.trainer_config_helpers.optimizers as v1_optimizers import paddle.trainer_config_helpers.optimizers as v1_optimizers
""" """
...@@ -17,6 +16,7 @@ __all__ = [ ...@@ -17,6 +16,7 @@ __all__ = [
class Optimizer(object): class Optimizer(object):
def __init__(self, **kwargs): def __init__(self, **kwargs):
import py_paddle.swig_paddle as swig_api
if 'batch_size' in kwargs: if 'batch_size' in kwargs:
del kwargs['batch_size'] # not important for python library. del kwargs['batch_size'] # not important for python library.
...@@ -35,18 +35,22 @@ class Optimizer(object): ...@@ -35,18 +35,22 @@ class Optimizer(object):
For each optimizer(SGD, Adam), GradientMachine should enable different For each optimizer(SGD, Adam), GradientMachine should enable different
buffers. buffers.
""" """
import py_paddle.swig_paddle as swig_api
tmp = swig_api.ParameterOptimizer.create(self.__opt_conf__) tmp = swig_api.ParameterOptimizer.create(self.__opt_conf__)
assert isinstance(tmp, swig_api.ParameterOptimizer) assert isinstance(tmp, swig_api.ParameterOptimizer)
return tmp.getParameterTypes() return tmp.getParameterTypes()
def __create_local_updater__(self): def __create_local_updater__(self):
import py_paddle.swig_paddle as swig_api
return swig_api.ParameterUpdater.createLocalUpdater(self.__opt_conf__) return swig_api.ParameterUpdater.createLocalUpdater(self.__opt_conf__)
def __create_remote_updater__(self, pass_num, use_sparse_updater): def __create_remote_updater__(self, pass_num, use_sparse_updater):
import py_paddle.swig_paddle as swig_api
return swig_api.ParameterUpdater.createRemoteUpdater( return swig_api.ParameterUpdater.createRemoteUpdater(
self.__opt_conf__, pass_num, use_sparse_updater) self.__opt_conf__, pass_num, use_sparse_updater)
def __create_new_remote_updater__(self, pserver_spec, use_etcd): def __create_new_remote_updater__(self, pserver_spec, use_etcd):
import py_paddle.swig_paddle as swig_api
return swig_api.ParameterUpdater.createNewRemoteUpdater( return swig_api.ParameterUpdater.createNewRemoteUpdater(
self.__opt_conf__, pserver_spec, use_etcd) self.__opt_conf__, pserver_spec, use_etcd)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册