Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
99c78b77
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
99c78b77
编写于
9月 16, 2019
作者:
K
Kaipeng Deng
提交者:
GitHub
9月 16, 2019
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix softmax axis!=-1. test=develop (#19800)
上级
6a1db204
变更
1
隐藏空白更改
内联
并排
Showing
1 changed file
with
16 addition
and
13 deletion
+16
-13
paddle/fluid/operators/math/softmax_impl.h
paddle/fluid/operators/math/softmax_impl.h
+16
-13
未找到文件。
paddle/fluid/operators/math/softmax_impl.h
浏览文件 @
99c78b77
...
...
@@ -41,6 +41,7 @@ void SoftmaxEigen(const DeviceContext& context, const int axis_dim,
const
framework
::
Tensor
*
X
,
framework
::
Tensor
*
Y
)
{
constexpr
int
kBatchDim
=
0
;
constexpr
int
kClassDim
=
1
;
constexpr
int
kAxisDim
=
1
;
auto
logits
=
EigenMatrix
<
T
>::
From
(
*
X
);
auto
softmax
=
EigenMatrix
<
T
>::
From
(
*
Y
);
...
...
@@ -49,26 +50,28 @@ void SoftmaxEigen(const DeviceContext& context, const int axis_dim,
const
int
num_classes
=
logits
.
dimension
(
kClassDim
);
const
int
num_remain
=
num_classes
/
axis_dim
;
Eigen
::
DSizes
<
int
,
1
>
along_class
(
kClassDim
);
Eigen
::
DSizes
<
int
,
2
>
batch_by_one
(
batch_size
,
1
);
Eigen
::
DSizes
<
int
,
2
>
one_by_class
(
1
,
num_classes
);
Eigen
::
DSizes
<
int
,
1
>
along_axis
(
kAxisDim
);
Eigen
::
DSizes
<
int
,
2
>
batch_classes
(
batch_size
,
num_classes
);
Eigen
::
DSizes
<
int
,
3
>
batch_one_remain
(
batch_size
,
1
,
num_remain
);
Eigen
::
DSizes
<
int
,
3
>
one_axis_one
(
1
,
axis_dim
,
1
);
Eigen
::
DSizes
<
int
,
3
>
batch_axis_remain
(
batch_size
,
axis_dim
,
num_remain
);
Eigen
::
DSizes
<
int
,
2
>
one_axis
(
1
,
axis_dim
);
auto
shifted_logits
=
(
logits
-
logits
.
maximum
(
along_class
)
auto
logits_reshape
=
logits
.
reshape
(
batch_axis_remain
);
auto
shifted_logits
=
(
logits_reshape
-
logits_reshape
.
maximum
(
along_axis
)
.
eval
()
.
reshape
(
batch_
by_one
)
.
broadcast
(
one_
by_class
))
.
reshape
(
batch_
one_remain
)
.
broadcast
(
one_
axis_one
))
.
unaryExpr
(
ValueClip
<
T
>
());
softmax
.
device
(
*
context
.
eigen_device
())
=
shifted_logits
.
exp
();
softmax
.
device
(
*
context
.
eigen_device
())
=
(
softmax
*
softmax
.
reshape
(
batch_axis_remain
)
.
sum
(
along_class
)
auto
exp
=
shifted_logits
.
exp
();
softmax
.
device
(
*
context
.
eigen_device
())
=
(
exp
*
exp
.
sum
(
along_axis
)
.
inverse
()
.
eval
()
.
broadcast
(
one_axis
));
.
reshape
(
batch_one_remain
)
.
broadcast
(
one_axis_one
))
.
reshape
(
batch_classes
);
}
template
<
typename
DeviceContext
,
typename
T
,
bool
is_test
,
typename
Enable
>
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录