Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleDetection
提交
3b44b849
P
PaddleDetection
项目概览
PaddlePaddle
/
PaddleDetection
大约 1 年 前同步成功
通知
694
Star
11112
Fork
2696
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
184
列表
看板
标记
里程碑
合并请求
40
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleDetection
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
184
Issue
184
列表
看板
标记
里程碑
合并请求
40
合并请求
40
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
3b44b849
编写于
3月 11, 2018
作者:
K
Kexin Zhao
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
address comments
上级
95de7617
变更
4
显示空白变更内容
内联
并排
Showing
4 changed file
with
37 addition
and
20 deletion
+37
-20
paddle/fluid/operators/math/math_function.cu
paddle/fluid/operators/math/math_function.cu
+9
-0
paddle/fluid/operators/math/math_function_test.cu
paddle/fluid/operators/math/math_function_test.cu
+20
-20
paddle/fluid/platform/device_context.cc
paddle/fluid/platform/device_context.cc
+5
-0
paddle/fluid/platform/device_context.h
paddle/fluid/platform/device_context.h
+3
-0
未找到文件。
paddle/fluid/operators/math/math_function.cu
浏览文件 @
3b44b849
...
@@ -45,6 +45,9 @@ void gemm<platform::CUDADeviceContext, float16>(
...
@@ -45,6 +45,9 @@ void gemm<platform::CUDADeviceContext, float16>(
const
half
*
h_B
=
reinterpret_cast
<
const
half
*>
(
B
);
const
half
*
h_B
=
reinterpret_cast
<
const
half
*>
(
B
);
half
*
h_C
=
reinterpret_cast
<
half
*>
(
C
);
half
*
h_C
=
reinterpret_cast
<
half
*>
(
C
);
// TODO(kexinzhao): add processing code for compute capability < 53 case
PADDLE_ENFORCE_GE
(
context
.
GetComputeCapability
(),
53
,
"cublas Hgemm requires GPU compute capability >= 53"
);
PADDLE_ENFORCE
(
platform
::
dynload
::
cublasHgemm
(
PADDLE_ENFORCE
(
platform
::
dynload
::
cublasHgemm
(
context
.
cublas_handle
(),
cuTransB
,
cuTransA
,
N
,
M
,
K
,
&
h_alpha
,
h_B
,
ldb
,
context
.
cublas_handle
(),
cuTransB
,
cuTransA
,
N
,
M
,
K
,
&
h_alpha
,
h_B
,
ldb
,
h_A
,
lda
,
&
h_beta
,
h_C
,
N
));
h_A
,
lda
,
&
h_beta
,
h_C
,
N
));
...
@@ -106,6 +109,9 @@ void gemm<platform::CUDADeviceContext, float16>(
...
@@ -106,6 +109,9 @@ void gemm<platform::CUDADeviceContext, float16>(
const
half
*
h_B
=
reinterpret_cast
<
const
half
*>
(
B
);
const
half
*
h_B
=
reinterpret_cast
<
const
half
*>
(
B
);
half
*
h_C
=
reinterpret_cast
<
half
*>
(
C
);
half
*
h_C
=
reinterpret_cast
<
half
*>
(
C
);
// TODO(kexinzhao): add processing code for compute capability < 53 case
PADDLE_ENFORCE_GE
(
context
.
GetComputeCapability
(),
53
,
"cublas Hgemm requires GPU compute capability >= 53"
);
PADDLE_ENFORCE
(
platform
::
dynload
::
cublasHgemm
(
PADDLE_ENFORCE
(
platform
::
dynload
::
cublasHgemm
(
context
.
cublas_handle
(),
cuTransB
,
cuTransA
,
N
,
M
,
K
,
&
h_alpha
,
h_B
,
ldb
,
context
.
cublas_handle
(),
cuTransB
,
cuTransA
,
N
,
M
,
K
,
&
h_alpha
,
h_B
,
ldb
,
h_A
,
lda
,
&
h_beta
,
h_C
,
ldc
));
h_A
,
lda
,
&
h_beta
,
h_C
,
ldc
));
...
@@ -251,6 +257,9 @@ void batched_gemm<platform::CUDADeviceContext, float16>(
...
@@ -251,6 +257,9 @@ void batched_gemm<platform::CUDADeviceContext, float16>(
const
half
*
h_B
=
reinterpret_cast
<
const
half
*>
(
B
);
const
half
*
h_B
=
reinterpret_cast
<
const
half
*>
(
B
);
half
*
h_C
=
reinterpret_cast
<
half
*>
(
C
);
half
*
h_C
=
reinterpret_cast
<
half
*>
(
C
);
// TODO(kexinzhao): add processing code for compute capability < 53 case
PADDLE_ENFORCE_GE
(
context
.
GetComputeCapability
(),
53
,
"cublas Hgemm requires GPU compute capability >= 53"
);
PADDLE_ENFORCE
(
platform
::
dynload
::
cublasHgemmStridedBatched
(
PADDLE_ENFORCE
(
platform
::
dynload
::
cublasHgemmStridedBatched
(
context
.
cublas_handle
(),
cuTransB
,
cuTransA
,
N
,
M
,
K
,
&
h_alpha
,
h_B
,
ldb
,
context
.
cublas_handle
(),
cuTransB
,
cuTransA
,
N
,
M
,
K
,
&
h_alpha
,
h_B
,
ldb
,
strideB
,
h_A
,
lda
,
strideA
,
&
h_beta
,
h_C
,
ldc
,
strideC
,
batchCount
));
strideB
,
h_A
,
lda
,
strideA
,
&
h_beta
,
h_C
,
ldc
,
strideC
,
batchCount
));
...
...
paddle/fluid/operators/math/math_function_test.cu
浏览文件 @
3b44b849
...
@@ -62,11 +62,6 @@ TEST(math_function, notrans_mul_trans_fp16) {
...
@@ -62,11 +62,6 @@ TEST(math_function, notrans_mul_trans_fp16) {
using
namespace
paddle
::
framework
;
using
namespace
paddle
::
framework
;
using
namespace
paddle
::
platform
;
using
namespace
paddle
::
platform
;
// fp16 GEMM in cublas requires GPU compute capability >= 53
if
(
GetCUDAComputeCapability
(
0
)
<
53
)
{
return
;
}
Tensor
input1
;
Tensor
input1
;
Tensor
input1_gpu
;
Tensor
input1_gpu
;
Tensor
input2_gpu
;
Tensor
input2_gpu
;
...
@@ -77,6 +72,11 @@ TEST(math_function, notrans_mul_trans_fp16) {
...
@@ -77,6 +72,11 @@ TEST(math_function, notrans_mul_trans_fp16) {
CUDAPlace
gpu_place
(
0
);
CUDAPlace
gpu_place
(
0
);
CUDADeviceContext
context
(
gpu_place
);
CUDADeviceContext
context
(
gpu_place
);
// fp16 GEMM in cublas requires GPU compute capability >= 53
if
(
context
.
GetComputeCapability
()
<
53
)
{
return
;
}
float16
*
input1_ptr
=
input1
.
mutable_data
<
float16
>
({
2
,
3
},
cpu_place
);
float16
*
input1_ptr
=
input1
.
mutable_data
<
float16
>
({
2
,
3
},
cpu_place
);
fill_fp16_data
(
input1_ptr
,
input1
.
numel
(),
{
0
,
1
,
2
,
3
,
4
,
5
});
fill_fp16_data
(
input1_ptr
,
input1
.
numel
(),
{
0
,
1
,
2
,
3
,
4
,
5
});
...
@@ -144,11 +144,6 @@ TEST(math_function, trans_mul_notrans_fp16) {
...
@@ -144,11 +144,6 @@ TEST(math_function, trans_mul_notrans_fp16) {
using
namespace
paddle
::
framework
;
using
namespace
paddle
::
framework
;
using
namespace
paddle
::
platform
;
using
namespace
paddle
::
platform
;
// fp16 GEMM in cublas requires GPU compute capability >= 53
if
(
GetCUDAComputeCapability
(
0
)
<
53
)
{
return
;
}
Tensor
input1
;
Tensor
input1
;
Tensor
input1_gpu
;
Tensor
input1_gpu
;
Tensor
input2_gpu
;
Tensor
input2_gpu
;
...
@@ -159,6 +154,11 @@ TEST(math_function, trans_mul_notrans_fp16) {
...
@@ -159,6 +154,11 @@ TEST(math_function, trans_mul_notrans_fp16) {
CUDAPlace
gpu_place
(
0
);
CUDAPlace
gpu_place
(
0
);
CUDADeviceContext
context
(
gpu_place
);
CUDADeviceContext
context
(
gpu_place
);
// fp16 GEMM in cublas requires GPU compute capability >= 53
if
(
context
.
GetComputeCapability
()
<
53
)
{
return
;
}
float16
*
input1_ptr
=
input1
.
mutable_data
<
float16
>
({
2
,
3
},
cpu_place
);
float16
*
input1_ptr
=
input1
.
mutable_data
<
float16
>
({
2
,
3
},
cpu_place
);
fill_fp16_data
(
input1_ptr
,
input1
.
numel
(),
{
0
,
1
,
2
,
3
,
4
,
5
});
fill_fp16_data
(
input1_ptr
,
input1
.
numel
(),
{
0
,
1
,
2
,
3
,
4
,
5
});
...
@@ -247,11 +247,6 @@ TEST(math_function, gemm_notrans_cublas_fp16) {
...
@@ -247,11 +247,6 @@ TEST(math_function, gemm_notrans_cublas_fp16) {
using
namespace
paddle
::
framework
;
using
namespace
paddle
::
framework
;
using
namespace
paddle
::
platform
;
using
namespace
paddle
::
platform
;
// fp16 GEMM in cublas requires GPU compute capability >= 53
if
(
GetCUDAComputeCapability
(
0
)
<
53
)
{
return
;
}
Tensor
input1
;
Tensor
input1
;
Tensor
input2
;
Tensor
input2
;
Tensor
input3
;
Tensor
input3
;
...
@@ -263,6 +258,11 @@ TEST(math_function, gemm_notrans_cublas_fp16) {
...
@@ -263,6 +258,11 @@ TEST(math_function, gemm_notrans_cublas_fp16) {
CUDAPlace
gpu_place
(
0
);
CUDAPlace
gpu_place
(
0
);
CUDADeviceContext
context
(
gpu_place
);
CUDADeviceContext
context
(
gpu_place
);
// fp16 GEMM in cublas requires GPU compute capability >= 53
if
(
context
.
GetComputeCapability
()
<
53
)
{
return
;
}
int
m
=
2
;
int
m
=
2
;
int
n
=
3
;
int
n
=
3
;
int
k
=
3
;
int
k
=
3
;
...
@@ -359,11 +359,6 @@ TEST(math_function, gemm_trans_cublas_fp16) {
...
@@ -359,11 +359,6 @@ TEST(math_function, gemm_trans_cublas_fp16) {
using
namespace
paddle
::
framework
;
using
namespace
paddle
::
framework
;
using
namespace
paddle
::
platform
;
using
namespace
paddle
::
platform
;
// fp16 GEMM in cublas requires GPU compute capability >= 53
if
(
GetCUDAComputeCapability
(
0
)
<
53
)
{
return
;
}
Tensor
input1
;
Tensor
input1
;
Tensor
input2
;
Tensor
input2
;
Tensor
input3
;
Tensor
input3
;
...
@@ -375,6 +370,11 @@ TEST(math_function, gemm_trans_cublas_fp16) {
...
@@ -375,6 +370,11 @@ TEST(math_function, gemm_trans_cublas_fp16) {
CUDAPlace
gpu_place
(
0
);
CUDAPlace
gpu_place
(
0
);
CUDADeviceContext
context
(
gpu_place
);
CUDADeviceContext
context
(
gpu_place
);
// fp16 GEMM in cublas requires GPU compute capability >= 53
if
(
context
.
GetComputeCapability
()
<
53
)
{
return
;
}
int
m
=
2
;
int
m
=
2
;
int
n
=
3
;
int
n
=
3
;
int
k
=
3
;
int
k
=
3
;
...
...
paddle/fluid/platform/device_context.cc
浏览文件 @
3b44b849
...
@@ -127,6 +127,7 @@ class EigenCudaStreamDevice : public Eigen::StreamInterface {
...
@@ -127,6 +127,7 @@ class EigenCudaStreamDevice : public Eigen::StreamInterface {
CUDADeviceContext
::
CUDADeviceContext
(
CUDAPlace
place
)
:
place_
(
place
)
{
CUDADeviceContext
::
CUDADeviceContext
(
CUDAPlace
place
)
:
place_
(
place
)
{
SetDeviceId
(
place_
.
device
);
SetDeviceId
(
place_
.
device
);
compute_capability
=
GetCUDAComputeCapability
(
place_
.
device
);
multi_process
=
GetCUDAMultiProcessors
(
place_
.
device
);
multi_process
=
GetCUDAMultiProcessors
(
place_
.
device
);
max_threads_per_mp
=
GetCUDAMaxThreadsPerMultiProcessor
(
place_
.
device
);
max_threads_per_mp
=
GetCUDAMaxThreadsPerMultiProcessor
(
place_
.
device
);
PADDLE_ENFORCE
(
cudaStreamCreate
(
&
stream_
));
PADDLE_ENFORCE
(
cudaStreamCreate
(
&
stream_
));
...
@@ -162,6 +163,10 @@ void CUDADeviceContext::Wait() const {
...
@@ -162,6 +163,10 @@ void CUDADeviceContext::Wait() const {
PADDLE_ENFORCE
(
cudaGetLastError
());
PADDLE_ENFORCE
(
cudaGetLastError
());
}
}
int
CUDADeviceContext
::
GetComputeCapability
()
const
{
return
compute_capability
;
}
int
CUDADeviceContext
::
GetMaxPhysicalThreadCount
()
const
{
int
CUDADeviceContext
::
GetMaxPhysicalThreadCount
()
const
{
return
multi_process
*
max_threads_per_mp
;
return
multi_process
*
max_threads_per_mp
;
}
}
...
...
paddle/fluid/platform/device_context.h
浏览文件 @
3b44b849
...
@@ -79,6 +79,8 @@ class CUDADeviceContext : public DeviceContext {
...
@@ -79,6 +79,8 @@ class CUDADeviceContext : public DeviceContext {
/*! \brief Return place in the device context. */
/*! \brief Return place in the device context. */
Place
GetPlace
()
const
override
;
Place
GetPlace
()
const
override
;
int
GetComputeCapability
()
const
;
/*! \brief Return the max physical thread count in the device context */
/*! \brief Return the max physical thread count in the device context */
int
GetMaxPhysicalThreadCount
()
const
;
int
GetMaxPhysicalThreadCount
()
const
;
...
@@ -104,6 +106,7 @@ class CUDADeviceContext : public DeviceContext {
...
@@ -104,6 +106,7 @@ class CUDADeviceContext : public DeviceContext {
cudnnHandle_t
cudnn_handle_
;
cudnnHandle_t
cudnn_handle_
;
cublasHandle_t
cublas_handle_
;
cublasHandle_t
cublas_handle_
;
int
compute_capability
;
int
multi_process
;
int
multi_process
;
int
max_threads_per_mp
;
int
max_threads_per_mp
;
};
};
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录