Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
64661927
P
Paddle
项目概览
PaddlePaddle
/
Paddle
1 年多 前同步成功
通知
2302
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
64661927
编写于
5月 28, 2023
作者:
K
kangguangli
提交者:
GitHub
5月 28, 2023
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[IR] add op name normalizer (#54143)
* add op name normalizer * disable unittest
上级
04d6afc9
变更
9
显示空白变更内容
内联
并排
Showing
9 changed file
with
205 addition
and
9 deletion
+205
-9
.gitignore
.gitignore
+1
-0
paddle/fluid/dialect/legacy_pd_op.h
paddle/fluid/dialect/legacy_pd_op.h
+13
-0
paddle/fluid/dialect/pd_dialect.cc
paddle/fluid/dialect/pd_dialect.cc
+14
-1
paddle/fluid/translator/CMakeLists.txt
paddle/fluid/translator/CMakeLists.txt
+12
-1
paddle/fluid/translator/op_compat_gen.py
paddle/fluid/translator/op_compat_gen.py
+86
-0
paddle/fluid/translator/op_compat_info.cc.j2
paddle/fluid/translator/op_compat_info.cc.j2
+15
-0
paddle/fluid/translator/op_compat_info.h
paddle/fluid/translator/op_compat_info.h
+50
-0
paddle/fluid/translator/op_translator.cc
paddle/fluid/translator/op_translator.cc
+10
-1
test/cpp/ir/program_translator_test.cc
test/cpp/ir/program_translator_test.cc
+4
-6
未找到文件。
.gitignore
浏览文件 @
64661927
...
@@ -96,3 +96,4 @@ paddle/phi/api/profiler/__init__.py
...
@@ -96,3 +96,4 @@ paddle/phi/api/profiler/__init__.py
python/paddle/incubate/fleet/parameter_server/pslib/ps_pb2.py
python/paddle/incubate/fleet/parameter_server/pslib/ps_pb2.py
paddle/phi/kernels/fusion/cutlass/conv2d/generated/*
paddle/phi/kernels/fusion/cutlass/conv2d/generated/*
python/paddle/fluid/incubate/fleet/parameter_server/pslib/ps_pb2.py
python/paddle/fluid/incubate/fleet/parameter_server/pslib/ps_pb2.py
paddle/fluid/translator/op_compat_info.cc
paddle/fluid/dialect/legacy_pd_op.h
浏览文件 @
64661927
...
@@ -79,6 +79,19 @@ REIGSTER_EMPTY_OP(batch_norm_grad,
...
@@ -79,6 +79,19 @@ REIGSTER_EMPTY_OP(batch_norm_grad,
REIGSTER_EMPTY_OP
(
conv2d_grad
,
Conv2DGradOp
);
// To be customized: conv2d_grad
REIGSTER_EMPTY_OP
(
conv2d_grad
,
Conv2DGradOp
);
// To be customized: conv2d_grad
REIGSTER_EMPTY_OP
(
sum
,
SumOp
);
// To be customized: sum(reduce_sum)
REIGSTER_EMPTY_OP
(
sum
,
SumOp
);
// To be customized: sum(reduce_sum)
REIGSTER_EMPTY_OP
(
fetch_v2
,
FetchV2Op
);
// To be customized: fetch_v2
REIGSTER_EMPTY_OP
(
fetch_v2
,
FetchV2Op
);
// To be customized: fetch_v2
REIGSTER_EMPTY_OP
(
add
,
AddOp
);
REIGSTER_EMPTY_OP
(
add_grad
,
AddGradOp
);
REIGSTER_EMPTY_OP
(
matmul
,
MatMulOp
);
REIGSTER_EMPTY_OP
(
matmul_grad
,
MatMulGradOp
);
REIGSTER_EMPTY_OP
(
reshape
,
ReshapeOp
);
REIGSTER_EMPTY_OP
(
reshape_grad
,
ReshapeGradOp
);
REIGSTER_EMPTY_OP
(
mean
,
MeanOp
);
REIGSTER_EMPTY_OP
(
cross_entropy_with_softmax
,
CrossEntropyOp
);
REIGSTER_EMPTY_OP
(
cross_entropy_with_softmax_grad
,
CrossEntropyGradOp
);
REIGSTER_EMPTY_OP
(
topk
,
TopKOp
);
REIGSTER_EMPTY_OP
(
topk_grad
,
TopKGradOp
);
REIGSTER_EMPTY_OP
(
full
,
FullOp
);
REIGSTER_EMPTY_OP
(
add_n
,
AddNOp
);
}
// namespace dialect
}
// namespace dialect
}
// namespace paddle
}
// namespace paddle
paddle/fluid/dialect/pd_dialect.cc
浏览文件 @
64661927
...
@@ -133,7 +133,20 @@ void PaddleDialect::initialize() {
...
@@ -133,7 +133,20 @@ void PaddleDialect::initialize() {
BatchNormGradOp
,
BatchNormGradOp
,
Conv2DGradOp
,
Conv2DGradOp
,
SumOp
,
SumOp
,
FetchV2Op
>
();
FetchV2Op
,
AddOp
,
MatMulOp
,
ReshapeOp
,
CrossEntropyOp
,
TopKOp
,
FullOp
,
MeanOp
,
AddNOp
,
AddGradOp
,
MatMulGradOp
,
ReshapeGradOp
,
CrossEntropyGradOp
,
TopKGradOp
>
();
}
}
void
PaddleDialect
::
PrintType
(
ir
::
Type
type
,
std
::
ostream
&
os
)
{
void
PaddleDialect
::
PrintType
(
ir
::
Type
type
,
std
::
ostream
&
os
)
{
...
...
paddle/fluid/translator/CMakeLists.txt
浏览文件 @
64661927
...
@@ -2,9 +2,20 @@ set(PD_PROGRAM_TRANSLATOR_SOURCE_DIR "${CMAKE_CURRENT_LIST_DIR}")
...
@@ -2,9 +2,20 @@ set(PD_PROGRAM_TRANSLATOR_SOURCE_DIR "${CMAKE_CURRENT_LIST_DIR}")
set
(
PD_PROGRAM_TRANSLATOR_BINARY_DIR
set
(
PD_PROGRAM_TRANSLATOR_BINARY_DIR
"
${
PADDLE_BINARY_DIR
}
/paddle/fluid/translator"
)
"
${
PADDLE_BINARY_DIR
}
/paddle/fluid/translator"
)
set
(
op_gen_file
${
PD_PROGRAM_TRANSLATOR_SOURCE_DIR
}
/op_compat_gen.py
)
set
(
op_compat_yaml_file
${
PADDLE_SOURCE_DIR
}
/paddle/phi/api/yaml/op_compat.yaml
)
set
(
op_compat_source_file
${
PD_PROGRAM_TRANSLATOR_SOURCE_DIR
}
/op_compat_info.cc
)
add_custom_command
(
OUTPUT
${
op_compat_source_file
}
COMMAND
${
PYTHON_EXECUTABLE
}
${
op_gen_file
}
--op_compat_yaml_file
${
op_compat_yaml_file
}
--output_source_file
${
op_compat_source_file
}
DEPENDS
${
op_gen_file
}
${
op_compat_yaml_file
}
VERBATIM
)
file
(
GLOB PD_PROGRAM_TRANSLATOR_SRCS
"*.cc"
)
file
(
GLOB PD_PROGRAM_TRANSLATOR_SRCS
"*.cc"
)
cc_library
(
cc_library
(
program_translator
program_translator
SRCS
${
PD_PROGRAM_TRANSLATOR_SRCS
}
SRCS
${
PD_PROGRAM_TRANSLATOR_SRCS
}
${
op_compat_source_file
}
DEPS proto_desc pd_dialect new_ir framework_proto
)
DEPS proto_desc pd_dialect new_ir framework_proto
)
paddle/fluid/translator/op_compat_gen.py
0 → 100644
浏览文件 @
64661927
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
argparse
from
pathlib
import
Path
import
yaml
from
jinja2
import
Environment
,
FileSystemLoader
,
StrictUndefined
file_loader
=
FileSystemLoader
(
Path
(
__file__
).
parent
)
env
=
Environment
(
loader
=
file_loader
,
keep_trailing_newline
=
True
,
trim_blocks
=
True
,
lstrip_blocks
=
True
,
undefined
=
StrictUndefined
,
extensions
=
[
'jinja2.ext.do'
],
)
def
OpNameNormalizerInitialization
(
op_compat_yaml_file
:
str
=
""
,
output_source_file
:
str
=
""
)
->
None
:
def
to_phi_and_fluid_op_name
(
op_item
):
# Templat: - op : phi_name (fluid_name)
names
=
op_item
.
split
(
'('
)
if
len
(
names
)
==
1
:
phi_fluid_name
=
names
[
0
].
strip
()
return
phi_fluid_name
,
phi_fluid_name
else
:
phi_name
=
names
[
0
].
strip
()
fluid_name
=
names
[
1
].
split
(
')'
)[
0
].
strip
()
return
phi_name
,
fluid_name
with
open
(
op_compat_yaml_file
,
"r"
)
as
f
:
op_compat_infos
=
yaml
.
safe_load
(
f
)
op_name_mappings
=
{}
for
op_compat_item
in
op_compat_infos
:
def
insert_new_mappings
(
op_name_str
):
normalized_name
,
legacy_name
=
to_phi_and_fluid_op_name
(
op_name_str
)
if
normalized_name
==
legacy_name
:
return
op_name_mappings
[
legacy_name
]
=
normalized_name
insert_new_mappings
(
op_compat_item
[
"op"
])
if
"backward"
in
op_compat_item
:
insert_new_mappings
(
op_compat_item
[
"backward"
])
op_name_normailzer_template
=
env
.
get_template
(
"op_compat_info.cc.j2"
)
with
open
(
output_source_file
,
'wt'
)
as
f
:
op_compat_definition
=
op_name_normailzer_template
.
render
(
op_name_paris
=
op_name_mappings
)
f
.
write
(
op_compat_definition
)
# =====================================
# Script parameter parsing
# =====================================
def
ParseArguments
():
parser
=
argparse
.
ArgumentParser
(
description
=
'Generate OP Compatiable info Files By Yaml'
)
parser
.
add_argument
(
'--op_compat_yaml_file'
,
type
=
str
)
parser
.
add_argument
(
'--output_source_file'
,
type
=
str
)
return
parser
.
parse_args
()
# =====================================
# Main
# =====================================
if
__name__
==
"__main__"
:
# parse arguments
args
=
ParseArguments
()
OpNameNormalizerInitialization
(
**
vars
(
args
))
paddle/fluid/translator/op_compat_info.cc.j2
0 → 100644
浏览文件 @
64661927
#include "paddle/fluid/translator/op_compat_info.h"
namespace paddle {
namespace translator {
OpNameNormalizer::OpNameNormalizer() {
op_name_mappings = {
{% for legacy_name, normalized_name in op_name_paris.items() %}
{ "{{legacy_name}}", "{{normalized_name}}" },
{% endfor %}
};
}
} // namespace translator
}// namespace paddle
paddle/fluid/translator/op_compat_info.h
0 → 100644
浏览文件 @
64661927
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <string>
#include <unordered_map>
#include "glog/logging.h"
#pragma once
namespace
paddle
{
namespace
translator
{
class
OpNameNormalizer
{
private:
OpNameNormalizer
();
// Disallow instantiation outside of the class.
std
::
unordered_map
<
std
::
string
,
std
::
string
>
op_name_mappings
;
public:
OpNameNormalizer
(
const
OpNameNormalizer
&
)
=
delete
;
OpNameNormalizer
&
operator
=
(
const
OpNameNormalizer
&
)
=
delete
;
OpNameNormalizer
(
OpNameNormalizer
&&
)
=
delete
;
OpNameNormalizer
&
operator
=
(
OpNameNormalizer
&&
)
=
delete
;
static
auto
&
instance
()
{
static
OpNameNormalizer
OpNameNormalizer
;
return
OpNameNormalizer
;
}
std
::
string
operator
[](
const
std
::
string
&
op_type
)
{
if
(
op_name_mappings
.
find
(
op_type
)
==
op_name_mappings
.
end
())
{
return
op_type
;
}
return
op_name_mappings
.
at
(
op_type
);
}
};
}
// namespace translator
}
// namespace paddle
paddle/fluid/translator/op_translator.cc
浏览文件 @
64661927
...
@@ -22,6 +22,7 @@
...
@@ -22,6 +22,7 @@
#include <vector>
#include <vector>
#include "paddle/fluid/framework/op_desc.h"
#include "paddle/fluid/framework/op_desc.h"
#include "paddle/fluid/translator/op_compat_info.h"
#include "paddle/fluid/translator/program_translator.h"
#include "paddle/fluid/translator/program_translator.h"
#include "paddle/fluid/translator/type_translator.h"
#include "paddle/fluid/translator/type_translator.h"
#include "paddle/ir/builtin_op.h"
#include "paddle/ir/builtin_op.h"
...
@@ -70,11 +71,19 @@ inline bool IsInplace(const OpDesc& op_desc) {
...
@@ -70,11 +71,19 @@ inline bool IsInplace(const OpDesc& op_desc) {
return
inplace
;
return
inplace
;
}
}
inline
std
::
string
OpNamecompatibleMapping
(
std
::
string
op_name
)
{
auto
&
op_normalizer
=
OpNameNormalizer
::
instance
();
return
op_normalizer
[
op_name
];
}
inline
ir
::
OpInfo
LoopkUpOpInfo
(
ir
::
IrContext
*
ctx
,
const
OpDesc
&
op_desc
)
{
inline
ir
::
OpInfo
LoopkUpOpInfo
(
ir
::
IrContext
*
ctx
,
const
OpDesc
&
op_desc
)
{
std
::
string
target_op_name
=
kTargetDialectPrefix
+
op_desc
.
Type
();
std
::
string
target_op_name
=
kTargetDialectPrefix
+
OpNamecompatibleMapping
(
op_desc
.
Type
());
if
(
IsInplace
(
op_desc
))
{
if
(
IsInplace
(
op_desc
))
{
target_op_name
+=
"_"
;
target_op_name
+=
"_"
;
}
}
VLOG
(
6
)
<<
"[op name normalizing: "
<<
op_desc
.
Type
()
<<
" to "
<<
target_op_name
;
auto
op_info
=
ctx
->
GetRegisteredOpInfo
(
target_op_name
);
auto
op_info
=
ctx
->
GetRegisteredOpInfo
(
target_op_name
);
if
(
!
op_info
)
{
if
(
!
op_info
)
{
PADDLE_THROW
(
platform
::
errors
::
PreconditionNotMet
(
PADDLE_THROW
(
platform
::
errors
::
PreconditionNotMet
(
...
...
test/cpp/ir/program_translator_test.cc
浏览文件 @
64661927
...
@@ -49,8 +49,6 @@ ProgramDesc load_from_file(const std::string &file_name) {
...
@@ -49,8 +49,6 @@ ProgramDesc load_from_file(const std::string &file_name) {
TEST
(
PaddleDialectTest
,
Translator
)
{
TEST
(
PaddleDialectTest
,
Translator
)
{
LOG
(
WARNING
)
<<
"TODO"
;
LOG
(
WARNING
)
<<
"TODO"
;
// auto p = load_from_file("restnet50_main.prog");
// auto p = load_from_file("restnet50_main.prog");
// std::cout << p.Size() << std::endl;
// EXPECT_EQ(p.Size(), 1u);
// EXPECT_EQ(p.Size(), 1u);
// ir::IrContext *ctx = ir::IrContext::Instance();
// ir::IrContext *ctx = ir::IrContext::Instance();
...
@@ -58,8 +56,8 @@ TEST(PaddleDialectTest, Translator) {
...
@@ -58,8 +56,8 @@ TEST(PaddleDialectTest, Translator) {
// ctx->GetOrRegisterDialect<ir::BuiltinDialect>();
// ctx->GetOrRegisterDialect<ir::BuiltinDialect>();
// auto program = paddle::TranslateLegacyProgramToProgram(p);
// auto program = paddle::TranslateLegacyProgramToProgram(p);
// s
td::list<ir::Operation *> ops = program->ops
();
// s
ize_t op_size = program->block()->size
();
// ops.size() = op size in BlockDesc + get_parameter_op + combine op
//
//
ops.size() = op size in BlockDesc + get_parameter_op + combine op
// EXPECT_EQ(op
s.size(), p.Block(0).OpSize() + program->parameters_num() +
// EXPECT_EQ(op
_size, p.Block(0).OpSize() + program->parameters_num() + 20);
//
20); std::cout << *program << std::endl
;
//
VLOG(0) << *program
;
}
}
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录