未验证 提交 5c5a3660 编写于 作者: L Leo Chen 提交者: GitHub

[new-exec] update the dependency of op that has inplace_back_map (#41009)

上级 77a455c7
......@@ -305,11 +305,6 @@ void ApplyDataTransform(const OpKernelType& expected_kernel_key,
var_scope->GetIdByName(var_name);
op_func_node->output_index[pair.first][j] =
var_scope->VarId(new_var_name);
// NOTE(zhiqiu): The inplace op with `transfer` also changes
// original output after that
// so add original output as well
op_func_node->output_index[pair.first].push_back(
var_scope->VarId(var_name));
}
}
}
......
......@@ -667,6 +667,23 @@ std::map<int, std::list<int>> build_op_downstream_map(
}
}
}
// NOTE(zhiqiu): The inplace op with `transfer` also changes
// original output after that so add original output as well
// original: a->op->a
// after: a->data_transfer->a'->op->a'->transfer_back->a
// which means op writes a and a'
if (!vec_instruction[op_idx].InplaceBackMap().empty()) {
auto& m = vec_instruction[op_idx].InplaceBackMap();
for (auto& p : m) {
auto var = p.second;
var2recent_write_op[var] = op_idx;
// var in input list and in output list, so remove it.
if (remove_duplicate.count(var) == 0) {
update_var_min_rw_op(op2dependences, &var2min_rw_op, op_idx, var);
}
}
}
}
return std::move(get_downstream_map(op2dependences));
}
......
......@@ -71,8 +71,8 @@ void AdamDenseKernel(const Context& dev_ctx,
phi::Copy(dev_ctx, param, dev_ctx.GetPlace(), false, param_out);
phi::Copy(dev_ctx, moment1, dev_ctx.GetPlace(), false, moment1_out);
phi::Copy(dev_ctx, moment2, dev_ctx.GetPlace(), false, moment2_out);
phi::Copy(dev_ctx, beta1_pow, dev_ctx.GetPlace(), false, beta1_pow_out);
phi::Copy(dev_ctx, beta2_pow, dev_ctx.GetPlace(), false, beta2_pow_out);
phi::Copy(dev_ctx, beta1_pow, beta1_pow.place(), false, beta1_pow_out);
phi::Copy(dev_ctx, beta2_pow, beta2_pow.place(), false, beta2_pow_out);
return;
}
......
......@@ -172,8 +172,8 @@ void AdamDenseKernel(const Context& dev_ctx,
phi::Copy(dev_ctx, param, dev_ctx.GetPlace(), false, param_out);
phi::Copy(dev_ctx, moment1, dev_ctx.GetPlace(), false, moment1_out);
phi::Copy(dev_ctx, moment2, dev_ctx.GetPlace(), false, moment2_out);
phi::Copy(dev_ctx, beta1_pow, dev_ctx.GetPlace(), false, beta1_pow_out);
phi::Copy(dev_ctx, beta2_pow, dev_ctx.GetPlace(), false, beta2_pow_out);
phi::Copy(dev_ctx, beta1_pow, beta1_pow.place(), false, beta1_pow_out);
phi::Copy(dev_ctx, beta2_pow, beta2_pow.place(), false, beta2_pow_out);
return;
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册