提交 8de4e3bd 编写于 作者: Q qijun

disable gpu implementation temporarily

上级 a821fec1
...@@ -26,6 +26,7 @@ void gemm<platform::GPUPlace, float>( ...@@ -26,6 +26,7 @@ void gemm<platform::GPUPlace, float>(
platform::DeviceContext* context) { platform::DeviceContext* context) {
// Note that cublas follows fortran order, so the order is different from // Note that cublas follows fortran order, so the order is different from
// the cblas convention. // the cblas convention.
/*
cublasOperation_t cuTransA = cublasOperation_t cuTransA =
(transA == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; (transA == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
cublasOperation_t cuTransB = cublasOperation_t cuTransB =
...@@ -34,6 +35,8 @@ void gemm<platform::GPUPlace, float>( ...@@ -34,6 +35,8 @@ void gemm<platform::GPUPlace, float>(
PADDLE_ENFORCE(platform::dynload::cublasSgemm( PADDLE_ENFORCE(platform::dynload::cublasSgemm(
reinterpret_cast<platform::CUDADeviceContext*>(context)->cublas_handle(), reinterpret_cast<platform::CUDADeviceContext*>(context)->cublas_handle(),
cuTransB, cuTransA, N, M, K, &alpha, B, ldb, A, lda, &beta, C, ldc)); cuTransB, cuTransA, N, M, K, &alpha, B, ldb, A, lda, &beta, C, ldc));
*/
PADDLE_THROW("not implemented now");
} }
template <> template <>
...@@ -44,6 +47,7 @@ void gemm<platform::GPUPlace, double>( ...@@ -44,6 +47,7 @@ void gemm<platform::GPUPlace, double>(
const int ldc, platform::DeviceContext* context) { const int ldc, platform::DeviceContext* context) {
// Note that cublas follows fortran order, so the order is different from // Note that cublas follows fortran order, so the order is different from
// the cblas convention. // the cblas convention.
/*
cublasOperation_t cuTransA = cublasOperation_t cuTransA =
(transA == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; (transA == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
cublasOperation_t cuTransB = cublasOperation_t cuTransB =
...@@ -51,6 +55,8 @@ void gemm<platform::GPUPlace, double>( ...@@ -51,6 +55,8 @@ void gemm<platform::GPUPlace, double>(
PADDLE_ENFORCE(platform::dynload::cublasDgemm( PADDLE_ENFORCE(platform::dynload::cublasDgemm(
reinterpret_cast<platform::CUDADeviceContext*>(context)->cublas_handle(), reinterpret_cast<platform::CUDADeviceContext*>(context)->cublas_handle(),
cuTransB, cuTransA, N, M, K, &alpha, B, ldb, A, lda, &beta, C, ldc)); cuTransB, cuTransA, N, M, K, &alpha, B, ldb, A, lda, &beta, C, ldc));
*/
PADDLE_THROW("not implemented now");
} }
template <> template <>
......
...@@ -40,36 +40,23 @@ extern "C" { ...@@ -40,36 +40,23 @@ extern "C" {
#include <cmath> #include <cmath>
#include "paddle/framework/tensor.h" #include "paddle/framework/tensor.h"
#include "paddle/platform/device_context.h" #include "paddle/platform/device_context.h"
#include "paddle/platform/enforce.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
namespace math { namespace math {
template <typename Place, typename T> template <typename Place, typename T>
void gemm(const CBLAS_TRANSPOSE transA, void gemm(const CBLAS_TRANSPOSE transA, const CBLAS_TRANSPOSE transB,
const CBLAS_TRANSPOSE transB, const int M, const int N, const int K, const T alpha, const T* A,
const int M, const int lda, const T* B, const int ldb, const T beta, T* C,
const int N, const int ldc, platform::DeviceContext* context);
const int K,
const T alpha,
const T* A,
const int lda,
const T* B,
const int ldb,
const T beta,
T* C,
const int ldc,
platform::DeviceContext* context);
// matrix multiply with continous memory // matrix multiply with continous memory
template <typename Place, typename T> template <typename Place, typename T>
void matmul(const framework::Tensor& in1, void matmul(const framework::Tensor& in1, bool in1_T,
bool in1_T, const framework::Tensor& in2, bool in2_T, float alpha,
const framework::Tensor& in2, framework::Tensor* out, float beta,
bool in2_T,
float alpha,
framework::Tensor* out,
float beta,
platform::DeviceContext* context); platform::DeviceContext* context);
} // namespace math } // namespace math
......
...@@ -15,4 +15,5 @@ ...@@ -15,4 +15,5 @@
#define EIGEN_USE_GPU #define EIGEN_USE_GPU
#include "paddle/operators/mul_op.h" #include "paddle/operators/mul_op.h"
REGISTER_OP_GPU_KERNEL(mul, ops::MulKernel<paddle::platform::GPUPlace, float>); // REGISTER_OP_GPU_KERNEL(mul, ops::MulKernel<paddle::platform::GPUPlace,
// float>);
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册