提交 b53eb7dc 编写于 作者: Q Qiao Longfei

add init once for assign layer

上级 dc8eca82
...@@ -5010,10 +5010,12 @@ def nce(input, ...@@ -5010,10 +5010,12 @@ def nce(input,
alias_probs_[little[0]] = 1.0 alias_probs_[little[0]] = 1.0
alias_[little[0]] = -1 alias_[little[0]] = -1
probs = assign(input=np.array(custom_dist).astype('float32')) probs = assign(
custom_alias = assign(input=np.array(alias_).astype('int32')) input=np.array(custom_dist).astype('float32'), init_once=True)
custom_alias = assign(
input=np.array(alias_).astype('int32'), init_once=True)
custom_alias_probs = assign( custom_alias_probs = assign(
input=np.array(alias_probs_).astype('float32')) input=np.array(alias_probs_).astype('float32'), init_once=True)
inputs['CustomDistProbs'] = probs inputs['CustomDistProbs'] = probs
inputs['CustomDistAlias'] = custom_alias inputs['CustomDistAlias'] = custom_alias
......
...@@ -285,7 +285,7 @@ def sums(input, out=None): ...@@ -285,7 +285,7 @@ def sums(input, out=None):
return out return out
def assign(input, output=None): def assign(input, output=None, init_once=False):
""" """
**Assign** **Assign**
...@@ -294,6 +294,7 @@ def assign(input, output=None): ...@@ -294,6 +294,7 @@ def assign(input, output=None):
Args: Args:
input(Variable|numpy.ndarray): The source variable input(Variable|numpy.ndarray): The source variable
output(Variable|None): The destination variable output(Variable|None): The destination variable
init_once(bool|false): assign value into global var only in startup program.
Returns: Returns:
Variable: The destination variable that was supplied as the *output*. Variable: The destination variable that was supplied as the *output*.
...@@ -307,10 +308,18 @@ def assign(input, output=None): ...@@ -307,10 +308,18 @@ def assign(input, output=None):
""" """
helper = LayerHelper('assign', **locals()) helper = LayerHelper('assign', **locals())
if output is None: if output is None:
output = helper.create_variable_for_type_inference(dtype=input.dtype) if init_once:
output = helper.create_parameter(
attr=ParamAttr(), shape=input.shape, dtype=input.dtype)
else:
output = helper.create_variable_for_type_inference(
dtype=input.dtype)
if isinstance(input, Variable): if isinstance(input, Variable):
if init_once:
raise ValueError("init once only support numpy assign!")
helper.append_op( helper.append_op(
type='assign', inputs={'X': [input]}, outputs={'Out': [output]}) type='assign', inputs={'X': [input]}, outputs={'Out': [output]})
elif isinstance(input, numpy.ndarray): elif isinstance(input, numpy.ndarray):
dtype = convert_np_dtype_to_dtype_(input.dtype) dtype = convert_np_dtype_to_dtype_(input.dtype)
if dtype == VarDesc.VarType.FP32: if dtype == VarDesc.VarType.FP32:
...@@ -325,14 +334,24 @@ def assign(input, output=None): ...@@ -325,14 +334,24 @@ def assign(input, output=None):
raise ValueError("The size of input is too big. Please consider " raise ValueError("The size of input is too big. Please consider "
"saving it to file and 'load_op' to load it") "saving it to file and 'load_op' to load it")
helper.append_op( if init_once:
type='assign_value', helper.startup_program.global_block().append_op(
outputs={'Out': [output]}, type='assign_value',
attrs={ outputs={'Out': [output]},
'dtype': dtype, attrs={
'shape': list(input.shape), 'dtype': dtype,
value_name: values 'shape': list(input.shape),
}) value_name: values
})
else:
helper.append_op(
type='assign_value',
outputs={'Out': [output]},
attrs={
'dtype': dtype,
'shape': list(input.shape),
value_name: values
})
else: else:
raise ValueError("Wrong type for assign input: %s" % type(input)) raise ValueError("Wrong type for assign input: %s" % type(input))
......
...@@ -1015,6 +1015,18 @@ class TestBook(unittest.TestCase): ...@@ -1015,6 +1015,18 @@ class TestBook(unittest.TestCase):
print(str(program)) print(str(program))
def test_assign(self):
import numpy as np
startup = Program()
main = Program()
with program_guard(main, startup):
probs = layers.assign(
input=np.random.random([1, 2]).astype('float32'),
init_once=True)
print(str(main))
print(str(startup))
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册