From 438975fd18d3e62bee778e1f7ed7314e4a59dfa5 Mon Sep 17 00:00:00 2001 From: Leo Guo <58431564+ZibinGuo@users.noreply.github.com> Date: Thu, 12 Jan 2023 11:23:47 +0800 Subject: [PATCH] Fix the bugs of set_value and set_value_grad ops and add register in (#49750) xpu2_op_list.cc. test=kunlun --- paddle/phi/backends/xpu/xpu2_op_list.cc | 10 + .../phi/kernels/xpu/set_value_grad_kernel.cc | 428 ++++++++++++++---- paddle/phi/kernels/xpu/set_value_kernel.cc | 384 +++++++++++++--- 3 files changed, 679 insertions(+), 143 deletions(-) diff --git a/paddle/phi/backends/xpu/xpu2_op_list.cc b/paddle/phi/backends/xpu/xpu2_op_list.cc index 9c42a8b550..39ca3764b1 100644 --- a/paddle/phi/backends/xpu/xpu2_op_list.cc +++ b/paddle/phi/backends/xpu/xpu2_op_list.cc @@ -478,6 +478,16 @@ XPUOpMap& get_kl2_ops() { phi::DataType::FLOAT32})}, {"sampling_id", XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT64})}, + {"set_value", + XPUKernelSet({phi::DataType::FLOAT32, + phi::DataType::INT32, + phi::DataType::INT64, + phi::DataType::FLOAT16})}, + {"set_value_grad", + XPUKernelSet({phi::DataType::FLOAT32, + phi::DataType::INT32, + phi::DataType::INT64, + phi::DataType::FLOAT16})}, {"sgd", XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, {"sgd_dense_param_sparse_grad", XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, diff --git a/paddle/phi/kernels/xpu/set_value_grad_kernel.cc b/paddle/phi/kernels/xpu/set_value_grad_kernel.cc index f4ce203145..7cee1b1d88 100644 --- a/paddle/phi/kernels/xpu/set_value_grad_kernel.cc +++ b/paddle/phi/kernels/xpu/set_value_grad_kernel.cc @@ -1,4 +1,4 @@ -// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -19,103 +19,379 @@ #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/common/int_array.h" -#include "paddle/phi/common/scalar.h" #include "paddle/phi/core/dense_tensor.h" #include "paddle/phi/core/tensor_utils.h" -#include "paddle/phi/kernels/empty_kernel.h" -#include "paddle/phi/kernels/funcs/broadcast_function.h" +#include "paddle/phi/kernels/full_kernel.h" #include "paddle/phi/kernels/funcs/eigen/common.h" -#include "paddle/phi/kernels/funcs/eigen/eigen_function.h" -#include "paddle/phi/kernels/funcs/elementwise_functor.h" -#include "paddle/phi/kernels/funcs/slice_utils.h" +#include "paddle/phi/kernels/funcs/math_function.h" +#include "paddle/phi/kernels/funcs/strided_slice.h" namespace phi { -template -void SetValueGradKernel(const Context& dev_ctx, - const DenseTensor& out_grad, - const IntArray& starts, - const IntArray& ends, - const IntArray& steps, - const std::vector& axes, - const std::vector& decrease_axes, - const std::vector& none_axes, - DenseTensor* x_grad, - DenseTensor* value_grad) { - using XPUType = typename XPUTypeTrait::Type; - x_grad->Resize(out_grad.dims()); - dev_ctx.template Alloc(x_grad); - dev_ctx.template Alloc(value_grad); - - const XPUType* dy_data = reinterpret_cast(out_grad.data()); - XPUType* dx_data = reinterpret_cast(x_grad->data()); - XPUType* dv_data = reinterpret_cast(value_grad->data()); - - std::vector starts_vec = starts.GetData(); - std::vector ends_vec = ends.GetData(); - std::vector steps_vec = steps.GetData(); - - auto dy_dims = out_grad.dims(); - std::vector dy_shape; - for (int i = 0; i < dy_dims.size(); ++i) { - dy_shape.push_back(dy_dims[i]); +inline void GetOffsets(const DDim& big_dim, + const DDim& small_dim, + DDim start_offset, + int cur_dim, + std::vector* offsets) { + if (cur_dim == big_dim.size()) { + offsets->push_back(start_offset); + return; } - - auto dv_dims = value_grad->dims(); - std::vector dv_shape; - for (int i = 0; i < dv_dims.size(); ++i) { - dv_shape.push_back(dv_dims[i]); + if (small_dim[cur_dim] == big_dim[cur_dim]) { + GetOffsets(big_dim, small_dim, start_offset, cur_dim + 1, offsets); + } else { + for (int i = 0; i < big_dim[cur_dim]; i++) { + GetOffsets(big_dim, small_dim, start_offset, cur_dim + 1, offsets); + start_offset[cur_dim] += 1; + } } +} - auto dx_dims = x_grad->dims(); - std::vector dx_shape; - for (int i = 0; i < dx_dims.size(); ++i) { - dx_shape.push_back(dx_dims[i]); - } +template +void SetValueGradImpl(const Context& dev_ctx, + const DenseTensor& out_grad, + const IntArray& starts, + const IntArray& ends, + const IntArray& steps, + const std::vector& axes, + const std::vector& decrease_axes, + const std::vector& none_axes, + DenseTensor* x_grad, + DenseTensor* value_grad) { + using XPUType = typename XPUTypeTrait::Type; + PADDLE_ENFORCE_EQ( + out_grad.IsInitialized(), + true, + errors::PermissionDenied( + "The input of `set_value_grad`(out_grad) has not been initialized")); + + auto in_dims = out_grad.dims(); + auto in_dims_vector = phi::vectorize(in_dims); + + std::vector decrease_axis_int32(decrease_axes.begin(), + decrease_axes.end()); + std::vector axes_int32(axes.begin(), axes.end()); + std::vector infer_flags(axes.size(), 1); + std::vector out_dims_vector(in_dims.size(), -1); + std::vector starts_local = starts.GetData(); + std::vector ends_local = ends.GetData(); + std::vector steps_local = steps.GetData(); + funcs::StridedSliceOutDims(starts_local, + ends_local, + steps_local, + axes_int32, + infer_flags, + in_dims, + decrease_axis_int32, + out_dims_vector.data(), + axes.size(), + false); + + DDim out_dims(phi::make_ddim(out_dims_vector)); - std::vector starts_vec_int32; - for (size_t i = 0; i < starts_vec.size(); ++i) { - starts_vec_int32.push_back(starts_vec[i]); + std::vector reverse_vector(starts_local.size(), 0); + funcs::StridedSliceFunctor(starts_local.data(), + ends_local.data(), + steps_local.data(), + axes_int32.data(), + reverse_vector.data(), + in_dims, + infer_flags, + decrease_axis_int32, + starts_local.size()); + + std::vector starts_indices(RANK, 0); + std::vector ends_indices(RANK, 0); + std::vector steps_indices(RANK, 0); + std::vector reverse_axis(RANK, 0); + std::vector flip_axis; + + for (size_t axis = 0; axis < RANK; axis++) { + starts_indices[axis] = 0; + ends_indices[axis] = out_dims[axis]; + steps_indices[axis] = 1; + reverse_axis[axis] = false; } - std::vector ends_vec_int32; - for (size_t i = 0; i < ends_vec.size(); ++i) { - ends_vec_int32.push_back(ends_vec[i]); + for (size_t axis = 0; axis < axes.size(); axis++) { + int axis_index = axes[axis]; + starts_indices[axis_index] = starts_local[axis]; + ends_indices[axis_index] = ends_local[axis]; + steps_indices[axis_index] = steps_local[axis]; + reverse_axis[axis_index] = (reverse_vector[axis] == 1) ? true : false; } - std::vector steps_vec_int32; - for (size_t i = 0; i < steps_vec.size(); ++i) { - steps_vec_int32.push_back(steps_vec[i]); + for (size_t axis = 0; axis < RANK; axis++) { + if (reverse_axis[axis]) { + flip_axis.push_back(axis); + } + if (ends_indices[axis] > in_dims[axis]) { + ends_indices[axis] = in_dims[axis]; + } } - std::vector axes_int32; - for (size_t i = 0; i < axes.size(); ++i) { - axes_int32.push_back(axes[i]); + bool need_reverse = false; + for (size_t axis = 0; axis < axes.size(); axis++) { + if (reverse_vector[axis] == 1) { + need_reverse = true; + break; + } } - std::vector decrease_axes_int32; - for (size_t i = 0; i < decrease_axes.size(); ++i) { - decrease_axes_int32.push_back(decrease_axes[i]); + phi::funcs::SetConstant set_zero; + int r = XPU_SUCCESS; + + if (x_grad) { + // Set gradient of `Input` + Copy(dev_ctx, out_grad, dev_ctx.GetPlace(), false, x_grad); + + DenseTensor tmp = Full(dev_ctx, out_dims_vector, static_cast(0)); + + r = xpu::strided_slice_view_update( + dev_ctx.x_context(), + reinterpret_cast(tmp.data()), + reinterpret_cast(x_grad->data()), + out_dims_vector, + phi::vectorize(x_grad->dims()), + starts_indices, + ends_indices, + steps_indices); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "strided_slice_view_update"); } + if (value_grad) { + dev_ctx.template Alloc(value_grad); + set_zero(dev_ctx, value_grad, static_cast(0)); + + if (value_grad->dims() == out_dims) { + if (need_reverse) { + r = xpu::strided_slice( + dev_ctx.x_context(), + reinterpret_cast(out_grad.data()), + reinterpret_cast(value_grad->data()), + in_dims_vector, + starts_indices, + ends_indices, + steps_indices); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "strided_slice"); - std::vector none_axes_int32; - for (size_t i = 0; i < none_axes.size(); ++i) { - none_axes_int32.push_back(none_axes[i]); + r = xpu::flip(dev_ctx.x_context(), + reinterpret_cast(value_grad->data()), + reinterpret_cast(value_grad->data()), + out_dims_vector, + flip_axis); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "flip"); + } else { + r = xpu::strided_slice( + dev_ctx.x_context(), + reinterpret_cast(out_grad.data()), + reinterpret_cast(value_grad->data()), + in_dims_vector, + starts_indices, + ends_indices, + steps_indices); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "strided_slice"); + } + } else { + int out_dims_size = out_dims.size(); + auto value_grad_dims = value_grad->dims(); + auto fake_value_grad_dims = out_dims; + + // Create an extented shape according to the rules of broadcast. + auto value_grad_dims_size = value_grad_dims.size(); + + int num_decrease = 0; + + int decrease_axis_size = decrease_axes.size(); + for (int i = 0; i < out_dims_size; i++) { + if (decrease_axes.end() != + std::find(decrease_axes.begin(), decrease_axes.end(), i)) { + fake_value_grad_dims[i] = 1; + num_decrease++; + } else if (i < out_dims_size - (value_grad_dims_size + + decrease_axis_size - num_decrease)) { + fake_value_grad_dims[i] = 1; + } else { + auto index_grad = + i - (out_dims_size - + (value_grad_dims_size + decrease_axis_size - num_decrease)); + fake_value_grad_dims[i] = value_grad_dims[index_grad]; + + PADDLE_ENFORCE_EQ((out_dims[i] == value_grad_dims[index_grad]) || + (value_grad_dims[index_grad] == 1), + true, + errors::InvalidArgument( + "An error occurred while calculating %s: " + "[%s] can not be accumulated into [%s].", + paddle::framework::GradVarName("ValueTensor"), + out_dims, + value_grad_dims)); + } + } + + VLOG(3) << "Dimensions of " + << paddle::framework::GradVarName("ValueTensor") << "([" + << value_grad_dims << "])is broadcasted into [" + << fake_value_grad_dims << "]."; + + std::vector slice_end(RANK, 0); + auto offset = out_dims; + for (int i = 0; i < out_dims_size; i++) { + offset[i] = 0; + } + std::vector offsets; + GetOffsets(out_dims, fake_value_grad_dims, offset, 0, &offsets); + + DenseTensor tmp = Full(dev_ctx, out_dims_vector, static_cast(0)); + + r = xpu::strided_slice( + dev_ctx.x_context(), + reinterpret_cast(out_grad.data()), + reinterpret_cast(tmp.data()), + in_dims_vector, + starts_indices, + ends_indices, + steps_indices); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "strided_slice"); + + // accumulate gradient + DenseTensor tmp2 = + Full(dev_ctx, + {fake_value_grad_dims.Get(), fake_value_grad_dims.size()}, + static_cast(0)); + auto value_grad_dims_vec = phi::vectorize(value_grad_dims); + for (auto offset : offsets) { + for (int i = 0; i < out_dims_size; i++) { + slice_end[i] = offset[i] + fake_value_grad_dims[i]; + } + r = xpu::slice(dev_ctx.x_context(), + reinterpret_cast(tmp.data()), + reinterpret_cast(tmp2.data()), + out_dims_vector, + phi::vectorize(offset), + slice_end); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "slice"); + r = xpu::broadcast_add( + dev_ctx.x_context(), + reinterpret_cast(value_grad->data()), + reinterpret_cast(tmp2.data()), + reinterpret_cast(value_grad->data()), + value_grad_dims_vec, + value_grad_dims_vec); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "broadcast_add"); + } + if (need_reverse) { + r = xpu::flip(dev_ctx.x_context(), + reinterpret_cast(value_grad->data()), + reinterpret_cast(value_grad->data()), + value_grad_dims_vec, + flip_axis); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "flip"); + } + } } +} + +template +void SetValueGradKernel(const Context& dev_ctx, + const DenseTensor& out_grad, + const IntArray& starts, + const IntArray& ends, + const IntArray& steps, + const std::vector& axes, + const std::vector& decrease_axes, + const std::vector& none_axes, + DenseTensor* x_grad, + DenseTensor* value_grad) { + const int rank = out_grad.dims().size(); - int r = xpu::set_value_grad(dev_ctx.x_context(), - dy_data, - dx_data, - dv_data, - dy_shape, - dv_shape, - starts_vec_int32, - ends_vec_int32, - steps_vec_int32, - axes_int32, - decrease_axes_int32, - none_axes_int32); - PADDLE_ENFORCE_XDNN_SUCCESS(r, "set_value_grad"); + switch (rank) { + case 1: + SetValueGradImpl(dev_ctx, + out_grad, + starts, + ends, + steps, + axes, + decrease_axes, + none_axes, + x_grad, + value_grad); + break; + case 2: + SetValueGradImpl(dev_ctx, + out_grad, + starts, + ends, + steps, + axes, + decrease_axes, + none_axes, + x_grad, + value_grad); + break; + case 3: + SetValueGradImpl(dev_ctx, + out_grad, + starts, + ends, + steps, + axes, + decrease_axes, + none_axes, + x_grad, + value_grad); + break; + case 4: + SetValueGradImpl(dev_ctx, + out_grad, + starts, + ends, + steps, + axes, + decrease_axes, + none_axes, + x_grad, + value_grad); + break; + case 5: + SetValueGradImpl(dev_ctx, + out_grad, + starts, + ends, + steps, + axes, + decrease_axes, + none_axes, + x_grad, + value_grad); + break; + case 6: + SetValueGradImpl(dev_ctx, + out_grad, + starts, + ends, + steps, + axes, + decrease_axes, + none_axes, + x_grad, + value_grad); + break; + default: + PADDLE_THROW(phi::errors::InvalidArgument( + "The rank of set_value_grad's input should be less than 7, but " + "received %d.", + rank)); + } } } // namespace phi + +PD_REGISTER_KERNEL(set_value_grad, + XPU, + ALL_LAYOUT, + phi::SetValueGradKernel, + float, + phi::dtype::float16, + int, + int64_t) {} diff --git a/paddle/phi/kernels/xpu/set_value_kernel.cc b/paddle/phi/kernels/xpu/set_value_kernel.cc index 7d50c57025..4ad350a57c 100644 --- a/paddle/phi/kernels/xpu/set_value_kernel.cc +++ b/paddle/phi/kernels/xpu/set_value_kernel.cc @@ -1,4 +1,4 @@ -// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -23,92 +23,324 @@ #include "paddle/phi/core/dense_tensor.h" #include "paddle/phi/core/tensor_utils.h" #include "paddle/phi/kernels/empty_kernel.h" -#include "paddle/phi/kernels/funcs/broadcast_function.h" -#include "paddle/phi/kernels/funcs/eigen/common.h" -#include "paddle/phi/kernels/funcs/eigen/eigen_function.h" -#include "paddle/phi/kernels/funcs/elementwise_functor.h" +#include "paddle/phi/kernels/funcs/elementwise_base.h" #include "paddle/phi/kernels/funcs/slice_utils.h" +#include "paddle/phi/kernels/xpu/elementwise.h" namespace phi { -template -void SetTensorValueKernel(const Context& dev_ctx, - const DenseTensor& x, - const DenseTensor& value, - const IntArray& starts, - const IntArray& ends, - const IntArray& steps, - const std::vector& axes, - const std::vector& decrease_axes, - const std::vector& none_axes, - DenseTensor* out) { +// check whether the tensor with dimension of second can assign to the +// tensor with dimension of first +inline void CheckIsDimsMatch(const DDim& first, const DDim& second) { + int ignore_axis1 = 0, ignore_axis2 = 0; + for (; ignore_axis1 < first.size(); ++ignore_axis1) { + if (first[ignore_axis1] != 1) { + break; + } + } + for (; ignore_axis2 < second.size(); ++ignore_axis2) { + if (second[ignore_axis2] != 1) { + break; + } + } + + if (second.size() == ignore_axis2) { + // second tensor has only one value + return; + } + + if (first.size() - ignore_axis1 >= second.size() - ignore_axis2) { + auto idx1 = first.size() - 1; + auto idx2 = second.size() - 1; + bool is_match = true; + for (; idx2 >= ignore_axis2; idx2--) { + if (first[idx1--] != second[idx2] && second[idx2] != 1) { + is_match = false; + break; + } + } + if (is_match) { + return; + } + } + PADDLE_THROW(errors::InvalidArgument( + "The shape of tensor assigned value must match the shape " + "of target shape: %d, but now shape is %d.", + second.to_str(), + first.to_str())); +} + +template +void SetValueImpl(const Context& dev_ctx, + const DenseTensor& in, + const DenseTensor& value, + const IntArray& starts, + const IntArray& ends, + const IntArray& steps, + const std::vector& axes, + const std::vector& decrease_axes, + const std::vector& none_axes, + DenseTensor* out) { using XPUType = typename XPUTypeTrait::Type; - out->Resize(x.dims()); - dev_ctx.template Alloc(out); + auto in_dims = in.dims(); + std::vector starts_local = starts.GetData(); + std::vector ends_local = ends.GetData(); + std::vector steps_local = steps.GetData(); + phi::funcs::CheckAndUpdateSliceAttrs( + in_dims, axes, &starts_local, &ends_local, &steps_local); + auto slice_dims = phi::funcs::GetSliceDims( + in_dims, axes, starts_local, ends_local, &steps_local); + auto decrease_slice_dims = + phi::funcs::GetDecreasedDims(slice_dims, decrease_axes); - const XPUType* x_data = reinterpret_cast(x.data()); - const XPUType* v_data = reinterpret_cast(value.data()); - XPUType* y_data = reinterpret_cast(out->data()); + auto slice_dims_for_assign = decrease_slice_dims; + if (!none_axes.empty()) { + std::vector slice_dims_with_none; - std::vector starts_vec = starts.GetData(); - std::vector ends_vec = ends.GetData(); - std::vector steps_vec = steps.GetData(); + size_t none_axes_cur = 0, decrease_axes_cur = 0; + for (int i = 0; i < slice_dims.size(); ++i) { + while (none_axes_cur < none_axes.size() && + none_axes[none_axes_cur] <= i) { + slice_dims_with_none.push_back(1); + none_axes_cur++; + } + if (decrease_axes_cur < decrease_axes.size() && + decrease_axes[decrease_axes_cur] == i) { + decrease_axes_cur++; + } else { + slice_dims_with_none.push_back(slice_dims[i]); + } + } + while (none_axes_cur < none_axes.size()) { + slice_dims_with_none.push_back(1); + none_axes_cur++; + } - std::vector starts_vec_int32; - for (size_t i = 0; i < starts_vec.size(); ++i) { - starts_vec_int32.push_back(starts_vec[i]); + slice_dims_for_assign = phi::make_ddim(slice_dims_with_none); } - std::vector ends_vec_int32; - for (size_t i = 0; i < ends_vec.size(); ++i) { - ends_vec_int32.push_back(ends_vec[i]); - } + auto place = dev_ctx.GetPlace(); - std::vector steps_vec_int32; - for (size_t i = 0; i < steps_vec.size(); ++i) { - steps_vec_int32.push_back(steps_vec[i]); - } + // Here copy data from input to avoid data loss at PE and Graph level. + // TODO(liym27): Speed up in the future version. + // - Q: Why don't call ShareDataWith to speed up? + // - A: Because it's not supported to ShareDataWith on OP's input and output + // https://github.com/PaddlePaddle/Paddle/wiki/ShareDataWith-and-ShareBufferWith-are-prohibited-in-OP + // - Q: Why don't delete Input, after all, the input and output are the same + // Tensor at program level? + // - A: If deleting Input, the graph will be complex, such as there will + // be two ops points to the output in graph: op1 -> output <- set_value. + // In this case, we have to find a way to handle the running order of + // set_value is what we want. + Copy(dev_ctx, in, place, false, out); - std::vector axes_int32; - for (size_t i = 0; i < axes.size(); ++i) { - axes_int32.push_back(axes[i]); - } + DenseTensor slice_tensor = + Empty(dev_ctx, IntArray{slice_dims.Get(), slice_dims.size()}); - std::vector decrease_axes_int32; - for (size_t i = 0; i < decrease_axes.size(); ++i) { - decrease_axes_int32.push_back(decrease_axes[i]); - } + int in_size = in_dims.size(); + std::vector starts_indices(in_size, 0); + std::vector ends_indices(in_size, 0); + std::vector strides_indices(in_size, 0); + std::vector flip_axis; - std::vector none_axes_int32; - for (size_t i = 0; i < none_axes.size(); ++i) { - none_axes_int32.push_back(none_axes[i]); + for (size_t i = 0; i < RANK; ++i) { + starts_indices[i] = 0; + ends_indices[i] = slice_dims[i]; + strides_indices[i] = 1; + } + for (size_t i = 0; i < axes.size(); i++) { + int axis_index = axes[i]; + starts_indices[axis_index] = starts_local[i]; + ends_indices[axis_index] = ends_local[i]; + strides_indices[axis_index] = steps_local[i]; + if (starts_local[i] == + ends_local[i]) { // slice is empty, data will not be changed + return; + } } - auto x_dims = x.dims(); - std::vector x_shape; - for (int i = 0; i < x_dims.size(); ++i) { - x_shape.push_back(x_dims[i]); + // Because strided_slice does not support the case of stride < 0 + // temporarily, the coordinates of starts_indices, ends_indices + // and strides_indices need to be converted. + // This logic may be deleted in the future. + bool need_flip = false; + for (size_t i = 0; i < RANK; ++i) { + if (strides_indices[i] < 0) { + if (!need_flip) { + need_flip = true; + } + flip_axis.push_back(i); + strides_indices[i] = strides_indices[i] * (-1); + ends_indices[i] = starts_indices[i] + 1; + starts_indices[i] = + starts_indices[i] - (slice_dims[i] - 1) * strides_indices[i]; + } } - auto v_dims = value.dims(); - std::vector v_shape; - for (int i = 0; i < v_dims.size(); ++i) { - v_shape.push_back(v_dims[i]); + auto out_shape = phi::vectorize(out->dims()); + auto slice_shape = phi::vectorize(slice_dims); + int r = XPU_SUCCESS; + r = xpu::strided_slice(dev_ctx.x_context(), + reinterpret_cast(out->data()), + reinterpret_cast(slice_tensor.data()), + out_shape, + starts_indices, + ends_indices, + strides_indices); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "strided_slice"); + + r = xpu::constant(dev_ctx.x_context(), + reinterpret_cast(slice_tensor.data()), + slice_tensor.numel(), + XPUType(0)); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "constant"); + + // Step 2: Set a tensor with the same shape as out tensor. And its data at + // '_index' is the same as value, and data out of '_index' to zero + + // - Step 2.1 Set slice tensor with value + + // NOTE(liym27): [ Why resize slice_tensor here? ] + // A: When do broadcasting on slice_tensor and value, the shape of + // slice_tensor should be decreased dims. + // e.g. + // x[:,0] = value + // x's shape = [3, 4], value's shape = [3] + // We get slice_dims = [3, 1], decrease_slice_dims = [3] + // If do broadcasting on Tensor with shape [3, 1] and [3], the result's + // shape is [3, 3], which cross the border; + // If do broadcasting on Tensor with shape [3] and [3], the result's shape + // is [3], which is right. + + slice_tensor.Resize(slice_dims_for_assign); + CheckIsDimsMatch(slice_dims_for_assign, value.dims()); + // XPUElementwise can do broadcasting + auto f = [](xpu::Context* ctx, + const XPUType* x, + const XPUType* y, + XPUType* z, + const std::vector& xshape, + const std::vector& yshape) { + return xpu::broadcast_add(ctx, x, y, z, xshape, yshape); + }; + XPUElementwise( + dev_ctx, slice_tensor, value, -1, &slice_tensor, f); + + slice_tensor.Resize(slice_dims); + + // - Step 2.2 If stride < 0, flip the slice_tensor. + if (need_flip) { + r = xpu::flip(dev_ctx.x_context(), + reinterpret_cast(slice_tensor.data()), + reinterpret_cast(slice_tensor.data()), + slice_shape, + flip_axis); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "flip"); } + // Step 3: Set out tensor with value + r = xpu::strided_slice_view_update( + dev_ctx.x_context(), + reinterpret_cast(slice_tensor.data()), + reinterpret_cast(out->data()), + slice_shape, + out_shape, + starts_indices, + ends_indices, + strides_indices); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "strided_slice_view_update"); +} - int r = xpu::set_value(dev_ctx.x_context(), - x_data, - v_data, - y_data, - x_shape, - v_shape, - starts_vec_int32, - ends_vec_int32, - steps_vec_int32, - axes_int32, - decrease_axes_int32, - none_axes_int32); - PADDLE_ENFORCE_XDNN_SUCCESS(r, "set_value"); +template +void SetTensorValueKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& value, + const IntArray& starts, + const IntArray& ends, + const IntArray& steps, + const std::vector& axes, + const std::vector& decrease_axes, + const std::vector& none_axes, + DenseTensor* out) { + // rank是xtensor的维度信息 + const int rank = x.dims().size(); + + switch (rank) { + case 1: + SetValueImpl(dev_ctx, + x, + value, + starts, + ends, + steps, + axes, + decrease_axes, + none_axes, + out); + break; + case 2: + SetValueImpl(dev_ctx, + x, + value, + starts, + ends, + steps, + axes, + decrease_axes, + none_axes, + out); + break; + case 3: + SetValueImpl(dev_ctx, + x, + value, + starts, + ends, + steps, + axes, + decrease_axes, + none_axes, + out); + break; + case 4: + SetValueImpl(dev_ctx, + x, + value, + starts, + ends, + steps, + axes, + decrease_axes, + none_axes, + out); + break; + case 5: + SetValueImpl(dev_ctx, + x, + value, + starts, + ends, + steps, + axes, + decrease_axes, + none_axes, + out); + break; + case 6: + SetValueImpl(dev_ctx, + x, + value, + starts, + ends, + steps, + axes, + decrease_axes, + none_axes, + out); + break; + default: + PADDLE_THROW(errors::InvalidArgument( + "The rank of input should be less than 7, but received %d.", rank)); + } } template @@ -145,3 +377,21 @@ void SetValueKernel(const Context& dev_ctx, } } // namespace phi + +PD_REGISTER_KERNEL(set_value, + XPU, + ALL_LAYOUT, + phi::SetValueKernel, + float, + phi::dtype::float16, + int, + int64_t) {} + +PD_REGISTER_KERNEL(set_value_with_tensor, + XPU, + ALL_LAYOUT, + phi::SetTensorValueKernel, + float, + phi::dtype::float16, + int, + int64_t) {} -- GitLab