提交 5e378724 编写于 作者: F fengjiayi

Refine code

上级 ab18947e
......@@ -45,12 +45,10 @@ static void TransOpArg(const OperatorBase* src_op, OperatorBase* dst_op,
int& idx, bool is_grad) {
const std::vector<std::string>& src_inout =
src_type == IN ? src_op->inputs_ : src_op->outputs_;
const VarIndexMap& src_varmap = *src_op->in_out_idxs_;
const std::vector<int>* src_format = GetOpFormat(src_op, src_type);
std::vector<std::string>& dst_inout =
dst_type == IN ? dst_op->inputs_ : dst_op->outputs_;
VarIndexMap& dst_varmap = *dst_op->in_out_idxs_;
std::vector<int>* dst_format = GetOpFormat(dst_op, dst_type);
const OpProto& proto = OpRegistry::protos().at(src_op->type_);
const auto& src_arg_list = src_type == IN ? proto.inputs() : proto.outputs();
......@@ -59,8 +57,8 @@ static void TransOpArg(const OperatorBase* src_op, OperatorBase* dst_op,
std::string src_name = arg.name();
std::string dst_name =
is_grad ? src_name + OperatorBase::GRAD_VAR_SUFFIX() : src_name;
dst_varmap[dst_name] = idx++;
int src_arg_idx = src_varmap.at(src_name);
(*dst_op->in_out_idxs_)[dst_name] = idx++;
int src_arg_idx = src_op->in_out_idxs_->at(src_name);
int src_begin =
src_format == nullptr ? src_arg_idx : src_format->at(src_arg_idx);
int src_end = src_format == nullptr ? src_arg_idx + 1
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册