diff --git a/paddle/framework/grad_op_builder.cc b/paddle/framework/grad_op_builder.cc index 9f7856a79b678ab81046f22c526d17c33fff4b40..afb8a2cfe1441146a63fd585ad6528d9872e31fb 100644 --- a/paddle/framework/grad_op_builder.cc +++ b/paddle/framework/grad_op_builder.cc @@ -45,12 +45,10 @@ static void TransOpArg(const OperatorBase* src_op, OperatorBase* dst_op, int& idx, bool is_grad) { const std::vector& src_inout = src_type == IN ? src_op->inputs_ : src_op->outputs_; - const VarIndexMap& src_varmap = *src_op->in_out_idxs_; const std::vector* src_format = GetOpFormat(src_op, src_type); std::vector& dst_inout = dst_type == IN ? dst_op->inputs_ : dst_op->outputs_; - VarIndexMap& dst_varmap = *dst_op->in_out_idxs_; std::vector* dst_format = GetOpFormat(dst_op, dst_type); const OpProto& proto = OpRegistry::protos().at(src_op->type_); const auto& src_arg_list = src_type == IN ? proto.inputs() : proto.outputs(); @@ -59,8 +57,8 @@ static void TransOpArg(const OperatorBase* src_op, OperatorBase* dst_op, std::string src_name = arg.name(); std::string dst_name = is_grad ? src_name + OperatorBase::GRAD_VAR_SUFFIX() : src_name; - dst_varmap[dst_name] = idx++; - int src_arg_idx = src_varmap.at(src_name); + (*dst_op->in_out_idxs_)[dst_name] = idx++; + int src_arg_idx = src_op->in_out_idxs_->at(src_name); int src_begin = src_format == nullptr ? src_arg_idx : src_format->at(src_arg_idx); int src_end = src_format == nullptr ? src_arg_idx + 1