未验证 提交 adac38c5 编写于 作者: L Leo Chen 提交者: GitHub

add dispenable input for core.ops.reshape2/expand/slice (#30072)

* add dispenable input 'shape' for core.ops.reshape2

* add dispenable inputs for core.ops.reshape2/expand/slice

* add ut
上级 3be65939
......@@ -38,6 +38,9 @@ std::map<std::string, std::set<std::string>> op_ins_map = {
{"gru_unit", {"Input", "HiddenPrev", "Weight", "Bias"}},
{"label_smooth", {"X", "PriorDist"}},
{"assign", {"X"}},
{"reshape2", {"X", "Shape"}},
{"expand", {"X", "ExpandTimes"}},
{"slice", {"Input", "StartsTensor", "EndsTensor"}},
{"fake_quantize_dequantize_moving_average_abs_max",
{"X", "InScale", "InAccum", "InState"}},
{"nll_loss", {"X", "Label", "Weight"}},
......
......@@ -6148,8 +6148,12 @@ def reshape(x, shape, actual_shape=None, act=None, inplace=False, name=None):
item.numpy().item(0) if isinstance(item, Variable) else item
for item in shape
]
out, _ = core.ops.reshape2(x, 'shape', shape)
return dygraph_utils._append_activation_in_dygraph(out, act)
out, _ = core.ops.reshape2(x, None, 'shape', shape)
elif isinstance(shape, Variable):
shape.stop_gradient = True
out, _ = core.ops.reshape2(x, shape)
return dygraph_utils._append_activation_in_dygraph(out, act)
check_variable_and_dtype(
x, 'x', ['float16', 'float32', 'float64', 'int32', 'int64',
......@@ -10315,13 +10319,19 @@ def expand(x, expand_times, name=None):
# the shape of expanded_2 is [48, 56].
"""
if in_dygraph_mode():
attrs = ()
expand_times_tensor = None
if isinstance(expand_times, (list, tuple)):
expand_times = [
item.numpy().item(0) if isinstance(item, Variable) else item
for item in expand_times
]
attrs += ('expand_times', expand_times)
elif isinstance(expand_times, Variable):
expand_times_tensor = expand_times
expand_times_tensor.stop_gradient = True
return core.ops.expand(x, 'expand_times', expand_times)
return core.ops.expand(x, expand_times_tensor, *attrs)
inputs = {"X": [x]}
attrs = {}
......@@ -10925,20 +10935,35 @@ def slice(input, axes, starts, ends):
# sliced_2 is input[0:3, 0:2, 2:4].
"""
if in_dygraph_mode():
attrs = ()
starts_tensor = None
ends_tensor = None
infer_flags = list(1 for i in range(len(axes)))
if isinstance(starts, (list, tuple)) and isinstance(ends,
(list, tuple)):
if isinstance(starts, (list, tuple)):
starts = [
item.numpy().item(0) if isinstance(item, Variable) else item
for item in starts
]
attrs += ('starts', starts)
elif isinstance(starts, Variable):
starts_tensor = starts
starts.stop_gradient = True
infer_flags = list(-1 for i in range(len(axes)))
if isinstance(ends, (list, tuple)):
ends = [
item.numpy().item(0) if isinstance(item, Variable) else item
for item in ends
]
return core.ops.slice(input, 'axes', axes, 'starts', starts, 'ends',
ends, 'infer_flags', infer_flags)
attrs += ('ends', ends)
elif isinstance(ends, Variable):
ends_tensor = ends
ends_tensor.stop_gradient = True
infer_flags = list(-1 for i in range(len(axes)))
return core.ops.slice(input, starts_tensor, ends_tensor, 'axes', axes,
'infer_flags', infer_flags, *attrs)
if not isinstance(starts, (list, tuple, Variable)):
raise ValueError(
......
......@@ -1669,7 +1669,7 @@ def eye(num_rows,
expand_times = batch_shape + [1, 1]
if in_dygraph_mode():
out = core.ops.reshape(out, 'shape', re_shape)
return core.ops.expand(out, 'expand_times', expand_times)
return core.ops.expand(out, None, 'expand_times', expand_times)
if not isinstance(batch_shape, list):
raise TypeError("batch_shape should be a list")
......
......@@ -19,6 +19,7 @@ import numpy as np
from op_test import OpTest
import paddle.fluid as fluid
from paddle.fluid import compiler, Program, program_guard
import paddle
# Situation 1: expand_times is a list(without tensor)
......@@ -237,5 +238,19 @@ class TestExpandAPI(unittest.TestCase):
assert np.array_equal(res_3, np.tile(input, (1, 3)))
class TestExpandDygraphAPI(unittest.TestCase):
def test_expand_times_is_tensor(self):
with paddle.fluid.dygraph.guard():
a = paddle.rand([2, 5])
b = paddle.fluid.layers.expand(a, expand_times=[2, 3])
c = paddle.fluid.layers.expand(
a, expand_times=paddle.to_tensor(
[2, 3], dtype='int32'))
self.assertTrue(
np.array_equal(b.numpy(), np.tile(a.numpy(), [2, 3])))
self.assertTrue(
np.array_equal(c.numpy(), np.tile(a.numpy(), [2, 3])))
if __name__ == "__main__":
unittest.main()
......@@ -20,6 +20,7 @@ import paddle.fluid.core as core
from op_test import OpTest
import paddle.fluid as fluid
import paddle.fluid.layers as layers
import paddle
# Situation 1: starts(list, no tensor), ends(list, no tensor)
......@@ -532,6 +533,25 @@ class TestSliceAPI(unittest.TestCase):
assert np.array_equal(res_7, input[-1, 0:100, :, 2:-1])
class TestSliceApiWithTensor(unittest.TestCase):
def test_starts_ends_is_tensor(self):
with paddle.fluid.dygraph.guard():
a = paddle.rand(shape=[4, 5, 6], dtype='float32')
axes = [0, 1, 2]
starts = [-3, 0, 2]
ends = [3, 2, 4]
a_1 = paddle.slice(
a,
axes=axes,
starts=paddle.to_tensor(
starts, dtype='int32'),
ends=paddle.to_tensor(
ends, dtype='int32'))
a_2 = paddle.slice(a, axes=axes, starts=starts, ends=ends)
self.assertTrue(np.array_equal(a_1.numpy(), a_2.numpy()))
class TestSliceApiWithLoDTensorArray(unittest.TestCase):
def setUp(self):
self.shape = (3, 4)
......
......@@ -796,14 +796,14 @@ def nll_loss(input,
c = input_shape[1]
if in_dygraph_mode():
if input_dims != 2 and input_dims != 4:
input, _ = core.ops.reshape2(input, 'shape', [n, c, 1, -1])
label, _ = core.ops.reshape2(label, 'shape', [n, 1, -1])
input, _ = core.ops.reshape2(input, None, 'shape', [n, c, 1, -1])
label, _ = core.ops.reshape2(label, None, 'shape', [n, 1, -1])
out_shape = [n] + input_shape[2:]
out, total_weight = core.ops.nll_loss(input, label, weight,
'ignore_index', ignore_index,
'reduction', reduction)
if input_dims != 2 and input_dims != 4 and reduction == 'none':
out, _ = core.ops.reshape2(out, 'shape', out_shape)
out, _ = core.ops.reshape2(out, None, 'shape', out_shape)
return out
helper = LayerHelper('nll_loss', **locals())
......@@ -1225,8 +1225,8 @@ def cross_entropy(input,
if weight is not None:
weight_gather = core.ops.gather_nd(weight, label) #trans to sample
input_shape = list(label.shape)
weight_gather_reshape, _ = core.ops.reshape2(weight_gather, 'shape',
input_shape)
weight_gather_reshape, _ = core.ops.reshape2(weight_gather, None,
'shape', input_shape)
out = core.ops.elementwise_mul(out, weight_gather_reshape)
if reduction == "sum":
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册