提交 17e33738 编写于 作者: D dangqingqing

Enhance unit testing and fix bug.

上级 8bec26be
......@@ -56,10 +56,6 @@ class LSTMKernel : public framework::OpKernel<T> {
framework::DDim dims({in_dims[0], frame_size});
if (bias) {
// framework::Tensor cpu_t;
// cpu_t.mutable_data<T>(in_dims, platform::CPUPlace());
// cpu_t.CopyFrom<T>(*batch_gate, platform::CPUPlace(),
// ctx.device_context());
Eigen::array<int, 2> extents({{1, 4 * frame_size}});
Eigen::array<int, 2> offsets({{0, 0}});
auto b = EigenMatrix<T>::From(*bias);
......@@ -105,14 +101,14 @@ class LSTMKernel : public framework::OpKernel<T> {
int cur_batch_size = bend - bstart;
if (n != 0) {
int pre_end = batch_lod[n - 1];
auto pre_hidden_t = batch_out.Slice<T>(pre_end, bstart);
int pre_h_start = batch_lod[n - 1];
int pre_h_end = pre_h_start + cur_batch_size;
auto pre_hidden_t = batch_out.Slice<T>(pre_h_start, pre_h_end);
math::matmul<Place, T>(ctx.device_context(), pre_hidden_t, false,
*weight, false, static_cast<T>(1.0), &gate_t,
static_cast<T>(0.0));
static_cast<T>(1.0));
}
// else if : how to pass the state from
// last mini-batch will be supported later
// else if : support the initial hidden and cell
lstm_value.gateValue = gate_t.data<T>();
lstm_value.outputValue = out_t.data<T>();
......@@ -132,9 +128,6 @@ class LSTMKernel : public framework::OpKernel<T> {
batch_cell.set_lod(batch_gate->lod());
// restore the output cell state in LoDTensor from the batch cell
to_seq(ctx.device_context(), batch_cell, *cell_out);
auto t = framework::EigenVector<T>::Flatten(*batch_gate);
t.device(ctx.GetEigenDevice<Place>()) = t.constant(static_cast<T>(0));
}
};
......
......@@ -30,7 +30,9 @@ __device__ static float sigmoid(const float a) {
}
__device__ static float tanh(const float a) {
return __fdividef(2.0f, (1.0f + __expf(-2.0f * a))) - 1.0f;
float tmp = -2.0 * a;
tmp = (tmp > EXP_MAX_INPUT) ? EXP_MAX_INPUT : tmp;
return __fdividef(2.0f, (1.0f + __expf(-2.0f * tmp))) - 1.0f;
}
__device__ static float linear(const float a) { return a; }
......@@ -63,6 +65,8 @@ __device__ static double sigmoid(const double a) {
}
__device__ static double tanh(const double a) {
double tmp = -2.0 * a;
tmp = (tmp > EXP_MAX_INPUT) ? EXP_MAX_INPUT : tmp;
return (2.0 / (1.0 + exp(-2.0 * a))) - 1.0;
}
......
......@@ -205,11 +205,13 @@ void gpu_lstm_forward(const platform::DeviceContext& context, Op op,
if (batchSize == 1) {
KeLstmForward<T, Op,
/* isBatch= */ false><<<grid, threads, 0, stream>>>(
op, value, frameSize, batchSize, active_node, active_gate, active_gate);
op, value, frameSize, batchSize, active_node, active_gate,
active_state);
} else {
KeLstmForward<T, Op,
/* isBatch= */ true><<<grid, threads, 0, stream>>>(
op, value, frameSize, batchSize, active_node, active_gate, active_gate);
op, value, frameSize, batchSize, active_node, active_gate,
active_state);
}
}
......
......@@ -60,7 +60,7 @@ inline activation_mode_t ActiveType(const std::string &type) {
return HL_ACTIVATION_RELU;
} else if (type == "tanh") {
return HL_ACTIVATION_TANH;
} else if (type == "linear" || type == "") {
} else if (type == "linear" || type == "identity" || type == "") {
return HL_ACTIVATION_LINEAR;
} else {
PADDLE_THROW("Do not support activation type.");
......
......@@ -242,7 +242,7 @@ class OpTest(unittest.TestCase):
self.assertTrue(
np.allclose(
actual, expect, atol=atol),
"output name: " + out_name + " has diff.")
"Output (" + out_name + ") has diff at " + str(place))
else:
actual = np.array(self.scope.find_var(out_name).get_tensor())
expect = self.outputs[out_name]
......@@ -250,7 +250,7 @@ class OpTest(unittest.TestCase):
self.assertTrue(
np.allclose(
actual, expect, atol=atol),
"output name: " + out_name + " has diff.")
"Output (" + out_name + ") has diff at " + str(place))
def check_output(self, atol=1e-5):
places = [core.CPUPlace()]
......
......@@ -28,6 +28,14 @@ def relu(x):
return np.maximum(x, 0)
ACTVATION = {
'identity': identity,
'sigmoid': sigmoid,
'tanh': tanh,
'relu': relu
}
def lstm(
input, # T x 4D
lod, # 1 x N
......@@ -37,37 +45,45 @@ def lstm(
w_b=None, # 1 x 4D
w_c=None, # 1 x 3D
is_reverse=False,
gate_act=None,
cell_act=None,
cand_act=None):
def _step(x, w_h, w_c, h_pre, c_pre, gate_act, cell_act, cand_act):
act_gate=None,
act_cell=None,
act_cand=None):
def _step(x, w_h, w_c, h_pre, c_pre, act_gate, act_cell, act_cand):
g = np.dot(h_pre, w_h) # 1 x 4D
g = g + x
g = np.reshape(g, (1, g.size))
c_tmp, g_i, g_f, g_o = np.split(g, 4, axis=1)
if w_c is None:
g_i = gate_act(g_i) # 1 x D
g_f = gate_act(g_f) # 1 x D
g_i = act_gate(g_i) # 1 x D
g_f = act_gate(g_f) # 1 x D
else:
w_ic, w_fc, w_oc = np.split(w_c, 3, axis=1)
g_i = gate_act(g_i + w_ic * c_pre) # 1 x D
g_f = gate_act(g_f + w_fc * c_pre) # 1 x D
c = g_f * c_pre + g_i * cand_act(c_tmp) # 1 x D
g_i = act_gate(g_i + w_ic * c_pre) # 1 x D
g_f = act_gate(g_f + w_fc * c_pre) # 1 x D
c = g_f * c_pre + g_i * act_cand(c_tmp) # 1 x D
if w_c is None:
g_o = gate_act(g_o) # 1 x D
g_o = act_gate(g_o) # 1 x D
else:
_, _, w_oc = np.split(w_c, 3, axis=1)
g_o = gate_act(g_o + w_oc * c) # 1 x D
h = g_o * cell_act(c)
bg = np.concatenate((cand_act(c_tmp), g_i, g_f, g_o), axis=1)
g_o = act_gate(g_o + w_oc * c) # 1 x D
h = g_o * act_cell(c)
bg = np.concatenate((act_cand(c_tmp), g_i, g_f, g_o), axis=1)
return h, c, bg
def _reverse(x, lod):
y = np.zeros_like(x)
for i in range(len(lod) - 1):
b, e = lod[i], lod[i + 1]
y[b:e, :] = np.flip(x[b:e, :], 0)
return y
offset = lod[0]
batch_size = len(offset) - 1
hidden = []
cell = []
gate = []
input = _reverse(input, offset) if is_reverse else input
if w_b is not None:
input = input + np.tile(w_b, (offset[-1], 1))
for i in range(batch_size):
......@@ -78,8 +94,8 @@ def lstm(
c_pre = c0[i] # 1 x D
for j in range(seq_len):
# compute one step
h_pre, c_pre, g_pre = _step(x[j], w_h, w_c, h_pre, c_pre, gate_act,
cell_act, cand_act)
h_pre, c_pre, g_pre = _step(x[j], w_h, w_c, h_pre, c_pre, act_gate,
act_cell, act_cand)
hidden.append(h_pre.flatten())
cell.append(c_pre.flatten())
gate.append(g_pre.flatten())
......@@ -87,38 +103,53 @@ def lstm(
hidden = np.array(hidden).astype("float64")
cell = np.array(cell).astype("float64")
gate = np.array(gate).astype("float64")
hidden = _reverse(hidden, offset) if is_reverse else hidden
cell = _reverse(cell, offset) if is_reverse else cell
assert gate.shape == input.shape
assert hidden.shape == (input.shape[0], input.shape[1] / 4)
assert cell.shape == (input.shape[0], input.shape[1] / 4)
return hidden, cell, gate
class LstmUnitTest(OpTest):
class TestLstmOp(OpTest):
def set_data(self):
D = 4
#lod = [[0, 2, 6, 9]]
lod = [[0, 1]]
shape = (1, D)
x = np.random.normal(size=(1, 4 * D)).astype("float64")
h0 = np.zeros((4, D)).astype("float64")
c0 = np.zeros((4, D)).astype("float64")
w = np.random.normal(size=(D, 4 * D)).astype("float64")
b = np.random.normal(size=(1, 7 * D)).astype("float64")
w_b = b[:, 0:4 * D]
w_c = b[:, 4 * D:]
#h, c, g = lstm(x, lod, h0, c0, w, w_b, w_c, False, sigmoid, tanh, tanh)
h, c, g = lstm(x, lod, h0, c0, w, w_b, w_c, False, identity, identity,
identity)
self.lod = [[0, 2, 6, 9]]
self.D = 64
self.sort_idx = [2, 6, 0, 3, 7, 1, 4, 8, 5]
self.act_gate = "sigmoid"
self.act_cell = "tanh"
self.act_cand = "tanh"
self.is_reverse = False
def setUp(self):
self.set_data()
self.op_type = "lstm"
T = self.lod[0][-1]
N = len(self.lod[0]) - 1
x = np.random.normal(size=(T, 4 * self.D)).astype("float64")
h0 = np.zeros((N, self.D)).astype("float64")
c0 = np.zeros((N, self.D)).astype("float64")
w = np.random.normal(size=(self.D, 4 * self.D)).astype("float64")
b = np.random.normal(size=(1, 7 * self.D)).astype("float64")
w_b = b[:, 0:4 * self.D]
w_c = b[:, 4 * self.D:]
h, c, g = lstm(x, self.lod, h0, c0, w, w_b, w_c, self.is_reverse,
ACTVATION[self.act_gate], ACTVATION[self.act_cell],
ACTVATION[self.act_cand])
g_sort = np.zeros_like(x)
#idx = [2,6,0,3,7,1,4,8,5]
#for i, j in enumerate(idx):
# g_sort[i, :] = g[j, :]
for i, j in enumerate(self.sort_idx):
g_sort[i, :] = g[j, :]
self.inputs = {
'Input': (x, lod),
'Input': (x, self.lod),
'H0': h0,
'C0': c0,
'Weight': w,
......@@ -127,19 +158,28 @@ class LstmUnitTest(OpTest):
self.outputs = {'Hidden': h, 'Cell': c, 'BatchGate': g_sort}
self.attrs = {
'usePeepholes': True,
'isReverse': False,
'gateActivation': 'linear',
'cellActivation': 'linear',
'candidateActivation': 'linear'
'isReverse': self.is_reverse,
'gateActivation': 'sigmoid',
'cellActivation': 'tanh',
'candidateActivation': 'tanh'
}
def setUp(self):
self.set_data()
self.op_type = "lstm"
def test_check_output(self):
self.check_output()
class TestLstmOpRerverse(TestLstmOp):
def set_data(self):
self.lod = [[0, 2, 6, 9]]
self.D = 64
self.sort_idx = [2, 6, 0, 3, 7, 1, 4, 8, 5]
self.act_gate = "sigmoid"
self.act_cell = "tanh"
self.act_cand = "tanh"
self.is_reverse = True
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册