未验证 提交 30f5e39b 编写于 作者: Y YuanRisheng 提交者: GitHub

[PHI]Rename some PHI Kernel (#49470)

* rename kernel

* delete sig

* modify code according comment

* fix ci bugs
上级 280677c5
......@@ -424,7 +424,7 @@ class ReshapeKernel {
}
if (platform::is_cpu_place(ctx.GetPlace())) {
auto &dev_ctx = ctx.device_context<phi::CPUContext>();
phi::ReshapeKernel(static_cast<const phi::CPUContext &>(dev_ctx),
phi::ReshapeInferKernel(static_cast<const phi::CPUContext &>(dev_ctx),
*in,
pt_scalar_shape,
out);
......@@ -432,7 +432,7 @@ class ReshapeKernel {
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
if (platform::is_gpu_place(ctx.GetPlace())) {
auto &dev_ctx = ctx.device_context<phi::GPUContext>();
phi::ReshapeKernel(static_cast<const phi::GPUContext &>(dev_ctx),
phi::ReshapeInferKernel(static_cast<const phi::GPUContext &>(dev_ctx),
*in,
pt_scalar_shape,
out);
......@@ -441,7 +441,7 @@ class ReshapeKernel {
#ifdef PADDLE_WITH_XPU
if (platform::is_xpu_place(ctx.GetPlace())) {
auto &dev_ctx = ctx.device_context<platform::XPUDeviceContext>();
phi::ReshapeKernel(static_cast<const phi::XPUContext &>(dev_ctx),
phi::ReshapeInferKernel(static_cast<const phi::XPUContext &>(dev_ctx),
*in,
pt_scalar_shape,
out);
......
......@@ -561,7 +561,7 @@
func : EinsumRawInferMeta
param : [x, equation]
kernel :
func : einsum_raw
func : einsum
backward : einsum_grad
- op : elementwise_pow
......@@ -677,7 +677,7 @@
infer_meta :
func : FlattenWithXShapeInferMeta
kernel :
func : flatten_with_xshape
func : flatten
backend : x
inplace : (x -> out)
view : (x -> out)
......@@ -1391,7 +1391,7 @@
infer_meta :
func : ReduceIntArrayAxisInferMetaBase
kernel :
func : prod_raw
func : prod
backward : prod_grad
- op : psroi_pool
......@@ -1473,7 +1473,7 @@
infer_meta :
func : ReshapeWithXShapeInferMeta
kernel :
func : reshape_with_xshape
func : reshape
inplace : (x -> out)
view: (x -> out)
intermediate : xshape
......
......@@ -1150,7 +1150,7 @@
infer_meta :
func : SqueezeWithXShapeInferMeta
kernel :
func : squeeze_with_xshape
func : squeeze
data_type : x
inplace : (x -> out)
view: (x -> out)
......@@ -1258,7 +1258,7 @@
infer_meta :
func : UnsqueezeWithXShapeInferMeta
kernel :
func : unsqueeze_with_xshape
func : unsqueeze
data_type : x
inplace : (x -> out)
view: (x -> out)
......
......@@ -270,7 +270,7 @@ XPUOpMap& get_kl2_ops() {
phi::DataType::INT8,
phi::DataType::FLOAT16,
phi::DataType::FLOAT32})},
{"flatten_with_xshape",
{"flatten",
XPUKernelSet({phi::DataType::INT64,
phi::DataType::INT32,
phi::DataType::INT8,
......@@ -450,7 +450,7 @@ XPUOpMap& get_kl2_ops() {
phi::DataType::INT32,
phi::DataType::BOOL,
phi::DataType::FLOAT32})},
{"reshape_with_xshape",
{"reshape",
XPUKernelSet({phi::DataType::FLOAT64,
phi::DataType::FLOAT16,
phi::DataType::INT64,
......@@ -541,7 +541,7 @@ XPUOpMap& get_kl2_ops() {
phi::DataType::INT8,
phi::DataType::UINT8,
phi::DataType::FLOAT32})},
{"squeeze_with_xshape",
{"squeeze",
XPUKernelSet({phi::DataType::FLOAT64,
phi::DataType::INT64,
phi::DataType::INT32,
......@@ -655,7 +655,7 @@ XPUOpMap& get_kl2_ops() {
phi::DataType::UINT8,
phi::DataType::FLOAT16,
phi::DataType::FLOAT32})},
{"unsqueeze_with_xshape",
{"unsqueeze",
XPUKernelSet({phi::DataType::FLOAT64,
phi::DataType::INT64,
phi::DataType::INT32,
......
......@@ -18,19 +18,19 @@
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/einsum_impl.h"
PD_REGISTER_KERNEL(einsum_raw,
PD_REGISTER_KERNEL(einsum,
CPU,
ALL_LAYOUT,
phi::EinsumKernelRaw,
phi::EinsumKernel,
float,
double,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
PD_REGISTER_KERNEL(einsum,
PD_REGISTER_KERNEL(einsum_infer,
CPU,
ALL_LAYOUT,
phi::EinsumKernel,
phi::EinsumInferKernel,
float,
double,
phi::dtype::complex<float>,
......
......@@ -22,7 +22,7 @@
namespace phi {
template <typename T, typename Context>
void ProdRawKernel(const Context& dev_ctx,
void ProdKernel(const Context& dev_ctx,
const DenseTensor& x,
const IntArray& dims,
bool keep_dim,
......@@ -36,11 +36,5 @@ void ProdRawKernel(const Context& dev_ctx,
} // namespace phi
PD_REGISTER_KERNEL(prod_raw,
CPU,
ALL_LAYOUT,
phi::ProdRawKernel,
float,
double,
int,
int64_t) {}
PD_REGISTER_KERNEL(
prod, CPU, ALL_LAYOUT, phi::ProdKernel, float, double, int, int64_t) {}
......@@ -19,13 +19,13 @@
namespace phi {
template <typename T, typename Context>
void EinsumKernel(const Context& dev_ctx,
void EinsumInferKernel(const Context& dev_ctx,
const std::vector<const DenseTensor*>& inputs,
const std::string& equation,
DenseTensor* out);
template <typename T, typename Context>
void EinsumKernelRaw(const Context& dev_ctx,
void EinsumKernel(const Context& dev_ctx,
const std::vector<const DenseTensor*>& inputs,
const std::string& equation,
DenseTensor* out,
......
......@@ -23,7 +23,7 @@
namespace phi {
template <typename T, typename Context>
void FlattenKernel(const Context& dev_ctx,
void FlattenInferKernel(const Context& dev_ctx,
const DenseTensor& x,
int start_axis,
int stop_axis,
......@@ -38,21 +38,21 @@ void FlattenKernel(const Context& dev_ctx,
// Output Tensor,
// is there a more flexible way to deal with this case?
template <typename T, typename Context>
void FlattenWithXShape(const Context& dev_ctx,
void FlattenKernel(const Context& dev_ctx,
const DenseTensor& x,
int start_axis,
int stop_axis,
DenseTensor* out,
DenseTensor* xshape) {
FlattenKernel<T, Context>(dev_ctx, x, start_axis, stop_axis, out);
FlattenInferKernel<T, Context>(dev_ctx, x, start_axis, stop_axis, out);
}
} // namespace phi
PD_REGISTER_KERNEL(flatten,
PD_REGISTER_KERNEL(flatten_infer,
CPU,
ALL_LAYOUT,
phi::FlattenKernel,
phi::FlattenInferKernel,
float,
phi::dtype::bfloat16,
double,
......@@ -62,10 +62,10 @@ PD_REGISTER_KERNEL(flatten,
int,
int64_t) {}
PD_REGISTER_KERNEL(flatten_with_xshape,
PD_REGISTER_KERNEL(flatten,
CPU,
ALL_LAYOUT,
phi::FlattenWithXShape,
phi::FlattenKernel,
float,
phi::dtype::bfloat16,
double,
......@@ -76,10 +76,10 @@ PD_REGISTER_KERNEL(flatten_with_xshape,
int64_t) {}
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
PD_REGISTER_KERNEL(flatten,
PD_REGISTER_KERNEL(flatten_infer,
GPU,
ALL_LAYOUT,
phi::FlattenKernel,
phi::FlattenInferKernel,
float,
phi::dtype::float16,
phi::dtype::bfloat16,
......@@ -90,10 +90,10 @@ PD_REGISTER_KERNEL(flatten,
int,
int64_t) {}
PD_REGISTER_KERNEL(flatten_with_xshape,
PD_REGISTER_KERNEL(flatten,
GPU,
ALL_LAYOUT,
phi::FlattenWithXShape,
phi::FlattenKernel,
float,
phi::dtype::float16,
phi::dtype::bfloat16,
......@@ -106,10 +106,10 @@ PD_REGISTER_KERNEL(flatten_with_xshape,
#endif
#ifdef PADDLE_WITH_XPU
PD_REGISTER_KERNEL(flatten,
PD_REGISTER_KERNEL(flatten_infer,
XPU,
ALL_LAYOUT,
phi::FlattenKernel,
phi::FlattenInferKernel,
float,
phi::dtype::float16,
int8_t,
......@@ -117,10 +117,10 @@ PD_REGISTER_KERNEL(flatten,
int,
int64_t) {}
PD_REGISTER_KERNEL(flatten_with_xshape,
PD_REGISTER_KERNEL(flatten,
XPU,
ALL_LAYOUT,
phi::FlattenWithXShape,
phi::FlattenKernel,
float,
phi::dtype::float16,
int8_t,
......
......@@ -21,14 +21,14 @@ limitations under the License. */
namespace phi {
template <typename T, typename Context>
void FlattenKernel(const Context& dev_ctx,
void FlattenInferKernel(const Context& dev_ctx,
const DenseTensor& x,
int start_axis,
int stop_axis,
DenseTensor* out);
template <typename T, typename Context>
void FlattenWithXShape(const Context& dev_ctx,
void FlattenKernel(const Context& dev_ctx,
const DenseTensor& x,
int start_axis,
int stop_axis,
......@@ -43,7 +43,7 @@ DenseTensor Flatten(const Context& dev_ctx,
DenseTensor dense_out;
MetaTensor meta_out(&dense_out);
FlattenInferMeta(x, start_axis, stop_axis, &meta_out);
FlattenKernel<T, Context>(dev_ctx, x, start_axis, stop_axis, &dense_out);
FlattenInferKernel<T, Context>(dev_ctx, x, start_axis, stop_axis, &dense_out);
return dense_out;
}
......
......@@ -18,10 +18,10 @@
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/einsum_impl.h"
PD_REGISTER_KERNEL(einsum_raw,
PD_REGISTER_KERNEL(einsum,
GPU,
ALL_LAYOUT,
phi::EinsumKernelRaw,
phi::EinsumKernel,
float,
double,
phi::dtype::float16,
......@@ -29,10 +29,10 @@ PD_REGISTER_KERNEL(einsum_raw,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
PD_REGISTER_KERNEL(einsum,
PD_REGISTER_KERNEL(einsum_infer,
GPU,
ALL_LAYOUT,
phi::EinsumKernel,
phi::EinsumInferKernel,
float,
double,
phi::dtype::float16,
......
......@@ -103,7 +103,7 @@ DenseTensor PerformTileAndReduction(const Context& dev_ctx,
// undiagonalize by einsum equation. only contain undiagonal operations.
DenseTensor out;
VLOG(5) << "Undiagonal by einsum with args: " << op_label + "->" + equ;
EinsumKernel<T, Context>(dev_ctx, {&ret}, op_label + "->" + equ, &out);
EinsumInferKernel<T, Context>(dev_ctx, {&ret}, op_label + "->" + equ, &out);
return out;
}
......@@ -157,7 +157,8 @@ void EinsumGradKernel(const Context& dev_ctx,
new_operands.push_back(&out_grad);
DenseTensor before_tile;
VLOG(5) << "new_equation is " << new_equation;
EinsumKernel<T, Context>(dev_ctx, new_operands, new_equation, &before_tile);
EinsumInferKernel<T, Context>(
dev_ctx, new_operands, new_equation, &before_tile);
*(x_grad[0]) = PerformTileAndReduction<T, Context>(dev_ctx,
labeltype,
labelshape,
......
......@@ -746,7 +746,7 @@ void EinsumKernelImpl(const Context& dev_ctx,
}
template <typename T, typename Context>
void EinsumKernelRaw(const Context& dev_ctx,
void EinsumKernel(const Context& dev_ctx,
const std::vector<const DenseTensor*>& inputs,
const std::string& equation,
DenseTensor* out,
......@@ -765,7 +765,7 @@ void EinsumKernelRaw(const Context& dev_ctx,
}
template <typename T, typename Context>
void EinsumKernel(const Context& dev_ctx,
void EinsumInferKernel(const Context& dev_ctx,
const std::vector<const DenseTensor*>& inputs,
const std::string& equation,
DenseTensor* out) {
......
......@@ -169,7 +169,7 @@ static void linalg_solve(const Context& dev_ctx,
out_tmp.Resize(out->dims());
out_tmp = *out;
phi::SqueezeKernel<T, Context>(dev_ctx, out_tmp, {-1}, out);
phi::SqueezeInferKernel<T, Context>(dev_ctx, out_tmp, {-1}, out);
} else {
PADDLE_ENFORCE_EQ(
x_dim[x_dim_size - 1],
......
......@@ -19,7 +19,7 @@
namespace phi {
template <typename T, typename Context>
void ProdRawKernel(const Context& dev_ctx,
void ProdKernel(const Context& dev_ctx,
const DenseTensor& x,
const IntArray& dims,
bool keep_dim,
......@@ -33,14 +33,8 @@ void ProdRawKernel(const Context& dev_ctx,
} // namespace phi
#ifdef PADDLE_WITH_XPU_KP
PD_REGISTER_KERNEL(prod_raw, KPS, ALL_LAYOUT, phi::ProdRawKernel, float) {}
PD_REGISTER_KERNEL(prod, KPS, ALL_LAYOUT, phi::ProdKernel, float) {}
#else
PD_REGISTER_KERNEL(prod_raw,
KPS,
ALL_LAYOUT,
phi::ProdRawKernel,
float,
double,
int,
int64_t) {}
PD_REGISTER_KERNEL(
prod, KPS, ALL_LAYOUT, phi::ProdKernel, float, double, int, int64_t) {}
#endif
......@@ -148,7 +148,7 @@ void ExecuteReshape(const Context& dev_ctx,
}
template <typename T, typename Context>
void ReshapeKernel(const Context& dev_ctx,
void ReshapeInferKernel(const Context& dev_ctx,
const DenseTensor& x,
const IntArray& shape,
DenseTensor* out) {
......@@ -157,7 +157,7 @@ void ReshapeKernel(const Context& dev_ctx,
}
template <typename T, typename Context>
void ReshapeWithXShape(const Context& dev_ctx,
void ReshapeKernel(const Context& dev_ctx,
const DenseTensor& x,
const IntArray& shape,
DenseTensor* out,
......@@ -168,12 +168,12 @@ void ReshapeWithXShape(const Context& dev_ctx,
} // namespace phi
PD_REGISTER_KERNEL(
reshape, OneDNN, ONEDNN, phi::ReshapeKernel, float, phi::dtype::bfloat16) {}
PD_REGISTER_KERNEL(reshape_with_xshape,
PD_REGISTER_KERNEL(reshape_infer,
OneDNN,
ONEDNN,
phi::ReshapeWithXShape,
phi::ReshapeInferKernel,
float,
phi::dtype::bfloat16) {}
PD_REGISTER_KERNEL(
reshape, OneDNN, ONEDNN, phi::ReshapeKernel, float, phi::dtype::bfloat16) {}
......@@ -52,7 +52,7 @@ void ExecuteSqueeze(const Context& dev_ctx,
}
template <typename T, typename Context>
void SqueezeKernel(const Context& dev_ctx,
void SqueezeInferKernel(const Context& dev_ctx,
const DenseTensor& x,
const IntArray& axes,
DenseTensor* out) {
......@@ -63,13 +63,13 @@ void SqueezeKernel(const Context& dev_ctx,
}
template <typename T, typename Context>
void SqueezeWithXShapeKernel(const Context& dev_ctx,
void SqueezeKernel(const Context& dev_ctx,
const DenseTensor& x,
const IntArray& axes,
DenseTensor* out,
DenseTensor* xshape) {
if (xshape == nullptr) {
SqueezeKernel<T, Context>(dev_ctx, x, axes, out);
SqueezeInferKernel<T, Context>(dev_ctx, x, axes, out);
} else {
auto x_dims = slice_ddim(xshape->dims(), 1, xshape->dims().size());
auto out_dims = out->dims();
......@@ -78,12 +78,12 @@ void SqueezeWithXShapeKernel(const Context& dev_ctx,
}
} // namespace phi
PD_REGISTER_KERNEL(
squeeze, OneDNN, ONEDNN, phi::SqueezeKernel, float, phi::dtype::bfloat16) {}
PD_REGISTER_KERNEL(squeeze_with_xshape,
PD_REGISTER_KERNEL(squeeze_infer,
OneDNN,
ONEDNN,
phi::SqueezeWithXShapeKernel,
phi::SqueezeInferKernel,
float,
phi::dtype::bfloat16) {}
PD_REGISTER_KERNEL(
squeeze, OneDNN, ONEDNN, phi::SqueezeKernel, float, phi::dtype::bfloat16) {}
......@@ -20,29 +20,41 @@
namespace phi {
template <typename T, typename Context>
void ProdKernel(const Context& dev_ctx,
void ProdInferKernel(const Context& dev_ctx,
const DenseTensor& x,
const IntArray& dims,
bool keep_dim,
DenseTensor* out) {
bool reduce_all = recompute_reduce_all(x, dims);
ProdRawKernel<T>(dev_ctx, x, dims, keep_dim, reduce_all, out);
ProdKernel<T>(dev_ctx, x, dims, keep_dim, reduce_all, out);
}
} // namespace phi
PD_REGISTER_KERNEL(
prod, CPU, ALL_LAYOUT, phi::ProdKernel, float, double, int, int64_t) {}
PD_REGISTER_KERNEL(prod_infer,
CPU,
ALL_LAYOUT,
phi::ProdInferKernel,
float,
double,
int,
int64_t) {}
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
PD_REGISTER_KERNEL(
prod, GPU, ALL_LAYOUT, phi::ProdKernel, float, double, int, int64_t) {}
PD_REGISTER_KERNEL(prod_infer,
GPU,
ALL_LAYOUT,
phi::ProdInferKernel,
float,
double,
int,
int64_t) {}
#endif
#if defined(PADDLE_WITH_XPU_KP) && !defined(PADDLE_WITH_XPU)
PD_REGISTER_KERNEL(prod, KPS, ALL_LAYOUT, phi::ProdKernel, float) {}
PD_REGISTER_KERNEL(prod_infer, KPS, ALL_LAYOUT, phi::ProdInferKernel, float) {}
#endif
#if defined(PADDLE_WITH_XPU)
PD_REGISTER_KERNEL(prod, XPU, ALL_LAYOUT, phi::ProdKernel, float) {}
PD_REGISTER_KERNEL(prod_infer, XPU, ALL_LAYOUT, phi::ProdInferKernel, float) {}
#endif
......@@ -19,7 +19,7 @@
namespace phi {
template <typename T, typename Context>
void ProdRawKernel(const Context& dev_ctx,
void ProdKernel(const Context& dev_ctx,
const DenseTensor& x,
const IntArray& dims,
bool keep_dim,
......@@ -27,7 +27,7 @@ void ProdRawKernel(const Context& dev_ctx,
DenseTensor* out);
template <typename T, typename Context>
void ProdKernel(const Context& dev_ctx,
void ProdInferKernel(const Context& dev_ctx,
const DenseTensor& x,
const IntArray& dims,
bool keep_dim,
......
......@@ -26,7 +26,7 @@
namespace phi {
template <typename Context>
void ReshapeKernel(const Context& dev_ctx,
void ReshapeInferKernel(const Context& dev_ctx,
const DenseTensor& x,
const IntArray& shape,
DenseTensor* out) {
......@@ -47,7 +47,7 @@ void ReshapeKernel(const Context& dev_ctx,
#ifdef PADDLE_WITH_XPU
template <>
void ReshapeKernel<phi::XPUContext>(const XPUContext& dev_ctx,
void ReshapeInferKernel<phi::XPUContext>(const XPUContext& dev_ctx,
const DenseTensor& x,
const IntArray& shape,
DenseTensor* out) {
......@@ -73,40 +73,40 @@ void ReshapeKernel<phi::XPUContext>(const XPUContext& dev_ctx,
#endif
template <typename Context>
void ReshapeWithXShape(const Context& dev_ctx,
void ReshapeKernel(const Context& dev_ctx,
const DenseTensor& x,
const IntArray& shape,
DenseTensor* out,
DenseTensor* xshape) {
ReshapeKernel(dev_ctx, x, shape, out);
ReshapeInferKernel(dev_ctx, x, shape, out);
}
} // namespace phi
PD_REGISTER_GENERAL_KERNEL(
reshape, CPU, ALL_LAYOUT, phi::ReshapeKernel<phi::CPUContext>, ALL_DTYPE) {}
PD_REGISTER_GENERAL_KERNEL(reshape_with_xshape,
PD_REGISTER_GENERAL_KERNEL(reshape_infer,
CPU,
ALL_LAYOUT,
phi::ReshapeWithXShape<phi::CPUContext>,
phi::ReshapeInferKernel<phi::CPUContext>,
ALL_DTYPE) {}
PD_REGISTER_GENERAL_KERNEL(
reshape, CPU, ALL_LAYOUT, phi::ReshapeKernel<phi::CPUContext>, ALL_DTYPE) {}
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
PD_REGISTER_GENERAL_KERNEL(
reshape, GPU, ALL_LAYOUT, phi::ReshapeKernel<phi::GPUContext>, ALL_DTYPE) {}
PD_REGISTER_GENERAL_KERNEL(reshape_with_xshape,
PD_REGISTER_GENERAL_KERNEL(reshape_infer,
GPU,
ALL_LAYOUT,
phi::ReshapeWithXShape<phi::GPUContext>,
phi::ReshapeInferKernel<phi::GPUContext>,
ALL_DTYPE) {}
PD_REGISTER_GENERAL_KERNEL(
reshape, GPU, ALL_LAYOUT, phi::ReshapeKernel<phi::GPUContext>, ALL_DTYPE) {}
#endif
#ifdef PADDLE_WITH_XPU
PD_REGISTER_GENERAL_KERNEL(
reshape, XPU, ALL_LAYOUT, phi::ReshapeKernel<phi::XPUContext>, ALL_DTYPE) {}
PD_REGISTER_GENERAL_KERNEL(reshape_with_xshape,
PD_REGISTER_GENERAL_KERNEL(reshape_infer,
XPU,
ALL_LAYOUT,
phi::ReshapeWithXShape<phi::XPUContext>,
phi::ReshapeInferKernel<phi::XPUContext>,
ALL_DTYPE) {}
PD_REGISTER_GENERAL_KERNEL(
reshape, XPU, ALL_LAYOUT, phi::ReshapeKernel<phi::XPUContext>, ALL_DTYPE) {}
#endif
......@@ -22,13 +22,13 @@ limitations under the License. */
namespace phi {
template <typename Context>
void ReshapeKernel(const Context& dev_ctx,
void ReshapeInferKernel(const Context& dev_ctx,
const DenseTensor& x,
const IntArray& shape,
DenseTensor* out);
template <typename Context>
void ReshapeWithXShape(const Context& dev_ctx,
void ReshapeKernel(const Context& dev_ctx,
const DenseTensor& x,
const IntArray& shape,
DenseTensor* out,
......@@ -41,7 +41,7 @@ DenseTensor Reshape(const Context& dev_ctx,
DenseTensor dense_out;
MetaTensor meta_out(&dense_out);
InferMetaFromVecValue(x, shape, &meta_out);
ReshapeKernel<Context>(dev_ctx, x, IntArray(shape), &dense_out);
ReshapeInferKernel<Context>(dev_ctx, x, IntArray(shape), &dense_out);
return dense_out;
}
......
......@@ -21,7 +21,7 @@
namespace phi {
template <typename T, typename Context>
void SqueezeKernel(const Context& dev_ctx,
void SqueezeInferKernel(const Context& dev_ctx,
const DenseTensor& x,
const IntArray& axes,
DenseTensor* out) {
......@@ -36,20 +36,20 @@ void SqueezeKernel(const Context& dev_ctx,
}
template <typename T, typename Context>
void SqueezeWithXShapeKernel(const Context& dev_ctx,
void SqueezeKernel(const Context& dev_ctx,
const DenseTensor& x,
const IntArray& axes,
DenseTensor* out,
DenseTensor* xshape) {
SqueezeKernel<T, Context>(dev_ctx, x, axes, out);
SqueezeInferKernel<T, Context>(dev_ctx, x, axes, out);
}
} // namespace phi
PD_REGISTER_KERNEL(squeeze,
PD_REGISTER_KERNEL(squeeze_infer,
CPU,
ALL_LAYOUT,
phi::SqueezeKernel,
phi::SqueezeInferKernel,
float,
double,
phi::dtype::bfloat16,
......@@ -61,10 +61,10 @@ PD_REGISTER_KERNEL(squeeze,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
PD_REGISTER_KERNEL(squeeze_with_xshape,
PD_REGISTER_KERNEL(squeeze,
CPU,
ALL_LAYOUT,
phi::SqueezeWithXShapeKernel,
phi::SqueezeKernel,
float,
double,
phi::dtype::bfloat16,
......@@ -76,10 +76,10 @@ PD_REGISTER_KERNEL(squeeze_with_xshape,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
PD_REGISTER_KERNEL(squeeze,
PD_REGISTER_KERNEL(squeeze_infer,
GPU,
ALL_LAYOUT,
phi::SqueezeKernel,
phi::SqueezeInferKernel,
float,
double,
phi::dtype::float16,
......@@ -92,10 +92,10 @@ PD_REGISTER_KERNEL(squeeze,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
PD_REGISTER_KERNEL(squeeze_with_xshape,
PD_REGISTER_KERNEL(squeeze,
GPU,
ALL_LAYOUT,
phi::SqueezeWithXShapeKernel,
phi::SqueezeKernel,
float,
double,
phi::dtype::float16,
......@@ -110,10 +110,10 @@ PD_REGISTER_KERNEL(squeeze_with_xshape,
#endif
#ifdef PADDLE_WITH_XPU
PD_REGISTER_KERNEL(squeeze,
PD_REGISTER_KERNEL(squeeze_infer,
XPU,
ALL_LAYOUT,
phi::SqueezeKernel,
phi::SqueezeInferKernel,
float,
double,
phi::dtype::float16,
......@@ -123,10 +123,10 @@ PD_REGISTER_KERNEL(squeeze,
int8_t,
int64_t) {}
PD_REGISTER_KERNEL(squeeze_with_xshape,
PD_REGISTER_KERNEL(squeeze,
XPU,
ALL_LAYOUT,
phi::SqueezeWithXShapeKernel,
phi::SqueezeKernel,
float,
double,
phi::dtype::float16,
......
......@@ -21,13 +21,13 @@
namespace phi {
template <typename T, typename Context>
void SqueezeKernel(const Context& dev_ctx,
void SqueezeInferKernel(const Context& dev_ctx,
const DenseTensor& x,
const IntArray& axes,
DenseTensor* out);
template <typename T, typename Context>
void SqueezeWithXShapeKernel(const Context& dev_ctx,
void SqueezeKernel(const Context& dev_ctx,
const DenseTensor& x,
const IntArray& axes,
DenseTensor* out,
......
......@@ -21,7 +21,7 @@
namespace phi {
template <typename T, typename Context>
void UnsqueezeKernel(const Context& dev_ctx,
void UnsqueezeInferKernel(const Context& dev_ctx,
const DenseTensor& x,
const IntArray& axes,
DenseTensor* out) {
......@@ -42,19 +42,19 @@ void UnsqueezeKernel(const Context& dev_ctx,
}
template <typename T, typename Context>
void UnsqueezeWithXShapeKernel(const Context& dev_ctx,
void UnsqueezeKernel(const Context& dev_ctx,
const DenseTensor& x,
const IntArray& axes,
DenseTensor* out,
DenseTensor* xshape) {
UnsqueezeKernel<T, Context>(dev_ctx, x, axes, out);
UnsqueezeInferKernel<T, Context>(dev_ctx, x, axes, out);
}
} // namespace phi
PD_REGISTER_KERNEL(unsqueeze,
PD_REGISTER_KERNEL(unsqueeze_infer,
CPU,
ALL_LAYOUT,
phi::UnsqueezeKernel,
phi::UnsqueezeInferKernel,
float,
double,
phi::dtype::bfloat16,
......@@ -67,10 +67,10 @@ PD_REGISTER_KERNEL(unsqueeze,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
PD_REGISTER_KERNEL(unsqueeze_with_xshape,
PD_REGISTER_KERNEL(unsqueeze,
CPU,
ALL_LAYOUT,
phi::UnsqueezeWithXShapeKernel,
phi::UnsqueezeKernel,
float,
double,
phi::dtype::bfloat16,
......@@ -83,10 +83,10 @@ PD_REGISTER_KERNEL(unsqueeze_with_xshape,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
PD_REGISTER_KERNEL(unsqueeze,
PD_REGISTER_KERNEL(unsqueeze_infer,
GPU,
ALL_LAYOUT,
phi::UnsqueezeKernel,
phi::UnsqueezeInferKernel,
float,
double,
phi::dtype::float16,
......@@ -100,10 +100,10 @@ PD_REGISTER_KERNEL(unsqueeze,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
PD_REGISTER_KERNEL(unsqueeze_with_xshape,
PD_REGISTER_KERNEL(unsqueeze,
GPU,
ALL_LAYOUT,
phi::UnsqueezeWithXShapeKernel,
phi::UnsqueezeKernel,
float,
double,
phi::dtype::float16,
......@@ -119,10 +119,10 @@ PD_REGISTER_KERNEL(unsqueeze_with_xshape,
#endif
#ifdef PADDLE_WITH_XPU
PD_REGISTER_KERNEL(unsqueeze,
PD_REGISTER_KERNEL(unsqueeze_infer,
XPU,
ALL_LAYOUT,
phi::UnsqueezeKernel,
phi::UnsqueezeInferKernel,
float,
double,
phi::dtype::float16,
......@@ -132,10 +132,10 @@ PD_REGISTER_KERNEL(unsqueeze,
int8_t,
int64_t) {}
PD_REGISTER_KERNEL(unsqueeze_with_xshape,
PD_REGISTER_KERNEL(unsqueeze,
XPU,
ALL_LAYOUT,
phi::UnsqueezeWithXShapeKernel,
phi::UnsqueezeKernel,
float,
double,
phi::dtype::float16,
......
......@@ -22,13 +22,13 @@
namespace phi {
template <typename T, typename Context>
void UnsqueezeKernel(const Context& dev_ctx,
void UnsqueezeInferKernel(const Context& dev_ctx,
const DenseTensor& x,
const IntArray& axes,
DenseTensor* out);
template <typename T, typename Context>
void UnsqueezeWithXShapeKernel(const Context& dev_ctx,
void UnsqueezeKernel(const Context& dev_ctx,
const DenseTensor& x,
const IntArray& axes,
DenseTensor* out,
......@@ -42,7 +42,7 @@ void Unsqueeze(const Context& dev_ctx,
DenseTensor* xshape) {
MetaTensor meta_out(out);
UnsqueezeInferMeta(x, axes, &meta_out);
UnsqueezeKernel<T, Context>(dev_ctx, x, axes, out);
UnsqueezeInferKernel<T, Context>(dev_ctx, x, axes, out);
}
} // namespace phi
......@@ -22,7 +22,7 @@
namespace phi {
template <typename T, typename Context>
void ProdRawKernel(const Context& dev_ctx,
void ProdKernel(const Context& dev_ctx,
const DenseTensor& x,
const IntArray& dims,
bool keep_dim,
......@@ -46,4 +46,4 @@ void ProdRawKernel(const Context& dev_ctx,
} // namespace phi
PD_REGISTER_KERNEL(prod_raw, XPU, ALL_LAYOUT, phi::ProdRawKernel, float) {}
PD_REGISTER_KERNEL(prod, XPU, ALL_LAYOUT, phi::ProdKernel, float) {}
......@@ -17,10 +17,8 @@ limitations under the License. */
namespace phi {
KernelSignature EinsumOpArgumentMapping(const ArgumentMappingContext& ctx) {
return KernelSignature("einsum_raw",
{"Operands"},
{"equation"},
{"Out", "InnerCache", "XShape"});
return KernelSignature(
"einsum", {"Operands"}, {"equation"}, {"Out", "InnerCache", "XShape"});
}
KernelSignature EinsumGradOpArgumentMapping(const ArgumentMappingContext& ctx) {
......@@ -31,7 +29,5 @@ KernelSignature EinsumGradOpArgumentMapping(const ArgumentMappingContext& ctx) {
}
} // namespace phi
PD_REGISTER_BASE_KERNEL_NAME(einsum, einsum_raw);
PD_REGISTER_ARG_MAPPING_FN(einsum, phi::EinsumOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(einsum_grad, phi::EinsumGradOpArgumentMapping);
......@@ -18,13 +18,11 @@ namespace phi {
KernelSignature FlattenOpArgumentMapping(const ArgumentMappingContext& ctx) {
if (ctx.HasOutput("XShape")) {
return KernelSignature("flatten_with_xshape",
{"X"},
{"start_axis", "stop_axis"},
{"Out", "XShape"});
return KernelSignature(
"flatten", {"X"}, {"start_axis", "stop_axis"}, {"Out", "XShape"});
} else {
return KernelSignature(
"flatten", {"X"}, {"start_axis", "stop_axis"}, {"Out"});
"flatten_infer", {"X"}, {"start_axis", "stop_axis"}, {"Out"});
}
}
......
......@@ -60,9 +60,9 @@ KernelSignature ReduceProdOpArgumentMapping(const ArgumentMappingContext& ctx) {
// the "max_raw" KernelSignature
if (ctx.IsForInferShape() || reduce_all) {
return KernelSignature(
"prod_raw", {"X"}, {"dim", "keep_dim", "reduce_all"}, {"Out"});
"prod", {"X"}, {"dim", "keep_dim", "reduce_all"}, {"Out"});
}
return KernelSignature("prod", {"X"}, {"dim", "keep_dim"}, {"Out"});
return KernelSignature("prod_infer", {"X"}, {"dim", "keep_dim"}, {"Out"});
}
return KernelSignature("unregistered", {}, {}, {});
}
......
......@@ -20,21 +20,19 @@ KernelSignature ReshapeOpArgumentMapping(const ArgumentMappingContext& ctx) {
if (ctx.HasOutput("XShape")) {
if (ctx.InputSize("ShapeTensor") > 0) {
return KernelSignature(
"reshape_with_xshape", {"X"}, {"ShapeTensor"}, {"Out", "XShape"});
"reshape", {"X"}, {"ShapeTensor"}, {"Out", "XShape"});
} else if (ctx.HasInput("Shape")) {
return KernelSignature(
"reshape_with_xshape", {"X"}, {"Shape"}, {"Out", "XShape"});
return KernelSignature("reshape", {"X"}, {"Shape"}, {"Out", "XShape"});
} else {
return KernelSignature(
"reshape_with_xshape", {"X"}, {"shape"}, {"Out", "XShape"});
return KernelSignature("reshape", {"X"}, {"shape"}, {"Out", "XShape"});
}
} else {
if (ctx.InputSize("ShapeTensor") > 0) {
return KernelSignature("reshape", {"X"}, {"ShapeTensor"}, {"Out"});
return KernelSignature("reshape_infer", {"X"}, {"ShapeTensor"}, {"Out"});
} else if (ctx.HasInput("Shape")) {
return KernelSignature("reshape", {"X"}, {"Shape"}, {"Out"});
return KernelSignature("reshape_infer", {"X"}, {"Shape"}, {"Out"});
} else {
return KernelSignature("reshape", {"X"}, {"shape"}, {"Out"});
return KernelSignature("reshape_infer", {"X"}, {"shape"}, {"Out"});
}
}
}
......
......@@ -618,18 +618,18 @@ TEST(ARG_MAP, reshape) {
TestArgumentMappingContext arg_case1({"X", "ShapeTensor"}, {}, {}, {"Out"});
auto signature1 =
(*OpUtilsMap::Instance().GetArgumentMappingFn("reshape2"))(arg_case1);
EXPECT_STREQ(signature1.name, "reshape");
EXPECT_STREQ(signature1.name, "reshape_infer");
TestArgumentMappingContext arg_case2({"X", "Shape"}, {}, {}, {"Out"});
auto signature2 =
(*OpUtilsMap::Instance().GetArgumentMappingFn("reshape2"))(arg_case2);
EXPECT_STREQ(signature2.name, "reshape");
EXPECT_STREQ(signature2.name, "reshape_infer");
TestArgumentMappingContext arg_case3(
{"X"}, {}, {{"shape", paddle::any(std::vector<int>({1, 2}))}}, {"Out"});
auto signature3 =
(*OpUtilsMap::Instance().GetArgumentMappingFn("reshape2"))(arg_case3);
EXPECT_STREQ(signature3.name, "reshape");
EXPECT_STREQ(signature3.name, "reshape_infer");
}
} // namespace tests
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册