未验证 提交 a46d7fe6 编写于 作者: D duanboqiang 提交者: GitHub

[phi]migrate class center sample kernel (#44949)

* migrate class center sample kernel

* fix Resize ddim error

* set buffer ptr

* add header

* add header

* remove comment

* remove header
上级 ecc3098e
......@@ -12,7 +12,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/class_center_sample_op.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/op_version_registry.h"
namespace paddle {
namespace operators {
......@@ -143,10 +144,6 @@ class ClassCenterSampleOpMaker : public framework::OpProtoAndCheckerMaker {
} // namespace paddle
namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OP_WITHOUT_GRADIENT(class_center_sample,
ops::ClassCenterSampleOp,
ops::ClassCenterSampleOpMaker);
REGISTER_OP_CPU_KERNEL(class_center_sample,
ops::ClassCenterSampleCPUKernel<int64_t>,
ops::ClassCenterSampleCPUKernel<int>);
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <map>
#include <set>
#include <vector>
#include "paddle/fluid/framework/generator.h"
#include "paddle/fluid/framework/op_registry.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
template <typename T>
class ClassCenterSampleCPUKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* label = ctx.Input<Tensor>("Label");
auto* remapped_label = ctx.Output<Tensor>("RemappedLabel");
auto* sampled_local_class_center =
ctx.Output<Tensor>("SampledLocalClassCenter");
int num_classes = ctx.Attr<int>("num_classes");
int num_samples = ctx.Attr<int>("num_samples");
int seed = ctx.Attr<int>("seed");
bool fix_seed = ctx.Attr<bool>("fix_seed");
PADDLE_ENFORCE_GT(num_classes,
0,
platform::errors::InvalidArgument(
"The value 'num_classes' for Op(class_center_sample) "
"must be greater than 0, "
"but the value given is %d.",
num_classes));
PADDLE_ENFORCE_GT(num_samples,
0,
platform::errors::InvalidArgument(
"The value 'num_samples' for Op(class_center_sample) "
"must be greater than 0, "
"but the value given is %d.",
num_samples));
PADDLE_ENFORCE_LE(num_samples,
num_classes,
platform::errors::InvalidArgument(
"The value 'num_samples' for Op(class_center_sample) "
"must be less than or equal to %d, "
"but the value given is %d.",
num_classes,
num_samples));
int64_t numel = label->numel();
auto* label_ptr = label->data<T>();
// get unique positive class center by ascending
std::set<T, std::less<T>> unique_label;
for (int64_t i = 0; i < numel; ++i) {
unique_label.insert(label_ptr[i]);
}
// constrcut a lookup table and get sampled_local_class_center
std::vector<T> actual_sampled;
std::map<T, T> new_class_dict;
T idx = 0;
for (auto& t : unique_label) {
new_class_dict[t] = idx;
actual_sampled.push_back(t);
idx++;
}
if (!fix_seed) {
std::random_device rnd;
seed = rnd();
}
std::uniform_int_distribution<T> dist(0, num_classes - 1);
auto engine = framework::GetCPURandomEngine(seed);
// sample negative class center randomly
while (unique_label.size() < static_cast<size_t>(num_samples)) {
T neg = dist(*engine);
if (unique_label.find(neg) == unique_label.end()) {
unique_label.insert(neg);
// unorder for negative class center
actual_sampled.push_back(neg);
}
}
int actual_num_samples = unique_label.size();
T* sampled_local_class_center_ptr =
sampled_local_class_center->mutable_data<T>({actual_num_samples},
ctx.GetPlace());
idx = 0;
for (auto& t : actual_sampled) {
sampled_local_class_center_ptr[idx] = t;
idx++;
}
// remap the input label to sampled class
auto* remmaped_label_ptr = remapped_label->mutable_data<T>(ctx.GetPlace());
for (int64_t i = 0; i < numel; ++i) {
remmaped_label_ptr[i] = new_class_dict[label_ptr[i]];
}
}
};
} // namespace operators
} // namespace paddle
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/core/dense_tensor.h"
namespace phi {
template <typename T, typename Context>
void ClassCenterSampleKernel(const Context& dev_ctx,
const DenseTensor& label,
int num_classes,
int num_samples,
int ring_id,
int rank,
int nranks,
bool fix_seed,
int seed,
DenseTensor* remapped_label,
DenseTensor* sampled_local_class_center);
} // namespace phi
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <map>
#include <set>
#include <vector>
#include "paddle/fluid/framework/generator.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
namespace phi {
template <typename T, typename Context>
void ClassCenterSampleKernel(const Context& dev_ctx,
const DenseTensor& label,
int num_classes,
int num_samples,
int ring_id,
int rank,
int nranks,
bool fix_seed,
int seed,
DenseTensor* remapped_label,
DenseTensor* sampled_local_class_center) {
PADDLE_ENFORCE_GT(num_classes,
0,
errors::InvalidArgument(
"The value 'num_classes' for Op(class_center_sample) "
"must be greater than 0, "
"but the value given is %d.",
num_classes));
PADDLE_ENFORCE_GT(num_samples,
0,
errors::InvalidArgument(
"The value 'num_samples' for Op(class_center_sample) "
"must be greater than 0, "
"but the value given is %d.",
num_samples));
PADDLE_ENFORCE_LE(num_samples,
num_classes,
errors::InvalidArgument(
"The value 'num_samples' for Op(class_center_sample) "
"must be less than or equal to %d, "
"but the value given is %d.",
num_classes,
num_samples));
int64_t numel = label.numel();
auto* label_ptr = label.data<T>();
// get unique positive class center by ascending
std::set<T, std::less<T>> unique_label;
for (int64_t i = 0; i < numel; ++i) {
unique_label.insert(label_ptr[i]);
}
// constrcut a lookup table and get sampled_local_class_center
std::vector<T> actual_sampled;
std::map<T, T> new_class_dict;
T idx = 0;
for (auto& t : unique_label) {
new_class_dict[t] = idx;
actual_sampled.push_back(t);
idx++;
}
if (!fix_seed) {
std::random_device rnd;
seed = rnd();
}
std::uniform_int_distribution<T> dist(0, num_classes - 1);
auto engine = paddle::framework::GetCPURandomEngine(seed);
// sample negative class center randomly
while (unique_label.size() < static_cast<size_t>(num_samples)) {
T neg = dist(*engine);
if (unique_label.find(neg) == unique_label.end()) {
unique_label.insert(neg);
// unorder for negative class center
actual_sampled.push_back(neg);
}
}
int actual_num_samples = unique_label.size();
sampled_local_class_center->Resize({actual_num_samples});
T* sampled_local_class_center_ptr =
dev_ctx.template Alloc<T>(sampled_local_class_center);
idx = 0;
for (auto& t : actual_sampled) {
sampled_local_class_center_ptr[idx] = t;
idx++;
}
// remap the input label to sampled class
auto* remmaped_label_ptr = dev_ctx.template Alloc<T>(remapped_label);
for (int64_t i = 0; i < numel; ++i) {
remmaped_label_ptr[i] = new_class_dict[label_ptr[i]];
}
}
} // namespace phi
PD_REGISTER_KERNEL(class_center_sample,
CPU,
ALL_LAYOUT,
phi::ClassCenterSampleKernel,
int64_t,
int) {}
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
......@@ -29,25 +29,24 @@ namespace cub = hipcub;
#include <iterator>
#include <random>
#include "paddle/fluid/operators/class_center_sample_op.h"
#include "paddle/phi/api/include/tensor.h"
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/platform/enforce.h"
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
#include "paddle/fluid/distributed/collective/ProcessGroup.h"
#include "paddle/fluid/platform/collective_helper.h"
#include "paddle/fluid/platform/device/gpu/nccl_helper.h"
#endif
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
namespace paddle {
namespace operators {
namespace phi {
#define CUDA_KERNEL_LOOP(i, n) \
for (int32_t i = blockIdx.x * blockDim.x + threadIdx.x, \
step = blockDim.x * gridDim.x; \
i < (n); \
i += step)
using Tensor = framework::Tensor;
static constexpr int kNumCUDAThreads = 512;
static constexpr int kNumMaxinumNumBlocks = 4096;
......@@ -246,13 +245,13 @@ struct ActualNumSampledFunctor {
explicit ActualNumSampledFunctor(const T num) : num_samples(num) {}
};
template <typename T>
template <typename T, typename Context>
class MemoryBuffer {
public:
MemoryBuffer(const int num_buffer_ele,
const int num_temp_ele,
const int nranks,
const platform::Place& place) {
const Context& dev_ctx) {
offset1 = 0;
offset2 = offset1 + num_buffer_ele;
offset3 = offset2 + num_buffer_ele;
......@@ -263,8 +262,8 @@ class MemoryBuffer {
offset8 = offset7 + (nranks + 1);
offset9 = offset8 + num_temp_ele;
buffer_ptr = buffer.mutable_data<T>(
{4 * num_buffer_ele + 3 * (nranks + 1) + num_temp_ele}, place);
buffer.Resize({4 * num_buffer_ele + 3 * (nranks + 1) + num_temp_ele});
buffer_ptr = dev_ctx.template Alloc<T>(&buffer);
}
T* cub_sort_keys_ptr() { return buffer_ptr + offset1; }
......@@ -279,7 +278,7 @@ class MemoryBuffer {
}
private:
Tensor buffer;
DenseTensor buffer;
T* buffer_ptr;
int offset1;
int offset2;
......@@ -292,26 +291,21 @@ class MemoryBuffer {
int offset9;
};
template <typename DeviceContext, typename T>
class ClassCenterSampleCUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* label = ctx.Input<Tensor>("Label");
auto* remapped_label = ctx.Output<Tensor>("RemappedLabel");
auto* sampled_local_class_center =
ctx.Output<Tensor>("SampledLocalClassCenter");
int num_classes = ctx.Attr<int>("num_classes");
int num_samples = ctx.Attr<int>("num_samples");
int rid = ctx.Attr<int>("ring_id");
int nranks = ctx.Attr<int>("nranks");
int rank = ctx.Attr<int>("rank");
int seed = ctx.Attr<int>("seed");
bool fix_seed = ctx.Attr<bool>("fix_seed");
template <typename T, typename Context>
void ClassCenterSampleKernel(const Context& dev_ctx,
const DenseTensor& label,
int num_classes,
int num_samples,
int ring_id,
int rank,
int nranks,
bool fix_seed,
int seed,
DenseTensor* remapped_label,
DenseTensor* sampled_local_class_center) {
PADDLE_ENFORCE_GT(num_classes,
0,
platform::errors::InvalidArgument(
errors::InvalidArgument(
"The value 'num_classes' for Op(class_center_sample) "
"must be greater than 0, "
"but the value given is %d.",
......@@ -319,7 +313,7 @@ class ClassCenterSampleCUDAKernel : public framework::OpKernel<T> {
PADDLE_ENFORCE_GT(num_samples,
0,
platform::errors::InvalidArgument(
errors::InvalidArgument(
"The value 'num_samples' for Op(class_center_sample) "
"must be greater than 0, "
"but the value given is %d.",
......@@ -327,17 +321,16 @@ class ClassCenterSampleCUDAKernel : public framework::OpKernel<T> {
PADDLE_ENFORCE_LE(num_samples,
num_classes,
platform::errors::InvalidArgument(
errors::InvalidArgument(
"The value 'num_samples' for Op(class_center_sample) "
"must be less than or equal to %d, "
"but the value given is %d.",
num_classes,
num_samples));
auto& dev_ctx = ctx.template device_context<DeviceContext>();
auto place = dev_ctx.GetPlace();
int batch_size = label->numel();
int batch_size = label.numel();
// Algorithm:
// We first randomly generate a value in [0, num_classes) on each position
// in a array(shape[num_classes]). Then, we mark the element as negative
......@@ -350,40 +343,42 @@ class ClassCenterSampleCUDAKernel : public framework::OpKernel<T> {
// step 1: Calculate num classes per device using nccl all reduce
std::vector<T> shard_dim_vec(nranks + 1, 0);
shard_dim_vec[rank + 1] = num_classes;
Tensor num_classes_per_device;
framework::TensorFromVector(
shard_dim_vec, ctx.cuda_device_context(), &num_classes_per_device);
DenseTensor num_classes_per_device;
paddle::framework::TensorFromVector(
shard_dim_vec, dev_ctx, &num_classes_per_device);
T* num_classes_per_device_ptr = num_classes_per_device.data<T>();
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
if (nranks > 1) {
auto map = distributed::ProcessGroupMapFromGid::getInstance();
if (map->has(rid)) {
auto map = paddle::distributed::ProcessGroupMapFromGid::getInstance();
if (map->has(ring_id)) {
// Use ProcessGroup
distributed::ProcessGroup* pg = map->get(rid);
paddle::distributed::ProcessGroup* pg = map->get(ring_id);
std::vector<phi::DenseTensor> in_tensor;
std::vector<phi::DenseTensor> out_tensor;
in_tensor.push_back(num_classes_per_device);
out_tensor.push_back(num_classes_per_device);
distributed::AllreduceOptions opts;
opts.reduce_op = distributed::ReduceOp::SUM;
paddle::distributed::AllreduceOptions opts;
opts.reduce_op = paddle::distributed::ReduceOp::SUM;
auto task = pg->AllReduce(in_tensor, out_tensor, opts);
task->Wait();
} else {
const auto& comm =
platform::NCCLCommContext::Instance().Get(rid, ctx.GetPlace());
const auto& comm = paddle::platform::NCCLCommContext::Instance().Get(
ring_id, dev_ctx.GetPlace());
// use global calculate stream
const auto calcu_stream =
static_cast<phi::GPUContext*>(
platform::DeviceContextPool::Instance().Get(ctx.GetPlace()))
static_cast<GPUContext*>(
paddle::platform::DeviceContextPool::Instance().Get(
dev_ctx.GetPlace()))
->stream();
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclAllReduce(
PADDLE_ENFORCE_GPU_SUCCESS(paddle::platform::dynload::ncclAllReduce(
num_classes_per_device_ptr,
num_classes_per_device_ptr,
num_classes_per_device.numel(),
platform::ToNCCLDataType(
framework::TransToProtoVarType(num_classes_per_device.dtype())),
paddle::platform::ToNCCLDataType(
paddle::framework::TransToProtoVarType(
num_classes_per_device.dtype())),
ncclSum,
comm->comm(),
calcu_stream));
......@@ -394,8 +389,8 @@ class ClassCenterSampleCUDAKernel : public framework::OpKernel<T> {
// step 2: Determine temporary device storage requirements
int num_buffer_ele = std::max(batch_size, num_classes);
size_t cub_sort_temp_store_size = 0;
PADDLE_ENFORCE_GPU_SUCCESS((cub::DeviceRadixSort::SortPairs<T, T>(
nullptr,
PADDLE_ENFORCE_GPU_SUCCESS(
(cub::DeviceRadixSort::SortPairs<T, T>(nullptr,
cub_sort_temp_store_size,
nullptr,
nullptr,
......@@ -404,18 +399,18 @@ class ClassCenterSampleCUDAKernel : public framework::OpKernel<T> {
num_buffer_ele,
0,
sizeof(T) * 8,
ctx.cuda_device_context().stream())));
dev_ctx.stream())));
size_t cub_sum_temp_store_size = 0;
NotEqualToPreviousAdjacentIterator<T> unique_counting_iter_temp(nullptr, 0);
PADDLE_ENFORCE_GPU_SUCCESS((
cub::DeviceScan::InclusiveSum<NotEqualToPreviousAdjacentIterator<T>,
T*>(nullptr,
PADDLE_ENFORCE_GPU_SUCCESS(
(cub::DeviceScan::InclusiveSum<NotEqualToPreviousAdjacentIterator<T>, T*>(
nullptr,
cub_sum_temp_store_size,
unique_counting_iter_temp,
nullptr,
batch_size,
ctx.cuda_device_context().stream())));
dev_ctx.stream())));
size_t cub_scan_temp_store_size = 0;
ActualNumSampledFunctor<T> actual_num_sampled_op_temp(num_samples);
......@@ -426,7 +421,7 @@ class ClassCenterSampleCUDAKernel : public framework::OpKernel<T> {
num_classes_per_device_ptr,
actual_num_sampled_op_temp,
nranks + 1,
ctx.cuda_device_context().stream())));
dev_ctx.stream())));
size_t cub_temp_storage_bytes =
std::max(std::max(cub_sort_temp_store_size, cub_scan_temp_store_size),
......@@ -434,8 +429,8 @@ class ClassCenterSampleCUDAKernel : public framework::OpKernel<T> {
int num_temp_ele = cub_temp_storage_bytes / sizeof(T) + 1;
// step 3: Alloc buffer memory so that we can reuse allocated memory
MemoryBuffer<T> memory_buffer =
MemoryBuffer<T>(num_buffer_ele, num_temp_ele, nranks, ctx.GetPlace());
MemoryBuffer<T, Context> memory_buffer =
MemoryBuffer<T, Context>(num_buffer_ele, num_temp_ele, nranks, dev_ctx);
T* cub_sort_keys_ptr = memory_buffer.cub_sort_keys_ptr();
T* cub_sort_keys_out_ptr = memory_buffer.cub_sort_keys_out_ptr();
......@@ -453,7 +448,7 @@ class ClassCenterSampleCUDAKernel : public framework::OpKernel<T> {
num_classes_per_device_ptr,
class_interval_ptr,
nranks + 1,
ctx.cuda_device_context().stream())));
dev_ctx.stream())));
// step 5: random sample negative class center
uint64_t seed_data;
......@@ -463,8 +458,8 @@ class ClassCenterSampleCUDAKernel : public framework::OpKernel<T> {
(NumBlocks(num_classes) * kNumCUDAThreads * vec_size) +
1) *
vec_size;
int device_id = ctx.GetPlace().GetDeviceId();
auto gen_cuda = framework::DefaultCUDAGenerator(device_id);
// auto gen_cuda = paddle::framework::DefaultCUDAGenerator(device_id);
auto gen_cuda = dev_ctx.GetGenerator();
if (!fix_seed) {
auto seed_offset = gen_cuda->IncrementOffset(offset);
seed_data = seed_offset.first;
......@@ -473,34 +468,27 @@ class ClassCenterSampleCUDAKernel : public framework::OpKernel<T> {
seed_data = seed + rank;
increment = offset;
}
RandomSampleClassCenter<T><<<NumBlocks(num_classes),
kNumCUDAThreads,
0,
ctx.cuda_device_context().stream()>>>(
RandomSampleClassCenter<T>
<<<NumBlocks(num_classes), kNumCUDAThreads, 0, dev_ctx.stream()>>>(
num_classes, seed_data, increment, num_classes, cub_sort_keys_ptr);
// step 6: mark positive class center as negative value
// fill the sort values to index 0, 1, ..., batch_size-1
MarkPositiveClassCenter<<<NumBlocks(batch_size),
kNumCUDAThreads,
0,
ctx.cuda_device_context().stream()>>>(
MarkPositiveClassCenter<T>
<<<NumBlocks(batch_size), kNumCUDAThreads, 0, dev_ctx.stream()>>>(
batch_size,
rank,
class_interval_ptr,
num_classes,
label->data<T>(),
label.data<T>(),
cub_sort_keys_ptr);
Range<T><<<NumBlocks(num_buffer_ele),
kNumCUDAThreads,
0,
ctx.cuda_device_context().stream()>>>(num_buffer_ele,
cub_sort_values_ptr);
Range<T><<<NumBlocks(num_buffer_ele), kNumCUDAThreads, 0, dev_ctx.stream()>>>(
num_buffer_ele, cub_sort_values_ptr);
// step 7: sort class center by ascending, so that positive class center
// always be sampled.
PADDLE_ENFORCE_GPU_SUCCESS((cub::DeviceRadixSort::SortPairs<T, T>(
cub_temp_storage_ptr,
PADDLE_ENFORCE_GPU_SUCCESS(
(cub::DeviceRadixSort::SortPairs<T, T>(cub_temp_storage_ptr,
cub_temp_storage_bytes,
cub_sort_keys_ptr,
cub_sort_keys_out_ptr,
......@@ -509,40 +497,38 @@ class ClassCenterSampleCUDAKernel : public framework::OpKernel<T> {
num_classes,
0,
sizeof(T) * 8,
ctx.cuda_device_context().stream())));
dev_ctx.stream())));
// step 8: sort input label ascending
PADDLE_ENFORCE_GPU_SUCCESS((cub::DeviceRadixSort::SortPairs<T, T>(
cub_temp_storage_ptr,
PADDLE_ENFORCE_GPU_SUCCESS(
(cub::DeviceRadixSort::SortPairs<T, T>(cub_temp_storage_ptr,
cub_temp_storage_bytes,
label->data<T>(),
label.data<T>(),
cub_sort_keys_out_ptr,
cub_sort_values_ptr,
cub_sort_keys_ptr,
batch_size,
0,
sizeof(T) * 8,
ctx.cuda_device_context().stream())));
dev_ctx.stream())));
// step 9: Calculate new index using InclusiveSum on ascending sorted input
// label
NotEqualToPreviousAdjacentIterator<T> unique_counting_iter(
cub_sort_keys_out_ptr, 0);
PADDLE_ENFORCE_GPU_SUCCESS((
cub::DeviceScan::InclusiveSum<NotEqualToPreviousAdjacentIterator<T>,
T*>(cub_temp_storage_ptr,
PADDLE_ENFORCE_GPU_SUCCESS(
(cub::DeviceScan::InclusiveSum<NotEqualToPreviousAdjacentIterator<T>, T*>(
cub_temp_storage_ptr,
cub_temp_storage_bytes,
unique_counting_iter,
cub_sort_values_ptr,
batch_size,
ctx.cuda_device_context().stream())));
dev_ctx.stream())));
// step 10: Calculate new class center bound among ranks
GetClassCenterBound<T>
<<<NumBlocks(batch_size),
kNumCUDAThreads,
0,
ctx.cuda_device_context().stream()>>>(batch_size,
<<<NumBlocks(batch_size), kNumCUDAThreads, 0, dev_ctx.stream()>>>(
batch_size,
nranks,
class_interval_ptr,
cub_sort_keys_out_ptr,
......@@ -561,7 +547,7 @@ class ClassCenterSampleCUDAKernel : public framework::OpKernel<T> {
num_classes_per_device_ptr,
actual_num_sampled_op,
nranks + 1,
ctx.cuda_device_context().stream())));
dev_ctx.stream())));
// step 12: Calculate actual sampled class interval among nranks
PADDLE_ENFORCE_GPU_SUCCESS(
......@@ -570,13 +556,11 @@ class ClassCenterSampleCUDAKernel : public framework::OpKernel<T> {
num_classes_per_device_ptr,
class_interval_ptr,
nranks + 1,
ctx.cuda_device_context().stream())));
dev_ctx.stream())));
// step 13: Get remapped label for output
GetRemappedLabel<T><<<NumBlocks(batch_size),
kNumCUDAThreads,
0,
ctx.cuda_device_context().stream()>>>(
GetRemappedLabel<T>
<<<NumBlocks(batch_size), kNumCUDAThreads, 0, dev_ctx.stream()>>>(
batch_size,
nranks,
class_interval_ptr,
......@@ -584,28 +568,33 @@ class ClassCenterSampleCUDAKernel : public framework::OpKernel<T> {
bound_value_ptr,
cub_sort_keys_ptr,
cub_sort_values_ptr,
remapped_label->mutable_data<T>(ctx.GetPlace()));
dev_ctx.template Alloc<T>(remapped_label));
// step 14: Get sampled class center for output
framework::TensorCopySync(
num_classes_per_device, platform::CPUPlace(), &num_classes_per_device);
paddle::framework::TensorCopySync(
num_classes_per_device, phi::CPUPlace(), &num_classes_per_device);
// phi::Copy<Context>(dev_ctx,
// num_classes_per_device,
// phi::CPUPlace(),
// true,
// &num_classes_per_device);
T actual_num_samples = num_classes_per_device.data<T>()[rank + 1];
sampled_local_class_center->Resize(phi::make_ddim({actual_num_samples}));
T* sampled_local_class_center_ptr =
sampled_local_class_center->mutable_data<T>({actual_num_samples},
ctx.GetPlace());
memory::Copy(place,
dev_ctx.template Alloc<T>(sampled_local_class_center);
paddle::memory::Copy(dev_ctx.GetPlace(),
sampled_local_class_center_ptr,
place,
dev_ctx.GetPlace(),
cub_sort_values_out_ptr,
actual_num_samples * sizeof(T),
nullptr);
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(
class_center_sample,
ops::ClassCenterSampleCUDAKernel<phi::GPUContext, int64_t>,
ops::ClassCenterSampleCUDAKernel<phi::GPUContext, int>);
}
} // namespace phi
PD_REGISTER_KERNEL(class_center_sample,
GPU,
ALL_LAYOUT,
phi::ClassCenterSampleKernel,
int64_t,
int) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/core/compat/op_utils.h"
namespace phi {
KernelSignature ClassCenterSampleOpArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature("class_center_sample",
{"Label"},
{"num_classes",
"num_samples",
"ring_id",
"rank",
"nranks",
"fix_seed",
"seed"},
{"RemappedLabel", "SampledLocalClassCenter"});
}
} // namespace phi
PD_REGISTER_ARG_MAPPING_FN(class_center_sample,
phi::ClassCenterSampleOpArgumentMapping);
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册