From 6768b31037161fa8a9979bd2b4294adbf11966c2 Mon Sep 17 00:00:00 2001 From: fengjiayi Date: Fri, 11 Aug 2017 13:43:31 -0700 Subject: [PATCH] Fix compile error --- paddle/framework/grad_op_builder.cc | 10 +++++----- paddle/framework/op_registry.h | 29 ++++++++++++++++------------- paddle/framework/operator_test.cc | 5 +++-- 3 files changed, 24 insertions(+), 20 deletions(-) diff --git a/paddle/framework/grad_op_builder.cc b/paddle/framework/grad_op_builder.cc index ff8a5583afe..f534b2c3366 100644 --- a/paddle/framework/grad_op_builder.cc +++ b/paddle/framework/grad_op_builder.cc @@ -50,7 +50,7 @@ static void TransOpArg(const OperatorBase* src_op, OperatorBase* dst_op, std::vector& dst_inout = dst_type == OpArgType::IN ? dst_op->inputs_ : dst_op->outputs_; std::vector* dst_format = GetOpFormat(dst_op, dst_type); - const OpProto& proto = OpRegistry::protos().at(src_op->type_); + const OpProto& proto = *(OpRegistry::op_info_map().at(src_op->type_).proto_); const auto& src_arg_list = src_type == OpArgType::IN ? proto.inputs() : proto.outputs(); @@ -76,13 +76,13 @@ static void TransOpArg(const OperatorBase* src_op, OperatorBase* dst_op, } OperatorBase* BuildGradOp(const OperatorBase* op) { - auto it = op_info_map().find(op->type_); + auto it = OpRegistry::op_info_map().find(op->type_); PADDLE_ENFORCE(it != OpRegistry::op_info_map().end(), - "'%s' has not been registered.", op->type); + "'%s' has not been registered.", op->type_); std::string grad_op_type = it->second.grad_op_type_; PADDLE_ENFORCE(!grad_op_type.empty(), "'%s' has no gradient operator.", - op->type); - it = op_info_map().find(grad_op_type); + op->type_); + it = OpRegistry::op_info_map().find(grad_op_type); PADDLE_ENFORCE(it != OpRegistry::op_info_map().end(), "'%s' has not been registered.", grad_op_type); OperatorBase* grad_op = it->second.creator_(); diff --git a/paddle/framework/op_registry.h b/paddle/framework/op_registry.h index b88559f82b1..69c5f549e37 100644 --- a/paddle/framework/op_registry.h +++ b/paddle/framework/op_registry.h @@ -175,17 +175,20 @@ Add a mark to which output is temporary is helpful for future optimization. bool has_temporary_output_{false}; }; -class NOPMaker : public OpProtoAndCheckerMaker {}; +class NOPMaker : public OpProtoAndCheckerMaker { + public: + NOPMaker(framework::OpProto* proto, framework::OpAttrChecker* op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) {} +}; struct OpInfo { - std::function creator_; + std::function creator_; std::string grad_op_type_; OpProto* proto_; OpAttrChecker* checker_; }; class OpRegistry { - using OpCreator = std::function; using VarIndexMap = std::unordered_map; using VarNameList = std::vector; @@ -201,28 +204,28 @@ class OpRegistry { if (std::type_index(typeid(ProtoMakerType)) != std::type_index(typeid(NOPMaker))) { op_info.proto_ = new OpProto; - op_info.op_checker_ = new OpAttrChecker; - auto maker = ProtoMakerType(op_info.proto_, op_info.op_checker_); + op_info.checker_ = new OpAttrChecker; + auto maker = ProtoMakerType(op_info.proto_, op_info.checker_); maker.Validate(); *op_info.proto_->mutable_type() = op_type; PADDLE_ENFORCE( op_info.proto_->IsInitialized(), "Fail to initialize %s's OpProto, because %s is not initialized", op_type, op_info.proto_->InitializationErrorString()); - //======will be refactored in following PRs============// + // ======will be refactored in following PRs============ // VarIndexMaps()[op_type].reset(new VarIndexMap()); auto& varmap = *VarIndexMaps()[op_type]; int idx = 0; - for (auto& var : op_proto.inputs()) { + for (auto& var : op_info.proto_->inputs()) { varmap[var.name()] = idx++; } idx = 0; - for (auto& var : op_proto.outputs()) { + for (auto& var : op_info.proto_->outputs()) { varmap[var.name()] = idx++; } - //================================================// + // ================================================ // } - op_info_map.insert(std::make_pair(op_type, op_info)); + op_info_map().insert(std::make_pair(op_type, op_info)); } static std::shared_ptr CreateOp(const std::string& type, @@ -281,8 +284,8 @@ class OpRegistry { return grad_op; } - static std::unordered_map& op_info_map() { - static std::unordered_map op_info_map_; + static std::unordered_map& op_info_map() { + static std::unordered_map op_info_map_; return op_info_map_; } @@ -321,7 +324,7 @@ class Registrar { template class OpRegistrar : public Registrar { public: - OpRegistrar(const char* op_type) { OpRegistrar(op_type, ""); } + explicit OpRegistrar(const char* op_type) { OpRegistrar(op_type, ""); } OpRegistrar(const char* op_type, const char* grad_op_type) { OpRegistry::RegisterOp(op_type, grad_op_type); } diff --git a/paddle/framework/operator_test.cc b/paddle/framework/operator_test.cc index b1976a65149..3887cadc60e 100644 --- a/paddle/framework/operator_test.cc +++ b/paddle/framework/operator_test.cc @@ -188,8 +188,9 @@ class CPUKernalMultiInputsTest : public OpKernel { } // namespace framework } // namespace paddle -REGISTER_OP(op_with_kernel, paddle::framework::OpWithKernelTest, - paddle::framework::OpKernelTestProtoAndCheckerMaker); +REGISTER_OP_WITHOUT_GRADIENT( + op_with_kernel, paddle::framework::OpWithKernelTest, + paddle::framework::OpKernelTestProtoAndCheckerMaker); REGISTER_OP_CPU_KERNEL(op_with_kernel, paddle::framework::CPUKernelTest); -- GitLab