提交 6768b310 编写于 作者: F fengjiayi

Fix compile error

上级 3e11e4c6
...@@ -50,7 +50,7 @@ static void TransOpArg(const OperatorBase* src_op, OperatorBase* dst_op, ...@@ -50,7 +50,7 @@ static void TransOpArg(const OperatorBase* src_op, OperatorBase* dst_op,
std::vector<std::string>& dst_inout = std::vector<std::string>& dst_inout =
dst_type == OpArgType::IN ? dst_op->inputs_ : dst_op->outputs_; dst_type == OpArgType::IN ? dst_op->inputs_ : dst_op->outputs_;
std::vector<int>* dst_format = GetOpFormat(dst_op, dst_type); std::vector<int>* dst_format = GetOpFormat(dst_op, dst_type);
const OpProto& proto = OpRegistry::protos().at(src_op->type_); const OpProto& proto = *(OpRegistry::op_info_map().at(src_op->type_).proto_);
const auto& src_arg_list = const auto& src_arg_list =
src_type == OpArgType::IN ? proto.inputs() : proto.outputs(); src_type == OpArgType::IN ? proto.inputs() : proto.outputs();
...@@ -76,13 +76,13 @@ static void TransOpArg(const OperatorBase* src_op, OperatorBase* dst_op, ...@@ -76,13 +76,13 @@ static void TransOpArg(const OperatorBase* src_op, OperatorBase* dst_op,
} }
OperatorBase* BuildGradOp(const OperatorBase* op) { OperatorBase* BuildGradOp(const OperatorBase* op) {
auto it = op_info_map().find(op->type_); auto it = OpRegistry::op_info_map().find(op->type_);
PADDLE_ENFORCE(it != OpRegistry::op_info_map().end(), PADDLE_ENFORCE(it != OpRegistry::op_info_map().end(),
"'%s' has not been registered.", op->type); "'%s' has not been registered.", op->type_);
std::string grad_op_type = it->second.grad_op_type_; std::string grad_op_type = it->second.grad_op_type_;
PADDLE_ENFORCE(!grad_op_type.empty(), "'%s' has no gradient operator.", PADDLE_ENFORCE(!grad_op_type.empty(), "'%s' has no gradient operator.",
op->type); op->type_);
it = op_info_map().find(grad_op_type); it = OpRegistry::op_info_map().find(grad_op_type);
PADDLE_ENFORCE(it != OpRegistry::op_info_map().end(), PADDLE_ENFORCE(it != OpRegistry::op_info_map().end(),
"'%s' has not been registered.", grad_op_type); "'%s' has not been registered.", grad_op_type);
OperatorBase* grad_op = it->second.creator_(); OperatorBase* grad_op = it->second.creator_();
......
...@@ -175,17 +175,20 @@ Add a mark to which output is temporary is helpful for future optimization. ...@@ -175,17 +175,20 @@ Add a mark to which output is temporary is helpful for future optimization.
bool has_temporary_output_{false}; bool has_temporary_output_{false};
}; };
class NOPMaker : public OpProtoAndCheckerMaker {}; class NOPMaker : public OpProtoAndCheckerMaker {
public:
NOPMaker(framework::OpProto* proto, framework::OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {}
};
struct OpInfo { struct OpInfo {
std::function creator_; std::function<OperatorBase*()> creator_;
std::string grad_op_type_; std::string grad_op_type_;
OpProto* proto_; OpProto* proto_;
OpAttrChecker* checker_; OpAttrChecker* checker_;
}; };
class OpRegistry { class OpRegistry {
using OpCreator = std::function<OperatorBase*()>;
using VarIndexMap = std::unordered_map<std::string, int>; using VarIndexMap = std::unordered_map<std::string, int>;
using VarNameList = std::vector<std::string>; using VarNameList = std::vector<std::string>;
...@@ -201,28 +204,28 @@ class OpRegistry { ...@@ -201,28 +204,28 @@ class OpRegistry {
if (std::type_index(typeid(ProtoMakerType)) != if (std::type_index(typeid(ProtoMakerType)) !=
std::type_index(typeid(NOPMaker))) { std::type_index(typeid(NOPMaker))) {
op_info.proto_ = new OpProto; op_info.proto_ = new OpProto;
op_info.op_checker_ = new OpAttrChecker; op_info.checker_ = new OpAttrChecker;
auto maker = ProtoMakerType(op_info.proto_, op_info.op_checker_); auto maker = ProtoMakerType(op_info.proto_, op_info.checker_);
maker.Validate(); maker.Validate();
*op_info.proto_->mutable_type() = op_type; *op_info.proto_->mutable_type() = op_type;
PADDLE_ENFORCE( PADDLE_ENFORCE(
op_info.proto_->IsInitialized(), op_info.proto_->IsInitialized(),
"Fail to initialize %s's OpProto, because %s is not initialized", "Fail to initialize %s's OpProto, because %s is not initialized",
op_type, op_info.proto_->InitializationErrorString()); op_type, op_info.proto_->InitializationErrorString());
//======will be refactored in following PRs============// // ======will be refactored in following PRs============ //
VarIndexMaps()[op_type].reset(new VarIndexMap()); VarIndexMaps()[op_type].reset(new VarIndexMap());
auto& varmap = *VarIndexMaps()[op_type]; auto& varmap = *VarIndexMaps()[op_type];
int idx = 0; int idx = 0;
for (auto& var : op_proto.inputs()) { for (auto& var : op_info.proto_->inputs()) {
varmap[var.name()] = idx++; varmap[var.name()] = idx++;
} }
idx = 0; idx = 0;
for (auto& var : op_proto.outputs()) { for (auto& var : op_info.proto_->outputs()) {
varmap[var.name()] = idx++; varmap[var.name()] = idx++;
} }
//================================================// // ================================================ //
} }
op_info_map.insert(std::make_pair(op_type, op_info)); op_info_map().insert(std::make_pair(op_type, op_info));
} }
static std::shared_ptr<OperatorBase> CreateOp(const std::string& type, static std::shared_ptr<OperatorBase> CreateOp(const std::string& type,
...@@ -281,8 +284,8 @@ class OpRegistry { ...@@ -281,8 +284,8 @@ class OpRegistry {
return grad_op; return grad_op;
} }
static std::unordered_map<const std::string, const OpInfo>& op_info_map() { static std::unordered_map<std::string, const OpInfo>& op_info_map() {
static std::unordered_map<const std::string, const OpInfo> op_info_map_; static std::unordered_map<std::string, const OpInfo> op_info_map_;
return op_info_map_; return op_info_map_;
} }
...@@ -321,7 +324,7 @@ class Registrar { ...@@ -321,7 +324,7 @@ class Registrar {
template <typename OpType, typename ProtoMakerType> template <typename OpType, typename ProtoMakerType>
class OpRegistrar : public Registrar { class OpRegistrar : public Registrar {
public: public:
OpRegistrar(const char* op_type) { OpRegistrar(op_type, ""); } explicit OpRegistrar(const char* op_type) { OpRegistrar(op_type, ""); }
OpRegistrar(const char* op_type, const char* grad_op_type) { OpRegistrar(const char* op_type, const char* grad_op_type) {
OpRegistry::RegisterOp<OpType, ProtoMakerType>(op_type, grad_op_type); OpRegistry::RegisterOp<OpType, ProtoMakerType>(op_type, grad_op_type);
} }
......
...@@ -188,8 +188,9 @@ class CPUKernalMultiInputsTest : public OpKernel { ...@@ -188,8 +188,9 @@ class CPUKernalMultiInputsTest : public OpKernel {
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
REGISTER_OP(op_with_kernel, paddle::framework::OpWithKernelTest, REGISTER_OP_WITHOUT_GRADIENT(
paddle::framework::OpKernelTestProtoAndCheckerMaker); op_with_kernel, paddle::framework::OpWithKernelTest,
paddle::framework::OpKernelTestProtoAndCheckerMaker);
REGISTER_OP_CPU_KERNEL(op_with_kernel, REGISTER_OP_CPU_KERNEL(op_with_kernel,
paddle::framework::CPUKernelTest<float, float>); paddle::framework::CPUKernelTest<float, float>);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册