diff --git a/oneflow/python/eager/boxing_hob.py b/oneflow/python/eager/boxing_hob.py index 59a1eec0e95f4ac336ac03df519e3116f1ae6139..709e9e1a6495e6bc2b5260d2a8b5b7258d1ad91c 100644 --- a/oneflow/python/eager/boxing_hob.py +++ b/oneflow/python/eager/boxing_hob.py @@ -101,8 +101,8 @@ class ComposeHob(BoolFunctor): return ctx.composer2middle_op_arg_parallel_attr[self] -@bool_functor("MasterMachineOnly") -def MasterMachineOnly(ctx): +@bool_functor("SingleMachine") +def SingleMachine(ctx): blob_device_ids = ( ctx.produced_blob_object.parallel_desc_symbol.machine_id2device_id_list ) diff --git a/oneflow/python/eager/boxing_util.py b/oneflow/python/eager/boxing_util.py index 2f3a0e07f49a1f2f6784859f83bd4eb6c30d179e..3d847c5f47a641cb2f975b3fb5aa1c64e8dac5d8 100644 --- a/oneflow/python/eager/boxing_util.py +++ b/oneflow/python/eager/boxing_util.py @@ -232,6 +232,43 @@ def CopyHD(builder, produced_blob_object, consumer_op_arg_parallel_attr): return BuildCopyHdInstruction(builder, produced_blob_object, op_device_tag) +BlobIsPartialSum = boxing_hob.producer_sbp_parallel.HasField("partial_sum_parallel") +OpArgIsBroadcast = boxing_hob.consumer_sbp_parallel.HasField("broadcast_parallel") + + +MatchInterNodeOneToMany = ( + ~boxing_hob.SingleMachine + & (boxing_hob.producer_parallel_desc.device_tag == "cpu") + & (boxing_hob.consumer_parallel_desc.device_tag == "cpu") + & (boxing_hob.producer_parallel_desc.parallel_num == 1) + & (boxing_hob.consumer_parallel_desc.parallel_num > 1) + & OpArgIsBroadcast +) + + +@boxing_condition(MatchInterNodeOneToMany) +def InterNodeOneToMany(builder, produced_blob_object, consumer_op_arg_parallel_attr): + out_blobs = [] + consumer_dev_ids = ( + consumer_op_arg_parallel_attr.parallel_desc_symbol.machine_id2device_id_list + ) + for machine_id, device_ids in consumer_dev_ids.items(): + for device_id in device_ids: + parallel_conf = placement_pb.ParallelConf() + parallel_conf.device_tag = "cpu" + parallel_conf.device_name.append("%s:%s" % (machine_id, device_id)) + parallel_desc_symbol = builder.GetParallelDescSymbol(parallel_conf) + out_blob = builder.Build121To(produced_blob_object, parallel_desc_symbol) + out_blobs.append(out_blob) + + return PackPhysicalBoxingBlobObjectsToLogical( + builder, + out_blobs, + consumer_op_arg_parallel_attr, + produced_blob_object.op_arg_blob_attr, + ) + + MatchInterNodeOneToOne = ( (boxing_hob.producer_parallel_desc.device_tag == "cpu") & (boxing_hob.consumer_parallel_desc.device_tag == "cpu") @@ -250,13 +287,9 @@ MatchInterNodeOneToOne = ( @boxing_condition(MatchInterNodeOneToOne) def InterNodeOneToOne(builder, produced_blob_object, consumer_op_arg_parallel_attr): - receive_blob_object = _MakeNewBlobObjectLike( - builder, - produced_blob_object, - consumer_op_arg_parallel_attr.parallel_desc_symbol, + return builder.Build121To( + produced_blob_object, consumer_op_arg_parallel_attr.parallel_desc_symbol ) - builder.Build121AssignInstruction(receive_blob_object, produced_blob_object) - return receive_blob_object MatchCpuBroadcastOneToOne = ( @@ -316,12 +349,8 @@ def VerboseOptionalBoxing(boxing_method): return opt_boxing_method -BlobIsPartialSum = boxing_hob.producer_sbp_parallel.HasField("partial_sum_parallel") -OpArgIsBroadcast = boxing_hob.consumer_sbp_parallel.HasField("broadcast_parallel") - - MatchNcclAllReduce = ( - boxing_hob.MasterMachineOnly + boxing_hob.SingleMachine & (boxing_hob.producer_parallel_desc.device_tag == "gpu") & (boxing_hob.producer_parallel_desc == boxing_hob.consumer_parallel_desc) & (boxing_hob.consumer_parallel_desc.parallel_num > 1) @@ -335,7 +364,7 @@ def GpuNcclAllReduce(builder, produced_blob_object, consumer_op_arg_parallel_att parallel_conf = consumer_op_arg_parallel_attr.parallel_desc_symbol.parallel_conf bn_in_op2blob_object = dict(in_0=produced_blob_object) op_attribute = _GetEagerNcclAllReduce(parallel_conf, bn_in_op2blob_object) - builder.BoxingStatelessCall( + builder.NoBoxingStatelessCall( op_attribute, parallel_conf=parallel_conf, bn_in_op2blob_object=bn_in_op2blob_object, @@ -373,8 +402,7 @@ MatchConcatManyToSplitMany = ( MatchNaiveCpuSplitToSplit = ( - boxing_hob.MasterMachineOnly - & (boxing_hob.producer_parallel_desc.device_tag == "cpu") + (boxing_hob.producer_parallel_desc.device_tag == "cpu") & (boxing_hob.consumer_parallel_desc.device_tag == "cpu") & (MatchSplitOneToMany | MatchConcatManyToOne | MatchConcatManyToSplitMany) ) @@ -391,8 +419,7 @@ def NaiveCpuSplitToSplit(builder, produced_blob_object, consumer_op_arg_parallel MatchNaiveCpuPartialSumToSplit = ( - boxing_hob.MasterMachineOnly - & (boxing_hob.producer_parallel_desc.device_tag == "cpu") + (boxing_hob.producer_parallel_desc.device_tag == "cpu") & (boxing_hob.consumer_parallel_desc.device_tag == "cpu") & (boxing_hob.producer_parallel_desc.parallel_num > 1) & boxing_hob.producer_sbp_parallel.HasField("partial_sum_parallel") @@ -424,17 +451,13 @@ def NaiveCpuRefPhysicalBlobObjectsScope( physical_in_blob_objects = UnpackLogicalBoxingBlobObjectToPhysical( builder, produced_blob_object ) - out_parallel_num = consumer_op_arg_parallel_attr.parallel_desc_symbol.parallel_num + consumer_parallel_desc_symbol = consumer_op_arg_parallel_attr.parallel_desc_symbol + out_parallel_num = consumer_parallel_desc_symbol.parallel_num boxing_parallel_desc_symbol = GetConcatSplitBoxingParallelDescSymbol( builder, - produced_blob_object.parallel_desc_symbol, + consumer_parallel_desc_symbol, max(len(physical_in_blob_objects), out_parallel_num), ) - physical_in_blob_objects = RefBlobObjectWithParallelDesc( - builder, - physical_in_blob_objects, - [boxing_parallel_desc_symbol] * len(physical_in_blob_objects), - ) physical_output_blob_objects = get_physical_out_blob_objects( builder=builder, produced_blob_object=produced_blob_object, @@ -520,7 +543,7 @@ def BuildNaiveCpuBoxing( bn_in_op2blob_object = {} for i in range(len(physical_in_blob_objects)): bn_in_op2blob_object["in_%s" % i] = physical_in_blob_objects[i] - builder.BoxingStatelessCall( + builder.NoBoxingStatelessCall( op_attribute, parallel_conf=boxing_parallel_desc_symbol.parallel_conf, bn_in_op2blob_object=bn_in_op2blob_object, @@ -585,7 +608,7 @@ def UnpackLogicalBoxingBlobObjectToPhysical(builder, produced_blob_object): MatchCpuBroadcastOneToMany = ( - boxing_hob.MasterMachineOnly + boxing_hob.SingleMachine & (boxing_hob.producer_parallel_desc.device_tag == "cpu") & (boxing_hob.consumer_parallel_desc.device_tag == "cpu") & boxing_hob.ProducerDevicesContainedInConsumerDevices @@ -605,8 +628,7 @@ def CpuBroadcastOneToMany(builder, produced_blob_object, consumer_op_arg_paralle MatchBroadcastManyToOne = ( - boxing_hob.MasterMachineOnly - & ( + ( boxing_hob.producer_parallel_desc.device_tag == boxing_hob.consumer_parallel_desc.device_tag ) @@ -673,7 +695,7 @@ def _BuildCopyInstruction(builder, produced_blob_object, op_conf, to_device_tag) assert to_device_tag != x_device_tag, (to_device_tag, x_device_tag) if to_device_tag == "cpu" and x_device_tag == "gpu": x_parallel_conf = produced_blob_object.parallel_desc_symbol.parallel_conf - builder.BoxingCudaD2HStatelessCall( + builder.NoBoxingCudaD2HStatelessCall( op_attribute, x_parallel_conf, bn_in_op2blob_object=bn_in_op2blob_object ) elif to_device_tag == "gpu" and x_device_tag == "cpu": @@ -682,7 +704,7 @@ def _BuildCopyInstruction(builder, produced_blob_object, op_conf, to_device_tag) ) out_parallel_conf = out_parallel_desc_symbol.parallel_conf with builder.CudaHostPinBlob(produced_blob_object): - builder.BoxingCudaH2DStatelessCall( + builder.NoBoxingCudaH2DStatelessCall( op_attribute, out_parallel_conf, bn_in_op2blob_object=bn_in_op2blob_object, @@ -721,19 +743,19 @@ def BuildAssignInstruction(builder, ref_blob_object, value_blob_object, op_conf) bn_in_op2blob_object = {"ref": ref_blob_object, "value": value_blob_object} op_attribute = op_infer_util.Infer(op_conf, bn_in_op2blob_object) if ref_device_tag == value_device_tag: - builder.BoxingStatelessCall( + builder.NoBoxingStatelessCall( op_attribute, parallel_conf=ref_parallel_conf, bn_in_op2blob_object=bn_in_op2blob_object, ) elif ref_device_tag == "cpu" and value_device_tag == "gpu": value_parallel_conf = value_blob_object.parallel_desc_symbol.parallel_conf - builder.BoxingCudaD2HStatelessCall( + builder.NoBoxingCudaD2HStatelessCall( op_attribute, value_parallel_conf, bn_in_op2blob_object=bn_in_op2blob_object ) elif ref_device_tag == "gpu" and value_device_tag == "cpu": with builder.CudaHostPinBlob(value_blob_object): - builder.BoxingCudaH2DStatelessCall( + builder.NoBoxingCudaH2DStatelessCall( op_attribute, ref_parallel_conf, bn_in_op2blob_object=bn_in_op2blob_object, @@ -766,26 +788,6 @@ def _GetEagerNcclAllReduce(parallel_conf, ibn2blob_object): return op_infer_util.Infer(op_conf, ibn2blob_object) -def _MakeNewBlobObjectLike(builder, blob_object, new_parallel_desc_symbol): - op_conf = op_conf_pb.OperatorConf() - op_conf.name = id_util.UniqueStr("Input") - op_conf.device_tag = new_parallel_desc_symbol.device_tag - op_conf.input_conf.out = "out" - blob_object.op_arg_parallel_attr.DumpToToInterfaceBlobConf( - op_conf.input_conf.blob_conf - ) - blob_object.op_arg_blob_attr.DumpToToInterfaceBlobConf(op_conf.input_conf.blob_conf) - op_conf.scope_symbol_id = oneflow.current_scope().symbol_id - upstream_signature = op_attribute_pb.OpNodeSignature() - op_attribute = c_api_util.InferOpConf(op_conf, upstream_signature) - parallel_conf = new_parallel_desc_symbol.parallel_conf - bn_in_op2blob_object = {} - builder.BoxingStatelessCall( - op_attribute, parallel_conf, bn_in_op2blob_object=bn_in_op2blob_object - ) - return bn_in_op2blob_object["out"] - - NcclAllReduce = Sequential( boxing_middle.BoxingToMiddle( GpuNcclAllReduce, @@ -823,6 +825,20 @@ BoxingInterNodeOneToOne = Sequential( OptionalBoxing(CopyH2D), ) +BoxingInterNodeOneToMany = Sequential( + boxing_middle.BoxingToMiddle( + OptionalBoxing(CopyD2H), + boxing_middle.ReplaceProducerDeviceTag("cpu"), + boxing_middle.ProducerSbpParallel, + ), + boxing_middle.BoxingToMiddle( + InterNodeOneToMany, + boxing_middle.ReplaceConsumerDeviceTag("cpu"), + boxing_middle.ConsumerSbpParallel, + ), + OptionalBoxing(CopyH2D), +) + conditional_function_table = [ CopyH2D, CopyD2H, @@ -830,6 +846,7 @@ conditional_function_table = [ # one to one BoxingIntraNodeOneToOne, BoxingInterNodeOneToOne, + BoxingInterNodeOneToMany, # B -> B BroadcastManyToOne, Sequential( diff --git a/oneflow/python/eager/vm_util.py b/oneflow/python/eager/vm_util.py index abc9a71dc97813467b8d27c28deffd7b2dfa00d7..ad5c950c94058499cb692dbcf3aef40a7447bc93 100644 --- a/oneflow/python/eager/vm_util.py +++ b/oneflow/python/eager/vm_util.py @@ -18,12 +18,12 @@ from __future__ import absolute_import import re from contextlib import contextmanager -import oneflow.core.eager.eager_symbol_pb2 as eager_symbol_util -import oneflow.core.job.placement_pb2 as placement_pb_util -import oneflow.core.operator.op_conf_pb2 as op_conf_util +import oneflow.core.eager.eager_symbol_pb2 as eager_symbol_pb +import oneflow.core.job.placement_pb2 as placement_pb +import oneflow.core.operator.op_conf_pb2 as op_conf_pb import oneflow.core.operator.op_attribute_pb2 as op_attribute_pb +import oneflow.core.vm.instruction_pb2 as instr_pb import oneflow.core.register.blob_desc_pb2 as blob_desc_pb -import oneflow.core.vm.instruction_pb2 as instr_util import oneflow.python.eager.blob_cache as blob_cache_util import oneflow.python.eager.boxing_util as boxing_util import oneflow.python.eager.object as object_util @@ -86,8 +86,8 @@ class InstructionsBuilder(object): ): self.id_generator_ = id_generator self.release_object_ = release_object - assert isinstance(instruction_list, instr_util.InstructionListProto) - assert isinstance(eager_symbol_list, eager_symbol_util.EagerSymbolList) + assert isinstance(instruction_list, instr_pb.InstructionListProto) + assert isinstance(eager_symbol_list, eager_symbol_pb.EagerSymbolList) self.instruction_list_ = instruction_list self.eager_symbol_list_ = eager_symbol_list @@ -99,9 +99,12 @@ class InstructionsBuilder(object): bn_in_op2blob_object=bn_in_op2blob_object, ) + def FetchDelegateBlobObject(x_blob_object, op_arg_parallel_attr): + return boxing_util.BoxingTo(self, x_blob_object, op_arg_parallel_attr) + def GetDelegateBlobObject(blob_object, op_arg_parallel_attr): return _FindOrCreateDelegateBlobObject( - self, blob_object, op_arg_parallel_attr + self, FetchDelegateBlobObject, blob_object, op_arg_parallel_attr ) self._StatelessCall( @@ -113,7 +116,9 @@ class InstructionsBuilder(object): get_delegate_blob_object=GetDelegateBlobObject, ) - def BoxingStatelessCall(self, op_attribute, parallel_conf, bn_in_op2blob_object={}): + def NoBoxingStatelessCall( + self, op_attribute, parallel_conf, bn_in_op2blob_object={} + ): op_parallel_desc_sym = self.GetParallelDescSymbol(parallel_conf) self._CheckRefInBlobObjectParallelDesc( op_attribute, @@ -121,8 +126,27 @@ class InstructionsBuilder(object): bn_in_op2blob_object=bn_in_op2blob_object, ) - def GetDirectBlobObject(blob_object, op_arg_parallel_attr): - return blob_object + def FetchDelegateBlobObject(blob_object, op_arg_parallel_attr): + from_pd = blob_object.parallel_desc_symbol + to_pd = op_arg_parallel_attr.parallel_desc_symbol + if from_pd == to_pd: + return blob_object + assert from_pd.device_tag == "cpu" + assert to_pd.device_tag == "cpu" + assert from_pd.parallel_num == to_pd.parallel_num + from_machine_ids = from_pd.machine_id2device_id_list.keys() + to_machine_ids = to_pd.machine_id2device_id_list.keys() + if ( + len(from_pd.machine_id2device_id_list) == from_pd.parallel_num + and from_machine_ids == to_machine_ids + ): + return self.BroadcastBlobReference(blob_object, to_pd) + return self.Build121To(blob_object, to_pd) + + def GetDirectOr121BlobObject(blob_object, op_arg_parallel_attr): + return _FindOrCreateDelegateBlobObject( + self, FetchDelegateBlobObject, blob_object, op_arg_parallel_attr + ) self._StatelessCall( "compute", @@ -130,10 +154,10 @@ class InstructionsBuilder(object): op_parallel_desc_sym=op_parallel_desc_sym, blob_parallel_desc_sym=op_parallel_desc_sym, bn_in_op2blob_object=bn_in_op2blob_object, - get_delegate_blob_object=GetDirectBlobObject, + get_delegate_blob_object=GetDirectOr121BlobObject, ) - def BoxingCudaD2HStatelessCall( + def NoBoxingCudaD2HStatelessCall( self, op_attribute, in_parallel_conf, bn_in_op2blob_object={} ): op_parallel_desc_sym = self.GetParallelDescSymbol(in_parallel_conf) @@ -158,7 +182,7 @@ class InstructionsBuilder(object): get_delegate_blob_object=GetDirectBlobObject, ) - def BoxingCudaH2DStatelessCall( + def NoBoxingCudaH2DStatelessCall( self, op_attribute, out_parallel_conf, bn_in_op2blob_object={} ): op_parallel_desc_sym = self.GetParallelDescSymbol(out_parallel_conf) @@ -180,6 +204,26 @@ class InstructionsBuilder(object): get_delegate_blob_object=GetDirectBlobObject, ) + def RawStatelessCall(self, op_attribute, parallel_conf, bn_in_op2blob_object={}): + op_parallel_desc_sym = self.GetParallelDescSymbol(parallel_conf) + self._CheckRefInBlobObjectParallelDesc( + op_attribute, + op_parallel_desc_sym, + bn_in_op2blob_object=bn_in_op2blob_object, + ) + + def GetDirectBlobObject(blob_object, op_arg_parallel_attr): + return blob_object + + self._StatelessCall( + "compute", + op_attribute, + op_parallel_desc_sym=op_parallel_desc_sym, + blob_parallel_desc_sym=op_parallel_desc_sym, + bn_in_op2blob_object=bn_in_op2blob_object, + get_delegate_blob_object=GetDirectBlobObject, + ) + def StatefulCall(self, op_attribute, opkernel_object, bn_in_op2blob_object={}): op_parallel_desc_sym = opkernel_object.parallel_desc_symbol parallel_sig = op_attribute.parallel_signature @@ -191,9 +235,12 @@ class InstructionsBuilder(object): bn_in_op2blob_object=bn_in_op2blob_object, ) + def FetchDelegateBlobObject(x_blob_object, op_arg_parallel_attr): + return boxing_util.BoxingTo(self, x_blob_object, op_arg_parallel_attr) + def GetDelegateBlobObject(blob_object, op_arg_parallel_attr): return _FindOrCreateDelegateBlobObject( - self, blob_object, op_arg_parallel_attr + self, FetchDelegateBlobObject, blob_object, op_arg_parallel_attr ) self._StatefulCall( @@ -209,7 +256,7 @@ class InstructionsBuilder(object): def InsertRemoveForeignCallbackInstruction(self, object_id, callback): unique_callback_id = python_callback.GetIdForRegisteredCallback(callback) - instruction = instr_util.InstructionProto() + instruction = instr_pb.InstructionProto() instruction.instr_type_name = "RemoveForeignCallback" instruction.operand.append(_DelObjectOperand(object_id)) instruction.operand.append(_Int64Operand(unique_callback_id)) @@ -264,7 +311,7 @@ class InstructionsBuilder(object): phy_parallel_desc_symbols = [] def AppendPhyParallelDescSymbol(machine_id, device_id): - parallel_conf = placement_pb_util.ParallelConf() + parallel_conf = placement_pb.ParallelConf() parallel_conf.device_tag = device_tag parallel_conf.device_name.append("%d:%d" % (machine_id, device_id)) phy_parallel_desc_symbols.append(self.GetParallelDescSymbol(parallel_conf)) @@ -450,6 +497,13 @@ class InstructionsBuilder(object): ) return OpKernelObject(object_id, op_conf, self.release_object_) + def Build121To(self, blob_object, parallel_desc_symbol): + ref_blob_object = _MakeNewBlobObjectLike( + self, blob_object, parallel_desc_symbol + ) + self.Build121AssignInstruction(ref_blob_object, blob_object) + return ref_blob_object + def Build121AssignInstruction(self, ref_blob_object, value_blob_object): parallel_num = ref_blob_object.parallel_desc_symbol.parallel_num assert parallel_num == value_blob_object.parallel_desc_symbol.parallel_num @@ -467,7 +521,7 @@ class InstructionsBuilder(object): def _BuildSendInstruction( self, dst_parallel_desc_symbol, src_blob_object, token_ids ): - instruction = instr_util.InstructionProto() + instruction = instr_pb.InstructionProto() instruction.instr_type_name = "SendBlob" instruction.parallel_desc_symbol_id = ( src_blob_object.parallel_desc_symbol.symbol_id @@ -485,7 +539,7 @@ class InstructionsBuilder(object): def _BuildRecvInstruction( self, src_parallel_desc_symbol, dst_blob_object, token_ids ): - instruction = instr_util.InstructionProto() + instruction = instr_pb.InstructionProto() instruction.instr_type_name = "ReceiveBlob" instruction.parallel_desc_symbol_id = ( dst_blob_object.parallel_desc_symbol.symbol_id @@ -502,7 +556,7 @@ class InstructionsBuilder(object): def _NewOpKernelObject(self, parallel_desc_symbol, job_desc_sym, op_conf_sym): object_id = self._NewObjectId(parallel_desc_symbol) - instruction = instr_util.InstructionProto() + instruction = instr_pb.InstructionProto() instruction.instr_type_name = "InitOpKernelObject" instruction.parallel_desc_symbol_id = parallel_desc_symbol.symbol_id instruction.operand.append(_SymbolOperand(job_desc_sym.symbol_id)) @@ -623,14 +677,14 @@ class InstructionsBuilder(object): ) def _CudaHostRegisterBlob(self, blob_object): - instruction = instr_util.InstructionProto() + instruction = instr_pb.InstructionProto() instruction.instr_type_name = "CudaHostRegisterBlob" instruction.parallel_desc_symbol_id = blob_object.parallel_desc_symbol.symbol_id instruction.operand.append(_MutOperand(blob_object.object_id)) self.instruction_list_.instruction.append(instruction) def _CudaHostUnregisterBlob(self, blob_object): - instruction = instr_util.InstructionProto() + instruction = instr_pb.InstructionProto() instruction.instr_type_name = "CudaHostUnregisterBlob" instruction.parallel_desc_symbol_id = blob_object.parallel_desc_symbol.symbol_id instruction.operand.append(_MutOperand(blob_object.object_id)) @@ -834,7 +888,7 @@ class InstructionsBuilder(object): mut1_operand_blob_objects, mut2_operand_blob_objects, ): - instruction = instr_util.InstructionProto() + instruction = instr_pb.InstructionProto() instruction.instr_type_name = "%s.%s" % ( parallel_desc_sym.device_tag, instr_name, @@ -877,7 +931,7 @@ class InstructionsBuilder(object): mut1_operand_blob_objects, mut2_operand_blob_objects, ): - instruction = instr_util.InstructionProto() + instruction = instr_pb.InstructionProto() instruction.instr_type_name = "%s.%s" % ( parallel_desc_sym.device_tag, instr_name, @@ -909,7 +963,7 @@ class InstructionsBuilder(object): def _NewSymbolId(self): symbol_id = self.id_generator_.NewSymbolId() - instruction = instr_util.InstructionProto() + instruction = instr_pb.InstructionProto() instruction.instr_type_name = "NewSymbol" instruction.operand.append(_Int64Operand(symbol_id)) self.instruction_list_.instruction.append(instruction) @@ -917,7 +971,7 @@ class InstructionsBuilder(object): def _NewObjectId(self, parallel_desc_sym): object_id = self.id_generator_.NewObjectId() - instruction = instr_util.InstructionProto() + instruction = instr_pb.InstructionProto() instruction.instr_type_name = "NewObject" instruction.parallel_desc_symbol_id = parallel_desc_sym.symbol_id instruction.operand.append(_Int64Operand(object_id)) @@ -925,7 +979,7 @@ class InstructionsBuilder(object): return object_id def _LazyReference(self, blob_object, interface_op_name): - instruction = instr_util.InstructionProto() + instruction = instr_pb.InstructionProto() device_tag = blob_object.parallel_desc_symbol.device_tag instruction.instr_type_name = "{}.LazyReference".format(device_tag) instruction.parallel_desc_symbol_id = blob_object.parallel_desc_symbol.symbol_id @@ -938,7 +992,7 @@ class InstructionsBuilder(object): def _BroadcastObjectReference(self, sole_mirrored_object, parallel_desc_sym): object_id = self.id_generator_.NewObjectId() - instruction = instr_util.InstructionProto() + instruction = instr_pb.InstructionProto() instruction.instr_type_name = "BroadcastObjectReference" instruction.parallel_desc_symbol_id = parallel_desc_sym.symbol_id instruction.operand.append(_Int64Operand(object_id)) @@ -947,68 +1001,68 @@ class InstructionsBuilder(object): return object_id def _InitStringSymbol(self, symbol_id, string): - instruction = instr_util.InstructionProto() + instruction = instr_pb.InstructionProto() instruction.instr_type_name = "InitStringSymbol" instruction.operand.append(_InitSymbolOperand(symbol_id)) self.instruction_list_.instruction.append(instruction) - eager_symbol = eager_symbol_util.EagerSymbol() + eager_symbol = eager_symbol_pb.EagerSymbol() eager_symbol.symbol_id = symbol_id eager_symbol.string_symbol = string self.eager_symbol_list_.eager_symbol.append(eager_symbol) def _NewParallelConfSymbol(self, symbol_id, parallel_conf): - instruction = instr_util.InstructionProto() + instruction = instr_pb.InstructionProto() instruction.instr_type_name = "NewParallelDescSymbol" instruction.operand.append(_Int64Operand(symbol_id)) self.instruction_list_.instruction.append(instruction) - eager_symbol = eager_symbol_util.EagerSymbol() + eager_symbol = eager_symbol_pb.EagerSymbol() eager_symbol.symbol_id = symbol_id eager_symbol.parallel_conf_symbol.CopyFrom(parallel_conf) self.eager_symbol_list_.eager_symbol.append(eager_symbol) def _NewScopeSymbol(self, scope_proto): - instruction = instr_util.InstructionProto() + instruction = instr_pb.InstructionProto() instruction.instr_type_name = "InitScopeSymbol" instruction.operand.append(_InitSymbolOperand(scope_proto.symbol_id)) self.instruction_list_.instruction.append(instruction) - eager_symbol = eager_symbol_util.EagerSymbol() + eager_symbol = eager_symbol_pb.EagerSymbol() eager_symbol.symbol_id = scope_proto.symbol_id eager_symbol.scope_symbol.CopyFrom(scope_proto) self.eager_symbol_list_.eager_symbol.append(eager_symbol) def _InitJobConfSymbol(self, symbol_id, job_conf): - instruction = instr_util.InstructionProto() + instruction = instr_pb.InstructionProto() instruction.instr_type_name = "InitJobDescSymbol" instruction.operand.append(_InitSymbolOperand(symbol_id)) self.instruction_list_.instruction.append(instruction) - eager_symbol = eager_symbol_util.EagerSymbol() + eager_symbol = eager_symbol_pb.EagerSymbol() eager_symbol.symbol_id = symbol_id eager_symbol.job_conf_symbol.CopyFrom(job_conf) self.eager_symbol_list_.eager_symbol.append(eager_symbol) def _InitOpConfSymbol(self, symbol_id, op_conf): - instruction = instr_util.InstructionProto() + instruction = instr_pb.InstructionProto() instruction.instr_type_name = "InitOperatorConfSymbol" instruction.operand.append(_InitSymbolOperand(symbol_id)) self.instruction_list_.instruction.append(instruction) - eager_symbol = eager_symbol_util.EagerSymbol() + eager_symbol = eager_symbol_pb.EagerSymbol() eager_symbol.symbol_id = symbol_id eager_symbol.op_conf_symbol.CopyFrom(op_conf) self.eager_symbol_list_.eager_symbol.append(eager_symbol) def _InitOpNodeSignatureDescSymbol(self, symbol_id, op_node_signature): - instruction = instr_util.InstructionProto() + instruction = instr_pb.InstructionProto() instruction.instr_type_name = "InitOpNodeSignatureDescSymbol" instruction.operand.append(_InitSymbolOperand(symbol_id)) self.instruction_list_.instruction.append(instruction) - eager_symbol = eager_symbol_util.EagerSymbol() + eager_symbol = eager_symbol_pb.EagerSymbol() eager_symbol.symbol_id = symbol_id eager_symbol.op_node_signature_symbol.CopyFrom(op_node_signature) self.eager_symbol_list_.eager_symbol.append(eager_symbol) def _FetchBlob(self, instruction_name, blob_object, fetcher): unique_callback_id = python_callback.GetIdForRegisteredCallback(fetcher) - instruction = instr_util.InstructionProto() + instruction = instr_pb.InstructionProto() device_tag = blob_object.parallel_desc_symbol.device_tag instruction.instr_type_name = "%s.%s" % (device_tag, instruction_name) instruction.parallel_desc_symbol_id = blob_object.parallel_desc_symbol.symbol_id @@ -1018,7 +1072,7 @@ class InstructionsBuilder(object): def FeedBlob(self, blob_object, feeder): unique_callback_id = python_callback.GetIdForRegisteredCallback(feeder) - instruction = instr_util.InstructionProto() + instruction = instr_pb.InstructionProto() device_tag = blob_object.parallel_desc_symbol.device_tag instruction.instr_type_name = "%s.%s" % (device_tag, "FeedBlob") instruction.parallel_desc_symbol_id = blob_object.parallel_desc_symbol.symbol_id @@ -1027,21 +1081,21 @@ class InstructionsBuilder(object): self.instruction_list_.instruction.append(instruction) def _TryClearObject(self, obj): - instruction = instr_util.InstructionProto() + instruction = instr_pb.InstructionProto() instruction.instr_type_name = "TryClearObject" instruction.parallel_desc_symbol_id = obj.parallel_desc_symbol.symbol_id instruction.operand.append(_MutOperand(obj.object_id)) self.instruction_list_.instruction.append(instruction) def _DeleteObject(self, blob_object): - instruction = instr_util.InstructionProto() + instruction = instr_pb.InstructionProto() instruction.instr_type_name = "DeleteObject" instruction.parallel_desc_symbol_id = blob_object.parallel_desc_symbol.symbol_id instruction.operand.append(_DelObjectOperand(blob_object.object_id)) self.instruction_list_.instruction.append(instruction) def _ReplaceMirrored(self, parallel_desc_sym, lhs_objects, rhs_objects): - instruction = instr_util.InstructionProto() + instruction = instr_pb.InstructionProto() instruction.instr_type_name = "ReplaceMirrored" instruction.parallel_desc_symbol_id = parallel_desc_sym.symbol_id for lhs_object in lhs_objects: @@ -1052,56 +1106,76 @@ class InstructionsBuilder(object): self.instruction_list_.instruction.append(instruction) +def _MakeNewBlobObjectLike(builder, blob_object, new_parallel_desc_symbol): + op_conf = op_conf_pb.OperatorConf() + op_conf.name = id_util.UniqueStr("Input") + op_conf.device_tag = new_parallel_desc_symbol.device_tag + op_conf.input_conf.out = "out" + blob_object.op_arg_parallel_attr.DumpToToInterfaceBlobConf( + op_conf.input_conf.blob_conf + ) + blob_object.op_arg_blob_attr.DumpToToInterfaceBlobConf(op_conf.input_conf.blob_conf) + op_conf.scope_symbol_id = oneflow.current_scope().symbol_id + upstream_signature = op_attribute_pb.OpNodeSignature() + op_attribute = c_api_util.InferOpConf(op_conf, upstream_signature) + parallel_conf = new_parallel_desc_symbol.parallel_conf + bn_in_op2blob_object = {} + builder.RawStatelessCall( + op_attribute, parallel_conf, bn_in_op2blob_object=bn_in_op2blob_object + ) + return bn_in_op2blob_object["out"] + + def _SymbolOperand(val): - operand = instr_util.InstructionOperandProto() + operand = instr_pb.InstructionOperandProto() _SetSoleMirroredOperand(operand.symbol_operand, val) return operand def _InitSymbolOperand(val): - operand = instr_util.InstructionOperandProto() + operand = instr_pb.InstructionOperandProto() _SetSoleMirroredOperand(operand.init_symbol_operand, val) return operand def _ConstOperand(val): - operand = instr_util.InstructionOperandProto() + operand = instr_pb.InstructionOperandProto() _SetMirroredOperand(operand.const_operand, val) return operand def _MutOperand(val): - operand = instr_util.InstructionOperandProto() + operand = instr_pb.InstructionOperandProto() _SetMirroredOperand(operand.mut_operand, val) return operand def _Mut2Operand(val): - operand = instr_util.InstructionOperandProto() + operand = instr_pb.InstructionOperandProto() _SetMirroredOperand(operand.mut2_operand, val) return operand def _DelObjectOperand(val): - operand = instr_util.InstructionOperandProto() + operand = instr_pb.InstructionOperandProto() _SetAllMirroredOperand(operand.mut_operand, val) return operand def _Int64Operand(val): - operand = instr_util.InstructionOperandProto() + operand = instr_pb.InstructionOperandProto() operand.int64_operand = val return operand def _Uint64Operand(val): - operand = instr_util.InstructionOperandProto() + operand = instr_pb.InstructionOperandProto() operand.uint64_operand = val return operand def _OperandSeparator(): - operand = instr_util.InstructionOperandProto() + operand = instr_pb.InstructionOperandProto() operand.separator.SetInParent() return operand @@ -1121,14 +1195,12 @@ def _SetAllMirroredOperand(operand, val): operand.all_mirrored_object.SetInParent() -def _FindOrCreateDelegateBlobObject(builder, x_blob_object, op_arg_parallel_attr): +def _FindOrCreateDelegateBlobObject( + builder, Fetch, x_blob_object, op_arg_parallel_attr +): if x_blob_object.op_arg_parallel_attr == op_arg_parallel_attr: return x_blob_object blob_cache = blob_cache_util.FindOrCreateBlobCache(x_blob_object) - - def Fetch(x_blob_object, op_arg_parallel_attr): - return boxing_util.BoxingTo(builder, x_blob_object, op_arg_parallel_attr) - return blob_cache.GetCachedDelegateBlobObject(op_arg_parallel_attr, Fetch)