diff --git a/paddle/fluid/operators/gru_op.cc b/paddle/fluid/operators/gru_op.cc index 5c746878823b3dcde2573feec00d3d9dac5ceab8..4847eb362699053be59bfab9726a10b037c39c87 100644 --- a/paddle/fluid/operators/gru_op.cc +++ b/paddle/fluid/operators/gru_op.cc @@ -211,6 +211,139 @@ class GRUGradOp : public framework::OperatorWithKernel { } }; +template +class GRUCPUKernel : public framework::OpKernel { + public: + void BatchCompute(const framework::ExecutionContext& context) const { + using DeviceContext = paddle::platform::CPUDeviceContext; + auto* input = context.Input("Input"); + auto* h0 = context.Input("H0"); + auto* weight = context.Input("Weight"); + const T* weight_data = weight->data(); + auto* bias = context.Input("Bias"); + auto* batch_gate = context.Output("BatchGate"); + batch_gate->mutable_data(context.GetPlace()); + auto* batch_reset_hidden_prev = + context.Output("BatchResetHiddenPrev"); + batch_reset_hidden_prev->mutable_data(context.GetPlace()); + auto* batch_hidden = context.Output("BatchHidden"); + batch_hidden->mutable_data(context.GetPlace()); + auto* hidden = context.Output("Hidden"); + hidden->mutable_data(context.GetPlace()); + + auto hidden_dims = hidden->dims(); + + bool is_reverse = context.Attr("is_reverse"); + math::LoDTensor2BatchFunctor to_batch; + auto& dev_ctx = context.template device_context(); + to_batch(dev_ctx, *input, batch_gate, true, is_reverse); + + if (bias) { + math::RowwiseAdd add_bias; + add_bias(dev_ctx, *batch_gate, *bias, batch_gate); + } + + int frame_size = hidden_dims[1]; + math::GRUMetaValue gru_value; + gru_value.gate_weight = const_cast(weight_data); + gru_value.state_weight = + const_cast(weight_data + 2 * frame_size * frame_size); + Tensor ordered_h0; + + framework::Vector order(batch_gate->lod()[2]); + + if (h0) { + // Since the batch computing for GRU reorders the input sequences + // according to their length. The initialized cell state also needs + // to reorder. + ReorderInitState( + context.template device_context(), *h0, order, + &ordered_h0, true); + gru_value.prev_out_value = ordered_h0.data(); + } else { + gru_value.prev_out_value = nullptr; + } + auto batch_starts = batch_gate->lod()[0]; + size_t num_batch = batch_starts.size() - 1; + auto active_node = math::detail::GetActivationType( + context.Attr("activation")); + auto active_gate = math::detail::GetActivationType( + context.Attr("gate_activation")); + +#ifdef PADDLE_WITH_MKLML + auto blas = math::GetBlas(dev_ctx); + // TODO(TJ): make a class + T* packed_gate = blas.GEMM_ALLOC(CblasBMatrix, 1 /*height of C*/, + frame_size * 2 /*width of weight*/, + frame_size /*height of height*/); + PADDLE_ENFORCE(packed_gate); + blas.GEMM_PACK(CblasBMatrix, CblasNoTrans, 1 /*cur bs?*/, frame_size * 2, + frame_size, T(1.0), gru_value.gate_weight, frame_size * 2, + packed_gate); + T* packed_state = blas.GEMM_ALLOC(CblasBMatrix, 1 /*height of C*/, + frame_size /*width of weight*/, + frame_size /*height of height*/); + PADDLE_ENFORCE(packed_state); + blas.GEMM_PACK(CblasBMatrix, CblasNoTrans, 1 /*cur bs?*/, frame_size, + frame_size, T(1.0), gru_value.state_weight, frame_size, + packed_state); +#endif + for (size_t n = 0; n < num_batch; n++) { + int bstart = static_cast(batch_starts[n]); + int bend = static_cast(batch_starts[n + 1]); + int cur_batch_size = bend - bstart; + + Tensor gate_t = batch_gate->Slice(bstart, bend); + Tensor reset_hidden_prev_t = batch_reset_hidden_prev->Slice(bstart, bend); + Tensor hidden_t = batch_hidden->Slice(bstart, bend); + gru_value.output_value = hidden_t.data(); + gru_value.gate_value = gate_t.data(); + gru_value.reset_output_value = reset_hidden_prev_t.data(); + +#ifdef PADDLE_WITH_MKLML + if (gru_value.prev_out_value) { + blas.GEMM_COMPUTE(CblasNoTrans, CblasPacked, cur_batch_size, + frame_size * 2, frame_size, gru_value.prev_out_value, + frame_size, packed_gate, frame_size * 2, T(1), + gru_value.gate_value, frame_size * 3); + } + + math::detail::forward_reset_output( + math::detail::forward::gru_resetOutput(), gru_value, frame_size, + cur_batch_size, active_gate); + + if (gru_value.prev_out_value) { + blas.GEMM_COMPUTE( + CblasNoTrans, CblasPacked, cur_batch_size, frame_size, frame_size, + gru_value.reset_output_value, frame_size, packed_state, frame_size, + T(1), gru_value.gate_value + frame_size * 2, frame_size * 3); + } + + math::detail::forward_final_output( + math::detail::forward::gru_finalOutput(), gru_value, frame_size, + cur_batch_size, active_node); +#else + math::GRUUnitFunctor::compute( + dev_ctx, gru_value, frame_size, cur_batch_size, active_node, + active_gate); +#endif + gru_value.prev_out_value = gru_value.output_value; + } +#ifdef PADDLE_WITH_MKLML + blas.GEMM_FREE(packed_gate); + blas.GEMM_FREE(packed_state); +#endif + + math::Batch2LoDTensorFunctor to_seq; + batch_hidden->set_lod(batch_gate->lod()); + to_seq(dev_ctx, *batch_hidden, hidden); + } + + void Compute(const framework::ExecutionContext& context) const override { + BatchCompute(context); + } +}; + } // namespace operators } // namespace paddle @@ -218,9 +351,8 @@ namespace ops = paddle::operators; REGISTER_OPERATOR(gru, ops::GRUOp, ops::GRUOpMaker, paddle::framework::DefaultGradOpDescMaker); REGISTER_OPERATOR(gru_grad, ops::GRUGradOp); -REGISTER_OP_CPU_KERNEL( - gru, ops::GRUKernel, - ops::GRUKernel); +REGISTER_OP_CPU_KERNEL(gru, ops::GRUCPUKernel, + ops::GRUCPUKernel); REGISTER_OP_CPU_KERNEL( gru_grad, ops::GRUGradKernel, ops::GRUGradKernel); diff --git a/paddle/fluid/operators/gru_op.cu.cc b/paddle/fluid/operators/gru_op.cu.cc index baf455a840314d1ab94eb8e0a2e5c660ba4202da..55721c283dd18c2f9642563a9ce1eabfce16fd7b 100644 --- a/paddle/fluid/operators/gru_op.cu.cc +++ b/paddle/fluid/operators/gru_op.cu.cc @@ -14,6 +14,96 @@ limitations under the License. */ #include "paddle/fluid/operators/gru_op.h" +namespace paddle { +namespace operators { + +template +class GRUKernel : public framework::OpKernel { + public: + void BatchCompute(const framework::ExecutionContext& context) const { + auto* input = context.Input("Input"); + auto* h0 = context.Input("H0"); + auto* weight = context.Input("Weight"); + const T* weight_data = weight->data(); + auto* bias = context.Input("Bias"); + auto* batch_gate = context.Output("BatchGate"); + batch_gate->mutable_data(context.GetPlace()); + auto* batch_reset_hidden_prev = + context.Output("BatchResetHiddenPrev"); + batch_reset_hidden_prev->mutable_data(context.GetPlace()); + auto* batch_hidden = context.Output("BatchHidden"); + batch_hidden->mutable_data(context.GetPlace()); + auto* hidden = context.Output("Hidden"); + hidden->mutable_data(context.GetPlace()); + + auto hidden_dims = hidden->dims(); + + bool is_reverse = context.Attr("is_reverse"); + math::LoDTensor2BatchFunctor to_batch; + auto& dev_ctx = context.template device_context(); + to_batch(dev_ctx, *input, batch_gate, true, is_reverse); + + if (bias) { + math::RowwiseAdd add_bias; + add_bias(dev_ctx, *batch_gate, *bias, batch_gate); + } + + int frame_size = hidden_dims[1]; + math::GRUMetaValue gru_value; + gru_value.gate_weight = const_cast(weight_data); + gru_value.state_weight = + const_cast(weight_data + 2 * frame_size * frame_size); + Tensor ordered_h0; + + framework::Vector order(batch_gate->lod()[2]); + + if (h0) { + // Since the batch computing for GRU reorders the input sequences + // according to their length. The initialized cell state also needs + // to reorder. + ReorderInitState( + context.template device_context(), *h0, order, + &ordered_h0, true); + gru_value.prev_out_value = ordered_h0.data(); + } else { + gru_value.prev_out_value = nullptr; + } + auto batch_starts = batch_gate->lod()[0]; + size_t num_batch = batch_starts.size() - 1; + auto active_node = math::detail::GetActivationType( + context.Attr("activation")); + auto active_gate = math::detail::GetActivationType( + context.Attr("gate_activation")); + for (size_t n = 0; n < num_batch; n++) { + int bstart = static_cast(batch_starts[n]); + int bend = static_cast(batch_starts[n + 1]); + int cur_batch_size = bend - bstart; + + Tensor gate_t = batch_gate->Slice(bstart, bend); + Tensor reset_hidden_prev_t = batch_reset_hidden_prev->Slice(bstart, bend); + Tensor hidden_t = batch_hidden->Slice(bstart, bend); + gru_value.output_value = hidden_t.data(); + gru_value.gate_value = gate_t.data(); + gru_value.reset_output_value = reset_hidden_prev_t.data(); + math::GRUUnitFunctor::compute( + dev_ctx, gru_value, frame_size, cur_batch_size, active_node, + active_gate); + gru_value.prev_out_value = gru_value.output_value; + } + + math::Batch2LoDTensorFunctor to_seq; + batch_hidden->set_lod(batch_gate->lod()); + to_seq(dev_ctx, *batch_hidden, hidden); + } + + void Compute(const framework::ExecutionContext& context) const override { + BatchCompute(context); + } +}; + +} // namespace operators +} // namespace paddle + namespace ops = paddle::operators; REGISTER_OP_CUDA_KERNEL( gru, ops::GRUKernel, diff --git a/paddle/fluid/operators/gru_op.h b/paddle/fluid/operators/gru_op.h index a9450337e74fe89e16524ff03346bebe43bb2bd3..0bf4e6bc447ad02cc621dd6c45afafcf4aa992e5 100644 --- a/paddle/fluid/operators/gru_op.h +++ b/paddle/fluid/operators/gru_op.h @@ -40,129 +40,6 @@ inline void ReorderInitState(const DeviceContext& ctx, row_shuffle(ctx, src, index_lod, dst, indexed_src); } -template -class GRUKernel : public framework::OpKernel { - public: - void BatchCompute(const framework::ExecutionContext& context) const { - auto* input = context.Input("Input"); - auto* h0 = context.Input("H0"); - auto* weight = context.Input("Weight"); - const T* weight_data = weight->data(); - auto* bias = context.Input("Bias"); - auto* batch_gate = context.Output("BatchGate"); - batch_gate->mutable_data(context.GetPlace()); - auto* batch_reset_hidden_prev = - context.Output("BatchResetHiddenPrev"); - batch_reset_hidden_prev->mutable_data(context.GetPlace()); - auto* batch_hidden = context.Output("BatchHidden"); - batch_hidden->mutable_data(context.GetPlace()); - auto* hidden = context.Output("Hidden"); - hidden->mutable_data(context.GetPlace()); - - auto hidden_dims = hidden->dims(); - - bool is_reverse = context.Attr("is_reverse"); - math::LoDTensor2BatchFunctor to_batch; - auto& dev_ctx = context.template device_context(); - to_batch(dev_ctx, *input, batch_gate, true, is_reverse); - - if (bias) { - math::RowwiseAdd add_bias; - add_bias(dev_ctx, *batch_gate, *bias, batch_gate); - } - - int frame_size = hidden_dims[1]; - math::GRUMetaValue gru_value; - gru_value.gate_weight = const_cast(weight_data); - gru_value.state_weight = - const_cast(weight_data + 2 * frame_size * frame_size); - Tensor ordered_h0; - - framework::Vector order(batch_gate->lod()[2]); - - if (h0) { - // Since the batch computing for GRU reorders the input sequences - // according to their length. The initialized cell state also needs - // to reorder. - ReorderInitState( - context.template device_context(), *h0, order, - &ordered_h0, true); - gru_value.prev_out_value = ordered_h0.data(); - } else { - gru_value.prev_out_value = nullptr; - } - auto batch_starts = batch_gate->lod()[0]; - size_t num_batch = batch_starts.size() - 1; - auto active_node = math::detail::GetActivationType( - context.Attr("activation")); - auto active_gate = math::detail::GetActivationType( - context.Attr("gate_activation")); - auto blas = math::GetBlas(dev_ctx); - - // TODO(TJ): make a class, make one pack - T* packed_gate = blas.GEMM_ALLOC(CblasBMatrix, 1 /*height of C*/, - frame_size * 2 /*width of weight*/, - frame_size /*height of height*/); - PADDLE_ENFORCE(packed_gate); - blas.GEMM_PACK(CblasBMatrix, CblasNoTrans, 1 /*cur bs?*/, frame_size * 2, - frame_size, T(1.0), gru_value.gate_weight, frame_size * 2, - packed_gate); - T* packed_state = blas.GEMM_ALLOC(CblasBMatrix, 1 /*height of C*/, - frame_size /*width of weight*/, - frame_size /*height of height*/); - PADDLE_ENFORCE(packed_state); - blas.GEMM_PACK(CblasBMatrix, CblasNoTrans, 1 /*cur bs?*/, frame_size, - frame_size, T(1.0), gru_value.state_weight, frame_size, - packed_state); - - for (size_t n = 0; n < num_batch; n++) { - int bstart = static_cast(batch_starts[n]); - int bend = static_cast(batch_starts[n + 1]); - int cur_batch_size = bend - bstart; - - Tensor gate_t = batch_gate->Slice(bstart, bend); - Tensor reset_hidden_prev_t = batch_reset_hidden_prev->Slice(bstart, bend); - Tensor hidden_t = batch_hidden->Slice(bstart, bend); - gru_value.output_value = hidden_t.data(); - gru_value.gate_value = gate_t.data(); - gru_value.reset_output_value = reset_hidden_prev_t.data(); - if (gru_value.prev_out_value) { - blas.GEMM_COMPUTE(CblasNoTrans, CblasPacked, cur_batch_size, - frame_size * 2, frame_size, gru_value.prev_out_value, - frame_size, packed_gate, frame_size * 2, T(1), - gru_value.gate_value, frame_size * 3); - } - - math::detail::forward_reset_output( - math::detail::forward::gru_resetOutput(), gru_value, frame_size, - cur_batch_size, active_gate); - - if (gru_value.prev_out_value) { - blas.GEMM_COMPUTE( - CblasNoTrans, CblasPacked, cur_batch_size, frame_size, frame_size, - gru_value.reset_output_value, frame_size, packed_state, frame_size, - T(1), gru_value.gate_value + frame_size * 2, frame_size * 3); - } - - math::detail::forward_final_output( - math::detail::forward::gru_finalOutput(), gru_value, frame_size, - cur_batch_size, active_node); - - gru_value.prev_out_value = gru_value.output_value; - } - blas.GEMM_FREE(packed_gate); - blas.GEMM_FREE(packed_state); - - math::Batch2LoDTensorFunctor to_seq; - batch_hidden->set_lod(batch_gate->lod()); - to_seq(dev_ctx, *batch_hidden, hidden); - } - - void Compute(const framework::ExecutionContext& context) const override { - BatchCompute(context); - } -}; - template class GRUGradKernel : public framework::OpKernel { public: