提交 c888e016 编写于 作者: Y Yu Yang

Refactor GEMM in blas

上级 c93a624b
...@@ -61,8 +61,8 @@ class BilinearTensorProductKernel : public framework::OpKernel<T> { ...@@ -61,8 +61,8 @@ class BilinearTensorProductKernel : public framework::OpKernel<T> {
auto output_col_vec = output_mat.chip(i, 1); auto output_col_vec = output_mat.chip(i, 1);
Tensor weight_mat = Tensor weight_mat =
weight->Slice(i, i + 1).Resize(framework::make_ddim({x_dim, y_dim})); weight->Slice(i, i + 1).Resize(framework::make_ddim({x_dim, y_dim}));
math::gemm<DeviceContext, T>(dev_ctx, CblasNoTrans, CblasNoTrans, math::GetBlas<DeviceContext, T>(dev_ctx).GEMM(
batch_size, y_dim, x_dim, 1, x->data<T>(), CblasNoTrans, CblasNoTrans, batch_size, y_dim, x_dim, 1, x->data<T>(),
weight_mat.data<T>(), 0, left_mul.data<T>()); weight_mat.data<T>(), 0, left_mul.data<T>());
output_col_vec.device(place) = output_col_vec.device(place) =
(left_mul_mat * y_mat).sum(Eigen::DSizes<int, 1>(1)); (left_mul_mat * y_mat).sum(Eigen::DSizes<int, 1>(1));
...@@ -125,6 +125,8 @@ class BilinearTensorProductGradKernel : public framework::OpKernel<T> { ...@@ -125,6 +125,8 @@ class BilinearTensorProductGradKernel : public framework::OpKernel<T> {
set_zero(dev_ctx, d_y, static_cast<T>(0)); set_zero(dev_ctx, d_y, static_cast<T>(0));
} }
auto blas = math::GetBlas<DeviceContext, T>(ctx);
// Caculate the Output(X@Grad) and Output(Y@Grad). // Caculate the Output(X@Grad) and Output(Y@Grad).
if (d_x || d_y) { if (d_x || d_y) {
Eigen::DSizes<int, 2> bcast_for_x(1, y_dim); Eigen::DSizes<int, 2> bcast_for_x(1, y_dim);
...@@ -138,8 +140,7 @@ class BilinearTensorProductGradKernel : public framework::OpKernel<T> { ...@@ -138,8 +140,7 @@ class BilinearTensorProductGradKernel : public framework::OpKernel<T> {
output_vec.reshape(Eigen::DSizes<int, 2>(batch_size, 1)) output_vec.reshape(Eigen::DSizes<int, 2>(batch_size, 1))
.broadcast(bcast_for_x) * .broadcast(bcast_for_x) *
y_mat; y_mat;
math::gemm<DeviceContext, T>( blas.GEMM(CblasNoTrans, CblasTrans, batch_size, x_dim, y_dim, 1,
dev_ctx, CblasNoTrans, CblasTrans, batch_size, x_dim, y_dim, 1,
y_scale.data<T>(), weight_i.data<T>(), 1, d_x->data<T>()); y_scale.data<T>(), weight_i.data<T>(), 1, d_x->data<T>());
} }
if (d_y) { if (d_y) {
...@@ -147,8 +148,7 @@ class BilinearTensorProductGradKernel : public framework::OpKernel<T> { ...@@ -147,8 +148,7 @@ class BilinearTensorProductGradKernel : public framework::OpKernel<T> {
output_vec.reshape(Eigen::DSizes<int, 2>(batch_size, 1)) output_vec.reshape(Eigen::DSizes<int, 2>(batch_size, 1))
.broadcast(bcast_for_y) * .broadcast(bcast_for_y) *
x_mat; x_mat;
math::gemm<DeviceContext, T>( blas.GEMM(CblasNoTrans, CblasNoTrans, batch_size, y_dim, x_dim, 1,
dev_ctx, CblasNoTrans, CblasNoTrans, batch_size, y_dim, x_dim, 1,
x_scale.data<T>(), weight_i.data<T>(), 1, d_y->data<T>()); x_scale.data<T>(), weight_i.data<T>(), 1, d_y->data<T>());
} }
} }
...@@ -166,9 +166,8 @@ class BilinearTensorProductGradKernel : public framework::OpKernel<T> { ...@@ -166,9 +166,8 @@ class BilinearTensorProductGradKernel : public framework::OpKernel<T> {
output_vec.reshape(Eigen::DSizes<int, 2>(batch_size, 1)) output_vec.reshape(Eigen::DSizes<int, 2>(batch_size, 1))
.broadcast(bcast_for_weight) * .broadcast(bcast_for_weight) *
x_mat; x_mat;
math::gemm<DeviceContext, T>(dev_ctx, CblasTrans, CblasNoTrans, x_dim, blas.GEMM(CblasTrans, CblasNoTrans, x_dim, y_dim, batch_size, 1,
y_dim, batch_size, 1, x_scale.data<T>(), x_scale.data<T>(), y->data<T>(), 0, d_weight_i.data<T>());
y->data<T>(), 0, d_weight_i.data<T>());
} }
} }
......
...@@ -87,10 +87,10 @@ class GRUUnitKernel : public framework::OpKernel<T> { ...@@ -87,10 +87,10 @@ class GRUUnitKernel : public framework::OpKernel<T> {
const T* weight_data = weight->data<T>(); const T* weight_data = weight->data<T>();
T* gate_data = gate->data<T>(); T* gate_data = gate->data<T>();
T* reset_hidden_prev_data = reset_hidden_prev->data<T>(); T* reset_hidden_prev_data = reset_hidden_prev->data<T>();
math::gemm<DeviceContext, T>( auto blas = math::GetBlas<DeviceContext, T>(context);
context.template device_context<DeviceContext>(), false, false, blas.GEMM(false, false, batch_size, 2 * frame_size, frame_size, 1,
batch_size, 2 * frame_size, frame_size, 1, hidden_prev_data, frame_size, hidden_prev_data, frame_size, weight_data, frame_size * 2, 1,
weight_data, frame_size * 2, 1, gate_data, frame_size * 3); gate_data, frame_size * 3);
// calculate activited gate // calculate activited gate
Eigen::array<int, 2> extents({{batch_size, frame_size}}); Eigen::array<int, 2> extents({{batch_size, frame_size}});
...@@ -103,10 +103,9 @@ class GRUUnitKernel : public framework::OpKernel<T> { ...@@ -103,10 +103,9 @@ class GRUUnitKernel : public framework::OpKernel<T> {
g.slice(r_offsets, extents), g.slice(r_offsets, extents)); g.slice(r_offsets, extents), g.slice(r_offsets, extents));
auto r = g.slice(r_offsets, extents); // reset gate auto r = g.slice(r_offsets, extents); // reset gate
r_h_p.device(place) = r * h_p; // reset previous hidden state r_h_p.device(place) = r * h_p; // reset previous hidden state
math::gemm<DeviceContext, T>( blas.GEMM(false, false, batch_size, frame_size, frame_size, 1,
context.template device_context<DeviceContext>(), false, false, reset_hidden_prev_data, frame_size,
batch_size, frame_size, frame_size, 1, reset_hidden_prev_data, weight_data + frame_size * frame_size * 2, frame_size, 1,
frame_size, weight_data + frame_size * frame_size * 2, frame_size, 1,
gate_data + frame_size * 2, frame_size * 3); gate_data + frame_size * 2, frame_size * 3);
Eigen::array<int, 2> c_offsets({{0, frame_size * 2}}); Eigen::array<int, 2> c_offsets({{0, frame_size * 2}});
...@@ -188,11 +187,11 @@ class GRUUnitGradKernel : public framework::OpKernel<T> { ...@@ -188,11 +187,11 @@ class GRUUnitGradKernel : public framework::OpKernel<T> {
ActGradCompute(context.Attr<int>("activation"), place, c, c, ActGradCompute(context.Attr<int>("activation"), place, c, c,
d_g.slice(c_offsets, extents), d_h * u); d_g.slice(c_offsets, extents), d_h * u);
// backward for reset_hidden_prev // backward for reset_hidden_prev
math::gemm<DeviceContext, T>( auto blas = math::GetBlas<DeviceContext, T>(context);
context.template device_context<DeviceContext>(), false, true, blas.GEMM(false, true, batch_size, frame_size, frame_size, 1,
batch_size, frame_size, frame_size, 1, gate_grad_data + frame_size * 2, gate_grad_data + frame_size * 2, frame_size * 3,
frame_size * 3, weight_data + frame_size * frame_size * 2, frame_size, weight_data + frame_size * frame_size * 2, frame_size, 0,
0, reset_hidden_prev_grad_data, frame_size); reset_hidden_prev_grad_data, frame_size);
// backward for unactivated reset gate // backward for unactivated reset gate
ActGradCompute(context.Attr<int>("gate_activation"), place, r, r, ActGradCompute(context.Attr<int>("gate_activation"), place, r, r,
d_g.slice(r_offsets, extents), d_r_h_p * h_p); d_g.slice(r_offsets, extents), d_r_h_p * h_p);
...@@ -200,18 +199,15 @@ class GRUUnitGradKernel : public framework::OpKernel<T> { ...@@ -200,18 +199,15 @@ class GRUUnitGradKernel : public framework::OpKernel<T> {
if (weight_grad) { if (weight_grad) {
T* weight_grad_data = weight_grad->mutable_data<T>(context.GetPlace()); T* weight_grad_data = weight_grad->mutable_data<T>(context.GetPlace());
// backward for state_weight // backward for state_weight
math::gemm<DeviceContext, T>( blas.GEMM(true, false, frame_size, frame_size, batch_size, 1,
context.template device_context<DeviceContext>(), true, false, reset_hidden_prev_data, frame_size,
frame_size, frame_size, batch_size, 1, reset_hidden_prev_data, gate_grad_data + frame_size * 2, frame_size * 3, 0,
frame_size, gate_grad_data + frame_size * 2, frame_size * 3, 0,
weight_grad_data + frame_size * frame_size * 2, frame_size); weight_grad_data + frame_size * frame_size * 2, frame_size);
// backward for update_gate_weight and reset_gate_weight // backward for update_gate_weight and reset_gate_weight
math::gemm<DeviceContext, T>( blas.GEMM(true, false, frame_size, frame_size * 2, batch_size, 1,
context.template device_context<DeviceContext>(), true, false, hidden_prev_data, frame_size, gate_grad_data, frame_size * 3, 0,
frame_size, frame_size * 2, batch_size, 1, hidden_prev_data, weight_grad_data, frame_size * 2);
frame_size, gate_grad_data, frame_size * 3, 0, weight_grad_data,
frame_size * 2);
} }
// backward for hidden_prev // backward for hidden_prev
if (hidden_prev_grad) { if (hidden_prev_grad) {
...@@ -219,11 +215,9 @@ class GRUUnitGradKernel : public framework::OpKernel<T> { ...@@ -219,11 +215,9 @@ class GRUUnitGradKernel : public framework::OpKernel<T> {
hidden_prev_grad->mutable_data<T>(context.GetPlace()); hidden_prev_grad->mutable_data<T>(context.GetPlace());
auto d_h_p = EigenMatrix<T>::From(*hidden_prev_grad); auto d_h_p = EigenMatrix<T>::From(*hidden_prev_grad);
d_h_p.device(place) = d_r_h_p * r + d_h * (u.constant(T(1)) - u); d_h_p.device(place) = d_r_h_p * r + d_h * (u.constant(T(1)) - u);
math::gemm<DeviceContext, T>( blas.GEMM(false, true, batch_size, frame_size, frame_size * 2, 1,
context.template device_context<DeviceContext>(), false, true, gate_grad_data, frame_size * 3, weight_data, frame_size * 2, 1,
batch_size, frame_size, frame_size * 2, 1, gate_grad_data, hidden_prev_grad_data, frame_size);
frame_size * 3, weight_data, frame_size * 2, 1, hidden_prev_grad_data,
frame_size);
} }
// backward for input // backward for input
if (input_grad) { if (input_grad) {
......
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/platform/dynload/cublas.h"
namespace paddle {
namespace operators {
namespace math {
template <typename T>
struct CUBlas;
template <>
struct CUBlas<float> {
template <typename... ARGS>
static void GEMM(ARGS... args) {
PADDLE_ENFORCE(platform::dynload::cublasSgemm(args...));
}
};
template <>
struct CUBlas<double> {
template <typename... ARGS>
static void GEMM(ARGS... args) {
PADDLE_ENFORCE(platform::dynload::cublasDgemm(args...));
}
};
template <>
struct CUBlas<platform::float16> {
template <typename... ARGS>
static void GEMM(ARGS... args) {
PADDLE_ENFORCE(platform::dynload::cublasHgemm(args...));
}
};
template <>
template <typename T>
void Blas<platform::CUDADeviceContext>::GEMM(const CBLAS_TRANSPOSE transA,
const CBLAS_TRANSPOSE transB,
const int M, const int N,
const int K, const T alpha,
const T *A, const T *B,
const T beta, T *C) const {
// Note that cublas follows fortran order, so the order is different from
// the cblas convention.
int lda = (transA == CblasNoTrans) ? K : M;
int ldb = (transB == CblasNoTrans) ? N : K;
cublasOperation_t cuTransA =
(transA == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
cublasOperation_t cuTransB =
(transB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
CUBlas<T>::GEMM(context_.cublas_handle(), cuTransB, cuTransA, N, M, K, &alpha,
B, ldb, A, lda, &beta, C, N);
}
template <>
template <>
inline void Blas<platform::CUDADeviceContext>::GEMM(
const CBLAS_TRANSPOSE transA, const CBLAS_TRANSPOSE transB, const int M,
const int N, const int K, const platform::float16 alpha,
const platform::float16 *A, const platform::float16 *B,
const platform::float16 beta, platform::float16 *C) const {
// Note that cublas follows fortran order, so the order is different from
// the cblas convention.
int lda = (transA == CblasNoTrans) ? K : M;
int ldb = (transB == CblasNoTrans) ? N : K;
cublasOperation_t cuTransA =
(transA == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
cublasOperation_t cuTransB =
(transB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
// TODO(kexinzhao): add processing code for compute capability < 53 case
PADDLE_ENFORCE_GE(context_.GetComputeCapability(), 53,
"cublas fp16 gemm requires GPU compute capability >= 53");
#if CUDA_VERSION >= 8000
float h_alpha = static_cast<float>(alpha);
float h_beta = static_cast<float>(beta);
cublasGemmAlgo_t algo = CUBLAS_GEMM_DFALT;
#if CUDA_VERSION >= 9000
if (context_.GetComputeCapability() >= 70) {
PADDLE_ENFORCE(platform::dynload::cublasSetMathMode(
context_.cublas_handle(), CUBLAS_TENSOR_OP_MATH));
algo = CUBLAS_GEMM_DFALT_TENSOR_OP;
} else {
PADDLE_ENFORCE(platform::dynload::cublasSetMathMode(
context_.cublas_handle(), CUBLAS_DEFAULT_MATH));
}
#endif // CUDA_VERSION >= 9000
// cublasHgemm does true FP16 computation which is slow for non-Volta
// GPUs. So use cublasGemmEx instead which does pesudo FP16 computation:
// input/output in fp16, computation in fp32, which can also be accelerated
// using tensor cores in volta GPUs.
PADDLE_ENFORCE(platform::dynload::cublasGemmEx(
context_.cublas_handle(), cuTransB, cuTransA, N, M, K, &h_alpha, B,
CUDA_R_16F, ldb, A, CUDA_R_16F, lda, &h_beta, C, CUDA_R_16F, N,
CUDA_R_32F, algo));
#else
// CUDA 7.5 does not support cublasGemmEx, hence we fall back to use hgemm
const half h_alpha = static_cast<const half>(alpha);
const half h_beta = static_cast<const half>(beta);
const half *h_A = reinterpret_cast<const half *>(A);
const half *h_B = reinterpret_cast<const half *>(B);
half *h_C = reinterpret_cast<half *>(C);
CUBlas<platform::float16>(context_.cublas_handle(), cuTransB, cuTransA, N, M,
K, &h_alpha, h_B, ldb, h_A, lda, &h_beta, h_C, N);
#endif // CUDA_VERSION >= 8000
}
template <>
template <typename T>
void Blas<platform::CUDADeviceContext>::GEMM(
const bool transA, const bool transB, const int M, const int N, const int K,
const T alpha, const T *A, const int lda, const T *B, const int ldb,
const T beta, T *C, const int ldc) const {
// Note that cublas follows fortran order, so the order is different from
// the cblas convention.
cublasOperation_t cuTransA = transA == false ? CUBLAS_OP_N : CUBLAS_OP_T;
cublasOperation_t cuTransB = transB == false ? CUBLAS_OP_N : CUBLAS_OP_T;
CUBlas<T>::GEMM(context_.cublas_handle(), cuTransB, cuTransA, N, M, K, &alpha,
B, ldb, A, lda, &beta, C, ldc);
}
} // namespace math
} // namespace operators
} // namespace paddle
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/operators/math/math_function.h"
namespace paddle {
namespace operators {
namespace math {
template <typename T>
struct CBlas;
template <>
struct CBlas<float> {
static constexpr auto GEMM = cblas_sgemm;
};
template <>
struct CBlas<double> {
static constexpr auto GEMM = cblas_dgemm;
};
template <>
struct CBlas<platform::float16> {
void GEMM(...) { PADDLE_THROW("float16 GEMM not supported on CPU"); }
};
template <>
template <typename T>
void Blas<platform::CPUDeviceContext>::GEMM(const CBLAS_TRANSPOSE transA,
const CBLAS_TRANSPOSE transB,
const int M, const int N,
const int K, const T alpha,
const T *A, const T *B,
const T beta, T *C) const {
int lda = (transA == CblasNoTrans) ? K : M;
int ldb = (transB == CblasNoTrans) ? N : K;
int ldc = N;
CBlas<T>::GEMM(CblasRowMajor, transA, transB, M, N, K, alpha, A, lda, B, ldb,
beta, C, ldc);
}
template <>
template <typename T>
void Blas<platform::CPUDeviceContext>::GEMM(
const bool transA, const bool transB, const int M, const int N, const int K,
const T alpha, const T *A, const int lda, const T *B, const int ldb,
const T beta, T *C, const int ldc) const {
CBlas<T>::GEMM(CblasRowMajor, transA == false ? CblasNoTrans : CblasTrans,
transB == false ? CblasNoTrans : CblasTrans, M, N, K, alpha, A,
lda, B, ldb, beta, C, ldc);
}
} // namespace math
} // namespace operators
} // namespace paddle
...@@ -25,21 +25,21 @@ struct GRUUnitFunctor<platform::CPUDeviceContext, T> { ...@@ -25,21 +25,21 @@ struct GRUUnitFunctor<platform::CPUDeviceContext, T> {
const detail::ActivationType active_node, const detail::ActivationType active_node,
const detail::ActivationType active_gate) { const detail::ActivationType active_gate) {
#ifndef __NVCC__ #ifndef __NVCC__
auto blas = math::GetBlas<platform::CPUDeviceContext, T>(context);
if (value.prev_out_value) { if (value.prev_out_value) {
math::gemm<platform::CPUDeviceContext, T>( blas.GEMM(false, false, batch_size, frame_size * 2, frame_size, 1,
context, false, false, batch_size, frame_size * 2, frame_size, 1, value.prev_out_value, frame_size, value.gate_weight,
value.prev_out_value, frame_size, value.gate_weight, frame_size * 2, frame_size * 2, 1, value.gate_value, frame_size * 3);
1, value.gate_value, frame_size * 3);
} }
detail::forward_reset_output(detail::forward::gru_resetOutput<T>(), value, detail::forward_reset_output(detail::forward::gru_resetOutput<T>(), value,
frame_size, batch_size, active_gate); frame_size, batch_size, active_gate);
if (value.prev_out_value) { if (value.prev_out_value) {
math::gemm<platform::CPUDeviceContext, T>( blas.GEMM(false, false, batch_size, frame_size, frame_size, 1,
context, false, false, batch_size, frame_size, frame_size, 1, value.reset_output_value, frame_size, value.state_weight,
value.reset_output_value, frame_size, value.state_weight, frame_size, frame_size, 1, value.gate_value + frame_size * 2,
1, value.gate_value + frame_size * 2, frame_size * 3); frame_size * 3);
} }
detail::forward_final_output(detail::forward::gru_finalOutput<T>(), value, detail::forward_final_output(detail::forward::gru_finalOutput<T>(), value,
...@@ -58,16 +58,15 @@ struct GRUUnitGradFunctor<platform::CPUDeviceContext, T> { ...@@ -58,16 +58,15 @@ struct GRUUnitGradFunctor<platform::CPUDeviceContext, T> {
#ifndef __NVCC__ #ifndef __NVCC__
detail::backward_state_grad(detail::backward::gru_stateGrad<T>(), value, detail::backward_state_grad(detail::backward::gru_stateGrad<T>(), value,
grad, frame_size, batch_size, active_node); grad, frame_size, batch_size, active_node);
auto blas = math::GetBlas<platform::CPUDeviceContext, T>(context);
if (value.prev_out_value && grad.prev_out_grad) { if (value.prev_out_value && grad.prev_out_grad) {
math::gemm<platform::CPUDeviceContext, T>( blas.GEMM(false, true, batch_size, frame_size, frame_size, 1,
context, false, true, batch_size, frame_size, frame_size, 1, grad.gate_grad + frame_size * 2, frame_size * 3,
grad.gate_grad + frame_size * 2, frame_size * 3, value.state_weight, value.state_weight, frame_size, 0, grad.reset_output_grad,
frame_size, 0, grad.reset_output_grad, frame_size); frame_size);
if (grad.state_weight_grad) { if (grad.state_weight_grad) {
math::gemm<platform::CPUDeviceContext, T>( blas.GEMM(true, false, frame_size, frame_size, batch_size, 1,
context, true, false, frame_size, frame_size, batch_size, 1,
value.reset_output_value, frame_size, value.reset_output_value, frame_size,
grad.gate_grad + frame_size * 2, frame_size * 3, 1, grad.gate_grad + frame_size * 2, frame_size * 3, 1,
grad.state_weight_grad, frame_size); grad.state_weight_grad, frame_size);
...@@ -76,18 +75,15 @@ struct GRUUnitGradFunctor<platform::CPUDeviceContext, T> { ...@@ -76,18 +75,15 @@ struct GRUUnitGradFunctor<platform::CPUDeviceContext, T> {
detail::backward_reset_grad(detail::backward::gru_resetGrad<T>(), value, detail::backward_reset_grad(detail::backward::gru_resetGrad<T>(), value,
grad, frame_size, batch_size, active_gate); grad, frame_size, batch_size, active_gate);
if (grad.prev_out_grad && value.prev_out_value) { if (grad.prev_out_grad && value.prev_out_value) {
math::gemm<platform::CPUDeviceContext, T>( blas.GEMM(false, true, batch_size, frame_size, frame_size * 2, 1,
context, false, true, batch_size, frame_size, frame_size * 2, 1, grad.gate_grad, frame_size * 3, value.gate_weight,
grad.gate_grad, frame_size * 3, value.gate_weight, frame_size * 2, 1, frame_size * 2, 1, grad.prev_out_grad, frame_size);
grad.prev_out_grad, frame_size);
if (grad.gate_weight_grad) { if (grad.gate_weight_grad) {
math::gemm<platform::CPUDeviceContext, T>( blas.GEMM(true, false, frame_size, frame_size * 2, batch_size, 1,
context, true, false, frame_size, frame_size * 2, batch_size, 1, value.prev_out_value, frame_size, grad.gate_grad,
value.prev_out_value, frame_size, grad.gate_grad, frame_size * 3, 1, frame_size * 3, 1, grad.gate_weight_grad, frame_size * 2);
grad.gate_weight_grad, frame_size * 2);
} }
} }
#endif #endif
......
...@@ -9,6 +9,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -9,6 +9,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include <paddle/fluid/platform/device_context.h>
#include "paddle/fluid/operators/math/detail/gru_gpu_kernel.h" #include "paddle/fluid/operators/math/detail/gru_gpu_kernel.h"
#include "paddle/fluid/operators/math/detail/gru_kernel.h" #include "paddle/fluid/operators/math/detail/gru_kernel.h"
#include "paddle/fluid/operators/math/gru_compute.h" #include "paddle/fluid/operators/math/gru_compute.h"
...@@ -36,12 +37,11 @@ struct GRUUnitFunctor<platform::CUDADeviceContext, T> { ...@@ -36,12 +37,11 @@ struct GRUUnitFunctor<platform::CUDADeviceContext, T> {
threads = dim3(32, 32); threads = dim3(32, 32);
grid = dim3((frame_size + 32 - 1) / 32, (batch_size + 32 - 1) / 32); grid = dim3((frame_size + 32 - 1) / 32, (batch_size + 32 - 1) / 32);
} }
auto blas = math::GetBlas<platform::CUDADeviceContext, T>(context);
if (value.prev_out_value) { if (value.prev_out_value) {
math::gemm<platform::CUDADeviceContext, T>( blas.GEMM(false, false, batch_size, frame_size * 2, frame_size, 1,
context, false, false, batch_size, frame_size * 2, frame_size, 1, value.prev_out_value, frame_size, value.gate_weight,
value.prev_out_value, frame_size, value.gate_weight, frame_size * 2, frame_size * 2, 1, value.gate_value, frame_size * 3);
1, value.gate_value, frame_size * 3);
} }
if (batch_size == 1) { if (batch_size == 1) {
...@@ -61,10 +61,10 @@ struct GRUUnitFunctor<platform::CUDADeviceContext, T> { ...@@ -61,10 +61,10 @@ struct GRUUnitFunctor<platform::CUDADeviceContext, T> {
} }
if (value.prev_out_value) { if (value.prev_out_value) {
math::gemm<platform::CUDADeviceContext, T>( blas.GEMM(false, false, batch_size, frame_size, frame_size, 1,
context, false, false, batch_size, frame_size, frame_size, 1, value.reset_output_value, frame_size, value.state_weight,
value.reset_output_value, frame_size, value.state_weight, frame_size, frame_size, 1, value.gate_value + frame_size * 2,
1, value.gate_value + frame_size * 2, frame_size * 3); frame_size * 3);
} }
if (batch_size == 1) { if (batch_size == 1) {
...@@ -121,15 +121,16 @@ struct GRUUnitGradFunctor<platform::CUDADeviceContext, T> { ...@@ -121,15 +121,16 @@ struct GRUUnitGradFunctor<platform::CUDADeviceContext, T> {
grad.output_grad, frame_size, batch_size, active_node); grad.output_grad, frame_size, batch_size, active_node);
} }
auto blas = math::GetBlas<platform::CUDADeviceContext, T>(context);
if (value.prev_out_value && grad.prev_out_grad) { if (value.prev_out_value && grad.prev_out_grad) {
math::gemm<platform::CUDADeviceContext, T>( blas.GEMM(false, true, batch_size, frame_size, frame_size, 1,
context, false, true, batch_size, frame_size, frame_size, 1, grad.gate_grad + frame_size * 2, frame_size * 3,
grad.gate_grad + frame_size * 2, frame_size * 3, value.state_weight, value.state_weight, frame_size, 0, grad.reset_output_grad,
frame_size, 0, grad.reset_output_grad, frame_size); frame_size);
if (grad.state_weight_grad) { if (grad.state_weight_grad) {
math::gemm<platform::CUDADeviceContext, T>( blas.GEMM(true, false, frame_size, frame_size, batch_size, 1,
context, true, false, frame_size, frame_size, batch_size, 1,
value.reset_output_value, frame_size, value.reset_output_value, frame_size,
grad.gate_grad + frame_size * 2, frame_size * 3, 1, grad.gate_grad + frame_size * 2, frame_size * 3, 1,
grad.state_weight_grad, frame_size); grad.state_weight_grad, frame_size);
...@@ -153,16 +154,14 @@ struct GRUUnitGradFunctor<platform::CUDADeviceContext, T> { ...@@ -153,16 +154,14 @@ struct GRUUnitGradFunctor<platform::CUDADeviceContext, T> {
} }
if (grad.prev_out_grad && value.prev_out_value) { if (grad.prev_out_grad && value.prev_out_value) {
math::gemm<platform::CUDADeviceContext, T>( blas.GEMM(false, true, batch_size, frame_size, frame_size * 2, 1,
context, false, true, batch_size, frame_size, frame_size * 2, 1, grad.gate_grad, frame_size * 3, value.gate_weight,
grad.gate_grad, frame_size * 3, value.gate_weight, frame_size * 2, 1, frame_size * 2, 1, grad.prev_out_grad, frame_size);
grad.prev_out_grad, frame_size);
if (grad.gate_weight_grad) { if (grad.gate_weight_grad) {
math::gemm<platform::CUDADeviceContext, T>( blas.GEMM(true, false, frame_size, frame_size * 2, batch_size, 1,
context, true, false, frame_size, frame_size * 2, batch_size, 1, value.prev_out_value, frame_size, grad.gate_grad,
value.prev_out_value, frame_size, grad.gate_grad, frame_size * 3, 1, frame_size * 3, 1, grad.gate_weight_grad, frame_size * 2);
grad.gate_weight_grad, frame_size * 2);
} }
} }
} }
......
...@@ -24,72 +24,6 @@ namespace math { ...@@ -24,72 +24,6 @@ namespace math {
using float16 = paddle::platform::float16; using float16 = paddle::platform::float16;
template <>
void gemm<platform::CPUDeviceContext, float16>(
const platform::CPUDeviceContext& context, const CBLAS_TRANSPOSE transA,
const CBLAS_TRANSPOSE transB, const int M, const int N, const int K,
const float16 alpha, const float16* A, const float16* B, const float16 beta,
float16* C) {
PADDLE_THROW("float16 GEMM not supported on CPU");
}
template <>
void gemm<platform::CPUDeviceContext, float>(
const platform::CPUDeviceContext& context, const CBLAS_TRANSPOSE transA,
const CBLAS_TRANSPOSE transB, const int M, const int N, const int K,
const float alpha, const float* A, const float* B, const float beta,
float* C) {
int lda = (transA == CblasNoTrans) ? K : M;
int ldb = (transB == CblasNoTrans) ? N : K;
int ldc = N;
cblas_sgemm(CblasRowMajor, transA, transB, M, N, K, alpha, A, lda, B, ldb,
beta, C, ldc);
}
template <>
void gemm<platform::CPUDeviceContext, double>(
const platform::CPUDeviceContext& context, const CBLAS_TRANSPOSE transA,
const CBLAS_TRANSPOSE transB, const int M, const int N, const int K,
const double alpha, const double* A, const double* B, const double beta,
double* C) {
int lda = (transA == CblasNoTrans) ? K : M;
int ldb = (transB == CblasNoTrans) ? N : K;
int ldc = N;
cblas_dgemm(CblasRowMajor, transA, transB, M, N, K, alpha, A, lda, B, ldb,
beta, C, ldc);
}
template <>
void gemm<platform::CPUDeviceContext, float16>(
const platform::CPUDeviceContext& context, const bool transA,
const bool transB, const int M, const int N, const int K,
const float16 alpha, const float16* A, const int lda, const float16* B,
const int ldb, const float16 beta, float16* C, const int ldc) {
PADDLE_THROW("float16 GEMM not supported on CPU");
}
template <>
void gemm<platform::CPUDeviceContext, float>(
const platform::CPUDeviceContext& context, const bool transA,
const bool transB, const int M, const int N, const int K, const float alpha,
const float* A, const int lda, const float* B, const int ldb,
const float beta, float* C, const int ldc) {
cblas_sgemm(CblasRowMajor, transA == false ? CblasNoTrans : CblasTrans,
transB == false ? CblasNoTrans : CblasTrans, M, N, K, alpha, A,
lda, B, ldb, beta, C, ldc);
}
template <>
void gemm<platform::CPUDeviceContext, double>(
const platform::CPUDeviceContext& context, const bool transA,
const bool transB, const int M, const int N, const int K,
const double alpha, const double* A, const int lda, const double* B,
const int ldb, const double beta, double* C, const int ldc) {
cblas_dgemm(CblasRowMajor, transA == false ? CblasNoTrans : CblasTrans,
transB == false ? CblasNoTrans : CblasTrans, M, N, K, alpha, A,
lda, B, ldb, beta, C, ldc);
}
template <> template <>
void matmul<platform::CPUDeviceContext, float16>( void matmul<platform::CPUDeviceContext, float16>(
const platform::CPUDeviceContext& context, const platform::CPUDeviceContext& context,
...@@ -123,8 +57,8 @@ void matmul<platform::CPUDeviceContext, float>( ...@@ -123,8 +57,8 @@ void matmul<platform::CPUDeviceContext, float>(
CBLAS_TRANSPOSE transA = (trans_a == false) ? CblasNoTrans : CblasTrans; CBLAS_TRANSPOSE transA = (trans_a == false) ? CblasNoTrans : CblasTrans;
CBLAS_TRANSPOSE transB = (trans_b == false) ? CblasNoTrans : CblasTrans; CBLAS_TRANSPOSE transB = (trans_b == false) ? CblasNoTrans : CblasTrans;
gemm<platform::CPUDeviceContext, float>( Blas<platform::CPUDeviceContext>(context).GEMM(
context, transA, transB, M, N, K, alpha, matrix_a.data<float>(), transA, transB, M, N, K, alpha, matrix_a.data<float>(),
matrix_b.data<float>(), beta, matrix_out->data<float>()); matrix_b.data<float>(), beta, matrix_out->data<float>());
} }
...@@ -152,8 +86,8 @@ void matmul<platform::CPUDeviceContext, double>( ...@@ -152,8 +86,8 @@ void matmul<platform::CPUDeviceContext, double>(
CBLAS_TRANSPOSE transA = (trans_a == false) ? CblasNoTrans : CblasTrans; CBLAS_TRANSPOSE transA = (trans_a == false) ? CblasNoTrans : CblasTrans;
CBLAS_TRANSPOSE transB = (trans_b == false) ? CblasNoTrans : CblasTrans; CBLAS_TRANSPOSE transB = (trans_b == false) ? CblasNoTrans : CblasTrans;
gemm<platform::CPUDeviceContext, double>( Blas<platform::CPUDeviceContext>(context).GEMM(
context, transA, transB, M, N, K, alpha, matrix_a.data<double>(), transA, transB, M, N, K, alpha, matrix_a.data<double>(),
matrix_b.data<double>(), beta, matrix_out->data<double>()); matrix_b.data<double>(), beta, matrix_out->data<double>());
} }
...@@ -230,7 +164,7 @@ void batched_gemm<platform::CPUDeviceContext, float>( ...@@ -230,7 +164,7 @@ void batched_gemm<platform::CPUDeviceContext, float>(
const float* Ak = &A[k * strideA]; const float* Ak = &A[k * strideA];
const float* Bk = &B[k * strideB]; const float* Bk = &B[k * strideB];
float* Ck = &C[k * M * N]; float* Ck = &C[k * M * N];
gemm<platform::CPUDeviceContext, float>(context, transA, transB, M, N, K, Blas<platform::CPUDeviceContext>(context).GEMM(transA, transB, M, N, K,
alpha, Ak, Bk, beta, Ck); alpha, Ak, Bk, beta, Ck);
} }
} }
...@@ -246,7 +180,7 @@ void batched_gemm<platform::CPUDeviceContext, double>( ...@@ -246,7 +180,7 @@ void batched_gemm<platform::CPUDeviceContext, double>(
const double* Ak = &A[k * strideA]; const double* Ak = &A[k * strideA];
const double* Bk = &B[k * strideB]; const double* Bk = &B[k * strideB];
double* Ck = &C[k * M * N]; double* Ck = &C[k * M * N];
gemm<platform::CPUDeviceContext, double>(context, transA, transB, M, N, K, Blas<platform::CPUDeviceContext>(context).GEMM(transA, transB, M, N, K,
alpha, Ak, Bk, beta, Ck); alpha, Ak, Bk, beta, Ck);
} }
} }
......
...@@ -25,157 +25,6 @@ namespace math { ...@@ -25,157 +25,6 @@ namespace math {
using float16 = paddle::platform::float16; using float16 = paddle::platform::float16;
template <>
void gemm<platform::CUDADeviceContext, float16>(
const platform::CUDADeviceContext& context, const CBLAS_TRANSPOSE transA,
const CBLAS_TRANSPOSE transB, const int M, const int N, const int K,
const float16 alpha, const float16* A, const float16* B, const float16 beta,
float16* C) {
// Note that cublas follows fortran order, so the order is different from
// the cblas convention.
int lda = (transA == CblasNoTrans) ? K : M;
int ldb = (transB == CblasNoTrans) ? N : K;
cublasOperation_t cuTransA =
(transA == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
cublasOperation_t cuTransB =
(transB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
// TODO(kexinzhao): add processing code for compute capability < 53 case
PADDLE_ENFORCE_GE(context.GetComputeCapability(), 53,
"cublas fp16 gemm requires GPU compute capability >= 53");
#if CUDA_VERSION >= 8000
float h_alpha = static_cast<float>(alpha);
float h_beta = static_cast<float>(beta);
cublasGemmAlgo_t algo = CUBLAS_GEMM_DFALT;
#if CUDA_VERSION >= 9000
if (context.GetComputeCapability() >= 70) {
PADDLE_ENFORCE(platform::dynload::cublasSetMathMode(context.cublas_handle(),
CUBLAS_TENSOR_OP_MATH));
algo = CUBLAS_GEMM_DFALT_TENSOR_OP;
} else {
PADDLE_ENFORCE(platform::dynload::cublasSetMathMode(context.cublas_handle(),
CUBLAS_DEFAULT_MATH));
}
#endif // CUDA_VERSION >= 9000
// cublasHgemm does true FP16 computation which is slow for non-Volta
// GPUs. So use cublasGemmEx instead which does pesudo FP16 computation:
// input/output in fp16, computation in fp32, which can also be accelerated
// using tensor cores in volta GPUs.
PADDLE_ENFORCE(platform::dynload::cublasGemmEx(
context.cublas_handle(), cuTransB, cuTransA, N, M, K, &h_alpha, B,
CUDA_R_16F, ldb, A, CUDA_R_16F, lda, &h_beta, C, CUDA_R_16F, N,
CUDA_R_32F, algo));
#else
// CUDA 7.5 does not support cublasGemmEx, hence we fall back to use hgemm
const half h_alpha = static_cast<const half>(alpha);
const half h_beta = static_cast<const half>(beta);
const half* h_A = reinterpret_cast<const half*>(A);
const half* h_B = reinterpret_cast<const half*>(B);
half* h_C = reinterpret_cast<half*>(C);
PADDLE_ENFORCE(platform::dynload::cublasHgemm(
context.cublas_handle(), cuTransB, cuTransA, N, M, K, &h_alpha, h_B, ldb,
h_A, lda, &h_beta, h_C, N));
#endif // CUDA_VERSION >= 8000
}
template <>
void gemm<platform::CUDADeviceContext, float>(
const platform::CUDADeviceContext& context, const CBLAS_TRANSPOSE transA,
const CBLAS_TRANSPOSE transB, const int M, const int N, const int K,
const float alpha, const float* A, const float* B, const float beta,
float* C) {
// Note that cublas follows fortran order, so the order is different from
// the cblas convention.
int lda = (transA == CblasNoTrans) ? K : M;
int ldb = (transB == CblasNoTrans) ? N : K;
cublasOperation_t cuTransA =
(transA == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
cublasOperation_t cuTransB =
(transB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
PADDLE_ENFORCE(platform::dynload::cublasSgemm(
context.cublas_handle(), cuTransB, cuTransA, N, M, K, &alpha, B, ldb, A,
lda, &beta, C, N));
}
template <>
void gemm<platform::CUDADeviceContext, double>(
const platform::CUDADeviceContext& context, const CBLAS_TRANSPOSE transA,
const CBLAS_TRANSPOSE transB, const int M, const int N, const int K,
const double alpha, const double* A, const double* B, const double beta,
double* C) {
// Note that cublas follows fortran order, so the order is different from
// the cblas convention.
int lda = (transA == CblasNoTrans) ? K : M;
int ldb = (transB == CblasNoTrans) ? N : K;
cublasOperation_t cuTransA =
(transA == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
cublasOperation_t cuTransB =
(transB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
PADDLE_ENFORCE(platform::dynload::cublasDgemm(
context.cublas_handle(), cuTransB, cuTransA, N, M, K, &alpha, B, ldb, A,
lda, &beta, C, N));
}
template <>
void gemm<platform::CUDADeviceContext, float16>(
const platform::CUDADeviceContext& context, const bool transA,
const bool transB, const int M, const int N, const int K,
const float16 alpha, const float16* A, const int lda, const float16* B,
const int ldb, const float16 beta, float16* C, const int ldc) {
// Note that cublas follows fortran order, so the order is different from
// the cblas convention.
cublasOperation_t cuTransA = transA == false ? CUBLAS_OP_N : CUBLAS_OP_T;
cublasOperation_t cuTransB = transB == false ? CUBLAS_OP_N : CUBLAS_OP_T;
const half h_alpha = static_cast<const half>(alpha);
const half h_beta = static_cast<const half>(beta);
const half* h_A = reinterpret_cast<const half*>(A);
const half* h_B = reinterpret_cast<const half*>(B);
half* h_C = reinterpret_cast<half*>(C);
// TODO(kexinzhao): add processing code for compute capability < 53 case
PADDLE_ENFORCE_GE(context.GetComputeCapability(), 53,
"cublas Hgemm requires GPU compute capability >= 53");
PADDLE_ENFORCE(platform::dynload::cublasHgemm(
context.cublas_handle(), cuTransB, cuTransA, N, M, K, &h_alpha, h_B, ldb,
h_A, lda, &h_beta, h_C, ldc));
}
template <>
void gemm<platform::CUDADeviceContext, float>(
const platform::CUDADeviceContext& context, const bool transA,
const bool transB, const int M, const int N, const int K, const float alpha,
const float* A, const int lda, const float* B, const int ldb,
const float beta, float* C, const int ldc) {
// Note that cublas follows fortran order, so the order is different from
// the cblas convention.
cublasOperation_t cuTransA = transA == false ? CUBLAS_OP_N : CUBLAS_OP_T;
cublasOperation_t cuTransB = transB == false ? CUBLAS_OP_N : CUBLAS_OP_T;
PADDLE_ENFORCE(platform::dynload::cublasSgemm(
context.cublas_handle(), cuTransB, cuTransA, N, M, K, &alpha, B, ldb, A,
lda, &beta, C, ldc));
}
template <>
void gemm<platform::CUDADeviceContext, double>(
const platform::CUDADeviceContext& context, const bool transA,
const bool transB, const int M, const int N, const int K,
const double alpha, const double* A, const int lda, const double* B,
const int ldb, const double beta, double* C, const int ldc) {
// Note that cublas follows fortran order, so the order is different from
// the cblas convention.
cublasOperation_t cuTransA = transA == false ? CUBLAS_OP_N : CUBLAS_OP_T;
cublasOperation_t cuTransB = transB == false ? CUBLAS_OP_N : CUBLAS_OP_T;
PADDLE_ENFORCE(platform::dynload::cublasDgemm(
context.cublas_handle(), cuTransB, cuTransA, N, M, K, &alpha, B, ldb, A,
lda, &beta, C, ldc));
}
template <> template <>
void matmul<platform::CUDADeviceContext, float16>( void matmul<platform::CUDADeviceContext, float16>(
const platform::CUDADeviceContext& context, const platform::CUDADeviceContext& context,
...@@ -200,8 +49,8 @@ void matmul<platform::CUDADeviceContext, float16>( ...@@ -200,8 +49,8 @@ void matmul<platform::CUDADeviceContext, float16>(
CBLAS_TRANSPOSE transA = (trans_a == false) ? CblasNoTrans : CblasTrans; CBLAS_TRANSPOSE transA = (trans_a == false) ? CblasNoTrans : CblasTrans;
CBLAS_TRANSPOSE transB = (trans_b == false) ? CblasNoTrans : CblasTrans; CBLAS_TRANSPOSE transB = (trans_b == false) ? CblasNoTrans : CblasTrans;
gemm<platform::CUDADeviceContext, float16>( Blas<platform::CUDADeviceContext>(context).GEMM(
context, transA, transB, M, N, K, alpha, matrix_a.data<float16>(), transA, transB, M, N, K, alpha, matrix_a.data<float16>(),
matrix_b.data<float16>(), beta, matrix_out->data<float16>()); matrix_b.data<float16>(), beta, matrix_out->data<float16>());
} }
...@@ -229,8 +78,8 @@ void matmul<platform::CUDADeviceContext, float>( ...@@ -229,8 +78,8 @@ void matmul<platform::CUDADeviceContext, float>(
CBLAS_TRANSPOSE transA = (trans_a == false) ? CblasNoTrans : CblasTrans; CBLAS_TRANSPOSE transA = (trans_a == false) ? CblasNoTrans : CblasTrans;
CBLAS_TRANSPOSE transB = (trans_b == false) ? CblasNoTrans : CblasTrans; CBLAS_TRANSPOSE transB = (trans_b == false) ? CblasNoTrans : CblasTrans;
gemm<platform::CUDADeviceContext, float>( Blas<platform::CUDADeviceContext>(context).GEMM(
context, transA, transB, M, N, K, alpha, matrix_a.data<float>(), transA, transB, M, N, K, alpha, matrix_a.data<float>(),
matrix_b.data<float>(), beta, matrix_out->data<float>()); matrix_b.data<float>(), beta, matrix_out->data<float>());
} }
...@@ -258,8 +107,8 @@ void matmul<platform::CUDADeviceContext, double>( ...@@ -258,8 +107,8 @@ void matmul<platform::CUDADeviceContext, double>(
CBLAS_TRANSPOSE transA = (trans_a == false) ? CblasNoTrans : CblasTrans; CBLAS_TRANSPOSE transA = (trans_a == false) ? CblasNoTrans : CblasTrans;
CBLAS_TRANSPOSE transB = (trans_b == false) ? CblasNoTrans : CblasTrans; CBLAS_TRANSPOSE transB = (trans_b == false) ? CblasNoTrans : CblasTrans;
gemm<platform::CUDADeviceContext, double>( Blas<platform::CUDADeviceContext>(context).GEMM(
context, transA, transB, M, N, K, alpha, matrix_a.data<double>(), transA, transB, M, N, K, alpha, matrix_a.data<double>(),
matrix_b.data<double>(), beta, matrix_out->data<double>()); matrix_b.data<double>(), beta, matrix_out->data<double>());
} }
......
...@@ -42,6 +42,7 @@ int LAPACKE_dgetri(int matrix_layout, int n, double* a, int lda, ...@@ -42,6 +42,7 @@ int LAPACKE_dgetri(int matrix_layout, int n, double* a, int lda,
#include <vector> #include <vector>
#include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/tensor.h" #include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/framework/tensor_util.h" #include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/platform/device_context.h" #include "paddle/fluid/platform/device_context.h"
...@@ -56,17 +57,48 @@ namespace math { ...@@ -56,17 +57,48 @@ namespace math {
// Then matrixA: M * K, matrixB: K * N, matrixC : M * N // Then matrixA: M * K, matrixB: K * N, matrixC : M * N
// For more detailed info, please refer to // For more detailed info, please refer to
// http://www.netlib.org/lapack/explore-html/d4/de2/sgemm_8f.html // http://www.netlib.org/lapack/explore-html/d4/de2/sgemm_8f.html
template <typename DeviceContext>
class Blas {
public:
explicit Blas(const DeviceContext& context) : context_(context) {}
template <typename T>
void GEMM(const CBLAS_TRANSPOSE transA, const CBLAS_TRANSPOSE transB,
const int M, const int N, const int K, const T alpha, const T* A,
const T* B, const T beta, T* C) const;
template <typename T>
void GEMM(const bool transA, const bool transB, const int M, const int N,
const int K, const T alpha, const T* A, const int lda, const T* B,
const int ldb, const T beta, T* C, const int ldc) const;
private:
const DeviceContext& context_;
};
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
void gemm(const DeviceContext& context, const CBLAS_TRANSPOSE transA, class BlasT : private Blas<DeviceContext> {
const CBLAS_TRANSPOSE transB, const int M, const int N, const int K, public:
const T alpha, const T* A, const T* B, const T beta, T* C); using Blas<DeviceContext>::Blas;
template <typename... ARGS>
void GEMM(ARGS... args) const {
static_cast<const Blas<DeviceContext>*>(this)->template GEMM<T>(args...);
}
};
// gemm wrapper with stride args for matrix uncontinuous in memory
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
void gemm(const DeviceContext& context, const bool transA, const bool transB, inline BlasT<DeviceContext, T> GetBlas(
const int M, const int N, const int K, const T alpha, const T* A, const framework::ExecutionContext& exe_ctx) {
const int lda, const T* B, const int ldb, const T beta, T* C, return BlasT<DeviceContext, T>(
const int ldc); exe_ctx.template device_context<DeviceContext>());
}
template <typename DeviceContext, typename T>
inline BlasT<DeviceContext, T> GetBlas(const DeviceContext& dev_ctx) {
return BlasT<DeviceContext, T>(dev_ctx);
}
// matrix multiply with continuous memory // matrix multiply with continuous memory
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
...@@ -137,3 +169,8 @@ struct RowwiseMean { ...@@ -137,3 +169,8 @@ struct RowwiseMean {
} // namespace math } // namespace math
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
#include "paddle/fluid/operators/math/blas_impl.h"
#ifdef PADDLE_WITH_CUDA
#include "paddle/fluid/operators/math/blas_impl.cu.h"
#endif
...@@ -131,8 +131,9 @@ class MatMulFunctor { ...@@ -131,8 +131,9 @@ class MatMulFunctor {
if (!batchCount) { if (!batchCount) {
// regular matrix multiplication // regular matrix multiplication
gemm<DeviceContext, T>(context, transA, transB, M, N, kA, alpha, Blas<DeviceContext>(context).GEMM(transA, transB, M, N, kA, alpha,
a.data<T>(), b.data<T>(), beta, out->data<T>()); a.data<T>(), b.data<T>(), beta,
out->data<T>());
} else { } else {
// batched matrix multiplication // batched matrix multiplication
batched_gemm<DeviceContext, T>( batched_gemm<DeviceContext, T>(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册