提交 81d56ca8 编写于 作者: Y Yu Yang

Remove lazy-initialization in device_context

* Also use `const DeviceContext&` all the time, to prevent `const_cast`

Fix #4169
Fix #3468
Fix #3475
上级 47b211de
...@@ -26,3 +26,4 @@ CMakeFiles ...@@ -26,3 +26,4 @@ CMakeFiles
cmake_install.cmake cmake_install.cmake
paddle/.timestamp paddle/.timestamp
python/paddlepaddle.egg-info/ python/paddlepaddle.egg-info/
paddle/pybind/pybind.h
...@@ -22,14 +22,14 @@ namespace framework { ...@@ -22,14 +22,14 @@ namespace framework {
template <> template <>
Eigen::DefaultDevice& ExecutionContext::GetEigenDevice< Eigen::DefaultDevice& ExecutionContext::GetEigenDevice<
platform::CPUPlace, Eigen::DefaultDevice>() const { platform::CPUPlace, Eigen::DefaultDevice>() const {
return *device_context_->get_eigen_device<Eigen::DefaultDevice>(); return *device_context_.get_eigen_device<Eigen::DefaultDevice>();
} }
#ifndef PADDLE_ONLY_CPU #ifndef PADDLE_ONLY_CPU
template <> template <>
Eigen::GpuDevice& Eigen::GpuDevice&
ExecutionContext::GetEigenDevice<platform::GPUPlace, Eigen::GpuDevice>() const { ExecutionContext::GetEigenDevice<platform::GPUPlace, Eigen::GpuDevice>() const {
return *device_context_->get_eigen_device<Eigen::GpuDevice>(); return *device_context_.get_eigen_device<Eigen::GpuDevice>();
} }
#endif #endif
......
...@@ -349,7 +349,7 @@ struct EigenDeviceConverter<platform::GPUPlace> { ...@@ -349,7 +349,7 @@ struct EigenDeviceConverter<platform::GPUPlace> {
class ExecutionContext : public InferShapeContext { class ExecutionContext : public InferShapeContext {
public: public:
ExecutionContext(const OperatorBase& op, const Scope& scope, ExecutionContext(const OperatorBase& op, const Scope& scope,
const platform::DeviceContext* device_context) const platform::DeviceContext& device_context)
: InferShapeContext(op, scope), device_context_(device_context) {} : InferShapeContext(op, scope), device_context_(device_context) {}
template <typename PlaceType, template <typename PlaceType,
...@@ -357,13 +357,14 @@ class ExecutionContext : public InferShapeContext { ...@@ -357,13 +357,14 @@ class ExecutionContext : public InferShapeContext {
typename EigenDeviceConverter<PlaceType>::EigenDeviceType> typename EigenDeviceConverter<PlaceType>::EigenDeviceType>
DeviceType& GetEigenDevice() const; DeviceType& GetEigenDevice() const;
platform::Place GetPlace() const { return device_context_->GetPlace(); } platform::Place GetPlace() const { return device_context_.GetPlace(); }
const platform::DeviceContext* device_context() const { const platform::DeviceContext& device_context() const {
return device_context_; return device_context_;
} }
const platform::DeviceContext* device_context_; private:
const platform::DeviceContext& device_context_;
}; };
class OpKernel { class OpKernel {
...@@ -416,7 +417,7 @@ class OperatorWithKernel : public OperatorBase { ...@@ -416,7 +417,7 @@ class OperatorWithKernel : public OperatorBase {
void Run(const Scope& scope, void Run(const Scope& scope,
const platform::DeviceContext& dev_ctx) const final { const platform::DeviceContext& dev_ctx) const final {
auto& opKernel = AllOpKernels().at(type_).at(OpKernelKey(dev_ctx)); auto& opKernel = AllOpKernels().at(type_).at(OpKernelKey(dev_ctx));
opKernel->Compute(ExecutionContext(*this, scope, &dev_ctx)); opKernel->Compute(ExecutionContext(*this, scope, dev_ctx));
} }
static std::unordered_map<std::string /* op_type */, OpKernelMap>& static std::unordered_map<std::string /* op_type */, OpKernelMap>&
......
...@@ -19,12 +19,13 @@ namespace operators { ...@@ -19,12 +19,13 @@ namespace operators {
namespace math { namespace math {
template <> template <>
void gemm<platform::CPUPlace, float>(const CBLAS_TRANSPOSE transA, void gemm<platform::CPUPlace, float>(const platform::DeviceContext& context,
const CBLAS_TRANSPOSE transA,
const CBLAS_TRANSPOSE transB, const int M, const CBLAS_TRANSPOSE transB, const int M,
const int N, const int K, const int N, const int K,
const float alpha, const float* A, const float alpha, const float* A,
const float* B, const float beta, float* C, const float* B, const float beta,
platform::DeviceContext* context) { float* C) {
int lda = (transA == CblasNoTrans) ? K : M; int lda = (transA == CblasNoTrans) ? K : M;
int ldb = (transB == CblasNoTrans) ? N : K; int ldb = (transB == CblasNoTrans) ? N : K;
int ldc = N; int ldc = N;
...@@ -33,13 +34,13 @@ void gemm<platform::CPUPlace, float>(const CBLAS_TRANSPOSE transA, ...@@ -33,13 +34,13 @@ void gemm<platform::CPUPlace, float>(const CBLAS_TRANSPOSE transA,
} }
template <> template <>
void gemm<platform::CPUPlace, double>(const CBLAS_TRANSPOSE transA, void gemm<platform::CPUPlace, double>(const platform::DeviceContext& context,
const CBLAS_TRANSPOSE transA,
const CBLAS_TRANSPOSE transB, const int M, const CBLAS_TRANSPOSE transB, const int M,
const int N, const int K, const int N, const int K,
const double alpha, const double* A, const double alpha, const double* A,
const double* B, const double beta, const double* B, const double beta,
double* C, double* C) {
platform::DeviceContext* context) {
int lda = (transA == CblasNoTrans) ? K : M; int lda = (transA == CblasNoTrans) ? K : M;
int ldb = (transB == CblasNoTrans) ? N : K; int ldb = (transB == CblasNoTrans) ? N : K;
int ldc = N; int ldc = N;
...@@ -48,13 +49,10 @@ void gemm<platform::CPUPlace, double>(const CBLAS_TRANSPOSE transA, ...@@ -48,13 +49,10 @@ void gemm<platform::CPUPlace, double>(const CBLAS_TRANSPOSE transA,
} }
template <> template <>
void matmul<platform::CPUPlace, float>(const framework::Tensor& matrix_a, void matmul<platform::CPUPlace, float>(
bool trans_a, const platform::DeviceContext& context, const framework::Tensor& matrix_a,
const framework::Tensor& matrix_b, bool trans_a, const framework::Tensor& matrix_b, bool trans_b, float alpha,
bool trans_b, float alpha, framework::Tensor* matrix_out, float beta) {
framework::Tensor* matrix_out,
float beta,
platform::DeviceContext* context) {
auto dim_a = matrix_a.dims(); auto dim_a = matrix_a.dims();
auto dim_b = matrix_b.dims(); auto dim_b = matrix_b.dims();
auto dim_out = matrix_out->dims(); auto dim_out = matrix_out->dims();
...@@ -74,18 +72,15 @@ void matmul<platform::CPUPlace, float>(const framework::Tensor& matrix_a, ...@@ -74,18 +72,15 @@ void matmul<platform::CPUPlace, float>(const framework::Tensor& matrix_a,
CBLAS_TRANSPOSE transB = (trans_b == false) ? CblasNoTrans : CblasTrans; CBLAS_TRANSPOSE transB = (trans_b == false) ? CblasNoTrans : CblasTrans;
gemm<platform::CPUPlace, float>( gemm<platform::CPUPlace, float>(
transA, transB, M, N, K, alpha, matrix_a.data<float>(), context, transA, transB, M, N, K, alpha, matrix_a.data<float>(),
matrix_b.data<float>(), beta, matrix_out->data<float>(), context); matrix_b.data<float>(), beta, matrix_out->data<float>());
} }
template <> template <>
void matmul<platform::CPUPlace, double>(const framework::Tensor& matrix_a, void matmul<platform::CPUPlace, double>(
bool trans_a, const platform::DeviceContext& context, const framework::Tensor& matrix_a,
const framework::Tensor& matrix_b, bool trans_a, const framework::Tensor& matrix_b, bool trans_b, double alpha,
bool trans_b, double alpha, framework::Tensor* matrix_out, double beta) {
framework::Tensor* matrix_out,
double beta,
platform::DeviceContext* context) {
auto dim_a = matrix_a.dims(); auto dim_a = matrix_a.dims();
auto dim_b = matrix_b.dims(); auto dim_b = matrix_b.dims();
auto dim_out = matrix_out->dims(); auto dim_out = matrix_out->dims();
...@@ -105,8 +100,8 @@ void matmul<platform::CPUPlace, double>(const framework::Tensor& matrix_a, ...@@ -105,8 +100,8 @@ void matmul<platform::CPUPlace, double>(const framework::Tensor& matrix_a,
CBLAS_TRANSPOSE transB = (trans_b == false) ? CblasNoTrans : CblasTrans; CBLAS_TRANSPOSE transB = (trans_b == false) ? CblasNoTrans : CblasTrans;
gemm<platform::CPUPlace, double>( gemm<platform::CPUPlace, double>(
transA, transB, M, N, K, alpha, matrix_a.data<double>(), context, transA, transB, M, N, K, alpha, matrix_a.data<double>(),
matrix_b.data<double>(), beta, matrix_out->data<double>(), context); matrix_b.data<double>(), beta, matrix_out->data<double>());
} }
} // namespace math } // namespace math
......
...@@ -19,12 +19,13 @@ namespace operators { ...@@ -19,12 +19,13 @@ namespace operators {
namespace math { namespace math {
template <> template <>
void gemm<platform::GPUPlace, float>(const CBLAS_TRANSPOSE transA, void gemm<platform::GPUPlace, float>(const platform::DeviceContext& context,
const CBLAS_TRANSPOSE transA,
const CBLAS_TRANSPOSE transB, const int M, const CBLAS_TRANSPOSE transB, const int M,
const int N, const int K, const int N, const int K,
const float alpha, const float* A, const float alpha, const float* A,
const float* B, const float beta, float* C, const float* B, const float beta,
platform::DeviceContext* context) { float* C) {
// Note that cublas follows fortran order, so the order is different from // Note that cublas follows fortran order, so the order is different from
// the cblas convention. // the cblas convention.
int lda = (transA == CblasNoTrans) ? K : M; int lda = (transA == CblasNoTrans) ? K : M;
...@@ -35,18 +36,19 @@ void gemm<platform::GPUPlace, float>(const CBLAS_TRANSPOSE transA, ...@@ -35,18 +36,19 @@ void gemm<platform::GPUPlace, float>(const CBLAS_TRANSPOSE transA,
(transB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; (transB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
PADDLE_ENFORCE(platform::dynload::cublasSgemm( PADDLE_ENFORCE(platform::dynload::cublasSgemm(
reinterpret_cast<platform::CUDADeviceContext*>(context)->cublas_handle(), reinterpret_cast<const platform::CUDADeviceContext&>(context)
.cublas_handle(),
cuTransB, cuTransA, N, M, K, &alpha, B, ldb, A, lda, &beta, C, N)); cuTransB, cuTransA, N, M, K, &alpha, B, ldb, A, lda, &beta, C, N));
} }
template <> template <>
void gemm<platform::GPUPlace, double>(const CBLAS_TRANSPOSE transA, void gemm<platform::GPUPlace, double>(const platform::DeviceContext& context,
const CBLAS_TRANSPOSE transA,
const CBLAS_TRANSPOSE transB, const int M, const CBLAS_TRANSPOSE transB, const int M,
const int N, const int K, const int N, const int K,
const double alpha, const double* A, const double alpha, const double* A,
const double* B, const double beta, const double* B, const double beta,
double* C, double* C) {
platform::DeviceContext* context) {
// Note that cublas follows fortran order, so the order is different from // Note that cublas follows fortran order, so the order is different from
// the cblas convention. // the cblas convention.
int lda = (transA == CblasNoTrans) ? K : M; int lda = (transA == CblasNoTrans) ? K : M;
...@@ -56,18 +58,16 @@ void gemm<platform::GPUPlace, double>(const CBLAS_TRANSPOSE transA, ...@@ -56,18 +58,16 @@ void gemm<platform::GPUPlace, double>(const CBLAS_TRANSPOSE transA,
cublasOperation_t cuTransB = cublasOperation_t cuTransB =
(transB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; (transB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
PADDLE_ENFORCE(platform::dynload::cublasDgemm( PADDLE_ENFORCE(platform::dynload::cublasDgemm(
reinterpret_cast<platform::CUDADeviceContext*>(context)->cublas_handle(), reinterpret_cast<const platform::CUDADeviceContext&>(context)
.cublas_handle(),
cuTransB, cuTransA, N, M, K, &alpha, B, ldb, A, lda, &beta, C, N)); cuTransB, cuTransA, N, M, K, &alpha, B, ldb, A, lda, &beta, C, N));
} }
template <> template <>
void matmul<platform::GPUPlace, float>(const framework::Tensor& matrix_a, void matmul<platform::GPUPlace, float>(
bool trans_a, const platform::DeviceContext& context, const framework::Tensor& matrix_a,
const framework::Tensor& matrix_b, bool trans_a, const framework::Tensor& matrix_b, bool trans_b, float alpha,
bool trans_b, float alpha, framework::Tensor* matrix_out, float beta) {
framework::Tensor* matrix_out,
float beta,
platform::DeviceContext* context) {
auto dim_a = matrix_a.dims(); auto dim_a = matrix_a.dims();
auto dim_b = matrix_b.dims(); auto dim_b = matrix_b.dims();
auto dim_out = matrix_out->dims(); auto dim_out = matrix_out->dims();
...@@ -87,18 +87,15 @@ void matmul<platform::GPUPlace, float>(const framework::Tensor& matrix_a, ...@@ -87,18 +87,15 @@ void matmul<platform::GPUPlace, float>(const framework::Tensor& matrix_a,
CBLAS_TRANSPOSE transB = (trans_b == false) ? CblasNoTrans : CblasTrans; CBLAS_TRANSPOSE transB = (trans_b == false) ? CblasNoTrans : CblasTrans;
gemm<platform::GPUPlace, float>( gemm<platform::GPUPlace, float>(
transA, transB, M, N, K, alpha, matrix_a.data<float>(), context, transA, transB, M, N, K, alpha, matrix_a.data<float>(),
matrix_b.data<float>(), beta, matrix_out->data<float>(), context); matrix_b.data<float>(), beta, matrix_out->data<float>());
} }
template <> template <>
void matmul<platform::GPUPlace, double>(const framework::Tensor& matrix_a, void matmul<platform::GPUPlace, double>(
bool trans_a, const platform::DeviceContext& context, const framework::Tensor& matrix_a,
const framework::Tensor& matrix_b, bool trans_a, const framework::Tensor& matrix_b, bool trans_b, double alpha,
bool trans_b, double alpha, framework::Tensor* matrix_out, double beta) {
framework::Tensor* matrix_out,
double beta,
platform::DeviceContext* context) {
auto dim_a = matrix_a.dims(); auto dim_a = matrix_a.dims();
auto dim_b = matrix_b.dims(); auto dim_b = matrix_b.dims();
auto dim_out = matrix_out->dims(); auto dim_out = matrix_out->dims();
...@@ -118,8 +115,8 @@ void matmul<platform::GPUPlace, double>(const framework::Tensor& matrix_a, ...@@ -118,8 +115,8 @@ void matmul<platform::GPUPlace, double>(const framework::Tensor& matrix_a,
CBLAS_TRANSPOSE transB = (trans_b == false) ? CblasNoTrans : CblasTrans; CBLAS_TRANSPOSE transB = (trans_b == false) ? CblasNoTrans : CblasTrans;
gemm<platform::GPUPlace, double>( gemm<platform::GPUPlace, double>(
transA, transB, M, N, K, alpha, matrix_a.data<double>(), context, transA, transB, M, N, K, alpha, matrix_a.data<double>(),
matrix_b.data<double>(), beta, matrix_out->data<double>(), context); matrix_b.data<double>(), beta, matrix_out->data<double>());
} }
} // namespace math } // namespace math
......
...@@ -66,16 +66,16 @@ namespace math { ...@@ -66,16 +66,16 @@ namespace math {
// For more detailed info, please refer to // For more detailed info, please refer to
// http://www.netlib.org/lapack/explore-html/d4/de2/sgemm_8f.html // http://www.netlib.org/lapack/explore-html/d4/de2/sgemm_8f.html
template <typename Place, typename T> template <typename Place, typename T>
void gemm(const CBLAS_TRANSPOSE transA, const CBLAS_TRANSPOSE transB, void gemm(const platform::DeviceContext& context, const CBLAS_TRANSPOSE transA,
const int M, const int N, const int K, const T alpha, const T* A, const CBLAS_TRANSPOSE transB, const int M, const int N, const int K,
const T* B, const T beta, T* C, platform::DeviceContext* context); const T alpha, const T* A, const T* B, const T beta, T* C);
// matrix multiply with continuous memory // matrix multiply with continuous memory
template <typename Place, typename T> template <typename Place, typename T>
void matmul(const framework::Tensor& matrix_a, bool trans_a, void matmul(const platform::DeviceContext& context,
const framework::Tensor& matrix_a, bool trans_a,
const framework::Tensor& matrix_b, bool trans_b, T alpha, const framework::Tensor& matrix_b, bool trans_b, T alpha,
framework::Tensor* matrix_out, T beta, framework::Tensor* matrix_out, T beta);
platform::DeviceContext* context);
} // namespace math } // namespace math
} // namespace operators } // namespace operators
......
...@@ -15,8 +15,7 @@ TEST(math_function, notrans_mul_trans) { ...@@ -15,8 +15,7 @@ TEST(math_function, notrans_mul_trans) {
memcpy(input1_ptr, arr, 6 * sizeof(float)); memcpy(input1_ptr, arr, 6 * sizeof(float));
auto* gpu_place = new paddle::platform::GPUPlace(0); auto* gpu_place = new paddle::platform::GPUPlace(0);
paddle::platform::DeviceContext* context = paddle::platform::CUDADeviceContext context(*gpu_place);
new paddle::platform::CUDADeviceContext(*gpu_place);
input1_gpu.CopyFrom<float>(input1, *gpu_place); input1_gpu.CopyFrom<float>(input1, *gpu_place);
input2_gpu.CopyFrom<float>(input1, *gpu_place); input2_gpu.CopyFrom<float>(input1, *gpu_place);
...@@ -24,7 +23,7 @@ TEST(math_function, notrans_mul_trans) { ...@@ -24,7 +23,7 @@ TEST(math_function, notrans_mul_trans) {
out_gpu.mutable_data<float>({2, 2}, *gpu_place); out_gpu.mutable_data<float>({2, 2}, *gpu_place);
paddle::operators::math::matmul<paddle::platform::GPUPlace, float>( paddle::operators::math::matmul<paddle::platform::GPUPlace, float>(
input1_gpu, false, input2_gpu, true, 1, &out_gpu, 0, context); context, input1_gpu, false, input2_gpu, true, 1, &out_gpu, 0);
out.CopyFrom<float>(out_gpu, *cpu_place); out.CopyFrom<float>(out_gpu, *cpu_place);
...@@ -33,6 +32,7 @@ TEST(math_function, notrans_mul_trans) { ...@@ -33,6 +32,7 @@ TEST(math_function, notrans_mul_trans) {
EXPECT_EQ(out_ptr[1], 14); EXPECT_EQ(out_ptr[1], 14);
EXPECT_EQ(out_ptr[2], 14); EXPECT_EQ(out_ptr[2], 14);
EXPECT_EQ(out_ptr[3], 50); EXPECT_EQ(out_ptr[3], 50);
delete gpu_place;
} }
TEST(math_function, trans_mul_notrans) { TEST(math_function, trans_mul_notrans) {
...@@ -48,8 +48,7 @@ TEST(math_function, trans_mul_notrans) { ...@@ -48,8 +48,7 @@ TEST(math_function, trans_mul_notrans) {
memcpy(input1_ptr, arr, 6 * sizeof(float)); memcpy(input1_ptr, arr, 6 * sizeof(float));
auto* gpu_place = new paddle::platform::GPUPlace(0); auto* gpu_place = new paddle::platform::GPUPlace(0);
paddle::platform::DeviceContext* context = paddle::platform::CUDADeviceContext context(*gpu_place);
new paddle::platform::CUDADeviceContext(*gpu_place);
input1_gpu.CopyFrom<float>(input1, *gpu_place); input1_gpu.CopyFrom<float>(input1, *gpu_place);
input2_gpu.CopyFrom<float>(input1, *gpu_place); input2_gpu.CopyFrom<float>(input1, *gpu_place);
...@@ -57,7 +56,7 @@ TEST(math_function, trans_mul_notrans) { ...@@ -57,7 +56,7 @@ TEST(math_function, trans_mul_notrans) {
out_gpu.mutable_data<float>({3, 3}, *gpu_place); out_gpu.mutable_data<float>({3, 3}, *gpu_place);
paddle::operators::math::matmul<paddle::platform::GPUPlace, float>( paddle::operators::math::matmul<paddle::platform::GPUPlace, float>(
input1_gpu, true, input2_gpu, false, 1, &out_gpu, 0, context); context, input1_gpu, true, input2_gpu, false, 1, &out_gpu, 0);
out.CopyFrom<float>(out_gpu, *cpu_place); out.CopyFrom<float>(out_gpu, *cpu_place);
...@@ -71,5 +70,6 @@ TEST(math_function, trans_mul_notrans) { ...@@ -71,5 +70,6 @@ TEST(math_function, trans_mul_notrans) {
EXPECT_EQ(out_ptr[6], 15); EXPECT_EQ(out_ptr[6], 15);
EXPECT_EQ(out_ptr[7], 22); EXPECT_EQ(out_ptr[7], 22);
EXPECT_EQ(out_ptr[8], 29); EXPECT_EQ(out_ptr[8], 29);
delete gpu_place;
} }
#endif #endif
...@@ -46,10 +46,8 @@ class MulKernel : public framework::OpKernel { ...@@ -46,10 +46,8 @@ class MulKernel : public framework::OpKernel {
: *y; : *y;
z->mutable_data<T>(context.GetPlace()); z->mutable_data<T>(context.GetPlace());
auto* device_context = math::matmul<Place, T>(context.device_context(), x_matrix, false, y_matrix,
const_cast<platform::DeviceContext*>(context.device_context_); false, 1, z, 0);
math::matmul<Place, T>(x_matrix, false, y_matrix, false, 1, z, 0,
device_context);
} }
}; };
...@@ -71,16 +69,14 @@ class MulGradKernel : public framework::OpKernel { ...@@ -71,16 +69,14 @@ class MulGradKernel : public framework::OpKernel {
Tensor* dx = ctx.Output<Tensor>(framework::GradVarName("X")); Tensor* dx = ctx.Output<Tensor>(framework::GradVarName("X"));
Tensor* dy = ctx.Output<Tensor>(framework::GradVarName("Y")); Tensor* dy = ctx.Output<Tensor>(framework::GradVarName("Y"));
auto* device_context =
const_cast<platform::DeviceContext*>(ctx.device_context_);
if (dx) { if (dx) {
dx->mutable_data<T>(ctx.GetPlace()); dx->mutable_data<T>(ctx.GetPlace());
Tensor dx_matrix = dx->dims().size() > 2 ? framework::ReshapeToMatrix<T>( Tensor dx_matrix = dx->dims().size() > 2 ? framework::ReshapeToMatrix<T>(
*dx, x_num_col_dims) *dx, x_num_col_dims)
: *dx; : *dx;
// dx = dout * y'. dx: M x K, dout : M x N, y : K x N // dx = dout * y'. dx: M x K, dout : M x N, y : K x N
math::matmul<Place, T>(*dout, false, y_matrix, true, 1, &dx_matrix, 0, math::matmul<Place, T>(ctx.device_context(), *dout, false, y_matrix, true,
device_context); 1, &dx_matrix, 0);
} }
if (dy) { if (dy) {
dy->mutable_data<T>(ctx.GetPlace()); dy->mutable_data<T>(ctx.GetPlace());
...@@ -88,8 +84,8 @@ class MulGradKernel : public framework::OpKernel { ...@@ -88,8 +84,8 @@ class MulGradKernel : public framework::OpKernel {
*dy, y_num_col_dims) *dy, y_num_col_dims)
: *dy; : *dy;
// dy = x' * dout. dy K x N, dout : M x N, x : M x K // dy = x' * dout. dy K x N, dout : M x N, x : M x K
math::matmul<Place, T>(x_matrix, true, *dout, false, 1, &dy_matrix, 0, math::matmul<Place, T>(ctx.device_context(), x_matrix, true, *dout, false,
device_context); 1, &dy_matrix, 0);
} }
} }
}; };
......
...@@ -101,19 +101,17 @@ CUDADeviceContext::CUDADeviceContext(GPUPlace place) : place_(place) { ...@@ -101,19 +101,17 @@ CUDADeviceContext::CUDADeviceContext(GPUPlace place) : place_(place) {
eigen_stream_.reset(new EigenCudaStreamDevice()); eigen_stream_.reset(new EigenCudaStreamDevice());
eigen_stream_->Reinitialize(&stream_, place); eigen_stream_->Reinitialize(&stream_, place);
eigen_device_.reset(new Eigen::GpuDevice(eigen_stream_.get())); eigen_device_.reset(new Eigen::GpuDevice(eigen_stream_.get()));
PADDLE_ENFORCE(dynload::cublasCreate(&cublas_handle_));
PADDLE_ENFORCE(dynload::cublasSetStream(cublas_handle_, stream_));
PADDLE_ENFORCE(dynload::cudnnCreate(&cudnn_handle_));
PADDLE_ENFORCE(dynload::cudnnSetStream(cudnn_handle_, stream_));
} }
CUDADeviceContext::~CUDADeviceContext() { CUDADeviceContext::~CUDADeviceContext() {
SetDeviceId(place_.device); SetDeviceId(place_.device);
Wait(); Wait();
if (cublas_handle_) {
PADDLE_ENFORCE(dynload::cublasDestroy(cublas_handle_)); PADDLE_ENFORCE(dynload::cublasDestroy(cublas_handle_));
}
if (cudnn_handle_) {
PADDLE_ENFORCE(dynload::cudnnDestroy(cudnn_handle_)); PADDLE_ENFORCE(dynload::cudnnDestroy(cudnn_handle_));
}
eigen_stream_.reset(); eigen_stream_.reset();
eigen_device_.reset(); eigen_device_.reset();
PADDLE_ENFORCE(cudaStreamDestroy(stream_)); PADDLE_ENFORCE(cudaStreamDestroy(stream_));
...@@ -129,25 +127,13 @@ Eigen::GpuDevice* CUDADeviceContext::eigen_device() const { ...@@ -129,25 +127,13 @@ Eigen::GpuDevice* CUDADeviceContext::eigen_device() const {
return eigen_device_.get(); return eigen_device_.get();
} }
cublasHandle_t CUDADeviceContext::cublas_handle() { cublasHandle_t CUDADeviceContext::cublas_handle() const {
if (!cublas_handle_) {
SetDeviceId(place_.device);
PADDLE_ENFORCE(dynload::cublasCreate(&cublas_handle_));
PADDLE_ENFORCE(dynload::cublasSetStream(cublas_handle_, stream_));
}
return cublas_handle_; return cublas_handle_;
} }
cudnnHandle_t CUDADeviceContext::cudnn_handle() { cudnnHandle_t CUDADeviceContext::cudnn_handle() const { return cudnn_handle_; }
if (!cudnn_handle_) {
SetDeviceId(place_.device);
PADDLE_ENFORCE(dynload::cudnnCreate(&cudnn_handle_));
PADDLE_ENFORCE(dynload::cudnnSetStream(cudnn_handle_, stream_));
}
return cudnn_handle_;
}
cudaStream_t CUDADeviceContext::stream() { return stream_; } cudaStream_t CUDADeviceContext::stream() const { return stream_; }
#endif // PADDLE_ONLY_CPU #endif // PADDLE_ONLY_CPU
......
...@@ -67,16 +67,14 @@ class CUDADeviceContext : public DeviceContext { ...@@ -67,16 +67,14 @@ class CUDADeviceContext : public DeviceContext {
/*! \brief Return eigen device in the device context. */ /*! \brief Return eigen device in the device context. */
Eigen::GpuDevice* eigen_device() const; Eigen::GpuDevice* eigen_device() const;
// clang-format off
/*! \brief Return cublas handle in the device context. */ /*! \brief Return cublas handle in the device context. */
cublasHandle_t cublas_handle(); cublasHandle_t cublas_handle() const;
/*! \brief Return cudnn handle in the device context. */ /*! \brief Return cudnn handle in the device context. */
cudnnHandle_t cudnn_handle(); cudnnHandle_t cudnn_handle() const;
/*! \brief Return cuda stream in the device context. */ /*! \brief Return cuda stream in the device context. */
cudaStream_t stream(); cudaStream_t stream() const;
// clang-format on
private: private:
GPUPlace place_; GPUPlace place_;
...@@ -84,11 +82,9 @@ class CUDADeviceContext : public DeviceContext { ...@@ -84,11 +82,9 @@ class CUDADeviceContext : public DeviceContext {
std::unique_ptr<Eigen::GpuDevice> eigen_device_; std::unique_ptr<Eigen::GpuDevice> eigen_device_;
std::unique_ptr<EigenCudaStreamDevice> eigen_stream_; std::unique_ptr<EigenCudaStreamDevice> eigen_stream_;
// clang-format off cudaStream_t stream_;
cudaStream_t stream_{nullptr}; cudnnHandle_t cudnn_handle_;
cudnnHandle_t cudnn_handle_{nullptr}; cublasHandle_t cublas_handle_;
cublasHandle_t cublas_handle_{nullptr};
// clang-format on
}; };
#endif #endif
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册