Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
c4e783e5
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
c4e783e5
编写于
9月 18, 2017
作者:
Y
Yu Yang
提交者:
GitHub
9月 18, 2017
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #4174 from reyoung/feature/remove_lazy_init_in_dev_ctx
Remove lazy-initialization in device_context
上级
d4d4580d
847fe473
变更
10
隐藏空白更改
内联
并排
Showing
10 changed file
with
84 addition
and
112 deletion
+84
-112
.gitignore
.gitignore
+1
-0
paddle/framework/operator.cc
paddle/framework/operator.cc
+2
-2
paddle/framework/operator.h
paddle/framework/operator.h
+6
-5
paddle/operators/math/math_function.cc
paddle/operators/math/math_function.cc
+19
-24
paddle/operators/math/math_function.cu
paddle/operators/math/math_function.cu
+23
-26
paddle/operators/math/math_function.h
paddle/operators/math/math_function.h
+6
-6
paddle/operators/math/math_function_test.cc
paddle/operators/math/math_function_test.cc
+6
-6
paddle/operators/mul_op.h
paddle/operators/mul_op.h
+6
-10
paddle/platform/device_context.cc
paddle/platform/device_context.cc
+9
-23
paddle/platform/device_context.h
paddle/platform/device_context.h
+6
-10
未找到文件。
.gitignore
浏览文件 @
c4e783e5
...
...
@@ -27,3 +27,4 @@ CMakeFiles
cmake_install.cmake
paddle/.timestamp
python/paddlepaddle.egg-info/
paddle/pybind/pybind.h
paddle/framework/operator.cc
浏览文件 @
c4e783e5
...
...
@@ -22,14 +22,14 @@ namespace framework {
template
<
>
Eigen
::
DefaultDevice
&
ExecutionContext
::
GetEigenDevice
<
platform
::
CPUPlace
,
Eigen
::
DefaultDevice
>
()
const
{
return
*
device_context_
->
get_eigen_device
<
Eigen
::
DefaultDevice
>
();
return
*
device_context_
.
get_eigen_device
<
Eigen
::
DefaultDevice
>
();
}
#ifndef PADDLE_ONLY_CPU
template
<
>
Eigen
::
GpuDevice
&
ExecutionContext
::
GetEigenDevice
<
platform
::
GPUPlace
,
Eigen
::
GpuDevice
>
()
const
{
return
*
device_context_
->
get_eigen_device
<
Eigen
::
GpuDevice
>
();
return
*
device_context_
.
get_eigen_device
<
Eigen
::
GpuDevice
>
();
}
#endif
...
...
paddle/framework/operator.h
浏览文件 @
c4e783e5
...
...
@@ -366,7 +366,7 @@ struct EigenDeviceConverter<platform::GPUPlace> {
class
ExecutionContext
:
public
InferShapeContext
{
public:
ExecutionContext
(
const
OperatorBase
&
op
,
const
Scope
&
scope
,
const
platform
::
DeviceContext
*
device_context
)
const
platform
::
DeviceContext
&
device_context
)
:
InferShapeContext
(
op
,
scope
),
device_context_
(
device_context
)
{}
template
<
typename
PlaceType
,
...
...
@@ -374,9 +374,9 @@ class ExecutionContext : public InferShapeContext {
typename
EigenDeviceConverter
<
PlaceType
>::
EigenDeviceType
>
DeviceType
&
GetEigenDevice
()
const
;
platform
::
Place
GetPlace
()
const
{
return
device_context_
->
GetPlace
();
}
platform
::
Place
GetPlace
()
const
{
return
device_context_
.
GetPlace
();
}
const
platform
::
DeviceContext
*
device_context
()
const
{
const
platform
::
DeviceContext
&
device_context
()
const
{
return
device_context_
;
}
...
...
@@ -401,7 +401,8 @@ class ExecutionContext : public InferShapeContext {
return
res
;
}
const
platform
::
DeviceContext
*
device_context_
;
private:
const
platform
::
DeviceContext
&
device_context_
;
};
template
<>
...
...
@@ -461,7 +462,7 @@ class OperatorWithKernel : public OperatorBase {
void
Run
(
const
Scope
&
scope
,
const
platform
::
DeviceContext
&
dev_ctx
)
const
final
{
auto
&
opKernel
=
AllOpKernels
().
at
(
type_
).
at
(
OpKernelKey
(
dev_ctx
));
opKernel
->
Compute
(
ExecutionContext
(
*
this
,
scope
,
&
dev_ctx
));
opKernel
->
Compute
(
ExecutionContext
(
*
this
,
scope
,
dev_ctx
));
}
static
std
::
unordered_map
<
std
::
string
/* op_type */
,
OpKernelMap
>&
...
...
paddle/operators/math/math_function.cc
浏览文件 @
c4e783e5
...
...
@@ -19,12 +19,13 @@ namespace operators {
namespace
math
{
template
<
>
void
gemm
<
platform
::
CPUPlace
,
float
>
(
const
CBLAS_TRANSPOSE
transA
,
void
gemm
<
platform
::
CPUPlace
,
float
>
(
const
platform
::
DeviceContext
&
context
,
const
CBLAS_TRANSPOSE
transA
,
const
CBLAS_TRANSPOSE
transB
,
const
int
M
,
const
int
N
,
const
int
K
,
const
float
alpha
,
const
float
*
A
,
const
float
*
B
,
const
float
beta
,
float
*
C
,
platform
::
DeviceContext
*
context
)
{
const
float
*
B
,
const
float
beta
,
float
*
C
)
{
int
lda
=
(
transA
==
CblasNoTrans
)
?
K
:
M
;
int
ldb
=
(
transB
==
CblasNoTrans
)
?
N
:
K
;
int
ldc
=
N
;
...
...
@@ -33,13 +34,13 @@ void gemm<platform::CPUPlace, float>(const CBLAS_TRANSPOSE transA,
}
template
<
>
void
gemm
<
platform
::
CPUPlace
,
double
>
(
const
CBLAS_TRANSPOSE
transA
,
void
gemm
<
platform
::
CPUPlace
,
double
>
(
const
platform
::
DeviceContext
&
context
,
const
CBLAS_TRANSPOSE
transA
,
const
CBLAS_TRANSPOSE
transB
,
const
int
M
,
const
int
N
,
const
int
K
,
const
double
alpha
,
const
double
*
A
,
const
double
*
B
,
const
double
beta
,
double
*
C
,
platform
::
DeviceContext
*
context
)
{
double
*
C
)
{
int
lda
=
(
transA
==
CblasNoTrans
)
?
K
:
M
;
int
ldb
=
(
transB
==
CblasNoTrans
)
?
N
:
K
;
int
ldc
=
N
;
...
...
@@ -48,13 +49,10 @@ void gemm<platform::CPUPlace, double>(const CBLAS_TRANSPOSE transA,
}
template
<
>
void
matmul
<
platform
::
CPUPlace
,
float
>
(
const
framework
::
Tensor
&
matrix_a
,
bool
trans_a
,
const
framework
::
Tensor
&
matrix_b
,
bool
trans_b
,
float
alpha
,
framework
::
Tensor
*
matrix_out
,
float
beta
,
platform
::
DeviceContext
*
context
)
{
void
matmul
<
platform
::
CPUPlace
,
float
>
(
const
platform
::
DeviceContext
&
context
,
const
framework
::
Tensor
&
matrix_a
,
bool
trans_a
,
const
framework
::
Tensor
&
matrix_b
,
bool
trans_b
,
float
alpha
,
framework
::
Tensor
*
matrix_out
,
float
beta
)
{
auto
dim_a
=
matrix_a
.
dims
();
auto
dim_b
=
matrix_b
.
dims
();
auto
dim_out
=
matrix_out
->
dims
();
...
...
@@ -74,18 +72,15 @@ void matmul<platform::CPUPlace, float>(const framework::Tensor& matrix_a,
CBLAS_TRANSPOSE
transB
=
(
trans_b
==
false
)
?
CblasNoTrans
:
CblasTrans
;
gemm
<
platform
::
CPUPlace
,
float
>
(
transA
,
transB
,
M
,
N
,
K
,
alpha
,
matrix_a
.
data
<
float
>
(),
matrix_b
.
data
<
float
>
(),
beta
,
matrix_out
->
data
<
float
>
()
,
context
);
context
,
transA
,
transB
,
M
,
N
,
K
,
alpha
,
matrix_a
.
data
<
float
>
(),
matrix_b
.
data
<
float
>
(),
beta
,
matrix_out
->
data
<
float
>
());
}
template
<
>
void
matmul
<
platform
::
CPUPlace
,
double
>
(
const
framework
::
Tensor
&
matrix_a
,
bool
trans_a
,
const
framework
::
Tensor
&
matrix_b
,
bool
trans_b
,
double
alpha
,
framework
::
Tensor
*
matrix_out
,
double
beta
,
platform
::
DeviceContext
*
context
)
{
void
matmul
<
platform
::
CPUPlace
,
double
>
(
const
platform
::
DeviceContext
&
context
,
const
framework
::
Tensor
&
matrix_a
,
bool
trans_a
,
const
framework
::
Tensor
&
matrix_b
,
bool
trans_b
,
double
alpha
,
framework
::
Tensor
*
matrix_out
,
double
beta
)
{
auto
dim_a
=
matrix_a
.
dims
();
auto
dim_b
=
matrix_b
.
dims
();
auto
dim_out
=
matrix_out
->
dims
();
...
...
@@ -105,8 +100,8 @@ void matmul<platform::CPUPlace, double>(const framework::Tensor& matrix_a,
CBLAS_TRANSPOSE
transB
=
(
trans_b
==
false
)
?
CblasNoTrans
:
CblasTrans
;
gemm
<
platform
::
CPUPlace
,
double
>
(
transA
,
transB
,
M
,
N
,
K
,
alpha
,
matrix_a
.
data
<
double
>
(),
matrix_b
.
data
<
double
>
(),
beta
,
matrix_out
->
data
<
double
>
()
,
context
);
context
,
transA
,
transB
,
M
,
N
,
K
,
alpha
,
matrix_a
.
data
<
double
>
(),
matrix_b
.
data
<
double
>
(),
beta
,
matrix_out
->
data
<
double
>
());
}
}
// namespace math
...
...
paddle/operators/math/math_function.cu
浏览文件 @
c4e783e5
...
...
@@ -19,12 +19,13 @@ namespace operators {
namespace
math
{
template
<
>
void
gemm
<
platform
::
GPUPlace
,
float
>
(
const
CBLAS_TRANSPOSE
transA
,
void
gemm
<
platform
::
GPUPlace
,
float
>
(
const
platform
::
DeviceContext
&
context
,
const
CBLAS_TRANSPOSE
transA
,
const
CBLAS_TRANSPOSE
transB
,
const
int
M
,
const
int
N
,
const
int
K
,
const
float
alpha
,
const
float
*
A
,
const
float
*
B
,
const
float
beta
,
float
*
C
,
platform
::
DeviceContext
*
context
)
{
const
float
*
B
,
const
float
beta
,
float
*
C
)
{
// Note that cublas follows fortran order, so the order is different from
// the cblas convention.
int
lda
=
(
transA
==
CblasNoTrans
)
?
K
:
M
;
...
...
@@ -35,18 +36,19 @@ void gemm<platform::GPUPlace, float>(const CBLAS_TRANSPOSE transA,
(
transB
==
CblasNoTrans
)
?
CUBLAS_OP_N
:
CUBLAS_OP_T
;
PADDLE_ENFORCE
(
platform
::
dynload
::
cublasSgemm
(
reinterpret_cast
<
platform
::
CUDADeviceContext
*>
(
context
)
->
cublas_handle
(),
reinterpret_cast
<
const
platform
::
CUDADeviceContext
&>
(
context
)
.
cublas_handle
(),
cuTransB
,
cuTransA
,
N
,
M
,
K
,
&
alpha
,
B
,
ldb
,
A
,
lda
,
&
beta
,
C
,
N
));
}
template
<
>
void
gemm
<
platform
::
GPUPlace
,
double
>
(
const
CBLAS_TRANSPOSE
transA
,
void
gemm
<
platform
::
GPUPlace
,
double
>
(
const
platform
::
DeviceContext
&
context
,
const
CBLAS_TRANSPOSE
transA
,
const
CBLAS_TRANSPOSE
transB
,
const
int
M
,
const
int
N
,
const
int
K
,
const
double
alpha
,
const
double
*
A
,
const
double
*
B
,
const
double
beta
,
double
*
C
,
platform
::
DeviceContext
*
context
)
{
double
*
C
)
{
// Note that cublas follows fortran order, so the order is different from
// the cblas convention.
int
lda
=
(
transA
==
CblasNoTrans
)
?
K
:
M
;
...
...
@@ -56,18 +58,16 @@ void gemm<platform::GPUPlace, double>(const CBLAS_TRANSPOSE transA,
cublasOperation_t
cuTransB
=
(
transB
==
CblasNoTrans
)
?
CUBLAS_OP_N
:
CUBLAS_OP_T
;
PADDLE_ENFORCE
(
platform
::
dynload
::
cublasDgemm
(
reinterpret_cast
<
platform
::
CUDADeviceContext
*>
(
context
)
->
cublas_handle
(),
reinterpret_cast
<
const
platform
::
CUDADeviceContext
&>
(
context
)
.
cublas_handle
(),
cuTransB
,
cuTransA
,
N
,
M
,
K
,
&
alpha
,
B
,
ldb
,
A
,
lda
,
&
beta
,
C
,
N
));
}
template
<
>
void
matmul
<
platform
::
GPUPlace
,
float
>
(
const
framework
::
Tensor
&
matrix_a
,
bool
trans_a
,
const
framework
::
Tensor
&
matrix_b
,
bool
trans_b
,
float
alpha
,
framework
::
Tensor
*
matrix_out
,
float
beta
,
platform
::
DeviceContext
*
context
)
{
void
matmul
<
platform
::
GPUPlace
,
float
>
(
const
platform
::
DeviceContext
&
context
,
const
framework
::
Tensor
&
matrix_a
,
bool
trans_a
,
const
framework
::
Tensor
&
matrix_b
,
bool
trans_b
,
float
alpha
,
framework
::
Tensor
*
matrix_out
,
float
beta
)
{
auto
dim_a
=
matrix_a
.
dims
();
auto
dim_b
=
matrix_b
.
dims
();
auto
dim_out
=
matrix_out
->
dims
();
...
...
@@ -87,18 +87,15 @@ void matmul<platform::GPUPlace, float>(const framework::Tensor& matrix_a,
CBLAS_TRANSPOSE
transB
=
(
trans_b
==
false
)
?
CblasNoTrans
:
CblasTrans
;
gemm
<
platform
::
GPUPlace
,
float
>
(
transA
,
transB
,
M
,
N
,
K
,
alpha
,
matrix_a
.
data
<
float
>
(),
matrix_b
.
data
<
float
>
(),
beta
,
matrix_out
->
data
<
float
>
()
,
context
);
context
,
transA
,
transB
,
M
,
N
,
K
,
alpha
,
matrix_a
.
data
<
float
>
(),
matrix_b
.
data
<
float
>
(),
beta
,
matrix_out
->
data
<
float
>
());
}
template
<
>
void
matmul
<
platform
::
GPUPlace
,
double
>
(
const
framework
::
Tensor
&
matrix_a
,
bool
trans_a
,
const
framework
::
Tensor
&
matrix_b
,
bool
trans_b
,
double
alpha
,
framework
::
Tensor
*
matrix_out
,
double
beta
,
platform
::
DeviceContext
*
context
)
{
void
matmul
<
platform
::
GPUPlace
,
double
>
(
const
platform
::
DeviceContext
&
context
,
const
framework
::
Tensor
&
matrix_a
,
bool
trans_a
,
const
framework
::
Tensor
&
matrix_b
,
bool
trans_b
,
double
alpha
,
framework
::
Tensor
*
matrix_out
,
double
beta
)
{
auto
dim_a
=
matrix_a
.
dims
();
auto
dim_b
=
matrix_b
.
dims
();
auto
dim_out
=
matrix_out
->
dims
();
...
...
@@ -118,8 +115,8 @@ void matmul<platform::GPUPlace, double>(const framework::Tensor& matrix_a,
CBLAS_TRANSPOSE
transB
=
(
trans_b
==
false
)
?
CblasNoTrans
:
CblasTrans
;
gemm
<
platform
::
GPUPlace
,
double
>
(
transA
,
transB
,
M
,
N
,
K
,
alpha
,
matrix_a
.
data
<
double
>
(),
matrix_b
.
data
<
double
>
(),
beta
,
matrix_out
->
data
<
double
>
()
,
context
);
context
,
transA
,
transB
,
M
,
N
,
K
,
alpha
,
matrix_a
.
data
<
double
>
(),
matrix_b
.
data
<
double
>
(),
beta
,
matrix_out
->
data
<
double
>
());
}
}
// namespace math
...
...
paddle/operators/math/math_function.h
浏览文件 @
c4e783e5
...
...
@@ -66,16 +66,16 @@ namespace math {
// For more detailed info, please refer to
// http://www.netlib.org/lapack/explore-html/d4/de2/sgemm_8f.html
template
<
typename
Place
,
typename
T
>
void
gemm
(
const
CBLAS_TRANSPOSE
transA
,
const
CBLAS_TRANSPOSE
transB
,
const
int
M
,
const
int
N
,
const
int
K
,
const
T
alpha
,
const
T
*
A
,
const
T
*
B
,
const
T
beta
,
T
*
C
,
platform
::
DeviceContext
*
context
);
void
gemm
(
const
platform
::
DeviceContext
&
context
,
const
CBLAS_TRANSPOSE
transA
,
const
CBLAS_TRANSPOSE
transB
,
const
int
M
,
const
int
N
,
const
int
K
,
const
T
alpha
,
const
T
*
A
,
const
T
*
B
,
const
T
beta
,
T
*
C
);
// matrix multiply with continuous memory
template
<
typename
Place
,
typename
T
>
void
matmul
(
const
framework
::
Tensor
&
matrix_a
,
bool
trans_a
,
void
matmul
(
const
platform
::
DeviceContext
&
context
,
const
framework
::
Tensor
&
matrix_a
,
bool
trans_a
,
const
framework
::
Tensor
&
matrix_b
,
bool
trans_b
,
T
alpha
,
framework
::
Tensor
*
matrix_out
,
T
beta
,
platform
::
DeviceContext
*
context
);
framework
::
Tensor
*
matrix_out
,
T
beta
);
}
// namespace math
}
// namespace operators
...
...
paddle/operators/math/math_function_test.cc
浏览文件 @
c4e783e5
...
...
@@ -15,8 +15,7 @@ TEST(math_function, notrans_mul_trans) {
memcpy
(
input1_ptr
,
arr
,
6
*
sizeof
(
float
));
auto
*
gpu_place
=
new
paddle
::
platform
::
GPUPlace
(
0
);
paddle
::
platform
::
DeviceContext
*
context
=
new
paddle
::
platform
::
CUDADeviceContext
(
*
gpu_place
);
paddle
::
platform
::
CUDADeviceContext
context
(
*
gpu_place
);
input1_gpu
.
CopyFrom
<
float
>
(
input1
,
*
gpu_place
);
input2_gpu
.
CopyFrom
<
float
>
(
input1
,
*
gpu_place
);
...
...
@@ -24,7 +23,7 @@ TEST(math_function, notrans_mul_trans) {
out_gpu
.
mutable_data
<
float
>
({
2
,
2
},
*
gpu_place
);
paddle
::
operators
::
math
::
matmul
<
paddle
::
platform
::
GPUPlace
,
float
>
(
input1_gpu
,
false
,
input2_gpu
,
true
,
1
,
&
out_gpu
,
0
,
context
);
context
,
input1_gpu
,
false
,
input2_gpu
,
true
,
1
,
&
out_gpu
,
0
);
out
.
CopyFrom
<
float
>
(
out_gpu
,
*
cpu_place
);
...
...
@@ -33,6 +32,7 @@ TEST(math_function, notrans_mul_trans) {
EXPECT_EQ
(
out_ptr
[
1
],
14
);
EXPECT_EQ
(
out_ptr
[
2
],
14
);
EXPECT_EQ
(
out_ptr
[
3
],
50
);
delete
gpu_place
;
}
TEST
(
math_function
,
trans_mul_notrans
)
{
...
...
@@ -48,8 +48,7 @@ TEST(math_function, trans_mul_notrans) {
memcpy
(
input1_ptr
,
arr
,
6
*
sizeof
(
float
));
auto
*
gpu_place
=
new
paddle
::
platform
::
GPUPlace
(
0
);
paddle
::
platform
::
DeviceContext
*
context
=
new
paddle
::
platform
::
CUDADeviceContext
(
*
gpu_place
);
paddle
::
platform
::
CUDADeviceContext
context
(
*
gpu_place
);
input1_gpu
.
CopyFrom
<
float
>
(
input1
,
*
gpu_place
);
input2_gpu
.
CopyFrom
<
float
>
(
input1
,
*
gpu_place
);
...
...
@@ -57,7 +56,7 @@ TEST(math_function, trans_mul_notrans) {
out_gpu
.
mutable_data
<
float
>
({
3
,
3
},
*
gpu_place
);
paddle
::
operators
::
math
::
matmul
<
paddle
::
platform
::
GPUPlace
,
float
>
(
input1_gpu
,
true
,
input2_gpu
,
false
,
1
,
&
out_gpu
,
0
,
context
);
context
,
input1_gpu
,
true
,
input2_gpu
,
false
,
1
,
&
out_gpu
,
0
);
out
.
CopyFrom
<
float
>
(
out_gpu
,
*
cpu_place
);
...
...
@@ -71,5 +70,6 @@ TEST(math_function, trans_mul_notrans) {
EXPECT_EQ
(
out_ptr
[
6
],
15
);
EXPECT_EQ
(
out_ptr
[
7
],
22
);
EXPECT_EQ
(
out_ptr
[
8
],
29
);
delete
gpu_place
;
}
#endif
paddle/operators/mul_op.h
浏览文件 @
c4e783e5
...
...
@@ -46,10 +46,8 @@ class MulKernel : public framework::OpKernel {
:
*
y
;
z
->
mutable_data
<
T
>
(
context
.
GetPlace
());
auto
*
device_context
=
const_cast
<
platform
::
DeviceContext
*>
(
context
.
device_context_
);
math
::
matmul
<
Place
,
T
>
(
x_matrix
,
false
,
y_matrix
,
false
,
1
,
z
,
0
,
device_context
);
math
::
matmul
<
Place
,
T
>
(
context
.
device_context
(),
x_matrix
,
false
,
y_matrix
,
false
,
1
,
z
,
0
);
}
};
...
...
@@ -71,16 +69,14 @@ class MulGradKernel : public framework::OpKernel {
Tensor
*
dx
=
ctx
.
Output
<
Tensor
>
(
framework
::
GradVarName
(
"X"
));
Tensor
*
dy
=
ctx
.
Output
<
Tensor
>
(
framework
::
GradVarName
(
"Y"
));
auto
*
device_context
=
const_cast
<
platform
::
DeviceContext
*>
(
ctx
.
device_context_
);
if
(
dx
)
{
dx
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
Tensor
dx_matrix
=
dx
->
dims
().
size
()
>
2
?
framework
::
ReshapeToMatrix
<
T
>
(
*
dx
,
x_num_col_dims
)
:
*
dx
;
// dx = dout * y'. dx: M x K, dout : M x N, y : K x N
math
::
matmul
<
Place
,
T
>
(
*
dout
,
false
,
y_matrix
,
true
,
1
,
&
dx_matrix
,
0
,
device_context
);
math
::
matmul
<
Place
,
T
>
(
ctx
.
device_context
(),
*
dout
,
false
,
y_matrix
,
true
,
1
,
&
dx_matrix
,
0
);
}
if
(
dy
)
{
dy
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
...
...
@@ -88,8 +84,8 @@ class MulGradKernel : public framework::OpKernel {
*
dy
,
y_num_col_dims
)
:
*
dy
;
// dy = x' * dout. dy K x N, dout : M x N, x : M x K
math
::
matmul
<
Place
,
T
>
(
x_matrix
,
true
,
*
dout
,
false
,
1
,
&
dy_matrix
,
0
,
device_context
);
math
::
matmul
<
Place
,
T
>
(
ctx
.
device_context
(),
x_matrix
,
true
,
*
dout
,
false
,
1
,
&
dy_matrix
,
0
);
}
}
};
...
...
paddle/platform/device_context.cc
浏览文件 @
c4e783e5
...
...
@@ -101,19 +101,17 @@ CUDADeviceContext::CUDADeviceContext(GPUPlace place) : place_(place) {
eigen_stream_
.
reset
(
new
EigenCudaStreamDevice
());
eigen_stream_
->
Reinitialize
(
&
stream_
,
place
);
eigen_device_
.
reset
(
new
Eigen
::
GpuDevice
(
eigen_stream_
.
get
()));
PADDLE_ENFORCE
(
dynload
::
cublasCreate
(
&
cublas_handle_
));
PADDLE_ENFORCE
(
dynload
::
cublasSetStream
(
cublas_handle_
,
stream_
));
PADDLE_ENFORCE
(
dynload
::
cudnnCreate
(
&
cudnn_handle_
));
PADDLE_ENFORCE
(
dynload
::
cudnnSetStream
(
cudnn_handle_
,
stream_
));
}
CUDADeviceContext
::~
CUDADeviceContext
()
{
SetDeviceId
(
place_
.
device
);
Wait
();
if
(
cublas_handle_
)
{
PADDLE_ENFORCE
(
dynload
::
cublasDestroy
(
cublas_handle_
));
}
if
(
cudnn_handle_
)
{
PADDLE_ENFORCE
(
dynload
::
cudnnDestroy
(
cudnn_handle_
));
}
PADDLE_ENFORCE
(
dynload
::
cublasDestroy
(
cublas_handle_
));
PADDLE_ENFORCE
(
dynload
::
cudnnDestroy
(
cudnn_handle_
));
eigen_stream_
.
reset
();
eigen_device_
.
reset
();
PADDLE_ENFORCE
(
cudaStreamDestroy
(
stream_
));
...
...
@@ -129,25 +127,13 @@ Eigen::GpuDevice* CUDADeviceContext::eigen_device() const {
return
eigen_device_
.
get
();
}
cublasHandle_t
CUDADeviceContext
::
cublas_handle
()
{
if
(
!
cublas_handle_
)
{
SetDeviceId
(
place_
.
device
);
PADDLE_ENFORCE
(
dynload
::
cublasCreate
(
&
cublas_handle_
));
PADDLE_ENFORCE
(
dynload
::
cublasSetStream
(
cublas_handle_
,
stream_
));
}
cublasHandle_t
CUDADeviceContext
::
cublas_handle
()
const
{
return
cublas_handle_
;
}
cudnnHandle_t
CUDADeviceContext
::
cudnn_handle
()
{
if
(
!
cudnn_handle_
)
{
SetDeviceId
(
place_
.
device
);
PADDLE_ENFORCE
(
dynload
::
cudnnCreate
(
&
cudnn_handle_
));
PADDLE_ENFORCE
(
dynload
::
cudnnSetStream
(
cudnn_handle_
,
stream_
));
}
return
cudnn_handle_
;
}
cudnnHandle_t
CUDADeviceContext
::
cudnn_handle
()
const
{
return
cudnn_handle_
;
}
cudaStream_t
CUDADeviceContext
::
stream
()
{
return
stream_
;
}
cudaStream_t
CUDADeviceContext
::
stream
()
const
{
return
stream_
;
}
#endif // PADDLE_ONLY_CPU
...
...
paddle/platform/device_context.h
浏览文件 @
c4e783e5
...
...
@@ -67,16 +67,14 @@ class CUDADeviceContext : public DeviceContext {
/*! \brief Return eigen device in the device context. */
Eigen
::
GpuDevice
*
eigen_device
()
const
;
// clang-format off
/*! \brief Return cublas handle in the device context. */
cublasHandle_t
cublas_handle
()
;
cublasHandle_t
cublas_handle
()
const
;
/*! \brief Return cudnn handle in the device context. */
cudnnHandle_t
cudnn_handle
()
;
cudnnHandle_t
cudnn_handle
()
const
;
/*! \brief Return cuda stream in the device context. */
cudaStream_t
stream
();
// clang-format on
cudaStream_t
stream
()
const
;
private:
GPUPlace
place_
;
...
...
@@ -84,11 +82,9 @@ class CUDADeviceContext : public DeviceContext {
std
::
unique_ptr
<
Eigen
::
GpuDevice
>
eigen_device_
;
std
::
unique_ptr
<
EigenCudaStreamDevice
>
eigen_stream_
;
// clang-format off
cudaStream_t
stream_
{
nullptr
};
cudnnHandle_t
cudnn_handle_
{
nullptr
};
cublasHandle_t
cublas_handle_
{
nullptr
};
// clang-format on
cudaStream_t
stream_
;
cudnnHandle_t
cudnn_handle_
;
cublasHandle_t
cublas_handle_
;
};
#endif
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录