Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
bfc3065f
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
bfc3065f
编写于
7月 14, 2020
作者:
M
mindspore-ci-bot
提交者:
Gitee
7月 14, 2020
浏览文件
操作
浏览文件
下载
差异文件
!3025 [AutoParallel]Add embedding look up op
Merge pull request !3025 from lichen/add_embedding_look_up_op
上级
e6ff8dc5
cde5cc2b
变更
9
隐藏空白更改
内联
并排
Showing
9 changed file
with
115 addition
and
110 deletion
+115
-110
mindspore/ccsrc/parallel/dynamic_creator.h
mindspore/ccsrc/parallel/dynamic_creator.h
+1
-0
mindspore/ccsrc/parallel/ops_info/gather_v2_p_info.cc
mindspore/ccsrc/parallel/ops_info/gather_v2_p_info.cc
+23
-58
mindspore/ccsrc/parallel/ops_info/gather_v2_p_info.h
mindspore/ccsrc/parallel/ops_info/gather_v2_p_info.h
+9
-4
mindspore/ccsrc/parallel/ops_info/ops_utils.h
mindspore/ccsrc/parallel/ops_info/ops_utils.h
+1
-0
mindspore/ccsrc/parallel/step_parallel.cc
mindspore/ccsrc/parallel/step_parallel.cc
+1
-1
mindspore/nn/layer/embedding.py
mindspore/nn/layer/embedding.py
+46
-0
model_zoo/wide_and_deep/src/wide_and_deep.py
model_zoo/wide_and_deep/src/wide_and_deep.py
+3
-3
tests/ut/python/parallel/test_embeddinglookup.py
tests/ut/python/parallel/test_embeddinglookup.py
+31
-3
tests/ut/python/parallel/test_gather_v2.py
tests/ut/python/parallel/test_gather_v2.py
+0
-41
未找到文件。
mindspore/ccsrc/parallel/dynamic_creator.h
浏览文件 @
bfc3065f
...
@@ -132,6 +132,7 @@ REGISTER(SqueezeInfo);
...
@@ -132,6 +132,7 @@ REGISTER(SqueezeInfo);
REGISTER
(
SigmoidCrossEntropyWithLogitsInfo
);
REGISTER
(
SigmoidCrossEntropyWithLogitsInfo
);
REGISTER
(
SquareInfo
);
REGISTER
(
SquareInfo
);
REGISTER
(
GatherV2PInfo
);
REGISTER
(
GatherV2PInfo
);
REGISTER
(
EmbeddingLookupInfo
);
}
// namespace parallel
}
// namespace parallel
}
// namespace mindspore
}
// namespace mindspore
...
...
mindspore/ccsrc/parallel/ops_info/gather_v2_p_info.cc
浏览文件 @
bfc3065f
...
@@ -28,24 +28,25 @@
...
@@ -28,24 +28,25 @@
namespace
mindspore
{
namespace
mindspore
{
namespace
parallel
{
namespace
parallel
{
Status
GatherV2PInfo
::
GetAttrs
()
{
Status
GatherV2PInfo
::
GetAttrs
()
{
// get axis, the third input is the axis, is a ValueNode
// get axis, the third input is the axis, is a ValueNode, embeddinglookup doesn't have axis.
if
(
input_value_
.
at
(
2
)
==
nullptr
)
{
if
(
target_
!=
CPU
)
{
MS_LOG
(
ERROR
)
<<
name_
<<
": the third input value is nullptr, is not a ValueNode!"
;
if
(
input_value_
.
at
(
2
)
==
nullptr
)
{
return
FAILED
;
MS_LOG
(
ERROR
)
<<
name_
<<
": the third input value is nullptr, is not a ValueNode!"
;
}
return
FAILED
;
auto
axis
=
GetValue
<
int
>
(
input_value_
.
at
(
2
));
}
// if axis is negative then convert it to positive
auto
axis
=
GetValue
<
int
>
(
input_value_
.
at
(
2
));
auto
params_shape
=
inputs_shape_
.
at
(
0
);
// if axis is negative then convert it to positive
if
(
params_shape
.
size
()
==
0
)
{
auto
params_shape
=
inputs_shape_
.
at
(
0
);
MS_LOG
(
ERROR
)
<<
name_
<<
": params can not be a scalar!"
;
if
(
params_shape
.
size
()
==
0
)
{
return
FAILED
;
MS_LOG
(
ERROR
)
<<
name_
<<
": params can not be a scalar!"
;
}
return
FAILED
;
if
(
axis
<
0
)
{
}
axis
+=
SizeToInt
(
inputs_shape_
[
0
].
size
());
if
(
axis
<
0
)
{
axis
+=
SizeToInt
(
inputs_shape_
[
0
].
size
());
}
axis_
=
axis
;
}
}
axis_
=
axis
;
// get target
auto
target_iter
=
attrs_
.
find
(
TARGET
);
auto
target_iter
=
attrs_
.
find
(
TARGET
);
if
(
target_iter
!=
attrs_
.
end
())
{
if
(
target_iter
!=
attrs_
.
end
())
{
MS_EXCEPTION_IF_NULL
(
target_iter
->
second
);
MS_EXCEPTION_IF_NULL
(
target_iter
->
second
);
...
@@ -53,16 +54,8 @@ Status GatherV2PInfo::GetAttrs() {
...
@@ -53,16 +54,8 @@ Status GatherV2PInfo::GetAttrs() {
target_
=
target_iter
->
second
->
cast
<
StringImmPtr
>
()
->
value
();
target_
=
target_iter
->
second
->
cast
<
StringImmPtr
>
()
->
value
();
}
else
{
}
else
{
MS_LOG
(
ERROR
)
<<
name_
<<
" : The value of target is not a string."
;
MS_LOG
(
ERROR
)
<<
name_
<<
" : The value of target is not a string."
;
return
FAILED
;
}
}
}
}
// target=CPU, axis must be 0
if
(
target_
==
"CPU"
&&
axis_
!=
0
)
{
MS_LOG
(
ERROR
)
<<
name_
<<
": target is CPU, axis must be 0, but got "
<<
axis_
;
return
FAILED
;
}
auto
manual_split_iter
=
attrs_
.
find
(
"manual_split"
);
auto
manual_split_iter
=
attrs_
.
find
(
"manual_split"
);
if
(
manual_split_iter
!=
attrs_
.
end
())
{
if
(
manual_split_iter
!=
attrs_
.
end
())
{
param_split_shapes_
.
clear
();
param_split_shapes_
.
clear
();
...
@@ -459,38 +452,13 @@ Status GatherV2PInfo::InferForwardCommunication() {
...
@@ -459,38 +452,13 @@ Status GatherV2PInfo::InferForwardCommunication() {
MS_LOG
(
ERROR
)
<<
name_
<<
": Infer Group failed."
;
MS_LOG
(
ERROR
)
<<
name_
<<
": Infer Group failed."
;
return
FAILED
;
return
FAILED
;
}
}
auto
group_size
=
group_
.
GetDevNum
();
Attr
attr_group
;
Attr
attr_group
;
if
(
host_reduce_scatter_
)
{
operator_name
=
REDUCE_SCATTER
;
// group size <= 8
if
(
InferGroup
()
!=
SUCCESS
)
{
std
::
vector
<
int32_t
>
rank_list
;
MS_LOG
(
ERROR
)
<<
name_
<<
": Infer Group failed."
;
if
(
group_size
<=
8
)
{
return
FAILED
;
reduce_scatter_flag_
=
false
;
operator_name
=
HOST_REDUCE_SCATTER
;
rank_list
=
GetRankFromGroup
(
group_
);
attr_group
=
std
::
make_pair
(
GROUP
,
MakeValue
(
rank_list
));
}
else
{
// group size > 8, don't support host reduce_scatter
reduce_scatter_flag_
=
true
;
split_num_
=
SizeToInt
(
group_size
/
8
);
CheckGlobalDeviceManager
();
operator_name
=
REDUCE_SCATTER
;
int32_t
rank
=
g_device_manager
->
global_rank
();
size_t
repeat
=
group_size
/
8
;
for
(
size_t
i
=
0
;
i
<
repeat
;
++
i
)
{
rank_list
.
push_back
(
rank
+
SizeToInt
(
i
*
8
));
}
Group
g
=
g_device_manager
->
CreateGroup
(
rank_list
);
attr_group
=
std
::
make_pair
(
GROUP
,
MakeValue
(
g
.
name
()));
}
}
else
{
operator_name
=
REDUCE_SCATTER
;
if
(
InferGroup
()
!=
SUCCESS
)
{
MS_LOG
(
ERROR
)
<<
name_
<<
": Infer Group failed."
;
return
FAILED
;
}
attr_group
=
std
::
make_pair
(
GROUP
,
MakeValue
(
group_
.
name
()));
}
}
attr_group
=
std
::
make_pair
(
GROUP
,
MakeValue
(
group_
.
name
()));
Attr
attr_op
=
std
::
make_pair
(
OP
,
MakeValue
(
REDUCE_OP_SUM
));
Attr
attr_op
=
std
::
make_pair
(
OP
,
MakeValue
(
REDUCE_OP_SUM
));
OperatorAttrs
attrs
=
{
attr_op
,
attr_group
};
OperatorAttrs
attrs
=
{
attr_op
,
attr_group
};
OperatorParams
params
;
OperatorParams
params
;
...
@@ -582,10 +550,7 @@ Status GatherV2PInfo::ComputeReplaceOp() {
...
@@ -582,10 +550,7 @@ Status GatherV2PInfo::ComputeReplaceOp() {
OperatorName
op_name
=
EMBEDDING_LOOKUP
;
OperatorName
op_name
=
EMBEDDING_LOOKUP
;
OperatorAttrs
attrs
;
OperatorAttrs
attrs
;
Attr
param_offset
=
std
::
make_pair
(
"offset"
,
MakeValue
(
bias_
));
Attr
param_offset
=
std
::
make_pair
(
"offset"
,
MakeValue
(
bias_
));
Attr
param_flag
=
std
::
make_pair
(
"reduce_scatter_flag"
,
MakeValue
(
reduce_scatter_flag_
));
OperatorParams
params
=
{
std
::
make_pair
(
param_offset
,
3
)};
Attr
param_split_num
=
std
::
make_pair
(
"split_num"
,
MakeValue
(
split_num_
));
OperatorParams
params
=
{
std
::
make_pair
(
param_offset
,
3
),
std
::
make_pair
(
param_flag
,
4
),
std
::
make_pair
(
param_split_num
,
5
)};
OperatorArgs
args
=
std
::
make_pair
(
attrs
,
params
);
OperatorArgs
args
=
std
::
make_pair
(
attrs
,
params
);
Operator
op
=
std
::
make_pair
(
op_name
,
args
);
Operator
op
=
std
::
make_pair
(
op_name
,
args
);
replace_op_
.
push_back
(
op
);
replace_op_
.
push_back
(
op
);
...
...
mindspore/ccsrc/parallel/ops_info/gather_v2_p_info.h
浏览文件 @
bfc3065f
...
@@ -65,16 +65,13 @@ class GatherV2PInfo : public OperatorInfo {
...
@@ -65,16 +65,13 @@ class GatherV2PInfo : public OperatorInfo {
Status
InferGroup
();
Status
InferGroup
();
int32_t
axis_
;
int32_t
axis_
;
std
::
string
target_
;
std
::
string
target_
=
DEVICE
;
std
::
string
replace_op_name_
=
GATHERV2
;
std
::
string
replace_op_name_
=
GATHERV2
;
int32_t
bias_
;
int32_t
bias_
;
int32_t
index_offset_
;
int32_t
index_offset_
;
int32_t
slice_size_
;
int32_t
slice_size_
;
Shape
out_dev_matrix_shape_
;
Shape
out_dev_matrix_shape_
;
Group
group_
;
Group
group_
;
bool
reduce_scatter_flag_
=
false
;
int32_t
split_num_
=
1
;
bool
host_reduce_scatter_
=
false
;
bool
manual_split_
=
false
;
bool
manual_split_
=
false
;
std
::
vector
<
int32_t
>
param_split_shapes_
;
std
::
vector
<
int32_t
>
param_split_shapes_
;
std
::
vector
<
int32_t
>
index_offsets_
;
std
::
vector
<
int32_t
>
index_offsets_
;
...
@@ -90,6 +87,14 @@ class SparseGatherV2Info : public GatherV2PInfo {
...
@@ -90,6 +87,14 @@ class SparseGatherV2Info : public GatherV2PInfo {
private:
private:
std
::
string
replace_op_name_
=
SPARSE_GATHERV2
;
std
::
string
replace_op_name_
=
SPARSE_GATHERV2
;
};
};
class
EmbeddingLookupInfo
:
public
GatherV2PInfo
{
public:
EmbeddingLookupInfo
(
const
std
::
string
&
name
,
const
Shapes
&
inputs_shape
,
const
Shapes
&
outputs_shape
,
const
PrimitiveAttrs
&
attrs
)
:
GatherV2PInfo
(
name
,
inputs_shape
,
outputs_shape
,
attrs
)
{}
~
EmbeddingLookupInfo
()
override
=
default
;
};
}
// namespace parallel
}
// namespace parallel
}
// namespace mindspore
}
// namespace mindspore
#endif // MINDSPORE_CCSRC_PARALLEL_OPS_INFO_GATHER_V2_P_INFO_H_
#endif // MINDSPORE_CCSRC_PARALLEL_OPS_INFO_GATHER_V2_P_INFO_H_
mindspore/ccsrc/parallel/ops_info/ops_utils.h
浏览文件 @
bfc3065f
...
@@ -132,6 +132,7 @@ constexpr char REDISTRIBUTION_OP[] = "redistribution_op";
...
@@ -132,6 +132,7 @@ constexpr char REDISTRIBUTION_OP[] = "redistribution_op";
constexpr
char
DARA_PARALLEL
[]
=
"data_parallel"
;
constexpr
char
DARA_PARALLEL
[]
=
"data_parallel"
;
constexpr
char
FORWARD_REDUCE_SCATTER
[]
=
"forward_reduce_scatter"
;
constexpr
char
FORWARD_REDUCE_SCATTER
[]
=
"forward_reduce_scatter"
;
constexpr
char
OPTIMIZER_SUB_STRING
[]
=
"optimizer"
;
constexpr
char
OPTIMIZER_SUB_STRING
[]
=
"optimizer"
;
constexpr
char
DEVICE
[]
=
"Device"
;
// Operator
// Operator
constexpr
char
VIRTUAL_DIV
[]
=
"_VirtualDiv"
;
constexpr
char
VIRTUAL_DIV
[]
=
"_VirtualDiv"
;
...
...
mindspore/ccsrc/parallel/step_parallel.cc
浏览文件 @
bfc3065f
...
@@ -536,7 +536,7 @@ std::vector<AnfNodePtr> ReplaceOpInput(const Operator &replace_op, const std::st
...
@@ -536,7 +536,7 @@ std::vector<AnfNodePtr> ReplaceOpInput(const Operator &replace_op, const std::st
}
}
std
::
vector
<
AnfNodePtr
>
replace_input
=
{
NewValueNode
(
pyop_instance
),
node
->
input
(
1
)};
std
::
vector
<
AnfNodePtr
>
replace_input
=
{
NewValueNode
(
pyop_instance
),
node
->
input
(
1
)};
auto
prim
=
GetValueNode
<
PrimitivePtr
>
(
node
->
input
(
0
));
auto
prim
=
GetValueNode
<
PrimitivePtr
>
(
node
->
input
(
0
));
if
(
prim
->
name
()
==
GATHERV2
||
prim
->
name
()
==
SPARSE_GATHERV2
)
{
if
(
prim
->
name
()
==
EMBEDDING_LOOKUP
)
{
replace_input
=
{
NewValueNode
(
pyop_instance
),
node
->
input
(
1
),
node
->
input
(
2
)};
replace_input
=
{
NewValueNode
(
pyop_instance
),
node
->
input
(
1
),
node
->
input
(
2
)};
}
}
if
(
!
params
.
empty
())
{
if
(
!
params
.
empty
())
{
...
...
mindspore/nn/layer/embedding.py
浏览文件 @
bfc3065f
...
@@ -105,3 +105,49 @@ class Embedding(Cell):
...
@@ -105,3 +105,49 @@ class Embedding(Cell):
self
.
embedding_table
,
self
.
embedding_table
,
self
.
dtype
)
self
.
dtype
)
return
s
return
s
class
EmbeddingLookup
(
Cell
):
r
"""
Returns a slice of input tensor based on the specified indices.
Note:
When 'target' is set to 'CPU', this module will use
P.EmbeddingLookup().add_prim_attr('primitive_target', 'CPU') which
specified 'offset = 0' to lookup table.
when 'target' is set to 'DEVICE', this module will use P.GatherV2() which
specified 'axis = 0' to lookup table.
Args:
target (str): Specify the target where the op is executed. Default: 'CPU'.
Inputs:
- **input_params** (Tensor) - The shape of tensor is :math:`(x_1, x_2, ..., x_R)`.
The Tensor slice, instead of the entire Tensor.
- **input_indices** (Tensor) - The shape of tensor is :math:`(y_1, y_2, ..., y_S)`.
Specifies the indices of elements of the original Tensor. Values can be out of range of `input_params`,
and the exceeding part will be filled with 0 in the output.
Outputs:
Tensor, the shape of tensor is :math:`(z_1, z_2, ..., z_N)`.
Examples:
>>> input_params = Tensor(np.array([[8, 9], [10, 11], [12, 13], [14, 15]]), mindspore.float32)
>>> input_indices = Tensor(np.array([[1, 0], [3, 2]]), mindspore.int32)
>>> out = nn.EmbeddingLookup()(input_params, input_indices)
[[[10, 11], [8 ,9]], [[14, 15], [12, 13]]]
"""
def
__init__
(
self
,
target
=
'CPU'
):
super
(
EmbeddingLookup
,
self
).
__init__
()
self
.
target
=
target
if
target
not
in
(
'CPU'
,
'DEVICE'
):
raise
ValueError
(
'Attr
\'
target
\'
of
\'
EmbeddingLookup
\'
Op passed '
+
str
(
target
)
+
', should be one of values in
\'
CPU
\'
,
\'
DEVICE
\'
.'
)
self
.
gatherv2
=
P
.
GatherV2
()
self
.
embeddinglookup
=
P
.
EmbeddingLookup
().
add_prim_attr
(
'primitive_target'
,
'CPU'
)
def
construct
(
self
,
params
,
indices
):
if
self
.
target
==
"CPU"
:
out
=
self
.
embeddinglookup
(
params
,
ids
,
0
)
else
:
out
=
self
.
gatherv2
(
param
,
ids
,
0
)
return
out
model_zoo/wide_and_deep/src/wide_and_deep.py
浏览文件 @
bfc3065f
...
@@ -188,7 +188,7 @@ class WideDeepModel(nn.Cell):
...
@@ -188,7 +188,7 @@ class WideDeepModel(nn.Cell):
self
.
deep_layer_act
,
self
.
deep_layer_act
,
use_activation
=
False
,
convert_dtype
=
True
,
drop_out
=
config
.
dropout_flag
)
use_activation
=
False
,
convert_dtype
=
True
,
drop_out
=
config
.
dropout_flag
)
self
.
gather_v2
=
P
.
GatherV2
()
self
.
embeddinglookup
=
nn
.
EmbeddingLookup
()
self
.
mul
=
P
.
Mul
()
self
.
mul
=
P
.
Mul
()
self
.
reduce_sum
=
P
.
ReduceSum
(
keep_dims
=
False
)
self
.
reduce_sum
=
P
.
ReduceSum
(
keep_dims
=
False
)
self
.
reshape
=
P
.
Reshape
()
self
.
reshape
=
P
.
Reshape
()
...
@@ -206,11 +206,11 @@ class WideDeepModel(nn.Cell):
...
@@ -206,11 +206,11 @@ class WideDeepModel(nn.Cell):
"""
"""
mask
=
self
.
reshape
(
wt_hldr
,
(
self
.
batch_size
,
self
.
field_size
,
1
))
mask
=
self
.
reshape
(
wt_hldr
,
(
self
.
batch_size
,
self
.
field_size
,
1
))
# Wide layer
# Wide layer
wide_id_weight
=
self
.
gather_v2
(
self
.
wide_w
,
id_hldr
,
0
)
wide_id_weight
=
self
.
embeddinglookup
(
self
.
wide_w
,
id_hldr
,
0
)
wx
=
self
.
mul
(
wide_id_weight
,
mask
)
wx
=
self
.
mul
(
wide_id_weight
,
mask
)
wide_out
=
self
.
reshape
(
self
.
reduce_sum
(
wx
,
1
)
+
self
.
wide_b
,
(
-
1
,
1
))
wide_out
=
self
.
reshape
(
self
.
reduce_sum
(
wx
,
1
)
+
self
.
wide_b
,
(
-
1
,
1
))
# Deep layer
# Deep layer
deep_id_embs
=
self
.
gather_v2
(
self
.
embedding_table
,
id_hldr
,
0
)
deep_id_embs
=
self
.
embeddinglookup
(
self
.
embedding_table
,
id_hldr
,
0
)
vx
=
self
.
mul
(
deep_id_embs
,
mask
)
vx
=
self
.
mul
(
deep_id_embs
,
mask
)
deep_in
=
self
.
reshape
(
vx
,
(
-
1
,
self
.
field_size
*
self
.
emb_dim
))
deep_in
=
self
.
reshape
(
vx
,
(
-
1
,
self
.
field_size
*
self
.
emb_dim
))
deep_in
=
self
.
dense_layer_1
(
deep_in
)
deep_in
=
self
.
dense_layer_1
(
deep_in
)
...
...
tests/ut/python/parallel/test_embeddinglookup.py
浏览文件 @
bfc3065f
...
@@ -41,12 +41,12 @@ class NetWithLoss(nn.Cell):
...
@@ -41,12 +41,12 @@ class NetWithLoss(nn.Cell):
return
self
.
loss
(
predict
)
return
self
.
loss
(
predict
)
class
Net
(
nn
.
Cell
):
class
Net
(
nn
.
Cell
):
def
__init__
(
self
,
shape
,
offset
):
def
__init__
(
self
,
shape
,
offset
,
strategy1
=
None
,
strategy2
=
None
,
target
=
"Device"
):
super
().
__init__
()
super
().
__init__
()
self
.
index
=
Tensor
(
np
.
ones
(
shape
),
dtype
=
ms
.
int32
)
self
.
index
=
Tensor
(
np
.
ones
(
shape
),
dtype
=
ms
.
int32
)
self
.
offset
=
offset
self
.
offset
=
offset
self
.
elu
=
P
.
EmbeddingLookup
()
self
.
elu
=
P
.
EmbeddingLookup
()
.
set_strategy
(
strategy1
).
add_prim_attr
(
"primitive_target"
,
target
)
self
.
mm
=
P
.
BatchMatMul
()
self
.
mm
=
P
.
BatchMatMul
()
.
set_strategy
(
strategy2
)
def
construct
(
self
,
x
,
y
):
def
construct
(
self
,
x
,
y
):
out
=
self
.
elu
(
x
,
self
.
index
,
self
.
offset
)
out
=
self
.
elu
(
x
,
self
.
index
,
self
.
offset
)
...
@@ -97,3 +97,31 @@ def test_embeddinglookup_reducescatter_true_grad():
...
@@ -97,3 +97,31 @@ def test_embeddinglookup_reducescatter_true_grad():
x
=
Tensor
(
np
.
ones
([
64
,
32
]),
dtype
=
ms
.
float32
)
x
=
Tensor
(
np
.
ones
([
64
,
32
]),
dtype
=
ms
.
float32
)
y
=
Tensor
(
np
.
ones
([
8
,
32
,
8
]),
dtype
=
ms
.
float32
)
y
=
Tensor
(
np
.
ones
([
8
,
32
,
8
]),
dtype
=
ms
.
float32
)
_executor
.
compile
(
net
,
x
,
y
)
_executor
.
compile
(
net
,
x
,
y
)
def
test_embeddinglookup_semi_auto1
():
context
.
set_auto_parallel_context
(
device_num
=
8
,
global_rank
=
0
,
parallel_mode
=
"semi_auto_parallel"
)
shape
=
[
64
,
32
]
offset
=
0
strategy1
=
((
8
,
1
),
(
1
,
1
))
strategy2
=
((
4
,
1
,
2
),
(
4
,
2
,
1
))
net
=
GradWrap
(
NetWithLoss
(
Net
(
shape
,
offset
,
strategy1
,
strategy2
,
"CPU"
)))
net
.
set_auto_parallel
()
x
=
Tensor
(
np
.
ones
([
64
,
64
]),
dtype
=
ms
.
float32
)
y
=
Tensor
(
np
.
ones
([
64
,
64
,
64
]),
dtype
=
ms
.
float32
)
_executor
.
compile
(
net
,
x
,
y
)
def
test_embeddinglookup_semi_auto2
():
context
.
set_auto_parallel_context
(
device_num
=
8
,
global_rank
=
0
,
parallel_mode
=
"semi_auto_parallel"
)
shape
=
[
64
,
32
]
offset
=
0
strategy1
=
((
1
,
8
),
(
1
,
1
))
strategy2
=
((
4
,
1
,
2
),
(
4
,
2
,
1
))
net
=
GradWrap
(
NetWithLoss
(
Net
(
shape
,
offset
,
strategy1
,
strategy2
,
"CPU"
)))
net
.
set_auto_parallel
()
x
=
Tensor
(
np
.
ones
([
64
,
64
]),
dtype
=
ms
.
float32
)
y
=
Tensor
(
np
.
ones
([
64
,
64
,
64
]),
dtype
=
ms
.
float32
)
_executor
.
compile
(
net
,
x
,
y
)
tests/ut/python/parallel/test_gather_v2.py
浏览文件 @
bfc3065f
...
@@ -13,8 +13,6 @@
...
@@ -13,8 +13,6 @@
# limitations under the License.
# limitations under the License.
# ============================================================================
# ============================================================================
import
numpy
as
np
import
numpy
as
np
import
pytest
import
mindspore
as
ms
import
mindspore
as
ms
import
mindspore.nn
as
nn
import
mindspore.nn
as
nn
from
mindspore
import
Tensor
from
mindspore
import
Tensor
...
@@ -183,42 +181,3 @@ def test_gatherv2_auto1():
...
@@ -183,42 +181,3 @@ def test_gatherv2_auto1():
x
=
Tensor
(
np
.
ones
([
64
,
32
]),
dtype
=
ms
.
float32
)
x
=
Tensor
(
np
.
ones
([
64
,
32
]),
dtype
=
ms
.
float32
)
y
=
Tensor
(
np
.
ones
([
64
,
64
,
64
]),
dtype
=
ms
.
float32
)
y
=
Tensor
(
np
.
ones
([
64
,
64
,
64
]),
dtype
=
ms
.
float32
)
_executor
.
compile
(
net
,
x
,
y
)
_executor
.
compile
(
net
,
x
,
y
)
@
pytest
.
mark
.
skip
(
reason
=
"The transition from GatherV2 to EmbeddingLookup needs adjusting. by lichen"
)
def
test_gatherv2_cpu0
():
context
.
set_auto_parallel_context
(
device_num
=
8
,
global_rank
=
0
,
parallel_mode
=
"semi_auto_parallel"
)
strategy1
=
((
8
,
1
),
(
1
,
1
))
strategy2
=
((
4
,
2
,
1
),
(
4
,
2
,
1
))
net
=
NetWithLoss
(
Net
(
0
,
strategy1
,
strategy2
,
None
,
"CPU"
))
net
.
set_auto_parallel
()
x
=
Tensor
(
np
.
ones
([
64
,
64
]),
dtype
=
ms
.
float32
)
y
=
Tensor
(
np
.
ones
([
64
,
64
,
64
]),
dtype
=
ms
.
float32
)
_executor
.
compile
(
net
,
x
,
y
)
@
pytest
.
mark
.
skip
(
reason
=
"The transition from GatherV2 to EmbeddingLookup needs adjusting. by lichen"
)
def
test_gatherv2_cpu1
():
context
.
set_auto_parallel_context
(
device_num
=
16
,
global_rank
=
0
,
parallel_mode
=
"semi_auto_parallel"
)
strategy1
=
((
16
,
1
),
(
1
,
1
))
strategy2
=
((
4
,
2
,
1
),
(
4
,
2
,
1
))
net
=
NetWithLoss
(
Net
(
0
,
strategy1
,
strategy2
,
None
,
"CPU"
))
net
.
set_auto_parallel
()
x
=
Tensor
(
np
.
ones
([
64
,
64
]),
dtype
=
ms
.
float32
)
y
=
Tensor
(
np
.
ones
([
64
,
64
,
64
]),
dtype
=
ms
.
float32
)
_executor
.
compile
(
net
,
x
,
y
)
@
pytest
.
mark
.
skip
(
reason
=
"The transition from GatherV2 to EmbeddingLookup needs adjusting. by lichen"
)
def
test_gatherv2_cpu2
():
context
.
set_auto_parallel_context
(
device_num
=
8
,
global_rank
=
0
,
parallel_mode
=
"semi_auto_parallel"
)
strategy1
=
((
1
,
8
),
(
1
,
1
))
strategy2
=
((
4
,
2
,
1
),
(
4
,
2
,
1
))
net
=
NetWithLoss
(
Net
(
0
,
strategy1
,
strategy2
,
None
,
"CPU"
))
net
.
set_auto_parallel
()
x
=
Tensor
(
np
.
ones
([
64
,
64
]),
dtype
=
ms
.
float32
)
y
=
Tensor
(
np
.
ones
([
64
,
64
,
64
]),
dtype
=
ms
.
float32
)
_executor
.
compile
(
net
,
x
,
y
)
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录