未验证 提交 d2b0d63f 编写于 作者: Z zhangyuqin1998 提交者: GitHub

rename_SliceKernel (#52863)

上级 514d83de
...@@ -20,7 +20,7 @@ ...@@ -20,7 +20,7 @@
namespace phi { namespace phi {
template <typename T, typename Context> template <typename T, typename Context>
void AccuracyRawKernel(const Context& dev_ctx, void AccuracyKernel(const Context& dev_ctx,
const DenseTensor& out, const DenseTensor& out,
const DenseTensor& indices, const DenseTensor& indices,
const DenseTensor& label, const DenseTensor& label,
......
...@@ -21,7 +21,7 @@ ...@@ -21,7 +21,7 @@
namespace phi { namespace phi {
template <typename T, typename Context> template <typename T, typename Context>
void AccuracyRawKernel(const Context& dev_ctx, void AccuracyKernel(const Context& dev_ctx,
const DenseTensor& inference, const DenseTensor& inference,
const DenseTensor& indices, const DenseTensor& indices,
const DenseTensor& label, const DenseTensor& label,
...@@ -93,7 +93,7 @@ void AccuracyRawKernel(const Context& dev_ctx, ...@@ -93,7 +93,7 @@ void AccuracyRawKernel(const Context& dev_ctx,
// TODO(add supported dtype.) // TODO(add supported dtype.)
PD_REGISTER_KERNEL( PD_REGISTER_KERNEL(
accuracy, CPU, ALL_LAYOUT, phi::AccuracyRawKernel, float, double) { accuracy, CPU, ALL_LAYOUT, phi::AccuracyKernel, float, double) {
kernel->InputAt(1).SetDataType(phi::DataType::INT64); kernel->InputAt(1).SetDataType(phi::DataType::INT64);
kernel->InputAt(2).SetDataType(phi::DataType::INT64); kernel->InputAt(2).SetDataType(phi::DataType::INT64);
kernel->OutputAt(0).SetDataType(phi::DataType::FLOAT32); kernel->OutputAt(0).SetDataType(phi::DataType::FLOAT32);
......
...@@ -21,7 +21,7 @@ ...@@ -21,7 +21,7 @@
PD_REGISTER_KERNEL(slice, PD_REGISTER_KERNEL(slice,
CPU, CPU,
ALL_LAYOUT, ALL_LAYOUT,
phi::SliceRawKernel, phi::SliceKernel,
bool, bool,
uint8_t, uint8_t,
int, int,
......
...@@ -73,7 +73,7 @@ __global__ void AccuracyCudaKernel(const int N, ...@@ -73,7 +73,7 @@ __global__ void AccuracyCudaKernel(const int N,
} }
template <typename T, typename Context> template <typename T, typename Context>
void AccuracyRawKernel(const Context& dev_ctx, void AccuracyKernel(const Context& dev_ctx,
const DenseTensor& inference, const DenseTensor& inference,
const DenseTensor& indices, const DenseTensor& indices,
const DenseTensor& label, const DenseTensor& label,
...@@ -137,7 +137,7 @@ void AccuracyRawKernel(const Context& dev_ctx, ...@@ -137,7 +137,7 @@ void AccuracyRawKernel(const Context& dev_ctx,
PD_REGISTER_KERNEL(accuracy, PD_REGISTER_KERNEL(accuracy,
GPU, GPU,
ALL_LAYOUT, ALL_LAYOUT,
phi::AccuracyRawKernel, phi::AccuracyKernel,
phi::dtype::float16, phi::dtype::float16,
phi::dtype::bfloat16, phi::dtype::bfloat16,
float, float,
......
...@@ -101,8 +101,8 @@ void QrKernel(const Context& ctx, ...@@ -101,8 +101,8 @@ void QrKernel(const Context& ctx,
if (reduced_mode) { if (reduced_mode) {
auto trans_qr = TransposeLast2Dim<T, Context>(ctx, qr); auto trans_qr = TransposeLast2Dim<T, Context>(ctx, qr);
auto sliced_qr = SliceKernel<T, Context>( auto sliced_qr = Slice<T, Context>(
ctx, trans_qr, {trans_qr.dims().size() - 2}, {0}, {min_mn}, {1}, {}); ctx, trans_qr, {trans_qr.dims().size() - 2}, {0}, {min_mn});
auto tmp_r = TrilTriu<T, Context>(ctx, sliced_qr, 0, false); auto tmp_r = TrilTriu<T, Context>(ctx, sliced_qr, 0, false);
// Transpose 'tmp_r' to retore the original row-major order // Transpose 'tmp_r' to retore the original row-major order
phi::Copy(ctx, tmp_r, r->place(), false, r); phi::Copy(ctx, tmp_r, r->place(), false, r);
...@@ -128,8 +128,8 @@ void QrKernel(const Context& ctx, ...@@ -128,8 +128,8 @@ void QrKernel(const Context& ctx,
qr_stride, qr_stride,
tau_stride); tau_stride);
auto trans_q = TransposeLast2Dim<T, Context>(ctx, qr); auto trans_q = TransposeLast2Dim<T, Context>(ctx, qr);
auto sliced_q = SliceKernel<T, Context>( auto sliced_q = Slice<T, Context>(
ctx, trans_q, {trans_q.dims().size() - 1}, {0}, {min_mn}, {1}, {}); ctx, trans_q, {trans_q.dims().size() - 1}, {0}, {min_mn});
phi::Copy(ctx, sliced_q, q->place(), false, q); phi::Copy(ctx, sliced_q, q->place(), false, q);
} else { } else {
if (m > n) { if (m > n) {
...@@ -170,8 +170,8 @@ void QrKernel(const Context& ctx, ...@@ -170,8 +170,8 @@ void QrKernel(const Context& ctx,
qr_stride, qr_stride,
tau_stride); tau_stride);
auto trans_q = TransposeLast2Dim<T, Context>(ctx, qr); auto trans_q = TransposeLast2Dim<T, Context>(ctx, qr);
auto sliced_q = SliceKernel<T, Context>( auto sliced_q = Slice<T, Context>(
ctx, trans_q, {trans_q.dims().size() - 1}, {0}, {m}, {1}, {}); ctx, trans_q, {trans_q.dims().size() - 1}, {0}, {m});
phi::Copy(ctx, sliced_q, q->place(), false, q); phi::Copy(ctx, sliced_q, q->place(), false, q);
} }
} }
......
...@@ -21,7 +21,7 @@ ...@@ -21,7 +21,7 @@
PD_REGISTER_KERNEL(slice, PD_REGISTER_KERNEL(slice,
GPU, GPU,
ALL_LAYOUT, ALL_LAYOUT,
phi::SliceRawKernel, phi::SliceKernel,
bool, bool,
uint8_t, uint8_t,
int, int,
......
...@@ -149,17 +149,13 @@ void QrGradKernel(const Context& ctx, ...@@ -149,17 +149,13 @@ void QrGradKernel(const Context& ctx,
// Calculate dX and dY individually and concatenate them to get dA // Calculate dX and dY individually and concatenate them to get dA
ctx.template Alloc<phi::dtype::Real<T>>(&dA); ctx.template Alloc<phi::dtype::Real<T>>(&dA);
auto Y = SliceKernel<T, Context>( auto Y = Slice<T, Context>(ctx, A, {A.dims().size() - 1}, {m}, {n});
ctx, A, {A.dims().size() - 1}, {m}, {n}, {1}, {}); auto U = Slice<T, Context>(ctx, R, {R.dims().size() - 1}, {0}, {m});
auto U = SliceKernel<T, Context>(
ctx, R, {R.dims().size() - 1}, {0}, {m}, {1}, {});
DenseTensor dY, dX, dV, dR_tmp, dQ_prime; DenseTensor dY, dX, dV, dR_tmp, dQ_prime;
if (dR.initialized()) { if (dR.initialized()) {
dV = SliceKernel<T, Context>( dV = Slice<T, Context>(ctx, dR, {dR.dims().size() - 1}, {m}, {n});
ctx, dR, {dR.dims().size() - 1}, {m}, {n}, {1}, {}); dR_tmp = Slice<T, Context>(ctx, dR, {dR.dims().size() - 1}, {0}, {m});
dR_tmp = SliceKernel<T, Context>(
ctx, dR, {dR.dims().size() - 1}, {0}, {m}, {1}, {});
// Y * dV^H // Y * dV^H
dQ_prime = dQ_prime =
Matmul<T, Context>(ctx, Y, TransposeLast2Dim<T, Context>(ctx, dV)); Matmul<T, Context>(ctx, Y, TransposeLast2Dim<T, Context>(ctx, dV));
......
...@@ -100,7 +100,7 @@ void SliceCompute(const Context& ctx, ...@@ -100,7 +100,7 @@ void SliceCompute(const Context& ctx,
} }
template <typename T, typename Context> template <typename T, typename Context>
void SliceRawKernel(const Context& ctx, void SliceKernel(const Context& ctx,
const DenseTensor& input, const DenseTensor& input,
const std::vector<int64_t>& axes, const std::vector<int64_t>& axes,
const IntArray& starts_arr, const IntArray& starts_arr,
......
...@@ -83,27 +83,17 @@ void SvdGradKernel(const Context& dev_ctx, ...@@ -83,27 +83,17 @@ void SvdGradKernel(const Context& dev_ctx,
DenseTensor U, VH, dU, dV, dVH; DenseTensor U, VH, dU, dV, dVH;
if (full_matrices) { if (full_matrices) {
// if full_matrices is set, slice the U and VT to k columns // if full_matrices is set, slice the U and VT to k columns
U = SliceKernel<T, Context>( U = Slice<T, Context>(dev_ctx, u, {u.dims().size() - 1}, {0}, {k});
dev_ctx, u, {u.dims().size() - 1}, {0}, {k}, {1}, {}); // If m < n for input matrices A, we partition A = [X|Y] and R = [U|V]
VH = SliceKernel<T, Context>(
dev_ctx, vh, {vh.dims().size() - 2}, {0}, {k}, {1}, {}); VH = Slice<T, Context>(dev_ctx, vh, {vh.dims().size() - 2}, {0}, {k});
if (u_grad.get_ptr() != nullptr) { if (u_grad.get_ptr() != nullptr) {
dU = SliceKernel<T, Context>(dev_ctx, dU = Slice<T, Context>(
*(u_grad.get_ptr()), dev_ctx, *(u_grad.get_ptr()), {u.dims().size() - 1}, {0}, {k});
{u.dims().size() - 1},
{0},
{k},
{1},
{});
} }
if (vh_grad.get_ptr() != nullptr) { if (vh_grad.get_ptr() != nullptr) {
dVH = SliceKernel<T, Context>(dev_ctx, dVH = Slice<T, Context>(
*(vh_grad.get_ptr()), dev_ctx, *(vh_grad.get_ptr()), {vh.dims().size() - 2}, {0}, {k});
{vh.dims().size() - 2},
{0},
{k},
{1},
{});
} }
} else { } else {
U = u; U = u;
......
...@@ -20,7 +20,7 @@ ...@@ -20,7 +20,7 @@
namespace phi { namespace phi {
template <typename T, typename Context> template <typename T, typename Context>
void SliceRawKernel(const Context& dev_ctx, void SliceKernel(const Context& dev_ctx,
const DenseTensor& x, const DenseTensor& x,
const std::vector<int64_t>& axes, const std::vector<int64_t>& axes,
const IntArray& starts, const IntArray& starts,
...@@ -102,7 +102,7 @@ void SliceRawKernel(const Context& dev_ctx, ...@@ -102,7 +102,7 @@ void SliceRawKernel(const Context& dev_ctx,
PD_REGISTER_KERNEL(slice, PD_REGISTER_KERNEL(slice,
OneDNN, OneDNN,
ONEDNN, ONEDNN,
phi::SliceRawKernel, phi::SliceKernel,
float, float,
int8_t, int8_t,
uint8_t, uint8_t,
......
...@@ -22,7 +22,7 @@ ...@@ -22,7 +22,7 @@
namespace phi { namespace phi {
template <typename T, typename Context> template <typename T, typename Context>
void SliceRawKernel(const Context& ctx, void SliceKernel(const Context& ctx,
const DenseTensor& input, const DenseTensor& input,
const std::vector<int64_t>& axes, const std::vector<int64_t>& axes,
const IntArray& starts, const IntArray& starts,
...@@ -45,18 +45,18 @@ void SliceArrayDenseKernel(const Context& dev_ctx, ...@@ -45,18 +45,18 @@ void SliceArrayDenseKernel(const Context& dev_ctx,
DenseTensor* out); DenseTensor* out);
template <typename T, typename Context> template <typename T, typename Context>
DenseTensor SliceKernel(const Context& ctx, DenseTensor Slice(const Context& ctx,
const DenseTensor& input, const DenseTensor& input,
const std::vector<int64_t>& axes, const std::vector<int64_t>& axes,
const IntArray& starts, const IntArray& starts,
const IntArray& ends, const IntArray& ends) {
const std::vector<int64_t>& infer_flags,
const std::vector<int64_t>& decrease_axis) {
DenseTensor dense_out; DenseTensor dense_out;
MetaTensor meta_out(&dense_out); MetaTensor meta_out(&dense_out);
std::vector<int64_t> infer_flags = {1};
std::vector<int64_t> decrease_axis = {};
SliceRawInferMeta( SliceRawInferMeta(
input, axes, starts, ends, infer_flags, decrease_axis, &meta_out); input, axes, starts, ends, infer_flags, decrease_axis, &meta_out);
SliceRawKernel<T, Context>( SliceKernel<T, Context>(
ctx, input, axes, starts, ends, infer_flags, decrease_axis, &dense_out); ctx, input, axes, starts, ends, infer_flags, decrease_axis, &dense_out);
return dense_out; return dense_out;
} }
......
...@@ -21,7 +21,7 @@ ...@@ -21,7 +21,7 @@
namespace phi { namespace phi {
template <typename T, typename Context> template <typename T, typename Context>
void SliceRawKernel(const Context& ctx, void SliceKernel(const Context& ctx,
const DenseTensor& input, const DenseTensor& input,
const std::vector<int64_t>& axes, const std::vector<int64_t>& axes,
const IntArray& starts_t, const IntArray& starts_t,
...@@ -110,7 +110,7 @@ void SliceRawKernel(const Context& ctx, ...@@ -110,7 +110,7 @@ void SliceRawKernel(const Context& ctx,
PD_REGISTER_KERNEL(slice, PD_REGISTER_KERNEL(slice,
XPU, XPU,
ALL_LAYOUT, ALL_LAYOUT,
phi::SliceRawKernel, phi::SliceKernel,
float, float,
int, int,
phi::dtype::float16, phi::dtype::float16,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册