未验证 提交 a46d7fe6 编写于 作者: D duanboqiang 提交者: GitHub

[phi]migrate class center sample kernel (#44949)

* migrate class center sample kernel

* fix Resize ddim error

* set buffer ptr

* add header

* add header

* remove comment

* remove header
上级 ecc3098e
...@@ -12,7 +12,8 @@ ...@@ -12,7 +12,8 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "paddle/fluid/operators/class_center_sample_op.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/op_version_registry.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -143,10 +144,6 @@ class ClassCenterSampleOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -143,10 +144,6 @@ class ClassCenterSampleOpMaker : public framework::OpProtoAndCheckerMaker {
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OP_WITHOUT_GRADIENT(class_center_sample, REGISTER_OP_WITHOUT_GRADIENT(class_center_sample,
ops::ClassCenterSampleOp, ops::ClassCenterSampleOp,
ops::ClassCenterSampleOpMaker); ops::ClassCenterSampleOpMaker);
REGISTER_OP_CPU_KERNEL(class_center_sample,
ops::ClassCenterSampleCPUKernel<int64_t>,
ops::ClassCenterSampleCPUKernel<int>);
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <map>
#include <set>
#include <vector>
#include "paddle/fluid/framework/generator.h"
#include "paddle/fluid/framework/op_registry.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
template <typename T>
class ClassCenterSampleCPUKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* label = ctx.Input<Tensor>("Label");
auto* remapped_label = ctx.Output<Tensor>("RemappedLabel");
auto* sampled_local_class_center =
ctx.Output<Tensor>("SampledLocalClassCenter");
int num_classes = ctx.Attr<int>("num_classes");
int num_samples = ctx.Attr<int>("num_samples");
int seed = ctx.Attr<int>("seed");
bool fix_seed = ctx.Attr<bool>("fix_seed");
PADDLE_ENFORCE_GT(num_classes,
0,
platform::errors::InvalidArgument(
"The value 'num_classes' for Op(class_center_sample) "
"must be greater than 0, "
"but the value given is %d.",
num_classes));
PADDLE_ENFORCE_GT(num_samples,
0,
platform::errors::InvalidArgument(
"The value 'num_samples' for Op(class_center_sample) "
"must be greater than 0, "
"but the value given is %d.",
num_samples));
PADDLE_ENFORCE_LE(num_samples,
num_classes,
platform::errors::InvalidArgument(
"The value 'num_samples' for Op(class_center_sample) "
"must be less than or equal to %d, "
"but the value given is %d.",
num_classes,
num_samples));
int64_t numel = label->numel();
auto* label_ptr = label->data<T>();
// get unique positive class center by ascending
std::set<T, std::less<T>> unique_label;
for (int64_t i = 0; i < numel; ++i) {
unique_label.insert(label_ptr[i]);
}
// constrcut a lookup table and get sampled_local_class_center
std::vector<T> actual_sampled;
std::map<T, T> new_class_dict;
T idx = 0;
for (auto& t : unique_label) {
new_class_dict[t] = idx;
actual_sampled.push_back(t);
idx++;
}
if (!fix_seed) {
std::random_device rnd;
seed = rnd();
}
std::uniform_int_distribution<T> dist(0, num_classes - 1);
auto engine = framework::GetCPURandomEngine(seed);
// sample negative class center randomly
while (unique_label.size() < static_cast<size_t>(num_samples)) {
T neg = dist(*engine);
if (unique_label.find(neg) == unique_label.end()) {
unique_label.insert(neg);
// unorder for negative class center
actual_sampled.push_back(neg);
}
}
int actual_num_samples = unique_label.size();
T* sampled_local_class_center_ptr =
sampled_local_class_center->mutable_data<T>({actual_num_samples},
ctx.GetPlace());
idx = 0;
for (auto& t : actual_sampled) {
sampled_local_class_center_ptr[idx] = t;
idx++;
}
// remap the input label to sampled class
auto* remmaped_label_ptr = remapped_label->mutable_data<T>(ctx.GetPlace());
for (int64_t i = 0; i < numel; ++i) {
remmaped_label_ptr[i] = new_class_dict[label_ptr[i]];
}
}
};
} // namespace operators
} // namespace paddle
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/core/dense_tensor.h"
namespace phi {
template <typename T, typename Context>
void ClassCenterSampleKernel(const Context& dev_ctx,
const DenseTensor& label,
int num_classes,
int num_samples,
int ring_id,
int rank,
int nranks,
bool fix_seed,
int seed,
DenseTensor* remapped_label,
DenseTensor* sampled_local_class_center);
} // namespace phi
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <map>
#include <set>
#include <vector>
#include "paddle/fluid/framework/generator.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
namespace phi {
template <typename T, typename Context>
void ClassCenterSampleKernel(const Context& dev_ctx,
const DenseTensor& label,
int num_classes,
int num_samples,
int ring_id,
int rank,
int nranks,
bool fix_seed,
int seed,
DenseTensor* remapped_label,
DenseTensor* sampled_local_class_center) {
PADDLE_ENFORCE_GT(num_classes,
0,
errors::InvalidArgument(
"The value 'num_classes' for Op(class_center_sample) "
"must be greater than 0, "
"but the value given is %d.",
num_classes));
PADDLE_ENFORCE_GT(num_samples,
0,
errors::InvalidArgument(
"The value 'num_samples' for Op(class_center_sample) "
"must be greater than 0, "
"but the value given is %d.",
num_samples));
PADDLE_ENFORCE_LE(num_samples,
num_classes,
errors::InvalidArgument(
"The value 'num_samples' for Op(class_center_sample) "
"must be less than or equal to %d, "
"but the value given is %d.",
num_classes,
num_samples));
int64_t numel = label.numel();
auto* label_ptr = label.data<T>();
// get unique positive class center by ascending
std::set<T, std::less<T>> unique_label;
for (int64_t i = 0; i < numel; ++i) {
unique_label.insert(label_ptr[i]);
}
// constrcut a lookup table and get sampled_local_class_center
std::vector<T> actual_sampled;
std::map<T, T> new_class_dict;
T idx = 0;
for (auto& t : unique_label) {
new_class_dict[t] = idx;
actual_sampled.push_back(t);
idx++;
}
if (!fix_seed) {
std::random_device rnd;
seed = rnd();
}
std::uniform_int_distribution<T> dist(0, num_classes - 1);
auto engine = paddle::framework::GetCPURandomEngine(seed);
// sample negative class center randomly
while (unique_label.size() < static_cast<size_t>(num_samples)) {
T neg = dist(*engine);
if (unique_label.find(neg) == unique_label.end()) {
unique_label.insert(neg);
// unorder for negative class center
actual_sampled.push_back(neg);
}
}
int actual_num_samples = unique_label.size();
sampled_local_class_center->Resize({actual_num_samples});
T* sampled_local_class_center_ptr =
dev_ctx.template Alloc<T>(sampled_local_class_center);
idx = 0;
for (auto& t : actual_sampled) {
sampled_local_class_center_ptr[idx] = t;
idx++;
}
// remap the input label to sampled class
auto* remmaped_label_ptr = dev_ctx.template Alloc<T>(remapped_label);
for (int64_t i = 0; i < numel; ++i) {
remmaped_label_ptr[i] = new_class_dict[label_ptr[i]];
}
}
} // namespace phi
PD_REGISTER_KERNEL(class_center_sample,
CPU,
ALL_LAYOUT,
phi::ClassCenterSampleKernel,
int64_t,
int) {}
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. // Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
// //
// Licensed under the Apache License, Version 2.0 (the "License"); // Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License. // you may not use this file except in compliance with the License.
...@@ -29,25 +29,24 @@ namespace cub = hipcub; ...@@ -29,25 +29,24 @@ namespace cub = hipcub;
#include <iterator> #include <iterator>
#include <random> #include <random>
#include "paddle/fluid/operators/class_center_sample_op.h" #include "paddle/fluid/framework/tensor_util.h"
#include "paddle/phi/api/include/tensor.h" #include "paddle/fluid/platform/enforce.h"
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) #if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
#include "paddle/fluid/distributed/collective/ProcessGroup.h" #include "paddle/fluid/distributed/collective/ProcessGroup.h"
#include "paddle/fluid/platform/collective_helper.h" #include "paddle/fluid/platform/collective_helper.h"
#include "paddle/fluid/platform/device/gpu/nccl_helper.h" #include "paddle/fluid/platform/device/gpu/nccl_helper.h"
#endif #endif
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
namespace paddle { namespace phi {
namespace operators {
#define CUDA_KERNEL_LOOP(i, n) \ #define CUDA_KERNEL_LOOP(i, n) \
for (int32_t i = blockIdx.x * blockDim.x + threadIdx.x, \ for (int32_t i = blockIdx.x * blockDim.x + threadIdx.x, \
step = blockDim.x * gridDim.x; \ step = blockDim.x * gridDim.x; \
i < (n); \ i < (n); \
i += step) i += step)
using Tensor = framework::Tensor;
static constexpr int kNumCUDAThreads = 512; static constexpr int kNumCUDAThreads = 512;
static constexpr int kNumMaxinumNumBlocks = 4096; static constexpr int kNumMaxinumNumBlocks = 4096;
...@@ -246,13 +245,13 @@ struct ActualNumSampledFunctor { ...@@ -246,13 +245,13 @@ struct ActualNumSampledFunctor {
explicit ActualNumSampledFunctor(const T num) : num_samples(num) {} explicit ActualNumSampledFunctor(const T num) : num_samples(num) {}
}; };
template <typename T> template <typename T, typename Context>
class MemoryBuffer { class MemoryBuffer {
public: public:
MemoryBuffer(const int num_buffer_ele, MemoryBuffer(const int num_buffer_ele,
const int num_temp_ele, const int num_temp_ele,
const int nranks, const int nranks,
const platform::Place& place) { const Context& dev_ctx) {
offset1 = 0; offset1 = 0;
offset2 = offset1 + num_buffer_ele; offset2 = offset1 + num_buffer_ele;
offset3 = offset2 + num_buffer_ele; offset3 = offset2 + num_buffer_ele;
...@@ -263,8 +262,8 @@ class MemoryBuffer { ...@@ -263,8 +262,8 @@ class MemoryBuffer {
offset8 = offset7 + (nranks + 1); offset8 = offset7 + (nranks + 1);
offset9 = offset8 + num_temp_ele; offset9 = offset8 + num_temp_ele;
buffer_ptr = buffer.mutable_data<T>( buffer.Resize({4 * num_buffer_ele + 3 * (nranks + 1) + num_temp_ele});
{4 * num_buffer_ele + 3 * (nranks + 1) + num_temp_ele}, place); buffer_ptr = dev_ctx.template Alloc<T>(&buffer);
} }
T* cub_sort_keys_ptr() { return buffer_ptr + offset1; } T* cub_sort_keys_ptr() { return buffer_ptr + offset1; }
...@@ -279,7 +278,7 @@ class MemoryBuffer { ...@@ -279,7 +278,7 @@ class MemoryBuffer {
} }
private: private:
Tensor buffer; DenseTensor buffer;
T* buffer_ptr; T* buffer_ptr;
int offset1; int offset1;
int offset2; int offset2;
...@@ -292,320 +291,310 @@ class MemoryBuffer { ...@@ -292,320 +291,310 @@ class MemoryBuffer {
int offset9; int offset9;
}; };
template <typename DeviceContext, typename T> template <typename T, typename Context>
class ClassCenterSampleCUDAKernel : public framework::OpKernel<T> { void ClassCenterSampleKernel(const Context& dev_ctx,
public: const DenseTensor& label,
void Compute(const framework::ExecutionContext& ctx) const override { int num_classes,
auto* label = ctx.Input<Tensor>("Label"); int num_samples,
auto* remapped_label = ctx.Output<Tensor>("RemappedLabel"); int ring_id,
auto* sampled_local_class_center = int rank,
ctx.Output<Tensor>("SampledLocalClassCenter"); int nranks,
int num_classes = ctx.Attr<int>("num_classes"); bool fix_seed,
int num_samples = ctx.Attr<int>("num_samples"); int seed,
DenseTensor* remapped_label,
int rid = ctx.Attr<int>("ring_id"); DenseTensor* sampled_local_class_center) {
int nranks = ctx.Attr<int>("nranks"); PADDLE_ENFORCE_GT(num_classes,
int rank = ctx.Attr<int>("rank"); 0,
errors::InvalidArgument(
int seed = ctx.Attr<int>("seed"); "The value 'num_classes' for Op(class_center_sample) "
bool fix_seed = ctx.Attr<bool>("fix_seed"); "must be greater than 0, "
PADDLE_ENFORCE_GT(num_classes, "but the value given is %d.",
0, num_classes));
platform::errors::InvalidArgument(
"The value 'num_classes' for Op(class_center_sample) " PADDLE_ENFORCE_GT(num_samples,
"must be greater than 0, " 0,
"but the value given is %d.", errors::InvalidArgument(
num_classes)); "The value 'num_samples' for Op(class_center_sample) "
"must be greater than 0, "
PADDLE_ENFORCE_GT(num_samples, "but the value given is %d.",
0, num_samples));
platform::errors::InvalidArgument(
"The value 'num_samples' for Op(class_center_sample) " PADDLE_ENFORCE_LE(num_samples,
"must be greater than 0, " num_classes,
"but the value given is %d.", errors::InvalidArgument(
num_samples)); "The value 'num_samples' for Op(class_center_sample) "
"must be less than or equal to %d, "
PADDLE_ENFORCE_LE(num_samples, "but the value given is %d.",
num_classes, num_classes,
platform::errors::InvalidArgument( num_samples));
"The value 'num_samples' for Op(class_center_sample) "
"must be less than or equal to %d, " auto place = dev_ctx.GetPlace();
"but the value given is %d.",
num_classes, int batch_size = label.numel();
num_samples)); // Algorithm:
// We first randomly generate a value in [0, num_classes) on each position
auto& dev_ctx = ctx.template device_context<DeviceContext>(); // in a array(shape[num_classes]). Then, we mark the element as negative
auto place = dev_ctx.GetPlace(); // value in the array according input label. Now, we can sort the array
// by ascending to ensure that the positive class center always in the
int batch_size = label->numel(); // front of the sorted array. So, we can get the sampled class center
// Algorithm: // index by sorted keys. Finally, we can get the rempped label by remap
// We first randomly generate a value in [0, num_classes) on each position // the input label according sampled class center.
// in a array(shape[num_classes]). Then, we mark the element as negative
// value in the array according input label. Now, we can sort the array // step 1: Calculate num classes per device using nccl all reduce
// by ascending to ensure that the positive class center always in the std::vector<T> shard_dim_vec(nranks + 1, 0);
// front of the sorted array. So, we can get the sampled class center shard_dim_vec[rank + 1] = num_classes;
// index by sorted keys. Finally, we can get the rempped label by remap DenseTensor num_classes_per_device;
// the input label according sampled class center. paddle::framework::TensorFromVector(
shard_dim_vec, dev_ctx, &num_classes_per_device);
// step 1: Calculate num classes per device using nccl all reduce T* num_classes_per_device_ptr = num_classes_per_device.data<T>();
std::vector<T> shard_dim_vec(nranks + 1, 0);
shard_dim_vec[rank + 1] = num_classes;
Tensor num_classes_per_device;
framework::TensorFromVector(
shard_dim_vec, ctx.cuda_device_context(), &num_classes_per_device);
T* num_classes_per_device_ptr = num_classes_per_device.data<T>();
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) #if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
if (nranks > 1) { if (nranks > 1) {
auto map = distributed::ProcessGroupMapFromGid::getInstance(); auto map = paddle::distributed::ProcessGroupMapFromGid::getInstance();
if (map->has(rid)) { if (map->has(ring_id)) {
// Use ProcessGroup // Use ProcessGroup
distributed::ProcessGroup* pg = map->get(rid); paddle::distributed::ProcessGroup* pg = map->get(ring_id);
std::vector<phi::DenseTensor> in_tensor; std::vector<phi::DenseTensor> in_tensor;
std::vector<phi::DenseTensor> out_tensor; std::vector<phi::DenseTensor> out_tensor;
in_tensor.push_back(num_classes_per_device); in_tensor.push_back(num_classes_per_device);
out_tensor.push_back(num_classes_per_device); out_tensor.push_back(num_classes_per_device);
distributed::AllreduceOptions opts; paddle::distributed::AllreduceOptions opts;
opts.reduce_op = distributed::ReduceOp::SUM; opts.reduce_op = paddle::distributed::ReduceOp::SUM;
auto task = pg->AllReduce(in_tensor, out_tensor, opts); auto task = pg->AllReduce(in_tensor, out_tensor, opts);
task->Wait(); task->Wait();
} else { } else {
const auto& comm = const auto& comm = paddle::platform::NCCLCommContext::Instance().Get(
platform::NCCLCommContext::Instance().Get(rid, ctx.GetPlace()); ring_id, dev_ctx.GetPlace());
// use global calculate stream // use global calculate stream
const auto calcu_stream = const auto calcu_stream =
static_cast<phi::GPUContext*>( static_cast<GPUContext*>(
platform::DeviceContextPool::Instance().Get(ctx.GetPlace())) paddle::platform::DeviceContextPool::Instance().Get(
->stream(); dev_ctx.GetPlace()))
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclAllReduce( ->stream();
num_classes_per_device_ptr, PADDLE_ENFORCE_GPU_SUCCESS(paddle::platform::dynload::ncclAllReduce(
num_classes_per_device_ptr, num_classes_per_device_ptr,
num_classes_per_device.numel(), num_classes_per_device_ptr,
platform::ToNCCLDataType( num_classes_per_device.numel(),
framework::TransToProtoVarType(num_classes_per_device.dtype())), paddle::platform::ToNCCLDataType(
ncclSum, paddle::framework::TransToProtoVarType(
comm->comm(), num_classes_per_device.dtype())),
calcu_stream)); ncclSum,
} comm->comm(),
calcu_stream));
} }
}
#endif #endif
// step 2: Determine temporary device storage requirements // step 2: Determine temporary device storage requirements
int num_buffer_ele = std::max(batch_size, num_classes); int num_buffer_ele = std::max(batch_size, num_classes);
size_t cub_sort_temp_store_size = 0; size_t cub_sort_temp_store_size = 0;
PADDLE_ENFORCE_GPU_SUCCESS((cub::DeviceRadixSort::SortPairs<T, T>( PADDLE_ENFORCE_GPU_SUCCESS(
nullptr, (cub::DeviceRadixSort::SortPairs<T, T>(nullptr,
cub_sort_temp_store_size, cub_sort_temp_store_size,
nullptr, nullptr,
nullptr, nullptr,
nullptr, nullptr,
nullptr, nullptr,
num_buffer_ele, num_buffer_ele,
0, 0,
sizeof(T) * 8, sizeof(T) * 8,
ctx.cuda_device_context().stream()))); dev_ctx.stream())));
size_t cub_sum_temp_store_size = 0; size_t cub_sum_temp_store_size = 0;
NotEqualToPreviousAdjacentIterator<T> unique_counting_iter_temp(nullptr, 0); NotEqualToPreviousAdjacentIterator<T> unique_counting_iter_temp(nullptr, 0);
PADDLE_ENFORCE_GPU_SUCCESS(( PADDLE_ENFORCE_GPU_SUCCESS(
cub::DeviceScan::InclusiveSum<NotEqualToPreviousAdjacentIterator<T>, (cub::DeviceScan::InclusiveSum<NotEqualToPreviousAdjacentIterator<T>, T*>(
T*>(nullptr, nullptr,
cub_sum_temp_store_size, cub_sum_temp_store_size,
unique_counting_iter_temp, unique_counting_iter_temp,
nullptr, nullptr,
batch_size, batch_size,
ctx.cuda_device_context().stream()))); dev_ctx.stream())));
size_t cub_scan_temp_store_size = 0; size_t cub_scan_temp_store_size = 0;
ActualNumSampledFunctor<T> actual_num_sampled_op_temp(num_samples); ActualNumSampledFunctor<T> actual_num_sampled_op_temp(num_samples);
PADDLE_ENFORCE_GPU_SUCCESS( PADDLE_ENFORCE_GPU_SUCCESS(
(cub::DeviceScan::InclusiveScan(nullptr, (cub::DeviceScan::InclusiveScan(nullptr,
cub_scan_temp_store_size, cub_scan_temp_store_size,
num_classes_per_device_ptr, num_classes_per_device_ptr,
num_classes_per_device_ptr, num_classes_per_device_ptr,
actual_num_sampled_op_temp, actual_num_sampled_op_temp,
nranks + 1, nranks + 1,
ctx.cuda_device_context().stream()))); dev_ctx.stream())));
size_t cub_temp_storage_bytes = size_t cub_temp_storage_bytes =
std::max(std::max(cub_sort_temp_store_size, cub_scan_temp_store_size), std::max(std::max(cub_sort_temp_store_size, cub_scan_temp_store_size),
cub_sum_temp_store_size); cub_sum_temp_store_size);
int num_temp_ele = cub_temp_storage_bytes / sizeof(T) + 1; int num_temp_ele = cub_temp_storage_bytes / sizeof(T) + 1;
// step 3: Alloc buffer memory so that we can reuse allocated memory // step 3: Alloc buffer memory so that we can reuse allocated memory
MemoryBuffer<T> memory_buffer = MemoryBuffer<T, Context> memory_buffer =
MemoryBuffer<T>(num_buffer_ele, num_temp_ele, nranks, ctx.GetPlace()); MemoryBuffer<T, Context>(num_buffer_ele, num_temp_ele, nranks, dev_ctx);
T* cub_sort_keys_ptr = memory_buffer.cub_sort_keys_ptr(); T* cub_sort_keys_ptr = memory_buffer.cub_sort_keys_ptr();
T* cub_sort_keys_out_ptr = memory_buffer.cub_sort_keys_out_ptr(); T* cub_sort_keys_out_ptr = memory_buffer.cub_sort_keys_out_ptr();
T* cub_sort_values_ptr = memory_buffer.cub_sort_values_ptr(); T* cub_sort_values_ptr = memory_buffer.cub_sort_values_ptr();
T* cub_sort_values_out_ptr = memory_buffer.cub_sort_values_out_ptr(); T* cub_sort_values_out_ptr = memory_buffer.cub_sort_values_out_ptr();
T* bound_index_ptr = memory_buffer.bound_index_ptr(); T* bound_index_ptr = memory_buffer.bound_index_ptr();
T* bound_value_ptr = memory_buffer.bound_value_ptr(); T* bound_value_ptr = memory_buffer.bound_value_ptr();
T* class_interval_ptr = memory_buffer.class_interval_ptr(); T* class_interval_ptr = memory_buffer.class_interval_ptr();
void* cub_temp_storage_ptr = memory_buffer.cub_temp_storage_ptr(); void* cub_temp_storage_ptr = memory_buffer.cub_temp_storage_ptr();
// step 4: Calculate class interval among nranks // step 4: Calculate class interval among nranks
PADDLE_ENFORCE_GPU_SUCCESS( PADDLE_ENFORCE_GPU_SUCCESS(
(cub::DeviceScan::InclusiveSum(cub_temp_storage_ptr, (cub::DeviceScan::InclusiveSum(cub_temp_storage_ptr,
cub_temp_storage_bytes, cub_temp_storage_bytes,
num_classes_per_device_ptr, num_classes_per_device_ptr,
class_interval_ptr, class_interval_ptr,
nranks + 1, nranks + 1,
ctx.cuda_device_context().stream()))); dev_ctx.stream())));
// step 5: random sample negative class center // step 5: random sample negative class center
uint64_t seed_data; uint64_t seed_data;
uint64_t increment; uint64_t increment;
int vec_size = VectorizedSize<T>(cub_sort_keys_ptr); int vec_size = VectorizedSize<T>(cub_sort_keys_ptr);
auto offset = ((num_classes - 1) / auto offset = ((num_classes - 1) /
(NumBlocks(num_classes) * kNumCUDAThreads * vec_size) + (NumBlocks(num_classes) * kNumCUDAThreads * vec_size) +
1) * 1) *
vec_size; vec_size;
int device_id = ctx.GetPlace().GetDeviceId(); // auto gen_cuda = paddle::framework::DefaultCUDAGenerator(device_id);
auto gen_cuda = framework::DefaultCUDAGenerator(device_id); auto gen_cuda = dev_ctx.GetGenerator();
if (!fix_seed) { if (!fix_seed) {
auto seed_offset = gen_cuda->IncrementOffset(offset); auto seed_offset = gen_cuda->IncrementOffset(offset);
seed_data = seed_offset.first; seed_data = seed_offset.first;
increment = seed_offset.second; increment = seed_offset.second;
} else { } else {
seed_data = seed + rank; seed_data = seed + rank;
increment = offset; increment = offset;
}
RandomSampleClassCenter<T><<<NumBlocks(num_classes),
kNumCUDAThreads,
0,
ctx.cuda_device_context().stream()>>>(
num_classes, seed_data, increment, num_classes, cub_sort_keys_ptr);
// step 6: mark positive class center as negative value
// fill the sort values to index 0, 1, ..., batch_size-1
MarkPositiveClassCenter<<<NumBlocks(batch_size),
kNumCUDAThreads,
0,
ctx.cuda_device_context().stream()>>>(
batch_size,
rank,
class_interval_ptr,
num_classes,
label->data<T>(),
cub_sort_keys_ptr);
Range<T><<<NumBlocks(num_buffer_ele),
kNumCUDAThreads,
0,
ctx.cuda_device_context().stream()>>>(num_buffer_ele,
cub_sort_values_ptr);
// step 7: sort class center by ascending, so that positive class center
// always be sampled.
PADDLE_ENFORCE_GPU_SUCCESS((cub::DeviceRadixSort::SortPairs<T, T>(
cub_temp_storage_ptr,
cub_temp_storage_bytes,
cub_sort_keys_ptr,
cub_sort_keys_out_ptr,
cub_sort_values_ptr,
cub_sort_values_out_ptr,
num_classes,
0,
sizeof(T) * 8,
ctx.cuda_device_context().stream())));
// step 8: sort input label ascending
PADDLE_ENFORCE_GPU_SUCCESS((cub::DeviceRadixSort::SortPairs<T, T>(
cub_temp_storage_ptr,
cub_temp_storage_bytes,
label->data<T>(),
cub_sort_keys_out_ptr,
cub_sort_values_ptr,
cub_sort_keys_ptr,
batch_size,
0,
sizeof(T) * 8,
ctx.cuda_device_context().stream())));
// step 9: Calculate new index using InclusiveSum on ascending sorted input
// label
NotEqualToPreviousAdjacentIterator<T> unique_counting_iter(
cub_sort_keys_out_ptr, 0);
PADDLE_ENFORCE_GPU_SUCCESS((
cub::DeviceScan::InclusiveSum<NotEqualToPreviousAdjacentIterator<T>,
T*>(cub_temp_storage_ptr,
cub_temp_storage_bytes,
unique_counting_iter,
cub_sort_values_ptr,
batch_size,
ctx.cuda_device_context().stream())));
// step 10: Calculate new class center bound among ranks
GetClassCenterBound<T>
<<<NumBlocks(batch_size),
kNumCUDAThreads,
0,
ctx.cuda_device_context().stream()>>>(batch_size,
nranks,
class_interval_ptr,
cub_sort_keys_out_ptr,
cub_sort_values_ptr,
bound_index_ptr,
bound_value_ptr);
// step 11: Calculate actual number of sampled class per device.
// Since maybe num_positive_class_center > num_samples,
// we need to ensure all positive class center per device are sampled.
ActualNumSampledFunctor<T> actual_num_sampled_op(num_samples);
PADDLE_ENFORCE_GPU_SUCCESS(
(cub::DeviceScan::InclusiveScan(cub_temp_storage_ptr,
cub_temp_storage_bytes,
bound_value_ptr,
num_classes_per_device_ptr,
actual_num_sampled_op,
nranks + 1,
ctx.cuda_device_context().stream())));
// step 12: Calculate actual sampled class interval among nranks
PADDLE_ENFORCE_GPU_SUCCESS(
(cub::DeviceScan::InclusiveSum(cub_temp_storage_ptr,
cub_temp_storage_bytes,
num_classes_per_device_ptr,
class_interval_ptr,
nranks + 1,
ctx.cuda_device_context().stream())));
// step 13: Get remapped label for output
GetRemappedLabel<T><<<NumBlocks(batch_size),
kNumCUDAThreads,
0,
ctx.cuda_device_context().stream()>>>(
batch_size,
nranks,
class_interval_ptr,
bound_index_ptr,
bound_value_ptr,
cub_sort_keys_ptr,
cub_sort_values_ptr,
remapped_label->mutable_data<T>(ctx.GetPlace()));
// step 14: Get sampled class center for output
framework::TensorCopySync(
num_classes_per_device, platform::CPUPlace(), &num_classes_per_device);
T actual_num_samples = num_classes_per_device.data<T>()[rank + 1];
T* sampled_local_class_center_ptr =
sampled_local_class_center->mutable_data<T>({actual_num_samples},
ctx.GetPlace());
memory::Copy(place,
sampled_local_class_center_ptr,
place,
cub_sort_values_out_ptr,
actual_num_samples * sizeof(T),
nullptr);
} }
}; RandomSampleClassCenter<T>
} // namespace operators <<<NumBlocks(num_classes), kNumCUDAThreads, 0, dev_ctx.stream()>>>(
} // namespace paddle num_classes, seed_data, increment, num_classes, cub_sort_keys_ptr);
namespace ops = paddle::operators; // step 6: mark positive class center as negative value
REGISTER_OP_CUDA_KERNEL( // fill the sort values to index 0, 1, ..., batch_size-1
class_center_sample, MarkPositiveClassCenter<T>
ops::ClassCenterSampleCUDAKernel<phi::GPUContext, int64_t>, <<<NumBlocks(batch_size), kNumCUDAThreads, 0, dev_ctx.stream()>>>(
ops::ClassCenterSampleCUDAKernel<phi::GPUContext, int>); batch_size,
rank,
class_interval_ptr,
num_classes,
label.data<T>(),
cub_sort_keys_ptr);
Range<T><<<NumBlocks(num_buffer_ele), kNumCUDAThreads, 0, dev_ctx.stream()>>>(
num_buffer_ele, cub_sort_values_ptr);
// step 7: sort class center by ascending, so that positive class center
// always be sampled.
PADDLE_ENFORCE_GPU_SUCCESS(
(cub::DeviceRadixSort::SortPairs<T, T>(cub_temp_storage_ptr,
cub_temp_storage_bytes,
cub_sort_keys_ptr,
cub_sort_keys_out_ptr,
cub_sort_values_ptr,
cub_sort_values_out_ptr,
num_classes,
0,
sizeof(T) * 8,
dev_ctx.stream())));
// step 8: sort input label ascending
PADDLE_ENFORCE_GPU_SUCCESS(
(cub::DeviceRadixSort::SortPairs<T, T>(cub_temp_storage_ptr,
cub_temp_storage_bytes,
label.data<T>(),
cub_sort_keys_out_ptr,
cub_sort_values_ptr,
cub_sort_keys_ptr,
batch_size,
0,
sizeof(T) * 8,
dev_ctx.stream())));
// step 9: Calculate new index using InclusiveSum on ascending sorted input
// label
NotEqualToPreviousAdjacentIterator<T> unique_counting_iter(
cub_sort_keys_out_ptr, 0);
PADDLE_ENFORCE_GPU_SUCCESS(
(cub::DeviceScan::InclusiveSum<NotEqualToPreviousAdjacentIterator<T>, T*>(
cub_temp_storage_ptr,
cub_temp_storage_bytes,
unique_counting_iter,
cub_sort_values_ptr,
batch_size,
dev_ctx.stream())));
// step 10: Calculate new class center bound among ranks
GetClassCenterBound<T>
<<<NumBlocks(batch_size), kNumCUDAThreads, 0, dev_ctx.stream()>>>(
batch_size,
nranks,
class_interval_ptr,
cub_sort_keys_out_ptr,
cub_sort_values_ptr,
bound_index_ptr,
bound_value_ptr);
// step 11: Calculate actual number of sampled class per device.
// Since maybe num_positive_class_center > num_samples,
// we need to ensure all positive class center per device are sampled.
ActualNumSampledFunctor<T> actual_num_sampled_op(num_samples);
PADDLE_ENFORCE_GPU_SUCCESS(
(cub::DeviceScan::InclusiveScan(cub_temp_storage_ptr,
cub_temp_storage_bytes,
bound_value_ptr,
num_classes_per_device_ptr,
actual_num_sampled_op,
nranks + 1,
dev_ctx.stream())));
// step 12: Calculate actual sampled class interval among nranks
PADDLE_ENFORCE_GPU_SUCCESS(
(cub::DeviceScan::InclusiveSum(cub_temp_storage_ptr,
cub_temp_storage_bytes,
num_classes_per_device_ptr,
class_interval_ptr,
nranks + 1,
dev_ctx.stream())));
// step 13: Get remapped label for output
GetRemappedLabel<T>
<<<NumBlocks(batch_size), kNumCUDAThreads, 0, dev_ctx.stream()>>>(
batch_size,
nranks,
class_interval_ptr,
bound_index_ptr,
bound_value_ptr,
cub_sort_keys_ptr,
cub_sort_values_ptr,
dev_ctx.template Alloc<T>(remapped_label));
// step 14: Get sampled class center for output
paddle::framework::TensorCopySync(
num_classes_per_device, phi::CPUPlace(), &num_classes_per_device);
// phi::Copy<Context>(dev_ctx,
// num_classes_per_device,
// phi::CPUPlace(),
// true,
// &num_classes_per_device);
T actual_num_samples = num_classes_per_device.data<T>()[rank + 1];
sampled_local_class_center->Resize(phi::make_ddim({actual_num_samples}));
T* sampled_local_class_center_ptr =
dev_ctx.template Alloc<T>(sampled_local_class_center);
paddle::memory::Copy(dev_ctx.GetPlace(),
sampled_local_class_center_ptr,
dev_ctx.GetPlace(),
cub_sort_values_out_ptr,
actual_num_samples * sizeof(T),
nullptr);
}
} // namespace phi
PD_REGISTER_KERNEL(class_center_sample,
GPU,
ALL_LAYOUT,
phi::ClassCenterSampleKernel,
int64_t,
int) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/core/compat/op_utils.h"
namespace phi {
KernelSignature ClassCenterSampleOpArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature("class_center_sample",
{"Label"},
{"num_classes",
"num_samples",
"ring_id",
"rank",
"nranks",
"fix_seed",
"seed"},
{"RemappedLabel", "SampledLocalClassCenter"});
}
} // namespace phi
PD_REGISTER_ARG_MAPPING_FN(class_center_sample,
phi::ClassCenterSampleOpArgumentMapping);
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册