Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
8b30c1ec
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
8b30c1ec
编写于
12月 10, 2021
作者:
J
jianghaicheng
提交者:
GitHub
12月 10, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add popart_canonicalization p2 (#37965)
上级
8adaa0f0
变更
13
隐藏空白更改
内联
并排
Showing
13 changed file
with
893 addition
and
1 deletion
+893
-1
paddle/fluid/framework/ir/ipu/ipu_runtime_replacer_pass.cc
paddle/fluid/framework/ir/ipu/ipu_runtime_replacer_pass.cc
+97
-0
paddle/fluid/framework/ir/ipu/ipu_runtime_replacer_pass.h
paddle/fluid/framework/ir/ipu/ipu_runtime_replacer_pass.h
+31
-0
paddle/fluid/framework/ir/ipu/optimizer_extract_pass.cc
paddle/fluid/framework/ir/ipu/optimizer_extract_pass.cc
+91
-0
paddle/fluid/framework/ir/ipu/optimizer_extract_pass.h
paddle/fluid/framework/ir/ipu/optimizer_extract_pass.h
+31
-0
paddle/fluid/framework/ir/ipu/optimizer_state_align_pass.cc
paddle/fluid/framework/ir/ipu/optimizer_state_align_pass.cc
+79
-0
paddle/fluid/framework/ir/ipu/optimizer_state_align_pass.h
paddle/fluid/framework/ir/ipu/optimizer_state_align_pass.h
+36
-0
paddle/fluid/framework/ir/ipu/popart_canonicalization_pass.cc
...le/fluid/framework/ir/ipu/popart_canonicalization_pass.cc
+68
-0
paddle/fluid/framework/ir/ipu/popart_canonicalization_pass.h
paddle/fluid/framework/ir/ipu/popart_canonicalization_pass.h
+30
-0
paddle/fluid/platform/device/ipu/CMakeLists.txt
paddle/fluid/platform/device/ipu/CMakeLists.txt
+1
-1
paddle/fluid/platform/device/ipu/popart_canonicalization/activation_ops.cc
...form/device/ipu/popart_canonicalization/activation_ops.cc
+72
-0
paddle/fluid/platform/device/ipu/popart_canonicalization/canonicalization_utils.cc
...ice/ipu/popart_canonicalization/canonicalization_utils.cc
+185
-0
paddle/fluid/platform/device/ipu/popart_canonicalization/canonicalization_utils.h
...vice/ipu/popart_canonicalization/canonicalization_utils.h
+64
-0
paddle/fluid/platform/device/ipu/popart_canonicalization/elementwise_ops.cc
...orm/device/ipu/popart_canonicalization/elementwise_ops.cc
+108
-0
未找到文件。
paddle/fluid/framework/ir/ipu/ipu_runtime_replacer_pass.cc
0 → 100644
浏览文件 @
8b30c1ec
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/ir/ipu/ipu_runtime_replacer_pass.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/framework/ir/pass_tester_helper.h"
namespace
paddle
{
namespace
framework
{
namespace
ir
{
void
IpuRuntimeReplacerPass
::
ApplyImpl
(
ir
::
Graph
*
graph
)
const
{
VLOG
(
10
)
<<
"enter IpuRuntimeReplacerPass::ApplyImpl"
;
VLOG
(
10
)
<<
"Raw Graph: "
;
VLOG
(
10
)
<<
DebugString
(
graph
);
std
::
vector
<
std
::
string
>
feed_list
;
feed_list
=
Get
<
std
::
vector
<
std
::
string
>>
(
"feed_list"
);
std
::
vector
<
std
::
string
>
fetch_list
;
fetch_list
=
Get
<
std
::
vector
<
std
::
string
>>
(
"fetch_list"
);
framework
::
OpDesc
ipu_rt_op_desc
;
ipu_rt_op_desc
.
SetType
(
"ipu_runtime"
);
ipu_rt_op_desc
.
SetInput
(
"FeedList"
,
feed_list
);
ipu_rt_op_desc
.
SetOutput
(
"FetchList"
,
fetch_list
);
ipu_rt_op_desc
.
Flush
();
// Create a new node for the ipu_runtime_op.
auto
*
ipu_rt_node
=
graph
->
CreateOpNode
(
&
ipu_rt_op_desc
);
for
(
auto
*
node
:
graph
->
Nodes
())
{
if
(
node
->
IsVar
())
{
for
(
auto
feed
:
feed_list
)
{
if
(
node
->
Name
()
==
feed
)
{
IR_NODE_LINK_TO
(
node
,
ipu_rt_node
);
}
}
for
(
auto
fetch
:
fetch_list
)
{
if
(
node
->
Name
()
==
fetch
)
{
IR_NODE_LINK_TO
(
ipu_rt_node
,
node
);
}
}
}
}
// set ipu_runtime_op dtype attr
if
(
fetch_list
.
size
()
==
1
)
{
for
(
auto
*
node
:
graph
->
Nodes
())
{
if
(
node
->
IsVar
())
{
for
(
auto
fetch
:
fetch_list
)
{
if
(
node
->
Name
()
==
fetch
)
{
ipu_rt_node
->
Op
()
->
SetAttr
(
"dtype"
,
node
->
Var
()
->
GetDataType
());
}
}
}
}
}
// Remove unneeded nodes.
std
::
unordered_set
<
const
Node
*>
marked_nodes
;
for
(
auto
*
node
:
graph
->
Nodes
())
{
if
(
node
->
IsOp
())
{
auto
*
op_desc
=
node
->
Op
();
if
(
op_desc
->
Type
()
!=
"ipu_runtime"
)
{
marked_nodes
.
insert
(
node
);
}
}
}
GraphSafeRemoveNodes
(
graph
,
marked_nodes
);
VLOG
(
10
)
<<
"Post Graph: "
;
VLOG
(
10
)
<<
DebugString
(
graph
);
VLOG
(
10
)
<<
"leave IpuRuntimeReplacerPass::ApplyImpl"
;
}
}
// namespace ir
}
// namespace framework
}
// namespace paddle
REGISTER_PASS
(
ipu_runtime_replacer_pass
,
paddle
::
framework
::
ir
::
IpuRuntimeReplacerPass
)
.
RequirePassAttr
(
"feed_list"
)
.
RequirePassAttr
(
"fetch_list"
);
paddle/fluid/framework/ir/ipu/ipu_runtime_replacer_pass.h
0 → 100644
浏览文件 @
8b30c1ec
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/ipu/ipu_pass_base.h"
namespace
paddle
{
namespace
framework
{
namespace
ir
{
class
IpuRuntimeReplacerPass
:
public
IPUPassBase
{
protected:
void
ApplyImpl
(
ir
::
Graph
*
graph
)
const
override
;
};
}
// namespace ir
}
// namespace framework
}
// namespace paddle
paddle/fluid/framework/ir/ipu/optimizer_extract_pass.cc
0 → 100644
浏览文件 @
8b30c1ec
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/ir/ipu/optimizer_extract_pass.h"
#include "paddle/fluid/framework/ir/pass_tester_helper.h"
#include "paddle/fluid/platform/device/ipu/ipu_backend.h"
namespace
paddle
{
namespace
framework
{
namespace
ir
{
void
IpuOptimizerExtractPass
::
ApplyImpl
(
ir
::
Graph
*
graph
)
const
{
VLOG
(
10
)
<<
"enter IpuOptimizerExtractPass::ApplyImpl"
;
VLOG
(
10
)
<<
"Raw Graph: "
;
VLOG
(
10
)
<<
DebugString
(
graph
);
auto
ipu_backend
=
paddle
::
platform
::
ipu
::
IpuBackend
::
GetInstance
();
for
(
auto
*
node
:
graph
->
Nodes
())
{
if
(
node
->
IsOp
()
&&
node
->
Op
())
{
int
op_role
=
BOOST_GET_CONST
(
int
,
node
->
Op
()
->
GetAttr
(
framework
::
OpProtoAndCheckerMaker
::
OpRoleAttrName
()));
// graph usually have multiple optimizer node for different parameter,
// and these node have the same type and attr value usually
if
((
op_role
==
static_cast
<
int
>
(
framework
::
OpRole
::
kOptimize
)))
{
ipu_backend
->
GetExecutor
().
SetOptimizerType
(
node
->
Op
()
->
Type
());
VLOG
(
10
)
<<
"found optimizer type: "
<<
node
->
Op
()
->
Type
();
for
(
const
std
::
string
&
attr_name
:
node
->
Op
()
->
AttrNames
())
{
auto
attr_type
=
node
->
Op
()
->
GetAttrType
(
attr_name
);
// with adam, attr are float
if
(
attr_type
==
proto
::
AttrType
::
FLOAT
)
{
auto
attr_value
=
BOOST_GET_CONST
(
float
,
node
->
Op
()
->
GetAttr
(
attr_name
));
ipu_backend
->
GetExecutor
().
SetOptimizerAttr
(
attr_name
,
attr_value
);
}
else
{
VLOG
(
10
)
<<
"Skip "
<<
attr_type
;
}
}
auto
lr_var_name
=
node
->
Op
()
->
Input
(
"LearningRate"
);
PADDLE_ENFORCE_EQ
(
lr_var_name
.
size
(),
1u
,
platform
::
errors
::
InvalidArgument
(
"In op(%s), find input(LearningRate) failed."
,
node
->
Op
()
->
Type
()));
ipu_backend
->
GetExecutor
().
SetLRVarName
(
lr_var_name
[
0
]);
}
if
((
op_role
==
static_cast
<
int
>
(
framework
::
OpRole
::
kLoss
)))
{
VLOG
(
10
)
<<
"found loss op type: "
<<
node
->
Op
()
->
Type
();
auto
outputs
=
node
->
Op
()
->
Outputs
();
PADDLE_ENFORCE_EQ
(
outputs
.
size
(),
1
,
platform
::
errors
::
InvalidArgument
(
"Can only support one loss key"
));
auto
losses_name
=
outputs
.
begin
()
->
second
;
PADDLE_ENFORCE_EQ
(
losses_name
.
size
(),
1
,
platform
::
errors
::
InvalidArgument
(
"Can only support one loss name"
));
ipu_backend
->
GetExecutor
().
SetLoss
(
losses_name
[
0
]);
}
}
}
VLOG
(
10
)
<<
"Post Graph: "
;
VLOG
(
10
)
<<
DebugString
(
graph
);
VLOG
(
10
)
<<
"leave IpuOptimizerExtractPass::ApplyImpl"
;
}
}
// namespace ir
}
// namespace framework
}
// namespace paddle
REGISTER_PASS
(
optimizer_extract_pass
,
paddle
::
framework
::
ir
::
IpuOptimizerExtractPass
);
paddle/fluid/framework/ir/ipu/optimizer_extract_pass.h
0 → 100644
浏览文件 @
8b30c1ec
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/ipu/ipu_pass_base.h"
namespace
paddle
{
namespace
framework
{
namespace
ir
{
class
IpuOptimizerExtractPass
:
public
IPUPassBase
{
protected:
void
ApplyImpl
(
ir
::
Graph
*
graph
)
const
override
;
};
}
// namespace ir
}
// namespace framework
}
// namespace paddle
paddle/fluid/framework/ir/ipu/optimizer_state_align_pass.cc
0 → 100644
浏览文件 @
8b30c1ec
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/ir/ipu/optimizer_state_align_pass.h"
#include "paddle/fluid/framework/ir/pass_tester_helper.h"
#include "paddle/fluid/platform/device/ipu/common.h"
#include "paddle/fluid/platform/device/ipu/ipu_backend.h"
namespace
paddle
{
namespace
framework
{
namespace
ir
{
using
paddle
::
platform
::
ipu
::
IpuBackend
;
using
framework
::
ir
::
Graph
;
using
framework
::
ir
::
Node
;
void
IpuOptimizerStateAlignPass
::
ApplyImpl
(
ir
::
Graph
*
graph
)
const
{
VLOG
(
10
)
<<
"enter IpuOptimizerStateAlignPass::ApplyImpl"
;
VLOG
(
10
)
<<
"Raw Graph: "
;
VLOG
(
10
)
<<
DebugString
(
graph
);
auto
ipu_backend
=
IpuBackend
::
GetInstance
();
const
auto
*
scope_
=
ipu_backend
->
GetScope
();
for
(
auto
*
node
:
graph
->
Nodes
())
{
if
(
node
->
IsOp
()
&&
node
->
Op
())
{
int
op_role
=
BOOST_GET_CONST
(
int
,
node
->
Op
()
->
GetAttr
(
framework
::
OpProtoAndCheckerMaker
::
OpRoleAttrName
()));
if
((
op_role
==
static_cast
<
int
>
(
framework
::
OpRole
::
kOptimize
)))
{
auto
inputs
=
node
->
Op
()
->
Inputs
();
if
(
inputs
.
count
(
platform
::
ipu
::
sBeta1Pow
))
{
auto
var
=
scope_
->
GetVar
(
inputs
.
at
(
platform
::
ipu
::
sBeta1Pow
)[
0
]);
auto
data
=
var
->
GetMutable
<
framework
::
LoDTensor
>
()
->
data
<
float
>
();
auto
beta
=
BOOST_GET_CONST
(
float
,
node
->
Op
()
->
GetAttr
(
platform
::
ipu
::
sBeta1
));
// ensure current save with beta1pow, rather than step.
// beta1pow = beta1 ^ (step + 1). Just set beta1pow because popart
// support single Step__
bool
save_with_beta1pow
=
(
data
[
0
]
<
1.0
f
)
&&
(
data
[
0
]
>
0.0
f
);
float
step
=
0
;
float
beta_acc
=
beta
;
while
(
beta_acc
>
data
[
0
]
&&
save_with_beta1pow
)
{
beta_acc
*=
beta
;
step
+=
1
;
}
if
(
save_with_beta1pow
)
{
data
[
0
]
=
step
;
}
}
}
}
}
VLOG
(
10
)
<<
"Post Graph: "
;
VLOG
(
10
)
<<
DebugString
(
graph
);
VLOG
(
10
)
<<
"leave IpuOptimizerStateAlignPass::ApplyImpl"
;
}
}
// namespace ir
}
// namespace framework
}
// namespace paddle
REGISTER_PASS
(
optimizer_state_align_pass
,
paddle
::
framework
::
ir
::
IpuOptimizerStateAlignPass
);
paddle/fluid/framework/ir/ipu/optimizer_state_align_pass.h
0 → 100644
浏览文件 @
8b30c1ec
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/ipu/ipu_pass_base.h"
namespace
paddle
{
namespace
framework
{
namespace
ir
{
/*
* This pass should only affect optimizer that need bias correction,
* include Adam/Lamb.
*/
class
IpuOptimizerStateAlignPass
:
public
IPUPassBase
{
protected:
void
ApplyImpl
(
ir
::
Graph
*
graph
)
const
override
;
};
}
// namespace ir
}
// namespace framework
}
// namespace paddle
paddle/fluid/framework/ir/ipu/popart_canonicalization_pass.cc
0 → 100644
浏览文件 @
8b30c1ec
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/ir/ipu/popart_canonicalization_pass.h"
#include "paddle/fluid/framework/ir/pass_tester_helper.h"
#include "paddle/fluid/platform/device/ipu/popart_canonicalization/canonicalization_utils.h"
#include "paddle/fluid/platform/device/ipu/popart_canonicalization/post_canonicalization.h"
namespace
paddle
{
namespace
framework
{
namespace
ir
{
using
framework
::
ir
::
Graph
;
using
framework
::
ir
::
Node
;
using
platform
::
ipu
::
SymbolHandler
;
void
PopartCanonicalizationPass
::
ApplyImpl
(
ir
::
Graph
*
graph
)
const
{
VLOG
(
10
)
<<
"enter PopartCanonicalizationPass::ApplyImpl"
;
VLOG
(
10
)
<<
"Raw Graph: "
;
VLOG
(
10
)
<<
DebugString
(
graph
);
auto
nodes
=
graph
->
Nodes
();
for
(
auto
*
node
:
nodes
)
{
if
(
!
node
->
IsOp
())
{
continue
;
}
auto
*
op
=
node
->
Op
();
auto
op_type
=
op
->
Type
();
ir
::
Node
*
new_node
=
nullptr
;
SymbolHandler
handler
=
platform
::
ipu
::
GetHandler
(
op_type
);
if
(
handler
)
{
VLOG
(
11
)
<<
"Raw Paddle Node:"
;
VLOG
(
11
)
<<
node
->
Op
()
->
Proto
()
->
DebugString
();
new_node
=
handler
(
graph
,
node
);
VLOG
(
11
)
<<
"Post Popart Node:"
;
VLOG
(
11
)
<<
new_node
->
Op
()
->
Proto
()
->
DebugString
();
platform
::
ipu
::
ClearNode
(
node
);
graph
->
RemoveNode
(
node
);
}
else
{
LOG
(
ERROR
)
<<
"Can not find OpHandler for op_type: "
<<
op_type
;
}
}
VLOG
(
10
)
<<
"Post Graph: "
;
VLOG
(
10
)
<<
DebugString
(
graph
);
VLOG
(
10
)
<<
"leave PopartCanonicalizationPass::ApplyImpl"
;
}
}
// namespace ir
}
// namespace framework
}
// namespace paddle
REGISTER_PASS
(
popart_canonicalization_pass
,
paddle
::
framework
::
ir
::
PopartCanonicalizationPass
);
paddle/fluid/framework/ir/ipu/popart_canonicalization_pass.h
0 → 100644
浏览文件 @
8b30c1ec
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/framework/ir/ipu/ipu_pass_base.h"
namespace
paddle
{
namespace
framework
{
namespace
ir
{
class
PopartCanonicalizationPass
:
public
IPUPassBase
{
protected:
void
ApplyImpl
(
ir
::
Graph
*
graph
)
const
override
;
};
}
// namespace ir
}
// namespace framework
}
// namespace paddle
paddle/fluid/platform/device/ipu/CMakeLists.txt
浏览文件 @
8b30c1ec
# IPU
IF
(
WITH_IPU
)
FILE
(
GLOB POPART_CANONICALIZATION_SRC
${
PADDLE_SOURCE_DIR
}
/paddle/fluid/platform/device/ipu/popart_canonicalization/*.cc
)
cc_library
(
ipu_device SRCS device.cc DEPS enforce popart
)
cc_library
(
ipu_utils SRCS ipu_utils.cc DEPS memory framework_proto popart
)
cc_library
(
ipu_strategy SRCS ipu_strategy.cc DEPS popart graph framework_proto enforce
)
...
...
paddle/fluid/platform/device/ipu/popart_canonicalization/activation_ops.cc
0 → 100644
浏览文件 @
8b30c1ec
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/platform/device/ipu/popart_canonicalization/canonicalization_utils.h"
#include "paddle/fluid/platform/device/ipu/popart_canonicalization/op_builder.h"
#include "paddle/fluid/platform/enforce.h"
namespace
paddle
{
namespace
platform
{
namespace
ipu
{
namespace
{
Node
*
activation_op_handler
(
Graph
*
graph
,
Node
*
node
,
const
std
::
string
&
type
)
{
auto
new_node
=
CreateBaseOp
(
graph
,
node
,
type
,
{
GetInputVarNode
(
"X"
,
node
)},
node
->
outputs
);
return
new_node
;
}
Node
*
relu_handler
(
Graph
*
graph
,
Node
*
node
)
{
return
activation_op_handler
(
graph
,
node
,
"popart_relu"
);
}
Node
*
tanh_handler
(
Graph
*
graph
,
Node
*
node
)
{
return
activation_op_handler
(
graph
,
node
,
"popart_tanh"
);
}
Node
*
log_handler
(
Graph
*
graph
,
Node
*
node
)
{
return
activation_op_handler
(
graph
,
node
,
"popart_log"
);
}
Node
*
sigmoid_handler
(
Graph
*
graph
,
Node
*
node
)
{
return
activation_op_handler
(
graph
,
node
,
"popart_sigmoid"
);
}
Node
*
sqrt_handler
(
Graph
*
graph
,
Node
*
node
)
{
return
activation_op_handler
(
graph
,
node
,
"popart_sqrt"
);
}
Node
*
gelu_handler
(
Graph
*
graph
,
Node
*
node
)
{
return
activation_op_handler
(
graph
,
node
,
"popart_gelu_v2"
);
}
Node
*
log_softmax_handler
(
Graph
*
graph
,
Node
*
node
)
{
auto
axis
=
BOOST_GET_CONST
(
int
,
node
->
Op
()
->
GetAttr
(
"axis"
));
auto
new_softmax
=
CreateSoftmaxOpset11
(
graph
,
node
,
node
->
inputs
,
{},
axis
);
return
CreateBaseOp
(
graph
,
node
,
"popart_log"
,
new_softmax
->
outputs
,
node
->
outputs
);
}
REGISTER_HANDLER
(
relu
,
relu_handler
);
REGISTER_HANDLER
(
tanh
,
tanh_handler
);
REGISTER_HANDLER
(
log
,
log_handler
);
REGISTER_HANDLER
(
sigmoid
,
sigmoid_handler
);
REGISTER_HANDLER
(
sqrt
,
sqrt_handler
);
REGISTER_HANDLER
(
gelu
,
gelu_handler
);
REGISTER_HANDLER
(
log_softmax
,
log_softmax_handler
);
}
// namespace
}
// namespace ipu
}
// namespace platform
}
// namespace paddle
paddle/fluid/platform/device/ipu/popart_canonicalization/canonicalization_utils.cc
0 → 100644
浏览文件 @
8b30c1ec
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/platform/device/ipu/popart_canonicalization/canonicalization_utils.h"
namespace
paddle
{
namespace
platform
{
namespace
ipu
{
// This avoids the static initialisation order fiasco,
std
::
unordered_map
<
std
::
string
,
SymbolHandler
>
&
SymbolHandlers
()
{
static
std
::
unordered_map
<
std
::
string
,
SymbolHandler
>
symbol_handlers
;
return
symbol_handlers
;
}
bool
RegisterHandler
(
const
std
::
string
&
symbol
,
const
SymbolHandler
&
handler
)
{
if
(
SymbolHandlers
().
count
(
symbol
)
!=
0
)
{
LOG
(
WARNING
)
<<
"Trying to register popart handler twice for operator: "
<<
symbol
;
return
false
;
}
bool
new_handler
=
SymbolHandlers
().
emplace
(
symbol
,
handler
).
second
;
return
new_handler
;
}
// Return a pointer to a handler if one is registered for this kind of node or
// an empty std::function otherwise.
SymbolHandler
GetHandler
(
const
std
::
string
&
kind
)
{
auto
it
=
SymbolHandlers
().
find
(
kind
);
if
(
it
!=
SymbolHandlers
().
end
())
{
return
it
->
second
;
}
return
{};
}
void
ConnectNodes
(
Node
*
first_node
,
Node
*
next_node
)
{
first_node
->
outputs
.
push_back
(
next_node
);
next_node
->
inputs
.
push_back
(
first_node
);
}
void
DisConnectNodes
(
Node
*
first_node
,
Node
*
next_node
)
{
auto
rm_by_value
=
[
&
](
std
::
vector
<
Node
*>
&
vec
,
Node
*
n
)
{
vec
.
erase
(
std
::
remove
(
vec
.
begin
(),
vec
.
end
(),
n
),
vec
.
end
());
};
rm_by_value
(
first_node
->
outputs
,
next_node
);
rm_by_value
(
next_node
->
inputs
,
first_node
);
rm_by_value
(
first_node
->
inputs
,
next_node
);
rm_by_value
(
next_node
->
outputs
,
first_node
);
}
void
ClearNode
(
Node
*
node
)
{
auto
rm_by_value
=
[
&
](
std
::
vector
<
Node
*>
&
vec
,
Node
*
n
)
{
vec
.
erase
(
std
::
remove
(
vec
.
begin
(),
vec
.
end
(),
n
),
vec
.
end
());
};
for
(
auto
*
node_in
:
node
->
inputs
)
{
rm_by_value
(
node_in
->
outputs
,
node
);
}
for
(
auto
*
node_out
:
node
->
outputs
)
{
rm_by_value
(
node_out
->
inputs
,
node
);
}
}
void
CopyOpAttr
(
const
std
::
string
&
attr_name
,
OpDesc
*
op
,
OpDesc
*
new_op
,
bool
override
)
{
if
(
new_op
->
HasAttr
(
attr_name
)
&&
!
override
)
{
return
;
}
if
(
op
->
HasAttr
(
attr_name
))
{
VLOG
(
10
)
<<
"Copying attr: "
<<
attr_name
<<
" from "
<<
op
->
Type
()
<<
" to "
<<
new_op
->
Type
();
new_op
->
SetAttr
(
attr_name
,
op
->
GetAttr
(
attr_name
));
new_op
->
Flush
();
}
}
const
int
VarType2OnnxDtype
(
const
int
type
)
{
auto
dtype
=
static_cast
<
framework
::
proto
::
VarType
::
Type
>
(
type
);
switch
(
dtype
)
{
case
framework
::
proto
::
VarType
::
BOOL
:
return
static_cast
<
int
>
(
ONNXDataType
::
BOOL
);
case
framework
::
proto
::
VarType
::
INT16
:
return
static_cast
<
int
>
(
ONNXDataType
::
INT16
);
case
framework
::
proto
::
VarType
::
INT32
:
return
static_cast
<
int
>
(
ONNXDataType
::
INT32
);
case
framework
::
proto
::
VarType
::
INT64
:
return
static_cast
<
int
>
(
ONNXDataType
::
INT64
);
case
framework
::
proto
::
VarType
::
FP16
:
return
static_cast
<
int
>
(
ONNXDataType
::
FLOAT16
);
case
framework
::
proto
::
VarType
::
FP32
:
return
static_cast
<
int
>
(
ONNXDataType
::
FLOAT
);
case
framework
::
proto
::
VarType
::
FP64
:
return
static_cast
<
int
>
(
ONNXDataType
::
DOUBLE
);
case
framework
::
proto
::
VarType
::
UINT8
:
return
static_cast
<
int
>
(
ONNXDataType
::
UINT8
);
case
framework
::
proto
::
VarType
::
INT8
:
return
static_cast
<
int
>
(
ONNXDataType
::
INT8
);
case
framework
::
proto
::
VarType
::
BF16
:
return
static_cast
<
int
>
(
ONNXDataType
::
BFLOAT16
);
case
framework
::
proto
::
VarType
::
COMPLEX64
:
return
static_cast
<
int
>
(
ONNXDataType
::
COMPLEX64
);
case
framework
::
proto
::
VarType
::
COMPLEX128
:
return
static_cast
<
int
>
(
ONNXDataType
::
COMPLEX128
);
default:
PADDLE_THROW
(
platform
::
errors
::
Unimplemented
(
"Unsupported data type: %d."
,
dtype
));
}
}
const
std
::
string
VarType2PopStr
(
const
int
type
)
{
auto
dtype
=
static_cast
<
framework
::
proto
::
VarType
::
Type
>
(
type
);
switch
(
dtype
)
{
case
framework
::
proto
::
VarType
::
UINT8
:
return
"UINT8"
;
case
framework
::
proto
::
VarType
::
INT8
:
return
"INT8"
;
case
framework
::
proto
::
VarType
::
INT16
:
return
"INT16"
;
case
framework
::
proto
::
VarType
::
INT32
:
return
"INT32"
;
case
framework
::
proto
::
VarType
::
INT64
:
return
"INT64"
;
case
framework
::
proto
::
VarType
::
BOOL
:
return
"BOOL"
;
case
framework
::
proto
::
VarType
::
FP64
:
return
"DOUBLE"
;
case
framework
::
proto
::
VarType
::
FP32
:
return
"FLOAT"
;
case
framework
::
proto
::
VarType
::
FP16
:
return
"FLOAT16"
;
default:
PADDLE_THROW
(
paddle
::
platform
::
errors
::
Unavailable
(
"Unsupported data type."
));
}
}
Node
*
GetInputVarNode
(
const
std
::
string
&
input_name
,
const
Node
*
op_node
,
const
int
id
)
{
auto
var_name
=
op_node
->
Op
()
->
Input
(
input_name
).
at
(
id
);
return
GetInputVarNodeByVarName
(
var_name
,
op_node
);
}
Node
*
GetOutputVarNode
(
const
std
::
string
&
output_name
,
const
Node
*
op_node
,
const
int
id
)
{
auto
var_name
=
op_node
->
Op
()
->
Output
(
output_name
).
at
(
id
);
return
GetOutputVarNodeByVarName
(
var_name
,
op_node
);
}
Node
*
GetInputVarNodeByVarName
(
const
std
::
string
&
var_name
,
const
Node
*
op_node
)
{
for
(
auto
*
var
:
op_node
->
inputs
)
{
if
(
var
->
Name
()
==
var_name
)
{
return
var
;
}
}
return
nullptr
;
}
Node
*
GetOutputVarNodeByVarName
(
const
std
::
string
&
var_name
,
const
Node
*
op_node
)
{
for
(
auto
*
var
:
op_node
->
outputs
)
{
if
(
var
->
Name
()
==
var_name
)
{
return
var
;
}
}
return
nullptr
;
}
const
bool
is_float_equal
(
float
a
,
float
b
,
float
eps
)
{
return
std
::
fabs
(
a
-
b
)
<=
eps
;
}
}
// namespace ipu
}
// namespace platform
}
// namespace paddle
paddle/fluid/platform/device/ipu/popart_canonicalization/canonicalization_utils.h
0 → 100644
浏览文件 @
8b30c1ec
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/node.h"
#include "paddle/fluid/framework/ir/pass.h"
#include "paddle/fluid/platform/device/ipu/ipu_utils.h"
namespace
paddle
{
namespace
platform
{
namespace
ipu
{
using
framework
::
ir
::
Graph
;
using
framework
::
ir
::
Node
;
using
framework
::
OpDesc
;
#define REGISTER_HANDLER(name, func) \
static bool __UNUSED_##name = \
paddle::platform::ipu::RegisterHandler(#name, func)
using
SymbolHandler
=
std
::
function
<
Node
*
(
Graph
*
,
Node
*
)
>
;
std
::
unordered_map
<
std
::
string
,
SymbolHandler
>
&
SymbolHandlers
();
bool
RegisterHandler
(
const
std
::
string
&
,
const
SymbolHandler
&
);
SymbolHandler
GetHandler
(
const
std
::
string
&
);
void
ConnectNodes
(
Node
*
first_node
,
Node
*
next_node
);
void
DisConnectNodes
(
Node
*
first_node
,
Node
*
next_node
);
void
ClearNode
(
Node
*
node
);
void
CopyOpAttr
(
const
std
::
string
&
attr_name
,
OpDesc
*
op
,
OpDesc
*
new_op
,
bool
override
=
false
);
const
int
VarType2OnnxDtype
(
const
int
type
);
const
std
::
string
VarType2PopStr
(
const
int
type
);
Node
*
GetInputVarNode
(
const
std
::
string
&
input_name
,
const
Node
*
op_node
,
const
int
id
=
0
);
Node
*
GetOutputVarNode
(
const
std
::
string
&
output_name
,
const
Node
*
op_node
,
const
int
id
=
0
);
Node
*
GetInputVarNodeByVarName
(
const
std
::
string
&
var_name
,
const
Node
*
op_node
);
Node
*
GetOutputVarNodeByVarName
(
const
std
::
string
&
var_name
,
const
Node
*
op_node
);
const
bool
is_float_equal
(
float
a
,
float
b
,
float
eps
=
1e-8
);
}
// namespace ipu
}
// namespace platform
}
// namespace paddle
paddle/fluid/platform/device/ipu/popart_canonicalization/elementwise_ops.cc
0 → 100644
浏览文件 @
8b30c1ec
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/platform/device/ipu/popart_canonicalization/canonicalization_utils.h"
#include "paddle/fluid/platform/device/ipu/popart_canonicalization/op_builder.h"
#include "paddle/fluid/platform/enforce.h"
namespace
paddle
{
namespace
platform
{
namespace
ipu
{
namespace
{
Node
*
elementwise_op_handler
(
Graph
*
graph
,
Node
*
node
,
const
std
::
string
&
type
)
{
auto
*
op
=
node
->
Op
();
auto
x_shape
=
GetInputVarNode
(
"X"
,
node
)
->
Var
()
->
GetShape
();
int64_t
x_rank
=
x_shape
.
size
();
auto
y_shape
=
GetInputVarNode
(
"Y"
,
node
)
->
Var
()
->
GetShape
();
int64_t
y_rank
=
y_shape
.
size
();
auto
axis
=
BOOST_GET_CONST
(
int
,
op
->
GetAttr
(
"axis"
));
if
(
axis
==
-
1
||
axis
==
x_rank
-
1
||
x_rank
==
y_rank
)
{
auto
new_node
=
CreateBaseOp
(
graph
,
node
,
type
,
{
GetInputVarNode
(
"X"
,
node
),
GetInputVarNode
(
"Y"
,
node
)},
node
->
outputs
);
return
new_node
;
}
else
{
auto
y_new_shape
=
std
::
vector
<
int64_t
>
(
x_rank
,
1
);
for
(
int
i
=
axis
;
i
<
axis
+
y_rank
;
++
i
)
{
y_new_shape
[
i
]
=
y_shape
[
i
-
axis
];
}
auto
attrs
=
AttributeMap
{
{
"value"
,
y_new_shape
},
{
"dims"
,
std
::
vector
<
int64_t
>
{
x_rank
}},
{
"dtype"
,
ONNXDataType
::
INT64
},
};
// constant
auto
new_node_const
=
CreateConst
(
graph
,
node
,
{},
{},
attrs
);
// reshape
auto
new_node_reshape
=
CreateBaseOp
(
graph
,
node
,
"popart_reshape"
,
{
GetInputVarNode
(
"Y"
,
node
),
new_node_const
->
outputs
[
0
]},
{});
// elementwise_op
auto
new_node
=
CreateBaseOp
(
graph
,
node
,
type
,
{
GetInputVarNode
(
"X"
,
node
),
new_node_reshape
->
outputs
[
0
]},
node
->
outputs
);
return
new_node
;
}
}
Node
*
elementwise_add_handler
(
Graph
*
graph
,
Node
*
node
)
{
return
elementwise_op_handler
(
graph
,
node
,
"popart_add"
);
}
Node
*
elementwise_sub_handler
(
Graph
*
graph
,
Node
*
node
)
{
return
elementwise_op_handler
(
graph
,
node
,
"popart_sub"
);
}
Node
*
elementwise_div_handler
(
Graph
*
graph
,
Node
*
node
)
{
return
elementwise_op_handler
(
graph
,
node
,
"popart_div"
);
}
Node
*
elementwise_mul_handler
(
Graph
*
graph
,
Node
*
node
)
{
return
elementwise_op_handler
(
graph
,
node
,
"popart_mul"
);
}
Node
*
elementwise_min_handler
(
Graph
*
graph
,
Node
*
node
)
{
return
elementwise_op_handler
(
graph
,
node
,
"popart_min"
);
}
Node
*
elementwise_max_handler
(
Graph
*
graph
,
Node
*
node
)
{
return
elementwise_op_handler
(
graph
,
node
,
"popart_max"
);
}
Node
*
elementwise_pow_handler
(
Graph
*
graph
,
Node
*
node
)
{
return
elementwise_op_handler
(
graph
,
node
,
"popart_pow"
);
}
Node
*
elementwise_mod_handler
(
Graph
*
graph
,
Node
*
node
)
{
return
elementwise_op_handler
(
graph
,
node
,
"popart_mod"
);
}
REGISTER_HANDLER
(
elementwise_add
,
elementwise_add_handler
);
REGISTER_HANDLER
(
elementwise_sub
,
elementwise_sub_handler
);
REGISTER_HANDLER
(
elementwise_div
,
elementwise_div_handler
);
REGISTER_HANDLER
(
elementwise_mul
,
elementwise_mul_handler
);
REGISTER_HANDLER
(
elementwise_min
,
elementwise_min_handler
);
REGISTER_HANDLER
(
elementwise_max
,
elementwise_max_handler
);
REGISTER_HANDLER
(
elementwise_pow
,
elementwise_pow_handler
);
REGISTER_HANDLER
(
elementwise_mod
,
elementwise_mod_handler
);
}
// namespace
}
// namespace ipu
}
// namespace platform
}
// namespace paddle
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录