未验证 提交 ba574c8e 编写于 作者: W wangchaochaohu 提交者: GitHub

refine the usage of numpy element fetch for Ops test=develop (#26194)

上级 b9828bdf
......@@ -1511,7 +1511,7 @@ def array_write(x, i, array=None):
assert i.shape == [
1
], "The shape of index 'i' should be [1] in dygraph mode"
i = i.numpy()[0]
i = i.numpy().item(0)
if array is None:
array = create_array(x.dtype)
assert isinstance(
......@@ -1976,7 +1976,7 @@ def array_read(array, i):
assert i.shape == [
1
], "The shape of index 'i' should be [1] in dygraph mode"
i = i.numpy()[0]
i = i.numpy().item(0)
return array[i]
check_variable_and_dtype(i, 'i', ['int64'], 'array_read')
......
......@@ -4841,7 +4841,7 @@ def split(input, num_or_sections, dim=-1, name=None):
if isinstance(dim, Variable):
dim = dim.numpy()
dim = dim[0]
dim = dim.item(0)
dim = (len(input.shape) + dim) if dim < 0 else dim
attrs += ('axis', dim)
......@@ -5885,7 +5885,7 @@ def one_hot(input, depth, allow_out_of_range=False):
depth = depth.numpy()
assert depth.shape == (
1, ), "depth of type Variable should have shape [1]"
depth = depth[0]
depth = depth.item(0)
out = core.ops.one_hot(input, 'depth', depth, 'allow_out_of_range',
allow_out_of_range)
out.stop_gradient = True
......@@ -6067,7 +6067,7 @@ def reshape(x, shape, actual_shape=None, act=None, inplace=False, name=None):
)
if isinstance(shape, (list, tuple)):
shape = [
item.numpy()[0] if isinstance(item, Variable) else item
item.numpy().item(0) if isinstance(item, Variable) else item
for item in shape
]
out, _ = core.ops.reshape2(x, 'shape', shape)
......@@ -10195,7 +10195,7 @@ def expand(x, expand_times, name=None):
if in_dygraph_mode():
if isinstance(expand_times, (list, tuple)):
expand_times = [
item.numpy()[0] if isinstance(item, Variable) else item
item.numpy().item(0) if isinstance(item, Variable) else item
for item in expand_times
]
......@@ -10806,11 +10806,11 @@ def slice(input, axes, starts, ends):
if isinstance(starts, (list, tuple)) and isinstance(ends,
(list, tuple)):
starts = [
item.numpy()[0] if isinstance(item, Variable) else item
item.numpy().item(0) if isinstance(item, Variable) else item
for item in starts
]
ends = [
item.numpy()[0] if isinstance(item, Variable) else item
item.numpy().item(0) if isinstance(item, Variable) else item
for item in ends
]
......
......@@ -317,7 +317,7 @@ def concat(input, axis=0, name=None):
if in_dygraph_mode():
if isinstance(axis, Variable):
axis = axis.numpy()
axis = axis[0]
axis = axis.item(0)
return core.ops.concat(input, 'axis', axis)
check_type(input, 'input', (list, tuple, Variable), 'concat')
......@@ -699,9 +699,9 @@ def fill_constant(shape, dtype, value, force_cpu=False, out=None, name=None):
if isinstance(value, Variable):
if dtype in ['int64', 'int32']:
attrs['str_value'] = str(int(value.numpy()))
attrs['str_value'] = str(int(value.numpy().item(0)))
else:
attrs['str_value'] = str(float(value.numpy()))
attrs['str_value'] = str(float(value.numpy().item(0)))
core.ops.fill_constant(out, 'value',
float(value), 'force_cpu', force_cpu, 'dtype',
......
......@@ -305,14 +305,18 @@ class TestFillConstantImperative(unittest.TestCase):
with fluid.dygraph.guard():
data1 = np.array([1, 2]).astype('int32')
data2 = np.array([1.1]).astype('float32')
data3 = np.array([88]).astype('int32')
shape = fluid.dygraph.to_variable(data1)
val = fluid.dygraph.to_variable(data2)
value = fluid.dygraph.to_variable(data3)
res1 = fluid.layers.fill_constant(
shape=[1, 2], dtype='float32', value=1.1)
res2 = fluid.layers.fill_constant(
shape=shape, dtype='float32', value=1.1)
res3 = fluid.layers.fill_constant(
shape=shape, dtype='float32', value=val)
res4 = fluid.layers.fill_constant(
shape=shape, dtype='int32', value=value)
assert np.array_equal(
res1.numpy(), np.full(
[1, 2], 1.1, dtype="float32"))
......@@ -322,6 +326,9 @@ class TestFillConstantImperative(unittest.TestCase):
assert np.array_equal(
res3.numpy(), np.full(
[1, 2], 1.1, dtype="float32"))
assert np.array_equal(
res4.numpy(), np.full(
[1, 2], 88, dtype="int32"))
class TestFillConstantOpError(unittest.TestCase):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册