Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
cbbd940e
P
Paddle
项目概览
PaddlePaddle
/
Paddle
1 年多 前同步成功
通知
2305
Star
20932
Fork
5423
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
cbbd940e
编写于
7月 27, 2023
作者:
A
Asthestarsfalll
提交者:
GitHub
7月 27, 2023
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[NewIR]Remove compatible logic of ProgramTranslator (#55453)
上级
147fbfe0
变更
6
隐藏空白更改
内联
并排
Showing
6 changed file
with
44 addition
and
52 deletion
+44
-52
paddle/fluid/ir_adaptor/translator/op_compat_gen.py
paddle/fluid/ir_adaptor/translator/op_compat_gen.py
+0
-3
paddle/fluid/ir_adaptor/translator/op_compat_info.h
paddle/fluid/ir_adaptor/translator/op_compat_info.h
+2
-4
paddle/fluid/ir_adaptor/translator/op_translator.cc
paddle/fluid/ir_adaptor/translator/op_translator.cc
+1
-1
paddle/fluid/ir_adaptor/translator/utils.h
paddle/fluid/ir_adaptor/translator/utils.h
+0
-42
paddle/phi/api/yaml/op_compat.yaml
paddle/phi/api/yaml/op_compat.yaml
+25
-2
test/ir/new_ir/test_special_op_translator.py
test/ir/new_ir/test_special_op_translator.py
+16
-0
未找到文件。
paddle/fluid/ir_adaptor/translator/op_compat_gen.py
浏览文件 @
cbbd940e
...
...
@@ -126,9 +126,6 @@ def OpNameNormalizerInitialization(
backward_op
,
op_compat_item
[
"scalar"
]
)
# special op mappings
op_name_mappings
[
"fetch_v2"
]
=
"fetch"
op_name_normailzer_template
=
env
.
get_template
(
"op_compat_info.cc.j2"
)
with
open
(
output_source_file
,
'wt'
)
as
f
:
op_compat_definition
=
op_name_normailzer_template
.
render
(
...
...
paddle/fluid/ir_adaptor/translator/op_compat_info.h
浏览文件 @
cbbd940e
...
...
@@ -19,8 +19,6 @@
#include "glog/logging.h"
#include "paddle/fluid/ir_adaptor/translator/utils.h"
#pragma once
namespace
paddle
{
...
...
@@ -106,11 +104,11 @@ class OpNameNormalizer {
return
legacy_name
;
}
if
(
op_arg_name_mappings
.
find
(
op_type
)
==
op_arg_name_mappings
.
end
())
{
return
UnderscoreToCamelCase
(
arg_name
)
;
return
arg_name
;
}
auto
&
arg_mappings
=
op_arg_name_mappings
[
op_type
];
if
(
arg_mappings
.
find
(
arg_name
)
==
arg_mappings
.
end
())
{
return
UnderscoreToCamelCase
(
arg_name
)
;
return
arg_name
;
}
return
arg_mappings
.
at
(
arg_name
);
}
...
...
paddle/fluid/ir_adaptor/translator/op_translator.cc
浏览文件 @
cbbd940e
...
...
@@ -307,7 +307,7 @@ ir::OpInfo OpTranscriber::LoopkUpOpInfo(ir::IrContext* ctx,
const
OpDesc
&
op_desc
)
{
std
::
string
target_op_name
=
kTargetDialectPrefix
+
OpNameCompatibleMapping
(
op_desc
.
Type
());
if
(
IsInplace
(
op_desc
))
{
if
(
IsInplace
(
op_desc
)
&&
*
target_op_name
.
rbegin
()
!=
'_'
)
{
target_op_name
+=
"_"
;
}
VLOG
(
6
)
<<
"[op name normalizing]: "
<<
op_desc
.
Type
()
<<
" to "
...
...
paddle/fluid/ir_adaptor/translator/utils.h
已删除
100644 → 0
浏览文件 @
147fbfe0
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include <string_view>
namespace
paddle
{
namespace
translator
{
static
std
::
string
UnderscoreToCamelCase
(
std
::
string
str
)
{
std
::
string
camel_case
;
bool
next_upper
=
true
;
for
(
char
c
:
str
)
{
if
(
c
==
'_'
)
{
next_upper
=
true
;
}
else
{
if
(
next_upper
)
{
camel_case
+=
toupper
(
c
);
next_upper
=
false
;
}
else
{
camel_case
+=
c
;
}
}
}
return
camel_case
;
}
}
// namespace translator
}
// namespace paddle
paddle/phi/api/yaml/op_compat.yaml
浏览文件 @
cbbd940e
...
...
@@ -354,6 +354,7 @@
attrs
:
[
bool use_mkldnn = false
]
-
op
:
bilinear (bilinear_tensor_product)
backward
:
bilinear_grad (bilinear_tensor_product_grad)
inputs
:
{
x
:
X
,
y
:
Y
,
weight
:
Weight
,
bias
:
Bias
}
outputs
:
...
...
@@ -1838,7 +1839,7 @@
data_type
:
float
support_tensor
:
true
-
op
:
merged_momentum_
-
op
:
merged_momentum_
(merged_momentum)
inputs
:
{
param
:
Param
,
grad
:
Grad
,
velocity
:
Velocity
,
learning_rate
:
LearningRate
,
master_param
:
MasterParam
}
outputs
:
...
...
@@ -3038,11 +3039,27 @@
yolo_loss
:
GetYoloLossExpectedKernelType
yolo_loss_grad
:
GetYoloLossExpectedKernelType
-
op
:
fetch
-
op
:
channel_shuffle
inputs
:
{
x
:
X
}
outputs
:
{
out
:
Out
}
-
op
:
fetch (fetch_v2)
inputs
:
{
x
:
X
}
outputs
:
{
out
:
Out
}
-
op
:
full_batch_size_like (fill_constant_batch_size_like)
inputs
:
{
input
:
Input
}
outputs
:
{
out
:
Out
}
-
op
:
logspace
inputs
:
{
start
:
Start
,
stop
:
Stop
,
num
:
Num
,
base
:
Base
}
outputs
:
{
out
:
Out
}
-
op
:
lu
backward
:
lu_grad
...
...
@@ -3059,6 +3076,12 @@
outputs
:
{
reindex_src
:
Reindex_Src
,
reindex_dst
:
Reindex_Dst
,
out_nodes
:
Out_Nodes
}
-
op
:
rrelu
inputs
:
{
x
:
X
}
outputs
:
{
out
:
Out
,
noise
:
Noise
}
-
op
:
sigmoid_cross_entropy_with_logits
backward
:
sigmoid_cross_entropy_with_logits_grad
inputs
:
...
...
test/ir/new_ir/test_special_op_translator.py
浏览文件 @
cbbd940e
...
...
@@ -194,5 +194,21 @@ class TestReduceOpTranscriber(unittest.TestCase):
np
.
testing
.
assert_array_equal
(
out
[
0
],
np
.
all
(
arr
,
axis
=
0
))
class
TestIndexPutOpTranscriber
(
unittest
.
TestCase
):
def
test_op
(
self
):
place
=
core
.
Place
()
place
.
set_place
(
paddle
.
CPUPlace
())
new_scope
=
paddle
.
static
.
Scope
()
main_program
=
paddle
.
static
.
Program
()
with
paddle
.
static
.
scope_guard
(
new_scope
):
with
paddle
.
static
.
program_guard
(
main_program
):
x
=
paddle
.
randn
([
2
,
3
])
indices
=
[
paddle
.
randint
(
0
,
2
,
[
2
]),
paddle
.
randint
(
0
,
1
,
[
2
])]
value
=
paddle
.
randn
([
2
])
y
=
paddle
.
index_put
(
x
,
indices
,
value
,
False
)
_
=
ir
.
translate_to_new_ir
(
main_program
.
desc
)
if
__name__
==
"__main__"
:
unittest
.
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录