Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
343a9e95
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2299
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
343a9e95
编写于
6月 03, 2023
作者:
K
kangguangli
提交者:
GitHub
6月 03, 2023
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Revert "[IR] Support op attribute and refactor for new op definition (#54068)"
This reverts commit
37930a69
.
上级
37930a69
变更
10
隐藏空白更改
内联
并排
Showing
10 changed file
with
77 addition
and
650 deletion
+77
-650
paddle/fluid/translator/attribute_translator.cc
paddle/fluid/translator/attribute_translator.cc
+0
-231
paddle/fluid/translator/attribute_translator.h
paddle/fluid/translator/attribute_translator.h
+0
-54
paddle/fluid/translator/op_compat_gen.py
paddle/fluid/translator/op_compat_gen.py
+6
-41
paddle/fluid/translator/op_compat_info.cc.j2
paddle/fluid/translator/op_compat_info.cc.j2
+1
-13
paddle/fluid/translator/op_compat_info.h
paddle/fluid/translator/op_compat_info.h
+0
-48
paddle/fluid/translator/op_translator.cc
paddle/fluid/translator/op_translator.cc
+54
-205
paddle/fluid/translator/program_translator.cc
paddle/fluid/translator/program_translator.cc
+1
-1
paddle/fluid/translator/program_translator.h
paddle/fluid/translator/program_translator.h
+2
-2
paddle/fluid/translator/utils.h
paddle/fluid/translator/utils.h
+0
-42
test/cpp/ir/core/program_translator_test.cc
test/cpp/ir/core/program_translator_test.cc
+13
-13
未找到文件。
paddle/fluid/translator/attribute_translator.cc
已删除
100644 → 0
浏览文件 @
37930a69
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/translator/attribute_translator.h"
#include <string>
#include <vector>
#include "paddle/fluid/dialect/pd_attribute.h"
#include "paddle/phi/common/data_type.h"
#include "paddle/phi/common/int_array.h"
#include "paddle/phi/common/layout.h"
#include "paddle/phi/common/place.h"
#include "paddle/phi/common/scalar.h"
#include "paddle/utils/variant.h"
namespace
paddle
{
namespace
translator
{
class
AttributeVisitor
{
public:
ir
::
IrContext
*
ctx
;
AttributeVisitor
()
{
ctx
=
ir
::
IrContext
::
Instance
();
}
~
AttributeVisitor
()
{}
public:
virtual
ir
::
Attribute
operator
()(
int
i
)
{
VLOG
(
10
)
<<
"translating int"
;
return
ir
::
Int32_tAttribute
::
get
(
ctx
,
i
);
}
virtual
ir
::
Attribute
operator
()(
float
f
)
{
VLOG
(
10
)
<<
"translating float"
;
return
ir
::
FloatAttribute
::
get
(
ctx
,
f
);
}
virtual
ir
::
Attribute
operator
()(
bool
b
)
{
VLOG
(
10
)
<<
"translating bool"
;
return
ir
::
BoolAttribute
::
get
(
ctx
,
b
);
}
virtual
ir
::
Attribute
operator
()(
double
d
)
{
VLOG
(
10
)
<<
"translating double"
;
return
ir
::
DoubleAttribute
::
get
(
ctx
,
d
);
}
virtual
ir
::
Attribute
operator
()(
std
::
string
str
)
{
VLOG
(
10
)
<<
"translating string"
;
return
ir
::
StrAttribute
::
get
(
ctx
,
str
);
}
virtual
ir
::
Attribute
operator
()(
const
paddle
::
experimental
::
Scalar
&
scalar
)
{
VLOG
(
10
)
<<
"translating scalar"
;
return
paddle
::
dialect
::
ScalarAttribute
::
get
(
ctx
,
scalar
);
}
virtual
ir
::
Attribute
operator
()(
const
std
::
vector
<
std
::
string
>&
strs
)
{
VLOG
(
10
)
<<
"translating vector<string>"
;
std
::
vector
<
ir
::
Attribute
>
attrs
;
attrs
.
reserve
(
strs
.
size
());
for
(
const
auto
&
v
:
strs
)
{
attrs
.
push_back
(
ir
::
StrAttribute
::
get
(
ctx
,
v
));
}
return
ir
::
ArrayAttribute
::
get
(
ctx
,
attrs
);
}
virtual
ir
::
Attribute
operator
()(
const
std
::
vector
<
float
>&
fs
)
{
VLOG
(
10
)
<<
"translating vector<float>"
;
std
::
vector
<
ir
::
Attribute
>
attrs
;
attrs
.
reserve
(
fs
.
size
());
for
(
const
auto
&
v
:
fs
)
{
attrs
.
push_back
(
ir
::
FloatAttribute
::
get
(
ctx
,
v
));
}
return
ir
::
ArrayAttribute
::
get
(
ctx
,
attrs
);
}
virtual
ir
::
Attribute
operator
()(
const
std
::
vector
<
int
>&
is
)
{
VLOG
(
10
)
<<
"translating vector<int>"
;
std
::
vector
<
ir
::
Attribute
>
attrs
;
attrs
.
reserve
(
is
.
size
());
for
(
const
auto
&
v
:
is
)
{
attrs
.
push_back
(
ir
::
Int32_tAttribute
::
get
(
ctx
,
v
));
}
return
ir
::
ArrayAttribute
::
get
(
ctx
,
attrs
);
}
virtual
ir
::
Attribute
operator
()(
const
std
::
vector
<
bool
>&
bs
)
{
VLOG
(
10
)
<<
"translating vector<bool>"
;
std
::
vector
<
ir
::
Attribute
>
attrs
;
attrs
.
reserve
(
bs
.
size
());
for
(
const
auto
&
v
:
bs
)
{
attrs
.
push_back
(
ir
::
BoolAttribute
::
get
(
ctx
,
v
));
}
return
ir
::
ArrayAttribute
::
get
(
ctx
,
attrs
);
}
virtual
ir
::
Attribute
operator
()(
const
std
::
vector
<
int64_t
>&
i64s
)
{
VLOG
(
10
)
<<
"translating vector<int64>"
;
std
::
vector
<
ir
::
Attribute
>
attrs
;
attrs
.
reserve
(
i64s
.
size
());
for
(
const
auto
&
v
:
i64s
)
{
attrs
.
push_back
(
ir
::
Int64_tAttribute
::
get
(
ctx
,
v
));
}
return
ir
::
ArrayAttribute
::
get
(
ctx
,
attrs
);
}
virtual
ir
::
Attribute
operator
()(
const
std
::
vector
<
double
>&
ds
)
{
VLOG
(
10
)
<<
"translating vector<double>"
;
std
::
vector
<
ir
::
Attribute
>
attrs
;
attrs
.
reserve
(
ds
.
size
());
for
(
const
auto
&
v
:
ds
)
{
attrs
.
push_back
(
ir
::
DoubleAttribute
::
get
(
ctx
,
v
));
}
return
ir
::
ArrayAttribute
::
get
(
ctx
,
attrs
);
}
virtual
ir
::
Attribute
operator
()(
const
std
::
vector
<
paddle
::
experimental
::
Scalar
>&
ss
)
{
VLOG
(
10
)
<<
"translating vector<scalar>"
;
std
::
vector
<
ir
::
Attribute
>
attrs
;
attrs
.
reserve
(
ss
.
size
());
for
(
const
auto
&
v
:
ss
)
{
attrs
.
push_back
(
paddle
::
dialect
::
ScalarAttribute
::
get
(
ctx
,
v
));
}
return
ir
::
ArrayAttribute
::
get
(
ctx
,
attrs
);
}
virtual
ir
::
Attribute
operator
()(
const
paddle
::
blank
&
blank
)
{
VLOG
(
10
)
<<
"translating paddle::blank"
;
return
ir
::
Attribute
(
nullptr
);
}
template
<
typename
T
>
ir
::
Attribute
operator
()(
T
attr
)
{
VLOG
(
10
)
<<
"translating null type"
;
return
ir
::
Attribute
(
nullptr
);
}
};
class
IntArrayAttributeVisitor
:
public
AttributeVisitor
{
public:
using
AttributeVisitor
::
AttributeVisitor
;
ir
::
Attribute
operator
()(
const
std
::
vector
<
int
>&
is
)
override
{
VLOG
(
10
)
<<
"translating vector<int> to IntArray"
;
phi
::
IntArray
data
(
is
);
return
paddle
::
dialect
::
IntArrayAttribute
::
get
(
ctx
,
data
);
}
ir
::
Attribute
operator
()(
const
std
::
vector
<
int64_t
>&
is
)
override
{
VLOG
(
10
)
<<
"translating vector<int> to IntArray"
;
phi
::
IntArray
data
(
is
);
return
paddle
::
dialect
::
IntArrayAttribute
::
get
(
ctx
,
data
);
}
};
class
ScalarAttributeVisitor
:
public
AttributeVisitor
{
public:
using
AttributeVisitor
::
AttributeVisitor
;
ir
::
Attribute
operator
()(
int
i
)
override
{
VLOG
(
10
)
<<
"translating int to Scalar"
;
phi
::
Scalar
data
(
i
);
return
paddle
::
dialect
::
ScalarAttribute
::
get
(
ctx
,
data
);
}
ir
::
Attribute
operator
()(
float
f
)
override
{
VLOG
(
10
)
<<
"translating float to Scalar"
;
phi
::
Scalar
data
(
f
);
return
paddle
::
dialect
::
ScalarAttribute
::
get
(
ctx
,
data
);
}
};
class
DataTypeAttributeVisitor
:
public
AttributeVisitor
{
public:
using
AttributeVisitor
::
AttributeVisitor
;
ir
::
Attribute
operator
()(
int
i
)
override
{
VLOG
(
10
)
<<
"translating int to DataType: "
<<
i
;
phi
::
DataType
data
=
static_cast
<
phi
::
DataType
>
(
i
);
return
paddle
::
dialect
::
DataTypeAttribute
::
get
(
ctx
,
data
);
}
};
class
PlaceAttributeVisitor
:
public
AttributeVisitor
{
public:
using
AttributeVisitor
::
AttributeVisitor
;
ir
::
Attribute
operator
()(
const
paddle
::
blank
&
blank
)
override
{
VLOG
(
10
)
<<
"translating paddle::blank"
;
phi
::
Place
data
(
phi
::
AllocationType
::
CPU
);
return
paddle
::
dialect
::
PlaceAttribute
::
get
(
ctx
,
data
);
}
};
AttributeTranslator
::
AttributeTranslator
()
{
general_visitor
=
new
AttributeVisitor
();
special_visitors
[
"paddle::dialect::IntArrayAttribute"
]
=
new
IntArrayAttributeVisitor
();
special_visitors
[
"paddle::dialect::ScalarAttribute"
]
=
new
ScalarAttributeVisitor
();
special_visitors
[
"paddle::dialect::DataTypeAttribute"
]
=
new
DataTypeAttributeVisitor
();
special_visitors
[
"paddle::dialect::PlaceAttribute"
]
=
new
PlaceAttributeVisitor
();
}
ir
::
Attribute
AttributeTranslator
::
operator
()(
const
framework
::
Attribute
&
attr
)
{
return
paddle
::
visit
(
*
general_visitor
,
attr
);
}
ir
::
Attribute
AttributeTranslator
::
operator
()(
const
std
::
string
&
target_type
,
const
framework
::
Attribute
&
attr
)
{
if
(
special_visitors
.
find
(
target_type
)
==
special_visitors
.
end
())
{
VLOG
(
10
)
<<
"["
<<
target_type
<<
"] not found"
;
return
paddle
::
visit
(
*
general_visitor
,
attr
);
}
return
paddle
::
visit
(
*
(
special_visitors
.
at
(
target_type
)),
attr
);
}
}
// namespace translator
}
// namespace paddle
paddle/fluid/translator/attribute_translator.h
已删除
100644 → 0
浏览文件 @
37930a69
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <string>
#include <unordered_map>
#include "paddle/fluid/framework/attribute.h"
#include "paddle/fluid/framework/type_defs.h"
#include "paddle/ir/core/attribute.h"
#include "paddle/ir/core/builtin_attribute.h"
#include "paddle/ir/core/ir_context.h"
#pragma once
namespace
paddle
{
namespace
translator
{
class
AttributeVisitor
;
class
AttributeTranslator
{
private:
AttributeTranslator
();
AttributeVisitor
*
general_visitor
;
std
::
unordered_map
<
std
::
string
,
AttributeVisitor
*>
special_visitors
;
public:
AttributeTranslator
(
const
AttributeTranslator
&
)
=
delete
;
AttributeTranslator
&
operator
=
(
const
AttributeTranslator
&
)
=
delete
;
AttributeTranslator
(
AttributeTranslator
&&
)
=
delete
;
AttributeTranslator
&
operator
=
(
AttributeTranslator
&&
)
=
delete
;
static
auto
&
instance
()
{
static
AttributeTranslator
attribute_translator
;
return
attribute_translator
;
}
ir
::
Attribute
operator
()(
const
framework
::
Attribute
&
attr
);
ir
::
Attribute
operator
()(
const
std
::
string
&
target_type
,
const
framework
::
Attribute
&
attr
);
};
}
// namespace translator
}
// namespace paddle
paddle/fluid/translator/op_compat_gen.py
浏览文件 @
343a9e95
...
...
@@ -14,7 +14,6 @@
import
argparse
from
pathlib
import
Path
from
typing
import
Dict
import
yaml
from
jinja2
import
Environment
,
FileSystemLoader
,
StrictUndefined
...
...
@@ -34,7 +33,7 @@ def OpNameNormalizerInitialization(
op_compat_yaml_file
:
str
=
""
,
output_source_file
:
str
=
""
)
->
None
:
def
to_phi_and_fluid_op_name
(
op_item
):
# Templat
e
: - op : phi_name (fluid_name)
# Templat: - op : phi_name (fluid_name)
names
=
op_item
.
split
(
'('
)
if
len
(
names
)
==
1
:
phi_fluid_name
=
names
[
0
].
strip
()
...
...
@@ -47,55 +46,21 @@ def OpNameNormalizerInitialization(
with
open
(
op_compat_yaml_file
,
"r"
)
as
f
:
op_compat_infos
=
yaml
.
safe_load
(
f
)
op_name_mappings
=
{}
op_arg_name_mappings
=
{}
for
op_compat_item
in
op_compat_infos
:
def
insert_new_mappings
(
op_name_str
:
str
)
->
str
:
def
insert_new_mappings
(
op_name_str
)
:
normalized_name
,
legacy_name
=
to_phi_and_fluid_op_name
(
op_name_str
)
if
normalized_name
==
legacy_name
:
return
normalized_name
,
legacy_name
op_name_mappings
[
legacy_name
]
=
normalized_name
return
normalized_name
,
legacy_name
def
insert_new_arg_mappings
(
op_name
:
str
,
arg_mapping
:
Dict
[
str
,
str
]):
if
op_name
is
None
:
return
if
op_name
not
in
op_arg_name_mappings
:
op_arg_name_mappings
[
op_name
]
=
{}
op_arg_name_mappings
[
op_name
].
update
(
arg_mapping
)
op_name_mappings
[
legacy_name
]
=
normalized_name
_
,
legacy_name
=
insert_new_mappings
(
op_compat_item
[
"op"
])
legacy_backward_op_names
=
[]
insert_new_mappings
(
op_compat_item
[
"op"
])
if
"backward"
in
op_compat_item
:
backward_op_name_mapping_paris
=
op_compat_item
[
"backward"
].
split
(
","
)
for
pair
in
backward_op_name_mapping_paris
:
_
,
legacy_backward_op_name
=
insert_new_mappings
(
pair
)
legacy_backward_op_names
.
append
(
legacy_backward_op_name
)
if
"inputs"
in
op_compat_item
:
insert_new_arg_mappings
(
legacy_name
,
op_compat_item
[
"inputs"
])
for
backward_op
in
legacy_backward_op_names
:
insert_new_arg_mappings
(
backward_op
,
op_compat_item
[
"inputs"
])
if
"attrs"
in
op_compat_item
:
insert_new_arg_mappings
(
legacy_name
,
op_compat_item
[
"attrs"
])
for
backward_op
in
legacy_backward_op_names
:
insert_new_arg_mappings
(
backward_op
,
op_compat_item
[
"attrs"
])
if
"outputs"
in
op_compat_item
:
insert_new_arg_mappings
(
legacy_name
,
op_compat_item
[
"outputs"
])
for
backward_op
in
legacy_backward_op_names
:
insert_new_arg_mappings
(
backward_op
,
op_compat_item
[
"outputs"
])
# special op mappings
op_name_mappings
[
"fetch_v2"
]
=
"fetch"
insert_new_mappings
(
op_compat_item
[
"backward"
])
op_name_normailzer_template
=
env
.
get_template
(
"op_compat_info.cc.j2"
)
with
open
(
output_source_file
,
'wt'
)
as
f
:
op_compat_definition
=
op_name_normailzer_template
.
render
(
op_name_pairs
=
op_name_mappings
,
op_arg_name_pairs
=
op_arg_name_mappings
,
op_name_paris
=
op_name_mappings
)
f
.
write
(
op_compat_definition
)
...
...
paddle/fluid/translator/op_compat_info.cc.j2
浏览文件 @
343a9e95
...
...
@@ -5,22 +5,10 @@ namespace translator {
OpNameNormalizer::OpNameNormalizer() {
op_name_mappings = {
{% for legacy_name, normalized_name in op_name_pa
ir
s.items() %}
{% for legacy_name, normalized_name in op_name_pa
ri
s.items() %}
{ "{{legacy_name}}", "{{normalized_name}}" },
{% endfor %}
};
op_arg_name_mappings = {
{% for op_name, arg_name_mappings in op_arg_name_pairs.items() %}
{
"{{op_name}}",
{
{% for normalized_name, legacy_name in arg_name_mappings.items() %}
{ "{{normalized_name}}", "{{legacy_name}}" },
{% endfor %}
},
},
{% endfor %}
};
}
} // namespace translator
...
...
paddle/fluid/translator/op_compat_info.h
浏览文件 @
343a9e95
...
...
@@ -12,14 +12,11 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include <functional>
#include <string>
#include <unordered_map>
#include "glog/logging.h"
#include "paddle/fluid/translator/utils.h"
#pragma once
namespace
paddle
{
...
...
@@ -29,8 +26,6 @@ class OpNameNormalizer {
private:
OpNameNormalizer
();
// Disallow instantiation outside of the class.
std
::
unordered_map
<
std
::
string
,
std
::
string
>
op_name_mappings
;
std
::
unordered_map
<
std
::
string
,
std
::
unordered_map
<
std
::
string
,
std
::
string
>>
op_arg_name_mappings
;
public:
OpNameNormalizer
(
const
OpNameNormalizer
&
)
=
delete
;
...
...
@@ -49,49 +44,6 @@ class OpNameNormalizer {
}
return
op_name_mappings
.
at
(
op_type
);
}
std
::
string
GetLegacyArgName
(
const
std
::
string
&
op_type
,
const
std
::
string
&
arg_name
)
{
bool
is_grad_op
=
(
op_type
.
find
(
"grad"
)
!=
std
::
string
::
npos
);
bool
is_grad_arg
=
(
arg_name
.
find
(
"grad"
)
!=
std
::
string
::
npos
);
if
(
is_grad_op
&&
is_grad_arg
)
{
std
::
string
target
=
"_grad"
;
std
::
string
data
=
"@GRAD"
;
size_t
first_grad_pos
=
arg_name
.
find_first_of
(
target
);
std
::
string
legacy_name
=
this
->
GetLegacyArgName
(
op_type
,
arg_name
.
substr
(
0
,
first_grad_pos
));
legacy_name
+=
arg_name
.
substr
(
first_grad_pos
);
for
(
size_t
pos
=
0
;
legacy_name
.
npos
!=
(
pos
=
legacy_name
.
find
(
target
,
pos
));
pos
+=
data
.
length
())
{
legacy_name
.
replace
(
pos
,
target
.
length
(),
data
);
}
return
legacy_name
;
}
if
(
op_arg_name_mappings
.
find
(
op_type
)
==
op_arg_name_mappings
.
end
())
{
return
UnderscoreToCamelCase
(
arg_name
);
}
auto
&
arg_mappings
=
op_arg_name_mappings
[
op_type
];
if
(
arg_mappings
.
find
(
arg_name
)
==
arg_mappings
.
end
())
{
return
UnderscoreToCamelCase
(
arg_name
);
}
return
arg_mappings
.
at
(
arg_name
);
}
std
::
string
GetLegacyAttrName
(
const
std
::
string
&
op_type
,
const
std
::
string
&
arg_name
)
{
if
(
op_arg_name_mappings
.
find
(
op_type
)
==
op_arg_name_mappings
.
end
())
{
VLOG
(
10
)
<<
"["
<<
op_type
<<
"] not found"
;
return
arg_name
;
}
auto
&
arg_mappings
=
op_arg_name_mappings
[
op_type
];
if
(
arg_mappings
.
find
(
arg_name
)
==
arg_mappings
.
end
())
{
VLOG
(
10
)
<<
"["
<<
op_type
<<
"]["
<<
arg_name
<<
"] not found"
;
return
arg_name
;
}
return
arg_mappings
.
at
(
arg_name
);
}
};
}
// namespace translator
...
...
paddle/fluid/translator/op_translator.cc
浏览文件 @
343a9e95
...
...
@@ -15,23 +15,19 @@
#include "paddle/fluid/translator/op_translator.h"
#include <algorithm>
#include <cctype>
#include <numeric>
#include <string>
#include <tuple>
#include <unordered_map>
#include <vector>
#include "paddle/fluid/dialect/pd_interface.h"
#include "paddle/fluid/framework/op_desc.h"
#include "paddle/fluid/translator/attribute_translator.h"
#include "paddle/fluid/translator/op_compat_info.h"
#include "paddle/fluid/translator/program_translator.h"
#include "paddle/fluid/translator/type_translator.h"
#include "paddle/ir/core/builtin_op.h"
#include "paddle/ir/core/builtin_type.h"
#include "paddle/ir/core/ir_context.h"
#include "paddle/ir/core/operation.h"
#include "paddle/ir/core/value.h"
#include "paddle/phi/core/enforce.h"
...
...
@@ -46,24 +42,11 @@ using BlockDesc = paddle::framework::BlockDesc;
using
VarDesc
=
paddle
::
framework
::
VarDesc
;
using
OpOutputTypeList
=
std
::
vector
<
ir
::
Type
>
;
using
OpOutputMapping
=
std
::
unordered_map
<
std
::
string
,
ResultIdx
>
;
using
OpInputInfo
=
paddle
::
dialect
::
OpInputInfo
;
using
OpInputInfoList
=
std
::
vector
<
paddle
::
dialect
::
OpInputInfo
>
;
using
OpAttributeInfo
=
paddle
::
dialect
::
OpAttributeInfo
;
using
OpAttributeInfoList
=
std
::
vector
<
paddle
::
dialect
::
OpAttributeInfo
>
;
using
OpOutputInfo
=
paddle
::
dialect
::
OpOutputInfo
;
using
OpOutputInfoList
=
std
::
vector
<
paddle
::
dialect
::
OpOutputInfo
>
;
static
const
char
kTargetDialectPrefix
[]
=
"pd."
;
static
const
std
::
unordered_set
<
std
::
string
>
special_inplace_ops
=
{
"batch_norm"
,
};
inline
bool
IsInplace
(
const
OpDesc
&
op_desc
)
{
bool
inplace
=
false
;
if
(
special_inplace_ops
.
count
(
op_desc
.
Type
()))
{
return
inplace
;
}
auto
input_names
=
op_desc
.
InputArgumentNames
();
auto
output_names
=
op_desc
.
OutputArgumentNames
();
...
...
@@ -146,7 +129,7 @@ inline ir::Operation* InsertCombineOperationForTarget(
std
::
vector
<
ir
::
OpResult
>
src_values
;
std
::
vector
<
ir
::
Type
>
types_in_vec
;
for
(
const
auto
&
arg_name
:
args
)
{
for
(
auto
arg_name
:
args
)
{
auto
defining_info
=
param_map
->
at
(
arg_name
);
src_values
.
push_back
(
defining_info
.
value
);
types_in_vec
.
push_back
(
defining_info
.
value
.
type
());
...
...
@@ -158,25 +141,13 @@ inline ir::Operation* InsertCombineOperationForTarget(
return
operation
;
}
inline
ir
::
Operation
*
InsertConstantOperationForOptionalArg
(
ir
::
IrContext
*
ctx
,
ir
::
Program
*
program
)
{
std
::
string
constant_op_name
(
ir
::
ConstantOp
::
name
());
ir
::
OpInfo
op_info
=
ctx
->
GetRegisteredOpInfo
(
constant_op_name
);
ir
::
Type
null_type
=
ir
::
Type
(
nullptr
);
ir
::
Operation
*
operation
=
ir
::
Operation
::
create
({},
{},
{
null_type
},
op_info
);
program
->
block
()
->
push_back
(
operation
);
return
operation
;
}
inline
std
::
vector
<
ir
::
OpResult
>
GenerateOperationInput
(
ir
::
IrContext
*
ctx
,
TranslationContext
*
param_map
,
ir
::
Program
*
program
,
const
OpDesc
&
op_desc
,
const
std
::
string
&
normalized_op_name
,
const
OpInputInfoList
&
input_infos
)
{
const
OpDesc
&
op_desc
)
{
std
::
vector
<
ir
::
OpResult
>
op_inputs
=
{};
// scan all inputs to see if any of them is generated as a vector<Tensor>
// so need an additional `SliceOp` to take it out.
for
(
const
auto
&
n
:
op_desc
.
Inputs
())
{
...
...
@@ -188,7 +159,7 @@ inline std::vector<ir::OpResult> GenerateOperationInput(
param_map
->
count
(
arg_name
),
0
,
platform
::
errors
::
PreconditionNotMet
(
"arg %s.%s as input should be exists before prasing %
s
"
,
"arg %s.%s as input should be exists before prasing %
d
"
,
name
,
arg_name
,
op_desc
.
Type
()));
...
...
@@ -200,116 +171,73 @@ inline std::vector<ir::OpResult> GenerateOperationInput(
}
}
std
::
vector
<
ir
::
OpResult
>
op_inputs
;
auto
&
op_normalizer
=
OpNameNormalizer
::
instance
();
for
(
const
auto
&
info
:
input_infos
)
{
std
::
string
legacy_input_name
=
op_normalizer
.
GetLegacyArgName
(
op_desc
.
Type
(),
info
.
name
);
// return empty OpResult if this arg is optional and not shown in OpDesc
// TODO(lyk): HasInput doesnot consider variadic attribute
if
(
!
op_desc
.
HasInput
(
legacy_input_name
))
{
PADDLE_ENFORCE
(
info
.
optional
,
platform
::
errors
::
PreconditionNotMet
(
"Op %s arg %s should be optional if it can be empty"
,
op_desc
.
Type
(),
legacy_input_name
));
op_inputs
.
push_back
(
ir
::
OpResult
(
nullptr
));
continue
;
}
const
auto
&
legacy_input_vars
=
op_desc
.
Input
(
legacy_input_name
,
true
);
bool
is_vector
=
(
info
.
type_name
.
find
(
"VectorType"
)
!=
std
::
string
::
npos
);
for
(
const
auto
&
n
:
op_desc
.
Inputs
())
{
auto
&
name
=
n
.
first
;
VLOG
(
10
)
<<
"[input retriving]"
<<
"["
<<
op_desc
.
Type
()
<<
"]"
<<
name
;
auto
&
args
=
n
.
second
;
// if src type is Tensor
if
(
!
is_vector
)
{
auto
defining_info
=
(
*
param_map
)[
legacy_input_vars
[
0
]];
op_inputs
.
push_back
(
defining_info
.
value
);
// if src type is Tensor or a Vector<Tensor> with size <= 1
if
(
args
.
size
()
<=
1
)
{
for
(
const
auto
&
arg_name
:
args
)
{
auto
defining_info
=
(
*
param_map
)[
arg_name
];
op_inputs
.
push_back
(
defining_info
.
value
);
}
// if src type is Vector<Tesnor> , need an additional `CombineOp` to
// assemble them.
}
else
{
auto
*
combine_op
=
InsertCombineOperationForTarget
(
ctx
,
param_map
,
program
,
legacy_input_var
s
);
auto
*
combine_op
=
InsertCombineOperationForTarget
(
ctx
,
param_map
,
program
,
arg
s
);
op_inputs
.
push_back
(
combine_op
->
GetResultByIndex
(
0
));
}
}
return
op_inputs
;
}
inline
std
::
tuple
<
OpOutputTypeList
,
OpOutputMapping
>
GenerateOperationOutput
(
ir
::
IrContext
*
ctx
,
const
OpDesc
&
op_desc
,
const
OpOutputInfoList
&
output_infos
)
{
ir
::
IrContext
*
ctx
,
const
OpDesc
&
op_desc
)
{
OpOutputMapping
arg_to_idx
;
OpOutputTypeList
op_output_types
=
{};
auto
&
type_translator
=
TypeTranslator
::
instance
();
auto
&
op_normalizer
=
OpNameNormalizer
::
instance
();
const
BlockDesc
*
block
=
op_desc
.
Block
();
for
(
const
auto
&
n
:
op_desc
.
Outputs
())
{
auto
&
name
=
n
.
first
;
VLOG
(
10
)
<<
"[output translating]"
<<
"["
<<
op_desc
.
Type
()
<<
"]"
<<
name
;
auto
&
args
=
n
.
second
;
for
(
const
auto
&
info
:
output_infos
)
{
size_t
cur_output_idx
=
op_output_types
.
size
();
std
::
string
legacy_output_name
=
op_normalizer
.
GetLegacyArgName
(
op_desc
.
Type
(),
info
.
name
);
// return empty type if this arg is optional and not shown in OpDesc
// TODO(lyk): HasOutput doesnot consider variadic attribute
if
(
!
op_desc
.
HasOutput
(
legacy_output_name
))
{
VLOG
(
10
)
<<
"[output translating]"
<<
"["
<<
op_desc
.
Type
()
<<
"] optional "
<<
info
.
name
<<
" :"
<<
info
.
type_name
<<
" "
<<
legacy_output_name
;
PADDLE_ENFORCE
(
info
.
optional
,
platform
::
errors
::
PreconditionNotMet
(
"Op %s arg %s should be optional if it can be empty"
,
op_desc
.
Type
(),
legacy_output_name
));
op_output_types
.
push_back
(
ir
::
Type
(
nullptr
));
continue
;
}
const
auto
&
legacy_output_vars
=
op_desc
.
Output
(
legacy_output_name
);
bool
is_vector
=
(
info
.
type_name
.
find
(
"VectorType"
)
!=
std
::
string
::
npos
);
// if src type is Tensor
if
(
!
is_vector
)
{
VLOG
(
10
)
<<
"[output translating]"
<<
"["
<<
op_desc
.
Type
()
<<
"]"
<<
info
.
name
<<
" :"
<<
info
.
type_name
<<
" "
<<
legacy_output_name
;
if
(
legacy_output_vars
.
size
()
==
0
)
{
op_output_types
.
push_back
(
ir
::
Type
(
nullptr
));
continue
;
}
auto
&
var_name
=
legacy_output_vars
[
0
];
VarDesc
*
var
=
block
->
FindVarRecursive
(
var_name
);
VLOG
(
10
)
<<
"[output translating]"
<<
"["
<<
op_desc
.
Type
()
<<
"]"
<<
info
.
name
<<
" "
<<
var_name
<<
" "
<<
var
->
GetType
();
// if src type is Tensor or a Vector<Tensor> with size <= 1
if
(
args
.
size
()
<=
1
)
{
for
(
const
auto
&
arg_name
:
args
)
{
VarDesc
*
var
=
block
->
FindVarRecursive
(
arg_name
);
VLOG
(
10
)
<<
"[output translating]"
<<
"["
<<
op_desc
.
Type
()
<<
"]"
<<
name
<<
" "
<<
arg_name
<<
" "
<<
var
->
GetType
();
ir
::
Type
translated_var_type
=
type_translator
[
var
->
GetType
()](
ctx
,
*
var
);
ir
::
Type
translated_var_type
=
type_translator
[
var
->
GetType
()](
ctx
,
*
var
);
arg_to_idx
[
var_name
]
=
cur_output_idx
;
op_output_types
.
push_back
(
translated_var_type
);
arg_to_idx
[
arg_name
]
=
cur_output_idx
;
op_output_types
.
push_back
(
translated_var_type
);
}
// if src type is Vector<Tesnor>
}
else
{
VLOG
(
10
)
<<
"[output translating]"
<<
"["
<<
op_desc
.
Type
()
<<
"]"
<<
info
.
name
<<
" :"
<<
info
.
type_name
<<
" "
<<
legacy_output_name
;
std
::
vector
<
ir
::
Type
>
types
;
for
(
const
auto
&
var_name
:
legacy_output_var
s
)
{
VarDesc
*
var
=
block
->
FindVarRecursive
(
var
_name
);
for
(
const
auto
&
arg_name
:
arg
s
)
{
VarDesc
*
var
=
block
->
FindVarRecursive
(
arg
_name
);
VLOG
(
10
)
<<
"[output translating]"
<<
"["
<<
op_desc
.
Type
()
<<
"]"
<<
info
.
name
<<
" "
<<
var
_name
<<
"["
<<
op_desc
.
Type
()
<<
"]"
<<
name
<<
" "
<<
arg
_name
<<
" "
<<
var
->
GetType
();
ir
::
Type
translated_var_type
=
type_translator
[
var
->
GetType
()](
ctx
,
*
var
);
types
.
push_back
(
translated_var_type
);
arg_to_idx
[
var
_name
]
=
cur_output_idx
;
arg_to_idx
[
arg
_name
]
=
cur_output_idx
;
}
ir
::
Type
vec_type
=
ir
::
VectorType
::
get
(
ctx
,
types
);
op_output_types
.
push_back
(
vec_type
);
...
...
@@ -318,38 +246,6 @@ inline std::tuple<OpOutputTypeList, OpOutputMapping> GenerateOperationOutput(
return
{
op_output_types
,
arg_to_idx
};
}
inline
ir
::
AttributeMap
TranslateOpAttribute
(
std
::
string
normalized_op_name
,
const
OpAttributeInfoList
&
op_attr_infos
,
const
OpDesc
&
op_desc
)
{
auto
&
attribute_translator
=
AttributeTranslator
::
instance
();
auto
&
op_normalizer
=
OpNameNormalizer
::
instance
();
ir
::
AttributeMap
attribute_map
=
{};
for
(
const
auto
&
info
:
op_attr_infos
)
{
auto
legacy_attr_name
=
op_normalizer
.
GetLegacyAttrName
(
op_desc
.
Type
(),
info
.
name
);
paddle
::
framework
::
Attribute
legacy_attr
;
if
(
op_desc
.
HasAttr
(
legacy_attr_name
))
{
legacy_attr
=
op_desc
.
GetAttr
(
legacy_attr_name
);
}
VLOG
(
10
)
<<
"attribute in "
<<
op_desc
.
Type
()
<<
" name: "
<<
legacy_attr_name
<<
" "
<<
legacy_attr
.
index
();
ir
::
Attribute
new_attr
=
attribute_translator
(
info
.
type_name
,
legacy_attr
);
attribute_map
[
info
.
name
]
=
new_attr
;
if
(
!
new_attr
)
{
VLOG
(
0
)
<<
"empty attribute in "
<<
op_desc
.
Type
()
<<
" name: "
<<
info
.
name
;
}
else
{
VLOG
(
10
)
<<
"new attribute in "
<<
op_desc
.
Type
()
<<
" name: "
<<
info
.
name
<<
" "
<<
new_attr
.
storage
();
}
}
return
attribute_map
;
}
inline
void
RecordOpResultMapping
(
TranslationContext
*
param_map
,
const
OpDesc
&
op_desc
,
ir
::
Operation
*
operation
,
...
...
@@ -378,34 +274,15 @@ ir::Operation* GeneralOpHandler(ir::IrContext* ctx,
TranslationContext
*
param_map
,
ir
::
Program
*
program
,
const
OpDesc
&
op_desc
)
{
auto
op_info
=
LoopkUpOpInfo
(
ctx
,
op_desc
);
auto
*
op_info_concept
=
op_info
.
GetInterfaceImpl
<
paddle
::
dialect
::
GetOpInfoInterface
>
();
OpInputInfoList
input_infos
;
OpAttributeInfoList
attr_infos
;
OpOutputInfoList
output_infos
;
std
::
tie
(
input_infos
,
attr_infos
,
output_infos
,
std
::
ignore
)
=
op_info_concept
->
get_op_info_
();
auto
op_inputs
=
GenerateOperationInput
(
ctx
,
param_map
,
program
,
op_desc
,
op_info
.
name
(),
input_infos
);
auto
op_inputs
=
GenerateOperationInput
(
ctx
,
param_map
,
program
,
op_desc
);
OpOutputMapping
arg_to_idx
;
OpOutputTypeList
op_output_types
;
std
::
tie
(
op_output_types
,
arg_to_idx
)
=
GenerateOperationOutput
(
ctx
,
op_desc
,
output_infos
);
auto
attribute_map
=
TranslateOpAttribute
(
op_info
.
name
(),
attr_infos
,
op_desc
);
VLOG
(
4
)
<<
"[general op]["
<<
op_desc
.
Type
()
<<
"] preparation end."
;
OpOutputTypeList
op_output_types
=
{};
std
::
tie
(
op_output_types
,
arg_to_idx
)
=
GenerateOperationOutput
(
ctx
,
op_desc
);
auto
op_info
=
LoopkUpOpInfo
(
ctx
,
op_desc
);
ir
::
Operation
*
operation
=
ir
::
Operation
::
create
(
op_inputs
,
attribute_map
,
op_output_types
,
op_info
);
VLOG
(
4
)
<<
"[general op]["
<<
op_desc
.
Type
()
<<
"] opearation creation end."
;
ir
::
Operation
::
create
(
op_inputs
,
{},
op_output_types
,
op_info
);
program
->
block
()
->
push_back
(
operation
);
VLOG
(
4
)
<<
"[general op]["
<<
op_desc
.
Type
()
<<
"] opearation insertion end."
;
RecordOpResultMapping
(
param_map
,
op_desc
,
operation
,
arg_to_idx
);
return
operation
;
...
...
@@ -415,28 +292,14 @@ ir::Operation* FeedOpHandler(ir::IrContext* ctx,
TranslationContext
*
param_map
,
ir
::
Program
*
program
,
const
OpDesc
&
op_desc
)
{
auto
op_info
=
LoopkUpOpInfo
(
ctx
,
op_desc
);
auto
*
op_info_concept
=
op_info
.
GetInterfaceImpl
<
paddle
::
dialect
::
GetOpInfoInterface
>
();
OpInputInfoList
input_infos
;
OpAttributeInfoList
attr_infos
;
OpOutputInfoList
output_infos
;
std
::
tie
(
input_infos
,
attr_infos
,
output_infos
,
std
::
ignore
)
=
op_info_concept
->
get_op_info_
();
std
::
vector
<
ir
::
OpResult
>
op_inputs
;
std
::
vector
<
ir
::
OpResult
>
op_inputs
=
{};
OpOutputMapping
arg_to_idx
;
OpOutputTypeList
op_output_types
;
std
::
tie
(
op_output_types
,
arg_to_idx
)
=
GenerateOperationOutput
(
ctx
,
op_desc
,
output_infos
);
ir
::
AttributeMap
attribute_map
=
{
{
"name"
,
ir
::
StrAttribute
::
get
(
ctx
,
op_desc
.
OutputArgumentNames
()[
0
])},
};
OpOutputTypeList
op_output_types
=
{};
std
::
tie
(
op_output_types
,
arg_to_idx
)
=
GenerateOperationOutput
(
ctx
,
op_desc
);
auto
op_info
=
LoopkUpOpInfo
(
ctx
,
op_desc
);
ir
::
Operation
*
operation
=
ir
::
Operation
::
create
(
op_inputs
,
attribute_map
,
op_output_types
,
op_info
);
ir
::
Operation
::
create
(
op_inputs
,
{}
,
op_output_types
,
op_info
);
program
->
block
()
->
push_back
(
operation
);
RecordOpResultMapping
(
param_map
,
op_desc
,
operation
,
arg_to_idx
);
...
...
@@ -447,26 +310,12 @@ ir::Operation* FetchOpHandler(ir::IrContext* ctx,
TranslationContext
*
param_map
,
ir
::
Program
*
program
,
const
OpDesc
&
op_desc
)
{
auto
op_info
=
LoopkUpOpInfo
(
ctx
,
op_desc
);
auto
*
op_info_concept
=
op_info
.
GetInterfaceImpl
<
paddle
::
dialect
::
GetOpInfoInterface
>
();
OpInputInfoList
input_infos
;
OpAttributeInfoList
attr_infos
;
OpOutputInfoList
output_infos
;
std
::
tie
(
input_infos
,
attr_infos
,
output_infos
,
std
::
ignore
)
=
op_info_concept
->
get_op_info_
();
auto
op_inputs
=
GenerateOperationInput
(
ctx
,
param_map
,
program
,
op_desc
,
op_info
.
name
(),
input_infos
);
OpOutputTypeList
op_output_types
;
ir
::
AttributeMap
attribute_map
=
{
{
"name"
,
ir
::
StrAttribute
::
get
(
ctx
,
op_desc
.
InputArgumentNames
()[
0
])},
};
auto
op_inputs
=
GenerateOperationInput
(
ctx
,
param_map
,
program
,
op_desc
);
OpOutputTypeList
op_output_types
=
{};
auto
op_info
=
LoopkUpOpInfo
(
ctx
,
op_desc
);
ir
::
Operation
*
operation
=
ir
::
Operation
::
create
(
op_inputs
,
attribute_map
,
op_output_types
,
op_info
);
ir
::
Operation
::
create
(
op_inputs
,
{}
,
op_output_types
,
op_info
);
program
->
block
()
->
push_back
(
operation
);
return
operation
;
...
...
paddle/fluid/translator/program_translator.cc
浏览文件 @
343a9e95
...
...
@@ -76,7 +76,7 @@ void ProgramTranslator::ExtractParameterFromSingleBlock(
std
::
string
get_parameter_op_name
(
ir
::
GetParameterOp
::
name
());
ir
::
OpInfo
op_info
=
ctx
->
GetRegisteredOpInfo
(
get_parameter_op_name
);
std
::
unordered_map
<
std
::
string
,
ir
::
Attribute
>
op_attribute_map
=
{
{
"parameter_name"
,
ir
::
StrAttribute
::
get
(
ctx
,
var
->
Name
())},
{
var
->
Name
()
,
ir
::
StrAttribute
::
get
(
ctx
,
var
->
Name
())},
};
ir
::
Type
translated_var_type
=
type_translator
[
var
->
GetType
()](
ctx
,
*
var
);
ir
::
Operation
*
operation
=
ir
::
Operation
::
create
(
...
...
paddle/fluid/translator/program_translator.h
浏览文件 @
343a9e95
...
...
@@ -39,9 +39,9 @@ struct VariableDefiningInfo {
ir
::
OpResult
value
;
bool
generated_by_vector
=
false
;
// true if target variab
l
e is generated by Vector<Tensor>
false
;
// true if target variabe is generated by Vector<Tensor>
int
idx_in_vector
=
-
1
;
// positive if target variab
l
e is generated by Vector<Tensor>
-
1
;
// positive if target variabe is generated by Vector<Tensor>
};
using
TranslationContext
=
...
...
paddle/fluid/translator/utils.h
已删除
100644 → 0
浏览文件 @
37930a69
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include <string_view>
namespace
paddle
{
namespace
translator
{
static
std
::
string
UnderscoreToCamelCase
(
std
::
string
str
)
{
std
::
string
camel_case
;
bool
next_upper
=
true
;
for
(
char
c
:
str
)
{
if
(
c
==
'_'
)
{
next_upper
=
true
;
}
else
{
if
(
next_upper
)
{
camel_case
+=
toupper
(
c
);
next_upper
=
false
;
}
else
{
camel_case
+=
c
;
}
}
}
return
camel_case
;
}
}
// namespace translator
}
// namespace paddle
test/cpp/ir/core/program_translator_test.cc
浏览文件 @
343a9e95
...
...
@@ -47,17 +47,17 @@ ProgramDesc load_from_file(const std::string &file_name) {
}
TEST
(
PaddleDialectTest
,
Translator
)
{
auto
p
=
load_from_file
(
"restnet50_main.prog"
)
;
EXPECT_EQ
(
p
.
Size
(),
1u
);
ir
::
IrContext
*
ctx
=
ir
::
IrContext
::
Instance
();
ctx
->
GetOrRegisterDialect
<
PaddleDialect
>
();
ctx
->
GetOrRegisterDialect
<
ir
::
Builtin
Dialect
>
();
auto
program
=
paddle
::
TranslateLegacyProgramToProgram
(
p
);
size_t
op_size
=
program
->
block
()
->
size
();
//
ops.size() = op size in BlockDesc + get_parameter_op + combine op
EXPECT_EQ
(
op_size
,
p
.
Block
(
0
).
OpSize
()
+
program
->
parameters_num
()
+
21
);
std
::
cout
<<
*
program
<<
std
::
endl
;
LOG
(
WARNING
)
<<
"TODO"
;
// auto p = load_from_file("restnet50_main.prog"
);
// EXPECT_EQ(p.Size(), 1u);
// ir::IrContext *ctx = ir::IrContext::Instance
();
// ctx->GetOrRegisterDialect<Paddle
Dialect>();
// ctx->GetOrRegisterDialect<ir::BuiltinDialect>(
);
// auto program = paddle::TranslateLegacyProgramToProgram(p);
//
size_t op_size = program->block()->size();
// // ops.size() = op size in BlockDesc + get_parameter_op + combine op
// EXPECT_EQ(op_size, p.Block(0).OpSize() + program->parameters_num() + 20);
// VLOG(0) << *program
;
}
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录