From 13e2e10cfddccc75fc483698f72a5cb78ea51111 Mon Sep 17 00:00:00 2001 From: zhangyuqin1998 <75946871+zhangyuqin1998@users.noreply.github.com> Date: Sat, 6 May 2023 10:04:01 +0800 Subject: [PATCH] move UniformRawKernel to legacy (#53158) * move UniformRawKernel to legacy * Update uniform_kernel.cc * Update uniform_kernel.cu * Update uniform_kernel.cc * Update uniform_kernel.cu * Update uniform_kernel.h * Update uniform_kernel.cc * Empty Commit to setup deployments --- paddle/phi/kernels/CMakeLists.txt | 1 + paddle/phi/kernels/cpu/uniform_kernel.cc | 38 ++------ paddle/phi/kernels/gpu/uniform_kernel.cu | 27 +++--- .../phi/kernels/legacy/cpu/uniform_kernel.cc | 70 ++++++++++++++ .../phi/kernels/legacy/gpu/uniform_kernel.cu | 96 +++++++++++++++++++ paddle/phi/kernels/legacy/uniform_kernel.h | 36 +++++++ .../phi/kernels/legacy/xpu/uniform_kernel.cc | 80 ++++++++++++++++ .../kernels/selected_rows/uniform_kernel.cc | 1 + paddle/phi/kernels/uniform_kernel.cc | 65 ------------- paddle/phi/kernels/uniform_kernel.h | 12 --- paddle/phi/kernels/xpu/uniform_kernel.cc | 23 +++-- 11 files changed, 316 insertions(+), 133 deletions(-) create mode 100644 paddle/phi/kernels/legacy/cpu/uniform_kernel.cc create mode 100644 paddle/phi/kernels/legacy/gpu/uniform_kernel.cu create mode 100644 paddle/phi/kernels/legacy/uniform_kernel.h create mode 100644 paddle/phi/kernels/legacy/xpu/uniform_kernel.cc delete mode 100644 paddle/phi/kernels/uniform_kernel.cc diff --git a/paddle/phi/kernels/CMakeLists.txt b/paddle/phi/kernels/CMakeLists.txt index de67958d5fe..9b78cbf4246 100644 --- a/paddle/phi/kernels/CMakeLists.txt +++ b/paddle/phi/kernels/CMakeLists.txt @@ -112,6 +112,7 @@ file( "gpudnn/*.cu" "kps/*.cu" "legacy/kps/*.cu" + "legacy/gpu/*.cu" "selected_rows/gpu/*.cu" "sparse/gpu/*.cu" "strings/gpu/*.cu" diff --git a/paddle/phi/kernels/cpu/uniform_kernel.cc b/paddle/phi/kernels/cpu/uniform_kernel.cc index 1b1503473d9..17ee5459188 100644 --- a/paddle/phi/kernels/cpu/uniform_kernel.cc +++ b/paddle/phi/kernels/cpu/uniform_kernel.cc @@ -20,16 +20,13 @@ namespace phi { template -void UniformRawKernel(const Context &dev_ctx, - const IntArray &shape, - DataType dtype, - const Scalar &min, - const Scalar &max, - int seed, - int diag_num, - int diag_step, - float diag_val, - DenseTensor *out) { +void UniformKernel(const Context &dev_ctx, + const IntArray &shape, + DataType dtype, + const Scalar &min, + const Scalar &max, + int seed, + DenseTensor *out) { out->Resize(phi::make_ddim(shape.GetData())); T *data = dev_ctx.template Alloc(out); auto size = out->numel(); @@ -42,31 +39,14 @@ void UniformRawKernel(const Context &dev_ctx, } UniformRealDistribution( data, size, min.to(), max.to(), engine); - if (diag_num > 0) { - PADDLE_ENFORCE_GT( - size, - (diag_num - 1) * (diag_step + 1), - phi::errors::InvalidArgument( - "ShapeInvalid: the diagonal's elements is equal (num-1) " - "* (step-1) with num %d, step %d," - "It should be smaller than %d, but received %d", - diag_num, - diag_step, - (diag_num - 1) * (diag_step + 1), - size)); - for (int64_t i = 0; i < diag_num; ++i) { - int64_t pos = i * diag_step + i; - data[pos] = diag_val; - } - } } } // namespace phi -PD_REGISTER_KERNEL(uniform_raw, +PD_REGISTER_KERNEL(uniform, CPU, ALL_LAYOUT, - phi::UniformRawKernel, + phi::UniformKernel, float, double, phi::dtype::bfloat16) {} diff --git a/paddle/phi/kernels/gpu/uniform_kernel.cu b/paddle/phi/kernels/gpu/uniform_kernel.cu index fe36fe5fc6e..1ba5847fa29 100644 --- a/paddle/phi/kernels/gpu/uniform_kernel.cu +++ b/paddle/phi/kernels/gpu/uniform_kernel.cu @@ -54,16 +54,13 @@ struct UniformGenerator { }; template -void UniformRawKernel(const Context& dev_ctx, - const IntArray& shape, - DataType dtype, - const Scalar& min, - const Scalar& max, - int seed, - int diag_num, - int diag_step, - float diag_val, - DenseTensor* out) { +void UniformKernel(const Context& dev_ctx, + const IntArray& shape, + DataType dtype, + const Scalar& min, + const Scalar& max, + int seed, + DenseTensor* out) { out->Resize(phi::make_ddim(shape.GetData())); dev_ctx.template Alloc(out); if (seed == 0) { @@ -77,19 +74,19 @@ void UniformRawKernel(const Context& dev_ctx, auto func = UniformGenerator(static_cast(min.to()), static_cast(max.to()), seed, - diag_num, - diag_step, - static_cast(diag_val)); + 0, + 0, + static_cast(0.0)); IndexKernel>(dev_ctx, out, func); } } } // namespace phi -PD_REGISTER_KERNEL(uniform_raw, +PD_REGISTER_KERNEL(uniform, GPU, ALL_LAYOUT, - phi::UniformRawKernel, + phi::UniformKernel, float, double, phi::dtype::float16, diff --git a/paddle/phi/kernels/legacy/cpu/uniform_kernel.cc b/paddle/phi/kernels/legacy/cpu/uniform_kernel.cc new file mode 100644 index 00000000000..ecea86874a7 --- /dev/null +++ b/paddle/phi/kernels/legacy/cpu/uniform_kernel.cc @@ -0,0 +1,70 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/funcs/uniform_real_distribution.h" + +namespace phi { + +template +void UniformRawKernel(const Context &dev_ctx, + const IntArray &shape, + DataType dtype, + const Scalar &min, + const Scalar &max, + int seed, + int diag_num, + int diag_step, + float diag_val, + DenseTensor *out) { + out->Resize(phi::make_ddim(shape.GetData())); + T *data = dev_ctx.template Alloc(out); + auto size = out->numel(); + std::shared_ptr engine; + if (seed) { + engine = std::make_shared(); + engine->seed(seed); + } else { + engine = dev_ctx.GetGenerator()->GetCPUEngine(); + } + UniformRealDistribution( + data, size, min.to(), max.to(), engine); + if (diag_num > 0) { + PADDLE_ENFORCE_GT( + size, + (diag_num - 1) * (diag_step + 1), + phi::errors::InvalidArgument( + "ShapeInvalid: the diagonal's elements is equal (num-1) " + "* (step-1) with num %d, step %d," + "It should be smaller than %d, but received %d", + diag_num, + diag_step, + (diag_num - 1) * (diag_step + 1), + size)); + for (int64_t i = 0; i < diag_num; ++i) { + int64_t pos = i * diag_step + i; + data[pos] = diag_val; + } + } +} + +} // namespace phi + +PD_REGISTER_KERNEL(uniform_raw, + CPU, + ALL_LAYOUT, + phi::UniformRawKernel, + float, + double, + phi::dtype::bfloat16) {} diff --git a/paddle/phi/kernels/legacy/gpu/uniform_kernel.cu b/paddle/phi/kernels/legacy/gpu/uniform_kernel.cu new file mode 100644 index 00000000000..211c7accf6f --- /dev/null +++ b/paddle/phi/kernels/legacy/gpu/uniform_kernel.cu @@ -0,0 +1,96 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/uniform_kernel.h" + +#include + +#include "gflags/gflags.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/funcs/distribution_helper.h" +#include "paddle/phi/kernels/funcs/index_impl.cu.h" + +namespace phi { + +template +struct UniformGenerator { + T min_, max_; + unsigned int seed_; + T diag_val_; + unsigned int diag_num_; + unsigned int diag_step_; + __host__ __device__ UniformGenerator( + T min, T max, int seed, int diag_num, int diag_step, T diag_val) + : min_(min), + max_(max), + seed_(seed), + diag_num_(diag_num), + diag_step_(diag_step), + diag_val_(diag_val) {} + + __host__ __device__ T operator()(const unsigned int n) const { + thrust::minstd_rand rng; + rng.seed(seed_); + thrust::uniform_real_distribution dist(min_, max_); + rng.discard(n); + T out = dist(rng); + unsigned int remainder = n % (diag_step_ + 1); + if (remainder == 0 && diag_num_ > n / (diag_step_ + 1)) { + out = diag_val_; + } + return out; + } +}; + +template +void UniformRawKernel(const Context& dev_ctx, + const IntArray& shape, + DataType dtype, + const Scalar& min, + const Scalar& max, + int seed, + int diag_num, + int diag_step, + float diag_val, + DenseTensor* out) { + out->Resize(phi::make_ddim(shape.GetData())); + dev_ctx.template Alloc(out); + if (seed == 0) { + // Use global Generator seed + using MT = typename kps::details::MPTypeTrait::Type; + funcs::uniform_distribution dist; + funcs::uniform_real_transform trans(min.to(), max.to()); + funcs::distribution_and_transform(dev_ctx, out, dist, trans); + } else { + // Use OP seed + auto func = UniformGenerator(static_cast(min.to()), + static_cast(max.to()), + seed, + diag_num, + diag_step, + static_cast(diag_val)); + IndexKernel>(dev_ctx, out, func); + } +} + +} // namespace phi + +PD_REGISTER_KERNEL(uniform_raw, + GPU, + ALL_LAYOUT, + phi::UniformRawKernel, + float, + double, + phi::dtype::float16, + phi::dtype::bfloat16) {} diff --git a/paddle/phi/kernels/legacy/uniform_kernel.h b/paddle/phi/kernels/legacy/uniform_kernel.h new file mode 100644 index 00000000000..5c3b9966fc9 --- /dev/null +++ b/paddle/phi/kernels/legacy/uniform_kernel.h @@ -0,0 +1,36 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/common/int_array.h" +#include "paddle/phi/common/scalar.h" +#include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/core/device_context.h" + +namespace phi { + +template +void UniformRawKernel(const Context& dev_ctx, + const IntArray& shape, + DataType dtype, + const Scalar& min, + const Scalar& max, + int seed, + int diag_num, + int diag_step, + float diag_val, + DenseTensor* out); + +} // namespace phi diff --git a/paddle/phi/kernels/legacy/xpu/uniform_kernel.cc b/paddle/phi/kernels/legacy/xpu/uniform_kernel.cc new file mode 100644 index 00000000000..f1907b13e5f --- /dev/null +++ b/paddle/phi/kernels/legacy/xpu/uniform_kernel.cc @@ -0,0 +1,80 @@ +/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/phi/kernels/uniform_kernel.h" + +#include + +#include "paddle/phi/backends/xpu/xpu_context.h" +#include "paddle/phi/common/memory_utils.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/funcs/uniform_real_distribution.h" + +namespace phi { + +template +void UniformRawKernel(const Context &dev_ctx, + const IntArray &shape, + DataType dtype, + const Scalar &min, + const Scalar &max, + int seed, + int diag_num, + int diag_step, + float diag_val, + DenseTensor *out) { + out->Resize(phi::make_ddim(shape.GetData())); + T *data = dev_ctx.template Alloc(out); + int64_t size = out->numel(); + + std::unique_ptr data_cpu(new T[size]); + + std::shared_ptr engine; + if (seed) { + engine = std::make_shared(); + engine->seed(seed); + } else { + engine = dev_ctx.GetGenerator()->GetCPUEngine(); + } + UniformRealDistribution( + data_cpu.get(), size, min.to(), max.to(), engine); + if (diag_num > 0) { + PADDLE_ENFORCE_GT( + size, + (diag_num - 1) * (diag_step + 1), + phi::errors::InvalidArgument( + "ShapeInvalid: the diagonal's elements is equal (num-1) " + "* (step-1) with num %d, step %d," + "It should be smaller than %d, but received %d", + diag_num, + diag_step, + (diag_num - 1) * (diag_step + 1), + size)); + for (int64_t i = 0; i < diag_num; ++i) { + int64_t pos = i * diag_step + i; + data_cpu[pos] = diag_val; + } + } + + memory_utils::Copy(dev_ctx.GetPlace(), + data, + phi::CPUPlace(), + reinterpret_cast(data_cpu.get()), + size * sizeof(T)); +} + +} // namespace phi + +PD_REGISTER_KERNEL(uniform_raw, XPU, ALL_LAYOUT, phi::UniformRawKernel, float) { +} diff --git a/paddle/phi/kernels/selected_rows/uniform_kernel.cc b/paddle/phi/kernels/selected_rows/uniform_kernel.cc index 89707018002..0af5d8788c7 100644 --- a/paddle/phi/kernels/selected_rows/uniform_kernel.cc +++ b/paddle/phi/kernels/selected_rows/uniform_kernel.cc @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/phi/kernels/selected_rows/uniform_kernel.h" +#include "paddle/phi/kernels/legacy/uniform_kernel.h" #include "paddle/phi/backends/cpu/cpu_context.h" #include "paddle/phi/backends/gpu/gpu_context.h" diff --git a/paddle/phi/kernels/uniform_kernel.cc b/paddle/phi/kernels/uniform_kernel.cc deleted file mode 100644 index 7e8138f6e1d..00000000000 --- a/paddle/phi/kernels/uniform_kernel.cc +++ /dev/null @@ -1,65 +0,0 @@ -// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "paddle/phi/kernels/uniform_kernel.h" - -#include "paddle/phi/common/int_array.h" -#include "paddle/phi/common/scalar.h" -#include "paddle/phi/core/dense_tensor.h" -#include "paddle/phi/core/device_context.h" -#include "paddle/phi/core/kernel_registry.h" -#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) -#include "paddle/phi/backends/gpu/gpu_context.h" -#endif -#ifdef PADDLE_WITH_XPU -#include "paddle/phi/backends/xpu/xpu_context.h" -#endif - -namespace phi { - -template -void UniformKernel(const Context& dev_ctx, - const IntArray& shape, - DataType dtype, - const Scalar& min, - const Scalar& max, - int seed, - DenseTensor* out) { - UniformRawKernel(dev_ctx, shape, dtype, min, max, seed, 0, 0, 0.0f, out); -} - -} // namespace phi - -PD_REGISTER_KERNEL(uniform, - CPU, - ALL_LAYOUT, - phi::UniformKernel, - float, - double, - phi::dtype::bfloat16) {} - -#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) -PD_REGISTER_KERNEL(uniform, - GPU, - ALL_LAYOUT, - phi::UniformKernel, - float, - double, - phi::dtype::float16, - phi::dtype::bfloat16) {} -#endif - -#ifdef PADDLE_WITH_XPU -PD_REGISTER_KERNEL(uniform, XPU, ALL_LAYOUT, phi::UniformKernel, float) {} -#endif diff --git a/paddle/phi/kernels/uniform_kernel.h b/paddle/phi/kernels/uniform_kernel.h index ef19c20b932..1d0160e4056 100644 --- a/paddle/phi/kernels/uniform_kernel.h +++ b/paddle/phi/kernels/uniform_kernel.h @@ -21,18 +21,6 @@ namespace phi { -template -void UniformRawKernel(const Context& dev_ctx, - const IntArray& shape, - DataType dtype, - const Scalar& min, - const Scalar& max, - int seed, - int diag_num, - int diag_step, - float diag_val, - DenseTensor* out); - template void UniformKernel(const Context& dev_ctx, const IntArray& shape, diff --git a/paddle/phi/kernels/xpu/uniform_kernel.cc b/paddle/phi/kernels/xpu/uniform_kernel.cc index 48f9a6e8d77..99388e31e58 100644 --- a/paddle/phi/kernels/xpu/uniform_kernel.cc +++ b/paddle/phi/kernels/xpu/uniform_kernel.cc @@ -24,16 +24,16 @@ limitations under the License. */ namespace phi { template -void UniformRawKernel(const Context &dev_ctx, - const IntArray &shape, - DataType dtype, - const Scalar &min, - const Scalar &max, - int seed, - int diag_num, - int diag_step, - float diag_val, - DenseTensor *out) { +void UniformKernel(const Context &dev_ctx, + const IntArray &shape, + DataType dtype, + const Scalar &min, + const Scalar &max, + int seed, + DenseTensor *out) { + int diag_num = 0; + int diag_step = 0; + float diag_val = 0.0f; out->Resize(phi::make_ddim(shape.GetData())); T *data = dev_ctx.template Alloc(out); int64_t size = out->numel(); @@ -76,5 +76,4 @@ void UniformRawKernel(const Context &dev_ctx, } // namespace phi -PD_REGISTER_KERNEL(uniform_raw, XPU, ALL_LAYOUT, phi::UniformRawKernel, float) { -} +PD_REGISTER_KERNEL(uniform, XPU, ALL_LAYOUT, phi::UniformKernel, float) {} -- GitLab