Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Oneflow-Inc
oneflow
提交
6be466ac
O
oneflow
项目概览
Oneflow-Inc
/
oneflow
上一次同步 2 年多
通知
13
Star
2733
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
O
oneflow
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
6be466ac
编写于
12月 31, 2019
作者:
L
lixinqi
浏览文件
操作
浏览文件
下载
差异文件
Merge branch 'dev_allow_cpu_return_op' into dev_python
上级
a0b87f1b
8640a7c2
变更
7
隐藏空白更改
内联
并排
Showing
7 changed file
with
31 addition
and
25 deletion
+31
-25
oneflow/core/job_completer/user_job_completer.cpp
oneflow/core/job_completer/user_job_completer.cpp
+0
-10
oneflow/core/operator/return_op.cpp
oneflow/core/operator/return_op.cpp
+1
-1
oneflow/python/framework/compile_context.py
oneflow/python/framework/compile_context.py
+2
-1
oneflow/python/framework/compiler.py
oneflow/python/framework/compiler.py
+13
-11
oneflow/python/framework/function_desc.py
oneflow/python/framework/function_desc.py
+1
-0
oneflow/python/framework/function_util.py
oneflow/python/framework/function_util.py
+4
-0
oneflow/python/ops/__init__.py
oneflow/python/ops/__init__.py
+10
-2
未找到文件。
oneflow/core/job_completer/user_job_completer.cpp
浏览文件 @
6be466ac
...
...
@@ -151,22 +151,12 @@ void FixInputOpParallelConf(Job* job) {
});
}
void
FixReturnOpParallelConf
(
Job
*
job
)
{
JobBuilder
job_builder
(
job
);
for
(
const
auto
&
op_conf
:
job
->
net
().
op
())
{
if
(
op_conf
.
has_return_conf
()
==
false
)
{
continue
;
}
LogicalBlobId
lbi
=
GenLogicalBlobId
(
op_conf
.
return_conf
().
in
());
job_builder
.
MutParallelConfOnlyOnce
(
op_conf
.
name
(),
job_builder
.
ParallelConf4Lbi
(
lbi
));
}
}
}
// namespace
void
UserJobCompleter
::
Complete
(
Job
*
job
)
const
{
SplitDecodeOps
(
job
);
AddRecordLoadOps
(
job
);
FixInputOpParallelConf
(
job
);
FixReturnOpParallelConf
(
job
);
}
}
// namespace oneflow
oneflow/core/operator/return_op.cpp
浏览文件 @
6be466ac
...
...
@@ -31,7 +31,7 @@ Maybe<void> ReturnOp::InferSbpSignature(
std
::
function
<
Maybe
<
const
SbpInferHint
*>
(
const
std
::
string
&
)
>
SbpInferHint4Ibn
,
const
ParallelDesc
&
parallel_desc
)
const
{
const
auto
&
in_sbp_infer_hint
=
*
JUST
(
SbpInferHint4Ibn
(
"in"
));
CHECK_OR_RETURN
(
in_sbp_infer_hint
.
parallel_desc
()
==
parallel_desc
);
OF_CHECK_EQ
(
in_sbp_infer_hint
.
parallel_desc
().
parallel_num
(),
parallel_desc
.
parallel_num
()
);
if
(
in_sbp_infer_hint
.
sbp_parallel
().
has_partial_sum_parallel
())
{
SbpSignatureBuilder
().
Broadcast
(
input_bns
()).
Broadcast
(
output_bns
()).
Build
(
sbp_signature
);
}
else
{
...
...
oneflow/python/framework/compile_context.py
浏览文件 @
6be466ac
...
...
@@ -32,7 +32,8 @@ def ResetCurJobContext():
def
GetOpConfAndParallelConf
(
op_conf
,
parallel_conf
=
None
):
_PrependOpNamePrefixIfNeed
(
op_conf
)
op_conf
.
device_type
=
placement_context
.
CurPlacementGroupGetDeviceType
(
op_conf
)
if
not
op_conf
.
HasField
(
'device_type'
):
op_conf
.
device_type
=
placement_context
.
CurPlacementGroupGetDeviceType
(
op_conf
)
if
parallel_conf
is
None
:
parallel_conf
=
placement_context
.
ParallelConf4OpConf
(
op_conf
)
return
op_conf
,
parallel_conf
...
...
oneflow/python/framework/compiler.py
浏览文件 @
6be466ac
...
...
@@ -31,12 +31,14 @@ def Compile(function_desc, config_proto):
compile_context
.
ResetCurJobContext
()
with
_JobBuildAndInferCtx
(
job_conf
.
job_name
),
placement_scope
,
distribute_strategy
:
c_api_util
.
CurJobBuildAndInferCtx_SetJobConf
(
job_conf
)
_CompileJob
(
function_desc
.
job_func
)
_CompileJob
(
function_desc
)
def
_CompileJob
(
func
):
def
_CompileJob
(
function_desc
):
func
=
function_desc
.
job_func
func
.
__oneflow_input_blob_defs__
=
_GetArgDefault
(
func
)
inputs
=
_RecursiveMakeInputBlobs
(
func
.
__oneflow_input_blob_defs__
)
func
.
__oneflow_output_remote_blobs__
=
_RecursiveMakeRetRemoteBlobs
(
func
(
*
inputs
))
kwarg
=
dict
(
allow_cpu_return_op
=
function_desc
.
function_attribute
.
allow_cpu_return_op
)
func
.
__oneflow_output_remote_blobs__
=
_RecursiveMakeRetRemoteBlobs
(
func
(
*
inputs
),
kwarg
)
@
contextmanager
def
_JobBuildAndInferCtx
(
job_name
):
...
...
@@ -64,13 +66,13 @@ def _RecursiveMakeInputBlobs(input_blob_def):
raise
NotImplementedError
(
"oneflow.function accepts "
+
"ArgBlobDefs or list/tuple/dict nested ArgBlobDefs as argument"
)
def
_RecursiveMakeRetRemoteBlobs
(
out_remote_blobs
):
if
out_
remote_blobs
is
None
:
return
None
if
isinstance
(
out_
remote_blobs
,
remote_blob_util
.
BlobDef
):
return
ops
.
RetOpByRemoteBlob
(
out_remote_blobs
)
if
isinstance
(
out_
remote_blobs
,
(
tuple
,
list
)):
return
type
(
out_remote_blobs
)(
_RecursiveMakeRetRemoteBlobs
(
x
)
for
x
in
out_
remote_blobs
)
if
isinstance
(
out_
remote_blobs
,
dict
):
return
{
k
:
_RecursiveMakeRetRemoteBlobs
(
v
)
for
k
,
v
in
out_
remote_blobs
.
items
()}
def
_RecursiveMakeRetRemoteBlobs
(
remote_blobs
,
kwarg
):
if
remote_blobs
is
None
:
return
None
if
isinstance
(
remote_blobs
,
remote_blob_util
.
BlobDef
):
return
ops
.
RetOpByRemoteBlob
(
remote_blobs
,
**
kwarg
)
if
isinstance
(
remote_blobs
,
(
tuple
,
list
)):
return
type
(
remote_blobs
)(
_RecursiveMakeRetRemoteBlobs
(
x
,
kwarg
)
for
x
in
remote_blobs
)
if
isinstance
(
remote_blobs
,
dict
):
return
{
k
:
_RecursiveMakeRetRemoteBlobs
(
v
,
kwarg
)
for
k
,
v
in
remote_blobs
.
items
()}
raise
NotImplementedError
(
"oneflow.function returns "
+
"RemoteBlob or list/tuple/dict nested RemoteBlob only"
)
oneflow/python/framework/function_desc.py
浏览文件 @
6be466ac
...
...
@@ -6,6 +6,7 @@ class FunctionAttribute(object):
def
__init__
(
self
):
self
.
default_placement_scope
=
None
self
.
default_distribute_strategy
=
None
self
.
allow_cpu_return_op
=
True
class
FunctionDesc
(
object
):
def
__init__
(
self
,
job_func
=
None
,
job_config_proto
=
None
,
function_attribute
=
None
):
...
...
oneflow/python/framework/function_util.py
浏览文件 @
6be466ac
...
...
@@ -299,3 +299,7 @@ def set_tensorrt_use_int8(func_desc, value = True):
def
set_default_distribute_strategy
(
func_desc
,
value
):
assert
isinstance
(
value
,
distribute_ctx
.
DistributeStrategy
)
func_desc
.
function_attribute
.
default_distribute_strategy
=
value
@
oneflow_function_config
(
'allow_cpu_return_op'
)
def
allow_cpu_return_op
(
func_desc
,
value
):
func_desc
.
function_attribute
.
allow_cpu_return_op
=
value
oneflow/python/ops/__init__.py
浏览文件 @
6be466ac
...
...
@@ -7,6 +7,8 @@ import oneflow.python.framework.id_util as id_util
import
oneflow.python.framework.c_api_util
as
c_api_util
import
oneflow.core.operator.op_conf_pb2
as
op_conf_util
import
oneflow.core.register.logical_blob_id_pb2
as
logical_blob_id_util
import
oneflow.core.job.placement_pb2
as
placement_proto_pb
import
re
def
InputOpByArgBlobDef
(
blob_def
):
assert
isinstance
(
blob_def
,
input_blob_util
.
ArgBlobDef
)
...
...
@@ -18,12 +20,18 @@ def InputOpByArgBlobDef(blob_def):
blob_def
.
AddAndInferOp
(
op_conf
)
return
remote_blob_util
.
RemoteBlob
(
blob_def
.
lbi
)
def
RetOpByRemoteBlob
(
remote_blob
):
def
RetOpByRemoteBlob
(
remote_blob
,
allow_cpu_return_op
=
True
):
op_conf
=
op_conf_util
.
OperatorConf
()
op_conf
.
name
=
id_util
.
UniqueStr
(
'Return_'
)
setattr
(
op_conf
.
return_conf
,
'in'
,
remote_blob
.
logical_blob_name
)
op_conf
.
return_conf
.
out
=
"out"
compile_context
.
CurJobAddOp
(
op_conf
,
remote_blob
.
parallel_conf
)
parallel_conf
=
placement_proto_pb
.
ParallelConf
()
parallel_conf
.
CopyFrom
(
remote_blob
.
parallel_conf
)
if
allow_cpu_return_op
:
op_conf
.
device_type
=
c_api_util
.
DeviceType4DeviceTag
(
'cpu'
)
for
i
in
range
(
len
(
parallel_conf
.
device_name
)):
parallel_conf
.
device_name
[
i
]
=
re
.
sub
(
":\w+:"
,
":cpu:"
,
parallel_conf
.
device_name
[
i
])
compile_context
.
CurJobAddOp
(
op_conf
,
parallel_conf
)
lbi
=
logical_blob_id_util
.
LogicalBlobId
()
lbi
.
op_name
=
op_conf
.
name
lbi
.
blob_name
=
"out"
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录