Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
X2Paddle
提交
3c80edf9
X
X2Paddle
项目概览
PaddlePaddle
/
X2Paddle
大约 1 年 前同步成功
通知
328
Star
698
Fork
167
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
26
列表
看板
标记
里程碑
合并请求
4
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
X
X2Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
26
Issue
26
列表
看板
标记
里程碑
合并请求
4
合并请求
4
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
3c80edf9
编写于
12月 03, 2020
作者:
S
SunAhong1993
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix the gather
上级
e8df9aec
变更
5
隐藏空白更改
内联
并排
Showing
5 changed file
with
49 addition
and
27 deletion
+49
-27
x2paddle/core/program.py
x2paddle/core/program.py
+4
-1
x2paddle/decoder/pytorch_decoder.py
x2paddle/decoder/pytorch_decoder.py
+8
-8
x2paddle/decoder/tf_decoder.py
x2paddle/decoder/tf_decoder.py
+1
-1
x2paddle/op_mapper/dygraph/pytorch2paddle/pytorch_custom_layer/gather.py
...per/dygraph/pytorch2paddle/pytorch_custom_layer/gather.py
+35
-16
x2paddle/optimizer/code_optimizer/hierachical_tree.py
x2paddle/optimizer/code_optimizer/hierachical_tree.py
+1
-1
未找到文件。
x2paddle/core/program.py
浏览文件 @
3c80edf9
...
@@ -76,6 +76,7 @@ class PaddleGraph(object):
...
@@ -76,6 +76,7 @@ class PaddleGraph(object):
self
.
source_type
=
source_type
self
.
source_type
=
source_type
self
.
custom_code
=
None
self
.
custom_code
=
None
self
.
inputs_info
=
None
self
.
inputs_info
=
None
self
.
can_dygraph2static
=
True
def
set_name
(
self
,
name
):
def
set_name
(
self
,
name
):
...
@@ -166,6 +167,8 @@ class PaddleGraph(object):
...
@@ -166,6 +167,8 @@ class PaddleGraph(object):
self
.
clear_edges
()
self
.
clear_edges
()
outputs_from_nodes
=
dict
()
outputs_from_nodes
=
dict
()
for
layer_id
,
layer
in
self
.
layers
.
items
():
for
layer_id
,
layer
in
self
.
layers
.
items
():
if
layer
.
kernel
==
"custom_layer:Gather"
:
self
.
can_dygraph2static
=
False
for
input_key
,
input_var
in
layer
.
inputs
.
items
():
for
input_key
,
input_var
in
layer
.
inputs
.
items
():
vs
=
input_var
vs
=
input_var
if
not
isinstance
(
vs
,
list
):
if
not
isinstance
(
vs
,
list
):
...
@@ -283,7 +286,7 @@ class PaddleGraph(object):
...
@@ -283,7 +286,7 @@ class PaddleGraph(object):
self
.
gen_dygraph_code
(
save_dir
)
self
.
gen_dygraph_code
(
save_dir
)
self
.
dump_dygraph_parameter
(
save_dir
)
self
.
dump_dygraph_parameter
(
save_dir
)
# 动转静
# 动转静
if
len
(
self
.
inputs_info
)
>
0
:
if
len
(
self
.
inputs_info
)
>
0
and
self
.
can_dygraph2static
:
input_shapes
=
list
()
input_shapes
=
list
()
input_types
=
list
()
input_types
=
list
()
for
input_name
in
self
.
inputs
:
for
input_name
in
self
.
inputs
:
...
...
x2paddle/decoder/pytorch_decoder.py
浏览文件 @
3c80edf9
...
@@ -21,14 +21,14 @@ import numpy as np
...
@@ -21,14 +21,14 @@ import numpy as np
class
Decoder
(
object
):
class
Decoder
(
object
):
def
_optimize_graph
(
self
,
graph
):
def
_optimize_graph
(
self
,
graph
):
torch
.
_C
.
_jit_pass_constant_propagation
(
graph
)
torch
.
_C
.
_jit_pass_constant_propagation
(
graph
)
torch
.
_C
.
_jit_pass_dce
(
graph
)
#
torch._C._jit_pass_dce(graph)
torch
.
_C
.
_jit_pass_lint
(
graph
)
#
torch._C._jit_pass_lint(graph)
torch
.
_C
.
_jit_pass_peephole
(
graph
)
#
torch._C._jit_pass_peephole(graph)
torch
.
_C
.
_jit_pass_lint
(
graph
)
#
torch._C._jit_pass_lint(graph)
torch
.
_C
.
_jit_pass_dce
(
graph
)
#
torch._C._jit_pass_dce(graph)
torch
.
_C
.
_jit_pass_lint
(
graph
)
#
torch._C._jit_pass_lint(graph)
torch
.
_C
.
_jit_pass_canonicalize
(
graph
)
#
torch._C._jit_pass_canonicalize(graph)
torch
.
_C
.
_jit_pass_lint
(
graph
)
#
torch._C._jit_pass_lint(graph)
torch
.
_C
.
_jit_pass_constant_propagation
(
graph
)
torch
.
_C
.
_jit_pass_constant_propagation
(
graph
)
return
graph
return
graph
...
...
x2paddle/decoder/tf_decoder.py
浏览文件 @
3c80edf9
...
@@ -402,7 +402,7 @@ class TFDecoder(object):
...
@@ -402,7 +402,7 @@ class TFDecoder(object):
right_shape_been_input
=
False
right_shape_been_input
=
False
while
not
right_shape_been_input
:
while
not
right_shape_been_input
:
try
:
try
:
shape
=
raw_
input
(
shape
=
input
(
"Shape of Input(e.g. None,224,224,3): "
)
"Shape of Input(e.g. None,224,224,3): "
)
except
:
except
:
shape
=
input
(
"Shape of Input(e.g. None,224,224,3): "
)
shape
=
input
(
"Shape of Input(e.g. None,224,224,3): "
)
...
...
x2paddle/op_mapper/dygraph/pytorch2paddle/pytorch_custom_layer/gather.py
浏览文件 @
3c80edf9
...
@@ -20,22 +20,41 @@ import numpy as np
...
@@ -20,22 +20,41 @@ import numpy as np
class
Gather
(
object
):
class
Gather
(
object
):
def
__init__
(
self
,
dim
):
def
__init__
(
self
,
dim
):
self
.
dim
=
dim
self
.
dim
=
dim
self
.
dtype_mapping
=
{
"VarType.INT32"
:
"int32"
,
"VarType.INT64"
:
"int64"
}
def
__call__
(
self
,
x
,
index
):
def
__call__
(
self
,
x
,
index
):
out_list
=
list
()
if
self
.
dim
<
0
:
dims
=
list
()
self
.
dim
+=
len
(
x
.
shape
)
index_shape
=
index
.
shape
x_range
=
list
(
range
(
len
(
x
.
shape
)))
x_type
=
x
.
numpy
().
dtype
x_range
[
0
]
=
self
.
dim
for
s
in
index_shape
:
x_range
[
self
.
dim
]
=
0
dims
.
append
(
list
(
range
(
s
)))
x_swaped
=
paddle
.
transpose
(
x
,
perm
=
x_range
)
for
id
in
product
(
*
dims
):
index_range
=
list
(
range
(
len
(
index
.
shape
)))
id
=
list
(
id
)
index_range
[
0
]
=
self
.
dim
id_tensor
=
paddle
.
to_tensor
(
np
.
array
(
id
).
astype
(
'int32'
))
index_range
[
self
.
dim
]
=
0
dim_id
=
paddle
.
gather_nd
(
index
,
id_tensor
).
numpy
()
index_swaped
=
paddle
.
transpose
(
index
,
perm
=
index_range
)
id
[
self
.
dim
]
=
dim_id
dtype
=
self
.
dtype_mapping
[
str
(
index
.
dtype
)]
id_tensor
=
paddle
.
to_tensor
(
np
.
array
(
id
).
astype
(
'int32'
))
data
=
paddle
.
gather_nd
(
x
,
id_tensor
).
numpy
()
x_shape
=
paddle
.
shape
(
x_swaped
)
out_list
.
append
(
data
)
index_shape
=
paddle
.
shape
(
index_swaped
)
out
=
paddle
.
to_tensor
(
np
.
array
(
out_list
).
astype
(
x_type
))
out
=
paddle
.
reshape
(
out
,
index_shape
)
prod
=
paddle
.
prod
(
x_shape
,
dtype
=
dtype
)
/
x_shape
[
0
]
x_swaped_flattend
=
paddle
.
flatten
(
x_swaped
)
index_swaped_flattend
=
paddle
.
flatten
(
index_swaped
)
index_swaped_flattend
*=
prod
bias
=
paddle
.
arange
(
start
=
0
,
end
=
prod
,
dtype
=
dtype
)
bias
=
paddle
.
reshape
(
bias
,
x_shape
[
1
:])
bias
=
paddle
.
crop
(
bias
,
index_shape
[
1
:])
bias
=
paddle
.
flatten
(
bias
)
bias
=
paddle
.
tile
(
bias
,
[
index_shape
[
0
]])
index_swaped_flattend
+=
bias
gathered
=
paddle
.
index_select
(
x_swaped_flattend
,
index_swaped_flattend
)
gathered
=
paddle
.
reshape
(
gathered
,
index_swaped
.
shape
)
out
=
paddle
.
transpose
(
gathered
,
perm
=
x_range
)
return
out
return
out
x2paddle/optimizer/code_optimizer/hierachical_tree.py
浏览文件 @
3c80edf9
...
@@ -372,7 +372,7 @@ class HierarchicalTree(Tree):
...
@@ -372,7 +372,7 @@ class HierarchicalTree(Tree):
"import paddle.fluid as fluid"
,
"import paddle.fluid as fluid"
,
"from paddle.fluid.initializer import Constant"
,
"from paddle.fluid.initializer import Constant"
,
"from paddle.fluid.param_attr import ParamAttr"
,
"from paddle.fluid.param_attr import ParamAttr"
,
"imort math"
,
"im
p
ort math"
,
"from x2paddle.op_mapper.dygraph.pytorch2paddle "
+
\
"from x2paddle.op_mapper.dygraph.pytorch2paddle "
+
\
"import pytorch_custom_layer as x2paddle_nn"
"import pytorch_custom_layer as x2paddle_nn"
"
\n
"
,]
"
\n
"
,]
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录