From 87b424e866b742b0173389ace4abfd6e66caa1ad Mon Sep 17 00:00:00 2001 From: Yang Yu Date: Mon, 22 Jan 2018 11:12:42 +0800 Subject: [PATCH] Follow comments --- .../paddle/v2/fluid/layers/math_op_patch.py | 27 +++++++------------ 1 file changed, 9 insertions(+), 18 deletions(-) diff --git a/python/paddle/v2/fluid/layers/math_op_patch.py b/python/paddle/v2/fluid/layers/math_op_patch.py index 5efc116db2e..11197b70a3d 100644 --- a/python/paddle/v2/fluid/layers/math_op_patch.py +++ b/python/paddle/v2/fluid/layers/math_op_patch.py @@ -19,7 +19,7 @@ __all__ = ['monkey_patch_variable'] def monkey_patch_variable(): - def new_name(): + def unique_tmp_name(): return unique_name("tmp") def safe_get_dtype(var): @@ -29,21 +29,9 @@ def monkey_patch_variable(): raise ValueError("Cannot get data type from %s", var.name) return dtype - def create_scalar(block, value, dtype): - value = float(value) - tmp_name = new_name() - var = block.create_var(name=tmp_name, shape=[1], dtype=dtype) - block.append_op( - type="fill", - outputs={"Out": [var]}, - attrs={"value": [value], - "shape": [1], - "dtype": dtype}) - return var - def create_tensor(block, value, dtype, shape): value = float(value) - tmp_name = new_name() + tmp_name = unique_tmp_name() var = block.create_var(name=tmp_name, shape=shape, dtype=dtype) block.append_op( type="fill_constant", @@ -53,10 +41,13 @@ def monkey_patch_variable(): 'value': value}) return var + def create_scalar(block, value, dtype): + return create_tensor(block, value, dtype, shape=[1]) + def create_tensor_with_batchsize(ref_var, value, dtype): assert isinstance(ref_var, Variable) value = float(value) - tmp_name = new_name() + tmp_name = unique_tmp_name() var = ref_var.block.create_var(name=tmp_name, dtype=dtype) ref_var.block.append_op( type='fill_constant_batch_size_like', @@ -68,7 +59,7 @@ def monkey_patch_variable(): def astype(self, dtype): """ - Cast a variable to data type. + Cast a variable to a specified data type. NOTE: The variable must be a Tensor Args: self(Variable): The source variable @@ -77,7 +68,7 @@ def monkey_patch_variable(): Returns: Variable with new dtype """ - tmp_name = new_name() + tmp_name = unique_tmp_name() out = self.block.create_var(name=tmp_name, dtype=dtype) self.block.append_op( type="cast", @@ -120,7 +111,7 @@ def monkey_patch_variable(): self = other_var other_var = tmp - tmp_name = new_name() + tmp_name = unique_tmp_name() out = self.block.create_var(name=tmp_name, dtype=lhs_dtype) self.block.append_op( type=op_type, -- GitLab