Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
b316437a
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
b316437a
编写于
11月 02, 2018
作者:
Z
Zeng Jinle
提交者:
GitHub
11月 02, 2018
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #14087 from sneaxiy/add_use_cudnn_in_softmax_with_xe
Add numeric_stable_mode parameters to softmax_with_xe op
上级
fcf370e6
366ebb93
变更
5
隐藏空白更改
内联
并排
Showing
5 changed file
with
222 addition
and
26 deletion
+222
-26
paddle/fluid/API.spec
paddle/fluid/API.spec
+1
-1
paddle/fluid/operators/softmax_with_cross_entropy_op.cc
paddle/fluid/operators/softmax_with_cross_entropy_op.cc
+6
-0
paddle/fluid/operators/softmax_with_cross_entropy_op.cu
paddle/fluid/operators/softmax_with_cross_entropy_op.cu
+166
-21
python/paddle/fluid/layers/nn.py
python/paddle/fluid/layers/nn.py
+26
-3
python/paddle/fluid/tests/unittests/test_softmax_with_cross_entropy_op.py
...uid/tests/unittests/test_softmax_with_cross_entropy_op.py
+23
-1
未找到文件。
paddle/fluid/API.spec
浏览文件 @
b316437a
...
...
@@ -103,7 +103,7 @@ paddle.fluid.layers.beam_search ArgSpec(args=['pre_ids', 'pre_scores', 'ids', 's
paddle.fluid.layers.row_conv ArgSpec(args=['input', 'future_context_size', 'param_attr', 'act'], varargs=None, keywords=None, defaults=(None, None))
paddle.fluid.layers.multiplex ArgSpec(args=['inputs', 'index'], varargs=None, keywords=None, defaults=None)
paddle.fluid.layers.layer_norm ArgSpec(args=['input', 'scale', 'shift', 'begin_norm_axis', 'epsilon', 'param_attr', 'bias_attr', 'act', 'name'], varargs=None, keywords=None, defaults=(True, True, 1, 1e-05, None, None, None, None))
paddle.fluid.layers.softmax_with_cross_entropy ArgSpec(args=['logits', 'label', 'soft_label', 'ignore_index'
], varargs=None, keywords=None, defaults=(False, -100
))
paddle.fluid.layers.softmax_with_cross_entropy ArgSpec(args=['logits', 'label', 'soft_label', 'ignore_index'
, 'numeric_stable_mode'], varargs=None, keywords=None, defaults=(False, -100, False
))
paddle.fluid.layers.smooth_l1 ArgSpec(args=['x', 'y', 'inside_weight', 'outside_weight', 'sigma'], varargs=None, keywords=None, defaults=(None, None, None))
paddle.fluid.layers.one_hot ArgSpec(args=['input', 'depth'], varargs=None, keywords=None, defaults=None)
paddle.fluid.layers.autoincreased_step_counter ArgSpec(args=['counter_name', 'begin', 'step'], varargs=None, keywords=None, defaults=(None, 1, 1))
...
...
paddle/fluid/operators/softmax_with_cross_entropy_op.cc
浏览文件 @
b316437a
...
...
@@ -44,6 +44,12 @@ class SoftmaxWithCrossEntropyOpMaker
"(bool, default: false), A flag to indicate whether to interpretate "
"the given labels as soft labels."
)
.
SetDefault
(
false
);
AddAttr
<
bool
>
(
"numeric_stable_mode"
,
"(bool, default: false), A flag to indicate whether to use more "
"numerically stable algorithm. This flag is only valid when "
"soft_label is false and GPU is used."
)
.
SetDefault
(
false
);
AddAttr
<
int
>
(
"ignore_index"
,
"(int, default -100), Specifies a target value that is ignored and"
...
...
paddle/fluid/operators/softmax_with_cross_entropy_op.cu
浏览文件 @
b316437a
...
...
@@ -17,6 +17,7 @@ limitations under the License. */
#include <cub/cub.cuh>
#include "paddle/fluid/operators/math/cross_entropy.h"
#include "paddle/fluid/operators/softmax_with_cross_entropy_op.h"
#include "paddle/fluid/platform/for_range.h"
namespace
paddle
{
namespace
operators
{
...
...
@@ -117,8 +118,8 @@ using BlockReduceTempStorage = typename BlockReduce<T, BlockDim>::TempStorage;
// Make sure that BlockDim <= feature_size
// This kernel is used to calculate the max element of each row
template
<
typename
T
,
int
BlockDim
>
__global__
void
RowReductionForMax
(
const
T
*
logits_data
,
T
*
max_data
,
int
feature_size
)
{
static
__global__
void
RowReductionForMax
(
const
T
*
logits_data
,
T
*
max_data
,
int
feature_size
)
{
__shared__
BlockReduceTempStorage
<
T
,
BlockDim
>
temp_storage
;
auto
beg_idx
=
feature_size
*
blockIdx
.
x
+
threadIdx
.
x
;
...
...
@@ -141,9 +142,10 @@ __global__ void RowReductionForMax(const T* logits_data, T* max_data,
}
// Make sure that BlockDim <= feature_size
template
<
typename
T
,
int
BlockDim
>
__global__
void
RowReductionForDiffMaxSum
(
const
T
*
logits_data
,
T
*
max_data
,
T
*
softmax
,
int
feature_size
)
{
template
<
typename
T
,
int
BlockDim
,
bool
CalculateLogSoftmax
=
false
>
static
__global__
void
RowReductionForDiffMaxSum
(
const
T
*
logits_data
,
T
*
max_data
,
T
*
softmax
,
int
feature_size
)
{
__shared__
BlockReduceTempStorage
<
T
,
BlockDim
>
temp_storage
;
auto
beg_idx
=
feature_size
*
blockIdx
.
x
+
threadIdx
.
x
;
...
...
@@ -153,24 +155,34 @@ __global__ void RowReductionForDiffMaxSum(const T* logits_data, T* max_data,
softmax
[
beg_idx
]
=
logits_data
[
beg_idx
]
-
block_max
;
T
diff_max_sum
=
real_exp
(
softmax
[
beg_idx
]);
beg_idx
+=
BlockDim
;
while
(
beg_
idx
<
end_idx
)
{
softmax
[
beg_idx
]
=
logits_data
[
beg_
idx
]
-
block_max
;
diff_max_sum
+=
real_exp
(
softmax
[
beg_
idx
]);
beg_
idx
+=
BlockDim
;
auto
idx
=
beg_idx
+
BlockDim
;
while
(
idx
<
end_idx
)
{
softmax
[
idx
]
=
logits_data
[
idx
]
-
block_max
;
diff_max_sum
+=
real_exp
(
softmax
[
idx
]);
idx
+=
BlockDim
;
}
diff_max_sum
=
BlockReduce
<
T
,
BlockDim
>
(
temp_storage
).
Reduce
(
diff_max_sum
,
cub
::
Sum
());
if
(
threadIdx
.
x
==
0
)
max_data
[
blockIdx
.
x
]
=
real_log
(
diff_max_sum
);
if
(
!
CalculateLogSoftmax
)
return
;
__syncthreads
();
diff_max_sum
=
max_data
[
blockIdx
.
x
];
softmax
[
beg_idx
]
-=
diff_max_sum
;
beg_idx
+=
BlockDim
;
while
(
beg_idx
<
end_idx
)
{
softmax
[
beg_idx
]
-=
diff_max_sum
;
beg_idx
+=
BlockDim
;
}
if
(
threadIdx
.
x
==
0
)
max_data
[
blockIdx
.
x
]
=
0
;
}
// Make sure that BlockDim <= feature_size
template
<
typename
T
,
int
BlockDim
>
__global__
void
RowReductionForSoftmaxAndCrossEntropy
(
const
T
*
logits_data
,
const
T
*
labels_data
,
T
*
loss_data
,
T
*
softmax
,
int
feature_size
)
{
static
__global__
void
RowReductionForSoftmaxAndCrossEntropy
(
const
T
*
logits_data
,
const
T
*
labels_data
,
T
*
loss_data
,
T
*
softmax
,
int
feature_size
)
{
__shared__
BlockReduceTempStorage
<
T
,
BlockDim
>
temp_storage
;
auto
beg_idx
=
feature_size
*
blockIdx
.
x
+
threadIdx
.
x
;
...
...
@@ -194,11 +206,134 @@ __global__ void RowReductionForSoftmaxAndCrossEntropy(const T* logits_data,
}
template
<
typename
T
>
__global__
void
SetSoftmaxToOneWhenFeatureSizeIsOne
(
T
*
out
,
int
batch_size
)
{
struct
HardLabelSoftmaxWithCrossEntropyFunctor
{
public:
HardLabelSoftmaxWithCrossEntropyFunctor
(
const
T
*
logits
,
const
int64_t
*
labels
,
T
*
loss
,
T
*
log_softmax
,
int
feature_size
)
:
logits_
(
logits
),
labels_
(
labels
),
loss_
(
loss
),
log_softmax_
(
log_softmax
),
feature_size_
(
feature_size
)
{}
__device__
void
operator
()(
int
idx
)
const
{
auto
row_idx
=
idx
/
feature_size_
;
auto
col_idx
=
idx
%
feature_size_
;
if
(
col_idx
!=
labels_
[
row_idx
])
{
log_softmax_
[
idx
]
=
real_exp
(
log_softmax_
[
idx
]);
}
else
{
auto
softmax
=
log_softmax_
[
idx
];
log_softmax_
[
idx
]
=
real_exp
(
softmax
);
loss_
[
row_idx
]
=
-
softmax
;
}
}
private:
const
T
*
logits_
;
const
int64_t
*
labels_
;
T
*
loss_
;
T
*
log_softmax_
;
int
feature_size_
;
};
template
<
typename
T
>
struct
HardLabelSoftmaxWithCrossEntropyFunctorWithIgnoreIdx
{
public:
HardLabelSoftmaxWithCrossEntropyFunctorWithIgnoreIdx
(
const
T
*
logits
,
const
int64_t
*
labels
,
T
*
loss
,
T
*
log_softmax
,
int
feature_size
,
int
ignore_idx
)
:
logits_
(
logits
),
labels_
(
labels
),
loss_
(
loss
),
log_softmax_
(
log_softmax
),
feature_size_
(
feature_size
),
ignore_idx_
(
ignore_idx
)
{}
__device__
void
operator
()(
int
idx
)
const
{
auto
row_idx
=
idx
/
feature_size_
;
auto
col_idx
=
idx
%
feature_size_
;
if
(
col_idx
!=
labels_
[
row_idx
]
||
col_idx
==
ignore_idx_
)
{
log_softmax_
[
idx
]
=
real_exp
(
log_softmax_
[
idx
]);
}
else
{
auto
softmax
=
log_softmax_
[
idx
];
log_softmax_
[
idx
]
=
real_exp
(
softmax
);
loss_
[
row_idx
]
=
-
softmax
;
}
}
private:
const
T
*
logits_
;
const
int64_t
*
labels_
;
T
*
loss_
;
T
*
log_softmax_
;
int
feature_size_
;
int
ignore_idx_
;
};
template
<
typename
T
>
static
__global__
void
SetSoftmaxToOneWhenFeatureSizeIsOne
(
T
*
out
,
int
batch_size
)
{
auto
idx
=
threadIdx
.
x
+
blockIdx
.
x
*
blockDim
.
x
;
if
(
idx
<
batch_size
)
out
[
idx
]
=
static_cast
<
T
>
(
1
);
}
template
<
typename
T
>
static
void
HardLabelSoftmaxWithCrossEntropy
(
const
platform
::
CUDADeviceContext
&
ctx
,
const
T
*
logits_data
,
const
int64_t
*
labels_data
,
T
*
loss_data
,
T
*
softmax_data
,
int
batch_size
,
int
feature_size
,
int
ignore_idx
)
{
constexpr
int
kMaxBlockDim
=
512
;
int
block_dim
=
feature_size
>=
kMaxBlockDim
?
kMaxBlockDim
:
(
1
<<
static_cast
<
int
>
(
std
::
log2
(
feature_size
)));
auto
stream
=
ctx
.
stream
();
#define CALL_HARD_LABEL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(BlockDim) \
case BlockDim: { \
RowReductionForMax<T, BlockDim><<<batch_size, BlockDim, 0, stream>>>( \
logits_data, loss_data, feature_size); \
RowReductionForDiffMaxSum<T, BlockDim, \
true><<<batch_size, BlockDim, 0, stream>>>( \
logits_data, loss_data, softmax_data, feature_size); \
platform::ForRange<platform::CUDADeviceContext> for_range( \
ctx, batch_size* feature_size); \
if (ignore_idx >= 0 && ignore_idx < feature_size) { \
for_range(HardLabelSoftmaxWithCrossEntropyFunctorWithIgnoreIdx<T>( \
logits_data, labels_data, loss_data, softmax_data, feature_size, \
ignore_idx)); \
} else { \
for_range(HardLabelSoftmaxWithCrossEntropyFunctor<T>( \
logits_data, labels_data, loss_data, softmax_data, feature_size)); \
} \
} break
switch
(
block_dim
)
{
CALL_HARD_LABEL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL
(
512
);
CALL_HARD_LABEL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL
(
256
);
CALL_HARD_LABEL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL
(
128
);
CALL_HARD_LABEL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL
(
64
);
CALL_HARD_LABEL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL
(
32
);
CALL_HARD_LABEL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL
(
16
);
CALL_HARD_LABEL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL
(
8
);
CALL_HARD_LABEL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL
(
4
);
CALL_HARD_LABEL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL
(
2
);
case
1
:
SetSoftmaxToOneWhenFeatureSizeIsOne
<<<
(
batch_size
+
kMaxBlockDim
-
1
)
/
kMaxBlockDim
,
kMaxBlockDim
,
0
,
stream
>>>
(
softmax_data
,
batch_size
);
cudaMemsetAsync
(
loss_data
,
0
,
batch_size
*
sizeof
(
T
),
stream
);
break
;
default:
PADDLE_THROW
(
"BlockDim must be 2^n in softmax_with_cross_entropy_op"
);
break
;
}
#undef CALL_HARD_LABEL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL
}
template
<
typename
T
>
static
void
SoftmaxWithCrossEntropyFusedKernel
(
const
T
*
logits_data
,
const
T
*
labels_data
,
...
...
@@ -237,7 +372,7 @@ static void SoftmaxWithCrossEntropyFusedKernel(const T* logits_data,
kMaxBlockDim
,
kMaxBlockDim
,
0
,
stream
>>>
(
softmax_data
,
batch_size
);
cudaMemsetAsync
(
loss_data
,
0
,
batch_size
,
stream
);
cudaMemsetAsync
(
loss_data
,
0
,
batch_size
*
sizeof
(
T
)
,
stream
);
break
;
default:
PADDLE_THROW
(
"BlockDim must be 2^n in softmax_with_cross_entropy_op"
);
...
...
@@ -272,11 +407,21 @@ class SoftmaxWithCrossEntropyCUDAKernel : public framework::OpKernel<T> {
logits_data
,
labels_data
,
softmax_data
,
loss_data
,
batch_size
,
feature_size
,
context
.
cuda_device_context
().
stream
());
}
else
{
math
::
SoftmaxCUDNNFunctor
<
T
>
()(
context
.
cuda_device_context
(),
logits
,
softmax
);
math
::
CrossEntropyFunctor
<
platform
::
CUDADeviceContext
,
T
>
()(
context
.
cuda_device_context
(),
loss
,
softmax
,
labels
,
false
,
ignore_index
);
if
(
!
context
.
Attr
<
bool
>
(
"numeric_stable_mode"
))
{
math
::
SoftmaxCUDNNFunctor
<
T
>
()(
context
.
cuda_device_context
(),
logits
,
softmax
);
math
::
CrossEntropyFunctor
<
platform
::
CUDADeviceContext
,
T
>
()(
context
.
cuda_device_context
(),
loss
,
softmax
,
labels
,
false
,
ignore_index
);
}
else
{
int
batch_size
=
logits
->
dims
()[
0
];
int
feature_size
=
logits
->
dims
()[
1
];
auto
*
logits_data
=
logits
->
data
<
T
>
();
auto
*
labels_data
=
labels
->
data
<
int64_t
>
();
HardLabelSoftmaxWithCrossEntropy
<
T
>
(
context
.
cuda_device_context
(),
logits_data
,
labels_data
,
loss_data
,
softmax_data
,
batch_size
,
feature_size
,
ignore_index
);
}
}
}
};
...
...
python/paddle/fluid/layers/nn.py
浏览文件 @
b316437a
...
...
@@ -4713,7 +4713,8 @@ def multiplex(inputs, index):
def
softmax_with_cross_entropy
(
logits
,
label
,
soft_label
=
False
,
ignore_index
=-
100
):
ignore_index
=-
100
,
numeric_stable_mode
=
False
):
"""
**Softmax With Cross Entropy Operator.**
...
...
@@ -4747,6 +4748,18 @@ def softmax_with_cross_entropy(logits,
\\
left(
\\
text{logit}_i -
\\
log
\\
left(
\\
sum_{i=0}^{K}
\\
exp(
\\
text{logit}_i)
\\
right)
\\
right), j = 1,...,K
3) If numeric_stable_mode is True, softmax is calculated first by:
.. math::
max_j =
\\
max_{i=0}^{K}{
\\
text{logit}_i}
log
\\
_max
\\
_sum_j =
\\
log
\\
sum_{i=0}^{K}
\\
exp(logit_i - max_j)
softmax_j =
\\
exp(logit_j - max_j - {log
\\
_max
\\
_sum}_j)
and then cross entropy loss is calculated by softmax and label.
Args:
logits (Variable): The unscaled log probabilities, which is a 2-D tensor
with shape [N x K]. N is the batch_size, and K is the class number.
...
...
@@ -4758,6 +4771,13 @@ def softmax_with_cross_entropy(logits,
ignore_index (int): Specifies a target value that is ignored and does
not contribute to the input gradient. Only valid
if soft_label is set to False. Default: -100
numeric_stable_mode (bool): A flag to indicate whether to use a more
numerically stable algorithm. Only valid
when soft_label is False and GPU is used.
When soft_label is True or CPU is used,
the algorithm is always numerically stable.
Note that the speed may be slower when use
stable algorithm. Default: False
Returns:
Variable: The cross entropy loss is a 2-D tensor with shape [N x 1].
...
...
@@ -4780,8 +4800,11 @@ def softmax_with_cross_entropy(logits,
'Label'
:
label
},
outputs
=
{
'Softmax'
:
softmax
,
'Loss'
:
loss
},
attrs
=
{
'soft_label'
:
soft_label
,
'ignore_index'
:
ignore_index
})
attrs
=
{
'soft_label'
:
soft_label
,
'ignore_index'
:
ignore_index
,
'numeric_stable_mode'
:
numeric_stable_mode
})
return
loss
...
...
python/paddle/fluid/tests/unittests/test_softmax_with_cross_entropy_op.py
浏览文件 @
b316437a
...
...
@@ -26,7 +26,11 @@ class TestSoftmaxWithCrossEntropyOp(OpTest):
Test softmax with cross entropy operator with discreate one-hot labels.
"""
def
initParams
(
self
):
self
.
numeric_stable_mode
=
False
def
setUp
(
self
):
self
.
initParams
()
self
.
op_type
=
"softmax_with_cross_entropy"
batch_size
=
41
class_num
=
37
...
...
@@ -46,6 +50,7 @@ class TestSoftmaxWithCrossEntropyOp(OpTest):
"Softmax"
:
softmax
.
astype
(
"float64"
),
"Loss"
:
cross_entropy
.
astype
(
"float64"
)
}
self
.
attrs
=
{
"numeric_stable_mode"
:
self
.
numeric_stable_mode
}
def
test_check_output
(
self
):
self
.
check_output
()
...
...
@@ -54,6 +59,11 @@ class TestSoftmaxWithCrossEntropyOp(OpTest):
self
.
check_grad
([
"Logits"
],
"Loss"
)
class
TestSoftmaxWithCrossEntropyOpNoCudnn
(
TestSoftmaxWithCrossEntropyOp
):
def
initParams
(
self
):
self
.
numeric_stable_mode
=
True
class
TestSoftmaxWithCrossEntropyOp2
(
OpTest
):
"""
Test softmax with cross entropy operator with soft labels.
...
...
@@ -93,7 +103,11 @@ class TestSoftmaxWithCrossEntropyOp3(OpTest):
Test softmax with cross entropy operator with ignore_index.
"""
def
initParams
(
self
):
self
.
numeric_stable_mode
=
False
def
setUp
(
self
):
self
.
initParams
()
self
.
op_type
=
"softmax_with_cross_entropy"
batch_size
=
41
class_num
=
37
...
...
@@ -114,7 +128,10 @@ class TestSoftmaxWithCrossEntropyOp3(OpTest):
"Softmax"
:
softmax
.
astype
(
"float64"
),
"Loss"
:
cross_entropy
.
astype
(
"float64"
)
}
self
.
attrs
=
{
"ignore_index"
:
ignore_index
}
self
.
attrs
=
{
"ignore_index"
:
ignore_index
,
"numeric_stable_mode"
:
self
.
numeric_stable_mode
}
def
test_check_output
(
self
):
self
.
check_output
()
...
...
@@ -123,5 +140,10 @@ class TestSoftmaxWithCrossEntropyOp3(OpTest):
self
.
check_grad
([
"Logits"
],
"Loss"
)
class
TestSoftmaxWithCrossEntropyOp3NoCudnn
(
TestSoftmaxWithCrossEntropyOp3
):
def
initParams
(
self
):
self
.
numeric_stable_mode
=
True
if
__name__
==
"__main__"
:
unittest
.
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录