未验证 提交 54497c47 编写于 作者: R Ruibiao Chen 提交者: GitHub

Skip device transfer when arg-defs is set to Allbackend (#52294)

上级 41f0e3c3
...@@ -29,9 +29,8 @@ namespace paddle { ...@@ -29,9 +29,8 @@ namespace paddle {
namespace framework { namespace framework {
namespace interpreter { namespace interpreter {
bool DataTranferHelper::apply( bool DataTranferHelper::apply(const phi::KernelKey& kernel_type_for_var,
const phi::KernelKey& kernel_type_for_var, const phi::KernelKey& expected_kernel_key,
const framework::OpKernelType& expected_kernel_key,
const phi::DenseTensor* tensor, const phi::DenseTensor* tensor,
const std::string& var_name, const std::string& var_name,
std::string* new_var_name, std::string* new_var_name,
...@@ -43,13 +42,11 @@ bool DataTranferHelper::apply( ...@@ -43,13 +42,11 @@ bool DataTranferHelper::apply(
auto* src_var_name = &var_name; auto* src_var_name = &var_name;
// 1. layout transform // 1. layout transform
if (need_layout_transform( if (need_layout_transform(kernel_type_for_var, expected_kernel_key)) {
kernel_type_for_var,
TransOpKernelTypeToPhiKernelKey(expected_kernel_key))) {
auto op = TransferLayout(*src_var_name, auto op = TransferLayout(*src_var_name,
new_var_name, new_var_name,
kernel_type_for_var.layout(), kernel_type_for_var.layout(),
expected_kernel_key.data_layout_, expected_kernel_key.layout(),
var_scope_, var_scope_,
scope_, scope_,
is_fetch_v2); is_fetch_v2);
...@@ -61,15 +58,14 @@ bool DataTranferHelper::apply( ...@@ -61,15 +58,14 @@ bool DataTranferHelper::apply(
src_var_name = new_var_name; src_var_name = new_var_name;
is_transferred = true; is_transferred = true;
} }
// 2. dype transform // 2. dype transform
if (need_dtype_transform( if (need_dtype_transform(kernel_type_for_var, expected_kernel_key)) {
kernel_type_for_var,
TransOpKernelTypeToPhiKernelKey(expected_kernel_key))) {
auto op = TransferDtype( auto op = TransferDtype(
*src_var_name, *src_var_name,
new_var_name, new_var_name,
framework::TransToProtoVarType(kernel_type_for_var.dtype()), framework::TransToProtoVarType(kernel_type_for_var.dtype()),
expected_kernel_key.data_type_, framework::TransToProtoVarType(expected_kernel_key.dtype()),
var_scope_, var_scope_,
scope_); scope_);
if (op) { if (op) {
...@@ -80,11 +76,12 @@ bool DataTranferHelper::apply( ...@@ -80,11 +76,12 @@ bool DataTranferHelper::apply(
src_var_name = new_var_name; src_var_name = new_var_name;
is_transferred = true; is_transferred = true;
} }
// 3. device transform // 3. device transform
if (need_device_transform( phi::Backend expected_backend = expected_kernel_key.backend();
kernel_type_for_var, tensor, expected_kernel_key.place_)) { if (need_device_transform(kernel_type_for_var, tensor, expected_backend)) {
auto src_place = tensor->place(); auto src_place = tensor->place();
auto dst_place = expected_kernel_key.place_; auto dst_place = phi::TransToPhiPlace(expected_backend);
auto op = TransferDevice( auto op = TransferDevice(
*src_var_name, new_var_name, src_place, dst_place, var_scope_, scope_); *src_var_name, new_var_name, src_place, dst_place, var_scope_, scope_);
...@@ -575,8 +572,7 @@ void ApplyDataTransform(const OpKernelType& expected_kernel_key, ...@@ -575,8 +572,7 @@ void ApplyDataTransform(const OpKernelType& expected_kernel_key,
} }
std::unique_ptr<phi::KernelKey> std::unique_ptr<phi::KernelKey>
expected_kernel_key_for_argument_def = nullptr; expected_kernel_key_for_argument_def = nullptr;
if (argument_def && if (argument_def) {
argument_def->backend != phi::Backend::ALL_BACKEND) {
const phi::Backend& tensor_backend = const phi::Backend& tensor_backend =
phi::TransToPhiBackend(tensor_in->place()); phi::TransToPhiBackend(tensor_in->place());
const phi::Backend& def_backend = argument_def->backend; const phi::Backend& def_backend = argument_def->backend;
...@@ -607,9 +603,8 @@ void ApplyDataTransform(const OpKernelType& expected_kernel_key, ...@@ -607,9 +603,8 @@ void ApplyDataTransform(const OpKernelType& expected_kernel_key,
is_transferred = data_transfer_helper.apply( is_transferred = data_transfer_helper.apply(
kernel_key_for_var, kernel_key_for_var,
(expected_kernel_key_for_argument_def (expected_kernel_key_for_argument_def
? TransPhiKernelKeyToOpKernelType( ? *expected_kernel_key_for_argument_def.get()
*expected_kernel_key_for_argument_def.get()) : TransOpKernelTypeToPhiKernelKey(expected_kernel_key)),
: expected_kernel_key),
tensor_in, tensor_in,
var_name, var_name,
&new_var_name, &new_var_name,
......
...@@ -35,7 +35,7 @@ class DataTranferHelper { ...@@ -35,7 +35,7 @@ class DataTranferHelper {
: place_(place), var_scope_(var_scope), scope_(local_scope) {} : place_(place), var_scope_(var_scope), scope_(local_scope) {}
bool apply(const phi::KernelKey& kernel_type_for_var, bool apply(const phi::KernelKey& kernel_type_for_var,
const framework::OpKernelType& expected_kernel_key, const phi::KernelKey& expected_kernel_key,
const phi::DenseTensor* tensor, const phi::DenseTensor* tensor,
const std::string& var_name, const std::string& var_name,
std::string* new_var_name, std::string* new_var_name,
...@@ -82,9 +82,14 @@ void HandleComplexGradToRealGrad(const OpFuncNode& op_func_node, ...@@ -82,9 +82,14 @@ void HandleComplexGradToRealGrad(const OpFuncNode& op_func_node,
inline bool need_device_transform(const phi::KernelKey& kernel_type_for_var, inline bool need_device_transform(const phi::KernelKey& kernel_type_for_var,
const phi::DenseTensor* tensor, const phi::DenseTensor* tensor,
const phi::Place& expected_place) { const phi::Backend& expected_backend) {
if (kernel_type_for_var.backend() == phi::Backend::ALL_BACKEND || if (kernel_type_for_var.backend() == phi::Backend::ALL_BACKEND ||
platform::is_same_place(tensor->place(), expected_place) || expected_backend == phi::Backend::ALL_BACKEND) {
return false;
}
phi::Place expected_place = phi::TransToPhiPlace(expected_backend);
if (platform::is_same_place(tensor->place(), expected_place) ||
(platform::is_cuda_pinned_place(tensor->place()) && (platform::is_cuda_pinned_place(tensor->place()) &&
platform::is_cpu_place(expected_place))) { platform::is_cpu_place(expected_place))) {
return false; return false;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册