提交 6be466ac 编写于 作者: L lixinqi

Merge branch 'dev_allow_cpu_return_op' into dev_python

......@@ -151,22 +151,12 @@ void FixInputOpParallelConf(Job* job) {
});
}
void FixReturnOpParallelConf(Job* job) {
JobBuilder job_builder(job);
for (const auto& op_conf : job->net().op()) {
if (op_conf.has_return_conf() == false) { continue; }
LogicalBlobId lbi = GenLogicalBlobId(op_conf.return_conf().in());
job_builder.MutParallelConfOnlyOnce(op_conf.name(), job_builder.ParallelConf4Lbi(lbi));
}
}
} // namespace
void UserJobCompleter::Complete(Job* job) const {
SplitDecodeOps(job);
AddRecordLoadOps(job);
FixInputOpParallelConf(job);
FixReturnOpParallelConf(job);
}
} // namespace oneflow
......@@ -31,7 +31,7 @@ Maybe<void> ReturnOp::InferSbpSignature(
std::function<Maybe<const SbpInferHint*>(const std::string&)> SbpInferHint4Ibn,
const ParallelDesc& parallel_desc) const {
const auto& in_sbp_infer_hint = *JUST(SbpInferHint4Ibn("in"));
CHECK_OR_RETURN(in_sbp_infer_hint.parallel_desc() == parallel_desc);
OF_CHECK_EQ(in_sbp_infer_hint.parallel_desc().parallel_num(), parallel_desc.parallel_num());
if (in_sbp_infer_hint.sbp_parallel().has_partial_sum_parallel()) {
SbpSignatureBuilder().Broadcast(input_bns()).Broadcast(output_bns()).Build(sbp_signature);
} else {
......
......@@ -32,7 +32,8 @@ def ResetCurJobContext():
def GetOpConfAndParallelConf(op_conf, parallel_conf=None):
_PrependOpNamePrefixIfNeed(op_conf)
op_conf.device_type = placement_context.CurPlacementGroupGetDeviceType(op_conf)
if not op_conf.HasField('device_type'):
op_conf.device_type = placement_context.CurPlacementGroupGetDeviceType(op_conf)
if parallel_conf is None: parallel_conf = placement_context.ParallelConf4OpConf(op_conf)
return op_conf, parallel_conf
......
......@@ -31,12 +31,14 @@ def Compile(function_desc, config_proto):
compile_context.ResetCurJobContext()
with _JobBuildAndInferCtx(job_conf.job_name), placement_scope, distribute_strategy:
c_api_util.CurJobBuildAndInferCtx_SetJobConf(job_conf)
_CompileJob(function_desc.job_func)
_CompileJob(function_desc)
def _CompileJob(func):
def _CompileJob(function_desc):
func = function_desc.job_func
func.__oneflow_input_blob_defs__ = _GetArgDefault(func)
inputs = _RecursiveMakeInputBlobs(func.__oneflow_input_blob_defs__)
func.__oneflow_output_remote_blobs__ = _RecursiveMakeRetRemoteBlobs(func(*inputs))
kwarg = dict(allow_cpu_return_op=function_desc.function_attribute.allow_cpu_return_op)
func.__oneflow_output_remote_blobs__ = _RecursiveMakeRetRemoteBlobs(func(*inputs), kwarg)
@contextmanager
def _JobBuildAndInferCtx(job_name):
......@@ -64,13 +66,13 @@ def _RecursiveMakeInputBlobs(input_blob_def):
raise NotImplementedError("oneflow.function accepts "
+ "ArgBlobDefs or list/tuple/dict nested ArgBlobDefs as argument")
def _RecursiveMakeRetRemoteBlobs(out_remote_blobs):
if out_remote_blobs is None: return None
if isinstance(out_remote_blobs, remote_blob_util.BlobDef):
return ops.RetOpByRemoteBlob(out_remote_blobs)
if isinstance(out_remote_blobs, (tuple, list)):
return type(out_remote_blobs)(_RecursiveMakeRetRemoteBlobs(x) for x in out_remote_blobs)
if isinstance(out_remote_blobs, dict):
return {k : _RecursiveMakeRetRemoteBlobs(v) for k, v in out_remote_blobs.items()}
def _RecursiveMakeRetRemoteBlobs(remote_blobs, kwarg):
if remote_blobs is None: return None
if isinstance(remote_blobs, remote_blob_util.BlobDef):
return ops.RetOpByRemoteBlob(remote_blobs, **kwarg)
if isinstance(remote_blobs, (tuple, list)):
return type(remote_blobs)(_RecursiveMakeRetRemoteBlobs(x, kwarg) for x in remote_blobs)
if isinstance(remote_blobs, dict):
return {k : _RecursiveMakeRetRemoteBlobs(v, kwarg) for k, v in remote_blobs.items()}
raise NotImplementedError("oneflow.function returns "
+ "RemoteBlob or list/tuple/dict nested RemoteBlob only")
......@@ -6,6 +6,7 @@ class FunctionAttribute(object):
def __init__(self):
self.default_placement_scope = None
self.default_distribute_strategy = None
self.allow_cpu_return_op = True
class FunctionDesc(object):
def __init__(self, job_func=None, job_config_proto=None, function_attribute=None):
......
......@@ -299,3 +299,7 @@ def set_tensorrt_use_int8(func_desc, value = True):
def set_default_distribute_strategy(func_desc, value):
assert isinstance(value, distribute_ctx.DistributeStrategy)
func_desc.function_attribute.default_distribute_strategy = value
@oneflow_function_config('allow_cpu_return_op')
def allow_cpu_return_op(func_desc, value):
func_desc.function_attribute.allow_cpu_return_op = value
......@@ -7,6 +7,8 @@ import oneflow.python.framework.id_util as id_util
import oneflow.python.framework.c_api_util as c_api_util
import oneflow.core.operator.op_conf_pb2 as op_conf_util
import oneflow.core.register.logical_blob_id_pb2 as logical_blob_id_util
import oneflow.core.job.placement_pb2 as placement_proto_pb
import re
def InputOpByArgBlobDef(blob_def):
assert isinstance(blob_def, input_blob_util.ArgBlobDef)
......@@ -18,12 +20,18 @@ def InputOpByArgBlobDef(blob_def):
blob_def.AddAndInferOp(op_conf)
return remote_blob_util.RemoteBlob(blob_def.lbi)
def RetOpByRemoteBlob(remote_blob):
def RetOpByRemoteBlob(remote_blob, allow_cpu_return_op = True):
op_conf = op_conf_util.OperatorConf()
op_conf.name = id_util.UniqueStr('Return_')
setattr(op_conf.return_conf, 'in', remote_blob.logical_blob_name)
op_conf.return_conf.out = "out"
compile_context.CurJobAddOp(op_conf, remote_blob.parallel_conf)
parallel_conf = placement_proto_pb.ParallelConf()
parallel_conf.CopyFrom(remote_blob.parallel_conf)
if allow_cpu_return_op:
op_conf.device_type = c_api_util.DeviceType4DeviceTag('cpu')
for i in range(len(parallel_conf.device_name)):
parallel_conf.device_name[i] = re.sub(":\w+:", ":cpu:", parallel_conf.device_name[i])
compile_context.CurJobAddOp(op_conf, parallel_conf)
lbi = logical_blob_id_util.LogicalBlobId()
lbi.op_name = op_conf.name
lbi.blob_name = "out"
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册