提交 b65722d3 编写于 作者: P phlrain

fix uni test; test=develop

上级 2770ea1a
......@@ -279,12 +279,6 @@ class CudnnLSTMGPUKernel : public framework::OpKernel<T> {
int num_layers = ctx.Attr<int>("num_layers");
bool is_test = ctx.Attr<bool>("is_test");
/*
if (is_test) {
TensorCopy(*x, ctx.GetPlace(), out);
return;
}*/
auto &dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();
auto handle = dev_ctx.cudnn_handle();
auto *cache_var = ctx.InputVar("Cache");
......
......@@ -477,12 +477,10 @@ def lstm(input,
init_h,
init_c,
max_len,
dropout_prob,
input_size,
hidden_size,
num_layers,
dropout_prob=0.0,
is_bidirec=False,
dtype='float32',
is_test=False,
name=None,
default_initializer=None,
......@@ -531,13 +529,11 @@ def lstm(input,
This is a tensor with shape ( num_layers x batch_size x hidden_size )
if is_bidirec = True, shape should be ( num_layers*2 x batch_size x hidden_size)
max_len (int): max length of LSTM. the first dim of input tensor CAN NOT greater than max_len
dropout_prob(float): dropout prob, dropout ONLY work between rnn layers, NOT between time steps
There is NO dropout work on rnn output of the last RNN layers
input_size (int): hidden size of the input tensor
hidden_size (int): hidden size of the LSTM
num_layers (int): total layers number of the LSTM
dropout_prob(float|0.0): dropout prob, dropout ONLY work between rnn layers, NOT between time steps
There is NO dropout work on rnn output of the last RNN layers
is_bidirec (bool): If it is bidirectional
dtype (str): Data type. Choices = ["float32", "float64"], default "float32".
is_test (bool): If it is in test phrase
name (str|None): A name for this layer(optional). If set None, the layer
will be named automatically.
......@@ -577,6 +573,9 @@ def lstm(input,
helper = LayerHelper('cudnn_lstm', **locals())
dtype = input.dtype
input_shape = list(input.shape)
input_size = input_shape[-1]
weight_size = 0
for i in range(num_layers):
if i == 0:
......
......@@ -216,6 +216,15 @@ class OpTest(unittest.TestCase):
self.dtype)
outputs = append_input_output(block, op_proto, self.outputs, False,
self.dtype)
if hasattr(self, "cache_name_list"):
for name in self.cache_name_list:
inputs[name] = block.create_var(
name=name,
persistable=True,
type=core.VarDesc.VarType.RAW,
stop_gradient=True)
op = block.append_op(
type=self.op_type,
inputs=inputs,
......
......@@ -19,6 +19,11 @@ import numpy as np
import paddle.fluid.core as core
from op_test import OpTest
import paddle.fluid as fluid
SIGMOID_THRESHOLD_MIN = -40.0
SIGMOID_THRESHOLD_MAX = 13.0
EXP_MAX_INPUT = 40.0
def lstm_naive(
......@@ -70,10 +75,15 @@ def lstm_naive(
bo_2 = w[offset:offset + hidden_size]
def sigmoid(x):
return 1.0 / (1.0 + np.exp(-x))
y = np.copy(x)
y[x < SIGMOID_THRESHOLD_MIN] = SIGMOID_THRESHOLD_MIN
y[x > SIGMOID_THRESHOLD_MAX] = SIGMOID_THRESHOLD_MAX
return 1. / (1. + np.exp(-y))
def tanh(x):
return (np.exp(x) - np.exp(-x)) / (np.exp(x) + np.exp(-x))
y = -2. * x
y[y > EXP_MAX_INPUT] = EXP_MAX_INPUT
return (2. / (1. + np.exp(y))) - 1.
output = []
pre_h = np.zeros((batch_size, hidden_size), dtype=input.dtype)
......@@ -103,7 +113,7 @@ def lstm_naive(
output = output.transpose((1, 0, 2))
return output
return output, pre_h, pre_c
class TestCUDNNLstmOp(OpTest):
......@@ -120,20 +130,32 @@ class TestCUDNNLstmOp(OpTest):
weight_size = input_weight_size + hidden_weight_size
weight_size += hidden_size * 8
input = np.random.random(
(num_steps, batch_size, hidden_size)).astype(self.dtype)
flat_w = np.random.random((weight_size)).astype(self.dtype)
input = np.random.uniform(
low=-0.1, high=0.1, size=(num_steps, batch_size,
hidden_size)).astype(self.dtype)
flat_w = np.random.uniform(
low=-0.1, high=0.1, size=(weight_size)).astype(self.dtype)
output = lstm_naive(input, flat_w)
output, last_hidden, last_cell = lstm_naive(input, flat_w)
init_h = np.zeros((batch_size, hidden_size), dtype=np.float32)
init_c = np.zeros((batch_size, hidden_size), dtype=np.float32)
scope = core.Scope()
program = fluid.Program()
block = program.global_block()
cache_temp = block.create_var(
name="Cache",
persistable=True,
type=core.VarDesc.VarType.RAW,
stop_gradient=True)
self.inputs = {
'Input': OpTest.np_dtype_to_fluid_dtype(input),
'W': OpTest.np_dtype_to_fluid_dtype(flat_w),
'InitH': OpTest.np_dtype_to_fluid_dtype(init_h),
'InitC': OpTest.np_dtype_to_fluid_dtype(init_c),
}
self.cache_name_list = ['Cache']
self.attrs = {
'max_len': num_steps,
'dropout_prob': 0.0,
......@@ -142,13 +164,16 @@ class TestCUDNNLstmOp(OpTest):
'hidden_size': hidden_size,
'num_layers': 1,
}
self.outputs = {'Out': output}
def test_grad_with_place(self):
place = core.CUDAPlace(0)
self.check_grad_with_place(place, atol=1e-5)
self.outputs = {
'Out': output,
"last_h": last_hidden,
'last_c': last_cell
}
def test_output_with_place(self):
place = core.CUDAPlace(0)
self.check_output_with_place(
place, atol=1e-5, no_check_set=['last_h', 'last_c'])
self.check_output_with_place(place, atol=1e-5)
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册