Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
s920243400
PaddleDetection
提交
2739096e
P
PaddleDetection
项目概览
s920243400
/
PaddleDetection
与 Fork 源项目一致
Fork自
PaddlePaddle / PaddleDetection
通知
2
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleDetection
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
2739096e
编写于
1月 27, 2019
作者:
D
dzhwinter
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
compatibable with python side mem_opt
上级
8f3b2523
变更
11
隐藏空白更改
内联
并排
Showing
11 changed file
with
633 addition
and
158 deletion
+633
-158
paddle/fluid/framework/details/CMakeLists.txt
paddle/fluid/framework/details/CMakeLists.txt
+4
-2
paddle/fluid/framework/details/build_strategy.cc
paddle/fluid/framework/details/build_strategy.cc
+29
-0
paddle/fluid/framework/details/graph_print_pass.cc
paddle/fluid/framework/details/graph_print_pass.cc
+125
-0
paddle/fluid/framework/details/graph_print_pass.h
paddle/fluid/framework/details/graph_print_pass.h
+66
-0
paddle/fluid/framework/details/graph_print_pass_test.cc
paddle/fluid/framework/details/graph_print_pass_test.cc
+79
-0
paddle/fluid/framework/details/graph_test_base.h
paddle/fluid/framework/details/graph_test_base.h
+80
-0
paddle/fluid/framework/details/inplace_op_pass.cc
paddle/fluid/framework/details/inplace_op_pass.cc
+121
-37
paddle/fluid/framework/details/memory_optimize_pass_test.cc
paddle/fluid/framework/details/memory_optimize_pass_test.cc
+1
-54
paddle/fluid/framework/details/multi_devices_graph_print_pass.h
.../fluid/framework/details/multi_devices_graph_print_pass.h
+1
-9
python/paddle/fluid/tests/unittests/parallel_executor_test_base.py
...ddle/fluid/tests/unittests/parallel_executor_test_base.py
+58
-56
python/paddle/fluid/tests/unittests/test_ir_inplace_pass.py
python/paddle/fluid/tests/unittests/test_ir_inplace_pass.py
+69
-0
未找到文件。
paddle/fluid/framework/details/CMakeLists.txt
浏览文件 @
2739096e
...
...
@@ -51,7 +51,8 @@ cc_library(gather_op_handle SRCS gather_op_handle.cc DEPS op_handle_base scope d
cc_library
(
fuse_vars_op_handle SRCS fuse_vars_op_handle.cc DEPS op_handle_base scope
)
cc_library
(
memory_optimize_pass SRCS memory_optimize_pass.cc memory_optimize_helper.cc DEPS graph graph_helper pass
)
cc_library
(
inplace_op_pass SRCS inplace_op_pass DEPS memory_optimize_pass op_info
)
cc_library
(
graph_print_pass SRCS graph_print_pass.cc DEPS graph_helper pass
)
cc_library
(
inplace_op_pass SRCS inplace_op_pass.cc DEPS memory_optimize_pass op_info graph_print_pass
)
cc_library
(
modify_op_lock_and_record_event_pass SRCS modify_op_lock_and_record_event_pass.cc DEPS computation_op_handle op_graph_view multi_devices_helper
)
cc_library
(
memory_early_delete_pass SRCS memory_early_delete_pass.cc DEPS memory_optimize_pass computation_op_handle scale_loss_grad_op_handle rpc_op_handle
all_reduce_op_handle reduce_op_handle broadcast_op_handle data_balance_op_handle graph graph_helper pass
)
...
...
@@ -72,6 +73,7 @@ if (WITH_GPU)
endif
()
cc_test
(
memory_optimize_helper_test SRCS memory_optimize_helper_test.cc memory_optimize_helper.cc DEPS framework_proto graph
)
cc_test
(
memory_optimize_pass_test SRCS memory_optimize_pass_test.cc memory_optimize_pass.cc memory_optimize_helper.cc DEPS framework_proto graph graph_helper op_registry pass
)
cc_test
(
graph_print_pass_test SRCS graph_print_pass_test.cc DEPS graph_print_pass framework_proto graph graph_helper op_registry pass
)
cc_library
(
ssa_graph_executor SRCS ssa_graph_executor.cc DEPS
${
SSA_GRAPH_EXECUTOR_DEPS
}
)
...
...
@@ -96,4 +98,4 @@ cc_library(build_strategy SRCS build_strategy.cc DEPS
multi_devices_graph_print_pass multi_devices_graph_check_pass
fuse_elewise_add_act_pass multi_batch_merge_pass
fuse_relu_depthwise_conv_pass
memory_optimize_pass lock_free_optimize_pass
)
memory_optimize_pass lock_free_optimize_pass
graph_print_pass
)
paddle/fluid/framework/details/build_strategy.cc
浏览文件 @
2739096e
...
...
@@ -17,6 +17,7 @@ limitations under the License. */
#include <glog/logging.h>
#include <memory>
#include "paddle/fluid/framework/details/graph_print_pass.h"
#include "paddle/fluid/framework/details/memory_optimize_helper.h"
#include "paddle/fluid/framework/details/multi_devices_graph_pass.h"
#include "paddle/fluid/framework/details/multi_devices_graph_print_pass.h"
...
...
@@ -43,8 +44,25 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder {
explicit
ParallelExecutorPassBuilder
(
const
BuildStrategy
&
strategy
)
:
ir
::
PassBuilder
(),
strategy_
(
strategy
)
{
if
(
strategy_
.
enable_inplace_
)
{
// before inplaced
// if (!strategy_.debug_graphviz_path_.empty()) {
// const std::string path = strategy_.debug_graphviz_path_ +
// "before_inplaced";
// auto pass = AppendPass("graph_print_pass");
// pass->Set<std::string>(kGraphvizPath, new std::string(path));
// }
AppendPass
(
"inplace_pass"
);
// after inplaced
// if (!strategy_.debug_graphviz_path_.empty()) {
// const std::string path = strategy_.debug_graphviz_path_ +
// "after_inplaced";
// auto pass = AppendPass("graph_print_pass");
// pass->Set<std::string>(details::kGraphvizPath, new
// std::string(path));
// }
}
if
(
strategy_
.
enable_sequential_execution_
)
{
AppendPass
(
"sequential_execution_pass"
);
}
...
...
@@ -189,6 +207,9 @@ std::unique_ptr<ir::Graph> BuildStrategy::Apply(
pass
->
SetNotOwned
<
platform
::
NCCLContextMap
>
(
"nccl_ctxs"
,
nctx
);
#endif
}
else
if
(
pass
->
Type
()
==
"memory_optimize_pass"
)
{
if
(
graph
->
Has
(
kAllOpDescs
))
{
graph
->
Erase
(
kAllOpDescs
);
}
const
std
::
vector
<
OpDesc
*>
*
all_op_descs
=
new
std
::
vector
<
OpDesc
*>
(
main_program
.
Block
(
0
).
AllOps
());
graph
->
Set
<
const
std
::
vector
<
OpDesc
*>>
(
kAllOpDescs
,
...
...
@@ -219,6 +240,9 @@ std::unique_ptr<ir::Graph> BuildStrategy::Apply(
if
(
graph
->
Has
(
kAllOpDescs
))
{
graph
->
Erase
(
kAllOpDescs
);
}
if
(
!
graph
->
Has
(
kGraphviz
))
{
graph
->
Set
<
GraphvizNodes
>
(
kGraphviz
,
new
GraphvizNodes
);
}
graph
->
Set
<
const
std
::
vector
<
OpDesc
*>>
(
kAllOpDescs
,
new
std
::
vector
<
OpDesc
*>
(
main_program
.
Block
(
0
).
AllOps
()));
...
...
@@ -228,6 +252,10 @@ std::unique_ptr<ir::Graph> BuildStrategy::Apply(
"GPU, skipped."
;
continue
;
}
}
else
if
(
pass
->
Type
()
==
"graph_print_path"
)
{
if
(
!
graph
->
Has
(
kGraphviz
))
{
graph
->
Set
<
GraphvizNodes
>
(
kGraphviz
,
new
GraphvizNodes
);
}
}
graph
=
pass
->
Apply
(
std
::
move
(
graph
));
}
...
...
@@ -253,3 +281,4 @@ USE_PASS(all_reduce_deps_pass);
USE_PASS
(
modify_op_lock_and_record_event_pass
);
USE_PASS
(
inplace_pass
);
USE_PASS
(
lock_free_optimize_pass
);
USE_PASS
(
graph_print_pass
);
paddle/fluid/framework/details/graph_print_pass.cc
0 → 100644
浏览文件 @
2739096e
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/details/graph_print_pass.h"
#include <string>
#include <vector>
namespace
paddle
{
namespace
framework
{
namespace
details
{
class
GraphvizVar
:
public
GraphvizNode
{
public:
GraphvizVar
(
ir
::
Node
*
n
,
const
int
&
i
)
:
GraphvizNode
(
n
,
i
)
{}
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
sout
,
const
GraphvizVar
&
var
)
{
sout
<<
"var_"
<<
var
.
id_
<<
" [label=
\"
"
<<
var
.
node_
->
Name
()
<<
"
\"
]"
<<
std
::
endl
;
return
sout
;
}
};
class
GraphvizOp
:
public
GraphvizNode
{
public:
GraphvizOp
(
ir
::
Node
*
n
,
const
int
&
i
)
:
GraphvizNode
(
n
,
i
)
{}
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
sout
,
const
GraphvizOp
&
op
)
{
sout
<<
"op_"
+
std
::
to_string
(
op
.
id_
)
<<
" [label=
\"
"
<<
op
.
node_
->
Name
()
<<
"
\"
, shape=rect]"
<<
std
::
endl
;
PADDLE_ENFORCE
(
op
.
stream_
.
rdbuf
()
->
in_avail
()
!=
0
,
"No inputs outputs. Please call AddEdge first!"
);
sout
<<
op
.
stream_
.
str
();
return
sout
;
}
template
<
typename
Callback
>
void
AddEdge
(
const
Callback
&
cb
)
{
std
::
string
op_name
=
"op_"
+
std
::
to_string
(
id_
);
for
(
auto
var
:
node_
->
inputs
)
{
std
::
string
var_name
=
"var_"
+
std
::
to_string
(
cb
(
var
));
stream_
<<
var_name
<<
"->"
<<
op_name
<<
std
::
endl
;
}
for
(
auto
var
:
node_
->
outputs
)
{
std
::
string
var_name
=
"var_"
+
std
::
to_string
(
cb
(
var
));
stream_
<<
op_name
<<
"->"
<<
var_name
<<
std
::
endl
;
}
}
private:
std
::
ostringstream
stream_
;
};
template
<
typename
T
,
typename
Container
>
std
::
vector
<
T
*>
FilterByNodeWrapper
(
const
Container
&
con
)
{
std
::
vector
<
T
*>
ret
;
for
(
auto
&
node
:
con
)
{
auto
i
=
dynamic_cast
<
T
*>
(
node
.
get
());
if
(
i
!=
nullptr
)
ret
.
emplace_back
(
i
);
}
return
ret
;
}
std
::
unordered_map
<
ir
::
Node
*
,
int
>
SSAGraphPrinterImpl
::
ToGraphvizNode
(
const
ir
::
Graph
&
graph
)
const
{
// Convert to GraphvizNode format
auto
&
graphviz_nodes
=
graph
.
Get
<
GraphvizNodes
>
(
kGraphviz
);
graphviz_nodes
.
clear
();
std
::
unordered_map
<
ir
::
Node
*
,
int
>
vars
;
int
var_id
=
0
;
int
op_id
=
0
;
for
(
auto
&
node
:
graph
.
Nodes
())
{
if
(
node
->
IsVar
())
{
graphviz_nodes
.
emplace
(
new
GraphvizVar
(
node
,
var_id
));
vars
.
emplace
(
std
::
make_pair
(
node
,
var_id
++
));
}
else
if
(
node
->
IsOp
())
{
graphviz_nodes
.
emplace
(
new
GraphvizOp
(
node
,
op_id
++
));
}
else
{
PADDLE_THROW
(
"Unknown op type"
);
}
}
return
vars
;
}
void
SSAGraphPrinterImpl
::
Print
(
const
ir
::
Graph
&
graph
,
std
::
ostream
&
sout
)
const
{
auto
vars
=
ToGraphvizNode
(
graph
);
auto
&
nodes
=
graph
.
Get
<
GraphvizNodes
>
(
kGraphviz
);
sout
<<
"digraph G {
\n
"
;
for
(
auto
&
var
:
FilterByNodeWrapper
<
GraphvizVar
>
(
nodes
))
{
sout
<<
*
var
;
}
for
(
auto
&
op
:
FilterByNodeWrapper
<
GraphvizOp
>
(
nodes
))
{
op
->
AddEdge
([
&
vars
](
ir
::
Node
*
var
)
{
return
vars
.
at
(
var
);
});
sout
<<
*
op
;
}
sout
<<
"}
\n
"
;
}
std
::
unique_ptr
<
ir
::
Graph
>
SSAGraphPrintPass
::
ApplyImpl
(
std
::
unique_ptr
<
ir
::
Graph
>
graph
)
const
{
printer_
.
reset
(
new
SSAGraphPrinterImpl
());
std
::
unique_ptr
<
std
::
ostream
>
fout
(
new
std
::
ofstream
(
Get
<
std
::
string
>
(
kGraphvizPath
)));
PADDLE_ENFORCE
(
fout
->
good
()
==
true
,
"Failed to open file."
);
printer_
->
Print
(
*
graph
,
*
fout
);
return
graph
;
}
}
// namespace details
}
// namespace framework
}
// namespace paddle
REGISTER_PASS
(
graph_print_pass
,
paddle
::
framework
::
details
::
SSAGraphPrintPass
)
.
RequirePassAttr
(
paddle
::
framework
::
details
::
kGraphvizPath
);
paddle/fluid/framework/details/graph_print_pass.h
0 → 100644
浏览文件 @
2739096e
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <fstream>
#include <memory>
#include <unordered_map>
#include "paddle/fluid/framework/details/multi_devices_helper.h"
namespace
paddle
{
namespace
framework
{
namespace
details
{
constexpr
char
kGraphvizPath
[]
=
"debug_graphviz_path"
;
constexpr
char
kGraphviz
[]
=
"graphviz"
;
class
GraphvizNode
{
public:
GraphvizNode
(
ir
::
Node
*
n
,
const
int
&
i
)
:
node_
(
n
),
id_
(
i
)
{}
virtual
~
GraphvizNode
()
=
default
;
protected:
ir
::
Node
*
node_
;
int
id_
;
};
class
GraphvizNode
;
typedef
std
::
unordered_set
<
std
::
unique_ptr
<
GraphvizNode
>>
GraphvizNodes
;
class
SSAGraphPrinter
{
public:
virtual
~
SSAGraphPrinter
()
{}
virtual
void
Print
(
const
ir
::
Graph
&
graph
,
std
::
ostream
&
sout
)
const
=
0
;
};
class
SSAGraphPrinterImpl
:
public
SSAGraphPrinter
{
public:
void
Print
(
const
ir
::
Graph
&
graph
,
std
::
ostream
&
sout
)
const
override
;
private:
std
::
unordered_map
<
ir
::
Node
*
,
int
>
ToGraphvizNode
(
const
ir
::
Graph
&
graph
)
const
;
};
class
SSAGraphPrintPass
:
public
ir
::
Pass
{
protected:
std
::
unique_ptr
<
ir
::
Graph
>
ApplyImpl
(
std
::
unique_ptr
<
ir
::
Graph
>
graph
)
const
override
;
private:
mutable
std
::
unique_ptr
<
SSAGraphPrinter
>
printer_
;
};
}
// namespace details
}
// namespace framework
}
// namespace paddle
paddle/fluid/framework/details/graph_print_pass_test.cc
0 → 100644
浏览文件 @
2739096e
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/details/graph_print_pass.h"
#include "paddle/fluid/framework/details/graph_test_base.h"
REGISTER_OPERATOR
(
sum
,
paddle
::
framework
::
DummyOp
,
paddle
::
framework
::
SumOpMaker
);
REGISTER_OPERATOR
(
split
,
paddle
::
framework
::
DummyOp
,
paddle
::
framework
::
SplitOpMaker
);
/*
a @ b
c
d @ e
*/
using
paddle
::
framework
::
ProgramDesc
;
using
paddle
::
framework
::
proto
::
VarType
;
inline
static
ProgramDesc
FillProgramDesc
()
{
ProgramDesc
prog
;
prog
.
MutableBlock
(
0
)
->
Var
(
"a"
)
->
SetType
(
VarType
::
LOD_TENSOR
);
prog
.
MutableBlock
(
0
)
->
Var
(
"b"
)
->
SetType
(
VarType
::
LOD_TENSOR
);
prog
.
MutableBlock
(
0
)
->
Var
(
"c"
)
->
SetType
(
VarType
::
LOD_TENSOR
);
prog
.
MutableBlock
(
0
)
->
Var
(
"d"
)
->
SetType
(
VarType
::
LOD_TENSOR
);
prog
.
MutableBlock
(
0
)
->
Var
(
"e"
)
->
SetType
(
VarType
::
LOD_TENSOR
);
{
auto
*
op
=
prog
.
MutableBlock
(
0
)
->
AppendOp
();
op
->
SetType
(
"sum"
);
op
->
SetInput
(
"X"
,
{
"a"
,
"b"
});
op
->
SetOutput
(
"Out"
,
{
"c"
});
}
{
auto
*
op
=
prog
.
MutableBlock
(
0
)
->
AppendOp
();
op
->
SetType
(
"split"
);
op
->
SetInput
(
"X"
,
{
"c"
});
op
->
SetOutput
(
"Out"
,
{
"d"
,
"e"
});
}
{
auto
*
op
=
prog
.
MutableBlock
(
0
)
->
AppendOp
();
op
->
SetType
(
"sum"
);
op
->
SetInput
(
"X"
,
{
"d"
,
"e"
});
op
->
SetOutput
(
"Out"
,
{
"d"
});
}
return
prog
;
}
namespace
paddle
{
namespace
framework
{
namespace
details
{
TEST
(
SSAGraphPrinter
,
Normal
)
{
auto
program
=
FillProgramDesc
();
std
::
unique_ptr
<
ir
::
Graph
>
graph
(
new
ir
::
Graph
(
program
));
graph
->
Set
<
GraphvizNodes
>
(
kGraphviz
,
new
GraphvizNodes
);
std
::
unique_ptr
<
SSAGraphPrinter
>
printer
(
new
SSAGraphPrinterImpl
);
// redirect debug graph to a file.
constexpr
char
graph_path
[]
=
"graph_print_pass.txt"
;
std
::
unique_ptr
<
std
::
ostream
>
fout
(
new
std
::
ofstream
(
graph_path
));
PADDLE_ENFORCE
(
fout
->
good
());
printer
->
Print
(
*
graph
,
*
fout
);
}
}
// namespace details
}
// namespace framework
}
// namespace paddle
paddle/fluid/framework/details/graph_test_base.h
0 → 100644
浏览文件 @
2739096e
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <algorithm>
#include <iostream>
#include <iterator>
#include <string>
#include "glog/logging.h"
#include "gtest/gtest.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/program_desc.h"
namespace
paddle
{
namespace
framework
{
class
DummyOp
:
public
OperatorBase
{
public:
DummyOp
(
const
std
::
string
&
type
,
const
VariableNameMap
&
inputs
,
const
VariableNameMap
&
outputs
,
const
AttributeMap
&
attrs
)
:
OperatorBase
(
type
,
inputs
,
outputs
,
attrs
)
{}
private:
void
RunImpl
(
const
Scope
&
scope
,
const
platform
::
Place
&
place
)
const
override
{}
};
class
SumOpMaker
:
public
OpProtoAndCheckerMaker
{
public:
void
Make
()
{
AddInput
(
"X"
,
""
).
AsDuplicable
();
AddOutput
(
"Out"
,
""
);
AddComment
(
""
);
}
};
class
AssignOpMaker
:
public
OpProtoAndCheckerMaker
{
public:
void
Make
()
{
AddInput
(
"X"
,
""
).
AsDuplicable
();
AddOutput
(
"Out"
,
""
);
AddComment
(
""
);
}
};
class
SplitOpMaker
:
public
OpProtoAndCheckerMaker
{
public:
void
Make
()
{
AddInput
(
"X"
,
""
);
AddOutput
(
"Out"
,
""
).
AsDuplicable
();
AddComment
(
""
);
}
};
class
DummyVarTypeInference
:
public
VarTypeInference
{
public:
void
operator
()(
const
OpDesc
&
op_desc
,
BlockDesc
*
block
)
const
override
{
auto
&
inputs
=
op_desc
.
Input
(
"X"
);
auto
type
=
block
->
Var
(
inputs
.
front
())
->
GetType
();
auto
out_var_name
=
op_desc
.
Output
(
"Out"
).
front
();
block
->
Var
(
out_var_name
)
->
SetType
(
type
);
}
};
}
// namespace framework
}
// namespace paddle
paddle/fluid/framework/details/inplace_op_pass.cc
浏览文件 @
2739096e
...
...
@@ -21,6 +21,7 @@
#include <unordered_map>
#include <unordered_set>
#include <vector>
#include "paddle/fluid/framework/details/graph_print_pass.h"
#include "paddle/fluid/framework/details/memory_optimize_pass.h"
#include "paddle/fluid/framework/op_info.h"
...
...
@@ -76,42 +77,92 @@ namespace paddle {
namespace
framework
{
namespace
details
{
static
inline
ir
::
Node
*
GetNextInplacedOpOutput
(
ir
::
Node
*
var
)
{
static
inline
std
::
string
NodeDebugString
(
ir
::
Node
*
var
)
{
std
::
ostringstream
os
;
if
(
var
->
IsCtrlVar
())
{
os
<<
"kControlDepVarName"
<<
" "
;
}
else
if
(
var
->
IsOp
())
{
os
<<
"kOperation"
<<
" "
<<
var
->
Name
();
PADDLE_ENFORCE
(
var
->
Op
()
!=
nullptr
&&
var
->
Op
()
->
Type
()
==
var
->
Name
());
}
else
if
(
var
->
IsVar
())
{
os
<<
"kVariable"
<<
" "
<<
var
->
Name
();
PADDLE_ENFORCE
(
var
->
Var
()
!=
nullptr
&&
var
->
Var
()
->
Name
()
==
var
->
Name
());
}
else
{
PADDLE_THROW
(
"Unknown node type."
);
}
return
os
.
str
();
}
static
inline
std
::
string
OpDebugString
(
ir
::
Node
*
var
)
{
ir
::
Node
*
op
=
var
;
if
(
var
->
IsVar
())
op
=
var
->
inputs
.
at
(
0
);
std
::
stringstream
os
;
os
<<
op
->
Name
()
<<
" : "
;
os
<<
"Input "
;
VLOG
(
3
)
<<
op
->
Name
();
for
(
auto
*
var
:
op
->
inputs
)
{
if
(
var
->
IsVar
()
&&
!
var
->
IsCtrlVar
())
{
PADDLE_ENFORCE
(
var
->
Var
()
!=
nullptr
&&
var
->
Var
()
->
Name
()
==
var
->
Name
(),
"unmatched desc and var"
);
// os << var << ":" << var->Name() << " ";
os
<<
var
->
Name
()
<<
" "
;
}
}
os
<<
"Output "
;
VLOG
(
3
)
<<
op
->
Name
();
for
(
auto
*
var
:
op
->
outputs
)
{
VLOG
(
3
)
<<
var
;
VLOG
(
3
)
<<
var
->
Name
();
if
(
!
var
->
IsVar
())
{
VLOG
(
3
)
<<
"error"
;
}
// VLOG(3) << var->Var()->Name();
if
(
var
->
IsVar
()
&&
!
var
->
IsCtrlVar
())
{
PADDLE_ENFORCE
(
var
->
Var
()
!=
nullptr
&&
var
->
Var
()
->
Name
()
==
var
->
Name
(),
"unmatched desc and var"
);
// os << var << ":" << var->Name() << " ";
os
<<
var
->
Name
()
<<
" "
;
}
if
(
var
->
Name
()
==
"fc_10.tmp_0"
)
{
VLOG
(
3
)
<<
NodeDebugString
(
var
);
}
}
return
os
.
str
();
}
static
inline
ir
::
Node
*
GetNextCascadeInplacedVar
(
ir
::
Node
*
var
)
{
// if next op is inplaced, then return the output var
// otherwise return nullptr
PADDLE_ENFORCE
(
var
&&
var
->
IsVar
()
&&
!
var
->
IsCtrlVar
());
ir
::
Node
*
inplaced_var
=
nullptr
;
// only has one output op can be inplaced
if
(
var
->
outputs
.
size
()
==
1
&&
var
->
outputs
[
0
]
->
IsOp
())
{
auto
*
op
=
var
->
outputs
[
0
];
for
(
auto
*
out_var
:
op
->
outputs
)
{
if
(
!
out_var
->
IsVar
()
||
out_var
->
IsCtrlVar
()
||
out_var
->
Var
()
==
nullptr
)
continue
;
if
(
out_var
->
Name
()
==
var
->
Name
())
{
inplaced_var
=
out_var
;
break
;
for
(
auto
*
next_op
:
var
->
outputs
)
{
for
(
auto
*
output
:
next_op
->
outputs
)
{
if
(
output
->
IsVar
()
&&
!
output
->
IsCtrlVar
()
&&
output
->
Name
()
==
var
->
Name
())
{
inplaced_var
=
output
;
}
}
}
return
inplaced_var
;
}
static
inline
ir
::
Node
*
GetPrev
InplacedOpInput
(
ir
::
Node
*
var
)
{
static
inline
ir
::
Node
*
GetPrev
CascadeInplacedVar
(
ir
::
Node
*
var
)
{
PADDLE_ENFORCE
(
var
&&
var
->
IsVar
()
&&
!
var
->
IsCtrlVar
());
ir
::
Node
*
inplaced_var
=
nullptr
;
if
(
var
->
inputs
.
size
()
==
1
&&
var
->
inputs
[
0
]
->
IsOp
())
{
auto
*
op
=
var
->
inputs
[
0
];
for
(
auto
*
in_var
:
op
->
inputs
)
{
if
(
!
in_var
->
IsVar
()
||
in_var
->
IsCtrlVar
()
||
in_var
->
Var
()
==
nullptr
)
continue
;
if
(
in_var
->
Name
()
==
var
->
Name
())
{
inplaced_var
=
in_var
;
break
;
}
}
}
return
inplaced_var
;
auto
*
prev_op
=
var
->
inputs
.
at
(
0
);
auto
input_it
=
std
::
find_if
(
prev_op
->
inputs
.
begin
(),
prev_op
->
inputs
.
end
(),
[
&
](
ir
::
Node
*
node
)
{
if
(
node
->
IsVar
()
&&
!
node
->
IsCtrlVar
()
&&
node
->
Name
()
==
var
->
Name
())
{
return
true
;
}
else
{
return
false
;
}
});
return
input_it
==
prev_op
->
inputs
.
end
()
?
nullptr
:
*
input_it
;
}
template
<
typename
Container
>
...
...
@@ -166,12 +217,22 @@ std::unique_ptr<ir::Graph> InplacePass::ApplyImpl(
view_
.
Build
(
graph
.
get
());
InitSSAGraphNodes
();
std
::
unique_ptr
<
SSAGraphPrinter
>
printer
(
new
SSAGraphPrinterImpl
);
for
(
auto
*
op
:
view_
.
AllOps
())
{
if
(
FLAGS_enable_inplace_whitelist
&&
!
whitelist_
.
count
(
op
->
Name
()))
continue
;
TryInplaceOpInputOutput
(
op
,
graph
.
get
());
}
graph
->
ResolveHazard
(
var_nodes_
);
constexpr
char
graph_path
[]
=
"ir_graph_inplaced.txt"
;
std
::
unique_ptr
<
std
::
ostream
>
fout
(
new
std
::
ofstream
(
graph_path
));
PADDLE_ENFORCE
(
fout
->
good
());
printer
->
Print
(
*
graph
,
*
fout
);
// for(auto* op : view_.AllOps()) {
// VLOG(3) << OpDebugString(op);
// }
return
graph
;
}
...
...
@@ -179,7 +240,7 @@ void InplacePass::InplaceModifyDesc(const std::string& var,
const
std
::
string
&
cache_var
,
const
size_t
&
idx
)
const
{
for
(
size_t
i
=
idx
;
i
<
view_
.
AllOps
().
size
();
++
i
)
{
auto
*
op
=
view_
.
AllOps
()[
i
];
ir
::
Node
*
op
=
view_
.
AllOps
()[
i
];
PADDLE_ENFORCE
(
op
->
IsOp
()
&&
op
->
Op
());
auto
*
op_desc
=
op
->
Op
();
op_desc
->
RenameInput
(
var
,
cache_var
);
...
...
@@ -203,14 +264,28 @@ void InplacePass::InplaceModifyVar(const std::string& var,
// redirect the input to the latest version of cache_var
for
(
auto
*
node
:
op
->
inputs
)
{
if
(
node
->
Name
()
==
var
)
{
ir
::
Node
*
cache_node
=
var_nodes_
[
cache_var
].
back
();
ir
::
Node
*
cache_node
=
graph
->
CreateVarNode
(
var_desc
.
get
());
var_nodes_
[
cache_var
].
emplace_back
(
cache_node
);
// swap node to cache_node
cache_node
->
outputs
.
insert
(
cache_node
->
outputs
.
end
(),
node
->
outputs
.
begin
(),
node
->
outputs
.
end
());
PADDLE_ENFORCE
(
node
->
inputs
.
size
()
==
1
&&
node
->
inputs
[
0
]
->
IsOp
());
auto
*
prev_op
=
node
->
inputs
[
0
];
std
::
replace
(
prev_op
->
outputs
.
begin
(),
prev_op
->
outputs
.
end
(),
node
,
cache_node
);
cache_node
->
inputs
.
emplace_back
(
prev_op
);
for
(
auto
*
next_op
:
node
->
outputs
)
{
std
::
replace
(
next_op
->
inputs
.
begin
(),
next_op
->
inputs
.
end
(),
node
,
cache_node
);
}
// release unused var in graph. Because python side memory optimize
// may reused the var in same name, so we only clear the var node
// after current inplaced index.
graph
->
RemoveNode
(
node
);
auto
&
nodes
=
var_nodes_
.
at
(
var
);
nodes
.
erase
(
std
::
remove
(
nodes
.
begin
(),
nodes
.
end
(),
node
),
nodes
.
end
());
}
}
...
...
@@ -220,7 +295,6 @@ void InplacePass::InplaceModifyVar(const std::string& var,
if
(
node
->
Name
()
==
var
)
{
ir
::
Node
*
cache_node
=
graph
->
CreateVarNode
(
var_desc
.
get
());
var_nodes_
[
cache_var
].
emplace_back
(
cache_node
);
// swap node to cache node
cache_node
->
outputs
.
insert
(
cache_node
->
outputs
.
end
(),
node
->
outputs
.
begin
(),
node
->
outputs
.
end
());
...
...
@@ -230,15 +304,14 @@ void InplacePass::InplaceModifyVar(const std::string& var,
std
::
replace
(
next_op
->
inputs
.
begin
(),
next_op
->
inputs
.
end
(),
node
,
cache_node
);
}
// release unsed var in graph
graph
->
RemoveNode
(
node
);
auto
&
nodes
=
var_nodes_
.
at
(
var
);
nodes
.
erase
(
std
::
remove
(
nodes
.
begin
(),
nodes
.
end
(),
node
),
nodes
.
end
());
}
}
}
// release node of unused var in graph
for
(
auto
*
node
:
var_nodes_
[
var
])
{
graph
->
RemoveNode
(
node
);
}
var_nodes_
.
at
(
var
).
clear
();
}
void
InplacePass
::
TryInplaceOpInputOutput
(
ir
::
Node
*
op
,
...
...
@@ -260,6 +333,7 @@ void InplacePass::TryInplaceOpInputOutput(ir::Node* op,
auto
&
all_ops
=
view_
.
AllOps
();
auto
cursor
=
std
::
find
(
all_ops
.
begin
(),
all_ops
.
end
(),
op
);
size_t
idx
=
std
::
distance
(
all_ops
.
begin
(),
cursor
);
VLOG
(
3
)
<<
op
->
Name
()
<<
idx
;
for
(
auto
&
pair
:
in_to_outs
)
{
auto
&
in_var_name
=
pair
.
first
;
...
...
@@ -286,6 +360,7 @@ void InplacePass::TryInplaceOpInputOutput(ir::Node* op,
}
VLOG
(
3
)
<<
string
::
Sprintf
(
"!!! %s, %s => %s inplaced"
,
op
->
Name
(),
out_var_name
,
in_var_name
);
// VLOG(3) << "Out " << OpDebugString(op);
InplaceModifyDesc
(
out_var_name
,
in_var_name
,
idx
);
InplaceModifyVar
(
out_var_name
,
in_var_name
,
idx
,
graph
);
}
...
...
@@ -319,7 +394,16 @@ ir::Node* GraphView::GetNodeByName(const std::string& name,
}
std
::
vector
<
ir
::
Node
*>
GraphView
::
PendingOpsOnVar
(
ir
::
Node
*
node
)
{
return
node
->
outputs
;
// get the pending ops depends on same var node.
// because node also maybe a inplaced variable, so need to backtrack all the
// previous inplaced vars.
std
::
vector
<
ir
::
Node
*>
pending_ops
;
ir
::
Node
*
p
=
node
;
while
(
p
!=
nullptr
)
{
pending_ops
.
insert
(
pending_ops
.
end
(),
p
->
outputs
.
begin
(),
p
->
outputs
.
end
());
p
=
GetPrevCascadeInplacedVar
(
p
);
}
return
pending_ops
;
}
void
GraphView
::
Build
(
ir
::
Graph
*
g
)
{
ops_
=
SortOpLikeDescOrder
(
*
g
);
}
...
...
@@ -354,14 +438,14 @@ bool GraphView::OutConnectInputByCtrlVar(ir::Node* in_var, ir::Node* out_var) {
// get the ops with same output name
while
(
out
!=
nullptr
)
{
out_var_set
.
emplace
(
out
);
out
=
GetNext
InplacedOpOutput
(
out
);
out
=
GetNext
CascadeInplacedVar
(
out
);
}
// get ops with same input name
ir
::
Node
*
in
=
in_var
;
while
(
in
!=
nullptr
)
{
in_var_set
.
emplace
(
in
);
in
=
GetPrev
InplacedOpInput
(
in
);
in
=
GetPrev
CascadeInplacedVar
(
in
);
}
// find if there is path with control dep var connect the in_var_set and
// out_var_set
...
...
paddle/fluid/framework/details/memory_optimize_pass_test.cc
浏览文件 @
2739096e
...
...
@@ -18,57 +18,13 @@
#include <iterator>
#include "glog/logging.h"
#include "gtest/gtest.h"
#include "paddle/fluid/framework/details/graph_test_base.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/program_desc.h"
namespace
paddle
{
namespace
framework
{
class
DummyOp
:
public
OperatorBase
{
public:
DummyOp
(
const
std
::
string
&
type
,
const
VariableNameMap
&
inputs
,
const
VariableNameMap
&
outputs
,
const
AttributeMap
&
attrs
)
:
OperatorBase
(
type
,
inputs
,
outputs
,
attrs
)
{}
private:
void
RunImpl
(
const
Scope
&
scope
,
const
platform
::
Place
&
place
)
const
override
{}
};
class
SumOpMaker
:
public
OpProtoAndCheckerMaker
{
public:
void
Make
()
{
AddInput
(
"X"
,
""
).
AsDuplicable
();
AddOutput
(
"Out"
,
""
);
AddComment
(
""
);
}
};
class
AssignOpMaker
:
public
OpProtoAndCheckerMaker
{
public:
void
Make
()
{
AddInput
(
"X"
,
""
).
AsDuplicable
();
AddOutput
(
"Out"
,
""
);
AddComment
(
""
);
}
};
class
DummyVarTypeInference
:
public
VarTypeInference
{
public:
void
operator
()(
const
OpDesc
&
op_desc
,
BlockDesc
*
block
)
const
override
{
auto
&
inputs
=
op_desc
.
Input
(
"X"
);
auto
type
=
block
->
Var
(
inputs
.
front
())
->
GetType
();
auto
out_var_name
=
op_desc
.
Output
(
"Out"
).
front
();
block
->
Var
(
out_var_name
)
->
SetType
(
type
);
}
};
}
// namespace framework
}
// namespace paddle
REGISTER_OPERATOR
(
sum
,
paddle
::
framework
::
DummyOp
,
paddle
::
framework
::
SumOpMaker
,
paddle
::
framework
::
DummyVarTypeInference
);
...
...
@@ -141,15 +97,6 @@ inline static ProgramDesc FillProgramDesc() {
return
prog
;
}
template
<
typename
Container
>
inline
static
std
::
string
DebugString
(
const
Container
&
c
)
{
std
::
stringstream
ss
;
for
(
auto
&
item
:
c
)
{
ss
<<
item
<<
" "
;
}
return
ss
.
str
();
}
TEST
(
CFGGraph
,
IRGraph
)
{
// prepare ir graph
auto
prog
=
FillProgramDesc
();
...
...
paddle/fluid/framework/details/multi_devices_graph_print_pass.h
浏览文件 @
2739096e
...
...
@@ -19,20 +19,12 @@
#include <iosfwd>
#include <ostream>
#include <string>
#include "paddle/fluid/framework/details/
multi_devices_helper
.h"
#include "paddle/fluid/framework/details/
graph_print_pass
.h"
namespace
paddle
{
namespace
framework
{
namespace
details
{
constexpr
char
kGraphvizPath
[]
=
"debug_graphviz_path"
;
class
SSAGraphPrinter
{
public:
virtual
~
SSAGraphPrinter
()
{}
virtual
void
Print
(
const
ir
::
Graph
&
graph
,
std
::
ostream
&
sout
)
const
=
0
;
};
class
GraphvizSSAGraphPrinter
:
public
SSAGraphPrinter
{
public:
void
Print
(
const
ir
::
Graph
&
graph
,
std
::
ostream
&
sout
)
const
override
;
...
...
python/paddle/fluid/tests/unittests/parallel_executor_test_base.py
浏览文件 @
2739096e
...
...
@@ -40,7 +40,7 @@ class TestParallelExecutorBase(unittest.TestCase):
seed
=
None
,
use_parallel_executor
=
True
,
use_reduce
=
False
,
use_ir_memory_optimize
=
Fals
e
,
use_ir_memory_optimize
=
Tru
e
,
enable_inplace
=
True
,
fuse_elewise_add_act_ops
=
False
,
fuse_relu_depthwise_conv
=
False
,
...
...
@@ -61,64 +61,66 @@ class TestParallelExecutorBase(unittest.TestCase):
main
.
random_seed
=
seed
loss
=
method
(
use_feed
=
feed_dict
is
not
None
)
if
optimizer
:
optimizer
().
minimize
(
loss
)
if
memory_opt
:
fluid
.
memory_optimize
(
main
)
place
=
fluid
.
CUDAPlace
(
0
)
if
use_cuda
else
fluid
.
CPUPlace
()
exe
=
fluid
.
Executor
(
place
)
exe
.
run
(
startup
)
exec_strategy
=
fluid
.
ExecutionStrategy
()
exec_strategy
.
allow_op_delay
=
allow_op_delay
if
use_fast_executor
:
exec_strategy
.
use_experimental_executor
=
True
build_strategy
=
fluid
.
BuildStrategy
()
build_strategy
.
reduce_strategy
=
fluid
.
BuildStrategy
.
ReduceStrategy
.
Reduce
\
if
use_reduce
else
fluid
.
BuildStrategy
.
ReduceStrategy
.
AllReduce
build_strategy
.
fuse_elewise_add_act_ops
=
fuse_elewise_add_act_ops
build_strategy
.
fuse_relu_depthwise_conv
=
fuse_relu_depthwise_conv
build_strategy
.
memory_optimize
=
use_ir_memory_optimize
build_strategy
.
enable_inplace
=
enable_inplace
build_strategy
.
enable_sequential_execution
=
enable_sequential_execution
if
use_cuda
and
core
.
is_compiled_with_cuda
():
build_strategy
.
remove_unnecessary_lock
=
True
if
use_parallel_executor
:
binary
=
compiler
.
CompiledProgram
(
main
).
with_data_parallel
(
loss_name
=
loss
.
name
,
build_strategy
=
build_strategy
,
exec_strategy
=
exec_strategy
)
else
:
binary
=
compiler
.
CompiledProgram
(
main
)
if
batch_size
is
not
None
:
batch_size
*=
fluid
.
core
.
get_cuda_device_count
(
)
if
use_cuda
else
int
(
os
.
environ
.
get
(
'CPU_NUM'
,
multiprocessing
.
cpu_count
()))
begin
=
time
.
time
()
first_loss
,
=
run_executor
(
exe
=
exe
,
binary
=
binary
,
feed
=
feed_dict
,
fetch_list
=
[
loss
.
name
])
for
i
in
range
(
iter
):
run_executor
(
exe
=
exe
,
binary
=
binary
,
feed
=
feed_dict
,
fetch_list
=
[])
last_loss
,
=
run_executor
(
exe
=
exe
,
binary
=
binary
,
feed
=
feed_dict
,
fetch_list
=
[
loss
.
name
])
end
=
time
.
time
()
if
batch_size
is
not
None
:
print
(
"%.4f Instance per second"
%
(
(
batch_size
*
iter
+
2
)
/
(
end
-
begin
)))
avg_last_loss_val
=
np
.
array
(
last_loss
).
mean
()
avg_first_loss_val
=
np
.
array
(
first_loss
).
mean
()
if
math
.
isnan
(
float
(
avg_last_loss_val
))
or
math
.
isnan
(
float
(
avg_first_loss_val
)):
sys
.
exit
(
"got NaN loss, training failed."
)
print
(
first_loss
,
last_loss
)
# self.assertGreater(first_loss[0], last_loss[0])
return
first_loss
,
last_loss
with
open
(
"program_model.txt"
,
"w"
)
as
f
:
f
.
write
(
str
(
main
))
place
=
fluid
.
CUDAPlace
(
0
)
if
use_cuda
else
fluid
.
CPUPlace
()
exe
=
fluid
.
Executor
(
place
)
exe
.
run
(
startup
)
exec_strategy
=
fluid
.
ExecutionStrategy
()
exec_strategy
.
allow_op_delay
=
allow_op_delay
if
use_fast_executor
:
exec_strategy
.
use_experimental_executor
=
True
build_strategy
=
fluid
.
BuildStrategy
()
build_strategy
.
reduce_strategy
=
fluid
.
BuildStrategy
.
ReduceStrategy
.
Reduce
\
if
use_reduce
else
fluid
.
BuildStrategy
.
ReduceStrategy
.
AllReduce
build_strategy
.
fuse_elewise_add_act_ops
=
fuse_elewise_add_act_ops
build_strategy
.
fuse_relu_depthwise_conv
=
fuse_relu_depthwise_conv
build_strategy
.
memory_optimize
=
use_ir_memory_optimize
build_strategy
.
enable_inplace
=
enable_inplace
build_strategy
.
enable_sequential_execution
=
enable_sequential_execution
build_strategy
.
debug_graphviz_path
=
"debug_ir_graph_"
if
use_cuda
and
core
.
is_compiled_with_cuda
():
build_strategy
.
remove_unnecessary_lock
=
True
if
use_parallel_executor
:
binary
=
compiler
.
CompiledProgram
(
main
).
with_data_parallel
(
loss_name
=
loss
.
name
,
build_strategy
=
build_strategy
,
exec_strategy
=
exec_strategy
)
else
:
binary
=
compiler
.
CompiledProgram
(
main
)
if
batch_size
is
not
None
:
batch_size
*=
fluid
.
core
.
get_cuda_device_count
(
)
if
use_cuda
else
int
(
os
.
environ
.
get
(
'CPU_NUM'
,
multiprocessing
.
cpu_count
()))
begin
=
time
.
time
()
first_loss
,
=
run_executor
(
exe
=
exe
,
binary
=
binary
,
feed
=
feed_dict
,
fetch_list
=
[
loss
.
name
])
for
i
in
range
(
iter
):
run_executor
(
exe
=
exe
,
binary
=
binary
,
feed
=
feed_dict
,
fetch_list
=
[])
last_loss
,
=
run_executor
(
exe
=
exe
,
binary
=
binary
,
feed
=
feed_dict
,
fetch_list
=
[
loss
.
name
])
end
=
time
.
time
()
if
batch_size
is
not
None
:
print
(
"%.4f Instance per second"
%
(
(
batch_size
*
iter
+
2
)
/
(
end
-
begin
)))
avg_last_loss_val
=
np
.
array
(
last_loss
).
mean
()
avg_first_loss_val
=
np
.
array
(
first_loss
).
mean
()
if
math
.
isnan
(
float
(
avg_last_loss_val
))
or
math
.
isnan
(
float
(
avg_first_loss_val
)):
sys
.
exit
(
"got NaN loss, training failed."
)
print
(
first_loss
,
last_loss
)
# self.assertGreater(first_loss[0], last_loss[0])
return
first_loss
,
last_loss
python/paddle/fluid/tests/unittests/test_ir_inplace_pass.py
0 → 100644
浏览文件 @
2739096e
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
print_function
import
os
import
unittest
import
numpy
as
np
import
paddle.fluid
as
fluid
from
parallel_executor_test_base
import
TestParallelExecutorBase
def
fc_with_batchnorm
(
use_feed
):
img
=
fluid
.
layers
.
data
(
name
=
'image'
,
shape
=
[
784
],
dtype
=
'float32'
)
label
=
fluid
.
layers
.
data
(
name
=
'label'
,
shape
=
[
1
],
dtype
=
'int64'
)
hidden
=
img
for
_
in
range
(
3
):
hidden
=
fluid
.
layers
.
fc
(
hidden
,
size
=
200
,
act
=
'tanh'
,
bias_attr
=
fluid
.
ParamAttr
(
initializer
=
fluid
.
initializer
.
Constant
(
value
=
1.0
)))
hidden
=
fluid
.
layers
.
batch_norm
(
input
=
hidden
)
prediction
=
fluid
.
layers
.
fc
(
hidden
,
size
=
10
,
act
=
'softmax'
)
loss
=
fluid
.
layers
.
cross_entropy
(
input
=
prediction
,
label
=
label
)
loss
=
fluid
.
layers
.
mean
(
loss
)
return
loss
class
TestIrInplace
(
TestParallelExecutorBase
):
@
classmethod
def
setUpClass
(
cls
):
os
.
environ
[
'CPU_NUM'
]
=
str
(
4
)
def
_fc_with_batchnorm
(
self
,
ir_memory_optimize
,
enable_inplace
):
np
.
random
.
seed
(
5
)
img
=
np
.
random
.
random
(
size
=
[
32
,
784
]).
astype
(
np
.
float32
)
label
=
np
.
ones
(
shape
=
[
32
,
1
],
dtype
=
'int64'
)
self
.
check_network_convergence
(
fc_with_batchnorm
,
feed_dict
=
{
"image"
:
img
,
"label"
:
label
},
use_cuda
=
True
,
memory_opt
=
False
,
# inplace is conflict with memory opt
use_ir_memory_optimize
=
ir_memory_optimize
,
enable_inplace
=
enable_inplace
)
def
test_fc_with_batchnorm
(
self
,
delta
=
1e-3
):
loss00
=
self
.
_fc_with_batchnorm
(
False
,
False
)
loss10
=
self
.
_fc_with_batchnorm
(
True
,
False
)
loss01
=
self
.
_fc_with_batchnorm
(
False
,
True
)
loss11
=
self
.
_fc_with_batchnorm
(
True
,
True
)
self
.
assertAlmostEqual
(
loss00
,
loss10
,
delta
=
delta
)
self
.
assertAlmostEqual
(
loss00
,
loss01
,
delta
=
delta
)
self
.
assertAlmostEqual
(
loss00
,
loss11
,
delta
=
delta
)
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录