提交 fa54fb33 编写于 作者: Q QI JUN 提交者: GitHub

Merge pull request #3209 from QiJune/port_blas

port gemm to new framework
......@@ -264,6 +264,10 @@ class ExecutionContext : public InferShapeContext {
platform::Place GetPlace() const { return device_context_->GetPlace(); }
const platform::DeviceContext* device_context() const {
return device_context_;
}
const platform::DeviceContext* device_context_;
};
......
......@@ -105,6 +105,8 @@ class Tensor {
template <typename T>
inline Tensor Slice(const int& begin_idx, const int& end_idx) const;
platform::Place place() const { return holder_->place(); }
private:
template <typename T>
inline void check_memory_size() const;
......
......@@ -41,6 +41,7 @@ function(op_library TARGET)
endif()
endfunction()
add_subdirectory(math)
cc_test(gather_test SRCS gather_test.cc DEPS tensor)
cc_library(net_op SRCS net_op.cc DEPS op_registry)
......@@ -50,7 +51,7 @@ op_library(add_op SRCS add_op.cc add_op.cu)
op_library(mean_op SRCS mean_op.cc mean_op.cu)
op_library(mul_op SRCS mul_op.cc mul_op.cu)
op_library(mul_op SRCS mul_op.cc mul_op.cu DEPS math_function)
op_library(rowwise_add_op SRCS rowwise_add_op.cu rowwise_add_op.cc)
op_library(sigmoid_op SRCS sigmoid_op.cc sigmoid_op.cu)
......
if(WITH_MKLML)
set(BLAS_LIB mklml)
else()
set(BLAS_LIB cblas)
endif()
if(WITH_GPU)
nv_library(math_function SRCS math_function.cc math_function.cu DEPS ${BLAS_LIB} device_context)
else()
cc_library(math_function SRCS math_function.cc DEPS ${BLAS_LIB} device_context)
endif()
nv_test(math_function_test SRCS math_function_test.cc DEPS math_function tensor)
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/math/math_function.h"
namespace paddle {
namespace operators {
namespace math {
template <>
void gemm<platform::CPUPlace, float>(const CBLAS_TRANSPOSE transA,
const CBLAS_TRANSPOSE transB, const int M,
const int N, const int K,
const float alpha, const float* A,
const float* B, const float beta, float* C,
platform::DeviceContext* context) {
int lda = K;
int ldb = N;
int ldc = N;
cblas_sgemm(CblasRowMajor, transA, transB, M, N, K, alpha, A, lda, B, ldb,
beta, C, ldc);
}
template <>
void gemm<platform::CPUPlace, double>(const CBLAS_TRANSPOSE transA,
const CBLAS_TRANSPOSE transB, const int M,
const int N, const int K,
const double alpha, const double* A,
const double* B, const double beta,
double* C,
platform::DeviceContext* context) {
int lda = K;
int ldb = N;
int ldc = N;
cblas_dgemm(CblasRowMajor, transA, transB, M, N, K, alpha, A, lda, B, ldb,
beta, C, ldc);
}
template <>
void matmul<platform::CPUPlace, float>(const framework::Tensor& matrix_a,
bool trans_a,
const framework::Tensor& matrix_b,
bool trans_b, float alpha,
framework::Tensor* matrix_out,
float beta,
platform::DeviceContext* context) {
auto dim_a = matrix_a.dims();
auto dim_b = matrix_b.dims();
auto dim_out = matrix_out->dims();
PADDLE_ENFORCE(dim_a.size() == 2 && dim_b.size() == 2 && dim_out.size() == 2,
"The input and output of matmul be matrix");
PADDLE_ENFORCE(platform::is_cpu_place(matrix_a.place()) &&
platform::is_cpu_place(matrix_b.place()) &&
platform::is_cpu_place(matrix_out->place()),
"Matrix must all be in CPUPlace");
int M = dim_out[0];
int N = dim_out[1];
int K = (trans_a == false) ? dim_a[1] : dim_a[0];
CBLAS_TRANSPOSE transA = (trans_a == false) ? CblasNoTrans : CblasTrans;
CBLAS_TRANSPOSE transB = (trans_b == false) ? CblasNoTrans : CblasTrans;
gemm<platform::CPUPlace, float>(
transA, transB, M, N, K, alpha, matrix_a.data<float>(),
matrix_b.data<float>(), beta, matrix_out->data<float>(), context);
}
template <>
void matmul<platform::CPUPlace, double>(const framework::Tensor& matrix_a,
bool trans_a,
const framework::Tensor& matrix_b,
bool trans_b, double alpha,
framework::Tensor* matrix_out,
double beta,
platform::DeviceContext* context) {
auto dim_a = matrix_a.dims();
auto dim_b = matrix_b.dims();
auto dim_out = matrix_out->dims();
PADDLE_ENFORCE(dim_a.size() == 2 && dim_b.size() == 2 && dim_out.size() == 2,
"The input and output of matmul be matrix");
PADDLE_ENFORCE(platform::is_cpu_place(matrix_a.place()) &&
platform::is_cpu_place(matrix_b.place()) &&
platform::is_cpu_place(matrix_out->place()),
"Matrix must all be in CPUPlace");
int M = dim_out[0];
int N = dim_out[1];
int K = (trans_a == false) ? dim_a[1] : dim_a[0];
CBLAS_TRANSPOSE transA = (trans_a == false) ? CblasNoTrans : CblasTrans;
CBLAS_TRANSPOSE transB = (trans_b == false) ? CblasNoTrans : CblasTrans;
gemm<platform::CPUPlace, double>(
transA, transB, M, N, K, alpha, matrix_a.data<double>(),
matrix_b.data<double>(), beta, matrix_out->data<double>(), context);
}
} // namespace math
} // namespace operators
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/math/math_function.h"
namespace paddle {
namespace operators {
namespace math {
template <>
void gemm<platform::GPUPlace, float>(const CBLAS_TRANSPOSE transA,
const CBLAS_TRANSPOSE transB, const int M,
const int N, const int K,
const float alpha, const float* A,
const float* B, const float beta, float* C,
platform::DeviceContext* context) {
// Note that cublas follows fortran order, so the order is different from
// the cblas convention.
int lda = (transA == CblasNoTrans) ? K : M;
int ldb = (transB == CblasNoTrans) ? N : K;
cublasOperation_t cuTransA =
(transA == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
cublasOperation_t cuTransB =
(transB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
PADDLE_ENFORCE(platform::dynload::cublasSgemm(
reinterpret_cast<platform::CUDADeviceContext*>(context)->cublas_handle(),
cuTransB, cuTransA, N, M, K, &alpha, B, ldb, A, lda, &beta, C, N));
}
template <>
void gemm<platform::GPUPlace, double>(const CBLAS_TRANSPOSE transA,
const CBLAS_TRANSPOSE transB, const int M,
const int N, const int K,
const double alpha, const double* A,
const double* B, const double beta,
double* C,
platform::DeviceContext* context) {
// Note that cublas follows fortran order, so the order is different from
// the cblas convention.
int lda = (transA == CblasNoTrans) ? K : M;
int ldb = (transB == CblasNoTrans) ? N : K;
cublasOperation_t cuTransA =
(transA == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
cublasOperation_t cuTransB =
(transB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
PADDLE_ENFORCE(platform::dynload::cublasDgemm(
reinterpret_cast<platform::CUDADeviceContext*>(context)->cublas_handle(),
cuTransB, cuTransA, N, M, K, &alpha, B, ldb, A, lda, &beta, C, N));
}
template <>
void matmul<platform::GPUPlace, float>(const framework::Tensor& matrix_a,
bool trans_a,
const framework::Tensor& matrix_b,
bool trans_b, float alpha,
framework::Tensor* matrix_out,
float beta,
platform::DeviceContext* context) {
auto dim_a = matrix_a.dims();
auto dim_b = matrix_b.dims();
auto dim_out = matrix_out->dims();
PADDLE_ENFORCE(dim_a.size() == 2 && dim_b.size() == 2 && dim_out.size() == 2,
"The input and output of matmul be matrix");
PADDLE_ENFORCE(platform::is_gpu_place(matrix_a.place()) &&
platform::is_gpu_place(matrix_b.place()) &&
platform::is_gpu_place(matrix_out->place()),
"Matrix must all be in GPUPlace");
int M = dim_out[0];
int N = dim_out[1];
int K = (trans_a == false) ? dim_a[1] : dim_a[0];
CBLAS_TRANSPOSE transA = (trans_a == false) ? CblasNoTrans : CblasTrans;
CBLAS_TRANSPOSE transB = (trans_b == false) ? CblasNoTrans : CblasTrans;
gemm<platform::GPUPlace, float>(
transA, transB, M, N, K, alpha, matrix_a.data<float>(),
matrix_b.data<float>(), beta, matrix_out->data<float>(), context);
}
template <>
void matmul<platform::GPUPlace, double>(const framework::Tensor& matrix_a,
bool trans_a,
const framework::Tensor& matrix_b,
bool trans_b, double alpha,
framework::Tensor* matrix_out,
double beta,
platform::DeviceContext* context) {
auto dim_a = matrix_a.dims();
auto dim_b = matrix_b.dims();
auto dim_out = matrix_out->dims();
PADDLE_ENFORCE(dim_a.size() == 2 && dim_b.size() == 2 && dim_out.size() == 2,
"The input and output of matmul be matrix");
PADDLE_ENFORCE(platform::is_gpu_place(matrix_a.place()) &&
platform::is_gpu_place(matrix_b.place()) &&
platform::is_gpu_place(matrix_out->place()),
"Matrix must all be in GPUPlace");
int M = dim_out[0];
int N = dim_out[1];
int K = (trans_a == false) ? dim_a[1] : dim_a[0];
CBLAS_TRANSPOSE transA = (trans_a == false) ? CblasNoTrans : CblasTrans;
CBLAS_TRANSPOSE transB = (trans_b == false) ? CblasNoTrans : CblasTrans;
gemm<platform::GPUPlace, double>(
transA, transB, M, N, K, alpha, matrix_a.data<double>(),
matrix_b.data<double>(), beta, matrix_out->data<double>(), context);
}
} // namespace math
} // namespace operators
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#ifdef PADDLE_USE_MKLML
#include <mkl_cblas.h>
#include <mkl_lapacke.h>
#include <mkl_vml_functions.h>
#endif
#ifdef PADDLE_USE_MKL
#include <mkl.h>
#include <mkl_lapacke.h>
#endif
#ifdef PADDLE_USE_ATLAS
extern "C" {
#include <cblas.h>
#include <clapack.h>
}
#endif
#ifdef PADDLE_USE_OPENBLAS
#include <cblas.h>
#include <lapacke.h>
#endif
#ifndef LAPACK_FOUND
extern "C" {
#include <cblas.h>
int LAPACKE_sgetrf(int matrix_layout, int m, int n, float* a, int lda,
int* ipiv);
int LAPACKE_dgetrf(int matrix_layout, int m, int n, double* a, int lda,
int* ipiv);
int LAPACKE_sgetri(int matrix_layout, int n, float* a, int lda,
const int* ipiv);
int LAPACKE_dgetri(int matrix_layout, int n, double* a, int lda,
const int* ipiv);
}
#endif
#include <cmath>
#include "paddle/framework/tensor.h"
#include "paddle/platform/device_context.h"
#include "paddle/platform/enforce.h"
namespace paddle {
namespace operators {
namespace math {
// Support continuous memory now
// If transA = N, and transB = N
// Then matrixA: M * K, matrixB: K * N matrixC : M * N
// For more detailed info, please refer to
// http://www.netlib.org/lapack/explore-html/d4/de2/sgemm_8f.html
template <typename Place, typename T>
void gemm(const CBLAS_TRANSPOSE transA, const CBLAS_TRANSPOSE transB,
const int M, const int N, const int K, const T alpha, const T* A,
const T* B, const T beta, T* C, platform::DeviceContext* context);
// matrix multiply with continuous memory
template <typename Place, typename T>
void matmul(const framework::Tensor& matrix_a, bool trans_a,
const framework::Tensor& matrix_b, bool trans_b, T alpha,
framework::Tensor* matrix_out, T beta,
platform::DeviceContext* context);
} // namespace math
} // namespace operators
} // namespace paddle
#include "paddle/operators/math/math_function.h"
#include "gtest/gtest.h"
#ifndef PADDLE_ONLY_CPU
TEST(math_function, notrans_mul_trans) {
paddle::framework::Tensor input1;
paddle::framework::Tensor input1_gpu;
paddle::framework::Tensor input2_gpu;
paddle::framework::Tensor out_gpu;
paddle::framework::Tensor out;
auto* cpu_place = new paddle::platform::CPUPlace();
float* input1_ptr = input1.mutable_data<float>({2, 3}, *cpu_place);
float arr[6] = {0, 1, 2, 3, 4, 5};
memcpy(input1_ptr, arr, 6 * sizeof(float));
auto* gpu_place = new paddle::platform::GPUPlace(0);
paddle::platform::DeviceContext* context =
new paddle::platform::CUDADeviceContext(*gpu_place);
input1_gpu.CopyFrom<float>(input1, *gpu_place);
input2_gpu.CopyFrom<float>(input1, *gpu_place);
out_gpu.mutable_data<float>({2, 2}, *gpu_place);
paddle::operators::math::matmul<paddle::platform::GPUPlace, float>(
input1_gpu, false, input2_gpu, true, 1, &out_gpu, 0, context);
out.CopyFrom<float>(out_gpu, *cpu_place);
float* out_ptr = out.data<float>();
EXPECT_EQ(out_ptr[0], 5);
EXPECT_EQ(out_ptr[1], 14);
EXPECT_EQ(out_ptr[2], 14);
EXPECT_EQ(out_ptr[3], 50);
}
TEST(math_function, trans_mul_notrans) {
paddle::framework::Tensor input1;
paddle::framework::Tensor input1_gpu;
paddle::framework::Tensor input2_gpu;
paddle::framework::Tensor out_gpu;
paddle::framework::Tensor out;
auto* cpu_place = new paddle::platform::CPUPlace();
float* input1_ptr = input1.mutable_data<float>({2, 3}, *cpu_place);
float arr[6] = {0, 1, 2, 3, 4, 5};
memcpy(input1_ptr, arr, 6 * sizeof(float));
auto* gpu_place = new paddle::platform::GPUPlace(0);
paddle::platform::DeviceContext* context =
new paddle::platform::CUDADeviceContext(*gpu_place);
input1_gpu.CopyFrom<float>(input1, *gpu_place);
input2_gpu.CopyFrom<float>(input1, *gpu_place);
out_gpu.mutable_data<float>({3, 3}, *gpu_place);
paddle::operators::math::matmul<paddle::platform::GPUPlace, float>(
input1_gpu, true, input2_gpu, false, 1, &out_gpu, 0, context);
out.CopyFrom<float>(out_gpu, *cpu_place);
float* out_ptr = out.data<float>();
EXPECT_EQ(out_ptr[0], 9);
EXPECT_EQ(out_ptr[1], 12);
EXPECT_EQ(out_ptr[2], 15);
EXPECT_EQ(out_ptr[3], 12);
EXPECT_EQ(out_ptr[4], 17);
EXPECT_EQ(out_ptr[5], 22);
EXPECT_EQ(out_ptr[6], 15);
EXPECT_EQ(out_ptr[7], 22);
EXPECT_EQ(out_ptr[8], 29);
}
#endif
......@@ -13,6 +13,7 @@
limitations under the License. */
#include "paddle/operators/mul_op.h"
#include "paddle/operators/math/math_function.h"
namespace paddle {
namespace operators {
......
......@@ -16,5 +16,4 @@
#include "paddle/operators/mul_op.h"
namespace ops = paddle::operators;
REGISTER_OP_GPU_KERNEL(mul, ops::MulKernel<paddle::platform::GPUPlace, float>);
......@@ -13,6 +13,9 @@
limitations under the License. */
#pragma once
#include "paddle/operators/math/math_function.h"
#include "paddle/framework/eigen.h"
#include "paddle/framework/op_registry.h"
......
......@@ -62,12 +62,12 @@ extern void *cublas_dso_handle;
DECLARE_DYNAMIC_LOAD_CUBLAS_WRAP(__name)
#define CUBLAS_BLAS_ROUTINE_EACH(__macro) \
__macro(cublasSgemv); \
__macro(cublasDgemv); \
__macro(cublasSgemm); \
__macro(cublasDgemm); \
__macro(cublasSgeam); \
__macro(cublasDgeam); \
__macro(cublasSgemv_v2); \
__macro(cublasDgemv_v2); \
__macro(cublasSgemm_v2); \
__macro(cublasDgemm_v2); \
__macro(cublasSgeam_v2); \
__macro(cublasDgeam_v2); \
__macro(cublasCreate_v2); \
__macro(cublasDestroy_v2); \
__macro(cublasSetStream_v2); \
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册