Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
0099da2c
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
0099da2c
编写于
8月 21, 2020
作者:
H
huangdongrun
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add support for tuple parameter transform
add support for pynative pass add testcases
上级
1d0e0ae2
变更
10
隐藏空白更改
内联
并排
Showing
10 changed file
with
923 addition
and
4 deletion
+923
-4
mindspore/ccsrc/frontend/optimizer/graph_transform.cc
mindspore/ccsrc/frontend/optimizer/graph_transform.cc
+144
-0
mindspore/ccsrc/frontend/optimizer/graph_transform.h
mindspore/ccsrc/frontend/optimizer/graph_transform.h
+108
-0
mindspore/ccsrc/frontend/optimizer/irpass.cc
mindspore/ccsrc/frontend/optimizer/irpass.cc
+5
-0
mindspore/ccsrc/frontend/optimizer/irpass.h
mindspore/ccsrc/frontend/optimizer/irpass.h
+3
-0
mindspore/ccsrc/frontend/optimizer/irpass/call_graph_tuple_transform.h
...rc/frontend/optimizer/irpass/call_graph_tuple_transform.h
+246
-0
mindspore/ccsrc/pipeline/jit/action.cc
mindspore/ccsrc/pipeline/jit/action.cc
+1
-0
mindspore/ccsrc/pipeline/jit/pass.cc
mindspore/ccsrc/pipeline/jit/pass.cc
+41
-3
mindspore/ccsrc/pipeline/pynative/pynative_execute.cc
mindspore/ccsrc/pipeline/pynative/pynative_execute.cc
+38
-1
tests/st/pynative/test_graph_param_transform.py
tests/st/pynative/test_graph_param_transform.py
+201
-0
tests/ut/python/pynative_mode/test_graph_param_cases.py
tests/ut/python/pynative_mode/test_graph_param_cases.py
+136
-0
未找到文件。
mindspore/ccsrc/frontend/optimizer/graph_transform.cc
0 → 100644
浏览文件 @
0099da2c
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "frontend/optimizer/graph_transform.h"
#include <vector>
#include <algorithm>
#include <string>
#include "ir/graph_utils.h"
namespace
mindspore
{
/* namespace to support opt */
namespace
opt
{
// check cnode input values, whether it is tuple input
bool
CNodeHasTupleInput
(
const
CNodePtr
&
cnode
)
{
auto
&
inputs
=
cnode
->
inputs
();
for
(
size_t
i
=
1
;
i
<
inputs
.
size
();
i
++
)
{
if
(
IsValueNode
<
FuncGraph
>
(
inputs
[
i
]))
{
continue
;
}
if
(
IsValueNode
<
Primitive
>
(
inputs
[
i
]))
{
// unexpected high order primitvie as cnode input when transform graph
MS_LOG
(
WARNING
)
<<
"CheckTupleInput, got unexpected primitve as input"
<<
cnode
->
DebugString
();
return
false
;
}
auto
abs
=
inputs
[
i
]
->
abstract
();
if
(
abs
==
nullptr
)
{
MS_LOG
(
WARNING
)
<<
"CheckTupleInput, got abstract nullptr for node:"
<<
cnode
->
DebugString
();
return
false
;
}
if
(
abs
->
isa
<
abstract
::
AbstractTuple
>
())
{
return
true
;
}
}
return
false
;
}
bool
FuncGraphHasTupleInput
(
const
FuncGraphPtr
&
fg
)
{
auto
&
params
=
fg
->
parameters
();
for
(
auto
&
param
:
params
)
{
if
(
param
->
abstract
()
->
isa
<
abstract
::
AbstractTuple
>
())
{
return
true
;
}
}
return
false
;
}
std
::
vector
<
AnfNodePtr
>
TransformTupleArgument
(
const
FuncGraphPtr
&
fg
,
const
AnfNodePtr
&
node
,
const
abstract
::
AbstractTuplePtr
&
abs
)
{
auto
&
elements
=
abs
->
elements
();
std
::
vector
<
AnfNodePtr
>
tuple_node_expanded
;
for
(
size_t
i
=
0
;
i
<
elements
.
size
();
i
++
)
{
auto
elem_node
=
fg
->
NewCNode
({
NewValueNode
(
prim
::
kPrimTupleGetItem
),
node
,
NewValueNode
(
SizeToInt
(
i
))});
elem_node
->
set_abstract
(
elements
[
i
]);
if
(
elements
[
i
]
->
isa
<
abstract
::
AbstractTuple
>
())
{
auto
nodes
=
TransformTupleArgument
(
fg
,
elem_node
,
elements
[
i
]
->
cast
<
abstract
::
AbstractTuplePtr
>
());
tuple_node_expanded
.
insert
(
tuple_node_expanded
.
end
(),
nodes
.
begin
(),
nodes
.
end
());
}
else
{
tuple_node_expanded
.
push_back
(
elem_node
);
}
}
return
tuple_node_expanded
;
}
AnfNodePtr
TransformCallGraph
(
const
FuncGraphPtr
&
trans_fg
,
const
CNodePtr
&
cnode
)
{
auto
&
cinputs
=
cnode
->
inputs
();
auto
fg
=
cnode
->
func_graph
();
std
::
vector
<
AnfNodePtr
>
inputs
;
inputs
.
push_back
(
NewValueNode
(
trans_fg
));
for
(
size_t
i
=
1
;
i
<
cinputs
.
size
();
i
++
)
{
auto
abs
=
cinputs
[
i
]
->
abstract
();
if
(
abs
==
nullptr
)
{
MS_LOG
(
EXCEPTION
)
<<
"TransformCallGraph:Node abstract should not be nullptr"
<<
cinputs
[
i
]
->
DebugString
();
}
if
(
abs
->
isa
<
abstract
::
AbstractTuple
>
())
{
auto
nodes
=
TransformTupleArgument
(
fg
,
cinputs
[
i
],
abs
->
cast
<
abstract
::
AbstractTuplePtr
>
());
inputs
.
insert
(
inputs
.
end
(),
nodes
.
begin
(),
nodes
.
end
());
}
else
{
inputs
.
push_back
(
cinputs
[
i
]);
}
}
auto
new_node
=
fg
->
NewCNode
(
inputs
);
new_node
->
set_abstract
(
cnode
->
abstract
());
return
new_node
;
}
AnfNodePtr
TransformPartial
(
const
FuncGraphPtr
&
trans_fg
,
const
CNodePtr
&
cnode
)
{
auto
&
cinputs
=
cnode
->
inputs
();
auto
fg
=
cnode
->
func_graph
();
std
::
vector
<
AnfNodePtr
>
inputs
;
inputs
.
push_back
(
NewValueNode
(
prim
::
kPrimPartial
));
inputs
.
push_back
(
NewValueNode
(
trans_fg
));
for
(
size_t
i
=
2
;
i
<
cinputs
.
size
();
i
++
)
{
auto
abs
=
cinputs
[
i
]
->
abstract
();
if
(
abs
==
nullptr
)
{
MS_LOG
(
EXCEPTION
)
<<
"TransformPartial:Node abstract should not be nullptr"
<<
cinputs
[
i
]
->
DebugString
();
}
if
(
abs
->
isa
<
abstract
::
AbstractTuple
>
())
{
auto
nodes
=
TransformTupleArgument
(
fg
,
cinputs
[
i
],
abs
->
cast
<
abstract
::
AbstractTuplePtr
>
());
inputs
.
insert
(
inputs
.
end
(),
nodes
.
begin
(),
nodes
.
end
());
}
else
{
inputs
.
push_back
(
cinputs
[
i
]);
}
}
auto
new_node
=
fg
->
NewCNode
(
inputs
);
new_node
->
set_abstract
(
cnode
->
abstract
());
return
new_node
;
}
AnfNodePtr
TransformSwitchCall
(
const
AnfNodePtr
&
swtich_node
,
const
CNodePtr
&
cnode
)
{
auto
&
cinputs
=
cnode
->
inputs
();
auto
fg
=
cnode
->
func_graph
();
std
::
vector
<
AnfNodePtr
>
inputs
;
inputs
.
push_back
(
swtich_node
);
for
(
size_t
i
=
1
;
i
<
cinputs
.
size
();
i
++
)
{
auto
abs
=
cinputs
[
i
]
->
abstract
();
if
(
abs
==
nullptr
)
{
MS_LOG
(
EXCEPTION
)
<<
"TransformSwitchCall:Node abstract should not be nullptr"
<<
cinputs
[
i
]
->
DebugString
();
}
if
(
abs
->
isa
<
abstract
::
AbstractTuple
>
())
{
auto
nodes
=
TransformTupleArgument
(
fg
,
cinputs
[
i
],
abs
->
cast
<
abstract
::
AbstractTuplePtr
>
());
inputs
.
insert
(
inputs
.
end
(),
nodes
.
begin
(),
nodes
.
end
());
}
else
{
inputs
.
push_back
(
cinputs
[
i
]);
}
}
auto
new_node
=
fg
->
NewCNode
(
inputs
);
new_node
->
set_abstract
(
cnode
->
abstract
());
return
new_node
;
}
}
// namespace opt
}
// namespace mindspore
mindspore/ccsrc/frontend/optimizer/graph_transform.h
0 → 100644
浏览文件 @
0099da2c
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_GRAPH_TRANSFORM_H
#define MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_GRAPH_TRANSFORM_H
#include <unordered_map>
#include <string>
#include <vector>
#include <algorithm>
#include <memory>
#include "frontend/optimizer/optimizer.h"
namespace
mindspore
{
namespace
opt
{
bool
CNodeHasTupleInput
(
const
CNodePtr
&
cnode
);
bool
FuncGraphHasTupleInput
(
const
FuncGraphPtr
&
fg
);
std
::
vector
<
AnfNodePtr
>
TransformTupleArgument
(
const
FuncGraphPtr
&
fg
,
const
AnfNodePtr
&
node
,
const
abstract
::
AbstractTuplePtr
&
abs
);
AnfNodePtr
TransformCallGraph
(
const
FuncGraphPtr
&
trans_fg
,
const
CNodePtr
&
cnode
);
AnfNodePtr
TransformPartial
(
const
FuncGraphPtr
&
trans_fg
,
const
CNodePtr
&
cnode
);
AnfNodePtr
TransformSwitchCall
(
const
AnfNodePtr
&
swtich_node
,
const
CNodePtr
&
cnode
);
class
GraphTupleParamTransform
{
public:
GraphTupleParamTransform
()
:
cache_
()
{}
~
GraphTupleParamTransform
()
{
cache_
.
clear
();
}
FuncGraphPtr
operator
()(
const
FuncGraphPtr
&
fg
,
const
FuncGraphManagerPtr
&
mng
)
{
if
(
cache_
.
find
(
fg
)
!=
cache_
.
end
())
{
return
cache_
[
fg
];
}
auto
new_fg
=
TransformGraphParam
(
fg
,
mng
);
cache_
[
fg
]
=
new_fg
;
return
new_fg
;
}
AnfNodePtr
GenerateTupleParams
(
const
abstract
::
AbstractTuplePtr
&
tuple_abs
,
const
FuncGraphPtr
&
fg
,
std
::
vector
<
AnfNodePtr
>
*
params
)
{
std
::
vector
<
AnfNodePtr
>
inputs
;
inputs
.
push_back
(
NewValueNode
(
prim
::
kPrimMakeTuple
));
auto
&
elements
=
tuple_abs
->
elements
();
for
(
auto
&
item
:
elements
)
{
if
(
item
->
isa
<
abstract
::
AbstractTuple
>
())
{
inputs
.
push_back
(
GenerateTupleParams
(
item
->
cast
<
abstract
::
AbstractTuplePtr
>
(),
fg
,
params
));
}
else
{
auto
p
=
std
::
make_shared
<
Parameter
>
(
fg
);
p
->
set_abstract
(
item
);
params
->
push_back
(
p
);
inputs
.
push_back
(
params
->
back
());
}
}
auto
node
=
fg
->
NewCNode
(
inputs
);
node
->
set_abstract
(
tuple_abs
);
return
node
;
}
FuncGraphPtr
TransformGraphParam
(
const
FuncGraphPtr
&
fg
,
const
FuncGraphManagerPtr
&
mng
)
{
Cloner
cloner
({
fg
},
false
,
false
,
false
,
std
::
make_shared
<
TraceCopy
>
(),
std
::
make_shared
<
TraceCopy
>
());
auto
new_fg
=
cloner
[
fg
];
auto
&
params
=
new_fg
->
parameters
();
std
::
vector
<
AnfNodePtr
>
new_params
;
std
::
unordered_map
<
AnfNodePtr
,
AnfNodePtr
>
repl
;
for
(
auto
&
param
:
params
)
{
auto
abs
=
param
->
abstract
();
if
(
abs
!=
nullptr
&&
abs
->
isa
<
abstract
::
AbstractTuple
>
())
{
auto
tuple_abs
=
abs
->
cast
<
abstract
::
AbstractTuplePtr
>
();
std
::
vector
<
AnfNodePtr
>
tuple_params
;
repl
.
emplace
(
param
,
GenerateTupleParams
(
tuple_abs
,
new_fg
,
&
tuple_params
));
std
::
transform
(
tuple_params
.
begin
(),
tuple_params
.
end
(),
std
::
back_inserter
(
new_params
),
[](
AnfNodePtr
p
)
{
return
p
;
});
}
else
{
new_params
.
push_back
(
param
);
}
}
auto
tmp_mng
=
mindspore
::
Manage
(
new_fg
,
false
);
auto
tr
=
tmp_mng
->
Transact
();
for
(
auto
&
item
:
repl
)
{
bool
ret
=
tr
.
Replace
(
item
.
first
,
item
.
second
);
if
(
ret
==
false
)
{
MS_LOG
(
ERROR
)
<<
"replace failed"
<<
item
.
first
->
DebugString
()
<<
" with__"
<<
item
.
second
->
DebugString
(
2
);
}
}
tr
.
SetParameters
(
new_fg
,
new_params
);
tr
.
Commit
();
mng
->
AddFuncGraph
(
new_fg
);
return
new_fg
;
}
std
::
unordered_map
<
FuncGraphPtr
,
FuncGraphPtr
>
cache_
;
};
}
// namespace opt
}
// namespace mindspore
#endif // MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_GRAPH_TRANSFORM_H
mindspore/ccsrc/frontend/optimizer/irpass.cc
浏览文件 @
0099da2c
...
...
@@ -44,6 +44,7 @@
#include "frontend/optimizer/irpass/row_tensor_eliminate.h"
#include "frontend/optimizer/irpass/sparse_tensor_eliminate.h"
#include "frontend/optimizer/irpass/switch_layer_defer_inline.h"
#include "frontend/optimizer/irpass/call_graph_tuple_transform.h"
namespace
mindspore
{
namespace
opt
{
...
...
@@ -158,6 +159,10 @@ OptimizeIRPassLib::OptimizeIRPassLib() {
unused_output_eliminate_
=
MakeSubstitution
(
std
::
make_shared
<
UnusedOutputEliminater
>
(),
"unused_output_eliminate"
,
IsCNodeGraphKernel
);
// tuple parameter graph transform
call_graph_tuple_transform_
=
MakeSubstitution
(
std
::
make_shared
<
CallGraphTupleTransform
>
(),
"graph_param_transorm"
,
IsCNode
);
// AddN eliminate
addn_eliminate_
=
MakeSubstitution
(
std
::
make_shared
<
AddNEliminater
>
(),
"addn_eliminate"
,
IsCNodeGraphKernel
);
...
...
mindspore/ccsrc/frontend/optimizer/irpass.h
浏览文件 @
0099da2c
...
...
@@ -103,6 +103,9 @@ class OptimizeIRPassLib {
SubstitutionPtr
unused_parameter_eliminate_
;
SubstitutionPtr
unused_output_eliminate_
;
// tuple parameter graph transform
SubstitutionPtr
call_graph_tuple_transform_
;
// AddN eliminate
SubstitutionPtr
addn_eliminate_
;
...
...
mindspore/ccsrc/frontend/optimizer/irpass/call_graph_tuple_transform.h
0 → 100644
浏览文件 @
0099da2c
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_CALL_GRAPH_TRANSFORM_H_
#define MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_CALL_GRAPH_TRANSFORM_H_
#include <algorithm>
#include <memory>
#include <unordered_map>
#include <unordered_set>
#include <vector>
#include "ir/func_graph.h"
#include "ir/func_graph_cloner.h"
#include "frontend/optimizer/optimizer_caller.h"
#include "frontend/optimizer/anf_visitor.h"
#include "frontend/operator/ops.h"
#include "frontend/optimizer/irpass.h"
#include "frontend/optimizer/optimizer.h"
#include "frontend/optimizer/graph_transform.h"
namespace
mindspore
{
namespace
opt
{
namespace
irpass
{
// {G, Xs}-->transform graph call tuple inputs to flat inputs.
class
GraphCallTupleTransform
:
public
AnfVisitor
{
public:
explicit
GraphCallTupleTransform
(
GraphTupleParamTransform
&
transformer
)
:
graph_transform_
(
transformer
)
{}
~
GraphCallTupleTransform
()
override
=
default
;
AnfNodePtr
operator
()(
const
OptimizerPtr
&
optimizer
,
const
AnfNodePtr
&
node
)
override
{
if
(
!
node
->
isa
<
CNode
>
()
||
node
->
func_graph
()
==
nullptr
)
{
return
nullptr
;
}
auto
cnode
=
node
->
cast
<
CNodePtr
>
();
MS_EXCEPTION_IF_NULL
(
cnode
);
auto
&
inputs
=
cnode
->
inputs
();
auto
fg
=
GetValueNode
<
FuncGraphPtr
>
(
inputs
[
0
]);
if
(
fg
==
nullptr
)
{
return
nullptr
;
}
if
(
!
CNodeHasTupleInput
(
node
->
cast
<
CNodePtr
>
()))
{
return
nullptr
;
}
FuncGraphPtr
transformed_fg
=
graph_transform_
(
fg
,
optimizer
->
manager
());
auto
new_node
=
TransformCallGraph
(
transformed_fg
,
node
->
cast
<
CNodePtr
>
());
return
new_node
;
}
private:
GraphTupleParamTransform
&
graph_transform_
;
};
// {{switch, cond, true_branch, false_branch}, Xs} -->transform switch graph call tuple inputs to flat inputs.
class
SwitchCallTupleTransform
:
public
AnfVisitor
{
public:
explicit
SwitchCallTupleTransform
(
GraphTupleParamTransform
&
transformer
)
:
graph_transform_
(
transformer
)
{}
~
SwitchCallTupleTransform
()
override
=
default
;
AnfNodePtr
operator
()(
const
OptimizerPtr
&
optimizer
,
const
AnfNodePtr
&
node
)
override
{
if
(
!
node
->
isa
<
CNode
>
()
||
node
->
func_graph
()
==
nullptr
)
{
return
nullptr
;
}
auto
switch_call_cnode
=
node
->
cast
<
CNodePtr
>
();
auto
call_inputs
=
switch_call_cnode
->
inputs
();
if
(
call_inputs
.
size
()
<
1
)
{
return
nullptr
;
}
if
(
!
IsPrimitiveCNode
(
call_inputs
[
0
],
prim
::
kPrimSwitch
))
{
return
nullptr
;
}
auto
swich_cnode
=
call_inputs
[
0
]
->
cast
<
CNodePtr
>
();
auto
switch_inputs
=
swich_cnode
->
inputs
();
if
(
switch_inputs
.
size
()
!=
4
)
{
return
nullptr
;
}
AnfNodePtr
transformed
=
nullptr
;
bool
true_br_changed
=
TransformBranchNode
(
switch_inputs
[
2
],
optimizer
->
manager
(),
&
transformed
);
if
(
true_br_changed
)
{
switch_inputs
[
2
]
=
transformed
;
}
bool
false_br_changed
=
TransformBranchNode
(
switch_inputs
[
3
],
optimizer
->
manager
(),
&
transformed
);
if
(
false_br_changed
)
{
switch_inputs
[
3
]
=
transformed
;
}
if
(
true_br_changed
||
false_br_changed
)
{
call_inputs
[
0
]
=
swich_cnode
->
func_graph
()
->
NewCNode
(
switch_inputs
);
}
if
(
CNodeHasTupleInput
(
switch_call_cnode
))
{
return
TransformSwitchCall
(
call_inputs
[
0
],
switch_call_cnode
);
}
if
(
true_br_changed
||
false_br_changed
)
{
return
switch_call_cnode
->
func_graph
()
->
NewCNode
(
call_inputs
);
}
return
nullptr
;
}
bool
TransformBranchNode
(
AnfNodePtr
node
,
FuncGraphManagerPtr
mng
,
AnfNodePtr
*
trans_node
)
{
if
(
IsValueNode
<
FuncGraph
>
(
node
))
{
FuncGraphPtr
fg
=
GetValueNode
<
FuncGraphPtr
>
(
node
);
if
(
FuncGraphHasTupleInput
(
fg
))
{
FuncGraphPtr
transformed_fg
=
graph_transform_
(
fg
,
mng
);
*
trans_node
=
NewValueNode
(
transformed_fg
);
return
true
;
}
return
false
;
}
if
(
IsPrimitiveCNode
(
node
,
prim
::
kPrimPartial
))
{
auto
partial_inputs
=
node
->
cast
<
CNodePtr
>
()
->
inputs
();
if
(
IsValueNode
<
FuncGraph
>
(
partial_inputs
[
1
]))
{
FuncGraphPtr
fg
=
GetValueNode
<
FuncGraphPtr
>
(
partial_inputs
[
1
]);
if
(
FuncGraphHasTupleInput
(
fg
))
{
fg
=
graph_transform_
(
fg
,
mng
);
}
if
(
CNodeHasTupleInput
(
node
->
cast
<
CNodePtr
>
()))
{
*
trans_node
=
TransformPartial
(
fg
,
node
->
cast
<
CNodePtr
>
());
return
true
;
}
}
return
false
;
}
MS_LOG
(
WARNING
)
<<
"Got unexpected switch branch node "
<<
node
->
DebugString
();
return
false
;
}
private:
GraphTupleParamTransform
&
graph_transform_
;
};
// {{switch_layer, index, {make_tuple, br1, br2,...,}}, Xs} ->
// transform switch layer graph call tuple inputs to flat inputs.
class
SwitchLayerCallTupleTransform
:
public
AnfVisitor
{
public:
explicit
SwitchLayerCallTupleTransform
(
GraphTupleParamTransform
&
transformer
)
:
graph_transform_
(
transformer
)
{}
~
SwitchLayerCallTupleTransform
()
override
=
default
;
AnfNodePtr
operator
()(
const
OptimizerPtr
&
optimizer
,
const
AnfNodePtr
&
node
)
override
{
if
(
!
node
->
isa
<
CNode
>
()
||
node
->
func_graph
()
==
nullptr
)
{
return
nullptr
;
}
auto
switch_layer_call_cnode
=
node
->
cast
<
CNodePtr
>
();
auto
call_inputs
=
switch_layer_call_cnode
->
inputs
();
if
(
call_inputs
.
size
()
<
1
)
{
return
nullptr
;
}
if
(
!
IsPrimitiveCNode
(
call_inputs
[
0
],
prim
::
kPrimSwitchLayer
))
{
return
nullptr
;
}
auto
swich_layer_cnode
=
call_inputs
[
0
]
->
cast
<
CNodePtr
>
();
auto
switch_layer_inputs
=
swich_layer_cnode
->
inputs
();
if
(
switch_layer_inputs
.
size
()
!=
3
)
{
return
nullptr
;
}
AnfNodePtr
transformed
=
nullptr
;
bool
layer_changed
=
TransformLayerNode
(
switch_layer_inputs
[
2
],
optimizer
->
manager
(),
&
transformed
);
if
(
layer_changed
)
{
switch_layer_inputs
[
2
]
=
transformed
;
call_inputs
[
0
]
=
switch_layer_call_cnode
->
func_graph
()
->
NewCNode
(
switch_layer_inputs
);
}
if
(
CNodeHasTupleInput
(
switch_layer_call_cnode
))
{
return
TransformSwitchCall
(
call_inputs
[
0
],
switch_layer_call_cnode
);
}
if
(
layer_changed
)
{
return
switch_layer_call_cnode
->
func_graph
()
->
NewCNode
(
call_inputs
);
}
return
nullptr
;
}
bool
TransformLayerNode
(
AnfNodePtr
node
,
FuncGraphManagerPtr
mng
,
AnfNodePtr
*
trans_node
)
{
if
(
!
IsPrimitiveCNode
(
node
,
prim
::
kPrimMakeTuple
))
{
MS_LOG
(
WARNING
)
<<
"SwitchLayer input is not MakeTuple"
;
return
false
;
}
auto
tuple_inputs
=
node
->
cast
<
CNodePtr
>
()
->
inputs
();
bool
changed
=
false
;
for
(
size_t
i
=
1
;
i
<
tuple_inputs
.
size
();
i
++
)
{
if
(
!
IsValueNode
<
FuncGraph
>
(
tuple_inputs
[
i
]))
{
MS_LOG
(
WARNING
)
<<
"SwitchLayer input is not FuncGraph"
;
return
false
;
}
FuncGraphPtr
fg
=
GetValueNode
<
FuncGraphPtr
>
(
tuple_inputs
[
i
]);
if
(
FuncGraphHasTupleInput
(
fg
))
{
FuncGraphPtr
transformed_fg
=
graph_transform_
(
fg
,
mng
);
tuple_inputs
[
i
]
=
NewValueNode
(
transformed_fg
);
changed
=
true
;
}
}
if
(
changed
)
{
*
trans_node
=
node
->
func_graph
()
->
NewCNode
(
tuple_inputs
);
}
return
changed
;
}
private:
GraphTupleParamTransform
&
graph_transform_
;
};
class
CallGraphTupleTransform
:
public
OptimizerCaller
{
public:
CallGraphTupleTransform
()
:
graph_transformer_
(),
graph_call_transform_
(
std
::
make_shared
<
GraphCallTupleTransform
>
(
graph_transformer_
)),
switch_call_transform_
(
std
::
make_shared
<
SwitchCallTupleTransform
>
(
graph_transformer_
)),
switch_layer_call_transform_
(
std
::
make_shared
<
SwitchLayerCallTupleTransform
>
(
graph_transformer_
))
{
transformers_
.
emplace_back
(
graph_call_transform_
);
transformers_
.
emplace_back
(
switch_call_transform_
);
transformers_
.
emplace_back
(
switch_layer_call_transform_
);
}
~
CallGraphTupleTransform
()
=
default
;
AnfNodePtr
operator
()(
const
OptimizerPtr
&
optimizer
,
const
AnfNodePtr
&
node
)
override
{
AnfNodePtr
new_node
;
for
(
auto
&
transform
:
transformers_
)
{
new_node
=
(
*
transform
)(
optimizer
,
node
);
if
(
new_node
!=
nullptr
)
{
return
new_node
;
}
}
return
nullptr
;
}
private:
GraphTupleParamTransform
graph_transformer_
;
OptimizerCallerPtr
graph_call_transform_
;
OptimizerCallerPtr
switch_call_transform_
;
OptimizerCallerPtr
switch_layer_call_transform_
;
std
::
vector
<
OptimizerCallerPtr
>
transformers_
{};
};
}
// namespace irpass
}
// namespace opt
}
// namespace mindspore
#endif // MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_CALL_GRAPH_TRANSFORM_H_
mindspore/ccsrc/pipeline/jit/action.cc
浏览文件 @
0099da2c
...
...
@@ -277,6 +277,7 @@ bool OptimizeAction(const ResourcePtr &res, const std::vector<PassItem> &passes)
MS_EXCEPTION_IF_NULL
(
func_graph
);
func_graph
->
DumpFuncGraph
(
fg_name
);
DumpIR
(
fg_name
+
".ir"
,
func_graph
);
ExportIR
(
fg_name
+
".dat"
,
""
,
func_graph
);
MS_LOG
(
DEBUG
)
<<
"Dump "
<<
fg_name
<<
" func graph."
;
}
counter
++
;
...
...
mindspore/ccsrc/pipeline/jit/pass.cc
浏览文件 @
0099da2c
...
...
@@ -33,6 +33,7 @@
#include "frontend/optimizer/clean.h"
#include "frontend/optimizer/irpass.h"
#include "frontend/optimizer/control_depend.h"
#include "frontend/optimizer/graph_transform.h"
#include "frontend/parallel/step_parallel.h"
#include "frontend/parallel/step_auto_parallel.h"
#include "frontend/parallel/allreduce_fusion/step_allreduce_fusion.h"
...
...
@@ -166,12 +167,23 @@ OptPassGroupMap GetOptPassesA(const opt::irpass::OptimizeIRPassLib &irpass) {
OptPassGroupMap
GetOptPassesAfterCconv
(
const
opt
::
irpass
::
OptimizeIRPassLib
&
irpass
)
{
opt
::
OptPassConfig
c_1
=
opt
::
OptPassConfig
({
// Safe inlining
// Safe inlining
,
irpass
.
inline_
,
irpass
.
partial_eliminate_
,
});
OptPassGroupMap
map_a
({{
"c_1"
,
c_1
},
{
"renormalize"
,
opt
::
OptPassConfig
::
Renormalize
()}});
OptPassGroupMap
map_a
({{
"c_1"
,
c_1
},
{
"cse"
,
opt
::
OptPassConfig
(
opt
::
CSEPass
(
false
))},
{
"renormalize"
,
opt
::
OptPassConfig
::
Renormalize
()}});
return
map_a
;
}
OptPassGroupMap
GetOptPassesTransformGraph
(
const
opt
::
irpass
::
OptimizeIRPassLib
&
irpass
)
{
opt
::
OptPassConfig
d_1
=
opt
::
OptPassConfig
({
// Safe inlining
irpass
.
call_graph_tuple_transform_
,
irpass
.
item_tuple_eliminate_
});
OptPassGroupMap
map_a
({{
"d_1"
,
d_1
},
{
"renormalize"
,
opt
::
OptPassConfig
::
Renormalize
()}});
return
map_a
;
}
...
...
@@ -262,6 +274,8 @@ void InitOpt(const ResourcePtr &res) {
g_pass_opts
[
"opt_b"
]
=
Optimizer
::
MakeOptimizer
(
"opt_b"
,
res
,
GetOptPassesB
(
irpass
),
false
,
true
);
g_pass_opts
[
"opt_after_cconv"
]
=
Optimizer
::
MakeOptimizer
(
"opt_after_cconv"
,
res
,
GetOptPassesAfterCconv
(
irpass
),
false
,
true
);
g_pass_opts
[
"opt_trans_graph"
]
=
Optimizer
::
MakeOptimizer
(
"opt_trans_graph"
,
res
,
GetOptPassesTransformGraph
(
irpass
),
true
,
true
);
g_pass_opts
[
"opt_graph_kernel_a"
]
=
Optimizer
::
MakeOptimizer
(
"opt_graph_kernel_a"
,
res
,
GetOptPassesGraphKernelA
(
irpass
),
true
);
g_pass_opts
[
"opt_graph_kernel_b"
]
=
...
...
@@ -307,6 +321,7 @@ bool OptPassGroup(const ResourcePtr &res, const std::string &name) {
bool
OptPassAGroup
(
const
ResourcePtr
&
res
)
{
return
OptPassGroup
(
res
,
"opt_a"
);
}
bool
OptPassBGroup
(
const
ResourcePtr
&
res
)
{
return
OptPassGroup
(
res
,
"opt_b"
);
}
bool
OptPassAfterCconvGroup
(
const
ResourcePtr
&
res
)
{
return
OptPassGroup
(
res
,
"opt_after_cconv"
);
}
bool
OptPassTransformGraphGroup
(
const
ResourcePtr
&
res
)
{
return
OptPassGroup
(
res
,
"opt_trans_graph"
);
}
bool
OptPassGraphKernelGroupA
(
const
ResourcePtr
&
res
)
{
return
OptPassGroup
(
res
,
"opt_graph_kernel_a"
);
}
bool
OptPassGraphKernelGroupB
(
const
ResourcePtr
&
res
)
{
return
OptPassGroup
(
res
,
"opt_graph_kernel_b"
);
}
bool
ControlGroup
(
const
ResourcePtr
&
res
)
{
return
OptPassGroup
(
res
,
"opt_control"
);
}
...
...
@@ -365,6 +380,24 @@ bool CconvPass(const ResourcePtr &res) {
return
true
;
}
bool
TransformTopGraphPass
(
const
ResourcePtr
&
res
)
{
if
(
res
->
func_graph
()
==
nullptr
)
{
MS_LOG
(
EXCEPTION
)
<<
"Transform top graph error."
;
}
FuncGraphPtr
func_graph
=
res
->
func_graph
();
if
(
opt
::
FuncGraphHasTupleInput
(
func_graph
))
{
opt
::
GraphTupleParamTransform
graph_trans
;
func_graph
=
graph_trans
(
func_graph
,
res
->
manager
());
res
->
set_func_graph
(
func_graph
);
AbstractBasePtrList
abs_spec_list
;
auto
&
params
=
func_graph
->
parameters
();
std
::
transform
(
params
.
begin
(),
params
.
end
(),
std
::
back_inserter
(
abs_spec_list
),
[](
AnfNodePtr
node
)
{
return
node
->
abstract
();
});
res
->
set_args_spec
(
abs_spec_list
);
}
return
true
;
}
bool
ValidatePass
(
const
ResourcePtr
&
res
)
{
MS_EXCEPTION_IF_NULL
(
res
->
func_graph
());
FuncGraphPtr
func_graph
=
res
->
func_graph
();
...
...
@@ -388,6 +421,7 @@ std::vector<PassItem> kVmPasses = {{"simplify_data_structures", SimplifyDataStru
{
"cconv"
,
CconvPass
},
{
"opt_after_cconv"
,
OptPassAfterCconvGroup
},
{
"remove_dup_value"
,
RemoveValueNodeDuplicationsPass
},
{
"tuple_transform"
,
OptPassTransformGraphGroup
},
{
"opt_graph_kernel_a"
,
OptPassGraphKernelGroupA
},
{
"opt_graph_kernel_b"
,
OptPassGraphKernelGroupB
},
{
"add_control_depend"
,
AddControlDependPass
}};
...
...
@@ -401,6 +435,10 @@ std::vector<PassItem> kGePasses = {{"simplify_data_structures", SimplifyDataStru
{
"opt_prepare"
,
PrepareGroup
},
{
"cconv"
,
CconvPass
}};
std
::
vector
<
PassItem
>
kPynativePasses
=
{{
"opt_a"
,
OptPassAGroup
},
{
"opt_b"
,
OptPassBGroup
},
{
"cconv"
,
CconvPass
}};
std
::
vector
<
PassItem
>
kPynativePasses
=
{{
"opt_a"
,
OptPassAGroup
},
{
"opt_b"
,
OptPassBGroup
},
{
"cconv"
,
CconvPass
},
{
"transform_top"
,
TransformTopGraphPass
},
{
"transform_graph"
,
OptPassTransformGraphGroup
}};
}
// namespace pipeline
}
// namespace mindspore
mindspore/ccsrc/pipeline/pynative/pynative_execute.cc
浏览文件 @
0099da2c
...
...
@@ -1351,9 +1351,46 @@ void PynativeExecutor::ClearRes() {
resource_
.
reset
();
}
size_t
GetTupleSize
(
const
py
::
tuple
&
args
)
{
size_t
count
=
0
;
for
(
size_t
i
=
0
;
i
<
args
.
size
();
i
++
)
{
if
(
py
::
isinstance
<
py
::
tuple
>
(
args
[
i
]))
{
count
+=
GetTupleSize
(
args
[
i
]);
}
else
{
count
+=
1
;
}
}
return
count
;
}
void
ConvertTupleArg
(
py
::
tuple
*
res
,
size_t
*
index
,
const
py
::
tuple
&
arg
)
{
for
(
size_t
i
=
0
;
i
<
arg
.
size
();
i
++
)
{
if
(
py
::
isinstance
<
py
::
tuple
>
(
arg
[
i
]))
{
ConvertTupleArg
(
res
,
index
,
arg
[
i
]);
}
else
{
(
*
res
)[(
*
index
)
++
]
=
arg
[
i
];
}
}
}
py
::
tuple
ConvertArgs
(
const
py
::
tuple
&
args
)
{
size_t
tuple_size
=
GetTupleSize
(
args
);
py
::
tuple
res
(
tuple_size
);
size_t
index
=
0
;
for
(
size_t
i
=
0
;
i
<
args
.
size
();
i
++
)
{
if
(
py
::
isinstance
<
py
::
tuple
>
(
args
[
i
]))
{
ConvertTupleArg
(
&
res
,
&
index
,
args
[
i
]);
}
else
{
res
[
index
++
]
=
args
[
i
];
}
}
return
res
;
}
py
::
object
PynativeExecutor
::
Run
(
const
py
::
tuple
&
args
,
const
py
::
object
&
phase
)
{
VectorRef
arg_list
;
pipeline
::
ProcessVmArgInner
(
args
,
resource_
,
&
arg_list
);
py
::
tuple
converted_args
=
ConvertArgs
(
args
);
pipeline
::
ProcessVmArgInner
(
converted_args
,
resource_
,
&
arg_list
);
if
(
resource_
->
results
().
find
(
pipeline
::
kOutput
)
==
resource_
->
results
().
end
()
||
!
resource_
->
results
()[
pipeline
::
kOutput
].
is
<
compile
::
VmEvalFuncPtr
>
())
{
MS_LOG
(
EXCEPTION
)
<<
"Can't find run graph func for "
;
...
...
tests/st/pynative/test_graph_param_transform.py
0 → 100644
浏览文件 @
0099da2c
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import
pytest
import
numpy
as
np
from
mindspore
import
RowTensor
from
mindspore
import
context
,
nn
,
Tensor
,
ParameterTuple
from
mindspore.common
import
dtype
as
mstype
from
mindspore.common
import
ms_function
from
mindspore.ops
import
operations
as
P
from
mindspore.ops
import
composite
as
C
def
setup_module
():
context
.
set_context
(
mode
=
context
.
PYNATIVE_MODE
,
enable_sparse
=
False
)
class
_Grad
(
nn
.
Cell
):
def
__init__
(
self
,
grad
,
network
,
wrt_params
=
False
,
real_inputs_count
=
None
):
super
().
__init__
()
self
.
network
=
network
self
.
grad
=
grad
self
.
sens_param
=
self
.
grad
.
sens_param
self
.
wrt_params
=
wrt_params
self
.
real_inputs_count
=
real_inputs_count
if
self
.
wrt_params
:
self
.
params
=
ParameterTuple
(
self
.
network
.
trainable_params
())
def
construct
(
self
,
*
inputs
):
if
self
.
wrt_params
:
if
self
.
real_inputs_count
is
None
or
self
.
sens_param
is
False
:
return
self
.
grad
(
self
.
network
,
self
.
params
)(
*
inputs
)
real_inputs
=
inputs
[:
self
.
real_inputs_count
]
sense_param_inputs
=
inputs
[
self
.
real_inputs_count
:]
return
self
.
grad
(
self
.
network
,
self
.
params
)(
*
real_inputs
,
sense_param_inputs
)
if
self
.
real_inputs_count
is
None
or
self
.
sens_param
is
False
:
return
self
.
grad
(
self
.
network
)(
*
inputs
)
real_inputs
=
inputs
[:
self
.
real_inputs_count
]
sense_param_inputs
=
inputs
[
self
.
real_inputs_count
:]
return
self
.
grad
(
self
.
network
)(
*
real_inputs
,
sense_param_inputs
)
class
GradOfFirstInput
(
_Grad
):
"""
get grad of first input
"""
def
__init__
(
self
,
network
,
sens_param
=
True
,
real_inputs_count
=
None
):
super
().
__init__
(
grad
=
C
.
GradOperation
(
sens_param
=
sens_param
),
network
=
network
,
real_inputs_count
=
real_inputs_count
)
class
GradOfAllInputs
(
_Grad
):
"""
get grad of first input
"""
def
__init__
(
self
,
network
,
sens_param
=
True
,
real_inputs_count
=
None
):
super
().
__init__
(
grad
=
C
.
GradOperation
(
get_all
=
True
,
sens_param
=
sens_param
),
network
=
network
,
real_inputs_count
=
real_inputs_count
)
@
pytest
.
mark
.
level0
@
pytest
.
mark
.
platform_arm_ascend_training
@
pytest
.
mark
.
platform_x86_ascend_training
@
pytest
.
mark
.
env_onecard
def
test_row_tensor_in_while
():
class
RowTensorValuesDouble
(
nn
.
Cell
):
def
construct
(
self
,
x
):
indices
=
x
.
indices
values
=
x
.
values
*
2
dense_shape
=
x
.
dense_shape
return
RowTensor
(
indices
,
values
,
dense_shape
)
class
RowTensorValuesAdd2
(
nn
.
Cell
):
def
construct
(
self
,
x
):
indices
=
x
.
indices
values
=
x
.
values
+
2
dense_shape
=
x
.
dense_shape
return
RowTensor
(
indices
,
values
,
dense_shape
)
class
RowTensorWithControlWhile
(
nn
.
Cell
):
def
__init__
(
self
,
dense_shape
):
super
().
__init__
()
self
.
op1
=
RowTensorValuesDouble
()
self
.
op2
=
RowTensorValuesAdd2
()
self
.
dense_shape
=
dense_shape
@
ms_function
def
construct
(
self
,
a
,
b
,
indices
,
values
):
x
=
RowTensor
(
indices
,
values
,
self
.
dense_shape
)
x
=
self
.
op2
(
x
)
while
a
>
b
:
x
=
self
.
op1
(
x
)
b
=
b
+
1
return
x
.
indices
,
x
.
values
,
x
.
dense_shape
a
=
Tensor
(
np
.
array
(
3
).
astype
(
np
.
int32
))
b
=
Tensor
(
np
.
array
(
0
).
astype
(
np
.
int32
))
indices
=
Tensor
(
np
.
array
([
0
,
2
]).
astype
(
np
.
int32
))
values
=
Tensor
(
np
.
ones
([
2
,
2
]).
astype
(
np
.
float32
))
dense_shape
=
(
5
,
2
)
net
=
RowTensorWithControlWhile
(
dense_shape
)
out
=
net
(
a
,
b
,
indices
,
values
)
assert
np
.
allclose
(
indices
.
asnumpy
(),
out
[
0
].
asnumpy
(),
.
0
,
.
0
)
assert
np
.
allclose
(
values
.
asnumpy
()
*
24
,
out
[
1
].
asnumpy
(),
.
0
,
.
0
)
assert
dense_shape
==
out
[
2
]
@
pytest
.
mark
.
level0
@
pytest
.
mark
.
platform_arm_ascend_training
@
pytest
.
mark
.
platform_x86_ascend_training
@
pytest
.
mark
.
env_onecard
def
test_parser_switch_layer_inputs_tuple
():
class
Add
(
nn
.
Cell
):
def
__init__
(
self
):
super
().
__init__
()
self
.
op
=
P
.
TensorAdd
()
def
construct
(
self
,
x
):
y
=
self
.
op
(
x
[
0
],
x
[
1
])
return
self
.
op
(
x
[
0
],
y
)
class
Mul
(
nn
.
Cell
):
def
__init__
(
self
):
super
().
__init__
()
self
.
op
=
P
.
Mul
()
def
construct
(
self
,
x
):
y
=
self
.
op
(
x
[
0
],
x
[
1
])
return
self
.
op
(
x
[
0
],
y
)
class
MulTwoInput
(
nn
.
Cell
):
def
__init__
(
self
):
super
().
__init__
()
self
.
op
=
P
.
Mul
()
@
ms_function
def
construct
(
self
,
x
,
y
):
y
=
self
.
op
(
x
,
y
)
return
self
.
op
(
x
,
y
)
class
TwoInputTupleFinalNet
(
nn
.
Cell
):
def
__init__
(
self
,
funcs
):
super
().
__init__
()
self
.
funcs
=
funcs
@
ms_function
def
construct
(
self
,
i
,
inputa
,
inputb
):
inputs
=
(
inputa
,
inputb
)
x
=
self
.
funcs
[
i
](
inputs
)
return
x
func1
=
Add
()
func2
=
Mul
()
funcs
=
(
func1
,
func2
)
net
=
TwoInputTupleFinalNet
(
funcs
)
input_data
=
Tensor
(
np
.
random
.
randn
(
2
,
3
,
4
,
5
).
astype
(
np
.
float32
))
input2
=
Tensor
(
np
.
random
.
randn
(
2
,
3
,
4
,
5
).
astype
(
np
.
float32
))
i
=
Tensor
(
1
,
mstype
.
int32
)
netout
=
net
(
i
,
input_data
,
input2
)
net_good
=
MulTwoInput
()
goodout
=
net_good
(
input_data
,
input2
)
assert
np
.
allclose
(
goodout
.
asnumpy
(),
netout
.
asnumpy
(),
0
,
0
)
@
pytest
.
mark
.
level0
@
pytest
.
mark
.
platform_arm_ascend_training
@
pytest
.
mark
.
platform_x86_ascend_training
@
pytest
.
mark
.
env_onecard
def
test_imagenet
():
class
ImageGradients
(
nn
.
Cell
):
def
__init__
(
self
):
super
().
__init__
()
self
.
imagegradients
=
nn
.
ImageGradients
()
def
construct
(
self
,
inputs
):
return
self
.
imagegradients
(
inputs
)
net
=
ImageGradients
()
net_me
=
GradOfFirstInput
(
net
,
real_inputs_count
=
1
)
net_me
.
set_train
()
input_data
=
Tensor
(
np
.
ones
([
32
,
16
,
8
,
8
]),
dtype
=
mstype
.
float32
)
output_grad
=
(
Tensor
(
np
.
ones
([
32
,
16
,
8
,
8
]),
dtype
=
mstype
.
float32
),
Tensor
(
np
.
ones
([
32
,
16
,
8
,
8
]),
dtype
=
mstype
.
float32
))
net_me
(
input_data
,
*
output_grad
)
tests/ut/python/pynative_mode/test_graph_param_cases.py
0 → 100644
浏览文件 @
0099da2c
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import
numpy
as
np
from
mindspore
import
RowTensor
from
mindspore
import
context
,
nn
,
Tensor
,
ParameterTuple
from
mindspore.common
import
dtype
as
mstype
from
mindspore.common
import
ms_function
from
mindspore.ops
import
composite
as
C
def
setup_module
():
context
.
set_context
(
mode
=
context
.
PYNATIVE_MODE
,
enable_sparse
=
False
)
class
_Grad
(
nn
.
Cell
):
def
__init__
(
self
,
grad
,
network
,
wrt_params
=
False
,
real_inputs_count
=
None
):
super
().
__init__
()
self
.
network
=
network
self
.
grad
=
grad
self
.
sens_param
=
self
.
grad
.
sens_param
self
.
wrt_params
=
wrt_params
self
.
real_inputs_count
=
real_inputs_count
if
self
.
wrt_params
:
self
.
params
=
ParameterTuple
(
self
.
network
.
trainable_params
())
def
construct
(
self
,
*
inputs
):
if
self
.
wrt_params
:
if
self
.
real_inputs_count
is
None
or
self
.
sens_param
is
False
:
return
self
.
grad
(
self
.
network
,
self
.
params
)(
*
inputs
)
real_inputs
=
inputs
[:
self
.
real_inputs_count
]
sense_param_inputs
=
inputs
[
self
.
real_inputs_count
:]
return
self
.
grad
(
self
.
network
,
self
.
params
)(
*
real_inputs
,
sense_param_inputs
)
if
self
.
real_inputs_count
is
None
or
self
.
sens_param
is
False
:
return
self
.
grad
(
self
.
network
)(
*
inputs
)
real_inputs
=
inputs
[:
self
.
real_inputs_count
]
sense_param_inputs
=
inputs
[
self
.
real_inputs_count
:]
return
self
.
grad
(
self
.
network
)(
*
real_inputs
,
sense_param_inputs
)
class
GradOfFirstInput
(
_Grad
):
"""
get grad of first input
"""
def
__init__
(
self
,
network
,
sens_param
=
True
,
real_inputs_count
=
None
):
super
().
__init__
(
grad
=
C
.
GradOperation
(
sens_param
=
sens_param
),
network
=
network
,
real_inputs_count
=
real_inputs_count
)
class
GradOfAllInputs
(
_Grad
):
"""
get grad of first input
"""
def
__init__
(
self
,
network
,
sens_param
=
True
,
real_inputs_count
=
None
):
super
().
__init__
(
grad
=
C
.
GradOperation
(
get_all
=
True
,
sens_param
=
sens_param
),
network
=
network
,
real_inputs_count
=
real_inputs_count
)
def
test_row_tensor_in_while
():
class
RowTensorValuesDouble
(
nn
.
Cell
):
def
__init__
(
self
):
super
().
__init__
()
def
construct
(
self
,
x
):
indices
=
x
.
indices
values
=
x
.
values
*
2
dense_shape
=
x
.
dense_shape
return
RowTensor
(
indices
,
values
,
dense_shape
)
class
RowTensorValuesAdd2
(
nn
.
Cell
):
def
__init__
(
self
):
super
().
__init__
()
def
construct
(
self
,
x
):
indices
=
x
.
indices
values
=
x
.
values
+
2
dense_shape
=
x
.
dense_shape
return
RowTensor
(
indices
,
values
,
dense_shape
)
class
RowTensorWithControlWhile
(
nn
.
Cell
):
def
__init__
(
self
,
dense_shape
):
super
().
__init__
()
self
.
op1
=
RowTensorValuesDouble
()
self
.
op2
=
RowTensorValuesAdd2
()
self
.
dense_shape
=
dense_shape
@
ms_function
def
construct
(
self
,
a
,
b
,
indices
,
values
):
x
=
RowTensor
(
indices
,
values
,
self
.
dense_shape
)
x
=
self
.
op2
(
x
)
while
(
a
>
b
):
x
=
self
.
op1
(
x
)
b
=
b
+
1
return
x
.
indices
,
x
.
values
,
x
.
dense_shape
a
=
Tensor
(
np
.
array
(
3
).
astype
(
np
.
int32
))
b
=
Tensor
(
np
.
array
(
0
).
astype
(
np
.
int32
))
indices
=
Tensor
(
np
.
array
([
0
,
2
]).
astype
(
np
.
int32
))
values
=
Tensor
(
np
.
ones
([
2
,
2
]).
astype
(
np
.
float32
))
dense_shape
=
(
5
,
2
)
net
=
RowTensorWithControlWhile
(
dense_shape
)
net
(
a
,
b
,
indices
,
values
)
def
test_multi_out_sens
():
class
ImageGradients
(
nn
.
Cell
):
def
__init__
(
self
):
super
().
__init__
()
def
construct
(
self
,
x
,
y
,
z
):
resa
=
x
*
y
resb
=
y
*
z
resc
=
x
*
z
return
resa
,
(
resb
,
resc
)
net
=
ImageGradients
()
net_me
=
GradOfAllInputs
(
net
,
real_inputs_count
=
3
)
net_me
.
set_train
()
input_data
=
Tensor
(
np
.
ones
([
32
]),
dtype
=
mstype
.
float32
)
output_grad
=
(
Tensor
(
np
.
ones
([
32
]),
dtype
=
mstype
.
float32
),
(
Tensor
(
np
.
ones
([
32
]),
dtype
=
mstype
.
float32
),
Tensor
(
np
.
ones
([
32
]),
dtype
=
mstype
.
float32
)))
net_me
(
input_data
,
input_data
,
input_data
,
*
output_grad
)
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录