Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
X2Paddle
比较版本
49d90ae4c5350e0b08686afa3ca0a25371038905...ce96b0f5874a83f5b57cac8415fe3f056c8441a7
X
X2Paddle
项目概览
PaddlePaddle
/
X2Paddle
8 个月 前同步成功
通知
327
Star
698
Fork
167
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
26
列表
看板
标记
里程碑
合并请求
4
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
X
X2Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
26
Issue
26
列表
看板
标记
里程碑
合并请求
4
合并请求
4
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
前往新版Gitcode,体验更适合开发者的 AI 搜索 >>
源分支
ce96b0f5874a83f5b57cac8415fe3f056c8441a7
选择Git版本
...
目标分支
49d90ae4c5350e0b08686afa3ca0a25371038905
选择Git版本
比较
Commits (2)
https://gitcode.net/paddlepaddle/X2Paddle/-/commit/7c61e52187d7e859b88aacf23707797a82cf49f9
[Bug] Fixed ONNX Gemm bug (#917)
2022-12-08T11:16:01+08:00
WJJ1995
wjjisloser@163.com
* fixed Gemm bug * re-lint
https://gitcode.net/paddlepaddle/X2Paddle/-/commit/ce96b0f5874a83f5b57cac8415fe3f056c8441a7
[Bug] Fixed typo error (#920)
2022-12-08T17:36:42+08:00
WJJ1995
wjjisloser@163.com
* fixed Gemm bug * re-lint * fixed typo error
隐藏空白更改
内联
并排
Showing
2 changed file
with
60 addition
and
19 deletion
+60
-19
x2paddle/op_mapper/onnx2paddle/opset_legacy.py
x2paddle/op_mapper/onnx2paddle/opset_legacy.py
+59
-18
x2paddle/project_convertor/pytorch/torch2paddle/io.py
x2paddle/project_convertor/pytorch/torch2paddle/io.py
+1
-1
未找到文件。
x2paddle/op_mapper/onnx2paddle/opset_legacy.py
浏览文件 @
ce96b0f5
...
...
@@ -1637,29 +1637,70 @@ class OpSet():
"transpose_x"
:
trans_a
,
"transpose_y"
:
trans_b
,
}
self
.
paddle_graph
.
add_layer
(
'paddle.matmul'
,
inputs
=
matmul_inputs
,
outputs
=
[
val_mm
],
**
attr_matmul
)
self
.
paddle_graph
.
add_layer
(
"paddle.scale"
,
inputs
=
{
"x"
:
val_mm
},
outputs
=
[
val_mm
],
scale
=
alpha
)
if
beta
!=
0
:
if
beta
==
1.
:
add_inputs
=
{
"x"
:
val_mm
,
"y"
:
val_c
.
name
}
if
abs
(
alpha
-
1.0
)
<
1e-5
:
if
abs
(
beta
-
0.0
)
<
1e-5
:
self
.
paddle_graph
.
add_layer
(
"paddle.add"
,
inputs
=
add_inputs
,
outputs
=
[
node
.
name
])
'paddle.matmul'
,
inputs
=
matmul_inputs
,
outputs
=
[
node
.
name
],
**
attr_matmul
)
else
:
var_beta
=
node
.
name
+
'_beta'
self
.
paddle_graph
.
add_layer
(
'paddle.matmul'
,
inputs
=
matmul_inputs
,
outputs
=
[
val_mm
],
**
attr_matmul
)
if
abs
(
beta
-
1.0
)
<
1e-5
:
add_inputs
=
{
"x"
:
val_mm
,
"y"
:
val_c
.
name
}
self
.
paddle_graph
.
add_layer
(
"paddle.add"
,
inputs
=
add_inputs
,
outputs
=
[
node
.
name
])
else
:
var_beta
=
node
.
name
+
'_beta'
self
.
paddle_graph
.
add_layer
(
"paddle.scale"
,
inputs
=
{
"x"
:
val_c
.
name
},
outputs
=
[
var_beta
],
scale
=
beta
)
add_inputs
=
{
"x"
:
val_mm
,
"y"
:
var_beta
}
self
.
paddle_graph
.
add_layer
(
"paddle.add"
,
inputs
=
add_inputs
,
outputs
=
[
node
.
name
])
else
:
if
abs
(
beta
-
0.0
)
<
1e-5
:
self
.
paddle_graph
.
add_layer
(
'paddle.matmul'
,
inputs
=
matmul_inputs
,
outputs
=
[
val_mm
],
**
attr_matmul
)
self
.
paddle_graph
.
add_layer
(
"paddle.scale"
,
inputs
=
{
"x"
:
val_
c
.
name
},
outputs
=
[
var_beta
],
scale
=
bet
a
)
add_inputs
=
{
"x"
:
val_mm
,
"y"
:
var_beta
}
inputs
=
{
"x"
:
val_
mm
},
outputs
=
[
node
.
name
],
scale
=
alph
a
)
else
:
self
.
paddle_graph
.
add_layer
(
"paddle.add"
,
inputs
=
add_inputs
,
outputs
=
[
node
.
name
])
'paddle.matmul'
,
inputs
=
[
matmul_inputs
],
outputs
=
[
val_mm
],
**
attr_matmul
)
self
.
paddle_graph
.
add_layer
(
"paddle.scale"
,
inputs
=
{
"x"
:
val_mm
},
outputs
=
[
val_mm
],
scale
=
alpha
)
if
abs
(
beta
-
1.0
)
<
1e-5
:
add_inputs
=
{
"x"
:
val_mm
,
"y"
:
val_c
.
name
}
self
.
paddle_graph
.
add_layer
(
"paddle.add"
,
inputs
=
add_inputs
,
outputs
=
[
node
.
name
])
else
:
var_beta
=
node
.
name
+
'_beta'
self
.
paddle_graph
.
add_layer
(
"paddle.scale"
,
inputs
=
{
"x"
:
val_c
.
name
},
outputs
=
[
var_beta
],
scale
=
beta
)
add_inputs
=
{
"x"
:
val_mm
,
"y"
:
var_beta
}
self
.
paddle_graph
.
add_layer
(
"paddle.add"
,
inputs
=
add_inputs
,
outputs
=
[
node
.
name
])
@
print_mapping_info
def
Sum
(
self
,
node
):
...
...
x2paddle/project_convertor/pytorch/torch2paddle/io.py
浏览文件 @
ce96b0f5
...
...
@@ -169,7 +169,7 @@ class DataLoader(paddle.io.DataLoader):
timeout
=
timeout
,
worker_init_fn
=
worker_init_fn
)
if
sampler
is
not
None
:
sel
d
.
batch_sampler
.
sampler
=
sampler
sel
f
.
batch_sampler
.
sampler
=
sampler
class
DistributedSampler
(
paddle
.
io
.
DistributedBatchSampler
):
...
...