Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
24f6b9d7
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
24f6b9d7
编写于
6月 08, 2020
作者:
Y
yujianfeng
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Add input2output pass
上级
f5dc6fbe
变更
9
显示空白变更内容
内联
并排
Showing
9 changed file
with
479 addition
and
0 deletion
+479
-0
mindspore/ccsrc/pre_activate/ascend/ascend_backend_optimization.cc
.../ccsrc/pre_activate/ascend/ascend_backend_optimization.cc
+2
-0
mindspore/ccsrc/pre_activate/ascend/ascend_helper.h
mindspore/ccsrc/pre_activate/ascend/ascend_helper.h
+15
-0
mindspore/ccsrc/pre_activate/ascend/ir_fusion/add_input_to_output.cc
...csrc/pre_activate/ascend/ir_fusion/add_input_to_output.cc
+115
-0
mindspore/ccsrc/pre_activate/ascend/ir_fusion/add_input_to_output.h
...ccsrc/pre_activate/ascend/ir_fusion/add_input_to_output.h
+39
-0
mindspore/ccsrc/pre_activate/ascend/ir_fusion/input_to_output_registry.cc
...pre_activate/ascend/ir_fusion/input_to_output_registry.cc
+122
-0
mindspore/ccsrc/pre_activate/ascend/ir_fusion/input_to_output_registry.h
.../pre_activate/ascend/ir_fusion/input_to_output_registry.h
+64
-0
mindspore/ccsrc/utils/utils.h
mindspore/ccsrc/utils/utils.h
+9
-0
tests/ut/cpp/pre_activate/ascend/ir_fusion/add_input_to_output_test.cc
...pre_activate/ascend/ir_fusion/add_input_to_output_test.cc
+74
-0
tests/ut/cpp/python_input/gtest_input/pre_activate/add_input_to_output_test.py
...nput/gtest_input/pre_activate/add_input_to_output_test.py
+39
-0
未找到文件。
mindspore/ccsrc/pre_activate/ascend/ascend_backend_optimization.cc
浏览文件 @
24f6b9d7
...
...
@@ -94,6 +94,7 @@
#include "pre_activate/ascend/ir_fission/split_fission.h"
#include "pre_activate/ascend/format_type/modify_ops_attrs.h"
#include "pre_activate/ascend/format_type/remove_no_use_reshape_op.h"
#include "pre_activate/ascend/ir_fusion/add_input_to_output.h"
#include "utils/context/ms_context.h"
#include "utils/config_manager.h"
#include "debug/anf_ir_dump.h"
...
...
@@ -259,6 +260,7 @@ void AscendBackendIRFusionOptimization(const std::shared_ptr<session::KernelGrap
ir_fusion_pm
->
AddPass
(
std
::
make_shared
<
EraseVisitAttr
>
());
}
ir_fusion_pm
->
AddPass
(
std
::
make_shared
<
InsertMemcpyAsyncForHcclOp
>
());
ir_fusion_pm
->
AddPass
(
std
::
make_shared
<
AddInputToOutput
>
());
optimizer
->
AddPassManager
(
ir_fusion_pm
);
(
void
)
optimizer
->
Optimize
(
kernel_graph
);
kernel_graph
->
SetExecOrderByDefault
();
...
...
mindspore/ccsrc/pre_activate/ascend/ascend_helper.h
浏览文件 @
24f6b9d7
...
...
@@ -70,6 +70,21 @@ class KernelQuery {
}
};
using
KernelQueryPtr
=
std
::
shared_ptr
<
KernelQuery
>
;
class
OpFinder
{
public:
OpFinder
()
=
default
;
virtual
~
OpFinder
()
=
default
;
virtual
int
GetOpRegisteredOutputNum
(
const
std
::
string
&
op_name
)
{
auto
op_info
=
kernel
::
OpLib
::
FindOp
(
op_name
,
kernel
::
kTBE
);
if
(
op_info
==
nullptr
)
{
return
-
1
;
}
return
op_info
->
outputs_ptr
().
size
();
}
};
using
OpFinderPtr
=
std
::
shared_ptr
<
OpFinder
>
;
void
RefreshKernelBuildInfo
(
const
std
::
string
&
input_format
,
const
std
::
string
&
output_format
,
const
AnfNodePtr
&
trans_data
,
const
std
::
vector
<
kernel
::
Axis
>
&
reshape_type
=
{});
...
...
mindspore/ccsrc/pre_activate/ascend/ir_fusion/add_input_to_output.cc
0 → 100644
浏览文件 @
24f6b9d7
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "pre_activate/ascend/ir_fusion/add_input_to_output.h"
#include <vector>
#include <algorithm>
#include "pre_activate/ascend/ir_fusion/input_to_output_registry.h"
#include "session/anf_runtime_algorithm.h"
#include "kernel/oplib/oplib.h"
namespace
mindspore
{
namespace
opt
{
namespace
{
void
GetInputOrOutputNames
(
const
CNodePtr
&
cnode
,
const
std
::
string
&
attr_name
,
std
::
vector
<
std
::
string
>
*
names_vec
)
{
MS_EXCEPTION_IF_NULL
(
names_vec
);
auto
primitive
=
AnfAlgo
::
GetCNodePrimitive
(
cnode
);
MS_EXCEPTION_IF_NULL
(
primitive
);
ValuePtr
names_value
=
primitive
->
GetAttr
(
attr_name
);
if
(
names_value
==
nullptr
)
{
return
;
}
*
names_vec
=
GetValue
<
std
::
vector
<
std
::
string
>>
(
names_value
);
}
void
AddOutputs
(
const
CNodePtr
&
cnode
,
const
std
::
vector
<
size_t
>
&
input_indices
)
{
MS_EXCEPTION_IF_NULL
(
cnode
);
std
::
vector
<
std
::
string
>
input_names_vec
;
GetInputOrOutputNames
(
cnode
,
kAttrInputNames
,
&
input_names_vec
);
std
::
vector
<
std
::
string
>
output_names_vec
;
GetInputOrOutputNames
(
cnode
,
kAttrOutputNames
,
&
output_names_vec
);
AbstractBasePtrList
abstract_list
;
auto
origin_abstract
=
cnode
->
abstract
();
MS_EXCEPTION_IF_NULL
(
origin_abstract
);
if
(
origin_abstract
->
isa
<
abstract
::
AbstractTuple
>
())
{
auto
origin_abstract_tuple
=
dyn_cast
<
abstract
::
AbstractTuple
>
(
origin_abstract
);
MS_EXCEPTION_IF_NULL
(
origin_abstract_tuple
);
AbstractBasePtrList
origin_abstract_list
=
origin_abstract_tuple
->
elements
();
(
void
)
std
::
copy
(
origin_abstract_list
.
begin
(),
origin_abstract_list
.
end
(),
std
::
back_inserter
(
abstract_list
));
}
else
{
abstract_list
.
emplace_back
(
origin_abstract
);
}
for
(
size_t
i
=
0
;
i
<
input_indices
.
size
();
++
i
)
{
size_t
index
=
input_indices
[
i
];
if
(
index
+
1
>=
cnode
->
inputs
().
size
())
{
MS_LOG
(
INFO
)
<<
"The input index "
<<
index
<<
" for converting to output is out of range, "
<<
"node: "
<<
cnode
->
DebugString
();
continue
;
}
auto
node_to_output
=
cnode
->
input
(
index
+
1
);
MS_EXCEPTION_IF_NULL
(
node_to_output
);
abstract_list
.
emplace_back
(
node_to_output
->
abstract
());
if
(
!
input_names_vec
.
empty
()
&&
!
output_names_vec
.
empty
()
&&
index
<
input_names_vec
.
size
())
{
output_names_vec
.
emplace_back
(
input_names_vec
[
index
]);
}
}
if
(
!
output_names_vec
.
empty
())
{
AnfAlgo
::
SetNodeAttr
(
kAttrOutputNames
,
MakeValue
(
output_names_vec
),
cnode
);
}
auto
abstract_tuple
=
std
::
make_shared
<
abstract
::
AbstractTuple
>
(
abstract_list
);
cnode
->
set_abstract
(
abstract_tuple
);
}
}
// namespace
const
AnfNodePtr
AddInputToOutput
::
Process
(
const
FuncGraphPtr
&
func_graph
,
const
AnfNodePtr
&
node
,
const
EquivPtr
&
)
const
{
if
(
node
==
nullptr
||
!
AnfAlgo
::
IsRealCNodeKernel
(
node
))
{
return
nullptr
;
}
auto
cnode
=
node
->
cast
<
CNodePtr
>
();
MS_EXCEPTION_IF_NULL
(
cnode
);
std
::
string
op_name
=
AnfAlgo
::
GetCNodeName
(
cnode
);
InputToOutputRegister
reg
;
if
(
!
InputToOutputRegistry
::
Instance
().
GetRegisterByOpName
(
op_name
,
&
reg
))
{
return
nullptr
;
}
int
output_num
=
op_finder_
->
GetOpRegisteredOutputNum
(
op_name
);
// No need add output when it is not a tbe op.
if
(
output_num
==
-
1
)
{
return
nullptr
;
}
// No need add output if the output num matches the registered output num for tbe.
if
(
AnfAlgo
::
GetOutputTensorNum
(
cnode
)
>=
IntToSize
(
output_num
))
{
return
nullptr
;
}
bool
is_origin_tuple_output
=
AnfAlgo
::
IsTupleOutput
(
cnode
);
AddOutputs
(
cnode
,
reg
.
input_indices
());
// No need to create tuple_getitem if the origin output is a tuple because there has already been some tuple_getitems
// pointed to the outputs.
if
(
is_origin_tuple_output
)
{
return
nullptr
;
}
std
::
vector
<
AnfNodePtr
>
new_outputs
;
auto
new_abstract_tuple
=
dyn_cast
<
abstract
::
AbstractTuple
>
(
cnode
->
abstract
());
MS_EXCEPTION_IF_NULL
(
new_abstract_tuple
);
CreateMultipleOutputsOfAnfNode
(
func_graph
,
cnode
,
new_abstract_tuple
->
size
(),
&
new_outputs
);
if
(
new_outputs
.
size
()
!=
new_abstract_tuple
->
size
())
{
MS_LOG
(
EXCEPTION
)
<<
"Failed to create outputs of "
<<
cnode
->
DebugString
();
}
return
new_outputs
[
0
];
}
}
// namespace opt
}
// namespace mindspore
mindspore/ccsrc/pre_activate/ascend/ir_fusion/add_input_to_output.h
0 → 100644
浏览文件 @
24f6b9d7
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_ADD_INPUT_TO_OUTPUT_H_
#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_ADD_INPUT_TO_OUTPUT_H_
#include <string>
#include <memory>
#include "pre_activate/common/optimizer.h"
#include "pre_activate/ascend/ascend_helper.h"
namespace
mindspore
{
namespace
opt
{
class
AddInputToOutput
:
public
PatternProcessPass
{
public:
explicit
AddInputToOutput
(
bool
multigraph
=
true
)
:
PatternProcessPass
(
"add_input_to_output"
,
multigraph
),
op_finder_
(
std
::
make_shared
<
OpFinder
>
())
{}
~
AddInputToOutput
()
override
=
default
;
const
AnfNodePtr
Process
(
const
FuncGraphPtr
&
,
const
AnfNodePtr
&
,
const
EquivPtr
&
)
const
override
;
private:
OpFinderPtr
op_finder_
;
};
}
// namespace opt
}
// namespace mindspore
#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_ADD_INPUT_TO_OUTPUT_H_
mindspore/ccsrc/pre_activate/ascend/ir_fusion/input_to_output_registry.cc
0 → 100644
浏览文件 @
24f6b9d7
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "pre_activate/ascend/ir_fusion/input_to_output_registry.h"
#include <utility>
#include "utils/utils.h"
#include "session/anf_runtime_algorithm.h"
namespace
mindspore
{
namespace
opt
{
namespace
{
bool
ApplyRMSPropPreCheck
(
const
CNodePtr
&
node
)
{
return
!
(
AnfAlgo
::
GetPrevNodeOutputInferDataType
(
node
,
0
)
!=
kNumberTypeFloat32
);
}
bool
FusedMulApplyMomentumPreCheck
(
const
CNodePtr
&
node
)
{
TypeId
data_type
=
AnfAlgo
::
GetPrevNodeOutputInferDataType
(
node
,
0
);
return
!
(
data_type
!=
kNumberTypeFloat32
&&
data_type
!=
kNumberTypeFloat16
);
}
bool
SparseApplyRMSPropPreCheck
(
const
CNodePtr
&
node
)
{
return
!
(
AnfAlgo
::
GetPrevNodeOutputInferDataType
(
node
,
0
)
!=
kNumberTypeFloat32
);
}
bool
ApplyAdagradV2PreCheck
(
const
CNodePtr
&
node
)
{
TypeId
data_type
=
AnfAlgo
::
GetPrevNodeOutputInferDataType
(
node
,
0
);
return
!
(
data_type
!=
kNumberTypeFloat32
&&
data_type
!=
kNumberTypeFloat16
);
}
bool
ApplyKerasMomentumPreCheck
(
const
CNodePtr
&
node
)
{
TypeId
data_type
=
AnfAlgo
::
GetPrevNodeOutputInferDataType
(
node
,
0
);
return
!
(
data_type
!=
kNumberTypeFloat32
&&
data_type
!=
kNumberTypeFloat16
);
}
bool
SparseApplyFtrlPreCheck
(
const
CNodePtr
&
node
)
{
return
!
(
AnfAlgo
::
GetPrevNodeOutputInferDataType
(
node
,
0
)
!=
kNumberTypeFloat32
);
}
bool
SparseApplyFtrlV2PreCheck
(
const
CNodePtr
&
node
)
{
return
!
(
AnfAlgo
::
GetPrevNodeOutputInferDataType
(
node
,
0
)
!=
kNumberTypeFloat32
);
}
bool
SparseApplyAdagradV2PreCheck
(
const
CNodePtr
&
node
)
{
return
!
(
AnfAlgo
::
GetPrevNodeOutputInferDataType
(
node
,
0
)
!=
kNumberTypeFloat32
);
}
bool
SparseApplyAdadeltaPreCheck
(
const
CNodePtr
&
node
)
{
return
!
(
AnfAlgo
::
GetPrevNodeOutputInferDataType
(
node
,
0
)
!=
kNumberTypeFloat32
);
}
}
// namespace
InputToOutputRegistry
::
InputToOutputRegistry
()
{
Register
(
kApplyRMSPropOpName
,
{
1
,
2
},
ApplyRMSPropPreCheck
);
Register
(
kFusedMulApplyMomentumOpName
,
{
1
},
FusedMulApplyMomentumPreCheck
);
Register
(
kApplyAdagradOpName
,
{
1
});
Register
(
kApplyAdagradDAName
,
{
1
,
2
});
Register
(
kApplyAdadeltaOpName
,
{
1
,
2
});
Register
(
kApplyPowerSignOpName
,
{
1
});
Register
(
kApplyProximalAdagradOpName
,
{
1
});
Register
(
kApplyAdaMaxOpName
,
{
1
,
2
});
Register
(
kApplyAdagradV2OpName
,
{
1
},
ApplyAdagradV2PreCheck
);
Register
(
kApplyKerasMomentumOpName
,
{
1
},
ApplyKerasMomentumPreCheck
);
Register
(
kSparseApplyFtrlOpName
,
{
1
,
2
},
SparseApplyFtrlPreCheck
);
Register
(
kSparseApplyFtrlV2OpName
,
{
1
,
2
},
SparseApplyFtrlV2PreCheck
);
Register
(
kSparseApplyAdagradV2OpName
,
{
1
},
SparseApplyAdagradV2PreCheck
);
Register
(
kSparseApplyProximalAdagradOpName
,
{
1
});
Register
(
kSparseApplyAdagradOpName
,
{
1
});
Register
(
kApplyFtrlV2OpName
,
{
1
,
2
});
Register
(
kApplyMomentumOpName
,
{
1
});
Register
(
kApplyFtrlOpName
,
{
1
,
2
});
Register
(
kApplyAdamOpName
,
{
1
,
2
});
Register
(
kApplyCenteredRMSPropOpName
,
{
1
,
2
,
3
});
Register
(
kApplyAddSignOpName
,
{
1
});
Register
(
kSparseApplyRMSPropOpName
,
{
1
,
2
},
SparseApplyRMSPropPreCheck
);
Register
(
kSparseApplyAdadeltaOpName
,
{
1
,
2
},
SparseApplyAdadeltaPreCheck
);
Register
(
kApplyAdamWithAmsgradOpName
,
{
1
,
2
});
}
InputToOutputRegistry
&
InputToOutputRegistry
::
Instance
()
{
static
InputToOutputRegistry
instance
;
return
instance
;
}
void
InputToOutputRegistry
::
Register
(
const
InputToOutputRegister
&
reg
)
{
auto
op_name
=
reg
.
op_name
();
if
(
op_input_to_output_map_
.
find
(
op_name
)
==
op_input_to_output_map_
.
end
())
{
(
void
)
op_input_to_output_map_
.
insert
(
make_pair
(
op_name
,
reg
));
MS_LOG
(
DEBUG
)
<<
op_name
<<
" input2output register successfully!"
;
}
}
void
InputToOutputRegistry
::
Register
(
const
std
::
string
&
op_name
,
const
std
::
vector
<
size_t
>
&
input_indices
,
const
PreCheckFunc
&
pre_check_func
)
{
if
(
op_input_to_output_map_
.
find
(
op_name
)
==
op_input_to_output_map_
.
end
())
{
InputToOutputRegister
reg
(
op_name
,
pre_check_func
);
reg
.
set_input_indices
(
input_indices
);
(
void
)
op_input_to_output_map_
.
insert
(
make_pair
(
op_name
,
reg
));
MS_LOG
(
DEBUG
)
<<
op_name
<<
" input2output register successfully!"
;
}
}
bool
InputToOutputRegistry
::
GetRegisterByOpName
(
const
std
::
string
&
op_name
,
InputToOutputRegister
*
reg
)
const
{
if
(
op_input_to_output_map_
.
find
(
op_name
)
!=
op_input_to_output_map_
.
end
())
{
*
reg
=
op_input_to_output_map_
.
at
(
op_name
);
MS_LOG
(
DEBUG
)
<<
op_name
<<
" input2output find in registry."
;
return
true
;
}
return
false
;
}
}
// namespace opt
}
// namespace mindspore
mindspore/ccsrc/pre_activate/ascend/ir_fusion/input_to_output_registry.h
0 → 100644
浏览文件 @
24f6b9d7
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_IR_FUSION_INPUT_TO_OUTPUT_REGISTRY_H_
#define MINDSPORE_CCSRC_PRE_ACTIVATE_IR_FUSION_INPUT_TO_OUTPUT_REGISTRY_H_
#include <string>
#include <unordered_map>
#include <vector>
#include <utility>
#include "ir/anf.h"
#include "common/utils.h"
namespace
mindspore
{
namespace
opt
{
using
PreCheckFunc
=
std
::
function
<
bool
(
const
CNodePtr
&
node
)
>
;
class
InputToOutputRegister
{
public:
explicit
InputToOutputRegister
(
const
std
::
string
&
op_name
=
""
,
const
PreCheckFunc
&
pre_check_func
=
[](
const
CNodePtr
&
node
)
{
return
true
;
})
:
op_name_
(
op_name
),
pre_check_func_
(
pre_check_func
)
{}
virtual
~
InputToOutputRegister
()
=
default
;
void
set_input_indices
(
const
std
::
vector
<
size_t
>
&
input_indices
)
{
input_indices_
=
input_indices
;
}
const
std
::
vector
<
size_t
>
&
input_indices
()
const
{
return
input_indices_
;
}
const
std
::
string
&
op_name
()
const
{
return
op_name_
;
}
private:
std
::
string
op_name_
;
std
::
vector
<
size_t
>
input_indices_
;
PreCheckFunc
pre_check_func_
;
};
class
InputToOutputRegistry
{
public:
static
InputToOutputRegistry
&
Instance
();
void
Register
(
const
InputToOutputRegister
&
reg
);
void
Register
(
const
std
::
string
&
op_name
,
const
std
::
vector
<
size_t
>
&
input_indices
,
const
PreCheckFunc
&
pre_check_func
=
[](
const
CNodePtr
&
node
)
{
return
true
;
});
bool
GetRegisterByOpName
(
const
std
::
string
&
op_name
,
InputToOutputRegister
*
reg
)
const
;
private:
InputToOutputRegistry
();
~
InputToOutputRegistry
()
=
default
;
DISABLE_COPY_AND_ASSIGN
(
InputToOutputRegistry
)
std
::
unordered_map
<
std
::
string
,
InputToOutputRegister
>
op_input_to_output_map_
;
};
}
// namespace opt
}
// namespace mindspore
#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_IR_FUSION_INPUT_TO_OUTPUT_REGISTRY_H_
mindspore/ccsrc/utils/utils.h
浏览文件 @
24f6b9d7
...
...
@@ -164,6 +164,15 @@ constexpr auto kStridedReadOpName = "StridedRead";
constexpr
auto
kStridedWriteOpName
=
"StridedWrite"
;
constexpr
auto
kFusedAdamWeightDecayName
=
"FusedAdamWeightDecay"
;
constexpr
auto
kFusedAdamName
=
"FusedAdam"
;
constexpr
auto
kApplyAdagradV2OpName
=
"ApplyAdagradV2"
;
constexpr
auto
kSparseApplyAdagradV2OpName
=
"SparseApplyAdagradV2"
;
constexpr
auto
kSparseApplyFtrlOpName
=
"SparseApplyFtrl"
;
constexpr
auto
kSparseApplyFtrlV2OpName
=
"SparseApplyFtrlV2"
;
constexpr
auto
kApplyKerasMomentumOpName
=
"ApplyKerasMomentum"
;
constexpr
auto
kSparseApplyProximalAdagradOpName
=
"SparseApplyProximalAdagrad"
;
constexpr
auto
kSparseApplyRMSPropOpName
=
"SparseApplyRMSProp"
;
constexpr
auto
kSparseApplyAdadeltaOpName
=
"SparseApplyAdadelta"
;
constexpr
auto
kApplyAdamWithAmsgradOpName
=
"ApplyAdamWithAmsgrad"
;
// attr key name
constexpr
auto
kAttrInputNames
=
"input_names"
;
...
...
tests/ut/cpp/pre_activate/ascend/ir_fusion/add_input_to_output_test.cc
0 → 100644
浏览文件 @
24f6b9d7
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "common/backend_common_test.h"
#include "common/py_func_graph_fetcher.h"
#include "debug/anf_ir_dump.h"
#define private public
#define protected public
#include "pre_activate/ascend/ir_fusion/add_input_to_output.h"
#undef private
#undef protected
namespace
mindspore
{
namespace
opt
{
class
TestHWAddInputToOutput
:
public
BackendCommon
{
public:
TestHWAddInputToOutput
()
:
getPyFun_
(
"gtest_input.pre_activate.add_input_to_output_test"
,
true
)
{}
~
TestHWAddInputToOutput
()
override
=
default
;
public:
UT
::
PyFuncGraphFetcher
getPyFun_
;
};
class
MockOpFinder
:
public
OpFinder
{
public:
MockOpFinder
()
=
default
;
~
MockOpFinder
()
override
=
default
;
int
GetOpRegisteredOutputNum
(
const
std
::
string
&
op_name
)
override
{
return
2
;
}
};
TEST_F
(
TestHWAddInputToOutput
,
test_add_input_to_output
)
{
FuncGraphPtr
g
=
getPyFun_
.
CallAndParseRet
(
"test_add_input_to_output"
,
"before"
);
EXPECT_NE
(
g
,
nullptr
);
std
::
vector
<
int
>
shp
{
2
,
32
,
224
,
224
};
auto
x_abstract
=
std
::
make_shared
<
abstract
::
AbstractTensor
>
(
kFloat32
,
shp
);
AbstractBasePtrList
args_spec_list
;
for
(
size_t
i
=
0
;
i
<
5
;
++
i
)
{
args_spec_list
.
push_back
(
x_abstract
);
}
auto
kg
=
GetKernelGraph
(
g
,
args_spec_list
);
EXPECT_NE
(
kg
,
nullptr
);
auto
ret
=
kg
->
get_return
();
EXPECT_NE
(
ret
,
nullptr
);
auto
make_tuple
=
ret
->
input
(
1
);
EXPECT_NE
(
make_tuple
,
nullptr
);
auto
momentum
=
make_tuple
->
cast
<
CNodePtr
>
()
->
input
(
1
);
EXPECT_NE
(
momentum
,
nullptr
);
EXPECT_NE
(
momentum
->
abstract
(),
nullptr
);
EXPECT_FALSE
(
momentum
->
abstract
()
->
isa
<
abstract
::
AbstractTuple
>
());
auto
optimizer
=
std
::
make_shared
<
opt
::
GraphOptimizer
>
();
auto
pm
=
std
::
make_shared
<
opt
::
PassManager
>
();
auto
pass
=
std
::
make_shared
<
opt
::
AddInputToOutput
>
();
pass
->
op_finder_
=
std
::
make_shared
<
MockOpFinder
>
();
pm
->
AddPass
(
pass
);
optimizer
->
AddPassManager
(
pm
);
(
void
)
optimizer
->
Optimize
(
kg
);
EXPECT_TRUE
(
momentum
->
abstract
()
->
isa
<
abstract
::
AbstractTuple
>
());
}
}
// namespace opt
}
// namespace mindspore
tests/ut/cpp/python_input/gtest_input/pre_activate/add_input_to_output_test.py
0 → 100644
浏览文件 @
24f6b9d7
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
from
mindspore.ops
import
operations
as
P
ApplyMomentum
=
P
.
ApplyMomentum
()
class
FnDict
:
def
__init__
(
self
):
self
.
fnDict
=
{}
def
__call__
(
self
,
fn
):
self
.
fnDict
[
fn
.
__name__
]
=
fn
def
__getitem__
(
self
,
name
):
return
self
.
fnDict
[
name
]
def
test_add_input_to_output
(
tag
):
fns
=
FnDict
()
@
fns
def
before
(
input0
,
input1
,
input2
,
input3
,
input4
):
return
ApplyMomentum
(
input0
,
input1
,
input2
,
input3
,
input4
)
return
fns
[
tag
]
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录