From ded33b58cb4d16685e30605071fd7186bb371bb1 Mon Sep 17 00:00:00 2001 From: HongyuJia Date: Thu, 1 Sep 2022 14:17:58 +0800 Subject: [PATCH] [phi] Migrate uniform_random XPU kernel to PHI (#45583) * copy kernel file to phi * delete some code * migrate uniform_random, test=kunlun * fix input error, test=kunlun * fix gpu register error, test=kunlun * add include file, test=kunlun * try fix error from CI, test=kunlun * polish other PR * fix CI-coverage error, test=kunlun --- .../new_executor/standalone_executor_test.cc | 1 + .../fluid/operators/uniform_random_op_xpu.cc | 113 ------------------ .../phi/kernels/cpu/uniform_random_kernel.cc | 46 +------ .../kernels/funcs/uniform_real_distribution.h | 49 ++++++++ .../phi/kernels/gpu/uniform_random_kernel.cu | 15 --- .../selected_rows/uniform_random_kernel.cc | 12 ++ paddle/phi/kernels/uniform_random_kernel.cc | 61 ++++++++++ .../xpu/truncated_gaussian_random_kernel.cc | 1 - .../phi/kernels/xpu/uniform_random_kernel.cc | 80 +++++++++++++ 9 files changed, 204 insertions(+), 174 deletions(-) delete mode 100644 paddle/fluid/operators/uniform_random_op_xpu.cc create mode 100644 paddle/phi/kernels/funcs/uniform_real_distribution.h create mode 100644 paddle/phi/kernels/uniform_random_kernel.cc create mode 100644 paddle/phi/kernels/xpu/uniform_random_kernel.cc diff --git a/paddle/fluid/framework/new_executor/standalone_executor_test.cc b/paddle/fluid/framework/new_executor/standalone_executor_test.cc index 701c1edcafd..2531a8e7cd3 100644 --- a/paddle/fluid/framework/new_executor/standalone_executor_test.cc +++ b/paddle/fluid/framework/new_executor/standalone_executor_test.cc @@ -64,6 +64,7 @@ USE_OP_ITSELF(fetch_v2); PD_DECLARE_KERNEL(full, GPU, ALL_LAYOUT); PD_DECLARE_KERNEL(uniform_random_raw, GPU, ALL_LAYOUT); +PD_DECLARE_KERNEL(uniform_random, GPU, ALL_LAYOUT); PD_DECLARE_KERNEL(transpose, GPU, ALL_LAYOUT); PD_DECLARE_KERNEL(reshape, GPU, ALL_LAYOUT); PD_DECLARE_KERNEL(split, GPU, ALL_LAYOUT); diff --git a/paddle/fluid/operators/uniform_random_op_xpu.cc b/paddle/fluid/operators/uniform_random_op_xpu.cc deleted file mode 100644 index 14e247894fd..00000000000 --- a/paddle/fluid/operators/uniform_random_op_xpu.cc +++ /dev/null @@ -1,113 +0,0 @@ -/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#ifdef PADDLE_WITH_XPU - -#include - -#include "paddle/fluid/framework/generator.h" -#include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/framework/operator.h" -#include "paddle/fluid/operators/uniform_random_op.h" - -namespace paddle { -namespace operators { - -template -class XPUUniformRandomKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext &ctx) const override { - framework::Tensor *tensor = nullptr; - auto out_var = ctx.OutputVar("Out"); - std::vector new_shape; - auto list_new_shape_tensor = - ctx.MultiInput("ShapeTensorList"); - if (list_new_shape_tensor.size() > 0 || ctx.HasInput("ShapeTensor")) { - if (ctx.HasInput("ShapeTensor")) { - auto *shape_tensor = ctx.Input("ShapeTensor"); - new_shape = GetNewDataFromShapeTensor(shape_tensor); - } else if (list_new_shape_tensor.size() > 0) { - new_shape = GetNewDataFromShapeTensorList(list_new_shape_tensor); - } - } - - if (out_var->IsType()) { - auto *selected_rows = out_var->GetMutable(); - tensor = selected_rows->mutable_value(); - auto shape = ctx.Attr>("shape"); - if (!new_shape.empty()) shape = new_shape; - tensor->Resize(phi::make_ddim(shape)); - selected_rows->mutable_rows()->reserve(shape[0]); - } else if (out_var->IsType()) { - tensor = out_var->GetMutable(); - if (!new_shape.empty()) tensor->Resize(phi::make_ddim(new_shape)); - } else { - PADDLE_THROW(platform::errors::InvalidArgument( - "Expected type of Output(out) in uniform_random_op must be Tensor, " - "SelectedRows. But got " - "unsupport type: %s.", - framework::ToTypeName(out_var->Type()))); - } - T *data = tensor->mutable_data(ctx.GetPlace()); - - int64_t size = tensor->numel(); - std::unique_ptr data_cpu(new T[size]); - std::uniform_real_distribution dist( - static_cast(ctx.Attr("min")), - static_cast(ctx.Attr("max"))); - unsigned int seed = static_cast(ctx.Attr("seed")); - auto engine = framework::GetCPURandomEngine(seed); - - for (int64_t i = 0; i < size; ++i) { - data_cpu[i] = dist(*engine); - } - - unsigned int diag_num = - static_cast(ctx.Attr("diag_num")); - unsigned int diag_step = - static_cast(ctx.Attr("diag_step")); - auto diag_val = static_cast(ctx.Attr("diag_val")); - if (diag_num > 0) { - PADDLE_ENFORCE_GT( - size, - (diag_num - 1) * (diag_step + 1), - platform::errors::InvalidArgument( - "ShapeInvalid: the diagonal's elements is equal (num-1) " - "* (step-1) with num %d, step %d," - "It should be smaller than %d, but received %d", - diag_num, - diag_step, - (diag_num - 1) * (diag_step + 1), - size)); - for (int64_t i = 0; i < diag_num; ++i) { - int64_t pos = i * diag_step + i; - data_cpu[pos] = diag_val; - } - } - - memory::Copy(ctx.GetPlace(), - data, - platform::CPUPlace(), - reinterpret_cast(data_cpu.get()), - size * sizeof(T)); - } -}; - -} // namespace operators -} // namespace paddle - -REGISTER_OP_XPU_KERNEL(uniform_random, - paddle::operators::XPUUniformRandomKernel); - -#endif // PADDLE_WITH_XPU diff --git a/paddle/phi/kernels/cpu/uniform_random_kernel.cc b/paddle/phi/kernels/cpu/uniform_random_kernel.cc index e5b25fc0554..a4e66a8f645 100644 --- a/paddle/phi/kernels/cpu/uniform_random_kernel.cc +++ b/paddle/phi/kernels/cpu/uniform_random_kernel.cc @@ -15,34 +15,10 @@ #include "paddle/phi/kernels/uniform_random_kernel.h" #include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/funcs/uniform_real_distribution.h" namespace phi { -template -inline void UniformRealDistribution(T *data, - const int64_t &size, - const float &min, - const float &max, - std::shared_ptr engine) { - std::uniform_real_distribution dist(static_cast(min), - static_cast(max)); - for (int64_t i = 0; i < size; ++i) { - data[i] = dist(*engine); - } -} - -template <> -inline void UniformRealDistribution(phi::dtype::bfloat16 *data, - const int64_t &size, - const float &min, - const float &max, - std::shared_ptr engine) { - std::uniform_real_distribution dist(min, max); - for (int64_t i = 0; i < size; ++i) { - data[i] = static_cast(dist(*engine)); - } -} - template void UniformRandomRawKernel(const Context &dev_ctx, const IntArray &shape, @@ -85,18 +61,6 @@ void UniformRandomRawKernel(const Context &dev_ctx, } } -template -void UniformRandomKernel(const Context &dev_ctx, - const IntArray &shape, - DataType dtype, - const Scalar &min, - const Scalar &max, - int seed, - DenseTensor *out) { - UniformRandomRawKernel( - dev_ctx, shape, dtype, min, max, seed, 0, 0, 0.0f, out); -} - } // namespace phi PD_REGISTER_KERNEL(uniform_random_raw, @@ -106,11 +70,3 @@ PD_REGISTER_KERNEL(uniform_random_raw, float, double, phi::dtype::bfloat16) {} - -PD_REGISTER_KERNEL(uniform_random, - CPU, - ALL_LAYOUT, - phi::UniformRandomKernel, - float, - double, - phi::dtype::bfloat16) {} diff --git a/paddle/phi/kernels/funcs/uniform_real_distribution.h b/paddle/phi/kernels/funcs/uniform_real_distribution.h new file mode 100644 index 00000000000..07318d4b6df --- /dev/null +++ b/paddle/phi/kernels/funcs/uniform_real_distribution.h @@ -0,0 +1,49 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include + +#include "paddle/phi/common/data_type.h" + +namespace phi { + +template +inline void UniformRealDistribution(T *data, + const int64_t &size, + const float &min, + const float &max, + std::shared_ptr engine) { + std::uniform_real_distribution dist(static_cast(min), + static_cast(max)); + for (int64_t i = 0; i < size; ++i) { + data[i] = dist(*engine); + } +} + +template <> +inline void UniformRealDistribution(phi::dtype::bfloat16 *data, + const int64_t &size, + const float &min, + const float &max, + std::shared_ptr engine) { + std::uniform_real_distribution dist(min, max); + for (int64_t i = 0; i < size; ++i) { + data[i] = static_cast(dist(*engine)); + } +} + +} // namespace phi diff --git a/paddle/phi/kernels/gpu/uniform_random_kernel.cu b/paddle/phi/kernels/gpu/uniform_random_kernel.cu index 36382f40ea1..23232970e19 100644 --- a/paddle/phi/kernels/gpu/uniform_random_kernel.cu +++ b/paddle/phi/kernels/gpu/uniform_random_kernel.cu @@ -80,18 +80,6 @@ void UniformRandomRawKernel(const Context& dev_ctx, } } -template -void UniformRandomKernel(const Context& dev_ctx, - const IntArray& shape, - DataType dtype, - const Scalar& min, - const Scalar& max, - int seed, - DenseTensor* out) { - UniformRandomRawKernel( - dev_ctx, shape, dtype, min, max, seed, 0, 0, 0.0f, out); -} - } // namespace phi PD_REGISTER_KERNEL(uniform_random_raw, @@ -100,6 +88,3 @@ PD_REGISTER_KERNEL(uniform_random_raw, phi::UniformRandomRawKernel, float, double) {} - -PD_REGISTER_KERNEL( - uniform_random, GPU, ALL_LAYOUT, phi::UniformRandomKernel, float, double) {} diff --git a/paddle/phi/kernels/selected_rows/uniform_random_kernel.cc b/paddle/phi/kernels/selected_rows/uniform_random_kernel.cc index c304fcf6770..d6037da45f6 100644 --- a/paddle/phi/kernels/selected_rows/uniform_random_kernel.cc +++ b/paddle/phi/kernels/selected_rows/uniform_random_kernel.cc @@ -92,3 +92,15 @@ PD_REGISTER_KERNEL(uniform_random_sr, float, double) {} #endif + +#if defined(PADDLE_WITH_XPU) + +PD_REGISTER_KERNEL(uniform_random_raw_sr, + XPU, + ALL_LAYOUT, + phi::sr::UniformRandomRawKernel, + float) {} + +PD_REGISTER_KERNEL( + uniform_random_sr, XPU, ALL_LAYOUT, phi::sr::UniformRandomKernel, float) {} +#endif diff --git a/paddle/phi/kernels/uniform_random_kernel.cc b/paddle/phi/kernels/uniform_random_kernel.cc new file mode 100644 index 00000000000..11f61e5b4a0 --- /dev/null +++ b/paddle/phi/kernels/uniform_random_kernel.cc @@ -0,0 +1,61 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/uniform_random_kernel.h" + +#include "paddle/phi/common/int_array.h" +#include "paddle/phi/common/scalar.h" +#include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/core/device_context.h" +#include "paddle/phi/core/kernel_registry.h" +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) +#include "paddle/phi/backends/gpu/gpu_context.h" +#endif +#ifdef PADDLE_WITH_XPU +#include "paddle/phi/backends/xpu/xpu_context.h" +#endif + +namespace phi { + +template +void UniformRandomKernel(const Context& dev_ctx, + const IntArray& shape, + DataType dtype, + const Scalar& min, + const Scalar& max, + int seed, + DenseTensor* out) { + UniformRandomRawKernel( + dev_ctx, shape, dtype, min, max, seed, 0, 0, 0.0f, out); +} + +} // namespace phi + +PD_REGISTER_KERNEL(uniform_random, + CPU, + ALL_LAYOUT, + phi::UniformRandomKernel, + float, + double, + phi::dtype::bfloat16) {} + +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) +PD_REGISTER_KERNEL( + uniform_random, GPU, ALL_LAYOUT, phi::UniformRandomKernel, float, double) {} +#endif + +#ifdef PADDLE_WITH_XPU +PD_REGISTER_KERNEL( + uniform_random, XPU, ALL_LAYOUT, phi::UniformRandomKernel, float) {} +#endif diff --git a/paddle/phi/kernels/xpu/truncated_gaussian_random_kernel.cc b/paddle/phi/kernels/xpu/truncated_gaussian_random_kernel.cc index 25a19d11ef5..38aa8eb67bf 100644 --- a/paddle/phi/kernels/xpu/truncated_gaussian_random_kernel.cc +++ b/paddle/phi/kernels/xpu/truncated_gaussian_random_kernel.cc @@ -18,7 +18,6 @@ limitations under the License. */ #include #include "paddle/fluid/memory/memcpy.h" -#include "paddle/phi/backends/cpu/cpu_context.h" #include "paddle/phi/backends/xpu/xpu_context.h" #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/kernels/funcs/truncated_normal.h" diff --git a/paddle/phi/kernels/xpu/uniform_random_kernel.cc b/paddle/phi/kernels/xpu/uniform_random_kernel.cc new file mode 100644 index 00000000000..3bc346ab957 --- /dev/null +++ b/paddle/phi/kernels/xpu/uniform_random_kernel.cc @@ -0,0 +1,80 @@ +/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/phi/kernels/uniform_random_kernel.h" + +#include + +#include "paddle/fluid/memory/memcpy.h" +#include "paddle/phi/backends/xpu/xpu_context.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/funcs/uniform_real_distribution.h" + +namespace phi { + +template +void UniformRandomRawKernel(const Context &dev_ctx, + const IntArray &shape, + DataType dtype, + const Scalar &min, + const Scalar &max, + int seed, + int diag_num, + int diag_step, + float diag_val, + DenseTensor *out) { + out->Resize(phi::make_ddim(shape.GetData())); + T *data = dev_ctx.template Alloc(out); + int64_t size = out->numel(); + + std::unique_ptr data_cpu(new T[size]); + + std::shared_ptr engine; + if (seed) { + engine = std::make_shared(); + engine->seed(seed); + } else { + engine = dev_ctx.GetGenerator()->GetCPUEngine(); + } + UniformRealDistribution( + data_cpu.get(), size, min.to(), max.to(), engine); + if (diag_num > 0) { + PADDLE_ENFORCE_GT( + size, + (diag_num - 1) * (diag_step + 1), + phi::errors::InvalidArgument( + "ShapeInvalid: the diagonal's elements is equal (num-1) " + "* (step-1) with num %d, step %d," + "It should be smaller than %d, but received %d", + diag_num, + diag_step, + (diag_num - 1) * (diag_step + 1), + size)); + for (int64_t i = 0; i < diag_num; ++i) { + int64_t pos = i * diag_step + i; + data_cpu[pos] = diag_val; + } + } + + paddle::memory::Copy(dev_ctx.GetPlace(), + data, + phi::CPUPlace(), + reinterpret_cast(data_cpu.get()), + size * sizeof(T)); +} + +} // namespace phi + +PD_REGISTER_KERNEL( + uniform_random_raw, XPU, ALL_LAYOUT, phi::UniformRandomRawKernel, float) {} -- GitLab