diff --git a/paddle/framework/op_registry.h b/paddle/framework/op_registry.h index 41bdb65f8e9bbbf63ec6e62e8c2243cc3d6f21de..7aa59f0b630d5a99fa15c00c9e32a22dd59b9a70 100644 --- a/paddle/framework/op_registry.h +++ b/paddle/framework/op_registry.h @@ -1,6 +1,7 @@ #pragma once #include +#include #include #include #include @@ -214,29 +215,61 @@ class OpRegistry { } static OperatorPtr CreateOp(const OpDesc& op_desc) { + //! Create a OpPtr by type. std::string op_type = op_desc.type(); OperatorPtr op(creators().at(op_type)()); + //! Fill op's data member. Not use constructor because it will be noising + //! for Op developer. + const OpProto& op_proto = protos().at(op_type); op->type_ = op_desc.type(); + // set op's inputs_ from desc. op->inputs_.reserve((size_t)op_desc.inputs_size()); std::copy(op_desc.inputs().begin(), op_desc.inputs().end(), std::back_inserter(op->inputs_)); + // set op's outputs_ from desc. op->outputs_.reserve((size_t)op_desc.outputs_size()); std::copy(op_desc.outputs().begin(), op_desc.outputs().end(), std::back_inserter(op->outputs_)); + + //! Fill attrs, and validate attrs. for (auto& attr : op_desc.attrs()) { op->attrs_[attr.name()] = AttrTypeHelper::GetAttrValue(attr); } op_checkers().at(op_type).Check(op->attrs_); + + //! Convert Temporary variable name to an unique variable name. + GenerateTempVariableName(op.get()); + + // set argument offsets stored in op. + CreateInOutOffsetMap(op, op_proto); + //! Other op's custom Init for a complex Op. For simple Op, the Init + //! method do nothing. op->Init(); return op; } + // init op.in_out_idxs_ to accelerate argument's offset lookup. + static void CreateInOutOffsetMap(OperatorPtr op, const OpProto& proto) { + op->CreateInOutOffsetMap(proto); + } + static std::unordered_map& protos() { static std::unordered_map protos_; return protos_; }; private: + static void GenerateTempVariableName(OperatorBase* op) { + static std::atomic gUniqId(0UL); + for (auto& outname : op->outputs_) { + if (outname == OperatorBase::TMP_VAR_NAME()) { + outname += op->type_; + outname += "@"; + outname += std::to_string(gUniqId.fetch_add(1)); + } + } + } + static std::unordered_map& creators() { static std::unordered_map creators_; return creators_; diff --git a/paddle/framework/operator.cc b/paddle/framework/operator.cc index 7756162a872db6de3f932a09e4b7525a2fde70da..50cb2d936274dcc046d5641ff276aae77358d1bf 100644 --- a/paddle/framework/operator.cc +++ b/paddle/framework/operator.cc @@ -12,30 +12,86 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ +#include + #include "paddle/framework/operator.h" namespace paddle { namespace framework { +void OperatorBase::CreateInOutOffsetMap(const OpProto& proto) { + PADDLE_ENFORCE(in_out_idxs_.empty(), "duplicate call CreateInOutOffsetMap"); + for (int i = 0; i < proto.inputs_size(); i++) { + const auto& name = proto.inputs()[i].name(); + in_out_idxs_[name] = i; + } + for (int i = 0; i < proto.outputs_size(); i++) { + const auto& name = proto.outputs()[i].name(); + in_out_idxs_[name] = i; + } +} + +const std::string& OperatorBase::Input(const std::string& name) const { + auto it = in_out_idxs_.find(name); + PADDLE_ENFORCE(it != in_out_idxs_.end(), "no key [%s] in in_out_idxs_", name); + + if (attrs_.count("input_format") == 0) { + return inputs_[it->second]; + } else { + const auto& input_format = GetAttr>("input_format"); + int idx = input_format[it->second]; + return inputs_.at(idx); + } +} + +std::vector OperatorBase::Inputs(const std::string& name) const { + auto input_format = GetAttr>("input_format"); + auto offset = in_out_idxs_.at(name); + + return std::vector{ + inputs_.begin() + input_format.at(offset), + inputs_.begin() + input_format.at(offset + 1)}; +} + +const std::string& OperatorBase::Output(const std::string& name) const { + auto it = in_out_idxs_.find(name); + PADDLE_ENFORCE(it != in_out_idxs_.end(), "no key [%s] in in_out_idxs_", name); + + if (attrs_.count("output_format") == 0) { + return outputs_[it->second]; + } else { + const auto& output_format = GetAttr>("output_format"); + int idx = output_format[it->second]; + return outputs_.at(idx); + } +} + +std::vector OperatorBase::Outputs(const std::string& name) const { + auto output_format = GetAttr>("output_format"); + auto offset = in_out_idxs_.at(name); + + return std::vector{ + outputs_.begin() + output_format.at(offset), + outputs_.begin() + output_format.at(offset + 1)}; +} + std::string OperatorBase::DebugString() const { std::stringstream ss; - ss << "=================\n"; - ss << "type = " << type_ << "\n"; - ss << "inputs = ["; - for (auto& ipt : inputs_) { - ss << ipt << ", "; - } - ss << "]\n"; - ss << "outputs = ["; - for (auto& opt : outputs_) { - ss << opt << ", "; + ss << "Op(" << type_ << "), inputs:("; + for (size_t i = 0; i < inputs_.size(); ++i) { + ss << inputs_[i]; + if (i != inputs_.size() - 1) { + ss << ", "; + } } - ss << "]\n"; - ss << "attr_keys = ["; - for (auto& attr : attrs_) { - ss << attr.first << ", "; + ss << "), outputs:("; + for (size_t i = 0; i < outputs_.size(); ++i) { + ss << outputs_[i]; + if (i != outputs_.size() - 1) { + ss << ", "; + } } - ss << "]\n"; + ss << ")."; return ss.str(); } diff --git a/paddle/framework/operator.h b/paddle/framework/operator.h index f7ed6e9f3d9422dd719721d4c3ab6b21d4e50e2d..2fe9670677c6041c0360d096b88e818676a8c929 100644 --- a/paddle/framework/operator.h +++ b/paddle/framework/operator.h @@ -14,18 +14,20 @@ limitations under the License. */ #pragma once -#include -#include -#include -#include -#include -#include -#include #include #include #include #include +#include "paddle/framework/attr_checker.h" +#include "paddle/framework/op_desc.pb.h" +#include "paddle/framework/op_proto.pb.h" +#include "paddle/framework/scope.h" +#include "paddle/framework/tensor.h" +#include "paddle/platform/device_context.h" +#include "paddle/platform/place.h" +#include "paddle/utils/Error.h" + namespace paddle { namespace framework { @@ -39,6 +41,13 @@ using OperatorPtr = std::shared_ptr; */ class OperatorBase { public: + /// If a variable is a empty variable, that name will be used. + static std::string EMPTY_VAR_NAME() { return "@EMPTY@"; } + + /// If a variable is a temporary variable, that name will be set in Python, + /// but it will be convert to a unique name in scope after OpCreator. + static std::string TMP_VAR_NAME() { return "@TEMP@"; } + virtual ~OperatorBase() {} template @@ -62,11 +71,72 @@ class OperatorBase { virtual void Run(const ScopePtr& scope, const platform::DeviceContext& dev_ctx) const = 0; + // Get a input with argument's name described in `op_proto` + const std::string& Input(const std::string& name) const; + // Get a input which has multiple variables. + // TODO add a vector_view to prevent memory copy. + std::vector Inputs(const std::string& name) const; + // Get a output with argument's name described in `op_proto` + const std::string& Output(const std::string& name) const; + // Get an output which has multiple variables. + // TODO add a vector_view to prevent memory copy. + std::vector Outputs(const std::string& name) const; + + // init in_out_idxs_ to accelerate argument's offset lookup. + void CreateInOutOffsetMap(const OpProto& proto); + public: std::string type_; std::vector inputs_; std::vector outputs_; AttributeMap attrs_; + // store the arguments' offset described in op_desc. + std::unordered_map in_out_idxs_; +}; + +class KernelContext { + public: + KernelContext(const OperatorBase* op, const std::shared_ptr& scope, + const platform::DeviceContext& device_context) + : op_(*op), scope_(scope), device_context_(device_context) {} + + const Variable* Input(int index) const { + return scope_->GetVariable(op_.inputs_[index]); + } + + Variable* Output(int index) const { + return scope_->GetVariable(op_.outputs_[index]); + } + + const Variable* Input(const std::string& name) const { + return scope_->GetVariable(op_.Input(name)); + } + + const Variable* Output(const std::string& name) const { + return scope_->GetVariable(op_.Output(name)); + } + + const std::vector Inputs(const std::string& name) const { + auto names = op_.Inputs(name); + std::vector res; + std::transform( + names.begin(), names.end(), res.begin(), + [this](const std::string& name) { return scope_->GetVariable(name); }); + return res; + } + + const std::vector Outputs(const std::string& name) const { + auto names = op_.Outputs(name); + std::vector res; + std::transform( + names.begin(), names.end(), res.begin(), + [this](const std::string& name) { return scope_->GetVariable(name); }); + return res; + } + + const OperatorBase& op_; + const std::shared_ptr& scope_; + const platform::DeviceContext& device_context_; }; class OpKernel { @@ -77,25 +147,6 @@ class OpKernel { * device resource such as CUDA stream, cublas handle, etc. from * KernelContext. User should construct it before run the Operator. */ - class KernelContext { - public: - KernelContext(const OperatorBase* op, const ScopePtr& scope, - const platform::DeviceContext& device_context) - : op_(*op), scope_(scope), device_context_(device_context) {} - - const Variable* Input(int index) const { - return scope_->GetVariable(op_.inputs_[index]); - } - - Variable* Output(int index) const { - return scope_->GetVariable(op_.outputs_[index]); - } - - const OperatorBase& op_; - const ScopePtr& scope_; - const platform::DeviceContext& device_context_; - }; - virtual void Compute(const KernelContext& context) const = 0; virtual ~OpKernel() {} @@ -140,7 +191,7 @@ class OperatorWithKernel : public OperatorBase { void Run(const ScopePtr& scope, const platform::DeviceContext& dev_ctx) const final { auto& opKernel = AllOpKernels().at(type_).at(OpKernelKey(dev_ctx)); - opKernel->Compute(OpKernel::KernelContext(this, scope, dev_ctx)); + opKernel->Compute(KernelContext(this, scope, dev_ctx)); } static std::unordered_map& @@ -148,6 +199,7 @@ class OperatorWithKernel : public OperatorBase { static std::unordered_map g_all_op_kernels; return g_all_op_kernels; } + void InferShape(const std::shared_ptr& scope) const final { std::vector ins; VarNamesToTensors(scope, inputs_, &ins); diff --git a/paddle/framework/operator_test.cc b/paddle/framework/operator_test.cc index 19ac4ecafa21d0a6fde57ef5e867670d7823fde0..6fa110f94ccc0c0a2f2e61316aa5dc271631a11c 100644 --- a/paddle/framework/operator_test.cc +++ b/paddle/framework/operator_test.cc @@ -30,7 +30,6 @@ class OpWithoutKernelTest : public OperatorBase { op_run_num++; ASSERT_EQ((int)inputs_.size(), 1); ASSERT_EQ((int)outputs_.size(), 1); - ASSERT_NEAR(GetAttr("scale"), 3.14, 1e-5); ASSERT_EQ(scope->GetVariable(inputs_[0]), nullptr); ASSERT_EQ(x, 1); ASSERT_NE(scope->GetVariable(outputs_[0]), nullptr); @@ -86,9 +85,11 @@ class OpKernelTestProtoAndCheckerMaker : public OpProtoAndCheckerMaker { public: OpKernelTestProtoAndCheckerMaker(OpProto* proto, OpAttrChecker* op_checker) : OpProtoAndCheckerMaker(proto, op_checker) { - AddInput("input", "input of test op"); - AddOutput("output", "output of test op"); - AddAttr("scale", "scale of cosine op"); + AddInput("x", "input of test op"); + AddOutput("y", "output of test op"); + AddAttr("scale", "scale of cosine op") + .SetDefault(1.0) + .LargerThan(0.0); AddComment("This is test op"); } }; @@ -103,11 +104,65 @@ class OpWithKernelTest : public OperatorWithKernel { class CPUKernelTest : public OpKernel { public: - void Compute(const KernelContext& context) const { + void Compute(const KernelContext& ctx) const { + std::cout << "this is cpu kernel" << std::endl; + std::cout << ctx.op_.DebugString() << std::endl; cpu_kernel_run_num++; - ASSERT_EQ((int)context.op_.inputs_.size(), 1); - ASSERT_EQ((int)context.op_.outputs_.size(), 1); - ASSERT_NEAR(context.op_.GetAttr("scale"), 3.14, 1e-5); + ASSERT_EQ(ctx.op_.Input("x"), "IN1"); + ASSERT_EQ(ctx.op_.Output("y"), "OUT1"); + } +}; + +// multiple inputs test +class OperatorMultiInputsTest : public OperatorBase { + public: + void Init() override { x = 1; } + void InferShape(const std::shared_ptr& scope) const override {} + void Run(const std::shared_ptr& scope, + const platform::DeviceContext& dev_ctx) const override { + ASSERT_EQ(scope->GetVariable(inputs_[0]), nullptr); + ASSERT_EQ(x, 1); + ASSERT_NE(scope->GetVariable(outputs_[0]), nullptr); + ASSERT_EQ(Input("x"), "IN1"); + ASSERT_EQ(Input("y"), "OUT1"); + } + + public: + float x = 0; +}; + +class OpKernelTestMultiInputsProtoAndCheckerMaker + : public OpProtoAndCheckerMaker { + public: + OpKernelTestMultiInputsProtoAndCheckerMaker(OpProto* proto, + OpAttrChecker* op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInputs("xs", "inputs of test op"); + AddInput("k", "input of test op"); + AddOutputs("ys", "outputs of test op"); + AddAttr("scale", "scale of cosine op") + .SetDefault(1.0) + .LargerThan(0.0); + AddComment("This is test op"); + } +}; + +class CPUKernalMultiInputsTest : public OpKernel { + public: + void Compute(const KernelContext& ctx) const { + auto xs = ctx.op_.Inputs("xs"); + ASSERT_EQ(xs.size(), 3UL); + ASSERT_EQ(xs[0], "x0"); + ASSERT_EQ(xs[1], "x1"); + ASSERT_EQ(xs[2], "x2"); + + auto k = ctx.op_.Input("k"); + ASSERT_EQ(k, "k0"); + + auto ys = ctx.op_.Outputs("ys"); + ASSERT_EQ(ys.size(), 2UL); + ASSERT_EQ(ys[0], "y0"); + ASSERT_EQ(ys[1], "y1"); } }; @@ -118,6 +173,7 @@ REGISTER_OP(op_with_kernel, paddle::framework::OpWithKernelTest, paddle::framework::OpKernelTestProtoAndCheckerMaker); REGISTER_OP_CPU_KERNEL(op_with_kernel, paddle::framework::CPUKernelTest); +// test with single input TEST(OpKernel, all) { paddle::framework::OpDesc op_desc; op_desc.set_type("op_with_kernel"); @@ -137,3 +193,47 @@ TEST(OpKernel, all) { op->Run(scope, cpu_device_context); ASSERT_EQ(paddle::framework::cpu_kernel_run_num, 1); } + +REGISTER_OP(op_multi_inputs_with_kernel, paddle::framework::OpWithKernelTest, + paddle::framework::OpKernelTestMultiInputsProtoAndCheckerMaker); +REGISTER_OP_CPU_KERNEL(op_multi_inputs_with_kernel, + paddle::framework::CPUKernalMultiInputsTest); + +// test with multi inputs +TEST(OpKernel, multi_inputs) { + using namespace paddle::framework; + + OpDesc op_desc; + op_desc.set_type("op_multi_inputs_with_kernel"); + *op_desc.mutable_inputs()->Add() = "x0"; + *op_desc.mutable_inputs()->Add() = "x1"; + *op_desc.mutable_inputs()->Add() = "x2"; + *op_desc.mutable_inputs()->Add() = "k0"; + *op_desc.mutable_outputs()->Add() = "y0"; + *op_desc.mutable_outputs()->Add() = "y1"; + auto attr = op_desc.mutable_attrs()->Add(); + attr->set_name("scale"); + attr->set_type(paddle::framework::AttrType::FLOAT); + attr->set_f(3.14); + + auto attr0 = op_desc.mutable_attrs()->Add(); + attr0->set_name("input_format"); + attr0->set_type(paddle::framework::AttrType::INTS); + auto input_format = attr0->mutable_ints(); + input_format->Add(0); // x0 + input_format->Add(3); // k + input_format->Add(4); // end + + auto attr1 = op_desc.mutable_attrs()->Add(); + attr1->set_name("output_format"); + attr1->set_type(paddle::framework::AttrType::INTS); + auto output_format = attr1->mutable_ints(); + output_format->Add(0); // y0 + output_format->Add(2); // y1 + + paddle::platform::CPUDeviceContext cpu_device_context; + auto scope = std::make_shared(); + + OperatorPtr op(paddle::framework::OpRegistry::CreateOp(op_desc)); + op->Run(scope, cpu_device_context); +} diff --git a/paddle/operators/add_op.h b/paddle/operators/add_op.h index 17d459dbc86af73fa86934a5ccce9da509c59c6b..000564f66dff2b20d17cee10e44aaca114c2c908 100644 --- a/paddle/operators/add_op.h +++ b/paddle/operators/add_op.h @@ -8,10 +8,10 @@ namespace operators { template class AddKernel : public framework::OpKernel { public: - void Compute(const KernelContext &context) const override { + void Compute(const framework::KernelContext &context) const override { LOG(INFO) << "Add kernel in " << typeid(Place).name(); } }; -} // namespace op +} // namespace operators } // namespace paddle diff --git a/paddle/pybind/pybind.cc b/paddle/pybind/pybind.cc index aa2b84799cc821a9bb06e68c1d5a77880862e040..9e480373874f7c6d24e12b4a657c8131adcc9278 100644 --- a/paddle/pybind/pybind.cc +++ b/paddle/pybind/pybind.cc @@ -67,6 +67,23 @@ All parameter, weight, gradient are variables in Paddle. } return ret_values; }); + m.def_submodule( + "var_names", + "The module will return special predefined variable name in Paddle") + .def("empty", pd::OperatorBase::EMPTY_VAR_NAME) + .def("temp", pd::OperatorBase::TMP_VAR_NAME); + + py::class_(m, "Operator") + .def("__str__", &pd::OperatorBase::DebugString) + .def_static("create", [](const std::string& protobin) { + pd::OpDesc desc; + PADDLE_ENFORCE(desc.ParsePartialFromString(protobin), + "Cannot parse user input to OpDesc"); + PADDLE_ENFORCE(desc.IsInitialized(), + "User OpDesc is not initialized, reason %s", + desc.InitializationErrorString()); + return pd::OpRegistry::CreateOp(desc); + }); return m.ptr(); } diff --git a/python/paddle/v2/framework/create_op_creation_methods.py b/python/paddle/v2/framework/create_op_creation_methods.py index 2fcdfead25414ccf44e9bfa964c83b98c852f6be..c2a7ae7692b08762ffbc91726be7bfa90e8ddedb 100644 --- a/python/paddle/v2/framework/create_op_creation_methods.py +++ b/python/paddle/v2/framework/create_op_creation_methods.py @@ -1,11 +1,246 @@ import paddle.v2.framework.core as core import paddle.v2.framework.proto.op_proto_pb2 as op_proto_pb2 +import paddle.v2.framework.proto.op_desc_pb2 as op_desc_pb2 +import paddle.v2.framework.proto.attr_type_pb2 as attr_type_pb2 +import cStringIO def get_all_op_protos(): + """ + Get all registered op proto from Paddle C++ + :return: list of OpProto + """ protostrs = core.get_all_op_protos() ret_values = [] for pbstr in protostrs: op_proto = op_proto_pb2.OpProto.FromString(str(pbstr)) ret_values.append(op_proto) return ret_values + + +class OpDescCreationMethod(object): + """ + A Functor object to convert user input(use key word args) to OpDesc based on + OpProto. + + :param op_proto: The OpProto object. + :type op_proto: op_proto_pb2.OpProto + """ + + def __init__(self, op_proto): + if not isinstance(op_proto, op_proto_pb2.OpProto): + raise TypeError("Argument should be OpProto") + self.__op_proto__ = op_proto + + def __call__(self, *args, **kwargs): + """ + Convert user input to OpDesc. Only key-word args are supported. + :return: OpDesc based on user input + :rtype: op_desc_pb2.OpDesc + """ + if len(args) != 0: + raise ValueError("Only keyword arguments is supported by Paddle") + op_desc = op_desc_pb2.OpDesc() + + # Inputs + ipts, ipt_format, _ = OpDescCreationMethod.extract_input_or_output( + "input", kwargs, self.__op_proto__.inputs) + op_desc.inputs.extend(ipts) + if ipt_format is not None: + op_desc.attrs.extend([ipt_format]) + + # Outputs + outs, out_format, tmp_index = OpDescCreationMethod.extract_input_or_output( + "output", kwargs, self.__op_proto__.outputs) + op_desc.outputs.extend(outs) + if out_format is not None: + op_desc.attrs.extend([out_format]) + if len(tmp_index) != 0: + tmp_index_attr = op_desc.attrs.add() + tmp_index_attr.type = attr_type_pb2.INTS + tmp_index_attr.name = "temporary_index" + tmp_index_attr.ints.extend(tmp_index) + + # Types + op_desc.type = self.__op_proto__.type + + # Attrs + for attr in self.__op_proto__.attrs: + if attr.generated: + continue + user_defined_attr = kwargs.get(attr.name, None) + if user_defined_attr is not None: + new_attr = op_desc.attrs.add() + new_attr.name = attr.name + new_attr.type = attr.type + if attr.type == attr_type_pb2.INT: + new_attr.i = user_defined_attr + elif attr.type == attr_type_pb2.FLOAT: + new_attr.f = user_defined_attr + elif attr.type == attr_type_pb2.STRING: + new_attr.s = user_defined_attr + elif attr.type == attr_type_pb2.INTS: + new_attr.ints.extend(user_defined_attr) + elif attr.type == attr_type_pb2.FLOATS: + new_attr.floats.extend(user_defined_attr) + elif attr.type == attr_type_pb2.STRINGS: + new_attr.strings.extend(user_defined_attr) + else: + raise NotImplementedError("Not support attribute type " + + attr.type) + + return op_desc + + @staticmethod + def extract_input_or_output(in_out, kwargs, meta): + """ + Extract input variable names or output variable names from key-word + arguments, which base on VarProtos. + + :param in_out: "input" or "output" + :param kwargs: key-word arguments that user inputted. + :param meta: a list of VarProto + :return: The three object will be return. The variable names. The + input_format or output_format attribute(None if the input or output is + not multiple). The temporary variable index list. + """ + multiple = OpDescCreationMethod.any_is_true((m.multiple for m in meta)) + tmp_index = [] + retv = [] + if multiple: + var_format = op_desc_pb2.AttrDesc() + var_format.type = attr_type_pb2.INTS + var_format.name = "%s_format" % in_out + var_format.ints.append(0) + + for var in meta: + var_name = var.name + + if var.temporary: + var_name = [core.var_names.temp()] + tmp_index.append(len(retv)) + else: + var_name = kwargs.get(var_name, []) + if not isinstance(var_name, list): + var_name = [var_name] + retv.extend(var_name) + var_format.ints.append(len(var_name) + var_format.ints[-1]) + return retv, var_format, tmp_index + else: + for var in meta: + if var.temporary: + retv.append(kwargs.get(var.name, core.var_names.temp())) + tmp_index.append(len(retv)) + else: + retv.append(kwargs.get(var.name, core.var_names.empty())) + return retv, None, tmp_index + + @staticmethod + def any_is_true(generator): + """ + Reduce a bool array to one. If any of them is True, then return True. + """ + for flag in generator: + if flag: + return True + return False + + +def get_docstring_from_op_proto(op_proto): + """ + Generate docstring from a OpProto + :param op_proto: a OpProto instance. + :type op_proto: op_proto_pb2.OpProto + :return: docstring + """ + if not isinstance(op_proto, op_proto_pb2.OpProto): + raise TypeError("Input must be OpProto") + f = cStringIO.StringIO() + f.write(op_proto.comment) + f.write("\n") + + def __append_param__(name, comment, type): + # Maybe replace the following line with template engine is better. + f.write(":param ") + f.write(name) + f.write(": ") + f.write(comment) + f.write("\n") + f.write(":type ") + f.write(name) + f.write(": ") + f.write(type) + f.write("\n") + + for ipt in op_proto.inputs: + __append_param__(ipt.name, ipt.comment, "list | basestr" + if ipt.multiple else "basestr") + + temp_var_prefix = \ + "This is a temporary variable. It does not have to set by user. " + for opt in op_proto.outputs: + __append_param__(opt.name, opt.comment if not opt.temporary else + temp_var_prefix + opt.comment, "list | basestr" + if opt.multiple else "basestr") + + for attr in op_proto.attrs: + attr_type = None + if attr.type == attr_type_pb2.INT: + attr_type = "int" + elif attr.type == attr_type_pb2.FLOAT: + attr_type = "float" + elif attr.type == attr_type_pb2.STRING: + attr_type = "basestr" + elif attr.type == attr_type_pb2.INTS: + attr_type = "list of int" + elif attr.type == attr_type_pb2.FLOATS: + attr_type = "list of float" + elif attr.type == attr_type_pb2.STRINGS: + attr_type = "list of basestr" + + if attr_type is None: + raise RuntimeError("Not supported attribute type " + attr.type) + + __append_param__(attr.name, attr.comment, attr_type) + + return f.getvalue() + + +def create_op_creation_method(op_proto): + """ + Generate op creation method for an OpProto + """ + method = OpDescCreationMethod(op_proto) + + def __impl__(*args, **kwargs): + opdesc = method(*args, **kwargs) + return core.Operator.create(opdesc.SerializeToString()) + + __impl__.__doc__ = get_docstring_from_op_proto(op_proto) + return __impl__ + + +class OpCreationsHolder(object): + """ + A object will holds all op creation methods. + + Use `op_creations.xxx_op` to access them. + """ + pass + + +op_creations = OpCreationsHolder() + + +def __bootstrap__(): + """ + Bootstrap function for this module. It will dynamic create all op creation + methods in runtime. + """ + for op_proto in get_all_op_protos(): + func = create_op_creation_method(op_proto) + func.__name__ = str(op_proto.type) + setattr(op_creations, func.__name__, func) + + +__bootstrap__() diff --git a/python/paddle/v2/framework/tests/test_op_creation_methods.py b/python/paddle/v2/framework/tests/test_op_creation_methods.py index b205e2cabb99ab08604ab3c3ce073bcb95ec4bb3..41db7c0d535aa920b34d6cc346090a8c15bfb110 100644 --- a/python/paddle/v2/framework/tests/test_op_creation_methods.py +++ b/python/paddle/v2/framework/tests/test_op_creation_methods.py @@ -1,9 +1,13 @@ import unittest import paddle.v2.framework.create_op_creation_methods as creation +import paddle.v2.framework.core as core +import paddle.v2.framework.proto.op_proto_pb2 as op_proto_pb2 +import paddle.v2.framework.proto.op_desc_pb2 as op_desc_pb2 +import paddle.v2.framework.proto.attr_type_pb2 as attr_type_pb2 -class TestOpCreationsMethods(unittest.TestCase): - def test_all_protos(self): +class TestGetAllProtos(unittest.TestCase): + def test_all(self): all_protos = creation.get_all_op_protos() self.assertNotEqual(0, len(all_protos)) @@ -11,5 +15,240 @@ class TestOpCreationsMethods(unittest.TestCase): self.assertTrue(each.IsInitialized()) +class TestOpDescCreationMethod(unittest.TestCase): + def test_plain_input_output(self): + op = op_proto_pb2.OpProto() + op.type = "test" + ipt = op.inputs.add() + ipt.name = "X" + ipt.comment = "not matter" + + ipt = op.inputs.add() + ipt.name = "Y" + ipt.comment = "not matter" + + opt = op.outputs.add() + opt.name = "Z" + opt.comment = "not matter" + + op.comment = "not matter" + + self.assertTrue(op.IsInitialized()) + + method = creation.OpDescCreationMethod(op) + output = method(X="a", Y="b", Z="c") + + expected = op_desc_pb2.OpDesc() + expected.type = "test" + expected.inputs.extend(["a", "b"]) + expected.outputs.append("c") + self.assertEqual(expected, output) + + def test_multiple_input_plain_output(self): + op = op_proto_pb2.OpProto() + op.type = "fc" + ipt = op.inputs.add() + ipt.name = "X" + ipt.comment = "" + ipt.multiple = True + + ipt = op.inputs.add() + ipt.name = "W" + ipt.comment = "" + ipt.multiple = True + + ipt = op.inputs.add() + ipt.name = "b" + ipt.comment = "" + + out = op.outputs.add() + out.name = "Y" + out.comment = "" + + op.comment = "" + self.assertTrue(op.IsInitialized()) + method = creation.OpDescCreationMethod(op) + + generated1 = method(X="x", W="w", b="b", Y="y") + expected1 = op_desc_pb2.OpDesc() + expected1.inputs.extend(['x', 'w', 'b']) + expected1.outputs.extend(['y']) + expected1.type = 'fc' + attr = expected1.attrs.add() + attr.name = 'input_format' + attr.type = attr_type_pb2.INTS + attr.ints.extend([0, 1, 2, 3]) + self.assertEqual(expected1, generated1) + + generated2 = method( + X=['x1', 'x2', 'x3'], b='b', W=['w1', 'w2', 'w3'], Y='y') + expected2 = op_desc_pb2.OpDesc() + expected2.inputs.extend(['x1', 'x2', 'x3', 'w1', 'w2', 'w3', 'b']) + expected2.outputs.extend(['y']) + expected2.type = 'fc' + attr = expected2.attrs.add() + attr.name = 'input_format' + attr.type = attr_type_pb2.INTS + attr.ints.extend([0, 3, 6, 7]) + self.assertEqual(expected2, generated2) + + def test_attrs(self): + op = op_proto_pb2.OpProto() + op.type = "test" + ipt = op.inputs.add() + ipt.name = 'X' + ipt.comment = "" + + def __add_attr__(name, type): + attr = op.attrs.add() + attr.name = name + attr.comment = "" + attr.type = type + + __add_attr__("int_attr", attr_type_pb2.INT) + __add_attr__("float_attr", attr_type_pb2.FLOAT) + __add_attr__("string_attr", attr_type_pb2.STRING) + __add_attr__("ints_attr", attr_type_pb2.INTS) + __add_attr__("floats_attr", attr_type_pb2.FLOATS) + __add_attr__("strings_attr", attr_type_pb2.STRINGS) + + op.comment = "" + self.assertTrue(op.IsInitialized()) + + method = creation.OpDescCreationMethod(op) + + generated = method( + X="a", + int_attr=10, + float_attr=3.2, + string_attr="test_str", + ints_attr=[0, 1, 2, 3, 4], + floats_attr=[0.2, 3.2, 4.5], + strings_attr=["a", "b", "c"]) + + expected = op_desc_pb2.OpDesc() + expected.type = "test" + expected.inputs.extend(['a']) + attr = expected.attrs.add() + attr.name = "int_attr" + attr.type = attr_type_pb2.INT + attr.i = 10 + + attr = expected.attrs.add() + attr.name = "float_attr" + attr.type = attr_type_pb2.FLOAT + attr.f = 3.2 + + attr = expected.attrs.add() + attr.name = "string_attr" + attr.type = attr_type_pb2.STRING + attr.s = "test_str" + + attr = expected.attrs.add() + attr.name = "ints_attr" + attr.type = attr_type_pb2.INTS + attr.ints.extend([0, 1, 2, 3, 4]) + + attr = expected.attrs.add() + attr.name = 'floats_attr' + attr.type = attr_type_pb2.FLOATS + attr.floats.extend([0.2, 3.2, 4.5]) + + attr = expected.attrs.add() + attr.name = 'strings_attr' + attr.type = attr_type_pb2.STRINGS + attr.strings.extend(['a', 'b', 'c']) + + self.assertEqual(expected, generated) + + def test_input_temporary_output(self): + op = op_proto_pb2.OpProto() + op.type = "test" + out = op.outputs.add() + out.name = "OUT" + out.comment = "" + + out = op.outputs.add() + out.name = "TMP" + out.comment = "" + out.temporary = True + + out = op.outputs.add() + out.name = "OUT2" + out.comment = "" + op.comment = "" + + method = creation.OpDescCreationMethod(op) + generated = method(OUT="a", OUT2="b") + desc = op_desc_pb2.OpDesc() + desc.outputs.extend(["a", core.var_names.temp(), "b"]) + desc.type = "test" + attr = desc.attrs.add() + attr.name = "temporary_index" + attr.type = attr_type_pb2.INTS + attr.ints.append(2) + self.assertEqual(generated, desc) + + +class TestOpCreationDocStr(unittest.TestCase): + def test_all(self): + op = op_proto_pb2.OpProto() + op.type = "test" + op.comment = """Test Op. + +This op is used for unit test, not a real op. +""" + a = op.inputs.add() + a.name = "a" + a.comment = "Input a for test op" + a.multiple = True + + b = op.inputs.add() + b.name = "b" + b.comment = "Input b for test op" + self.assertTrue(op.IsInitialized()) + + o1 = op.outputs.add() + o1.name = "output" + o1.comment = "The output of test op" + + o2 = op.outputs.add() + o2.name = "temp output" + o2.comment = "The temporary output of test op" + o2.temporary = True + + test_str = op.attrs.add() + test_str.name = "str_attr" + test_str.type = attr_type_pb2.STRING + test_str.comment = "A string attribute for test op" + + actual = creation.get_docstring_from_op_proto(op) + expected_docstring = '''Test Op. + +This op is used for unit test, not a real op. + +:param a: Input a for test op +:type a: list | basestr +:param b: Input b for test op +:type b: basestr +:param output: The output of test op +:type output: basestr +:param temp output: This is a temporary variable. It does not have to set by user. The temporary output of test op +:type temp output: basestr +:param str_attr: A string attribute for test op +:type str_attr: basestr +''' + self.assertEqual(expected_docstring, actual) + + +class TestOpCreations(unittest.TestCase): + def test_all(self): + add_op = creation.op_creations.add_two(X="a", Y="b", Out="z") + self.assertIsNotNone(add_op) + # Invoke C++ DebugString() + self.assertEqual('Op(add_two), inputs:(a, b), outputs:(z).', + str(add_op)) + + if __name__ == "__main__": unittest.main() diff --git a/python/paddle/v2/optimizer.py b/python/paddle/v2/optimizer.py index 7e8a3bece9acb93ffba634e7b3656f2b580a48f2..ba581980334fec6226a537af2cf53b3465d32c1e 100644 --- a/python/paddle/v2/optimizer.py +++ b/python/paddle/v2/optimizer.py @@ -1,4 +1,3 @@ -import py_paddle.swig_paddle as swig_api import paddle.trainer_config_helpers.config_parser_utils as config_parser_utils import paddle.trainer_config_helpers.optimizers as v1_optimizers """ @@ -17,6 +16,7 @@ __all__ = [ class Optimizer(object): def __init__(self, **kwargs): + import py_paddle.swig_paddle as swig_api if 'batch_size' in kwargs: del kwargs['batch_size'] # not important for python library. @@ -35,18 +35,22 @@ class Optimizer(object): For each optimizer(SGD, Adam), GradientMachine should enable different buffers. """ + import py_paddle.swig_paddle as swig_api tmp = swig_api.ParameterOptimizer.create(self.__opt_conf__) assert isinstance(tmp, swig_api.ParameterOptimizer) return tmp.getParameterTypes() def __create_local_updater__(self): + import py_paddle.swig_paddle as swig_api return swig_api.ParameterUpdater.createLocalUpdater(self.__opt_conf__) def __create_remote_updater__(self, pass_num, use_sparse_updater): + import py_paddle.swig_paddle as swig_api return swig_api.ParameterUpdater.createRemoteUpdater( self.__opt_conf__, pass_num, use_sparse_updater) def __create_new_remote_updater__(self, pserver_spec, use_etcd): + import py_paddle.swig_paddle as swig_api return swig_api.ParameterUpdater.createNewRemoteUpdater( self.__opt_conf__, pserver_spec, use_etcd)