提交 edb541f2 编写于 作者: F fengjiayi

fix compile errors

上级 3e6e5c92
...@@ -25,8 +25,9 @@ static void TransOpArg(const OperatorBase* src_op, const OpArgType& src_type, ...@@ -25,8 +25,9 @@ static void TransOpArg(const OperatorBase* src_op, const OpArgType& src_type,
const auto& src_inout = const auto& src_inout =
src_type == OpArgType::IN ? src_op->inputs_ : src_op->outputs_; src_type == OpArgType::IN ? src_op->inputs_ : src_op->outputs_;
auto& dst_inout = *vars; auto& dst_inout = *vars;
const OpProto* proto = OpRegistry::op_info_map().at(src_op->type_).proto_;
const auto& src_arg_list = const auto& src_arg_list =
src_type == OpArgType::IN ? proto.inputs() : proto.outputs(); src_type == OpArgType::IN ? proto->inputs() : proto->outputs();
for (const auto& arg : src_arg_list) { for (const auto& arg : src_arg_list) {
if (arg.no_gradient() && !is_grad) continue; if (arg.no_gradient() && !is_grad) continue;
const std::string src_name = arg.name(); const std::string src_name = arg.name();
...@@ -43,6 +44,8 @@ OperatorBase* BuildGradOp(const OperatorBase* op) { ...@@ -43,6 +44,8 @@ OperatorBase* BuildGradOp(const OperatorBase* op) {
auto it = OpRegistry::op_info_map().find(op->type_); auto it = OpRegistry::op_info_map().find(op->type_);
PADDLE_ENFORCE(it != OpRegistry::op_info_map().end(), PADDLE_ENFORCE(it != OpRegistry::op_info_map().end(),
"'%s' has not been registered.", op->type_); "'%s' has not been registered.", op->type_);
PADDLE_ENFORCE(it->second.proto_ != nullptr, "'%s' has no OpProto.",
op->type_);
std::string grad_op_type = it->second.grad_op_type_; std::string grad_op_type = it->second.grad_op_type_;
PADDLE_ENFORCE(!grad_op_type.empty(), "'%s' has no gradient operator.", PADDLE_ENFORCE(!grad_op_type.empty(), "'%s' has no gradient operator.",
op->type_); op->type_);
......
...@@ -126,13 +126,6 @@ class NOPMaker : public OpProtoAndCheckerMaker { ...@@ -126,13 +126,6 @@ class NOPMaker : public OpProtoAndCheckerMaker {
: OpProtoAndCheckerMaker(proto, op_checker) {} : OpProtoAndCheckerMaker(proto, op_checker) {}
}; };
struct OpInfo {
std::function<OperatorBase*()> creator_;
std::string grad_op_type_;
OpProto* proto_;
OpAttrChecker* checker_;
};
class OpRegistry { class OpRegistry {
using VarNameMap = OperatorBase::VarNameMap; using VarNameMap = OperatorBase::VarNameMap;
using OpCreator = std::function<OperatorBase*( using OpCreator = std::function<OperatorBase*(
...@@ -140,6 +133,13 @@ class OpRegistry { ...@@ -140,6 +133,13 @@ class OpRegistry {
const VarNameMap& /*outputs*/, const AttributeMap& /*attrs*/)>; const VarNameMap& /*outputs*/, const AttributeMap& /*attrs*/)>;
public: public:
struct OpInfo {
OpCreator creator_;
std::string grad_op_type_;
OpProto* proto_;
OpAttrChecker* checker_;
};
template <typename OpType, typename ProtoMakerType, typename GradOpType> template <typename OpType, typename ProtoMakerType, typename GradOpType>
static void RegisterOp(const std::string& op_type, static void RegisterOp(const std::string& op_type,
const std::string& grad_op_type) { const std::string& grad_op_type) {
...@@ -175,9 +175,9 @@ class OpRegistry { ...@@ -175,9 +175,9 @@ class OpRegistry {
} }
static std::shared_ptr<OperatorBase> CreateOp(const std::string& type, static std::shared_ptr<OperatorBase> CreateOp(const std::string& type,
const VarNameList& inputs, const VarNameMap& inputs,
const VarNameList& outputs, const VarNameMap& outputs,
const AttributeMap& attrs) { AttributeMap attrs) {
auto it = op_info_map().find(type); auto it = op_info_map().find(type);
PADDLE_ENFORCE(it != op_info_map().end(), PADDLE_ENFORCE(it != op_info_map().end(),
"Operator '%s' has not been registered.", type); "Operator '%s' has not been registered.", type);
......
...@@ -152,7 +152,7 @@ std::vector<std::string> OperatorBase::OutputVars(bool has_intermediate) const { ...@@ -152,7 +152,7 @@ std::vector<std::string> OperatorBase::OutputVars(bool has_intermediate) const {
type_); type_);
// get all OpProto::Var for outputs // get all OpProto::Var for outputs
for (auto& o : it->second.proto_.outputs()) { for (auto& o : it->second.proto_->outputs()) {
// ignore all intermediate output // ignore all intermediate output
if (o.intermediate()) continue; if (o.intermediate()) continue;
auto out = outputs_.find(o.name()); auto out = outputs_.find(o.name());
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册