You need to sign in or sign up before continuing.
提交 ad8fa77c 编写于 作者: Y Yu Yang

Merge branch 'develop' into feature/add_some_skeletons_of_ops

#pragma once
#include <algorithm>
#include <atomic>
#include <type_traits>
#include <unordered_map>
#include <unordered_set>
......@@ -214,29 +215,61 @@ class OpRegistry {
}
static OperatorPtr CreateOp(const OpDesc& op_desc) {
//! Create a OpPtr by type.
std::string op_type = op_desc.type();
OperatorPtr op(creators().at(op_type)());
//! Fill op's data member. Not use constructor because it will be noising
//! for Op developer.
const OpProto& op_proto = protos().at(op_type);
op->type_ = op_desc.type();
// set op's inputs_ from desc.
op->inputs_.reserve((size_t)op_desc.inputs_size());
std::copy(op_desc.inputs().begin(), op_desc.inputs().end(),
std::back_inserter(op->inputs_));
// set op's outputs_ from desc.
op->outputs_.reserve((size_t)op_desc.outputs_size());
std::copy(op_desc.outputs().begin(), op_desc.outputs().end(),
std::back_inserter(op->outputs_));
//! Fill attrs, and validate attrs.
for (auto& attr : op_desc.attrs()) {
op->attrs_[attr.name()] = AttrTypeHelper::GetAttrValue(attr);
}
op_checkers().at(op_type).Check(op->attrs_);
//! Convert Temporary variable name to an unique variable name.
GenerateTempVariableName(op.get());
// set argument offsets stored in op.
CreateInOutOffsetMap(op, op_proto);
//! Other op's custom Init for a complex Op. For simple Op, the Init
//! method do nothing.
op->Init();
return op;
}
// init op.in_out_idxs_ to accelerate argument's offset lookup.
static void CreateInOutOffsetMap(OperatorPtr op, const OpProto& proto) {
op->CreateInOutOffsetMap(proto);
}
static std::unordered_map<std::string, OpProto>& protos() {
static std::unordered_map<std::string, OpProto> protos_;
return protos_;
};
private:
static void GenerateTempVariableName(OperatorBase* op) {
static std::atomic<size_t> gUniqId(0UL);
for (auto& outname : op->outputs_) {
if (outname == OperatorBase::TMP_VAR_NAME()) {
outname += op->type_;
outname += "@";
outname += std::to_string(gUniqId.fetch_add(1));
}
}
}
static std::unordered_map<std::string, OpCreator>& creators() {
static std::unordered_map<std::string, OpCreator> creators_;
return creators_;
......
......@@ -12,30 +12,86 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <algorithm>
#include "paddle/framework/operator.h"
namespace paddle {
namespace framework {
void OperatorBase::CreateInOutOffsetMap(const OpProto& proto) {
PADDLE_ENFORCE(in_out_idxs_.empty(), "duplicate call CreateInOutOffsetMap");
for (int i = 0; i < proto.inputs_size(); i++) {
const auto& name = proto.inputs()[i].name();
in_out_idxs_[name] = i;
}
for (int i = 0; i < proto.outputs_size(); i++) {
const auto& name = proto.outputs()[i].name();
in_out_idxs_[name] = i;
}
}
const std::string& OperatorBase::Input(const std::string& name) const {
auto it = in_out_idxs_.find(name);
PADDLE_ENFORCE(it != in_out_idxs_.end(), "no key [%s] in in_out_idxs_", name);
if (attrs_.count("input_format") == 0) {
return inputs_[it->second];
} else {
const auto& input_format = GetAttr<std::vector<int>>("input_format");
int idx = input_format[it->second];
return inputs_.at(idx);
}
}
std::vector<std::string> OperatorBase::Inputs(const std::string& name) const {
auto input_format = GetAttr<std::vector<int>>("input_format");
auto offset = in_out_idxs_.at(name);
return std::vector<std::string>{
inputs_.begin() + input_format.at(offset),
inputs_.begin() + input_format.at(offset + 1)};
}
const std::string& OperatorBase::Output(const std::string& name) const {
auto it = in_out_idxs_.find(name);
PADDLE_ENFORCE(it != in_out_idxs_.end(), "no key [%s] in in_out_idxs_", name);
if (attrs_.count("output_format") == 0) {
return outputs_[it->second];
} else {
const auto& output_format = GetAttr<std::vector<int>>("output_format");
int idx = output_format[it->second];
return outputs_.at(idx);
}
}
std::vector<std::string> OperatorBase::Outputs(const std::string& name) const {
auto output_format = GetAttr<std::vector<int>>("output_format");
auto offset = in_out_idxs_.at(name);
return std::vector<std::string>{
outputs_.begin() + output_format.at(offset),
outputs_.begin() + output_format.at(offset + 1)};
}
std::string OperatorBase::DebugString() const {
std::stringstream ss;
ss << "=================\n";
ss << "type = " << type_ << "\n";
ss << "inputs = [";
for (auto& ipt : inputs_) {
ss << ipt << ", ";
}
ss << "]\n";
ss << "outputs = [";
for (auto& opt : outputs_) {
ss << opt << ", ";
ss << "Op(" << type_ << "), inputs:(";
for (size_t i = 0; i < inputs_.size(); ++i) {
ss << inputs_[i];
if (i != inputs_.size() - 1) {
ss << ", ";
}
}
ss << "]\n";
ss << "attr_keys = [";
for (auto& attr : attrs_) {
ss << attr.first << ", ";
ss << "), outputs:(";
for (size_t i = 0; i < outputs_.size(); ++i) {
ss << outputs_[i];
if (i != outputs_.size() - 1) {
ss << ", ";
}
}
ss << "]\n";
ss << ").";
return ss.str();
}
......
......@@ -14,18 +14,20 @@ limitations under the License. */
#pragma once
#include <paddle/framework/attr_checker.h>
#include <paddle/framework/op_desc.pb.h>
#include <paddle/framework/scope.h>
#include <paddle/framework/tensor.h>
#include <paddle/platform/device_context.h>
#include <paddle/platform/place.h>
#include <paddle/utils/Error.h>
#include <boost/variant.hpp>
#include <string>
#include <unordered_map>
#include <vector>
#include "paddle/framework/attr_checker.h"
#include "paddle/framework/op_desc.pb.h"
#include "paddle/framework/op_proto.pb.h"
#include "paddle/framework/scope.h"
#include "paddle/framework/tensor.h"
#include "paddle/platform/device_context.h"
#include "paddle/platform/place.h"
#include "paddle/utils/Error.h"
namespace paddle {
namespace framework {
......@@ -39,6 +41,13 @@ using OperatorPtr = std::shared_ptr<OperatorBase>;
*/
class OperatorBase {
public:
/// If a variable is a empty variable, that name will be used.
static std::string EMPTY_VAR_NAME() { return "@EMPTY@"; }
/// If a variable is a temporary variable, that name will be set in Python,
/// but it will be convert to a unique name in scope after OpCreator.
static std::string TMP_VAR_NAME() { return "@TEMP@"; }
virtual ~OperatorBase() {}
template <typename T>
......@@ -62,11 +71,72 @@ class OperatorBase {
virtual void Run(const ScopePtr& scope,
const platform::DeviceContext& dev_ctx) const = 0;
// Get a input with argument's name described in `op_proto`
const std::string& Input(const std::string& name) const;
// Get a input which has multiple variables.
// TODO add a vector_view to prevent memory copy.
std::vector<std::string> Inputs(const std::string& name) const;
// Get a output with argument's name described in `op_proto`
const std::string& Output(const std::string& name) const;
// Get an output which has multiple variables.
// TODO add a vector_view to prevent memory copy.
std::vector<std::string> Outputs(const std::string& name) const;
// init in_out_idxs_ to accelerate argument's offset lookup.
void CreateInOutOffsetMap(const OpProto& proto);
public:
std::string type_;
std::vector<std::string> inputs_;
std::vector<std::string> outputs_;
AttributeMap attrs_;
// store the arguments' offset described in op_desc.
std::unordered_map<std::string, int> in_out_idxs_;
};
class KernelContext {
public:
KernelContext(const OperatorBase* op, const std::shared_ptr<Scope>& scope,
const platform::DeviceContext& device_context)
: op_(*op), scope_(scope), device_context_(device_context) {}
const Variable* Input(int index) const {
return scope_->GetVariable(op_.inputs_[index]);
}
Variable* Output(int index) const {
return scope_->GetVariable(op_.outputs_[index]);
}
const Variable* Input(const std::string& name) const {
return scope_->GetVariable(op_.Input(name));
}
const Variable* Output(const std::string& name) const {
return scope_->GetVariable(op_.Output(name));
}
const std::vector<const Variable*> Inputs(const std::string& name) const {
auto names = op_.Inputs(name);
std::vector<const Variable*> res;
std::transform(
names.begin(), names.end(), res.begin(),
[this](const std::string& name) { return scope_->GetVariable(name); });
return res;
}
const std::vector<const Variable*> Outputs(const std::string& name) const {
auto names = op_.Outputs(name);
std::vector<const Variable*> res;
std::transform(
names.begin(), names.end(), res.begin(),
[this](const std::string& name) { return scope_->GetVariable(name); });
return res;
}
const OperatorBase& op_;
const std::shared_ptr<Scope>& scope_;
const platform::DeviceContext& device_context_;
};
class OpKernel {
......@@ -77,25 +147,6 @@ class OpKernel {
* device resource such as CUDA stream, cublas handle, etc. from
* KernelContext. User should construct it before run the Operator.
*/
class KernelContext {
public:
KernelContext(const OperatorBase* op, const ScopePtr& scope,
const platform::DeviceContext& device_context)
: op_(*op), scope_(scope), device_context_(device_context) {}
const Variable* Input(int index) const {
return scope_->GetVariable(op_.inputs_[index]);
}
Variable* Output(int index) const {
return scope_->GetVariable(op_.outputs_[index]);
}
const OperatorBase& op_;
const ScopePtr& scope_;
const platform::DeviceContext& device_context_;
};
virtual void Compute(const KernelContext& context) const = 0;
virtual ~OpKernel() {}
......@@ -140,7 +191,7 @@ class OperatorWithKernel : public OperatorBase {
void Run(const ScopePtr& scope,
const platform::DeviceContext& dev_ctx) const final {
auto& opKernel = AllOpKernels().at(type_).at(OpKernelKey(dev_ctx));
opKernel->Compute(OpKernel::KernelContext(this, scope, dev_ctx));
opKernel->Compute(KernelContext(this, scope, dev_ctx));
}
static std::unordered_map<std::string /* op_type */, OpKernelMap>&
......@@ -148,6 +199,7 @@ class OperatorWithKernel : public OperatorBase {
static std::unordered_map<std::string, OpKernelMap> g_all_op_kernels;
return g_all_op_kernels;
}
void InferShape(const std::shared_ptr<Scope>& scope) const final {
std::vector<const Tensor*> ins;
VarNamesToTensors(scope, inputs_, &ins);
......
......@@ -30,7 +30,6 @@ class OpWithoutKernelTest : public OperatorBase {
op_run_num++;
ASSERT_EQ((int)inputs_.size(), 1);
ASSERT_EQ((int)outputs_.size(), 1);
ASSERT_NEAR(GetAttr<float>("scale"), 3.14, 1e-5);
ASSERT_EQ(scope->GetVariable(inputs_[0]), nullptr);
ASSERT_EQ(x, 1);
ASSERT_NE(scope->GetVariable(outputs_[0]), nullptr);
......@@ -86,9 +85,11 @@ class OpKernelTestProtoAndCheckerMaker : public OpProtoAndCheckerMaker {
public:
OpKernelTestProtoAndCheckerMaker(OpProto* proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("input", "input of test op");
AddOutput("output", "output of test op");
AddAttr<float>("scale", "scale of cosine op");
AddInput("x", "input of test op");
AddOutput("y", "output of test op");
AddAttr<float>("scale", "scale of cosine op")
.SetDefault(1.0)
.LargerThan(0.0);
AddComment("This is test op");
}
};
......@@ -103,11 +104,65 @@ class OpWithKernelTest : public OperatorWithKernel {
class CPUKernelTest : public OpKernel {
public:
void Compute(const KernelContext& context) const {
void Compute(const KernelContext& ctx) const {
std::cout << "this is cpu kernel" << std::endl;
std::cout << ctx.op_.DebugString() << std::endl;
cpu_kernel_run_num++;
ASSERT_EQ((int)context.op_.inputs_.size(), 1);
ASSERT_EQ((int)context.op_.outputs_.size(), 1);
ASSERT_NEAR(context.op_.GetAttr<float>("scale"), 3.14, 1e-5);
ASSERT_EQ(ctx.op_.Input("x"), "IN1");
ASSERT_EQ(ctx.op_.Output("y"), "OUT1");
}
};
// multiple inputs test
class OperatorMultiInputsTest : public OperatorBase {
public:
void Init() override { x = 1; }
void InferShape(const std::shared_ptr<Scope>& scope) const override {}
void Run(const std::shared_ptr<Scope>& scope,
const platform::DeviceContext& dev_ctx) const override {
ASSERT_EQ(scope->GetVariable(inputs_[0]), nullptr);
ASSERT_EQ(x, 1);
ASSERT_NE(scope->GetVariable(outputs_[0]), nullptr);
ASSERT_EQ(Input("x"), "IN1");
ASSERT_EQ(Input("y"), "OUT1");
}
public:
float x = 0;
};
class OpKernelTestMultiInputsProtoAndCheckerMaker
: public OpProtoAndCheckerMaker {
public:
OpKernelTestMultiInputsProtoAndCheckerMaker(OpProto* proto,
OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInputs("xs", "inputs of test op");
AddInput("k", "input of test op");
AddOutputs("ys", "outputs of test op");
AddAttr<float>("scale", "scale of cosine op")
.SetDefault(1.0)
.LargerThan(0.0);
AddComment("This is test op");
}
};
class CPUKernalMultiInputsTest : public OpKernel {
public:
void Compute(const KernelContext& ctx) const {
auto xs = ctx.op_.Inputs("xs");
ASSERT_EQ(xs.size(), 3UL);
ASSERT_EQ(xs[0], "x0");
ASSERT_EQ(xs[1], "x1");
ASSERT_EQ(xs[2], "x2");
auto k = ctx.op_.Input("k");
ASSERT_EQ(k, "k0");
auto ys = ctx.op_.Outputs("ys");
ASSERT_EQ(ys.size(), 2UL);
ASSERT_EQ(ys[0], "y0");
ASSERT_EQ(ys[1], "y1");
}
};
......@@ -118,6 +173,7 @@ REGISTER_OP(op_with_kernel, paddle::framework::OpWithKernelTest,
paddle::framework::OpKernelTestProtoAndCheckerMaker);
REGISTER_OP_CPU_KERNEL(op_with_kernel, paddle::framework::CPUKernelTest);
// test with single input
TEST(OpKernel, all) {
paddle::framework::OpDesc op_desc;
op_desc.set_type("op_with_kernel");
......@@ -137,3 +193,47 @@ TEST(OpKernel, all) {
op->Run(scope, cpu_device_context);
ASSERT_EQ(paddle::framework::cpu_kernel_run_num, 1);
}
REGISTER_OP(op_multi_inputs_with_kernel, paddle::framework::OpWithKernelTest,
paddle::framework::OpKernelTestMultiInputsProtoAndCheckerMaker);
REGISTER_OP_CPU_KERNEL(op_multi_inputs_with_kernel,
paddle::framework::CPUKernalMultiInputsTest);
// test with multi inputs
TEST(OpKernel, multi_inputs) {
using namespace paddle::framework;
OpDesc op_desc;
op_desc.set_type("op_multi_inputs_with_kernel");
*op_desc.mutable_inputs()->Add() = "x0";
*op_desc.mutable_inputs()->Add() = "x1";
*op_desc.mutable_inputs()->Add() = "x2";
*op_desc.mutable_inputs()->Add() = "k0";
*op_desc.mutable_outputs()->Add() = "y0";
*op_desc.mutable_outputs()->Add() = "y1";
auto attr = op_desc.mutable_attrs()->Add();
attr->set_name("scale");
attr->set_type(paddle::framework::AttrType::FLOAT);
attr->set_f(3.14);
auto attr0 = op_desc.mutable_attrs()->Add();
attr0->set_name("input_format");
attr0->set_type(paddle::framework::AttrType::INTS);
auto input_format = attr0->mutable_ints();
input_format->Add(0); // x0
input_format->Add(3); // k
input_format->Add(4); // end
auto attr1 = op_desc.mutable_attrs()->Add();
attr1->set_name("output_format");
attr1->set_type(paddle::framework::AttrType::INTS);
auto output_format = attr1->mutable_ints();
output_format->Add(0); // y0
output_format->Add(2); // y1
paddle::platform::CPUDeviceContext cpu_device_context;
auto scope = std::make_shared<Scope>();
OperatorPtr op(paddle::framework::OpRegistry::CreateOp(op_desc));
op->Run(scope, cpu_device_context);
}
......@@ -8,10 +8,10 @@ namespace operators {
template <typename Place>
class AddKernel : public framework::OpKernel {
public:
void Compute(const KernelContext &context) const override {
void Compute(const framework::KernelContext &context) const override {
LOG(INFO) << "Add kernel in " << typeid(Place).name();
}
};
} // namespace op
} // namespace operators
} // namespace paddle
......@@ -67,6 +67,23 @@ All parameter, weight, gradient are variables in Paddle.
}
return ret_values;
});
m.def_submodule(
"var_names",
"The module will return special predefined variable name in Paddle")
.def("empty", pd::OperatorBase::EMPTY_VAR_NAME)
.def("temp", pd::OperatorBase::TMP_VAR_NAME);
py::class_<pd::OperatorBase, pd::OperatorPtr>(m, "Operator")
.def("__str__", &pd::OperatorBase::DebugString)
.def_static("create", [](const std::string& protobin) {
pd::OpDesc desc;
PADDLE_ENFORCE(desc.ParsePartialFromString(protobin),
"Cannot parse user input to OpDesc");
PADDLE_ENFORCE(desc.IsInitialized(),
"User OpDesc is not initialized, reason %s",
desc.InitializationErrorString());
return pd::OpRegistry::CreateOp(desc);
});
return m.ptr();
}
import paddle.v2.framework.core as core
import paddle.v2.framework.proto.op_proto_pb2 as op_proto_pb2
import paddle.v2.framework.proto.op_desc_pb2 as op_desc_pb2
import paddle.v2.framework.proto.attr_type_pb2 as attr_type_pb2
import cStringIO
def get_all_op_protos():
"""
Get all registered op proto from Paddle C++
:return: list of OpProto
"""
protostrs = core.get_all_op_protos()
ret_values = []
for pbstr in protostrs:
op_proto = op_proto_pb2.OpProto.FromString(str(pbstr))
ret_values.append(op_proto)
return ret_values
class OpDescCreationMethod(object):
"""
A Functor object to convert user input(use key word args) to OpDesc based on
OpProto.
:param op_proto: The OpProto object.
:type op_proto: op_proto_pb2.OpProto
"""
def __init__(self, op_proto):
if not isinstance(op_proto, op_proto_pb2.OpProto):
raise TypeError("Argument should be OpProto")
self.__op_proto__ = op_proto
def __call__(self, *args, **kwargs):
"""
Convert user input to OpDesc. Only key-word args are supported.
:return: OpDesc based on user input
:rtype: op_desc_pb2.OpDesc
"""
if len(args) != 0:
raise ValueError("Only keyword arguments is supported by Paddle")
op_desc = op_desc_pb2.OpDesc()
# Inputs
ipts, ipt_format, _ = OpDescCreationMethod.extract_input_or_output(
"input", kwargs, self.__op_proto__.inputs)
op_desc.inputs.extend(ipts)
if ipt_format is not None:
op_desc.attrs.extend([ipt_format])
# Outputs
outs, out_format, tmp_index = OpDescCreationMethod.extract_input_or_output(
"output", kwargs, self.__op_proto__.outputs)
op_desc.outputs.extend(outs)
if out_format is not None:
op_desc.attrs.extend([out_format])
if len(tmp_index) != 0:
tmp_index_attr = op_desc.attrs.add()
tmp_index_attr.type = attr_type_pb2.INTS
tmp_index_attr.name = "temporary_index"
tmp_index_attr.ints.extend(tmp_index)
# Types
op_desc.type = self.__op_proto__.type
# Attrs
for attr in self.__op_proto__.attrs:
if attr.generated:
continue
user_defined_attr = kwargs.get(attr.name, None)
if user_defined_attr is not None:
new_attr = op_desc.attrs.add()
new_attr.name = attr.name
new_attr.type = attr.type
if attr.type == attr_type_pb2.INT:
new_attr.i = user_defined_attr
elif attr.type == attr_type_pb2.FLOAT:
new_attr.f = user_defined_attr
elif attr.type == attr_type_pb2.STRING:
new_attr.s = user_defined_attr
elif attr.type == attr_type_pb2.INTS:
new_attr.ints.extend(user_defined_attr)
elif attr.type == attr_type_pb2.FLOATS:
new_attr.floats.extend(user_defined_attr)
elif attr.type == attr_type_pb2.STRINGS:
new_attr.strings.extend(user_defined_attr)
else:
raise NotImplementedError("Not support attribute type " +
attr.type)
return op_desc
@staticmethod
def extract_input_or_output(in_out, kwargs, meta):
"""
Extract input variable names or output variable names from key-word
arguments, which base on VarProtos.
:param in_out: "input" or "output"
:param kwargs: key-word arguments that user inputted.
:param meta: a list of VarProto
:return: The three object will be return. The variable names. The
input_format or output_format attribute(None if the input or output is
not multiple). The temporary variable index list.
"""
multiple = OpDescCreationMethod.any_is_true((m.multiple for m in meta))
tmp_index = []
retv = []
if multiple:
var_format = op_desc_pb2.AttrDesc()
var_format.type = attr_type_pb2.INTS
var_format.name = "%s_format" % in_out
var_format.ints.append(0)
for var in meta:
var_name = var.name
if var.temporary:
var_name = [core.var_names.temp()]
tmp_index.append(len(retv))
else:
var_name = kwargs.get(var_name, [])
if not isinstance(var_name, list):
var_name = [var_name]
retv.extend(var_name)
var_format.ints.append(len(var_name) + var_format.ints[-1])
return retv, var_format, tmp_index
else:
for var in meta:
if var.temporary:
retv.append(kwargs.get(var.name, core.var_names.temp()))
tmp_index.append(len(retv))
else:
retv.append(kwargs.get(var.name, core.var_names.empty()))
return retv, None, tmp_index
@staticmethod
def any_is_true(generator):
"""
Reduce a bool array to one. If any of them is True, then return True.
"""
for flag in generator:
if flag:
return True
return False
def get_docstring_from_op_proto(op_proto):
"""
Generate docstring from a OpProto
:param op_proto: a OpProto instance.
:type op_proto: op_proto_pb2.OpProto
:return: docstring
"""
if not isinstance(op_proto, op_proto_pb2.OpProto):
raise TypeError("Input must be OpProto")
f = cStringIO.StringIO()
f.write(op_proto.comment)
f.write("\n")
def __append_param__(name, comment, type):
# Maybe replace the following line with template engine is better.
f.write(":param ")
f.write(name)
f.write(": ")
f.write(comment)
f.write("\n")
f.write(":type ")
f.write(name)
f.write(": ")
f.write(type)
f.write("\n")
for ipt in op_proto.inputs:
__append_param__(ipt.name, ipt.comment, "list | basestr"
if ipt.multiple else "basestr")
temp_var_prefix = \
"This is a temporary variable. It does not have to set by user. "
for opt in op_proto.outputs:
__append_param__(opt.name, opt.comment if not opt.temporary else
temp_var_prefix + opt.comment, "list | basestr"
if opt.multiple else "basestr")
for attr in op_proto.attrs:
attr_type = None
if attr.type == attr_type_pb2.INT:
attr_type = "int"
elif attr.type == attr_type_pb2.FLOAT:
attr_type = "float"
elif attr.type == attr_type_pb2.STRING:
attr_type = "basestr"
elif attr.type == attr_type_pb2.INTS:
attr_type = "list of int"
elif attr.type == attr_type_pb2.FLOATS:
attr_type = "list of float"
elif attr.type == attr_type_pb2.STRINGS:
attr_type = "list of basestr"
if attr_type is None:
raise RuntimeError("Not supported attribute type " + attr.type)
__append_param__(attr.name, attr.comment, attr_type)
return f.getvalue()
def create_op_creation_method(op_proto):
"""
Generate op creation method for an OpProto
"""
method = OpDescCreationMethod(op_proto)
def __impl__(*args, **kwargs):
opdesc = method(*args, **kwargs)
return core.Operator.create(opdesc.SerializeToString())
__impl__.__doc__ = get_docstring_from_op_proto(op_proto)
return __impl__
class OpCreationsHolder(object):
"""
A object will holds all op creation methods.
Use `op_creations.xxx_op` to access them.
"""
pass
op_creations = OpCreationsHolder()
def __bootstrap__():
"""
Bootstrap function for this module. It will dynamic create all op creation
methods in runtime.
"""
for op_proto in get_all_op_protos():
func = create_op_creation_method(op_proto)
func.__name__ = str(op_proto.type)
setattr(op_creations, func.__name__, func)
__bootstrap__()
import unittest
import paddle.v2.framework.create_op_creation_methods as creation
import paddle.v2.framework.core as core
import paddle.v2.framework.proto.op_proto_pb2 as op_proto_pb2
import paddle.v2.framework.proto.op_desc_pb2 as op_desc_pb2
import paddle.v2.framework.proto.attr_type_pb2 as attr_type_pb2
class TestOpCreationsMethods(unittest.TestCase):
def test_all_protos(self):
class TestGetAllProtos(unittest.TestCase):
def test_all(self):
all_protos = creation.get_all_op_protos()
self.assertNotEqual(0, len(all_protos))
......@@ -11,5 +15,240 @@ class TestOpCreationsMethods(unittest.TestCase):
self.assertTrue(each.IsInitialized())
class TestOpDescCreationMethod(unittest.TestCase):
def test_plain_input_output(self):
op = op_proto_pb2.OpProto()
op.type = "test"
ipt = op.inputs.add()
ipt.name = "X"
ipt.comment = "not matter"
ipt = op.inputs.add()
ipt.name = "Y"
ipt.comment = "not matter"
opt = op.outputs.add()
opt.name = "Z"
opt.comment = "not matter"
op.comment = "not matter"
self.assertTrue(op.IsInitialized())
method = creation.OpDescCreationMethod(op)
output = method(X="a", Y="b", Z="c")
expected = op_desc_pb2.OpDesc()
expected.type = "test"
expected.inputs.extend(["a", "b"])
expected.outputs.append("c")
self.assertEqual(expected, output)
def test_multiple_input_plain_output(self):
op = op_proto_pb2.OpProto()
op.type = "fc"
ipt = op.inputs.add()
ipt.name = "X"
ipt.comment = ""
ipt.multiple = True
ipt = op.inputs.add()
ipt.name = "W"
ipt.comment = ""
ipt.multiple = True
ipt = op.inputs.add()
ipt.name = "b"
ipt.comment = ""
out = op.outputs.add()
out.name = "Y"
out.comment = ""
op.comment = ""
self.assertTrue(op.IsInitialized())
method = creation.OpDescCreationMethod(op)
generated1 = method(X="x", W="w", b="b", Y="y")
expected1 = op_desc_pb2.OpDesc()
expected1.inputs.extend(['x', 'w', 'b'])
expected1.outputs.extend(['y'])
expected1.type = 'fc'
attr = expected1.attrs.add()
attr.name = 'input_format'
attr.type = attr_type_pb2.INTS
attr.ints.extend([0, 1, 2, 3])
self.assertEqual(expected1, generated1)
generated2 = method(
X=['x1', 'x2', 'x3'], b='b', W=['w1', 'w2', 'w3'], Y='y')
expected2 = op_desc_pb2.OpDesc()
expected2.inputs.extend(['x1', 'x2', 'x3', 'w1', 'w2', 'w3', 'b'])
expected2.outputs.extend(['y'])
expected2.type = 'fc'
attr = expected2.attrs.add()
attr.name = 'input_format'
attr.type = attr_type_pb2.INTS
attr.ints.extend([0, 3, 6, 7])
self.assertEqual(expected2, generated2)
def test_attrs(self):
op = op_proto_pb2.OpProto()
op.type = "test"
ipt = op.inputs.add()
ipt.name = 'X'
ipt.comment = ""
def __add_attr__(name, type):
attr = op.attrs.add()
attr.name = name
attr.comment = ""
attr.type = type
__add_attr__("int_attr", attr_type_pb2.INT)
__add_attr__("float_attr", attr_type_pb2.FLOAT)
__add_attr__("string_attr", attr_type_pb2.STRING)
__add_attr__("ints_attr", attr_type_pb2.INTS)
__add_attr__("floats_attr", attr_type_pb2.FLOATS)
__add_attr__("strings_attr", attr_type_pb2.STRINGS)
op.comment = ""
self.assertTrue(op.IsInitialized())
method = creation.OpDescCreationMethod(op)
generated = method(
X="a",
int_attr=10,
float_attr=3.2,
string_attr="test_str",
ints_attr=[0, 1, 2, 3, 4],
floats_attr=[0.2, 3.2, 4.5],
strings_attr=["a", "b", "c"])
expected = op_desc_pb2.OpDesc()
expected.type = "test"
expected.inputs.extend(['a'])
attr = expected.attrs.add()
attr.name = "int_attr"
attr.type = attr_type_pb2.INT
attr.i = 10
attr = expected.attrs.add()
attr.name = "float_attr"
attr.type = attr_type_pb2.FLOAT
attr.f = 3.2
attr = expected.attrs.add()
attr.name = "string_attr"
attr.type = attr_type_pb2.STRING
attr.s = "test_str"
attr = expected.attrs.add()
attr.name = "ints_attr"
attr.type = attr_type_pb2.INTS
attr.ints.extend([0, 1, 2, 3, 4])
attr = expected.attrs.add()
attr.name = 'floats_attr'
attr.type = attr_type_pb2.FLOATS
attr.floats.extend([0.2, 3.2, 4.5])
attr = expected.attrs.add()
attr.name = 'strings_attr'
attr.type = attr_type_pb2.STRINGS
attr.strings.extend(['a', 'b', 'c'])
self.assertEqual(expected, generated)
def test_input_temporary_output(self):
op = op_proto_pb2.OpProto()
op.type = "test"
out = op.outputs.add()
out.name = "OUT"
out.comment = ""
out = op.outputs.add()
out.name = "TMP"
out.comment = ""
out.temporary = True
out = op.outputs.add()
out.name = "OUT2"
out.comment = ""
op.comment = ""
method = creation.OpDescCreationMethod(op)
generated = method(OUT="a", OUT2="b")
desc = op_desc_pb2.OpDesc()
desc.outputs.extend(["a", core.var_names.temp(), "b"])
desc.type = "test"
attr = desc.attrs.add()
attr.name = "temporary_index"
attr.type = attr_type_pb2.INTS
attr.ints.append(2)
self.assertEqual(generated, desc)
class TestOpCreationDocStr(unittest.TestCase):
def test_all(self):
op = op_proto_pb2.OpProto()
op.type = "test"
op.comment = """Test Op.
This op is used for unit test, not a real op.
"""
a = op.inputs.add()
a.name = "a"
a.comment = "Input a for test op"
a.multiple = True
b = op.inputs.add()
b.name = "b"
b.comment = "Input b for test op"
self.assertTrue(op.IsInitialized())
o1 = op.outputs.add()
o1.name = "output"
o1.comment = "The output of test op"
o2 = op.outputs.add()
o2.name = "temp output"
o2.comment = "The temporary output of test op"
o2.temporary = True
test_str = op.attrs.add()
test_str.name = "str_attr"
test_str.type = attr_type_pb2.STRING
test_str.comment = "A string attribute for test op"
actual = creation.get_docstring_from_op_proto(op)
expected_docstring = '''Test Op.
This op is used for unit test, not a real op.
:param a: Input a for test op
:type a: list | basestr
:param b: Input b for test op
:type b: basestr
:param output: The output of test op
:type output: basestr
:param temp output: This is a temporary variable. It does not have to set by user. The temporary output of test op
:type temp output: basestr
:param str_attr: A string attribute for test op
:type str_attr: basestr
'''
self.assertEqual(expected_docstring, actual)
class TestOpCreations(unittest.TestCase):
def test_all(self):
add_op = creation.op_creations.add_two(X="a", Y="b", Out="z")
self.assertIsNotNone(add_op)
# Invoke C++ DebugString()
self.assertEqual('Op(add_two), inputs:(a, b), outputs:(z).',
str(add_op))
if __name__ == "__main__":
unittest.main()
import py_paddle.swig_paddle as swig_api
import paddle.trainer_config_helpers.config_parser_utils as config_parser_utils
import paddle.trainer_config_helpers.optimizers as v1_optimizers
"""
......@@ -17,6 +16,7 @@ __all__ = [
class Optimizer(object):
def __init__(self, **kwargs):
import py_paddle.swig_paddle as swig_api
if 'batch_size' in kwargs:
del kwargs['batch_size'] # not important for python library.
......@@ -35,18 +35,22 @@ class Optimizer(object):
For each optimizer(SGD, Adam), GradientMachine should enable different
buffers.
"""
import py_paddle.swig_paddle as swig_api
tmp = swig_api.ParameterOptimizer.create(self.__opt_conf__)
assert isinstance(tmp, swig_api.ParameterOptimizer)
return tmp.getParameterTypes()
def __create_local_updater__(self):
import py_paddle.swig_paddle as swig_api
return swig_api.ParameterUpdater.createLocalUpdater(self.__opt_conf__)
def __create_remote_updater__(self, pass_num, use_sparse_updater):
import py_paddle.swig_paddle as swig_api
return swig_api.ParameterUpdater.createRemoteUpdater(
self.__opt_conf__, pass_num, use_sparse_updater)
def __create_new_remote_updater__(self, pserver_spec, use_etcd):
import py_paddle.swig_paddle as swig_api
return swig_api.ParameterUpdater.createNewRemoteUpdater(
self.__opt_conf__, pserver_spec, use_etcd)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册