Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
343a9e95
P
Paddle
项目概览
PaddlePaddle
/
Paddle
1 年多 前同步成功
通知
2306
Star
20932
Fork
5423
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
343a9e95
编写于
6月 03, 2023
作者:
K
kangguangli
提交者:
GitHub
6月 03, 2023
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Revert "[IR] Support op attribute and refactor for new op definition (#54068)"
This reverts commit
37930a69
.
上级
37930a69
变更
10
隐藏空白更改
内联
并排
Showing
10 changed file
with
77 addition
and
650 deletion
+77
-650
paddle/fluid/translator/attribute_translator.cc
paddle/fluid/translator/attribute_translator.cc
+0
-231
paddle/fluid/translator/attribute_translator.h
paddle/fluid/translator/attribute_translator.h
+0
-54
paddle/fluid/translator/op_compat_gen.py
paddle/fluid/translator/op_compat_gen.py
+6
-41
paddle/fluid/translator/op_compat_info.cc.j2
paddle/fluid/translator/op_compat_info.cc.j2
+1
-13
paddle/fluid/translator/op_compat_info.h
paddle/fluid/translator/op_compat_info.h
+0
-48
paddle/fluid/translator/op_translator.cc
paddle/fluid/translator/op_translator.cc
+54
-205
paddle/fluid/translator/program_translator.cc
paddle/fluid/translator/program_translator.cc
+1
-1
paddle/fluid/translator/program_translator.h
paddle/fluid/translator/program_translator.h
+2
-2
paddle/fluid/translator/utils.h
paddle/fluid/translator/utils.h
+0
-42
test/cpp/ir/core/program_translator_test.cc
test/cpp/ir/core/program_translator_test.cc
+13
-13
未找到文件。
paddle/fluid/translator/attribute_translator.cc
已删除
100644 → 0
浏览文件 @
37930a69
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/translator/attribute_translator.h"
#include <string>
#include <vector>
#include "paddle/fluid/dialect/pd_attribute.h"
#include "paddle/phi/common/data_type.h"
#include "paddle/phi/common/int_array.h"
#include "paddle/phi/common/layout.h"
#include "paddle/phi/common/place.h"
#include "paddle/phi/common/scalar.h"
#include "paddle/utils/variant.h"
namespace
paddle
{
namespace
translator
{
class
AttributeVisitor
{
public:
ir
::
IrContext
*
ctx
;
AttributeVisitor
()
{
ctx
=
ir
::
IrContext
::
Instance
();
}
~
AttributeVisitor
()
{}
public:
virtual
ir
::
Attribute
operator
()(
int
i
)
{
VLOG
(
10
)
<<
"translating int"
;
return
ir
::
Int32_tAttribute
::
get
(
ctx
,
i
);
}
virtual
ir
::
Attribute
operator
()(
float
f
)
{
VLOG
(
10
)
<<
"translating float"
;
return
ir
::
FloatAttribute
::
get
(
ctx
,
f
);
}
virtual
ir
::
Attribute
operator
()(
bool
b
)
{
VLOG
(
10
)
<<
"translating bool"
;
return
ir
::
BoolAttribute
::
get
(
ctx
,
b
);
}
virtual
ir
::
Attribute
operator
()(
double
d
)
{
VLOG
(
10
)
<<
"translating double"
;
return
ir
::
DoubleAttribute
::
get
(
ctx
,
d
);
}
virtual
ir
::
Attribute
operator
()(
std
::
string
str
)
{
VLOG
(
10
)
<<
"translating string"
;
return
ir
::
StrAttribute
::
get
(
ctx
,
str
);
}
virtual
ir
::
Attribute
operator
()(
const
paddle
::
experimental
::
Scalar
&
scalar
)
{
VLOG
(
10
)
<<
"translating scalar"
;
return
paddle
::
dialect
::
ScalarAttribute
::
get
(
ctx
,
scalar
);
}
virtual
ir
::
Attribute
operator
()(
const
std
::
vector
<
std
::
string
>&
strs
)
{
VLOG
(
10
)
<<
"translating vector<string>"
;
std
::
vector
<
ir
::
Attribute
>
attrs
;
attrs
.
reserve
(
strs
.
size
());
for
(
const
auto
&
v
:
strs
)
{
attrs
.
push_back
(
ir
::
StrAttribute
::
get
(
ctx
,
v
));
}
return
ir
::
ArrayAttribute
::
get
(
ctx
,
attrs
);
}
virtual
ir
::
Attribute
operator
()(
const
std
::
vector
<
float
>&
fs
)
{
VLOG
(
10
)
<<
"translating vector<float>"
;
std
::
vector
<
ir
::
Attribute
>
attrs
;
attrs
.
reserve
(
fs
.
size
());
for
(
const
auto
&
v
:
fs
)
{
attrs
.
push_back
(
ir
::
FloatAttribute
::
get
(
ctx
,
v
));
}
return
ir
::
ArrayAttribute
::
get
(
ctx
,
attrs
);
}
virtual
ir
::
Attribute
operator
()(
const
std
::
vector
<
int
>&
is
)
{
VLOG
(
10
)
<<
"translating vector<int>"
;
std
::
vector
<
ir
::
Attribute
>
attrs
;
attrs
.
reserve
(
is
.
size
());
for
(
const
auto
&
v
:
is
)
{
attrs
.
push_back
(
ir
::
Int32_tAttribute
::
get
(
ctx
,
v
));
}
return
ir
::
ArrayAttribute
::
get
(
ctx
,
attrs
);
}
virtual
ir
::
Attribute
operator
()(
const
std
::
vector
<
bool
>&
bs
)
{
VLOG
(
10
)
<<
"translating vector<bool>"
;
std
::
vector
<
ir
::
Attribute
>
attrs
;
attrs
.
reserve
(
bs
.
size
());
for
(
const
auto
&
v
:
bs
)
{
attrs
.
push_back
(
ir
::
BoolAttribute
::
get
(
ctx
,
v
));
}
return
ir
::
ArrayAttribute
::
get
(
ctx
,
attrs
);
}
virtual
ir
::
Attribute
operator
()(
const
std
::
vector
<
int64_t
>&
i64s
)
{
VLOG
(
10
)
<<
"translating vector<int64>"
;
std
::
vector
<
ir
::
Attribute
>
attrs
;
attrs
.
reserve
(
i64s
.
size
());
for
(
const
auto
&
v
:
i64s
)
{
attrs
.
push_back
(
ir
::
Int64_tAttribute
::
get
(
ctx
,
v
));
}
return
ir
::
ArrayAttribute
::
get
(
ctx
,
attrs
);
}
virtual
ir
::
Attribute
operator
()(
const
std
::
vector
<
double
>&
ds
)
{
VLOG
(
10
)
<<
"translating vector<double>"
;
std
::
vector
<
ir
::
Attribute
>
attrs
;
attrs
.
reserve
(
ds
.
size
());
for
(
const
auto
&
v
:
ds
)
{
attrs
.
push_back
(
ir
::
DoubleAttribute
::
get
(
ctx
,
v
));
}
return
ir
::
ArrayAttribute
::
get
(
ctx
,
attrs
);
}
virtual
ir
::
Attribute
operator
()(
const
std
::
vector
<
paddle
::
experimental
::
Scalar
>&
ss
)
{
VLOG
(
10
)
<<
"translating vector<scalar>"
;
std
::
vector
<
ir
::
Attribute
>
attrs
;
attrs
.
reserve
(
ss
.
size
());
for
(
const
auto
&
v
:
ss
)
{
attrs
.
push_back
(
paddle
::
dialect
::
ScalarAttribute
::
get
(
ctx
,
v
));
}
return
ir
::
ArrayAttribute
::
get
(
ctx
,
attrs
);
}
virtual
ir
::
Attribute
operator
()(
const
paddle
::
blank
&
blank
)
{
VLOG
(
10
)
<<
"translating paddle::blank"
;
return
ir
::
Attribute
(
nullptr
);
}
template
<
typename
T
>
ir
::
Attribute
operator
()(
T
attr
)
{
VLOG
(
10
)
<<
"translating null type"
;
return
ir
::
Attribute
(
nullptr
);
}
};
class
IntArrayAttributeVisitor
:
public
AttributeVisitor
{
public:
using
AttributeVisitor
::
AttributeVisitor
;
ir
::
Attribute
operator
()(
const
std
::
vector
<
int
>&
is
)
override
{
VLOG
(
10
)
<<
"translating vector<int> to IntArray"
;
phi
::
IntArray
data
(
is
);
return
paddle
::
dialect
::
IntArrayAttribute
::
get
(
ctx
,
data
);
}
ir
::
Attribute
operator
()(
const
std
::
vector
<
int64_t
>&
is
)
override
{
VLOG
(
10
)
<<
"translating vector<int> to IntArray"
;
phi
::
IntArray
data
(
is
);
return
paddle
::
dialect
::
IntArrayAttribute
::
get
(
ctx
,
data
);
}
};
class
ScalarAttributeVisitor
:
public
AttributeVisitor
{
public:
using
AttributeVisitor
::
AttributeVisitor
;
ir
::
Attribute
operator
()(
int
i
)
override
{
VLOG
(
10
)
<<
"translating int to Scalar"
;
phi
::
Scalar
data
(
i
);
return
paddle
::
dialect
::
ScalarAttribute
::
get
(
ctx
,
data
);
}
ir
::
Attribute
operator
()(
float
f
)
override
{
VLOG
(
10
)
<<
"translating float to Scalar"
;
phi
::
Scalar
data
(
f
);
return
paddle
::
dialect
::
ScalarAttribute
::
get
(
ctx
,
data
);
}
};
class
DataTypeAttributeVisitor
:
public
AttributeVisitor
{
public:
using
AttributeVisitor
::
AttributeVisitor
;
ir
::
Attribute
operator
()(
int
i
)
override
{
VLOG
(
10
)
<<
"translating int to DataType: "
<<
i
;
phi
::
DataType
data
=
static_cast
<
phi
::
DataType
>
(
i
);
return
paddle
::
dialect
::
DataTypeAttribute
::
get
(
ctx
,
data
);
}
};
class
PlaceAttributeVisitor
:
public
AttributeVisitor
{
public:
using
AttributeVisitor
::
AttributeVisitor
;
ir
::
Attribute
operator
()(
const
paddle
::
blank
&
blank
)
override
{
VLOG
(
10
)
<<
"translating paddle::blank"
;
phi
::
Place
data
(
phi
::
AllocationType
::
CPU
);
return
paddle
::
dialect
::
PlaceAttribute
::
get
(
ctx
,
data
);
}
};
AttributeTranslator
::
AttributeTranslator
()
{
general_visitor
=
new
AttributeVisitor
();
special_visitors
[
"paddle::dialect::IntArrayAttribute"
]
=
new
IntArrayAttributeVisitor
();
special_visitors
[
"paddle::dialect::ScalarAttribute"
]
=
new
ScalarAttributeVisitor
();
special_visitors
[
"paddle::dialect::DataTypeAttribute"
]
=
new
DataTypeAttributeVisitor
();
special_visitors
[
"paddle::dialect::PlaceAttribute"
]
=
new
PlaceAttributeVisitor
();
}
ir
::
Attribute
AttributeTranslator
::
operator
()(
const
framework
::
Attribute
&
attr
)
{
return
paddle
::
visit
(
*
general_visitor
,
attr
);
}
ir
::
Attribute
AttributeTranslator
::
operator
()(
const
std
::
string
&
target_type
,
const
framework
::
Attribute
&
attr
)
{
if
(
special_visitors
.
find
(
target_type
)
==
special_visitors
.
end
())
{
VLOG
(
10
)
<<
"["
<<
target_type
<<
"] not found"
;
return
paddle
::
visit
(
*
general_visitor
,
attr
);
}
return
paddle
::
visit
(
*
(
special_visitors
.
at
(
target_type
)),
attr
);
}
}
// namespace translator
}
// namespace paddle
paddle/fluid/translator/attribute_translator.h
已删除
100644 → 0
浏览文件 @
37930a69
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <string>
#include <unordered_map>
#include "paddle/fluid/framework/attribute.h"
#include "paddle/fluid/framework/type_defs.h"
#include "paddle/ir/core/attribute.h"
#include "paddle/ir/core/builtin_attribute.h"
#include "paddle/ir/core/ir_context.h"
#pragma once
namespace
paddle
{
namespace
translator
{
class
AttributeVisitor
;
class
AttributeTranslator
{
private:
AttributeTranslator
();
AttributeVisitor
*
general_visitor
;
std
::
unordered_map
<
std
::
string
,
AttributeVisitor
*>
special_visitors
;
public:
AttributeTranslator
(
const
AttributeTranslator
&
)
=
delete
;
AttributeTranslator
&
operator
=
(
const
AttributeTranslator
&
)
=
delete
;
AttributeTranslator
(
AttributeTranslator
&&
)
=
delete
;
AttributeTranslator
&
operator
=
(
AttributeTranslator
&&
)
=
delete
;
static
auto
&
instance
()
{
static
AttributeTranslator
attribute_translator
;
return
attribute_translator
;
}
ir
::
Attribute
operator
()(
const
framework
::
Attribute
&
attr
);
ir
::
Attribute
operator
()(
const
std
::
string
&
target_type
,
const
framework
::
Attribute
&
attr
);
};
}
// namespace translator
}
// namespace paddle
paddle/fluid/translator/op_compat_gen.py
浏览文件 @
343a9e95
...
@@ -14,7 +14,6 @@
...
@@ -14,7 +14,6 @@
import
argparse
import
argparse
from
pathlib
import
Path
from
pathlib
import
Path
from
typing
import
Dict
import
yaml
import
yaml
from
jinja2
import
Environment
,
FileSystemLoader
,
StrictUndefined
from
jinja2
import
Environment
,
FileSystemLoader
,
StrictUndefined
...
@@ -34,7 +33,7 @@ def OpNameNormalizerInitialization(
...
@@ -34,7 +33,7 @@ def OpNameNormalizerInitialization(
op_compat_yaml_file
:
str
=
""
,
output_source_file
:
str
=
""
op_compat_yaml_file
:
str
=
""
,
output_source_file
:
str
=
""
)
->
None
:
)
->
None
:
def
to_phi_and_fluid_op_name
(
op_item
):
def
to_phi_and_fluid_op_name
(
op_item
):
# Templat
e
: - op : phi_name (fluid_name)
# Templat: - op : phi_name (fluid_name)
names
=
op_item
.
split
(
'('
)
names
=
op_item
.
split
(
'('
)
if
len
(
names
)
==
1
:
if
len
(
names
)
==
1
:
phi_fluid_name
=
names
[
0
].
strip
()
phi_fluid_name
=
names
[
0
].
strip
()
...
@@ -47,55 +46,21 @@ def OpNameNormalizerInitialization(
...
@@ -47,55 +46,21 @@ def OpNameNormalizerInitialization(
with
open
(
op_compat_yaml_file
,
"r"
)
as
f
:
with
open
(
op_compat_yaml_file
,
"r"
)
as
f
:
op_compat_infos
=
yaml
.
safe_load
(
f
)
op_compat_infos
=
yaml
.
safe_load
(
f
)
op_name_mappings
=
{}
op_name_mappings
=
{}
op_arg_name_mappings
=
{}
for
op_compat_item
in
op_compat_infos
:
for
op_compat_item
in
op_compat_infos
:
def
insert_new_mappings
(
op_name_str
:
str
)
->
str
:
def
insert_new_mappings
(
op_name_str
)
:
normalized_name
,
legacy_name
=
to_phi_and_fluid_op_name
(
op_name_str
)
normalized_name
,
legacy_name
=
to_phi_and_fluid_op_name
(
op_name_str
)
if
normalized_name
==
legacy_name
:
if
normalized_name
==
legacy_name
:
return
normalized_name
,
legacy_name
op_name_mappings
[
legacy_name
]
=
normalized_name
return
normalized_name
,
legacy_name
def
insert_new_arg_mappings
(
op_name
:
str
,
arg_mapping
:
Dict
[
str
,
str
]):
if
op_name
is
None
:
return
return
if
op_name
not
in
op_arg_name_mappings
:
op_name_mappings
[
legacy_name
]
=
normalized_name
op_arg_name_mappings
[
op_name
]
=
{}
op_arg_name_mappings
[
op_name
].
update
(
arg_mapping
)
_
,
legacy_name
=
insert_new_mappings
(
op_compat_item
[
"op"
])
insert_new_mappings
(
op_compat_item
[
"op"
])
legacy_backward_op_names
=
[]
if
"backward"
in
op_compat_item
:
if
"backward"
in
op_compat_item
:
backward_op_name_mapping_paris
=
op_compat_item
[
"backward"
].
split
(
insert_new_mappings
(
op_compat_item
[
"backward"
])
","
)
for
pair
in
backward_op_name_mapping_paris
:
_
,
legacy_backward_op_name
=
insert_new_mappings
(
pair
)
legacy_backward_op_names
.
append
(
legacy_backward_op_name
)
if
"inputs"
in
op_compat_item
:
insert_new_arg_mappings
(
legacy_name
,
op_compat_item
[
"inputs"
])
for
backward_op
in
legacy_backward_op_names
:
insert_new_arg_mappings
(
backward_op
,
op_compat_item
[
"inputs"
])
if
"attrs"
in
op_compat_item
:
insert_new_arg_mappings
(
legacy_name
,
op_compat_item
[
"attrs"
])
for
backward_op
in
legacy_backward_op_names
:
insert_new_arg_mappings
(
backward_op
,
op_compat_item
[
"attrs"
])
if
"outputs"
in
op_compat_item
:
insert_new_arg_mappings
(
legacy_name
,
op_compat_item
[
"outputs"
])
for
backward_op
in
legacy_backward_op_names
:
insert_new_arg_mappings
(
backward_op
,
op_compat_item
[
"outputs"
])
# special op mappings
op_name_mappings
[
"fetch_v2"
]
=
"fetch"
op_name_normailzer_template
=
env
.
get_template
(
"op_compat_info.cc.j2"
)
op_name_normailzer_template
=
env
.
get_template
(
"op_compat_info.cc.j2"
)
with
open
(
output_source_file
,
'wt'
)
as
f
:
with
open
(
output_source_file
,
'wt'
)
as
f
:
op_compat_definition
=
op_name_normailzer_template
.
render
(
op_compat_definition
=
op_name_normailzer_template
.
render
(
op_name_pairs
=
op_name_mappings
,
op_name_paris
=
op_name_mappings
op_arg_name_pairs
=
op_arg_name_mappings
,
)
)
f
.
write
(
op_compat_definition
)
f
.
write
(
op_compat_definition
)
...
...
paddle/fluid/translator/op_compat_info.cc.j2
浏览文件 @
343a9e95
...
@@ -5,22 +5,10 @@ namespace translator {
...
@@ -5,22 +5,10 @@ namespace translator {
OpNameNormalizer::OpNameNormalizer() {
OpNameNormalizer::OpNameNormalizer() {
op_name_mappings = {
op_name_mappings = {
{% for legacy_name, normalized_name in op_name_pa
ir
s.items() %}
{% for legacy_name, normalized_name in op_name_pa
ri
s.items() %}
{ "{{legacy_name}}", "{{normalized_name}}" },
{ "{{legacy_name}}", "{{normalized_name}}" },
{% endfor %}
{% endfor %}
};
};
op_arg_name_mappings = {
{% for op_name, arg_name_mappings in op_arg_name_pairs.items() %}
{
"{{op_name}}",
{
{% for normalized_name, legacy_name in arg_name_mappings.items() %}
{ "{{normalized_name}}", "{{legacy_name}}" },
{% endfor %}
},
},
{% endfor %}
};
}
}
} // namespace translator
} // namespace translator
...
...
paddle/fluid/translator/op_compat_info.h
浏览文件 @
343a9e95
...
@@ -12,14 +12,11 @@
...
@@ -12,14 +12,11 @@
// See the License for the specific language governing permissions and
// See the License for the specific language governing permissions and
// limitations under the License.
// limitations under the License.
#include <functional>
#include <string>
#include <string>
#include <unordered_map>
#include <unordered_map>
#include "glog/logging.h"
#include "glog/logging.h"
#include "paddle/fluid/translator/utils.h"
#pragma once
#pragma once
namespace
paddle
{
namespace
paddle
{
...
@@ -29,8 +26,6 @@ class OpNameNormalizer {
...
@@ -29,8 +26,6 @@ class OpNameNormalizer {
private:
private:
OpNameNormalizer
();
// Disallow instantiation outside of the class.
OpNameNormalizer
();
// Disallow instantiation outside of the class.
std
::
unordered_map
<
std
::
string
,
std
::
string
>
op_name_mappings
;
std
::
unordered_map
<
std
::
string
,
std
::
string
>
op_name_mappings
;
std
::
unordered_map
<
std
::
string
,
std
::
unordered_map
<
std
::
string
,
std
::
string
>>
op_arg_name_mappings
;
public:
public:
OpNameNormalizer
(
const
OpNameNormalizer
&
)
=
delete
;
OpNameNormalizer
(
const
OpNameNormalizer
&
)
=
delete
;
...
@@ -49,49 +44,6 @@ class OpNameNormalizer {
...
@@ -49,49 +44,6 @@ class OpNameNormalizer {
}
}
return
op_name_mappings
.
at
(
op_type
);
return
op_name_mappings
.
at
(
op_type
);
}
}
std
::
string
GetLegacyArgName
(
const
std
::
string
&
op_type
,
const
std
::
string
&
arg_name
)
{
bool
is_grad_op
=
(
op_type
.
find
(
"grad"
)
!=
std
::
string
::
npos
);
bool
is_grad_arg
=
(
arg_name
.
find
(
"grad"
)
!=
std
::
string
::
npos
);
if
(
is_grad_op
&&
is_grad_arg
)
{
std
::
string
target
=
"_grad"
;
std
::
string
data
=
"@GRAD"
;
size_t
first_grad_pos
=
arg_name
.
find_first_of
(
target
);
std
::
string
legacy_name
=
this
->
GetLegacyArgName
(
op_type
,
arg_name
.
substr
(
0
,
first_grad_pos
));
legacy_name
+=
arg_name
.
substr
(
first_grad_pos
);
for
(
size_t
pos
=
0
;
legacy_name
.
npos
!=
(
pos
=
legacy_name
.
find
(
target
,
pos
));
pos
+=
data
.
length
())
{
legacy_name
.
replace
(
pos
,
target
.
length
(),
data
);
}
return
legacy_name
;
}
if
(
op_arg_name_mappings
.
find
(
op_type
)
==
op_arg_name_mappings
.
end
())
{
return
UnderscoreToCamelCase
(
arg_name
);
}
auto
&
arg_mappings
=
op_arg_name_mappings
[
op_type
];
if
(
arg_mappings
.
find
(
arg_name
)
==
arg_mappings
.
end
())
{
return
UnderscoreToCamelCase
(
arg_name
);
}
return
arg_mappings
.
at
(
arg_name
);
}
std
::
string
GetLegacyAttrName
(
const
std
::
string
&
op_type
,
const
std
::
string
&
arg_name
)
{
if
(
op_arg_name_mappings
.
find
(
op_type
)
==
op_arg_name_mappings
.
end
())
{
VLOG
(
10
)
<<
"["
<<
op_type
<<
"] not found"
;
return
arg_name
;
}
auto
&
arg_mappings
=
op_arg_name_mappings
[
op_type
];
if
(
arg_mappings
.
find
(
arg_name
)
==
arg_mappings
.
end
())
{
VLOG
(
10
)
<<
"["
<<
op_type
<<
"]["
<<
arg_name
<<
"] not found"
;
return
arg_name
;
}
return
arg_mappings
.
at
(
arg_name
);
}
};
};
}
// namespace translator
}
// namespace translator
...
...
paddle/fluid/translator/op_translator.cc
浏览文件 @
343a9e95
...
@@ -15,23 +15,19 @@
...
@@ -15,23 +15,19 @@
#include "paddle/fluid/translator/op_translator.h"
#include "paddle/fluid/translator/op_translator.h"
#include <algorithm>
#include <algorithm>
#include <cctype>
#include <numeric>
#include <numeric>
#include <string>
#include <string>
#include <tuple>
#include <tuple>
#include <unordered_map>
#include <unordered_map>
#include <vector>
#include <vector>
#include "paddle/fluid/dialect/pd_interface.h"
#include "paddle/fluid/framework/op_desc.h"
#include "paddle/fluid/framework/op_desc.h"
#include "paddle/fluid/translator/attribute_translator.h"
#include "paddle/fluid/translator/op_compat_info.h"
#include "paddle/fluid/translator/op_compat_info.h"
#include "paddle/fluid/translator/program_translator.h"
#include "paddle/fluid/translator/program_translator.h"
#include "paddle/fluid/translator/type_translator.h"
#include "paddle/fluid/translator/type_translator.h"
#include "paddle/ir/core/builtin_op.h"
#include "paddle/ir/core/builtin_op.h"
#include "paddle/ir/core/builtin_type.h"
#include "paddle/ir/core/builtin_type.h"
#include "paddle/ir/core/ir_context.h"
#include "paddle/ir/core/ir_context.h"
#include "paddle/ir/core/operation.h"
#include "paddle/ir/core/value.h"
#include "paddle/ir/core/value.h"
#include "paddle/phi/core/enforce.h"
#include "paddle/phi/core/enforce.h"
...
@@ -46,24 +42,11 @@ using BlockDesc = paddle::framework::BlockDesc;
...
@@ -46,24 +42,11 @@ using BlockDesc = paddle::framework::BlockDesc;
using
VarDesc
=
paddle
::
framework
::
VarDesc
;
using
VarDesc
=
paddle
::
framework
::
VarDesc
;
using
OpOutputTypeList
=
std
::
vector
<
ir
::
Type
>
;
using
OpOutputTypeList
=
std
::
vector
<
ir
::
Type
>
;
using
OpOutputMapping
=
std
::
unordered_map
<
std
::
string
,
ResultIdx
>
;
using
OpOutputMapping
=
std
::
unordered_map
<
std
::
string
,
ResultIdx
>
;
using
OpInputInfo
=
paddle
::
dialect
::
OpInputInfo
;
using
OpInputInfoList
=
std
::
vector
<
paddle
::
dialect
::
OpInputInfo
>
;
using
OpAttributeInfo
=
paddle
::
dialect
::
OpAttributeInfo
;
using
OpAttributeInfoList
=
std
::
vector
<
paddle
::
dialect
::
OpAttributeInfo
>
;
using
OpOutputInfo
=
paddle
::
dialect
::
OpOutputInfo
;
using
OpOutputInfoList
=
std
::
vector
<
paddle
::
dialect
::
OpOutputInfo
>
;
static
const
char
kTargetDialectPrefix
[]
=
"pd."
;
static
const
char
kTargetDialectPrefix
[]
=
"pd."
;
static
const
std
::
unordered_set
<
std
::
string
>
special_inplace_ops
=
{
"batch_norm"
,
};
inline
bool
IsInplace
(
const
OpDesc
&
op_desc
)
{
inline
bool
IsInplace
(
const
OpDesc
&
op_desc
)
{
bool
inplace
=
false
;
bool
inplace
=
false
;
if
(
special_inplace_ops
.
count
(
op_desc
.
Type
()))
{
return
inplace
;
}
auto
input_names
=
op_desc
.
InputArgumentNames
();
auto
input_names
=
op_desc
.
InputArgumentNames
();
auto
output_names
=
op_desc
.
OutputArgumentNames
();
auto
output_names
=
op_desc
.
OutputArgumentNames
();
...
@@ -146,7 +129,7 @@ inline ir::Operation* InsertCombineOperationForTarget(
...
@@ -146,7 +129,7 @@ inline ir::Operation* InsertCombineOperationForTarget(
std
::
vector
<
ir
::
OpResult
>
src_values
;
std
::
vector
<
ir
::
OpResult
>
src_values
;
std
::
vector
<
ir
::
Type
>
types_in_vec
;
std
::
vector
<
ir
::
Type
>
types_in_vec
;
for
(
const
auto
&
arg_name
:
args
)
{
for
(
auto
arg_name
:
args
)
{
auto
defining_info
=
param_map
->
at
(
arg_name
);
auto
defining_info
=
param_map
->
at
(
arg_name
);
src_values
.
push_back
(
defining_info
.
value
);
src_values
.
push_back
(
defining_info
.
value
);
types_in_vec
.
push_back
(
defining_info
.
value
.
type
());
types_in_vec
.
push_back
(
defining_info
.
value
.
type
());
...
@@ -158,25 +141,13 @@ inline ir::Operation* InsertCombineOperationForTarget(
...
@@ -158,25 +141,13 @@ inline ir::Operation* InsertCombineOperationForTarget(
return
operation
;
return
operation
;
}
}
inline
ir
::
Operation
*
InsertConstantOperationForOptionalArg
(
ir
::
IrContext
*
ctx
,
ir
::
Program
*
program
)
{
std
::
string
constant_op_name
(
ir
::
ConstantOp
::
name
());
ir
::
OpInfo
op_info
=
ctx
->
GetRegisteredOpInfo
(
constant_op_name
);
ir
::
Type
null_type
=
ir
::
Type
(
nullptr
);
ir
::
Operation
*
operation
=
ir
::
Operation
::
create
({},
{},
{
null_type
},
op_info
);
program
->
block
()
->
push_back
(
operation
);
return
operation
;
}
inline
std
::
vector
<
ir
::
OpResult
>
GenerateOperationInput
(
inline
std
::
vector
<
ir
::
OpResult
>
GenerateOperationInput
(
ir
::
IrContext
*
ctx
,
ir
::
IrContext
*
ctx
,
TranslationContext
*
param_map
,
TranslationContext
*
param_map
,
ir
::
Program
*
program
,
ir
::
Program
*
program
,
const
OpDesc
&
op_desc
,
const
OpDesc
&
op_desc
)
{
const
std
::
string
&
normalized_op_name
,
std
::
vector
<
ir
::
OpResult
>
op_inputs
=
{};
const
OpInputInfoList
&
input_infos
)
{
// scan all inputs to see if any of them is generated as a vector<Tensor>
// scan all inputs to see if any of them is generated as a vector<Tensor>
// so need an additional `SliceOp` to take it out.
// so need an additional `SliceOp` to take it out.
for
(
const
auto
&
n
:
op_desc
.
Inputs
())
{
for
(
const
auto
&
n
:
op_desc
.
Inputs
())
{
...
@@ -188,7 +159,7 @@ inline std::vector<ir::OpResult> GenerateOperationInput(
...
@@ -188,7 +159,7 @@ inline std::vector<ir::OpResult> GenerateOperationInput(
param_map
->
count
(
arg_name
),
param_map
->
count
(
arg_name
),
0
,
0
,
platform
::
errors
::
PreconditionNotMet
(
platform
::
errors
::
PreconditionNotMet
(
"arg %s.%s as input should be exists before prasing %
s
"
,
"arg %s.%s as input should be exists before prasing %
d
"
,
name
,
name
,
arg_name
,
arg_name
,
op_desc
.
Type
()));
op_desc
.
Type
()));
...
@@ -200,116 +171,73 @@ inline std::vector<ir::OpResult> GenerateOperationInput(
...
@@ -200,116 +171,73 @@ inline std::vector<ir::OpResult> GenerateOperationInput(
}
}
}
}
std
::
vector
<
ir
::
OpResult
>
op_inputs
;
for
(
const
auto
&
n
:
op_desc
.
Inputs
())
{
auto
&
op_normalizer
=
OpNameNormalizer
::
instance
();
auto
&
name
=
n
.
first
;
VLOG
(
10
)
<<
"[input retriving]"
for
(
const
auto
&
info
:
input_infos
)
{
<<
"["
<<
op_desc
.
Type
()
<<
"]"
<<
name
;
std
::
string
legacy_input_name
=
auto
&
args
=
n
.
second
;
op_normalizer
.
GetLegacyArgName
(
op_desc
.
Type
(),
info
.
name
);
// return empty OpResult if this arg is optional and not shown in OpDesc
// TODO(lyk): HasInput doesnot consider variadic attribute
if
(
!
op_desc
.
HasInput
(
legacy_input_name
))
{
PADDLE_ENFORCE
(
info
.
optional
,
platform
::
errors
::
PreconditionNotMet
(
"Op %s arg %s should be optional if it can be empty"
,
op_desc
.
Type
(),
legacy_input_name
));
op_inputs
.
push_back
(
ir
::
OpResult
(
nullptr
));
continue
;
}
const
auto
&
legacy_input_vars
=
op_desc
.
Input
(
legacy_input_name
,
true
);
bool
is_vector
=
(
info
.
type_name
.
find
(
"VectorType"
)
!=
std
::
string
::
npos
);
// if src type is Tensor
// if src type is Tensor or a Vector<Tensor> with size <= 1
if
(
!
is_vector
)
{
if
(
args
.
size
()
<=
1
)
{
auto
defining_info
=
(
*
param_map
)[
legacy_input_vars
[
0
]];
for
(
const
auto
&
arg_name
:
args
)
{
op_inputs
.
push_back
(
defining_info
.
value
);
auto
defining_info
=
(
*
param_map
)[
arg_name
];
op_inputs
.
push_back
(
defining_info
.
value
);
}
// if src type is Vector<Tesnor> , need an additional `CombineOp` to
// if src type is Vector<Tesnor> , need an additional `CombineOp` to
// assemble them.
// assemble them.
}
else
{
}
else
{
auto
*
combine_op
=
InsertCombineOperationForTarget
(
auto
*
combine_op
=
ctx
,
param_map
,
program
,
legacy_input_var
s
);
InsertCombineOperationForTarget
(
ctx
,
param_map
,
program
,
arg
s
);
op_inputs
.
push_back
(
combine_op
->
GetResultByIndex
(
0
));
op_inputs
.
push_back
(
combine_op
->
GetResultByIndex
(
0
));
}
}
}
}
return
op_inputs
;
return
op_inputs
;
}
}
inline
std
::
tuple
<
OpOutputTypeList
,
OpOutputMapping
>
GenerateOperationOutput
(
inline
std
::
tuple
<
OpOutputTypeList
,
OpOutputMapping
>
GenerateOperationOutput
(
ir
::
IrContext
*
ctx
,
ir
::
IrContext
*
ctx
,
const
OpDesc
&
op_desc
)
{
const
OpDesc
&
op_desc
,
const
OpOutputInfoList
&
output_infos
)
{
OpOutputMapping
arg_to_idx
;
OpOutputMapping
arg_to_idx
;
OpOutputTypeList
op_output_types
=
{};
OpOutputTypeList
op_output_types
=
{};
auto
&
type_translator
=
TypeTranslator
::
instance
();
auto
&
type_translator
=
TypeTranslator
::
instance
();
auto
&
op_normalizer
=
OpNameNormalizer
::
instance
();
const
BlockDesc
*
block
=
op_desc
.
Block
();
const
BlockDesc
*
block
=
op_desc
.
Block
();
for
(
const
auto
&
n
:
op_desc
.
Outputs
())
{
auto
&
name
=
n
.
first
;
VLOG
(
10
)
<<
"[output translating]"
<<
"["
<<
op_desc
.
Type
()
<<
"]"
<<
name
;
auto
&
args
=
n
.
second
;
for
(
const
auto
&
info
:
output_infos
)
{
size_t
cur_output_idx
=
op_output_types
.
size
();
size_t
cur_output_idx
=
op_output_types
.
size
();
std
::
string
legacy_output_name
=
op_normalizer
.
GetLegacyArgName
(
op_desc
.
Type
(),
info
.
name
);
// return empty type if this arg is optional and not shown in OpDesc
// TODO(lyk): HasOutput doesnot consider variadic attribute
if
(
!
op_desc
.
HasOutput
(
legacy_output_name
))
{
VLOG
(
10
)
<<
"[output translating]"
<<
"["
<<
op_desc
.
Type
()
<<
"] optional "
<<
info
.
name
<<
" :"
<<
info
.
type_name
<<
" "
<<
legacy_output_name
;
PADDLE_ENFORCE
(
info
.
optional
,
platform
::
errors
::
PreconditionNotMet
(
"Op %s arg %s should be optional if it can be empty"
,
op_desc
.
Type
(),
legacy_output_name
));
op_output_types
.
push_back
(
ir
::
Type
(
nullptr
));
continue
;
}
const
auto
&
legacy_output_vars
=
op_desc
.
Output
(
legacy_output_name
);
bool
is_vector
=
(
info
.
type_name
.
find
(
"VectorType"
)
!=
std
::
string
::
npos
);
// if src type is Tensor
if
(
!
is_vector
)
{
VLOG
(
10
)
<<
"[output translating]"
<<
"["
<<
op_desc
.
Type
()
<<
"]"
<<
info
.
name
<<
" :"
<<
info
.
type_name
<<
" "
<<
legacy_output_name
;
if
(
legacy_output_vars
.
size
()
==
0
)
{
op_output_types
.
push_back
(
ir
::
Type
(
nullptr
));
continue
;
}
auto
&
var_name
=
legacy_output_vars
[
0
];
// if src type is Tensor or a Vector<Tensor> with size <= 1
VarDesc
*
var
=
block
->
FindVarRecursive
(
var_name
);
if
(
args
.
size
()
<=
1
)
{
VLOG
(
10
)
<<
"[output translating]"
for
(
const
auto
&
arg_name
:
args
)
{
<<
"["
<<
op_desc
.
Type
()
<<
"]"
<<
info
.
name
<<
" "
<<
var_name
VarDesc
*
var
=
block
->
FindVarRecursive
(
arg_name
);
<<
" "
<<
var
->
GetType
();
VLOG
(
10
)
<<
"[output translating]"
<<
"["
<<
op_desc
.
Type
()
<<
"]"
<<
name
<<
" "
<<
arg_name
<<
" "
<<
var
->
GetType
();
ir
::
Type
translated_var_type
=
type_translator
[
var
->
GetType
()](
ctx
,
*
var
);
ir
::
Type
translated_var_type
=
type_translator
[
var
->
GetType
()](
ctx
,
*
var
);
arg_to_idx
[
var_name
]
=
cur_output_idx
;
arg_to_idx
[
arg_name
]
=
cur_output_idx
;
op_output_types
.
push_back
(
translated_var_type
);
op_output_types
.
push_back
(
translated_var_type
);
}
// if src type is Vector<Tesnor>
// if src type is Vector<Tesnor>
}
else
{
}
else
{
VLOG
(
10
)
<<
"[output translating]"
<<
"["
<<
op_desc
.
Type
()
<<
"]"
<<
info
.
name
<<
" :"
<<
info
.
type_name
<<
" "
<<
legacy_output_name
;
std
::
vector
<
ir
::
Type
>
types
;
std
::
vector
<
ir
::
Type
>
types
;
for
(
const
auto
&
var_name
:
legacy_output_var
s
)
{
for
(
const
auto
&
arg_name
:
arg
s
)
{
VarDesc
*
var
=
block
->
FindVarRecursive
(
var
_name
);
VarDesc
*
var
=
block
->
FindVarRecursive
(
arg
_name
);
VLOG
(
10
)
<<
"[output translating]"
VLOG
(
10
)
<<
"[output translating]"
<<
"["
<<
op_desc
.
Type
()
<<
"]"
<<
info
.
name
<<
" "
<<
var
_name
<<
"["
<<
op_desc
.
Type
()
<<
"]"
<<
name
<<
" "
<<
arg
_name
<<
" "
<<
var
->
GetType
();
<<
" "
<<
var
->
GetType
();
ir
::
Type
translated_var_type
=
ir
::
Type
translated_var_type
=
type_translator
[
var
->
GetType
()](
ctx
,
*
var
);
type_translator
[
var
->
GetType
()](
ctx
,
*
var
);
types
.
push_back
(
translated_var_type
);
types
.
push_back
(
translated_var_type
);
arg_to_idx
[
var
_name
]
=
cur_output_idx
;
arg_to_idx
[
arg
_name
]
=
cur_output_idx
;
}
}
ir
::
Type
vec_type
=
ir
::
VectorType
::
get
(
ctx
,
types
);
ir
::
Type
vec_type
=
ir
::
VectorType
::
get
(
ctx
,
types
);
op_output_types
.
push_back
(
vec_type
);
op_output_types
.
push_back
(
vec_type
);
...
@@ -318,38 +246,6 @@ inline std::tuple<OpOutputTypeList, OpOutputMapping> GenerateOperationOutput(
...
@@ -318,38 +246,6 @@ inline std::tuple<OpOutputTypeList, OpOutputMapping> GenerateOperationOutput(
return
{
op_output_types
,
arg_to_idx
};
return
{
op_output_types
,
arg_to_idx
};
}
}
inline
ir
::
AttributeMap
TranslateOpAttribute
(
std
::
string
normalized_op_name
,
const
OpAttributeInfoList
&
op_attr_infos
,
const
OpDesc
&
op_desc
)
{
auto
&
attribute_translator
=
AttributeTranslator
::
instance
();
auto
&
op_normalizer
=
OpNameNormalizer
::
instance
();
ir
::
AttributeMap
attribute_map
=
{};
for
(
const
auto
&
info
:
op_attr_infos
)
{
auto
legacy_attr_name
=
op_normalizer
.
GetLegacyAttrName
(
op_desc
.
Type
(),
info
.
name
);
paddle
::
framework
::
Attribute
legacy_attr
;
if
(
op_desc
.
HasAttr
(
legacy_attr_name
))
{
legacy_attr
=
op_desc
.
GetAttr
(
legacy_attr_name
);
}
VLOG
(
10
)
<<
"attribute in "
<<
op_desc
.
Type
()
<<
" name: "
<<
legacy_attr_name
<<
" "
<<
legacy_attr
.
index
();
ir
::
Attribute
new_attr
=
attribute_translator
(
info
.
type_name
,
legacy_attr
);
attribute_map
[
info
.
name
]
=
new_attr
;
if
(
!
new_attr
)
{
VLOG
(
0
)
<<
"empty attribute in "
<<
op_desc
.
Type
()
<<
" name: "
<<
info
.
name
;
}
else
{
VLOG
(
10
)
<<
"new attribute in "
<<
op_desc
.
Type
()
<<
" name: "
<<
info
.
name
<<
" "
<<
new_attr
.
storage
();
}
}
return
attribute_map
;
}
inline
void
RecordOpResultMapping
(
TranslationContext
*
param_map
,
inline
void
RecordOpResultMapping
(
TranslationContext
*
param_map
,
const
OpDesc
&
op_desc
,
const
OpDesc
&
op_desc
,
ir
::
Operation
*
operation
,
ir
::
Operation
*
operation
,
...
@@ -378,34 +274,15 @@ ir::Operation* GeneralOpHandler(ir::IrContext* ctx,
...
@@ -378,34 +274,15 @@ ir::Operation* GeneralOpHandler(ir::IrContext* ctx,
TranslationContext
*
param_map
,
TranslationContext
*
param_map
,
ir
::
Program
*
program
,
ir
::
Program
*
program
,
const
OpDesc
&
op_desc
)
{
const
OpDesc
&
op_desc
)
{
auto
op_info
=
LoopkUpOpInfo
(
ctx
,
op_desc
);
auto
op_inputs
=
GenerateOperationInput
(
ctx
,
param_map
,
program
,
op_desc
);
auto
*
op_info_concept
=
op_info
.
GetInterfaceImpl
<
paddle
::
dialect
::
GetOpInfoInterface
>
();
OpInputInfoList
input_infos
;
OpAttributeInfoList
attr_infos
;
OpOutputInfoList
output_infos
;
std
::
tie
(
input_infos
,
attr_infos
,
output_infos
,
std
::
ignore
)
=
op_info_concept
->
get_op_info_
();
auto
op_inputs
=
GenerateOperationInput
(
ctx
,
param_map
,
program
,
op_desc
,
op_info
.
name
(),
input_infos
);
OpOutputMapping
arg_to_idx
;
OpOutputMapping
arg_to_idx
;
OpOutputTypeList
op_output_types
;
OpOutputTypeList
op_output_types
=
{};
std
::
tie
(
op_output_types
,
arg_to_idx
)
=
std
::
tie
(
op_output_types
,
arg_to_idx
)
=
GenerateOperationOutput
(
ctx
,
op_desc
);
GenerateOperationOutput
(
ctx
,
op_desc
,
output_infos
);
auto
op_info
=
LoopkUpOpInfo
(
ctx
,
op_desc
);
auto
attribute_map
=
TranslateOpAttribute
(
op_info
.
name
(),
attr_infos
,
op_desc
);
VLOG
(
4
)
<<
"[general op]["
<<
op_desc
.
Type
()
<<
"] preparation end."
;
ir
::
Operation
*
operation
=
ir
::
Operation
*
operation
=
ir
::
Operation
::
create
(
op_inputs
,
attribute_map
,
op_output_types
,
op_info
);
ir
::
Operation
::
create
(
op_inputs
,
{},
op_output_types
,
op_info
);
VLOG
(
4
)
<<
"[general op]["
<<
op_desc
.
Type
()
<<
"] opearation creation end."
;
program
->
block
()
->
push_back
(
operation
);
program
->
block
()
->
push_back
(
operation
);
VLOG
(
4
)
<<
"[general op]["
<<
op_desc
.
Type
()
<<
"] opearation insertion end."
;
RecordOpResultMapping
(
param_map
,
op_desc
,
operation
,
arg_to_idx
);
RecordOpResultMapping
(
param_map
,
op_desc
,
operation
,
arg_to_idx
);
return
operation
;
return
operation
;
...
@@ -415,28 +292,14 @@ ir::Operation* FeedOpHandler(ir::IrContext* ctx,
...
@@ -415,28 +292,14 @@ ir::Operation* FeedOpHandler(ir::IrContext* ctx,
TranslationContext
*
param_map
,
TranslationContext
*
param_map
,
ir
::
Program
*
program
,
ir
::
Program
*
program
,
const
OpDesc
&
op_desc
)
{
const
OpDesc
&
op_desc
)
{
auto
op_info
=
LoopkUpOpInfo
(
ctx
,
op_desc
);
std
::
vector
<
ir
::
OpResult
>
op_inputs
=
{};
auto
*
op_info_concept
=
op_info
.
GetInterfaceImpl
<
paddle
::
dialect
::
GetOpInfoInterface
>
();
OpInputInfoList
input_infos
;
OpAttributeInfoList
attr_infos
;
OpOutputInfoList
output_infos
;
std
::
tie
(
input_infos
,
attr_infos
,
output_infos
,
std
::
ignore
)
=
op_info_concept
->
get_op_info_
();
std
::
vector
<
ir
::
OpResult
>
op_inputs
;
OpOutputMapping
arg_to_idx
;
OpOutputMapping
arg_to_idx
;
OpOutputTypeList
op_output_types
;
OpOutputTypeList
op_output_types
=
{};
std
::
tie
(
op_output_types
,
arg_to_idx
)
=
std
::
tie
(
op_output_types
,
arg_to_idx
)
=
GenerateOperationOutput
(
ctx
,
op_desc
);
GenerateOperationOutput
(
ctx
,
op_desc
,
output_infos
);
auto
op_info
=
LoopkUpOpInfo
(
ctx
,
op_desc
);
ir
::
AttributeMap
attribute_map
=
{
{
"name"
,
ir
::
StrAttribute
::
get
(
ctx
,
op_desc
.
OutputArgumentNames
()[
0
])},
};
ir
::
Operation
*
operation
=
ir
::
Operation
*
operation
=
ir
::
Operation
::
create
(
op_inputs
,
attribute_map
,
op_output_types
,
op_info
);
ir
::
Operation
::
create
(
op_inputs
,
{}
,
op_output_types
,
op_info
);
program
->
block
()
->
push_back
(
operation
);
program
->
block
()
->
push_back
(
operation
);
RecordOpResultMapping
(
param_map
,
op_desc
,
operation
,
arg_to_idx
);
RecordOpResultMapping
(
param_map
,
op_desc
,
operation
,
arg_to_idx
);
...
@@ -447,26 +310,12 @@ ir::Operation* FetchOpHandler(ir::IrContext* ctx,
...
@@ -447,26 +310,12 @@ ir::Operation* FetchOpHandler(ir::IrContext* ctx,
TranslationContext
*
param_map
,
TranslationContext
*
param_map
,
ir
::
Program
*
program
,
ir
::
Program
*
program
,
const
OpDesc
&
op_desc
)
{
const
OpDesc
&
op_desc
)
{
auto
op_info
=
LoopkUpOpInfo
(
ctx
,
op_desc
);
auto
op_inputs
=
GenerateOperationInput
(
ctx
,
param_map
,
program
,
op_desc
);
auto
*
op_info_concept
=
op_info
.
GetInterfaceImpl
<
paddle
::
dialect
::
GetOpInfoInterface
>
();
OpInputInfoList
input_infos
;
OpAttributeInfoList
attr_infos
;
OpOutputInfoList
output_infos
;
std
::
tie
(
input_infos
,
attr_infos
,
output_infos
,
std
::
ignore
)
=
op_info_concept
->
get_op_info_
();
auto
op_inputs
=
GenerateOperationInput
(
ctx
,
param_map
,
program
,
op_desc
,
op_info
.
name
(),
input_infos
);
OpOutputTypeList
op_output_types
;
ir
::
AttributeMap
attribute_map
=
{
{
"name"
,
ir
::
StrAttribute
::
get
(
ctx
,
op_desc
.
InputArgumentNames
()[
0
])},
};
OpOutputTypeList
op_output_types
=
{};
auto
op_info
=
LoopkUpOpInfo
(
ctx
,
op_desc
);
ir
::
Operation
*
operation
=
ir
::
Operation
*
operation
=
ir
::
Operation
::
create
(
op_inputs
,
attribute_map
,
op_output_types
,
op_info
);
ir
::
Operation
::
create
(
op_inputs
,
{}
,
op_output_types
,
op_info
);
program
->
block
()
->
push_back
(
operation
);
program
->
block
()
->
push_back
(
operation
);
return
operation
;
return
operation
;
...
...
paddle/fluid/translator/program_translator.cc
浏览文件 @
343a9e95
...
@@ -76,7 +76,7 @@ void ProgramTranslator::ExtractParameterFromSingleBlock(
...
@@ -76,7 +76,7 @@ void ProgramTranslator::ExtractParameterFromSingleBlock(
std
::
string
get_parameter_op_name
(
ir
::
GetParameterOp
::
name
());
std
::
string
get_parameter_op_name
(
ir
::
GetParameterOp
::
name
());
ir
::
OpInfo
op_info
=
ctx
->
GetRegisteredOpInfo
(
get_parameter_op_name
);
ir
::
OpInfo
op_info
=
ctx
->
GetRegisteredOpInfo
(
get_parameter_op_name
);
std
::
unordered_map
<
std
::
string
,
ir
::
Attribute
>
op_attribute_map
=
{
std
::
unordered_map
<
std
::
string
,
ir
::
Attribute
>
op_attribute_map
=
{
{
"parameter_name"
,
ir
::
StrAttribute
::
get
(
ctx
,
var
->
Name
())},
{
var
->
Name
()
,
ir
::
StrAttribute
::
get
(
ctx
,
var
->
Name
())},
};
};
ir
::
Type
translated_var_type
=
type_translator
[
var
->
GetType
()](
ctx
,
*
var
);
ir
::
Type
translated_var_type
=
type_translator
[
var
->
GetType
()](
ctx
,
*
var
);
ir
::
Operation
*
operation
=
ir
::
Operation
::
create
(
ir
::
Operation
*
operation
=
ir
::
Operation
::
create
(
...
...
paddle/fluid/translator/program_translator.h
浏览文件 @
343a9e95
...
@@ -39,9 +39,9 @@ struct VariableDefiningInfo {
...
@@ -39,9 +39,9 @@ struct VariableDefiningInfo {
ir
::
OpResult
value
;
ir
::
OpResult
value
;
bool
generated_by_vector
=
bool
generated_by_vector
=
false
;
// true if target variab
l
e is generated by Vector<Tensor>
false
;
// true if target variabe is generated by Vector<Tensor>
int
idx_in_vector
=
int
idx_in_vector
=
-
1
;
// positive if target variab
l
e is generated by Vector<Tensor>
-
1
;
// positive if target variabe is generated by Vector<Tensor>
};
};
using
TranslationContext
=
using
TranslationContext
=
...
...
paddle/fluid/translator/utils.h
已删除
100644 → 0
浏览文件 @
37930a69
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include <string_view>
namespace
paddle
{
namespace
translator
{
static
std
::
string
UnderscoreToCamelCase
(
std
::
string
str
)
{
std
::
string
camel_case
;
bool
next_upper
=
true
;
for
(
char
c
:
str
)
{
if
(
c
==
'_'
)
{
next_upper
=
true
;
}
else
{
if
(
next_upper
)
{
camel_case
+=
toupper
(
c
);
next_upper
=
false
;
}
else
{
camel_case
+=
c
;
}
}
}
return
camel_case
;
}
}
// namespace translator
}
// namespace paddle
test/cpp/ir/core/program_translator_test.cc
浏览文件 @
343a9e95
...
@@ -47,17 +47,17 @@ ProgramDesc load_from_file(const std::string &file_name) {
...
@@ -47,17 +47,17 @@ ProgramDesc load_from_file(const std::string &file_name) {
}
}
TEST
(
PaddleDialectTest
,
Translator
)
{
TEST
(
PaddleDialectTest
,
Translator
)
{
auto
p
=
load_from_file
(
"restnet50_main.prog"
)
;
LOG
(
WARNING
)
<<
"TODO"
;
EXPECT_EQ
(
p
.
Size
(),
1u
);
// auto p = load_from_file("restnet50_main.prog"
);
// EXPECT_EQ(p.Size(), 1u);
ir
::
IrContext
*
ctx
=
ir
::
IrContext
::
Instance
();
ctx
->
GetOrRegisterDialect
<
PaddleDialect
>
();
// ir::IrContext *ctx = ir::IrContext::Instance
();
ctx
->
GetOrRegisterDialect
<
ir
::
Builtin
Dialect
>
();
// ctx->GetOrRegisterDialect<Paddle
Dialect>();
auto
program
=
paddle
::
TranslateLegacyProgramToProgram
(
p
);
// ctx->GetOrRegisterDialect<ir::BuiltinDialect>(
);
// auto program = paddle::TranslateLegacyProgramToProgram(p);
size_t
op_size
=
program
->
block
()
->
size
();
//
ops.size() = op size in BlockDesc + get_parameter_op + combine op
//
size_t op_size = program->block()->size();
EXPECT_EQ
(
op_size
,
p
.
Block
(
0
).
OpSize
()
+
program
->
parameters_num
()
+
21
);
// // ops.size() = op size in BlockDesc + get_parameter_op + combine op
// EXPECT_EQ(op_size, p.Block(0).OpSize() + program->parameters_num() + 20);
std
::
cout
<<
*
program
<<
std
::
endl
;
// VLOG(0) << *program
;
}
}
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录