提交 815d8884 编写于 作者: Y Yu Yang

Clean MatMul

上级 9d7279b9
......@@ -161,6 +161,7 @@ class GemmConvKernel : public framework::OpKernel<T> {
math::Im2ColFunctor<math::ColFormat::kCFO, DeviceContext, T> im2col;
auto& dev_ctx = context.template device_context<DeviceContext>();
auto blas = math::GetBlas<DeviceContext, T>(dev_ctx);
for (int i = 0; i < batch_size; i++) {
Tensor in_batch = input->Slice(i, i + 1).Resize(input_shape);
Tensor out_batch = output->Slice(i, i + 1).Resize(output_matrix_shape);
......@@ -186,8 +187,7 @@ class GemmConvKernel : public framework::OpKernel<T> {
// gemm
Tensor out_slice = out_batch.Slice(g * out_step, (g + 1) * out_step);
Tensor filter_slice = filter.Slice(g * out_step, (g + 1) * out_step);
math::matmul<DeviceContext, T>(dev_ctx, filter_slice, false, col_matrix,
false, T(1.0), &out_slice, T(0.0));
blas.MatMul(filter_slice, col_matrix, &out_slice);
}
}
}
......@@ -274,6 +274,7 @@ class GemmConvGradKernel : public framework::OpKernel<T> {
math::SetConstant<DeviceContext, T> set_zero;
auto& dev_ctx = context.template device_context<DeviceContext>();
auto blas = math::GetBlas<DeviceContext, T>(dev_ctx);
if (input_grad) {
input_grad->mutable_data<T>(context.GetPlace());
......@@ -303,9 +304,7 @@ class GemmConvGradKernel : public framework::OpKernel<T> {
col_matrix.ShareDataWith(in_grad_slice);
col_matrix.Resize(col_matrix_shape);
}
math::matmul<DeviceContext, T>(dev_ctx, filter_slice, true,
out_grad_slice, false, T(1.0),
&col_matrix, T(0.0));
blas.MatMul(filter_slice, true, out_grad_slice, false, &col_matrix);
if (is_expand && data_dim == 2U) {
col2im(dev_ctx, col, dilations, strides,
......@@ -352,9 +351,8 @@ class GemmConvGradKernel : public framework::OpKernel<T> {
// gemm
Tensor filter_grad_slice =
filter_grad_.Slice(g * out_step, (g + 1) * out_step);
math::matmul<DeviceContext, T>(dev_ctx, out_grad_slice, false,
col_matrix, true, T(1.0),
&filter_grad_slice, T(1.0));
blas.MatMul(out_grad_slice, false, col_matrix, true,
&filter_grad_slice);
}
}
}
......
......@@ -118,6 +118,7 @@ class GemmConvTransposeKernel : public framework::OpKernel<T> {
output->mutable_data<T>(context.GetPlace());
math::SetConstant<DeviceContext, T> set_zero;
auto& dev_ctx = context.template device_context<DeviceContext>();
auto blas = math::GetBlas<DeviceContext, T>(dev_ctx);
set_zero(dev_ctx, output, static_cast<T>(0));
math::Col2ImFunctor<math::ColFormat::kCFO, DeviceContext, T> col2im;
......@@ -134,9 +135,7 @@ class GemmConvTransposeKernel : public framework::OpKernel<T> {
// col_matrix = filter * input_batch
// of shape (c * k_h * k_w, h * w) or (c * k_d * k_h * k_w, d * h * w)
math::matmul<DeviceContext, T>(dev_ctx, filter, true, input_batch, false,
static_cast<T>(1.0), &col_matrix,
static_cast<T>(0.0));
blas.MatMul(filter, true, input_batch, false, &col_matrix);
if (data_dim == 2U) {
// col2im: col_matrix -> dy
......@@ -213,6 +212,7 @@ class GemmConvTransposeGradKernel : public framework::OpKernel<T> {
// im2col + gemm (similar to conv-forward)
// input need to compute gradient
auto& dev_ctx = context.template device_context<DeviceContext>();
auto blas = math::GetBlas<DeviceContext, T>(dev_ctx);
if (input_grad || filter_grad) {
Tensor col;
col.mutable_data<T>(col_shape, context.GetPlace());
......@@ -267,9 +267,7 @@ class GemmConvTransposeGradKernel : public framework::OpKernel<T> {
// or
// (m, c * k_d * k_h * k_w) * (c * k_d * k_h * k_w, d * h * w) -> (m,
// d, h, w)
math::matmul<DeviceContext, T>(
dev_ctx, filter, false, col_matrix, false, static_cast<T>(1.0),
&input_grad_batch, static_cast<T>(0.0));
blas.MatMul(filter, false, col_matrix, false, &input_grad_batch);
}
if (filter_grad) {
// input batch
......@@ -279,9 +277,7 @@ class GemmConvTransposeGradKernel : public framework::OpKernel<T> {
// or
// (m, d * h * w) * (d * h * w, c * k_d * k_h * k_w) -> (m, c * k_d *
// k_h * k_w)
math::matmul<DeviceContext, T>(dev_ctx, in_batch, false, col_matrix,
true, static_cast<T>(1.0),
&filter_grad_, static_cast<T>(1.0));
blas.MatMul(in_batch, false, col_matrix, true, &filter_grad_);
}
}
}
......
......@@ -114,6 +114,7 @@ class LSTMKernel : public framework::OpKernel<T> {
auto cand_act = math::detail::GetActivationType(
ctx.Attr<std::string>("candidate_activation"));
auto blas = math::GetBlas<DeviceContext, T>(device_ctx);
for (size_t n = 0; n < num_batch; n++) {
int bstart = static_cast<int>(batch_starts[n]);
int bend = static_cast<int>(batch_starts[n + 1]);
......@@ -129,9 +130,8 @@ class LSTMKernel : public framework::OpKernel<T> {
int pre_h_start = static_cast<int>(batch_starts[n - 1]);
int pre_h_end = pre_h_start + cur_batch_size;
auto pre_hidden_t = batch_hidden.Slice(pre_h_start, pre_h_end);
math::matmul<DeviceContext, T>(device_ctx, pre_hidden_t, false, *weight,
false, static_cast<T>(1.0), &gate_t,
static_cast<T>(1.0));
blas.MatMul(pre_hidden_t, false, *weight, false, static_cast<T>(1.0),
&gate_t, static_cast<T>(1.0));
} else if (hidden_t0) {
// If n == 0 and there is no initialized hidden state, that is to say
// the H0 is zeros, the calculation W_h * H0 will be skiped.
......@@ -143,9 +143,8 @@ class LSTMKernel : public framework::OpKernel<T> {
Tensor ordered_h0;
ReorderInitState<DeviceContext, T>(device_ctx, *hidden_t0, order,
&ordered_h0, true);
math::matmul<DeviceContext, T>(device_ctx, ordered_h0, false, *weight,
false, static_cast<T>(1.0), &gate_t,
static_cast<T>(1.0));
blas.MatMul(ordered_h0, false, *weight, false, static_cast<T>(1.0),
&gate_t, static_cast<T>(1.0));
}
lstm_value.gate_value = gate_t.data<T>();
......@@ -282,6 +281,7 @@ class LSTMGradKernel : public framework::OpKernel<T> {
auto batch_starts = batch_gate->lod()[0];
size_t num_batch = batch_starts.size() - 1;
auto blas = math::GetBlas<DeviceContext, T>(device_ctx);
for (int n = static_cast<int>(num_batch) - 1; n >= 0; n--) {
int bstart = static_cast<int>(batch_starts[n]);
int bend = static_cast<int>(batch_starts[n + 1]);
......@@ -320,29 +320,25 @@ class LSTMGradKernel : public framework::OpKernel<T> {
int pre_h_start = static_cast<int>(batch_starts[n - 1]);
int pre_h_end = pre_h_start + cur_batch_size;
auto pre_hidden_g = batch_hidden_g.Slice(pre_h_start, pre_h_end);
math::matmul<DeviceContext, T>(device_ctx, gate_g, false, *weight, true,
static_cast<T>(1.0), &pre_hidden_g,
static_cast<T>(1.0));
blas.MatMul(gate_g, false, *weight, true, static_cast<T>(1.0),
&pre_hidden_g, static_cast<T>(1.0));
if (weight_g) {
/* backward weight */
auto pre_hidden = batch_hidden.Slice(pre_h_start, pre_h_end);
math::matmul<DeviceContext, T>(device_ctx, pre_hidden, true, gate_g,
false, static_cast<T>(1.0), weight_g,
static_cast<T>(1.0));
blas.MatMul(pre_hidden, true, gate_g, false, static_cast<T>(1.0),
weight_g, static_cast<T>(1.0));
}
} else {
if (h0 && weight_g) {
ReorderInitState<DeviceContext, T>(device_ctx, *h0, order,
&ordered_h0, true);
math::matmul<DeviceContext, T>(device_ctx, ordered_h0, true, gate_g,
false, static_cast<T>(1.0), weight_g,
static_cast<T>(1.0));
blas.MatMul(ordered_h0, true, gate_g, false, static_cast<T>(1.0),
weight_g, static_cast<T>(1.0));
}
if (h0 && h0_g) {
ordered_h0_g.mutable_data<T>(h0_g->dims(), ctx.GetPlace());
math::matmul<DeviceContext, T>(device_ctx, gate_g, false, *weight,
true, static_cast<T>(1.0),
&ordered_h0_g, static_cast<T>(0.0));
blas.MatMul(gate_g, false, *weight, true, static_cast<T>(1.0),
&ordered_h0_g, static_cast<T>(0.0));
}
}
}
......
......@@ -143,7 +143,7 @@ class LSTMPKernel : public framework::OpKernel<T> {
auto proj_act = math::detail::GetActivationType(
ctx.Attr<std::string>("proj_activation"));
auto& place = *ctx.template device_context<DeviceContext>().eigen_device();
auto blas = math::GetBlas<DeviceContext, T>(device_ctx);
for (size_t n = 0; n < num_batch; n++) {
int bstart = static_cast<int>(batch_starts[n]);
int bend = static_cast<int>(batch_starts[n + 1]);
......@@ -160,9 +160,8 @@ class LSTMPKernel : public framework::OpKernel<T> {
int pre_h_start = static_cast<int>(batch_starts[n - 1]);
int pre_h_end = pre_h_start + cur_batch_size;
auto pre_proj_t = batch_proj.Slice(pre_h_start, pre_h_end);
math::matmul<DeviceContext, T>(device_ctx, pre_proj_t, false, *weight,
false, static_cast<T>(1.0), &gate_t,
static_cast<T>(1.0));
blas.MatMul(pre_proj_t, false, *weight, false, static_cast<T>(1.0),
&gate_t, static_cast<T>(1.0));
} else if (hidden_t0) {
// If n == 0 and there is no initialized hidden state, that is to say
// the H0 is zeros, the calculation W_h * H0 will be skiped.
......@@ -176,16 +175,14 @@ class LSTMPKernel : public framework::OpKernel<T> {
ordered_proj0->mutable_data<T>(ctx.GetPlace());
ReorderInitState<DeviceContext, T>(device_ctx, *hidden_t0, order,
&ordered_h0, true);
math::matmul<DeviceContext, T>(device_ctx, ordered_h0, false,
*proj_weight, false, static_cast<T>(1.0),
ordered_proj0, static_cast<T>(0.0));
blas.MatMul(ordered_h0, false, *proj_weight, false, static_cast<T>(1.0),
ordered_proj0, static_cast<T>(0.0));
if (proj_act != math::detail::ActivationType::kIdentity) {
auto proj0_dev = EigenMatrix<T>::From(*ordered_proj0);
ActCompute(cell_act, place, proj0_dev, proj0_dev);
}
math::matmul<DeviceContext, T>(device_ctx, *ordered_proj0, false,
*weight, false, static_cast<T>(1.0),
&gate_t, static_cast<T>(1.0));
blas.MatMul(*ordered_proj0, false, *weight, false, static_cast<T>(1.0),
&gate_t, static_cast<T>(1.0));
}
lstmp_value.gate_value = gate_t.data<T>();
......@@ -196,9 +193,8 @@ class LSTMPKernel : public framework::OpKernel<T> {
device_ctx, lstmp_value, frame_size, cur_batch_size, gate_act,
cell_act, cand_act);
lstmp_value.prev_state_value = lstmp_value.state_value;
math::matmul<DeviceContext, T>(device_ctx, hidden_t, false, *proj_weight,
false, static_cast<T>(1.0), &proj_t,
static_cast<T>(0.0));
blas.MatMul(hidden_t, false, *proj_weight, false, static_cast<T>(1.0),
&proj_t, static_cast<T>(0.0));
if (proj_act != math::detail::ActivationType::kIdentity) {
auto proj_t_dev = EigenMatrix<T>::From(proj_t);
ActCompute(cell_act, place, proj_t_dev, proj_t_dev);
......@@ -361,6 +357,7 @@ class LSTMPGradKernel : public framework::OpKernel<T> {
auto batch_starts = batch_gate->lod()[0];
size_t num_batch = batch_starts.size() - 1;
auto blas = math::GetBlas<DeviceContext, T>(device_ctx);
for (int n = static_cast<int>(num_batch) - 1; n >= 0; n--) {
int bstart = static_cast<int>(batch_starts[n]);
int bend = static_cast<int>(batch_starts[n + 1]);
......@@ -375,15 +372,13 @@ class LSTMPGradKernel : public framework::OpKernel<T> {
}
/* hidden state backwarad */
Tensor out_g = batch_hidden_g.Slice(bstart, bend);
math::matmul<DeviceContext, T>(device_ctx, proj_g, false, *proj_weight,
true, static_cast<T>(1.0), &out_g,
static_cast<T>(0.0));
blas.MatMul(proj_g, false, *proj_weight, true, static_cast<T>(1.0),
&out_g, static_cast<T>(0.0));
/* projection weight backward*/
if (proj_weight_g) {
Tensor hidden_t = batch_hidden->Slice(bstart, bend);
math::matmul<DeviceContext, T>(device_ctx, hidden_t, true, proj_g,
false, static_cast<T>(1.0),
proj_weight_g, static_cast<T>(1.0));
blas.MatMul(hidden_t, true, proj_g, false, static_cast<T>(1.0),
proj_weight_g, static_cast<T>(1.0));
}
Tensor gate = batch_gate->Slice(bstart, bend);
......@@ -419,24 +414,21 @@ class LSTMPGradKernel : public framework::OpKernel<T> {
int pre_h_start = static_cast<int>(batch_starts[n - 1]);
int pre_h_end = pre_h_start + cur_batch_size;
auto pre_proj_g = batch_proj_g.Slice(pre_h_start, pre_h_end);
math::matmul<DeviceContext, T>(device_ctx, gate_g, false, *weight, true,
static_cast<T>(1.0), &pre_proj_g,
static_cast<T>(1.0));
blas.MatMul(gate_g, false, *weight, true, static_cast<T>(1.0),
&pre_proj_g, static_cast<T>(1.0));
if (weight_g) {
/* weight backward*/
auto pre_proj = batch_proj.Slice(pre_h_start, pre_h_end);
math::matmul<DeviceContext, T>(device_ctx, pre_proj, true, gate_g,
false, static_cast<T>(1.0), weight_g,
static_cast<T>(1.0));
blas.MatMul(pre_proj, true, gate_g, false, static_cast<T>(1.0),
weight_g, static_cast<T>(1.0));
}
} else {
if (h0 && weight_g) {
ReorderInitState<DeviceContext, T>(device_ctx, *h0, order,
&ordered_h0, true);
if (weight_g) {
math::matmul<DeviceContext, T>(device_ctx, *ordered_proj0, true,
gate_g, false, static_cast<T>(1.0),
weight_g, static_cast<T>(1.0));
blas.MatMul(*ordered_proj0, true, gate_g, false,
static_cast<T>(1.0), weight_g, static_cast<T>(1.0));
}
}
if (h0 && (h0_g || proj_weight_g)) {
......@@ -444,9 +436,8 @@ class LSTMPGradKernel : public framework::OpKernel<T> {
Tensor proj0_g;
proj0_g.Resize({in_dims[0], proj_weight->dims()[1]});
proj0_g.mutable_data<T>(ctx.GetPlace());
math::matmul<DeviceContext, T>(device_ctx, gate_g, false, *weight,
true, static_cast<T>(1.0), &proj0_g,
static_cast<T>(0.0));
blas.MatMul(gate_g, false, *weight, true, static_cast<T>(1.0),
&proj0_g, static_cast<T>(0.0));
if (proj_act != math::detail::ActivationType::kIdentity) {
auto proj0_dev = EigenMatrix<T>::From(*ordered_proj0);
auto proj0_g_dev = EigenMatrix<T>::From(proj0_g);
......@@ -454,14 +445,12 @@ class LSTMPGradKernel : public framework::OpKernel<T> {
proj0_g_dev);
}
if (h0_g) {
math::matmul<DeviceContext, T>(
device_ctx, proj0_g, false, *proj_weight, true,
static_cast<T>(1.0), &ordered_h0_g, static_cast<T>(0.0));
blas.MatMul(proj0_g, false, *proj_weight, true, static_cast<T>(1.0),
&ordered_h0_g, static_cast<T>(0.0));
}
if (proj_weight_g) {
math::matmul<DeviceContext, T>(device_ctx, ordered_h0, true,
proj0_g, false, static_cast<T>(1.0),
proj_weight_g, static_cast<T>(1.0));
blas.MatMul(ordered_h0, true, proj0_g, false, static_cast<T>(1.0),
proj_weight_g, static_cast<T>(1.0));
}
}
}
......
......@@ -61,12 +61,10 @@ struct CUBlas<platform::float16> {
template <>
template <typename T>
void Blas<platform::CUDADeviceContext>::GEMM(const CBLAS_TRANSPOSE transA,
const CBLAS_TRANSPOSE transB,
const int M, const int N,
const int K, const T alpha,
const T *A, const T *B,
const T beta, T *C) const {
void Blas<platform::CUDADeviceContext>::GEMM(CBLAS_TRANSPOSE transA,
CBLAS_TRANSPOSE transB, int M,
int N, int K, T alpha, const T *A,
const T *B, T beta, T *C) const {
// Note that cublas follows fortran order, so the order is different from
// the cblas convention.
int lda = (transA == CblasNoTrans) ? K : M;
......@@ -83,10 +81,10 @@ void Blas<platform::CUDADeviceContext>::GEMM(const CBLAS_TRANSPOSE transA,
template <>
template <>
inline void Blas<platform::CUDADeviceContext>::GEMM(
const CBLAS_TRANSPOSE transA, const CBLAS_TRANSPOSE transB, const int M,
const int N, const int K, const platform::float16 alpha,
const platform::float16 *A, const platform::float16 *B,
const platform::float16 beta, platform::float16 *C) const {
CBLAS_TRANSPOSE transA, CBLAS_TRANSPOSE transB, int M, int N, int K,
platform::float16 alpha, const platform::float16 *A,
const platform::float16 *B, platform::float16 beta,
platform::float16 *C) const {
// Note that cublas follows fortran order, so the order is different from
// the cblas convention.
int lda = (transA == CblasNoTrans) ? K : M;
......@@ -134,14 +132,14 @@ inline void Blas<platform::CUDADeviceContext>::GEMM(
template <>
template <typename T>
void Blas<platform::CUDADeviceContext>::GEMM(
const bool transA, const bool transB, const int M, const int N, const int K,
const T alpha, const T *A, const int lda, const T *B, const int ldb,
const T beta, T *C, const int ldc) const {
void Blas<platform::CUDADeviceContext>::GEMM(bool transA, bool transB, int M,
int N, int K, T alpha, const T *A,
int lda, const T *B, int ldb,
T beta, T *C, int ldc) const {
// Note that cublas follows fortran order, so the order is different from
// the cblas convention.
cublasOperation_t cuTransA = transA == false ? CUBLAS_OP_N : CUBLAS_OP_T;
cublasOperation_t cuTransB = transB == false ? CUBLAS_OP_N : CUBLAS_OP_T;
cublasOperation_t cuTransA = transA ? CUBLAS_OP_T : CUBLAS_OP_N;
cublasOperation_t cuTransB = transB ? CUBLAS_OP_T : CUBLAS_OP_N;
CUBlas<T>::GEMM(context_.cublas_handle(), cuTransB, cuTransA, N, M, K, &alpha,
B, ldb, A, lda, &beta, C, ldc);
}
......
......@@ -45,12 +45,10 @@ struct CBlas<platform::float16> {
template <>
template <typename T>
void Blas<platform::CPUDeviceContext>::GEMM(const CBLAS_TRANSPOSE transA,
const CBLAS_TRANSPOSE transB,
const int M, const int N,
const int K, const T alpha,
const T *A, const T *B,
const T beta, T *C) const {
void Blas<platform::CPUDeviceContext>::GEMM(CBLAS_TRANSPOSE transA,
CBLAS_TRANSPOSE transB, int M,
int N, int K, T alpha, const T *A,
const T *B, T beta, T *C) const {
int lda = (transA == CblasNoTrans) ? K : M;
int ldb = (transB == CblasNoTrans) ? N : K;
int ldc = N;
......@@ -60,15 +58,41 @@ void Blas<platform::CPUDeviceContext>::GEMM(const CBLAS_TRANSPOSE transA,
template <>
template <typename T>
void Blas<platform::CPUDeviceContext>::GEMM(
const bool transA, const bool transB, const int M, const int N, const int K,
const T alpha, const T *A, const int lda, const T *B, const int ldb,
const T beta, T *C, const int ldc) const {
void Blas<platform::CPUDeviceContext>::GEMM(bool transA, bool transB, int M,
int N, int K, T alpha, const T *A,
int lda, const T *B, int ldb,
T beta, T *C, int ldc) const {
CBlas<T>::GEMM(CblasRowMajor, transA == false ? CblasNoTrans : CblasTrans,
transB == false ? CblasNoTrans : CblasTrans, M, N, K, alpha, A,
lda, B, ldb, beta, C, ldc);
}
template <typename DeviceContext>
template <typename T>
void Blas<DeviceContext>::MatMul(const framework::Tensor &mat_a, bool trans_a,
const framework::Tensor &mat_b, bool trans_b,
T alpha, framework::Tensor *mat_out,
T beta) const {
auto dim_a = mat_a.dims();
auto dim_b = mat_b.dims();
auto dim_out = mat_out->dims();
PADDLE_ENFORCE(dim_a.size() == 2 && dim_b.size() == 2 && dim_out.size() == 2,
"The input and output of matmul be matrix");
PADDLE_ENFORCE(
mat_a.place() == mat_b.place() && mat_a.place() == mat_out->place(),
"The places of matrices must be same");
int M = dim_out[0];
int N = dim_out[1];
int K = !trans_a ? dim_a[1] : dim_a[0];
CBLAS_TRANSPOSE transA = !trans_a ? CblasNoTrans : CblasTrans;
CBLAS_TRANSPOSE transB = !trans_b ? CblasNoTrans : CblasTrans;
this->GEMM(transA, transB, M, N, K, alpha, mat_a.data<T>(), mat_b.data<T>(),
beta, mat_out->data<T>());
}
} // namespace math
} // namespace operators
} // namespace paddle
......@@ -24,73 +24,6 @@ namespace math {
using float16 = paddle::platform::float16;
template <>
void matmul<platform::CPUDeviceContext, float16>(
const platform::CPUDeviceContext& context,
const framework::Tensor& matrix_a, bool trans_a,
const framework::Tensor& matrix_b, bool trans_b, float16 alpha,
framework::Tensor* matrix_out, float16 beta) {
PADDLE_THROW("float16 matmul not supported on CPU");
}
template <>
void matmul<platform::CPUDeviceContext, float>(
const platform::CPUDeviceContext& context,
const framework::Tensor& matrix_a, bool trans_a,
const framework::Tensor& matrix_b, bool trans_b, float alpha,
framework::Tensor* matrix_out, float beta) {
auto dim_a = matrix_a.dims();
auto dim_b = matrix_b.dims();
auto dim_out = matrix_out->dims();
PADDLE_ENFORCE(dim_a.size() == 2 && dim_b.size() == 2 && dim_out.size() == 2,
"The input and output of matmul be matrix");
PADDLE_ENFORCE(platform::is_cpu_place(matrix_a.place()) &&
platform::is_cpu_place(matrix_b.place()) &&
platform::is_cpu_place(matrix_out->place()),
"Matrix must all be in CPUPlace");
int M = dim_out[0];
int N = dim_out[1];
int K = (trans_a == false) ? dim_a[1] : dim_a[0];
CBLAS_TRANSPOSE transA = (trans_a == false) ? CblasNoTrans : CblasTrans;
CBLAS_TRANSPOSE transB = (trans_b == false) ? CblasNoTrans : CblasTrans;
Blas<platform::CPUDeviceContext>(context).GEMM(
transA, transB, M, N, K, alpha, matrix_a.data<float>(),
matrix_b.data<float>(), beta, matrix_out->data<float>());
}
template <>
void matmul<platform::CPUDeviceContext, double>(
const platform::CPUDeviceContext& context,
const framework::Tensor& matrix_a, bool trans_a,
const framework::Tensor& matrix_b, bool trans_b, double alpha,
framework::Tensor* matrix_out, double beta) {
auto dim_a = matrix_a.dims();
auto dim_b = matrix_b.dims();
auto dim_out = matrix_out->dims();
PADDLE_ENFORCE(dim_a.size() == 2 && dim_b.size() == 2 && dim_out.size() == 2,
"The input and output of matmul be matrix");
PADDLE_ENFORCE(platform::is_cpu_place(matrix_a.place()) &&
platform::is_cpu_place(matrix_b.place()) &&
platform::is_cpu_place(matrix_out->place()),
"Matrix must all be in CPUPlace");
int M = dim_out[0];
int N = dim_out[1];
int K = (trans_a == false) ? dim_a[1] : dim_a[0];
CBLAS_TRANSPOSE transA = (trans_a == false) ? CblasNoTrans : CblasTrans;
CBLAS_TRANSPOSE transB = (trans_b == false) ? CblasNoTrans : CblasTrans;
Blas<platform::CPUDeviceContext>(context).GEMM(
transA, transB, M, N, K, alpha, matrix_a.data<double>(),
matrix_b.data<double>(), beta, matrix_out->data<double>());
}
template <>
void batched_gemm<platform::CPUDeviceContext, float16>(
const platform::CPUDeviceContext& context, const CBLAS_TRANSPOSE transA,
......
......@@ -25,93 +25,6 @@ namespace math {
using float16 = paddle::platform::float16;
template <>
void matmul<platform::CUDADeviceContext, float16>(
const platform::CUDADeviceContext& context,
const framework::Tensor& matrix_a, bool trans_a,
const framework::Tensor& matrix_b, bool trans_b, float16 alpha,
framework::Tensor* matrix_out, float16 beta) {
auto dim_a = matrix_a.dims();
auto dim_b = matrix_b.dims();
auto dim_out = matrix_out->dims();
PADDLE_ENFORCE(dim_a.size() == 2 && dim_b.size() == 2 && dim_out.size() == 2,
"The input and output of matmul be matrix");
PADDLE_ENFORCE(platform::is_gpu_place(matrix_a.place()) &&
platform::is_gpu_place(matrix_b.place()) &&
platform::is_gpu_place(matrix_out->place()),
"Matrix must all be in CUDAPlace");
int M = dim_out[0];
int N = dim_out[1];
int K = (trans_a == false) ? dim_a[1] : dim_a[0];
CBLAS_TRANSPOSE transA = (trans_a == false) ? CblasNoTrans : CblasTrans;
CBLAS_TRANSPOSE transB = (trans_b == false) ? CblasNoTrans : CblasTrans;
Blas<platform::CUDADeviceContext>(context).GEMM(
transA, transB, M, N, K, alpha, matrix_a.data<float16>(),
matrix_b.data<float16>(), beta, matrix_out->data<float16>());
}
template <>
void matmul<platform::CUDADeviceContext, float>(
const platform::CUDADeviceContext& context,
const framework::Tensor& matrix_a, bool trans_a,
const framework::Tensor& matrix_b, bool trans_b, float alpha,
framework::Tensor* matrix_out, float beta) {
auto dim_a = matrix_a.dims();
auto dim_b = matrix_b.dims();
auto dim_out = matrix_out->dims();
PADDLE_ENFORCE(dim_a.size() == 2 && dim_b.size() == 2 && dim_out.size() == 2,
"The input and output of matmul be matrix");
PADDLE_ENFORCE(platform::is_gpu_place(matrix_a.place()) &&
platform::is_gpu_place(matrix_b.place()) &&
platform::is_gpu_place(matrix_out->place()),
"Matrix must all be in CUDAPlace");
int M = dim_out[0];
int N = dim_out[1];
int K = (trans_a == false) ? dim_a[1] : dim_a[0];
CBLAS_TRANSPOSE transA = (trans_a == false) ? CblasNoTrans : CblasTrans;
CBLAS_TRANSPOSE transB = (trans_b == false) ? CblasNoTrans : CblasTrans;
Blas<platform::CUDADeviceContext>(context).GEMM(
transA, transB, M, N, K, alpha, matrix_a.data<float>(),
matrix_b.data<float>(), beta, matrix_out->data<float>());
}
template <>
void matmul<platform::CUDADeviceContext, double>(
const platform::CUDADeviceContext& context,
const framework::Tensor& matrix_a, bool trans_a,
const framework::Tensor& matrix_b, bool trans_b, double alpha,
framework::Tensor* matrix_out, double beta) {
auto dim_a = matrix_a.dims();
auto dim_b = matrix_b.dims();
auto dim_out = matrix_out->dims();
PADDLE_ENFORCE(dim_a.size() == 2 && dim_b.size() == 2 && dim_out.size() == 2,
"The input and output of matmul be matrix");
PADDLE_ENFORCE(platform::is_gpu_place(matrix_a.place()) &&
platform::is_gpu_place(matrix_b.place()) &&
platform::is_gpu_place(matrix_out->place()),
"Matrix must all be in CUDAPlace");
int M = dim_out[0];
int N = dim_out[1];
int K = (trans_a == false) ? dim_a[1] : dim_a[0];
CBLAS_TRANSPOSE transA = (trans_a == false) ? CblasNoTrans : CblasTrans;
CBLAS_TRANSPOSE transB = (trans_b == false) ? CblasNoTrans : CblasTrans;
Blas<platform::CUDADeviceContext>(context).GEMM(
transA, transB, M, N, K, alpha, matrix_a.data<double>(),
matrix_b.data<double>(), beta, matrix_out->data<double>());
}
template <>
void batched_gemm<platform::CUDADeviceContext, float16>(
const platform::CUDADeviceContext& context, const CBLAS_TRANSPOSE transA,
......
......@@ -64,14 +64,31 @@ class Blas {
explicit Blas(const DeviceContext& context) : context_(context) {}
template <typename T>
void GEMM(const CBLAS_TRANSPOSE transA, const CBLAS_TRANSPOSE transB,
const int M, const int N, const int K, const T alpha, const T* A,
const T* B, const T beta, T* C) const;
void GEMM(CBLAS_TRANSPOSE transA, CBLAS_TRANSPOSE transB, int M, int N, int K,
T alpha, const T* A, const T* B, T beta, T* C) const;
template <typename T>
void GEMM(const bool transA, const bool transB, const int M, const int N,
const int K, const T alpha, const T* A, const int lda, const T* B,
const int ldb, const T beta, T* C, const int ldc) const;
void GEMM(bool transA, bool transB, int M, int N, int K, T alpha, const T* A,
int lda, const T* B, int ldb, T beta, T* C, int ldc) const;
template <typename T>
void MatMul(const framework::Tensor& mat_a, bool trans_a,
const framework::Tensor& mat_b, bool trans_b, T alpha,
framework::Tensor* mat_out, T beta) const;
template <typename T>
void MatMul(const framework::Tensor& mat_a, bool trans_a,
const framework::Tensor& mat_b, bool trans_b,
framework::Tensor* mat_out) const {
MatMul(mat_a, trans_a, mat_b, trans_b, static_cast<T>(1.0), mat_out,
static_cast<T>(0.0));
}
template <typename T>
void MatMul(const framework::Tensor& mat_a, const framework::Tensor& mat_b,
framework::Tensor* mat_out) const {
this->template MatMul<T>(mat_a, false, mat_b, false, mat_out);
}
private:
const DeviceContext& context_;
......@@ -86,6 +103,11 @@ class BlasT : private Blas<DeviceContext> {
void GEMM(ARGS... args) const {
static_cast<const Blas<DeviceContext>*>(this)->template GEMM<T>(args...);
}
template <typename... ARGS>
void MatMul(ARGS... args) const {
static_cast<const Blas<DeviceContext>*>(this)->template MatMul<T>(args...);
}
};
template <typename DeviceContext, typename T>
......@@ -100,12 +122,6 @@ inline BlasT<DeviceContext, T> GetBlas(const DeviceContext& dev_ctx) {
return BlasT<DeviceContext, T>(dev_ctx);
}
// matrix multiply with continuous memory
template <typename DeviceContext, typename T>
void matmul(const DeviceContext& context, const framework::Tensor& matrix_a,
bool trans_a, const framework::Tensor& matrix_b, bool trans_b,
T alpha, framework::Tensor* matrix_out, T beta);
// Batched gemm
template <typename DeviceContext, typename T>
void batched_gemm(const DeviceContext& context, const CBLAS_TRANSPOSE transA,
......
......@@ -23,6 +23,13 @@ void fill_fp16_data(paddle::platform::float16* in_ptr, size_t size,
}
}
template <typename T>
inline paddle::operators::math::BlasT<paddle::platform::CUDADeviceContext, T>
GetBlas(const paddle::platform::CUDADeviceContext& context) {
return paddle::operators::math::GetBlas<paddle::platform::CUDADeviceContext,
T>(context);
}
TEST(math_function, notrans_mul_trans_fp32) {
paddle::framework::Tensor input1;
paddle::framework::Tensor input1_gpu;
......@@ -42,9 +49,8 @@ TEST(math_function, notrans_mul_trans_fp32) {
paddle::framework::TensorCopySync(input1, gpu_place, &input2_gpu);
out_gpu.mutable_data<float>({2, 2}, gpu_place);
paddle::operators::math::matmul<paddle::platform::CUDADeviceContext, float>(
context, input1_gpu, false, input2_gpu, true, 1, &out_gpu, 0);
GetBlas<float>(context).MatMul(input1_gpu, false, input2_gpu, true, 1,
&out_gpu, 0);
paddle::framework::TensorCopySync(out_gpu, cpu_place, &out);
......@@ -81,10 +87,9 @@ TEST(math_function, notrans_mul_trans_fp16) {
out_gpu.mutable_data<paddle::platform::float16>({2, 2}, gpu_place);
paddle::operators::math::matmul<paddle::platform::CUDADeviceContext,
paddle::platform::float16>(
context, input1_gpu, false, input2_gpu, true,
paddle::platform::float16(1), &out_gpu, paddle::platform::float16(0));
GetBlas<paddle::platform::float16>(context).MatMul(
input1_gpu, false, input2_gpu, true, paddle::platform::float16(1),
&out_gpu, paddle::platform::float16(0));
paddle::framework::TensorCopySync(out_gpu, cpu_place, &out);
......@@ -116,8 +121,8 @@ TEST(math_function, trans_mul_notrans_fp32) {
out_gpu.mutable_data<float>({3, 3}, gpu_place);
paddle::operators::math::matmul<paddle::platform::CUDADeviceContext, float>(
context, input1_gpu, true, input2_gpu, false, 1, &out_gpu, 0);
GetBlas<float>(context).MatMul(input1_gpu, true, input2_gpu, false, 1,
&out_gpu, 0);
paddle::framework::TensorCopySync(out_gpu, cpu_place, &out);
......@@ -159,10 +164,9 @@ TEST(math_function, trans_mul_notrans_fp16) {
out_gpu.mutable_data<paddle::platform::float16>({3, 3}, gpu_place);
paddle::operators::math::matmul<paddle::platform::CUDADeviceContext,
paddle::platform::float16>(
context, input1_gpu, true, input2_gpu, false,
paddle::platform::float16(1), &out_gpu, paddle::platform::float16(0));
GetBlas<paddle::platform::float16>(context).MatMul(
input1_gpu, true, input2_gpu, false, paddle::platform::float16(1),
&out_gpu, paddle::platform::float16(0));
paddle::framework::TensorCopySync(out_gpu, cpu_place, &out);
......@@ -179,13 +183,6 @@ TEST(math_function, trans_mul_notrans_fp16) {
EXPECT_EQ(static_cast<float>(out_ptr[8]), 29);
}
template <typename T>
inline paddle::operators::math::BlasT<paddle::platform::CUDADeviceContext, T>
GetBlas(const paddle::platform::CUDADeviceContext& context) {
return paddle::operators::math::GetBlas<paddle::platform::CUDADeviceContext,
T>(context);
}
TEST(math_function, gemm_notrans_cublas_fp32) {
paddle::framework::Tensor input1;
paddle::framework::Tensor input2;
......
......@@ -46,9 +46,10 @@ class MulKernel : public framework::OpKernel<T> {
if (z_dim.size() != 2) {
z->Resize({x_matrix.dims()[0], y_matrix.dims()[1]});
}
math::matmul<DeviceContext, T>(
context.template device_context<DeviceContext>(), x_matrix, false,
y_matrix, false, static_cast<T>(1), z, static_cast<T>(0));
auto blas = math::GetBlas<DeviceContext, T>(context);
blas.MatMul(x_matrix, y_matrix, z);
if (z_dim.size() != 2) {
z->Resize(z_dim);
}
......@@ -79,6 +80,7 @@ class MulGradKernel : public framework::OpKernel<T> {
Tensor* dx = ctx.Output<Tensor>(framework::GradVarName("X"));
Tensor* dy = ctx.Output<Tensor>(framework::GradVarName("Y"));
auto& dev_ctx = ctx.template device_context<DeviceContext>();
auto blas = math::GetBlas<DeviceContext, T>(dev_ctx);
if (dx) {
dx->mutable_data<T>(ctx.GetPlace());
Tensor dx_matrix = dx->dims().size() > 2
......@@ -86,8 +88,7 @@ class MulGradKernel : public framework::OpKernel<T> {
: *dx;
// dx = dout * y'. dx: M x K, dout : M x N, y : K x N
math::matmul<DeviceContext, T>(dev_ctx, dout_mat, false, y_matrix, true,
1, &dx_matrix, 0);
blas.MatMul(dout_mat, false, y_matrix, true, &dx_matrix);
}
if (dy) {
dy->mutable_data<T>(ctx.GetPlace());
......@@ -95,8 +96,7 @@ class MulGradKernel : public framework::OpKernel<T> {
? framework::ReshapeToMatrix(*dy, y_num_col_dims)
: *dy;
// dy = x' * dout. dy K x N, dout : M x N, x : M x K
math::matmul<DeviceContext, T>(dev_ctx, x_matrix, true, dout_mat, false,
1, &dy_matrix, 0);
blas.MatMul(x_matrix, true, dout_mat, false, &dy_matrix);
}
}
};
......
......@@ -58,17 +58,15 @@ class SequenceConvKernel : public framework::OpKernel<T> {
// Because if padding_trainable is false, padding data should be zeros.
math::SetConstant<DeviceContext, T> set_zero;
auto& dev_ctx = context.template device_context<DeviceContext>();
auto blas = math::GetBlas<DeviceContext, T>(dev_ctx);
set_zero(dev_ctx, &col, static_cast<T>(0));
math::ContextProjectFunctor<DeviceContext, T> seq_project_functor;
seq_project_functor(dev_ctx, *in, *padding_data, padding_trainable,
context_start, context_length, context_stride, up_pad,
down_pad, &col);
math::matmul<DeviceContext, T>(dev_ctx, col, false, filter, false,
static_cast<T>(1.0), out,
static_cast<T>(0.0));
blas.MatMul(col, filter, out);
}
};
......@@ -99,6 +97,7 @@ class SequenceConvGradKernel : public framework::OpKernel<T> {
math::SetConstant<DeviceContext, T> set_zero;
auto& dev_ctx = context.template device_context<DeviceContext>();
auto blas = math::GetBlas<DeviceContext, T>(dev_ctx);
// use col_shape in the im2col calculation
framework::DDim col_shape = {in->dims()[0],
sequence_width * context_length};
......@@ -108,8 +107,7 @@ class SequenceConvGradKernel : public framework::OpKernel<T> {
col.mutable_data<T>(col_shape, context.GetPlace());
// Because if padding_trainable is false, padding data should be zeros.
set_zero(dev_ctx, &col, static_cast<T>(0));
math::matmul<DeviceContext, T>(dev_ctx, *out_g, false, *filter, true,
T(1.0), &col, T(1.0));
blas.MatMul(*out_g, false, *filter, true, &col);
}
math::ContextProjectFunctor<DeviceContext, T> seq_project_functor;
math::ContextProjectGradFunctor<DeviceContext, T> seq_project_grad_functor;
......@@ -150,8 +148,7 @@ class SequenceConvGradKernel : public framework::OpKernel<T> {
context_start, context_length, context_stride, up_pad,
down_pad, &col);
math::matmul<DeviceContext, T>(dev_ctx, col, true, out_grad, false,
T(1.0), &filter_grad, T(1.0));
blas.MatMul(col, true, out_grad, false, &filter_grad);
}
}
};
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册