From 8775545a7dd4e54d46ad5ba48db6d31e0aea0dd2 Mon Sep 17 00:00:00 2001 From: zyfncg Date: Thu, 27 Oct 2022 19:06:17 +0800 Subject: [PATCH] support prepare_data for selected_rows in c++ api (#47380) --- paddle/phi/api/lib/data_transform.cc | 46 +++++++++++++++++++++-- paddle/phi/api/lib/data_transform.h | 11 ++++++ paddle/phi/api/yaml/generator/api_base.py | 2 +- 3 files changed, 54 insertions(+), 5 deletions(-) diff --git a/paddle/phi/api/lib/data_transform.cc b/paddle/phi/api/lib/data_transform.cc index 048a24ff5e3..c6a773ebe5f 100644 --- a/paddle/phi/api/lib/data_transform.cc +++ b/paddle/phi/api/lib/data_transform.cc @@ -169,10 +169,6 @@ inline phi::DenseTensor TransDataPlace(const phi::DenseTensor& tensor, VLOG(3) << "DeviceTransform in, src_place " << tensor.place() << " dst_place: " << dst_place; - DefaultAllocator alloc(dst_place); - phi::DenseTensor out(&alloc, - {tensor.dtype(), tensor.dims(), tensor.layout()}); - #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) auto& pool = paddle::platform::DeviceContextPool::Instance(); // NOTE(yy): TransDataPlace should wait for computation of input. @@ -191,6 +187,7 @@ inline phi::DenseTensor TransDataPlace(const phi::DenseTensor& tensor, // the transforming is from CPU to GPU and the number of elements is little. // But the embarrassment is that this solution this solution makes training // slower. + phi::DenseTensor out; paddle::framework::TensorCopySync(tensor, dst_place, &out); return out; } @@ -305,6 +302,47 @@ paddle::optional> PrepareData( return paddle::none; } +std::shared_ptr PrepareDataForSelectedRows( + const Tensor& input, + const phi::TensorArgDef& target_args_def, + const TransformFlag& transform_flag) { + const auto& tensor_in = input.impl(); + if (tensor_in) { + phi::SelectedRows& selected_rows = + *static_cast(tensor_in.get()); + if (!transform_flag.NeedTransform() || !selected_rows.initialized() || + (!NeedTransformPlace( + selected_rows.place(), target_args_def.backend, transform_flag))) { + return std::static_pointer_cast(tensor_in); + } + + auto dense_out = TransDataPlace( + selected_rows.value(), phi::TransToPhiPlace(target_args_def.backend)); + if (selected_rows.place().GetType() == AllocationType::GPUPINNED) { + selected_rows.mutable_value()->ShareBufferWith(dense_out); + return std::static_pointer_cast(tensor_in); + } + + auto out_new = std::make_shared(selected_rows.rows(), + selected_rows.height()); + *out_new->mutable_value() = dense_out; + return out_new; + } + PADDLE_THROW(phi::errors::InvalidArgument( + "The impl() of input tensor is nullptr, it doesn't support for " + "selected_rows data transform now.")); +} + +paddle::optional PrepareDataForSelectedRows( + const paddle::optional& input, + const phi::TensorArgDef& target_args_def, + const TransformFlag& transform_flag) { + if (input) { + return *PrepareDataForSelectedRows(*input, target_args_def, transform_flag); + } + return paddle::none; +} + void TransDataBackend(const phi::DenseTensor* tensor, Backend target_backend, phi::DenseTensor* out) { diff --git a/paddle/phi/api/lib/data_transform.h b/paddle/phi/api/lib/data_transform.h index 7695855e30b..7a97bb01f61 100644 --- a/paddle/phi/api/lib/data_transform.h +++ b/paddle/phi/api/lib/data_transform.h @@ -82,6 +82,17 @@ paddle::optional> PrepareData( const phi::TensorArgDef& target_args_def, const TransformFlag& transform_flag); +// Only support transfering place for SelectedRows +std::shared_ptr PrepareDataForSelectedRows( + const Tensor& input, + const phi::TensorArgDef& target_args_def, + const TransformFlag& transform_flag); + +paddle::optional PrepareDataForSelectedRows( + const paddle::optional& input, + const phi::TensorArgDef& target_args_def, + const TransformFlag& transform_flag); + void TransDataBackend(const phi::DenseTensor* tensor, Backend target_backend, phi::DenseTensor* out); diff --git a/paddle/phi/api/yaml/generator/api_base.py b/paddle/phi/api/yaml/generator/api_base.py index e67023d2faf..53b950b63f0 100644 --- a/paddle/phi/api/yaml/generator/api_base.py +++ b/paddle/phi/api/yaml/generator/api_base.py @@ -715,7 +715,7 @@ PADDLE_API {self.get_return_type(inplace_flag=True)} {api_func_name}({self.get_d input_tensor_code = ( input_tensor_code + f""" -{code_indent} auto {PREFIX_TENSOR_NAME}{input_name} = TensorToSelectedRows({input_name}); +{code_indent} auto {PREFIX_TENSOR_NAME}{input_name} = PrepareDataForSelectedRows({input_name}, kernel.InputAt({kernel_param.index(input_name)}), {trans_flag}); """ ) return input_tensor_code -- GitLab