Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
8adaa0f0
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
8adaa0f0
编写于
12月 10, 2021
作者:
J
jianghaicheng
提交者:
GitHub
12月 10, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add popart_canonicalization p1 (#37964)
上级
89069af5
变更
16
隐藏空白更改
内联
并排
Showing
16 changed file
with
929 addition
and
0 deletion
+929
-0
paddle/fluid/framework/ir/ipu/avg_shard_pass.cc
paddle/fluid/framework/ir/ipu/avg_shard_pass.cc
+56
-0
paddle/fluid/framework/ir/ipu/avg_shard_pass.h
paddle/fluid/framework/ir/ipu/avg_shard_pass.h
+30
-0
paddle/fluid/framework/ir/ipu/forward_graph_extract_pass.cc
paddle/fluid/framework/ir/ipu/forward_graph_extract_pass.cc
+133
-0
paddle/fluid/framework/ir/ipu/forward_graph_extract_pass.h
paddle/fluid/framework/ir/ipu/forward_graph_extract_pass.h
+31
-0
paddle/fluid/framework/ir/ipu/infer_shape_pass.cc
paddle/fluid/framework/ir/ipu/infer_shape_pass.cc
+108
-0
paddle/fluid/framework/ir/ipu/infer_shape_pass.h
paddle/fluid/framework/ir/ipu/infer_shape_pass.h
+30
-0
paddle/fluid/framework/ir/ipu/inference_postprocess_pass.cc
paddle/fluid/framework/ir/ipu/inference_postprocess_pass.cc
+89
-0
paddle/fluid/framework/ir/ipu/inference_postprocess_pass.h
paddle/fluid/framework/ir/ipu/inference_postprocess_pass.h
+30
-0
paddle/fluid/framework/ir/ipu/inference_process_pass.cc
paddle/fluid/framework/ir/ipu/inference_process_pass.cc
+129
-0
paddle/fluid/framework/ir/ipu/inference_process_pass.h
paddle/fluid/framework/ir/ipu/inference_process_pass.h
+30
-0
paddle/fluid/framework/ir/ipu/ipu_graph_builder_pass.cc
paddle/fluid/framework/ir/ipu/ipu_graph_builder_pass.cc
+52
-0
paddle/fluid/framework/ir/ipu/ipu_graph_builder_pass.h
paddle/fluid/framework/ir/ipu/ipu_graph_builder_pass.h
+31
-0
paddle/fluid/framework/ir/ipu/ipu_inplace_pass.cc
paddle/fluid/framework/ir/ipu/ipu_inplace_pass.cc
+85
-0
paddle/fluid/framework/ir/ipu/ipu_inplace_pass.h
paddle/fluid/framework/ir/ipu/ipu_inplace_pass.h
+30
-0
paddle/fluid/framework/ir/ipu/ipu_pass_base.cc
paddle/fluid/framework/ir/ipu/ipu_pass_base.cc
+28
-0
paddle/fluid/framework/ir/ipu/ipu_pass_base.h
paddle/fluid/framework/ir/ipu/ipu_pass_base.h
+37
-0
未找到文件。
paddle/fluid/framework/ir/ipu/avg_shard_pass.cc
0 → 100644
浏览文件 @
8adaa0f0
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/ir/ipu/avg_shard_pass.h"
#include "paddle/fluid/platform/device/ipu/ipu_backend.h"
#include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/framework/ir/pass_tester_helper.h"
namespace
paddle
{
namespace
framework
{
namespace
ir
{
void
AvgShardPass
::
ApplyImpl
(
ir
::
Graph
*
graph
)
const
{
VLOG
(
10
)
<<
"enter AvgShardPass::ApplyImpl"
;
std
::
shared_ptr
<
platform
::
ipu
::
IpuBackend
>
ipu_backend
=
platform
::
ipu
::
IpuBackend
::
GetInstance
();
if
(
ipu_backend
->
GetIpuStrategy
()
->
need_avg_shard
)
{
VLOG
(
10
)
<<
"start AvgShardPass"
;
auto
nodes
=
ir
::
TopologySortOperations
(
*
graph
);
auto
num_ipus
=
ipu_backend
->
GetIpuStrategy
()
->
num_ipus
;
int
shard_position
=
nodes
.
size
()
/
num_ipus
;
int
index_and_stage
=
-
1
;
for
(
int
i
=
0
;
i
<
nodes
.
size
();
i
++
)
{
if
((
i
%
shard_position
)
==
0
&&
index_and_stage
<
num_ipus
-
1
)
{
index_and_stage
++
;
}
nodes
[
i
]
->
Op
()
->
SetAttr
(
"ipu_index"
,
index_and_stage
);
nodes
[
i
]
->
Op
()
->
SetAttr
(
"ipu_stage"
,
index_and_stage
);
}
VLOG
(
10
)
<<
"end AvgShardPass"
;
}
VLOG
(
10
)
<<
"leave AvgShardPass::ApplyImpl"
;
}
}
// namespace ir
}
// namespace framework
}
// namespace paddle
REGISTER_PASS
(
avg_shard_pass
,
paddle
::
framework
::
ir
::
AvgShardPass
);
paddle/fluid/framework/ir/ipu/avg_shard_pass.h
0 → 100644
浏览文件 @
8adaa0f0
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/framework/ir/ipu/ipu_pass_base.h"
namespace
paddle
{
namespace
framework
{
namespace
ir
{
class
AvgShardPass
:
public
IPUPassBase
{
protected:
void
ApplyImpl
(
ir
::
Graph
*
graph
)
const
override
;
};
}
// namespace ir
}
// namespace framework
}
// namespace paddle
paddle/fluid/framework/ir/ipu/forward_graph_extract_pass.cc
0 → 100644
浏览文件 @
8adaa0f0
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/ir/ipu/forward_graph_extract_pass.h"
#include "paddle/fluid/framework/ir/pass_tester_helper.h"
namespace
paddle
{
namespace
framework
{
namespace
ir
{
void
ForwardGraphExtractPass
::
ApplyImpl
(
ir
::
Graph
*
graph
)
const
{
VLOG
(
10
)
<<
"enter ForwardGraphExtractPass::ApplyImpl"
;
std
::
unordered_map
<
OpRole
,
std
::
unordered_set
<
ir
::
Node
*>>
all_ops
{
{
OpRole
::
kForward
,
{}},
{
OpRole
::
kBackward
,
{}},
{
OpRole
::
kOptimize
,
{}},
{
OpRole
::
kRPC
,
{}},
{
OpRole
::
kDist
,
{}},
{
OpRole
::
kLRSched
,
{}},
{
OpRole
::
kLoss
,
{}},
{
OpRole
::
kNotSpecified
,
{}}};
for
(
auto
*
node
:
graph
->
Nodes
())
{
if
(
!
node
->
IsOp
())
{
continue
;
}
auto
op_role
=
BOOST_GET_MUTABLE
(
int
,
node
->
Op
()
->
GetAttr
(
"op_role"
));
if
(
op_role
==
static_cast
<
int
>
(
OpRole
::
kForward
))
{
all_ops
[
OpRole
::
kForward
].
insert
(
node
);
}
else
if
(
op_role
==
static_cast
<
int
>
(
OpRole
::
kBackward
))
{
all_ops
[
OpRole
::
kBackward
].
insert
(
node
);
}
else
if
(
op_role
==
static_cast
<
int
>
(
OpRole
::
kOptimize
))
{
all_ops
[
OpRole
::
kOptimize
].
insert
(
node
);
}
else
if
(
op_role
==
static_cast
<
int
>
(
OpRole
::
kRPC
))
{
}
else
if
(
op_role
==
static_cast
<
int
>
(
OpRole
::
kDist
))
{
}
else
if
(
op_role
==
static_cast
<
int
>
(
OpRole
::
kLRSched
))
{
}
else
if
(
op_role
==
static_cast
<
int
>
(
OpRole
::
kLoss
))
{
all_ops
[
OpRole
::
kLoss
].
insert
(
node
);
}
else
if
(
op_role
==
static_cast
<
int
>
(
OpRole
::
kNotSpecified
))
{
LOG
(
WARNING
)
<<
"Op: "
<<
node
->
Name
()
<<
" OpRole is NotSpecified "
;
}
}
std
::
unordered_set
<
ir
::
Node
*>
forward_vars
;
std
::
unordered_set
<
ir
::
Node
*>
backward_vars
;
std
::
unordered_set
<
ir
::
Node
*>
control_vars
;
// forward_vars
for
(
auto
&
nodes
:
std
::
array
<
std
::
unordered_set
<
ir
::
Node
*>
,
2
>
{
all_ops
[
OpRole
::
kForward
],
all_ops
[
OpRole
::
kLoss
]})
{
for
(
auto
*
node
:
nodes
)
{
for
(
auto
*
in_node
:
node
->
inputs
)
{
forward_vars
.
insert
(
in_node
);
}
for
(
auto
*
out_node
:
node
->
outputs
)
{
forward_vars
.
insert
(
out_node
);
}
}
}
// control_vars & backward_vars
for
(
auto
*
node
:
graph
->
Nodes
())
{
if
(
!
node
->
IsVar
())
{
continue
;
}
if
(
node
->
IsCtrlVar
())
{
control_vars
.
insert
(
node
);
}
for
(
auto
*
in_node
:
node
->
inputs
)
{
if
(
all_ops
[
OpRole
::
kOptimize
].
count
(
in_node
))
{
backward_vars
.
insert
(
node
);
}
}
}
// all removed node
std
::
unordered_set
<
ir
::
Node
*>
rm_nodes
;
for
(
auto
*
node
:
graph
->
Nodes
())
{
if
(
backward_vars
.
count
(
node
))
{
rm_nodes
.
insert
(
node
);
}
else
if
(
control_vars
.
count
(
node
))
{
rm_nodes
.
insert
(
node
);
}
else
if
(
all_ops
[
OpRole
::
kBackward
].
count
(
node
))
{
rm_nodes
.
insert
(
node
);
}
else
if
(
all_ops
[
OpRole
::
kForward
].
count
(
node
)
==
0
&&
all_ops
[
OpRole
::
kLoss
].
count
(
node
)
==
0
&&
forward_vars
.
count
(
node
)
==
0
)
{
rm_nodes
.
insert
(
node
);
}
else
if
(
node
->
Name
()
==
"feed"
||
node
->
Name
()
==
"fetch"
)
{
rm_nodes
.
insert
(
node
);
}
}
VLOG
(
10
)
<<
"Remove Node: "
;
for
(
auto
*
node
:
rm_nodes
)
{
// rm node releations
for
(
auto
*
node_in
:
node
->
inputs
)
{
for
(
size_t
i
=
0
;
i
<
node_in
->
outputs
.
size
();
++
i
)
{
if
(
node_in
->
outputs
[
i
]
==
node
)
{
node_in
->
outputs
.
erase
(
node_in
->
outputs
.
begin
()
+
i
);
break
;
}
}
}
for
(
auto
*
node_out
:
node
->
outputs
)
{
for
(
size_t
i
=
0
;
i
<
node_out
->
inputs
.
size
();
++
i
)
{
if
(
node_out
->
inputs
[
i
]
==
node
)
{
node_out
->
inputs
.
erase
(
node_out
->
inputs
.
begin
()
+
i
);
break
;
}
}
}
VLOG
(
10
)
<<
"
\t
"
<<
node
->
Name
();
graph
->
RemoveNode
(
node
);
}
VLOG
(
10
)
<<
"Post Graph: "
;
VLOG
(
10
)
<<
DebugString
(
graph
);
VLOG
(
10
)
<<
"leave ForwardGraphExtractPass::ApplyImpl"
;
}
}
// namespace ir
}
// namespace framework
}
// namespace paddle
REGISTER_PASS
(
forward_graph_extract_pass
,
paddle
::
framework
::
ir
::
ForwardGraphExtractPass
);
paddle/fluid/framework/ir/ipu/forward_graph_extract_pass.h
0 → 100644
浏览文件 @
8adaa0f0
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/ipu/ipu_pass_base.h"
namespace
paddle
{
namespace
framework
{
namespace
ir
{
class
ForwardGraphExtractPass
:
public
IPUPassBase
{
protected:
void
ApplyImpl
(
ir
::
Graph
*
graph
)
const
override
;
};
}
// namespace ir
}
// namespace framework
}
// namespace paddle
paddle/fluid/framework/ir/ipu/infer_shape_pass.cc
0 → 100644
浏览文件 @
8adaa0f0
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/ir/ipu/infer_shape_pass.h"
#include "paddle/fluid/platform/device/ipu/ipu_backend.h"
#include "paddle/fluid/framework/ddim.h"
#include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/framework/ir/pass_tester_helper.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/variable_helper.h"
namespace
paddle
{
namespace
framework
{
namespace
ir
{
void
InferShapePass
::
ApplyImpl
(
ir
::
Graph
*
graph
)
const
{
VLOG
(
10
)
<<
"enter InferShapePass::ApplyImpl"
;
VLOG
(
10
)
<<
"Raw Graph: "
;
VLOG
(
10
)
<<
DebugString
(
graph
);
std
::
shared_ptr
<
platform
::
ipu
::
IpuBackend
>
ipu_backend
=
platform
::
ipu
::
IpuBackend
::
GetInstance
();
auto
batch_size
=
ipu_backend
->
GetIpuStrategy
()
->
batch_size
;
auto
feed_list
=
Get
<
std
::
vector
<
std
::
string
>>
(
"feed_list"
);
for
(
auto
node
:
graph
->
Nodes
())
{
if
(
!
node
->
IsVar
())
{
continue
;
}
bool
is_feed
=
std
::
find
(
feed_list
.
begin
(),
feed_list
.
end
(),
node
->
Name
())
!=
feed_list
.
end
();
if
(
is_feed
)
{
auto
input_shape
=
node
->
Var
()
->
GetShape
();
if
(
input_shape
[
0
]
<=
-
1
)
{
input_shape
[
0
]
=
batch_size
;
node
->
Var
()
->
SetShape
(
input_shape
);
}
// int64->int32
if
(
node
->
Var
()
->
GetDataType
()
==
proto
::
VarType
::
INT64
)
{
node
->
Var
()
->
SetDataType
(
proto
::
VarType
::
INT32
);
}
}
}
// temp scope for shape inference
std
::
shared_ptr
<
paddle
::
framework
::
Scope
>
scope
(
new
paddle
::
framework
::
Scope
());
for
(
auto
node
:
graph
->
Nodes
())
{
if
(
!
node
->
IsVar
())
{
continue
;
}
auto
var_desc
=
node
->
Var
();
auto
*
ptr
=
scope
->
Var
(
var_desc
->
Name
());
paddle
::
framework
::
InitializeVariable
(
ptr
,
var_desc
->
GetType
());
auto
tensor
=
ptr
->
GetMutable
<
paddle
::
framework
::
LoDTensor
>
();
tensor
->
Resize
(
paddle
::
framework
::
make_ddim
(
var_desc
->
GetShape
()));
}
// infer shape
auto
nodes
=
ir
::
TopologySortOperations
(
*
graph
);
for
(
auto
node
:
nodes
)
{
auto
op_desc
=
node
->
Op
();
auto
op
=
paddle
::
framework
::
OpRegistry
::
CreateOp
(
*
op_desc
);
paddle
::
framework
::
RuntimeContext
ctx
(
op
->
Inputs
(),
op
->
Outputs
(),
*
scope
);
op
->
RuntimeInferShape
(
*
scope
,
paddle
::
platform
::
CPUPlace
(),
ctx
);
for
(
auto
it
=
ctx
.
outputs
.
begin
();
it
!=
ctx
.
outputs
.
end
();
it
++
)
{
for
(
int
i
=
0
;
i
<
it
->
second
.
size
();
i
++
)
{
auto
output_name
=
op_desc
->
Output
(
it
->
first
)[
i
];
auto
dim
=
it
->
second
[
i
]
->
GetMutable
<
paddle
::
framework
::
LoDTensor
>
()
->
dims
();
auto
new_shape
=
paddle
::
framework
::
vectorize
(
dim
);
for
(
auto
output_node
:
node
->
outputs
)
{
if
(
output_node
->
Name
()
==
output_name
)
{
output_node
->
Var
()
->
SetShape
(
new_shape
);
}
}
}
}
}
// release the temp scope
scope
.
reset
();
VLOG
(
10
)
<<
"Post Graph: "
;
VLOG
(
10
)
<<
DebugString
(
graph
);
VLOG
(
10
)
<<
"leave InferShapePass::ApplyImpl"
;
}
}
// namespace ir
}
// namespace framework
}
// namespace paddle
REGISTER_PASS
(
infer_shape_pass
,
paddle
::
framework
::
ir
::
InferShapePass
)
.
RequirePassAttr
(
"feed_list"
);
paddle/fluid/framework/ir/ipu/infer_shape_pass.h
0 → 100644
浏览文件 @
8adaa0f0
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/framework/ir/ipu/ipu_pass_base.h"
namespace
paddle
{
namespace
framework
{
namespace
ir
{
class
InferShapePass
:
public
IPUPassBase
{
protected:
void
ApplyImpl
(
ir
::
Graph
*
graph
)
const
override
;
};
}
// namespace ir
}
// namespace framework
}
// namespace paddle
paddle/fluid/framework/ir/ipu/inference_postprocess_pass.cc
0 → 100644
浏览文件 @
8adaa0f0
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/ir/ipu/inference_postprocess_pass.h"
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/pass_tester_helper.h"
#include "paddle/fluid/platform/device/ipu/ipu_backend.h"
#include "paddle/fluid/platform/device/ipu/ipu_strategy.h"
#include "paddle/fluid/platform/enforce.h"
namespace
paddle
{
namespace
framework
{
namespace
ir
{
void
InferencePostprocessPass
::
ApplyImpl
(
ir
::
Graph
*
graph
)
const
{
VLOG
(
10
)
<<
"enter InferencePostprocessPass::ApplyImpl"
;
std
::
vector
<
std
::
string
>
feed_list
;
feed_list
=
Get
<
std
::
vector
<
std
::
string
>>
(
"feed_list"
);
std
::
vector
<
std
::
string
>
fetch_list
;
fetch_list
=
Get
<
std
::
vector
<
std
::
string
>>
(
"fetch_list"
);
auto
*
feed_var
=
new
paddle
::
framework
::
VarDesc
(
"feed"
);
feed_var
->
SetType
(
proto
::
VarType
::
FEED_MINIBATCH
);
auto
*
feed_var_node
=
graph
->
CreateVarNode
(
feed_var
);
auto
*
fetch_var
=
new
paddle
::
framework
::
VarDesc
(
"fetch"
);
fetch_var
->
SetType
(
proto
::
VarType
::
FETCH_LIST
);
auto
*
fetch_var_node
=
graph
->
CreateVarNode
(
fetch_var
);
for
(
int
i
=
0
;
i
<
feed_list
.
size
();
i
++
)
{
for
(
auto
node
:
graph
->
Nodes
())
{
if
(
node
->
Name
()
==
feed_list
[
i
])
{
auto
*
op
=
new
paddle
::
framework
::
OpDesc
();
op
->
SetType
(
"feed"
);
op
->
SetInput
(
"X"
,
{
"feed"
});
op
->
SetOutput
(
"Out"
,
{
node
->
Name
()});
op
->
SetAttr
(
"col"
,
i
);
auto
*
op_node
=
graph
->
CreateOpNode
(
op
);
node
->
inputs
.
push_back
(
op_node
);
op_node
->
outputs
.
push_back
(
node
);
feed_var_node
->
outputs
.
push_back
(
op_node
);
op_node
->
inputs
.
push_back
(
feed_var_node
);
break
;
}
}
}
for
(
int
i
=
0
;
i
<
fetch_list
.
size
();
i
++
)
{
for
(
auto
node
:
graph
->
Nodes
())
{
if
(
node
->
Name
()
==
fetch_list
[
i
])
{
auto
*
op
=
new
paddle
::
framework
::
OpDesc
();
op
->
SetType
(
"fetch"
);
op
->
SetInput
(
"X"
,
{
node
->
Name
()});
op
->
SetOutput
(
"Out"
,
{
"fetch"
});
op
->
SetAttr
(
"col"
,
i
);
auto
*
op_node
=
graph
->
CreateOpNode
(
op
);
node
->
outputs
.
push_back
(
op_node
);
op_node
->
inputs
.
push_back
(
node
);
fetch_var_node
->
inputs
.
push_back
(
op_node
);
op_node
->
outputs
.
push_back
(
fetch_var_node
);
break
;
}
}
}
VLOG
(
10
)
<<
"leave InferencePostprocessPass::ApplyImpl"
;
}
}
// namespace ir
}
// namespace framework
}
// namespace paddle
REGISTER_PASS
(
inference_postprocess_pass
,
paddle
::
framework
::
ir
::
InferencePostprocessPass
)
.
RequirePassAttr
(
"feed_list"
)
.
RequirePassAttr
(
"fetch_list"
);
paddle/fluid/framework/ir/ipu/inference_postprocess_pass.h
0 → 100644
浏览文件 @
8adaa0f0
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/framework/ir/ipu/ipu_pass_base.h"
namespace
paddle
{
namespace
framework
{
namespace
ir
{
class
InferencePostprocessPass
:
public
IPUPassBase
{
protected:
void
ApplyImpl
(
ir
::
Graph
*
graph
)
const
override
;
};
}
// namespace ir
}
// namespace framework
}
// namespace paddle
paddle/fluid/framework/ir/ipu/inference_process_pass.cc
0 → 100644
浏览文件 @
8adaa0f0
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/ir/ipu/inference_process_pass.h"
#include "paddle/fluid/platform/device/ipu/ipu_backend.h"
#include "paddle/fluid/platform/device/ipu/ipu_strategy.h"
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/pass_tester_helper.h"
#include "paddle/fluid/platform/enforce.h"
namespace
paddle
{
namespace
framework
{
namespace
ir
{
void
InferenceProcessPass
::
ApplyImpl
(
ir
::
Graph
*
graph
)
const
{
VLOG
(
10
)
<<
"enter InferenceProcessPass::ApplyImpl"
;
// Get a new instance of ipu_backend
std
::
shared_ptr
<
platform
::
ipu
::
IpuBackend
>
ipu_backend
=
platform
::
ipu
::
IpuBackend
::
GetNewInstance
();
// Set scope
auto
&
scope
=
graph
->
Get
<
Scope
>
(
kParamScopeAttr
);
ipu_backend
->
SetScope
(
scope
);
// Set ipu_strategy
static
std
::
shared_ptr
<
platform
::
ipu
::
IpuStrategy
>
ipu_strategy_instance_
(
new
platform
::
ipu
::
IpuStrategy
());
ipu_strategy_instance_
->
is_training
=
false
;
auto
num_ipus
=
graph
->
Get
<
int
>
(
"num_ipus"
);
ipu_strategy_instance_
->
num_ipus
=
num_ipus
;
if
(
num_ipus
>
1
)
{
ipu_strategy_instance_
->
popart_options_
.
virtualGraphMode
=
platform
::
ipu
::
VirtualGraphMode
::
Manual
;
}
else
{
ipu_strategy_instance_
->
popart_options_
.
virtualGraphMode
=
platform
::
ipu
::
VirtualGraphMode
::
Off
;
}
auto
enable_pipelining
=
graph
->
Get
<
bool
>
(
"enable_pipelining"
);
ipu_strategy_instance_
->
popart_options_
.
enablePipelining
=
enable_pipelining
;
if
(
enable_pipelining
)
{
auto
batches_per_step
=
graph
->
Get
<
int
>
(
"batches_per_step"
);
PADDLE_ENFORCE_GE
(
batches_per_step
,
num_ipus
,
platform
::
errors
::
InvalidArgument
(
"Batched per step should be equal or "
"greater than the number of IPUs"
));
ipu_strategy_instance_
->
batches_per_step
=
batches_per_step
;
}
ipu_strategy_instance_
->
batch_size
=
graph
->
Get
<
int
>
(
"batch_size"
);
ipu_strategy_instance_
->
need_avg_shard
=
graph
->
Get
<
bool
>
(
"need_avg_shard"
);
ipu_backend
->
SetIpuStrategy
(
*
(
ipu_strategy_instance_
.
get
()));
// Get feed_list and fetch list
std
::
vector
<
std
::
string
>
feed_list
=
{};
std
::
vector
<
std
::
string
>
fetch_list
=
{};
for
(
auto
node
:
graph
->
Nodes
())
{
if
(
node
->
Name
()
==
"feed"
)
{
if
(
node
->
IsOp
())
{
feed_list
.
push_back
(
""
);
}
}
else
if
(
node
->
Name
()
==
"fetch"
)
{
if
(
node
->
IsOp
())
{
fetch_list
.
push_back
(
""
);
}
}
}
for
(
auto
node
:
graph
->
Nodes
())
{
if
(
node
->
Name
()
==
"feed"
)
{
if
(
node
->
IsOp
())
{
feed_list
[
BOOST_GET_CONST
(
int
,
node
->
Op
()
->
GetAttr
(
"col"
))]
=
node
->
outputs
[
0
]
->
Name
();
}
}
else
if
(
node
->
Name
()
==
"fetch"
)
{
if
(
node
->
IsOp
())
{
fetch_list
[
BOOST_GET_CONST
(
int
,
node
->
Op
()
->
GetAttr
(
"col"
))]
=
node
->
inputs
[
0
]
->
Name
();
}
}
}
// Run passes
std
::
vector
<
std
::
string
>
graph_pass
=
{
"forward_graph_extract_pass"
,
"infer_shape_pass"
,
"avg_shard_pass"
,
"popart_canonicalization_pass"
};
std
::
vector
<
std
::
string
>
compile_pass
=
{
"ipu_inplace_pass"
,
"ipu_graph_builder_pass"
,
"ipu_runtime_replacer_pass"
,
"inference_postprocess_pass"
};
for
(
auto
pass_name
:
graph_pass
)
{
auto
pass
=
PassRegistry
::
Instance
().
Get
(
pass_name
);
if
(
pass_name
==
"infer_shape_pass"
)
{
pass
->
Set
(
"feed_list"
,
new
std
::
vector
<
std
::
string
>
(
feed_list
.
begin
(),
feed_list
.
end
()));
}
pass
->
Apply
(
graph
);
}
for
(
auto
pass_name
:
compile_pass
)
{
auto
pass
=
PassRegistry
::
Instance
().
Get
(
pass_name
);
pass
->
Set
(
"feed_list"
,
new
std
::
vector
<
std
::
string
>
(
feed_list
.
begin
(),
feed_list
.
end
()));
pass
->
Set
(
"fetch_list"
,
new
std
::
vector
<
std
::
string
>
(
fetch_list
.
begin
(),
fetch_list
.
end
()));
pass
->
Apply
(
graph
);
}
VLOG
(
10
)
<<
"leave InferenceProcessPass::ApplyImpl"
;
}
}
// namespace ir
}
// namespace framework
}
// namespace paddle
REGISTER_PASS
(
inference_process_pass
,
paddle
::
framework
::
ir
::
InferenceProcessPass
);
paddle/fluid/framework/ir/ipu/inference_process_pass.h
0 → 100644
浏览文件 @
8adaa0f0
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/framework/ir/ipu/ipu_pass_base.h"
namespace
paddle
{
namespace
framework
{
namespace
ir
{
class
InferenceProcessPass
:
public
IPUPassBase
{
protected:
void
ApplyImpl
(
ir
::
Graph
*
graph
)
const
override
;
};
}
// namespace ir
}
// namespace framework
}
// namespace paddle
paddle/fluid/framework/ir/ipu/ipu_graph_builder_pass.cc
0 → 100644
浏览文件 @
8adaa0f0
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/ir/ipu/ipu_graph_builder_pass.h"
#include "paddle/fluid/framework/ir/pass_tester_helper.h"
#include "paddle/fluid/platform/device/ipu/ipu_backend.h"
namespace
paddle
{
namespace
framework
{
namespace
ir
{
void
IpuGraphBuilderPass
::
ApplyImpl
(
ir
::
Graph
*
graph
)
const
{
VLOG
(
10
)
<<
"enter IpuGraphBuilderPass::ApplyImpl"
;
VLOG
(
10
)
<<
"Raw Graph: "
;
VLOG
(
10
)
<<
DebugString
(
graph
);
std
::
vector
<
std
::
string
>
feed_list
;
feed_list
=
Get
<
std
::
vector
<
std
::
string
>>
(
"feed_list"
);
std
::
vector
<
std
::
string
>
fetch_list
;
fetch_list
=
Get
<
std
::
vector
<
std
::
string
>>
(
"fetch_list"
);
std
::
shared_ptr
<
platform
::
ipu
::
IpuBackend
>
ipu_backend
=
platform
::
ipu
::
IpuBackend
::
GetInstance
();
ipu_backend
->
Compile
(
graph
,
feed_list
,
fetch_list
);
VLOG
(
10
)
<<
"Post Graph: "
;
VLOG
(
10
)
<<
DebugString
(
graph
);
VLOG
(
10
)
<<
"leave IpuGraphBuilderPass::ApplyImpl"
;
}
}
// namespace ir
}
// namespace framework
}
// namespace paddle
REGISTER_PASS
(
ipu_graph_builder_pass
,
paddle
::
framework
::
ir
::
IpuGraphBuilderPass
)
.
RequirePassAttr
(
"feed_list"
)
.
RequirePassAttr
(
"fetch_list"
);
paddle/fluid/framework/ir/ipu/ipu_graph_builder_pass.h
0 → 100644
浏览文件 @
8adaa0f0
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/ipu/ipu_pass_base.h"
namespace
paddle
{
namespace
framework
{
namespace
ir
{
class
IpuGraphBuilderPass
:
public
IPUPassBase
{
protected:
void
ApplyImpl
(
ir
::
Graph
*
graph
)
const
override
;
};
}
// namespace ir
}
// namespace framework
}
// namespace paddle
paddle/fluid/framework/ir/ipu/ipu_inplace_pass.cc
0 → 100644
浏览文件 @
8adaa0f0
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/ir/ipu/ipu_inplace_pass.h"
#include "paddle/fluid/framework/ir/pass_tester_helper.h"
namespace
paddle
{
namespace
framework
{
namespace
ir
{
std
::
string
GenerateVarName
(
Node
*
node
)
{
return
node
->
Name
()
+
"_"
+
std
::
to_string
(
node
->
id
());
}
void
IpuInplacePass
::
ApplyImpl
(
ir
::
Graph
*
graph
)
const
{
// use this pass after forward_graph_extract_pass
// raise error if the inplaced var both in feed_list & fetch_list
VLOG
(
10
)
<<
"enter IpuInplacePass::ApplyImpl"
;
VLOG
(
10
)
<<
"Raw Graph: "
;
VLOG
(
10
)
<<
DebugString
(
graph
);
std
::
vector
<
std
::
string
>
feed_list
;
feed_list
=
Get
<
std
::
vector
<
std
::
string
>>
(
"feed_list"
);
std
::
vector
<
std
::
string
>
fetch_list
;
fetch_list
=
Get
<
std
::
vector
<
std
::
string
>>
(
"fetch_list"
);
std
::
map
<
std
::
string
,
int
>
var_name
;
for
(
auto
*
node
:
graph
->
Nodes
())
{
if
(
node
->
IsVar
())
{
if
(
var_name
.
find
(
node
->
Name
())
==
var_name
.
end
())
{
var_name
.
emplace
(
node
->
Name
(),
1
);
}
else
{
var_name
[
node
->
Name
()]
++
;
}
}
}
for
(
auto
*
node
:
graph
->
Nodes
())
{
if
(
node
->
IsVar
())
{
if
(
var_name
[
node
->
Name
()]
>
1
)
{
auto
is_feed
=
(
std
::
find
(
feed_list
.
begin
(),
feed_list
.
end
(),
node
->
Name
())
!=
feed_list
.
end
())
&&
(
node
->
inputs
.
size
()
==
0
);
auto
is_fetch
=
(
std
::
find
(
fetch_list
.
begin
(),
fetch_list
.
end
(),
node
->
Name
())
!=
fetch_list
.
end
())
&&
(
node
->
outputs
.
size
()
==
0
);
if
(
!
is_feed
&&
!
is_fetch
&&
!
node
->
Var
()
->
Persistable
())
{
auto
old_name
=
node
->
Name
();
auto
new_name
=
GenerateVarName
(
node
);
node
->
RenameVar
(
new_name
);
for
(
auto
*
op_in
:
node
->
inputs
)
{
op_in
->
Op
()
->
RenameOutput
(
old_name
,
new_name
);
}
for
(
auto
*
op_out
:
node
->
outputs
)
{
op_out
->
Op
()
->
RenameInput
(
old_name
,
new_name
);
}
}
}
}
}
VLOG
(
10
)
<<
"Post Graph: "
;
VLOG
(
10
)
<<
DebugString
(
graph
);
VLOG
(
10
)
<<
"leave IpuInplacePass::ApplyImpl"
;
}
}
// namespace ir
}
// namespace framework
}
// namespace paddle
REGISTER_PASS
(
ipu_inplace_pass
,
paddle
::
framework
::
ir
::
IpuInplacePass
)
.
RequirePassAttr
(
"feed_list"
)
.
RequirePassAttr
(
"fetch_list"
);
paddle/fluid/framework/ir/ipu/ipu_inplace_pass.h
0 → 100644
浏览文件 @
8adaa0f0
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/framework/ir/ipu/ipu_pass_base.h"
namespace
paddle
{
namespace
framework
{
namespace
ir
{
class
IpuInplacePass
:
public
IPUPassBase
{
protected:
void
ApplyImpl
(
ir
::
Graph
*
graph
)
const
override
;
};
}
// namespace ir
}
// namespace framework
}
// namespace paddle
paddle/fluid/framework/ir/ipu/ipu_pass_base.cc
0 → 100644
浏览文件 @
8adaa0f0
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/ir/ipu/ipu_pass_base.h"
namespace
paddle
{
namespace
framework
{
namespace
ir
{
void
IPUPassBase
::
Init
(
const
std
::
string
&
repr
,
Graph
*
graph
)
const
{
repr_
=
repr
;
graph_
=
graph
;
}
}
// namespace ir
}
// namespace framework
}
// namespace paddle
paddle/fluid/framework/ir/ipu/ipu_pass_base.h
0 → 100644
浏览文件 @
8adaa0f0
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/pass.h"
#include "paddle/fluid/framework/scope.h"
namespace
paddle
{
namespace
framework
{
namespace
ir
{
class
IPUPassBase
:
public
Pass
{
public:
void
Init
(
const
std
::
string
&
repr
,
Graph
*
graph
)
const
;
virtual
~
IPUPassBase
()
{}
protected:
mutable
Graph
*
graph_
;
mutable
std
::
string
repr_
;
};
}
// namespace ir
}
// namespace framework
}
// namespace paddle
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录