未验证 提交 93027d9f 编写于 作者: H heyanru 提交者: GitHub

[Fluid Clean] remove nn.topk, nn.ctc_greedy_decoder, nn.im2sequence,...

[Fluid Clean] remove nn.topk, nn.ctc_greedy_decoder, nn.im2sequence, nn.multiplex, nn.smooth_l1 (#48289)
上级 d6aa0d43
......@@ -475,7 +475,7 @@ void SwapDim1And2InNarrow(const phi::GPUContext& d,
CeilOrFloor<int, false>(input_long_edge, proposed_tile_long_edge) *
proposed_tile_long_edge;
int num_full_tiles =
int num_full_tiles =
CeilOrFloor<int, false>(input_long_edge, proposed_tile_long_edge);
float cost = num_wasted_threads;
......
......@@ -1688,7 +1688,8 @@ def ssd_loss(
location = __reshape_to_2d(location)
target_bbox = __reshape_to_2d(target_bbox)
loc_loss = nn.smooth_l1(location, target_bbox)
smooth_l1_loss = paddle.nn.loss.SmoothL1Loss()
loc_loss = smooth_l1_loss(location, target_bbox)
target_loc_weight = __reshape_to_2d(target_loc_weight)
loc_loss = loc_loss * target_loc_weight
......
此差异已折叠。
......@@ -1833,8 +1833,8 @@ def fast_decode(
)
logits = paddle.reshape(logits, (-1, trg_vocab_size))
topk_scores, topk_indices = layers.topk(
input=paddle.nn.functional.softmax(logits), k=beam_size
topk_scores, topk_indices = paddle.topk(
x=paddle.nn.functional.softmax(logits), k=beam_size
)
accu_scores = layers.elementwise_add(
x=paddle.log(topk_scores),
......
......@@ -459,9 +459,7 @@ class BaseModel(fluid.dygraph.Layer):
scores = paddle.reshape(
log_probs, [-1, self.beam_size * self.tar_vocab_size]
)
topk_scores, topk_indices = fluid.layers.topk(
input=scores, k=self.beam_size
)
topk_scores, topk_indices = paddle.topk(x=scores, k=self.beam_size)
beam_indices = paddle.floor_divide(topk_indices, vocab_size_tensor)
token_indices = paddle.remainder(topk_indices, vocab_size_tensor)
......
......@@ -853,9 +853,7 @@ class Transformer(Layer):
log_probs, [-1, beam_size * self.trg_vocab_size]
)
scores = log_probs
topk_scores, topk_indices = fluid.layers.topk(
input=scores, k=beam_size
)
topk_scores, topk_indices = paddle.topk(x=scores, k=beam_size)
beam_indices = paddle.floor_divide(topk_indices, vocab_size_tensor)
token_indices = paddle.remainder(topk_indices, vocab_size_tensor)
......
......@@ -31,7 +31,7 @@ class TestTopKOp(IPUOpTest):
self.set_op_attrs()
def set_test_op(self):
self.op = paddle.fluid.layers.topk
self.op = paddle.topk
def set_data_feed(self):
data = np.random.uniform(size=[3, 5])
......
......@@ -138,22 +138,5 @@ class TestSmoothL1LossOp2(OpTest):
)
class TestSmoothL1LossOpError(unittest.TestCase):
def test_errors(self):
with fluid.program_guard(fluid.Program(), fluid.Program()):
# The input type of accuracy_op must be Variable.
x1 = fluid.create_lod_tensor(
np.array([[-1]]), [[1]], fluid.NPUPlace(0)
)
y1 = fluid.create_lod_tensor(
np.array([[-1]]), [[1]], fluid.NPUPlace(0)
)
self.assertRaises(TypeError, fluid.layers.smooth_l1, x1, y1)
# The input dtype of accuracy_op must be float32 or float64.
x2 = fluid.layers.data(name='x2', shape=[4], dtype="int32")
y2 = fluid.layers.data(name='x2', shape=[4], dtype="int32")
self.assertRaises(TypeError, fluid.layers.smooth_l1, x2, y2)
if __name__ == '__main__':
unittest.main()
......@@ -312,7 +312,7 @@ class TestBeamSearchOpError(unittest.TestCase):
name='pre_scores', shape=[1], lod_level=2, dtype='float32'
)
probs = fluid.data(name='probs', shape=[10000], dtype='float32')
topk_scores, topk_indices = fluid.layers.topk(probs, k=4)
topk_scores, topk_indices = paddle.topk(probs, k=4)
accu_scores = fluid.layers.elementwise_add(
x=paddle.log(x=topk_scores),
y=paddle.reshape(pre_scores, shape=[-1]),
......
......@@ -18,7 +18,6 @@ import numpy as np
from op_test import OpTest
import paddle
import paddle.fluid as fluid
def CTCAlign(input, lod, blank, merge_repeated, padding=0, input_length=None):
......@@ -226,50 +225,6 @@ class TestCTCAlignOpCase5(TestCTCAlignPaddingOp):
)
class TestCTCAlignOpApi(unittest.TestCase):
def test_api(self):
x = fluid.layers.data('x', shape=[4], dtype='float32')
y = fluid.layers.ctc_greedy_decoder(x, blank=0)
x_pad = fluid.layers.data('x_pad', shape=[4, 4], dtype='float32')
x_pad_len = fluid.layers.data('x_pad_len', shape=[1], dtype='int64')
y_pad, y_pad_len = fluid.layers.ctc_greedy_decoder(
x_pad, blank=0, input_length=x_pad_len
)
place = fluid.CPUPlace()
x_tensor = fluid.create_lod_tensor(
np.random.rand(8, 4).astype("float32"), [[4, 4]], place
)
x_pad_tensor = np.random.rand(2, 4, 4).astype("float32")
x_pad_len_tensor = np.array([[4], [4]]).reshape([2, 1]).astype("int64")
exe = fluid.Executor(place)
exe.run(fluid.default_startup_program())
ret = exe.run(
feed={
'x': x_tensor,
'x_pad': x_pad_tensor,
'x_pad_len': x_pad_len_tensor,
},
fetch_list=[y, y_pad, y_pad_len],
return_numpy=False,
)
class BadInputTestCTCAlignr(unittest.TestCase):
def test_error(self):
with fluid.program_guard(fluid.Program()):
def test_bad_x():
x = fluid.layers.data(name='x', shape=[8], dtype='int64')
cost = fluid.layers.ctc_greedy_decoder(input=x, blank=0)
self.assertRaises(TypeError, test_bad_x)
if __name__ == "__main__":
paddle.enable_static()
unittest.main()
......@@ -1519,8 +1519,8 @@ class TestLayer(LayerTest):
with self.dynamic_graph():
with _test_eager_guard():
input = fluid.dygraph.to_variable(np.random.random((13, 11)))
top5_values1, top5_indices1 = layers.topk(input, k=5)
top5_values2, top5_indices2 = layers.topk(
top5_values1, top5_indices1 = paddle.topk(input, k=5)
top5_values2, top5_indices2 = paddle.topk(
input, k=fluid.dygraph.to_variable(np.array([5]))
)
np.testing.assert_array_equal(
......@@ -1531,8 +1531,8 @@ class TestLayer(LayerTest):
)
input = fluid.dygraph.to_variable(np.random.random((13, 11)))
top5_values1, top5_indices1 = layers.topk(input, k=5)
top5_values2, top5_indices2 = layers.topk(
top5_values1, top5_indices1 = paddle.topk(input, k=5)
top5_values2, top5_indices2 = paddle.topk(
input, k=fluid.dygraph.to_variable(np.array([5]))
)
np.testing.assert_array_equal(
......@@ -3104,7 +3104,7 @@ class TestBook(LayerTest):
x1 = self._get_data(name='x1', shape=[4], dtype='float32')
x2 = self._get_data(name='x2', shape=[4], dtype='float32')
index = self._get_data(name='index', shape=[1], dtype='int32')
out = layers.multiplex(inputs=[x1, x2], index=index)
out = paddle.multiplex(inputs=[x1, x2], index=index)
return out
def make_softmax_with_cross_entropy(self):
......@@ -3144,15 +3144,6 @@ class TestBook(LayerTest):
self.assertIsNotNone(loss4)
return loss4
def make_smooth_l1(self):
with program_guard(
fluid.default_main_program(), fluid.default_startup_program()
):
x = self._get_data(name='x', shape=[4], dtype='float32')
y = self._get_data(name='label', shape=[4], dtype='float32')
loss = layers.smooth_l1(x, y)
return loss
def make_scatter(self):
with program_guard(
fluid.default_main_program(), fluid.default_startup_program()
......@@ -3192,7 +3183,7 @@ class TestBook(LayerTest):
fluid.default_main_program(), fluid.default_startup_program()
):
data = self._get_data(name="label", shape=[200], dtype="float32")
values, indices = layers.topk(data, k=5)
values, indices = paddle.topk(data, k=5)
return values
return indices
......@@ -3559,20 +3550,6 @@ class TestBook(LayerTest):
)
)
def test_im2sequence(self):
# TODO(minqiyang): dygraph do not support lod now
with self.static_graph():
x = layers.data(name='x', shape=[3, 128, 128], dtype='float32')
y = layers.data(name='y', shape=[], dtype='float32')
output = layers.im2sequence(
input=x,
input_image_size=y,
stride=[1, 1],
filter_size=[2, 2],
out_stride=[1, 1],
)
return output
def test_lod_reset(self):
# TODO(minqiyang): dygraph do not support lod now
with self.static_graph():
......
......@@ -17,8 +17,6 @@ import unittest
import numpy as np
from op_test import OpTest
import paddle.fluid as fluid
def smooth_l1_loss_forward(val, sigma2):
abs_val = abs(val)
......@@ -124,22 +122,5 @@ class TestSmoothL1LossOp2(OpTest):
)
class TestSmoothL1LossOpError(unittest.TestCase):
def test_errors(self):
with fluid.program_guard(fluid.Program(), fluid.Program()):
# The input type of accuracy_op must be Variable.
x1 = fluid.create_lod_tensor(
np.array([[-1]]), [[1]], fluid.CPUPlace()
)
y1 = fluid.create_lod_tensor(
np.array([[-1]]), [[1]], fluid.CPUPlace()
)
self.assertRaises(TypeError, fluid.layers.smooth_l1, x1, y1)
# The input dtype of accuracy_op must be float32 or float64.
x2 = fluid.layers.data(name='x2', shape=[4], dtype="int32")
y2 = fluid.layers.data(name='x2', shape=[4], dtype="int32")
self.assertRaises(TypeError, fluid.layers.smooth_l1, x2, y2)
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册