diff --git a/paddle/framework/operator.h b/paddle/framework/operator.h index 55435103489ace11868eed61c38018d8ba357e65..6a9057e5dbd1a80abaa55e3305fb7c0a768cb946 100644 --- a/paddle/framework/operator.h +++ b/paddle/framework/operator.h @@ -257,6 +257,10 @@ class ExecutionContext : public OperatorContext { platform::Place GetPlace() const { return device_context_.GetPlace(); } + const platform::DeviceContext& device_context() const { + return device_context_; + }; + const platform::DeviceContext& device_context_; }; diff --git a/paddle/operators/CMakeLists.txt b/paddle/operators/CMakeLists.txt index 6465deeec93100f0238ac850b92f7f7c5a60b795..6be90d91246dd923491b41112a9cec11f142c572 100644 --- a/paddle/operators/CMakeLists.txt +++ b/paddle/operators/CMakeLists.txt @@ -41,13 +41,15 @@ function(op_library TARGET) endif() endfunction() +add_subdirectory(math) + op_library(add_op SRCS add_op.cc add_op.cu) cc_test(add_op_test SRCS add_op_test.cc DEPS add_op) op_library(mean_op SRCS mean_op.cc mean_op.cu) cc_test(mean_op_test SRCS mean_op_test.cc DEPS mean_op) -op_library(mul_op SRCS mul_op.cc mul_op.cu) +op_library(mul_op SRCS mul_op.cc mul_op.cu DEPS math_function) op_library(rowwise_add_op SRCS rowwise_add_op.cu rowwise_add_op.cc) op_library(sigmoid_op SRCS sigmoid_op.cc sigmoid_op.cu) diff --git a/paddle/operators/math/CMakeLists.txt b/paddle/operators/math/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..586347668efeea24526e3d8cd7b83cf3c8855e0d --- /dev/null +++ b/paddle/operators/math/CMakeLists.txt @@ -0,0 +1,5 @@ +if (WITH_GPU) + nv_library(math_function SRCS math_function.cc math_function.cu DEPS cblas device_context) +else() + cc_library(math_function SRCS math_function.cc DEPS cblas device_context) +endif() diff --git a/paddle/operators/math/math_function.cc b/paddle/operators/math/math_function.cc new file mode 100644 index 0000000000000000000000000000000000000000..0532e8f034ccc8ddb3a0a6fac37b0415de6056b4 --- /dev/null +++ b/paddle/operators/math/math_function.cc @@ -0,0 +1,121 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/operators/math/math_function.h" + +namespace paddle { +namespace operators { +namespace math { + +template <> +void gemm(const CBLAS_TRANSPOSE transA, + const CBLAS_TRANSPOSE transB, + const int M, + const int N, + const int K, + const float alpha, + const float* A, + const int lda, + const float* B, + const int ldb, + const float beta, + float* C, + const int ldc, + const platform::DeviceContext* context) { + cblas_sgemm(CblasRowMajor, + transA, + transB, + M, + N, + K, + alpha, + A, + lda, + B, + ldb, + beta, + C, + ldc); +} + +template <> +void gemm(const CBLAS_TRANSPOSE transA, + const CBLAS_TRANSPOSE transB, + const int M, + const int N, + const int K, + const double alpha, + const double* A, + const int lda, + const double* B, + const int ldb, + const double beta, + double* C, + const int ldc, + const platform::DeviceContext* context) { + cblas_dgemm(CblasRowMajor, + transA, + transB, + M, + N, + K, + alpha, + A, + lda, + B, + ldb, + beta, + C, + ldc); +} + +template <> +void axpy(const int n, + const float alpha, + const float* x, + float* y, + const platform::DeviceContext* context) { + cblas_saxpy(n, alpha, x, 1, y, 1); +} + +template <> +void axpy(const int n, + const double alpha, + const double* x, + double* y, + const platform::DeviceContext* context) { + cblas_daxpy(n, alpha, x, 1, y, 1); +} + +template <> +float dotProduct( + const int n, + const float* x, + const float* y, + const platform::DeviceContext* context) { + return cblas_sdot(n, x, 1, y, 1); +} + +template <> +double dotProduct( + const int n, + const double* x, + const double* y, + const platform::DeviceContext* context) { + return cblas_ddot(n, x, 1, y, 1); +} + +} // namespace math +} // namespace operators +} // namespace paddle diff --git a/paddle/operators/math/math_function.cu b/paddle/operators/math/math_function.cu new file mode 100644 index 0000000000000000000000000000000000000000..46301df8f9d4f82f0d915a131965e5fd76038be6 --- /dev/null +++ b/paddle/operators/math/math_function.cu @@ -0,0 +1,146 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/operators/math/math_function.h" + + +namespace paddle { +namespace operators { +namespace math { + +template <> +void gemm(const CBLAS_TRANSPOSE transA, + const CBLAS_TRANSPOSE transB, + const int M, + const int N, + const int K, + const float alpha, + const float* A, + const int lda, + const float* B, + const int ldb, + const float beta, + float* C, + const int ldc, + const platform::DeviceContext* context) { + // Note that cublas follows fortran order, so the order is different from + // the cblas convention. + cublasOperation_t cuTransA = + (TransA == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; + cublasOperation_t cuTransB = + (TransB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; + + PADDLE_ENFORCE(platform::dynload::cublasSgemm( + reinterpret_cast(context)-> + cublas_handle(), + cuTransB, + cuTransA, + N, + M, + K, + &alpha, + B, + ldb, + A, + lda, + &beta, + C, + ldc)); +} + +template <> +void gemm(const CBLAS_TRANSPOSE transA, + const CBLAS_TRANSPOSE transB, + const int M, + const int N, + const int K, + const double alpha, + const double* A, + const int lda, + const double* B, + const int ldb, + const double beta, + double* C, + const int ldc, + const platform::DeviceContext* context) { + // Note that cublas follows fortran order, so the order is different from + // the cblas convention. + cublasOperation_t cuTransA = + (TransA == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; + cublasOperation_t cuTransB = + (TransB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; + PADDLE_ENFORCE(platform::dynload::cublasDgemm( + reinterpret_cast(context)-> + cublas_handle(), + cuTransB, + cuTransA, + N, + M, + K, + &alpha, + B, + ldb, + A, + lda, + &beta, + C, + ldc)); +} + + +template <> +void axpy(const int n, + const float alpha, + const float* x, + float* y, + const platform::DeviceContext* context) { + CUBLAS_ENFORCE(platform::dynload::cublasSaxpy( + reinterpret_cast(context)-> + cublas_handle(), N, &alpha, X, 1, Y, 1)); +} + +template <> +void axpy(const int n, + const double alpha, + const double* x, + double* y, + const platform::DeviceContext* context) { + CUBLAS_ENFORCE(platform::dynload::cublasDaxpy( + reinterpret_cast(context)-> + cublas_handle(), N, &alpha, X, 1, Y, 1)); +} + +template <> +float dotProduct(const int n, + const float* x, + const float* y, + const platform::DeviceContext* context) { + CUBLAS_ENFORCE(platform::dynload::cublasSdot( + reinterpret_cast(context)-> + cublas_handle(), n, a, 1, b, 1, &result)); +} + +template <> +double dotProduct(const int n, + const double* x, + const double* y, + const platform::DeviceContext* context) { + CUBLAS_ENFORCE(platform::dynload::cublasDdot( + reinterpret_cast(context)-> + cublas_handle(), n, a, 1, b, 1, &result)); +} + +} // namespace math +} // namespace operators +} // namespace paddle diff --git a/paddle/operators/math/math_function.h b/paddle/operators/math/math_function.h new file mode 100644 index 0000000000000000000000000000000000000000..c5b7fe8793c952afa4af7bae02434f6d1df86ca0 --- /dev/null +++ b/paddle/operators/math/math_function.h @@ -0,0 +1,78 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#ifdef PADDLE_USE_MKLML +#include +#include +#include +#endif + +#ifdef PADDLE_USE_MKL +#include +#include +#endif + +#ifdef PADDLE_USE_ATLAS +extern "C" { +#include +#include +} +#endif + +#ifdef PADDLE_USE_OPENBLAS +#include +#include +#endif + +#include +#include "paddle/platform/device_context.h" + +namespace paddle { +namespace operators { +namespace math { + +template +void gemm(const CBLAS_TRANSPOSE transA, + const CBLAS_TRANSPOSE transB, + const int M, + const int N, + const int K, + const T alpha, + const T* A, + const int lda, + const T* B, + const int ldb, + const T beta, + T* C, + const int ldc, + const platform::DeviceContext* context); + +template +void axpy(const int n, + const T alpha, + const T* x, + T* y, + const platform::DeviceContext* context); + +template +T dotProduct(const int n, + const T* x, + const T* y, + const platform::DeviceContext* context); + +} // namespace math +} // namespace operators +} // namespace paddle diff --git a/paddle/operators/mean_op.h b/paddle/operators/mean_op.h index a89cb422f9b296dba6eb5358043f73d00aefc5d3..e712dee6a785749e51be7b233e85dbf39c835218 100644 --- a/paddle/operators/mean_op.h +++ b/paddle/operators/mean_op.h @@ -47,7 +47,7 @@ public: T ig_size = (T)framework::product(IG->dims()); - EigenVector::Flatten(*IG).device(*(context.GetEigenDevice())) = + EigenVector::Flatten(*IG).device(context.GetEigenDevice()) = EigenScalar::From(*OG) / ig_size; } }; diff --git a/paddle/operators/mul_op.cc b/paddle/operators/mul_op.cc index d127f3a302a340fe7558f918d6eeb2ea0a3fafe7..eaf1d3266ca6790036985322f0333ac491e3b143 100644 --- a/paddle/operators/mul_op.cc +++ b/paddle/operators/mul_op.cc @@ -13,6 +13,7 @@ limitations under the License. */ #include "paddle/operators/mul_op.h" +#include "paddle/operators/math/math_function.h" namespace paddle { namespace operators { diff --git a/paddle/operators/mul_op.cu b/paddle/operators/mul_op.cu index dc9236701627dc9335b844d2a82e18eb1f7dfd42..ba0460550366bbc9a7180e7a0ae8de425e50d116 100644 --- a/paddle/operators/mul_op.cu +++ b/paddle/operators/mul_op.cu @@ -15,4 +15,6 @@ #define EIGEN_USE_GPU #include "paddle/operators/mul_op.h" + + REGISTER_OP_GPU_KERNEL(mul, ops::MulKernel); \ No newline at end of file diff --git a/paddle/operators/mul_op.h b/paddle/operators/mul_op.h index c7b78ad39045d25d73bfc2c930063c255a514864..e1759d00c55ab9caf5e6714883d7b187deb05363 100644 --- a/paddle/operators/mul_op.h +++ b/paddle/operators/mul_op.h @@ -14,6 +14,7 @@ #pragma once +#include "paddle/operators/math/math_function.h" #include "paddle/operators/type_alias.h" namespace paddle { @@ -23,22 +24,35 @@ template class MulKernel : public OpKernel { public: void Compute(const ExecutionContext& context) const override { - Eigen::array, 1> dim_pair = { - {Eigen::IndexPair(1, 0)}}; - auto input0 = context.Input("X"); auto input1 = context.Input("Y"); auto output = context.Output(0); output->mutable_data(context.GetPlace()); - auto X = EigenMatrix::From(*input0); - auto Y = EigenMatrix::From(*input1); - auto Z = EigenMatrix::From(*output); - auto place = context.GetEigenDevice(); - - Z.device(place) = X.contract(Y, dim_pair); + auto out_dim = output->dims(); + auto in0_dim = input0->dims(); + + int M = out_dim[0]; + int N = out_dim[1]; + int K = in0_dim[1]; + + paddle::operators::math::template gemm(CblasNoTrans, + CblasNoTrans, + M, + N, + K, + 1, + input0->data(), + K, + input1->data(), + N, + 0, + output->data(), + N, + &context.device_context()); } }; + } // namespace operators } // namespace paddle diff --git a/python/paddle/v2/framework/tests/op_test_util.py b/python/paddle/v2/framework/tests/op_test_util.py index 98fae1b975ad6243b20e5c19ec6ff68d5536cd74..35d285e2e6ab09e96632807fad1adc94ba43268f 100644 --- a/python/paddle/v2/framework/tests/op_test_util.py +++ b/python/paddle/v2/framework/tests/op_test_util.py @@ -61,7 +61,7 @@ class OpTestMeta(type): for out_name in func.all_output_args: actual = numpy.array(scope.find_var(out_name).get_tensor()) expect = getattr(self, out_name) - # TODO(qijun) The default decimal is 7, but numpy.dot and eigen.mul + # TODO(qijun) The default decimal is 7, but numpy.dot and blas.gemm # has some diff, and could not pass unittest. So I set decimal 3 here. # And I will check this in future. numpy.testing.assert_almost_equal(actual, expect, decimal=3)