From cc2059a0d56b5bd146680a0f5d9a8fa956592edf Mon Sep 17 00:00:00 2001 From: jiangfan06 <117341294+MuShangCC@users.noreply.github.com> Date: Mon, 3 Jul 2023 19:07:33 +0800 Subject: [PATCH] [XPU] Fix the topk, set_value ops that using temporary tensors avoiding the memory overlaps during multi-stream inference (#54851) --- paddle/fluid/framework/operator.cc | 8 +- paddle/phi/backends/xpu/xpu2_op_list.cc | 3 +- paddle/phi/kernels/xpu/elementwise.h | 38 +++-- .../phi/kernels/xpu/instance_norm_kernel.cc | 44 ++++-- paddle/phi/kernels/xpu/scatter_kernel.cc | 2 +- paddle/phi/kernels/xpu/set_value_kernel.cc | 145 ++++++++++++------ paddle/phi/kernels/xpu/stride_slice_kernel.cc | 1 + paddle/phi/kernels/xpu/top_k_kernel.cc | 15 +- 8 files changed, 174 insertions(+), 82 deletions(-) diff --git a/paddle/fluid/framework/operator.cc b/paddle/fluid/framework/operator.cc index 1ca558a8d89..74e4b04d053 100644 --- a/paddle/fluid/framework/operator.cc +++ b/paddle/fluid/framework/operator.cc @@ -2621,7 +2621,9 @@ Scope* OperatorWithKernel::PrepareData( if (kernel_type_for_var.backend() == phi::Backend::GPU || kernel_type_for_var.backend() == phi::Backend::GPUDNN || new_expected_kernel_key->backend() == phi::Backend::GPU || - new_expected_kernel_key->backend() == phi::Backend::GPUDNN) { + new_expected_kernel_key->backend() == phi::Backend::GPUDNN || + kernel_type_for_var.backend() == phi::Backend::XPU || + new_expected_kernel_key->backend() == phi::Backend::XPU) { new_scope = TryCreateTransferScope( kernel_type_for_var, *new_expected_kernel_key, &scope); enable_cache_transfer_scope_ = true; @@ -2629,7 +2631,9 @@ Scope* OperatorWithKernel::PrepareData( } else if (kernel_type_for_var.backend() == phi::Backend::GPU || kernel_type_for_var.backend() == phi::Backend::GPUDNN || expected_kernel_key.backend() == phi::Backend::GPU || - expected_kernel_key.backend() == phi::Backend::GPUDNN) { + expected_kernel_key.backend() == phi::Backend::GPUDNN || + kernel_type_for_var.backend() == phi::Backend::XPU || + expected_kernel_key.backend() == phi::Backend::XPU) { new_scope = TryCreateTransferScope( kernel_type_for_var, expected_kernel_key, &scope); enable_cache_transfer_scope_ = true; diff --git a/paddle/phi/backends/xpu/xpu2_op_list.cc b/paddle/phi/backends/xpu/xpu2_op_list.cc index 63d7a1b3ced..cbd495135db 100644 --- a/paddle/phi/backends/xpu/xpu2_op_list.cc +++ b/paddle/phi/backends/xpu/xpu2_op_list.cc @@ -780,7 +780,8 @@ XPUOpMap& get_kl2_ops() { XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16, phi::DataType::INT16, - phi::DataType::INT32})}, + phi::DataType::INT32, + phi::DataType::INT64})}, {"strided_slice_grad", XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16, diff --git a/paddle/phi/kernels/xpu/elementwise.h b/paddle/phi/kernels/xpu/elementwise.h index 3af7a034069..efc62bc3ffc 100644 --- a/paddle/phi/kernels/xpu/elementwise.h +++ b/paddle/phi/kernels/xpu/elementwise.h @@ -29,19 +29,18 @@ namespace phi { template void XPUElementwise(const XPUContext& dev_ctx, - const DenseTensor& x, - const DenseTensor& y, + const T* x_data, + const DDim& x_dims, + const T* y_data, + const DDim& y_dims, int axis, - DenseTensor* z, + T* z_data, std::function&, const std::vector&)> func) { - dev_ctx.template Alloc(z); - auto x_dims = x.dims(); - auto y_dims = y.dims(); int max_dim = std::max(x_dims.size(), y_dims.size()); axis = (axis == -1 ? std::abs(x_dims.size() - y_dims.size()) : axis); @@ -78,9 +77,6 @@ void XPUElementwise(const XPUContext& dev_ctx, y_dims_vec[i + axis] = y_dims[i]; } } - const T* x_data = x.data(); - const T* y_data = y.data(); - T* z_data = z->data(); int ret = xpu::SUCCESS; @@ -104,6 +100,30 @@ void XPUElementwise(const XPUContext& dev_ctx, PADDLE_ENFORCE_XDNN_SUCCESS(ret, "elementwise"); } +template +void XPUElementwise(const XPUContext& dev_ctx, + const DenseTensor& x, + const DenseTensor& y, + int axis, + DenseTensor* z, + std::function&, + const std::vector&)> func) { + dev_ctx.template Alloc(z); + const DDim& x_dims = x.dims(); + const DDim& y_dims = y.dims(); + + const T* x_data = x.data(); + const T* y_data = y.data(); + T* z_data = z->data(); + + XPUElementwise( + dev_ctx, x_data, x_dims, y_data, y_dims, axis, z_data, func); +} + template void XPUElementwiseGrad(const XPUContext& dev_ctx, const DenseTensor& x, diff --git a/paddle/phi/kernels/xpu/instance_norm_kernel.cc b/paddle/phi/kernels/xpu/instance_norm_kernel.cc index 1631d0ccbee..4302d6ed900 100644 --- a/paddle/phi/kernels/xpu/instance_norm_kernel.cc +++ b/paddle/phi/kernels/xpu/instance_norm_kernel.cc @@ -38,29 +38,53 @@ void InstanceNormKernel(const Context& dev_ctx, dev_ctx.template Alloc(y); dev_ctx.template Alloc(saved_mean); dev_ctx.template Alloc(saved_var); + + xpu::ctx_guard RAII_GUARD(dev_ctx.x_context()); + // scale const auto scale_ptr = scale.get_ptr(); const float* scale_data_fp32 = nullptr; - DenseTensor scale_data; if (scale_ptr == nullptr) { - scale_data.Resize({c}); - dev_ctx.template Alloc(&scale_data); - phi::funcs::set_constant(dev_ctx, &scale_data, static_cast(1)); - scale_data_fp32 = scale_data.data(); + float* scale_data_temp = RAII_GUARD.alloc_l3_or_gm(c); + int r = xpu::constant(dev_ctx.x_context(), scale_data_temp, c, 1.f); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "constant"); + scale_data_fp32 = scale_data_temp; + } else if (scale_ptr->dtype() == + phi::CppTypeToDataType::Type()) { + float* scale_data_temp = + RAII_GUARD.alloc_l3_or_gm(scale_ptr->numel()); + int r = xpu::cast( + dev_ctx.x_context(), + reinterpret_cast(scale_ptr->data()), + scale_data_temp, + scale_ptr->numel()); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "cast"); + scale_data_fp32 = scale_data_temp; } else { // no need to cast scale_data_fp32 = scale_ptr->data(); } + // bias const float* bias_data_fp32 = nullptr; const auto* bias_ptr = bias.get_ptr(); - DenseTensor bias_data; if (bias_ptr == nullptr) { - bias_data.Resize({c}); - dev_ctx.template Alloc(&bias_data); - phi::funcs::set_constant(dev_ctx, &bias_data, static_cast(0)); - bias_data_fp32 = bias_data.data(); + float* bias_data_temp = RAII_GUARD.alloc_l3_or_gm(c); + int r = xpu::constant(dev_ctx.x_context(), bias_data_temp, c, 1.f); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "constant"); + bias_data_fp32 = bias_data_temp; + } else if (bias_ptr->dtype() == + phi::CppTypeToDataType::Type()) { + float* bias_data_temp = RAII_GUARD.alloc_l3_or_gm(bias_ptr->numel()); + int r = xpu::cast( + dev_ctx.x_context(), + reinterpret_cast(bias_ptr->data()), + bias_data_temp, + bias_ptr->numel()); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "cast"); + bias_data_fp32 = bias_data_temp; } else { + // no need to cast bias_data_fp32 = bias_ptr->data(); } diff --git a/paddle/phi/kernels/xpu/scatter_kernel.cc b/paddle/phi/kernels/xpu/scatter_kernel.cc index 4856c05ebf7..9052cd5b5f5 100644 --- a/paddle/phi/kernels/xpu/scatter_kernel.cc +++ b/paddle/phi/kernels/xpu/scatter_kernel.cc @@ -83,7 +83,7 @@ void ScatterKernel(const Context &ctx, static_cast(phi::product(phi::slice_ddim(x_dims, 1, x_dims.size()))); DenseTensor indices_cpu(index.type()); - phi::Copy(ctx, index, phi::CPUPlace(), false, &indices_cpu); + phi::Copy(ctx, index, phi::CPUPlace(), true, &indices_cpu); int r = 0; if (index_type == phi::DataType::INT32) { diff --git a/paddle/phi/kernels/xpu/set_value_kernel.cc b/paddle/phi/kernels/xpu/set_value_kernel.cc index 3d372043379..0d984be8b2f 100644 --- a/paddle/phi/kernels/xpu/set_value_kernel.cc +++ b/paddle/phi/kernels/xpu/set_value_kernel.cc @@ -73,7 +73,8 @@ inline void CheckIsDimsMatch(const DDim& first, const DDim& second) { template void SetValueImpl(const Context& dev_ctx, const DenseTensor& in, - const DenseTensor& value, + const T* value_data, + const DDim& value_dims, const IntArray& starts, const IntArray& ends, const IntArray& steps, @@ -139,8 +140,9 @@ void SetValueImpl(const Context& dev_ctx, in.numel()); PADDLE_ENFORCE_XDNN_SUCCESS(r, "copy"); - DenseTensor slice_tensor = - Empty(dev_ctx, IntArray{slice_dims.Get(), slice_dims.size()}); + xpu::ctx_guard RAII_GUARD(dev_ctx.x_context()); + int64_t slice_numels = phi::product(slice_dims); + XPUType* slice_data = RAII_GUARD.alloc_l3_or_gm(slice_numels); int in_size = in_dims.size(); std::vector starts_indices(in_size, 0); @@ -186,17 +188,14 @@ void SetValueImpl(const Context& dev_ctx, auto slice_shape = phi::vectorize(slice_dims); r = xpu::strided_slice(dev_ctx.x_context(), reinterpret_cast(out->data()), - reinterpret_cast(slice_tensor.data()), + slice_data, out_shape, starts_indices, ends_indices, strides_indices); PADDLE_ENFORCE_XDNN_SUCCESS(r, "strided_slice"); - r = xpu::constant(dev_ctx.x_context(), - reinterpret_cast(slice_tensor.data()), - slice_tensor.numel(), - XPUType(0)); + r = xpu::constant(dev_ctx.x_context(), slice_data, slice_numels, XPUType(0)); PADDLE_ENFORCE_XDNN_SUCCESS(r, "constant"); // Step 2: Set a tensor with the same shape as out tensor. And its data at @@ -216,8 +215,7 @@ void SetValueImpl(const Context& dev_ctx, // If do broadcasting on Tensor with shape [3] and [3], the result's shape // is [3], which is right. - slice_tensor.Resize(slice_dims_for_assign); - CheckIsDimsMatch(slice_dims_for_assign, value.dims()); + CheckIsDimsMatch(slice_dims_for_assign, value_dims); // XPUElementwise can do broadcasting auto f = [](xpu::Context* ctx, const XPUType* x, @@ -227,16 +225,20 @@ void SetValueImpl(const Context& dev_ctx, const std::vector& yshape) { return xpu::broadcast_add(ctx, x, y, z, xshape, yshape); }; - XPUElementwise( - dev_ctx, slice_tensor, value, -1, &slice_tensor, f); - - slice_tensor.Resize(slice_dims); + XPUElementwise(dev_ctx, + reinterpret_cast(slice_data), + slice_dims_for_assign, + value_data, + value_dims, + -1, + reinterpret_cast(slice_data), + f); // - Step 2.2 If stride < 0, flip the slice_tensor. if (need_flip) { r = xpu::flip(dev_ctx.x_context(), - reinterpret_cast(slice_tensor.data()), - reinterpret_cast(slice_tensor.data()), + reinterpret_cast(slice_data), + slice_data, slice_shape, flip_axis); PADDLE_ENFORCE_XDNN_SUCCESS(r, "flip"); @@ -244,7 +246,7 @@ void SetValueImpl(const Context& dev_ctx, // Step 3: Set out tensor with value r = xpu::strided_slice_view_update( dev_ctx.x_context(), - reinterpret_cast(slice_tensor.data()), + reinterpret_cast(slice_data), reinterpret_cast(out->data()), slice_shape, out_shape, @@ -255,16 +257,17 @@ void SetValueImpl(const Context& dev_ctx, } template -void SetTensorValueKernel(const Context& dev_ctx, - const DenseTensor& x, - const DenseTensor& value, - const IntArray& starts, - const IntArray& ends, - const IntArray& steps, - const std::vector& axes, - const std::vector& decrease_axes, - const std::vector& none_axes, - DenseTensor* out) { +void SetValueKernelImpl(const Context& dev_ctx, + const DenseTensor& x, + const T* value_data, + const DDim& value_dims, + const IntArray& starts, + const IntArray& ends, + const IntArray& steps, + const std::vector& axes, + const std::vector& decrease_axes, + const std::vector& none_axes, + DenseTensor* out) { // rank是xtensor的维度信息 const int rank = x.dims().size(); @@ -272,7 +275,8 @@ void SetTensorValueKernel(const Context& dev_ctx, case 1: SetValueImpl(dev_ctx, x, - value, + value_data, + value_dims, starts, ends, steps, @@ -284,7 +288,8 @@ void SetTensorValueKernel(const Context& dev_ctx, case 2: SetValueImpl(dev_ctx, x, - value, + value_data, + value_dims, starts, ends, steps, @@ -296,7 +301,8 @@ void SetTensorValueKernel(const Context& dev_ctx, case 3: SetValueImpl(dev_ctx, x, - value, + value_data, + value_dims, starts, ends, steps, @@ -308,7 +314,8 @@ void SetTensorValueKernel(const Context& dev_ctx, case 4: SetValueImpl(dev_ctx, x, - value, + value_data, + value_dims, starts, ends, steps, @@ -320,7 +327,8 @@ void SetTensorValueKernel(const Context& dev_ctx, case 5: SetValueImpl(dev_ctx, x, - value, + value_data, + value_dims, starts, ends, steps, @@ -332,7 +340,8 @@ void SetTensorValueKernel(const Context& dev_ctx, case 6: SetValueImpl(dev_ctx, x, - value, + value_data, + value_dims, starts, ends, steps, @@ -347,6 +356,30 @@ void SetTensorValueKernel(const Context& dev_ctx, } } +template +void SetTensorValueKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& value, + const IntArray& starts, + const IntArray& ends, + const IntArray& steps, + const std::vector& axes, + const std::vector& decrease_axes, + const std::vector& none_axes, + DenseTensor* out) { + SetValueKernelImpl(dev_ctx, + x, + value.data(), + value.dims(), + starts, + ends, + steps, + axes, + decrease_axes, + none_axes, + out); +} + template void SetValueKernel(const Context& dev_ctx, const DenseTensor& x, @@ -359,25 +392,37 @@ void SetValueKernel(const Context& dev_ctx, const std::vector& shape, const std::vector& values, DenseTensor* out) { - std::vector assgin_values; - assgin_values.reserve(values.size()); + using XPUType = typename XPUTypeTrait::Type; + std::vector assign_values; + assign_values.reserve(values.size()); for (const auto& val : values) { - assgin_values.push_back(val.to()); + assign_values.push_back(val.to()); } - DenseTensor value_tensor = Empty(dev_ctx, shape); - phi::TensorFromVector(assgin_values, dev_ctx, &value_tensor); - value_tensor.Resize(phi::make_ddim(shape)); - - SetTensorValueKernel(dev_ctx, - x, - value_tensor, - starts, - ends, - steps, - axes, - decrease_axes, - none_axes, - out); + + xpu::ctx_guard RAII_GUARD(dev_ctx.x_context()); + auto value_dims = phi::make_ddim(shape); + XPUType* value_data = + RAII_GUARD.alloc_l3_or_gm(phi::product(value_dims)); + + phi::CPUPlace src_place; + auto dst_place = dev_ctx.GetPlace(); + memory_utils::Copy(dst_place, + value_data, + src_place, + assign_values.data(), + assign_values.size() * sizeof(T)); + + SetValueKernelImpl(dev_ctx, + x, + reinterpret_cast(value_data), + value_dims, + starts, + ends, + steps, + axes, + decrease_axes, + none_axes, + out); } } // namespace phi diff --git a/paddle/phi/kernels/xpu/stride_slice_kernel.cc b/paddle/phi/kernels/xpu/stride_slice_kernel.cc index 517445bb713..da181376997 100644 --- a/paddle/phi/kernels/xpu/stride_slice_kernel.cc +++ b/paddle/phi/kernels/xpu/stride_slice_kernel.cc @@ -117,5 +117,6 @@ PD_REGISTER_KERNEL(strided_slice_raw, phi::StridedSliceRawKernel, int, int16_t, + int64_t, float, phi::dtype::float16) {} diff --git a/paddle/phi/kernels/xpu/top_k_kernel.cc b/paddle/phi/kernels/xpu/top_k_kernel.cc index 5d77e9c4dc8..a3a37db5e6e 100644 --- a/paddle/phi/kernels/xpu/top_k_kernel.cc +++ b/paddle/phi/kernels/xpu/top_k_kernel.cc @@ -60,8 +60,8 @@ void TopkKernel(const Context& dev_ctx, size_t k = k_scalar.to(); if (axis + 1 == in_dims.size()) { xpu::ctx_guard RAII_GUARD(dev_ctx.x_context()); - int32_t* indices_int_data = Alloc_l3_or_gm( - dev_ctx, &RAII_GUARD, indices->numel()); + int32_t* indices_int_data = + RAII_GUARD.alloc_l3_or_gm(indices->numel()); PADDLE_ENFORCE_XDNN_NOT_NULL(indices_int_data); const size_t row = @@ -106,8 +106,7 @@ void TopkKernel(const Context& dev_ctx, } xpu::ctx_guard RAII_GUARD(dev_ctx.x_context()); - XPUType* trans_in_data = - Alloc_l3_or_gm(dev_ctx, &RAII_GUARD, x.numel()); + XPUType* trans_in_data = RAII_GUARD.alloc_l3_or_gm(x.numel()); PADDLE_ENFORCE_XDNN_NOT_NULL(trans_in_data); // Transpose and save interval output to trans_in @@ -123,16 +122,14 @@ void TopkKernel(const Context& dev_ctx, r, XPUAPIErrorMsg[r])); - XPUType* trans_out_data = - Alloc_l3_or_gm(dev_ctx, &RAII_GUARD, out->numel()); + XPUType* trans_out_data = RAII_GUARD.alloc_l3_or_gm(out->numel()); PADDLE_ENFORCE_XDNN_NOT_NULL(trans_out_data); - int64_t* trans_idx_data = - Alloc_l3_or_gm(dev_ctx, &RAII_GUARD, out->numel()); + int64_t* trans_idx_data = RAII_GUARD.alloc_l3_or_gm(out->numel()); PADDLE_ENFORCE_XDNN_NOT_NULL(trans_idx_data); int32_t* trans_idx_int32_data = - Alloc_l3_or_gm(dev_ctx, &RAII_GUARD, out->numel()); + RAII_GUARD.alloc_l3_or_gm(out->numel()); PADDLE_ENFORCE_XDNN_NOT_NULL(trans_idx_int32_data); const size_t row = phi::product(phi::slice_ddim(trans_dims, 0, trans_dims.size() - 1)); -- GitLab