提交 d60fe75a 编写于 作者: D dangqingqing

follow comments.

上级 4098ce73
......@@ -246,25 +246,17 @@ class LSTMGradOp : public framework::OperatorWithKernel {
PADDLE_ENFORCE(ctx->HasInput("BatchCellPreAct"),
"Input(BatchGate) of LSTM should not be null.");
auto in_g_name = framework::GradVarName("Input");
if (ctx->HasOutput(in_g_name))
ctx->SetOutputDim(in_g_name, ctx->GetInputDim("Input"));
auto w_g_name = framework::GradVarName("Weight");
if (ctx->HasOutput(w_g_name))
ctx->SetOutputDim(w_g_name, ctx->GetInputDim("Weight"));
auto b_g_name = framework::GradVarName("Bias");
if (ctx->HasOutput(b_g_name))
ctx->SetOutputDim(b_g_name, ctx->GetInputDim("Bias"));
auto h0_g_name = framework::GradVarName("H0");
if (ctx->HasOutput(h0_g_name))
ctx->SetOutputDim(h0_g_name, ctx->GetInputDim("H0"));
auto c0_g_name = framework::GradVarName("C0");
if (ctx->HasOutput(c0_g_name))
ctx->SetOutputDim(c0_g_name, ctx->GetInputDim("C0"));
auto SetOutGradDim = [&ctx](const std::string& name) {
auto g_name = framework::GradVarName(name);
if (ctx->HasOutput(g_name))
ctx->SetOutputDim(g_name, ctx->GetInputDim(name));
};
SetOutGradDim("Input");
SetOutGradDim("Weight");
SetOutGradDim("Bias");
SetOutGradDim("H0");
SetOutGradDim("C0");
}
protected:
......
......@@ -28,6 +28,15 @@ template <typename T, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
using EigenMatrix = framework::EigenMatrix<T, MajorType, IndexType>;
template <typename Place, typename T>
inline void ReorderInitState(const platform::DeviceContext& ctx,
const framework::Tensor& src, const size_t* index,
framework::Tensor* dst, bool indexed_src) {
math::CopyMatrixRowsFunctor<Place, T> row_shuffle;
dst->mutable_data<T>(src.dims(), ctx.GetPlace());
row_shuffle(ctx, src, index, *dst, indexed_src);
}
template <typename Place, typename T>
class LSTMKernel : public framework::OpKernel<T> {
public:
......@@ -83,11 +92,13 @@ class LSTMKernel : public framework::OpKernel<T> {
}
lstm_value.prevStateValue = nullptr;
Tensor ordered_c0;
const size_t* order = batch_gate->lod()[2].data();
if (cell_t0) {
math::CopyMatrixRowsFunctor<Place, T> row_shuffle;
ordered_c0.mutable_data<T>(cell_t0->dims(), ctx.GetPlace());
const size_t* order = batch_gate->lod()[2].data();
row_shuffle(device_ctx, *cell_t0, order, ordered_c0, true);
// Since the batch computing for LSTM reorders the input sequence
// according to their length. The initialized cell state also needs
// to reorder.
ReorderInitState<Place, T>(device_ctx, *cell_t0, order, &ordered_c0,
true);
lstm_value.prevStateValue = ordered_c0.data<T>();
}
......@@ -123,11 +134,16 @@ class LSTMKernel : public framework::OpKernel<T> {
static_cast<T>(1.0), &gate_t,
static_cast<T>(1.0));
} else if (hidden_t0) {
math::CopyMatrixRowsFunctor<Place, T> row_shuffle;
// If n == 0 and there is no initialized hidden state, that is to say
// the H0 is zeros, the calculation W_h * H0 will be skiped.
// If n == 0 and there is initialized hidden state, calculate W_h * H0.
// Since the batch computing for LSTM reorders the input sequence
// according to their length. The initialized hidden state also needs
// to reorder.
Tensor ordered_h0;
ordered_h0.mutable_data<T>(hidden_t0->dims(), ctx.GetPlace());
const size_t* order = batch_gate->lod()[2].data();
row_shuffle(device_ctx, *hidden_t0, order, ordered_h0, true);
ReorderInitState<Place, T>(device_ctx, *hidden_t0, order, &ordered_h0,
true);
math::matmul<Place, T>(device_ctx, ordered_h0, false, *weight, false,
static_cast<T>(1.0), &gate_t,
static_cast<T>(1.0));
......@@ -187,12 +203,16 @@ class LSTMGradKernel : public framework::OpKernel<T> {
zero(device_ctx, weight_g, static_cast<T>(0.0));
}
// ordered_h0/c0 is the reordered hidden/cell initialization.
// ordered_h0_g/c0_g is the reordered gradient of hidden/cell
// initialization.
Tensor ordered_h0, ordered_c0, ordered_h0_g, ordered_c0_g;
math::CopyMatrixRowsFunctor<Place, T> row_shuffle;
const size_t* order = batch_gate->lod()[2].data();
if (c0) {
ordered_c0.mutable_data<T>(c0->dims(), ctx.GetPlace());
row_shuffle(device_ctx, *c0, order, ordered_c0, true);
ReorderInitState<Place, T>(device_ctx, *c0, order, &ordered_c0, true);
}
if (c0 && c0_g) {
ordered_c0_g.mutable_data<T>(c0_g->dims(), ctx.GetPlace());
}
auto in_dims = input->dims();
......@@ -231,30 +251,24 @@ class LSTMGradKernel : public framework::OpKernel<T> {
math::LoDTensor2BatchFunctor<Place, T> to_batch;
// use the local variable as here.
LoDTensor batch_hidden;
batch_hidden.mutable_data<T>(out_dims, ctx.GetPlace());
batch_hidden.set_lod(batch_gate->lod());
to_batch(device_ctx, *hidden_out, batch_hidden, false);
LoDTensor batch_hidden_g;
batch_hidden_g.mutable_data<T>(out_dims, ctx.GetPlace());
batch_hidden_g.set_lod(batch_gate->lod());
to_batch(device_ctx, *hidden_g, batch_hidden_g, false);
auto ToBatch = [&batch_gate, &to_batch](
const platform::DeviceContext& ctx, const framework::LoDTensor& src,
const framework::DDim& dims, framework::LoDTensor& dst) {
dst.mutable_data<T>(dims, ctx.GetPlace());
dst.set_lod(batch_gate->lod());
to_batch(ctx, src, dst, false);
};
LoDTensor batch_cell;
batch_cell.mutable_data<T>(out_dims, ctx.GetPlace());
batch_cell.set_lod(batch_gate->lod());
to_batch(device_ctx, *cell_out, batch_cell, false);
LoDTensor batch_hidden, batch_hidden_g, batch_cell;
ToBatch(device_ctx, *hidden_out, out_dims, batch_hidden);
ToBatch(device_ctx, *hidden_g, out_dims, batch_hidden_g);
ToBatch(device_ctx, *cell_out, out_dims, batch_cell);
LoDTensor batch_cell_g;
LoDTensor batch_cell_g, batch_gate_g;
batch_cell_g.mutable_data<T>(out_dims, ctx.GetPlace());
batch_cell_g.set_lod(batch_gate->lod());
// TODO(qingqing) support the case output cell has gradient.
// to_batch(device_ctx, *cell_g, batch_cell_g, false);
zero(device_ctx, &batch_cell_g, static_cast<T>(0.0));
LoDTensor batch_gate_g;
batch_gate_g.mutable_data<T>(batch_gate->dims(), ctx.GetPlace());
batch_gate_g.set_lod(batch_gate->lod());
......@@ -289,17 +303,8 @@ class LSTMGradKernel : public framework::OpKernel<T> {
lstm_value.prevStateValue = cell_pre.data<T>();
lstm_grad.prevStateGrad = cell_pre_g.data<T>();
} else {
if (c0) {
lstm_value.prevStateValue = ordered_c0.data<T>();
} else {
lstm_value.prevStateValue = nullptr;
}
if (c0 && c0_g) {
ordered_c0_g.mutable_data<T>(c0_g->dims(), ctx.GetPlace());
lstm_grad.prevStateGrad = ordered_c0_g.data<T>();
} else {
lstm_grad.prevStateGrad = nullptr;
}
lstm_value.prevStateValue = c0 ? ordered_c0.data<T>() : nullptr;
lstm_grad.prevStateGrad = c0_g ? ordered_c0_g.data<T>() : nullptr;
}
int cur_batch_size = bend - bstart;
......@@ -323,8 +328,7 @@ class LSTMGradKernel : public framework::OpKernel<T> {
}
} else {
if (h0 && weight_g) {
ordered_h0.mutable_data<T>(h0->dims(), ctx.GetPlace());
row_shuffle(device_ctx, *h0, order, ordered_h0, true);
ReorderInitState<Place, T>(device_ctx, *h0, order, &ordered_h0, true);
math::matmul<Place, T>(device_ctx, ordered_h0, true, gate_g, false,
static_cast<T>(1.0), weight_g,
static_cast<T>(1.0));
......@@ -359,12 +363,10 @@ class LSTMGradKernel : public framework::OpKernel<T> {
}
if (h0 && h0_g) {
h0_g->mutable_data<T>(ctx.GetPlace());
row_shuffle(device_ctx, ordered_h0_g, order, *h0_g, false);
ReorderInitState<Place, T>(device_ctx, ordered_h0_g, order, h0_g, false);
}
if (c0 && c0_g) {
c0_g->mutable_data<T>(ctx.GetPlace());
row_shuffle(device_ctx, ordered_c0_g, order, *c0_g, false);
ReorderInitState<Place, T>(device_ctx, ordered_c0_g, order, c0_g, false);
}
}
};
......
......@@ -179,36 +179,6 @@ class TestLstmOp(OpTest):
self.check_grad(
['Input', 'Weight', 'Bias'], ['Hidden'], max_relative_error=5e-4)
def test_check_grad_ingore_bias(self):
N = len(self.lod[0]) - 1
self.outputs['BatchGate'] = np.zeros((N, 4 * self.D)).astype('float64')
self.outputs['BatchCellPreAct'] = np.zeros(
(N, self.D)).astype('float64')
self.check_grad(
['Input', 'Weight'], ['Hidden'],
max_relative_error=5e-4,
no_grad_set=set('Bias'))
def test_check_grad_ingore_weight(self):
N = len(self.lod[0]) - 1
self.outputs['BatchGate'] = np.zeros((N, 4 * self.D)).astype('float64')
self.outputs['BatchCellPreAct'] = np.zeros(
(N, self.D)).astype('float64')
self.check_grad(
['Input', 'Bias'], ['Hidden'],
max_relative_error=5e-4,
no_grad_set=set('Weight'))
def test_check_grad_ingore_input(self):
N = len(self.lod[0]) - 1
self.outputs['BatchGate'] = np.zeros((N, 4 * self.D)).astype('float64')
self.outputs['BatchCellPreAct'] = np.zeros(
(N, self.D)).astype('float64')
self.check_grad(
['Weight', 'Bias'], ['Hidden'],
max_relative_error=5e-4,
no_grad_set=set('Input'))
class TestLstmOpHasInitial(TestLstmOp):
def set_argument(self):
......@@ -233,15 +203,35 @@ class TestLstmOpHasInitial(TestLstmOp):
['Input', 'Weight', 'Bias', 'H0', 'C0'], ['Hidden'],
max_relative_error=5e-4)
# In order to speed up, skip following testing
def test_check_grad_ingore_bias(self):
return
N = len(self.lod[0]) - 1
self.outputs['BatchGate'] = np.zeros((N, 4 * self.D)).astype('float64')
self.outputs['BatchCellPreAct'] = np.zeros(
(N, self.D)).astype('float64')
self.check_grad(
['Input', 'Weight'], ['Hidden'],
max_relative_error=5e-4,
no_grad_set=set('Bias'))
def test_check_grad_ingore_weight(self):
return
N = len(self.lod[0]) - 1
self.outputs['BatchGate'] = np.zeros((N, 4 * self.D)).astype('float64')
self.outputs['BatchCellPreAct'] = np.zeros(
(N, self.D)).astype('float64')
self.check_grad(
['Input', 'Bias'], ['Hidden'],
max_relative_error=5e-4,
no_grad_set=set('Weight'))
def test_check_grad_ingore_input(self):
return
N = len(self.lod[0]) - 1
self.outputs['BatchGate'] = np.zeros((N, 4 * self.D)).astype('float64')
self.outputs['BatchCellPreAct'] = np.zeros(
(N, self.D)).astype('float64')
self.check_grad(
['Weight', 'Bias'], ['Hidden'],
max_relative_error=5e-4,
no_grad_set=set('Input'))
def test_check_grad_ingore_h0(self):
N = len(self.lod[0]) - 1
......@@ -277,16 +267,6 @@ class TestLstmOpRerverse(TestLstmOp):
self.is_reverse = True
self.use_peepholes = True
# In order to speed up, skip following testing
def test_check_grad_ingore_bias(self):
return
def test_check_grad_ingore_weight(self):
return
def test_check_grad_ingore_input(self):
return
class TestLstmOpNotUsePeepholes(TestLstmOp):
def set_argument(self):
......@@ -301,16 +281,6 @@ class TestLstmOpNotUsePeepholes(TestLstmOp):
self.is_reverse = True
self.use_peepholes = False
# In order to speed up, skip following testing
def test_check_grad_ingore_bias(self):
return
def test_check_grad_ingore_weight(self):
return
def test_check_grad_ingore_input(self):
return
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册