diff --git a/paddle/fluid/framework/new_executor/interpreter/data_transfer.cc b/paddle/fluid/framework/new_executor/interpreter/data_transfer.cc index 9e494eb28dd38dbf2f92709faec2630f16014cbe..7024a57bb84b12e5210bfe694eee75a734850449 100644 --- a/paddle/fluid/framework/new_executor/interpreter/data_transfer.cc +++ b/paddle/fluid/framework/new_executor/interpreter/data_transfer.cc @@ -29,27 +29,24 @@ namespace paddle { namespace framework { namespace interpreter { -bool DataTranferHelper::apply( - const phi::KernelKey& kernel_type_for_var, - const framework::OpKernelType& expected_kernel_key, - const phi::DenseTensor* tensor, - const std::string& var_name, - std::string* new_var_name, - std::vector* op_func_nodes, - bool use_local_scope, - bool is_fetch_v2, - bool skip_run) { +bool DataTranferHelper::apply(const phi::KernelKey& kernel_type_for_var, + const phi::KernelKey& expected_kernel_key, + const phi::DenseTensor* tensor, + const std::string& var_name, + std::string* new_var_name, + std::vector* op_func_nodes, + bool use_local_scope, + bool is_fetch_v2, + bool skip_run) { bool is_transferred = false; auto* src_var_name = &var_name; // 1. layout transform - if (need_layout_transform( - kernel_type_for_var, - TransOpKernelTypeToPhiKernelKey(expected_kernel_key))) { + if (need_layout_transform(kernel_type_for_var, expected_kernel_key)) { auto op = TransferLayout(*src_var_name, new_var_name, kernel_type_for_var.layout(), - expected_kernel_key.data_layout_, + expected_kernel_key.layout(), var_scope_, scope_, is_fetch_v2); @@ -61,15 +58,14 @@ bool DataTranferHelper::apply( src_var_name = new_var_name; is_transferred = true; } + // 2. dype transform - if (need_dtype_transform( - kernel_type_for_var, - TransOpKernelTypeToPhiKernelKey(expected_kernel_key))) { + if (need_dtype_transform(kernel_type_for_var, expected_kernel_key)) { auto op = TransferDtype( *src_var_name, new_var_name, framework::TransToProtoVarType(kernel_type_for_var.dtype()), - expected_kernel_key.data_type_, + framework::TransToProtoVarType(expected_kernel_key.dtype()), var_scope_, scope_); if (op) { @@ -80,11 +76,12 @@ bool DataTranferHelper::apply( src_var_name = new_var_name; is_transferred = true; } + // 3. device transform - if (need_device_transform( - kernel_type_for_var, tensor, expected_kernel_key.place_)) { + phi::Backend expected_backend = expected_kernel_key.backend(); + if (need_device_transform(kernel_type_for_var, tensor, expected_backend)) { auto src_place = tensor->place(); - auto dst_place = expected_kernel_key.place_; + auto dst_place = phi::TransToPhiPlace(expected_backend); auto op = TransferDevice( *src_var_name, new_var_name, src_place, dst_place, var_scope_, scope_); @@ -575,8 +572,7 @@ void ApplyDataTransform(const OpKernelType& expected_kernel_key, } std::unique_ptr expected_kernel_key_for_argument_def = nullptr; - if (argument_def && - argument_def->backend != phi::Backend::ALL_BACKEND) { + if (argument_def) { const phi::Backend& tensor_backend = phi::TransToPhiBackend(tensor_in->place()); const phi::Backend& def_backend = argument_def->backend; @@ -607,9 +603,8 @@ void ApplyDataTransform(const OpKernelType& expected_kernel_key, is_transferred = data_transfer_helper.apply( kernel_key_for_var, (expected_kernel_key_for_argument_def - ? TransPhiKernelKeyToOpKernelType( - *expected_kernel_key_for_argument_def.get()) - : expected_kernel_key), + ? *expected_kernel_key_for_argument_def.get() + : TransOpKernelTypeToPhiKernelKey(expected_kernel_key)), tensor_in, var_name, &new_var_name, diff --git a/paddle/fluid/framework/new_executor/interpreter/data_transfer.h b/paddle/fluid/framework/new_executor/interpreter/data_transfer.h index 616f625075262cab722d2d96afc11c5109b11edf..19089b11942f784964edf18049d05e883f2f82b4 100644 --- a/paddle/fluid/framework/new_executor/interpreter/data_transfer.h +++ b/paddle/fluid/framework/new_executor/interpreter/data_transfer.h @@ -35,7 +35,7 @@ class DataTranferHelper { : place_(place), var_scope_(var_scope), scope_(local_scope) {} bool apply(const phi::KernelKey& kernel_type_for_var, - const framework::OpKernelType& expected_kernel_key, + const phi::KernelKey& expected_kernel_key, const phi::DenseTensor* tensor, const std::string& var_name, std::string* new_var_name, @@ -82,9 +82,14 @@ void HandleComplexGradToRealGrad(const OpFuncNode& op_func_node, inline bool need_device_transform(const phi::KernelKey& kernel_type_for_var, const phi::DenseTensor* tensor, - const phi::Place& expected_place) { + const phi::Backend& expected_backend) { if (kernel_type_for_var.backend() == phi::Backend::ALL_BACKEND || - platform::is_same_place(tensor->place(), expected_place) || + expected_backend == phi::Backend::ALL_BACKEND) { + return false; + } + + phi::Place expected_place = phi::TransToPhiPlace(expected_backend); + if (platform::is_same_place(tensor->place(), expected_place) || (platform::is_cuda_pinned_place(tensor->place()) && platform::is_cpu_place(expected_place))) { return false;