提交 8c23f7c4 编写于 作者: T tensor-tang

fix blas and use packed weight

上级 d9cc6b18
......@@ -98,6 +98,23 @@ class GRUKernel : public framework::OpKernel<T> {
auto active_gate = math::detail::GetActivationType(
context.Attr<std::string>("gate_activation"));
auto blas = math::GetBlas<DeviceContext, T>(dev_ctx);
// TODO(TJ): make a class, make one pack
T* packed_gate = blas.GEMM_ALLOC(CblasBMatrix, 1 /*height of C*/,
frame_size * 2 /*width of weight*/,
frame_size /*height of height*/);
PADDLE_ENFORCE(packed_gate);
blas.GEMM_PACK(CblasBMatrix, CblasNoTrans, 1 /*cur bs?*/, frame_size * 2,
frame_size, T(1.0), gru_value.gate_weight, frame_size * 2,
packed_gate);
T* packed_state = blas.GEMM_ALLOC(CblasBMatrix, 1 /*height of C*/,
frame_size /*width of weight*/,
frame_size /*height of height*/);
PADDLE_ENFORCE(packed_state);
blas.GEMM_PACK(CblasBMatrix, CblasNoTrans, 1 /*cur bs?*/, frame_size,
frame_size, T(1.0), gru_value.state_weight, frame_size,
packed_state);
for (size_t n = 0; n < num_batch; n++) {
int bstart = static_cast<int>(batch_starts[n]);
int bend = static_cast<int>(batch_starts[n + 1]);
......@@ -110,9 +127,10 @@ class GRUKernel : public framework::OpKernel<T> {
gru_value.gate_value = gate_t.data<T>();
gru_value.reset_output_value = reset_hidden_prev_t.data<T>();
if (gru_value.prev_out_value) {
blas.GEMM(false, false, cur_batch_size, frame_size * 2, frame_size, 1,
gru_value.prev_out_value, frame_size, gru_value.gate_weight,
frame_size * 2, 1, gru_value.gate_value, frame_size * 3);
blas.GEMM_COMPUTE(CblasNoTrans, CblasPacked, cur_batch_size,
frame_size * 2, frame_size, gru_value.prev_out_value,
frame_size, packed_gate, frame_size * 2, T(1),
gru_value.gate_value, frame_size * 3);
}
math::detail::forward_reset_output(
......@@ -120,10 +138,10 @@ class GRUKernel : public framework::OpKernel<T> {
cur_batch_size, active_gate);
if (gru_value.prev_out_value) {
blas.GEMM(false, false, cur_batch_size, frame_size, frame_size, 1,
gru_value.reset_output_value, frame_size,
gru_value.state_weight, frame_size, 1,
gru_value.gate_value + frame_size * 2, frame_size * 3);
blas.GEMM_COMPUTE(
CblasNoTrans, CblasPacked, cur_batch_size, frame_size, frame_size,
gru_value.reset_output_value, frame_size, packed_state, frame_size,
T(1), gru_value.gate_value + frame_size * 2, frame_size * 3);
}
math::detail::forward_final_output(
......@@ -132,6 +150,8 @@ class GRUKernel : public framework::OpKernel<T> {
gru_value.prev_out_value = gru_value.output_value;
}
blas.GEMM_FREE(packed_gate);
blas.GEMM_FREE(packed_state);
math::Batch2LoDTensorFunctor<DeviceContext, T> to_seq;
batch_hidden->set_lod(batch_gate->lod());
......
......@@ -165,7 +165,7 @@ class BlasT : private Blas<DeviceContext> {
template <typename... ARGS>
T* GEMM_ALLOC(ARGS... args) const {
Base()->template GEMM_ALLOC<T>(args...);
return Base()->template GEMM_ALLOC<T>(args...);
}
template <typename... ARGS>
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册