diff --git a/oneflow/core/job_completer/user_job_completer.cpp b/oneflow/core/job_completer/user_job_completer.cpp index 3b9e63b95aed855bc2ee072cdb1b9547376f71e4..5bcfc0b4bc84b6b24ebbf7a27b8132c215b7ed4e 100644 --- a/oneflow/core/job_completer/user_job_completer.cpp +++ b/oneflow/core/job_completer/user_job_completer.cpp @@ -151,22 +151,12 @@ void FixInputOpParallelConf(Job* job) { }); } -void FixReturnOpParallelConf(Job* job) { - JobBuilder job_builder(job); - for (const auto& op_conf : job->net().op()) { - if (op_conf.has_return_conf() == false) { continue; } - LogicalBlobId lbi = GenLogicalBlobId(op_conf.return_conf().in()); - job_builder.MutParallelConfOnlyOnce(op_conf.name(), job_builder.ParallelConf4Lbi(lbi)); - } -} - } // namespace void UserJobCompleter::Complete(Job* job) const { SplitDecodeOps(job); AddRecordLoadOps(job); FixInputOpParallelConf(job); - FixReturnOpParallelConf(job); } } // namespace oneflow diff --git a/oneflow/core/operator/return_op.cpp b/oneflow/core/operator/return_op.cpp index 63b1c740bd23f1efd56b51c8e4941af6c6004c26..11fb8da1221e269acda6598bb5ea80f3eef9a339 100644 --- a/oneflow/core/operator/return_op.cpp +++ b/oneflow/core/operator/return_op.cpp @@ -31,7 +31,7 @@ Maybe ReturnOp::InferSbpSignature( std::function(const std::string&)> SbpInferHint4Ibn, const ParallelDesc& parallel_desc) const { const auto& in_sbp_infer_hint = *JUST(SbpInferHint4Ibn("in")); - CHECK_OR_RETURN(in_sbp_infer_hint.parallel_desc() == parallel_desc); + OF_CHECK_EQ(in_sbp_infer_hint.parallel_desc().parallel_num(), parallel_desc.parallel_num()); if (in_sbp_infer_hint.sbp_parallel().has_partial_sum_parallel()) { SbpSignatureBuilder().Broadcast(input_bns()).Broadcast(output_bns()).Build(sbp_signature); } else { diff --git a/oneflow/python/framework/compile_context.py b/oneflow/python/framework/compile_context.py index ed3cf51f145fff2274d2d8a0e5a53fbd5b33ad52..c8c5e91be2526ec07917b8b921c0377ac2035039 100644 --- a/oneflow/python/framework/compile_context.py +++ b/oneflow/python/framework/compile_context.py @@ -32,7 +32,8 @@ def ResetCurJobContext(): def GetOpConfAndParallelConf(op_conf, parallel_conf=None): _PrependOpNamePrefixIfNeed(op_conf) - op_conf.device_type = placement_context.CurPlacementGroupGetDeviceType(op_conf) + if not op_conf.HasField('device_type'): + op_conf.device_type = placement_context.CurPlacementGroupGetDeviceType(op_conf) if parallel_conf is None: parallel_conf = placement_context.ParallelConf4OpConf(op_conf) return op_conf, parallel_conf diff --git a/oneflow/python/framework/compiler.py b/oneflow/python/framework/compiler.py index e1c1fff455cea889a14064091f4919bf4838e8cc..8797b2e6d5f741f1eea3913822f181cfeb81bba0 100644 --- a/oneflow/python/framework/compiler.py +++ b/oneflow/python/framework/compiler.py @@ -31,12 +31,14 @@ def Compile(function_desc, config_proto): compile_context.ResetCurJobContext() with _JobBuildAndInferCtx(job_conf.job_name), placement_scope, distribute_strategy: c_api_util.CurJobBuildAndInferCtx_SetJobConf(job_conf) - _CompileJob(function_desc.job_func) + _CompileJob(function_desc) -def _CompileJob(func): +def _CompileJob(function_desc): + func = function_desc.job_func func.__oneflow_input_blob_defs__ = _GetArgDefault(func) inputs = _RecursiveMakeInputBlobs(func.__oneflow_input_blob_defs__) - func.__oneflow_output_remote_blobs__ = _RecursiveMakeRetRemoteBlobs(func(*inputs)) + kwarg = dict(allow_cpu_return_op=function_desc.function_attribute.allow_cpu_return_op) + func.__oneflow_output_remote_blobs__ = _RecursiveMakeRetRemoteBlobs(func(*inputs), kwarg) @contextmanager def _JobBuildAndInferCtx(job_name): @@ -64,13 +66,13 @@ def _RecursiveMakeInputBlobs(input_blob_def): raise NotImplementedError("oneflow.function accepts " + "ArgBlobDefs or list/tuple/dict nested ArgBlobDefs as argument") -def _RecursiveMakeRetRemoteBlobs(out_remote_blobs): - if out_remote_blobs is None: return None - if isinstance(out_remote_blobs, remote_blob_util.BlobDef): - return ops.RetOpByRemoteBlob(out_remote_blobs) - if isinstance(out_remote_blobs, (tuple, list)): - return type(out_remote_blobs)(_RecursiveMakeRetRemoteBlobs(x) for x in out_remote_blobs) - if isinstance(out_remote_blobs, dict): - return {k : _RecursiveMakeRetRemoteBlobs(v) for k, v in out_remote_blobs.items()} +def _RecursiveMakeRetRemoteBlobs(remote_blobs, kwarg): + if remote_blobs is None: return None + if isinstance(remote_blobs, remote_blob_util.BlobDef): + return ops.RetOpByRemoteBlob(remote_blobs, **kwarg) + if isinstance(remote_blobs, (tuple, list)): + return type(remote_blobs)(_RecursiveMakeRetRemoteBlobs(x, kwarg) for x in remote_blobs) + if isinstance(remote_blobs, dict): + return {k : _RecursiveMakeRetRemoteBlobs(v, kwarg) for k, v in remote_blobs.items()} raise NotImplementedError("oneflow.function returns " + "RemoteBlob or list/tuple/dict nested RemoteBlob only") diff --git a/oneflow/python/framework/function_desc.py b/oneflow/python/framework/function_desc.py index 54e427a4dd331344023157c7549a1660ae16d0fd..1d057f9fbe41f0221c7d85e6853e4d9787769688 100644 --- a/oneflow/python/framework/function_desc.py +++ b/oneflow/python/framework/function_desc.py @@ -6,6 +6,7 @@ class FunctionAttribute(object): def __init__(self): self.default_placement_scope = None self.default_distribute_strategy = None + self.allow_cpu_return_op = True class FunctionDesc(object): def __init__(self, job_func=None, job_config_proto=None, function_attribute=None): diff --git a/oneflow/python/framework/function_util.py b/oneflow/python/framework/function_util.py index e133c77864ff9d33722c84a6adf20f173193fa83..f8a88aec1a432781a68a63ad77a5c4fd6005ff9b 100644 --- a/oneflow/python/framework/function_util.py +++ b/oneflow/python/framework/function_util.py @@ -299,3 +299,7 @@ def set_tensorrt_use_int8(func_desc, value = True): def set_default_distribute_strategy(func_desc, value): assert isinstance(value, distribute_ctx.DistributeStrategy) func_desc.function_attribute.default_distribute_strategy = value + +@oneflow_function_config('allow_cpu_return_op') +def allow_cpu_return_op(func_desc, value): + func_desc.function_attribute.allow_cpu_return_op = value diff --git a/oneflow/python/ops/__init__.py b/oneflow/python/ops/__init__.py index 6de3302944dcb64e5d58e4e5ab1ddd7fbb451563..5f30e8defd8fb4ed997ed2158d9bd0c5f67dc3b7 100644 --- a/oneflow/python/ops/__init__.py +++ b/oneflow/python/ops/__init__.py @@ -7,6 +7,8 @@ import oneflow.python.framework.id_util as id_util import oneflow.python.framework.c_api_util as c_api_util import oneflow.core.operator.op_conf_pb2 as op_conf_util import oneflow.core.register.logical_blob_id_pb2 as logical_blob_id_util +import oneflow.core.job.placement_pb2 as placement_proto_pb +import re def InputOpByArgBlobDef(blob_def): assert isinstance(blob_def, input_blob_util.ArgBlobDef) @@ -18,12 +20,18 @@ def InputOpByArgBlobDef(blob_def): blob_def.AddAndInferOp(op_conf) return remote_blob_util.RemoteBlob(blob_def.lbi) -def RetOpByRemoteBlob(remote_blob): +def RetOpByRemoteBlob(remote_blob, allow_cpu_return_op = True): op_conf = op_conf_util.OperatorConf() op_conf.name = id_util.UniqueStr('Return_') setattr(op_conf.return_conf, 'in', remote_blob.logical_blob_name) op_conf.return_conf.out = "out" - compile_context.CurJobAddOp(op_conf, remote_blob.parallel_conf) + parallel_conf = placement_proto_pb.ParallelConf() + parallel_conf.CopyFrom(remote_blob.parallel_conf) + if allow_cpu_return_op: + op_conf.device_type = c_api_util.DeviceType4DeviceTag('cpu') + for i in range(len(parallel_conf.device_name)): + parallel_conf.device_name[i] = re.sub(":\w+:", ":cpu:", parallel_conf.device_name[i]) + compile_context.CurJobAddOp(op_conf, parallel_conf) lbi = logical_blob_id_util.LogicalBlobId() lbi.op_name = op_conf.name lbi.blob_name = "out"