Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
8b30c1ec
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2298
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
8b30c1ec
编写于
12月 10, 2021
作者:
J
jianghaicheng
提交者:
GitHub
12月 10, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add popart_canonicalization p2 (#37965)
上级
8adaa0f0
变更
13
隐藏空白更改
内联
并排
Showing
13 changed file
with
893 addition
and
1 deletion
+893
-1
paddle/fluid/framework/ir/ipu/ipu_runtime_replacer_pass.cc
paddle/fluid/framework/ir/ipu/ipu_runtime_replacer_pass.cc
+97
-0
paddle/fluid/framework/ir/ipu/ipu_runtime_replacer_pass.h
paddle/fluid/framework/ir/ipu/ipu_runtime_replacer_pass.h
+31
-0
paddle/fluid/framework/ir/ipu/optimizer_extract_pass.cc
paddle/fluid/framework/ir/ipu/optimizer_extract_pass.cc
+91
-0
paddle/fluid/framework/ir/ipu/optimizer_extract_pass.h
paddle/fluid/framework/ir/ipu/optimizer_extract_pass.h
+31
-0
paddle/fluid/framework/ir/ipu/optimizer_state_align_pass.cc
paddle/fluid/framework/ir/ipu/optimizer_state_align_pass.cc
+79
-0
paddle/fluid/framework/ir/ipu/optimizer_state_align_pass.h
paddle/fluid/framework/ir/ipu/optimizer_state_align_pass.h
+36
-0
paddle/fluid/framework/ir/ipu/popart_canonicalization_pass.cc
...le/fluid/framework/ir/ipu/popart_canonicalization_pass.cc
+68
-0
paddle/fluid/framework/ir/ipu/popart_canonicalization_pass.h
paddle/fluid/framework/ir/ipu/popart_canonicalization_pass.h
+30
-0
paddle/fluid/platform/device/ipu/CMakeLists.txt
paddle/fluid/platform/device/ipu/CMakeLists.txt
+1
-1
paddle/fluid/platform/device/ipu/popart_canonicalization/activation_ops.cc
...form/device/ipu/popart_canonicalization/activation_ops.cc
+72
-0
paddle/fluid/platform/device/ipu/popart_canonicalization/canonicalization_utils.cc
...ice/ipu/popart_canonicalization/canonicalization_utils.cc
+185
-0
paddle/fluid/platform/device/ipu/popart_canonicalization/canonicalization_utils.h
...vice/ipu/popart_canonicalization/canonicalization_utils.h
+64
-0
paddle/fluid/platform/device/ipu/popart_canonicalization/elementwise_ops.cc
...orm/device/ipu/popart_canonicalization/elementwise_ops.cc
+108
-0
未找到文件。
paddle/fluid/framework/ir/ipu/ipu_runtime_replacer_pass.cc
0 → 100644
浏览文件 @
8b30c1ec
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/ir/ipu/ipu_runtime_replacer_pass.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/framework/ir/pass_tester_helper.h"
namespace
paddle
{
namespace
framework
{
namespace
ir
{
void
IpuRuntimeReplacerPass
::
ApplyImpl
(
ir
::
Graph
*
graph
)
const
{
VLOG
(
10
)
<<
"enter IpuRuntimeReplacerPass::ApplyImpl"
;
VLOG
(
10
)
<<
"Raw Graph: "
;
VLOG
(
10
)
<<
DebugString
(
graph
);
std
::
vector
<
std
::
string
>
feed_list
;
feed_list
=
Get
<
std
::
vector
<
std
::
string
>>
(
"feed_list"
);
std
::
vector
<
std
::
string
>
fetch_list
;
fetch_list
=
Get
<
std
::
vector
<
std
::
string
>>
(
"fetch_list"
);
framework
::
OpDesc
ipu_rt_op_desc
;
ipu_rt_op_desc
.
SetType
(
"ipu_runtime"
);
ipu_rt_op_desc
.
SetInput
(
"FeedList"
,
feed_list
);
ipu_rt_op_desc
.
SetOutput
(
"FetchList"
,
fetch_list
);
ipu_rt_op_desc
.
Flush
();
// Create a new node for the ipu_runtime_op.
auto
*
ipu_rt_node
=
graph
->
CreateOpNode
(
&
ipu_rt_op_desc
);
for
(
auto
*
node
:
graph
->
Nodes
())
{
if
(
node
->
IsVar
())
{
for
(
auto
feed
:
feed_list
)
{
if
(
node
->
Name
()
==
feed
)
{
IR_NODE_LINK_TO
(
node
,
ipu_rt_node
);
}
}
for
(
auto
fetch
:
fetch_list
)
{
if
(
node
->
Name
()
==
fetch
)
{
IR_NODE_LINK_TO
(
ipu_rt_node
,
node
);
}
}
}
}
// set ipu_runtime_op dtype attr
if
(
fetch_list
.
size
()
==
1
)
{
for
(
auto
*
node
:
graph
->
Nodes
())
{
if
(
node
->
IsVar
())
{
for
(
auto
fetch
:
fetch_list
)
{
if
(
node
->
Name
()
==
fetch
)
{
ipu_rt_node
->
Op
()
->
SetAttr
(
"dtype"
,
node
->
Var
()
->
GetDataType
());
}
}
}
}
}
// Remove unneeded nodes.
std
::
unordered_set
<
const
Node
*>
marked_nodes
;
for
(
auto
*
node
:
graph
->
Nodes
())
{
if
(
node
->
IsOp
())
{
auto
*
op_desc
=
node
->
Op
();
if
(
op_desc
->
Type
()
!=
"ipu_runtime"
)
{
marked_nodes
.
insert
(
node
);
}
}
}
GraphSafeRemoveNodes
(
graph
,
marked_nodes
);
VLOG
(
10
)
<<
"Post Graph: "
;
VLOG
(
10
)
<<
DebugString
(
graph
);
VLOG
(
10
)
<<
"leave IpuRuntimeReplacerPass::ApplyImpl"
;
}
}
// namespace ir
}
// namespace framework
}
// namespace paddle
REGISTER_PASS
(
ipu_runtime_replacer_pass
,
paddle
::
framework
::
ir
::
IpuRuntimeReplacerPass
)
.
RequirePassAttr
(
"feed_list"
)
.
RequirePassAttr
(
"fetch_list"
);
paddle/fluid/framework/ir/ipu/ipu_runtime_replacer_pass.h
0 → 100644
浏览文件 @
8b30c1ec
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/ipu/ipu_pass_base.h"
namespace
paddle
{
namespace
framework
{
namespace
ir
{
class
IpuRuntimeReplacerPass
:
public
IPUPassBase
{
protected:
void
ApplyImpl
(
ir
::
Graph
*
graph
)
const
override
;
};
}
// namespace ir
}
// namespace framework
}
// namespace paddle
paddle/fluid/framework/ir/ipu/optimizer_extract_pass.cc
0 → 100644
浏览文件 @
8b30c1ec
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/ir/ipu/optimizer_extract_pass.h"
#include "paddle/fluid/framework/ir/pass_tester_helper.h"
#include "paddle/fluid/platform/device/ipu/ipu_backend.h"
namespace
paddle
{
namespace
framework
{
namespace
ir
{
void
IpuOptimizerExtractPass
::
ApplyImpl
(
ir
::
Graph
*
graph
)
const
{
VLOG
(
10
)
<<
"enter IpuOptimizerExtractPass::ApplyImpl"
;
VLOG
(
10
)
<<
"Raw Graph: "
;
VLOG
(
10
)
<<
DebugString
(
graph
);
auto
ipu_backend
=
paddle
::
platform
::
ipu
::
IpuBackend
::
GetInstance
();
for
(
auto
*
node
:
graph
->
Nodes
())
{
if
(
node
->
IsOp
()
&&
node
->
Op
())
{
int
op_role
=
BOOST_GET_CONST
(
int
,
node
->
Op
()
->
GetAttr
(
framework
::
OpProtoAndCheckerMaker
::
OpRoleAttrName
()));
// graph usually have multiple optimizer node for different parameter,
// and these node have the same type and attr value usually
if
((
op_role
==
static_cast
<
int
>
(
framework
::
OpRole
::
kOptimize
)))
{
ipu_backend
->
GetExecutor
().
SetOptimizerType
(
node
->
Op
()
->
Type
());
VLOG
(
10
)
<<
"found optimizer type: "
<<
node
->
Op
()
->
Type
();
for
(
const
std
::
string
&
attr_name
:
node
->
Op
()
->
AttrNames
())
{
auto
attr_type
=
node
->
Op
()
->
GetAttrType
(
attr_name
);
// with adam, attr are float
if
(
attr_type
==
proto
::
AttrType
::
FLOAT
)
{
auto
attr_value
=
BOOST_GET_CONST
(
float
,
node
->
Op
()
->
GetAttr
(
attr_name
));
ipu_backend
->
GetExecutor
().
SetOptimizerAttr
(
attr_name
,
attr_value
);
}
else
{
VLOG
(
10
)
<<
"Skip "
<<
attr_type
;
}
}
auto
lr_var_name
=
node
->
Op
()
->
Input
(
"LearningRate"
);
PADDLE_ENFORCE_EQ
(
lr_var_name
.
size
(),
1u
,
platform
::
errors
::
InvalidArgument
(
"In op(%s), find input(LearningRate) failed."
,
node
->
Op
()
->
Type
()));
ipu_backend
->
GetExecutor
().
SetLRVarName
(
lr_var_name
[
0
]);
}
if
((
op_role
==
static_cast
<
int
>
(
framework
::
OpRole
::
kLoss
)))
{
VLOG
(
10
)
<<
"found loss op type: "
<<
node
->
Op
()
->
Type
();
auto
outputs
=
node
->
Op
()
->
Outputs
();
PADDLE_ENFORCE_EQ
(
outputs
.
size
(),
1
,
platform
::
errors
::
InvalidArgument
(
"Can only support one loss key"
));
auto
losses_name
=
outputs
.
begin
()
->
second
;
PADDLE_ENFORCE_EQ
(
losses_name
.
size
(),
1
,
platform
::
errors
::
InvalidArgument
(
"Can only support one loss name"
));
ipu_backend
->
GetExecutor
().
SetLoss
(
losses_name
[
0
]);
}
}
}
VLOG
(
10
)
<<
"Post Graph: "
;
VLOG
(
10
)
<<
DebugString
(
graph
);
VLOG
(
10
)
<<
"leave IpuOptimizerExtractPass::ApplyImpl"
;
}
}
// namespace ir
}
// namespace framework
}
// namespace paddle
REGISTER_PASS
(
optimizer_extract_pass
,
paddle
::
framework
::
ir
::
IpuOptimizerExtractPass
);
paddle/fluid/framework/ir/ipu/optimizer_extract_pass.h
0 → 100644
浏览文件 @
8b30c1ec
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/ipu/ipu_pass_base.h"
namespace
paddle
{
namespace
framework
{
namespace
ir
{
class
IpuOptimizerExtractPass
:
public
IPUPassBase
{
protected:
void
ApplyImpl
(
ir
::
Graph
*
graph
)
const
override
;
};
}
// namespace ir
}
// namespace framework
}
// namespace paddle
paddle/fluid/framework/ir/ipu/optimizer_state_align_pass.cc
0 → 100644
浏览文件 @
8b30c1ec
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/ir/ipu/optimizer_state_align_pass.h"
#include "paddle/fluid/framework/ir/pass_tester_helper.h"
#include "paddle/fluid/platform/device/ipu/common.h"
#include "paddle/fluid/platform/device/ipu/ipu_backend.h"
namespace
paddle
{
namespace
framework
{
namespace
ir
{
using
paddle
::
platform
::
ipu
::
IpuBackend
;
using
framework
::
ir
::
Graph
;
using
framework
::
ir
::
Node
;
void
IpuOptimizerStateAlignPass
::
ApplyImpl
(
ir
::
Graph
*
graph
)
const
{
VLOG
(
10
)
<<
"enter IpuOptimizerStateAlignPass::ApplyImpl"
;
VLOG
(
10
)
<<
"Raw Graph: "
;
VLOG
(
10
)
<<
DebugString
(
graph
);
auto
ipu_backend
=
IpuBackend
::
GetInstance
();
const
auto
*
scope_
=
ipu_backend
->
GetScope
();
for
(
auto
*
node
:
graph
->
Nodes
())
{
if
(
node
->
IsOp
()
&&
node
->
Op
())
{
int
op_role
=
BOOST_GET_CONST
(
int
,
node
->
Op
()
->
GetAttr
(
framework
::
OpProtoAndCheckerMaker
::
OpRoleAttrName
()));
if
((
op_role
==
static_cast
<
int
>
(
framework
::
OpRole
::
kOptimize
)))
{
auto
inputs
=
node
->
Op
()
->
Inputs
();
if
(
inputs
.
count
(
platform
::
ipu
::
sBeta1Pow
))
{
auto
var
=
scope_
->
GetVar
(
inputs
.
at
(
platform
::
ipu
::
sBeta1Pow
)[
0
]);
auto
data
=
var
->
GetMutable
<
framework
::
LoDTensor
>
()
->
data
<
float
>
();
auto
beta
=
BOOST_GET_CONST
(
float
,
node
->
Op
()
->
GetAttr
(
platform
::
ipu
::
sBeta1
));
// ensure current save with beta1pow, rather than step.
// beta1pow = beta1 ^ (step + 1). Just set beta1pow because popart
// support single Step__
bool
save_with_beta1pow
=
(
data
[
0
]
<
1.0
f
)
&&
(
data
[
0
]
>
0.0
f
);
float
step
=
0
;
float
beta_acc
=
beta
;
while
(
beta_acc
>
data
[
0
]
&&
save_with_beta1pow
)
{
beta_acc
*=
beta
;
step
+=
1
;
}
if
(
save_with_beta1pow
)
{
data
[
0
]
=
step
;
}
}
}
}
}
VLOG
(
10
)
<<
"Post Graph: "
;
VLOG
(
10
)
<<
DebugString
(
graph
);
VLOG
(
10
)
<<
"leave IpuOptimizerStateAlignPass::ApplyImpl"
;
}
}
// namespace ir
}
// namespace framework
}
// namespace paddle
REGISTER_PASS
(
optimizer_state_align_pass
,
paddle
::
framework
::
ir
::
IpuOptimizerStateAlignPass
);
paddle/fluid/framework/ir/ipu/optimizer_state_align_pass.h
0 → 100644
浏览文件 @
8b30c1ec
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/ipu/ipu_pass_base.h"
namespace
paddle
{
namespace
framework
{
namespace
ir
{
/*
* This pass should only affect optimizer that need bias correction,
* include Adam/Lamb.
*/
class
IpuOptimizerStateAlignPass
:
public
IPUPassBase
{
protected:
void
ApplyImpl
(
ir
::
Graph
*
graph
)
const
override
;
};
}
// namespace ir
}
// namespace framework
}
// namespace paddle
paddle/fluid/framework/ir/ipu/popart_canonicalization_pass.cc
0 → 100644
浏览文件 @
8b30c1ec
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/ir/ipu/popart_canonicalization_pass.h"
#include "paddle/fluid/framework/ir/pass_tester_helper.h"
#include "paddle/fluid/platform/device/ipu/popart_canonicalization/canonicalization_utils.h"
#include "paddle/fluid/platform/device/ipu/popart_canonicalization/post_canonicalization.h"
namespace
paddle
{
namespace
framework
{
namespace
ir
{
using
framework
::
ir
::
Graph
;
using
framework
::
ir
::
Node
;
using
platform
::
ipu
::
SymbolHandler
;
void
PopartCanonicalizationPass
::
ApplyImpl
(
ir
::
Graph
*
graph
)
const
{
VLOG
(
10
)
<<
"enter PopartCanonicalizationPass::ApplyImpl"
;
VLOG
(
10
)
<<
"Raw Graph: "
;
VLOG
(
10
)
<<
DebugString
(
graph
);
auto
nodes
=
graph
->
Nodes
();
for
(
auto
*
node
:
nodes
)
{
if
(
!
node
->
IsOp
())
{
continue
;
}
auto
*
op
=
node
->
Op
();
auto
op_type
=
op
->
Type
();
ir
::
Node
*
new_node
=
nullptr
;
SymbolHandler
handler
=
platform
::
ipu
::
GetHandler
(
op_type
);
if
(
handler
)
{
VLOG
(
11
)
<<
"Raw Paddle Node:"
;
VLOG
(
11
)
<<
node
->
Op
()
->
Proto
()
->
DebugString
();
new_node
=
handler
(
graph
,
node
);
VLOG
(
11
)
<<
"Post Popart Node:"
;
VLOG
(
11
)
<<
new_node
->
Op
()
->
Proto
()
->
DebugString
();
platform
::
ipu
::
ClearNode
(
node
);
graph
->
RemoveNode
(
node
);
}
else
{
LOG
(
ERROR
)
<<
"Can not find OpHandler for op_type: "
<<
op_type
;
}
}
VLOG
(
10
)
<<
"Post Graph: "
;
VLOG
(
10
)
<<
DebugString
(
graph
);
VLOG
(
10
)
<<
"leave PopartCanonicalizationPass::ApplyImpl"
;
}
}
// namespace ir
}
// namespace framework
}
// namespace paddle
REGISTER_PASS
(
popart_canonicalization_pass
,
paddle
::
framework
::
ir
::
PopartCanonicalizationPass
);
paddle/fluid/framework/ir/ipu/popart_canonicalization_pass.h
0 → 100644
浏览文件 @
8b30c1ec
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/framework/ir/ipu/ipu_pass_base.h"
namespace
paddle
{
namespace
framework
{
namespace
ir
{
class
PopartCanonicalizationPass
:
public
IPUPassBase
{
protected:
void
ApplyImpl
(
ir
::
Graph
*
graph
)
const
override
;
};
}
// namespace ir
}
// namespace framework
}
// namespace paddle
paddle/fluid/platform/device/ipu/CMakeLists.txt
浏览文件 @
8b30c1ec
# IPU
IF
(
WITH_IPU
)
IF
(
WITH_IPU
)
FILE
(
GLOB POPART_CANONICALIZATION_SRC
${
PADDLE_SOURCE_DIR
}
/paddle/fluid/platform/device/ipu/popart_canonicalization/*.cc
)
cc_library
(
ipu_device SRCS device.cc DEPS enforce popart
)
cc_library
(
ipu_device SRCS device.cc DEPS enforce popart
)
cc_library
(
ipu_utils SRCS ipu_utils.cc DEPS memory framework_proto popart
)
cc_library
(
ipu_utils SRCS ipu_utils.cc DEPS memory framework_proto popart
)
cc_library
(
ipu_strategy SRCS ipu_strategy.cc DEPS popart graph framework_proto enforce
)
cc_library
(
ipu_strategy SRCS ipu_strategy.cc DEPS popart graph framework_proto enforce
)
...
...
paddle/fluid/platform/device/ipu/popart_canonicalization/activation_ops.cc
0 → 100644
浏览文件 @
8b30c1ec
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/platform/device/ipu/popart_canonicalization/canonicalization_utils.h"
#include "paddle/fluid/platform/device/ipu/popart_canonicalization/op_builder.h"
#include "paddle/fluid/platform/enforce.h"
namespace
paddle
{
namespace
platform
{
namespace
ipu
{
namespace
{
Node
*
activation_op_handler
(
Graph
*
graph
,
Node
*
node
,
const
std
::
string
&
type
)
{
auto
new_node
=
CreateBaseOp
(
graph
,
node
,
type
,
{
GetInputVarNode
(
"X"
,
node
)},
node
->
outputs
);
return
new_node
;
}
Node
*
relu_handler
(
Graph
*
graph
,
Node
*
node
)
{
return
activation_op_handler
(
graph
,
node
,
"popart_relu"
);
}
Node
*
tanh_handler
(
Graph
*
graph
,
Node
*
node
)
{
return
activation_op_handler
(
graph
,
node
,
"popart_tanh"
);
}
Node
*
log_handler
(
Graph
*
graph
,
Node
*
node
)
{
return
activation_op_handler
(
graph
,
node
,
"popart_log"
);
}
Node
*
sigmoid_handler
(
Graph
*
graph
,
Node
*
node
)
{
return
activation_op_handler
(
graph
,
node
,
"popart_sigmoid"
);
}
Node
*
sqrt_handler
(
Graph
*
graph
,
Node
*
node
)
{
return
activation_op_handler
(
graph
,
node
,
"popart_sqrt"
);
}
Node
*
gelu_handler
(
Graph
*
graph
,
Node
*
node
)
{
return
activation_op_handler
(
graph
,
node
,
"popart_gelu_v2"
);
}
Node
*
log_softmax_handler
(
Graph
*
graph
,
Node
*
node
)
{
auto
axis
=
BOOST_GET_CONST
(
int
,
node
->
Op
()
->
GetAttr
(
"axis"
));
auto
new_softmax
=
CreateSoftmaxOpset11
(
graph
,
node
,
node
->
inputs
,
{},
axis
);
return
CreateBaseOp
(
graph
,
node
,
"popart_log"
,
new_softmax
->
outputs
,
node
->
outputs
);
}
REGISTER_HANDLER
(
relu
,
relu_handler
);
REGISTER_HANDLER
(
tanh
,
tanh_handler
);
REGISTER_HANDLER
(
log
,
log_handler
);
REGISTER_HANDLER
(
sigmoid
,
sigmoid_handler
);
REGISTER_HANDLER
(
sqrt
,
sqrt_handler
);
REGISTER_HANDLER
(
gelu
,
gelu_handler
);
REGISTER_HANDLER
(
log_softmax
,
log_softmax_handler
);
}
// namespace
}
// namespace ipu
}
// namespace platform
}
// namespace paddle
paddle/fluid/platform/device/ipu/popart_canonicalization/canonicalization_utils.cc
0 → 100644
浏览文件 @
8b30c1ec
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/platform/device/ipu/popart_canonicalization/canonicalization_utils.h"
namespace
paddle
{
namespace
platform
{
namespace
ipu
{
// This avoids the static initialisation order fiasco,
std
::
unordered_map
<
std
::
string
,
SymbolHandler
>
&
SymbolHandlers
()
{
static
std
::
unordered_map
<
std
::
string
,
SymbolHandler
>
symbol_handlers
;
return
symbol_handlers
;
}
bool
RegisterHandler
(
const
std
::
string
&
symbol
,
const
SymbolHandler
&
handler
)
{
if
(
SymbolHandlers
().
count
(
symbol
)
!=
0
)
{
LOG
(
WARNING
)
<<
"Trying to register popart handler twice for operator: "
<<
symbol
;
return
false
;
}
bool
new_handler
=
SymbolHandlers
().
emplace
(
symbol
,
handler
).
second
;
return
new_handler
;
}
// Return a pointer to a handler if one is registered for this kind of node or
// an empty std::function otherwise.
SymbolHandler
GetHandler
(
const
std
::
string
&
kind
)
{
auto
it
=
SymbolHandlers
().
find
(
kind
);
if
(
it
!=
SymbolHandlers
().
end
())
{
return
it
->
second
;
}
return
{};
}
void
ConnectNodes
(
Node
*
first_node
,
Node
*
next_node
)
{
first_node
->
outputs
.
push_back
(
next_node
);
next_node
->
inputs
.
push_back
(
first_node
);
}
void
DisConnectNodes
(
Node
*
first_node
,
Node
*
next_node
)
{
auto
rm_by_value
=
[
&
](
std
::
vector
<
Node
*>
&
vec
,
Node
*
n
)
{
vec
.
erase
(
std
::
remove
(
vec
.
begin
(),
vec
.
end
(),
n
),
vec
.
end
());
};
rm_by_value
(
first_node
->
outputs
,
next_node
);
rm_by_value
(
next_node
->
inputs
,
first_node
);
rm_by_value
(
first_node
->
inputs
,
next_node
);
rm_by_value
(
next_node
->
outputs
,
first_node
);
}
void
ClearNode
(
Node
*
node
)
{
auto
rm_by_value
=
[
&
](
std
::
vector
<
Node
*>
&
vec
,
Node
*
n
)
{
vec
.
erase
(
std
::
remove
(
vec
.
begin
(),
vec
.
end
(),
n
),
vec
.
end
());
};
for
(
auto
*
node_in
:
node
->
inputs
)
{
rm_by_value
(
node_in
->
outputs
,
node
);
}
for
(
auto
*
node_out
:
node
->
outputs
)
{
rm_by_value
(
node_out
->
inputs
,
node
);
}
}
void
CopyOpAttr
(
const
std
::
string
&
attr_name
,
OpDesc
*
op
,
OpDesc
*
new_op
,
bool
override
)
{
if
(
new_op
->
HasAttr
(
attr_name
)
&&
!
override
)
{
return
;
}
if
(
op
->
HasAttr
(
attr_name
))
{
VLOG
(
10
)
<<
"Copying attr: "
<<
attr_name
<<
" from "
<<
op
->
Type
()
<<
" to "
<<
new_op
->
Type
();
new_op
->
SetAttr
(
attr_name
,
op
->
GetAttr
(
attr_name
));
new_op
->
Flush
();
}
}
const
int
VarType2OnnxDtype
(
const
int
type
)
{
auto
dtype
=
static_cast
<
framework
::
proto
::
VarType
::
Type
>
(
type
);
switch
(
dtype
)
{
case
framework
::
proto
::
VarType
::
BOOL
:
return
static_cast
<
int
>
(
ONNXDataType
::
BOOL
);
case
framework
::
proto
::
VarType
::
INT16
:
return
static_cast
<
int
>
(
ONNXDataType
::
INT16
);
case
framework
::
proto
::
VarType
::
INT32
:
return
static_cast
<
int
>
(
ONNXDataType
::
INT32
);
case
framework
::
proto
::
VarType
::
INT64
:
return
static_cast
<
int
>
(
ONNXDataType
::
INT64
);
case
framework
::
proto
::
VarType
::
FP16
:
return
static_cast
<
int
>
(
ONNXDataType
::
FLOAT16
);
case
framework
::
proto
::
VarType
::
FP32
:
return
static_cast
<
int
>
(
ONNXDataType
::
FLOAT
);
case
framework
::
proto
::
VarType
::
FP64
:
return
static_cast
<
int
>
(
ONNXDataType
::
DOUBLE
);
case
framework
::
proto
::
VarType
::
UINT8
:
return
static_cast
<
int
>
(
ONNXDataType
::
UINT8
);
case
framework
::
proto
::
VarType
::
INT8
:
return
static_cast
<
int
>
(
ONNXDataType
::
INT8
);
case
framework
::
proto
::
VarType
::
BF16
:
return
static_cast
<
int
>
(
ONNXDataType
::
BFLOAT16
);
case
framework
::
proto
::
VarType
::
COMPLEX64
:
return
static_cast
<
int
>
(
ONNXDataType
::
COMPLEX64
);
case
framework
::
proto
::
VarType
::
COMPLEX128
:
return
static_cast
<
int
>
(
ONNXDataType
::
COMPLEX128
);
default:
PADDLE_THROW
(
platform
::
errors
::
Unimplemented
(
"Unsupported data type: %d."
,
dtype
));
}
}
const
std
::
string
VarType2PopStr
(
const
int
type
)
{
auto
dtype
=
static_cast
<
framework
::
proto
::
VarType
::
Type
>
(
type
);
switch
(
dtype
)
{
case
framework
::
proto
::
VarType
::
UINT8
:
return
"UINT8"
;
case
framework
::
proto
::
VarType
::
INT8
:
return
"INT8"
;
case
framework
::
proto
::
VarType
::
INT16
:
return
"INT16"
;
case
framework
::
proto
::
VarType
::
INT32
:
return
"INT32"
;
case
framework
::
proto
::
VarType
::
INT64
:
return
"INT64"
;
case
framework
::
proto
::
VarType
::
BOOL
:
return
"BOOL"
;
case
framework
::
proto
::
VarType
::
FP64
:
return
"DOUBLE"
;
case
framework
::
proto
::
VarType
::
FP32
:
return
"FLOAT"
;
case
framework
::
proto
::
VarType
::
FP16
:
return
"FLOAT16"
;
default:
PADDLE_THROW
(
paddle
::
platform
::
errors
::
Unavailable
(
"Unsupported data type."
));
}
}
Node
*
GetInputVarNode
(
const
std
::
string
&
input_name
,
const
Node
*
op_node
,
const
int
id
)
{
auto
var_name
=
op_node
->
Op
()
->
Input
(
input_name
).
at
(
id
);
return
GetInputVarNodeByVarName
(
var_name
,
op_node
);
}
Node
*
GetOutputVarNode
(
const
std
::
string
&
output_name
,
const
Node
*
op_node
,
const
int
id
)
{
auto
var_name
=
op_node
->
Op
()
->
Output
(
output_name
).
at
(
id
);
return
GetOutputVarNodeByVarName
(
var_name
,
op_node
);
}
Node
*
GetInputVarNodeByVarName
(
const
std
::
string
&
var_name
,
const
Node
*
op_node
)
{
for
(
auto
*
var
:
op_node
->
inputs
)
{
if
(
var
->
Name
()
==
var_name
)
{
return
var
;
}
}
return
nullptr
;
}
Node
*
GetOutputVarNodeByVarName
(
const
std
::
string
&
var_name
,
const
Node
*
op_node
)
{
for
(
auto
*
var
:
op_node
->
outputs
)
{
if
(
var
->
Name
()
==
var_name
)
{
return
var
;
}
}
return
nullptr
;
}
const
bool
is_float_equal
(
float
a
,
float
b
,
float
eps
)
{
return
std
::
fabs
(
a
-
b
)
<=
eps
;
}
}
// namespace ipu
}
// namespace platform
}
// namespace paddle
paddle/fluid/platform/device/ipu/popart_canonicalization/canonicalization_utils.h
0 → 100644
浏览文件 @
8b30c1ec
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/node.h"
#include "paddle/fluid/framework/ir/pass.h"
#include "paddle/fluid/platform/device/ipu/ipu_utils.h"
namespace
paddle
{
namespace
platform
{
namespace
ipu
{
using
framework
::
ir
::
Graph
;
using
framework
::
ir
::
Node
;
using
framework
::
OpDesc
;
#define REGISTER_HANDLER(name, func) \
static bool __UNUSED_##name = \
paddle::platform::ipu::RegisterHandler(#name, func)
using
SymbolHandler
=
std
::
function
<
Node
*
(
Graph
*
,
Node
*
)
>
;
std
::
unordered_map
<
std
::
string
,
SymbolHandler
>
&
SymbolHandlers
();
bool
RegisterHandler
(
const
std
::
string
&
,
const
SymbolHandler
&
);
SymbolHandler
GetHandler
(
const
std
::
string
&
);
void
ConnectNodes
(
Node
*
first_node
,
Node
*
next_node
);
void
DisConnectNodes
(
Node
*
first_node
,
Node
*
next_node
);
void
ClearNode
(
Node
*
node
);
void
CopyOpAttr
(
const
std
::
string
&
attr_name
,
OpDesc
*
op
,
OpDesc
*
new_op
,
bool
override
=
false
);
const
int
VarType2OnnxDtype
(
const
int
type
);
const
std
::
string
VarType2PopStr
(
const
int
type
);
Node
*
GetInputVarNode
(
const
std
::
string
&
input_name
,
const
Node
*
op_node
,
const
int
id
=
0
);
Node
*
GetOutputVarNode
(
const
std
::
string
&
output_name
,
const
Node
*
op_node
,
const
int
id
=
0
);
Node
*
GetInputVarNodeByVarName
(
const
std
::
string
&
var_name
,
const
Node
*
op_node
);
Node
*
GetOutputVarNodeByVarName
(
const
std
::
string
&
var_name
,
const
Node
*
op_node
);
const
bool
is_float_equal
(
float
a
,
float
b
,
float
eps
=
1e-8
);
}
// namespace ipu
}
// namespace platform
}
// namespace paddle
paddle/fluid/platform/device/ipu/popart_canonicalization/elementwise_ops.cc
0 → 100644
浏览文件 @
8b30c1ec
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/platform/device/ipu/popart_canonicalization/canonicalization_utils.h"
#include "paddle/fluid/platform/device/ipu/popart_canonicalization/op_builder.h"
#include "paddle/fluid/platform/enforce.h"
namespace
paddle
{
namespace
platform
{
namespace
ipu
{
namespace
{
Node
*
elementwise_op_handler
(
Graph
*
graph
,
Node
*
node
,
const
std
::
string
&
type
)
{
auto
*
op
=
node
->
Op
();
auto
x_shape
=
GetInputVarNode
(
"X"
,
node
)
->
Var
()
->
GetShape
();
int64_t
x_rank
=
x_shape
.
size
();
auto
y_shape
=
GetInputVarNode
(
"Y"
,
node
)
->
Var
()
->
GetShape
();
int64_t
y_rank
=
y_shape
.
size
();
auto
axis
=
BOOST_GET_CONST
(
int
,
op
->
GetAttr
(
"axis"
));
if
(
axis
==
-
1
||
axis
==
x_rank
-
1
||
x_rank
==
y_rank
)
{
auto
new_node
=
CreateBaseOp
(
graph
,
node
,
type
,
{
GetInputVarNode
(
"X"
,
node
),
GetInputVarNode
(
"Y"
,
node
)},
node
->
outputs
);
return
new_node
;
}
else
{
auto
y_new_shape
=
std
::
vector
<
int64_t
>
(
x_rank
,
1
);
for
(
int
i
=
axis
;
i
<
axis
+
y_rank
;
++
i
)
{
y_new_shape
[
i
]
=
y_shape
[
i
-
axis
];
}
auto
attrs
=
AttributeMap
{
{
"value"
,
y_new_shape
},
{
"dims"
,
std
::
vector
<
int64_t
>
{
x_rank
}},
{
"dtype"
,
ONNXDataType
::
INT64
},
};
// constant
auto
new_node_const
=
CreateConst
(
graph
,
node
,
{},
{},
attrs
);
// reshape
auto
new_node_reshape
=
CreateBaseOp
(
graph
,
node
,
"popart_reshape"
,
{
GetInputVarNode
(
"Y"
,
node
),
new_node_const
->
outputs
[
0
]},
{});
// elementwise_op
auto
new_node
=
CreateBaseOp
(
graph
,
node
,
type
,
{
GetInputVarNode
(
"X"
,
node
),
new_node_reshape
->
outputs
[
0
]},
node
->
outputs
);
return
new_node
;
}
}
Node
*
elementwise_add_handler
(
Graph
*
graph
,
Node
*
node
)
{
return
elementwise_op_handler
(
graph
,
node
,
"popart_add"
);
}
Node
*
elementwise_sub_handler
(
Graph
*
graph
,
Node
*
node
)
{
return
elementwise_op_handler
(
graph
,
node
,
"popart_sub"
);
}
Node
*
elementwise_div_handler
(
Graph
*
graph
,
Node
*
node
)
{
return
elementwise_op_handler
(
graph
,
node
,
"popart_div"
);
}
Node
*
elementwise_mul_handler
(
Graph
*
graph
,
Node
*
node
)
{
return
elementwise_op_handler
(
graph
,
node
,
"popart_mul"
);
}
Node
*
elementwise_min_handler
(
Graph
*
graph
,
Node
*
node
)
{
return
elementwise_op_handler
(
graph
,
node
,
"popart_min"
);
}
Node
*
elementwise_max_handler
(
Graph
*
graph
,
Node
*
node
)
{
return
elementwise_op_handler
(
graph
,
node
,
"popart_max"
);
}
Node
*
elementwise_pow_handler
(
Graph
*
graph
,
Node
*
node
)
{
return
elementwise_op_handler
(
graph
,
node
,
"popart_pow"
);
}
Node
*
elementwise_mod_handler
(
Graph
*
graph
,
Node
*
node
)
{
return
elementwise_op_handler
(
graph
,
node
,
"popart_mod"
);
}
REGISTER_HANDLER
(
elementwise_add
,
elementwise_add_handler
);
REGISTER_HANDLER
(
elementwise_sub
,
elementwise_sub_handler
);
REGISTER_HANDLER
(
elementwise_div
,
elementwise_div_handler
);
REGISTER_HANDLER
(
elementwise_mul
,
elementwise_mul_handler
);
REGISTER_HANDLER
(
elementwise_min
,
elementwise_min_handler
);
REGISTER_HANDLER
(
elementwise_max
,
elementwise_max_handler
);
REGISTER_HANDLER
(
elementwise_pow
,
elementwise_pow_handler
);
REGISTER_HANDLER
(
elementwise_mod
,
elementwise_mod_handler
);
}
// namespace
}
// namespace ipu
}
// namespace platform
}
// namespace paddle
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录