Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
c5360a3f
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
c5360a3f
编写于
2月 19, 2019
作者:
X
xuezhong
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
refine code
上级
44240216
变更
5
隐藏空白更改
内联
并排
Showing
5 changed file
with
163 addition
and
158 deletion
+163
-158
paddle/fluid/operators/sample_logits_op.cc
paddle/fluid/operators/sample_logits_op.cc
+49
-49
paddle/fluid/operators/sample_logits_op.cu
paddle/fluid/operators/sample_logits_op.cu
+18
-16
paddle/fluid/operators/sample_logits_op.h
paddle/fluid/operators/sample_logits_op.h
+22
-18
python/paddle/fluid/layers/nn.py
python/paddle/fluid/layers/nn.py
+14
-12
python/paddle/fluid/tests/unittests/test_sample_logits.py
python/paddle/fluid/tests/unittests/test_sample_logits.py
+60
-63
未找到文件。
paddle/fluid/operators/sample_logits_op.cc
浏览文件 @
c5360a3f
...
@@ -25,63 +25,64 @@ class SampleLogitsOpMaker : public framework::OpProtoAndCheckerMaker {
...
@@ -25,63 +25,64 @@ class SampleLogitsOpMaker : public framework::OpProtoAndCheckerMaker {
"(Tensor, default: Tensor<float>), The unscaled log probabilities "
"(Tensor, default: Tensor<float>), The unscaled log probabilities "
"which is a 2-D tensor with shape [N x K]. N is the batch_size, "
"which is a 2-D tensor with shape [N x K]. N is the batch_size, "
"and K is the class number."
);
"and K is the class number."
);
AddInput
(
"Label"
,
AddInput
(
"Label
s
"
,
"(Tensor) The ground truth which is a 2-D tensor. Label is a "
"(Tensor) The ground truth which is a 2-D tensor. Label
s
is a "
"Tensor<int64> with shape [N x NT], where NT is the number of"
"Tensor<int64> with shape [N x NT], where NT is the number of"
"true labels for each example."
);
"true labels for each example."
);
AddInput
(
AddInput
(
"CustomizedSamples"
,
"CustomSamples"
,
"(Tensor, default: Tensor<int64_t>), A 2-D tensor with shape [N, "
"(Tensor, default: Tensor<int64_t>), A 2-D tensor with shaoe [N x "
"NT + S],"
"S+NT]."
" where N is the batch size, NT is the number of true labels "
"The customized sample labels with true labels at first. This tensor"
"and S is the number of negtive sample for each example."
"is only use_custom_samples is true."
)
"The first NT elements of each row should be the same with true "
"labels, "
"followed by S custom negtive samples. This tensor"
"is only used when use_customized_samples is true."
)
.
AsDispensable
();
.
AsDispensable
();
AddInput
(
AddInput
(
"CustomProbabilities"
,
"CustomizedProbabilities"
,
"(Tensor, default: Tensor<float>), A 2-D tensor with shaoe [N x S+NT]."
"(Tensor, default: Tensor<float>), A 2-D tensor with shape [N, NT + S]."
"The customized sample probabilities with true labels at first. This "
"The tensor has the same shape with CustomSamples,"
"tensor is only use_custom_samples is true."
)
"and each element represents probability of element in CustomSamples. "
"This "
"tensor is only used when use_customized_samples is true."
)
.
AsDispensable
();
.
AsDispensable
();
AddOutput
(
AddOutput
(
"Samples"
,
"Samples"
,
"(Tensor, default: Tensor<int64_t>), A 2-D tensor with shape [N, "
"(Tensor, default: Tensor<int64_t>), A 2-D tensor with shape [N x "
"NT + S]."
"S+NT]."
"The outputs value of sampler, including NT true lables and S "
"The outputs value of sampler by given the true label, where S is the "
"negetive samples "
"number of negative sample for each example. So Samples includes NT "
"for each example. This will be used in"
"true"
"backward calculation."
)
"labels and S negative labels for each example. This will be used in"
"backward calculation."
)
.
AsIntermediate
();
.
AsIntermediate
();
AddOutput
(
AddOutput
(
"Probabilities"
,
"Probabilities"
,
"(Tensor, default: Tensor<float>), A 2-D tensor with shape [N x "
"(Tensor, default: Tensor<float>), A 2-D tensor with shape [N, NT + S]."
"S+NT]."
"The probabilites of sampled positive and negtive labels."
)
"The outputs value of progabilites of samples by given the true label, "
"where S is the "
"number of negative sample for each example. So Samples includes NT "
"true"
"labels and S negative labels for each example."
)
.
AsIntermediate
();
.
AsIntermediate
();
AddOutput
(
"SampledLogits"
,
AddOutput
(
"SampledLogits"
,
"(Tensor, default: Tensor<float>), A 2-D tensor with shape"
"(Tensor, default: Tensor<float>), A 2-D tensor with shape"
"[N
x S+NT]. The outputs value of sample
logits, which will be"
"[N
, NT + S]. The outputs value of sampled
logits, which will be"
"used in backward
calcul
ation."
)
"used in backward
propag
ation."
)
.
AsIntermediate
();
.
AsIntermediate
();
AddOutput
(
AddOutput
(
"SampledLabel"
,
"SampledLabels"
,
"(Tensor, default: Tensor<int64>), A 2-D tensor. The sampled label"
"(Tensor, default: Tensor<int64>), A 2-D tensor. The sampled labels"
"with shape [N x S + NT]."
);
"with shape [N, NT]. The tonsor contains hard labels as input to "
" softmax op, that is 0, 1, …, NT-1 because of the first NT elements"
" of Sampels are positive lables."
);
AddAttr
<
bool
>
(
AddAttr
<
bool
>
(
"use_custom_samples"
,
"use_customized_samples"
,
"An indicator whether to use custom samples with probabilities, if True"
"An indicator whether to use customized samples with probabilities, if "
"the operator will use custom samples and custom probabilities"
"True"
"the operator will use customized samples and customized probabilities"
"otherwise, the operator will generate them by itself."
)
"otherwise, the operator will generate them by itself."
)
.
SetDefault
(
false
);
.
SetDefault
(
false
);
AddAttr
<
bool
>
(
AddAttr
<
bool
>
(
"uniq"
,
"uniq"
,
"An indicator whether to sample non-repetitive negtive labels, if True"
"An indicator whether to sample non-repetitive negtive labels, if True"
"the operator will sample negtive labels without replacement."
"the operator will sample negtive labels without replacement."
"
o
therwise, the operator will sample negtive labels with replacement."
)
"
O
therwise, the operator will sample negtive labels with replacement."
)
.
SetDefault
(
true
);
.
SetDefault
(
true
);
AddAttr
<
bool
>
(
AddAttr
<
bool
>
(
"remove_accidental_hits"
,
"remove_accidental_hits"
,
...
@@ -95,8 +96,7 @@ class SampleLogitsOpMaker : public framework::OpProtoAndCheckerMaker {
...
@@ -95,8 +96,7 @@ class SampleLogitsOpMaker : public framework::OpProtoAndCheckerMaker {
AddComment
(
R"DOC(
AddComment
(
R"DOC(
"""
"""
Computes sampled output training logits and labels suitable for implementing
Computes sampled output training logits and labels suitable for implementing
sampled softmax.
sampled softmax.
"""
"""
)DOC"
);
)DOC"
);
...
@@ -110,7 +110,8 @@ class SampleLogitsOp : public framework::OperatorWithKernel {
...
@@ -110,7 +110,8 @@ class SampleLogitsOp : public framework::OperatorWithKernel {
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Logits"
),
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Logits"
),
"Input(Logits) should be not null."
);
"Input(Logits) should be not null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Label"
),
"Input(Label) should be not null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Labels"
),
"Input(Labels) should be not null."
);
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"Samples"
),
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"Samples"
),
"Output(Samples) should be not null."
);
"Output(Samples) should be not null."
);
...
@@ -118,11 +119,11 @@ class SampleLogitsOp : public framework::OperatorWithKernel {
...
@@ -118,11 +119,11 @@ class SampleLogitsOp : public framework::OperatorWithKernel {
"Output(Probabilities) should be not null."
);
"Output(Probabilities) should be not null."
);
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"SampledLogits"
),
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"SampledLogits"
),
"Output(SampledLogits) should be not null."
);
"Output(SampledLogits) should be not null."
);
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"SampledLabel"
),
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"SampledLabel
s
"
),
"Output(SampledLabel) should be not null."
);
"Output(SampledLabel
s
) should be not null."
);
auto
logits_dims
=
ctx
->
GetInputDim
(
"Logits"
);
auto
logits_dims
=
ctx
->
GetInputDim
(
"Logits"
);
auto
labels_dims
=
ctx
->
GetInputDim
(
"Label"
);
auto
labels_dims
=
ctx
->
GetInputDim
(
"Label
s
"
);
PADDLE_ENFORCE_EQ
(
PADDLE_ENFORCE_EQ
(
logits_dims
.
size
(),
2UL
,
logits_dims
.
size
(),
2UL
,
...
@@ -135,7 +136,7 @@ class SampleLogitsOp : public framework::OperatorWithKernel {
...
@@ -135,7 +136,7 @@ class SampleLogitsOp : public framework::OperatorWithKernel {
ctx
->
SetOutputDim
(
"Samples"
,
{
logits_dims
[
0
],
num_sampled_classes
});
ctx
->
SetOutputDim
(
"Samples"
,
{
logits_dims
[
0
],
num_sampled_classes
});
ctx
->
SetOutputDim
(
"Probabilities"
,
{
logits_dims
[
0
],
num_sampled_classes
});
ctx
->
SetOutputDim
(
"Probabilities"
,
{
logits_dims
[
0
],
num_sampled_classes
});
ctx
->
SetOutputDim
(
"SampledLogits"
,
{
logits_dims
[
0
],
num_sampled_classes
});
ctx
->
SetOutputDim
(
"SampledLogits"
,
{
logits_dims
[
0
],
num_sampled_classes
});
ctx
->
SetOutputDim
(
"SampledLabel"
,
{
logits_dims
[
0
],
labels_dims
[
1
]});
ctx
->
SetOutputDim
(
"SampledLabel
s
"
,
{
logits_dims
[
0
],
labels_dims
[
1
]});
}
}
protected:
protected:
...
@@ -144,7 +145,6 @@ class SampleLogitsOp : public framework::OperatorWithKernel {
...
@@ -144,7 +145,6 @@ class SampleLogitsOp : public framework::OperatorWithKernel {
auto
data_type
=
framework
::
GetDataTypeOfVar
(
ctx
.
InputVar
(
"Logits"
));
auto
data_type
=
framework
::
GetDataTypeOfVar
(
ctx
.
InputVar
(
"Logits"
));
framework
::
OpKernelType
kt
=
framework
::
OpKernelType
kt
=
framework
::
OpKernelType
(
data_type
,
ctx
.
device_context
());
framework
::
OpKernelType
(
data_type
,
ctx
.
device_context
());
// kt.place_ = platform::CPUPlace();
return
kt
;
return
kt
;
}
}
};
};
...
@@ -157,7 +157,8 @@ class SampleLogitsOpGrad : public framework::OperatorWithKernel {
...
@@ -157,7 +157,8 @@ class SampleLogitsOpGrad : public framework::OperatorWithKernel {
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Logits"
),
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Logits"
),
"Input(Logits) should not be null."
);
"Input(Logits) should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Label"
),
"Input(Label) should be not null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Labels"
),
"Input(Labels) should be not null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Samples"
),
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Samples"
),
"Input(Samples) should be not null."
);
"Input(Samples) should be not null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"SampledLogits"
),
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"SampledLogits"
),
...
@@ -168,7 +169,7 @@ class SampleLogitsOpGrad : public framework::OperatorWithKernel {
...
@@ -168,7 +169,7 @@ class SampleLogitsOpGrad : public framework::OperatorWithKernel {
"Output(Logits@Grad) should be not null."
);
"Output(Logits@Grad) should be not null."
);
auto
logit_dims
=
ctx
->
GetInputDim
(
"Logits"
);
auto
logit_dims
=
ctx
->
GetInputDim
(
"Logits"
);
auto
label_dims
=
ctx
->
GetInputDim
(
"Label"
);
auto
label_dims
=
ctx
->
GetInputDim
(
"Label
s
"
);
PADDLE_ENFORCE_EQ
(
label_dims
.
size
(),
2UL
,
PADDLE_ENFORCE_EQ
(
label_dims
.
size
(),
2UL
,
"The label should be a 2-D tensor."
);
"The label should be a 2-D tensor."
);
PADDLE_ENFORCE_EQ
(
logit_dims
.
size
(),
2UL
,
PADDLE_ENFORCE_EQ
(
logit_dims
.
size
(),
2UL
,
...
@@ -185,7 +186,6 @@ class SampleLogitsOpGrad : public framework::OperatorWithKernel {
...
@@ -185,7 +186,6 @@ class SampleLogitsOpGrad : public framework::OperatorWithKernel {
ctx
.
InputVar
(
framework
::
GradVarName
(
"SampledLogits"
)));
ctx
.
InputVar
(
framework
::
GradVarName
(
"SampledLogits"
)));
framework
::
OpKernelType
kt
=
framework
::
OpKernelType
kt
=
framework
::
OpKernelType
(
data_type
,
ctx
.
device_context
());
framework
::
OpKernelType
(
data_type
,
ctx
.
device_context
());
// kt.place_ = platform::CPUPlace();
return
kt
;
return
kt
;
}
}
};
};
...
@@ -200,7 +200,7 @@ class SampleLogitsGradMaker : public framework::SingleGradOpDescMaker {
...
@@ -200,7 +200,7 @@ class SampleLogitsGradMaker : public framework::SingleGradOpDescMaker {
auto
*
grad_op
=
new
framework
::
OpDesc
();
auto
*
grad_op
=
new
framework
::
OpDesc
();
grad_op
->
SetType
(
"sample_logits_grad"
);
grad_op
->
SetType
(
"sample_logits_grad"
);
grad_op
->
SetInput
(
"Logits"
,
Input
(
"Logits"
));
grad_op
->
SetInput
(
"Logits"
,
Input
(
"Logits"
));
grad_op
->
SetInput
(
"Label
"
,
Input
(
"Label
"
));
grad_op
->
SetInput
(
"Label
s"
,
Input
(
"Labels
"
));
grad_op
->
SetInput
(
"Samples"
,
Output
(
"Samples"
));
grad_op
->
SetInput
(
"Samples"
,
Output
(
"Samples"
));
grad_op
->
SetInput
(
"SampledLogits"
,
Output
(
"SampledLogits"
));
grad_op
->
SetInput
(
"SampledLogits"
,
Output
(
"SampledLogits"
));
grad_op
->
SetInput
(
framework
::
GradVarName
(
"SampledLogits"
),
grad_op
->
SetInput
(
framework
::
GradVarName
(
"SampledLogits"
),
...
...
paddle/fluid/operators/sample_logits_op.cu
浏览文件 @
c5360a3f
...
@@ -109,25 +109,26 @@ class SampleLogitsCUDAKernel : public framework::OpKernel<T> {
...
@@ -109,25 +109,26 @@ class SampleLogitsCUDAKernel : public framework::OpKernel<T> {
void
Compute
(
const
framework
::
ExecutionContext
&
context
)
const
override
{
void
Compute
(
const
framework
::
ExecutionContext
&
context
)
const
override
{
// get necessary inputs
// get necessary inputs
const
Tensor
*
logits
=
context
.
Input
<
Tensor
>
(
"Logits"
);
const
Tensor
*
logits
=
context
.
Input
<
Tensor
>
(
"Logits"
);
const
Tensor
*
label
=
context
.
Input
<
Tensor
>
(
"Label
"
);
const
Tensor
*
label
s
=
context
.
Input
<
Tensor
>
(
"Labels
"
);
VLOG
(
3
)
<<
"Enter SampleLogitsCUDAKernel"
;
VLOG
(
3
)
<<
"Enter SampleLogitsCUDAKernel"
;
// get necessary outputs
// get necessary outputs
Tensor
*
samples
=
context
.
Output
<
Tensor
>
(
"Samples"
);
Tensor
*
samples
=
context
.
Output
<
Tensor
>
(
"Samples"
);
Tensor
*
probabilities
=
context
.
Output
<
Tensor
>
(
"Probabilities"
);
Tensor
*
probabilities
=
context
.
Output
<
Tensor
>
(
"Probabilities"
);
Tensor
*
sampled_logits
=
context
.
Output
<
Tensor
>
(
"SampledLogits"
);
Tensor
*
sampled_logits
=
context
.
Output
<
Tensor
>
(
"SampledLogits"
);
Tensor
*
sampled_label
=
context
.
Output
<
Tensor
>
(
"SampledLabel
"
);
Tensor
*
sampled_label
s
=
context
.
Output
<
Tensor
>
(
"SampledLabels
"
);
// shapes
// shapes
const
auto
batch_size
=
logits
->
dims
()[
0
];
const
auto
batch_size
=
logits
->
dims
()[
0
];
const
auto
num_classes
=
logits
->
dims
()[
1
];
const
auto
num_classes
=
logits
->
dims
()[
1
];
const
auto
label
_dim
=
label
->
dims
();
const
auto
label
s_dim
=
labels
->
dims
();
const
auto
num_true
=
label_dim
[
1
];
const
auto
num_true
=
label
s
_dim
[
1
];
const
auto
samples_dim
=
samples
->
dims
();
const
auto
samples_dim
=
samples
->
dims
();
// attrs
// attrs
const
auto
num_samples
=
context
.
Attr
<
int
>
(
"num_samples"
);
const
auto
num_samples
=
context
.
Attr
<
int
>
(
"num_samples"
);
const
bool
use_custom_samples
=
context
.
Attr
<
bool
>
(
"use_custom_samples"
);
const
bool
use_customized_samples
=
context
.
Attr
<
bool
>
(
"use_customized_samples"
);
const
bool
uniq
=
context
.
Attr
<
bool
>
(
"uniq"
);
const
bool
uniq
=
context
.
Attr
<
bool
>
(
"uniq"
);
const
bool
remove_accidental_hits
=
const
bool
remove_accidental_hits
=
context
.
Attr
<
bool
>
(
"remove_accidental_hits"
);
context
.
Attr
<
bool
>
(
"remove_accidental_hits"
);
...
@@ -140,21 +141,22 @@ class SampleLogitsCUDAKernel : public framework::OpKernel<T> {
...
@@ -140,21 +141,22 @@ class SampleLogitsCUDAKernel : public framework::OpKernel<T> {
math
::
SetConstant
<
platform
::
CUDADeviceContext
,
T
>
set_zero
;
math
::
SetConstant
<
platform
::
CUDADeviceContext
,
T
>
set_zero
;
set_zero
(
dev_ctx
,
sampled_logits
,
static_cast
<
T
>
(
0
));
set_zero
(
dev_ctx
,
sampled_logits
,
static_cast
<
T
>
(
0
));
auto
sampled_label_data
=
auto
sampled_label
s
_data
=
sampled_label
->
mutable_data
<
int64_t
>
(
label
_dim
,
context
.
GetPlace
());
sampled_label
s
->
mutable_data
<
int64_t
>
(
labels
_dim
,
context
.
GetPlace
());
int
threads
=
512
;
int
threads
=
512
;
size_t
size
=
batch_size
*
num_true
;
size_t
size
=
batch_size
*
num_true
;
int
grid
=
(
size
+
threads
-
1
)
/
threads
;
int
grid
=
(
size
+
threads
-
1
)
/
threads
;
GPUSetLabel
<
GPUSetLabel
<
T
><<<
grid
,
threads
,
0
,
context
.
cuda_device_context
().
stream
()
>>>
(
T
><<<
grid
,
threads
,
0
,
context
.
cuda_device_context
().
stream
()
>>>
(
size
,
num_true
,
sampled_label_data
);
size
,
num_true
,
sampled_labels_data
);
if
(
use_custom_samples
)
{
if
(
use_customized_samples
)
{
const
Tensor
*
custom_samples
=
context
.
Input
<
Tensor
>
(
"CustomSamples"
);
const
Tensor
*
customized_samples
=
const
Tensor
*
custom_probabilities
=
context
.
Input
<
Tensor
>
(
"CustomizedSamples"
);
context
.
Input
<
Tensor
>
(
"CustomProbabilities"
);
const
Tensor
*
customized_probabilities
=
samples
->
ShareDataWith
(
*
custom_samples
);
context
.
Input
<
Tensor
>
(
"CustomizedProbabilities"
);
probabilities
->
ShareDataWith
(
*
custom_probabilities
);
samples
->
ShareDataWith
(
*
customized_samples
);
probabilities
->
ShareDataWith
(
*
customized_probabilities
);
}
else
{
}
else
{
samples
->
mutable_data
<
int64_t
>
(
context
.
GetPlace
());
samples
->
mutable_data
<
int64_t
>
(
context
.
GetPlace
());
probabilities
->
mutable_data
<
T
>
(
samples_dim
,
context
.
GetPlace
());
probabilities
->
mutable_data
<
T
>
(
samples_dim
,
context
.
GetPlace
());
...
@@ -162,7 +164,7 @@ class SampleLogitsCUDAKernel : public framework::OpKernel<T> {
...
@@ -162,7 +164,7 @@ class SampleLogitsCUDAKernel : public framework::OpKernel<T> {
const
auto
seed
=
context
.
Attr
<
int
>
(
"seed"
);
const
auto
seed
=
context
.
Attr
<
int
>
(
"seed"
);
auto
sampler_with_prob
=
math
::
GPUSampleWithProb
<
T
>
();
auto
sampler_with_prob
=
math
::
GPUSampleWithProb
<
T
>
();
sampler_with_prob
(
context
.
cuda_device_context
(),
seed
,
num_classes
,
uniq
,
sampler_with_prob
(
context
.
cuda_device_context
(),
seed
,
num_classes
,
uniq
,
num_samples
,
label
,
samples
,
probabilities
);
num_samples
,
label
s
,
samples
,
probabilities
);
}
}
// UNDERSTAND: gather sampled logits and remove accidental hits if needed
// UNDERSTAND: gather sampled logits and remove accidental hits if needed
...
...
paddle/fluid/operators/sample_logits_op.h
浏览文件 @
c5360a3f
...
@@ -150,24 +150,25 @@ class SampleLogitsKernel : public framework::OpKernel<T> {
...
@@ -150,24 +150,25 @@ class SampleLogitsKernel : public framework::OpKernel<T> {
VLOG
(
3
)
<<
"Enter SampleLogitsKernel"
;
VLOG
(
3
)
<<
"Enter SampleLogitsKernel"
;
// get necessary inputs
// get necessary inputs
const
Tensor
*
logits
=
context
.
Input
<
Tensor
>
(
"Logits"
);
const
Tensor
*
logits
=
context
.
Input
<
Tensor
>
(
"Logits"
);
const
Tensor
*
label
=
context
.
Input
<
Tensor
>
(
"Label
"
);
const
Tensor
*
label
s
=
context
.
Input
<
Tensor
>
(
"Labels
"
);
// get necessary outputs
// get necessary outputs
Tensor
*
samples
=
context
.
Output
<
Tensor
>
(
"Samples"
);
Tensor
*
samples
=
context
.
Output
<
Tensor
>
(
"Samples"
);
Tensor
*
probabilities
=
context
.
Output
<
Tensor
>
(
"Probabilities"
);
Tensor
*
probabilities
=
context
.
Output
<
Tensor
>
(
"Probabilities"
);
Tensor
*
sampled_logits
=
context
.
Output
<
Tensor
>
(
"SampledLogits"
);
Tensor
*
sampled_logits
=
context
.
Output
<
Tensor
>
(
"SampledLogits"
);
Tensor
*
sampled_label
=
context
.
Output
<
Tensor
>
(
"SampledLabel
"
);
Tensor
*
sampled_label
s
=
context
.
Output
<
Tensor
>
(
"SampledLabels
"
);
// shapes
// shapes
const
auto
batch_size
=
logits
->
dims
()[
0
];
const
auto
batch_size
=
logits
->
dims
()[
0
];
const
auto
num_classes
=
logits
->
dims
()[
1
];
const
auto
num_classes
=
logits
->
dims
()[
1
];
const
auto
label
_dim
=
label
->
dims
();
const
auto
label
s_dim
=
labels
->
dims
();
const
auto
num_true
=
label_dim
[
1
];
const
auto
num_true
=
label
s
_dim
[
1
];
const
auto
samples_dim
=
samples
->
dims
();
const
auto
samples_dim
=
samples
->
dims
();
// attrs
// attrs
const
auto
num_samples
=
context
.
Attr
<
int
>
(
"num_samples"
);
const
auto
num_samples
=
context
.
Attr
<
int
>
(
"num_samples"
);
const
bool
use_custom_samples
=
context
.
Attr
<
bool
>
(
"use_custom_samples"
);
const
bool
use_customized_samples
=
context
.
Attr
<
bool
>
(
"use_customized_samples"
);
const
bool
remove_accidental_hits
=
const
bool
remove_accidental_hits
=
context
.
Attr
<
bool
>
(
"remove_accidental_hits"
);
context
.
Attr
<
bool
>
(
"remove_accidental_hits"
);
...
@@ -177,18 +178,21 @@ class SampleLogitsKernel : public framework::OpKernel<T> {
...
@@ -177,18 +178,21 @@ class SampleLogitsKernel : public framework::OpKernel<T> {
// UNDERSTAND: allocate memories for temporaries
// UNDERSTAND: allocate memories for temporaries
sampled_logits
->
mutable_data
<
T
>
(
samples_dim
,
context
.
GetPlace
());
sampled_logits
->
mutable_data
<
T
>
(
samples_dim
,
context
.
GetPlace
());
auto
sampled_label_data
=
auto
sampled_labels_data
=
sampled_label
->
mutable_data
<
int64_t
>
(
label_dim
,
context
.
GetPlace
());
sampled_labels
->
mutable_data
<
int64_t
>
(
labels_dim
,
context
.
GetPlace
());
for
(
int
i
=
0
;
i
<
batch_size
;
++
i
)
for
(
int
i
=
0
;
i
<
batch_size
;
++
i
)
{
for
(
int
j
=
0
;
j
<
num_true
;
++
j
)
for
(
int
j
=
0
;
j
<
num_true
;
++
j
)
{
sampled_label_data
[
i
*
num_true
+
j
]
=
j
;
sampled_labels_data
[
i
*
num_true
+
j
]
=
j
;
}
if
(
use_custom_samples
)
{
}
const
Tensor
*
custom_samples
=
context
.
Input
<
Tensor
>
(
"CustomSamples"
);
const
Tensor
*
custom_probabilities
=
if
(
use_customized_samples
)
{
context
.
Input
<
Tensor
>
(
"CustomProbabilities"
);
const
Tensor
*
customized_samples
=
samples
->
ShareDataWith
(
*
custom_samples
);
context
.
Input
<
Tensor
>
(
"CustomizedSamples"
);
probabilities
->
ShareDataWith
(
*
custom_probabilities
);
const
Tensor
*
customized_probabilities
=
context
.
Input
<
Tensor
>
(
"CustomizedProbabilities"
);
samples
->
ShareDataWith
(
*
customized_samples
);
probabilities
->
ShareDataWith
(
*
customized_probabilities
);
}
else
{
}
else
{
samples
->
mutable_data
<
int64_t
>
(
context
.
GetPlace
());
samples
->
mutable_data
<
int64_t
>
(
context
.
GetPlace
());
probabilities
->
mutable_data
<
T
>
(
samples_dim
,
context
.
GetPlace
());
probabilities
->
mutable_data
<
T
>
(
samples_dim
,
context
.
GetPlace
());
...
@@ -197,7 +201,7 @@ class SampleLogitsKernel : public framework::OpKernel<T> {
...
@@ -197,7 +201,7 @@ class SampleLogitsKernel : public framework::OpKernel<T> {
auto
sampler_with_prob
=
auto
sampler_with_prob
=
math
::
SampleWithProb
<
platform
::
CPUDeviceContext
,
T
>
();
math
::
SampleWithProb
<
platform
::
CPUDeviceContext
,
T
>
();
sampler_with_prob
(
dev_ctx
,
math
::
LogUniformSampler
(
num_classes
,
seed
),
sampler_with_prob
(
dev_ctx
,
math
::
LogUniformSampler
(
num_classes
,
seed
),
num_samples
,
label
,
samples
,
probabilities
);
num_samples
,
label
s
,
samples
,
probabilities
);
}
}
// UNDERSTAND: gather sampled logits and remove accidental hits if needed
// UNDERSTAND: gather sampled logits and remove accidental hits if needed
...
...
python/paddle/fluid/layers/nn.py
浏览文件 @
c5360a3f
...
@@ -5771,9 +5771,9 @@ def sampled_softmax_with_cross_entropy(logits,
...
@@ -5771,9 +5771,9 @@ def sampled_softmax_with_cross_entropy(logits,
num_samples
,
num_samples
,
num_true
=
1
,
num_true
=
1
,
remove_accidental_hits
=
True
,
remove_accidental_hits
=
True
,
use_custom_samples
=
False
,
use_custom
ized
_samples
=
False
,
custom_samples
=
None
,
custom
ized
_samples
=
None
,
custom_probabilities
=
None
,
custom
ized
_probabilities
=
None
,
seed
=
0
):
seed
=
0
):
"""
"""
**Sampled Softmax With Cross Entropy Operator.**
**Sampled Softmax With Cross Entropy Operator.**
...
@@ -5789,7 +5789,7 @@ def sampled_softmax_with_cross_entropy(logits,
...
@@ -5789,7 +5789,7 @@ def sampled_softmax_with_cross_entropy(logits,
For examples with T true labels (T >= 1), we assume that each true label has
For examples with T true labels (T >= 1), we assume that each true label has
a probability of 1/T. For each sample, S samples are generated using a
a probability of 1/T. For each sample, S samples are generated using a
log uniform distribution. True labels are concatenated with hese samples to
log uniform distribution. True labels are concatenated with
t
hese samples to
form T + S samples for each example. So, assume the shape of logits is
form T + S samples for each example. So, assume the shape of logits is
[N x K], the shape for samples is [N x (T+S)]. For each sampled label, a
[N x K], the shape for samples is [N x (T+S)]. For each sampled label, a
probability is calculated, which corresponds to the Q(y|x) in
probability is calculated, which corresponds to the Q(y|x) in
...
@@ -5798,7 +5798,7 @@ def sampled_softmax_with_cross_entropy(logits,
...
@@ -5798,7 +5798,7 @@ def sampled_softmax_with_cross_entropy(logits,
Logits are sampled according to the sampled labels. Then if
Logits are sampled according to the sampled labels. Then if
remove_accidental_hits is True, if a sample[i, j] accidentally hits true
remove_accidental_hits is True, if a sample[i, j] accidentally hits true
labels, then the corresponding sampled_logits[i, j] is minus by 1e20 to
labels, then the corresponding sampled_logits[i, j] is minus by 1e20 to
make its softmax result close to zero. Then samled logits are subtracted by
make its softmax result close to zero. Then sam
p
led logits are subtracted by
logQ(y|x), these sampled logits and re-indexed labels are used to compute
logQ(y|x), these sampled logits and re-indexed labels are used to compute
a softmax with cross entropy.
a softmax with cross entropy.
...
@@ -5816,14 +5816,16 @@ def sampled_softmax_with_cross_entropy(logits,
...
@@ -5816,14 +5816,16 @@ def sampled_softmax_with_cross_entropy(logits,
accidentally hits true labels, then the corresponding
accidentally hits true labels, then the corresponding
sampled_logits[i, j] is minus by 1e20 to make its softmax result
sampled_logits[i, j] is minus by 1e20 to make its softmax result
close to zero. Default is True.
close to zero. Default is True.
use_custom_samples (bool): Whether to use custom samples and probabities to sample
use_custom
ized
_samples (bool): Whether to use custom samples and probabities to sample
logits.
logits.
custom_samples (Variable): User defined samples, which is a 1-D tensor with shape [S]. S is the num_samples.
customized_samples (Variable): User defined samples, which is a 2-D tensor
custom_probabilities (Variable): User defined probabilities of samples, a 1-D tensor which has the same shape with custom_samples.
with shape [N, T + S]. S is the num_samples, and T is the number of true
labels per example.
customized_probabilities (Variable): User defined probabilities of samples,
a 2-D tensor which has the same shape with customized_samples.
seed (int): The random seed for generating random number, which is used
seed (int): The random seed for generating random number, which is used
in the process of sampling. Default is 0.
in the process of sampling. Default is 0.
Returns:
Returns:
Variable: Return the cross entropy loss which is a 2-D tensor with shape
Variable: Return the cross entropy loss which is a 2-D tensor with shape
[N x 1].
[N x 1].
...
@@ -5849,18 +5851,18 @@ def sampled_softmax_with_cross_entropy(logits,
...
@@ -5849,18 +5851,18 @@ def sampled_softmax_with_cross_entropy(logits,
type
=
'sample_logits'
,
type
=
'sample_logits'
,
inputs
=
{
inputs
=
{
'Logits'
:
logits
,
'Logits'
:
logits
,
'Label'
:
label
,
'Label
s
'
:
label
,
'CustomSamples'
:
custom_samples
,
'CustomSamples'
:
custom_samples
,
'CustomProbabilities'
:
custom_probabilities
'CustomProbabilities'
:
custom_probabilities
},
},
outputs
=
{
outputs
=
{
'Samples'
:
samples
,
'Samples'
:
samples
,
'Probabilities'
:
probabilities
,
'Probabilities'
:
probabilities
,
'SampledLabel'
:
sampled_label
,
'SampledLabel
s
'
:
sampled_label
,
'SampledLogits'
:
sampled_logits
'SampledLogits'
:
sampled_logits
},
},
attrs
=
{
attrs
=
{
'use_custom
_samples'
:
use_custom
_samples
,
'use_custom
ized_samples'
:
use_customized
_samples
,
'uniq'
:
True
,
'uniq'
:
True
,
'remove_accidental_hits'
:
remove_accidental_hits
,
'remove_accidental_hits'
:
remove_accidental_hits
,
'num_samples'
:
num_samples
,
'num_samples'
:
num_samples
,
...
...
python/paddle/fluid/tests/unittests/test_sample_logits.py
浏览文件 @
c5360a3f
...
@@ -61,8 +61,8 @@ def take_along_axis1(array, index):
...
@@ -61,8 +61,8 @@ def take_along_axis1(array, index):
return
out
return
out
def
sample_prob
(
sampler
,
num_samples
,
label
):
def
sample_prob
(
sampler
,
num_samples
,
label
s
):
batch_size
,
num_true
=
label
.
shape
batch_size
,
num_true
=
label
s
.
shape
num_sampled_classes
=
num_samples
+
num_true
num_sampled_classes
=
num_samples
+
num_true
samples
=
np
.
zeros
((
batch_size
,
num_sampled_classes
),
dtype
=
np
.
int64
)
samples
=
np
.
zeros
((
batch_size
,
num_sampled_classes
),
dtype
=
np
.
int64
)
...
@@ -74,8 +74,8 @@ def sample_prob(sampler, num_samples, label):
...
@@ -74,8 +74,8 @@ def sample_prob(sampler, num_samples, label):
j
=
0
j
=
0
while
j
<
num_true
:
while
j
<
num_true
:
for
i
in
range
(
batch_size
):
for
i
in
range
(
batch_size
):
samples
[
i
,
j
]
=
label
[
i
,
j
]
samples
[
i
,
j
]
=
label
s
[
i
,
j
]
probabilities
[
i
,
j
]
=
sampler
.
probability
(
label
[
i
,
j
])
probabilities
[
i
,
j
]
=
sampler
.
probability
(
label
s
[
i
,
j
])
j
+=
1
j
+=
1
while
j
<
num_sampled_classes
:
while
j
<
num_sampled_classes
:
v
=
sampler
.
sample
()
v
=
sampler
.
sample
()
...
@@ -103,33 +103,30 @@ def compute_remove_accidental_hits(sampled_logits, samples, num_true):
...
@@ -103,33 +103,30 @@ def compute_remove_accidental_hits(sampled_logits, samples, num_true):
def
sample_logits
(
logits
,
def
sample_logits
(
logits
,
label
,
label
s
,
num_samples
,
num_samples
,
seed
,
seed
,
remove_accidental_hits
,
remove_accidental_hits
,
use_custom_samples
,
use_custom
ized
_samples
,
custom_samples
=
None
,
custom
ized
_samples
=
None
,
custom_probabilities
=
None
):
custom
ized
_probabilities
=
None
):
batch_size
,
num_classes
=
logits
.
shape
batch_size
,
num_classes
=
logits
.
shape
num_true
=
label
.
shape
[
1
]
num_true
=
label
s
.
shape
[
1
]
num_sampled_classes
=
num_true
+
num_samples
num_sampled_classes
=
num_true
+
num_samples
if
use_custom_samples
:
if
use_custom
ized
_samples
:
samples
=
custom_samples
samples
=
custom
ized
_samples
probabilities
=
custom_probabilities
probabilities
=
custom
ized
_probabilities
else
:
else
:
sampler
=
LogUniformSampler
(
num_classes
,
seed
)
sampler
=
LogUniformSampler
(
num_classes
,
seed
)
samples
,
probabilities
=
sample_prob
(
sampler
,
num_samples
,
label
)
samples
,
probabilities
=
sample_prob
(
sampler
,
num_samples
,
label
s
)
sampled_logits
=
take_along_axis1
(
logits
,
samples
)
sampled_logits
=
take_along_axis1
(
logits
,
samples
)
#print(samples)
#print(probabilities)
#print(sampled_logits)
if
remove_accidental_hits
:
if
remove_accidental_hits
:
compute_remove_accidental_hits
(
sampled_logits
,
samples
,
num_true
)
compute_remove_accidental_hits
(
sampled_logits
,
samples
,
num_true
)
sampled_logits
-=
np
.
log
(
probabilities
)
sampled_logits
-=
np
.
log
(
probabilities
)
sampled_label
=
np
.
tile
(
np
.
arange
(
num_true
),
(
batch_size
,
1
))
sampled_label
s
=
np
.
tile
(
np
.
arange
(
num_true
),
(
batch_size
,
1
))
return
(
sampled_logits
,
samples
,
sampled_label
,
probabilities
)
return
(
sampled_logits
,
samples
,
sampled_label
s
,
probabilities
)
class
TestSampleLogitsOp
(
OpTest
):
class
TestSampleLogitsOp
(
OpTest
):
...
@@ -138,51 +135,51 @@ class TestSampleLogitsOp(OpTest):
...
@@ -138,51 +135,51 @@ class TestSampleLogitsOp(OpTest):
in python and just test the non-random part.
in python and just test the non-random part.
'''
'''
def
generate_data
(
self
,
logits
,
label
,
num_samples
,
seed
,
def
generate_data
(
self
,
logits
,
label
s
,
num_samples
,
seed
,
remove_accidental_hits
,
use_custom_samples
,
remove_accidental_hits
,
use_custom
ized
_samples
,
custom
_samples
,
custom
_probabilities
):
custom
ized_samples
,
customized
_probabilities
):
self
.
attrs
=
{
self
.
attrs
=
{
'num_samples'
:
num_samples
,
'num_samples'
:
num_samples
,
'use_custom
_samples'
:
use_custom
_samples
,
'use_custom
ized_samples'
:
use_customized
_samples
,
'remove_accidental_hits'
:
remove_accidental_hits
,
'remove_accidental_hits'
:
remove_accidental_hits
,
'seed'
:
seed
'seed'
:
seed
}
}
self
.
inputs
=
{
self
.
inputs
=
{
'Logits'
:
logits
,
'Logits'
:
logits
,
'Label
'
:
label
,
'Label
s'
:
labels
,
'Custom
Samples'
:
custom
_samples
,
'Custom
izedSamples'
:
customized
_samples
,
'Custom
Probabilities'
:
custom
_probabilities
'Custom
izedProbabilities'
:
customized
_probabilities
}
}
def
set_data
(
self
,
batch_size
,
num_classes
,
num_true
,
num_samples
,
seed
,
def
set_data
(
self
,
batch_size
,
num_classes
,
num_true
,
num_samples
,
seed
,
remove_accidental_hits
):
remove_accidental_hits
):
logits
=
np
.
random
.
randn
(
batch_size
,
num_classes
)
logits
=
np
.
random
.
randn
(
batch_size
,
num_classes
)
label
=
np
.
stack
([
label
s
=
np
.
stack
([
np
.
random
.
choice
(
np
.
random
.
choice
(
range
(
0
,
num_classes
),
num_true
,
replace
=
False
)
range
(
0
,
num_classes
),
num_true
,
replace
=
False
)
for
_
in
range
(
batch_size
)
for
_
in
range
(
batch_size
)
])
])
sampler
=
LogUniformSampler
(
num_classes
,
seed
)
sampler
=
LogUniformSampler
(
num_classes
,
seed
)
custom
_samples
,
custom
_probabilities
=
\
custom
ized_samples
,
customized
_probabilities
=
\
sample_prob
(
sampler
,
num_samples
,
label
)
sample_prob
(
sampler
,
num_samples
,
label
s
)
use_custom_samples
=
True
use_custom
ized
_samples
=
True
remove_accidental_hits
=
remove_accidental_hits
remove_accidental_hits
=
remove_accidental_hits
self
.
generate_data
(
logits
,
label
,
num_samples
,
seed
,
self
.
generate_data
(
logits
,
label
s
,
num_samples
,
seed
,
remove_accidental_hits
,
use_custom_samples
,
remove_accidental_hits
,
use_custom
ized
_samples
,
custom
_samples
,
custom
_probabilities
)
custom
ized_samples
,
customized
_probabilities
)
def
compute
(
self
):
def
compute
(
self
):
out
=
sample_logits
(
self
.
inputs
[
"Logits"
],
self
.
inputs
[
"Label"
],
out
=
sample_logits
(
self
.
inputs
[
"Logits"
],
self
.
inputs
[
"Label
s
"
],
self
.
attrs
[
"num_samples"
],
self
.
attrs
[
"seed"
],
self
.
attrs
[
"num_samples"
],
self
.
attrs
[
"seed"
],
self
.
attrs
[
"remove_accidental_hits"
],
self
.
attrs
[
"remove_accidental_hits"
],
self
.
attrs
[
"use_custom_samples"
],
self
.
attrs
[
"use_custom
ized
_samples"
],
self
.
inputs
[
"CustomSamples"
],
self
.
inputs
[
"Custom
ized
Samples"
],
self
.
inputs
[
"CustomProbabilities"
])
self
.
inputs
[
"Custom
ized
Probabilities"
])
self
.
outputs
=
{
self
.
outputs
=
{
'SampledLogits'
:
out
[
0
],
'SampledLogits'
:
out
[
0
],
'Samples'
:
out
[
1
],
'Samples'
:
out
[
1
],
'SampledLabel'
:
out
[
2
],
'SampledLabel
s
'
:
out
[
2
],
'Probabilities'
:
out
[
3
]
'Probabilities'
:
out
[
3
]
}
}
...
@@ -255,29 +252,29 @@ class TestSampleLogitsOpV2(OpTest):
...
@@ -255,29 +252,29 @@ class TestSampleLogitsOpV2(OpTest):
in C++ and copied to python and just test the non-random part.
in C++ and copied to python and just test the non-random part.
'''
'''
def
generate_data
(
self
,
logits
,
label
,
num_samples
,
seed
,
def
generate_data
(
self
,
logits
,
label
s
,
num_samples
,
seed
,
remove_accidental_hits
,
use_custom_samples
):
remove_accidental_hits
,
use_custom
ized
_samples
):
self
.
attrs
=
{
self
.
attrs
=
{
'num_samples'
:
num_samples
,
'num_samples'
:
num_samples
,
'use_custom
_samples'
:
use_custom
_samples
,
'use_custom
ized_samples'
:
use_customized
_samples
,
'remove_accidental_hits'
:
remove_accidental_hits
,
'remove_accidental_hits'
:
remove_accidental_hits
,
'seed'
:
seed
'seed'
:
seed
}
}
self
.
inputs
=
{
'Logits'
:
logits
,
'Label
'
:
label
.
astype
(
np
.
int64
)}
self
.
inputs
=
{
'Logits'
:
logits
,
'Label
s'
:
labels
.
astype
(
np
.
int64
)}
def
set_data
(
self
,
num_classes
,
num_samples
,
seed
,
remove_accidental_hits
):
def
set_data
(
self
,
num_classes
,
num_samples
,
seed
,
remove_accidental_hits
):
label
=
np
.
array
([[
6
,
12
,
15
,
5
,
1
],
[
0
,
9
,
4
,
1
,
10
],
label
s
=
np
.
array
([[
6
,
12
,
15
,
5
,
1
],
[
0
,
9
,
4
,
1
,
10
],
[
0
,
2
,
10
,
16
,
13
],
[
14
,
4
,
7
,
2
,
1
],
[
0
,
2
,
10
,
16
,
13
],
[
14
,
4
,
7
,
2
,
1
],
[
3
,
18
,
11
,
8
,
14
]])
[
3
,
18
,
11
,
8
,
14
]])
batch_size
,
num_true
=
label
.
shape
batch_size
,
num_true
=
label
s
.
shape
use_custom_samples
=
False
use_custom
ized
_samples
=
False
num_sampled_classes
=
num_samples
+
num_true
num_sampled_classes
=
num_samples
+
num_true
logits
=
np
.
random
.
randn
(
batch_size
,
num_classes
)
logits
=
np
.
random
.
randn
(
batch_size
,
num_classes
)
remove_accidental_hits
=
remove_accidental_hits
remove_accidental_hits
=
remove_accidental_hits
self
.
generate_data
(
logits
,
label
,
num_samples
,
seed
,
self
.
generate_data
(
logits
,
label
s
,
num_samples
,
seed
,
remove_accidental_hits
,
use_custom_samples
)
remove_accidental_hits
,
use_custom
ized
_samples
)
# python and c++ use different random generator
# python and c++ use different random generator
# use fetched samples from c++ for python code
# use fetched samples from c++ for python code
...
@@ -302,7 +299,7 @@ class TestSampleLogitsOpV2(OpTest):
...
@@ -302,7 +299,7 @@ class TestSampleLogitsOpV2(OpTest):
self
.
probabilities
=
probabilities
self
.
probabilities
=
probabilities
def
compute
(
self
):
def
compute
(
self
):
out
=
sample_logits
(
self
.
inputs
[
"Logits"
],
self
.
inputs
[
"Label"
],
out
=
sample_logits
(
self
.
inputs
[
"Logits"
],
self
.
inputs
[
"Label
s
"
],
self
.
attrs
[
"num_samples"
],
self
.
attrs
[
"seed"
],
self
.
attrs
[
"num_samples"
],
self
.
attrs
[
"seed"
],
self
.
attrs
[
"remove_accidental_hits"
],
True
,
self
.
attrs
[
"remove_accidental_hits"
],
True
,
self
.
fetched_samples
.
astype
(
np
.
int64
),
self
.
fetched_samples
.
astype
(
np
.
int64
),
...
@@ -310,7 +307,7 @@ class TestSampleLogitsOpV2(OpTest):
...
@@ -310,7 +307,7 @@ class TestSampleLogitsOpV2(OpTest):
self
.
outputs
=
{
self
.
outputs
=
{
'SampledLogits'
:
out
[
0
],
'SampledLogits'
:
out
[
0
],
'Samples'
:
out
[
1
],
'Samples'
:
out
[
1
],
'SampledLabel'
:
out
[
2
],
'SampledLabel
s
'
:
out
[
2
],
'Probabilities'
:
out
[
3
]
'Probabilities'
:
out
[
3
]
}
}
...
@@ -339,18 +336,18 @@ class TestSampleLogitsOpV3(OpTest):
...
@@ -339,18 +336,18 @@ class TestSampleLogitsOpV3(OpTest):
in C++ and copied to python and just test the non-random part.
in C++ and copied to python and just test the non-random part.
'''
'''
def
generate_data
(
self
,
logits
,
label
,
num_samples
,
seed
,
def
generate_data
(
self
,
logits
,
label
s
,
num_samples
,
seed
,
remove_accidental_hits
,
use_custom_samples
):
remove_accidental_hits
,
use_custom
ized
_samples
):
self
.
attrs
=
{
self
.
attrs
=
{
'num_samples'
:
num_samples
,
'num_samples'
:
num_samples
,
'use_custom
_samples'
:
use_custom
_samples
,
'use_custom
ized_samples'
:
use_customized
_samples
,
'remove_accidental_hits'
:
remove_accidental_hits
,
'remove_accidental_hits'
:
remove_accidental_hits
,
'seed'
:
seed
'seed'
:
seed
}
}
self
.
inputs
=
{
'Logits'
:
logits
,
'Label
'
:
label
.
astype
(
np
.
int64
)}
self
.
inputs
=
{
'Logits'
:
logits
,
'Label
s'
:
labels
.
astype
(
np
.
int64
)}
def
set_data
(
self
,
num_classes
,
num_samples
,
seed
,
remove_accidental_hits
):
def
set_data
(
self
,
num_classes
,
num_samples
,
seed
,
remove_accidental_hits
):
label
=
[
52
,
2
,
2
,
17
,
96
,
2
,
17
,
96
,
37
,
2
]
label
s
=
[
52
,
2
,
2
,
17
,
96
,
2
,
17
,
96
,
37
,
2
]
samples
=
[
samples
=
[
3
,
12
,
74
,
28
,
1
,
79
,
2
,
42
,
8
,
13
,
0
,
18
,
88
,
49
,
14
,
46
,
39
,
57
,
3
,
12
,
74
,
28
,
1
,
79
,
2
,
42
,
8
,
13
,
0
,
18
,
88
,
49
,
14
,
46
,
39
,
57
,
26
,
75
,
9
,
50
,
16
,
66
,
6
,
23
,
5
,
11
,
17
,
54
,
35
,
20
,
53
,
10
,
47
,
80
,
26
,
75
,
9
,
50
,
16
,
66
,
6
,
23
,
5
,
11
,
17
,
54
,
35
,
20
,
53
,
10
,
47
,
80
,
...
@@ -359,19 +356,19 @@ class TestSampleLogitsOpV3(OpTest):
...
@@ -359,19 +356,19 @@ class TestSampleLogitsOpV3(OpTest):
63
,
81
,
59
,
48
,
91
,
68
,
72
,
61
,
52
,
86
63
,
81
,
59
,
48
,
91
,
68
,
72
,
61
,
52
,
86
]
]
self
.
fetched_samples
=
np
.
array
([[
x
]
+
samples
for
x
in
label
])
self
.
fetched_samples
=
np
.
array
([[
x
]
+
samples
for
x
in
label
s
])
fectched_num_tries
=
323
fectched_num_tries
=
323
label
=
self
.
fetched_samples
[:,
0
:
1
]
label
s
=
self
.
fetched_samples
[:,
0
:
1
]
batch_size
,
num_true
=
label
.
shape
batch_size
,
num_true
=
label
s
.
shape
use_custom_samples
=
False
use_custom
ized
_samples
=
False
num_sampled_classes
=
num_samples
+
num_true
num_sampled_classes
=
num_samples
+
num_true
logits
=
np
.
random
.
randn
(
batch_size
,
num_classes
)
logits
=
np
.
random
.
randn
(
batch_size
,
num_classes
)
remove_accidental_hits
=
remove_accidental_hits
remove_accidental_hits
=
remove_accidental_hits
self
.
generate_data
(
logits
,
label
,
num_samples
,
seed
,
self
.
generate_data
(
logits
,
label
s
,
num_samples
,
seed
,
remove_accidental_hits
,
use_custom_samples
)
remove_accidental_hits
,
use_custom
ized
_samples
)
# python and c++ use different random generator
# python and c++ use different random generator
# use fetched samples from c++ for python code
# use fetched samples from c++ for python code
...
@@ -388,7 +385,7 @@ class TestSampleLogitsOpV3(OpTest):
...
@@ -388,7 +385,7 @@ class TestSampleLogitsOpV3(OpTest):
self
.
probabilities
=
probabilities
self
.
probabilities
=
probabilities
def
compute
(
self
):
def
compute
(
self
):
out
=
sample_logits
(
self
.
inputs
[
"Logits"
],
self
.
inputs
[
"Label"
],
out
=
sample_logits
(
self
.
inputs
[
"Logits"
],
self
.
inputs
[
"Label
s
"
],
self
.
attrs
[
"num_samples"
],
self
.
attrs
[
"seed"
],
self
.
attrs
[
"num_samples"
],
self
.
attrs
[
"seed"
],
self
.
attrs
[
"remove_accidental_hits"
],
True
,
self
.
attrs
[
"remove_accidental_hits"
],
True
,
self
.
fetched_samples
.
astype
(
np
.
int64
),
self
.
fetched_samples
.
astype
(
np
.
int64
),
...
@@ -396,7 +393,7 @@ class TestSampleLogitsOpV3(OpTest):
...
@@ -396,7 +393,7 @@ class TestSampleLogitsOpV3(OpTest):
self
.
outputs
=
{
self
.
outputs
=
{
'SampledLogits'
:
out
[
0
],
'SampledLogits'
:
out
[
0
],
'Samples'
:
out
[
1
],
'Samples'
:
out
[
1
],
'SampledLabel'
:
out
[
2
],
'SampledLabel
s
'
:
out
[
2
],
'Probabilities'
:
out
[
3
]
'Probabilities'
:
out
[
3
]
}
}
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录