Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
100db44f
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2299
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
100db44f
编写于
8月 18, 2021
作者:
G
Guoxia Wang
提交者:
GitHub
8月 18, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
support class center sample of PartialFC (#34106)
* support class center sample of PartialFC
上级
c7070cb8
变更
11
显示空白变更内容
内联
并排
Showing
11 changed file
with
1271 addition
and
2 deletion
+1271
-2
paddle/fluid/operators/class_center_sample_op.cc
paddle/fluid/operators/class_center_sample_op.cc
+147
-0
paddle/fluid/operators/class_center_sample_op.cu
paddle/fluid/operators/class_center_sample_op.cu
+486
-0
paddle/fluid/operators/class_center_sample_op.h
paddle/fluid/operators/class_center_sample_op.h
+114
-0
python/paddle/fluid/tests/unittests/CMakeLists.txt
python/paddle/fluid/tests/unittests/CMakeLists.txt
+5
-1
python/paddle/fluid/tests/unittests/parallel_class_center_sample.py
...dle/fluid/tests/unittests/parallel_class_center_sample.py
+110
-0
python/paddle/fluid/tests/unittests/test_class_center_sample_op.py
...ddle/fluid/tests/unittests/test_class_center_sample_op.py
+222
-0
python/paddle/fluid/tests/unittests/test_parallel_class_center_sample.py
...luid/tests/unittests/test_parallel_class_center_sample.py
+29
-0
python/paddle/fluid/tests/unittests/white_list/no_check_set_white_list.py
...uid/tests/unittests/white_list/no_check_set_white_list.py
+1
-0
python/paddle/nn/functional/__init__.py
python/paddle/nn/functional/__init__.py
+3
-1
python/paddle/nn/functional/common.py
python/paddle/nn/functional/common.py
+153
-0
tools/static_mode_white_list.py
tools/static_mode_white_list.py
+1
-0
未找到文件。
paddle/fluid/operators/class_center_sample_op.cc
0 → 100644
浏览文件 @
100db44f
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/class_center_sample_op.h"
namespace
paddle
{
namespace
operators
{
class
ClassCenterSampleOp
:
public
framework
::
OperatorWithKernel
{
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"Label"
),
"Input"
,
"Label"
,
"ClassCenterSample"
);
OP_INOUT_CHECK
(
ctx
->
HasOutput
(
"RemappedLabel"
),
"Output"
,
"RemappedLabel"
,
"ClassCenterSample"
);
OP_INOUT_CHECK
(
ctx
->
HasOutput
(
"SampledLocalClassCenter"
),
"Output"
,
"SampledLocalClassCenter"
,
"ClassCenterSample"
);
auto
x_dims
=
ctx
->
GetInputDim
(
"Label"
);
PADDLE_ENFORCE_EQ
(
x_dims
.
size
(),
1
,
platform
::
errors
::
InvalidArgument
(
"Rank of Input(Label) should be equal to 1, "
"but the value given is %d."
,
x_dims
.
size
()));
ctx
->
SetOutputDim
(
"RemappedLabel"
,
x_dims
);
auto
num_samples
=
ctx
->
Attrs
().
Get
<
int
>
(
"num_samples"
);
ctx
->
SetOutputDim
(
"SampledLocalClassCenter"
,
framework
::
make_ddim
({
num_samples
}));
}
protected:
framework
::
OpKernelType
GetExpectedKernelType
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
return
framework
::
OpKernelType
(
OperatorWithKernel
::
IndicateVarDataType
(
ctx
,
"Label"
),
ctx
.
device_context
());
}
};
class
ClassCenterSampleOpMaker
:
public
framework
::
OpProtoAndCheckerMaker
{
public:
void
Make
()
override
{
AddInput
(
"Label"
,
"(Tensor<int|int64>) The input of ClassCenterSample op. Each value "
"of Label is an integer label."
);
AddOutput
(
"RemappedLabel"
,
"(Tensor<int|int64>) Output tensor with same shape as Label. "
"Each label is remap using sampled class."
);
AddOutput
(
"SampledLocalClassCenter"
,
"(Tensor<int|int64>) The sampled class center for local rank,"
"value in [0, num_classes)."
);
AddAttr
<
int
>
(
"num_classes"
,
"A positive integer to specify the number of classes at local rank. "
"Note that num_classes of each GPU can be different."
);
AddAttr
<
int
>
(
"num_samples"
,
"A positive integer to specify the number of class center to sample."
);
AddAttr
<
int
>
(
"ring_id"
,
"(int default 0) nccl communication ring id."
)
.
SetDefault
(
0
);
AddAttr
<
int
>
(
"nranks"
,
"(int default 1) The total number of GPUs."
)
.
SetDefault
(
1
);
AddAttr
<
int
>
(
"rank"
,
"(int default 0) The rank id in nranks."
)
.
SetDefault
(
0
);
AddAttr
<
bool
>
(
"fix_seed"
,
"A flag indicating whether to use a fixed seed to generate "
"random negative class center. NOTE: DO NOT set this flag to"
"true in training. Setting this flag to true is only useful "
"in unittest or for debug"
)
.
SetDefault
(
false
);
AddAttr
<
int
>
(
"seed"
,
"Random seed used to generate random negative class center. "
"[default 0]."
)
.
SetDefault
(
0
);
AddComment
(
R"DOC(
Class center sample method is proposed from the paper PartialFC that only sample a subset of the class centers.
The process of sampling subset class centers is straightforward: 1) First select the positive class centers;
2) Randomly sample negative class centers. Specifically, given a Label tensor, shape [batch_size], select all
the positive class centers and randomly sample negative class centers, then remap the input label tensor using
the sampled class centers. Note that if the number of the positive class centers is greater than the input
num_samples, it keeps all the positive class centers and the shape of SampledLocalClassCenter will be
[num_positive_class_centers]. The op supports CPU, single GPU and multi GPU.
For more information, Partial FC: Training 10 Million Identities on a Single Machine
arxiv: https://arxiv.org/abs/2010.05222
Examples:
For CPU or only one GPU
Given:
Label: [11, 5 , 1 , 3 , 12, 2 , 15, 19, 18, 19]
num_classes = 20
num_samples = 6
Then:
RemappedLabel: [4, 3, 0, 2, 5, 1, 6, 8, 7, 8]
SampledLocalClassCenter: [1 , 2 , 3 , 5 , 11, 12, 15, 18, 19]
For multi GPU
Given:
rank0:
Label: [10, 17, 15, 11, 9 , 12, 18, 18, 17, 18, 19, 2 , 8 , 13, 11, 13, 9 , 10, 0 , 4 ]
num_classes = 10
num_samples = 6
ring_id = 0
nranks = 2
rank = 0
rank1:
Label: [10, 17, 15, 11, 9 , 12, 18, 18, 17, 18, 19, 2 , 8 , 13, 11, 13, 9 , 10, 0 , 4 ]
num_classes = 10
num_samples = 6
ring_id = 0
nranks = 2
rank = 1
Then:
rank0:
RemappedLabel: [6 , 11, 10, 7 , 4 , 8 , 12, 12, 11, 12, 13, 1 , 3 , 9 , 7 , 9 , 4 , 6 , 0 , 2 ]
SampledLocalClassCenter: [0, 2, 4, 8, 9, 3]
rank1:
RemappedLabel: [6 , 11, 10, 7 , 4 , 8 , 12, 12, 11, 12, 13, 1 , 3 , 9 , 7 , 9 , 4 , 6 , 0 , 2 ]
SampledLocalClassCenter: [0, 1, 2, 3, 5, 7, 8]
)DOC"
);
}
};
}
// namespace operators
}
// namespace paddle
namespace
ops
=
paddle
::
operators
;
namespace
plat
=
paddle
::
platform
;
REGISTER_OP_WITHOUT_GRADIENT
(
class_center_sample
,
ops
::
ClassCenterSampleOp
,
ops
::
ClassCenterSampleOpMaker
);
REGISTER_OP_CPU_KERNEL
(
class_center_sample
,
ops
::
ClassCenterSampleCPUKernel
<
int64_t
>
,
ops
::
ClassCenterSampleCPUKernel
<
int
>
);
paddle/fluid/operators/class_center_sample_op.cu
0 → 100644
浏览文件 @
100db44f
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifdef PADDLE_WITH_HIP
#include <hiprand.h>
#include <hiprand_kernel.h>
#include <hipcub/hipcub.hpp>
typedef
hiprandState
curandState
;
namespace
cub
=
hipcub
;
#else
#include <curand.h>
#include <curand_kernel.h>
#include <cub/cub.cuh>
#endif
#include <iterator>
#include <random>
#include "paddle/fluid/operators/class_center_sample_op.h"
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
#include "paddle/fluid/platform/collective_helper.h"
#include "paddle/fluid/platform/nccl_helper.h"
#endif
namespace
paddle
{
namespace
operators
{
#define CUDA_KERNEL_LOOP(i, n) \
for (int32_t i = blockIdx.x * blockDim.x + threadIdx.x, \
step = blockDim.x * gridDim.x; \
i < (n); i += step)
using
Tensor
=
framework
::
Tensor
;
static
constexpr
int
kNumCUDAThreads
=
512
;
static
constexpr
int
kNumMaxinumNumBlocks
=
4096
;
inline
int32_t
NumBlocks
(
const
int32_t
n
)
{
return
std
::
min
((
n
+
kNumCUDAThreads
-
1
)
/
kNumCUDAThreads
,
kNumMaxinumNumBlocks
);
}
template
<
typename
T
>
__global__
void
RandomSampleClassCenter
(
const
int64_t
n
,
int64_t
seed
,
int64_t
increment
,
const
int64_t
max_val
,
T
*
buffer
)
{
const
int
id
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
curandState
localState
;
size_t
local_seed
=
(
static_cast
<
size_t
>
(
seed
)
+
0x9E3779B9U
+
(
static_cast
<
size_t
>
(
id
)
<<
6U
)
+
(
static_cast
<
size_t
>
(
id
)
>>
2U
));
#ifdef PADDLE_WITH_HIP
hiprand_init
(
local_seed
,
id
,
increment
,
&
localState
);
CUDA_KERNEL_LOOP
(
i
,
n
)
{
buffer
[
i
]
=
static_cast
<
T
>
(
hiprand
(
&
localState
)
%
max_val
);
}
#else
curand_init
(
local_seed
,
id
,
increment
,
&
localState
);
CUDA_KERNEL_LOOP
(
i
,
n
)
{
buffer
[
i
]
=
static_cast
<
T
>
(
curand
(
&
localState
)
%
max_val
);
}
#endif
}
template
<
typename
T
>
__global__
void
Range
(
const
int64_t
n
,
T
*
out
)
{
CUDA_KERNEL_LOOP
(
i
,
n
)
{
out
[
i
]
=
static_cast
<
T
>
(
i
);
}
}
template
<
typename
T
>
__global__
void
MarkPositiveClassCenter
(
const
int64_t
n
,
const
int64_t
rank
,
const
T
*
class_interval_ptr
,
const
int
num_classes
,
const
T
*
labels
,
T
*
out
)
{
CUDA_KERNEL_LOOP
(
i
,
n
)
{
T
label
=
labels
[
i
]
-
class_interval_ptr
[
rank
];
if
(
label
>=
0
&&
label
<
num_classes
)
{
out
[
label
]
=
label
-
num_classes
;
}
}
}
template
<
typename
T
>
__device__
void
FindIntervalIndex
(
const
T
*
class_interval_ptr
,
const
int64_t
nranks
,
const
T
value
,
int64_t
*
find_index
)
{
int64_t
start
=
0
;
int64_t
end
=
nranks
;
int64_t
mid
=
((
end
-
start
)
>>
1
)
+
start
+
1
;
while
(
start
<
end
)
{
if
(
class_interval_ptr
[
mid
]
==
value
)
break
;
if
(
class_interval_ptr
[
mid
]
>
value
)
end
=
mid
-
1
;
else
start
=
mid
;
mid
=
((
end
-
start
)
>>
1
)
+
start
+
1
;
}
*
find_index
=
min
(
mid
,
end
);
}
template
<
typename
T
>
__global__
void
GetClassCenterBound
(
const
int64_t
n
,
const
int64_t
nranks
,
const
T
*
class_interval_ptr
,
const
T
*
key_ptr
,
const
T
*
value_ptr
,
T
*
bound_index
,
T
*
bound_value
)
{
CUDA_KERNEL_LOOP
(
i
,
n
)
{
if
(
i
!=
0
)
{
int64_t
cur_index
,
pre_index
;
FindIntervalIndex
(
class_interval_ptr
,
nranks
,
key_ptr
[
i
],
&
cur_index
);
FindIntervalIndex
(
class_interval_ptr
,
nranks
,
key_ptr
[
i
-
1
],
&
pre_index
);
if
(
cur_index
>
pre_index
)
{
assert
(
cur_index
<
nranks
);
#pragma unroll
for
(
int32_t
j
=
pre_index
+
1
;
j
<=
cur_index
;
++
j
)
{
bound_index
[
j
]
=
static_cast
<
T
>
(
i
);
bound_value
[
j
]
=
value_ptr
[
i
];
}
}
}
}
CUDA_KERNEL_LOOP
(
i
,
nranks
+
1
)
{
int64_t
first_index
,
last_index
;
FindIntervalIndex
(
class_interval_ptr
,
nranks
,
key_ptr
[
0
],
&
first_index
);
FindIntervalIndex
(
class_interval_ptr
,
nranks
,
key_ptr
[
n
-
1
],
&
last_index
);
if
(
i
<=
first_index
)
{
bound_index
[
i
]
=
0
;
bound_value
[
i
]
=
value_ptr
[
0
];
}
else
if
(
i
>
last_index
)
{
bound_index
[
i
]
=
n
;
bound_value
[
i
]
=
value_ptr
[
n
-
1
]
+
1
;
}
}
}
template
<
typename
T
>
__global__
void
GetRemappedLabel
(
const
int64_t
n
,
const
int64_t
nranks
,
const
T
*
sampled_class_interval_ptr
,
const
T
*
bound_index
,
const
T
*
bound_value
,
const
T
*
label_map_key
,
T
*
label_map_value
,
T
*
mapped_label
)
{
CUDA_KERNEL_LOOP
(
i
,
n
)
{
#pragma unroll
for
(
int64_t
j
=
0
;
j
<
nranks
;
j
++
)
{
if
(
i
>=
bound_index
[
j
]
&&
i
<
bound_index
[
j
+
1
])
{
label_map_value
[
i
]
=
label_map_value
[
i
]
-
bound_value
[
j
]
+
sampled_class_interval_ptr
[
j
];
}
}
mapped_label
[
label_map_key
[
i
]]
=
label_map_value
[
i
];
}
}
// aligned vector generates vectorized load/store on CUDA
template
<
typename
T
,
int
Size
>
struct
alignas
(
sizeof
(
T
)
*
Size
)
AlignedVector
{
T
val
[
Size
];
};
template
<
typename
T
>
inline
int
VectorizedSize
(
const
T
*
pointer
)
{
uint64_t
address
=
reinterpret_cast
<
uint64_t
>
(
pointer
);
constexpr
int
vec4
=
std
::
alignment_of
<
AlignedVector
<
T
,
4
>>::
value
;
// NOLINT
if
(
address
%
vec4
==
0
)
{
return
4
;
}
return
1
;
}
#undef CUDA_KERNEL_LOOP
template
<
typename
T
>
class
NotEqualToPreviousAdjacentIterator
{
public:
using
self_type
=
NotEqualToPreviousAdjacentIterator
;
using
value_type
=
T
;
using
difference_type
=
std
::
ptrdiff_t
;
using
pointer
=
T
*
;
using
reference
=
T
;
using
iterator_category
=
std
::
input_iterator_tag
;
public:
__host__
__device__
__forceinline__
NotEqualToPreviousAdjacentIterator
(
const
T
*
arr
,
int64_t
offset
)
:
arr_
(
arr
),
offset_
(
offset
)
{}
__host__
__device__
__forceinline__
reference
operator
*
()
const
{
return
offset_
==
0
?
0
:
(
arr_
[
offset_
]
==
arr_
[
offset_
-
1
]
?
0
:
1
);
}
template
<
typename
Distance
>
__host__
__device__
__forceinline__
self_type
operator
+
(
Distance
n
)
const
{
self_type
ret
(
arr_
,
offset_
+
n
);
return
ret
;
}
template
<
typename
Distance
>
__host__
__device__
__forceinline__
reference
operator
[](
Distance
n
)
const
{
return
*
(
*
this
+
n
);
}
private:
const
T
*
arr_
;
int64_t
offset_
;
};
template
<
typename
T
>
struct
ActualNumSampledFunctor
{
__host__
__device__
__forceinline__
T
operator
()(
const
T
&
a
,
const
T
&
b
)
const
{
return
max
(
num_samples
,
(
b
-
a
));
}
T
num_samples
;
explicit
ActualNumSampledFunctor
(
const
T
num
)
:
num_samples
(
num
)
{}
};
template
<
typename
T
>
class
MemoryBuffer
{
public:
MemoryBuffer
(
const
int
num_buffer_ele
,
const
int
num_temp_ele
,
const
int
nranks
,
const
platform
::
Place
&
place
)
{
offset1
=
0
;
offset2
=
offset1
+
num_buffer_ele
;
offset3
=
offset2
+
num_buffer_ele
;
offset4
=
offset3
+
num_buffer_ele
;
offset5
=
offset4
+
num_buffer_ele
;
offset6
=
offset5
+
(
nranks
+
1
);
offset7
=
offset6
+
(
nranks
+
1
);
offset8
=
offset7
+
(
nranks
+
1
);
offset9
=
offset8
+
num_temp_ele
;
buffer_ptr
=
buffer
.
mutable_data
<
T
>
(
{
4
*
num_buffer_ele
+
3
*
(
nranks
+
1
)
+
num_temp_ele
},
place
);
}
T
*
cub_sort_keys_ptr
()
{
return
buffer_ptr
+
offset1
;
}
T
*
cub_sort_keys_out_ptr
()
{
return
buffer_ptr
+
offset2
;
}
T
*
cub_sort_values_ptr
()
{
return
buffer_ptr
+
offset3
;
}
T
*
cub_sort_values_out_ptr
()
{
return
buffer_ptr
+
offset4
;
}
T
*
bound_index_ptr
()
{
return
buffer_ptr
+
offset5
;
}
T
*
bound_value_ptr
()
{
return
buffer_ptr
+
offset6
;
}
T
*
class_interval_ptr
()
{
return
buffer_ptr
+
offset7
;
}
void
*
cub_temp_storage_ptr
()
{
return
reinterpret_cast
<
void
*>
(
buffer_ptr
+
offset8
);
}
private:
Tensor
buffer
;
T
*
buffer_ptr
;
int
offset1
;
int
offset2
;
int
offset3
;
int
offset4
;
int
offset5
;
int
offset6
;
int
offset7
;
int
offset8
;
int
offset9
;
};
template
<
typename
DeviceContext
,
typename
T
>
class
ClassCenterSampleCUDAKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
auto
*
label
=
ctx
.
Input
<
Tensor
>
(
"Label"
);
auto
*
remapped_label
=
ctx
.
Output
<
Tensor
>
(
"RemappedLabel"
);
auto
*
sampled_local_class_center
=
ctx
.
Output
<
Tensor
>
(
"SampledLocalClassCenter"
);
int
num_classes
=
ctx
.
Attr
<
int
>
(
"num_classes"
);
int
num_samples
=
ctx
.
Attr
<
int
>
(
"num_samples"
);
int
rid
=
ctx
.
Attr
<
int
>
(
"ring_id"
);
int
nranks
=
ctx
.
Attr
<
int
>
(
"nranks"
);
int
rank
=
ctx
.
Attr
<
int
>
(
"rank"
);
int
seed
=
ctx
.
Attr
<
int
>
(
"seed"
);
bool
fix_seed
=
ctx
.
Attr
<
bool
>
(
"fix_seed"
);
PADDLE_ENFORCE_GT
(
num_classes
,
0
,
platform
::
errors
::
InvalidArgument
(
"The value 'num_classes' for Op(class_center_sample) "
"must be greater than 0, "
"but the value given is %d."
,
num_classes
));
PADDLE_ENFORCE_GT
(
num_samples
,
0
,
platform
::
errors
::
InvalidArgument
(
"The value 'num_samples' for Op(class_center_sample) "
"must be greater than 0, "
"but the value given is %d."
,
num_samples
));
PADDLE_ENFORCE_LE
(
num_samples
,
num_classes
,
platform
::
errors
::
InvalidArgument
(
"The value 'num_samples' for Op(class_center_sample) "
"must be less than or equal to %d, "
"but the value given is %d."
,
num_classes
,
num_samples
));
auto
&
dev_ctx
=
ctx
.
template
device_context
<
DeviceContext
>();
auto
place
=
BOOST_GET_CONST
(
platform
::
CUDAPlace
,
dev_ctx
.
GetPlace
());
int
batch_size
=
label
->
numel
();
// Algorithm:
// We first randomly generate a value in [0, num_classes) on each position
// in a array(shape[num_classes]). Then, we mark the element as negative
// value in the array according input label. Now, we can sort the array
// by ascending to ensure that the positive class center always in the
// front of the sorted array. So, we can get the sampled class center
// index by sorted keys. Finally, we can get the rempped label by remap
// the input label according sampled class center.
// step 1: Calculate num classes per device using nccl all reduce
std
::
vector
<
T
>
shard_dim_vec
(
nranks
+
1
,
0
);
shard_dim_vec
[
rank
+
1
]
=
num_classes
;
Tensor
num_classes_per_device
;
framework
::
TensorFromVector
(
shard_dim_vec
,
ctx
.
cuda_device_context
(),
&
num_classes_per_device
);
T
*
num_classes_per_device_ptr
=
num_classes_per_device
.
data
<
T
>
();
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
if
(
nranks
>
1
)
{
const
auto
&
comm
=
platform
::
NCCLCommContext
::
Instance
().
Get
(
rid
,
ctx
.
GetPlace
());
// use global calculate stream
const
auto
calcu_stream
=
static_cast
<
platform
::
CUDADeviceContext
*>
(
platform
::
DeviceContextPool
::
Instance
().
Get
(
ctx
.
GetPlace
()))
->
stream
();
PADDLE_ENFORCE_CUDA_SUCCESS
(
platform
::
dynload
::
ncclAllReduce
(
num_classes_per_device_ptr
,
num_classes_per_device_ptr
,
num_classes_per_device
.
numel
(),
platform
::
ToNCCLDataType
(
num_classes_per_device
.
type
()),
ncclSum
,
comm
->
comm
(),
calcu_stream
));
}
#endif
// step 2: Determine temporary device storage requirements
int
num_buffer_ele
=
std
::
max
(
batch_size
,
num_classes
);
size_t
cub_sort_temp_store_size
=
0
;
PADDLE_ENFORCE_CUDA_SUCCESS
((
cub
::
DeviceRadixSort
::
SortPairs
<
T
,
T
>
(
nullptr
,
cub_sort_temp_store_size
,
nullptr
,
nullptr
,
nullptr
,
nullptr
,
num_buffer_ele
,
0
,
sizeof
(
T
)
*
8
,
ctx
.
cuda_device_context
().
stream
())));
size_t
cub_sum_temp_store_size
=
0
;
NotEqualToPreviousAdjacentIterator
<
T
>
unique_counting_iter_temp
(
nullptr
,
0
);
PADDLE_ENFORCE_CUDA_SUCCESS
(
(
cub
::
DeviceScan
::
InclusiveSum
<
NotEqualToPreviousAdjacentIterator
<
T
>
,
T
*>
(
nullptr
,
cub_sum_temp_store_size
,
unique_counting_iter_temp
,
nullptr
,
batch_size
,
ctx
.
cuda_device_context
().
stream
())));
size_t
cub_scan_temp_store_size
=
0
;
ActualNumSampledFunctor
<
T
>
actual_num_sampled_op_temp
(
num_samples
);
PADDLE_ENFORCE_CUDA_SUCCESS
((
cub
::
DeviceScan
::
InclusiveScan
(
nullptr
,
cub_scan_temp_store_size
,
num_classes_per_device_ptr
,
num_classes_per_device_ptr
,
actual_num_sampled_op_temp
,
nranks
+
1
,
ctx
.
cuda_device_context
().
stream
())));
size_t
cub_temp_storage_bytes
=
std
::
max
(
std
::
max
(
cub_sort_temp_store_size
,
cub_scan_temp_store_size
),
cub_sum_temp_store_size
);
int
num_temp_ele
=
cub_temp_storage_bytes
/
sizeof
(
T
)
+
1
;
// step 3: Alloc buffer memory so that we can reuse allocated memory
MemoryBuffer
<
T
>
memory_buffer
=
MemoryBuffer
<
T
>
(
num_buffer_ele
,
num_temp_ele
,
nranks
,
ctx
.
GetPlace
());
T
*
cub_sort_keys_ptr
=
memory_buffer
.
cub_sort_keys_ptr
();
T
*
cub_sort_keys_out_ptr
=
memory_buffer
.
cub_sort_keys_out_ptr
();
T
*
cub_sort_values_ptr
=
memory_buffer
.
cub_sort_values_ptr
();
T
*
cub_sort_values_out_ptr
=
memory_buffer
.
cub_sort_values_out_ptr
();
T
*
bound_index_ptr
=
memory_buffer
.
bound_index_ptr
();
T
*
bound_value_ptr
=
memory_buffer
.
bound_value_ptr
();
T
*
class_interval_ptr
=
memory_buffer
.
class_interval_ptr
();
void
*
cub_temp_storage_ptr
=
memory_buffer
.
cub_temp_storage_ptr
();
// step 4: Calculate class interval among nranks
PADDLE_ENFORCE_CUDA_SUCCESS
((
cub
::
DeviceScan
::
InclusiveSum
(
cub_temp_storage_ptr
,
cub_temp_storage_bytes
,
num_classes_per_device_ptr
,
class_interval_ptr
,
nranks
+
1
,
ctx
.
cuda_device_context
().
stream
())));
// step 5: random sample negative class center
int
vec_size
=
VectorizedSize
<
T
>
(
cub_sort_keys_ptr
);
int
increment
=
((
num_classes
-
1
)
/
(
NumBlocks
(
num_classes
)
*
kNumCUDAThreads
*
vec_size
)
+
1
)
*
vec_size
;
if
(
!
fix_seed
)
{
std
::
random_device
rnd
;
seed
=
rnd
();
}
RandomSampleClassCenter
<
T
><<<
NumBlocks
(
num_classes
),
kNumCUDAThreads
,
0
,
ctx
.
cuda_device_context
().
stream
()
>>>
(
num_classes
,
seed
+
rank
,
increment
,
num_classes
,
cub_sort_keys_ptr
);
// step 6: mark positive class center as negative value
// fill the sort values to index 0, 1, ..., batch_size-1
MarkPositiveClassCenter
<<<
NumBlocks
(
batch_size
),
kNumCUDAThreads
,
0
,
ctx
.
cuda_device_context
().
stream
()
>>>
(
batch_size
,
rank
,
class_interval_ptr
,
num_classes
,
label
->
data
<
T
>
(),
cub_sort_keys_ptr
);
Range
<
T
><<<
NumBlocks
(
num_buffer_ele
),
kNumCUDAThreads
,
0
,
ctx
.
cuda_device_context
().
stream
()
>>>
(
num_buffer_ele
,
cub_sort_values_ptr
);
// step 7: sort class center by ascending, so that positive class center
// always be sampled.
PADDLE_ENFORCE_CUDA_SUCCESS
((
cub
::
DeviceRadixSort
::
SortPairs
<
T
,
T
>
(
cub_temp_storage_ptr
,
cub_temp_storage_bytes
,
cub_sort_keys_ptr
,
cub_sort_keys_out_ptr
,
cub_sort_values_ptr
,
cub_sort_values_out_ptr
,
num_classes
,
0
,
sizeof
(
T
)
*
8
,
ctx
.
cuda_device_context
().
stream
())));
// step 8: sort input label ascending
PADDLE_ENFORCE_CUDA_SUCCESS
((
cub
::
DeviceRadixSort
::
SortPairs
<
T
,
T
>
(
cub_temp_storage_ptr
,
cub_temp_storage_bytes
,
label
->
data
<
T
>
(),
cub_sort_keys_out_ptr
,
cub_sort_values_ptr
,
cub_sort_keys_ptr
,
batch_size
,
0
,
sizeof
(
T
)
*
8
,
ctx
.
cuda_device_context
().
stream
())));
// step 9: Calculate new index using InclusiveSum on ascending sorted input
// label
NotEqualToPreviousAdjacentIterator
<
T
>
unique_counting_iter
(
cub_sort_keys_out_ptr
,
0
);
PADDLE_ENFORCE_CUDA_SUCCESS
((
cub
::
DeviceScan
::
InclusiveSum
<
NotEqualToPreviousAdjacentIterator
<
T
>
,
T
*>
(
cub_temp_storage_ptr
,
cub_temp_storage_bytes
,
unique_counting_iter
,
cub_sort_values_ptr
,
batch_size
,
ctx
.
cuda_device_context
().
stream
())));
// step 10: Calculate new class center bound among ranks
GetClassCenterBound
<
T
><<<
NumBlocks
(
batch_size
),
kNumCUDAThreads
,
0
,
ctx
.
cuda_device_context
().
stream
()
>>>
(
batch_size
,
nranks
,
class_interval_ptr
,
cub_sort_keys_out_ptr
,
cub_sort_values_ptr
,
bound_index_ptr
,
bound_value_ptr
);
// step 11: Calculate actual number of sampled class per device.
// Since maybe num_positive_class_center > num_samples,
// we need to ensure all positive class center per device are sampled.
ActualNumSampledFunctor
<
T
>
actual_num_sampled_op
(
num_samples
);
PADDLE_ENFORCE_CUDA_SUCCESS
((
cub
::
DeviceScan
::
InclusiveScan
(
cub_temp_storage_ptr
,
cub_temp_storage_bytes
,
bound_value_ptr
,
num_classes_per_device_ptr
,
actual_num_sampled_op
,
nranks
+
1
,
ctx
.
cuda_device_context
().
stream
())));
// step 12: Calculate actual sampled class interval among nranks
PADDLE_ENFORCE_CUDA_SUCCESS
((
cub
::
DeviceScan
::
InclusiveSum
(
cub_temp_storage_ptr
,
cub_temp_storage_bytes
,
num_classes_per_device_ptr
,
class_interval_ptr
,
nranks
+
1
,
ctx
.
cuda_device_context
().
stream
())));
// step 13: Get remapped label for output
GetRemappedLabel
<
T
><<<
NumBlocks
(
batch_size
),
kNumCUDAThreads
,
0
,
ctx
.
cuda_device_context
().
stream
()
>>>
(
batch_size
,
nranks
,
class_interval_ptr
,
bound_index_ptr
,
bound_value_ptr
,
cub_sort_keys_ptr
,
cub_sort_values_ptr
,
remapped_label
->
mutable_data
<
T
>
(
ctx
.
GetPlace
()));
// step 14: Get sampled class center for output
framework
::
TensorCopySync
(
num_classes_per_device
,
platform
::
CPUPlace
(),
&
num_classes_per_device
);
T
actual_num_samples
=
num_classes_per_device
.
data
<
T
>
()[
rank
+
1
];
T
*
sampled_local_class_center_ptr
=
sampled_local_class_center
->
mutable_data
<
T
>
({
actual_num_samples
},
ctx
.
GetPlace
());
memory
::
Copy
(
place
,
sampled_local_class_center_ptr
,
place
,
cub_sort_values_out_ptr
,
actual_num_samples
*
sizeof
(
T
),
nullptr
);
}
};
}
// namespace operators
}
// namespace paddle
namespace
ops
=
paddle
::
operators
;
REGISTER_OP_CUDA_KERNEL
(
class_center_sample
,
ops
::
ClassCenterSampleCUDAKernel
<
paddle
::
platform
::
CUDADeviceContext
,
int64_t
>
,
ops
::
ClassCenterSampleCUDAKernel
<
paddle
::
platform
::
CUDADeviceContext
,
int
>
);
paddle/fluid/operators/class_center_sample_op.h
0 → 100644
浏览文件 @
100db44f
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <map>
#include <set>
#include <vector>
#include "paddle/fluid/framework/generator.h"
#include "paddle/fluid/framework/op_registry.h"
namespace
paddle
{
namespace
operators
{
using
Tensor
=
framework
::
Tensor
;
template
<
typename
T
>
class
ClassCenterSampleCPUKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
auto
*
label
=
ctx
.
Input
<
Tensor
>
(
"Label"
);
auto
*
remapped_label
=
ctx
.
Output
<
Tensor
>
(
"RemappedLabel"
);
auto
*
sampled_local_class_center
=
ctx
.
Output
<
Tensor
>
(
"SampledLocalClassCenter"
);
int
num_classes
=
ctx
.
Attr
<
int
>
(
"num_classes"
);
int
num_samples
=
ctx
.
Attr
<
int
>
(
"num_samples"
);
int
seed
=
ctx
.
Attr
<
int
>
(
"seed"
);
bool
fix_seed
=
ctx
.
Attr
<
bool
>
(
"fix_seed"
);
PADDLE_ENFORCE_GT
(
num_classes
,
0
,
platform
::
errors
::
InvalidArgument
(
"The value 'num_classes' for Op(class_center_sample) "
"must be greater than 0, "
"but the value given is %d."
,
num_classes
));
PADDLE_ENFORCE_GT
(
num_samples
,
0
,
platform
::
errors
::
InvalidArgument
(
"The value 'num_samples' for Op(class_center_sample) "
"must be greater than 0, "
"but the value given is %d."
,
num_samples
));
PADDLE_ENFORCE_LE
(
num_samples
,
num_classes
,
platform
::
errors
::
InvalidArgument
(
"The value 'num_samples' for Op(class_center_sample) "
"must be less than or equal to %d, "
"but the value given is %d."
,
num_classes
,
num_samples
));
int64_t
numel
=
label
->
numel
();
auto
*
label_ptr
=
label
->
data
<
T
>
();
// get unique positive class center by ascending
std
::
set
<
T
,
std
::
less
<
T
>>
unique_label
;
for
(
int64_t
i
=
0
;
i
<
numel
;
++
i
)
{
unique_label
.
insert
(
label_ptr
[
i
]);
}
// constrcut a lookup table and get sampled_local_class_center
std
::
vector
<
T
>
actual_sampled
;
std
::
map
<
T
,
T
>
new_class_dict
;
T
idx
=
0
;
for
(
auto
&
t
:
unique_label
)
{
new_class_dict
[
t
]
=
idx
;
actual_sampled
.
push_back
(
t
);
idx
++
;
}
if
(
!
fix_seed
)
{
std
::
random_device
rnd
;
seed
=
rnd
();
}
std
::
uniform_int_distribution
<
T
>
dist
(
0
,
num_classes
-
1
);
auto
engine
=
framework
::
GetCPURandomEngine
(
seed
);
// sample negative class center randomly
while
(
unique_label
.
size
()
<
static_cast
<
size_t
>
(
num_samples
))
{
T
neg
=
dist
(
*
engine
);
if
(
unique_label
.
find
(
neg
)
==
unique_label
.
end
())
{
unique_label
.
insert
(
neg
);
// unorder for negative class center
actual_sampled
.
push_back
(
neg
);
}
}
int
actual_num_samples
=
unique_label
.
size
();
T
*
sampled_local_class_center_ptr
=
sampled_local_class_center
->
mutable_data
<
T
>
({
actual_num_samples
},
ctx
.
GetPlace
());
idx
=
0
;
for
(
auto
&
t
:
actual_sampled
)
{
sampled_local_class_center_ptr
[
idx
]
=
t
;
idx
++
;
}
// remap the input label to sampled class
auto
*
remmaped_label_ptr
=
remapped_label
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
for
(
int64_t
i
=
0
;
i
<
numel
;
++
i
)
{
remmaped_label_ptr
[
i
]
=
new_class_dict
[
label_ptr
[
i
]];
}
}
};
}
// namespace operators
}
// namespace paddle
python/paddle/fluid/tests/unittests/CMakeLists.txt
浏览文件 @
100db44f
file
(
GLOB TEST_OPS RELATIVE
"
${
CMAKE_CURRENT_SOURCE_DIR
}
"
"test_*.py"
)
file
(
GLOB TEST_OPS RELATIVE
"
${
CMAKE_CURRENT_SOURCE_DIR
}
"
"test_*.py"
)
string
(
REPLACE
".py"
""
TEST_OPS
"
${
TEST_OPS
}
"
)
set
(
GC_ENVS FLAGS_eager_delete_tensor_gb=0.0 FLAGS_fast_eager_deletion_mode=1 FLAGS_memory_fraction_of_eager_deletion=1.0
)
set
(
dist_ENVS http_proxy=
""
https_proxy=
""
)
...
...
@@ -28,6 +29,7 @@ list(APPEND DIST_TEST_OPS test_parallel_dygraph_pipeline_parallel)
list
(
APPEND DIST_TEST_OPS test_parallel_dygraph_tensor_parallel
)
list
(
APPEND DIST_TEST_OPS test_parallel_dygraph_sharding_parallel
)
list
(
APPEND DIST_TEST_OPS test_parallel_dygraph_mp_layers
)
list
(
APPEND DIST_TEST_OPS test_parallel_class_center_sample
)
list
(
APPEND DIST_TEST_OPS test_parallel_margin_cross_entropy
)
set
(
MIXED_DIST_TEST_OPS
${
DIST_TEST_OPS
}
)
#remove distribute unittests.
...
...
@@ -196,6 +198,7 @@ if ((NOT WITH_GPU) AND (NOT WITH_ROCM))
LIST
(
REMOVE_ITEM TEST_OPS test_mixed_precision
)
LIST
(
REMOVE_ITEM TEST_OPS test_fleet_base_single
)
LIST
(
REMOVE_ITEM TEST_OPS test_dygraph_recompute
)
list
(
REMOVE_ITEM TEST_OPS test_parallel_class_center_sample
)
LIST
(
REMOVE_ITEM TEST_OPS test_parallel_margin_cross_entropy
)
elseif
(
WITH_GPU
)
if
(
${
CUDNN_VERSION
}
VERSION_LESS 7100
)
...
...
@@ -908,6 +911,7 @@ if(WITH_DISTRIBUTE AND WITH_GPU AND WITH_NCCL)
set_tests_properties
(
test_parallel_dygraph_tensor_parallel PROPERTIES TIMEOUT 200
)
set_tests_properties
(
test_parallel_dygraph_sharding_parallel PROPERTIES TIMEOUT 120
)
set_tests_properties
(
test_parallel_dygraph_mp_layers PROPERTIES TIMEOUT 120
)
set_tests_properties
(
test_parallel_class_center_sample PROPERTIES TIMEOUT 120
)
set_tests_properties
(
test_parallel_margin_cross_entropy PROPERTIES TIMEOUT 120
)
if
(
${
NCCL_VERSION
}
VERSION_GREATER_EQUAL 2212
)
set_tests_properties
(
test_parallel_dygraph_sparse_embedding PROPERTIES TIMEOUT 120
)
...
...
python/paddle/fluid/tests/unittests/parallel_class_center_sample.py
0 → 100644
浏览文件 @
100db44f
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
division
from
__future__
import
print_function
import
unittest
import
paddle
import
numpy
as
np
import
random
import
paddle.distributed
as
dist
import
paddle.fluid
as
fluid
import
paddle.distributed.fleet
as
fleet
from
paddle
import
framework
def
set_random_seed
(
seed
):
"""Set random seed for reproducability."""
random
.
seed
(
seed
)
np
.
random
.
seed
(
seed
)
paddle
.
seed
(
seed
)
fleet
.
meta_parallel
.
model_parallel_random_seed
(
seed
)
def
class_center_sample_numpy
(
label
,
classes_list
,
num_samples
):
unique_label
=
np
.
unique
(
label
)
nranks
=
len
(
classes_list
)
class_interval
=
np
.
cumsum
(
np
.
insert
(
classes_list
,
0
,
0
))
pos_class_center_per_device
=
[]
unique_label_per_device
=
[]
for
i
in
range
(
nranks
):
index
=
np
.
logical_and
(
unique_label
>=
class_interval
[
i
],
unique_label
<
class_interval
[
i
+
1
])
pos_class_center_per_device
.
append
(
unique_label
[
index
]
-
class_interval
[
i
])
unique_label_per_device
.
append
(
unique_label
[
index
])
num_samples_per_device
=
[]
for
pos_class_center
in
pos_class_center_per_device
:
num_samples_per_device
.
append
(
max
(
len
(
pos_class_center
),
num_samples
))
sampled_class_interval
=
np
.
cumsum
(
np
.
insert
(
num_samples_per_device
,
0
,
0
))
remapped_dict
=
{}
for
i
in
range
(
nranks
):
for
idx
,
v
in
enumerate
(
unique_label_per_device
[
i
],
sampled_class_interval
[
i
]):
remapped_dict
[
v
]
=
idx
remapped_label
=
[]
for
l
in
label
:
remapped_label
.
append
(
remapped_dict
[
l
])
return
remapped_label
,
pos_class_center_per_device
class
TestParallelClassCenterSampleOp
(
unittest
.
TestCase
):
def
setUp
(
self
):
strategy
=
fleet
.
DistributedStrategy
()
fleet
.
init
(
is_collective
=
True
,
strategy
=
strategy
)
def
test_class_center_sample
(
self
):
rank_id
=
dist
.
get_rank
()
nranks
=
dist
.
get_world_size
()
seed
=
1025
set_random_seed
(
seed
)
paddle
.
seed
(
rank_id
*
10
)
random
.
seed
(
seed
)
np
.
random
.
seed
(
seed
)
batch_size
=
20
num_samples
=
6
for
dtype
in
(
'int32'
,
'int64'
):
for
_
in
range
(
5
):
classes_list
=
np
.
random
.
randint
(
10
,
15
,
(
nranks
,
))
num_class
=
np
.
sum
(
classes_list
)
np_label
=
np
.
random
.
randint
(
0
,
num_class
,
(
batch_size
,
),
dtype
=
dtype
)
label
=
paddle
.
to_tensor
(
np_label
,
dtype
=
dtype
)
np_remapped_label
,
np_sampled_class_center_per_device
=
class_center_sample_numpy
(
np_label
,
classes_list
,
num_samples
)
remapped_label
,
sampled_class_index
=
paddle
.
nn
.
functional
.
class_center_sample
(
label
,
classes_list
[
rank_id
],
num_samples
)
np
.
testing
.
assert_allclose
(
remapped_label
.
numpy
(),
np_remapped_label
)
np_sampled_class_index
=
np_sampled_class_center_per_device
[
rank_id
]
np
.
testing
.
assert_allclose
(
sampled_class_index
.
numpy
()[:
len
(
np_sampled_class_index
)],
np_sampled_class_index
)
if
__name__
==
'__main__'
:
unittest
.
main
()
python/paddle/fluid/tests/unittests/test_class_center_sample_op.py
0 → 100644
浏览文件 @
100db44f
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
print_function
import
unittest
import
numpy
as
np
import
math
import
random
import
paddle
import
paddle.fluid.core
as
core
from
op_test
import
OpTest
from
paddle.fluid
import
Program
,
program_guard
def
class_center_sample_numpy
(
label
,
classes_list
,
num_samples
):
unique_label
=
np
.
unique
(
label
)
nranks
=
len
(
classes_list
)
class_interval
=
np
.
cumsum
(
np
.
insert
(
classes_list
,
0
,
0
))
pos_class_center_per_device
=
[]
unique_label_per_device
=
[]
for
i
in
range
(
nranks
):
index
=
np
.
logical_and
(
unique_label
>=
class_interval
[
i
],
unique_label
<
class_interval
[
i
+
1
])
pos_class_center_per_device
.
append
(
unique_label
[
index
]
-
class_interval
[
i
])
unique_label_per_device
.
append
(
unique_label
[
index
])
num_samples_per_device
=
[]
for
pos_class_center
in
pos_class_center_per_device
:
num_samples_per_device
.
append
(
max
(
len
(
pos_class_center
),
num_samples
))
sampled_class_interval
=
np
.
cumsum
(
np
.
insert
(
num_samples_per_device
,
0
,
0
))
remapped_dict
=
{}
for
i
in
range
(
nranks
):
for
idx
,
v
in
enumerate
(
unique_label_per_device
[
i
],
sampled_class_interval
[
i
]):
remapped_dict
[
v
]
=
idx
remapped_label
=
[]
for
l
in
label
:
remapped_label
.
append
(
remapped_dict
[
l
])
return
np
.
array
(
remapped_label
),
np
.
array
(
pos_class_center_per_device
)
class
TestClassCenterSampleOp
(
OpTest
):
def
initParams
(
self
):
self
.
op_type
=
"class_center_sample"
self
.
batch_size
=
20
self
.
num_samples
=
6
self
.
num_classes
=
10
self
.
seed
=
2021
def
init_dtype
(
self
):
self
.
dtype
=
np
.
int64
def
init_fix_seed
(
self
):
self
.
fix_seed
=
True
def
setUp
(
self
):
self
.
initParams
()
self
.
init_dtype
()
self
.
init_fix_seed
()
label
=
np
.
random
.
randint
(
0
,
self
.
num_classes
,
(
self
.
batch_size
,
),
dtype
=
self
.
dtype
)
remapped_label
,
sampled_class_center
=
class_center_sample_numpy
(
label
,
[
self
.
num_classes
],
self
.
num_samples
)
self
.
inputs
=
{
'Label'
:
label
}
self
.
outputs
=
{
'RemappedLabel'
:
remapped_label
.
astype
(
self
.
dtype
),
'SampledLocalClassCenter'
:
sampled_class_center
.
astype
(
self
.
dtype
)
}
self
.
attrs
=
{
'num_classes'
:
self
.
num_classes
,
'num_samples'
:
self
.
num_samples
,
'seed'
:
self
.
seed
,
'fix_seed'
:
self
.
fix_seed
,
}
def
test_check_output
(
self
):
self
.
check_output
(
no_check_set
=
[
'SampledLocalClassCenter'
])
class
TestClassCenterSampleOpINT32
(
TestClassCenterSampleOp
):
def
init_dtype
(
self
):
self
.
dtype
=
np
.
int32
class
TestClassCenterSampleOpFixSeed
(
TestClassCenterSampleOp
):
def
init_fix_seed
(
self
):
self
.
fix_seed
=
True
class
TestClassCenterSampleV2
(
unittest
.
TestCase
):
def
setUp
(
self
):
self
.
initParams
()
np
.
random
.
seed
(
self
.
seed
)
paddle
.
framework
.
random
.
_manual_program_seed
(
2021
)
self
.
places
=
[
paddle
.
fluid
.
CPUPlace
()]
if
core
.
is_compiled_with_cuda
():
self
.
places
.
append
(
paddle
.
fluid
.
CUDAPlace
(
0
))
def
initParams
(
self
):
self
.
batch_size
=
10
self
.
num_samples
=
6
self
.
num_classes
=
20
self
.
seed
=
0
self
.
init_dtype
()
def
init_dtype
(
self
):
self
.
dtype
=
np
.
int64
def
test_static
(
self
):
for
place
in
self
.
places
:
self
.
check_static_result
(
place
=
place
)
def
check_static_result
(
self
,
place
):
with
program_guard
(
Program
(),
Program
()):
label_np
=
np
.
random
.
randint
(
0
,
self
.
num_classes
,
(
self
.
batch_size
,
),
dtype
=
self
.
dtype
)
label
=
paddle
.
static
.
data
(
name
=
'label'
,
shape
=
[
self
.
batch_size
],
dtype
=
self
.
dtype
)
remapped_label
,
sampled_class_index
=
paddle
.
nn
.
functional
.
class_center_sample
(
label
,
self
.
num_classes
,
self
.
num_samples
,
seed
=
self
.
seed
)
remapped_label_np
,
sampled_class_center_np
=
class_center_sample_numpy
(
label_np
,
[
self
.
num_classes
],
self
.
num_samples
)
exe
=
paddle
.
fluid
.
Executor
(
place
)
[
remapped_label_res
,
sampled_class_index_res
]
=
exe
.
run
(
paddle
.
fluid
.
default_main_program
(),
feed
=
{
'label'
:
label_np
},
fetch_list
=
[
remapped_label
,
sampled_class_index
])
np
.
testing
.
assert_allclose
(
remapped_label_res
,
remapped_label_np
)
np
.
testing
.
assert_allclose
(
sampled_class_index_res
[:
len
(
sampled_class_center_np
[
0
])],
sampled_class_center_np
[
0
])
def
test_dynamic
(
self
):
for
place
in
self
.
places
:
self
.
check_dynamic_result
(
place
=
place
)
def
check_dynamic_result
(
self
,
place
):
with
paddle
.
fluid
.
dygraph
.
guard
(
place
):
label_np
=
np
.
random
.
randint
(
0
,
self
.
num_classes
,
(
self
.
batch_size
,
),
dtype
=
self
.
dtype
)
label
=
paddle
.
to_tensor
(
label_np
,
dtype
=
self
.
dtype
)
remapped_label
,
sampled_class_index
=
paddle
.
nn
.
functional
.
class_center_sample
(
label
,
self
.
num_classes
,
self
.
num_samples
,
seed
=
self
.
seed
)
remapped_label_np
,
sampled_class_center_np
=
class_center_sample_numpy
(
label_np
,
[
self
.
num_classes
],
self
.
num_samples
)
remapped_label_res
=
remapped_label
.
numpy
()
sampled_class_index_res
=
sampled_class_index
.
numpy
()
np
.
testing
.
assert_allclose
(
remapped_label_res
,
remapped_label_np
)
np
.
testing
.
assert_allclose
(
sampled_class_index_res
[:
len
(
sampled_class_center_np
[
0
])],
sampled_class_center_np
[
0
])
class
TestClassCenterSampleV2INT32
(
TestClassCenterSampleV2
):
def
init_dtype
(
self
):
self
.
dtype
=
np
.
int32
class
TestClassCenterSampleAPIError
(
unittest
.
TestCase
):
def
setUp
(
self
):
self
.
initParams
()
np
.
random
.
seed
(
self
.
seed
)
self
.
places
=
[
paddle
.
fluid
.
CPUPlace
()]
if
core
.
is_compiled_with_cuda
():
self
.
places
.
append
(
paddle
.
fluid
.
CUDAPlace
(
0
))
def
initParams
(
self
):
self
.
batch_size
=
20
self
.
num_samples
=
15
self
.
num_classes
=
10
self
.
seed
=
2021
self
.
init_dtype
()
def
init_dtype
(
self
):
self
.
dtype
=
np
.
int64
def
test_dynamic_errors
(
self
):
def
test_num_samples
():
for
place
in
self
.
places
:
with
paddle
.
fluid
.
dygraph
.
guard
(
place
):
label_np
=
np
.
random
.
randint
(
0
,
self
.
num_classes
,
(
self
.
batch_size
,
),
dtype
=
self
.
dtype
)
label
=
paddle
.
to_tensor
(
label_np
)
remapped_label
,
sampled_class_index
=
paddle
.
nn
.
functional
.
class_center_sample
(
label
,
self
.
num_classes
,
self
.
num_samples
,
seed
=
self
.
seed
)
self
.
assertRaises
(
ValueError
,
test_num_samples
)
if
__name__
==
'__main__'
:
unittest
.
main
()
python/paddle/fluid/tests/unittests/test_parallel_class_center_sample.py
0 → 100644
浏览文件 @
100db44f
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
print_function
import
unittest
import
paddle.fluid
as
fluid
from
test_parallel_dygraph_dataparallel
import
TestMultipleGpus
class
TestParallelClassCenterSample
(
TestMultipleGpus
):
def
test_parallel_class_center_sample
(
self
):
self
.
run_mnist_2gpu
(
'parallel_class_center_sample.py'
)
if
__name__
==
"__main__"
:
unittest
.
main
()
python/paddle/fluid/tests/unittests/white_list/no_check_set_white_list.py
浏览文件 @
100db44f
...
...
@@ -31,4 +31,5 @@ no_check_set_white_list = [
'rnn'
,
'fusion_lstm'
,
'softmax_with_cross_entropy'
,
'class_center_sample'
,
]
python/paddle/nn/functional/__init__.py
浏览文件 @
100db44f
...
...
@@ -55,6 +55,7 @@ from .common import unfold # noqa: F401
from
.common
import
interpolate
# noqa: F401
from
.common
import
upsample
# noqa: F401
from
.common
import
bilinear
# noqa: F401
from
.common
import
class_center_sample
# noqa: F401
from
.conv
import
conv1d
# noqa: F401
from
.conv
import
conv1d_transpose
# noqa: F401
from
.common
import
linear
# noqa: F401
...
...
@@ -200,5 +201,6 @@ __all__ = [ #noqa
'temporal_shift'
,
'batch_norm'
,
'layer_norm'
,
'instance_norm'
'instance_norm'
,
'class_center_sample'
,
]
python/paddle/nn/functional/common.py
浏览文件 @
100db44f
...
...
@@ -1564,3 +1564,156 @@ def label_smooth(label, prior_dist=None, epsilon=0.1, name=None):
outputs
=
{
"Out"
:
smooth_label
},
attrs
=
{
"epsilon"
:
float
(
epsilon
)})
return
smooth_label
def
class_center_sample
(
label
,
num_classes
,
num_samples
,
group
=
None
,
seed
=
None
):
"""
Class center sample method is proposed from the paper PartialFC that only sample a subset of the class centers.
The process of sampling subset class centers is straightforward:
1. First select the positive class centers;
2. Then randomly sample negative class centers.
Specifically, given a label tensor, shape [batch_size], select all the positive class centers and randomly
sample negative class centers, then remap the input label tensor using the sampled class centers.
For more information, Partial FC: Training 10 Million Identities on a Single Machine
arxiv: https://arxiv.org/abs/2010.05222
.. hint::
If the number of the positive class centers is greater than the input num_samples, it keeps all the positive
class centers and the shape of sampled_class_center will be [num_positive_class_centers].
The API supports CPU, single GPU and multi GPU.
Args:
label (Tensor): 1-D tensor with shape [N], each label in [0, num_classes)
num_classes (int): A positive integer to specify the number of classes at local rank.
Note that num_classes of each GPU can be different.
num_samples (int): A positive integer to specify the number of class center to sample.
group (Group, optional): The abstract representation of group.
See paddle.distributed.collective.Group. Default is ``None``.
seed (int, optional): Random seed. Default is ``None``.
Returns:
Tuple of two ``Tensor`` : (remapped_label, sampled_class_center), remapped label using sampled class center,
sampled class center from [0, num_classes).
Examples:
.. code-block:: python
# CPU or single GPU
import paddle
num_classes = 20
batch_size = 10
num_samples = 6
label = paddle.randint(low=0, high=num_classes, shape=[batch_size], dtype='int64')
remapped_label, sampled_class_index = paddle.nn.functional.class_center_sample(label, num_classes, num_samples)
print(label)
print(remapped_label)
print(sampled_class_index)
# the output is
#Tensor(shape=[10], dtype=int64, place=CPUPlace, stop_gradient=True,
# [11, 5 , 1 , 3 , 12, 2 , 15, 19, 18, 19])
#Tensor(shape=[10], dtype=int64, place=CPUPlace, stop_gradient=True,
# [4, 3, 0, 2, 5, 1, 6, 8, 7, 8])
#Tensor(shape=[9], dtype=int64, place=CPUPlace, stop_gradient=True,
# [1 , 2 , 3 , 5 , 11, 12, 15, 18, 19])
.. code-block:: python
# required: distributed
# Multi GPU, test_class_center_sample.py
import paddle
import paddle.distributed as dist
strategy = dist.fleet.DistributedStrategy()
dist.fleet.init(is_collective=True, strategy=strategy)
batch_size = 10
num_samples = 6
rank_id = dist.get_rank()
# num_classes of each GPU can be different, e.g num_classes_list = [10, 8]
num_classes_list = [10, 10]
num_classes = paddle.sum(paddle.to_tensor(num_classes_list))
label = paddle.randint(low=0, high=num_classes.item(), shape=[batch_size], dtype='int64')
label_list = []
dist.all_gather(label_list, label)
label = paddle.concat(label_list, axis=0)
remapped_label, sampled_class_index = paddle.nn.functional.class_center_sample(label, num_classes_list[rank_id], num_samples)
print(label)
print(remapped_label)
print(sampled_class_index)
#python -m paddle.distributed.launch --gpus=0,1 test_class_center_sample.py
# rank 0 output:
#Tensor(shape=[20], dtype=int64, place=CUDAPlace(0), stop_gradient=True,
# [10, 17, 15, 11, 9 , 12, 18, 18, 17, 18, 19, 2 , 8 , 13, 11, 13, 9 , 10, 0 , 4 ])
#Tensor(shape=[20], dtype=int64, place=CUDAPlace(0), stop_gradient=True,
# [6 , 11, 10, 7 , 4 , 8 , 12, 12, 11, 12, 13, 1 , 3 , 9 , 7 , 9 , 4 , 6 , 0 , 2 ])
#Tensor(shape=[6], dtype=int64, place=CUDAPlace(0), stop_gradient=True,
# [0, 2, 4, 8, 9, 3])
# rank 1 output:
#Tensor(shape=[20], dtype=int64, place=CUDAPlace(1), stop_gradient=True,
# [10, 17, 15, 11, 9 , 12, 18, 18, 17, 18, 19, 2 , 8 , 13, 11, 13, 9 , 10, 0 , 4 ])
#Tensor(shape=[20], dtype=int64, place=CUDAPlace(1), stop_gradient=True,
# [6 , 11, 10, 7 , 4 , 8 , 12, 12, 11, 12, 13, 1 , 3 , 9 , 7 , 9 , 4 , 6 , 0 , 2 ])
#Tensor(shape=[7], dtype=int64, place=CUDAPlace(1), stop_gradient=True,
# [0, 1, 2, 3, 5, 7, 8])
"""
if
group
is
not
None
and
not
group
.
is_member
():
return
ring_id
=
0
if
group
is
None
else
group
.
id
rank
=
0
nranks
=
1
if
core
.
is_compiled_with_dist
():
parallel_env
=
paddle
.
distributed
.
ParallelEnv
()
global_rank
=
parallel_env
.
rank
rank
=
global_rank
if
group
is
None
else
group
.
get_group_rank
(
global_rank
)
nranks
=
parallel_env
.
world_size
if
group
is
None
else
group
.
nranks
if
num_samples
>
num_classes
:
raise
ValueError
(
'Expected num_samples less than or equal to {}, got num_samples {}'
.
format
(
num_classes
,
num_samples
))
if
(
seed
is
None
or
seed
==
0
)
and
default_main_program
().
random_seed
!=
0
:
seed
=
default_main_program
().
random_seed
if
in_dygraph_mode
():
remapped_label
,
sampled_class_center
=
core
.
ops
.
class_center_sample
(
label
,
'num_classes'
,
num_classes
,
'num_samples'
,
num_samples
,
'ring_id'
,
ring_id
,
'nranks'
,
nranks
,
'rank'
,
rank
,
'fix_seed'
,
seed
is
not
None
,
'seed'
,
seed
if
seed
is
not
None
else
0
)
return
remapped_label
,
sampled_class_center
check_variable_and_dtype
(
label
,
'label'
,
[
'int64'
,
'int32'
],
'class_center_sample'
)
op_type
=
'class_center_sample'
helper
=
LayerHelper
(
op_type
,
**
locals
())
remapped_label
=
helper
.
create_variable_for_type_inference
(
dtype
=
label
.
dtype
)
sampled_class_center
=
helper
.
create_variable_for_type_inference
(
dtype
=
label
.
dtype
)
helper
.
append_op
(
type
=
op_type
,
inputs
=
{
'Label'
:
label
},
outputs
=
{
'RemappedLabel'
:
remapped_label
,
'SampledLocalClassCenter'
:
sampled_class_center
},
attrs
=
{
'num_classes'
:
num_classes
,
'num_samples'
:
num_samples
,
'ring_id'
:
ring_id
,
'nranks'
:
nranks
,
'rank'
:
rank
,
'fix_seed'
:
seed
is
not
None
,
'seed'
:
seed
if
seed
is
not
None
else
0
})
return
remapped_label
,
sampled_class_center
tools/static_mode_white_list.py
浏览文件 @
100db44f
...
...
@@ -719,5 +719,6 @@ STATIC_MODE_TESTING_LIST = [
'test_sgd_op_bf16'
,
'test_marker_op'
,
'test_c_embedding_op'
,
'test_class_center_sample_op'
,
'test_margin_cross_entropy_op'
,
]
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录