Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Oneflow-Inc
oneflow
提交
ab58fb90
O
oneflow
项目概览
Oneflow-Inc
/
oneflow
上一次同步 2 年多
通知
13
Star
2733
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
O
oneflow
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
ab58fb90
编写于
7月 06, 2017
作者:
W
Will Zhang
提交者:
GitHub
7月 06, 2017
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #182 from Oneflow-Inc/dev_chengcheng
fix wrong design of softmax_op
上级
f98bc438
52700be7
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
9 addition
and
8 deletion
+9
-8
oneflow/core/operator/op_conf.proto
oneflow/core/operator/op_conf.proto
+0
-1
oneflow/core/operator/softmax_op.cpp
oneflow/core/operator/softmax_op.cpp
+3
-3
oneflow/core/operator/softmax_op_test.cpp
oneflow/core/operator/softmax_op_test.cpp
+6
-4
未找到文件。
oneflow/core/operator/op_conf.proto
浏览文件 @
ab58fb90
...
...
@@ -128,7 +128,6 @@ message ReluOpConf {
message
SoftmaxOpConf
{
string
in
=
1
;
string
out
=
2
;
int32
axis
=
3
;
}
message
MultinomialLogisticLossOpConf
{
...
...
oneflow/core/operator/softmax_op.cpp
浏览文件 @
ab58fb90
...
...
@@ -8,6 +8,7 @@ void SoftmaxOp::InitFromOpConf(const OperatorConf& op_conf) {
EnrollInputBn
(
"in"
);
EnrollOutputBn
(
"out"
);
EnrollDataTmpBn
(
"tmp_max"
);
}
const
PbMessage
&
SoftmaxOp
::
GetSpecialConf
()
const
{
...
...
@@ -18,10 +19,9 @@ void SoftmaxOp::InferShape4FwBlobs(
std
::
function
<
Shape
*
(
const
std
::
string
&
)
>
GetShapePtr4BnInOp
,
ParallelPolicy
policy
,
int64_t
parallel_id
,
int64_t
parallel_num
)
const
{
std
::
vector
<
int64_t
>
vec
=
GetShapePtr4BnInOp
(
SoleIbn
())
->
dim_vec
();
CHECK_GT
(
vec
.
size
(),
1
);
int32_t
axis
=
(
op_conf
().
softmax_conf
().
axis
()
+
vec
.
size
())
%
vec
.
size
();
vec
.
erase
(
vec
.
begin
()
+
axis
);
CHECK_GE
(
vec
.
size
(),
2
);
*
GetShapePtr4BnInOp
(
SoleObn
())
=
Shape
(
vec
);
*
GetShapePtr4BnInOp
(
SoleDtbn
())
=
Shape
({
vec
[
0
]});
}
REGISTER_OP
(
OperatorConf
::
kSoftmaxConf
,
SoftmaxOp
);
...
...
oneflow/core/operator/softmax_op_test.cpp
浏览文件 @
ab58fb90
...
...
@@ -2,17 +2,17 @@
namespace
oneflow
{
TEST
(
SoftmaxOp
,
softmax_3x
4x
5
)
{
TEST
(
SoftmaxOp
,
softmax_3x5
)
{
// create softmax_op
OperatorConf
op_conf
;
op_conf
.
set_name
(
"softmax_test"
);
op_conf
.
mutable_softmax_conf
()
->
set_axis
(
1
);
op_conf
.
mutable_softmax_conf
()
->
set_in
(
"softmax/in"
);
op_conf
.
mutable_softmax_conf
()
->
set_out
(
"softmax/out"
);
auto
softmax_op
=
OpMgr
::
Singleton
()
->
ConstructOp
(
op_conf
);
HashMap
<
std
::
string
,
Shape
*>
bn2shape_ptr
{
{
softmax_op
->
SoleIbn
(),
new
Shape
({
3
,
4
,
5
})},
{
softmax_op
->
SoleObn
(),
new
Shape
}};
{
softmax_op
->
SoleIbn
(),
new
Shape
({
3
,
5
})},
{
softmax_op
->
SoleObn
(),
new
Shape
},
{
softmax_op
->
SoleDtbn
(),
new
Shape
}};
auto
fp
=
[
&
bn2shape_ptr
](
const
std
::
string
&
bn
)
{
return
bn2shape_ptr
.
at
(
bn
);
};
...
...
@@ -20,7 +20,9 @@ TEST(SoftmaxOp, softmax_3x4x5) {
softmax_op
->
InferShape4FwBlobs
(
fp
,
kDataParallel
,
0
,
1
);
// test
Shape
*
output_shape_ptr
=
fp
(
softmax_op
->
SoleObn
());
Shape
*
tmp_max_shape_ptr
=
fp
(
softmax_op
->
SoleDtbn
());
ASSERT_EQ
(
*
output_shape_ptr
,
Shape
({
3
,
5
}));
ASSERT_EQ
(
*
tmp_max_shape_ptr
,
Shape
({
3
}));
}
}
// namespace oneflow
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录