未验证 提交 2abcf379 编写于 作者: Y Yu Yang 提交者: GitHub

Merge pull request #10327 from reyoung/feature/clean_blas

Feature/clean blas
......@@ -61,9 +61,9 @@ class BilinearTensorProductKernel : public framework::OpKernel<T> {
auto output_col_vec = output_mat.chip(i, 1);
Tensor weight_mat =
weight->Slice(i, i + 1).Resize(framework::make_ddim({x_dim, y_dim}));
math::gemm<DeviceContext, T>(dev_ctx, CblasNoTrans, CblasNoTrans,
batch_size, y_dim, x_dim, 1, x->data<T>(),
weight_mat.data<T>(), 0, left_mul.data<T>());
math::GetBlas<DeviceContext, T>(dev_ctx).GEMM(
CblasNoTrans, CblasNoTrans, batch_size, y_dim, x_dim, 1, x->data<T>(),
weight_mat.data<T>(), 0, left_mul.data<T>());
output_col_vec.device(place) =
(left_mul_mat * y_mat).sum(Eigen::DSizes<int, 1>(1));
}
......@@ -125,6 +125,8 @@ class BilinearTensorProductGradKernel : public framework::OpKernel<T> {
set_zero(dev_ctx, d_y, static_cast<T>(0));
}
auto blas = math::GetBlas<DeviceContext, T>(ctx);
// Caculate the Output(X@Grad) and Output(Y@Grad).
if (d_x || d_y) {
Eigen::DSizes<int, 2> bcast_for_x(1, y_dim);
......@@ -138,18 +140,16 @@ class BilinearTensorProductGradKernel : public framework::OpKernel<T> {
output_vec.reshape(Eigen::DSizes<int, 2>(batch_size, 1))
.broadcast(bcast_for_x) *
y_mat;
math::gemm<DeviceContext, T>(
dev_ctx, CblasNoTrans, CblasTrans, batch_size, x_dim, y_dim, 1,
y_scale.data<T>(), weight_i.data<T>(), 1, d_x->data<T>());
blas.GEMM(CblasNoTrans, CblasTrans, batch_size, x_dim, y_dim, 1,
y_scale.data<T>(), weight_i.data<T>(), 1, d_x->data<T>());
}
if (d_y) {
x_scale_mat.device(place) =
output_vec.reshape(Eigen::DSizes<int, 2>(batch_size, 1))
.broadcast(bcast_for_y) *
x_mat;
math::gemm<DeviceContext, T>(
dev_ctx, CblasNoTrans, CblasNoTrans, batch_size, y_dim, x_dim, 1,
x_scale.data<T>(), weight_i.data<T>(), 1, d_y->data<T>());
blas.GEMM(CblasNoTrans, CblasNoTrans, batch_size, y_dim, x_dim, 1,
x_scale.data<T>(), weight_i.data<T>(), 1, d_y->data<T>());
}
}
}
......@@ -166,9 +166,8 @@ class BilinearTensorProductGradKernel : public framework::OpKernel<T> {
output_vec.reshape(Eigen::DSizes<int, 2>(batch_size, 1))
.broadcast(bcast_for_weight) *
x_mat;
math::gemm<DeviceContext, T>(dev_ctx, CblasTrans, CblasNoTrans, x_dim,
y_dim, batch_size, 1, x_scale.data<T>(),
y->data<T>(), 0, d_weight_i.data<T>());
blas.GEMM(CblasTrans, CblasNoTrans, x_dim, y_dim, batch_size, 1,
x_scale.data<T>(), y->data<T>(), 0, d_weight_i.data<T>());
}
}
......
......@@ -87,10 +87,10 @@ class GRUUnitKernel : public framework::OpKernel<T> {
const T* weight_data = weight->data<T>();
T* gate_data = gate->data<T>();
T* reset_hidden_prev_data = reset_hidden_prev->data<T>();
math::gemm<DeviceContext, T>(
context.template device_context<DeviceContext>(), false, false,
batch_size, 2 * frame_size, frame_size, 1, hidden_prev_data, frame_size,
weight_data, frame_size * 2, 1, gate_data, frame_size * 3);
auto blas = math::GetBlas<DeviceContext, T>(context);
blas.GEMM(false, false, batch_size, 2 * frame_size, frame_size, 1,
hidden_prev_data, frame_size, weight_data, frame_size * 2, 1,
gate_data, frame_size * 3);
// calculate activited gate
Eigen::array<int, 2> extents({{batch_size, frame_size}});
......@@ -103,11 +103,10 @@ class GRUUnitKernel : public framework::OpKernel<T> {
g.slice(r_offsets, extents), g.slice(r_offsets, extents));
auto r = g.slice(r_offsets, extents); // reset gate
r_h_p.device(place) = r * h_p; // reset previous hidden state
math::gemm<DeviceContext, T>(
context.template device_context<DeviceContext>(), false, false,
batch_size, frame_size, frame_size, 1, reset_hidden_prev_data,
frame_size, weight_data + frame_size * frame_size * 2, frame_size, 1,
gate_data + frame_size * 2, frame_size * 3);
blas.GEMM(false, false, batch_size, frame_size, frame_size, 1,
reset_hidden_prev_data, frame_size,
weight_data + frame_size * frame_size * 2, frame_size, 1,
gate_data + frame_size * 2, frame_size * 3);
Eigen::array<int, 2> c_offsets({{0, frame_size * 2}});
ActCompute(context.Attr<int>("activation"), place,
......@@ -188,11 +187,11 @@ class GRUUnitGradKernel : public framework::OpKernel<T> {
ActGradCompute(context.Attr<int>("activation"), place, c, c,
d_g.slice(c_offsets, extents), d_h * u);
// backward for reset_hidden_prev
math::gemm<DeviceContext, T>(
context.template device_context<DeviceContext>(), false, true,
batch_size, frame_size, frame_size, 1, gate_grad_data + frame_size * 2,
frame_size * 3, weight_data + frame_size * frame_size * 2, frame_size,
0, reset_hidden_prev_grad_data, frame_size);
auto blas = math::GetBlas<DeviceContext, T>(context);
blas.GEMM(false, true, batch_size, frame_size, frame_size, 1,
gate_grad_data + frame_size * 2, frame_size * 3,
weight_data + frame_size * frame_size * 2, frame_size, 0,
reset_hidden_prev_grad_data, frame_size);
// backward for unactivated reset gate
ActGradCompute(context.Attr<int>("gate_activation"), place, r, r,
d_g.slice(r_offsets, extents), d_r_h_p * h_p);
......@@ -200,18 +199,15 @@ class GRUUnitGradKernel : public framework::OpKernel<T> {
if (weight_grad) {
T* weight_grad_data = weight_grad->mutable_data<T>(context.GetPlace());
// backward for state_weight
math::gemm<DeviceContext, T>(
context.template device_context<DeviceContext>(), true, false,
frame_size, frame_size, batch_size, 1, reset_hidden_prev_data,
frame_size, gate_grad_data + frame_size * 2, frame_size * 3, 0,
weight_grad_data + frame_size * frame_size * 2, frame_size);
blas.GEMM(true, false, frame_size, frame_size, batch_size, 1,
reset_hidden_prev_data, frame_size,
gate_grad_data + frame_size * 2, frame_size * 3, 0,
weight_grad_data + frame_size * frame_size * 2, frame_size);
// backward for update_gate_weight and reset_gate_weight
math::gemm<DeviceContext, T>(
context.template device_context<DeviceContext>(), true, false,
frame_size, frame_size * 2, batch_size, 1, hidden_prev_data,
frame_size, gate_grad_data, frame_size * 3, 0, weight_grad_data,
frame_size * 2);
blas.GEMM(true, false, frame_size, frame_size * 2, batch_size, 1,
hidden_prev_data, frame_size, gate_grad_data, frame_size * 3, 0,
weight_grad_data, frame_size * 2);
}
// backward for hidden_prev
if (hidden_prev_grad) {
......@@ -219,11 +215,9 @@ class GRUUnitGradKernel : public framework::OpKernel<T> {
hidden_prev_grad->mutable_data<T>(context.GetPlace());
auto d_h_p = EigenMatrix<T>::From(*hidden_prev_grad);
d_h_p.device(place) = d_r_h_p * r + d_h * (u.constant(T(1)) - u);
math::gemm<DeviceContext, T>(
context.template device_context<DeviceContext>(), false, true,
batch_size, frame_size, frame_size * 2, 1, gate_grad_data,
frame_size * 3, weight_data, frame_size * 2, 1, hidden_prev_grad_data,
frame_size);
blas.GEMM(false, true, batch_size, frame_size, frame_size * 2, 1,
gate_grad_data, frame_size * 3, weight_data, frame_size * 2, 1,
hidden_prev_grad_data, frame_size);
}
// backward for input
if (input_grad) {
......
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/platform/dynload/cublas.h"
namespace paddle {
namespace operators {
namespace math {
template <typename T>
struct CUBlas;
template <>
struct CUBlas<float> {
template <typename... ARGS>
static void GEMM(ARGS... args) {
PADDLE_ENFORCE(platform::dynload::cublasSgemm(args...));
}
};
template <>
struct CUBlas<double> {
template <typename... ARGS>
static void GEMM(ARGS... args) {
PADDLE_ENFORCE(platform::dynload::cublasDgemm(args...));
}
};
template <>
struct CUBlas<platform::float16> {
using float16 = platform::float16;
static void GEMM(cublasHandle_t handle, cublasOperation_t transa,
cublasOperation_t transb, int m, int n, int k,
const float16 *alpha, const float16 *A, int lda,
const float16 *B, int ldb, const float16 *beta, float16 *C,
int ldc) {
PADDLE_ENFORCE(
platform::dynload::cublasHgemm(handle, transa, transb, m, n, k,
reinterpret_cast<const __half *>(alpha),
reinterpret_cast<const __half *>(A), lda,
reinterpret_cast<const __half *>(B), ldb,
reinterpret_cast<const __half *>(beta),
reinterpret_cast<__half *>(C), ldc));
}
};
template <>
template <typename T>
void Blas<platform::CUDADeviceContext>::GEMM(const CBLAS_TRANSPOSE transA,
const CBLAS_TRANSPOSE transB,
const int M, const int N,
const int K, const T alpha,
const T *A, const T *B,
const T beta, T *C) const {
// Note that cublas follows fortran order, so the order is different from
// the cblas convention.
int lda = (transA == CblasNoTrans) ? K : M;
int ldb = (transB == CblasNoTrans) ? N : K;
cublasOperation_t cuTransA =
(transA == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
cublasOperation_t cuTransB =
(transB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
CUBlas<T>::GEMM(context_.cublas_handle(), cuTransB, cuTransA, N, M, K, &alpha,
B, ldb, A, lda, &beta, C, N);
}
template <>
template <>
inline void Blas<platform::CUDADeviceContext>::GEMM(
const CBLAS_TRANSPOSE transA, const CBLAS_TRANSPOSE transB, const int M,
const int N, const int K, const platform::float16 alpha,
const platform::float16 *A, const platform::float16 *B,
const platform::float16 beta, platform::float16 *C) const {
// Note that cublas follows fortran order, so the order is different from
// the cblas convention.
int lda = (transA == CblasNoTrans) ? K : M;
int ldb = (transB == CblasNoTrans) ? N : K;
cublasOperation_t cuTransA =
(transA == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
cublasOperation_t cuTransB =
(transB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
// TODO(kexinzhao): add processing code for compute capability < 53 case
PADDLE_ENFORCE_GE(context_.GetComputeCapability(), 53,
"cublas fp16 gemm requires GPU compute capability >= 53");
#if CUDA_VERSION >= 8000
float h_alpha = static_cast<float>(alpha);
float h_beta = static_cast<float>(beta);
cublasGemmAlgo_t algo = CUBLAS_GEMM_DFALT;
#if CUDA_VERSION >= 9000
if (context_.GetComputeCapability() >= 70) {
PADDLE_ENFORCE(platform::dynload::cublasSetMathMode(
context_.cublas_handle(), CUBLAS_TENSOR_OP_MATH));
algo = CUBLAS_GEMM_DFALT_TENSOR_OP;
} else {
PADDLE_ENFORCE(platform::dynload::cublasSetMathMode(
context_.cublas_handle(), CUBLAS_DEFAULT_MATH));
}
#endif // CUDA_VERSION >= 9000
// cublasHgemm does true FP16 computation which is slow for non-Volta
// GPUs. So use cublasGemmEx instead which does pesudo FP16 computation:
// input/output in fp16, computation in fp32, which can also be accelerated
// using tensor cores in volta GPUs.
PADDLE_ENFORCE(platform::dynload::cublasGemmEx(
context_.cublas_handle(), cuTransB, cuTransA, N, M, K, &h_alpha, B,
CUDA_R_16F, ldb, A, CUDA_R_16F, lda, &h_beta, C, CUDA_R_16F, N,
CUDA_R_32F, algo));
#else
// CUDA 7.5 does not support cublasGemmEx, hence we fall back to use hgemm
CUBlas<platform::float16>::GEMM(context_.cublas_handle(), cuTransB, cuTransA,
N, M, K, &h_alpha, h_B, ldb, h_A, lda,
&h_beta, h_C, N);
#endif // CUDA_VERSION >= 8000
}
template <>
template <typename T>
void Blas<platform::CUDADeviceContext>::GEMM(
const bool transA, const bool transB, const int M, const int N, const int K,
const T alpha, const T *A, const int lda, const T *B, const int ldb,
const T beta, T *C, const int ldc) const {
// Note that cublas follows fortran order, so the order is different from
// the cblas convention.
cublasOperation_t cuTransA = transA == false ? CUBLAS_OP_N : CUBLAS_OP_T;
cublasOperation_t cuTransB = transB == false ? CUBLAS_OP_N : CUBLAS_OP_T;
CUBlas<T>::GEMM(context_.cublas_handle(), cuTransB, cuTransA, N, M, K, &alpha,
B, ldb, A, lda, &beta, C, ldc);
}
} // namespace math
} // namespace operators
} // namespace paddle
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/operators/math/math_function.h"
namespace paddle {
namespace operators {
namespace math {
template <typename T>
struct CBlas;
template <>
struct CBlas<float> {
template <typename... ARGS>
static void GEMM(ARGS... args) {
cblas_sgemm(args...);
}
};
template <>
struct CBlas<double> {
template <typename... ARGS>
static void GEMM(ARGS... args) {
cblas_dgemm(args...);
}
};
template <>
struct CBlas<platform::float16> {
static void GEMM(...) { PADDLE_THROW("float16 GEMM not supported on CPU"); }
};
template <>
template <typename T>
void Blas<platform::CPUDeviceContext>::GEMM(const CBLAS_TRANSPOSE transA,
const CBLAS_TRANSPOSE transB,
const int M, const int N,
const int K, const T alpha,
const T *A, const T *B,
const T beta, T *C) const {
int lda = (transA == CblasNoTrans) ? K : M;
int ldb = (transB == CblasNoTrans) ? N : K;
int ldc = N;
CBlas<T>::GEMM(CblasRowMajor, transA, transB, M, N, K, alpha, A, lda, B, ldb,
beta, C, ldc);
}
template <>
template <typename T>
void Blas<platform::CPUDeviceContext>::GEMM(
const bool transA, const bool transB, const int M, const int N, const int K,
const T alpha, const T *A, const int lda, const T *B, const int ldb,
const T beta, T *C, const int ldc) const {
CBlas<T>::GEMM(CblasRowMajor, transA == false ? CblasNoTrans : CblasTrans,
transB == false ? CblasNoTrans : CblasTrans, M, N, K, alpha, A,
lda, B, ldb, beta, C, ldc);
}
} // namespace math
} // namespace operators
} // namespace paddle
......@@ -25,21 +25,21 @@ struct GRUUnitFunctor<platform::CPUDeviceContext, T> {
const detail::ActivationType active_node,
const detail::ActivationType active_gate) {
#ifndef __NVCC__
auto blas = math::GetBlas<platform::CPUDeviceContext, T>(context);
if (value.prev_out_value) {
math::gemm<platform::CPUDeviceContext, T>(
context, false, false, batch_size, frame_size * 2, frame_size, 1,
value.prev_out_value, frame_size, value.gate_weight, frame_size * 2,
1, value.gate_value, frame_size * 3);
blas.GEMM(false, false, batch_size, frame_size * 2, frame_size, 1,
value.prev_out_value, frame_size, value.gate_weight,
frame_size * 2, 1, value.gate_value, frame_size * 3);
}
detail::forward_reset_output(detail::forward::gru_resetOutput<T>(), value,
frame_size, batch_size, active_gate);
if (value.prev_out_value) {
math::gemm<platform::CPUDeviceContext, T>(
context, false, false, batch_size, frame_size, frame_size, 1,
value.reset_output_value, frame_size, value.state_weight, frame_size,
1, value.gate_value + frame_size * 2, frame_size * 3);
blas.GEMM(false, false, batch_size, frame_size, frame_size, 1,
value.reset_output_value, frame_size, value.state_weight,
frame_size, 1, value.gate_value + frame_size * 2,
frame_size * 3);
}
detail::forward_final_output(detail::forward::gru_finalOutput<T>(), value,
......@@ -58,36 +58,32 @@ struct GRUUnitGradFunctor<platform::CPUDeviceContext, T> {
#ifndef __NVCC__
detail::backward_state_grad(detail::backward::gru_stateGrad<T>(), value,
grad, frame_size, batch_size, active_node);
auto blas = math::GetBlas<platform::CPUDeviceContext, T>(context);
if (value.prev_out_value && grad.prev_out_grad) {
math::gemm<platform::CPUDeviceContext, T>(
context, false, true, batch_size, frame_size, frame_size, 1,
grad.gate_grad + frame_size * 2, frame_size * 3, value.state_weight,
frame_size, 0, grad.reset_output_grad, frame_size);
blas.GEMM(false, true, batch_size, frame_size, frame_size, 1,
grad.gate_grad + frame_size * 2, frame_size * 3,
value.state_weight, frame_size, 0, grad.reset_output_grad,
frame_size);
if (grad.state_weight_grad) {
math::gemm<platform::CPUDeviceContext, T>(
context, true, false, frame_size, frame_size, batch_size, 1,
value.reset_output_value, frame_size,
grad.gate_grad + frame_size * 2, frame_size * 3, 1,
grad.state_weight_grad, frame_size);
blas.GEMM(true, false, frame_size, frame_size, batch_size, 1,
value.reset_output_value, frame_size,
grad.gate_grad + frame_size * 2, frame_size * 3, 1,
grad.state_weight_grad, frame_size);
}
}
detail::backward_reset_grad(detail::backward::gru_resetGrad<T>(), value,
grad, frame_size, batch_size, active_gate);
if (grad.prev_out_grad && value.prev_out_value) {
math::gemm<platform::CPUDeviceContext, T>(
context, false, true, batch_size, frame_size, frame_size * 2, 1,
grad.gate_grad, frame_size * 3, value.gate_weight, frame_size * 2, 1,
grad.prev_out_grad, frame_size);
blas.GEMM(false, true, batch_size, frame_size, frame_size * 2, 1,
grad.gate_grad, frame_size * 3, value.gate_weight,
frame_size * 2, 1, grad.prev_out_grad, frame_size);
if (grad.gate_weight_grad) {
math::gemm<platform::CPUDeviceContext, T>(
context, true, false, frame_size, frame_size * 2, batch_size, 1,
value.prev_out_value, frame_size, grad.gate_grad, frame_size * 3, 1,
grad.gate_weight_grad, frame_size * 2);
blas.GEMM(true, false, frame_size, frame_size * 2, batch_size, 1,
value.prev_out_value, frame_size, grad.gate_grad,
frame_size * 3, 1, grad.gate_weight_grad, frame_size * 2);
}
}
#endif
......
......@@ -9,6 +9,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <paddle/fluid/platform/device_context.h>
#include "paddle/fluid/operators/math/detail/gru_gpu_kernel.h"
#include "paddle/fluid/operators/math/detail/gru_kernel.h"
#include "paddle/fluid/operators/math/gru_compute.h"
......@@ -36,12 +37,11 @@ struct GRUUnitFunctor<platform::CUDADeviceContext, T> {
threads = dim3(32, 32);
grid = dim3((frame_size + 32 - 1) / 32, (batch_size + 32 - 1) / 32);
}
auto blas = math::GetBlas<platform::CUDADeviceContext, T>(context);
if (value.prev_out_value) {
math::gemm<platform::CUDADeviceContext, T>(
context, false, false, batch_size, frame_size * 2, frame_size, 1,
value.prev_out_value, frame_size, value.gate_weight, frame_size * 2,
1, value.gate_value, frame_size * 3);
blas.GEMM(false, false, batch_size, frame_size * 2, frame_size, 1,
value.prev_out_value, frame_size, value.gate_weight,
frame_size * 2, 1, value.gate_value, frame_size * 3);
}
if (batch_size == 1) {
......@@ -61,10 +61,10 @@ struct GRUUnitFunctor<platform::CUDADeviceContext, T> {
}
if (value.prev_out_value) {
math::gemm<platform::CUDADeviceContext, T>(
context, false, false, batch_size, frame_size, frame_size, 1,
value.reset_output_value, frame_size, value.state_weight, frame_size,
1, value.gate_value + frame_size * 2, frame_size * 3);
blas.GEMM(false, false, batch_size, frame_size, frame_size, 1,
value.reset_output_value, frame_size, value.state_weight,
frame_size, 1, value.gate_value + frame_size * 2,
frame_size * 3);
}
if (batch_size == 1) {
......@@ -121,18 +121,19 @@ struct GRUUnitGradFunctor<platform::CUDADeviceContext, T> {
grad.output_grad, frame_size, batch_size, active_node);
}
auto blas = math::GetBlas<platform::CUDADeviceContext, T>(context);
if (value.prev_out_value && grad.prev_out_grad) {
math::gemm<platform::CUDADeviceContext, T>(
context, false, true, batch_size, frame_size, frame_size, 1,
grad.gate_grad + frame_size * 2, frame_size * 3, value.state_weight,
frame_size, 0, grad.reset_output_grad, frame_size);
blas.GEMM(false, true, batch_size, frame_size, frame_size, 1,
grad.gate_grad + frame_size * 2, frame_size * 3,
value.state_weight, frame_size, 0, grad.reset_output_grad,
frame_size);
if (grad.state_weight_grad) {
math::gemm<platform::CUDADeviceContext, T>(
context, true, false, frame_size, frame_size, batch_size, 1,
value.reset_output_value, frame_size,
grad.gate_grad + frame_size * 2, frame_size * 3, 1,
grad.state_weight_grad, frame_size);
blas.GEMM(true, false, frame_size, frame_size, batch_size, 1,
value.reset_output_value, frame_size,
grad.gate_grad + frame_size * 2, frame_size * 3, 1,
grad.state_weight_grad, frame_size);
}
}
......@@ -153,16 +154,14 @@ struct GRUUnitGradFunctor<platform::CUDADeviceContext, T> {
}
if (grad.prev_out_grad && value.prev_out_value) {
math::gemm<platform::CUDADeviceContext, T>(
context, false, true, batch_size, frame_size, frame_size * 2, 1,
grad.gate_grad, frame_size * 3, value.gate_weight, frame_size * 2, 1,
grad.prev_out_grad, frame_size);
blas.GEMM(false, true, batch_size, frame_size, frame_size * 2, 1,
grad.gate_grad, frame_size * 3, value.gate_weight,
frame_size * 2, 1, grad.prev_out_grad, frame_size);
if (grad.gate_weight_grad) {
math::gemm<platform::CUDADeviceContext, T>(
context, true, false, frame_size, frame_size * 2, batch_size, 1,
value.prev_out_value, frame_size, grad.gate_grad, frame_size * 3, 1,
grad.gate_weight_grad, frame_size * 2);
blas.GEMM(true, false, frame_size, frame_size * 2, batch_size, 1,
value.prev_out_value, frame_size, grad.gate_grad,
frame_size * 3, 1, grad.gate_weight_grad, frame_size * 2);
}
}
}
......
......@@ -24,72 +24,6 @@ namespace math {
using float16 = paddle::platform::float16;
template <>
void gemm<platform::CPUDeviceContext, float16>(
const platform::CPUDeviceContext& context, const CBLAS_TRANSPOSE transA,
const CBLAS_TRANSPOSE transB, const int M, const int N, const int K,
const float16 alpha, const float16* A, const float16* B, const float16 beta,
float16* C) {
PADDLE_THROW("float16 GEMM not supported on CPU");
}
template <>
void gemm<platform::CPUDeviceContext, float>(
const platform::CPUDeviceContext& context, const CBLAS_TRANSPOSE transA,
const CBLAS_TRANSPOSE transB, const int M, const int N, const int K,
const float alpha, const float* A, const float* B, const float beta,
float* C) {
int lda = (transA == CblasNoTrans) ? K : M;
int ldb = (transB == CblasNoTrans) ? N : K;
int ldc = N;
cblas_sgemm(CblasRowMajor, transA, transB, M, N, K, alpha, A, lda, B, ldb,
beta, C, ldc);
}
template <>
void gemm<platform::CPUDeviceContext, double>(
const platform::CPUDeviceContext& context, const CBLAS_TRANSPOSE transA,
const CBLAS_TRANSPOSE transB, const int M, const int N, const int K,
const double alpha, const double* A, const double* B, const double beta,
double* C) {
int lda = (transA == CblasNoTrans) ? K : M;
int ldb = (transB == CblasNoTrans) ? N : K;
int ldc = N;
cblas_dgemm(CblasRowMajor, transA, transB, M, N, K, alpha, A, lda, B, ldb,
beta, C, ldc);
}
template <>
void gemm<platform::CPUDeviceContext, float16>(
const platform::CPUDeviceContext& context, const bool transA,
const bool transB, const int M, const int N, const int K,
const float16 alpha, const float16* A, const int lda, const float16* B,
const int ldb, const float16 beta, float16* C, const int ldc) {
PADDLE_THROW("float16 GEMM not supported on CPU");
}
template <>
void gemm<platform::CPUDeviceContext, float>(
const platform::CPUDeviceContext& context, const bool transA,
const bool transB, const int M, const int N, const int K, const float alpha,
const float* A, const int lda, const float* B, const int ldb,
const float beta, float* C, const int ldc) {
cblas_sgemm(CblasRowMajor, transA == false ? CblasNoTrans : CblasTrans,
transB == false ? CblasNoTrans : CblasTrans, M, N, K, alpha, A,
lda, B, ldb, beta, C, ldc);
}
template <>
void gemm<platform::CPUDeviceContext, double>(
const platform::CPUDeviceContext& context, const bool transA,
const bool transB, const int M, const int N, const int K,
const double alpha, const double* A, const int lda, const double* B,
const int ldb, const double beta, double* C, const int ldc) {
cblas_dgemm(CblasRowMajor, transA == false ? CblasNoTrans : CblasTrans,
transB == false ? CblasNoTrans : CblasTrans, M, N, K, alpha, A,
lda, B, ldb, beta, C, ldc);
}
template <>
void matmul<platform::CPUDeviceContext, float16>(
const platform::CPUDeviceContext& context,
......@@ -123,8 +57,8 @@ void matmul<platform::CPUDeviceContext, float>(
CBLAS_TRANSPOSE transA = (trans_a == false) ? CblasNoTrans : CblasTrans;
CBLAS_TRANSPOSE transB = (trans_b == false) ? CblasNoTrans : CblasTrans;
gemm<platform::CPUDeviceContext, float>(
context, transA, transB, M, N, K, alpha, matrix_a.data<float>(),
Blas<platform::CPUDeviceContext>(context).GEMM(
transA, transB, M, N, K, alpha, matrix_a.data<float>(),
matrix_b.data<float>(), beta, matrix_out->data<float>());
}
......@@ -152,8 +86,8 @@ void matmul<platform::CPUDeviceContext, double>(
CBLAS_TRANSPOSE transA = (trans_a == false) ? CblasNoTrans : CblasTrans;
CBLAS_TRANSPOSE transB = (trans_b == false) ? CblasNoTrans : CblasTrans;
gemm<platform::CPUDeviceContext, double>(
context, transA, transB, M, N, K, alpha, matrix_a.data<double>(),
Blas<platform::CPUDeviceContext>(context).GEMM(
transA, transB, M, N, K, alpha, matrix_a.data<double>(),
matrix_b.data<double>(), beta, matrix_out->data<double>());
}
......@@ -230,8 +164,8 @@ void batched_gemm<platform::CPUDeviceContext, float>(
const float* Ak = &A[k * strideA];
const float* Bk = &B[k * strideB];
float* Ck = &C[k * M * N];
gemm<platform::CPUDeviceContext, float>(context, transA, transB, M, N, K,
alpha, Ak, Bk, beta, Ck);
Blas<platform::CPUDeviceContext>(context).GEMM(transA, transB, M, N, K,
alpha, Ak, Bk, beta, Ck);
}
}
......@@ -246,8 +180,8 @@ void batched_gemm<platform::CPUDeviceContext, double>(
const double* Ak = &A[k * strideA];
const double* Bk = &B[k * strideB];
double* Ck = &C[k * M * N];
gemm<platform::CPUDeviceContext, double>(context, transA, transB, M, N, K,
alpha, Ak, Bk, beta, Ck);
Blas<platform::CPUDeviceContext>(context).GEMM(transA, transB, M, N, K,
alpha, Ak, Bk, beta, Ck);
}
}
#endif
......
......@@ -25,157 +25,6 @@ namespace math {
using float16 = paddle::platform::float16;
template <>
void gemm<platform::CUDADeviceContext, float16>(
const platform::CUDADeviceContext& context, const CBLAS_TRANSPOSE transA,
const CBLAS_TRANSPOSE transB, const int M, const int N, const int K,
const float16 alpha, const float16* A, const float16* B, const float16 beta,
float16* C) {
// Note that cublas follows fortran order, so the order is different from
// the cblas convention.
int lda = (transA == CblasNoTrans) ? K : M;
int ldb = (transB == CblasNoTrans) ? N : K;
cublasOperation_t cuTransA =
(transA == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
cublasOperation_t cuTransB =
(transB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
// TODO(kexinzhao): add processing code for compute capability < 53 case
PADDLE_ENFORCE_GE(context.GetComputeCapability(), 53,
"cublas fp16 gemm requires GPU compute capability >= 53");
#if CUDA_VERSION >= 8000
float h_alpha = static_cast<float>(alpha);
float h_beta = static_cast<float>(beta);
cublasGemmAlgo_t algo = CUBLAS_GEMM_DFALT;
#if CUDA_VERSION >= 9000
if (context.GetComputeCapability() >= 70) {
PADDLE_ENFORCE(platform::dynload::cublasSetMathMode(context.cublas_handle(),
CUBLAS_TENSOR_OP_MATH));
algo = CUBLAS_GEMM_DFALT_TENSOR_OP;
} else {
PADDLE_ENFORCE(platform::dynload::cublasSetMathMode(context.cublas_handle(),
CUBLAS_DEFAULT_MATH));
}
#endif // CUDA_VERSION >= 9000
// cublasHgemm does true FP16 computation which is slow for non-Volta
// GPUs. So use cublasGemmEx instead which does pesudo FP16 computation:
// input/output in fp16, computation in fp32, which can also be accelerated
// using tensor cores in volta GPUs.
PADDLE_ENFORCE(platform::dynload::cublasGemmEx(
context.cublas_handle(), cuTransB, cuTransA, N, M, K, &h_alpha, B,
CUDA_R_16F, ldb, A, CUDA_R_16F, lda, &h_beta, C, CUDA_R_16F, N,
CUDA_R_32F, algo));
#else
// CUDA 7.5 does not support cublasGemmEx, hence we fall back to use hgemm
const half h_alpha = static_cast<const half>(alpha);
const half h_beta = static_cast<const half>(beta);
const half* h_A = reinterpret_cast<const half*>(A);
const half* h_B = reinterpret_cast<const half*>(B);
half* h_C = reinterpret_cast<half*>(C);
PADDLE_ENFORCE(platform::dynload::cublasHgemm(
context.cublas_handle(), cuTransB, cuTransA, N, M, K, &h_alpha, h_B, ldb,
h_A, lda, &h_beta, h_C, N));
#endif // CUDA_VERSION >= 8000
}
template <>
void gemm<platform::CUDADeviceContext, float>(
const platform::CUDADeviceContext& context, const CBLAS_TRANSPOSE transA,
const CBLAS_TRANSPOSE transB, const int M, const int N, const int K,
const float alpha, const float* A, const float* B, const float beta,
float* C) {
// Note that cublas follows fortran order, so the order is different from
// the cblas convention.
int lda = (transA == CblasNoTrans) ? K : M;
int ldb = (transB == CblasNoTrans) ? N : K;
cublasOperation_t cuTransA =
(transA == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
cublasOperation_t cuTransB =
(transB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
PADDLE_ENFORCE(platform::dynload::cublasSgemm(
context.cublas_handle(), cuTransB, cuTransA, N, M, K, &alpha, B, ldb, A,
lda, &beta, C, N));
}
template <>
void gemm<platform::CUDADeviceContext, double>(
const platform::CUDADeviceContext& context, const CBLAS_TRANSPOSE transA,
const CBLAS_TRANSPOSE transB, const int M, const int N, const int K,
const double alpha, const double* A, const double* B, const double beta,
double* C) {
// Note that cublas follows fortran order, so the order is different from
// the cblas convention.
int lda = (transA == CblasNoTrans) ? K : M;
int ldb = (transB == CblasNoTrans) ? N : K;
cublasOperation_t cuTransA =
(transA == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
cublasOperation_t cuTransB =
(transB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
PADDLE_ENFORCE(platform::dynload::cublasDgemm(
context.cublas_handle(), cuTransB, cuTransA, N, M, K, &alpha, B, ldb, A,
lda, &beta, C, N));
}
template <>
void gemm<platform::CUDADeviceContext, float16>(
const platform::CUDADeviceContext& context, const bool transA,
const bool transB, const int M, const int N, const int K,
const float16 alpha, const float16* A, const int lda, const float16* B,
const int ldb, const float16 beta, float16* C, const int ldc) {
// Note that cublas follows fortran order, so the order is different from
// the cblas convention.
cublasOperation_t cuTransA = transA == false ? CUBLAS_OP_N : CUBLAS_OP_T;
cublasOperation_t cuTransB = transB == false ? CUBLAS_OP_N : CUBLAS_OP_T;
const half h_alpha = static_cast<const half>(alpha);
const half h_beta = static_cast<const half>(beta);
const half* h_A = reinterpret_cast<const half*>(A);
const half* h_B = reinterpret_cast<const half*>(B);
half* h_C = reinterpret_cast<half*>(C);
// TODO(kexinzhao): add processing code for compute capability < 53 case
PADDLE_ENFORCE_GE(context.GetComputeCapability(), 53,
"cublas Hgemm requires GPU compute capability >= 53");
PADDLE_ENFORCE(platform::dynload::cublasHgemm(
context.cublas_handle(), cuTransB, cuTransA, N, M, K, &h_alpha, h_B, ldb,
h_A, lda, &h_beta, h_C, ldc));
}
template <>
void gemm<platform::CUDADeviceContext, float>(
const platform::CUDADeviceContext& context, const bool transA,
const bool transB, const int M, const int N, const int K, const float alpha,
const float* A, const int lda, const float* B, const int ldb,
const float beta, float* C, const int ldc) {
// Note that cublas follows fortran order, so the order is different from
// the cblas convention.
cublasOperation_t cuTransA = transA == false ? CUBLAS_OP_N : CUBLAS_OP_T;
cublasOperation_t cuTransB = transB == false ? CUBLAS_OP_N : CUBLAS_OP_T;
PADDLE_ENFORCE(platform::dynload::cublasSgemm(
context.cublas_handle(), cuTransB, cuTransA, N, M, K, &alpha, B, ldb, A,
lda, &beta, C, ldc));
}
template <>
void gemm<platform::CUDADeviceContext, double>(
const platform::CUDADeviceContext& context, const bool transA,
const bool transB, const int M, const int N, const int K,
const double alpha, const double* A, const int lda, const double* B,
const int ldb, const double beta, double* C, const int ldc) {
// Note that cublas follows fortran order, so the order is different from
// the cblas convention.
cublasOperation_t cuTransA = transA == false ? CUBLAS_OP_N : CUBLAS_OP_T;
cublasOperation_t cuTransB = transB == false ? CUBLAS_OP_N : CUBLAS_OP_T;
PADDLE_ENFORCE(platform::dynload::cublasDgemm(
context.cublas_handle(), cuTransB, cuTransA, N, M, K, &alpha, B, ldb, A,
lda, &beta, C, ldc));
}
template <>
void matmul<platform::CUDADeviceContext, float16>(
const platform::CUDADeviceContext& context,
......@@ -200,8 +49,8 @@ void matmul<platform::CUDADeviceContext, float16>(
CBLAS_TRANSPOSE transA = (trans_a == false) ? CblasNoTrans : CblasTrans;
CBLAS_TRANSPOSE transB = (trans_b == false) ? CblasNoTrans : CblasTrans;
gemm<platform::CUDADeviceContext, float16>(
context, transA, transB, M, N, K, alpha, matrix_a.data<float16>(),
Blas<platform::CUDADeviceContext>(context).GEMM(
transA, transB, M, N, K, alpha, matrix_a.data<float16>(),
matrix_b.data<float16>(), beta, matrix_out->data<float16>());
}
......@@ -229,8 +78,8 @@ void matmul<platform::CUDADeviceContext, float>(
CBLAS_TRANSPOSE transA = (trans_a == false) ? CblasNoTrans : CblasTrans;
CBLAS_TRANSPOSE transB = (trans_b == false) ? CblasNoTrans : CblasTrans;
gemm<platform::CUDADeviceContext, float>(
context, transA, transB, M, N, K, alpha, matrix_a.data<float>(),
Blas<platform::CUDADeviceContext>(context).GEMM(
transA, transB, M, N, K, alpha, matrix_a.data<float>(),
matrix_b.data<float>(), beta, matrix_out->data<float>());
}
......@@ -258,8 +107,8 @@ void matmul<platform::CUDADeviceContext, double>(
CBLAS_TRANSPOSE transA = (trans_a == false) ? CblasNoTrans : CblasTrans;
CBLAS_TRANSPOSE transB = (trans_b == false) ? CblasNoTrans : CblasTrans;
gemm<platform::CUDADeviceContext, double>(
context, transA, transB, M, N, K, alpha, matrix_a.data<double>(),
Blas<platform::CUDADeviceContext>(context).GEMM(
transA, transB, M, N, K, alpha, matrix_a.data<double>(),
matrix_b.data<double>(), beta, matrix_out->data<double>());
}
......
......@@ -42,6 +42,7 @@ int LAPACKE_dgetri(int matrix_layout, int n, double* a, int lda,
#include <vector>
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/platform/device_context.h"
......@@ -56,17 +57,48 @@ namespace math {
// Then matrixA: M * K, matrixB: K * N, matrixC : M * N
// For more detailed info, please refer to
// http://www.netlib.org/lapack/explore-html/d4/de2/sgemm_8f.html
template <typename DeviceContext>
class Blas {
public:
explicit Blas(const DeviceContext& context) : context_(context) {}
template <typename T>
void GEMM(const CBLAS_TRANSPOSE transA, const CBLAS_TRANSPOSE transB,
const int M, const int N, const int K, const T alpha, const T* A,
const T* B, const T beta, T* C) const;
template <typename T>
void GEMM(const bool transA, const bool transB, const int M, const int N,
const int K, const T alpha, const T* A, const int lda, const T* B,
const int ldb, const T beta, T* C, const int ldc) const;
private:
const DeviceContext& context_;
};
template <typename DeviceContext, typename T>
void gemm(const DeviceContext& context, const CBLAS_TRANSPOSE transA,
const CBLAS_TRANSPOSE transB, const int M, const int N, const int K,
const T alpha, const T* A, const T* B, const T beta, T* C);
class BlasT : private Blas<DeviceContext> {
public:
using Blas<DeviceContext>::Blas;
template <typename... ARGS>
void GEMM(ARGS... args) const {
static_cast<const Blas<DeviceContext>*>(this)->template GEMM<T>(args...);
}
};
// gemm wrapper with stride args for matrix uncontinuous in memory
template <typename DeviceContext, typename T>
void gemm(const DeviceContext& context, const bool transA, const bool transB,
const int M, const int N, const int K, const T alpha, const T* A,
const int lda, const T* B, const int ldb, const T beta, T* C,
const int ldc);
inline BlasT<DeviceContext, T> GetBlas(
const framework::ExecutionContext& exe_ctx) {
return BlasT<DeviceContext, T>(
exe_ctx.template device_context<DeviceContext>());
}
template <typename DeviceContext, typename T>
inline BlasT<DeviceContext, T> GetBlas(const DeviceContext& dev_ctx) {
return BlasT<DeviceContext, T>(dev_ctx);
}
// matrix multiply with continuous memory
template <typename DeviceContext, typename T>
......@@ -137,3 +169,8 @@ struct RowwiseMean {
} // namespace math
} // namespace operators
} // namespace paddle
#include "paddle/fluid/operators/math/blas_impl.h"
#ifdef PADDLE_WITH_CUDA
#include "paddle/fluid/operators/math/blas_impl.cu.h"
#endif
......@@ -14,6 +14,13 @@
#include "paddle/fluid/operators/math/math_function.h"
#include "gtest/gtest.h"
template <typename T>
inline paddle::operators::math::BlasT<paddle::platform::CPUDeviceContext, T>
GetBlas(const paddle::platform::CPUDeviceContext& context) {
return paddle::operators::math::GetBlas<paddle::platform::CPUDeviceContext,
T>(context);
}
TEST(math_function, gemm_notrans_cblas) {
paddle::framework::Tensor input1;
paddle::framework::Tensor input2;
......@@ -34,9 +41,8 @@ TEST(math_function, gemm_notrans_cblas) {
memcpy(input3_ptr, arr3, 8 * sizeof(float));
paddle::platform::CPUDeviceContext context(*cpu_place);
paddle::operators::math::gemm<paddle::platform::CPUDeviceContext, float>(
context, false, false, m, n, k, 1, input1_ptr, 3, input2_ptr + 1, 4, 1,
input3_ptr + 1, 4);
GetBlas<float>(context).GEMM(false, false, m, n, k, 1, input1_ptr, 3,
input2_ptr + 1, 4, 1, input3_ptr + 1, 4);
EXPECT_EQ(input3_ptr[0], 0);
EXPECT_EQ(input3_ptr[1], 24);
......@@ -68,9 +74,8 @@ TEST(math_function, gemm_trans_clbas) {
memcpy(input3_ptr, arr3, 8 * sizeof(float));
paddle::platform::CPUDeviceContext context(*cpu_place);
paddle::operators::math::gemm<paddle::platform::CPUDeviceContext, float>(
context, false, true, m, n, k, 1, input1_ptr, 3, input2_ptr + 3, 3, 1,
input3_ptr + 1, 4);
GetBlas<float>(context).GEMM(false, true, m, n, k, 1, input1_ptr, 3,
input2_ptr + 3, 3, 1, input3_ptr + 1, 4);
EXPECT_EQ(input3_ptr[0], 0);
EXPECT_EQ(input3_ptr[1], 24);
......
......@@ -13,6 +13,7 @@
// limitations under the License.
#include "gtest/gtest.h"
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/platform/device_context.h"
void fill_fp16_data(paddle::platform::float16* in_ptr, size_t size,
const std::vector<float>& data) {
......@@ -178,6 +179,13 @@ TEST(math_function, trans_mul_notrans_fp16) {
EXPECT_EQ(static_cast<float>(out_ptr[8]), 29);
}
template <typename T>
inline paddle::operators::math::BlasT<paddle::platform::CUDADeviceContext, T>
GetBlas(const paddle::platform::CUDADeviceContext& context) {
return paddle::operators::math::GetBlas<paddle::platform::CUDADeviceContext,
T>(context);
}
TEST(math_function, gemm_notrans_cublas_fp32) {
paddle::framework::Tensor input1;
paddle::framework::Tensor input2;
......@@ -210,8 +218,8 @@ TEST(math_function, gemm_notrans_cublas_fp32) {
float* b = input2_gpu.data<float>();
float* c = input3_gpu.mutable_data<float>(gpu_place);
paddle::operators::math::gemm<paddle::platform::CUDADeviceContext, float>(
context, false, false, m, n, k, 1, a, 3, b + 1, 4, 1, c + 1, 4);
GetBlas<float>(context).GEMM(false, false, m, n, k, 1, a, 3, b + 1, 4, 1,
c + 1, 4);
paddle::framework::TensorCopySync(input3_gpu, cpu_place, &input3);
......@@ -271,10 +279,9 @@ TEST(math_function, gemm_notrans_cublas_fp16) {
paddle::platform::float16* c =
input3_gpu.mutable_data<paddle::platform::float16>(gpu_place);
paddle::operators::math::gemm<paddle::platform::CUDADeviceContext,
paddle::platform::float16>(
context, false, false, m, n, k, paddle::platform::float16(1), a, 3, b + 1,
4, paddle::platform::float16(1), c + 1, 4);
GetBlas<paddle::platform::float16>(context).GEMM(
false, false, m, n, k, static_cast<paddle::platform::float16>(1), a, 3,
b + 1, 4, static_cast<paddle::platform::float16>(1), c + 1, 4);
paddle::framework::TensorCopySync(input3_gpu, cpu_place, &input3);
......@@ -327,8 +334,8 @@ TEST(math_function, gemm_trans_cublas_fp32) {
float* b = input2_gpu.data<float>();
float* c = input3_gpu.mutable_data<float>(gpu_place);
paddle::operators::math::gemm<paddle::platform::CUDADeviceContext, float>(
context, false, true, m, n, k, 1, a, 3, b + 3, 3, 1, c + 1, 4);
GetBlas<float>(context).GEMM(false, true, m, n, k, 1, a, 3, b + 3, 3, 1,
c + 1, 4);
paddle::framework::TensorCopySync(input3_gpu, cpu_place, &input3);
......@@ -382,10 +389,9 @@ TEST(math_function, gemm_trans_cublas_fp16) {
paddle::platform::float16* c =
input3_gpu.mutable_data<paddle::platform::float16>(gpu_place);
paddle::operators::math::gemm<paddle::platform::CUDADeviceContext,
paddle::platform::float16>(
context, false, true, m, n, k, paddle::platform::float16(1), a, 3, b + 3,
3, paddle::platform::float16(1), c + 1, 4);
GetBlas<paddle::platform::float16>(context).GEMM(
false, true, m, n, k, static_cast<paddle::platform::float16>(1), a, 3,
b + 3, 3, static_cast<paddle::platform::float16>(1), c + 1, 4);
paddle::framework::TensorCopySync(input3_gpu, cpu_place, &input3);
......
......@@ -131,8 +131,9 @@ class MatMulFunctor {
if (!batchCount) {
// regular matrix multiplication
gemm<DeviceContext, T>(context, transA, transB, M, N, kA, alpha,
a.data<T>(), b.data<T>(), beta, out->data<T>());
Blas<DeviceContext>(context).GEMM(transA, transB, M, N, kA, alpha,
a.data<T>(), b.data<T>(), beta,
out->data<T>());
} else {
// batched matrix multiplication
batched_gemm<DeviceContext, T>(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册