Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
e5c35169
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
e5c35169
编写于
5月 29, 2020
作者:
Y
Yi Huaijie
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
support load full dataset on each device
上级
8de8289c
变更
14
隐藏空白更改
内联
并排
Showing
14 changed file
with
229 addition
and
79 deletion
+229
-79
mindspore/ccsrc/parallel/context.cc
mindspore/ccsrc/parallel/context.cc
+3
-0
mindspore/ccsrc/parallel/context.h
mindspore/ccsrc/parallel/context.h
+4
-0
mindspore/ccsrc/parallel/ops_info/get_next_info.cc
mindspore/ccsrc/parallel/ops_info/get_next_info.cc
+15
-2
mindspore/ccsrc/parallel/ops_info/virtual_dataset_info.cc
mindspore/ccsrc/parallel/ops_info/virtual_dataset_info.cc
+29
-50
mindspore/ccsrc/parallel/step_parallel.cc
mindspore/ccsrc/parallel/step_parallel.cc
+9
-1
mindspore/ccsrc/pipeline/init.cc
mindspore/ccsrc/pipeline/init.cc
+2
-0
mindspore/context.py
mindspore/context.py
+3
-1
mindspore/parallel/_auto_parallel_context.py
mindspore/parallel/_auto_parallel_context.py
+21
-3
mindspore/parallel/_utils.py
mindspore/parallel/_utils.py
+16
-0
mindspore/train/dataset_helper.py
mindspore/train/dataset_helper.py
+7
-13
model_zoo/wide_and_deep/src/config.py
model_zoo/wide_and_deep/src/config.py
+3
-0
model_zoo/wide_and_deep/src/metrics.py
model_zoo/wide_and_deep/src/metrics.py
+15
-5
model_zoo/wide_and_deep/train_and_test_multinpu_auto_parallel.py
...oo/wide_and_deep/train_and_test_multinpu_auto_parallel.py
+13
-4
tests/ut/python/parallel/test_full_batch.py
tests/ut/python/parallel/test_full_batch.py
+89
-0
未找到文件。
mindspore/ccsrc/parallel/context.cc
浏览文件 @
e5c35169
...
...
@@ -48,6 +48,7 @@ ParallelContext::ParallelContext() { Reset(); }
void
ParallelContext
::
Reset
()
{
mirror_mean_
=
false
;
full_batch_
=
false
;
cast_before_mirror_
=
true
;
loss_repeated_mean_
=
true
;
device_num_
=
1
;
...
...
@@ -75,6 +76,8 @@ void ParallelContext::set_global_rank(int32_t global_rank) {
void
ParallelContext
::
set_mirror_mean
(
bool
mirror_mean
)
{
mirror_mean_
=
mirror_mean
;
}
void
ParallelContext
::
set_full_batch
(
bool
full_batch
)
{
full_batch_
=
full_batch
;
}
void
ParallelContext
::
set_cast_before_mirror
(
bool
cast_before_mirror
)
{
cast_before_mirror_
=
cast_before_mirror
;
}
void
ParallelContext
::
set_loss_repeated_mean
(
bool
loss_repeated_mean
)
{
loss_repeated_mean_
=
loss_repeated_mean
;
}
...
...
mindspore/ccsrc/parallel/context.h
浏览文件 @
e5c35169
...
...
@@ -55,6 +55,9 @@ class ParallelContext {
void
set_mirror_mean
(
bool
mirror_mean
);
bool
mirror_mean
()
const
{
return
mirror_mean_
;
}
void
set_full_batch
(
bool
full_batch
);
bool
full_batch
()
const
{
return
full_batch_
;
}
void
set_cast_before_mirror
(
bool
cast_before_mirror
);
bool
cast_before_mirror
()
const
{
return
cast_before_mirror_
;
}
...
...
@@ -103,6 +106,7 @@ class ParallelContext {
ParallelContext
();
static
std
::
shared_ptr
<
ParallelContext
>
inst_context_
;
bool
mirror_mean_
;
bool
full_batch_
;
bool
cast_before_mirror_
;
bool
loss_repeated_mean_
;
int32_t
device_num_
;
...
...
mindspore/ccsrc/parallel/ops_info/get_next_info.cc
浏览文件 @
e5c35169
...
...
@@ -24,15 +24,23 @@
#include "ir/value.h"
#include "parallel/device_matrix.h"
#include "parallel/strategy.h"
#include "parallel/context.h"
#include "parallel/tensor_layout/tensor_redistribution.h"
namespace
mindspore
{
namespace
parallel
{
Status
GetNextInfo
::
InferTensorMap
()
{
MS_EXCEPTION_IF_NULL
(
ParallelContext
::
GetInstance
());
bool
full_batch
=
ParallelContext
::
GetInstance
()
->
full_batch
();
for
(
auto
shp
:
shapes_
)
{
TensorMap
out_tensor_map
;
for
(
size_t
i
=
0
;
i
<
shp
.
size
();
++
i
)
{
out_tensor_map
.
push_back
(
SizeToInt
(
dev_matrix_shape_
.
size
()
-
i
-
1
));
if
(
full_batch
)
{
out_tensor_map
.
push_back
(
MAP_NONE
);
}
else
{
out_tensor_map
.
push_back
(
SizeToInt
(
dev_matrix_shape_
.
size
()
-
i
-
1
));
}
}
outputs_tensor_map_
.
push_back
(
out_tensor_map
);
}
...
...
@@ -190,6 +198,9 @@ Status GetNextInfo::GetAttrs() {
}
Status
GetNextInfo
::
InferReplaceOps
(
const
StrategyPtr
&
)
{
MS_EXCEPTION_IF_NULL
(
ParallelContext
::
GetInstance
());
bool
full_batch
=
ParallelContext
::
GetInstance
()
->
full_batch
();
Shapes
out_shapes
=
outputs_shape_
;
for
(
size_t
i
=
0
;
i
<
out_shapes
.
size
();
++
i
)
{
if
(
dev_num_
<=
0
)
{
...
...
@@ -200,7 +211,9 @@ Status GetNextInfo::InferReplaceOps(const StrategyPtr &) {
MS_LOG
(
ERROR
)
<<
name_
<<
" : batch num cannot floor div dev num."
;
return
FAILED
;
}
out_shapes
[
i
][
0
]
=
out_shapes
[
i
][
0
]
/
dev_num_
;
if
(
!
full_batch
)
{
out_shapes
[
i
][
0
]
=
out_shapes
[
i
][
0
]
/
dev_num_
;
}
}
ValuePtr
new_shapes
=
MakeValue
(
out_shapes
);
Attr
attr_types
=
std
::
make_pair
(
TYPES
,
attrs_
[
TYPES
]);
...
...
mindspore/ccsrc/parallel/ops_info/virtual_dataset_info.cc
浏览文件 @
e5c35169
...
...
@@ -23,6 +23,7 @@
#include "parallel/device_manager.h"
#include "parallel/device_matrix.h"
#include "parallel/step_parallel.h"
#include "parallel/context.h"
#include "utils/log_adapter.h"
namespace
mindspore
{
...
...
@@ -93,59 +94,21 @@ Status VirtualDatasetInfo::InferDevMatrixShape() {
return
SUCCESS
;
}
Status
VirtualDatasetInfo
::
InferMirrorOps
()
{
mirror_ops_
.
clear
();
int32_t
stage
=
strategy_
->
GetInputStage
();
CheckGlobalDeviceManager
();
RankList
dev_list
=
g_device_manager
->
GetDeviceListByStageId
(
stage
);
if
(
dev_list
.
empty
())
{
MS_LOG
(
ERROR
)
<<
name_
<<
": The current stage is empty!"
;
return
Status
::
FAILED
;
}
if
(
dev_list
.
size
()
==
1
)
{
MS_LOG
(
INFO
)
<<
name_
<<
": No need mirror ops."
;
return
Status
::
SUCCESS
;
}
OperatorName
operator_name
=
BROADCAST
;
ValuePtr
attr0_value
=
MakeValue
(
dev_list
.
front
());
std
::
vector
<
Group
>
group_list
;
if
(
CreateGroupByDim
(
dev_matrix_shape_
.
size
()
-
1
,
&
group_list
)
!=
SUCCESS
)
{
MS_LOG
(
ERROR
)
<<
name_
<<
": Infer mirror ops, create group failed."
;
return
FAILED
;
}
else
if
(
group_list
.
empty
())
{
MS_LOG
(
INFO
)
<<
name_
<<
": No need mirror ops."
;
return
SUCCESS
;
}
std
::
string
group
=
group_list
[
0
].
name
();
ValuePtr
attr1_value
=
MakeValue
(
group
);
Attr
attr0
=
std
::
make_pair
(
SRC
,
attr0_value
);
Attr
attr1
=
std
::
make_pair
(
GROUP
,
attr1_value
);
OperatorAttrs
operator_attrs
=
{
attr0
,
attr1
};
OperatorParams
operator_param
;
OperatorArgs
operator_args
=
std
::
make_pair
(
operator_attrs
,
operator_param
);
Operator
op
=
std
::
make_pair
(
operator_name
,
operator_args
);
OperatorVector
op_vector
=
{
op
};
size_t
size
=
inputs_shape_
.
size
();
for
(
size_t
i
=
0
;
i
<
size
;
++
i
)
{
mirror_ops_
.
push_back
(
op_vector
);
}
mirror_ops_
.
clear
();
return
SUCCESS
;
}
Status
VirtualDatasetInfo
::
InferMirrorOps
()
{
return
SUCCESS
;
}
Status
VirtualDatasetInfo
::
InferForwardCommunication
()
{
return
SUCCESS
;
}
Status
VirtualDatasetInfo
::
InferTensorMap
()
{
MS_EXCEPTION_IF_NULL
(
ParallelContext
::
GetInstance
());
bool
full_batch
=
ParallelContext
::
GetInstance
()
->
full_batch
();
for
(
size_t
i
=
0
;
i
<
strategy_
->
GetInputNumber
();
i
++
)
{
std
::
vector
<
int32_t
>
tensor_map_index
;
tensor_map_index
.
push_back
((
int32_t
)(
LAST_INDEX
(
SizeToUint
(
dev_matrix_shape_
.
size
()))));
if
(
full_batch
)
{
tensor_map_index
.
push_back
(
MAP_NONE
);
}
else
{
tensor_map_index
.
push_back
((
int32_t
)(
LAST_INDEX
(
SizeToUint
(
dev_matrix_shape_
.
size
()))));
}
for
(
size_t
j
=
1
;
j
<
strategy_
->
GetInputDim
()[
i
].
size
();
++
j
)
{
tensor_map_index
.
push_back
(
MAP_NONE
);
}
...
...
@@ -213,6 +176,10 @@ Status VirtualDatasetInfo::SetCostUnderStrategy(const StrategyPtr &strategy) {
}
Status
VirtualDatasetInfo
::
GenerateStrategies
(
int32_t
stage_id
)
{
MS_EXCEPTION_IF_NULL
(
ParallelContext
::
GetInstance
());
bool
full_batch
=
ParallelContext
::
GetInstance
()
->
full_batch
();
size_t
total_dev_num
;
if
(
GetAttrs
()
!=
SUCCESS
)
{
MS_LOG
(
ERROR
)
<<
name_
<<
": GetAttrs failed"
;
return
FAILED
;
...
...
@@ -220,7 +187,11 @@ Status VirtualDatasetInfo::GenerateStrategies(int32_t stage_id) {
CheckGlobalDeviceManager
();
is_auto_parallel_
=
true
;
size_t
total_dev_num
=
g_device_manager
->
GetDeviceListByStageId
(
stage_id
).
size
();
if
(
full_batch
)
{
total_dev_num
=
1
;
}
else
{
total_dev_num
=
g_device_manager
->
GetDeviceListByStageId
(
stage_id
).
size
();
}
StrategyPtr
sp
;
std
::
vector
<
Dimensions
>
strategy
;
for
(
auto
&
shape
:
inputs_shape_
)
{
...
...
@@ -232,10 +203,18 @@ Status VirtualDatasetInfo::GenerateStrategies(int32_t stage_id) {
sp
=
std
::
make_shared
<
Strategy
>
(
stage_id
,
strategy
);
if
(
SetCostUnderStrategy
(
sp
)
==
SUCCESS
)
{
MS_LOG
(
INFO
)
<<
name_
<<
": Successfully generated batch-parallel-strategy."
;
if
(
full_batch
)
{
MS_LOG
(
INFO
)
<<
name_
<<
": Successfully generated full-batch-parallel-strategy."
;
}
else
{
MS_LOG
(
INFO
)
<<
name_
<<
": Successfully generated batch-parallel-strategy."
;
}
PrintStrategy
(
sp
);
}
else
{
MS_LOG
(
ERROR
)
<<
name_
<<
": Generating batch-parallel-strategy failed."
;
if
(
full_batch
)
{
MS_LOG
(
ERROR
)
<<
name_
<<
": Generating full-batch-parallel-strategy failed."
;
}
else
{
MS_LOG
(
ERROR
)
<<
name_
<<
": Generating batch-parallel-strategy failed."
;
}
return
FAILED
;
}
return
SUCCESS
;
...
...
mindspore/ccsrc/parallel/step_parallel.cc
浏览文件 @
e5c35169
...
...
@@ -1375,11 +1375,19 @@ void SetClonedTensorShapeForOptimizer(const FuncGraphPtr &root) {
void
SetVirtualDatasetStrategy
(
const
CNodePtr
&
node
)
{
MS_EXCEPTION_IF_NULL
(
node
);
MS_EXCEPTION_IF_NULL
(
ParallelContext
::
GetInstance
());
bool
full_batch
=
ParallelContext
::
GetInstance
()
->
full_batch
();
PrimitivePtr
prim
=
GetValueNode
<
PrimitivePtr
>
(
node
->
input
(
0
));
MS_EXCEPTION_IF_NULL
(
prim
);
if
(
prim
->
name
()
==
VIRTUAL_DATA_SET
)
{
CheckGlobalDeviceManager
();
int32_t
dev_num
=
SizeToInt
(
g_device_manager
->
GetDeviceListByStageId
(
0
).
size
());
int32_t
dev_num
;
if
(
full_batch
)
{
dev_num
=
1
;
}
else
{
dev_num
=
SizeToInt
(
g_device_manager
->
GetDeviceListByStageId
(
0
).
size
());
}
auto
attrs_temp
=
prim
->
attrs
();
std
::
vector
<
Shapes
>
shape_list
=
ExtractShape
(
node
);
if
(
shape_list
.
empty
())
{
...
...
mindspore/ccsrc/pipeline/init.cc
浏览文件 @
e5c35169
...
...
@@ -187,6 +187,8 @@ PYBIND11_MODULE(_c_expression, m) {
"Set strategy checkpoint save file."
)
.
def
(
"get_strategy_ckpt_load_file"
,
&
ParallelContext
::
strategy_ckpt_load_file
,
"Get strategy checkpoint load file."
)
.
def
(
"get_strategy_ckpt_save_file"
,
&
ParallelContext
::
strategy_ckpt_save_file
,
"Get strategy checkpoint save file."
)
.
def
(
"set_full_batch"
,
&
ParallelContext
::
set_full_batch
,
"Set whether load full batch on each device."
)
.
def
(
"get_full_batch"
,
&
ParallelContext
::
full_batch
,
"Get whether load full batch on each device."
)
.
def
(
"reset"
,
&
ParallelContext
::
Reset
,
"Reset auto parallel context."
);
(
void
)
py
::
class_
<
CostModelContext
,
std
::
shared_ptr
<
CostModelContext
>>
(
m
,
"CostModelContext"
)
...
...
mindspore/context.py
浏览文件 @
e5c35169
...
...
@@ -367,7 +367,8 @@ def _context():
@
args_type_check
(
device_num
=
int
,
global_rank
=
int
,
mirror_mean
=
bool
,
cast_before_mirror
=
bool
,
parallel_mode
=
str
,
parameter_broadcast
=
bool
,
strategy_ckpt_load_file
=
str
,
strategy_ckpt_save_file
=
str
)
parameter_broadcast
=
bool
,
strategy_ckpt_load_file
=
str
,
strategy_ckpt_save_file
=
str
,
full_batch
=
bool
)
def
set_auto_parallel_context
(
**
kwargs
):
"""
Set auto parallel context.
...
...
@@ -404,6 +405,7 @@ def set_auto_parallel_context(**kwargs):
broadcast. Default: False.
strategy_ckpt_load_file (str): The path to load parallel strategy checkpoint. Default: ''
strategy_ckpt_save_file (str): The path to save parallel strategy checkpoint. Default: ''
full_batch (bool): Whether to load the whole batch on each device. Default: False.
Raises:
ValueError: If input key is not attribute in auto parallel context.
...
...
mindspore/parallel/_auto_parallel_context.py
浏览文件 @
e5c35169
...
...
@@ -225,6 +225,21 @@ class _AutoParallelContext:
self
.
check_context_handle
()
return
self
.
_context_handle
.
get_strategy_ckpt_load_file
()
def
set_full_batch
(
self
,
full_batch
):
"""
Set whether load full batch on each device.
Args:
full_batch (bool): True if load full batch on each device.
"""
self
.
check_context_handle
()
self
.
_context_handle
.
set_full_batch
(
full_batch
)
def
get_full_batch
(
self
):
"""Get whether load full batch on each device."""
self
.
check_context_handle
()
return
self
.
_context_handle
.
get_full_batch
()
def
set_strategy_ckpt_save_file
(
self
,
strategy_ckpt_save_file
):
"""
Set strategy checkpoint save path.
...
...
@@ -415,7 +430,8 @@ _set_auto_parallel_context_func_map = {
"parallel_mode"
:
auto_parallel_context
().
set_parallel_mode
,
"parameter_broadcast"
:
auto_parallel_context
().
set_parameter_broadcast
,
"strategy_ckpt_load_file"
:
auto_parallel_context
().
set_strategy_ckpt_load_file
,
"strategy_ckpt_save_file"
:
auto_parallel_context
().
set_strategy_ckpt_save_file
}
"strategy_ckpt_save_file"
:
auto_parallel_context
().
set_strategy_ckpt_save_file
,
"full_batch"
:
auto_parallel_context
().
set_full_batch
}
_get_auto_parallel_context_func_map
=
{
...
...
@@ -427,12 +443,13 @@ _get_auto_parallel_context_func_map = {
"parallel_mode"
:
auto_parallel_context
().
get_parallel_mode
,
"parameter_broadcast"
:
auto_parallel_context
().
get_parameter_broadcast
,
"strategy_ckpt_load_file"
:
auto_parallel_context
().
get_strategy_ckpt_load_file
,
"strategy_ckpt_save_file"
:
auto_parallel_context
().
get_strategy_ckpt_save_file
}
"strategy_ckpt_save_file"
:
auto_parallel_context
().
get_strategy_ckpt_save_file
,
"full_batch"
:
auto_parallel_context
().
get_full_batch
}
@
args_type_check
(
device_num
=
int
,
global_rank
=
int
,
mirror_mean
=
bool
,
cast_before_mirror
=
bool
,
loss_repeated_mean
=
bool
,
parallel_mode
=
str
,
parameter_broadcast
=
bool
,
strategy_ckpt_load_file
=
str
,
strategy_ckpt_save_file
=
str
)
strategy_ckpt_load_file
=
str
,
strategy_ckpt_save_file
=
str
,
full_batch
=
bool
)
def
_set_auto_parallel_context
(
**
kwargs
):
"""
Set auto parallel context.
...
...
@@ -465,6 +482,7 @@ def _set_auto_parallel_context(**kwargs):
broadcast. Default: False.
strategy_ckpt_load_file (str): The path to load parallel strategy checkpoint. Default: ''
strategy_ckpt_save_file (str): The path to save parallel strategy checkpoint. Default: ''
full_batch (bool): Whether to load the whole batch on each device. Default: False.
Raises:
ValueError: If input key is not attribute in auto parallel context.
...
...
mindspore/parallel/_utils.py
浏览文件 @
e5c35169
...
...
@@ -20,10 +20,26 @@ from mindspore.parallel._auto_parallel_context import auto_parallel_context
def
_get_parallel_mode
():
"""Get parallel mode."""
return
auto_parallel_context
().
get_parallel_mode
()
def
_get_full_batch
():
"""Get whether to use full_batch."""
return
auto_parallel_context
().
get_full_batch
()
def
_need_to_full
():
"""Check whether to convert input to full shape or tensor."""
parallel_mode
=
_get_parallel_mode
()
full_batch
=
_get_full_batch
()
need
=
((
parallel_mode
in
(
"semi_auto_parallel"
,
"auto_parallel"
))
and
(
not
full_batch
))
return
need
def
_get_mirror_mean
():
"""Get if using mirror_mean."""
return
auto_parallel_context
().
get_mirror_mean
()
...
...
mindspore/train/dataset_helper.py
浏览文件 @
e5c35169
...
...
@@ -17,11 +17,10 @@ import math
from
mindspore._checkparam
import
check_bool
from
..
import
context
from
.parallel_utils
import
ParallelMode
from
._utils
import
_exec_datagraph
,
_get_types_and_shapes
,
_to_tensor
,
\
_construct_tensor_list
,
_to_full_shapes
,
_to_full_tensor
from
..nn.wrap
import
GetNextSingleOp
from
..parallel._utils
import
_get_device_num
,
_get_global_rank
,
_
get_parallel_mode
from
..parallel._utils
import
_get_device_num
,
_get_global_rank
,
_
need_to_full
class
DatasetHelper
:
...
...
@@ -118,10 +117,10 @@ class _DatasetIterMSLoopSink(_DatasetIter):
def
__init__
(
self
,
dataset
):
super
(
_DatasetIterMSLoopSink
,
self
).
__init__
(
dataset
)
self
.
loop_count
=
self
.
get_loop_count
(
dataset
)
# for self._parallel_mode equal to semi_auto_parallel or auto_parallel,
use a complete tensor to
#
compile, and slice tensor to run. The batch dimension of tensors for compile is device_numbe
r
# times the batch dimension of tensors for run. Now only support LoopSink.
if
_
get_parallel_mode
()
in
(
ParallelMode
.
SEMI_AUTO_PARALLEL
,
ParallelMode
.
AUTO_PARALLEL
):
# for self._parallel_mode equal to semi_auto_parallel or auto_parallel,
and not using full_batch,
#
use a complete tensor to compile, and slice tensor to run. The batch dimension of tensors fo
r
#
compile is device_number
times the batch dimension of tensors for run. Now only support LoopSink.
if
_
need_to_full
(
):
device_num
=
_get_device_num
()
self
.
dataset_shapes
=
_to_full_shapes
(
self
.
dataset_shapes
,
device_num
)
...
...
@@ -146,10 +145,8 @@ class _DatasetIterGE(_DatasetIter):
def
__init__
(
self
,
dataset
):
super
(
_DatasetIterGE
,
self
).
__init__
(
dataset
)
self
.
loop_count
=
self
.
get_loop_count
(
dataset
)
parallel_mode
=
_get_parallel_mode
()
self
.
need_to_full
=
parallel_mode
in
(
ParallelMode
.
SEMI_AUTO_PARALLEL
,
ParallelMode
.
AUTO_PARALLEL
)
batch_expand_num
=
1
if
self
.
need_to_full
:
if
_need_to_full
()
:
batch_expand_num
=
_get_device_num
()
tensor_list_run
=
_construct_tensor_list
(
self
.
dataset_types
,
self
.
dataset_shapes
,
batch_expand_num
)
...
...
@@ -170,9 +167,6 @@ class _DatasetIterFeed:
self
.
loop_count
=
dataset
.
get_dataset_size
()
self
.
ind
=
0
parallel_mode
=
context
.
get_auto_parallel_context
(
"parallel_mode"
)
self
.
need_to_full
=
parallel_mode
in
(
ParallelMode
.
SEMI_AUTO_PARALLEL
,
ParallelMode
.
AUTO_PARALLEL
)
def
__iter__
(
self
):
if
self
.
repeat_ind
%
self
.
repeat_count
==
0
:
self
.
iter
=
self
.
dataset
.
__iter__
()
...
...
@@ -186,6 +180,6 @@ class _DatasetIterFeed:
raise
StopIteration
()
self
.
ind
+=
1
data
=
self
.
iter
.
__next__
()
if
self
.
need_to_full
:
if
_need_to_full
()
:
return
_to_full_tensor
(
data
,
self
.
device_num
,
self
.
global_rank
)
return
_to_tensor
(
data
)
model_zoo/wide_and_deep/src/config.py
浏览文件 @
e5c35169
...
...
@@ -22,6 +22,7 @@ def argparse_init():
parser
=
argparse
.
ArgumentParser
(
description
=
'WideDeep'
)
parser
.
add_argument
(
"--data_path"
,
type
=
str
,
default
=
"./test_raw_data/"
)
parser
.
add_argument
(
"--epochs"
,
type
=
int
,
default
=
15
)
parser
.
add_argument
(
"--full_batch"
,
type
=
bool
,
default
=
False
)
parser
.
add_argument
(
"--batch_size"
,
type
=
int
,
default
=
16000
)
parser
.
add_argument
(
"--eval_batch_size"
,
type
=
int
,
default
=
16000
)
parser
.
add_argument
(
"--field_size"
,
type
=
int
,
default
=
39
)
...
...
@@ -44,6 +45,7 @@ class WideDeepConfig():
"""
def
__init__
(
self
):
self
.
data_path
=
"./test_raw_data/"
self
.
full_batch
=
False
self
.
epochs
=
15
self
.
batch_size
=
16000
self
.
eval_batch_size
=
16000
...
...
@@ -72,6 +74,7 @@ class WideDeepConfig():
args
,
_
=
parser
.
parse_known_args
()
self
.
data_path
=
args
.
data_path
self
.
epochs
=
args
.
epochs
self
.
full_batch
=
args
.
full_batch
self
.
batch_size
=
args
.
batch_size
self
.
eval_batch_size
=
args
.
eval_batch_size
self
.
field_size
=
args
.
field_size
...
...
model_zoo/wide_and_deep/src/metrics.py
浏览文件 @
e5c35169
...
...
@@ -17,8 +17,10 @@
Area under cure metric
"""
from
mindspore.nn.metrics
import
Metric
from
sklearn.metrics
import
roc_auc_score
from
mindspore
import
context
from
mindspore.nn.metrics
import
Metric
from
mindspore.communication.management
import
get_rank
,
get_group_size
class
AUCMetric
(
Metric
):
"""
...
...
@@ -28,6 +30,7 @@ class AUCMetric(Metric):
def
__init__
(
self
):
super
(
AUCMetric
,
self
).
__init__
()
self
.
clear
()
self
.
full_batch
=
context
.
get_auto_parallel_context
(
"full_batch"
)
def
clear
(
self
):
"""Clear the internal evaluation result."""
...
...
@@ -35,10 +38,17 @@ class AUCMetric(Metric):
self
.
pred_probs
=
[]
def
update
(
self
,
*
inputs
):
# inputs
all_predict
=
inputs
[
1
].
asnumpy
()
# predict
all_label
=
inputs
[
2
].
asnumpy
()
# label
self
.
true_labels
.
extend
(
all_label
.
flatten
().
tolist
())
self
.
pred_probs
.
extend
(
all_predict
.
flatten
().
tolist
())
"""Update list of predicts and labels."""
all_predict
=
inputs
[
1
].
asnumpy
().
flatten
().
tolist
()
# predict
all_label
=
inputs
[
2
].
asnumpy
().
flatten
().
tolist
()
# label
self
.
pred_probs
.
extend
(
all_predict
)
if
self
.
full_batch
:
rank_id
=
get_rank
()
group_size
=
get_group_size
()
gap
=
len
(
all_label
)
//
group_size
self
.
true_labels
.
extend
(
all_label
[
rank_id
*
gap
:
(
rank_id
+
1
)
*
gap
])
else
:
self
.
true_labels
.
extend
(
all_label
)
def
eval
(
self
):
if
len
(
self
.
true_labels
)
!=
len
(
self
.
pred_probs
):
...
...
model_zoo/wide_and_deep/train_and_test_multinpu_auto_parallel.py
浏览文件 @
e5c35169
...
...
@@ -17,6 +17,7 @@
import
os
import
sys
import
mindspore.dataset.engine
as
de
from
mindspore
import
Model
,
context
from
mindspore.train.callback
import
ModelCheckpoint
,
CheckpointConfig
,
TimeMonitor
from
mindspore.train
import
ParallelMode
...
...
@@ -79,10 +80,18 @@ def test_train_eval():
batch_size
=
config
.
batch_size
epochs
=
config
.
epochs
print
(
"epochs is {}"
.
format
(
epochs
))
ds_train
=
create_dataset
(
data_path
,
train_mode
=
True
,
epochs
=
epochs
,
batch_size
=
batch_size
,
rank_id
=
get_rank
(),
rank_size
=
get_group_size
())
ds_eval
=
create_dataset
(
data_path
,
train_mode
=
False
,
epochs
=
epochs
+
1
,
batch_size
=
batch_size
,
rank_id
=
get_rank
(),
rank_size
=
get_group_size
())
if
config
.
full_batch
:
context
.
set_auto_parallel_context
(
full_batch
=
True
)
de
.
config
.
set_seed
(
1
)
ds_train
=
create_dataset
(
data_path
,
train_mode
=
True
,
epochs
=
epochs
,
batch_size
=
batch_size
*
get_group_size
())
ds_eval
=
create_dataset
(
data_path
,
train_mode
=
False
,
epochs
=
epochs
+
1
,
batch_size
=
batch_size
*
get_group_size
())
else
:
ds_train
=
create_dataset
(
data_path
,
train_mode
=
True
,
epochs
=
epochs
,
batch_size
=
batch_size
,
rank_id
=
get_rank
(),
rank_size
=
get_group_size
())
ds_eval
=
create_dataset
(
data_path
,
train_mode
=
False
,
epochs
=
epochs
+
1
,
batch_size
=
batch_size
,
rank_id
=
get_rank
(),
rank_size
=
get_group_size
())
print
(
"ds_train.size: {}"
.
format
(
ds_train
.
get_dataset_size
()))
print
(
"ds_eval.size: {}"
.
format
(
ds_eval
.
get_dataset_size
()))
...
...
tests/ut/python/parallel/test_full_batch.py
0 → 100644
浏览文件 @
e5c35169
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
numpy
as
np
import
mindspore
as
ms
import
mindspore.nn
as
nn
from
mindspore
import
Tensor
from
mindspore
import
context
from
mindspore.common.parameter
import
Parameter
from
mindspore.nn.loss
import
SoftmaxCrossEntropyWithLogits
from
mindspore.nn.optim.momentum
import
Momentum
from
mindspore.ops
import
operations
as
P
from
mindspore.parallel._utils
import
_reset_op_id
from
mindspore.train
import
Model
,
ParallelMode
from
tests.dataset_mock
import
MindData
class
Dataset
(
MindData
):
def
__init__
(
self
,
predict
,
label
,
length
=
3
):
super
(
Dataset
,
self
).
__init__
(
size
=
length
)
self
.
predict
=
predict
self
.
label
=
label
self
.
index
=
0
self
.
length
=
length
def
__iter__
(
self
):
return
self
def
__next__
(
self
):
if
self
.
index
>=
self
.
length
:
raise
StopIteration
self
.
index
+=
1
return
self
.
predict
,
self
.
label
def
reset
(
self
):
self
.
index
=
0
class
AllToAllNet
(
nn
.
Cell
):
def
__init__
(
self
,
strategy1
):
super
(
AllToAllNet
,
self
).
__init__
()
self
.
matmul
=
P
.
MatMul
().
set_strategy
(((
1
,
1
),
(
1
,
8
)))
self
.
matmul_weight
=
Parameter
(
Tensor
(
np
.
ones
([
128
,
256
]),
dtype
=
ms
.
float32
),
name
=
"weight"
)
self
.
transpose1
=
P
.
Transpose
().
set_strategy
(
strategy1
)
def
construct
(
self
,
x
):
x
=
self
.
matmul
(
x
,
self
.
matmul_weight
)
x
=
self
.
transpose1
(
x
,
(
1
,
0
))
return
x
def
all_to_all_net
(
strategy1
):
return
AllToAllNet
(
strategy1
=
strategy1
)
def
all_to_all_common
(
strategy1
):
learning_rate
=
0.1
momentum
=
0.9
epoch_size
=
2
context
.
set_context
(
mode
=
context
.
GRAPH_MODE
,
save_graphs
=
False
)
context
.
reset_auto_parallel_context
()
context
.
set_auto_parallel_context
(
parallel_mode
=
ParallelMode
.
SEMI_AUTO_PARALLEL
,
device_num
=
8
,
full_batch
=
True
)
predict
=
Tensor
(
np
.
ones
([
256
,
128
]),
dtype
=
ms
.
float32
)
label
=
Tensor
(
np
.
ones
([
256
]),
dtype
=
ms
.
int32
)
dataset
=
Dataset
(
predict
,
label
,
2
)
net
=
all_to_all_net
(
strategy1
)
loss
=
SoftmaxCrossEntropyWithLogits
(
is_grad
=
False
,
sparse
=
True
)
loss
.
softmax_cross_entropy
.
set_strategy
(((
8
,
1
),
(
8
,
1
)))
loss
.
one_hot
.
set_strategy
(((
8
,
1
),
(),
()))
opt
=
Momentum
(
net
.
trainable_params
(),
learning_rate
,
momentum
)
model
=
Model
(
net
,
loss
,
opt
)
model
.
train
(
epoch_size
,
dataset
,
dataset_sink_mode
=
False
)
def
test_all_to_all
():
strategy1
=
((
8
,
1
),)
_reset_op_id
()
all_to_all_common
(
strategy1
)
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录