提交 7307b439 编写于 作者: Q qijun

fix gpu build error

上级 a821fec1
...@@ -36,8 +36,8 @@ include(simd) ...@@ -36,8 +36,8 @@ include(simd)
################################ Configurations ####################################### ################################ Configurations #######################################
option(WITH_GPU "Compile PaddlePaddle with NVIDIA GPU" ${CUDA_FOUND}) option(WITH_GPU "Compile PaddlePaddle with NVIDIA GPU" ${CUDA_FOUND})
option(WITH_AVX "Compile PaddlePaddle with AVX intrinsics" ${AVX_FOUND}) option(WITH_AVX "Compile PaddlePaddle with AVX intrinsics" ${AVX_FOUND})
option(WITH_MKLDNN "Compile PaddlePaddle with mkl-dnn support." ${AVX_FOUND}) option(WITH_MKLDNN "Compile PaddlePaddle with mkl-dnn support." OFF)
option(WITH_MKLML "Compile PaddlePaddle with mklml package." ${AVX_FOUND}) option(WITH_MKLML "Compile PaddlePaddle with mklml package." OFF)
option(WITH_DSO "Compile PaddlePaddle with dynamic linked CUDA" ON) option(WITH_DSO "Compile PaddlePaddle with dynamic linked CUDA" ON)
option(WITH_TESTING "Compile PaddlePaddle with unit testing" ON) option(WITH_TESTING "Compile PaddlePaddle with unit testing" ON)
option(WITH_SWIG_PY "Compile PaddlePaddle with inference api" ON) option(WITH_SWIG_PY "Compile PaddlePaddle with inference api" ON)
......
...@@ -13,7 +13,6 @@ See the License for the specific language governing permissions and ...@@ -13,7 +13,6 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/operators/math/math_function.h" #include "paddle/operators/math/math_function.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
namespace math { namespace math {
...@@ -26,6 +25,8 @@ void gemm<platform::GPUPlace, float>( ...@@ -26,6 +25,8 @@ void gemm<platform::GPUPlace, float>(
platform::DeviceContext* context) { platform::DeviceContext* context) {
// Note that cublas follows fortran order, so the order is different from // Note that cublas follows fortran order, so the order is different from
// the cblas convention. // the cblas convention.
int lda = (transA == CblasNoTrans) ? K : M;
int ldb = (transB == CblasNoTrans) ? N : K;
cublasOperation_t cuTransA = cublasOperation_t cuTransA =
(transA == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; (transA == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
cublasOperation_t cuTransB = cublasOperation_t cuTransB =
...@@ -44,6 +45,8 @@ void gemm<platform::GPUPlace, double>( ...@@ -44,6 +45,8 @@ void gemm<platform::GPUPlace, double>(
const int ldc, platform::DeviceContext* context) { const int ldc, platform::DeviceContext* context) {
// Note that cublas follows fortran order, so the order is different from // Note that cublas follows fortran order, so the order is different from
// the cblas convention. // the cblas convention.
lda = (transA == CblasNoTrans) ? K : M;
ldb = (transB == CblasNoTrans) ? N : K;
cublasOperation_t cuTransA = cublasOperation_t cuTransA =
(transA == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; (transA == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
cublasOperation_t cuTransB = cublasOperation_t cuTransB =
...@@ -118,7 +121,6 @@ void matmul<platform::GPUPlace, double>(const framework::Tensor& in1, ...@@ -118,7 +121,6 @@ void matmul<platform::GPUPlace, double>(const framework::Tensor& in1,
in1.data<double>(), K, in2.data<double>(), N, in1.data<double>(), K, in2.data<double>(), N,
beta, out->data<double>(), N, context); beta, out->data<double>(), N, context);
} }
} // namespace math } // namespace math
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
...@@ -37,6 +37,20 @@ extern "C" { ...@@ -37,6 +37,20 @@ extern "C" {
#include <lapacke.h> #include <lapacke.h>
#endif #endif
#ifndef LAPACK_FOUND
extern "C" {
#include <cblas.h>
int LAPACKE_sgetrf(
int matrix_layout, int m, int n, float* a, int lda, int* ipiv);
int LAPACKE_dgetrf(
int matrix_layout, int m, int n, double* a, int lda, int* ipiv);
int LAPACKE_sgetri(
int matrix_layout, int n, float* a, int lda, const int* ipiv);
int LAPACKE_dgetri(
int matrix_layout, int n, double* a, int lda, const int* ipiv);
}
#endif
#include <cmath> #include <cmath>
#include "paddle/framework/tensor.h" #include "paddle/framework/tensor.h"
#include "paddle/platform/device_context.h" #include "paddle/platform/device_context.h"
...@@ -61,7 +75,7 @@ void gemm(const CBLAS_TRANSPOSE transA, ...@@ -61,7 +75,7 @@ void gemm(const CBLAS_TRANSPOSE transA,
const int ldc, const int ldc,
platform::DeviceContext* context); platform::DeviceContext* context);
// matrix multiply with continous memory // matrix multiply with continuous memory
template <typename Place, typename T> template <typename Place, typename T>
void matmul(const framework::Tensor& in1, void matmul(const framework::Tensor& in1,
bool in1_T, bool in1_T,
......
...@@ -15,4 +15,5 @@ ...@@ -15,4 +15,5 @@
#define EIGEN_USE_GPU #define EIGEN_USE_GPU
#include "paddle/operators/mul_op.h" #include "paddle/operators/mul_op.h"
namespace ops = paddle::operators;
REGISTER_OP_GPU_KERNEL(mul, ops::MulKernel<paddle::platform::GPUPlace, float>); REGISTER_OP_GPU_KERNEL(mul, ops::MulKernel<paddle::platform::GPUPlace, float>);
...@@ -31,9 +31,6 @@ template <typename Place, typename T> ...@@ -31,9 +31,6 @@ template <typename Place, typename T>
class MulKernel : public framework::OpKernel { class MulKernel : public framework::OpKernel {
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 1> dim_pair = {
{Eigen::IndexPair<Eigen::DenseIndex>(1, 0)}};
auto input0 = context.Input<Tensor>("X"); auto input0 = context.Input<Tensor>("X");
auto input1 = context.Input<Tensor>("Y"); auto input1 = context.Input<Tensor>("Y");
auto output = context.Output<Tensor>(0); auto output = context.Output<Tensor>(0);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册