提交 22dac40c 编写于 作者: Q qijun

add gemm for both cpu and gpu

上级 dd249a50
...@@ -257,6 +257,10 @@ class ExecutionContext : public OperatorContext { ...@@ -257,6 +257,10 @@ class ExecutionContext : public OperatorContext {
platform::Place GetPlace() const { return device_context_.GetPlace(); } platform::Place GetPlace() const { return device_context_.GetPlace(); }
const platform::DeviceContext& device_context() const {
return device_context_;
};
const platform::DeviceContext& device_context_; const platform::DeviceContext& device_context_;
}; };
......
...@@ -41,13 +41,15 @@ function(op_library TARGET) ...@@ -41,13 +41,15 @@ function(op_library TARGET)
endif() endif()
endfunction() endfunction()
add_subdirectory(math)
op_library(add_op SRCS add_op.cc add_op.cu) op_library(add_op SRCS add_op.cc add_op.cu)
cc_test(add_op_test SRCS add_op_test.cc DEPS add_op) cc_test(add_op_test SRCS add_op_test.cc DEPS add_op)
op_library(mean_op SRCS mean_op.cc mean_op.cu) op_library(mean_op SRCS mean_op.cc mean_op.cu)
cc_test(mean_op_test SRCS mean_op_test.cc DEPS mean_op) cc_test(mean_op_test SRCS mean_op_test.cc DEPS mean_op)
op_library(mul_op SRCS mul_op.cc mul_op.cu) op_library(mul_op SRCS mul_op.cc mul_op.cu DEPS math_function)
op_library(rowwise_add_op SRCS rowwise_add_op.cu rowwise_add_op.cc) op_library(rowwise_add_op SRCS rowwise_add_op.cu rowwise_add_op.cc)
op_library(sigmoid_op SRCS sigmoid_op.cc sigmoid_op.cu) op_library(sigmoid_op SRCS sigmoid_op.cc sigmoid_op.cu)
......
if (WITH_GPU)
nv_library(math_function SRCS math_function.cc math_function.cu DEPS cblas device_context)
else()
cc_library(math_function SRCS math_function.cc DEPS cblas device_context)
endif()
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/math/math_function.h"
namespace paddle {
namespace operators {
namespace math {
template <>
void gemm<platform::CPUPlace, float>(const CBLAS_TRANSPOSE transA,
const CBLAS_TRANSPOSE transB,
const int M,
const int N,
const int K,
const float alpha,
const float* A,
const int lda,
const float* B,
const int ldb,
const float beta,
float* C,
const int ldc,
const platform::DeviceContext* context) {
cblas_sgemm(CblasRowMajor,
transA,
transB,
M,
N,
K,
alpha,
A,
lda,
B,
ldb,
beta,
C,
ldc);
}
template <>
void gemm<platform::CPUPlace, double>(const CBLAS_TRANSPOSE transA,
const CBLAS_TRANSPOSE transB,
const int M,
const int N,
const int K,
const double alpha,
const double* A,
const int lda,
const double* B,
const int ldb,
const double beta,
double* C,
const int ldc,
const platform::DeviceContext* context) {
cblas_dgemm(CblasRowMajor,
transA,
transB,
M,
N,
K,
alpha,
A,
lda,
B,
ldb,
beta,
C,
ldc);
}
template <>
void axpy<platform::CPUPlace, float>(const int n,
const float alpha,
const float* x,
float* y,
const platform::DeviceContext* context) {
cblas_saxpy(n, alpha, x, 1, y, 1);
}
template <>
void axpy<platform::CPUPlace, double>(const int n,
const double alpha,
const double* x,
double* y,
const platform::DeviceContext* context) {
cblas_daxpy(n, alpha, x, 1, y, 1);
}
template <>
float dotProduct<platform::CPUPlace, float>(
const int n,
const float* x,
const float* y,
const platform::DeviceContext* context) {
return cblas_sdot(n, x, 1, y, 1);
}
template <>
double dotProduct<platform::CPUPlace, double>(
const int n,
const double* x,
const double* y,
const platform::DeviceContext* context) {
return cblas_ddot(n, x, 1, y, 1);
}
} // namespace math
} // namespace operators
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/math/math_function.h"
namespace paddle {
namespace operators {
namespace math {
template <>
void gemm<platform::GPUPlace float>(const CBLAS_TRANSPOSE transA,
const CBLAS_TRANSPOSE transB,
const int M,
const int N,
const int K,
const float alpha,
const float* A,
const int lda,
const float* B,
const int ldb,
const float beta,
float* C,
const int ldc,
const platform::DeviceContext* context) {
// Note that cublas follows fortran order, so the order is different from
// the cblas convention.
cublasOperation_t cuTransA =
(TransA == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
cublasOperation_t cuTransB =
(TransB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
PADDLE_ENFORCE(platform::dynload::cublasSgemm(
reinterpret_cast<const platform::CUDADeviceContext*>(context)->
cublas_handle(),
cuTransB,
cuTransA,
N,
M,
K,
&alpha,
B,
ldb,
A,
lda,
&beta,
C,
ldc));
}
template <>
void gemm<platform::GPUPlace, double>(const CBLAS_TRANSPOSE transA,
const CBLAS_TRANSPOSE transB,
const int M,
const int N,
const int K,
const double alpha,
const double* A,
const int lda,
const double* B,
const int ldb,
const double beta,
double* C,
const int ldc,
const platform::DeviceContext* context) {
// Note that cublas follows fortran order, so the order is different from
// the cblas convention.
cublasOperation_t cuTransA =
(TransA == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
cublasOperation_t cuTransB =
(TransB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
PADDLE_ENFORCE(platform::dynload::cublasDgemm(
reinterpret_cast<const platform::CUDADeviceContext*>(context)->
cublas_handle(),
cuTransB,
cuTransA,
N,
M,
K,
&alpha,
B,
ldb,
A,
lda,
&beta,
C,
ldc));
}
template <>
void axpy<platform::GPUPlace, float>(const int n,
const float alpha,
const float* x,
float* y,
const platform::DeviceContext* context) {
CUBLAS_ENFORCE(platform::dynload::cublasSaxpy(
reinterpret_cast<const platform::CUDADeviceContext*>(context)->
cublas_handle(), N, &alpha, X, 1, Y, 1));
}
template <>
void axpy<platform::GPUPlace, double>(const int n,
const double alpha,
const double* x,
double* y,
const platform::DeviceContext* context) {
CUBLAS_ENFORCE(platform::dynload::cublasDaxpy(
reinterpret_cast<const platform::CUDADeviceContext*>(context)->
cublas_handle(), N, &alpha, X, 1, Y, 1));
}
template <>
float dotProduct<platform::GPUPlace, float>(const int n,
const float* x,
const float* y,
const platform::DeviceContext* context) {
CUBLAS_ENFORCE(platform::dynload::cublasSdot(
reinterpret_cast<const platform::CUDADeviceContext*>(context)->
cublas_handle(), n, a, 1, b, 1, &result));
}
template <>
double dotProduct<platform::GPUPlace, double>(const int n,
const double* x,
const double* y,
const platform::DeviceContext* context) {
CUBLAS_ENFORCE(platform::dynload::cublasDdot(
reinterpret_cast<const platform::CUDADeviceContext*>(context)->
cublas_handle(), n, a, 1, b, 1, &result));
}
} // namespace math
} // namespace operators
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#ifdef PADDLE_USE_MKLML
#include <mkl_cblas.h>
#include <mkl_lapacke.h>
#include <mkl_vml_functions.h>
#endif
#ifdef PADDLE_USE_MKL
#include <mkl.h>
#include <mkl_lapacke.h>
#endif
#ifdef PADDLE_USE_ATLAS
extern "C" {
#include <cblas.h>
#include <clapack.h>
}
#endif
#ifdef PADDLE_USE_OPENBLAS
#include <cblas.h>
#include <lapacke.h>
#endif
#include <cmath>
#include "paddle/platform/device_context.h"
namespace paddle {
namespace operators {
namespace math {
template <typename Place, typename T>
void gemm(const CBLAS_TRANSPOSE transA,
const CBLAS_TRANSPOSE transB,
const int M,
const int N,
const int K,
const T alpha,
const T* A,
const int lda,
const T* B,
const int ldb,
const T beta,
T* C,
const int ldc,
const platform::DeviceContext* context);
template <typename Place, typename T>
void axpy(const int n,
const T alpha,
const T* x,
T* y,
const platform::DeviceContext* context);
template <typename Place, typename T>
T dotProduct(const int n,
const T* x,
const T* y,
const platform::DeviceContext* context);
} // namespace math
} // namespace operators
} // namespace paddle
...@@ -47,7 +47,7 @@ public: ...@@ -47,7 +47,7 @@ public:
T ig_size = (T)framework::product(IG->dims()); T ig_size = (T)framework::product(IG->dims());
EigenVector<T>::Flatten(*IG).device(*(context.GetEigenDevice<Place>())) = EigenVector<T>::Flatten(*IG).device(context.GetEigenDevice<Place>()) =
EigenScalar<T>::From(*OG) / ig_size; EigenScalar<T>::From(*OG) / ig_size;
} }
}; };
......
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
limitations under the License. */ limitations under the License. */
#include "paddle/operators/mul_op.h" #include "paddle/operators/mul_op.h"
#include "paddle/operators/math/math_function.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
......
...@@ -15,4 +15,6 @@ ...@@ -15,4 +15,6 @@
#define EIGEN_USE_GPU #define EIGEN_USE_GPU
#include "paddle/operators/mul_op.h" #include "paddle/operators/mul_op.h"
REGISTER_OP_GPU_KERNEL(mul, ops::MulKernel<ops::GPUPlace, float>); REGISTER_OP_GPU_KERNEL(mul, ops::MulKernel<ops::GPUPlace, float>);
\ No newline at end of file
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
#pragma once #pragma once
#include "paddle/operators/math/math_function.h"
#include "paddle/operators/type_alias.h" #include "paddle/operators/type_alias.h"
namespace paddle { namespace paddle {
...@@ -23,22 +24,35 @@ template <typename Place, typename T> ...@@ -23,22 +24,35 @@ template <typename Place, typename T>
class MulKernel : public OpKernel { class MulKernel : public OpKernel {
public: public:
void Compute(const ExecutionContext& context) const override { void Compute(const ExecutionContext& context) const override {
Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 1> dim_pair = {
{Eigen::IndexPair<Eigen::DenseIndex>(1, 0)}};
auto input0 = context.Input<Tensor>("X"); auto input0 = context.Input<Tensor>("X");
auto input1 = context.Input<Tensor>("Y"); auto input1 = context.Input<Tensor>("Y");
auto output = context.Output<Tensor>(0); auto output = context.Output<Tensor>(0);
output->mutable_data<T>(context.GetPlace()); output->mutable_data<T>(context.GetPlace());
auto X = EigenMatrix<T>::From(*input0); auto out_dim = output->dims();
auto Y = EigenMatrix<T>::From(*input1); auto in0_dim = input0->dims();
auto Z = EigenMatrix<T>::From(*output);
auto place = context.GetEigenDevice<Place>(); int M = out_dim[0];
int N = out_dim[1];
Z.device(place) = X.contract(Y, dim_pair); int K = in0_dim[1];
paddle::operators::math::template gemm<Place, T>(CblasNoTrans,
CblasNoTrans,
M,
N,
K,
1,
input0->data<T>(),
K,
input1->data<T>(),
N,
0,
output->data<T>(),
N,
&context.device_context());
} }
}; };
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
...@@ -61,7 +61,7 @@ class OpTestMeta(type): ...@@ -61,7 +61,7 @@ class OpTestMeta(type):
for out_name in func.all_output_args: for out_name in func.all_output_args:
actual = numpy.array(scope.find_var(out_name).get_tensor()) actual = numpy.array(scope.find_var(out_name).get_tensor())
expect = getattr(self, out_name) expect = getattr(self, out_name)
# TODO(qijun) The default decimal is 7, but numpy.dot and eigen.mul # TODO(qijun) The default decimal is 7, but numpy.dot and blas.gemm
# has some diff, and could not pass unittest. So I set decimal 3 here. # has some diff, and could not pass unittest. So I set decimal 3 here.
# And I will check this in future. # And I will check this in future.
numpy.testing.assert_almost_equal(actual, expect, decimal=3) numpy.testing.assert_almost_equal(actual, expect, decimal=3)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册