Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
c8cdb6b3
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
c8cdb6b3
编写于
4月 07, 2020
作者:
C
c00425699
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
support distributed GatherV2 operator
上级
d949c17a
变更
11
隐藏空白更改
内联
并排
Showing
11 changed file
with
597 addition
and
47 deletion
+597
-47
mindspore/ccsrc/parallel/auto_parallel/operator_costmodel.cc
mindspore/ccsrc/parallel/auto_parallel/operator_costmodel.cc
+29
-0
mindspore/ccsrc/parallel/auto_parallel/operator_costmodel.h
mindspore/ccsrc/parallel/auto_parallel/operator_costmodel.h
+27
-0
mindspore/ccsrc/parallel/ops_info/batch_parallel_info.cc
mindspore/ccsrc/parallel/ops_info/batch_parallel_info.cc
+0
-20
mindspore/ccsrc/parallel/ops_info/batch_parallel_info.h
mindspore/ccsrc/parallel/ops_info/batch_parallel_info.h
+0
-9
mindspore/ccsrc/parallel/ops_info/gather_v2_info.cc
mindspore/ccsrc/parallel/ops_info/gather_v2_info.cc
+350
-0
mindspore/ccsrc/parallel/ops_info/gather_v2_info.h
mindspore/ccsrc/parallel/ops_info/gather_v2_info.h
+73
-0
mindspore/ccsrc/parallel/ops_info/operator_info.cc
mindspore/ccsrc/parallel/ops_info/operator_info.cc
+1
-0
mindspore/ccsrc/parallel/ops_info/operator_info.h
mindspore/ccsrc/parallel/ops_info/operator_info.h
+3
-0
mindspore/ccsrc/parallel/ops_info/ops_info_head_files.h
mindspore/ccsrc/parallel/ops_info/ops_info_head_files.h
+1
-0
mindspore/ccsrc/parallel/step_parallel.cc
mindspore/ccsrc/parallel/step_parallel.cc
+8
-0
tests/ut/python/parallel/test_gather_v2_primitive.py
tests/ut/python/parallel/test_gather_v2_primitive.py
+105
-18
未找到文件。
mindspore/ccsrc/parallel/auto_parallel/operator_costmodel.cc
浏览文件 @
c8cdb6b3
...
@@ -623,5 +623,34 @@ double DropOutCost::GetForwardComputationCost(const std::vector<TensorInfo>& inp
...
@@ -623,5 +623,34 @@ double DropOutCost::GetForwardComputationCost(const std::vector<TensorInfo>& inp
Shape
input0_slice_shape
=
input0
.
slice_shape
();
Shape
input0_slice_shape
=
input0
.
slice_shape
();
return
ListProduct
(
input0_slice_shape
)
*
static_cast
<
double
>
(
inputs_type_lengths_
[
0
])
*
DROPOUT_COST_RATE
;
return
ListProduct
(
input0_slice_shape
)
*
static_cast
<
double
>
(
inputs_type_lengths_
[
0
])
*
DROPOUT_COST_RATE
;
}
}
// return the per device communication cost in the forward phase.
double
GatherV2Cost
::
GetForwardCommCost
(
const
std
::
vector
<
TensorInfo
>&
,
const
std
::
vector
<
TensorInfo
>&
,
const
int32_t
&
)
const
{
// GatherV2Cost does not need communication in the forward phase
return
0.0
;
}
// return the per device communication cost in the backward phase.
double
GatherV2Cost
::
GetBackwardCommCost
(
const
std
::
vector
<
TensorInfo
>&
,
const
std
::
vector
<
TensorInfo
>&
,
const
int32_t
&
)
const
{
// GatherV2Cost does not need communication in the backward phase
return
0.0
;
}
double
GatherV2Cost
::
GetForwardComputationCost
(
const
std
::
vector
<
TensorInfo
>&
inputs
,
const
std
::
vector
<
TensorInfo
>&
,
const
int32_t
&
)
const
{
// In forward phase, the computation cost = slice(A) + slice(B)
Shape
input0_slice_shape
=
inputs
[
0
].
slice_shape
();
Shape
input1_slice_shape
=
inputs
[
1
].
slice_shape
();
double
result
=
ListProduct
(
input0_slice_shape
)
*
static_cast
<
double
>
(
inputs_type_lengths_
[
0
])
+
ListProduct
(
input1_slice_shape
)
*
static_cast
<
double
>
(
inputs_type_lengths_
[
1
]);
return
result
;
}
double
GatherV2Cost
::
GetBackwardComputationCost
(
const
std
::
vector
<
TensorInfo
>&
,
const
std
::
vector
<
TensorInfo
>&
,
const
int32_t
&
)
const
{
return
0.0
;
}
}
// namespace parallel
}
// namespace parallel
}
// namespace mindspore
}
// namespace mindspore
mindspore/ccsrc/parallel/auto_parallel/operator_costmodel.h
浏览文件 @
c8cdb6b3
...
@@ -81,6 +81,8 @@ class OperatorCost {
...
@@ -81,6 +81,8 @@ class OperatorCost {
std
::
vector
<
size_t
>
outputs_type_lengths_
;
std
::
vector
<
size_t
>
outputs_type_lengths_
;
};
};
using
OperatorCostPtr
=
std
::
shared_ptr
<
OperatorCost
>
;
class
MatMulCost
:
public
OperatorCost
{
class
MatMulCost
:
public
OperatorCost
{
public:
public:
MatMulCost
()
=
default
;
MatMulCost
()
=
default
;
...
@@ -525,6 +527,31 @@ class DropOutCost : public OperatorCost {
...
@@ -525,6 +527,31 @@ class DropOutCost : public OperatorCost {
};
};
using
DropOutCostPtr
=
std
::
shared_ptr
<
DropOutCost
>
;
using
DropOutCostPtr
=
std
::
shared_ptr
<
DropOutCost
>
;
class
GatherV2Cost
:
public
OperatorCost
{
public:
GatherV2Cost
()
=
default
;
~
GatherV2Cost
()
override
=
default
;
double
GetCommCost
(
const
std
::
vector
<
TensorInfo
>&
inputs
,
const
std
::
vector
<
TensorInfo
>&
outputs
,
const
int32_t
&
stage_id
)
const
override
{
return
GetForwardCommCost
(
inputs
,
outputs
,
stage_id
)
+
GetBackwardCommCost
(
inputs
,
outputs
,
stage_id
);
}
double
GetForwardCommCost
(
const
std
::
vector
<
TensorInfo
>&
inputs
,
const
std
::
vector
<
TensorInfo
>&
outputs
,
const
int32_t
&
stage_id
)
const
override
;
double
GetBackwardCommCost
(
const
std
::
vector
<
TensorInfo
>&
inputs
,
const
std
::
vector
<
TensorInfo
>&
outputs
,
const
int32_t
&
stage_id
)
const
override
;
double
GetComputationCost
(
const
std
::
vector
<
TensorInfo
>&
inputs
,
const
std
::
vector
<
TensorInfo
>&
outputs
,
const
int32_t
&
stage_id
)
const
override
{
return
GetForwardComputationCost
(
inputs
,
outputs
,
stage_id
)
+
GetBackwardComputationCost
(
inputs
,
outputs
,
stage_id
);
}
double
GetForwardComputationCost
(
const
std
::
vector
<
TensorInfo
>&
inputs
,
const
std
::
vector
<
TensorInfo
>&
outputs
,
const
int32_t
&
stage_id
)
const
override
;
double
GetBackwardComputationCost
(
const
std
::
vector
<
TensorInfo
>&
inputs
,
const
std
::
vector
<
TensorInfo
>&
outputs
,
const
int32_t
&
)
const
override
;
};
using
GatherV2CostPtr
=
std
::
shared_ptr
<
GatherV2Cost
>
;
}
// namespace parallel
}
// namespace parallel
}
// namespace mindspore
}
// namespace mindspore
#endif // PARALLEL_AUTO_PARALLEL_OPERATOR_COSTMODEL_H_
#endif // PARALLEL_AUTO_PARALLEL_OPERATOR_COSTMODEL_H_
mindspore/ccsrc/parallel/ops_info/batch_parallel_info.cc
浏览文件 @
c8cdb6b3
...
@@ -228,26 +228,6 @@ void SparseSoftmaxCrossEntropyWithLogitsInfo::ReComputeBatchSplitFlagList() {
...
@@ -228,26 +228,6 @@ void SparseSoftmaxCrossEntropyWithLogitsInfo::ReComputeBatchSplitFlagList() {
}
}
}
}
void
GatherV2Info
::
ReComputeBatchSplitFlagList
()
{
MS_ASSERT
(
inputs_shape_
.
size
()
==
2
);
MS_ASSERT
(
input_value_
.
size
()
==
3
);
MS_ASSERT
(
input_value_
[
0
]
==
nullptr
);
// the second input is the index tensor
MS_ASSERT
(
input_value_
[
1
]
!=
nullptr
);
// the third input is the axis
MS_ASSERT
(
input_value_
[
2
]
!=
nullptr
);
int
axis
=
GetValue
<
int
>
(
input_value_
[
2
]);
MS_ASSERT
(
axis
<
inputs_shape_
[
0
].
size
()
&&
axis
>=
0
-
inputs_shape_
[
0
].
size
());
if
(
axis
<
0
)
{
axis
+=
SizeToInt
(
inputs_shape_
[
0
].
size
());
}
split_flag_list_
[
0
]
=
true
;
// if gather axis is 0, the index's strategy is equal to device number
if
(
axis
==
0
)
{
split_flag_list_
[
1
]
=
true
;
}
}
Status
BatchParallelInfo
::
InferAsLossDivisor
()
{
Status
BatchParallelInfo
::
InferAsLossDivisor
()
{
as_loss_divisor_
=
1
;
as_loss_divisor_
=
1
;
return
SUCCESS
;
return
SUCCESS
;
...
...
mindspore/ccsrc/parallel/ops_info/batch_parallel_info.h
浏览文件 @
c8cdb6b3
...
@@ -62,15 +62,6 @@ class SparseSoftmaxCrossEntropyWithLogitsInfo : public BatchParallelInfo {
...
@@ -62,15 +62,6 @@ class SparseSoftmaxCrossEntropyWithLogitsInfo : public BatchParallelInfo {
~
SparseSoftmaxCrossEntropyWithLogitsInfo
()
override
=
default
;
~
SparseSoftmaxCrossEntropyWithLogitsInfo
()
override
=
default
;
void
ReComputeBatchSplitFlagList
()
override
;
void
ReComputeBatchSplitFlagList
()
override
;
};
};
class
GatherV2Info
:
public
BatchParallelInfo
{
public:
GatherV2Info
(
const
std
::
string
&
name
,
const
Shapes
&
inputs_shape
,
const
Shapes
&
outputs_shape
,
const
PrimitiveAttrs
&
attrs
)
:
BatchParallelInfo
(
name
,
inputs_shape
,
outputs_shape
,
attrs
)
{}
~
GatherV2Info
()
override
=
default
;
void
ReComputeBatchSplitFlagList
()
override
;
};
}
// namespace parallel
}
// namespace parallel
}
// namespace mindspore
}
// namespace mindspore
...
...
mindspore/ccsrc/parallel/ops_info/gather_v2_info.cc
0 → 100644
浏览文件 @
c8cdb6b3
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "parallel/ops_info/gather_v2_info.h"
#include <memory>
#include <utility>
#include <vector>
#include "ir/meta_tensor.h"
#include "ir/value.h"
#include "parallel/auto_parallel/costmodel.h"
#include "parallel/device_matrix.h"
#include "parallel/graph_util/generate_graph.h"
#include "parallel/strategy.h"
#include "utils/log_adapter.h"
namespace
mindspore
{
namespace
parallel
{
Status
GatherV2Info
::
GetAttrs
()
{
if
(
inputs_shape_
.
size
()
!=
GATHER_V2_INPUTS_SIZE
)
{
MS_LOG
(
ERROR
)
<<
name_
<<
": inputs shape size must be 2, but is "
<<
inputs_shape_
.
size
();
return
FAILED
;
}
if
(
outputs_shape_
.
size
()
!=
GATHER_V2_OUTPUTS_SIZE
)
{
MS_LOG
(
ERROR
)
<<
name_
<<
": outputs shape size must be 1, but is "
<<
outputs_shape_
.
size
();
return
FAILED
;
}
if
(
input_value_
.
size
()
!=
GATHER_V2_INPUTS_VALUE_SIZE
)
{
MS_LOG
(
ERROR
)
<<
name_
<<
": input value size must be 3, but is "
<<
input_value_
.
size
();
return
FAILED
;
}
// the second input is the index tensor
// the third input is the axis, is a ValueNode
if
(
input_value_
.
at
(
2
)
==
nullptr
)
{
MS_LOG
(
ERROR
)
<<
name_
<<
": the third input value is nullptr, is not a ValueNode!"
;
return
FAILED
;
}
if
(
inputs_shape_
.
at
(
0
).
size
()
==
0
)
{
MS_LOG
(
ERROR
)
<<
name_
<<
": input can not be a scalar!"
;
return
FAILED
;
}
int
axis
=
GetValue
<
int
>
(
input_value_
.
at
(
2
));
if
(
axis
>=
SizeToInt
(
inputs_shape_
.
at
(
0
).
size
())
||
axis
<
0
-
SizeToInt
(
inputs_shape_
.
at
(
0
).
size
()))
{
MS_LOG
(
ERROR
)
<<
"Axis is "
<<
axis
<<
", not in [-"
<<
inputs_shape_
.
at
(
0
).
size
()
<<
", "
<<
inputs_shape_
.
at
(
0
).
size
()
<<
")."
;
}
if
(
axis
<
0
)
{
axis
+=
SizeToInt
(
inputs_shape_
[
0
].
size
());
}
axis_
=
axis
;
index_size_
=
inputs_shape_
.
at
(
1
).
size
();
return
SUCCESS
;
}
Status
GatherV2Info
::
CheckStrategy
(
const
StrategyPtr
&
strategy
)
{
if
(
inputs_shape_
.
size
()
!=
GATHER_V2_INPUTS_SIZE
)
{
MS_LOG
(
ERROR
)
<<
name_
<<
": inputs shape size must be "
<<
GATHER_V2_INPUTS_SIZE
<<
", but is "
<<
inputs_shape_
.
size
();
return
FAILED
;
}
if
(
outputs_shape_
.
size
()
!=
GATHER_V2_OUTPUTS_SIZE
)
{
MS_LOG
(
ERROR
)
<<
name_
<<
": outputs shape size must be "
<<
GATHER_V2_OUTPUTS_SIZE
<<
", but is "
<<
outputs_shape_
.
size
();
return
FAILED
;
}
// Only strategy of the first input should be set.
if
(
CheckStrategyValue
(
strategy
,
{
inputs_shape_
.
at
(
0
)},
is_auto_parallel_
)
!=
SUCCESS
)
{
if
(
is_auto_parallel_
)
{
MS_LOG
(
DEBUG
)
<<
name_
<<
": Invalid strategy."
;
}
else
{
MS_LOG
(
ERROR
)
<<
name_
<<
": Invalid strategy."
;
}
return
FAILED
;
}
axis_strategy_
=
strategy
->
GetInputDim
().
at
(
0
).
at
(
axis_
);
if
(
index_size_
!=
1
&&
axis_strategy_
!=
1
)
{
MS_LOG
(
ERROR
)
<<
name_
<<
": Invalid strategy. If the index is a scalar or a more than 1 dimension vector, the strategy "
"corresponding to axis must be 1, but is "
<<
axis_strategy_
;
return
FAILED
;
}
if
(
index_size_
==
1
&&
axis_strategy_
!=
1
&&
inputs_shape_
.
at
(
1
).
at
(
0
)
%
axis_strategy_
!=
0
)
{
MS_LOG
(
ERROR
)
<<
name_
<<
": Invalid strategy. The first dimension of index can not be divided by strategy corresponding to "
"axis. The first dimension of index is "
<<
inputs_shape_
.
at
(
1
).
at
(
0
)
<<
" strategy corresponding to axis is "
<<
axis_strategy_
;
return
FAILED
;
}
return
SUCCESS
;
}
Status
GatherV2Info
::
InferDevMatrixShape
()
{
std
::
vector
<
Dimensions
>
stra
=
strategy_
->
GetInputDim
();
dev_matrix_shape_
=
stra
.
at
(
0
);
return
SUCCESS
;
}
// If index is a scalar, output dimension is input dimension minus 1;
// If index is a n dimension tensor, output dimension is input dimension plus (n - 1).
// Tensor map dimension is equal to the corresponding input and output dimension.
// If index's dimension is more than 1, we insert -1 for the output tensor map.
Status
GatherV2Info
::
InferTensorMap
()
{
if
(
inputs_shape_
.
size
()
!=
GATHER_V2_INPUTS_SIZE
)
{
MS_LOG
(
ERROR
)
<<
name_
<<
": inputs shape size must be "
<<
GATHER_V2_INPUTS_SIZE
<<
", but is "
<<
inputs_shape_
.
size
();
return
FAILED
;
}
if
(
outputs_shape_
.
size
()
!=
GATHER_V2_OUTPUTS_SIZE
)
{
MS_LOG
(
ERROR
)
<<
name_
<<
": outputs shape size must be "
<<
GATHER_V2_OUTPUTS_SIZE
<<
", but is "
<<
outputs_shape_
.
size
();
return
FAILED
;
}
std
::
vector
<
int32_t
>
tensor_map_in
;
std
::
vector
<
int32_t
>
tensor_map_out
;
size_t
size
=
inputs_shape_
.
at
(
0
).
size
();
// such as 4: tensor_map_index [3,2,1,0]
for
(
size_t
i
=
0
;
i
<
size
;
++
i
)
{
tensor_map_in
.
push_back
(
SizeToInt
(
size
-
i
-
1
));
tensor_map_out
.
push_back
(
SizeToInt
(
size
-
i
-
1
));
}
if
(
index_size_
==
0
)
{
(
void
)
tensor_map_out
.
erase
(
tensor_map_out
.
begin
()
+
axis_
);
}
else
if
(
index_size_
>
1
)
{
(
void
)
tensor_map_out
.
insert
(
tensor_map_out
.
begin
()
+
axis_
,
index_size_
-
1
,
-
1
);
}
if
(
tensor_map_out
.
size
()
!=
outputs_shape_
.
at
(
0
).
size
())
{
MS_LOG
(
ERROR
)
<<
"Out tensor map size is not equal to output size! Out tensor map size is "
<<
tensor_map_out
.
size
()
<<
" output size is "
<<
outputs_shape_
.
at
(
0
).
size
();
return
FAILED
;
}
std
::
vector
<
int32_t
>
tensor_map_in_index
;
if
(
index_size_
>=
1
)
{
tensor_map_in_index
.
push_back
(
SizeToInt
(
size
-
axis_
-
1
));
}
for
(
size_t
i
=
1
;
i
<
index_size_
;
++
i
)
{
tensor_map_in_index
.
push_back
(
-
1
);
}
inputs_tensor_map_
.
emplace_back
(
std
::
move
(
tensor_map_in
));
inputs_tensor_map_
.
emplace_back
(
std
::
move
(
tensor_map_in_index
));
outputs_tensor_map_
.
emplace_back
(
std
::
move
(
tensor_map_out
));
return
SUCCESS
;
}
Status
GatherV2Info
::
InferTensorInfo
()
{
if
(
inputs_shape_
.
size
()
!=
GATHER_V2_INPUTS_SIZE
)
{
MS_LOG
(
ERROR
)
<<
name_
<<
": inputs shape size must be "
<<
GATHER_V2_INPUTS_SIZE
<<
", but is "
<<
inputs_shape_
.
size
();
return
FAILED
;
}
if
(
outputs_shape_
.
size
()
!=
GATHER_V2_OUTPUTS_SIZE
)
{
MS_LOG
(
ERROR
)
<<
name_
<<
": outputs shape size must be "
<<
GATHER_V2_OUTPUTS_SIZE
<<
", but is "
<<
outputs_shape_
.
size
();
return
FAILED
;
}
if
(
inputs_tensor_map_
.
size
()
!=
GATHER_V2_INPUTS_SIZE
)
{
MS_LOG
(
ERROR
)
<<
name_
<<
": inputs tensor map size must be "
<<
GATHER_V2_INPUTS_SIZE
<<
", but is "
<<
inputs_tensor_map_
.
size
();
return
FAILED
;
}
if
(
outputs_tensor_map_
.
size
()
!=
GATHER_V2_OUTPUTS_SIZE
)
{
MS_LOG
(
ERROR
)
<<
name_
<<
": outputs tensor map size must be "
<<
GATHER_V2_OUTPUTS_SIZE
<<
", but is "
<<
outputs_tensor_map_
.
size
();
return
FAILED
;
}
// infer tensor shape
Shape
input_shape
=
inputs_shape_
.
at
(
0
);
Shape
input_index_shape
=
inputs_shape_
.
at
(
1
);
Shape
output_shape
=
outputs_shape_
.
at
(
0
);
TensorLayout
input_tensor_layout
,
input_index_layout
,
output_tensor_layout
;
if
((
input_tensor_layout
.
InitFromVector
(
dev_matrix_shape_
,
inputs_tensor_map_
.
at
(
0
),
input_shape
)
!=
SUCCESS
)
||
(
input_index_layout
.
InitFromVector
(
dev_matrix_shape_
,
inputs_tensor_map_
.
at
(
1
),
input_index_shape
)
!=
SUCCESS
)
||
(
output_tensor_layout
.
InitFromVector
(
dev_matrix_shape_
,
outputs_tensor_map_
.
at
(
0
),
output_shape
)
!=
SUCCESS
))
{
return
FAILED
;
}
TensorInfo
input_tensor_info
(
input_tensor_layout
);
TensorInfo
input_index_info
(
input_index_layout
);
TensorInfo
output_tensor_info
(
output_tensor_layout
);
inputs_tensor_info_
.
push_back
(
input_tensor_info
);
inputs_tensor_info_
.
push_back
(
input_index_info
);
outputs_tensor_info_
.
push_back
(
output_tensor_info
);
return
SUCCESS
;
}
OperatorVector
CreateSubOp
(
int32_t
sub_value
)
{
OperatorVector
ops
;
OperatorName
operator_name
=
SUB
;
OperatorAttrs
operator_attrs
;
py
::
tuple
tuple
=
py
::
make_tuple
(
sub_value
);
mindspore
::
tensor
::
TensorPtr
tensor_ptr
=
std
::
make_shared
<
mindspore
::
tensor
::
Tensor
>
(
tuple
,
kInt32
);
ValuePtr
op_param_value
=
MakeValue
(
tensor_ptr
);
Attr
op1_param
=
std
::
make_pair
(
""
,
op_param_value
);
OperatorParams
operator_param
=
{
std
::
make_pair
(
op1_param
,
2
)};
OperatorArgs
operator_args
=
std
::
make_pair
(
operator_attrs
,
operator_param
);
Operator
op
=
std
::
make_pair
(
operator_name
,
operator_args
);
ops
.
push_back
(
op
);
return
ops
;
}
Status
GatherV2Info
::
InferTensorSubOps
()
{
sub_ops_
.
clear
();
if
((
index_size_
==
0
)
||
(
axis_strategy_
==
1
))
{
return
SUCCESS
;
}
int32_t
mod_n
=
1
;
for
(
size_t
i
=
IntToSize
(
axis_
)
+
1
;
i
<
dev_matrix_shape_
.
size
();
i
++
)
{
mod_n
*=
dev_matrix_shape_
.
at
(
i
);
}
if
((
axis_
>=
SizeToInt
(
dev_matrix_shape_
.
size
()))
||
axis_
<
0
)
{
MS_LOG
(
ERROR
)
<<
"Axis is "
<<
axis_
<<
", not in [0, "
<<
dev_matrix_shape_
.
size
()
<<
")."
;
}
int32_t
mod_p
=
mod_n
*
dev_matrix_shape_
.
at
(
axis_
);
int32_t
rank
=
g_device_manager
->
global_rank
();
int32_t
mod_rank
=
rank
%
mod_p
;
mod_rank
=
static_cast
<
int32_t
>
(
mod_rank
/
mod_n
);
if
(
inputs_shape_
.
size
()
!=
GATHER_V2_INPUTS_SIZE
)
{
MS_LOG
(
ERROR
)
<<
name_
<<
": inputs shape size must be "
<<
GATHER_V2_INPUTS_SIZE
<<
", but is "
<<
inputs_shape_
.
size
();
return
FAILED
;
}
if
((
axis_
>=
SizeToInt
(
inputs_shape_
.
at
(
0
).
size
()))
||
axis_
<
0
)
{
MS_LOG
(
ERROR
)
<<
"Axis is "
<<
axis_
<<
", not in [0, "
<<
inputs_shape_
.
at
(
0
).
size
()
<<
")."
;
}
int32_t
sub_value
=
static_cast
<
int32_t
>
(
inputs_shape_
.
at
(
0
).
at
(
axis_
)
/
dev_matrix_shape_
.
at
(
axis_
))
*
mod_rank
;
OperatorVector
sub_op
;
sub_ops_
.
emplace_back
(
std
::
move
(
sub_op
));
sub_op
=
CreateSubOp
(
sub_value
);
sub_ops_
.
emplace_back
(
std
::
move
(
sub_op
));
return
SUCCESS
;
}
Status
GatherV2Info
::
Init
(
const
StrategyPtr
&
strategy
)
{
if
(
InitWithAutoRepeatCalc
(
strategy
)
!=
SUCCESS
)
{
MS_LOG
(
ERROR
)
<<
name_
<<
": Init failed."
;
return
FAILED
;
}
Status
status
=
InferTensorSubOps
();
if
(
status
!=
SUCCESS
)
{
MS_LOG
(
ERROR
)
<<
name_
<<
": InferTensorSubOps failed."
;
return
status
;
}
MS_LOG
(
INFO
)
<<
name_
<<
": Init success."
;
return
SUCCESS
;
}
Status
GatherV2Info
::
InitForCostModel
(
const
StrategyPtr
&
strategy
)
{
if
(
InitForCostModelWithAutoRepeatCalc
(
strategy
)
!=
SUCCESS
)
{
if
(
is_auto_parallel_
)
{
MS_LOG
(
DEBUG
)
<<
name_
<<
": Init for cost model failed."
;
}
else
{
MS_LOG
(
ERROR
)
<<
name_
<<
": Init for cost model failed."
;
}
return
FAILED
;
}
MS_LOG
(
INFO
)
<<
name_
<<
": Init for cost model success."
;
return
SUCCESS
;
}
Status
GatherV2Info
::
GenerateStrategies
(
int32_t
stage_id
)
{
if
((
inputs_shape_
.
size
()
!=
GATHER_V2_INPUTS_SIZE
)
||
(
outputs_shape_
.
size
()
!=
GATHER_V2_OUTPUTS_SIZE
))
{
MS_LOG
(
ERROR
)
<<
name_
<<
" : Inputs shape size("
<<
inputs_shape_
.
size
()
<<
") or outputs shape size("
<<
outputs_shape_
.
size
()
<<
"is wrong."
;
return
FAILED
;
}
is_auto_parallel_
=
true
;
Shape
input0_split
(
inputs_shape_
[
0
].
size
());
Shapes
splittable_inputs
=
{
input0_split
};
std
::
vector
<
StrategyPtr
>
sp_vector
;
if
(
GenerateStrategiesForIndependentInputs
(
stage_id
,
{
inputs_shape_
.
at
(
0
)},
splittable_inputs
,
&
sp_vector
)
!=
SUCCESS
)
{
MS_LOG
(
ERROR
)
<<
name_
<<
" : Generate strategies for independent inputs() failed."
;
return
FAILED
;
}
size_t
success
=
0
;
for
(
auto
&
sp
:
sp_vector
)
{
if
(
SetCostUnderStrategy
(
sp
)
==
SUCCESS
)
{
success
++
;
MS_LOG
(
INFO
)
<<
name_
<<
" : Successfully generated "
<<
success
<<
" strategy"
;
PrintStrategy
(
sp
);
}
}
return
SUCCESS
;
}
Status
GatherV2Info
::
SetCostUnderStrategy
(
const
StrategyPtr
&
strategy
)
{
if
(
SetCostUnderStrategyBase
(
strategy
)
!=
SUCCESS
)
{
if
(
is_auto_parallel_
)
{
MS_LOG
(
DEBUG
)
<<
name_
<<
": Set cost under strategy failed."
;
}
else
{
MS_LOG
(
ERROR
)
<<
name_
<<
": Set cost under strategy failed."
;
}
return
FAILED
;
}
return
SUCCESS
;
}
std
::
shared_ptr
<
std
::
vector
<
std
::
vector
<
int32_t
>>>
GatherV2Info
::
GenerateBatchStrategies
()
{
if
(
inputs_shape_
.
size
()
!=
GATHER_V2_INPUTS_SIZE
)
{
MS_LOG
(
EXCEPTION
)
<<
name_
<<
": inputs shape size must be "
<<
GATHER_V2_INPUTS_SIZE
<<
", but is "
<<
inputs_shape_
.
size
();
}
CheckGlobalDeviceManager
();
size_t
dev_num
=
g_device_manager
->
GetDeviceListByStageId
(
0
).
size
();
if
(
GetAttrs
()
!=
SUCCESS
)
{
MS_LOG
(
EXCEPTION
)
<<
"GetAttrs failed!"
;
}
Dimensions
strategy
;
if
(
index_size_
!=
1
)
{
strategy
.
push_back
(
1
);
}
else
{
strategy
.
push_back
(
SizeToInt
(
dev_num
));
}
for
(
size_t
i
=
1
;
i
<
inputs_shape_
[
0
].
size
();
i
++
)
{
strategy
.
push_back
(
1
);
}
std
::
vector
<
Dimensions
>
strategy_v
=
{
strategy
};
return
std
::
make_shared
<
std
::
vector
<
std
::
vector
<
int32_t
>>>
(
strategy_v
);
}
}
// namespace parallel
}
// namespace mindspore
mindspore/ccsrc/parallel/ops_info/gather_v2_info.h
0 → 100644
浏览文件 @
c8cdb6b3
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_PARALLEL_OPS_INFO_GATHER_V2_INFO_H_
#define MINDSPORE_CCSRC_PARALLEL_OPS_INFO_GATHER_V2_INFO_H_
#include <memory>
#include <string>
#include <unordered_map>
#include <vector>
#include "ir/value.h"
#include "parallel/auto_parallel/operator_costmodel.h"
#include "parallel/ops_info/operator_info.h"
#include "parallel/strategy.h"
namespace
mindspore
{
namespace
parallel
{
constexpr
size_t
GATHER_V2_INPUTS_SIZE
=
2
;
constexpr
size_t
GATHER_V2_OUTPUTS_SIZE
=
1
;
constexpr
size_t
GATHER_V2_INPUTS_VALUE_SIZE
=
3
;
// We now supported limited parallel strategies.
// If the strategy corresponding to axis is more than 1, index must be evenly distributed across the axis-dimension of
// the input.
// If Index is a scalar or n-dimension vector(n > 1), the strategy corresponding to axis must be 1.
class
GatherV2Info
:
public
OperatorInfo
{
public:
GatherV2Info
(
const
std
::
string
&
name
,
const
Shapes
&
inputs_shape
,
const
Shapes
&
outputs_shape
,
const
PrimitiveAttrs
&
attrs
)
:
OperatorInfo
(
name
,
inputs_shape
,
outputs_shape
,
attrs
,
std
::
make_shared
<
GatherV2Cost
>
()),
axis_
(
-
1
),
index_size_
(
0
),
axis_strategy_
(
1
)
{}
~
GatherV2Info
()
override
=
default
;
Status
Init
(
const
StrategyPtr
&
strategy
)
override
;
Status
InitForCostModel
(
const
StrategyPtr
&
strategy
)
override
;
Status
GenerateStrategies
(
int32_t
stage_id
)
override
;
Status
SetCostUnderStrategy
(
const
StrategyPtr
&
strategy
)
override
;
std
::
shared_ptr
<
std
::
vector
<
std
::
vector
<
int32_t
>>>
GenerateBatchStrategies
()
override
;
protected:
Status
CheckStrategy
(
const
StrategyPtr
&
strategy
)
override
;
Status
InferMirrorOps
()
override
{
return
SUCCESS
;
}
Status
InferForwardCommunication
()
override
{
return
SUCCESS
;
}
Status
InferTensorInfo
()
override
;
Status
InferDevMatrixShape
()
override
;
Status
InferTensorMap
()
override
;
Status
GetAttrs
()
override
;
private:
Status
InferTensorSubOps
();
int32_t
axis_
;
size_t
index_size_
;
int32_t
axis_strategy_
;
};
}
// namespace parallel
}
// namespace mindspore
#endif // MINDSPORE_CCSRC_PARALLEL_OPS_INFO_GATHER_V2_INFO_H_
mindspore/ccsrc/parallel/ops_info/operator_info.cc
浏览文件 @
c8cdb6b3
...
@@ -112,6 +112,7 @@ void OperatorInfo::ResetQueueMember() {
...
@@ -112,6 +112,7 @@ void OperatorInfo::ResetQueueMember() {
dev_matrix_shape_
.
clear
();
dev_matrix_shape_
.
clear
();
forward_op_
.
clear
();
forward_op_
.
clear
();
mirror_ops_
.
clear
();
mirror_ops_
.
clear
();
sub_ops_
.
clear
();
replace_op_
.
clear
();
replace_op_
.
clear
();
replace_op_info_
.
clear
();
replace_op_info_
.
clear
();
virtual_div_op_
.
clear
();
virtual_div_op_
.
clear
();
...
...
mindspore/ccsrc/parallel/ops_info/operator_info.h
浏览文件 @
c8cdb6b3
...
@@ -41,6 +41,7 @@ namespace mindspore {
...
@@ -41,6 +41,7 @@ namespace mindspore {
namespace
parallel
{
namespace
parallel
{
using
ForwardOp
=
OperatorVector
;
using
ForwardOp
=
OperatorVector
;
using
MirrorOps
=
std
::
vector
<
OperatorVector
>
;
using
MirrorOps
=
std
::
vector
<
OperatorVector
>
;
using
Ops
=
std
::
vector
<
OperatorVector
>
;
using
VirtualDivOp
=
OperatorVector
;
using
VirtualDivOp
=
OperatorVector
;
using
TensorMaps
=
std
::
vector
<
std
::
vector
<
int32_t
>>
;
using
TensorMaps
=
std
::
vector
<
std
::
vector
<
int32_t
>>
;
using
TensorLayouts
=
std
::
vector
<
TensorLayout
>
;
using
TensorLayouts
=
std
::
vector
<
TensorLayout
>
;
...
@@ -99,6 +100,7 @@ class OperatorInfo {
...
@@ -99,6 +100,7 @@ class OperatorInfo {
OutPutInfoVector
replace_op_info
()
const
{
return
replace_op_info_
;
}
OutPutInfoVector
replace_op_info
()
const
{
return
replace_op_info_
;
}
virtual
ReplaceGraphPtr
replace_graph
(
const
CNodePtr
&
)
{
return
replace_graph_
;
}
virtual
ReplaceGraphPtr
replace_graph
(
const
CNodePtr
&
)
{
return
replace_graph_
;
}
MirrorOps
mirror_ops
()
const
{
return
mirror_ops_
;
}
MirrorOps
mirror_ops
()
const
{
return
mirror_ops_
;
}
Ops
sub_ops
()
const
{
return
sub_ops_
;
}
VirtualDivOp
virtual_div_op
()
const
{
return
virtual_div_op_
;
}
VirtualDivOp
virtual_div_op
()
const
{
return
virtual_div_op_
;
}
Shape
dev_matrix_shape
()
const
{
return
dev_matrix_shape_
;
}
Shape
dev_matrix_shape
()
const
{
return
dev_matrix_shape_
;
}
std
::
vector
<
TensorInfo
>
inputs_tensor_info
()
const
{
return
inputs_tensor_info_
;
}
std
::
vector
<
TensorInfo
>
inputs_tensor_info
()
const
{
return
inputs_tensor_info_
;
}
...
@@ -190,6 +192,7 @@ class OperatorInfo {
...
@@ -190,6 +192,7 @@ class OperatorInfo {
TensorMaps
inputs_tensor_map_
;
TensorMaps
inputs_tensor_map_
;
TensorMaps
outputs_tensor_map_
;
TensorMaps
outputs_tensor_map_
;
ForwardOp
forward_op_
;
ForwardOp
forward_op_
;
Ops
sub_ops_
;
ForwardOp
replace_op_
;
ForwardOp
replace_op_
;
OutPutInfoVector
replace_op_info_
;
OutPutInfoVector
replace_op_info_
;
ReplaceGraphPtr
replace_graph_
;
ReplaceGraphPtr
replace_graph_
;
...
...
mindspore/ccsrc/parallel/ops_info/ops_info_head_files.h
浏览文件 @
c8cdb6b3
...
@@ -24,6 +24,7 @@
...
@@ -24,6 +24,7 @@
#include "parallel/ops_info/comparison_function_info.h"
#include "parallel/ops_info/comparison_function_info.h"
#include "parallel/ops_info/dropout_do_mask_info.h"
#include "parallel/ops_info/dropout_do_mask_info.h"
#include "parallel/ops_info/elementary_function_info.h"
#include "parallel/ops_info/elementary_function_info.h"
#include "parallel/ops_info/gather_v2_info.h"
#include "parallel/ops_info/get_next_info.h"
#include "parallel/ops_info/get_next_info.h"
#include "parallel/ops_info/l2_normalize_info.h"
#include "parallel/ops_info/l2_normalize_info.h"
#include "parallel/ops_info/loss_info.h"
#include "parallel/ops_info/loss_info.h"
...
...
mindspore/ccsrc/parallel/step_parallel.cc
浏览文件 @
c8cdb6b3
...
@@ -464,6 +464,14 @@ void SplitTensor(const AnfNodePtr& node, const CNodePtr& next_node, int index) {
...
@@ -464,6 +464,14 @@ void SplitTensor(const AnfNodePtr& node, const CNodePtr& next_node, int index) {
MS_EXCEPTION_IF_NULL
(
func_graph
);
MS_EXCEPTION_IF_NULL
(
func_graph
);
Operator
op
=
CreateGetTensorSliceOp
(
tensor_layout
);
Operator
op
=
CreateGetTensorSliceOp
(
tensor_layout
);
InsertGetTensorSliceOp
(
op
,
next_node
,
func_graph
,
index
,
SPLIT_TENSOR
);
InsertGetTensorSliceOp
(
op
,
next_node
,
func_graph
,
index
,
SPLIT_TENSOR
);
if
(
!
op_info
->
sub_ops
().
empty
())
{
auto
sub_ops
=
op_info
->
sub_ops
();
for
(
size_t
i
=
0
;
i
<
sub_ops
.
size
();
i
++
)
{
if
(
!
sub_ops
.
at
(
i
).
empty
())
{
InsertGetTensorSliceOp
(
sub_ops
.
at
(
i
).
at
(
0
),
next_node
,
func_graph
,
index
,
SUB
);
}
}
}
}
}
void
StepSplitTensor
(
const
AnfNodePtr
&
node
,
const
FuncGraphManagerPtr
&
manager
)
{
void
StepSplitTensor
(
const
AnfNodePtr
&
node
,
const
FuncGraphManagerPtr
&
manager
)
{
...
...
tests/ut/python/parallel/test_gather_v2_primitive.py
浏览文件 @
c8cdb6b3
...
@@ -29,6 +29,8 @@ from mindspore.nn import Dense, Cell
...
@@ -29,6 +29,8 @@ from mindspore.nn import Dense, Cell
from
mindspore
import
context
from
mindspore
import
context
context
.
set_context
(
mode
=
context
.
GRAPH_MODE
)
context
.
set_context
(
mode
=
context
.
GRAPH_MODE
)
device_number
=
32
batch_size_per_device
=
128
class
Dataset
():
class
Dataset
():
...
@@ -57,15 +59,22 @@ class Dataset():
...
@@ -57,15 +59,22 @@ class Dataset():
class
GatherV2
(
_Loss
):
class
GatherV2
(
_Loss
):
def
__init__
(
self
,
batchsize
):
def
__init__
(
self
,
index_dim
,
strategy
,
index_size
=
16
):
super
(
GatherV2
,
self
).
__init__
()
super
(
GatherV2
,
self
).
__init__
()
self
.
pow
=
P
.
Pow
()
self
.
pow
=
P
.
Pow
()
emb_list
=
list
(
range
(
batchsize
))
emb1_list
=
21
emb1_list
=
emb_list
[
0
::
2
]
emb2_list
=
2
emb2_list
=
emb_list
[
1
::
2
]
if
index_dim
==
1
:
emb_list
=
list
(
range
(
index_size
))
emb1_list
=
emb_list
[
0
::
2
]
emb2_list
=
emb_list
[
1
::
2
]
if
index_dim
==
2
:
emb_list
=
np
.
arange
(
index_size
*
16
)
emb1_list
=
np
.
reshape
(
emb_list
[
0
::
2
],
(
int
(
index_size
/
2
),
16
))
emb2_list
=
np
.
reshape
(
emb_list
[
1
::
2
],
(
int
(
index_size
/
2
),
16
))
self
.
emb1_param
=
Tensor
(
emb1_list
,
dtype
=
mstype
.
int32
)
self
.
emb1_param
=
Tensor
(
emb1_list
,
dtype
=
mstype
.
int32
)
self
.
emb2_param
=
Tensor
(
emb2_list
,
dtype
=
mstype
.
int32
)
self
.
emb2_param
=
Tensor
(
emb2_list
,
dtype
=
mstype
.
int32
)
self
.
gatherv2
=
P
.
GatherV2
()
self
.
gatherv2
=
P
.
GatherV2
()
.
set_strategy
(
strategy
)
def
construct
(
self
,
nembeddings
):
def
construct
(
self
,
nembeddings
):
emb1
=
self
.
gatherv2
(
nembeddings
,
self
.
emb1_param
,
0
)
emb1
=
self
.
gatherv2
(
nembeddings
,
self
.
emb1_param
,
0
)
...
@@ -73,10 +82,6 @@ class GatherV2(_Loss):
...
@@ -73,10 +82,6 @@ class GatherV2(_Loss):
return
self
.
pow
((
emb1
-
emb2
),
2.0
)
return
self
.
pow
((
emb1
-
emb2
),
2.0
)
def
get_loss
(
batchsize
):
return
GatherV2
(
batchsize
)
def
fc_with_initialize
(
input_channels
,
out_channels
):
def
fc_with_initialize
(
input_channels
,
out_channels
):
return
Dense
(
input_channels
,
out_channels
)
return
Dense
(
input_channels
,
out_channels
)
...
@@ -114,26 +119,23 @@ class TrainOneStepCell(Cell):
...
@@ -114,26 +119,23 @@ class TrainOneStepCell(Cell):
return
F
.
depend
(
loss
,
self
.
optimizer
(
grads
))
return
F
.
depend
(
loss
,
self
.
optimizer
(
grads
))
def
test_trains
(
):
def
net_trains
(
gather_v2_strategy
,
criterion
,
rank
):
init
()
init
()
lr
=
0.1
lr
=
0.1
momentum
=
0.9
momentum
=
0.9
max_epoch
=
20
max_epoch
=
20
device_number
=
32
batch_size_per_device
=
128
input_channels
=
256
input_channels
=
256
out_channels
=
512
out_channels
=
512
context
.
set_context
(
mode
=
context
.
GRAPH_MODE
,
save_graphs
=
False
)
context
.
reset_auto_parallel_context
()
context
.
reset_auto_parallel_context
()
context
.
set_auto_parallel_context
(
parallel_mode
=
ParallelMode
.
SEMI_AUTO_PARALLEL
,
device_num
=
device_number
)
context
.
set_auto_parallel_context
(
parallel_mode
=
ParallelMode
.
SEMI_AUTO_PARALLEL
,
device_num
=
device_number
,
global_rank
=
rank
)
predict
=
Tensor
(
np
.
ones
([
batch_size_per_device
,
input_channels
]),
dtype
=
ms
.
float32
)
predict
=
Tensor
(
np
.
ones
([
batch_size_per_device
,
input_channels
]),
dtype
=
ms
.
float32
)
dataset
=
Dataset
(
predict
,
4
)
dataset
=
Dataset
(
predict
,
4
)
network
=
fc_with_initialize
(
input_channels
,
out_channels
)
network
=
fc_with_initialize
(
input_channels
,
out_channels
)
network
.
set_train
()
network
.
set_train
()
criterion
=
get_loss
(
batch_size_per_device
*
device_number
)
train_network
=
BuildTrainNetwork
(
network
,
criterion
)
train_network
=
BuildTrainNetwork
(
network
,
criterion
)
train_network
.
set_train
()
train_network
.
set_train
()
opt
=
Momentum
(
train_network
.
trainable_params
(),
lr
,
momentum
)
opt
=
Momentum
(
train_network
.
trainable_params
(),
lr
,
momentum
)
...
@@ -143,5 +145,90 @@ def test_trains():
...
@@ -143,5 +145,90 @@ def test_trains():
model
.
train
(
max_epoch
,
dataset
,
dataset_sink_mode
=
False
)
model
.
train
(
max_epoch
,
dataset
,
dataset_sink_mode
=
False
)
context
.
reset_auto_parallel_context
()
context
.
reset_auto_parallel_context
()
if
__name__
==
"__main__"
:
test_trains
()
def
test_auto_batch_parallel
():
gather_v2_strategy
=
None
criterion
=
GatherV2
(
1
,
strategy
=
gather_v2_strategy
,
index_size
=
batch_size_per_device
*
device_number
)
rank
=
2
net_trains
(
gather_v2_strategy
,
criterion
,
rank
)
def
test_2d_index_auto_batch_parallel
():
gather_v2_strategy
=
None
criterion
=
GatherV2
(
2
,
strategy
=
gather_v2_strategy
,
index_size
=
batch_size_per_device
*
device_number
)
rank
=
2
net_trains
(
gather_v2_strategy
,
criterion
,
rank
)
def
test_batch_parallel
():
gather_v2_strategy
=
((
device_number
,
1
),)
criterion
=
GatherV2
(
1
,
strategy
=
gather_v2_strategy
,
index_size
=
batch_size_per_device
*
device_number
)
rank
=
2
net_trains
(
gather_v2_strategy
,
criterion
,
rank
)
def
test_strategy1
():
gather_v2_strategy
=
((
16
,
2
),)
rank
=
2
criterion
=
GatherV2
(
1
,
strategy
=
gather_v2_strategy
,
index_size
=
batch_size_per_device
*
device_number
)
net_trains
(
gather_v2_strategy
,
criterion
,
rank
)
def
test_strategy2
():
gather_v2_strategy
=
((
1
,
device_number
),)
rank
=
2
criterion
=
GatherV2
(
1
,
strategy
=
gather_v2_strategy
,
index_size
=
batch_size_per_device
*
device_number
)
net_trains
(
gather_v2_strategy
,
criterion
,
rank
)
def
test_strategy3
():
gather_v2_strategy
=
((
8
,
1
),)
rank
=
2
criterion
=
GatherV2
(
1
,
strategy
=
gather_v2_strategy
,
index_size
=
batch_size_per_device
*
device_number
)
net_trains
(
gather_v2_strategy
,
criterion
,
rank
)
class
GatherV2Axis1
(
_Loss
):
def
__init__
(
self
,
index_dim
,
strategy
,
index_size
=
16
):
super
(
GatherV2Axis1
,
self
).
__init__
()
self
.
pow
=
P
.
Pow
()
emb1_list
=
21
emb2_list
=
2
if
index_dim
==
1
:
emb_list
=
list
(
range
(
index_size
))
emb1_list
=
emb_list
[
0
::
2
]
emb2_list
=
emb_list
[
1
::
2
]
if
index_dim
==
2
:
emb_list
=
np
.
arange
(
index_size
*
index_size
)
emb1_list
=
np
.
reshape
(
emb_list
[
0
::
2
],
(
int
(
index_size
/
2
),
index_size
))
emb2_list
=
np
.
reshape
(
emb_list
[
1
::
2
],
(
int
(
index_size
/
2
),
index_size
))
self
.
emb1_param
=
Tensor
(
emb1_list
,
dtype
=
mstype
.
int32
)
self
.
emb2_param
=
Tensor
(
emb2_list
,
dtype
=
mstype
.
int32
)
self
.
gatherv2
=
P
.
GatherV2
().
set_strategy
(
strategy
)
def
construct
(
self
,
nembeddings
):
emb1
=
self
.
gatherv2
(
nembeddings
,
self
.
emb1_param
,
1
)
emb2
=
self
.
gatherv2
(
nembeddings
,
self
.
emb2_param
,
1
)
return
self
.
pow
((
emb1
-
emb2
),
2.0
)
def
test_axis1_auto_batch_parallel
():
gather_v2_strategy
=
None
criterion
=
GatherV2Axis1
(
1
,
strategy
=
gather_v2_strategy
,
index_size
=
512
)
rank
=
2
net_trains
(
gather_v2_strategy
,
criterion
,
rank
)
def
test_axis1_batch_parallel
():
gather_v2_strategy
=
((
device_number
,
1
),)
criterion
=
GatherV2Axis1
(
1
,
strategy
=
gather_v2_strategy
,
index_size
=
512
)
rank
=
2
net_trains
(
gather_v2_strategy
,
criterion
,
rank
)
def
test_axis1_strategy1
():
gather_v2_strategy
=
((
16
,
2
),)
rank
=
17
criterion
=
GatherV2Axis1
(
1
,
strategy
=
gather_v2_strategy
,
index_size
=
512
)
net_trains
(
gather_v2_strategy
,
criterion
,
rank
)
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录