未验证 提交 cad594bd 编写于 作者: O OuYang Yu 提交者: GitHub

Dev replace py parallel conf proto to cfg (#3810)

* Replace the py instruction with CFG Instruction

* move RunInstruction to pybind & refactor EagerOneflow's interface by cfg

* use forward declaration

* Remove unused import

* fix code style

* Adjust import order

* replace py parallel_conf proto to cfg

* fix test_cpu_only_user_op parallel_conf

* move RunInstruction api to oneflow_api.vm

* fix MakeMachineId2DeviceIdList

* remove useless line in oneflow_internal.i

* replace args str to cfg_obj in python callback

* add forward declear of InstructionListProto

* cancel forward declear of InstructionListProto

* fix a name spelling mistake

* use the CFG object in the Python Callback

* use cfg in the py Callback

* fix redundant conversions

* fix template bug

* fix template

* virtual const tmplate cfg

* fix template.cfg.h

* update cfg

* update cfg
Co-authored-by: qq_22305325's avatarclackhan <han_binbin@163.com>
Co-authored-by: Noneflow-bot <69100618+oneflow-bot@users.noreply.github.com>
Co-authored-by: Noneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com>
上级 9f504dfc
......@@ -16,9 +16,9 @@ limitations under the License.
from __future__ import absolute_import
import oneflow.python.framework.op_arg_util as op_arg_util
import oneflow.core.job.placement_pb2 as placement_pb
import oneflow.python.eager.symbol as symbol_util
import oneflow.core.job.sbp_parallel_pb2 as sbp_parallel_pb
import oneflow_api.oneflow.core.job.placement as placement_cfg
import random
......@@ -137,10 +137,10 @@ def TryReplaceDeviceTag(builder, parallel_desc_symbol, device_tag):
def ReplaceDeviceTag(parallel_desc_symbol, device_tag, builder=None):
assert parallel_desc_symbol.device_tag != device_tag
parallel_conf = placement_pb.ParallelConf()
parallel_conf.device_tag = device_tag
for device_name in parallel_desc_symbol.parallel_conf.device_name:
parallel_conf.device_name.append(device_name)
parallel_conf = placement_cfg.ParallelConf()
parallel_conf.set_device_tag(device_tag)
for device_name in parallel_desc_symbol.parallel_conf.device_name():
parallel_conf.add_device_name(device_name)
if builder is None:
return symbol_util.ParallelDescSymbol(
parallel_desc_symbol.symbol_id, parallel_conf
......@@ -151,13 +151,13 @@ def ReplaceDeviceTag(parallel_desc_symbol, device_tag, builder=None):
def RandomParallelIdPerMachine(parallel_desc_symbol, device_tag=None, builder=None):
if device_tag is None:
device_tag = parallel_desc_symbol.parallel_conf.device_tag
device_tag = parallel_desc_symbol.parallel_conf.device_tag()
assert device_tag is not None
parallel_conf = placement_pb.ParallelConf()
parallel_conf.device_tag = device_tag
parallel_conf = placement_cfg.ParallelConf()
parallel_conf.set_device_tag(device_tag)
for machine_id, dev_ids in parallel_desc_symbol.machine_id2device_id_list.items():
dev_id = dev_ids[random.randint(0, len(dev_ids) - 1)]
parallel_conf.device_name.append("%s:%s" % (machine_id, dev_id))
parallel_conf.add_device_name("%s:%s" % (machine_id, dev_id))
if builder is None:
return symbol_util.ParallelDescSymbol(
parallel_desc_symbol.symbol_id, parallel_conf
......
......@@ -19,7 +19,6 @@ import oneflow.python.eager.symbol as symbol_util
import oneflow.core.operator.op_conf_pb2 as op_conf_pb
import oneflow.core.operator.op_attribute_pb2 as op_attribute_pb
import oneflow.core.job.sbp_parallel_pb2 as sbp_parallel_pb
import oneflow.core.job.placement_pb2 as placement_pb
import oneflow.python.framework.id_util as id_util
import oneflow.python.framework.c_api_util as c_api_util
import oneflow.python.framework.op_arg_util as op_arg_util
......@@ -27,7 +26,6 @@ import oneflow.python.framework.balanced_splitter as balanced_splitter
import oneflow.python.lib.core.enable_if as enable_if
import oneflow.python.lib.core.high_order_bool as high_order_bool
import oneflow.core.register.logical_blob_id_pb2 as logical_blob_id_util
import oneflow.core.job.placement_pb2 as placement_pb
import oneflow.python.eager.blob_cache as blob_cache_util
import oneflow.python.eager.boxing_hob as boxing_hob
import oneflow.python.eager.op_infer_util as op_infer_util
......@@ -35,6 +33,7 @@ from oneflow.python.eager.boxing_hob import BoxingHobContext
import oneflow.python.eager.boxing_middle as boxing_middle
import random
import oneflow
import oneflow_api.oneflow.core.job.placement as placement_cfg
def BoxingTo(builder, produced_blob_object, consumer_op_arg_parallel_attr):
......@@ -254,9 +253,9 @@ def InterNodeOneToMany(builder, produced_blob_object, consumer_op_arg_parallel_a
)
for machine_id, device_ids in consumer_dev_ids.items():
for device_id in device_ids:
parallel_conf = placement_pb.ParallelConf()
parallel_conf.device_tag = "cpu"
parallel_conf.device_name.append("%s:%s" % (machine_id, device_id))
parallel_conf = placement_cfg.ParallelConf()
parallel_conf.set_device_tag("cpu")
parallel_conf.add_device_name("%s:%s" % (machine_id, device_id))
parallel_desc_symbol = builder.GetParallelDescSymbol(parallel_conf)
out_blob = builder.Build121To(produced_blob_object, parallel_desc_symbol)
out_blobs.append(out_blob)
......@@ -594,10 +593,10 @@ def GetConcatSplitBoxingParallelDescSymbol(
builder, blob_parallel_desc_symbol, max_parallel_num
):
random_rank_id = random.randint(0, max_parallel_num - 1)
parallel_conf = placement_pb.ParallelConf()
parallel_conf.device_tag = "cpu"
parallel_conf = placement_cfg.ParallelConf()
parallel_conf.set_device_tag("cpu")
for machine_id, _ in blob_parallel_desc_symbol.machine_id2device_id_list.items():
parallel_conf.device_name.append("%s:%s" % (machine_id, random_rank_id))
parallel_conf.add_device_name("%s:%s" % (machine_id, random_rank_id))
return builder.GetParallelDescSymbol(parallel_conf)
......
......@@ -41,16 +41,17 @@ def AddScopeToStorage(scope_symbol_id, scope_proto_str):
symbol_storage.SetSymbol4SerializedScopeProto(scope_proto_str, symbol)
def MakeScopeSymbol(job_conf_str, parallel_conf_str, is_mirrored):
def MakeScopeSymbol(job_conf_str, parallel_conf, is_mirrored):
job_conf = text_format.Parse(job_conf_str, job_conf_pb.JobConfigProto())
parallel_conf = text_format.Parse(parallel_conf_str, placement_pb.ParallelConf())
return scope_util.MakeInitialScope(
job_conf, parallel_conf.device_tag, list(parallel_conf.device_name), is_mirrored
job_conf,
parallel_conf.device_tag(),
list(parallel_conf.device_name()),
is_mirrored,
).symbol_id
def MakeParallelDescSymbol(parallel_conf_str):
parallel_conf = text_format.Parse(parallel_conf_str, placement_pb.ParallelConf())
def MakeParallelDescSymbol(parallel_conf):
symbol_id = None
def BuildInstruction(builder):
......@@ -61,7 +62,7 @@ def MakeParallelDescSymbol(parallel_conf_str):
return symbol_id
def MirroredCast(op_attribute_str, parallel_conf_str):
def MirroredCast(op_attribute_str, parallel_conf):
op_attribute = text_format.Parse(op_attribute_str, op_attribute_pb.OpAttribute())
blob_register = blob_register_util.GetDefaultBlobRegister()
is_cast_to_mirrored = op_attribute.op_conf.HasField("cast_to_mirrored_conf")
......@@ -74,15 +75,15 @@ def MirroredCast(op_attribute_str, parallel_conf_str):
)
def InterpretCompletedOp(op_attribute_str, parallel_conf_str):
def InterpretCompletedOp(op_attribute_str, parallel_conf):
op_attribute = text_format.Parse(op_attribute_str, op_attribute_pb.OpAttribute())
blob_register = gradient_util.GetDefaultBackwardBlobRegister()
_InterpretCompletedOp(op_attribute, parallel_conf_str, blob_register)
_InterpretCompletedOp(op_attribute, parallel_conf, blob_register)
gradient_util.ReleaseUnusedBlobObject(op_attribute, blob_register)
def _InterpretCompletedOp(op_attribute, parallel_conf_str, blob_register):
return op_executor.Interpret(op_attribute, parallel_conf_str, blob_register)
def _InterpretCompletedOp(op_attribute, parallel_conf, blob_register):
return op_executor.Interpret(op_attribute, parallel_conf, blob_register)
def _MirroredCastAndAddOutputBlobReleaser(op_attribute, blob_register):
......
......@@ -27,9 +27,9 @@ import oneflow.python.framework.op_arg_util as op_arg_util
import oneflow.python.experimental.name_scope as name_scope
import oneflow.python.framework.session_context as session_ctx
import oneflow.python.framework.scope_util as scope_util
import oneflow.core.job.placement_pb2 as placement_pb
import oneflow.python.framework.dtype as dtype_util
import oneflow.python.eager.op_infer_util as op_infer_util
import oneflow_api.oneflow.core.job.placement as placement_cfg
from google.protobuf import text_format
import oneflow
......@@ -42,10 +42,7 @@ def Interpret(op_attribute, parallel_conf, blob_register):
return MirroredCast(op_attribute, blob_register)
if op_attribute.op_conf.HasField("cast_from_mirrored_conf"):
return MirroredCast(op_attribute, blob_register)
if type(parallel_conf) is str:
parallel_conf = text_format.Parse(parallel_conf, placement_pb.ParallelConf())
else:
assert isinstance(parallel_conf, placement_pb.ParallelConf)
assert isinstance(parallel_conf, placement_cfg.ParallelConf)
if op_attribute.op_conf.HasField("distribute_split_conf"):
return DistributeSplitOrClone(op_attribute, parallel_conf, blob_register)
if op_attribute.op_conf.HasField("distribute_clone_conf"):
......
......@@ -16,6 +16,7 @@ limitations under the License.
from __future__ import absolute_import
import oneflow.python.framework.c_api_util as c_api_util
import oneflow.core.job.placement_pb2 as placement_pb
import functools
......@@ -36,7 +37,7 @@ class Symbol(object):
class ParallelDescSymbol(Symbol):
def __init__(self, symbol_id, parallel_conf):
Symbol.__init__(self, symbol_id, parallel_conf)
self.device_tag_ = parallel_conf.device_tag
self.device_tag_ = parallel_conf.device_tag()
self.machine_id2device_id_list_ = MakeMachineId2DeviceIdList(parallel_conf)
sub_parallel_nums = [len(v) for k, v in self.machine_id2device_id_list_.items()]
self.parallel_num_ = functools.reduce(lambda a, b: a + b, sub_parallel_nums, 0)
......@@ -90,7 +91,7 @@ def _GlobalDeviceIdsContaining(bigger, smaller):
def MakeMachineId2DeviceIdList(parallel_conf):
parallel_conf_str = parallel_conf.SerializeToString()
parallel_conf_str = str(parallel_conf)
global _parallel_conf_str2ofrecord
if parallel_conf_str not in _parallel_conf_str2ofrecord:
ofrecord = c_api_util.GetMachine2DeviceIdListOFRecordFromParallelConf(
......
......@@ -41,6 +41,8 @@ from oneflow.python.eager.opkernel_object import OpKernelObject
import oneflow.python.vm.id_util as vm_id_util
import oneflow
import oneflow_api.oneflow.core.vm.instruction as instr_cfg
import oneflow_api.oneflow.core.job.placement as placement_cfg
from google.protobuf import text_format
oneflow_api = oneflow.oneflow_api
......@@ -273,7 +275,7 @@ class InstructionsBuilder(object):
):
parallel_desc_symbol = op_arg_parallel_attr.parallel_desc_symbol
machine_id2device_ids = parallel_desc_symbol.machine_id2device_id_list
device_tag = parallel_desc_symbol.parallel_conf.device_tag
device_tag = parallel_desc_symbol.parallel_conf.device_tag()
machine_device_ids = set()
for physical_blob_object in physical_blob_objects:
phy_paralle_desc_sym = physical_blob_object.parallel_desc_symbol
......@@ -307,13 +309,13 @@ class InstructionsBuilder(object):
def GetPhysicalParallelDescSymbols(self, parallel_desc_symbol):
machine_id2device_ids = parallel_desc_symbol.machine_id2device_id_list
device_tag = parallel_desc_symbol.parallel_conf.device_tag
device_tag = parallel_desc_symbol.parallel_conf.device_tag()
phy_parallel_desc_symbols = []
def AppendPhyParallelDescSymbol(machine_id, device_id):
parallel_conf = placement_pb.ParallelConf()
parallel_conf.device_tag = device_tag
parallel_conf.device_name.append("%d:%d" % (machine_id, device_id))
parallel_conf = placement_cfg.ParallelConf()
parallel_conf.set_device_tag(device_tag)
parallel_conf.add_device_name("%d:%d" % (machine_id, device_id))
phy_parallel_desc_symbols.append(self.GetParallelDescSymbol(parallel_conf))
for machine_id, device_ids in machine_id2device_ids.items():
......@@ -420,7 +422,8 @@ class InstructionsBuilder(object):
return symbol
def GetParallelDescSymbol(self, parallel_conf):
serialized_parallel_conf = parallel_conf.SerializeToString()
# parallel_conf is cfg
serialized_parallel_conf = str(parallel_conf)
if symbol_storage.HasSymbol4SerializedParallelConf(serialized_parallel_conf):
return symbol_storage.GetSymbol4SerializedParallelConf(
serialized_parallel_conf
......@@ -1078,7 +1081,10 @@ class InstructionsBuilder(object):
self.instruction_list_.mutable_instruction().Add().CopyFrom(instruction)
eager_symbol = eager_symbol_pb.EagerSymbol()
eager_symbol.symbol_id = symbol_id
eager_symbol.parallel_conf_symbol.CopyFrom(parallel_conf)
# Temporary transformation
eager_symbol.parallel_conf_symbol.CopyFrom(
text_format.Parse(str(parallel_conf), placement_pb.ParallelConf())
)
self.eager_symbol_list_.eager_symbol.append(eager_symbol)
def _NewScopeSymbol(self, symbol_id, scope_proto):
......
......@@ -32,6 +32,7 @@ from oneflow.core.framework.config_def_pb2 import ConfigDef
from oneflow.core.job.inter_user_job_info_pb2 import InterUserJobInfo
from oneflow.python.framework.job_build_and_infer_error import JobBuildAndInferError
import oneflow
import oneflow_api.oneflow.core.job.placement as placement_cfg
oneflow_api = oneflow.oneflow_api
......@@ -441,7 +442,14 @@ def JobBuildAndInferCtx_MirroredBlobGetParallelConfFromProducerView(job_name, lb
error = text_format.Parse(error_str, error_util.ErrorProto())
if error.HasField("error_type"):
raise JobBuildAndInferError(error)
return text_format.Parse(parallel_conf_str, placement_pb.ParallelConf())
parallel_conf = text_format.Parse(parallel_conf_str, placement_pb.ParallelConf())
# Temporary transformation
parallel_conf_cfg = placement_cfg.ParallelConf()
parallel_conf_cfg.set_device_tag(parallel_conf.device_tag)
for device_name in parallel_conf.device_name:
parallel_conf_cfg.add_device_name(device_name)
return parallel_conf_cfg
def JobBuildAndInferCtx_GetStaticShape(job_name, lbn):
......@@ -541,11 +549,18 @@ def JobBuildAndInferCtx_GetParallelConfFromProducerView(job_name, lbn):
error = text_format.Parse(error_str, error_util.ErrorProto())
if error.HasField("error_type"):
raise JobBuildAndInferError(error)
return text_format.Parse(parallel_conf, placement_pb.ParallelConf())
parallel_conf = text_format.Parse(parallel_conf, placement_pb.ParallelConf())
# Temporary transformation
parallel_conf_cfg = placement_cfg.ParallelConf()
parallel_conf_cfg.set_device_tag(parallel_conf.device_tag)
for device_name in parallel_conf.device_name:
parallel_conf_cfg.add_device_name(device_name)
return parallel_conf_cfg
def GetMachine2DeviceIdListOFRecordFromParallelConf(parallel_conf):
serialized_parallel_conf = str(text_format.MessageToString(parallel_conf))
serialized_parallel_conf = str(parallel_conf)
(
ofrecord,
error_str,
......
......@@ -17,7 +17,7 @@ limitations under the License.
def GetDeviceTagAndMachineDeviceIds(parallel_conf):
machine_device_ids = []
for device_name in parallel_conf.device_name:
for device_name in list(parallel_conf.device_name()):
machine_device_ids.append(device_name)
device_tag = parallel_conf.device_tag
device_tag = parallel_conf.device_tag()
return device_tag, machine_device_ids
......@@ -24,6 +24,7 @@ import oneflow.python.framework.op_util as op_util
import oneflow.python.framework.session_context as session_ctx
import oneflow.python.framework.scope_symbol as scope_symbol
import oneflow
import oneflow_api.oneflow.core.job.placement as placement_cfg
class PlacementScope(object):
......@@ -79,7 +80,9 @@ def MakeParallelConf4Resource(device_tag, resource):
def MakeParallelConf(device_tag, machine_device_ids):
assert isinstance(machine_device_ids, collections.Sized)
device_names = []
parallel_conf = placement_cfg.ParallelConf()
parallel_conf.set_device_tag(device_tag)
for machine_device_id in machine_device_ids:
assert isinstance(
machine_device_id, str
......@@ -88,16 +91,13 @@ def MakeParallelConf(device_tag, machine_device_ids):
"machine_device_id: %s is not valid" % machine_device_id
)
pair = machine_device_id.split(":")
device_names.append("%s:%s" % (pair[0], pair[1]))
parallel_conf.add_device_name("%s:%s" % (pair[0], pair[1]))
parallel_conf = placement_pb.ParallelConf()
parallel_conf.device_tag = device_tag
parallel_conf.device_name.extend(device_names)
return parallel_conf
def MakeMachineId2DeviceIdList(parallel_conf):
parallel_conf_str = parallel_conf.SerializeToString()
parallel_conf_str = str(parallel_conf)
global _parallel_conf_str2ofrecord
if parallel_conf_str not in _parallel_conf_str2ofrecord:
ofrecord = c_api_util.GetMachine2DeviceIdListOFRecordFromParallelConf(
......
......@@ -56,9 +56,7 @@ class PythonCallback(oneflow_api.ForeignCallback):
def EagerInterpretCompletedOp(self, op_attribute, parallel_conf):
try:
# TODO(hanbinbin): str() will be removed after proto obj is replaced with cfg obj in python side
interpreter_callback.InterpretCompletedOp(
str(op_attribute), str(parallel_conf)
)
interpreter_callback.InterpretCompletedOp(str(op_attribute), parallel_conf)
except Exception as e:
print(traceback.format_exc())
raise e
......@@ -66,7 +64,7 @@ class PythonCallback(oneflow_api.ForeignCallback):
def EagerMirroredCast(self, op_attribute, parallel_conf):
try:
# TODO(hanbinbin): str() will be removed after proto obj is replaced with cfg obj in python side
interpreter_callback.MirroredCast(str(op_attribute), str(parallel_conf))
interpreter_callback.MirroredCast(str(op_attribute), parallel_conf)
except Exception as e:
print(traceback.format_exc())
raise e
......@@ -74,7 +72,7 @@ class PythonCallback(oneflow_api.ForeignCallback):
def EagerCastFromMirrored(self, op_attribute, parallel_conf):
try:
# TODO(hanbinbin): str() will be removed after proto obj is replaced with cfg obj in python side
interpreter_callback.CastFromMirrored(str(op_attribute), str(parallel_conf))
interpreter_callback.CastFromMirrored(str(op_attribute), parallel_conf)
except Exception as e:
print(traceback.format_exc())
raise e
......@@ -92,7 +90,7 @@ class PythonCallback(oneflow_api.ForeignCallback):
try:
# TODO(hanbinbin): str() will be removed after proto obj is replaced with cfg obj in python side
return interpreter_callback.MakeScopeSymbol(
str(job_conf), str(parallel_conf), is_mirrored
str(job_conf), parallel_conf, is_mirrored
)
except Exception as e:
print(traceback.format_exc())
......@@ -100,8 +98,7 @@ class PythonCallback(oneflow_api.ForeignCallback):
def MakeParallelDescSymbol(self, parallel_conf):
try:
# TODO(hanbinbin): str() will be removed after proto obj is replaced with cfg obj in python side
return interpreter_callback.MakeParallelDescSymbol(str(parallel_conf))
return interpreter_callback.MakeParallelDescSymbol(parallel_conf)
except Exception as e:
print(traceback.format_exc())
raise e
......
......@@ -31,7 +31,7 @@ import oneflow.python.eager.vm_util as vm_util
import oneflow.python.eager.gradient_util as gradient_util
import oneflow.python.eager.boxing_util as boxing_util
import oneflow.python.framework.op_arg_util as op_arg_util
import oneflow.core.job.placement_pb2 as placement_pb
import oneflow_api.oneflow.core.job.placement as placement_cfg
import traceback
import sys
......@@ -367,9 +367,11 @@ class EagerBlobTrait(object):
consistent_blob_name = None
def BoxingToSingleDevice(builder):
parallel_conf = placement_pb.ParallelConf()
parallel_conf.device_tag = blob_object.parallel_desc_symbol.device_tag
parallel_conf.device_name.append("{}:{}".format(0, 0))
parallel_conf = placement_cfg.ParallelConf()
parallel_conf.set_device_tag(
blob_object.parallel_desc_symbol.device_tag
)
parallel_conf.add_device_name("{}:{}".format(0, 0))
tmp_parallel_desc_symbol = builder.GetParallelDescSymbol(parallel_conf)
tmp_op_arg_parallel_attr = op_arg_util.OpArgParallelAttribute(
tmp_parallel_desc_symbol,
......@@ -377,7 +379,8 @@ class EagerBlobTrait(object):
blob_object.op_arg_parallel_attr.opt_mirrored_parallel,
)
with oneflow.scope.placement(
self.parallel_conf.device_tag, list(self.parallel_conf.device_name)
self.parallel_conf.device_tag(),
list(self.parallel_conf.device_name()),
):
tmp_blob_object = boxing_util.BoxingTo(
builder, blob_object, tmp_op_arg_parallel_attr
......
......@@ -18,8 +18,8 @@ from __future__ import absolute_import
from oneflow.python.eager.symbol import Symbol
import oneflow.python.eager.symbol_storage as symbol_storage
import oneflow.python.framework.parallel_conf_util as parallel_conf_util
import oneflow.core.job.placement_pb2 as placement_pb
import oneflow.core.job.scope_pb2 as scope_pb
import oneflow_api.oneflow.core.job.placement as placement_cfg
import collections
import re
......@@ -140,7 +140,9 @@ def BuildInitialScope(
def MakeParallelConf(device_tag, machine_device_ids):
assert isinstance(machine_device_ids, (list, tuple))
device_names = []
parallel_conf = placement_cfg.ParallelConf()
parallel_conf.set_device_tag(device_tag)
for machine_device_id in machine_device_ids:
assert isinstance(
machine_device_id, str
......@@ -148,9 +150,6 @@ def MakeParallelConf(device_tag, machine_device_ids):
assert re.match("^\d+:\d+(-\d+)?$", machine_device_id) is not None, (
"machine_device_id: %s is not valid" % machine_device_id
)
device_names.append(machine_device_id)
parallel_conf.add_device_name(machine_device_id)
parallel_conf = placement_pb.ParallelConf()
parallel_conf.device_tag = device_tag
parallel_conf.device_name.extend(device_names)
return parallel_conf
......@@ -17,7 +17,6 @@ from __future__ import absolute_import
import re
import oneflow.core.job.placement_pb2 as placement_proto_pb
import oneflow.core.operator.op_conf_pb2 as op_conf_util
import oneflow.core.register.logical_blob_id_pb2 as logical_blob_id_util
import oneflow.python.framework.c_api_util as c_api_util
......@@ -31,6 +30,7 @@ import oneflow.python.framework.session_context as session_ctx
import oneflow.python.framework.scope_util as scope_util
import oneflow.python.eager.vm_util as vm_util
import oneflow.python.eager.blob_register as blob_register_util
import oneflow_api.oneflow.core.job.placement as placement_cfg
blob_register = blob_register_util.GetDefaultBlobRegister()
......@@ -102,7 +102,7 @@ def _GetReturnOpConfAndOutLbiAndScope(remote_blob, allow_cpu_return_op=True):
lbi.op_name = op_conf.name
lbi.blob_name = "out"
parallel_conf = placement_proto_pb.ParallelConf()
parallel_conf = placement_cfg.ParallelConf()
parallel_conf.CopyFrom(remote_blob.parallel_conf)
def BuildScope(old_scope, builder):
......
......@@ -159,7 +159,7 @@ def LazyConsistentWatch(blob_watched, handler):
op_conf.name = id_util.UniqueStr("ForeignWatch_")
setattr(op_conf.foreign_watch_conf, "in", blob_watched.unique_name)
op_conf.foreign_watch_conf.handler_uuid = handler_uuid
device_name = blob_watched.parallel_conf.device_name[0]
device_name = blob_watched.parallel_conf.device_name(0)
with oneflow.scope.placement("cpu", "0:0"):
compile_context.CurJobAddOp(op_conf)
watcher_util.BindUuidAndHandler(handler_uuid, blob_watched, handler)
......
......@@ -41,8 +41,8 @@ def _check_cpu_only_relu_device(test_case, verbose=False):
def cpu_only_relu_job(x_def: oft.Numpy.Placeholder(shape=(2, 5), dtype=flow.float)):
y = _cpu_only_relu(x_def)
if verbose:
print("cpu_only_relu output device", y.parallel_conf.device_tag)
test_case.assertTrue("cpu" in y.parallel_conf.device_tag)
print("cpu_only_relu output device", y.parallel_conf.device_tag())
test_case.assertTrue("cpu" in y.parallel_conf.device_tag())
return y
cpu_only_relu_job(np.random.rand(2, 5).astype(np.single)).get()
......@@ -59,7 +59,7 @@ def _check_non_cpu_only_relu_device(test_case):
with flow.scope.placement("gpu", "0:0"):
y = flow.math.relu(x_def)
test_case.assertTrue("gpu" in y.parallel_conf.device_tag)
test_case.assertTrue("gpu" in y.parallel_conf.device_tag())
return y
......
......@@ -598,8 +598,8 @@ bool Const{{ util.class_name(cls) }}::_{{ util.class_name(cls) }}_::operator<(co
}
using _{{ util.class_name(cls) }}_ = Const{{ util.class_name(cls) }}::_{{ util.class_name(cls) }}_;
Const{{ util.class_name(cls) }}::Const{{ util.class_name(cls) }}(const ::std::shared_ptr<::std::unique_ptr<_{{ util.class_name(cls) }}_>>& data): data_(data) {}
Const{{ util.class_name(cls) }}::Const{{ util.class_name(cls) }}(): data_(::std::make_shared<::std::unique_ptr<_{{ util.class_name(cls) }}_>>()) {}
Const{{ util.class_name(cls) }}::Const{{ util.class_name(cls) }}(const ::std::shared_ptr<_{{ util.class_name(cls) }}_>& data): data_(data) {}
Const{{ util.class_name(cls) }}::Const{{ util.class_name(cls) }}(): data_(::std::make_shared<_{{ util.class_name(cls) }}_>()) {}
Const{{ util.class_name(cls) }}::Const{{ util.class_name(cls) }}(const {{ util.module_package_namespace(module) }}::{{ util.class_name(cls) }}& proto_{{ util.class_name(cls).lower() }}) {
BuildFromProto(proto_{{ util.class_name(cls).lower() }});
}
......@@ -616,7 +616,7 @@ void Const{{ util.class_name(cls) }}::ToProto(PbMessage* proto_{{ util.class_nam
}
bool Const{{ util.class_name(cls) }}::__Empty__() const {
return !*data_;
return !data_;
}
int Const{{ util.class_name(cls) }}::FieldNumber4FieldName(const ::std::string& field_name) const {
......@@ -778,24 +778,21 @@ bool Const{{ util.class_name(cls) }}::operator<(const Const{{ util.class_name(cl
return *__SharedPtrOrDefault__() < *other.__SharedPtrOrDefault__();
}
const ::std::unique_ptr<_{{ util.class_name(cls) }}_>& Const{{ util.class_name(cls) }}::__SharedPtrOrDefault__() const {
if (*data_) { return *data_; }
static const ::std::unique_ptr<_{{ util.class_name(cls) }}_> default_ptr(new _{{ util.class_name(cls) }}_());
const ::std::shared_ptr<_{{ util.class_name(cls) }}_>& Const{{ util.class_name(cls) }}::__SharedPtrOrDefault__() const {
if (data_) { return data_; }
static const ::std::shared_ptr<_{{ util.class_name(cls) }}_> default_ptr = std::make_shared<_{{ util.class_name(cls) }}_>();
return default_ptr;
}
const ::std::unique_ptr<_{{ util.class_name(cls) }}_>& Const{{ util.class_name(cls) }}::__SharedPtr__() {
return *__SharedUniquePtr__();
}
const ::std::shared_ptr<::std::unique_ptr<_{{ util.class_name(cls) }}_>>& Const{{ util.class_name(cls) }}::__SharedUniquePtr__() {
if (!*data_) { data_->reset(new _{{ util.class_name(cls) }}_()); }
const ::std::shared_ptr<_{{ util.class_name(cls) }}_>& Const{{ util.class_name(cls) }}::__SharedPtr__() {
if (!data_) { data_.reset(new _{{ util.class_name(cls) }}_()); }
return data_;
}
// use a protected member method to avoid someone change member variable(data_) by Const{{ util.class_name(cls) }}
void Const{{ util.class_name(cls) }}::BuildFromProto(const PbMessage& proto_{{ util.class_name(cls).lower() }}) {
data_ = ::std::make_shared<::std::unique_ptr<_{{ util.class_name(cls) }}_>>(new _{{ util.class_name(cls) }}_(dynamic_cast<const {{ util.module_package_namespace(module) }}::{{ util.class_name(cls) }}&>(proto_{{ util.class_name(cls).lower() }})));
data_ = ::std::make_shared<_{{ util.class_name(cls) }}_>(dynamic_cast<const {{ util.module_package_namespace(module) }}::{{ util.class_name(cls) }}&>(proto_{{ util.class_name(cls).lower() }}));
}
{{ util.class_name(cls) }}::{{ util.class_name(cls) }}(const ::std::shared_ptr<::std::unique_ptr<_{{ util.class_name(cls) }}_>>& data)
{{ util.class_name(cls) }}::{{ util.class_name(cls) }}(const ::std::shared_ptr<_{{ util.class_name(cls) }}_>& data)
: Const{{ util.class_name(cls) }}(data) {}
{{ util.class_name(cls) }}::{{ util.class_name(cls) }}(const {{ util.class_name(cls) }}& other) { CopyFrom(other); }
// enable nothrow for ::std::vector<{{ util.class_name(cls) }}> resize
......@@ -828,13 +825,13 @@ bool {{ util.class_name(cls) }}::operator<(const {{ util.class_name(cls) }}& oth
return *__SharedPtrOrDefault__() < *other.__SharedPtrOrDefault__();
}
void {{ util.class_name(cls) }}::Clear() {
if (data_) { data_->reset(); }
if (data_) { data_.reset(); }
}
void {{ util.class_name(cls) }}::CopyFrom(const {{ util.class_name(cls) }}& other) {
if (other.__Empty__()) {
Clear();
} else {
__SharedPtr__()->CopyFrom(**other.data_);
__SharedPtr__()->CopyFrom(*other.data_);
}
}
{{ util.class_name(cls) }}& {{ util.class_name(cls) }}::operator=(const {{ util.class_name(cls) }}& other) {
......@@ -939,7 +936,7 @@ const {{ util.field_map_container_name(field) }} & {{ util.class_name(cls) }}::{
{% endfor %}{# field #}
::std::shared_ptr<{{ util.class_name(cls) }}> {{ util.class_name(cls) }}::__SharedMutable__() {
return ::std::make_shared<{{ util.class_name(cls) }}>(__SharedUniquePtr__());
return ::std::make_shared<{{ util.class_name(cls) }}>(__SharedPtr__());
}
{% endif %}{# cls is not entry #}
{% endfor %}{# cls #}
......
......@@ -229,12 +229,12 @@ class Const{{ util.class_name(cls) }} : public ::oneflow::cfg::Message {
bool operator<(const _{{ util.class_name(cls) }}_& other) const;
};
Const{{ util.class_name(cls) }}(const ::std::shared_ptr<::std::unique_ptr<_{{ util.class_name(cls) }}_>>& data);
Const{{ util.class_name(cls) }}(const ::std::shared_ptr<_{{ util.class_name(cls) }}_>& data);
Const{{ util.class_name(cls) }}(const Const{{ util.class_name(cls) }}&);
Const{{ util.class_name(cls) }}(Const{{ util.class_name(cls) }}&&) noexcept;
Const{{ util.class_name(cls) }}();
Const{{ util.class_name(cls) }}(const {{ util.module_package_namespace(module) }}::{{ util.class_name(cls) }}& proto_{{ util.class_name(cls).lower() }});
~Const{{ util.class_name(cls) }}() override;
virtual ~Const{{ util.class_name(cls) }}() override;
using PbMessage = ::google::protobuf::Message;
void ToProto(PbMessage* proto_{{ util.class_name(cls).lower() }}) const override;
......@@ -302,33 +302,29 @@ class Const{{ util.class_name(cls) }} : public ::oneflow::cfg::Message {
::std::shared_ptr<Const{{ util.class_name(cls) }}> __SharedConst__() const;
int64_t __Id__() const;
// the data of `this` will be moved to the result which is mutable
::std::shared_ptr<{{ util.class_name(cls) }}> __Move__();
public:
bool operator==(const Const{{ util.class_name(cls) }}& other) const;
bool operator<(const Const{{ util.class_name(cls) }}& other) const;
protected:
const ::std::unique_ptr<_{{ util.class_name(cls) }}_>& __SharedPtrOrDefault__() const;
const ::std::unique_ptr<_{{ util.class_name(cls) }}_>& __SharedPtr__();
const ::std::shared_ptr<::std::unique_ptr<_{{ util.class_name(cls) }}_>>& __SharedUniquePtr__();
const ::std::shared_ptr<_{{ util.class_name(cls) }}_>& __SharedPtrOrDefault__() const;
const ::std::shared_ptr<_{{ util.class_name(cls) }}_>& __SharedPtr__();
// use a protected member method to avoid someone change member variable(data_) by Const{{ util.class_name(cls) }}
void BuildFromProto(const PbMessage& proto_{{ util.class_name(cls).lower() }});
// use ::std::shared_ptr for sharing reference between mutable object and const object
// use ::std::unique_ptr for moving ownership
::std::shared_ptr<::std::unique_ptr<_{{ util.class_name(cls) }}_>> data_;
::std::shared_ptr<_{{ util.class_name(cls) }}_> data_;
};
class {{ util.class_name(cls) }} final : public Const{{ util.class_name(cls) }} {
public:
{{ util.class_name(cls) }}(const ::std::shared_ptr<::std::unique_ptr<_{{ util.class_name(cls) }}_>>& data);
{{ util.class_name(cls) }}(const ::std::shared_ptr<_{{ util.class_name(cls) }}_>& data);
{{ util.class_name(cls) }}(const {{ util.class_name(cls) }}& other);
// enable nothrow for ::std::vector<{{ util.class_name(cls) }}> resize
{{ util.class_name(cls) }}({{ util.class_name(cls) }}&&) noexcept;
{{ util.class_name(cls) }}();
{{ util.class_name(cls) }}(const {{ util.module_package_namespace(module) }}::{{ util.class_name(cls) }}& proto_{{ util.class_name(cls).lower() }});
~{{ util.class_name(cls) }}();
~{{ util.class_name(cls) }}() override;
void InitFromProto(const PbMessage& proto_{{ util.class_name(cls).lower() }}) override;
......@@ -498,12 +494,6 @@ class {{ util.field_map_container_name(field) }} final : public Const{{ util.fie
{% endfor %}{# field #}
inline ::std::shared_ptr<{{ util.class_name(cls) }}> Const{{ util.class_name(cls) }}::__Move__() {
if (__Empty__()) { return ::std::make_shared<{{ util.class_name(cls) }}>(); }
auto data = ::std::make_shared<::std::unique_ptr<_{{ util.class_name(cls) }}_>>();
*data = ::std::move(*data_);
return ::std::make_shared<{{ util.class_name(cls) }}>(data);
}
{% endif %}{# cls is not entry #}
{% endfor %}{# cls #}
......
......@@ -133,8 +133,6 @@ ONEFLOW_CFG_PYBIND11_MODULE("{{ util.module_get_python_module_path(module) }}",
{% if not util.class_is_map_entry(cls) %}
{
pybind11::class_<Const{{ util.class_name(cls) }}, std::shared_ptr<Const{{ util.class_name(cls) }}>> registry(m, "Const{{ util.class_name(cls) }}");
// the data of `self` will be moved to the result which is always mutable
registry.def("Move", &Const{{ util.class_name(cls) }}::__Move__);
registry.def("__id__", &{{ util.module_package_cfg_namespace(module) }}::{{ util.class_name(cls) }}::__Id__);
registry.def(pybind11::self == pybind11:: self);
registry.def(pybind11::self < pybind11:: self);
......@@ -198,7 +196,6 @@ ONEFLOW_CFG_PYBIND11_MODULE("{{ util.module_get_python_module_path(module) }}",
registry.def("Clear", &{{ util.module_package_cfg_namespace(module) }}::{{ util.class_name(cls) }}::Clear);
registry.def("CopyFrom", (void ({{ util.module_package_cfg_namespace(module) }}::{{ util.class_name(cls) }}::*)(const Const{{ util.class_name(cls) }}&))&{{ util.module_package_cfg_namespace(module) }}::{{ util.class_name(cls) }}::CopyFrom);
registry.def("CopyFrom", (void ({{ util.module_package_cfg_namespace(module) }}::{{ util.class_name(cls) }}::*)(const {{ util.module_package_cfg_namespace(module) }}::{{ util.class_name(cls) }}&))&{{ util.module_package_cfg_namespace(module) }}::{{ util.class_name(cls) }}::CopyFrom);
registry.def("Move", &{{ util.module_package_cfg_namespace(module) }}::{{ util.class_name(cls) }}::__Move__);
registry.def("__id__", &{{ util.module_package_cfg_namespace(module) }}::{{ util.class_name(cls) }}::__Id__);
registry.def(pybind11::self == pybind11:: self);
registry.def(pybind11::self < pybind11:: self);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册