提交 bd680f15 编写于 作者: D dangqingqing

fix compiling warning.

上级 bcc0dad7
......@@ -155,7 +155,7 @@ class LSTMGradKernel : public framework::OpKernel<T> {
auto* batch_cell_pre_act = ctx.Input<LoDTensor>("BatchCellPreAct");
auto* hidden_g = ctx.Input<LoDTensor>(framework::GradVarName("Hidden"));
auto* cell_g = ctx.Input<LoDTensor>(framework::GradVarName("Cell"));
// auto* cell_g = ctx.Input<LoDTensor>(framework::GradVarName("Cell"));
auto* in_g = ctx.Output<LoDTensor>(framework::GradVarName("Input"));
auto* weight_g = ctx.Output<Tensor>(framework::GradVarName("Weight"));
......@@ -219,8 +219,8 @@ class LSTMGradKernel : public framework::OpKernel<T> {
LoDTensor batch_cell_g;
batch_cell_g.mutable_data<T>(out_dims, ctx.GetPlace());
batch_cell_g.set_lod(batch_gate->lod());
to_batch(device_ctx, *cell_g, batch_cell_g, false);
// TODO(qingqing) support the case output cell has gradient.
// to_batch(device_ctx, *cell_g, batch_cell_g, false);
zero(device_ctx, &batch_cell_g, static_cast<T>(0.0));
LoDTensor batch_gate_g;
......
......@@ -58,7 +58,8 @@ class LoDTensor2BatchFunctor {
if (!is_cal_batch_lod) {
auto lods = batch.lod();
PADDLE_ENFORCE_EQ(lods.size(), 2UL);
PADDLE_ENFORCE_EQ(lods[1].size(), lod_tensor.dims()[0]);
PADDLE_ENFORCE_EQ(lods[1].size(),
static_cast<size_t>(lod_tensor.dims()[0]));
CopyMatrixRowsFunctor<Place, T> to_batch;
to_batch(context, lod_tensor, lods[1].data(), batch, true);
return;
......@@ -111,10 +112,10 @@ class LoDTensor2BatchFunctor {
size_t* batch_starts = batch_lods[0].data();
size_t* seq2batch_idx = batch_lods[1].data();
batch_starts[0] = 0;
for (size_t n = 0; n < num_batch; n++) {
for (int n = 0; n < num_batch; n++) {
auto batch_id = static_cast<int>(batch_starts[n]);
for (size_t i = 0; i < seq_info.size(); ++i) {
size_t seq_len = seq_info[i].length;
int seq_len = seq_info[i].length;
int start = seq_info[i].start;
if (n < seq_len) {
seq2batch_idx[batch_id] =
......
......@@ -52,7 +52,7 @@ def lstm(
g = np.dot(h_pre, w_h) # 1 x 4D
g = g + x
g = np.reshape(g, (1, g.size))
c_tmp, g_i, g_f, g_o = np.split(g, 4, axis=1)
c, g_i, g_f, g_o = np.split(g, 4, axis=1)
if w_c is None:
g_i = act_gate(g_i) # 1 x D
g_f = act_gate(g_f) # 1 x D
......@@ -60,7 +60,7 @@ def lstm(
w_ic, w_fc, w_oc = np.split(w_c, 3, axis=1)
g_i = act_gate(g_i + w_ic * c_pre) # 1 x D
g_f = act_gate(g_f + w_fc * c_pre) # 1 x D
c = g_f * c_pre + g_i * act_cand(c_tmp) # 1 x D
c = g_f * c_pre + g_i * act_cand(c) # 1 x D
if w_c is None:
g_o = act_gate(g_o) # 1 x D
......@@ -68,8 +68,7 @@ def lstm(
_, _, w_oc = np.split(w_c, 3, axis=1)
g_o = act_gate(g_o + w_oc * c) # 1 x D
h = g_o * act_cell(c)
bg = np.concatenate((act_cand(c_tmp), g_i, g_f, g_o), axis=1)
return h, c, bg
return h, c
def _reverse(x, lod):
y = np.zeros_like(x)
......@@ -82,7 +81,6 @@ def lstm(
batch_size = len(offset) - 1
hidden = []
cell = []
gate = []
input = _reverse(input, offset) if is_reverse else input
if w_b is not None:
input = input + np.tile(w_b, (offset[-1], 1))
......@@ -94,30 +92,26 @@ def lstm(
c_pre = c0[i] # 1 x D
for j in range(seq_len):
# compute one step
h_pre, c_pre, g_pre = _step(x[j], w_h, w_c, h_pre, c_pre, act_gate,
h_pre, c_pre = _step(x[j], w_h, w_c, h_pre, c_pre, act_gate,
act_cell, act_cand)
hidden.append(h_pre.flatten())
cell.append(c_pre.flatten())
gate.append(g_pre.flatten())
hidden = np.array(hidden).astype('float64')
cell = np.array(cell).astype('float64')
gate = np.array(gate).astype('float64')
hidden = _reverse(hidden, offset) if is_reverse else hidden
cell = _reverse(cell, offset) if is_reverse else cell
assert gate.shape == input.shape
assert hidden.shape == (input.shape[0], input.shape[1] / 4)
assert cell.shape == (input.shape[0], input.shape[1] / 4)
return hidden, cell, gate
return hidden, cell
class TestLstmOp(OpTest):
def set_argument(self):
self.lod = [[0, 2, 6, 9]]
self.lod = [[0, 2, 6]]
self.D = 16
self.sort_idx = [2, 6, 0, 3, 7, 1, 4, 8, 5]
self.act_gate = 'sigmoid'
self.act_cell = 'tanh'
......@@ -141,22 +135,18 @@ class TestLstmOp(OpTest):
w_b = b[:, 0:4 * self.D]
w_c = b[:, 4 * self.D:]
h, c, g = lstm(x, self.lod, h0, c0, w, w_b, w_c, self.is_reverse,
h, c = lstm(x, self.lod, h0, c0, w, w_b, w_c, self.is_reverse,
ACTVATION[self.act_gate], ACTVATION[self.act_cell],
ACTVATION[self.act_cand])
g_sort = np.zeros_like(x)
for i, j in enumerate(self.sort_idx):
g_sort[i, :] = g[j, :]
self.inputs = {'Input': (x, self.lod), 'Weight': w, 'Bias': b}
if self.has_initial_state:
self.inputs['H0'] = h0
self.inputs['C0'] = c0
self.outputs = {
'Hidden': (h, self.lod),
'Cell': (c, self.lod),
'BatchGate': g_sort,
}
self.attrs = {
'usePeepholes': True,
......@@ -179,9 +169,8 @@ class TestLstmOp(OpTest):
class TestLstmOpHasNoInitial(TestLstmOp):
def set_argument(self):
self.lod = [[0, 2, 6, 9]]
self.D = 64
self.sort_idx = [2, 6, 0, 3, 7, 1, 4, 8, 5]
self.lod = [[0, 2, 6]]
self.D = 16
self.act_gate = 'sigmoid'
self.act_cell = 'tanh'
......@@ -193,9 +182,8 @@ class TestLstmOpHasNoInitial(TestLstmOp):
class TestLstmOpRerverse(TestLstmOp):
def set_argument(self):
self.lod = [[0, 2, 6, 9]]
self.D = 64
self.sort_idx = [2, 6, 0, 3, 7, 1, 4, 8, 5]
self.lod = [[0, 2, 6]]
self.D = 16
self.act_gate = 'sigmoid'
self.act_cell = 'tanh'
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册