From 5e37872462c7dfec33f8da80335520a645beb1b8 Mon Sep 17 00:00:00 2001 From: fengjiayi Date: Wed, 2 Aug 2017 16:56:40 -0700 Subject: [PATCH] Refine code --- paddle/framework/grad_op_builder.cc | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/paddle/framework/grad_op_builder.cc b/paddle/framework/grad_op_builder.cc index 9f7856a79b6..afb8a2cfe14 100644 --- a/paddle/framework/grad_op_builder.cc +++ b/paddle/framework/grad_op_builder.cc @@ -45,12 +45,10 @@ static void TransOpArg(const OperatorBase* src_op, OperatorBase* dst_op, int& idx, bool is_grad) { const std::vector& src_inout = src_type == IN ? src_op->inputs_ : src_op->outputs_; - const VarIndexMap& src_varmap = *src_op->in_out_idxs_; const std::vector* src_format = GetOpFormat(src_op, src_type); std::vector& dst_inout = dst_type == IN ? dst_op->inputs_ : dst_op->outputs_; - VarIndexMap& dst_varmap = *dst_op->in_out_idxs_; std::vector* dst_format = GetOpFormat(dst_op, dst_type); const OpProto& proto = OpRegistry::protos().at(src_op->type_); const auto& src_arg_list = src_type == IN ? proto.inputs() : proto.outputs(); @@ -59,8 +57,8 @@ static void TransOpArg(const OperatorBase* src_op, OperatorBase* dst_op, std::string src_name = arg.name(); std::string dst_name = is_grad ? src_name + OperatorBase::GRAD_VAR_SUFFIX() : src_name; - dst_varmap[dst_name] = idx++; - int src_arg_idx = src_varmap.at(src_name); + (*dst_op->in_out_idxs_)[dst_name] = idx++; + int src_arg_idx = src_op->in_out_idxs_->at(src_name); int src_begin = src_format == nullptr ? src_arg_idx : src_format->at(src_arg_idx); int src_end = src_format == nullptr ? src_arg_idx + 1 -- GitLab