未验证 提交 100db44f 编写于 作者: G Guoxia Wang 提交者: GitHub

support class center sample of PartialFC (#34106)

* support class center sample of PartialFC
上级 c7070cb8
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/class_center_sample_op.h"
namespace paddle {
namespace operators {
class ClassCenterSampleOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("Label"), "Input", "Label",
"ClassCenterSample");
OP_INOUT_CHECK(ctx->HasOutput("RemappedLabel"), "Output", "RemappedLabel",
"ClassCenterSample");
OP_INOUT_CHECK(ctx->HasOutput("SampledLocalClassCenter"), "Output",
"SampledLocalClassCenter", "ClassCenterSample");
auto x_dims = ctx->GetInputDim("Label");
PADDLE_ENFORCE_EQ(x_dims.size(), 1,
platform::errors::InvalidArgument(
"Rank of Input(Label) should be equal to 1, "
"but the value given is %d.",
x_dims.size()));
ctx->SetOutputDim("RemappedLabel", x_dims);
auto num_samples = ctx->Attrs().Get<int>("num_samples");
ctx->SetOutputDim("SampledLocalClassCenter",
framework::make_ddim({num_samples}));
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "Label"),
ctx.device_context());
}
};
class ClassCenterSampleOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput(
"Label",
"(Tensor<int|int64>) The input of ClassCenterSample op. Each value "
"of Label is an integer label.");
AddOutput("RemappedLabel",
"(Tensor<int|int64>) Output tensor with same shape as Label. "
"Each label is remap using sampled class.");
AddOutput("SampledLocalClassCenter",
"(Tensor<int|int64>) The sampled class center for local rank,"
"value in [0, num_classes).");
AddAttr<int>(
"num_classes",
"A positive integer to specify the number of classes at local rank. "
"Note that num_classes of each GPU can be different.");
AddAttr<int>(
"num_samples",
"A positive integer to specify the number of class center to sample.");
AddAttr<int>("ring_id", "(int default 0) nccl communication ring id.")
.SetDefault(0);
AddAttr<int>("nranks", "(int default 1) The total number of GPUs.")
.SetDefault(1);
AddAttr<int>("rank", "(int default 0) The rank id in nranks.")
.SetDefault(0);
AddAttr<bool>("fix_seed",
"A flag indicating whether to use a fixed seed to generate "
"random negative class center. NOTE: DO NOT set this flag to"
"true in training. Setting this flag to true is only useful "
"in unittest or for debug")
.SetDefault(false);
AddAttr<int>("seed",
"Random seed used to generate random negative class center. "
"[default 0].")
.SetDefault(0);
AddComment(R"DOC(
Class center sample method is proposed from the paper PartialFC that only sample a subset of the class centers.
The process of sampling subset class centers is straightforward: 1) First select the positive class centers;
2) Randomly sample negative class centers. Specifically, given a Label tensor, shape [batch_size], select all
the positive class centers and randomly sample negative class centers, then remap the input label tensor using
the sampled class centers. Note that if the number of the positive class centers is greater than the input
num_samples, it keeps all the positive class centers and the shape of SampledLocalClassCenter will be
[num_positive_class_centers]. The op supports CPU, single GPU and multi GPU.
For more information, Partial FC: Training 10 Million Identities on a Single Machine
arxiv: https://arxiv.org/abs/2010.05222
Examples:
For CPU or only one GPU
Given:
Label: [11, 5 , 1 , 3 , 12, 2 , 15, 19, 18, 19]
num_classes = 20
num_samples = 6
Then:
RemappedLabel: [4, 3, 0, 2, 5, 1, 6, 8, 7, 8]
SampledLocalClassCenter: [1 , 2 , 3 , 5 , 11, 12, 15, 18, 19]
For multi GPU
Given:
rank0:
Label: [10, 17, 15, 11, 9 , 12, 18, 18, 17, 18, 19, 2 , 8 , 13, 11, 13, 9 , 10, 0 , 4 ]
num_classes = 10
num_samples = 6
ring_id = 0
nranks = 2
rank = 0
rank1:
Label: [10, 17, 15, 11, 9 , 12, 18, 18, 17, 18, 19, 2 , 8 , 13, 11, 13, 9 , 10, 0 , 4 ]
num_classes = 10
num_samples = 6
ring_id = 0
nranks = 2
rank = 1
Then:
rank0:
RemappedLabel: [6 , 11, 10, 7 , 4 , 8 , 12, 12, 11, 12, 13, 1 , 3 , 9 , 7 , 9 , 4 , 6 , 0 , 2 ]
SampledLocalClassCenter: [0, 2, 4, 8, 9, 3]
rank1:
RemappedLabel: [6 , 11, 10, 7 , 4 , 8 , 12, 12, 11, 12, 13, 1 , 3 , 9 , 7 , 9 , 4 , 6 , 0 , 2 ]
SampledLocalClassCenter: [0, 1, 2, 3, 5, 7, 8]
)DOC");
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OP_WITHOUT_GRADIENT(class_center_sample, ops::ClassCenterSampleOp,
ops::ClassCenterSampleOpMaker);
REGISTER_OP_CPU_KERNEL(class_center_sample,
ops::ClassCenterSampleCPUKernel<int64_t>,
ops::ClassCenterSampleCPUKernel<int>);
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifdef PADDLE_WITH_HIP
#include <hiprand.h>
#include <hiprand_kernel.h>
#include <hipcub/hipcub.hpp>
typedef hiprandState curandState;
namespace cub = hipcub;
#else
#include <curand.h>
#include <curand_kernel.h>
#include <cub/cub.cuh>
#endif
#include <iterator>
#include <random>
#include "paddle/fluid/operators/class_center_sample_op.h"
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
#include "paddle/fluid/platform/collective_helper.h"
#include "paddle/fluid/platform/nccl_helper.h"
#endif
namespace paddle {
namespace operators {
#define CUDA_KERNEL_LOOP(i, n) \
for (int32_t i = blockIdx.x * blockDim.x + threadIdx.x, \
step = blockDim.x * gridDim.x; \
i < (n); i += step)
using Tensor = framework::Tensor;
static constexpr int kNumCUDAThreads = 512;
static constexpr int kNumMaxinumNumBlocks = 4096;
inline int32_t NumBlocks(const int32_t n) {
return std::min((n + kNumCUDAThreads - 1) / kNumCUDAThreads,
kNumMaxinumNumBlocks);
}
template <typename T>
__global__ void RandomSampleClassCenter(const int64_t n, int64_t seed,
int64_t increment,
const int64_t max_val, T* buffer) {
const int id = blockIdx.x * blockDim.x + threadIdx.x;
curandState localState;
size_t local_seed =
(static_cast<size_t>(seed) + 0x9E3779B9U +
(static_cast<size_t>(id) << 6U) + (static_cast<size_t>(id) >> 2U));
#ifdef PADDLE_WITH_HIP
hiprand_init(local_seed, id, increment, &localState);
CUDA_KERNEL_LOOP(i, n) {
buffer[i] = static_cast<T>(hiprand(&localState) % max_val);
}
#else
curand_init(local_seed, id, increment, &localState);
CUDA_KERNEL_LOOP(i, n) {
buffer[i] = static_cast<T>(curand(&localState) % max_val);
}
#endif
}
template <typename T>
__global__ void Range(const int64_t n, T* out) {
CUDA_KERNEL_LOOP(i, n) { out[i] = static_cast<T>(i); }
}
template <typename T>
__global__ void MarkPositiveClassCenter(const int64_t n, const int64_t rank,
const T* class_interval_ptr,
const int num_classes, const T* labels,
T* out) {
CUDA_KERNEL_LOOP(i, n) {
T label = labels[i] - class_interval_ptr[rank];
if (label >= 0 && label < num_classes) {
out[label] = label - num_classes;
}
}
}
template <typename T>
__device__ void FindIntervalIndex(const T* class_interval_ptr,
const int64_t nranks, const T value,
int64_t* find_index) {
int64_t start = 0;
int64_t end = nranks;
int64_t mid = ((end - start) >> 1) + start + 1;
while (start < end) {
if (class_interval_ptr[mid] == value) break;
if (class_interval_ptr[mid] > value)
end = mid - 1;
else
start = mid;
mid = ((end - start) >> 1) + start + 1;
}
*find_index = min(mid, end);
}
template <typename T>
__global__ void GetClassCenterBound(const int64_t n, const int64_t nranks,
const T* class_interval_ptr,
const T* key_ptr, const T* value_ptr,
T* bound_index, T* bound_value) {
CUDA_KERNEL_LOOP(i, n) {
if (i != 0) {
int64_t cur_index, pre_index;
FindIntervalIndex(class_interval_ptr, nranks, key_ptr[i], &cur_index);
FindIntervalIndex(class_interval_ptr, nranks, key_ptr[i - 1], &pre_index);
if (cur_index > pre_index) {
assert(cur_index < nranks);
#pragma unroll
for (int32_t j = pre_index + 1; j <= cur_index; ++j) {
bound_index[j] = static_cast<T>(i);
bound_value[j] = value_ptr[i];
}
}
}
}
CUDA_KERNEL_LOOP(i, nranks + 1) {
int64_t first_index, last_index;
FindIntervalIndex(class_interval_ptr, nranks, key_ptr[0], &first_index);
FindIntervalIndex(class_interval_ptr, nranks, key_ptr[n - 1], &last_index);
if (i <= first_index) {
bound_index[i] = 0;
bound_value[i] = value_ptr[0];
} else if (i > last_index) {
bound_index[i] = n;
bound_value[i] = value_ptr[n - 1] + 1;
}
}
}
template <typename T>
__global__ void GetRemappedLabel(const int64_t n, const int64_t nranks,
const T* sampled_class_interval_ptr,
const T* bound_index, const T* bound_value,
const T* label_map_key, T* label_map_value,
T* mapped_label) {
CUDA_KERNEL_LOOP(i, n) {
#pragma unroll
for (int64_t j = 0; j < nranks; j++) {
if (i >= bound_index[j] && i < bound_index[j + 1]) {
label_map_value[i] =
label_map_value[i] - bound_value[j] + sampled_class_interval_ptr[j];
}
}
mapped_label[label_map_key[i]] = label_map_value[i];
}
}
// aligned vector generates vectorized load/store on CUDA
template <typename T, int Size>
struct alignas(sizeof(T) * Size) AlignedVector {
T val[Size];
};
template <typename T>
inline int VectorizedSize(const T* pointer) {
uint64_t address = reinterpret_cast<uint64_t>(pointer);
constexpr int vec4 = std::alignment_of<AlignedVector<T, 4>>::value; // NOLINT
if (address % vec4 == 0) {
return 4;
}
return 1;
}
#undef CUDA_KERNEL_LOOP
template <typename T>
class NotEqualToPreviousAdjacentIterator {
public:
using self_type = NotEqualToPreviousAdjacentIterator;
using value_type = T;
using difference_type = std::ptrdiff_t;
using pointer = T*;
using reference = T;
using iterator_category = std::input_iterator_tag;
public:
__host__ __device__ __forceinline__
NotEqualToPreviousAdjacentIterator(const T* arr, int64_t offset)
: arr_(arr), offset_(offset) {}
__host__ __device__ __forceinline__ reference operator*() const {
return offset_ == 0 ? 0 : (arr_[offset_] == arr_[offset_ - 1] ? 0 : 1);
}
template <typename Distance>
__host__ __device__ __forceinline__ self_type operator+(Distance n) const {
self_type ret(arr_, offset_ + n);
return ret;
}
template <typename Distance>
__host__ __device__ __forceinline__ reference operator[](Distance n) const {
return *(*this + n);
}
private:
const T* arr_;
int64_t offset_;
};
template <typename T>
struct ActualNumSampledFunctor {
__host__ __device__ __forceinline__ T operator()(const T& a,
const T& b) const {
return max(num_samples, (b - a));
}
T num_samples;
explicit ActualNumSampledFunctor(const T num) : num_samples(num) {}
};
template <typename T>
class MemoryBuffer {
public:
MemoryBuffer(const int num_buffer_ele, const int num_temp_ele,
const int nranks, const platform::Place& place) {
offset1 = 0;
offset2 = offset1 + num_buffer_ele;
offset3 = offset2 + num_buffer_ele;
offset4 = offset3 + num_buffer_ele;
offset5 = offset4 + num_buffer_ele;
offset6 = offset5 + (nranks + 1);
offset7 = offset6 + (nranks + 1);
offset8 = offset7 + (nranks + 1);
offset9 = offset8 + num_temp_ele;
buffer_ptr = buffer.mutable_data<T>(
{4 * num_buffer_ele + 3 * (nranks + 1) + num_temp_ele}, place);
}
T* cub_sort_keys_ptr() { return buffer_ptr + offset1; }
T* cub_sort_keys_out_ptr() { return buffer_ptr + offset2; }
T* cub_sort_values_ptr() { return buffer_ptr + offset3; }
T* cub_sort_values_out_ptr() { return buffer_ptr + offset4; }
T* bound_index_ptr() { return buffer_ptr + offset5; }
T* bound_value_ptr() { return buffer_ptr + offset6; }
T* class_interval_ptr() { return buffer_ptr + offset7; }
void* cub_temp_storage_ptr() {
return reinterpret_cast<void*>(buffer_ptr + offset8);
}
private:
Tensor buffer;
T* buffer_ptr;
int offset1;
int offset2;
int offset3;
int offset4;
int offset5;
int offset6;
int offset7;
int offset8;
int offset9;
};
template <typename DeviceContext, typename T>
class ClassCenterSampleCUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* label = ctx.Input<Tensor>("Label");
auto* remapped_label = ctx.Output<Tensor>("RemappedLabel");
auto* sampled_local_class_center =
ctx.Output<Tensor>("SampledLocalClassCenter");
int num_classes = ctx.Attr<int>("num_classes");
int num_samples = ctx.Attr<int>("num_samples");
int rid = ctx.Attr<int>("ring_id");
int nranks = ctx.Attr<int>("nranks");
int rank = ctx.Attr<int>("rank");
int seed = ctx.Attr<int>("seed");
bool fix_seed = ctx.Attr<bool>("fix_seed");
PADDLE_ENFORCE_GT(num_classes, 0,
platform::errors::InvalidArgument(
"The value 'num_classes' for Op(class_center_sample) "
"must be greater than 0, "
"but the value given is %d.",
num_classes));
PADDLE_ENFORCE_GT(num_samples, 0,
platform::errors::InvalidArgument(
"The value 'num_samples' for Op(class_center_sample) "
"must be greater than 0, "
"but the value given is %d.",
num_samples));
PADDLE_ENFORCE_LE(num_samples, num_classes,
platform::errors::InvalidArgument(
"The value 'num_samples' for Op(class_center_sample) "
"must be less than or equal to %d, "
"but the value given is %d.",
num_classes, num_samples));
auto& dev_ctx = ctx.template device_context<DeviceContext>();
auto place = BOOST_GET_CONST(platform::CUDAPlace, dev_ctx.GetPlace());
int batch_size = label->numel();
// Algorithm:
// We first randomly generate a value in [0, num_classes) on each position
// in a array(shape[num_classes]). Then, we mark the element as negative
// value in the array according input label. Now, we can sort the array
// by ascending to ensure that the positive class center always in the
// front of the sorted array. So, we can get the sampled class center
// index by sorted keys. Finally, we can get the rempped label by remap
// the input label according sampled class center.
// step 1: Calculate num classes per device using nccl all reduce
std::vector<T> shard_dim_vec(nranks + 1, 0);
shard_dim_vec[rank + 1] = num_classes;
Tensor num_classes_per_device;
framework::TensorFromVector(shard_dim_vec, ctx.cuda_device_context(),
&num_classes_per_device);
T* num_classes_per_device_ptr = num_classes_per_device.data<T>();
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
if (nranks > 1) {
const auto& comm =
platform::NCCLCommContext::Instance().Get(rid, ctx.GetPlace());
// use global calculate stream
const auto calcu_stream =
static_cast<platform::CUDADeviceContext*>(
platform::DeviceContextPool::Instance().Get(ctx.GetPlace()))
->stream();
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::ncclAllReduce(
num_classes_per_device_ptr, num_classes_per_device_ptr,
num_classes_per_device.numel(),
platform::ToNCCLDataType(num_classes_per_device.type()), ncclSum,
comm->comm(), calcu_stream));
}
#endif
// step 2: Determine temporary device storage requirements
int num_buffer_ele = std::max(batch_size, num_classes);
size_t cub_sort_temp_store_size = 0;
PADDLE_ENFORCE_CUDA_SUCCESS((cub::DeviceRadixSort::SortPairs<T, T>(
nullptr, cub_sort_temp_store_size, nullptr, nullptr, nullptr, nullptr,
num_buffer_ele, 0, sizeof(T) * 8, ctx.cuda_device_context().stream())));
size_t cub_sum_temp_store_size = 0;
NotEqualToPreviousAdjacentIterator<T> unique_counting_iter_temp(nullptr, 0);
PADDLE_ENFORCE_CUDA_SUCCESS(
(cub::DeviceScan::InclusiveSum<NotEqualToPreviousAdjacentIterator<T>,
T*>(
nullptr, cub_sum_temp_store_size, unique_counting_iter_temp,
nullptr, batch_size, ctx.cuda_device_context().stream())));
size_t cub_scan_temp_store_size = 0;
ActualNumSampledFunctor<T> actual_num_sampled_op_temp(num_samples);
PADDLE_ENFORCE_CUDA_SUCCESS((cub::DeviceScan::InclusiveScan(
nullptr, cub_scan_temp_store_size, num_classes_per_device_ptr,
num_classes_per_device_ptr, actual_num_sampled_op_temp, nranks + 1,
ctx.cuda_device_context().stream())));
size_t cub_temp_storage_bytes =
std::max(std::max(cub_sort_temp_store_size, cub_scan_temp_store_size),
cub_sum_temp_store_size);
int num_temp_ele = cub_temp_storage_bytes / sizeof(T) + 1;
// step 3: Alloc buffer memory so that we can reuse allocated memory
MemoryBuffer<T> memory_buffer =
MemoryBuffer<T>(num_buffer_ele, num_temp_ele, nranks, ctx.GetPlace());
T* cub_sort_keys_ptr = memory_buffer.cub_sort_keys_ptr();
T* cub_sort_keys_out_ptr = memory_buffer.cub_sort_keys_out_ptr();
T* cub_sort_values_ptr = memory_buffer.cub_sort_values_ptr();
T* cub_sort_values_out_ptr = memory_buffer.cub_sort_values_out_ptr();
T* bound_index_ptr = memory_buffer.bound_index_ptr();
T* bound_value_ptr = memory_buffer.bound_value_ptr();
T* class_interval_ptr = memory_buffer.class_interval_ptr();
void* cub_temp_storage_ptr = memory_buffer.cub_temp_storage_ptr();
// step 4: Calculate class interval among nranks
PADDLE_ENFORCE_CUDA_SUCCESS((cub::DeviceScan::InclusiveSum(
cub_temp_storage_ptr, cub_temp_storage_bytes,
num_classes_per_device_ptr, class_interval_ptr, nranks + 1,
ctx.cuda_device_context().stream())));
// step 5: random sample negative class center
int vec_size = VectorizedSize<T>(cub_sort_keys_ptr);
int increment = ((num_classes - 1) /
(NumBlocks(num_classes) * kNumCUDAThreads * vec_size) +
1) *
vec_size;
if (!fix_seed) {
std::random_device rnd;
seed = rnd();
}
RandomSampleClassCenter<T><<<NumBlocks(num_classes), kNumCUDAThreads, 0,
ctx.cuda_device_context().stream()>>>(
num_classes, seed + rank, increment, num_classes, cub_sort_keys_ptr);
// step 6: mark positive class center as negative value
// fill the sort values to index 0, 1, ..., batch_size-1
MarkPositiveClassCenter<<<NumBlocks(batch_size), kNumCUDAThreads, 0,
ctx.cuda_device_context().stream()>>>(
batch_size, rank, class_interval_ptr, num_classes, label->data<T>(),
cub_sort_keys_ptr);
Range<T><<<NumBlocks(num_buffer_ele), kNumCUDAThreads, 0,
ctx.cuda_device_context().stream()>>>(num_buffer_ele,
cub_sort_values_ptr);
// step 7: sort class center by ascending, so that positive class center
// always be sampled.
PADDLE_ENFORCE_CUDA_SUCCESS((cub::DeviceRadixSort::SortPairs<T, T>(
cub_temp_storage_ptr, cub_temp_storage_bytes, cub_sort_keys_ptr,
cub_sort_keys_out_ptr, cub_sort_values_ptr, cub_sort_values_out_ptr,
num_classes, 0, sizeof(T) * 8, ctx.cuda_device_context().stream())));
// step 8: sort input label ascending
PADDLE_ENFORCE_CUDA_SUCCESS((cub::DeviceRadixSort::SortPairs<T, T>(
cub_temp_storage_ptr, cub_temp_storage_bytes, label->data<T>(),
cub_sort_keys_out_ptr, cub_sort_values_ptr, cub_sort_keys_ptr,
batch_size, 0, sizeof(T) * 8, ctx.cuda_device_context().stream())));
// step 9: Calculate new index using InclusiveSum on ascending sorted input
// label
NotEqualToPreviousAdjacentIterator<T> unique_counting_iter(
cub_sort_keys_out_ptr, 0);
PADDLE_ENFORCE_CUDA_SUCCESS((cub::DeviceScan::InclusiveSum<
NotEqualToPreviousAdjacentIterator<T>, T*>(
cub_temp_storage_ptr, cub_temp_storage_bytes, unique_counting_iter,
cub_sort_values_ptr, batch_size, ctx.cuda_device_context().stream())));
// step 10: Calculate new class center bound among ranks
GetClassCenterBound<T><<<NumBlocks(batch_size), kNumCUDAThreads, 0,
ctx.cuda_device_context().stream()>>>(
batch_size, nranks, class_interval_ptr, cub_sort_keys_out_ptr,
cub_sort_values_ptr, bound_index_ptr, bound_value_ptr);
// step 11: Calculate actual number of sampled class per device.
// Since maybe num_positive_class_center > num_samples,
// we need to ensure all positive class center per device are sampled.
ActualNumSampledFunctor<T> actual_num_sampled_op(num_samples);
PADDLE_ENFORCE_CUDA_SUCCESS((cub::DeviceScan::InclusiveScan(
cub_temp_storage_ptr, cub_temp_storage_bytes, bound_value_ptr,
num_classes_per_device_ptr, actual_num_sampled_op, nranks + 1,
ctx.cuda_device_context().stream())));
// step 12: Calculate actual sampled class interval among nranks
PADDLE_ENFORCE_CUDA_SUCCESS((cub::DeviceScan::InclusiveSum(
cub_temp_storage_ptr, cub_temp_storage_bytes,
num_classes_per_device_ptr, class_interval_ptr, nranks + 1,
ctx.cuda_device_context().stream())));
// step 13: Get remapped label for output
GetRemappedLabel<T><<<NumBlocks(batch_size), kNumCUDAThreads, 0,
ctx.cuda_device_context().stream()>>>(
batch_size, nranks, class_interval_ptr, bound_index_ptr,
bound_value_ptr, cub_sort_keys_ptr, cub_sort_values_ptr,
remapped_label->mutable_data<T>(ctx.GetPlace()));
// step 14: Get sampled class center for output
framework::TensorCopySync(num_classes_per_device, platform::CPUPlace(),
&num_classes_per_device);
T actual_num_samples = num_classes_per_device.data<T>()[rank + 1];
T* sampled_local_class_center_ptr =
sampled_local_class_center->mutable_data<T>({actual_num_samples},
ctx.GetPlace());
memory::Copy(place, sampled_local_class_center_ptr, place,
cub_sort_values_out_ptr, actual_num_samples * sizeof(T),
nullptr);
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(
class_center_sample,
ops::ClassCenterSampleCUDAKernel<paddle::platform::CUDADeviceContext,
int64_t>,
ops::ClassCenterSampleCUDAKernel<paddle::platform::CUDADeviceContext, int>);
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <map>
#include <set>
#include <vector>
#include "paddle/fluid/framework/generator.h"
#include "paddle/fluid/framework/op_registry.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
template <typename T>
class ClassCenterSampleCPUKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* label = ctx.Input<Tensor>("Label");
auto* remapped_label = ctx.Output<Tensor>("RemappedLabel");
auto* sampled_local_class_center =
ctx.Output<Tensor>("SampledLocalClassCenter");
int num_classes = ctx.Attr<int>("num_classes");
int num_samples = ctx.Attr<int>("num_samples");
int seed = ctx.Attr<int>("seed");
bool fix_seed = ctx.Attr<bool>("fix_seed");
PADDLE_ENFORCE_GT(num_classes, 0,
platform::errors::InvalidArgument(
"The value 'num_classes' for Op(class_center_sample) "
"must be greater than 0, "
"but the value given is %d.",
num_classes));
PADDLE_ENFORCE_GT(num_samples, 0,
platform::errors::InvalidArgument(
"The value 'num_samples' for Op(class_center_sample) "
"must be greater than 0, "
"but the value given is %d.",
num_samples));
PADDLE_ENFORCE_LE(num_samples, num_classes,
platform::errors::InvalidArgument(
"The value 'num_samples' for Op(class_center_sample) "
"must be less than or equal to %d, "
"but the value given is %d.",
num_classes, num_samples));
int64_t numel = label->numel();
auto* label_ptr = label->data<T>();
// get unique positive class center by ascending
std::set<T, std::less<T>> unique_label;
for (int64_t i = 0; i < numel; ++i) {
unique_label.insert(label_ptr[i]);
}
// constrcut a lookup table and get sampled_local_class_center
std::vector<T> actual_sampled;
std::map<T, T> new_class_dict;
T idx = 0;
for (auto& t : unique_label) {
new_class_dict[t] = idx;
actual_sampled.push_back(t);
idx++;
}
if (!fix_seed) {
std::random_device rnd;
seed = rnd();
}
std::uniform_int_distribution<T> dist(0, num_classes - 1);
auto engine = framework::GetCPURandomEngine(seed);
// sample negative class center randomly
while (unique_label.size() < static_cast<size_t>(num_samples)) {
T neg = dist(*engine);
if (unique_label.find(neg) == unique_label.end()) {
unique_label.insert(neg);
// unorder for negative class center
actual_sampled.push_back(neg);
}
}
int actual_num_samples = unique_label.size();
T* sampled_local_class_center_ptr =
sampled_local_class_center->mutable_data<T>({actual_num_samples},
ctx.GetPlace());
idx = 0;
for (auto& t : actual_sampled) {
sampled_local_class_center_ptr[idx] = t;
idx++;
}
// remap the input label to sampled class
auto* remmaped_label_ptr = remapped_label->mutable_data<T>(ctx.GetPlace());
for (int64_t i = 0; i < numel; ++i) {
remmaped_label_ptr[i] = new_class_dict[label_ptr[i]];
}
}
};
} // namespace operators
} // namespace paddle
file(GLOB TEST_OPS RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}" "test_*.py")
file(GLOB TEST_OPS RELATIVE
"${CMAKE_CURRENT_SOURCE_DIR}" "test_*.py")
string(REPLACE ".py" "" TEST_OPS "${TEST_OPS}")
set(GC_ENVS FLAGS_eager_delete_tensor_gb=0.0 FLAGS_fast_eager_deletion_mode=1 FLAGS_memory_fraction_of_eager_deletion=1.0)
set(dist_ENVS http_proxy="" https_proxy="")
......@@ -28,6 +29,7 @@ list(APPEND DIST_TEST_OPS test_parallel_dygraph_pipeline_parallel)
list(APPEND DIST_TEST_OPS test_parallel_dygraph_tensor_parallel)
list(APPEND DIST_TEST_OPS test_parallel_dygraph_sharding_parallel)
list(APPEND DIST_TEST_OPS test_parallel_dygraph_mp_layers)
list(APPEND DIST_TEST_OPS test_parallel_class_center_sample)
list(APPEND DIST_TEST_OPS test_parallel_margin_cross_entropy)
set(MIXED_DIST_TEST_OPS ${DIST_TEST_OPS})
#remove distribute unittests.
......@@ -196,6 +198,7 @@ if ((NOT WITH_GPU) AND (NOT WITH_ROCM))
LIST(REMOVE_ITEM TEST_OPS test_mixed_precision)
LIST(REMOVE_ITEM TEST_OPS test_fleet_base_single)
LIST(REMOVE_ITEM TEST_OPS test_dygraph_recompute)
list(REMOVE_ITEM TEST_OPS test_parallel_class_center_sample)
LIST(REMOVE_ITEM TEST_OPS test_parallel_margin_cross_entropy)
elseif(WITH_GPU)
if (${CUDNN_VERSION} VERSION_LESS 7100)
......@@ -908,6 +911,7 @@ if(WITH_DISTRIBUTE AND WITH_GPU AND WITH_NCCL)
set_tests_properties(test_parallel_dygraph_tensor_parallel PROPERTIES TIMEOUT 200)
set_tests_properties(test_parallel_dygraph_sharding_parallel PROPERTIES TIMEOUT 120)
set_tests_properties(test_parallel_dygraph_mp_layers PROPERTIES TIMEOUT 120)
set_tests_properties(test_parallel_class_center_sample PROPERTIES TIMEOUT 120)
set_tests_properties(test_parallel_margin_cross_entropy PROPERTIES TIMEOUT 120)
if(${NCCL_VERSION} VERSION_GREATER_EQUAL 2212)
set_tests_properties(test_parallel_dygraph_sparse_embedding PROPERTIES TIMEOUT 120)
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import division
from __future__ import print_function
import unittest
import paddle
import numpy as np
import random
import paddle.distributed as dist
import paddle.fluid as fluid
import paddle.distributed.fleet as fleet
from paddle import framework
def set_random_seed(seed):
"""Set random seed for reproducability."""
random.seed(seed)
np.random.seed(seed)
paddle.seed(seed)
fleet.meta_parallel.model_parallel_random_seed(seed)
def class_center_sample_numpy(label, classes_list, num_samples):
unique_label = np.unique(label)
nranks = len(classes_list)
class_interval = np.cumsum(np.insert(classes_list, 0, 0))
pos_class_center_per_device = []
unique_label_per_device = []
for i in range(nranks):
index = np.logical_and(unique_label >= class_interval[i],
unique_label < class_interval[i + 1])
pos_class_center_per_device.append(unique_label[index] - class_interval[
i])
unique_label_per_device.append(unique_label[index])
num_samples_per_device = []
for pos_class_center in pos_class_center_per_device:
num_samples_per_device.append(max(len(pos_class_center), num_samples))
sampled_class_interval = np.cumsum(np.insert(num_samples_per_device, 0, 0))
remapped_dict = {}
for i in range(nranks):
for idx, v in enumerate(unique_label_per_device[i],
sampled_class_interval[i]):
remapped_dict[v] = idx
remapped_label = []
for l in label:
remapped_label.append(remapped_dict[l])
return remapped_label, pos_class_center_per_device
class TestParallelClassCenterSampleOp(unittest.TestCase):
def setUp(self):
strategy = fleet.DistributedStrategy()
fleet.init(is_collective=True, strategy=strategy)
def test_class_center_sample(self):
rank_id = dist.get_rank()
nranks = dist.get_world_size()
seed = 1025
set_random_seed(seed)
paddle.seed(rank_id * 10)
random.seed(seed)
np.random.seed(seed)
batch_size = 20
num_samples = 6
for dtype in ('int32', 'int64'):
for _ in range(5):
classes_list = np.random.randint(10, 15, (nranks, ))
num_class = np.sum(classes_list)
np_label = np.random.randint(
0, num_class, (batch_size, ), dtype=dtype)
label = paddle.to_tensor(np_label, dtype=dtype)
np_remapped_label, np_sampled_class_center_per_device = class_center_sample_numpy(
np_label, classes_list, num_samples)
remapped_label, sampled_class_index = paddle.nn.functional.class_center_sample(
label, classes_list[rank_id], num_samples)
np.testing.assert_allclose(remapped_label.numpy(),
np_remapped_label)
np_sampled_class_index = np_sampled_class_center_per_device[
rank_id]
np.testing.assert_allclose(
sampled_class_index.numpy()[:len(np_sampled_class_index)],
np_sampled_class_index)
if __name__ == '__main__':
unittest.main()
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import numpy as np
import math
import random
import paddle
import paddle.fluid.core as core
from op_test import OpTest
from paddle.fluid import Program, program_guard
def class_center_sample_numpy(label, classes_list, num_samples):
unique_label = np.unique(label)
nranks = len(classes_list)
class_interval = np.cumsum(np.insert(classes_list, 0, 0))
pos_class_center_per_device = []
unique_label_per_device = []
for i in range(nranks):
index = np.logical_and(unique_label >= class_interval[i],
unique_label < class_interval[i + 1])
pos_class_center_per_device.append(unique_label[index] - class_interval[
i])
unique_label_per_device.append(unique_label[index])
num_samples_per_device = []
for pos_class_center in pos_class_center_per_device:
num_samples_per_device.append(max(len(pos_class_center), num_samples))
sampled_class_interval = np.cumsum(np.insert(num_samples_per_device, 0, 0))
remapped_dict = {}
for i in range(nranks):
for idx, v in enumerate(unique_label_per_device[i],
sampled_class_interval[i]):
remapped_dict[v] = idx
remapped_label = []
for l in label:
remapped_label.append(remapped_dict[l])
return np.array(remapped_label), np.array(pos_class_center_per_device)
class TestClassCenterSampleOp(OpTest):
def initParams(self):
self.op_type = "class_center_sample"
self.batch_size = 20
self.num_samples = 6
self.num_classes = 10
self.seed = 2021
def init_dtype(self):
self.dtype = np.int64
def init_fix_seed(self):
self.fix_seed = True
def setUp(self):
self.initParams()
self.init_dtype()
self.init_fix_seed()
label = np.random.randint(
0, self.num_classes, (self.batch_size, ), dtype=self.dtype)
remapped_label, sampled_class_center = class_center_sample_numpy(
label, [self.num_classes], self.num_samples)
self.inputs = {'Label': label}
self.outputs = {
'RemappedLabel': remapped_label.astype(self.dtype),
'SampledLocalClassCenter': sampled_class_center.astype(self.dtype)
}
self.attrs = {
'num_classes': self.num_classes,
'num_samples': self.num_samples,
'seed': self.seed,
'fix_seed': self.fix_seed,
}
def test_check_output(self):
self.check_output(no_check_set=['SampledLocalClassCenter'])
class TestClassCenterSampleOpINT32(TestClassCenterSampleOp):
def init_dtype(self):
self.dtype = np.int32
class TestClassCenterSampleOpFixSeed(TestClassCenterSampleOp):
def init_fix_seed(self):
self.fix_seed = True
class TestClassCenterSampleV2(unittest.TestCase):
def setUp(self):
self.initParams()
np.random.seed(self.seed)
paddle.framework.random._manual_program_seed(2021)
self.places = [paddle.fluid.CPUPlace()]
if core.is_compiled_with_cuda():
self.places.append(paddle.fluid.CUDAPlace(0))
def initParams(self):
self.batch_size = 10
self.num_samples = 6
self.num_classes = 20
self.seed = 0
self.init_dtype()
def init_dtype(self):
self.dtype = np.int64
def test_static(self):
for place in self.places:
self.check_static_result(place=place)
def check_static_result(self, place):
with program_guard(Program(), Program()):
label_np = np.random.randint(
0, self.num_classes, (self.batch_size, ), dtype=self.dtype)
label = paddle.static.data(
name='label', shape=[self.batch_size], dtype=self.dtype)
remapped_label, sampled_class_index = paddle.nn.functional.class_center_sample(
label, self.num_classes, self.num_samples, seed=self.seed)
remapped_label_np, sampled_class_center_np = class_center_sample_numpy(
label_np, [self.num_classes], self.num_samples)
exe = paddle.fluid.Executor(place)
[remapped_label_res, sampled_class_index_res] = exe.run(
paddle.fluid.default_main_program(),
feed={'label': label_np},
fetch_list=[remapped_label, sampled_class_index])
np.testing.assert_allclose(remapped_label_res, remapped_label_np)
np.testing.assert_allclose(
sampled_class_index_res[:len(sampled_class_center_np[0])],
sampled_class_center_np[0])
def test_dynamic(self):
for place in self.places:
self.check_dynamic_result(place=place)
def check_dynamic_result(self, place):
with paddle.fluid.dygraph.guard(place):
label_np = np.random.randint(
0, self.num_classes, (self.batch_size, ), dtype=self.dtype)
label = paddle.to_tensor(label_np, dtype=self.dtype)
remapped_label, sampled_class_index = paddle.nn.functional.class_center_sample(
label, self.num_classes, self.num_samples, seed=self.seed)
remapped_label_np, sampled_class_center_np = class_center_sample_numpy(
label_np, [self.num_classes], self.num_samples)
remapped_label_res = remapped_label.numpy()
sampled_class_index_res = sampled_class_index.numpy()
np.testing.assert_allclose(remapped_label_res, remapped_label_np)
np.testing.assert_allclose(
sampled_class_index_res[:len(sampled_class_center_np[0])],
sampled_class_center_np[0])
class TestClassCenterSampleV2INT32(TestClassCenterSampleV2):
def init_dtype(self):
self.dtype = np.int32
class TestClassCenterSampleAPIError(unittest.TestCase):
def setUp(self):
self.initParams()
np.random.seed(self.seed)
self.places = [paddle.fluid.CPUPlace()]
if core.is_compiled_with_cuda():
self.places.append(paddle.fluid.CUDAPlace(0))
def initParams(self):
self.batch_size = 20
self.num_samples = 15
self.num_classes = 10
self.seed = 2021
self.init_dtype()
def init_dtype(self):
self.dtype = np.int64
def test_dynamic_errors(self):
def test_num_samples():
for place in self.places:
with paddle.fluid.dygraph.guard(place):
label_np = np.random.randint(
0,
self.num_classes, (self.batch_size, ),
dtype=self.dtype)
label = paddle.to_tensor(label_np)
remapped_label, sampled_class_index = paddle.nn.functional.class_center_sample(
label,
self.num_classes,
self.num_samples,
seed=self.seed)
self.assertRaises(ValueError, test_num_samples)
if __name__ == '__main__':
unittest.main()
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import paddle.fluid as fluid
from test_parallel_dygraph_dataparallel import TestMultipleGpus
class TestParallelClassCenterSample(TestMultipleGpus):
def test_parallel_class_center_sample(self):
self.run_mnist_2gpu('parallel_class_center_sample.py')
if __name__ == "__main__":
unittest.main()
......@@ -31,4 +31,5 @@ no_check_set_white_list = [
'rnn',
'fusion_lstm',
'softmax_with_cross_entropy',
'class_center_sample',
]
......@@ -55,6 +55,7 @@ from .common import unfold # noqa: F401
from .common import interpolate # noqa: F401
from .common import upsample # noqa: F401
from .common import bilinear # noqa: F401
from .common import class_center_sample # noqa: F401
from .conv import conv1d # noqa: F401
from .conv import conv1d_transpose # noqa: F401
from .common import linear # noqa: F401
......@@ -200,5 +201,6 @@ __all__ = [ #noqa
'temporal_shift',
'batch_norm',
'layer_norm',
'instance_norm'
'instance_norm',
'class_center_sample',
]
......@@ -1564,3 +1564,156 @@ def label_smooth(label, prior_dist=None, epsilon=0.1, name=None):
outputs={"Out": smooth_label},
attrs={"epsilon": float(epsilon)})
return smooth_label
def class_center_sample(label, num_classes, num_samples, group=None, seed=None):
"""
Class center sample method is proposed from the paper PartialFC that only sample a subset of the class centers.
The process of sampling subset class centers is straightforward:
1. First select the positive class centers;
2. Then randomly sample negative class centers.
Specifically, given a label tensor, shape [batch_size], select all the positive class centers and randomly
sample negative class centers, then remap the input label tensor using the sampled class centers.
For more information, Partial FC: Training 10 Million Identities on a Single Machine
arxiv: https://arxiv.org/abs/2010.05222
.. hint::
If the number of the positive class centers is greater than the input num_samples, it keeps all the positive
class centers and the shape of sampled_class_center will be [num_positive_class_centers].
The API supports CPU, single GPU and multi GPU.
Args:
label (Tensor): 1-D tensor with shape [N], each label in [0, num_classes)
num_classes (int): A positive integer to specify the number of classes at local rank.
Note that num_classes of each GPU can be different.
num_samples (int): A positive integer to specify the number of class center to sample.
group (Group, optional): The abstract representation of group.
See paddle.distributed.collective.Group. Default is ``None``.
seed (int, optional): Random seed. Default is ``None``.
Returns:
Tuple of two ``Tensor`` : (remapped_label, sampled_class_center), remapped label using sampled class center,
sampled class center from [0, num_classes).
Examples:
.. code-block:: python
# CPU or single GPU
import paddle
num_classes = 20
batch_size = 10
num_samples = 6
label = paddle.randint(low=0, high=num_classes, shape=[batch_size], dtype='int64')
remapped_label, sampled_class_index = paddle.nn.functional.class_center_sample(label, num_classes, num_samples)
print(label)
print(remapped_label)
print(sampled_class_index)
# the output is
#Tensor(shape=[10], dtype=int64, place=CPUPlace, stop_gradient=True,
# [11, 5 , 1 , 3 , 12, 2 , 15, 19, 18, 19])
#Tensor(shape=[10], dtype=int64, place=CPUPlace, stop_gradient=True,
# [4, 3, 0, 2, 5, 1, 6, 8, 7, 8])
#Tensor(shape=[9], dtype=int64, place=CPUPlace, stop_gradient=True,
# [1 , 2 , 3 , 5 , 11, 12, 15, 18, 19])
.. code-block:: python
# required: distributed
# Multi GPU, test_class_center_sample.py
import paddle
import paddle.distributed as dist
strategy = dist.fleet.DistributedStrategy()
dist.fleet.init(is_collective=True, strategy=strategy)
batch_size = 10
num_samples = 6
rank_id = dist.get_rank()
# num_classes of each GPU can be different, e.g num_classes_list = [10, 8]
num_classes_list = [10, 10]
num_classes = paddle.sum(paddle.to_tensor(num_classes_list))
label = paddle.randint(low=0, high=num_classes.item(), shape=[batch_size], dtype='int64')
label_list = []
dist.all_gather(label_list, label)
label = paddle.concat(label_list, axis=0)
remapped_label, sampled_class_index = paddle.nn.functional.class_center_sample(label, num_classes_list[rank_id], num_samples)
print(label)
print(remapped_label)
print(sampled_class_index)
#python -m paddle.distributed.launch --gpus=0,1 test_class_center_sample.py
# rank 0 output:
#Tensor(shape=[20], dtype=int64, place=CUDAPlace(0), stop_gradient=True,
# [10, 17, 15, 11, 9 , 12, 18, 18, 17, 18, 19, 2 , 8 , 13, 11, 13, 9 , 10, 0 , 4 ])
#Tensor(shape=[20], dtype=int64, place=CUDAPlace(0), stop_gradient=True,
# [6 , 11, 10, 7 , 4 , 8 , 12, 12, 11, 12, 13, 1 , 3 , 9 , 7 , 9 , 4 , 6 , 0 , 2 ])
#Tensor(shape=[6], dtype=int64, place=CUDAPlace(0), stop_gradient=True,
# [0, 2, 4, 8, 9, 3])
# rank 1 output:
#Tensor(shape=[20], dtype=int64, place=CUDAPlace(1), stop_gradient=True,
# [10, 17, 15, 11, 9 , 12, 18, 18, 17, 18, 19, 2 , 8 , 13, 11, 13, 9 , 10, 0 , 4 ])
#Tensor(shape=[20], dtype=int64, place=CUDAPlace(1), stop_gradient=True,
# [6 , 11, 10, 7 , 4 , 8 , 12, 12, 11, 12, 13, 1 , 3 , 9 , 7 , 9 , 4 , 6 , 0 , 2 ])
#Tensor(shape=[7], dtype=int64, place=CUDAPlace(1), stop_gradient=True,
# [0, 1, 2, 3, 5, 7, 8])
"""
if group is not None and not group.is_member():
return
ring_id = 0 if group is None else group.id
rank = 0
nranks = 1
if core.is_compiled_with_dist():
parallel_env = paddle.distributed.ParallelEnv()
global_rank = parallel_env.rank
rank = global_rank if group is None else group.get_group_rank(
global_rank)
nranks = parallel_env.world_size if group is None else group.nranks
if num_samples > num_classes:
raise ValueError(
'Expected num_samples less than or equal to {}, got num_samples {}'.
format(num_classes, num_samples))
if (seed is None or seed == 0) and default_main_program().random_seed != 0:
seed = default_main_program().random_seed
if in_dygraph_mode():
remapped_label, sampled_class_center = core.ops.class_center_sample(
label, 'num_classes', num_classes, 'num_samples', num_samples,
'ring_id', ring_id, 'nranks', nranks, 'rank', rank, 'fix_seed',
seed is not None, 'seed', seed if seed is not None else 0)
return remapped_label, sampled_class_center
check_variable_and_dtype(label, 'label', ['int64', 'int32'],
'class_center_sample')
op_type = 'class_center_sample'
helper = LayerHelper(op_type, **locals())
remapped_label = helper.create_variable_for_type_inference(
dtype=label.dtype)
sampled_class_center = helper.create_variable_for_type_inference(
dtype=label.dtype)
helper.append_op(
type=op_type,
inputs={'Label': label},
outputs={
'RemappedLabel': remapped_label,
'SampledLocalClassCenter': sampled_class_center
},
attrs={
'num_classes': num_classes,
'num_samples': num_samples,
'ring_id': ring_id,
'nranks': nranks,
'rank': rank,
'fix_seed': seed is not None,
'seed': seed if seed is not None else 0
})
return remapped_label, sampled_class_center
......@@ -719,5 +719,6 @@ STATIC_MODE_TESTING_LIST = [
'test_sgd_op_bf16',
'test_marker_op',
'test_c_embedding_op',
'test_class_center_sample_op',
'test_margin_cross_entropy_op',
]
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册