提交 5377edd2 编写于 作者: T tensor-tang

refine packed condition

上级 3bf3e77a
...@@ -14,6 +14,11 @@ limitations under the License. */ ...@@ -14,6 +14,11 @@ limitations under the License. */
#include "paddle/fluid/operators/gru_op.h" #include "paddle/fluid/operators/gru_op.h"
#include <string> #include <string>
#include "paddle/fluid/operators/math/blas.h"
#include "paddle/fluid/operators/math/detail/gru_cpu_kernel.h"
#include "paddle/fluid/operators/math/detail/gru_kernel.h"
DECLARE_int32(paddle_num_threads);
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -264,76 +269,94 @@ class GRUCPUKernel : public framework::OpKernel<T> { ...@@ -264,76 +269,94 @@ class GRUCPUKernel : public framework::OpKernel<T> {
gru_value.prev_out_value = nullptr; gru_value.prev_out_value = nullptr;
} }
auto batch_starts = batch_gate->lod()[0]; auto batch_starts = batch_gate->lod()[0];
size_t num_batch = batch_starts.size() - 1; size_t seq_len = batch_starts.size() - 1;
auto active_node = math::detail::GetActivationType( auto active_node = math::detail::GetActivationType(
context.Attr<std::string>("activation")); context.Attr<std::string>("activation"));
auto active_gate = math::detail::GetActivationType( auto active_gate = math::detail::GetActivationType(
context.Attr<std::string>("gate_activation")); context.Attr<std::string>("gate_activation"));
#ifdef PADDLE_WITH_MKLML #ifdef PADDLE_WITH_MKLML
auto blas = math::GetBlas<DeviceContext, T>(dev_ctx); if (FLAGS_paddle_num_threads >= 4) {
// TODO(TJ): make a class auto blas = math::GetBlas<DeviceContext, T>(dev_ctx);
T* packed_gate = blas.GEMM_ALLOC(CblasBMatrix, 1 /*height of C*/, T* packed_gate = blas.GEMM_ALLOC(CblasBMatrix, 1 /*height of C*/,
frame_size * 2 /*width of weight*/, frame_size * 2 /*width of weight*/,
frame_size /*height of height*/); frame_size /*height of height*/);
PADDLE_ENFORCE(packed_gate); PADDLE_ENFORCE(packed_gate);
blas.GEMM_PACK(CblasBMatrix, CblasNoTrans, 1 /*cur bs?*/, frame_size * 2, blas.GEMM_PACK(CblasBMatrix, CblasNoTrans, 1 /*cur bs?*/, frame_size * 2,
frame_size, T(1.0), gru_value.gate_weight, frame_size * 2, frame_size, T(1.0), gru_value.gate_weight, frame_size * 2,
packed_gate); packed_gate);
T* packed_state = blas.GEMM_ALLOC(CblasBMatrix, 1 /*height of C*/, T* packed_state = blas.GEMM_ALLOC(CblasBMatrix, 1 /*height of C*/,
frame_size /*width of weight*/, frame_size /*width of weight*/,
frame_size /*height of height*/); frame_size /*height of height*/);
PADDLE_ENFORCE(packed_state); PADDLE_ENFORCE(packed_state);
blas.GEMM_PACK(CblasBMatrix, CblasNoTrans, 1 /*cur bs?*/, frame_size, blas.GEMM_PACK(CblasBMatrix, CblasNoTrans, 1 /*cur bs?*/, frame_size,
frame_size, T(1.0), gru_value.state_weight, frame_size, frame_size, T(1.0), gru_value.state_weight, frame_size,
packed_state); packed_state);
#endif for (size_t n = 0; n < seq_len; n++) {
for (size_t n = 0; n < num_batch; n++) { int bstart = static_cast<int>(batch_starts[n]);
int bstart = static_cast<int>(batch_starts[n]); int bend = static_cast<int>(batch_starts[n + 1]);
int bend = static_cast<int>(batch_starts[n + 1]); int cur_batch_size = bend - bstart;
int cur_batch_size = bend - bstart;
Tensor gate_t = batch_gate->Slice(bstart, bend);
Tensor reset_hidden_prev_t = batch_reset_hidden_prev->Slice(bstart, bend);
Tensor hidden_t = batch_hidden->Slice(bstart, bend);
gru_value.output_value = hidden_t.data<T>();
gru_value.gate_value = gate_t.data<T>();
gru_value.reset_output_value = reset_hidden_prev_t.data<T>();
#ifdef PADDLE_WITH_MKLML Tensor gate_t = batch_gate->Slice(bstart, bend);
if (gru_value.prev_out_value) { Tensor reset_hidden_prev_t =
blas.GEMM_COMPUTE(CblasNoTrans, CblasPacked, cur_batch_size, batch_reset_hidden_prev->Slice(bstart, bend);
frame_size * 2, frame_size, gru_value.prev_out_value, Tensor hidden_t = batch_hidden->Slice(bstart, bend);
frame_size, packed_gate, frame_size * 2, T(1), gru_value.output_value = hidden_t.data<T>();
gru_value.gate_value, frame_size * 3); gru_value.gate_value = gate_t.data<T>();
} gru_value.reset_output_value = reset_hidden_prev_t.data<T>();
math::detail::forward_reset_output( if (gru_value.prev_out_value) {
math::detail::forward::gru_resetOutput<T>(), gru_value, frame_size, blas.GEMM_COMPUTE(
cur_batch_size, active_gate); CblasNoTrans, CblasPacked, cur_batch_size, frame_size * 2,
frame_size, gru_value.prev_out_value, frame_size, packed_gate,
frame_size * 2, T(1), gru_value.gate_value, frame_size * 3);
}
if (gru_value.prev_out_value) { math::detail::forward_reset_output(
blas.GEMM_COMPUTE( math::detail::forward::gru_resetOutput<T>(), gru_value, frame_size,
CblasNoTrans, CblasPacked, cur_batch_size, frame_size, frame_size, cur_batch_size, active_gate);
gru_value.reset_output_value, frame_size, packed_state, frame_size,
T(1), gru_value.gate_value + frame_size * 2, frame_size * 3); if (gru_value.prev_out_value) {
blas.GEMM_COMPUTE(
CblasNoTrans, CblasPacked, cur_batch_size, frame_size, frame_size,
gru_value.reset_output_value, frame_size, packed_state,
frame_size, T(1), gru_value.gate_value + frame_size * 2,
frame_size * 3);
}
math::detail::forward_final_output(
math::detail::forward::gru_finalOutput<T>(), gru_value, frame_size,
cur_batch_size, active_node);
gru_value.prev_out_value = gru_value.output_value;
} }
math::detail::forward_final_output( blas.GEMM_FREE(packed_gate);
math::detail::forward::gru_finalOutput<T>(), gru_value, frame_size, blas.GEMM_FREE(packed_state);
cur_batch_size, active_node); } else {
#else
math::GRUUnitFunctor<DeviceContext, T>::compute(
dev_ctx, gru_value, frame_size, cur_batch_size, active_node,
active_gate);
#endif #endif
gru_value.prev_out_value = gru_value.output_value; for (size_t n = 0; n < seq_len; n++) {
} int bstart = static_cast<int>(batch_starts[n]);
int bend = static_cast<int>(batch_starts[n + 1]);
int cur_batch_size = bend - bstart;
Tensor gate_t = batch_gate->Slice(bstart, bend);
Tensor reset_hidden_prev_t =
batch_reset_hidden_prev->Slice(bstart, bend);
Tensor hidden_t = batch_hidden->Slice(bstart, bend);
gru_value.output_value = hidden_t.data<T>();
gru_value.gate_value = gate_t.data<T>();
gru_value.reset_output_value = reset_hidden_prev_t.data<T>();
math::GRUUnitFunctor<DeviceContext, T>::compute(
dev_ctx, gru_value, frame_size, cur_batch_size, active_node,
active_gate);
gru_value.prev_out_value = gru_value.output_value;
}
#ifdef PADDLE_WITH_MKLML #ifdef PADDLE_WITH_MKLML
blas.GEMM_FREE(packed_gate); }
blas.GEMM_FREE(packed_state);
#endif #endif
math::Batch2LoDTensorFunctor<DeviceContext, T> to_seq; math::Batch2LoDTensorFunctor<DeviceContext, T> to_seq;
batch_hidden->set_lod(batch_gate->lod()); batch_hidden->set_lod(batch_gate->lod());
to_seq(dev_ctx, *batch_hidden, hidden); to_seq(dev_ctx, *batch_hidden, hidden);
......
...@@ -16,10 +16,7 @@ limitations under the License. */ ...@@ -16,10 +16,7 @@ limitations under the License. */
#include <string> #include <string>
#include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/math/blas.h"
#include "paddle/fluid/operators/math/detail/activation_functions.h" #include "paddle/fluid/operators/math/detail/activation_functions.h"
#include "paddle/fluid/operators/math/detail/gru_cpu_kernel.h"
#include "paddle/fluid/operators/math/detail/gru_kernel.h"
#include "paddle/fluid/operators/math/gru_compute.h" #include "paddle/fluid/operators/math/gru_compute.h"
#include "paddle/fluid/operators/math/math_function.h" #include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/operators/math/sequence2batch.h" #include "paddle/fluid/operators/math/sequence2batch.h"
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册