Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
40a5f3fd
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
40a5f3fd
编写于
6月 03, 2020
作者:
J
Jacek Czaja
提交者:
GitHub
6月 03, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[oneDNN] Clearing mkldnn cache in naiveexecutor destructor (#24756)
上级
8468dae2
变更
7
显示空白变更内容
内联
并排
Showing
7 changed file
with
378 addition
and
5 deletion
+378
-5
paddle/fluid/framework/executor.cc
paddle/fluid/framework/executor.cc
+1
-2
paddle/fluid/framework/ir/CMakeLists.txt
paddle/fluid/framework/ir/CMakeLists.txt
+5
-0
paddle/fluid/framework/ir/conv_bn_fuse_pass.cc
paddle/fluid/framework/ir/conv_bn_fuse_pass.cc
+43
-3
paddle/fluid/framework/ir/mkldnn/mkldnn_conv_bn_fuse_pass_tester.cc
...id/framework/ir/mkldnn/mkldnn_conv_bn_fuse_pass_tester.cc
+303
-0
paddle/fluid/framework/ir/pass.cc
paddle/fluid/framework/ir/pass.cc
+9
-0
paddle/fluid/framework/naive_executor.cc
paddle/fluid/framework/naive_executor.cc
+15
-0
paddle/fluid/framework/naive_executor.h
paddle/fluid/framework/naive_executor.h
+2
-0
未找到文件。
paddle/fluid/framework/executor.cc
浏览文件 @
40a5f3fd
...
...
@@ -81,8 +81,7 @@ Executor::Executor(const platform::Place& place) : place_(place) {}
Executor
::~
Executor
()
{
#ifdef PADDLE_WITH_MKLDNN
// Clear mkl-dnn cache, unless explicitly
// (as set in constructor) marked not to do so
// Clear mkl-dnn cache,
// this is needed to have mkl-dnn unit tests working
if
(
platform
::
is_cpu_place
(
place_
))
{
platform
::
DeviceContextPool
&
pool
=
platform
::
DeviceContextPool
::
Instance
();
...
...
paddle/fluid/framework/ir/CMakeLists.txt
浏览文件 @
40a5f3fd
...
...
@@ -146,6 +146,11 @@ if (WITH_MKLDNN)
cc_test
(
test_conv_activation_mkldnn_fuse_pass SRCS mkldnn/conv_activation_mkldnn_fuse_pass_tester.cc DEPS conv_activation_mkldnn_fuse_pass
)
cc_test
(
test_conv_concat_relu_mkldnn_fuse_pass SRCS mkldnn/conv_concat_relu_mkldnn_fuse_pass_tester.cc DEPS conv_concat_relu_mkldnn_fuse_pass
)
cc_test
(
test_conv_elementwise_add_mkldnn_fuse_pass SRCS mkldnn/conv_elementwise_add_mkldnn_fuse_pass_tester.cc DEPS conv_elementwise_add_mkldnn_fuse_pass
)
set
(
TEST_CONV_BN_PASS_DEPS conv_bn_fuse_pass graph_to_program_pass conv_op conv_transpose_op math_function im2col vol2col batch_norm_op gelu_op activation_op elementwise_add_op concat_and_split naive_executor device_context
)
if
(
WITH_GPU
)
set
(
TEST_CONV_BN_PASS_DEPS
${
TEST_CONV_BN_PASS_DEPS
}
depthwise_conv
)
endif
()
cc_test
(
test_conv_batch_norm_mkldnn_fuse_pass SRCS mkldnn/mkldnn_conv_bn_fuse_pass_tester.cc DEPS
${
TEST_CONV_BN_PASS_DEPS
}
)
cc_test
(
test_scale_matmul_fuse_pass SRCS mkldnn/scale_matmul_fuse_pass_tester.cc DEPS scale_matmul_fuse_pass
)
cc_test
(
test_mkldnn_placement_pass SRCS mkldnn/mkldnn_placement_pass_tester.cc DEPS mkldnn_placement_pass
)
cc_test
(
test_mkldnn_inplace_pass SRCS mkldnn/mkldnn_inplace_pass_tester.cc DEPS mkldnn_inplace_pass
)
...
...
paddle/fluid/framework/ir/conv_bn_fuse_pass.cc
浏览文件 @
40a5f3fd
...
...
@@ -13,6 +13,7 @@
// limitations under the License.
#include "paddle/fluid/framework/ir/conv_bn_fuse_pass.h"
#include <algorithm>
#include <functional>
#include <string>
#include <vector>
...
...
@@ -278,9 +279,48 @@ void ConvEltwiseAddBNFusePass::ApplyImpl(ir::Graph* graph) const {
// update weights and biases
float
epsilon
=
BOOST_GET_CONST
(
float
,
batch_norm
->
Op
()
->
GetAttr
(
"epsilon"
));
// if bias is an input to other ops as well then we cannot overwrite it
// so we create separate elementwise Y in nodes
if
(
eltwise_y_in
->
outputs
.
size
()
>
1
)
{
// Make a copy of eltwise Y input tensor
// Create eltwise_y (conv bias) variable
VarDesc
eltwise_y_in_desc
(
patterns
::
PDNodeName
(
name_scope_
,
"eltwise_y_in"
+
std
::
to_string
(
found_conv_bn_count
)));
eltwise_y_in_desc
.
SetShape
(
framework
::
vectorize
(
eltwise_y_in_tensor
->
dims
()));
eltwise_y_in_desc
.
SetDataType
(
eltwise_y_in_tensor
->
type
());
eltwise_y_in_desc
.
SetLoDLevel
(
eltwise_y_in
->
Var
()
->
GetLoDLevel
());
eltwise_y_in_desc
.
SetPersistable
(
true
);
auto
*
eltwise_y_in_node
=
g
->
CreateVarNode
(
&
eltwise_y_in_desc
);
auto
*
eltwise_y_in_tensor_ex
=
scope
->
Var
(
eltwise_y_in_node
->
Name
())
->
GetMutable
<
LoDTensor
>
();
// Initialize eltwise_y
TensorCopy
(
*
eltwise_y_in_tensor
,
platform
::
CPUPlace
(),
eltwise_y_in_tensor_ex
);
recompute_bias_and_weights
(
scope
,
conv_weight
,
*
bn_scale
,
*
bn_bias_tensor
,
*
bn_mean
,
*
bn_variance
,
eltwise_y_in_tensor_ex
,
epsilon
,
conv_type
());
// Set new var
eltwise
->
Op
()
->
RenameInput
(
eltwise_y_in
->
Name
(),
eltwise_y_in_node
->
Name
());
// Link new bias node to eltwise
IR_NODE_LINK_TO
(
eltwise_y_in_node
,
eltwise
);
// unlink original bias from eltwise_op
eltwise_y_in
->
outputs
.
erase
(
std
::
remove_if
(
eltwise_y_in
->
outputs
.
begin
(),
eltwise_y_in
->
outputs
.
end
(),
[
&
](
Node
*&
n
)
{
return
n
->
id
()
==
eltwise
->
id
()
?
true
:
false
;
}),
eltwise_y_in
->
outputs
.
end
());
}
else
{
recompute_bias_and_weights
(
scope
,
conv_weight
,
*
bn_scale
,
*
bn_bias_tensor
,
*
bn_mean
,
*
bn_variance
,
eltwise_y_in_tensor
,
epsilon
,
conv_type
());
}
// Update the elementwise_add node
eltwise
->
Op
()
->
SetAttr
(
"axis"
,
1
);
...
...
paddle/fluid/framework/ir/mkldnn/mkldnn_conv_bn_fuse_pass_tester.cc
0 → 100644
浏览文件 @
40a5f3fd
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <string>
#include <gtest/gtest.h>
#include <boost/logic/tribool.hpp>
#include <random>
#include <unordered_set>
#include "paddle/fluid/framework/ir/graph_traits.h"
#include "paddle/fluid/framework/ir/mkldnn/conv_elementwise_add_mkldnn_fuse_pass.h"
#include "paddle/fluid/framework/ir/pass_tester_helper.h"
#include "paddle/fluid/framework/naive_executor.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/platform/place.h"
USE_OP
(
batch_norm
);
USE_OP_DEVICE_KERNEL
(
batch_norm
,
MKLDNN
);
USE_OP
(
conv2d_transpose
);
USE_OP_DEVICE_KERNEL
(
conv2d_transpose
,
MKLDNN
);
USE_OP
(
elementwise_add
);
USE_OP_DEVICE_KERNEL
(
elementwise_add
,
MKLDNN
);
USE_OP
(
gelu
);
USE_OP_DEVICE_KERNEL
(
gelu
,
MKLDNN
);
namespace
paddle
{
namespace
framework
{
namespace
ir
{
class
MKLDNNConvBatchNormPassTest
{
private:
void
SetOp
(
ProgramDesc
*
prog
,
const
std
::
string
&
type
,
const
std
::
string
&
name
,
const
std
::
vector
<
std
::
string
>&
inputs
,
const
std
::
vector
<
std
::
string
>&
outputs
,
boost
::
tribool
use_mkldnn
)
{
auto
*
op
=
prog
->
MutableBlock
(
0
)
->
AppendOp
();
op
->
SetType
(
type
);
if
(
!
boost
::
indeterminate
(
use_mkldnn
))
op
->
SetAttr
(
"use_mkldnn"
,
use_mkldnn
);
if
(
type
==
"conv2d_transpose"
)
{
op
->
SetAttr
(
"name"
,
name
);
op
->
SetInput
(
"Input"
,
{
inputs
[
0
]});
op
->
SetInput
(
"Filter"
,
{
inputs
[
1
]});
op
->
SetOutput
(
"Output"
,
{
outputs
[
0
]});
op
->
SetAttr
(
"is_test"
,
true
);
op
->
SetAttr
(
"strides"
,
std
::
vector
<
int
>
(
2
,
2
));
}
else
if
(
std
::
unordered_set
<
std
::
string
>
{
"gelu"
,
"leaky_relu"
,
"relu"
,
"tanh"
}
.
count
(
type
))
{
op
->
SetInput
(
"X"
,
inputs
);
op
->
SetOutput
(
"Out"
,
{
outputs
[
0
]});
}
else
if
(
type
==
"elementwise_add"
)
{
op
->
SetAttr
(
"axis"
,
static_cast
<
int
>
(
1
));
op
->
SetInput
(
"X"
,
{
inputs
[
0
]});
op
->
SetInput
(
"Y"
,
{
inputs
[
1
]});
op
->
SetOutput
(
"Out"
,
{
outputs
[
0
]});
}
else
if
(
type
==
"batch_norm"
)
{
op
->
SetAttr
(
"is_test"
,
true
);
op
->
SetAttr
(
"epsilon"
,
static_cast
<
float
>
(
1e-5
));
op
->
SetInput
(
"X"
,
{
inputs
[
0
]});
op
->
SetInput
(
"Scale"
,
{
inputs
[
1
]});
op
->
SetInput
(
"Bias"
,
{
inputs
[
2
]});
op
->
SetInput
(
"Mean"
,
{
inputs
[
3
]});
op
->
SetInput
(
"Variance"
,
{
inputs
[
4
]});
op
->
SetOutput
(
"Y"
,
{
outputs
[
0
]});
op
->
SetOutput
(
"MeanOut"
,
{
outputs
[
1
]});
op
->
SetOutput
(
"VarianceOut"
,
{
outputs
[
2
]});
op
->
SetOutput
(
"SavedMean"
,
{
outputs
[
3
]});
op
->
SetOutput
(
"SavedVariance"
,
{
outputs
[
4
]});
}
else
{
FAIL
()
<<
"Unexpected operator type."
;
}
}
ProgramDesc
BuildProgramDesc
(
bool
is_elementwise_add
)
{
ProgramDesc
prog
;
// params
for
(
auto
&
v
:
std
::
vector
<
std
::
string
>
(
{
"weights"
,
"weights2"
,
"bias_bn"
,
"scale"
,
"mean"
,
"variance"
,
"saved_mean"
,
"saved_variance"
,
"bias_bn2"
,
"scale2"
,
"mean2"
,
"variance2"
,
"saved_mean2"
,
"saved_variance2"
}))
{
auto
*
var
=
prog
.
MutableBlock
(
0
)
->
Var
(
v
);
var
->
SetType
(
proto
::
VarType
::
LOD_TENSOR
);
var
->
SetPersistable
(
true
);
}
// inputs and non-persistant holders
for
(
auto
&
v
:
std
::
vector
<
std
::
string
>
(
{
"a"
,
"b"
,
"e"
,
"f"
,
"g"
,
"h"
,
"i"
,
"j"
,
"k"
,
"l"
,
"m"
}))
{
auto
*
var
=
prog
.
MutableBlock
(
0
)
->
Var
(
v
);
var
->
SetType
(
proto
::
VarType
::
LOD_TENSOR
);
}
SetOp
(
&
prog
,
"conv2d_transpose"
,
"conv1"
,
std
::
vector
<
std
::
string
>
({
"a"
,
"weights"
}),
std
::
vector
<
std
::
string
>
({
"f"
}),
true
);
if
(
is_elementwise_add
==
true
)
{
SetOp
(
&
prog
,
"conv2d_transpose"
,
"conv2"
,
std
::
vector
<
std
::
string
>
({
"b"
,
"weights2"
}),
std
::
vector
<
std
::
string
>
({
"e"
}),
true
);
SetOp
(
&
prog
,
"elementwise_add"
,
"elementwise_add1"
,
std
::
vector
<
std
::
string
>
({
"f"
,
"g"
}),
std
::
vector
<
std
::
string
>
({
"h"
}),
true
);
SetOp
(
&
prog
,
"elementwise_add"
,
"elementwise_add2"
,
std
::
vector
<
std
::
string
>
({
"e"
,
"g"
}),
std
::
vector
<
std
::
string
>
({
"j"
}),
true
);
SetOp
(
&
prog
,
"batch_norm"
,
"batch_norm1"
,
std
::
vector
<
std
::
string
>
(
{
"h"
,
"scale"
,
"bias_bn"
,
"mean"
,
"variance"
}),
std
::
vector
<
std
::
string
>
(
{
"i"
,
"mean"
,
"variance"
,
"saved_mean"
,
"saved_variance"
}),
true
);
SetOp
(
&
prog
,
"batch_norm"
,
"batch_norm2"
,
std
::
vector
<
std
::
string
>
(
{
"j"
,
"scale2"
,
"bias_bn2"
,
"mean2"
,
"variance2"
}),
std
::
vector
<
std
::
string
>
(
{
"k"
,
"mean2"
,
"variance2"
,
"saved_mean2"
,
"saved_variance2"
}),
true
);
SetOp
(
&
prog
,
"elementwise_add"
,
"elementwise_add3"
,
std
::
vector
<
std
::
string
>
({
"i"
,
"k"
}),
std
::
vector
<
std
::
string
>
({
"l"
}),
true
);
}
else
{
SetOp
(
&
prog
,
"batch_norm"
,
"batch_norm1"
,
std
::
vector
<
std
::
string
>
(
{
"f"
,
"scale"
,
"bias_bn"
,
"mean"
,
"variance"
}),
std
::
vector
<
std
::
string
>
(
{
"l"
,
"mean"
,
"variance"
,
"saved_mean"
,
"saved_variance"
}),
true
);
}
SetOp
(
&
prog
,
"gelu"
,
"gelu1"
,
std
::
vector
<
std
::
string
>
({
"l"
}),
std
::
vector
<
std
::
string
>
({
"m"
}),
true
);
return
prog
;
}
void
FillTensorWithRandomData
(
Tensor
*
tnsr
,
float
lowb
,
float
upb
,
platform
::
CPUPlace
place
)
{
float
*
ptr
=
tnsr
->
mutable_data
<
float
>
(
place
);
// Initialize input data
std
::
uniform_real_distribution
<
float
>
dist
(
static_cast
<
float
>
(
lowb
),
static_cast
<
float
>
(
upb
));
std
::
mt19937
engine
;
for
(
int
i
=
0
;
i
<
tnsr
->
numel
();
++
i
)
{
ptr
[
i
]
=
dist
(
engine
);
}
}
void
CompareTensors
(
Tensor
*
tensor1
,
Tensor
*
tensor2
)
{
// check dims
for
(
int
i
=
0
;
i
<
tensor1
->
numel
();
++
i
)
{
EXPECT_NEAR
(
tensor1
->
data
<
float
>
()[
i
],
tensor2
->
data
<
float
>
()[
i
],
1e-3
);
}
}
public:
void
MainTest
(
bool
is_elementwise_add
)
{
auto
base_prog
=
BuildProgramDesc
(
is_elementwise_add
);
std
::
unique_ptr
<
ir
::
Graph
>
graph
(
new
ir
::
Graph
(
base_prog
));
Scope
scope
;
auto
place
=
paddle
::
platform
::
CPUPlace
();
NaiveExecutor
exe
{
place
};
auto
pass
=
PassRegistry
::
Instance
().
Get
(
is_elementwise_add
?
"conv_transpose_eltwiseadd_bn_fuse_pass"
:
"conv_transpose_bn_fuse_pass"
);
graph
->
SetNotOwned
(
kParamScopeAttr
,
&
scope
);
auto
&
prog
=
graph
->
OriginProgram
();
exe
.
CreateVariables
(
prog
,
0
,
true
,
&
scope
);
exe
.
CreateVariables
(
prog
,
0
,
false
,
&
scope
);
exe
.
Prepare
(
&
scope
,
prog
,
0
,
false
);
std
::
cout
<<
GenScopeTreeDebugInfo
(
&
scope
);
auto
*
a_tensor
=
exe
.
FindTensor
(
"a"
);
auto
*
b_tensor
=
exe
.
FindTensor
(
"b"
);
auto
*
weights_tensor
=
exe
.
FindTensor
(
"weights"
);
auto
*
weights2_tensor
=
exe
.
FindTensor
(
"weights2"
);
auto
*
g_tensor
=
exe
.
FindTensor
(
"g"
);
// Batch Norm
auto
*
bias_bn_tensor
=
exe
.
FindTensor
(
"bias_bn"
);
// shift
auto
*
scale_tensor
=
exe
.
FindTensor
(
"scale"
);
auto
*
mean_tensor
=
exe
.
FindTensor
(
"mean"
);
auto
*
variance_tensor
=
exe
.
FindTensor
(
"variance"
);
auto
*
bias_bn2_tensor
=
exe
.
FindTensor
(
"bias_bn2"
);
// shift
auto
*
scale2_tensor
=
exe
.
FindTensor
(
"scale2"
);
auto
*
mean2_tensor
=
exe
.
FindTensor
(
"mean2"
);
auto
*
variance2_tensor
=
exe
.
FindTensor
(
"variance2"
);
int
ic
,
oc
,
iw
,
ih
,
n
,
fw
,
fh
;
n
=
1
;
fw
=
fh
=
2
;
oc
=
ic
=
24
;
iw
=
ih
=
160
;
// mb1_ic24oc24_ih8oh16kh2sh2dh0ph0_iw80ow160kw2sw2dw0pw0 deconv
a_tensor
->
Resize
({
n
,
ic
,
ih
,
iw
});
weights_tensor
->
Resize
({
oc
,
ic
,
fh
,
fw
});
g_tensor
->
Resize
({
oc
});
bias_bn_tensor
->
Resize
({
oc
});
scale_tensor
->
Resize
({
oc
});
mean_tensor
->
Resize
({
oc
});
variance_tensor
->
Resize
({
oc
});
if
(
is_elementwise_add
)
{
b_tensor
->
Resize
({
n
,
ic
,
ih
,
iw
});
weights2_tensor
->
Resize
({
oc
,
ic
,
fh
,
fw
});
bias_bn2_tensor
->
Resize
({
oc
});
scale2_tensor
->
Resize
({
oc
});
mean2_tensor
->
Resize
({
oc
});
variance2_tensor
->
Resize
({
oc
});
}
// Input and conv transpose
FillTensorWithRandomData
(
a_tensor
,
1.0
f
,
2.0
f
,
place
);
FillTensorWithRandomData
(
g_tensor
,
1.0
f
,
2.0
f
,
place
);
FillTensorWithRandomData
(
weights_tensor
,
1.0
f
,
2.0
f
,
place
);
if
(
is_elementwise_add
)
{
FillTensorWithRandomData
(
b_tensor
,
1.0
f
,
2.0
f
,
place
);
FillTensorWithRandomData
(
weights2_tensor
,
1.0
f
,
2.0
f
,
place
);
}
// First Batch_Norm
FillTensorWithRandomData
(
bias_bn_tensor
,
1.0
f
,
2.0
f
,
place
);
FillTensorWithRandomData
(
scale_tensor
,
1.0
f
,
2.0
f
,
place
);
FillTensorWithRandomData
(
mean_tensor
,
1.0
f
,
2.0
f
,
place
);
FillTensorWithRandomData
(
variance_tensor
,
1.0
f
,
2.0
f
,
place
);
// Second Batch Norm (exists only when elementwise_add is present)
if
(
is_elementwise_add
)
{
FillTensorWithRandomData
(
bias_bn2_tensor
,
1.0
f
,
2.0
f
,
place
);
FillTensorWithRandomData
(
scale2_tensor
,
1.0
f
,
2.0
f
,
place
);
FillTensorWithRandomData
(
mean2_tensor
,
1.0
f
,
2.0
f
,
place
);
FillTensorWithRandomData
(
variance2_tensor
,
1.0
f
,
2.0
f
,
place
);
}
exe
.
Run
();
// Get result without IR passes applied
// Need to copy result over as the same scope is used in both executors
// so first result will be overwritten by second
auto
*
m_tensor
=
exe
.
FindTensor
(
"m"
);
Tensor
no_ir_result
;
TensorCopy
(
*
m_tensor
,
place
,
&
no_ir_result
);
graph
.
reset
(
pass
->
Apply
(
graph
.
release
()));
// Get Program from graph
ProgramDesc
optimized_prog
;
auto
graph2program_pass
=
paddle
::
framework
::
ir
::
PassRegistry
::
Instance
().
Get
(
"graph_to_program_pass"
);
graph2program_pass
->
SetNotOwned
<
paddle
::
framework
::
ProgramDesc
>
(
"program"
,
&
optimized_prog
);
graph2program_pass
->
Apply
(
graph
.
release
());
exe
.
Prepare
(
&
scope
,
optimized_prog
,
0
,
false
);
exe
.
Run
();
auto
*
ir_result
=
exe
.
FindTensor
(
"m"
);
// Two graphs. Execute both and compare results
CompareTensors
(
&
no_ir_result
,
ir_result
);
VLOG
(
3
)
<<
DebugString
(
graph
);
}
};
TEST
(
MKLDNNConvBatchNormPassTest
,
conv_batch_norm
)
{
MKLDNNConvBatchNormPassTest
().
MainTest
(
false
);
}
TEST
(
MKLDNNConvBatchNormPassTest
,
conv_elementwise_add_batch_norm
)
{
MKLDNNConvBatchNormPassTest
().
MainTest
(
true
);
}
}
// namespace ir
}
// namespace framework
}
// namespace paddle
USE_PASS
(
conv_transpose_bn_fuse_pass
);
USE_PASS
(
conv_transpose_eltwiseadd_bn_fuse_pass
);
USE_PASS
(
graph_to_program_pass
);
paddle/fluid/framework/ir/pass.cc
浏览文件 @
40a5f3fd
...
...
@@ -18,6 +18,7 @@ limitations under the License. */
#include <utility>
#include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/platform/device_context.h"
namespace
paddle
{
namespace
framework
{
...
...
@@ -49,6 +50,14 @@ Graph* Pass::Apply(Graph* graph) const {
graph
->
Set
<
PassRecorder
>
(
kPassRecorder
,
new
PassRecorder
);
}
graph
->
Get
<
PassRecorder
>
(
kPassRecorder
).
insert
(
Type
());
#ifdef PADDLE_WITH_MKLDNN
// Clear mkl-dnn cache,
// Passes can change params, tensors, so caching need to be discarded
platform
::
DeviceContextPool
&
pool
=
platform
::
DeviceContextPool
::
Instance
();
platform
::
MKLDNNDeviceContext
*
dev_ctx
=
(
platform
::
MKLDNNDeviceContext
*
)
pool
.
Get
(
paddle
::
platform
::
CPUPlace
());
dev_ctx
->
ResetBlobMap
();
#endif
return
graph
;
}
...
...
paddle/fluid/framework/naive_executor.cc
浏览文件 @
40a5f3fd
...
...
@@ -118,5 +118,20 @@ void NaiveExecutor::CleanFeedFetchOps() {
ops_
.
swap
(
ops
);
}
NaiveExecutor
::~
NaiveExecutor
()
{
#ifdef PADDLE_WITH_MKLDNN
// Clear mkl-dnn cache,
// this is needed to have mkl-dnn unit tests working
if
(
platform
::
is_cpu_place
(
place_
))
{
platform
::
DeviceContextPool
&
pool
=
platform
::
DeviceContextPool
::
Instance
();
platform
::
MKLDNNDeviceContext
*
dev_ctx
=
(
platform
::
MKLDNNDeviceContext
*
)
pool
.
Get
(
place_
);
dev_ctx
->
ResetBlobMap
();
platform
::
MKLDNNDeviceContext
::
tls
().
set_cur_paddle_data_layout
(
paddle
::
framework
::
DataLayout
::
kNCHW
);
}
#endif
}
}
// namespace framework
}
// namespace paddle
paddle/fluid/framework/naive_executor.h
浏览文件 @
40a5f3fd
...
...
@@ -32,6 +32,8 @@ class NaiveExecutor {
public:
explicit
NaiveExecutor
(
const
platform
::
Place
&
place
)
:
place_
(
place
)
{}
~
NaiveExecutor
();
// Create child scope.
// Create variables.
// @with_feed_fetch_ops: whether to work with the feed and fetch operators.
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录