Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
ae641e14
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
ae641e14
编写于
8月 07, 2020
作者:
W
WilliamLian
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
unify primitive : add conv2d's primitive C & checkfunction
上级
eb4bee3c
变更
9
隐藏空白更改
内联
并排
Showing
9 changed file
with
632 addition
and
1 deletion
+632
-1
mindspore/ccsrc/CMakeLists.txt
mindspore/ccsrc/CMakeLists.txt
+3
-1
mindspore/core/c_ops/CMakeLists.txt
mindspore/core/c_ops/CMakeLists.txt
+2
-0
mindspore/core/c_ops/conv2d.cc
mindspore/core/c_ops/conv2d.cc
+139
-0
mindspore/core/c_ops/conv2d.h
mindspore/core/c_ops/conv2d.h
+94
-0
mindspore/core/c_ops/primitive_c.h
mindspore/core/c_ops/primitive_c.h
+37
-0
mindspore/core/ir/func_graph.cc
mindspore/core/ir/func_graph.cc
+13
-0
mindspore/core/ir/func_graph.h
mindspore/core/ir/func_graph.h
+2
-0
mindspore/core/utils/check_convert_utils.cc
mindspore/core/utils/check_convert_utils.cc
+270
-0
mindspore/core/utils/check_convert_utils.h
mindspore/core/utils/check_convert_utils.h
+72
-0
未找到文件。
mindspore/ccsrc/CMakeLists.txt
浏览文件 @
ae641e14
...
...
@@ -149,7 +149,9 @@ add_subdirectory(${CMAKE_SOURCE_DIR}/mindspore/core/utils util)
list
(
APPEND SUB_OBJECTS_SRC $<TARGET_OBJECTS:_mindspore_core_utils_obj>
)
add_subdirectory
(
${
CMAKE_SOURCE_DIR
}
/mindspore/core/ir ir
)
list
(
APPEND SUB_OBJECTS_SRC $<TARGET_OBJECTS:_mindspore_ir_obj>
)
add_dependencies
(
_mindspore_core_utils_obj _mindspore_base_obj _mindspore_ir_obj _mindspore_abstract_obj proto_input
)
add_subdirectory
(
${
CMAKE_SOURCE_DIR
}
/mindspore/core/c_ops c_ops
)
list
(
APPEND SUB_OBJECTS_SRC $<TARGET_OBJECTS:_mindspore_c_ops_obj>
)
add_dependencies
(
_mindspore_core_utils_obj _mindspore_base_obj _mindspore_ir_obj _mindspore_abstract_obj _mindspore_c_ops_obj proto_input
)
set_property
(
SOURCE
${
SUB_OBJECTS_SRC
}
PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_ME
)
add_library
(
mindspore STATIC
${
SUB_OBJECTS_SRC
}
)
...
...
mindspore/core/c_ops/CMakeLists.txt
0 → 100644
浏览文件 @
ae641e14
file
(
GLOB_RECURSE _C_OPS_ALL_SRC_FILES RELATIVE
${
CMAKE_CURRENT_SOURCE_DIR
}
"*.cc"
)
add_library
(
_mindspore_c_ops_obj OBJECT
${
_C_OPS_ALL_SRC_FILES
}
)
mindspore/core/c_ops/conv2d.cc
0 → 100644
浏览文件 @
ae641e14
/**
* This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
*
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "c_ops/conv2d.h"
#include <string>
#include <algorithm>
#include <memory>
#include <set>
#include <vector>
#include "utils/check_convert_utils.h"
namespace
mindspore
{
namespace
{
using
PrimConv2dPtr
=
std
::
shared_ptr
<
Conv2d
>
;
abstract
::
ShapePtr
InferShape
(
const
PrimitivePtr
&
primitive
,
const
std
::
vector
<
AbstractBasePtr
>
&
input_args
)
{
MS_EXCEPTION_IF_NULL
(
primitive
);
auto
conv_prim
=
primitive
->
cast
<
PrimConv2dPtr
>
();
MS_EXCEPTION_IF_NULL
(
conv_prim
);
auto
prim_name
=
conv_prim
->
name
();
CheckAndConvertUtils
::
CheckInRange
(
"Conv2d Infer"
,
input_args
.
size
(),
kIncludeLeft
,
{
2
,
3
},
prim_name
);
auto
w_shape
=
CheckAndConvertUtils
::
ConvertShapePtrToShape
(
"w_shape"
,
input_args
[
0
]
->
GetShapeTrack
(),
prim_name
);
auto
x_shape
=
CheckAndConvertUtils
::
ConvertShapePtrToShape
(
"x_shape"
,
input_args
[
1
]
->
GetShapeTrack
(),
prim_name
);
CheckAndConvertUtils
::
CheckInteger
(
"weight rank"
,
w_shape
.
size
(),
kEqual
,
4
,
prim_name
);
CheckAndConvertUtils
::
CheckInteger
(
"x rank"
,
x_shape
.
size
(),
kEqual
,
4
,
prim_name
);
CheckAndConvertUtils
::
Check
(
"x_shape[1] / group"
,
x_shape
[
1
]
/
conv_prim
->
GetGroup
(),
kEqual
,
"w_shape[1]"
,
w_shape
[
1
],
conv_prim
->
name
());
auto
out_channel
=
conv_prim
->
GetOutputChannel
();
CheckAndConvertUtils
::
Check
(
"out_channel"
,
out_channel
,
kEqual
,
"w_shape[0]"
,
w_shape
[
0
],
conv_prim
->
name
());
std
::
vector
<
int
>
temp_w
;
std
::
copy
(
w_shape
.
begin
()
+
2
,
w_shape
.
end
(),
std
::
back_inserter
(
temp_w
));
CheckAndConvertUtils
::
Check
(
"kernel_size"
,
conv_prim
->
GetKernelSize
(),
kEqual
,
"w_shape[2:4]"
,
temp_w
,
conv_prim
->
name
());
auto
kernel_size_h
=
w_shape
[
2
];
auto
kernel_size_w
=
w_shape
[
3
];
auto
stride
=
conv_prim
->
GetStride
();
auto
dilation
=
conv_prim
->
GetDilation
();
auto
stride_h
=
stride
[
2
];
auto
stride_w
=
stride
[
3
];
auto
dilation_h
=
dilation
[
2
];
auto
dilation_w
=
dilation
[
3
];
int
h_out
=
-
1
;
int
w_out
=
-
1
;
std
::
vector
<
int
>
pad_list
(
4
,
0
);
auto
pad_mode
=
conv_prim
->
GetPadMode
();
if
(
pad_mode
==
"valid"
)
{
h_out
=
ceil
((
x_shape
[
2
]
-
dilation_h
*
(
kernel_size_h
-
1
))
/
stride_h
);
w_out
=
ceil
((
x_shape
[
3
]
-
dilation_w
*
(
kernel_size_w
-
1
))
/
stride_w
);
}
else
if
(
pad_mode
==
"same"
)
{
h_out
=
ceil
(
x_shape
[
2
]
/
stride_h
);
w_out
=
ceil
(
x_shape
[
3
]
/
stride_w
);
auto
pad_needed_h
=
std
::
max
(
0
,
(
h_out
-
1
)
*
stride_h
+
dilation_h
*
(
kernel_size_h
-
1
)
+
1
-
x_shape
[
2
]);
pad_list
.
emplace_back
(
floor
(
pad_needed_h
/
2
));
pad_list
.
emplace_back
(
pad_needed_h
/
2
);
auto
pad_needed_w
=
std
::
max
(
0
,
(
w_out
-
1
)
*
stride_w
+
dilation_w
*
(
kernel_size_w
-
1
)
+
1
-
x_shape
[
3
]);
auto
pad_left
=
floor
(
pad_needed_w
/
2
);
pad_list
.
emplace_back
(
pad_left
);
pad_list
.
emplace_back
(
pad_needed_h
-
pad_left
);
}
else
if
(
pad_mode
==
"pad"
)
{
std
::
copy
(
conv_prim
->
GetPad
().
begin
(),
conv_prim
->
GetPad
().
end
(),
std
::
back_inserter
(
pad_list
));
auto
pad_top
=
conv_prim
->
GetPad
()[
0
];
auto
pad_bottom
=
conv_prim
->
GetPad
()[
1
];
auto
pad_right
=
conv_prim
->
GetPad
()[
2
];
auto
pad_left
=
conv_prim
->
GetPad
()[
3
];
h_out
=
1
+
(
x_shape
[
2
]
+
pad_top
+
pad_bottom
-
kernel_size_h
-
(
kernel_size_h
-
1
)
*
(
dilation_h
-
1
))
/
stride_h
;
w_out
=
1
+
(
x_shape
[
3
]
+
pad_left
+
pad_right
-
kernel_size_w
-
(
kernel_size_w
-
1
)
*
(
dilation_w
-
1
))
/
stride_w
;
h_out
=
floor
(
h_out
);
w_out
=
floor
(
w_out
);
}
conv_prim
->
SetPadList
(
pad_list
);
std
::
vector
<
int
>
out_shape
=
{
x_shape
[
0
],
out_channel
,
h_out
,
w_out
};
return
std
::
make_shared
<
abstract
::
Shape
>
(
out_shape
);
}
TypePtr
InferType
(
const
PrimitivePtr
&
prim
,
const
std
::
vector
<
AbstractBasePtr
>
&
input_args
)
{
CheckAndConvertUtils
::
CheckInRange
(
""
,
input_args
.
size
(),
kIncludeLeft
,
{
2
,
3
},
prim
->
name
());
for
(
const
auto
&
item
:
input_args
)
{
MS_EXCEPTION_IF_NULL
(
item
);
}
auto
x_type
=
CheckAndConvertUtils
::
ConvertTypePtrToTypeId
(
"x_dtype"
,
input_args
[
0
]
->
GetTypeTrack
(),
prim
->
name
());
const
std
::
set
<
TypeId
>
valid_types
=
{
kNumberTypeInt8
,
kNumberTypeInt32
,
kNumberTypeFloat16
,
kNumberTypeFloat32
};
std
::
map
<
std
::
string
,
TypePtr
>
types
;
types
.
emplace
(
"x"
,
input_args
[
0
]
->
GetTypeTrack
());
types
.
emplace
(
"w"
,
input_args
[
1
]
->
GetTypeTrack
());
CheckAndConvertUtils
::
CheckTensorTypeSame
(
types
,
valid_types
,
prim
->
name
());
if
(
x_type
==
kNumberTypeInt8
)
{
return
std
::
make_shared
<
TensorType
>
(
TypeIdToType
(
kNumberTypeInt32
));
}
return
std
::
make_shared
<
TensorType
>
(
TypeIdToType
(
x_type
));
}
}
// namespace
void
Conv2d
::
Init
(
int
out_channel
,
const
std
::
vector
<
int
>
&
kernel_size
,
int
mode
,
const
std
::
string
&
pad_mode
,
const
std
::
vector
<
int
>
&
pad
,
const
std
::
vector
<
int
>
&
stride
,
const
std
::
vector
<
int
>
&
dilation
,
int
group
)
{
auto
prim_name
=
this
->
name
();
this
->
AddAttr
(
"data_format"
,
MakeValue
(
"NCHW"
));
this
->
AddAttr
(
"offset_a"
,
MakeValue
(
0
));
this
->
SetKernelSize
(
CheckAndConvertUtils
::
CheckPositiveVector
(
kKernelSize
,
kernel_size
,
prim_name
));
this
->
SetStride
(
CheckAndConvertUtils
::
CheckPositiveVector
(
kStride
,
stride
,
this
->
name
(),
true
,
true
));
this
->
SetDilation
(
CheckAndConvertUtils
::
CheckPositiveVector
(
kDilation
,
dilation
,
this
->
name
(),
true
,
true
));
this
->
SetPadMode
(
CheckAndConvertUtils
::
CheckString
(
kPadMode
,
pad_mode
,
{
"valid"
,
"same"
,
"pad"
},
prim_name
));
CheckAndConvertUtils
::
CheckInteger
(
"pad size"
,
pad
.
size
(),
kEqual
,
4
,
prim_name
);
if
(
pad_mode
==
"pad"
)
{
for
(
auto
item
:
pad
)
{
CheckAndConvertUtils
::
Check
(
"pad item"
,
item
,
kGreaterEqual
,
"zeros list"
,
0
,
prim_name
);
}
}
else
{
CheckAndConvertUtils
::
Check
(
kPad
,
pad
,
kEqual
,
"zeros list"
,
{
0
,
0
,
0
,
0
},
prim_name
);
}
this
->
SetPad
(
CheckAndConvertUtils
::
CheckPositiveVector
(
kPad
,
pad
,
this
->
name
(),
true
,
true
));
this
->
SetMode
(
CheckAndConvertUtils
::
CheckInteger
(
"mode"
,
mode
,
kEqual
,
1
,
prim_name
));
this
->
SetOutChannel
(
CheckAndConvertUtils
::
CheckInteger
(
"out_channel"
,
out_channel
,
kGreaterThan
,
0
,
prim_name
));
this
->
SetGroup
(
CheckAndConvertUtils
::
CheckInteger
(
"group"
,
group
,
kGreaterThan
,
0
,
prim_name
));
}
AbstractBasePtr
Conv2dInfer
(
const
abstract
::
AnalysisEnginePtr
&
,
const
PrimitivePtr
&
primitive
,
const
std
::
vector
<
AbstractBasePtr
>
&
input_args
)
{
return
std
::
make_shared
<
abstract
::
AbstractTensor
>
(
InferType
(
primitive
,
input_args
),
InferShape
(
primitive
,
input_args
)
->
shape
());
}
}
// namespace mindspore
mindspore/core/c_ops/conv2d.h
0 → 100644
浏览文件 @
ae641e14
/**
* This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
*
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CORE_C_OPS_CONV2D_H
#define MINDSPORE_CORE_C_OPS_CONV2D_H
#include <map>
#include <vector>
#include <string>
#include "c_ops/primitive_c.h"
#include "abstract/abstract_value.h"
#include "utils/check_convert_utils.h"
namespace
mindspore
{
class
Conv2d
:
public
PrimitiveC
{
public:
Conv2d
()
:
PrimitiveC
(
kConv2DName
)
{
InitIOName
({
"x"
,
"w"
},
{
"output"
});
}
void
Init
(
int
out_channel
,
const
std
::
vector
<
int
>
&
kernel_size
,
int
mode
=
1
,
const
std
::
string
&
pad_mode
=
"valid"
,
const
std
::
vector
<
int
>
&
pad
=
{
0
,
0
,
0
,
0
},
const
std
::
vector
<
int
>
&
stride
=
{
1
,
1
,
1
,
1
},
const
std
::
vector
<
int
>
&
dilation
=
{
1
,
1
,
1
,
1
},
int
group
=
1
);
std
::
vector
<
int
>
GetKernelSize
()
const
{
auto
value_ptr
=
this
->
GetAttr
(
kKernelSize
);
return
GetValue
<
std
::
vector
<
int
>>
(
value_ptr
);
}
std
::
vector
<
int
>
GetStride
()
const
{
auto
value_ptr
=
GetAttr
(
kStride
);
return
GetValue
<
std
::
vector
<
int
>>
(
value_ptr
);
}
std
::
vector
<
int
>
GetDilation
()
const
{
auto
value_ptr
=
GetAttr
(
kDilation
);
return
GetValue
<
std
::
vector
<
int
>>
(
value_ptr
);
}
std
::
string
GetPadMode
()
const
{
auto
value_ptr
=
this
->
GetAttr
(
kPadMode
);
return
GetValue
<
string
>
(
value_ptr
);
}
std
::
vector
<
int
>
GetPad
()
const
{
auto
value_ptr
=
this
->
GetAttr
(
kPad
);
return
GetValue
<
std
::
vector
<
int
>>
(
value_ptr
);
}
int
GetMode
()
const
{
auto
value_ptr
=
this
->
GetAttr
(
kMode
);
return
GetValue
<
int
>
(
value_ptr
);
}
int
GetGroup
()
const
{
auto
value_ptr
=
this
->
GetAttr
(
kGroup
);
return
GetValue
<
int
>
(
value_ptr
);
}
int
GetOutputChannel
()
const
{
auto
value_ptr
=
this
->
GetAttr
(
kOutputChannel
);
return
GetValue
<
int
>
(
value_ptr
);
}
void
SetKernelSize
(
const
std
::
vector
<
int
>
&
kernel_size
)
{
this
->
AddAttr
(
kKernelSize
,
MakeValue
(
kernel_size
));
}
void
SetStride
(
const
std
::
vector
<
int
>
&
stride
)
{
this
->
AddAttr
(
kStride
,
MakeValue
(
stride
));
}
void
SetDilation
(
const
std
::
vector
<
int
>
&
dilation
)
{
this
->
AddAttr
(
kDilation
,
MakeValue
(
dilation
));
}
void
SetPadMode
(
const
std
::
string
&
pad_mode
)
{
this
->
AddAttr
(
kPadMode
,
MakeValue
(
pad_mode
));
}
void
SetPad
(
const
std
::
vector
<
int
>
&
pad
)
{
this
->
AddAttr
(
kPad
,
MakeValue
(
pad
));
}
void
SetMode
(
int
mode
)
{
this
->
AddAttr
(
kMode
,
MakeValue
(
mode
));
}
void
SetGroup
(
int
group
)
{
this
->
AddAttr
(
kGroup
,
MakeValue
(
group
));
}
void
SetOutChannel
(
int
output_channel
)
{
this
->
AddAttr
(
kOutputChannel
,
MakeValue
(
output_channel
));
}
void
SetPadList
(
const
std
::
vector
<
int
>
&
pad_list
)
{
this
->
AddAttr
(
kPadList
,
MakeValue
(
pad_list
));
}
private:
inline
static
const
string
kKernelSize
=
"kernel_size"
;
inline
static
const
string
kStride
=
"stride"
;
inline
static
const
string
kDilation
=
"dilation"
;
inline
static
const
string
kPadMode
=
"pad_mode"
;
inline
static
const
string
kPad
=
"pad"
;
inline
static
const
string
kMode
=
"mode"
;
inline
static
const
string
kGroup
=
"group"
;
inline
static
const
string
kOutputChannel
=
"output channel"
;
inline
static
const
string
kPadList
=
"pad_list"
;
inline
static
const
string
kConv2DName
=
"Conv2D"
;
};
AbstractBasePtr
Conv2dInfer
(
const
abstract
::
AnalysisEnginePtr
&
,
const
PrimitivePtr
&
primitive
,
const
std
::
vector
<
AbstractBasePtr
>
&
input_args
);
}
// namespace mindspore
#endif // MINDSPORE_CORE_C_OPS_CONV2D_H
mindspore/core/c_ops/primitive_c.h
0 → 100644
浏览文件 @
ae641e14
/**
* This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
*
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CORE_C_OPS_PRIMITIVE_C_H
#define MINDSPORE_CORE_C_OPS_PRIMITIVE_C_H
#include <string>
#include <vector>
#include "ir/primitive.h"
#include "ir/value.h"
namespace
mindspore
{
class
PrimitiveC
:
public
Primitive
{
public:
explicit
PrimitiveC
(
const
std
::
string
&
name
)
:
Primitive
(
name
)
{
attrs_
=
{};
}
protected:
void
InitIOName
(
const
std
::
vector
<
std
::
string
>
&
inputs_name
,
const
std
::
vector
<
std
::
string
>
&
outputs_name
)
{
this
->
AddAttr
(
"input_names"
,
MakeValue
(
inputs_name
));
this
->
AddAttr
(
"output_names"
,
MakeValue
(
outputs_name
));
}
};
}
// namespace mindspore
#endif // MINDSPORE_CORE_C_OPS_PRIMITIVE_C_H
mindspore/core/ir/func_graph.cc
浏览文件 @
ae641e14
...
...
@@ -632,6 +632,19 @@ void FuncGraph::CheckOrder() {
MS_LOG
(
DEBUG
)
<<
"Check order okay."
;
}
}
CNodePtr
FuncGraph
::
NewCNode
(
const
PrimitivePtr
&
primitive
,
const
std
::
vector
<
AnfNodePtr
>
&
inputs
)
{
auto
primitive_node
=
std
::
make_shared
<
ValueNode
>
(
primitive
);
std
::
vector
<
AnfNodePtr
>
input_node_list
=
{
primitive_node
};
std
::
copy
(
inputs
.
begin
(),
inputs
.
end
(),
std
::
back_inserter
(
input_node_list
));
return
NewCNode
(
input_node_list
);
}
ParameterPtr
FuncGraph
::
add_parameter
(
const
tensor
::
MetaTensorPtr
&
meta_tensor
)
{
auto
parameter
=
add_parameter
();
parameter
->
set_default_param
(
MakeValue
(
meta_tensor
));
parameter
->
set_abstract
(
meta_tensor
->
ToAbstract
());
return
parameter
;
}
size_t
NewFgSeenGeneration
()
{
static
size_t
fg_seen_generation
=
0
;
...
...
mindspore/core/ir/func_graph.h
浏览文件 @
ae641e14
...
...
@@ -170,7 +170,9 @@ class FuncGraph : public FuncGraphBase {
// create a cnode with given inputs, bound to this graph, and set to specific scope
CNodePtr
NewCNodeWithScope
(
const
std
::
vector
<
AnfNodePtr
>
&
inputs
,
const
ScopePtr
&
scope
);
virtual
CNodePtr
NewCNode
(
const
PrimitivePtr
&
primitive
,
const
std
::
vector
<
AnfNodePtr
>
&
prim_inputs
);
virtual
ParameterPtr
add_parameter
(
const
tensor
::
MetaTensorPtr
&
meta_tensor
);
// Functions for handling variable argument, keyword-only arguments and variable keyword argument
AnfNodePtr
GetDefaultValueByName
(
const
std
::
string
&
name
);
void
set_param_default_value
(
const
std
::
string
&
name
,
const
AnfNodePtr
&
node
)
{
...
...
mindspore/core/utils/check_convert_utils.cc
0 → 100644
浏览文件 @
ae641e14
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "utils/check_convert_utils.h"
#include <utility>
#include "abstract/abstract_value.h"
namespace
mindspore
{
namespace
{
const
std
::
map
<
CompareEnum
,
std
::
function
<
bool
(
int
,
int
)
>>
kCompareMap
=
{
{
kEqual
,
[](
int
num1
,
int
num2
)
->
bool
{
return
num1
==
num2
;
}},
{
kNotEqual
,
[](
int
num1
,
int
num2
)
->
bool
{
return
num1
!=
num2
;
}},
{
kLessThan
,
[](
int
num1
,
int
num2
)
->
bool
{
return
num1
<
num2
;
}},
{
kLessEqual
,
[](
int
num1
,
int
num2
)
->
bool
{
return
num1
<=
num2
;
}},
{
kGreaterThan
,
[](
int
num1
,
int
num2
)
->
bool
{
return
num1
>
num2
;
}},
{
kGreaterEqual
,
[](
int
num1
,
int
num2
)
->
bool
{
return
num1
>=
num2
;
}}};
const
std
::
map
<
CompareRange
,
std
::
function
<
bool
(
int
,
std
::
pair
<
int
,
int
>
)
>>
kCompareRangeMap
=
{
{
kIncludeNeither
,
[](
int
num1
,
std
::
pair
<
int
,
int
>
range
)
->
bool
{
return
num1
>
range
.
first
&&
num1
<
range
.
second
;
}},
{
kIncludeLeft
,
[](
int
num1
,
std
::
pair
<
int
,
int
>
range
)
->
bool
{
return
num1
>=
range
.
first
&&
num1
<
range
.
second
;
}},
{
kIncludeRight
,
[](
int
num1
,
std
::
pair
<
int
,
int
>
range
)
->
bool
{
return
num1
>
range
.
first
&&
num1
<=
range
.
second
;
}},
{
kIncludeBoth
,
[](
int
num1
,
std
::
pair
<
int
,
int
>
range
)
->
bool
{
return
num1
>=
range
.
first
&&
num1
<=
range
.
second
;
}}};
const
std
::
map
<
CompareEnum
,
std
::
string
>
kCompareToString
=
{
{
kEqual
,
"equal"
},
{
kNotEqual
,
"not equal"
},
{
kLessThan
,
"less than"
},
{
kLessEqual
,
"less eqaul"
},
{
kGreaterThan
,
"greater than"
},
{
kGreaterEqual
,
"greate equal"
}};
const
std
::
map
<
CompareRange
,
std
::
pair
<
std
::
string
,
std
::
string
>>
kCompareRangeToString
=
{
{
kIncludeNeither
,
{
"in ("
,
")"
}},
{
kIncludeLeft
,
{
" in ["
,
")"
}},
{
kIncludeRight
,
{
"in ("
,
"]"
}},
{
kIncludeBoth
,
{
"in ["
,
"]"
}}};
}
// namespace
bool
CheckAndConvertUtils
::
IsEqualVector
(
const
std
::
vector
<
int
>
&
vec_1
,
const
std
::
vector
<
int
>
&
vec_2
)
{
if
(
vec_1
.
size
()
!=
vec_2
.
size
())
{
return
false
;
}
for
(
size_t
index
=
0
;
index
<
vec_1
.
size
();
++
index
)
{
if
(
vec_1
[
index
]
!=
vec_2
[
index
])
{
return
false
;
}
}
return
true
;
}
std
::
vector
<
int
>
CheckAndConvertUtils
::
CheckPositiveVector
(
const
std
::
string
&
arg_name
,
const
std
::
vector
<
int
>
&
arg_value
,
const
std
::
string
&
prim_name
,
bool
allow_four
,
bool
ret_four
)
{
if
(
arg_value
.
size
()
==
2
)
{
return
ret_four
?
std
::
vector
<
int
>
{
1
,
1
,
arg_value
[
0
],
arg_value
[
1
]}
:
arg_value
;
}
else
if
(
arg_value
.
size
()
==
4
&&
allow_four
)
{
return
ret_four
?
arg_value
:
std
::
vector
<
int
>
{
arg_value
[
2
],
arg_value
[
3
]};
}
std
::
ostringstream
buffer
;
buffer
<<
"For "
<<
prim_name
<<
" attr "
<<
arg_name
<<
" should be a positive vector of size two "
;
if
(
allow_four
)
{
buffer
<<
"or four "
;
}
buffer
<<
" positive int numbers , but got ["
;
for
(
auto
item
:
arg_value
)
{
buffer
<<
item
<<
","
;
}
buffer
<<
"]"
;
MS_EXCEPTION
(
ValueError
)
<<
buffer
.
str
();
}
std
::
string
CheckAndConvertUtils
::
CheckString
(
const
std
::
string
&
arg_name
,
const
std
::
string
&
arg_value
,
const
std
::
set
<
std
::
string
>
&
check_list
,
const
std
::
string
&
prim_name
)
{
if
(
check_list
.
find
(
arg_value
)
!=
check_list
.
end
())
{
return
arg_value
;
}
std
::
ostringstream
buffer
;
buffer
<<
"For "
<<
prim_name
<<
" the "
<<
arg_name
<<
" should be str and must be "
;
if
(
check_list
.
size
()
==
1
)
{
buffer
<<
(
*
check_list
.
begin
())
<<
"but got "
<<
arg_value
;
MS_EXCEPTION
(
ValueError
)
<<
buffer
.
str
();
}
buffer
<<
"one of {"
;
for
(
const
auto
&
item
:
check_list
)
{
buffer
<<
item
<<
" ,"
;
}
buffer
<<
" }"
<<
" but got "
<<
arg_value
;
MS_EXCEPTION
(
ValueError
)
<<
buffer
.
str
();
}
int
CheckAndConvertUtils
::
CheckInteger
(
const
std
::
string
&
arg_name
,
int
arg_value
,
CompareEnum
compare_operator
,
int
match_value
,
const
std
::
string
&
prim_name
)
{
auto
iter
=
kCompareMap
.
find
(
compare_operator
);
if
(
iter
==
kCompareMap
.
end
())
{
MS_EXCEPTION
(
NotExistsError
)
<<
"compare_operator "
<<
compare_operator
<<
" cannot find in the compare map"
;
}
if
(
iter
->
second
(
arg_value
,
match_value
))
{
return
arg_value
;
}
std
::
ostringstream
buffer
;
if
(
prim_name
.
empty
())
{
buffer
<<
"The "
;
}
else
{
buffer
<<
"For "
<<
prim_name
<<
" the "
;
}
buffer
<<
arg_name
<<
" must "
;
auto
iter_to_string
=
kCompareToString
.
find
(
compare_operator
);
if
(
iter_to_string
==
kCompareToString
.
end
())
{
MS_EXCEPTION
(
NotExistsError
)
<<
"compare_operator "
<<
compare_operator
<<
" cannot find in the compare string map"
;
}
buffer
<<
iter_to_string
->
second
<<
match_value
<<
" , but got "
<<
arg_value
;
MS_EXCEPTION
(
ValueError
)
<<
buffer
.
str
();
}
void
CheckAndConvertUtils
::
CheckInRange
(
const
std
::
string
&
arg_name
,
int
arg_value
,
CompareRange
compare_operator
,
const
std
::
pair
<
int
,
int
>
&
range
,
const
std
::
string
&
prim_name
)
{
auto
iter
=
kCompareRangeMap
.
find
(
compare_operator
);
if
(
iter
==
kCompareRangeMap
.
end
())
{
MS_EXCEPTION
(
NotExistsError
)
<<
"compare_operator "
<<
compare_operator
<<
" cannot find in the compare map"
;
}
if
(
iter
->
second
(
arg_value
,
range
))
{
return
;
}
std
::
ostringstream
buffer
;
if
(
prim_name
.
empty
())
{
buffer
<<
"The "
;
}
else
{
buffer
<<
"For "
<<
prim_name
<<
" the "
;
}
buffer
<<
arg_name
<<
" must "
;
auto
iter_to_string
=
kCompareRangeToString
.
find
(
compare_operator
);
if
(
iter_to_string
==
kCompareRangeToString
.
end
())
{
MS_EXCEPTION
(
NotExistsError
)
<<
"compare_operator "
<<
compare_operator
<<
" cannot find in the compare string map"
;
}
auto
range_strng
=
iter_to_string
->
second
;
buffer
<<
range_strng
.
first
<<
range
.
first
<<
","
<<
range_strng
.
second
<<
" , but got "
<<
arg_value
;
MS_EXCEPTION
(
ValueError
)
<<
buffer
.
str
();
}
std
::
vector
<
int
>
CheckAndConvertUtils
::
ConvertShapePtrToShape
(
const
std
::
string
&
arg_name
,
const
BaseShapePtr
&
shape
,
const
std
::
string
&
prim_name
)
{
MS_EXCEPTION_IF_NULL
(
shape
);
if
(
!
shape
->
isa
<
abstract
::
Shape
>
())
{
MS_EXCEPTION
(
ValueError
)
<<
"The "
<<
arg_name
<<
"'s shape is "
<<
shape
->
ToString
()
<<
"should be a common shape!"
;
}
auto
shape_element
=
shape
->
cast
<
abstract
::
ShapePtr
>
();
MS_EXCEPTION_IF_NULL
(
shape_element
);
return
shape_element
->
shape
();
}
TypeId
CheckAndConvertUtils
::
ConvertTypePtrToTypeId
(
const
string
&
arg_name
,
const
TypePtr
&
type_ptr
,
const
string
&
prim_name
)
{
MS_EXCEPTION_IF_NULL
(
type_ptr
);
if
(
!
type_ptr
->
isa
<
TensorType
>
()
||
!
type_ptr
->
isa
<
Number
>
())
{
MS_EXCEPTION
(
ValueError
)
<<
"The "
<<
arg_name
<<
"'s shape is "
<<
type_ptr
->
ToString
()
<<
"should be a common type!(tensor_type && numbertype)"
;
}
return
type_ptr
->
type_id
();
}
void
CheckAndConvertUtils
::
Check
(
const
string
&
arg_name
,
int
arg_value
,
CompareEnum
compare_type
,
const
string
&
value_name
,
int
value
,
const
string
&
prim_name
,
ExceptionType
exception_type
)
{
auto
iter
=
kCompareMap
.
find
(
compare_type
);
if
(
iter
==
kCompareMap
.
end
())
{
MS_EXCEPTION
(
NotExistsError
)
<<
"the compare type :"
<<
compare_type
<<
" is not in the compare map"
;
}
if
(
iter
->
second
(
arg_value
,
value
))
{
return
;
}
std
::
ostringstream
buffer
;
if
(
prim_name
.
empty
())
{
buffer
<<
"The "
;
}
else
{
buffer
<<
"For "
<<
prim_name
<<
" the "
;
}
auto
iter_to_string
=
kCompareToString
.
find
(
compare_type
);
if
(
iter_to_string
==
kCompareToString
.
end
())
{
MS_EXCEPTION
(
NotExistsError
)
<<
"compare_operator "
<<
compare_type
<<
" cannot find in the compare string map"
;
}
MS_EXCEPTION
(
exception_type
)
<<
buffer
.
str
()
<<
arg_name
<<
" should be "
<<
iter_to_string
->
second
<<
value
<<
" but got "
<<
arg_value
;
}
void
CheckAndConvertUtils
::
Check
(
const
string
&
arg_name
,
const
std
::
vector
<
int
>
&
arg_value
,
CompareEnum
compare_type
,
const
string
&
value_name
,
const
std
::
vector
<
int
>
&
value
,
const
string
&
prim_name
,
ExceptionType
exception_type
)
{
if
(
compare_type
!=
kEqual
)
{
auto
iter
=
kCompareToString
.
find
(
compare_type
);
if
(
iter
!=
kCompareToString
.
end
())
{
MS_EXCEPTION
(
NotSupportError
)
<<
"Only supported equal to compare two vectors but got "
<<
iter
->
second
;
}
MS_EXCEPTION
(
UnknownError
)
<<
"Cannot find the operator "
<<
compare_type
<<
"in the compare map!"
;
}
if
(
arg_value
==
value
)
{
return
;
}
std
::
ostringstream
buffer
;
if
(
prim_name
.
empty
())
{
buffer
<<
"The "
;
}
else
{
buffer
<<
"For "
<<
prim_name
<<
" the "
;
}
auto
iter_to_string
=
kCompareToString
.
find
(
compare_type
);
if
(
iter_to_string
==
kCompareToString
.
end
())
{
MS_EXCEPTION
(
NotExistsError
)
<<
"compare_operator "
<<
compare_type
<<
" cannot find in the compare string map"
;
}
buffer
<<
arg_name
<<
"should be "
<<
iter_to_string
->
second
<<
" ["
;
for
(
auto
item
:
value
)
{
buffer
<<
item
<<
","
;
}
buffer
<<
"] "
<<
"but got ["
;
for
(
auto
item
:
arg_value
)
{
buffer
<<
item
<<
" ,"
;
}
buffer
<<
"]"
;
MS_EXCEPTION
(
exception_type
)
<<
buffer
.
str
();
}
void
CheckAndConvertUtils
::
CheckTensorTypeSame
(
const
std
::
map
<
std
::
string
,
TypePtr
>
&
types
,
const
std
::
set
<
TypeId
>
&
check_list
,
const
std
::
string
&
prim_name
)
{
if
(
types
.
empty
())
{
MS_LOG
(
WARNING
)
<<
"Tryinh to use the function to check a empty types map!"
;
return
;
}
std
::
set
<
TypeId
>
types_id
;
std
::
ostringstream
buffer
;
buffer
<<
"For "
<<
prim_name
;
for
(
const
auto
&
type
:
types
)
{
MS_EXCEPTION_IF_NULL
(
type
.
second
);
if
(
!
type
.
second
->
isa
<
TensorType
>
())
{
MS_EXCEPTION
(
TypeError
)
<<
"The "
<<
prim_name
<<
"'s"
<<
type
.
first
<<
" input must be tensor type but got "
<<
type
.
second
->
ToString
();
}
types_id
.
emplace
(
type
.
second
->
type_id
());
}
if
(
types_id
.
size
()
>
1
)
{
buffer
<<
"'s input type is not same : "
;
for
(
const
auto
&
item
:
types
)
{
buffer
<<
"[ name : "
<<
item
.
first
<<
" ,type : "
<<
item
.
second
->
ToString
()
<<
"]"
;
}
MS_EXCEPTION
(
TypeError
)
<<
buffer
.
str
();
}
if
(
check_list
.
find
(
*
(
types_id
.
begin
()))
!=
check_list
.
end
())
{
buffer
<<
" type of "
;
for
(
const
auto
&
elem
:
types
)
{
buffer
<<
elem
.
first
<<
" should be in ["
;
for
(
auto
type_elem
:
check_list
)
{
buffer
<<
type_elem
<<
" ,"
;
}
buffer
<<
"] , but got "
<<
types
.
begin
()
->
second
->
ToString
();
}
}
MS_EXCEPTION
(
TypeError
)
<<
buffer
.
str
();
}
}
// namespace mindspore
mindspore/core/utils/check_convert_utils.h
0 → 100644
浏览文件 @
ae641e14
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CORE_UTILS_CHECK_CONVERT_UTILS_H
#define MINDSPORE_CORE_UTILS_CHECK_CONVERT_UTILS_H
#include <vector>
#include <string>
#include <map>
#include <set>
#include <utility>
#include "base/base.h"
#include "ir/anf.h"
#include "ir/dtype/type_id.h"
#include "utils/log_adapter.h"
namespace
mindspore
{
enum
CompareEnum
:
int
{
kEqual
=
1
,
// ==
kNotEqual
=
2
,
// !=
kLessThan
=
3
,
// <
kLessEqual
=
4
,
// <=
kGreaterThan
=
5
,
// >
kGreaterEqual
=
6
,
// >=
};
enum
CompareRange
{
kIncludeNeither
=
1
,
// (a,b)
kIncludeLeft
=
2
,
// [a,b)
kIncludeRight
=
3
,
// (a,b]
kIncludeBoth
=
4
,
// [a,b]
};
class
CheckAndConvertUtils
{
public:
static
std
::
vector
<
int
>
CheckPositiveVector
(
const
std
::
string
&
arg_name
,
const
std
::
vector
<
int
>
&
arg_value
,
const
std
::
string
&
prim_name
,
bool
allow_four
=
false
,
bool
ret_four
=
false
);
static
std
::
string
CheckString
(
const
std
::
string
&
arg_name
,
const
std
::
string
&
arg_value
,
const
std
::
set
<
std
::
string
>
&
check_list
,
const
std
::
string
&
prim_name
);
static
int
CheckInteger
(
const
std
::
string
&
arg_name
,
int
arg_value
,
CompareEnum
compare_operator
,
int
match_value
,
const
std
::
string
&
prim_name
);
static
void
CheckInRange
(
const
std
::
string
&
arg_name
,
int
arg_value
,
CompareRange
compare_operator
,
const
std
::
pair
<
int
,
int
>
&
range
,
const
std
::
string
&
prim_name
);
static
std
::
vector
<
int
>
ConvertShapePtrToShape
(
const
std
::
string
&
arg_name
,
const
BaseShapePtr
&
shape
,
const
std
::
string
&
prim_name
);
static
TypeId
ConvertTypePtrToTypeId
(
const
std
::
string
&
arg_name
,
const
TypePtr
&
type_ptr
,
const
std
::
string
&
prim_name
);
static
void
Check
(
const
std
::
string
&
arg_name
,
int
arg_value
,
CompareEnum
compare_type
,
const
std
::
string
&
value_name
,
int
value
,
const
std
::
string
&
prim_name
=
""
,
ExceptionType
exception_type
=
ValueError
);
static
void
Check
(
const
std
::
string
&
arg_name
,
const
std
::
vector
<
int
>
&
arg_value
,
CompareEnum
compare_type
,
const
std
::
string
&
value_name
,
const
std
::
vector
<
int
>
&
value
,
const
std
::
string
&
prim_name
=
""
,
ExceptionType
exception_type
=
ValueError
);
static
void
CheckTensorTypeSame
(
const
std
::
map
<
std
::
string
,
TypePtr
>
&
types
,
const
std
::
set
<
TypeId
>
&
check_list
,
const
std
::
string
&
prim_name
);
private:
static
bool
IsEqualVector
(
const
std
::
vector
<
int
>
&
vec_1
,
const
std
::
vector
<
int
>
&
vec_2
);
};
}
// namespace mindspore
#endif // MINDSPORE_CORE_UTILS_CHECK_CONVERT_UTILS_H
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录