提交 6768b310 编写于 作者: F fengjiayi

Fix compile error

上级 3e11e4c6
......@@ -50,7 +50,7 @@ static void TransOpArg(const OperatorBase* src_op, OperatorBase* dst_op,
std::vector<std::string>& dst_inout =
dst_type == OpArgType::IN ? dst_op->inputs_ : dst_op->outputs_;
std::vector<int>* dst_format = GetOpFormat(dst_op, dst_type);
const OpProto& proto = OpRegistry::protos().at(src_op->type_);
const OpProto& proto = *(OpRegistry::op_info_map().at(src_op->type_).proto_);
const auto& src_arg_list =
src_type == OpArgType::IN ? proto.inputs() : proto.outputs();
......@@ -76,13 +76,13 @@ static void TransOpArg(const OperatorBase* src_op, OperatorBase* dst_op,
}
OperatorBase* BuildGradOp(const OperatorBase* op) {
auto it = op_info_map().find(op->type_);
auto it = OpRegistry::op_info_map().find(op->type_);
PADDLE_ENFORCE(it != OpRegistry::op_info_map().end(),
"'%s' has not been registered.", op->type);
"'%s' has not been registered.", op->type_);
std::string grad_op_type = it->second.grad_op_type_;
PADDLE_ENFORCE(!grad_op_type.empty(), "'%s' has no gradient operator.",
op->type);
it = op_info_map().find(grad_op_type);
op->type_);
it = OpRegistry::op_info_map().find(grad_op_type);
PADDLE_ENFORCE(it != OpRegistry::op_info_map().end(),
"'%s' has not been registered.", grad_op_type);
OperatorBase* grad_op = it->second.creator_();
......
......@@ -175,17 +175,20 @@ Add a mark to which output is temporary is helpful for future optimization.
bool has_temporary_output_{false};
};
class NOPMaker : public OpProtoAndCheckerMaker {};
class NOPMaker : public OpProtoAndCheckerMaker {
public:
NOPMaker(framework::OpProto* proto, framework::OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {}
};
struct OpInfo {
std::function creator_;
std::function<OperatorBase*()> creator_;
std::string grad_op_type_;
OpProto* proto_;
OpAttrChecker* checker_;
};
class OpRegistry {
using OpCreator = std::function<OperatorBase*()>;
using VarIndexMap = std::unordered_map<std::string, int>;
using VarNameList = std::vector<std::string>;
......@@ -201,28 +204,28 @@ class OpRegistry {
if (std::type_index(typeid(ProtoMakerType)) !=
std::type_index(typeid(NOPMaker))) {
op_info.proto_ = new OpProto;
op_info.op_checker_ = new OpAttrChecker;
auto maker = ProtoMakerType(op_info.proto_, op_info.op_checker_);
op_info.checker_ = new OpAttrChecker;
auto maker = ProtoMakerType(op_info.proto_, op_info.checker_);
maker.Validate();
*op_info.proto_->mutable_type() = op_type;
PADDLE_ENFORCE(
op_info.proto_->IsInitialized(),
"Fail to initialize %s's OpProto, because %s is not initialized",
op_type, op_info.proto_->InitializationErrorString());
//======will be refactored in following PRs============//
// ======will be refactored in following PRs============ //
VarIndexMaps()[op_type].reset(new VarIndexMap());
auto& varmap = *VarIndexMaps()[op_type];
int idx = 0;
for (auto& var : op_proto.inputs()) {
for (auto& var : op_info.proto_->inputs()) {
varmap[var.name()] = idx++;
}
idx = 0;
for (auto& var : op_proto.outputs()) {
for (auto& var : op_info.proto_->outputs()) {
varmap[var.name()] = idx++;
}
//================================================//
// ================================================ //
}
op_info_map.insert(std::make_pair(op_type, op_info));
op_info_map().insert(std::make_pair(op_type, op_info));
}
static std::shared_ptr<OperatorBase> CreateOp(const std::string& type,
......@@ -281,8 +284,8 @@ class OpRegistry {
return grad_op;
}
static std::unordered_map<const std::string, const OpInfo>& op_info_map() {
static std::unordered_map<const std::string, const OpInfo> op_info_map_;
static std::unordered_map<std::string, const OpInfo>& op_info_map() {
static std::unordered_map<std::string, const OpInfo> op_info_map_;
return op_info_map_;
}
......@@ -321,7 +324,7 @@ class Registrar {
template <typename OpType, typename ProtoMakerType>
class OpRegistrar : public Registrar {
public:
OpRegistrar(const char* op_type) { OpRegistrar(op_type, ""); }
explicit OpRegistrar(const char* op_type) { OpRegistrar(op_type, ""); }
OpRegistrar(const char* op_type, const char* grad_op_type) {
OpRegistry::RegisterOp<OpType, ProtoMakerType>(op_type, grad_op_type);
}
......
......@@ -188,8 +188,9 @@ class CPUKernalMultiInputsTest : public OpKernel {
} // namespace framework
} // namespace paddle
REGISTER_OP(op_with_kernel, paddle::framework::OpWithKernelTest,
paddle::framework::OpKernelTestProtoAndCheckerMaker);
REGISTER_OP_WITHOUT_GRADIENT(
op_with_kernel, paddle::framework::OpWithKernelTest,
paddle::framework::OpKernelTestProtoAndCheckerMaker);
REGISTER_OP_CPU_KERNEL(op_with_kernel,
paddle::framework::CPUKernelTest<float, float>);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册