diff --git a/oneflow/python/ops/get_variable.py b/oneflow/python/ops/get_variable.py index 1a2426f3043bdbfedc9af35b315b483cb7aca572..f8602fd43bfad5871b7915fe769460428381c059 100644 --- a/oneflow/python/ops/get_variable.py +++ b/oneflow/python/ops/get_variable.py @@ -20,6 +20,20 @@ def get_variable( random_seed=None, distribute=distribute_util.broadcast(), ): + op_conf = _GenerateVariableOpConf( + name=name, shape=shape, dtype=dtype, initializer=initializer, trainable=trainable, + model_name=model_name, random_seed=random_seed,distribute=distribute) + op_conf, parallel_conf = compile_context.GetOpConfAndParallelConf(op_conf) + return _CreateVariableBlob(op_conf, parallel_conf) + +def _GenerateVariableOpConf(name, + shape=None, + dtype=None, + initializer=None, + trainable=None, + model_name=None, + random_seed=None, + distribute=distribute_util.broadcast(),): assert isinstance(name, str) name = compile_context.GetVariablePrefix() + name @@ -58,11 +72,9 @@ def get_variable( op_conf.variable_conf.random_seed = random_seed op_conf.variable_conf.out = "out" + return op_conf - op_conf, parallel_conf = compile_context.GetOpConfAndParallelConf(op_conf) - return CreateVariableBlob(op_conf, parallel_conf) - -def CreateVariableBlob(op_conf, parallel_conf): +def _CreateVariableBlob(op_conf, parallel_conf): compile_context.CurJobAddConsistentOp(op_conf) lbi = logical_blob_id_util.LogicalBlobId() lbi.op_name = op_conf.name