Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
正统之独孤求败
mindspore
提交
b413638f
M
mindspore
项目概览
正统之独孤求败
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
b413638f
编写于
4月 09, 2020
作者:
C
c00425699
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
refactor OperatorCostPtr in OperatorInfo
上级
cc53ddae
变更
27
隐藏空白更改
内联
并排
Showing
27 changed file
with
62 addition
and
211 deletion
+62
-211
mindspore/ccsrc/parallel/auto_parallel/operator_costmodel.cc
mindspore/ccsrc/parallel/auto_parallel/operator_costmodel.cc
+0
-54
mindspore/ccsrc/parallel/auto_parallel/operator_costmodel.h
mindspore/ccsrc/parallel/auto_parallel/operator_costmodel.h
+4
-26
mindspore/ccsrc/parallel/ops_info/activation_info.h
mindspore/ccsrc/parallel/ops_info/activation_info.h
+4
-14
mindspore/ccsrc/parallel/ops_info/arithmetic_info.h
mindspore/ccsrc/parallel/ops_info/arithmetic_info.h
+1
-5
mindspore/ccsrc/parallel/ops_info/batch_parallel_info.h
mindspore/ccsrc/parallel/ops_info/batch_parallel_info.h
+1
-5
mindspore/ccsrc/parallel/ops_info/bias_add_info.h
mindspore/ccsrc/parallel/ops_info/bias_add_info.h
+1
-5
mindspore/ccsrc/parallel/ops_info/dropout_do_mask_info.h
mindspore/ccsrc/parallel/ops_info/dropout_do_mask_info.h
+1
-7
mindspore/ccsrc/parallel/ops_info/generator_info.h
mindspore/ccsrc/parallel/ops_info/generator_info.h
+1
-5
mindspore/ccsrc/parallel/ops_info/get_next_info.h
mindspore/ccsrc/parallel/ops_info/get_next_info.h
+1
-5
mindspore/ccsrc/parallel/ops_info/l2_normalize_info.h
mindspore/ccsrc/parallel/ops_info/l2_normalize_info.h
+1
-5
mindspore/ccsrc/parallel/ops_info/loss_info.h
mindspore/ccsrc/parallel/ops_info/loss_info.h
+1
-5
mindspore/ccsrc/parallel/ops_info/matmul_info.cc
mindspore/ccsrc/parallel/ops_info/matmul_info.cc
+3
-3
mindspore/ccsrc/parallel/ops_info/matmul_info.h
mindspore/ccsrc/parallel/ops_info/matmul_info.h
+1
-6
mindspore/ccsrc/parallel/ops_info/onehot_info.h
mindspore/ccsrc/parallel/ops_info/onehot_info.h
+1
-5
mindspore/ccsrc/parallel/ops_info/operator_info.cc
mindspore/ccsrc/parallel/ops_info/operator_info.cc
+6
-7
mindspore/ccsrc/parallel/ops_info/operator_info.h
mindspore/ccsrc/parallel/ops_info/operator_info.h
+9
-4
mindspore/ccsrc/parallel/ops_info/prelu_info.h
mindspore/ccsrc/parallel/ops_info/prelu_info.h
+1
-5
mindspore/ccsrc/parallel/ops_info/reduce_method_info.cc
mindspore/ccsrc/parallel/ops_info/reduce_method_info.cc
+6
-2
mindspore/ccsrc/parallel/ops_info/reduce_method_info.h
mindspore/ccsrc/parallel/ops_info/reduce_method_info.h
+2
-6
mindspore/ccsrc/parallel/ops_info/reshape_info.h
mindspore/ccsrc/parallel/ops_info/reshape_info.h
+2
-6
mindspore/ccsrc/parallel/ops_info/tmp_identity_info.h
mindspore/ccsrc/parallel/ops_info/tmp_identity_info.h
+1
-7
mindspore/ccsrc/parallel/ops_info/transpose_info.h
mindspore/ccsrc/parallel/ops_info/transpose_info.h
+1
-5
mindspore/ccsrc/parallel/ops_info/virtual_dataset_info.h
mindspore/ccsrc/parallel/ops_info/virtual_dataset_info.h
+1
-7
tests/ut/cpp/parallel/ops_info/activation_test.cc
tests/ut/cpp/parallel/ops_info/activation_test.cc
+4
-4
tests/ut/cpp/parallel/ops_info/matmul_info_test.cc
tests/ut/cpp/parallel/ops_info/matmul_info_test.cc
+2
-2
tests/ut/cpp/parallel/ops_info/tensor_add_info_test.cc
tests/ut/cpp/parallel/ops_info/tensor_add_info_test.cc
+4
-4
tests/ut/cpp/parallel/ops_info/tmpidentity_test.cc
tests/ut/cpp/parallel/ops_info/tmpidentity_test.cc
+2
-2
未找到文件。
mindspore/ccsrc/parallel/auto_parallel/operator_costmodel.cc
浏览文件 @
b413638f
...
@@ -514,60 +514,6 @@ double ArithmeticCost::GetBackwardCommCost(const std::vector<TensorInfo>& inputs
...
@@ -514,60 +514,6 @@ double ArithmeticCost::GetBackwardCommCost(const std::vector<TensorInfo>& inputs
return
result
;
return
result
;
}
}
double
L2NormalizeCost
::
GetBackwardCommCost
(
const
std
::
vector
<
TensorInfo
>&
inputs
,
const
std
::
vector
<
TensorInfo
>&
,
const
int32_t
&
stage_id
)
const
{
double
result
=
0.0
;
if
(
is_parameter_
[
0
])
{
TensorInfo
input_tensor_info
=
inputs
[
0
];
CheckGlobalDeviceManager
();
MS_EXCEPTION_IF_NULL
(
g_device_manager
);
auto
total_device_num
=
g_device_manager
->
GetDeviceListByStageId
(
stage_id
).
size
();
Shape
input_shape
=
input_tensor_info
.
shape
();
Shape
input_slice_shape
=
input_tensor_info
.
slice_shape
();
int32_t
used_device_num
=
1
;
for
(
size_t
i
=
0
;
i
<
input_shape
.
size
();
++
i
)
{
used_device_num
*=
input_shape
[
i
]
/
input_slice_shape
[
i
];
}
if
(
total_device_num
!=
IntToSize
(
used_device_num
))
result
+=
ListProduct
(
input_slice_shape
)
*
static_cast
<
double
>
(
inputs_type_lengths_
[
0
]);
}
return
result
;
}
double
L2NormalizeCost
::
GetForwardComputationCost
(
const
std
::
vector
<
TensorInfo
>&
inputs
,
const
std
::
vector
<
TensorInfo
>&
,
const
int32_t
&
)
const
{
TensorInfo
input0_info
=
inputs
[
0
];
Shape
input0_slice_shape
=
input0_info
.
slice_shape
();
return
ListProduct
(
input0_slice_shape
)
*
static_cast
<
double
>
(
inputs_type_lengths_
[
0
]);
}
double
L2NormalizeCost
::
GetBackwardComputationCost
(
const
std
::
vector
<
TensorInfo
>&
inputs
,
const
std
::
vector
<
TensorInfo
>&
,
const
int32_t
&
stage_id
)
const
{
double
result
=
0.0
;
if
(
is_parameter_
[
0
])
{
TensorInfo
input_tensor_info
=
inputs
[
0
];
CheckGlobalDeviceManager
();
MS_EXCEPTION_IF_NULL
(
g_device_manager
);
auto
total_device_num
=
g_device_manager
->
GetDeviceListByStageId
(
stage_id
).
size
();
Shape
input_shape
=
input_tensor_info
.
shape
();
Shape
input_slice_shape
=
input_tensor_info
.
slice_shape
();
int32_t
used_device_num
=
1
;
for
(
size_t
i
=
0
;
i
<
input_shape
.
size
();
++
i
)
{
used_device_num
*=
input_shape
[
i
]
/
input_slice_shape
[
i
];
}
if
(
total_device_num
!=
IntToSize
(
used_device_num
))
result
+=
ListProduct
(
input_slice_shape
)
*
static_cast
<
double
>
(
inputs_type_lengths_
[
0
]);
}
return
result
;
}
bool
IsDataParallel
(
const
Shape
&
shape
,
const
Shape
&
slice_shape
,
const
int32_t
&
stage_id
)
{
bool
IsDataParallel
(
const
Shape
&
shape
,
const
Shape
&
slice_shape
,
const
int32_t
&
stage_id
)
{
CheckGlobalDeviceManager
();
CheckGlobalDeviceManager
();
MS_EXCEPTION_IF_NULL
(
g_device_manager
);
MS_EXCEPTION_IF_NULL
(
g_device_manager
);
...
...
mindspore/ccsrc/parallel/auto_parallel/operator_costmodel.h
浏览文件 @
b413638f
...
@@ -132,6 +132,8 @@ class ActivationCost : public OperatorCost {
...
@@ -132,6 +132,8 @@ class ActivationCost : public OperatorCost {
};
};
using
ActivationCostPtr
=
std
::
shared_ptr
<
ActivationCost
>
;
using
ActivationCostPtr
=
std
::
shared_ptr
<
ActivationCost
>
;
using
TransposeCost
=
ActivationCost
;
using
TransposeCostPtr
=
std
::
shared_ptr
<
TransposeCost
>
;
class
SoftmaxCost
:
public
OperatorCost
{
class
SoftmaxCost
:
public
OperatorCost
{
public:
public:
...
@@ -415,32 +417,8 @@ class ArithmeticCost : public OperatorCost {
...
@@ -415,32 +417,8 @@ class ArithmeticCost : public OperatorCost {
const
int32_t
&
stage_id
)
const
override
;
const
int32_t
&
stage_id
)
const
override
;
};
};
using
ArithmeticCostPtr
=
std
::
shared_ptr
<
ArithmeticCost
>
;
using
ArithmeticCostPtr
=
std
::
shared_ptr
<
ArithmeticCost
>
;
using
BiasAddCost
=
ArithmeticCost
;
class
L2NormalizeCost
:
public
OperatorCost
{
using
BiasAddCostPtr
=
std
::
shared_ptr
<
BiasAddCost
>
;
public:
L2NormalizeCost
()
=
default
;
~
L2NormalizeCost
()
override
=
default
;
double
GetCommCost
(
const
std
::
vector
<
TensorInfo
>&
inputs
,
const
std
::
vector
<
TensorInfo
>&
outputs
,
const
int32_t
&
stage_id
)
const
override
{
return
GetForwardCommCost
(
inputs
,
outputs
,
stage_id
)
+
GetBackwardCommCost
(
inputs
,
outputs
,
stage_id
);
}
double
GetForwardCommCost
(
const
std
::
vector
<
TensorInfo
>&
,
const
std
::
vector
<
TensorInfo
>&
,
const
int32_t
&
)
const
override
{
return
0.0
;
}
double
GetBackwardCommCost
(
const
std
::
vector
<
TensorInfo
>&
inputs
,
const
std
::
vector
<
TensorInfo
>&
outputs
,
const
int32_t
&
stage_id
)
const
override
;
double
GetComputationCost
(
const
std
::
vector
<
TensorInfo
>&
inputs
,
const
std
::
vector
<
TensorInfo
>&
outputs
,
const
int32_t
&
stage_id
)
const
override
{
return
GetForwardComputationCost
(
inputs
,
outputs
,
stage_id
)
+
GetBackwardComputationCost
(
inputs
,
outputs
,
stage_id
);
}
double
GetForwardComputationCost
(
const
std
::
vector
<
TensorInfo
>&
inputs
,
const
std
::
vector
<
TensorInfo
>&
outputs
,
const
int32_t
&
stage_id
)
const
override
;
double
GetBackwardComputationCost
(
const
std
::
vector
<
TensorInfo
>&
inputs
,
const
std
::
vector
<
TensorInfo
>&
outputs
,
const
int32_t
&
stage_id
)
const
override
;
};
using
L2NormalizeCostPtr
=
std
::
shared_ptr
<
L2NormalizeCost
>
;
class
ReduceMethodCost
:
public
OperatorCost
{
class
ReduceMethodCost
:
public
OperatorCost
{
public:
public:
...
...
mindspore/ccsrc/parallel/ops_info/activation_info.h
浏览文件 @
b413638f
...
@@ -32,8 +32,8 @@ namespace parallel {
...
@@ -32,8 +32,8 @@ namespace parallel {
class
ActivationBase
:
public
OperatorInfo
{
class
ActivationBase
:
public
OperatorInfo
{
public:
public:
ActivationBase
(
const
std
::
string
&
operator_name
,
const
Shapes
&
inputs_shape
,
const
Shapes
&
outputs_shape
,
ActivationBase
(
const
std
::
string
&
operator_name
,
const
Shapes
&
inputs_shape
,
const
Shapes
&
outputs_shape
,
const
PrimitiveAttrs
&
attrs
)
const
PrimitiveAttrs
&
attrs
,
OperatorCostPtr
cost
)
:
OperatorInfo
(
operator_name
,
inputs_shape
,
outputs_shape
,
attrs
)
{}
:
OperatorInfo
(
operator_name
,
inputs_shape
,
outputs_shape
,
attrs
,
cost
)
{}
~
ActivationBase
()
override
=
default
;
~
ActivationBase
()
override
=
default
;
Status
Init
(
const
StrategyPtr
&
strategy
)
override
;
Status
Init
(
const
StrategyPtr
&
strategy
)
override
;
...
@@ -51,19 +51,13 @@ class Activation : public ActivationBase {
...
@@ -51,19 +51,13 @@ class Activation : public ActivationBase {
public:
public:
Activation
(
const
std
::
string
&
name
,
const
Shapes
&
inputs_shape
,
const
Shapes
&
outputs_shape
,
Activation
(
const
std
::
string
&
name
,
const
Shapes
&
inputs_shape
,
const
Shapes
&
outputs_shape
,
const
PrimitiveAttrs
&
attrs
)
const
PrimitiveAttrs
&
attrs
)
:
ActivationBase
(
name
,
inputs_shape
,
outputs_shape
,
attrs
)
{
:
ActivationBase
(
name
,
inputs_shape
,
outputs_shape
,
attrs
,
std
::
make_shared
<
ActivationCost
>
())
{}
ac_cost_ptr_
=
std
::
make_shared
<
ActivationCost
>
();
}
~
Activation
()
override
=
default
;
~
Activation
()
override
=
default
;
Status
GenerateStrategies
(
int32_t
stage_id
)
override
;
Status
GenerateStrategies
(
int32_t
stage_id
)
override
;
Status
SetCostUnderStrategy
(
const
StrategyPtr
&
strategy
)
override
;
Status
SetCostUnderStrategy
(
const
StrategyPtr
&
strategy
)
override
;
OperatorCostPtr
GetOperatorCost
()
const
override
{
return
ac_cost_ptr_
;
}
protected:
protected:
Status
CheckStrategy
(
const
StrategyPtr
&
strategy
)
override
;
Status
CheckStrategy
(
const
StrategyPtr
&
strategy
)
override
;
private:
ActivationCostPtr
ac_cost_ptr_
;
};
};
class
ActivationInfo
:
public
Activation
{
class
ActivationInfo
:
public
Activation
{
...
@@ -108,13 +102,10 @@ class Softmax : public ActivationBase {
...
@@ -108,13 +102,10 @@ class Softmax : public ActivationBase {
public:
public:
explicit
Softmax
(
const
std
::
string
&
name
,
const
Shapes
&
inputs_shape
,
const
Shapes
&
outputs_shape
,
explicit
Softmax
(
const
std
::
string
&
name
,
const
Shapes
&
inputs_shape
,
const
Shapes
&
outputs_shape
,
const
PrimitiveAttrs
&
attrs
)
const
PrimitiveAttrs
&
attrs
)
:
ActivationBase
(
name
,
inputs_shape
,
outputs_shape
,
attrs
)
{
:
ActivationBase
(
name
,
inputs_shape
,
outputs_shape
,
attrs
,
std
::
make_shared
<
SoftmaxCost
>
())
{}
sm_cost_ptr_
=
std
::
make_shared
<
SoftmaxCost
>
();
}
~
Softmax
()
override
=
default
;
~
Softmax
()
override
=
default
;
Status
GenerateStrategies
(
int32_t
stage_id
)
override
;
Status
GenerateStrategies
(
int32_t
stage_id
)
override
;
Status
SetCostUnderStrategy
(
const
StrategyPtr
&
strategy
)
override
;
Status
SetCostUnderStrategy
(
const
StrategyPtr
&
strategy
)
override
;
OperatorCostPtr
GetOperatorCost
()
const
override
{
return
sm_cost_ptr_
;
}
protected:
protected:
Status
CheckStrategy
(
const
StrategyPtr
&
strategy
)
override
;
Status
CheckStrategy
(
const
StrategyPtr
&
strategy
)
override
;
...
@@ -122,7 +113,6 @@ class Softmax : public ActivationBase {
...
@@ -122,7 +113,6 @@ class Softmax : public ActivationBase {
private:
private:
std
::
vector
<
int32_t
>
axis_
;
std
::
vector
<
int32_t
>
axis_
;
SoftmaxCostPtr
sm_cost_ptr_
;
};
};
class
SoftmaxInfo
:
public
Softmax
{
class
SoftmaxInfo
:
public
Softmax
{
...
...
mindspore/ccsrc/parallel/ops_info/arithmetic_info.h
浏览文件 @
b413638f
...
@@ -33,15 +33,12 @@ class ArithmeticBase : public OperatorInfo {
...
@@ -33,15 +33,12 @@ class ArithmeticBase : public OperatorInfo {
public:
public:
ArithmeticBase
(
const
std
::
string
&
operator_name
,
const
Shapes
&
inputs_shape
,
const
Shapes
&
outputs_shape
,
ArithmeticBase
(
const
std
::
string
&
operator_name
,
const
Shapes
&
inputs_shape
,
const
Shapes
&
outputs_shape
,
const
PrimitiveAttrs
&
attrs
)
const
PrimitiveAttrs
&
attrs
)
:
OperatorInfo
(
operator_name
,
inputs_shape
,
outputs_shape
,
attrs
)
{
:
OperatorInfo
(
operator_name
,
inputs_shape
,
outputs_shape
,
attrs
,
std
::
make_shared
<
ArithmeticCost
>
())
{}
arithmeticcost_ptr_
=
std
::
make_shared
<
ArithmeticCost
>
();
}
~
ArithmeticBase
()
override
=
default
;
~
ArithmeticBase
()
override
=
default
;
Status
Init
(
const
StrategyPtr
&
strategy
)
override
;
Status
Init
(
const
StrategyPtr
&
strategy
)
override
;
Status
InitForCostModel
(
const
StrategyPtr
&
strategy
)
override
;
Status
InitForCostModel
(
const
StrategyPtr
&
strategy
)
override
;
Status
GenerateStrategies
(
int32_t
)
override
;
Status
GenerateStrategies
(
int32_t
)
override
;
Status
SetCostUnderStrategy
(
const
StrategyPtr
&
)
override
;
Status
SetCostUnderStrategy
(
const
StrategyPtr
&
)
override
;
OperatorCostPtr
GetOperatorCost
()
const
override
{
return
arithmeticcost_ptr_
;
}
void
ReComputeBatchSplitFlagList
()
override
;
void
ReComputeBatchSplitFlagList
()
override
;
protected:
protected:
...
@@ -54,7 +51,6 @@ class ArithmeticBase : public OperatorInfo {
...
@@ -54,7 +51,6 @@ class ArithmeticBase : public OperatorInfo {
Status
InferTensorMap
()
override
;
Status
InferTensorMap
()
override
;
Status
InferTensorLayout
(
TensorLayouts
*
inputs_layout
,
TensorLayouts
*
outputs_layout
,
const
Shape
&
dev_matrix_array
);
Status
InferTensorLayout
(
TensorLayouts
*
inputs_layout
,
TensorLayouts
*
outputs_layout
,
const
Shape
&
dev_matrix_array
);
Shapes
InferExpendShape
();
Shapes
InferExpendShape
();
ArithmeticCostPtr
arithmeticcost_ptr_
;
};
};
class
SubInfo
:
public
ArithmeticBase
{
class
SubInfo
:
public
ArithmeticBase
{
...
...
mindspore/ccsrc/parallel/ops_info/batch_parallel_info.h
浏览文件 @
b413638f
...
@@ -31,16 +31,13 @@ class BatchParallelInfo : public OperatorInfo {
...
@@ -31,16 +31,13 @@ class BatchParallelInfo : public OperatorInfo {
public:
public:
BatchParallelInfo
(
const
std
::
string
&
name
,
const
Shapes
&
inputs_shape
,
const
Shapes
&
outputs_shape
,
BatchParallelInfo
(
const
std
::
string
&
name
,
const
Shapes
&
inputs_shape
,
const
Shapes
&
outputs_shape
,
const
PrimitiveAttrs
&
attrs
)
const
PrimitiveAttrs
&
attrs
)
:
OperatorInfo
(
name
,
inputs_shape
,
outputs_shape
,
attrs
),
dev_num_
(
1
)
{
:
OperatorInfo
(
name
,
inputs_shape
,
outputs_shape
,
attrs
,
std
::
make_shared
<
BatchParallelCost
>
()),
dev_num_
(
1
)
{}
bp_cost_ptr_
=
std
::
make_shared
<
BatchParallelCost
>
();
}
~
BatchParallelInfo
()
override
=
default
;
~
BatchParallelInfo
()
override
=
default
;
Status
Init
(
const
StrategyPtr
&
strategy
)
override
;
Status
Init
(
const
StrategyPtr
&
strategy
)
override
;
Status
InitForCostModel
(
const
StrategyPtr
&
strategy
)
override
;
Status
InitForCostModel
(
const
StrategyPtr
&
strategy
)
override
;
Status
GenerateStrategies
(
int32_t
stage_id
)
override
;
Status
GenerateStrategies
(
int32_t
stage_id
)
override
;
Status
SetCostUnderStrategy
(
const
StrategyPtr
&
strategy
)
override
;
Status
SetCostUnderStrategy
(
const
StrategyPtr
&
strategy
)
override
;
OperatorCostPtr
GetOperatorCost
()
const
override
{
return
bp_cost_ptr_
;
}
protected:
protected:
Status
CheckStrategy
(
const
StrategyPtr
&
strategy
)
override
;
Status
CheckStrategy
(
const
StrategyPtr
&
strategy
)
override
;
...
@@ -55,7 +52,6 @@ class BatchParallelInfo : public OperatorInfo {
...
@@ -55,7 +52,6 @@ class BatchParallelInfo : public OperatorInfo {
private:
private:
int32_t
dev_num_
;
int32_t
dev_num_
;
BatchParallelCostPtr
bp_cost_ptr_
;
};
};
class
SparseSoftmaxCrossEntropyWithLogitsInfo
:
public
BatchParallelInfo
{
class
SparseSoftmaxCrossEntropyWithLogitsInfo
:
public
BatchParallelInfo
{
...
...
mindspore/ccsrc/parallel/ops_info/bias_add_info.h
浏览文件 @
b413638f
...
@@ -34,16 +34,13 @@ class BiasAddInfo : public OperatorInfo {
...
@@ -34,16 +34,13 @@ class BiasAddInfo : public OperatorInfo {
public:
public:
BiasAddInfo
(
const
std
::
string
&
operator_name
,
const
Shapes
&
inputs_shape
,
const
Shapes
&
outputs_shape
,
BiasAddInfo
(
const
std
::
string
&
operator_name
,
const
Shapes
&
inputs_shape
,
const
Shapes
&
outputs_shape
,
const
PrimitiveAttrs
&
attrs
)
const
PrimitiveAttrs
&
attrs
)
:
OperatorInfo
(
operator_name
,
inputs_shape
,
outputs_shape
,
attrs
)
{
:
OperatorInfo
(
operator_name
,
inputs_shape
,
outputs_shape
,
attrs
,
std
::
make_shared
<
BiasAddCost
>
())
{}
biasaddcost_ptr_
=
std
::
make_shared
<
ArithmeticCost
>
();
}
~
BiasAddInfo
()
override
=
default
;
~
BiasAddInfo
()
override
=
default
;
Status
Init
(
const
StrategyPtr
&
strategy
)
override
;
Status
Init
(
const
StrategyPtr
&
strategy
)
override
;
Status
InitForCostModel
(
const
StrategyPtr
&
strategy
)
override
;
Status
InitForCostModel
(
const
StrategyPtr
&
strategy
)
override
;
Status
GenerateStrategies
(
int32_t
)
override
;
Status
GenerateStrategies
(
int32_t
)
override
;
Status
SetCostUnderStrategy
(
const
StrategyPtr
&
)
override
;
Status
SetCostUnderStrategy
(
const
StrategyPtr
&
)
override
;
OperatorCostPtr
GetOperatorCost
()
const
override
{
return
biasaddcost_ptr_
;
}
void
ReComputeBatchSplitFlagList
()
override
;
void
ReComputeBatchSplitFlagList
()
override
;
protected:
protected:
...
@@ -55,7 +52,6 @@ class BiasAddInfo : public OperatorInfo {
...
@@ -55,7 +52,6 @@ class BiasAddInfo : public OperatorInfo {
Status
InferDevMatrixShape
()
override
;
Status
InferDevMatrixShape
()
override
;
Status
InferTensorMap
()
override
;
Status
InferTensorMap
()
override
;
Status
InferTensorLayout
(
TensorLayouts
*
inputs_layout
,
TensorLayouts
*
outputs_layout
,
const
Shape
&
dev_matrix_array
);
Status
InferTensorLayout
(
TensorLayouts
*
inputs_layout
,
TensorLayouts
*
outputs_layout
,
const
Shape
&
dev_matrix_array
);
ArithmeticCostPtr
biasaddcost_ptr_
;
};
};
}
// namespace parallel
}
// namespace parallel
}
// namespace mindspore
}
// namespace mindspore
...
...
mindspore/ccsrc/parallel/ops_info/dropout_do_mask_info.h
浏览文件 @
b413638f
...
@@ -33,15 +33,12 @@ class DropoutDoMaskInfo : public OperatorInfo {
...
@@ -33,15 +33,12 @@ class DropoutDoMaskInfo : public OperatorInfo {
public:
public:
DropoutDoMaskInfo
(
const
std
::
string
&
name
,
const
Shapes
&
inputs_shape
,
const
Shapes
&
outputs_shape
,
DropoutDoMaskInfo
(
const
std
::
string
&
name
,
const
Shapes
&
inputs_shape
,
const
Shapes
&
outputs_shape
,
const
PrimitiveAttrs
&
attrs
)
const
PrimitiveAttrs
&
attrs
)
:
OperatorInfo
(
name
,
inputs_shape
,
outputs_shape
,
attrs
)
{
:
OperatorInfo
(
name
,
inputs_shape
,
outputs_shape
,
attrs
,
std
::
make_shared
<
BatchParallelCost
>
())
{}
bpcost_ptr_
=
std
::
make_shared
<
BatchParallelCost
>
();
}
~
DropoutDoMaskInfo
()
override
=
default
;
~
DropoutDoMaskInfo
()
override
=
default
;
Status
Init
(
const
StrategyPtr
&
strategy
)
override
;
Status
Init
(
const
StrategyPtr
&
strategy
)
override
;
Status
GenerateStrategies
(
int32_t
stage_id
)
override
;
Status
GenerateStrategies
(
int32_t
stage_id
)
override
;
Status
SetCostUnderStrategy
(
const
StrategyPtr
&
strategy
)
override
;
Status
SetCostUnderStrategy
(
const
StrategyPtr
&
strategy
)
override
;
OperatorCostPtr
GetOperatorCost
()
const
override
{
return
bpcost_ptr_
;
}
Status
InitForCostModel
(
const
StrategyPtr
&
strategy
)
override
;
Status
InitForCostModel
(
const
StrategyPtr
&
strategy
)
override
;
std
::
shared_ptr
<
std
::
vector
<
std
::
vector
<
int32_t
>>>
GenerateBatchStrategies
()
override
;
std
::
shared_ptr
<
std
::
vector
<
std
::
vector
<
int32_t
>>>
GenerateBatchStrategies
()
override
;
...
@@ -53,9 +50,6 @@ class DropoutDoMaskInfo : public OperatorInfo {
...
@@ -53,9 +50,6 @@ class DropoutDoMaskInfo : public OperatorInfo {
Status
GetAttrs
()
override
{
return
SUCCESS
;
}
Status
GetAttrs
()
override
{
return
SUCCESS
;
}
Status
InferTensorInfo
()
override
;
Status
InferTensorInfo
()
override
;
Status
InferDevMatrixShape
()
override
;
Status
InferDevMatrixShape
()
override
;
private:
BatchParallelCostPtr
bpcost_ptr_
;
};
};
}
// namespace parallel
}
// namespace parallel
}
// namespace mindspore
}
// namespace mindspore
...
...
mindspore/ccsrc/parallel/ops_info/generator_info.h
浏览文件 @
b413638f
...
@@ -32,15 +32,12 @@ class GeneratorBase : public OperatorInfo {
...
@@ -32,15 +32,12 @@ class GeneratorBase : public OperatorInfo {
public:
public:
GeneratorBase
(
const
std
::
string
&
operator_name
,
const
Shapes
&
inputs_shape
,
const
Shapes
&
outputs_shape
,
GeneratorBase
(
const
std
::
string
&
operator_name
,
const
Shapes
&
inputs_shape
,
const
Shapes
&
outputs_shape
,
const
PrimitiveAttrs
&
attrs
)
const
PrimitiveAttrs
&
attrs
)
:
OperatorInfo
(
operator_name
,
inputs_shape
,
outputs_shape
,
attrs
)
{
:
OperatorInfo
(
operator_name
,
inputs_shape
,
outputs_shape
,
attrs
,
std
::
make_shared
<
GeneratorBaseCost
>
())
{}
generatorbasecost_ptr_
=
std
::
make_shared
<
GeneratorBaseCost
>
();
}
~
GeneratorBase
()
override
=
default
;
~
GeneratorBase
()
override
=
default
;
Status
Init
(
const
StrategyPtr
&
strategy
)
override
;
Status
Init
(
const
StrategyPtr
&
strategy
)
override
;
Status
SetCostUnderStrategy
(
const
StrategyPtr
&
strategy
)
override
;
Status
SetCostUnderStrategy
(
const
StrategyPtr
&
strategy
)
override
;
OperatorCostPtr
GetOperatorCost
()
const
override
{
return
generatorbasecost_ptr_
;
}
Status
InitForCostModel
(
const
StrategyPtr
&
strategy
)
override
;
Status
InitForCostModel
(
const
StrategyPtr
&
strategy
)
override
;
protected:
protected:
...
@@ -52,7 +49,6 @@ class GeneratorBase : public OperatorInfo {
...
@@ -52,7 +49,6 @@ class GeneratorBase : public OperatorInfo {
Status
InferMirrorOps
()
override
{
return
SUCCESS
;
}
Status
InferMirrorOps
()
override
{
return
SUCCESS
;
}
Status
InferForwardCommunication
()
override
{
return
SUCCESS
;
}
Status
InferForwardCommunication
()
override
{
return
SUCCESS
;
}
virtual
Status
InferReplaceOps
(
const
StrategyPtr
&
strategy
)
=
0
;
virtual
Status
InferReplaceOps
(
const
StrategyPtr
&
strategy
)
=
0
;
GeneratorBaseCostPtr
generatorbasecost_ptr_
;
};
};
class
DropoutGenMaskInfo
:
public
GeneratorBase
{
class
DropoutGenMaskInfo
:
public
GeneratorBase
{
...
...
mindspore/ccsrc/parallel/ops_info/get_next_info.h
浏览文件 @
b413638f
...
@@ -32,14 +32,11 @@ class GetNextInfo : public OperatorInfo {
...
@@ -32,14 +32,11 @@ class GetNextInfo : public OperatorInfo {
public:
public:
GetNextInfo
(
const
std
::
string
&
operator_name
,
const
Shapes
&
inputs_shape
,
const
Shapes
&
outputs_shape
,
GetNextInfo
(
const
std
::
string
&
operator_name
,
const
Shapes
&
inputs_shape
,
const
Shapes
&
outputs_shape
,
const
PrimitiveAttrs
&
attrs
)
const
PrimitiveAttrs
&
attrs
)
:
OperatorInfo
(
operator_name
,
inputs_shape
,
outputs_shape
,
attrs
)
{
:
OperatorInfo
(
operator_name
,
inputs_shape
,
outputs_shape
,
attrs
,
std
::
make_shared
<
GetNextCost
>
())
{}
getnextcost_ptr_
=
std
::
make_shared
<
GetNextCost
>
();
}
~
GetNextInfo
()
override
=
default
;
~
GetNextInfo
()
override
=
default
;
Status
Init
(
const
StrategyPtr
&
strategy
)
override
;
Status
Init
(
const
StrategyPtr
&
strategy
)
override
;
Status
SetCostUnderStrategy
(
const
StrategyPtr
&
strategy
)
override
;
Status
SetCostUnderStrategy
(
const
StrategyPtr
&
strategy
)
override
;
OperatorCostPtr
GetOperatorCost
()
const
override
{
return
getnextcost_ptr_
;
}
Status
InitForCostModel
(
const
StrategyPtr
&
strategy
)
override
;
Status
InitForCostModel
(
const
StrategyPtr
&
strategy
)
override
;
Status
GenerateStrategies
(
int32_t
stage_id
)
override
;
Status
GenerateStrategies
(
int32_t
stage_id
)
override
;
...
@@ -65,7 +62,6 @@ class GetNextInfo : public OperatorInfo {
...
@@ -65,7 +62,6 @@ class GetNextInfo : public OperatorInfo {
Shapes
shapes_
;
Shapes
shapes_
;
int32_t
output_num_
=
0
;
int32_t
output_num_
=
0
;
std
::
string
shared_name_
;
std
::
string
shared_name_
;
GetNextCostPtr
getnextcost_ptr_
;
};
};
}
// namespace parallel
}
// namespace parallel
}
// namespace mindspore
}
// namespace mindspore
...
...
mindspore/ccsrc/parallel/ops_info/l2_normalize_info.h
浏览文件 @
b413638f
...
@@ -33,12 +33,9 @@ class L2NormalizeInfo : public Activation {
...
@@ -33,12 +33,9 @@ class L2NormalizeInfo : public Activation {
public:
public:
L2NormalizeInfo
(
const
std
::
string
&
name
,
const
Shapes
&
inputs_shape
,
const
Shapes
&
outputs_shape
,
L2NormalizeInfo
(
const
std
::
string
&
name
,
const
Shapes
&
inputs_shape
,
const
Shapes
&
outputs_shape
,
const
PrimitiveAttrs
&
attrs
)
const
PrimitiveAttrs
&
attrs
)
:
Activation
(
name
,
inputs_shape
,
outputs_shape
,
attrs
)
{
:
Activation
(
name
,
inputs_shape
,
outputs_shape
,
attrs
)
{}
l2normalizecost_ptr_
=
std
::
make_shared
<
L2NormalizeCost
>
();
}
~
L2NormalizeInfo
()
override
=
default
;
~
L2NormalizeInfo
()
override
=
default
;
Status
GenerateStrategies
(
int32_t
stage_id
)
override
;
Status
GenerateStrategies
(
int32_t
stage_id
)
override
;
OperatorCostPtr
GetOperatorCost
()
const
override
{
return
l2normalizecost_ptr_
;
}
protected:
protected:
Status
GetAttrs
()
override
;
Status
GetAttrs
()
override
;
...
@@ -47,7 +44,6 @@ class L2NormalizeInfo : public Activation {
...
@@ -47,7 +44,6 @@ class L2NormalizeInfo : public Activation {
private:
private:
int32_t
axis_
=
0
;
// Default value = 0
int32_t
axis_
=
0
;
// Default value = 0
L2NormalizeCostPtr
l2normalizecost_ptr_
;
};
};
}
// namespace parallel
}
// namespace parallel
}
// namespace mindspore
}
// namespace mindspore
...
...
mindspore/ccsrc/parallel/ops_info/loss_info.h
浏览文件 @
b413638f
...
@@ -36,16 +36,13 @@ class SoftmaxCrossEntropyWithLogitsInfo : public OperatorInfo {
...
@@ -36,16 +36,13 @@ class SoftmaxCrossEntropyWithLogitsInfo : public OperatorInfo {
public:
public:
SoftmaxCrossEntropyWithLogitsInfo
(
const
std
::
string
&
name
,
const
Shapes
&
inputs_shape
,
const
Shapes
&
outputs_shape
,
SoftmaxCrossEntropyWithLogitsInfo
(
const
std
::
string
&
name
,
const
Shapes
&
inputs_shape
,
const
Shapes
&
outputs_shape
,
const
PrimitiveAttrs
&
attrs
)
const
PrimitiveAttrs
&
attrs
)
:
OperatorInfo
(
name
,
inputs_shape
,
outputs_shape
,
attrs
)
{
:
OperatorInfo
(
name
,
inputs_shape
,
outputs_shape
,
attrs
,
std
::
make_shared
<
SoftmaxCrossEntropyWithLogitsCost
>
())
{}
softmax_loss_cost_ptr_
=
std
::
make_shared
<
SoftmaxCrossEntropyWithLogitsCost
>
();
}
~
SoftmaxCrossEntropyWithLogitsInfo
()
override
=
default
;
~
SoftmaxCrossEntropyWithLogitsInfo
()
override
=
default
;
Status
Init
(
const
StrategyPtr
&
strategy
)
override
;
Status
Init
(
const
StrategyPtr
&
strategy
)
override
;
Status
InitForCostModel
(
const
StrategyPtr
&
strategy
)
override
;
Status
InitForCostModel
(
const
StrategyPtr
&
strategy
)
override
;
Status
GenerateStrategies
(
int32_t
stage_id
)
override
;
Status
GenerateStrategies
(
int32_t
stage_id
)
override
;
Status
SetCostUnderStrategy
(
const
StrategyPtr
&
strategy
)
override
;
Status
SetCostUnderStrategy
(
const
StrategyPtr
&
strategy
)
override
;
OperatorCostPtr
GetOperatorCost
()
const
override
{
return
softmax_loss_cost_ptr_
;
}
void
ReComputeBatchSplitFlagList
()
override
;
void
ReComputeBatchSplitFlagList
()
override
;
protected:
protected:
...
@@ -59,7 +56,6 @@ class SoftmaxCrossEntropyWithLogitsInfo : public OperatorInfo {
...
@@ -59,7 +56,6 @@ class SoftmaxCrossEntropyWithLogitsInfo : public OperatorInfo {
// There are two outputs for SoftmaxCrossEntropyWithLogits, and outputs[1] is used for grad and overload
// There are two outputs for SoftmaxCrossEntropyWithLogits, and outputs[1] is used for grad and overload
// the InferAsLossDivisor.
// the InferAsLossDivisor.
Status
InferAsLossDivisor
()
override
;
Status
InferAsLossDivisor
()
override
;
SoftmaxCrossEntropyWithLogitsCostPtr
softmax_loss_cost_ptr_
;
private:
private:
int32_t
axis_
=
-
1
;
// default -1
int32_t
axis_
=
-
1
;
// default -1
...
...
mindspore/ccsrc/parallel/ops_info/matmul_info.cc
浏览文件 @
b413638f
...
@@ -593,11 +593,11 @@ Status MatMulBase::SetCostUnderStrategy(const mindspore::parallel::StrategyPtr&
...
@@ -593,11 +593,11 @@ Status MatMulBase::SetCostUnderStrategy(const mindspore::parallel::StrategyPtr&
// Here, we use the origin outputs_, because we only use the slice size of the output tensor.
// Here, we use the origin outputs_, because we only use the slice size of the output tensor.
// It does not matter whether the output tensor is transposed or not.
// It does not matter whether the output tensor is transposed or not.
double
computation_cost
=
double
computation_cost
=
matmulcost_ptr
->
GetForwardComputationCost
(
relica_inputs_tensor_vector
,
outputs_tensor_info_
,
stage_id
);
cost
()
->
GetForwardComputationCost
(
relica_inputs_tensor_vector
,
outputs_tensor_info_
,
stage_id
);
double
communication_cost
=
matmulcost_ptr
->
GetCommCost
(
relica_inputs_tensor_vector
,
outputs_tensor_info_
,
stage_id
);
double
communication_cost
=
cost
()
->
GetCommCost
(
relica_inputs_tensor_vector
,
outputs_tensor_info_
,
stage_id
);
std
::
shared_ptr
<
Cost
>
result
=
std
::
make_shared
<
Cost
>
(
computation_cost
,
communication_cost
);
std
::
shared_ptr
<
Cost
>
result
=
std
::
make_shared
<
Cost
>
(
computation_cost
,
communication_cost
);
result
->
communication_without_parameter_
=
result
->
communication_without_parameter_
=
matmulcost_ptr
->
GetForwardCommCost
(
relica_inputs_tensor_vector
,
outputs_tensor_info_
,
stage_id
);
cost
()
->
GetForwardCommCost
(
relica_inputs_tensor_vector
,
outputs_tensor_info_
,
stage_id
);
result
->
communication_with_partial_para_
=
result
->
communication_with_partial_para_
=
result
->
communication_without_parameter_
+
result
->
communication_without_parameter_
+
COST_MODEL_GAMMA
*
(
communication_cost
-
result
->
communication_without_parameter_
);
COST_MODEL_GAMMA
*
(
communication_cost
-
result
->
communication_without_parameter_
);
...
...
mindspore/ccsrc/parallel/ops_info/matmul_info.h
浏览文件 @
b413638f
...
@@ -34,9 +34,7 @@ class MatMulBase : public OperatorInfo {
...
@@ -34,9 +34,7 @@ class MatMulBase : public OperatorInfo {
public:
public:
MatMulBase
(
const
std
::
string
&
name
,
const
Shapes
&
inputs_shape
,
const
Shapes
&
outputs_shape
,
MatMulBase
(
const
std
::
string
&
name
,
const
Shapes
&
inputs_shape
,
const
Shapes
&
outputs_shape
,
const
PrimitiveAttrs
&
attrs
)
const
PrimitiveAttrs
&
attrs
)
:
OperatorInfo
(
name
,
inputs_shape
,
outputs_shape
,
attrs
)
{
:
OperatorInfo
(
name
,
inputs_shape
,
outputs_shape
,
attrs
,
std
::
make_shared
<
MatMulCost
>
())
{}
matmulcost_ptr
=
std
::
make_shared
<
MatMulCost
>
();
}
~
MatMulBase
()
override
=
default
;
~
MatMulBase
()
override
=
default
;
Status
Init
(
const
StrategyPtr
&
strategy
)
override
;
Status
Init
(
const
StrategyPtr
&
strategy
)
override
;
...
@@ -48,7 +46,6 @@ class MatMulBase : public OperatorInfo {
...
@@ -48,7 +46,6 @@ class MatMulBase : public OperatorInfo {
Status
PrepareStrategy
(
int32_t
stage_id
,
size_t
dev_num
,
Dimensions
combined_partitions
,
size_t
input0_shape_size
,
Status
PrepareStrategy
(
int32_t
stage_id
,
size_t
dev_num
,
Dimensions
combined_partitions
,
size_t
input0_shape_size
,
size_t
input1_shape_size
,
StrategyPtr
*
sp
);
size_t
input1_shape_size
,
StrategyPtr
*
sp
);
OperatorCostPtr
GetOperatorCost
()
const
override
{
return
matmulcost_ptr
;
}
Status
SwapLastTwoElements
(
Shape
*
shape
);
Status
SwapLastTwoElements
(
Shape
*
shape
);
protected:
protected:
...
@@ -66,8 +63,6 @@ class MatMulBase : public OperatorInfo {
...
@@ -66,8 +63,6 @@ class MatMulBase : public OperatorInfo {
bool
transpose_b_
=
false
;
bool
transpose_b_
=
false
;
size_t
mat_a_dimension_
=
0
;
size_t
mat_a_dimension_
=
0
;
size_t
mat_b_dimension_
=
0
;
size_t
mat_b_dimension_
=
0
;
MatMulCostPtr
matmulcost_ptr
;
};
};
class
MatMul
:
public
MatMulBase
{
class
MatMul
:
public
MatMulBase
{
...
...
mindspore/ccsrc/parallel/ops_info/onehot_info.h
浏览文件 @
b413638f
...
@@ -33,16 +33,13 @@ class OneHotInfo : public OperatorInfo {
...
@@ -33,16 +33,13 @@ class OneHotInfo : public OperatorInfo {
public:
public:
OneHotInfo
(
const
std
::
string
&
name
,
const
Shapes
&
inputs_shape
,
const
Shapes
&
outputs_shape
,
OneHotInfo
(
const
std
::
string
&
name
,
const
Shapes
&
inputs_shape
,
const
Shapes
&
outputs_shape
,
const
PrimitiveAttrs
&
attrs
)
const
PrimitiveAttrs
&
attrs
)
:
OperatorInfo
(
name
,
inputs_shape
,
outputs_shape
,
attrs
)
{
:
OperatorInfo
(
name
,
inputs_shape
,
outputs_shape
,
attrs
,
std
::
make_shared
<
OneHotCost
>
())
{}
onehot_cost_ptr_
=
std
::
make_shared
<
OneHotCost
>
();
}
~
OneHotInfo
()
override
=
default
;
~
OneHotInfo
()
override
=
default
;
Status
Init
(
const
StrategyPtr
&
strategy
)
override
;
Status
Init
(
const
StrategyPtr
&
strategy
)
override
;
Status
InitForCostModel
(
const
StrategyPtr
&
strategy
)
override
;
Status
InitForCostModel
(
const
StrategyPtr
&
strategy
)
override
;
Status
GenerateStrategies
(
int32_t
stage_id
)
override
;
Status
GenerateStrategies
(
int32_t
stage_id
)
override
;
Status
SetCostUnderStrategy
(
const
StrategyPtr
&
strategy
)
override
;
Status
SetCostUnderStrategy
(
const
StrategyPtr
&
strategy
)
override
;
OperatorCostPtr
GetOperatorCost
()
const
override
{
return
onehot_cost_ptr_
;
}
ReplaceGraphPtr
replace_graph
(
const
CNodePtr
&
cnode
)
override
;
ReplaceGraphPtr
replace_graph
(
const
CNodePtr
&
cnode
)
override
;
std
::
shared_ptr
<
std
::
vector
<
std
::
vector
<
int32_t
>>>
GenerateBatchStrategies
()
override
;
std
::
shared_ptr
<
std
::
vector
<
std
::
vector
<
int32_t
>>>
GenerateBatchStrategies
()
override
;
...
@@ -60,7 +57,6 @@ class OneHotInfo : public OperatorInfo {
...
@@ -60,7 +57,6 @@ class OneHotInfo : public OperatorInfo {
Status
ComputeReplaceGraph
(
const
CNodePtr
&
cnode
);
Status
ComputeReplaceGraph
(
const
CNodePtr
&
cnode
);
int
axis_
=
-
1
;
int
axis_
=
-
1
;
OneHotCostPtr
onehot_cost_ptr_
;
int32_t
rank_
=
0
;
int32_t
rank_
=
0
;
int32_t
total_class_number_
=
1
;
int32_t
total_class_number_
=
1
;
int32_t
classes_each_device_
=
1
;
int32_t
classes_each_device_
=
1
;
...
...
mindspore/ccsrc/parallel/ops_info/operator_info.cc
浏览文件 @
b413638f
...
@@ -1034,12 +1034,11 @@ Status OperatorInfo::SetCostUnderStrategyBase(const StrategyPtr& strategy) {
...
@@ -1034,12 +1034,11 @@ Status OperatorInfo::SetCostUnderStrategyBase(const StrategyPtr& strategy) {
return
FAILED
;
return
FAILED
;
}
}
int32_t
stage_id
=
strategy
->
GetInputStage
();
int32_t
stage_id
=
strategy
->
GetInputStage
();
double
computation_cost
=
double
computation_cost
=
cost
()
->
GetForwardComputationCost
(
inputs_tensor_info_
,
outputs_tensor_info_
,
stage_id
);
GetOperatorCost
()
->
GetForwardComputationCost
(
inputs_tensor_info_
,
outputs_tensor_info_
,
stage_id
);
double
communication_cost
=
cost
()
->
GetCommCost
(
inputs_tensor_info_
,
outputs_tensor_info_
,
stage_id
);
double
communication_cost
=
GetOperatorCost
()
->
GetCommCost
(
inputs_tensor_info_
,
outputs_tensor_info_
,
stage_id
);
std
::
shared_ptr
<
Cost
>
result
=
std
::
make_shared
<
Cost
>
(
computation_cost
,
communication_cost
);
std
::
shared_ptr
<
Cost
>
result
=
std
::
make_shared
<
Cost
>
(
computation_cost
,
communication_cost
);
result
->
communication_without_parameter_
=
result
->
communication_without_parameter_
=
GetOperatorC
ost
()
->
GetForwardCommCost
(
inputs_tensor_info_
,
outputs_tensor_info_
,
stage_id
);
c
ost
()
->
GetForwardCommCost
(
inputs_tensor_info_
,
outputs_tensor_info_
,
stage_id
);
result
->
communication_with_partial_para_
=
result
->
communication_with_partial_para_
=
result
->
communication_without_parameter_
+
result
->
communication_without_parameter_
+
COST_MODEL_GAMMA
*
(
communication_cost
-
result
->
communication_without_parameter_
);
COST_MODEL_GAMMA
*
(
communication_cost
-
result
->
communication_without_parameter_
);
...
@@ -1096,7 +1095,7 @@ Status OperatorInfo::set_is_parameter(const std::vector<bool>& is_parameter) {
...
@@ -1096,7 +1095,7 @@ Status OperatorInfo::set_is_parameter(const std::vector<bool>& is_parameter) {
return
FAILED
;
return
FAILED
;
}
}
is_parameter_
=
is_parameter
;
is_parameter_
=
is_parameter
;
GetOperatorC
ost
()
->
set_is_parameter
(
is_parameter
);
c
ost
()
->
set_is_parameter
(
is_parameter
);
return
SUCCESS
;
return
SUCCESS
;
}
}
...
@@ -1193,7 +1192,7 @@ Status OperatorInfo::SetInputAndOutputTypeLength(const std::vector<size_t>& inpu
...
@@ -1193,7 +1192,7 @@ Status OperatorInfo::SetInputAndOutputTypeLength(const std::vector<size_t>& inpu
}
}
inputs_type_lengths_
=
input_lengths
;
inputs_type_lengths_
=
input_lengths
;
outputs_type_lengths_
=
output_lengths
;
outputs_type_lengths_
=
output_lengths
;
GetOperatorC
ost
()
->
SetInputAndOutputTypeLength
(
input_lengths
,
output_lengths
);
c
ost
()
->
SetInputAndOutputTypeLength
(
input_lengths
,
output_lengths
);
return
SUCCESS
;
return
SUCCESS
;
}
}
...
@@ -1211,7 +1210,7 @@ void OperatorInfo::BreakingTiesForPerferringDataParallel(const StrategyPtr& stra
...
@@ -1211,7 +1210,7 @@ void OperatorInfo::BreakingTiesForPerferringDataParallel(const StrategyPtr& stra
}
}
double
OperatorInfo
::
GetForwardMemoryCostFromCNode
()
{
double
OperatorInfo
::
GetForwardMemoryCostFromCNode
()
{
return
GetOperatorC
ost
()
->
GetForwardComputationCost
(
inputs_tensor_info_
,
outputs_tensor_info_
,
0
);
return
c
ost
()
->
GetForwardComputationCost
(
inputs_tensor_info_
,
outputs_tensor_info_
,
0
);
}
}
}
// namespace parallel
}
// namespace parallel
...
...
mindspore/ccsrc/parallel/ops_info/operator_info.h
浏览文件 @
b413638f
...
@@ -53,12 +53,13 @@ class Edge;
...
@@ -53,12 +53,13 @@ class Edge;
class
OperatorInfo
{
class
OperatorInfo
{
public:
public:
OperatorInfo
(
std
::
string
name
,
Shapes
inputs_shape
,
Shapes
outputs_shape
,
PrimitiveAttrs
attrs
)
OperatorInfo
(
std
::
string
name
,
Shapes
inputs_shape
,
Shapes
outputs_shape
,
PrimitiveAttrs
attrs
,
OperatorCostPtr
cost
)
:
name_
(
std
::
move
(
name
)),
:
name_
(
std
::
move
(
name
)),
inputs_shape_
(
std
::
move
(
inputs_shape
)),
inputs_shape_
(
std
::
move
(
inputs_shape
)),
outputs_shape_
(
std
::
move
(
outputs_shape
)),
outputs_shape_
(
std
::
move
(
outputs_shape
)),
attrs_
(
std
::
move
(
attrs
)),
attrs_
(
std
::
move
(
attrs
)),
is_alive_
(
true
)
{
is_alive_
(
true
),
cost_
(
cost
)
{
std
::
vector
<
bool
>
not_parameteter
(
inputs_shape_
.
size
(),
false
);
std
::
vector
<
bool
>
not_parameteter
(
inputs_shape_
.
size
(),
false
);
is_parameter_
=
not_parameteter
;
is_parameter_
=
not_parameteter
;
refkey_parameter_name_
=
""
;
refkey_parameter_name_
=
""
;
...
@@ -75,7 +76,8 @@ class OperatorInfo {
...
@@ -75,7 +76,8 @@ class OperatorInfo {
// Given the stage_id (which indicates the number of devices),
// Given the stage_id (which indicates the number of devices),
// generate all strategies for this operator
// generate all strategies for this operator
virtual
Status
GenerateStrategies
(
int32_t
stage_id
)
=
0
;
virtual
Status
GenerateStrategies
(
int32_t
stage_id
)
=
0
;
virtual
OperatorCostPtr
GetOperatorCost
()
const
=
0
;
const
OperatorCostPtr
&
cost
()
const
{
return
cost_
;
}
void
set_cost
(
const
OperatorCostPtr
&
cost
)
{
cost_
=
cost
;
}
virtual
Status
SetCostUnderStrategy
(
const
StrategyPtr
&
strategy
)
=
0
;
virtual
Status
SetCostUnderStrategy
(
const
StrategyPtr
&
strategy
)
=
0
;
virtual
std
::
shared_ptr
<
std
::
vector
<
std
::
vector
<
int32_t
>>>
GenerateBatchStrategies
();
virtual
std
::
shared_ptr
<
std
::
vector
<
std
::
vector
<
int32_t
>>>
GenerateBatchStrategies
();
...
@@ -115,7 +117,7 @@ class OperatorInfo {
...
@@ -115,7 +117,7 @@ class OperatorInfo {
void
ReplaceSuccEdge
(
const
std
::
shared_ptr
<
OperatorInfo
>&
op
,
const
std
::
shared_ptr
<
Edge
>&
new_edge
);
void
ReplaceSuccEdge
(
const
std
::
shared_ptr
<
OperatorInfo
>&
op
,
const
std
::
shared_ptr
<
Edge
>&
new_edge
);
void
ReplacePreEdges
(
const
std
::
shared_ptr
<
OperatorInfo
>&
op
,
const
std
::
shared_ptr
<
Edge
>&
new_edge
);
void
ReplacePreEdges
(
const
std
::
shared_ptr
<
OperatorInfo
>&
op
,
const
std
::
shared_ptr
<
Edge
>&
new_edge
);
void
ReplaceSuccEdges
(
const
std
::
shared_ptr
<
OperatorInfo
>&
op
,
const
std
::
shared_ptr
<
Edge
>&
new_edge
);
void
ReplaceSuccEdges
(
const
std
::
shared_ptr
<
OperatorInfo
>&
op
,
const
std
::
shared_ptr
<
Edge
>&
new_edge
);
std
::
vector
<
size_t
>
GetOutputTypeLengths
()
const
{
return
GetOperatorC
ost
()
->
outputs_type_lengths
();
}
std
::
vector
<
size_t
>
GetOutputTypeLengths
()
const
{
return
c
ost
()
->
outputs_type_lengths
();
}
void
SetSelectedStrategyAndCost
(
const
StrategyPtr
&
s_strategy
,
const
CostPtr
&
cost
)
{
void
SetSelectedStrategyAndCost
(
const
StrategyPtr
&
s_strategy
,
const
CostPtr
&
cost
)
{
selected_strategy_
=
s_strategy
;
selected_strategy_
=
s_strategy
;
selected_cost_
=
cost
;
selected_cost_
=
cost
;
...
@@ -221,6 +223,9 @@ class OperatorInfo {
...
@@ -221,6 +223,9 @@ class OperatorInfo {
std
::
string
refkey_parameter_name_
;
std
::
string
refkey_parameter_name_
;
CNodePtr
cnode_
;
CNodePtr
cnode_
;
int32_t
used_devices_
=
-
1
;
int32_t
used_devices_
=
-
1
;
private:
OperatorCostPtr
cost_
;
};
};
Shape
GetSliceShape
(
const
Shape
&
tensor_shape
,
const
Dimensions
&
strategy
);
Shape
GetSliceShape
(
const
Shape
&
tensor_shape
,
const
Dimensions
&
strategy
);
...
...
mindspore/ccsrc/parallel/ops_info/prelu_info.h
浏览文件 @
b413638f
...
@@ -35,15 +35,12 @@ class PReLUInfo : public OperatorInfo {
...
@@ -35,15 +35,12 @@ class PReLUInfo : public OperatorInfo {
public:
public:
PReLUInfo
(
const
std
::
string
&
name
,
const
Shapes
&
inputs_shape
,
const
Shapes
&
outputs_shape
,
PReLUInfo
(
const
std
::
string
&
name
,
const
Shapes
&
inputs_shape
,
const
Shapes
&
outputs_shape
,
const
PrimitiveAttrs
&
attrs
)
const
PrimitiveAttrs
&
attrs
)
:
OperatorInfo
(
name
,
inputs_shape
,
outputs_shape
,
attrs
)
{
:
OperatorInfo
(
name
,
inputs_shape
,
outputs_shape
,
attrs
,
std
::
make_shared
<
PReLUCost
>
())
{}
prelucost_ptr
=
std
::
make_shared
<
PReLUCost
>
();
}
~
PReLUInfo
()
override
=
default
;
~
PReLUInfo
()
override
=
default
;
Status
Init
(
const
StrategyPtr
&
strategy
)
override
;
Status
Init
(
const
StrategyPtr
&
strategy
)
override
;
Status
InitForCostModel
(
const
StrategyPtr
&
strategy
)
override
;
Status
InitForCostModel
(
const
StrategyPtr
&
strategy
)
override
;
Status
GenerateStrategies
(
int32_t
stage_id
)
override
;
Status
GenerateStrategies
(
int32_t
stage_id
)
override
;
OperatorCostPtr
GetOperatorCost
()
const
override
{
return
prelucost_ptr
;
}
Status
SetCostUnderStrategy
(
const
StrategyPtr
&
strategy
)
override
;
Status
SetCostUnderStrategy
(
const
StrategyPtr
&
strategy
)
override
;
protected:
protected:
...
@@ -59,7 +56,6 @@ class PReLUInfo : public OperatorInfo {
...
@@ -59,7 +56,6 @@ class PReLUInfo : public OperatorInfo {
private:
private:
Dimensions
input_strategy_
;
Dimensions
input_strategy_
;
PReLUCostPtr
prelucost_ptr
;
};
};
}
// namespace parallel
}
// namespace parallel
}
// namespace mindspore
}
// namespace mindspore
...
...
mindspore/ccsrc/parallel/ops_info/reduce_method_info.cc
浏览文件 @
b413638f
...
@@ -109,8 +109,12 @@ Status ReduceMethod::GetAttrs() {
...
@@ -109,8 +109,12 @@ Status ReduceMethod::GetAttrs() {
}
}
cross_batch_
=
cross_batch_iter
->
second
->
cast
<
BoolImmPtr
>
()
->
value
();
cross_batch_
=
cross_batch_iter
->
second
->
cast
<
BoolImmPtr
>
()
->
value
();
}
}
reducemethodcost_ptr_
->
set_cross_batch
(
cross_batch_
);
auto
reducemethodcost
=
std
::
dynamic_pointer_cast
<
ReduceMethodCost
>
(
cost
());
if
(
reducemethodcost
==
nullptr
)
{
MS_LOG
(
ERROR
)
<<
"Cost cast to ReduceMethodCostPtr failed!"
;
return
FAILED
;
}
reducemethodcost
->
set_cross_batch
(
cross_batch_
);
return
SUCCESS
;
return
SUCCESS
;
}
}
...
...
mindspore/ccsrc/parallel/ops_info/reduce_method_info.h
浏览文件 @
b413638f
...
@@ -34,9 +34,7 @@ class ReduceMethod : public OperatorInfo {
...
@@ -34,9 +34,7 @@ class ReduceMethod : public OperatorInfo {
public:
public:
ReduceMethod
(
const
std
::
string
&
name
,
const
Shapes
&
inputs_shape
,
const
Shapes
&
outputs_shape
,
ReduceMethod
(
const
std
::
string
&
name
,
const
Shapes
&
inputs_shape
,
const
Shapes
&
outputs_shape
,
const
PrimitiveAttrs
&
attrs
)
const
PrimitiveAttrs
&
attrs
)
:
OperatorInfo
(
name
,
inputs_shape
,
outputs_shape
,
attrs
)
{
:
OperatorInfo
(
name
,
inputs_shape
,
outputs_shape
,
attrs
,
std
::
make_shared
<
ReduceMethodCost
>
())
{}
reducemethodcost_ptr_
=
std
::
make_shared
<
ReduceMethodCost
>
();
}
~
ReduceMethod
()
override
=
default
;
~
ReduceMethod
()
override
=
default
;
Status
Init
(
const
StrategyPtr
&
strategy
)
override
;
Status
Init
(
const
StrategyPtr
&
strategy
)
override
;
...
@@ -44,13 +42,11 @@ class ReduceMethod : public OperatorInfo {
...
@@ -44,13 +42,11 @@ class ReduceMethod : public OperatorInfo {
Status
GenerateStrategies
(
int32_t
stage_id
)
override
;
Status
GenerateStrategies
(
int32_t
stage_id
)
override
;
Status
SetCostUnderStrategy
(
const
StrategyPtr
&
strategy
)
override
;
Status
SetCostUnderStrategy
(
const
StrategyPtr
&
strategy
)
override
;
OperatorCostPtr
GetOperatorCost
()
const
override
{
return
reducemethodcost_ptr_
;
}
protected:
protected:
std
::
string
reduce_method_
;
std
::
string
reduce_method_
;
bool
keepdims_
=
false
;
bool
keepdims_
=
false
;
bool
cross_batch_
=
false
;
bool
cross_batch_
=
false
;
ReduceMethodCostPtr
reducemethodcost_ptr_
;
Status
CheckStrategy
(
const
StrategyPtr
&
strategy
)
override
;
Status
CheckStrategy
(
const
StrategyPtr
&
strategy
)
override
;
Status
GetAttrs
()
override
;
Status
GetAttrs
()
override
;
Dimensions
InferOutputStrategy
();
Dimensions
InferOutputStrategy
();
...
@@ -110,7 +106,7 @@ class ReduceMeanInfo : public ReduceMethod {
...
@@ -110,7 +106,7 @@ class ReduceMeanInfo : public ReduceMethod {
ReduceMeanInfo
(
const
std
::
string
&
name
,
const
Shapes
&
inputs_shape
,
const
Shapes
&
outputs_shape
,
ReduceMeanInfo
(
const
std
::
string
&
name
,
const
Shapes
&
inputs_shape
,
const
Shapes
&
outputs_shape
,
const
PrimitiveAttrs
&
attrs
)
const
PrimitiveAttrs
&
attrs
)
:
ReduceMethod
(
name
,
inputs_shape
,
outputs_shape
,
attrs
)
{
:
ReduceMethod
(
name
,
inputs_shape
,
outputs_shape
,
attrs
)
{
reducemethodcost_ptr_
=
std
::
make_shared
<
ReduceMeanCost
>
(
);
set_cost
(
std
::
make_shared
<
ReduceMeanCost
>
()
);
}
}
~
ReduceMeanInfo
()
override
=
default
;
~
ReduceMeanInfo
()
override
=
default
;
...
...
mindspore/ccsrc/parallel/ops_info/reshape_info.h
浏览文件 @
b413638f
...
@@ -36,12 +36,10 @@ class ReshapeInfo : public OperatorInfo {
...
@@ -36,12 +36,10 @@ class ReshapeInfo : public OperatorInfo {
public:
public:
ReshapeInfo
(
const
std
::
string
&
name
,
const
Shapes
&
inputs_shape
,
const
Shapes
&
outputs_shape
,
ReshapeInfo
(
const
std
::
string
&
name
,
const
Shapes
&
inputs_shape
,
const
Shapes
&
outputs_shape
,
const
PrimitiveAttrs
&
attrs
)
const
PrimitiveAttrs
&
attrs
)
:
OperatorInfo
(
name
,
inputs_shape
,
outputs_shape
,
attrs
),
:
OperatorInfo
(
name
,
inputs_shape
,
outputs_shape
,
attrs
,
std
::
make_shared
<
ReshapeCost
>
()
),
dev_num_
(
0
),
dev_num_
(
0
),
input_layout_set_flag_
(
false
),
input_layout_set_flag_
(
false
),
output_layout_set_flag_
(
false
)
{
output_layout_set_flag_
(
false
)
{}
reshape_cost_ptr_
=
std
::
make_shared
<
ReshapeCost
>
();
}
~
ReshapeInfo
()
override
=
default
;
~
ReshapeInfo
()
override
=
default
;
Status
Init
(
const
StrategyPtr
&
strategy
)
override
;
Status
Init
(
const
StrategyPtr
&
strategy
)
override
;
void
SetInputLayout
(
const
TensorLayout
&
input_layout
)
{
void
SetInputLayout
(
const
TensorLayout
&
input_layout
)
{
...
@@ -55,7 +53,6 @@ class ReshapeInfo : public OperatorInfo {
...
@@ -55,7 +53,6 @@ class ReshapeInfo : public OperatorInfo {
Status
InitForCostModel
(
const
StrategyPtr
&
strategy
)
override
;
Status
InitForCostModel
(
const
StrategyPtr
&
strategy
)
override
;
Status
GenerateStrategies
(
int32_t
stage_id
)
override
;
Status
GenerateStrategies
(
int32_t
stage_id
)
override
;
Status
SetCostUnderStrategy
(
const
StrategyPtr
&
strategy
)
override
;
Status
SetCostUnderStrategy
(
const
StrategyPtr
&
strategy
)
override
;
OperatorCostPtr
GetOperatorCost
()
const
override
{
return
reshape_cost_ptr_
;
}
protected:
protected:
Status
CheckStrategy
(
const
StrategyPtr
&
strategy
)
override
;
Status
CheckStrategy
(
const
StrategyPtr
&
strategy
)
override
;
...
@@ -67,7 +64,6 @@ class ReshapeInfo : public OperatorInfo {
...
@@ -67,7 +64,6 @@ class ReshapeInfo : public OperatorInfo {
Status
InferTensorLayout
(
TensorLayouts
*
inputs_layout
,
TensorLayouts
*
outputs_layout
);
Status
InferTensorLayout
(
TensorLayouts
*
inputs_layout
,
TensorLayouts
*
outputs_layout
);
Status
GetAttrs
()
override
;
Status
GetAttrs
()
override
;
Strategys
GetOutputsStrategy
();
Strategys
GetOutputsStrategy
();
ReshapeCostPtr
reshape_cost_ptr_
;
private:
private:
Status
GetParameterInput
();
Status
GetParameterInput
();
...
...
mindspore/ccsrc/parallel/ops_info/tmp_identity_info.h
浏览文件 @
b413638f
...
@@ -34,9 +34,7 @@ class TmpIdentityInfo : public OperatorInfo {
...
@@ -34,9 +34,7 @@ class TmpIdentityInfo : public OperatorInfo {
public:
public:
TmpIdentityInfo
(
const
Shapes
&
inputs_shape
,
const
Shapes
&
outputs_shape
,
const
PrimitiveAttrs
&
attrs
,
TmpIdentityInfo
(
const
Shapes
&
inputs_shape
,
const
Shapes
&
outputs_shape
,
const
PrimitiveAttrs
&
attrs
,
const
std
::
string
&
name
=
IDENTITY_INFO
)
const
std
::
string
&
name
=
IDENTITY_INFO
)
:
OperatorInfo
(
name
,
inputs_shape
,
outputs_shape
,
attrs
)
{
:
OperatorInfo
(
name
,
inputs_shape
,
outputs_shape
,
attrs
,
std
::
make_shared
<
TmpIdentityCost
>
())
{}
id_cost_ptr_
=
std
::
make_shared
<
TmpIdentityCost
>
();
}
~
TmpIdentityInfo
()
override
=
default
;
~
TmpIdentityInfo
()
override
=
default
;
Status
Init
(
const
StrategyPtr
&
strategy
)
override
;
Status
Init
(
const
StrategyPtr
&
strategy
)
override
;
...
@@ -44,7 +42,6 @@ class TmpIdentityInfo : public OperatorInfo {
...
@@ -44,7 +42,6 @@ class TmpIdentityInfo : public OperatorInfo {
Status
GenerateStrategies
(
int32_t
stage_id
)
override
;
Status
GenerateStrategies
(
int32_t
stage_id
)
override
;
Status
SetCostUnderStrategy
(
const
StrategyPtr
&
strategy
)
override
;
Status
SetCostUnderStrategy
(
const
StrategyPtr
&
strategy
)
override
;
OperatorCostPtr
GetOperatorCost
()
const
override
{
return
id_cost_ptr_
;
}
protected:
protected:
Status
CheckStrategy
(
const
StrategyPtr
&
strategy
)
override
;
Status
CheckStrategy
(
const
StrategyPtr
&
strategy
)
override
;
...
@@ -54,9 +51,6 @@ class TmpIdentityInfo : public OperatorInfo {
...
@@ -54,9 +51,6 @@ class TmpIdentityInfo : public OperatorInfo {
Status
InferTensorInfo
()
override
;
Status
InferTensorInfo
()
override
;
Status
InferDevMatrixShape
()
override
;
Status
InferDevMatrixShape
()
override
;
Status
InferTensorMap
()
override
;
Status
InferTensorMap
()
override
;
private:
TmpIdentityCostPtr
id_cost_ptr_
;
};
};
}
// namespace parallel
}
// namespace parallel
}
// namespace mindspore
}
// namespace mindspore
...
...
mindspore/ccsrc/parallel/ops_info/transpose_info.h
浏览文件 @
b413638f
...
@@ -35,15 +35,12 @@ class TransposeInfo : public OperatorInfo {
...
@@ -35,15 +35,12 @@ class TransposeInfo : public OperatorInfo {
public:
public:
TransposeInfo
(
const
std
::
string
&
name
,
const
Shapes
&
inputs_shape
,
const
Shapes
&
outputs_shape
,
TransposeInfo
(
const
std
::
string
&
name
,
const
Shapes
&
inputs_shape
,
const
Shapes
&
outputs_shape
,
const
PrimitiveAttrs
&
attrs
)
const
PrimitiveAttrs
&
attrs
)
:
OperatorInfo
(
name
,
inputs_shape
,
outputs_shape
,
attrs
)
{
:
OperatorInfo
(
name
,
inputs_shape
,
outputs_shape
,
attrs
,
std
::
make_shared
<
TransposeCost
>
())
{}
transpose_cost_ptr_
=
std
::
make_shared
<
ActivationCost
>
();
}
~
TransposeInfo
()
override
=
default
;
~
TransposeInfo
()
override
=
default
;
Status
Init
(
const
StrategyPtr
&
strategy
)
override
;
Status
Init
(
const
StrategyPtr
&
strategy
)
override
;
Status
InitForCostModel
(
const
StrategyPtr
&
strategy
)
override
;
Status
InitForCostModel
(
const
StrategyPtr
&
strategy
)
override
;
Status
GenerateStrategies
(
int32_t
stage_id
)
override
;
Status
GenerateStrategies
(
int32_t
stage_id
)
override
;
Status
SetCostUnderStrategy
(
const
StrategyPtr
&
strategy
)
override
;
Status
SetCostUnderStrategy
(
const
StrategyPtr
&
strategy
)
override
;
OperatorCostPtr
GetOperatorCost
()
const
override
{
return
transpose_cost_ptr_
;
}
protected:
protected:
Status
CheckStrategy
(
const
StrategyPtr
&
strategy
)
override
;
Status
CheckStrategy
(
const
StrategyPtr
&
strategy
)
override
;
...
@@ -60,7 +57,6 @@ class TransposeInfo : public OperatorInfo {
...
@@ -60,7 +57,6 @@ class TransposeInfo : public OperatorInfo {
Status
ComputeAxis
();
Status
ComputeAxis
();
std
::
vector
<
int32_t
>
axis_v_
;
std
::
vector
<
int32_t
>
axis_v_
;
Dimensions
input_strategy_
;
Dimensions
input_strategy_
;
ActivationCostPtr
transpose_cost_ptr_
;
};
};
}
// namespace parallel
}
// namespace parallel
}
// namespace mindspore
}
// namespace mindspore
...
...
mindspore/ccsrc/parallel/ops_info/virtual_dataset_info.h
浏览文件 @
b413638f
...
@@ -32,16 +32,13 @@ class VirtualDatasetInfo : public OperatorInfo {
...
@@ -32,16 +32,13 @@ class VirtualDatasetInfo : public OperatorInfo {
public:
public:
VirtualDatasetInfo
(
const
std
::
string
&
name
,
const
Shapes
&
inputs_shape
,
const
Shapes
&
outputs_shape
,
VirtualDatasetInfo
(
const
std
::
string
&
name
,
const
Shapes
&
inputs_shape
,
const
Shapes
&
outputs_shape
,
const
PrimitiveAttrs
&
attrs
)
const
PrimitiveAttrs
&
attrs
)
:
OperatorInfo
(
name
,
inputs_shape
,
outputs_shape
,
attrs
)
{
:
OperatorInfo
(
name
,
inputs_shape
,
outputs_shape
,
attrs
,
std
::
make_shared
<
VirtualDatasetCost
>
())
{}
vd_cost_ptr_
=
std
::
make_shared
<
VirtualDatasetCost
>
();
}
~
VirtualDatasetInfo
()
override
=
default
;
~
VirtualDatasetInfo
()
override
=
default
;
Status
Init
(
const
StrategyPtr
&
strategy
)
override
;
Status
Init
(
const
StrategyPtr
&
strategy
)
override
;
Status
InitForCostModel
(
const
StrategyPtr
&
strategy
)
override
;
Status
InitForCostModel
(
const
StrategyPtr
&
strategy
)
override
;
Status
GenerateStrategies
(
int32_t
stage_id
)
override
;
Status
GenerateStrategies
(
int32_t
stage_id
)
override
;
Status
SetCostUnderStrategy
(
const
StrategyPtr
&
strategy
)
override
;
Status
SetCostUnderStrategy
(
const
StrategyPtr
&
strategy
)
override
;
OperatorCostPtr
GetOperatorCost
()
const
override
{
return
vd_cost_ptr_
;
}
void
ReComputeBatchSplitFlagList
()
override
;
void
ReComputeBatchSplitFlagList
()
override
;
protected:
protected:
...
@@ -53,9 +50,6 @@ class VirtualDatasetInfo : public OperatorInfo {
...
@@ -53,9 +50,6 @@ class VirtualDatasetInfo : public OperatorInfo {
Status
InferTensorMap
()
override
;
Status
InferTensorMap
()
override
;
Status
GetAttrs
()
override
;
Status
GetAttrs
()
override
;
Status
InferAsLossDivisor
()
override
;
Status
InferAsLossDivisor
()
override
;
private:
VirtualDatasetCostPtr
vd_cost_ptr_
;
};
};
}
// namespace parallel
}
// namespace parallel
...
...
tests/ut/cpp/parallel/ops_info/activation_test.cc
浏览文件 @
b413638f
...
@@ -84,9 +84,9 @@ TEST_F(TestActivation, test_activation_strategies) {
...
@@ -84,9 +84,9 @@ TEST_F(TestActivation, test_activation_strategies) {
act_ptr_
->
InitForCostModel
(
sp
);
act_ptr_
->
InitForCostModel
(
sp
);
std
::
vector
<
TensorInfo
>
inputs_info
=
act_ptr_
->
inputs_tensor_info
();
std
::
vector
<
TensorInfo
>
inputs_info
=
act_ptr_
->
inputs_tensor_info
();
std
::
vector
<
TensorInfo
>
outputs_info
=
act_ptr_
->
outputs_tensor_info
();
std
::
vector
<
TensorInfo
>
outputs_info
=
act_ptr_
->
outputs_tensor_info
();
ASSERT_DOUBLE_EQ
(
act_ptr_
->
GetOperatorC
ost
()
->
GetComputationCost
(
inputs_info
,
outputs_info
,
sp
->
GetInputStage
()),
ASSERT_DOUBLE_EQ
(
act_ptr_
->
c
ost
()
->
GetComputationCost
(
inputs_info
,
outputs_info
,
sp
->
GetInputStage
()),
cost
.
computation_cost_
);
cost
.
computation_cost_
);
ASSERT_DOUBLE_EQ
(
act_ptr_
->
GetOperatorC
ost
()
->
GetCommCost
(
inputs_info
,
outputs_info
,
sp
->
GetInputStage
()),
ASSERT_DOUBLE_EQ
(
act_ptr_
->
c
ost
()
->
GetCommCost
(
inputs_info
,
outputs_info
,
sp
->
GetInputStage
()),
cost
.
communication_cost_
);
cost
.
communication_cost_
);
}
}
}
}
...
@@ -109,9 +109,9 @@ TEST_F(TestActivation, test_softmax_strategies) {
...
@@ -109,9 +109,9 @@ TEST_F(TestActivation, test_softmax_strategies) {
soft_ptr_
->
InitForCostModel
(
sp
);
soft_ptr_
->
InitForCostModel
(
sp
);
std
::
vector
<
TensorInfo
>
inputs_info
=
soft_ptr_
->
inputs_tensor_info
();
std
::
vector
<
TensorInfo
>
inputs_info
=
soft_ptr_
->
inputs_tensor_info
();
std
::
vector
<
TensorInfo
>
outputs_info
=
soft_ptr_
->
outputs_tensor_info
();
std
::
vector
<
TensorInfo
>
outputs_info
=
soft_ptr_
->
outputs_tensor_info
();
ASSERT_DOUBLE_EQ
(
soft_ptr_
->
GetOperatorC
ost
()
->
GetComputationCost
(
inputs_info
,
outputs_info
,
sp
->
GetInputStage
()),
ASSERT_DOUBLE_EQ
(
soft_ptr_
->
c
ost
()
->
GetComputationCost
(
inputs_info
,
outputs_info
,
sp
->
GetInputStage
()),
cost
.
computation_cost_
);
cost
.
computation_cost_
);
ASSERT_DOUBLE_EQ
(
soft_ptr_
->
GetOperatorC
ost
()
->
GetCommCost
(
inputs_info
,
outputs_info
,
sp
->
GetInputStage
()),
ASSERT_DOUBLE_EQ
(
soft_ptr_
->
c
ost
()
->
GetCommCost
(
inputs_info
,
outputs_info
,
sp
->
GetInputStage
()),
cost
.
communication_cost_
);
cost
.
communication_cost_
);
}
}
}
}
...
...
tests/ut/cpp/parallel/ops_info/matmul_info_test.cc
浏览文件 @
b413638f
...
@@ -569,7 +569,7 @@ TEST_F(TestMatmulInfo, test_GenerateStrategies1) {
...
@@ -569,7 +569,7 @@ TEST_F(TestMatmulInfo, test_GenerateStrategies1) {
matmul1
->
InitForCostModel
(
sp
);
matmul1
->
InitForCostModel
(
sp
);
std
::
vector
<
TensorInfo
>
inputs_info
=
matmul1
->
inputs_tensor_info
();
std
::
vector
<
TensorInfo
>
inputs_info
=
matmul1
->
inputs_tensor_info
();
std
::
vector
<
TensorInfo
>
outputs_info
=
matmul1
->
outputs_tensor_info
();
std
::
vector
<
TensorInfo
>
outputs_info
=
matmul1
->
outputs_tensor_info
();
ASSERT_DOUBLE_EQ
(
matmul1
->
GetOperatorC
ost
()
->
GetComputationCost
(
inputs_info
,
outputs_info
,
sp
->
GetInputStage
()),
ASSERT_DOUBLE_EQ
(
matmul1
->
c
ost
()
->
GetComputationCost
(
inputs_info
,
outputs_info
,
sp
->
GetInputStage
()),
cost
.
computation_cost_
);
cost
.
computation_cost_
);
break
;
break
;
}
}
...
@@ -599,7 +599,7 @@ TEST_F(TestMatmulInfo, test_GenerateStrategies2) {
...
@@ -599,7 +599,7 @@ TEST_F(TestMatmulInfo, test_GenerateStrategies2) {
TensorInfo
replica_input1_info
(
tly
,
input1_shape
,
input1_slice_shape
);
TensorInfo
replica_input1_info
(
tly
,
input1_shape
,
input1_slice_shape
);
replica_inputs_info
.
push_back
(
replica_input1_info
);
replica_inputs_info
.
push_back
(
replica_input1_info
);
ASSERT_DOUBLE_EQ
(
matmul3
->
GetOperatorC
ost
()
->
GetComputationCost
(
replica_inputs_info
,
outputs_info
,
sp
->
GetInputStage
()),
ASSERT_DOUBLE_EQ
(
matmul3
->
c
ost
()
->
GetComputationCost
(
replica_inputs_info
,
outputs_info
,
sp
->
GetInputStage
()),
cost
.
computation_cost_
);
cost
.
computation_cost_
);
break
;
break
;
}
}
...
...
tests/ut/cpp/parallel/ops_info/tensor_add_info_test.cc
浏览文件 @
b413638f
...
@@ -188,11 +188,11 @@ TEST_F(TestTensorAddInfo, GenerateStrategies) {
...
@@ -188,11 +188,11 @@ TEST_F(TestTensorAddInfo, GenerateStrategies) {
tensor_add
->
InitForCostModel
(
sp
);
tensor_add
->
InitForCostModel
(
sp
);
std
::
vector
<
TensorInfo
>
inputs_info
=
tensor_add
->
inputs_tensor_info
();
std
::
vector
<
TensorInfo
>
inputs_info
=
tensor_add
->
inputs_tensor_info
();
std
::
vector
<
TensorInfo
>
outputs_info
=
tensor_add
->
outputs_tensor_info
();
std
::
vector
<
TensorInfo
>
outputs_info
=
tensor_add
->
outputs_tensor_info
();
double
memory_cost0
=
tensor_add
->
GetOperatorC
ost
()
->
GetComputationCost
(
inputs_info
,
outputs_info
,
sp
->
GetInputStage
());
double
memory_cost0
=
tensor_add
->
c
ost
()
->
GetComputationCost
(
inputs_info
,
outputs_info
,
sp
->
GetInputStage
());
double
memory_cost1
=
cost
.
computation_cost_
;
double
memory_cost1
=
cost
.
computation_cost_
;
bool
memory
=
memory_cost0
-
memory_cost1
<=
1.0
;
bool
memory
=
memory_cost0
-
memory_cost1
<=
1.0
;
double
comm_cost0
=
tensor_add
->
GetOperatorC
ost
()
->
GetCommCost
(
inputs_info
,
outputs_info
,
sp
->
GetInputStage
());
double
comm_cost0
=
tensor_add
->
c
ost
()
->
GetCommCost
(
inputs_info
,
outputs_info
,
sp
->
GetInputStage
());
double
comm_cost1
=
cost
.
communication_cost_
;
double
comm_cost1
=
cost
.
communication_cost_
;
bool
comm
=
comm_cost0
-
comm_cost1
<=
1.0
;
bool
comm
=
comm_cost0
-
comm_cost1
<=
1.0
;
...
@@ -210,11 +210,11 @@ TEST_F(TestTensorAddInfo, GenerateStrategies1) {
...
@@ -210,11 +210,11 @@ TEST_F(TestTensorAddInfo, GenerateStrategies1) {
tensor_add1
->
InitForCostModel
(
sp
);
tensor_add1
->
InitForCostModel
(
sp
);
std
::
vector
<
TensorInfo
>
inputs_info
=
tensor_add1
->
inputs_tensor_info
();
std
::
vector
<
TensorInfo
>
inputs_info
=
tensor_add1
->
inputs_tensor_info
();
std
::
vector
<
TensorInfo
>
outputs_info
=
tensor_add1
->
outputs_tensor_info
();
std
::
vector
<
TensorInfo
>
outputs_info
=
tensor_add1
->
outputs_tensor_info
();
double
memory_cost0
=
tensor_add1
->
GetOperatorC
ost
()
->
GetComputationCost
(
inputs_info
,
outputs_info
,
sp
->
GetInputStage
());
double
memory_cost0
=
tensor_add1
->
c
ost
()
->
GetComputationCost
(
inputs_info
,
outputs_info
,
sp
->
GetInputStage
());
double
memory_cost1
=
cost
.
computation_cost_
;
double
memory_cost1
=
cost
.
computation_cost_
;
bool
memory
=
memory_cost0
-
memory_cost1
<=
1.0
;
bool
memory
=
memory_cost0
-
memory_cost1
<=
1.0
;
double
comm_cost0
=
tensor_add1
->
GetOperatorC
ost
()
->
GetCommCost
(
inputs_info
,
outputs_info
,
sp
->
GetInputStage
());
double
comm_cost0
=
tensor_add1
->
c
ost
()
->
GetCommCost
(
inputs_info
,
outputs_info
,
sp
->
GetInputStage
());
double
comm_cost1
=
cost
.
communication_cost_
;
double
comm_cost1
=
cost
.
communication_cost_
;
bool
comm
=
comm_cost0
-
comm_cost1
<=
1.0
;
bool
comm
=
comm_cost0
-
comm_cost1
<=
1.0
;
...
...
tests/ut/cpp/parallel/ops_info/tmpidentity_test.cc
浏览文件 @
b413638f
...
@@ -145,9 +145,9 @@ TEST_F(TestTmpIdentityInfo, test_generate_strategies) {
...
@@ -145,9 +145,9 @@ TEST_F(TestTmpIdentityInfo, test_generate_strategies) {
identity_ptr
->
Init
(
sp
);
identity_ptr
->
Init
(
sp
);
std
::
vector
<
TensorInfo
>
inputs_info
=
identity_ptr
->
inputs_tensor_info
();
std
::
vector
<
TensorInfo
>
inputs_info
=
identity_ptr
->
inputs_tensor_info
();
std
::
vector
<
TensorInfo
>
outputs_info
=
identity_ptr
->
outputs_tensor_info
();
std
::
vector
<
TensorInfo
>
outputs_info
=
identity_ptr
->
outputs_tensor_info
();
ASSERT_DOUBLE_EQ
(
identity_ptr
->
GetOperatorC
ost
()
->
GetComputationCost
(
inputs_info
,
outputs_info
,
sp
->
GetInputStage
()),
ASSERT_DOUBLE_EQ
(
identity_ptr
->
c
ost
()
->
GetComputationCost
(
inputs_info
,
outputs_info
,
sp
->
GetInputStage
()),
cost
.
computation_cost_
);
cost
.
computation_cost_
);
ASSERT_DOUBLE_EQ
(
identity_ptr
->
GetOperatorC
ost
()
->
GetCommCost
(
inputs_info
,
outputs_info
,
sp
->
GetInputStage
()),
ASSERT_DOUBLE_EQ
(
identity_ptr
->
c
ost
()
->
GetCommCost
(
inputs_info
,
outputs_info
,
sp
->
GetInputStage
()),
cost
.
communication_cost_
);
cost
.
communication_cost_
);
}
}
}
}
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录