diff --git a/paddle/fluid/operators/randperm_op.cc b/paddle/fluid/operators/randperm_op.cc index bdc2ea0b5bfbbfc45f02d4df3a7cf1dbae25bacf..1b28ab3c133f7d57250e3357b0d732603719ef99 100644 --- a/paddle/fluid/operators/randperm_op.cc +++ b/paddle/fluid/operators/randperm_op.cc @@ -12,7 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/fluid/operators/randperm_op.h" #include #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/operator.h" @@ -89,10 +88,3 @@ REGISTER_OPERATOR( paddle::framework::EmptyGradOpMaker, paddle::framework::EmptyGradOpMaker, paddle::operators::RandpermOpVarTypeInference); - -template -using kernel = - paddle::operators::RandpermKernel; - -REGISTER_OP_CPU_KERNEL(randperm, kernel, kernel, kernel, - kernel); diff --git a/paddle/fluid/operators/randperm_op.cu b/paddle/fluid/operators/randperm_op.cu deleted file mode 100644 index 7ed52a8fd25b104f50446082ff3a040e90bf44ea..0000000000000000000000000000000000000000 --- a/paddle/fluid/operators/randperm_op.cu +++ /dev/null @@ -1,24 +0,0 @@ -/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/framework/operator.h" -#include "paddle/fluid/operators/randperm_op.h" - -template -using kernel = - paddle::operators::RandpermKernel; - -REGISTER_OP_CUDA_KERNEL(randperm, kernel, kernel, kernel, - kernel); diff --git a/paddle/fluid/platform/device_context.cc b/paddle/fluid/platform/device_context.cc index 4282ec20623c93b331ef21a4edea0d350cfe85c8..6a7956628f80464740e3cd812b0b663cc36d6fc6 100644 --- a/paddle/fluid/platform/device_context.cc +++ b/paddle/fluid/platform/device_context.cc @@ -172,6 +172,7 @@ inline void EmplaceDeviceContext( .get()); dev_ctx->SetGenerator(framework::DefaultCPUGenerator().get()); } + dev_ctx->SetHostGenerator(framework::DefaultCPUGenerator().get()); dev_ctx->SetHostAllocator( memory::allocation::AllocatorFacade::Instance() .GetAllocator(platform::CPUPlace()) diff --git a/paddle/phi/core/device_context.cc b/paddle/phi/core/device_context.cc index c3e0d2a75228b3211e5d76f95c2f8ff8089b6415..9c1d85251f8926141341ee6b8c15e29164894ee7 100644 --- a/paddle/phi/core/device_context.cc +++ b/paddle/phi/core/device_context.cc @@ -119,22 +119,39 @@ struct DeviceContext::Impl { gen, phi::errors::InvalidArgument( "Required generator shall not be nullptr, but received nullptr.")); - generator_ = gen; + device_generator_ = gen; } Generator* GetGenerator() const { PADDLE_ENFORCE_NOT_NULL( - generator_, + device_generator_, phi::errors::InvalidArgument("Required generator_ shall not be " "nullptr, but received nullptr.")); - return generator_; + return device_generator_; + } + + void SetHostGenerator(Generator* gen) { + PADDLE_ENFORCE_NOT_NULL( + gen, + phi::errors::InvalidArgument( + "Required generator shall not be nullptr, but received nullptr.")); + host_generator_ = gen; + } + + Generator* GetHostGenerator() const { + PADDLE_ENFORCE_NOT_NULL( + host_generator_, + phi::errors::InvalidArgument("Required generator_ shall not be " + "nullptr, but received nullptr.")); + return host_generator_; } private: const Allocator* device_allocator_{nullptr}; const Allocator* host_allocator_{nullptr}; const Allocator* zero_allocator_{nullptr}; - Generator* generator_{nullptr}; + Generator* device_generator_{nullptr}; + Generator* host_generator_{nullptr}; }; DeviceContext::DeviceContext() { impl_ = std::make_unique(); } @@ -143,6 +160,8 @@ DeviceContext::DeviceContext(const DeviceContext& other) { impl_->SetHostAllocator(&other.GetHostAllocator()); impl_->SetAllocator(&other.GetAllocator()); impl_->SetZeroAllocator(&other.GetZeroAllocator()); + impl_->SetHostGenerator(other.GetHostGenerator()); + impl_->SetGenerator(other.GetGenerator()); } DeviceContext::DeviceContext(DeviceContext&& other) { @@ -224,4 +243,12 @@ void DeviceContext::SetGenerator(Generator* gen) { impl_->SetGenerator(gen); } Generator* DeviceContext::GetGenerator() const { return impl_->GetGenerator(); } +void DeviceContext::SetHostGenerator(Generator* gen) { + impl_->SetHostGenerator(gen); +} + +Generator* DeviceContext::GetHostGenerator() const { + return impl_->GetHostGenerator(); +} + } // namespace phi diff --git a/paddle/phi/core/device_context.h b/paddle/phi/core/device_context.h index 7c1411e3bef3740f11ff39947028ead4d0357771..689f4e4e66d15f60aec873a9e9b9c07797833487 100644 --- a/paddle/phi/core/device_context.h +++ b/paddle/phi/core/device_context.h @@ -132,6 +132,19 @@ class DeviceContext { */ Generator* GetGenerator() const; + /** + * @brief Set the host generator for special op. + * + * @param Generator + */ + void SetHostGenerator(Generator*); + /** + * @brief Get the host generator object. + * + * @return Generator + */ + Generator* GetHostGenerator() const; + private: struct Impl; std::unique_ptr impl_; diff --git a/paddle/phi/kernels/cpu/randperm_kernel.cc b/paddle/phi/kernels/cpu/randperm_kernel.cc new file mode 100644 index 0000000000000000000000000000000000000000..28092c8df6d153c6b5e787027f0c2239bd257cc1 --- /dev/null +++ b/paddle/phi/kernels/cpu/randperm_kernel.cc @@ -0,0 +1,46 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/randperm_kernel.h" +#include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/core/device_context.h" +#include "paddle/phi/core/kernel_registry.h" + +namespace phi { + +template +void RandpermKernel(const Context& ctx, + int n, + DataType dtype, + DenseTensor* out) { + T* out_data = ctx.template Alloc(out); + auto gen_ptr = ctx.GetHostGenerator(); + auto engine = gen_ptr->GetCPUEngine(); + + for (int i = 0; i < n; ++i) { + out_data[i] = static_cast(i); + } + std::shuffle(out_data, out_data + n, *engine); +} + +} // namespace phi + +PD_REGISTER_KERNEL(randperm, + CPU, + ALL_LAYOUT, + phi::RandpermKernel, + float, + double, + int, + int64_t) {} diff --git a/paddle/phi/kernels/gpu/randperm_kernel.cu b/paddle/phi/kernels/gpu/randperm_kernel.cu new file mode 100644 index 0000000000000000000000000000000000000000..f75f768b633a31a9d3d6eadcf036640f50309a8b --- /dev/null +++ b/paddle/phi/kernels/gpu/randperm_kernel.cu @@ -0,0 +1,57 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/core/device_context.h" +#include "paddle/phi/kernels/randperm_kernel.h" + +// See Note [ Why still include the fluid headers? ] +#include "paddle/fluid/memory/memcpy.h" +#include "paddle/phi/core/kernel_registry.h" + +namespace phi { + +template +void RandpermKernel(const Context& ctx, + int n, + DataType dtype, + DenseTensor* out) { + DenseTensor tmp; + tmp.Resize(phi::make_ddim({n})); + T* tmp_data = ctx.template HostAlloc(&tmp); + + auto gen_ptr = ctx.GetHostGenerator(); + auto engine = gen_ptr->GetCPUEngine(); + + for (int i = 0; i < n; ++i) { + tmp_data[i] = static_cast(i); + } + std::shuffle(tmp_data, tmp_data + n, *engine); + + T* out_data = ctx.template Alloc(out); + auto size = out->numel() * paddle::experimental::SizeOf(out->dtype()); + paddle::memory::Copy( + out->place(), out_data, tmp.place(), tmp_data, size, 0); +} + +} // namespace phi + +PD_REGISTER_KERNEL(randperm, + GPU, + ALL_LAYOUT, + phi::RandpermKernel, + float, + double, + int, + int64_t) {} diff --git a/paddle/phi/kernels/randperm_kernel.h b/paddle/phi/kernels/randperm_kernel.h new file mode 100644 index 0000000000000000000000000000000000000000..63bdac6da6fdc12955e3743f23941840696a0ce6 --- /dev/null +++ b/paddle/phi/kernels/randperm_kernel.h @@ -0,0 +1,28 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/core/device_context.h" + +namespace phi { + +template +void RandpermKernel(const Context& ctx, + int n, + DataType dtype, + DenseTensor* out); + +} // namespace phi diff --git a/paddle/phi/ops/compat/randperm_sig.cc b/paddle/phi/ops/compat/randperm_sig.cc new file mode 100644 index 0000000000000000000000000000000000000000..14b28512e402a377b4f2f8f7d8f1e90f7ef37b71 --- /dev/null +++ b/paddle/phi/ops/compat/randperm_sig.cc @@ -0,0 +1,25 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/core/compat/op_utils.h" + +namespace phi { + +KernelSignature RandpermOpArgumentMapping(const ArgumentMappingContext& ctx) { + return KernelSignature("randperm", {}, {"n", "dtype"}, {"Out"}); +} + +} // namespace phi + +PD_REGISTER_ARG_MAPPING_FN(randperm, phi::RandpermOpArgumentMapping);