Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
bfb07aaf
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2298
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
bfb07aaf
编写于
4月 02, 2020
作者:
Z
zhongpu
提交者:
GitHub
4月 02, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Revert "Exhaustive search (#22821)", test=develop (#23401)
This reverts commit
48144e40
.
上级
7fda333a
变更
15
隐藏空白更改
内联
并排
Showing
15 changed file
with
200 addition
and
197 deletion
+200
-197
paddle/fluid/framework/operator.cc
paddle/fluid/framework/operator.cc
+15
-3
paddle/fluid/framework/operator.h
paddle/fluid/framework/operator.h
+34
-2
paddle/fluid/framework/operator_kernel_configs.h
paddle/fluid/framework/operator_kernel_configs.h
+21
-61
paddle/fluid/framework/operator_test.cc
paddle/fluid/framework/operator_test.cc
+1
-1
paddle/fluid/imperative/execution_context.h
paddle/fluid/imperative/execution_context.h
+2
-1
paddle/fluid/imperative/prepared_operator.cc
paddle/fluid/imperative/prepared_operator.cc
+21
-11
paddle/fluid/imperative/prepared_operator.h
paddle/fluid/imperative/prepared_operator.h
+3
-1
paddle/fluid/imperative/tests/test_layer.cc
paddle/fluid/imperative/tests/test_layer.cc
+1
-1
paddle/fluid/operators/beam_search_decode_op.cc
paddle/fluid/operators/beam_search_decode_op.cc
+1
-1
paddle/fluid/operators/conv_cudnn_helper.h
paddle/fluid/operators/conv_cudnn_helper.h
+27
-69
paddle/fluid/operators/conv_cudnn_op.cu
paddle/fluid/operators/conv_cudnn_op.cu
+21
-38
paddle/fluid/operators/conv_op.cc
paddle/fluid/operators/conv_op.cc
+45
-0
paddle/fluid/operators/elementwise/test_elementwise_mul_op_dim.cc
...luid/operators/elementwise/test_elementwise_mul_op_dim.cc
+2
-1
paddle/fluid/operators/fused/conv_fusion_op.cu
paddle/fluid/operators/fused/conv_fusion_op.cu
+3
-4
paddle/fluid/operators/warpctc_op.cc
paddle/fluid/operators/warpctc_op.cc
+3
-3
未找到文件。
paddle/fluid/framework/operator.cc
浏览文件 @
bfb07aaf
...
@@ -905,6 +905,16 @@ void OperatorWithKernel::RuntimeInferShape(const Scope& scope,
...
@@ -905,6 +905,16 @@ void OperatorWithKernel::RuntimeInferShape(const Scope& scope,
this
->
InferShape
(
&
infer_shape_ctx
);
this
->
InferShape
(
&
infer_shape_ctx
);
}
}
std
::
vector
<
KernelConfig
>*
OperatorWithKernel
::
GetKernelConfig
(
const
OpKernelType
&
key
)
const
{
auto
config_iter
=
kernel_configs_map_
.
find
(
key
);
std
::
vector
<
KernelConfig
>*
kernel_configs
=
nullptr
;
if
(
config_iter
!=
kernel_configs_map_
.
end
())
{
kernel_configs
=
&
(
config_iter
->
second
);
}
return
kernel_configs
;
}
void
OperatorWithKernel
::
RunImpl
(
const
Scope
&
scope
,
void
OperatorWithKernel
::
RunImpl
(
const
Scope
&
scope
,
const
platform
::
Place
&
place
)
const
{
const
platform
::
Place
&
place
)
const
{
// To reduce the elapsed time of HasAttr, we use bool variable to record the
// To reduce the elapsed time of HasAttr, we use bool variable to record the
...
@@ -941,6 +951,8 @@ void OperatorWithKernel::RunImpl(const Scope& scope,
...
@@ -941,6 +951,8 @@ void OperatorWithKernel::RunImpl(const Scope& scope,
ChooseKernel
(
*
runtime_ctx
,
scope
,
place
);
ChooseKernel
(
*
runtime_ctx
,
scope
,
place
);
}
}
std
::
vector
<
KernelConfig
>*
kernel_configs
=
GetKernelConfig
(
*
kernel_type_
);
// do data transformScope &transfer_scope;
// do data transformScope &transfer_scope;
std
::
vector
<
std
::
string
>
transfered_inplace_vars
;
std
::
vector
<
std
::
string
>
transfered_inplace_vars
;
Scope
*
transfer_scope
=
nullptr
;
Scope
*
transfer_scope
=
nullptr
;
...
@@ -976,8 +988,8 @@ void OperatorWithKernel::RunImpl(const Scope& scope,
...
@@ -976,8 +988,8 @@ void OperatorWithKernel::RunImpl(const Scope& scope,
{
{
platform
::
RecordEvent
record_event
(
"compute"
,
platform
::
RecordEvent
record_event
(
"compute"
,
platform
::
EventRole
::
kInnerOp
);
platform
::
EventRole
::
kInnerOp
);
(
*
kernel_func_
)(
(
*
kernel_func_
)(
ExecutionContext
(
*
this
,
exec_scope
,
*
dev_ctx
,
*
runtime_ctx
,
ExecutionContext
(
*
this
,
exec_scope
,
*
dev_ctx
,
*
runtime_ctx
));
kernel_configs
));
}
}
if
(
!
transfered_inplace_vars
.
empty
())
{
if
(
!
transfered_inplace_vars
.
empty
())
{
...
@@ -1046,7 +1058,7 @@ void OperatorWithKernel::ChooseKernel(const RuntimeContext& ctx,
...
@@ -1046,7 +1058,7 @@ void OperatorWithKernel::ChooseKernel(const RuntimeContext& ctx,
OpKernelMap
&
kernels
=
kernels_iter
->
second
;
OpKernelMap
&
kernels
=
kernels_iter
->
second
;
auto
expected_kernel_key
=
this
->
GetExpectedKernelType
(
auto
expected_kernel_key
=
this
->
GetExpectedKernelType
(
ExecutionContext
(
*
this
,
scope
,
*
dev_ctx
,
ctx
));
ExecutionContext
(
*
this
,
scope
,
*
dev_ctx
,
ctx
,
nullptr
));
if
(
HasAttr
(
"op_device"
))
{
if
(
HasAttr
(
"op_device"
))
{
if
(
Attr
<
std
::
string
>
(
"op_device"
)
==
"cpu"
)
{
if
(
Attr
<
std
::
string
>
(
"op_device"
)
==
"cpu"
)
{
expected_kernel_key
.
place_
=
platform
::
CPUPlace
();
expected_kernel_key
.
place_
=
platform
::
CPUPlace
();
...
...
paddle/fluid/framework/operator.h
浏览文件 @
bfb07aaf
...
@@ -31,6 +31,7 @@ limitations under the License. */
...
@@ -31,6 +31,7 @@ limitations under the License. */
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_info.h"
#include "paddle/fluid/framework/op_info.h"
#include "paddle/fluid/framework/op_kernel_type.h"
#include "paddle/fluid/framework/op_kernel_type.h"
#include "paddle/fluid/framework/operator_kernel_configs.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/framework/tensor.h"
...
@@ -215,12 +216,30 @@ class OperatorBase {
...
@@ -215,12 +216,30 @@ class OperatorBase {
const
platform
::
Place
&
place
)
const
=
0
;
const
platform
::
Place
&
place
)
const
=
0
;
};
};
#ifdef PADDLE_WITH_CUDA
using
KernelConfig
=
boost
::
variant
<
std
::
shared_ptr
<
AlgorithmsCache
<
cudnnConvolutionFwdAlgo_t
>>
,
std
::
shared_ptr
<
AlgorithmsCache
<
cudnnConvolutionBwdDataAlgo_t
>>
,
std
::
shared_ptr
<
AlgorithmsCache
<
cudnnConvolutionBwdFilterAlgo_t
>>>
;
#else
using
KernelConfig
=
boost
::
variant
<
boost
::
blank
>
;
#endif
using
OpKernelConfigsMap
=
std
::
unordered_map
<
OpKernelType
,
std
::
vector
<
KernelConfig
>
,
OpKernelType
::
Hash
>
;
class
ExecutionContext
{
class
ExecutionContext
{
public:
public:
ExecutionContext
(
const
OperatorBase
&
op
,
const
Scope
&
scope
,
ExecutionContext
(
const
OperatorBase
&
op
,
const
Scope
&
scope
,
const
platform
::
DeviceContext
&
device_context
,
const
platform
::
DeviceContext
&
device_context
,
const
RuntimeContext
&
ctx
)
const
RuntimeContext
&
ctx
,
:
op_
(
op
),
scope_
(
scope
),
device_context_
(
device_context
),
ctx_
(
ctx
)
{}
std
::
vector
<
KernelConfig
>*
configs
)
:
op_
(
op
),
scope_
(
scope
),
device_context_
(
device_context
),
ctx_
(
ctx
),
kernel_configs_
(
configs
)
{}
virtual
~
ExecutionContext
()
{}
virtual
~
ExecutionContext
()
{}
virtual
std
::
string
InputName
(
const
std
::
string
&
name
)
const
{
virtual
std
::
string
InputName
(
const
std
::
string
&
name
)
const
{
...
@@ -386,6 +405,15 @@ class ExecutionContext {
...
@@ -386,6 +405,15 @@ class ExecutionContext {
return
temp_tensor
;
return
temp_tensor
;
}
}
template
<
typename
T
>
T
&
GetKernelConfig
(
size_t
idx
)
const
{
PADDLE_ENFORCE
(
kernel_configs_
&&
kernel_configs_
->
size
()
>
static_cast
<
size_t
>
(
idx
),
"%s selected kernel doesn't have kernel config %lu <= %lu"
,
op_
.
Type
().
c_str
(),
kernel_configs_
->
size
(),
idx
);
return
*
boost
::
get
<
std
::
shared_ptr
<
T
>>
((
*
kernel_configs_
)[
idx
]);
}
const
RuntimeContext
Context
()
const
{
return
ctx_
;
}
const
RuntimeContext
Context
()
const
{
return
ctx_
;
}
std
::
string
DebugString
()
const
{
return
op_
.
DebugString
();
}
std
::
string
DebugString
()
const
{
return
op_
.
DebugString
();
}
...
@@ -395,6 +423,7 @@ class ExecutionContext {
...
@@ -395,6 +423,7 @@ class ExecutionContext {
const
Scope
&
scope_
;
const
Scope
&
scope_
;
const
platform
::
DeviceContext
&
device_context_
;
const
platform
::
DeviceContext
&
device_context_
;
const
RuntimeContext
&
ctx_
;
const
RuntimeContext
&
ctx_
;
mutable
std
::
vector
<
KernelConfig
>*
kernel_configs_
;
};
};
template
<>
template
<>
...
@@ -470,6 +499,8 @@ class OperatorWithKernel : public OperatorBase {
...
@@ -470,6 +499,8 @@ class OperatorWithKernel : public OperatorBase {
virtual
OpKernelType
GetExpectedKernelType
(
const
ExecutionContext
&
ctx
)
const
;
virtual
OpKernelType
GetExpectedKernelType
(
const
ExecutionContext
&
ctx
)
const
;
std
::
vector
<
KernelConfig
>*
GetKernelConfig
(
const
OpKernelType
&
key
)
const
;
// change this to public so that in dygraph mode we can call it to check if we
// change this to public so that in dygraph mode we can call it to check if we
// need transform data
// need transform data
virtual
OpKernelType
GetKernelTypeForVar
(
virtual
OpKernelType
GetKernelTypeForVar
(
...
@@ -506,6 +537,7 @@ class OperatorWithKernel : public OperatorBase {
...
@@ -506,6 +537,7 @@ class OperatorWithKernel : public OperatorBase {
const
platform
::
Place
&
place
)
const
;
const
platform
::
Place
&
place
)
const
;
protected:
protected:
mutable
OpKernelConfigsMap
kernel_configs_map_
;
mutable
std
::
unique_ptr
<
OpKernelType
>
kernel_type_
;
mutable
std
::
unique_ptr
<
OpKernelType
>
kernel_type_
;
mutable
std
::
unique_ptr
<
OpKernelFunc
>
kernel_func_
;
mutable
std
::
unique_ptr
<
OpKernelFunc
>
kernel_func_
;
mutable
std
::
unique_ptr
<
RuntimeContext
>
runtime_ctx_
;
mutable
std
::
unique_ptr
<
RuntimeContext
>
runtime_ctx_
;
...
...
paddle/fluid/framework/operator_kernel_configs.h
浏览文件 @
bfb07aaf
...
@@ -21,21 +21,19 @@ limitations under the License. */
...
@@ -21,21 +21,19 @@ limitations under the License. */
namespace
paddle
{
namespace
paddle
{
namespace
framework
{
namespace
framework
{
//
thread-safe
.
//
Not thread-safe. Should be owned per-kernel
.
template
<
typename
TAlgorithm
>
template
<
typename
TAlgorithm
>
class
AlgorithmsCache
{
class
AlgorithmsCache
{
public:
public:
AlgorithmsCache
()
:
search_times_
(
0
)
{
hash_
.
clear
();
}
AlgorithmsCache
()
:
search_times_
(
0
)
{
hash_
.
clear
();
}
// Caches the best algorithm for a given
// Caches the best algorithm for a given
// combination of tensor dimensions & compute data type.
// combination of tensor dimensions & compute data type.
// cudnn_dtype set for different data type
TAlgorithm
GetAlgorithm
(
TAlgorithm
GetAlgorithm
(
const
std
::
vector
<
int64_t
>&
dims1
,
const
std
::
vector
<
int64_t
>&
dims1
,
const
std
::
vector
<
int64_t
>&
dims2
,
const
std
::
vector
<
int64_t
>&
dims2
,
const
std
::
vector
<
int
>&
strides
,
const
std
::
vector
<
int
>&
paddings
,
const
std
::
vector
<
int
>&
strides
,
const
std
::
vector
<
int
>&
dilations
,
const
std
::
vector
<
int
>&
paddings
,
int
algorithmFlags
,
// can set for different data type
const
std
::
vector
<
int
>&
dilations
,
int
algorithmFlags
,
std
::
function
<
TAlgorithm
()
>
gen_func
);
int64_t
cudnn_dtype
,
std
::
function
<
TAlgorithm
()
>
gen_func
);
TAlgorithm
GetAlgorithm
(
int64_t
area
,
int
search_times
,
int
algorithmFlags
,
TAlgorithm
GetAlgorithm
(
int64_t
area
,
int
search_times
,
int
algorithmFlags
,
std
::
function
<
TAlgorithm
()
>
gen_func
);
std
::
function
<
TAlgorithm
()
>
gen_func
);
...
@@ -43,14 +41,13 @@ class AlgorithmsCache {
...
@@ -43,14 +41,13 @@ class AlgorithmsCache {
private:
private:
std
::
unordered_map
<
int64_t
,
TAlgorithm
>
hash_
;
std
::
unordered_map
<
int64_t
,
TAlgorithm
>
hash_
;
int
search_times_
;
int
search_times_
;
std
::
mutex
cache_mutex
;
};
};
template
<
typename
TAlgorithm
>
template
<
typename
TAlgorithm
>
TAlgorithm
framework
::
AlgorithmsCache
<
TAlgorithm
>::
GetAlgorithm
(
TAlgorithm
framework
::
AlgorithmsCache
<
TAlgorithm
>::
GetAlgorithm
(
const
std
::
vector
<
int64_t
>&
dims1
,
const
std
::
vector
<
int64_t
>&
dims2
,
const
std
::
vector
<
int64_t
>&
dims1
,
const
std
::
vector
<
int64_t
>&
dims2
,
const
std
::
vector
<
int
>&
strides
,
const
std
::
vector
<
int
>&
paddings
,
const
std
::
vector
<
int
>&
strides
,
const
std
::
vector
<
int
>&
paddings
,
const
std
::
vector
<
int
>&
dilations
,
int
algorithmFlags
,
int64_t
cudnn_dtype
,
const
std
::
vector
<
int
>&
dilations
,
int
algorithmFlags
,
std
::
function
<
TAlgorithm
()
>
gen_func
)
{
std
::
function
<
TAlgorithm
()
>
gen_func
)
{
int64_t
seed
=
0
;
int64_t
seed
=
0
;
// Hash all of the inputs, use to try and look up a previously
// Hash all of the inputs, use to try and look up a previously
...
@@ -84,73 +81,36 @@ TAlgorithm framework::AlgorithmsCache<TAlgorithm>::GetAlgorithm(
...
@@ -84,73 +81,36 @@ TAlgorithm framework::AlgorithmsCache<TAlgorithm>::GetAlgorithm(
seed
^=
hashFn
(
static_cast
<
int64_t
>
(
algorithmFlags
))
+
0x9e3779b9
+
seed
^=
hashFn
(
static_cast
<
int64_t
>
(
algorithmFlags
))
+
0x9e3779b9
+
(
seed
<<
6
)
+
(
seed
>>
2
)
+
5
;
(
seed
<<
6
)
+
(
seed
>>
2
)
+
5
;
seed
^=
hashFn
(
static_cast
<
int64_t
>
(
cudnn_dtype
))
+
0x9e3779b9
+
(
seed
<<
6
)
+
(
seed
>>
2
)
+
6
;
VLOG
(
10
)
<<
"seed:"
<<
seed
<<
", hash_.size:"
<<
hash_
.
size
();
VLOG
(
10
)
<<
"seed:"
<<
seed
<<
", hash_.size:"
<<
hash_
.
size
();
if
(
seed
==
0
)
return
gen_func
();
if
(
seed
==
0
)
return
gen_func
();
TAlgorithm
ret
;
if
(
hash_
.
find
(
seed
)
==
hash_
.
end
())
{
auto
it
=
hash_
.
end
();
TAlgorithm
value
=
gen_func
();
bool
have_found
=
false
;
hash_
[
seed
]
=
value
;
{
std
::
lock_guard
<
std
::
mutex
>
lock
(
cache_mutex
);
it
=
hash_
.
find
(
seed
);
if
(
it
!=
hash_
.
end
())
{
ret
=
it
->
second
;
have_found
=
true
;
}
}
if
(
!
have_found
)
{
ret
=
gen_func
();
std
::
lock_guard
<
std
::
mutex
>
lock
(
cache_mutex
);
hash_
[
seed
]
=
ret
;
}
}
return
hash_
[
seed
];
return
ret
;
}
}
template
<
typename
TAlgorithm
>
template
<
typename
TAlgorithm
>
TAlgorithm
AlgorithmsCache
<
TAlgorithm
>::
GetAlgorithm
(
TAlgorithm
AlgorithmsCache
<
TAlgorithm
>::
GetAlgorithm
(
int64_t
area
,
int
search_times
,
int
algorithmFlags
,
int64_t
area
,
int
search_times
,
int
algorithmFlags
,
std
::
function
<
TAlgorithm
()
>
gen_func
)
{
std
::
function
<
TAlgorithm
()
>
gen_func
)
{
auto
it
=
hash_
.
end
();
if
(
hash_
.
find
(
area
)
!=
hash_
.
end
())
{
{
return
hash_
[
area
];
std
::
lock_guard
<
std
::
mutex
>
lock
(
cache_mutex
);
it
=
hash_
.
find
(
area
);
if
(
it
!=
hash_
.
end
())
{
return
it
->
second
;
}
}
bool
gene_flag
=
false
;
{
std
::
lock_guard
<
std
::
mutex
>
lock
(
cache_mutex
);
gene_flag
=
(
search_times_
<
search_times
);
}
}
if
(
search_times_
<
search_times
)
{
TAlgorithm
algo
{};
auto
algo
=
gen_func
();
if
(
gene_flag
)
{
algo
=
gen_func
();
std
::
lock_guard
<
std
::
mutex
>
lock
(
cache_mutex
);
hash_
[
area
]
=
algo
;
hash_
[
area
]
=
algo
;
++
search_times_
;
++
search_times_
;
return
algo
;
return
algo
;
}
}
TAlgorithm
algo
{};
int64_t
min
=
static_cast
<
uint64_t
>
(
INT_MAX
);
int64_t
min
=
static_cast
<
uint64_t
>
(
INT_MAX
);
{
for
(
const
auto
&
m
:
hash_
)
{
std
::
lock_guard
<
std
::
mutex
>
lock
(
cache_mutex
);
if
(
m
.
first
<
min
)
{
for
(
const
auto
&
m
:
hash_
)
{
min
=
m
.
first
;
if
(
m
.
first
<
min
)
{
algo
=
m
.
second
;
min
=
m
.
first
;
algo
=
m
.
second
;
}
}
}
}
}
return
algo
;
return
algo
;
...
...
paddle/fluid/framework/operator_test.cc
浏览文件 @
bfb07aaf
...
@@ -525,7 +525,7 @@ TEST(ExecutionContextAttrAndInOut, new_api) {
...
@@ -525,7 +525,7 @@ TEST(ExecutionContextAttrAndInOut, new_api) {
paddle
::
framework
::
RuntimeContext
ctx
({},
{});
paddle
::
framework
::
RuntimeContext
ctx
({},
{});
paddle
::
framework
::
ExecutionContext
exe_context
(
*
(
op
.
get
()),
scope
,
*
dev_ctx
,
paddle
::
framework
::
ExecutionContext
exe_context
(
*
(
op
.
get
()),
scope
,
*
dev_ctx
,
ctx
);
ctx
,
nullptr
);
ASSERT_EQ
(
exe_context
.
InputSize
(
"input"
),
1u
);
ASSERT_EQ
(
exe_context
.
InputSize
(
"input"
),
1u
);
ASSERT_EQ
(
exe_context
.
OutputSize
(
"output"
),
1u
);
ASSERT_EQ
(
exe_context
.
OutputSize
(
"output"
),
1u
);
...
...
paddle/fluid/imperative/execution_context.h
浏览文件 @
bfb07aaf
...
@@ -33,10 +33,11 @@ class DygraphExecutionContext : public framework::ExecutionContext {
...
@@ -33,10 +33,11 @@ class DygraphExecutionContext : public framework::ExecutionContext {
const
framework
::
Scope
&
scope
,
const
framework
::
Scope
&
scope
,
const
platform
::
DeviceContext
&
device_context
,
const
platform
::
DeviceContext
&
device_context
,
const
framework
::
RuntimeContext
&
ctx
,
const
framework
::
RuntimeContext
&
ctx
,
std
::
vector
<
framework
::
KernelConfig
>*
configs
,
const
NameVarMap
<
VarType
>&
var_base_map_in
,
const
NameVarMap
<
VarType
>&
var_base_map_in
,
const
NameVarMap
<
VarType
>&
var_base_map_out
,
const
NameVarMap
<
VarType
>&
var_base_map_out
,
const
framework
::
AttributeMap
&
attrs
)
const
framework
::
AttributeMap
&
attrs
)
:
ExecutionContext
(
op
,
scope
,
device_context
,
ctx
),
:
ExecutionContext
(
op
,
scope
,
device_context
,
ctx
,
configs
),
var_base_map_in_
(
var_base_map_in
),
var_base_map_in_
(
var_base_map_in
),
var_base_map_out_
(
var_base_map_out
),
var_base_map_out_
(
var_base_map_out
),
attrs_
(
attrs
)
{}
attrs_
(
attrs
)
{}
...
...
paddle/fluid/imperative/prepared_operator.cc
浏览文件 @
bfb07aaf
...
@@ -80,8 +80,13 @@ void PreparedOp::PrepareData(
...
@@ -80,8 +80,13 @@ void PreparedOp::PrepareData(
PreparedOp
::
PreparedOp
(
const
framework
::
OperatorBase
&
op
,
PreparedOp
::
PreparedOp
(
const
framework
::
OperatorBase
&
op
,
const
framework
::
RuntimeContext
&
ctx
,
const
framework
::
RuntimeContext
&
ctx
,
const
framework
::
OperatorWithKernel
::
OpKernelFunc
&
func
,
const
framework
::
OperatorWithKernel
::
OpKernelFunc
&
func
,
platform
::
DeviceContext
*
dev_ctx
)
platform
::
DeviceContext
*
dev_ctx
,
:
op_
(
op
),
ctx_
(
ctx
),
func_
(
func
),
dev_ctx_
(
dev_ctx
)
{}
std
::
vector
<
framework
::
KernelConfig
>*
kernel_configs
)
:
op_
(
op
),
ctx_
(
ctx
),
func_
(
func
),
dev_ctx_
(
dev_ctx
),
kernel_configs_
(
kernel_configs
)
{}
template
<
typename
VarType
>
template
<
typename
VarType
>
PreparedOp
PrepareOpImpl
(
const
NameVarMap
<
VarType
>&
ins
,
PreparedOp
PrepareOpImpl
(
const
NameVarMap
<
VarType
>&
ins
,
...
@@ -106,7 +111,7 @@ PreparedOp PrepareOpImpl(const NameVarMap<VarType>& ins,
...
@@ -106,7 +111,7 @@ PreparedOp PrepareOpImpl(const NameVarMap<VarType>& ins,
framework
::
RuntimeContext
ctx
({},
{});
framework
::
RuntimeContext
ctx
({},
{});
auto
expected_kernel_key
=
auto
expected_kernel_key
=
op
.
GetExpectedKernelType
(
DygraphExecutionContext
<
VarType
>
(
op
.
GetExpectedKernelType
(
DygraphExecutionContext
<
VarType
>
(
op
,
framework
::
Scope
(),
*
dev_ctx
,
ctx
,
ins
,
outs
,
attrs
));
op
,
framework
::
Scope
(),
*
dev_ctx
,
ctx
,
nullptr
,
ins
,
outs
,
attrs
));
VLOG
(
3
)
<<
"expected_kernel_key:"
<<
expected_kernel_key
;
VLOG
(
3
)
<<
"expected_kernel_key:"
<<
expected_kernel_key
;
auto
kernel_iter
=
kernels
.
find
(
expected_kernel_key
);
auto
kernel_iter
=
kernels
.
find
(
expected_kernel_key
);
...
@@ -115,6 +120,8 @@ PreparedOp PrepareOpImpl(const NameVarMap<VarType>& ins,
...
@@ -115,6 +120,8 @@ PreparedOp PrepareOpImpl(const NameVarMap<VarType>& ins,
PADDLE_THROW
(
"op %s does not have kernel for %s"
,
op
.
Type
(),
PADDLE_THROW
(
"op %s does not have kernel for %s"
,
op
.
Type
(),
KernelTypeToString
(
expected_kernel_key
));
KernelTypeToString
(
expected_kernel_key
));
}
}
std
::
vector
<
framework
::
KernelConfig
>*
kernel_configs
=
op
.
GetKernelConfig
(
expected_kernel_key
);
if
(
!
(
expected_kernel_key
.
place_
==
place
))
{
if
(
!
(
expected_kernel_key
.
place_
==
place
))
{
dev_ctx
=
pool
.
Get
(
expected_kernel_key
.
place_
);
dev_ctx
=
pool
.
Get
(
expected_kernel_key
.
place_
);
...
@@ -122,7 +129,7 @@ PreparedOp PrepareOpImpl(const NameVarMap<VarType>& ins,
...
@@ -122,7 +129,7 @@ PreparedOp PrepareOpImpl(const NameVarMap<VarType>& ins,
}
}
PrepareDataImpl
<
VarType
>
(
place
,
ins
,
op
,
expected_kernel_key
);
PrepareDataImpl
<
VarType
>
(
place
,
ins
,
op
,
expected_kernel_key
);
return
PreparedOp
(
op
,
ctx
,
kernel_iter
->
second
,
dev_ctx
);
return
PreparedOp
(
op
,
ctx
,
kernel_iter
->
second
,
dev_ctx
,
kernel_configs
);
}
}
PreparedOp
PreparedOp
::
Prepare
(
const
NameVarMap
<
VarBase
>&
ins
,
PreparedOp
PreparedOp
::
Prepare
(
const
NameVarMap
<
VarBase
>&
ins
,
...
@@ -145,8 +152,10 @@ template <typename VarType>
...
@@ -145,8 +152,10 @@ template <typename VarType>
static
void
PreparedOpRunImpl
(
static
void
PreparedOpRunImpl
(
const
framework
::
OperatorBase
&
op
,
const
framework
::
RuntimeContext
&
ctx
,
const
framework
::
OperatorBase
&
op
,
const
framework
::
RuntimeContext
&
ctx
,
const
framework
::
OperatorWithKernel
::
OpKernelFunc
&
func
,
const
framework
::
OperatorWithKernel
::
OpKernelFunc
&
func
,
platform
::
DeviceContext
*
dev_ctx
,
const
NameVarMap
<
VarType
>&
ins
,
platform
::
DeviceContext
*
dev_ctx
,
const
NameVarMap
<
VarType
>&
outs
,
const
framework
::
AttributeMap
&
attrs
)
{
std
::
vector
<
framework
::
KernelConfig
>*
kernel_configs
,
const
NameVarMap
<
VarType
>&
ins
,
const
NameVarMap
<
VarType
>&
outs
,
const
framework
::
AttributeMap
&
attrs
)
{
// TODO(zjl): remove scope in dygraph
// TODO(zjl): remove scope in dygraph
framework
::
Scope
scope
;
framework
::
Scope
scope
;
...
@@ -154,21 +163,22 @@ static void PreparedOpRunImpl(
...
@@ -154,21 +163,22 @@ static void PreparedOpRunImpl(
static_cast
<
const
framework
::
OperatorWithKernel
&>
(
op
).
InferShape
(
static_cast
<
const
framework
::
OperatorWithKernel
&>
(
op
).
InferShape
(
&
infer_shape_ctx
);
&
infer_shape_ctx
);
func
(
DygraphExecutionContext
<
VarType
>
(
op
,
scope
,
*
dev_ctx
,
ctx
,
ins
,
outs
,
func
(
DygraphExecutionContext
<
VarType
>
(
op
,
scope
,
*
dev_ctx
,
ctx
,
attrs
));
kernel_configs
,
ins
,
outs
,
attrs
));
}
}
void
PreparedOp
::
Run
(
const
NameVarMap
<
VarBase
>&
ins
,
void
PreparedOp
::
Run
(
const
NameVarMap
<
VarBase
>&
ins
,
const
NameVarMap
<
VarBase
>&
outs
,
const
NameVarMap
<
VarBase
>&
outs
,
const
framework
::
AttributeMap
&
attrs
)
{
const
framework
::
AttributeMap
&
attrs
)
{
PreparedOpRunImpl
<
VarBase
>
(
op_
,
ctx_
,
func_
,
dev_ctx_
,
ins
,
outs
,
attrs
);
PreparedOpRunImpl
<
VarBase
>
(
op_
,
ctx_
,
func_
,
dev_ctx_
,
kernel_configs_
,
ins
,
outs
,
attrs
);
}
}
void
PreparedOp
::
Run
(
const
NameVarMap
<
VariableWrapper
>&
ins
,
void
PreparedOp
::
Run
(
const
NameVarMap
<
VariableWrapper
>&
ins
,
const
NameVarMap
<
VariableWrapper
>&
outs
,
const
NameVarMap
<
VariableWrapper
>&
outs
,
const
framework
::
AttributeMap
&
attrs
)
{
const
framework
::
AttributeMap
&
attrs
)
{
PreparedOpRunImpl
<
VariableWrapper
>
(
op_
,
ctx_
,
func_
,
dev_ctx_
,
ins
,
outs
,
PreparedOpRunImpl
<
VariableWrapper
>
(
op_
,
ctx_
,
func_
,
dev_ctx_
,
attrs
);
kernel_configs_
,
ins
,
outs
,
attrs
);
}
}
}
// namespace imperative
}
// namespace imperative
...
...
paddle/fluid/imperative/prepared_operator.h
浏览文件 @
bfb07aaf
...
@@ -33,7 +33,8 @@ class PreparedOp {
...
@@ -33,7 +33,8 @@ class PreparedOp {
PreparedOp
(
const
framework
::
OperatorBase
&
op
,
PreparedOp
(
const
framework
::
OperatorBase
&
op
,
const
framework
::
RuntimeContext
&
ctx
,
const
framework
::
RuntimeContext
&
ctx
,
const
framework
::
OperatorWithKernel
::
OpKernelFunc
&
func
,
const
framework
::
OperatorWithKernel
::
OpKernelFunc
&
func
,
platform
::
DeviceContext
*
dev_ctx
);
platform
::
DeviceContext
*
dev_ctx
,
std
::
vector
<
framework
::
KernelConfig
>*
kernel_configs
);
static
PreparedOp
Prepare
(
const
NameVarMap
<
VarBase
>&
ins
,
static
PreparedOp
Prepare
(
const
NameVarMap
<
VarBase
>&
ins
,
const
NameVarMap
<
VarBase
>&
outs
,
const
NameVarMap
<
VarBase
>&
outs
,
...
@@ -71,6 +72,7 @@ class PreparedOp {
...
@@ -71,6 +72,7 @@ class PreparedOp {
const
framework
::
RuntimeContext
&
ctx_
;
const
framework
::
RuntimeContext
&
ctx_
;
framework
::
OperatorWithKernel
::
OpKernelFunc
func_
;
framework
::
OperatorWithKernel
::
OpKernelFunc
func_
;
platform
::
DeviceContext
*
dev_ctx_
;
platform
::
DeviceContext
*
dev_ctx_
;
std
::
vector
<
framework
::
KernelConfig
>*
kernel_configs_
;
};
};
}
// namespace imperative
}
// namespace imperative
...
...
paddle/fluid/imperative/tests/test_layer.cc
浏览文件 @
bfb07aaf
...
@@ -235,7 +235,7 @@ TEST(test_layer, test_dygraph_execution_context) {
...
@@ -235,7 +235,7 @@ TEST(test_layer, test_dygraph_execution_context) {
framework
::
Scope
scope
;
framework
::
Scope
scope
;
DygraphExecutionContext
<
imperative
::
VarBase
>
dy_exe_context
(
DygraphExecutionContext
<
imperative
::
VarBase
>
dy_exe_context
(
*
(
op
.
get
()),
scope
,
*
dev_ctx
,
ctx
,
ins
,
outs
,
concat_att_map
);
*
(
op
.
get
()),
scope
,
*
dev_ctx
,
ctx
,
nullptr
,
ins
,
outs
,
concat_att_map
);
ASSERT_EQ
(
dy_exe_context
.
InputSize
(
"X"
),
1u
);
ASSERT_EQ
(
dy_exe_context
.
InputSize
(
"X"
),
1u
);
ASSERT_EQ
(
dy_exe_context
.
InputName
(
"X"
),
"vin"
);
ASSERT_EQ
(
dy_exe_context
.
InputName
(
"X"
),
"vin"
);
...
...
paddle/fluid/operators/beam_search_decode_op.cc
浏览文件 @
bfb07aaf
...
@@ -123,7 +123,7 @@ class BeamSearchDecodeOp : public framework::OperatorBase {
...
@@ -123,7 +123,7 @@ class BeamSearchDecodeOp : public framework::OperatorBase {
auto
&
dev_ctx
=
*
pool
.
Get
(
dev_place
);
auto
&
dev_ctx
=
*
pool
.
Get
(
dev_place
);
framework
::
RuntimeContext
run_ctx
(
Inputs
(),
Outputs
(),
scope
);
framework
::
RuntimeContext
run_ctx
(
Inputs
(),
Outputs
(),
scope
);
framework
::
ExecutionContext
ctx
(
*
this
,
scope
,
dev_ctx
,
run_ctx
);
framework
::
ExecutionContext
ctx
(
*
this
,
scope
,
dev_ctx
,
run_ctx
,
nullptr
);
const
LoDTensorArray
*
ids
=
ctx
.
Input
<
LoDTensorArray
>
(
"Ids"
);
const
LoDTensorArray
*
ids
=
ctx
.
Input
<
LoDTensorArray
>
(
"Ids"
);
const
LoDTensorArray
*
scores
=
ctx
.
Input
<
LoDTensorArray
>
(
"Scores"
);
const
LoDTensorArray
*
scores
=
ctx
.
Input
<
LoDTensorArray
>
(
"Scores"
);
...
...
paddle/fluid/operators/conv_cudnn_helper.h
浏览文件 @
bfb07aaf
...
@@ -21,7 +21,6 @@ limitations under the License. */
...
@@ -21,7 +21,6 @@ limitations under the License. */
#include "paddle/fluid/framework/operator_kernel_configs.h"
#include "paddle/fluid/framework/operator_kernel_configs.h"
#include "paddle/fluid/operators/conv_cudnn_op_cache.h"
#include "paddle/fluid/operators/conv_cudnn_op_cache.h"
#include "paddle/fluid/platform/cudnn_desc.h"
#include "paddle/fluid/platform/cudnn_desc.h"
// #include "paddle/fluid/platform/device_context.h"
namespace
paddle
{
namespace
paddle
{
namespace
operators
{
namespace
operators
{
...
@@ -90,43 +89,7 @@ std::ostream& operator<<(std::ostream& out, const std::vector<T>& v) {
...
@@ -90,43 +89,7 @@ std::ostream& operator<<(std::ostream& out, const std::vector<T>& v) {
return
out
;
return
out
;
}
}
// ConvSearchCache using framework::AlgorithmsCache to search
using
framework
::
AlgorithmsCache
;
// cudnnConvolutionFwdAlgo_t, cudnnConvolutionBwdDataAlgo_t or
// cudnnConvolutionBwdFilterAlgo_t
class
ConvSearchCache
{
public:
static
ConvSearchCache
&
Instance
()
{
static
ConvSearchCache
instance
;
return
instance
;
}
framework
::
AlgorithmsCache
<
cudnnConvolutionFwdAlgo_t
>*
GetForward
()
{
return
&
forward_cache_
;
}
framework
::
AlgorithmsCache
<
cudnnConvolutionBwdDataAlgo_t
>*
GetBackwardData
()
{
return
&
backward_data_cache_
;
}
framework
::
AlgorithmsCache
<
cudnnConvolutionBwdFilterAlgo_t
>*
GetBackwardFilter
()
{
return
&
backward_filter_cache_
;
}
framework
::
AlgorithmsCache
<
cudnnConvolutionFwdAlgo_t
>*
GetConvFusion
()
{
return
&
fusion_forward_cache_
;
}
private:
ConvSearchCache
()
{}
~
ConvSearchCache
()
{}
ConvSearchCache
(
const
ConvSearchCache
&
)
{}
ConvSearchCache
&
operator
=
(
const
ConvSearchCache
&
)
{}
framework
::
AlgorithmsCache
<
cudnnConvolutionFwdAlgo_t
>
forward_cache_
;
framework
::
AlgorithmsCache
<
cudnnConvolutionBwdDataAlgo_t
>
backward_data_cache_
;
framework
::
AlgorithmsCache
<
cudnnConvolutionBwdFilterAlgo_t
>
backward_filter_cache_
;
framework
::
AlgorithmsCache
<
cudnnConvolutionFwdAlgo_t
>
fusion_forward_cache_
;
};
struct
ConvArgs
{
struct
ConvArgs
{
cudnnHandle_t
handle
;
cudnnHandle_t
handle
;
...
@@ -134,7 +97,6 @@ struct ConvArgs {
...
@@ -134,7 +97,6 @@ struct ConvArgs {
platform
::
FilterDescriptor
wdesc
;
platform
::
FilterDescriptor
wdesc
;
platform
::
ConvolutionDescriptor
cdesc
;
platform
::
ConvolutionDescriptor
cdesc
;
const
framework
::
Tensor
*
x
,
*
w
,
*
o
;
const
framework
::
Tensor
*
x
,
*
w
,
*
o
;
cudnnDataType_t
cudnn_dtype
;
// strides
// strides
std
::
vector
<
int
>
s
;
std
::
vector
<
int
>
s
;
...
@@ -145,9 +107,8 @@ struct ConvArgs {
...
@@ -145,9 +107,8 @@ struct ConvArgs {
ConvArgs
(
const
framework
::
Tensor
*
x
,
const
framework
::
Tensor
*
w
,
ConvArgs
(
const
framework
::
Tensor
*
x
,
const
framework
::
Tensor
*
w
,
const
framework
::
Tensor
*
o
,
const
std
::
vector
<
int
>
s
,
const
framework
::
Tensor
*
o
,
const
std
::
vector
<
int
>
s
,
const
std
::
vector
<
int
>
p
,
const
std
::
vector
<
int
>
d
,
const
std
::
vector
<
int
>
p
,
const
std
::
vector
<
int
>
d
)
cudnnDataType_t
dtype
)
:
x
(
x
),
w
(
w
),
o
(
o
),
s
(
s
),
p
(
p
),
d
(
d
)
{}
:
x
(
x
),
w
(
w
),
o
(
o
),
s
(
s
),
p
(
p
),
d
(
d
),
cudnn_dtype
(
dtype
)
{}
};
};
template
<
typename
perf_t
>
template
<
typename
perf_t
>
...
@@ -160,7 +121,7 @@ struct SearchAlgorithm<cudnnConvolutionFwdAlgoPerf_t> {
...
@@ -160,7 +121,7 @@ struct SearchAlgorithm<cudnnConvolutionFwdAlgoPerf_t> {
template
<
typename
T
>
template
<
typename
T
>
static
algo_t
Find
(
const
ConvArgs
&
args
,
bool
exhaustive_search
,
static
algo_t
Find
(
const
ConvArgs
&
args
,
bool
exhaustive_search
,
bool
deterministic
,
bool
deterministic
,
int
algo_cache_id
,
const
framework
::
ExecutionContext
&
ctx
)
{
const
framework
::
ExecutionContext
&
ctx
)
{
auto
dtype
=
platform
::
CudnnDataType
<
T
>::
type
;
auto
dtype
=
platform
::
CudnnDataType
<
T
>::
type
;
bool
has_got_workspace_size
=
true
;
bool
has_got_workspace_size
=
true
;
...
@@ -222,24 +183,22 @@ struct SearchAlgorithm<cudnnConvolutionFwdAlgoPerf_t> {
...
@@ -222,24 +183,22 @@ struct SearchAlgorithm<cudnnConvolutionFwdAlgoPerf_t> {
#endif
#endif
VLOG
(
3
)
<<
"choose algo "
<<
algo
;
VLOG
(
3
)
<<
"choose algo "
<<
algo
;
}
else
{
}
else
{
AlgorithmsCache
<
algo_t
>&
algo_cache
=
ctx
.
GetKernelConfig
<
AlgorithmsCache
<
algo_t
>>
(
algo_cache_id
);
auto
&
dev_ctx
=
auto
&
dev_ctx
=
ctx
.
template
device_context
<
platform
::
CUDADeviceContext
>();
ctx
.
template
device_context
<
platform
::
CUDADeviceContext
>();
auto
workspace_handle
=
dev_ctx
.
cudnn_workspace_handle
();
auto
workspace_handle
=
dev_ctx
.
cudnn_workspace_handle
();
auto
&
temp
=
ctx
.
cuda_device_context
();
AlgorithmsCache
<
algo_t
>&
algo_cache
=
*
(
ConvSearchCache
::
Instance
().
GetForward
());
auto
x_dims
=
framework
::
vectorize
(
args
.
x
->
dims
());
auto
x_dims
=
framework
::
vectorize
(
args
.
x
->
dims
());
auto
w_dims
=
framework
::
vectorize
(
args
.
w
->
dims
());
auto
w_dims
=
framework
::
vectorize
(
args
.
w
->
dims
());
VLOG
(
10
)
<<
"cudnnConvolutionFwdAlgoPerf_t:"
VLOG
(
10
)
<<
"cudnnConvolutionFwdAlgoPerf_t algo_cache_id:"
<<
", x_dims:"
<<
x_dims
<<
", w_dims:"
<<
w_dims
<<
", args.s"
<<
algo_cache_id
<<
", x_dims:"
<<
x_dims
<<
args
.
s
<<
", args.p"
<<
args
.
p
<<
", args.d"
<<
args
.
d
;
<<
", w_dims:"
<<
w_dims
<<
", args.s"
<<
args
.
s
<<
", args.p"
<<
args
.
p
<<
", args.d"
<<
args
.
d
;
algo
=
algo_cache
.
GetAlgorithm
(
algo
=
algo_cache
.
GetAlgorithm
(
x_dims
,
w_dims
,
args
.
s
,
args
.
p
,
args
.
d
,
0
,
x_dims
,
w_dims
,
args
.
s
,
args
.
p
,
args
.
d
,
0
,
[
&
]()
{
static_cast
<
int64_t
>
(
args
.
cudnn_dtype
),
[
&
]()
{
int
returned_algo_count
;
int
returned_algo_count
;
std
::
array
<
perf_t
,
kNUM_CUDNN_FWD_ALGS
>
perf_stat
;
std
::
array
<
perf_t
,
kNUM_CUDNN_FWD_ALGS
>
perf_stat
;
...
@@ -285,7 +244,7 @@ struct SearchAlgorithm<cudnnConvolutionBwdDataAlgoPerf_t> {
...
@@ -285,7 +244,7 @@ struct SearchAlgorithm<cudnnConvolutionBwdDataAlgoPerf_t> {
template
<
typename
T
>
template
<
typename
T
>
static
algo_t
Find
(
const
ConvArgs
&
args
,
bool
exhaustive_search
,
static
algo_t
Find
(
const
ConvArgs
&
args
,
bool
exhaustive_search
,
bool
deterministic
,
bool
deterministic
,
int
algo_cache_id
,
const
framework
::
ExecutionContext
&
ctx
)
{
const
framework
::
ExecutionContext
&
ctx
)
{
auto
dtype
=
platform
::
CudnnDataType
<
T
>::
type
;
auto
dtype
=
platform
::
CudnnDataType
<
T
>::
type
;
bool
exhaustive
=
(
exhaustive_search
)
&
(
dtype
!=
CUDNN_DATA_HALF
);
bool
exhaustive
=
(
exhaustive_search
)
&
(
dtype
!=
CUDNN_DATA_HALF
);
...
@@ -362,23 +321,22 @@ struct SearchAlgorithm<cudnnConvolutionBwdDataAlgoPerf_t> {
...
@@ -362,23 +321,22 @@ struct SearchAlgorithm<cudnnConvolutionBwdDataAlgoPerf_t> {
}
else
if
(
deterministic
)
{
}
else
if
(
deterministic
)
{
return
CUDNN_CONVOLUTION_BWD_DATA_ALGO_1
;
return
CUDNN_CONVOLUTION_BWD_DATA_ALGO_1
;
}
else
{
}
else
{
AlgorithmsCache
<
algo_t
>&
algo_cache
=
ctx
.
GetKernelConfig
<
AlgorithmsCache
<
algo_t
>>
(
algo_cache_id
);
auto
&
dev_ctx
=
auto
&
dev_ctx
=
ctx
.
template
device_context
<
platform
::
CUDADeviceContext
>();
ctx
.
template
device_context
<
platform
::
CUDADeviceContext
>();
auto
workspace_handle
=
dev_ctx
.
cudnn_workspace_handle
();
auto
workspace_handle
=
dev_ctx
.
cudnn_workspace_handle
();
AlgorithmsCache
<
algo_t
>&
algo_cache
=
*
(
ConvSearchCache
::
Instance
().
GetBackwardData
());
auto
x_dims
=
framework
::
vectorize
(
args
.
x
->
dims
());
auto
x_dims
=
framework
::
vectorize
(
args
.
x
->
dims
());
auto
w_dims
=
framework
::
vectorize
(
args
.
w
->
dims
());
auto
w_dims
=
framework
::
vectorize
(
args
.
w
->
dims
());
VLOG
(
10
)
<<
"cudnnConvolutionFwdAlgoPerf_t"
VLOG
(
10
)
<<
"cudnnConvolutionFwdAlgoPerf_t algo_cache_id:"
<<
", x_dims:"
<<
x_dims
<<
", w_dims:"
<<
w_dims
<<
", args.s"
<<
algo_cache_id
<<
", x_dims:"
<<
x_dims
<<
args
.
s
<<
", args.p"
<<
args
.
p
<<
", args.d"
<<
args
.
d
;
<<
", w_dims:"
<<
w_dims
<<
", args.s"
<<
args
.
s
<<
", args.p"
<<
args
.
p
<<
", args.d"
<<
args
.
d
;
algo
=
algo_cache
.
GetAlgorithm
(
algo
=
algo_cache
.
GetAlgorithm
(
x_dims
,
w_dims
,
args
.
s
,
args
.
p
,
args
.
d
,
0
,
x_dims
,
w_dims
,
args
.
s
,
args
.
p
,
args
.
d
,
0
,
[
&
]()
{
static_cast
<
int64_t
>
(
args
.
cudnn_dtype
),
[
&
]()
{
int
returned_algo_count
;
int
returned_algo_count
;
std
::
array
<
perf_t
,
kNUM_CUDNN_FWD_ALGS
>
perf_stat
;
std
::
array
<
perf_t
,
kNUM_CUDNN_FWD_ALGS
>
perf_stat
;
...
@@ -427,7 +385,7 @@ struct SearchAlgorithm<cudnnConvolutionBwdFilterAlgoPerf_t> {
...
@@ -427,7 +385,7 @@ struct SearchAlgorithm<cudnnConvolutionBwdFilterAlgoPerf_t> {
template
<
typename
T
>
template
<
typename
T
>
static
algo_t
Find
(
const
ConvArgs
&
args
,
bool
exhaustive_search
,
static
algo_t
Find
(
const
ConvArgs
&
args
,
bool
exhaustive_search
,
bool
deterministic
,
bool
deterministic
,
int
algo_cache_id
,
const
framework
::
ExecutionContext
&
ctx
)
{
const
framework
::
ExecutionContext
&
ctx
)
{
auto
dtype
=
platform
::
CudnnDataType
<
T
>::
type
;
auto
dtype
=
platform
::
CudnnDataType
<
T
>::
type
;
bool
exhaustive
=
(
exhaustive_search
)
&
(
dtype
!=
CUDNN_DATA_HALF
);
bool
exhaustive
=
(
exhaustive_search
)
&
(
dtype
!=
CUDNN_DATA_HALF
);
...
@@ -491,22 +449,22 @@ struct SearchAlgorithm<cudnnConvolutionBwdFilterAlgoPerf_t> {
...
@@ -491,22 +449,22 @@ struct SearchAlgorithm<cudnnConvolutionBwdFilterAlgoPerf_t> {
}
else
if
(
deterministic
)
{
}
else
if
(
deterministic
)
{
return
CUDNN_CONVOLUTION_BWD_FILTER_ALGO_1
;
return
CUDNN_CONVOLUTION_BWD_FILTER_ALGO_1
;
}
else
{
}
else
{
AlgorithmsCache
<
algo_t
>&
algo_cache
=
ctx
.
GetKernelConfig
<
AlgorithmsCache
<
algo_t
>>
(
algo_cache_id
);
auto
&
dev_ctx
=
auto
&
dev_ctx
=
ctx
.
template
device_context
<
platform
::
CUDADeviceContext
>();
ctx
.
template
device_context
<
platform
::
CUDADeviceContext
>();
auto
workspace_handle
=
dev_ctx
.
cudnn_workspace_handle
();
auto
workspace_handle
=
dev_ctx
.
cudnn_workspace_handle
();
AlgorithmsCache
<
algo_t
>&
algo_cache
=
*
(
ConvSearchCache
::
Instance
().
GetBackwardFilter
());
auto
x_dims
=
framework
::
vectorize
(
args
.
x
->
dims
());
auto
x_dims
=
framework
::
vectorize
(
args
.
x
->
dims
());
auto
w_dims
=
framework
::
vectorize
(
args
.
w
->
dims
());
auto
w_dims
=
framework
::
vectorize
(
args
.
w
->
dims
());
VLOG
(
10
)
<<
"cudnnConvolutionFwdAlgoPerf_t:"
VLOG
(
10
)
<<
"cudnnConvolutionFwdAlgoPerf_t algo_cache_id:"
<<
", x_dims:"
<<
x_dims
<<
", w_dims:"
<<
w_dims
<<
", args.s"
<<
algo_cache_id
<<
", x_dims:"
<<
x_dims
<<
args
.
s
<<
", args.p"
<<
args
.
p
<<
", args.d"
<<
args
.
d
;
<<
", w_dims:"
<<
w_dims
<<
", args.s"
<<
args
.
s
<<
", args.p"
<<
args
.
p
<<
", args.d"
<<
args
.
d
;
algo
=
algo_cache
.
GetAlgorithm
(
algo
=
algo_cache
.
GetAlgorithm
(
x_dims
,
w_dims
,
args
.
s
,
args
.
p
,
args
.
d
,
0
,
x_dims
,
w_dims
,
args
.
s
,
args
.
p
,
args
.
d
,
0
,
[
&
]()
{
static_cast
<
int64_t
>
(
args
.
cudnn_dtype
),
[
&
]()
{
int
returned_algo_count
;
int
returned_algo_count
;
std
::
array
<
perf_t
,
kNUM_CUDNN_FWD_ALGS
>
perf_stat
;
std
::
array
<
perf_t
,
kNUM_CUDNN_FWD_ALGS
>
perf_stat
;
auto
cudnn_find_func
=
[
&
](
void
*
cudnn_workspace_ptr
)
{
auto
cudnn_find_func
=
[
&
](
void
*
cudnn_workspace_ptr
)
{
...
...
paddle/fluid/operators/conv_cudnn_op.cu
浏览文件 @
bfb07aaf
...
@@ -216,13 +216,9 @@ class CUDNNConvOpKernel : public framework::OpKernel<T> {
...
@@ -216,13 +216,9 @@ class CUDNNConvOpKernel : public framework::OpKernel<T> {
const
T
*
filter_data
=
transformed_filter_channel
.
data
<
T
>
();
const
T
*
filter_data
=
transformed_filter_channel
.
data
<
T
>
();
// ------------------- cudnn descriptors ---------------------
// ------------------- cudnn descriptors ---------------------
ConvArgs
args
{
&
transformed_input
,
ConvArgs
args
{
&
transformed_input
,
&
transformed_filter_channel
,
&
transformed_filter_channel
,
&
transformed_output
,
strides
,
&
transformed_output
,
padding_common
,
dilations
};
strides
,
padding_common
,
dilations
,
dtype
};
auto
handle
=
dev_ctx
.
cudnn_handle
();
auto
handle
=
dev_ctx
.
cudnn_handle
();
auto
workspace_handle
=
dev_ctx
.
cudnn_workspace_handle
();
auto
workspace_handle
=
dev_ctx
.
cudnn_workspace_handle
();
...
@@ -273,7 +269,7 @@ class CUDNNConvOpKernel : public framework::OpKernel<T> {
...
@@ -273,7 +269,7 @@ class CUDNNConvOpKernel : public framework::OpKernel<T> {
cudnnConvolutionFwdAlgo_t
algo
{};
cudnnConvolutionFwdAlgo_t
algo
{};
using
search
=
SearchAlgorithm
<
cudnnConvolutionFwdAlgoPerf_t
>
;
using
search
=
SearchAlgorithm
<
cudnnConvolutionFwdAlgoPerf_t
>
;
algo
=
search
::
Find
<
T
>
(
args
,
exhaustive_search
,
false
,
ctx
);
algo
=
search
::
Find
<
T
>
(
args
,
exhaustive_search
,
false
,
0
,
ctx
);
workspace_size
=
search
::
GetWorkspaceSize
(
args
,
algo
);
workspace_size
=
search
::
GetWorkspaceSize
(
args
,
algo
);
#if CUDNN_VERSION_MIN(7, 0, 1)
#if CUDNN_VERSION_MIN(7, 0, 1)
...
@@ -522,15 +518,13 @@ class CUDNNConvGradOpKernel : public framework::OpKernel<T> {
...
@@ -522,15 +518,13 @@ class CUDNNConvGradOpKernel : public framework::OpKernel<T> {
&
transformed_output_grad_channel
,
&
transformed_output_grad_channel
,
strides
,
strides
,
padding_common
,
padding_common
,
dilations
,
dilations
};
dtype
};
ConvArgs
args2
{
&
transformed_input
,
ConvArgs
args2
{
&
transformed_input
,
&
transformed_filter_grad_channel
,
&
transformed_filter_grad_channel
,
&
transformed_output_grad_channel
,
&
transformed_output_grad_channel
,
strides
,
strides
,
padding_common
,
padding_common
,
dilations
,
dilations
};
dtype
};
auto
handle
=
dev_ctx
.
cudnn_handle
();
auto
handle
=
dev_ctx
.
cudnn_handle
();
DataLayout
layout
=
compute_format
==
DataLayout
::
kNHWC
?
DataLayout
::
kNHWC
DataLayout
layout
=
compute_format
==
DataLayout
::
kNHWC
?
DataLayout
::
kNHWC
...
@@ -586,7 +580,7 @@ class CUDNNConvGradOpKernel : public framework::OpKernel<T> {
...
@@ -586,7 +580,7 @@ class CUDNNConvGradOpKernel : public framework::OpKernel<T> {
using
search1
=
SearchAlgorithm
<
cudnnConvolutionBwdDataAlgoPerf_t
>
;
using
search1
=
SearchAlgorithm
<
cudnnConvolutionBwdDataAlgoPerf_t
>
;
data_algo
=
data_algo
=
search1
::
Find
<
T
>
(
args1
,
exhaustive_search
,
deterministic
,
ctx
);
search1
::
Find
<
T
>
(
args1
,
exhaustive_search
,
deterministic
,
0
,
ctx
);
workspace_size
=
workspace_size
=
std
::
max
(
workspace_size
,
search1
::
GetWorkspaceSize
(
args1
,
data_algo
));
std
::
max
(
workspace_size
,
search1
::
GetWorkspaceSize
(
args1
,
data_algo
));
}
}
...
@@ -603,7 +597,7 @@ class CUDNNConvGradOpKernel : public framework::OpKernel<T> {
...
@@ -603,7 +597,7 @@ class CUDNNConvGradOpKernel : public framework::OpKernel<T> {
using
search2
=
SearchAlgorithm
<
cudnnConvolutionBwdFilterAlgoPerf_t
>
;
using
search2
=
SearchAlgorithm
<
cudnnConvolutionBwdFilterAlgoPerf_t
>
;
filter_algo
=
filter_algo
=
search2
::
Find
<
T
>
(
args2
,
exhaustive_search
,
deterministic
,
ctx
);
search2
::
Find
<
T
>
(
args2
,
exhaustive_search
,
deterministic
,
1
,
ctx
);
workspace_size
=
std
::
max
(
workspace_size
,
workspace_size
=
std
::
max
(
workspace_size
,
search2
::
GetWorkspaceSize
(
args2
,
filter_algo
));
search2
::
GetWorkspaceSize
(
args2
,
filter_algo
));
}
}
...
@@ -904,26 +898,15 @@ class CUDNNConvDoubleGradOpKernel : public framework::OpKernel<T> {
...
@@ -904,26 +898,15 @@ class CUDNNConvDoubleGradOpKernel : public framework::OpKernel<T> {
auto
handle
=
dev_ctx
.
cudnn_handle
();
auto
handle
=
dev_ctx
.
cudnn_handle
();
ConvArgs
args1
{
&
transformed_ddX
,
ConvArgs
args1
{
&
transformed_ddX
,
W
,
W
,
&
transformed_ddO_channel
,
strides
,
&
transformed_ddO_channel
,
padding_common
,
dilations
};
strides
,
ConvArgs
args2
{
&
transformed_X
,
ddW
,
&
transformed_ddO_channel
,
strides
,
padding_common
,
padding_common
,
dilations
};
dilations
,
ConvArgs
args3
{
&
transformed_ddX
,
dW
,
&
transformed_dO_channel
,
strides
,
dtype
};
padding_common
,
dilations
};
ConvArgs
args2
{
ConvArgs
args4
{
&
transformed_dX
,
ddW
,
&
transformed_dO_channel
,
strides
,
&
transformed_X
,
ddW
,
&
transformed_ddO_channel
,
strides
,
padding_common
,
padding_common
,
dilations
};
dilations
,
dtype
};
ConvArgs
args3
{
&
transformed_ddX
,
dW
,
&
transformed_dO_channel
,
strides
,
padding_common
,
dilations
,
dtype
};
ConvArgs
args4
{
&
transformed_dX
,
ddW
,
&
transformed_dO_channel
,
strides
,
padding_common
,
dilations
,
dtype
};
cudnnConvolutionFwdAlgo_t
fwd_algo1
=
cudnnConvolutionFwdAlgo_t
fwd_algo1
=
static_cast
<
cudnnConvolutionFwdAlgo_t
>
(
0
);
static_cast
<
cudnnConvolutionFwdAlgo_t
>
(
0
);
...
@@ -951,7 +934,7 @@ class CUDNNConvDoubleGradOpKernel : public framework::OpKernel<T> {
...
@@ -951,7 +934,7 @@ class CUDNNConvDoubleGradOpKernel : public framework::OpKernel<T> {
args1
.
cdesc
.
set
(
dtype
,
padding_common
,
strides
,
dilations
,
c_group
);
args1
.
cdesc
.
set
(
dtype
,
padding_common
,
strides
,
dilations
,
c_group
);
using
search1
=
SearchAlgorithm
<
cudnnConvolutionFwdAlgoPerf_t
>
;
using
search1
=
SearchAlgorithm
<
cudnnConvolutionFwdAlgoPerf_t
>
;
fwd_algo1
=
search1
::
Find
<
T
>
(
args1
,
exhaustive_search
,
false
,
ctx
);
fwd_algo1
=
search1
::
Find
<
T
>
(
args1
,
exhaustive_search
,
false
,
0
,
ctx
);
workspace_size
=
search1
::
GetWorkspaceSize
(
args1
,
fwd_algo1
);
workspace_size
=
search1
::
GetWorkspaceSize
(
args1
,
fwd_algo1
);
}
}
...
@@ -966,7 +949,7 @@ class CUDNNConvDoubleGradOpKernel : public framework::OpKernel<T> {
...
@@ -966,7 +949,7 @@ class CUDNNConvDoubleGradOpKernel : public framework::OpKernel<T> {
args2
.
cdesc
.
set
(
dtype
,
padding_common
,
strides
,
dilations
,
c_group
);
args2
.
cdesc
.
set
(
dtype
,
padding_common
,
strides
,
dilations
,
c_group
);
using
search2
=
SearchAlgorithm
<
cudnnConvolutionFwdAlgoPerf_t
>
;
using
search2
=
SearchAlgorithm
<
cudnnConvolutionFwdAlgoPerf_t
>
;
fwd_algo2
=
search2
::
Find
<
T
>
(
args2
,
exhaustive_search
,
false
,
ctx
);
fwd_algo2
=
search2
::
Find
<
T
>
(
args2
,
exhaustive_search
,
false
,
0
,
ctx
);
workspace_size
=
std
::
max
(
workspace_size
,
workspace_size
=
std
::
max
(
workspace_size
,
search2
::
GetWorkspaceSize
(
args2
,
fwd_algo2
));
search2
::
GetWorkspaceSize
(
args2
,
fwd_algo2
));
}
}
...
@@ -984,7 +967,7 @@ class CUDNNConvDoubleGradOpKernel : public framework::OpKernel<T> {
...
@@ -984,7 +967,7 @@ class CUDNNConvDoubleGradOpKernel : public framework::OpKernel<T> {
using
search3
=
SearchAlgorithm
<
cudnnConvolutionBwdFilterAlgoPerf_t
>
;
using
search3
=
SearchAlgorithm
<
cudnnConvolutionBwdFilterAlgoPerf_t
>
;
filter_algo
=
filter_algo
=
search3
::
Find
<
T
>
(
args3
,
exhaustive_search
,
deterministic
,
ctx
);
search3
::
Find
<
T
>
(
args3
,
exhaustive_search
,
deterministic
,
1
,
ctx
);
workspace_size
=
std
::
max
(
workspace_size
,
workspace_size
=
std
::
max
(
workspace_size
,
search3
::
GetWorkspaceSize
(
args3
,
filter_algo
));
search3
::
GetWorkspaceSize
(
args3
,
filter_algo
));
}
}
...
@@ -1000,7 +983,7 @@ class CUDNNConvDoubleGradOpKernel : public framework::OpKernel<T> {
...
@@ -1000,7 +983,7 @@ class CUDNNConvDoubleGradOpKernel : public framework::OpKernel<T> {
using
search4
=
SearchAlgorithm
<
cudnnConvolutionBwdDataAlgoPerf_t
>
;
using
search4
=
SearchAlgorithm
<
cudnnConvolutionBwdDataAlgoPerf_t
>
;
data_algo
=
data_algo
=
search4
::
Find
<
T
>
(
args4
,
exhaustive_search
,
deterministic
,
ctx
);
search4
::
Find
<
T
>
(
args4
,
exhaustive_search
,
deterministic
,
2
,
ctx
);
workspace_size
=
workspace_size
=
std
::
max
(
workspace_size
,
search4
::
GetWorkspaceSize
(
args4
,
data_algo
));
std
::
max
(
workspace_size
,
search4
::
GetWorkspaceSize
(
args4
,
data_algo
));
}
}
...
...
paddle/fluid/operators/conv_op.cc
浏览文件 @
bfb07aaf
...
@@ -178,6 +178,17 @@ framework::OpKernelType ConvOp::GetExpectedKernelType(
...
@@ -178,6 +178,17 @@ framework::OpKernelType ConvOp::GetExpectedKernelType(
auto
type
=
framework
::
OpKernelType
(
input_data_type
,
ctx
.
GetPlace
(),
layout
,
auto
type
=
framework
::
OpKernelType
(
input_data_type
,
ctx
.
GetPlace
(),
layout
,
library
,
customized_type_value
);
library
,
customized_type_value
);
#ifdef PADDLE_WITH_CUDA
std
::
vector
<
framework
::
KernelConfig
>&
configs
=
kernel_configs_map_
[
type
];
// TODO(dangqingqing): Currently conv_fusion_op use cudnn but sets use_cudnn
// to false. It should be fixed and then here should only create if library
// is kCUDNN.
if
(
configs
.
empty
())
{
std
::
shared_ptr
<
framework
::
AlgorithmsCache
<
cudnnConvolutionFwdAlgo_t
>>
p
(
new
framework
::
AlgorithmsCache
<
cudnnConvolutionFwdAlgo_t
>
());
configs
.
push_back
(
p
);
}
#endif
return
type
;
return
type
;
}
}
...
@@ -552,6 +563,21 @@ framework::OpKernelType ConvOpGrad::GetExpectedKernelType(
...
@@ -552,6 +563,21 @@ framework::OpKernelType ConvOpGrad::GetExpectedKernelType(
auto
type
=
framework
::
OpKernelType
(
auto
type
=
framework
::
OpKernelType
(
OperatorWithKernel
::
IndicateVarDataType
(
ctx
,
"Input"
),
ctx
.
GetPlace
(),
OperatorWithKernel
::
IndicateVarDataType
(
ctx
,
"Input"
),
ctx
.
GetPlace
(),
layout_
,
library_
,
customized_type_value
);
layout_
,
library_
,
customized_type_value
);
#ifdef PADDLE_WITH_CUDA
if
(
library_
==
framework
::
LibraryType
::
kCUDNN
)
{
std
::
vector
<
framework
::
KernelConfig
>&
configs
=
kernel_configs_map_
[
type
];
if
(
configs
.
empty
())
{
std
::
shared_ptr
<
framework
::
AlgorithmsCache
<
cudnnConvolutionBwdDataAlgo_t
>>
p
(
new
framework
::
AlgorithmsCache
<
cudnnConvolutionBwdDataAlgo_t
>
());
configs
.
push_back
(
p
);
std
::
shared_ptr
<
framework
::
AlgorithmsCache
<
cudnnConvolutionBwdFilterAlgo_t
>>
p2
(
new
framework
::
AlgorithmsCache
<
cudnnConvolutionBwdFilterAlgo_t
>
());
configs
.
push_back
(
p2
);
}
}
#endif
return
type
;
return
type
;
}
}
...
@@ -728,6 +754,25 @@ framework::OpKernelType ConvOpDoubleGrad::GetExpectedKernelType(
...
@@ -728,6 +754,25 @@ framework::OpKernelType ConvOpDoubleGrad::GetExpectedKernelType(
auto
type
=
framework
::
OpKernelType
(
auto
type
=
framework
::
OpKernelType
(
OperatorWithKernel
::
IndicateVarDataType
(
ctx
,
"Input"
),
ctx
.
GetPlace
(),
OperatorWithKernel
::
IndicateVarDataType
(
ctx
,
"Input"
),
ctx
.
GetPlace
(),
layout_
,
library_
,
customized_type_value
);
layout_
,
library_
,
customized_type_value
);
#ifdef PADDLE_WITH_CUDA
if
(
library_
==
framework
::
LibraryType
::
kCUDNN
)
{
std
::
vector
<
framework
::
KernelConfig
>&
configs
=
kernel_configs_map_
[
type
];
if
(
configs
.
empty
())
{
std
::
shared_ptr
<
framework
::
AlgorithmsCache
<
cudnnConvolutionFwdAlgo_t
>>
p0
(
new
framework
::
AlgorithmsCache
<
cudnnConvolutionFwdAlgo_t
>
());
configs
.
push_back
(
p0
);
std
::
shared_ptr
<
framework
::
AlgorithmsCache
<
cudnnConvolutionBwdFilterAlgo_t
>>
p1
(
new
framework
::
AlgorithmsCache
<
cudnnConvolutionBwdFilterAlgo_t
>
());
configs
.
push_back
(
p1
);
std
::
shared_ptr
<
framework
::
AlgorithmsCache
<
cudnnConvolutionBwdDataAlgo_t
>>
p2
(
new
framework
::
AlgorithmsCache
<
cudnnConvolutionBwdDataAlgo_t
>
());
configs
.
push_back
(
p2
);
}
}
#endif
return
type
;
return
type
;
}
}
...
...
paddle/fluid/operators/elementwise/test_elementwise_mul_op_dim.cc
浏览文件 @
bfb07aaf
...
@@ -58,7 +58,8 @@ void MainTest(const TestData& test_data) {
...
@@ -58,7 +58,8 @@ void MainTest(const TestData& test_data) {
RuntimeContext
runtime_ctx
=
RuntimeContext
runtime_ctx
=
RuntimeContext
(
op
->
Inputs
(),
op
->
Outputs
(),
scope
);
RuntimeContext
(
op
->
Inputs
(),
op
->
Outputs
(),
scope
);
ExecutionContext
ctx
=
ExecutionContext
(
*
op
,
scope
,
*
dev_ctx
,
runtime_ctx
);
ExecutionContext
ctx
=
ExecutionContext
(
*
op
,
scope
,
*
dev_ctx
,
runtime_ctx
,
nullptr
);
bool
result
=
ElementwiseMulOp
::
AreDimsAndFormatCorrect
(
bool
result
=
ElementwiseMulOp
::
AreDimsAndFormatCorrect
(
ctx
,
16
,
MKLDNNMemoryFormat
::
nChw16c
);
ctx
,
16
,
MKLDNNMemoryFormat
::
nChw16c
);
if
(
test_data
.
supposed_to_fail
)
if
(
test_data
.
supposed_to_fail
)
...
...
paddle/fluid/operators/fused/conv_fusion_op.cu
浏览文件 @
bfb07aaf
...
@@ -14,10 +14,10 @@ limitations under the License. */
...
@@ -14,10 +14,10 @@ limitations under the License. */
#include <array>
#include <array>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/conv_cudnn_helper.h"
#include "paddle/fluid/operators/conv_cudnn_op_cache.h"
#include "paddle/fluid/operators/conv_cudnn_op_cache.h"
#include "paddle/fluid/operators/conv_op.h"
#include "paddle/fluid/operators/conv_op.h"
#include "paddle/fluid/operators/math/padding.h"
#include "paddle/fluid/operators/math/padding.h"
#include "paddle/fluid/platform/cudnn_helper.h"
DECLARE_int64
(
cudnn_exhaustive_search_times
);
DECLARE_int64
(
cudnn_exhaustive_search_times
);
...
@@ -233,7 +233,7 @@ class CUDNNConvFusionOpKernel : public framework::OpKernel<T> {
...
@@ -233,7 +233,7 @@ class CUDNNConvFusionOpKernel : public framework::OpKernel<T> {
return
fwd_perf_stat
[
0
].
algo
;
return
fwd_perf_stat
[
0
].
algo
;
};
};
AlgorithmsCache
<
cudnnConvolutionFwdAlgo_t
>&
algo_cache
=
AlgorithmsCache
<
cudnnConvolutionFwdAlgo_t
>&
algo_cache
=
*
(
ConvSearchCache
::
Instance
().
GetConvFusion
()
);
ctx
.
GetKernelConfig
<
AlgorithmsCache
<
cudnnConvolutionFwdAlgo_t
>>
(
0
);
int
search_times
=
ctx
.
Attr
<
int
>
(
"search_times"
);
int
search_times
=
ctx
.
Attr
<
int
>
(
"search_times"
);
search_times
=
std
::
max
(
search_times
=
std
::
max
(
static_cast
<
int
>
(
FLAGS_cudnn_exhaustive_search_times
),
search_times
);
static_cast
<
int
>
(
FLAGS_cudnn_exhaustive_search_times
),
search_times
);
...
@@ -245,9 +245,8 @@ class CUDNNConvFusionOpKernel : public framework::OpKernel<T> {
...
@@ -245,9 +245,8 @@ class CUDNNConvFusionOpKernel : public framework::OpKernel<T> {
algo
=
algo_cache
.
GetAlgorithm
(
x_dims
[
2
]
*
x_dims
[
3
],
search_times
,
0
,
algo
=
algo_cache
.
GetAlgorithm
(
x_dims
[
2
]
*
x_dims
[
3
],
search_times
,
0
,
search_func
);
search_func
);
}
else
{
}
else
{
auto
dtype
=
platform
::
CudnnDataType
<
T
>::
type
;
algo
=
algo_cache
.
GetAlgorithm
(
x_dims
,
f_dims
,
strides
,
paddings
,
algo
=
algo_cache
.
GetAlgorithm
(
x_dims
,
f_dims
,
strides
,
paddings
,
dilations
,
0
,
dtype
,
search_func
);
dilations
,
0
,
search_func
);
}
}
VLOG
(
3
)
<<
"choose algo "
<<
algo
;
VLOG
(
3
)
<<
"choose algo "
<<
algo
;
}
}
...
...
paddle/fluid/operators/warpctc_op.cc
浏览文件 @
bfb07aaf
...
@@ -61,8 +61,8 @@ class WarpCTCOp : public framework::OperatorWithKernel {
...
@@ -61,8 +61,8 @@ class WarpCTCOp : public framework::OperatorWithKernel {
framework
::
LibraryType
library_
{
framework
::
LibraryType
::
kPlain
};
framework
::
LibraryType
library_
{
framework
::
LibraryType
::
kPlain
};
framework
::
DataLayout
layout_
=
framework
::
DataLayout
::
kAnyLayout
;
framework
::
DataLayout
layout_
=
framework
::
DataLayout
::
kAnyLayout
;
return
framework
::
OpKernelType
(
return
framework
::
OpKernelType
(
OperatorWithKernel
::
IndicateVarDataType
(
ctx
,
"Logits"
),
ctx
.
GetPlace
(),
OperatorWithKernel
::
IndicateVarDataType
(
ctx
,
"Logits"
),
layout_
,
library_
);
ctx
.
device_context
(),
layout_
,
library_
);
}
}
};
};
...
@@ -174,7 +174,7 @@ class WarpCTCGradOp : public framework::OperatorWithKernel {
...
@@ -174,7 +174,7 @@ class WarpCTCGradOp : public framework::OperatorWithKernel {
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
return
framework
::
OpKernelType
(
OperatorWithKernel
::
IndicateVarDataType
(
return
framework
::
OpKernelType
(
OperatorWithKernel
::
IndicateVarDataType
(
ctx
,
framework
::
GradVarName
(
"Loss"
)),
ctx
,
framework
::
GradVarName
(
"Loss"
)),
ctx
.
GetPlace
());
ctx
.
device_context
());
}
}
};
};
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录