diff --git a/doc/howto/dev/new_op_cn.md b/doc/howto/dev/new_op_cn.md index 9d3d02ffc3116ebec537ab9b890eafccad196ed0..c823d7e9fcd63dd7719ac1403952b03c2d2f03c0 100644 --- a/doc/howto/dev/new_op_cn.md +++ b/doc/howto/dev/new_op_cn.md @@ -206,7 +206,7 @@ MulOp(const std::string &type, const framework::VariableNameMap &inputs, - `REGISTER_OP` : 注册`ops::MulOp`类,类型名为`mul`,该类的`ProtoMaker`为`ops::MulOpMaker`,注册`ops::MulOpGrad`,类型名为`mul_grad`。 - `REGISTER_OP_WITHOUT_GRADIENT` : 用于注册没有反向的Op。 - - `REGISTER_OP_CPU_KERNEL` :注册`ops::MulKernel`类,并特化模板参数为`paddle::platform::CPUPlace`和`float`类型,同理,注册`ops::MulKernel`类。 + - `REGISTER_OP_CPU_KERNEL` :注册`ops::MulKernel`类,并特化模板参数为`paddle::platform::CPUPlace`和`float`类型,同理,注册`ops::MulGradKernel`类。 - 在 `.cu`文件中注册GPU Kernel。 diff --git a/doc/howto/dev/new_op_en.md b/doc/howto/dev/new_op_en.md index 57ff7caad19cc6bf2e4a052d306d4fc303c8875d..1e88e1f5b4df710f1b69f0305d8d8a2921c4249a 100644 --- a/doc/howto/dev/new_op_en.md +++ b/doc/howto/dev/new_op_en.md @@ -205,7 +205,7 @@ The definition of its corresponding backward operator, if applicable, is similar - `REGISTER_OP` registers the `ops::MulOp` class, type named `mul`, its type `ProtoMaker` is `ops::MulOpMaker`, registering `ops::MulOpGrad` as `mul_grad`. - `REGISTER_OP_WITHOUT_GRADIENT` registers an operator without gradient. - - `REGISTER_OP_CPU_KERNEL` registers `ops::MulKernel` class and specialized template types `paddle::platform::CPUPlace` and `float`, which also registers `ops::MulKernel`. + - `REGISTER_OP_CPU_KERNEL` registers `ops::MulKernel` class and specialized template types `paddle::platform::CPUPlace` and `float`, which also registers `ops::MulGradKernel`. - Registering GPU Kernel in `.cu` files diff --git a/paddle/framework/attribute.h b/paddle/framework/attribute.h index c7559cefb6415ee141f32e4357459653564cd2ac..d13530e3408a54c7ecab87c3bd9e6288e342f9af 100644 --- a/paddle/framework/attribute.h +++ b/paddle/framework/attribute.h @@ -21,20 +21,12 @@ limitations under the License. */ #include #include "paddle/framework/framework.pb.h" +#include "paddle/framework/type_defs.h" #include "paddle/platform/enforce.h" -#include "paddle/platform/variant.h" namespace paddle { namespace framework { -// The order should be as same as framework.proto -typedef boost::variant, - std::vector, std::vector, bool, - std::vector, BlockDesc*> - Attribute; - -typedef std::unordered_map AttributeMap; - ProgramDesc& GetProgramDesc(); template diff --git a/paddle/framework/op_desc.cc b/paddle/framework/op_desc.cc index 0c12c55dc09f6aa064066b5c73bc5e985a57343f..33a064890cc62a4e3f04c79bc000951e76fd477c 100644 --- a/paddle/framework/op_desc.cc +++ b/paddle/framework/op_desc.cc @@ -112,6 +112,30 @@ const std::unordered_map &OpDescBind::GetAttrMap() return attrs_; } +struct SetAttrDescVisitor : public boost::static_visitor { + explicit SetAttrDescVisitor(OpDesc::Attr *attr) : attr_(attr) {} + mutable OpDesc::Attr *attr_; + void operator()(int v) const { attr_->set_i(v); } + void operator()(float v) const { attr_->set_f(v); } + void operator()(const std::string &v) const { attr_->set_s(v); } + void operator()(bool b) const { attr_->set_b(b); } + + void operator()(const std::vector &v) const { + VectorToRepeated(v, attr_->mutable_ints()); + } + void operator()(const std::vector &v) const { + VectorToRepeated(v, attr_->mutable_floats()); + } + void operator()(const std::vector &v) const { + VectorToRepeated(v, attr_->mutable_strings()); + } + void operator()(const std::vector &v) const { + VectorToRepeated(v, attr_->mutable_bools()); + } + void operator()(BlockDesc *desc) const { attr_->set_block_idx(desc->idx()); } + void operator()(boost::blank) const { PADDLE_THROW("Unexpected branch"); } +}; + void OpDescBind::Sync() { if (need_update_) { this->op_desc_.mutable_inputs()->Clear(); @@ -134,7 +158,8 @@ void OpDescBind::Sync() { attr_desc->set_name(attr.first); attr_desc->set_type( static_cast(attr.second.which() - 1)); - boost::apply_visitor(SetAttrDescVisitor(attr_desc), attr.second); + SetAttrDescVisitor visitor(attr_desc); + boost::apply_visitor(visitor, attr.second); } need_update_ = false; diff --git a/paddle/framework/op_desc.h b/paddle/framework/op_desc.h index 0cf7d13971675eb825bcd0c7636896f0862d6ebb..0af416971556c812db238caeb7959a0a488c6e1b 100644 --- a/paddle/framework/op_desc.h +++ b/paddle/framework/op_desc.h @@ -17,6 +17,7 @@ limitations under the License. */ #include #include #include "paddle/framework/attribute.h" +#include "paddle/framework/type_defs.h" #include "paddle/framework/var_desc.h" namespace paddle { @@ -61,48 +62,22 @@ class OpDescBind { void SetBlockAttr(const std::string &name, BlockDescBind &block); // Only be used in C++ - void SetAttrMap(const std::unordered_map &attr_map); + void SetAttrMap(const AttributeMap &attr_map); Attribute GetAttr(const std::string &name) const; int GetBlockAttr(const std::string &name) const; // Only be used in C++ - const std::unordered_map &GetAttrMap() const; + const AttributeMap &GetAttrMap() const; private: - struct SetAttrDescVisitor : public boost::static_visitor { - explicit SetAttrDescVisitor(OpDesc::Attr *attr) : attr_(attr) {} - mutable OpDesc::Attr *attr_; - void operator()(int v) const { attr_->set_i(v); } - void operator()(float v) const { attr_->set_f(v); } - void operator()(const std::string &v) const { attr_->set_s(v); } - void operator()(bool b) const { attr_->set_b(b); } - - void operator()(const std::vector &v) const { - VectorToRepeated(v, attr_->mutable_ints()); - } - void operator()(const std::vector &v) const { - VectorToRepeated(v, attr_->mutable_floats()); - } - void operator()(const std::vector &v) const { - VectorToRepeated(v, attr_->mutable_strings()); - } - void operator()(const std::vector &v) const { - VectorToRepeated(v, attr_->mutable_bools()); - } - void operator()(BlockDesc *desc) const { - attr_->set_block_idx(desc->idx()); - } - void operator()(boost::blank) const { PADDLE_THROW("Unexpected branch"); } - }; - void Sync(); OpDesc op_desc_; - std::unordered_map> inputs_; - std::unordered_map> outputs_; - std::unordered_map attrs_; + VariableNameMap inputs_; + VariableNameMap outputs_; + AttributeMap attrs_; // need_update_ indicate there some local changes not be synchronized. If // local changes should be synchronized, need_update_ should be set to true. diff --git a/paddle/framework/op_info.h b/paddle/framework/op_info.h index 5df309331856ad891b495c7d505863c2b18c7c41..7940922b0974d47e8defb85edae69a8e72470c59 100644 --- a/paddle/framework/op_info.h +++ b/paddle/framework/op_info.h @@ -19,16 +19,11 @@ #include #include "paddle/framework/attribute.h" #include "paddle/framework/op_desc.h" -#include "paddle/platform/macros.h" +#include "paddle/framework/type_defs.h" + namespace paddle { namespace framework { -class OperatorBase; -using VariableNameMap = std::map>; - -using OpCreator = std::function; class GradOpDescMakerBase { public: diff --git a/paddle/framework/type_defs.h b/paddle/framework/type_defs.h new file mode 100644 index 0000000000000000000000000000000000000000..dec5066f1e649a27c9f47fc5f93faa1498b37de7 --- /dev/null +++ b/paddle/framework/type_defs.h @@ -0,0 +1,38 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#pragma once +#include +#include +#include "paddle/platform/variant.h" + +namespace paddle { +namespace framework { +class OperatorBase; +using VariableNameMap = std::map>; + +// The order should be as same as framework.proto +using Attribute = + boost::variant, + std::vector, std::vector, bool, + std::vector, BlockDesc*>; + +using AttributeMap = std::unordered_map; + +using OpCreator = std::function; + +} // namespace framework +} // namespace paddle diff --git a/paddle/platform/enforce.h b/paddle/platform/enforce.h index b523ef03c0053622bfda5b4bf07515c1b480b4af..52bd23039b82fc0290edb640f5d5851ba9abee61 100644 --- a/paddle/platform/enforce.h +++ b/paddle/platform/enforce.h @@ -185,7 +185,7 @@ inline void throw_on_error(T e) { std::make_exception_ptr( \ std::runtime_error(paddle::string::Sprintf(__VA_ARGS__))), \ __FILE__, __LINE__); \ - } while (0) + } while (false) #define PADDLE_ENFORCE(...) \ do { \ @@ -195,7 +195,7 @@ inline void throw_on_error(T e) { throw ::paddle::platform::EnforceNotMet(std::current_exception(), \ __FILE__, __LINE__); \ } \ - } while (0) + } while (false) /* * Some enforce helpers here, usage: