From 60e1eccbddd562eca7489fa21a61066f43f42cca Mon Sep 17 00:00:00 2001 From: wanghuancoder Date: Mon, 29 Aug 2022 15:50:03 +0800 Subject: [PATCH] [Phi] gather gather_grad gather_nd gaussian_random xpu to Phi (#45465) * gather gather_grad gather_nd gaussian_random xpu to phi --- paddle/fluid/operators/gather_nd_op_xpu.cc | 96 -------- paddle/fluid/operators/gather_op_xpu.cc | 228 ------------------ .../fluid/operators/gaussian_random_op_xpu.cc | 58 ----- paddle/phi/kernels/xpu/gather_grad_kernel.cc | 111 +++++++++ paddle/phi/kernels/xpu/gather_kernel.cc | 86 +++++++ paddle/phi/kernels/xpu/gather_nd_kernel.cc | 83 +++++++ .../phi/kernels/xpu/gaussian_random_kernel.cc | 55 +++++ 7 files changed, 335 insertions(+), 382 deletions(-) delete mode 100644 paddle/fluid/operators/gather_nd_op_xpu.cc delete mode 100644 paddle/fluid/operators/gather_op_xpu.cc delete mode 100644 paddle/fluid/operators/gaussian_random_op_xpu.cc create mode 100644 paddle/phi/kernels/xpu/gather_grad_kernel.cc create mode 100644 paddle/phi/kernels/xpu/gather_kernel.cc create mode 100644 paddle/phi/kernels/xpu/gather_nd_kernel.cc create mode 100644 paddle/phi/kernels/xpu/gaussian_random_kernel.cc diff --git a/paddle/fluid/operators/gather_nd_op_xpu.cc b/paddle/fluid/operators/gather_nd_op_xpu.cc deleted file mode 100644 index ab7e0bc1ad..0000000000 --- a/paddle/fluid/operators/gather_nd_op_xpu.cc +++ /dev/null @@ -1,96 +0,0 @@ -/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - http://www.apache.org/licenses/LICENSE-2.0 -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#ifdef PADDLE_WITH_XPU - -#include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/framework/operator.h" -#include "paddle/fluid/framework/tensor.h" -#include "paddle/fluid/platform/device_context.h" - -namespace paddle { -namespace operators { - -template -class GatherNdXPUKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext &ctx) const override { - auto *x = ctx.Input("X"); - auto *index = ctx.Input("Index"); - auto *out = ctx.Output("Out"); - - out->template mutable_data(ctx.GetPlace()); - if (x->numel() == 0) return; - - if (index->numel() == 0) { - framework::TensorCopy(*x, ctx.GetPlace(), ctx.device_context(), out); - return; - } - - const auto &index_type = framework::TransToProtoVarType(index->dtype()); - bool index_type_match = index_type == framework::proto::VarType::INT32 || - index_type == framework::proto::VarType::INT64; - PADDLE_ENFORCE_EQ(index_type_match, - true, - platform::errors::InvalidArgument( - "Index holds the wrong type, it holds [%s]," - "but desires to be [%s] or [%s]", - paddle::framework::DataTypeToString(index_type), - paddle::framework::DataTypeToString( - framework::proto::VarType::INT32), - paddle::framework::DataTypeToString( - framework::proto::VarType::INT64))); - - auto x_shape = phi::vectorize(x->dims()); - auto index_shape = phi::vectorize(index->dims()); - if (index_shape.size() == 1) { - index_shape.insert(index_shape.begin(), 1); - } - xpu::VectorParam x_vec = { - x_shape.data(), static_cast(x_shape.size()), nullptr}; - - auto &dev_ctx = - ctx.template device_context(); - int ret = XPU_SUCCESS; - if (index_type == framework::proto::VarType::INT32) { - ret = xpu::gather_nd(dev_ctx.x_context(), - x->data(), - index->data(), - out->data(), - x_vec, - index_shape); - } else { - ret = xpu::gather_nd(dev_ctx.x_context(), - x->data(), - index->data(), - out->data(), - x_vec, - index_shape); - } - PADDLE_ENFORCE_EQ(ret, - XPU_SUCCESS, - platform::errors::External( - "XPU gather_nd kernel return wrong value[%d %s]", - ret, - XPUAPIErrorMsg[ret])); - } -}; - -} // namespace operators -} // namespace paddle - -namespace ops = paddle::operators; -REGISTER_OP_XPU_KERNEL(gather_nd, - ops::GatherNdXPUKernel, - ops::GatherNdXPUKernel, - ops::GatherNdXPUKernel); - -#endif diff --git a/paddle/fluid/operators/gather_op_xpu.cc b/paddle/fluid/operators/gather_op_xpu.cc deleted file mode 100644 index 9a3cdc8def..0000000000 --- a/paddle/fluid/operators/gather_op_xpu.cc +++ /dev/null @@ -1,228 +0,0 @@ -/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#ifdef PADDLE_WITH_XPU -#include -#include -#include - -#include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/framework/op_version_registry.h" -#include "paddle/phi/core/ddim.h" -namespace paddle { -namespace operators { - -using Tensor = framework::Tensor; - -template -class GatherOpXPUKernel : public framework::OpKernel { - using XPUType = typename XPUTypeTrait::Type; - - public: - void Compute(const framework::ExecutionContext &ctx) const override { - PADDLE_ENFORCE_EQ( - platform::is_xpu_place(ctx.GetPlace()), - true, - platform::errors::PreconditionNotMet("This kernel only runs on XPU.")); - - auto *x = ctx.Input("X"); - auto *index = ctx.Input("Index"); - auto *output = ctx.Output("Out"); - - int axis = ctx.Attr("axis"); - if (ctx.HasInput("Axis")) { - Tensor cpu_axis; - const Tensor *axis_tensor = ctx.Input("Axis"); - framework::TensorCopy(*axis_tensor, platform::CPUPlace(), &cpu_axis); - const auto &axis_type = axis_tensor->dtype(); - if (framework::TransToProtoVarType(axis_type) == - framework::proto::VarType::INT32) { - axis = static_cast(cpu_axis.data()[0]); - } else if (framework::TransToProtoVarType(axis_type) == - framework::proto::VarType::INT64) { - axis = static_cast(cpu_axis.data()[0]); - } - } - - output->mutable_data(ctx.GetPlace()); - if (x->numel() == 0) return; - - const auto index_dims = index->dims(); - if (index_dims.size() == 2) { - PADDLE_ENFORCE_EQ( - index_dims[1], - 1, - platform::errors::InvalidArgument( - "The last dim of index should be 1 when it is 2D, but we get %d", - index_dims[1])); - } else { - PADDLE_ENFORCE_EQ( - index_dims.size(), - 1, - platform::errors::InvalidArgument( - "The index should be 1D, when it is not 2D, but we get %d", - index_dims.size())); - } - std::vector xshape(x->dims().size()); - for (int i = 0; i < x->dims().size(); ++i) { - xshape[i] = x->dims()[i]; - } - - auto &dev_ctx = ctx.template device_context(); - int r = XPU_SUCCESS; - if (framework::TransToProtoVarType(index->dtype()) == - framework::proto::VarType::INT32) { - r = xpu::gather( - dev_ctx.x_context(), - reinterpret_cast(x->data()), - index->data(), - reinterpret_cast(output->data()), - xshape, - index->dims()[0], - axis); - } else { - r = xpu::gather( - dev_ctx.x_context(), - reinterpret_cast(x->data()), - index->data(), - reinterpret_cast(output->data()), - xshape, - index->dims()[0], - axis); - } - PADDLE_ENFORCE_EQ(r, - xpu::Error_t::SUCCESS, - platform::errors::External( - "XPU gather kernel return wrong value[%d %s]", - r, - XPUAPIErrorMsg[r])); - } -}; - -template -class GatherGradOpXPUKernel : public framework::OpKernel { - using XPUType = typename XPUTypeTrait::Type; - - public: - void Compute(const framework::ExecutionContext &ctx) const override { - PADDLE_ENFORCE_EQ( - platform::is_xpu_place(ctx.GetPlace()), - true, - platform::errors::PreconditionNotMet("This kernel only runs on XPU.")); - - auto *index = ctx.Input("Index"); - auto *dx = ctx.Output(framework::GradVarName("X")); - auto *dout = ctx.Input(framework::GradVarName("Out")); - auto &dev_ctx = ctx.template device_context(); - - int axis = ctx.Attr("axis"); - if (ctx.HasInput("Axis")) { - Tensor cpu_axis; - const Tensor *axis_tensor = ctx.Input("Axis"); - framework::TensorCopy(*axis_tensor, platform::CPUPlace(), &cpu_axis); - const auto &axis_type = axis_tensor->dtype(); - if (framework::TransToProtoVarType(axis_type) == - framework::proto::VarType::INT32) { - axis = static_cast(cpu_axis.data()[0]); - } else if (framework::TransToProtoVarType(axis_type) == - framework::proto::VarType::INT64) { - axis = static_cast(cpu_axis.data()[0]); - } - } - if (dout->numel() == 0) { - return; - } - - bool overwrite = ctx.Attr("overwrite"); - const auto index_dims = index->dims(); - if (index_dims.size() == 2) { - PADDLE_ENFORCE_EQ( - index_dims[1], - 1, - platform::errors::InvalidArgument( - "The last dim of index should be 1 when it is 2D, but we get %d", - index_dims[1])); - } else { - PADDLE_ENFORCE_EQ( - index_dims.size(), - 1, - platform::errors::InvalidArgument( - "The index should be 1D, when it is not 2D, but we get %d", - index_dims.size())); - } - std::vector xshape(dx->dims().size()); - for (int i = 0; i < dx->dims().size(); ++i) { - xshape[i] = dx->dims()[i]; - } - - dx->mutable_data(ctx.GetPlace()); - - int r = XPU_SUCCESS; - if (framework::TransToProtoVarType(index->dtype()) == - framework::proto::VarType::INT32) { - r = xpu::gather_grad( - dev_ctx.x_context(), - reinterpret_cast(dout->data()), - index->data(), - reinterpret_cast(dx->data()), - xshape, - index->dims()[0], - axis, - overwrite); - } else { - xpu::ctx_guard RAII_GUARD(dev_ctx.x_context()); - int *index_int_ptr_l3 = - RAII_GUARD.alloc_l3_or_gm(index->numel()); - r = xpu::cast_v2(dev_ctx.x_context(), - index->data(), - index_int_ptr_l3, - index->numel()); - PADDLE_ENFORCE_EQ( - r, - XPU_SUCCESS, - platform::errors::External("XPU API(cast_v2) return wrong " - "value[%d %s]", - r, - XPUAPIErrorMsg[r])); - - r = xpu::gather_grad( - dev_ctx.x_context(), - reinterpret_cast(dout->data()), - index_int_ptr_l3, - reinterpret_cast(dx->data()), - xshape, - index->dims()[0], - axis, - overwrite); - } - PADDLE_ENFORCE_EQ(r, - xpu::Error_t::SUCCESS, - platform::errors::External( - "XPU gather grad kernel return wrong value[%d %s]", - r, - XPUAPIErrorMsg[r])); - } -}; - -} // namespace operators -} // namespace paddle - -namespace ops = paddle::operators; -REGISTER_OP_XPU_KERNEL(gather, - ops::GatherOpXPUKernel, - ops::GatherOpXPUKernel); -REGISTER_OP_XPU_KERNEL(gather_grad, - ops::GatherGradOpXPUKernel, - ops::GatherGradOpXPUKernel); -#endif diff --git a/paddle/fluid/operators/gaussian_random_op_xpu.cc b/paddle/fluid/operators/gaussian_random_op_xpu.cc deleted file mode 100644 index a8d6081915..0000000000 --- a/paddle/fluid/operators/gaussian_random_op_xpu.cc +++ /dev/null @@ -1,58 +0,0 @@ -/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#ifdef PADDLE_WITH_XPU - -#include - -#include "paddle/fluid/framework/generator.h" -#include "paddle/fluid/framework/op_registry.h" - -namespace paddle { -namespace operators { - -template -class XPUGaussianRandomKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& context) const override { - float mean = context.Attr("mean"); - float std = context.Attr("std"); - auto* tensor = context.Output("Out"); - - std::normal_distribution dist(mean, std); - int64_t size = tensor->numel(); - T* data = tensor->mutable_data(context.GetPlace()); - unsigned int seed = static_cast(context.Attr("seed")); - // TODO(pangyoki): implement GetXPURandomEngine to set different seeds on - // corresponding XPU device. - auto engine = framework::GetCPURandomEngine(seed); - - std::unique_ptr data_cpu(new T[size]); - for (int64_t i = 0; i < size; ++i) { - data_cpu[i] = dist(*engine); - } - memory::Copy(context.GetPlace(), - data, - platform::CPUPlace(), - reinterpret_cast(data_cpu.get()), - size * sizeof(T)); - } -}; - -} // namespace operators -} // namespace paddle - -namespace ops = paddle::operators; -REGISTER_OP_XPU_KERNEL(gaussian_random, ops::XPUGaussianRandomKernel); -#endif diff --git a/paddle/phi/kernels/xpu/gather_grad_kernel.cc b/paddle/phi/kernels/xpu/gather_grad_kernel.cc new file mode 100644 index 0000000000..b1c1731130 --- /dev/null +++ b/paddle/phi/kernels/xpu/gather_grad_kernel.cc @@ -0,0 +1,111 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/gather_kernel.h" + +#include "paddle/phi/backends/xpu/enforce_xpu.h" +#include "paddle/phi/core/kernel_registry.h" + +namespace phi { + +template +void GatherGradKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& index, + const DenseTensor& out_grad, + const Scalar& axis, + bool overwrite, + DenseTensor* x_grad) { + auto axis_v = axis.to(); + const auto& index_type = index.dtype(); + + if (out_grad.numel() == 0) { + return; + } + + const auto index_dims = index.dims(); + if (index_dims.size() == 2) { + PADDLE_ENFORCE_EQ( + index_dims[1], + 1, + phi::errors::InvalidArgument( + "The last dim of index should be 1 when it is 2D, but we get %d", + index_dims[1])); + } else { + PADDLE_ENFORCE_EQ( + index_dims.size(), + 1, + phi::errors::InvalidArgument( + "The index should be 1D, when it is not 2D, but we get %d", + index_dims.size())); + } + std::vector xshape(x_grad->dims().size()); + for (int i = 0; i < x_grad->dims().size(); ++i) { + xshape[i] = x_grad->dims()[i]; + } + + dev_ctx.template Alloc(x_grad); + using XPUType = typename XPUTypeTrait::Type; + + int r = XPU_SUCCESS; + if (index_type == DataType::INT32) { + r = xpu::gather_grad( + dev_ctx.x_context(), + reinterpret_cast(out_grad.data()), + index.data(), + reinterpret_cast(x_grad->data()), + xshape, + index.dims()[0], + axis_v, + overwrite); + } else { + xpu::ctx_guard RAII_GUARD(dev_ctx.x_context()); + int* index_int_ptr_l3 = RAII_GUARD.alloc_l3_or_gm(index.numel()); + r = xpu::cast_v2(dev_ctx.x_context(), + index.data(), + index_int_ptr_l3, + index.numel()); + PADDLE_ENFORCE_EQ(r, + XPU_SUCCESS, + phi::errors::External("XPU API(cast_v2) return wrong " + "value[%d %s]", + r, + XPUAPIErrorMsg[r])); + + r = xpu::gather_grad( + dev_ctx.x_context(), + reinterpret_cast(out_grad.data()), + index_int_ptr_l3, + reinterpret_cast(x_grad->data()), + xshape, + index.dims()[0], + axis_v, + overwrite); + } + PADDLE_ENFORCE_EQ( + r, + xpu::Error_t::SUCCESS, + phi::errors::External("XPU gather grad kernel return wrong value[%d %s]", + r, + XPUAPIErrorMsg[r])); +} + +} // namespace phi + +PD_REGISTER_KERNEL(gather_grad, + XPU, + ALL_LAYOUT, + phi::GatherGradKernel, + float, + phi::dtype::float16) {} diff --git a/paddle/phi/kernels/xpu/gather_kernel.cc b/paddle/phi/kernels/xpu/gather_kernel.cc new file mode 100644 index 0000000000..c3520178d1 --- /dev/null +++ b/paddle/phi/kernels/xpu/gather_kernel.cc @@ -0,0 +1,86 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/gather_kernel.h" + +#include "paddle/phi/backends/xpu/enforce_xpu.h" +#include "paddle/phi/core/kernel_registry.h" + +namespace phi { + +template +void GatherKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& index, + const Scalar& axis, + DenseTensor* out) { + auto axis_v = axis.to(); + const auto& index_type = index.dtype(); + + dev_ctx.template Alloc(out); + if (x.numel() == 0) return; + + const auto index_dims = index.dims(); + if (index_dims.size() == 2) { + PADDLE_ENFORCE_EQ( + index_dims[1], + 1, + phi::errors::InvalidArgument( + "The last dim of index should be 1 when it is 2D, but we get %d", + index_dims[1])); + } else { + PADDLE_ENFORCE_EQ( + index_dims.size(), + 1, + phi::errors::InvalidArgument( + "The index should be 1D, when it is not 2D, but we get %d", + index_dims.size())); + } + std::vector xshape(x.dims().size()); + for (int i = 0; i < x.dims().size(); ++i) { + xshape[i] = x.dims()[i]; + } + + using XPUType = typename XPUTypeTrait::Type; + + int r = XPU_SUCCESS; + if (index_type == DataType::INT32) { + r = xpu::gather(dev_ctx.x_context(), + reinterpret_cast(x.data()), + index.data(), + reinterpret_cast(out->data()), + xshape, + index.dims()[0], + axis_v); + } else { + r = xpu::gather( + dev_ctx.x_context(), + reinterpret_cast(x.data()), + index.data(), + reinterpret_cast(out->data()), + xshape, + index.dims()[0], + axis_v); + } + PADDLE_ENFORCE_EQ( + r, + xpu::Error_t::SUCCESS, + phi::errors::External( + "XPU gather kernel return wrong value[%d %s]", r, XPUAPIErrorMsg[r])); +} + +} // namespace phi + +PD_REGISTER_KERNEL( + gather, XPU, ALL_LAYOUT, phi::GatherKernel, float, phi::dtype::float16) {} diff --git a/paddle/phi/kernels/xpu/gather_nd_kernel.cc b/paddle/phi/kernels/xpu/gather_nd_kernel.cc new file mode 100644 index 0000000000..d7d23fa17c --- /dev/null +++ b/paddle/phi/kernels/xpu/gather_nd_kernel.cc @@ -0,0 +1,83 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/gather_nd_kernel.h" + +#include "paddle/phi/backends/xpu/enforce_xpu.h" +#include "paddle/phi/core/kernel_registry.h" + +namespace phi { + +template +void GatherNdKernel(const Context &ctx, + const DenseTensor &x, + const DenseTensor &index, + DenseTensor *out) { + ctx.template Alloc(out); + const auto &index_type = index.dtype(); + + if (x.numel() == 0) return; + + if (index.numel() == 0) { + phi::Copy(ctx, x, phi::XPUPlace(), true, out); + return; + } + + bool index_type_match = + index_type == DataType::INT32 || index_type == DataType::INT64; + PADDLE_ENFORCE_EQ( + index_type_match, + true, + phi::errors::InvalidArgument("Index holds the wrong type, it holds [%s]," + "but desires to be [%s] or [%s]", + index_type, + DataType::INT32, + DataType::INT64)); + + auto x_shape = phi::vectorize(x.dims()); + auto index_shape = phi::vectorize(index.dims()); + if (index_shape.size() == 1) { + index_shape.insert(index_shape.begin(), 1); + } + xpu::VectorParam x_vec = { + x_shape.data(), static_cast(x_shape.size()), nullptr}; + + int ret = XPU_SUCCESS; + if (index_type == DataType::INT32) { + ret = xpu::gather_nd(ctx.x_context(), + x.data(), + index.data(), + out->data(), + x_vec, + index_shape); + } else { + ret = xpu::gather_nd(ctx.x_context(), + x.data(), + index.data(), + out->data(), + x_vec, + index_shape); + } + PADDLE_ENFORCE_EQ( + ret, + XPU_SUCCESS, + phi::errors::External("XPU gather_nd kernel return wrong value[%d %s]", + ret, + XPUAPIErrorMsg[ret])); +} + +} // namespace phi + +PD_REGISTER_KERNEL( + gather_nd, XPU, ALL_LAYOUT, phi::GatherNdKernel, float, int64_t, int) {} diff --git a/paddle/phi/kernels/xpu/gaussian_random_kernel.cc b/paddle/phi/kernels/xpu/gaussian_random_kernel.cc new file mode 100644 index 0000000000..ee216e7588 --- /dev/null +++ b/paddle/phi/kernels/xpu/gaussian_random_kernel.cc @@ -0,0 +1,55 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/gaussian_random_kernel.h" + +#include "paddle/fluid/framework/generator.h" +#include "paddle/fluid/memory/memcpy.h" +#include "paddle/phi/backends/xpu/enforce_xpu.h" +#include "paddle/phi/core/kernel_registry.h" + +namespace phi { + +template +void GaussianRandomKernel(const Context& ctx, + const IntArray& shape, + float mean, + float std, + int seed, + DataType dtype, + DenseTensor* out) { + std::normal_distribution dist(mean, std); + int64_t size = out->numel(); + ctx.template Alloc(out); + auto* data = out->data(); + uint64_t seed_v = static_cast(seed); + // TODO(pangyoki): implement GetXPURandomEngine to set different seeds on + // corresponding XPU device. + auto engine = paddle::framework::GetCPURandomEngine(seed_v); + + std::unique_ptr data_cpu(new T[size]); + for (int64_t i = 0; i < size; ++i) { + data_cpu[i] = dist(*engine); + } + paddle::memory::Copy(phi::XPUPlace(), + data, + phi::CPUPlace(), + reinterpret_cast(data_cpu.get()), + size * sizeof(T)); +} + +} // namespace phi + +PD_REGISTER_KERNEL( + gaussian_random, XPU, ALL_LAYOUT, phi::GaussianRandomKernel, float) {} -- GitLab