未验证 提交 ff92b6ba 编写于 作者: T tensor-tang 提交者: GitHub

Merge pull request #12531 from tensor-tang/refine/op/gru

Refine gru cpu forward
...@@ -14,6 +14,11 @@ limitations under the License. */ ...@@ -14,6 +14,11 @@ limitations under the License. */
#include "paddle/fluid/operators/gru_op.h" #include "paddle/fluid/operators/gru_op.h"
#include <string> #include <string>
#include "paddle/fluid/operators/math/blas.h"
#include "paddle/fluid/operators/math/detail/gru_cpu_kernel.h"
#include "paddle/fluid/operators/math/detail/gru_kernel.h"
DECLARE_int32(paddle_num_threads);
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -211,6 +216,158 @@ class GRUGradOp : public framework::OperatorWithKernel { ...@@ -211,6 +216,158 @@ class GRUGradOp : public framework::OperatorWithKernel {
} }
}; };
template <typename T>
class GRUCPUKernel : public framework::OpKernel<T> {
public:
void BatchCompute(const framework::ExecutionContext& context) const {
using DeviceContext = paddle::platform::CPUDeviceContext;
auto* input = context.Input<LoDTensor>("Input");
auto* h0 = context.Input<Tensor>("H0");
auto* weight = context.Input<Tensor>("Weight");
const T* weight_data = weight->data<T>();
auto* bias = context.Input<Tensor>("Bias");
auto* batch_gate = context.Output<LoDTensor>("BatchGate");
batch_gate->mutable_data<T>(context.GetPlace());
auto* batch_reset_hidden_prev =
context.Output<LoDTensor>("BatchResetHiddenPrev");
batch_reset_hidden_prev->mutable_data<T>(context.GetPlace());
auto* batch_hidden = context.Output<LoDTensor>("BatchHidden");
batch_hidden->mutable_data<T>(context.GetPlace());
auto* hidden = context.Output<LoDTensor>("Hidden");
hidden->mutable_data<T>(context.GetPlace());
auto hidden_dims = hidden->dims();
bool is_reverse = context.Attr<bool>("is_reverse");
math::LoDTensor2BatchFunctor<DeviceContext, T> to_batch;
auto& dev_ctx = context.template device_context<DeviceContext>();
to_batch(dev_ctx, *input, batch_gate, true, is_reverse);
if (bias) {
math::RowwiseAdd<DeviceContext, T> add_bias;
add_bias(dev_ctx, *batch_gate, *bias, batch_gate);
}
int frame_size = hidden_dims[1];
math::GRUMetaValue<T> gru_value;
gru_value.gate_weight = const_cast<T*>(weight_data);
gru_value.state_weight =
const_cast<T*>(weight_data + 2 * frame_size * frame_size);
Tensor ordered_h0;
framework::Vector<size_t> order(batch_gate->lod()[2]);
if (h0) {
// Since the batch computing for GRU reorders the input sequences
// according to their length. The initialized cell state also needs
// to reorder.
ReorderInitState<DeviceContext, T>(
context.template device_context<DeviceContext>(), *h0, order,
&ordered_h0, true);
gru_value.prev_out_value = ordered_h0.data<T>();
} else {
gru_value.prev_out_value = nullptr;
}
auto batch_starts = batch_gate->lod()[0];
size_t seq_len = batch_starts.size() - 1;
auto active_node = math::detail::GetActivationType(
context.Attr<std::string>("activation"));
auto active_gate = math::detail::GetActivationType(
context.Attr<std::string>("gate_activation"));
#ifdef PADDLE_WITH_MKLML
// use MKL packed to speedup GEMM
if (FLAGS_paddle_num_threads >= 4) {
auto blas = math::GetBlas<DeviceContext, T>(dev_ctx);
T* packed_gate = blas.GEMM_ALLOC(CblasBMatrix, 1 /*height of C*/,
frame_size * 2 /*width of weight*/,
frame_size /*height of height*/);
PADDLE_ENFORCE(packed_gate);
blas.GEMM_PACK(CblasBMatrix, CblasNoTrans, 1 /*cur bs?*/, frame_size * 2,
frame_size, T(1.0), gru_value.gate_weight, frame_size * 2,
packed_gate);
T* packed_state = blas.GEMM_ALLOC(CblasBMatrix, 1 /*height of C*/,
frame_size /*width of weight*/,
frame_size /*height of height*/);
PADDLE_ENFORCE(packed_state);
blas.GEMM_PACK(CblasBMatrix, CblasNoTrans, 1 /*cur bs?*/, frame_size,
frame_size, T(1.0), gru_value.state_weight, frame_size,
packed_state);
for (size_t n = 0; n < seq_len; n++) {
int bstart = static_cast<int>(batch_starts[n]);
int bend = static_cast<int>(batch_starts[n + 1]);
int cur_batch_size = bend - bstart;
Tensor gate_t = batch_gate->Slice(bstart, bend);
Tensor reset_hidden_prev_t =
batch_reset_hidden_prev->Slice(bstart, bend);
Tensor hidden_t = batch_hidden->Slice(bstart, bend);
gru_value.output_value = hidden_t.data<T>();
gru_value.gate_value = gate_t.data<T>();
gru_value.reset_output_value = reset_hidden_prev_t.data<T>();
if (gru_value.prev_out_value) {
blas.GEMM_COMPUTE(
CblasNoTrans, CblasPacked, cur_batch_size, frame_size * 2,
frame_size, gru_value.prev_out_value, frame_size, packed_gate,
frame_size * 2, T(1), gru_value.gate_value, frame_size * 3);
}
math::detail::forward_reset_output(
math::detail::forward::gru_resetOutput<T>(), gru_value, frame_size,
cur_batch_size, active_gate);
if (gru_value.prev_out_value) {
blas.GEMM_COMPUTE(
CblasNoTrans, CblasPacked, cur_batch_size, frame_size, frame_size,
gru_value.reset_output_value, frame_size, packed_state,
frame_size, T(1), gru_value.gate_value + frame_size * 2,
frame_size * 3);
}
math::detail::forward_final_output(
math::detail::forward::gru_finalOutput<T>(), gru_value, frame_size,
cur_batch_size, active_node);
gru_value.prev_out_value = gru_value.output_value;
}
blas.GEMM_FREE(packed_gate);
blas.GEMM_FREE(packed_state);
} else {
#endif
for (size_t n = 0; n < seq_len; n++) {
int bstart = static_cast<int>(batch_starts[n]);
int bend = static_cast<int>(batch_starts[n + 1]);
int cur_batch_size = bend - bstart;
Tensor gate_t = batch_gate->Slice(bstart, bend);
Tensor reset_hidden_prev_t =
batch_reset_hidden_prev->Slice(bstart, bend);
Tensor hidden_t = batch_hidden->Slice(bstart, bend);
gru_value.output_value = hidden_t.data<T>();
gru_value.gate_value = gate_t.data<T>();
gru_value.reset_output_value = reset_hidden_prev_t.data<T>();
math::GRUUnitFunctor<DeviceContext, T>::compute(
dev_ctx, gru_value, frame_size, cur_batch_size, active_node,
active_gate);
gru_value.prev_out_value = gru_value.output_value;
}
#ifdef PADDLE_WITH_MKLML
}
#endif
math::Batch2LoDTensorFunctor<DeviceContext, T> to_seq;
batch_hidden->set_lod(batch_gate->lod());
to_seq(dev_ctx, *batch_hidden, hidden);
}
void Compute(const framework::ExecutionContext& context) const override {
BatchCompute(context);
}
};
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
...@@ -218,9 +375,8 @@ namespace ops = paddle::operators; ...@@ -218,9 +375,8 @@ namespace ops = paddle::operators;
REGISTER_OPERATOR(gru, ops::GRUOp, ops::GRUOpMaker, REGISTER_OPERATOR(gru, ops::GRUOp, ops::GRUOpMaker,
paddle::framework::DefaultGradOpDescMaker<true>); paddle::framework::DefaultGradOpDescMaker<true>);
REGISTER_OPERATOR(gru_grad, ops::GRUGradOp); REGISTER_OPERATOR(gru_grad, ops::GRUGradOp);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(gru, ops::GRUCPUKernel<float>,
gru, ops::GRUKernel<paddle::platform::CPUDeviceContext, float>, ops::GRUCPUKernel<double>);
ops::GRUKernel<paddle::platform::CPUDeviceContext, double>);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
gru_grad, ops::GRUGradKernel<paddle::platform::CPUDeviceContext, float>, gru_grad, ops::GRUGradKernel<paddle::platform::CPUDeviceContext, float>,
ops::GRUGradKernel<paddle::platform::CPUDeviceContext, double>); ops::GRUGradKernel<paddle::platform::CPUDeviceContext, double>);
...@@ -14,6 +14,96 @@ limitations under the License. */ ...@@ -14,6 +14,96 @@ limitations under the License. */
#include "paddle/fluid/operators/gru_op.h" #include "paddle/fluid/operators/gru_op.h"
namespace paddle {
namespace operators {
template <typename DeviceContext, typename T>
class GRUKernel : public framework::OpKernel<T> {
public:
void BatchCompute(const framework::ExecutionContext& context) const {
auto* input = context.Input<LoDTensor>("Input");
auto* h0 = context.Input<Tensor>("H0");
auto* weight = context.Input<Tensor>("Weight");
const T* weight_data = weight->data<T>();
auto* bias = context.Input<Tensor>("Bias");
auto* batch_gate = context.Output<LoDTensor>("BatchGate");
batch_gate->mutable_data<T>(context.GetPlace());
auto* batch_reset_hidden_prev =
context.Output<LoDTensor>("BatchResetHiddenPrev");
batch_reset_hidden_prev->mutable_data<T>(context.GetPlace());
auto* batch_hidden = context.Output<LoDTensor>("BatchHidden");
batch_hidden->mutable_data<T>(context.GetPlace());
auto* hidden = context.Output<LoDTensor>("Hidden");
hidden->mutable_data<T>(context.GetPlace());
auto hidden_dims = hidden->dims();
bool is_reverse = context.Attr<bool>("is_reverse");
math::LoDTensor2BatchFunctor<DeviceContext, T> to_batch;
auto& dev_ctx = context.template device_context<DeviceContext>();
to_batch(dev_ctx, *input, batch_gate, true, is_reverse);
if (bias) {
math::RowwiseAdd<DeviceContext, T> add_bias;
add_bias(dev_ctx, *batch_gate, *bias, batch_gate);
}
int frame_size = hidden_dims[1];
math::GRUMetaValue<T> gru_value;
gru_value.gate_weight = const_cast<T*>(weight_data);
gru_value.state_weight =
const_cast<T*>(weight_data + 2 * frame_size * frame_size);
Tensor ordered_h0;
framework::Vector<size_t> order(batch_gate->lod()[2]);
if (h0) {
// Since the batch computing for GRU reorders the input sequences
// according to their length. The initialized cell state also needs
// to reorder.
ReorderInitState<DeviceContext, T>(
context.template device_context<DeviceContext>(), *h0, order,
&ordered_h0, true);
gru_value.prev_out_value = ordered_h0.data<T>();
} else {
gru_value.prev_out_value = nullptr;
}
auto batch_starts = batch_gate->lod()[0];
size_t num_batch = batch_starts.size() - 1;
auto active_node = math::detail::GetActivationType(
context.Attr<std::string>("activation"));
auto active_gate = math::detail::GetActivationType(
context.Attr<std::string>("gate_activation"));
for (size_t n = 0; n < num_batch; n++) {
int bstart = static_cast<int>(batch_starts[n]);
int bend = static_cast<int>(batch_starts[n + 1]);
int cur_batch_size = bend - bstart;
Tensor gate_t = batch_gate->Slice(bstart, bend);
Tensor reset_hidden_prev_t = batch_reset_hidden_prev->Slice(bstart, bend);
Tensor hidden_t = batch_hidden->Slice(bstart, bend);
gru_value.output_value = hidden_t.data<T>();
gru_value.gate_value = gate_t.data<T>();
gru_value.reset_output_value = reset_hidden_prev_t.data<T>();
math::GRUUnitFunctor<DeviceContext, T>::compute(
dev_ctx, gru_value, frame_size, cur_batch_size, active_node,
active_gate);
gru_value.prev_out_value = gru_value.output_value;
}
math::Batch2LoDTensorFunctor<DeviceContext, T> to_seq;
batch_hidden->set_lod(batch_gate->lod());
to_seq(dev_ctx, *batch_hidden, hidden);
}
void Compute(const framework::ExecutionContext& context) const override {
BatchCompute(context);
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL( REGISTER_OP_CUDA_KERNEL(
gru, ops::GRUKernel<paddle::platform::CUDADeviceContext, float>, gru, ops::GRUKernel<paddle::platform::CUDADeviceContext, float>,
......
...@@ -37,90 +37,6 @@ inline void ReorderInitState(const DeviceContext& ctx, ...@@ -37,90 +37,6 @@ inline void ReorderInitState(const DeviceContext& ctx,
row_shuffle(ctx, src, index_lod, dst, indexed_src); row_shuffle(ctx, src, index_lod, dst, indexed_src);
} }
template <typename DeviceContext, typename T>
class GRUKernel : public framework::OpKernel<T> {
public:
void BatchCompute(const framework::ExecutionContext& context) const {
auto* input = context.Input<LoDTensor>("Input");
auto* h0 = context.Input<Tensor>("H0");
auto* weight = context.Input<Tensor>("Weight");
const T* weight_data = weight->data<T>();
auto* bias = context.Input<Tensor>("Bias");
auto* batch_gate = context.Output<LoDTensor>("BatchGate");
batch_gate->mutable_data<T>(context.GetPlace());
auto* batch_reset_hidden_prev =
context.Output<LoDTensor>("BatchResetHiddenPrev");
batch_reset_hidden_prev->mutable_data<T>(context.GetPlace());
auto* batch_hidden = context.Output<LoDTensor>("BatchHidden");
batch_hidden->mutable_data<T>(context.GetPlace());
auto* hidden = context.Output<LoDTensor>("Hidden");
hidden->mutable_data<T>(context.GetPlace());
auto hidden_dims = hidden->dims();
bool is_reverse = context.Attr<bool>("is_reverse");
math::LoDTensor2BatchFunctor<DeviceContext, T> to_batch;
auto& dev_ctx = context.template device_context<DeviceContext>();
to_batch(dev_ctx, *input, batch_gate, true, is_reverse);
if (bias) {
math::RowwiseAdd<DeviceContext, T> add_bias;
add_bias(dev_ctx, *batch_gate, *bias, batch_gate);
}
int frame_size = hidden_dims[1];
math::GRUMetaValue<T> gru_value;
gru_value.gate_weight = const_cast<T*>(weight_data);
gru_value.state_weight =
const_cast<T*>(weight_data + 2 * frame_size * frame_size);
Tensor ordered_h0;
framework::Vector<size_t> order(batch_gate->lod()[2]);
if (h0) {
// Since the batch computing for GRU reorders the input sequences
// according to their length. The initialized cell state also needs
// to reorder.
ReorderInitState<DeviceContext, T>(
context.template device_context<DeviceContext>(), *h0, order,
&ordered_h0, true);
gru_value.prev_out_value = ordered_h0.data<T>();
} else {
gru_value.prev_out_value = nullptr;
}
auto batch_starts = batch_gate->lod()[0];
size_t num_batch = batch_starts.size() - 1;
auto active_node = math::detail::GetActivationType(
context.Attr<std::string>("activation"));
auto active_gate = math::detail::GetActivationType(
context.Attr<std::string>("gate_activation"));
for (size_t n = 0; n < num_batch; n++) {
int bstart = static_cast<int>(batch_starts[n]);
int bend = static_cast<int>(batch_starts[n + 1]);
int cur_batch_size = bend - bstart;
Tensor gate_t = batch_gate->Slice(bstart, bend);
Tensor reset_hidden_prev_t = batch_reset_hidden_prev->Slice(bstart, bend);
Tensor hidden_t = batch_hidden->Slice(bstart, bend);
gru_value.output_value = hidden_t.data<T>();
gru_value.gate_value = gate_t.data<T>();
gru_value.reset_output_value = reset_hidden_prev_t.data<T>();
math::GRUUnitFunctor<DeviceContext, T>::compute(
dev_ctx, gru_value, frame_size, cur_batch_size, active_node,
active_gate);
gru_value.prev_out_value = gru_value.output_value;
}
math::Batch2LoDTensorFunctor<DeviceContext, T> to_seq;
batch_hidden->set_lod(batch_gate->lod());
to_seq(dev_ctx, *batch_hidden, hidden);
}
void Compute(const framework::ExecutionContext& context) const override {
BatchCompute(context);
}
};
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
class GRUGradKernel : public framework::OpKernel<T> { class GRUGradKernel : public framework::OpKernel<T> {
public: public:
......
...@@ -90,6 +90,25 @@ class Blas { ...@@ -90,6 +90,25 @@ class Blas {
void GEMM(bool transA, bool transB, int M, int N, int K, T alpha, const T* A, void GEMM(bool transA, bool transB, int M, int N, int K, T alpha, const T* A,
int lda, const T* B, int ldb, T beta, T* C, int ldc) const; int lda, const T* B, int ldb, T beta, T* C, int ldc) const;
#ifdef PADDLE_WITH_MKLML
template <typename T>
T* GEMM_ALLOC(const CBLAS_IDENTIFIER id, const int M, const int N,
const int K) const;
template <typename T>
void GEMM_PACK(const CBLAS_IDENTIFIER id, const CBLAS_TRANSPOSE trans, int M,
int N, int K, const T alpha, const T* src, const int ld,
T* dst) const;
template <typename T>
void GEMM_COMPUTE(int transA, int transB, int M, int N, int K, const T* A,
const int lda, const T* B, const int ldb, T beta, T* C,
const int ldc) const;
template <typename T>
void GEMM_FREE(T* data) const;
#endif
template <typename T> template <typename T>
void MatMul(const framework::Tensor& mat_a, bool trans_a, void MatMul(const framework::Tensor& mat_a, bool trans_a,
const framework::Tensor& mat_b, bool trans_b, T alpha, const framework::Tensor& mat_b, bool trans_b, T alpha,
...@@ -146,6 +165,28 @@ class BlasT : private Blas<DeviceContext> { ...@@ -146,6 +165,28 @@ class BlasT : private Blas<DeviceContext> {
Base()->template GEMM<T>(args...); Base()->template GEMM<T>(args...);
} }
#ifdef PADDLE_WITH_MKLML
template <typename... ARGS>
T* GEMM_ALLOC(ARGS... args) const {
return Base()->template GEMM_ALLOC<T>(args...);
}
template <typename... ARGS>
void GEMM_PACK(ARGS... args) const {
Base()->template GEMM_PACK<T>(args...);
}
template <typename... ARGS>
void GEMM_COMPUTE(ARGS... args) const {
Base()->template GEMM_COMPUTE<T>(args...);
}
template <typename... ARGS>
void GEMM_FREE(ARGS... args) const {
Base()->template GEMM_FREE<T>(args...);
}
#endif
template <typename... ARGS> template <typename... ARGS>
void MatMul(ARGS... args) const { void MatMul(ARGS... args) const {
Base()->template MatMul<T>(args...); Base()->template MatMul<T>(args...);
......
...@@ -31,6 +31,26 @@ struct CBlas<float> { ...@@ -31,6 +31,26 @@ struct CBlas<float> {
platform::dynload::cblas_sgemm(args...); platform::dynload::cblas_sgemm(args...);
} }
template <typename... ARGS>
static float *GEMM_ALLOC(ARGS... args) {
return platform::dynload::cblas_sgemm_alloc(args...);
}
template <typename... ARGS>
static void GEMM_PACK(ARGS... args) {
platform::dynload::cblas_sgemm_pack(args...);
}
template <typename... ARGS>
static void GEMM_COMPUTE(ARGS... args) {
platform::dynload::cblas_sgemm_compute(args...);
}
template <typename... ARGS>
static void GEMM_FREE(ARGS... args) {
platform::dynload::cblas_sgemm_free(args...);
}
#ifdef PADDLE_WITH_LIBXSMM #ifdef PADDLE_WITH_LIBXSMM
template <typename... ARGS> template <typename... ARGS>
static void SMM_GEMM(ARGS... args) { static void SMM_GEMM(ARGS... args) {
...@@ -71,6 +91,26 @@ struct CBlas<double> { ...@@ -71,6 +91,26 @@ struct CBlas<double> {
platform::dynload::cblas_dgemm(args...); platform::dynload::cblas_dgemm(args...);
} }
template <typename... ARGS>
static double *GEMM_ALLOC(ARGS... args) {
return platform::dynload::cblas_dgemm_alloc(args...);
}
template <typename... ARGS>
static void GEMM_PACK(ARGS... args) {
platform::dynload::cblas_dgemm_pack(args...);
}
template <typename... ARGS>
static void GEMM_COMPUTE(ARGS... args) {
platform::dynload::cblas_dgemm_compute(args...);
}
template <typename... ARGS>
static void GEMM_FREE(ARGS... args) {
platform::dynload::cblas_dgemm_free(args...);
}
#ifdef PADDLE_WITH_LIBXSMM #ifdef PADDLE_WITH_LIBXSMM
template <typename... ARGS> template <typename... ARGS>
static void SMM_GEMM(ARGS... args) { static void SMM_GEMM(ARGS... args) {
...@@ -224,6 +264,41 @@ inline void GEMM_WARP(CBLAS_ORDER order, CBLAS_TRANSPOSE transA, ...@@ -224,6 +264,41 @@ inline void GEMM_WARP(CBLAS_ORDER order, CBLAS_TRANSPOSE transA,
beta, C, ldc); beta, C, ldc);
} }
#ifdef PADDLE_WITH_MKLML
template <>
template <typename T>
T *Blas<platform::CPUDeviceContext>::GEMM_ALLOC(const CBLAS_IDENTIFIER id,
const int M, const int N,
const int K) const {
return CBlas<T>::GEMM_ALLOC(id, M, N, K);
}
template <>
template <typename T>
void Blas<platform::CPUDeviceContext>::GEMM_PACK(const CBLAS_IDENTIFIER id,
const CBLAS_TRANSPOSE trans,
int M, int N, int K,
const T alpha, const T *src,
const int ld, T *dst) const {
CBlas<T>::GEMM_PACK(CblasRowMajor, id, trans, M, N, K, alpha, src, ld, dst);
}
template <>
template <typename T>
void Blas<platform::CPUDeviceContext>::GEMM_COMPUTE(
int transA, int transB, int M, int N, int K, const T *A, const int lda,
const T *B, const int ldb, T beta, T *C, const int ldc) const {
CBlas<T>::GEMM_COMPUTE(CblasRowMajor, transA, transB, M, N, K, A, lda, B, ldb,
beta, C, ldc);
}
template <>
template <typename T>
void Blas<platform::CPUDeviceContext>::GEMM_FREE(T *data) const {
CBlas<T>::GEMM_FREE(data);
}
#endif
template <> template <>
template <typename T> template <typename T>
void Blas<platform::CPUDeviceContext>::GEMM(CBLAS_TRANSPOSE transA, void Blas<platform::CPUDeviceContext>::GEMM(CBLAS_TRANSPOSE transA,
......
...@@ -60,6 +60,14 @@ extern void* mklml_dso_handle; ...@@ -60,6 +60,14 @@ extern void* mklml_dso_handle;
__macro(cblas_dgemm_batch); \ __macro(cblas_dgemm_batch); \
__macro(vsAdd); \ __macro(vsAdd); \
__macro(vdAdd); \ __macro(vdAdd); \
__macro(cblas_sgemm_alloc); \
__macro(cblas_sgemm_pack); \
__macro(cblas_sgemm_compute); \
__macro(cblas_sgemm_free); \
__macro(cblas_dgemm_alloc); \
__macro(cblas_dgemm_pack); \
__macro(cblas_dgemm_compute); \
__macro(cblas_dgemm_free); \
__macro(MKL_Set_Num_Threads) __macro(MKL_Set_Num_Threads)
MKLML_ROUTINE_EACH(DECLARE_DYNAMIC_LOAD_MKLML_WRAP); MKLML_ROUTINE_EACH(DECLARE_DYNAMIC_LOAD_MKLML_WRAP);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册