提交 d60fe75a 编写于 作者: D dangqingqing

follow comments.

上级 4098ce73
...@@ -246,25 +246,17 @@ class LSTMGradOp : public framework::OperatorWithKernel { ...@@ -246,25 +246,17 @@ class LSTMGradOp : public framework::OperatorWithKernel {
PADDLE_ENFORCE(ctx->HasInput("BatchCellPreAct"), PADDLE_ENFORCE(ctx->HasInput("BatchCellPreAct"),
"Input(BatchGate) of LSTM should not be null."); "Input(BatchGate) of LSTM should not be null.");
auto in_g_name = framework::GradVarName("Input"); auto SetOutGradDim = [&ctx](const std::string& name) {
if (ctx->HasOutput(in_g_name)) auto g_name = framework::GradVarName(name);
ctx->SetOutputDim(in_g_name, ctx->GetInputDim("Input")); if (ctx->HasOutput(g_name))
ctx->SetOutputDim(g_name, ctx->GetInputDim(name));
auto w_g_name = framework::GradVarName("Weight"); };
if (ctx->HasOutput(w_g_name))
ctx->SetOutputDim(w_g_name, ctx->GetInputDim("Weight")); SetOutGradDim("Input");
SetOutGradDim("Weight");
auto b_g_name = framework::GradVarName("Bias"); SetOutGradDim("Bias");
if (ctx->HasOutput(b_g_name)) SetOutGradDim("H0");
ctx->SetOutputDim(b_g_name, ctx->GetInputDim("Bias")); SetOutGradDim("C0");
auto h0_g_name = framework::GradVarName("H0");
if (ctx->HasOutput(h0_g_name))
ctx->SetOutputDim(h0_g_name, ctx->GetInputDim("H0"));
auto c0_g_name = framework::GradVarName("C0");
if (ctx->HasOutput(c0_g_name))
ctx->SetOutputDim(c0_g_name, ctx->GetInputDim("C0"));
} }
protected: protected:
......
...@@ -28,6 +28,15 @@ template <typename T, int MajorType = Eigen::RowMajor, ...@@ -28,6 +28,15 @@ template <typename T, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex> typename IndexType = Eigen::DenseIndex>
using EigenMatrix = framework::EigenMatrix<T, MajorType, IndexType>; using EigenMatrix = framework::EigenMatrix<T, MajorType, IndexType>;
template <typename Place, typename T>
inline void ReorderInitState(const platform::DeviceContext& ctx,
const framework::Tensor& src, const size_t* index,
framework::Tensor* dst, bool indexed_src) {
math::CopyMatrixRowsFunctor<Place, T> row_shuffle;
dst->mutable_data<T>(src.dims(), ctx.GetPlace());
row_shuffle(ctx, src, index, *dst, indexed_src);
}
template <typename Place, typename T> template <typename Place, typename T>
class LSTMKernel : public framework::OpKernel<T> { class LSTMKernel : public framework::OpKernel<T> {
public: public:
...@@ -83,11 +92,13 @@ class LSTMKernel : public framework::OpKernel<T> { ...@@ -83,11 +92,13 @@ class LSTMKernel : public framework::OpKernel<T> {
} }
lstm_value.prevStateValue = nullptr; lstm_value.prevStateValue = nullptr;
Tensor ordered_c0; Tensor ordered_c0;
const size_t* order = batch_gate->lod()[2].data();
if (cell_t0) { if (cell_t0) {
math::CopyMatrixRowsFunctor<Place, T> row_shuffle; // Since the batch computing for LSTM reorders the input sequence
ordered_c0.mutable_data<T>(cell_t0->dims(), ctx.GetPlace()); // according to their length. The initialized cell state also needs
const size_t* order = batch_gate->lod()[2].data(); // to reorder.
row_shuffle(device_ctx, *cell_t0, order, ordered_c0, true); ReorderInitState<Place, T>(device_ctx, *cell_t0, order, &ordered_c0,
true);
lstm_value.prevStateValue = ordered_c0.data<T>(); lstm_value.prevStateValue = ordered_c0.data<T>();
} }
...@@ -123,11 +134,16 @@ class LSTMKernel : public framework::OpKernel<T> { ...@@ -123,11 +134,16 @@ class LSTMKernel : public framework::OpKernel<T> {
static_cast<T>(1.0), &gate_t, static_cast<T>(1.0), &gate_t,
static_cast<T>(1.0)); static_cast<T>(1.0));
} else if (hidden_t0) { } else if (hidden_t0) {
math::CopyMatrixRowsFunctor<Place, T> row_shuffle; // If n == 0 and there is no initialized hidden state, that is to say
// the H0 is zeros, the calculation W_h * H0 will be skiped.
// If n == 0 and there is initialized hidden state, calculate W_h * H0.
// Since the batch computing for LSTM reorders the input sequence
// according to their length. The initialized hidden state also needs
// to reorder.
Tensor ordered_h0; Tensor ordered_h0;
ordered_h0.mutable_data<T>(hidden_t0->dims(), ctx.GetPlace()); ReorderInitState<Place, T>(device_ctx, *hidden_t0, order, &ordered_h0,
const size_t* order = batch_gate->lod()[2].data(); true);
row_shuffle(device_ctx, *hidden_t0, order, ordered_h0, true);
math::matmul<Place, T>(device_ctx, ordered_h0, false, *weight, false, math::matmul<Place, T>(device_ctx, ordered_h0, false, *weight, false,
static_cast<T>(1.0), &gate_t, static_cast<T>(1.0), &gate_t,
static_cast<T>(1.0)); static_cast<T>(1.0));
...@@ -187,12 +203,16 @@ class LSTMGradKernel : public framework::OpKernel<T> { ...@@ -187,12 +203,16 @@ class LSTMGradKernel : public framework::OpKernel<T> {
zero(device_ctx, weight_g, static_cast<T>(0.0)); zero(device_ctx, weight_g, static_cast<T>(0.0));
} }
// ordered_h0/c0 is the reordered hidden/cell initialization.
// ordered_h0_g/c0_g is the reordered gradient of hidden/cell
// initialization.
Tensor ordered_h0, ordered_c0, ordered_h0_g, ordered_c0_g; Tensor ordered_h0, ordered_c0, ordered_h0_g, ordered_c0_g;
math::CopyMatrixRowsFunctor<Place, T> row_shuffle;
const size_t* order = batch_gate->lod()[2].data(); const size_t* order = batch_gate->lod()[2].data();
if (c0) { if (c0) {
ordered_c0.mutable_data<T>(c0->dims(), ctx.GetPlace()); ReorderInitState<Place, T>(device_ctx, *c0, order, &ordered_c0, true);
row_shuffle(device_ctx, *c0, order, ordered_c0, true); }
if (c0 && c0_g) {
ordered_c0_g.mutable_data<T>(c0_g->dims(), ctx.GetPlace());
} }
auto in_dims = input->dims(); auto in_dims = input->dims();
...@@ -231,30 +251,24 @@ class LSTMGradKernel : public framework::OpKernel<T> { ...@@ -231,30 +251,24 @@ class LSTMGradKernel : public framework::OpKernel<T> {
math::LoDTensor2BatchFunctor<Place, T> to_batch; math::LoDTensor2BatchFunctor<Place, T> to_batch;
// use the local variable as here. auto ToBatch = [&batch_gate, &to_batch](
LoDTensor batch_hidden; const platform::DeviceContext& ctx, const framework::LoDTensor& src,
batch_hidden.mutable_data<T>(out_dims, ctx.GetPlace()); const framework::DDim& dims, framework::LoDTensor& dst) {
batch_hidden.set_lod(batch_gate->lod()); dst.mutable_data<T>(dims, ctx.GetPlace());
to_batch(device_ctx, *hidden_out, batch_hidden, false); dst.set_lod(batch_gate->lod());
to_batch(ctx, src, dst, false);
LoDTensor batch_hidden_g; };
batch_hidden_g.mutable_data<T>(out_dims, ctx.GetPlace());
batch_hidden_g.set_lod(batch_gate->lod());
to_batch(device_ctx, *hidden_g, batch_hidden_g, false);
LoDTensor batch_cell; LoDTensor batch_hidden, batch_hidden_g, batch_cell;
batch_cell.mutable_data<T>(out_dims, ctx.GetPlace()); ToBatch(device_ctx, *hidden_out, out_dims, batch_hidden);
batch_cell.set_lod(batch_gate->lod()); ToBatch(device_ctx, *hidden_g, out_dims, batch_hidden_g);
to_batch(device_ctx, *cell_out, batch_cell, false); ToBatch(device_ctx, *cell_out, out_dims, batch_cell);
LoDTensor batch_cell_g; LoDTensor batch_cell_g, batch_gate_g;
batch_cell_g.mutable_data<T>(out_dims, ctx.GetPlace()); batch_cell_g.mutable_data<T>(out_dims, ctx.GetPlace());
batch_cell_g.set_lod(batch_gate->lod());
// TODO(qingqing) support the case output cell has gradient. // TODO(qingqing) support the case output cell has gradient.
// to_batch(device_ctx, *cell_g, batch_cell_g, false); // to_batch(device_ctx, *cell_g, batch_cell_g, false);
zero(device_ctx, &batch_cell_g, static_cast<T>(0.0)); zero(device_ctx, &batch_cell_g, static_cast<T>(0.0));
LoDTensor batch_gate_g;
batch_gate_g.mutable_data<T>(batch_gate->dims(), ctx.GetPlace()); batch_gate_g.mutable_data<T>(batch_gate->dims(), ctx.GetPlace());
batch_gate_g.set_lod(batch_gate->lod()); batch_gate_g.set_lod(batch_gate->lod());
...@@ -289,17 +303,8 @@ class LSTMGradKernel : public framework::OpKernel<T> { ...@@ -289,17 +303,8 @@ class LSTMGradKernel : public framework::OpKernel<T> {
lstm_value.prevStateValue = cell_pre.data<T>(); lstm_value.prevStateValue = cell_pre.data<T>();
lstm_grad.prevStateGrad = cell_pre_g.data<T>(); lstm_grad.prevStateGrad = cell_pre_g.data<T>();
} else { } else {
if (c0) { lstm_value.prevStateValue = c0 ? ordered_c0.data<T>() : nullptr;
lstm_value.prevStateValue = ordered_c0.data<T>(); lstm_grad.prevStateGrad = c0_g ? ordered_c0_g.data<T>() : nullptr;
} else {
lstm_value.prevStateValue = nullptr;
}
if (c0 && c0_g) {
ordered_c0_g.mutable_data<T>(c0_g->dims(), ctx.GetPlace());
lstm_grad.prevStateGrad = ordered_c0_g.data<T>();
} else {
lstm_grad.prevStateGrad = nullptr;
}
} }
int cur_batch_size = bend - bstart; int cur_batch_size = bend - bstart;
...@@ -323,8 +328,7 @@ class LSTMGradKernel : public framework::OpKernel<T> { ...@@ -323,8 +328,7 @@ class LSTMGradKernel : public framework::OpKernel<T> {
} }
} else { } else {
if (h0 && weight_g) { if (h0 && weight_g) {
ordered_h0.mutable_data<T>(h0->dims(), ctx.GetPlace()); ReorderInitState<Place, T>(device_ctx, *h0, order, &ordered_h0, true);
row_shuffle(device_ctx, *h0, order, ordered_h0, true);
math::matmul<Place, T>(device_ctx, ordered_h0, true, gate_g, false, math::matmul<Place, T>(device_ctx, ordered_h0, true, gate_g, false,
static_cast<T>(1.0), weight_g, static_cast<T>(1.0), weight_g,
static_cast<T>(1.0)); static_cast<T>(1.0));
...@@ -359,12 +363,10 @@ class LSTMGradKernel : public framework::OpKernel<T> { ...@@ -359,12 +363,10 @@ class LSTMGradKernel : public framework::OpKernel<T> {
} }
if (h0 && h0_g) { if (h0 && h0_g) {
h0_g->mutable_data<T>(ctx.GetPlace()); ReorderInitState<Place, T>(device_ctx, ordered_h0_g, order, h0_g, false);
row_shuffle(device_ctx, ordered_h0_g, order, *h0_g, false);
} }
if (c0 && c0_g) { if (c0 && c0_g) {
c0_g->mutable_data<T>(ctx.GetPlace()); ReorderInitState<Place, T>(device_ctx, ordered_c0_g, order, c0_g, false);
row_shuffle(device_ctx, ordered_c0_g, order, *c0_g, false);
} }
} }
}; };
......
...@@ -179,36 +179,6 @@ class TestLstmOp(OpTest): ...@@ -179,36 +179,6 @@ class TestLstmOp(OpTest):
self.check_grad( self.check_grad(
['Input', 'Weight', 'Bias'], ['Hidden'], max_relative_error=5e-4) ['Input', 'Weight', 'Bias'], ['Hidden'], max_relative_error=5e-4)
def test_check_grad_ingore_bias(self):
N = len(self.lod[0]) - 1
self.outputs['BatchGate'] = np.zeros((N, 4 * self.D)).astype('float64')
self.outputs['BatchCellPreAct'] = np.zeros(
(N, self.D)).astype('float64')
self.check_grad(
['Input', 'Weight'], ['Hidden'],
max_relative_error=5e-4,
no_grad_set=set('Bias'))
def test_check_grad_ingore_weight(self):
N = len(self.lod[0]) - 1
self.outputs['BatchGate'] = np.zeros((N, 4 * self.D)).astype('float64')
self.outputs['BatchCellPreAct'] = np.zeros(
(N, self.D)).astype('float64')
self.check_grad(
['Input', 'Bias'], ['Hidden'],
max_relative_error=5e-4,
no_grad_set=set('Weight'))
def test_check_grad_ingore_input(self):
N = len(self.lod[0]) - 1
self.outputs['BatchGate'] = np.zeros((N, 4 * self.D)).astype('float64')
self.outputs['BatchCellPreAct'] = np.zeros(
(N, self.D)).astype('float64')
self.check_grad(
['Weight', 'Bias'], ['Hidden'],
max_relative_error=5e-4,
no_grad_set=set('Input'))
class TestLstmOpHasInitial(TestLstmOp): class TestLstmOpHasInitial(TestLstmOp):
def set_argument(self): def set_argument(self):
...@@ -233,15 +203,35 @@ class TestLstmOpHasInitial(TestLstmOp): ...@@ -233,15 +203,35 @@ class TestLstmOpHasInitial(TestLstmOp):
['Input', 'Weight', 'Bias', 'H0', 'C0'], ['Hidden'], ['Input', 'Weight', 'Bias', 'H0', 'C0'], ['Hidden'],
max_relative_error=5e-4) max_relative_error=5e-4)
# In order to speed up, skip following testing
def test_check_grad_ingore_bias(self): def test_check_grad_ingore_bias(self):
return N = len(self.lod[0]) - 1
self.outputs['BatchGate'] = np.zeros((N, 4 * self.D)).astype('float64')
self.outputs['BatchCellPreAct'] = np.zeros(
(N, self.D)).astype('float64')
self.check_grad(
['Input', 'Weight'], ['Hidden'],
max_relative_error=5e-4,
no_grad_set=set('Bias'))
def test_check_grad_ingore_weight(self): def test_check_grad_ingore_weight(self):
return N = len(self.lod[0]) - 1
self.outputs['BatchGate'] = np.zeros((N, 4 * self.D)).astype('float64')
self.outputs['BatchCellPreAct'] = np.zeros(
(N, self.D)).astype('float64')
self.check_grad(
['Input', 'Bias'], ['Hidden'],
max_relative_error=5e-4,
no_grad_set=set('Weight'))
def test_check_grad_ingore_input(self): def test_check_grad_ingore_input(self):
return N = len(self.lod[0]) - 1
self.outputs['BatchGate'] = np.zeros((N, 4 * self.D)).astype('float64')
self.outputs['BatchCellPreAct'] = np.zeros(
(N, self.D)).astype('float64')
self.check_grad(
['Weight', 'Bias'], ['Hidden'],
max_relative_error=5e-4,
no_grad_set=set('Input'))
def test_check_grad_ingore_h0(self): def test_check_grad_ingore_h0(self):
N = len(self.lod[0]) - 1 N = len(self.lod[0]) - 1
...@@ -277,16 +267,6 @@ class TestLstmOpRerverse(TestLstmOp): ...@@ -277,16 +267,6 @@ class TestLstmOpRerverse(TestLstmOp):
self.is_reverse = True self.is_reverse = True
self.use_peepholes = True self.use_peepholes = True
# In order to speed up, skip following testing
def test_check_grad_ingore_bias(self):
return
def test_check_grad_ingore_weight(self):
return
def test_check_grad_ingore_input(self):
return
class TestLstmOpNotUsePeepholes(TestLstmOp): class TestLstmOpNotUsePeepholes(TestLstmOp):
def set_argument(self): def set_argument(self):
...@@ -301,16 +281,6 @@ class TestLstmOpNotUsePeepholes(TestLstmOp): ...@@ -301,16 +281,6 @@ class TestLstmOpNotUsePeepholes(TestLstmOp):
self.is_reverse = True self.is_reverse = True
self.use_peepholes = False self.use_peepholes = False
# In order to speed up, skip following testing
def test_check_grad_ingore_bias(self):
return
def test_check_grad_ingore_weight(self):
return
def test_check_grad_ingore_input(self):
return
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册