Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
1f5441d7
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
1f5441d7
编写于
8月 08, 2020
作者:
W
WilliamLian
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
remove inferfunction to core
上级
01158763
变更
22
展开全部
隐藏空白更改
内联
并排
Showing
22 changed file
with
1285 addition
and
1090 deletion
+1285
-1090
mindspore/ccsrc/frontend/operator/cc_implementations.cc
mindspore/ccsrc/frontend/operator/cc_implementations.cc
+0
-35
mindspore/ccsrc/frontend/operator/cc_implementations.h
mindspore/ccsrc/frontend/operator/cc_implementations.h
+0
-1
mindspore/ccsrc/frontend/operator/ops.h
mindspore/ccsrc/frontend/operator/ops.h
+0
-15
mindspore/ccsrc/frontend/operator/ops_front_infer_function.cc
...spore/ccsrc/frontend/operator/ops_front_infer_function.cc
+340
-401
mindspore/ccsrc/frontend/operator/ops_front_infer_function.h
mindspore/ccsrc/frontend/operator/ops_front_infer_function.h
+77
-0
mindspore/ccsrc/pipeline/jit/static_analysis/prim.cc
mindspore/ccsrc/pipeline/jit/static_analysis/prim.cc
+1
-104
mindspore/ccsrc/pipeline/jit/static_analysis/prim.h
mindspore/ccsrc/pipeline/jit/static_analysis/prim.h
+1
-195
mindspore/core/abstract/infer_functions.h
mindspore/core/abstract/infer_functions.h
+187
-0
mindspore/core/abstract/prim_arrays.cc
mindspore/core/abstract/prim_arrays.cc
+38
-3
mindspore/core/abstract/prim_debug.cc
mindspore/core/abstract/prim_debug.cc
+1
-2
mindspore/core/abstract/prim_maths.cc
mindspore/core/abstract/prim_maths.cc
+1
-2
mindspore/core/abstract/prim_nn.cc
mindspore/core/abstract/prim_nn.cc
+90
-9
mindspore/core/abstract/prim_others.cc
mindspore/core/abstract/prim_others.cc
+2
-142
mindspore/core/abstract/prim_statement.cc
mindspore/core/abstract/prim_statement.cc
+1
-34
mindspore/core/abstract/prim_structures.cc
mindspore/core/abstract/prim_structures.cc
+278
-0
mindspore/core/abstract/primitive_infer_map.cc
mindspore/core/abstract/primitive_infer_map.cc
+114
-0
mindspore/core/abstract/primitive_infer_map.h
mindspore/core/abstract/primitive_infer_map.h
+53
-0
mindspore/core/base/core_ops.h
mindspore/core/base/core_ops.h
+19
-0
mindspore/core/c_ops/conv2d.cc
mindspore/core/c_ops/conv2d.cc
+53
-84
mindspore/core/c_ops/conv2d.h
mindspore/core/c_ops/conv2d.h
+24
-59
mindspore/core/c_ops/primitive_c.h
mindspore/core/c_ops/primitive_c.h
+4
-4
tests/ut/cpp/CMakeLists.txt
tests/ut/cpp/CMakeLists.txt
+1
-0
未找到文件。
mindspore/ccsrc/frontend/operator/cc_implementations.cc
浏览文件 @
1f5441d7
...
...
@@ -393,40 +393,5 @@ ValuePtr BoolEq(const ValuePtrList &list) {
MS_LOG
(
EXCEPTION
)
<<
"Unsported Value for BoolEq, x: "
<<
x
->
ToString
()
<<
"."
;
}
std
::
vector
<
int
>
BroadcastShape_
(
std
::
vector
<
int
>
shpx
,
std
::
vector
<
int
>
shpy
)
{
int
dlen
=
SizeToInt
(
shpx
.
size
())
-
SizeToInt
(
shpy
.
size
());
if
(
dlen
<
0
)
{
for
(
int
i
=
0
;
i
<
-
dlen
;
++
i
)
{
(
void
)
shpx
.
insert
(
shpx
.
begin
(),
1
);
}
}
else
if
(
dlen
>
0
)
{
for
(
int
i
=
0
;
i
<
dlen
;
i
++
)
{
(
void
)
shpy
.
insert
(
shpy
.
begin
(),
1
);
}
}
if
(
shpx
.
size
()
!=
shpy
.
size
())
{
MS_LOG
(
EXCEPTION
)
<<
"Failure: shpx.size() != shpy.size()."
;
}
std
::
vector
<
int
>
shp
;
for
(
size_t
i
=
0
;
i
<
shpx
.
size
();
i
++
)
{
auto
a
=
shpx
[
i
];
auto
b
=
shpy
[
i
];
if
(
a
==
1
)
{
shp
.
push_back
(
b
);
}
else
if
(
b
==
1
)
{
shp
.
push_back
(
a
);
}
else
if
(
a
==
-
1
)
{
shp
.
push_back
(
b
);
}
else
if
(
b
==
-
1
)
{
shp
.
push_back
(
a
);
}
else
if
(
a
==
b
)
{
shp
.
push_back
(
a
);
}
else
{
return
std
::
vector
<
int
>
();
}
}
return
shp
;
}
}
// namespace prim
}
// namespace mindspore
mindspore/ccsrc/frontend/operator/cc_implementations.h
浏览文件 @
1f5441d7
...
...
@@ -52,7 +52,6 @@ ValuePtr BoolNot(const ValuePtrList &list);
ValuePtr
BoolAnd
(
const
ValuePtrList
&
list
);
ValuePtr
BoolOr
(
const
ValuePtrList
&
list
);
ValuePtr
BoolEq
(
const
ValuePtrList
&
list
);
std
::
vector
<
int
>
BroadcastShape_
(
std
::
vector
<
int
>
s1
,
std
::
vector
<
int
>
s2
);
}
// namespace prim
}
// namespace mindspore
...
...
mindspore/ccsrc/frontend/operator/ops.h
浏览文件 @
1f5441d7
...
...
@@ -42,28 +42,13 @@ inline const PrimitivePtr kPrimRefToEmbed = std::make_shared<Primitive>("RefToEm
inline
const
PrimitivePtr
kPrimCreateInstance
=
std
::
make_shared
<
Primitive
>
(
"create_instance"
);
// Other miscellaneous
inline
const
PrimitivePtr
kPrimEnvSetItem
=
std
::
make_shared
<
Primitive
>
(
"env_setitem"
);
inline
const
PrimitivePtr
kPrimEnvGetItem
=
std
::
make_shared
<
Primitive
>
(
"env_getitem"
);
inline
const
PrimitivePtr
kPrimEnvAdd
=
std
::
make_shared
<
Primitive
>
(
"env_add"
);
inline
const
PrimitivePtr
kPrimMakeRefKey
=
std
::
make_shared
<
Primitive
>
(
"MakeRefKey"
);
inline
const
PrimitivePtr
kPrimGetRefKey
=
std
::
make_shared
<
Primitive
>
(
"get_ref_key"
);
inline
const
PrimitivePtr
kPrimGetRefValue
=
std
::
make_shared
<
Primitive
>
(
"get_ref_value"
);
inline
const
PrimitivePtr
kPrimGetRefOrigin
=
std
::
make_shared
<
Primitive
>
(
"get_ref_origin"
);
inline
const
PrimitivePtr
kPrimInsertGradientOf
=
std
::
make_shared
<
Primitive
>
(
"InsertGradientOf"
);
inline
const
PrimitivePtr
kPrimCheckBprop
=
std
::
make_shared
<
Primitive
>
(
"CheckBprop"
);
inline
const
PrimitivePtr
kPrimMakeRef
=
std
::
make_shared
<
Primitive
>
(
"make_ref"
);
inline
const
PrimitivePtr
kPrimMixedPrecisionCast
=
std
::
make_shared
<
Primitive
>
(
"mixed_precision_cast"
);
inline
const
PrimitivePtr
kPrimMakeRecord
=
std
::
make_shared
<
Primitive
>
(
"make_record"
);
// Structures
inline
const
PrimitivePtr
kPrimMakeList
=
std
::
make_shared
<
Primitive
>
(
"make_list"
);
inline
const
PrimitivePtr
kPrimMakeKeywordArg
=
std
::
make_shared
<
Primitive
>
(
"make_keyword_arg"
);
inline
const
PrimitivePtr
kPrimListGetItem
=
std
::
make_shared
<
Primitive
>
(
"list_getitem"
);
inline
const
PrimitivePtr
kPrimListSetItem
=
std
::
make_shared
<
Primitive
>
(
"list_setitem"
);
inline
const
PrimitivePtr
kPrimDictGetItem
=
std
::
make_shared
<
Primitive
>
(
"dict_getitem"
);
inline
const
PrimitivePtr
kPrimDictSetItem
=
std
::
make_shared
<
Primitive
>
(
"dict_setitem"
);
inline
const
PrimitivePtr
kPrimListAppend
=
std
::
make_shared
<
Primitive
>
(
"list_append"
);
inline
const
PrimitivePtr
kPrimListLen
=
std
::
make_shared
<
Primitive
>
(
"list_len"
);
inline
const
PrimitivePtr
kPrimListMap
=
std
::
make_shared
<
Primitive
>
(
"list_map"
);
inline
const
PrimitivePtr
kPrimListReduce
=
std
::
make_shared
<
Primitive
>
(
"list_reduce"
);
...
...
mindspore/ccsrc/frontend/operator/
prim_structures
.cc
→
mindspore/ccsrc/frontend/operator/
ops_front_infer_function
.cc
浏览文件 @
1f5441d7
此差异已折叠。
点击以展开。
mindspore/ccsrc/frontend/operator/ops_front_infer_function.h
0 → 100644
浏览文件 @
1f5441d7
/**
* Copyright 2019 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_FRONTEND_OPERATE_OPS_FRONT_INFER_FUNCTION_H_
#define MINDSPORE_CCSRC_FRONTEND_OPERATE_OPS_FRONT_INFER_FUNCTION_H_
#include "abstract/abstract_value.h"
#include "abstract/primitive_infer_map.h"
namespace
mindspore
{
namespace
abstract
{
AbstractBasePtr
InferImplTypeof
(
const
AnalysisEnginePtr
&
,
const
PrimitivePtr
&
,
const
AbstractBasePtrList
&
args_spec_list
);
AbstractBasePtr
InferImplHasType
(
const
AnalysisEnginePtr
&
,
const
PrimitivePtr
&
primitive
,
const
AbstractBasePtrList
&
args_spec_list
);
AbstractBasePtr
InferImplBroadcastGradientArgs
(
const
AnalysisEnginePtr
&
,
const
PrimitivePtr
&
primitive
,
const
AbstractBasePtrList
&
args_spec_list
);
AbstractBasePtr
InferImplListMap
(
const
AnalysisEnginePtr
&
,
const
PrimitivePtr
&
primitive
,
const
AbstractBasePtrList
&
args_spec_list
);
AbstractBasePtr
InferImplListReduce
(
const
AnalysisEnginePtr
&
,
const
PrimitivePtr
&
primitive
,
const
AbstractBasePtrList
&
args_spec_list
);
AbstractBasePtr
InferImplTupleReversed
(
const
AnalysisEnginePtr
&
,
const
PrimitivePtr
&
primitive
,
const
AbstractBasePtrList
&
args_spec_list
);
AbstractBasePtr
InferImplReduceShape
(
const
AnalysisEnginePtr
&
,
const
PrimitivePtr
&
primitive
,
const
AbstractBasePtrList
&
args_spec_list
);
AbstractBasePtr
InferImplTupleDiv
(
const
AnalysisEnginePtr
&
,
const
PrimitivePtr
&
primitive
,
const
AbstractBasePtrList
&
args_spec_list
);
AbstractBasePtr
InferImplTuple2Array
(
const
AnalysisEnginePtr
&
,
const
PrimitivePtr
&
primitive
,
const
AbstractBasePtrList
&
args_spec_list
);
AbstractBasePtr
InferImplShapeMul
(
const
AnalysisEnginePtr
&
,
const
PrimitivePtr
&
primitive
,
const
AbstractBasePtrList
&
args_spec_list
);
AbstractBasePtr
InferImplTupleEqual
(
const
AnalysisEnginePtr
&
,
const
PrimitivePtr
&
primitive
,
const
AbstractBasePtrList
&
args_spec_list
);
AbstractBasePtr
InferImplListEqual
(
const
AnalysisEnginePtr
&
,
const
PrimitivePtr
&
primitive
,
const
AbstractBasePtrList
&
args_spec_list
);
AbstractBasePtr
InferImplMakeRange
(
const
AnalysisEnginePtr
&
,
const
PrimitivePtr
&
primitive
,
const
AbstractBasePtrList
&
args_spec_list
);
AbstractBasePtr
InferImplStopGradient
(
const
AnalysisEnginePtr
&
,
const
PrimitivePtr
&
primitive
,
const
AbstractBasePtrList
&
args_spec_list
);
AbstractBasePtr
InferImplStringEqual
(
const
AnalysisEnginePtr
&
,
const
PrimitivePtr
&
primitive
,
const
AbstractBasePtrList
&
args_spec_list
);
AbstractBasePtr
InferImplStringConcat
(
const
AnalysisEnginePtr
&
,
const
PrimitivePtr
&
primitive
,
const
AbstractBasePtrList
&
args_spec_list
);
AbstractBasePtr
InferImplDictLen
(
const
AnalysisEnginePtr
&
,
const
PrimitivePtr
&
primitive
,
const
AbstractBasePtrList
&
args_spec_list
);
AbstractBasePtr
InferImplJ
(
const
AnalysisEnginePtr
&
,
const
PrimitivePtr
&
primitive
,
const
AbstractBasePtrList
&
args_spec_list
);
AbstractBasePtr
InferImplFakeBprop
(
const
AnalysisEnginePtr
&
,
const
PrimitivePtr
&
primitive
,
const
AbstractBasePtrList
&
args_spec_list
);
AbstractBasePtr
InferImplMakeRecord
(
const
AnalysisEnginePtr
&
,
const
PrimitivePtr
&
primitive
,
const
AbstractBasePtrList
&
args_spec_list
);
class
RegisterFrontendPrimitiveEvalHelper
{
public:
RegisterFrontendPrimitiveEvalHelper
(
const
PrimitivePtr
&
primitive
,
const
StandardPrimitiveEvalImpl
&
impl
)
{
const
StandardPrimitiveImplReg
impl_reg
{
impl
,
false
};
RegisterStandardPrimitiveImpl
(
primitive
,
impl_reg
);
}
~
RegisterFrontendPrimitiveEvalHelper
()
=
default
;
};
#define REGISTER_FRONTENT_PRIMITIVE_EVAL_IMPL(name, primitive, impl) \
static auto helper_##name = RegisterFrontendPrimitiveEvalHelper(primitive, impl)
}
// namespace abstract
}
// namespace mindspore
#endif // MINDSPORE_CCSRC_FRONTEND_OPERATE_OPS_FRONT_INFER_FUNCTION_H_
mindspore/ccsrc/pipeline/jit/static_analysis/prim.cc
浏览文件 @
1f5441d7
...
...
@@ -36,115 +36,12 @@
#include "utils/convert_utils.h"
#include "utils/ms_context.h"
#include "pipeline/jit/parse/data_converter.h"
#include "abstract/primitive_infer_map.h"
#include "abstract/param_validator.h"
#include "utils/ms_utils.h"
namespace
mindspore
{
namespace
abstract
{
PrimitiveEvalImplMap
&
GetPrimitiveToEvalImplMap
()
{
static
PrimitiveEvalImplMap
prim_eval_implement_map
=
{
// Statements
{
prim
::
kPrimReturn
,
{
InferImplReturn
,
true
}},
{
prim
::
kPrimTypeOf
,
{
InferImplTypeof
,
false
}},
{
prim
::
kPrimHasType
,
{
InferImplHasType
,
false
}},
{
prim
::
kPrimDot
,
{
InferImplDot
,
true
}},
{
prim
::
kPrimSwitch
,
{
InferImplSwitch
,
true
}},
{
prim
::
kPrimSwitchLayer
,
{
InferImplSwitchLayer
,
true
}},
{
prim
::
kPrimIs_
,
{
InferImplIs_
,
true
}},
{
prim
::
kPrimIsNot
,
{
InferImplIsNot
,
true
}},
{
prim
::
kPrimInDict
,
{
InferImplInDict
,
true
}},
{
prim
::
kPrimNotInDict
,
{
InferImplNotInDict
,
true
}},
{
prim
::
kPrimIsConsant
,
{
InferImplIsConstant
,
true
}},
// Maths
{
prim
::
kPrimMaximumGrad
,
{
InferImplMinOrMaxGrad
,
true
}},
{
prim
::
kPrimMinimumGrad
,
{
InferImplMinOrMaxGrad
,
true
}},
// Array
{
prim
::
kPrimScalarToArray
,
{
InferImplScalarToArray
,
true
}},
{
prim
::
kPrimArrayToScalar
,
{
InferImplArrayToScalar
,
true
}},
{
prim
::
kPrimBroadcastShape
,
{
InferImplBroadCastShape
,
true
}},
{
prim
::
kPrimPack
,
{
InferImplPack
,
true
}},
{
prim
::
kPrimUnique
,
{
InferImplUnique
,
true
}},
{
prim
::
kPrimUniqueGrad
,
{
InferImplUniqueGrad
,
true
}},
// Structure
{
prim
::
kPrimMakeTuple
,
{
InferImplMakeTuple
,
true
}},
{
prim
::
kPrimMakeList
,
{
InferImplMakeList
,
true
}},
{
prim
::
kPrimMakeDict
,
{
InferImplMakeDict
,
true
}},
{
prim
::
kPrimMakeSlice
,
{
InferImplMakeSlice
,
true
}},
{
prim
::
kPrimMakeKeywordArg
,
{
InferImplMakeKwarg
,
true
}},
{
prim
::
kPrimExtractKeywordArg
,
{
InferImplExtractKwarg
,
true
}},
{
prim
::
kPrimMakeRecord
,
{
InferImplMakeRecord
,
false
}},
{
prim
::
kPrimTupleGetItem
,
{
InferImplTupleGetItem
,
true
}},
{
prim
::
kPrimListGetItem
,
{
InferImplListGetItem
,
true
}},
{
prim
::
kPrimTupleSetItem
,
{
InferImplTupleSetItem
,
true
}},
{
prim
::
kPrimListSetItem
,
{
InferImplListSetItem
,
true
}},
{
prim
::
kPrimDictGetItem
,
{
InferImplDictGetItem
,
true
}},
{
prim
::
kPrimDictSetItem
,
{
InferImplDictSetItem
,
true
}},
{
prim
::
kPrimListAppend
,
{
InferImplListAppend
,
true
}},
{
prim
::
kPrimTupleLen
,
{
InferImplTupleLen
,
true
}},
{
prim
::
kPrimListLen
,
{
InferImplListLen
,
true
}},
{
prim
::
kPrimArrayLen
,
{
InferImplArrayLen
,
true
}},
{
prim
::
kPrimListMap
,
{
InferImplListMap
,
false
}},
{
prim
::
kPrimListReduce
,
{
InferImplListReduce
,
false
}},
{
prim
::
kPrimTupleReversed
,
{
InferImplTupleReversed
,
false
}},
{
prim
::
kPrimReducedShape
,
{
InferImplReduceShape
,
false
}},
{
prim
::
kPrimTupleDiv
,
{
InferImplTupleDiv
,
false
}},
{
prim
::
kPrimTupleToArray
,
{
InferImplTuple2Array
,
false
}},
{
prim
::
kPrimShapeMul
,
{
InferImplShapeMul
,
false
}},
{
prim
::
kPrimTupleEqual
,
{
InferImplTupleEqual
,
false
}},
{
prim
::
kPrimListEqual
,
{
InferImplListEqual
,
false
}},
{
prim
::
kPrimMakeRange
,
{
InferImplMakeRange
,
false
}},
{
prim
::
kPrimStopGradient
,
{
InferImplStopGradient
,
false
}},
{
prim
::
kPrimStringEqual
,
{
InferImplStringEqual
,
false
}},
{
prim
::
kPrimStringConcat
,
{
InferImplStringConcat
,
false
}},
{
prim
::
kPrimDictLen
,
{
InferImplDictLen
,
false
}},
// NN
{
prim
::
kPrimPooling
,
{
InferImplPooling
,
true
}},
{
prim
::
kPrimPoolingGrad
,
{
InferImplPoolingGrad
,
true
}},
{
prim
::
kPrimFusedBatchNorm
,
{
InferImplFusedBatchNorm
,
true
}},
{
prim
::
kPrimFusedBatchNormGrad
,
{
InferImplFusedBatchNormGrad
,
true
}},
{
prim
::
kPrimReluGrad
,
{
InferImplReluGrad
,
true
}},
{
prim
::
kPrimConv2DBackpropInput
,
{
InferImplConv2DBackpropInput
,
true
}},
{
prim
::
kPrimConv2DBackpropFilter
,
{
InferImplConv2DBackpropFilter
,
true
}},
{
prim
::
kPrimBiasAddGrad
,
{
InferImplBiasAddGrad
,
true
}},
{
prim
::
kPrimRelu
,
{
InferImplRelu
,
true
}},
{
prim
::
kPrimFakeBprop
,
{
InferImplFakeBprop
,
false
}},
{
prim
::
kPrimZerosLike
,
{
InferImplZerosLike
,
true
}},
{
prim
::
kPrimBpropCut
,
{
InferImplBpropCut
,
true
}},
{
prim
::
kPrimLayerNorm
,
{
InferImplLayerNorm
,
true
}},
{
prim
::
kPrimLayerNormGrad
,
{
InferImplLayerNormGrad
,
true
}},
{
prim
::
kPrimDropoutGenMask
,
{
InferImplDropoutGenMask
,
true
}},
// Others
{
prim
::
kPrimIdentity
,
{
InferImplIdentity
,
true
}},
// Set impl to null as it will use PartialEvaluator;
{
prim
::
kPrimPartial
,
{
nullptr
,
true
}},
{
prim
::
kPrimJ
,
{
InferImplJ
,
false
}},
{
prim
::
kPrimEnvGetItem
,
{
InferImplEnvGetItem
,
true
}},
{
prim
::
kPrimEnvSetItem
,
{
InferImplEnvSetItem
,
true
}},
{
prim
::
kPrimEnvAdd
,
{
InferImplEnvAdd
,
true
}},
{
prim
::
kPrimMakeRefKey
,
{
InferImplMakeRefKey
,
true
}},
{
prim
::
kPrimMakeRef
,
{
InferImplMakeRef
,
true
}},
{
prim
::
kPrimGetRefKey
,
{
InferImplGetRefKey
,
true
}},
{
prim
::
kPrimGetRefValue
,
{
InferImplGetRefValue
,
true
}},
{
prim
::
kPrimStateSetItem
,
{
InferImplStateSetItem
,
true
}},
{
prim
::
kPrimDepend
,
{
InferImplDepend
,
true
}},
{
prim
::
kPrimBroadcastGradientArgs
,
{
InferImplBroadcastGradientArgs
,
false
}},
{
prim
::
kPrimControlDepend
,
{
InferImplControlDepend
,
true
}},
// Debug
{
prim
::
kPrimDebug
,
{
InferImplDebug
,
true
}},
// RowTensor
{
prim
::
kPrimMakeRowTensor
,
{
InferImplMakeRowTensor
,
true
}},
{
prim
::
kPrimRowTensorGetValues
,
{
InferImplRowTensorGetValues
,
true
}},
{
prim
::
kPrimRowTensorGetIndices
,
{
InferImplRowTensorGetIndices
,
true
}},
{
prim
::
kPrimRowTensorGetDenseShape
,
{
InferImplRowTensorGetDenseShape
,
true
}},
// SparseTensor
{
prim
::
kPrimMakeSparseTensor
,
{
InferImplMakeSparseTensor
,
true
}},
{
prim
::
kPrimSparseTensorGetValues
,
{
InferImplSparseTensorGetValues
,
true
}},
{
prim
::
kPrimSparseTensorGetIndices
,
{
InferImplSparseTensorGetIndices
,
true
}},
{
prim
::
kPrimSparseTensorGetDenseShape
,
{
InferImplSparseTensorGetDenseShape
,
true
}},
};
return
prim_eval_implement_map
;
}
using
mindspore
::
parse
::
PyObjectWrapper
;
std
::
unordered_set
<
std
::
string
>
prims_to_skip_undetermined_infer
{
"make_tuple"
,
"make_list"
,
"switch"
,
"env_setitem"
,
...
...
mindspore/ccsrc/pipeline/jit/static_analysis/prim.h
浏览文件 @
1f5441d7
...
...
@@ -26,19 +26,10 @@
#include <vector>
#include "pipeline/jit/static_analysis/evaluator.h"
#include "abstract/primitive_infer_map.h"
namespace
mindspore
{
namespace
abstract
{
using
StandardPrimitiveEvalImpl
=
AbstractBasePtr
(
*
)(
const
AnalysisEnginePtr
&
,
const
PrimitivePtr
&
,
const
AbstractBasePtrList
&
);
struct
StandartPrimitiveImplReg
{
StandardPrimitiveEvalImpl
impl_
;
// Implement function of Primitive.
bool
in_white_list_
;
// true if this Primitive in white list, else false.
};
using
PrimitiveEvalImplMap
=
std
::
unordered_map
<
PrimitivePtr
,
StandartPrimitiveImplReg
,
PrimitiveHasher
,
PrimitiveEqual
>
;
class
StandardPrimEvaluator
:
public
TrivialPrimEvaluator
{
public:
StandardPrimEvaluator
(
const
PrimitivePtr
primitive
,
StandardPrimitiveEvalImpl
eval_impl
)
...
...
@@ -179,191 +170,6 @@ bool IsSubtype(const AbstractBasePtr x, const TypePtr model);
void
ClearPrimEvaluatorMap
();
py
::
dict
ConvertAbstractToPython
(
const
AbstractBasePtr
&
abs_base
);
AbstractBasePtr
InferImplReturn
(
const
AnalysisEnginePtr
&
,
const
PrimitivePtr
&
,
const
AbstractBasePtrList
&
args_spec_list
);
AbstractBasePtr
InferImplTypeof
(
const
AnalysisEnginePtr
&
,
const
PrimitivePtr
&
,
const
AbstractBasePtrList
&
args_spec_list
);
AbstractBasePtr
InferImplHasType
(
const
AnalysisEnginePtr
&
,
const
PrimitivePtr
&
primitive
,
const
AbstractBasePtrList
&
args_spec_list
);
AbstractBasePtr
InferImplDot
(
const
AnalysisEnginePtr
&
,
const
PrimitivePtr
&
primitive
,
const
AbstractBasePtrList
&
args_spec_list
);
AbstractBasePtr
InferImplSwitch
(
const
AnalysisEnginePtr
&
,
const
PrimitivePtr
&
,
const
AbstractBasePtrList
&
args_spec_list
);
AbstractBasePtr
InferImplSwitchLayer
(
const
AnalysisEnginePtr
&
,
const
PrimitivePtr
&
,
const
AbstractBasePtrList
&
args_spec_list
);
AbstractBasePtr
InferImplIs_
(
const
AnalysisEnginePtr
&
,
const
PrimitivePtr
&
,
const
AbstractBasePtrList
&
args_spec_list
);
AbstractBasePtr
InferImplIsNot
(
const
AnalysisEnginePtr
&
,
const
PrimitivePtr
&
,
const
AbstractBasePtrList
&
args_spec_list
);
AbstractBasePtr
InferImplInDict
(
const
AnalysisEnginePtr
&
,
const
PrimitivePtr
&
,
const
AbstractBasePtrList
&
args_spec_list
);
AbstractBasePtr
InferImplNotInDict
(
const
AnalysisEnginePtr
&
,
const
PrimitivePtr
&
,
const
AbstractBasePtrList
&
args_spec_list
);
AbstractBasePtr
InferImplIsConstant
(
const
AnalysisEnginePtr
&
,
const
PrimitivePtr
&
,
const
AbstractBasePtrList
&
args_spec_list
);
AbstractBasePtr
InferImplPooling
(
const
AnalysisEnginePtr
&
,
const
PrimitivePtr
&
primitive
,
const
AbstractBasePtrList
&
args_spec_list
);
AbstractBasePtr
InferImplPoolingGrad
(
const
AnalysisEnginePtr
&
,
const
PrimitivePtr
&
primitive
,
const
AbstractBasePtrList
&
args_spec_list
);
AbstractBasePtr
InferImplFusedBatchNorm
(
const
AnalysisEnginePtr
&
,
const
PrimitivePtr
&
primitive
,
const
AbstractBasePtrList
&
args_spec_list
);
AbstractBasePtr
InferImplFusedBatchNormGrad
(
const
AnalysisEnginePtr
&
,
const
PrimitivePtr
&
primitive
,
const
AbstractBasePtrList
&
args_spec_list
);
AbstractBasePtr
InferImplReluGrad
(
const
AnalysisEnginePtr
&
,
const
PrimitivePtr
&
primitive
,
const
AbstractBasePtrList
&
args_spec_list
);
AbstractBasePtr
InferImplConv2DBackpropInput
(
const
AnalysisEnginePtr
&
,
const
PrimitivePtr
&
primitive
,
const
AbstractBasePtrList
&
args_spec_list
);
AbstractBasePtr
InferImplConv2DBackpropFilter
(
const
AnalysisEnginePtr
&
,
const
PrimitivePtr
&
primitive
,
const
AbstractBasePtrList
&
args_spec_list
);
AbstractBasePtr
InferImplBiasAddGrad
(
const
AnalysisEnginePtr
&
,
const
PrimitivePtr
&
primitive
,
const
AbstractBasePtrList
&
args_spec_list
);
AbstractBasePtr
InferImplRelu
(
const
AnalysisEnginePtr
&
,
const
PrimitivePtr
&
primitive
,
const
AbstractBasePtrList
&
args_spec_list
);
AbstractBasePtr
InferImplFakeBprop
(
const
AnalysisEnginePtr
&
,
const
PrimitivePtr
&
primitive
,
const
AbstractBasePtrList
&
args_spec_list
);
AbstractBasePtr
InferImplZerosLike
(
const
AnalysisEnginePtr
&
,
const
PrimitivePtr
&
primitive
,
const
AbstractBasePtrList
&
args_spec_list
);
AbstractBasePtr
InferImplBpropCut
(
const
AnalysisEnginePtr
&
,
const
PrimitivePtr
&
primitive
,
const
AbstractBasePtrList
&
args_spec_list
);
AbstractBasePtr
InferImplLayerNorm
(
const
AnalysisEnginePtr
&
,
const
PrimitivePtr
&
primitive
,
const
AbstractBasePtrList
&
args_spec_list
);
AbstractBasePtr
InferImplLayerNormGrad
(
const
AnalysisEnginePtr
&
,
const
PrimitivePtr
&
primitive
,
const
AbstractBasePtrList
&
args_spec_list
);
AbstractBasePtr
InferImplDropoutGenMask
(
const
AnalysisEnginePtr
&
,
const
PrimitivePtr
&
primitive
,
const
AbstractBasePtrList
&
args_spec_list
);
AbstractBasePtr
InferImplMinOrMaxGrad
(
const
AnalysisEnginePtr
&
,
const
PrimitivePtr
&
primitive
,
const
AbstractBasePtrList
&
args_spec_list
);
AbstractBasePtr
InferImplScalarToArray
(
const
AnalysisEnginePtr
&
,
const
PrimitivePtr
&
primitive
,
const
AbstractBasePtrList
&
args_spec_list
);
AbstractBasePtr
InferImplArrayToScalar
(
const
AnalysisEnginePtr
&
,
const
PrimitivePtr
&
primitive
,
const
AbstractBasePtrList
&
args_spec_list
);
AbstractBasePtr
InferImplBroadCastShape
(
const
AnalysisEnginePtr
&
,
const
PrimitivePtr
&
primitive
,
const
AbstractBasePtrList
&
args_spec_list
);
AbstractBasePtr
InferImplPack
(
const
AnalysisEnginePtr
&
,
const
PrimitivePtr
&
primitive
,
const
AbstractBasePtrList
&
args_spec_list
);
AbstractBasePtr
InferImplUnique
(
const
AnalysisEnginePtr
&
,
const
PrimitivePtr
&
primitive
,
const
AbstractBasePtrList
&
args_spec_list
);
AbstractBasePtr
InferImplUniqueGrad
(
const
AnalysisEnginePtr
&
,
const
PrimitivePtr
&
primitive
,
const
AbstractBasePtrList
&
args_spec_list
);
AbstractBasePtr
InferImplMakeTuple
(
const
AnalysisEnginePtr
&
,
const
PrimitivePtr
&
primitive
,
const
AbstractBasePtrList
&
args_spec_list
);
AbstractBasePtr
InferImplMakeList
(
const
AnalysisEnginePtr
&
,
const
PrimitivePtr
&
primitive
,
const
AbstractBasePtrList
&
args_spec_list
);
AbstractBasePtr
InferImplMakeDict
(
const
AnalysisEnginePtr
&
,
const
PrimitivePtr
&
primitive
,
const
AbstractBasePtrList
&
args_spec_list
);
AbstractBasePtr
InferImplMakeSlice
(
const
AnalysisEnginePtr
&
,
const
PrimitivePtr
&
primitive
,
const
AbstractBasePtrList
&
args_spec_list
);
AbstractBasePtr
InferImplMakeKwarg
(
const
AnalysisEnginePtr
&
,
const
PrimitivePtr
&
primitive
,
const
AbstractBasePtrList
&
args_spec_list
);
AbstractBasePtr
InferImplExtractKwarg
(
const
AnalysisEnginePtr
&
,
const
PrimitivePtr
&
primitive
,
const
AbstractBasePtrList
&
args_spec_list
);
AbstractBasePtr
InferImplMakeRecord
(
const
AnalysisEnginePtr
&
,
const
PrimitivePtr
&
primitive
,
const
AbstractBasePtrList
&
args_spec_list
);
AbstractBasePtr
InferImplTupleGetItem
(
const
AnalysisEnginePtr
&
,
const
PrimitivePtr
&
primitive
,
const
AbstractBasePtrList
&
args_spec_list
);
AbstractBasePtr
InferImplListGetItem
(
const
AnalysisEnginePtr
&
,
const
PrimitivePtr
&
primitive
,
const
AbstractBasePtrList
&
args_spec_list
);
AbstractBasePtr
InferImplTupleSetItem
(
const
AnalysisEnginePtr
&
,
const
PrimitivePtr
&
primitive
,
const
AbstractBasePtrList
&
args_spec_list
);
AbstractBasePtr
InferImplListSetItem
(
const
AnalysisEnginePtr
&
,
const
PrimitivePtr
&
primitive
,
const
AbstractBasePtrList
&
args_spec_list
);
AbstractBasePtr
InferImplDictGetItem
(
const
AnalysisEnginePtr
&
,
const
PrimitivePtr
&
primitive
,
const
AbstractBasePtrList
&
args_spec_list
);
AbstractBasePtr
InferImplDictSetItem
(
const
AnalysisEnginePtr
&
,
const
PrimitivePtr
&
primitive
,
const
AbstractBasePtrList
&
args_spec_list
);
AbstractBasePtr
InferImplListAppend
(
const
AnalysisEnginePtr
&
,
const
PrimitivePtr
&
primitive
,
const
AbstractBasePtrList
&
args_spec_list
);
AbstractBasePtr
InferImplTupleLen
(
const
AnalysisEnginePtr
&
,
const
PrimitivePtr
&
primitive
,
const
AbstractBasePtrList
&
args_spec_list
);
AbstractBasePtr
InferImplListLen
(
const
AnalysisEnginePtr
&
,
const
PrimitivePtr
&
primitive
,
const
AbstractBasePtrList
&
args_spec_list
);
AbstractBasePtr
InferImplArrayLen
(
const
AnalysisEnginePtr
&
,
const
PrimitivePtr
&
primitive
,
const
AbstractBasePtrList
&
args_spec_list
);
AbstractBasePtr
InferImplListMap
(
const
AnalysisEnginePtr
&
,
const
PrimitivePtr
&
primitive
,
const
AbstractBasePtrList
&
args_spec_list
);
AbstractBasePtr
InferImplListReduce
(
const
AnalysisEnginePtr
&
,
const
PrimitivePtr
&
primitive
,
const
AbstractBasePtrList
&
args_spec_list
);
AbstractBasePtr
InferImplTupleReversed
(
const
AnalysisEnginePtr
&
,
const
PrimitivePtr
&
primitive
,
const
AbstractBasePtrList
&
args_spec_list
);
AbstractBasePtr
InferImplReduceShape
(
const
AnalysisEnginePtr
&
,
const
PrimitivePtr
&
primitive
,
const
AbstractBasePtrList
&
args_spec_list
);
AbstractBasePtr
InferImplTupleDiv
(
const
AnalysisEnginePtr
&
,
const
PrimitivePtr
&
primitive
,
const
AbstractBasePtrList
&
args_spec_list
);
AbstractBasePtr
InferImplTuple2Array
(
const
AnalysisEnginePtr
&
,
const
PrimitivePtr
&
primitive
,
const
AbstractBasePtrList
&
args_spec_list
);
AbstractBasePtr
InferImplShapeMul
(
const
AnalysisEnginePtr
&
,
const
PrimitivePtr
&
primitive
,
const
AbstractBasePtrList
&
args_spec_list
);
AbstractBasePtr
InferImplGenShapeIndex
(
const
AnalysisEnginePtr
&
,
const
PrimitivePtr
&
primitive
,
const
AbstractBasePtrList
&
args_spec_list
);
AbstractBasePtr
InferImplGenInverseIndex
(
const
AnalysisEnginePtr
&
,
const
PrimitivePtr
&
primitive
,
const
AbstractBasePtrList
&
args_spec_list
);
AbstractBasePtr
InferImplTupleEqual
(
const
AnalysisEnginePtr
&
,
const
PrimitivePtr
&
primitive
,
const
AbstractBasePtrList
&
args_spec_list
);
AbstractBasePtr
InferImplListEqual
(
const
AnalysisEnginePtr
&
,
const
PrimitivePtr
&
primitive
,
const
AbstractBasePtrList
&
args_spec_list
);
AbstractBasePtr
InferImplMakeRange
(
const
AnalysisEnginePtr
&
,
const
PrimitivePtr
&
primitive
,
const
AbstractBasePtrList
&
args_spec_list
);
AbstractBasePtr
InferImplStopGradient
(
const
AnalysisEnginePtr
&
,
const
PrimitivePtr
&
primitive
,
const
AbstractBasePtrList
&
args_spec_list
);
AbstractBasePtr
InferImplStringEqual
(
const
AnalysisEnginePtr
&
,
const
PrimitivePtr
&
primitive
,
const
AbstractBasePtrList
&
args_spec_list
);
AbstractBasePtr
InferImplStringConcat
(
const
AnalysisEnginePtr
&
,
const
PrimitivePtr
&
primitive
,
const
AbstractBasePtrList
&
args_spec_list
);
AbstractBasePtr
InferImplDictLen
(
const
AnalysisEnginePtr
&
,
const
PrimitivePtr
&
primitive
,
const
AbstractBasePtrList
&
args_spec_list
);
AbstractBasePtr
InferImplIdentity
(
const
AnalysisEnginePtr
&
,
const
PrimitivePtr
&
primitive
,
const
AbstractBasePtrList
&
args_spec_list
);
AbstractBasePtr
InferImplJ
(
const
AnalysisEnginePtr
&
,
const
PrimitivePtr
&
primitive
,
const
AbstractBasePtrList
&
args_spec_list
);
AbstractBasePtr
InferImplEnvGetItem
(
const
AnalysisEnginePtr
&
,
const
PrimitivePtr
&
primitive
,
const
AbstractBasePtrList
&
args_spec_list
);
AbstractBasePtr
InferImplEnvSetItem
(
const
AnalysisEnginePtr
&
,
const
PrimitivePtr
&
primitive
,
const
AbstractBasePtrList
&
args_spec_list
);
AbstractBasePtr
InferImplEnvAdd
(
const
AnalysisEnginePtr
&
,
const
PrimitivePtr
&
primitive
,
const
AbstractBasePtrList
&
args_spec_list
);
AbstractBasePtr
InferImplMakeRefKey
(
const
AnalysisEnginePtr
&
,
const
PrimitivePtr
&
primitive
,
const
AbstractBasePtrList
&
args_spec_list
);
AbstractBasePtr
InferImplMakeRef
(
const
AnalysisEnginePtr
&
,
const
PrimitivePtr
&
primitive
,
const
AbstractBasePtrList
&
args_spec_list
);
AbstractBasePtr
InferImplGetRefKey
(
const
AnalysisEnginePtr
&
,
const
PrimitivePtr
&
primitive
,
const
AbstractBasePtrList
&
args_spec_list
);
AbstractBasePtr
InferImplGetRefValue
(
const
AnalysisEnginePtr
&
,
const
PrimitivePtr
&
primitive
,
const
AbstractBasePtrList
&
args_spec_list
);
AbstractBasePtr
InferImplGetRefOrigin
(
const
AnalysisEnginePtr
&
,
const
PrimitivePtr
&
primitive
,
const
AbstractBasePtrList
&
args_spec_list
);
AbstractBasePtr
InferImplStateSetItem
(
const
AnalysisEnginePtr
&
,
const
PrimitivePtr
&
primitive
,
const
AbstractBasePtrList
&
args_spec_list
);
AbstractBasePtr
InferImplDepend
(
const
AnalysisEnginePtr
&
,
const
PrimitivePtr
&
primitive
,
const
AbstractBasePtrList
&
args_spec_list
);
AbstractBasePtr
InferImplBroadcastGradientArgs
(
const
AnalysisEnginePtr
&
,
const
PrimitivePtr
&
primitive
,
const
AbstractBasePtrList
&
args_spec_list
);
AbstractBasePtr
InferImplControlDepend
(
const
AnalysisEnginePtr
&
,
const
PrimitivePtr
&
primitive
,
const
AbstractBasePtrList
&
args_spec_list
);
AbstractBasePtr
InferImplDebug
(
const
AnalysisEnginePtr
&
,
const
PrimitivePtr
&
primitive
,
const
AbstractBasePtrList
&
args_spec_list
);
AbstractBasePtr
InferImplMakeRowTensor
(
const
AnalysisEnginePtr
&
,
const
PrimitivePtr
&
primitive
,
const
AbstractBasePtrList
&
args_spec_list
);
AbstractBasePtr
InferImplRowTensorGetValues
(
const
AnalysisEnginePtr
&
,
const
PrimitivePtr
&
primitive
,
const
AbstractBasePtrList
&
args_spec_list
);
AbstractBasePtr
InferImplRowTensorGetIndices
(
const
AnalysisEnginePtr
&
,
const
PrimitivePtr
&
primitive
,
const
AbstractBasePtrList
&
args_spec_list
);
AbstractBasePtr
InferImplRowTensorGetDenseShape
(
const
AnalysisEnginePtr
&
,
const
PrimitivePtr
&
primitive
,
const
AbstractBasePtrList
&
args_spec_list
);
AbstractBasePtr
InferImplMakeSparseTensor
(
const
AnalysisEnginePtr
&
,
const
PrimitivePtr
&
primitive
,
const
AbstractBasePtrList
&
args_spec_list
);
AbstractBasePtr
InferImplSparseTensorGetValues
(
const
AnalysisEnginePtr
&
,
const
PrimitivePtr
&
primitive
,
const
AbstractBasePtrList
&
args_spec_list
);
AbstractBasePtr
InferImplSparseTensorGetIndices
(
const
AnalysisEnginePtr
&
,
const
PrimitivePtr
&
primitive
,
const
AbstractBasePtrList
&
args_spec_list
);
AbstractBasePtr
InferImplSparseTensorGetDenseShape
(
const
AnalysisEnginePtr
&
,
const
PrimitivePtr
&
primitive
,
const
AbstractBasePtrList
&
args_spec_list
);
}
// namespace abstract
}
// namespace mindspore
...
...
mindspore/core/abstract/infer_functions.h
0 → 100644
浏览文件 @
1f5441d7
/**
* This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
*
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CORE_ABSTRACT_INFER_FUNCTIONS_H_
#define MINDSPORE_CORE_ABSTRACT_INFER_FUNCTIONS_H_
#include <string>
#include <memory>
#include "abstract/abstract_value.h"
#include "abstract/param_validator.h"
#include "base/core_ops.h"
namespace
mindspore
{
namespace
abstract
{
AbstractBasePtr
InferImplReturn
(
const
AnalysisEnginePtr
&
,
const
PrimitivePtr
&
,
const
AbstractBasePtrList
&
args_spec_list
);
AbstractBasePtr
InferImplDot
(
const
AnalysisEnginePtr
&
,
const
PrimitivePtr
&
primitive
,
const
AbstractBasePtrList
&
args_spec_list
);
AbstractBasePtr
InferImplSwitch
(
const
AnalysisEnginePtr
&
,
const
PrimitivePtr
&
,
const
AbstractBasePtrList
&
args_spec_list
);
AbstractBasePtr
InferImplSwitchLayer
(
const
AnalysisEnginePtr
&
,
const
PrimitivePtr
&
,
const
AbstractBasePtrList
&
args_spec_list
);
AbstractBasePtr
InferImplIs_
(
const
AnalysisEnginePtr
&
,
const
PrimitivePtr
&
,
const
AbstractBasePtrList
&
args_spec_list
);
AbstractBasePtr
InferImplIsNot
(
const
AnalysisEnginePtr
&
,
const
PrimitivePtr
&
,
const
AbstractBasePtrList
&
args_spec_list
);
AbstractBasePtr
InferImplInDict
(
const
AnalysisEnginePtr
&
,
const
PrimitivePtr
&
,
const
AbstractBasePtrList
&
args_spec_list
);
AbstractBasePtr
InferImplNotInDict
(
const
AnalysisEnginePtr
&
,
const
PrimitivePtr
&
,
const
AbstractBasePtrList
&
args_spec_list
);
AbstractBasePtr
InferImplIsConstant
(
const
AnalysisEnginePtr
&
,
const
PrimitivePtr
&
,
const
AbstractBasePtrList
&
args_spec_list
);
AbstractBasePtr
InferImplPooling
(
const
AnalysisEnginePtr
&
,
const
PrimitivePtr
&
primitive
,
const
AbstractBasePtrList
&
args_spec_list
);
AbstractBasePtr
InferImplPoolingGrad
(
const
AnalysisEnginePtr
&
,
const
PrimitivePtr
&
primitive
,
const
AbstractBasePtrList
&
args_spec_list
);
AbstractBasePtr
InferImplFusedBatchNorm
(
const
AnalysisEnginePtr
&
,
const
PrimitivePtr
&
primitive
,
const
AbstractBasePtrList
&
args_spec_list
);
AbstractBasePtr
InferImplFusedBatchNormGrad
(
const
AnalysisEnginePtr
&
,
const
PrimitivePtr
&
primitive
,
const
AbstractBasePtrList
&
args_spec_list
);
AbstractBasePtr
InferImplReluGrad
(
const
AnalysisEnginePtr
&
,
const
PrimitivePtr
&
primitive
,
const
AbstractBasePtrList
&
args_spec_list
);
AbstractBasePtr
InferImplConv2DBackpropInput
(
const
AnalysisEnginePtr
&
,
const
PrimitivePtr
&
primitive
,
const
AbstractBasePtrList
&
args_spec_list
);
AbstractBasePtr
InferImplConv2DBackpropFilter
(
const
AnalysisEnginePtr
&
,
const
PrimitivePtr
&
primitive
,
const
AbstractBasePtrList
&
args_spec_list
);
AbstractBasePtr
InferImplBiasAddGrad
(
const
AnalysisEnginePtr
&
,
const
PrimitivePtr
&
primitive
,
const
AbstractBasePtrList
&
args_spec_list
);
AbstractBasePtr
InferImplGelu
(
const
AnalysisEnginePtr
&
,
const
PrimitivePtr
&
primitive
,
const
AbstractBasePtrList
&
args_spec_list
);
AbstractBasePtr
InferImplGeluGrad
(
const
AnalysisEnginePtr
&
,
const
PrimitivePtr
&
primitive
,
const
AbstractBasePtrList
&
args_spec_list
);
AbstractBasePtr
InferImplRelu
(
const
AnalysisEnginePtr
&
,
const
PrimitivePtr
&
primitive
,
const
AbstractBasePtrList
&
args_spec_list
);
AbstractBasePtr
InferImplZerosLike
(
const
AnalysisEnginePtr
&
,
const
PrimitivePtr
&
primitive
,
const
AbstractBasePtrList
&
args_spec_list
);
AbstractBasePtr
InferImplBpropCut
(
const
AnalysisEnginePtr
&
,
const
PrimitivePtr
&
primitive
,
const
AbstractBasePtrList
&
args_spec_list
);
AbstractBasePtr
InferImplLayerNorm
(
const
AnalysisEnginePtr
&
,
const
PrimitivePtr
&
primitive
,
const
AbstractBasePtrList
&
args_spec_list
);
AbstractBasePtr
InferImplLayerNormGrad
(
const
AnalysisEnginePtr
&
,
const
PrimitivePtr
&
primitive
,
const
AbstractBasePtrList
&
args_spec_list
);
AbstractBasePtr
InferImplDropoutGenMask
(
const
AnalysisEnginePtr
&
,
const
PrimitivePtr
&
primitive
,
const
AbstractBasePtrList
&
args_spec_list
);
AbstractBasePtr
InferImplMinOrMaxGrad
(
const
AnalysisEnginePtr
&
,
const
PrimitivePtr
&
primitive
,
const
AbstractBasePtrList
&
args_spec_list
);
AbstractBasePtr
InferImplScalarToArray
(
const
AnalysisEnginePtr
&
,
const
PrimitivePtr
&
primitive
,
const
AbstractBasePtrList
&
args_spec_list
);
AbstractBasePtr
InferImplArrayToScalar
(
const
AnalysisEnginePtr
&
,
const
PrimitivePtr
&
primitive
,
const
AbstractBasePtrList
&
args_spec_list
);
AbstractBasePtr
InferImplBroadCastShape
(
const
AnalysisEnginePtr
&
,
const
PrimitivePtr
&
primitive
,
const
AbstractBasePtrList
&
args_spec_list
);
AbstractBasePtr
InferImplPack
(
const
AnalysisEnginePtr
&
,
const
PrimitivePtr
&
primitive
,
const
AbstractBasePtrList
&
args_spec_list
);
AbstractBasePtr
InferImplMakeTuple
(
const
AnalysisEnginePtr
&
,
const
PrimitivePtr
&
primitive
,
const
AbstractBasePtrList
&
args_spec_list
);
AbstractBasePtr
InferImplMakeList
(
const
AnalysisEnginePtr
&
,
const
PrimitivePtr
&
primitive
,
const
AbstractBasePtrList
&
args_spec_list
);
AbstractBasePtr
InferImplMakeDict
(
const
AnalysisEnginePtr
&
,
const
PrimitivePtr
&
primitive
,
const
AbstractBasePtrList
&
args_spec_list
);
AbstractBasePtr
InferImplMakeSlice
(
const
AnalysisEnginePtr
&
,
const
PrimitivePtr
&
primitive
,
const
AbstractBasePtrList
&
args_spec_list
);
AbstractBasePtr
InferImplMakeKwarg
(
const
AnalysisEnginePtr
&
,
const
PrimitivePtr
&
primitive
,
const
AbstractBasePtrList
&
args_spec_list
);
AbstractBasePtr
InferImplExtractKwarg
(
const
AnalysisEnginePtr
&
,
const
PrimitivePtr
&
primitive
,
const
AbstractBasePtrList
&
args_spec_list
);
AbstractBasePtr
InferImplMakeRecord
(
const
AnalysisEnginePtr
&
,
const
PrimitivePtr
&
primitive
,
const
AbstractBasePtrList
&
args_spec_list
);
AbstractBasePtr
InferImplTupleGetItem
(
const
AnalysisEnginePtr
&
,
const
PrimitivePtr
&
primitive
,
const
AbstractBasePtrList
&
args_spec_list
);
AbstractBasePtr
InferImplListGetItem
(
const
AnalysisEnginePtr
&
,
const
PrimitivePtr
&
primitive
,
const
AbstractBasePtrList
&
args_spec_list
);
AbstractBasePtr
InferImplTupleSetItem
(
const
AnalysisEnginePtr
&
,
const
PrimitivePtr
&
primitive
,
const
AbstractBasePtrList
&
args_spec_list
);
AbstractBasePtr
InferImplListSetItem
(
const
AnalysisEnginePtr
&
,
const
PrimitivePtr
&
primitive
,
const
AbstractBasePtrList
&
args_spec_list
);
AbstractBasePtr
InferImplDictGetItem
(
const
AnalysisEnginePtr
&
,
const
PrimitivePtr
&
primitive
,
const
AbstractBasePtrList
&
args_spec_list
);
AbstractBasePtr
InferImplDictSetItem
(
const
AnalysisEnginePtr
&
,
const
PrimitivePtr
&
primitive
,
const
AbstractBasePtrList
&
args_spec_list
);
AbstractBasePtr
InferImplListAppend
(
const
AnalysisEnginePtr
&
,
const
PrimitivePtr
&
primitive
,
const
AbstractBasePtrList
&
args_spec_list
);
AbstractBasePtr
InferImplTupleLen
(
const
AnalysisEnginePtr
&
,
const
PrimitivePtr
&
primitive
,
const
AbstractBasePtrList
&
args_spec_list
);
AbstractBasePtr
InferImplListLen
(
const
AnalysisEnginePtr
&
,
const
PrimitivePtr
&
primitive
,
const
AbstractBasePtrList
&
args_spec_list
);
AbstractBasePtr
InferImplArrayLen
(
const
AnalysisEnginePtr
&
,
const
PrimitivePtr
&
primitive
,
const
AbstractBasePtrList
&
args_spec_list
);
AbstractBasePtr
InferImplGenShapeIndex
(
const
AnalysisEnginePtr
&
,
const
PrimitivePtr
&
primitive
,
const
AbstractBasePtrList
&
args_spec_list
);
AbstractBasePtr
InferImplGenInverseIndex
(
const
AnalysisEnginePtr
&
,
const
PrimitivePtr
&
primitive
,
const
AbstractBasePtrList
&
args_spec_list
);
AbstractBasePtr
InferImplIdentity
(
const
AnalysisEnginePtr
&
,
const
PrimitivePtr
&
primitive
,
const
AbstractBasePtrList
&
args_spec_list
);
AbstractBasePtr
InferImplEnvGetItem
(
const
AnalysisEnginePtr
&
,
const
PrimitivePtr
&
primitive
,
const
AbstractBasePtrList
&
args_spec_list
);
AbstractBasePtr
InferImplEnvSetItem
(
const
AnalysisEnginePtr
&
,
const
PrimitivePtr
&
primitive
,
const
AbstractBasePtrList
&
args_spec_list
);
AbstractBasePtr
InferImplEnvAdd
(
const
AnalysisEnginePtr
&
,
const
PrimitivePtr
&
primitive
,
const
AbstractBasePtrList
&
args_spec_list
);
AbstractBasePtr
InferImplMakeRefKey
(
const
AnalysisEnginePtr
&
,
const
PrimitivePtr
&
primitive
,
const
AbstractBasePtrList
&
args_spec_list
);
AbstractBasePtr
InferImplMakeRef
(
const
AnalysisEnginePtr
&
,
const
PrimitivePtr
&
primitive
,
const
AbstractBasePtrList
&
args_spec_list
);
AbstractBasePtr
InferImplGetRefKey
(
const
AnalysisEnginePtr
&
,
const
PrimitivePtr
&
primitive
,
const
AbstractBasePtrList
&
args_spec_list
);
AbstractBasePtr
InferImplGetRefValue
(
const
AnalysisEnginePtr
&
,
const
PrimitivePtr
&
primitive
,
const
AbstractBasePtrList
&
args_spec_list
);
AbstractBasePtr
InferImplGetRefOrigin
(
const
AnalysisEnginePtr
&
,
const
PrimitivePtr
&
primitive
,
const
AbstractBasePtrList
&
args_spec_list
);
AbstractBasePtr
InferImplStateSetItem
(
const
AnalysisEnginePtr
&
,
const
PrimitivePtr
&
primitive
,
const
AbstractBasePtrList
&
args_spec_list
);
AbstractBasePtr
InferImplDepend
(
const
AnalysisEnginePtr
&
,
const
PrimitivePtr
&
primitive
,
const
AbstractBasePtrList
&
args_spec_list
);
AbstractBasePtr
InferImplControlDepend
(
const
AnalysisEnginePtr
&
,
const
PrimitivePtr
&
primitive
,
const
AbstractBasePtrList
&
args_spec_list
);
AbstractBasePtr
InferImplDebug
(
const
AnalysisEnginePtr
&
,
const
PrimitivePtr
&
primitive
,
const
AbstractBasePtrList
&
args_spec_list
);
AbstractBasePtr
InferImplMakeSparseTensor
(
const
AnalysisEnginePtr
&
,
const
PrimitivePtr
&
primitive
,
const
AbstractBasePtrList
&
args_spec_list
);
AbstractBasePtr
InferImplSparseTensorGetValues
(
const
AnalysisEnginePtr
&
,
const
PrimitivePtr
&
primitive
,
const
AbstractBasePtrList
&
args_spec_list
);
AbstractBasePtr
InferImplSparseTensorGetIndices
(
const
AnalysisEnginePtr
&
,
const
PrimitivePtr
&
primitive
,
const
AbstractBasePtrList
&
args_spec_list
);
AbstractBasePtr
InferImplSparseTensorGetDenseShape
(
const
AnalysisEnginePtr
&
,
const
PrimitivePtr
&
primitive
,
const
AbstractBasePtrList
&
args_spec_list
);
AbstractBasePtr
InferImplMakeRowTensor
(
const
AnalysisEnginePtr
&
,
const
PrimitivePtr
&
primitive
,
const
AbstractBasePtrList
&
args_spec_list
);
AbstractBasePtr
InferImplRowTensorGetValues
(
const
AnalysisEnginePtr
&
,
const
PrimitivePtr
&
primitive
,
const
AbstractBasePtrList
&
args_spec_list
);
AbstractBasePtr
InferImplRowTensorGetIndices
(
const
AnalysisEnginePtr
&
,
const
PrimitivePtr
&
primitive
,
const
AbstractBasePtrList
&
args_spec_list
);
AbstractBasePtr
InferImplRowTensorGetDenseShape
(
const
AnalysisEnginePtr
&
,
const
PrimitivePtr
&
primitive
,
const
AbstractBasePtrList
&
args_spec_list
);
AbstractBasePtr
InferImplUniqueGrad
(
const
AnalysisEnginePtr
&
,
const
PrimitivePtr
&
primitive
,
const
AbstractBasePtrList
&
args_spec_list
);
AbstractBasePtr
InferImplUnique
(
const
AnalysisEnginePtr
&
,
const
PrimitivePtr
&
primitive
,
const
AbstractBasePtrList
&
args_spec_list
);
template
<
typename
T
>
AbstractBasePtr
InferTupleOrListOrDictLen
(
const
std
::
string
&
op_name
,
const
AbstractBasePtrList
&
args_spec_list
)
{
// Inputs: a tuple or list or dict.
CheckArgsSize
(
op_name
,
args_spec_list
,
1
);
auto
arg
=
CheckArg
<
T
>
(
op_name
,
args_spec_list
,
0
);
return
std
::
make_shared
<
AbstractScalar
>
(
SizeToInt
(
arg
->
size
()));
}
}
// namespace abstract
}
// namespace mindspore
#endif // MINDSPORE_CORE_ABSTRACT_INFER_FUNCTIONS_H_
mindspore/c
csrc/frontend/operator
/prim_arrays.cc
→
mindspore/c
ore/abstract
/prim_arrays.cc
浏览文件 @
1f5441d7
...
...
@@ -14,13 +14,48 @@
* limitations under the License.
*/
#include "
pipeline/jit/static_analysis/prim
.h"
#include "
abstract/infer_functions
.h"
#include "abstract/utils.h"
#include "frontend/operator/cc_implementations.h"
#include "abstract/param_validator.h"
namespace
mindspore
{
namespace
abstract
{
namespace
{
std
::
vector
<
int
>
BroadcastShape
(
std
::
vector
<
int
>
shpx
,
std
::
vector
<
int
>
shpy
)
{
int
dlen
=
SizeToInt
(
shpx
.
size
())
-
SizeToInt
(
shpy
.
size
());
if
(
dlen
<
0
)
{
for
(
int
i
=
0
;
i
<
-
dlen
;
++
i
)
{
(
void
)
shpx
.
insert
(
shpx
.
begin
(),
1
);
}
}
else
if
(
dlen
>
0
)
{
for
(
int
i
=
0
;
i
<
dlen
;
i
++
)
{
(
void
)
shpy
.
insert
(
shpy
.
begin
(),
1
);
}
}
if
(
shpx
.
size
()
!=
shpy
.
size
())
{
MS_LOG
(
EXCEPTION
)
<<
"Failure: shpx.size() != shpy.size()."
;
}
std
::
vector
<
int
>
shp
;
for
(
size_t
i
=
0
;
i
<
shpx
.
size
();
i
++
)
{
auto
a
=
shpx
[
i
];
auto
b
=
shpy
[
i
];
if
(
a
==
1
)
{
shp
.
push_back
(
b
);
}
else
if
(
b
==
1
)
{
shp
.
push_back
(
a
);
}
else
if
(
a
==
-
1
)
{
shp
.
push_back
(
b
);
}
else
if
(
b
==
-
1
)
{
shp
.
push_back
(
a
);
}
else
if
(
a
==
b
)
{
shp
.
push_back
(
a
);
}
else
{
return
std
::
vector
<
int
>
();
}
}
return
shp
;
}
}
// namespace
AbstractBasePtr
InferImplScalarToArray
(
const
AnalysisEnginePtr
&
,
const
PrimitivePtr
&
primitive
,
const
AbstractBasePtrList
&
args_spec_list
)
{
// Inputs: a scalar.
...
...
@@ -65,7 +100,7 @@ AbstractBasePtr InferImplBroadCastShape(const AnalysisEnginePtr &, const Primiti
(
void
)
std
::
transform
(
std
::
begin
(
shp_tuple_y
),
std
::
end
(
shp_tuple_y
),
std
::
back_inserter
(
shp_y
),
[](
const
ValuePtr
&
e
)
->
int
{
return
GetValue
<
int
>
(
e
);
});
std
::
vector
<
int
>
res
=
prim
::
BroadcastShape_
(
shp_x
,
shp_y
);
std
::
vector
<
int
>
res
=
BroadcastShape
(
shp_x
,
shp_y
);
if
(
res
.
empty
())
{
MS_LOG
(
EXCEPTION
)
<<
"BroadcastShape fail: "
<<
args_spec_list
[
0
]
->
ToString
()
<<
","
<<
args_spec_list
[
1
]
->
ToString
();
...
...
mindspore/c
csrc/frontend/operator
/prim_debug.cc
→
mindspore/c
ore/abstract
/prim_debug.cc
浏览文件 @
1f5441d7
...
...
@@ -15,8 +15,7 @@
*/
#include "abstract/param_validator.h"
#include "pipeline/jit/static_analysis/prim.h"
#include "frontend/operator/ops.h"
#include "abstract/infer_functions.h"
#include "abstract/utils.h"
#include "utils/symbolic.h"
...
...
mindspore/c
csrc/frontend/operator
/prim_maths.cc
→
mindspore/c
ore/abstract
/prim_maths.cc
浏览文件 @
1f5441d7
...
...
@@ -14,8 +14,7 @@
* limitations under the License.
*/
#include "pipeline/jit/static_analysis/prim.h"
#include "frontend/operator/ops.h"
#include "abstract/infer_functions.h"
#include "abstract/utils.h"
#include "abstract/param_validator.h"
#include "utils/ms_utils.h"
...
...
mindspore/c
csrc/frontend/operator
/prim_nn.cc
→
mindspore/c
ore/abstract
/prim_nn.cc
浏览文件 @
1f5441d7
...
...
@@ -14,10 +14,12 @@
* limitations under the License.
*/
#include "pipeline/jit/static_analysis/prim.h"
#include "frontend/operator/ops.h"
#include "abstract/infer_functions.h"
#include "abstract/utils.h"
#include "abstract/param_validator.h"
#include "utils/check_convert_utils.h"
#include "c_ops/conv2d.h"
#include "abstract/primitive_infer_map.h"
namespace
mindspore
{
namespace
abstract
{
...
...
@@ -278,13 +280,6 @@ AbstractBasePtr InferImplZerosLike(const AnalysisEnginePtr &, const PrimitivePtr
return
args_spec_list
[
0
]
->
Broaden
();
}
AbstractBasePtr
InferImplFakeBprop
(
const
AnalysisEnginePtr
&
,
const
PrimitivePtr
&
primitive
,
const
AbstractBasePtrList
&
args_spec_list
)
{
// Inputs: a tensor.
CheckArgsSize
(
primitive
->
name
(),
args_spec_list
,
1
);
return
args_spec_list
[
0
]
->
Broaden
();
}
AbstractBasePtr
InferImplBpropCut
(
const
AnalysisEnginePtr
&
,
const
PrimitivePtr
&
primitive
,
const
AbstractBasePtrList
&
args_spec_list
)
{
// Inputs: a tensor.
...
...
@@ -433,5 +428,91 @@ AbstractBasePtr InferImplDropoutGenMask(const AnalysisEnginePtr &, const Primiti
return
std
::
make_shared
<
AbstractTensor
>
(
std
::
make_shared
<
AbstractScalar
>
(
kAnyValue
,
kUInt8
),
std
::
make_shared
<
Shape
>
(
std
::
vector
<
int64_t
>
{
shape_y
}));
}
abstract
::
ShapePtr
Conv2dInferShape
(
const
PrimitivePtr
&
primitive
,
const
std
::
vector
<
AbstractBasePtr
>
&
input_args
)
{
MS_EXCEPTION_IF_NULL
(
primitive
);
auto
conv_prim
=
primitive
->
cast
<
PrimConv2dPtr
>
();
MS_EXCEPTION_IF_NULL
(
conv_prim
);
auto
prim_name
=
conv_prim
->
name
();
CheckAndConvertUtils
::
CheckInRange
(
"Conv2d Infer"
,
input_args
.
size
(),
kIncludeLeft
,
{
2
,
3
},
prim_name
);
auto
w_shape
=
CheckAndConvertUtils
::
ConvertShapePtrToShape
(
"w_shape"
,
input_args
[
0
]
->
GetShapeTrack
(),
prim_name
);
auto
x_shape
=
CheckAndConvertUtils
::
ConvertShapePtrToShape
(
"x_shape"
,
input_args
[
1
]
->
GetShapeTrack
(),
prim_name
);
CheckAndConvertUtils
::
CheckInteger
(
"weight rank"
,
w_shape
.
size
(),
kEqual
,
4
,
prim_name
);
CheckAndConvertUtils
::
CheckInteger
(
"x rank"
,
x_shape
.
size
(),
kEqual
,
4
,
prim_name
);
CheckAndConvertUtils
::
Check
(
"x_shape[1] / group"
,
x_shape
[
1
]
/
conv_prim
->
GetGroup
(),
kEqual
,
"w_shape[1]"
,
w_shape
[
1
],
conv_prim
->
name
());
auto
out_channel
=
conv_prim
->
GetOutputChannel
();
CheckAndConvertUtils
::
Check
(
"out_channel"
,
out_channel
,
kEqual
,
"w_shape[0]"
,
w_shape
[
0
],
conv_prim
->
name
());
std
::
vector
<
int
>
temp_w
;
std
::
copy
(
w_shape
.
begin
()
+
2
,
w_shape
.
end
(),
std
::
back_inserter
(
temp_w
));
CheckAndConvertUtils
::
Check
(
"kernel_size"
,
conv_prim
->
GetKernelSize
(),
kEqual
,
"w_shape[2:4]"
,
temp_w
,
conv_prim
->
name
());
auto
kernel_size_h
=
w_shape
[
2
];
auto
kernel_size_w
=
w_shape
[
3
];
auto
stride
=
conv_prim
->
GetStride
();
auto
dilation
=
conv_prim
->
GetDilation
();
auto
stride_h
=
stride
[
2
];
auto
stride_w
=
stride
[
3
];
auto
dilation_h
=
dilation
[
2
];
auto
dilation_w
=
dilation
[
3
];
int
h_out
=
-
1
;
int
w_out
=
-
1
;
std
::
vector
<
int
>
pad_list
(
4
,
0
);
auto
pad_mode
=
conv_prim
->
GetPadMode
();
if
(
pad_mode
==
"valid"
)
{
h_out
=
ceil
((
x_shape
[
2
]
-
dilation_h
*
(
kernel_size_h
-
1
))
/
stride_h
);
w_out
=
ceil
((
x_shape
[
3
]
-
dilation_w
*
(
kernel_size_w
-
1
))
/
stride_w
);
}
else
if
(
pad_mode
==
"same"
)
{
h_out
=
ceil
(
x_shape
[
2
]
/
stride_h
);
w_out
=
ceil
(
x_shape
[
3
]
/
stride_w
);
auto
pad_needed_h
=
std
::
max
(
0
,
(
h_out
-
1
)
*
stride_h
+
dilation_h
*
(
kernel_size_h
-
1
)
+
1
-
x_shape
[
2
]);
pad_list
.
emplace_back
(
floor
(
pad_needed_h
/
2
));
pad_list
.
emplace_back
(
pad_needed_h
/
2
);
auto
pad_needed_w
=
std
::
max
(
0
,
(
w_out
-
1
)
*
stride_w
+
dilation_w
*
(
kernel_size_w
-
1
)
+
1
-
x_shape
[
3
]);
auto
pad_left
=
floor
(
pad_needed_w
/
2
);
pad_list
.
emplace_back
(
pad_left
);
pad_list
.
emplace_back
(
pad_needed_h
-
pad_left
);
}
else
if
(
pad_mode
==
"pad"
)
{
std
::
copy
(
conv_prim
->
GetPad
().
begin
(),
conv_prim
->
GetPad
().
end
(),
std
::
back_inserter
(
pad_list
));
auto
pad_top
=
conv_prim
->
GetPad
()[
0
];
auto
pad_bottom
=
conv_prim
->
GetPad
()[
1
];
auto
pad_right
=
conv_prim
->
GetPad
()[
2
];
auto
pad_left
=
conv_prim
->
GetPad
()[
3
];
h_out
=
1
+
(
x_shape
[
2
]
+
pad_top
+
pad_bottom
-
kernel_size_h
-
(
kernel_size_h
-
1
)
*
(
dilation_h
-
1
))
/
stride_h
;
w_out
=
1
+
(
x_shape
[
3
]
+
pad_left
+
pad_right
-
kernel_size_w
-
(
kernel_size_w
-
1
)
*
(
dilation_w
-
1
))
/
stride_w
;
h_out
=
floor
(
h_out
);
w_out
=
floor
(
w_out
);
}
conv_prim
->
SetPadList
(
pad_list
);
std
::
vector
<
int
>
out_shape
=
{
x_shape
[
0
],
out_channel
,
h_out
,
w_out
};
return
std
::
make_shared
<
abstract
::
Shape
>
(
out_shape
);
}
TypePtr
Conv2dInferType
(
const
PrimitivePtr
&
prim
,
const
std
::
vector
<
AbstractBasePtr
>
&
input_args
)
{
CheckAndConvertUtils
::
CheckInRange
(
""
,
input_args
.
size
(),
kIncludeLeft
,
{
2
,
3
},
prim
->
name
());
for
(
const
auto
&
item
:
input_args
)
{
MS_EXCEPTION_IF_NULL
(
item
);
}
auto
x_type
=
CheckAndConvertUtils
::
ConvertTypePtrToTypeId
(
"x_dtype"
,
input_args
[
0
]
->
GetTypeTrack
(),
prim
->
name
());
const
std
::
set
<
TypeId
>
valid_types
=
{
kNumberTypeInt8
,
kNumberTypeInt32
,
kNumberTypeFloat16
,
kNumberTypeFloat32
};
std
::
map
<
std
::
string
,
TypePtr
>
types
;
types
.
emplace
(
"x"
,
input_args
[
0
]
->
GetTypeTrack
());
types
.
emplace
(
"w"
,
input_args
[
1
]
->
GetTypeTrack
());
CheckAndConvertUtils
::
CheckTensorTypeSame
(
types
,
valid_types
,
prim
->
name
());
if
(
x_type
==
kNumberTypeInt8
)
{
return
std
::
make_shared
<
TensorType
>
(
TypeIdToType
(
kNumberTypeInt32
));
}
return
std
::
make_shared
<
TensorType
>
(
TypeIdToType
(
x_type
));
}
AbstractBasePtr
Conv2dInfer
(
const
abstract
::
AnalysisEnginePtr
&
,
const
PrimitivePtr
&
primitive
,
const
std
::
vector
<
AbstractBasePtr
>
&
input_args
)
{
return
std
::
make_shared
<
abstract
::
AbstractTensor
>
(
Conv2dInferType
(
primitive
,
input_args
),
Conv2dInferShape
(
primitive
,
input_args
)
->
shape
());
}
REGISTER_PRIMITIVE_EVAL_IMPL
(
Conv2D
,
prim
::
kPrimConv2D
,
Conv2dInfer
);
}
// namespace abstract
}
// namespace mindspore
mindspore/c
csrc/frontend/operator
/prim_others.cc
→
mindspore/c
ore/abstract
/prim_others.cc
浏览文件 @
1f5441d7
...
...
@@ -19,9 +19,9 @@
#include "ir/dtype.h"
#include "utils/ms_utils.h"
#include "
frontend/operator/
ops.h"
#include "
base/core_
ops.h"
#include "abstract/param_validator.h"
#include "
pipeline/jit/static_analysis/prim
.h"
#include "
abstract/infer_functions
.h"
#include "abstract/utils.h"
#include "utils/ms_context.h"
#include "utils/symbolic.h"
...
...
@@ -35,27 +35,6 @@ AbstractBasePtr InferImplIdentity(const AnalysisEnginePtr &, const PrimitivePtr
return
args_spec_list
[
0
];
}
AbstractBasePtr
InferImplJ
(
const
AnalysisEnginePtr
&
,
const
PrimitivePtr
&
primitive
,
const
AbstractBasePtrList
&
args_spec_list
)
{
// args: An object of AbstractFunction.
CheckArgsSize
(
primitive
->
name
(),
args_spec_list
,
1
);
MS_LOG
(
DEBUG
)
<<
"evaluate J: "
<<
args_spec_list
[
0
]
->
ToString
();
AbstractFunctionPtr
x
=
dyn_cast
<
AbstractFunction
>
(
args_spec_list
[
0
]);
if
(
x
==
nullptr
)
{
return
std
::
make_shared
<
AbstractJTagged
>
(
args_spec_list
[
0
]);
}
AbstractFuncAtomPtrList
jv
;
auto
build_jv
=
[
&
jv
](
const
AbstractFuncAtomPtr
&
func
)
{
auto
j_closure
=
std
::
make_shared
<
JTransformedAbstractClosure
>
(
func
);
jv
.
push_back
(
j_closure
);
};
x
->
Visit
(
build_jv
);
return
AbstractFunction
::
MakeAbstractFunction
(
jv
);
}
AbstractBasePtr
InferImplEnvGetItem
(
const
AnalysisEnginePtr
&
,
const
PrimitivePtr
&
primitive
,
const
AbstractBasePtrList
&
args_spec_list
)
{
MS_EXCEPTION_IF_NULL
(
primitive
);
...
...
@@ -196,125 +175,6 @@ AbstractBasePtr InferImplDepend(const AnalysisEnginePtr &, const PrimitivePtr &p
return
depends
;
}
bool
CompareShape
(
const
std
::
vector
<
ValuePtr
>
&
x_shape
,
const
std
::
vector
<
ValuePtr
>
&
y_shape
)
{
if
(
x_shape
.
size
()
!=
y_shape
.
size
())
{
return
false
;
}
for
(
size_t
i
=
0
;
i
<
x_shape
.
size
();
++
i
)
{
if
(
GetValue
<
int
>
(
x_shape
[
i
])
!=
GetValue
<
int
>
(
y_shape
[
i
]))
{
return
false
;
}
}
return
true
;
}
enum
State
{
SAME
,
X_ONE
,
Y_ONE
,
};
void
ComputeReduceIndex
(
const
std
::
vector
<
int
>
&
reverse_x
,
const
std
::
vector
<
int
>
&
reverse_y
,
std
::
vector
<
int
>
*
grad_x_reduce_idx
,
std
::
vector
<
int
>
*
grad_y_reduce_idy
)
{
const
size_t
n
=
reverse_x
.
size
();
for
(
size_t
i
=
0
;
i
<
n
;
++
i
)
{
State
curr
;
const
int32_t
x_i
=
reverse_x
[
i
];
const
int32_t
y_i
=
reverse_y
[
i
];
const
int
reduce_idx
=
SizeToInt
(
n
-
1
-
i
);
if
(
x_i
==
y_i
)
{
curr
=
SAME
;
}
else
if
(
x_i
==
1
)
{
grad_x_reduce_idx
->
push_back
(
reduce_idx
);
curr
=
X_ONE
;
}
else
if
(
y_i
==
1
)
{
grad_y_reduce_idy
->
push_back
(
reduce_idx
);
curr
=
Y_ONE
;
}
else
{
MS_LOG
(
EXCEPTION
)
<<
"not compatible shape input for BroadcastGradientArgs"
;
}
if
(
curr
==
SAME
&&
x_i
==
1
)
{
grad_x_reduce_idx
->
push_back
(
reduce_idx
);
grad_y_reduce_idy
->
push_back
(
reduce_idx
);
continue
;
}
}
std
::
reverse
(
grad_x_reduce_idx
->
begin
(),
grad_x_reduce_idx
->
end
());
std
::
reverse
(
grad_y_reduce_idy
->
begin
(),
grad_y_reduce_idy
->
end
());
}
AbstractBasePtr
BroadcastGradientArgsDiff
(
const
std
::
vector
<
ValuePtr
>
&
x_shape
,
const
std
::
vector
<
ValuePtr
>
&
y_shape
)
{
std
::
vector
<
int
>
reverse_x
;
std
::
vector
<
int
>
reverse_y
;
(
void
)
std
::
transform
(
x_shape
.
rbegin
(),
x_shape
.
rend
(),
std
::
back_inserter
(
reverse_x
),
[](
const
ValuePtr
&
v
)
{
return
v
->
cast
<
Int32ImmPtr
>
()
->
value
();
});
(
void
)
std
::
transform
(
y_shape
.
rbegin
(),
y_shape
.
rend
(),
std
::
back_inserter
(
reverse_y
),
[](
const
ValuePtr
&
v
)
{
return
v
->
cast
<
Int32ImmPtr
>
()
->
value
();
});
if
(
reverse_x
.
size
()
>
reverse_y
.
size
())
{
reverse_y
.
resize
(
reverse_x
.
size
(),
1
);
}
else
{
reverse_x
.
resize
(
reverse_y
.
size
(),
1
);
}
std
::
vector
<
int
>
grad_x_reduce_idx
;
std
::
vector
<
int
>
grad_y_reduce_idy
;
ComputeReduceIndex
(
reverse_x
,
reverse_y
,
&
grad_x_reduce_idx
,
&
grad_y_reduce_idy
);
AbstractBasePtrList
abs_list_x
;
AbstractBasePtrList
abs_list_y
;
(
void
)
std
::
transform
(
grad_x_reduce_idx
.
begin
(),
grad_x_reduce_idx
.
end
(),
std
::
back_inserter
(
abs_list_x
),
[](
int
v
)
{
return
abstract
::
FromValue
(
v
);
});
(
void
)
std
::
transform
(
grad_y_reduce_idy
.
begin
(),
grad_y_reduce_idy
.
end
(),
std
::
back_inserter
(
abs_list_y
),
[](
int
v
)
{
return
abstract
::
FromValue
(
v
);
});
auto
x_reduce_idx
=
std
::
make_shared
<
AbstractTuple
>
(
abs_list_x
);
auto
y_reduce_idx
=
std
::
make_shared
<
AbstractTuple
>
(
abs_list_y
);
AbstractBasePtrList
elem_list
;
elem_list
.
push_back
(
x_reduce_idx
);
elem_list
.
push_back
(
y_reduce_idx
);
return
std
::
make_shared
<
AbstractTuple
>
(
elem_list
);
}
AbstractBasePtr
InferImplBroadcastGradientArgs
(
const
AnalysisEnginePtr
&
,
const
PrimitivePtr
&
primitive
,
const
AbstractBasePtrList
&
args_spec_list
)
{
// this primitive get the index that need to reduce
// input: x's shape and y's shape, inputs should be tuple
// output: tuple of x and y 's reduce index, reduce index should be a tuple
const
std
::
string
op_name
=
primitive
->
name
();
CheckArgsSize
(
op_name
,
args_spec_list
,
2
);
auto
arg_x
=
CheckArg
<
AbstractTuple
>
(
op_name
,
args_spec_list
,
0
);
auto
arg_y
=
CheckArg
<
AbstractTuple
>
(
op_name
,
args_spec_list
,
1
);
ValueTuplePtr
arg_x_value
=
arg_x
->
BuildValue
()
->
cast
<
ValueTuplePtr
>
();
MS_EXCEPTION_IF_NULL
(
arg_x_value
);
ValueTuplePtr
arg_y_value
=
arg_y
->
BuildValue
()
->
cast
<
ValueTuplePtr
>
();
MS_EXCEPTION_IF_NULL
(
arg_y_value
);
const
std
::
vector
<
ValuePtr
>
x_shape
=
arg_x_value
->
value
();
const
std
::
vector
<
ValuePtr
>
y_shape
=
arg_y_value
->
value
();
bool
is_same_shape
=
CompareShape
(
x_shape
,
y_shape
);
// if it is the same shape , do not need reduce , return empty tuple
if
(
is_same_shape
)
{
AbstractBasePtrList
empty_list
;
auto
x_reduce_idx
=
std
::
make_shared
<
AbstractTuple
>
(
empty_list
);
auto
y_reduce_idx
=
std
::
make_shared
<
AbstractTuple
>
(
empty_list
);
AbstractBasePtrList
elem_list
;
elem_list
.
push_back
(
x_reduce_idx
);
elem_list
.
push_back
(
y_reduce_idx
);
return
std
::
make_shared
<
AbstractTuple
>
(
elem_list
);
}
return
BroadcastGradientArgsDiff
(
x_shape
,
y_shape
);
}
AbstractBasePtr
InferImplControlDepend
(
const
AnalysisEnginePtr
&
,
const
PrimitivePtr
&
primitive
,
const
AbstractBasePtrList
&
args_spec_list
)
{
// args: Two objects of a subclass of AbstractBase
...
...
mindspore/c
csrc/frontend/operator
/prim_statement.cc
→
mindspore/c
ore/abstract
/prim_statement.cc
浏览文件 @
1f5441d7
...
...
@@ -15,8 +15,7 @@
*/
#include "abstract/param_validator.h"
#include "pipeline/jit/static_analysis/prim.h"
#include "frontend/operator/ops.h"
#include "abstract/infer_functions.h"
#include "abstract/utils.h"
#include "utils/symbolic.h"
...
...
@@ -34,38 +33,6 @@ AbstractBasePtr InferImplReturn(const AnalysisEnginePtr &, const PrimitivePtr &,
return
abs_base
;
}
AbstractBasePtr
InferImplTypeof
(
const
AnalysisEnginePtr
&
,
const
PrimitivePtr
&
,
const
AbstractBasePtrList
&
args_spec_list
)
{
// Inputs: a pointer to an AbstractBase object
if
(
args_spec_list
.
size
()
!=
1
)
{
MS_LOG
(
EXCEPTION
)
<<
"Typeof evaluator requires 1 parameter, while the input size is "
<<
args_spec_list
.
size
()
<<
"."
;
}
AbstractBasePtr
abs_base
=
args_spec_list
[
0
];
MS_EXCEPTION_IF_NULL
(
abs_base
);
TypePtr
type
=
abs_base
->
BuildType
();
return
std
::
make_shared
<
AbstractType
>
(
type
);
}
AbstractBasePtr
InferImplHasType
(
const
AnalysisEnginePtr
&
,
const
PrimitivePtr
&
primitive
,
const
AbstractBasePtrList
&
args_spec_list
)
{
// Inputs: a pointer to an AbstractBase object and a pointer to a Type
const
std
::
string
op_name
=
primitive
->
name
();
CheckArgsSize
(
op_name
,
args_spec_list
,
2
);
AbstractTypePtr
abs_type
=
CheckArg
<
AbstractType
>
(
op_name
,
args_spec_list
,
1
);
auto
mode_v
=
abs_type
->
GetValueTrack
();
MS_EXCEPTION_IF_NULL
(
mode_v
);
if
(
!
mode_v
->
isa
<
Type
>
())
{
MS_LOG
(
EXCEPTION
)
<<
"Get the type from AbstractType value failed."
;
}
TypePtr
mode_t
=
mode_v
->
cast
<
TypePtr
>
();
MS_EXCEPTION_IF_NULL
(
args_spec_list
[
0
]);
bool
v
=
IsSubtype
(
args_spec_list
[
0
],
mode_t
);
return
std
::
make_shared
<
AbstractScalar
>
(
std
::
make_shared
<
BoolImm
>
(
v
),
kBool
);
}
AbstractBasePtr
InferImplDot
(
const
AnalysisEnginePtr
&
,
const
PrimitivePtr
&
primitive
,
const
AbstractBasePtrList
&
args_spec_list
)
{
// Inputs: two tensors.
...
...
mindspore/core/abstract/prim_structures.cc
0 → 100644
浏览文件 @
1f5441d7
/**
* This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
*
* Copyright 2019 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "abstract/infer_functions.h"
#include "abstract/utils.h"
#include "abstract/param_validator.h"
namespace
mindspore
{
namespace
abstract
{
AbstractBasePtr
InferImplMakeTuple
(
const
AnalysisEnginePtr
&
,
const
PrimitivePtr
&
,
const
AbstractBasePtrList
&
args_spec_list
)
{
return
std
::
make_shared
<
AbstractTuple
>
(
args_spec_list
);
}
AbstractBasePtr
InferImplMakeList
(
const
AnalysisEnginePtr
&
,
const
PrimitivePtr
&
,
const
AbstractBasePtrList
&
args_spec_list
)
{
return
std
::
make_shared
<
AbstractList
>
(
args_spec_list
);
}
AbstractBasePtr
InferImplMakeDict
(
const
AnalysisEnginePtr
&
,
const
PrimitivePtr
&
primitive
,
const
AbstractBasePtrList
&
args_spec_list
)
{
// Inputs: two tuples.
const
std
::
string
op_name
=
primitive
->
name
();
CheckArgsSize
(
op_name
,
args_spec_list
,
2
);
AbstractTuplePtr
keys
=
CheckArg
<
AbstractTuple
>
(
op_name
,
args_spec_list
,
0
);
AbstractTuplePtr
values
=
CheckArg
<
AbstractTuple
>
(
op_name
,
args_spec_list
,
1
);
size_t
keys_size
=
keys
->
size
();
if
(
values
->
size
()
!=
keys_size
)
{
MS_LOG
(
EXCEPTION
)
<<
op_name
<<
" evaluator keys' size is not equal with values' size"
;
}
std
::
vector
<
AbstractAttribute
>
key_value
;
AbstractScalarPtr
key
;
AbstractBasePtrList
key_list
=
keys
->
elements
();
AbstractBasePtrList
value_list
=
values
->
elements
();
for
(
size_t
index
=
0
;
index
<
keys_size
;
index
++
)
{
key
=
CheckArg
<
AbstractScalar
>
(
op_name
+
"key"
,
key_list
,
index
);
ValuePtr
keyPtr
=
key
->
BuildValue
();
MS_EXCEPTION_IF_NULL
(
keyPtr
);
if
(
!
keyPtr
->
isa
<
StringImm
>
())
{
MS_LOG
(
EXCEPTION
)
<<
op_name
<<
" evaluator keys should be string, but got "
<<
keyPtr
->
ToString
();
}
std
::
string
key_string
=
GetValue
<
std
::
string
>
(
keyPtr
);
key_value
.
emplace_back
(
key_string
,
value_list
[
index
]);
}
return
std
::
make_shared
<
AbstractDictionary
>
(
key_value
);
}
AbstractBasePtr
InferImplMakeKwarg
(
const
AnalysisEnginePtr
&
,
const
PrimitivePtr
&
primitive
,
const
AbstractBasePtrList
&
args_spec_list
)
{
// Inputs: a string and an object of a subclass of AbstractBase.
const
std
::
string
op_name
=
primitive
->
name
();
CheckArgsSize
(
op_name
,
args_spec_list
,
2
);
AbstractScalarPtr
key
=
CheckArg
<
AbstractScalar
>
(
op_name
,
args_spec_list
,
0
);
ValuePtr
keyPtr
=
key
->
BuildValue
();
if
(
!
keyPtr
->
isa
<
StringImm
>
())
{
MS_LOG
(
EXCEPTION
)
<<
op_name
<<
" evaluator key should be string, but got "
<<
keyPtr
->
ToString
();
}
std
::
string
key_string
=
GetValue
<
std
::
string
>
(
keyPtr
);
return
std
::
make_shared
<
AbstractKeywordArg
>
(
key_string
,
args_spec_list
[
1
]);
}
AbstractBasePtr
InferImplExtractKwarg
(
const
AnalysisEnginePtr
&
,
const
PrimitivePtr
&
primitive
,
const
AbstractBasePtrList
&
args_spec_list
)
{
// Inputs: a string and a keyword.
const
std
::
string
op_name
=
primitive
->
name
();
CheckArgsSize
(
op_name
,
args_spec_list
,
2
);
AbstractScalarPtr
key
=
CheckArg
<
AbstractScalar
>
(
op_name
,
args_spec_list
,
0
);
AbstractKeywordArgPtr
kwarg
=
CheckArg
<
AbstractKeywordArg
>
(
op_name
,
args_spec_list
,
1
);
ValuePtr
key_value
=
key
->
BuildValue
();
if
(
!
key_value
->
isa
<
StringImm
>
())
{
MS_LOG
(
EXCEPTION
)
<<
op_name
<<
" evaluator key should be string, but got "
<<
key_value
->
ToString
();
}
std
::
string
key_input
=
GetValue
<
std
::
string
>
(
key_value
);
std
::
string
key_actual
=
kwarg
->
get_key
();
if
(
key_actual
!=
key_input
)
{
MS_LOG
(
EXCEPTION
)
<<
op_name
<<
" evaluator input key should be same as AbstractKeywordArg' key, but input is "
<<
key_input
<<
", AbstractKeywordArg' key is "
<<
key_actual
;
}
return
kwarg
->
get_arg
();
}
AbstractBasePtr
InferImplMakeSlice
(
const
AnalysisEnginePtr
&
,
const
PrimitivePtr
&
primitive
,
const
AbstractBasePtrList
&
args_spec_list
)
{
// Inputs: three scalars whose value is an int32 number.
CheckArgsSize
(
primitive
->
name
(),
args_spec_list
,
3
);
size_t
args_size
=
args_spec_list
.
size
();
for
(
size_t
index
=
0
;
index
<
args_size
;
index
++
)
{
MS_EXCEPTION_IF_NULL
(
args_spec_list
[
index
]);
if
(
!
args_spec_list
[
index
]
->
isa
<
AbstractScalar
>
()
&&
!
args_spec_list
[
index
]
->
isa
<
AbstractNone
>
())
{
MS_EXCEPTION
(
TypeError
)
<<
"MakeSlice eval "
<<
index
<<
" parameter is neither AbstractScalar nor AbstractNone."
;
}
if
(
args_spec_list
[
index
]
->
isa
<
AbstractScalar
>
()
&&
!
dyn_cast
<
AbstractScalar
>
(
args_spec_list
[
index
])
->
BuildValue
()
->
isa
<
Int32Imm
>
())
{
MS_EXCEPTION
(
TypeError
)
<<
"MakeSlice eval "
<<
index
<<
" parameter is an AbstractScalar, but is not an int32 number."
;
}
}
// Slice: start, end, step
return
std
::
make_shared
<
AbstractSlice
>
(
args_spec_list
[
0
],
args_spec_list
[
1
],
args_spec_list
[
2
]);
}
template
<
typename
T
>
AbstractBasePtr
InferTupleOrListGetItem
(
const
std
::
string
&
op_name
,
const
AbstractBasePtrList
&
args_spec_list
)
{
// Inputs: a tuple or list and a scalar whose value is an int32 number.
CheckArgsSize
(
op_name
,
args_spec_list
,
2
);
auto
queue
=
CheckArg
<
T
>
(
op_name
,
args_spec_list
,
0
);
AbstractScalarPtr
index
=
CheckArg
<
AbstractScalar
>
(
op_name
,
args_spec_list
,
1
);
ValuePtr
index_value
=
index
->
BuildValue
();
if
(
!
index_value
->
isa
<
Int32Imm
>
())
{
// when index_value is an AnyValue and args_spec_list[0] is a scalar, try to return the type of the first element
// and continue
if
(
dyn_cast
<
AbstractScalar
>
(
queue
->
elements
()[
0
])
!=
nullptr
)
{
return
std
::
make_shared
<
AbstractScalar
>
(
queue
->
elements
()[
0
]
->
BuildType
());
}
MS_EXCEPTION
(
IndexError
)
<<
op_name
<<
" evaluator index should be an int32 number, but got "
<<
index_value
->
ToString
();
}
int
idx_v
=
GetValue
<
int
>
(
index_value
);
std
::
size_t
nelems
=
queue
->
elements
().
size
();
if
(
idx_v
>=
SizeToInt
(
nelems
)
||
idx_v
<
-
SizeToInt
(
nelems
))
{
MS_EXCEPTION
(
IndexError
)
<<
op_name
<<
" evaluator index should be in range[-"
<<
SizeToInt
(
nelems
)
<<
", "
<<
SizeToInt
(
nelems
)
<<
"), but got "
<<
idx_v
<<
"."
;
}
std
::
size_t
uidx_v
=
0
;
if
(
idx_v
>=
0
)
{
uidx_v
=
IntToSize
(
idx_v
);
}
else
{
uidx_v
=
IntToSize
(
idx_v
+
SizeToInt
(
nelems
));
}
return
queue
->
elements
()[
uidx_v
];
}
template
<
typename
T
>
AbstractBasePtr
InferTupleOrListSetItem
(
const
std
::
string
&
op_name
,
const
AbstractBasePtrList
&
args_spec_list
)
{
// Inputs: a tuple or list, a scalar whose value is an int32 number and an object of a subclass of AbstractBase.
CheckArgsSize
(
op_name
,
args_spec_list
,
3
);
auto
queue
=
CheckArg
<
T
>
(
op_name
,
args_spec_list
,
0
);
AbstractScalarPtr
index
=
CheckArg
<
AbstractScalar
>
(
op_name
,
args_spec_list
,
1
);
ValuePtr
index_value
=
index
->
BuildValue
();
if
(
!
index_value
->
isa
<
Int32Imm
>
())
{
MS_EXCEPTION
(
IndexError
)
<<
op_name
<<
" evaluator index should be an int32 number, but got "
<<
index_value
->
ToString
();
}
int
idx_v
=
GetValue
<
int
>
(
index_value
);
if
(
idx_v
<
0
)
{
MS_EXCEPTION
(
IndexError
)
<<
"The index of "
<<
typeid
(
T
).
name
()
<<
" should be positive number, but got "
<<
idx_v
<<
"."
;
}
size_t
uidx_v
=
IntToSize
(
idx_v
);
AbstractBasePtrList
elements
=
queue
->
elements
();
std
::
size_t
nelems
=
elements
.
size
();
if
(
uidx_v
>=
nelems
)
{
MS_EXCEPTION
(
IndexError
)
<<
op_name
<<
" evaluator the index: "
<<
uidx_v
<<
" to set out of range: "
<<
nelems
-
1
<<
"."
;
}
elements
[
uidx_v
]
=
args_spec_list
[
2
];
return
std
::
make_shared
<
T
>
(
elements
);
}
AbstractBasePtr
InferImplTupleGetItem
(
const
AnalysisEnginePtr
&
,
const
PrimitivePtr
&
primitive
,
const
AbstractBasePtrList
&
args_spec_list
)
{
return
InferTupleOrListGetItem
<
AbstractTuple
>
(
primitive
->
name
(),
args_spec_list
);
}
AbstractBasePtr
InferImplListGetItem
(
const
AnalysisEnginePtr
&
,
const
PrimitivePtr
&
primitive
,
const
AbstractBasePtrList
&
args_spec_list
)
{
return
InferTupleOrListGetItem
<
AbstractList
>
(
primitive
->
name
(),
args_spec_list
);
}
AbstractBasePtr
InferImplTupleSetItem
(
const
AnalysisEnginePtr
&
,
const
PrimitivePtr
&
primitive
,
const
AbstractBasePtrList
&
args_spec_list
)
{
return
InferTupleOrListSetItem
<
AbstractTuple
>
(
primitive
->
name
(),
args_spec_list
);
}
AbstractBasePtr
InferImplListSetItem
(
const
AnalysisEnginePtr
&
,
const
PrimitivePtr
&
primitive
,
const
AbstractBasePtrList
&
args_spec_list
)
{
return
InferTupleOrListSetItem
<
AbstractList
>
(
primitive
->
name
(),
args_spec_list
);
}
AbstractBasePtr
InferImplDictGetItem
(
const
AnalysisEnginePtr
&
,
const
PrimitivePtr
&
primitive
,
const
AbstractBasePtrList
&
args_spec_list
)
{
// Inputs: a dict and a scalar whose value is a string.
const
std
::
string
op_name
=
primitive
->
name
();
CheckArgsSize
(
op_name
,
args_spec_list
,
2
);
AbstractDictionaryPtr
dict
=
CheckArg
<
AbstractDictionary
>
(
op_name
,
args_spec_list
,
0
);
AbstractScalarPtr
key
=
CheckArg
<
AbstractScalar
>
(
op_name
,
args_spec_list
,
1
);
ValuePtr
key_value
=
key
->
BuildValue
();
if
(
!
key_value
->
isa
<
StringImm
>
())
{
MS_LOG
(
EXCEPTION
)
<<
op_name
<<
" evaluator key should be string, but got "
<<
key_value
->
ToString
();
}
auto
key_str
=
GetValue
<
std
::
string
>
(
key_value
);
std
::
vector
<
AbstractAttribute
>
dict_elems
=
dict
->
elements
();
auto
it
=
std
::
find_if
(
dict_elems
.
begin
(),
dict_elems
.
end
(),
[
key_str
](
const
AbstractAttribute
&
item
)
{
return
item
.
first
==
key_str
;
});
if
(
it
==
dict_elems
.
end
())
{
MS_LOG
(
EXCEPTION
)
<<
"The key "
<<
key_str
<<
" does not exist in the dict:"
<<
args_spec_list
[
0
]
->
ToString
();
}
return
it
->
second
;
}
AbstractBasePtr
InferImplDictSetItem
(
const
AnalysisEnginePtr
&
,
const
PrimitivePtr
&
primitive
,
const
AbstractBasePtrList
&
args_spec_list
)
{
// Inputs: a dict and a scalar whose value is a string and an object of a subclass of AbstractBase.
const
std
::
string
op_name
=
primitive
->
name
();
CheckArgsSize
(
op_name
,
args_spec_list
,
3
);
AbstractDictionaryPtr
dict
=
CheckArg
<
AbstractDictionary
>
(
op_name
,
args_spec_list
,
0
);
AbstractScalarPtr
key
=
CheckArg
<
AbstractScalar
>
(
op_name
,
args_spec_list
,
1
);
ValuePtr
key_value
=
key
->
BuildValue
();
if
(
!
key_value
->
isa
<
StringImm
>
())
{
MS_LOG
(
EXCEPTION
)
<<
op_name
<<
" evaluator key should be string, but got "
<<
key_value
->
ToString
();
}
std
::
string
key_str
=
GetValue
<
std
::
string
>
(
key_value
);
std
::
vector
<
AbstractAttribute
>
dict_elems
=
dict
->
elements
();
auto
it
=
std
::
find_if
(
dict_elems
.
begin
(),
dict_elems
.
end
(),
[
key_str
](
const
AbstractAttribute
&
item
)
{
return
item
.
first
==
key_str
;
});
MS_EXCEPTION_IF_NULL
(
args_spec_list
[
2
]);
auto
new_ele
=
std
::
make_pair
(
key_str
,
args_spec_list
[
2
]);
if
(
it
!=
dict_elems
.
end
())
{
int
index
=
it
-
dict_elems
.
begin
();
dict_elems
[
IntToSize
(
index
)]
=
new_ele
;
}
else
{
dict_elems
.
push_back
(
new_ele
);
}
return
std
::
make_shared
<
AbstractDictionary
>
(
dict_elems
);
}
AbstractBasePtr
InferImplListAppend
(
const
AnalysisEnginePtr
&
,
const
PrimitivePtr
&
primitive
,
const
AbstractBasePtrList
&
args_spec_list
)
{
// Inputs: a list and an object of a subclass of AbstractBase.
const
std
::
string
op_name
=
primitive
->
name
();
CheckArgsSize
(
op_name
,
args_spec_list
,
2
);
AbstractListPtr
list
=
CheckArg
<
AbstractList
>
(
op_name
,
args_spec_list
,
0
);
(
void
)
AbstractJoin
(
list
->
elements
());
return
list
;
}
AbstractBasePtr
InferImplTupleLen
(
const
AnalysisEnginePtr
&
,
const
PrimitivePtr
&
primitive
,
const
AbstractBasePtrList
&
args_spec_list
)
{
return
InferTupleOrListOrDictLen
<
AbstractTuple
>
(
primitive
->
name
(),
args_spec_list
);
}
AbstractBasePtr
InferImplListLen
(
const
AnalysisEnginePtr
&
,
const
PrimitivePtr
&
primitive
,
const
AbstractBasePtrList
&
args_spec_list
)
{
return
InferTupleOrListOrDictLen
<
AbstractList
>
(
primitive
->
name
(),
args_spec_list
);
}
AbstractBasePtr
InferImplArrayLen
(
const
AnalysisEnginePtr
&
,
const
PrimitivePtr
&
,
const
AbstractBasePtrList
&
args_spec_list
)
{
return
std
::
make_shared
<
AbstractScalar
>
(
kAnyValue
,
kInt32
);
}
}
// namespace abstract
}
// namespace mindspore
mindspore/core/abstract/primitive_infer_map.cc
0 → 100644
浏览文件 @
1f5441d7
/**
* This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
*
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "abstract/primitive_infer_map.h"
#include "abstract/abstract_function.h"
#include "abstract/infer_functions.h"
namespace
mindspore
{
namespace
abstract
{
PrimitiveEvalImplMap
&
GetPrimitiveToEvalImplMap
()
{
static
PrimitiveEvalImplMap
prim_eval_implement_map
=
{
// Statements
{
prim
::
kPrimReturn
,
{
InferImplReturn
,
true
}},
{
prim
::
kPrimDot
,
{
InferImplDot
,
true
}},
{
prim
::
kPrimSwitch
,
{
InferImplSwitch
,
true
}},
{
prim
::
kPrimSwitchLayer
,
{
InferImplSwitchLayer
,
true
}},
{
prim
::
kPrimIs_
,
{
InferImplIs_
,
true
}},
{
prim
::
kPrimIsNot
,
{
InferImplIsNot
,
true
}},
{
prim
::
kPrimInDict
,
{
InferImplInDict
,
true
}},
{
prim
::
kPrimNotInDict
,
{
InferImplNotInDict
,
true
}},
{
prim
::
kPrimIsConsant
,
{
InferImplIsConstant
,
true
}},
// Maths
{
prim
::
kPrimMaximumGrad
,
{
InferImplMinOrMaxGrad
,
true
}},
{
prim
::
kPrimMinimumGrad
,
{
InferImplMinOrMaxGrad
,
true
}},
// Array
{
prim
::
kPrimScalarToArray
,
{
InferImplScalarToArray
,
true
}},
{
prim
::
kPrimArrayToScalar
,
{
InferImplArrayToScalar
,
true
}},
{
prim
::
kPrimBroadcastShape
,
{
InferImplBroadCastShape
,
true
}},
{
prim
::
kPrimPack
,
{
InferImplPack
,
true
}},
{
prim
::
kPrimUnique
,
{
InferImplUnique
,
true
}},
{
prim
::
kPrimUniqueGrad
,
{
InferImplUniqueGrad
,
true
}},
// Structure
{
prim
::
kPrimMakeTuple
,
{
InferImplMakeTuple
,
true
}},
{
prim
::
kPrimMakeList
,
{
InferImplMakeList
,
true
}},
{
prim
::
kPrimMakeDict
,
{
InferImplMakeDict
,
true
}},
{
prim
::
kPrimMakeSlice
,
{
InferImplMakeSlice
,
true
}},
{
prim
::
kPrimMakeKeywordArg
,
{
InferImplMakeKwarg
,
true
}},
{
prim
::
kPrimExtractKeywordArg
,
{
InferImplExtractKwarg
,
true
}},
{
prim
::
kPrimTupleGetItem
,
{
InferImplTupleGetItem
,
true
}},
{
prim
::
kPrimListGetItem
,
{
InferImplListGetItem
,
true
}},
{
prim
::
kPrimTupleSetItem
,
{
InferImplTupleSetItem
,
true
}},
{
prim
::
kPrimListSetItem
,
{
InferImplListSetItem
,
true
}},
{
prim
::
kPrimDictGetItem
,
{
InferImplDictGetItem
,
true
}},
{
prim
::
kPrimDictSetItem
,
{
InferImplDictSetItem
,
true
}},
{
prim
::
kPrimListAppend
,
{
InferImplListAppend
,
true
}},
{
prim
::
kPrimTupleLen
,
{
InferImplTupleLen
,
true
}},
{
prim
::
kPrimListLen
,
{
InferImplListLen
,
true
}},
{
prim
::
kPrimArrayLen
,
{
InferImplArrayLen
,
true
}},
// NN
{
prim
::
kPrimPooling
,
{
InferImplPooling
,
true
}},
{
prim
::
kPrimPoolingGrad
,
{
InferImplPoolingGrad
,
true
}},
{
prim
::
kPrimFusedBatchNorm
,
{
InferImplFusedBatchNorm
,
true
}},
{
prim
::
kPrimFusedBatchNormGrad
,
{
InferImplFusedBatchNormGrad
,
true
}},
{
prim
::
kPrimReluGrad
,
{
InferImplReluGrad
,
true
}},
{
prim
::
kPrimConv2DBackpropInput
,
{
InferImplConv2DBackpropInput
,
true
}},
{
prim
::
kPrimConv2DBackpropFilter
,
{
InferImplConv2DBackpropFilter
,
true
}},
{
prim
::
kPrimBiasAddGrad
,
{
InferImplBiasAddGrad
,
true
}},
{
prim
::
kPrimRelu
,
{
InferImplRelu
,
true
}},
{
prim
::
kPrimZerosLike
,
{
InferImplZerosLike
,
true
}},
{
prim
::
kPrimBpropCut
,
{
InferImplBpropCut
,
true
}},
{
prim
::
kPrimLayerNorm
,
{
InferImplLayerNorm
,
true
}},
{
prim
::
kPrimLayerNormGrad
,
{
InferImplLayerNormGrad
,
true
}},
{
prim
::
kPrimDropoutGenMask
,
{
InferImplDropoutGenMask
,
true
}},
// Others
{
prim
::
kPrimIdentity
,
{
InferImplIdentity
,
true
}},
// Set impl to null as it will use PartialEvaluator;
{
prim
::
kPrimPartial
,
{
nullptr
,
true
}},
{
prim
::
kPrimEnvGetItem
,
{
InferImplEnvGetItem
,
true
}},
{
prim
::
kPrimEnvSetItem
,
{
InferImplEnvSetItem
,
true
}},
{
prim
::
kPrimEnvAdd
,
{
InferImplEnvAdd
,
true
}},
{
prim
::
kPrimMakeRefKey
,
{
InferImplMakeRefKey
,
true
}},
{
prim
::
kPrimMakeRef
,
{
InferImplMakeRef
,
true
}},
{
prim
::
kPrimGetRefKey
,
{
InferImplGetRefKey
,
true
}},
{
prim
::
kPrimGetRefValue
,
{
InferImplGetRefValue
,
true
}},
{
prim
::
kPrimStateSetItem
,
{
InferImplStateSetItem
,
true
}},
{
prim
::
kPrimDepend
,
{
InferImplDepend
,
true
}},
{
prim
::
kPrimControlDepend
,
{
InferImplControlDepend
,
true
}},
// Debug
{
prim
::
kPrimDebug
,
{
InferImplDebug
,
true
}},
// SparseTensor
{
prim
::
kPrimMakeSparseTensor
,
{
InferImplMakeSparseTensor
,
true
}},
{
prim
::
kPrimSparseTensorGetValues
,
{
InferImplSparseTensorGetValues
,
true
}},
{
prim
::
kPrimSparseTensorGetIndices
,
{
InferImplSparseTensorGetIndices
,
true
}},
{
prim
::
kPrimSparseTensorGetDenseShape
,
{
InferImplSparseTensorGetDenseShape
,
true
}},
// RowTensor
{
prim
::
kPrimMakeRowTensor
,
{
InferImplMakeRowTensor
,
true
}},
{
prim
::
kPrimRowTensorGetValues
,
{
InferImplRowTensorGetValues
,
true
}},
{
prim
::
kPrimRowTensorGetIndices
,
{
InferImplRowTensorGetIndices
,
true
}},
{
prim
::
kPrimRowTensorGetDenseShape
,
{
InferImplRowTensorGetDenseShape
,
true
}},
};
return
prim_eval_implement_map
;
}
void
RegisterStandardPrimitiveImpl
(
const
PrimitivePtr
&
primitive
,
const
StandardPrimitiveImplReg
&
impl_reg
)
{
auto
&
prim_eval_map
=
GetPrimitiveToEvalImplMap
();
prim_eval_map
[
primitive
]
=
impl_reg
;
}
}
// namespace abstract
}
// namespace mindspore
mindspore/core/abstract/primitive_infer_map.h
0 → 100644
浏览文件 @
1f5441d7
/**
* This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
*
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CORE_ABSTRACT_PRIMITIVE_INFER_MAP_H_
#define MINDSPORE_CORE_ABSTRACT_PRIMITIVE_INFER_MAP_H_
#include <unordered_map>
#include "ir/primitive.h"
#include "base/core_ops.h"
#include "abstract/abstract_value.h"
namespace
mindspore
{
namespace
abstract
{
using
StandardPrimitiveEvalImpl
=
AbstractBasePtr
(
*
)(
const
abstract
::
AnalysisEnginePtr
&
,
const
PrimitivePtr
&
,
const
AbstractBasePtrList
&
);
struct
StandardPrimitiveImplReg
{
StandardPrimitiveEvalImpl
impl_
;
// Implement function of Primitive.
bool
in_white_list_
;
// true if this Primitive in white list, else false.
};
using
PrimitiveEvalImplMap
=
std
::
unordered_map
<
PrimitivePtr
,
StandardPrimitiveImplReg
,
PrimitiveHasher
,
PrimitiveEqual
>
;
PrimitiveEvalImplMap
&
GetPrimitiveToEvalImplMap
();
void
RegisterStandardPrimitiveImpl
(
const
PrimitivePtr
&
primitive
,
const
StandardPrimitiveImplReg
&
impl_reg
);
class
RegisterStandardPrimitiveEvalHelper
{
public:
RegisterStandardPrimitiveEvalHelper
(
const
PrimitivePtr
&
primitive
,
const
StandardPrimitiveEvalImpl
&
impl
)
{
const
StandardPrimitiveImplReg
impl_reg
{
impl
,
true
};
RegisterStandardPrimitiveImpl
(
primitive
,
impl_reg
);
}
~
RegisterStandardPrimitiveEvalHelper
()
=
default
;
};
#define REGISTER_PRIMITIVE_EVAL_IMPL(name, primitive, impl) \
static auto helper_##name = RegisterStandardPrimitiveEvalHelper(primitive, impl)
}
// namespace abstract
}
// namespace mindspore
#endif // MINDSPORE_CORE_ABSTRACT_PRIMITIVE_INFER_MAP_H_
mindspore/core/base/core_ops.h
浏览文件 @
1f5441d7
...
...
@@ -246,6 +246,25 @@ inline const PrimitivePtr kPrimNotInDict = std::make_shared<Primitive>("not_in_d
inline
const
PrimitivePtr
kPrimIsConsant
=
std
::
make_shared
<
Primitive
>
(
"is_constant"
);
inline
const
PrimitivePtr
kPrimEquivFormat
=
std
::
make_shared
<
Primitive
>
(
"EquivFormat"
);
// Structures
inline
const
PrimitivePtr
kPrimMakeList
=
std
::
make_shared
<
Primitive
>
(
"make_list"
);
inline
const
PrimitivePtr
kPrimMakeKeywordArg
=
std
::
make_shared
<
Primitive
>
(
"make_keyword_arg"
);
inline
const
PrimitivePtr
kPrimListGetItem
=
std
::
make_shared
<
Primitive
>
(
"list_getitem"
);
inline
const
PrimitivePtr
kPrimListSetItem
=
std
::
make_shared
<
Primitive
>
(
"list_setitem"
);
inline
const
PrimitivePtr
kPrimDictGetItem
=
std
::
make_shared
<
Primitive
>
(
"dict_getitem"
);
inline
const
PrimitivePtr
kPrimDictSetItem
=
std
::
make_shared
<
Primitive
>
(
"dict_setitem"
);
inline
const
PrimitivePtr
kPrimListAppend
=
std
::
make_shared
<
Primitive
>
(
"list_append"
);
inline
const
PrimitivePtr
kPrimListLen
=
std
::
make_shared
<
Primitive
>
(
"list_len"
);
// Other miscellaneous
inline
const
PrimitivePtr
kPrimEnvSetItem
=
std
::
make_shared
<
Primitive
>
(
"env_setitem"
);
inline
const
PrimitivePtr
kPrimEnvGetItem
=
std
::
make_shared
<
Primitive
>
(
"env_getitem"
);
inline
const
PrimitivePtr
kPrimEnvAdd
=
std
::
make_shared
<
Primitive
>
(
"env_add"
);
inline
const
PrimitivePtr
kPrimMakeRefKey
=
std
::
make_shared
<
Primitive
>
(
"MakeRefKey"
);
inline
const
PrimitivePtr
kPrimGetRefKey
=
std
::
make_shared
<
Primitive
>
(
"get_ref_key"
);
inline
const
PrimitivePtr
kPrimMakeRef
=
std
::
make_shared
<
Primitive
>
(
"make_ref"
);
inline
const
PrimitivePtr
kPrimGetRefValue
=
std
::
make_shared
<
Primitive
>
(
"get_ref_value"
);
// Other primitve not used by backend but used in core;
inline
const
PrimitivePtr
kPrimStateSetItem
=
std
::
make_shared
<
Primitive
>
(
"state_setitem"
);
inline
const
PrimitivePtr
kPrimJ
=
std
::
make_shared
<
Primitive
>
(
"J"
);
...
...
mindspore/core/c_ops/conv2d.cc
浏览文件 @
1f5441d7
...
...
@@ -26,87 +26,19 @@
namespace
mindspore
{
namespace
{
using
PrimConv2dPtr
=
std
::
shared_ptr
<
Conv2d
>
;
abstract
::
ShapePtr
InferShape
(
const
PrimitivePtr
&
primitive
,
const
std
::
vector
<
AbstractBasePtr
>
&
input_args
)
{
MS_EXCEPTION_IF_NULL
(
primitive
);
auto
conv_prim
=
primitive
->
cast
<
PrimConv2dPtr
>
();
MS_EXCEPTION_IF_NULL
(
conv_prim
);
auto
prim_name
=
conv_prim
->
name
();
CheckAndConvertUtils
::
CheckInRange
(
"Conv2d Infer"
,
input_args
.
size
(),
kIncludeLeft
,
{
2
,
3
},
prim_name
);
auto
w_shape
=
CheckAndConvertUtils
::
ConvertShapePtrToShape
(
"w_shape"
,
input_args
[
0
]
->
GetShapeTrack
(),
prim_name
);
auto
x_shape
=
CheckAndConvertUtils
::
ConvertShapePtrToShape
(
"x_shape"
,
input_args
[
1
]
->
GetShapeTrack
(),
prim_name
);
CheckAndConvertUtils
::
CheckInteger
(
"weight rank"
,
w_shape
.
size
(),
kEqual
,
4
,
prim_name
);
CheckAndConvertUtils
::
CheckInteger
(
"x rank"
,
x_shape
.
size
(),
kEqual
,
4
,
prim_name
);
CheckAndConvertUtils
::
Check
(
"x_shape[1] / group"
,
x_shape
[
1
]
/
conv_prim
->
GetGroup
(),
kEqual
,
"w_shape[1]"
,
w_shape
[
1
],
conv_prim
->
name
());
auto
out_channel
=
conv_prim
->
GetOutputChannel
();
CheckAndConvertUtils
::
Check
(
"out_channel"
,
out_channel
,
kEqual
,
"w_shape[0]"
,
w_shape
[
0
],
conv_prim
->
name
());
std
::
vector
<
int
>
temp_w
;
std
::
copy
(
w_shape
.
begin
()
+
2
,
w_shape
.
end
(),
std
::
back_inserter
(
temp_w
));
CheckAndConvertUtils
::
Check
(
"kernel_size"
,
conv_prim
->
GetKernelSize
(),
kEqual
,
"w_shape[2:4]"
,
temp_w
,
conv_prim
->
name
());
auto
kernel_size_h
=
w_shape
[
2
];
auto
kernel_size_w
=
w_shape
[
3
];
auto
stride
=
conv_prim
->
GetStride
();
auto
dilation
=
conv_prim
->
GetDilation
();
auto
stride_h
=
stride
[
2
];
auto
stride_w
=
stride
[
3
];
auto
dilation_h
=
dilation
[
2
];
auto
dilation_w
=
dilation
[
3
];
int
h_out
=
-
1
;
int
w_out
=
-
1
;
std
::
vector
<
int
>
pad_list
(
4
,
0
);
auto
pad_mode
=
conv_prim
->
GetPadMode
();
if
(
pad_mode
==
"valid"
)
{
h_out
=
ceil
((
x_shape
[
2
]
-
dilation_h
*
(
kernel_size_h
-
1
))
/
stride_h
);
w_out
=
ceil
((
x_shape
[
3
]
-
dilation_w
*
(
kernel_size_w
-
1
))
/
stride_w
);
}
else
if
(
pad_mode
==
"same"
)
{
h_out
=
ceil
(
x_shape
[
2
]
/
stride_h
);
w_out
=
ceil
(
x_shape
[
3
]
/
stride_w
);
auto
pad_needed_h
=
std
::
max
(
0
,
(
h_out
-
1
)
*
stride_h
+
dilation_h
*
(
kernel_size_h
-
1
)
+
1
-
x_shape
[
2
]);
pad_list
.
emplace_back
(
floor
(
pad_needed_h
/
2
));
pad_list
.
emplace_back
(
pad_needed_h
/
2
);
auto
pad_needed_w
=
std
::
max
(
0
,
(
w_out
-
1
)
*
stride_w
+
dilation_w
*
(
kernel_size_w
-
1
)
+
1
-
x_shape
[
3
]);
auto
pad_left
=
floor
(
pad_needed_w
/
2
);
pad_list
.
emplace_back
(
pad_left
);
pad_list
.
emplace_back
(
pad_needed_h
-
pad_left
);
}
else
if
(
pad_mode
==
"pad"
)
{
std
::
copy
(
conv_prim
->
GetPad
().
begin
(),
conv_prim
->
GetPad
().
end
(),
std
::
back_inserter
(
pad_list
));
auto
pad_top
=
conv_prim
->
GetPad
()[
0
];
auto
pad_bottom
=
conv_prim
->
GetPad
()[
1
];
auto
pad_right
=
conv_prim
->
GetPad
()[
2
];
auto
pad_left
=
conv_prim
->
GetPad
()[
3
];
h_out
=
1
+
(
x_shape
[
2
]
+
pad_top
+
pad_bottom
-
kernel_size_h
-
(
kernel_size_h
-
1
)
*
(
dilation_h
-
1
))
/
stride_h
;
w_out
=
1
+
(
x_shape
[
3
]
+
pad_left
+
pad_right
-
kernel_size_w
-
(
kernel_size_w
-
1
)
*
(
dilation_w
-
1
))
/
stride_w
;
h_out
=
floor
(
h_out
);
w_out
=
floor
(
w_out
);
}
conv_prim
->
SetPadList
(
pad_list
);
std
::
vector
<
int
>
out_shape
=
{
x_shape
[
0
],
out_channel
,
h_out
,
w_out
};
return
std
::
make_shared
<
abstract
::
Shape
>
(
out_shape
);
}
TypePtr
InferType
(
const
PrimitivePtr
&
prim
,
const
std
::
vector
<
AbstractBasePtr
>
&
input_args
)
{
CheckAndConvertUtils
::
CheckInRange
(
""
,
input_args
.
size
(),
kIncludeLeft
,
{
2
,
3
},
prim
->
name
());
for
(
const
auto
&
item
:
input_args
)
{
MS_EXCEPTION_IF_NULL
(
item
);
}
auto
x_type
=
CheckAndConvertUtils
::
ConvertTypePtrToTypeId
(
"x_dtype"
,
input_args
[
0
]
->
GetTypeTrack
(),
prim
->
name
());
const
std
::
set
<
TypeId
>
valid_types
=
{
kNumberTypeInt8
,
kNumberTypeInt32
,
kNumberTypeFloat16
,
kNumberTypeFloat32
};
std
::
map
<
std
::
string
,
TypePtr
>
types
;
types
.
emplace
(
"x"
,
input_args
[
0
]
->
GetTypeTrack
());
types
.
emplace
(
"w"
,
input_args
[
1
]
->
GetTypeTrack
());
CheckAndConvertUtils
::
CheckTensorTypeSame
(
types
,
valid_types
,
prim
->
name
());
if
(
x_type
==
kNumberTypeInt8
)
{
return
std
::
make_shared
<
TensorType
>
(
TypeIdToType
(
kNumberTypeInt32
));
}
return
std
::
make_shared
<
TensorType
>
(
TypeIdToType
(
x_type
));
}
constexpr
auto
kKernelSize
=
"kernel_size"
;
constexpr
auto
kStride
=
"stride"
;
constexpr
auto
kDilation
=
"dilation"
;
constexpr
auto
kPadMode
=
"pad_mode"
;
constexpr
auto
kPad
=
"pad"
;
constexpr
auto
kMode
=
"mode"
;
constexpr
auto
kGroup
=
"group"
;
constexpr
auto
kOutputChannel
=
"output channel"
;
constexpr
auto
kPadList
=
"pad_list"
;
constexpr
auto
kConv2DName
=
"Conv2D"
;
}
// namespace
Conv2d
::
Conv2d
()
:
PrimitiveC
(
kConv2DName
)
{
InitIOName
({
"x"
,
"w"
},
{
"output"
});
}
void
Conv2d
::
Init
(
int
out_channel
,
const
std
::
vector
<
int
>
&
kernel_size
,
int
mode
,
const
std
::
string
&
pad_mode
,
const
std
::
vector
<
int
>
&
pad
,
const
std
::
vector
<
int
>
&
stride
,
const
std
::
vector
<
int
>
&
dilation
,
int
group
)
{
...
...
@@ -130,10 +62,47 @@ void Conv2d::Init(int out_channel, const std::vector<int> &kernel_size, int mode
this
->
SetOutChannel
(
CheckAndConvertUtils
::
CheckInteger
(
"out_channel"
,
out_channel
,
kGreaterThan
,
0
,
prim_name
));
this
->
SetGroup
(
CheckAndConvertUtils
::
CheckInteger
(
"group"
,
group
,
kGreaterThan
,
0
,
prim_name
));
}
std
::
vector
<
int
>
Conv2d
::
GetKernelSize
()
const
{
auto
value_ptr
=
GetAttr
(
kKernelSize
);
return
GetValue
<
std
::
vector
<
int
>>
(
value_ptr
);
}
std
::
vector
<
int
>
Conv2d
::
GetStride
()
const
{
auto
value_ptr
=
GetAttr
(
kStride
);
return
GetValue
<
std
::
vector
<
int
>>
(
value_ptr
);
}
std
::
vector
<
int
>
Conv2d
::
GetDilation
()
const
{
auto
value_ptr
=
GetAttr
(
kDilation
);
return
GetValue
<
std
::
vector
<
int
>>
(
value_ptr
);
}
std
::
string
Conv2d
::
GetPadMode
()
const
{
auto
value_ptr
=
this
->
GetAttr
(
kPadMode
);
return
GetValue
<
string
>
(
value_ptr
);
}
std
::
vector
<
int
>
Conv2d
::
GetPad
()
const
{
auto
value_ptr
=
this
->
GetAttr
(
kPad
);
return
GetValue
<
std
::
vector
<
int
>>
(
value_ptr
);
}
int
Conv2d
::
GetMode
()
const
{
auto
value_ptr
=
this
->
GetAttr
(
kMode
);
return
GetValue
<
int
>
(
value_ptr
);
}
AbstractBasePtr
Conv2dInfer
(
const
abstract
::
AnalysisEnginePtr
&
,
const
PrimitivePtr
&
primitive
,
const
std
::
vector
<
AbstractBasePtr
>
&
input_args
)
{
return
std
::
make_shared
<
abstract
::
AbstractTensor
>
(
InferType
(
primitive
,
input_args
),
InferShape
(
primitive
,
input_args
)
->
shape
());
int
Conv2d
::
GetGroup
()
const
{
auto
value_ptr
=
this
->
GetAttr
(
kGroup
);
return
GetValue
<
int
>
(
value_ptr
);
}
int
Conv2d
::
GetOutputChannel
()
const
{
auto
value_ptr
=
this
->
GetAttr
(
kOutputChannel
);
return
GetValue
<
int
>
(
value_ptr
);
}
void
Conv2d
::
SetKernelSize
(
const
std
::
vector
<
int
>
&
kernel_size
)
{
this
->
AddAttr
(
kKernelSize
,
MakeValue
(
kernel_size
));
}
void
Conv2d
::
SetStride
(
const
std
::
vector
<
int
>
&
stride
)
{
this
->
AddAttr
(
kStride
,
MakeValue
(
stride
));
}
void
Conv2d
::
SetDilation
(
const
std
::
vector
<
int
>
&
dilation
)
{
this
->
AddAttr
(
kDilation
,
MakeValue
(
dilation
));
}
void
Conv2d
::
SetPadMode
(
const
std
::
string
&
pad_mode
)
{
this
->
AddAttr
(
kPadMode
,
MakeValue
(
pad_mode
));
}
void
Conv2d
::
SetPad
(
const
std
::
vector
<
int
>
&
pad
)
{
this
->
AddAttr
(
kPad
,
MakeValue
(
pad
));
}
void
Conv2d
::
SetMode
(
int
mode
)
{
this
->
AddAttr
(
kMode
,
MakeValue
(
mode
));
}
void
Conv2d
::
SetGroup
(
int
group
)
{
this
->
AddAttr
(
kGroup
,
MakeValue
(
group
));
}
void
Conv2d
::
SetOutChannel
(
int
output_channel
)
{
this
->
AddAttr
(
kOutputChannel
,
MakeValue
(
output_channel
));
}
void
Conv2d
::
SetPadList
(
const
std
::
vector
<
int
>
&
pad_list
)
{
this
->
AddAttr
(
kPadList
,
MakeValue
(
pad_list
));
}
}
// namespace mindspore
mindspore/core/c_ops/conv2d.h
浏览文件 @
1f5441d7
...
...
@@ -16,79 +16,44 @@
* limitations under the License.
*/
#ifndef MINDSPORE_CORE_C_OPS_CONV2D_H
#define MINDSPORE_CORE_C_OPS_CONV2D_H
#ifndef MINDSPORE_CORE_C_OPS_CONV2D_H
_
#define MINDSPORE_CORE_C_OPS_CONV2D_H
_
#include <map>
#include <vector>
#include <string>
#include <memory>
#include "c_ops/primitive_c.h"
#include "abstract/abstract_value.h"
#include "utils/check_convert_utils.h"
namespace
mindspore
{
class
Conv2d
:
public
PrimitiveC
{
public:
Conv2d
()
:
PrimitiveC
(
kConv2DName
)
{
InitIOName
({
"x"
,
"w"
},
{
"output"
});
}
Conv2d
()
;
void
Init
(
int
out_channel
,
const
std
::
vector
<
int
>
&
kernel_size
,
int
mode
=
1
,
const
std
::
string
&
pad_mode
=
"valid"
,
const
std
::
vector
<
int
>
&
pad
=
{
0
,
0
,
0
,
0
},
const
std
::
vector
<
int
>
&
stride
=
{
1
,
1
,
1
,
1
},
const
std
::
vector
<
int
>
&
dilation
=
{
1
,
1
,
1
,
1
},
int
group
=
1
);
std
::
vector
<
int
>
GetKernelSize
()
const
{
auto
value_ptr
=
this
->
GetAttr
(
kKernelSize
);
return
GetValue
<
std
::
vector
<
int
>>
(
value_ptr
);
}
std
::
vector
<
int
>
GetStride
()
const
{
auto
value_ptr
=
GetAttr
(
kStride
);
return
GetValue
<
std
::
vector
<
int
>>
(
value_ptr
);
}
std
::
vector
<
int
>
GetDilation
()
const
{
auto
value_ptr
=
GetAttr
(
kDilation
);
return
GetValue
<
std
::
vector
<
int
>>
(
value_ptr
);
}
std
::
string
GetPadMode
()
const
{
auto
value_ptr
=
this
->
GetAttr
(
kPadMode
);
return
GetValue
<
string
>
(
value_ptr
);
}
std
::
vector
<
int
>
GetPad
()
const
{
auto
value_ptr
=
this
->
GetAttr
(
kPad
);
return
GetValue
<
std
::
vector
<
int
>>
(
value_ptr
);
}
int
GetMode
()
const
{
auto
value_ptr
=
this
->
GetAttr
(
kMode
);
return
GetValue
<
int
>
(
value_ptr
);
}
int
GetGroup
()
const
{
auto
value_ptr
=
this
->
GetAttr
(
kGroup
);
return
GetValue
<
int
>
(
value_ptr
);
}
int
GetOutputChannel
()
const
{
auto
value_ptr
=
this
->
GetAttr
(
kOutputChannel
);
return
GetValue
<
int
>
(
value_ptr
);
}
void
SetKernelSize
(
const
std
::
vector
<
int
>
&
kernel_size
)
{
this
->
AddAttr
(
kKernelSize
,
MakeValue
(
kernel_size
));
}
void
SetStride
(
const
std
::
vector
<
int
>
&
stride
)
{
this
->
AddAttr
(
kStride
,
MakeValue
(
stride
));
}
void
SetDilation
(
const
std
::
vector
<
int
>
&
dilation
)
{
this
->
AddAttr
(
kDilation
,
MakeValue
(
dilation
));
}
void
SetPadMode
(
const
std
::
string
&
pad_mode
)
{
this
->
AddAttr
(
kPadMode
,
MakeValue
(
pad_mode
));
}
void
SetPad
(
const
std
::
vector
<
int
>
&
pad
)
{
this
->
AddAttr
(
kPad
,
MakeValue
(
pad
));
}
void
SetMode
(
int
mode
)
{
this
->
AddAttr
(
kMode
,
MakeValue
(
mode
));
}
void
SetGroup
(
int
group
)
{
this
->
AddAttr
(
kGroup
,
MakeValue
(
group
));
}
void
SetOutChannel
(
int
output_channel
)
{
this
->
AddAttr
(
kOutputChannel
,
MakeValue
(
output_channel
));
}
void
SetPadList
(
const
std
::
vector
<
int
>
&
pad_list
)
{
this
->
AddAttr
(
kPadList
,
MakeValue
(
pad_list
));
}
private:
inline
static
const
string
kKernelSize
=
"kernel_size"
;
inline
static
const
string
kStride
=
"stride"
;
inline
static
const
string
kDilation
=
"dilation"
;
inline
static
const
string
kPadMode
=
"pad_mode"
;
inline
static
const
string
kPad
=
"pad"
;
inline
static
const
string
kMode
=
"mode"
;
inline
static
const
string
kGroup
=
"group"
;
inline
static
const
string
kOutputChannel
=
"output channel"
;
inline
static
const
string
kPadList
=
"pad_list"
;
inline
static
const
string
kConv2DName
=
"Conv2D"
;
std
::
vector
<
int
>
GetKernelSize
()
const
;
std
::
vector
<
int
>
GetStride
()
const
;
std
::
vector
<
int
>
GetDilation
()
const
;
std
::
string
GetPadMode
()
const
;
std
::
vector
<
int
>
GetPad
()
const
;
int
GetMode
()
const
;
int
GetGroup
()
const
;
int
GetOutputChannel
()
const
;
void
SetKernelSize
(
const
std
::
vector
<
int
>
&
kernel_size
);
void
SetStride
(
const
std
::
vector
<
int
>
&
stride
);
void
SetDilation
(
const
std
::
vector
<
int
>
&
dilation
);
void
SetPadMode
(
const
std
::
string
&
pad_mode
);
void
SetPad
(
const
std
::
vector
<
int
>
&
pad
);
void
SetMode
(
int
mode
);
void
SetGroup
(
int
group
);
void
SetOutChannel
(
int
output_channel
);
void
SetPadList
(
const
std
::
vector
<
int
>
&
pad_list
);
};
AbstractBasePtr
Conv2dInfer
(
const
abstract
::
AnalysisEnginePtr
&
,
const
PrimitivePtr
&
primitive
,
const
std
::
vector
<
AbstractBasePtr
>
&
input_args
);
using
PrimConv2dPtr
=
std
::
shared_ptr
<
Conv2d
>
;
}
// namespace mindspore
#endif // MINDSPORE_CORE_C_OPS_CONV2D_H
#endif // MINDSPORE_CORE_C_OPS_CONV2D_H
_
mindspore/core/c_ops/primitive_c.h
浏览文件 @
1f5441d7
...
...
@@ -16,8 +16,8 @@
* limitations under the License.
*/
#ifndef MINDSPORE_CORE_C_OPS_PRIMITIVE_C_H
#define MINDSPORE_CORE_C_OPS_PRIMITIVE_C_H
#ifndef MINDSPORE_CORE_C_OPS_PRIMITIVE_C_H
_
#define MINDSPORE_CORE_C_OPS_PRIMITIVE_C_H
_
#include <string>
#include <vector>
#include "ir/primitive.h"
...
...
@@ -25,7 +25,7 @@
namespace
mindspore
{
class
PrimitiveC
:
public
Primitive
{
public:
explicit
PrimitiveC
(
const
std
::
string
&
name
)
:
Primitive
(
name
)
{
attrs_
=
{};
}
explicit
PrimitiveC
(
const
std
::
string
&
name
)
:
Primitive
(
name
)
{}
protected:
void
InitIOName
(
const
std
::
vector
<
std
::
string
>
&
inputs_name
,
const
std
::
vector
<
std
::
string
>
&
outputs_name
)
{
...
...
@@ -34,4 +34,4 @@ class PrimitiveC : public Primitive {
}
};
}
// namespace mindspore
#endif // MINDSPORE_CORE_C_OPS_PRIMITIVE_C_H
#endif // MINDSPORE_CORE_C_OPS_PRIMITIVE_C_H
_
tests/ut/cpp/CMakeLists.txt
浏览文件 @
1f5441d7
...
...
@@ -61,6 +61,7 @@ file(GLOB_RECURSE MINDSPORE_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}
"../../../mindspore/core/abstract/*.cc"
"../../../mindspore/core/ir/*.cc"
"../../../mindspore/core/utils/*.cc"
"../../../mindspore/core/c_ops/*.cc"
"../../../mindspore/ccsrc/common/*.cc"
"../../../mindspore/ccsrc/utils/*.cc"
"../../../mindspore/ccsrc/pipeline/jit/parse/*.cc"
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录