提交 77af58f8 编写于 作者: F fengjiayi

Change gradient Op registry mechanism

OLD: op_type -> grad_op_creator

NEW: grad_op_type -> grad_op_creator
     op_type -> grad_op_type
上级 02cde244
...@@ -12,20 +12,22 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,20 +12,22 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/framework/grad_op_creator.h" #include "paddle/framework/grad_op_builder.h"
#include "paddle/framework/op_registry.h" #include "paddle/framework/op_registry.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
OperatorBase* GradOpCreator::Create() { OperatorBase* GradOpBuilder::Build() {
BuildOpInOutArgList(); BuildOpInOutArgList();
OperatorBase* grad_op = OpRegistry::grad_creators().at(op_->type_)(); std::string grad_op_type = OpRegistry::grad_ops().at(op->type_);
OperatorBase* grad_op = OpRegistry::op_creators().at(grad_op_type)();
grad_op->type_ = grad_op_type;
CompleteGradOp(grad_op); CompleteGradOp(grad_op);
return grad_op; return grad_op;
} }
OpInOutArg* GradOpCreator::BuildArg(const VarProto& var, OpInOutArg* GradOpBuilder::BuildArg(const VarProto& var,
const VarIndexMap& var_map, const VarIndexMap& var_map,
const std::vector<int>& format, const std::vector<int>& format,
InOutType type) { InOutType type) {
...@@ -36,7 +38,7 @@ OpInOutArg* GradOpCreator::BuildArg(const VarProto& var, ...@@ -36,7 +38,7 @@ OpInOutArg* GradOpCreator::BuildArg(const VarProto& var,
end_idx); end_idx);
} }
void GradOpCreator::BuildOpInOutArgList() { void GradOpBuilder::BuildOpInOutArgList() {
const OpProto& op_proto = OpRegistry::protos().at(op_->type_); const OpProto& op_proto = OpRegistry::protos().at(op_->type_);
const auto& var_map = *(OpRegistry::VarIndexMaps().at(op_->type_)); const auto& var_map = *(OpRegistry::VarIndexMaps().at(op_->type_));
const std::vector<int>& in_format = const std::vector<int>& in_format =
...@@ -57,7 +59,7 @@ void GradOpCreator::BuildOpInOutArgList() { ...@@ -57,7 +59,7 @@ void GradOpCreator::BuildOpInOutArgList() {
} }
} }
void GradOpCreator::AddArgIntoGradOp(const OpInOutArg* arg, void GradOpBuilder::AddArgIntoGradOp(const OpInOutArg* arg,
std::vector<std::string>& in_out, std::vector<std::string>& in_out,
std::vector<int>& format, std::vector<int>& format,
VarIndexMap* varmap, int& idx, VarIndexMap* varmap, int& idx,
...@@ -80,8 +82,7 @@ void GradOpCreator::AddArgIntoGradOp(const OpInOutArg* arg, ...@@ -80,8 +82,7 @@ void GradOpCreator::AddArgIntoGradOp(const OpInOutArg* arg,
format.push_back(in_out.size()); format.push_back(in_out.size());
} }
void GradOpCreator::CompleteGradOp(OperatorBase* grad_op) const { void GradOpBuilder::CompleteGradOp(OperatorBase* grad_op) const {
grad_op->type_ = op_->type_ + "@GRAD"; // not necessary
grad_op->attrs_ = op_->attrs_; grad_op->attrs_ = op_->attrs_;
grad_op->attrs_.erase("input_format"); grad_op->attrs_.erase("input_format");
grad_op->attrs_.erase("output_format"); grad_op->attrs_.erase("output_format");
......
...@@ -25,12 +25,12 @@ struct OpInOutArg { ...@@ -25,12 +25,12 @@ struct OpInOutArg {
size_t end_idx_; size_t end_idx_;
}; };
class GradOpCreator { class GradOpBuilder {
using VarIndexMap = std::unordered_map<std::string, int>; using VarIndexMap = std::unordered_map<std::string, int>;
public: public:
GradOpCreator(const OperatorBase* op) : op_(op) {} GradOpBuilder(const OperatorBase* op) : op_(op) {}
OperatorBase* Create(); OperatorBase* Build();
private: private:
OpInOutArg* BuildArg(const VarProto& var, const VarIndexMap& var_map, OpInOutArg* BuildArg(const VarProto& var, const VarIndexMap& var_map,
......
...@@ -222,7 +222,7 @@ class OpRegistry { ...@@ -222,7 +222,7 @@ class OpRegistry {
public: public:
template <typename OpType, typename ProtoMakerType> template <typename OpType, typename ProtoMakerType>
static void RegisterOp(const std::string& op_type) { static void RegisterOp(const std::string& op_type) {
creators()[op_type] = [] { return new OpType; }; op_creators()[op_type] = [] { return new OpType; };
OpAttrChecker& op_checker = op_checkers()[op_type]; OpAttrChecker& op_checker = op_checkers()[op_type];
OpProto& op_proto = protos()[op_type]; OpProto& op_proto = protos()[op_type];
auto maker = ProtoMakerType(&op_proto, &op_checker); auto maker = ProtoMakerType(&op_proto, &op_checker);
...@@ -245,17 +245,19 @@ class OpRegistry { ...@@ -245,17 +245,19 @@ class OpRegistry {
} }
} }
template <typename OpType> template <typename GradOpType>
static void RegisterGradOp(const std::string& op_type) { static void RegisterGradOp(const std::string& op_type,
grad_creators()[op_type] = [] { return new OpType; }; const std::string& grad_op_type) {
op_creators()[grad_op_type] = [] { return new GradOpType; };
grad_ops()[op_type] = grad_op_type;
} }
static std::shared_ptr<OperatorBase> CreateOp(const std::string& type, static std::shared_ptr<OperatorBase> CreateOp(const std::string& type,
const VarNameList& inputs, const VarNameList& inputs,
const VarNameList& outputs, const VarNameList& outputs,
const AttributeMap& attrs) { const AttributeMap& attrs) {
auto op_create_it = creators().find(type); auto op_create_it = op_creators().find(type);
PADDLE_ENFORCE(op_create_it != creators().end(), PADDLE_ENFORCE(op_create_it != op_creators().end(),
"Operator %s cannot be found.", type); "Operator %s cannot be found.", type);
auto op = op_create_it->second(); auto op = op_create_it->second();
...@@ -300,8 +302,8 @@ class OpRegistry { ...@@ -300,8 +302,8 @@ class OpRegistry {
static std::shared_ptr<OperatorBase> CreateGradOp( static std::shared_ptr<OperatorBase> CreateGradOp(
std::shared_ptr<OperatorBase> op) { std::shared_ptr<OperatorBase> op) {
GradOpCreator creator(op.get()); GradOpBuilder builder(op.get());
std::shared_ptr<OperatorBase> grad_op(creator.Create()); std::shared_ptr<OperatorBase> grad_op(builder.Build());
grad_op->Init(); grad_op->Init();
return grad_op; return grad_op;
} }
...@@ -311,9 +313,9 @@ class OpRegistry { ...@@ -311,9 +313,9 @@ class OpRegistry {
return protos_; return protos_;
}; };
static std::unordered_map<std::string, OpCreator>& grad_creators() { static std::unordered_map<std::string, std::string>& grad_ops() {
static std::unordered_map<std::string, OpCreator> grad_creators_; static std::unordered_map<std::string, std::string> grad_ops_;
return grad_creators_; return grad_ops_;
} }
static std::unordered_map<std::string, std::shared_ptr<VarIndexMap>>& static std::unordered_map<std::string, std::shared_ptr<VarIndexMap>>&
...@@ -322,12 +324,12 @@ class OpRegistry { ...@@ -322,12 +324,12 @@ class OpRegistry {
return maps_; return maps_;
} }
private: static std::unordered_map<std::string, OpCreator>& op_creators() {
static std::unordered_map<std::string, OpCreator>& creators() { static std::unordered_map<std::string, OpCreator> op_creators_;
static std::unordered_map<std::string, OpCreator> creators_; return op_creators_;
return creators_;
} }
private:
static std::unordered_map<std::string, OpAttrChecker>& op_checkers() { static std::unordered_map<std::string, OpAttrChecker>& op_checkers() {
static std::unordered_map<std::string, OpAttrChecker> op_checkers_; static std::unordered_map<std::string, OpAttrChecker> op_checkers_;
return op_checkers_; return op_checkers_;
...@@ -353,11 +355,11 @@ class OpRegisterHelper { ...@@ -353,11 +355,11 @@ class OpRegisterHelper {
} }
}; };
template <typename OpType> template <typename GradOpType>
class GradOpRegisterHelper { class GradOpRegisterHelper {
public: public:
GradOpRegisterHelper(const char* op_type) { GradOpRegisterHelper(const char* op_type, const char* grad_op_type) {
OpRegistry::RegisterGradOp<OpType>(op_type); OpRegistry::RegisterGradOp<GradOpType>(op_type, grad_op_type);
} }
}; };
...@@ -383,13 +385,16 @@ class GradOpRegisterHelper { ...@@ -383,13 +385,16 @@ class GradOpRegisterHelper {
/** /**
* Macro to Register Gradient Operator. * Macro to Register Gradient Operator.
*/ */
#define REGISTER_GRADIENT_OP(__op_type, __op_class) \ #define REGISTER_GRADIENT_OP(__op_type, __grad_op_type, __grad_op_class) \
STATIC_ASSERT_GLOBAL_NAMESPACE( \ STATIC_ASSERT_GLOBAL_NAMESPACE( \
__reg_gradient_op__##__op_type, \ __reg_gradient_op__##__op_type##__grad_op_type, \
"REGISTER_GRADIENT_OP must be in global namespace"); \ "REGISTER_GRADIENT_OP must be in global namespace"); \
static ::paddle::framework::GradOpRegisterHelper<__op_class> \ static ::paddle::framework::GradOpRegisterHelper<__grad_op_class> \
__op_gradient_register_##__op_type##__(#__op_type); \ __op_gradient_register_##__op_type##__grad_op_type##__(#__op_type, \
int __op_gradient_register_##__op_type##_handle__() { return 0; } #__grad_op_type); \
int __op_gradient_register_##__op_type##__grad_op_type##_handle__() { \
return 0; \
}
/** /**
* Macro to Register OperatorKernel. * Macro to Register OperatorKernel.
......
...@@ -65,6 +65,6 @@ protected: ...@@ -65,6 +65,6 @@ protected:
} // namespace paddle } // namespace paddle
REGISTER_OP(add_two, paddle::operators::AddOp, paddle::operators::AddOpMaker); REGISTER_OP(add_two, paddle::operators::AddOp, paddle::operators::AddOpMaker);
REGISTER_GRADIENT_OP(add_two, paddle::operators::AddOpGrad); REGISTER_GRADIENT_OP(add_two, add_two_grad, paddle::operators::AddOpGrad);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
add_two, paddle::operators::AddKernel<paddle::platform::CPUPlace, float>); add_two, paddle::operators::AddKernel<paddle::platform::CPUPlace, float>);
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册