提交 f8c305b2 编写于 作者: T tensor-tang

Merge remote-tracking branch 'ups/develop' into fuse/seqpool_concat_2

test=develop
...@@ -87,11 +87,12 @@ Variable* Scope::Var(const std::string& name) { ...@@ -87,11 +87,12 @@ Variable* Scope::Var(const std::string& name) {
} }
Variable* Scope::Var(std::string* name) { Variable* Scope::Var(std::string* name) {
auto new_name = string::Sprintf("%p.%d", this, vars_.size()); SCOPE_VARS_WRITER_LOCK
auto new_name = std::to_string(reinterpret_cast<uintptr_t>(this)) + "." +
std::to_string(vars_.size());
if (name != nullptr) { if (name != nullptr) {
*name = new_name; *name = new_name;
} }
SCOPE_VARS_WRITER_LOCK
return VarInternal(new_name); return VarInternal(new_name);
} }
......
...@@ -105,13 +105,15 @@ struct VarIdToTypeIndexMapHolder { ...@@ -105,13 +105,15 @@ struct VarIdToTypeIndexMapHolder {
} // namespace detail } // namespace detail
const std::type_index &ToTypeIndex(int var_id) { const std::type_index &VarTraitIdToTypeIndex(int var_id) {
return detail::VarIdToTypeIndexMapHolder::ToTypeIndex(var_id); return detail::VarIdToTypeIndexMapHolder::ToTypeIndex(var_id);
} }
const char *ToTypeName(int var_id) { return ToTypeIndex(var_id).name(); } const char *ToTypeName(int var_id) {
return VarTraitIdToTypeIndex(var_id).name();
}
int ToTypeId(const std::type_index &type) { int TypeIndexToVarTraitId(const std::type_index &type) {
return detail::VarIdToTypeIndexMapHolder::ToTypeId(type); return detail::VarIdToTypeIndexMapHolder::ToTypeId(type);
} }
......
...@@ -66,8 +66,8 @@ namespace paddle { ...@@ -66,8 +66,8 @@ namespace paddle {
namespace framework { namespace framework {
const char *ToTypeName(int var_id); const char *ToTypeName(int var_id);
const std::type_index &ToTypeIndex(int var_id); const std::type_index &VarTraitIdToTypeIndex(int var_id);
int ToTypeId(const std::type_index &type); int TypeIndexToVarTraitId(const std::type_index &type);
namespace detail { namespace detail {
......
...@@ -45,10 +45,11 @@ struct TypeIndexChecker { ...@@ -45,10 +45,11 @@ struct TypeIndexChecker {
constexpr auto kId = VarTypeTrait<Type>::kId; constexpr auto kId = VarTypeTrait<Type>::kId;
std::type_index actual_type(typeid(Type)); std::type_index actual_type(typeid(Type));
EXPECT_EQ(std::string(ToTypeName(kId)), std::string(actual_type.name())); EXPECT_EQ(std::string(ToTypeName(kId)), std::string(actual_type.name()));
EXPECT_EQ(ToTypeIndex(kId), actual_type); EXPECT_EQ(VarTraitIdToTypeIndex(kId), actual_type);
EXPECT_EQ(ToTypeId(actual_type), kId); EXPECT_EQ(TypeIndexToVarTraitId(actual_type), kId);
EXPECT_EQ(ToTypeIndex(ToTypeId(actual_type)), actual_type); EXPECT_EQ(VarTraitIdToTypeIndex(TypeIndexToVarTraitId(actual_type)),
EXPECT_EQ(ToTypeId(ToTypeIndex(kId)), kId); actual_type);
EXPECT_EQ(TypeIndexToVarTraitId(VarTraitIdToTypeIndex(kId)), kId);
EXPECT_TRUE(var_id_set->count(kId) == 0); // NOLINT EXPECT_TRUE(var_id_set->count(kId) == 0); // NOLINT
EXPECT_TRUE(type_index_set->count(actual_type) == 0); // NOLINT EXPECT_TRUE(type_index_set->count(actual_type) == 0); // NOLINT
......
...@@ -80,8 +80,8 @@ void TestWord2vecPrediction(const std::string& model_path) { ...@@ -80,8 +80,8 @@ void TestWord2vecPrediction(const std::string& model_path) {
i++) { i++) {
LOG(INFO) << "data: " << static_cast<float*>(outputs.front().data.data())[i] LOG(INFO) << "data: " << static_cast<float*>(outputs.front().data.data())[i]
<< " result: " << result[i]; << " result: " << result[i];
PADDLE_ENFORCE(static_cast<float*>(outputs.front().data.data())[i], EXPECT_NEAR(static_cast<float*>(outputs.front().data.data())[i], result[i],
result[i]); 1e-3);
} }
} }
......
...@@ -7,4 +7,5 @@ set(analysis_deps ${analysis_deps} ...@@ -7,4 +7,5 @@ set(analysis_deps ${analysis_deps}
ir_graph_build_pass ir_graph_build_pass
ir_analysis_pass ir_analysis_pass
analysis_passes analysis_passes
subgraph_detector
CACHE INTERNAL "") CACHE INTERNAL "")
...@@ -62,27 +62,19 @@ struct CUBlas<float> { ...@@ -62,27 +62,19 @@ struct CUBlas<float> {
cudaDataType_t Atype, int lda, const void *B, cudaDataType_t Atype, int lda, const void *B,
cudaDataType_t Btype, int ldb, const float *beta, void *C, cudaDataType_t Btype, int ldb, const float *beta, void *C,
cudaDataType_t Ctype, int ldc) { cudaDataType_t Ctype, int ldc) {
// Because the gcc 4.8 doesn't expand template parameter pack that // Because the gcc 4.8 doesn't expand template parameter pack that
// appears in a lambda-expression, I can not use template parameter pack // appears in a lambda-expression, I can not use template parameter pack
// here. // here.
auto cublas_call = [&]() {
#if CUDA_VERSION >= 8000 #if CUDA_VERSION >= 8000
VLOG(5) << "use_tensor_op_math: " VLOG(5) << "use_tensor_op_math: "
<< (platform::TensorCoreAvailable() ? "True" : "False"); << (dev_ctx->tensor_core_available() ? "True" : "False");
dev_ctx->TensorCoreCublasCallIfAvailable([&](cublasHandle_t handle) {
PADDLE_ENFORCE(platform::dynload::cublasSgemmEx( PADDLE_ENFORCE(platform::dynload::cublasSgemmEx(
dev_ctx->cublas_handle(), transa, transb, m, n, k, alpha, A, Atype, handle, transa, transb, m, n, k, alpha, A, Atype, lda, B, Btype, ldb,
lda, B, Btype, ldb, beta, C, Ctype, ldc)); beta, C, Ctype, ldc));
});
#else #else
PADDLE_THROW("cublasSgemmEx is supported on cuda >= 8.0"); PADDLE_THROW("cublasSgemmEx is supported on cuda >= 8.0");
#endif
};
#if CUDA_VERSION >= 9000
// NOTES: To use Tensor Core, we should change the cublas config,
// but the cublas may be hold by multi-thread.
dev_ctx->CublasCall(cublas_call, CUBLAS_TENSOR_OP_MATH);
#else
cublas_call();
#endif #endif
} }
}; };
...@@ -170,32 +162,24 @@ struct CUBlas<platform::float16> { ...@@ -170,32 +162,24 @@ struct CUBlas<platform::float16> {
cudaDataType_t Btype, int ldb, const void *beta, void *C, cudaDataType_t Btype, int ldb, const void *beta, void *C,
cudaDataType_t Ctype, int ldc, cudaDataType_t Ctype, int ldc,
cudaDataType_t computeType) { cudaDataType_t computeType) {
auto cublas_call = [&]() {
#if CUDA_VERSION >= 8000 #if CUDA_VERSION >= 8000
cublasGemmAlgo_t algo = CUBLAS_GEMM_DFALT; cublasGemmAlgo_t algo = CUBLAS_GEMM_DFALT;
#if CUDA_VERSION >= 9000 #if CUDA_VERSION >= 9000
bool use_tensor_op_math = platform::TensorCoreAvailable(); bool use_tensor_op_math = dev_ctx->tensor_core_available();
if (use_tensor_op_math) { if (use_tensor_op_math) {
algo = CUBLAS_GEMM_DFALT_TENSOR_OP; algo = CUBLAS_GEMM_DFALT_TENSOR_OP;
} }
VLOG(5) << "use_tensor_op_math: " VLOG(5) << "use_tensor_op_math: "
<< (use_tensor_op_math ? "True" : "False"); << (use_tensor_op_math ? "True" : "False");
#endif // CUDA_VERSION >= 9000 #endif // CUDA_VERSION >= 9000
dev_ctx->TensorCoreCublasCallIfAvailable([&](cublasHandle_t handle) {
PADDLE_ENFORCE(platform::dynload::cublasGemmEx( PADDLE_ENFORCE(platform::dynload::cublasGemmEx(
dev_ctx->cublas_handle(), transa, transb, m, n, k, alpha, A, Atype, handle, transa, transb, m, n, k, alpha, A, Atype, lda, B, Btype, ldb,
lda, B, Btype, ldb, beta, C, Ctype, ldc, computeType, algo)); beta, C, Ctype, ldc, computeType, algo));
});
#else #else
PADDLE_THROW("cublasGemmEx is supported on cuda >= 8.0"); PADDLE_THROW("cublasGemmEx is supported on cuda >= 8.0");
#endif
};
#if CUDA_VERSION >= 9000
// NOTES: To use Tensor Core, we should change the cublas config,
// but the cublas may be hold by multi-thread.
dev_ctx->CublasCall(cublas_call, CUBLAS_TENSOR_OP_MATH);
#else
cublas_call();
#endif #endif
} }
}; };
...@@ -223,9 +207,10 @@ void Blas<platform::CUDADeviceContext>::GEMM(CBLAS_TRANSPOSE transA, ...@@ -223,9 +207,10 @@ void Blas<platform::CUDADeviceContext>::GEMM(CBLAS_TRANSPOSE transA,
CUDA_R_32F, N); CUDA_R_32F, N);
} else { } else {
#endif // CUDA_VERSION >= 8000 #endif // CUDA_VERSION >= 8000
context_.CublasCall([&](cublasHandle_t handle) {
CUBlas<T>::GEMM(context_.cublas_handle(), cuTransB, cuTransA, N, M, K, CUBlas<T>::GEMM(handle, cuTransB, cuTransA, N, M, K, &alpha, B, ldb, A,
&alpha, B, ldb, A, lda, &beta, C, N); lda, &beta, C, N);
});
#if CUDA_VERSION >= 8000 #if CUDA_VERSION >= 8000
} }
...@@ -266,9 +251,12 @@ inline void Blas<platform::CUDADeviceContext>::GEMM( ...@@ -266,9 +251,12 @@ inline void Blas<platform::CUDADeviceContext>::GEMM(
CUDA_R_16F, lda, &h_beta, C, CUDA_R_16F, N, CUDA_R_32F); CUDA_R_16F, lda, &h_beta, C, CUDA_R_16F, N, CUDA_R_32F);
#else #else
// CUDA 7.5 does not support cublasGemmEx, hence we fall back to use hgemm // CUDA 7.5 does not support cublasGemmEx, hence we fall back to use hgemm
CUBlas<platform::float16>::GEMM(context_.cublas_handle(), cuTransB, cuTransA,
N, M, K, &h_alpha, h_B, ldb, h_A, lda, context_.CublasCall([&](cublasHandle_t handle) {
&h_beta, h_C, N); CUBlas<platform::float16>::GEMM(handle, cuTransB, cuTransA, N, M, K,
&h_alpha, h_B, ldb, h_A, lda, &h_beta, h_C,
N);
});
#endif // CUDA_VERSION >= 8000 #endif // CUDA_VERSION >= 8000
} }
...@@ -292,8 +280,10 @@ void Blas<platform::CUDADeviceContext>::GEMM(bool transA, bool transB, int M, ...@@ -292,8 +280,10 @@ void Blas<platform::CUDADeviceContext>::GEMM(bool transA, bool transB, int M,
} else { } else {
#endif // CUDA_VERSION >= 8000 #endif // CUDA_VERSION >= 8000
CUBlas<T>::GEMM(context_.cublas_handle(), cuTransB, cuTransA, N, M, K, context_.CublasCall([&](cublasHandle_t handle) {
&alpha, B, ldb, A, lda, &beta, C, ldc); CUBlas<T>::GEMM(handle, cuTransB, cuTransA, N, M, K, &alpha, B, ldb, A,
lda, &beta, C, ldc);
});
#if CUDA_VERSION >= 8000 #if CUDA_VERSION >= 8000
} }
...@@ -311,16 +301,19 @@ inline void Blas<platform::CUDADeviceContext>::GEMM( ...@@ -311,16 +301,19 @@ inline void Blas<platform::CUDADeviceContext>::GEMM(
cublasOperation_t cuTransA = transA ? CUBLAS_OP_T : CUBLAS_OP_N; cublasOperation_t cuTransA = transA ? CUBLAS_OP_T : CUBLAS_OP_N;
cublasOperation_t cuTransB = transB ? CUBLAS_OP_T : CUBLAS_OP_N; cublasOperation_t cuTransB = transB ? CUBLAS_OP_T : CUBLAS_OP_N;
CUBlas<platform::float16>::GEMM(context_.cublas_handle(), cuTransB, cuTransA, context_.CublasCall([&](cublasHandle_t handle) {
N, M, K, &alpha, B, ldb, A, lda, &beta, C, CUBlas<platform::float16>::GEMM(handle, cuTransB, cuTransA, N, M, K, &alpha,
ldc); B, ldb, A, lda, &beta, C, ldc);
});
} }
template <> template <>
template <typename T> template <typename T>
void Blas<platform::CUDADeviceContext>::AXPY(int n, T alpha, const T *x, void Blas<platform::CUDADeviceContext>::AXPY(int n, T alpha, const T *x,
T *y) const { T *y) const {
CUBlas<T>::AXPY(context_.cublas_handle(), n, &alpha, x, 1, y, 1); context_.CublasCall([&](cublasHandle_t handle) {
CUBlas<T>::AXPY(handle, n, &alpha, x, 1, y, 1);
});
} }
template <> template <>
...@@ -330,8 +323,9 @@ void Blas<platform::CUDADeviceContext>::GEMV(bool trans_a, int M, int N, ...@@ -330,8 +323,9 @@ void Blas<platform::CUDADeviceContext>::GEMV(bool trans_a, int M, int N,
T beta, T *C) const { T beta, T *C) const {
cublasOperation_t cuTransA = !trans_a ? CUBLAS_OP_T : CUBLAS_OP_N; cublasOperation_t cuTransA = !trans_a ? CUBLAS_OP_T : CUBLAS_OP_N;
CUBlas<T>::GEMV(context_.cublas_handle(), cuTransA, N, M, &alpha, A, N, B, 1, context_.CublasCall([&](cublasHandle_t handle) {
&beta, C, 1); CUBlas<T>::GEMV(handle, cuTransA, N, M, &alpha, A, N, B, 1, &beta, C, 1);
});
} }
template <> template <>
...@@ -353,28 +347,28 @@ void Blas<platform::CUDADeviceContext>::BatchedGEMM( ...@@ -353,28 +347,28 @@ void Blas<platform::CUDADeviceContext>::BatchedGEMM(
#if CUDA_VERSION >= 9010 #if CUDA_VERSION >= 9010
if (FLAGS_enable_cublas_tensor_op_math && std::is_same<T, float>::value) { if (FLAGS_enable_cublas_tensor_op_math && std::is_same<T, float>::value) {
auto cublas_call = [&]() { cublasGemmAlgo_t algo = CUBLAS_GEMM_DFALT;
cublasGemmAlgo_t algo = CUBLAS_GEMM_DFALT; bool use_tensor_op_math = context_.tensor_core_available();
bool use_tensor_op_math = platform::TensorCoreAvailable(); if (use_tensor_op_math) {
if (use_tensor_op_math) { algo = CUBLAS_GEMM_DFALT_TENSOR_OP;
algo = CUBLAS_GEMM_DFALT_TENSOR_OP; }
} VLOG(5) << "use_tensor_op_math: "
VLOG(5) << "use_tensor_op_math: " << (use_tensor_op_math ? "True" : "False");
<< (use_tensor_op_math ? "True" : "False");
context_.TensorCoreCublasCallIfAvailable([&](cublasHandle_t handle) {
PADDLE_ENFORCE(platform::dynload::cublasGemmStridedBatchedEx( PADDLE_ENFORCE(platform::dynload::cublasGemmStridedBatchedEx(
context_.cublas_handle(), cuTransB, cuTransA, N, M, K, &alpha, B, handle, cuTransB, cuTransA, N, M, K, &alpha, B, CUDA_R_32F, ldb,
CUDA_R_32F, ldb, strideB, A, CUDA_R_32F, lda, strideA, &beta, C, strideB, A, CUDA_R_32F, lda, strideA, &beta, C, CUDA_R_32F, ldc,
CUDA_R_32F, ldc, strideC, batchCount, CUDA_R_32F, algo)); strideC, batchCount, CUDA_R_32F, algo));
}; });
auto &dev_ctx = const_cast<platform::CUDADeviceContext &>(context_);
dev_ctx.CublasCall(cublas_call, CUBLAS_TENSOR_OP_MATH);
} else { } else {
#endif // CUDA_VERSION >= 9010 #endif // CUDA_VERSION >= 9010
CUBlas<T>::GEMM_STRIDED_BATCH(context_.cublas_handle(), cuTransB, cuTransA, context_.CublasCall([&](cublasHandle_t handle) {
N, M, K, &alpha, B, ldb, strideB, A, lda, CUBlas<T>::GEMM_STRIDED_BATCH(handle, cuTransB, cuTransA, N, M, K, &alpha,
strideA, &beta, C, ldc, strideC, batchCount); B, ldb, strideB, A, lda, strideA, &beta, C,
ldc, strideC, batchCount);
});
#if CUDA_VERSION >= 9010 #if CUDA_VERSION >= 9010
} }
......
...@@ -12,7 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,7 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#ifdef PADDLE_WITH_NGRAPH
#pragma once #pragma once
#include <string> #include <string>
...@@ -48,4 +47,3 @@ static void BuildUnaryNode( ...@@ -48,4 +47,3 @@ static void BuildUnaryNode(
} // namespace ngraphs } // namespace ngraphs
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
#endif
...@@ -12,7 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,7 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#ifdef PADDLE_WITH_NGRAPH
#pragma once #pragma once
#include <string> #include <string>
...@@ -58,4 +57,3 @@ std::shared_ptr<ngraph::Node> ElementwiseScalar( ...@@ -58,4 +57,3 @@ std::shared_ptr<ngraph::Node> ElementwiseScalar(
} // namespace ngraphs } // namespace ngraphs
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
#endif
...@@ -12,7 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,7 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#ifdef PADDLE_WITH_NGRAPH
#pragma once #pragma once
#include <string> #include <string>
...@@ -58,4 +57,3 @@ void BuildFillConstantNode( ...@@ -58,4 +57,3 @@ void BuildFillConstantNode(
} // namespace ngraphs } // namespace ngraphs
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
#endif
...@@ -12,7 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,7 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#ifdef PADDLE_WITH_NGRAPH
#pragma once #pragma once
#include <functional> #include <functional>
...@@ -65,4 +64,3 @@ void BuildMeanGradNode( ...@@ -65,4 +64,3 @@ void BuildMeanGradNode(
} // namespace ngraphs } // namespace ngraphs
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
#endif
...@@ -12,7 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,7 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#ifdef PADDLE_WITH_NGRAPH
#pragma once #pragma once
#include <string> #include <string>
...@@ -131,4 +130,3 @@ static void BuildMulGradNode( ...@@ -131,4 +130,3 @@ static void BuildMulGradNode(
} // namespace ngraphs } // namespace ngraphs
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
#endif
...@@ -12,7 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,7 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#ifdef PADDLE_WITH_NGRAPH
#pragma once #pragma once
#include <string> #include <string>
...@@ -38,4 +37,3 @@ void BuildScaleNode( ...@@ -38,4 +37,3 @@ void BuildScaleNode(
} // namespace ngraphs } // namespace ngraphs
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
#endif
...@@ -12,7 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,7 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#ifdef PADDLE_WITH_NGRAPH
#pragma once #pragma once
#include <string> #include <string>
...@@ -48,4 +47,3 @@ void BuildTopKNode( ...@@ -48,4 +47,3 @@ void BuildTopKNode(
} // namespace ngraphs } // namespace ngraphs
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
#endif
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <mutex> // NOLINT
#include "paddle/fluid/platform/dynload/cublas.h"
#include "paddle/fluid/platform/macros.h"
#if CUDA_VERSION < 9000
enum cublasMath_t { CUBLAS_DEFAULT_MATH = 0 };
#endif
namespace paddle {
namespace platform {
class CublasHandleHolder {
public:
CublasHandleHolder(cudaStream_t stream, cublasMath_t math_type) {
PADDLE_ENFORCE(dynload::cublasCreate(&handle_));
PADDLE_ENFORCE(dynload::cublasSetStream(handle_, stream));
#if CUDA_VERSION >= 9000
if (math_type == CUBLAS_TENSOR_OP_MATH) {
PADDLE_ENFORCE(
dynload::cublasSetMathMode(handle_, CUBLAS_TENSOR_OP_MATH));
}
#endif
}
~CublasHandleHolder() { PADDLE_ENFORCE(dynload::cublasDestroy(handle_)); }
template <typename Callback>
inline void Call(Callback &&callback) const {
std::lock_guard<std::mutex> guard(mtx_);
callback(handle_);
}
private:
DISABLE_COPY_AND_ASSIGN(CublasHandleHolder);
cublasHandle_t handle_;
mutable std::mutex mtx_;
};
} // namespace platform
} // namespace paddle
...@@ -245,8 +245,15 @@ CUDADeviceContext::CUDADeviceContext(CUDAPlace place) ...@@ -245,8 +245,15 @@ CUDADeviceContext::CUDADeviceContext(CUDAPlace place)
eigen_stream_.reset(new EigenCudaStreamDevice()); eigen_stream_.reset(new EigenCudaStreamDevice());
eigen_stream_->Reinitialize(&stream_, place); eigen_stream_->Reinitialize(&stream_, place);
eigen_device_.reset(new Eigen::GpuDevice(eigen_stream_.get())); eigen_device_.reset(new Eigen::GpuDevice(eigen_stream_.get()));
PADDLE_ENFORCE(dynload::cublasCreate(&cublas_handle_)); cublas_handle_.reset(new CublasHandleHolder(stream_, CUBLAS_DEFAULT_MATH));
PADDLE_ENFORCE(dynload::cublasSetStream(cublas_handle_, stream_));
if (TensorCoreAvailable()) {
#if CUDA_VERSION >= 9000
cublas_tensor_core_handle_.reset(
new CublasHandleHolder(stream_, CUBLAS_TENSOR_OP_MATH));
#endif
}
if (dynload::HasCUDNN()) { if (dynload::HasCUDNN()) {
cudnn_holder_.reset(new CudnnHolder(&stream_, place)); cudnn_holder_.reset(new CudnnHolder(&stream_, place));
} }
...@@ -306,7 +313,8 @@ CUDADeviceContext::~CUDADeviceContext() { ...@@ -306,7 +313,8 @@ CUDADeviceContext::~CUDADeviceContext() {
SetDeviceId(place_.device); SetDeviceId(place_.device);
Wait(); Wait();
WaitStreamCallback(); WaitStreamCallback();
PADDLE_ENFORCE(dynload::cublasDestroy(cublas_handle_)); cublas_handle_.reset();
cublas_tensor_core_handle_.reset();
eigen_stream_.reset(); eigen_stream_.reset();
eigen_device_.reset(); eigen_device_.reset();
PADDLE_ENFORCE(cudaStreamDestroy(stream_)); PADDLE_ENFORCE(cudaStreamDestroy(stream_));
...@@ -335,8 +343,8 @@ Eigen::GpuDevice* CUDADeviceContext::eigen_device() const { ...@@ -335,8 +343,8 @@ Eigen::GpuDevice* CUDADeviceContext::eigen_device() const {
return eigen_device_.get(); return eigen_device_.get();
} }
cublasHandle_t CUDADeviceContext::cublas_handle() const { bool CUDADeviceContext::tensor_core_available() const {
return cublas_handle_; return cublas_tensor_core_handle_ != nullptr;
} }
cudnnHandle_t CUDADeviceContext::cudnn_handle() const { cudnnHandle_t CUDADeviceContext::cudnn_handle() const {
......
...@@ -20,6 +20,7 @@ limitations under the License. */ ...@@ -20,6 +20,7 @@ limitations under the License. */
#include "paddle/fluid/memory/malloc.h" #include "paddle/fluid/memory/malloc.h"
#include "paddle/fluid/platform/temporary_allocator.h" #include "paddle/fluid/platform/temporary_allocator.h"
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
#include "paddle/fluid/platform/cuda_helper.h"
#include "paddle/fluid/platform/dynload/cublas.h" #include "paddle/fluid/platform/dynload/cublas.h"
#include "paddle/fluid/platform/dynload/cudnn.h" #include "paddle/fluid/platform/dynload/cudnn.h"
#include "paddle/fluid/platform/gpu_info.h" #include "paddle/fluid/platform/gpu_info.h"
...@@ -209,39 +210,6 @@ class CudnnWorkspaceHandle { ...@@ -209,39 +210,6 @@ class CudnnWorkspaceHandle {
std::unique_ptr<std::lock_guard<std::mutex>> guard_; std::unique_ptr<std::lock_guard<std::mutex>> guard_;
}; };
#if CUDA_VERSION >= 9000
class ScopedCublasMathMode {
public:
ScopedCublasMathMode(cublasHandle_t handle, cublasMath_t new_math_mode)
: handle_(handle) {
need_reset = false;
PADDLE_ENFORCE(
platform::dynload::cublasGetMathMode(handle_, &old_math_mode_),
"Failed to get old cublas math mode");
if (old_math_mode_ != new_math_mode) {
PADDLE_ENFORCE(
platform::dynload::cublasSetMathMode(handle_, new_math_mode),
"Failed to set old cublas math mode");
need_reset = true;
}
}
~ScopedCublasMathMode() {
if (need_reset) {
PADDLE_ENFORCE(
platform::dynload::cublasSetMathMode(handle_, old_math_mode_),
"Failed to set old cublas math mode");
}
}
private:
cublasHandle_t handle_;
cublasMath_t old_math_mode_;
bool need_reset;
};
#endif
class CUDADeviceContext : public DeviceContext { class CUDADeviceContext : public DeviceContext {
public: public:
explicit CUDADeviceContext(CUDAPlace place); explicit CUDADeviceContext(CUDAPlace place);
...@@ -262,8 +230,25 @@ class CUDADeviceContext : public DeviceContext { ...@@ -262,8 +230,25 @@ class CUDADeviceContext : public DeviceContext {
/*! \brief Return eigen device in the device context. */ /*! \brief Return eigen device in the device context. */
Eigen::GpuDevice* eigen_device() const; Eigen::GpuDevice* eigen_device() const;
/*! \brief Return cublas handle in the device context. */ /*! \brief Call cublas function safely. */
cublasHandle_t cublas_handle() const; template <typename Callback>
inline void CublasCall(Callback&& callback) const {
cublas_handle_->Call(std::forward<Callback>(callback));
}
/*! \brief Check whether tensor core is supported */
bool tensor_core_available() const;
/*! \brief Call cublas function with Tensor Core safely. If
Tensor Core is not available, use DEFAULT_MATH instead. */
template <typename Callback>
inline void TensorCoreCublasCallIfAvailable(Callback&& callback) const {
if (cublas_tensor_core_handle_) {
cublas_tensor_core_handle_->Call(std::forward<Callback>(callback));
} else {
cublas_handle_->Call(std::forward<Callback>(callback));
}
}
/*! \brief Return cudnn handle in the device context. */ /*! \brief Return cudnn handle in the device context. */
cudnnHandle_t cudnn_handle() const; cudnnHandle_t cudnn_handle() const;
...@@ -282,7 +267,6 @@ class CUDADeviceContext : public DeviceContext { ...@@ -282,7 +267,6 @@ class CUDADeviceContext : public DeviceContext {
template <typename Callback> template <typename Callback>
void RecordEvent(cudaEvent_t ev, Callback callback) { void RecordEvent(cudaEvent_t ev, Callback callback) {
std::lock_guard<std::mutex> guard(mtx_);
callback(); callback();
PADDLE_ENFORCE(cudaEventRecord(ev, stream_)); PADDLE_ENFORCE(cudaEventRecord(ev, stream_));
} }
...@@ -294,18 +278,6 @@ class CUDADeviceContext : public DeviceContext { ...@@ -294,18 +278,6 @@ class CUDADeviceContext : public DeviceContext {
void WaitStreamCallback() const { callback_manager_->Wait(); } void WaitStreamCallback() const { callback_manager_->Wait(); }
#if CUDA_VERSION >= 9000
/*! \brief CublasCall may need to change cublas's config,
* but the cublas may be hold by multi-thread, so we should
* add lock here. */
template <typename Callback>
void CublasCall(Callback callback, cublasMath_t new_math) {
std::lock_guard<std::mutex> guard(cublas_mtx_);
ScopedCublasMathMode scoped_cublas_math(cublas_handle_, new_math);
callback();
}
#endif
private: private:
CUDAPlace place_; CUDAPlace place_;
...@@ -313,7 +285,9 @@ class CUDADeviceContext : public DeviceContext { ...@@ -313,7 +285,9 @@ class CUDADeviceContext : public DeviceContext {
std::unique_ptr<EigenCudaStreamDevice> eigen_stream_; std::unique_ptr<EigenCudaStreamDevice> eigen_stream_;
std::unique_ptr<CudnnHolder> cudnn_holder_; std::unique_ptr<CudnnHolder> cudnn_holder_;
cudaStream_t stream_; cudaStream_t stream_;
cublasHandle_t cublas_handle_;
std::unique_ptr<CublasHandleHolder> cublas_handle_;
std::unique_ptr<CublasHandleHolder> cublas_tensor_core_handle_;
int compute_capability_; int compute_capability_;
int runtime_version_; int runtime_version_;
...@@ -321,12 +295,10 @@ class CUDADeviceContext : public DeviceContext { ...@@ -321,12 +295,10 @@ class CUDADeviceContext : public DeviceContext {
int multi_process_; int multi_process_;
int max_threads_per_mp_; int max_threads_per_mp_;
mutable std::mutex mtx_;
// StreamCallbackManager is thread-safe // StreamCallbackManager is thread-safe
std::unique_ptr<StreamCallbackManager> callback_manager_; std::unique_ptr<StreamCallbackManager> callback_manager_;
mutable std::mutex cublas_mtx_; DISABLE_COPY_AND_ASSIGN(CUDADeviceContext);
}; };
template <> template <>
......
...@@ -43,9 +43,6 @@ TEST(Device, CUDADeviceContext) { ...@@ -43,9 +43,6 @@ TEST(Device, CUDADeviceContext) {
ASSERT_NE(nullptr, gpu_device); ASSERT_NE(nullptr, gpu_device);
cudnnHandle_t cudnn_handle = device_context->cudnn_handle(); cudnnHandle_t cudnn_handle = device_context->cudnn_handle();
ASSERT_NE(nullptr, cudnn_handle); ASSERT_NE(nullptr, cudnn_handle);
cublasHandle_t cublas_handle = device_context->cublas_handle();
ASSERT_NE(nullptr, cublas_handle);
ASSERT_NE(nullptr, device_context->stream());
delete device_context; delete device_context;
} }
} }
......
...@@ -155,7 +155,7 @@ def __bootstrap__(): ...@@ -155,7 +155,7 @@ def __bootstrap__():
'fraction_of_gpu_memory_to_use', 'cudnn_deterministic', 'fraction_of_gpu_memory_to_use', 'cudnn_deterministic',
'enable_cublas_tensor_op_math', 'conv_workspace_size_limit', 'enable_cublas_tensor_op_math', 'conv_workspace_size_limit',
'cudnn_exhaustive_search', 'memory_optimize_debug', 'selected_gpus', 'cudnn_exhaustive_search', 'memory_optimize_debug', 'selected_gpus',
'cudnn_exhaustive_search_times', 'sync_nccl_allreduce' 'sync_nccl_allreduce'
] ]
core.init_gflags([sys.argv[0]] + core.init_gflags([sys.argv[0]] +
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册