Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Crayon鑫
Paddle
提交
2bb15f43
P
Paddle
项目概览
Crayon鑫
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
2bb15f43
编写于
8月 28, 2018
作者:
X
Xin Pan
提交者:
GitHub
8月 28, 2018
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #12791 from panyx0718/ir3
graph to program pass
上级
a22309af
cf547e27
变更
13
显示空白变更内容
内联
并排
Showing
13 changed file
with
254 addition
and
10 deletion
+254
-10
paddle/fluid/framework/CMakeLists.txt
paddle/fluid/framework/CMakeLists.txt
+2
-2
paddle/fluid/framework/ir/CMakeLists.txt
paddle/fluid/framework/ir/CMakeLists.txt
+2
-0
paddle/fluid/framework/ir/graph_to_program_pass.cc
paddle/fluid/framework/ir/graph_to_program_pass.cc
+65
-0
paddle/fluid/framework/ir/graph_to_program_pass.h
paddle/fluid/framework/ir/graph_to_program_pass.h
+30
-0
paddle/fluid/framework/ir/graph_to_program_pass_test.cc
paddle/fluid/framework/ir/graph_to_program_pass_test.cc
+110
-0
paddle/fluid/framework/op_desc.cc
paddle/fluid/framework/op_desc.cc
+3
-1
paddle/fluid/framework/program_desc.cc
paddle/fluid/framework/program_desc.cc
+20
-2
paddle/fluid/framework/program_desc.h
paddle/fluid/framework/program_desc.h
+2
-0
paddle/fluid/inference/CMakeLists.txt
paddle/fluid/inference/CMakeLists.txt
+1
-1
paddle/fluid/inference/tests/test_helper.h
paddle/fluid/inference/tests/test_helper.h
+14
-0
paddle/fluid/operators/parallel_do_op.cc
paddle/fluid/operators/parallel_do_op.cc
+1
-0
python/paddle/fluid/tests/book_memory_optimization/test_memopt_image_classification_train.py
...ry_optimization/test_memopt_image_classification_train.py
+2
-2
python/paddle/fluid/tests/book_memory_optimization/test_memopt_machine_translation.py
...ok_memory_optimization/test_memopt_machine_translation.py
+2
-2
未找到文件。
paddle/fluid/framework/CMakeLists.txt
浏览文件 @
2bb15f43
...
@@ -107,11 +107,11 @@ cc_library(lod_rank_table SRCS lod_rank_table.cc DEPS lod_tensor)
...
@@ -107,11 +107,11 @@ cc_library(lod_rank_table SRCS lod_rank_table.cc DEPS lod_tensor)
cc_library
(
feed_fetch_method SRCS feed_fetch_method.cc DEPS lod_tensor scope glog
)
cc_library
(
feed_fetch_method SRCS feed_fetch_method.cc DEPS lod_tensor scope glog
)
if
(
WITH_DISTRIBUTE
)
if
(
WITH_DISTRIBUTE
)
cc_library
(
executor SRCS executor.cc DEPS op_registry device_context scope framework_proto glog lod_rank_table feed_fetch_method sendrecvop_grpc cares grpc++_unsecure grpc_unsecure gpr
)
cc_library
(
executor SRCS executor.cc DEPS op_registry device_context scope framework_proto glog lod_rank_table feed_fetch_method sendrecvop_grpc cares grpc++_unsecure grpc_unsecure gpr
graph_to_program_pass
)
set
(
DISTRIBUTE_COMPILE_FLAGS
"-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor"
)
set
(
DISTRIBUTE_COMPILE_FLAGS
"-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor"
)
set_source_files_properties
(
executor.cc PROPERTIES COMPILE_FLAGS
${
DISTRIBUTE_COMPILE_FLAGS
}
)
set_source_files_properties
(
executor.cc PROPERTIES COMPILE_FLAGS
${
DISTRIBUTE_COMPILE_FLAGS
}
)
else
()
else
()
cc_library
(
executor SRCS executor.cc DEPS op_registry device_context scope framework_proto glog lod_rank_table feed_fetch_method
)
cc_library
(
executor SRCS executor.cc DEPS op_registry device_context scope framework_proto glog lod_rank_table feed_fetch_method
graph_to_program_pass
)
endif
()
endif
()
if
(
NOT WIN32
)
if
(
NOT WIN32
)
...
...
paddle/fluid/framework/ir/CMakeLists.txt
浏览文件 @
2bb15f43
...
@@ -3,6 +3,7 @@ cc_library(graph SRCS graph.cc DEPS node)
...
@@ -3,6 +3,7 @@ cc_library(graph SRCS graph.cc DEPS node)
cc_library
(
graph_helper SRCS graph_helper.cc DEPS graph
)
cc_library
(
graph_helper SRCS graph_helper.cc DEPS graph
)
cc_library
(
pass SRCS pass.cc DEPS graph node graph_helper
)
cc_library
(
pass SRCS pass.cc DEPS graph node graph_helper
)
cc_library
(
graph_viz_pass SRCS graph_viz_pass.cc DEPS graph pass graph_helper
)
cc_library
(
graph_viz_pass SRCS graph_viz_pass.cc DEPS graph pass graph_helper
)
cc_library
(
graph_to_program_pass SRCS graph_to_program_pass.cc DEPS graph pass graph_helper
)
cc_library
(
graph_traits SRCS graph_traits.cc DEPS graph
)
cc_library
(
graph_traits SRCS graph_traits.cc DEPS graph
)
cc_library
(
graph_pattern_detecter SRCS graph_pattern_detecter.cc DEPS graph graph_helper graph_traits
)
cc_library
(
graph_pattern_detecter SRCS graph_pattern_detecter.cc DEPS graph graph_helper graph_traits
)
cc_library
(
fc_fuse_pass SRCS fc_fuse_pass.cc DEPS graph graph_pattern_detecter
)
cc_library
(
fc_fuse_pass SRCS fc_fuse_pass.cc DEPS graph graph_pattern_detecter
)
...
@@ -12,5 +13,6 @@ cc_library(infer_clean_graph_pass SRCS infer_clean_graph_pass.cc DEPS graph pass
...
@@ -12,5 +13,6 @@ cc_library(infer_clean_graph_pass SRCS infer_clean_graph_pass.cc DEPS graph pass
cc_test
(
pass_test SRCS pass_test.cc DEPS graph pass graph_helper
)
cc_test
(
pass_test SRCS pass_test.cc DEPS graph pass graph_helper
)
cc_test
(
graph_test SRCS graph_test.cc DEPS graph graph_helper op_registry
)
cc_test
(
graph_test SRCS graph_test.cc DEPS graph graph_helper op_registry
)
cc_test
(
graph_helper_test SRCS graph_helper_test.cc DEPS graph graph_helper op_registry
)
cc_test
(
graph_helper_test SRCS graph_helper_test.cc DEPS graph graph_helper op_registry
)
cc_test
(
graph_to_program_pass_test SRCS graph_to_program_pass_test.cc DEPS graph_to_program_pass
)
cc_test
(
test_graph_pattern_detecter SRCS graph_pattern_detecter_tester.cc DEPS graph_pattern_detecter
)
cc_test
(
test_graph_pattern_detecter SRCS graph_pattern_detecter_tester.cc DEPS graph_pattern_detecter
)
cc_test
(
test_fc_fuse_pass SRCS fc_fuse_pass_tester.cc DEPS fc_fuse_pass graph_pattern_detecter graph pass graph_traits framework_proto
)
cc_test
(
test_fc_fuse_pass SRCS fc_fuse_pass_tester.cc DEPS fc_fuse_pass graph_pattern_detecter graph pass graph_traits framework_proto
)
paddle/fluid/framework/ir/graph_to_program_pass.cc
0 → 100644
浏览文件 @
2bb15f43
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/ir/graph_to_program_pass.h"
#include <map>
#include <string>
#include <vector>
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/framework/program_desc.h"
namespace
paddle
{
namespace
framework
{
namespace
ir
{
std
::
unique_ptr
<
Graph
>
GraphToProgramPass
::
ApplyImpl
(
std
::
unique_ptr
<
Graph
>
graph
)
const
{
ProgramDesc
&
program
=
Get
<
ProgramDesc
>
(
"program"
);
std
::
unique_ptr
<
proto
::
ProgramDesc
>
program_pb
(
new
proto
::
ProgramDesc
(
*
program
.
Proto
()));
auto
block
=
program_pb
->
mutable_blocks
(
kRootBlockIndex
);
block
->
clear_vars
();
std
::
unordered_set
<
std
::
string
>
visited_vars
;
for
(
ir
::
Node
*
n
:
graph
->
Nodes
())
{
if
(
n
->
NodeType
()
==
ir
::
Node
::
Type
::
kVariable
)
{
if
(
n
->
Var
()
&&
visited_vars
.
count
(
n
->
Var
()
->
Name
())
==
0
)
{
visited_vars
.
insert
(
n
->
Var
()
->
Name
());
block
->
add_vars
()
->
MergeFrom
(
*
n
->
Var
()
->
Proto
());
}
}
}
block
->
clear_ops
();
std
::
vector
<
ir
::
Node
*>
nodes
=
TopologySortOperations
(
*
graph
);
for
(
ir
::
Node
*
n
:
nodes
)
{
if
(
!
n
->
Op
())
{
continue
;
}
block
->
add_ops
()
->
MergeFrom
(
*
n
->
Op
()
->
Proto
());
}
program
.
CopyFrom
(
*
program_pb
);
return
graph
;
}
}
// namespace ir
}
// namespace framework
}
// namespace paddle
REGISTER_PASS
(
graph_to_program_pass
,
paddle
::
framework
::
ir
::
GraphToProgramPass
);
paddle/fluid/framework/ir/graph_to_program_pass.h
0 → 100644
浏览文件 @
2bb15f43
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/fluid/framework/ir/pass.h"
namespace
paddle
{
namespace
framework
{
namespace
ir
{
class
GraphToProgramPass
:
public
Pass
{
protected:
std
::
unique_ptr
<
Graph
>
ApplyImpl
(
std
::
unique_ptr
<
Graph
>
graph
)
const
override
;
};
}
// namespace ir
}
// namespace framework
}
// namespace paddle
paddle/fluid/framework/ir/graph_to_program_pass_test.cc
0 → 100644
浏览文件 @
2bb15f43
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/ir/graph_to_program_pass.h"
#include <string>
#include <vector>
#include "gtest/gtest.h"
#include "paddle/fluid/framework/program_desc.h"
namespace
paddle
{
namespace
framework
{
namespace
ir
{
void
BuildNoCircleGraph
(
Graph
*
g
)
{
OpDesc
op1
;
op1
.
SetType
(
"op1"
);
OpDesc
op2
;
op2
.
SetType
(
"op2"
);
OpDesc
op3
;
op3
.
SetType
(
"op3"
);
OpDesc
op4
;
op4
.
SetType
(
"op4"
);
OpDesc
op5
;
op5
.
SetType
(
"op5"
);
VarDesc
var1
(
"var1"
);
VarDesc
var2
(
"var2"
);
VarDesc
var3
(
"var3"
);
VarDesc
var4
(
"var4"
);
ir
::
Node
*
o1
=
g
->
CreateOpNode
(
&
op1
);
ir
::
Node
*
o2
=
g
->
CreateOpNode
(
&
op2
);
ir
::
Node
*
o3
=
g
->
CreateOpNode
(
&
op3
);
ir
::
Node
*
o4
=
g
->
CreateOpNode
(
&
op4
);
ir
::
Node
*
o5
=
g
->
CreateOpNode
(
&
op5
);
ir
::
Node
*
v1
=
g
->
CreateVarNode
(
&
var1
);
ir
::
Node
*
v2
=
g
->
CreateVarNode
(
&
var2
);
ir
::
Node
*
v3
=
g
->
CreateVarNode
(
&
var3
);
ir
::
Node
*
v4
=
g
->
CreateVarNode
(
&
var4
);
// o1->v1->o2
o1
->
outputs
.
push_back
(
v1
);
o2
->
inputs
.
push_back
(
v1
);
v1
->
inputs
.
push_back
(
o1
);
v1
->
outputs
.
push_back
(
o2
);
// o2->v2->o3
// o2->v2->o4
o2
->
outputs
.
push_back
(
v2
);
o3
->
inputs
.
push_back
(
v2
);
o4
->
inputs
.
push_back
(
v2
);
v2
->
outputs
.
push_back
(
o3
);
v2
->
outputs
.
push_back
(
o4
);
v2
->
inputs
.
push_back
(
o2
);
// o2->v3->o5
o2
->
outputs
.
push_back
(
v3
);
o5
->
inputs
.
push_back
(
v3
);
v3
->
inputs
.
push_back
(
o2
);
v3
->
outputs
.
push_back
(
o5
);
// o3-v4->o5
o3
->
outputs
.
push_back
(
v4
);
o5
->
inputs
.
push_back
(
v4
);
v4
->
inputs
.
push_back
(
o3
);
v4
->
outputs
.
push_back
(
o5
);
}
TEST
(
GraphToProgramPass
,
Basic
)
{
ProgramDesc
prog
;
std
::
unique_ptr
<
Graph
>
g
(
new
Graph
(
prog
));
BuildNoCircleGraph
(
g
.
get
());
auto
pass
=
paddle
::
framework
::
ir
::
PassRegistry
::
Instance
().
Get
(
"graph_to_program_pass"
);
ProgramDesc
compiled_prog
;
pass
->
SetNotOwned
<
paddle
::
framework
::
ProgramDesc
>
(
"program"
,
&
compiled_prog
);
pass
->
Apply
(
std
::
move
(
g
));
std
::
vector
<
OpDesc
*>
ops
=
compiled_prog
.
Block
(
0
).
AllOps
();
EXPECT_EQ
(
ops
[
0
]
->
Type
(),
"op1"
);
EXPECT_EQ
(
ops
[
1
]
->
Type
(),
"op2"
);
if
(
ops
[
2
]
->
Type
()
==
"op3"
)
{
EXPECT_EQ
(
ops
[
3
]
->
Type
(),
"op4"
);
}
else
if
(
ops
[
2
]
->
Type
()
==
"op4"
)
{
EXPECT_EQ
(
ops
[
3
]
->
Type
(),
"op3"
);
}
EXPECT_EQ
(
ops
[
4
]
->
Type
(),
"op5"
);
std
::
unordered_set
<
std
::
string
>
vars
;
for
(
VarDesc
*
v
:
compiled_prog
.
Block
(
0
).
AllVars
())
{
vars
.
insert
(
v
->
Name
());
}
EXPECT_TRUE
(
vars
.
find
(
"var1"
)
!=
vars
.
end
());
EXPECT_TRUE
(
vars
.
find
(
"var2"
)
!=
vars
.
end
());
EXPECT_TRUE
(
vars
.
find
(
"var3"
)
!=
vars
.
end
());
}
}
// namespace ir
}
// namespace framework
}
// namespace paddle
USE_PASS
(
graph_to_program_pass
);
paddle/fluid/framework/op_desc.cc
浏览文件 @
2bb15f43
...
@@ -132,7 +132,9 @@ OpDesc::OpDesc(const proto::OpDesc &desc, BlockDesc *block)
...
@@ -132,7 +132,9 @@ OpDesc::OpDesc(const proto::OpDesc &desc, BlockDesc *block)
std
::
string
attr_name
=
attr
.
name
();
std
::
string
attr_name
=
attr
.
name
();
// The sub_block referred to by the BLOCK attr hasn't been added
// The sub_block referred to by the BLOCK attr hasn't been added
// to ProgramDesc class yet, we skip setting BLOCK attr here.
// to ProgramDesc class yet, we skip setting BLOCK attr here.
if
(
attr
.
type
()
!=
proto
::
AttrType
::
BLOCK
)
{
// TODO(paddle-dev): Need copy fix this to copy Block as well.
if
(
attr
.
type
()
!=
proto
::
AttrType
::
BLOCK
&&
attr
.
type
()
!=
proto
::
AttrType
::
BLOCKS
)
{
attrs_
[
attr_name
]
=
GetAttrValue
(
attr
);
attrs_
[
attr_name
]
=
GetAttrValue
(
attr
);
}
}
}
}
...
...
paddle/fluid/framework/program_desc.cc
浏览文件 @
2bb15f43
...
@@ -80,6 +80,12 @@ ProgramDesc::ProgramDesc(const proto::ProgramDesc &desc) {
...
@@ -80,6 +80,12 @@ ProgramDesc::ProgramDesc(const proto::ProgramDesc &desc) {
InitFromProto
();
InitFromProto
();
}
}
void
ProgramDesc
::
CopyFrom
(
const
proto
::
ProgramDesc
&
desc
)
{
blocks_
.
clear
();
desc_
=
desc
;
InitFromProto
();
}
ProgramDesc
::
ProgramDesc
(
const
std
::
string
&
binary_str
)
{
ProgramDesc
::
ProgramDesc
(
const
std
::
string
&
binary_str
)
{
PADDLE_ENFORCE
(
desc_
.
ParseFromString
(
binary_str
),
PADDLE_ENFORCE
(
desc_
.
ParseFromString
(
binary_str
),
"Fail to parse program_desc from binary string."
);
"Fail to parse program_desc from binary string."
);
...
@@ -111,10 +117,16 @@ void ProgramDesc::InitFromProto() {
...
@@ -111,10 +117,16 @@ void ProgramDesc::InitFromProto() {
const
std
::
vector
<
std
::
string
>
ProgramDesc
::
GetFeedTargetNames
()
{
const
std
::
vector
<
std
::
string
>
ProgramDesc
::
GetFeedTargetNames
()
{
auto
&
global_block
=
Block
(
0
);
auto
&
global_block
=
Block
(
0
);
// The order of feed_target_names must follow the index specified in `col`.
// since feed operator's order doesn't necessary follow 'col'.
std
::
vector
<
std
::
string
>
feed_target_names
;
std
::
vector
<
std
::
string
>
feed_target_names
;
for
(
auto
*
op
:
global_block
.
AllOps
())
{
for
(
auto
*
op
:
global_block
.
AllOps
())
{
if
(
op
->
Type
()
==
kFeedOpType
)
{
if
(
op
->
Type
()
==
kFeedOpType
)
{
feed_target_names
.
insert
(
feed_target_names
.
begin
(),
op
->
Output
(
"Out"
)[
0
]);
int
col
=
boost
::
get
<
int
>
(
op
->
GetAttr
(
"col"
));
if
(
col
>=
feed_target_names
.
size
())
{
feed_target_names
.
resize
(
col
+
1
);
}
feed_target_names
[
col
]
=
op
->
Output
(
"Out"
)[
0
];
}
}
}
}
return
feed_target_names
;
return
feed_target_names
;
...
@@ -122,10 +134,16 @@ const std::vector<std::string> ProgramDesc::GetFeedTargetNames() {
...
@@ -122,10 +134,16 @@ const std::vector<std::string> ProgramDesc::GetFeedTargetNames() {
const
std
::
vector
<
std
::
string
>
ProgramDesc
::
GetFetchTargetNames
()
{
const
std
::
vector
<
std
::
string
>
ProgramDesc
::
GetFetchTargetNames
()
{
auto
&
global_block
=
Block
(
0
);
auto
&
global_block
=
Block
(
0
);
// The order of fetch_target_names must follow the index specified in `col`.
// since fetch operator's order doesn't necessary follow 'col'.
std
::
vector
<
std
::
string
>
fetch_target_names
;
std
::
vector
<
std
::
string
>
fetch_target_names
;
for
(
auto
*
op
:
global_block
.
AllOps
())
{
for
(
auto
*
op
:
global_block
.
AllOps
())
{
if
(
op
->
Type
()
==
kFetchOpType
)
{
if
(
op
->
Type
()
==
kFetchOpType
)
{
fetch_target_names
.
push_back
(
op
->
Input
(
"X"
)[
0
]);
int
col
=
boost
::
get
<
int
>
(
op
->
GetAttr
(
"col"
));
if
(
col
>=
fetch_target_names
.
size
())
{
fetch_target_names
.
resize
(
col
+
1
);
}
fetch_target_names
[
col
]
=
op
->
Input
(
"X"
)[
0
];
}
}
}
}
return
fetch_target_names
;
return
fetch_target_names
;
...
...
paddle/fluid/framework/program_desc.h
浏览文件 @
2bb15f43
...
@@ -53,6 +53,8 @@ class ProgramDesc {
...
@@ -53,6 +53,8 @@ class ProgramDesc {
void
Flush
();
void
Flush
();
void
CopyFrom
(
const
proto
::
ProgramDesc
&
desc
);
proto
::
ProgramDesc
*
Proto
();
proto
::
ProgramDesc
*
Proto
();
// The output variable of feed_op is referenced as feed_target.
// The output variable of feed_op is referenced as feed_target.
...
...
paddle/fluid/inference/CMakeLists.txt
浏览文件 @
2bb15f43
...
@@ -10,7 +10,7 @@ set(FLUID_CORE_MODULES proto_desc memory lod_tensor executor)
...
@@ -10,7 +10,7 @@ set(FLUID_CORE_MODULES proto_desc memory lod_tensor executor)
# TODO(panyx0718): Should this be called paddle_fluid_inference_api_internal?
# TODO(panyx0718): Should this be called paddle_fluid_inference_api_internal?
cc_library
(
paddle_fluid_api
cc_library
(
paddle_fluid_api
SRCS io.cc
SRCS io.cc
DEPS
${
FLUID_CORE_MODULES
}
${
GLOB_OP_LIB
}
)
DEPS
${
FLUID_CORE_MODULES
}
${
GLOB_OP_LIB
}
graph_to_program_pass
)
get_property
(
fluid_modules GLOBAL PROPERTY FLUID_MODULES
)
get_property
(
fluid_modules GLOBAL PROPERTY FLUID_MODULES
)
...
...
paddle/fluid/inference/tests/test_helper.h
浏览文件 @
2bb15f43
...
@@ -18,6 +18,7 @@ limitations under the License. */
...
@@ -18,6 +18,7 @@ limitations under the License. */
#include <string>
#include <string>
#include <vector>
#include <vector>
#include "paddle/fluid/framework/ir/graph_to_program_pass.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/inference/io.h"
#include "paddle/fluid/inference/io.h"
#include "paddle/fluid/platform/profiler.h"
#include "paddle/fluid/platform/profiler.h"
...
@@ -135,6 +136,15 @@ std::vector<std::vector<int64_t>> GetFeedTargetShapes(
...
@@ -135,6 +136,15 @@ std::vector<std::vector<int64_t>> GetFeedTargetShapes(
return
feed_target_shapes
;
return
feed_target_shapes
;
}
}
void
Compile
(
paddle
::
framework
::
ProgramDesc
*
program
)
{
std
::
unique_ptr
<
paddle
::
framework
::
ir
::
Graph
>
g
(
new
paddle
::
framework
::
ir
::
Graph
(
*
program
));
auto
pass
=
paddle
::
framework
::
ir
::
PassRegistry
::
Instance
().
Get
(
"graph_to_program_pass"
);
pass
->
SetNotOwned
<
paddle
::
framework
::
ProgramDesc
>
(
"program"
,
program
);
pass
->
Apply
(
std
::
move
(
g
));
}
template
<
typename
Place
,
bool
CreateVars
=
true
,
bool
PrepareContext
=
false
>
template
<
typename
Place
,
bool
CreateVars
=
true
,
bool
PrepareContext
=
false
>
void
TestInference
(
const
std
::
string
&
dirname
,
void
TestInference
(
const
std
::
string
&
dirname
,
const
std
::
vector
<
paddle
::
framework
::
LoDTensor
*>&
cpu_feeds
,
const
std
::
vector
<
paddle
::
framework
::
LoDTensor
*>&
cpu_feeds
,
...
@@ -172,6 +182,8 @@ void TestInference(const std::string& dirname,
...
@@ -172,6 +182,8 @@ void TestInference(const std::string& dirname,
paddle
::
platform
::
DeviceContextPool
::
Instance
().
Get
(
place
));
paddle
::
platform
::
DeviceContextPool
::
Instance
().
Get
(
place
));
inference_program
=
InitProgram
(
&
executor
,
scope
,
dirname
,
is_combined
);
inference_program
=
InitProgram
(
&
executor
,
scope
,
dirname
,
is_combined
);
}
}
Compile
(
inference_program
.
get
());
// Disable the profiler and print the timing information
// Disable the profiler and print the timing information
paddle
::
platform
::
DisableProfiler
(
paddle
::
platform
::
EventSortingKey
::
kDefault
,
paddle
::
platform
::
DisableProfiler
(
paddle
::
platform
::
EventSortingKey
::
kDefault
,
"load_program_profiler"
);
"load_program_profiler"
);
...
@@ -249,3 +261,5 @@ void TestInference(const std::string& dirname,
...
@@ -249,3 +261,5 @@ void TestInference(const std::string& dirname,
delete
scope
;
delete
scope
;
}
}
USE_PASS
(
graph_to_program_pass
);
paddle/fluid/operators/parallel_do_op.cc
浏览文件 @
2bb15f43
...
@@ -355,6 +355,7 @@ class ParallelDoGradOpDescMaker : public framework::SingleGradOpDescMaker {
...
@@ -355,6 +355,7 @@ class ParallelDoGradOpDescMaker : public framework::SingleGradOpDescMaker {
grad
->
SetInput
(
framework
::
GradVarName
(
output_param
),
og_names
);
grad
->
SetInput
(
framework
::
GradVarName
(
output_param
),
og_names
);
}
}
}
}
grad
->
SetInput
(
"Communicator"
,
{
"nccl_com__do_not_change_"
});
grad
->
SetAttrMap
(
this
->
Attrs
());
grad
->
SetAttrMap
(
this
->
Attrs
());
grad
->
SetBlockAttr
(
kParallelBlock
,
grad_block_
[
0
]);
grad
->
SetBlockAttr
(
kParallelBlock
,
grad_block_
[
0
]);
...
...
python/paddle/fluid/tests/book_memory_optimization/test_memopt_image_classification_train.py
浏览文件 @
2bb15f43
...
@@ -125,8 +125,8 @@ opts = optimizer.minimize(avg_cost)
...
@@ -125,8 +125,8 @@ opts = optimizer.minimize(avg_cost)
batch_size
=
fluid
.
layers
.
create_tensor
(
dtype
=
'int64'
)
batch_size
=
fluid
.
layers
.
create_tensor
(
dtype
=
'int64'
)
batch_acc
=
fluid
.
layers
.
accuracy
(
input
=
predict
,
label
=
label
,
total
=
batch_size
)
batch_acc
=
fluid
.
layers
.
accuracy
(
input
=
predict
,
label
=
label
,
total
=
batch_size
)
#
fluid.memory_optimize(fluid.default_main_program(), level=0)
fluid
.
memory_optimize
(
fluid
.
default_main_program
(),
level
=
0
)
fluid
.
release_memory
(
fluid
.
default_main_program
())
#
fluid.release_memory(fluid.default_main_program())
BATCH_SIZE
=
16
BATCH_SIZE
=
16
PASS_NUM
=
1
PASS_NUM
=
1
...
...
python/paddle/fluid/tests/book_memory_optimization/test_memopt_machine_translation.py
浏览文件 @
2bb15f43
...
@@ -92,8 +92,8 @@ def main():
...
@@ -92,8 +92,8 @@ def main():
optimizer
=
fluid
.
optimizer
.
Adagrad
(
learning_rate
=
1e-4
)
optimizer
=
fluid
.
optimizer
.
Adagrad
(
learning_rate
=
1e-4
)
optimizer
.
minimize
(
avg_cost
)
optimizer
.
minimize
(
avg_cost
)
#
fluid.memory_optimize(fluid.default_main_program())
fluid
.
memory_optimize
(
fluid
.
default_main_program
())
fluid
.
release_memory
(
fluid
.
default_main_program
())
#
fluid.release_memory(fluid.default_main_program())
# fix the order of training data
# fix the order of training data
train_data
=
paddle
.
batch
(
train_data
=
paddle
.
batch
(
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录