Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleDetection
提交
58ad40cc
P
PaddleDetection
项目概览
PaddlePaddle
/
PaddleDetection
大约 1 年 前同步成功
通知
695
Star
11112
Fork
2696
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
184
列表
看板
标记
里程碑
合并请求
40
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleDetection
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
184
Issue
184
列表
看板
标记
里程碑
合并请求
40
合并请求
40
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
58ad40cc
编写于
1月 30, 2019
作者:
X
xuezhong
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add sample_logits op
上级
b5ebca47
变更
14
展开全部
隐藏空白更改
内联
并排
Showing
14 changed file
with
2540 addition
and
2 deletion
+2540
-2
paddle/fluid/operators/CMakeLists.txt
paddle/fluid/operators/CMakeLists.txt
+1
-1
paddle/fluid/operators/math/CMakeLists.txt
paddle/fluid/operators/math/CMakeLists.txt
+1
-0
paddle/fluid/operators/math/sample_prob.cc
paddle/fluid/operators/math/sample_prob.cc
+26
-0
paddle/fluid/operators/math/sample_prob.cu
paddle/fluid/operators/math/sample_prob.cu
+188
-0
paddle/fluid/operators/math/sample_prob.h
paddle/fluid/operators/math/sample_prob.h
+118
-0
paddle/fluid/operators/sample_logits_op.cc
paddle/fluid/operators/sample_logits_op.cc
+248
-0
paddle/fluid/operators/sample_logits_op.cu
paddle/fluid/operators/sample_logits_op.cu
+321
-0
paddle/fluid/operators/sample_logits_op.h
paddle/fluid/operators/sample_logits_op.h
+275
-0
python/paddle/fluid/__init__.py
python/paddle/fluid/__init__.py
+1
-1
python/paddle/fluid/layers/nn.py
python/paddle/fluid/layers/nn.py
+99
-0
python/paddle/fluid/tests/unittests/op_test.py
python/paddle/fluid/tests/unittests/op_test.py
+1
-0
python/paddle/fluid/tests/unittests/test_layers.py
python/paddle/fluid/tests/unittests/test_layers.py
+10
-0
python/paddle/fluid/tests/unittests/test_sample_logits.py
python/paddle/fluid/tests/unittests/test_sample_logits.py
+1233
-0
python/paddle/fluid/tests/unittests/testsuite.py
python/paddle/fluid/tests/unittests/testsuite.py
+18
-0
未找到文件。
paddle/fluid/operators/CMakeLists.txt
浏览文件 @
58ad40cc
...
...
@@ -66,7 +66,7 @@ set(COMMON_OP_DEPS ${OP_HEADER_DEPS})
set
(
COMMON_OP_DEPS
${
COMMON_OP_DEPS
}
selected_rows_functor selected_rows lod_tensor maxouting unpooling pooling lod_rank_table context_project sequence_pooling executor
)
set
(
COMMON_OP_DEPS
${
COMMON_OP_DEPS
}
dynload_warpctc
)
set
(
COMMON_OP_DEPS
${
COMMON_OP_DEPS
}
sequence_padding sequence_scale cos_sim_functor memory jit_kernel_helper concat_and_split cross_entropy softmax vol2col im2col sampler tree2col
)
set
(
COMMON_OP_DEPS
${
COMMON_OP_DEPS
}
sequence_padding sequence_scale cos_sim_functor memory jit_kernel_helper concat_and_split cross_entropy softmax vol2col im2col sampler
sample_prob
tree2col
)
set
(
COMMON_OP_DEPS
${
COMMON_OP_DEPS
}
sequence2batch lstm_compute matrix_bit_code gru_compute activation_functions beam_search
)
if
(
WITH_GPU
)
set
(
COMMON_OP_DEPS
${
COMMON_OP_DEPS
}
depthwise_conv prelu
)
...
...
paddle/fluid/operators/math/CMakeLists.txt
浏览文件 @
58ad40cc
...
...
@@ -39,6 +39,7 @@ math_library(cross_entropy)
math_library
(
cos_sim_functor
)
math_library
(
depthwise_conv
)
math_library
(
im2col
)
math_library
(
sample_prob
)
math_library
(
sampler
)
math_library
(
gru_compute DEPS activation_functions math_function
)
...
...
paddle/fluid/operators/math/sample_prob.cc
0 → 100644
浏览文件 @
58ad40cc
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/math/sample_prob.h"
namespace
paddle
{
namespace
operators
{
namespace
math
{
template
class
SampleWithProb
<
platform
::
CPUDeviceContext
,
float
>;
template
class
SampleWithProb
<
platform
::
CPUDeviceContext
,
double
>;
}
// namespace math
}
// namespace operators
}
// namespace paddle
paddle/fluid/operators/math/sample_prob.cu
0 → 100644
浏览文件 @
58ad40cc
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <thrust/random.h>
#include <thrust/sort.h>
#include <iostream>
#include <vector>
#include "paddle/fluid/framework/ddim.h"
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/operators/math/sample_prob.h"
#include "paddle/fluid/operators/math/sampler.h"
namespace
paddle
{
namespace
operators
{
namespace
math
{
using
Tensor
=
framework
::
Tensor
;
template
<
typename
T
>
__device__
T
gpu_adjust_prob
(
const
T
prob
,
const
int
num_samples
,
const
int
num_tries
)
{
if
(
num_samples
==
num_tries
)
{
return
prob
*
num_samples
;
}
else
{
return
-
expm1
(
num_tries
*
log1p
(
-
prob
));
}
}
class
GPULogUniformSampler
{
public:
__device__
int64_t
Sample
(
float
random
,
const
int
range
,
const
float
log_range
)
const
;
__device__
float
Probability
(
int64_t
value
,
const
float
log_range
)
const
;
};
__device__
int64_t
GPULogUniformSampler
::
Sample
(
float
random
,
const
int
range
,
const
float
log_range
)
const
{
// Got Log Uniform distribution from uniform distribution by
// inverse_transform_sampling method
const
int64_t
value
=
static_cast
<
int64_t
>
(
exp
(
random
*
log_range
))
-
1
;
// Mathematically, value should be <= range_, but might not be due to some
// floating point roundoff, so we mod by range_.
return
value
%
range
;
}
__device__
float
GPULogUniformSampler
::
Probability
(
int64_t
value
,
const
float
log_range
)
const
{
// Given f(x) = 1/[(x+1) * log_range_]
// The value's probability is integral of f(x) from value to (value + 1)
return
(
log
((
value
+
2.0
)
/
(
value
+
1.0
)))
/
log_range
;
}
template
<
typename
T
>
__global__
void
SamplingCondidate
(
const
size_t
n
,
const
int
num_tries
,
const
int
range
,
const
float
log_range
,
const
int
num_true
,
const
std
::
size_t
num_samples
,
const
int64_t
*
label_data
,
int64_t
*
samples_data
,
T
*
probabilities_data
)
{
const
int
num_sampled_classes
=
num_true
+
num_samples
;
int
idx
=
blockDim
.
x
*
blockIdx
.
x
+
threadIdx
.
x
;
int
step_size
=
0
;
GPULogUniformSampler
sampler
;
for
(;
idx
<
n
;
idx
+=
blockDim
.
x
*
gridDim
.
x
)
{
int
col_idx
=
idx
%
num_sampled_classes
;
int
row_idx
=
idx
/
num_sampled_classes
;
if
(
col_idx
<
num_true
)
{
samples_data
[
idx
]
=
label_data
[
row_idx
*
num_true
+
col_idx
];
}
else
{
samples_data
[
idx
]
=
samples_data
[
col_idx
];
}
probabilities_data
[
idx
]
=
sampler
.
Probability
(
samples_data
[
idx
],
log_range
);
probabilities_data
[
idx
]
=
gpu_adjust_prob
(
probabilities_data
[
idx
],
num_samples
,
num_tries
);
}
}
template
<
typename
T
>
int
UniqSampler
(
const
Sampler
&
sampler
,
const
std
::
size_t
num_samples
,
int64_t
*
samples_data
)
{
// sample num_samles unique samples for an example, note that they are not
// all negative samples
std
::
unordered_set
<
int64_t
>
tmp_samples
;
tmp_samples
.
clear
();
int
num_tries
=
0
;
int
j
=
0
;
while
(
j
<
num_samples
)
{
++
num_tries
;
auto
v
=
sampler
.
Sample
();
auto
insert_ok
=
tmp_samples
.
insert
(
v
).
second
;
if
(
!
insert_ok
)
{
continue
;
}
samples_data
[
j
]
=
v
;
++
j
;
}
return
num_tries
;
}
/*
template <typename T>
void Print(Tensor & t, std::string name) {
if (!FLAGS_debug_print) {
return;
}
VLOG(1) << "qxz print "<< name;
VLOG(1) << name << "size = " << t.numel();
size_t size = t.numel();
type *d = t.data<type>();
#ifdef PADDLE_WITH_CUDA
std::vector<type> vec;
platform::DeviceContextPool::Instance().Get(t.place())->Wait();
if (platform::is_gpu_place(t.place())) {
vec.resize(size);
cudaMemcpy(vec.data(), d, sizeof(T) * size, cudaMemcpyDeviceToHost);
d = vec.data();
}
#endif
VLOG(1) << name << " data_ptr = " << static_cast<void*>(d);
std::string out;
for (size_t i = 0; i < size; i++) {
out += std::to_string(d[i]);
out += ",";
}
VLOG(1) << out;
}*/
template
<
typename
T
>
void
GPUSampleWithProb
<
T
>::
operator
()(
const
platform
::
CUDADeviceContext
&
context
,
const
int
seed
,
const
int
dict_size
,
const
bool
uniq
,
const
std
::
size_t
num_samples
,
const
Tensor
*
L
,
Tensor
*
S
,
Tensor
*
P
)
{
// UNDERSTAND: dimension issues
const
auto
lbl_dim
=
L
->
dims
();
const
int
batch_size
=
lbl_dim
[
0
];
const
int
num_true
=
lbl_dim
[
1
];
const
int
num_sampled_classes
=
num_true
+
num_samples
;
framework
::
DDim
ret_dim
{
batch_size
,
num_sampled_classes
};
// UNDERSTAND: raw data view
const
int64_t
*
label_data
=
L
->
data
<
int64_t
>
();
int64_t
*
samples_data
=
S
->
data
<
int64_t
>
();
T
*
probabilities_data
=
P
->
data
<
T
>
();
int
s_size
=
num_samples
;
framework
::
DDim
s_dim
{
s_size
};
Tensor
s
;
int64_t
*
s_data
=
s
.
mutable_data
<
int64_t
>
(
s_dim
,
platform
::
CPUPlace
());
math
::
LogUniformSampler
sampler
(
dict_size
,
seed
);
int
range
=
dict_size
;
float
log_range
=
log
(
range
+
1
);
int
num_tries
=
UniqSampler
<
T
>
(
sampler
,
num_samples
,
s_data
);
VLOG
(
1
)
<<
"num_tries: "
<<
num_tries
;
PADDLE_ENFORCE
(
cudaMemcpy
(
samples_data
+
num_true
,
s_data
,
sizeof
(
int64_t
)
*
num_samples
,
cudaMemcpyHostToDevice
));
int
threads
=
512
;
const
size_t
size
=
batch_size
*
num_sampled_classes
;
int
grid
=
(
batch_size
*
num_sampled_classes
+
threads
-
1
)
/
threads
;
SamplingCondidate
<
T
><<<
grid
,
threads
,
0
,
context
.
stream
()
>>>
(
size
,
num_tries
,
range
,
log_range
,
num_true
,
num_samples
,
label_data
,
samples_data
,
probabilities_data
);
}
template
class
GPUSampleWithProb
<
float
>;
template
class
GPUSampleWithProb
<
double
>;
}
// namespace math
}
// namespace operators
}
// namespace paddle
paddle/fluid/operators/math/sample_prob.h
0 → 100644
浏览文件 @
58ad40cc
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <iostream>
#include <unordered_set>
#include <vector>
#include "paddle/fluid/framework/ddim.h"
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/operators/math/sampler.h"
namespace
paddle
{
namespace
operators
{
namespace
math
{
using
Tensor
=
framework
::
Tensor
;
/* UNDERSTAND: utility function to adjust probability for unique sampling,
return whatever as it is if not using unique samping */
template
<
typename
T
>
static
T
adjust_prob
(
const
T
prob
,
const
int
num_samples
,
const
int
num_tries
)
{
if
(
num_samples
==
num_tries
)
{
return
prob
*
num_samples
;
}
else
{
return
-
expm1
(
num_tries
*
log1p
(
-
prob
));
}
}
template
<
typename
DeviceContext
,
typename
T
>
class
SampleWithProb
{
public:
void
operator
()(
const
DeviceContext
&
context
,
const
Sampler
&
sampler
,
const
std
::
size_t
num_samples
,
const
Tensor
*
L
,
Tensor
*
S
,
Tensor
*
P
)
{
// UNDERSTAND: dimension issues
const
auto
lbl_dim
=
L
->
dims
();
const
int
batch_size
=
lbl_dim
[
0
];
const
int
num_true
=
lbl_dim
[
1
];
const
int
num_sampled_classes
=
num_true
+
num_samples
;
framework
::
DDim
ret_dim
{
batch_size
,
num_sampled_classes
};
// UNDERSTAND: raw data view
const
int64_t
*
label_data
=
L
->
data
<
int64_t
>
();
int64_t
*
samples_data
=
S
->
mutable_data
<
int64_t
>
(
ret_dim
,
context
.
GetPlace
());
T
*
probabilities_data
=
P
->
mutable_data
<
T
>
(
ret_dim
,
context
.
GetPlace
());
// temp sets for unique sampling
std
::
unordered_set
<
int64_t
>
tmp_samples
;
int
j
=
0
;
// column index
// add true labels, not that efficient
while
(
j
<
num_true
)
{
for
(
int
i
=
0
;
i
<
batch_size
;
++
i
)
{
auto
samples_index
=
i
*
num_sampled_classes
+
j
;
auto
v
=
label_data
[
i
*
num_true
+
j
];
samples_data
[
samples_index
]
=
v
;
probabilities_data
[
samples_index
]
=
sampler
.
Probability
(
v
);
}
++
j
;
}
// sample num_samles unique samples for an example, note that they are not
// all negative samples
tmp_samples
.
clear
();
int
num_tries
=
0
;
while
(
j
<
num_sampled_classes
)
{
++
num_tries
;
auto
v
=
sampler
.
Sample
();
auto
insert_ok
=
tmp_samples
.
insert
(
v
).
second
;
if
(
!
insert_ok
)
{
continue
;
}
auto
p
=
sampler
.
Probability
(
v
);
for
(
int
i
=
0
;
i
<
batch_size
;
++
i
)
{
auto
samples_index
=
i
*
num_sampled_classes
+
j
;
samples_data
[
samples_index
]
=
v
;
probabilities_data
[
samples_index
]
=
p
;
}
++
j
;
}
// compute Q(y|x), because of unique sampling, probabilities need to be
// adjusted
for
(
int
k
=
0
;
k
<
num_sampled_classes
;
++
k
)
{
for
(
int
i
=
0
;
i
<
batch_size
;
++
i
)
{
auto
samples_index
=
i
*
num_sampled_classes
+
k
;
probabilities_data
[
samples_index
]
=
adjust_prob
(
probabilities_data
[
samples_index
],
num_samples
,
num_tries
);
}
}
}
};
#ifdef PADDLE_WITH_CUDA
template
<
typename
T
>
class
GPUSampleWithProb
{
public:
void
operator
()(
const
platform
::
CUDADeviceContext
&
context
,
const
int
seed
,
const
int
dict_size
,
const
bool
uniq
,
const
std
::
size_t
num_samples
,
const
Tensor
*
L
,
Tensor
*
S
,
Tensor
*
P
);
};
#endif
}
// namespace math
}
// namespace operators
}
// namespace paddle
paddle/fluid/operators/sample_logits_op.cc
0 → 100644
浏览文件 @
58ad40cc
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/sample_logits_op.h"
#include "paddle/fluid/operators/math/sample_prob.h"
namespace
paddle
{
namespace
operators
{
class
SampleLogitsOpMaker
:
public
framework
::
OpProtoAndCheckerMaker
{
public:
void
Make
()
override
{
AddInput
(
"Logits"
,
"(Tensor, default: Tensor<float>), The unscaled log probabilities "
"which is a 2-D tensor with shape [N x K]. N is the batch_size, "
"and K is the class number."
);
AddInput
(
"Label"
,
"(Tensor) The ground truth which is a 2-D tensor. Label is a "
"Tensor<int64> with shape [N x NT], where NT is the number of"
"true labels for each example."
);
AddInput
(
"CustomSamples"
,
"(Tensor, default: Tensor<int64_t>), A 2-D tensor with shaoe [N x "
"S+NT]."
"The customized sample labels with true labels at first. This tensor"
"is only use_custom_samples is true."
)
.
AsDispensable
();
AddInput
(
"CustomProbabilities"
,
"(Tensor, default: Tensor<float>), A 2-D tensor with shaoe [N x S+NT]."
"The customized sample probabilities with true labels at first. This "
"tensor is only use_custom_samples is true."
)
.
AsDispensable
();
AddOutput
(
"Samples"
,
"(Tensor, default: Tensor<int64_t>), A 2-D tensor with shape [N x "
"S+NT]."
"The outputs value of sampler by given the true label, where S is the "
"number of negative sample for each example. So Samples includes NT "
"true"
"labels and S negative labels for each example. This will be used in"
"backward calculation."
)
.
AsIntermediate
();
AddOutput
(
"Probabilities"
,
"(Tensor, default: Tensor<float>), A 2-D tensor with shape [N x "
"S+NT]."
"The outputs value of progabilites of samples by given the true label, "
"where S is the "
"number of negative sample for each example. So Samples includes NT "
"true"
"labels and S negative labels for each example."
)
.
AsIntermediate
();
AddOutput
(
"SampledLogits"
,
"(Tensor, default: Tensor<float>), A 2-D tensor with shape"
"[N x S+NT]. The outputs value of sampled softmax, which will be"
"used in backward calculation."
)
.
AsIntermediate
();
AddOutput
(
"SampledLabel"
,
"(Tensor, default: Tensor<int64>), A 2-D tensor. The cross "
"entropy loss with shape [N x NT]."
);
AddAttr
<
bool
>
(
"use_custom_samples"
,
"An indicator whether to use custom samples with probabilities, if True"
"the operator will use custom samples and custom probabilities"
"otherwise, the operator will generate them by itself."
)
.
SetDefault
(
false
);
AddAttr
<
bool
>
(
"uniq"
,
"An indicator whether to sample non-repetitive negtive labels, if True"
"the operator will sample negtive labels without replacement."
"otherwise, the operator will sample negtive labels with replacement."
)
.
SetDefault
(
false
);
AddAttr
<
bool
>
(
"remove_accidental_hits"
,
"An indicator whether to remove accidental hits when samples hits true"
"labels, the removal is implemented by subtracting the corresponding"
"logits by float_max to subpress their softmax to be zero."
)
.
SetDefault
(
true
);
AddAttr
<
int
>
(
"num_samples"
,
"The number of negative samples."
);
AddAttr
<
int
>
(
"seed"
,
"Random seed for generating samples"
).
SetDefault
(
0
);
AddComment
(
R"DOC(
TODO(chenfeiyu): Write documentation for this Operator.
Sampled Softmax With Cross Entropy Operator.
Cross entropy loss with sampled softmax is used as the output layer extensively.
This operator computes the softmax normalized values for each row of the input
tensor, after which cross-entropy loss is computed. This provides a more
numerically stable gradient.
Because this operator performs a softmax on logits internally, it expects
unscaled logits. This operator should not be used with the output of
softmax operator since that would produce incorrect results.
When the attribute soft_label is set false, this operators expects mutually
exclusive hard labels, each sample in a batch is in exactly one class with a
probability of 1.0. Each sample in the batch will have a single label.
The equation is as follows:
1) Hard label (one-hot label, so every sample has exactly one class)
$$Loss_j = -\text{Logit}_{Label_j} +
\log\left(\sum_{i=0}^{K}\exp(\text{Logit}_i)\right),
j = 1,..., K$$
2) Soft label (each sample can have a distribution over all classes)
$$Loss_j = -\sum_{i=0}^{K}\text{Label}_i \left(\text{Logit}_i -
\log\left(\sum_{i=0}^{K}\exp(\text{Logit}_i)\right)\right),
j = 1,...,K$$
)DOC"
);
}
};
class
SampleLogitsOp
:
public
framework
::
OperatorWithKernel
{
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Logits"
),
"Input(Logits) should be not null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Label"
),
"Input(Label) should be not null."
);
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"Samples"
),
"Output(Samples) should be not null."
);
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"Probabilities"
),
"Output(Probabilities) should be not null."
);
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"SampledLogits"
),
"Output(SampledLogits) should be not null."
);
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"SampledLabel"
),
"Output(SampledLabel) should be not null."
);
auto
logits_dims
=
ctx
->
GetInputDim
(
"Logits"
);
auto
labels_dims
=
ctx
->
GetInputDim
(
"Label"
);
PADDLE_ENFORCE_EQ
(
logits_dims
.
size
(),
2UL
,
"The logits of softmax_with_cross_entropy should be a 2-D tensor."
);
PADDLE_ENFORCE_EQ
(
labels_dims
.
size
(),
2UL
,
"The labels should be a 2-D tensor."
);
const
int
num_samples
=
ctx
->
Attrs
().
Get
<
int
>
(
"num_samples"
);
const
int
num_sampled_classes
=
labels_dims
[
1
]
+
num_samples
;
ctx
->
SetOutputDim
(
"Samples"
,
{
logits_dims
[
0
],
num_sampled_classes
});
ctx
->
SetOutputDim
(
"Probabilities"
,
{
logits_dims
[
0
],
num_sampled_classes
});
ctx
->
SetOutputDim
(
"SampledLogits"
,
{
logits_dims
[
0
],
num_sampled_classes
});
ctx
->
SetOutputDim
(
"SampledLabel"
,
{
logits_dims
[
0
],
labels_dims
[
1
]});
}
protected:
framework
::
OpKernelType
GetExpectedKernelType
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
auto
data_type
=
framework
::
GetDataTypeOfVar
(
ctx
.
InputVar
(
"Logits"
));
framework
::
OpKernelType
kt
=
framework
::
OpKernelType
(
data_type
,
ctx
.
device_context
());
// kt.place_ = platform::CPUPlace();
return
kt
;
}
};
// UNDERSTAND: InferShape for Grad
class
SampleLogitsOpGrad
:
public
framework
::
OperatorWithKernel
{
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Logits"
),
"Input(Logits) should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Label"
),
"Input(Label) should be not null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Samples"
),
"Input(Samples) should be not null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"SampledLogits"
),
"Input(SampledLogits) should be not null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
framework
::
GradVarName
(
"SampledLogits"
)),
"Input(SampledLogits@Grad) should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
framework
::
GradVarName
(
"Logits"
)),
"Output(Logits@Grad) should be not null."
);
auto
logit_dims
=
ctx
->
GetInputDim
(
"Logits"
);
auto
label_dims
=
ctx
->
GetInputDim
(
"Label"
);
PADDLE_ENFORCE_EQ
(
label_dims
.
size
(),
2UL
,
"The label should be a 2-D tensor."
);
PADDLE_ENFORCE_EQ
(
logit_dims
.
size
(),
2UL
,
"The logits should be a 2-D tensor."
);
ctx
->
SetOutputDim
(
framework
::
GradVarName
(
"Logits"
),
ctx
->
GetInputDim
(
"Logits"
));
}
protected:
framework
::
OpKernelType
GetExpectedKernelType
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
auto
data_type
=
framework
::
GetDataTypeOfVar
(
ctx
.
InputVar
(
framework
::
GradVarName
(
"SampledLogits"
)));
framework
::
OpKernelType
kt
=
framework
::
OpKernelType
(
data_type
,
ctx
.
device_context
());
// kt.place_ = platform::CPUPlace();
return
kt
;
}
};
// UNDERSTAND: what's the rule for making a GradMaker TODO
class
SampleLogitsGradMaker
:
public
framework
::
SingleGradOpDescMaker
{
public:
using
framework
::
SingleGradOpDescMaker
::
SingleGradOpDescMaker
;
protected:
std
::
unique_ptr
<
framework
::
OpDesc
>
Apply
()
const
override
{
auto
*
grad_op
=
new
framework
::
OpDesc
();
grad_op
->
SetType
(
"sample_logits_grad"
);
grad_op
->
SetInput
(
"Logits"
,
Input
(
"Logits"
));
grad_op
->
SetInput
(
"Label"
,
Input
(
"Label"
));
grad_op
->
SetInput
(
"Samples"
,
Output
(
"Samples"
));
grad_op
->
SetInput
(
"SampledLogits"
,
Output
(
"SampledLogits"
));
grad_op
->
SetInput
(
framework
::
GradVarName
(
"SampledLogits"
),
OutputGrad
(
"SampledLogits"
));
grad_op
->
SetOutput
(
framework
::
GradVarName
(
"Logits"
),
InputGrad
(
"Logits"
));
grad_op
->
SetAttrMap
(
Attrs
());
return
std
::
unique_ptr
<
framework
::
OpDesc
>
(
grad_op
);
}
};
}
// namespace operators
}
// namespace paddle
namespace
ops
=
paddle
::
operators
;
REGISTER_OPERATOR
(
sample_logits
,
ops
::
SampleLogitsOp
,
ops
::
SampleLogitsOpMaker
,
ops
::
SampleLogitsGradMaker
);
REGISTER_OPERATOR
(
sample_logits_grad
,
ops
::
SampleLogitsOpGrad
);
REGISTER_OP_CPU_KERNEL
(
sample_logits
,
ops
::
SampleLogitsKernel
<
float
>
,
ops
::
SampleLogitsKernel
<
double
>
);
REGISTER_OP_CPU_KERNEL
(
sample_logits_grad
,
ops
::
SampleLogitsGradKernel
<
float
>
,
ops
::
SampleLogitsGradKernel
<
double
>
);
paddle/fluid/operators/sample_logits_op.cu
0 → 100644
浏览文件 @
58ad40cc
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <string>
#include <vector>
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/operators/math/sample_prob.h"
#include "paddle/fluid/operators/math/softmax.h"
#include "paddle/fluid/operators/sample_logits_op.h"
namespace
paddle
{
namespace
operators
{
DEFINE_bool
(
debug_print
,
true
,
"run debug mode"
);
// UNDERSTAND: something like take_along_axis in numpy.
template
<
typename
T
>
__global__
void
GPUTakeAlongD1
(
size_t
size
,
const
int
batch_size
,
const
int
array_slice_size
,
const
int
idx_slice_size
,
const
T
*
p_array
,
const
int64_t
*
p_index
,
T
*
p_value
)
{
const
auto
value_slice_size
=
idx_slice_size
;
int
idx
=
blockDim
.
x
*
blockIdx
.
x
+
threadIdx
.
x
;
int
step_size
=
blockDim
.
x
*
gridDim
.
x
;
for
(;
idx
<
size
;
idx
+=
step_size
)
{
int
i
=
idx
/
idx_slice_size
;
auto
array_index
=
p_index
[
idx
];
p_value
[
idx
]
=
p_array
[
i
*
array_slice_size
+
array_index
];
}
}
// UNDERSTAND: something like put_along_axis in numpy but if there is duplicate
// indices, scatter is done in += way.
template
<
typename
T
>
__global__
void
GPUPutAlongD1
(
size_t
size
,
const
int
batch_size
,
const
int
array_slice_size
,
const
int
idx_slice_size
,
T
*
p_array
,
const
int64_t
*
p_index
,
const
T
*
p_value
)
{
const
auto
value_slice_size
=
idx_slice_size
;
int
idx
=
blockDim
.
x
*
blockIdx
.
x
+
threadIdx
.
x
;
int
step_size
=
blockDim
.
x
*
gridDim
.
x
;
// size == batch_size
for
(;
idx
<
size
;
idx
+=
step_size
)
{
int
i
=
idx
;
for
(
int
j
=
0
;
j
<
idx_slice_size
;
++
j
)
{
auto
array_index
=
p_index
[
i
*
idx_slice_size
+
j
];
p_array
[
i
*
array_slice_size
+
array_index
]
+=
p_value
[
i
*
idx_slice_size
+
j
];
}
}
}
// UNDERSTAND: set label as 0,1,...,num_true-1
template
<
typename
T
>
__global__
void
GPUSetLabel
(
size_t
size
,
const
int
num_true
,
int64_t
*
p_array
)
{
int
idx
=
blockDim
.
x
*
blockIdx
.
x
+
threadIdx
.
x
;
int
step_size
=
blockDim
.
x
*
gridDim
.
x
;
for
(;
idx
<
size
;
idx
+=
step_size
)
{
p_array
[
idx
]
=
idx
%
num_true
;
}
}
// UNDERSTAND: compute accidentdal hits from samples and minus corresponding
// logits by a float max, here 1e20
template
<
typename
T
>
__global__
void
gpu_compute_remove_accidental_hits
(
const
int
size
,
const
int
num_true
,
const
int
idx_slice_size
,
const
int64_t
*
p_index
,
T
*
p_value
)
{
const
auto
value_slice_size
=
idx_slice_size
;
int
idx
=
blockDim
.
x
*
blockIdx
.
x
+
threadIdx
.
x
;
int
step_size
=
blockDim
.
x
*
gridDim
.
x
;
for
(;
idx
<
size
;
idx
+=
step_size
)
{
int
i
=
idx
/
idx_slice_size
;
if
(
idx
%
idx_slice_size
<
num_true
)
continue
;
for
(
int
j
=
0
;
j
<
num_true
;
++
j
)
{
const
auto
true_idx
=
i
*
idx_slice_size
+
j
;
if
(
p_index
[
true_idx
]
==
p_index
[
idx
])
{
p_value
[
idx
]
-=
1e20
;
break
;
}
}
}
}
template
<
typename
T
>
class
SampleLogitsCUDAKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
using
Tensor
=
framework
::
Tensor
;
template
<
typename
type
>
void
Print
(
const
Tensor
&
t
,
std
::
string
name
)
const
{
if
(
!
FLAGS_debug_print
)
{
return
;
}
VLOG
(
1
)
<<
"qxz print "
<<
name
;
VLOG
(
1
)
<<
name
<<
"size = "
<<
t
.
numel
();
size_t
size
=
t
.
numel
();
type
*
d
=
t
.
data
<
type
>
();
#ifdef PADDLE_WITH_CUDA
std
::
vector
<
type
>
vec
;
platform
::
DeviceContextPool
::
Instance
().
Get
(
t
.
place
())
->
Wait
();
if
(
platform
::
is_gpu_place
(
t
.
place
()))
{
vec
.
resize
(
size
);
cudaMemcpy
(
vec
.
data
(),
d
,
sizeof
(
T
)
*
size
,
cudaMemcpyDeviceToHost
);
d
=
vec
.
data
();
}
#endif
VLOG
(
1
)
<<
name
<<
" data_ptr = "
<<
static_cast
<
void
*>
(
d
);
std
::
string
out
;
for
(
size_t
i
=
0
;
i
<
size
;
i
++
)
{
out
+=
std
::
to_string
(
d
[
i
]);
out
+=
","
;
}
VLOG
(
1
)
<<
out
;
}
void
Compute
(
const
framework
::
ExecutionContext
&
context
)
const
override
{
// get necessary inputs
const
Tensor
*
logits
=
context
.
Input
<
Tensor
>
(
"Logits"
);
const
Tensor
*
label
=
context
.
Input
<
Tensor
>
(
"Label"
);
VLOG
(
3
)
<<
"Enter SampleLogitsCUDAKernel"
;
// get necessary outputs
Tensor
*
samples
=
context
.
Output
<
Tensor
>
(
"Samples"
);
Tensor
*
probabilities
=
context
.
Output
<
Tensor
>
(
"Probabilities"
);
Tensor
*
sampled_logits
=
context
.
Output
<
Tensor
>
(
"SampledLogits"
);
Tensor
*
sampled_label
=
context
.
Output
<
Tensor
>
(
"SampledLabel"
);
// shapes
const
auto
batch_size
=
logits
->
dims
()[
0
];
const
auto
num_classes
=
logits
->
dims
()[
1
];
const
auto
label_dim
=
label
->
dims
();
const
auto
num_true
=
label_dim
[
1
];
const
auto
samples_dim
=
samples
->
dims
();
// attrs
const
auto
num_samples
=
context
.
Attr
<
int
>
(
"num_samples"
);
const
bool
use_custom_samples
=
context
.
Attr
<
bool
>
(
"use_custom_samples"
);
const
bool
uniq
=
context
.
Attr
<
bool
>
(
"uniq"
);
const
bool
remove_accidental_hits
=
context
.
Attr
<
bool
>
(
"remove_accidental_hits"
);
// device contexts
auto
&
dev_ctx
=
context
.
cuda_device_context
();
// UNDERSTAND: allocate memories for temporaries
sampled_logits
->
mutable_data
<
T
>
(
samples_dim
,
context
.
GetPlace
());
math
::
SetConstant
<
platform
::
CUDADeviceContext
,
T
>
set_zero
;
set_zero
(
dev_ctx
,
sampled_logits
,
static_cast
<
T
>
(
0
));
auto
sampled_label_data
=
sampled_label
->
mutable_data
<
int64_t
>
(
label_dim
,
context
.
GetPlace
());
int
threads
=
512
;
size_t
size
=
batch_size
*
num_true
;
int
grid
=
(
size
+
threads
-
1
)
/
threads
;
GPUSetLabel
<
T
><<<
grid
,
threads
,
0
,
context
.
cuda_device_context
().
stream
()
>>>
(
size
,
num_true
,
sampled_label_data
);
if
(
use_custom_samples
)
{
const
Tensor
*
custom_samples
=
context
.
Input
<
Tensor
>
(
"CustomSamples"
);
const
Tensor
*
custom_probabilities
=
context
.
Input
<
Tensor
>
(
"CustomProbabilities"
);
samples
->
ShareDataWith
(
*
custom_samples
);
probabilities
->
ShareDataWith
(
*
custom_probabilities
);
}
else
{
samples
->
mutable_data
<
int64_t
>
(
context
.
GetPlace
());
probabilities
->
mutable_data
<
T
>
(
samples_dim
,
context
.
GetPlace
());
// UNDERSTAND: sampling
const
auto
seed
=
context
.
Attr
<
int
>
(
"seed"
);
auto
sampler_with_prob
=
math
::
GPUSampleWithProb
<
T
>
();
Print
<
int64_t
>
(
*
samples
,
std
::
string
(
"samples1"
));
sampler_with_prob
(
context
.
cuda_device_context
(),
seed
,
num_classes
,
uniq
,
num_samples
,
label
,
samples
,
probabilities
);
}
Print
<
int64_t
>
(
*
samples
,
std
::
string
(
"samples2"
));
Print
<
T
>
(
*
probabilities
,
std
::
string
(
"probabilities"
));
// UNDERSTAND: gather sampled logits and remove accidental hits if needed
const
auto
num_take
=
samples
->
dims
()[
1
];
const
auto
array_dims
=
logits
->
dims
();
const
auto
idx_dims
=
samples
->
dims
();
const
T
*
p_array
=
logits
->
data
<
T
>
();
const
int64_t
*
p_index
=
samples
->
data
<
int64_t
>
();
T
*
p_value
=
sampled_logits
->
data
<
T
>
();
// src slice size
const
auto
array_slice_size
=
array_dims
[
1
];
// index slice size
const
auto
idx_slice_size
=
idx_dims
[
1
];
size
=
batch_size
*
num_take
;
grid
=
(
size
+
threads
-
1
)
/
threads
;
GPUTakeAlongD1
<
T
><<<
grid
,
threads
,
0
,
context
.
cuda_device_context
().
stream
()
>>>
(
size
,
batch_size
,
array_slice_size
,
idx_slice_size
,
p_array
,
p_index
,
p_value
);
Print
<
T
>
(
*
sampled_logits
,
std
::
string
(
"sampled_logits"
));
if
(
remove_accidental_hits
)
{
const
size_t
size
=
batch_size
*
(
num_true
+
num_samples
);
int
grid
=
(
size
+
threads
-
1
)
/
threads
;
gpu_compute_remove_accidental_hits
<
T
><<<
grid
,
threads
,
0
,
context
.
cuda_device_context
().
stream
()
>>>
(
size
,
num_true
,
idx_slice_size
,
p_index
,
p_value
);
Print
<
T
>
(
*
sampled_logits
,
std
::
string
(
"sampled_logits_remove_accidental_hits"
));
}
// subtracted sampled logits with logQ(y|x)
auto
probs
=
EigenMatrix
<
T
>::
From
(
*
probabilities
);
auto
smp_logits
=
EigenMatrix
<
T
>::
From
(
*
sampled_logits
);
smp_logits
.
device
(
*
dev_ctx
.
eigen_device
())
=
(
smp_logits
-
probs
.
log
().
unaryExpr
(
TolerableValue
<
T
>
()))
.
unaryExpr
(
TolerableValue
<
T
>
());
Print
<
T
>
(
*
sampled_logits
,
std
::
string
(
"sampled_logits_res"
));
}
};
template
<
typename
T
>
class
SampleLogitsGradCUDAKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
using
Tensor
=
framework
::
Tensor
;
template
<
typename
type
>
void
Print
(
const
Tensor
&
t
,
std
::
string
name
)
const
{
if
(
!
FLAGS_debug_print
)
{
return
;
}
VLOG
(
1
)
<<
"qxz print "
<<
name
;
VLOG
(
1
)
<<
name
<<
"size = "
<<
t
.
numel
();
size_t
size
=
t
.
numel
();
const
type
*
d
=
t
.
data
<
type
>
();
#ifdef PADDLE_WITH_CUDA
std
::
vector
<
type
>
vec
;
platform
::
DeviceContextPool
::
Instance
().
Get
(
t
.
place
())
->
Wait
();
if
(
platform
::
is_gpu_place
(
t
.
place
()))
{
vec
.
resize
(
size
);
cudaMemcpy
(
vec
.
data
(),
d
,
sizeof
(
T
)
*
size
,
cudaMemcpyDeviceToHost
);
d
=
vec
.
data
();
}
#endif
VLOG
(
1
)
<<
name
<<
" data_ptr = "
<<
static_cast
<
const
void
*>
(
d
);
std
::
string
out
;
for
(
size_t
i
=
0
;
i
<
size
;
i
++
)
{
out
+=
std
::
to_string
(
d
[
i
]);
out
+=
","
;
}
VLOG
(
1
)
<<
out
;
}
void
Compute
(
const
framework
::
ExecutionContext
&
context
)
const
override
{
auto
logits_grad
=
context
.
Output
<
Tensor
>
(
framework
::
GradVarName
(
"Logits"
));
const
Tensor
*
samples
=
context
.
Input
<
Tensor
>
(
"Samples"
);
const
Tensor
*
sampled_logits_grad
=
context
.
Input
<
Tensor
>
(
framework
::
GradVarName
(
"SampledLogits"
));
logits_grad
->
mutable_data
<
T
>
(
context
.
GetPlace
());
auto
&
dev_ctx
=
context
.
cuda_device_context
();
math
::
SetConstant
<
platform
::
CUDADeviceContext
,
T
>
set_zero
;
set_zero
(
dev_ctx
,
logits_grad
,
static_cast
<
T
>
(
0
));
// UNDERSTAND: scatter it back to logit_grad
const
auto
batch_size
=
samples
->
dims
()[
0
];
const
auto
num_put
=
samples
->
dims
()[
1
];
const
auto
array_dims
=
logits_grad
->
dims
();
const
auto
idx_dims
=
samples
->
dims
();
T
*
p_array
=
logits_grad
->
data
<
T
>
();
const
int64_t
*
p_index
=
samples
->
data
<
int64_t
>
();
const
T
*
p_value
=
sampled_logits_grad
->
data
<
T
>
();
// src slice size
const
auto
array_slice_size
=
array_dims
[
1
];
// index slice size
const
auto
idx_slice_size
=
idx_dims
[
1
];
int
threads
=
128
;
const
size_t
size
=
batch_size
;
int
grid
=
(
size
+
threads
-
1
)
/
threads
;
Print
<
T
>
(
*
sampled_logits_grad
,
std
::
string
(
"sampled_logits_grad"
));
Print
<
int64_t
>
(
*
samples
,
std
::
string
(
"samples"
));
GPUPutAlongD1
<
T
><<<
grid
,
threads
,
0
,
context
.
cuda_device_context
().
stream
()
>>>
(
size
,
batch_size
,
array_slice_size
,
idx_slice_size
,
p_array
,
p_index
,
p_value
);
Print
<
T
>
(
*
logits_grad
,
std
::
string
(
"logits_grad"
));
}
};
}
// namespace operators
}
// namespace paddle
namespace
ops
=
paddle
::
operators
;
REGISTER_OP_CUDA_KERNEL
(
sample_logits
,
ops
::
SampleLogitsCUDAKernel
<
float
>
,
ops
::
SampleLogitsCUDAKernel
<
double
>
);
REGISTER_OP_CUDA_KERNEL
(
sample_logits_grad
,
ops
::
SampleLogitsGradCUDAKernel
<
float
>
,
ops
::
SampleLogitsGradCUDAKernel
<
double
>
);
paddle/fluid/operators/sample_logits_op.h
0 → 100644
浏览文件 @
58ad40cc
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <vector>
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/operators/math/sample_prob.h"
#include "paddle/fluid/operators/math/softmax.h"
namespace
paddle
{
namespace
operators
{
using
Tensor
=
framework
::
Tensor
;
template
<
typename
T
,
int
MajorType
=
Eigen
::
RowMajor
,
typename
IndexType
=
Eigen
::
DenseIndex
>
using
EigenMatrix
=
framework
::
EigenMatrix
<
T
,
MajorType
,
IndexType
>
;
template
<
typename
T
>
struct
TolerableValue
{
HOSTDEVICE
T
operator
()(
const
T
&
x
)
const
{
PADDLE_ASSERT
(
std
::
is_floating_point
<
T
>::
value
);
const
T
kApproInf
=
1e20
;
if
(
x
==
INFINITY
)
return
kApproInf
;
if
(
x
==
-
INFINITY
)
return
-
kApproInf
;
return
x
;
}
};
// UNDERSTAND: something like take_along_axis in numpy.
template
<
typename
T
>
static
void
CPUTakeAlongD1
(
const
platform
::
DeviceContext
&
ctx
,
const
framework
::
Tensor
&
array
,
const
framework
::
Tensor
&
index
,
framework
::
Tensor
*
value
)
{
PADDLE_ENFORCE
(
platform
::
is_cpu_place
(
ctx
.
GetPlace
()));
// UNDERSTAND: check shape src(B, C), index(B, K), out should also be (B, K)
PADDLE_ENFORCE
(
index
.
dims
().
size
()
==
2
&&
array
.
dims
().
size
()
==
2
&&
index
.
dims
()[
0
]
==
array
.
dims
()[
0
]
&&
index
.
dims
()
==
value
->
dims
());
const
auto
batch_size
=
index
.
dims
()[
0
];
const
auto
num_take
=
index
.
dims
()[
1
];
const
auto
array_dims
=
array
.
dims
();
const
auto
idx_dims
=
index
.
dims
();
// UNDERSTAND: no allocations here
const
T
*
p_array
=
array
.
data
<
T
>
();
const
int64_t
*
p_index
=
index
.
data
<
int64_t
>
();
T
*
p_value
=
value
->
data
<
T
>
();
// src slice size
const
auto
array_slice_size
=
array_dims
[
1
];
// index slice size
const
auto
idx_slice_size
=
idx_dims
[
1
];
const
auto
value_slice_size
=
idx_slice_size
;
for
(
int
i
=
0
;
i
<
batch_size
;
++
i
)
{
for
(
int
j
=
0
;
j
<
num_take
;
++
j
)
{
auto
array_index
=
p_index
[
i
*
idx_slice_size
+
j
];
p_value
[
i
*
value_slice_size
+
j
]
=
p_array
[
i
*
array_slice_size
+
array_index
];
}
}
}
// UNDERSTAND: something like put_along_axis in numpy but if there is duplicate
// indices, scatter is done in += way.
template
<
typename
T
>
static
void
CPUPutAlongD1
(
const
platform
::
DeviceContext
&
ctx
,
framework
::
Tensor
*
array
,
const
framework
::
Tensor
&
index
,
const
framework
::
Tensor
&
value
)
{
PADDLE_ENFORCE
(
platform
::
is_cpu_place
(
ctx
.
GetPlace
()));
// UNDERSTAND: check shape src(B, C), index(B, K), out should also be (B, K)
PADDLE_ENFORCE
(
index
.
dims
().
size
()
==
2
&&
array
->
dims
().
size
()
==
2
&&
index
.
dims
()[
0
]
==
array
->
dims
()[
0
]
&&
index
.
dims
()
==
value
.
dims
());
const
auto
batch_size
=
index
.
dims
()[
0
];
const
auto
num_put
=
index
.
dims
()[
1
];
auto
array_dims
=
array
->
dims
();
auto
idx_dims
=
index
.
dims
();
// UNDERSTAND: no allocations here
T
*
p_array
=
array
->
data
<
T
>
();
const
int64_t
*
p_index
=
index
.
data
<
int64_t
>
();
const
T
*
p_value
=
value
.
data
<
T
>
();
// slice sizes
const
auto
array_slice_size
=
array_dims
[
1
];
const
auto
idx_slice_size
=
idx_dims
[
1
];
const
auto
value_slice_size
=
idx_slice_size
;
for
(
int
i
=
0
;
i
<
batch_size
;
++
i
)
{
for
(
int
j
=
0
;
j
<
num_put
;
++
j
)
{
auto
array_index
=
p_index
[
i
*
idx_slice_size
+
j
];
p_array
[
i
*
array_slice_size
+
array_index
]
+=
p_value
[
i
*
value_slice_size
+
j
];
}
}
}
// UNDERSTAND: compute accidentdal hits from samples and minus corresponding
// logits by a float max, here 1e20
template
<
typename
T
>
static
void
compute_remove_accidental_hits
(
const
platform
::
DeviceContext
&
ctx
,
framework
::
Tensor
*
sampled_logits
,
const
framework
::
Tensor
&
samples
,
const
int
num_true
)
{
const
auto
batch_size
=
sampled_logits
->
dims
()[
0
];
const
auto
num_sampled_classes
=
sampled_logits
->
dims
()[
1
];
T
*
sampled_logits_data
=
sampled_logits
->
data
<
T
>
();
const
auto
samples_data
=
samples
.
data
<
int64_t
>
();
std
::
unordered_set
<
int64_t
>
tmp_true_labels
;
for
(
int
i
=
0
;
i
<
batch_size
;
++
i
)
{
tmp_true_labels
.
clear
();
tmp_true_labels
.
insert
(
samples_data
+
i
*
num_sampled_classes
,
samples_data
+
i
*
num_sampled_classes
+
num_true
);
for
(
int
j
=
num_true
;
j
<
num_sampled_classes
;
++
j
)
{
const
auto
idx
=
i
*
num_sampled_classes
+
j
;
if
(
tmp_true_labels
.
find
(
samples_data
[
idx
])
!=
tmp_true_labels
.
end
())
sampled_logits_data
[
idx
]
-=
1e20
;
}
}
}
template
<
typename
T
>
class
SampleLogitsKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
using
Tensor
=
framework
::
Tensor
;
void
Compute
(
const
framework
::
ExecutionContext
&
context
)
const
override
{
PADDLE_ENFORCE
(
platform
::
is_cpu_place
(
context
.
GetPlace
()),
"This kernel only runs on CPU."
);
VLOG
(
3
)
<<
"Enter SampleLogitsKernel"
;
// get necessary inputs
const
Tensor
*
logits
=
context
.
Input
<
Tensor
>
(
"Logits"
);
const
Tensor
*
label
=
context
.
Input
<
Tensor
>
(
"Label"
);
// get necessary outputs
Tensor
*
samples
=
context
.
Output
<
Tensor
>
(
"Samples"
);
Tensor
*
probabilities
=
context
.
Output
<
Tensor
>
(
"Probabilities"
);
Tensor
*
sampled_logits
=
context
.
Output
<
Tensor
>
(
"SampledLogits"
);
Tensor
*
sampled_label
=
context
.
Output
<
Tensor
>
(
"SampledLabel"
);
// shapes
const
auto
batch_size
=
logits
->
dims
()[
0
];
const
auto
num_classes
=
logits
->
dims
()[
1
];
const
auto
label_dim
=
label
->
dims
();
const
auto
num_true
=
label_dim
[
1
];
const
auto
samples_dim
=
samples
->
dims
();
// attrs
const
auto
num_samples
=
context
.
Attr
<
int
>
(
"num_samples"
);
const
bool
use_custom_samples
=
context
.
Attr
<
bool
>
(
"use_custom_samples"
);
const
bool
remove_accidental_hits
=
context
.
Attr
<
bool
>
(
"remove_accidental_hits"
);
// device contexts
auto
&
dev_ctx
=
context
.
template
device_context
<
platform
::
CPUDeviceContext
>();
// UNDERSTAND: allocate memories for temporaries
sampled_logits
->
mutable_data
<
T
>
(
samples_dim
,
context
.
GetPlace
());
auto
sampled_label_data
=
sampled_label
->
mutable_data
<
int64_t
>
(
label_dim
,
context
.
GetPlace
());
for
(
int
i
=
0
;
i
<
batch_size
;
++
i
)
for
(
int
j
=
0
;
j
<
num_true
;
++
j
)
sampled_label_data
[
i
*
num_true
+
j
]
=
j
;
if
(
use_custom_samples
)
{
const
Tensor
*
custom_samples
=
context
.
Input
<
Tensor
>
(
"CustomSamples"
);
const
Tensor
*
custom_probabilities
=
context
.
Input
<
Tensor
>
(
"CustomProbabilities"
);
samples
->
ShareDataWith
(
*
custom_samples
);
probabilities
->
ShareDataWith
(
*
custom_probabilities
);
}
else
{
samples
->
mutable_data
<
int64_t
>
(
context
.
GetPlace
());
probabilities
->
mutable_data
<
T
>
(
samples_dim
,
context
.
GetPlace
());
// UNDERSTAND: sampling
const
auto
seed
=
context
.
Attr
<
int
>
(
"seed"
);
auto
sampler_with_prob
=
math
::
SampleWithProb
<
platform
::
CPUDeviceContext
,
T
>
();
sampler_with_prob
(
dev_ctx
,
math
::
LogUniformSampler
(
num_classes
,
seed
),
num_samples
,
label
,
samples
,
probabilities
);
}
// UNDERSTAND: gather sampled logits and remove accidental hits if needed
CPUTakeAlongD1
<
T
>
(
dev_ctx
,
*
logits
,
*
samples
,
sampled_logits
);
if
(
remove_accidental_hits
)
{
compute_remove_accidental_hits
<
T
>
(
dev_ctx
,
sampled_logits
,
*
samples
,
num_true
);
}
/* Debug
const auto num_sampled_classes = samples_dim[1];
std::cout << "Sampled Logits" << std::endl;
const auto sampled_logits_data = sampled_logits->data<T>();
for (int i = 0; i < sampled_logits->numel(); ++i) {
std::cout << sampled_logits_data[i] << ", ";
if ((i + 1) % num_sampled_classes == 0)
std::cout << std::endl;
}
std::cout << std::endl;
*/
/* Debug
std::cout << "Samples" << std::endl;
const auto samples_data = samples->data<int64_t>();
for (int i = 0; i < samples->numel(); ++i) {
std::cout << samples_data[i] << ", ";
if ((i + 1) % num_sampled_classes == 0)
std::cout << std::endl;
}
std::cout << std::endl;
*/
/* Debug
std::cout << "Probabilities" << std::endl;
const auto probabilities_data = probabilities->data<T>();
for (int i = 0; i < probabilities->numel(); ++i) {
std::cout << probabilities_data[i] << ", ";
if ((i + 1) % num_sampled_classes == 0)
std::cout << std::endl;
}
std::cout << std::endl;
*/
// subtracted sampled logits with logQ(y|x)
auto
probs
=
EigenMatrix
<
T
>::
From
(
*
probabilities
);
auto
smp_logits
=
EigenMatrix
<
T
>::
From
(
*
sampled_logits
);
smp_logits
.
device
(
*
dev_ctx
.
eigen_device
())
=
(
smp_logits
-
probs
.
log
().
unaryExpr
(
TolerableValue
<
T
>
()))
.
unaryExpr
(
TolerableValue
<
T
>
());
}
};
template
<
typename
T
>
class
SampleLogitsGradKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
using
Tensor
=
framework
::
Tensor
;
void
Compute
(
const
framework
::
ExecutionContext
&
context
)
const
override
{
auto
logits_grad
=
context
.
Output
<
Tensor
>
(
framework
::
GradVarName
(
"Logits"
));
const
Tensor
*
samples
=
context
.
Input
<
Tensor
>
(
"Samples"
);
const
Tensor
*
sampled_logits_grad
=
context
.
Input
<
Tensor
>
(
framework
::
GradVarName
(
"SampledLogits"
));
logits_grad
->
mutable_data
<
T
>
(
context
.
GetPlace
());
auto
&
dev_ctx
=
context
.
template
device_context
<
platform
::
CPUDeviceContext
>();
math
::
SetConstant
<
platform
::
CPUDeviceContext
,
T
>
set_zero
;
set_zero
(
dev_ctx
,
logits_grad
,
static_cast
<
T
>
(
0
));
// const bool remove_accidental_hits =
// context.Attr<bool>("remove_accidental_hits");
// UNDERSTAND: scatter it back to logit_grad
CPUPutAlongD1
<
T
>
(
dev_ctx
,
logits_grad
,
*
samples
,
*
sampled_logits_grad
);
}
};
}
// namespace operators
}
// namespace paddle
python/paddle/fluid/__init__.py
浏览文件 @
58ad40cc
...
...
@@ -131,7 +131,7 @@ def __bootstrap__():
'eager_delete_tensor_gb'
,
'fast_eager_deletion_mode'
,
'allocator_strategy'
,
'reader_queue_speed_test_mode'
,
'print_sub_graph_dir'
,
'pe_profile_fname'
,
'warpctc_dir'
,
'inner_op_parallelism'
,
'enable_parallel_graph'
'inner_op_parallelism'
,
'enable_parallel_graph'
,
'debug_print'
]
if
'Darwin'
not
in
sysstr
:
read_env_flags
.
append
(
'use_pinned_memory'
)
...
...
python/paddle/fluid/layers/nn.py
浏览文件 @
58ad40cc
...
...
@@ -87,6 +87,7 @@ __all__ = [
'transpose'
,
'im2sequence'
,
'nce'
,
'sample_logits'
,
'hsigmoid'
,
'beam_search'
,
'row_conv'
,
...
...
@@ -5764,6 +5765,104 @@ def softmax_with_cross_entropy(logits,
return
loss
def
sample_logits
(
logits
,
label
,
num_samples
,
uniq
=
True
,
remove_accidental_hits
=
True
,
use_custom_samples
=
False
,
custom_samples
=
None
,
custom_probabilities
=
None
,
seed
=
0
):
"""
**Sampled Softmax With Cross Entropy Operator.**
Cross entropy loss with sampled softmax is used as the output layer for
larger output classes extensively. This operator samples a number of samples
for each example(row), and computes the softmax normalized values for each
row of the sampled tensor, after which cross-entropy loss is computed.
This provides a more numerically stable gradient.
Because this operator performs a softmax on logits internally, it expects
unscaled logits. This operator should not be used with the output of
softmax operator since that would produce incorrect results.
For examples with T true labels (T >= 1), we assume that each true label has
a probability of 1/T. For each sample, S samples are generated using a
log uniform distribution. True labels are concatenated with hese samples to
form T + S samples for each example. So, assume the shape of logits is
[N x K], the shape for samples is [N x (T+S)]. For each sampled label, a
probability is calculated, which corresponds to the Q(y|x) in
[Jean et al., 2014](http://arxiv.org/abs/1412.2007).
Logits are sampled according to the sampled labels. Then if
remove_accidental_hits is True, if a sample[i, j] accidentally hits true
labels, then the corresponding sampled_logits[i, j] is minus by 1e20 to
make its softmax result close to zero. Then samled logits are subtracted by
logQ(y|x), these sampled logits and re-indexed labels are used to compute
a softmax with cross entropy.
Args:
logits (Variable): The unscaled log probabilities, which is a 2-D tensor
with shape [N x K]. N is the batch_size, and K is the class number.
label (Variable): The ground truth which is a 2-D tensor. Label is a
Tensor<int64> with shape [N x T], where T is the number of true
labels per example.
num_samples (int): The number for each example, num_samples should be
less than the number of class.
seed (int): The random seed for generating random number, which is used
in the process of sampling. Default is 0.
remove_accidental_hits (bool): A flag indicating whether to remove
accidental hits when sampling. If True and if a sample[i, j]
accidentally hits true labels, then the corresponding
sampled_logits[i, j] is minus by 1e20 to make its softmax result
close to zero. Default is True.
Returns:
Variable: Return the cross entropy loss which is a 2-D tensor with shape
[N x 1].
Examples:
.. code-block:: python
logits = fluid.layers.data(name='data', shape=[256], dtype='float32')
label = fluid.layers.data(name='label', shape=[5], dtype='int64')
fc = fluid.layers.fc(input=data, size=100)
out = fluid.layers.sampled_softmax_with_cross_entropy(
logits=fc, label=label, num_samples=25)
"""
helper
=
LayerHelper
(
'sample_logits'
,
**
locals
())
samples
=
helper
.
create_variable_for_type_inference
(
dtype
=
'int64'
)
probabilities
=
helper
.
create_variable_for_type_inference
(
dtype
=
logits
.
dtype
)
sampled_logits
\
=
helper
.
create_variable_for_type_inference
(
dtype
=
logits
.
dtype
)
sampled_label
=
helper
.
create_variable_for_type_inference
(
dtype
=
'int64'
)
helper
.
append_op
(
type
=
'sample_logits'
,
inputs
=
{
'Logits'
:
logits
,
'Label'
:
label
,
'CustomSamples'
:
custom_samples
,
'CustomProbabilities'
:
custom_probabilities
},
outputs
=
{
'Samples'
:
samples
,
'Probabilities'
:
probabilities
,
'SampledLabel'
:
sampled_label
,
'SampledLogits'
:
sampled_logits
},
attrs
=
{
'use_custom_samples'
:
use_custom_samples
,
'uniq'
:
uniq
,
'remove_accidental_hits'
:
remove_accidental_hits
,
'num_samples'
:
num_samples
,
'seed'
:
seed
})
return
sampled_logits
,
sampled_label
,
samples
,
probabilities
def
smooth_l1
(
x
,
y
,
inside_weight
=
None
,
outside_weight
=
None
,
sigma
=
None
):
"""
This layer computes the smooth L1 loss for Variable :attr:`x` and :attr:`y`.
...
...
python/paddle/fluid/tests/unittests/op_test.py
浏览文件 @
58ad40cc
...
...
@@ -350,6 +350,7 @@ class OpTest(unittest.TestCase):
actual_t
=
np
.
array
(
actual
)
expect
=
self
.
outputs
[
out_name
]
expect_t
=
expect
[
0
]
if
isinstance
(
expect
,
tuple
)
else
expect
#import pdb; pdb.set_trace()
self
.
assertTrue
(
np
.
allclose
(
actual_t
,
expect_t
,
atol
=
atol
,
equal_nan
=
equal_nan
),
...
...
python/paddle/fluid/tests/unittests/test_layers.py
浏览文件 @
58ad40cc
...
...
@@ -374,6 +374,16 @@ class TestBook(unittest.TestCase):
self
.
assertIsNotNone
(
output
)
print
(
str
(
program
))
def
test_sample_logits
(
self
):
program
=
Program
()
with
program_guard
(
program
):
logits
=
layers
.
data
(
name
=
'Logits'
,
shape
=
[
256
],
dtype
=
'float64'
)
label
=
layers
.
data
(
name
=
'Label'
,
shape
=
[
5
],
dtype
=
'int64'
)
num_samples
=
25
output
=
layers
.
sample_logits
(
logits
,
label
,
num_samples
)
self
.
assertIsNotNone
(
output
)
print
(
str
(
program
))
@
decorators
.
prog_scope
()
def
test_nce
(
self
):
window_size
=
5
...
...
python/paddle/fluid/tests/unittests/test_sample_logits.py
0 → 100644
浏览文件 @
58ad40cc
此差异已折叠。
点击以展开。
python/paddle/fluid/tests/unittests/testsuite.py
浏览文件 @
58ad40cc
...
...
@@ -156,8 +156,26 @@ def append_input_output(block, op_proto, np_list, is_input, dtype):
return
var_dict
def
var_cast
(
block
,
input
):
if
input
.
dtype
==
core
.
VarDesc
.
VarType
.
FP32
or
input
.
dtype
==
core
.
VarDesc
.
VarType
.
FP32
:
return
input
out
=
block
.
create_var
(
dtype
=
"float32"
,
shape
=
[
1
])
op
=
block
.
append_op
(
inputs
=
{
"X"
:
input
},
outputs
=
{
"Out"
:
out
},
type
=
'cast'
,
attrs
=
{
'out_dtype'
:
core
.
VarDesc
.
VarType
.
FP32
,
'in_dtype'
:
input
.
dtype
})
op
.
desc
.
infer_var_type
(
block
.
desc
)
op
.
desc
.
infer_shape
(
block
.
desc
)
return
out
def
append_loss_ops
(
block
,
output_names
):
mean_inputs
=
list
(
map
(
block
.
var
,
output_names
))
mean_inputs
=
[
var_cast
(
block
,
x
)
for
x
in
mean_inputs
]
if
len
(
mean_inputs
)
==
1
:
loss
=
block
.
create_var
(
dtype
=
mean_inputs
[
0
].
dtype
,
shape
=
[
1
])
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录