Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Oneflow-Inc
oneflow
提交
77865ab3
O
oneflow
项目概览
Oneflow-Inc
/
oneflow
上一次同步 接近 3 年
通知
13
Star
2733
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
O
oneflow
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
77865ab3
编写于
1月 02, 2020
作者:
L
lixinqi
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
avoid CHECK failed in SetCtrlInOpName4VariableOp
上级
ed0a7a38
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
12 addition
and
5 deletion
+12
-5
oneflow/core/job_completer/job_completer.cpp
oneflow/core/job_completer/job_completer.cpp
+9
-3
oneflow/python/test/ops/test_activations.py
oneflow/python/test/ops/test_activations.py
+2
-1
oneflow/python/test/ops/test_batch_normalization.py
oneflow/python/test/ops/test_batch_normalization.py
+1
-1
未找到文件。
oneflow/core/job_completer/job_completer.cpp
浏览文件 @
77865ab3
...
...
@@ -274,6 +274,7 @@ void SetCtrlInOpName4VariableOp(const OpGraph& op_graph, JobBuilder* job_builder
}
return
false
;
};
HashMap
<
const
OperatorConf
*
,
HashSet
<
std
::
string
>>
op_conf2ctrl_in_op_names
;
op_graph
.
ForEachNode
([
&
](
OpNode
*
op_node
)
{
if
(
op_node
->
op
().
op_conf
().
has_variable_conf
()
==
false
)
{
return
;
}
if
(
op_node
->
out_edges
().
size
()
<=
1
)
{
return
;
}
...
...
@@ -291,12 +292,17 @@ void SetCtrlInOpName4VariableOp(const OpGraph& op_graph, JobBuilder* job_builder
}
}
if
(
mutable_consumer
==
nullptr
)
{
return
;
}
OperatorConf
mut_mutable_consumer_op_conf
(
*
mutable_consumer
);
for
(
const
auto
*
fw_bw_op
:
naive_consumers
)
{
mut_mutable_consumer_op_conf
.
add_ctrl_in_op_name
(
fw_bw_op
->
name
());
op_conf2ctrl_in_op_names
[
mutable_consumer
].
insert
(
fw_bw_op
->
name
());
}
job_builder
->
MutOpsOnlyOnce
({
mut_mutable_consumer_op_conf
});
});
for
(
const
auto
&
pair
:
op_conf2ctrl_in_op_names
)
{
OperatorConf
mut_mutable_consumer_op_conf
(
*
pair
.
first
);
for
(
const
auto
&
fw_bw_op_name
:
pair
.
second
)
{
mut_mutable_consumer_op_conf
.
add_ctrl_in_op_name
(
fw_bw_op_name
);
}
job_builder
->
MutOpsOnlyOnce
({
mut_mutable_consumer_op_conf
});
}
}
void
SetOpTimeShape7BatchAxisLbis
(
const
OpGraph
&
op_graph
,
JobBuilder
*
job_builder
)
{
...
...
oneflow/python/test/ops/test_activations.py
浏览文件 @
77865ab3
...
...
@@ -73,7 +73,8 @@ def compare_with_tensorflow(device_type, activation_type, shape):
def
test_activations
(
test_case
):
arg_dict
=
OrderedDict
()
arg_dict
[
"device_type"
]
=
[
"gpu"
]
arg_dict
[
"activation_type"
]
=
[
"relu"
,
"sigmoid"
,
"tanh"
,
"gelu"
]
# arg_dict["activation_type"] = ["relu", "sigmoid", "tanh", "gelu"]
arg_dict
[
"activation_type"
]
=
[
"relu"
,
"sigmoid"
,
"tanh"
]
arg_dict
[
"shape"
]
=
[(
1024
,
1024
)]
for
arg
in
GenArgList
(
arg_dict
):
compare_with_tensorflow
(
*
arg
)
oneflow/python/test/ops/test_batch_normalization.py
浏览文件 @
77865ab3
...
...
@@ -43,7 +43,7 @@ def TODO_test_train(test_case):
flow
.
losses
.
add_loss
(
flow
.
math
.
reduce_sum
(
y
))
Foo
(
np
.
ones
((
2
,
8
,
32
,
32
),
dtype
=
np
.
float32
))
def
TODO_
test_watch_scope
(
test_case
):
def
test_watch_scope
(
test_case
):
func_config
=
flow
.
FunctionConfig
()
func_config
.
default_distribute_strategy
(
flow
.
distribute
.
consistent_strategy
())
func_config
.
default_data_type
(
flow
.
float32
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录