提交 aab01c7b 编写于 作者: L lixinqi

refactor get_variable

上级 455aa602
......@@ -20,6 +20,20 @@ def get_variable(
random_seed=None,
distribute=distribute_util.broadcast(),
):
op_conf = _GenerateVariableOpConf(
name=name, shape=shape, dtype=dtype, initializer=initializer, trainable=trainable,
model_name=model_name, random_seed=random_seed,distribute=distribute)
op_conf, parallel_conf = compile_context.GetOpConfAndParallelConf(op_conf)
return _CreateVariableBlob(op_conf, parallel_conf)
def _GenerateVariableOpConf(name,
shape=None,
dtype=None,
initializer=None,
trainable=None,
model_name=None,
random_seed=None,
distribute=distribute_util.broadcast(),):
assert isinstance(name, str)
name = compile_context.GetVariablePrefix() + name
......@@ -58,11 +72,9 @@ def get_variable(
op_conf.variable_conf.random_seed = random_seed
op_conf.variable_conf.out = "out"
return op_conf
op_conf, parallel_conf = compile_context.GetOpConfAndParallelConf(op_conf)
return CreateVariableBlob(op_conf, parallel_conf)
def CreateVariableBlob(op_conf, parallel_conf):
def _CreateVariableBlob(op_conf, parallel_conf):
compile_context.CurJobAddConsistentOp(op_conf)
lbi = logical_blob_id_util.LogicalBlobId()
lbi.op_name = op_conf.name
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册