未验证 提交 738c8464 编写于 作者: S songyouwei 提交者: GitHub

support tensor array create read write in dygraph (#23200)

* support tensor array create read write in dygraph
test=develop

* minor fix
test=develop

* support tensor_array_to_tensor
test=develop

* add while loop case and err msg
test=develop

* refine ut
test=develop
上级 f944b0f6
...@@ -1287,6 +1287,31 @@ def array_write(x, i, array=None): ...@@ -1287,6 +1287,31 @@ def array_write(x, i, array=None):
# and '__int64' on Windows. They both represent 64-bit integer variables. # and '__int64' on Windows. They both represent 64-bit integer variables.
""" """
if in_dygraph_mode():
assert isinstance(
x, Variable
), "The input data 'x' in array_write must be Variable in dygraph mode"
assert isinstance(
i, Variable
), "The index 'i' in array_write must be Variable in dygraph mode"
assert i.shape == [
1
], "The shape of index 'i' should be [1] in dygraph mode"
i = i.numpy()[0]
if array is None:
array = create_array(x.dtype)
assert isinstance(
array,
list), "The 'array' in array_write must be a list in dygraph mode"
assert i <= len(
array
), "The index 'i' should not be greater than the length of 'array' in dygraph mode"
if i < len(array):
array[i] = x
else:
array.append(x)
return array
helper = LayerHelper('array_write', **locals()) helper = LayerHelper('array_write', **locals())
if array is None: if array is None:
array = helper.create_variable( array = helper.create_variable(
...@@ -1322,6 +1347,9 @@ def create_array(dtype): ...@@ -1322,6 +1347,9 @@ def create_array(dtype):
data = fluid.layers.create_array(dtype='float32') # Create a float32 LoDTensorArray. data = fluid.layers.create_array(dtype='float32') # Create a float32 LoDTensorArray.
""" """
if in_dygraph_mode():
return []
helper = LayerHelper("array", **locals()) helper = LayerHelper("array", **locals())
return helper.create_variable( return helper.create_variable(
name="{0}.out".format(helper.name), name="{0}.out".format(helper.name),
...@@ -1643,6 +1671,19 @@ def array_read(array, i): ...@@ -1643,6 +1671,19 @@ def array_read(array, i):
# so the dtype value is typeid(int64_t).Name(), which is 'x' on MacOS, 'l' on Linux, # so the dtype value is typeid(int64_t).Name(), which is 'x' on MacOS, 'l' on Linux,
# and '__int64' on Windows. They both represent 64-bit integer variables. # and '__int64' on Windows. They both represent 64-bit integer variables.
""" """
if in_dygraph_mode():
assert isinstance(
array,
list), "The 'array' in array_read must be list in dygraph mode"
assert isinstance(
i, Variable
), "The index 'i' in array_read must be Variable in dygraph mode"
assert i.shape == [
1
], "The shape of index 'i' should be [1] in dygraph mode"
i = i.numpy()[0]
return array[i]
helper = LayerHelper('array_read', **locals()) helper = LayerHelper('array_read', **locals())
if not isinstance( if not isinstance(
array, array,
...@@ -1739,6 +1780,12 @@ def array_length(array): ...@@ -1739,6 +1780,12 @@ def array_length(array):
# so the dtype value is typeid(int64_t).Name(), which is 'x' on MacOS, 'l' on Linux, # so the dtype value is typeid(int64_t).Name(), which is 'x' on MacOS, 'l' on Linux,
# and '__int64' on Windows. They both represent 64-bit integer variables. # and '__int64' on Windows. They both represent 64-bit integer variables.
""" """
if in_dygraph_mode():
assert isinstance(
array,
list), "The 'array' in array_write must be a list in dygraph mode"
return len(array)
helper = LayerHelper('array_length', **locals()) helper = LayerHelper('array_length', **locals())
tmp = helper.create_variable_for_type_inference(dtype='int64') tmp = helper.create_variable_for_type_inference(dtype='int64')
tmp.stop_gradient = True tmp.stop_gradient = True
......
...@@ -381,6 +381,17 @@ def tensor_array_to_tensor(input, axis=1, name=None, use_stack=False): ...@@ -381,6 +381,17 @@ def tensor_array_to_tensor(input, axis=1, name=None, use_stack=False):
fluid.layers.array_write(x1, i + 1, array) fluid.layers.array_write(x1, i + 1, array)
output, output_index = fluid.layers.tensor_array_to_tensor(input=array) output, output_index = fluid.layers.tensor_array_to_tensor(input=array)
""" """
if in_dygraph_mode():
assert isinstance(
input, list), "The 'input' in tensor_array_to_tensor must be list"
from .nn import stack, concat
from ..dygraph import to_variable
op = stack if use_stack else concat
res = op(input, axis=axis)
sizes = to_variable(
numpy.array(list(map(lambda x: int(x.shape[axis]), input))))
return res, sizes
helper = LayerHelper('tensor_array_to_tensor', **locals()) helper = LayerHelper('tensor_array_to_tensor', **locals())
out = helper.create_variable_for_type_inference(dtype=helper.input_dtype()) out = helper.create_variable_for_type_inference(dtype=helper.input_dtype())
out_index = helper.create_variable_for_type_inference(dtype="int32") out_index = helper.create_variable_for_type_inference(dtype="int32")
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
from __future__ import print_function from __future__ import print_function
import unittest import unittest
import paddle.fluid as fluid
import paddle.fluid.core as core import paddle.fluid.core as core
import paddle.fluid.layers as layers import paddle.fluid.layers as layers
from paddle.fluid.executor import Executor from paddle.fluid.executor import Executor
...@@ -23,18 +24,7 @@ from paddle.fluid.framework import default_main_program ...@@ -23,18 +24,7 @@ from paddle.fluid.framework import default_main_program
import numpy import numpy
class TestArrayReadWrite(unittest.TestCase): def _test_read_write(x):
def test_read_write(self):
x = [
layers.data(
name='x0', shape=[100]), layers.data(
name='x1', shape=[100]), layers.data(
name='x2', shape=[100])
]
for each_x in x:
each_x.stop_gradient = False
i = layers.zeros(shape=[1], dtype='int64') i = layers.zeros(shape=[1], dtype='int64')
i.stop_gradient = False i.stop_gradient = False
arr = layers.array_write(x=x[0], i=i) arr = layers.array_write(x=x[0], i=i)
...@@ -63,18 +53,30 @@ class TestArrayReadWrite(unittest.TestCase): ...@@ -63,18 +53,30 @@ class TestArrayReadWrite(unittest.TestCase):
x_sum = layers.sums(input=[mean_x0, mean_x1, mean_x2]) x_sum = layers.sums(input=[mean_x0, mean_x1, mean_x2])
scope = core.Scope() return a_sum, x_sum
cpu = core.CPUPlace()
exe = Executor(cpu)
class TestArrayReadWrite(unittest.TestCase):
def test_read_write(self):
x = [
layers.data(
name='x0', shape=[100]), layers.data(
name='x1', shape=[100]), layers.data(
name='x2', shape=[100])
]
for each_x in x:
each_x.stop_gradient = False
tensor = numpy.random.random(size=(100, 100)).astype('float32') tensor = numpy.random.random(size=(100, 100)).astype('float32')
a_sum, x_sum = _test_read_write(x)
place = core.CPUPlace()
exe = Executor(place)
outs = exe.run(feed={'x0': tensor, outs = exe.run(feed={'x0': tensor,
'x1': tensor, 'x1': tensor,
'x2': tensor}, 'x2': tensor},
fetch_list=[a_sum, x_sum], fetch_list=[a_sum, x_sum],
scope=scope) scope=core.Scope())
self.assertEqual(outs[0], outs[1]) self.assertEqual(outs[0], outs[1])
total_sum = layers.sums(input=[a_sum, x_sum]) total_sum = layers.sums(input=[a_sum, x_sum])
...@@ -100,6 +102,28 @@ class TestArrayReadWrite(unittest.TestCase): ...@@ -100,6 +102,28 @@ class TestArrayReadWrite(unittest.TestCase):
# the input gradient should also be 1 # the input gradient should also be 1
self.assertAlmostEqual(1.0, g_out_sum, delta=0.1) self.assertAlmostEqual(1.0, g_out_sum, delta=0.1)
with fluid.dygraph.guard(place):
tensor1 = fluid.dygraph.to_variable(tensor)
tensor2 = fluid.dygraph.to_variable(tensor)
tensor3 = fluid.dygraph.to_variable(tensor)
x_dygraph = [tensor1, tensor2, tensor3]
for each_x in x_dygraph:
each_x.stop_gradient = False
a_sum_dygraph, x_sum_dygraph = _test_read_write(x_dygraph)
self.assertEqual(a_sum_dygraph, x_sum_dygraph)
total_sum_dygraph = layers.sums(
input=[a_sum_dygraph, x_sum_dygraph])
total_sum_scaled_dygraph = layers.scale(
x=total_sum_dygraph, scale=1 / 6.0)
total_sum_scaled_dygraph.backward()
g_out_dygraph = [
item._grad_ivar().numpy().sum() for item in x_dygraph
]
g_out_sum_dygraph = numpy.array(g_out_dygraph).sum()
self.assertAlmostEqual(1.0, g_out_sum_dygraph, delta=0.1)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
...@@ -204,17 +204,61 @@ class TestLoDTensorArrayStack(unittest.TestCase): ...@@ -204,17 +204,61 @@ class TestLoDTensorArrayStack(unittest.TestCase):
class TestTensorArrayToTensorAPI(unittest.TestCase): class TestTensorArrayToTensorAPI(unittest.TestCase):
def test_case(self): def _test_case(self, inp1, inp2):
x0 = fluid.layers.assign(numpy.random.rand(2, 3, 4).astype("float32")) x0 = fluid.layers.assign(inp1)
x1 = fluid.layers.assign(numpy.random.rand(2, 3, 4).astype("float32")) x0.stop_gradient = False
x1 = fluid.layers.assign(inp2)
x1.stop_gradient = False
i = fluid.layers.fill_constant(shape=[1], dtype="int64", value=0) i = fluid.layers.fill_constant(shape=[1], dtype="int64", value=0)
array = fluid.layers.create_array(dtype='float32') array = fluid.layers.create_array(dtype='float32')
fluid.layers.array_write(x0, i, array) fluid.layers.array_write(x0, i, array)
fluid.layers.array_write(x1, i + 1, array) fluid.layers.array_write(x1, i + 1, array)
output, output_index = fluid.layers.tensor_array_to_tensor( output_stack, output_index_stack = fluid.layers.tensor_array_to_tensor(
input=array, axis=1, use_stack=True) input=array, axis=1, use_stack=True)
output, output_index = fluid.layers.tensor_array_to_tensor( output_concat, output_index_concat = fluid.layers.tensor_array_to_tensor(
input=array, axis=1, use_stack=False) input=array, axis=1, use_stack=False)
return output_stack, output_index_stack, output_concat, output_index_concat
def test_case(self):
inp0 = numpy.random.rand(2, 3, 4).astype("float32")
inp1 = numpy.random.rand(2, 3, 4).astype("float32")
_outs_static = self._test_case(inp0, inp1)
place = fluid.CPUPlace()
exe = fluid.Executor(place)
outs_static = exe.run(fetch_list=list(_outs_static))
with fluid.dygraph.guard(place):
outs_dynamic = self._test_case(inp0, inp1)
for s, d in zip(outs_static, outs_dynamic):
self.assertTrue(numpy.array_equal(s, d.numpy()))
def test_while_loop_case(self):
with fluid.dygraph.guard():
zero = fluid.layers.fill_constant(shape=[1], dtype='int64', value=0)
i = fluid.layers.fill_constant(shape=[1], dtype='int64', value=1)
ten = fluid.layers.fill_constant(shape=[1], dtype='int64', value=10)
array = fluid.layers.create_array(dtype='float32')
inp0 = numpy.random.rand(2, 3, 4).astype("float32")
x0 = fluid.layers.assign(inp0)
fluid.layers.array_write(x0, zero, array)
def cond(i, end, array):
return fluid.layers.less_than(i, end)
def body(i, end, array):
prev = fluid.layers.array_read(array, i - 1)
fluid.layers.array_write(prev, i, array)
return i + 1, end, array
_, _, array = fluid.layers.while_loop(cond, body, [i, ten, array])
self.assertTrue(fluid.layers.array_length(array), 10)
last = fluid.layers.fill_constant(shape=[1], dtype='int64', value=9)
self.assertTrue(
numpy.array_equal(
fluid.layers.array_read(array, last).numpy(), inp0))
if __name__ == '__main__': if __name__ == '__main__':
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册