提交 8de4e3bd 编写于 作者: Q qijun

disable gpu implementation temporarily

上级 a821fec1
......@@ -26,6 +26,7 @@ void gemm<platform::GPUPlace, float>(
platform::DeviceContext* context) {
// Note that cublas follows fortran order, so the order is different from
// the cblas convention.
/*
cublasOperation_t cuTransA =
(transA == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
cublasOperation_t cuTransB =
......@@ -34,6 +35,8 @@ void gemm<platform::GPUPlace, float>(
PADDLE_ENFORCE(platform::dynload::cublasSgemm(
reinterpret_cast<platform::CUDADeviceContext*>(context)->cublas_handle(),
cuTransB, cuTransA, N, M, K, &alpha, B, ldb, A, lda, &beta, C, ldc));
*/
PADDLE_THROW("not implemented now");
}
template <>
......@@ -44,6 +47,7 @@ void gemm<platform::GPUPlace, double>(
const int ldc, platform::DeviceContext* context) {
// Note that cublas follows fortran order, so the order is different from
// the cblas convention.
/*
cublasOperation_t cuTransA =
(transA == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
cublasOperation_t cuTransB =
......@@ -51,6 +55,8 @@ void gemm<platform::GPUPlace, double>(
PADDLE_ENFORCE(platform::dynload::cublasDgemm(
reinterpret_cast<platform::CUDADeviceContext*>(context)->cublas_handle(),
cuTransB, cuTransA, N, M, K, &alpha, B, ldb, A, lda, &beta, C, ldc));
*/
PADDLE_THROW("not implemented now");
}
template <>
......
......@@ -40,36 +40,23 @@ extern "C" {
#include <cmath>
#include "paddle/framework/tensor.h"
#include "paddle/platform/device_context.h"
#include "paddle/platform/enforce.h"
namespace paddle {
namespace operators {
namespace math {
template <typename Place, typename T>
void gemm(const CBLAS_TRANSPOSE transA,
const CBLAS_TRANSPOSE transB,
const int M,
const int N,
const int K,
const T alpha,
const T* A,
const int lda,
const T* B,
const int ldb,
const T beta,
T* C,
const int ldc,
platform::DeviceContext* context);
void gemm(const CBLAS_TRANSPOSE transA, const CBLAS_TRANSPOSE transB,
const int M, const int N, const int K, const T alpha, const T* A,
const int lda, const T* B, const int ldb, const T beta, T* C,
const int ldc, platform::DeviceContext* context);
// matrix multiply with continous memory
template <typename Place, typename T>
void matmul(const framework::Tensor& in1,
bool in1_T,
const framework::Tensor& in2,
bool in2_T,
float alpha,
framework::Tensor* out,
float beta,
void matmul(const framework::Tensor& in1, bool in1_T,
const framework::Tensor& in2, bool in2_T, float alpha,
framework::Tensor* out, float beta,
platform::DeviceContext* context);
} // namespace math
......
......@@ -15,4 +15,5 @@
#define EIGEN_USE_GPU
#include "paddle/operators/mul_op.h"
REGISTER_OP_GPU_KERNEL(mul, ops::MulKernel<paddle::platform::GPUPlace, float>);
// REGISTER_OP_GPU_KERNEL(mul, ops::MulKernel<paddle::platform::GPUPlace,
// float>);
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册