未验证 提交 54497c47 编写于 作者: R Ruibiao Chen 提交者: GitHub

Skip device transfer when arg-defs is set to Allbackend (#52294)

上级 41f0e3c3
......@@ -29,9 +29,8 @@ namespace paddle {
namespace framework {
namespace interpreter {
bool DataTranferHelper::apply(
const phi::KernelKey& kernel_type_for_var,
const framework::OpKernelType& expected_kernel_key,
bool DataTranferHelper::apply(const phi::KernelKey& kernel_type_for_var,
const phi::KernelKey& expected_kernel_key,
const phi::DenseTensor* tensor,
const std::string& var_name,
std::string* new_var_name,
......@@ -43,13 +42,11 @@ bool DataTranferHelper::apply(
auto* src_var_name = &var_name;
// 1. layout transform
if (need_layout_transform(
kernel_type_for_var,
TransOpKernelTypeToPhiKernelKey(expected_kernel_key))) {
if (need_layout_transform(kernel_type_for_var, expected_kernel_key)) {
auto op = TransferLayout(*src_var_name,
new_var_name,
kernel_type_for_var.layout(),
expected_kernel_key.data_layout_,
expected_kernel_key.layout(),
var_scope_,
scope_,
is_fetch_v2);
......@@ -61,15 +58,14 @@ bool DataTranferHelper::apply(
src_var_name = new_var_name;
is_transferred = true;
}
// 2. dype transform
if (need_dtype_transform(
kernel_type_for_var,
TransOpKernelTypeToPhiKernelKey(expected_kernel_key))) {
if (need_dtype_transform(kernel_type_for_var, expected_kernel_key)) {
auto op = TransferDtype(
*src_var_name,
new_var_name,
framework::TransToProtoVarType(kernel_type_for_var.dtype()),
expected_kernel_key.data_type_,
framework::TransToProtoVarType(expected_kernel_key.dtype()),
var_scope_,
scope_);
if (op) {
......@@ -80,11 +76,12 @@ bool DataTranferHelper::apply(
src_var_name = new_var_name;
is_transferred = true;
}
// 3. device transform
if (need_device_transform(
kernel_type_for_var, tensor, expected_kernel_key.place_)) {
phi::Backend expected_backend = expected_kernel_key.backend();
if (need_device_transform(kernel_type_for_var, tensor, expected_backend)) {
auto src_place = tensor->place();
auto dst_place = expected_kernel_key.place_;
auto dst_place = phi::TransToPhiPlace(expected_backend);
auto op = TransferDevice(
*src_var_name, new_var_name, src_place, dst_place, var_scope_, scope_);
......@@ -575,8 +572,7 @@ void ApplyDataTransform(const OpKernelType& expected_kernel_key,
}
std::unique_ptr<phi::KernelKey>
expected_kernel_key_for_argument_def = nullptr;
if (argument_def &&
argument_def->backend != phi::Backend::ALL_BACKEND) {
if (argument_def) {
const phi::Backend& tensor_backend =
phi::TransToPhiBackend(tensor_in->place());
const phi::Backend& def_backend = argument_def->backend;
......@@ -607,9 +603,8 @@ void ApplyDataTransform(const OpKernelType& expected_kernel_key,
is_transferred = data_transfer_helper.apply(
kernel_key_for_var,
(expected_kernel_key_for_argument_def
? TransPhiKernelKeyToOpKernelType(
*expected_kernel_key_for_argument_def.get())
: expected_kernel_key),
? *expected_kernel_key_for_argument_def.get()
: TransOpKernelTypeToPhiKernelKey(expected_kernel_key)),
tensor_in,
var_name,
&new_var_name,
......
......@@ -35,7 +35,7 @@ class DataTranferHelper {
: place_(place), var_scope_(var_scope), scope_(local_scope) {}
bool apply(const phi::KernelKey& kernel_type_for_var,
const framework::OpKernelType& expected_kernel_key,
const phi::KernelKey& expected_kernel_key,
const phi::DenseTensor* tensor,
const std::string& var_name,
std::string* new_var_name,
......@@ -82,9 +82,14 @@ void HandleComplexGradToRealGrad(const OpFuncNode& op_func_node,
inline bool need_device_transform(const phi::KernelKey& kernel_type_for_var,
const phi::DenseTensor* tensor,
const phi::Place& expected_place) {
const phi::Backend& expected_backend) {
if (kernel_type_for_var.backend() == phi::Backend::ALL_BACKEND ||
platform::is_same_place(tensor->place(), expected_place) ||
expected_backend == phi::Backend::ALL_BACKEND) {
return false;
}
phi::Place expected_place = phi::TransToPhiPlace(expected_backend);
if (platform::is_same_place(tensor->place(), expected_place) ||
(platform::is_cuda_pinned_place(tensor->place()) &&
platform::is_cpu_place(expected_place))) {
return false;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册