提交 edb541f2 编写于 作者: F fengjiayi

fix compile errors

上级 3e6e5c92
......@@ -25,8 +25,9 @@ static void TransOpArg(const OperatorBase* src_op, const OpArgType& src_type,
const auto& src_inout =
src_type == OpArgType::IN ? src_op->inputs_ : src_op->outputs_;
auto& dst_inout = *vars;
const OpProto* proto = OpRegistry::op_info_map().at(src_op->type_).proto_;
const auto& src_arg_list =
src_type == OpArgType::IN ? proto.inputs() : proto.outputs();
src_type == OpArgType::IN ? proto->inputs() : proto->outputs();
for (const auto& arg : src_arg_list) {
if (arg.no_gradient() && !is_grad) continue;
const std::string src_name = arg.name();
......@@ -43,6 +44,8 @@ OperatorBase* BuildGradOp(const OperatorBase* op) {
auto it = OpRegistry::op_info_map().find(op->type_);
PADDLE_ENFORCE(it != OpRegistry::op_info_map().end(),
"'%s' has not been registered.", op->type_);
PADDLE_ENFORCE(it->second.proto_ != nullptr, "'%s' has no OpProto.",
op->type_);
std::string grad_op_type = it->second.grad_op_type_;
PADDLE_ENFORCE(!grad_op_type.empty(), "'%s' has no gradient operator.",
op->type_);
......
......@@ -126,13 +126,6 @@ class NOPMaker : public OpProtoAndCheckerMaker {
: OpProtoAndCheckerMaker(proto, op_checker) {}
};
struct OpInfo {
std::function<OperatorBase*()> creator_;
std::string grad_op_type_;
OpProto* proto_;
OpAttrChecker* checker_;
};
class OpRegistry {
using VarNameMap = OperatorBase::VarNameMap;
using OpCreator = std::function<OperatorBase*(
......@@ -140,6 +133,13 @@ class OpRegistry {
const VarNameMap& /*outputs*/, const AttributeMap& /*attrs*/)>;
public:
struct OpInfo {
OpCreator creator_;
std::string grad_op_type_;
OpProto* proto_;
OpAttrChecker* checker_;
};
template <typename OpType, typename ProtoMakerType, typename GradOpType>
static void RegisterOp(const std::string& op_type,
const std::string& grad_op_type) {
......@@ -175,9 +175,9 @@ class OpRegistry {
}
static std::shared_ptr<OperatorBase> CreateOp(const std::string& type,
const VarNameList& inputs,
const VarNameList& outputs,
const AttributeMap& attrs) {
const VarNameMap& inputs,
const VarNameMap& outputs,
AttributeMap attrs) {
auto it = op_info_map().find(type);
PADDLE_ENFORCE(it != op_info_map().end(),
"Operator '%s' has not been registered.", type);
......
......@@ -152,7 +152,7 @@ std::vector<std::string> OperatorBase::OutputVars(bool has_intermediate) const {
type_);
// get all OpProto::Var for outputs
for (auto& o : it->second.proto_.outputs()) {
for (auto& o : it->second.proto_->outputs()) {
// ignore all intermediate output
if (o.intermediate()) continue;
auto out = outputs_.find(o.name());
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册