未验证 提交 ba574c8e 编写于 作者: W wangchaochaohu 提交者: GitHub

refine the usage of numpy element fetch for Ops test=develop (#26194)

上级 b9828bdf
...@@ -1511,7 +1511,7 @@ def array_write(x, i, array=None): ...@@ -1511,7 +1511,7 @@ def array_write(x, i, array=None):
assert i.shape == [ assert i.shape == [
1 1
], "The shape of index 'i' should be [1] in dygraph mode" ], "The shape of index 'i' should be [1] in dygraph mode"
i = i.numpy()[0] i = i.numpy().item(0)
if array is None: if array is None:
array = create_array(x.dtype) array = create_array(x.dtype)
assert isinstance( assert isinstance(
...@@ -1976,7 +1976,7 @@ def array_read(array, i): ...@@ -1976,7 +1976,7 @@ def array_read(array, i):
assert i.shape == [ assert i.shape == [
1 1
], "The shape of index 'i' should be [1] in dygraph mode" ], "The shape of index 'i' should be [1] in dygraph mode"
i = i.numpy()[0] i = i.numpy().item(0)
return array[i] return array[i]
check_variable_and_dtype(i, 'i', ['int64'], 'array_read') check_variable_and_dtype(i, 'i', ['int64'], 'array_read')
......
...@@ -4841,7 +4841,7 @@ def split(input, num_or_sections, dim=-1, name=None): ...@@ -4841,7 +4841,7 @@ def split(input, num_or_sections, dim=-1, name=None):
if isinstance(dim, Variable): if isinstance(dim, Variable):
dim = dim.numpy() dim = dim.numpy()
dim = dim[0] dim = dim.item(0)
dim = (len(input.shape) + dim) if dim < 0 else dim dim = (len(input.shape) + dim) if dim < 0 else dim
attrs += ('axis', dim) attrs += ('axis', dim)
...@@ -5885,7 +5885,7 @@ def one_hot(input, depth, allow_out_of_range=False): ...@@ -5885,7 +5885,7 @@ def one_hot(input, depth, allow_out_of_range=False):
depth = depth.numpy() depth = depth.numpy()
assert depth.shape == ( assert depth.shape == (
1, ), "depth of type Variable should have shape [1]" 1, ), "depth of type Variable should have shape [1]"
depth = depth[0] depth = depth.item(0)
out = core.ops.one_hot(input, 'depth', depth, 'allow_out_of_range', out = core.ops.one_hot(input, 'depth', depth, 'allow_out_of_range',
allow_out_of_range) allow_out_of_range)
out.stop_gradient = True out.stop_gradient = True
...@@ -6067,7 +6067,7 @@ def reshape(x, shape, actual_shape=None, act=None, inplace=False, name=None): ...@@ -6067,7 +6067,7 @@ def reshape(x, shape, actual_shape=None, act=None, inplace=False, name=None):
) )
if isinstance(shape, (list, tuple)): if isinstance(shape, (list, tuple)):
shape = [ shape = [
item.numpy()[0] if isinstance(item, Variable) else item item.numpy().item(0) if isinstance(item, Variable) else item
for item in shape for item in shape
] ]
out, _ = core.ops.reshape2(x, 'shape', shape) out, _ = core.ops.reshape2(x, 'shape', shape)
...@@ -10195,7 +10195,7 @@ def expand(x, expand_times, name=None): ...@@ -10195,7 +10195,7 @@ def expand(x, expand_times, name=None):
if in_dygraph_mode(): if in_dygraph_mode():
if isinstance(expand_times, (list, tuple)): if isinstance(expand_times, (list, tuple)):
expand_times = [ expand_times = [
item.numpy()[0] if isinstance(item, Variable) else item item.numpy().item(0) if isinstance(item, Variable) else item
for item in expand_times for item in expand_times
] ]
...@@ -10806,11 +10806,11 @@ def slice(input, axes, starts, ends): ...@@ -10806,11 +10806,11 @@ def slice(input, axes, starts, ends):
if isinstance(starts, (list, tuple)) and isinstance(ends, if isinstance(starts, (list, tuple)) and isinstance(ends,
(list, tuple)): (list, tuple)):
starts = [ starts = [
item.numpy()[0] if isinstance(item, Variable) else item item.numpy().item(0) if isinstance(item, Variable) else item
for item in starts for item in starts
] ]
ends = [ ends = [
item.numpy()[0] if isinstance(item, Variable) else item item.numpy().item(0) if isinstance(item, Variable) else item
for item in ends for item in ends
] ]
......
...@@ -317,7 +317,7 @@ def concat(input, axis=0, name=None): ...@@ -317,7 +317,7 @@ def concat(input, axis=0, name=None):
if in_dygraph_mode(): if in_dygraph_mode():
if isinstance(axis, Variable): if isinstance(axis, Variable):
axis = axis.numpy() axis = axis.numpy()
axis = axis[0] axis = axis.item(0)
return core.ops.concat(input, 'axis', axis) return core.ops.concat(input, 'axis', axis)
check_type(input, 'input', (list, tuple, Variable), 'concat') check_type(input, 'input', (list, tuple, Variable), 'concat')
...@@ -699,9 +699,9 @@ def fill_constant(shape, dtype, value, force_cpu=False, out=None, name=None): ...@@ -699,9 +699,9 @@ def fill_constant(shape, dtype, value, force_cpu=False, out=None, name=None):
if isinstance(value, Variable): if isinstance(value, Variable):
if dtype in ['int64', 'int32']: if dtype in ['int64', 'int32']:
attrs['str_value'] = str(int(value.numpy())) attrs['str_value'] = str(int(value.numpy().item(0)))
else: else:
attrs['str_value'] = str(float(value.numpy())) attrs['str_value'] = str(float(value.numpy().item(0)))
core.ops.fill_constant(out, 'value', core.ops.fill_constant(out, 'value',
float(value), 'force_cpu', force_cpu, 'dtype', float(value), 'force_cpu', force_cpu, 'dtype',
......
...@@ -305,14 +305,18 @@ class TestFillConstantImperative(unittest.TestCase): ...@@ -305,14 +305,18 @@ class TestFillConstantImperative(unittest.TestCase):
with fluid.dygraph.guard(): with fluid.dygraph.guard():
data1 = np.array([1, 2]).astype('int32') data1 = np.array([1, 2]).astype('int32')
data2 = np.array([1.1]).astype('float32') data2 = np.array([1.1]).astype('float32')
data3 = np.array([88]).astype('int32')
shape = fluid.dygraph.to_variable(data1) shape = fluid.dygraph.to_variable(data1)
val = fluid.dygraph.to_variable(data2) val = fluid.dygraph.to_variable(data2)
value = fluid.dygraph.to_variable(data3)
res1 = fluid.layers.fill_constant( res1 = fluid.layers.fill_constant(
shape=[1, 2], dtype='float32', value=1.1) shape=[1, 2], dtype='float32', value=1.1)
res2 = fluid.layers.fill_constant( res2 = fluid.layers.fill_constant(
shape=shape, dtype='float32', value=1.1) shape=shape, dtype='float32', value=1.1)
res3 = fluid.layers.fill_constant( res3 = fluid.layers.fill_constant(
shape=shape, dtype='float32', value=val) shape=shape, dtype='float32', value=val)
res4 = fluid.layers.fill_constant(
shape=shape, dtype='int32', value=value)
assert np.array_equal( assert np.array_equal(
res1.numpy(), np.full( res1.numpy(), np.full(
[1, 2], 1.1, dtype="float32")) [1, 2], 1.1, dtype="float32"))
...@@ -322,6 +326,9 @@ class TestFillConstantImperative(unittest.TestCase): ...@@ -322,6 +326,9 @@ class TestFillConstantImperative(unittest.TestCase):
assert np.array_equal( assert np.array_equal(
res3.numpy(), np.full( res3.numpy(), np.full(
[1, 2], 1.1, dtype="float32")) [1, 2], 1.1, dtype="float32"))
assert np.array_equal(
res4.numpy(), np.full(
[1, 2], 88, dtype="int32"))
class TestFillConstantOpError(unittest.TestCase): class TestFillConstantOpError(unittest.TestCase):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册