未验证 提交 30f5e39b 编写于 作者: Y YuanRisheng 提交者: GitHub

[PHI]Rename some PHI Kernel (#49470)

* rename kernel

* delete sig

* modify code according comment

* fix ci bugs
上级 280677c5
...@@ -424,27 +424,27 @@ class ReshapeKernel { ...@@ -424,27 +424,27 @@ class ReshapeKernel {
} }
if (platform::is_cpu_place(ctx.GetPlace())) { if (platform::is_cpu_place(ctx.GetPlace())) {
auto &dev_ctx = ctx.device_context<phi::CPUContext>(); auto &dev_ctx = ctx.device_context<phi::CPUContext>();
phi::ReshapeKernel(static_cast<const phi::CPUContext &>(dev_ctx), phi::ReshapeInferKernel(static_cast<const phi::CPUContext &>(dev_ctx),
*in, *in,
pt_scalar_shape, pt_scalar_shape,
out); out);
} }
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
if (platform::is_gpu_place(ctx.GetPlace())) { if (platform::is_gpu_place(ctx.GetPlace())) {
auto &dev_ctx = ctx.device_context<phi::GPUContext>(); auto &dev_ctx = ctx.device_context<phi::GPUContext>();
phi::ReshapeKernel(static_cast<const phi::GPUContext &>(dev_ctx), phi::ReshapeInferKernel(static_cast<const phi::GPUContext &>(dev_ctx),
*in, *in,
pt_scalar_shape, pt_scalar_shape,
out); out);
} }
#endif #endif
#ifdef PADDLE_WITH_XPU #ifdef PADDLE_WITH_XPU
if (platform::is_xpu_place(ctx.GetPlace())) { if (platform::is_xpu_place(ctx.GetPlace())) {
auto &dev_ctx = ctx.device_context<platform::XPUDeviceContext>(); auto &dev_ctx = ctx.device_context<platform::XPUDeviceContext>();
phi::ReshapeKernel(static_cast<const phi::XPUContext &>(dev_ctx), phi::ReshapeInferKernel(static_cast<const phi::XPUContext &>(dev_ctx),
*in, *in,
pt_scalar_shape, pt_scalar_shape,
out); out);
} }
#endif #endif
} }
......
...@@ -561,7 +561,7 @@ ...@@ -561,7 +561,7 @@
func : EinsumRawInferMeta func : EinsumRawInferMeta
param : [x, equation] param : [x, equation]
kernel : kernel :
func : einsum_raw func : einsum
backward : einsum_grad backward : einsum_grad
- op : elementwise_pow - op : elementwise_pow
...@@ -677,7 +677,7 @@ ...@@ -677,7 +677,7 @@
infer_meta : infer_meta :
func : FlattenWithXShapeInferMeta func : FlattenWithXShapeInferMeta
kernel : kernel :
func : flatten_with_xshape func : flatten
backend : x backend : x
inplace : (x -> out) inplace : (x -> out)
view : (x -> out) view : (x -> out)
...@@ -1391,7 +1391,7 @@ ...@@ -1391,7 +1391,7 @@
infer_meta : infer_meta :
func : ReduceIntArrayAxisInferMetaBase func : ReduceIntArrayAxisInferMetaBase
kernel : kernel :
func : prod_raw func : prod
backward : prod_grad backward : prod_grad
- op : psroi_pool - op : psroi_pool
...@@ -1473,7 +1473,7 @@ ...@@ -1473,7 +1473,7 @@
infer_meta : infer_meta :
func : ReshapeWithXShapeInferMeta func : ReshapeWithXShapeInferMeta
kernel : kernel :
func : reshape_with_xshape func : reshape
inplace : (x -> out) inplace : (x -> out)
view: (x -> out) view: (x -> out)
intermediate : xshape intermediate : xshape
......
...@@ -1150,7 +1150,7 @@ ...@@ -1150,7 +1150,7 @@
infer_meta : infer_meta :
func : SqueezeWithXShapeInferMeta func : SqueezeWithXShapeInferMeta
kernel : kernel :
func : squeeze_with_xshape func : squeeze
data_type : x data_type : x
inplace : (x -> out) inplace : (x -> out)
view: (x -> out) view: (x -> out)
...@@ -1258,7 +1258,7 @@ ...@@ -1258,7 +1258,7 @@
infer_meta : infer_meta :
func : UnsqueezeWithXShapeInferMeta func : UnsqueezeWithXShapeInferMeta
kernel : kernel :
func : unsqueeze_with_xshape func : unsqueeze
data_type : x data_type : x
inplace : (x -> out) inplace : (x -> out)
view: (x -> out) view: (x -> out)
......
...@@ -270,7 +270,7 @@ XPUOpMap& get_kl2_ops() { ...@@ -270,7 +270,7 @@ XPUOpMap& get_kl2_ops() {
phi::DataType::INT8, phi::DataType::INT8,
phi::DataType::FLOAT16, phi::DataType::FLOAT16,
phi::DataType::FLOAT32})}, phi::DataType::FLOAT32})},
{"flatten_with_xshape", {"flatten",
XPUKernelSet({phi::DataType::INT64, XPUKernelSet({phi::DataType::INT64,
phi::DataType::INT32, phi::DataType::INT32,
phi::DataType::INT8, phi::DataType::INT8,
...@@ -450,7 +450,7 @@ XPUOpMap& get_kl2_ops() { ...@@ -450,7 +450,7 @@ XPUOpMap& get_kl2_ops() {
phi::DataType::INT32, phi::DataType::INT32,
phi::DataType::BOOL, phi::DataType::BOOL,
phi::DataType::FLOAT32})}, phi::DataType::FLOAT32})},
{"reshape_with_xshape", {"reshape",
XPUKernelSet({phi::DataType::FLOAT64, XPUKernelSet({phi::DataType::FLOAT64,
phi::DataType::FLOAT16, phi::DataType::FLOAT16,
phi::DataType::INT64, phi::DataType::INT64,
...@@ -541,7 +541,7 @@ XPUOpMap& get_kl2_ops() { ...@@ -541,7 +541,7 @@ XPUOpMap& get_kl2_ops() {
phi::DataType::INT8, phi::DataType::INT8,
phi::DataType::UINT8, phi::DataType::UINT8,
phi::DataType::FLOAT32})}, phi::DataType::FLOAT32})},
{"squeeze_with_xshape", {"squeeze",
XPUKernelSet({phi::DataType::FLOAT64, XPUKernelSet({phi::DataType::FLOAT64,
phi::DataType::INT64, phi::DataType::INT64,
phi::DataType::INT32, phi::DataType::INT32,
...@@ -655,7 +655,7 @@ XPUOpMap& get_kl2_ops() { ...@@ -655,7 +655,7 @@ XPUOpMap& get_kl2_ops() {
phi::DataType::UINT8, phi::DataType::UINT8,
phi::DataType::FLOAT16, phi::DataType::FLOAT16,
phi::DataType::FLOAT32})}, phi::DataType::FLOAT32})},
{"unsqueeze_with_xshape", {"unsqueeze",
XPUKernelSet({phi::DataType::FLOAT64, XPUKernelSet({phi::DataType::FLOAT64,
phi::DataType::INT64, phi::DataType::INT64,
phi::DataType::INT32, phi::DataType::INT32,
......
...@@ -18,19 +18,19 @@ ...@@ -18,19 +18,19 @@
#include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/einsum_impl.h" #include "paddle/phi/kernels/impl/einsum_impl.h"
PD_REGISTER_KERNEL(einsum_raw, PD_REGISTER_KERNEL(einsum,
CPU, CPU,
ALL_LAYOUT, ALL_LAYOUT,
phi::EinsumKernelRaw, phi::EinsumKernel,
float, float,
double, double,
phi::dtype::complex<float>, phi::dtype::complex<float>,
phi::dtype::complex<double>) {} phi::dtype::complex<double>) {}
PD_REGISTER_KERNEL(einsum, PD_REGISTER_KERNEL(einsum_infer,
CPU, CPU,
ALL_LAYOUT, ALL_LAYOUT,
phi::EinsumKernel, phi::EinsumInferKernel,
float, float,
double, double,
phi::dtype::complex<float>, phi::dtype::complex<float>,
......
...@@ -22,12 +22,12 @@ ...@@ -22,12 +22,12 @@
namespace phi { namespace phi {
template <typename T, typename Context> template <typename T, typename Context>
void ProdRawKernel(const Context& dev_ctx, void ProdKernel(const Context& dev_ctx,
const DenseTensor& x, const DenseTensor& x,
const IntArray& dims, const IntArray& dims,
bool keep_dim, bool keep_dim,
bool reduce_all, bool reduce_all,
DenseTensor* out) { DenseTensor* out) {
reduce_all = recompute_reduce_all(x, dims, reduce_all); reduce_all = recompute_reduce_all(x, dims, reduce_all);
auto out_dtype = x.dtype(); auto out_dtype = x.dtype();
phi::Reduce<CPUContext, T, phi::funcs::ProdFunctor>( phi::Reduce<CPUContext, T, phi::funcs::ProdFunctor>(
...@@ -36,11 +36,5 @@ void ProdRawKernel(const Context& dev_ctx, ...@@ -36,11 +36,5 @@ void ProdRawKernel(const Context& dev_ctx,
} // namespace phi } // namespace phi
PD_REGISTER_KERNEL(prod_raw, PD_REGISTER_KERNEL(
CPU, prod, CPU, ALL_LAYOUT, phi::ProdKernel, float, double, int, int64_t) {}
ALL_LAYOUT,
phi::ProdRawKernel,
float,
double,
int,
int64_t) {}
...@@ -18,18 +18,18 @@ ...@@ -18,18 +18,18 @@
namespace phi { namespace phi {
template <typename T, typename Context>
void EinsumInferKernel(const Context& dev_ctx,
const std::vector<const DenseTensor*>& inputs,
const std::string& equation,
DenseTensor* out);
template <typename T, typename Context> template <typename T, typename Context>
void EinsumKernel(const Context& dev_ctx, void EinsumKernel(const Context& dev_ctx,
const std::vector<const DenseTensor*>& inputs, const std::vector<const DenseTensor*>& inputs,
const std::string& equation, const std::string& equation,
DenseTensor* out); DenseTensor* out,
std::vector<DenseTensor*> inner_cache,
template <typename T, typename Context> std::vector<DenseTensor*> xshape);
void EinsumKernelRaw(const Context& dev_ctx,
const std::vector<const DenseTensor*>& inputs,
const std::string& equation,
DenseTensor* out,
std::vector<DenseTensor*> inner_cache,
std::vector<DenseTensor*> xshape);
} // namespace phi } // namespace phi
...@@ -23,11 +23,11 @@ ...@@ -23,11 +23,11 @@
namespace phi { namespace phi {
template <typename T, typename Context> template <typename T, typename Context>
void FlattenKernel(const Context& dev_ctx, void FlattenInferKernel(const Context& dev_ctx,
const DenseTensor& x, const DenseTensor& x,
int start_axis, int start_axis,
int stop_axis, int stop_axis,
DenseTensor* out) { DenseTensor* out) {
dev_ctx.Alloc(out, x.dtype()); dev_ctx.Alloc(out, x.dtype());
auto out_dims = out->dims(); auto out_dims = out->dims();
phi::Copy(dev_ctx, x, dev_ctx.GetPlace(), false, out); phi::Copy(dev_ctx, x, dev_ctx.GetPlace(), false, out);
...@@ -38,21 +38,21 @@ void FlattenKernel(const Context& dev_ctx, ...@@ -38,21 +38,21 @@ void FlattenKernel(const Context& dev_ctx,
// Output Tensor, // Output Tensor,
// is there a more flexible way to deal with this case? // is there a more flexible way to deal with this case?
template <typename T, typename Context> template <typename T, typename Context>
void FlattenWithXShape(const Context& dev_ctx, void FlattenKernel(const Context& dev_ctx,
const DenseTensor& x, const DenseTensor& x,
int start_axis, int start_axis,
int stop_axis, int stop_axis,
DenseTensor* out, DenseTensor* out,
DenseTensor* xshape) { DenseTensor* xshape) {
FlattenKernel<T, Context>(dev_ctx, x, start_axis, stop_axis, out); FlattenInferKernel<T, Context>(dev_ctx, x, start_axis, stop_axis, out);
} }
} // namespace phi } // namespace phi
PD_REGISTER_KERNEL(flatten, PD_REGISTER_KERNEL(flatten_infer,
CPU, CPU,
ALL_LAYOUT, ALL_LAYOUT,
phi::FlattenKernel, phi::FlattenInferKernel,
float, float,
phi::dtype::bfloat16, phi::dtype::bfloat16,
double, double,
...@@ -62,10 +62,10 @@ PD_REGISTER_KERNEL(flatten, ...@@ -62,10 +62,10 @@ PD_REGISTER_KERNEL(flatten,
int, int,
int64_t) {} int64_t) {}
PD_REGISTER_KERNEL(flatten_with_xshape, PD_REGISTER_KERNEL(flatten,
CPU, CPU,
ALL_LAYOUT, ALL_LAYOUT,
phi::FlattenWithXShape, phi::FlattenKernel,
float, float,
phi::dtype::bfloat16, phi::dtype::bfloat16,
double, double,
...@@ -76,10 +76,10 @@ PD_REGISTER_KERNEL(flatten_with_xshape, ...@@ -76,10 +76,10 @@ PD_REGISTER_KERNEL(flatten_with_xshape,
int64_t) {} int64_t) {}
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
PD_REGISTER_KERNEL(flatten, PD_REGISTER_KERNEL(flatten_infer,
GPU, GPU,
ALL_LAYOUT, ALL_LAYOUT,
phi::FlattenKernel, phi::FlattenInferKernel,
float, float,
phi::dtype::float16, phi::dtype::float16,
phi::dtype::bfloat16, phi::dtype::bfloat16,
...@@ -90,10 +90,10 @@ PD_REGISTER_KERNEL(flatten, ...@@ -90,10 +90,10 @@ PD_REGISTER_KERNEL(flatten,
int, int,
int64_t) {} int64_t) {}
PD_REGISTER_KERNEL(flatten_with_xshape, PD_REGISTER_KERNEL(flatten,
GPU, GPU,
ALL_LAYOUT, ALL_LAYOUT,
phi::FlattenWithXShape, phi::FlattenKernel,
float, float,
phi::dtype::float16, phi::dtype::float16,
phi::dtype::bfloat16, phi::dtype::bfloat16,
...@@ -106,10 +106,10 @@ PD_REGISTER_KERNEL(flatten_with_xshape, ...@@ -106,10 +106,10 @@ PD_REGISTER_KERNEL(flatten_with_xshape,
#endif #endif
#ifdef PADDLE_WITH_XPU #ifdef PADDLE_WITH_XPU
PD_REGISTER_KERNEL(flatten, PD_REGISTER_KERNEL(flatten_infer,
XPU, XPU,
ALL_LAYOUT, ALL_LAYOUT,
phi::FlattenKernel, phi::FlattenInferKernel,
float, float,
phi::dtype::float16, phi::dtype::float16,
int8_t, int8_t,
...@@ -117,10 +117,10 @@ PD_REGISTER_KERNEL(flatten, ...@@ -117,10 +117,10 @@ PD_REGISTER_KERNEL(flatten,
int, int,
int64_t) {} int64_t) {}
PD_REGISTER_KERNEL(flatten_with_xshape, PD_REGISTER_KERNEL(flatten,
XPU, XPU,
ALL_LAYOUT, ALL_LAYOUT,
phi::FlattenWithXShape, phi::FlattenKernel,
float, float,
phi::dtype::float16, phi::dtype::float16,
int8_t, int8_t,
......
...@@ -20,20 +20,20 @@ limitations under the License. */ ...@@ -20,20 +20,20 @@ limitations under the License. */
namespace phi { namespace phi {
template <typename T, typename Context>
void FlattenInferKernel(const Context& dev_ctx,
const DenseTensor& x,
int start_axis,
int stop_axis,
DenseTensor* out);
template <typename T, typename Context> template <typename T, typename Context>
void FlattenKernel(const Context& dev_ctx, void FlattenKernel(const Context& dev_ctx,
const DenseTensor& x, const DenseTensor& x,
int start_axis, int start_axis,
int stop_axis, int stop_axis,
DenseTensor* out); DenseTensor* out,
DenseTensor* xshape);
template <typename T, typename Context>
void FlattenWithXShape(const Context& dev_ctx,
const DenseTensor& x,
int start_axis,
int stop_axis,
DenseTensor* out,
DenseTensor* xshape);
template <typename T, typename Context> template <typename T, typename Context>
DenseTensor Flatten(const Context& dev_ctx, DenseTensor Flatten(const Context& dev_ctx,
...@@ -43,7 +43,7 @@ DenseTensor Flatten(const Context& dev_ctx, ...@@ -43,7 +43,7 @@ DenseTensor Flatten(const Context& dev_ctx,
DenseTensor dense_out; DenseTensor dense_out;
MetaTensor meta_out(&dense_out); MetaTensor meta_out(&dense_out);
FlattenInferMeta(x, start_axis, stop_axis, &meta_out); FlattenInferMeta(x, start_axis, stop_axis, &meta_out);
FlattenKernel<T, Context>(dev_ctx, x, start_axis, stop_axis, &dense_out); FlattenInferKernel<T, Context>(dev_ctx, x, start_axis, stop_axis, &dense_out);
return dense_out; return dense_out;
} }
......
...@@ -18,10 +18,10 @@ ...@@ -18,10 +18,10 @@
#include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/einsum_impl.h" #include "paddle/phi/kernels/impl/einsum_impl.h"
PD_REGISTER_KERNEL(einsum_raw, PD_REGISTER_KERNEL(einsum,
GPU, GPU,
ALL_LAYOUT, ALL_LAYOUT,
phi::EinsumKernelRaw, phi::EinsumKernel,
float, float,
double, double,
phi::dtype::float16, phi::dtype::float16,
...@@ -29,10 +29,10 @@ PD_REGISTER_KERNEL(einsum_raw, ...@@ -29,10 +29,10 @@ PD_REGISTER_KERNEL(einsum_raw,
phi::dtype::complex<float>, phi::dtype::complex<float>,
phi::dtype::complex<double>) {} phi::dtype::complex<double>) {}
PD_REGISTER_KERNEL(einsum, PD_REGISTER_KERNEL(einsum_infer,
GPU, GPU,
ALL_LAYOUT, ALL_LAYOUT,
phi::EinsumKernel, phi::EinsumInferKernel,
float, float,
double, double,
phi::dtype::float16, phi::dtype::float16,
......
...@@ -103,7 +103,7 @@ DenseTensor PerformTileAndReduction(const Context& dev_ctx, ...@@ -103,7 +103,7 @@ DenseTensor PerformTileAndReduction(const Context& dev_ctx,
// undiagonalize by einsum equation. only contain undiagonal operations. // undiagonalize by einsum equation. only contain undiagonal operations.
DenseTensor out; DenseTensor out;
VLOG(5) << "Undiagonal by einsum with args: " << op_label + "->" + equ; VLOG(5) << "Undiagonal by einsum with args: " << op_label + "->" + equ;
EinsumKernel<T, Context>(dev_ctx, {&ret}, op_label + "->" + equ, &out); EinsumInferKernel<T, Context>(dev_ctx, {&ret}, op_label + "->" + equ, &out);
return out; return out;
} }
...@@ -157,7 +157,8 @@ void EinsumGradKernel(const Context& dev_ctx, ...@@ -157,7 +157,8 @@ void EinsumGradKernel(const Context& dev_ctx,
new_operands.push_back(&out_grad); new_operands.push_back(&out_grad);
DenseTensor before_tile; DenseTensor before_tile;
VLOG(5) << "new_equation is " << new_equation; VLOG(5) << "new_equation is " << new_equation;
EinsumKernel<T, Context>(dev_ctx, new_operands, new_equation, &before_tile); EinsumInferKernel<T, Context>(
dev_ctx, new_operands, new_equation, &before_tile);
*(x_grad[0]) = PerformTileAndReduction<T, Context>(dev_ctx, *(x_grad[0]) = PerformTileAndReduction<T, Context>(dev_ctx,
labeltype, labeltype,
labelshape, labelshape,
......
...@@ -746,12 +746,12 @@ void EinsumKernelImpl(const Context& dev_ctx, ...@@ -746,12 +746,12 @@ void EinsumKernelImpl(const Context& dev_ctx,
} }
template <typename T, typename Context> template <typename T, typename Context>
void EinsumKernelRaw(const Context& dev_ctx, void EinsumKernel(const Context& dev_ctx,
const std::vector<const DenseTensor*>& inputs, const std::vector<const DenseTensor*>& inputs,
const std::string& equation, const std::string& equation,
DenseTensor* out, DenseTensor* out,
std::vector<DenseTensor*> cache, std::vector<DenseTensor*> cache,
std::vector<DenseTensor*> xshape) { std::vector<DenseTensor*> xshape) {
std::vector<char> tmp; std::vector<char> tmp;
// for the sake of compatibility, we may load and run v2.3 EinsumOp. Output // for the sake of compatibility, we may load and run v2.3 EinsumOp. Output
// may have nullptr and the cache.size() is not equal to inputs.size(). refer // may have nullptr and the cache.size() is not equal to inputs.size(). refer
...@@ -765,10 +765,10 @@ void EinsumKernelRaw(const Context& dev_ctx, ...@@ -765,10 +765,10 @@ void EinsumKernelRaw(const Context& dev_ctx,
} }
template <typename T, typename Context> template <typename T, typename Context>
void EinsumKernel(const Context& dev_ctx, void EinsumInferKernel(const Context& dev_ctx,
const std::vector<const DenseTensor*>& inputs, const std::vector<const DenseTensor*>& inputs,
const std::string& equation, const std::string& equation,
DenseTensor* out) { DenseTensor* out) {
std::vector<char> place_holder; std::vector<char> place_holder;
std::vector<DenseTensor*> cache_tensor( std::vector<DenseTensor*> cache_tensor(
inputs.size()); // set empty; TA, TB, TdC inputs.size()); // set empty; TA, TB, TdC
......
...@@ -169,7 +169,7 @@ static void linalg_solve(const Context& dev_ctx, ...@@ -169,7 +169,7 @@ static void linalg_solve(const Context& dev_ctx,
out_tmp.Resize(out->dims()); out_tmp.Resize(out->dims());
out_tmp = *out; out_tmp = *out;
phi::SqueezeKernel<T, Context>(dev_ctx, out_tmp, {-1}, out); phi::SqueezeInferKernel<T, Context>(dev_ctx, out_tmp, {-1}, out);
} else { } else {
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
x_dim[x_dim_size - 1], x_dim[x_dim_size - 1],
......
...@@ -19,12 +19,12 @@ ...@@ -19,12 +19,12 @@
namespace phi { namespace phi {
template <typename T, typename Context> template <typename T, typename Context>
void ProdRawKernel(const Context& dev_ctx, void ProdKernel(const Context& dev_ctx,
const DenseTensor& x, const DenseTensor& x,
const IntArray& dims, const IntArray& dims,
bool keep_dim, bool keep_dim,
bool reduce_all, bool reduce_all,
DenseTensor* out) { DenseTensor* out) {
reduce_all = recompute_reduce_all(x, dims, reduce_all); reduce_all = recompute_reduce_all(x, dims, reduce_all);
auto out_dtype = x.dtype(); auto out_dtype = x.dtype();
phi::Reduce<T, kps::MulFunctor, kps::IdentityFunctor>( phi::Reduce<T, kps::MulFunctor, kps::IdentityFunctor>(
...@@ -33,14 +33,8 @@ void ProdRawKernel(const Context& dev_ctx, ...@@ -33,14 +33,8 @@ void ProdRawKernel(const Context& dev_ctx,
} // namespace phi } // namespace phi
#ifdef PADDLE_WITH_XPU_KP #ifdef PADDLE_WITH_XPU_KP
PD_REGISTER_KERNEL(prod_raw, KPS, ALL_LAYOUT, phi::ProdRawKernel, float) {} PD_REGISTER_KERNEL(prod, KPS, ALL_LAYOUT, phi::ProdKernel, float) {}
#else #else
PD_REGISTER_KERNEL(prod_raw, PD_REGISTER_KERNEL(
KPS, prod, KPS, ALL_LAYOUT, phi::ProdKernel, float, double, int, int64_t) {}
ALL_LAYOUT,
phi::ProdRawKernel,
float,
double,
int,
int64_t) {}
#endif #endif
...@@ -148,32 +148,32 @@ void ExecuteReshape(const Context& dev_ctx, ...@@ -148,32 +148,32 @@ void ExecuteReshape(const Context& dev_ctx,
} }
template <typename T, typename Context> template <typename T, typename Context>
void ReshapeKernel(const Context& dev_ctx, void ReshapeInferKernel(const Context& dev_ctx,
const DenseTensor& x, const DenseTensor& x,
const IntArray& shape, const IntArray& shape,
DenseTensor* out) { DenseTensor* out) {
auto x_dims = x.dims(); auto x_dims = x.dims();
ExecuteReshape<T, Context>(dev_ctx, x, shape, x_dims, out); ExecuteReshape<T, Context>(dev_ctx, x, shape, x_dims, out);
} }
template <typename T, typename Context> template <typename T, typename Context>
void ReshapeWithXShape(const Context& dev_ctx, void ReshapeKernel(const Context& dev_ctx,
const DenseTensor& x, const DenseTensor& x,
const IntArray& shape, const IntArray& shape,
DenseTensor* out, DenseTensor* out,
DenseTensor* xshape) { DenseTensor* xshape) {
auto x_dims = slice_ddim(xshape->dims(), 1, xshape->dims().size()); auto x_dims = slice_ddim(xshape->dims(), 1, xshape->dims().size());
ExecuteReshape<T, Context>(dev_ctx, x, shape, x_dims, out); ExecuteReshape<T, Context>(dev_ctx, x, shape, x_dims, out);
} }
} // namespace phi } // namespace phi
PD_REGISTER_KERNEL( PD_REGISTER_KERNEL(reshape_infer,
reshape, OneDNN, ONEDNN, phi::ReshapeKernel, float, phi::dtype::bfloat16) {}
PD_REGISTER_KERNEL(reshape_with_xshape,
OneDNN, OneDNN,
ONEDNN, ONEDNN,
phi::ReshapeWithXShape, phi::ReshapeInferKernel,
float, float,
phi::dtype::bfloat16) {} phi::dtype::bfloat16) {}
PD_REGISTER_KERNEL(
reshape, OneDNN, ONEDNN, phi::ReshapeKernel, float, phi::dtype::bfloat16) {}
...@@ -52,10 +52,10 @@ void ExecuteSqueeze(const Context& dev_ctx, ...@@ -52,10 +52,10 @@ void ExecuteSqueeze(const Context& dev_ctx,
} }
template <typename T, typename Context> template <typename T, typename Context>
void SqueezeKernel(const Context& dev_ctx, void SqueezeInferKernel(const Context& dev_ctx,
const DenseTensor& x, const DenseTensor& x,
const IntArray& axes, const IntArray& axes,
DenseTensor* out) { DenseTensor* out) {
auto x_dims = x.dims(); auto x_dims = x.dims();
std::vector<int32_t> tmp(axes.GetData().begin(), axes.GetData().end()); std::vector<int32_t> tmp(axes.GetData().begin(), axes.GetData().end());
auto out_dims = funcs::GetOutputSqueezeShape(tmp, x_dims, true); auto out_dims = funcs::GetOutputSqueezeShape(tmp, x_dims, true);
...@@ -63,13 +63,13 @@ void SqueezeKernel(const Context& dev_ctx, ...@@ -63,13 +63,13 @@ void SqueezeKernel(const Context& dev_ctx,
} }
template <typename T, typename Context> template <typename T, typename Context>
void SqueezeWithXShapeKernel(const Context& dev_ctx, void SqueezeKernel(const Context& dev_ctx,
const DenseTensor& x, const DenseTensor& x,
const IntArray& axes, const IntArray& axes,
DenseTensor* out, DenseTensor* out,
DenseTensor* xshape) { DenseTensor* xshape) {
if (xshape == nullptr) { if (xshape == nullptr) {
SqueezeKernel<T, Context>(dev_ctx, x, axes, out); SqueezeInferKernel<T, Context>(dev_ctx, x, axes, out);
} else { } else {
auto x_dims = slice_ddim(xshape->dims(), 1, xshape->dims().size()); auto x_dims = slice_ddim(xshape->dims(), 1, xshape->dims().size());
auto out_dims = out->dims(); auto out_dims = out->dims();
...@@ -78,12 +78,12 @@ void SqueezeWithXShapeKernel(const Context& dev_ctx, ...@@ -78,12 +78,12 @@ void SqueezeWithXShapeKernel(const Context& dev_ctx,
} }
} // namespace phi } // namespace phi
PD_REGISTER_KERNEL( PD_REGISTER_KERNEL(squeeze_infer,
squeeze, OneDNN, ONEDNN, phi::SqueezeKernel, float, phi::dtype::bfloat16) {}
PD_REGISTER_KERNEL(squeeze_with_xshape,
OneDNN, OneDNN,
ONEDNN, ONEDNN,
phi::SqueezeWithXShapeKernel, phi::SqueezeInferKernel,
float, float,
phi::dtype::bfloat16) {} phi::dtype::bfloat16) {}
PD_REGISTER_KERNEL(
squeeze, OneDNN, ONEDNN, phi::SqueezeKernel, float, phi::dtype::bfloat16) {}
...@@ -20,29 +20,41 @@ ...@@ -20,29 +20,41 @@
namespace phi { namespace phi {
template <typename T, typename Context> template <typename T, typename Context>
void ProdKernel(const Context& dev_ctx, void ProdInferKernel(const Context& dev_ctx,
const DenseTensor& x, const DenseTensor& x,
const IntArray& dims, const IntArray& dims,
bool keep_dim, bool keep_dim,
DenseTensor* out) { DenseTensor* out) {
bool reduce_all = recompute_reduce_all(x, dims); bool reduce_all = recompute_reduce_all(x, dims);
ProdRawKernel<T>(dev_ctx, x, dims, keep_dim, reduce_all, out); ProdKernel<T>(dev_ctx, x, dims, keep_dim, reduce_all, out);
} }
} // namespace phi } // namespace phi
PD_REGISTER_KERNEL( PD_REGISTER_KERNEL(prod_infer,
prod, CPU, ALL_LAYOUT, phi::ProdKernel, float, double, int, int64_t) {} CPU,
ALL_LAYOUT,
phi::ProdInferKernel,
float,
double,
int,
int64_t) {}
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
PD_REGISTER_KERNEL( PD_REGISTER_KERNEL(prod_infer,
prod, GPU, ALL_LAYOUT, phi::ProdKernel, float, double, int, int64_t) {} GPU,
ALL_LAYOUT,
phi::ProdInferKernel,
float,
double,
int,
int64_t) {}
#endif #endif
#if defined(PADDLE_WITH_XPU_KP) && !defined(PADDLE_WITH_XPU) #if defined(PADDLE_WITH_XPU_KP) && !defined(PADDLE_WITH_XPU)
PD_REGISTER_KERNEL(prod, KPS, ALL_LAYOUT, phi::ProdKernel, float) {} PD_REGISTER_KERNEL(prod_infer, KPS, ALL_LAYOUT, phi::ProdInferKernel, float) {}
#endif #endif
#if defined(PADDLE_WITH_XPU) #if defined(PADDLE_WITH_XPU)
PD_REGISTER_KERNEL(prod, XPU, ALL_LAYOUT, phi::ProdKernel, float) {} PD_REGISTER_KERNEL(prod_infer, XPU, ALL_LAYOUT, phi::ProdInferKernel, float) {}
#endif #endif
...@@ -18,19 +18,19 @@ ...@@ -18,19 +18,19 @@
#include "paddle/phi/core/dense_tensor.h" #include "paddle/phi/core/dense_tensor.h"
namespace phi { namespace phi {
template <typename T, typename Context>
void ProdRawKernel(const Context& dev_ctx,
const DenseTensor& x,
const IntArray& dims,
bool keep_dim,
bool reduce_all,
DenseTensor* out);
template <typename T, typename Context> template <typename T, typename Context>
void ProdKernel(const Context& dev_ctx, void ProdKernel(const Context& dev_ctx,
const DenseTensor& x, const DenseTensor& x,
const IntArray& dims, const IntArray& dims,
bool keep_dim, bool keep_dim,
bool reduce_all,
DenseTensor* out); DenseTensor* out);
template <typename T, typename Context>
void ProdInferKernel(const Context& dev_ctx,
const DenseTensor& x,
const IntArray& dims,
bool keep_dim,
DenseTensor* out);
} // namespace phi } // namespace phi
...@@ -26,10 +26,10 @@ ...@@ -26,10 +26,10 @@
namespace phi { namespace phi {
template <typename Context> template <typename Context>
void ReshapeKernel(const Context& dev_ctx, void ReshapeInferKernel(const Context& dev_ctx,
const DenseTensor& x, const DenseTensor& x,
const IntArray& shape, const IntArray& shape,
DenseTensor* out) { DenseTensor* out) {
MetaTensor meta_out(out); MetaTensor meta_out(out);
InferMetaFromVecValue(x, shape.GetData(), &meta_out); InferMetaFromVecValue(x, shape.GetData(), &meta_out);
if (x.initialized() && x.Holder() == out->Holder()) { if (x.initialized() && x.Holder() == out->Holder()) {
...@@ -47,10 +47,10 @@ void ReshapeKernel(const Context& dev_ctx, ...@@ -47,10 +47,10 @@ void ReshapeKernel(const Context& dev_ctx,
#ifdef PADDLE_WITH_XPU #ifdef PADDLE_WITH_XPU
template <> template <>
void ReshapeKernel<phi::XPUContext>(const XPUContext& dev_ctx, void ReshapeInferKernel<phi::XPUContext>(const XPUContext& dev_ctx,
const DenseTensor& x, const DenseTensor& x,
const IntArray& shape, const IntArray& shape,
DenseTensor* out) { DenseTensor* out) {
MetaTensor meta_out(out); MetaTensor meta_out(out);
InferMetaFromVecValue(x, shape.GetData(), &meta_out); InferMetaFromVecValue(x, shape.GetData(), &meta_out);
if (x.initialized() && x.Holder() == out->Holder()) { if (x.initialized() && x.Holder() == out->Holder()) {
...@@ -73,40 +73,40 @@ void ReshapeKernel<phi::XPUContext>(const XPUContext& dev_ctx, ...@@ -73,40 +73,40 @@ void ReshapeKernel<phi::XPUContext>(const XPUContext& dev_ctx,
#endif #endif
template <typename Context> template <typename Context>
void ReshapeWithXShape(const Context& dev_ctx, void ReshapeKernel(const Context& dev_ctx,
const DenseTensor& x, const DenseTensor& x,
const IntArray& shape, const IntArray& shape,
DenseTensor* out, DenseTensor* out,
DenseTensor* xshape) { DenseTensor* xshape) {
ReshapeKernel(dev_ctx, x, shape, out); ReshapeInferKernel(dev_ctx, x, shape, out);
} }
} // namespace phi } // namespace phi
PD_REGISTER_GENERAL_KERNEL( PD_REGISTER_GENERAL_KERNEL(reshape_infer,
reshape, CPU, ALL_LAYOUT, phi::ReshapeKernel<phi::CPUContext>, ALL_DTYPE) {}
PD_REGISTER_GENERAL_KERNEL(reshape_with_xshape,
CPU, CPU,
ALL_LAYOUT, ALL_LAYOUT,
phi::ReshapeWithXShape<phi::CPUContext>, phi::ReshapeInferKernel<phi::CPUContext>,
ALL_DTYPE) {} ALL_DTYPE) {}
PD_REGISTER_GENERAL_KERNEL(
reshape, CPU, ALL_LAYOUT, phi::ReshapeKernel<phi::CPUContext>, ALL_DTYPE) {}
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
PD_REGISTER_GENERAL_KERNEL( PD_REGISTER_GENERAL_KERNEL(reshape_infer,
reshape, GPU, ALL_LAYOUT, phi::ReshapeKernel<phi::GPUContext>, ALL_DTYPE) {}
PD_REGISTER_GENERAL_KERNEL(reshape_with_xshape,
GPU, GPU,
ALL_LAYOUT, ALL_LAYOUT,
phi::ReshapeWithXShape<phi::GPUContext>, phi::ReshapeInferKernel<phi::GPUContext>,
ALL_DTYPE) {} ALL_DTYPE) {}
PD_REGISTER_GENERAL_KERNEL(
reshape, GPU, ALL_LAYOUT, phi::ReshapeKernel<phi::GPUContext>, ALL_DTYPE) {}
#endif #endif
#ifdef PADDLE_WITH_XPU #ifdef PADDLE_WITH_XPU
PD_REGISTER_GENERAL_KERNEL( PD_REGISTER_GENERAL_KERNEL(reshape_infer,
reshape, XPU, ALL_LAYOUT, phi::ReshapeKernel<phi::XPUContext>, ALL_DTYPE) {}
PD_REGISTER_GENERAL_KERNEL(reshape_with_xshape,
XPU, XPU,
ALL_LAYOUT, ALL_LAYOUT,
phi::ReshapeWithXShape<phi::XPUContext>, phi::ReshapeInferKernel<phi::XPUContext>,
ALL_DTYPE) {} ALL_DTYPE) {}
PD_REGISTER_GENERAL_KERNEL(
reshape, XPU, ALL_LAYOUT, phi::ReshapeKernel<phi::XPUContext>, ALL_DTYPE) {}
#endif #endif
...@@ -21,18 +21,18 @@ limitations under the License. */ ...@@ -21,18 +21,18 @@ limitations under the License. */
namespace phi { namespace phi {
template <typename Context>
void ReshapeInferKernel(const Context& dev_ctx,
const DenseTensor& x,
const IntArray& shape,
DenseTensor* out);
template <typename Context> template <typename Context>
void ReshapeKernel(const Context& dev_ctx, void ReshapeKernel(const Context& dev_ctx,
const DenseTensor& x, const DenseTensor& x,
const IntArray& shape, const IntArray& shape,
DenseTensor* out); DenseTensor* out,
DenseTensor* xshape);
template <typename Context>
void ReshapeWithXShape(const Context& dev_ctx,
const DenseTensor& x,
const IntArray& shape,
DenseTensor* out,
DenseTensor* xshape);
template <typename T, typename Context> template <typename T, typename Context>
DenseTensor Reshape(const Context& dev_ctx, DenseTensor Reshape(const Context& dev_ctx,
...@@ -41,7 +41,7 @@ DenseTensor Reshape(const Context& dev_ctx, ...@@ -41,7 +41,7 @@ DenseTensor Reshape(const Context& dev_ctx,
DenseTensor dense_out; DenseTensor dense_out;
MetaTensor meta_out(&dense_out); MetaTensor meta_out(&dense_out);
InferMetaFromVecValue(x, shape, &meta_out); InferMetaFromVecValue(x, shape, &meta_out);
ReshapeKernel<Context>(dev_ctx, x, IntArray(shape), &dense_out); ReshapeInferKernel<Context>(dev_ctx, x, IntArray(shape), &dense_out);
return dense_out; return dense_out;
} }
......
...@@ -21,10 +21,10 @@ ...@@ -21,10 +21,10 @@
namespace phi { namespace phi {
template <typename T, typename Context> template <typename T, typename Context>
void SqueezeKernel(const Context& dev_ctx, void SqueezeInferKernel(const Context& dev_ctx,
const DenseTensor& x, const DenseTensor& x,
const IntArray& axes, const IntArray& axes,
DenseTensor* out) { DenseTensor* out) {
auto x_dims = x.dims(); auto x_dims = x.dims();
std::vector<int32_t> tmp(axes.GetData().begin(), axes.GetData().end()); std::vector<int32_t> tmp(axes.GetData().begin(), axes.GetData().end());
auto out_dims = funcs::GetOutputSqueezeShape(tmp, x_dims, true); auto out_dims = funcs::GetOutputSqueezeShape(tmp, x_dims, true);
...@@ -36,20 +36,20 @@ void SqueezeKernel(const Context& dev_ctx, ...@@ -36,20 +36,20 @@ void SqueezeKernel(const Context& dev_ctx,
} }
template <typename T, typename Context> template <typename T, typename Context>
void SqueezeWithXShapeKernel(const Context& dev_ctx, void SqueezeKernel(const Context& dev_ctx,
const DenseTensor& x, const DenseTensor& x,
const IntArray& axes, const IntArray& axes,
DenseTensor* out, DenseTensor* out,
DenseTensor* xshape) { DenseTensor* xshape) {
SqueezeKernel<T, Context>(dev_ctx, x, axes, out); SqueezeInferKernel<T, Context>(dev_ctx, x, axes, out);
} }
} // namespace phi } // namespace phi
PD_REGISTER_KERNEL(squeeze, PD_REGISTER_KERNEL(squeeze_infer,
CPU, CPU,
ALL_LAYOUT, ALL_LAYOUT,
phi::SqueezeKernel, phi::SqueezeInferKernel,
float, float,
double, double,
phi::dtype::bfloat16, phi::dtype::bfloat16,
...@@ -61,10 +61,10 @@ PD_REGISTER_KERNEL(squeeze, ...@@ -61,10 +61,10 @@ PD_REGISTER_KERNEL(squeeze,
phi::dtype::complex<float>, phi::dtype::complex<float>,
phi::dtype::complex<double>) {} phi::dtype::complex<double>) {}
PD_REGISTER_KERNEL(squeeze_with_xshape, PD_REGISTER_KERNEL(squeeze,
CPU, CPU,
ALL_LAYOUT, ALL_LAYOUT,
phi::SqueezeWithXShapeKernel, phi::SqueezeKernel,
float, float,
double, double,
phi::dtype::bfloat16, phi::dtype::bfloat16,
...@@ -76,10 +76,10 @@ PD_REGISTER_KERNEL(squeeze_with_xshape, ...@@ -76,10 +76,10 @@ PD_REGISTER_KERNEL(squeeze_with_xshape,
phi::dtype::complex<float>, phi::dtype::complex<float>,
phi::dtype::complex<double>) {} phi::dtype::complex<double>) {}
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
PD_REGISTER_KERNEL(squeeze, PD_REGISTER_KERNEL(squeeze_infer,
GPU, GPU,
ALL_LAYOUT, ALL_LAYOUT,
phi::SqueezeKernel, phi::SqueezeInferKernel,
float, float,
double, double,
phi::dtype::float16, phi::dtype::float16,
...@@ -92,10 +92,10 @@ PD_REGISTER_KERNEL(squeeze, ...@@ -92,10 +92,10 @@ PD_REGISTER_KERNEL(squeeze,
phi::dtype::complex<float>, phi::dtype::complex<float>,
phi::dtype::complex<double>) {} phi::dtype::complex<double>) {}
PD_REGISTER_KERNEL(squeeze_with_xshape, PD_REGISTER_KERNEL(squeeze,
GPU, GPU,
ALL_LAYOUT, ALL_LAYOUT,
phi::SqueezeWithXShapeKernel, phi::SqueezeKernel,
float, float,
double, double,
phi::dtype::float16, phi::dtype::float16,
...@@ -110,10 +110,10 @@ PD_REGISTER_KERNEL(squeeze_with_xshape, ...@@ -110,10 +110,10 @@ PD_REGISTER_KERNEL(squeeze_with_xshape,
#endif #endif
#ifdef PADDLE_WITH_XPU #ifdef PADDLE_WITH_XPU
PD_REGISTER_KERNEL(squeeze, PD_REGISTER_KERNEL(squeeze_infer,
XPU, XPU,
ALL_LAYOUT, ALL_LAYOUT,
phi::SqueezeKernel, phi::SqueezeInferKernel,
float, float,
double, double,
phi::dtype::float16, phi::dtype::float16,
...@@ -123,10 +123,10 @@ PD_REGISTER_KERNEL(squeeze, ...@@ -123,10 +123,10 @@ PD_REGISTER_KERNEL(squeeze,
int8_t, int8_t,
int64_t) {} int64_t) {}
PD_REGISTER_KERNEL(squeeze_with_xshape, PD_REGISTER_KERNEL(squeeze,
XPU, XPU,
ALL_LAYOUT, ALL_LAYOUT,
phi::SqueezeWithXShapeKernel, phi::SqueezeKernel,
float, float,
double, double,
phi::dtype::float16, phi::dtype::float16,
......
...@@ -20,17 +20,17 @@ ...@@ -20,17 +20,17 @@
namespace phi { namespace phi {
template <typename T, typename Context>
void SqueezeInferKernel(const Context& dev_ctx,
const DenseTensor& x,
const IntArray& axes,
DenseTensor* out);
template <typename T, typename Context> template <typename T, typename Context>
void SqueezeKernel(const Context& dev_ctx, void SqueezeKernel(const Context& dev_ctx,
const DenseTensor& x, const DenseTensor& x,
const IntArray& axes, const IntArray& axes,
DenseTensor* out); DenseTensor* out,
DenseTensor* xshape);
template <typename T, typename Context>
void SqueezeWithXShapeKernel(const Context& dev_ctx,
const DenseTensor& x,
const IntArray& axes,
DenseTensor* out,
DenseTensor* xshape);
} // namespace phi } // namespace phi
...@@ -21,10 +21,10 @@ ...@@ -21,10 +21,10 @@
namespace phi { namespace phi {
template <typename T, typename Context> template <typename T, typename Context>
void UnsqueezeKernel(const Context& dev_ctx, void UnsqueezeInferKernel(const Context& dev_ctx,
const DenseTensor& x, const DenseTensor& x,
const IntArray& axes, const IntArray& axes,
DenseTensor* out) { DenseTensor* out) {
auto x_dims = x.dims(); auto x_dims = x.dims();
auto out_dims = out->dims(); auto out_dims = out->dims();
if (axes.FromTensor()) { if (axes.FromTensor()) {
...@@ -42,19 +42,19 @@ void UnsqueezeKernel(const Context& dev_ctx, ...@@ -42,19 +42,19 @@ void UnsqueezeKernel(const Context& dev_ctx,
} }
template <typename T, typename Context> template <typename T, typename Context>
void UnsqueezeWithXShapeKernel(const Context& dev_ctx, void UnsqueezeKernel(const Context& dev_ctx,
const DenseTensor& x, const DenseTensor& x,
const IntArray& axes, const IntArray& axes,
DenseTensor* out, DenseTensor* out,
DenseTensor* xshape) { DenseTensor* xshape) {
UnsqueezeKernel<T, Context>(dev_ctx, x, axes, out); UnsqueezeInferKernel<T, Context>(dev_ctx, x, axes, out);
} }
} // namespace phi } // namespace phi
PD_REGISTER_KERNEL(unsqueeze, PD_REGISTER_KERNEL(unsqueeze_infer,
CPU, CPU,
ALL_LAYOUT, ALL_LAYOUT,
phi::UnsqueezeKernel, phi::UnsqueezeInferKernel,
float, float,
double, double,
phi::dtype::bfloat16, phi::dtype::bfloat16,
...@@ -67,10 +67,10 @@ PD_REGISTER_KERNEL(unsqueeze, ...@@ -67,10 +67,10 @@ PD_REGISTER_KERNEL(unsqueeze,
phi::dtype::complex<float>, phi::dtype::complex<float>,
phi::dtype::complex<double>) {} phi::dtype::complex<double>) {}
PD_REGISTER_KERNEL(unsqueeze_with_xshape, PD_REGISTER_KERNEL(unsqueeze,
CPU, CPU,
ALL_LAYOUT, ALL_LAYOUT,
phi::UnsqueezeWithXShapeKernel, phi::UnsqueezeKernel,
float, float,
double, double,
phi::dtype::bfloat16, phi::dtype::bfloat16,
...@@ -83,10 +83,10 @@ PD_REGISTER_KERNEL(unsqueeze_with_xshape, ...@@ -83,10 +83,10 @@ PD_REGISTER_KERNEL(unsqueeze_with_xshape,
phi::dtype::complex<float>, phi::dtype::complex<float>,
phi::dtype::complex<double>) {} phi::dtype::complex<double>) {}
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
PD_REGISTER_KERNEL(unsqueeze, PD_REGISTER_KERNEL(unsqueeze_infer,
GPU, GPU,
ALL_LAYOUT, ALL_LAYOUT,
phi::UnsqueezeKernel, phi::UnsqueezeInferKernel,
float, float,
double, double,
phi::dtype::float16, phi::dtype::float16,
...@@ -100,10 +100,10 @@ PD_REGISTER_KERNEL(unsqueeze, ...@@ -100,10 +100,10 @@ PD_REGISTER_KERNEL(unsqueeze,
phi::dtype::complex<float>, phi::dtype::complex<float>,
phi::dtype::complex<double>) {} phi::dtype::complex<double>) {}
PD_REGISTER_KERNEL(unsqueeze_with_xshape, PD_REGISTER_KERNEL(unsqueeze,
GPU, GPU,
ALL_LAYOUT, ALL_LAYOUT,
phi::UnsqueezeWithXShapeKernel, phi::UnsqueezeKernel,
float, float,
double, double,
phi::dtype::float16, phi::dtype::float16,
...@@ -119,10 +119,10 @@ PD_REGISTER_KERNEL(unsqueeze_with_xshape, ...@@ -119,10 +119,10 @@ PD_REGISTER_KERNEL(unsqueeze_with_xshape,
#endif #endif
#ifdef PADDLE_WITH_XPU #ifdef PADDLE_WITH_XPU
PD_REGISTER_KERNEL(unsqueeze, PD_REGISTER_KERNEL(unsqueeze_infer,
XPU, XPU,
ALL_LAYOUT, ALL_LAYOUT,
phi::UnsqueezeKernel, phi::UnsqueezeInferKernel,
float, float,
double, double,
phi::dtype::float16, phi::dtype::float16,
...@@ -132,10 +132,10 @@ PD_REGISTER_KERNEL(unsqueeze, ...@@ -132,10 +132,10 @@ PD_REGISTER_KERNEL(unsqueeze,
int8_t, int8_t,
int64_t) {} int64_t) {}
PD_REGISTER_KERNEL(unsqueeze_with_xshape, PD_REGISTER_KERNEL(unsqueeze,
XPU, XPU,
ALL_LAYOUT, ALL_LAYOUT,
phi::UnsqueezeWithXShapeKernel, phi::UnsqueezeKernel,
float, float,
double, double,
phi::dtype::float16, phi::dtype::float16,
......
...@@ -21,18 +21,18 @@ ...@@ -21,18 +21,18 @@
namespace phi { namespace phi {
template <typename T, typename Context>
void UnsqueezeInferKernel(const Context& dev_ctx,
const DenseTensor& x,
const IntArray& axes,
DenseTensor* out);
template <typename T, typename Context> template <typename T, typename Context>
void UnsqueezeKernel(const Context& dev_ctx, void UnsqueezeKernel(const Context& dev_ctx,
const DenseTensor& x, const DenseTensor& x,
const IntArray& axes, const IntArray& axes,
DenseTensor* out); DenseTensor* out,
DenseTensor* xshape);
template <typename T, typename Context>
void UnsqueezeWithXShapeKernel(const Context& dev_ctx,
const DenseTensor& x,
const IntArray& axes,
DenseTensor* out,
DenseTensor* xshape);
template <typename T, typename Context> template <typename T, typename Context>
void Unsqueeze(const Context& dev_ctx, void Unsqueeze(const Context& dev_ctx,
...@@ -42,7 +42,7 @@ void Unsqueeze(const Context& dev_ctx, ...@@ -42,7 +42,7 @@ void Unsqueeze(const Context& dev_ctx,
DenseTensor* xshape) { DenseTensor* xshape) {
MetaTensor meta_out(out); MetaTensor meta_out(out);
UnsqueezeInferMeta(x, axes, &meta_out); UnsqueezeInferMeta(x, axes, &meta_out);
UnsqueezeKernel<T, Context>(dev_ctx, x, axes, out); UnsqueezeInferKernel<T, Context>(dev_ctx, x, axes, out);
} }
} // namespace phi } // namespace phi
...@@ -22,12 +22,12 @@ ...@@ -22,12 +22,12 @@
namespace phi { namespace phi {
template <typename T, typename Context> template <typename T, typename Context>
void ProdRawKernel(const Context& dev_ctx, void ProdKernel(const Context& dev_ctx,
const DenseTensor& x, const DenseTensor& x,
const IntArray& dims, const IntArray& dims,
bool keep_dim, bool keep_dim,
bool reduce_all, bool reduce_all,
DenseTensor* out) { DenseTensor* out) {
reduce_all = recompute_reduce_all(x, dims, reduce_all); reduce_all = recompute_reduce_all(x, dims, reduce_all);
using XPUType = typename XPUTypeTrait<T>::Type; using XPUType = typename XPUTypeTrait<T>::Type;
...@@ -46,4 +46,4 @@ void ProdRawKernel(const Context& dev_ctx, ...@@ -46,4 +46,4 @@ void ProdRawKernel(const Context& dev_ctx,
} // namespace phi } // namespace phi
PD_REGISTER_KERNEL(prod_raw, XPU, ALL_LAYOUT, phi::ProdRawKernel, float) {} PD_REGISTER_KERNEL(prod, XPU, ALL_LAYOUT, phi::ProdKernel, float) {}
...@@ -17,10 +17,8 @@ limitations under the License. */ ...@@ -17,10 +17,8 @@ limitations under the License. */
namespace phi { namespace phi {
KernelSignature EinsumOpArgumentMapping(const ArgumentMappingContext& ctx) { KernelSignature EinsumOpArgumentMapping(const ArgumentMappingContext& ctx) {
return KernelSignature("einsum_raw", return KernelSignature(
{"Operands"}, "einsum", {"Operands"}, {"equation"}, {"Out", "InnerCache", "XShape"});
{"equation"},
{"Out", "InnerCache", "XShape"});
} }
KernelSignature EinsumGradOpArgumentMapping(const ArgumentMappingContext& ctx) { KernelSignature EinsumGradOpArgumentMapping(const ArgumentMappingContext& ctx) {
...@@ -31,7 +29,5 @@ KernelSignature EinsumGradOpArgumentMapping(const ArgumentMappingContext& ctx) { ...@@ -31,7 +29,5 @@ KernelSignature EinsumGradOpArgumentMapping(const ArgumentMappingContext& ctx) {
} }
} // namespace phi } // namespace phi
PD_REGISTER_BASE_KERNEL_NAME(einsum, einsum_raw);
PD_REGISTER_ARG_MAPPING_FN(einsum, phi::EinsumOpArgumentMapping); PD_REGISTER_ARG_MAPPING_FN(einsum, phi::EinsumOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(einsum_grad, phi::EinsumGradOpArgumentMapping); PD_REGISTER_ARG_MAPPING_FN(einsum_grad, phi::EinsumGradOpArgumentMapping);
...@@ -18,13 +18,11 @@ namespace phi { ...@@ -18,13 +18,11 @@ namespace phi {
KernelSignature FlattenOpArgumentMapping(const ArgumentMappingContext& ctx) { KernelSignature FlattenOpArgumentMapping(const ArgumentMappingContext& ctx) {
if (ctx.HasOutput("XShape")) { if (ctx.HasOutput("XShape")) {
return KernelSignature("flatten_with_xshape", return KernelSignature(
{"X"}, "flatten", {"X"}, {"start_axis", "stop_axis"}, {"Out", "XShape"});
{"start_axis", "stop_axis"},
{"Out", "XShape"});
} else { } else {
return KernelSignature( return KernelSignature(
"flatten", {"X"}, {"start_axis", "stop_axis"}, {"Out"}); "flatten_infer", {"X"}, {"start_axis", "stop_axis"}, {"Out"});
} }
} }
......
...@@ -60,9 +60,9 @@ KernelSignature ReduceProdOpArgumentMapping(const ArgumentMappingContext& ctx) { ...@@ -60,9 +60,9 @@ KernelSignature ReduceProdOpArgumentMapping(const ArgumentMappingContext& ctx) {
// the "max_raw" KernelSignature // the "max_raw" KernelSignature
if (ctx.IsForInferShape() || reduce_all) { if (ctx.IsForInferShape() || reduce_all) {
return KernelSignature( return KernelSignature(
"prod_raw", {"X"}, {"dim", "keep_dim", "reduce_all"}, {"Out"}); "prod", {"X"}, {"dim", "keep_dim", "reduce_all"}, {"Out"});
} }
return KernelSignature("prod", {"X"}, {"dim", "keep_dim"}, {"Out"}); return KernelSignature("prod_infer", {"X"}, {"dim", "keep_dim"}, {"Out"});
} }
return KernelSignature("unregistered", {}, {}, {}); return KernelSignature("unregistered", {}, {}, {});
} }
......
...@@ -20,21 +20,19 @@ KernelSignature ReshapeOpArgumentMapping(const ArgumentMappingContext& ctx) { ...@@ -20,21 +20,19 @@ KernelSignature ReshapeOpArgumentMapping(const ArgumentMappingContext& ctx) {
if (ctx.HasOutput("XShape")) { if (ctx.HasOutput("XShape")) {
if (ctx.InputSize("ShapeTensor") > 0) { if (ctx.InputSize("ShapeTensor") > 0) {
return KernelSignature( return KernelSignature(
"reshape_with_xshape", {"X"}, {"ShapeTensor"}, {"Out", "XShape"}); "reshape", {"X"}, {"ShapeTensor"}, {"Out", "XShape"});
} else if (ctx.HasInput("Shape")) { } else if (ctx.HasInput("Shape")) {
return KernelSignature( return KernelSignature("reshape", {"X"}, {"Shape"}, {"Out", "XShape"});
"reshape_with_xshape", {"X"}, {"Shape"}, {"Out", "XShape"});
} else { } else {
return KernelSignature( return KernelSignature("reshape", {"X"}, {"shape"}, {"Out", "XShape"});
"reshape_with_xshape", {"X"}, {"shape"}, {"Out", "XShape"});
} }
} else { } else {
if (ctx.InputSize("ShapeTensor") > 0) { if (ctx.InputSize("ShapeTensor") > 0) {
return KernelSignature("reshape", {"X"}, {"ShapeTensor"}, {"Out"}); return KernelSignature("reshape_infer", {"X"}, {"ShapeTensor"}, {"Out"});
} else if (ctx.HasInput("Shape")) { } else if (ctx.HasInput("Shape")) {
return KernelSignature("reshape", {"X"}, {"Shape"}, {"Out"}); return KernelSignature("reshape_infer", {"X"}, {"Shape"}, {"Out"});
} else { } else {
return KernelSignature("reshape", {"X"}, {"shape"}, {"Out"}); return KernelSignature("reshape_infer", {"X"}, {"shape"}, {"Out"});
} }
} }
} }
......
...@@ -618,18 +618,18 @@ TEST(ARG_MAP, reshape) { ...@@ -618,18 +618,18 @@ TEST(ARG_MAP, reshape) {
TestArgumentMappingContext arg_case1({"X", "ShapeTensor"}, {}, {}, {"Out"}); TestArgumentMappingContext arg_case1({"X", "ShapeTensor"}, {}, {}, {"Out"});
auto signature1 = auto signature1 =
(*OpUtilsMap::Instance().GetArgumentMappingFn("reshape2"))(arg_case1); (*OpUtilsMap::Instance().GetArgumentMappingFn("reshape2"))(arg_case1);
EXPECT_STREQ(signature1.name, "reshape"); EXPECT_STREQ(signature1.name, "reshape_infer");
TestArgumentMappingContext arg_case2({"X", "Shape"}, {}, {}, {"Out"}); TestArgumentMappingContext arg_case2({"X", "Shape"}, {}, {}, {"Out"});
auto signature2 = auto signature2 =
(*OpUtilsMap::Instance().GetArgumentMappingFn("reshape2"))(arg_case2); (*OpUtilsMap::Instance().GetArgumentMappingFn("reshape2"))(arg_case2);
EXPECT_STREQ(signature2.name, "reshape"); EXPECT_STREQ(signature2.name, "reshape_infer");
TestArgumentMappingContext arg_case3( TestArgumentMappingContext arg_case3(
{"X"}, {}, {{"shape", paddle::any(std::vector<int>({1, 2}))}}, {"Out"}); {"X"}, {}, {{"shape", paddle::any(std::vector<int>({1, 2}))}}, {"Out"});
auto signature3 = auto signature3 =
(*OpUtilsMap::Instance().GetArgumentMappingFn("reshape2"))(arg_case3); (*OpUtilsMap::Instance().GetArgumentMappingFn("reshape2"))(arg_case3);
EXPECT_STREQ(signature3.name, "reshape"); EXPECT_STREQ(signature3.name, "reshape_infer");
} }
} // namespace tests } // namespace tests
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册