未验证 提交 8775545a 编写于 作者: Z zyfncg 提交者: GitHub

support prepare_data for selected_rows in c++ api (#47380)

上级 2096448b
......@@ -169,10 +169,6 @@ inline phi::DenseTensor TransDataPlace(const phi::DenseTensor& tensor,
VLOG(3) << "DeviceTransform in, src_place " << tensor.place()
<< " dst_place: " << dst_place;
DefaultAllocator alloc(dst_place);
phi::DenseTensor out(&alloc,
{tensor.dtype(), tensor.dims(), tensor.layout()});
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
auto& pool = paddle::platform::DeviceContextPool::Instance();
// NOTE(yy): TransDataPlace should wait for computation of input.
......@@ -191,6 +187,7 @@ inline phi::DenseTensor TransDataPlace(const phi::DenseTensor& tensor,
// the transforming is from CPU to GPU and the number of elements is little.
// But the embarrassment is that this solution this solution makes training
// slower.
phi::DenseTensor out;
paddle::framework::TensorCopySync(tensor, dst_place, &out);
return out;
}
......@@ -305,6 +302,47 @@ paddle::optional<std::vector<phi::DenseTensor>> PrepareData(
return paddle::none;
}
std::shared_ptr<phi::SelectedRows> PrepareDataForSelectedRows(
const Tensor& input,
const phi::TensorArgDef& target_args_def,
const TransformFlag& transform_flag) {
const auto& tensor_in = input.impl();
if (tensor_in) {
phi::SelectedRows& selected_rows =
*static_cast<phi::SelectedRows*>(tensor_in.get());
if (!transform_flag.NeedTransform() || !selected_rows.initialized() ||
(!NeedTransformPlace(
selected_rows.place(), target_args_def.backend, transform_flag))) {
return std::static_pointer_cast<phi::SelectedRows>(tensor_in);
}
auto dense_out = TransDataPlace(
selected_rows.value(), phi::TransToPhiPlace(target_args_def.backend));
if (selected_rows.place().GetType() == AllocationType::GPUPINNED) {
selected_rows.mutable_value()->ShareBufferWith(dense_out);
return std::static_pointer_cast<phi::SelectedRows>(tensor_in);
}
auto out_new = std::make_shared<phi::SelectedRows>(selected_rows.rows(),
selected_rows.height());
*out_new->mutable_value() = dense_out;
return out_new;
}
PADDLE_THROW(phi::errors::InvalidArgument(
"The impl() of input tensor is nullptr, it doesn't support for "
"selected_rows data transform now."));
}
paddle::optional<phi::SelectedRows> PrepareDataForSelectedRows(
const paddle::optional<Tensor>& input,
const phi::TensorArgDef& target_args_def,
const TransformFlag& transform_flag) {
if (input) {
return *PrepareDataForSelectedRows(*input, target_args_def, transform_flag);
}
return paddle::none;
}
void TransDataBackend(const phi::DenseTensor* tensor,
Backend target_backend,
phi::DenseTensor* out) {
......
......@@ -82,6 +82,17 @@ paddle::optional<std::vector<phi::DenseTensor>> PrepareData(
const phi::TensorArgDef& target_args_def,
const TransformFlag& transform_flag);
// Only support transfering place for SelectedRows
std::shared_ptr<phi::SelectedRows> PrepareDataForSelectedRows(
const Tensor& input,
const phi::TensorArgDef& target_args_def,
const TransformFlag& transform_flag);
paddle::optional<phi::SelectedRows> PrepareDataForSelectedRows(
const paddle::optional<Tensor>& input,
const phi::TensorArgDef& target_args_def,
const TransformFlag& transform_flag);
void TransDataBackend(const phi::DenseTensor* tensor,
Backend target_backend,
phi::DenseTensor* out);
......
......@@ -715,7 +715,7 @@ PADDLE_API {self.get_return_type(inplace_flag=True)} {api_func_name}({self.get_d
input_tensor_code = (
input_tensor_code
+ f"""
{code_indent} auto {PREFIX_TENSOR_NAME}{input_name} = TensorToSelectedRows({input_name});
{code_indent} auto {PREFIX_TENSOR_NAME}{input_name} = PrepareDataForSelectedRows({input_name}, kernel.InputAt({kernel_param.index(input_name)}), {trans_flag});
"""
)
return input_tensor_code
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册