提交 87b424e8 编写于 作者: Y Yang Yu

Follow comments

上级 9f731a60
...@@ -19,7 +19,7 @@ __all__ = ['monkey_patch_variable'] ...@@ -19,7 +19,7 @@ __all__ = ['monkey_patch_variable']
def monkey_patch_variable(): def monkey_patch_variable():
def new_name(): def unique_tmp_name():
return unique_name("tmp") return unique_name("tmp")
def safe_get_dtype(var): def safe_get_dtype(var):
...@@ -29,21 +29,9 @@ def monkey_patch_variable(): ...@@ -29,21 +29,9 @@ def monkey_patch_variable():
raise ValueError("Cannot get data type from %s", var.name) raise ValueError("Cannot get data type from %s", var.name)
return dtype return dtype
def create_scalar(block, value, dtype):
value = float(value)
tmp_name = new_name()
var = block.create_var(name=tmp_name, shape=[1], dtype=dtype)
block.append_op(
type="fill",
outputs={"Out": [var]},
attrs={"value": [value],
"shape": [1],
"dtype": dtype})
return var
def create_tensor(block, value, dtype, shape): def create_tensor(block, value, dtype, shape):
value = float(value) value = float(value)
tmp_name = new_name() tmp_name = unique_tmp_name()
var = block.create_var(name=tmp_name, shape=shape, dtype=dtype) var = block.create_var(name=tmp_name, shape=shape, dtype=dtype)
block.append_op( block.append_op(
type="fill_constant", type="fill_constant",
...@@ -53,10 +41,13 @@ def monkey_patch_variable(): ...@@ -53,10 +41,13 @@ def monkey_patch_variable():
'value': value}) 'value': value})
return var return var
def create_scalar(block, value, dtype):
return create_tensor(block, value, dtype, shape=[1])
def create_tensor_with_batchsize(ref_var, value, dtype): def create_tensor_with_batchsize(ref_var, value, dtype):
assert isinstance(ref_var, Variable) assert isinstance(ref_var, Variable)
value = float(value) value = float(value)
tmp_name = new_name() tmp_name = unique_tmp_name()
var = ref_var.block.create_var(name=tmp_name, dtype=dtype) var = ref_var.block.create_var(name=tmp_name, dtype=dtype)
ref_var.block.append_op( ref_var.block.append_op(
type='fill_constant_batch_size_like', type='fill_constant_batch_size_like',
...@@ -68,7 +59,7 @@ def monkey_patch_variable(): ...@@ -68,7 +59,7 @@ def monkey_patch_variable():
def astype(self, dtype): def astype(self, dtype):
""" """
Cast a variable to data type. Cast a variable to a specified data type.
NOTE: The variable must be a Tensor NOTE: The variable must be a Tensor
Args: Args:
self(Variable): The source variable self(Variable): The source variable
...@@ -77,7 +68,7 @@ def monkey_patch_variable(): ...@@ -77,7 +68,7 @@ def monkey_patch_variable():
Returns: Returns:
Variable with new dtype Variable with new dtype
""" """
tmp_name = new_name() tmp_name = unique_tmp_name()
out = self.block.create_var(name=tmp_name, dtype=dtype) out = self.block.create_var(name=tmp_name, dtype=dtype)
self.block.append_op( self.block.append_op(
type="cast", type="cast",
...@@ -120,7 +111,7 @@ def monkey_patch_variable(): ...@@ -120,7 +111,7 @@ def monkey_patch_variable():
self = other_var self = other_var
other_var = tmp other_var = tmp
tmp_name = new_name() tmp_name = unique_tmp_name()
out = self.block.create_var(name=tmp_name, dtype=lhs_dtype) out = self.block.create_var(name=tmp_name, dtype=lhs_dtype)
self.block.append_op( self.block.append_op(
type=op_type, type=op_type,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册