Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
3a2afbf0
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
3a2afbf0
编写于
12月 25, 2018
作者:
S
sneaxiy
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
polish code
test=develop
上级
7f6e513b
变更
13
隐藏空白更改
内联
并排
Showing
13 changed file
with
56 addition
and
36 deletion
+56
-36
paddle/fluid/framework/operator.h
paddle/fluid/framework/operator.h
+0
-12
paddle/fluid/framework/var_type.h
paddle/fluid/framework/var_type.h
+5
-5
paddle/fluid/framework/var_type_inference_test.cc
paddle/fluid/framework/var_type_inference_test.cc
+1
-1
paddle/fluid/framework/var_type_traits.h
paddle/fluid/framework/var_type_traits.h
+3
-3
paddle/fluid/framework/var_type_traits_test.cc
paddle/fluid/framework/var_type_traits_test.cc
+17
-0
paddle/fluid/framework/variable.h
paddle/fluid/framework/variable.h
+6
-4
paddle/fluid/operators/affine_grid_op.cc
paddle/fluid/operators/affine_grid_op.cc
+2
-2
paddle/fluid/operators/conv_op.cc
paddle/fluid/operators/conv_op.cc
+2
-2
paddle/fluid/operators/grid_sampler_op.cc
paddle/fluid/operators/grid_sampler_op.cc
+2
-2
paddle/fluid/operators/pool_op.cc
paddle/fluid/operators/pool_op.cc
+2
-2
paddle/fluid/operators/softmax_op.cc
paddle/fluid/operators/softmax_op.cc
+2
-2
paddle/fluid/operators/warpctc_op.cc
paddle/fluid/operators/warpctc_op.cc
+1
-1
paddle/fluid/platform/cudnn_helper.h
paddle/fluid/platform/cudnn_helper.h
+13
-0
未找到文件。
paddle/fluid/framework/operator.h
浏览文件 @
3a2afbf0
...
@@ -310,18 +310,6 @@ class ExecutionContext {
...
@@ -310,18 +310,6 @@ class ExecutionContext {
const
RuntimeContext
&
ctx_
;
const
RuntimeContext
&
ctx_
;
};
};
inline
bool
CanCUDNNBeUsed
(
const
framework
::
ExecutionContext
&
ctx
)
{
bool
use_cudnn
=
ctx
.
Attr
<
bool
>
(
"use_cudnn"
);
use_cudnn
&=
paddle
::
platform
::
is_gpu_place
(
ctx
.
GetPlace
());
#ifdef PADDLE_WITH_CUDA
if
(
use_cudnn
)
{
auto
&
dev_ctx
=
ctx
.
device_context
<
platform
::
CUDADeviceContext
>
();
use_cudnn
&=
dev_ctx
.
cudnn_handle
()
!=
nullptr
;
}
#endif
return
use_cudnn
;
}
template
<>
template
<>
const
Tensor
*
ExecutionContext
::
Input
<
Tensor
>
(
const
std
::
string
&
name
)
const
;
const
Tensor
*
ExecutionContext
::
Input
<
Tensor
>
(
const
std
::
string
&
name
)
const
;
...
...
paddle/fluid/framework/var_type.h
浏览文件 @
3a2afbf0
...
@@ -46,19 +46,19 @@ inline proto::VarType::Type ToVarType(int type) {
...
@@ -46,19 +46,19 @@ inline proto::VarType::Type ToVarType(int type) {
template
<
typename
Visitor
>
template
<
typename
Visitor
>
inline
void
VisitVarType
(
const
framework
::
Variable
&
var
,
Visitor
visitor
)
{
inline
void
VisitVarType
(
const
framework
::
Variable
&
var
,
Visitor
visitor
)
{
switch
(
var
.
Type
())
{
switch
(
var
.
Type
())
{
case
proto
::
VarType
_Type_
LOD_TENSOR
:
case
proto
::
VarType
::
LOD_TENSOR
:
visitor
(
var
.
Get
<
LoDTensor
>
());
visitor
(
var
.
Get
<
LoDTensor
>
());
return
;
return
;
case
proto
::
VarType
_Type_
LOD_RANK_TABLE
:
case
proto
::
VarType
::
LOD_RANK_TABLE
:
visitor
(
var
.
Get
<
LoDRankTable
>
());
visitor
(
var
.
Get
<
LoDRankTable
>
());
return
;
return
;
case
proto
::
VarType
_Type_
LOD_TENSOR_ARRAY
:
case
proto
::
VarType
::
LOD_TENSOR_ARRAY
:
visitor
(
var
.
Get
<
LoDTensorArray
>
());
visitor
(
var
.
Get
<
LoDTensorArray
>
());
return
;
return
;
case
proto
::
VarType
_Type_
SELECTED_ROWS
:
case
proto
::
VarType
::
SELECTED_ROWS
:
visitor
(
var
.
Get
<
SelectedRows
>
());
visitor
(
var
.
Get
<
SelectedRows
>
());
return
;
return
;
case
proto
::
VarType
_Type_
READER
:
case
proto
::
VarType
::
READER
:
visitor
(
var
.
Get
<
ReaderHolder
>
());
visitor
(
var
.
Get
<
ReaderHolder
>
());
return
;
return
;
default:
default:
...
...
paddle/fluid/framework/var_type_inference_test.cc
浏览文件 @
3a2afbf0
...
@@ -108,7 +108,7 @@ TEST(InferVarType, sum_op_without_infer_var_type) {
...
@@ -108,7 +108,7 @@ TEST(InferVarType, sum_op_without_infer_var_type) {
op
->
InferVarType
(
prog
.
MutableBlock
(
0
));
op
->
InferVarType
(
prog
.
MutableBlock
(
0
));
ASSERT_EQ
(
proto
::
VarType
_Type_
LOD_TENSOR
,
ASSERT_EQ
(
proto
::
VarType
::
LOD_TENSOR
,
prog
.
MutableBlock
(
0
)
->
Var
(
"test2_out"
)
->
GetType
());
prog
.
MutableBlock
(
0
)
->
Var
(
"test2_out"
)
->
GetType
());
}
}
...
...
paddle/fluid/framework/var_type_traits.h
浏览文件 @
3a2afbf0
...
@@ -17,7 +17,7 @@
...
@@ -17,7 +17,7 @@
#include <map>
#include <map>
#include <string>
#include <string>
#include <tuple>
#include <tuple>
#include <typein
fo
>
#include <typein
dex
>
#include <vector>
#include <vector>
#include "paddle/fluid/framework/framework.pb.h"
#include "paddle/fluid/framework/framework.pb.h"
#include "paddle/fluid/framework/lod_tensor_array.h"
#include "paddle/fluid/framework/lod_tensor_array.h"
...
@@ -136,8 +136,6 @@ struct VarTypeRegistryImpl {
...
@@ -136,8 +136,6 @@ struct VarTypeRegistryImpl {
// Users should add other variable types below.
// Users should add other variable types below.
// Paddle would generate unique Ids for each registered variable types.
// Paddle would generate unique Ids for each registered variable types.
class
Scope
;
using
VarTypeRegistry
=
detail
::
VarTypeRegistryImpl
<
using
VarTypeRegistry
=
detail
::
VarTypeRegistryImpl
<
Tensor
,
LoDTensor
,
SelectedRows
,
std
::
vector
<
Scope
*>
,
LoDRankTable
,
Tensor
,
LoDTensor
,
SelectedRows
,
std
::
vector
<
Scope
*>
,
LoDRankTable
,
LoDTensorArray
,
platform
::
PlaceList
,
ReaderHolder
,
std
::
string
,
Scope
*
,
LoDTensorArray
,
platform
::
PlaceList
,
ReaderHolder
,
std
::
string
,
Scope
*
,
...
@@ -171,6 +169,8 @@ REG_PROTO_VAR_TYPE_TRAIT(LoDRankTable, proto::VarType::LOD_RANK_TABLE);
...
@@ -171,6 +169,8 @@ REG_PROTO_VAR_TYPE_TRAIT(LoDRankTable, proto::VarType::LOD_RANK_TABLE);
REG_PROTO_VAR_TYPE_TRAIT
(
LoDTensorArray
,
proto
::
VarType
::
LOD_TENSOR_ARRAY
);
REG_PROTO_VAR_TYPE_TRAIT
(
LoDTensorArray
,
proto
::
VarType
::
LOD_TENSOR_ARRAY
);
REG_PROTO_VAR_TYPE_TRAIT
(
platform
::
PlaceList
,
proto
::
VarType
::
PLACE_LIST
);
REG_PROTO_VAR_TYPE_TRAIT
(
platform
::
PlaceList
,
proto
::
VarType
::
PLACE_LIST
);
REG_PROTO_VAR_TYPE_TRAIT
(
ReaderHolder
,
proto
::
VarType
::
READER
);
REG_PROTO_VAR_TYPE_TRAIT
(
ReaderHolder
,
proto
::
VarType
::
READER
);
REG_PROTO_VAR_TYPE_TRAIT
(
int
,
proto
::
VarType
::
INT32
);
REG_PROTO_VAR_TYPE_TRAIT
(
float
,
proto
::
VarType
::
FP32
);
/** End of variable type registration */
/** End of variable type registration */
...
...
paddle/fluid/framework/var_type_traits_test.cc
浏览文件 @
3a2afbf0
...
@@ -88,6 +88,23 @@ TEST(var_type_traits, check_proto_type_id) {
...
@@ -88,6 +88,23 @@ TEST(var_type_traits, check_proto_type_id) {
ASSERT_TRUE
(
CheckVarId
<
LoDTensorArray
>
(
proto
::
VarType
::
LOD_TENSOR_ARRAY
));
ASSERT_TRUE
(
CheckVarId
<
LoDTensorArray
>
(
proto
::
VarType
::
LOD_TENSOR_ARRAY
));
ASSERT_TRUE
(
CheckVarId
<
platform
::
PlaceList
>
(
proto
::
VarType
::
PLACE_LIST
));
ASSERT_TRUE
(
CheckVarId
<
platform
::
PlaceList
>
(
proto
::
VarType
::
PLACE_LIST
));
ASSERT_TRUE
(
CheckVarId
<
ReaderHolder
>
(
proto
::
VarType
::
READER
));
ASSERT_TRUE
(
CheckVarId
<
ReaderHolder
>
(
proto
::
VarType
::
READER
));
ASSERT_TRUE
(
CheckVarId
<
int
>
(
proto
::
VarType
::
INT32
));
ASSERT_TRUE
(
CheckVarId
<
float
>
(
proto
::
VarType
::
FP32
));
ASSERT_EQ
(
proto
::
VarType_Type_LOD_TENSOR
,
proto
::
VarType
::
LOD_TENSOR
);
ASSERT_EQ
(
proto
::
VarType_Type_SELECTED_ROWS
,
proto
::
VarType
::
SELECTED_ROWS
);
ASSERT_EQ
(
proto
::
VarType_Type_STEP_SCOPES
,
proto
::
VarType
::
STEP_SCOPES
);
ASSERT_EQ
(
proto
::
VarType_Type_LOD_RANK_TABLE
,
proto
::
VarType
::
LOD_RANK_TABLE
);
ASSERT_EQ
(
proto
::
VarType_Type_LOD_TENSOR_ARRAY
,
proto
::
VarType
::
LOD_TENSOR_ARRAY
);
ASSERT_EQ
(
proto
::
VarType_Type_PLACE_LIST
,
proto
::
VarType
::
PLACE_LIST
);
ASSERT_EQ
(
proto
::
VarType_Type_READER
,
proto
::
VarType
::
READER
);
ASSERT_EQ
(
proto
::
VarType_Type_FEED_MINIBATCH
,
proto
::
VarType
::
FEED_MINIBATCH
);
ASSERT_EQ
(
proto
::
VarType_Type_FETCH_LIST
,
proto
::
VarType
::
FETCH_LIST
);
ASSERT_EQ
(
proto
::
VarType_Type_RAW
,
proto
::
VarType
::
RAW
);
ASSERT_EQ
(
proto
::
VarType_Type_TUPLE
,
proto
::
VarType
::
TUPLE
);
ASSERT_EQ
(
proto
::
VarType_Type_INT32
,
proto
::
VarType
::
INT32
);
ASSERT_EQ
(
proto
::
VarType_Type_FP32
,
proto
::
VarType
::
FP32
);
}
}
TEST
(
var_type_traits
,
test_registry
)
{
TEST
(
var_type_traits
,
test_registry
)
{
...
...
paddle/fluid/framework/variable.h
浏览文件 @
3a2afbf0
...
@@ -67,7 +67,6 @@ class Variable {
...
@@ -67,7 +67,6 @@ class Variable {
private:
private:
struct
Placeholder
{
struct
Placeholder
{
explicit
Placeholder
(
int
type
)
:
type_
(
type
)
{}
virtual
~
Placeholder
()
=
default
;
virtual
~
Placeholder
()
=
default
;
inline
int
Type
()
const
{
return
type_
;
}
inline
int
Type
()
const
{
return
type_
;
}
...
@@ -75,6 +74,11 @@ class Variable {
...
@@ -75,6 +74,11 @@ class Variable {
inline
void
*
Ptr
()
{
return
ptr_
;
}
inline
void
*
Ptr
()
{
return
ptr_
;
}
protected:
protected:
inline
void
Init
(
void
*
p
,
int
type
)
{
ptr_
=
p
;
type_
=
type
;
}
void
*
ptr_
;
void
*
ptr_
;
int
type_
;
int
type_
;
};
};
...
@@ -86,9 +90,7 @@ class Variable {
...
@@ -86,9 +90,7 @@ class Variable {
static_assert
(
static_assert
(
IsRegisteredVarType
<
T
>
(),
IsRegisteredVarType
<
T
>
(),
"Not registered type. Please register T inside var_type_traits.h"
);
"Not registered type. Please register T inside var_type_traits.h"
);
PlaceholderImpl
()
:
Placeholder
(
VarTypeTrait
<
T
>::
kId
)
{
PlaceholderImpl
()
{
this
->
Init
(
&
obj_
,
VarTypeTrait
<
T
>::
kId
);
}
this
->
ptr_
=
&
obj_
;
}
private:
private:
T
obj_
;
T
obj_
;
...
...
paddle/fluid/operators/affine_grid_op.cc
浏览文件 @
3a2afbf0
...
@@ -74,7 +74,7 @@ class AffineGridOp : public framework::OperatorWithKernel {
...
@@ -74,7 +74,7 @@ class AffineGridOp : public framework::OperatorWithKernel {
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
framework
::
LibraryType
library
{
framework
::
LibraryType
::
kPlain
};
framework
::
LibraryType
library
{
framework
::
LibraryType
::
kPlain
};
#ifdef PADDLE_WITH_CUDA
#ifdef PADDLE_WITH_CUDA
if
(
framework
::
CanCUDNNBeUsed
(
ctx
))
{
if
(
platform
::
CanCUDNNBeUsed
(
ctx
))
{
library
=
framework
::
LibraryType
::
kCUDNN
;
library
=
framework
::
LibraryType
::
kCUDNN
;
}
}
#endif
#endif
...
@@ -184,7 +184,7 @@ class AffineGridOpGrad : public framework::OperatorWithKernel {
...
@@ -184,7 +184,7 @@ class AffineGridOpGrad : public framework::OperatorWithKernel {
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
framework
::
LibraryType
library_
{
framework
::
LibraryType
::
kPlain
};
framework
::
LibraryType
library_
{
framework
::
LibraryType
::
kPlain
};
#ifdef PADDLE_WITH_CUDA
#ifdef PADDLE_WITH_CUDA
if
(
framework
::
CanCUDNNBeUsed
(
ctx
))
{
if
(
platform
::
CanCUDNNBeUsed
(
ctx
))
{
library_
=
framework
::
LibraryType
::
kCUDNN
;
library_
=
framework
::
LibraryType
::
kCUDNN
;
}
}
#endif
#endif
...
...
paddle/fluid/operators/conv_op.cc
浏览文件 @
3a2afbf0
...
@@ -84,7 +84,7 @@ framework::OpKernelType ConvOp::GetExpectedKernelType(
...
@@ -84,7 +84,7 @@ framework::OpKernelType ConvOp::GetExpectedKernelType(
framework
::
DataLayout
layout
=
framework
::
StringToDataLayout
(
data_format
);
framework
::
DataLayout
layout
=
framework
::
StringToDataLayout
(
data_format
);
#ifdef PADDLE_WITH_CUDA
#ifdef PADDLE_WITH_CUDA
if
(
framework
::
CanCUDNNBeUsed
(
ctx
))
{
if
(
platform
::
CanCUDNNBeUsed
(
ctx
))
{
library
=
framework
::
LibraryType
::
kCUDNN
;
library
=
framework
::
LibraryType
::
kCUDNN
;
}
}
#endif
#endif
...
@@ -369,7 +369,7 @@ framework::OpKernelType ConvOpGrad::GetExpectedKernelType(
...
@@ -369,7 +369,7 @@ framework::OpKernelType ConvOpGrad::GetExpectedKernelType(
framework
::
DataLayout
layout_
=
framework
::
StringToDataLayout
(
data_format
);
framework
::
DataLayout
layout_
=
framework
::
StringToDataLayout
(
data_format
);
#ifdef PADDLE_WITH_CUDA
#ifdef PADDLE_WITH_CUDA
if
(
framework
::
CanCUDNNBeUsed
(
ctx
))
{
if
(
platform
::
CanCUDNNBeUsed
(
ctx
))
{
library_
=
framework
::
LibraryType
::
kCUDNN
;
library_
=
framework
::
LibraryType
::
kCUDNN
;
}
}
#endif
#endif
...
...
paddle/fluid/operators/grid_sampler_op.cc
浏览文件 @
3a2afbf0
...
@@ -59,7 +59,7 @@ class GridSampleOp : public framework::OperatorWithKernel {
...
@@ -59,7 +59,7 @@ class GridSampleOp : public framework::OperatorWithKernel {
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
framework
::
LibraryType
library_
{
framework
::
LibraryType
::
kPlain
};
framework
::
LibraryType
library_
{
framework
::
LibraryType
::
kPlain
};
#ifdef PADDLE_WITH_CUDA
#ifdef PADDLE_WITH_CUDA
if
(
framework
::
CanCUDNNBeUsed
(
ctx
))
{
if
(
platform
::
CanCUDNNBeUsed
(
ctx
))
{
library_
=
framework
::
LibraryType
::
kCUDNN
;
library_
=
framework
::
LibraryType
::
kCUDNN
;
}
}
#endif
#endif
...
@@ -155,7 +155,7 @@ class GridSampleOpGrad : public framework::OperatorWithKernel {
...
@@ -155,7 +155,7 @@ class GridSampleOpGrad : public framework::OperatorWithKernel {
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
framework
::
LibraryType
library_
{
framework
::
LibraryType
::
kPlain
};
framework
::
LibraryType
library_
{
framework
::
LibraryType
::
kPlain
};
#ifdef PADDLE_WITH_CUDA
#ifdef PADDLE_WITH_CUDA
if
(
framework
::
CanCUDNNBeUsed
(
ctx
))
{
if
(
platform
::
CanCUDNNBeUsed
(
ctx
))
{
library_
=
framework
::
LibraryType
::
kCUDNN
;
library_
=
framework
::
LibraryType
::
kCUDNN
;
}
}
#endif
#endif
...
...
paddle/fluid/operators/pool_op.cc
浏览文件 @
3a2afbf0
...
@@ -92,7 +92,7 @@ framework::OpKernelType PoolOp::GetExpectedKernelType(
...
@@ -92,7 +92,7 @@ framework::OpKernelType PoolOp::GetExpectedKernelType(
framework
::
DataLayout
layout_
=
framework
::
StringToDataLayout
(
data_format
);
framework
::
DataLayout
layout_
=
framework
::
StringToDataLayout
(
data_format
);
#ifdef PADDLE_WITH_CUDA
#ifdef PADDLE_WITH_CUDA
if
(
framework
::
CanCUDNNBeUsed
(
ctx
))
{
if
(
platform
::
CanCUDNNBeUsed
(
ctx
))
{
library_
=
framework
::
LibraryType
::
kCUDNN
;
library_
=
framework
::
LibraryType
::
kCUDNN
;
}
}
#endif
#endif
...
@@ -122,7 +122,7 @@ framework::OpKernelType PoolOpGrad::GetExpectedKernelType(
...
@@ -122,7 +122,7 @@ framework::OpKernelType PoolOpGrad::GetExpectedKernelType(
framework
::
DataLayout
layout_
=
framework
::
StringToDataLayout
(
data_format
);
framework
::
DataLayout
layout_
=
framework
::
StringToDataLayout
(
data_format
);
#ifdef PADDLE_WITH_CUDA
#ifdef PADDLE_WITH_CUDA
if
(
framework
::
CanCUDNNBeUsed
(
ctx
))
{
if
(
platform
::
CanCUDNNBeUsed
(
ctx
))
{
library_
=
framework
::
LibraryType
::
kCUDNN
;
library_
=
framework
::
LibraryType
::
kCUDNN
;
}
}
#endif
#endif
...
...
paddle/fluid/operators/softmax_op.cc
浏览文件 @
3a2afbf0
...
@@ -50,7 +50,7 @@ class SoftmaxOp : public framework::OperatorWithKernel {
...
@@ -50,7 +50,7 @@ class SoftmaxOp : public framework::OperatorWithKernel {
framework
::
DataLayout
layout_
=
framework
::
StringToDataLayout
(
data_format
);
framework
::
DataLayout
layout_
=
framework
::
StringToDataLayout
(
data_format
);
#ifdef PADDLE_WITH_CUDA
#ifdef PADDLE_WITH_CUDA
if
(
framework
::
CanCUDNNBeUsed
(
ctx
))
{
if
(
platform
::
CanCUDNNBeUsed
(
ctx
))
{
library_
=
framework
::
LibraryType
::
kCUDNN
;
library_
=
framework
::
LibraryType
::
kCUDNN
;
}
}
#endif
#endif
...
@@ -157,7 +157,7 @@ class SoftmaxOpGrad : public framework::OperatorWithKernel {
...
@@ -157,7 +157,7 @@ class SoftmaxOpGrad : public framework::OperatorWithKernel {
framework
::
DataLayout
layout_
=
framework
::
StringToDataLayout
(
data_format
);
framework
::
DataLayout
layout_
=
framework
::
StringToDataLayout
(
data_format
);
#ifdef PADDLE_WITH_CUDA
#ifdef PADDLE_WITH_CUDA
if
(
framework
::
CanCUDNNBeUsed
(
ctx
))
{
if
(
platform
::
CanCUDNNBeUsed
(
ctx
))
{
library_
=
framework
::
LibraryType
::
kCUDNN
;
library_
=
framework
::
LibraryType
::
kCUDNN
;
}
}
#endif
#endif
...
...
paddle/fluid/operators/warpctc_op.cc
浏览文件 @
3a2afbf0
...
@@ -51,7 +51,7 @@ class WarpCTCOp : public framework::OperatorWithKernel {
...
@@ -51,7 +51,7 @@ class WarpCTCOp : public framework::OperatorWithKernel {
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
framework
::
LibraryType
library_
{
framework
::
LibraryType
::
kPlain
};
framework
::
LibraryType
library_
{
framework
::
LibraryType
::
kPlain
};
#ifdef PADDLE_WITH_CUDA
#ifdef PADDLE_WITH_CUDA
if
(
framework
::
CanCUDNNBeUsed
(
ctx
))
{
if
(
platform
::
CanCUDNNBeUsed
(
ctx
))
{
library_
=
framework
::
LibraryType
::
kCUDNN
;
library_
=
framework
::
LibraryType
::
kCUDNN
;
}
}
#endif
#endif
...
...
paddle/fluid/platform/cudnn_helper.h
浏览文件 @
3a2afbf0
...
@@ -17,6 +17,7 @@ limitations under the License. */
...
@@ -17,6 +17,7 @@ limitations under the License. */
#include <string>
#include <string>
#include <vector>
#include <vector>
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/platform/dynload/cudnn.h"
#include "paddle/fluid/platform/dynload/cudnn.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/float16.h"
#include "paddle/fluid/platform/float16.h"
...
@@ -450,6 +451,18 @@ class ScopedActivationDescriptor {
...
@@ -450,6 +451,18 @@ class ScopedActivationDescriptor {
DISABLE_COPY_AND_ASSIGN
(
ScopedActivationDescriptor
);
DISABLE_COPY_AND_ASSIGN
(
ScopedActivationDescriptor
);
};
};
inline
bool
CanCUDNNBeUsed
(
const
framework
::
ExecutionContext
&
ctx
)
{
bool
use_cudnn
=
ctx
.
Attr
<
bool
>
(
"use_cudnn"
);
use_cudnn
&=
paddle
::
platform
::
is_gpu_place
(
ctx
.
GetPlace
());
#ifdef PADDLE_WITH_CUDA
if
(
use_cudnn
)
{
auto
&
dev_ctx
=
ctx
.
device_context
<
platform
::
CUDADeviceContext
>
();
use_cudnn
&=
dev_ctx
.
cudnn_handle
()
!=
nullptr
;
}
#endif
return
use_cudnn
;
}
#if CUDNN_VERSION >= 7001
#if CUDNN_VERSION >= 7001
class
ScopedCTCLossDescriptor
{
class
ScopedCTCLossDescriptor
{
public:
public:
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录