Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
636fefd9
P
Paddle
项目概览
PaddlePaddle
/
Paddle
1 年多 前同步成功
通知
2302
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
636fefd9
编写于
1月 29, 2021
作者:
G
gongweibao
提交者:
GitHub
1月 29, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
code style (#30781)
code style
上级
88dfd067
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
104 addition
and
54 deletion
+104
-54
python/paddle/distributed/fleet/meta_optimizers/ascend/ascend_optimizer.py
...tributed/fleet/meta_optimizers/ascend/ascend_optimizer.py
+49
-21
python/paddle/distributed/fleet/meta_optimizers/ascend/ascend_parser.py
...distributed/fleet/meta_optimizers/ascend/ascend_parser.py
+55
-33
未找到文件。
python/paddle/distributed/fleet/meta_optimizers/ascend/ascend_optimizer.py
浏览文件 @
636fefd9
...
@@ -24,6 +24,7 @@ from collections import namedtuple
...
@@ -24,6 +24,7 @@ from collections import namedtuple
HcomGroupConfig
=
namedtuple
(
'HcomGroupConfig'
,
[
'name'
,
'nranks'
,
'rank_ids'
])
HcomGroupConfig
=
namedtuple
(
'HcomGroupConfig'
,
[
'name'
,
'nranks'
,
'rank_ids'
])
class
AscendIRParser
(
object
):
class
AscendIRParser
(
object
):
def
__init__
(
self
):
def
__init__
(
self
):
self
.
graph_idx
=
0
self
.
graph_idx
=
0
...
@@ -34,19 +35,26 @@ class AscendIRParser(object):
...
@@ -34,19 +35,26 @@ class AscendIRParser(object):
ret_map
=
{}
ret_map
=
{}
ge_in_operator
=
[]
ge_in_operator
=
[]
for
id
,
var
in
enumerate
(
input_varlist
):
for
id
,
var
in
enumerate
(
input_varlist
):
if
var
.
is_data
:
# input data
if
var
.
is_data
:
# input data
ge_input
=
core
.
GEOperatorFactory
.
create_operator
(
var
.
name
,
"Data"
).
set_attr_int32
(
"index"
,
id
)
ge_input
=
core
.
GEOperatorFactory
.
create_operator
(
var
.
name
,
"Data"
).
set_attr_int32
(
"index"
,
id
)
ret_map
[
var
.
name
]
=
ge_input
ret_map
[
var
.
name
]
=
ge_input
ge_in_operator
.
append
(
ge_input
)
ge_in_operator
.
append
(
ge_input
)
else
:
# param, learning ...
else
:
# param, learning ...
ge_input
=
core
.
GEOperatorFactory
.
create_operator
(
var
.
name
,
"Variable"
)
ge_input
=
core
.
GEOperatorFactory
.
create_operator
(
var
.
name
,
ge_input
.
update_output_desc
(
"y"
,
core
.
GETensorDesc
(
core
.
GEShape
(
var
.
shape
),
core
.
GEFormat
.
FORMAT_ND
,
core
.
GEDataType
.
DT_FLOAT
))
"Variable"
)
ge_input
.
update_output_desc
(
"y"
,
core
.
GETensorDesc
(
core
.
GEShape
(
var
.
shape
),
core
.
GEFormat
.
FORMAT_ND
,
core
.
GEDataType
.
DT_FLOAT
))
ret_map
[
var
.
name
]
=
ge_input
ret_map
[
var
.
name
]
=
ge_input
return
ge_in_operator
,
ret_map
return
ge_in_operator
,
ret_map
def
_endpoint_to_world_rank_id
(
self
,
endpoint
):
def
_endpoint_to_world_rank_id
(
self
,
endpoint
):
world_endpoints
=
fleet
.
worker_endpoints
()
world_endpoints
=
fleet
.
worker_endpoints
()
assert
endpoint
in
world_endpoints
,
"endpoint (%s) not in worker_endpoints (%s) "
%
(
endpoint
,
fleet
.
world_device_ids
())
assert
endpoint
in
world_endpoints
,
"endpoint (%s) not in worker_endpoints (%s) "
%
(
endpoint
,
fleet
.
world_device_ids
())
return
world_endpoints
.
index
(
endpoint
)
return
world_endpoints
.
index
(
endpoint
)
def
parse_op
(
self
,
op
):
def
parse_op
(
self
,
op
):
...
@@ -62,26 +70,40 @@ class AscendIRParser(object):
...
@@ -62,26 +70,40 @@ class AscendIRParser(object):
self
.
hcom_endpoints
[
nccl_id
]
=
other_endpoints
[:]
self
.
hcom_endpoints
[
nccl_id
]
=
other_endpoints
[:]
self
.
hcom_endpoints
[
nccl_id
].
insert
(
rank
,
endpoint
)
self
.
hcom_endpoints
[
nccl_id
].
insert
(
rank
,
endpoint
)
print
(
"nccl_id (%s) registered endpoints %s"
%
(
nccl_id
,
self
.
hcom_endpoints
[
nccl_id
]))
print
(
"nccl_id (%s) registered endpoints %s"
%
(
nccl_id
,
self
.
hcom_endpoints
[
nccl_id
]))
elif
op
.
type
==
'c_comm_init'
:
elif
op
.
type
==
'c_comm_init'
:
nccl_id
=
op
.
input_arg_names
[
0
]
nccl_id
=
op
.
input_arg_names
[
0
]
nranks
=
op
.
attr
(
"nranks"
)
nranks
=
op
.
attr
(
"nranks"
)
assert
nranks
==
len
(
self
.
hcom_endpoints
[
nccl_id
]),
"nranks doesn't match endpoint count"
assert
nranks
==
len
(
self
.
hcom_endpoints
[
nccl_id
]),
"nranks doesn't match endpoint count"
rank
=
op
.
attr
(
"rank"
)
rank
=
op
.
attr
(
"rank"
)
ring_id
=
op
.
attr
(
"ring_id"
)
ring_id
=
op
.
attr
(
"ring_id"
)
group_name
=
"hcom_group_"
+
str
(
ring_id
)
group_name
=
"hcom_group_"
+
str
(
ring_id
)
global_rank_ids
=
[
self
.
_endpoint_to_world_rank_id
(
endpoint
)
for
endpoint
in
self
.
hcom_endpoints
[
nccl_id
]]
global_rank_ids
=
[
self
.
groups_to_create
.
append
(
HcomGroupConfig
(
name
=
group_name
,
nranks
=
nranks
,
rank_ids
=
global_rank_ids
))
self
.
_endpoint_to_world_rank_id
(
endpoint
)
print
(
"append to create group: %s, with rank_ids: %s"
%
(
group_name
,
global_rank_ids
))
for
endpoint
in
self
.
hcom_endpoints
[
nccl_id
]
]
self
.
groups_to_create
.
append
(
HcomGroupConfig
(
name
=
group_name
,
nranks
=
nranks
,
rank_ids
=
global_rank_ids
))
print
(
"append to create group: %s, with rank_ids: %s"
%
(
group_name
,
global_rank_ids
))
elif
op
.
type
in
ascend_parser
.
registerd_op
:
elif
op
.
type
in
ascend_parser
.
registerd_op
:
print
(
"Op[%s] has been registered, begin to parse it"
%
(
op
.
type
))
print
(
"Op[%s] has been registered, begin to parse it"
%
(
op
.
type
))
op_parser
=
self
.
parser_factory
.
create_parse
(
ascend_parser
.
registerd_op
[
op
.
type
])
op_parser
=
self
.
parser_factory
.
create_parse
(
ascend_parser
.
registerd_op
[
op
.
type
])
op_parser
.
apply
(
op
)
op_parser
.
apply
(
op
)
else
:
else
:
print
(
"Op[%s] has not been registered, so we have to skip it"
%
(
op
.
type
))
print
(
"Op[%s] has not been registered, so we have to skip it"
%
(
op
.
type
))
def
_parse_program
(
self
,
graph_name
,
program
,
input_varlist
=
[],
fetch_list
=
[]):
def
_parse_program
(
self
,
graph_name
,
program
,
input_varlist
=
[],
fetch_list
=
[]):
begin_graph_idx
=
self
.
graph_idx
begin_graph_idx
=
self
.
graph_idx
ge_in_operator
=
[]
ge_in_operator
=
[]
ge_out_operator
=
[]
ge_out_operator
=
[]
...
@@ -96,7 +118,8 @@ class AscendIRParser(object):
...
@@ -96,7 +118,8 @@ class AscendIRParser(object):
ge_in_operator
,
self
.
var2geop
=
self
.
_construct_input_map
(
input_varlist
)
ge_in_operator
,
self
.
var2geop
=
self
.
_construct_input_map
(
input_varlist
)
self
.
parser_factory
=
ascend_parser
.
AscendParserFactory
(
graph
,
self
.
var2geop
)
self
.
parser_factory
=
ascend_parser
.
AscendParserFactory
(
graph
,
self
.
var2geop
)
for
i
,
curop
in
list
(
enumerate
(
block
.
ops
)):
for
i
,
curop
in
list
(
enumerate
(
block
.
ops
)):
self
.
parse_op
(
curop
)
self
.
parse_op
(
curop
)
...
@@ -133,9 +156,11 @@ class AscendIRParser(object):
...
@@ -133,9 +156,11 @@ class AscendIRParser(object):
self
.
graph_idx
+=
1
self
.
graph_idx
+=
1
return
graph
return
graph
def
parse_program
(
self
,
startup_program
,
main_program
,
input_varlist
,
fetch_list
):
def
parse_program
(
self
,
startup_program
,
main_program
,
input_varlist
,
fetch_list
):
startup_graph
=
self
.
_parse_program
(
"startup"
,
startup_program
)
startup_graph
=
self
.
_parse_program
(
"startup"
,
startup_program
)
main_graph
=
self
.
_parse_program
(
"main"
,
main_program
,
input_varlist
,
fetch_list
)
main_graph
=
self
.
_parse_program
(
"main"
,
main_program
,
input_varlist
,
fetch_list
)
return
startup_graph
,
main_graph
return
startup_graph
,
main_graph
...
@@ -174,14 +199,16 @@ class AscendOptimizer(Optimizer):
...
@@ -174,14 +199,16 @@ class AscendOptimizer(Optimizer):
auto_dp
=
False
):
auto_dp
=
False
):
minimized
=
None
minimized
=
None
if
self
.
inner_opt
:
if
self
.
inner_opt
:
minimized
=
self
.
inner_opt
.
minimize
(
loss
,
startup_program
=
startup_program
)
minimized
=
self
.
inner_opt
.
minimize
(
loss
,
startup_program
=
startup_program
)
self
.
ascend_instance
=
core
.
AscendInstance
()
self
.
ascend_instance
=
core
.
AscendInstance
()
from
paddle.distributed
import
fleet
from
paddle.distributed
import
fleet
if
auto_dp
and
fleet
.
worker_num
()
>
1
:
if
auto_dp
and
fleet
.
worker_num
()
>
1
:
from
paddle.fluid.transpiler
import
ascend_transpiler
from
paddle.fluid.transpiler
import
ascend_transpiler
t
=
ascend_transpiler
.
AscendTranspiler
(
startup_program
,
loss
.
block
.
program
)
t
=
ascend_transpiler
.
AscendTranspiler
(
startup_program
,
loss
.
block
.
program
)
t
.
transpile
()
t
.
transpile
()
print
(
loss
.
block
.
program
)
print
(
loss
.
block
.
program
)
...
@@ -211,7 +238,8 @@ class AscendOptimizer(Optimizer):
...
@@ -211,7 +238,8 @@ class AscendOptimizer(Optimizer):
for
cfg
in
self
.
parser
.
groups_to_create
:
for
cfg
in
self
.
parser
.
groups_to_create
:
hccl
.
create_group
(
cfg
.
name
,
cfg
.
nranks
,
cfg
.
rank_ids
)
hccl
.
create_group
(
cfg
.
name
,
cfg
.
nranks
,
cfg
.
rank_ids
)
print
(
"create group (%s), nranks: %d, rank_ids: %s"
%
(
cfg
.
name
,
cfg
.
nranks
,
cfg
.
rank_ids
))
print
(
"create group (%s), nranks: %d, rank_ids: %s"
%
(
cfg
.
name
,
cfg
.
nranks
,
cfg
.
rank_ids
))
self
.
ascend_instance
.
add_ascend_subgraph
(
0
,
startup_graph
)
self
.
ascend_instance
.
add_ascend_subgraph
(
0
,
startup_graph
)
self
.
ascend_instance
.
add_ascend_subgraph
(
1
,
main_graph
)
self
.
ascend_instance
.
add_ascend_subgraph
(
1
,
main_graph
)
...
...
python/paddle/distributed/fleet/meta_optimizers/ascend/ascend_parser.py
浏览文件 @
636fefd9
...
@@ -69,11 +69,13 @@ class AscendHelper(object):
...
@@ -69,11 +69,13 @@ class AscendHelper(object):
}
}
def
dtype2ge
(
self
,
dtype
):
def
dtype2ge
(
self
,
dtype
):
assert
dtype
in
self
.
dtype2ge_map
,
"dtype[%d] is not supported %d"
%
(
dtype
)
assert
dtype
in
self
.
dtype2ge_map
,
"dtype[%d] is not supported %d"
%
(
dtype
)
return
self
.
dtype2ge_map
[
dtype
]
return
self
.
dtype2ge_map
[
dtype
]
def
dtype2np
(
self
,
index
):
def
dtype2np
(
self
,
index
):
assert
index
in
self
.
dtype2np_map
,
"index[%d] is not supported %d"
%
(
dtype
)
assert
index
in
self
.
dtype2np_map
,
"index[%d] is not supported %d"
%
(
dtype
)
return
self
.
dtype2np_map
[
index
]
return
self
.
dtype2np_map
[
index
]
...
@@ -98,7 +100,8 @@ class AscendParserBase(object):
...
@@ -98,7 +100,8 @@ class AscendParserBase(object):
self
.
ascend_helper
=
AscendHelper
()
self
.
ascend_helper
=
AscendHelper
()
def
_get_ge_input
(
self
,
input_var_name
):
def
_get_ge_input
(
self
,
input_var_name
):
assert
input_var_name
in
self
.
var2geop
,
"var %s not created before"
%
(
input_var_name
)
assert
input_var_name
in
self
.
var2geop
,
"var %s not created before"
%
(
input_var_name
)
return
self
.
var2geop
[
input_var_name
]
return
self
.
var2geop
[
input_var_name
]
def
update_output
(
self
,
geop_list
,
index_list
):
def
update_output
(
self
,
geop_list
,
index_list
):
...
@@ -119,7 +122,8 @@ class AscendParserBase(object):
...
@@ -119,7 +122,8 @@ class AscendParserBase(object):
for
i
in
range
(
len
(
arguments
)):
for
i
in
range
(
len
(
arguments
)):
print
(
"assgin index_list[%d][%d] to %s"
%
print
(
"assgin index_list[%d][%d] to %s"
%
(
output_id
,
i
,
arguments
[
i
]))
(
output_id
,
i
,
arguments
[
i
]))
self
.
var2geop
[
arguments
[
i
]]
=
geop_list
[
index_list
[
output_id
][
i
]]
self
.
var2geop
[
arguments
[
i
]]
=
geop_list
[
index_list
[
output_id
][
i
]]
for
geop
in
geop_list
:
for
geop
in
geop_list
:
self
.
graph
.
add_op
(
geop
)
self
.
graph
.
add_op
(
geop
)
...
@@ -483,11 +487,15 @@ class TruncatedNormalParser(AscendParserBase):
...
@@ -483,11 +487,15 @@ class TruncatedNormalParser(AscendParserBase):
"const"
+
self
.
_accumulated_op_id
(),
"Const"
).
set_attr_tensor
(
"const"
+
self
.
_accumulated_op_id
(),
"Const"
).
set_attr_tensor
(
"value"
,
tensor3
)
"value"
,
tensor3
)
tensor4
=
self
.
_create_ge_tensor
([
1
],
dtype
,
mean
-
2
*
std
)
tensor4
=
self
.
_create_ge_tensor
([
1
],
dtype
,
mean
-
2
*
std
)
min_tensor
=
core
.
GEOperatorFactory
.
create_operator
(
"const"
+
self
.
_accumulated_op_id
(),
"Const"
).
set_attr_tensor
(
"value"
,
tensor4
)
min_tensor
=
core
.
GEOperatorFactory
.
create_operator
(
"const"
+
self
.
_accumulated_op_id
(),
"Const"
).
set_attr_tensor
(
"value"
,
tensor4
)
tensor5
=
self
.
_create_ge_tensor
([
1
],
dtype
,
mean
+
2
*
std
)
tensor5
=
self
.
_create_ge_tensor
([
1
],
dtype
,
mean
+
2
*
std
)
max_tensor
=
core
.
GEOperatorFactory
.
create_operator
(
"const"
+
self
.
_accumulated_op_id
(),
"Const"
).
set_attr_tensor
(
"value"
,
tensor5
)
max_tensor
=
core
.
GEOperatorFactory
.
create_operator
(
"const"
+
self
.
_accumulated_op_id
(),
"Const"
).
set_attr_tensor
(
"value"
,
tensor5
)
self
.
_mark_as_input
(
shape_tensor
)
self
.
_mark_as_input
(
shape_tensor
)
self
.
_mark_as_input
(
mean_tensor
)
self
.
_mark_as_input
(
mean_tensor
)
...
@@ -546,10 +554,11 @@ class AllGatherParser(AscendParserBase):
...
@@ -546,10 +554,11 @@ class AllGatherParser(AscendParserBase):
"rank_size"
,
rank_size
).
set_attr_string
(
"group"
,
group
)
"rank_size"
,
rank_size
).
set_attr_string
(
"group"
,
group
)
return
[
allgather
],
[[
0
]]
return
[
allgather
],
[[
0
]]
class
AllReduceParser
(
AscendParserBase
):
class
AllReduceParser
(
AscendParserBase
):
def
__init__
(
self
,
graph
,
var2geop
,
reduction
):
def
__init__
(
self
,
graph
,
var2geop
,
reduction
):
super
(
AllReduceParser
,
self
).
__init__
(
graph
,
var2geop
)
super
(
AllReduceParser
,
self
).
__init__
(
graph
,
var2geop
)
self
.
parser_name
=
"c_allreduce_"
+
reduction
self
.
parser_name
=
"c_allreduce_"
+
reduction
self
.
reduction
=
reduction
self
.
reduction
=
reduction
def
_apply
(
self
):
def
_apply
(
self
):
...
@@ -557,8 +566,8 @@ class AllReduceParser(AscendParserBase):
...
@@ -557,8 +566,8 @@ class AllReduceParser(AscendParserBase):
reduction
=
self
.
reduction
reduction
=
self
.
reduction
ring_id
=
self
.
op
.
attr
(
"ring_id"
)
ring_id
=
self
.
op
.
attr
(
"ring_id"
)
group
=
"hcom_group_"
+
str
(
ring_id
)
group
=
"hcom_group_"
+
str
(
ring_id
)
fusion
=
None
#self.op.attr("fusion")
fusion
=
None
#self.op.attr("fusion")
fusion_id
=
None
#self.op.attr("fusion_id")
fusion_id
=
None
#self.op.attr("fusion_id")
allreduce
=
core
.
GEOperatorFactory
.
create_operator
(
allreduce
=
core
.
GEOperatorFactory
.
create_operator
(
"allreduce"
+
self
.
_accumulated_op_id
(),
"HcomAllReduce"
).
set_input
(
"allreduce"
+
self
.
_accumulated_op_id
(),
"HcomAllReduce"
).
set_input
(
...
@@ -611,10 +620,10 @@ class ReduceScatterParser(AscendParserBase):
...
@@ -611,10 +620,10 @@ class ReduceScatterParser(AscendParserBase):
rank_size
=
self
.
op
.
attr
(
"rank_size"
)
rank_size
=
self
.
op
.
attr
(
"rank_size"
)
reduce_scatter
=
core
.
GEOperatorFactory
.
create_operator
(
reduce_scatter
=
core
.
GEOperatorFactory
.
create_operator
(
"reducescatter"
+
self
.
_accumulated_op_id
(),
"HcomReduceScatter"
).
set_input
(
"reducescatter"
+
self
.
_accumulated_op_id
(),
"x"
,
x
).
set_attr_string
(
"HcomReduceScatter"
).
set_input
(
"x"
,
x
).
set_attr_string
(
"reduction"
,
reduction
).
set_attr_string
(
"reduction"
,
reduction
).
set_attr_string
(
"group"
,
group
).
set_attr_int32
(
"rank_size"
,
rank_size
)
"group"
,
group
).
set_attr_int32
(
"rank_size"
,
rank_size
)
return
[
reduce_scatter
],
[[
0
]]
return
[
reduce_scatter
],
[[
0
]]
...
@@ -631,9 +640,8 @@ class SendParser(AscendParserBase):
...
@@ -631,9 +640,8 @@ class SendParser(AscendParserBase):
send
=
core
.
GEOperatorFactory
.
create_operator
(
send
=
core
.
GEOperatorFactory
.
create_operator
(
"send"
+
self
.
_accumulated_op_id
(),
"HcomSend"
).
set_input
(
"send"
+
self
.
_accumulated_op_id
(),
"HcomSend"
).
set_input
(
"x"
,
x
).
set_attr_int32
(
"x"
,
x
).
set_attr_int32
(
"sr_tag"
,
sr_tag
).
set_attr_int32
(
"sr_tag"
,
sr_tag
).
set_attr_int32
(
"dest_rank"
,
dest_rank
).
set_attr_string
(
"group"
,
group
)
"dest_rank"
,
dest_rank
).
set_attr_string
(
"group"
,
group
)
return
[
send
],
[[
0
]]
return
[
send
],
[[
0
]]
...
@@ -652,11 +660,10 @@ class ReceiveParser(AscendParserBase):
...
@@ -652,11 +660,10 @@ class ReceiveParser(AscendParserBase):
receive
=
core
.
GEOperatorFactory
.
create_operator
(
receive
=
core
.
GEOperatorFactory
.
create_operator
(
"receive"
+
self
.
_accumulated_op_id
(),
"HcomReceive"
).
set_input
(
"receive"
+
self
.
_accumulated_op_id
(),
"HcomReceive"
).
set_input
(
"x"
,
x
).
set_attr_int32
(
"x"
,
x
).
set_attr_int32
(
"sr_tag"
,
sr_tag
).
set_attr_int32
(
"sr_tag"
,
sr_tag
).
set_attr_int32
(
"src_rank"
,
src_rank
).
set_attr_string
(
"src_rank"
,
src_rank
).
set_attr_string
(
"group"
,
group
).
set_attr_vec_int32
(
"group"
,
group
).
set_attr_vec_int32
(
"shape"
,
shape
).
set_attr_int32
(
"dtype"
,
dtype
)
"shape"
,
shape
).
set_attr_int32
(
"dtype"
,
dtype
)
return
[
receive
],
[[
0
]]
return
[
receive
],
[[
0
]]
...
@@ -667,18 +674,30 @@ class ScaleParser(AscendParserBase):
...
@@ -667,18 +674,30 @@ class ScaleParser(AscendParserBase):
def
_apply
(
self
):
def
_apply
(
self
):
x
=
self
.
_get_ge_input
(
self
.
op
.
input_arg_names
[
0
])
x
=
self
.
_get_ge_input
(
self
.
op
.
input_arg_names
[
0
])
scale
=
self
.
op
.
attr
(
"scale"
)
#self.get_ge_input(self.op.input_arg_names[1])
scale
=
self
.
op
.
attr
(
"scale"
)
#self.get_ge_input(self.op.input_arg_names[1])
bias
=
self
.
op
.
attr
(
"bias"
)
bias
=
self
.
op
.
attr
(
"bias"
)
bias_after_scale
=
self
.
op
.
attr
(
"bias_after_scale"
)
bias_after_scale
=
self
.
op
.
attr
(
"bias_after_scale"
)
if
bias_after_scale
:
if
bias_after_scale
:
scale_value
=
core
.
GEOperatorFactory
.
create_operator
(
"scale"
+
self
.
_accumulated_op_id
(),
"Power"
).
set_input
(
"x"
,
x
).
set_attr_float
(
"power"
,
1.0
).
set_attr_float
(
"scale"
,
scale
).
set_attr_float
(
"shift"
,
bias
)
scale_value
=
core
.
GEOperatorFactory
.
create_operator
(
"scale"
+
self
.
_accumulated_op_id
(),
"Power"
).
set_input
(
"x"
,
x
).
set_attr_float
(
"power"
,
1.0
).
set_attr_float
(
"scale"
,
scale
).
set_attr_float
(
"shift"
,
bias
)
else
:
else
:
x_add_bias
=
core
.
GEOperatorFactory
.
create_operator
(
"adds"
+
self
.
_accumulated_op_id
(),
"Adds"
).
set_input
(
"x"
,
x
).
set_attr_float
(
"value"
,
bias
)
#set_input("x2", bias)
x_add_bias
=
core
.
GEOperatorFactory
.
create_operator
(
scale_value
=
core
.
GEOperatorFactory
.
create_operator
(
"scale"
+
self
.
_accumulated_op_id
(),
"Power"
).
set_input
(
"x"
,
x_add_bias
).
set_attr_float
(
"power"
,
1.0
).
set_attr_float
(
"scale"
,
scale
).
set_attr_float
(
"shift"
,
0.0
)
"adds"
+
self
.
_accumulated_op_id
(),
"Adds"
).
set_input
(
"x"
,
x
).
set_attr_float
(
"value"
,
bias
)
#set_input("x2", bias)
scale_value
=
core
.
GEOperatorFactory
.
create_operator
(
"scale"
+
self
.
_accumulated_op_id
(),
"Power"
).
set_input
(
"x"
,
x_add_bias
).
set_attr_float
(
"power"
,
1.0
).
set_attr_float
(
"scale"
,
scale
).
set_attr_float
(
"shift"
,
0.0
)
#tensor_zeros = core.GEOperatorFactory.create_operator("zeroslike" + self.getid(), "ZerosLike").set_input("x", x)
#tensor_zeros = core.GEOperatorFactory.create_operator("zeroslike" + self.getid(), "ZerosLike").set_input("x", x)
#bias_ = self.create_ge_tensor([1], 5, bias)
#bias_ = self.create_ge_tensor([1], 5, bias)
#const_bias = core.GEOperatorFactory.create_operator("const" + self.getid(), "Const").set_attr_tensor("value", tensor_bias)
#const_bias = core.GEOperatorFactory.create_operator("const" + self.getid(), "Const").set_attr_tensor("value", tensor_bias)
return
[
scale_value
],[[
0
]]
return
[
scale_value
],
[[
0
]]
class
ReshapeParser
(
AscendParserBase
):
class
ReshapeParser
(
AscendParserBase
):
def
__init__
(
self
,
graph
,
var2geop
):
def
__init__
(
self
,
graph
,
var2geop
):
...
@@ -695,9 +714,12 @@ class ReshapeParser(AscendParserBase):
...
@@ -695,9 +714,12 @@ class ReshapeParser(AscendParserBase):
print
(
"shape: "
,
shape
)
print
(
"shape: "
,
shape
)
data_x1_shape
=
self
.
_get_ge_input
(
self
.
op
.
input_arg_names
[
0
])
data_x1_shape
=
self
.
_get_ge_input
(
self
.
op
.
input_arg_names
[
0
])
tensor
=
self
.
_create_ge_tensor
([
len
(
shape
)],
2
,
shape
)
tensor
=
self
.
_create_ge_tensor
([
len
(
shape
)],
2
,
shape
)
const_shape
=
core
.
GEOperatorFactory
.
create_operator
(
"shape"
+
self
.
_accumulated_op_id
(),
"Const"
).
set_attr_tensor
(
"value"
,
tensor
)
const_shape
=
core
.
GEOperatorFactory
.
create_operator
(
reshape
=
core
.
GEOperatorFactory
.
create_operator
(
"reshape"
+
self
.
_accumulated_op_id
(),
"Reshape"
).
set_input
(
"x"
,
data_x1_shape
).
set_input
(
"shape"
,
const_shape
).
set_attr_int32
(
"axis"
,
axis
)
"shape"
+
self
.
_accumulated_op_id
(),
"Const"
).
set_attr_tensor
(
"value"
,
tensor
)
return
[
reshape
,
reshape
],
[[
0
],[
1
]]
reshape
=
core
.
GEOperatorFactory
.
create_operator
(
"reshape"
+
self
.
_accumulated_op_id
(),
"Reshape"
).
set_input
(
"x"
,
data_x1_shape
).
set_input
(
"shape"
,
const_shape
).
set_attr_int32
(
"axis"
,
axis
)
return
[
reshape
,
reshape
],
[[
0
],
[
1
]]
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录