未验证 提交 438975fd 编写于 作者: L Leo Guo 提交者: GitHub

Fix the bugs of set_value and set_value_grad ops and add register in (#49750)

xpu2_op_list.cc. test=kunlun
上级 8fabf417
...@@ -478,6 +478,16 @@ XPUOpMap& get_kl2_ops() { ...@@ -478,6 +478,16 @@ XPUOpMap& get_kl2_ops() {
phi::DataType::FLOAT32})}, phi::DataType::FLOAT32})},
{"sampling_id", {"sampling_id",
XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT64})}, XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT64})},
{"set_value",
XPUKernelSet({phi::DataType::FLOAT32,
phi::DataType::INT32,
phi::DataType::INT64,
phi::DataType::FLOAT16})},
{"set_value_grad",
XPUKernelSet({phi::DataType::FLOAT32,
phi::DataType::INT32,
phi::DataType::INT64,
phi::DataType::FLOAT16})},
{"sgd", XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, {"sgd", XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})},
{"sgd_dense_param_sparse_grad", {"sgd_dense_param_sparse_grad",
XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})},
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. // Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
// //
// Licensed under the Apache License, Version 2.0 (the "License"); // Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License. // you may not use this file except in compliance with the License.
...@@ -19,20 +19,36 @@ ...@@ -19,20 +19,36 @@
#include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/common/int_array.h" #include "paddle/phi/common/int_array.h"
#include "paddle/phi/common/scalar.h"
#include "paddle/phi/core/dense_tensor.h" #include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/tensor_utils.h" #include "paddle/phi/core/tensor_utils.h"
#include "paddle/phi/kernels/empty_kernel.h" #include "paddle/phi/kernels/full_kernel.h"
#include "paddle/phi/kernels/funcs/broadcast_function.h"
#include "paddle/phi/kernels/funcs/eigen/common.h" #include "paddle/phi/kernels/funcs/eigen/common.h"
#include "paddle/phi/kernels/funcs/eigen/eigen_function.h" #include "paddle/phi/kernels/funcs/math_function.h"
#include "paddle/phi/kernels/funcs/elementwise_functor.h" #include "paddle/phi/kernels/funcs/strided_slice.h"
#include "paddle/phi/kernels/funcs/slice_utils.h"
namespace phi { namespace phi {
template <typename T, typename Context> inline void GetOffsets(const DDim& big_dim,
void SetValueGradKernel(const Context& dev_ctx, const DDim& small_dim,
DDim start_offset,
int cur_dim,
std::vector<DDim>* offsets) {
if (cur_dim == big_dim.size()) {
offsets->push_back(start_offset);
return;
}
if (small_dim[cur_dim] == big_dim[cur_dim]) {
GetOffsets(big_dim, small_dim, start_offset, cur_dim + 1, offsets);
} else {
for (int i = 0; i < big_dim[cur_dim]; i++) {
GetOffsets(big_dim, small_dim, start_offset, cur_dim + 1, offsets);
start_offset[cur_dim] += 1;
}
}
}
template <typename T, typename Context, size_t RANK>
void SetValueGradImpl(const Context& dev_ctx,
const DenseTensor& out_grad, const DenseTensor& out_grad,
const IntArray& starts, const IntArray& starts,
const IntArray& ends, const IntArray& ends,
...@@ -43,79 +59,339 @@ void SetValueGradKernel(const Context& dev_ctx, ...@@ -43,79 +59,339 @@ void SetValueGradKernel(const Context& dev_ctx,
DenseTensor* x_grad, DenseTensor* x_grad,
DenseTensor* value_grad) { DenseTensor* value_grad) {
using XPUType = typename XPUTypeTrait<T>::Type; using XPUType = typename XPUTypeTrait<T>::Type;
x_grad->Resize(out_grad.dims()); PADDLE_ENFORCE_EQ(
dev_ctx.template Alloc<T>(x_grad); out_grad.IsInitialized(),
dev_ctx.template Alloc<T>(value_grad); true,
errors::PermissionDenied(
"The input of `set_value_grad`(out_grad) has not been initialized"));
auto in_dims = out_grad.dims();
auto in_dims_vector = phi::vectorize<int64_t>(in_dims);
std::vector<int> decrease_axis_int32(decrease_axes.begin(),
decrease_axes.end());
std::vector<int> axes_int32(axes.begin(), axes.end());
std::vector<int> infer_flags(axes.size(), 1);
std::vector<int64_t> out_dims_vector(in_dims.size(), -1);
std::vector<int64_t> starts_local = starts.GetData();
std::vector<int64_t> ends_local = ends.GetData();
std::vector<int64_t> steps_local = steps.GetData();
funcs::StridedSliceOutDims(starts_local,
ends_local,
steps_local,
axes_int32,
infer_flags,
in_dims,
decrease_axis_int32,
out_dims_vector.data(),
axes.size(),
false);
DDim out_dims(phi::make_ddim(out_dims_vector));
const XPUType* dy_data = reinterpret_cast<const XPUType*>(out_grad.data<T>()); std::vector<int> reverse_vector(starts_local.size(), 0);
XPUType* dx_data = reinterpret_cast<XPUType*>(x_grad->data<T>()); funcs::StridedSliceFunctor(starts_local.data(),
XPUType* dv_data = reinterpret_cast<XPUType*>(value_grad->data<T>()); ends_local.data(),
steps_local.data(),
axes_int32.data(),
reverse_vector.data(),
in_dims,
infer_flags,
decrease_axis_int32,
starts_local.size());
std::vector<int64_t> starts_vec = starts.GetData(); std::vector<int64_t> starts_indices(RANK, 0);
std::vector<int64_t> ends_vec = ends.GetData(); std::vector<int64_t> ends_indices(RANK, 0);
std::vector<int64_t> steps_vec = steps.GetData(); std::vector<int64_t> steps_indices(RANK, 0);
std::vector<bool> reverse_axis(RANK, 0);
std::vector<int64_t> flip_axis;
auto dy_dims = out_grad.dims(); for (size_t axis = 0; axis < RANK; axis++) {
std::vector<int> dy_shape; starts_indices[axis] = 0;
for (int i = 0; i < dy_dims.size(); ++i) { ends_indices[axis] = out_dims[axis];
dy_shape.push_back(dy_dims[i]); steps_indices[axis] = 1;
reverse_axis[axis] = false;
} }
auto dv_dims = value_grad->dims(); for (size_t axis = 0; axis < axes.size(); axis++) {
std::vector<int> dv_shape; int axis_index = axes[axis];
for (int i = 0; i < dv_dims.size(); ++i) { starts_indices[axis_index] = starts_local[axis];
dv_shape.push_back(dv_dims[i]); ends_indices[axis_index] = ends_local[axis];
steps_indices[axis_index] = steps_local[axis];
reverse_axis[axis_index] = (reverse_vector[axis] == 1) ? true : false;
} }
auto dx_dims = x_grad->dims(); for (size_t axis = 0; axis < RANK; axis++) {
std::vector<int> dx_shape; if (reverse_axis[axis]) {
for (int i = 0; i < dx_dims.size(); ++i) { flip_axis.push_back(axis);
dx_shape.push_back(dx_dims[i]); }
if (ends_indices[axis] > in_dims[axis]) {
ends_indices[axis] = in_dims[axis];
}
} }
std::vector<int> starts_vec_int32; bool need_reverse = false;
for (size_t i = 0; i < starts_vec.size(); ++i) { for (size_t axis = 0; axis < axes.size(); axis++) {
starts_vec_int32.push_back(starts_vec[i]); if (reverse_vector[axis] == 1) {
need_reverse = true;
break;
}
} }
std::vector<int> ends_vec_int32; phi::funcs::SetConstant<Context, T> set_zero;
for (size_t i = 0; i < ends_vec.size(); ++i) { int r = XPU_SUCCESS;
ends_vec_int32.push_back(ends_vec[i]);
if (x_grad) {
// Set gradient of `Input`
Copy(dev_ctx, out_grad, dev_ctx.GetPlace(), false, x_grad);
DenseTensor tmp = Full<T>(dev_ctx, out_dims_vector, static_cast<T>(0));
r = xpu::strided_slice_view_update(
dev_ctx.x_context(),
reinterpret_cast<const XPUType*>(tmp.data<T>()),
reinterpret_cast<XPUType*>(x_grad->data<T>()),
out_dims_vector,
phi::vectorize<int64_t>(x_grad->dims()),
starts_indices,
ends_indices,
steps_indices);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "strided_slice_view_update");
} }
if (value_grad) {
dev_ctx.template Alloc<T>(value_grad);
set_zero(dev_ctx, value_grad, static_cast<T>(0));
if (value_grad->dims() == out_dims) {
if (need_reverse) {
r = xpu::strided_slice(
dev_ctx.x_context(),
reinterpret_cast<const XPUType*>(out_grad.data<T>()),
reinterpret_cast<XPUType*>(value_grad->data<T>()),
in_dims_vector,
starts_indices,
ends_indices,
steps_indices);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "strided_slice");
std::vector<int> steps_vec_int32; r = xpu::flip(dev_ctx.x_context(),
for (size_t i = 0; i < steps_vec.size(); ++i) { reinterpret_cast<const XPUType*>(value_grad->data<T>()),
steps_vec_int32.push_back(steps_vec[i]); reinterpret_cast<XPUType*>(value_grad->data<T>()),
out_dims_vector,
flip_axis);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "flip");
} else {
r = xpu::strided_slice(
dev_ctx.x_context(),
reinterpret_cast<const XPUType*>(out_grad.data<T>()),
reinterpret_cast<XPUType*>(value_grad->data<T>()),
in_dims_vector,
starts_indices,
ends_indices,
steps_indices);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "strided_slice");
} }
} else {
int out_dims_size = out_dims.size();
auto value_grad_dims = value_grad->dims();
auto fake_value_grad_dims = out_dims;
// Create an extented shape according to the rules of broadcast.
auto value_grad_dims_size = value_grad_dims.size();
std::vector<int> axes_int32; int num_decrease = 0;
for (size_t i = 0; i < axes.size(); ++i) {
axes_int32.push_back(axes[i]); int decrease_axis_size = decrease_axes.size();
for (int i = 0; i < out_dims_size; i++) {
if (decrease_axes.end() !=
std::find(decrease_axes.begin(), decrease_axes.end(), i)) {
fake_value_grad_dims[i] = 1;
num_decrease++;
} else if (i < out_dims_size - (value_grad_dims_size +
decrease_axis_size - num_decrease)) {
fake_value_grad_dims[i] = 1;
} else {
auto index_grad =
i - (out_dims_size -
(value_grad_dims_size + decrease_axis_size - num_decrease));
fake_value_grad_dims[i] = value_grad_dims[index_grad];
PADDLE_ENFORCE_EQ((out_dims[i] == value_grad_dims[index_grad]) ||
(value_grad_dims[index_grad] == 1),
true,
errors::InvalidArgument(
"An error occurred while calculating %s: "
"[%s] can not be accumulated into [%s].",
paddle::framework::GradVarName("ValueTensor"),
out_dims,
value_grad_dims));
}
} }
std::vector<int> decrease_axes_int32; VLOG(3) << "Dimensions of "
for (size_t i = 0; i < decrease_axes.size(); ++i) { << paddle::framework::GradVarName("ValueTensor") << "(["
decrease_axes_int32.push_back(decrease_axes[i]); << value_grad_dims << "])is broadcasted into ["
<< fake_value_grad_dims << "].";
std::vector<int64_t> slice_end(RANK, 0);
auto offset = out_dims;
for (int i = 0; i < out_dims_size; i++) {
offset[i] = 0;
} }
std::vector<DDim> offsets;
GetOffsets(out_dims, fake_value_grad_dims, offset, 0, &offsets);
DenseTensor tmp = Full<T>(dev_ctx, out_dims_vector, static_cast<T>(0));
r = xpu::strided_slice(
dev_ctx.x_context(),
reinterpret_cast<const XPUType*>(out_grad.data<T>()),
reinterpret_cast<XPUType*>(tmp.data<T>()),
in_dims_vector,
starts_indices,
ends_indices,
steps_indices);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "strided_slice");
std::vector<int> none_axes_int32; // accumulate gradient
for (size_t i = 0; i < none_axes.size(); ++i) { DenseTensor tmp2 =
none_axes_int32.push_back(none_axes[i]); Full<T>(dev_ctx,
{fake_value_grad_dims.Get(), fake_value_grad_dims.size()},
static_cast<T>(0));
auto value_grad_dims_vec = phi::vectorize<int64_t>(value_grad_dims);
for (auto offset : offsets) {
for (int i = 0; i < out_dims_size; i++) {
slice_end[i] = offset[i] + fake_value_grad_dims[i];
}
r = xpu::slice(dev_ctx.x_context(),
reinterpret_cast<const XPUType*>(tmp.data<T>()),
reinterpret_cast<XPUType*>(tmp2.data<T>()),
out_dims_vector,
phi::vectorize<int64_t>(offset),
slice_end);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "slice");
r = xpu::broadcast_add(
dev_ctx.x_context(),
reinterpret_cast<const XPUType*>(value_grad->data<T>()),
reinterpret_cast<const XPUType*>(tmp2.data<T>()),
reinterpret_cast<XPUType*>(value_grad->data<T>()),
value_grad_dims_vec,
value_grad_dims_vec);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "broadcast_add");
}
if (need_reverse) {
r = xpu::flip(dev_ctx.x_context(),
reinterpret_cast<const XPUType*>(value_grad->data<T>()),
reinterpret_cast<XPUType*>(value_grad->data<T>()),
value_grad_dims_vec,
flip_axis);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "flip");
}
} }
}
}
int r = xpu::set_value_grad(dev_ctx.x_context(), template <typename T, typename Context>
dy_data, void SetValueGradKernel(const Context& dev_ctx,
dx_data, const DenseTensor& out_grad,
dv_data, const IntArray& starts,
dy_shape, const IntArray& ends,
dv_shape, const IntArray& steps,
starts_vec_int32, const std::vector<int64_t>& axes,
ends_vec_int32, const std::vector<int64_t>& decrease_axes,
steps_vec_int32, const std::vector<int64_t>& none_axes,
axes_int32, DenseTensor* x_grad,
decrease_axes_int32, DenseTensor* value_grad) {
none_axes_int32); const int rank = out_grad.dims().size();
PADDLE_ENFORCE_XDNN_SUCCESS(r, "set_value_grad");
switch (rank) {
case 1:
SetValueGradImpl<T, Context, 1>(dev_ctx,
out_grad,
starts,
ends,
steps,
axes,
decrease_axes,
none_axes,
x_grad,
value_grad);
break;
case 2:
SetValueGradImpl<T, Context, 2>(dev_ctx,
out_grad,
starts,
ends,
steps,
axes,
decrease_axes,
none_axes,
x_grad,
value_grad);
break;
case 3:
SetValueGradImpl<T, Context, 3>(dev_ctx,
out_grad,
starts,
ends,
steps,
axes,
decrease_axes,
none_axes,
x_grad,
value_grad);
break;
case 4:
SetValueGradImpl<T, Context, 4>(dev_ctx,
out_grad,
starts,
ends,
steps,
axes,
decrease_axes,
none_axes,
x_grad,
value_grad);
break;
case 5:
SetValueGradImpl<T, Context, 5>(dev_ctx,
out_grad,
starts,
ends,
steps,
axes,
decrease_axes,
none_axes,
x_grad,
value_grad);
break;
case 6:
SetValueGradImpl<T, Context, 6>(dev_ctx,
out_grad,
starts,
ends,
steps,
axes,
decrease_axes,
none_axes,
x_grad,
value_grad);
break;
default:
PADDLE_THROW(phi::errors::InvalidArgument(
"The rank of set_value_grad's input should be less than 7, but "
"received %d.",
rank));
}
} }
} // namespace phi } // namespace phi
PD_REGISTER_KERNEL(set_value_grad,
XPU,
ALL_LAYOUT,
phi::SetValueGradKernel,
float,
phi::dtype::float16,
int,
int64_t) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. // Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
// //
// Licensed under the Apache License, Version 2.0 (the "License"); // Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License. // you may not use this file except in compliance with the License.
...@@ -23,17 +23,56 @@ ...@@ -23,17 +23,56 @@
#include "paddle/phi/core/dense_tensor.h" #include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/tensor_utils.h" #include "paddle/phi/core/tensor_utils.h"
#include "paddle/phi/kernels/empty_kernel.h" #include "paddle/phi/kernels/empty_kernel.h"
#include "paddle/phi/kernels/funcs/broadcast_function.h" #include "paddle/phi/kernels/funcs/elementwise_base.h"
#include "paddle/phi/kernels/funcs/eigen/common.h"
#include "paddle/phi/kernels/funcs/eigen/eigen_function.h"
#include "paddle/phi/kernels/funcs/elementwise_functor.h"
#include "paddle/phi/kernels/funcs/slice_utils.h" #include "paddle/phi/kernels/funcs/slice_utils.h"
#include "paddle/phi/kernels/xpu/elementwise.h"
namespace phi { namespace phi {
template <typename T, typename Context> // check whether the tensor with dimension of second can assign to the
void SetTensorValueKernel(const Context& dev_ctx, // tensor with dimension of first
const DenseTensor& x, inline void CheckIsDimsMatch(const DDim& first, const DDim& second) {
int ignore_axis1 = 0, ignore_axis2 = 0;
for (; ignore_axis1 < first.size(); ++ignore_axis1) {
if (first[ignore_axis1] != 1) {
break;
}
}
for (; ignore_axis2 < second.size(); ++ignore_axis2) {
if (second[ignore_axis2] != 1) {
break;
}
}
if (second.size() == ignore_axis2) {
// second tensor has only one value
return;
}
if (first.size() - ignore_axis1 >= second.size() - ignore_axis2) {
auto idx1 = first.size() - 1;
auto idx2 = second.size() - 1;
bool is_match = true;
for (; idx2 >= ignore_axis2; idx2--) {
if (first[idx1--] != second[idx2] && second[idx2] != 1) {
is_match = false;
break;
}
}
if (is_match) {
return;
}
}
PADDLE_THROW(errors::InvalidArgument(
"The shape of tensor assigned value must match the shape "
"of target shape: %d, but now shape is %d.",
second.to_str(),
first.to_str()));
}
template <typename T, typename Context, size_t RANK>
void SetValueImpl(const Context& dev_ctx,
const DenseTensor& in,
const DenseTensor& value, const DenseTensor& value,
const IntArray& starts, const IntArray& starts,
const IntArray& ends, const IntArray& ends,
...@@ -43,72 +82,265 @@ void SetTensorValueKernel(const Context& dev_ctx, ...@@ -43,72 +82,265 @@ void SetTensorValueKernel(const Context& dev_ctx,
const std::vector<int64_t>& none_axes, const std::vector<int64_t>& none_axes,
DenseTensor* out) { DenseTensor* out) {
using XPUType = typename XPUTypeTrait<T>::Type; using XPUType = typename XPUTypeTrait<T>::Type;
out->Resize(x.dims()); auto in_dims = in.dims();
dev_ctx.template Alloc<T>(out); std::vector<int64_t> starts_local = starts.GetData();
std::vector<int64_t> ends_local = ends.GetData();
std::vector<int64_t> steps_local = steps.GetData();
phi::funcs::CheckAndUpdateSliceAttrs(
in_dims, axes, &starts_local, &ends_local, &steps_local);
auto slice_dims = phi::funcs::GetSliceDims(
in_dims, axes, starts_local, ends_local, &steps_local);
auto decrease_slice_dims =
phi::funcs::GetDecreasedDims(slice_dims, decrease_axes);
const XPUType* x_data = reinterpret_cast<const XPUType*>(x.data<T>()); auto slice_dims_for_assign = decrease_slice_dims;
const XPUType* v_data = reinterpret_cast<const XPUType*>(value.data<T>()); if (!none_axes.empty()) {
XPUType* y_data = reinterpret_cast<XPUType*>(out->data<T>()); std::vector<int64_t> slice_dims_with_none;
std::vector<int64_t> starts_vec = starts.GetData(); size_t none_axes_cur = 0, decrease_axes_cur = 0;
std::vector<int64_t> ends_vec = ends.GetData(); for (int i = 0; i < slice_dims.size(); ++i) {
std::vector<int64_t> steps_vec = steps.GetData(); while (none_axes_cur < none_axes.size() &&
none_axes[none_axes_cur] <= i) {
std::vector<int> starts_vec_int32; slice_dims_with_none.push_back(1);
for (size_t i = 0; i < starts_vec.size(); ++i) { none_axes_cur++;
starts_vec_int32.push_back(starts_vec[i]);
} }
if (decrease_axes_cur < decrease_axes.size() &&
std::vector<int> ends_vec_int32; decrease_axes[decrease_axes_cur] == i) {
for (size_t i = 0; i < ends_vec.size(); ++i) { decrease_axes_cur++;
ends_vec_int32.push_back(ends_vec[i]); } else {
slice_dims_with_none.push_back(slice_dims[i]);
} }
std::vector<int> steps_vec_int32;
for (size_t i = 0; i < steps_vec.size(); ++i) {
steps_vec_int32.push_back(steps_vec[i]);
} }
while (none_axes_cur < none_axes.size()) {
std::vector<int> axes_int32; slice_dims_with_none.push_back(1);
for (size_t i = 0; i < axes.size(); ++i) { none_axes_cur++;
axes_int32.push_back(axes[i]);
} }
std::vector<int> decrease_axes_int32; slice_dims_for_assign = phi::make_ddim(slice_dims_with_none);
for (size_t i = 0; i < decrease_axes.size(); ++i) {
decrease_axes_int32.push_back(decrease_axes[i]);
} }
std::vector<int> none_axes_int32; auto place = dev_ctx.GetPlace();
for (size_t i = 0; i < none_axes.size(); ++i) {
none_axes_int32.push_back(none_axes[i]); // Here copy data from input to avoid data loss at PE and Graph level.
// TODO(liym27): Speed up in the future version.
// - Q: Why don't call ShareDataWith to speed up?
// - A: Because it's not supported to ShareDataWith on OP's input and output
// https://github.com/PaddlePaddle/Paddle/wiki/ShareDataWith-and-ShareBufferWith-are-prohibited-in-OP
// - Q: Why don't delete Input, after all, the input and output are the same
// Tensor at program level?
// - A: If deleting Input, the graph will be complex, such as there will
// be two ops points to the output in graph: op1 -> output <- set_value.
// In this case, we have to find a way to handle the running order of
// set_value is what we want.
Copy(dev_ctx, in, place, false, out);
DenseTensor slice_tensor =
Empty<T>(dev_ctx, IntArray{slice_dims.Get(), slice_dims.size()});
int in_size = in_dims.size();
std::vector<int> starts_indices(in_size, 0);
std::vector<int> ends_indices(in_size, 0);
std::vector<int> strides_indices(in_size, 0);
std::vector<int> flip_axis;
for (size_t i = 0; i < RANK; ++i) {
starts_indices[i] = 0;
ends_indices[i] = slice_dims[i];
strides_indices[i] = 1;
}
for (size_t i = 0; i < axes.size(); i++) {
int axis_index = axes[i];
starts_indices[axis_index] = starts_local[i];
ends_indices[axis_index] = ends_local[i];
strides_indices[axis_index] = steps_local[i];
if (starts_local[i] ==
ends_local[i]) { // slice is empty, data will not be changed
return;
}
} }
auto x_dims = x.dims(); // Because strided_slice does not support the case of stride < 0
std::vector<int> x_shape; // temporarily, the coordinates of starts_indices, ends_indices
for (int i = 0; i < x_dims.size(); ++i) { // and strides_indices need to be converted.
x_shape.push_back(x_dims[i]); // This logic may be deleted in the future.
bool need_flip = false;
for (size_t i = 0; i < RANK; ++i) {
if (strides_indices[i] < 0) {
if (!need_flip) {
need_flip = true;
}
flip_axis.push_back(i);
strides_indices[i] = strides_indices[i] * (-1);
ends_indices[i] = starts_indices[i] + 1;
starts_indices[i] =
starts_indices[i] - (slice_dims[i] - 1) * strides_indices[i];
} }
}
auto out_shape = phi::vectorize<int>(out->dims());
auto slice_shape = phi::vectorize<int>(slice_dims);
int r = XPU_SUCCESS;
r = xpu::strided_slice(dev_ctx.x_context(),
reinterpret_cast<const XPUType*>(out->data<T>()),
reinterpret_cast<XPUType*>(slice_tensor.data<T>()),
out_shape,
starts_indices,
ends_indices,
strides_indices);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "strided_slice");
r = xpu::constant(dev_ctx.x_context(),
reinterpret_cast<XPUType*>(slice_tensor.data<T>()),
slice_tensor.numel(),
XPUType(0));
PADDLE_ENFORCE_XDNN_SUCCESS(r, "constant");
// Step 2: Set a tensor with the same shape as out tensor. And its data at
// '_index' is the same as value, and data out of '_index' to zero
// - Step 2.1 Set slice tensor with value
auto v_dims = value.dims(); // NOTE(liym27): [ Why resize slice_tensor here? ]
std::vector<int> v_shape; // A: When do broadcasting on slice_tensor and value, the shape of
for (int i = 0; i < v_dims.size(); ++i) { // slice_tensor should be decreased dims.
v_shape.push_back(v_dims[i]); // e.g.
// x[:,0] = value
// x's shape = [3, 4], value's shape = [3]
// We get slice_dims = [3, 1], decrease_slice_dims = [3]
// If do broadcasting on Tensor with shape [3, 1] and [3], the result's
// shape is [3, 3], which cross the border;
// If do broadcasting on Tensor with shape [3] and [3], the result's shape
// is [3], which is right.
slice_tensor.Resize(slice_dims_for_assign);
CheckIsDimsMatch(slice_dims_for_assign, value.dims());
// XPUElementwise can do broadcasting
auto f = [](xpu::Context* ctx,
const XPUType* x,
const XPUType* y,
XPUType* z,
const std::vector<int>& xshape,
const std::vector<int>& yshape) {
return xpu::broadcast_add<XPUType>(ctx, x, y, z, xshape, yshape);
};
XPUElementwise<T, XPUType>(
dev_ctx, slice_tensor, value, -1, &slice_tensor, f);
slice_tensor.Resize(slice_dims);
// - Step 2.2 If stride < 0, flip the slice_tensor.
if (need_flip) {
r = xpu::flip(dev_ctx.x_context(),
reinterpret_cast<const XPUType*>(slice_tensor.data<T>()),
reinterpret_cast<XPUType*>(slice_tensor.data<T>()),
slice_shape,
flip_axis);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "flip");
} }
// Step 3: Set out tensor with value
r = xpu::strided_slice_view_update(
dev_ctx.x_context(),
reinterpret_cast<const XPUType*>(slice_tensor.data<T>()),
reinterpret_cast<XPUType*>(out->data<T>()),
slice_shape,
out_shape,
starts_indices,
ends_indices,
strides_indices);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "strided_slice_view_update");
}
int r = xpu::set_value(dev_ctx.x_context(), template <typename T, typename Context>
x_data, void SetTensorValueKernel(const Context& dev_ctx,
v_data, const DenseTensor& x,
y_data, const DenseTensor& value,
x_shape, const IntArray& starts,
v_shape, const IntArray& ends,
starts_vec_int32, const IntArray& steps,
ends_vec_int32, const std::vector<int64_t>& axes,
steps_vec_int32, const std::vector<int64_t>& decrease_axes,
axes_int32, const std::vector<int64_t>& none_axes,
decrease_axes_int32, DenseTensor* out) {
none_axes_int32); // rank是xtensor的维度信息
PADDLE_ENFORCE_XDNN_SUCCESS(r, "set_value"); const int rank = x.dims().size();
switch (rank) {
case 1:
SetValueImpl<T, Context, 1>(dev_ctx,
x,
value,
starts,
ends,
steps,
axes,
decrease_axes,
none_axes,
out);
break;
case 2:
SetValueImpl<T, Context, 2>(dev_ctx,
x,
value,
starts,
ends,
steps,
axes,
decrease_axes,
none_axes,
out);
break;
case 3:
SetValueImpl<T, Context, 3>(dev_ctx,
x,
value,
starts,
ends,
steps,
axes,
decrease_axes,
none_axes,
out);
break;
case 4:
SetValueImpl<T, Context, 4>(dev_ctx,
x,
value,
starts,
ends,
steps,
axes,
decrease_axes,
none_axes,
out);
break;
case 5:
SetValueImpl<T, Context, 5>(dev_ctx,
x,
value,
starts,
ends,
steps,
axes,
decrease_axes,
none_axes,
out);
break;
case 6:
SetValueImpl<T, Context, 6>(dev_ctx,
x,
value,
starts,
ends,
steps,
axes,
decrease_axes,
none_axes,
out);
break;
default:
PADDLE_THROW(errors::InvalidArgument(
"The rank of input should be less than 7, but received %d.", rank));
}
} }
template <typename T, typename Context> template <typename T, typename Context>
...@@ -145,3 +377,21 @@ void SetValueKernel(const Context& dev_ctx, ...@@ -145,3 +377,21 @@ void SetValueKernel(const Context& dev_ctx,
} }
} // namespace phi } // namespace phi
PD_REGISTER_KERNEL(set_value,
XPU,
ALL_LAYOUT,
phi::SetValueKernel,
float,
phi::dtype::float16,
int,
int64_t) {}
PD_REGISTER_KERNEL(set_value_with_tensor,
XPU,
ALL_LAYOUT,
phi::SetTensorValueKernel,
float,
phi::dtype::float16,
int,
int64_t) {}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册