未验证 提交 30992ea0 编写于 作者: L Leo Chen 提交者: GitHub

[phi] move randperm to phi (#39816)

* move randperm to phi

* fix npu

* fix memory::Copy
上级 ad294a81
......@@ -12,7 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/randperm_op.h"
#include <string>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
......@@ -89,10 +88,3 @@ REGISTER_OPERATOR(
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>,
paddle::operators::RandpermOpVarTypeInference);
template <typename T>
using kernel =
paddle::operators::RandpermKernel<paddle::platform::CPUDeviceContext, T>;
REGISTER_OP_CPU_KERNEL(randperm, kernel<int64_t>, kernel<int>, kernel<float>,
kernel<double>);
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/operators/randperm_op.h"
template <typename T>
using kernel =
paddle::operators::RandpermKernel<paddle::platform::CUDADeviceContext, T>;
REGISTER_OP_CUDA_KERNEL(randperm, kernel<int64_t>, kernel<int>, kernel<float>,
kernel<double>);
......@@ -172,6 +172,7 @@ inline void EmplaceDeviceContext(
.get());
dev_ctx->SetGenerator(framework::DefaultCPUGenerator().get());
}
dev_ctx->SetHostGenerator(framework::DefaultCPUGenerator().get());
dev_ctx->SetHostAllocator(
memory::allocation::AllocatorFacade::Instance()
.GetAllocator(platform::CPUPlace())
......
......@@ -119,22 +119,39 @@ struct DeviceContext::Impl {
gen,
phi::errors::InvalidArgument(
"Required generator shall not be nullptr, but received nullptr."));
generator_ = gen;
device_generator_ = gen;
}
Generator* GetGenerator() const {
PADDLE_ENFORCE_NOT_NULL(
generator_,
device_generator_,
phi::errors::InvalidArgument("Required generator_ shall not be "
"nullptr, but received nullptr."));
return generator_;
return device_generator_;
}
void SetHostGenerator(Generator* gen) {
PADDLE_ENFORCE_NOT_NULL(
gen,
phi::errors::InvalidArgument(
"Required generator shall not be nullptr, but received nullptr."));
host_generator_ = gen;
}
Generator* GetHostGenerator() const {
PADDLE_ENFORCE_NOT_NULL(
host_generator_,
phi::errors::InvalidArgument("Required generator_ shall not be "
"nullptr, but received nullptr."));
return host_generator_;
}
private:
const Allocator* device_allocator_{nullptr};
const Allocator* host_allocator_{nullptr};
const Allocator* zero_allocator_{nullptr};
Generator* generator_{nullptr};
Generator* device_generator_{nullptr};
Generator* host_generator_{nullptr};
};
DeviceContext::DeviceContext() { impl_ = std::make_unique<Impl>(); }
......@@ -143,6 +160,8 @@ DeviceContext::DeviceContext(const DeviceContext& other) {
impl_->SetHostAllocator(&other.GetHostAllocator());
impl_->SetAllocator(&other.GetAllocator());
impl_->SetZeroAllocator(&other.GetZeroAllocator());
impl_->SetHostGenerator(other.GetHostGenerator());
impl_->SetGenerator(other.GetGenerator());
}
DeviceContext::DeviceContext(DeviceContext&& other) {
......@@ -224,4 +243,12 @@ void DeviceContext::SetGenerator(Generator* gen) { impl_->SetGenerator(gen); }
Generator* DeviceContext::GetGenerator() const { return impl_->GetGenerator(); }
void DeviceContext::SetHostGenerator(Generator* gen) {
impl_->SetHostGenerator(gen);
}
Generator* DeviceContext::GetHostGenerator() const {
return impl_->GetHostGenerator();
}
} // namespace phi
......@@ -132,6 +132,19 @@ class DeviceContext {
*/
Generator* GetGenerator() const;
/**
* @brief Set the host generator for special op.
*
* @param Generator
*/
void SetHostGenerator(Generator*);
/**
* @brief Get the host generator object.
*
* @return Generator
*/
Generator* GetHostGenerator() const;
private:
struct Impl;
std::unique_ptr<Impl> impl_;
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/randperm_kernel.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/device_context.h"
#include "paddle/phi/core/kernel_registry.h"
namespace phi {
template <typename T, typename Context>
void RandpermKernel(const Context& ctx,
int n,
DataType dtype,
DenseTensor* out) {
T* out_data = ctx.template Alloc<T>(out);
auto gen_ptr = ctx.GetHostGenerator();
auto engine = gen_ptr->GetCPUEngine();
for (int i = 0; i < n; ++i) {
out_data[i] = static_cast<T>(i);
}
std::shuffle(out_data, out_data + n, *engine);
}
} // namespace phi
PD_REGISTER_KERNEL(randperm,
CPU,
ALL_LAYOUT,
phi::RandpermKernel,
float,
double,
int,
int64_t) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/device_context.h"
#include "paddle/phi/kernels/randperm_kernel.h"
// See Note [ Why still include the fluid headers? ]
#include "paddle/fluid/memory/memcpy.h"
#include "paddle/phi/core/kernel_registry.h"
namespace phi {
template <typename T, typename Context>
void RandpermKernel(const Context& ctx,
int n,
DataType dtype,
DenseTensor* out) {
DenseTensor tmp;
tmp.Resize(phi::make_ddim({n}));
T* tmp_data = ctx.template HostAlloc<T>(&tmp);
auto gen_ptr = ctx.GetHostGenerator();
auto engine = gen_ptr->GetCPUEngine();
for (int i = 0; i < n; ++i) {
tmp_data[i] = static_cast<T>(i);
}
std::shuffle(tmp_data, tmp_data + n, *engine);
T* out_data = ctx.template Alloc<T>(out);
auto size = out->numel() * paddle::experimental::SizeOf(out->dtype());
paddle::memory::Copy<phi::GPUPlace, phi::Place>(
out->place(), out_data, tmp.place(), tmp_data, size, 0);
}
} // namespace phi
PD_REGISTER_KERNEL(randperm,
GPU,
ALL_LAYOUT,
phi::RandpermKernel,
float,
double,
int,
int64_t) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/device_context.h"
namespace phi {
template <typename T, typename Context>
void RandpermKernel(const Context& ctx,
int n,
DataType dtype,
DenseTensor* out);
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/core/compat/op_utils.h"
namespace phi {
KernelSignature RandpermOpArgumentMapping(const ArgumentMappingContext& ctx) {
return KernelSignature("randperm", {}, {"n", "dtype"}, {"Out"});
}
} // namespace phi
PD_REGISTER_ARG_MAPPING_FN(randperm, phi::RandpermOpArgumentMapping);
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册