未验证 提交 d2b0d63f 编写于 作者: Z zhangyuqin1998 提交者: GitHub

rename_SliceKernel (#52863)

上级 514d83de
......@@ -20,7 +20,7 @@
namespace phi {
template <typename T, typename Context>
void AccuracyRawKernel(const Context& dev_ctx,
void AccuracyKernel(const Context& dev_ctx,
const DenseTensor& out,
const DenseTensor& indices,
const DenseTensor& label,
......
......@@ -21,7 +21,7 @@
namespace phi {
template <typename T, typename Context>
void AccuracyRawKernel(const Context& dev_ctx,
void AccuracyKernel(const Context& dev_ctx,
const DenseTensor& inference,
const DenseTensor& indices,
const DenseTensor& label,
......@@ -93,7 +93,7 @@ void AccuracyRawKernel(const Context& dev_ctx,
// TODO(add supported dtype.)
PD_REGISTER_KERNEL(
accuracy, CPU, ALL_LAYOUT, phi::AccuracyRawKernel, float, double) {
accuracy, CPU, ALL_LAYOUT, phi::AccuracyKernel, float, double) {
kernel->InputAt(1).SetDataType(phi::DataType::INT64);
kernel->InputAt(2).SetDataType(phi::DataType::INT64);
kernel->OutputAt(0).SetDataType(phi::DataType::FLOAT32);
......
......@@ -21,7 +21,7 @@
PD_REGISTER_KERNEL(slice,
CPU,
ALL_LAYOUT,
phi::SliceRawKernel,
phi::SliceKernel,
bool,
uint8_t,
int,
......
......@@ -73,7 +73,7 @@ __global__ void AccuracyCudaKernel(const int N,
}
template <typename T, typename Context>
void AccuracyRawKernel(const Context& dev_ctx,
void AccuracyKernel(const Context& dev_ctx,
const DenseTensor& inference,
const DenseTensor& indices,
const DenseTensor& label,
......@@ -137,7 +137,7 @@ void AccuracyRawKernel(const Context& dev_ctx,
PD_REGISTER_KERNEL(accuracy,
GPU,
ALL_LAYOUT,
phi::AccuracyRawKernel,
phi::AccuracyKernel,
phi::dtype::float16,
phi::dtype::bfloat16,
float,
......
......@@ -101,8 +101,8 @@ void QrKernel(const Context& ctx,
if (reduced_mode) {
auto trans_qr = TransposeLast2Dim<T, Context>(ctx, qr);
auto sliced_qr = SliceKernel<T, Context>(
ctx, trans_qr, {trans_qr.dims().size() - 2}, {0}, {min_mn}, {1}, {});
auto sliced_qr = Slice<T, Context>(
ctx, trans_qr, {trans_qr.dims().size() - 2}, {0}, {min_mn});
auto tmp_r = TrilTriu<T, Context>(ctx, sliced_qr, 0, false);
// Transpose 'tmp_r' to retore the original row-major order
phi::Copy(ctx, tmp_r, r->place(), false, r);
......@@ -128,8 +128,8 @@ void QrKernel(const Context& ctx,
qr_stride,
tau_stride);
auto trans_q = TransposeLast2Dim<T, Context>(ctx, qr);
auto sliced_q = SliceKernel<T, Context>(
ctx, trans_q, {trans_q.dims().size() - 1}, {0}, {min_mn}, {1}, {});
auto sliced_q = Slice<T, Context>(
ctx, trans_q, {trans_q.dims().size() - 1}, {0}, {min_mn});
phi::Copy(ctx, sliced_q, q->place(), false, q);
} else {
if (m > n) {
......@@ -170,8 +170,8 @@ void QrKernel(const Context& ctx,
qr_stride,
tau_stride);
auto trans_q = TransposeLast2Dim<T, Context>(ctx, qr);
auto sliced_q = SliceKernel<T, Context>(
ctx, trans_q, {trans_q.dims().size() - 1}, {0}, {m}, {1}, {});
auto sliced_q = Slice<T, Context>(
ctx, trans_q, {trans_q.dims().size() - 1}, {0}, {m});
phi::Copy(ctx, sliced_q, q->place(), false, q);
}
}
......
......@@ -21,7 +21,7 @@
PD_REGISTER_KERNEL(slice,
GPU,
ALL_LAYOUT,
phi::SliceRawKernel,
phi::SliceKernel,
bool,
uint8_t,
int,
......
......@@ -149,17 +149,13 @@ void QrGradKernel(const Context& ctx,
// Calculate dX and dY individually and concatenate them to get dA
ctx.template Alloc<phi::dtype::Real<T>>(&dA);
auto Y = SliceKernel<T, Context>(
ctx, A, {A.dims().size() - 1}, {m}, {n}, {1}, {});
auto U = SliceKernel<T, Context>(
ctx, R, {R.dims().size() - 1}, {0}, {m}, {1}, {});
auto Y = Slice<T, Context>(ctx, A, {A.dims().size() - 1}, {m}, {n});
auto U = Slice<T, Context>(ctx, R, {R.dims().size() - 1}, {0}, {m});
DenseTensor dY, dX, dV, dR_tmp, dQ_prime;
if (dR.initialized()) {
dV = SliceKernel<T, Context>(
ctx, dR, {dR.dims().size() - 1}, {m}, {n}, {1}, {});
dR_tmp = SliceKernel<T, Context>(
ctx, dR, {dR.dims().size() - 1}, {0}, {m}, {1}, {});
dV = Slice<T, Context>(ctx, dR, {dR.dims().size() - 1}, {m}, {n});
dR_tmp = Slice<T, Context>(ctx, dR, {dR.dims().size() - 1}, {0}, {m});
// Y * dV^H
dQ_prime =
Matmul<T, Context>(ctx, Y, TransposeLast2Dim<T, Context>(ctx, dV));
......
......@@ -100,7 +100,7 @@ void SliceCompute(const Context& ctx,
}
template <typename T, typename Context>
void SliceRawKernel(const Context& ctx,
void SliceKernel(const Context& ctx,
const DenseTensor& input,
const std::vector<int64_t>& axes,
const IntArray& starts_arr,
......
......@@ -83,27 +83,17 @@ void SvdGradKernel(const Context& dev_ctx,
DenseTensor U, VH, dU, dV, dVH;
if (full_matrices) {
// if full_matrices is set, slice the U and VT to k columns
U = SliceKernel<T, Context>(
dev_ctx, u, {u.dims().size() - 1}, {0}, {k}, {1}, {});
VH = SliceKernel<T, Context>(
dev_ctx, vh, {vh.dims().size() - 2}, {0}, {k}, {1}, {});
U = Slice<T, Context>(dev_ctx, u, {u.dims().size() - 1}, {0}, {k});
// If m < n for input matrices A, we partition A = [X|Y] and R = [U|V]
VH = Slice<T, Context>(dev_ctx, vh, {vh.dims().size() - 2}, {0}, {k});
if (u_grad.get_ptr() != nullptr) {
dU = SliceKernel<T, Context>(dev_ctx,
*(u_grad.get_ptr()),
{u.dims().size() - 1},
{0},
{k},
{1},
{});
dU = Slice<T, Context>(
dev_ctx, *(u_grad.get_ptr()), {u.dims().size() - 1}, {0}, {k});
}
if (vh_grad.get_ptr() != nullptr) {
dVH = SliceKernel<T, Context>(dev_ctx,
*(vh_grad.get_ptr()),
{vh.dims().size() - 2},
{0},
{k},
{1},
{});
dVH = Slice<T, Context>(
dev_ctx, *(vh_grad.get_ptr()), {vh.dims().size() - 2}, {0}, {k});
}
} else {
U = u;
......
......@@ -20,7 +20,7 @@
namespace phi {
template <typename T, typename Context>
void SliceRawKernel(const Context& dev_ctx,
void SliceKernel(const Context& dev_ctx,
const DenseTensor& x,
const std::vector<int64_t>& axes,
const IntArray& starts,
......@@ -102,7 +102,7 @@ void SliceRawKernel(const Context& dev_ctx,
PD_REGISTER_KERNEL(slice,
OneDNN,
ONEDNN,
phi::SliceRawKernel,
phi::SliceKernel,
float,
int8_t,
uint8_t,
......
......@@ -22,7 +22,7 @@
namespace phi {
template <typename T, typename Context>
void SliceRawKernel(const Context& ctx,
void SliceKernel(const Context& ctx,
const DenseTensor& input,
const std::vector<int64_t>& axes,
const IntArray& starts,
......@@ -45,18 +45,18 @@ void SliceArrayDenseKernel(const Context& dev_ctx,
DenseTensor* out);
template <typename T, typename Context>
DenseTensor SliceKernel(const Context& ctx,
DenseTensor Slice(const Context& ctx,
const DenseTensor& input,
const std::vector<int64_t>& axes,
const IntArray& starts,
const IntArray& ends,
const std::vector<int64_t>& infer_flags,
const std::vector<int64_t>& decrease_axis) {
const IntArray& ends) {
DenseTensor dense_out;
MetaTensor meta_out(&dense_out);
std::vector<int64_t> infer_flags = {1};
std::vector<int64_t> decrease_axis = {};
SliceRawInferMeta(
input, axes, starts, ends, infer_flags, decrease_axis, &meta_out);
SliceRawKernel<T, Context>(
SliceKernel<T, Context>(
ctx, input, axes, starts, ends, infer_flags, decrease_axis, &dense_out);
return dense_out;
}
......
......@@ -21,7 +21,7 @@
namespace phi {
template <typename T, typename Context>
void SliceRawKernel(const Context& ctx,
void SliceKernel(const Context& ctx,
const DenseTensor& input,
const std::vector<int64_t>& axes,
const IntArray& starts_t,
......@@ -110,7 +110,7 @@ void SliceRawKernel(const Context& ctx,
PD_REGISTER_KERNEL(slice,
XPU,
ALL_LAYOUT,
phi::SliceRawKernel,
phi::SliceKernel,
float,
int,
phi::dtype::float16,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册