diff --git a/paddle/fluid/operators/randint_op.cc b/paddle/fluid/operators/randint_op.cc index 09c58cd7d4cda396d60a94b02cc8a705bb3c3b01..548e28716dd9108ffd55463cccf9f91ad3b9a941 100644 --- a/paddle/fluid/operators/randint_op.cc +++ b/paddle/fluid/operators/randint_op.cc @@ -24,37 +24,6 @@ namespace paddle { namespace operators { -template -class CPURandintKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& ctx) const override { - std::vector new_shape; - auto list_new_shape_tensor = - ctx.MultiInput("ShapeTensorList"); - if (list_new_shape_tensor.size() > 0 || ctx.HasInput("ShapeTensor")) { - if (ctx.HasInput("ShapeTensor")) { - auto* shape_tensor = ctx.Input("ShapeTensor"); - new_shape = GetNewDataFromShapeTensor(shape_tensor); - } else if (list_new_shape_tensor.size() > 0) { - new_shape = GetNewDataFromShapeTensorList(list_new_shape_tensor); - } - } - auto* out = ctx.Output("Out"); - if (!new_shape.empty()) out->Resize(phi::make_ddim(new_shape)); - T* data = out->mutable_data(ctx.GetPlace()); - int64_t size = out->numel(); - - std::uniform_int_distribution dist(ctx.Attr("low"), - ctx.Attr("high") - 1); - unsigned int seed = static_cast(ctx.Attr("seed")); - auto engine = framework::GetCPURandomEngine(seed); - - for (int64_t i = 0; i < size; ++i) { - data[i] = dist(*engine); - } - } -}; - class RandintOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; @@ -176,6 +145,3 @@ REGISTER_OPERATOR( randint, ops::RandintOp, ops::RandintOpMaker, paddle::framework::EmptyGradOpMaker, paddle::framework::EmptyGradOpMaker) - -REGISTER_OP_CPU_KERNEL(randint, ops::CPURandintKernel, - ops::CPURandintKernel) diff --git a/paddle/fluid/operators/randint_op.cu b/paddle/fluid/operators/randint_op.cu deleted file mode 100644 index 2f9a8cfd142ec7a3d0175b91bd79f239f654c126..0000000000000000000000000000000000000000 --- a/paddle/fluid/operators/randint_op.cu +++ /dev/null @@ -1,84 +0,0 @@ -// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -#include -#include -#include "paddle/fluid/framework/convert_utils.h" -#include "paddle/fluid/framework/generator.h" -#include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/operators/uniform_random_op.h" - -namespace paddle { -namespace operators { - -template -class GPURandintKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& context) const override { - std::vector new_shape; - auto list_new_shape_tensor = - context.MultiInput("ShapeTensorList"); - if (list_new_shape_tensor.size() > 0 || context.HasInput("ShapeTensor")) { - if (context.HasInput("ShapeTensor")) { - auto* shape_tensor = context.Input("ShapeTensor"); - new_shape = GetNewDataFromShapeTensor(shape_tensor); - } else if (list_new_shape_tensor.size() > 0) { - new_shape = GetNewDataFromShapeTensorList(list_new_shape_tensor); - } - } - - platform::CPUPlace cpu; - auto dtype = static_cast( - context.Attr("dtype")); - auto* out = context.Output("Out"); - if (!new_shape.empty()) out->Resize(phi::make_ddim(new_shape)); - T low = static_cast(context.Attr("low")); - T high = static_cast(context.Attr("high")) - 1; - framework::LoDTensor tensor; - tensor.Resize(out->dims()); - tensor.mutable_data(cpu, framework::TransToPtenDataType(dtype)); - T* data = tensor.mutable_data(cpu); - - int64_t size = out->numel(); - unsigned int seed = static_cast(context.Attr("seed")); - - /* - std::minstd_rand engine; - if (seed == 0) { - std::random_device rd; - seed = rd(); - } - engine.seed(seed); - */ - - std::uniform_int_distribution<> dist(context.Attr("low"), - context.Attr("high") - 1); - auto engine = framework::GetCPURandomEngine(seed); - - for (int64_t i = 0; i < size; ++i) { - data[i] = dist(*engine); - } - - if (platform::is_gpu_place(context.GetPlace())) { - // Copy tensor to out - framework::TensorCopy(tensor, context.GetPlace(), out); - } - } -}; - -} // namespace operators -} // namespace paddle - -namespace ops = paddle::operators; -REGISTER_OP_CUDA_KERNEL(randint, ops::GPURandintKernel, - ops::GPURandintKernel) diff --git a/paddle/phi/kernels/cpu/randint_kernel.cc b/paddle/phi/kernels/cpu/randint_kernel.cc new file mode 100644 index 0000000000000000000000000000000000000000..5fe56b57452d5a1afd718d85a038310e27f0ff50 --- /dev/null +++ b/paddle/phi/kernels/cpu/randint_kernel.cc @@ -0,0 +1,63 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/randint_kernel.h" + +#include + +#include "paddle/phi/backends/cpu/cpu_context.h" +#include "paddle/phi/core/kernel_registry.h" + +namespace phi { + +template +void RandintRawKernel(const Context& ctx, + int low, + int high, + const ScalarArray& shape, + DataType dtype, + int seed, + DenseTensor* out) { + out->ResizeAndAllocate(phi::make_ddim(shape.GetData())); + auto size = out->numel(); + std::shared_ptr engine; + if (seed) { + engine = std::make_shared(); + engine->seed(seed); + } else { + engine = ctx.GetGenerator()->GetCPUEngine(); + } + std::uniform_int_distribution dist(low, high - 1); + auto data = out->data(); + for (int64_t i = 0; i < size; ++i) { + data[i] = dist(*engine); + } +} + +template +void RandintKernel(const Context& ctx, + int low, + int high, + const ScalarArray& shape, + DataType dtype, + DenseTensor* out) { + RandintRawKernel(ctx, low, high, shape, dtype, 0, out); +} + +} // namespace phi + +PD_REGISTER_KERNEL( + randint_raw, CPU, ALL_LAYOUT, phi::RandintRawKernel, int, int64_t) {} +PD_REGISTER_KERNEL(randint, CPU, ALL_LAYOUT, phi::RandintKernel, int, int64_t) { +} diff --git a/paddle/phi/kernels/gpu/randint_kernel.cu b/paddle/phi/kernels/gpu/randint_kernel.cu new file mode 100644 index 0000000000000000000000000000000000000000..b89b714c73d92c896cde3beef182703018b6aa12 --- /dev/null +++ b/paddle/phi/kernels/gpu/randint_kernel.cu @@ -0,0 +1,80 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/randint_kernel.h" + +#include + +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/core/kernel_registry.h" + +// See Note [ Why still include the fluid headers? ] +#include "paddle/fluid/memory/memcpy.h" + +namespace phi { + +template +void RandintRawKernel(const Context& ctx, + int low, + int high, + const ScalarArray& shape, + DataType dtype, + int seed, + DenseTensor* out) { + DenseTensor tmp; + tmp.Resize(phi::make_ddim(shape.GetData())); + T* tmp_data = ctx.template HostAlloc(&tmp); + + out->ResizeAndAllocate(tmp.dims()); + auto size = out->numel(); + + std::shared_ptr engine; + if (seed) { + engine = std::make_shared(); + engine->seed(seed); + } else { + engine = ctx.GetHostGenerator()->GetCPUEngine(); + } + std::uniform_int_distribution dist(low, high - 1); + auto data = out->data(); + for (int64_t i = 0; i < size; ++i) { + tmp_data[i] = dist(*engine); + } + + paddle::memory::Copy( + out->place(), + data, + tmp.place(), + tmp_data, + size * paddle::experimental::SizeOf(out->dtype()), + 0); +} + +template +void RandintKernel(const Context& ctx, + int low, + int high, + const ScalarArray& shape, + DataType dtype, + DenseTensor* out) { + RandintRawKernel(ctx, low, high, shape, dtype, 0, out); +} + +} // namespace phi + +PD_REGISTER_KERNEL( + randint_raw, GPU, ALL_LAYOUT, phi::RandintRawKernel, int, int64_t) {} + +PD_REGISTER_KERNEL(randint, GPU, ALL_LAYOUT, phi::RandintKernel, int, int64_t) { +} diff --git a/paddle/phi/kernels/randint_kernel.h b/paddle/phi/kernels/randint_kernel.h new file mode 100644 index 0000000000000000000000000000000000000000..1a78e73d863e33619d10b72bf9e368aba4c856c5 --- /dev/null +++ b/paddle/phi/kernels/randint_kernel.h @@ -0,0 +1,39 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/common/scalar_array.h" +#include "paddle/phi/core/dense_tensor.h" + +namespace phi { + +template +void RandintKernel(const Context& ctx, + int low, + int high, + const ScalarArray& shape, + DataType dtype, + DenseTensor* out); + +template +void RandintRawKernel(const Context& ctx, + int low, + int high, + const ScalarArray& shape, + DataType dtype, + int seed, + DenseTensor* out); + +} // namespace phi diff --git a/paddle/phi/ops/compat/randint_sig.cc b/paddle/phi/ops/compat/randint_sig.cc new file mode 100644 index 0000000000000000000000000000000000000000..eb6da78a258bc415b54fd128655bae422b3b711c --- /dev/null +++ b/paddle/phi/ops/compat/randint_sig.cc @@ -0,0 +1,63 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/phi/core/compat/op_utils.h" + +namespace phi { + +KernelSignature RandintOpArgumentMapping(const ArgumentMappingContext& ctx) { + int seed = paddle::any_cast(ctx.Attr("seed")); + if (seed) { + if (ctx.InputSize("ShapeTensorList") > 0) { + return KernelSignature( + "randint_raw", + {}, + {"low", "high", "ShapeTensorList", "seed", "dtype"}, + {"Out"}); + } else { + const auto& shape = + paddle::any_cast>(ctx.Attr("shape")); + if (ctx.HasInput("ShapeTensor") && shape.empty()) { + return KernelSignature("randint_raw", + {}, + {"low", "high", "ShapeTensor", "seed", "dtype"}, + {"Out"}); + } else { + return KernelSignature("randint_raw", + {}, + {"low", "high", "shape", "seed", "dtype"}, + {"Out"}); + } + } + } else { + if (ctx.InputSize("ShapeTensorList") > 0) { + return KernelSignature( + "randint", {}, {"low", "high", "ShapeTensorList", "dtype"}, {"Out"}); + } else { + const auto& shape = + paddle::any_cast>(ctx.Attr("shape")); + if (ctx.HasInput("ShapeTensor") && shape.empty()) { + return KernelSignature( + "randint", {}, {"low", "high", "ShapeTensor", "dtype"}, {"Out"}); + } else { + return KernelSignature( + "randint", {}, {"low", "high", "shape", "dtype"}, {"Out"}); + } + } + } +} + +} // namespace phi + +PD_REGISTER_ARG_MAPPING_FN(randint, phi::RandintOpArgumentMapping);