diff --git a/paddle/fluid/framework/scope.cc b/paddle/fluid/framework/scope.cc index a5742dbd3d66a47ca108768d875e5764a0e62f4f..953618560913229cd1e47659ad61e621efc10ed1 100644 --- a/paddle/fluid/framework/scope.cc +++ b/paddle/fluid/framework/scope.cc @@ -87,11 +87,12 @@ Variable* Scope::Var(const std::string& name) { } Variable* Scope::Var(std::string* name) { - auto new_name = string::Sprintf("%p.%d", this, vars_.size()); + SCOPE_VARS_WRITER_LOCK + auto new_name = std::to_string(reinterpret_cast(this)) + "." + + std::to_string(vars_.size()); if (name != nullptr) { *name = new_name; } - SCOPE_VARS_WRITER_LOCK return VarInternal(new_name); } diff --git a/paddle/fluid/framework/var_type_traits.cc b/paddle/fluid/framework/var_type_traits.cc index c3c5bab23b92a0274cf786ea2f18d8246706162f..a37b1fbab8cfd0642beaf725c02941002b2176b3 100644 --- a/paddle/fluid/framework/var_type_traits.cc +++ b/paddle/fluid/framework/var_type_traits.cc @@ -105,13 +105,15 @@ struct VarIdToTypeIndexMapHolder { } // namespace detail -const std::type_index &ToTypeIndex(int var_id) { +const std::type_index &VarTraitIdToTypeIndex(int var_id) { return detail::VarIdToTypeIndexMapHolder::ToTypeIndex(var_id); } -const char *ToTypeName(int var_id) { return ToTypeIndex(var_id).name(); } +const char *ToTypeName(int var_id) { + return VarTraitIdToTypeIndex(var_id).name(); +} -int ToTypeId(const std::type_index &type) { +int TypeIndexToVarTraitId(const std::type_index &type) { return detail::VarIdToTypeIndexMapHolder::ToTypeId(type); } diff --git a/paddle/fluid/framework/var_type_traits.h b/paddle/fluid/framework/var_type_traits.h index cc68cf2ab8e1bbc8a57cf97a2084610440a75f85..733542e4972b16a71f9e76c3076b424b7a901066 100644 --- a/paddle/fluid/framework/var_type_traits.h +++ b/paddle/fluid/framework/var_type_traits.h @@ -66,8 +66,8 @@ namespace paddle { namespace framework { const char *ToTypeName(int var_id); -const std::type_index &ToTypeIndex(int var_id); -int ToTypeId(const std::type_index &type); +const std::type_index &VarTraitIdToTypeIndex(int var_id); +int TypeIndexToVarTraitId(const std::type_index &type); namespace detail { diff --git a/paddle/fluid/framework/var_type_traits_test.cc b/paddle/fluid/framework/var_type_traits_test.cc index 00840d634d802cfe17fbff127a75606cb5e2cf79..a47275e1ca25a4f66e67b4986ec78e49ea952a51 100644 --- a/paddle/fluid/framework/var_type_traits_test.cc +++ b/paddle/fluid/framework/var_type_traits_test.cc @@ -45,10 +45,11 @@ struct TypeIndexChecker { constexpr auto kId = VarTypeTrait::kId; std::type_index actual_type(typeid(Type)); EXPECT_EQ(std::string(ToTypeName(kId)), std::string(actual_type.name())); - EXPECT_EQ(ToTypeIndex(kId), actual_type); - EXPECT_EQ(ToTypeId(actual_type), kId); - EXPECT_EQ(ToTypeIndex(ToTypeId(actual_type)), actual_type); - EXPECT_EQ(ToTypeId(ToTypeIndex(kId)), kId); + EXPECT_EQ(VarTraitIdToTypeIndex(kId), actual_type); + EXPECT_EQ(TypeIndexToVarTraitId(actual_type), kId); + EXPECT_EQ(VarTraitIdToTypeIndex(TypeIndexToVarTraitId(actual_type)), + actual_type); + EXPECT_EQ(TypeIndexToVarTraitId(VarTraitIdToTypeIndex(kId)), kId); EXPECT_TRUE(var_id_set->count(kId) == 0); // NOLINT EXPECT_TRUE(type_index_set->count(actual_type) == 0); // NOLINT diff --git a/paddle/fluid/inference/analysis/analyzer_tester.cc b/paddle/fluid/inference/analysis/analyzer_tester.cc index f84e1ab6b827b3b96d0a503394d95b06ed25a3d2..4c84d02d8679c4d42c0d02ae83e7f869c0f5ce8b 100644 --- a/paddle/fluid/inference/analysis/analyzer_tester.cc +++ b/paddle/fluid/inference/analysis/analyzer_tester.cc @@ -80,8 +80,8 @@ void TestWord2vecPrediction(const std::string& model_path) { i++) { LOG(INFO) << "data: " << static_cast(outputs.front().data.data())[i] << " result: " << result[i]; - PADDLE_ENFORCE(static_cast(outputs.front().data.data())[i], - result[i]); + EXPECT_NEAR(static_cast(outputs.front().data.data())[i], result[i], + 1e-3); } } diff --git a/paddle/fluid/inference/analysis/passes/CMakeLists.txt b/paddle/fluid/inference/analysis/passes/CMakeLists.txt index d3ea511d8f4d8cbec1be57633391f00e29a3e6e9..add9b70f2cd960a94232b35edb928ab4115cbff0 100644 --- a/paddle/fluid/inference/analysis/passes/CMakeLists.txt +++ b/paddle/fluid/inference/analysis/passes/CMakeLists.txt @@ -7,4 +7,5 @@ set(analysis_deps ${analysis_deps} ir_graph_build_pass ir_analysis_pass analysis_passes + subgraph_detector CACHE INTERNAL "") diff --git a/paddle/fluid/operators/math/blas_impl.cu.h b/paddle/fluid/operators/math/blas_impl.cu.h index d35073029a3440d8a17e383ce97fcfc582663888..58f7be12ce6b5d447e93cf86c4954a86fccf48ef 100644 --- a/paddle/fluid/operators/math/blas_impl.cu.h +++ b/paddle/fluid/operators/math/blas_impl.cu.h @@ -62,27 +62,19 @@ struct CUBlas { cudaDataType_t Atype, int lda, const void *B, cudaDataType_t Btype, int ldb, const float *beta, void *C, cudaDataType_t Ctype, int ldc) { - // Because the gcc 4.8 doesn't expand template parameter pack that - // appears in a lambda-expression, I can not use template parameter pack - // here. - auto cublas_call = [&]() { +// Because the gcc 4.8 doesn't expand template parameter pack that +// appears in a lambda-expression, I can not use template parameter pack +// here. #if CUDA_VERSION >= 8000 - VLOG(5) << "use_tensor_op_math: " - << (platform::TensorCoreAvailable() ? "True" : "False"); + VLOG(5) << "use_tensor_op_math: " + << (dev_ctx->tensor_core_available() ? "True" : "False"); + dev_ctx->TensorCoreCublasCallIfAvailable([&](cublasHandle_t handle) { PADDLE_ENFORCE(platform::dynload::cublasSgemmEx( - dev_ctx->cublas_handle(), transa, transb, m, n, k, alpha, A, Atype, - lda, B, Btype, ldb, beta, C, Ctype, ldc)); + handle, transa, transb, m, n, k, alpha, A, Atype, lda, B, Btype, ldb, + beta, C, Ctype, ldc)); + }); #else - PADDLE_THROW("cublasSgemmEx is supported on cuda >= 8.0"); -#endif - }; - -#if CUDA_VERSION >= 9000 - // NOTES: To use Tensor Core, we should change the cublas config, - // but the cublas may be hold by multi-thread. - dev_ctx->CublasCall(cublas_call, CUBLAS_TENSOR_OP_MATH); -#else - cublas_call(); + PADDLE_THROW("cublasSgemmEx is supported on cuda >= 8.0"); #endif } }; @@ -170,32 +162,24 @@ struct CUBlas { cudaDataType_t Btype, int ldb, const void *beta, void *C, cudaDataType_t Ctype, int ldc, cudaDataType_t computeType) { - auto cublas_call = [&]() { #if CUDA_VERSION >= 8000 - cublasGemmAlgo_t algo = CUBLAS_GEMM_DFALT; + cublasGemmAlgo_t algo = CUBLAS_GEMM_DFALT; #if CUDA_VERSION >= 9000 - bool use_tensor_op_math = platform::TensorCoreAvailable(); - if (use_tensor_op_math) { - algo = CUBLAS_GEMM_DFALT_TENSOR_OP; - } - VLOG(5) << "use_tensor_op_math: " - << (use_tensor_op_math ? "True" : "False"); + bool use_tensor_op_math = dev_ctx->tensor_core_available(); + if (use_tensor_op_math) { + algo = CUBLAS_GEMM_DFALT_TENSOR_OP; + } + VLOG(5) << "use_tensor_op_math: " + << (use_tensor_op_math ? "True" : "False"); #endif // CUDA_VERSION >= 9000 + dev_ctx->TensorCoreCublasCallIfAvailable([&](cublasHandle_t handle) { PADDLE_ENFORCE(platform::dynload::cublasGemmEx( - dev_ctx->cublas_handle(), transa, transb, m, n, k, alpha, A, Atype, - lda, B, Btype, ldb, beta, C, Ctype, ldc, computeType, algo)); + handle, transa, transb, m, n, k, alpha, A, Atype, lda, B, Btype, ldb, + beta, C, Ctype, ldc, computeType, algo)); + }); #else - PADDLE_THROW("cublasGemmEx is supported on cuda >= 8.0"); -#endif - }; - -#if CUDA_VERSION >= 9000 - // NOTES: To use Tensor Core, we should change the cublas config, - // but the cublas may be hold by multi-thread. - dev_ctx->CublasCall(cublas_call, CUBLAS_TENSOR_OP_MATH); -#else - cublas_call(); + PADDLE_THROW("cublasGemmEx is supported on cuda >= 8.0"); #endif } }; @@ -223,9 +207,10 @@ void Blas::GEMM(CBLAS_TRANSPOSE transA, CUDA_R_32F, N); } else { #endif // CUDA_VERSION >= 8000 - - CUBlas::GEMM(context_.cublas_handle(), cuTransB, cuTransA, N, M, K, - &alpha, B, ldb, A, lda, &beta, C, N); + context_.CublasCall([&](cublasHandle_t handle) { + CUBlas::GEMM(handle, cuTransB, cuTransA, N, M, K, &alpha, B, ldb, A, + lda, &beta, C, N); + }); #if CUDA_VERSION >= 8000 } @@ -266,9 +251,12 @@ inline void Blas::GEMM( CUDA_R_16F, lda, &h_beta, C, CUDA_R_16F, N, CUDA_R_32F); #else // CUDA 7.5 does not support cublasGemmEx, hence we fall back to use hgemm - CUBlas::GEMM(context_.cublas_handle(), cuTransB, cuTransA, - N, M, K, &h_alpha, h_B, ldb, h_A, lda, - &h_beta, h_C, N); + + context_.CublasCall([&](cublasHandle_t handle) { + CUBlas::GEMM(handle, cuTransB, cuTransA, N, M, K, + &h_alpha, h_B, ldb, h_A, lda, &h_beta, h_C, + N); + }); #endif // CUDA_VERSION >= 8000 } @@ -292,8 +280,10 @@ void Blas::GEMM(bool transA, bool transB, int M, } else { #endif // CUDA_VERSION >= 8000 - CUBlas::GEMM(context_.cublas_handle(), cuTransB, cuTransA, N, M, K, - &alpha, B, ldb, A, lda, &beta, C, ldc); + context_.CublasCall([&](cublasHandle_t handle) { + CUBlas::GEMM(handle, cuTransB, cuTransA, N, M, K, &alpha, B, ldb, A, + lda, &beta, C, ldc); + }); #if CUDA_VERSION >= 8000 } @@ -311,16 +301,19 @@ inline void Blas::GEMM( cublasOperation_t cuTransA = transA ? CUBLAS_OP_T : CUBLAS_OP_N; cublasOperation_t cuTransB = transB ? CUBLAS_OP_T : CUBLAS_OP_N; - CUBlas::GEMM(context_.cublas_handle(), cuTransB, cuTransA, - N, M, K, &alpha, B, ldb, A, lda, &beta, C, - ldc); + context_.CublasCall([&](cublasHandle_t handle) { + CUBlas::GEMM(handle, cuTransB, cuTransA, N, M, K, &alpha, + B, ldb, A, lda, &beta, C, ldc); + }); } template <> template void Blas::AXPY(int n, T alpha, const T *x, T *y) const { - CUBlas::AXPY(context_.cublas_handle(), n, &alpha, x, 1, y, 1); + context_.CublasCall([&](cublasHandle_t handle) { + CUBlas::AXPY(handle, n, &alpha, x, 1, y, 1); + }); } template <> @@ -330,8 +323,9 @@ void Blas::GEMV(bool trans_a, int M, int N, T beta, T *C) const { cublasOperation_t cuTransA = !trans_a ? CUBLAS_OP_T : CUBLAS_OP_N; - CUBlas::GEMV(context_.cublas_handle(), cuTransA, N, M, &alpha, A, N, B, 1, - &beta, C, 1); + context_.CublasCall([&](cublasHandle_t handle) { + CUBlas::GEMV(handle, cuTransA, N, M, &alpha, A, N, B, 1, &beta, C, 1); + }); } template <> @@ -353,28 +347,28 @@ void Blas::BatchedGEMM( #if CUDA_VERSION >= 9010 if (FLAGS_enable_cublas_tensor_op_math && std::is_same::value) { - auto cublas_call = [&]() { - cublasGemmAlgo_t algo = CUBLAS_GEMM_DFALT; - bool use_tensor_op_math = platform::TensorCoreAvailable(); - if (use_tensor_op_math) { - algo = CUBLAS_GEMM_DFALT_TENSOR_OP; - } - VLOG(5) << "use_tensor_op_math: " - << (use_tensor_op_math ? "True" : "False"); - + cublasGemmAlgo_t algo = CUBLAS_GEMM_DFALT; + bool use_tensor_op_math = context_.tensor_core_available(); + if (use_tensor_op_math) { + algo = CUBLAS_GEMM_DFALT_TENSOR_OP; + } + VLOG(5) << "use_tensor_op_math: " + << (use_tensor_op_math ? "True" : "False"); + + context_.TensorCoreCublasCallIfAvailable([&](cublasHandle_t handle) { PADDLE_ENFORCE(platform::dynload::cublasGemmStridedBatchedEx( - context_.cublas_handle(), cuTransB, cuTransA, N, M, K, &alpha, B, - CUDA_R_32F, ldb, strideB, A, CUDA_R_32F, lda, strideA, &beta, C, - CUDA_R_32F, ldc, strideC, batchCount, CUDA_R_32F, algo)); - }; - auto &dev_ctx = const_cast(context_); - dev_ctx.CublasCall(cublas_call, CUBLAS_TENSOR_OP_MATH); + handle, cuTransB, cuTransA, N, M, K, &alpha, B, CUDA_R_32F, ldb, + strideB, A, CUDA_R_32F, lda, strideA, &beta, C, CUDA_R_32F, ldc, + strideC, batchCount, CUDA_R_32F, algo)); + }); } else { #endif // CUDA_VERSION >= 9010 - CUBlas::GEMM_STRIDED_BATCH(context_.cublas_handle(), cuTransB, cuTransA, - N, M, K, &alpha, B, ldb, strideB, A, lda, - strideA, &beta, C, ldc, strideC, batchCount); + context_.CublasCall([&](cublasHandle_t handle) { + CUBlas::GEMM_STRIDED_BATCH(handle, cuTransB, cuTransA, N, M, K, &alpha, + B, ldb, strideB, A, lda, strideA, &beta, C, + ldc, strideC, batchCount); + }); #if CUDA_VERSION >= 9010 } diff --git a/paddle/fluid/operators/ngraph/ops/binary_unnary_op.h b/paddle/fluid/operators/ngraph/ops/binary_unnary_op.h index 6610380fcf432d0019f7e844fa9304e151b20efd..0c0d25d0cd1ae536618057ce80388b8eeb81c68a 100644 --- a/paddle/fluid/operators/ngraph/ops/binary_unnary_op.h +++ b/paddle/fluid/operators/ngraph/ops/binary_unnary_op.h @@ -12,7 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#ifdef PADDLE_WITH_NGRAPH #pragma once #include @@ -48,4 +47,3 @@ static void BuildUnaryNode( } // namespace ngraphs } // namespace operators } // namespace paddle -#endif diff --git a/paddle/fluid/operators/ngraph/ops/elementwise_scalar_op.h b/paddle/fluid/operators/ngraph/ops/elementwise_scalar_op.h index 15fbd58b02d2b13a8f5401f7cbe291da35748e83..8f5092963c8b79501ea68c1f521c4678977635ea 100644 --- a/paddle/fluid/operators/ngraph/ops/elementwise_scalar_op.h +++ b/paddle/fluid/operators/ngraph/ops/elementwise_scalar_op.h @@ -12,7 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#ifdef PADDLE_WITH_NGRAPH #pragma once #include @@ -58,4 +57,3 @@ std::shared_ptr ElementwiseScalar( } // namespace ngraphs } // namespace operators } // namespace paddle -#endif diff --git a/paddle/fluid/operators/ngraph/ops/fill_constant_op.h b/paddle/fluid/operators/ngraph/ops/fill_constant_op.h index 5eff69e7b165fa19c775926914b7b3e8fcb043e5..406a4314f89810df192280cc97de245553d5520f 100644 --- a/paddle/fluid/operators/ngraph/ops/fill_constant_op.h +++ b/paddle/fluid/operators/ngraph/ops/fill_constant_op.h @@ -12,7 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#ifdef PADDLE_WITH_NGRAPH #pragma once #include @@ -58,4 +57,3 @@ void BuildFillConstantNode( } // namespace ngraphs } // namespace operators } // namespace paddle -#endif diff --git a/paddle/fluid/operators/ngraph/ops/mean_op.h b/paddle/fluid/operators/ngraph/ops/mean_op.h index 7fcf8f09cd346db8cf6706014e0d4573ced7a86c..4c44bc4c112f401c2707f7babd49a33f238a768f 100644 --- a/paddle/fluid/operators/ngraph/ops/mean_op.h +++ b/paddle/fluid/operators/ngraph/ops/mean_op.h @@ -12,7 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#ifdef PADDLE_WITH_NGRAPH #pragma once #include @@ -65,4 +64,3 @@ void BuildMeanGradNode( } // namespace ngraphs } // namespace operators } // namespace paddle -#endif diff --git a/paddle/fluid/operators/ngraph/ops/mul_op.h b/paddle/fluid/operators/ngraph/ops/mul_op.h index 9e12e5d7c3da04706907c7ae63ce8046ce667f25..4a6cbebe245f891c6c33b2116330a41d89d50e25 100644 --- a/paddle/fluid/operators/ngraph/ops/mul_op.h +++ b/paddle/fluid/operators/ngraph/ops/mul_op.h @@ -12,7 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#ifdef PADDLE_WITH_NGRAPH #pragma once #include @@ -131,4 +130,3 @@ static void BuildMulGradNode( } // namespace ngraphs } // namespace operators } // namespace paddle -#endif diff --git a/paddle/fluid/operators/ngraph/ops/scale_op.h b/paddle/fluid/operators/ngraph/ops/scale_op.h index 24ab0702aa50861b34fe1af7ccaf37d4e1dffc41..91a57d0be606373e985a30b7ac9c73648062d8e4 100644 --- a/paddle/fluid/operators/ngraph/ops/scale_op.h +++ b/paddle/fluid/operators/ngraph/ops/scale_op.h @@ -12,7 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#ifdef PADDLE_WITH_NGRAPH #pragma once #include @@ -38,4 +37,3 @@ void BuildScaleNode( } // namespace ngraphs } // namespace operators } // namespace paddle -#endif diff --git a/paddle/fluid/operators/ngraph/ops/top_k_op.h b/paddle/fluid/operators/ngraph/ops/top_k_op.h index 2b7254497c0e1aab2e653e69e6461f262b929703..ea66953a125860ab1ce8309819b6c433ff32eaaa 100644 --- a/paddle/fluid/operators/ngraph/ops/top_k_op.h +++ b/paddle/fluid/operators/ngraph/ops/top_k_op.h @@ -12,7 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#ifdef PADDLE_WITH_NGRAPH #pragma once #include @@ -48,4 +47,3 @@ void BuildTopKNode( } // namespace ngraphs } // namespace operators } // namespace paddle -#endif diff --git a/paddle/fluid/platform/cuda_helper.h b/paddle/fluid/platform/cuda_helper.h new file mode 100644 index 0000000000000000000000000000000000000000..122de72e15d587cf33b5d9856ac8b1243f666881 --- /dev/null +++ b/paddle/fluid/platform/cuda_helper.h @@ -0,0 +1,58 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include // NOLINT + +#include "paddle/fluid/platform/dynload/cublas.h" +#include "paddle/fluid/platform/macros.h" + +#if CUDA_VERSION < 9000 +enum cublasMath_t { CUBLAS_DEFAULT_MATH = 0 }; +#endif + +namespace paddle { +namespace platform { + +class CublasHandleHolder { + public: + CublasHandleHolder(cudaStream_t stream, cublasMath_t math_type) { + PADDLE_ENFORCE(dynload::cublasCreate(&handle_)); + PADDLE_ENFORCE(dynload::cublasSetStream(handle_, stream)); +#if CUDA_VERSION >= 9000 + if (math_type == CUBLAS_TENSOR_OP_MATH) { + PADDLE_ENFORCE( + dynload::cublasSetMathMode(handle_, CUBLAS_TENSOR_OP_MATH)); + } +#endif + } + + ~CublasHandleHolder() { PADDLE_ENFORCE(dynload::cublasDestroy(handle_)); } + + template + inline void Call(Callback &&callback) const { + std::lock_guard guard(mtx_); + callback(handle_); + } + + private: + DISABLE_COPY_AND_ASSIGN(CublasHandleHolder); + + cublasHandle_t handle_; + mutable std::mutex mtx_; +}; + +} // namespace platform +} // namespace paddle diff --git a/paddle/fluid/platform/device_context.cc b/paddle/fluid/platform/device_context.cc index 6f38dbb7a20dae4c4ea1e448c8572d98800b0213..09f3d3de54e4388f7090621a0fead96b3043d918 100644 --- a/paddle/fluid/platform/device_context.cc +++ b/paddle/fluid/platform/device_context.cc @@ -245,8 +245,15 @@ CUDADeviceContext::CUDADeviceContext(CUDAPlace place) eigen_stream_.reset(new EigenCudaStreamDevice()); eigen_stream_->Reinitialize(&stream_, place); eigen_device_.reset(new Eigen::GpuDevice(eigen_stream_.get())); - PADDLE_ENFORCE(dynload::cublasCreate(&cublas_handle_)); - PADDLE_ENFORCE(dynload::cublasSetStream(cublas_handle_, stream_)); + cublas_handle_.reset(new CublasHandleHolder(stream_, CUBLAS_DEFAULT_MATH)); + + if (TensorCoreAvailable()) { +#if CUDA_VERSION >= 9000 + cublas_tensor_core_handle_.reset( + new CublasHandleHolder(stream_, CUBLAS_TENSOR_OP_MATH)); +#endif + } + if (dynload::HasCUDNN()) { cudnn_holder_.reset(new CudnnHolder(&stream_, place)); } @@ -306,7 +313,8 @@ CUDADeviceContext::~CUDADeviceContext() { SetDeviceId(place_.device); Wait(); WaitStreamCallback(); - PADDLE_ENFORCE(dynload::cublasDestroy(cublas_handle_)); + cublas_handle_.reset(); + cublas_tensor_core_handle_.reset(); eigen_stream_.reset(); eigen_device_.reset(); PADDLE_ENFORCE(cudaStreamDestroy(stream_)); @@ -335,8 +343,8 @@ Eigen::GpuDevice* CUDADeviceContext::eigen_device() const { return eigen_device_.get(); } -cublasHandle_t CUDADeviceContext::cublas_handle() const { - return cublas_handle_; +bool CUDADeviceContext::tensor_core_available() const { + return cublas_tensor_core_handle_ != nullptr; } cudnnHandle_t CUDADeviceContext::cudnn_handle() const { diff --git a/paddle/fluid/platform/device_context.h b/paddle/fluid/platform/device_context.h index 7e875801893f3b73f8efaf33af690f8c855beee4..c81d17380cf894631d06588c007c2e11ce5c7836 100644 --- a/paddle/fluid/platform/device_context.h +++ b/paddle/fluid/platform/device_context.h @@ -20,6 +20,7 @@ limitations under the License. */ #include "paddle/fluid/memory/malloc.h" #include "paddle/fluid/platform/temporary_allocator.h" #ifdef PADDLE_WITH_CUDA +#include "paddle/fluid/platform/cuda_helper.h" #include "paddle/fluid/platform/dynload/cublas.h" #include "paddle/fluid/platform/dynload/cudnn.h" #include "paddle/fluid/platform/gpu_info.h" @@ -209,39 +210,6 @@ class CudnnWorkspaceHandle { std::unique_ptr> guard_; }; -#if CUDA_VERSION >= 9000 -class ScopedCublasMathMode { - public: - ScopedCublasMathMode(cublasHandle_t handle, cublasMath_t new_math_mode) - : handle_(handle) { - need_reset = false; - PADDLE_ENFORCE( - platform::dynload::cublasGetMathMode(handle_, &old_math_mode_), - "Failed to get old cublas math mode"); - if (old_math_mode_ != new_math_mode) { - PADDLE_ENFORCE( - platform::dynload::cublasSetMathMode(handle_, new_math_mode), - "Failed to set old cublas math mode"); - need_reset = true; - } - } - - ~ScopedCublasMathMode() { - if (need_reset) { - PADDLE_ENFORCE( - platform::dynload::cublasSetMathMode(handle_, old_math_mode_), - "Failed to set old cublas math mode"); - } - } - - private: - cublasHandle_t handle_; - cublasMath_t old_math_mode_; - bool need_reset; -}; - -#endif - class CUDADeviceContext : public DeviceContext { public: explicit CUDADeviceContext(CUDAPlace place); @@ -262,8 +230,25 @@ class CUDADeviceContext : public DeviceContext { /*! \brief Return eigen device in the device context. */ Eigen::GpuDevice* eigen_device() const; - /*! \brief Return cublas handle in the device context. */ - cublasHandle_t cublas_handle() const; + /*! \brief Call cublas function safely. */ + template + inline void CublasCall(Callback&& callback) const { + cublas_handle_->Call(std::forward(callback)); + } + + /*! \brief Check whether tensor core is supported */ + bool tensor_core_available() const; + + /*! \brief Call cublas function with Tensor Core safely. If + Tensor Core is not available, use DEFAULT_MATH instead. */ + template + inline void TensorCoreCublasCallIfAvailable(Callback&& callback) const { + if (cublas_tensor_core_handle_) { + cublas_tensor_core_handle_->Call(std::forward(callback)); + } else { + cublas_handle_->Call(std::forward(callback)); + } + } /*! \brief Return cudnn handle in the device context. */ cudnnHandle_t cudnn_handle() const; @@ -282,7 +267,6 @@ class CUDADeviceContext : public DeviceContext { template void RecordEvent(cudaEvent_t ev, Callback callback) { - std::lock_guard guard(mtx_); callback(); PADDLE_ENFORCE(cudaEventRecord(ev, stream_)); } @@ -294,18 +278,6 @@ class CUDADeviceContext : public DeviceContext { void WaitStreamCallback() const { callback_manager_->Wait(); } -#if CUDA_VERSION >= 9000 - /*! \brief CublasCall may need to change cublas's config, - * but the cublas may be hold by multi-thread, so we should - * add lock here. */ - template - void CublasCall(Callback callback, cublasMath_t new_math) { - std::lock_guard guard(cublas_mtx_); - ScopedCublasMathMode scoped_cublas_math(cublas_handle_, new_math); - callback(); - } -#endif - private: CUDAPlace place_; @@ -313,7 +285,9 @@ class CUDADeviceContext : public DeviceContext { std::unique_ptr eigen_stream_; std::unique_ptr cudnn_holder_; cudaStream_t stream_; - cublasHandle_t cublas_handle_; + + std::unique_ptr cublas_handle_; + std::unique_ptr cublas_tensor_core_handle_; int compute_capability_; int runtime_version_; @@ -321,12 +295,10 @@ class CUDADeviceContext : public DeviceContext { int multi_process_; int max_threads_per_mp_; - mutable std::mutex mtx_; - // StreamCallbackManager is thread-safe std::unique_ptr callback_manager_; - mutable std::mutex cublas_mtx_; + DISABLE_COPY_AND_ASSIGN(CUDADeviceContext); }; template <> diff --git a/paddle/fluid/platform/device_context_test.cu b/paddle/fluid/platform/device_context_test.cu index 171d2979a0218ad5e22112190a59866b3e0b617f..5b3aa98efb46b51d6c3edb6d2cbd4200bd0a35c6 100644 --- a/paddle/fluid/platform/device_context_test.cu +++ b/paddle/fluid/platform/device_context_test.cu @@ -43,9 +43,6 @@ TEST(Device, CUDADeviceContext) { ASSERT_NE(nullptr, gpu_device); cudnnHandle_t cudnn_handle = device_context->cudnn_handle(); ASSERT_NE(nullptr, cudnn_handle); - cublasHandle_t cublas_handle = device_context->cublas_handle(); - ASSERT_NE(nullptr, cublas_handle); - ASSERT_NE(nullptr, device_context->stream()); delete device_context; } } diff --git a/python/paddle/fluid/__init__.py b/python/paddle/fluid/__init__.py index f9f3807b1567eaf0be20b522154552a8b157583f..2c17716500ababfab3216a5ec47fecca30065ff1 100644 --- a/python/paddle/fluid/__init__.py +++ b/python/paddle/fluid/__init__.py @@ -155,7 +155,7 @@ def __bootstrap__(): 'fraction_of_gpu_memory_to_use', 'cudnn_deterministic', 'enable_cublas_tensor_op_math', 'conv_workspace_size_limit', 'cudnn_exhaustive_search', 'memory_optimize_debug', 'selected_gpus', - 'cudnn_exhaustive_search_times', 'sync_nccl_allreduce' + 'sync_nccl_allreduce' ] core.init_gflags([sys.argv[0]] +