From b65722d3cf0ac525eaf39fd026013f1aaf718531 Mon Sep 17 00:00:00 2001 From: phlrain Date: Sat, 1 Dec 2018 16:03:42 +0800 Subject: [PATCH] fix uni test; test=develop --- paddle/fluid/operators/cudnn_lstm_op.cu.cc | 6 --- python/paddle/fluid/layers/nn.py | 13 +++-- .../paddle/fluid/tests/unittests/op_test.py | 9 ++++ .../tests/unittests/test_lstm_cudnn_op.py | 53 ++++++++++++++----- 4 files changed, 54 insertions(+), 27 deletions(-) diff --git a/paddle/fluid/operators/cudnn_lstm_op.cu.cc b/paddle/fluid/operators/cudnn_lstm_op.cu.cc index cad62de754..e01070c7b8 100644 --- a/paddle/fluid/operators/cudnn_lstm_op.cu.cc +++ b/paddle/fluid/operators/cudnn_lstm_op.cu.cc @@ -279,12 +279,6 @@ class CudnnLSTMGPUKernel : public framework::OpKernel { int num_layers = ctx.Attr("num_layers"); bool is_test = ctx.Attr("is_test"); - /* - if (is_test) { - TensorCopy(*x, ctx.GetPlace(), out); - return; - }*/ - auto &dev_ctx = ctx.template device_context(); auto handle = dev_ctx.cudnn_handle(); auto *cache_var = ctx.InputVar("Cache"); diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index f9e3da68d7..dbc39afccb 100644 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -477,12 +477,10 @@ def lstm(input, init_h, init_c, max_len, - dropout_prob, - input_size, hidden_size, num_layers, + dropout_prob=0.0, is_bidirec=False, - dtype='float32', is_test=False, name=None, default_initializer=None, @@ -531,13 +529,11 @@ def lstm(input, This is a tensor with shape ( num_layers x batch_size x hidden_size ) if is_bidirec = True, shape should be ( num_layers*2 x batch_size x hidden_size) max_len (int): max length of LSTM. the first dim of input tensor CAN NOT greater than max_len - dropout_prob(float): dropout prob, dropout ONLY work between rnn layers, NOT between time steps - There is NO dropout work on rnn output of the last RNN layers - input_size (int): hidden size of the input tensor hidden_size (int): hidden size of the LSTM num_layers (int): total layers number of the LSTM + dropout_prob(float|0.0): dropout prob, dropout ONLY work between rnn layers, NOT between time steps + There is NO dropout work on rnn output of the last RNN layers is_bidirec (bool): If it is bidirectional - dtype (str): Data type. Choices = ["float32", "float64"], default "float32". is_test (bool): If it is in test phrase name (str|None): A name for this layer(optional). If set None, the layer will be named automatically. @@ -577,6 +573,9 @@ def lstm(input, helper = LayerHelper('cudnn_lstm', **locals()) + dtype = input.dtype + input_shape = list(input.shape) + input_size = input_shape[-1] weight_size = 0 for i in range(num_layers): if i == 0: diff --git a/python/paddle/fluid/tests/unittests/op_test.py b/python/paddle/fluid/tests/unittests/op_test.py index 271b9c740f..0200d74136 100644 --- a/python/paddle/fluid/tests/unittests/op_test.py +++ b/python/paddle/fluid/tests/unittests/op_test.py @@ -216,6 +216,15 @@ class OpTest(unittest.TestCase): self.dtype) outputs = append_input_output(block, op_proto, self.outputs, False, self.dtype) + + if hasattr(self, "cache_name_list"): + for name in self.cache_name_list: + inputs[name] = block.create_var( + name=name, + persistable=True, + type=core.VarDesc.VarType.RAW, + stop_gradient=True) + op = block.append_op( type=self.op_type, inputs=inputs, diff --git a/python/paddle/fluid/tests/unittests/test_lstm_cudnn_op.py b/python/paddle/fluid/tests/unittests/test_lstm_cudnn_op.py index 2741bf167b..8d313970cc 100644 --- a/python/paddle/fluid/tests/unittests/test_lstm_cudnn_op.py +++ b/python/paddle/fluid/tests/unittests/test_lstm_cudnn_op.py @@ -19,6 +19,11 @@ import numpy as np import paddle.fluid.core as core from op_test import OpTest +import paddle.fluid as fluid + +SIGMOID_THRESHOLD_MIN = -40.0 +SIGMOID_THRESHOLD_MAX = 13.0 +EXP_MAX_INPUT = 40.0 def lstm_naive( @@ -70,10 +75,15 @@ def lstm_naive( bo_2 = w[offset:offset + hidden_size] def sigmoid(x): - return 1.0 / (1.0 + np.exp(-x)) + y = np.copy(x) + y[x < SIGMOID_THRESHOLD_MIN] = SIGMOID_THRESHOLD_MIN + y[x > SIGMOID_THRESHOLD_MAX] = SIGMOID_THRESHOLD_MAX + return 1. / (1. + np.exp(-y)) def tanh(x): - return (np.exp(x) - np.exp(-x)) / (np.exp(x) + np.exp(-x)) + y = -2. * x + y[y > EXP_MAX_INPUT] = EXP_MAX_INPUT + return (2. / (1. + np.exp(y))) - 1. output = [] pre_h = np.zeros((batch_size, hidden_size), dtype=input.dtype) @@ -103,7 +113,7 @@ def lstm_naive( output = output.transpose((1, 0, 2)) - return output + return output, pre_h, pre_c class TestCUDNNLstmOp(OpTest): @@ -120,20 +130,32 @@ class TestCUDNNLstmOp(OpTest): weight_size = input_weight_size + hidden_weight_size weight_size += hidden_size * 8 - input = np.random.random( - (num_steps, batch_size, hidden_size)).astype(self.dtype) - flat_w = np.random.random((weight_size)).astype(self.dtype) + input = np.random.uniform( + low=-0.1, high=0.1, size=(num_steps, batch_size, + hidden_size)).astype(self.dtype) + flat_w = np.random.uniform( + low=-0.1, high=0.1, size=(weight_size)).astype(self.dtype) - output = lstm_naive(input, flat_w) + output, last_hidden, last_cell = lstm_naive(input, flat_w) init_h = np.zeros((batch_size, hidden_size), dtype=np.float32) init_c = np.zeros((batch_size, hidden_size), dtype=np.float32) + scope = core.Scope() + program = fluid.Program() + block = program.global_block() + + cache_temp = block.create_var( + name="Cache", + persistable=True, + type=core.VarDesc.VarType.RAW, + stop_gradient=True) self.inputs = { 'Input': OpTest.np_dtype_to_fluid_dtype(input), 'W': OpTest.np_dtype_to_fluid_dtype(flat_w), 'InitH': OpTest.np_dtype_to_fluid_dtype(init_h), 'InitC': OpTest.np_dtype_to_fluid_dtype(init_c), } + self.cache_name_list = ['Cache'] self.attrs = { 'max_len': num_steps, 'dropout_prob': 0.0, @@ -142,13 +164,16 @@ class TestCUDNNLstmOp(OpTest): 'hidden_size': hidden_size, 'num_layers': 1, } - self.outputs = {'Out': output} - - def test_grad_with_place(self): - place = core.CUDAPlace(0) - self.check_grad_with_place(place, atol=1e-5) + self.outputs = { + 'Out': output, + "last_h": last_hidden, + 'last_c': last_cell + } def test_output_with_place(self): place = core.CUDAPlace(0) - self.check_output_with_place( - place, atol=1e-5, no_check_set=['last_h', 'last_c']) + self.check_output_with_place(place, atol=1e-5) + + +if __name__ == '__main__': + unittest.main() -- GitLab