未验证 提交 84bae277 编写于 作者: J Jack Zhou 提交者: GitHub

fix wmt14 doc, remove backward, add bidirect direction in rnn api (#29633)

* fix wmt14 doc, remove backward, add bidirect direction in rnn api

* fix rnn unittest

* fix test_rnn_nets_static.py bug
上级 613c46bc
...@@ -414,9 +414,9 @@ class SimpleRNN(RNNMixin): ...@@ -414,9 +414,9 @@ class SimpleRNN(RNNMixin):
time_major=False, time_major=False,
dtype="float64"): dtype="float64"):
super(SimpleRNN, self).__init__() super(SimpleRNN, self).__init__()
bidirectional_list = ["bidirectional", "bidirect"]
if direction in ["forward", "backward"]: if direction in ["forward"]:
is_reverse = direction == "backward" is_reverse = False
cell = SimpleRNNCell( cell = SimpleRNNCell(
input_size, hidden_size, nonlinearity=nonlinearity, dtype=dtype) input_size, hidden_size, nonlinearity=nonlinearity, dtype=dtype)
self.append(RNN(cell, is_reverse, time_major)) self.append(RNN(cell, is_reverse, time_major))
...@@ -427,7 +427,7 @@ class SimpleRNN(RNNMixin): ...@@ -427,7 +427,7 @@ class SimpleRNN(RNNMixin):
nonlinearity=nonlinearity, nonlinearity=nonlinearity,
dtype=dtype) dtype=dtype)
self.append(RNN(cell, is_reverse, time_major)) self.append(RNN(cell, is_reverse, time_major))
elif direction == "bidirectional": elif direction in bidirectional_list:
cell_fw = SimpleRNNCell( cell_fw = SimpleRNNCell(
input_size, hidden_size, nonlinearity=nonlinearity, dtype=dtype) input_size, hidden_size, nonlinearity=nonlinearity, dtype=dtype)
cell_bw = SimpleRNNCell( cell_bw = SimpleRNNCell(
...@@ -447,7 +447,7 @@ class SimpleRNN(RNNMixin): ...@@ -447,7 +447,7 @@ class SimpleRNN(RNNMixin):
self.input_size = input_size self.input_size = input_size
self.hidden_size = hidden_size self.hidden_size = hidden_size
self.dropout = dropout self.dropout = dropout
self.num_directions = 2 if direction == "bidirectional" else 1 self.num_directions = 2 if direction in bidirectional_list else 1
self.time_major = time_major self.time_major = time_major
self.num_layers = num_layers self.num_layers = num_layers
self.state_components = 1 self.state_components = 1
...@@ -464,14 +464,15 @@ class LSTM(RNNMixin): ...@@ -464,14 +464,15 @@ class LSTM(RNNMixin):
dtype="float64"): dtype="float64"):
super(LSTM, self).__init__() super(LSTM, self).__init__()
if direction in ["forward", "backward"]: bidirectional_list = ["bidirectional", "bidirect"]
is_reverse = direction == "backward" if direction in ["forward"]:
is_reverse = False
cell = LSTMCell(input_size, hidden_size, dtype=dtype) cell = LSTMCell(input_size, hidden_size, dtype=dtype)
self.append(RNN(cell, is_reverse, time_major)) self.append(RNN(cell, is_reverse, time_major))
for i in range(1, num_layers): for i in range(1, num_layers):
cell = LSTMCell(hidden_size, hidden_size, dtype=dtype) cell = LSTMCell(hidden_size, hidden_size, dtype=dtype)
self.append(RNN(cell, is_reverse, time_major)) self.append(RNN(cell, is_reverse, time_major))
elif direction == "bidirectional": elif direction in bidirectional_list:
cell_fw = LSTMCell(input_size, hidden_size, dtype=dtype) cell_fw = LSTMCell(input_size, hidden_size, dtype=dtype)
cell_bw = LSTMCell(input_size, hidden_size, dtype=dtype) cell_bw = LSTMCell(input_size, hidden_size, dtype=dtype)
self.append(BiRNN(cell_fw, cell_bw, time_major)) self.append(BiRNN(cell_fw, cell_bw, time_major))
...@@ -487,7 +488,7 @@ class LSTM(RNNMixin): ...@@ -487,7 +488,7 @@ class LSTM(RNNMixin):
self.input_size = input_size self.input_size = input_size
self.hidden_size = hidden_size self.hidden_size = hidden_size
self.dropout = dropout self.dropout = dropout
self.num_directions = 2 if direction == "bidirectional" else 1 self.num_directions = 2 if direction in bidirectional_list else 1
self.time_major = time_major self.time_major = time_major
self.num_layers = num_layers self.num_layers = num_layers
self.state_components = 2 self.state_components = 2
...@@ -504,14 +505,15 @@ class GRU(RNNMixin): ...@@ -504,14 +505,15 @@ class GRU(RNNMixin):
dtype="float64"): dtype="float64"):
super(GRU, self).__init__() super(GRU, self).__init__()
if direction in ["forward", "backward"]: bidirectional_list = ["bidirectional", "bidirect"]
is_reverse = direction == "backward" if direction in ["forward"]:
is_reverse = False
cell = GRUCell(input_size, hidden_size, dtype=dtype) cell = GRUCell(input_size, hidden_size, dtype=dtype)
self.append(RNN(cell, is_reverse, time_major)) self.append(RNN(cell, is_reverse, time_major))
for i in range(1, num_layers): for i in range(1, num_layers):
cell = GRUCell(hidden_size, hidden_size, dtype=dtype) cell = GRUCell(hidden_size, hidden_size, dtype=dtype)
self.append(RNN(cell, is_reverse, time_major)) self.append(RNN(cell, is_reverse, time_major))
elif direction == "bidirectional": elif direction in bidirectional_list:
cell_fw = GRUCell(input_size, hidden_size, dtype=dtype) cell_fw = GRUCell(input_size, hidden_size, dtype=dtype)
cell_bw = GRUCell(input_size, hidden_size, dtype=dtype) cell_bw = GRUCell(input_size, hidden_size, dtype=dtype)
self.append(BiRNN(cell_fw, cell_bw, time_major)) self.append(BiRNN(cell_fw, cell_bw, time_major))
...@@ -527,7 +529,7 @@ class GRU(RNNMixin): ...@@ -527,7 +529,7 @@ class GRU(RNNMixin):
self.input_size = input_size self.input_size = input_size
self.hidden_size = hidden_size self.hidden_size = hidden_size
self.dropout = dropout self.dropout = dropout
self.num_directions = 2 if direction == "bidirectional" else 1 self.num_directions = 2 if direction in bidirectional_list else 1
self.time_major = time_major self.time_major = time_major
self.num_layers = num_layers self.num_layers = num_layers
self.state_components = 1 self.state_components = 1
...@@ -22,13 +22,15 @@ import unittest ...@@ -22,13 +22,15 @@ import unittest
from convert import convert_params_for_net from convert import convert_params_for_net
from rnn_numpy import SimpleRNN, LSTM, GRU from rnn_numpy import SimpleRNN, LSTM, GRU
bidirectional_list = ["bidirectional", "bidirect"]
class TestSimpleRNN(unittest.TestCase): class TestSimpleRNN(unittest.TestCase):
def __init__(self, time_major=True, direction="forward", place="cpu"): def __init__(self, time_major=True, direction="forward", place="cpu"):
super(TestSimpleRNN, self).__init__("runTest") super(TestSimpleRNN, self).__init__("runTest")
self.time_major = time_major self.time_major = time_major
self.direction = direction self.direction = direction
self.num_directions = 2 if direction == "bidirectional" else 1 self.num_directions = 2 if direction in bidirectional_list else 1
self.place = place self.place = place
def setUp(self): def setUp(self):
...@@ -109,7 +111,7 @@ class TestGRU(unittest.TestCase): ...@@ -109,7 +111,7 @@ class TestGRU(unittest.TestCase):
super(TestGRU, self).__init__("runTest") super(TestGRU, self).__init__("runTest")
self.time_major = time_major self.time_major = time_major
self.direction = direction self.direction = direction
self.num_directions = 2 if direction == "bidirectional" else 1 self.num_directions = 2 if direction in bidirectional_list else 1
self.place = place self.place = place
def setUp(self): def setUp(self):
...@@ -196,7 +198,7 @@ class TestLSTM(unittest.TestCase): ...@@ -196,7 +198,7 @@ class TestLSTM(unittest.TestCase):
super(TestLSTM, self).__init__("runTest") super(TestLSTM, self).__init__("runTest")
self.time_major = time_major self.time_major = time_major
self.direction = direction self.direction = direction
self.num_directions = 2 if direction == "bidirectional" else 1 self.num_directions = 2 if direction in bidirectional_list else 1
self.place = place self.place = place
def setUp(self): def setUp(self):
...@@ -339,7 +341,7 @@ def load_tests(loader, tests, pattern): ...@@ -339,7 +341,7 @@ def load_tests(loader, tests, pattern):
suite = unittest.TestSuite() suite = unittest.TestSuite()
devices = ["cpu", "gpu"] if paddle.fluid.is_compiled_with_cuda() \ devices = ["cpu", "gpu"] if paddle.fluid.is_compiled_with_cuda() \
else ["cpu"] else ["cpu"]
for direction in ["forward", "backward", "bidirectional"]: for direction in ["forward", "bidirectional", "bidirect"]:
for time_major in [True, False]: for time_major in [True, False]:
for device in devices: for device in devices:
for test_class in [TestSimpleRNN, TestLSTM, TestGRU]: for test_class in [TestSimpleRNN, TestLSTM, TestGRU]:
......
...@@ -23,13 +23,15 @@ import unittest ...@@ -23,13 +23,15 @@ import unittest
from convert import convert_params_for_net_static from convert import convert_params_for_net_static
from rnn_numpy import SimpleRNN, LSTM, GRU from rnn_numpy import SimpleRNN, LSTM, GRU
bidirectional_list = ["bidirectional", "bidirect"]
class TestSimpleRNN(unittest.TestCase): class TestSimpleRNN(unittest.TestCase):
def __init__(self, time_major=True, direction="forward", place="cpu"): def __init__(self, time_major=True, direction="forward", place="cpu"):
super(TestSimpleRNN, self).__init__("runTest") super(TestSimpleRNN, self).__init__("runTest")
self.time_major = time_major self.time_major = time_major
self.direction = direction self.direction = direction
self.num_directions = 2 if direction == "bidirectional" else 1 self.num_directions = 2 if direction in bidirectional_list else 1
self.place = place self.place = place
def setUp(self): def setUp(self):
...@@ -173,7 +175,7 @@ class TestGRU(unittest.TestCase): ...@@ -173,7 +175,7 @@ class TestGRU(unittest.TestCase):
super(TestGRU, self).__init__("runTest") super(TestGRU, self).__init__("runTest")
self.time_major = time_major self.time_major = time_major
self.direction = direction self.direction = direction
self.num_directions = 2 if direction == "bidirectional" else 1 self.num_directions = 2 if direction in bidirectional_list else 1
self.place = place self.place = place
def setUp(self): def setUp(self):
...@@ -319,7 +321,7 @@ class TestLSTM(unittest.TestCase): ...@@ -319,7 +321,7 @@ class TestLSTM(unittest.TestCase):
super(TestLSTM, self).__init__("runTest") super(TestLSTM, self).__init__("runTest")
self.time_major = time_major self.time_major = time_major
self.direction = direction self.direction = direction
self.num_directions = 2 if direction == "bidirectional" else 1 self.num_directions = 2 if direction in bidirectional_list else 1
self.place = place self.place = place
def setUp(self): def setUp(self):
...@@ -469,9 +471,13 @@ def load_tests(loader, tests, pattern): ...@@ -469,9 +471,13 @@ def load_tests(loader, tests, pattern):
suite = unittest.TestSuite() suite = unittest.TestSuite()
devices = ["cpu", "gpu"] if paddle.fluid.is_compiled_with_cuda() \ devices = ["cpu", "gpu"] if paddle.fluid.is_compiled_with_cuda() \
else ["cpu"] else ["cpu"]
for direction in ["forward", "backward", "bidirectional"]: for direction in ["forward", "bidirectional", "bidirect"]:
for time_major in [True, False]: for time_major in [True, False]:
for device in devices: for device in devices:
for test_class in [TestSimpleRNN, TestLSTM, TestGRU]: for test_class in [TestSimpleRNN, TestLSTM, TestGRU]:
suite.addTest(test_class(time_major, direction, device)) suite.addTest(test_class(time_major, direction, device))
return suite return suite
if __name__ == "__main__":
unittest.main()
...@@ -858,11 +858,12 @@ class RNNBase(LayerList): ...@@ -858,11 +858,12 @@ class RNNBase(LayerList):
bias_ih_attr=None, bias_ih_attr=None,
bias_hh_attr=None): bias_hh_attr=None):
super(RNNBase, self).__init__() super(RNNBase, self).__init__()
bidirectional_list = ["bidirectional", "bidirect"]
self.mode = mode self.mode = mode
self.input_size = input_size self.input_size = input_size
self.hidden_size = hidden_size self.hidden_size = hidden_size
self.dropout = dropout self.dropout = dropout
self.num_directions = 2 if direction == "bidirectional" else 1 self.num_directions = 2 if direction in bidirectional_list else 1
self.time_major = time_major self.time_major = time_major
self.num_layers = num_layers self.num_layers = num_layers
self.state_components = 2 if mode == "LSTM" else 1 self.state_components = 2 if mode == "LSTM" else 1
...@@ -882,14 +883,14 @@ class RNNBase(LayerList): ...@@ -882,14 +883,14 @@ class RNNBase(LayerList):
rnn_cls = SimpleRNNCell rnn_cls = SimpleRNNCell
kwargs["activation"] = self.activation kwargs["activation"] = self.activation
if direction in ["forward", "backward"]: if direction in ["forward"]:
is_reverse = direction == "backward" is_reverse = False
cell = rnn_cls(input_size, hidden_size, **kwargs) cell = rnn_cls(input_size, hidden_size, **kwargs)
self.append(RNN(cell, is_reverse, time_major)) self.append(RNN(cell, is_reverse, time_major))
for i in range(1, num_layers): for i in range(1, num_layers):
cell = rnn_cls(hidden_size, hidden_size, **kwargs) cell = rnn_cls(hidden_size, hidden_size, **kwargs)
self.append(RNN(cell, is_reverse, time_major)) self.append(RNN(cell, is_reverse, time_major))
elif direction == "bidirectional": elif direction in bidirectional_list:
cell_fw = rnn_cls(input_size, hidden_size, **kwargs) cell_fw = rnn_cls(input_size, hidden_size, **kwargs)
cell_bw = rnn_cls(input_size, hidden_size, **kwargs) cell_bw = rnn_cls(input_size, hidden_size, **kwargs)
self.append(BiRNN(cell_fw, cell_bw, time_major)) self.append(BiRNN(cell_fw, cell_bw, time_major))
...@@ -899,13 +900,12 @@ class RNNBase(LayerList): ...@@ -899,13 +900,12 @@ class RNNBase(LayerList):
self.append(BiRNN(cell_fw, cell_bw, time_major)) self.append(BiRNN(cell_fw, cell_bw, time_major))
else: else:
raise ValueError( raise ValueError(
"direction should be forward, backward or bidirectional, " "direction should be forward or bidirect (or bidirectional), "
"received direction = {}".format(direction)) "received direction = {}".format(direction))
self.could_use_cudnn = True self.could_use_cudnn = True
self.could_use_cudnn &= direction != "backward"
self.could_use_cudnn &= len(self.parameters()) == num_layers * 4 * ( self.could_use_cudnn &= len(self.parameters()) == num_layers * 4 * (
2 if direction == "bidirectional" else 1) 2 if direction in bidirectional_list else 1)
# Expose params as RNN's attribute, which can make it compatible when # Expose params as RNN's attribute, which can make it compatible when
# replacing small ops composed rnn with cpp rnn kernel. # replacing small ops composed rnn with cpp rnn kernel.
...@@ -1079,8 +1079,8 @@ class SimpleRNN(RNNBase): ...@@ -1079,8 +1079,8 @@ class SimpleRNN(RNNBase):
input_size (int): The input size for the first layer's cell. input_size (int): The input size for the first layer's cell.
hidden_size (int): The hidden size for each layer's cell. hidden_size (int): The hidden size for each layer's cell.
num_layers (int, optional): Number of layers. Defaults to 1. num_layers (int, optional): Number of layers. Defaults to 1.
direction (str, optional): The direction of the network. It can be "forward", direction (str, optional): The direction of the network. It can be "forward"
"backward" and "bidirectional". When "bidirectional", the way to merge or "bidirect"(or "bidirectional"). When "bidirect", the way to merge
outputs of forward and backward is concatenating. Defaults to "forward". outputs of forward and backward is concatenating. Defaults to "forward".
time_major (bool, optional): Whether the first dimension of the input means the time_major (bool, optional): Whether the first dimension of the input means the
time steps. Defaults to False. time steps. Defaults to False.
...@@ -1195,8 +1195,8 @@ class LSTM(RNNBase): ...@@ -1195,8 +1195,8 @@ class LSTM(RNNBase):
input_size (int): The input size for the first layer's cell. input_size (int): The input size for the first layer's cell.
hidden_size (int): The hidden size for each layer's cell. hidden_size (int): The hidden size for each layer's cell.
num_layers (int, optional): Number of layers. Defaults to 1. num_layers (int, optional): Number of layers. Defaults to 1.
direction (str, optional): The direction of the network. It can be "forward", direction (str, optional): The direction of the network. It can be "forward"
"backward" and "bidirectional". When "bidirectional", the way to merge or "bidirect"(or "bidirectional"). When "bidirect", the way to merge
outputs of forward and backward is concatenating. Defaults to "forward". outputs of forward and backward is concatenating. Defaults to "forward".
time_major (bool, optional): Whether the first dimension of the input time_major (bool, optional): Whether the first dimension of the input
means the time steps. Defaults to False. means the time steps. Defaults to False.
...@@ -1300,8 +1300,8 @@ class GRU(RNNBase): ...@@ -1300,8 +1300,8 @@ class GRU(RNNBase):
input_size (int): The input size for the first layer's cell. input_size (int): The input size for the first layer's cell.
hidden_size (int): The hidden size for each layer's cell. hidden_size (int): The hidden size for each layer's cell.
num_layers (int, optional): Number of layers. Defaults to 1. num_layers (int, optional): Number of layers. Defaults to 1.
direction (str, optional): The direction of the network. It can be "forward", direction (str, optional): The direction of the network. It can be "forward"
"backward" and "bidirectional". When "bidirectional", the way to merge or "bidirect"(or "bidirectional"). When "bidirect", the way to merge
outputs of forward and backward is concatenating. Defaults to "forward". outputs of forward and backward is concatenating. Defaults to "forward".
time_major (bool, optional): Whether the first dimension of the input time_major (bool, optional): Whether the first dimension of the input
means the time steps. Defaults to False. means the time steps. Defaults to False.
......
...@@ -43,7 +43,7 @@ class WMT14(Dataset): ...@@ -43,7 +43,7 @@ class WMT14(Dataset):
Implementation of `WMT14 <http://www.statmt.org/wmt14/>`_ test dataset. Implementation of `WMT14 <http://www.statmt.org/wmt14/>`_ test dataset.
The original WMT14 dataset is too large and a small set of data for set is The original WMT14 dataset is too large and a small set of data for set is
provided. This module will download dataset from provided. This module will download dataset from
http://paddlepaddle.bj.bcebos.com/demo/wmt_shrinked_data/wmt14.tgz http://paddlemodels.bj.bcebos.com/wmt/wmt14.tgz .
Args: Args:
data_file(str): path to data tar file, can be set None if data_file(str): path to data tar file, can be set None if
...@@ -70,8 +70,6 @@ class WMT14(Dataset): ...@@ -70,8 +70,6 @@ class WMT14(Dataset):
def forward(self, src_ids, trg_ids, trg_ids_next): def forward(self, src_ids, trg_ids, trg_ids_next):
return paddle.sum(src_ids), paddle.sum(trg_ids), paddle.sum(trg_ids_next) return paddle.sum(src_ids), paddle.sum(trg_ids), paddle.sum(trg_ids_next)
paddle.disable_static()
wmt14 = WMT14(mode='train', dict_size=50) wmt14 = WMT14(mode='train', dict_size=50)
for i in range(10): for i in range(10):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册