Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
c2fddb56
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
c2fddb56
编写于
9月 01, 2020
作者:
M
mindspore-ci-bot
提交者:
Gitee
9月 01, 2020
浏览文件
操作
浏览文件
下载
差异文件
!4922 Transform tuple parameter to multiple parameters
Merge pull request !4922 from amongo/TupleTransform
上级
24f00cc6
0099da2c
变更
10
隐藏空白更改
内联
并排
Showing
10 changed file
with
923 addition
and
4 deletion
+923
-4
mindspore/ccsrc/frontend/optimizer/graph_transform.cc
mindspore/ccsrc/frontend/optimizer/graph_transform.cc
+144
-0
mindspore/ccsrc/frontend/optimizer/graph_transform.h
mindspore/ccsrc/frontend/optimizer/graph_transform.h
+108
-0
mindspore/ccsrc/frontend/optimizer/irpass.cc
mindspore/ccsrc/frontend/optimizer/irpass.cc
+5
-0
mindspore/ccsrc/frontend/optimizer/irpass.h
mindspore/ccsrc/frontend/optimizer/irpass.h
+3
-0
mindspore/ccsrc/frontend/optimizer/irpass/call_graph_tuple_transform.h
...rc/frontend/optimizer/irpass/call_graph_tuple_transform.h
+246
-0
mindspore/ccsrc/pipeline/jit/action.cc
mindspore/ccsrc/pipeline/jit/action.cc
+1
-0
mindspore/ccsrc/pipeline/jit/pass.cc
mindspore/ccsrc/pipeline/jit/pass.cc
+41
-3
mindspore/ccsrc/pipeline/pynative/pynative_execute.cc
mindspore/ccsrc/pipeline/pynative/pynative_execute.cc
+38
-1
tests/st/pynative/test_graph_param_transform.py
tests/st/pynative/test_graph_param_transform.py
+201
-0
tests/ut/python/pynative_mode/test_graph_param_cases.py
tests/ut/python/pynative_mode/test_graph_param_cases.py
+136
-0
未找到文件。
mindspore/ccsrc/frontend/optimizer/graph_transform.cc
0 → 100644
浏览文件 @
c2fddb56
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "frontend/optimizer/graph_transform.h"
#include <vector>
#include <algorithm>
#include <string>
#include "ir/graph_utils.h"
namespace
mindspore
{
/* namespace to support opt */
namespace
opt
{
// check cnode input values, whether it is tuple input
bool
CNodeHasTupleInput
(
const
CNodePtr
&
cnode
)
{
auto
&
inputs
=
cnode
->
inputs
();
for
(
size_t
i
=
1
;
i
<
inputs
.
size
();
i
++
)
{
if
(
IsValueNode
<
FuncGraph
>
(
inputs
[
i
]))
{
continue
;
}
if
(
IsValueNode
<
Primitive
>
(
inputs
[
i
]))
{
// unexpected high order primitvie as cnode input when transform graph
MS_LOG
(
WARNING
)
<<
"CheckTupleInput, got unexpected primitve as input"
<<
cnode
->
DebugString
();
return
false
;
}
auto
abs
=
inputs
[
i
]
->
abstract
();
if
(
abs
==
nullptr
)
{
MS_LOG
(
WARNING
)
<<
"CheckTupleInput, got abstract nullptr for node:"
<<
cnode
->
DebugString
();
return
false
;
}
if
(
abs
->
isa
<
abstract
::
AbstractTuple
>
())
{
return
true
;
}
}
return
false
;
}
bool
FuncGraphHasTupleInput
(
const
FuncGraphPtr
&
fg
)
{
auto
&
params
=
fg
->
parameters
();
for
(
auto
&
param
:
params
)
{
if
(
param
->
abstract
()
->
isa
<
abstract
::
AbstractTuple
>
())
{
return
true
;
}
}
return
false
;
}
std
::
vector
<
AnfNodePtr
>
TransformTupleArgument
(
const
FuncGraphPtr
&
fg
,
const
AnfNodePtr
&
node
,
const
abstract
::
AbstractTuplePtr
&
abs
)
{
auto
&
elements
=
abs
->
elements
();
std
::
vector
<
AnfNodePtr
>
tuple_node_expanded
;
for
(
size_t
i
=
0
;
i
<
elements
.
size
();
i
++
)
{
auto
elem_node
=
fg
->
NewCNode
({
NewValueNode
(
prim
::
kPrimTupleGetItem
),
node
,
NewValueNode
(
SizeToInt
(
i
))});
elem_node
->
set_abstract
(
elements
[
i
]);
if
(
elements
[
i
]
->
isa
<
abstract
::
AbstractTuple
>
())
{
auto
nodes
=
TransformTupleArgument
(
fg
,
elem_node
,
elements
[
i
]
->
cast
<
abstract
::
AbstractTuplePtr
>
());
tuple_node_expanded
.
insert
(
tuple_node_expanded
.
end
(),
nodes
.
begin
(),
nodes
.
end
());
}
else
{
tuple_node_expanded
.
push_back
(
elem_node
);
}
}
return
tuple_node_expanded
;
}
AnfNodePtr
TransformCallGraph
(
const
FuncGraphPtr
&
trans_fg
,
const
CNodePtr
&
cnode
)
{
auto
&
cinputs
=
cnode
->
inputs
();
auto
fg
=
cnode
->
func_graph
();
std
::
vector
<
AnfNodePtr
>
inputs
;
inputs
.
push_back
(
NewValueNode
(
trans_fg
));
for
(
size_t
i
=
1
;
i
<
cinputs
.
size
();
i
++
)
{
auto
abs
=
cinputs
[
i
]
->
abstract
();
if
(
abs
==
nullptr
)
{
MS_LOG
(
EXCEPTION
)
<<
"TransformCallGraph:Node abstract should not be nullptr"
<<
cinputs
[
i
]
->
DebugString
();
}
if
(
abs
->
isa
<
abstract
::
AbstractTuple
>
())
{
auto
nodes
=
TransformTupleArgument
(
fg
,
cinputs
[
i
],
abs
->
cast
<
abstract
::
AbstractTuplePtr
>
());
inputs
.
insert
(
inputs
.
end
(),
nodes
.
begin
(),
nodes
.
end
());
}
else
{
inputs
.
push_back
(
cinputs
[
i
]);
}
}
auto
new_node
=
fg
->
NewCNode
(
inputs
);
new_node
->
set_abstract
(
cnode
->
abstract
());
return
new_node
;
}
AnfNodePtr
TransformPartial
(
const
FuncGraphPtr
&
trans_fg
,
const
CNodePtr
&
cnode
)
{
auto
&
cinputs
=
cnode
->
inputs
();
auto
fg
=
cnode
->
func_graph
();
std
::
vector
<
AnfNodePtr
>
inputs
;
inputs
.
push_back
(
NewValueNode
(
prim
::
kPrimPartial
));
inputs
.
push_back
(
NewValueNode
(
trans_fg
));
for
(
size_t
i
=
2
;
i
<
cinputs
.
size
();
i
++
)
{
auto
abs
=
cinputs
[
i
]
->
abstract
();
if
(
abs
==
nullptr
)
{
MS_LOG
(
EXCEPTION
)
<<
"TransformPartial:Node abstract should not be nullptr"
<<
cinputs
[
i
]
->
DebugString
();
}
if
(
abs
->
isa
<
abstract
::
AbstractTuple
>
())
{
auto
nodes
=
TransformTupleArgument
(
fg
,
cinputs
[
i
],
abs
->
cast
<
abstract
::
AbstractTuplePtr
>
());
inputs
.
insert
(
inputs
.
end
(),
nodes
.
begin
(),
nodes
.
end
());
}
else
{
inputs
.
push_back
(
cinputs
[
i
]);
}
}
auto
new_node
=
fg
->
NewCNode
(
inputs
);
new_node
->
set_abstract
(
cnode
->
abstract
());
return
new_node
;
}
AnfNodePtr
TransformSwitchCall
(
const
AnfNodePtr
&
swtich_node
,
const
CNodePtr
&
cnode
)
{
auto
&
cinputs
=
cnode
->
inputs
();
auto
fg
=
cnode
->
func_graph
();
std
::
vector
<
AnfNodePtr
>
inputs
;
inputs
.
push_back
(
swtich_node
);
for
(
size_t
i
=
1
;
i
<
cinputs
.
size
();
i
++
)
{
auto
abs
=
cinputs
[
i
]
->
abstract
();
if
(
abs
==
nullptr
)
{
MS_LOG
(
EXCEPTION
)
<<
"TransformSwitchCall:Node abstract should not be nullptr"
<<
cinputs
[
i
]
->
DebugString
();
}
if
(
abs
->
isa
<
abstract
::
AbstractTuple
>
())
{
auto
nodes
=
TransformTupleArgument
(
fg
,
cinputs
[
i
],
abs
->
cast
<
abstract
::
AbstractTuplePtr
>
());
inputs
.
insert
(
inputs
.
end
(),
nodes
.
begin
(),
nodes
.
end
());
}
else
{
inputs
.
push_back
(
cinputs
[
i
]);
}
}
auto
new_node
=
fg
->
NewCNode
(
inputs
);
new_node
->
set_abstract
(
cnode
->
abstract
());
return
new_node
;
}
}
// namespace opt
}
// namespace mindspore
mindspore/ccsrc/frontend/optimizer/graph_transform.h
0 → 100644
浏览文件 @
c2fddb56
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_GRAPH_TRANSFORM_H
#define MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_GRAPH_TRANSFORM_H
#include <unordered_map>
#include <string>
#include <vector>
#include <algorithm>
#include <memory>
#include "frontend/optimizer/optimizer.h"
namespace
mindspore
{
namespace
opt
{
bool
CNodeHasTupleInput
(
const
CNodePtr
&
cnode
);
bool
FuncGraphHasTupleInput
(
const
FuncGraphPtr
&
fg
);
std
::
vector
<
AnfNodePtr
>
TransformTupleArgument
(
const
FuncGraphPtr
&
fg
,
const
AnfNodePtr
&
node
,
const
abstract
::
AbstractTuplePtr
&
abs
);
AnfNodePtr
TransformCallGraph
(
const
FuncGraphPtr
&
trans_fg
,
const
CNodePtr
&
cnode
);
AnfNodePtr
TransformPartial
(
const
FuncGraphPtr
&
trans_fg
,
const
CNodePtr
&
cnode
);
AnfNodePtr
TransformSwitchCall
(
const
AnfNodePtr
&
swtich_node
,
const
CNodePtr
&
cnode
);
class
GraphTupleParamTransform
{
public:
GraphTupleParamTransform
()
:
cache_
()
{}
~
GraphTupleParamTransform
()
{
cache_
.
clear
();
}
FuncGraphPtr
operator
()(
const
FuncGraphPtr
&
fg
,
const
FuncGraphManagerPtr
&
mng
)
{
if
(
cache_
.
find
(
fg
)
!=
cache_
.
end
())
{
return
cache_
[
fg
];
}
auto
new_fg
=
TransformGraphParam
(
fg
,
mng
);
cache_
[
fg
]
=
new_fg
;
return
new_fg
;
}
AnfNodePtr
GenerateTupleParams
(
const
abstract
::
AbstractTuplePtr
&
tuple_abs
,
const
FuncGraphPtr
&
fg
,
std
::
vector
<
AnfNodePtr
>
*
params
)
{
std
::
vector
<
AnfNodePtr
>
inputs
;
inputs
.
push_back
(
NewValueNode
(
prim
::
kPrimMakeTuple
));
auto
&
elements
=
tuple_abs
->
elements
();
for
(
auto
&
item
:
elements
)
{
if
(
item
->
isa
<
abstract
::
AbstractTuple
>
())
{
inputs
.
push_back
(
GenerateTupleParams
(
item
->
cast
<
abstract
::
AbstractTuplePtr
>
(),
fg
,
params
));
}
else
{
auto
p
=
std
::
make_shared
<
Parameter
>
(
fg
);
p
->
set_abstract
(
item
);
params
->
push_back
(
p
);
inputs
.
push_back
(
params
->
back
());
}
}
auto
node
=
fg
->
NewCNode
(
inputs
);
node
->
set_abstract
(
tuple_abs
);
return
node
;
}
FuncGraphPtr
TransformGraphParam
(
const
FuncGraphPtr
&
fg
,
const
FuncGraphManagerPtr
&
mng
)
{
Cloner
cloner
({
fg
},
false
,
false
,
false
,
std
::
make_shared
<
TraceCopy
>
(),
std
::
make_shared
<
TraceCopy
>
());
auto
new_fg
=
cloner
[
fg
];
auto
&
params
=
new_fg
->
parameters
();
std
::
vector
<
AnfNodePtr
>
new_params
;
std
::
unordered_map
<
AnfNodePtr
,
AnfNodePtr
>
repl
;
for
(
auto
&
param
:
params
)
{
auto
abs
=
param
->
abstract
();
if
(
abs
!=
nullptr
&&
abs
->
isa
<
abstract
::
AbstractTuple
>
())
{
auto
tuple_abs
=
abs
->
cast
<
abstract
::
AbstractTuplePtr
>
();
std
::
vector
<
AnfNodePtr
>
tuple_params
;
repl
.
emplace
(
param
,
GenerateTupleParams
(
tuple_abs
,
new_fg
,
&
tuple_params
));
std
::
transform
(
tuple_params
.
begin
(),
tuple_params
.
end
(),
std
::
back_inserter
(
new_params
),
[](
AnfNodePtr
p
)
{
return
p
;
});
}
else
{
new_params
.
push_back
(
param
);
}
}
auto
tmp_mng
=
mindspore
::
Manage
(
new_fg
,
false
);
auto
tr
=
tmp_mng
->
Transact
();
for
(
auto
&
item
:
repl
)
{
bool
ret
=
tr
.
Replace
(
item
.
first
,
item
.
second
);
if
(
ret
==
false
)
{
MS_LOG
(
ERROR
)
<<
"replace failed"
<<
item
.
first
->
DebugString
()
<<
" with__"
<<
item
.
second
->
DebugString
(
2
);
}
}
tr
.
SetParameters
(
new_fg
,
new_params
);
tr
.
Commit
();
mng
->
AddFuncGraph
(
new_fg
);
return
new_fg
;
}
std
::
unordered_map
<
FuncGraphPtr
,
FuncGraphPtr
>
cache_
;
};
}
// namespace opt
}
// namespace mindspore
#endif // MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_GRAPH_TRANSFORM_H
mindspore/ccsrc/frontend/optimizer/irpass.cc
浏览文件 @
c2fddb56
...
...
@@ -44,6 +44,7 @@
#include "frontend/optimizer/irpass/row_tensor_eliminate.h"
#include "frontend/optimizer/irpass/sparse_tensor_eliminate.h"
#include "frontend/optimizer/irpass/switch_layer_defer_inline.h"
#include "frontend/optimizer/irpass/call_graph_tuple_transform.h"
namespace
mindspore
{
namespace
opt
{
...
...
@@ -158,6 +159,10 @@ OptimizeIRPassLib::OptimizeIRPassLib() {
unused_output_eliminate_
=
MakeSubstitution
(
std
::
make_shared
<
UnusedOutputEliminater
>
(),
"unused_output_eliminate"
,
IsCNodeGraphKernel
);
// tuple parameter graph transform
call_graph_tuple_transform_
=
MakeSubstitution
(
std
::
make_shared
<
CallGraphTupleTransform
>
(),
"graph_param_transorm"
,
IsCNode
);
// AddN eliminate
addn_eliminate_
=
MakeSubstitution
(
std
::
make_shared
<
AddNEliminater
>
(),
"addn_eliminate"
,
IsCNodeGraphKernel
);
...
...
mindspore/ccsrc/frontend/optimizer/irpass.h
浏览文件 @
c2fddb56
...
...
@@ -103,6 +103,9 @@ class OptimizeIRPassLib {
SubstitutionPtr
unused_parameter_eliminate_
;
SubstitutionPtr
unused_output_eliminate_
;
// tuple parameter graph transform
SubstitutionPtr
call_graph_tuple_transform_
;
// AddN eliminate
SubstitutionPtr
addn_eliminate_
;
...
...
mindspore/ccsrc/frontend/optimizer/irpass/call_graph_tuple_transform.h
0 → 100644
浏览文件 @
c2fddb56
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_CALL_GRAPH_TRANSFORM_H_
#define MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_CALL_GRAPH_TRANSFORM_H_
#include <algorithm>
#include <memory>
#include <unordered_map>
#include <unordered_set>
#include <vector>
#include "ir/func_graph.h"
#include "ir/func_graph_cloner.h"
#include "frontend/optimizer/optimizer_caller.h"
#include "frontend/optimizer/anf_visitor.h"
#include "frontend/operator/ops.h"
#include "frontend/optimizer/irpass.h"
#include "frontend/optimizer/optimizer.h"
#include "frontend/optimizer/graph_transform.h"
namespace
mindspore
{
namespace
opt
{
namespace
irpass
{
// {G, Xs}-->transform graph call tuple inputs to flat inputs.
class
GraphCallTupleTransform
:
public
AnfVisitor
{
public:
explicit
GraphCallTupleTransform
(
GraphTupleParamTransform
&
transformer
)
:
graph_transform_
(
transformer
)
{}
~
GraphCallTupleTransform
()
override
=
default
;
AnfNodePtr
operator
()(
const
OptimizerPtr
&
optimizer
,
const
AnfNodePtr
&
node
)
override
{
if
(
!
node
->
isa
<
CNode
>
()
||
node
->
func_graph
()
==
nullptr
)
{
return
nullptr
;
}
auto
cnode
=
node
->
cast
<
CNodePtr
>
();
MS_EXCEPTION_IF_NULL
(
cnode
);
auto
&
inputs
=
cnode
->
inputs
();
auto
fg
=
GetValueNode
<
FuncGraphPtr
>
(
inputs
[
0
]);
if
(
fg
==
nullptr
)
{
return
nullptr
;
}
if
(
!
CNodeHasTupleInput
(
node
->
cast
<
CNodePtr
>
()))
{
return
nullptr
;
}
FuncGraphPtr
transformed_fg
=
graph_transform_
(
fg
,
optimizer
->
manager
());
auto
new_node
=
TransformCallGraph
(
transformed_fg
,
node
->
cast
<
CNodePtr
>
());
return
new_node
;
}
private:
GraphTupleParamTransform
&
graph_transform_
;
};
// {{switch, cond, true_branch, false_branch}, Xs} -->transform switch graph call tuple inputs to flat inputs.
class
SwitchCallTupleTransform
:
public
AnfVisitor
{
public:
explicit
SwitchCallTupleTransform
(
GraphTupleParamTransform
&
transformer
)
:
graph_transform_
(
transformer
)
{}
~
SwitchCallTupleTransform
()
override
=
default
;
AnfNodePtr
operator
()(
const
OptimizerPtr
&
optimizer
,
const
AnfNodePtr
&
node
)
override
{
if
(
!
node
->
isa
<
CNode
>
()
||
node
->
func_graph
()
==
nullptr
)
{
return
nullptr
;
}
auto
switch_call_cnode
=
node
->
cast
<
CNodePtr
>
();
auto
call_inputs
=
switch_call_cnode
->
inputs
();
if
(
call_inputs
.
size
()
<
1
)
{
return
nullptr
;
}
if
(
!
IsPrimitiveCNode
(
call_inputs
[
0
],
prim
::
kPrimSwitch
))
{
return
nullptr
;
}
auto
swich_cnode
=
call_inputs
[
0
]
->
cast
<
CNodePtr
>
();
auto
switch_inputs
=
swich_cnode
->
inputs
();
if
(
switch_inputs
.
size
()
!=
4
)
{
return
nullptr
;
}
AnfNodePtr
transformed
=
nullptr
;
bool
true_br_changed
=
TransformBranchNode
(
switch_inputs
[
2
],
optimizer
->
manager
(),
&
transformed
);
if
(
true_br_changed
)
{
switch_inputs
[
2
]
=
transformed
;
}
bool
false_br_changed
=
TransformBranchNode
(
switch_inputs
[
3
],
optimizer
->
manager
(),
&
transformed
);
if
(
false_br_changed
)
{
switch_inputs
[
3
]
=
transformed
;
}
if
(
true_br_changed
||
false_br_changed
)
{
call_inputs
[
0
]
=
swich_cnode
->
func_graph
()
->
NewCNode
(
switch_inputs
);
}
if
(
CNodeHasTupleInput
(
switch_call_cnode
))
{
return
TransformSwitchCall
(
call_inputs
[
0
],
switch_call_cnode
);
}
if
(
true_br_changed
||
false_br_changed
)
{
return
switch_call_cnode
->
func_graph
()
->
NewCNode
(
call_inputs
);
}
return
nullptr
;
}
bool
TransformBranchNode
(
AnfNodePtr
node
,
FuncGraphManagerPtr
mng
,
AnfNodePtr
*
trans_node
)
{
if
(
IsValueNode
<
FuncGraph
>
(
node
))
{
FuncGraphPtr
fg
=
GetValueNode
<
FuncGraphPtr
>
(
node
);
if
(
FuncGraphHasTupleInput
(
fg
))
{
FuncGraphPtr
transformed_fg
=
graph_transform_
(
fg
,
mng
);
*
trans_node
=
NewValueNode
(
transformed_fg
);
return
true
;
}
return
false
;
}
if
(
IsPrimitiveCNode
(
node
,
prim
::
kPrimPartial
))
{
auto
partial_inputs
=
node
->
cast
<
CNodePtr
>
()
->
inputs
();
if
(
IsValueNode
<
FuncGraph
>
(
partial_inputs
[
1
]))
{
FuncGraphPtr
fg
=
GetValueNode
<
FuncGraphPtr
>
(
partial_inputs
[
1
]);
if
(
FuncGraphHasTupleInput
(
fg
))
{
fg
=
graph_transform_
(
fg
,
mng
);
}
if
(
CNodeHasTupleInput
(
node
->
cast
<
CNodePtr
>
()))
{
*
trans_node
=
TransformPartial
(
fg
,
node
->
cast
<
CNodePtr
>
());
return
true
;
}
}
return
false
;
}
MS_LOG
(
WARNING
)
<<
"Got unexpected switch branch node "
<<
node
->
DebugString
();
return
false
;
}
private:
GraphTupleParamTransform
&
graph_transform_
;
};
// {{switch_layer, index, {make_tuple, br1, br2,...,}}, Xs} ->
// transform switch layer graph call tuple inputs to flat inputs.
class
SwitchLayerCallTupleTransform
:
public
AnfVisitor
{
public:
explicit
SwitchLayerCallTupleTransform
(
GraphTupleParamTransform
&
transformer
)
:
graph_transform_
(
transformer
)
{}
~
SwitchLayerCallTupleTransform
()
override
=
default
;
AnfNodePtr
operator
()(
const
OptimizerPtr
&
optimizer
,
const
AnfNodePtr
&
node
)
override
{
if
(
!
node
->
isa
<
CNode
>
()
||
node
->
func_graph
()
==
nullptr
)
{
return
nullptr
;
}
auto
switch_layer_call_cnode
=
node
->
cast
<
CNodePtr
>
();
auto
call_inputs
=
switch_layer_call_cnode
->
inputs
();
if
(
call_inputs
.
size
()
<
1
)
{
return
nullptr
;
}
if
(
!
IsPrimitiveCNode
(
call_inputs
[
0
],
prim
::
kPrimSwitchLayer
))
{
return
nullptr
;
}
auto
swich_layer_cnode
=
call_inputs
[
0
]
->
cast
<
CNodePtr
>
();
auto
switch_layer_inputs
=
swich_layer_cnode
->
inputs
();
if
(
switch_layer_inputs
.
size
()
!=
3
)
{
return
nullptr
;
}
AnfNodePtr
transformed
=
nullptr
;
bool
layer_changed
=
TransformLayerNode
(
switch_layer_inputs
[
2
],
optimizer
->
manager
(),
&
transformed
);
if
(
layer_changed
)
{
switch_layer_inputs
[
2
]
=
transformed
;
call_inputs
[
0
]
=
switch_layer_call_cnode
->
func_graph
()
->
NewCNode
(
switch_layer_inputs
);
}
if
(
CNodeHasTupleInput
(
switch_layer_call_cnode
))
{
return
TransformSwitchCall
(
call_inputs
[
0
],
switch_layer_call_cnode
);
}
if
(
layer_changed
)
{
return
switch_layer_call_cnode
->
func_graph
()
->
NewCNode
(
call_inputs
);
}
return
nullptr
;
}
bool
TransformLayerNode
(
AnfNodePtr
node
,
FuncGraphManagerPtr
mng
,
AnfNodePtr
*
trans_node
)
{
if
(
!
IsPrimitiveCNode
(
node
,
prim
::
kPrimMakeTuple
))
{
MS_LOG
(
WARNING
)
<<
"SwitchLayer input is not MakeTuple"
;
return
false
;
}
auto
tuple_inputs
=
node
->
cast
<
CNodePtr
>
()
->
inputs
();
bool
changed
=
false
;
for
(
size_t
i
=
1
;
i
<
tuple_inputs
.
size
();
i
++
)
{
if
(
!
IsValueNode
<
FuncGraph
>
(
tuple_inputs
[
i
]))
{
MS_LOG
(
WARNING
)
<<
"SwitchLayer input is not FuncGraph"
;
return
false
;
}
FuncGraphPtr
fg
=
GetValueNode
<
FuncGraphPtr
>
(
tuple_inputs
[
i
]);
if
(
FuncGraphHasTupleInput
(
fg
))
{
FuncGraphPtr
transformed_fg
=
graph_transform_
(
fg
,
mng
);
tuple_inputs
[
i
]
=
NewValueNode
(
transformed_fg
);
changed
=
true
;
}
}
if
(
changed
)
{
*
trans_node
=
node
->
func_graph
()
->
NewCNode
(
tuple_inputs
);
}
return
changed
;
}
private:
GraphTupleParamTransform
&
graph_transform_
;
};
class
CallGraphTupleTransform
:
public
OptimizerCaller
{
public:
CallGraphTupleTransform
()
:
graph_transformer_
(),
graph_call_transform_
(
std
::
make_shared
<
GraphCallTupleTransform
>
(
graph_transformer_
)),
switch_call_transform_
(
std
::
make_shared
<
SwitchCallTupleTransform
>
(
graph_transformer_
)),
switch_layer_call_transform_
(
std
::
make_shared
<
SwitchLayerCallTupleTransform
>
(
graph_transformer_
))
{
transformers_
.
emplace_back
(
graph_call_transform_
);
transformers_
.
emplace_back
(
switch_call_transform_
);
transformers_
.
emplace_back
(
switch_layer_call_transform_
);
}
~
CallGraphTupleTransform
()
=
default
;
AnfNodePtr
operator
()(
const
OptimizerPtr
&
optimizer
,
const
AnfNodePtr
&
node
)
override
{
AnfNodePtr
new_node
;
for
(
auto
&
transform
:
transformers_
)
{
new_node
=
(
*
transform
)(
optimizer
,
node
);
if
(
new_node
!=
nullptr
)
{
return
new_node
;
}
}
return
nullptr
;
}
private:
GraphTupleParamTransform
graph_transformer_
;
OptimizerCallerPtr
graph_call_transform_
;
OptimizerCallerPtr
switch_call_transform_
;
OptimizerCallerPtr
switch_layer_call_transform_
;
std
::
vector
<
OptimizerCallerPtr
>
transformers_
{};
};
}
// namespace irpass
}
// namespace opt
}
// namespace mindspore
#endif // MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_CALL_GRAPH_TRANSFORM_H_
mindspore/ccsrc/pipeline/jit/action.cc
浏览文件 @
c2fddb56
...
...
@@ -277,6 +277,7 @@ bool OptimizeAction(const ResourcePtr &res, const std::vector<PassItem> &passes)
MS_EXCEPTION_IF_NULL
(
func_graph
);
func_graph
->
DumpFuncGraph
(
fg_name
);
DumpIR
(
fg_name
+
".ir"
,
func_graph
);
ExportIR
(
fg_name
+
".dat"
,
""
,
func_graph
);
MS_LOG
(
DEBUG
)
<<
"Dump "
<<
fg_name
<<
" func graph."
;
}
counter
++
;
...
...
mindspore/ccsrc/pipeline/jit/pass.cc
浏览文件 @
c2fddb56
...
...
@@ -33,6 +33,7 @@
#include "frontend/optimizer/clean.h"
#include "frontend/optimizer/irpass.h"
#include "frontend/optimizer/control_depend.h"
#include "frontend/optimizer/graph_transform.h"
#include "frontend/parallel/step_parallel.h"
#include "frontend/parallel/step_auto_parallel.h"
#include "frontend/parallel/allreduce_fusion/step_allreduce_fusion.h"
...
...
@@ -166,12 +167,23 @@ OptPassGroupMap GetOptPassesA(const opt::irpass::OptimizeIRPassLib &irpass) {
OptPassGroupMap
GetOptPassesAfterCconv
(
const
opt
::
irpass
::
OptimizeIRPassLib
&
irpass
)
{
opt
::
OptPassConfig
c_1
=
opt
::
OptPassConfig
({
// Safe inlining
// Safe inlining
,
irpass
.
inline_
,
irpass
.
partial_eliminate_
,
});
OptPassGroupMap
map_a
({{
"c_1"
,
c_1
},
{
"renormalize"
,
opt
::
OptPassConfig
::
Renormalize
()}});
OptPassGroupMap
map_a
({{
"c_1"
,
c_1
},
{
"cse"
,
opt
::
OptPassConfig
(
opt
::
CSEPass
(
false
))},
{
"renormalize"
,
opt
::
OptPassConfig
::
Renormalize
()}});
return
map_a
;
}
OptPassGroupMap
GetOptPassesTransformGraph
(
const
opt
::
irpass
::
OptimizeIRPassLib
&
irpass
)
{
opt
::
OptPassConfig
d_1
=
opt
::
OptPassConfig
({
// Safe inlining
irpass
.
call_graph_tuple_transform_
,
irpass
.
item_tuple_eliminate_
});
OptPassGroupMap
map_a
({{
"d_1"
,
d_1
},
{
"renormalize"
,
opt
::
OptPassConfig
::
Renormalize
()}});
return
map_a
;
}
...
...
@@ -262,6 +274,8 @@ void InitOpt(const ResourcePtr &res) {
g_pass_opts
[
"opt_b"
]
=
Optimizer
::
MakeOptimizer
(
"opt_b"
,
res
,
GetOptPassesB
(
irpass
),
false
,
true
);
g_pass_opts
[
"opt_after_cconv"
]
=
Optimizer
::
MakeOptimizer
(
"opt_after_cconv"
,
res
,
GetOptPassesAfterCconv
(
irpass
),
false
,
true
);
g_pass_opts
[
"opt_trans_graph"
]
=
Optimizer
::
MakeOptimizer
(
"opt_trans_graph"
,
res
,
GetOptPassesTransformGraph
(
irpass
),
true
,
true
);
g_pass_opts
[
"opt_graph_kernel_a"
]
=
Optimizer
::
MakeOptimizer
(
"opt_graph_kernel_a"
,
res
,
GetOptPassesGraphKernelA
(
irpass
),
true
);
g_pass_opts
[
"opt_graph_kernel_b"
]
=
...
...
@@ -307,6 +321,7 @@ bool OptPassGroup(const ResourcePtr &res, const std::string &name) {
bool
OptPassAGroup
(
const
ResourcePtr
&
res
)
{
return
OptPassGroup
(
res
,
"opt_a"
);
}
bool
OptPassBGroup
(
const
ResourcePtr
&
res
)
{
return
OptPassGroup
(
res
,
"opt_b"
);
}
bool
OptPassAfterCconvGroup
(
const
ResourcePtr
&
res
)
{
return
OptPassGroup
(
res
,
"opt_after_cconv"
);
}
bool
OptPassTransformGraphGroup
(
const
ResourcePtr
&
res
)
{
return
OptPassGroup
(
res
,
"opt_trans_graph"
);
}
bool
OptPassGraphKernelGroupA
(
const
ResourcePtr
&
res
)
{
return
OptPassGroup
(
res
,
"opt_graph_kernel_a"
);
}
bool
OptPassGraphKernelGroupB
(
const
ResourcePtr
&
res
)
{
return
OptPassGroup
(
res
,
"opt_graph_kernel_b"
);
}
bool
ControlGroup
(
const
ResourcePtr
&
res
)
{
return
OptPassGroup
(
res
,
"opt_control"
);
}
...
...
@@ -365,6 +380,24 @@ bool CconvPass(const ResourcePtr &res) {
return
true
;
}
bool
TransformTopGraphPass
(
const
ResourcePtr
&
res
)
{
if
(
res
->
func_graph
()
==
nullptr
)
{
MS_LOG
(
EXCEPTION
)
<<
"Transform top graph error."
;
}
FuncGraphPtr
func_graph
=
res
->
func_graph
();
if
(
opt
::
FuncGraphHasTupleInput
(
func_graph
))
{
opt
::
GraphTupleParamTransform
graph_trans
;
func_graph
=
graph_trans
(
func_graph
,
res
->
manager
());
res
->
set_func_graph
(
func_graph
);
AbstractBasePtrList
abs_spec_list
;
auto
&
params
=
func_graph
->
parameters
();
std
::
transform
(
params
.
begin
(),
params
.
end
(),
std
::
back_inserter
(
abs_spec_list
),
[](
AnfNodePtr
node
)
{
return
node
->
abstract
();
});
res
->
set_args_spec
(
abs_spec_list
);
}
return
true
;
}
bool
ValidatePass
(
const
ResourcePtr
&
res
)
{
MS_EXCEPTION_IF_NULL
(
res
->
func_graph
());
FuncGraphPtr
func_graph
=
res
->
func_graph
();
...
...
@@ -388,6 +421,7 @@ std::vector<PassItem> kVmPasses = {{"simplify_data_structures", SimplifyDataStru
{
"cconv"
,
CconvPass
},
{
"opt_after_cconv"
,
OptPassAfterCconvGroup
},
{
"remove_dup_value"
,
RemoveValueNodeDuplicationsPass
},
{
"tuple_transform"
,
OptPassTransformGraphGroup
},
{
"opt_graph_kernel_a"
,
OptPassGraphKernelGroupA
},
{
"opt_graph_kernel_b"
,
OptPassGraphKernelGroupB
},
{
"add_control_depend"
,
AddControlDependPass
}};
...
...
@@ -401,6 +435,10 @@ std::vector<PassItem> kGePasses = {{"simplify_data_structures", SimplifyDataStru
{
"opt_prepare"
,
PrepareGroup
},
{
"cconv"
,
CconvPass
}};
std
::
vector
<
PassItem
>
kPynativePasses
=
{{
"opt_a"
,
OptPassAGroup
},
{
"opt_b"
,
OptPassBGroup
},
{
"cconv"
,
CconvPass
}};
std
::
vector
<
PassItem
>
kPynativePasses
=
{{
"opt_a"
,
OptPassAGroup
},
{
"opt_b"
,
OptPassBGroup
},
{
"cconv"
,
CconvPass
},
{
"transform_top"
,
TransformTopGraphPass
},
{
"transform_graph"
,
OptPassTransformGraphGroup
}};
}
// namespace pipeline
}
// namespace mindspore
mindspore/ccsrc/pipeline/pynative/pynative_execute.cc
浏览文件 @
c2fddb56
...
...
@@ -1387,9 +1387,46 @@ void PynativeExecutor::ClearRes() {
resource_
.
reset
();
}
size_t
GetTupleSize
(
const
py
::
tuple
&
args
)
{
size_t
count
=
0
;
for
(
size_t
i
=
0
;
i
<
args
.
size
();
i
++
)
{
if
(
py
::
isinstance
<
py
::
tuple
>
(
args
[
i
]))
{
count
+=
GetTupleSize
(
args
[
i
]);
}
else
{
count
+=
1
;
}
}
return
count
;
}
void
ConvertTupleArg
(
py
::
tuple
*
res
,
size_t
*
index
,
const
py
::
tuple
&
arg
)
{
for
(
size_t
i
=
0
;
i
<
arg
.
size
();
i
++
)
{
if
(
py
::
isinstance
<
py
::
tuple
>
(
arg
[
i
]))
{
ConvertTupleArg
(
res
,
index
,
arg
[
i
]);
}
else
{
(
*
res
)[(
*
index
)
++
]
=
arg
[
i
];
}
}
}
py
::
tuple
ConvertArgs
(
const
py
::
tuple
&
args
)
{
size_t
tuple_size
=
GetTupleSize
(
args
);
py
::
tuple
res
(
tuple_size
);
size_t
index
=
0
;
for
(
size_t
i
=
0
;
i
<
args
.
size
();
i
++
)
{
if
(
py
::
isinstance
<
py
::
tuple
>
(
args
[
i
]))
{
ConvertTupleArg
(
&
res
,
&
index
,
args
[
i
]);
}
else
{
res
[
index
++
]
=
args
[
i
];
}
}
return
res
;
}
py
::
object
PynativeExecutor
::
Run
(
const
py
::
tuple
&
args
,
const
py
::
object
&
phase
)
{
VectorRef
arg_list
;
pipeline
::
ProcessVmArgInner
(
args
,
resource_
,
&
arg_list
);
py
::
tuple
converted_args
=
ConvertArgs
(
args
);
pipeline
::
ProcessVmArgInner
(
converted_args
,
resource_
,
&
arg_list
);
if
(
resource_
->
results
().
find
(
pipeline
::
kOutput
)
==
resource_
->
results
().
end
()
||
!
resource_
->
results
()[
pipeline
::
kOutput
].
is
<
compile
::
VmEvalFuncPtr
>
())
{
MS_LOG
(
EXCEPTION
)
<<
"Can't find run graph func for "
;
...
...
tests/st/pynative/test_graph_param_transform.py
0 → 100644
浏览文件 @
c2fddb56
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import
pytest
import
numpy
as
np
from
mindspore
import
RowTensor
from
mindspore
import
context
,
nn
,
Tensor
,
ParameterTuple
from
mindspore.common
import
dtype
as
mstype
from
mindspore.common
import
ms_function
from
mindspore.ops
import
operations
as
P
from
mindspore.ops
import
composite
as
C
def
setup_module
():
context
.
set_context
(
mode
=
context
.
PYNATIVE_MODE
,
enable_sparse
=
False
)
class
_Grad
(
nn
.
Cell
):
def
__init__
(
self
,
grad
,
network
,
wrt_params
=
False
,
real_inputs_count
=
None
):
super
().
__init__
()
self
.
network
=
network
self
.
grad
=
grad
self
.
sens_param
=
self
.
grad
.
sens_param
self
.
wrt_params
=
wrt_params
self
.
real_inputs_count
=
real_inputs_count
if
self
.
wrt_params
:
self
.
params
=
ParameterTuple
(
self
.
network
.
trainable_params
())
def
construct
(
self
,
*
inputs
):
if
self
.
wrt_params
:
if
self
.
real_inputs_count
is
None
or
self
.
sens_param
is
False
:
return
self
.
grad
(
self
.
network
,
self
.
params
)(
*
inputs
)
real_inputs
=
inputs
[:
self
.
real_inputs_count
]
sense_param_inputs
=
inputs
[
self
.
real_inputs_count
:]
return
self
.
grad
(
self
.
network
,
self
.
params
)(
*
real_inputs
,
sense_param_inputs
)
if
self
.
real_inputs_count
is
None
or
self
.
sens_param
is
False
:
return
self
.
grad
(
self
.
network
)(
*
inputs
)
real_inputs
=
inputs
[:
self
.
real_inputs_count
]
sense_param_inputs
=
inputs
[
self
.
real_inputs_count
:]
return
self
.
grad
(
self
.
network
)(
*
real_inputs
,
sense_param_inputs
)
class
GradOfFirstInput
(
_Grad
):
"""
get grad of first input
"""
def
__init__
(
self
,
network
,
sens_param
=
True
,
real_inputs_count
=
None
):
super
().
__init__
(
grad
=
C
.
GradOperation
(
sens_param
=
sens_param
),
network
=
network
,
real_inputs_count
=
real_inputs_count
)
class
GradOfAllInputs
(
_Grad
):
"""
get grad of first input
"""
def
__init__
(
self
,
network
,
sens_param
=
True
,
real_inputs_count
=
None
):
super
().
__init__
(
grad
=
C
.
GradOperation
(
get_all
=
True
,
sens_param
=
sens_param
),
network
=
network
,
real_inputs_count
=
real_inputs_count
)
@
pytest
.
mark
.
level0
@
pytest
.
mark
.
platform_arm_ascend_training
@
pytest
.
mark
.
platform_x86_ascend_training
@
pytest
.
mark
.
env_onecard
def
test_row_tensor_in_while
():
class
RowTensorValuesDouble
(
nn
.
Cell
):
def
construct
(
self
,
x
):
indices
=
x
.
indices
values
=
x
.
values
*
2
dense_shape
=
x
.
dense_shape
return
RowTensor
(
indices
,
values
,
dense_shape
)
class
RowTensorValuesAdd2
(
nn
.
Cell
):
def
construct
(
self
,
x
):
indices
=
x
.
indices
values
=
x
.
values
+
2
dense_shape
=
x
.
dense_shape
return
RowTensor
(
indices
,
values
,
dense_shape
)
class
RowTensorWithControlWhile
(
nn
.
Cell
):
def
__init__
(
self
,
dense_shape
):
super
().
__init__
()
self
.
op1
=
RowTensorValuesDouble
()
self
.
op2
=
RowTensorValuesAdd2
()
self
.
dense_shape
=
dense_shape
@
ms_function
def
construct
(
self
,
a
,
b
,
indices
,
values
):
x
=
RowTensor
(
indices
,
values
,
self
.
dense_shape
)
x
=
self
.
op2
(
x
)
while
a
>
b
:
x
=
self
.
op1
(
x
)
b
=
b
+
1
return
x
.
indices
,
x
.
values
,
x
.
dense_shape
a
=
Tensor
(
np
.
array
(
3
).
astype
(
np
.
int32
))
b
=
Tensor
(
np
.
array
(
0
).
astype
(
np
.
int32
))
indices
=
Tensor
(
np
.
array
([
0
,
2
]).
astype
(
np
.
int32
))
values
=
Tensor
(
np
.
ones
([
2
,
2
]).
astype
(
np
.
float32
))
dense_shape
=
(
5
,
2
)
net
=
RowTensorWithControlWhile
(
dense_shape
)
out
=
net
(
a
,
b
,
indices
,
values
)
assert
np
.
allclose
(
indices
.
asnumpy
(),
out
[
0
].
asnumpy
(),
.
0
,
.
0
)
assert
np
.
allclose
(
values
.
asnumpy
()
*
24
,
out
[
1
].
asnumpy
(),
.
0
,
.
0
)
assert
dense_shape
==
out
[
2
]
@
pytest
.
mark
.
level0
@
pytest
.
mark
.
platform_arm_ascend_training
@
pytest
.
mark
.
platform_x86_ascend_training
@
pytest
.
mark
.
env_onecard
def
test_parser_switch_layer_inputs_tuple
():
class
Add
(
nn
.
Cell
):
def
__init__
(
self
):
super
().
__init__
()
self
.
op
=
P
.
TensorAdd
()
def
construct
(
self
,
x
):
y
=
self
.
op
(
x
[
0
],
x
[
1
])
return
self
.
op
(
x
[
0
],
y
)
class
Mul
(
nn
.
Cell
):
def
__init__
(
self
):
super
().
__init__
()
self
.
op
=
P
.
Mul
()
def
construct
(
self
,
x
):
y
=
self
.
op
(
x
[
0
],
x
[
1
])
return
self
.
op
(
x
[
0
],
y
)
class
MulTwoInput
(
nn
.
Cell
):
def
__init__
(
self
):
super
().
__init__
()
self
.
op
=
P
.
Mul
()
@
ms_function
def
construct
(
self
,
x
,
y
):
y
=
self
.
op
(
x
,
y
)
return
self
.
op
(
x
,
y
)
class
TwoInputTupleFinalNet
(
nn
.
Cell
):
def
__init__
(
self
,
funcs
):
super
().
__init__
()
self
.
funcs
=
funcs
@
ms_function
def
construct
(
self
,
i
,
inputa
,
inputb
):
inputs
=
(
inputa
,
inputb
)
x
=
self
.
funcs
[
i
](
inputs
)
return
x
func1
=
Add
()
func2
=
Mul
()
funcs
=
(
func1
,
func2
)
net
=
TwoInputTupleFinalNet
(
funcs
)
input_data
=
Tensor
(
np
.
random
.
randn
(
2
,
3
,
4
,
5
).
astype
(
np
.
float32
))
input2
=
Tensor
(
np
.
random
.
randn
(
2
,
3
,
4
,
5
).
astype
(
np
.
float32
))
i
=
Tensor
(
1
,
mstype
.
int32
)
netout
=
net
(
i
,
input_data
,
input2
)
net_good
=
MulTwoInput
()
goodout
=
net_good
(
input_data
,
input2
)
assert
np
.
allclose
(
goodout
.
asnumpy
(),
netout
.
asnumpy
(),
0
,
0
)
@
pytest
.
mark
.
level0
@
pytest
.
mark
.
platform_arm_ascend_training
@
pytest
.
mark
.
platform_x86_ascend_training
@
pytest
.
mark
.
env_onecard
def
test_imagenet
():
class
ImageGradients
(
nn
.
Cell
):
def
__init__
(
self
):
super
().
__init__
()
self
.
imagegradients
=
nn
.
ImageGradients
()
def
construct
(
self
,
inputs
):
return
self
.
imagegradients
(
inputs
)
net
=
ImageGradients
()
net_me
=
GradOfFirstInput
(
net
,
real_inputs_count
=
1
)
net_me
.
set_train
()
input_data
=
Tensor
(
np
.
ones
([
32
,
16
,
8
,
8
]),
dtype
=
mstype
.
float32
)
output_grad
=
(
Tensor
(
np
.
ones
([
32
,
16
,
8
,
8
]),
dtype
=
mstype
.
float32
),
Tensor
(
np
.
ones
([
32
,
16
,
8
,
8
]),
dtype
=
mstype
.
float32
))
net_me
(
input_data
,
*
output_grad
)
tests/ut/python/pynative_mode/test_graph_param_cases.py
0 → 100644
浏览文件 @
c2fddb56
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import
numpy
as
np
from
mindspore
import
RowTensor
from
mindspore
import
context
,
nn
,
Tensor
,
ParameterTuple
from
mindspore.common
import
dtype
as
mstype
from
mindspore.common
import
ms_function
from
mindspore.ops
import
composite
as
C
def
setup_module
():
context
.
set_context
(
mode
=
context
.
PYNATIVE_MODE
,
enable_sparse
=
False
)
class
_Grad
(
nn
.
Cell
):
def
__init__
(
self
,
grad
,
network
,
wrt_params
=
False
,
real_inputs_count
=
None
):
super
().
__init__
()
self
.
network
=
network
self
.
grad
=
grad
self
.
sens_param
=
self
.
grad
.
sens_param
self
.
wrt_params
=
wrt_params
self
.
real_inputs_count
=
real_inputs_count
if
self
.
wrt_params
:
self
.
params
=
ParameterTuple
(
self
.
network
.
trainable_params
())
def
construct
(
self
,
*
inputs
):
if
self
.
wrt_params
:
if
self
.
real_inputs_count
is
None
or
self
.
sens_param
is
False
:
return
self
.
grad
(
self
.
network
,
self
.
params
)(
*
inputs
)
real_inputs
=
inputs
[:
self
.
real_inputs_count
]
sense_param_inputs
=
inputs
[
self
.
real_inputs_count
:]
return
self
.
grad
(
self
.
network
,
self
.
params
)(
*
real_inputs
,
sense_param_inputs
)
if
self
.
real_inputs_count
is
None
or
self
.
sens_param
is
False
:
return
self
.
grad
(
self
.
network
)(
*
inputs
)
real_inputs
=
inputs
[:
self
.
real_inputs_count
]
sense_param_inputs
=
inputs
[
self
.
real_inputs_count
:]
return
self
.
grad
(
self
.
network
)(
*
real_inputs
,
sense_param_inputs
)
class
GradOfFirstInput
(
_Grad
):
"""
get grad of first input
"""
def
__init__
(
self
,
network
,
sens_param
=
True
,
real_inputs_count
=
None
):
super
().
__init__
(
grad
=
C
.
GradOperation
(
sens_param
=
sens_param
),
network
=
network
,
real_inputs_count
=
real_inputs_count
)
class
GradOfAllInputs
(
_Grad
):
"""
get grad of first input
"""
def
__init__
(
self
,
network
,
sens_param
=
True
,
real_inputs_count
=
None
):
super
().
__init__
(
grad
=
C
.
GradOperation
(
get_all
=
True
,
sens_param
=
sens_param
),
network
=
network
,
real_inputs_count
=
real_inputs_count
)
def
test_row_tensor_in_while
():
class
RowTensorValuesDouble
(
nn
.
Cell
):
def
__init__
(
self
):
super
().
__init__
()
def
construct
(
self
,
x
):
indices
=
x
.
indices
values
=
x
.
values
*
2
dense_shape
=
x
.
dense_shape
return
RowTensor
(
indices
,
values
,
dense_shape
)
class
RowTensorValuesAdd2
(
nn
.
Cell
):
def
__init__
(
self
):
super
().
__init__
()
def
construct
(
self
,
x
):
indices
=
x
.
indices
values
=
x
.
values
+
2
dense_shape
=
x
.
dense_shape
return
RowTensor
(
indices
,
values
,
dense_shape
)
class
RowTensorWithControlWhile
(
nn
.
Cell
):
def
__init__
(
self
,
dense_shape
):
super
().
__init__
()
self
.
op1
=
RowTensorValuesDouble
()
self
.
op2
=
RowTensorValuesAdd2
()
self
.
dense_shape
=
dense_shape
@
ms_function
def
construct
(
self
,
a
,
b
,
indices
,
values
):
x
=
RowTensor
(
indices
,
values
,
self
.
dense_shape
)
x
=
self
.
op2
(
x
)
while
(
a
>
b
):
x
=
self
.
op1
(
x
)
b
=
b
+
1
return
x
.
indices
,
x
.
values
,
x
.
dense_shape
a
=
Tensor
(
np
.
array
(
3
).
astype
(
np
.
int32
))
b
=
Tensor
(
np
.
array
(
0
).
astype
(
np
.
int32
))
indices
=
Tensor
(
np
.
array
([
0
,
2
]).
astype
(
np
.
int32
))
values
=
Tensor
(
np
.
ones
([
2
,
2
]).
astype
(
np
.
float32
))
dense_shape
=
(
5
,
2
)
net
=
RowTensorWithControlWhile
(
dense_shape
)
net
(
a
,
b
,
indices
,
values
)
def
test_multi_out_sens
():
class
ImageGradients
(
nn
.
Cell
):
def
__init__
(
self
):
super
().
__init__
()
def
construct
(
self
,
x
,
y
,
z
):
resa
=
x
*
y
resb
=
y
*
z
resc
=
x
*
z
return
resa
,
(
resb
,
resc
)
net
=
ImageGradients
()
net_me
=
GradOfAllInputs
(
net
,
real_inputs_count
=
3
)
net_me
.
set_train
()
input_data
=
Tensor
(
np
.
ones
([
32
]),
dtype
=
mstype
.
float32
)
output_grad
=
(
Tensor
(
np
.
ones
([
32
]),
dtype
=
mstype
.
float32
),
(
Tensor
(
np
.
ones
([
32
]),
dtype
=
mstype
.
float32
),
Tensor
(
np
.
ones
([
32
]),
dtype
=
mstype
.
float32
)))
net_me
(
input_data
,
input_data
,
input_data
,
*
output_grad
)
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录