Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
5a886794
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
5a886794
编写于
6月 28, 2020
作者:
Y
yanghaitao
浏览文件
操作
浏览文件
下载
差异文件
Merge branch 'master' of gitee.com:mindspore/mindspore
上级
7f54d17b
d6d93f16
变更
122
隐藏空白更改
内联
并排
Showing
122 changed file
with
3860 addition
and
1962 deletion
+3860
-1962
akg
akg
+1
-1
mindspore/ccsrc/CMakeLists.txt
mindspore/ccsrc/CMakeLists.txt
+7
-6
mindspore/ccsrc/dataset/kernels/image/random_crop_and_resize_with_bbox_op.cc
...aset/kernels/image/random_crop_and_resize_with_bbox_op.cc
+1
-2
mindspore/ccsrc/dataset/kernels/image/random_crop_with_bbox_op.cc
...e/ccsrc/dataset/kernels/image/random_crop_with_bbox_op.cc
+1
-2
mindspore/ccsrc/dataset/kernels/image/random_vertical_flip_with_bbox_op.cc
...ataset/kernels/image/random_vertical_flip_with_bbox_op.cc
+1
-2
mindspore/ccsrc/dataset/util/CMakeLists.txt
mindspore/ccsrc/dataset/util/CMakeLists.txt
+6
-0
mindspore/ccsrc/dataset/util/allocator.h
mindspore/ccsrc/dataset/util/allocator.h
+87
-0
mindspore/ccsrc/dataset/util/auto_index.h
mindspore/ccsrc/dataset/util/auto_index.h
+1
-1
mindspore/ccsrc/dataset/util/buddy.cc
mindspore/ccsrc/dataset/util/buddy.cc
+388
-0
mindspore/ccsrc/dataset/util/buddy.h
mindspore/ccsrc/dataset/util/buddy.h
+133
-0
mindspore/ccsrc/dataset/util/cache_pool.cc
mindspore/ccsrc/dataset/util/cache_pool.cc
+202
-0
mindspore/ccsrc/dataset/util/cache_pool.h
mindspore/ccsrc/dataset/util/cache_pool.h
+139
-0
mindspore/ccsrc/dataset/util/list.h
mindspore/ccsrc/dataset/util/list.h
+18
-0
mindspore/ccsrc/dataset/util/memory_pool.h
mindspore/ccsrc/dataset/util/memory_pool.h
+0
-14
mindspore/ccsrc/dataset/util/path.cc
mindspore/ccsrc/dataset/util/path.cc
+115
-3
mindspore/ccsrc/dataset/util/path.h
mindspore/ccsrc/dataset/util/path.h
+14
-0
mindspore/ccsrc/dataset/util/semaphore.cc
mindspore/ccsrc/dataset/util/semaphore.cc
+41
-0
mindspore/ccsrc/dataset/util/semaphore.h
mindspore/ccsrc/dataset/util/semaphore.h
+54
-0
mindspore/ccsrc/dataset/util/slice.cc
mindspore/ccsrc/dataset/util/slice.cc
+38
-0
mindspore/ccsrc/dataset/util/slice.h
mindspore/ccsrc/dataset/util/slice.h
+122
-0
mindspore/ccsrc/dataset/util/storage_container.cc
mindspore/ccsrc/dataset/util/storage_container.cc
+164
-0
mindspore/ccsrc/dataset/util/storage_container.h
mindspore/ccsrc/dataset/util/storage_container.h
+79
-0
mindspore/ccsrc/dataset/util/storage_manager.cc
mindspore/ccsrc/dataset/util/storage_manager.cc
+167
-0
mindspore/ccsrc/dataset/util/storage_manager.h
mindspore/ccsrc/dataset/util/storage_manager.h
+76
-0
mindspore/ccsrc/dataset/util/system_pool.h
mindspore/ccsrc/dataset/util/system_pool.h
+7
-0
mindspore/ccsrc/device/kernel_runtime.cc
mindspore/ccsrc/device/kernel_runtime.cc
+21
-6
mindspore/ccsrc/device/kernel_runtime.h
mindspore/ccsrc/device/kernel_runtime.h
+2
-2
mindspore/ccsrc/ir/optimizer_caller.h
mindspore/ccsrc/ir/optimizer_caller.h
+11
-1
mindspore/ccsrc/kernel/kernel_query.cc
mindspore/ccsrc/kernel/kernel_query.cc
+7
-0
mindspore/ccsrc/optimizer/cse.cc
mindspore/ccsrc/optimizer/cse.cc
+53
-27
mindspore/ccsrc/optimizer/cse.h
mindspore/ccsrc/optimizer/cse.h
+1
-1
mindspore/ccsrc/optimizer/irpass.cc
mindspore/ccsrc/optimizer/irpass.cc
+89
-75
mindspore/ccsrc/optimizer/irpass/arithmetic_simplify.h
mindspore/ccsrc/optimizer/irpass/arithmetic_simplify.h
+28
-31
mindspore/ccsrc/optimizer/irpass/cast_eliminate.h
mindspore/ccsrc/optimizer/irpass/cast_eliminate.h
+3
-3
mindspore/ccsrc/optimizer/irpass/env_item_eliminate.h
mindspore/ccsrc/optimizer/irpass/env_item_eliminate.h
+16
-14
mindspore/ccsrc/optimizer/irpass/incorporate_getitem.h
mindspore/ccsrc/optimizer/irpass/incorporate_getitem.h
+15
-12
mindspore/ccsrc/optimizer/irpass/item_tuple_eliminate.h
mindspore/ccsrc/optimizer/irpass/item_tuple_eliminate.h
+16
-17
mindspore/ccsrc/optimizer/irpass/ref_eliminate.h
mindspore/ccsrc/optimizer/irpass/ref_eliminate.h
+2
-2
mindspore/ccsrc/optimizer/irpass/reshape_eliminate.h
mindspore/ccsrc/optimizer/irpass/reshape_eliminate.h
+6
-5
mindspore/ccsrc/optimizer/irpass/special_op_eliminate.h
mindspore/ccsrc/optimizer/irpass/special_op_eliminate.h
+18
-18
mindspore/ccsrc/optimizer/opt.cc
mindspore/ccsrc/optimizer/opt.cc
+9
-10
mindspore/ccsrc/optimizer/opt.h
mindspore/ccsrc/optimizer/opt.h
+9
-15
mindspore/ccsrc/parallel/auto_parallel/operator_costmodel.cc
mindspore/ccsrc/parallel/auto_parallel/operator_costmodel.cc
+2
-2
mindspore/ccsrc/parallel/context.cc
mindspore/ccsrc/parallel/context.cc
+1
-0
mindspore/ccsrc/parallel/context.h
mindspore/ccsrc/parallel/context.h
+6
-0
mindspore/ccsrc/pipeline/init.cc
mindspore/ccsrc/pipeline/init.cc
+4
-0
mindspore/ccsrc/pre_activate/pass/common_subexpression_elimination.cc
...src/pre_activate/pass/common_subexpression_elimination.cc
+1
-1
mindspore/ccsrc/pre_activate/pass/common_subexpression_elimination.h
...csrc/pre_activate/pass/common_subexpression_elimination.h
+1
-1
mindspore/ccsrc/pybind_api/export_flags.cc
mindspore/ccsrc/pybind_api/export_flags.cc
+1
-0
mindspore/ccsrc/pybind_api/export_flags.h
mindspore/ccsrc/pybind_api/export_flags.h
+1
-1
mindspore/ccsrc/session/ascend_control_parser.cc
mindspore/ccsrc/session/ascend_control_parser.cc
+32
-8
mindspore/ccsrc/session/ascend_control_parser.h
mindspore/ccsrc/session/ascend_control_parser.h
+3
-2
mindspore/ccsrc/session/kernel_graph.cc
mindspore/ccsrc/session/kernel_graph.cc
+56
-2
mindspore/ccsrc/session/session.cc
mindspore/ccsrc/session/session.cc
+73
-39
mindspore/ccsrc/session/session_basic.cc
mindspore/ccsrc/session/session_basic.cc
+10
-2
mindspore/ccsrc/transform/convert.cc
mindspore/ccsrc/transform/convert.cc
+3
-1
mindspore/ccsrc/utils/log_adapter.cc
mindspore/ccsrc/utils/log_adapter.cc
+21
-11
mindspore/ccsrc/utils/utils.h
mindspore/ccsrc/utils/utils.h
+1
-0
mindspore/common/tensor.py
mindspore/common/tensor.py
+8
-0
mindspore/dataset/engine/datasets.py
mindspore/dataset/engine/datasets.py
+3
-3
mindspore/dataset/transforms/vision/c_transforms.py
mindspore/dataset/transforms/vision/c_transforms.py
+3
-3
mindspore/nn/optim/adam.py
mindspore/nn/optim/adam.py
+47
-36
mindspore/nn/optim/lamb.py
mindspore/nn/optim/lamb.py
+72
-70
mindspore/nn/optim/optimizer.py
mindspore/nn/optim/optimizer.py
+94
-4
mindspore/nn/wrap/cell_wrapper.py
mindspore/nn/wrap/cell_wrapper.py
+3
-1
mindspore/ops/_op_impl/akg/__init__.py
mindspore/ops/_op_impl/akg/__init__.py
+1
-0
mindspore/ops/_op_impl/akg/batchmatmul.py
mindspore/ops/_op_impl/akg/batchmatmul.py
+73
-0
mindspore/ops/_op_impl/tbe/confusion_transpose_d.py
mindspore/ops/_op_impl/tbe/confusion_transpose_d.py
+2
-20
mindspore/ops/composite/multitype_ops/setitem_impl.py
mindspore/ops/composite/multitype_ops/setitem_impl.py
+16
-0
mindspore/ops/operations/comm_ops.py
mindspore/ops/operations/comm_ops.py
+1
-0
mindspore/ops/operations/debug_ops.py
mindspore/ops/operations/debug_ops.py
+1
-7
mindspore/ops/operations/math_ops.py
mindspore/ops/operations/math_ops.py
+4
-2
mindspore/ops/operations/other_ops.py
mindspore/ops/operations/other_ops.py
+1
-2
mindspore/parallel/_auto_parallel_context.py
mindspore/parallel/_auto_parallel_context.py
+25
-3
mindspore/train/callback/_summary_collector.py
mindspore/train/callback/_summary_collector.py
+20
-4
model_zoo/faster_rcnn/src/dataset.py
model_zoo/faster_rcnn/src/dataset.py
+14
-18
model_zoo/vgg16/src/config.py
model_zoo/vgg16/src/config.py
+3
-1
model_zoo/vgg16/train.py
model_zoo/vgg16/train.py
+16
-10
serving/core/server.cc
serving/core/server.cc
+28
-13
serving/core/util/file_system_operation.cc
serving/core/util/file_system_operation.cc
+2
-3
serving/core/util/option_parser.cc
serving/core/util/option_parser.cc
+23
-17
serving/core/util/option_parser.h
serving/core/util/option_parser.h
+1
-2
serving/core/version_control/model.cc
serving/core/version_control/model.cc
+0
-1
serving/core/version_control/version_controller.cc
serving/core/version_control/version_controller.cc
+6
-8
serving/core/version_control/version_controller.h
serving/core/version_control/version_controller.h
+0
-1
serving/cpp_example/ms_client.cc
serving/cpp_example/ms_client.cc
+1
-1
serving/scripts/format_source_code.sh
serving/scripts/format_source_code.sh
+1
-1
setup.py
setup.py
+1
-0
tests/ut/cpp/dataset/btree_test.cc
tests/ut/cpp/dataset/btree_test.cc
+3
-3
tests/ut/cpp/optimizer/opt_test.cc
tests/ut/cpp/optimizer/opt_test.cc
+4
-4
tests/ut/cpp/parallel/step_parallel_test.cc
tests/ut/cpp/parallel/step_parallel_test.cc
+3
-0
tests/ut/data/dataset/declient.cfg
tests/ut/data/dataset/declient.cfg
+2
-1
tests/ut/data/dataset/golden/bounding_box_augment_crop_c_result.npz
...ata/dataset/golden/bounding_box_augment_crop_c_result.npz
+0
-0
tests/ut/data/dataset/golden/bounding_box_augment_rotation_c_result.npz
...dataset/golden/bounding_box_augment_rotation_c_result.npz
+0
-0
tests/ut/data/dataset/golden/bounding_box_augment_valid_edge_c_result.npz
...taset/golden/bounding_box_augment_valid_edge_c_result.npz
+0
-0
tests/ut/data/dataset/golden/bounding_box_augment_valid_ratio_c_result.npz
...aset/golden/bounding_box_augment_valid_ratio_c_result.npz
+0
-0
tests/ut/data/dataset/golden/random_crop_with_bbox_01_c_result.npz
...data/dataset/golden/random_crop_with_bbox_01_c_result.npz
+0
-0
tests/ut/data/dataset/golden/random_horizontal_flip_with_bbox_01_c_result.npz
...t/golden/random_horizontal_flip_with_bbox_01_c_result.npz
+0
-0
tests/ut/data/dataset/golden/random_resize_with_bbox_op_01_c_result.npz
...dataset/golden/random_resize_with_bbox_op_01_c_result.npz
+0
-0
tests/ut/data/dataset/golden/random_resized_crop_with_bbox_01_c_result.npz
...aset/golden/random_resized_crop_with_bbox_01_c_result.npz
+0
-0
tests/ut/data/dataset/golden/random_vertical_flip_with_bbox_01_c_result.npz
...set/golden/random_vertical_flip_with_bbox_01_c_result.npz
+0
-0
tests/ut/data/dataset/golden/resize_with_bbox_op_01_c_result.npz
...t/data/dataset/golden/resize_with_bbox_op_01_c_result.npz
+0
-0
tests/ut/python/dataset/test_batch.py
tests/ut/python/dataset/test_batch.py
+3
-5
tests/ut/python/dataset/test_bounding_box_augment.py
tests/ut/python/dataset/test_bounding_box_augment.py
+182
-188
tests/ut/python/dataset/test_center_crop.py
tests/ut/python/dataset/test_center_crop.py
+3
-8
tests/ut/python/dataset/test_config.py
tests/ut/python/dataset/test_config.py
+6
-1
tests/ut/python/dataset/test_filterop.py
tests/ut/python/dataset/test_filterop.py
+14
-43
tests/ut/python/dataset/test_pad.py
tests/ut/python/dataset/test_pad.py
+5
-9
tests/ut/python/dataset/test_random_crop_and_resize_with_bbox.py
...t/python/dataset/test_random_crop_and_resize_with_bbox.py
+60
-156
tests/ut/python/dataset/test_random_crop_with_bbox.py
tests/ut/python/dataset/test_random_crop_with_bbox.py
+56
-155
tests/ut/python/dataset/test_random_horizontal_flip_bbox.py
tests/ut/python/dataset/test_random_horizontal_flip_bbox.py
+0
-266
tests/ut/python/dataset/test_random_horizontal_flip_with_bbox.py
...t/python/dataset/test_random_horizontal_flip_with_bbox.py
+229
-0
tests/ut/python/dataset/test_random_resize_with_bbox.py
tests/ut/python/dataset/test_random_resize_with_bbox.py
+126
-197
tests/ut/python/dataset/test_random_vertical_flip_with_bbox.py
.../ut/python/dataset/test_random_vertical_flip_with_bbox.py
+74
-21
tests/ut/python/dataset/test_resize_with_bbox.py
tests/ut/python/dataset/test_resize_with_bbox.py
+99
-229
tests/ut/python/dataset/util.py
tests/ut/python/dataset/util.py
+18
-13
tests/ut/python/ops/test_signature.py
tests/ut/python/ops/test_signature.py
+0
-22
tests/ut/python/parallel/test_parallel_optimizer.py
tests/ut/python/parallel/test_parallel_optimizer.py
+114
-0
tests/ut/python/parallel/test_set_auto_parallel_context.py
tests/ut/python/parallel/test_set_auto_parallel_context.py
+4
-0
tests/ut/python/pynative_mode/ge/model/__init__.py
tests/ut/python/pynative_mode/ge/model/__init__.py
+0
-0
tests/ut/python/pynative_mode/ge/model/test_lenet_model.py
tests/ut/python/pynative_mode/ge/model/test_lenet_model.py
+0
-70
tests/ut/python/utils/test_serialize.py
tests/ut/python/utils/test_serialize.py
+1
-0
未找到文件。
akg
@
df57a6cf
Subproject commit
c460176523d039c8995f1d71089753725ebc0792
Subproject commit
df57a6cf9450e347d1854687d1fe66a420ee3b35
mindspore/ccsrc/CMakeLists.txt
浏览文件 @
5a886794
...
@@ -277,10 +277,11 @@ endif ()
...
@@ -277,10 +277,11 @@ endif ()
if
(
USE_GLOG
)
if
(
USE_GLOG
)
target_link_libraries
(
inference PRIVATE mindspore::glog
)
target_link_libraries
(
inference PRIVATE mindspore::glog
)
else
()
if
(
CMAKE_SYSTEM_NAME MATCHES
"Linux"
)
target_link_options
(
inference PRIVATE -Wl,-init,mindspore_log_init
)
elseif
(
CMAKE_SYSTEM_NAME MATCHES
"Darwin"
)
set_target_properties
(
inference PROPERTIES MACOSX_RPATH ON
)
endif
()
endif
()
endif
()
if
(
CMAKE_SYSTEM_NAME MATCHES
"Linux"
)
target_link_options
(
inference PRIVATE -Wl,-init,common_log_init
)
elseif
(
CMAKE_SYSTEM_NAME MATCHES
"Darwin"
)
set_target_properties
(
inference PROPERTIES MACOSX_RPATH ON
)
endif
()
mindspore/ccsrc/dataset/kernels/image/random_crop_and_resize_with_bbox_op.cc
浏览文件 @
5a886794
...
@@ -30,8 +30,7 @@ Status RandomCropAndResizeWithBBoxOp::Compute(const TensorRow &input, TensorRow
...
@@ -30,8 +30,7 @@ Status RandomCropAndResizeWithBBoxOp::Compute(const TensorRow &input, TensorRow
BOUNDING_BOX_CHECK
(
input
);
BOUNDING_BOX_CHECK
(
input
);
CHECK_FAIL_RETURN_UNEXPECTED
(
input
[
0
]
->
shape
().
Size
()
>=
2
,
"The shape of input is abnormal"
);
CHECK_FAIL_RETURN_UNEXPECTED
(
input
[
0
]
->
shape
().
Size
()
>=
2
,
"The shape of input is abnormal"
);
(
*
output
).
push_back
(
nullptr
);
// init memory for return vector
output
->
resize
(
2
);
(
*
output
).
push_back
(
nullptr
);
(
*
output
)[
1
]
=
std
::
move
(
input
[
1
]);
// move boxes over to output
(
*
output
)[
1
]
=
std
::
move
(
input
[
1
]);
// move boxes over to output
size_t
bboxCount
=
input
[
1
]
->
shape
()[
0
];
// number of rows in bbox tensor
size_t
bboxCount
=
input
[
1
]
->
shape
()[
0
];
// number of rows in bbox tensor
...
...
mindspore/ccsrc/dataset/kernels/image/random_crop_with_bbox_op.cc
浏览文件 @
5a886794
...
@@ -36,8 +36,7 @@ Status RandomCropWithBBoxOp::Compute(const TensorRow &input, TensorRow *output)
...
@@ -36,8 +36,7 @@ Status RandomCropWithBBoxOp::Compute(const TensorRow &input, TensorRow *output)
int32_t
padded_image_h
;
int32_t
padded_image_h
;
int32_t
padded_image_w
;
int32_t
padded_image_w
;
(
*
output
).
push_back
(
nullptr
);
output
->
resize
(
2
);
(
*
output
).
push_back
(
nullptr
);
(
*
output
)[
1
]
=
std
::
move
(
input
[
1
]);
// since some boxes may be removed
(
*
output
)[
1
]
=
std
::
move
(
input
[
1
]);
// since some boxes may be removed
bool
crop_further
=
true
;
// Whether further cropping will be required or not, true unless required size matches
bool
crop_further
=
true
;
// Whether further cropping will be required or not, true unless required size matches
...
...
mindspore/ccsrc/dataset/kernels/image/random_vertical_flip_with_bbox_op.cc
浏览文件 @
5a886794
...
@@ -45,8 +45,7 @@ Status RandomVerticalFlipWithBBoxOp::Compute(const TensorRow &input, TensorRow *
...
@@ -45,8 +45,7 @@ Status RandomVerticalFlipWithBBoxOp::Compute(const TensorRow &input, TensorRow *
RETURN_IF_NOT_OK
(
input
[
1
]
->
SetItemAt
({
i
,
1
},
newBoxCorner_y
));
RETURN_IF_NOT_OK
(
input
[
1
]
->
SetItemAt
({
i
,
1
},
newBoxCorner_y
));
}
}
(
*
output
).
push_back
(
nullptr
);
output
->
resize
(
2
);
(
*
output
).
push_back
(
nullptr
);
(
*
output
)[
1
]
=
std
::
move
(
input
[
1
]);
(
*
output
)[
1
]
=
std
::
move
(
input
[
1
]);
return
VerticalFlip
(
input
[
0
],
&
(
*
output
)[
0
]);
return
VerticalFlip
(
input
[
0
],
&
(
*
output
)[
0
]);
...
...
mindspore/ccsrc/dataset/util/CMakeLists.txt
浏览文件 @
5a886794
...
@@ -2,6 +2,8 @@ file(GLOB_RECURSE _CURRENT_SRC_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.cc"
...
@@ -2,6 +2,8 @@ file(GLOB_RECURSE _CURRENT_SRC_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.cc"
set_property
(
SOURCE
${
_CURRENT_SRC_FILES
}
PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_MD
)
set_property
(
SOURCE
${
_CURRENT_SRC_FILES
}
PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_MD
)
add_library
(
utils OBJECT
add_library
(
utils OBJECT
arena.cc
arena.cc
buddy.cc
cache_pool.cc
circular_pool.cc
circular_pool.cc
memory_pool.cc
memory_pool.cc
cond_var.cc
cond_var.cc
...
@@ -11,7 +13,11 @@ add_library(utils OBJECT
...
@@ -11,7 +13,11 @@ add_library(utils OBJECT
service.cc
service.cc
services.cc
services.cc
lock.cc
lock.cc
semaphore.cc
status.cc
status.cc
storage_container.cc
storage_manager.cc
slice.cc
path.cc
path.cc
wait_post.cc
wait_post.cc
sig_handler.cc
)
sig_handler.cc
)
mindspore/ccsrc/dataset/util/allocator.h
浏览文件 @
5a886794
...
@@ -17,8 +17,10 @@
...
@@ -17,8 +17,10 @@
#define DATASET_UTIL_ALLOCATOR_H_
#define DATASET_UTIL_ALLOCATOR_H_
#include <cstdlib>
#include <cstdlib>
#include <functional>
#include <memory>
#include <memory>
#include <type_traits>
#include <type_traits>
#include <utility>
#include "dataset/util/memory_pool.h"
#include "dataset/util/memory_pool.h"
namespace
mindspore
{
namespace
mindspore
{
...
@@ -84,6 +86,91 @@ class Allocator {
...
@@ -84,6 +86,91 @@ class Allocator {
private:
private:
std
::
shared_ptr
<
MemoryPool
>
pool_
;
std
::
shared_ptr
<
MemoryPool
>
pool_
;
};
};
/// \brief It is a wrapper of unique_ptr with a custom allocator and acts like std::lock_guard such that the memory will
/// be released when the object goes out of scope \tparam T The type of object to be allocated \tparam C Allocator.
/// Default to std::allocator
template
<
typename
T
,
typename
C
=
std
::
allocator
<
T
>
>
class
MemGuard
{
public:
using
allocator
=
C
;
MemGuard
()
:
n_
(
0
)
{}
explicit
MemGuard
(
allocator
a
)
:
n_
(
0
),
alloc_
(
a
)
{}
// There is no copy constructor nor assignment operator because the memory is solely owned by this object.
MemGuard
(
const
MemGuard
&
)
=
delete
;
MemGuard
&
operator
=
(
const
MemGuard
&
)
=
delete
;
// On the other hand, We can support move constructor
MemGuard
(
MemGuard
&&
lhs
)
noexcept
:
alloc_
(
std
::
move
(
lhs
.
alloc_
)),
ptr_
(
std
::
move
(
lhs
.
ptr_
)),
n_
(
lhs
.
n_
)
{}
MemGuard
&
operator
=
(
MemGuard
&&
lhs
)
noexcept
{
if
(
this
!=
&
lhs
)
{
this
->
deallocate
();
n_
=
lhs
.
n_
;
alloc_
=
std
::
move
(
lhs
.
alloc_
);
ptr_
=
std
::
move
(
lhs
.
ptr_
);
}
return
*
this
;
}
/// \brief Explicitly deallocate the memory if allocated
void
deallocate
()
{
if
(
ptr_
)
{
auto
*
p
=
ptr_
.
release
();
if
(
!
std
::
is_arithmetic
<
T
>::
value
&&
std
::
is_destructible
<
T
>::
value
)
{
for
(
auto
i
=
0
;
i
<
n_
;
++
i
)
{
p
[
i
].
~
T
();
}
}
alloc_
.
deallocate
(
p
,
n_
);
n_
=
0
;
}
}
/// \brief Allocate memory (with emplace feature). Previous one will be released. If size is 0, no new memory is
/// allocated.
/// \param n Number of objects of type T to be allocated
/// \tparam Args Extra arguments pass to the constructor of T
template
<
typename
...
Args
>
Status
allocate
(
size_t
n
,
Args
&&
...
args
)
noexcept
{
try
{
deallocate
();
if
(
n
>
0
)
{
T
*
data
=
alloc_
.
allocate
(
n
);
if
(
!
std
::
is_arithmetic
<
T
>::
value
)
{
for
(
auto
i
=
0
;
i
<
n
;
i
++
)
{
std
::
allocator_traits
<
C
>::
construct
(
alloc_
,
&
(
data
[
i
]),
std
::
forward
<
Args
>
(
args
)...);
}
}
ptr_
=
std
::
unique_ptr
<
T
[]
>
(
data
);
n_
=
n
;
}
}
catch
(
const
std
::
bad_alloc
&
e
)
{
return
Status
(
StatusCode
::
kOutOfMemory
);
}
catch
(
std
::
exception
&
e
)
{
RETURN_STATUS_UNEXPECTED
(
e
.
what
());
}
return
Status
::
OK
();
}
~
MemGuard
()
noexcept
{
deallocate
();
}
/// \brief Getter function
/// \return The pointer to the memory allocated
T
*
GetPointer
()
const
{
return
ptr_
.
get
();
}
/// \brief Getter function
/// \return The pointer to the memory allocated
T
*
GetMutablePointer
()
{
return
ptr_
.
get
();
}
/// \brief Overload [] operator to access a particular element
/// \param x index to the element. Must be less than number of element allocated.
/// \return pointer to the x-th element
T
*
operator
[](
size_t
x
)
{
return
GetMutablePointer
()
+
x
;
}
/// \brief Overload [] operator to access a particular element
/// \param x index to the element. Must be less than number of element allocated.
/// \return pointer to the x-th element
T
*
operator
[](
size_t
x
)
const
{
return
GetPointer
()
+
x
;
}
/// \brief Return how many bytes are allocated in total
/// \return Number of bytes allocated in total
size_t
GetSizeInBytes
()
const
{
return
n_
*
sizeof
(
T
);
}
private:
allocator
alloc_
;
std
::
unique_ptr
<
T
[],
std
::
function
<
void
(
T
*
)
>>
ptr_
;
size_t
n_
;
};
}
// namespace dataset
}
// namespace dataset
}
// namespace mindspore
}
// namespace mindspore
...
...
mindspore/ccsrc/dataset/util/auto_index.h
浏览文件 @
5a886794
...
@@ -91,7 +91,7 @@ class AutoIndexObj : public BPlusTree<int64_t, T, A> {
...
@@ -91,7 +91,7 @@ class AutoIndexObj : public BPlusTree<int64_t, T, A> {
}
}
private:
private:
static
constexpr
key_type
kMinKey
=
1
;
static
constexpr
key_type
kMinKey
=
0
;
std
::
atomic
<
key_type
>
inx_
;
std
::
atomic
<
key_type
>
inx_
;
};
};
}
// namespace dataset
}
// namespace dataset
...
...
mindspore/ccsrc/dataset/util/buddy.cc
0 → 100644
浏览文件 @
5a886794
/**
* Copyright 2019 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "dataset/util/buddy.h"
#include <iomanip>
#include <stdexcept>
#include "dataset/util/de_error.h"
#include "dataset/util/memory_pool.h"
#include "dataset/util/system_pool.h"
#include "./securec.h"
inline
uint64_t
BitLeftShift
(
uint64_t
v
,
uint64_t
n
)
{
return
(
v
<<
n
);
}
inline
uint64_t
BitRightShift
(
uint64_t
v
,
uint64_t
n
)
{
return
(
v
>>
n
);
}
inline
uint64_t
BitOr
(
uint64_t
rhs
,
uint64_t
lhs
)
{
return
rhs
|
lhs
;
}
inline
uint64_t
BitEx
(
uint64_t
rhs
,
uint64_t
lhs
)
{
return
rhs
^
lhs
;
}
inline
uint64_t
BitAnd
(
uint64_t
rhs
,
uint64_t
lhs
)
{
return
rhs
&
lhs
;
}
namespace
mindspore
{
namespace
dataset
{
Status
BuddySpace
::
Init
()
{
if
(
log_min_
<
0
)
{
return
Status
(
StatusCode
::
kUnexpectedError
,
__LINE__
,
__FILE__
,
"log_min must be positive : "
+
std
::
to_string
(
log_min_
));
}
if
(
num_lvl_
<
3
||
num_lvl_
>
18
)
{
return
Status
(
StatusCode
::
kUnexpectedError
,
__LINE__
,
__FILE__
,
"num_lvl must be between 3 and 18 : "
+
std
::
to_string
(
num_lvl_
));
}
min_
=
BitLeftShift
(
1
,
log_min_
);
max_
=
BitLeftShift
(
1
,
log_min_
+
num_lvl_
-
1
);
size_t
offset_1
=
sizeof
(
rel_addr_t
)
*
num_lvl_
;
size_t
offset_2
=
sizeof
(
int
)
*
num_lvl_
+
offset_1
;
size_t
offset_3
=
sizeof
(
char
)
*
BitLeftShift
(
1
,
num_lvl_
-
3
)
+
offset_2
;
RETURN_IF_NOT_OK
(
DeMalloc
(
offset_3
,
&
ptr_
,
true
));
hint_
=
reinterpret_cast
<
rel_addr_t
*>
(
ptr_
);
count_
=
reinterpret_cast
<
int
*>
((
reinterpret_cast
<
char
*>
(
ptr_
)
+
offset_1
));
map_
=
reinterpret_cast
<
char
*>
(
ptr_
)
+
offset_2
;
count_
[
num_lvl_
-
1
]
=
1
;
map_
[
0
]
=
BitOr
(
MORE_BIT
,
num_lvl_
-
3
);
return
Status
::
OK
();
}
Status
BuddySpace
::
Alloc
(
const
uint64_t
sz
,
BSpaceDescriptor
*
desc
,
addr_t
*
p
)
noexcept
{
std
::
lock_guard
<
std
::
mutex
>
lock
(
mutex_
);
addr_t
addr
=
AllocNoLock
(
sz
,
desc
);
if
(
addr
!=
NOSPACE
)
{
*
p
=
addr
;
return
Status
::
OK
();
}
else
{
return
Status
(
StatusCode
::
kNoSpace
,
"BuddySpace full. Not an error. Please ignore."
);
}
}
addr_t
BuddySpace
::
AllocNoLock
(
const
uint64_t
sz
,
BSpaceDescriptor
*
desc
)
noexcept
{
DS_ASSERT
(
sz
<=
max_
);
uint32_t
reqSize
=
SizeToBlock
(
sz
);
rel_addr_t
rel_addr
=
AllocBuddySeg
(
reqSize
);
if
(
rel_addr
!=
static_cast
<
rel_addr_t
>
(
NOSPACE
))
{
(
void
)
memset_s
(
desc
,
sizeof
(
BSpaceDescriptor
),
0
,
sizeof
(
BSpaceDescriptor
));
desc
->
sig
=
static_cast
<
int
>
(
0xDEADBEEF
);
desc
->
addr
=
rel_addr
;
desc
->
req_size
=
reqSize
;
desc
->
blk_size
=
NextPowerOf2
(
reqSize
);
return
static_cast
<
addr_t
>
(
rel_addr
*
min_
);
}
else
{
return
NOSPACE
;
}
}
void
BuddySpace
::
FreeNoLock
(
const
BSpaceDescriptor
*
desc
)
{
DS_ASSERT
(
desc
->
sig
==
0XDEADBEEF
);
rel_addr_t
rel_addr
=
desc
->
addr
;
size_t
blk_size
=
desc
->
blk_size
;
size_t
req_size
=
desc
->
req_size
;
FreeBuddySeg
(
rel_addr
,
blk_size
,
req_size
);
}
void
BuddySpace
::
Free
(
const
BSpaceDescriptor
*
desc
)
{
std
::
lock_guard
<
std
::
mutex
>
lock
(
mutex_
);
return
FreeNoLock
(
desc
);
}
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
BuddySpace
&
s
)
{
os
<<
"1 unit = "
<<
s
.
GetMinSize
()
<<
"
\n
"
<<
"Size of buddy space = "
<<
s
.
GetMaxSize
()
<<
"
\n
"
<<
"Number of levels = "
<<
s
.
num_lvl_
<<
"
\n\n
"
<<
"Percent free = "
<<
s
.
PercentFree
()
<<
"
\n
"
<<
"Dumping count array : "
<<
"
\n
"
;
for
(
int
i
=
0
;
i
<
s
.
num_lvl_
;
i
++
)
{
os
<<
"["
<<
i
<<
"] = "
<<
s
.
count_
[
i
]
<<
" "
;
if
(((
i
+
1
)
%
4
)
==
0
)
{
os
<<
"
\n
"
;
}
}
os
<<
"
\n
"
;
os
<<
"Dumping allocation info:"
<<
"
\n
"
;
auto
max_addr
=
static_cast
<
rel_addr_t
>
(
BitLeftShift
(
1
,
s
.
num_lvl_
-
1
));
rel_addr_t
addr
=
0
;
while
(
addr
<
max_addr
)
{
size_t
sz
=
0
;
BuddySpace
::
STATE
st
;
s
.
GetBuddySegState
(
addr
,
&
sz
,
&
st
);
os
<<
"Address : "
<<
std
::
left
<<
std
::
setw
(
8
)
<<
addr
<<
" Size : "
<<
std
::
setw
(
8
)
<<
sz
<<
" State : "
<<
((
st
==
BuddySpace
::
STATE
::
kAlloc
)
?
"ALLOC"
:
((
st
==
BuddySpace
::
STATE
::
kFree
)
?
"FREE"
:
"Unkonwn"
))
<<
"
\n
"
;
addr
+=
sz
;
}
return
os
;
}
void
BuddySpace
::
GetBuddySegState
(
const
rel_addr_t
rel_addr
,
size_t
*
rel_sz
,
STATE
*
st
)
const
{
char
byte
;
int
pos
;
int
offset
;
uint64_t
val
=
0
;
int
shift
;
pos
=
BitRightShift
(
rel_addr
,
2
);
offset
=
rel_addr
%
4
;
shift
=
offset
*
2
;
byte
=
map_
[
pos
];
switch
(
offset
)
{
case
0
:
val
=
byte
;
break
;
case
1
:
case
3
:
if
(
offset
==
1
)
{
val
=
BitLeftShift
(
BitAnd
(
byte
,
0x30
),
shift
);
}
else
{
val
=
BitLeftShift
(
BitAnd
(
byte
,
0x03
),
shift
);
}
break
;
case
2
:
val
=
BitLeftShift
(
BitAnd
(
byte
,
0x0F
),
shift
);
break
;
}
if
(
BitAnd
(
val
,
ONE_BIT
))
{
*
rel_sz
=
1
;
}
else
if
(
BitAnd
(
val
,
TWO_BIT
))
{
*
rel_sz
=
2
;
}
else
if
(
BitAnd
(
val
,
MORE_BIT
))
{
log_t
lg
=
BitAnd
(
val
,
0x0F
);
*
rel_sz
=
BitLeftShift
(
1
,
lg
+
2
);
}
else
{
*
st
=
STATE
::
kEmpty
;
return
;
}
*
st
=
BitAnd
(
val
,
ALLOC_BIT
)
?
STATE
::
kAlloc
:
STATE
::
kFree
;
}
void
BuddySpace
::
SetBuddySegState
(
rel_addr_t
rel_addr
,
size_t
rel_sz
,
STATE
st
)
{
int
clr
;
int
mask
;
int
pos
;
int
offset
;
int
val
=
0
;
int
shift
;
auto
log_sz
=
static_cast
<
log_t
>
(
Log2
(
rel_sz
));
pos
=
BitRightShift
(
rel_addr
,
2
);
offset
=
rel_addr
%
4
;
shift
=
offset
*
2
;
if
(
rel_sz
==
1
)
{
val
=
ONE_BIT
;
mask
=
0xC0
;
}
else
if
(
rel_sz
==
2
)
{
val
=
TWO_BIT
;
mask
=
0xF0
;
}
else
{
val
=
BitOr
(
log_sz
-
2
,
MORE_BIT
);
mask
=
0xFF
;
}
if
(
st
==
STATE
::
kAlloc
)
{
val
=
BitOr
(
val
,
ALLOC_BIT
);
}
else
if
(
st
==
STATE
::
kFree
)
{
val
=
BitAnd
(
val
,
~
(
static_cast
<
uint64_t
>
(
ALLOC_BIT
)));
}
else
if
(
st
==
STATE
::
kEmpty
)
{
val
=
0
;
}
clr
=
static_cast
<
int
>
(
~
(
BitRightShift
(
mask
,
shift
)));
map_
[
pos
]
=
static_cast
<
char
>
(
BitAnd
(
map_
[
pos
],
clr
));
map_
[
pos
]
=
static_cast
<
char
>
(
BitOr
(
map_
[
pos
],
BitRightShift
(
val
,
shift
)));
if
(
st
==
STATE
::
kAlloc
)
{
count_
[
log_sz
]
--
;
}
else
if
(
st
==
STATE
::
kFree
)
{
count_
[
log_sz
]
++
;
if
(
rel_addr
<
hint_
[
log_sz
])
{
hint_
[
log_sz
]
=
rel_addr
;
}
}
}
void
BuddySpace
::
JoinBuddySeg
(
rel_addr_t
addr
,
size_t
blk_sz
)
{
while
(
blk_sz
<
BitLeftShift
(
1
,
num_lvl_
))
{
rel_addr_t
buddy
=
BitEx
(
addr
,
blk_sz
);
size_t
sz
=
0
;
STATE
st
;
GetBuddySegState
(
buddy
,
&
sz
,
&
st
);
if
(
st
==
STATE
::
kFree
&&
sz
==
blk_sz
)
{
auto
log_sz
=
static_cast
<
log_t
>
(
Log2
(
blk_sz
));
rel_addr_t
left
=
(
buddy
<
addr
)
?
buddy
:
addr
;
rel_addr_t
right
=
left
+
blk_sz
;
DS_ASSERT
(
count_
[
log_sz
]
>=
2
);
count_
[
log_sz
]
-=
2
;
SetBuddySegState
(
right
,
blk_sz
,
STATE
::
kEmpty
);
SetBuddySegState
(
left
,
BitLeftShift
(
blk_sz
,
1
),
STATE
::
kFree
);
for
(
int
i
=
0
;
i
<
log_sz
;
i
++
)
{
if
(
hint_
[
i
]
==
right
)
{
hint_
[
i
]
=
left
;
}
}
addr
=
left
;
blk_sz
<<=
1u
;
}
else
{
break
;
}
}
}
void
BuddySpace
::
TrimBuddySeg
(
rel_addr_t
addr
,
size_t
blk_sz
,
size_t
ask_sz
)
{
DS_ASSERT
(
ask_sz
<
blk_sz
);
uint32_t
inx
=
Log2
(
blk_sz
);
size_t
remaining_sz
=
ask_sz
;
for
(
int
i
=
inx
;
i
>
0
;
i
--
)
{
size_t
b_size
=
BitLeftShift
(
1
,
i
);
size_t
half_sz
=
BitRightShift
(
b_size
,
1
);
count_
[
i
]
--
;
SetBuddySegState
(
addr
,
half_sz
,
STATE
::
kFree
);
SetBuddySegState
(
addr
+
half_sz
,
half_sz
,
STATE
::
kFree
);
if
(
remaining_sz
>=
half_sz
)
{
SetBuddySegState
(
addr
,
half_sz
,
STATE
::
kAlloc
);
remaining_sz
-=
half_sz
;
if
(
remaining_sz
==
0
)
{
break
;
}
addr
+=
half_sz
;
}
}
}
void
BuddySpace
::
UnTrimBuddySeg
(
rel_addr_t
addr
,
size_t
blk_sz
,
size_t
ask_sz
)
{
DS_ASSERT
(
ask_sz
<
blk_sz
);
uint32_t
inx
=
Log2
(
blk_sz
);
size_t
remaining_sz
=
ask_sz
;
for
(
int
i
=
inx
;
i
>
0
;
i
--
)
{
size_t
b_size
=
BitLeftShift
(
1
,
i
);
size_t
half_sz
=
BitRightShift
(
b_size
,
1
);
if
(
remaining_sz
>=
half_sz
)
{
#ifdef DEBUG
{
size_t
sz
=
0
;
STATE
st
;
GetBuddySegState
(
addr
,
&
sz
,
&
st
);
DS_ASSERT
(
sz
==
half_sz
&&
st
==
STATE
::
kAlloc
);
}
#endif
SetBuddySegState
(
addr
,
half_sz
,
STATE
::
kFree
);
remaining_sz
-=
half_sz
;
if
(
remaining_sz
==
0
)
{
JoinBuddySeg
(
addr
,
half_sz
);
break
;
}
addr
+=
half_sz
;
}
}
}
rel_addr_t
BuddySpace
::
AllocBuddySeg
(
uint32_t
req_size
)
noexcept
{
uint32_t
blk_size
=
NextPowerOf2
(
req_size
);
int
start_inx
=
static_cast
<
int
>
(
Log2
(
blk_size
));
bool
found
=
false
;
rel_addr_t
ask_addr
=
0
;
auto
max_addr
=
static_cast
<
rel_addr_t
>
(
BitLeftShift
(
1
,
num_lvl_
-
1
));
STATE
st
;
size_t
sz
=
0
;
for
(
int
i
=
start_inx
;
!
found
&&
i
<
num_lvl_
;
i
++
)
{
DS_ASSERT
(
count_
[
i
]
>=
0
);
if
(
count_
[
i
]
==
0
)
{
continue
;
}
auto
blk_sz
=
static_cast
<
size_t
>
(
BitLeftShift
(
1
,
i
));
ask_addr
=
hint_
[
i
];
while
(
ask_addr
<
max_addr
&&
!
found
)
{
GetBuddySegState
(
ask_addr
,
&
sz
,
&
st
);
if
(
st
==
STATE
::
kFree
&&
sz
==
blk_sz
)
{
found
=
true
;
}
else
{
DS_ASSERT
(
st
!=
STATE
::
kEmpty
);
ask_addr
+=
((
sz
>
blk_sz
)
?
sz
:
blk_sz
);
}
}
}
if
(
found
)
{
if
(
sz
>
req_size
)
{
TrimBuddySeg
(
ask_addr
,
sz
,
req_size
);
}
else
{
SetBuddySegState
(
ask_addr
,
sz
,
STATE
::
kAlloc
);
hint_
[
start_inx
]
=
ask_addr
;
}
return
ask_addr
;
}
else
{
return
static_cast
<
rel_addr_t
>
(
NOSPACE
);
}
}
void
BuddySpace
::
FreeBuddySeg
(
rel_addr_t
addr
,
size_t
blk_size
,
size_t
req_size
)
{
if
(
req_size
==
blk_size
)
{
#ifdef DEBUG
{
size_t
sz
=
0
;
STATE
st
;
GetBuddySegState
(
addr
,
&
sz
,
&
st
);
}
#endif
SetBuddySegState
(
addr
,
blk_size
,
STATE
::
kFree
);
JoinBuddySeg
(
addr
,
blk_size
);
}
else
{
UnTrimBuddySeg
(
addr
,
blk_size
,
req_size
);
}
}
int
BuddySpace
::
PercentFree
()
const
{
uint64_t
total_free_sz
=
0
;
uint64_t
max_sz_in_unit
=
BitLeftShift
(
1
,
num_lvl_
-
1
);
// Go through the count array without lock
for
(
int
i
=
0
;
i
<
num_lvl_
;
i
++
)
{
int
cnt
=
count_
[
i
];
if
(
cnt
==
0
)
{
continue
;
}
uint64_t
blk_sz
=
BitLeftShift
(
1
,
i
);
total_free_sz
+=
(
blk_sz
*
cnt
);
}
return
static_cast
<
int
>
(
static_cast
<
float
>
(
total_free_sz
)
/
static_cast
<
float
>
(
max_sz_in_unit
)
*
100
);
}
BuddySpace
::
BuddySpace
(
int
log_min
,
int
num_lvl
)
:
hint_
(
nullptr
),
count_
(
nullptr
),
map_
(
nullptr
),
log_min_
(
log_min
),
num_lvl_
(
num_lvl
),
min_
(
0
),
max_
(
0
),
ptr_
(
nullptr
)
{}
BuddySpace
::~
BuddySpace
()
{
if
(
ptr_
!=
nullptr
)
{
free
(
ptr_
);
}
hint_
=
nullptr
;
count_
=
nullptr
;
map_
=
nullptr
;
}
Status
BuddySpace
::
CreateBuddySpace
(
std
::
unique_ptr
<
BuddySpace
>
*
out_bs
,
int
log_min
,
int
num_lvl
)
{
Status
rc
;
auto
bs
=
new
(
std
::
nothrow
)
BuddySpace
(
log_min
,
num_lvl
);
if
(
bs
==
nullptr
)
{
return
Status
(
StatusCode
::
kOutOfMemory
);
}
rc
=
bs
->
Init
();
if
(
rc
.
IsOk
())
{
(
*
out_bs
).
reset
(
bs
);
}
else
{
delete
bs
;
}
return
rc
;
}
}
// namespace dataset
}
// namespace mindspore
mindspore/ccsrc/dataset/util/buddy.h
0 → 100644
浏览文件 @
5a886794
/**
* Copyright 2019 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef DATASET_UTIL_BUDDY_H_
#define DATASET_UTIL_BUDDY_H_
#include <cstddef>
#include <cstdint>
#include <cstring>
#include <iostream>
#include <memory>
#include <mutex>
#include "dataset/util/status.h"
using
addr_t
=
int64_t
;
using
rel_addr_t
=
int32_t
;
using
log_t
=
int
;
#define ALLOC_BIT 0x80
#define ONE_BIT 0x40
#define TWO_BIT 0x20
#define MORE_BIT 0x10
#define NOSPACE ((addr_t)(-1))
namespace
mindspore
{
namespace
dataset
{
struct
BSpaceDescriptor
{
int32_t
sig
;
rel_addr_t
addr
;
size_t
req_size
;
size_t
blk_size
;
};
class
BuddySpace
{
public:
// C++11 feature. Change STATE into a type safe class with
// the keyword. Don't take out the keyword 'class'
enum
class
STATE
{
kFree
,
kAlloc
,
kEmpty
};
BuddySpace
(
const
BuddySpace
&
)
=
delete
;
BuddySpace
&
operator
=
(
const
BuddySpace
&
)
=
delete
;
virtual
~
BuddySpace
();
Status
Alloc
(
uint64_t
sz
,
BSpaceDescriptor
*
desc
,
addr_t
*
)
noexcept
;
void
Free
(
const
BSpaceDescriptor
*
desc
);
uint64_t
GetMinSize
()
const
{
return
min_
;
}
uint64_t
GetMaxSize
()
const
{
return
max_
;
}
int
PercentFree
()
const
;
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
BuddySpace
&
s
);
static
uint64_t
NextPowerOf2
(
uint64_t
n
)
{
if
(
n
<=
1
)
{
return
1
;
}
n
=
n
-
1
;
while
(
n
&
(
n
-
1
))
{
n
=
n
&
(
n
-
1
);
}
return
n
<<
1
;
}
static
uint32_t
Log2
(
uint64_t
n
)
{
uint32_t
cnt
=
0
;
while
(
n
>>=
1
)
{
cnt
++
;
}
return
cnt
;
}
static
Status
CreateBuddySpace
(
std
::
unique_ptr
<
BuddySpace
>
*
out_bs
,
int
log_min
=
15
,
int
num_lvl
=
18
);
private:
rel_addr_t
*
hint_
;
int
*
count_
;
char
*
map_
;
int
log_min_
;
int
num_lvl_
;
uint64_t
min_
;
uint64_t
max_
;
void
*
ptr_
;
std
::
mutex
mutex_
;
explicit
BuddySpace
(
int
log_min
=
15
,
int
num_lvl
=
18
);
Status
Init
();
addr_t
AllocNoLock
(
const
uint64_t
sz
,
BSpaceDescriptor
*
desc
)
noexcept
;
void
FreeNoLock
(
const
BSpaceDescriptor
*
desc
);
uint32_t
SizeToBlock
(
const
uint64_t
sz
)
const
{
uint32_t
reqSize
=
(
sz
/
min_
);
if
(
sz
%
min_
)
{
reqSize
++
;
}
return
reqSize
;
}
void
GetBuddySegState
(
const
rel_addr_t
rel_addr
,
size_t
*
rel_sz
,
STATE
*
st
)
const
;
void
SetBuddySegState
(
rel_addr_t
rel_addr
,
size_t
rel_sz
,
STATE
st
);
void
JoinBuddySeg
(
rel_addr_t
addr
,
size_t
blk_sz
);
void
TrimBuddySeg
(
rel_addr_t
addr
,
size_t
blk_sz
,
size_t
ask_sz
);
void
UnTrimBuddySeg
(
rel_addr_t
addr
,
size_t
blk_sz
,
size_t
ask_sz
);
rel_addr_t
AllocBuddySeg
(
uint32_t
req_size
)
noexcept
;
void
FreeBuddySeg
(
rel_addr_t
addr
,
size_t
blk_size
,
size_t
req_size
);
};
}
// namespace dataset
}
// namespace mindspore
#endif // DATASET_UTIL_BUDDY_H_
mindspore/ccsrc/dataset/util/cache_pool.cc
0 → 100644
浏览文件 @
5a886794
/**
* Copyright 2019 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <algorithm>
#include "common/utils.h"
#include "dataset/util/cache_pool.h"
#include "dataset/util/services.h"
namespace
mindspore
{
namespace
dataset
{
CachePool
::
CachePool
(
const
value_allocator
&
alloc
,
const
std
::
string
&
root
)
:
alloc_
(
alloc
),
root_
(
root
),
subfolder_
(
Services
::
GetUniqueID
()),
sm_
(
nullptr
),
tree_
(
nullptr
)
{}
Status
CachePool
::
DoServiceStart
()
{
tree_
=
std
::
make_shared
<
data_index
>
();
// If we are given a disk path, set up the StorageManager
if
(
!
root_
.
toString
().
empty
())
{
Path
spill
=
GetSpillPath
();
RETURN_IF_NOT_OK
(
spill
.
CreateDirectories
());
sm_
=
std
::
make_shared
<
StorageManager
>
(
spill
);
RETURN_IF_NOT_OK
(
sm_
->
ServiceStart
());
MS_LOG
(
INFO
)
<<
"CachePool will use disk folder: "
<<
common
::
SafeCStr
(
spill
.
toString
());
}
return
Status
::
OK
();
}
Status
CachePool
::
DoServiceStop
()
{
Status
rc
;
Status
rc2
;
if
(
sm_
!=
nullptr
)
{
rc
=
sm_
->
ServiceStop
();
if
(
rc
.
IsError
())
{
rc2
=
rc
;
}
}
sm_
.
reset
();
for
(
auto
&
bl
:
*
tree_
)
{
if
(
bl
.
ptr
!=
nullptr
)
{
alloc_
.
deallocate
(
bl
.
ptr
,
bl
.
sz
);
}
}
tree_
.
reset
();
if
(
!
root_
.
toString
().
empty
())
{
Path
spill
=
GetSpillPath
();
auto
it
=
Path
::
DirIterator
::
OpenDirectory
(
&
spill
);
while
(
it
->
hasNext
())
{
rc
=
it
->
next
().
Remove
();
if
(
rc
.
IsError
()
&&
rc2
.
IsOk
())
{
rc2
=
rc
;
}
}
rc
=
spill
.
Remove
();
if
(
rc
.
IsError
()
&&
rc2
.
IsOk
())
{
rc2
=
rc
;
}
}
return
rc2
;
}
CachePool
::~
CachePool
()
noexcept
{
(
void
)
ServiceStop
();
}
Status
CachePool
::
Insert
(
const
std
::
vector
<
ReadableSlice
>
&
buf
,
CachePool
::
key_type
*
key
)
{
DataLocator
bl
;
Status
rc
;
size_t
sz
=
0
;
// We will consolidate all the slices into one piece.
for
(
auto
&
v
:
buf
)
{
sz
+=
v
.
GetSize
();
}
bl
.
sz
=
sz
;
try
{
bl
.
ptr
=
alloc_
.
allocate
(
sz
);
// We will do a piecewise copy.
WritableSlice
dest
(
bl
.
ptr
,
bl
.
sz
);
size_t
pos
=
0
;
for
(
auto
&
v
:
buf
)
{
WritableSlice
out
(
dest
,
pos
);
rc
=
WritableSlice
::
Copy
(
&
out
,
v
);
if
(
rc
.
IsError
())
{
break
;
}
pos
+=
v
.
GetSize
();
}
if
(
rc
.
IsError
())
{
alloc_
.
deallocate
(
bl
.
ptr
,
sz
);
bl
.
ptr
=
nullptr
;
return
rc
;
}
}
catch
(
std
::
bad_alloc
&
e
)
{
if
(
sm_
!=
nullptr
)
{
RETURN_IF_NOT_OK
(
sm_
->
Write
(
&
bl
.
storage_key
,
buf
));
// We have an assumption 0 is not a valid key from the design of AutoIndexObj.
// Make sure it is not 0.
if
(
bl
.
storage_key
==
0
)
{
RETURN_STATUS_UNEXPECTED
(
"Key 0 is returned which is unexpected"
);
}
}
else
{
return
Status
(
StatusCode
::
kOutOfMemory
,
__LINE__
,
__FILE__
);
}
}
rc
=
tree_
->
insert
(
bl
,
key
);
if
(
rc
.
IsError
()
&&
bl
.
ptr
!=
nullptr
)
{
alloc_
.
deallocate
(
bl
.
ptr
,
sz
);
}
return
rc
;
}
Status
CachePool
::
Read
(
CachePool
::
key_type
key
,
WritableSlice
*
dest
,
size_t
*
bytesRead
)
const
{
RETURN_UNEXPECTED_IF_NULL
(
dest
);
auto
r
=
tree_
->
Search
(
key
);
if
(
r
.
second
)
{
auto
&
it
=
r
.
first
;
if
(
it
->
ptr
!=
nullptr
)
{
ReadableSlice
src
(
it
->
ptr
,
it
->
sz
);
RETURN_IF_NOT_OK
(
WritableSlice
::
Copy
(
dest
,
src
));
}
else
if
(
sm_
!=
nullptr
)
{
size_t
expectedLength
=
0
;
RETURN_IF_NOT_OK
(
sm_
->
Read
(
it
->
storage_key
,
dest
,
&
expectedLength
));
if
(
expectedLength
!=
it
->
sz
)
{
MS_LOG
(
ERROR
)
<<
"Unexpected length. Read "
<<
expectedLength
<<
". Expected "
<<
it
->
sz
<<
"."
<<
" Internal key: "
<<
key
<<
"
\n
"
;
RETURN_STATUS_UNEXPECTED
(
"Length mismatch. See log file for details."
);
}
}
if
(
bytesRead
!=
nullptr
)
{
*
bytesRead
=
it
->
sz
;
}
}
else
{
RETURN_STATUS_UNEXPECTED
(
"Key not found"
);
}
return
Status
::
OK
();
}
const
CachePool
::
value_allocator
&
CachePool
::
get_allocator
()
const
{
return
alloc_
;
}
Path
CachePool
::
GetSpillPath
()
const
{
auto
spill
=
Path
(
root_
)
/
subfolder_
;
return
spill
;
}
CachePool
::
CacheStat
CachePool
::
GetStat
()
const
{
CacheStat
cs
{
0
};
for
(
auto
&
it
:
*
tree_
)
{
if
(
it
.
ptr
!=
nullptr
)
{
++
cs
.
num_mem_cached
;
}
else
{
++
cs
.
num_disk_cached
;
}
}
return
cs
;
}
Status
CachePool
::
Spill
(
CachePool
::
DataLocator
*
dl
)
{
if
(
sm_
==
nullptr
)
{
RETURN_STATUS_UNEXPECTED
(
"No disk storage to spill"
);
}
RETURN_UNEXPECTED_IF_NULL
(
dl
);
RETURN_UNEXPECTED_IF_NULL
(
dl
->
ptr
);
if
(
dl
->
storage_key
==
0
)
{
ReadableSlice
data
(
dl
->
ptr
,
dl
->
sz
);
RETURN_IF_NOT_OK
(
sm_
->
Write
(
&
dl
->
storage_key
,
{
data
}));
}
alloc_
.
deallocate
(
dl
->
ptr
,
dl
->
sz
);
dl
->
ptr
=
nullptr
;
return
Status
::
OK
();
}
Status
CachePool
::
Locate
(
CachePool
::
DataLocator
*
dl
)
{
RETURN_UNEXPECTED_IF_NULL
(
dl
);
if
(
dl
->
ptr
==
nullptr
)
{
if
(
sm_
==
nullptr
)
{
RETURN_STATUS_UNEXPECTED
(
"No disk storage to locate the data"
);
}
try
{
dl
->
ptr
=
alloc_
.
allocate
(
dl
->
sz
);
WritableSlice
dest
(
dl
->
ptr
,
dl
->
sz
);
Status
rc
=
Read
(
dl
->
storage_key
,
&
dest
);
if
(
rc
.
IsError
())
{
alloc_
.
deallocate
(
dl
->
ptr
,
dl
->
sz
);
dl
->
ptr
=
nullptr
;
return
rc
;
}
}
catch
(
const
std
::
bad_alloc
&
e
)
{
return
Status
(
StatusCode
::
kOutOfMemory
,
__LINE__
,
__FILE__
);
}
}
return
Status
::
OK
();
}
size_t
CachePool
::
GetSize
(
CachePool
::
key_type
key
)
const
{
auto
r
=
tree_
->
Search
(
key
);
if
(
r
.
second
)
{
auto
&
it
=
r
.
first
;
return
it
->
sz
;
}
else
{
return
0
;
}
}
}
// namespace dataset
}
// namespace mindspore
mindspore/ccsrc/dataset/util/cache_pool.h
0 → 100644
浏览文件 @
5a886794
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef DATASET_UTIL_CACHE_POOL_H_
#define DATASET_UTIL_CACHE_POOL_H_
#include <memory>
#include <mutex>
#include <string>
#include <vector>
#include "dataset/util/allocator.h"
#include "dataset/util/service.h"
#include "dataset/util/slice.h"
#include "dataset/util/storage_manager.h"
#include "dataset/util/auto_index.h"
namespace
mindspore
{
namespace
dataset
{
/// \brief A CachePool provides service for backup/restore a buffer. A buffer can be represented in a form of vector of
/// ReadableSlice where all memory blocks will be copied to one contiguous block which can be in memory or spilled to
/// disk (if a disk directory is provided). Every buffer insert will return a generated key which can be used to
/// restore the buffer.
/// \see ReadableSlice
class
CachePool
:
public
Service
{
public:
using
base_type
=
uint8_t
;
using
pointer
=
base_type
*
;
using
const_pointer
=
const
base_type
*
;
using
reference
=
base_type
&
;
using
const_reference
=
const
base_type
&
;
using
value_allocator
=
Allocator
<
base_type
>
;
// An internal class to locate the whereabouts of a backed up buffer which can be either in
class
DataLocator
{
public:
DataLocator
()
:
ptr
(
nullptr
),
sz
(
0
),
storage_key
(
0
)
{}
~
DataLocator
()
=
default
;
DataLocator
(
const
DataLocator
&
other
)
=
default
;
DataLocator
&
operator
=
(
const
DataLocator
&
other
)
=
default
;
DataLocator
(
DataLocator
&&
other
)
noexcept
{
ptr
=
other
.
ptr
;
sz
=
other
.
sz
;
storage_key
=
other
.
storage_key
;
other
.
ptr
=
nullptr
;
other
.
sz
=
0
;
other
.
storage_key
=
0
;
}
DataLocator
&
operator
=
(
DataLocator
&&
other
)
noexcept
{
if
(
&
other
!=
this
)
{
ptr
=
other
.
ptr
;
sz
=
other
.
sz
;
storage_key
=
other
.
storage_key
;
other
.
ptr
=
nullptr
;
other
.
sz
=
0
;
other
.
storage_key
=
0
;
}
return
*
this
;
}
pointer
ptr
;
size_t
sz
;
StorageManager
::
key_type
storage_key
;
};
using
data_index
=
AutoIndexObj
<
DataLocator
>
;
using
key_type
=
data_index
::
key_type
;
using
bl_alloc_type
=
typename
value_allocator
::
template
rebind
<
DataLocator
>
::
other
;
/// \brief Simple statistics returned from CachePool like how many elements are cached in memory and
/// how many elements are spilled to disk.
struct
CacheStat
{
int64_t
num_mem_cached
;
int64_t
num_disk_cached
;
};
/// \brief Constructor
/// \param alloc Allocator to allocate memory from
/// \param root Optional disk folder to spill
explicit
CachePool
(
const
value_allocator
&
alloc
,
const
std
::
string
&
root
=
""
);
CachePool
(
const
CachePool
&
)
=
delete
;
CachePool
(
CachePool
&&
)
=
delete
;
CachePool
&
operator
=
(
const
CachePool
&
)
=
delete
;
CachePool
&
operator
=
(
CachePool
&&
)
=
delete
;
~
CachePool
()
noexcept
;
Status
DoServiceStart
()
override
;
Status
DoServiceStop
()
override
;
Path
GetSpillPath
()
const
;
/// \brief Insert a sequence of ReadableSlice objects into the pool.
/// All memory blocks will be consolidated into one contiguous block and be cached in either memory or on disk.
/// \param[in] buf A sequence of ReadableSlice objects.
/// \param[out] key Generated key
/// \return Error code
Status
Insert
(
const
std
::
vector
<
ReadableSlice
>
&
buf
,
key_type
*
key
);
/// \brief Restore a cached buffer (from memory or disk)
/// \param[in] key A previous key returned from Insert
/// \param[out] dest The cached buffer will be copied to this destination represented by a WritableSlice
/// \param[out] bytesRead Optional. Number of bytes read.
/// \return Error code
Status
Read
(
key_type
key
,
WritableSlice
*
dest
,
size_t
*
bytesRead
=
nullptr
)
const
;
Status
Spill
(
DataLocator
*
dl
);
Status
Locate
(
DataLocator
*
dl
);
size_t
GetSize
(
key_type
key
)
const
;
/// \brief Get statistics.
/// \return CacheStat object
CacheStat
GetStat
()
const
;
const
value_allocator
&
get_allocator
()
const
;
std
::
string
MyName
()
const
{
return
subfolder_
;
}
private:
value_allocator
alloc_
;
Path
root_
;
const
std
::
string
subfolder_
;
std
::
shared_ptr
<
StorageManager
>
sm_
;
std
::
shared_ptr
<
data_index
>
tree_
;
};
}
// namespace dataset
}
// namespace mindspore
#endif
mindspore/ccsrc/dataset/util/list.h
浏览文件 @
5a886794
...
@@ -106,6 +106,24 @@ struct List {
...
@@ -106,6 +106,24 @@ struct List {
++
count
;
++
count
;
}
}
// Insert elem2 before elem1 in the list.
virtual
void
InsertBefore
(
pointer
elem1
,
pointer
elem2
)
{
DS_ASSERT
(
elem1
!=
elem2
);
Node
<
T
>
&
elem1_node
=
elem1
->*
node
;
Node
<
T
>
&
elem2_node
=
elem2
->*
node
;
elem2_node
.
next
=
elem1
;
elem2_node
.
prev
=
elem1_node
.
prev
;
if
(
elem1_node
.
prev
!=
nullptr
)
{
Node
<
T
>
&
prev_node
=
elem1_node
.
prev
->*
node
;
prev_node
.
next
=
elem2
;
}
elem1_node
.
prev
=
elem2
;
if
(
head
==
elem1
)
{
head
=
elem2
;
}
++
count
;
}
// Remove an element in the list
// Remove an element in the list
virtual
void
Remove
(
pointer
elem
)
noexcept
{
virtual
void
Remove
(
pointer
elem
)
noexcept
{
Node
<
T
>
&
elem_node
=
elem
->*
node
;
Node
<
T
>
&
elem_node
=
elem
->*
node
;
...
...
mindspore/ccsrc/dataset/util/memory_pool.h
浏览文件 @
5a886794
...
@@ -44,20 +44,6 @@ class MemoryPool {
...
@@ -44,20 +44,6 @@ class MemoryPool {
virtual
~
MemoryPool
()
{}
virtual
~
MemoryPool
()
{}
};
};
// Used by unique_ptr
template
<
typename
T
>
class
Deleter
{
public:
explicit
Deleter
(
std
::
shared_ptr
<
MemoryPool
>
&
mp
)
:
mp_
(
mp
)
{}
~
Deleter
()
=
default
;
void
operator
()(
T
*
ptr
)
const
{
mp_
->
Deallocate
(
ptr
);
}
private:
std
::
shared_ptr
<
MemoryPool
>
mp_
;
};
Status
DeMalloc
(
std
::
size_t
s
,
void
**
p
,
bool
);
Status
DeMalloc
(
std
::
size_t
s
,
void
**
p
,
bool
);
}
// namespace dataset
}
// namespace dataset
}
// namespace mindspore
}
// namespace mindspore
...
...
mindspore/ccsrc/dataset/util/path.cc
浏览文件 @
5a886794
...
@@ -16,6 +16,8 @@
...
@@ -16,6 +16,8 @@
#include "dataset/util/path.h"
#include "dataset/util/path.h"
#include <sys/stat.h>
#include <sys/stat.h>
#include <fcntl.h>
#include <unistd.h>
#include <new>
#include <new>
#include <sstream>
#include <sstream>
#include <utility>
#include <utility>
...
@@ -26,7 +28,7 @@
...
@@ -26,7 +28,7 @@
namespace
mindspore
{
namespace
mindspore
{
namespace
dataset
{
namespace
dataset
{
#if
def _WIN32
#if
defined(_WIN32) || defined(_WIN64)
char
Path
::
separator_
=
'\\'
;
char
Path
::
separator_
=
'\\'
;
#else
#else
char
Path
::
separator_
=
'/'
;
char
Path
::
separator_
=
'/'
;
...
@@ -132,7 +134,7 @@ Status Path::CreateDirectory() {
...
@@ -132,7 +134,7 @@ Status Path::CreateDirectory() {
#if defined(_WIN32) || defined(_WIN64)
#if defined(_WIN32) || defined(_WIN64)
int
rc
=
mkdir
(
common
::
SafeCStr
(
path_
));
int
rc
=
mkdir
(
common
::
SafeCStr
(
path_
));
#else
#else
int
rc
=
mkdir
(
common
::
SafeCStr
(
path_
),
0700
);
int
rc
=
mkdir
(
common
::
SafeCStr
(
path_
),
S_IRUSR
|
S_IWUSR
|
S_IXUSR
);
#endif
#endif
if
(
rc
)
{
if
(
rc
)
{
std
::
ostringstream
oss
;
std
::
ostringstream
oss
;
...
@@ -182,6 +184,111 @@ Status Path::CreateDirectories() {
...
@@ -182,6 +184,111 @@ Status Path::CreateDirectories() {
return
Status
::
OK
();
return
Status
::
OK
();
}
}
Status
Path
::
Remove
()
{
if
(
Exists
())
{
if
(
IsDirectory
())
{
errno_t
err
=
rmdir
(
common
::
SafeCStr
(
path_
));
if
(
err
==
-
1
)
{
std
::
ostringstream
oss
;
oss
<<
"Unable to delete directory "
<<
path_
<<
". Errno = "
<<
errno
;
RETURN_STATUS_UNEXPECTED
(
oss
.
str
());
}
}
else
{
errno_t
err
=
unlink
(
common
::
SafeCStr
(
path_
));
if
(
err
==
-
1
)
{
std
::
ostringstream
oss
;
oss
<<
"Unable to delete file "
<<
path_
<<
". Errno = "
<<
errno
;
RETURN_STATUS_UNEXPECTED
(
oss
.
str
());
}
}
}
return
Status
::
OK
();
}
Status
Path
::
CreateFile
(
int
*
file_descriptor
)
{
return
OpenFile
(
file_descriptor
,
true
);
}
Status
Path
::
OpenFile
(
int
*
file_descriptor
,
bool
create
)
{
int
fd
;
if
(
file_descriptor
==
nullptr
)
{
RETURN_STATUS_UNEXPECTED
(
"null pointer"
);
}
if
(
IsDirectory
())
{
std
::
ostringstream
oss
;
oss
<<
"Unable to create file "
<<
path_
<<
" which is a directory."
;
RETURN_STATUS_UNEXPECTED
(
oss
.
str
());
}
// Convert to canonical form.
if
(
strlen
(
common
::
SafeCStr
(
path_
))
>
PATH_MAX
)
{
RETURN_STATUS_UNEXPECTED
(
strerror
(
errno
));
}
char
canonical_path
[
PATH_MAX
+
1
]
=
{
0x00
};
#if defined(_WIN32) || defined(_WIN64)
if
(
_fullpath
(
canonical_path
,
common
::
SafeCStr
(
path_
),
PATH_MAX
)
==
nullptr
)
{
#else
if
(
realpath
(
common
::
SafeCStr
(
path_
),
canonical_path
)
==
nullptr
)
{
#endif
if
(
errno
==
ENOENT
&&
create
)
{
// File doesn't exist and we are to create it. Let's break it down.
auto
file_part
=
Basename
();
auto
parent_part
=
ParentPath
();
#if defined(_WIN32) || defined(_WIN64)
if
(
_fullpath
(
canonical_path
,
common
::
SafeCStr
(
parent_part
),
PATH_MAX
)
==
nullptr
)
{
#else
if
(
realpath
(
common
::
SafeCStr
(
parent_part
),
canonical_path
)
==
nullptr
)
{
#endif
RETURN_STATUS_UNEXPECTED
(
strerror
(
errno
));
}
auto
cur_inx
=
strlen
(
canonical_path
);
if
((
cur_inx
+
file_part
.
length
()
+
1
)
>
PATH_MAX
)
{
RETURN_STATUS_UNEXPECTED
(
strerror
(
errno
));
}
canonical_path
[
cur_inx
++
]
=
separator_
;
if
(
strncpy_s
(
canonical_path
+
cur_inx
,
PATH_MAX
-
cur_inx
,
common
::
SafeCStr
(
file_part
),
file_part
.
length
())
!=
EOK
)
{
RETURN_STATUS_UNEXPECTED
(
strerror
(
errno
));
}
}
else
{
RETURN_STATUS_UNEXPECTED
(
strerror
(
errno
));
}
}
if
(
create
)
{
fd
=
open
(
canonical_path
,
O_CREAT
|
O_TRUNC
|
O_RDWR
,
S_IRUSR
|
S_IWUSR
|
S_IRGRP
);
}
else
{
fd
=
open
(
canonical_path
,
O_RDWR
);
}
if
(
fd
==
-
1
)
{
RETURN_STATUS_UNEXPECTED
(
strerror
(
errno
));
}
*
file_descriptor
=
fd
;
return
Status
::
OK
();
}
Status
Path
::
CloseFile
(
int
fd
)
const
{
if
(
close
(
fd
)
<
0
)
{
RETURN_STATUS_UNEXPECTED
(
strerror
(
errno
));
}
return
Status
::
OK
();
}
Status
Path
::
TruncateFile
(
int
fd
)
const
{
int
rc
;
rc
=
ftruncate
(
fd
,
0
);
if
(
rc
==
0
)
{
return
Status
::
OK
();
}
else
{
RETURN_STATUS_UNEXPECTED
(
strerror
(
errno
));
}
}
std
::
string
Path
::
Basename
()
{
std
::
size_t
found
=
path_
.
find_last_of
(
separator_
);
if
(
found
!=
std
::
string
::
npos
)
{
return
path_
.
substr
(
found
+
1
);
}
else
{
return
path_
;
}
}
std
::
shared_ptr
<
Path
::
DirIterator
>
Path
::
DirIterator
::
OpenDirectory
(
Path
*
f
)
{
std
::
shared_ptr
<
Path
::
DirIterator
>
Path
::
DirIterator
::
OpenDirectory
(
Path
*
f
)
{
auto
it
=
new
(
std
::
nothrow
)
DirIterator
(
f
);
auto
it
=
new
(
std
::
nothrow
)
DirIterator
(
f
);
...
@@ -208,7 +315,7 @@ Path::DirIterator::~DirIterator() {
...
@@ -208,7 +315,7 @@ Path::DirIterator::~DirIterator() {
Path
::
DirIterator
::
DirIterator
(
Path
*
f
)
:
dir_
(
f
),
dp_
(
nullptr
),
entry_
(
nullptr
)
{
Path
::
DirIterator
::
DirIterator
(
Path
*
f
)
:
dir_
(
f
),
dp_
(
nullptr
),
entry_
(
nullptr
)
{
MS_LOG
(
DEBUG
)
<<
"Open directory "
<<
f
->
toString
()
<<
"."
;
MS_LOG
(
DEBUG
)
<<
"Open directory "
<<
f
->
toString
()
<<
"."
;
dp_
=
opendir
(
common
::
SafeCStr
(
f
->
toString
()
));
dp_
=
opendir
(
f
->
toString
().
c_str
(
));
}
}
bool
Path
::
DirIterator
::
hasNext
()
{
bool
Path
::
DirIterator
::
hasNext
()
{
...
@@ -225,5 +332,10 @@ bool Path::DirIterator::hasNext() {
...
@@ -225,5 +332,10 @@ bool Path::DirIterator::hasNext() {
}
}
Path
Path
::
DirIterator
::
next
()
{
return
(
*
(
this
->
dir_
)
/
Path
(
entry_
->
d_name
));
}
Path
Path
::
DirIterator
::
next
()
{
return
(
*
(
this
->
dir_
)
/
Path
(
entry_
->
d_name
));
}
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
Path
&
s
)
{
os
<<
s
.
path_
;
return
os
;
}
}
// namespace dataset
}
// namespace dataset
}
// namespace mindspore
}
// namespace mindspore
mindspore/ccsrc/dataset/util/path.h
浏览文件 @
5a886794
...
@@ -90,6 +90,20 @@ class Path {
...
@@ -90,6 +90,20 @@ class Path {
std
::
string
ParentPath
();
std
::
string
ParentPath
();
Status
Remove
();
Status
CreateFile
(
int
*
fd
);
Status
OpenFile
(
int
*
fd
,
bool
create
=
false
);
Status
CloseFile
(
int
fd
)
const
;
Status
TruncateFile
(
int
fd
)
const
;
std
::
string
Basename
();
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
Path
&
s
);
private:
private:
static
char
separator_
;
static
char
separator_
;
std
::
string
path_
;
std
::
string
path_
;
...
...
mindspore/ccsrc/dataset/util/semaphore.cc
0 → 100644
浏览文件 @
5a886794
/**
* Copyright 2019 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "dataset/util/semaphore.h"
#include "dataset/util/task_manager.h"
namespace
mindspore
{
namespace
dataset
{
Status
Semaphore
::
P
()
{
std
::
unique_lock
<
std
::
mutex
>
lck
(
mutex_
);
RETURN_IF_NOT_OK
(
wait_cond_
.
Wait
(
&
lck
,
[
this
]()
{
return
value_
>
0
;
}));
--
value_
;
return
Status
::
OK
();
}
void
Semaphore
::
V
()
{
std
::
unique_lock
<
std
::
mutex
>
lck
(
mutex_
);
++
value_
;
wait_cond_
.
NotifyOne
();
}
int
Semaphore
::
Peek
()
{
std
::
unique_lock
<
std
::
mutex
>
lck
(
mutex_
);
return
value_
;
}
Status
Semaphore
::
Register
(
TaskGroup
*
vg
)
{
return
wait_cond_
.
Register
(
vg
->
GetIntrpService
());
}
Status
Semaphore
::
Deregister
()
{
return
(
wait_cond_
.
Deregister
());
}
void
Semaphore
::
ResetIntrpState
()
{
wait_cond_
.
ResetIntrpState
();
}
}
// namespace dataset
}
// namespace mindspore
mindspore/ccsrc/dataset/util/semaphore.h
0 → 100644
浏览文件 @
5a886794
/**
* Copyright 2019 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef DATASET_UTIL_SEMAPHORE_H_
#define DATASET_UTIL_SEMAPHORE_H_
#include "dataset/util/cond_var.h"
namespace
mindspore
{
namespace
dataset
{
class
TaskGroup
;
/// \brief A counting semaphore. There are two external functions P and V. P decrements the internal count and will be
/// blocked if the count is 0 (zero). V increments the internal count and wake up one of the waiters.
class
Semaphore
{
public:
/// \brief Constructor
/// \param init Initial value of the internal counter.
explicit
Semaphore
(
int
init
)
:
value_
(
init
)
{}
virtual
~
Semaphore
()
{}
/// \brief Decrement the internal counter. Will be blocked if the value is 0.
/// \return Error code. Can get interrupt.
Status
P
();
/// \brief Increment the internal counter. Wakeup on of the watiers if any.
void
V
();
/// \brief Peek the internal value
/// \return The internal value
int
Peek
();
Status
Register
(
TaskGroup
*
vg
);
Status
Deregister
();
void
ResetIntrpState
();
private:
int
value_
;
std
::
mutex
mutex_
;
CondVar
wait_cond_
;
};
}
// namespace dataset
}
// namespace mindspore
#endif // DATASET_UTIL_SEMAPHORE_H_
mindspore/ccsrc/dataset/util/slice.cc
0 → 100644
浏览文件 @
5a886794
/**
* Copyright 2019 Huawei Technologies Co., Ltd
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
* http://www.apache.org/licenses/LICENSE-2.0
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "dataset/util/slice.h"
namespace
mindspore
{
namespace
dataset
{
WritableSlice
::
WritableSlice
(
const
WritableSlice
&
src
,
off64_t
offset
,
size_t
len
)
:
ReadableSlice
(
src
,
offset
,
len
)
{
mutable_data_
=
static_cast
<
char
*>
(
src
.
mutable_data_
)
+
offset
;
}
WritableSlice
::
WritableSlice
(
const
WritableSlice
&
src
,
off64_t
offset
)
:
WritableSlice
(
src
,
offset
,
src
.
GetSize
()
-
offset
)
{}
Status
WritableSlice
::
Copy
(
WritableSlice
*
dest
,
const
ReadableSlice
&
src
)
{
RETURN_UNEXPECTED_IF_NULL
(
dest
);
RETURN_UNEXPECTED_IF_NULL
(
dest
->
GetMutablePointer
());
if
(
dest
->
GetSize
()
<=
0
)
{
RETURN_STATUS_UNEXPECTED
(
"Destination length is non-positive"
);
}
auto
err
=
memcpy_s
(
dest
->
GetMutablePointer
(),
dest
->
GetSize
(),
src
.
GetPointer
(),
src
.
GetSize
());
if
(
err
)
{
RETURN_STATUS_UNEXPECTED
(
std
::
to_string
(
err
));
}
return
Status
::
OK
();
}
}
// namespace dataset
}
// namespace mindspore
mindspore/ccsrc/dataset/util/slice.h
0 → 100644
浏览文件 @
5a886794
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef DATASET_UTIL_SLICE_H_
#define DATASET_UTIL_SLICE_H_
#include <unistd.h>
#include <cstddef>
#include <utility>
#include "./securec.h"
#include "dataset/util/allocator.h"
#include "dataset/util/status.h"
namespace
mindspore
{
namespace
dataset
{
/// \brief A ReadableSlice wraps a const pointer in memory and its size.
/// \see WritableSlice for a non-const version
///
class
ReadableSlice
{
public:
ReadableSlice
()
:
ptr_
(
nullptr
),
sz_
(
0
)
{}
ReadableSlice
(
const
void
*
ptr
,
size_t
sz
)
:
ptr_
(
ptr
),
sz_
(
sz
)
{}
ReadableSlice
(
const
ReadableSlice
&
src
,
off64_t
offset
,
size_t
len
)
{
ptr_
=
static_cast
<
const
char
*>
(
src
.
GetPointer
())
+
offset
;
sz_
=
len
;
}
ReadableSlice
(
const
ReadableSlice
&
src
,
off64_t
offset
)
:
ReadableSlice
(
src
,
offset
,
src
.
sz_
-
offset
)
{}
ReadableSlice
(
const
ReadableSlice
&
lhs
)
{
ptr_
=
lhs
.
ptr_
;
sz_
=
lhs
.
sz_
;
}
ReadableSlice
&
operator
=
(
const
ReadableSlice
&
lhs
)
{
if
(
this
!=
&
lhs
)
{
ptr_
=
lhs
.
ptr_
;
sz_
=
lhs
.
sz_
;
}
return
*
this
;
}
ReadableSlice
(
ReadableSlice
&&
lhs
)
noexcept
{
if
(
this
!=
&
lhs
)
{
ptr_
=
lhs
.
ptr_
;
sz_
=
lhs
.
sz_
;
lhs
.
ptr_
=
nullptr
;
lhs
.
sz_
=
0
;
}
}
ReadableSlice
&
operator
=
(
ReadableSlice
&&
lhs
)
noexcept
{
if
(
this
!=
&
lhs
)
{
ptr_
=
lhs
.
ptr_
;
sz_
=
lhs
.
sz_
;
lhs
.
ptr_
=
nullptr
;
lhs
.
sz_
=
0
;
}
return
*
this
;
}
/// \brief Getter function
/// \return Const version of the pointer
const
void
*
GetPointer
()
const
{
return
ptr_
;
}
/// \brief Getter function
/// \return Size of the slice
size_t
GetSize
()
const
{
return
sz_
;
}
bool
empty
()
const
{
return
ptr_
==
nullptr
;
}
private:
const
void
*
ptr_
;
size_t
sz_
;
};
/// \brief A WritableSlice inherits from ReadableSlice to allow
/// one to write to the address pointed to by the pointer.
///
class
WritableSlice
:
public
ReadableSlice
{
public:
friend
class
StorageContainer
;
/// \brief Default constructor
WritableSlice
()
:
ReadableSlice
(),
mutable_data_
(
nullptr
)
{}
/// \brief This form of a constructor takes a pointer and its size.
WritableSlice
(
void
*
ptr
,
size_t
sz
)
:
ReadableSlice
(
ptr
,
sz
),
mutable_data_
(
ptr
)
{}
WritableSlice
(
const
WritableSlice
&
src
,
off64_t
offset
,
size_t
len
);
WritableSlice
(
const
WritableSlice
&
src
,
off64_t
offset
);
WritableSlice
(
const
WritableSlice
&
lhs
)
:
ReadableSlice
(
lhs
)
{
mutable_data_
=
lhs
.
mutable_data_
;
}
WritableSlice
&
operator
=
(
const
WritableSlice
&
lhs
)
{
if
(
this
!=
&
lhs
)
{
mutable_data_
=
lhs
.
mutable_data_
;
ReadableSlice
::
operator
=
(
lhs
);
}
return
*
this
;
}
WritableSlice
(
WritableSlice
&&
lhs
)
noexcept
:
ReadableSlice
(
std
::
move
(
lhs
))
{
if
(
this
!=
&
lhs
)
{
mutable_data_
=
lhs
.
mutable_data_
;
lhs
.
mutable_data_
=
nullptr
;
}
}
WritableSlice
&
operator
=
(
WritableSlice
&&
lhs
)
noexcept
{
if
(
this
!=
&
lhs
)
{
mutable_data_
=
lhs
.
mutable_data_
;
lhs
.
mutable_data_
=
nullptr
;
ReadableSlice
::
operator
=
(
std
::
move
(
lhs
));
}
return
*
this
;
}
/// \brief Copy the content from one slice onto another.
static
Status
Copy
(
WritableSlice
*
dest
,
const
ReadableSlice
&
src
);
private:
void
*
mutable_data_
;
void
*
GetMutablePointer
()
{
return
mutable_data_
;
}
};
}
// namespace dataset
}
// namespace mindspore
#endif // DATASET_UTIL_SLICE_H_
mindspore/ccsrc/dataset/util/storage_container.cc
0 → 100644
浏览文件 @
5a886794
/**
* Copyright 2019 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "dataset/util/storage_container.h"
#include <fcntl.h>
#include <sys/stat.h>
#include <unistd.h>
#include <vector>
#include "common/utils.h"
#include "dataset/util/de_error.h"
#include "dataset/util/path.h"
#include "dataset/util/status.h"
#include "utils/log_adapter.h"
namespace
mindspore
{
namespace
dataset
{
Status
StorageContainer
::
Create
()
{
RETURN_IF_NOT_OK
(
BuddySpace
::
CreateBuddySpace
(
&
bs_
));
RETURN_IF_NOT_OK
(
cont_
.
CreateFile
(
&
fd_
));
is_open_
=
true
;
MS_LOG
(
INFO
)
<<
"Container "
<<
cont_
<<
" created"
;
return
Status
::
OK
();
}
Status
StorageContainer
::
Open
()
noexcept
{
std
::
lock_guard
<
std
::
mutex
>
lck
(
mutex_
);
// Check again
if
(
!
is_open_
)
{
RETURN_IF_NOT_OK
(
cont_
.
OpenFile
(
&
fd_
));
is_open_
=
true
;
}
return
Status
::
OK
();
}
Status
StorageContainer
::
Close
()
noexcept
{
if
(
is_open_
)
{
std
::
lock_guard
<
std
::
mutex
>
lck
(
mutex_
);
// Check again
if
(
is_open_
)
{
RETURN_IF_NOT_OK
(
cont_
.
CloseFile
(
fd_
));
is_open_
=
false
;
fd_
=
-
1
;
}
}
return
Status
::
OK
();
}
Status
StorageContainer
::
Read
(
WritableSlice
*
dest
,
off64_t
offset
)
const
noexcept
{
DS_ASSERT
(
is_open_
);
RETURN_UNEXPECTED_IF_NULL
(
dest
);
auto
sz
=
dest
->
GetSize
();
#if defined(_WIN32) || defined(_WIN64)
// Doesn't seem there is any pread64 on mingw.
// So we will do a seek and then a read under
// a protection of mutex.
std
::
lock_guard
<
std
::
mutex
>
lck
(
mutex_
);
auto
seek_err
=
lseek
(
fd_
,
offset
,
SEEK_SET
);
if
(
seek_err
<
0
)
{
RETURN_STATUS_UNEXPECTED
(
strerror
(
errno
));
}
auto
r_sz
=
read
(
fd_
,
dest
->
GetMutablePointer
(),
sz
);
#else
auto
r_sz
=
pread64
(
fd_
,
dest
->
GetMutablePointer
(),
sz
,
offset
);
#endif
if
(
r_sz
!=
sz
)
{
errno_t
err
=
(
r_sz
==
0
)
?
EOF
:
errno
;
RETURN_STATUS_UNEXPECTED
(
strerror
(
err
));
}
return
Status
::
OK
();
}
Status
StorageContainer
::
Write
(
const
ReadableSlice
&
dest
,
off64_t
offset
)
const
noexcept
{
DS_ASSERT
(
is_open_
);
auto
sz
=
dest
.
GetSize
();
#if defined(_WIN32) || defined(_WIN64)
// Doesn't seem there is any pwrite64 on mingw.
// So we will do a seek and then a read under
// a protection of mutex.
std
::
lock_guard
<
std
::
mutex
>
lck
(
mutex_
);
auto
seek_err
=
lseek
(
fd_
,
offset
,
SEEK_SET
);
if
(
seek_err
<
0
)
{
RETURN_STATUS_UNEXPECTED
(
strerror
(
errno
));
}
auto
r_sz
=
write
(
fd_
,
dest
.
GetPointer
(),
sz
);
#else
auto
r_sz
=
pwrite64
(
fd_
,
dest
.
GetPointer
(),
sz
,
offset
);
#endif
if
(
r_sz
!=
sz
)
{
errno_t
err
=
(
r_sz
==
0
)
?
EOF
:
errno
;
RETURN_STATUS_UNEXPECTED
(
strerror
(
err
));
}
return
Status
::
OK
();
}
Status
StorageContainer
::
Insert
(
const
std
::
vector
<
ReadableSlice
>
&
buf
,
off64_t
*
offset
)
noexcept
{
size_t
sz
=
0
;
for
(
auto
&
v
:
buf
)
{
sz
+=
v
.
GetSize
();
}
if
(
sz
==
0
)
{
RETURN_STATUS_UNEXPECTED
(
"Unexpected 0 length"
);
}
if
(
sz
>
bs_
->
GetMaxSize
())
{
RETURN_STATUS_UNEXPECTED
(
"Request size too big"
);
}
BSpaceDescriptor
bspd
{
0
};
addr_t
addr
=
0
;
RETURN_IF_NOT_OK
(
bs_
->
Alloc
(
sz
,
&
bspd
,
&
addr
));
*
offset
=
static_cast
<
off64_t
>
(
addr
);
// We will do piecewise copy of the data to disk.
for
(
auto
&
v
:
buf
)
{
RETURN_IF_NOT_OK
(
Write
(
v
,
addr
));
addr
+=
v
.
GetSize
();
}
return
Status
::
OK
();
}
Status
StorageContainer
::
Truncate
()
const
noexcept
{
if
(
is_open_
)
{
RETURN_IF_NOT_OK
(
cont_
.
TruncateFile
(
fd_
));
MS_LOG
(
INFO
)
<<
"Container "
<<
cont_
<<
" truncated"
;
}
return
Status
::
OK
();
}
StorageContainer
::~
StorageContainer
()
noexcept
{
(
void
)
Truncate
();
(
void
)
Close
();
}
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
StorageContainer
&
s
)
{
os
<<
"File path : "
<<
s
.
cont_
<<
"
\n
"
<<
*
(
s
.
bs_
.
get
());
return
os
;
}
Status
StorageContainer
::
CreateStorageContainer
(
std
::
shared_ptr
<
StorageContainer
>
*
out_sc
,
const
std
::
string
&
path
)
{
Status
rc
;
auto
sc
=
new
(
std
::
nothrow
)
StorageContainer
(
path
);
if
(
sc
==
nullptr
)
{
return
Status
(
StatusCode
::
kOutOfMemory
);
}
rc
=
sc
->
Create
();
if
(
rc
.
IsOk
())
{
(
*
out_sc
).
reset
(
sc
);
}
else
{
delete
sc
;
}
return
rc
;
}
}
// namespace dataset
}
// namespace mindspore
mindspore/ccsrc/dataset/util/storage_container.h
0 → 100644
浏览文件 @
5a886794
/**
* Copyright 2019 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef DATASET_UTIL_STORAGE_CONTAINER_H_
#define DATASET_UTIL_STORAGE_CONTAINER_H_
#include <limits.h>
#include <unistd.h>
#include <memory>
#include <mutex>
#include <string>
#include <vector>
#include "dataset/util/system_pool.h"
#include "dataset/util/buddy.h"
#include "dataset/util/path.h"
#include "dataset/util/slice.h"
#include "dataset/util/status.h"
namespace
mindspore
{
namespace
dataset
{
class
StorageManager
;
class
StorageContainer
{
public:
friend
class
StorageManager
;
~
StorageContainer
()
noexcept
;
StorageContainer
(
const
StorageContainer
&
)
=
delete
;
StorageContainer
&
operator
=
(
const
StorageContainer
&
)
=
delete
;
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
StorageContainer
&
s
);
Status
Open
()
noexcept
;
Status
Close
()
noexcept
;
Status
Insert
(
const
std
::
vector
<
ReadableSlice
>
&
buf
,
off64_t
*
offset
)
noexcept
;
Status
Write
(
const
ReadableSlice
&
dest
,
off64_t
offset
)
const
noexcept
;
Status
Read
(
WritableSlice
*
dest
,
off64_t
offset
)
const
noexcept
;
Status
Truncate
()
const
noexcept
;
bool
IsOpen
()
const
{
return
is_open_
;
}
static
Status
CreateStorageContainer
(
std
::
shared_ptr
<
StorageContainer
>
*
out_sc
,
const
std
::
string
&
path
);
private:
mutable
std
::
mutex
mutex_
;
Path
cont_
;
int
fd_
;
bool
is_open_
;
std
::
unique_ptr
<
BuddySpace
>
bs_
;
// Use the default value of BuddySpace
// which can map upto 4G of space.
explicit
StorageContainer
(
const
std
::
string
&
path
)
:
cont_
(
path
),
fd_
(
-
1
),
is_open_
(
false
),
bs_
(
nullptr
)
{}
Status
Create
();
};
}
// namespace dataset
}
// namespace mindspore
#endif // DATASET_UTIL_STORAGE_CONTAINER_H_
mindspore/ccsrc/dataset/util/storage_manager.cc
0 → 100644
浏览文件 @
5a886794
/**
* Copyright 2019 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "dataset/util/storage_manager.h"
#include <iomanip>
#include <sstream>
#include <stdexcept>
#include <utility>
#include "common/utils.h"
#include "dataset/util/path.h"
#include "dataset/util/services.h"
#include "dataset/util//de_error.h"
#include "utils/log_adapter.h"
namespace
mindspore
{
namespace
dataset
{
std
::
string
StorageManager
::
GetBaseName
(
const
std
::
string
&
prefix
,
int32_t
file_id
)
{
std
::
ostringstream
oss
;
oss
<<
prefix
<<
std
::
setfill
(
'0'
)
<<
std
::
setw
(
5
)
<<
file_id
;
return
oss
.
str
();
}
std
::
string
StorageManager
::
ConstructFileName
(
const
std
::
string
&
prefix
,
int32_t
file_id
,
const
std
::
string
&
suffix
)
{
std
::
string
base_name
=
GetBaseName
(
prefix
,
file_id
);
return
(
base_name
+
"."
+
suffix
);
}
Status
StorageManager
::
AddOneContainer
()
{
const
std
::
string
kPrefix
=
"IMG"
;
const
std
::
string
kSuffix
=
"LB"
;
Path
container_name
=
root_
/
ConstructFileName
(
kPrefix
,
file_id_
,
kSuffix
);
std
::
shared_ptr
<
StorageContainer
>
sc
;
RETURN_IF_NOT_OK
(
StorageContainer
::
CreateStorageContainer
(
&
sc
,
container_name
.
toString
()));
containers_
.
push_back
(
sc
);
file_id_
++
;
return
Status
::
OK
();
}
Status
StorageManager
::
DoServiceStart
()
{
containers_
.
reserve
(
1000
);
if
(
root_
.
IsDirectory
())
{
RETURN_IF_NOT_OK
(
AddOneContainer
());
}
else
{
RETURN_STATUS_UNEXPECTED
(
"Not a directory"
);
}
return
Status
::
OK
();
}
Status
StorageManager
::
Write
(
key_type
*
key
,
const
std
::
vector
<
ReadableSlice
>
&
buf
)
{
RETURN_UNEXPECTED_IF_NULL
(
key
);
size_t
sz
=
0
;
for
(
auto
&
v
:
buf
)
{
sz
+=
v
.
GetSize
();
}
if
(
sz
==
0
)
{
RETURN_STATUS_UNEXPECTED
(
"Unexpected 0 length"
);
}
std
::
shared_ptr
<
StorageContainer
>
cont
;
key_type
out_key
;
value_type
out_value
;
bool
create_new_container
=
false
;
do
{
SharedLock
lock_s
(
&
rw_lock_
);
size_t
num_containers
=
containers_
.
size
();
if
(
create_new_container
)
{
// Upgrade to exclusvie lock.
lock_s
.
Upgrade
();
create_new_container
=
false
;
// Check again if someone has already added a
// new container after we got the x lock
if
(
containers_
.
size
()
==
num_containers
)
{
RETURN_IF_NOT_OK
(
AddOneContainer
());
}
// Refresh how many containers there are.
num_containers
=
containers_
.
size
();
// Downgrade back to shared lock
lock_s
.
Downgrade
();
}
if
(
num_containers
==
0
)
{
RETURN_STATUS_UNEXPECTED
(
"num_containers is zero"
);
}
// Go to the last container to insert.
cont
=
containers_
.
at
(
num_containers
-
1
);
off64_t
offset
;
Status
rc
=
cont
->
Insert
(
buf
,
&
offset
);
if
(
rc
.
IsNoSpace
())
{
create_new_container
=
true
;
}
else
if
(
rc
.
IsOk
())
{
out_value
=
std
::
make_pair
(
num_containers
-
1
,
std
::
make_pair
(
offset
,
sz
));
RETURN_IF_NOT_OK
(
index_
.
insert
(
out_value
,
&
out_key
));
*
key
=
out_key
;
break
;
}
else
{
return
rc
;
}
}
while
(
true
);
return
Status
::
OK
();
}
Status
StorageManager
::
Read
(
StorageManager
::
key_type
key
,
WritableSlice
*
dest
,
size_t
*
bytesRead
)
const
{
RETURN_UNEXPECTED_IF_NULL
(
dest
);
auto
r
=
index_
.
Search
(
key
);
if
(
r
.
second
)
{
auto
&
it
=
r
.
first
;
value_type
v
=
*
it
;
int
container_inx
=
v
.
first
;
off_t
offset
=
v
.
second
.
first
;
size_t
sz
=
v
.
second
.
second
;
if
(
dest
->
GetSize
()
<
sz
)
{
std
::
string
errMsg
=
"Destination buffer too small. Expect at least "
+
std
::
to_string
(
sz
)
+
" but length = "
+
std
::
to_string
(
dest
->
GetSize
());
RETURN_STATUS_UNEXPECTED
(
errMsg
);
}
if
(
bytesRead
!=
nullptr
)
{
*
bytesRead
=
sz
;
}
auto
cont
=
containers_
.
at
(
container_inx
);
RETURN_IF_NOT_OK
(
cont
->
Read
(
dest
,
offset
));
}
else
{
RETURN_STATUS_UNEXPECTED
(
"Key not found"
);
}
return
Status
::
OK
();
}
Status
StorageManager
::
DoServiceStop
()
noexcept
{
Status
rc
;
Status
rc1
;
for
(
auto
const
&
p
:
containers_
)
{
// The destructor of StorageContainer is not called automatically until the use
// count drops to 0. But it is not always the case. We will do it ourselves.
rc
=
p
.
get
()
->
Truncate
();
if
(
rc
.
IsError
())
{
rc1
=
rc
;
}
}
containers_
.
clear
();
file_id_
=
0
;
return
rc1
;
}
StorageManager
::
StorageManager
(
const
Path
&
root
)
:
root_
(
root
),
file_id_
(
0
),
index_
()
{}
StorageManager
::~
StorageManager
()
{
(
void
)
StorageManager
::
DoServiceStop
();
}
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
StorageManager
&
s
)
{
os
<<
"Dumping all containers ..."
<<
"
\n
"
;
for
(
auto
const
&
p
:
s
.
containers_
)
{
os
<<
*
(
p
.
get
());
}
return
os
;
}
}
// namespace dataset
}
// namespace mindspore
mindspore/ccsrc/dataset/util/storage_manager.h
0 → 100644
浏览文件 @
5a886794
/**
* Copyright 2019 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef DATASET_UTIL_STORAGE_MANAGER_H_
#define DATASET_UTIL_STORAGE_MANAGER_H_
#include <unistd.h>
#include <memory>
#include <string>
#include <utility>
#include <vector>
#include "dataset/util/allocator.h"
#include "dataset/util/auto_index.h"
#include "dataset/util/lock.h"
#include "dataset/util/memory_pool.h"
#include "dataset/util/path.h"
#include "dataset/util/service.h"
#include "dataset/util/slice.h"
#include "dataset/util/storage_container.h"
using
ListOfContainers
=
std
::
vector
<
std
::
shared_ptr
<
mindspore
::
dataset
::
StorageContainer
>>
;
namespace
mindspore
{
namespace
dataset
{
class
StorageManager
:
public
Service
{
public:
using
storage_index
=
AutoIndexObj
<
std
::
pair
<
int
,
std
::
pair
<
off_t
,
size_t
>>>
;
using
key_type
=
storage_index
::
key_type
;
using
value_type
=
storage_index
::
value_type
;
explicit
StorageManager
(
const
Path
&
);
~
StorageManager
()
override
;
StorageManager
(
const
StorageManager
&
)
=
delete
;
StorageManager
&
operator
=
(
const
StorageManager
&
)
=
delete
;
Status
Write
(
key_type
*
out_key
,
const
std
::
vector
<
ReadableSlice
>
&
buf
);
Status
Read
(
key_type
key
,
WritableSlice
*
dest
,
size_t
*
bytesRead
)
const
;
Status
DoServiceStart
()
override
;
Status
DoServiceStop
()
noexcept
override
;
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
StorageManager
&
s
);
private:
Path
root_
;
ListOfContainers
containers_
;
int
file_id_
;
RWLock
rw_lock_
;
storage_index
index_
;
std
::
string
GetBaseName
(
const
std
::
string
&
prefix
,
int32_t
file_id
);
std
::
string
ConstructFileName
(
const
std
::
string
&
prefix
,
int32_t
file_id
,
const
std
::
string
&
suffix
);
Status
AddOneContainer
();
};
}
// namespace dataset
}
// namespace mindspore
#endif // DATASET_UTIL_STORAGE_MANAGER_H_
mindspore/ccsrc/dataset/util/system_pool.h
浏览文件 @
5a886794
...
@@ -19,8 +19,10 @@
...
@@ -19,8 +19,10 @@
#include <cstddef>
#include <cstddef>
#include <cstdlib>
#include <cstdlib>
#include <limits>
#include <limits>
#include <memory>
#include <new>
#include <new>
#include "./securec.h"
#include "./securec.h"
#include "dataset/util/allocator.h"
#include "dataset/util/memory_pool.h"
#include "dataset/util/memory_pool.h"
namespace
mindspore
{
namespace
mindspore
{
...
@@ -61,6 +63,11 @@ class SystemPool : public MemoryPool {
...
@@ -61,6 +63,11 @@ class SystemPool : public MemoryPool {
uint64_t
get_max_size
()
const
override
{
return
std
::
numeric_limits
<
uint64_t
>::
max
();
}
uint64_t
get_max_size
()
const
override
{
return
std
::
numeric_limits
<
uint64_t
>::
max
();
}
int
PercentFree
()
const
override
{
return
100
;
}
int
PercentFree
()
const
override
{
return
100
;
}
template
<
typename
T
>
static
Allocator
<
T
>
GetAllocator
()
{
return
Allocator
<
T
>
(
std
::
make_shared
<
SystemPool
>
());
}
};
};
}
// namespace dataset
}
// namespace dataset
}
// namespace mindspore
}
// namespace mindspore
...
...
mindspore/ccsrc/device/kernel_runtime.cc
浏览文件 @
5a886794
...
@@ -30,6 +30,7 @@
...
@@ -30,6 +30,7 @@
#include "kernel/common_utils.h"
#include "kernel/common_utils.h"
#include "kernel/oplib/oplib.h"
#include "kernel/oplib/oplib.h"
#include "ir/value.h"
#include "ir/value.h"
#include "pre_activate/common/helper.h"
using
mindspore
::
kernel
::
Address
;
using
mindspore
::
kernel
::
Address
;
using
mindspore
::
kernel
::
AddressPtr
;
using
mindspore
::
kernel
::
AddressPtr
;
...
@@ -632,7 +633,7 @@ void KernelRuntime::AssignWorkSpaceMem(int flag, const AnfNodePtr &node) {
...
@@ -632,7 +633,7 @@ void KernelRuntime::AssignWorkSpaceMem(int flag, const AnfNodePtr &node) {
}
}
}
}
void
KernelRuntime
::
GenLaunchArgs
(
const
mindspore
::
kernel
::
KernelMod
&
kernel_mod
,
const
mindspore
::
AnfNodePtr
&
kernel
,
void
KernelRuntime
::
GenLaunchArgs
(
const
session
::
KernelGraph
&
graph
,
const
mindspore
::
AnfNodePtr
&
kernel
,
AddressPtrList
*
kernel_inputs
,
AddressPtrList
*
const
kernel_workspaces
,
AddressPtrList
*
kernel_inputs
,
AddressPtrList
*
const
kernel_workspaces
,
AddressPtrList
*
kernel_outputs
)
{
AddressPtrList
*
kernel_outputs
)
{
MS_EXCEPTION_IF_NULL
(
kernel
);
MS_EXCEPTION_IF_NULL
(
kernel
);
...
@@ -644,9 +645,15 @@ void KernelRuntime::GenLaunchArgs(const mindspore::kernel::KernelMod &kernel_mod
...
@@ -644,9 +645,15 @@ void KernelRuntime::GenLaunchArgs(const mindspore::kernel::KernelMod &kernel_mod
if
(
AnfAlgo
::
GetCNodeName
(
cnode
)
==
kAtomicAddrCleanOpName
)
{
if
(
AnfAlgo
::
GetCNodeName
(
cnode
)
==
kAtomicAddrCleanOpName
)
{
return
GenAddrCleanLaunchArgs
(
cnode
,
kernel_inputs
);
return
GenAddrCleanLaunchArgs
(
cnode
,
kernel_inputs
);
}
}
auto
is_all_nop_node
=
opt
::
IsAllNopNode
(
&
graph
);
for
(
size_t
i
=
0
;
i
<
AnfAlgo
::
GetInputTensorNum
(
kernel
);
++
i
)
{
for
(
size_t
i
=
0
;
i
<
AnfAlgo
::
GetInputTensorNum
(
kernel
);
++
i
)
{
auto
real_input
=
AnfAlgo
::
GetRealInputIndex
(
kernel
,
i
);
auto
real_input
=
AnfAlgo
::
GetRealInputIndex
(
kernel
,
i
);
auto
device_address
=
AnfAlgo
::
GetPrevNodeOutputAddr
(
kernel
,
real_input
);
DeviceAddressPtr
device_address
;
if
(
is_all_nop_node
)
{
device_address
=
AnfAlgo
::
GetPrevNodeMutableOutputAddr
(
kernel
,
real_input
,
false
);
}
else
{
device_address
=
AnfAlgo
::
GetPrevNodeMutableOutputAddr
(
kernel
,
real_input
,
true
);
}
MS_EXCEPTION_IF_NULL
(
device_address
);
MS_EXCEPTION_IF_NULL
(
device_address
);
kernel
::
AddressPtr
input
=
std
::
make_shared
<
kernel
::
Address
>
();
kernel
::
AddressPtr
input
=
std
::
make_shared
<
kernel
::
Address
>
();
MS_EXCEPTION_IF_NULL
(
input
);
MS_EXCEPTION_IF_NULL
(
input
);
...
@@ -656,8 +663,16 @@ void KernelRuntime::GenLaunchArgs(const mindspore::kernel::KernelMod &kernel_mod
...
@@ -656,8 +663,16 @@ void KernelRuntime::GenLaunchArgs(const mindspore::kernel::KernelMod &kernel_mod
kernel_inputs
->
emplace_back
(
input
);
kernel_inputs
->
emplace_back
(
input
);
}
}
for
(
size_t
i
=
0
;
i
<
kernel_mod
.
GetOutputSizeList
().
size
();
++
i
)
{
auto
kernel_mod
=
AnfAlgo
::
GetKernelMod
(
kernel
);
auto
device_address
=
AnfAlgo
::
GetOutputAddr
(
kernel
,
i
);
MS_EXCEPTION_IF_NULL
(
kernel_mod
);
for
(
size_t
i
=
0
;
i
<
kernel_mod
->
GetOutputSizeList
().
size
();
++
i
)
{
DeviceAddressPtr
device_address
;
if
(
is_all_nop_node
)
{
device_address
=
AnfAlgo
::
GetMutableOutputAddr
(
kernel
,
i
,
false
);
}
else
{
device_address
=
AnfAlgo
::
GetMutableOutputAddr
(
kernel
,
i
,
true
);
}
MS_EXCEPTION_IF_NULL
(
device_address
);
kernel
::
AddressPtr
output
=
std
::
make_shared
<
kernel
::
Address
>
();
kernel
::
AddressPtr
output
=
std
::
make_shared
<
kernel
::
Address
>
();
MS_EXCEPTION_IF_NULL
(
output
);
MS_EXCEPTION_IF_NULL
(
output
);
output
->
addr
=
device_address
->
ptr_
;
output
->
addr
=
device_address
->
ptr_
;
...
@@ -666,7 +681,7 @@ void KernelRuntime::GenLaunchArgs(const mindspore::kernel::KernelMod &kernel_mod
...
@@ -666,7 +681,7 @@ void KernelRuntime::GenLaunchArgs(const mindspore::kernel::KernelMod &kernel_mod
kernel_outputs
->
emplace_back
(
output
);
kernel_outputs
->
emplace_back
(
output
);
}
}
for
(
size_t
i
=
0
;
i
<
kernel_mod
.
GetWorkspaceSizeList
().
size
();
++
i
)
{
for
(
size_t
i
=
0
;
i
<
kernel_mod
->
GetWorkspaceSizeList
().
size
();
++
i
)
{
auto
device_address
=
AnfAlgo
::
GetWorkspaceAddr
(
kernel
,
i
);
auto
device_address
=
AnfAlgo
::
GetWorkspaceAddr
(
kernel
,
i
);
kernel
::
AddressPtr
workspace
=
std
::
make_shared
<
kernel
::
Address
>
();
kernel
::
AddressPtr
workspace
=
std
::
make_shared
<
kernel
::
Address
>
();
MS_EXCEPTION_IF_NULL
(
workspace
);
MS_EXCEPTION_IF_NULL
(
workspace
);
...
@@ -721,7 +736,7 @@ bool KernelRuntime::LaunchKernelMod(const session::KernelGraph &graph) {
...
@@ -721,7 +736,7 @@ bool KernelRuntime::LaunchKernelMod(const session::KernelGraph &graph) {
AddressPtrList
kernel_inputs
;
AddressPtrList
kernel_inputs
;
AddressPtrList
kernel_workspaces
;
AddressPtrList
kernel_workspaces
;
AddressPtrList
kernel_outputs
;
AddressPtrList
kernel_outputs
;
GenLaunchArgs
(
*
kernel_mod
,
kernel
,
&
kernel_inputs
,
&
kernel_workspaces
,
&
kernel_outputs
);
GenLaunchArgs
(
graph
,
kernel
,
&
kernel_inputs
,
&
kernel_workspaces
,
&
kernel_outputs
);
auto
ret
=
kernel_mod
->
Launch
(
kernel_inputs
,
kernel_workspaces
,
kernel_outputs
,
stream_
);
auto
ret
=
kernel_mod
->
Launch
(
kernel_inputs
,
kernel_workspaces
,
kernel_outputs
,
stream_
);
if
(
!
ret
)
{
if
(
!
ret
)
{
MS_LOG
(
ERROR
)
<<
"Launch kernel failed."
;
MS_LOG
(
ERROR
)
<<
"Launch kernel failed."
;
...
...
mindspore/ccsrc/device/kernel_runtime.h
浏览文件 @
5a886794
...
@@ -96,8 +96,8 @@ class KernelRuntime {
...
@@ -96,8 +96,8 @@ class KernelRuntime {
private:
private:
void
AssignStaticMemoryOutput
(
const
session
::
KernelGraph
*
graph
);
void
AssignStaticMemoryOutput
(
const
session
::
KernelGraph
*
graph
);
void
GenLaunchArgs
(
const
mindspore
::
kernel
::
KernelMod
&
kernel_mod
,
const
AnfNodePtr
&
kernel
,
void
GenLaunchArgs
(
const
session
::
KernelGraph
&
graph
,
const
AnfNodePtr
&
kernel
,
AddressPtrList
*
kernel_inputs
,
AddressPtrList
*
kernel_
inputs
,
AddressPtrList
*
kernel_
workspaces
,
AddressPtrList
*
kernel_outputs
);
AddressPtrList
*
kernel_workspaces
,
AddressPtrList
*
kernel_outputs
);
bool
LaunchKernelMod
(
const
session
::
KernelGraph
&
graph
);
bool
LaunchKernelMod
(
const
session
::
KernelGraph
&
graph
);
void
GenAddrCleanLaunchArgs
(
const
CNodePtr
&
cnode
,
AddressPtrList
*
kernel_inputs
);
void
GenAddrCleanLaunchArgs
(
const
CNodePtr
&
cnode
,
AddressPtrList
*
kernel_inputs
);
size_t
CountNodeDeviceMemorySize
(
const
AnfNodePtr
&
node
,
size_t
output_index
);
size_t
CountNodeDeviceMemorySize
(
const
AnfNodePtr
&
node
,
size_t
output_index
);
...
...
mindspore/ccsrc/ir/optimizer_caller.h
浏览文件 @
5a886794
...
@@ -17,13 +17,23 @@
...
@@ -17,13 +17,23 @@
#ifndef MINDSPORE_CCSRC_IR_OPTIMIZER_CALLER_H_
#ifndef MINDSPORE_CCSRC_IR_OPTIMIZER_CALLER_H_
#define MINDSPORE_CCSRC_IR_OPTIMIZER_CALLER_H_
#define MINDSPORE_CCSRC_IR_OPTIMIZER_CALLER_H_
#include <memory>
#include "ir/anf.h"
#include "ir/anf.h"
#include "optimizer/opt.h"
namespace
mindspore
{
namespace
mindspore
{
namespace
opt
{
class
Optimizer
;
using
OptimizerPtr
=
std
::
shared_ptr
<
Optimizer
>
;
using
OptimizerWeakPtr
=
std
::
weak_ptr
<
Optimizer
>
;
using
PredicateFuncType
=
std
::
function
<
bool
(
const
AnfNodePtr
&
)
>
;
}
// namespace opt
class
OptimizerCaller
{
class
OptimizerCaller
{
public:
public:
virtual
AnfNodePtr
operator
()(
const
opt
::
OptimizerPtr
&
,
const
AnfNodePtr
&
)
{
return
nullptr
;
}
virtual
AnfNodePtr
operator
()(
const
opt
::
OptimizerPtr
&
,
const
AnfNodePtr
&
)
{
return
nullptr
;
}
};
};
using
OptimizerCallerPtr
=
std
::
shared_ptr
<
OptimizerCaller
>
;
}
// namespace mindspore
}
// namespace mindspore
#endif // MINDSPORE_CCSRC_IR_OPTIMIZER_CALLER_H_
#endif // MINDSPORE_CCSRC_IR_OPTIMIZER_CALLER_H_
mindspore/ccsrc/kernel/kernel_query.cc
浏览文件 @
5a886794
...
@@ -23,6 +23,7 @@
...
@@ -23,6 +23,7 @@
#include "kernel/tbe/tbe_kernel_select/tbe_kernel_select.h"
#include "kernel/tbe/tbe_kernel_select/tbe_kernel_select.h"
#include "kernel/akg/akg_kernel_metadata.h"
#include "kernel/akg/akg_kernel_metadata.h"
#include "session/anf_runtime_algorithm.h"
#include "session/anf_runtime_algorithm.h"
#include "utils/context/ms_context.h"
namespace
mindspore
{
namespace
mindspore
{
namespace
kernel
{
namespace
kernel
{
...
@@ -97,6 +98,12 @@ void KernelQuery(const CNodePtr &kernel_node, std::vector<std::shared_ptr<kernel
...
@@ -97,6 +98,12 @@ void KernelQuery(const CNodePtr &kernel_node, std::vector<std::shared_ptr<kernel
std
::
string
op_name
=
AnfAlgo
::
GetCNodeName
(
kernel_node
);
std
::
string
op_name
=
AnfAlgo
::
GetCNodeName
(
kernel_node
);
auto
context_ptr
=
MsContext
::
GetInstance
();
MS_EXCEPTION_IF_NULL
(
context_ptr
);
if
(
context_ptr
->
enable_graph_kernel
()
&&
IsPrimitiveCNode
(
kernel_node
,
prim
::
kPrimBatchMatMul
))
{
kernel_type
=
KernelType
::
AKG_KERNEL
;
}
switch
(
kernel_type
)
{
switch
(
kernel_type
)
{
case
KernelType
::
AKG_KERNEL
:
case
KernelType
::
AKG_KERNEL
:
AkgMetadataInfo
(
kernel_node
,
kernel_info_list
);
AkgMetadataInfo
(
kernel_node
,
kernel_info_list
);
...
...
mindspore/ccsrc/optimizer/cse.cc
浏览文件 @
5a886794
...
@@ -89,15 +89,28 @@ bool CSE::BuildOrderGroupAndDoReplace(const FuncGraphManagerPtr manager) const {
...
@@ -89,15 +89,28 @@ bool CSE::BuildOrderGroupAndDoReplace(const FuncGraphManagerPtr manager) const {
return
changed
;
return
changed
;
}
}
// The op like print, summary, or the op do not has true output, and always as a depend node input.
static
bool
HasSideEffect
(
const
AnfNodePtr
&
node
)
{
auto
prim
=
GetCNodePrimitive
(
node
);
if
(
prim
==
nullptr
)
{
return
false
;
}
auto
side_effect_v
=
prim
->
GetAttr
(
GRAPH_FLAG_SIDE_EFFECT
);
if
(
side_effect_v
!=
nullptr
&&
side_effect_v
->
isa
<
BoolImm
>
())
{
return
GetValue
<
bool
>
(
side_effect_v
);
}
return
false
;
}
// If true do not merge the node.
bool
CSE
::
CheckRandomEffect
(
const
AnfNodePtr
&
main
,
const
AnfNodePtr
&
node
)
const
{
bool
CSE
::
CheckRandomEffect
(
const
AnfNodePtr
&
main
,
const
AnfNodePtr
&
node
)
const
{
bool
has_random_effect
=
false
;
bool
has_random_effect
=
false
;
auto
prim_main
=
GetCNodePrimitive
(
main
);
auto
prim_main
=
GetCNodePrimitive
(
main
);
auto
prim_node
=
GetCNodePrimitive
(
node
);
auto
prim_node
=
GetCNodePrimitive
(
node
);
if
(
prim_main
==
prim_node
)
{
// if has random effect, when generate by different op (not same object), do not merge.
return
false
;
}
if
(
prim_main
!=
nullptr
)
{
if
(
prim_main
!=
nullptr
)
{
if
(
prim_main
==
prim_node
)
{
return
false
;
}
auto
effect_val
=
prim_main
->
GetAttr
(
GRAPH_FLAG_RANDOM_EFFECT
);
auto
effect_val
=
prim_main
->
GetAttr
(
GRAPH_FLAG_RANDOM_EFFECT
);
if
(
effect_val
!=
nullptr
&&
effect_val
->
isa
<
BoolImm
>
())
{
if
(
effect_val
!=
nullptr
&&
effect_val
->
isa
<
BoolImm
>
())
{
has_random_effect
=
GetValue
<
bool
>
(
effect_val
);
has_random_effect
=
GetValue
<
bool
>
(
effect_val
);
...
@@ -106,45 +119,58 @@ bool CSE::CheckRandomEffect(const AnfNodePtr &main, const AnfNodePtr &node) cons
...
@@ -106,45 +119,58 @@ bool CSE::CheckRandomEffect(const AnfNodePtr &main, const AnfNodePtr &node) cons
return
has_random_effect
;
return
has_random_effect
;
}
}
bool
CSE
::
CheckReplace
(
const
AnfNodePtr
&
main
,
const
AnfNodePtr
&
node
)
const
{
bool
CSE
::
CheckReplace
(
const
AnfNodePtr
&
main
,
const
AnfNodePtr
&
node
,
bool
check_side_effect
)
const
{
MS_EXCEPTION_IF_NULL
(
main
);
MS_EXCEPTION_IF_NULL
(
main
);
MS_EXCEPTION_IF_NULL
(
node
);
MS_EXCEPTION_IF_NULL
(
node
);
bool
replace
=
false
;
if
(
main
->
isa
<
ValueNode
>
()
&&
node
->
isa
<
ValueNode
>
())
{
if
(
main
->
isa
<
ValueNode
>
()
&&
node
->
isa
<
ValueNode
>
())
{
auto
main_value
=
GetValueNode
(
main
);
auto
main_value
=
GetValueNode
(
main
);
auto
node_value
=
GetValueNode
(
node
);
auto
node_value
=
GetValueNode
(
node
);
re
place
=
(
AbsOf
(
main
)
==
AbsOf
(
node
))
&&
(
*
main_value
==
*
node_value
);
re
turn
(
AbsOf
(
main
)
==
AbsOf
(
node
))
&&
(
*
main_value
==
*
node_value
);
}
else
if
(
main
->
isa
<
CNode
>
()
&&
node
->
isa
<
CNode
>
())
{
}
else
if
(
main
->
isa
<
CNode
>
()
&&
node
->
isa
<
CNode
>
())
{
auto
c_main
=
main
->
cast
<
CNodePtr
>
();
auto
c_main
=
main
->
cast
<
CNodePtr
>
();
auto
c_node
=
node
->
cast
<
CNodePtr
>
();
auto
c_node
=
node
->
cast
<
CNodePtr
>
();
// When appsame is true, check if has side effect, do not merge.
if
(
check_side_effect
&&
HasSideEffect
(
main
))
{
return
false
;
}
const
auto
&
inp1
=
c_main
->
inputs
();
const
auto
&
inp1
=
c_main
->
inputs
();
const
auto
&
inp2
=
c_node
->
inputs
();
const
auto
&
inp2
=
c_node
->
inputs
();
if
(
inp1
.
size
()
==
inp2
.
size
())
{
if
(
inp1
.
size
()
!=
inp2
.
size
())
{
bool
appsame
=
true
;
return
false
;
for
(
size_t
j
=
0
;
j
<
inp1
.
size
();
j
++
)
{
}
MS_EXCEPTION_IF_NULL
(
inp1
[
j
]);
for
(
size_t
j
=
0
;
j
<
inp1
.
size
();
j
++
)
{
MS_EXCEPTION_IF_NULL
(
inp2
[
j
]);
auto
inp1_j
=
inp1
[
j
];
if
(
!
(
*
inp1
[
j
]
==
*
inp2
[
j
]))
{
auto
inp2_j
=
inp2
[
j
];
// Handle the case of two different Tensor, but with the same value
MS_EXCEPTION_IF_NULL
(
inp1_j
);
if
(
IsValueNode
<
tensor
::
Tensor
>
(
inp1
[
j
])
&&
IsValueNode
<
tensor
::
Tensor
>
(
inp2
[
j
]))
{
MS_EXCEPTION_IF_NULL
(
inp2_j
);
auto
tensor1
=
GetValueNode
<
tensor
::
TensorPtr
>
(
inp1
[
j
]);
if
(
!
(
*
inp1_j
==
*
inp2_j
))
{
auto
tensor2
=
GetValueNode
<
tensor
::
TensorPtr
>
(
inp2
[
j
]);
// Handle the case of two different Tensor, but with the same value
if
(
tensor1
->
ValueEqual
(
*
tensor2
))
{
if
(
IsValueNode
<
tensor
::
Tensor
>
(
inp1_j
)
&&
IsValueNode
<
tensor
::
Tensor
>
(
inp2_j
))
{
continue
;
auto
tensor1
=
GetValueNode
<
tensor
::
TensorPtr
>
(
inp1_j
);
}
auto
tensor2
=
GetValueNode
<
tensor
::
TensorPtr
>
(
inp2_j
);
if
(
tensor1
->
ValueEqual
(
*
tensor2
))
{
continue
;
}
}
else
if
(
HasSideEffect
(
inp1_j
)
&&
HasSideEffect
(
inp2_j
))
{
// When the same side effect node as another two nodes' inputs, we still merge the node.
// Because the node only can be the inputs of `depend`, when the `depend` is duplicated merge the depend the
// node.
if
(
CheckReplace
(
inp1_j
,
inp2_j
,
false
))
{
continue
;
}
}
appsame
=
false
;
break
;
}
}
return
false
;
}
}
if
(
CheckRandomEffect
(
c_main
,
c_node
))
{
appsame
=
false
;
}
replace
=
appsame
;
}
}
// When appsame is true, check if has random effect do not merge
if
(
CheckRandomEffect
(
c_main
,
c_node
))
{
return
false
;
}
return
true
;
}
}
return
replace
;
// a parameter node.
return
false
;
}
}
bool
CSE
::
DoReplace
(
const
FuncGraphManagerPtr
manager
,
const
std
::
vector
<
std
::
size_t
>
&
order_group
,
bool
CSE
::
DoReplace
(
const
FuncGraphManagerPtr
manager
,
const
std
::
vector
<
std
::
size_t
>
&
order_group
,
...
...
mindspore/ccsrc/optimizer/cse.h
浏览文件 @
5a886794
...
@@ -41,7 +41,7 @@ class CSE {
...
@@ -41,7 +41,7 @@ class CSE {
return
chg
&&
report_changes_
;
return
chg
&&
report_changes_
;
}
}
virtual
bool
CheckReplace
(
const
AnfNodePtr
&
main
,
const
AnfNodePtr
&
node
)
const
;
virtual
bool
CheckReplace
(
const
AnfNodePtr
&
main
,
const
AnfNodePtr
&
node
,
bool
check_side_effect
=
true
)
const
;
virtual
bool
CheckRandomEffect
(
const
AnfNodePtr
&
main
,
const
AnfNodePtr
&
node
)
const
;
virtual
bool
CheckRandomEffect
(
const
AnfNodePtr
&
main
,
const
AnfNodePtr
&
node
)
const
;
...
...
mindspore/ccsrc/optimizer/irpass.cc
浏览文件 @
5a886794
...
@@ -14,140 +14,154 @@
...
@@ -14,140 +14,154 @@
* limitations under the License.
* limitations under the License.
*/
*/
#include "optimizer/irpass.h"
#include <string>
#include <string>
#include "optimizer/irpass
/symbol_resolver
.h"
#include "optimizer/irpass.h"
#include "optimizer/irpass/arithmetic_simplify.h"
#include "optimizer/irpass/arithmetic_simplify.h"
#include "optimizer/irpass/special_op_eliminate.h"
#include "optimizer/irpass/item_tuple_eliminate.h"
#include "optimizer/irpass/env_item_eliminate.h"
#include "optimizer/irpass/tile_eliminate.h"
#include "optimizer/irpass/cast_eliminate.h"
#include "optimizer/irpass/reshape_eliminate.h"
#include "optimizer/irpass/transpose_eliminate.h"
#include "optimizer/irpass/reduce_eliminate.h"
#include "optimizer/irpass/partial_eliminate.h"
#include "optimizer/irpass/ref_eliminate.h"
#include "optimizer/irpass/merge_addn.h"
#include "optimizer/irpass/branch_culling.h"
#include "optimizer/irpass/branch_culling.h"
#include "optimizer/irpass/cast_eliminate.h"
#include "optimizer/irpass/convert.h"
#include "optimizer/irpass/env_item_eliminate.h"
#include "optimizer/irpass/grad_var_prepare.h"
#include "optimizer/irpass/gradient_eliminate.h"
#include "optimizer/irpass/gradient_eliminate.h"
#include "optimizer/irpass/minmax_grad.h"
#include "optimizer/irpass/inline.h"
#include "optimizer/irpass/inline.h"
#include "optimizer/irpass/convert.h"
#include "optimizer/irpass/specialize_transform.h"
#include "optimizer/irpass/incorporate_getitem.h"
#include "optimizer/irpass/incorporate_call.h"
#include "optimizer/irpass/incorporate_call.h"
#include "optimizer/irpass/
grad_var_prepare
.h"
#include "optimizer/irpass/
incorporate_getitem
.h"
#include "optimizer/irpass/
param_replac
e.h"
#include "optimizer/irpass/
item_tuple_eliminat
e.h"
#include "optimizer/irpass/mark_interface_fusion.h"
#include "optimizer/irpass/mark_interface_fusion.h"
#include "optimizer/irpass/merge_addn.h"
#include "optimizer/irpass/minmax_grad.h"
#include "optimizer/irpass/param_replace.h"
#include "optimizer/irpass/partial_eliminate.h"
#include "optimizer/irpass/reduce_eliminate.h"
#include "optimizer/irpass/ref_eliminate.h"
#include "optimizer/irpass/reshape_eliminate.h"
#include "optimizer/irpass/special_op_eliminate.h"
#include "optimizer/irpass/specialize_transform.h"
#include "optimizer/irpass/symbol_resolver.h"
#include "optimizer/irpass/tile_eliminate.h"
#include "optimizer/irpass/transpose_eliminate.h"
#include "optimizer/opt.h"
#include "optimizer/opt.h"
namespace
mindspore
{
namespace
mindspore
{
namespace
opt
{
namespace
opt
{
namespace
irpass
{
namespace
irpass
{
OptimizeIRPassLib
::
OptimizeIRPassLib
()
{
OptimizeIRPassLib
::
OptimizeIRPassLib
()
{
arithmetic_simplify_
=
MakeSubstitution
(
ArithmeticSimplify
(),
"arithmetic_simplify"
,
arithmetic_simplify_
=
MakeSubstitution
(
std
::
make_shared
<
ArithmeticSimplify
>
(),
"arithmetic_simplify"
,
{
prim
::
kPrimScalarAdd
,
prim
::
kPrimScalarMul
,
prim
::
kPrimTensorAdd
,
{
prim
::
kPrimScalarAdd
,
prim
::
kPrimScalarMul
,
prim
::
kPrimTensorAdd
,
prim
::
kPrimIdentity
,
prim
::
kPrimMomentum
,
prim
::
kPrimMul
,
prim
::
kPrimPow
});
prim
::
kPrimIdentity
,
prim
::
kPrimMomentum
,
prim
::
kPrimMul
,
prim
::
kPrimPow
});
arithmetic_simplify2_
=
MakeSubstitution
(
ArithmeticSimplify2
(),
"arithmetic_simplify2"
,
{
prim
::
kPrimMul
});
arithmetic_simplify2_
=
MakeSubstitution
(
std
::
make_shared
<
ArithmeticSimplify2
>
(),
"arithmetic_simplify2"
,
{
prim
::
kPrimMul
});
special_op_eliminate_
=
special_op_eliminate_
=
MakeSubstitution
(
SpecialOpEliminater
(),
"special_op_eliminate"
,
MakeSubstitution
(
std
::
make_shared
<
SpecialOpEliminater
>
(),
"special_op_eliminate"
,
{
prim
::
kPrimInsertGradientOf
,
prim
::
kPrimStopGradient
,
prim
::
kPrimHookBackward
,
{
prim
::
kPrimInsertGradientOf
,
prim
::
kPrimStopGradient
,
prim
::
kPrimHookBackward
,
prim
::
kPrimPrintShapeType
,
prim
::
kPrimGetRefKey
,
prim
::
kPrimMirror
,
prim
::
kPrimVirtualDiv
});
prim
::
kPrimPrintShapeType
,
prim
::
kPrimGetRefKey
,
prim
::
kPrimMirror
,
prim
::
kPrimVirtualDiv
});
zero_like_fill_zero_
=
MakeSubstitution
(
ZeroLikeFillZero
(),
"zero_like_fill_zero"
,
prim
::
kPrimZerosLike
);
zero_like_fill_zero_
=
adjust_all_reduce_mul_add_
=
MakeSubstitution
(
AdjustAllReduceMulAdd
(),
"adjust_all_reduce_mul_add"
,
prim
::
kPrimAddN
);
MakeSubstitution
(
std
::
make_shared
<
ZeroLikeFillZero
>
(),
"zero_like_fill_zero"
,
prim
::
kPrimZerosLike
);
adjust_all_reduce_mul_add_
=
MakeSubstitution
(
std
::
make_shared
<
AdjustAllReduceMulAdd
>
(),
"adjust_all_reduce_mul_add"
,
prim
::
kPrimAddN
);
// ops eliminate
// ops eliminate
item_tuple_eliminate_
=
item_tuple_eliminate_
=
MakeSubstitution
(
std
::
make_shared
<
ItemTupleEliminater
>
(),
"item_tuple_eliminate"
,
MakeSubstitution
(
ItemTupleEliminater
(),
"item_tuple_eliminate"
,
{
prim
::
kPrimTupleGetItem
,
prim
::
kPrimTupleSetItem
});
{
prim
::
kPrimTupleGetItem
,
prim
::
kPrimTupleSetItem
});
tile_eliminate_
=
MakeSubstitution
(
TileMultiplyByOne
(),
"tile_eliminate"
,
prim
::
kPrimTile
);
tile_eliminate_
=
MakeSubstitution
(
std
::
make_shared
<
TileMultiplyByOne
>
(),
"tile_eliminate"
,
prim
::
kPrimTile
);
cast_eliminate_
=
MakeSubstitution
(
CastEliminater
(),
"cast_eliminate"
,
prim
::
kPrimCast
);
cast_eliminate_
=
MakeSubstitution
(
std
::
make_shared
<
CastEliminater
>
(),
"cast_eliminate"
,
prim
::
kPrimCast
);
reshape_eliminate_
=
MakeSubstitution
(
ReshapeEliminater
(),
"reshape_eliminate"
,
prim
::
kPrimReshape
);
reshape_eliminate_
=
MakeSubstitution
(
std
::
make_shared
<
ReshapeEliminater
>
(),
"reshape_eliminate"
,
prim
::
kPrimReshape
);
transpose_eliminate_
=
MakeSubstitution
(
TransposeSameIOEliminater
(),
"transpose_eliminate"
,
prim
::
kPrimTranspose
);
transpose_eliminate_
=
MakeSubstitution
(
std
::
make_shared
<
TransposeSameIOEliminater
>
(),
"transpose_eliminate"
,
prim
::
kPrimTranspose
);
reduce_eliminate_
=
MakeSubstitution
(
reduce_eliminate_
=
MakeSubstitution
(
ReduceOneEliminater
(),
"reduce_eliminate"
,
std
::
make_shared
<
ReduceOneEliminater
>
(),
"reduce_eliminate"
,
{
prim
::
kPrimReduceMean
,
prim
::
kPrimReduceAll
,
prim
::
kPrimReduceSum
,
prim
::
kPrimReduceMax
,
prim
::
kPrimReduceMin
});
{
prim
::
kPrimReduceMean
,
prim
::
kPrimReduceAll
,
prim
::
kPrimReduceSum
,
prim
::
kPrimReduceMax
,
prim
::
kPrimReduceMin
});
partial_eliminate_
=
MakeSubstitution
(
PartialEliminater
(),
"partial_eliminate"
,
IsCNodeDup
);
partial_eliminate_
=
MakeSubstitution
(
std
::
make_shared
<
PartialEliminater
>
(),
"partial_eliminate"
,
IsCNodeDup
);
same_eliminate_
=
MakeSubstitution
(
SameEliminater
(),
"same_eliminate"
,
prim
::
kPrimSameTypeShape
);
same_eliminate_
=
MakeSubstitution
(
std
::
make_shared
<
SameEliminater
>
(),
"same_eliminate"
,
prim
::
kPrimSameTypeShape
);
check_bprop_eliminate_
=
MakeSubstitution
(
CheckBpropEliminater
(),
"check_bprop_eliminate"
,
prim
::
kPrimCheckBprop
);
check_bprop_eliminate_
=
reset_defer_inline_
=
MakeSubstitution
(
ResetDeferInline
(),
"reset_defer_inline"
,
IsValueNode
<
FuncGraph
>
);
MakeSubstitution
(
std
::
make_shared
<
CheckBpropEliminater
>
(),
"check_bprop_eliminate"
,
prim
::
kPrimCheckBprop
);
depend_value_elim_
=
MakeSubstitution
(
DependValueElim
(),
"depend_value_elim"
,
prim
::
kPrimDepend
);
reset_defer_inline_
=
MakeSubstitution
(
std
::
make_shared
<
ResetDeferInline
>
(),
"reset_defer_inline"
,
IsValueNode
<
FuncGraph
>
);
depend_value_elim_
=
MakeSubstitution
(
std
::
make_shared
<
DependValueElim
>
(),
"depend_value_elim"
,
prim
::
kPrimDepend
);
// Env Item Eliminate
// Env Item Eliminate
env_get_item_eliminate_
=
MakeSubstitution
(
EnvGetItemEliminater
(),
"env_get_item_eliminate"
,
prim
::
kPrimEnvGetItem
);
env_get_item_eliminate_
=
new_env_get_item_
=
MakeSubstitution
(
NewEnvGetItem
(),
"new_env_get_item"
,
prim
::
kPrimEnvGetItem
);
MakeSubstitution
(
std
::
make_shared
<
EnvGetItemEliminater
>
(),
"env_get_item_eliminate"
,
prim
::
kPrimEnvGetItem
);
new_env_get_item_
=
MakeSubstitution
(
std
::
make_shared
<
NewEnvGetItem
>
(),
"new_env_get_item"
,
prim
::
kPrimEnvGetItem
);
incorporate_env_getitem_
=
incorporate_env_getitem_
=
MakeSubstitution
(
IncorporateEnvGetitem
(),
"incorporate_env_get_item"
,
prim
::
kPrimEnvGetItem
);
MakeSubstitution
(
std
::
make_shared
<
IncorporateEnvGetitem
>
(),
"incorporate_env_get_item"
,
prim
::
kPrimEnvGetItem
);
incorporate_env_getitem_switch_
=
incorporate_env_getitem_switch_
=
MakeSubstitution
(
std
::
make_shared
<
IncorporateEnvGetitemSwitch
>
(),
MakeSubstitution
(
IncorporateEnvGetitemSwitch
(),
"incorporate_env_getitem_switch"
,
prim
::
kPrimEnvGetItem
);
"incorporate_env_getitem_switch"
,
prim
::
kPrimEnvGetItem
);
// Ref eliminate
// Ref eliminate
make_ref_eliminate_
=
MakeSubstitution
(
MakeRefEliminater
(),
"make_ref_eliminate"
,
prim
::
kPrimMakeRef
);
make_ref_eliminate_
=
get_ref_param_eliminate_
=
MakeSubstitution
(
GetRefParamEliminater
(),
"get_ref_param_eliminate"
,
MakeSubstitution
(
std
::
make_shared
<
MakeRefEliminater
>
(),
"make_ref_eliminate"
,
prim
::
kPrimMakeRef
);
get_ref_param_eliminate_
=
MakeSubstitution
(
std
::
make_shared
<
GetRefParamEliminater
>
(),
"get_ref_param_eliminate"
,
{
prim
::
kPrimGetRefValue
,
prim
::
kPrimGetRefOrigin
});
{
prim
::
kPrimGetRefValue
,
prim
::
kPrimGetRefOrigin
});
get_make_ref_eliminate_
=
MakeSubstitution
(
GetMakeRefEliminater
(),
"get_make_ref_eliminate"
,
get_make_ref_eliminate_
=
MakeSubstitution
(
std
::
make_shared
<
GetMakeRefEliminater
>
(),
"get_make_ref_eliminate"
,
{
prim
::
kPrimGetRefKey
,
prim
::
kPrimGetRefValue
,
prim
::
kPrimGetRefOrigin
});
{
prim
::
kPrimGetRefKey
,
prim
::
kPrimGetRefValue
,
prim
::
kPrimGetRefOrigin
});
replace_refkey_by_param_
=
replace_refkey_by_param_
=
MakeSubstitution
(
std
::
make_shared
<
ReplaceRefkeyByParam
>
(),
"replace_refkey_by_param"
,
MakeSubstitution
(
ReplaceRefkeyByParam
(),
"replace_refkey_by_param"
,
IsValueNode
<
RefKey
>
,
opt
::
FORCE_RENORM
);
IsValueNode
<
RefKey
>
,
opt
::
FORCE_RENORM
);
replace_old_param_
=
MakeSubstitution
(
ReplaceOldParam
(),
"replace_old_param"
,
IsParam
);
replace_old_param_
=
MakeSubstitution
(
std
::
make_shared
<
ReplaceOldParam
>
(),
"replace_old_param"
,
IsParam
);
// Gradient transforms
// Gradient transforms
expand_jprim_
=
MakeSubstitution
(
ExpandJPrim
(),
"expand_jprim"
,
prim
::
kPrimJ
);
expand_jprim_
=
MakeSubstitution
(
std
::
make_shared
<
ExpandJPrim
>
(),
"expand_jprim"
,
prim
::
kPrimJ
);
minmaximum_grad_
=
MakeSubstitution
(
MinMaximumGrad
(),
"minmaximum_grad"
,
prim
::
kPrimTupleGetItem
);
minmaximum_grad_
=
MakeSubstitution
(
std
::
make_shared
<
MinMaximumGrad
>
(),
"minmaximum_grad"
,
prim
::
kPrimTupleGetItem
);
// branch culling
// branch culling
switch_simplify_
=
MakeSubstitution
(
SwitchSimplify
(),
"switch_simplify"
,
prim
::
kPrimSwitch
);
switch_simplify_
=
MakeSubstitution
(
std
::
make_shared
<
SwitchSimplify
>
(),
"switch_simplify"
,
prim
::
kPrimSwitch
);
float_tuple_getitem_switch_
=
float_tuple_getitem_switch_
=
MakeSubstitution
(
std
::
make_shared
<
FloatTupleGetItemSwitch
>
(),
MakeSubstitution
(
FloatTupleGetItemSwitch
(),
"float_tuple_getitem_switch"
,
prim
::
kPrimTupleGetItem
);
"float_tuple_getitem_switch"
,
prim
::
kPrimTupleGetItem
);
float_env_getitem_switch_
=
float_env_getitem_switch_
=
MakeSubstitution
(
FloatEnvGetItemSwitch
(),
"float_env_getitem_switch"
,
prim
::
kPrimEnvGetItem
);
MakeSubstitution
(
std
::
make_shared
<
FloatEnvGetItemSwitch
>
(),
"float_env_getitem_switch"
,
prim
::
kPrimEnvGetItem
);
convert_switch_replacement_
=
MakeSubstitution
(
ConvertSwitchReplacement
(),
"convert_switch_replacement"
,
IsCNodeDup
);
convert_switch_replacement_
=
MakeSubstitution
(
std
::
make_shared
<
ConvertSwitchReplacement
>
(),
"convert_switch_replacement"
,
IsCNodeDup
);
// Addn
// Addn
merge_addn_
=
MakeSubstitution
(
MergeAddN
(),
"merge_addn"
,
prim
::
kPrimAddN
);
merge_addn_
=
MakeSubstitution
(
std
::
make_shared
<
MergeAddN
>
(),
"merge_addn"
,
prim
::
kPrimAddN
);
addn_zero_filter_
=
MakeSubstitution
(
AddNZeroFilter
(),
"addn_zero_filter"
,
prim
::
kPrimAddN
);
addn_zero_filter_
=
MakeSubstitution
(
std
::
make_shared
<
AddNZeroFilter
>
(),
"addn_zero_filter"
,
prim
::
kPrimAddN
);
// inline
// inline
inline_
=
MakeSubstitution
(
Inliner
(),
"inline"
,
IsCNodeGraph
);
inline_
=
MakeSubstitution
(
std
::
make_shared
<
Inliner
>
(),
"inline"
,
IsCNodeGraph
);
replace_applicator_
=
MakeSubstitution
(
ReplaceApplicator
(),
"replace_applicator"
,
IsValueNode
<
FuncGraph
>
);
replace_applicator_
=
specialize_transform_
=
MakeSubstitution
(
SpecializeOnGraphArguments
(),
"specialize_transform"
,
IsCNodeGraph
);
MakeSubstitution
(
std
::
make_shared
<
ReplaceApplicator
>
(),
"replace_applicator"
,
IsValueNode
<
FuncGraph
>
);
specialize_transform_
=
MakeSubstitution
(
std
::
make_shared
<
SpecializeOnGraphArguments
>
(),
"specialize_transform"
,
IsCNodeGraph
);
// Incorporation
// Incorporation
incorporate_getitem_set_
=
incorporate_getitem_set_
=
MakeSubstitution
(
IncorporateGetitemSet
(),
"incorporate_getitem_set"
,
prim
::
kPrimTupleGetItem
);
MakeSubstitution
(
std
::
make_shared
<
IncorporateGetitemSet
>
(),
"incorporate_getitem_set"
,
prim
::
kPrimTupleGetItem
);
incorporate_getitem_from_param_
=
incorporate_getitem_from_param_
=
MakeSubstitution
(
std
::
make_shared
<
IncorporateGetitemFromParam
>
(),
MakeSubstitution
(
IncorporateGetitemFromParam
(),
"incorporate_getitem_from_param"
,
IsCNodeGraphKernel
);
"incorporate_getitem_from_param"
,
IsCNodeGraphKernel
);
incorporate_call_
=
MakeSubstitution
(
IncorporateCall
(),
"incorporate_call"
,
IsCNodeDup
);
incorporate_call_
=
MakeSubstitution
(
std
::
make_shared
<
IncorporateCall
>
(),
"incorporate_call"
,
IsCNodeDup
);
incorporate_call_switch_
=
MakeSubstitution
(
IncorporateCallSwitch
(),
"incorporate_call_switch"
,
IsCNodeDup
);
incorporate_call_switch_
=
MakeSubstitution
(
std
::
make_shared
<
IncorporateCallSwitch
>
(),
"incorporate_call_switch"
,
IsCNodeDup
);
// Virtual Dataset
// Virtual Dataset
virtual_dataset_eliminate_
=
virtual_dataset_eliminate_
=
MakeSubstitution
(
std
::
make_shared
<
VirtualDatasetEliminater
>
(),
MakeSubstitution
(
VirtualDatasetEliminater
(),
"virtual_dataset_eliminate"
,
prim
::
kPrimVirtualDataset
);
"virtual_dataset_eliminate"
,
prim
::
kPrimVirtualDataset
);
// Convert
// Convert
print_tuple_wrapper_
=
MakeSubstitution
(
PrintTupleWrapper
(),
"print_tuple_wrapper"
,
prim
::
kPrimPrint
);
print_tuple_wrapper_
=
MakeSubstitution
(
std
::
make_shared
<
PrintTupleWrapper
>
(),
"print_tuple_wrapper"
,
prim
::
kPrimPrint
);
// Unused parameter eliminate
// Unused parameter eliminate
unused_parameter_eliminate_
=
unused_parameter_eliminate_
=
MakeSubstitution
(
UnusedParasEliminater
(),
"unused_parameter_eliminate"
,
IsCNodeGraphKernel
);
MakeSubstitution
(
std
::
make_shared
<
UnusedParasEliminater
>
(),
"unused_parameter_eliminate"
,
IsCNodeGraphKernel
);
unused_output_eliminate_
=
MakeSubstitution
(
UnusedOutputEliminater
(),
"unused_output_eliminate"
,
IsCNodeGraphKernel
);
unused_output_eliminate_
=
MakeSubstitution
(
std
::
make_shared
<
UnusedOutputEliminater
>
(),
"unused_output_eliminate"
,
IsCNodeGraphKernel
);
// AddN eliminate
// AddN eliminate
addn_eliminate_
=
MakeSubstitution
(
AddNEliminater
(),
"addn_eliminate"
,
IsCNodeGraphKernel
);
addn_eliminate_
=
MakeSubstitution
(
std
::
make_shared
<
AddNEliminater
>
(),
"addn_eliminate"
,
IsCNodeGraphKernel
);
// Mark interface fusion
// Mark interface fusion
mark_interface_fusion_
=
MakeSubstitution
(
MarkInterfaceFusion
(),
"mark_interface_fusion"
,
prim
::
kPrimSelect
);
mark_interface_fusion_
=
MakeSubstitution
(
std
::
make_shared
<
MarkInterfaceFusion
>
(),
"mark_interface_fusion"
,
prim
::
kPrimSelect
);
}
}
ResolveIRPassLib
::
ResolveIRPassLib
()
{
ResolveIRPassLib
::
ResolveIRPassLib
()
{
resolver_resolve_
=
MakeSubstitution
(
ResolverResolve
(),
"resolver_resolve"
,
prim
::
kPrimResolve
);
resolver_resolve_
=
MakeSubstitution
(
std
::
make_shared
<
ResolverResolve
>
(),
"resolver_resolve"
,
prim
::
kPrimResolve
);
resolver_getattr_
=
MakeSubstitution
(
ResolverGetattr
(),
"resolver_getattr"
,
prim
::
kPrimGetAttr
);
resolver_getattr_
=
MakeSubstitution
(
std
::
make_shared
<
ResolverGetattr
>
(),
"resolver_getattr"
,
prim
::
kPrimGetAttr
);
}
}
InferenceOptPrepareLib
::
InferenceOptPrepareLib
()
{
InferenceOptPrepareLib
::
InferenceOptPrepareLib
()
{
grad_var_prepare_
=
MakeSubstitution
(
GradVarPrepare
(),
"grad_var_prepare"
,
IsCNode
);
grad_var_prepare_
=
MakeSubstitution
(
std
::
make_shared
<
GradVarPrepare
>
(),
"grad_var_prepare"
,
IsCNode
);
}
}
}
// namespace irpass
}
// namespace irpass
}
// namespace opt
}
// namespace opt
...
...
mindspore/ccsrc/optimizer/irpass/arithmetic_simplify.h
浏览文件 @
5a886794
...
@@ -17,15 +17,16 @@
...
@@ -17,15 +17,16 @@
#ifndef MINDSPORE_CCSRC_OPTIMIZER_IRPASS_ARITHMETIC_SIMPLIFY_H_
#ifndef MINDSPORE_CCSRC_OPTIMIZER_IRPASS_ARITHMETIC_SIMPLIFY_H_
#define MINDSPORE_CCSRC_OPTIMIZER_IRPASS_ARITHMETIC_SIMPLIFY_H_
#define MINDSPORE_CCSRC_OPTIMIZER_IRPASS_ARITHMETIC_SIMPLIFY_H_
#include <vector>
#include <memory>
#include <algorithm>
#include <algorithm>
#include <memory>
#include <vector>
#include "optimizer/optimizer.h"
#include "ir/optimizer_caller.h"
#include "optimizer/irpass.h"
#include "optimizer/irpass/prim_eliminate.h"
#include "ir/visitor.h"
#include "ir/visitor.h"
#include "operator/ops.h"
#include "operator/ops.h"
#include "optimizer/irpass.h"
#include "optimizer/irpass/prim_eliminate.h"
#include "optimizer/optimizer.h"
namespace
mindspore
{
namespace
mindspore
{
namespace
opt
{
namespace
opt
{
...
@@ -739,17 +740,17 @@ class AdjustAllReduceMulAdd : public AnfVisitor {
...
@@ -739,17 +740,17 @@ class AdjustAllReduceMulAdd : public AnfVisitor {
FuncGraphPtr
all_reduce_fg_
{
nullptr
};
FuncGraphPtr
all_reduce_fg_
{
nullptr
};
};
};
class
ArithmeticSimplify
{
class
ArithmeticSimplify
:
public
OptimizerCaller
{
public:
public:
ArithmeticSimplify
()
ArithmeticSimplify
()
:
multiply_by_zero_or_one_
(),
:
multiply_by_zero_or_one_
(
std
::
make_shared
<
MultiplyByZeroOrOne
>
()
),
tensor_multiply_by_one_
(),
tensor_multiply_by_one_
(
std
::
make_shared
<
TensorMultiplyByOne
>
()
),
add_by_zero_
(),
add_by_zero_
(
std
::
make_shared
<
AddByZero
>
()
),
tensor_add_by_zero_
(),
tensor_add_by_zero_
(
std
::
make_shared
<
TensorAddByZero
>
()
),
identity_
(
prim
::
kPrimIdentity
),
identity_
(
std
::
make_shared
<
PrimEliminater
>
(
prim
::
kPrimIdentity
)
),
opt_update_zero_tensor_
(),
opt_update_zero_tensor_
(
std
::
make_shared
<
OptUpdateZeroTensor
>
()
),
constant_duplicate_mul_
(),
constant_duplicate_mul_
(
std
::
make_shared
<
ConstantDuplicateMul
>
()
),
power_one_
()
{
power_one_
(
std
::
make_shared
<
PowerOneEliminate
>
()
)
{
eliminaters_
.
emplace_back
(
multiply_by_zero_or_one_
);
eliminaters_
.
emplace_back
(
multiply_by_zero_or_one_
);
eliminaters_
.
emplace_back
(
tensor_multiply_by_one_
);
eliminaters_
.
emplace_back
(
tensor_multiply_by_one_
);
eliminaters_
.
emplace_back
(
add_by_zero_
);
eliminaters_
.
emplace_back
(
add_by_zero_
);
...
@@ -761,10 +762,10 @@ class ArithmeticSimplify {
...
@@ -761,10 +762,10 @@ class ArithmeticSimplify {
}
}
~
ArithmeticSimplify
()
=
default
;
~
ArithmeticSimplify
()
=
default
;
AnfNodePtr
operator
()(
const
OptimizerPtr
&
optimizer
,
const
AnfNodePtr
&
node
)
{
AnfNodePtr
operator
()(
const
OptimizerPtr
&
optimizer
,
const
AnfNodePtr
&
node
)
override
{
AnfNodePtr
new_node
;
AnfNodePtr
new_node
;
for
(
auto
&
eliminater
:
eliminaters_
)
{
for
(
auto
&
eliminater
:
eliminaters_
)
{
new_node
=
eliminater
(
optimizer
,
node
);
new_node
=
(
*
eliminater
)
(
optimizer
,
node
);
if
(
new_node
!=
nullptr
)
{
if
(
new_node
!=
nullptr
)
{
return
new_node
;
return
new_node
;
}
}
...
@@ -773,15 +774,9 @@ class ArithmeticSimplify {
...
@@ -773,15 +774,9 @@ class ArithmeticSimplify {
}
}
private:
private:
MultiplyByZeroOrOne
multiply_by_zero_or_one_
;
OptimizerCallerPtr
multiply_by_zero_or_one_
,
tensor_multiply_by_one_
,
add_by_zero_
,
tensor_add_by_zero_
,
identity_
,
TensorMultiplyByOne
tensor_multiply_by_one_
;
opt_update_zero_tensor_
,
constant_duplicate_mul_
,
power_one_
;
AddByZero
add_by_zero_
;
std
::
vector
<
OptimizerCallerPtr
>
eliminaters_
{};
TensorAddByZero
tensor_add_by_zero_
;
PrimEliminater
identity_
;
OptUpdateZeroTensor
opt_update_zero_tensor_
;
ConstantDuplicateMul
constant_duplicate_mul_
;
PowerOneEliminate
power_one_
;
std
::
vector
<
TransformFuncType
>
eliminaters_
{};
};
};
// Arithmetic Simplifications should be done after step_parallel.
// Arithmetic Simplifications should be done after step_parallel.
...
@@ -789,15 +784,17 @@ class ArithmeticSimplify {
...
@@ -789,15 +784,17 @@ class ArithmeticSimplify {
// with shape(weight), but after step_parallel, shape of weight may be changed, so the
// with shape(weight), but after step_parallel, shape of weight may be changed, so the
// shape of the constant tensor should also be changed. So this pass is seperated from
// shape of the constant tensor should also be changed. So this pass is seperated from
// ArithmeticSimplify and deferred until step_parallel.
// ArithmeticSimplify and deferred until step_parallel.
class
ArithmeticSimplify2
{
class
ArithmeticSimplify2
:
public
OptimizerCaller
{
public:
public:
ArithmeticSimplify2
()
:
tensor_multiply_by_zero_
()
{
eliminaters_
.
emplace_back
(
tensor_multiply_by_zero_
);
}
ArithmeticSimplify2
()
:
tensor_multiply_by_zero_
(
std
::
make_shared
<
TensorMultiplyByZero
>
())
{
eliminaters_
.
emplace_back
(
tensor_multiply_by_zero_
);
}
~
ArithmeticSimplify2
()
=
default
;
~
ArithmeticSimplify2
()
=
default
;
AnfNodePtr
operator
()(
const
OptimizerPtr
&
optimizer
,
const
AnfNodePtr
&
node
)
{
AnfNodePtr
operator
()(
const
OptimizerPtr
&
optimizer
,
const
AnfNodePtr
&
node
)
override
{
AnfNodePtr
new_node
;
AnfNodePtr
new_node
;
for
(
auto
&
eliminater
:
eliminaters_
)
{
for
(
auto
&
eliminater
:
eliminaters_
)
{
new_node
=
eliminater
(
optimizer
,
node
);
new_node
=
(
*
eliminater
)
(
optimizer
,
node
);
if
(
new_node
!=
nullptr
)
{
if
(
new_node
!=
nullptr
)
{
return
new_node
;
return
new_node
;
}
}
...
@@ -806,8 +803,8 @@ class ArithmeticSimplify2 {
...
@@ -806,8 +803,8 @@ class ArithmeticSimplify2 {
}
}
private:
private:
TensorMultiplyByZero
tensor_multiply_by_zero_
;
OptimizerCallerPtr
tensor_multiply_by_zero_
;
std
::
vector
<
TransformFuncType
>
eliminaters_
{};
std
::
vector
<
OptimizerCallerPtr
>
eliminaters_
{};
};
};
}
// namespace irpass
}
// namespace irpass
}
// namespace opt
}
// namespace opt
...
...
mindspore/ccsrc/optimizer/irpass/cast_eliminate.h
浏览文件 @
5a886794
...
@@ -17,9 +17,9 @@
...
@@ -17,9 +17,9 @@
#ifndef MINDSPORE_CCSRC_OPTIMIZER_IRPASS_CAST_ELIMINATE_H_
#ifndef MINDSPORE_CCSRC_OPTIMIZER_IRPASS_CAST_ELIMINATE_H_
#define MINDSPORE_CCSRC_OPTIMIZER_IRPASS_CAST_ELIMINATE_H_
#define MINDSPORE_CCSRC_OPTIMIZER_IRPASS_CAST_ELIMINATE_H_
#include "ir/visitor.h"
#include "optimizer/irpass.h"
#include "optimizer/irpass.h"
#include "optimizer/optimizer.h"
#include "optimizer/optimizer.h"
#include "ir/visitor.h"
namespace
mindspore
{
namespace
mindspore
{
namespace
opt
{
namespace
opt
{
...
@@ -52,12 +52,12 @@ class TwoCastEliminater : public AnfVisitor {
...
@@ -52,12 +52,12 @@ class TwoCastEliminater : public AnfVisitor {
AnfNodePtr
x_
{
nullptr
},
t_
{
nullptr
};
AnfNodePtr
x_
{
nullptr
},
t_
{
nullptr
};
};
};
class
CastEliminater
{
class
CastEliminater
:
public
OptimizerCaller
{
public:
public:
CastEliminater
()
:
cast_same_type_eliminater_
(),
two_cast_eliminater_
()
{}
CastEliminater
()
:
cast_same_type_eliminater_
(),
two_cast_eliminater_
()
{}
~
CastEliminater
()
=
default
;
~
CastEliminater
()
=
default
;
AnfNodePtr
operator
()(
const
OptimizerPtr
&
optimizer
,
const
AnfNodePtr
&
node
)
{
AnfNodePtr
operator
()(
const
OptimizerPtr
&
optimizer
,
const
AnfNodePtr
&
node
)
override
{
auto
new_node
=
cast_same_type_eliminater_
(
optimizer
,
node
);
auto
new_node
=
cast_same_type_eliminater_
(
optimizer
,
node
);
if
(
new_node
!=
nullptr
)
{
if
(
new_node
!=
nullptr
)
{
return
new_node
;
return
new_node
;
...
...
mindspore/ccsrc/optimizer/irpass/env_item_eliminate.h
浏览文件 @
5a886794
...
@@ -17,18 +17,19 @@
...
@@ -17,18 +17,19 @@
#ifndef MINDSPORE_CCSRC_OPTIMIZER_IRPASS_ENV_ITEM_ELIMINATE_H_
#ifndef MINDSPORE_CCSRC_OPTIMIZER_IRPASS_ENV_ITEM_ELIMINATE_H_
#define MINDSPORE_CCSRC_OPTIMIZER_IRPASS_ENV_ITEM_ELIMINATE_H_
#define MINDSPORE_CCSRC_OPTIMIZER_IRPASS_ENV_ITEM_ELIMINATE_H_
#include <vector>
#include <utility>
#include <algorithm>
#include <algorithm>
#include <unordered_map>
#include <memory>
#include <memory>
#include <unordered_map>
#include <utility>
#include <vector>
#include "optimizer/irpass.h"
#include "optimizer/optimizer.h"
#include "ir/visitor.h"
#include "ir/func_graph.h"
#include "ir/func_graph.h"
#include "ir/func_graph_cloner.h"
#include "ir/func_graph_cloner.h"
#include "ir/optimizer_caller.h"
#include "ir/visitor.h"
#include "operator/ops.h"
#include "operator/ops.h"
#include "optimizer/irpass.h"
#include "optimizer/optimizer.h"
#include "utils/symbolic.h"
#include "utils/symbolic.h"
namespace
mindspore
{
namespace
mindspore
{
...
@@ -225,19 +226,22 @@ class EnvGetSetItem : public AnfVisitor {
...
@@ -225,19 +226,22 @@ class EnvGetSetItem : public AnfVisitor {
bool
is_match_
{
false
};
bool
is_match_
{
false
};
};
};
class
EnvGetItemEliminater
{
class
EnvGetItemEliminater
:
public
OptimizerCaller
{
public:
public:
EnvGetItemEliminater
()
:
new_env_get_item_
(),
add_env_get_item_
(),
env_get_set_item_
()
{
EnvGetItemEliminater
()
:
new_env_get_item_
(
std
::
make_shared
<
NewEnvGetItem
>
()),
add_env_get_item_
(
std
::
make_shared
<
AddEnvGetItem
>
()),
env_get_set_item_
(
std
::
make_shared
<
EnvGetSetItem
>
())
{
eliminaters_
.
emplace_back
(
new_env_get_item_
);
eliminaters_
.
emplace_back
(
new_env_get_item_
);
eliminaters_
.
emplace_back
(
add_env_get_item_
);
eliminaters_
.
emplace_back
(
add_env_get_item_
);
eliminaters_
.
emplace_back
(
env_get_set_item_
);
eliminaters_
.
emplace_back
(
env_get_set_item_
);
}
}
~
EnvGetItemEliminater
()
=
default
;
~
EnvGetItemEliminater
()
=
default
;
AnfNodePtr
operator
()(
const
OptimizerPtr
&
optimizer
,
const
AnfNodePtr
&
node
)
{
AnfNodePtr
operator
()(
const
OptimizerPtr
&
optimizer
,
const
AnfNodePtr
&
node
)
override
{
AnfNodePtr
new_node
;
AnfNodePtr
new_node
;
for
(
auto
&
eliminater
:
eliminaters_
)
{
for
(
auto
&
eliminater
:
eliminaters_
)
{
new_node
=
eliminater
(
optimizer
,
node
);
new_node
=
(
*
eliminater
)
(
optimizer
,
node
);
if
(
new_node
!=
nullptr
)
{
if
(
new_node
!=
nullptr
)
{
return
new_node
;
return
new_node
;
}
}
...
@@ -246,10 +250,8 @@ class EnvGetItemEliminater {
...
@@ -246,10 +250,8 @@ class EnvGetItemEliminater {
}
}
private:
private:
NewEnvGetItem
new_env_get_item_
;
OptimizerCallerPtr
new_env_get_item_
,
add_env_get_item_
,
env_get_set_item_
;
AddEnvGetItem
add_env_get_item_
;
std
::
vector
<
OptimizerCallerPtr
>
eliminaters_
{};
EnvGetSetItem
env_get_set_item_
;
std
::
vector
<
TransformFuncType
>
eliminaters_
{};
};
};
// {prim::kPrimEnvGetItem, {G, Xs}, C, Y}
// {prim::kPrimEnvGetItem, {G, Xs}, C, Y}
...
...
mindspore/ccsrc/optimizer/irpass/incorporate_getitem.h
浏览文件 @
5a886794
...
@@ -17,18 +17,20 @@
...
@@ -17,18 +17,20 @@
#ifndef MINDSPORE_CCSRC_OPTIMIZER_IRPASS_INCORPORATE_GETITEM_H_
#ifndef MINDSPORE_CCSRC_OPTIMIZER_IRPASS_INCORPORATE_GETITEM_H_
#define MINDSPORE_CCSRC_OPTIMIZER_IRPASS_INCORPORATE_GETITEM_H_
#define MINDSPORE_CCSRC_OPTIMIZER_IRPASS_INCORPORATE_GETITEM_H_
#include <vector>
#include <algorithm>
#include <algorithm>
#include <unordered_map>
#include <memory>
#include <memory>
#include <unordered_map>
#include <unordered_set>
#include <unordered_set>
#include <vector>
#include "optimizer/irpass.h"
#include "optimizer/optimizer.h"
#include "ir/visitor.h"
#include "ir/func_graph.h"
#include "ir/func_graph.h"
#include "ir/func_graph_cloner.h"
#include "ir/func_graph_cloner.h"
#include "ir/optimizer_caller.h"
#include "ir/visitor.h"
#include "operator/ops.h"
#include "operator/ops.h"
#include "optimizer/irpass.h"
#include "optimizer/optimizer.h"
namespace
mindspore
{
namespace
mindspore
{
namespace
opt
{
namespace
opt
{
namespace
irpass
{
namespace
irpass
{
...
@@ -383,18 +385,20 @@ class IncorporateGetitemSwitch : public AnfVisitor {
...
@@ -383,18 +385,20 @@ class IncorporateGetitemSwitch : public AnfVisitor {
internal
::
GetitemTransform
getitem_transform_
;
internal
::
GetitemTransform
getitem_transform_
;
};
};
class
IncorporateGetitemSet
{
class
IncorporateGetitemSet
:
public
OptimizerCaller
{
public:
public:
IncorporateGetitemSet
()
:
incorporate_getitem_
(),
incorporate_getitem_switch_
()
{
IncorporateGetitemSet
()
:
incorporate_getitem_
(
std
::
make_shared
<
IncorporateGetitem
>
()),
incorporate_getitem_switch_
(
std
::
make_shared
<
IncorporateGetitemSwitch
>
())
{
eliminaters_
.
emplace_back
(
incorporate_getitem_
);
eliminaters_
.
emplace_back
(
incorporate_getitem_
);
eliminaters_
.
emplace_back
(
incorporate_getitem_switch_
);
eliminaters_
.
emplace_back
(
incorporate_getitem_switch_
);
}
}
~
IncorporateGetitemSet
()
=
default
;
~
IncorporateGetitemSet
()
=
default
;
AnfNodePtr
operator
()(
const
OptimizerPtr
&
optimizer
,
const
AnfNodePtr
&
node
)
{
AnfNodePtr
operator
()(
const
OptimizerPtr
&
optimizer
,
const
AnfNodePtr
&
node
)
override
{
AnfNodePtr
new_node
;
AnfNodePtr
new_node
;
for
(
auto
&
eliminater
:
eliminaters_
)
{
for
(
auto
&
eliminater
:
eliminaters_
)
{
new_node
=
eliminater
(
optimizer
,
node
);
new_node
=
(
*
eliminater
)
(
optimizer
,
node
);
if
(
new_node
!=
nullptr
)
{
if
(
new_node
!=
nullptr
)
{
return
new_node
;
return
new_node
;
}
}
...
@@ -403,9 +407,8 @@ class IncorporateGetitemSet {
...
@@ -403,9 +407,8 @@ class IncorporateGetitemSet {
}
}
private:
private:
IncorporateGetitem
incorporate_getitem_
;
OptimizerCallerPtr
incorporate_getitem_
,
incorporate_getitem_switch_
;
IncorporateGetitemSwitch
incorporate_getitem_switch_
;
std
::
vector
<
OptimizerCallerPtr
>
eliminaters_
{};
std
::
vector
<
TransformFuncType
>
eliminaters_
{};
};
};
}
// namespace irpass
}
// namespace irpass
}
// namespace opt
}
// namespace opt
...
...
mindspore/ccsrc/optimizer/irpass/item_tuple_eliminate.h
浏览文件 @
5a886794
...
@@ -17,13 +17,15 @@
...
@@ -17,13 +17,15 @@
#ifndef MINDSPORE_CCSRC_OPTIMIZER_IRPASS_ITEM_TUPLE_ELIMINATE_H_
#ifndef MINDSPORE_CCSRC_OPTIMIZER_IRPASS_ITEM_TUPLE_ELIMINATE_H_
#define MINDSPORE_CCSRC_OPTIMIZER_IRPASS_ITEM_TUPLE_ELIMINATE_H_
#define MINDSPORE_CCSRC_OPTIMIZER_IRPASS_ITEM_TUPLE_ELIMINATE_H_
#include <vector>
#include <algorithm>
#include <algorithm>
#include <memory>
#include <vector>
#include "optimizer/irpass.h"
#include "ir/optimizer_caller.h"
#include "optimizer/optimizer.h"
#include "ir/visitor.h"
#include "ir/visitor.h"
#include "operator/ops.h"
#include "operator/ops.h"
#include "optimizer/irpass.h"
#include "optimizer/optimizer.h"
namespace
mindspore
{
namespace
mindspore
{
namespace
opt
{
namespace
opt
{
...
@@ -261,14 +263,14 @@ class GetitemDependReorder : public AnfVisitor {
...
@@ -261,14 +263,14 @@ class GetitemDependReorder : public AnfVisitor {
AnfNodePtr
x_
{
nullptr
},
y_
{
nullptr
},
c_
{
nullptr
};
AnfNodePtr
x_
{
nullptr
},
y_
{
nullptr
},
c_
{
nullptr
};
};
};
class
ItemTupleEliminater
{
class
ItemTupleEliminater
:
public
OptimizerCaller
{
public:
public:
ItemTupleEliminater
()
ItemTupleEliminater
()
:
get_item_eliminater_
(),
:
get_item_eliminater_
(
std
::
make_shared
<
GetitemEliminater
>
()
),
get_item_const_eliminater_
(),
get_item_const_eliminater_
(
std
::
make_shared
<
GetitemConstEliminater
>
()
),
set_item_eliminater_
(),
set_item_eliminater_
(
std
::
make_shared
<
SetitemEliminater
>
()
),
get_set_item_eliminater_
(),
get_set_item_eliminater_
(
std
::
make_shared
<
GetSetitemEliminater
>
()
),
get_item_depend_reorder_
()
{
get_item_depend_reorder_
(
std
::
make_shared
<
GetitemDependReorder
>
()
)
{
eliminaters_
.
emplace_back
(
get_item_eliminater_
);
eliminaters_
.
emplace_back
(
get_item_eliminater_
);
eliminaters_
.
emplace_back
(
get_item_const_eliminater_
);
eliminaters_
.
emplace_back
(
get_item_const_eliminater_
);
eliminaters_
.
emplace_back
(
set_item_eliminater_
);
eliminaters_
.
emplace_back
(
set_item_eliminater_
);
...
@@ -277,10 +279,10 @@ class ItemTupleEliminater {
...
@@ -277,10 +279,10 @@ class ItemTupleEliminater {
}
}
~
ItemTupleEliminater
()
=
default
;
~
ItemTupleEliminater
()
=
default
;
AnfNodePtr
operator
()(
const
OptimizerPtr
&
optimizer
,
const
AnfNodePtr
&
node
)
{
AnfNodePtr
operator
()(
const
OptimizerPtr
&
optimizer
,
const
AnfNodePtr
&
node
)
override
{
AnfNodePtr
new_node
;
AnfNodePtr
new_node
;
for
(
auto
&
eliminater
:
eliminaters_
)
{
for
(
auto
&
eliminater
:
eliminaters_
)
{
new_node
=
eliminater
(
optimizer
,
node
);
new_node
=
(
*
eliminater
)
(
optimizer
,
node
);
if
(
new_node
!=
nullptr
)
{
if
(
new_node
!=
nullptr
)
{
return
new_node
;
return
new_node
;
}
}
...
@@ -289,12 +291,9 @@ class ItemTupleEliminater {
...
@@ -289,12 +291,9 @@ class ItemTupleEliminater {
}
}
private:
private:
GetitemEliminater
get_item_eliminater_
;
OptimizerCallerPtr
get_item_eliminater_
,
get_item_const_eliminater_
,
set_item_eliminater_
,
get_set_item_eliminater_
,
GetitemConstEliminater
get_item_const_eliminater_
;
get_item_depend_reorder_
;
SetitemEliminater
set_item_eliminater_
;
std
::
vector
<
OptimizerCallerPtr
>
eliminaters_
{};
GetSetitemEliminater
get_set_item_eliminater_
;
GetitemDependReorder
get_item_depend_reorder_
;
std
::
vector
<
TransformFuncType
>
eliminaters_
{};
};
};
}
// namespace irpass
}
// namespace irpass
}
// namespace opt
}
// namespace opt
...
...
mindspore/ccsrc/optimizer/irpass/ref_eliminate.h
浏览文件 @
5a886794
...
@@ -19,9 +19,9 @@
...
@@ -19,9 +19,9 @@
#include <memory>
#include <memory>
#include "optimizer/optimizer.h"
#include "optimizer/irpass.h"
#include "ir/pattern_matcher.h"
#include "ir/pattern_matcher.h"
#include "optimizer/irpass.h"
#include "optimizer/optimizer.h"
namespace
mindspore
{
namespace
mindspore
{
namespace
opt
{
namespace
opt
{
...
...
mindspore/ccsrc/optimizer/irpass/reshape_eliminate.h
浏览文件 @
5a886794
...
@@ -19,11 +19,12 @@
...
@@ -19,11 +19,12 @@
#include <vector>
#include <vector>
#include "optimizer/irpass.h"
#include "optimizer/optimizer.h"
#include "ir/visitor.h"
#include "ir/func_graph.h"
#include "ir/func_graph.h"
#include "ir/optimizer_caller.h"
#include "ir/visitor.h"
#include "operator/ops.h"
#include "operator/ops.h"
#include "optimizer/irpass.h"
#include "optimizer/optimizer.h"
#include "pipeline/static_analysis/dshape.h"
#include "pipeline/static_analysis/dshape.h"
namespace
mindspore
{
namespace
mindspore
{
...
@@ -124,12 +125,12 @@ class TwoReshapeEliminater : public AnfVisitor {
...
@@ -124,12 +125,12 @@ class TwoReshapeEliminater : public AnfVisitor {
AnfNodePtr
x_
{
nullptr
},
shape_
{
nullptr
};
AnfNodePtr
x_
{
nullptr
},
shape_
{
nullptr
};
};
};
class
ReshapeEliminater
{
class
ReshapeEliminater
:
public
OptimizerCaller
{
public:
public:
ReshapeEliminater
()
:
reshape_same_shape_eliminater_
(),
two_reshape_eliminater_
()
{}
ReshapeEliminater
()
:
reshape_same_shape_eliminater_
(),
two_reshape_eliminater_
()
{}
~
ReshapeEliminater
()
=
default
;
~
ReshapeEliminater
()
=
default
;
AnfNodePtr
operator
()(
const
OptimizerPtr
&
optimizer
,
const
AnfNodePtr
&
node
)
{
AnfNodePtr
operator
()(
const
OptimizerPtr
&
optimizer
,
const
AnfNodePtr
&
node
)
override
{
auto
new_node
=
reshape_same_shape_eliminater_
(
optimizer
,
node
);
auto
new_node
=
reshape_same_shape_eliminater_
(
optimizer
,
node
);
if
(
new_node
!=
nullptr
)
{
if
(
new_node
!=
nullptr
)
{
return
new_node
;
return
new_node
;
...
...
mindspore/ccsrc/optimizer/irpass/special_op_eliminate.h
浏览文件 @
5a886794
...
@@ -18,31 +18,31 @@
...
@@ -18,31 +18,31 @@
#define MINDSPORE_CCSRC_OPTIMIZER_IRPASS_SPECIAL_OP_ELIMINATE_H_
#define MINDSPORE_CCSRC_OPTIMIZER_IRPASS_SPECIAL_OP_ELIMINATE_H_
#include <securec.h>
#include <securec.h>
#include <vector>
#include <memory>
#include <algorithm>
#include <algorithm>
#include <memory>
#include <vector>
#include "optimizer/optimizer.h"
#include "optimizer/irpass.h"
#include "ir/optimizer_caller.h"
#include "ir/optimizer_caller.h"
#include "
optimizer/irpass/prim_eliminate
.h"
#include "
ir/pattern_matcher
.h"
#include "ir/visitor.h"
#include "ir/visitor.h"
#include "operator/ops.h"
#include "operator/ops.h"
#include "ir/pattern_matcher.h"
#include "optimizer/irpass.h"
#include "optimizer/irpass/prim_eliminate.h"
#include "optimizer/optimizer.h"
namespace
mindspore
{
namespace
mindspore
{
namespace
opt
{
namespace
opt
{
namespace
irpass
{
namespace
irpass
{
class
SpecialOpEliminater
{
class
SpecialOpEliminater
:
public
OptimizerCaller
{
public:
public:
SpecialOpEliminater
()
SpecialOpEliminater
()
:
insert_gradient_of_
(
prim
::
kPrimInsertGradientOf
),
:
insert_gradient_of_
(
std
::
make_shared
<
PrimEliminater
>
(
prim
::
kPrimInsertGradientOf
)
),
stop_gradient_
(
prim
::
kPrimStopGradient
),
stop_gradient_
(
std
::
make_shared
<
PrimEliminater
>
(
prim
::
kPrimStopGradient
)
),
hook_backward_
(
prim
::
kPrimHookBackward
),
hook_backward_
(
std
::
make_shared
<
PrimEliminater
>
(
prim
::
kPrimHookBackward
)
),
print_shape_type_
(
prim
::
kPrimPrintShapeType
),
print_shape_type_
(
std
::
make_shared
<
PrimEliminater
>
(
prim
::
kPrimPrintShapeType
)
),
get_ref_value_
(
prim
::
kPrimGetRefValue
),
get_ref_value_
(
std
::
make_shared
<
PrimEliminater
>
(
prim
::
kPrimGetRefValue
)
),
mirror_
(
prim
::
kPrimMirror
),
mirror_
(
std
::
make_shared
<
PrimEliminater
>
(
prim
::
kPrimMirror
)
),
virtual_div_
(
prim
::
kPrimVirtualDiv
)
{
virtual_div_
(
std
::
make_shared
<
PrimEliminater
>
(
prim
::
kPrimVirtualDiv
)
)
{
eliminaters_
.
emplace_back
(
insert_gradient_of_
);
eliminaters_
.
emplace_back
(
insert_gradient_of_
);
eliminaters_
.
emplace_back
(
stop_gradient_
);
eliminaters_
.
emplace_back
(
stop_gradient_
);
eliminaters_
.
emplace_back
(
hook_backward_
);
eliminaters_
.
emplace_back
(
hook_backward_
);
...
@@ -53,10 +53,10 @@ class SpecialOpEliminater {
...
@@ -53,10 +53,10 @@ class SpecialOpEliminater {
}
}
~
SpecialOpEliminater
()
=
default
;
~
SpecialOpEliminater
()
=
default
;
AnfNodePtr
operator
()(
const
OptimizerPtr
&
optimizer
,
const
AnfNodePtr
&
node
)
{
AnfNodePtr
operator
()(
const
OptimizerPtr
&
optimizer
,
const
AnfNodePtr
&
node
)
override
{
AnfNodePtr
new_node
;
AnfNodePtr
new_node
;
for
(
auto
&
eliminater
:
eliminaters_
)
{
for
(
auto
&
eliminater
:
eliminaters_
)
{
new_node
=
eliminater
(
optimizer
,
node
);
new_node
=
(
*
eliminater
)
(
optimizer
,
node
);
if
(
new_node
!=
nullptr
)
{
if
(
new_node
!=
nullptr
)
{
return
new_node
;
return
new_node
;
}
}
...
@@ -65,9 +65,9 @@ class SpecialOpEliminater {
...
@@ -65,9 +65,9 @@ class SpecialOpEliminater {
}
}
private:
private:
PrimEliminate
r
insert_gradient_of_
,
stop_gradient_
,
hook_backward_
,
print_shape_type_
,
get_ref_value_
,
mirror_
,
OptimizerCallerPt
r
insert_gradient_of_
,
stop_gradient_
,
hook_backward_
,
print_shape_type_
,
get_ref_value_
,
mirror_
,
virtual_div_
;
virtual_div_
;
std
::
vector
<
TransformFuncType
>
eliminaters_
{};
std
::
vector
<
OptimizerCallerPtr
>
eliminaters_
{};
};
};
// {PrimVirtualDataset, X} -> X
// {PrimVirtualDataset, X} -> X
...
...
mindspore/ccsrc/optimizer/opt.cc
浏览文件 @
5a886794
...
@@ -16,28 +16,27 @@
...
@@ -16,28 +16,27 @@
#include "optimizer/opt.h"
#include "optimizer/opt.h"
#include <algorithm>
#include <deque>
#include <memory>
#include <memory>
#include <unordered_set>
#include <unordered_set>
#include <deque>
#include <algorithm>
#include "ir/anf.h"
#include "ir/anf.h"
#include "ir/manager.h"
#include "ir/manager.h"
#include "utils/ordered_set.h"
#include "utils/log_adapter.h"
#include "optimizer/optimizer.h"
#include "optimizer/optimizer.h"
#include "utils/log_adapter.h"
#include "utils/ordered_set.h"
namespace
mindspore
{
namespace
mindspore
{
/* namespace to support opt */
/* namespace to support opt */
namespace
opt
{
namespace
opt
{
SubstitutionPtr
MakeSubstitution
(
const
TransformFuncType
&
transform
,
const
std
::
string
&
name
,
const
PrimitivePtr
&
prim
,
SubstitutionPtr
MakeSubstitution
(
const
OptimizerCallerPtr
&
transform
,
const
std
::
string
&
name
,
const
PrimitivePtr
&
prim
,
const
RenormAction
&
renorm_action
)
{
const
RenormAction
&
renorm_action
)
{
auto
fn
=
[
prim
](
const
AnfNodePtr
&
node
)
->
bool
{
return
IsPrimitiveCNode
(
node
,
prim
);
};
auto
fn
=
[
prim
](
const
AnfNodePtr
&
node
)
->
bool
{
return
IsPrimitiveCNode
(
node
,
prim
);
};
return
std
::
make_shared
<
Substitution
>
(
transform
,
name
,
fn
,
renorm_action
);
return
std
::
make_shared
<
Substitution
>
(
transform
,
name
,
fn
,
renorm_action
);
}
}
SubstitutionPtr
MakeSubstitution
(
const
TransformFuncType
&
transform
,
const
std
::
string
&
name
,
SubstitutionPtr
MakeSubstitution
(
const
OptimizerCallerPtr
&
transform
,
const
std
::
string
&
name
,
const
std
::
vector
<
PrimitivePtr
>
&
prims
,
const
RenormAction
&
renorm_action
)
{
const
std
::
vector
<
PrimitivePtr
>
&
prims
,
const
RenormAction
&
renorm_action
)
{
auto
fn
=
[
prims
](
const
AnfNodePtr
&
node
)
->
bool
{
auto
fn
=
[
prims
](
const
AnfNodePtr
&
node
)
->
bool
{
if
(
!
node
->
isa
<
CNode
>
())
{
if
(
!
node
->
isa
<
CNode
>
())
{
...
@@ -64,16 +63,16 @@ SubstitutionPtr MakeSubstitution(const TransformFuncType &transform, const std::
...
@@ -64,16 +63,16 @@ SubstitutionPtr MakeSubstitution(const TransformFuncType &transform, const std::
return
std
::
make_shared
<
Substitution
>
(
transform
,
name
,
fn
,
renorm_action
);
return
std
::
make_shared
<
Substitution
>
(
transform
,
name
,
fn
,
renorm_action
);
}
}
SubstitutionPtr
MakeSubstitution
(
const
TransformFuncType
&
transform
,
const
std
::
string
&
name
,
SubstitutionPtr
MakeSubstitution
(
const
OptimizerCallerPtr
&
transform
,
const
std
::
string
&
name
,
const
PredicateFuncType
&
predicate
,
const
RenormAction
&
renorm_action
)
{
const
PredicateFuncType
&
predicate
,
const
RenormAction
&
renorm_action
)
{
return
std
::
make_shared
<
Substitution
>
(
transform
,
name
,
predicate
,
renorm_action
);
return
std
::
make_shared
<
Substitution
>
(
transform
,
name
,
predicate
,
renorm_action
);
}
}
AnfNodePtr
Substitution
::
operator
()(
const
OptimizerPtr
&
optimizer
,
const
AnfNodePtr
&
node
)
const
{
AnfNodePtr
Substitution
::
operator
()(
const
OptimizerPtr
&
optimizer
,
const
AnfNodePtr
&
node
)
{
#ifdef ENABLE_PROFILE
#ifdef ENABLE_PROFILE
double
t
=
GetTime
();
double
t
=
GetTime
();
#endif
#endif
AnfNodePtr
result
=
transform_
(
optimizer
,
node
);
AnfNodePtr
result
=
(
*
transform_
)
(
optimizer
,
node
);
#ifdef ENABLE_PROFILE
#ifdef ENABLE_PROFILE
if
(
optimizer
!=
nullptr
)
{
if
(
optimizer
!=
nullptr
)
{
auto
time
=
GetTime
();
auto
time
=
GetTime
();
...
...
mindspore/ccsrc/optimizer/opt.h
浏览文件 @
5a886794
...
@@ -17,24 +17,18 @@
...
@@ -17,24 +17,18 @@
#ifndef MINDSPORE_CCSRC_OPTIMIZER_OPT_H_
#ifndef MINDSPORE_CCSRC_OPTIMIZER_OPT_H_
#define MINDSPORE_CCSRC_OPTIMIZER_OPT_H_
#define MINDSPORE_CCSRC_OPTIMIZER_OPT_H_
#include <vector>
#include <string>
#include <memory>
#include <memory>
#include <string>
#include <vector>
#include "ir/anf.h"
#include "ir/anf.h"
#include "ir/func_graph.h"
#include "ir/func_graph.h"
#include "ir/optimizer_caller.h"
#include "operator/ops.h"
#include "operator/ops.h"
namespace
mindspore
{
namespace
mindspore
{
/* namespace to support opt */
/* namespace to support opt */
namespace
opt
{
namespace
opt
{
class
Optimizer
;
using
OptimizerPtr
=
std
::
shared_ptr
<
Optimizer
>
;
using
OptimizerWeakPtr
=
std
::
weak_ptr
<
Optimizer
>
;
using
PredicateFuncType
=
std
::
function
<
bool
(
const
AnfNodePtr
&
)
>
;
using
TransformFuncType
=
std
::
function
<
AnfNodePtr
(
const
OptimizerPtr
&
,
const
AnfNodePtr
&
)
>
;
// Define the interaction mode between an Optimize pass and Renormalize pass
// Define the interaction mode between an Optimize pass and Renormalize pass
// FORCE_RENORM: if the pass modified the graph then the next Renormalize will be executed
// FORCE_RENORM: if the pass modified the graph then the next Renormalize will be executed
...
@@ -43,26 +37,26 @@ enum RenormAction : int { FORCE_RENORM = 0, CHECK_RENORM };
...
@@ -43,26 +37,26 @@ enum RenormAction : int { FORCE_RENORM = 0, CHECK_RENORM };
class
Substitution
{
class
Substitution
{
public:
public:
TransformFuncType
transform_
{
nullptr
}
;
OptimizerCallerPtr
transform_
;
std
::
string
name_
;
std
::
string
name_
;
PredicateFuncType
predicate_
{
nullptr
};
PredicateFuncType
predicate_
{
nullptr
};
// an enum to mark this Substitution relation to renormalize pass
// an enum to mark this Substitution relation to renormalize pass
RenormAction
renorm_action_
;
RenormAction
renorm_action_
;
Substitution
(
const
TransformFuncType
&
transform
,
const
std
::
string
&
name
,
const
PredicateFuncType
&
predicate
,
Substitution
(
const
OptimizerCallerPtr
&
transform
,
const
std
::
string
&
name
,
const
PredicateFuncType
&
predicate
,
const
RenormAction
&
renorm_action
)
const
RenormAction
&
renorm_action
)
:
transform_
(
transform
),
name_
(
name
),
predicate_
(
predicate
),
renorm_action_
(
renorm_action
)
{}
:
transform_
(
transform
),
name_
(
name
),
predicate_
(
predicate
),
renorm_action_
(
renorm_action
)
{}
~
Substitution
()
=
default
;
~
Substitution
()
=
default
;
AnfNodePtr
operator
()(
const
OptimizerPtr
&
optimizer
,
const
AnfNodePtr
&
node
)
const
;
AnfNodePtr
operator
()(
const
OptimizerPtr
&
optimizer
,
const
AnfNodePtr
&
node
);
};
};
using
SubstitutionPtr
=
std
::
shared_ptr
<
Substitution
>
;
using
SubstitutionPtr
=
std
::
shared_ptr
<
Substitution
>
;
SubstitutionPtr
MakeSubstitution
(
const
TransformFuncType
&
transform
,
const
std
::
string
&
name
,
const
PrimitivePtr
&
prim
,
SubstitutionPtr
MakeSubstitution
(
const
OptimizerCallerPtr
&
transform
,
const
std
::
string
&
name
,
const
PrimitivePtr
&
prim
,
const
RenormAction
&
action_renorm
=
CHECK_RENORM
);
const
RenormAction
&
action_renorm
=
CHECK_RENORM
);
SubstitutionPtr
MakeSubstitution
(
const
TransformFuncType
&
transform
,
const
std
::
string
&
name
,
SubstitutionPtr
MakeSubstitution
(
const
OptimizerCallerPtr
&
transform
,
const
std
::
string
&
name
,
const
std
::
vector
<
PrimitivePtr
>
&
prims
,
const
std
::
vector
<
PrimitivePtr
>
&
prims
,
const
RenormAction
&
action_renorm
=
CHECK_RENORM
);
const
RenormAction
&
action_renorm
=
CHECK_RENORM
);
SubstitutionPtr
MakeSubstitution
(
const
TransformFuncType
&
transform
,
const
std
::
string
&
name
,
SubstitutionPtr
MakeSubstitution
(
const
OptimizerCallerPtr
&
transform
,
const
std
::
string
&
name
,
const
PredicateFuncType
&
predicate
,
const
RenormAction
&
action_renorm
=
CHECK_RENORM
);
const
PredicateFuncType
&
predicate
,
const
RenormAction
&
action_renorm
=
CHECK_RENORM
);
class
SubstitutionList
{
class
SubstitutionList
{
...
...
mindspore/ccsrc/parallel/auto_parallel/operator_costmodel.cc
浏览文件 @
5a886794
...
@@ -465,7 +465,7 @@ double ReshapeCost::GetForwardCommCost(const std::vector<TensorInfo> &inputs, co
...
@@ -465,7 +465,7 @@ double ReshapeCost::GetForwardCommCost(const std::vector<TensorInfo> &inputs, co
CheckGlobalDeviceManager
();
CheckGlobalDeviceManager
();
MS_EXCEPTION_IF_NULL
(
g_device_manager
);
MS_EXCEPTION_IF_NULL
(
g_device_manager
);
RankList
dev_list
=
g_device_manager
->
GetDeviceListByStageId
(
stage_id
);
RankList
dev_list
=
g_device_manager
->
GetDeviceListByStageId
(
stage_id
);
TensorRedistribution
tensor_redistribution
;
TensorRedistribution
tensor_redistribution
(
false
,
true
)
;
if
(
tensor_redistribution
.
Init
(
inputs
[
0
].
tensor_layout
(),
outputs
[
0
].
tensor_layout
(),
dev_list
)
==
FAILED
)
{
if
(
tensor_redistribution
.
Init
(
inputs
[
0
].
tensor_layout
(),
outputs
[
0
].
tensor_layout
(),
dev_list
)
==
FAILED
)
{
MS_LOG
(
EXCEPTION
)
<<
"Failure: tensor_redistribution init failed."
;
MS_LOG
(
EXCEPTION
)
<<
"Failure: tensor_redistribution init failed."
;
}
}
...
@@ -503,7 +503,7 @@ double ReshapeCost::GetForwardComputationCost(const std::vector<TensorInfo> &inp
...
@@ -503,7 +503,7 @@ double ReshapeCost::GetForwardComputationCost(const std::vector<TensorInfo> &inp
CheckGlobalDeviceManager
();
CheckGlobalDeviceManager
();
MS_EXCEPTION_IF_NULL
(
g_device_manager
);
MS_EXCEPTION_IF_NULL
(
g_device_manager
);
RankList
dev_list
=
g_device_manager
->
GetDeviceListByStageId
(
stage_id
);
RankList
dev_list
=
g_device_manager
->
GetDeviceListByStageId
(
stage_id
);
TensorRedistribution
tensor_redistribution
;
TensorRedistribution
tensor_redistribution
(
false
,
true
)
;
if
(
tensor_redistribution
.
Init
(
inputs
[
0
].
tensor_layout
(),
outputs
[
0
].
tensor_layout
(),
dev_list
)
==
FAILED
)
{
if
(
tensor_redistribution
.
Init
(
inputs
[
0
].
tensor_layout
(),
outputs
[
0
].
tensor_layout
(),
dev_list
)
==
FAILED
)
{
MS_LOG
(
EXCEPTION
)
<<
"Failure: tensor_redistribution init failed."
;
MS_LOG
(
EXCEPTION
)
<<
"Failure: tensor_redistribution init failed."
;
}
}
...
...
mindspore/ccsrc/parallel/context.cc
浏览文件 @
5a886794
...
@@ -62,6 +62,7 @@ void ParallelContext::Reset() {
...
@@ -62,6 +62,7 @@ void ParallelContext::Reset() {
enable_all_reduce_fusion_
=
false
;
enable_all_reduce_fusion_
=
false
;
strategy_ckpt_load_file_
=
""
;
strategy_ckpt_load_file_
=
""
;
strategy_ckpt_save_file_
=
""
;
strategy_ckpt_save_file_
=
""
;
enable_parallel_optimizer_
=
false
;
}
}
void
ParallelContext
::
set_device_num
(
int32_t
device_num
)
{
void
ParallelContext
::
set_device_num
(
int32_t
device_num
)
{
...
...
mindspore/ccsrc/parallel/context.h
浏览文件 @
5a886794
...
@@ -100,6 +100,11 @@ class ParallelContext {
...
@@ -100,6 +100,11 @@ class ParallelContext {
void
set_strategy_ckpt_save_file
(
const
std
::
string
&
strategy_ckpt_save_file
);
void
set_strategy_ckpt_save_file
(
const
std
::
string
&
strategy_ckpt_save_file
);
std
::
string
strategy_ckpt_save_file
()
const
{
return
strategy_ckpt_save_file_
;
}
std
::
string
strategy_ckpt_save_file
()
const
{
return
strategy_ckpt_save_file_
;
}
void
set_enable_parallel_optimizer
(
bool
enable_parallel_optimizer
)
{
enable_parallel_optimizer_
=
enable_parallel_optimizer
;
}
bool
enable_parallel_optimizer
()
const
{
return
enable_parallel_optimizer_
;
}
void
Reset
();
void
Reset
();
private:
private:
...
@@ -123,6 +128,7 @@ class ParallelContext {
...
@@ -123,6 +128,7 @@ class ParallelContext {
std
::
map
<
std
::
string
,
std
::
vector
<
uint32_t
>>
all_reduce_fusion_split_sizes_
;
std
::
map
<
std
::
string
,
std
::
vector
<
uint32_t
>>
all_reduce_fusion_split_sizes_
;
std
::
string
strategy_ckpt_load_file_
;
std
::
string
strategy_ckpt_load_file_
;
std
::
string
strategy_ckpt_save_file_
;
std
::
string
strategy_ckpt_save_file_
;
bool
enable_parallel_optimizer_
;
};
};
void
ParallelParameterContextInit
(
const
FuncGraphPtr
&
func_graph
);
void
ParallelParameterContextInit
(
const
FuncGraphPtr
&
func_graph
);
...
...
mindspore/ccsrc/pipeline/init.cc
浏览文件 @
5a886794
...
@@ -205,6 +205,10 @@ PYBIND11_MODULE(_c_expression, m) {
...
@@ -205,6 +205,10 @@ PYBIND11_MODULE(_c_expression, m) {
.
def
(
"get_strategy_ckpt_save_file"
,
&
ParallelContext
::
strategy_ckpt_save_file
,
"Get strategy checkpoint save file."
)
.
def
(
"get_strategy_ckpt_save_file"
,
&
ParallelContext
::
strategy_ckpt_save_file
,
"Get strategy checkpoint save file."
)
.
def
(
"set_full_batch"
,
&
ParallelContext
::
set_full_batch
,
"Set whether load full batch on each device."
)
.
def
(
"set_full_batch"
,
&
ParallelContext
::
set_full_batch
,
"Set whether load full batch on each device."
)
.
def
(
"get_full_batch"
,
&
ParallelContext
::
full_batch
,
"Get whether load full batch on each device."
)
.
def
(
"get_full_batch"
,
&
ParallelContext
::
full_batch
,
"Get whether load full batch on each device."
)
.
def
(
"set_enable_parallel_optimizer"
,
&
ParallelContext
::
set_enable_parallel_optimizer
,
"Set enable/disable parallel optimizer."
)
.
def
(
"get_enable_parallel_optimizer"
,
&
ParallelContext
::
enable_parallel_optimizer
,
"Get enable/disable parallel optimizer."
)
.
def
(
"reset"
,
&
ParallelContext
::
Reset
,
"Reset auto parallel context."
);
.
def
(
"reset"
,
&
ParallelContext
::
Reset
,
"Reset auto parallel context."
);
(
void
)
py
::
class_
<
CostModelContext
,
std
::
shared_ptr
<
CostModelContext
>>
(
m
,
"CostModelContext"
)
(
void
)
py
::
class_
<
CostModelContext
,
std
::
shared_ptr
<
CostModelContext
>>
(
m
,
"CostModelContext"
)
...
...
mindspore/ccsrc/pre_activate/pass/common_subexpression_elimination.cc
浏览文件 @
5a886794
...
@@ -35,7 +35,7 @@ bool CheckEqualKernelBuildInfo(const AnfNodePtr &main, const AnfNodePtr &node) {
...
@@ -35,7 +35,7 @@ bool CheckEqualKernelBuildInfo(const AnfNodePtr &main, const AnfNodePtr &node) {
}
}
}
// namespace
}
// namespace
bool
BackendCSE
::
CheckReplace
(
const
AnfNodePtr
&
main
,
const
AnfNodePtr
&
node
)
const
{
bool
BackendCSE
::
CheckReplace
(
const
AnfNodePtr
&
main
,
const
AnfNodePtr
&
node
,
bool
)
const
{
MS_EXCEPTION_IF_NULL
(
main
);
MS_EXCEPTION_IF_NULL
(
main
);
MS_EXCEPTION_IF_NULL
(
node
);
MS_EXCEPTION_IF_NULL
(
node
);
...
...
mindspore/ccsrc/pre_activate/pass/common_subexpression_elimination.h
浏览文件 @
5a886794
...
@@ -31,7 +31,7 @@ class BackendCSE : public CSE {
...
@@ -31,7 +31,7 @@ class BackendCSE : public CSE {
public:
public:
BackendCSE
()
=
default
;
BackendCSE
()
=
default
;
~
BackendCSE
()
override
=
default
;
~
BackendCSE
()
override
=
default
;
bool
CheckReplace
(
const
AnfNodePtr
&
main
,
const
AnfNodePtr
&
node
)
const
override
;
bool
CheckReplace
(
const
AnfNodePtr
&
main
,
const
AnfNodePtr
&
node
,
bool
check_side_effect
=
true
)
const
override
;
};
};
}
// namespace opt
}
// namespace opt
}
// namespace mindspore
}
// namespace mindspore
...
...
mindspore/ccsrc/pybind_api/export_flags.cc
浏览文件 @
5a886794
...
@@ -33,5 +33,6 @@ const char GRAPH_FLAG_LOOP_CAN_UNROLL[] = "loop_can_unroll";
...
@@ -33,5 +33,6 @@ const char GRAPH_FLAG_LOOP_CAN_UNROLL[] = "loop_can_unroll";
const
char
GRAPH_FLAG_HAS_EFFECT
[]
=
"has_effect"
;
const
char
GRAPH_FLAG_HAS_EFFECT
[]
=
"has_effect"
;
const
char
GRAPH_FLAG_EFFECT_PATIAL_ORDER
[]
=
"_effect_patial_order"
;
const
char
GRAPH_FLAG_EFFECT_PATIAL_ORDER
[]
=
"_effect_patial_order"
;
const
char
GRAPH_FLAG_RANDOM_EFFECT
[]
=
"_random_effect"
;
const
char
GRAPH_FLAG_RANDOM_EFFECT
[]
=
"_random_effect"
;
const
char
GRAPH_FLAG_SIDE_EFFECT
[]
=
"_side_effect"
;
}
// namespace mindspore
}
// namespace mindspore
mindspore/ccsrc/pybind_api/export_flags.h
浏览文件 @
5a886794
...
@@ -34,7 +34,7 @@ extern const char GRAPH_FLAG_LOOP_CAN_UNROLL[];
...
@@ -34,7 +34,7 @@ extern const char GRAPH_FLAG_LOOP_CAN_UNROLL[];
extern
const
char
GRAPH_FLAG_HAS_EFFECT
[];
extern
const
char
GRAPH_FLAG_HAS_EFFECT
[];
extern
const
char
GRAPH_FLAG_EFFECT_PATIAL_ORDER
[];
extern
const
char
GRAPH_FLAG_EFFECT_PATIAL_ORDER
[];
extern
const
char
GRAPH_FLAG_RANDOM_EFFECT
[];
extern
const
char
GRAPH_FLAG_RANDOM_EFFECT
[];
extern
const
char
GRAPH_FLAG_SIDE_EFFECT
[];
}
// namespace mindspore
}
// namespace mindspore
#endif // PYBIND_API_EXPORT_FLAGS_H_
#endif // PYBIND_API_EXPORT_FLAGS_H_
mindspore/ccsrc/session/ascend_control_parser.cc
浏览文件 @
5a886794
...
@@ -33,6 +33,21 @@ static constexpr size_t kCNodeSwitchLayerLength = 3;
...
@@ -33,6 +33,21 @@ static constexpr size_t kCNodeSwitchLayerLength = 3;
namespace
mindspore
{
namespace
mindspore
{
namespace
session
{
namespace
session
{
static
CNodePtr
GetJumpNode
(
NotNull
<
KernelGraphPtr
>
parent_graph
,
NotNull
<
KernelGraphPtr
>
child_graph
)
{
auto
&
nodes
=
parent_graph
->
execution_order
();
for
(
auto
&
node
:
nodes
)
{
if
(
IsPrimitiveCNode
(
node
,
prim
::
kPrimLabelGoto
)
&&
child_graph
->
get_start_label
()
==
node
->
input
(
kCNodeCallArg
))
{
return
node
;
}
else
if
(
IsPrimitiveCNode
(
node
,
prim
::
kPrimLabelSwitch
)
&&
(
child_graph
->
get_start_label
()
==
node
->
input
(
kCNodeSwitchFalse
)
||
child_graph
->
get_start_label
()
==
node
->
input
(
kCNodeSwitchTrue
)))
{
return
node
;
}
}
MS_LOG
(
INFO
)
<<
"Cannot find jump node from "
<<
parent_graph
->
ToString
()
<<
" to "
<<
child_graph
->
ToString
();
return
nullptr
;
}
static
void
InitUnionFindSet
(
NotNull
<
KernelGraphPtr
>
kg
,
const
NotNull
<
UnionFindSet
<
AnfNodePtr
>
*>
union_find_set
,
static
void
InitUnionFindSet
(
NotNull
<
KernelGraphPtr
>
kg
,
const
NotNull
<
UnionFindSet
<
AnfNodePtr
>
*>
union_find_set
,
const
NotNull
<
std
::
set
<
KernelGraphPtr
>
*>
memo
)
{
const
NotNull
<
std
::
set
<
KernelGraphPtr
>
*>
memo
)
{
if
(
memo
->
find
(
kg
.
get
())
!=
memo
->
end
())
{
if
(
memo
->
find
(
kg
.
get
())
!=
memo
->
end
())
{
...
@@ -200,7 +215,8 @@ void AscendControlParser::ChildGraphDataAssign(const std::map<uint32_t, KernelGr
...
@@ -200,7 +215,8 @@ void AscendControlParser::ChildGraphDataAssign(const std::map<uint32_t, KernelGr
if
(
target_graph_iter
==
graph_id_map
.
end
())
{
if
(
target_graph_iter
==
graph_id_map
.
end
())
{
MS_LOG
(
EXCEPTION
)
<<
"Graph id "
<<
AnfAlgo
::
GetGraphId
(
arg
.
get
())
<<
" not found."
;
MS_LOG
(
EXCEPTION
)
<<
"Graph id "
<<
AnfAlgo
::
GetGraphId
(
arg
.
get
())
<<
" not found."
;
}
}
InsertMultipleAssignToGraph
(
NOT_NULL
(
target_graph_iter
->
second
),
NOT_NULL
(
arg
),
NOT_NULL
(
parameter
));
InsertMultipleAssignToGraph
(
NOT_NULL
(
target_graph_iter
->
second
),
NOT_NULL
(
kg
),
NOT_NULL
(
arg
),
NOT_NULL
(
parameter
));
}
}
}
}
}
}
...
@@ -263,7 +279,7 @@ NotNull<CNodePtr> AscendControlParser::ProcessKernelGraph(NotNull<KernelGraphPtr
...
@@ -263,7 +279,7 @@ NotNull<CNodePtr> AscendControlParser::ProcessKernelGraph(NotNull<KernelGraphPtr
RecurseSwitchLayer
(
kg
,
NOT_NULL
(
cnode
),
GetNextRealKernel
(
nodes
,
i
+
1
),
memo
);
RecurseSwitchLayer
(
kg
,
NOT_NULL
(
cnode
),
GetNextRealKernel
(
nodes
,
i
+
1
),
memo
);
}
}
}
}
kg
->
SetExecOrderByDefault
();
MS_LOG
(
INFO
)
<<
"End KernelGraph process: "
<<
kg
->
ToString
();
MS_LOG
(
INFO
)
<<
"End KernelGraph process: "
<<
kg
->
ToString
();
return
NOT_NULL
(
start_label
);
return
NOT_NULL
(
start_label
);
}
}
...
@@ -433,7 +449,8 @@ std::tuple<CNodePtr, KernelGraphPtr> AscendControlParser::ParsePartial(NotNull<A
...
@@ -433,7 +449,8 @@ std::tuple<CNodePtr, KernelGraphPtr> AscendControlParser::ParsePartial(NotNull<A
return
{
partial_cnode
,
branch_kg
};
return
{
partial_cnode
,
branch_kg
};
}
}
void
AscendControlParser
::
InsertMultipleAssignToGraph
(
NotNull
<
KernelGraphPtr
>
kg
,
NotNull
<
AnfNodePtr
>
from
,
void
AscendControlParser
::
InsertMultipleAssignToGraph
(
NotNull
<
KernelGraphPtr
>
from_graph
,
NotNull
<
KernelGraphPtr
>
to_graph
,
NotNull
<
AnfNodePtr
>
from
,
NotNull
<
AnfNodePtr
>
to
)
{
NotNull
<
AnfNodePtr
>
to
)
{
std
::
vector
<
AnfNodePtr
>
from_outputs
=
AnfAlgo
::
GetAllOutput
(
from
,
{
prim
::
kPrimTupleGetItem
});
std
::
vector
<
AnfNodePtr
>
from_outputs
=
AnfAlgo
::
GetAllOutput
(
from
,
{
prim
::
kPrimTupleGetItem
});
std
::
vector
<
AnfNodePtr
>
to_outputs
=
AnfAlgo
::
GetAllOutput
(
to
,
{
prim
::
kPrimTupleGetItem
});
std
::
vector
<
AnfNodePtr
>
to_outputs
=
AnfAlgo
::
GetAllOutput
(
to
,
{
prim
::
kPrimTupleGetItem
});
...
@@ -443,18 +460,24 @@ void AscendControlParser::InsertMultipleAssignToGraph(NotNull<KernelGraphPtr> kg
...
@@ -443,18 +460,24 @@ void AscendControlParser::InsertMultipleAssignToGraph(NotNull<KernelGraphPtr> kg
<<
to_outputs
.
size
()
<<
"]"
;
<<
to_outputs
.
size
()
<<
"]"
;
}
}
for
(
size_t
i
=
0
;
i
<
from_outputs
.
size
();
i
++
)
{
for
(
size_t
i
=
0
;
i
<
from_outputs
.
size
();
i
++
)
{
InsertAssignToGraph
(
kg
,
NOT_NULL
(
from_outputs
[
i
]),
NOT_NULL
(
to_outputs
[
i
]));
auto
assign_node
=
InsertAssignToGraph
(
from_graph
,
NOT_NULL
(
from_outputs
[
i
]),
NOT_NULL
(
to_outputs
[
i
]));
if
(
assign_node
!=
nullptr
)
{
auto
jump_node
=
GetJumpNode
(
from_graph
,
to_graph
);
if
(
jump_node
!=
nullptr
)
{
InsertControlDependToGraph
(
from_graph
,
NOT_NULL
(
assign_node
),
NOT_NULL
(
jump_node
));
}
}
}
}
}
}
void
AscendControlParser
::
InsertAssignToGraph
(
NotNull
<
KernelGraphPtr
>
kg
,
NotNull
<
AnfNodePtr
>
from
,
AnfNodePtr
AscendControlParser
::
InsertAssignToGraph
(
NotNull
<
KernelGraphPtr
>
kg
,
NotNull
<
AnfNodePtr
>
from
,
NotNull
<
AnfNodePtr
>
to
)
{
NotNull
<
AnfNodePtr
>
to
)
{
if
(
AnfAlgo
::
OutputAddrExist
(
from
,
0
)
&&
AnfAlgo
::
OutputAddrExist
(
to
,
0
)
&&
if
(
AnfAlgo
::
OutputAddrExist
(
from
,
0
)
&&
AnfAlgo
::
OutputAddrExist
(
to
,
0
)
&&
AnfAlgo
::
GetOutputAddr
(
from
,
0
)
==
AnfAlgo
::
GetOutputAddr
(
to
,
0
))
{
AnfAlgo
::
GetOutputAddr
(
from
,
0
)
==
AnfAlgo
::
GetOutputAddr
(
to
,
0
))
{
return
;
return
nullptr
;
}
}
if
(
from
.
get
()
==
to
.
get
())
{
if
(
from
.
get
()
==
to
.
get
())
{
return
;
return
nullptr
;
}
}
MS_LOG
(
INFO
)
<<
"Insert assign to graph "
<<
kg
->
ToString
()
<<
" from "
<<
from
->
DebugString
()
<<
" to "
MS_LOG
(
INFO
)
<<
"Insert assign to graph "
<<
kg
->
ToString
()
<<
" from "
<<
from
->
DebugString
()
<<
" to "
<<
to
->
DebugString
();
<<
to
->
DebugString
();
...
@@ -466,6 +489,7 @@ void AscendControlParser::InsertAssignToGraph(NotNull<KernelGraphPtr> kg, NotNul
...
@@ -466,6 +489,7 @@ void AscendControlParser::InsertAssignToGraph(NotNull<KernelGraphPtr> kg, NotNul
assign_node
->
set_abstract
(
to
->
abstract
());
assign_node
->
set_abstract
(
to
->
abstract
());
// append the assign at the end of from graph
// append the assign at the end of from graph
InsertDependToGraph
(
kg
,
NOT_NULL
(
assign_node
));
InsertDependToGraph
(
kg
,
NOT_NULL
(
assign_node
));
return
assign_node
;
}
}
std
::
vector
<
CNodePtr
>
AscendControlParser
::
RecurseGraph
(
NotNull
<
KernelGraphPtr
>
graph
,
std
::
vector
<
CNodePtr
>
AscendControlParser
::
RecurseGraph
(
NotNull
<
KernelGraphPtr
>
graph
,
...
...
mindspore/ccsrc/session/ascend_control_parser.h
浏览文件 @
5a886794
...
@@ -52,8 +52,9 @@ class AscendControlParser {
...
@@ -52,8 +52,9 @@ class AscendControlParser {
const
CNodePtr
&
last_label
);
const
CNodePtr
&
last_label
);
static
std
::
tuple
<
CNodePtr
,
KernelGraphPtr
>
ParsePartial
(
NotNull
<
AnfNodePtr
>
node
);
static
std
::
tuple
<
CNodePtr
,
KernelGraphPtr
>
ParsePartial
(
NotNull
<
AnfNodePtr
>
node
);
static
void
InsertMultipleAssignToGraph
(
NotNull
<
KernelGraphPtr
>
kg
,
NotNull
<
AnfNodePtr
>
from
,
NotNull
<
AnfNodePtr
>
to
);
static
void
InsertMultipleAssignToGraph
(
NotNull
<
KernelGraphPtr
>
from_graph
,
NotNull
<
KernelGraphPtr
>
to_graph
,
static
void
InsertAssignToGraph
(
NotNull
<
KernelGraphPtr
>
kg
,
NotNull
<
AnfNodePtr
>
from
,
NotNull
<
AnfNodePtr
>
to
);
NotNull
<
AnfNodePtr
>
from
,
NotNull
<
AnfNodePtr
>
to
);
static
AnfNodePtr
InsertAssignToGraph
(
NotNull
<
KernelGraphPtr
>
kg
,
NotNull
<
AnfNodePtr
>
from
,
NotNull
<
AnfNodePtr
>
to
);
// root graph order
// root graph order
static
bool
CheckLabelIndex
(
uint32_t
order_index
,
uint32_t
label_index
,
const
CNodePtr
&
cnode
,
static
bool
CheckLabelIndex
(
uint32_t
order_index
,
uint32_t
label_index
,
const
CNodePtr
&
cnode
,
...
...
mindspore/ccsrc/session/kernel_graph.cc
浏览文件 @
5a886794
...
@@ -521,6 +521,47 @@ std::vector<AnfNodePtr> KernelGraph::GetOutputNodes(const AnfNodePtr &node) {
...
@@ -521,6 +521,47 @@ std::vector<AnfNodePtr> KernelGraph::GetOutputNodes(const AnfNodePtr &node) {
return
output_nodes
;
return
output_nodes
;
}
}
// Find control_depend real input nodes.
void
GetAllFatherRealNode
(
const
AnfNodePtr
&
anf_node
,
std
::
vector
<
AnfNodePtr
>
*
result
,
std
::
set
<
AnfNodePtr
>
*
visited
)
{
MS_EXCEPTION_IF_NULL
(
anf_node
);
MS_EXCEPTION_IF_NULL
(
result
);
MS_EXCEPTION_IF_NULL
(
visited
);
if
(
visited
->
find
(
anf_node
)
!=
visited
->
end
())
{
MS_LOG
(
WARNING
)
<<
"Node:"
<<
anf_node
->
fullname_with_scope
()
<<
" has alreday been visited"
;
return
;
}
visited
->
insert
(
anf_node
);
if
(
AnfAlgo
::
IsRealKernel
(
anf_node
))
{
result
->
emplace_back
(
anf_node
);
return
;
}
if
(
!
anf_node
->
isa
<
CNode
>
())
{
return
;
}
auto
cnode
=
anf_node
->
cast
<
CNodePtr
>
();
MS_EXCEPTION_IF_NULL
(
cnode
);
if
(
cnode
->
inputs
().
empty
())
{
MS_LOG
(
EXCEPTION
)
<<
"Illegal null input of cnode(%s)"
<<
anf_node
->
DebugString
();
}
auto
input0
=
cnode
->
input
(
0
);
if
(
IsPrimitive
(
input0
,
prim
::
kPrimMakeTuple
))
{
for
(
size_t
i
=
1
;
i
<
cnode
->
inputs
().
size
();
++
i
)
{
GetAllFatherRealNode
(
cnode
->
input
(
i
),
result
,
visited
);
}
}
else
if
(
IsPrimitive
(
input0
,
prim
::
kPrimTupleGetItem
))
{
if
(
cnode
->
inputs
().
size
()
!=
kTupleGetItemInputSize
)
{
MS_LOG
(
EXCEPTION
)
<<
"The node tuple_get_item must have 2 inputs!"
;
}
GetAllFatherRealNode
(
cnode
->
input
(
kRealInputNodeIndexInTupleGetItem
),
result
,
visited
);
}
else
if
(
IsPrimitive
(
input0
,
prim
::
kPrimDepend
))
{
if
(
cnode
->
inputs
().
size
()
!=
kDependInputSize
)
{
MS_LOG
(
EXCEPTION
)
<<
"Depend node must have 2 inputs!"
;
}
GetAllFatherRealNode
(
cnode
->
input
(
kRealInputIndexInDepend
),
result
,
visited
);
GetAllFatherRealNode
(
cnode
->
input
(
kDependAttachNodeIndex
),
result
,
visited
);
}
}
// update the depend relations of control depend
// update the depend relations of control depend
void
KernelGraph
::
UpdateControlDependRelations
(
const
std
::
vector
<
AnfNodePtr
>
&
depends
)
{
void
KernelGraph
::
UpdateControlDependRelations
(
const
std
::
vector
<
AnfNodePtr
>
&
depends
)
{
for
(
const
auto
&
node
:
depends
)
{
for
(
const
auto
&
node
:
depends
)
{
...
@@ -551,11 +592,24 @@ void KernelGraph::UpdateControlDependRelations(const std::vector<AnfNodePtr> &de
...
@@ -551,11 +592,24 @@ void KernelGraph::UpdateControlDependRelations(const std::vector<AnfNodePtr> &de
if
(
depend_node
->
isa
<
Parameter
>
()
&&
depend_mode
==
1
)
{
if
(
depend_node
->
isa
<
Parameter
>
()
&&
depend_mode
==
1
)
{
depend_nodes
=
GetOutputNodes
(
depend_node
);
depend_nodes
=
GetOutputNodes
(
depend_node
);
}
}
for
(
auto
&
first_node
:
prior_nodes
)
{
std
::
vector
<
AnfNodePtr
>
real_prior_nodes
;
std
::
set
<
AnfNodePtr
>
prior_visited
;
for
(
const
auto
&
tmp
:
prior_nodes
)
{
GetAllFatherRealNode
(
tmp
,
&
real_prior_nodes
,
&
prior_visited
);
}
std
::
vector
<
AnfNodePtr
>
real_depend_nodes
;
std
::
set
<
AnfNodePtr
>
depend_visited
;
for
(
const
auto
&
tmp
:
depend_nodes
)
{
GetAllFatherRealNode
(
tmp
,
&
real_depend_nodes
,
&
depend_visited
);
}
for
(
auto
&
first_node
:
real_prior_nodes
)
{
if
(
AnfAlgo
::
CheckPrimitiveType
(
first_node
,
prim
::
kPrimControlDepend
))
{
if
(
AnfAlgo
::
CheckPrimitiveType
(
first_node
,
prim
::
kPrimControlDepend
))
{
continue
;
continue
;
}
}
for
(
auto
&
second_node
:
depend_nodes
)
{
for
(
auto
&
second_node
:
real_
depend_nodes
)
{
if
(
AnfAlgo
::
CheckPrimitiveType
(
second_node
,
prim
::
kPrimControlDepend
))
{
if
(
AnfAlgo
::
CheckPrimitiveType
(
second_node
,
prim
::
kPrimControlDepend
))
{
continue
;
continue
;
}
}
...
...
mindspore/ccsrc/session/session.cc
浏览文件 @
5a886794
...
@@ -33,9 +33,14 @@
...
@@ -33,9 +33,14 @@
namespace
py
=
pybind11
;
namespace
py
=
pybind11
;
namespace
mindspore
::
inference
{
namespace
mindspore
::
inference
{
std
::
shared_ptr
<
FuncGraph
>
LoadModel
(
const
char
*
model_buf
,
size_t
size
,
const
std
::
string
&
device
)
{
std
::
shared_ptr
<
FuncGraph
>
LoadModel
(
const
char
*
model_buf
,
size_t
size
,
const
std
::
string
&
device
)
{
inference
::
Session
::
RegAllOp
();
try
{
auto
anf_graph
=
lite
::
AnfConverter
::
RunAnfConverter
(
model_buf
,
size
);
inference
::
Session
::
RegAllOp
();
return
anf_graph
;
auto
anf_graph
=
lite
::
AnfConverter
::
RunAnfConverter
(
model_buf
,
size
);
return
anf_graph
;
}
catch
(
std
::
exception
&
e
)
{
MS_LOG
(
ERROR
)
<<
"Inference LoadModel failed"
;
return
nullptr
;
}
}
}
void
ExitInference
()
{
void
ExitInference
()
{
...
@@ -51,12 +56,17 @@ void ExitInference() {
...
@@ -51,12 +56,17 @@ void ExitInference() {
}
}
std
::
shared_ptr
<
MSSession
>
MSSession
::
CreateSession
(
const
std
::
string
&
device
,
uint32_t
device_id
)
{
std
::
shared_ptr
<
MSSession
>
MSSession
::
CreateSession
(
const
std
::
string
&
device
,
uint32_t
device_id
)
{
auto
session
=
std
::
make_shared
<
inference
::
Session
>
();
try
{
auto
ret
=
session
->
Init
(
device
,
device_id
);
auto
session
=
std
::
make_shared
<
inference
::
Session
>
();
if
(
ret
!=
0
)
{
auto
ret
=
session
->
Init
(
device
,
device_id
);
if
(
ret
!=
0
)
{
return
nullptr
;
}
return
session
;
}
catch
(
std
::
exception
&
e
)
{
MS_LOG
(
ERROR
)
<<
"Inference CreatSession failed"
;
return
nullptr
;
return
nullptr
;
}
}
return
session
;
}
}
void
Session
::
RegAllOp
()
{
void
Session
::
RegAllOp
()
{
...
@@ -113,47 +123,71 @@ void Session::RegAllOp() {
...
@@ -113,47 +123,71 @@ void Session::RegAllOp() {
uint32_t
Session
::
CompileGraph
(
std
::
shared_ptr
<
FuncGraph
>
funcGraphPtr
)
{
uint32_t
Session
::
CompileGraph
(
std
::
shared_ptr
<
FuncGraph
>
funcGraphPtr
)
{
MS_ASSERT
(
session_impl_
!=
nullptr
);
MS_ASSERT
(
session_impl_
!=
nullptr
);
auto
graph_id
=
session_impl_
->
CompileGraph
(
NOT_NULL
(
funcGraphPtr
));
try
{
py
::
gil_scoped_release
gil_release
;
auto
graph_id
=
session_impl_
->
CompileGraph
(
NOT_NULL
(
funcGraphPtr
));
return
graph_id
;
py
::
gil_scoped_release
gil_release
;
return
graph_id
;
}
catch
(
std
::
exception
&
e
)
{
MS_LOG
(
ERROR
)
<<
"Inference CompileGraph failed"
;
return
static_cast
<
uint32_t
>
(
-
1
);
}
}
}
MultiTensor
Session
::
RunGraph
(
uint32_t
graph_id
,
const
std
::
vector
<
std
::
shared_ptr
<
inference
::
MSTensor
>>
&
inputs
)
{
MultiTensor
Session
::
RunGraph
(
uint32_t
graph_id
,
const
std
::
vector
<
std
::
shared_ptr
<
inference
::
MSTensor
>>
&
inputs
)
{
std
::
vector
<
tensor
::
TensorPtr
>
inTensors
;
try
{
inTensors
.
resize
(
inputs
.
size
());
std
::
vector
<
tensor
::
TensorPtr
>
inTensors
;
bool
has_error
=
false
;
inTensors
.
resize
(
inputs
.
size
());
std
::
transform
(
inputs
.
begin
(),
inputs
.
end
(),
inTensors
.
begin
(),
bool
has_error
=
false
;
[
&
has_error
](
const
std
::
shared_ptr
<
inference
::
MSTensor
>
&
tensor_ptr
)
->
tensor
::
TensorPtr
{
std
::
transform
(
inputs
.
begin
(),
inputs
.
end
(),
inTensors
.
begin
(),
if
(
tensor_ptr
==
nullptr
)
{
[
&
has_error
](
const
std
::
shared_ptr
<
inference
::
MSTensor
>
&
tensor_ptr
)
->
tensor
::
TensorPtr
{
MS_LOG
(
WARNING
)
<<
"input MSTensor is nullptr, return nullptr"
;
if
(
tensor_ptr
==
nullptr
)
{
has_error
=
true
;
MS_LOG
(
WARNING
)
<<
"input MSTensor is nullptr, return nullptr"
;
return
nullptr
;
has_error
=
true
;
}
return
nullptr
;
auto
tensor
=
static_cast
<
inference
::
Tensor
*>
(
tensor_ptr
.
get
());
}
if
(
tensor
==
nullptr
)
{
auto
tensor
=
static_cast
<
inference
::
Tensor
*>
(
tensor_ptr
.
get
());
MS_LOG
(
ERROR
)
<<
"Can not cast input MSTensor to tensor"
;
if
(
tensor
==
nullptr
)
{
has_error
=
true
;
MS_LOG
(
ERROR
)
<<
"Can not cast input MSTensor to tensor"
;
return
nullptr
;
has_error
=
true
;
}
return
nullptr
;
return
tensor
->
tensor
();
}
});
return
tensor
->
tensor
();
if
(
has_error
)
{
});
MS_LOG
(
ERROR
)
<<
"Init Tensor failed, returning empty result"
;
if
(
has_error
)
{
std
::
vector
<
std
::
shared_ptr
<
inference
::
MSTensor
>>
multiTensor
;
MS_LOG
(
ERROR
)
<<
"Init Tensor failed, returning empty result"
;
return
multiTensor
;
std
::
vector
<
std
::
shared_ptr
<
inference
::
MSTensor
>>
multiTensor
;
}
return
multiTensor
;
VectorRef
outputs
;
}
session_impl_
->
RunGraph
(
graph_id
,
inTensors
,
&
outputs
);
VectorRef
outputs
;
session_impl_
->
RunGraph
(
graph_id
,
inTensors
,
&
outputs
);
return
TransformVectorRefToMultiTensor
(
outputs
);
return
TransformVectorRefToMultiTensor
(
outputs
);
}
catch
(
std
::
exception
&
e
)
{
MS_LOG
(
ERROR
)
<<
"Inference Rungraph failed"
;
return
MultiTensor
();
}
}
}
namespace
{
string
AjustTargetName
(
const
std
::
string
&
device
)
{
if
(
device
==
kAscendDevice
)
{
return
std
::
string
(
kAscendDevice
)
+
"Inference"
;
}
else
{
MS_LOG
(
ERROR
)
<<
"Only support device Ascend right now"
;
return
""
;
}
}
}
// namespace
int
Session
::
Init
(
const
std
::
string
&
device
,
uint32_t
device_id
)
{
int
Session
::
Init
(
const
std
::
string
&
device
,
uint32_t
device_id
)
{
RegAllOp
();
RegAllOp
();
auto
ms_context
=
MsContext
::
GetInstance
();
auto
ms_context
=
MsContext
::
GetInstance
();
ms_context
->
set_execution_mode
(
kGraphMode
);
ms_context
->
set_execution_mode
(
kGraphMode
);
ms_context
->
set_device_target
(
kAscendDevice
);
ms_context
->
set_device_id
(
device_id
);
session_impl_
=
session
::
SessionFactory
::
Get
().
Create
(
device
);
auto
ajust_device
=
AjustTargetName
(
device
);
if
(
ajust_device
==
""
)
{
return
-
1
;
}
ms_context
->
set_device_target
(
device
);
session_impl_
=
session
::
SessionFactory
::
Get
().
Create
(
ajust_device
);
if
(
session_impl_
==
nullptr
)
{
if
(
session_impl_
==
nullptr
)
{
MS_LOG
(
ERROR
)
<<
"Session create failed!, please make sure target device:"
<<
device
<<
" is available."
;
MS_LOG
(
ERROR
)
<<
"Session create failed!, please make sure target device:"
<<
device
<<
" is available."
;
return
-
1
;
return
-
1
;
...
...
mindspore/ccsrc/session/session_basic.cc
浏览文件 @
5a886794
...
@@ -81,7 +81,15 @@ BaseRef CreateOneTensor(const AnfNodePtr &node, size_t output_index, const Kerne
...
@@ -81,7 +81,15 @@ BaseRef CreateOneTensor(const AnfNodePtr &node, size_t output_index, const Kerne
}
}
}
}
// if proccess reach here,it remarks item_with_index is a real node(Parameter,or executable CNode)
// if proccess reach here,it remarks item_with_index is a real node(Parameter,or executable CNode)
auto
address
=
AnfAlgo
::
GetOutputAddr
(
node
,
output_index
);
DeviceAddressPtr
address
;
auto
is_all_nop_node
=
opt
::
IsAllNopNode
(
&
graph
);
if
(
is_all_nop_node
)
{
// The graph does not remove the nop node.
address
=
AnfAlgo
::
GetMutableOutputAddr
(
node
,
output_index
,
false
);
}
else
{
// The graph removes the nop node.
address
=
AnfAlgo
::
GetMutableOutputAddr
(
node
,
output_index
,
true
);
}
MS_EXCEPTION_IF_NULL
(
address
);
MS_EXCEPTION_IF_NULL
(
address
);
auto
shape
=
AnfAlgo
::
GetOutputInferShape
(
node
,
output_index
);
auto
shape
=
AnfAlgo
::
GetOutputInferShape
(
node
,
output_index
);
TypeId
type_id
=
kNumberTypeFloat32
;
TypeId
type_id
=
kNumberTypeFloat32
;
...
@@ -93,7 +101,7 @@ BaseRef CreateOneTensor(const AnfNodePtr &node, size_t output_index, const Kerne
...
@@ -93,7 +101,7 @@ BaseRef CreateOneTensor(const AnfNodePtr &node, size_t output_index, const Kerne
auto
ms_context
=
MsContext
::
GetInstance
();
auto
ms_context
=
MsContext
::
GetInstance
();
MS_EXCEPTION_IF_NULL
(
ms_context
);
MS_EXCEPTION_IF_NULL
(
ms_context
);
if
(
ms_context
->
execution_mode
()
==
kPynativeMode
||
ms_context
->
device_target
()
==
kGPUDevice
)
{
if
(
ms_context
->
execution_mode
()
==
kPynativeMode
||
ms_context
->
device_target
()
==
kGPUDevice
)
{
tensor
->
set_device_address
(
AnfAlgo
::
GetMutableOutputAddr
(
node
,
output_index
)
);
tensor
->
set_device_address
(
address
);
tensor
->
set_dirty
(
false
);
tensor
->
set_dirty
(
false
);
}
else
if
(
!
address
->
SyncDeviceToHost
(
trans
::
GetRuntimePaddingShape
(
node
,
output_index
),
}
else
if
(
!
address
->
SyncDeviceToHost
(
trans
::
GetRuntimePaddingShape
(
node
,
output_index
),
LongToSize
(
tensor
->
data
().
nbytes
()),
tensor
->
data_type
(),
tensor
->
data_c
()))
{
LongToSize
(
tensor
->
data
().
nbytes
()),
tensor
->
data_type
(),
tensor
->
data_c
()))
{
...
...
mindspore/ccsrc/transform/convert.cc
浏览文件 @
5a886794
...
@@ -1646,7 +1646,7 @@ bool DfGraphConvertor::GetControlDependList(const CNodePtr &node,
...
@@ -1646,7 +1646,7 @@ bool DfGraphConvertor::GetControlDependList(const CNodePtr &node,
dst_ops_list
->
insert
(
dst_ops_list
->
end
(),
converted_list
.
begin
(),
converted_list
.
end
());
dst_ops_list
->
insert
(
dst_ops_list
->
end
(),
converted_list
.
begin
(),
converted_list
.
end
());
}
}
if
(
src_ops_list
->
empty
()
||
dst_ops_list
->
empty
())
{
if
(
src_ops_list
->
empty
()
||
dst_ops_list
->
empty
())
{
MS_LOG
(
WARNING
)
<<
"Control depend node's src or dest node is not a apply n
ode, ignore it"
;
MS_LOG
(
DEBUG
)
<<
"Control depend node's src or dest node is not a CN
ode, ignore it"
;
error_
=
SUCCESS
;
error_
=
SUCCESS
;
}
}
return
true
;
return
true
;
...
@@ -1690,6 +1690,8 @@ void DfGraphConvertor::ConvertControlDependNode(const CNodePtr node) {
...
@@ -1690,6 +1690,8 @@ void DfGraphConvertor::ConvertControlDependNode(const CNodePtr node) {
});
});
}
else
if
(
src_ops_list
->
size
()
==
1
&&
dst_ops_list
->
size
()
==
1
)
{
}
else
if
(
src_ops_list
->
size
()
==
1
&&
dst_ops_list
->
size
()
==
1
)
{
control_edges
.
push_back
({(
*
src_ops_list
)[
0
],
(
*
dst_ops_list
)[
0
]});
control_edges
.
push_back
({(
*
src_ops_list
)[
0
],
(
*
dst_ops_list
)[
0
]});
}
else
if
(
src_ops_list
->
empty
()
||
dst_ops_list
->
empty
())
{
MS_LOG
(
DEBUG
)
<<
"Depend list of src or dst is empty, ignore it"
;
}
else
{
}
else
{
MS_LOG
(
ERROR
)
<<
"Convert control depend node to operator failed, depend src:"
<<
src_ops_list
->
size
()
MS_LOG
(
ERROR
)
<<
"Convert control depend node to operator failed, depend src:"
<<
src_ops_list
->
size
()
<<
" -> dst:"
<<
dst_ops_list
->
size
();
<<
" -> dst:"
<<
dst_ops_list
->
size
();
...
...
mindspore/ccsrc/utils/log_adapter.cc
浏览文件 @
5a886794
...
@@ -463,7 +463,7 @@ void InitSubModulesLogLevel() {
...
@@ -463,7 +463,7 @@ void InitSubModulesLogLevel() {
// set submodule's log level
// set submodule's log level
auto
submodule
=
GetEnv
(
"MS_SUBMODULE_LOG_v"
);
auto
submodule
=
GetEnv
(
"MS_SUBMODULE_LOG_v"
);
MS_LOG
(
INFO
)
<<
"MS_SUBMODULE_LOG_v=`"
<<
submodule
<<
"`"
;
MS_LOG
(
DEBUG
)
<<
"MS_SUBMODULE_LOG_v=`"
<<
submodule
<<
"`"
;
LogConfigParser
parser
(
submodule
);
LogConfigParser
parser
(
submodule
);
auto
configs
=
parser
.
Parse
();
auto
configs
=
parser
.
Parse
();
for
(
const
auto
&
cfg
:
configs
)
{
for
(
const
auto
&
cfg
:
configs
)
{
...
@@ -489,22 +489,14 @@ void InitSubModulesLogLevel() {
...
@@ -489,22 +489,14 @@ void InitSubModulesLogLevel() {
}
// namespace mindspore
}
// namespace mindspore
extern
"C"
{
extern
"C"
{
// shared lib init hook
#if defined(_WIN32) || defined(_WIN64)
#if defined(_WIN32) || defined(_WIN64)
__attribute__
((
constructor
))
void
mindspore
_log_init
(
void
)
{
__attribute__
((
constructor
))
void
common
_log_init
(
void
)
{
#else
#else
void
mindspore
_log_init
(
void
)
{
void
common
_log_init
(
void
)
{
#endif
#endif
#ifdef USE_GLOG
#ifdef USE_GLOG
// do not use glog predefined log prefix
// do not use glog predefined log prefix
FLAGS_log_prefix
=
false
;
FLAGS_log_prefix
=
false
;
static
bool
is_glog_initialzed
=
false
;
if
(
!
is_glog_initialzed
)
{
#if !defined(_WIN32) && !defined(_WIN64)
google
::
InitGoogleLogging
(
"mindspore"
);
#endif
is_glog_initialzed
=
true
;
}
// set default log level to WARNING
// set default log level to WARNING
if
(
mindspore
::
GetEnv
(
"GLOG_v"
).
empty
())
{
if
(
mindspore
::
GetEnv
(
"GLOG_v"
).
empty
())
{
FLAGS_v
=
mindspore
::
WARNING
;
FLAGS_v
=
mindspore
::
WARNING
;
...
@@ -525,4 +517,22 @@ void mindspore_log_init(void) {
...
@@ -525,4 +517,22 @@ void mindspore_log_init(void) {
#endif
#endif
mindspore
::
InitSubModulesLogLevel
();
mindspore
::
InitSubModulesLogLevel
();
}
}
// shared lib init hook
#if defined(_WIN32) || defined(_WIN64)
__attribute__
((
constructor
))
void
mindspore_log_init
(
void
)
{
#else
void
mindspore_log_init
(
void
)
{
#endif
#ifdef USE_GLOG
static
bool
is_glog_initialzed
=
false
;
if
(
!
is_glog_initialzed
)
{
#if !defined(_WIN32) && !defined(_WIN64)
google
::
InitGoogleLogging
(
"mindspore"
);
#endif
is_glog_initialzed
=
true
;
}
#endif
common_log_init
();
}
}
}
mindspore/ccsrc/utils/utils.h
浏览文件 @
5a886794
...
@@ -252,6 +252,7 @@ constexpr auto kControlDependMode = "depend_mode";
...
@@ -252,6 +252,7 @@ constexpr auto kControlDependMode = "depend_mode";
// index define of depend
// index define of depend
constexpr
auto
kRealInputIndexInDepend
=
1
;
constexpr
auto
kRealInputIndexInDepend
=
1
;
constexpr
auto
kDependAttachNodeIndex
=
2
;
constexpr
auto
kDependAttachNodeIndex
=
2
;
constexpr
auto
kDependInputSize
=
3
;
// format
// format
constexpr
auto
kOpFormat_DEFAULT
=
"DefaultFormat"
;
constexpr
auto
kOpFormat_DEFAULT
=
"DefaultFormat"
;
constexpr
auto
kOpFormat_NC1KHKWHWC0
=
"NC1KHKWHWC0"
;
constexpr
auto
kOpFormat_NC1KHKWHWC0
=
"NC1KHKWHWC0"
;
...
...
mindspore/common/tensor.py
浏览文件 @
5a886794
...
@@ -22,6 +22,10 @@ from . import dtype as mstype
...
@@ -22,6 +22,10 @@ from . import dtype as mstype
from
._register_for_tensor
import
tensor_operator_registry
from
._register_for_tensor
import
tensor_operator_registry
__all__
=
[
'Tensor'
,
'MetaTensor'
]
__all__
=
[
'Tensor'
,
'MetaTensor'
]
np_types
=
(
np
.
int8
,
np
.
int16
,
np
.
int32
,
np
.
int64
,
np
.
uint8
,
np
.
uint16
,
np
.
uint32
,
np
.
uint64
,
np
.
float16
,
np
.
float32
,
np
.
float64
,
np
.
bool_
)
class
Tensor
(
Tensor_
):
class
Tensor
(
Tensor_
):
...
@@ -54,6 +58,10 @@ class Tensor(Tensor_):
...
@@ -54,6 +58,10 @@ class Tensor(Tensor_):
"""
"""
def
__init__
(
self
,
input_data
,
dtype
=
None
):
def
__init__
(
self
,
input_data
,
dtype
=
None
):
# If input data is numpy number, convert it to np array
if
isinstance
(
input_data
,
np_types
):
input_data
=
np
.
array
(
input_data
)
# If input_data is tuple/list/numpy.ndarray, it's support in check_type method.
# If input_data is tuple/list/numpy.ndarray, it's support in check_type method.
check_type
(
'tensor input_data'
,
input_data
,
(
Tensor_
,
float
,
int
))
check_type
(
'tensor input_data'
,
input_data
,
(
Tensor_
,
float
,
int
))
if
dtype
is
not
None
:
if
dtype
is
not
None
:
...
...
mindspore/dataset/engine/datasets.py
浏览文件 @
5a886794
...
@@ -1040,7 +1040,7 @@ class Dataset:
...
@@ -1040,7 +1040,7 @@ class Dataset:
Args:
Args:
columns (list[str], optional): List of columns to be used to specify the order of columns
columns (list[str], optional): List of columns to be used to specify the order of columns
(default
s
=None, means all columns).
(default=None, means all columns).
Returns:
Returns:
Iterator, list of ndarray.
Iterator, list of ndarray.
...
@@ -3382,7 +3382,7 @@ class ManifestDataset(MappableDataset):
...
@@ -3382,7 +3382,7 @@ class ManifestDataset(MappableDataset):
class_indexing (dict, optional): A str-to-int mapping from label name to index
class_indexing (dict, optional): A str-to-int mapping from label name to index
(default=None, the folder names will be sorted alphabetically and each
(default=None, the folder names will be sorted alphabetically and each
class will be given a unique index starting from 0).
class will be given a unique index starting from 0).
decode (bool, optional): decode the images after reading (default
s
=False).
decode (bool, optional): decode the images after reading (default=False).
num_shards (int, optional): Number of shards that the dataset should be divided
num_shards (int, optional): Number of shards that the dataset should be divided
into (default=None).
into (default=None).
shard_id (int, optional): The shard ID within num_shards (default=None). This
shard_id (int, optional): The shard ID within num_shards (default=None). This
...
@@ -4760,7 +4760,7 @@ class _NumpySlicesDataset:
...
@@ -4760,7 +4760,7 @@ class _NumpySlicesDataset:
def
process_dict
(
self
,
input_data
):
def
process_dict
(
self
,
input_data
):
"""
"""
Convert the dict like data into tuple format, when input is a tuple of dict then compose it into a dict first.
Convert the dict like data into tuple format, when input is a tuple of dict
s
then compose it into a dict first.
"""
"""
# Convert pandas like dict(has "values" column) into General dict
# Convert pandas like dict(has "values" column) into General dict
data_keys
=
list
(
input_data
.
keys
())
data_keys
=
list
(
input_data
.
keys
())
...
...
mindspore/dataset/transforms/vision/c_transforms.py
浏览文件 @
5a886794
...
@@ -202,7 +202,7 @@ class RandomHorizontalFlip(cde.RandomHorizontalFlipOp):
...
@@ -202,7 +202,7 @@ class RandomHorizontalFlip(cde.RandomHorizontalFlipOp):
Flip the input image horizontally, randomly with a given probability.
Flip the input image horizontally, randomly with a given probability.
Args:
Args:
prob (float): Probability of the image being flipped (default=0.5).
prob (float
, optional
): Probability of the image being flipped (default=0.5).
"""
"""
@
check_prob
@
check_prob
...
@@ -217,7 +217,7 @@ class RandomHorizontalFlipWithBBox(cde.RandomHorizontalFlipWithBBoxOp):
...
@@ -217,7 +217,7 @@ class RandomHorizontalFlipWithBBox(cde.RandomHorizontalFlipWithBBoxOp):
Maintains data integrity by also flipping bounding boxes in an object detection pipeline.
Maintains data integrity by also flipping bounding boxes in an object detection pipeline.
Args:
Args:
prob (float): Probability of the image being flipped (default=0.5).
prob (float
, optional
): Probability of the image being flipped (default=0.5).
"""
"""
@
check_prob
@
check_prob
...
@@ -231,7 +231,7 @@ class RandomVerticalFlip(cde.RandomVerticalFlipOp):
...
@@ -231,7 +231,7 @@ class RandomVerticalFlip(cde.RandomVerticalFlipOp):
Flip the input image vertically, randomly with a given probability.
Flip the input image vertically, randomly with a given probability.
Args:
Args:
prob (float): Probability of the image being flipped (default=0.5).
prob (float
, optional
): Probability of the image being flipped (default=0.5).
"""
"""
@
check_prob
@
check_prob
...
...
mindspore/nn/optim/adam.py
浏览文件 @
5a886794
...
@@ -29,8 +29,9 @@ from .optimizer import Optimizer
...
@@ -29,8 +29,9 @@ from .optimizer import Optimizer
_adam_opt
=
C
.
MultitypeFuncGraph
(
"adam_opt"
)
_adam_opt
=
C
.
MultitypeFuncGraph
(
"adam_opt"
)
@
_adam_opt
.
register
(
"Tensor"
,
"Tensor"
,
"Tensor"
,
"Tensor"
,
"Tensor"
,
"Tensor"
,
"Tensor"
,
"Tensor"
,
"Tensor"
,
"Bool"
)
@
_adam_opt
.
register
(
"Tensor"
,
"Tensor"
,
"Tensor"
,
"Tensor"
,
"Tensor"
,
"Tensor"
,
"Tensor"
,
"Tensor"
,
def
_update_run_op
(
beta1
,
beta2
,
eps
,
lr
,
weight_decay_tensor
,
param
,
m
,
v
,
gradient
,
decay_flag
):
"Tensor"
,
"Bool"
,
"Bool"
)
def
_update_run_op
(
beta1
,
beta2
,
eps
,
lr
,
weight_decay_tensor
,
param
,
m
,
v
,
gradient
,
decay_flag
,
optim_filter
):
"""
"""
Update parameters.
Update parameters.
...
@@ -44,38 +45,44 @@ def _update_run_op(beta1, beta2, eps, lr, weight_decay_tensor, param, m, v, grad
...
@@ -44,38 +45,44 @@ def _update_run_op(beta1, beta2, eps, lr, weight_decay_tensor, param, m, v, grad
m (Tensor): m value of parameters.
m (Tensor): m value of parameters.
v (Tensor): v value of parameters.
v (Tensor): v value of parameters.
gradient (Tensor): Gradient of parameters.
gradient (Tensor): Gradient of parameters.
decay_flag (bool): Applies weight decay or not.
optim_filter (bool): Applies parameter update or not.
Returns:
Returns:
Tensor, the new value of v after updating.
Tensor, the new value of v after updating.
"""
"""
op_mul
=
P
.
Mul
()
if
optim_filter
:
op_square
=
P
.
Square
()
op_mul
=
P
.
Mul
()
op_sqrt
=
P
.
Sqrt
()
op_square
=
P
.
Square
()
op_cast
=
P
.
Cast
()
op_sqrt
=
P
.
Sqrt
()
op_reshape
=
P
.
Reshape
()
op_cast
=
P
.
Cast
()
op_shape
=
P
.
Shape
()
op_reshape
=
P
.
Reshape
()
op_shape
=
P
.
Shape
()
param_fp32
=
op_cast
(
param
,
mstype
.
float32
)
param_fp32
=
op_cast
(
param
,
mstype
.
float32
)
m_fp32
=
op_cast
(
m
,
mstype
.
float32
)
m_fp32
=
op_cast
(
m
,
mstype
.
float32
)
v_fp32
=
op_cast
(
v
,
mstype
.
float32
)
v_fp32
=
op_cast
(
v
,
mstype
.
float32
)
gradient_fp32
=
op_cast
(
gradient
,
mstype
.
float32
)
gradient_fp32
=
op_cast
(
gradient
,
mstype
.
float32
)
next_m
=
op_mul
(
beta1
,
m_fp32
)
+
op_mul
(
op_cast
(
F
.
tuple_to_array
((
1.0
,)),
mstype
.
float32
)
-
beta1
,
gradient_fp32
)
next_m
=
op_mul
(
beta1
,
m_fp32
)
+
op_mul
(
op_cast
(
F
.
tuple_to_array
((
1.0
,)),
mstype
.
float32
)
-
beta1
,
gradient_fp32
)
next_v
=
op_mul
(
beta2
,
v_fp32
)
+
op_mul
(
op_cast
(
F
.
tuple_to_array
((
1.0
,)),
mstype
.
float32
)
next_v
=
op_mul
(
beta2
,
v_fp32
)
+
op_mul
(
op_cast
(
F
.
tuple_to_array
((
1.0
,)),
mstype
.
float32
)
-
beta2
,
op_square
(
gradient_fp32
))
-
beta2
,
op_square
(
gradient_fp32
))
update
=
next_m
/
(
eps
+
op_sqrt
(
next_v
))
if
decay_flag
:
update
=
op_mul
(
weight_decay_tensor
,
param_fp32
)
+
update
update_with_lr
=
op_mul
(
lr
,
update
)
update
=
next_m
/
(
eps
+
op_sqrt
(
next_v
))
next_param
=
param_fp32
-
op_reshape
(
update_with_lr
,
op_shape
(
param_fp32
))
if
decay_flag
:
update
=
op_mul
(
weight_decay_tensor
,
param_fp32
)
+
update
next_v
=
F
.
depend
(
next_v
,
F
.
assign
(
param
,
op_cast
(
next_param
,
F
.
dtype
(
param
))))
update_with_lr
=
op_mul
(
lr
,
update
)
next_v
=
F
.
depend
(
next_v
,
F
.
assign
(
m
,
op_cast
(
next_m
,
F
.
dtype
(
m
))))
next_param
=
param_fp32
-
op_reshape
(
update_with_lr
,
op_shape
(
param_fp32
))
next_v
=
F
.
depend
(
next_v
,
F
.
assign
(
v
,
op_cast
(
next_v
,
F
.
dtype
(
v
))))
return
next_v
next_param
=
F
.
depend
(
next_param
,
F
.
assign
(
param
,
op_cast
(
next_param
,
F
.
dtype
(
param
))))
next_param
=
F
.
depend
(
next_param
,
F
.
assign
(
m
,
op_cast
(
next_m
,
F
.
dtype
(
m
))))
next_param
=
F
.
depend
(
next_param
,
F
.
assign
(
v
,
op_cast
(
next_v
,
F
.
dtype
(
v
))))
return
next_param
return
gradient
def
_check_param_value
(
beta1
,
beta2
,
eps
,
weight_decay
,
prim_name
):
def
_check_param_value
(
beta1
,
beta2
,
eps
,
weight_decay
,
prim_name
):
...
@@ -300,7 +307,7 @@ class AdamWeightDecay(Optimizer):
...
@@ -300,7 +307,7 @@ class AdamWeightDecay(Optimizer):
- **gradients** (tuple[Tensor]) - The gradients of `params`, the shape is the same as `params`.
- **gradients** (tuple[Tensor]) - The gradients of `params`, the shape is the same as `params`.
Outputs:
Outputs:
tuple[
Parameter], the updated velocity value, the shape is the same as `params`
.
tuple[
bool], all elements are True
.
Examples:
Examples:
>>> net = Net()
>>> net = Net()
...
@@ -328,11 +335,13 @@ class AdamWeightDecay(Optimizer):
...
@@ -328,11 +335,13 @@ class AdamWeightDecay(Optimizer):
def
construct
(
self
,
gradients
):
def
construct
(
self
,
gradients
):
lr
=
self
.
get_lr
()
lr
=
self
.
get_lr
()
updated_velocity
=
self
.
hyper_map
(
F
.
partial
(
_adam_opt
,
self
.
beta1
,
self
.
beta2
,
self
.
eps
,
lr
,
optim_result
=
self
.
hyper_map
(
F
.
partial
(
_adam_opt
,
self
.
beta1
,
self
.
beta2
,
self
.
eps
,
lr
,
self
.
weight_decay_tensor
),
self
.
weight_decay_tensor
),
self
.
params
,
self
.
moments1
,
self
.
moments2
,
gradients
,
self
.
decay_flag
)
self
.
params
,
self
.
moments1
,
self
.
moments2
,
gradients
,
self
.
decay_flag
,
self
.
optim_filter
)
return
updated_velocity
if
self
.
use_parallel
:
optim_result
=
self
.
broadcast_params
(
optim_result
)
return
optim_result
class
AdamWeightDecayDynamicLR
(
Optimizer
):
class
AdamWeightDecayDynamicLR
(
Optimizer
):
...
@@ -363,7 +372,7 @@ class AdamWeightDecayDynamicLR(Optimizer):
...
@@ -363,7 +372,7 @@ class AdamWeightDecayDynamicLR(Optimizer):
- **gradients** (tuple[Tensor]) - The gradients of `params`, the shape is the same as `params`.
- **gradients** (tuple[Tensor]) - The gradients of `params`, the shape is the same as `params`.
Outputs:
Outputs:
tuple[
Parameter], the updated velocity value, the shape is the same as `params`
.
tuple[
bool], all elements are True
.
Examples:
Examples:
>>> net = Net()
>>> net = Net()
...
@@ -424,12 +433,14 @@ class AdamWeightDecayDynamicLR(Optimizer):
...
@@ -424,12 +433,14 @@ class AdamWeightDecayDynamicLR(Optimizer):
warmup_lr
=
self
.
start_learning_rate
*
warmup_percent
warmup_lr
=
self
.
start_learning_rate
*
warmup_percent
is_warmup
=
self
.
cast
(
self
.
greater
(
self
.
warmup_steps
,
self
.
global_step
),
mstype
.
float32
)
is_warmup
=
self
.
cast
(
self
.
greater
(
self
.
warmup_steps
,
self
.
global_step
),
mstype
.
float32
)
lr
=
(
self
.
one
-
is_warmup
)
*
lr
+
is_warmup
*
warmup_lr
lr
=
(
self
.
one
-
is_warmup
)
*
lr
+
is_warmup
*
warmup_lr
updated_velocity
=
self
.
hyper_map
(
F
.
partial
(
_adam_opt
,
self
.
beta1
,
self
.
beta2
,
self
.
eps
,
lr
,
optim_result
=
self
.
hyper_map
(
F
.
partial
(
_adam_opt
,
self
.
beta1
,
self
.
beta2
,
self
.
eps
,
lr
,
self
.
weight_decay_tensor
),
self
.
weight_decay_tensor
),
self
.
params
,
self
.
moments1
,
self
.
moments2
,
gradients
,
self
.
decay_flag
)
self
.
params
,
self
.
moments1
,
self
.
moments2
,
gradients
,
self
.
decay_flag
,
self
.
optim_filter
)
if
self
.
use_parallel
:
optim_result
=
self
.
broadcast_params
(
optim_result
)
added_global_step
=
self
.
global_step
+
self
.
one
added_global_step
=
self
.
global_step
+
self
.
one
F
.
control_depend
(
lr
,
added_global_step
)
F
.
control_depend
(
lr
,
added_global_step
)
self
.
global_step
=
added_global_step
self
.
global_step
=
added_global_step
return
updated_velocity
return
optim_result
mindspore/nn/optim/lamb.py
浏览文件 @
5a886794
...
@@ -32,11 +32,10 @@ num_one = Tensor(np.ones([1]), mstype.float32)
...
@@ -32,11 +32,10 @@ num_one = Tensor(np.ones([1]), mstype.float32)
_lamb_opt
=
C
.
MultitypeFuncGraph
(
"lamb_opt"
)
_lamb_opt
=
C
.
MultitypeFuncGraph
(
"lamb_opt"
)
@
_lamb_opt
.
register
(
"Tensor"
,
"Tensor"
,
"Tensor"
,
"Tensor"
,
"Tensor"
,
"Tensor"
,
"Tensor"
,
"Tensor"
,
"Tensor"
,
@
_lamb_opt
.
register
(
"Tensor"
,
"Tensor"
,
"Tensor"
,
"Tensor"
,
"Tensor"
,
"Tensor"
,
"Tensor"
,
"Bool"
,
"Bool"
)
"Tensor"
,
"Tensor"
,
"Tensor"
,
"Tensor"
,
"Bool"
)
def
_update_run_op
(
beta1
,
beta2
,
eps
,
lr
,
weight_decay_tensor
,
global_step
,
param
,
m
,
v
,
def
_update_run_op
(
beta1
,
beta2
,
eps
,
lr
,
weight_decay_tensor
,
global_step
,
param
,
m
,
v
,
gradient
,
decay_flag
):
gradient
,
decay_flag
,
optim_filter
):
"""
"""
Update parameters.
Update parameters.
...
@@ -52,66 +51,66 @@ def _update_run_op(beta1, beta2, eps, lr, weight_decay_tensor, global_step, para
...
@@ -52,66 +51,66 @@ def _update_run_op(beta1, beta2, eps, lr, weight_decay_tensor, global_step, para
v (Tensor): v value of parameters.
v (Tensor): v value of parameters.
gradient (Tensor): Gradient of parameters.
gradient (Tensor): Gradient of parameters.
decay_flag (bool): Specifies whether param update with weight decay.
decay_flag (bool): Specifies whether param update with weight decay.
optim_filter(bool): Applies parameter update or not.
Returns:
Returns:
Tensor, the new value of v after updating.
Tensor, the new value of v after updating.
"""
"""
op_mul
=
P
.
Mul
()
if
optim_filter
:
op_sqrt
=
P
.
Sqrt
()
op_mul
=
P
.
Mul
()
op_rsqrt
=
P
.
Rsqrt
()
op_sqrt
=
P
.
Sqrt
()
op_square
=
P
.
Square
()
op_rsqrt
=
P
.
Rsqrt
()
op_cast
=
P
.
Cast
()
op_square
=
P
.
Square
()
op_reshape
=
P
.
Reshape
()
op_cast
=
P
.
Cast
()
op_shape
=
P
.
Shape
()
op_reshape
=
P
.
Reshape
()
op_pow
=
P
.
Pow
()
op_shape
=
P
.
Shape
()
op_norm
=
layer
.
Norm
()
op_pow
=
P
.
Pow
()
op_select
=
P
.
Select
()
op_norm
=
layer
.
Norm
()
op_greater
=
P
.
Greater
()
op_select
=
P
.
Select
()
op_fill
=
P
.
Fill
()
op_greater
=
P
.
Greater
()
op_dtype
=
P
.
DType
()
op_fill
=
P
.
Fill
()
op_dtype
=
P
.
DType
()
param_fp32
=
op_cast
(
param
,
mstype
.
float32
)
m_fp32
=
op_cast
(
m
,
mstype
.
float32
)
param_fp32
=
op_cast
(
param
,
mstype
.
float32
)
v_fp32
=
op_cast
(
v
,
mstype
.
float32
)
m_fp32
=
op_cast
(
m
,
mstype
.
float32
)
gradient_fp32
=
op_cast
(
gradient
,
mstype
.
float32
)
v_fp32
=
op_cast
(
v
,
mstype
.
float32
)
gradient_fp32
=
op_cast
(
gradient
,
mstype
.
float32
)
next_m
=
op_mul
(
beta1
,
m_fp32
)
+
op_mul
(
op_cast
(
num_one
,
mstype
.
float32
)
-
beta1
,
gradient_fp32
)
next_m
=
op_mul
(
beta1
,
m_fp32
)
+
op_mul
(
op_cast
(
num_one
,
mstype
.
float32
)
-
beta1
,
gradient_fp32
)
next_v
=
op_mul
(
beta2
,
v_fp32
)
+
op_mul
(
op_cast
(
num_one
,
next_v
=
op_mul
(
beta2
,
v_fp32
)
+
op_mul
(
op_cast
(
num_one
,
mstype
.
float32
)
-
beta2
,
op_square
(
gradient_fp32
))
mstype
.
float32
)
-
beta2
,
op_square
(
gradient_fp32
))
next_mm
=
next_m
/
(
op_cast
(
num_one
,
mstype
.
float32
)
next_mm
=
next_m
/
(
op_cast
(
num_one
,
mstype
.
float32
)
-
op_pow
(
beta1
,
op_cast
(
global_step
+
num_one
,
mstype
.
float32
)))
-
op_pow
(
beta1
,
op_cast
(
global_step
+
num_one
,
mstype
.
float32
)))
next_vv
=
next_v
/
(
op_cast
(
num_one
,
mstype
.
float32
)
-
next_vv
=
next_v
/
(
op_cast
(
num_one
,
mstype
.
float32
)
-
op_pow
(
beta2
,
op_cast
(
global_step
+
num_one
,
mstype
.
float32
)))
op_pow
(
beta2
,
op_cast
(
global_step
+
num_one
,
mstype
.
float32
)))
w_norm
=
op_norm
(
param_fp32
)
w_norm
=
op_norm
(
param_fp32
)
g_norm
=
op_norm
(
gradient_fp32
)
g_norm
=
op_norm
(
gradient_fp32
)
g_norm_hat
=
op_norm
(
op_mul
(
next_mm
,
op_rsqrt
(
next_vv
+
eps
))
+
weight_decay_tensor
*
param_fp32
)
g_norm_hat
=
op_norm
(
op_mul
(
next_mm
,
op_rsqrt
(
zeros
=
F
.
zeros_like
(
w_norm
)
next_vv
+
eps
))
+
weight_decay_tensor
*
param_fp32
)
ones
=
op_fill
(
op_dtype
(
w_norm
),
op_shape
(
w_norm
),
1.0
)
zeros
=
F
.
zeros_like
(
w_norm
)
trust_ratio
=
op_select
(
ones
=
op_fill
(
op_dtype
(
w_norm
),
op_shape
(
w_norm
),
1.0
)
op_greater
(
w_norm
,
zeros
),
trust_ratio
=
op_select
(
op_select
(
op_greater
(
g_norm
,
zeros
),
w_norm
/
g_norm_hat
,
ones
),
op_greater
(
w_norm
,
zeros
),
ones
)
op_select
(
op_greater
(
g_norm
,
zeros
),
w_norm
/
g_norm_hat
,
ones
),
tens
=
op_fill
(
op_dtype
(
trust_ratio
),
op_shape
(
trust_ratio
),
10.0
)
ones
)
trust_ratio
=
C
.
clip_by_value
(
trust_ratio
,
zeros
,
tens
)
tens
=
op_fill
(
op_dtype
(
trust_ratio
),
op_shape
(
trust_ratio
),
10.0
)
update
=
next_mm
/
(
op_sqrt
(
next_vv
)
+
eps
)
trust_ratio
=
C
.
clip_by_value
(
trust_ratio
,
zeros
,
tens
)
update
=
next_mm
/
(
op_sqrt
(
next_vv
)
+
eps
)
if
decay_flag
:
update
=
update
+
op_mul
(
weight_decay_tensor
,
param_fp32
)
if
decay_flag
:
update
=
update
+
op_mul
(
weight_decay_tensor
,
param_fp32
)
update_with_lr
=
op_mul
(
op_mul
(
trust_ratio
,
lr
),
update
)
update_with_lr
=
op_mul
(
op_mul
(
trust_ratio
,
lr
),
update
)
next_param
=
param_fp32
-
op_reshape
(
update_with_lr
,
op_shape
(
param_fp32
))
next_param
=
param_fp32
-
op_reshape
(
update_with_lr
,
op_shape
(
param_fp32
))
next_param
=
F
.
depend
(
next_param
,
F
.
assign
(
param
,
next_param
))
next_param
=
F
.
depend
(
next_param
,
F
.
assign
(
m
,
next_m
))
next_v
=
F
.
depend
(
next_v
,
F
.
assign
(
param
,
next_param
))
next_param
=
F
.
depend
(
next_param
,
F
.
assign
(
v
,
next_v
))
next_v
=
F
.
depend
(
next_v
,
F
.
assign
(
m
,
next_m
))
next_v
=
F
.
depend
(
next_v
,
F
.
assign
(
v
,
next_v
))
return
next_param
return
gradient
return
next_v
lamb_opt_graph_kernel
=
C
.
MultitypeFuncGraph
(
"lamb_opt_graph_kernel"
)
lamb_opt_graph_kernel
=
C
.
MultitypeFuncGraph
(
"lamb_opt_graph_kernel"
)
...
@@ -238,7 +237,7 @@ class Lamb(Optimizer):
...
@@ -238,7 +237,7 @@ class Lamb(Optimizer):
- **gradients** (tuple[Tensor]) - The gradients of `params`, the shape is the same as `params`.
- **gradients** (tuple[Tensor]) - The gradients of `params`, the shape is the same as `params`.
Outputs:
Outputs:
tuple[
Parameter], the updated velocity value, the shape is the same as `params`
.
tuple[
bool], all elements are True
.
Examples:
Examples:
>>> net = Net()
>>> net = Net()
...
@@ -311,18 +310,21 @@ class Lamb(Optimizer):
...
@@ -311,18 +310,21 @@ class Lamb(Optimizer):
self
.
warmup_steps
,
self
.
global_step
),
mstype
.
float32
)
self
.
warmup_steps
,
self
.
global_step
),
mstype
.
float32
)
lr
=
(
self
.
one
-
is_warmup
)
*
lr
+
is_warmup
*
warmup_lr
lr
=
(
self
.
one
-
is_warmup
)
*
lr
+
is_warmup
*
warmup_lr
if
self
.
enable_graph_kernel
:
if
self
.
enable_graph_kernel
:
updated_velocity
=
self
.
hyper_map
(
F
.
partial
(
lamb_opt_graph_kernel
,
optim_result
=
self
.
hyper_map
(
F
.
partial
(
lamb_opt_graph_kernel
,
self
.
beta1
,
self
.
beta2
,
self
.
eps
,
lr
,
self
.
beta1
,
self
.
beta2
,
self
.
eps
,
lr
,
self
.
weight_decay_tensor
,
self
.
global_step
),
self
.
weight_decay_tensor
,
self
.
global_step
),
self
.
params
,
self
.
moments1
,
self
.
moments2
,
gradients
,
self
.
decay_flag
)
self
.
params
,
self
.
moments1
,
self
.
moments2
,
gradients
,
self
.
decay_flag
)
else
:
else
:
updated_velocity
=
self
.
hyper_map
(
F
.
partial
(
_lamb_opt
,
optim_result
=
self
.
hyper_map
(
F
.
partial
(
_lamb_opt
,
self
.
beta1
,
self
.
beta2
,
self
.
eps
,
lr
,
self
.
beta1
,
self
.
beta2
,
self
.
eps
,
lr
,
self
.
weight_decay_tensor
,
self
.
global_step
),
self
.
weight_decay_tensor
,
self
.
global_step
),
self
.
params
,
self
.
moments1
,
self
.
moments2
,
gradients
,
self
.
decay_flag
)
self
.
params
,
self
.
moments1
,
self
.
moments2
,
gradients
,
self
.
decay_flag
,
self
.
optim_filter
)
if
self
.
use_parallel
:
optim_result
=
self
.
broadcast_params
(
optim_result
)
added_global_step
=
self
.
global_step
+
self
.
one
added_global_step
=
self
.
global_step
+
self
.
one
F
.
control_depend
(
lr
,
added_global_step
)
F
.
control_depend
(
lr
,
added_global_step
)
self
.
global_step
=
added_global_step
self
.
global_step
=
added_global_step
return
updated_velocity
return
optim_result
mindspore/nn/optim/optimizer.py
浏览文件 @
5a886794
...
@@ -22,11 +22,14 @@ from mindspore.ops import functional as F, composite as C, operations as P
...
@@ -22,11 +22,14 @@ from mindspore.ops import functional as F, composite as C, operations as P
from
mindspore.nn.cell
import
Cell
from
mindspore.nn.cell
import
Cell
from
mindspore.common.parameter
import
Parameter
,
ParameterTuple
from
mindspore.common.parameter
import
Parameter
,
ParameterTuple
from
mindspore.common.initializer
import
initializer
from
mindspore.common.initializer
import
initializer
from
mindspore.common.tensor
import
Tensor
import
mindspore.common.dtype
as
mstype
import
mindspore.common.dtype
as
mstype
from
mindspore._checkparam
import
Validator
as
validator
from
mindspore._checkparam
import
Validator
as
validator
from
mindspore._checkparam
import
Rel
from
mindspore._checkparam
import
Rel
from
mindspore.common.tensor
import
Tensor
from
mindspore
import
log
as
logger
from
mindspore
import
log
as
logger
from
mindspore.parallel._utils
import
_get_global_rank
,
_get_device_num
,
_get_parallel_mode
from
mindspore.parallel._auto_parallel_context
import
auto_parallel_context
from
mindspore.train.parallel_utils
import
ParallelMode
__all__
=
[
'Optimizer'
]
__all__
=
[
'Optimizer'
]
...
@@ -155,6 +158,27 @@ class Optimizer(Cell):
...
@@ -155,6 +158,27 @@ class Optimizer(Cell):
self
.
param_length
=
len
(
self
.
parameters
)
self
.
param_length
=
len
(
self
.
parameters
)
self
.
map_
=
C
.
Map
()
self
.
map_
=
C
.
Map
()
use_parallel
=
auto_parallel_context
().
get_enable_parallel_optimizer
()
self
.
use_parallel
=
use_parallel
if
use_parallel
:
if
self
.
cls_name
not
in
[
"Lamb"
,
"AdamWeightDecayDynamicLR"
,
"AdamWeightDecay"
]:
raise
RuntimeError
(
"Optimizer segmentation does not support optimizer {}"
.
format
(
self
.
cls_name
))
if
_get_parallel_mode
()
not
in
[
ParallelMode
.
HYBRID_PARALLEL
,
ParallelMode
.
DATA_PARALLEL
,
ParallelMode
.
AUTO_PARALLEL
]:
raise
RuntimeError
(
"Optimizer segmentation does not support parallel mode {}"
.
format
(
_get_parallel_mode
()))
self
.
dev_num
=
_get_device_num
()
if
self
.
dev_num
>
self
.
param_length
:
raise
RuntimeError
(
"Optimizer segmentation can not be applied when the number of parameters {} is"
" less than the number of devices {}"
.
format
(
self
.
param_length
,
self
.
dev_num
))
self
.
param_rank
=
self
.
_get_parameter_group_id
()
self
.
optim_filter
=
tuple
(
map
(
lambda
x
:
x
==
_get_global_rank
(),
self
.
param_rank
))
self
.
param_names
=
[]
for
param
in
self
.
parameters
:
self
.
param_names
.
append
(
param
.
name
)
else
:
self
.
optim_filter
=
(
True
,)
*
self
.
param_length
def
decay_weight
(
self
,
gradients
):
def
decay_weight
(
self
,
gradients
):
"""
"""
Weight decay.
Weight decay.
...
@@ -219,8 +243,32 @@ class Optimizer(Cell):
...
@@ -219,8 +243,32 @@ class Optimizer(Cell):
raise
TypeError
(
"Learning rate should be float, Tensor or Iterable."
)
raise
TypeError
(
"Learning rate should be float, Tensor or Iterable."
)
return
lr
return
lr
def
_check_group_params
(
self
,
parameters
):
"""Check group params."""
parse_keys
=
[
'params'
,
'lr'
,
'weight_decay'
,
'order_params'
]
for
group_param
in
parameters
:
invalid_key
=
list
(
filter
(
lambda
x
:
x
not
in
parse_keys
,
group_param
.
keys
()))
if
invalid_key
:
raise
KeyError
(
f
'The key "
{
invalid_key
}
" cannot be recognized in group params.'
)
if
'order_params'
in
group_param
.
keys
():
if
len
(
group_param
.
keys
())
>
1
:
raise
ValueError
(
"The order params dict in group parameters should "
"only include the 'order_params' key."
)
if
not
isinstance
(
group_param
[
'order_params'
],
Iterable
):
raise
TypeError
(
"The value of 'order_params' should be an Iterable type."
)
continue
if
not
group_param
[
'params'
]:
raise
ValueError
(
"Optimizer got an empty group parameter list."
)
for
param
in
group_param
[
'params'
]:
if
not
isinstance
(
param
,
Parameter
):
raise
TypeError
(
"The group param should be an iterator of Parameter type."
)
def
_parse_group_params
(
self
,
parameters
,
learning_rate
):
def
_parse_group_params
(
self
,
parameters
,
learning_rate
):
"""Parse group params."""
"""Parse group params."""
self
.
_check_group_params
(
parameters
)
if
self
.
dynamic_lr
:
if
self
.
dynamic_lr
:
dynamic_lr_length
=
learning_rate
.
size
()
dynamic_lr_length
=
learning_rate
.
size
()
else
:
else
:
...
@@ -250,9 +298,6 @@ class Optimizer(Cell):
...
@@ -250,9 +298,6 @@ class Optimizer(Cell):
if
dynamic_lr_length
not
in
(
lr_length
,
0
):
if
dynamic_lr_length
not
in
(
lr_length
,
0
):
raise
ValueError
(
"The dynamic learning rate in group should be the same size."
)
raise
ValueError
(
"The dynamic learning rate in group should be the same size."
)
if
not
group_param
[
'params'
]:
raise
ValueError
(
"Optimizer got an empty group parameter list."
)
dynamic_lr_length
=
lr_length
dynamic_lr_length
=
lr_length
self
.
dynamic_lr_length
=
dynamic_lr_length
self
.
dynamic_lr_length
=
dynamic_lr_length
...
@@ -384,6 +429,51 @@ class Optimizer(Cell):
...
@@ -384,6 +429,51 @@ class Optimizer(Cell):
lr
=
self
.
learning_rate
lr
=
self
.
learning_rate
return
lr
return
lr
def
_get_parameter_group_id
(
self
):
"""
Get the parameter partition group id, which is less than the number of devices.
Returns:
tuple, the group id tuple of parameters.
"""
rank_list
=
()
count
=
0
for
_
in
range
(
self
.
param_length
):
rank_list
=
rank_list
+
(
count
,)
count
=
count
+
1
if
count
==
self
.
dev_num
:
count
=
0
return
rank_list
def
broadcast_params
(
self
,
optim_result
):
"""
Apply Broadcast operations in the sequential order of parameter groups.
Returns:
bool, the status flag.
"""
param_group
=
[]
key_group
=
[]
for
_
in
range
(
self
.
dev_num
):
param_group
.
append
(
F
.
make_tuple
())
key_group
.
append
(
F
.
make_tuple
())
for
i
in
range
(
self
.
param_length
):
param_group
[
self
.
param_rank
[
i
]]
=
param_group
[
self
.
param_rank
[
i
]]
+
(
optim_result
[
i
],)
key
=
P
.
MakeRefKey
(
self
.
param_names
[
i
])()
key_group
[
self
.
param_rank
[
i
]]
=
key_group
[
self
.
param_rank
[
i
]]
+
(
key
,)
new_param_group
=
[]
for
root
in
range
(
self
.
dev_num
):
ops
=
P
.
Broadcast
(
root
)
next_params
=
ops
(
param_group
[
root
])
new_param_group
.
append
(
next_params
)
for
i
in
range
(
F
.
tuple_len
(
next_params
)):
F
.
assign
(
key_group
[
root
][
i
],
next_params
[
i
])
status
=
True
for
i
in
range
(
self
.
dev_num
-
1
):
status
=
F
.
control_depend
(
new_param_group
[
i
][
0
],
new_param_group
[
i
+
1
])
return
status
def
construct
(
self
,
*
hyper_params
):
def
construct
(
self
,
*
hyper_params
):
raise
NotImplementedError
raise
NotImplementedError
...
...
mindspore/nn/wrap/cell_wrapper.py
浏览文件 @
5a886794
...
@@ -220,7 +220,9 @@ class DataWrapper(Cell):
...
@@ -220,7 +220,9 @@ class DataWrapper(Cell):
def
__init__
(
self
,
network
,
dataset_types
,
dataset_shapes
,
queue_name
):
def
__init__
(
self
,
network
,
dataset_types
,
dataset_shapes
,
queue_name
):
super
(
DataWrapper
,
self
).
__init__
(
auto_prefix
=
False
,
flags
=
network
.
get_flags
())
super
(
DataWrapper
,
self
).
__init__
(
auto_prefix
=
False
,
flags
=
network
.
get_flags
())
# Also copy the flag in `network` construct
flags
=
getattr
(
network
.
__class__
.
construct
,
"_mindspore_flags"
,
{})
self
.
add_flags
(
**
flags
)
self
.
get_next
=
P
.
GetNext
(
dataset_types
,
dataset_shapes
,
len
(
dataset_types
),
queue_name
)
self
.
get_next
=
P
.
GetNext
(
dataset_types
,
dataset_shapes
,
len
(
dataset_types
),
queue_name
)
self
.
network
=
network
self
.
network
=
network
...
...
mindspore/ops/_op_impl/akg/__init__.py
浏览文件 @
5a886794
...
@@ -47,6 +47,7 @@ from .gather_v2 import _gather_v2_akg
...
@@ -47,6 +47,7 @@ from .gather_v2 import _gather_v2_akg
from
.less
import
_less_akg
from
.less
import
_less_akg
from
.log
import
_log_akg
from
.log
import
_log_akg
from
.matmul
import
_matmul_akg
from
.matmul
import
_matmul_akg
from
.batchmatmul
import
_batchmatmul_akg
from
.max_pool_grad_with_argmax
import
_max_pool_grad_with_argmax_akg
from
.max_pool_grad_with_argmax
import
_max_pool_grad_with_argmax_akg
from
.max_pool_with_argmax
import
_max_pool_with_argmax_akg
from
.max_pool_with_argmax
import
_max_pool_with_argmax_akg
from
.max
import
_max_akg
from
.max
import
_max_akg
...
...
mindspore/ops/_op_impl/akg/batchmatmul.py
0 → 100644
浏览文件 @
5a886794
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""BatchMatMul op"""
from
mindspore.ops.op_info_register
import
op_info_register
@
op_info_register
(
"""{
"op_name": "BatchMatMul",
"imply_type": "AutoDiff",
"fusion_type": "OPAQUE",
"attr": [
{
"name": "transpose_a",
"param_type": "optional",
"type": "bool"
},
{
"name": "transpose_b",
"param_type": "optional",
"type": "bool"
}
],
"inputs": [
{
"index": 0,
"dtype": [
"float16"
],
"format": [
"FRACTAL_NZ"
],
"name": "x1"
},
{
"index": 1,
"dtype": [
"float16"
],
"format": [
"FRACTAL_NZ"
],
"name": "x2"
}
],
"outputs": [
{
"index": 0,
"dtype": [
"float16"
],
"format": [
"FRACTAL_NZ"
],
"name": "output"
}
]
}"""
)
def
_batchmatmul_akg
():
"""BatchMatMul AKG register"""
return
mindspore/ops/_op_impl/tbe/confusion_transpose_d.py
浏览文件 @
5a886794
...
@@ -28,26 +28,8 @@ confusion_transpose_d_op_info = TBERegOp("ConfusionTransposeD") \
...
@@ -28,26 +28,8 @@ confusion_transpose_d_op_info = TBERegOp("ConfusionTransposeD") \
.
attr
(
"transpose_first"
,
"required"
,
"bool"
,
"all"
)
\
.
attr
(
"transpose_first"
,
"required"
,
"bool"
,
"all"
)
\
.
input
(
0
,
"x"
,
False
,
"required"
,
"all"
)
\
.
input
(
0
,
"x"
,
False
,
"required"
,
"all"
)
\
.
output
(
0
,
"y"
,
False
,
"required"
,
"all"
)
\
.
output
(
0
,
"y"
,
False
,
"required"
,
"all"
)
\
.
dtype_format
(
DataType
.
I8_FracNZ
,
DataType
.
I8_FracNZ
)
\
.
op_pattern
(
"dynamicFormat"
)
\
.
dtype_format
(
DataType
.
I8_Default
,
DataType
.
I8_Default
)
\
.
dtype_format
(
DataType
.
None_None
,
DataType
.
None_None
)
\
.
dtype_format
(
DataType
.
U8_FracNZ
,
DataType
.
U8_FracNZ
)
\
.
dtype_format
(
DataType
.
U8_Default
,
DataType
.
U8_Default
)
\
.
dtype_format
(
DataType
.
I16_FracNZ
,
DataType
.
I16_FracNZ
)
\
.
dtype_format
(
DataType
.
I16_Default
,
DataType
.
I16_Default
)
\
.
dtype_format
(
DataType
.
U16_FracNZ
,
DataType
.
U16_FracNZ
)
\
.
dtype_format
(
DataType
.
U16_Default
,
DataType
.
U16_Default
)
\
.
dtype_format
(
DataType
.
I32_FracNZ
,
DataType
.
I32_FracNZ
)
\
.
dtype_format
(
DataType
.
I32_Default
,
DataType
.
I32_Default
)
\
.
dtype_format
(
DataType
.
U32_FracNZ
,
DataType
.
U32_FracNZ
)
\
.
dtype_format
(
DataType
.
U32_Default
,
DataType
.
U32_Default
)
\
.
dtype_format
(
DataType
.
I64_FracNZ
,
DataType
.
I64_FracNZ
)
\
.
dtype_format
(
DataType
.
I64_Default
,
DataType
.
I64_Default
)
\
.
dtype_format
(
DataType
.
U64_FracNZ
,
DataType
.
U64_FracNZ
)
\
.
dtype_format
(
DataType
.
U64_Default
,
DataType
.
U64_Default
)
\
.
dtype_format
(
DataType
.
F16_FracNZ
,
DataType
.
F16_FracNZ
)
\
.
dtype_format
(
DataType
.
F16_Default
,
DataType
.
F16_Default
)
\
.
dtype_format
(
DataType
.
F32_FracNZ
,
DataType
.
F32_FracNZ
)
\
.
dtype_format
(
DataType
.
F32_Default
,
DataType
.
F32_Default
)
\
.
get_op_info
()
.
get_op_info
()
...
...
mindspore/ops/composite/multitype_ops/setitem_impl.py
浏览文件 @
5a886794
...
@@ -85,6 +85,22 @@ def _list_setitem_with_List(data, number_index, value):
...
@@ -85,6 +85,22 @@ def _list_setitem_with_List(data, number_index, value):
return
F
.
list_setitem
(
data
,
number_index
,
value
)
return
F
.
list_setitem
(
data
,
number_index
,
value
)
@
setitem
.
register
(
"List"
,
"Number"
,
"Tuple"
)
def
_list_setitem_with_Tuple
(
data
,
number_index
,
value
):
"""
Assigns value to list.
Inputs:
data (list): Data of type lis.
number_index (Number): Index of data.
value (list): Value given.
Outputs:
list, type is same as the element type of data.
"""
return
F
.
list_setitem
(
data
,
number_index
,
value
)
@
setitem
.
register
(
"Dictionary"
,
"String"
,
"Tensor"
)
@
setitem
.
register
(
"Dictionary"
,
"String"
,
"Tensor"
)
def
_dict_setitem_with_tensor
(
data
,
key
,
value
):
def
_dict_setitem_with_tensor
(
data
,
key
,
value
):
"""
"""
...
...
mindspore/ops/operations/comm_ops.py
浏览文件 @
5a886794
...
@@ -98,6 +98,7 @@ class AllReduce(PrimitiveWithInfer):
...
@@ -98,6 +98,7 @@ class AllReduce(PrimitiveWithInfer):
self
.
op
=
op
self
.
op
=
op
self
.
add_prim_attr
(
'group'
,
_get_group
(
group
))
self
.
add_prim_attr
(
'group'
,
_get_group
(
group
))
self
.
add_prim_attr
(
'fusion'
,
0
)
self
.
add_prim_attr
(
'fusion'
,
0
)
self
.
add_prim_attr
(
'index'
,
0
)
def
vm_impl
(
self
,
x
):
def
vm_impl
(
self
,
x
):
"""Implement by vm mode."""
"""Implement by vm mode."""
...
...
mindspore/ops/operations/debug_ops.py
浏览文件 @
5a886794
...
@@ -309,12 +309,6 @@ class Print(PrimitiveWithInfer):
...
@@ -309,12 +309,6 @@ class Print(PrimitiveWithInfer):
Output tensor or string to stdout.
Output tensor or string to stdout.
Note:
Note:
The print operation cannot support the following cases currently.
1. The type of tensor is float64 or bool.
2. The data of tensor is a scalar type.
In pynative mode, please use python print function.
In pynative mode, please use python print function.
Inputs:
Inputs:
...
@@ -334,7 +328,7 @@ class Print(PrimitiveWithInfer):
...
@@ -334,7 +328,7 @@ class Print(PrimitiveWithInfer):
@
prim_attr_register
@
prim_attr_register
def
__init__
(
self
):
def
__init__
(
self
):
pass
self
.
add_prim_attr
(
"_side_effect"
,
True
)
def
__call__
(
self
,
*
args
):
def
__call__
(
self
,
*
args
):
for
arg
in
args
:
for
arg
in
args
:
...
...
mindspore/ops/operations/math_ops.py
浏览文件 @
5a886794
...
@@ -888,7 +888,8 @@ class Neg(PrimitiveWithInfer):
...
@@ -888,7 +888,8 @@ class Neg(PrimitiveWithInfer):
def
infer_value
(
self
,
input_x
):
def
infer_value
(
self
,
input_x
):
if
input_x
is
not
None
:
if
input_x
is
not
None
:
input_x
=
input_x
.
asnumpy
()
input_x
=
input_x
.
asnumpy
()
return
Tensor
(
-
input_x
)
out
=
np
.
array
(
-
input_x
,
input_x
.
dtype
)
return
Tensor
(
out
)
return
None
return
None
...
@@ -1667,7 +1668,8 @@ class Div(_MathBinaryOp):
...
@@ -1667,7 +1668,8 @@ class Div(_MathBinaryOp):
if
x
is
not
None
and
y
is
not
None
:
if
x
is
not
None
and
y
is
not
None
:
x
=
x
.
asnumpy
()
x
=
x
.
asnumpy
()
y
=
y
.
asnumpy
()
y
=
y
.
asnumpy
()
return
Tensor
(
x
/
y
)
out
=
np
.
array
(
x
/
y
,
x
.
dtype
)
return
Tensor
(
out
)
return
None
return
None
...
...
mindspore/ops/operations/other_ops.py
浏览文件 @
5a886794
...
@@ -59,8 +59,7 @@ class Assign(PrimitiveWithInfer):
...
@@ -59,8 +59,7 @@ class Assign(PrimitiveWithInfer):
return
variable
return
variable
def
infer_dtype
(
self
,
variable
,
value
):
def
infer_dtype
(
self
,
variable
,
value
):
args
=
{
"variable"
:
variable
,
"value"
:
value
}
# Add a type validation later when we don't have to assign a value to RefKey.
validator
.
check_tensor_type_same
(
args
,
(
mstype
.
bool_
,)
+
mstype
.
number_type
,
self
.
name
)
return
variable
return
variable
...
...
mindspore/parallel/_auto_parallel_context.py
浏览文件 @
5a886794
...
@@ -400,6 +400,23 @@ class _AutoParallelContext:
...
@@ -400,6 +400,23 @@ class _AutoParallelContext:
self
.
check_context_handle
()
self
.
check_context_handle
()
return
self
.
_context_handle
.
get_global_rank_is_set
()
return
self
.
_context_handle
.
get_global_rank_is_set
()
def
set_enable_parallel_optimizer
(
self
,
enable_parallel_optimizer
):
"""
Set enable/disable parallel optimizer.
Args:
set_enable_parallel_optimizer (bool): Enable/disable parallel optimizer.
"""
self
.
check_context_handle
()
if
not
isinstance
(
enable_parallel_optimizer
,
bool
):
raise
TypeError
(
'enable_parallel_optimizer is invalid type'
)
self
.
_context_handle
.
set_enable_parallel_optimizer
(
enable_parallel_optimizer
)
def
get_enable_parallel_optimizer
(
self
):
"""Get parallel optimizer flag."""
self
.
check_context_handle
()
return
self
.
_context_handle
.
get_enable_parallel_optimizer
()
def
reset
(
self
):
def
reset
(
self
):
"""Reset all settings."""
"""Reset all settings."""
self
.
check_context_handle
()
self
.
check_context_handle
()
...
@@ -433,7 +450,8 @@ _set_auto_parallel_context_func_map = {
...
@@ -433,7 +450,8 @@ _set_auto_parallel_context_func_map = {
"parameter_broadcast"
:
auto_parallel_context
().
set_parameter_broadcast
,
"parameter_broadcast"
:
auto_parallel_context
().
set_parameter_broadcast
,
"strategy_ckpt_load_file"
:
auto_parallel_context
().
set_strategy_ckpt_load_file
,
"strategy_ckpt_load_file"
:
auto_parallel_context
().
set_strategy_ckpt_load_file
,
"strategy_ckpt_save_file"
:
auto_parallel_context
().
set_strategy_ckpt_save_file
,
"strategy_ckpt_save_file"
:
auto_parallel_context
().
set_strategy_ckpt_save_file
,
"full_batch"
:
auto_parallel_context
().
set_full_batch
}
"full_batch"
:
auto_parallel_context
().
set_full_batch
,
"enable_parallel_optimizer"
:
auto_parallel_context
().
set_enable_parallel_optimizer
}
_get_auto_parallel_context_func_map
=
{
_get_auto_parallel_context_func_map
=
{
...
@@ -447,13 +465,15 @@ _get_auto_parallel_context_func_map = {
...
@@ -447,13 +465,15 @@ _get_auto_parallel_context_func_map = {
"parameter_broadcast"
:
auto_parallel_context
().
get_parameter_broadcast
,
"parameter_broadcast"
:
auto_parallel_context
().
get_parameter_broadcast
,
"strategy_ckpt_load_file"
:
auto_parallel_context
().
get_strategy_ckpt_load_file
,
"strategy_ckpt_load_file"
:
auto_parallel_context
().
get_strategy_ckpt_load_file
,
"strategy_ckpt_save_file"
:
auto_parallel_context
().
get_strategy_ckpt_save_file
,
"strategy_ckpt_save_file"
:
auto_parallel_context
().
get_strategy_ckpt_save_file
,
"full_batch"
:
auto_parallel_context
().
get_full_batch
}
"full_batch"
:
auto_parallel_context
().
get_full_batch
,
"enable_parallel_optimizer"
:
auto_parallel_context
().
get_enable_parallel_optimizer
}
@
args_type_check
(
device_num
=
int
,
global_rank
=
int
,
mirror_mean
=
bool
,
cast_before_mirror
=
bool
,
@
args_type_check
(
device_num
=
int
,
global_rank
=
int
,
mirror_mean
=
bool
,
cast_before_mirror
=
bool
,
loss_repeated_mean
=
bool
,
parallel_mode
=
str
,
auto_parallel_search_mode
=
str
,
loss_repeated_mean
=
bool
,
parallel_mode
=
str
,
auto_parallel_search_mode
=
str
,
parameter_broadcast
=
bool
,
strategy_ckpt_load_file
=
str
,
parameter_broadcast
=
bool
,
strategy_ckpt_load_file
=
str
,
strategy_ckpt_save_file
=
str
,
full_batch
=
bool
)
strategy_ckpt_save_file
=
str
,
full_batch
=
bool
,
enable_parallel_optimizer
=
bool
)
def
_set_auto_parallel_context
(
**
kwargs
):
def
_set_auto_parallel_context
(
**
kwargs
):
"""
"""
Set auto parallel context.
Set auto parallel context.
...
@@ -493,6 +513,7 @@ def _set_auto_parallel_context(**kwargs):
...
@@ -493,6 +513,7 @@ def _set_auto_parallel_context(**kwargs):
strategy_ckpt_load_file (str): The path to load parallel strategy checkpoint. Default: ''
strategy_ckpt_load_file (str): The path to load parallel strategy checkpoint. Default: ''
strategy_ckpt_save_file (str): The path to save parallel strategy checkpoint. Default: ''
strategy_ckpt_save_file (str): The path to save parallel strategy checkpoint. Default: ''
full_batch (bool): Whether to load the whole batch on each device. Default: False.
full_batch (bool): Whether to load the whole batch on each device. Default: False.
enable_parallel_optimizer (bool): Enable using optimizer segmentation or noe. Default: False.
Raises:
Raises:
ValueError: If input key is not attribute in auto parallel context.
ValueError: If input key is not attribute in auto parallel context.
...
@@ -535,5 +556,6 @@ def _reset_auto_parallel_context():
...
@@ -535,5 +556,6 @@ def _reset_auto_parallel_context():
- parameter_broadcast: False.
- parameter_broadcast: False.
- strategy_ckpt_load_file: ""
- strategy_ckpt_load_file: ""
- strategy_ckpt_save_file: ""
- strategy_ckpt_save_file: ""
- enable_parallel_optimizer: False
"""
"""
auto_parallel_context
().
reset
()
auto_parallel_context
().
reset
()
mindspore/train/callback/_summary_collector.py
浏览文件 @
5a886794
...
@@ -166,8 +166,11 @@ class SummaryCollector(Callback):
...
@@ -166,8 +166,11 @@ class SummaryCollector(Callback):
self
.
_has_saved_custom_data
=
False
self
.
_has_saved_custom_data
=
False
self
.
_is_parse_loss_success
=
True
self
.
_is_parse_loss_success
=
True
self
.
_first_step
=
True
self
.
_first_step
=
True
self
.
_dataset_sink_mode
=
True
def
__enter__
(
self
):
def
__enter__
(
self
):
self
.
_first_step
=
True
self
.
_dataset_sink_mode
=
True
self
.
_record
=
SummaryRecord
(
log_dir
=
self
.
_summary_dir
)
self
.
_record
=
SummaryRecord
(
log_dir
=
self
.
_summary_dir
)
return
self
return
self
...
@@ -279,15 +282,15 @@ class SummaryCollector(Callback):
...
@@ -279,15 +282,15 @@ class SummaryCollector(Callback):
def
step_end
(
self
,
run_context
):
def
step_end
(
self
,
run_context
):
cb_params
=
run_context
.
original_args
()
cb_params
=
run_context
.
original_args
()
if
self
.
_first_step
:
# Notice: This way of determining whether dataset sink mode is True does not work in the eval scenario
self
.
_dataset_sink_mode
=
bool
(
cb_params
.
cur_step_num
==
cb_params
.
batch_num
)
if
cb_params
.
mode
==
ModeEnum
.
TRAIN
.
value
:
if
cb_params
.
mode
==
ModeEnum
.
TRAIN
.
value
:
# Make sure the first step data is recorded
if
not
self
.
_is_collect_this_step
(
cb_params
):
if
not
self
.
_first_step
and
cb_params
.
cur_step_num
%
self
.
_collect_freq
:
return
return
self
.
_first_step
=
False
if
not
self
.
_has_saved_train_network
:
if
not
self
.
_has_saved_train_network
:
self
.
_collect_graphs
(
cb_params
)
self
.
_collect_graphs
(
cb_params
)
...
@@ -295,6 +298,7 @@ class SummaryCollector(Callback):
...
@@ -295,6 +298,7 @@ class SummaryCollector(Callback):
self
.
_collect_metric
(
cb_params
)
self
.
_collect_metric
(
cb_params
)
self
.
_collect_histogram
(
cb_params
)
self
.
_collect_histogram
(
cb_params
)
self
.
_first_step
=
False
self
.
_record
.
record
(
cb_params
.
cur_step_num
)
self
.
_record
.
record
(
cb_params
.
cur_step_num
)
def
end
(
self
,
run_context
):
def
end
(
self
,
run_context
):
...
@@ -320,6 +324,18 @@ class SummaryCollector(Callback):
...
@@ -320,6 +324,18 @@ class SummaryCollector(Callback):
raise
ValueError
(
f
"There are more than one
{
self
.
__class__
.
__name__
}
instance in callback list,"
raise
ValueError
(
f
"There are more than one
{
self
.
__class__
.
__name__
}
instance in callback list,"
f
"but expected only one
{
self
.
__class__
.
__name__
}
instance."
)
f
"but expected only one
{
self
.
__class__
.
__name__
}
instance."
)
def
_is_collect_this_step
(
self
,
cb_params
):
"""Decide whether to collect data for the current step."""
# Make sure the first step data is recorded
if
not
self
.
_first_step
:
if
self
.
_dataset_sink_mode
:
if
cb_params
.
cur_epoch_num
%
self
.
_collect_freq
:
return
False
else
:
if
cb_params
.
cur_step_num
%
self
.
_collect_freq
:
return
False
return
True
@
staticmethod
@
staticmethod
def
_package_custom_lineage_data
(
custom_lineage_data
):
def
_package_custom_lineage_data
(
custom_lineage_data
):
"""
"""
...
...
model_zoo/faster_rcnn/src/dataset.py
浏览文件 @
5a886794
...
@@ -318,10 +318,6 @@ def preprocess_fn(image, box, is_training):
...
@@ -318,10 +318,6 @@ def preprocess_fn(image, box, is_training):
else
:
else
:
input_data
=
resize_column
(
*
input_data
)
input_data
=
resize_column
(
*
input_data
)
photo
=
(
np
.
random
.
rand
()
<
config
.
photo_ratio
)
if
photo
:
input_data
=
photo_crop_column
(
*
input_data
)
input_data
=
image_bgr_rgb
(
*
input_data
)
input_data
=
image_bgr_rgb
(
*
input_data
)
output_data
=
input_data
output_data
=
input_data
...
@@ -432,19 +428,19 @@ def data_to_mindrecord_byte_image(dataset="coco", is_training=True, prefix="fast
...
@@ -432,19 +428,19 @@ def data_to_mindrecord_byte_image(dataset="coco", is_training=True, prefix="fast
writer
.
write_raw_data
([
row
])
writer
.
write_raw_data
([
row
])
writer
.
commit
()
writer
.
commit
()
def
create_fasterrcnn_dataset
(
mindrecord_file
,
batch_size
=
2
,
repeat_num
=
12
,
device_num
=
1
,
rank_id
=
0
,
def
create_fasterrcnn_dataset
(
mindrecord_file
,
batch_size
=
2
,
repeat_num
=
12
,
device_num
=
1
,
rank_id
=
0
,
is_training
=
True
,
num_parallel_workers
=
8
):
is_training
=
True
,
num_parallel_workers
=
4
):
"""Creatr FasterRcnn dataset with MindDataset."""
"""Creatr FasterRcnn dataset with MindDataset."""
ds
=
de
.
MindDataset
(
mindrecord_file
,
columns_list
=
[
"image"
,
"annotation"
],
num_shards
=
device_num
,
shard_id
=
rank_id
,
ds
=
de
.
MindDataset
(
mindrecord_file
,
columns_list
=
[
"image"
,
"annotation"
],
num_shards
=
device_num
,
shard_id
=
rank_id
,
num_parallel_workers
=
num_parallel_workers
,
shuffle
=
is_training
)
num_parallel_workers
=
1
,
shuffle
=
is_training
)
decode
=
C
.
Decode
()
decode
=
C
.
Decode
()
ds
=
ds
.
map
(
input_columns
=
[
"image"
],
operations
=
decode
)
ds
=
ds
.
map
(
input_columns
=
[
"image"
],
operations
=
decode
,
num_parallel_workers
=
1
)
compose_map_func
=
(
lambda
image
,
annotation
:
preprocess_fn
(
image
,
annotation
,
is_training
))
compose_map_func
=
(
lambda
image
,
annotation
:
preprocess_fn
(
image
,
annotation
,
is_training
))
hwc_to_chw
=
C
.
HWC2CHW
()
hwc_to_chw
=
C
.
HWC2CHW
()
normalize_op
=
C
.
Normalize
((
123.675
,
116.28
,
103.53
),
(
58.395
,
57.12
,
57.375
))
normalize_op
=
C
.
Normalize
((
123.675
,
116.28
,
103.53
),
(
58.395
,
57.12
,
57.375
))
horizontally_op
=
C
.
RandomHorizontalFlip
(
1
)
horizontally_op
=
C
.
RandomHorizontalFlip
(
1
)
type_cast0
=
CC
.
TypeCast
(
mstype
.
float32
)
type_cast1
=
CC
.
TypeCast
(
mstype
.
float16
)
type_cast1
=
CC
.
TypeCast
(
mstype
.
float16
)
type_cast2
=
CC
.
TypeCast
(
mstype
.
int32
)
type_cast2
=
CC
.
TypeCast
(
mstype
.
int32
)
type_cast3
=
CC
.
TypeCast
(
mstype
.
bool_
)
type_cast3
=
CC
.
TypeCast
(
mstype
.
bool_
)
...
@@ -453,17 +449,18 @@ def create_fasterrcnn_dataset(mindrecord_file, batch_size=2, repeat_num=12, devi
...
@@ -453,17 +449,18 @@ def create_fasterrcnn_dataset(mindrecord_file, batch_size=2, repeat_num=12, devi
ds
=
ds
.
map
(
input_columns
=
[
"image"
,
"annotation"
],
ds
=
ds
.
map
(
input_columns
=
[
"image"
,
"annotation"
],
output_columns
=
[
"image"
,
"image_shape"
,
"box"
,
"label"
,
"valid_num"
],
output_columns
=
[
"image"
,
"image_shape"
,
"box"
,
"label"
,
"valid_num"
],
columns_order
=
[
"image"
,
"image_shape"
,
"box"
,
"label"
,
"valid_num"
],
columns_order
=
[
"image"
,
"image_shape"
,
"box"
,
"label"
,
"valid_num"
],
operations
=
compose_map_func
,
num_parallel_workers
=
4
)
operations
=
compose_map_func
,
num_parallel_workers
=
num_parallel_workers
)
ds
=
ds
.
map
(
input_columns
=
[
"image"
],
operations
=
[
normalize_op
,
type_cast0
],
num_parallel_workers
=
num_parallel_workers
)
flip
=
(
np
.
random
.
rand
()
<
config
.
flip_ratio
)
flip
=
(
np
.
random
.
rand
()
<
config
.
flip_ratio
)
if
flip
:
if
flip
:
ds
=
ds
.
map
(
input_columns
=
[
"image"
],
operations
=
[
horizontally_op
],
ds
=
ds
.
map
(
input_columns
=
[
"image"
],
operations
=
[
normalize_op
,
horizontally_op
,
hwc_to_chw
,
type_cast1
],
num_parallel_workers
=
num_parallel_workers
)
num_parallel_workers
=
24
)
ds
=
ds
.
map
(
input_columns
=
[
"image"
,
"image_shape"
,
"box"
,
"label"
,
"valid_num"
],
ds
=
ds
.
map
(
input_columns
=
[
"image"
,
"image_shape"
,
"box"
,
"label"
,
"valid_num"
],
operations
=
flipped_generation
,
num_parallel_workers
=
4
)
operations
=
flipped_generation
,
num_parallel_workers
=
num_parallel_workers
)
else
:
ds
=
ds
.
map
(
input_columns
=
[
"image"
],
operations
=
[
normalize_op
,
hwc_to_chw
,
type_cast1
],
num_parallel_workers
=
24
)
else
:
else
:
ds
=
ds
.
map
(
input_columns
=
[
"image"
,
"annotation"
],
ds
=
ds
.
map
(
input_columns
=
[
"image"
,
"annotation"
],
output_columns
=
[
"image"
,
"image_shape"
,
"box"
,
"label"
,
"valid_num"
],
output_columns
=
[
"image"
,
"image_shape"
,
"box"
,
"label"
,
"valid_num"
],
...
@@ -471,11 +468,10 @@ def create_fasterrcnn_dataset(mindrecord_file, batch_size=2, repeat_num=12, devi
...
@@ -471,11 +468,10 @@ def create_fasterrcnn_dataset(mindrecord_file, batch_size=2, repeat_num=12, devi
operations
=
compose_map_func
,
operations
=
compose_map_func
,
num_parallel_workers
=
num_parallel_workers
)
num_parallel_workers
=
num_parallel_workers
)
ds
=
ds
.
map
(
input_columns
=
[
"image"
],
operations
=
[
normalize_op
,
type_cast0
],
ds
=
ds
.
map
(
input_columns
=
[
"image"
],
operations
=
[
normalize_op
,
hwc_to_chw
,
type_cast1
],
num_parallel_workers
=
num_parallel_workers
)
num_parallel_workers
=
24
)
# transpose_column from python to c
# transpose_column from python to c
ds
=
ds
.
map
(
input_columns
=
[
"image"
],
operations
=
[
hwc_to_chw
,
type_cast1
])
ds
=
ds
.
map
(
input_columns
=
[
"image_shape"
],
operations
=
[
type_cast1
])
ds
=
ds
.
map
(
input_columns
=
[
"image_shape"
],
operations
=
[
type_cast1
])
ds
=
ds
.
map
(
input_columns
=
[
"box"
],
operations
=
[
type_cast1
])
ds
=
ds
.
map
(
input_columns
=
[
"box"
],
operations
=
[
type_cast1
])
ds
=
ds
.
map
(
input_columns
=
[
"label"
],
operations
=
[
type_cast2
])
ds
=
ds
.
map
(
input_columns
=
[
"label"
],
operations
=
[
type_cast2
])
...
...
model_zoo/vgg16/src/config.py
浏览文件 @
5a886794
...
@@ -19,7 +19,9 @@ from easydict import EasyDict as edict
...
@@ -19,7 +19,9 @@ from easydict import EasyDict as edict
cifar_cfg
=
edict
({
cifar_cfg
=
edict
({
'num_classes'
:
10
,
'num_classes'
:
10
,
'lr_init'
:
0.05
,
'lr_init'
:
0.01
,
'lr_max'
:
0.1
,
'warmup_epochs'
:
5
,
'batch_size'
:
64
,
'batch_size'
:
64
,
'epoch_size'
:
70
,
'epoch_size'
:
70
,
'momentum'
:
0.9
,
'momentum'
:
0.9
,
...
...
model_zoo/vgg16/train.py
浏览文件 @
5a886794
...
@@ -38,20 +38,25 @@ random.seed(1)
...
@@ -38,20 +38,25 @@ random.seed(1)
np
.
random
.
seed
(
1
)
np
.
random
.
seed
(
1
)
def
lr_steps
(
global_step
,
lr_
max
=
None
,
total_epochs
=
None
,
steps_per_epoch
=
None
):
def
lr_steps
(
global_step
,
lr_
init
,
lr_max
,
warmup_epochs
,
total_epochs
,
steps_per_epoch
):
"""Set learning rate."""
"""Set learning rate."""
lr_each_step
=
[]
lr_each_step
=
[]
total_steps
=
steps_per_epoch
*
total_epochs
total_steps
=
steps_per_epoch
*
total_epochs
decay_epoch_index
=
[
0.3
*
total_steps
,
0.6
*
total_steps
,
0.8
*
total_steps
]
warmup_steps
=
steps_per_epoch
*
warmup_epochs
if
warmup_steps
!=
0
:
inc_each_step
=
(
float
(
lr_max
)
-
float
(
lr_init
))
/
float
(
warmup_steps
)
else
:
inc_each_step
=
0
for
i
in
range
(
total_steps
):
for
i
in
range
(
total_steps
):
if
i
<
decay_epoch_index
[
0
]:
if
i
<
warmup_steps
:
lr_each_step
.
append
(
lr_max
)
lr_value
=
float
(
lr_init
)
+
inc_each_step
*
float
(
i
)
elif
i
<
decay_epoch_index
[
1
]:
lr_each_step
.
append
(
lr_max
*
0.1
)
elif
i
<
decay_epoch_index
[
2
]:
lr_each_step
.
append
(
lr_max
*
0.01
)
else
:
else
:
lr_each_step
.
append
(
lr_max
*
0.001
)
base
=
(
1.0
-
(
float
(
i
)
-
float
(
warmup_steps
))
/
(
float
(
total_steps
)
-
float
(
warmup_steps
)))
lr_value
=
float
(
lr_max
)
*
base
*
base
if
lr_value
<
0.0
:
lr_value
=
0.0
lr_each_step
.
append
(
lr_value
)
current_step
=
global_step
current_step
=
global_step
lr_each_step
=
np
.
array
(
lr_each_step
).
astype
(
np
.
float32
)
lr_each_step
=
np
.
array
(
lr_each_step
).
astype
(
np
.
float32
)
learning_rate
=
lr_each_step
[
current_step
:]
learning_rate
=
lr_each_step
[
current_step
:]
...
@@ -86,7 +91,8 @@ if __name__ == '__main__':
...
@@ -86,7 +91,8 @@ if __name__ == '__main__':
if
args_opt
.
pre_trained
:
if
args_opt
.
pre_trained
:
load_param_into_net
(
net
,
load_checkpoint
(
args_opt
.
pre_trained
))
load_param_into_net
(
net
,
load_checkpoint
(
args_opt
.
pre_trained
))
lr
=
lr_steps
(
0
,
lr_max
=
cfg
.
lr_init
,
total_epochs
=
cfg
.
epoch_size
,
steps_per_epoch
=
batch_num
)
lr
=
lr_steps
(
0
,
lr_init
=
cfg
.
lr_init
,
lr_max
=
cfg
.
lr_max
,
warmup_epochs
=
cfg
.
warmup_epochs
,
total_epochs
=
cfg
.
epoch_size
,
steps_per_epoch
=
batch_num
)
opt
=
Momentum
(
filter
(
lambda
x
:
x
.
requires_grad
,
net
.
get_parameters
()),
Tensor
(
lr
),
cfg
.
momentum
,
opt
=
Momentum
(
filter
(
lambda
x
:
x
.
requires_grad
,
net
.
get_parameters
()),
Tensor
(
lr
),
cfg
.
momentum
,
weight_decay
=
cfg
.
weight_decay
)
weight_decay
=
cfg
.
weight_decay
)
loss
=
nn
.
SoftmaxCrossEntropyWithLogits
(
sparse
=
True
,
reduction
=
'mean'
,
is_grad
=
False
)
loss
=
nn
.
SoftmaxCrossEntropyWithLogits
(
sparse
=
True
,
reduction
=
'mean'
,
is_grad
=
False
)
...
...
serving/core/server.cc
浏览文件 @
5a886794
...
@@ -22,6 +22,7 @@
...
@@ -22,6 +22,7 @@
#include <vector>
#include <vector>
#include <utility>
#include <utility>
#include <memory>
#include <memory>
#include <future>
#include "mindspore/ccsrc/utils/log_adapter.h"
#include "mindspore/ccsrc/utils/log_adapter.h"
#include "serving/ms_service.grpc.pb.h"
#include "serving/ms_service.grpc.pb.h"
...
@@ -40,7 +41,7 @@ namespace serving {
...
@@ -40,7 +41,7 @@ namespace serving {
using
MSTensorPtr
=
std
::
shared_ptr
<
inference
::
MSTensor
>
;
using
MSTensorPtr
=
std
::
shared_ptr
<
inference
::
MSTensor
>
;
Status
Session
::
CreatDeviceSession
(
const
std
::
string
&
device
,
uint32_t
device_id
)
{
Status
Session
::
CreatDeviceSession
(
const
std
::
string
&
device
,
uint32_t
device_id
)
{
session_
=
inference
::
MSSession
::
CreateSession
(
device
+
"Inference"
,
device_id
);
session_
=
inference
::
MSSession
::
CreateSession
(
device
,
device_id
);
if
(
session_
==
nullptr
)
{
if
(
session_
==
nullptr
)
{
MS_LOG
(
ERROR
)
<<
"Creat Session Failed"
;
MS_LOG
(
ERROR
)
<<
"Creat Session Failed"
;
return
FAILED
;
return
FAILED
;
...
@@ -67,6 +68,7 @@ Status Session::Predict(const std::vector<MSTensorPtr> &inputs, inference::Multi
...
@@ -67,6 +68,7 @@ Status Session::Predict(const std::vector<MSTensorPtr> &inputs, inference::Multi
MS_LOG
(
INFO
)
<<
"run Predict"
;
MS_LOG
(
INFO
)
<<
"run Predict"
;
*
outputs
=
session_
->
RunGraph
(
graph_id_
,
inputs
);
*
outputs
=
session_
->
RunGraph
(
graph_id_
,
inputs
);
MS_LOG
(
INFO
)
<<
"run Predict finished"
;
return
SUCCESS
;
return
SUCCESS
;
}
}
...
@@ -80,12 +82,16 @@ Status Session::Warmup(const MindSporeModelPtr model) {
...
@@ -80,12 +82,16 @@ Status Session::Warmup(const MindSporeModelPtr model) {
std
::
string
file_name
=
model
->
GetModelPath
()
+
'/'
+
model
->
GetModelName
();
std
::
string
file_name
=
model
->
GetModelPath
()
+
'/'
+
model
->
GetModelName
();
char
*
graphBuf
=
ReadFile
(
file_name
.
c_str
(),
&
size
);
char
*
graphBuf
=
ReadFile
(
file_name
.
c_str
(),
&
size
);
if
(
graphBuf
==
nullptr
)
{
if
(
graphBuf
==
nullptr
)
{
MS_LOG
(
ERROR
)
<<
"
Load graph model
failed, file name is "
<<
file_name
.
c_str
();
MS_LOG
(
ERROR
)
<<
"
Read model file
failed, file name is "
<<
file_name
.
c_str
();
return
FAILED
;
return
FAILED
;
}
}
last_graph_
=
inference
::
LoadModel
(
graphBuf
,
size
,
device_type_
);
last_graph_
=
inference
::
LoadModel
(
graphBuf
,
size
,
device_type_
);
if
(
last_graph_
==
nullptr
)
{
MS_LOG
(
ERROR
)
<<
"Load graph model failed, file name is "
<<
file_name
.
c_str
();
return
FAILED
;
}
graph_id_
=
session_
->
CompileGraph
(
last_graph_
);
graph_id_
=
session_
->
CompileGraph
(
last_graph_
);
MS_LOG
(
INFO
)
<<
"Session Warmup"
;
MS_LOG
(
INFO
)
<<
"Session Warmup
finished
"
;
return
SUCCESS
;
return
SUCCESS
;
}
}
...
@@ -95,6 +101,9 @@ Status Session::Clear() {
...
@@ -95,6 +101,9 @@ Status Session::Clear() {
}
}
namespace
{
namespace
{
static
const
uint32_t
uint32max
=
0x7FFFFFFF
;
std
::
promise
<
void
>
exit_requested
;
const
std
::
map
<
ms_serving
::
DataType
,
TypeId
>
type2id_map
{
const
std
::
map
<
ms_serving
::
DataType
,
TypeId
>
type2id_map
{
{
ms_serving
::
MS_UNKNOWN
,
TypeId
::
kNumberTypeBegin
},
{
ms_serving
::
MS_BOOL
,
TypeId
::
kNumberTypeBool
},
{
ms_serving
::
MS_UNKNOWN
,
TypeId
::
kNumberTypeBegin
},
{
ms_serving
::
MS_BOOL
,
TypeId
::
kNumberTypeBool
},
{
ms_serving
::
MS_INT8
,
TypeId
::
kNumberTypeInt8
},
{
ms_serving
::
MS_UINT8
,
TypeId
::
kNumberTypeUInt8
},
{
ms_serving
::
MS_INT8
,
TypeId
::
kNumberTypeInt8
},
{
ms_serving
::
MS_UINT8
,
TypeId
::
kNumberTypeUInt8
},
...
@@ -141,7 +150,7 @@ MSTensorPtr ServingTensor2MSTensor(const ms_serving::Tensor &tensor) {
...
@@ -141,7 +150,7 @@ MSTensorPtr ServingTensor2MSTensor(const ms_serving::Tensor &tensor) {
}
}
TypeId
type
=
iter
->
second
;
TypeId
type
=
iter
->
second
;
auto
ms_tensor
=
std
::
shared_ptr
<
inference
::
MSTensor
>
(
inference
::
MSTensor
::
CreateTensor
(
type
,
shape
));
auto
ms_tensor
=
std
::
shared_ptr
<
inference
::
MSTensor
>
(
inference
::
MSTensor
::
CreateTensor
(
type
,
shape
));
memcpy_s
(
ms_tensor
->
MutableData
(),
tensor
.
data
().
s
ize
(),
tensor
.
data
().
data
(),
tensor
.
data
().
size
());
memcpy_s
(
ms_tensor
->
MutableData
(),
ms_tensor
->
S
ize
(),
tensor
.
data
().
data
(),
tensor
.
data
().
size
());
return
ms_tensor
;
return
ms_tensor
;
}
}
...
@@ -166,10 +175,7 @@ void ClearEnv() {
...
@@ -166,10 +175,7 @@ void ClearEnv() {
Session
::
Instance
().
Clear
();
Session
::
Instance
().
Clear
();
inference
::
ExitInference
();
inference
::
ExitInference
();
}
}
void
HandleSignal
(
int
sig
)
{
void
HandleSignal
(
int
sig
)
{
exit_requested
.
set_value
();
}
ClearEnv
();
exit
(
0
);
}
#ifdef ENABLE_D
#ifdef ENABLE_D
static
rtContext_t
g_ctx
=
nullptr
;
static
rtContext_t
g_ctx
=
nullptr
;
...
@@ -247,6 +253,7 @@ Status Server::BuildAndStart() {
...
@@ -247,6 +253,7 @@ Status Server::BuildAndStart() {
rtError_t
rt_ret
=
rtCtxGetCurrent
(
&
ctx
);
rtError_t
rt_ret
=
rtCtxGetCurrent
(
&
ctx
);
if
(
rt_ret
!=
RT_ERROR_NONE
||
ctx
==
nullptr
)
{
if
(
rt_ret
!=
RT_ERROR_NONE
||
ctx
==
nullptr
)
{
MS_LOG
(
ERROR
)
<<
"the ascend device context is null"
;
MS_LOG
(
ERROR
)
<<
"the ascend device context is null"
;
ClearEnv
();
return
FAILED
;
return
FAILED
;
}
}
g_ctx
=
ctx
;
g_ctx
=
ctx
;
...
@@ -258,6 +265,7 @@ Status Server::BuildAndStart() {
...
@@ -258,6 +265,7 @@ Status Server::BuildAndStart() {
auto
option
=
grpc
::
MakeChannelArgumentOption
(
GRPC_ARG_ALLOW_REUSEPORT
,
0
);
auto
option
=
grpc
::
MakeChannelArgumentOption
(
GRPC_ARG_ALLOW_REUSEPORT
,
0
);
grpc
::
ServerBuilder
builder
;
grpc
::
ServerBuilder
builder
;
builder
.
SetOption
(
std
::
move
(
option
));
builder
.
SetOption
(
std
::
move
(
option
));
builder
.
SetMaxMessageSize
(
uint32max
);
// Listen on the given address without any authentication mechanism.
// Listen on the given address without any authentication mechanism.
builder
.
AddListeningPort
(
server_address
,
grpc
::
InsecureServerCredentials
());
builder
.
AddListeningPort
(
server_address
,
grpc
::
InsecureServerCredentials
());
// Register "service" as the instance through which we'll communicate with
// Register "service" as the instance through which we'll communicate with
...
@@ -265,13 +273,20 @@ Status Server::BuildAndStart() {
...
@@ -265,13 +273,20 @@ Status Server::BuildAndStart() {
builder
.
RegisterService
(
&
service
);
builder
.
RegisterService
(
&
service
);
// Finally assemble the server.
// Finally assemble the server.
std
::
unique_ptr
<
grpc
::
Server
>
server
(
builder
.
BuildAndStart
());
std
::
unique_ptr
<
grpc
::
Server
>
server
(
builder
.
BuildAndStart
());
if
(
server
==
nullptr
)
{
MS_LOG
(
ERROR
)
<<
"The serving server create failed"
;
ClearEnv
();
return
FAILED
;
}
auto
grpc_server_run
=
[
&
server
]()
{
server
->
Wait
();
};
std
::
thread
serving_thread
(
grpc_server_run
);
MS_LOG
(
INFO
)
<<
"Server listening on "
<<
server_address
<<
std
::
endl
;
MS_LOG
(
INFO
)
<<
"Server listening on "
<<
server_address
<<
std
::
endl
;
auto
exit_future
=
exit_requested
.
get_future
();
// Wait for the server to shutdown. Note that some other thread must be
exit_future
.
wait
();
// responsible for shutting down the server for this call to ever return.
ClearEnv
();
server
->
Wait
();
server
->
Shutdown
();
serving_thread
.
join
();
return
SUCCESS
;
return
SUCCESS
;
}
}
}
// namespace serving
}
// namespace serving
}
// namespace mindspore
}
// namespace mindspore
serving/core/util/file_system_operation.cc
浏览文件 @
5a886794
...
@@ -29,7 +29,6 @@
...
@@ -29,7 +29,6 @@
namespace
mindspore
{
namespace
mindspore
{
namespace
serving
{
namespace
serving
{
char
*
ReadFile
(
const
char
*
file
,
size_t
*
size
)
{
char
*
ReadFile
(
const
char
*
file
,
size_t
*
size
)
{
if
(
file
==
nullptr
)
{
if
(
file
==
nullptr
)
{
MS_LOG
(
ERROR
)
<<
"file is nullptr"
;
MS_LOG
(
ERROR
)
<<
"file is nullptr"
;
...
@@ -70,8 +69,8 @@ bool DirOrFileExist(const std::string &file_path) {
...
@@ -70,8 +69,8 @@ bool DirOrFileExist(const std::string &file_path) {
}
}
std
::
vector
<
std
::
string
>
GetAllSubDirs
(
const
std
::
string
&
dir_path
)
{
std
::
vector
<
std
::
string
>
GetAllSubDirs
(
const
std
::
string
&
dir_path
)
{
DIR
*
dir
;
DIR
*
dir
=
nullptr
;
struct
dirent
*
ptr
;
struct
dirent
*
ptr
=
nullptr
;
std
::
vector
<
std
::
string
>
SubDirs
;
std
::
vector
<
std
::
string
>
SubDirs
;
if
((
dir
=
opendir
(
dir_path
.
c_str
()))
==
NULL
)
{
if
((
dir
=
opendir
(
dir_path
.
c_str
()))
==
NULL
)
{
...
...
serving/core/util/option_parser.cc
浏览文件 @
5a886794
...
@@ -36,17 +36,16 @@ bool RemovePrefix(std::string *str, const std::string &prefix) {
...
@@ -36,17 +36,16 @@ bool RemovePrefix(std::string *str, const std::string &prefix) {
bool
Option
::
ParseInt32
(
std
::
string
*
arg
)
{
bool
Option
::
ParseInt32
(
std
::
string
*
arg
)
{
if
(
RemovePrefix
(
arg
,
"--"
)
&&
RemovePrefix
(
arg
,
name_
)
&&
RemovePrefix
(
arg
,
"="
))
{
if
(
RemovePrefix
(
arg
,
"--"
)
&&
RemovePrefix
(
arg
,
name_
)
&&
RemovePrefix
(
arg
,
"="
))
{
char
extra
;
int32_t
parsed_value
;
int32_t
parsed_value
;
if
(
sscanf
(
arg
->
data
(),
"%d%c"
,
&
parsed_value
,
&
extra
)
!=
1
)
{
try
{
std
::
cout
<<
"Parse "
<<
name_
<<
"Error for option "
<<
*
arg
<<
std
::
endl
;
parsed_value
=
std
::
stoi
(
arg
->
data
());
}
catch
(
std
::
invalid_argument
)
{
std
::
cout
<<
"Parse "
<<
name_
<<
" Error for option "
<<
*
arg
<<
std
::
endl
;
return
false
;
return
false
;
}
else
{
*
int32_default_
=
parsed_value
;
}
}
*
int32_default_
=
parsed_value
;
return
true
;
return
true
;
}
}
return
false
;
return
false
;
}
}
...
@@ -76,17 +75,16 @@ bool Option::ParseString(std::string *arg) {
...
@@ -76,17 +75,16 @@ bool Option::ParseString(std::string *arg) {
bool
Option
::
ParseFloat
(
std
::
string
*
arg
)
{
bool
Option
::
ParseFloat
(
std
::
string
*
arg
)
{
if
(
RemovePrefix
(
arg
,
"--"
)
&&
RemovePrefix
(
arg
,
name_
)
&&
RemovePrefix
(
arg
,
"="
))
{
if
(
RemovePrefix
(
arg
,
"--"
)
&&
RemovePrefix
(
arg
,
name_
)
&&
RemovePrefix
(
arg
,
"="
))
{
char
extra
;
float
parsed_value
;
float
parsed_value
;
if
(
sscanf
(
arg
->
data
(),
"%f%c"
,
&
parsed_value
,
&
extra
)
!=
1
)
{
try
{
std
::
cout
<<
"Parse "
<<
name_
<<
"Error for option "
<<
*
arg
<<
std
::
endl
;
parsed_value
=
std
::
stof
(
arg
->
data
());
}
catch
(
std
::
invalid_argument
)
{
std
::
cout
<<
"Parse "
<<
name_
<<
" Error for option "
<<
*
arg
<<
std
::
endl
;
return
false
;
return
false
;
}
else
{
*
float_default_
=
parsed_value
;
}
}
*
float_default_
=
parsed_value
;
return
true
;
return
true
;
}
}
return
false
;
return
false
;
}
}
...
@@ -159,10 +157,11 @@ Options::Options() : args_(nullptr) { CreateOptions(); }
...
@@ -159,10 +157,11 @@ Options::Options() : args_(nullptr) { CreateOptions(); }
void
Options
::
CreateOptions
()
{
void
Options
::
CreateOptions
()
{
args_
=
std
::
make_shared
<
Arguments
>
();
args_
=
std
::
make_shared
<
Arguments
>
();
std
::
vector
<
Option
>
options
=
{
std
::
vector
<
Option
>
options
=
{
Option
(
"port"
,
&
args_
->
grpc_port
,
"Port to listen on for gRPC API, default is 5500"
),
Option
(
"port"
,
&
args_
->
grpc_port
,
Option
(
"model_name"
,
&
args_
->
model_name
,
"model name "
),
"[Optional] Port to listen on for gRPC API, default is 5500, range from 1 to 65535"
),
Option
(
"model_path"
,
&
args_
->
model_path
,
"the path of the model files"
),
Option
(
"model_name"
,
&
args_
->
model_name
,
"[Required] model name "
),
Option
(
"device_id"
,
&
args_
->
device_id
,
"the device id, default is 0"
),
Option
(
"model_path"
,
&
args_
->
model_path
,
"[Required] the path of the model files"
),
Option
(
"device_id"
,
&
args_
->
device_id
,
"[Optional] the device id, default is 0, range from 0 to 7"
),
};
};
options_
=
options
;
options_
=
options
;
}
}
...
@@ -176,6 +175,14 @@ bool Options::CheckOptions() {
...
@@ -176,6 +175,14 @@ bool Options::CheckOptions() {
std
::
cout
<<
"device_type only support Ascend right now"
<<
std
::
endl
;
std
::
cout
<<
"device_type only support Ascend right now"
<<
std
::
endl
;
return
false
;
return
false
;
}
}
if
(
args_
->
device_id
>
7
)
{
std
::
cout
<<
"the device_id should be in [0~7]"
<<
std
::
endl
;
return
false
;
}
if
(
args_
->
grpc_port
<
1
||
args_
->
grpc_port
>
65535
)
{
std
::
cout
<<
"the port should be in [1~65535]"
<<
std
::
endl
;
return
false
;
}
return
true
;
return
true
;
}
}
...
@@ -238,6 +245,5 @@ void Options::Usage() {
...
@@ -238,6 +245,5 @@ void Options::Usage() {
<<
option
.
usage_
<<
std
::
endl
;
<<
option
.
usage_
<<
std
::
endl
;
}
}
}
}
}
// namespace serving
}
// namespace serving
}
// namespace mindspore
}
// namespace mindspore
serving/core/util/option_parser.h
浏览文件 @
5a886794
...
@@ -22,7 +22,6 @@
...
@@ -22,7 +22,6 @@
namespace
mindspore
{
namespace
mindspore
{
namespace
serving
{
namespace
serving
{
struct
Arguments
{
struct
Arguments
{
int32_t
grpc_port
=
5500
;
int32_t
grpc_port
=
5500
;
std
::
string
grpc_socket_path
;
std
::
string
grpc_socket_path
;
...
@@ -40,6 +39,7 @@ class Option {
...
@@ -40,6 +39,7 @@ class Option {
Option
(
const
std
::
string
&
name
,
bool
*
default_point
,
const
std
::
string
&
usage
);
Option
(
const
std
::
string
&
name
,
bool
*
default_point
,
const
std
::
string
&
usage
);
Option
(
const
std
::
string
&
name
,
std
::
string
*
default_point
,
const
std
::
string
&
usage
);
Option
(
const
std
::
string
&
name
,
std
::
string
*
default_point
,
const
std
::
string
&
usage
);
Option
(
const
std
::
string
&
name
,
float
*
default_point
,
const
std
::
string
&
usage
);
Option
(
const
std
::
string
&
name
,
float
*
default_point
,
const
std
::
string
&
usage
);
~
Option
()
=
default
;
private:
private:
friend
class
Options
;
friend
class
Options
;
...
@@ -77,7 +77,6 @@ class Options {
...
@@ -77,7 +77,6 @@ class Options {
std
::
vector
<
Option
>
options_
;
std
::
vector
<
Option
>
options_
;
std
::
shared_ptr
<
Arguments
>
args_
;
std
::
shared_ptr
<
Arguments
>
args_
;
};
};
}
// namespace serving
}
// namespace serving
}
// namespace mindspore
}
// namespace mindspore
...
...
serving/core/version_control/model.cc
浏览文件 @
5a886794
...
@@ -19,7 +19,6 @@
...
@@ -19,7 +19,6 @@
namespace
mindspore
{
namespace
mindspore
{
namespace
serving
{
namespace
serving
{
MindSporeModel
::
MindSporeModel
(
const
std
::
string
&
model_name
,
const
std
::
string
&
model_path
,
MindSporeModel
::
MindSporeModel
(
const
std
::
string
&
model_name
,
const
std
::
string
&
model_path
,
const
std
::
string
&
model_version
,
const
time_t
&
last_update_time
)
const
std
::
string
&
model_version
,
const
time_t
&
last_update_time
)
:
model_name_
(
model_name
),
:
model_name_
(
model_name
),
...
...
serving/core/version_control/version_controller.cc
浏览文件 @
5a886794
...
@@ -25,7 +25,6 @@
...
@@ -25,7 +25,6 @@
namespace
mindspore
{
namespace
mindspore
{
namespace
serving
{
namespace
serving
{
volatile
bool
stop_poll
=
false
;
volatile
bool
stop_poll
=
false
;
std
::
string
GetVersionFromPath
(
const
std
::
string
&
path
)
{
std
::
string
GetVersionFromPath
(
const
std
::
string
&
path
)
{
...
@@ -102,10 +101,10 @@ Status VersionController::CreateInitModels() {
...
@@ -102,10 +101,10 @@ Status VersionController::CreateInitModels() {
}
}
std
::
vector
<
std
::
string
>
SubDirs
=
GetAllSubDirs
(
models_path_
);
std
::
vector
<
std
::
string
>
SubDirs
=
GetAllSubDirs
(
models_path_
);
if
(
version_control_strategy_
==
kLastest
)
{
if
(
version_control_strategy_
==
kLastest
)
{
auto
path
=
SubDirs
.
empty
()
?
models_path_
:
SubDirs
.
back
(
);
std
::
string
model_version
=
GetVersionFromPath
(
models_path_
);
std
::
string
model_version
=
GetVersionFromPath
(
path
);
time_t
last_update_time
=
GetModifyTime
(
models_path_
);
time_t
last_update_time
=
GetModifyTime
(
path
);
MindSporeModelPtr
model_ptr
=
MindSporeModelPtr
model_ptr
=
std
::
make_shared
<
MindSporeModel
>
(
model_name_
,
path
,
model_version
,
last_update_time
);
std
::
make_shared
<
MindSporeModel
>
(
model_name_
,
models_path_
,
model_version
,
last_update_time
);
valid_models_
.
emplace_back
(
model_ptr
);
valid_models_
.
emplace_back
(
model_ptr
);
}
else
{
}
else
{
for
(
auto
&
dir
:
SubDirs
)
{
for
(
auto
&
dir
:
SubDirs
)
{
...
@@ -119,8 +118,8 @@ Status VersionController::CreateInitModels() {
...
@@ -119,8 +118,8 @@ Status VersionController::CreateInitModels() {
MS_LOG
(
ERROR
)
<<
"There is no valid model for serving"
;
MS_LOG
(
ERROR
)
<<
"There is no valid model for serving"
;
return
FAILED
;
return
FAILED
;
}
}
Session
::
Instance
().
Warmup
(
valid_models_
.
back
());
auto
ret
=
Session
::
Instance
().
Warmup
(
valid_models_
.
back
());
return
SUCCESS
;
return
ret
;
}
}
void
VersionController
::
StartPollModelPeriodic
()
{
void
VersionController
::
StartPollModelPeriodic
()
{
...
@@ -129,6 +128,5 @@ void VersionController::StartPollModelPeriodic() {
...
@@ -129,6 +128,5 @@ void VersionController::StartPollModelPeriodic() {
}
}
void
VersionController
::
StopPollModelPeriodic
()
{}
void
VersionController
::
StopPollModelPeriodic
()
{}
}
// namespace serving
}
// namespace serving
}
// namespace mindspore
}
// namespace mindspore
serving/core/version_control/version_controller.h
浏览文件 @
5a886794
...
@@ -64,7 +64,6 @@ class PeriodicFunction {
...
@@ -64,7 +64,6 @@ class PeriodicFunction {
VersionController
::
VersionControllerStrategy
version_control_strategy_
;
VersionController
::
VersionControllerStrategy
version_control_strategy_
;
std
::
vector
<
MindSporeModelPtr
>
valid_models_
;
std
::
vector
<
MindSporeModelPtr
>
valid_models_
;
};
};
}
// namespace serving
}
// namespace serving
}
// namespace mindspore
}
// namespace mindspore
...
...
serving/cpp_example/ms_client.cc
浏览文件 @
5a886794
...
@@ -214,6 +214,7 @@ PredictRequest ReadBertInput() {
...
@@ -214,6 +214,7 @@ PredictRequest ReadBertInput() {
class
MSClient
{
class
MSClient
{
public:
public:
explicit
MSClient
(
std
::
shared_ptr
<
Channel
>
channel
)
:
stub_
(
MSService
::
NewStub
(
channel
))
{}
explicit
MSClient
(
std
::
shared_ptr
<
Channel
>
channel
)
:
stub_
(
MSService
::
NewStub
(
channel
))
{}
~
MSClient
()
=
default
;
std
::
string
Predict
(
const
std
::
string
&
type
)
{
std
::
string
Predict
(
const
std
::
string
&
type
)
{
// Data we are sending to the server.
// Data we are sending to the server.
...
@@ -310,7 +311,6 @@ int main(int argc, char **argv) {
...
@@ -310,7 +311,6 @@ int main(int argc, char **argv) {
type
=
"add"
;
type
=
"add"
;
}
}
}
}
}
else
{
}
else
{
target_str
=
"localhost:5500"
;
target_str
=
"localhost:5500"
;
type
=
"add"
;
type
=
"add"
;
...
...
serving/scripts/format_source_code.sh
浏览文件 @
5a886794
...
@@ -81,7 +81,7 @@ function checkopts()
...
@@ -81,7 +81,7 @@ function checkopts()
checkopts
"
$@
"
checkopts
"
$@
"
# switch to project root path, which contains clang-format config file '.clang-format'
# switch to project root path, which contains clang-format config file '.clang-format'
cd
"
${
SCRIPTS_PATH
}
/.."
||
exit
1
cd
"
${
SCRIPTS_PATH
}
/..
/..
"
||
exit
1
FMT_FILE_LIST
=
'__format_files_list__'
FMT_FILE_LIST
=
'__format_files_list__'
...
...
setup.py
浏览文件 @
5a886794
...
@@ -161,6 +161,7 @@ setup(
...
@@ -161,6 +161,7 @@ setup(
description
=
'MindSpore is a new open source deep learning training/inference '
description
=
'MindSpore is a new open source deep learning training/inference '
'framework that could be used for mobile, edge and cloud scenarios.'
,
'framework that could be used for mobile, edge and cloud scenarios.'
,
long_description
=
"
\n\n
"
.
join
([
readme
,
release
]),
long_description
=
"
\n\n
"
.
join
([
readme
,
release
]),
long_description_content_type
=
"text/markdown"
,
packages
=
find_packages
(),
packages
=
find_packages
(),
package_data
=
package_data
,
package_data
=
package_data
,
include_package_data
=
True
,
include_package_data
=
True
,
...
...
tests/ut/cpp/dataset/btree_test.cc
浏览文件 @
5a886794
...
@@ -190,9 +190,9 @@ TEST_F(MindDataTestBPlusTree, Test3) {
...
@@ -190,9 +190,9 @@ TEST_F(MindDataTestBPlusTree, Test3) {
EXPECT_TRUE
(
rc
.
IsOk
());
EXPECT_TRUE
(
rc
.
IsOk
());
uint64_t
min
=
ai
.
min_key
();
uint64_t
min
=
ai
.
min_key
();
uint64_t
max
=
ai
.
max_key
();
uint64_t
max
=
ai
.
max_key
();
EXPECT_EQ
(
min
,
1
);
EXPECT_EQ
(
min
,
0
);
EXPECT_EQ
(
max
,
4
);
EXPECT_EQ
(
max
,
3
);
auto
r
=
ai
.
Search
(
3
);
auto
r
=
ai
.
Search
(
2
);
auto
&
it
=
r
.
first
;
auto
&
it
=
r
.
first
;
EXPECT_EQ
(
it
.
value
(),
"b"
);
EXPECT_EQ
(
it
.
value
(),
"b"
);
MS_LOG
(
INFO
)
<<
"Dump all the values using [] operator."
;
MS_LOG
(
INFO
)
<<
"Dump all the values using [] operator."
;
...
...
tests/ut/cpp/optimizer/opt_test.cc
浏览文件 @
5a886794
...
@@ -77,10 +77,10 @@ class TestOptOpt : public UT::Common {
...
@@ -77,10 +77,10 @@ class TestOptOpt : public UT::Common {
};
};
void
SetUp
()
{
void
SetUp
()
{
elim_Z
=
MakeSubstitution
(
irpass
::
AddByZero
(),
"elim_Z"
,
prim
::
kPrimScalarAdd
);
elim_Z
=
MakeSubstitution
(
std
::
make_shared
<
irpass
::
AddByZero
>
(),
"elim_Z"
,
prim
::
kPrimScalarAdd
);
elim_R
=
MakeSubstitution
(
irpass
::
PrimEliminater
(
R
),
"elim_R"
,
R
);
elim_R
=
MakeSubstitution
(
std
::
make_shared
<
irpass
::
PrimEliminater
>
(
R
),
"elim_R"
,
R
);
idempotent_P
=
MakeSubstitution
(
IdempotentEliminater
(),
"idempotent_P"
,
P
);
idempotent_P
=
MakeSubstitution
(
std
::
make_shared
<
IdempotentEliminater
>
(),
"idempotent_P"
,
P
);
Qct_to_P
=
MakeSubstitution
(
QctToP
(),
"Qct_to_P"
,
Q
);
Qct_to_P
=
MakeSubstitution
(
std
::
make_shared
<
QctToP
>
(),
"Qct_to_P"
,
Q
);
}
}
bool
CheckTransform
(
FuncGraphPtr
gbefore
,
FuncGraphPtr
gafter
,
const
SubstitutionList
&
transform
)
{
bool
CheckTransform
(
FuncGraphPtr
gbefore
,
FuncGraphPtr
gafter
,
const
SubstitutionList
&
transform
)
{
...
...
tests/ut/cpp/parallel/step_parallel_test.cc
浏览文件 @
5a886794
...
@@ -327,6 +327,9 @@ TEST_F(TestStepParallel, CreatOpInstance) {
...
@@ -327,6 +327,9 @@ TEST_F(TestStepParallel, CreatOpInstance) {
}
else
if
(
name
==
"instance_name"
)
{
}
else
if
(
name
==
"instance_name"
)
{
parse
::
ConvertData
(
py
::
cast
<
py
::
object
>
(
item
.
second
),
&
converted_ret
);
parse
::
ConvertData
(
py
::
cast
<
py
::
object
>
(
item
.
second
),
&
converted_ret
);
ASSERT_EQ
(
converted_ret
->
ToString
(),
"test"
);
ASSERT_EQ
(
converted_ret
->
ToString
(),
"test"
);
}
else
if
(
name
==
"index"
)
{
parse
::
ConvertData
(
py
::
cast
<
py
::
object
>
(
item
.
second
),
&
converted_ret
);
ASSERT_EQ
(
converted_ret
->
ToString
(),
"0"
);
}
else
{
}
else
{
MS_LOG
(
EXCEPTION
)
<<
"Test failed"
;
MS_LOG
(
EXCEPTION
)
<<
"Test failed"
;
}
}
...
...
tests/ut/data/dataset/declient.cfg
浏览文件 @
5a886794
...
@@ -4,6 +4,7 @@
...
@@ -4,6 +4,7 @@
"numParallelWorkers": 4,
"numParallelWorkers": 4,
"workerConnectorSize": 16,
"workerConnectorSize": 16,
"opConnectorSize": 16,
"opConnectorSize": 16,
"seed": 5489
"seed": 5489,
"monitor_sampling_interval": 15
}
}
tests/ut/data/dataset/golden/bounding_box_augment_crop_c_result.npz
0 → 100644
浏览文件 @
5a886794
文件已添加
tests/ut/data/dataset/golden/bounding_box_augment_rotation_c_result.npz
0 → 100644
浏览文件 @
5a886794
文件已添加
tests/ut/data/dataset/golden/bounding_box_augment_valid_edge_c_result.npz
0 → 100644
浏览文件 @
5a886794
文件已添加
tests/ut/data/dataset/golden/bounding_box_augment_valid_ratio_c_result.npz
0 → 100644
浏览文件 @
5a886794
文件已添加
tests/ut/data/dataset/golden/random_crop_with_bbox_01_c_result.npz
0 → 100644
浏览文件 @
5a886794
文件已添加
tests/ut/data/dataset/golden/random_horizontal_flip_with_bbox_01_c_result.npz
0 → 100644
浏览文件 @
5a886794
文件已添加
tests/ut/data/dataset/golden/random_resize_with_bbox_op_01_c_result.npz
0 → 100644
浏览文件 @
5a886794
文件已添加
tests/ut/data/dataset/golden/random_resized_crop_with_bbox_01_c_result.npz
0 → 100644
浏览文件 @
5a886794
文件已添加
tests/ut/data/dataset/golden/random_vertical_flip_with_bbox_01_c_result.npz
0 → 100644
浏览文件 @
5a886794
文件已添加
tests/ut/data/dataset/golden/resize_with_bbox_op_01_c_result.npz
0 → 100644
浏览文件 @
5a886794
文件已添加
tests/ut/python/dataset/test_batch.py
浏览文件 @
5a886794
...
@@ -12,10 +12,9 @@
...
@@ -12,10 +12,9 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
# ==============================================================================
# ==============================================================================
from
util
import
save_and_check
import
mindspore.dataset
as
ds
import
mindspore.dataset
as
ds
from
mindspore
import
log
as
logger
from
mindspore
import
log
as
logger
from
util
import
save_and_check
# Note: Number of rows in test.data dataset: 12
# Note: Number of rows in test.data dataset: 12
DATA_DIR
=
[
"../data/dataset/testTFTestAllTypes/test.data"
]
DATA_DIR
=
[
"../data/dataset/testTFTestAllTypes/test.data"
]
...
@@ -434,7 +433,6 @@ def test_batch_exception_11():
...
@@ -434,7 +433,6 @@ def test_batch_exception_11():
assert
"drop_remainder"
in
str
(
e
)
assert
"drop_remainder"
in
str
(
e
)
# pylint: disable=redundant-keyword-arg
def
test_batch_exception_12
():
def
test_batch_exception_12
():
"""
"""
Test batch exception: wrong input order, drop_remainder wrongly used as batch_size
Test batch exception: wrong input order, drop_remainder wrongly used as batch_size
...
@@ -447,12 +445,12 @@ def test_batch_exception_12():
...
@@ -447,12 +445,12 @@ def test_batch_exception_12():
# apply dataset operations
# apply dataset operations
data1
=
ds
.
TFRecordDataset
(
DATA_DIR
)
data1
=
ds
.
TFRecordDataset
(
DATA_DIR
)
try
:
try
:
data1
=
data1
.
batch
(
drop_remainder
,
batch_size
=
batch_size
)
data1
=
data1
.
batch
(
drop_remainder
,
batch_size
)
sum
([
1
for
_
in
data1
])
sum
([
1
for
_
in
data1
])
except
Exception
as
e
:
except
Exception
as
e
:
logger
.
info
(
"Got an exception in DE: {}"
.
format
(
str
(
e
)))
logger
.
info
(
"Got an exception in DE: {}"
.
format
(
str
(
e
)))
assert
"
batch_size
"
in
str
(
e
)
assert
"
drop_remainder
"
in
str
(
e
)
def
test_batch_exception_13
():
def
test_batch_exception_13
():
...
...
tests/ut/python/dataset/test_bounding_box_augment.py
浏览文件 @
5a886794
...
@@ -15,92 +15,18 @@
...
@@ -15,92 +15,18 @@
"""
"""
Testing the bounding box augment op in DE
Testing the bounding box augment op in DE
"""
"""
from
enum
import
Enum
from
util
import
visualize_with_bounding_boxes
,
InvalidBBoxType
,
check_bad_bbox
,
\
config_get_set_seed
,
config_get_set_num_parallel_workers
,
save_and_check_md5
import
numpy
as
np
import
mindspore.log
as
logger
import
mindspore.log
as
logger
import
mindspore.dataset
as
ds
import
mindspore.dataset
as
ds
import
mindspore.dataset.transforms.vision.c_transforms
as
c_vision
import
mindspore.dataset.transforms.vision.c_transforms
as
c_vision
import
matplotlib.pyplot
as
plt
import
matplotlib.patches
as
patches
import
numpy
as
np
GENERATE_GOLDEN
=
False
GENERATE_GOLDEN
=
False
DATA_DIR
=
"../data/dataset/testVOC2012_2"
DATA_DIR
=
"../data/dataset/testVOC2012_2"
class
BoxType
(
Enum
):
"""
Defines box types for test cases
"""
WidthOverflow
=
1
HeightOverflow
=
2
NegativeXY
=
3
OnEdge
=
4
WrongShape
=
5
def
add_bad_annotation
(
img
,
bboxes
,
box_type
):
"""
Used to generate erroneous bounding box examples on given img.
:param img: image where the bounding boxes are.
:param bboxes: in [x_min, y_min, w, h, label, truncate, difficult] format
:param box_type: type of bad box
:return: bboxes with bad examples added
"""
height
=
img
.
shape
[
0
]
width
=
img
.
shape
[
1
]
if
box_type
==
BoxType
.
WidthOverflow
:
# use box that overflows on width
return
img
,
np
.
array
([[
0
,
0
,
width
+
1
,
height
,
0
,
0
,
0
]]).
astype
(
np
.
uint32
)
if
box_type
==
BoxType
.
HeightOverflow
:
# use box that overflows on height
return
img
,
np
.
array
([[
0
,
0
,
width
,
height
+
1
,
0
,
0
,
0
]]).
astype
(
np
.
uint32
)
if
box_type
==
BoxType
.
NegativeXY
:
# use box with negative xy
return
img
,
np
.
array
([[
-
10
,
-
10
,
width
,
height
,
0
,
0
,
0
]]).
astype
(
np
.
uint32
)
if
box_type
==
BoxType
.
OnEdge
:
# use box that covers the whole image
return
img
,
np
.
array
([[
0
,
0
,
width
,
height
,
0
,
0
,
0
]]).
astype
(
np
.
uint32
)
if
box_type
==
BoxType
.
WrongShape
:
# use box that covers the whole image
return
img
,
np
.
array
([[
0
,
0
,
width
-
1
]]).
astype
(
np
.
uint32
)
return
img
,
bboxes
def
check_bad_box
(
data
,
box_type
,
expected_error
):
"""
:param data: de object detection pipeline
:param box_type: type of bad box
:param expected_error: error expected to get due to bad box
:return: None
"""
try
:
test_op
=
c_vision
.
BoundingBoxAugment
(
c_vision
.
RandomHorizontalFlip
(
1
),
1
)
# DEFINE TEST OP HERE -- (PROB 1 IN CASE OF RANDOM)
data
=
data
.
map
(
input_columns
=
[
"annotation"
],
output_columns
=
[
"annotation"
],
operations
=
fix_annotate
)
# map to use width overflow
data
=
data
.
map
(
input_columns
=
[
"image"
,
"annotation"
],
output_columns
=
[
"image"
,
"annotation"
],
columns_order
=
[
"image"
,
"annotation"
],
operations
=
lambda
img
,
bboxes
:
add_bad_annotation
(
img
,
bboxes
,
box_type
))
# map to apply ops
data
=
data
.
map
(
input_columns
=
[
"image"
,
"annotation"
],
output_columns
=
[
"image"
,
"annotation"
],
columns_order
=
[
"image"
,
"annotation"
],
operations
=
[
test_op
])
# Add column for "annotation"
for
_
,
_
in
enumerate
(
data
.
create_dict_iterator
()):
break
except
RuntimeError
as
error
:
logger
.
info
(
"Got an exception in DE: {}"
.
format
(
str
(
error
)))
assert
expected_error
in
str
(
error
)
def
fix_annotate
(
bboxes
):
def
fix_annotate
(
bboxes
):
"""
"""
Fix annotations to format followed by mindspore.
Fix annotations to format followed by mindspore.
...
@@ -117,153 +43,217 @@ def fix_annotate(bboxes):
...
@@ -117,153 +43,217 @@ def fix_annotate(bboxes):
return
bboxes
return
bboxes
def
add_bounding_boxes
(
axis
,
bboxes
):
def
test_bounding_box_augment_with_rotation_op
(
plot_vis
=
False
):
"""
:param axis: axis to modify
:param bboxes: bounding boxes to draw on the axis
:return: None
"""
for
bbox
in
bboxes
:
rect
=
patches
.
Rectangle
((
bbox
[
0
],
bbox
[
1
]),
bbox
[
2
],
bbox
[
3
],
linewidth
=
1
,
edgecolor
=
'r'
,
facecolor
=
'none'
)
# Add the patch to the Axes
axis
.
add_patch
(
rect
)
def
visualize
(
unaugmented_data
,
augment_data
):
"""
"""
:param unaugmented_data: original data
Test BoundingBoxAugment op (passing rotation op as transform)
:param augment_data: data after augmentations
:return: None
"""
for
idx
,
(
un_aug_item
,
aug_item
)
in
\
enumerate
(
zip
(
unaugmented_data
.
create_dict_iterator
(),
augment_data
.
create_dict_iterator
())):
axis
=
plt
.
subplot
(
141
)
plt
.
imshow
(
un_aug_item
[
"image"
])
add_bounding_boxes
(
axis
,
un_aug_item
[
"annotation"
])
# add Orig BBoxes
plt
.
title
(
"Original"
+
str
(
idx
+
1
))
logger
.
info
(
"Original "
,
str
(
idx
+
1
),
" :"
,
un_aug_item
[
"annotation"
])
axis
=
plt
.
subplot
(
142
)
plt
.
imshow
(
aug_item
[
"image"
])
add_bounding_boxes
(
axis
,
aug_item
[
"annotation"
])
# add AugBBoxes
plt
.
title
(
"Augmented"
+
str
(
idx
+
1
))
logger
.
info
(
"Augmented "
,
str
(
idx
+
1
),
" "
,
aug_item
[
"annotation"
],
"
\n
"
)
plt
.
show
()
def
test_bounding_box_augment_with_rotation_op
(
plot
=
False
):
"""
Test BoundingBoxAugment op
Prints images side by side with and without Aug applied + bboxes to compare and test
Prints images side by side with and without Aug applied + bboxes to compare and test
"""
"""
logger
.
info
(
"test_bounding_box_augment_with_rotation_op"
)
logger
.
info
(
"test_bounding_box_augment_with_rotation_op"
)
data_voc1
=
ds
.
VOCDataset
(
DATA_DIR
,
task
=
"Detection"
,
mode
=
"train"
,
decode
=
True
,
shuffle
=
False
)
original_seed
=
config_get_set_seed
(
0
)
data_voc2
=
ds
.
VOCDataset
(
DATA_DIR
,
task
=
"Detection"
,
mode
=
"train"
,
decode
=
True
,
shuffle
=
False
)
original_num_parallel_workers
=
config_get_set_num_parallel_workers
(
1
)
dataVoc1
=
ds
.
VOCDataset
(
DATA_DIR
,
task
=
"Detection"
,
mode
=
"train"
,
decode
=
True
,
shuffle
=
False
)
dataVoc2
=
ds
.
VOCDataset
(
DATA_DIR
,
task
=
"Detection"
,
mode
=
"train"
,
decode
=
True
,
shuffle
=
False
)
# Ratio is set to 1 to apply rotation on all bounding boxes.
test_op
=
c_vision
.
BoundingBoxAugment
(
c_vision
.
RandomRotation
(
90
),
1
)
test_op
=
c_vision
.
BoundingBoxAugment
(
c_vision
.
RandomRotation
(
90
),
1
)
# DEFINE TEST OP HERE -- (PROB 1 IN CASE OF RANDOM)
# maps to fix annotations to minddata standard
# maps to fix annotations to minddata standard
data
_voc1
=
data_v
oc1
.
map
(
input_columns
=
[
"annotation"
],
data
Voc1
=
dataV
oc1
.
map
(
input_columns
=
[
"annotation"
],
output_columns
=
[
"annotation"
],
output_columns
=
[
"annotation"
],
operations
=
fix_annotate
)
operations
=
fix_annotate
)
data
_voc2
=
data_v
oc2
.
map
(
input_columns
=
[
"annotation"
],
data
Voc2
=
dataV
oc2
.
map
(
input_columns
=
[
"annotation"
],
output_columns
=
[
"annotation"
],
output_columns
=
[
"annotation"
],
operations
=
fix_annotate
)
operations
=
fix_annotate
)
# map to apply ops
# map to apply ops
data_voc2
=
data_voc2
.
map
(
input_columns
=
[
"image"
,
"annotation"
],
dataVoc2
=
dataVoc2
.
map
(
input_columns
=
[
"image"
,
"annotation"
],
output_columns
=
[
"image"
,
"annotation"
],
output_columns
=
[
"image"
,
"annotation"
],
columns_order
=
[
"image"
,
"annotation"
],
columns_order
=
[
"image"
,
"annotation"
],
operations
=
[
test_op
])
# Add column for "annotation"
operations
=
[
test_op
])
if
plot
:
visualize
(
data_voc1
,
data_voc2
)
filename
=
"bounding_box_augment_rotation_c_result.npz"
save_and_check_md5
(
dataVoc2
,
filename
,
generate_golden
=
GENERATE_GOLDEN
)
def
test_bounding_box_augment_with_crop_op
(
plot
=
False
):
unaugSamp
,
augSamp
=
[],
[]
for
unAug
,
Aug
in
zip
(
dataVoc1
.
create_dict_iterator
(),
dataVoc2
.
create_dict_iterator
()):
unaugSamp
.
append
(
unAug
)
augSamp
.
append
(
Aug
)
if
plot_vis
:
visualize_with_bounding_boxes
(
unaugSamp
,
augSamp
)
# Restore config setting
ds
.
config
.
set_seed
(
original_seed
)
ds
.
config
.
set_num_parallel_workers
(
original_num_parallel_workers
)
def
test_bounding_box_augment_with_crop_op
(
plot_vis
=
False
):
"""
"""
Test BoundingBoxAugment op
Test BoundingBoxAugment op
(passing crop op as transform)
Prints images side by side with and without Aug applied + bboxes to compare and test
Prints images side by side with and without Aug applied + bboxes to compare and test
"""
"""
logger
.
info
(
"test_bounding_box_augment_with_crop_op"
)
logger
.
info
(
"test_bounding_box_augment_with_crop_op"
)
data_voc1
=
ds
.
VOCDataset
(
DATA_DIR
,
task
=
"Detection"
,
mode
=
"train"
,
decode
=
True
,
shuffle
=
False
)
original_seed
=
config_get_set_seed
(
1
)
data_voc2
=
ds
.
VOCDataset
(
DATA_DIR
,
task
=
"Detection"
,
mode
=
"train"
,
decode
=
True
,
shuffle
=
False
)
original_num_parallel_workers
=
config_get_set_num_parallel_workers
(
1
)
dataVoc1
=
ds
.
VOCDataset
(
DATA_DIR
,
task
=
"Detection"
,
mode
=
"train"
,
decode
=
True
,
shuffle
=
False
)
dataVoc2
=
ds
.
VOCDataset
(
DATA_DIR
,
task
=
"Detection"
,
mode
=
"train"
,
decode
=
True
,
shuffle
=
False
)
# Ratio is set to 1 to apply rotation on all bounding boxes.
test_op
=
c_vision
.
BoundingBoxAugment
(
c_vision
.
RandomCrop
(
90
),
1
)
test_op
=
c_vision
.
BoundingBoxAugment
(
c_vision
.
RandomCrop
(
90
),
1
)
# maps to fix annotations to minddata standard
# maps to fix annotations to minddata standard
data
_voc1
=
data_v
oc1
.
map
(
input_columns
=
[
"annotation"
],
data
Voc1
=
dataV
oc1
.
map
(
input_columns
=
[
"annotation"
],
output_columns
=
[
"annotation"
],
output_columns
=
[
"annotation"
],
operations
=
fix_annotate
)
operations
=
fix_annotate
)
data
_voc2
=
data_v
oc2
.
map
(
input_columns
=
[
"annotation"
],
data
Voc2
=
dataV
oc2
.
map
(
input_columns
=
[
"annotation"
],
output_columns
=
[
"annotation"
],
output_columns
=
[
"annotation"
],
operations
=
fix_annotate
)
operations
=
fix_annotate
)
# map to apply ops
# map to apply ops
data_voc2
=
data_voc2
.
map
(
input_columns
=
[
"image"
,
"annotation"
],
dataVoc2
=
dataVoc2
.
map
(
input_columns
=
[
"image"
,
"annotation"
],
output_columns
=
[
"image"
,
"annotation"
],
output_columns
=
[
"image"
,
"annotation"
],
columns_order
=
[
"image"
,
"annotation"
],
columns_order
=
[
"image"
,
"annotation"
],
operations
=
[
test_op
])
# Add column for "annotation"
operations
=
[
test_op
])
if
plot
:
visualize
(
data_voc1
,
data_voc2
)
filename
=
"bounding_box_augment_crop_c_result.npz"
save_and_check_md5
(
dataVoc2
,
filename
,
generate_golden
=
GENERATE_GOLDEN
)
unaugSamp
,
augSamp
=
[],
[]
def
test_bounding_box_augment_valid_ratio_c
(
plot
=
False
):
for
unAug
,
Aug
in
zip
(
dataVoc1
.
create_dict_iterator
(),
dataVoc2
.
create_dict_iterator
()):
unaugSamp
.
append
(
unAug
)
augSamp
.
append
(
Aug
)
if
plot_vis
:
visualize_with_bounding_boxes
(
unaugSamp
,
augSamp
)
# Restore config setting
ds
.
config
.
set_seed
(
original_seed
)
ds
.
config
.
set_num_parallel_workers
(
original_num_parallel_workers
)
def
test_bounding_box_augment_valid_ratio_c
(
plot_vis
=
False
):
"""
"""
Test
RandomHorizontalFlipWithBBox op
Test
BoundingBoxAugment op (testing with valid ratio, less than 1.
Prints images side by side with and without Aug applied + bboxes to compare and test
Prints images side by side with and without Aug applied + bboxes to compare and test
"""
"""
logger
.
info
(
"test_bounding_box_augment_valid_ratio_c"
)
logger
.
info
(
"test_bounding_box_augment_valid_ratio_c"
)
data_voc1
=
ds
.
VOCDataset
(
DATA_DIR
,
task
=
"Detection"
,
mode
=
"train"
,
decode
=
True
,
shuffle
=
False
)
original_seed
=
config_get_set_seed
(
1
)
data_voc2
=
ds
.
VOCDataset
(
DATA_DIR
,
task
=
"Detection"
,
mode
=
"train"
,
decode
=
True
,
shuffle
=
False
)
original_num_parallel_workers
=
config_get_set_num_parallel_workers
(
1
)
dataVoc1
=
ds
.
VOCDataset
(
DATA_DIR
,
task
=
"Detection"
,
mode
=
"train"
,
decode
=
True
,
shuffle
=
False
)
dataVoc2
=
ds
.
VOCDataset
(
DATA_DIR
,
task
=
"Detection"
,
mode
=
"train"
,
decode
=
True
,
shuffle
=
False
)
test_op
=
c_vision
.
BoundingBoxAugment
(
c_vision
.
RandomHorizontalFlip
(
1
),
0.9
)
test_op
=
c_vision
.
BoundingBoxAugment
(
c_vision
.
RandomHorizontalFlip
(
1
),
0.9
)
# DEFINE TEST OP HERE -- (PROB 1 IN CASE OF RANDOM)
# maps to fix annotations to minddata standard
# maps to fix annotations to minddata standard
data_voc1
=
data_voc1
.
map
(
input_columns
=
[
"annotation"
],
dataVoc1
=
dataVoc1
.
map
(
input_columns
=
[
"annotation"
],
output_columns
=
[
"annotation"
],
output_columns
=
[
"annotation"
],
operations
=
fix_annotate
)
operations
=
fix_annotate
)
data_voc2
=
data_voc2
.
map
(
input_columns
=
[
"annotation"
],
dataVoc2
=
dataVoc2
.
map
(
input_columns
=
[
"annotation"
],
output_columns
=
[
"annotation"
],
output_columns
=
[
"annotation"
],
operations
=
fix_annotate
)
operations
=
fix_annotate
)
# map to apply ops
dataVoc2
=
dataVoc2
.
map
(
input_columns
=
[
"image"
,
"annotation"
],
output_columns
=
[
"image"
,
"annotation"
],
columns_order
=
[
"image"
,
"annotation"
],
operations
=
[
test_op
])
# Add column for "annotation"
filename
=
"bounding_box_augment_valid_ratio_c_result.npz"
save_and_check_md5
(
dataVoc2
,
filename
,
generate_golden
=
GENERATE_GOLDEN
)
unaugSamp
,
augSamp
=
[],
[]
for
unAug
,
Aug
in
zip
(
dataVoc1
.
create_dict_iterator
(),
dataVoc2
.
create_dict_iterator
()):
unaugSamp
.
append
(
unAug
)
augSamp
.
append
(
Aug
)
if
plot_vis
:
visualize_with_bounding_boxes
(
unaugSamp
,
augSamp
)
# Restore config setting
ds
.
config
.
set_seed
(
original_seed
)
ds
.
config
.
set_num_parallel_workers
(
original_num_parallel_workers
)
def
test_bounding_box_augment_valid_edge_c
(
plot_vis
=
False
):
"""
Test BoundingBoxAugment op (testing with valid edge case, box covering full image).
Prints images side by side with and without Aug applied + bboxes to compare and test
"""
logger
.
info
(
"test_bounding_box_augment_valid_edge_c"
)
original_seed
=
config_get_set_seed
(
1
)
original_num_parallel_workers
=
config_get_set_num_parallel_workers
(
1
)
dataVoc1
=
ds
.
VOCDataset
(
DATA_DIR
,
task
=
"Detection"
,
mode
=
"train"
,
decode
=
True
,
shuffle
=
False
)
dataVoc2
=
ds
.
VOCDataset
(
DATA_DIR
,
task
=
"Detection"
,
mode
=
"train"
,
decode
=
True
,
shuffle
=
False
)
test_op
=
c_vision
.
BoundingBoxAugment
(
c_vision
.
RandomHorizontalFlip
(
1
),
1
)
# maps to fix annotations to minddata standard
dataVoc1
=
dataVoc1
.
map
(
input_columns
=
[
"annotation"
],
output_columns
=
[
"annotation"
],
operations
=
fix_annotate
)
dataVoc2
=
dataVoc2
.
map
(
input_columns
=
[
"annotation"
],
output_columns
=
[
"annotation"
],
operations
=
fix_annotate
)
# map to apply ops
# map to apply ops
data_voc2
=
data_voc2
.
map
(
input_columns
=
[
"image"
,
"annotation"
],
# Add column for "annotation"
output_columns
=
[
"image"
,
"annotation"
],
dataVoc1
=
dataVoc1
.
map
(
input_columns
=
[
"image"
,
"annotation"
],
columns_order
=
[
"image"
,
"annotation"
],
output_columns
=
[
"image"
,
"annotation"
],
operations
=
[
test_op
])
# Add column for "annotation"
columns_order
=
[
"image"
,
"annotation"
],
if
plot
:
operations
=
lambda
img
,
bbox
:
visualize
(
data_voc1
,
data_voc2
)
(
img
,
np
.
array
([[
0
,
0
,
img
.
shape
[
1
],
img
.
shape
[
0
],
0
,
0
,
0
]]).
astype
(
np
.
uint32
)))
dataVoc2
=
dataVoc2
.
map
(
input_columns
=
[
"image"
,
"annotation"
],
output_columns
=
[
"image"
,
"annotation"
],
columns_order
=
[
"image"
,
"annotation"
],
operations
=
lambda
img
,
bbox
:
(
img
,
np
.
array
([[
0
,
0
,
img
.
shape
[
1
],
img
.
shape
[
0
],
0
,
0
,
0
]]).
astype
(
np
.
uint32
)))
dataVoc2
=
dataVoc2
.
map
(
input_columns
=
[
"image"
,
"annotation"
],
output_columns
=
[
"image"
,
"annotation"
],
columns_order
=
[
"image"
,
"annotation"
],
operations
=
[
test_op
])
filename
=
"bounding_box_augment_valid_edge_c_result.npz"
save_and_check_md5
(
dataVoc2
,
filename
,
generate_golden
=
GENERATE_GOLDEN
)
unaugSamp
,
augSamp
=
[],
[]
for
unAug
,
Aug
in
zip
(
dataVoc1
.
create_dict_iterator
(),
dataVoc2
.
create_dict_iterator
()):
unaugSamp
.
append
(
unAug
)
augSamp
.
append
(
Aug
)
if
plot_vis
:
visualize_with_bounding_boxes
(
unaugSamp
,
augSamp
)
# Restore config setting
ds
.
config
.
set_seed
(
original_seed
)
ds
.
config
.
set_num_parallel_workers
(
original_num_parallel_workers
)
def
test_bounding_box_augment_invalid_ratio_c
():
def
test_bounding_box_augment_invalid_ratio_c
():
"""
"""
Test
RandomHorizontalFlipWithBBox op with invalid input probability
Test
BoundingBoxAugment op with invalid input ratio
"""
"""
logger
.
info
(
"test_bounding_box_augment_invalid_ratio_c"
)
logger
.
info
(
"test_bounding_box_augment_invalid_ratio_c"
)
data_voc1
=
ds
.
VOCDataset
(
DATA_DIR
,
task
=
"Detection"
,
mode
=
"train"
,
decode
=
True
,
shuffle
=
False
)
dataVoc2
=
ds
.
VOCDataset
(
DATA_DIR
,
task
=
"Detection"
,
mode
=
"train"
,
decode
=
True
,
shuffle
=
False
)
data_voc2
=
ds
.
VOCDataset
(
DATA_DIR
,
task
=
"Detection"
,
mode
=
"train"
,
decode
=
True
,
shuffle
=
False
)
try
:
try
:
# ratio range is from 0 - 1
# ratio range is from 0 - 1
test_op
=
c_vision
.
BoundingBoxAugment
(
c_vision
.
RandomHorizontalFlip
(
1
),
1.5
)
test_op
=
c_vision
.
BoundingBoxAugment
(
c_vision
.
RandomHorizontalFlip
(
1
),
1.5
)
# maps to fix annotations to minddata standard
# maps to fix annotations to minddata standard
data_voc1
=
data_voc1
.
map
(
input_columns
=
[
"annotation"
],
dataVoc2
=
dataVoc2
.
map
(
input_columns
=
[
"annotation"
],
output_columns
=
[
"annotation"
],
output_columns
=
[
"annotation"
],
operations
=
fix_annotate
)
operations
=
fix_annotate
)
data_voc2
=
data_voc2
.
map
(
input_columns
=
[
"annotation"
],
output_columns
=
[
"annotation"
],
operations
=
fix_annotate
)
# map to apply ops
# map to apply ops
data
_voc2
=
data_v
oc2
.
map
(
input_columns
=
[
"image"
,
"annotation"
],
data
Voc2
=
dataV
oc2
.
map
(
input_columns
=
[
"image"
,
"annotation"
],
output_columns
=
[
"image"
,
"annotation"
],
output_columns
=
[
"image"
,
"annotation"
],
columns_order
=
[
"image"
,
"annotation"
],
columns_order
=
[
"image"
,
"annotation"
],
operations
=
[
test_op
])
# Add column for "annotation"
operations
=
[
test_op
])
# Add column for "annotation"
except
ValueError
as
error
:
except
ValueError
as
error
:
logger
.
info
(
"Got an exception in DE: {}"
.
format
(
str
(
error
)))
logger
.
info
(
"Got an exception in DE: {}"
.
format
(
str
(
error
)))
assert
"Input is not"
in
str
(
error
)
assert
"Input is not"
in
str
(
error
)
...
@@ -275,20 +265,24 @@ def test_bounding_box_augment_invalid_bounds_c():
...
@@ -275,20 +265,24 @@ def test_bounding_box_augment_invalid_bounds_c():
"""
"""
logger
.
info
(
"test_bounding_box_augment_invalid_bounds_c"
)
logger
.
info
(
"test_bounding_box_augment_invalid_bounds_c"
)
data_voc2
=
ds
.
VOCDataset
(
DATA_DIR
,
task
=
"Detection"
,
mode
=
"train"
,
decode
=
True
,
shuffle
=
False
)
test_op
=
c_vision
.
BoundingBoxAugment
(
c_vision
.
RandomHorizontalFlip
(
1
),
check_bad_box
(
data_voc2
,
BoxType
.
WidthOverflow
,
"bounding boxes is out of bounds of the image"
)
1
)
data_voc2
=
ds
.
VOCDataset
(
DATA_DIR
,
task
=
"Detection"
,
mode
=
"train"
,
decode
=
True
,
shuffle
=
False
)
check_bad_box
(
data_voc2
,
BoxType
.
HeightOverflow
,
"bounding boxes is out of bounds of the image"
)
dataVoc2
=
ds
.
VOCDataset
(
DATA_DIR
,
task
=
"Detection"
,
mode
=
"train"
,
decode
=
True
,
shuffle
=
False
)
data_voc2
=
ds
.
VOCDataset
(
DATA_DIR
,
task
=
"Detection"
,
mode
=
"train"
,
decode
=
True
,
shuffle
=
False
)
check_bad_bbox
(
dataVoc2
,
test_op
,
InvalidBBoxType
.
WidthOverflow
,
"bounding boxes is out of bounds of the image"
)
check_bad_box
(
data_voc2
,
BoxType
.
NegativeXY
,
"min_x"
)
dataVoc2
=
ds
.
VOCDataset
(
DATA_DIR
,
task
=
"Detection"
,
mode
=
"train"
,
decode
=
True
,
shuffle
=
False
)
data_voc2
=
ds
.
VOCDataset
(
DATA_DIR
,
task
=
"Detection"
,
mode
=
"train"
,
decode
=
True
,
shuffle
=
False
)
check_bad_bbox
(
dataVoc2
,
test_op
,
InvalidBBoxType
.
HeightOverflow
,
"bounding boxes is out of bounds of the image"
)
check_bad_box
(
data_voc2
,
BoxType
.
WrongShape
,
"4 features"
)
dataVoc2
=
ds
.
VOCDataset
(
DATA_DIR
,
task
=
"Detection"
,
mode
=
"train"
,
decode
=
True
,
shuffle
=
False
)
check_bad_bbox
(
dataVoc2
,
test_op
,
InvalidBBoxType
.
NegativeXY
,
"min_x"
)
dataVoc2
=
ds
.
VOCDataset
(
DATA_DIR
,
task
=
"Detection"
,
mode
=
"train"
,
decode
=
True
,
shuffle
=
False
)
check_bad_bbox
(
dataVoc2
,
test_op
,
InvalidBBoxType
.
WrongShape
,
"4 features"
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
# set to false to not show plots
# set to false to not show plots
test_bounding_box_augment_with_rotation_op
(
False
)
test_bounding_box_augment_with_rotation_op
(
plot_vis
=
False
)
test_bounding_box_augment_with_crop_op
(
False
)
test_bounding_box_augment_with_crop_op
(
plot_vis
=
False
)
test_bounding_box_augment_valid_ratio_c
(
False
)
test_bounding_box_augment_valid_ratio_c
(
plot_vis
=
False
)
test_bounding_box_augment_valid_edge_c
(
plot_vis
=
False
)
test_bounding_box_augment_invalid_ratio_c
()
test_bounding_box_augment_invalid_ratio_c
()
test_bounding_box_augment_invalid_bounds_c
()
test_bounding_box_augment_invalid_bounds_c
()
tests/ut/python/dataset/test_center_crop.py
浏览文件 @
5a886794
...
@@ -109,23 +109,18 @@ def test_center_crop_comp(height=375, width=375, plot=False):
...
@@ -109,23 +109,18 @@ def test_center_crop_comp(height=375, width=375, plot=False):
visualize_list
(
image_c_cropped
,
image_py_cropped
,
visualize_mode
=
2
)
visualize_list
(
image_c_cropped
,
image_py_cropped
,
visualize_mode
=
2
)
# pylint: disable=unnecessary-lambda
def
test_crop_grayscale
(
height
=
375
,
width
=
375
):
def
test_crop_grayscale
(
height
=
375
,
width
=
375
):
"""
"""
Test that centercrop works with pad and grayscale images
Test that centercrop works with pad and grayscale images
"""
"""
def
channel_swap
(
image
):
# Note: image.transpose performs channel swap to allow py transforms to
"""
# work with c transforms
Py func hack for our pytransforms to work with c transforms
"""
return
(
image
.
transpose
(
1
,
2
,
0
)
*
255
).
astype
(
np
.
uint8
)
transforms
=
[
transforms
=
[
py_vision
.
Decode
(),
py_vision
.
Decode
(),
py_vision
.
Grayscale
(
1
),
py_vision
.
Grayscale
(
1
),
py_vision
.
ToTensor
(),
py_vision
.
ToTensor
(),
(
lambda
image
:
channel_swap
(
image
))
(
lambda
image
:
(
image
.
transpose
(
1
,
2
,
0
)
*
255
).
astype
(
np
.
uint8
))
]
]
transform
=
py_vision
.
ComposeOp
(
transforms
)
transform
=
py_vision
.
ComposeOp
(
transforms
)
...
...
tests/ut/python/dataset/test_config.py
浏览文件 @
5a886794
...
@@ -37,6 +37,7 @@ def test_basic():
...
@@ -37,6 +37,7 @@ def test_basic():
num_parallel_workers_original
=
ds
.
config
.
get_num_parallel_workers
()
num_parallel_workers_original
=
ds
.
config
.
get_num_parallel_workers
()
prefetch_size_original
=
ds
.
config
.
get_prefetch_size
()
prefetch_size_original
=
ds
.
config
.
get_prefetch_size
()
seed_original
=
ds
.
config
.
get_seed
()
seed_original
=
ds
.
config
.
get_seed
()
monitor_sampling_interval_original
=
ds
.
config
.
get_monitor_sampling_interval
()
ds
.
config
.
load
(
'../data/dataset/declient.cfg'
)
ds
.
config
.
load
(
'../data/dataset/declient.cfg'
)
...
@@ -45,23 +46,27 @@ def test_basic():
...
@@ -45,23 +46,27 @@ def test_basic():
# assert ds.config.get_worker_connector_size() == 16
# assert ds.config.get_worker_connector_size() == 16
assert
ds
.
config
.
get_prefetch_size
()
==
16
assert
ds
.
config
.
get_prefetch_size
()
==
16
assert
ds
.
config
.
get_seed
()
==
5489
assert
ds
.
config
.
get_seed
()
==
5489
# assert ds.config.get_monitor_sampling_interval() == 15
# ds.config.set_rows_per_buffer(1)
# ds.config.set_rows_per_buffer(1)
ds
.
config
.
set_num_parallel_workers
(
2
)
ds
.
config
.
set_num_parallel_workers
(
2
)
# ds.config.set_worker_connector_size(3)
# ds.config.set_worker_connector_size(3)
ds
.
config
.
set_prefetch_size
(
4
)
ds
.
config
.
set_prefetch_size
(
4
)
ds
.
config
.
set_seed
(
5
)
ds
.
config
.
set_seed
(
5
)
ds
.
config
.
set_monitor_sampling_interval
(
45
)
# assert ds.config.get_rows_per_buffer() == 1
# assert ds.config.get_rows_per_buffer() == 1
assert
ds
.
config
.
get_num_parallel_workers
()
==
2
assert
ds
.
config
.
get_num_parallel_workers
()
==
2
# assert ds.config.get_worker_connector_size() == 3
# assert ds.config.get_worker_connector_size() == 3
assert
ds
.
config
.
get_prefetch_size
()
==
4
assert
ds
.
config
.
get_prefetch_size
()
==
4
assert
ds
.
config
.
get_seed
()
==
5
assert
ds
.
config
.
get_seed
()
==
5
assert
ds
.
config
.
get_monitor_sampling_interval
()
==
45
# Restore original configuration values
# Restore original configuration values
ds
.
config
.
set_num_parallel_workers
(
num_parallel_workers_original
)
ds
.
config
.
set_num_parallel_workers
(
num_parallel_workers_original
)
ds
.
config
.
set_prefetch_size
(
prefetch_size_original
)
ds
.
config
.
set_prefetch_size
(
prefetch_size_original
)
ds
.
config
.
set_seed
(
seed_original
)
ds
.
config
.
set_seed
(
seed_original
)
ds
.
config
.
set_monitor_sampling_interval
(
monitor_sampling_interval_original
)
def
test_get_seed
():
def
test_get_seed
():
...
@@ -150,7 +155,7 @@ def test_deterministic_run_fail():
...
@@ -150,7 +155,7 @@ def test_deterministic_run_fail():
def
test_deterministic_run_pass
():
def
test_deterministic_run_pass
():
"""
"""
Test deterministic run with
with
setting the seed
Test deterministic run with setting the seed
"""
"""
logger
.
info
(
"test_deterministic_run_pass"
)
logger
.
info
(
"test_deterministic_run_pass"
)
...
...
tests/ut/python/dataset/test_filterop.py
浏览文件 @
5a886794
...
@@ -50,9 +50,7 @@ def test_diff_predicate_func():
...
@@ -50,9 +50,7 @@ def test_diff_predicate_func():
def
filter_func_ge
(
data
):
def
filter_func_ge
(
data
):
if
data
>
10
:
return
data
<=
10
return
False
return
True
def
generator_1d
():
def
generator_1d
():
...
@@ -108,15 +106,11 @@ def test_filter_by_generator_with_repeat_after():
...
@@ -108,15 +106,11 @@ def test_filter_by_generator_with_repeat_after():
def
filter_func_batch
(
data
):
def
filter_func_batch
(
data
):
if
data
[
0
]
>
8
:
return
data
[
0
]
<=
8
return
False
return
True
def
filter_func_batch_after
(
data
):
def
filter_func_batch_after
(
data
):
if
data
>
20
:
return
data
<=
20
return
False
return
True
# test with batchOp before
# test with batchOp before
...
@@ -152,9 +146,7 @@ def test_filter_by_generator_with_batch_after():
...
@@ -152,9 +146,7 @@ def test_filter_by_generator_with_batch_after():
def
filter_func_shuffle
(
data
):
def
filter_func_shuffle
(
data
):
if
data
>
20
:
return
data
<=
20
return
False
return
True
# test with batchOp before
# test with batchOp before
...
@@ -169,9 +161,7 @@ def test_filter_by_generator_with_shuffle():
...
@@ -169,9 +161,7 @@ def test_filter_by_generator_with_shuffle():
def
filter_func_shuffle_after
(
data
):
def
filter_func_shuffle_after
(
data
):
if
data
>
20
:
return
data
<=
20
return
False
return
True
# test with batchOp after
# test with batchOp after
...
@@ -197,15 +187,11 @@ def generator_1d_zip2():
...
@@ -197,15 +187,11 @@ def generator_1d_zip2():
def
filter_func_zip
(
data1
,
data2
):
def
filter_func_zip
(
data1
,
data2
):
_
=
data2
_
=
data2
if
data1
>
20
:
return
data1
<=
20
return
False
return
True
def
filter_func_zip_after
(
data1
):
def
filter_func_zip_after
(
data1
):
if
data1
>
20
:
return
data1
<=
20
return
False
return
True
# test with zipOp before
# test with zipOp before
...
@@ -247,16 +233,11 @@ def test_filter_by_generator_with_zip_after():
...
@@ -247,16 +233,11 @@ def test_filter_by_generator_with_zip_after():
def
filter_func_map
(
col1
,
col2
):
def
filter_func_map
(
col1
,
col2
):
_
=
col2
_
=
col2
if
col1
[
0
]
>
8
:
return
col1
[
0
]
>
8
return
True
return
False
# pylint: disable=simplifiable-if-statement
def
filter_func_map_part
(
col1
):
def
filter_func_map_part
(
col1
):
if
col1
<
3
:
return
col1
<
3
return
True
return
False
def
filter_func_map_all
(
col1
,
col2
):
def
filter_func_map_all
(
col1
,
col2
):
...
@@ -311,9 +292,7 @@ def test_filter_by_generator_with_map_part_col():
...
@@ -311,9 +292,7 @@ def test_filter_by_generator_with_map_part_col():
def
filter_func_rename
(
data
):
def
filter_func_rename
(
data
):
if
data
>
8
:
return
data
>
8
return
True
return
False
# test with rename before
# test with rename before
...
@@ -334,15 +313,11 @@ def test_filter_by_generator_with_rename():
...
@@ -334,15 +313,11 @@ def test_filter_by_generator_with_rename():
# test input_column
# test input_column
def
filter_func_input_column1
(
col1
,
col2
):
def
filter_func_input_column1
(
col1
,
col2
):
_
=
col2
_
=
col2
if
col1
[
0
]
<
8
:
return
col1
[
0
]
<
8
return
True
return
False
def
filter_func_input_column2
(
col1
):
def
filter_func_input_column2
(
col1
):
if
col1
[
0
]
<
8
:
return
col1
[
0
]
<
8
return
True
return
False
def
filter_func_input_column3
(
col1
):
def
filter_func_input_column3
(
col1
):
...
@@ -439,9 +414,7 @@ def test_filter_by_generator_Partial2():
...
@@ -439,9 +414,7 @@ def test_filter_by_generator_Partial2():
def
filter_func_Partial
(
col1
,
col2
):
def
filter_func_Partial
(
col1
,
col2
):
_
=
col2
_
=
col2
if
col1
[
0
]
%
3
==
0
:
return
col1
[
0
]
%
3
==
0
return
True
return
False
def
generator_big
(
maxid
=
20
):
def
generator_big
(
maxid
=
20
):
...
@@ -461,9 +434,7 @@ def test_filter_by_generator_Partial():
...
@@ -461,9 +434,7 @@ def test_filter_by_generator_Partial():
def
filter_func_cifar
(
col1
,
col2
):
def
filter_func_cifar
(
col1
,
col2
):
_
=
col1
_
=
col1
if
col2
%
3
==
0
:
return
col2
%
3
==
0
return
True
return
False
# test with cifar10
# test with cifar10
...
...
tests/ut/python/dataset/test_pad.py
浏览文件 @
5a886794
...
@@ -16,12 +16,12 @@
...
@@ -16,12 +16,12 @@
Testing Pad op in DE
Testing Pad op in DE
"""
"""
import
numpy
as
np
import
numpy
as
np
from
util
import
diff_mse
import
mindspore.dataset
as
ds
import
mindspore.dataset
as
ds
import
mindspore.dataset.transforms.vision.c_transforms
as
c_vision
import
mindspore.dataset.transforms.vision.c_transforms
as
c_vision
import
mindspore.dataset.transforms.vision.py_transforms
as
py_vision
import
mindspore.dataset.transforms.vision.py_transforms
as
py_vision
from
mindspore
import
log
as
logger
from
mindspore
import
log
as
logger
from
util
import
diff_mse
DATA_DIR
=
[
"../data/dataset/test_tf_file_3_images/train-0000-of-0001.data"
]
DATA_DIR
=
[
"../data/dataset/test_tf_file_3_images/train-0000-of-0001.data"
]
SCHEMA_DIR
=
"../data/dataset/test_tf_file_3_images/datasetSchema.json"
SCHEMA_DIR
=
"../data/dataset/test_tf_file_3_images/datasetSchema.json"
...
@@ -69,23 +69,19 @@ def test_pad_op():
...
@@ -69,23 +69,19 @@ def test_pad_op():
assert
mse
<
0.01
assert
mse
<
0.01
# pylint: disable=unnecessary-lambda
def
test_pad_grayscale
():
def
test_pad_grayscale
():
"""
"""
Tests that the pad works for grayscale images
Tests that the pad works for grayscale images
"""
"""
def
channel_swap
(
image
):
# Note: image.transpose performs channel swap to allow py transforms to
"""
# work with c transforms
Py func hack for our pytransforms to work with c transforms
"""
return
(
image
.
transpose
(
1
,
2
,
0
)
*
255
).
astype
(
np
.
uint8
)
transforms
=
[
transforms
=
[
py_vision
.
Decode
(),
py_vision
.
Decode
(),
py_vision
.
Grayscale
(
1
),
py_vision
.
Grayscale
(
1
),
py_vision
.
ToTensor
(),
py_vision
.
ToTensor
(),
(
lambda
image
:
channel_swap
(
image
))
(
lambda
image
:
(
image
.
transpose
(
1
,
2
,
0
)
*
255
).
astype
(
np
.
uint8
))
]
]
transform
=
py_vision
.
ComposeOp
(
transforms
)
transform
=
py_vision
.
ComposeOp
(
transforms
)
...
...
tests/ut/python/dataset/test_random_crop_and_resize_with_bbox.py
浏览文件 @
5a886794
...
@@ -13,17 +13,17 @@
...
@@ -13,17 +13,17 @@
# limitations under the License.
# limitations under the License.
# ==============================================================================
# ==============================================================================
"""
"""
Testing RandomCropAndResizeWithBBox op
Testing RandomCropAndResizeWithBBox op
in DE
"""
"""
import
numpy
as
np
import
numpy
as
np
import
matplotlib.pyplot
as
plt
import
matplotlib.patches
as
patches
import
mindspore.dataset
as
ds
import
mindspore.dataset
as
ds
import
mindspore.dataset.transforms.vision.c_transforms
as
c_vision
import
mindspore.dataset.transforms.vision.c_transforms
as
c_vision
from
mindspore
import
log
as
logger
from
mindspore
import
log
as
logger
from
util
import
visualize_with_bounding_boxes
,
InvalidBBoxType
,
check_bad_bbox
,
\
config_get_set_seed
,
config_get_set_num_parallel_workers
,
save_and_check_md5
GENERATE_GOLDEN
=
False
# updated VOC dataset with correct annotations
# updated VOC dataset with correct annotations
DATA_DIR
=
"../data/dataset/testVOC2012_2"
DATA_DIR
=
"../data/dataset/testVOC2012_2"
...
@@ -31,8 +31,7 @@ DATA_DIR = "../data/dataset/testVOC2012_2"
...
@@ -31,8 +31,7 @@ DATA_DIR = "../data/dataset/testVOC2012_2"
def
fix_annotate
(
bboxes
):
def
fix_annotate
(
bboxes
):
"""
"""
Update Current VOC dataset format to Proposed HQ BBox format
Fix annotations to format followed by mindspore.
:param bboxes: in [label, x_min, y_min, w, h, truncate, difficult] format
:param bboxes: in [label, x_min, y_min, w, h, truncate, difficult] format
:return: annotation in [x_min, y_min, w, h, label, truncate, difficult] format
:return: annotation in [x_min, y_min, w, h, label, truncate, difficult] format
"""
"""
...
@@ -46,112 +45,22 @@ def fix_annotate(bboxes):
...
@@ -46,112 +45,22 @@ def fix_annotate(bboxes):
return
bboxes
return
bboxes
def
add_bounding_boxes
(
ax
,
bboxes
):
def
test_random_resized_crop_with_bbox_op_c
(
plot_vis
=
False
):
for
bbox
in
bboxes
:
rect
=
patches
.
Rectangle
((
bbox
[
0
],
bbox
[
1
]),
bbox
[
2
],
bbox
[
3
],
linewidth
=
1
,
edgecolor
=
'r'
,
facecolor
=
'none'
)
# Add the patch to the Axes
ax
.
add_patch
(
rect
)
def
vis_check
(
orig
,
aug
):
if
not
isinstance
(
orig
,
list
)
or
not
isinstance
(
aug
,
list
):
return
False
if
len
(
orig
)
!=
len
(
aug
):
return
False
return
True
def
visualize
(
orig
,
aug
):
if
not
vis_check
(
orig
,
aug
):
return
plotrows
=
3
compset
=
int
(
len
(
orig
)
/
plotrows
)
orig
,
aug
=
np
.
array
(
orig
),
np
.
array
(
aug
)
orig
=
np
.
split
(
orig
[:
compset
*
plotrows
],
compset
)
+
[
orig
[
compset
*
plotrows
:]]
aug
=
np
.
split
(
aug
[:
compset
*
plotrows
],
compset
)
+
[
aug
[
compset
*
plotrows
:]]
for
ix
,
allData
in
enumerate
(
zip
(
orig
,
aug
)):
base_ix
=
ix
*
plotrows
# will signal what base level we're on
fig
,
axs
=
plt
.
subplots
(
len
(
allData
[
0
]),
2
)
fig
.
tight_layout
(
pad
=
1.5
)
for
x
,
(
dataA
,
dataB
)
in
enumerate
(
zip
(
allData
[
0
],
allData
[
1
])):
cur_ix
=
base_ix
+
x
axs
[
x
,
0
].
imshow
(
dataA
[
"image"
])
add_bounding_boxes
(
axs
[
x
,
0
],
dataA
[
"annotation"
])
axs
[
x
,
0
].
title
.
set_text
(
"Original"
+
str
(
cur_ix
+
1
))
print
(
"Original **
\n
"
,
str
(
cur_ix
+
1
),
" :"
,
dataA
[
"annotation"
])
axs
[
x
,
1
].
imshow
(
dataB
[
"image"
])
add_bounding_boxes
(
axs
[
x
,
1
],
dataB
[
"annotation"
])
axs
[
x
,
1
].
title
.
set_text
(
"Augmented"
+
str
(
cur_ix
+
1
))
print
(
"Augmented **
\n
"
,
str
(
cur_ix
+
1
),
" "
,
dataB
[
"annotation"
],
"
\n
"
)
plt
.
show
()
# Functions to pass to Gen for creating invalid bounding boxes
def
gen_bad_bbox_neg_xy
(
im
,
bbox
):
im_h
,
im_w
=
im
.
shape
[
0
],
im
.
shape
[
1
]
bbox
[
0
][:
4
]
=
[
-
50
,
-
50
,
im_w
-
10
,
im_h
-
10
]
return
im
,
bbox
def
gen_bad_bbox_overflow_width
(
im
,
bbox
):
im_h
,
im_w
=
im
.
shape
[
0
],
im
.
shape
[
1
]
bbox
[
0
][:
4
]
=
[
0
,
0
,
im_w
+
10
,
im_h
-
10
]
return
im
,
bbox
def
gen_bad_bbox_overflow_height
(
im
,
bbox
):
im_h
,
im_w
=
im
.
shape
[
0
],
im
.
shape
[
1
]
bbox
[
0
][:
4
]
=
[
0
,
0
,
im_w
-
10
,
im_h
+
10
]
return
im
,
bbox
def
gen_bad_bbox_wrong_shape
(
im
,
bbox
):
bbox
=
np
.
array
([[
0
,
0
,
0
]]).
astype
(
bbox
.
dtype
)
return
im
,
bbox
badGenFuncs
=
[
gen_bad_bbox_neg_xy
,
gen_bad_bbox_overflow_width
,
gen_bad_bbox_overflow_height
,
gen_bad_bbox_wrong_shape
]
assertVal
=
[
"min_x"
,
"is out of bounds of the image"
,
"is out of bounds of the image"
,
"4 features"
]
# Gen Edge case BBox
def
gen_bbox_edge
(
im
,
bbox
):
im_h
,
im_w
=
im
.
shape
[
0
],
im
.
shape
[
1
]
bbox
[
0
][:
4
]
=
[
0
,
0
,
im_w
,
im_h
]
return
im
,
bbox
def
test_c_random_resized_crop_with_bbox_op
(
plot_vis
=
False
):
"""
"""
Prints images side by side with and without Aug applied + bboxes to compare and test
Prints images and bboxes side by side with and without RandomResizedCropWithBBox Op applied,
tests with MD5 check, expected to pass
"""
"""
logger
.
info
(
"test_random_resized_crop_with_bbox_op_c"
)
original_seed
=
config_get_set_seed
(
23415
)
original_num_parallel_workers
=
config_get_set_num_parallel_workers
(
1
)
# Load dataset
# Load dataset
dataVoc1
=
ds
.
VOCDataset
(
DATA_DIR
,
task
=
"Detection"
,
mode
=
"train"
,
decode
=
True
,
shuffle
=
False
)
dataVoc1
=
ds
.
VOCDataset
(
DATA_DIR
,
task
=
"Detection"
,
mode
=
"train"
,
decode
=
True
,
shuffle
=
False
)
dataVoc2
=
ds
.
VOCDataset
(
DATA_DIR
,
task
=
"Detection"
,
mode
=
"train"
,
decode
=
True
,
shuffle
=
False
)
dataVoc2
=
ds
.
VOCDataset
(
DATA_DIR
,
task
=
"Detection"
,
mode
=
"train"
,
decode
=
True
,
shuffle
=
False
)
test_op
=
c_vision
.
RandomResizedCropWithBBox
((
256
,
512
),
(
0.5
,
0.5
),
(
0.5
,
0.5
))
test_op
=
c_vision
.
RandomResizedCropWithBBox
((
256
,
512
),
(
0.5
,
0.5
),
(
0.5
,
0.5
))
# maps to fix annotations to HQ standard
dataVoc1
=
dataVoc1
.
map
(
input_columns
=
[
"annotation"
],
dataVoc1
=
dataVoc1
.
map
(
input_columns
=
[
"annotation"
],
output_columns
=
[
"annotation"
],
output_columns
=
[
"annotation"
],
operations
=
fix_annotate
)
operations
=
fix_annotate
)
...
@@ -164,6 +73,9 @@ def test_c_random_resized_crop_with_bbox_op(plot_vis=False):
...
@@ -164,6 +73,9 @@ def test_c_random_resized_crop_with_bbox_op(plot_vis=False):
columns_order
=
[
"image"
,
"annotation"
],
columns_order
=
[
"image"
,
"annotation"
],
operations
=
[
test_op
])
# Add column for "annotation"
operations
=
[
test_op
])
# Add column for "annotation"
filename
=
"random_resized_crop_with_bbox_01_c_result.npz"
save_and_check_md5
(
dataVoc2
,
filename
,
generate_golden
=
GENERATE_GOLDEN
)
unaugSamp
,
augSamp
=
[],
[]
unaugSamp
,
augSamp
=
[],
[]
for
unAug
,
Aug
in
zip
(
dataVoc1
.
create_dict_iterator
(),
dataVoc2
.
create_dict_iterator
()):
for
unAug
,
Aug
in
zip
(
dataVoc1
.
create_dict_iterator
(),
dataVoc2
.
create_dict_iterator
()):
...
@@ -171,20 +83,26 @@ def test_c_random_resized_crop_with_bbox_op(plot_vis=False):
...
@@ -171,20 +83,26 @@ def test_c_random_resized_crop_with_bbox_op(plot_vis=False):
augSamp
.
append
(
Aug
)
augSamp
.
append
(
Aug
)
if
plot_vis
:
if
plot_vis
:
visualize
(
unaugSamp
,
augSamp
)
visualize_with_bounding_boxes
(
unaugSamp
,
augSamp
)
# Restore config setting
ds
.
config
.
set_seed
(
original_seed
)
ds
.
config
.
set_num_parallel_workers
(
original_num_parallel_workers
)
def
test_
c_random_resized_crop_with_bbox_op_edge
(
plot_vis
=
False
):
def
test_
random_resized_crop_with_bbox_op_edge_c
(
plot_vis
=
False
):
"""
"""
Prints images side by side with and without Aug applied + bboxes to compare and test
Prints images and bboxes side by side with and without RandomResizedCropWithBBox Op applied,
tests on dynamically generated edge case, expected to pass
"""
"""
logger
.
info
(
"test_random_resized_crop_with_bbox_op_edge_c"
)
# Load dataset
# Load dataset
dataVoc1
=
ds
.
VOCDataset
(
DATA_DIR
,
task
=
"Detection"
,
mode
=
"train"
,
decode
=
True
,
shuffle
=
False
)
dataVoc1
=
ds
.
VOCDataset
(
DATA_DIR
,
task
=
"Detection"
,
mode
=
"train"
,
decode
=
True
,
shuffle
=
False
)
dataVoc2
=
ds
.
VOCDataset
(
DATA_DIR
,
task
=
"Detection"
,
mode
=
"train"
,
decode
=
True
,
shuffle
=
False
)
dataVoc2
=
ds
.
VOCDataset
(
DATA_DIR
,
task
=
"Detection"
,
mode
=
"train"
,
decode
=
True
,
shuffle
=
False
)
test_op
=
c_vision
.
RandomResizedCropWithBBox
((
256
,
512
),
(
0.5
,
0.5
),
(
0.5
,
0.5
))
test_op
=
c_vision
.
RandomResizedCropWithBBox
((
256
,
512
),
(
0.5
,
0.5
),
(
0.5
,
0.5
))
# maps to fix annotations to HQ standard
dataVoc1
=
dataVoc1
.
map
(
input_columns
=
[
"annotation"
],
dataVoc1
=
dataVoc1
.
map
(
input_columns
=
[
"annotation"
],
output_columns
=
[
"annotation"
],
output_columns
=
[
"annotation"
],
operations
=
fix_annotate
)
operations
=
fix_annotate
)
...
@@ -192,17 +110,17 @@ def test_c_random_resized_crop_with_bbox_op_edge(plot_vis=False):
...
@@ -192,17 +110,17 @@ def test_c_random_resized_crop_with_bbox_op_edge(plot_vis=False):
output_columns
=
[
"annotation"
],
output_columns
=
[
"annotation"
],
operations
=
fix_annotate
)
operations
=
fix_annotate
)
#
Modify BBoxes to serve as valid edge cases
#
maps to convert data into valid edge case data
dataVoc
2
=
dataVoc2
.
map
(
input_columns
=
[
"image"
,
"annotation"
],
dataVoc
1
=
dataVoc1
.
map
(
input_columns
=
[
"image"
,
"annotation"
],
output_columns
=
[
"image"
,
"annotation"
],
output_columns
=
[
"image"
,
"annotation"
],
columns_order
=
[
"image"
,
"annotation"
],
columns_order
=
[
"image"
,
"annotation"
],
operations
=
[
gen_bbox_edge
])
operations
=
[
lambda
img
,
bboxes
:
(
img
,
np
.
array
([[
0
,
0
,
img
.
shape
[
1
],
img
.
shape
[
0
]]]).
astype
(
bboxes
.
dtype
))
])
#
map to apply ops
#
Test Op added to list of Operations here
dataVoc2
=
dataVoc2
.
map
(
input_columns
=
[
"image"
,
"annotation"
],
dataVoc2
=
dataVoc2
.
map
(
input_columns
=
[
"image"
,
"annotation"
],
output_columns
=
[
"image"
,
"annotation"
],
output_columns
=
[
"image"
,
"annotation"
],
columns_order
=
[
"image"
,
"annotation"
],
columns_order
=
[
"image"
,
"annotation"
],
operations
=
[
test_op
])
# Add column for "annotation"
operations
=
[
lambda
img
,
bboxes
:
(
img
,
np
.
array
([[
0
,
0
,
img
.
shape
[
1
],
img
.
shape
[
0
]]]).
astype
(
bboxes
.
dtype
)),
test_op
])
unaugSamp
,
augSamp
=
[],
[]
unaugSamp
,
augSamp
=
[],
[]
...
@@ -211,21 +129,22 @@ def test_c_random_resized_crop_with_bbox_op_edge(plot_vis=False):
...
@@ -211,21 +129,22 @@ def test_c_random_resized_crop_with_bbox_op_edge(plot_vis=False):
augSamp
.
append
(
Aug
)
augSamp
.
append
(
Aug
)
if
plot_vis
:
if
plot_vis
:
visualize
(
unaugSamp
,
augSamp
)
visualize
_with_bounding_boxes
(
unaugSamp
,
augSamp
)
def
test_
c_random_resized_crop_with_bbox_op_invalid
():
def
test_
random_resized_crop_with_bbox_op_invalid_c
():
"""
"""
Prints images side by side with and without Aug applied + bboxes to compare and test
Tests RandomResizedCropWithBBox on invalid constructor parameters, expected to raise ValueError
"""
"""
# Load dataset # only loading the to AugDataset as test will fail on this
logger
.
info
(
"test_random_resized_crop_with_bbox_op_invalid_c"
)
# Load dataset, only Augmented Dataset as test will raise ValueError
dataVoc2
=
ds
.
VOCDataset
(
DATA_DIR
,
task
=
"Detection"
,
mode
=
"train"
,
decode
=
True
,
shuffle
=
False
)
dataVoc2
=
ds
.
VOCDataset
(
DATA_DIR
,
task
=
"Detection"
,
mode
=
"train"
,
decode
=
True
,
shuffle
=
False
)
try
:
try
:
# If input range of scale is not in the order of (min, max), ValueError will be raised.
# If input range of scale is not in the order of (min, max), ValueError will be raised.
test_op
=
c_vision
.
RandomResizedCropWithBBox
((
256
,
512
),
(
1
,
0.5
),
(
0.5
,
0.5
))
test_op
=
c_vision
.
RandomResizedCropWithBBox
((
256
,
512
),
(
1
,
0.5
),
(
0.5
,
0.5
))
# maps to fix annotations to HQ standard
dataVoc2
=
dataVoc2
.
map
(
input_columns
=
[
"annotation"
],
dataVoc2
=
dataVoc2
.
map
(
input_columns
=
[
"annotation"
],
output_columns
=
[
"annotation"
],
output_columns
=
[
"annotation"
],
operations
=
fix_annotate
)
operations
=
fix_annotate
)
...
@@ -243,10 +162,11 @@ def test_c_random_resized_crop_with_bbox_op_invalid():
...
@@ -243,10 +162,11 @@ def test_c_random_resized_crop_with_bbox_op_invalid():
assert
"Input range is not valid"
in
str
(
err
)
assert
"Input range is not valid"
in
str
(
err
)
def
test_
c_random_resized_crop_with_bbox_op_invalid2
():
def
test_
random_resized_crop_with_bbox_op_invalid2_c
():
"""
"""
Prints images side by side with and without Aug applied + bboxes to compare and test
Tests RandomResizedCropWithBBox Op on invalid constructor parameters, expected to raise ValueError
"""
"""
logger
.
info
(
"test_random_resized_crop_with_bbox_op_invalid2_c"
)
# Load dataset # only loading the to AugDataset as test will fail on this
# Load dataset # only loading the to AugDataset as test will fail on this
dataVoc2
=
ds
.
VOCDataset
(
DATA_DIR
,
task
=
"Detection"
,
mode
=
"train"
,
decode
=
True
,
shuffle
=
False
)
dataVoc2
=
ds
.
VOCDataset
(
DATA_DIR
,
task
=
"Detection"
,
mode
=
"train"
,
decode
=
True
,
shuffle
=
False
)
...
@@ -254,7 +174,6 @@ def test_c_random_resized_crop_with_bbox_op_invalid2():
...
@@ -254,7 +174,6 @@ def test_c_random_resized_crop_with_bbox_op_invalid2():
# If input range of ratio is not in the order of (min, max), ValueError will be raised.
# If input range of ratio is not in the order of (min, max), ValueError will be raised.
test_op
=
c_vision
.
RandomResizedCropWithBBox
((
256
,
512
),
(
1
,
1
),
(
1
,
0.5
))
test_op
=
c_vision
.
RandomResizedCropWithBBox
((
256
,
512
),
(
1
,
1
),
(
1
,
0.5
))
# maps to fix annotations to HQ standard
dataVoc2
=
dataVoc2
.
map
(
input_columns
=
[
"annotation"
],
dataVoc2
=
dataVoc2
.
map
(
input_columns
=
[
"annotation"
],
output_columns
=
[
"annotation"
],
output_columns
=
[
"annotation"
],
operations
=
fix_annotate
)
operations
=
fix_annotate
)
...
@@ -272,41 +191,26 @@ def test_c_random_resized_crop_with_bbox_op_invalid2():
...
@@ -272,41 +191,26 @@ def test_c_random_resized_crop_with_bbox_op_invalid2():
assert
"Input range is not valid"
in
str
(
err
)
assert
"Input range is not valid"
in
str
(
err
)
def
test_c_random_resized_crop_with_bbox_op_bad
():
def
test_random_resized_crop_with_bbox_op_bad_c
():
# Should Fail - Errors logged to logger
"""
for
ix
,
badFunc
in
enumerate
(
badGenFuncs
):
Test RandomCropWithBBox op with invalid bounding boxes, expected to catch multiple errors.
try
:
"""
dataVoc2
=
ds
.
VOCDataset
(
DATA_DIR
,
task
=
"Detection"
,
mode
=
"train"
,
logger
.
info
(
"test_random_resized_crop_with_bbox_op_bad_c"
)
decode
=
True
,
shuffle
=
False
)
test_op
=
c_vision
.
RandomResizedCropWithBBox
((
256
,
512
),
(
0.5
,
0.5
),
(
0.5
,
0.5
))
test_op
=
c_vision
.
RandomVerticalFlipWithBBox
(
1
)
dataVoc2
=
dataVoc2
.
map
(
input_columns
=
[
"annotation"
],
output_columns
=
[
"annotation"
],
operations
=
fix_annotate
)
dataVoc2
=
dataVoc2
.
map
(
input_columns
=
[
"image"
,
"annotation"
],
output_columns
=
[
"image"
,
"annotation"
],
columns_order
=
[
"image"
,
"annotation"
],
operations
=
[
badFunc
])
# map to apply ops
dataVoc2
=
dataVoc2
.
map
(
input_columns
=
[
"image"
,
"annotation"
],
output_columns
=
[
"image"
,
"annotation"
],
columns_order
=
[
"image"
,
"annotation"
],
operations
=
[
test_op
])
for
_
in
dataVoc2
.
create_dict_iterator
():
break
# first sample will cause exception
except
RuntimeError
as
err
:
data_voc2
=
ds
.
VOCDataset
(
DATA_DIR
,
task
=
"Detection"
,
mode
=
"train"
,
decode
=
True
,
shuffle
=
False
)
logger
.
info
(
"Got an exception in DE: {}"
.
format
(
str
(
err
)))
check_bad_bbox
(
data_voc2
,
test_op
,
InvalidBBoxType
.
WidthOverflow
,
"bounding boxes is out of bounds of the image"
)
assert
assertVal
[
ix
]
in
str
(
err
)
data_voc2
=
ds
.
VOCDataset
(
DATA_DIR
,
task
=
"Detection"
,
mode
=
"train"
,
decode
=
True
,
shuffle
=
False
)
check_bad_bbox
(
data_voc2
,
test_op
,
InvalidBBoxType
.
HeightOverflow
,
"bounding boxes is out of bounds of the image"
)
data_voc2
=
ds
.
VOCDataset
(
DATA_DIR
,
task
=
"Detection"
,
mode
=
"train"
,
decode
=
True
,
shuffle
=
False
)
check_bad_bbox
(
data_voc2
,
test_op
,
InvalidBBoxType
.
NegativeXY
,
"min_x"
)
data_voc2
=
ds
.
VOCDataset
(
DATA_DIR
,
task
=
"Detection"
,
mode
=
"train"
,
decode
=
True
,
shuffle
=
False
)
check_bad_bbox
(
data_voc2
,
test_op
,
InvalidBBoxType
.
WrongShape
,
"4 features"
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
test_
c_random_resized_crop_with_bbox_op
(
plot_vis
=
True
)
test_
random_resized_crop_with_bbox_op_c
(
plot_vis
=
True
)
test_
c_random_resized_crop_with_bbox_op_edge
(
plot_vis
=
True
)
test_
random_resized_crop_with_bbox_op_edge_c
(
plot_vis
=
True
)
test_
c_random_resized_crop_with_bbox_op_invalid
()
test_
random_resized_crop_with_bbox_op_invalid_c
()
test_
c_random_resized_crop_with_bbox_op_invalid2
()
test_
random_resized_crop_with_bbox_op_invalid2_c
()
test_
c_random_resized_crop_with_bbox_op_bad
()
test_
random_resized_crop_with_bbox_op_bad_c
()
tests/ut/python/dataset/test_random_crop_with_bbox.py
浏览文件 @
5a886794
...
@@ -13,18 +13,18 @@
...
@@ -13,18 +13,18 @@
# limitations under the License.
# limitations under the License.
# ==============================================================================
# ==============================================================================
"""
"""
Testing RandomCropWithBBox op
Testing RandomCropWithBBox op
in DE
"""
"""
import
numpy
as
np
import
numpy
as
np
import
matplotlib.pyplot
as
plt
import
matplotlib.patches
as
patches
import
mindspore.dataset
as
ds
import
mindspore.dataset
as
ds
import
mindspore.dataset.transforms.vision.c_transforms
as
c_vision
import
mindspore.dataset.transforms.vision.c_transforms
as
c_vision
import
mindspore.dataset.transforms.vision.utils
as
mode
import
mindspore.dataset.transforms.vision.utils
as
mode
from
mindspore
import
log
as
logger
from
mindspore
import
log
as
logger
from
util
import
visualize_with_bounding_boxes
,
InvalidBBoxType
,
check_bad_bbox
,
\
config_get_set_seed
,
config_get_set_num_parallel_workers
,
save_and_check_md5
GENERATE_GOLDEN
=
False
# updated VOC dataset with correct annotations
# updated VOC dataset with correct annotations
DATA_DIR
=
"../data/dataset/testVOC2012_2"
DATA_DIR
=
"../data/dataset/testVOC2012_2"
...
@@ -32,8 +32,7 @@ DATA_DIR = "../data/dataset/testVOC2012_2"
...
@@ -32,8 +32,7 @@ DATA_DIR = "../data/dataset/testVOC2012_2"
def
fix_annotate
(
bboxes
):
def
fix_annotate
(
bboxes
):
"""
"""
Update Current VOC dataset format to Proposed HQ BBox format
Fix annotations to format followed by mindspore.
:param bboxes: in [label, x_min, y_min, w, h, truncate, difficult] format
:param bboxes: in [label, x_min, y_min, w, h, truncate, difficult] format
:return: annotation in [x_min, y_min, w, h, label, truncate, difficult] format
:return: annotation in [x_min, y_min, w, h, label, truncate, difficult] format
"""
"""
...
@@ -47,113 +46,19 @@ def fix_annotate(bboxes):
...
@@ -47,113 +46,19 @@ def fix_annotate(bboxes):
return
bboxes
return
bboxes
def
add_bounding_boxes
(
ax
,
bboxes
):
for
bbox
in
bboxes
:
rect
=
patches
.
Rectangle
((
bbox
[
0
],
bbox
[
1
]),
bbox
[
2
],
bbox
[
3
],
linewidth
=
1
,
edgecolor
=
'r'
,
facecolor
=
'none'
)
# Add the patch to the Axes
ax
.
add_patch
(
rect
)
def
vis_check
(
orig
,
aug
):
if
not
isinstance
(
orig
,
list
)
or
not
isinstance
(
aug
,
list
):
return
False
if
len
(
orig
)
!=
len
(
aug
):
return
False
return
True
def
visualize
(
orig
,
aug
):
if
not
vis_check
(
orig
,
aug
):
return
plotrows
=
3
compset
=
int
(
len
(
orig
)
/
plotrows
)
orig
,
aug
=
np
.
array
(
orig
),
np
.
array
(
aug
)
orig
=
np
.
split
(
orig
[:
compset
*
plotrows
],
compset
)
+
[
orig
[
compset
*
plotrows
:]]
aug
=
np
.
split
(
aug
[:
compset
*
plotrows
],
compset
)
+
[
aug
[
compset
*
plotrows
:]]
for
ix
,
allData
in
enumerate
(
zip
(
orig
,
aug
)):
base_ix
=
ix
*
plotrows
# will signal what base level we're on
fig
,
axs
=
plt
.
subplots
(
len
(
allData
[
0
]),
2
)
fig
.
tight_layout
(
pad
=
1.5
)
for
x
,
(
dataA
,
dataB
)
in
enumerate
(
zip
(
allData
[
0
],
allData
[
1
])):
cur_ix
=
base_ix
+
x
axs
[
x
,
0
].
imshow
(
dataA
[
"image"
])
add_bounding_boxes
(
axs
[
x
,
0
],
dataA
[
"annotation"
])
axs
[
x
,
0
].
title
.
set_text
(
"Original"
+
str
(
cur_ix
+
1
))
print
(
"Original **
\n
"
,
str
(
cur_ix
+
1
),
" :"
,
dataA
[
"annotation"
])
axs
[
x
,
1
].
imshow
(
dataB
[
"image"
])
add_bounding_boxes
(
axs
[
x
,
1
],
dataB
[
"annotation"
])
axs
[
x
,
1
].
title
.
set_text
(
"Augmented"
+
str
(
cur_ix
+
1
))
print
(
"Augmented **
\n
"
,
str
(
cur_ix
+
1
),
" "
,
dataB
[
"annotation"
],
"
\n
"
)
plt
.
show
()
# Functions to pass to Gen for creating invalid bounding boxes
def
gen_bad_bbox_neg_xy
(
im
,
bbox
):
im_h
,
im_w
=
im
.
shape
[
0
],
im
.
shape
[
1
]
bbox
[
0
][:
4
]
=
[
-
50
,
-
50
,
im_w
-
10
,
im_h
-
10
]
return
im
,
bbox
def
gen_bad_bbox_overflow_width
(
im
,
bbox
):
im_h
,
im_w
=
im
.
shape
[
0
],
im
.
shape
[
1
]
bbox
[
0
][:
4
]
=
[
0
,
0
,
im_w
+
10
,
im_h
-
10
]
return
im
,
bbox
def
gen_bad_bbox_overflow_height
(
im
,
bbox
):
im_h
,
im_w
=
im
.
shape
[
0
],
im
.
shape
[
1
]
bbox
[
0
][:
4
]
=
[
0
,
0
,
im_w
-
10
,
im_h
+
10
]
return
im
,
bbox
def
gen_bad_bbox_wrong_shape
(
im
,
bbox
):
bbox
=
np
.
array
([[
0
,
0
,
0
]]).
astype
(
bbox
.
dtype
)
return
im
,
bbox
badGenFuncs
=
[
gen_bad_bbox_neg_xy
,
gen_bad_bbox_overflow_width
,
gen_bad_bbox_overflow_height
,
gen_bad_bbox_wrong_shape
]
assertVal
=
[
"min_x"
,
"is out of bounds of the image"
,
"is out of bounds of the image"
,
"4 features"
]
# Gen Edge case BBox
def
gen_bbox_edge
(
im
,
bbox
):
im_h
,
im_w
=
im
.
shape
[
0
],
im
.
shape
[
1
]
bbox
[
0
][:
4
]
=
[
0
,
0
,
im_w
,
im_h
]
return
im
,
bbox
def
test_random_crop_with_bbox_op_c
(
plot_vis
=
False
):
def
test_random_crop_with_bbox_op_c
(
plot_vis
=
False
):
"""
"""
Prints images
side by side with and without Aug applied + bboxes
Prints images
and bboxes side by side with and without RandomCropWithBBox Op applied
"""
"""
logger
.
info
(
"test_random_crop_with_bbox_op_c"
)
# Load dataset
# Load dataset
dataVoc1
=
ds
.
VOCDataset
(
DATA_DIR
,
task
=
"Detection"
,
mode
=
"train"
,
decode
=
True
,
shuffle
=
False
)
dataVoc1
=
ds
.
VOCDataset
(
DATA_DIR
,
task
=
"Detection"
,
mode
=
"train"
,
decode
=
True
,
shuffle
=
False
)
dataVoc2
=
ds
.
VOCDataset
(
DATA_DIR
,
task
=
"Detection"
,
mode
=
"train"
,
decode
=
True
,
shuffle
=
False
)
dataVoc2
=
ds
.
VOCDataset
(
DATA_DIR
,
task
=
"Detection"
,
mode
=
"train"
,
decode
=
True
,
shuffle
=
False
)
# define test OP with values to match existing Op
unit - test
# define test OP with values to match existing Op
UT
test_op
=
c_vision
.
RandomCropWithBBox
([
512
,
512
],
[
200
,
200
,
200
,
200
])
test_op
=
c_vision
.
RandomCropWithBBox
([
512
,
512
],
[
200
,
200
,
200
,
200
])
# maps to fix annotations to HQ standard
dataVoc1
=
dataVoc1
.
map
(
input_columns
=
[
"annotation"
],
dataVoc1
=
dataVoc1
.
map
(
input_columns
=
[
"annotation"
],
output_columns
=
[
"annotation"
],
output_columns
=
[
"annotation"
],
operations
=
fix_annotate
)
operations
=
fix_annotate
)
...
@@ -173,14 +78,17 @@ def test_random_crop_with_bbox_op_c(plot_vis=False):
...
@@ -173,14 +78,17 @@ def test_random_crop_with_bbox_op_c(plot_vis=False):
augSamp
.
append
(
Aug
)
augSamp
.
append
(
Aug
)
if
plot_vis
:
if
plot_vis
:
visualize
(
unaugSamp
,
augSamp
)
visualize
_with_bounding_boxes
(
unaugSamp
,
augSamp
)
def
test_random_crop_with_bbox_op2_c
(
plot_vis
=
False
):
def
test_random_crop_with_bbox_op2_c
(
plot_vis
=
False
):
"""
"""
Prints images
side by side with and without Aug applied + bboxes
Prints images
and bboxes side by side with and without RandomCropWithBBox Op applied,
With Fill Value
with md5 check, expected to pass
"""
"""
logger
.
info
(
"test_random_crop_with_bbox_op2_c"
)
original_seed
=
config_get_set_seed
(
593447
)
original_num_parallel_workers
=
config_get_set_num_parallel_workers
(
1
)
# Load dataset
# Load dataset
dataVoc1
=
ds
.
VOCDataset
(
DATA_DIR
,
task
=
"Detection"
,
mode
=
"train"
,
decode
=
True
,
shuffle
=
False
)
dataVoc1
=
ds
.
VOCDataset
(
DATA_DIR
,
task
=
"Detection"
,
mode
=
"train"
,
decode
=
True
,
shuffle
=
False
)
...
@@ -189,7 +97,6 @@ def test_random_crop_with_bbox_op2_c(plot_vis=False):
...
@@ -189,7 +97,6 @@ def test_random_crop_with_bbox_op2_c(plot_vis=False):
# define test OP with values to match existing Op unit - test
# define test OP with values to match existing Op unit - test
test_op
=
c_vision
.
RandomCropWithBBox
(
512
,
[
200
,
200
,
200
,
200
],
fill_value
=
(
255
,
255
,
255
))
test_op
=
c_vision
.
RandomCropWithBBox
(
512
,
[
200
,
200
,
200
,
200
],
fill_value
=
(
255
,
255
,
255
))
# maps to fix annotations to HQ standard
dataVoc1
=
dataVoc1
.
map
(
input_columns
=
[
"annotation"
],
dataVoc1
=
dataVoc1
.
map
(
input_columns
=
[
"annotation"
],
output_columns
=
[
"annotation"
],
output_columns
=
[
"annotation"
],
operations
=
fix_annotate
)
operations
=
fix_annotate
)
...
@@ -202,6 +109,9 @@ def test_random_crop_with_bbox_op2_c(plot_vis=False):
...
@@ -202,6 +109,9 @@ def test_random_crop_with_bbox_op2_c(plot_vis=False):
columns_order
=
[
"image"
,
"annotation"
],
columns_order
=
[
"image"
,
"annotation"
],
operations
=
[
test_op
])
# Add column for "annotation"
operations
=
[
test_op
])
# Add column for "annotation"
filename
=
"random_crop_with_bbox_01_c_result.npz"
save_and_check_md5
(
dataVoc2
,
filename
,
generate_golden
=
GENERATE_GOLDEN
)
unaugSamp
,
augSamp
=
[],
[]
unaugSamp
,
augSamp
=
[],
[]
for
unAug
,
Aug
in
zip
(
dataVoc1
.
create_dict_iterator
(),
dataVoc2
.
create_dict_iterator
()):
for
unAug
,
Aug
in
zip
(
dataVoc1
.
create_dict_iterator
(),
dataVoc2
.
create_dict_iterator
()):
...
@@ -209,14 +119,20 @@ def test_random_crop_with_bbox_op2_c(plot_vis=False):
...
@@ -209,14 +119,20 @@ def test_random_crop_with_bbox_op2_c(plot_vis=False):
augSamp
.
append
(
Aug
)
augSamp
.
append
(
Aug
)
if
plot_vis
:
if
plot_vis
:
visualize
(
unaugSamp
,
augSamp
)
visualize_with_bounding_boxes
(
unaugSamp
,
augSamp
)
# Restore config setting
ds
.
config
.
set_seed
(
original_seed
)
ds
.
config
.
set_num_parallel_workers
(
original_num_parallel_workers
)
def
test_random_crop_with_bbox_op3_c
(
plot_vis
=
False
):
def
test_random_crop_with_bbox_op3_c
(
plot_vis
=
False
):
"""
"""
Prints images
side by side with and without Aug applied + bboxes
Prints images
and bboxes side by side with and without RandomCropWithBBox Op applied,
With Padding Mode
passed
with Padding Mode explicitly
passed
"""
"""
logger
.
info
(
"test_random_crop_with_bbox_op3_c"
)
# Load dataset
# Load dataset
dataVoc1
=
ds
.
VOCDataset
(
DATA_DIR
,
task
=
"Detection"
,
mode
=
"train"
,
decode
=
True
,
shuffle
=
False
)
dataVoc1
=
ds
.
VOCDataset
(
DATA_DIR
,
task
=
"Detection"
,
mode
=
"train"
,
decode
=
True
,
shuffle
=
False
)
dataVoc2
=
ds
.
VOCDataset
(
DATA_DIR
,
task
=
"Detection"
,
mode
=
"train"
,
decode
=
True
,
shuffle
=
False
)
dataVoc2
=
ds
.
VOCDataset
(
DATA_DIR
,
task
=
"Detection"
,
mode
=
"train"
,
decode
=
True
,
shuffle
=
False
)
...
@@ -224,7 +140,6 @@ def test_random_crop_with_bbox_op3_c(plot_vis=False):
...
@@ -224,7 +140,6 @@ def test_random_crop_with_bbox_op3_c(plot_vis=False):
# define test OP with values to match existing Op unit - test
# define test OP with values to match existing Op unit - test
test_op
=
c_vision
.
RandomCropWithBBox
(
512
,
[
200
,
200
,
200
,
200
],
padding_mode
=
mode
.
Border
.
EDGE
)
test_op
=
c_vision
.
RandomCropWithBBox
(
512
,
[
200
,
200
,
200
,
200
],
padding_mode
=
mode
.
Border
.
EDGE
)
# maps to fix annotations to HQ standard
dataVoc1
=
dataVoc1
.
map
(
input_columns
=
[
"annotation"
],
dataVoc1
=
dataVoc1
.
map
(
input_columns
=
[
"annotation"
],
output_columns
=
[
"annotation"
],
output_columns
=
[
"annotation"
],
operations
=
fix_annotate
)
operations
=
fix_annotate
)
...
@@ -244,14 +159,16 @@ def test_random_crop_with_bbox_op3_c(plot_vis=False):
...
@@ -244,14 +159,16 @@ def test_random_crop_with_bbox_op3_c(plot_vis=False):
augSamp
.
append
(
Aug
)
augSamp
.
append
(
Aug
)
if
plot_vis
:
if
plot_vis
:
visualize
(
unaugSamp
,
augSamp
)
visualize
_with_bounding_boxes
(
unaugSamp
,
augSamp
)
def
test_random_crop_with_bbox_op_edge_c
(
plot_vis
=
False
):
def
test_random_crop_with_bbox_op_edge_c
(
plot_vis
=
False
):
"""
"""
Prints images
side by side with and without Aug applied + bboxes
Prints images
and bboxes side by side with and without RandomCropWithBBox Op applied,
Testing for an Edge case
applied on dynamically generated edge case, expected to pass
"""
"""
logger
.
info
(
"test_random_crop_with_bbox_op_edge_c"
)
# Load dataset
# Load dataset
dataVoc1
=
ds
.
VOCDataset
(
DATA_DIR
,
task
=
"Detection"
,
mode
=
"train"
,
decode
=
True
,
shuffle
=
False
)
dataVoc1
=
ds
.
VOCDataset
(
DATA_DIR
,
task
=
"Detection"
,
mode
=
"train"
,
decode
=
True
,
shuffle
=
False
)
dataVoc2
=
ds
.
VOCDataset
(
DATA_DIR
,
task
=
"Detection"
,
mode
=
"train"
,
decode
=
True
,
shuffle
=
False
)
dataVoc2
=
ds
.
VOCDataset
(
DATA_DIR
,
task
=
"Detection"
,
mode
=
"train"
,
decode
=
True
,
shuffle
=
False
)
...
@@ -259,7 +176,6 @@ def test_random_crop_with_bbox_op_edge_c(plot_vis=False):
...
@@ -259,7 +176,6 @@ def test_random_crop_with_bbox_op_edge_c(plot_vis=False):
# define test OP with values to match existing Op unit - test
# define test OP with values to match existing Op unit - test
test_op
=
c_vision
.
RandomCropWithBBox
(
512
,
[
200
,
200
,
200
,
200
],
padding_mode
=
mode
.
Border
.
EDGE
)
test_op
=
c_vision
.
RandomCropWithBBox
(
512
,
[
200
,
200
,
200
,
200
],
padding_mode
=
mode
.
Border
.
EDGE
)
# maps to fix annotations to HQ standard
dataVoc1
=
dataVoc1
.
map
(
input_columns
=
[
"annotation"
],
dataVoc1
=
dataVoc1
.
map
(
input_columns
=
[
"annotation"
],
output_columns
=
[
"annotation"
],
output_columns
=
[
"annotation"
],
operations
=
fix_annotate
)
operations
=
fix_annotate
)
...
@@ -267,17 +183,17 @@ def test_random_crop_with_bbox_op_edge_c(plot_vis=False):
...
@@ -267,17 +183,17 @@ def test_random_crop_with_bbox_op_edge_c(plot_vis=False):
output_columns
=
[
"annotation"
],
output_columns
=
[
"annotation"
],
operations
=
fix_annotate
)
operations
=
fix_annotate
)
#
Modify BBoxes to serve as valid edge cases
#
maps to convert data into valid edge case data
dataVoc
2
=
dataVoc2
.
map
(
input_columns
=
[
"image"
,
"annotation"
],
dataVoc
1
=
dataVoc1
.
map
(
input_columns
=
[
"image"
,
"annotation"
],
output_columns
=
[
"image"
,
"annotation"
],
output_columns
=
[
"image"
,
"annotation"
],
columns_order
=
[
"image"
,
"annotation"
],
columns_order
=
[
"image"
,
"annotation"
],
operations
=
[
gen_bbox_edge
])
operations
=
[
lambda
img
,
bboxes
:
(
img
,
np
.
array
([[
0
,
0
,
img
.
shape
[
1
],
img
.
shape
[
0
]]]).
astype
(
bboxes
.
dtype
))
])
#
map to apply ops
#
Test Op added to list of Operations here
dataVoc2
=
dataVoc2
.
map
(
input_columns
=
[
"image"
,
"annotation"
],
dataVoc2
=
dataVoc2
.
map
(
input_columns
=
[
"image"
,
"annotation"
],
output_columns
=
[
"image"
,
"annotation"
],
output_columns
=
[
"image"
,
"annotation"
],
columns_order
=
[
"image"
,
"annotation"
],
columns_order
=
[
"image"
,
"annotation"
],
operations
=
[
test_op
])
# Add column for "annotation"
operations
=
[
lambda
img
,
bboxes
:
(
img
,
np
.
array
([[
0
,
0
,
img
.
shape
[
1
],
img
.
shape
[
0
]]]).
astype
(
bboxes
.
dtype
)),
test_op
])
unaugSamp
,
augSamp
=
[],
[]
unaugSamp
,
augSamp
=
[],
[]
...
@@ -286,13 +202,15 @@ def test_random_crop_with_bbox_op_edge_c(plot_vis=False):
...
@@ -286,13 +202,15 @@ def test_random_crop_with_bbox_op_edge_c(plot_vis=False):
augSamp
.
append
(
Aug
)
augSamp
.
append
(
Aug
)
if
plot_vis
:
if
plot_vis
:
visualize
(
unaugSamp
,
augSamp
)
visualize
_with_bounding_boxes
(
unaugSamp
,
augSamp
)
def
test_random_crop_with_bbox_op_invalid_c
():
def
test_random_crop_with_bbox_op_invalid_c
():
"""
"""
Checking for invalid params passed to Aug Construct
or
Test RandomCropWithBBox Op on invalid constructor parameters, expected to raise ValueErr
or
"""
"""
logger
.
info
(
"test_random_crop_with_bbox_op_invalid_c"
)
# Load dataset
# Load dataset
dataVoc2
=
ds
.
VOCDataset
(
DATA_DIR
,
task
=
"Detection"
,
mode
=
"train"
,
decode
=
True
,
shuffle
=
False
)
dataVoc2
=
ds
.
VOCDataset
(
DATA_DIR
,
task
=
"Detection"
,
mode
=
"train"
,
decode
=
True
,
shuffle
=
False
)
...
@@ -300,8 +218,6 @@ def test_random_crop_with_bbox_op_invalid_c():
...
@@ -300,8 +218,6 @@ def test_random_crop_with_bbox_op_invalid_c():
# define test OP with values to match existing Op unit - test
# define test OP with values to match existing Op unit - test
test_op
=
c_vision
.
RandomCropWithBBox
([
512
,
512
,
375
])
test_op
=
c_vision
.
RandomCropWithBBox
([
512
,
512
,
375
])
# maps to fix annotations to HQ standard
dataVoc2
=
dataVoc2
.
map
(
input_columns
=
[
"annotation"
],
dataVoc2
=
dataVoc2
.
map
(
input_columns
=
[
"annotation"
],
output_columns
=
[
"annotation"
],
output_columns
=
[
"annotation"
],
operations
=
fix_annotate
)
operations
=
fix_annotate
)
...
@@ -320,35 +236,20 @@ def test_random_crop_with_bbox_op_invalid_c():
...
@@ -320,35 +236,20 @@ def test_random_crop_with_bbox_op_invalid_c():
def
test_random_crop_with_bbox_op_bad_c
():
def
test_random_crop_with_bbox_op_bad_c
():
# Should Fail - Errors logged to logger
"""
for
ix
,
badFunc
in
enumerate
(
badGenFuncs
):
Tests RandomCropWithBBox Op with invalid bounding boxes, expected to catch multiple errors.
try
:
"""
dataVoc2
=
ds
.
VOCDataset
(
DATA_DIR
,
task
=
"Detection"
,
mode
=
"train"
,
logger
.
info
(
"test_random_crop_with_bbox_op_bad_c"
)
decode
=
True
,
shuffle
=
False
)
test_op
=
c_vision
.
RandomCropWithBBox
([
512
,
512
],
[
200
,
200
,
200
,
200
])
test_op
=
c_vision
.
RandomCropWithBBox
([
512
,
512
],
[
200
,
200
,
200
,
200
])
data_voc2
=
ds
.
VOCDataset
(
DATA_DIR
,
task
=
"Detection"
,
mode
=
"train"
,
decode
=
True
,
shuffle
=
False
)
check_bad_bbox
(
data_voc2
,
test_op
,
InvalidBBoxType
.
WidthOverflow
,
"bounding boxes is out of bounds of the image"
)
dataVoc2
=
dataVoc2
.
map
(
input_columns
=
[
"annotation"
],
data_voc2
=
ds
.
VOCDataset
(
DATA_DIR
,
task
=
"Detection"
,
mode
=
"train"
,
decode
=
True
,
shuffle
=
False
)
output_columns
=
[
"annotation"
],
check_bad_bbox
(
data_voc2
,
test_op
,
InvalidBBoxType
.
HeightOverflow
,
"bounding boxes is out of bounds of the image"
)
operations
=
fix_annotate
)
data_voc2
=
ds
.
VOCDataset
(
DATA_DIR
,
task
=
"Detection"
,
mode
=
"train"
,
decode
=
True
,
shuffle
=
False
)
check_bad_bbox
(
data_voc2
,
test_op
,
InvalidBBoxType
.
NegativeXY
,
"min_x"
)
dataVoc2
=
dataVoc2
.
map
(
input_columns
=
[
"image"
,
"annotation"
],
data_voc2
=
ds
.
VOCDataset
(
DATA_DIR
,
task
=
"Detection"
,
mode
=
"train"
,
decode
=
True
,
shuffle
=
False
)
output_columns
=
[
"image"
,
"annotation"
],
check_bad_bbox
(
data_voc2
,
test_op
,
InvalidBBoxType
.
WrongShape
,
"4 features"
)
columns_order
=
[
"image"
,
"annotation"
],
operations
=
[
badFunc
])
# map to apply ops
dataVoc2
=
dataVoc2
.
map
(
input_columns
=
[
"image"
,
"annotation"
],
output_columns
=
[
"image"
,
"annotation"
],
columns_order
=
[
"image"
,
"annotation"
],
operations
=
[
test_op
])
for
_
in
dataVoc2
.
create_dict_iterator
():
break
# first sample will cause exception
except
RuntimeError
as
err
:
logger
.
info
(
"Got an exception in DE: {}"
.
format
(
str
(
err
)))
assert
assertVal
[
ix
]
in
str
(
err
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
...
...
tests/ut/python/dataset/test_random_horizontal_flip_bbox.py
已删除
100644 → 0
浏览文件 @
7f54d17b
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""
Testing the random horizontal flip with bounding boxes op in DE
"""
from
enum
import
Enum
import
matplotlib.pyplot
as
plt
import
matplotlib.patches
as
patches
import
numpy
as
np
import
mindspore.log
as
logger
import
mindspore.dataset
as
ds
import
mindspore.dataset.transforms.vision.c_transforms
as
c_vision
GENERATE_GOLDEN
=
False
DATA_DIR
=
"../data/dataset/testVOC2012_2"
class
BoxType
(
Enum
):
"""
Defines box types for test cases
"""
WidthOverflow
=
1
HeightOverflow
=
2
NegativeXY
=
3
OnEdge
=
4
WrongShape
=
5
def
add_bad_annotation
(
img
,
bboxes
,
box_type
):
"""
Used to generate erroneous bounding box examples on given img.
:param img: image where the bounding boxes are.
:param bboxes: in [x_min, y_min, w, h, label, truncate, difficult] format
:param box_type: type of bad box
:return: bboxes with bad examples added
"""
height
=
img
.
shape
[
0
]
width
=
img
.
shape
[
1
]
if
box_type
==
BoxType
.
WidthOverflow
:
# use box that overflows on width
return
img
,
np
.
array
([[
0
,
0
,
width
+
1
,
height
,
0
,
0
,
0
]]).
astype
(
np
.
uint32
)
if
box_type
==
BoxType
.
HeightOverflow
:
# use box that overflows on height
return
img
,
np
.
array
([[
0
,
0
,
width
,
height
+
1
,
0
,
0
,
0
]]).
astype
(
np
.
uint32
)
if
box_type
==
BoxType
.
NegativeXY
:
# use box with negative xy
return
img
,
np
.
array
([[
-
10
,
-
10
,
width
,
height
,
0
,
0
,
0
]]).
astype
(
np
.
uint32
)
if
box_type
==
BoxType
.
OnEdge
:
# use box that covers the whole image
return
img
,
np
.
array
([[
0
,
0
,
width
,
height
,
0
,
0
,
0
]]).
astype
(
np
.
uint32
)
if
box_type
==
BoxType
.
WrongShape
:
# use box that covers the whole image
return
img
,
np
.
array
([[
0
,
0
,
width
-
1
]]).
astype
(
np
.
uint32
)
return
img
,
bboxes
def
h_flip
(
image
):
"""
Apply the random_horizontal
"""
# that's why we flip here too
image
=
image
[:,
::
-
1
,
:]
return
image
def
check_bad_box
(
data
,
box_type
,
expected_error
):
"""
:param data: de object detection pipeline
:param box_type: type of bad box
:param expected_error: error expected to get due to bad box
:return: None
"""
# DEFINE TEST OP HERE -- (PROB 1 IN CASE OF RANDOM)
try
:
test_op
=
c_vision
.
RandomHorizontalFlipWithBBox
(
1
)
data
=
data
.
map
(
input_columns
=
[
"annotation"
],
output_columns
=
[
"annotation"
],
operations
=
fix_annotate
)
# map to use width overflow
data
=
data
.
map
(
input_columns
=
[
"image"
,
"annotation"
],
output_columns
=
[
"image"
,
"annotation"
],
columns_order
=
[
"image"
,
"annotation"
],
operations
=
lambda
img
,
bboxes
:
add_bad_annotation
(
img
,
bboxes
,
box_type
))
# map to apply ops
data
=
data
.
map
(
input_columns
=
[
"image"
,
"annotation"
],
output_columns
=
[
"image"
,
"annotation"
],
columns_order
=
[
"image"
,
"annotation"
],
operations
=
[
test_op
])
# Add column for "annotation"
for
_
,
_
in
enumerate
(
data
.
create_dict_iterator
()):
break
except
RuntimeError
as
error
:
logger
.
info
(
"Got an exception in DE: {}"
.
format
(
str
(
error
)))
assert
expected_error
in
str
(
error
)
def
fix_annotate
(
bboxes
):
"""
Fix annotations to format followed by mindspore.
:param bboxes: in [label, x_min, y_min, w, h, truncate, difficult] format
:return: annotation in [x_min, y_min, w, h, label, truncate, difficult] format
"""
for
bbox
in
bboxes
:
tmp
=
bbox
[
0
]
bbox
[
0
]
=
bbox
[
1
]
bbox
[
1
]
=
bbox
[
2
]
bbox
[
2
]
=
bbox
[
3
]
bbox
[
3
]
=
bbox
[
4
]
bbox
[
4
]
=
tmp
return
bboxes
def
add_bounding_boxes
(
axis
,
bboxes
):
"""
:param axis: axis to modify
:param bboxes: bounding boxes to draw on the axis
:return: None
"""
for
bbox
in
bboxes
:
rect
=
patches
.
Rectangle
((
bbox
[
0
],
bbox
[
1
]),
bbox
[
2
],
bbox
[
3
],
linewidth
=
1
,
edgecolor
=
'r'
,
facecolor
=
'none'
)
# Add the patch to the Axes
axis
.
add_patch
(
rect
)
def
visualize
(
unaugmented_data
,
augment_data
):
"""
:param unaugmented_data: original data
:param augment_data: data after augmentations
:return: None
"""
for
idx
,
(
un_aug_item
,
aug_item
)
in
\
enumerate
(
zip
(
unaugmented_data
.
create_dict_iterator
(),
augment_data
.
create_dict_iterator
())):
axis
=
plt
.
subplot
(
141
)
plt
.
imshow
(
un_aug_item
[
"image"
])
add_bounding_boxes
(
axis
,
un_aug_item
[
"annotation"
])
# add Orig BBoxes
plt
.
title
(
"Original"
+
str
(
idx
+
1
))
logger
.
info
(
"Original "
,
str
(
idx
+
1
),
" :"
,
un_aug_item
[
"annotation"
])
axis
=
plt
.
subplot
(
142
)
plt
.
imshow
(
aug_item
[
"image"
])
add_bounding_boxes
(
axis
,
aug_item
[
"annotation"
])
# add AugBBoxes
plt
.
title
(
"Augmented"
+
str
(
idx
+
1
))
logger
.
info
(
"Augmented "
,
str
(
idx
+
1
),
" "
,
aug_item
[
"annotation"
],
"
\n
"
)
plt
.
show
()
def
test_random_horizontal_bbox_op
(
plot
=
False
):
"""
Test RandomHorizontalFlipWithBBox op
Prints images side by side with and without Aug applied + bboxes to compare and test
"""
logger
.
info
(
"test_random_horizontal_bbox_c"
)
data_voc1
=
ds
.
VOCDataset
(
DATA_DIR
,
task
=
"Detection"
,
mode
=
"train"
,
decode
=
True
,
shuffle
=
False
)
data_voc2
=
ds
.
VOCDataset
(
DATA_DIR
,
task
=
"Detection"
,
mode
=
"train"
,
decode
=
True
,
shuffle
=
False
)
# DEFINE TEST OP HERE -- (PROB 1 IN CASE OF RANDOM)
test_op
=
c_vision
.
RandomHorizontalFlipWithBBox
(
1
)
# maps to fix annotations to minddata standard
data_voc1
=
data_voc1
.
map
(
input_columns
=
[
"annotation"
],
output_columns
=
[
"annotation"
],
operations
=
fix_annotate
)
data_voc2
=
data_voc2
.
map
(
input_columns
=
[
"annotation"
],
output_columns
=
[
"annotation"
],
operations
=
fix_annotate
)
# map to apply ops
data_voc2
=
data_voc2
.
map
(
input_columns
=
[
"image"
,
"annotation"
],
output_columns
=
[
"image"
,
"annotation"
],
columns_order
=
[
"image"
,
"annotation"
],
operations
=
[
test_op
])
# Add column for "annotation"
if
plot
:
visualize
(
data_voc1
,
data_voc2
)
def
test_random_horizontal_bbox_valid_prob_c
(
plot
=
False
):
"""
Test RandomHorizontalFlipWithBBox op
Prints images side by side with and without Aug applied + bboxes to compare and test
"""
logger
.
info
(
"test_random_horizontal_bbox_valid_prob_c"
)
data_voc1
=
ds
.
VOCDataset
(
DATA_DIR
,
task
=
"Detection"
,
mode
=
"train"
,
decode
=
True
,
shuffle
=
False
)
data_voc2
=
ds
.
VOCDataset
(
DATA_DIR
,
task
=
"Detection"
,
mode
=
"train"
,
decode
=
True
,
shuffle
=
False
)
# DEFINE TEST OP HERE -- (PROB 1 IN CASE OF RANDOM)
test_op
=
c_vision
.
RandomHorizontalFlipWithBBox
(
0.3
)
# maps to fix annotations to minddata standard
data_voc1
=
data_voc1
.
map
(
input_columns
=
[
"annotation"
],
output_columns
=
[
"annotation"
],
operations
=
fix_annotate
)
data_voc2
=
data_voc2
.
map
(
input_columns
=
[
"annotation"
],
output_columns
=
[
"annotation"
],
operations
=
fix_annotate
)
# map to apply ops
data_voc2
=
data_voc2
.
map
(
input_columns
=
[
"image"
,
"annotation"
],
output_columns
=
[
"image"
,
"annotation"
],
columns_order
=
[
"image"
,
"annotation"
],
operations
=
[
test_op
])
# Add column for "annotation"
if
plot
:
visualize
(
data_voc1
,
data_voc2
)
def
test_random_horizontal_bbox_invalid_prob_c
():
"""
Test RandomHorizontalFlipWithBBox op with invalid input probability
"""
logger
.
info
(
"test_random_horizontal_bbox_invalid_prob_c"
)
data_voc2
=
ds
.
VOCDataset
(
DATA_DIR
,
task
=
"Detection"
,
mode
=
"train"
,
decode
=
True
,
shuffle
=
False
)
try
:
# Note: Valid range of prob should be [0.0, 1.0]
test_op
=
c_vision
.
RandomHorizontalFlipWithBBox
(
1.5
)
data_voc2
=
data_voc2
.
map
(
input_columns
=
[
"annotation"
],
output_columns
=
[
"annotation"
],
operations
=
fix_annotate
)
# map to apply ops
data_voc2
=
data_voc2
.
map
(
input_columns
=
[
"image"
,
"annotation"
],
output_columns
=
[
"image"
,
"annotation"
],
columns_order
=
[
"image"
,
"annotation"
],
operations
=
[
test_op
])
# Add column for "annotation"
except
ValueError
as
error
:
logger
.
info
(
"Got an exception in DE: {}"
.
format
(
str
(
error
)))
assert
"Input is not"
in
str
(
error
)
def
test_random_horizontal_bbox_invalid_bounds_c
():
"""
Test RandomHorizontalFlipWithBBox op with invalid bounding boxes
"""
logger
.
info
(
"test_random_horizontal_bbox_invalid_bounds_c"
)
data_voc2
=
ds
.
VOCDataset
(
DATA_DIR
,
task
=
"Detection"
,
mode
=
"train"
,
decode
=
True
,
shuffle
=
False
)
check_bad_box
(
data_voc2
,
BoxType
.
WidthOverflow
,
"bounding boxes is out of bounds of the image"
)
data_voc2
=
ds
.
VOCDataset
(
DATA_DIR
,
task
=
"Detection"
,
mode
=
"train"
,
decode
=
True
,
shuffle
=
False
)
check_bad_box
(
data_voc2
,
BoxType
.
HeightOverflow
,
"bounding boxes is out of bounds of the image"
)
data_voc2
=
ds
.
VOCDataset
(
DATA_DIR
,
task
=
"Detection"
,
mode
=
"train"
,
decode
=
True
,
shuffle
=
False
)
check_bad_box
(
data_voc2
,
BoxType
.
NegativeXY
,
"min_x"
)
data_voc2
=
ds
.
VOCDataset
(
DATA_DIR
,
task
=
"Detection"
,
mode
=
"train"
,
decode
=
True
,
shuffle
=
False
)
check_bad_box
(
data_voc2
,
BoxType
.
WrongShape
,
"4 features"
)
if
__name__
==
"__main__"
:
# set to false to not show plots
test_random_horizontal_bbox_op
(
False
)
test_random_horizontal_bbox_valid_prob_c
(
False
)
test_random_horizontal_bbox_invalid_prob_c
()
test_random_horizontal_bbox_invalid_bounds_c
()
tests/ut/python/dataset/test_random_horizontal_flip_with_bbox.py
0 → 100644
浏览文件 @
5a886794
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""
Testing the random horizontal flip with bounding boxes op in DE
"""
import
numpy
as
np
import
mindspore.log
as
logger
import
mindspore.dataset
as
ds
import
mindspore.dataset.transforms.vision.c_transforms
as
c_vision
from
util
import
visualize_with_bounding_boxes
,
InvalidBBoxType
,
check_bad_bbox
,
\
config_get_set_seed
,
config_get_set_num_parallel_workers
,
save_and_check_md5
GENERATE_GOLDEN
=
False
DATA_DIR
=
"../data/dataset/testVOC2012_2"
def
fix_annotate
(
bboxes
):
"""
Fix annotations to format followed by mindspore.
:param bboxes: in [label, x_min, y_min, w, h, truncate, difficult] format
:return: annotation in [x_min, y_min, w, h, label, truncate, difficult] format
"""
for
bbox
in
bboxes
:
tmp
=
bbox
[
0
]
bbox
[
0
]
=
bbox
[
1
]
bbox
[
1
]
=
bbox
[
2
]
bbox
[
2
]
=
bbox
[
3
]
bbox
[
3
]
=
bbox
[
4
]
bbox
[
4
]
=
tmp
return
bboxes
def
test_random_horizontal_flip_with_bbox_op_c
(
plot_vis
=
False
):
"""
Prints images side by side with and without Aug applied + bboxes to
compare and test
"""
logger
.
info
(
"test_random_horizontal_flip_with_bbox_op_c"
)
# Load dataset
dataVoc1
=
ds
.
VOCDataset
(
DATA_DIR
,
task
=
"Detection"
,
mode
=
"train"
,
decode
=
True
,
shuffle
=
False
)
dataVoc2
=
ds
.
VOCDataset
(
DATA_DIR
,
task
=
"Detection"
,
mode
=
"train"
,
decode
=
True
,
shuffle
=
False
)
test_op
=
c_vision
.
RandomHorizontalFlipWithBBox
(
1
)
# maps to fix annotations to minddata standard
dataVoc1
=
dataVoc1
.
map
(
input_columns
=
[
"annotation"
],
output_columns
=
[
"annotation"
],
operations
=
fix_annotate
)
dataVoc2
=
dataVoc2
.
map
(
input_columns
=
[
"annotation"
],
output_columns
=
[
"annotation"
],
operations
=
fix_annotate
)
# map to apply ops
dataVoc2
=
dataVoc2
.
map
(
input_columns
=
[
"image"
,
"annotation"
],
output_columns
=
[
"image"
,
"annotation"
],
columns_order
=
[
"image"
,
"annotation"
],
operations
=
[
test_op
])
unaugSamp
,
augSamp
=
[],
[]
for
unAug
,
Aug
in
zip
(
dataVoc1
.
create_dict_iterator
(),
dataVoc2
.
create_dict_iterator
()):
unaugSamp
.
append
(
unAug
)
augSamp
.
append
(
Aug
)
if
plot_vis
:
visualize_with_bounding_boxes
(
unaugSamp
,
augSamp
)
def
test_random_horizontal_bbox_with_bbox_valid_rand_c
(
plot_vis
=
False
):
"""
Uses a valid non-default input, expect to pass
Prints images side by side with and without Aug applied + bboxes to
compare and test
"""
logger
.
info
(
"test_random_horizontal_bbox_valid_rand_c"
)
original_seed
=
config_get_set_seed
(
1
)
original_num_parallel_workers
=
config_get_set_num_parallel_workers
(
1
)
# Load dataset
dataVoc1
=
ds
.
VOCDataset
(
DATA_DIR
,
task
=
"Detection"
,
mode
=
"train"
,
decode
=
True
,
shuffle
=
False
)
dataVoc2
=
ds
.
VOCDataset
(
DATA_DIR
,
task
=
"Detection"
,
mode
=
"train"
,
decode
=
True
,
shuffle
=
False
)
test_op
=
c_vision
.
RandomHorizontalFlipWithBBox
(
0.6
)
# maps to fix annotations to minddata standard
dataVoc1
=
dataVoc1
.
map
(
input_columns
=
[
"annotation"
],
output_columns
=
[
"annotation"
],
operations
=
fix_annotate
)
dataVoc2
=
dataVoc2
.
map
(
input_columns
=
[
"annotation"
],
output_columns
=
[
"annotation"
],
operations
=
fix_annotate
)
# map to apply ops
dataVoc2
=
dataVoc2
.
map
(
input_columns
=
[
"image"
,
"annotation"
],
output_columns
=
[
"image"
,
"annotation"
],
columns_order
=
[
"image"
,
"annotation"
],
operations
=
[
test_op
])
filename
=
"random_horizontal_flip_with_bbox_01_c_result.npz"
save_and_check_md5
(
dataVoc2
,
filename
,
generate_golden
=
GENERATE_GOLDEN
)
unaugSamp
,
augSamp
=
[],
[]
for
unAug
,
Aug
in
zip
(
dataVoc1
.
create_dict_iterator
(),
dataVoc2
.
create_dict_iterator
()):
unaugSamp
.
append
(
unAug
)
augSamp
.
append
(
Aug
)
if
plot_vis
:
visualize_with_bounding_boxes
(
unaugSamp
,
augSamp
)
# Restore config setting
ds
.
config
.
set_seed
(
original_seed
)
ds
.
config
.
set_num_parallel_workers
(
original_num_parallel_workers
)
def
test_random_horizontal_flip_with_bbox_valid_edge_c
(
plot_vis
=
False
):
"""
Test RandomHorizontalFlipWithBBox op (testing with valid edge case, box covering full image).
Prints images side by side with and without Aug applied + bboxes to compare and test
"""
logger
.
info
(
"test_horizontal_flip_with_bbox_valid_edge_c"
)
dataVoc1
=
ds
.
VOCDataset
(
DATA_DIR
,
task
=
"Detection"
,
mode
=
"train"
,
decode
=
True
,
shuffle
=
False
)
dataVoc2
=
ds
.
VOCDataset
(
DATA_DIR
,
task
=
"Detection"
,
mode
=
"train"
,
decode
=
True
,
shuffle
=
False
)
test_op
=
c_vision
.
RandomHorizontalFlipWithBBox
(
1
)
# maps to fix annotations to minddata standard
dataVoc1
=
dataVoc1
.
map
(
input_columns
=
[
"annotation"
],
output_columns
=
[
"annotation"
],
operations
=
fix_annotate
)
dataVoc2
=
dataVoc2
.
map
(
input_columns
=
[
"annotation"
],
output_columns
=
[
"annotation"
],
operations
=
fix_annotate
)
# map to apply ops
# Add column for "annotation"
dataVoc1
=
dataVoc1
.
map
(
input_columns
=
[
"image"
,
"annotation"
],
output_columns
=
[
"image"
,
"annotation"
],
columns_order
=
[
"image"
,
"annotation"
],
operations
=
lambda
img
,
bbox
:
(
img
,
np
.
array
([[
0
,
0
,
img
.
shape
[
1
],
img
.
shape
[
0
],
0
,
0
,
0
]]).
astype
(
np
.
uint32
)))
dataVoc2
=
dataVoc2
.
map
(
input_columns
=
[
"image"
,
"annotation"
],
output_columns
=
[
"image"
,
"annotation"
],
columns_order
=
[
"image"
,
"annotation"
],
operations
=
lambda
img
,
bbox
:
(
img
,
np
.
array
([[
0
,
0
,
img
.
shape
[
1
],
img
.
shape
[
0
],
0
,
0
,
0
]]).
astype
(
np
.
uint32
)))
dataVoc2
=
dataVoc2
.
map
(
input_columns
=
[
"image"
,
"annotation"
],
output_columns
=
[
"image"
,
"annotation"
],
columns_order
=
[
"image"
,
"annotation"
],
operations
=
[
test_op
])
unaugSamp
,
augSamp
=
[],
[]
for
unAug
,
Aug
in
zip
(
dataVoc1
.
create_dict_iterator
(),
dataVoc2
.
create_dict_iterator
()):
unaugSamp
.
append
(
unAug
)
augSamp
.
append
(
Aug
)
if
plot_vis
:
visualize_with_bounding_boxes
(
unaugSamp
,
augSamp
)
def
test_random_horizontal_flip_with_bbox_invalid_prob_c
():
"""
Test RandomHorizontalFlipWithBBox op with invalid input probability
"""
logger
.
info
(
"test_random_horizontal_bbox_invalid_prob_c"
)
dataVoc2
=
ds
.
VOCDataset
(
DATA_DIR
,
task
=
"Detection"
,
mode
=
"train"
,
decode
=
True
,
shuffle
=
False
)
try
:
# Note: Valid range of prob should be [0.0, 1.0]
test_op
=
c_vision
.
RandomHorizontalFlipWithBBox
(
1.5
)
dataVoc2
=
dataVoc2
.
map
(
input_columns
=
[
"annotation"
],
output_columns
=
[
"annotation"
],
operations
=
fix_annotate
)
# map to apply ops
dataVoc2
=
dataVoc2
.
map
(
input_columns
=
[
"image"
,
"annotation"
],
output_columns
=
[
"image"
,
"annotation"
],
columns_order
=
[
"image"
,
"annotation"
],
operations
=
[
test_op
])
# Add column for "annotation"
except
ValueError
as
error
:
logger
.
info
(
"Got an exception in DE: {}"
.
format
(
str
(
error
)))
assert
"Input is not"
in
str
(
error
)
def
test_random_horizontal_flip_with_bbox_invalid_bounds_c
():
"""
Test RandomHorizontalFlipWithBBox op with invalid bounding boxes
"""
logger
.
info
(
"test_random_horizontal_bbox_invalid_bounds_c"
)
test_op
=
c_vision
.
RandomHorizontalFlipWithBBox
(
1
)
dataVoc2
=
ds
.
VOCDataset
(
DATA_DIR
,
task
=
"Detection"
,
mode
=
"train"
,
decode
=
True
,
shuffle
=
False
)
check_bad_bbox
(
dataVoc2
,
test_op
,
InvalidBBoxType
.
WidthOverflow
,
"bounding boxes is out of bounds of the image"
)
dataVoc2
=
ds
.
VOCDataset
(
DATA_DIR
,
task
=
"Detection"
,
mode
=
"train"
,
decode
=
True
,
shuffle
=
False
)
check_bad_bbox
(
dataVoc2
,
test_op
,
InvalidBBoxType
.
HeightOverflow
,
"bounding boxes is out of bounds of the image"
)
dataVoc2
=
ds
.
VOCDataset
(
DATA_DIR
,
task
=
"Detection"
,
mode
=
"train"
,
decode
=
True
,
shuffle
=
False
)
check_bad_bbox
(
dataVoc2
,
test_op
,
InvalidBBoxType
.
NegativeXY
,
"min_x"
)
dataVoc2
=
ds
.
VOCDataset
(
DATA_DIR
,
task
=
"Detection"
,
mode
=
"train"
,
decode
=
True
,
shuffle
=
False
)
check_bad_bbox
(
dataVoc2
,
test_op
,
InvalidBBoxType
.
WrongShape
,
"4 features"
)
if
__name__
==
"__main__"
:
# set to false to not show plots
test_random_horizontal_flip_with_bbox_op_c
(
plot_vis
=
False
)
test_random_horizontal_bbox_with_bbox_valid_rand_c
(
plot_vis
=
False
)
test_random_horizontal_flip_with_bbox_valid_edge_c
(
plot_vis
=
False
)
test_random_horizontal_flip_with_bbox_invalid_prob_c
()
test_random_horizontal_flip_with_bbox_invalid_bounds_c
()
tests/ut/python/dataset/test_random_resize_with_bbox.py
浏览文件 @
5a886794
...
@@ -15,251 +15,180 @@
...
@@ -15,251 +15,180 @@
"""
"""
Testing the random resize with bounding boxes op in DE
Testing the random resize with bounding boxes op in DE
"""
"""
from
enum
import
Enum
import
matplotlib.pyplot
as
plt
import
matplotlib.patches
as
patches
import
numpy
as
np
import
numpy
as
np
import
mindspore.dataset
as
ds
import
mindspore.dataset
as
ds
from
mindspore
import
log
as
logger
import
mindspore.dataset.transforms.vision.c_transforms
as
c_vision
import
mindspore.dataset.transforms.vision.c_transforms
as
c_vision
from
mindspore
import
log
as
logger
from
util
import
visualize_with_bounding_boxes
,
InvalidBBoxType
,
check_bad_bbox
,
\
config_get_set_seed
,
config_get_set_num_parallel_workers
,
save_and_check_md5
GENERATE_GOLDEN
=
False
GENERATE_GOLDEN
=
False
DATA_DIR
=
"../data/dataset/testVOC2012"
DATA_DIR
=
"../data/dataset/testVOC2012
_2
"
def
fix_annotate
(
bboxes
):
def
fix_annotate
(
bboxes
):
"""
"""
Fix annotations to format followed by mindspore.
:param bboxes: in [label, x_min, y_min, w, h, truncate, difficult] format
:param bboxes: in [label, x_min, y_min, w, h, truncate, difficult] format
:return: annotation in [x_min, y_min, w, h, label, truncate, difficult] format
:return: annotation in [x_min, y_min, w, h, label, truncate, difficult] format
"""
"""
for
bbox
in
bboxes
:
for
(
i
,
box
)
in
enumerate
(
bboxes
):
tmp
=
bbox
[
0
]
bboxes
[
i
]
=
np
.
roll
(
box
,
-
1
)
bbox
[
0
]
=
bbox
[
1
]
bbox
[
1
]
=
bbox
[
2
]
bbox
[
2
]
=
bbox
[
3
]
bbox
[
3
]
=
bbox
[
4
]
bbox
[
4
]
=
tmp
return
bboxes
return
bboxes
class
BoxType
(
Enum
):
def
test_random_resize_with_bbox_op_rand_c
(
plot_vis
=
False
):
"""
"""
Defines box types for test cases
Prints images and bboxes side by side with and without RandomResizeWithBBox Op applied,
tests with MD5 check, expected to pass
"""
"""
WidthOverflow
=
1
logger
.
info
(
"test_random_resize_with_bbox_rand_c"
)
HeightOverflow
=
2
original_seed
=
config_get_set_seed
(
1
)
NegativeXY
=
3
original_num_parallel_workers
=
config_get_set_num_parallel_workers
(
1
)
OnEdge
=
4
WrongShape
=
5
# Load dataset
dataVoc1
=
ds
.
VOCDataset
(
DATA_DIR
,
task
=
"Detection"
,
mode
=
"train"
,
decode
=
True
,
shuffle
=
False
)
class
AddBadAnnotation
:
# pylint: disable=too-few-public-methods
dataVoc2
=
ds
.
VOCDataset
(
DATA_DIR
,
task
=
"Detection"
,
mode
=
"train"
,
"""
decode
=
True
,
shuffle
=
False
)
Used to add erroneous bounding boxes to object detection pipelines.
Usage:
>>> # Adds a box that covers the whole image. Good for testing edge cases
>>> de = de.map(input_columns=["image", "annotation"],
>>> output_columns=["image", "annotation"],
>>> operations=AddBadAnnotation(BoxType.OnEdge))
"""
def
__init__
(
self
,
box_type
):
test_op
=
c_vision
.
RandomResizeWithBBox
(
200
)
self
.
box_type
=
box_type
dataVoc1
=
dataVoc1
.
map
(
input_columns
=
[
"annotation"
],
def
__call__
(
self
,
img
,
bboxes
):
output_columns
=
[
"annotation"
],
"""
operations
=
fix_annotate
)
Used to generate erroneous bounding box examples on given img.
dataVoc2
=
dataVoc2
.
map
(
input_columns
=
[
"annotation"
],
:param img: image where the bounding boxes are.
output_columns
=
[
"annotation"
],
:param bboxes: in [x_min, y_min, w, h, label, truncate, difficult] format
operations
=
fix_annotate
)
:return: bboxes with bad examples added
# map to apply ops
"""
dataVoc2
=
dataVoc2
.
map
(
input_columns
=
[
"image"
,
"annotation"
],
height
=
img
.
shape
[
0
]
output_columns
=
[
"image"
,
"annotation"
],
width
=
img
.
shape
[
1
]
columns_order
=
[
"image"
,
"annotation"
],
if
self
.
box_type
==
BoxType
.
WidthOverflow
:
operations
=
[
test_op
])
# use box that overflows on width
return
img
,
np
.
array
([[
0
,
0
,
width
+
1
,
height
-
1
,
0
,
0
,
0
]]).
astype
(
np
.
uint32
)
filename
=
"random_resize_with_bbox_op_01_c_result.npz"
save_and_check_md5
(
dataVoc2
,
filename
,
generate_golden
=
GENERATE_GOLDEN
)
if
self
.
box_type
==
BoxType
.
HeightOverflow
:
# use box that overflows on height
unaugSamp
,
augSamp
=
[],
[]
return
img
,
np
.
array
([[
0
,
0
,
width
-
1
,
height
+
1
,
0
,
0
,
0
]]).
astype
(
np
.
uint32
)
for
unAug
,
Aug
in
zip
(
dataVoc1
.
create_dict_iterator
(),
dataVoc2
.
create_dict_iterator
()):
if
self
.
box_type
==
BoxType
.
NegativeXY
:
unaugSamp
.
append
(
unAug
)
# use box with negative xy
augSamp
.
append
(
Aug
)
return
img
,
np
.
array
([[
-
10
,
-
10
,
width
-
1
,
height
-
1
,
0
,
0
,
0
]]).
astype
(
np
.
uint32
)
if
plot_vis
:
if
self
.
box_type
==
BoxType
.
OnEdge
:
visualize_with_bounding_boxes
(
unaugSamp
,
augSamp
)
# use box that covers the whole image
return
img
,
np
.
array
([[
0
,
0
,
width
-
1
,
height
-
1
,
0
,
0
,
0
]]).
astype
(
np
.
uint32
)
# Restore config setting
ds
.
config
.
set_seed
(
original_seed
)
if
self
.
box_type
==
BoxType
.
WrongShape
:
ds
.
config
.
set_num_parallel_workers
(
original_num_parallel_workers
)
# use box that covers the whole image
return
img
,
np
.
array
([[
0
,
0
,
width
-
1
]]).
astype
(
np
.
uint32
)
return
img
,
bboxes
def
test_random_resize_with_bbox_op_edge_c
(
plot_vis
=
False
):
def
check_bad_box
(
data
,
box_type
,
expected_error
):
try
:
test_op
=
c_vision
.
RandomResizeWithBBox
(
100
)
# DEFINE TEST OP HERE -- (PROB 1 IN CASE OF RANDOM)
data
=
data
.
map
(
input_columns
=
[
"annotation"
],
output_columns
=
[
"annotation"
],
operations
=
fix_annotate
)
# map to use width overflow
data
=
data
.
map
(
input_columns
=
[
"image"
,
"annotation"
],
output_columns
=
[
"image"
,
"annotation"
],
columns_order
=
[
"image"
,
"annotation"
],
operations
=
AddBadAnnotation
(
box_type
))
# Add column for "annotation"
# map to apply ops
data
=
data
.
map
(
input_columns
=
[
"image"
,
"annotation"
],
output_columns
=
[
"image"
,
"annotation"
],
columns_order
=
[
"image"
,
"annotation"
],
operations
=
[
test_op
])
# Add column for "annotation"
for
_
,
_
in
enumerate
(
data
.
create_dict_iterator
()):
break
except
RuntimeError
as
e
:
logger
.
info
(
"Got an exception in DE: {}"
.
format
(
str
(
e
)))
assert
expected_error
in
str
(
e
)
def
add_bounding_boxes
(
axis
,
bboxes
):
"""
:param axis: axis to modify
:param bboxes: bounding boxes to draw on the axis
:return: None
"""
for
bbox
in
bboxes
:
rect
=
patches
.
Rectangle
((
bbox
[
0
],
bbox
[
1
]),
bbox
[
2
],
bbox
[
3
],
linewidth
=
1
,
edgecolor
=
'r'
,
facecolor
=
'none'
)
# Add the patch to the Axes
axis
.
add_patch
(
rect
)
def
visualize
(
unaugmented_data
,
augment_data
):
for
idx
,
(
un_aug_item
,
aug_item
)
in
\
enumerate
(
zip
(
unaugmented_data
.
create_dict_iterator
(),
augment_data
.
create_dict_iterator
())):
axis
=
plt
.
subplot
(
141
)
plt
.
imshow
(
un_aug_item
[
"image"
])
add_bounding_boxes
(
axis
,
un_aug_item
[
"annotation"
])
# add Orig BBoxes
plt
.
title
(
"Original"
+
str
(
idx
+
1
))
logger
.
info
(
"Original "
,
str
(
idx
+
1
),
" :"
,
un_aug_item
[
"annotation"
])
axis
=
plt
.
subplot
(
142
)
plt
.
imshow
(
aug_item
[
"image"
])
add_bounding_boxes
(
axis
,
aug_item
[
"annotation"
])
# add AugBBoxes
plt
.
title
(
"Augmented"
+
str
(
idx
+
1
))
logger
.
info
(
"Augmented "
,
str
(
idx
+
1
),
" "
,
aug_item
[
"annotation"
],
"
\n
"
)
plt
.
show
()
def
test_random_resize_with_bbox_op
(
plot
=
False
):
"""
"""
Test random_resize_with_bbox_op
Prints images and bboxes side by side with and without RandomresizeWithBBox Op applied,
applied on dynamically generated edge case, expected to pass. edge case is when bounding
box has dimensions as the image itself.
"""
"""
logger
.
info
(
"Test random resize with bbox"
)
logger
.
info
(
"test_random_resize_with_bbox_op_edge_c"
)
dataVoc1
=
ds
.
VOCDataset
(
DATA_DIR
,
task
=
"Detection"
,
mode
=
"train"
,
decode
=
True
,
shuffle
=
False
)
# original images
dataVoc2
=
ds
.
VOCDataset
(
DATA_DIR
,
task
=
"Detection"
,
mode
=
"train"
,
data_original
=
ds
.
VOCDataset
(
DATA_DIR
,
task
=
"Detection"
,
mode
=
"train"
,
decode
=
True
,
shuffle
=
False
)
decode
=
True
,
shuffle
=
False
)
# augmented images
test_op
=
c_vision
.
RandomResizeWithBBox
(
500
)
data_augmented
=
ds
.
VOCDataset
(
DATA_DIR
,
task
=
"Detection"
,
mode
=
"train"
,
decode
=
True
,
shuffle
=
False
)
data_original
=
data_original
.
map
(
input_columns
=
[
"annotation"
],
dataVoc1
=
dataVoc1
.
map
(
input_columns
=
[
"annotation"
],
output_columns
=
[
"annotation"
],
output_columns
=
[
"annotation"
],
operations
=
fix_annotate
)
operations
=
fix_annotate
)
dataVoc2
=
dataVoc2
.
map
(
input_columns
=
[
"annotation"
],
output_columns
=
[
"annotation"
],
operations
=
fix_annotate
)
data_augmented
=
data_augmented
.
map
(
input_columns
=
[
"annotation"
],
# maps to convert data into valid edge case data
output_columns
=
[
"annotation"
],
dataVoc1
=
dataVoc1
.
map
(
input_columns
=
[
"image"
,
"annotation"
],
operations
=
fix_annotate
)
output_columns
=
[
"image"
,
"annotation"
],
columns_order
=
[
"image"
,
"annotation"
],
operations
=
[
lambda
img
,
bboxes
:
(
img
,
np
.
array
([[
0
,
0
,
img
.
shape
[
1
],
img
.
shape
[
0
]]]).
astype
(
bboxes
.
dtype
))])
# define map operations
dataVoc2
=
dataVoc2
.
map
(
input_columns
=
[
"image"
,
"annotation"
],
test_op
=
c_vision
.
RandomResizeWithBBox
(
100
)
# input value being the target size of resizeOp
output_columns
=
[
"image"
,
"annotation"
],
columns_order
=
[
"image"
,
"annotation"
],
operations
=
[
lambda
img
,
bboxes
:
(
img
,
np
.
array
([[
0
,
0
,
img
.
shape
[
1
],
img
.
shape
[
0
]]]).
astype
(
bboxes
.
dtype
)),
test_op
])
data_augmented
=
data_augmented
.
map
(
input_columns
=
[
"image"
,
"annotation"
],
unaugSamp
,
augSamp
=
[],
[]
output_columns
=
[
"image"
,
"annotation"
],
columns_order
=
[
"image"
,
"annotation"
],
operations
=
[
test_op
])
if
plot
:
visualize
(
data_original
,
data_augmented
)
for
unAug
,
Aug
in
zip
(
dataVoc1
.
create_dict_iterator
(),
dataVoc2
.
create_dict_iterator
()):
unaugSamp
.
append
(
unAug
)
augSamp
.
append
(
Aug
)
def
test_random_resize_with_bbox_invalid_bounds
():
if
plot_vis
:
data_voc2
=
ds
.
VOCDataset
(
DATA_DIR
,
task
=
"Detection"
,
mode
=
"train"
,
decode
=
True
,
shuffle
=
False
)
visualize_with_bounding_boxes
(
unaugSamp
,
augSamp
)
check_bad_box
(
data_voc2
,
BoxType
.
WidthOverflow
,
"bounding boxes is out of bounds of the image"
)
data_voc2
=
ds
.
VOCDataset
(
DATA_DIR
,
task
=
"Detection"
,
mode
=
"train"
,
decode
=
True
,
shuffle
=
False
)
check_bad_box
(
data_voc2
,
BoxType
.
HeightOverflow
,
"bounding boxes is out of bounds of the image"
)
data_voc2
=
ds
.
VOCDataset
(
DATA_DIR
,
task
=
"Detection"
,
mode
=
"train"
,
decode
=
True
,
shuffle
=
False
)
check_bad_box
(
data_voc2
,
BoxType
.
NegativeXY
,
"min_x"
)
data_voc2
=
ds
.
VOCDataset
(
DATA_DIR
,
task
=
"Detection"
,
mode
=
"train"
,
decode
=
True
,
shuffle
=
False
)
check_bad_box
(
data_voc2
,
BoxType
.
WrongShape
,
"4 features"
)
def
test_random_resize_with_bbox_invalid_size
():
def
test_random_resize_with_bbox_op_invalid_c
():
"""
Test RandomResizeWithBBox Op on invalid constructor parameters, expected to raise ValueError
"""
"""
Test random_resize_with_bbox_op
logger
.
info
(
"test_random_resize_with_bbox_op_invalid_c"
)
"""
logger
.
info
(
"Test random resize with bbox with invalid target size"
)
# original images
try
:
data
=
ds
.
VOCDataset
(
DATA_DIR
,
task
=
"Detection"
,
mode
=
"train"
,
decode
=
True
,
shuffle
=
False
)
# zero value for resize
c_vision
.
RandomResizeWithBBox
(
0
)
data
=
data
.
map
(
input_columns
=
[
"annotation"
],
except
ValueError
as
err
:
output_columns
=
[
"annotation"
],
logger
.
info
(
"Got an exception in DE: {}"
.
format
(
str
(
err
)))
operations
=
fix_annotate
)
assert
"Input is not"
in
str
(
err
)
# negative target size as input
try
:
try
:
test_op
=
c_vision
.
RandomResizeWithBBox
(
-
10
)
# DEFINE TEST OP HERE -- (PROB 1 IN CASE OF RANDOM)
# one of the size values is zero
c_vision
.
RandomResizeWithBBox
((
0
,
100
))
# map to apply ops
except
ValueError
as
err
:
data
=
data
.
map
(
input_columns
=
[
"image"
,
"annotation"
],
logger
.
info
(
"Got an exception in DE: {}"
.
format
(
str
(
err
)))
output_columns
=
[
"image"
,
"annotation"
],
assert
"Input is not"
in
str
(
err
)
columns_order
=
[
"image"
,
"annotation"
],
operations
=
[
test_op
])
# Add column for "annotation"
for
_
,
_
in
enumerate
(
data
.
create_dict_iterator
()):
try
:
break
# negative value for resize
c_vision
.
RandomResizeWithBBox
(
-
10
)
except
ValueError
as
e
:
except
ValueError
as
err
:
logger
.
info
(
"Got an exception in DE: {}"
.
format
(
str
(
e
)))
logger
.
info
(
"Got an exception in DE: {}"
.
format
(
str
(
err
)))
print
(
e
)
assert
"Input is not"
in
str
(
err
)
assert
"Input is not"
in
str
(
e
)
# zero target size as input
try
:
try
:
test_op
=
c_vision
.
RandomResizeWithBBox
(
0
)
# DEFINE TEST OP HERE -- (PROB 1 IN CASE OF RANDOM)
# invalid input shape
c_vision
.
RandomResizeWithBBox
((
100
,
100
,
100
))
# map to apply ops
except
TypeError
as
err
:
data
=
data
.
map
(
input_columns
=
[
"image"
,
"annotation"
],
logger
.
info
(
"Got an exception in DE: {}"
.
format
(
str
(
err
)))
output_columns
=
[
"image"
,
"annotation"
],
assert
"Size should be"
in
str
(
err
)
columns_order
=
[
"image"
,
"annotation"
],
operations
=
[
test_op
])
# Add column for "annotation"
for
_
,
_
in
enumerate
(
data
.
create_dict_iterator
()):
break
except
ValueError
as
e
:
def
test_random_resize_with_bbox_op_bad_c
():
logger
.
info
(
"Got an exception in DE: {}"
.
format
(
str
(
e
)))
"""
assert
"Input is not"
in
str
(
e
)
Tests RandomResizeWithBBox Op with invalid bounding boxes, expected to catch multiple errors
"""
# invalid input shape
logger
.
info
(
"test_random_resize_with_bbox_op_bad_c"
)
try
:
test_op
=
c_vision
.
RandomResizeWithBBox
((
400
,
300
))
test_op
=
c_vision
.
RandomResizeWithBBox
((
10
,
10
,
10
))
# DEFINE TEST OP HERE -- (PROB 1 IN CASE OF RANDOM)
# map to apply ops
data
=
data
.
map
(
input_columns
=
[
"image"
,
"annotation"
],
output_columns
=
[
"image"
,
"annotation"
],
columns_order
=
[
"image"
,
"annotation"
],
operations
=
[
test_op
])
# Add column for "annotation"
for
_
,
_
in
enumerate
(
data
.
create_dict_iterator
()):
data_voc2
=
ds
.
VOCDataset
(
DATA_DIR
,
task
=
"Detection"
,
mode
=
"train"
,
decode
=
True
,
shuffle
=
False
)
break
check_bad_bbox
(
data_voc2
,
test_op
,
InvalidBBoxType
.
WidthOverflow
,
"bounding boxes is out of bounds of the image"
)
data_voc2
=
ds
.
VOCDataset
(
DATA_DIR
,
task
=
"Detection"
,
mode
=
"train"
,
decode
=
True
,
shuffle
=
False
)
check_bad_bbox
(
data_voc2
,
test_op
,
InvalidBBoxType
.
HeightOverflow
,
"bounding boxes is out of bounds of the image"
)
data_voc2
=
ds
.
VOCDataset
(
DATA_DIR
,
task
=
"Detection"
,
mode
=
"train"
,
decode
=
True
,
shuffle
=
False
)
check_bad_bbox
(
data_voc2
,
test_op
,
InvalidBBoxType
.
NegativeXY
,
"min_x"
)
data_voc2
=
ds
.
VOCDataset
(
DATA_DIR
,
task
=
"Detection"
,
mode
=
"train"
,
decode
=
True
,
shuffle
=
False
)
check_bad_bbox
(
data_voc2
,
test_op
,
InvalidBBoxType
.
WrongShape
,
"4 features"
)
except
TypeError
as
e
:
logger
.
info
(
"Got an exception in DE: {}"
.
format
(
str
(
e
)))
assert
"Size should be"
in
str
(
e
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
test_random_resize_with_bbox_op
(
plot
=
False
)
test_random_resize_with_bbox_op_rand_c
(
plot_vis
=
False
)
test_random_resize_with_bbox_invalid_bounds
()
test_random_resize_with_bbox_op_edge_c
(
plot_vis
=
False
)
test_random_resize_with_bbox_invalid_size
()
test_random_resize_with_bbox_op_invalid_c
()
test_random_resize_with_bbox_op_bad_c
()
tests/ut/python/dataset/test_random_vertical_flip_with_bbox.py
浏览文件 @
5a886794
...
@@ -13,14 +13,17 @@
...
@@ -13,14 +13,17 @@
# limitations under the License.
# limitations under the License.
# ==============================================================================
# ==============================================================================
"""
"""
Testing RandomVerticalFlipWithBBox op
Testing RandomVerticalFlipWithBBox op
in DE
"""
"""
import
numpy
as
np
import
mindspore.dataset
as
ds
import
mindspore.dataset
as
ds
import
mindspore.dataset.transforms.vision.c_transforms
as
c_vision
import
mindspore.dataset.transforms.vision.c_transforms
as
c_vision
from
mindspore
import
log
as
logger
from
mindspore
import
log
as
logger
from
util
import
visualize_with_bounding_boxes
,
InvalidBBoxType
,
check_bad_bbox
from
util
import
visualize_with_bounding_boxes
,
InvalidBBoxType
,
check_bad_bbox
,
\
config_get_set_seed
,
config_get_set_num_parallel_workers
,
save_and_check_md5
GENERATE_GOLDEN
=
False
# updated VOC dataset with correct annotations
# updated VOC dataset with correct annotations
DATA_DIR
=
"../data/dataset/testVOC2012_2"
DATA_DIR
=
"../data/dataset/testVOC2012_2"
...
@@ -28,10 +31,9 @@ DATA_DIR = "../data/dataset/testVOC2012_2"
...
@@ -28,10 +31,9 @@ DATA_DIR = "../data/dataset/testVOC2012_2"
def
fix_annotate
(
bboxes
):
def
fix_annotate
(
bboxes
):
"""
"""
Update Current VOC dataset format to Proposed HQ BBox format
Fix annotations to format followed by mindspore.
:param bboxes: in [label, x_min, y_min, w, h, truncate, difficult] format
:param bboxes: as [label, x_min, y_min, w, h, truncate, difficult]
:return: annotation in [x_min, y_min, w, h, label, truncate, difficult] format
:return: annotation as [x_min, y_min, w, h, label, truncate, difficult]
"""
"""
for
bbox
in
bboxes
:
for
bbox
in
bboxes
:
tmp
=
bbox
[
0
]
tmp
=
bbox
[
0
]
...
@@ -45,9 +47,9 @@ def fix_annotate(bboxes):
...
@@ -45,9 +47,9 @@ def fix_annotate(bboxes):
def
test_random_vertical_flip_with_bbox_op_c
(
plot_vis
=
False
):
def
test_random_vertical_flip_with_bbox_op_c
(
plot_vis
=
False
):
"""
"""
Prints images side by side with and without Aug applied + bboxes to
Prints images and bboxes side by side with and without RandomVerticalFlipWithBBox Op applied
compare and test
"""
"""
logger
.
info
(
"test_random_vertical_flip_with_bbox_op_c"
)
# Load dataset
# Load dataset
dataVoc1
=
ds
.
VOCDataset
(
DATA_DIR
,
task
=
"Detection"
,
mode
=
"train"
,
dataVoc1
=
ds
.
VOCDataset
(
DATA_DIR
,
task
=
"Detection"
,
mode
=
"train"
,
decode
=
True
,
shuffle
=
False
)
decode
=
True
,
shuffle
=
False
)
...
@@ -57,7 +59,6 @@ def test_random_vertical_flip_with_bbox_op_c(plot_vis=False):
...
@@ -57,7 +59,6 @@ def test_random_vertical_flip_with_bbox_op_c(plot_vis=False):
test_op
=
c_vision
.
RandomVerticalFlipWithBBox
(
1
)
test_op
=
c_vision
.
RandomVerticalFlipWithBBox
(
1
)
# maps to fix annotations to HQ standard
dataVoc1
=
dataVoc1
.
map
(
input_columns
=
[
"annotation"
],
dataVoc1
=
dataVoc1
.
map
(
input_columns
=
[
"annotation"
],
output_columns
=
[
"annotation"
],
output_columns
=
[
"annotation"
],
operations
=
fix_annotate
)
operations
=
fix_annotate
)
...
@@ -82,9 +83,12 @@ def test_random_vertical_flip_with_bbox_op_c(plot_vis=False):
...
@@ -82,9 +83,12 @@ def test_random_vertical_flip_with_bbox_op_c(plot_vis=False):
def
test_random_vertical_flip_with_bbox_op_rand_c
(
plot_vis
=
False
):
def
test_random_vertical_flip_with_bbox_op_rand_c
(
plot_vis
=
False
):
"""
"""
Prints images
side by side with and without Aug applied + bboxes to
Prints images
and bboxes side by side with and without RandomVerticalFlipWithBBox Op applied,
compare and test
tests with MD5 check, expected to pass
"""
"""
logger
.
info
(
"test_random_vertical_flip_with_bbox_op_rand_c"
)
original_seed
=
config_get_set_seed
(
29847
)
original_num_parallel_workers
=
config_get_set_num_parallel_workers
(
1
)
# Load dataset
# Load dataset
dataVoc1
=
ds
.
VOCDataset
(
DATA_DIR
,
task
=
"Detection"
,
mode
=
"train"
,
dataVoc1
=
ds
.
VOCDataset
(
DATA_DIR
,
task
=
"Detection"
,
mode
=
"train"
,
...
@@ -93,9 +97,8 @@ def test_random_vertical_flip_with_bbox_op_rand_c(plot_vis=False):
...
@@ -93,9 +97,8 @@ def test_random_vertical_flip_with_bbox_op_rand_c(plot_vis=False):
dataVoc2
=
ds
.
VOCDataset
(
DATA_DIR
,
task
=
"Detection"
,
mode
=
"train"
,
dataVoc2
=
ds
.
VOCDataset
(
DATA_DIR
,
task
=
"Detection"
,
mode
=
"train"
,
decode
=
True
,
shuffle
=
False
)
decode
=
True
,
shuffle
=
False
)
test_op
=
c_vision
.
RandomVerticalFlipWithBBox
(
0.
6
)
test_op
=
c_vision
.
RandomVerticalFlipWithBBox
(
0.
8
)
# maps to fix annotations to HQ standard
dataVoc1
=
dataVoc1
.
map
(
input_columns
=
[
"annotation"
],
dataVoc1
=
dataVoc1
.
map
(
input_columns
=
[
"annotation"
],
output_columns
=
[
"annotation"
],
output_columns
=
[
"annotation"
],
operations
=
fix_annotate
)
operations
=
fix_annotate
)
...
@@ -108,6 +111,56 @@ def test_random_vertical_flip_with_bbox_op_rand_c(plot_vis=False):
...
@@ -108,6 +111,56 @@ def test_random_vertical_flip_with_bbox_op_rand_c(plot_vis=False):
columns_order
=
[
"image"
,
"annotation"
],
columns_order
=
[
"image"
,
"annotation"
],
operations
=
[
test_op
])
operations
=
[
test_op
])
filename
=
"random_vertical_flip_with_bbox_01_c_result.npz"
save_and_check_md5
(
dataVoc2
,
filename
,
generate_golden
=
GENERATE_GOLDEN
)
unaugSamp
,
augSamp
=
[],
[]
for
unAug
,
Aug
in
zip
(
dataVoc1
.
create_dict_iterator
(),
dataVoc2
.
create_dict_iterator
()):
unaugSamp
.
append
(
unAug
)
augSamp
.
append
(
Aug
)
if
plot_vis
:
visualize_with_bounding_boxes
(
unaugSamp
,
augSamp
)
# Restore config setting
ds
.
config
.
set_seed
(
original_seed
)
ds
.
config
.
set_num_parallel_workers
(
original_num_parallel_workers
)
def
test_random_vertical_flip_with_bbox_op_edge_c
(
plot_vis
=
False
):
"""
Prints images and bboxes side by side with and without RandomVerticalFlipWithBBox Op applied,
applied on dynamically generated edge case, expected to pass
"""
logger
.
info
(
"test_random_vertical_flip_with_bbox_op_edge_c"
)
dataVoc1
=
ds
.
VOCDataset
(
DATA_DIR
,
task
=
"Detection"
,
mode
=
"train"
,
decode
=
True
,
shuffle
=
False
)
dataVoc2
=
ds
.
VOCDataset
(
DATA_DIR
,
task
=
"Detection"
,
mode
=
"train"
,
decode
=
True
,
shuffle
=
False
)
test_op
=
c_vision
.
RandomVerticalFlipWithBBox
(
1
)
dataVoc1
=
dataVoc1
.
map
(
input_columns
=
[
"annotation"
],
output_columns
=
[
"annotation"
],
operations
=
fix_annotate
)
dataVoc2
=
dataVoc2
.
map
(
input_columns
=
[
"annotation"
],
output_columns
=
[
"annotation"
],
operations
=
fix_annotate
)
# maps to convert data into valid edge case data
dataVoc1
=
dataVoc1
.
map
(
input_columns
=
[
"image"
,
"annotation"
],
output_columns
=
[
"image"
,
"annotation"
],
columns_order
=
[
"image"
,
"annotation"
],
operations
=
[
lambda
img
,
bboxes
:
(
img
,
np
.
array
([[
0
,
0
,
img
.
shape
[
1
],
img
.
shape
[
0
]]]).
astype
(
bboxes
.
dtype
))])
# Test Op added to list of Operations here
dataVoc2
=
dataVoc2
.
map
(
input_columns
=
[
"image"
,
"annotation"
],
output_columns
=
[
"image"
,
"annotation"
],
columns_order
=
[
"image"
,
"annotation"
],
operations
=
[
lambda
img
,
bboxes
:
(
img
,
np
.
array
([[
0
,
0
,
img
.
shape
[
1
],
img
.
shape
[
0
]]]).
astype
(
bboxes
.
dtype
)),
test_op
])
unaugSamp
,
augSamp
=
[],
[]
unaugSamp
,
augSamp
=
[],
[]
for
unAug
,
Aug
in
zip
(
dataVoc1
.
create_dict_iterator
(),
dataVoc2
.
create_dict_iterator
()):
for
unAug
,
Aug
in
zip
(
dataVoc1
.
create_dict_iterator
(),
dataVoc2
.
create_dict_iterator
()):
...
@@ -119,16 +172,15 @@ def test_random_vertical_flip_with_bbox_op_rand_c(plot_vis=False):
...
@@ -119,16 +172,15 @@ def test_random_vertical_flip_with_bbox_op_rand_c(plot_vis=False):
def
test_random_vertical_flip_with_bbox_op_invalid_c
():
def
test_random_vertical_flip_with_bbox_op_invalid_c
():
# Should Fail
"""
# Load dataset
Test RandomVerticalFlipWithBBox Op on invalid constructor parameters, expected to raise ValueError
"""
logger
.
info
(
"test_random_vertical_flip_with_bbox_op_invalid_c"
)
dataVoc2
=
ds
.
VOCDataset
(
DATA_DIR
,
task
=
"Detection"
,
mode
=
"train"
,
dataVoc2
=
ds
.
VOCDataset
(
DATA_DIR
,
task
=
"Detection"
,
mode
=
"train"
,
decode
=
True
,
shuffle
=
False
)
decode
=
True
,
shuffle
=
False
)
try
:
try
:
test_op
=
c_vision
.
RandomVerticalFlipWithBBox
(
2
)
test_op
=
c_vision
.
RandomVerticalFlipWithBBox
(
2
)
# maps to fix annotations to HQ standard
dataVoc2
=
dataVoc2
.
map
(
input_columns
=
[
"annotation"
],
dataVoc2
=
dataVoc2
.
map
(
input_columns
=
[
"annotation"
],
output_columns
=
[
"annotation"
],
output_columns
=
[
"annotation"
],
operations
=
fix_annotate
)
operations
=
fix_annotate
)
...
@@ -148,9 +200,9 @@ def test_random_vertical_flip_with_bbox_op_invalid_c():
...
@@ -148,9 +200,9 @@ def test_random_vertical_flip_with_bbox_op_invalid_c():
def
test_random_vertical_flip_with_bbox_op_bad_c
():
def
test_random_vertical_flip_with_bbox_op_bad_c
():
"""
"""
Test
RandomHorizontalFlipWithBBox op with invalid bounding boxe
s
Test
s RandomVerticalFlipWithBBox Op with invalid bounding boxes, expected to catch multiple error
s
"""
"""
logger
.
info
(
"test_random_
horizontal_bbox_invalid_bounds
_c"
)
logger
.
info
(
"test_random_
vertical_flip_with_bbox_op_bad
_c"
)
test_op
=
c_vision
.
RandomVerticalFlipWithBBox
(
1
)
test_op
=
c_vision
.
RandomVerticalFlipWithBBox
(
1
)
data_voc2
=
ds
.
VOCDataset
(
DATA_DIR
,
task
=
"Detection"
,
mode
=
"train"
,
decode
=
True
,
shuffle
=
False
)
data_voc2
=
ds
.
VOCDataset
(
DATA_DIR
,
task
=
"Detection"
,
mode
=
"train"
,
decode
=
True
,
shuffle
=
False
)
...
@@ -166,5 +218,6 @@ def test_random_vertical_flip_with_bbox_op_bad_c():
...
@@ -166,5 +218,6 @@ def test_random_vertical_flip_with_bbox_op_bad_c():
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
test_random_vertical_flip_with_bbox_op_c
(
plot_vis
=
True
)
test_random_vertical_flip_with_bbox_op_c
(
plot_vis
=
True
)
test_random_vertical_flip_with_bbox_op_rand_c
(
plot_vis
=
True
)
test_random_vertical_flip_with_bbox_op_rand_c
(
plot_vis
=
True
)
test_random_vertical_flip_with_bbox_op_edge_c
(
plot_vis
=
True
)
test_random_vertical_flip_with_bbox_op_invalid_c
()
test_random_vertical_flip_with_bbox_op_invalid_c
()
test_random_vertical_flip_with_bbox_op_bad_c
()
test_random_vertical_flip_with_bbox_op_bad_c
()
tests/ut/python/dataset/test_resize_with_bbox.py
浏览文件 @
5a886794
...
@@ -15,281 +15,151 @@
...
@@ -15,281 +15,151 @@
"""
"""
Testing the resize with bounding boxes op in DE
Testing the resize with bounding boxes op in DE
"""
"""
from
enum
import
Enum
import
numpy
as
np
import
numpy
as
np
import
matplotlib.patches
as
patches
import
mindspore.dataset
as
ds
import
matplotlib.pyplot
as
plt
import
mindspore.dataset.transforms.vision.c_transforms
as
c_vision
import
mindspore.dataset.transforms.vision.c_transforms
as
c_vision
from
mindspore
import
log
as
logger
from
mindspore
import
log
as
logger
import
mindspore.dataset
as
ds
from
util
import
visualize_with_bounding_boxes
,
InvalidBBoxType
,
check_bad_bbox
,
\
save_and_check_md5
GENERATE_GOLDEN
=
False
GENERATE_GOLDEN
=
False
DATA_DIR
=
"../data/dataset/testVOC2012"
DATA_DIR
=
"../data/dataset/testVOC2012
_2
"
def
fix_annotate
(
bboxes
):
def
fix_annotate
(
bboxes
):
"""
"""
Fix annotations to format followed by mindspore.
:param bboxes: in [label, x_min, y_min, w, h, truncate, difficult] format
:param bboxes: in [label, x_min, y_min, w, h, truncate, difficult] format
:return: annotation in [x_min, y_min, w, h, label, truncate, difficult] format
:return: annotation in [x_min, y_min, w, h, label, truncate, difficult] format
"""
"""
for
bbox
in
bboxes
:
for
(
i
,
box
)
in
enumerate
(
bboxes
):
tmp
=
bbox
[
0
]
bboxes
[
i
]
=
np
.
roll
(
box
,
-
1
)
bbox
[
0
]
=
bbox
[
1
]
bbox
[
1
]
=
bbox
[
2
]
bbox
[
2
]
=
bbox
[
3
]
bbox
[
3
]
=
bbox
[
4
]
bbox
[
4
]
=
tmp
return
bboxes
return
bboxes
class
BoxType
(
Enum
):
def
test_resize_with_bbox_op_c
(
plot_vis
=
False
):
"""
"""
Defines box types for test cases
Prints images and bboxes side by side with and without ResizeWithBBox Op applied,
tests with MD5 check, expected to pass
"""
"""
WidthOverflow
=
1
logger
.
info
(
"test_resize_with_bbox_op_c"
)
HeightOverflow
=
2
NegativeXY
=
3
OnEdge
=
4
WrongShape
=
5
# Load dataset
dataVoc1
=
ds
.
VOCDataset
(
DATA_DIR
,
task
=
"Detection"
,
mode
=
"train"
,
decode
=
True
,
shuffle
=
False
)
class
AddBadAnnotation
:
# pylint: disable=too-few-public-methods
dataVoc2
=
ds
.
VOCDataset
(
DATA_DIR
,
task
=
"Detection"
,
mode
=
"train"
,
"""
decode
=
True
,
shuffle
=
False
)
Used to add erroneous bounding boxes to object detection pipelines.
Usage:
>>> # Adds a box that covers the whole image. Good for testing edge cases
>>> de = de.map(input_columns=["image", "annotation"],
>>> output_columns=["image", "annotation"],
>>> operations=AddBadAnnotation(BoxType.OnEdge))
"""
def
__init__
(
self
,
box_type
):
test_op
=
c_vision
.
ResizeWithBBox
(
200
)
self
.
box_type
=
box_type
def
__call__
(
self
,
img
,
bboxes
):
"""
Used to generate erroneous bounding box examples on given img.
:param img: image where the bounding boxes are.
:param bboxes: in [x_min, y_min, w, h, label, truncate, difficult] format
:return: bboxes with bad examples added
"""
height
=
img
.
shape
[
0
]
width
=
img
.
shape
[
1
]
if
self
.
box_type
==
BoxType
.
WidthOverflow
:
# use box that overflows on width
return
img
,
np
.
array
([[
0
,
0
,
width
+
1
,
height
-
1
,
0
,
0
,
0
]]).
astype
(
np
.
uint32
)
if
self
.
box_type
==
BoxType
.
HeightOverflow
:
# use box that overflows on height
return
img
,
np
.
array
([[
0
,
0
,
width
-
1
,
height
+
1
,
0
,
0
,
0
]]).
astype
(
np
.
uint32
)
if
self
.
box_type
==
BoxType
.
NegativeXY
:
# use box with negative xy
return
img
,
np
.
array
([[
-
10
,
-
10
,
width
-
1
,
height
-
1
,
0
,
0
,
0
]]).
astype
(
np
.
uint32
)
if
self
.
box_type
==
BoxType
.
OnEdge
:
# use box that covers the whole image
return
img
,
np
.
array
([[
0
,
0
,
width
-
1
,
height
-
1
,
0
,
0
,
0
]]).
astype
(
np
.
uint32
)
if
self
.
box_type
==
BoxType
.
WrongShape
:
# use box that covers the whole image
return
img
,
np
.
array
([[
0
,
0
,
width
-
1
]]).
astype
(
np
.
uint32
)
return
img
,
bboxes
def
check_bad_box
(
data
,
box_type
,
expected_error
):
try
:
test_op
=
c_vision
.
ResizeWithBBox
(
100
)
data
=
data
.
map
(
input_columns
=
[
"annotation"
],
output_columns
=
[
"annotation"
],
operations
=
fix_annotate
)
# map to use width overflow
data
=
data
.
map
(
input_columns
=
[
"image"
,
"annotation"
],
output_columns
=
[
"image"
,
"annotation"
],
columns_order
=
[
"image"
,
"annotation"
],
operations
=
AddBadAnnotation
(
box_type
))
# Add column for "annotation"
# map to apply ops
data
=
data
.
map
(
input_columns
=
[
"image"
,
"annotation"
],
output_columns
=
[
"image"
,
"annotation"
],
columns_order
=
[
"image"
,
"annotation"
],
operations
=
[
test_op
])
# Add column for "annotation"
for
_
,
_
in
enumerate
(
data
.
create_dict_iterator
()):
break
except
RuntimeError
as
e
:
logger
.
info
(
"Got an exception in DE: {}"
.
format
(
str
(
e
)))
assert
expected_error
in
str
(
e
)
def
add_bounding_boxes
(
axis
,
bboxes
):
"""
:param axis: axis to modify
:param bboxes: bounding boxes to draw on the axis
:return: None
"""
for
bbox
in
bboxes
:
rect
=
patches
.
Rectangle
((
bbox
[
0
],
bbox
[
1
]),
bbox
[
2
],
bbox
[
3
],
linewidth
=
1
,
edgecolor
=
'r'
,
facecolor
=
'none'
)
# Add the patch to the Axes
axis
.
add_patch
(
rect
)
def
visualize
(
unaugmented_data
,
augment_data
):
for
idx
,
(
un_aug_item
,
aug_item
)
in
enumerate
(
zip
(
unaugmented_data
.
create_dict_iterator
(),
augment_data
.
create_dict_iterator
())):
axis
=
plt
.
subplot
(
141
)
plt
.
imshow
(
un_aug_item
[
"image"
])
add_bounding_boxes
(
axis
,
un_aug_item
[
"annotation"
])
# add Orig BBoxes
plt
.
title
(
"Original"
+
str
(
idx
+
1
))
logger
.
info
(
"Original "
,
str
(
idx
+
1
),
" :"
,
un_aug_item
[
"annotation"
])
axis
=
plt
.
subplot
(
142
)
plt
.
imshow
(
aug_item
[
"image"
])
add_bounding_boxes
(
axis
,
aug_item
[
"annotation"
])
# add AugBBoxes
plt
.
title
(
"Augmented"
+
str
(
idx
+
1
))
logger
.
info
(
"Augmented "
,
str
(
idx
+
1
),
" "
,
aug_item
[
"annotation"
],
"
\n
"
)
plt
.
show
()
def
test_resize_with_bbox_op
(
plot
=
False
):
"""
Test resize_with_bbox_op
"""
logger
.
info
(
"Test resize with bbox"
)
# original images
data_original
=
ds
.
VOCDataset
(
DATA_DIR
,
task
=
"Detection"
,
mode
=
"train"
,
decode
=
True
,
shuffle
=
False
)
# augmented images
data_augmented
=
ds
.
VOCDataset
(
DATA_DIR
,
task
=
"Detection"
,
mode
=
"train"
,
decode
=
True
,
shuffle
=
False
)
data_original
=
data_original
.
map
(
input_columns
=
[
"annotation"
],
output_columns
=
[
"annotation"
],
operations
=
fix_annotate
)
data_augmented
=
data_augmented
.
map
(
input_columns
=
[
"annotation"
],
dataVoc1
=
dataVoc1
.
map
(
input_columns
=
[
"annotation"
],
output_columns
=
[
"annotation"
],
output_columns
=
[
"annotation"
],
operations
=
fix_annotate
)
operations
=
fix_annotate
)
dataVoc2
=
dataVoc2
.
map
(
input_columns
=
[
"annotation"
],
output_columns
=
[
"annotation"
],
operations
=
fix_annotate
)
# map to apply ops
dataVoc2
=
dataVoc2
.
map
(
input_columns
=
[
"image"
,
"annotation"
],
output_columns
=
[
"image"
,
"annotation"
],
columns_order
=
[
"image"
,
"annotation"
],
operations
=
[
test_op
])
# define map operations
filename
=
"resize_with_bbox_op_01_c_result.npz"
test_op
=
c_vision
.
ResizeWithBBox
(
100
)
# input value being the target size of resizeOp
save_and_check_md5
(
dataVoc2
,
filename
,
generate_golden
=
GENERATE_GOLDEN
)
data_augmented
=
data_augmented
.
map
(
input_columns
=
[
"image"
,
"annotation"
],
unaugSamp
,
augSamp
=
[],
[]
output_columns
=
[
"image"
,
"annotation"
],
columns_order
=
[
"image"
,
"annotation"
],
operations
=
[
test_op
])
if
plot
:
visualize
(
data_original
,
data_augmented
)
for
unAug
,
Aug
in
zip
(
dataVoc1
.
create_dict_iterator
(),
dataVoc2
.
create_dict_iterator
()):
unaugSamp
.
append
(
unAug
)
augSamp
.
append
(
Aug
)
def
test_resize_with_bbox_invalid_bounds
():
if
plot_vis
:
data_voc2
=
ds
.
VOCDataset
(
DATA_DIR
,
task
=
"Detection"
,
mode
=
"train"
,
decode
=
True
,
shuffle
=
False
)
visualize_with_bounding_boxes
(
unaugSamp
,
augSamp
)
check_bad_box
(
data_voc2
,
BoxType
.
WidthOverflow
,
"bounding boxes is out of bounds of the image"
)
data_voc2
=
ds
.
VOCDataset
(
DATA_DIR
,
task
=
"Detection"
,
mode
=
"train"
,
decode
=
True
,
shuffle
=
False
)
check_bad_box
(
data_voc2
,
BoxType
.
HeightOverflow
,
"bounding boxes is out of bounds of the image"
)
data_voc2
=
ds
.
VOCDataset
(
DATA_DIR
,
task
=
"Detection"
,
mode
=
"train"
,
decode
=
True
,
shuffle
=
False
)
check_bad_box
(
data_voc2
,
BoxType
.
NegativeXY
,
"min_x"
)
data_voc2
=
ds
.
VOCDataset
(
DATA_DIR
,
task
=
"Detection"
,
mode
=
"train"
,
decode
=
True
,
shuffle
=
False
)
check_bad_box
(
data_voc2
,
BoxType
.
WrongShape
,
"4 features"
)
def
test_resize_with_bbox_
invalid_size
(
):
def
test_resize_with_bbox_
op_edge_c
(
plot_vis
=
False
):
"""
"""
Test resize_with_bbox_op
Prints images and bboxes side by side with and without ResizeWithBBox Op applied,
"""
applied on dynamically generated edge case, expected to pass. edge case is when bounding
logger
.
info
(
"Test resize with bbox with invalid target size"
)
box has dimensions as the image itself.
"""
logger
.
info
(
"test_resize_with_bbox_op_edge_c"
)
dataVoc1
=
ds
.
VOCDataset
(
DATA_DIR
,
task
=
"Detection"
,
mode
=
"train"
,
decode
=
True
,
shuffle
=
False
)
# original images
dataVoc2
=
ds
.
VOCDataset
(
DATA_DIR
,
task
=
"Detection"
,
mode
=
"train"
,
data
=
ds
.
VOCDataset
(
DATA_DIR
,
task
=
"Detection"
,
mode
=
"train"
,
decode
=
True
,
shuffle
=
False
)
decode
=
True
,
shuffle
=
False
)
data
=
data
.
map
(
input_columns
=
[
"annotation"
],
test_op
=
c_vision
.
ResizeWithBBox
(
500
)
output_columns
=
[
"annotation"
],
operations
=
fix_annotate
)
# negative target size as input
dataVoc1
=
dataVoc1
.
map
(
input_columns
=
[
"annotation"
],
try
:
output_columns
=
[
"annotation"
],
test_op
=
c_vision
.
ResizeWithBBox
(
-
10
)
operations
=
fix_annotate
)
dataVoc2
=
dataVoc2
.
map
(
input_columns
=
[
"annotation"
],
output_columns
=
[
"annotation"
],
operations
=
fix_annotate
)
# map to apply ops
# maps to convert data into valid edge case data
data
=
data
.
map
(
input_columns
=
[
"image"
,
"annotation"
],
dataVoc1
=
dataVoc1
.
map
(
input_columns
=
[
"image"
,
"annotation"
],
output_columns
=
[
"image"
,
"annotation"
],
output_columns
=
[
"image"
,
"annotation"
],
columns_order
=
[
"image"
,
"annotation"
],
columns_order
=
[
"image"
,
"annotation"
],
operations
=
[
test_op
])
# Add column for "annotation"
operations
=
[
lambda
img
,
bboxes
:
(
img
,
np
.
array
([[
0
,
0
,
img
.
shape
[
1
],
img
.
shape
[
0
]]]).
astype
(
bboxes
.
dtype
))])
for
_
,
_
in
enumerate
(
data
.
create_dict_iterator
()):
# Test Op added to list of Operations here
break
dataVoc2
=
dataVoc2
.
map
(
input_columns
=
[
"image"
,
"annotation"
],
output_columns
=
[
"image"
,
"annotation"
],
columns_order
=
[
"image"
,
"annotation"
],
operations
=
[
lambda
img
,
bboxes
:
(
img
,
np
.
array
([[
0
,
0
,
img
.
shape
[
1
],
img
.
shape
[
0
]]]).
astype
(
bboxes
.
dtype
)),
test_op
])
except
ValueError
as
e
:
unaugSamp
,
augSamp
=
[],
[]
logger
.
info
(
"Got an exception in DE: {}"
.
format
(
str
(
e
)))
assert
"Input is not"
in
str
(
e
)
# zero target size as input
for
unAug
,
Aug
in
zip
(
dataVoc1
.
create_dict_iterator
(),
dataVoc2
.
create_dict_iterator
()):
try
:
unaugSamp
.
append
(
unAug
)
test_op
=
c_vision
.
ResizeWithBBox
(
0
)
augSamp
.
append
(
Aug
)
# map to apply ops
if
plot_vis
:
data
=
data
.
map
(
input_columns
=
[
"image"
,
"annotation"
],
visualize_with_bounding_boxes
(
unaugSamp
,
augSamp
)
output_columns
=
[
"image"
,
"annotation"
],
columns_order
=
[
"image"
,
"annotation"
],
operations
=
[
test_op
])
# Add column for "annotation"
for
_
,
_
in
enumerate
(
data
.
create_dict_iterator
()):
break
except
ValueError
as
e
:
def
test_resize_with_bbox_op_invalid_c
():
logger
.
info
(
"Got an exception in DE: {}"
.
format
(
str
(
e
)))
"""
assert
"Input is not"
in
str
(
e
)
Test ResizeWithBBox Op on invalid constructor parameters, expected to raise ValueError
"""
logger
.
info
(
"test_resize_with_bbox_op_invalid_c"
)
# invalid input shape
try
:
try
:
test_op
=
c_vision
.
ResizeWithBBox
((
10
,
10
,
10
))
# invalid interpolation value
c_vision
.
ResizeWithBBox
(
400
,
interpolation
=
"invalid"
)
# map to apply ops
except
ValueError
as
err
:
data
=
data
.
map
(
input_columns
=
[
"image"
,
"annotation"
],
logger
.
info
(
"Got an exception in DE: {}"
.
format
(
str
(
err
)))
output_columns
=
[
"image"
,
"annotation"
],
assert
"interpolation"
in
str
(
err
)
columns_order
=
[
"image"
,
"annotation"
],
operations
=
[
test_op
])
# Add column for "annotation"
for
_
,
_
in
enumerate
(
data
.
create_dict_iterator
()):
break
except
TypeError
as
e
:
def
test_resize_with_bbox_op_bad_c
():
logger
.
info
(
"Got an exception in DE: {}"
.
format
(
str
(
e
)))
assert
"Size should be"
in
str
(
e
)
def
test_resize_with_bbox_invalid_interpolation
():
"""
"""
Test resize_with_bbox_op
Tests ResizeWithBBox Op with invalid bounding boxes, expected to catch multiple errors
"""
"""
logger
.
info
(
"Test resize with bbox with invalid interpolation size"
)
logger
.
info
(
"test_resize_with_bbox_op_bad_c"
)
test_op
=
c_vision
.
ResizeWithBBox
((
200
,
300
))
# original images
data
=
ds
.
VOCDataset
(
DATA_DIR
,
task
=
"Detection"
,
mode
=
"train"
,
decode
=
True
,
shuffle
=
False
)
data
=
data
.
map
(
input_columns
=
[
"annotation"
],
output_columns
=
[
"annotation"
],
operations
=
fix_annotate
)
# invalid interpolation
try
:
test_op
=
c_vision
.
ResizeWithBBox
(
100
,
interpolation
=
"invalid"
)
# map to apply ops
data
=
data
.
map
(
input_columns
=
[
"image"
,
"annotation"
],
output_columns
=
[
"image"
,
"annotation"
],
columns_order
=
[
"image"
,
"annotation"
],
operations
=
[
test_op
])
# Add column for "annotation"
for
_
,
_
in
enumerate
(
data
.
create_dict_iterator
()):
data_voc2
=
ds
.
VOCDataset
(
DATA_DIR
,
task
=
"Detection"
,
mode
=
"train"
,
decode
=
True
,
shuffle
=
False
)
break
check_bad_bbox
(
data_voc2
,
test_op
,
InvalidBBoxType
.
WidthOverflow
,
"bounding boxes is out of bounds of the image"
)
data_voc2
=
ds
.
VOCDataset
(
DATA_DIR
,
task
=
"Detection"
,
mode
=
"train"
,
decode
=
True
,
shuffle
=
False
)
check_bad_bbox
(
data_voc2
,
test_op
,
InvalidBBoxType
.
HeightOverflow
,
"bounding boxes is out of bounds of the image"
)
data_voc2
=
ds
.
VOCDataset
(
DATA_DIR
,
task
=
"Detection"
,
mode
=
"train"
,
decode
=
True
,
shuffle
=
False
)
check_bad_bbox
(
data_voc2
,
test_op
,
InvalidBBoxType
.
NegativeXY
,
"min_x"
)
data_voc2
=
ds
.
VOCDataset
(
DATA_DIR
,
task
=
"Detection"
,
mode
=
"train"
,
decode
=
True
,
shuffle
=
False
)
check_bad_bbox
(
data_voc2
,
test_op
,
InvalidBBoxType
.
WrongShape
,
"4 features"
)
except
ValueError
as
e
:
logger
.
info
(
"Got an exception in DE: {}"
.
format
(
str
(
e
)))
assert
"interpolation"
in
str
(
e
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
test_resize_with_bbox_op
(
plot
=
False
)
test_resize_with_bbox_op
_c
(
plot_vis
=
False
)
test_resize_with_bbox_
invalid_bounds
(
)
test_resize_with_bbox_
op_edge_c
(
plot_vis
=
False
)
test_resize_with_bbox_
invalid_size
()
test_resize_with_bbox_
op_invalid_c
()
test_resize_with_bbox_
invalid_interpolation
()
test_resize_with_bbox_
op_bad_c
()
tests/ut/python/dataset/util.py
浏览文件 @
5a886794
...
@@ -312,34 +312,39 @@ def visualize_with_bounding_boxes(orig, aug, plot_rows=3):
...
@@ -312,34 +312,39 @@ def visualize_with_bounding_boxes(orig, aug, plot_rows=3):
if
len
(
orig
)
!=
len
(
aug
)
or
not
orig
:
if
len
(
orig
)
!=
len
(
aug
)
or
not
orig
:
return
return
comp_set
=
int
(
len
(
orig
)
/
plot_rows
)
batch_size
=
int
(
len
(
orig
)
/
plot_rows
)
# creates batches of images to plot together
split_point
=
batch_size
*
plot_rows
orig
,
aug
=
np
.
array
(
orig
),
np
.
array
(
aug
)
orig
,
aug
=
np
.
array
(
orig
),
np
.
array
(
aug
)
if
len
(
orig
)
>
plot_rows
:
if
len
(
orig
)
>
plot_rows
:
orig
=
np
.
split
(
orig
[:
comp_set
*
plot_rows
],
comp_set
)
+
[
orig
[
comp_set
*
plot_rows
:]]
# Create batches of required size and add remainder to last batch
aug
=
np
.
split
(
aug
[:
comp_set
*
plot_rows
],
comp_set
)
+
[
aug
[
comp_set
*
plot_rows
:]]
orig
=
np
.
split
(
orig
[:
split_point
],
batch_size
)
+
([
orig
[
split_point
:]]
if
(
split_point
<
orig
.
shape
[
0
])
else
[])
# check to avoid empty arrays being added
aug
=
np
.
split
(
aug
[:
split_point
],
batch_size
)
+
([
aug
[
split_point
:]]
if
(
split_point
<
aug
.
shape
[
0
])
else
[])
else
:
else
:
orig
=
[
orig
]
orig
=
[
orig
]
aug
=
[
aug
]
aug
=
[
aug
]
for
ix
,
allData
in
enumerate
(
zip
(
orig
,
aug
)):
for
ix
,
allData
in
enumerate
(
zip
(
orig
,
aug
)):
base_ix
=
ix
*
plot_rows
# will signal what base level we're on
base_ix
=
ix
*
plot_rows
# current batch starting index
curPlot
=
len
(
allData
[
0
])
sub_plot_count
=
2
if
(
len
(
allData
[
0
])
<
2
)
else
len
(
allData
[
0
])
# if 1 image remains, create subplot for 2 to simplify axis selection
fig
,
axs
=
plt
.
subplots
(
curPlot
,
2
)
fig
,
axs
=
plt
.
subplots
(
sub_plot_count
,
2
)
fig
.
tight_layout
(
pad
=
1.5
)
fig
.
tight_layout
(
pad
=
1.5
)
for
x
,
(
dataA
,
dataB
)
in
enumerate
(
zip
(
allData
[
0
],
allData
[
1
])):
for
x
,
(
dataA
,
dataB
)
in
enumerate
(
zip
(
allData
[
0
],
allData
[
1
])):
cur_ix
=
base_ix
+
x
cur_ix
=
base_ix
+
x
(
axA
,
axB
)
=
(
axs
[
x
,
0
],
axs
[
x
,
1
])
if
(
curPlot
>
1
)
else
(
axs
[
0
],
axs
[
1
])
# select plotting axes based on number of image rows on plot - else case when 1 row
axs
[
x
,
0
].
imshow
(
dataA
[
"image"
])
axA
.
imshow
(
dataA
[
"image"
])
add_bounding_boxes
(
axs
[
x
,
0
],
dataA
[
"annotation"
])
add_bounding_boxes
(
axA
,
dataA
[
"annotation"
])
axs
[
x
,
0
].
title
.
set_text
(
"Original"
+
str
(
cur_ix
+
1
))
axA
.
title
.
set_text
(
"Original"
+
str
(
cur_ix
+
1
))
logger
.
info
(
"Original **
\n
{} : {}"
.
format
(
str
(
cur_ix
+
1
),
dataA
[
"annotation"
]))
axs
[
x
,
1
].
imshow
(
dataB
[
"image"
])
axB
.
imshow
(
dataB
[
"image"
])
add_bounding_boxes
(
axs
[
x
,
1
],
dataB
[
"annotation"
])
add_bounding_boxes
(
axB
,
dataB
[
"annotation"
])
axs
[
x
,
1
].
title
.
set_text
(
"Augmented"
+
str
(
cur_ix
+
1
))
axB
.
title
.
set_text
(
"Augmented"
+
str
(
cur_ix
+
1
))
logger
.
info
(
"Original **
\n
{} : {}"
.
format
(
str
(
cur_ix
+
1
),
dataA
[
"annotation"
]))
logger
.
info
(
"Augmented **
\n
{} : {}
\n
"
.
format
(
str
(
cur_ix
+
1
),
dataB
[
"annotation"
]))
logger
.
info
(
"Augmented **
\n
{} : {}
\n
"
.
format
(
str
(
cur_ix
+
1
),
dataB
[
"annotation"
]))
plt
.
show
()
plt
.
show
()
...
...
tests/ut/python/ops/test_signature.py
浏览文件 @
5a886794
...
@@ -16,7 +16,6 @@
...
@@ -16,7 +16,6 @@
test assign sub
test assign sub
"""
"""
import
numpy
as
np
import
numpy
as
np
import
pytest
import
mindspore.context
as
context
import
mindspore.context
as
context
import
mindspore.nn
as
nn
import
mindspore.nn
as
nn
...
@@ -36,27 +35,6 @@ class AssignW(nn.Cell):
...
@@ -36,27 +35,6 @@ class AssignW(nn.Cell):
return
x
return
x
class
Net
(
nn
.
Cell
):
def
__init__
(
self
):
super
(
Net
,
self
).
__init__
()
self
.
b
=
Parameter
(
initializer
(
'ones'
,
[
5
]),
name
=
'b'
)
self
.
assign
=
AssignW
()
def
construct
(
self
,
value
):
return
self
.
assign
(
self
.
b
,
value
)
def
test_assign_through_cell
():
context
.
set_context
(
mode
=
context
.
GRAPH_MODE
)
net
=
Net
()
net
.
to_float
(
ms
.
float16
)
net
.
add_flags_recursive
(
fp16
=
False
)
input_data
=
Tensor
(
np
.
ones
([
5
]).
astype
(
np
.
float32
))
net
(
input_data
)
with
pytest
.
raises
(
TypeError
):
net
(
None
)
class
AssignOp
(
nn
.
Cell
):
class
AssignOp
(
nn
.
Cell
):
def
__init__
(
self
):
def
__init__
(
self
):
super
(
AssignOp
,
self
).
__init__
()
super
(
AssignOp
,
self
).
__init__
()
...
...
tests/ut/python/parallel/test_parallel_optimizer.py
0 → 100644
浏览文件 @
5a886794
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
""" test adam """
import
numpy
as
np
import
pytest
import
mindspore.nn
as
nn
from
mindspore
import
Tensor
from
mindspore.common.api
import
_executor
from
mindspore.nn
import
TrainOneStepCell
,
WithLossCell
from
mindspore.nn.optim
import
Adam
,
AdamWeightDecay
,
AdamWeightDecayDynamicLR
,
Lamb
from
mindspore.ops
import
operations
as
P
from
mindspore.parallel._auto_parallel_context
import
auto_parallel_context
from
mindspore
import
context
class
Net
(
nn
.
Cell
):
"""Net definition"""
def
__init__
(
self
):
super
(
Net
,
self
).
__init__
()
self
.
fc1
=
nn
.
Dense
(
128
,
768
,
activation
=
'relu'
)
self
.
fc2
=
nn
.
Dense
(
128
,
768
,
activation
=
'relu'
)
self
.
fc3
=
nn
.
Dense
(
128
,
768
,
activation
=
'relu'
)
self
.
fc4
=
nn
.
Dense
(
768
,
768
,
activation
=
'relu'
)
self
.
relu4
=
nn
.
ReLU
()
self
.
relu5
=
nn
.
ReLU
()
self
.
transpose
=
P
.
Transpose
()
self
.
matmul1
=
P
.
MatMul
()
self
.
matmul2
=
P
.
MatMul
()
def
construct
(
self
,
x
):
q
=
self
.
fc1
(
x
)
k
=
self
.
fc2
(
x
)
v
=
self
.
fc3
(
x
)
k
=
self
.
transpose
(
k
,
(
1
,
0
))
c
=
self
.
relu4
(
self
.
matmul1
(
q
,
k
))
s
=
self
.
relu5
(
self
.
matmul2
(
c
,
v
))
s
=
self
.
fc4
(
s
)
return
s
def
test_AdamWeightDecayDynamicLR
():
""" test_AdamWeightDecayDynamicLR """
auto_parallel_context
().
set_enable_parallel_optimizer
(
True
)
context
.
set_auto_parallel_context
(
parallel_mode
=
"data_parallel"
,
device_num
=
2
)
inputs
=
Tensor
(
np
.
ones
([
32
,
128
]).
astype
(
np
.
float32
))
label
=
Tensor
(
np
.
zeros
([
32
,
768
]).
astype
(
np
.
float32
))
net
=
Net
()
net
.
set_train
()
loss
=
nn
.
SoftmaxCrossEntropyWithLogits
()
optimizer
=
AdamWeightDecayDynamicLR
(
net
.
trainable_params
(),
decay_steps
=
20
,
learning_rate
=
0.1
)
net_with_loss
=
WithLossCell
(
net
,
loss
)
train_network
=
TrainOneStepCell
(
net_with_loss
,
optimizer
)
_executor
.
compile
(
train_network
,
inputs
,
label
)
def
test_AdamWeightDecay
():
""" test_AdamWeightDecayDynamicLR """
auto_parallel_context
().
set_enable_parallel_optimizer
(
True
)
context
.
set_auto_parallel_context
(
parallel_mode
=
"data_parallel"
,
device_num
=
2
)
inputs
=
Tensor
(
np
.
ones
([
32
,
128
]).
astype
(
np
.
float32
))
label
=
Tensor
(
np
.
zeros
([
32
,
768
]).
astype
(
np
.
float32
))
net
=
Net
()
net
.
set_train
()
loss
=
nn
.
SoftmaxCrossEntropyWithLogits
()
optimizer
=
AdamWeightDecay
(
net
.
trainable_params
(),
learning_rate
=
0.1
)
net_with_loss
=
WithLossCell
(
net
,
loss
)
train_network
=
TrainOneStepCell
(
net_with_loss
,
optimizer
)
_executor
.
compile
(
train_network
,
inputs
,
label
)
def
test_lamb_compile
():
""" test_Lamb_compile """
auto_parallel_context
().
set_enable_parallel_optimizer
(
True
)
context
.
set_auto_parallel_context
(
parallel_mode
=
"auto_parallel"
,
device_num
=
2
)
inputs
=
Tensor
(
np
.
ones
([
32
,
128
]).
astype
(
np
.
float32
))
label
=
Tensor
(
np
.
zeros
([
32
,
768
]).
astype
(
np
.
float32
))
net
=
Net
()
net
.
set_train
()
loss
=
nn
.
SoftmaxCrossEntropyWithLogits
()
optimizer
=
Lamb
(
net
.
trainable_params
(),
decay_steps
=
10
)
net_with_loss
=
WithLossCell
(
net
,
loss
)
train_network
=
TrainOneStepCell
(
net_with_loss
,
optimizer
)
_executor
.
compile
(
train_network
,
inputs
,
label
)
def
test_edge_case
():
""" test_edge_case """
auto_parallel_context
().
set_enable_parallel_optimizer
(
True
)
net
=
Net
()
with
pytest
.
raises
(
RuntimeError
):
context
.
set_auto_parallel_context
(
parallel_mode
=
"stand_alone"
)
Lamb
(
net
.
trainable_params
(),
decay_steps
=
10
)
with
pytest
.
raises
(
RuntimeError
):
Adam
(
net
.
trainable_params
(),
learning_rate
=
0.1
)
with
pytest
.
raises
(
RuntimeError
):
context
.
set_auto_parallel_context
(
device_num
=
16
)
Lamb
(
net
.
trainable_params
(),
decay_steps
=
10
)
tests/ut/python/parallel/test_set_auto_parallel_context.py
浏览文件 @
5a886794
...
@@ -81,6 +81,10 @@ def test_set_auto_parallel_context():
...
@@ -81,6 +81,10 @@ def test_set_auto_parallel_context():
with
pytest
.
raises
(
ValueError
):
with
pytest
.
raises
(
ValueError
):
set_algo_parameters
(
tensor_slice_align_size
=
1025
)
set_algo_parameters
(
tensor_slice_align_size
=
1025
)
auto_parallel_context
().
set_enable_parallel_optimizer
(
True
)
assert
auto_parallel_context
().
get_enable_parallel_optimizer
()
is
True
assert
not
auto_parallel_context
().
get_all_reduce_fusion_split_indices
()
def
test_reset_auto_parallel_context
():
def
test_reset_auto_parallel_context
():
context
.
reset_auto_parallel_context
()
context
.
reset_auto_parallel_context
()
...
...
tests/ut/python/pynative_mode/ge/model/__init__.py
已删除
100644 → 0
浏览文件 @
7f54d17b
tests/ut/python/pynative_mode/ge/model/test_lenet_model.py
已删除
100644 → 0
浏览文件 @
7f54d17b
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
""" test_lenet_model """
import
numpy
as
np
import
pytest
import
mindspore.nn
as
nn
from
mindspore.common.tensor
import
Tensor
from
mindspore.nn
import
WithGradCell
from
mindspore.ops
import
operations
as
P
class
LeNet5
(
nn
.
Cell
):
""" LeNet5 definition """
def
__init__
(
self
):
super
(
LeNet5
,
self
).
__init__
()
self
.
conv1
=
nn
.
Conv2d
(
1
,
6
,
5
,
pad_mode
=
'valid'
)
self
.
conv2
=
nn
.
Conv2d
(
6
,
16
,
5
,
pad_mode
=
'valid'
)
self
.
fc1
=
nn
.
Dense
(
16
*
5
*
5
,
120
)
self
.
fc2
=
nn
.
Dense
(
120
,
84
)
self
.
fc3
=
nn
.
Dense
(
84
,
10
)
self
.
relu
=
nn
.
ReLU
()
self
.
max_pool2d
=
nn
.
MaxPool2d
(
kernel_size
=
2
,
stride
=
2
)
self
.
flatten
=
P
.
Flatten
()
def
construct
(
self
,
x
):
x
=
self
.
max_pool2d
(
self
.
relu
(
self
.
conv1
(
x
)))
x
=
self
.
max_pool2d
(
self
.
relu
(
self
.
conv2
(
x
)))
x
=
self
.
flatten
(
x
)
x
=
self
.
relu
(
self
.
fc1
(
x
))
x
=
self
.
relu
(
self
.
fc2
(
x
))
x
=
self
.
fc3
(
x
)
return
x
@
pytest
.
mark
.
skip
(
reason
=
"need ge backend"
)
def
test_lenet_pynative_train_net
():
""" test_lenet_pynative_train_net """
data
=
Tensor
(
np
.
ones
([
1
,
1
,
32
,
32
]).
astype
(
np
.
float32
)
*
0.01
)
label
=
Tensor
(
np
.
ones
([
1
,
10
]).
astype
(
np
.
float32
))
dout
=
Tensor
(
np
.
ones
([
1
]).
astype
(
np
.
float32
))
iteration_num
=
1
verification_step
=
0
net
=
LeNet5
()
for
i
in
range
(
0
,
iteration_num
):
# get the gradients
loss_fn
=
nn
.
SoftmaxCrossEntropyWithLogits
(
is_grad
=
False
)
grad_fn
=
nn
.
SoftmaxCrossEntropyWithLogits
()
grad_net
=
WithGradCell
(
net
,
grad_fn
,
sens
=
dout
)
def
test_lenet_pynative_train_model
():
""" test_lenet_pynative_train_model """
# get loss from model.compute_loss
return
tests/ut/python/utils/test_serialize.py
浏览文件 @
5a886794
...
@@ -336,6 +336,7 @@ class PrintNet(nn.Cell):
...
@@ -336,6 +336,7 @@ class PrintNet(nn.Cell):
def
construct
(
self
,
int8
,
uint8
,
int16
,
uint16
,
int32
,
uint32
,
int64
,
uint64
,
flt16
,
flt32
,
flt64
,
bool_
,
def
construct
(
self
,
int8
,
uint8
,
int16
,
uint16
,
int32
,
uint32
,
int64
,
uint64
,
flt16
,
flt32
,
flt64
,
bool_
,
scale1
,
scale2
):
scale1
,
scale2
):
self
.
print
(
'============tensor int8:=============='
,
int8
)
self
.
print
(
'============tensor int8:=============='
,
int8
)
self
.
print
(
'============tensor int8:=============='
,
int8
)
self
.
print
(
'============tensor uint8:=============='
,
uint8
)
self
.
print
(
'============tensor uint8:=============='
,
uint8
)
self
.
print
(
'============tensor int16:=============='
,
int16
)
self
.
print
(
'============tensor int16:=============='
,
int16
)
self
.
print
(
'============tensor uint16:=============='
,
uint16
)
self
.
print
(
'============tensor uint16:=============='
,
uint16
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录