未验证 提交 6f86c96b 编写于 作者: Z zhangyuqin1998 提交者: GitHub

Delete randperm raw op (#51631)

* Delete randperm raw op

* fix
上级 3734e89a
......@@ -514,11 +514,6 @@ XPUOpMap& get_kl2_ops() {
XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})},
{"prod_raw", XPUKernelSet({phi::DataType::FLOAT32})},
{"range", XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::INT64})},
{"randperm_raw",
XPUKernelSet({phi::DataType::INT32,
phi::DataType::INT64,
phi::DataType::FLOAT32,
phi::DataType::FLOAT64})},
{"randperm",
XPUKernelSet({phi::DataType::INT32,
phi::DataType::INT64,
......
......@@ -19,9 +19,12 @@
namespace phi {
template <typename T, typename Context>
void RandpermRawKernel(
const Context& dev_ctx, int n, DataType dtype, int seed, DenseTensor* out) {
void RandpermKernel(const Context& dev_ctx,
int n,
DataType dtype,
DenseTensor* out) {
T* out_data = dev_ctx.template Alloc<T>(out);
int seed = 0;
std::shared_ptr<std::mt19937_64> engine;
if (seed) {
......@@ -37,25 +40,8 @@ void RandpermRawKernel(
std::shuffle(out_data, out_data + n, *engine);
}
template <typename T, typename Context>
void RandpermKernel(const Context& dev_ctx,
int n,
DataType dtype,
DenseTensor* out) {
RandpermRawKernel<T>(dev_ctx, n, dtype, 0, out);
}
} // namespace phi
PD_REGISTER_KERNEL(randperm_raw,
CPU,
ALL_LAYOUT,
phi::RandpermRawKernel,
float,
double,
int,
int64_t) {}
PD_REGISTER_KERNEL(randperm,
CPU,
ALL_LAYOUT,
......
......@@ -85,9 +85,12 @@ __global__ void SwapRepeatKernel(keyT* key_out_data,
}
template <typename T, typename Context>
void RandpermRawKernel(
const Context& dev_ctx, int n, DataType dtype, int seed, DenseTensor* out) {
void RandpermKernel(const Context& dev_ctx,
int n,
DataType dtype,
DenseTensor* out) {
DenseTensor key;
int seed = 0;
RandintKernel<int, Context>(dev_ctx,
std::numeric_limits<int>::min(),
std::numeric_limits<int>::max(),
......@@ -151,25 +154,8 @@ void RandpermRawKernel(
key_out.data<int>(), out_data, n, seed_offset.first, seed_offset.second);
}
template <typename T, typename Context>
void RandpermKernel(const Context& dev_ctx,
int n,
DataType dtype,
DenseTensor* out) {
RandpermRawKernel<T>(dev_ctx, n, dtype, 0, out);
}
} // namespace phi
PD_REGISTER_KERNEL(randperm_raw,
GPU,
ALL_LAYOUT,
phi::RandpermRawKernel,
float,
double,
int,
int64_t) {}
PD_REGISTER_KERNEL(randperm,
GPU,
ALL_LAYOUT,
......
......@@ -19,10 +19,6 @@
namespace phi {
template <typename T, typename Context>
void RandpermRawKernel(
const Context& dev_ctx, int n, DataType dtype, int seed, DenseTensor* out);
template <typename T, typename Context>
void RandpermKernel(const Context& dev_ctx,
int n,
......
......@@ -21,10 +21,12 @@
namespace phi {
template <typename T, typename Context>
void RandpermRawKernel(
const Context& dev_ctx, int n, DataType dtype, int seed, DenseTensor* out) {
void RandpermKernel(const Context& dev_ctx,
int n,
DataType dtype,
DenseTensor* out) {
std::shared_ptr<std::mt19937_64> engine;
int seed = 0;
if (seed) {
engine = std::make_shared<std::mt19937_64>();
engine->seed(seed);
......@@ -51,25 +53,8 @@ void RandpermRawKernel(
}
}
template <typename T, typename Context>
void RandpermKernel(const Context& dev_ctx,
int n,
DataType dtype,
DenseTensor* out) {
RandpermRawKernel<T, Context>(dev_ctx, n, dtype, 0, out);
}
} // namespace phi
PD_REGISTER_KERNEL(randperm_raw,
XPU,
ALL_LAYOUT,
phi::RandpermRawKernel,
int,
int64_t,
float,
double) {}
PD_REGISTER_KERNEL(randperm,
XPU,
ALL_LAYOUT,
......
......@@ -17,12 +17,7 @@
namespace phi {
KernelSignature RandpermOpArgumentMapping(const ArgumentMappingContext& ctx) {
int seed = paddle::any_cast<int>(ctx.Attr("seed"));
if (seed) {
return KernelSignature("randperm", {}, {"n", "dtype", "seed"}, {"Out"});
} else {
return KernelSignature("randperm", {}, {"n", "dtype"}, {"Out"});
}
}
} // namespace phi
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册