提交 81d56ca8 编写于 作者: Y Yu Yang

Remove lazy-initialization in device_context

* Also use `const DeviceContext&` all the time, to prevent `const_cast`

Fix #4169
Fix #3468
Fix #3475
上级 47b211de
......@@ -26,3 +26,4 @@ CMakeFiles
cmake_install.cmake
paddle/.timestamp
python/paddlepaddle.egg-info/
paddle/pybind/pybind.h
......@@ -22,14 +22,14 @@ namespace framework {
template <>
Eigen::DefaultDevice& ExecutionContext::GetEigenDevice<
platform::CPUPlace, Eigen::DefaultDevice>() const {
return *device_context_->get_eigen_device<Eigen::DefaultDevice>();
return *device_context_.get_eigen_device<Eigen::DefaultDevice>();
}
#ifndef PADDLE_ONLY_CPU
template <>
Eigen::GpuDevice&
ExecutionContext::GetEigenDevice<platform::GPUPlace, Eigen::GpuDevice>() const {
return *device_context_->get_eigen_device<Eigen::GpuDevice>();
return *device_context_.get_eigen_device<Eigen::GpuDevice>();
}
#endif
......
......@@ -349,7 +349,7 @@ struct EigenDeviceConverter<platform::GPUPlace> {
class ExecutionContext : public InferShapeContext {
public:
ExecutionContext(const OperatorBase& op, const Scope& scope,
const platform::DeviceContext* device_context)
const platform::DeviceContext& device_context)
: InferShapeContext(op, scope), device_context_(device_context) {}
template <typename PlaceType,
......@@ -357,13 +357,14 @@ class ExecutionContext : public InferShapeContext {
typename EigenDeviceConverter<PlaceType>::EigenDeviceType>
DeviceType& GetEigenDevice() const;
platform::Place GetPlace() const { return device_context_->GetPlace(); }
platform::Place GetPlace() const { return device_context_.GetPlace(); }
const platform::DeviceContext* device_context() const {
const platform::DeviceContext& device_context() const {
return device_context_;
}
const platform::DeviceContext* device_context_;
private:
const platform::DeviceContext& device_context_;
};
class OpKernel {
......@@ -416,7 +417,7 @@ class OperatorWithKernel : public OperatorBase {
void Run(const Scope& scope,
const platform::DeviceContext& dev_ctx) const final {
auto& opKernel = AllOpKernels().at(type_).at(OpKernelKey(dev_ctx));
opKernel->Compute(ExecutionContext(*this, scope, &dev_ctx));
opKernel->Compute(ExecutionContext(*this, scope, dev_ctx));
}
static std::unordered_map<std::string /* op_type */, OpKernelMap>&
......
......@@ -19,12 +19,13 @@ namespace operators {
namespace math {
template <>
void gemm<platform::CPUPlace, float>(const CBLAS_TRANSPOSE transA,
void gemm<platform::CPUPlace, float>(const platform::DeviceContext& context,
const CBLAS_TRANSPOSE transA,
const CBLAS_TRANSPOSE transB, const int M,
const int N, const int K,
const float alpha, const float* A,
const float* B, const float beta, float* C,
platform::DeviceContext* context) {
const float* B, const float beta,
float* C) {
int lda = (transA == CblasNoTrans) ? K : M;
int ldb = (transB == CblasNoTrans) ? N : K;
int ldc = N;
......@@ -33,13 +34,13 @@ void gemm<platform::CPUPlace, float>(const CBLAS_TRANSPOSE transA,
}
template <>
void gemm<platform::CPUPlace, double>(const CBLAS_TRANSPOSE transA,
void gemm<platform::CPUPlace, double>(const platform::DeviceContext& context,
const CBLAS_TRANSPOSE transA,
const CBLAS_TRANSPOSE transB, const int M,
const int N, const int K,
const double alpha, const double* A,
const double* B, const double beta,
double* C,
platform::DeviceContext* context) {
double* C) {
int lda = (transA == CblasNoTrans) ? K : M;
int ldb = (transB == CblasNoTrans) ? N : K;
int ldc = N;
......@@ -48,13 +49,10 @@ void gemm<platform::CPUPlace, double>(const CBLAS_TRANSPOSE transA,
}
template <>
void matmul<platform::CPUPlace, float>(const framework::Tensor& matrix_a,
bool trans_a,
const framework::Tensor& matrix_b,
bool trans_b, float alpha,
framework::Tensor* matrix_out,
float beta,
platform::DeviceContext* context) {
void matmul<platform::CPUPlace, float>(
const platform::DeviceContext& context, const framework::Tensor& matrix_a,
bool trans_a, const framework::Tensor& matrix_b, bool trans_b, float alpha,
framework::Tensor* matrix_out, float beta) {
auto dim_a = matrix_a.dims();
auto dim_b = matrix_b.dims();
auto dim_out = matrix_out->dims();
......@@ -74,18 +72,15 @@ void matmul<platform::CPUPlace, float>(const framework::Tensor& matrix_a,
CBLAS_TRANSPOSE transB = (trans_b == false) ? CblasNoTrans : CblasTrans;
gemm<platform::CPUPlace, float>(
transA, transB, M, N, K, alpha, matrix_a.data<float>(),
matrix_b.data<float>(), beta, matrix_out->data<float>(), context);
context, transA, transB, M, N, K, alpha, matrix_a.data<float>(),
matrix_b.data<float>(), beta, matrix_out->data<float>());
}
template <>
void matmul<platform::CPUPlace, double>(const framework::Tensor& matrix_a,
bool trans_a,
const framework::Tensor& matrix_b,
bool trans_b, double alpha,
framework::Tensor* matrix_out,
double beta,
platform::DeviceContext* context) {
void matmul<platform::CPUPlace, double>(
const platform::DeviceContext& context, const framework::Tensor& matrix_a,
bool trans_a, const framework::Tensor& matrix_b, bool trans_b, double alpha,
framework::Tensor* matrix_out, double beta) {
auto dim_a = matrix_a.dims();
auto dim_b = matrix_b.dims();
auto dim_out = matrix_out->dims();
......@@ -105,8 +100,8 @@ void matmul<platform::CPUPlace, double>(const framework::Tensor& matrix_a,
CBLAS_TRANSPOSE transB = (trans_b == false) ? CblasNoTrans : CblasTrans;
gemm<platform::CPUPlace, double>(
transA, transB, M, N, K, alpha, matrix_a.data<double>(),
matrix_b.data<double>(), beta, matrix_out->data<double>(), context);
context, transA, transB, M, N, K, alpha, matrix_a.data<double>(),
matrix_b.data<double>(), beta, matrix_out->data<double>());
}
} // namespace math
......
......@@ -19,12 +19,13 @@ namespace operators {
namespace math {
template <>
void gemm<platform::GPUPlace, float>(const CBLAS_TRANSPOSE transA,
void gemm<platform::GPUPlace, float>(const platform::DeviceContext& context,
const CBLAS_TRANSPOSE transA,
const CBLAS_TRANSPOSE transB, const int M,
const int N, const int K,
const float alpha, const float* A,
const float* B, const float beta, float* C,
platform::DeviceContext* context) {
const float* B, const float beta,
float* C) {
// Note that cublas follows fortran order, so the order is different from
// the cblas convention.
int lda = (transA == CblasNoTrans) ? K : M;
......@@ -35,18 +36,19 @@ void gemm<platform::GPUPlace, float>(const CBLAS_TRANSPOSE transA,
(transB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
PADDLE_ENFORCE(platform::dynload::cublasSgemm(
reinterpret_cast<platform::CUDADeviceContext*>(context)->cublas_handle(),
reinterpret_cast<const platform::CUDADeviceContext&>(context)
.cublas_handle(),
cuTransB, cuTransA, N, M, K, &alpha, B, ldb, A, lda, &beta, C, N));
}
template <>
void gemm<platform::GPUPlace, double>(const CBLAS_TRANSPOSE transA,
void gemm<platform::GPUPlace, double>(const platform::DeviceContext& context,
const CBLAS_TRANSPOSE transA,
const CBLAS_TRANSPOSE transB, const int M,
const int N, const int K,
const double alpha, const double* A,
const double* B, const double beta,
double* C,
platform::DeviceContext* context) {
double* C) {
// Note that cublas follows fortran order, so the order is different from
// the cblas convention.
int lda = (transA == CblasNoTrans) ? K : M;
......@@ -56,18 +58,16 @@ void gemm<platform::GPUPlace, double>(const CBLAS_TRANSPOSE transA,
cublasOperation_t cuTransB =
(transB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
PADDLE_ENFORCE(platform::dynload::cublasDgemm(
reinterpret_cast<platform::CUDADeviceContext*>(context)->cublas_handle(),
reinterpret_cast<const platform::CUDADeviceContext&>(context)
.cublas_handle(),
cuTransB, cuTransA, N, M, K, &alpha, B, ldb, A, lda, &beta, C, N));
}
template <>
void matmul<platform::GPUPlace, float>(const framework::Tensor& matrix_a,
bool trans_a,
const framework::Tensor& matrix_b,
bool trans_b, float alpha,
framework::Tensor* matrix_out,
float beta,
platform::DeviceContext* context) {
void matmul<platform::GPUPlace, float>(
const platform::DeviceContext& context, const framework::Tensor& matrix_a,
bool trans_a, const framework::Tensor& matrix_b, bool trans_b, float alpha,
framework::Tensor* matrix_out, float beta) {
auto dim_a = matrix_a.dims();
auto dim_b = matrix_b.dims();
auto dim_out = matrix_out->dims();
......@@ -87,18 +87,15 @@ void matmul<platform::GPUPlace, float>(const framework::Tensor& matrix_a,
CBLAS_TRANSPOSE transB = (trans_b == false) ? CblasNoTrans : CblasTrans;
gemm<platform::GPUPlace, float>(
transA, transB, M, N, K, alpha, matrix_a.data<float>(),
matrix_b.data<float>(), beta, matrix_out->data<float>(), context);
context, transA, transB, M, N, K, alpha, matrix_a.data<float>(),
matrix_b.data<float>(), beta, matrix_out->data<float>());
}
template <>
void matmul<platform::GPUPlace, double>(const framework::Tensor& matrix_a,
bool trans_a,
const framework::Tensor& matrix_b,
bool trans_b, double alpha,
framework::Tensor* matrix_out,
double beta,
platform::DeviceContext* context) {
void matmul<platform::GPUPlace, double>(
const platform::DeviceContext& context, const framework::Tensor& matrix_a,
bool trans_a, const framework::Tensor& matrix_b, bool trans_b, double alpha,
framework::Tensor* matrix_out, double beta) {
auto dim_a = matrix_a.dims();
auto dim_b = matrix_b.dims();
auto dim_out = matrix_out->dims();
......@@ -118,8 +115,8 @@ void matmul<platform::GPUPlace, double>(const framework::Tensor& matrix_a,
CBLAS_TRANSPOSE transB = (trans_b == false) ? CblasNoTrans : CblasTrans;
gemm<platform::GPUPlace, double>(
transA, transB, M, N, K, alpha, matrix_a.data<double>(),
matrix_b.data<double>(), beta, matrix_out->data<double>(), context);
context, transA, transB, M, N, K, alpha, matrix_a.data<double>(),
matrix_b.data<double>(), beta, matrix_out->data<double>());
}
} // namespace math
......
......@@ -66,16 +66,16 @@ namespace math {
// For more detailed info, please refer to
// http://www.netlib.org/lapack/explore-html/d4/de2/sgemm_8f.html
template <typename Place, typename T>
void gemm(const CBLAS_TRANSPOSE transA, const CBLAS_TRANSPOSE transB,
const int M, const int N, const int K, const T alpha, const T* A,
const T* B, const T beta, T* C, platform::DeviceContext* context);
void gemm(const platform::DeviceContext& context, const CBLAS_TRANSPOSE transA,
const CBLAS_TRANSPOSE transB, const int M, const int N, const int K,
const T alpha, const T* A, const T* B, const T beta, T* C);
// matrix multiply with continuous memory
template <typename Place, typename T>
void matmul(const framework::Tensor& matrix_a, bool trans_a,
void matmul(const platform::DeviceContext& context,
const framework::Tensor& matrix_a, bool trans_a,
const framework::Tensor& matrix_b, bool trans_b, T alpha,
framework::Tensor* matrix_out, T beta,
platform::DeviceContext* context);
framework::Tensor* matrix_out, T beta);
} // namespace math
} // namespace operators
......
......@@ -15,8 +15,7 @@ TEST(math_function, notrans_mul_trans) {
memcpy(input1_ptr, arr, 6 * sizeof(float));
auto* gpu_place = new paddle::platform::GPUPlace(0);
paddle::platform::DeviceContext* context =
new paddle::platform::CUDADeviceContext(*gpu_place);
paddle::platform::CUDADeviceContext context(*gpu_place);
input1_gpu.CopyFrom<float>(input1, *gpu_place);
input2_gpu.CopyFrom<float>(input1, *gpu_place);
......@@ -24,7 +23,7 @@ TEST(math_function, notrans_mul_trans) {
out_gpu.mutable_data<float>({2, 2}, *gpu_place);
paddle::operators::math::matmul<paddle::platform::GPUPlace, float>(
input1_gpu, false, input2_gpu, true, 1, &out_gpu, 0, context);
context, input1_gpu, false, input2_gpu, true, 1, &out_gpu, 0);
out.CopyFrom<float>(out_gpu, *cpu_place);
......@@ -33,6 +32,7 @@ TEST(math_function, notrans_mul_trans) {
EXPECT_EQ(out_ptr[1], 14);
EXPECT_EQ(out_ptr[2], 14);
EXPECT_EQ(out_ptr[3], 50);
delete gpu_place;
}
TEST(math_function, trans_mul_notrans) {
......@@ -48,8 +48,7 @@ TEST(math_function, trans_mul_notrans) {
memcpy(input1_ptr, arr, 6 * sizeof(float));
auto* gpu_place = new paddle::platform::GPUPlace(0);
paddle::platform::DeviceContext* context =
new paddle::platform::CUDADeviceContext(*gpu_place);
paddle::platform::CUDADeviceContext context(*gpu_place);
input1_gpu.CopyFrom<float>(input1, *gpu_place);
input2_gpu.CopyFrom<float>(input1, *gpu_place);
......@@ -57,7 +56,7 @@ TEST(math_function, trans_mul_notrans) {
out_gpu.mutable_data<float>({3, 3}, *gpu_place);
paddle::operators::math::matmul<paddle::platform::GPUPlace, float>(
input1_gpu, true, input2_gpu, false, 1, &out_gpu, 0, context);
context, input1_gpu, true, input2_gpu, false, 1, &out_gpu, 0);
out.CopyFrom<float>(out_gpu, *cpu_place);
......@@ -71,5 +70,6 @@ TEST(math_function, trans_mul_notrans) {
EXPECT_EQ(out_ptr[6], 15);
EXPECT_EQ(out_ptr[7], 22);
EXPECT_EQ(out_ptr[8], 29);
delete gpu_place;
}
#endif
......@@ -46,10 +46,8 @@ class MulKernel : public framework::OpKernel {
: *y;
z->mutable_data<T>(context.GetPlace());
auto* device_context =
const_cast<platform::DeviceContext*>(context.device_context_);
math::matmul<Place, T>(x_matrix, false, y_matrix, false, 1, z, 0,
device_context);
math::matmul<Place, T>(context.device_context(), x_matrix, false, y_matrix,
false, 1, z, 0);
}
};
......@@ -71,16 +69,14 @@ class MulGradKernel : public framework::OpKernel {
Tensor* dx = ctx.Output<Tensor>(framework::GradVarName("X"));
Tensor* dy = ctx.Output<Tensor>(framework::GradVarName("Y"));
auto* device_context =
const_cast<platform::DeviceContext*>(ctx.device_context_);
if (dx) {
dx->mutable_data<T>(ctx.GetPlace());
Tensor dx_matrix = dx->dims().size() > 2 ? framework::ReshapeToMatrix<T>(
*dx, x_num_col_dims)
: *dx;
// dx = dout * y'. dx: M x K, dout : M x N, y : K x N
math::matmul<Place, T>(*dout, false, y_matrix, true, 1, &dx_matrix, 0,
device_context);
math::matmul<Place, T>(ctx.device_context(), *dout, false, y_matrix, true,
1, &dx_matrix, 0);
}
if (dy) {
dy->mutable_data<T>(ctx.GetPlace());
......@@ -88,8 +84,8 @@ class MulGradKernel : public framework::OpKernel {
*dy, y_num_col_dims)
: *dy;
// dy = x' * dout. dy K x N, dout : M x N, x : M x K
math::matmul<Place, T>(x_matrix, true, *dout, false, 1, &dy_matrix, 0,
device_context);
math::matmul<Place, T>(ctx.device_context(), x_matrix, true, *dout, false,
1, &dy_matrix, 0);
}
}
};
......
......@@ -101,19 +101,17 @@ CUDADeviceContext::CUDADeviceContext(GPUPlace place) : place_(place) {
eigen_stream_.reset(new EigenCudaStreamDevice());
eigen_stream_->Reinitialize(&stream_, place);
eigen_device_.reset(new Eigen::GpuDevice(eigen_stream_.get()));
PADDLE_ENFORCE(dynload::cublasCreate(&cublas_handle_));
PADDLE_ENFORCE(dynload::cublasSetStream(cublas_handle_, stream_));
PADDLE_ENFORCE(dynload::cudnnCreate(&cudnn_handle_));
PADDLE_ENFORCE(dynload::cudnnSetStream(cudnn_handle_, stream_));
}
CUDADeviceContext::~CUDADeviceContext() {
SetDeviceId(place_.device);
Wait();
if (cublas_handle_) {
PADDLE_ENFORCE(dynload::cublasDestroy(cublas_handle_));
}
if (cudnn_handle_) {
PADDLE_ENFORCE(dynload::cudnnDestroy(cudnn_handle_));
}
PADDLE_ENFORCE(dynload::cublasDestroy(cublas_handle_));
PADDLE_ENFORCE(dynload::cudnnDestroy(cudnn_handle_));
eigen_stream_.reset();
eigen_device_.reset();
PADDLE_ENFORCE(cudaStreamDestroy(stream_));
......@@ -129,25 +127,13 @@ Eigen::GpuDevice* CUDADeviceContext::eigen_device() const {
return eigen_device_.get();
}
cublasHandle_t CUDADeviceContext::cublas_handle() {
if (!cublas_handle_) {
SetDeviceId(place_.device);
PADDLE_ENFORCE(dynload::cublasCreate(&cublas_handle_));
PADDLE_ENFORCE(dynload::cublasSetStream(cublas_handle_, stream_));
}
cublasHandle_t CUDADeviceContext::cublas_handle() const {
return cublas_handle_;
}
cudnnHandle_t CUDADeviceContext::cudnn_handle() {
if (!cudnn_handle_) {
SetDeviceId(place_.device);
PADDLE_ENFORCE(dynload::cudnnCreate(&cudnn_handle_));
PADDLE_ENFORCE(dynload::cudnnSetStream(cudnn_handle_, stream_));
}
return cudnn_handle_;
}
cudnnHandle_t CUDADeviceContext::cudnn_handle() const { return cudnn_handle_; }
cudaStream_t CUDADeviceContext::stream() { return stream_; }
cudaStream_t CUDADeviceContext::stream() const { return stream_; }
#endif // PADDLE_ONLY_CPU
......
......@@ -67,16 +67,14 @@ class CUDADeviceContext : public DeviceContext {
/*! \brief Return eigen device in the device context. */
Eigen::GpuDevice* eigen_device() const;
// clang-format off
/*! \brief Return cublas handle in the device context. */
cublasHandle_t cublas_handle();
cublasHandle_t cublas_handle() const;
/*! \brief Return cudnn handle in the device context. */
cudnnHandle_t cudnn_handle();
cudnnHandle_t cudnn_handle() const;
/*! \brief Return cuda stream in the device context. */
cudaStream_t stream();
// clang-format on
cudaStream_t stream() const;
private:
GPUPlace place_;
......@@ -84,11 +82,9 @@ class CUDADeviceContext : public DeviceContext {
std::unique_ptr<Eigen::GpuDevice> eigen_device_;
std::unique_ptr<EigenCudaStreamDevice> eigen_stream_;
// clang-format off
cudaStream_t stream_{nullptr};
cudnnHandle_t cudnn_handle_{nullptr};
cublasHandle_t cublas_handle_{nullptr};
// clang-format on
cudaStream_t stream_;
cudnnHandle_t cudnn_handle_;
cublasHandle_t cublas_handle_;
};
#endif
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册