Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
8e20d4d8
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
8e20d4d8
编写于
6月 22, 2020
作者:
H
hongxing
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix l2normalize/prelu/softmax cost
上级
948ea950
变更
8
隐藏空白更改
内联
并排
Showing
8 changed file
with
111 addition
and
31 deletion
+111
-31
mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_cost.cc
mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_cost.cc
+43
-0
mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_cost.h
mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_cost.h
+6
-0
mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_generate_strategy.cc
.../parallel/auto_parallel/rec_core/rec_generate_strategy.cc
+29
-13
mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_generate_strategy.h
...c/parallel/auto_parallel/rec_core/rec_generate_strategy.h
+2
-3
mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_graph.h
mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_graph.h
+1
-0
mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_parse_graph.cc
.../ccsrc/parallel/auto_parallel/rec_core/rec_parse_graph.cc
+14
-4
mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_parse_graph.h
...e/ccsrc/parallel/auto_parallel/rec_core/rec_parse_graph.h
+1
-1
mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_partition.cc
...re/ccsrc/parallel/auto_parallel/rec_core/rec_partition.cc
+15
-10
未找到文件。
mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_cost.cc
浏览文件 @
8e20d4d8
...
@@ -703,5 +703,48 @@ StrategyRec CostBatchParallel::ChoseStr(const std::vector<double> &cost_op, Stra
...
@@ -703,5 +703,48 @@ StrategyRec CostBatchParallel::ChoseStr(const std::vector<double> &cost_op, Stra
}
}
return
str
;
return
str
;
}
}
// Chose strategy for CostSoftmaxCrossEntropyWithLogits
StrategyRec
CostSoftmaxCrossEntropyWithLogits
::
ChoseStr
(
const
std
::
vector
<
double
>
&
cost_op
,
StrategyRec
str
)
{
uint64_t
min_position
=
min_element
(
cost_op
.
begin
(),
cost_op
.
end
())
-
cost_op
.
begin
();
if
(
cost_op
[
min_position
]
>
(
DOUBLE_MAX
-
0.1
))
{
return
str
;
}
switch
(
min_position
)
{
case
0
:
str
.
inputTensor
[
0
].
str_n
/=
2.0
;
str
.
inputTensor
[
1
].
str_n
/=
2.0
;
str
.
cut_counter
+=
1
;
str
.
cost
=
str
.
cost
+
cost_in_
;
break
;
case
1
:
str
.
inputTensor
[
0
].
str_c
/=
2.0
;
str
.
inputTensor
[
1
].
str_c
/=
2.0
;
str
.
cut_counter
+=
1
;
str
.
cost
=
str
.
cost
+
cost_in_
;
break
;
case
2
:
str
.
inputTensor
[
0
].
str_h
/=
2.0
;
str
.
inputTensor
[
1
].
str_h
/=
2.0
;
str
.
outputTensor
.
str_w
/=
2.0
;
str
.
cut_counter
+=
1
;
str
.
cost
=
str
.
cost
+
cost_in_
;
break
;
case
3
:
str
.
inputTensor
[
0
].
str_w
/=
2.0
;
str
.
inputTensor
[
1
].
str_w
/=
2.0
;
str
.
cut_counter
+=
1
;
str
.
cost
=
str
.
cost
+
cost_in_
;
break
;
default:
MS_LOG
(
EXCEPTION
)
<<
"Failure: CostSoftmax failed."
;
}
return
str
;
}
}
// namespace parallel
}
// namespace parallel
}
// namespace mindspore
}
// namespace mindspore
mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_cost.h
浏览文件 @
8e20d4d8
...
@@ -222,6 +222,12 @@ class CostBatchParallel {
...
@@ -222,6 +222,12 @@ class CostBatchParallel {
class
CostBatchNorm
:
public
CostBatchParallel
{};
class
CostBatchNorm
:
public
CostBatchParallel
{};
class
CostOneHot
:
public
CostBatchParallel
{};
class
CostOneHot
:
public
CostBatchParallel
{};
class
CostPRelu
:
public
CostBatchParallel
{};
class
CostSoftmax
:
public
CostBatchParallel
{};
class
CostSoftmaxCrossEntropyWithLogits
:
public
CostBatchParallel
{
StrategyRec
ChoseStr
(
const
std
::
vector
<
double
>
&
cost_op
,
StrategyRec
str
);
};
}
// namespace parallel
}
// namespace parallel
}
// namespace mindspore
}
// namespace mindspore
#endif // PARALLEL_AUTO_PARALLEL_REC_COST_H_
#endif // PARALLEL_AUTO_PARALLEL_REC_COST_H_
mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_generate_strategy.cc
浏览文件 @
8e20d4d8
...
@@ -127,14 +127,6 @@ std::vector<std::vector<int32_t>> PrepareMatMul(const std::shared_ptr<Graph> &gr
...
@@ -127,14 +127,6 @@ std::vector<std::vector<int32_t>> PrepareMatMul(const std::shared_ptr<Graph> &gr
return
strategies
;
return
strategies
;
}
}
std
::
vector
<
std
::
vector
<
int32_t
>>
PreparePReLU
(
const
std
::
shared_ptr
<
Graph
>
&
graph
,
const
std
::
vector
<
std
::
shared_ptr
<
OperatorInfo
>>
&
ops
,
const
size_t
iter_graph
,
const
size_t
iter_ops
)
{
std
::
vector
<
std
::
vector
<
int32_t
>>
strategies
=
MakeDataParallelStrategy
(
graph
,
ops
,
iter_graph
,
iter_ops
);
strategies
[
1
][
0
]
=
1
;
return
strategies
;
}
std
::
vector
<
std
::
vector
<
int32_t
>>
PrepareBiasAdd
(
const
std
::
shared_ptr
<
std
::
vector
<
int32_t
>>
&
s
)
{
std
::
vector
<
std
::
vector
<
int32_t
>>
PrepareBiasAdd
(
const
std
::
shared_ptr
<
std
::
vector
<
int32_t
>>
&
s
)
{
std
::
vector
<
std
::
vector
<
int32_t
>>
strategies
;
std
::
vector
<
std
::
vector
<
int32_t
>>
strategies
;
strategies
.
push_back
(
*
s
);
strategies
.
push_back
(
*
s
);
...
@@ -164,6 +156,32 @@ std::vector<std::vector<int32_t>> PrepareGatherV2(const std::shared_ptr<std::vec
...
@@ -164,6 +156,32 @@ std::vector<std::vector<int32_t>> PrepareGatherV2(const std::shared_ptr<std::vec
return
strategies
;
return
strategies
;
}
}
std
::
vector
<
std
::
vector
<
int32_t
>>
PrepareL2Normalize
(
const
std
::
vector
<
std
::
shared_ptr
<
OperatorInfo
>>
&
ops
,
const
size_t
iter_ops
,
std
::
vector
<
int32_t
>
s
)
{
int32_t
axis
=
0
;
auto
iter
=
ops
[
iter_ops
]
->
attrs
().
find
(
AXIS
);
if
(
iter
!=
ops
[
iter_ops
]
->
attrs
().
end
())
{
MS_EXCEPTION_IF_NULL
(
iter
->
second
);
if
(
iter
->
second
->
isa
<
Int32Imm
>
())
{
axis
=
iter
->
second
->
cast
<
Int32ImmPtr
>
()
->
value
();
}
else
{
MS_LOG
(
EXCEPTION
)
<<
ops
[
iter_ops
]
->
name
()
<<
" : The value of axis is not int."
;
}
}
int32_t
axis_index
=
axis
;
if
(
axis
<
0
)
{
size_t
input_dim
=
ops
[
iter_ops
]
->
inputs_tensor_info
()[
0
].
shape
().
size
();
axis_index
=
static_cast
<
int32_t
>
(
input_dim
)
+
axis
;
}
s
[
IntToSize
(
axis_index
)]
=
1
;
std
::
vector
<
std
::
vector
<
int32_t
>>
strategies
;
strategies
.
push_back
(
s
);
return
strategies
;
}
std
::
vector
<
std
::
vector
<
int32_t
>>
MakeRecSearchStrategy
(
const
std
::
shared_ptr
<
Graph
>
&
graph
,
std
::
vector
<
std
::
vector
<
int32_t
>>
MakeRecSearchStrategy
(
const
std
::
shared_ptr
<
Graph
>
&
graph
,
const
std
::
vector
<
std
::
shared_ptr
<
OperatorInfo
>>
&
ops
,
const
std
::
vector
<
std
::
shared_ptr
<
OperatorInfo
>>
&
ops
,
const
size_t
iter_graph
,
const
size_t
iter_ops
)
{
const
size_t
iter_graph
,
const
size_t
iter_ops
)
{
...
@@ -279,13 +297,8 @@ std::vector<std::vector<int32_t>> PrepareStrategy(const std::shared_ptr<Graph> &
...
@@ -279,13 +297,8 @@ std::vector<std::vector<int32_t>> PrepareStrategy(const std::shared_ptr<Graph> &
if
(
type
==
MATMUL
)
{
if
(
type
==
MATMUL
)
{
return
PrepareMatMul
(
graph
,
ops
,
iter_graph
,
iter_ops
);
return
PrepareMatMul
(
graph
,
ops
,
iter_graph
,
iter_ops
);
}
else
if
(
type
==
PRELU
)
{
return
PreparePReLU
(
graph
,
ops
,
iter_graph
,
iter_ops
);
}
else
if
(
type
==
ONEHOT
)
{
}
else
if
(
type
==
ONEHOT
)
{
return
PrepareOneHot
(
graph
,
ops
,
iter_graph
,
iter_ops
);
return
PrepareOneHot
(
graph
,
ops
,
iter_graph
,
iter_ops
);
}
else
if
(
type
==
SOFTMAX
||
type
==
LOG_SOFTMAX
||
type
==
SPARSE_SOFTMAX_CROSS_ENTROPY_WITH_LOGITS
||
type
==
SOFTMAX_CROSS_ENTROPY_WITH_LOGITS
)
{
return
MakeDataParallelStrategy
(
graph
,
ops
,
iter_graph
,
iter_ops
);
}
else
{
}
else
{
return
MakeRecSearchStrategy
(
graph
,
ops
,
iter_graph
,
iter_ops
);
return
MakeRecSearchStrategy
(
graph
,
ops
,
iter_graph
,
iter_ops
);
}
}
...
@@ -510,6 +523,9 @@ std::vector<std::vector<int32_t>> GenerateStrategiesFromStrategy(const std::vect
...
@@ -510,6 +523,9 @@ std::vector<std::vector<int32_t>> GenerateStrategiesFromStrategy(const std::vect
if
(
ops
[
iter_ops
]
->
type
()
==
GATHERV2
)
{
if
(
ops
[
iter_ops
]
->
type
()
==
GATHERV2
)
{
return
PrepareGatherV2
(
s_ptr
);
return
PrepareGatherV2
(
s_ptr
);
}
}
if
(
ops
[
iter_ops
]
->
type
()
==
L2_NORMALIZE
)
{
return
PrepareL2Normalize
(
ops
,
iter_ops
,
basic_stra
);
}
for
(
size_t
iter_op_inputs
=
0
;
iter_op_inputs
<
(
size_t
)
ops
[
iter_ops
]
->
inputs_tensor_info
().
size
();
for
(
size_t
iter_op_inputs
=
0
;
iter_op_inputs
<
(
size_t
)
ops
[
iter_ops
]
->
inputs_tensor_info
().
size
();
iter_op_inputs
++
)
{
iter_op_inputs
++
)
{
...
...
mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_generate_strategy.h
浏览文件 @
8e20d4d8
...
@@ -34,14 +34,13 @@ void GenerateStrategy(std::shared_ptr<Graph> graph, const std::vector<std::share
...
@@ -34,14 +34,13 @@ void GenerateStrategy(std::shared_ptr<Graph> graph, const std::vector<std::share
std
::
vector
<
std
::
vector
<
int32_t
>>
PrepareMatMul
(
const
std
::
shared_ptr
<
Graph
>
&
graph
,
std
::
vector
<
std
::
vector
<
int32_t
>>
PrepareMatMul
(
const
std
::
shared_ptr
<
Graph
>
&
graph
,
const
std
::
vector
<
std
::
shared_ptr
<
OperatorInfo
>>
&
ops
,
const
std
::
vector
<
std
::
shared_ptr
<
OperatorInfo
>>
&
ops
,
const
size_t
iter_graph
,
const
size_t
iter_ops
);
const
size_t
iter_graph
,
const
size_t
iter_ops
);
std
::
vector
<
std
::
vector
<
int32_t
>>
PreparePReLU
(
const
std
::
shared_ptr
<
Graph
>
&
graph
,
const
std
::
vector
<
std
::
shared_ptr
<
OperatorInfo
>>
&
ops
,
const
size_t
iter_graph
,
const
size_t
iter_ops
);
std
::
vector
<
std
::
vector
<
int32_t
>>
PrepareBiasAdd
(
const
std
::
shared_ptr
<
std
::
vector
<
int32_t
>>
&
s
);
std
::
vector
<
std
::
vector
<
int32_t
>>
PrepareBiasAdd
(
const
std
::
shared_ptr
<
std
::
vector
<
int32_t
>>
&
s
);
std
::
vector
<
std
::
vector
<
int32_t
>>
PrepareOneHot
(
const
std
::
shared_ptr
<
Graph
>
&
graph
,
std
::
vector
<
std
::
vector
<
int32_t
>>
PrepareOneHot
(
const
std
::
shared_ptr
<
Graph
>
&
graph
,
const
std
::
vector
<
std
::
shared_ptr
<
OperatorInfo
>>
&
ops
,
const
std
::
vector
<
std
::
shared_ptr
<
OperatorInfo
>>
&
ops
,
const
size_t
iter_graph
,
const
size_t
iter_ops
);
const
size_t
iter_graph
,
const
size_t
iter_ops
);
std
::
vector
<
std
::
vector
<
int32_t
>>
PrepareGatherV2
(
const
std
::
shared_ptr
<
std
::
vector
<
int32_t
>>
&
s
);
std
::
vector
<
std
::
vector
<
int32_t
>>
PrepareGatherV2
(
const
std
::
shared_ptr
<
std
::
vector
<
int32_t
>>
&
s
);
std
::
vector
<
std
::
vector
<
int32_t
>>
PrepareL2Normalize
(
const
std
::
vector
<
std
::
shared_ptr
<
OperatorInfo
>>
&
ops
,
const
size_t
iter_ops
,
std
::
vector
<
int32_t
>
s
);
std
::
vector
<
std
::
vector
<
int32_t
>>
MakeRecSearchStrategy
(
const
std
::
shared_ptr
<
Graph
>
&
graph
,
std
::
vector
<
std
::
vector
<
int32_t
>>
MakeRecSearchStrategy
(
const
std
::
shared_ptr
<
Graph
>
&
graph
,
const
std
::
vector
<
std
::
shared_ptr
<
OperatorInfo
>>
&
ops
,
const
std
::
vector
<
std
::
shared_ptr
<
OperatorInfo
>>
&
ops
,
const
size_t
iter_graph
,
const
size_t
iter_ops
);
const
size_t
iter_graph
,
const
size_t
iter_ops
);
...
...
mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_graph.h
浏览文件 @
8e20d4d8
...
@@ -38,6 +38,7 @@ enum OperatorType {
...
@@ -38,6 +38,7 @@ enum OperatorType {
kRecBiasAdd
,
kRecBiasAdd
,
kRecSoftmax
,
kRecSoftmax
,
kRecSparseSoftmaxCrossEntropyWithLogits
,
kRecSparseSoftmaxCrossEntropyWithLogits
,
kRecSoftmaxCrossEntropyWithLogits
,
kRecOneHot
,
kRecOneHot
,
kRecLog
,
kRecLog
,
kRecExp
,
kRecExp
,
...
...
mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_parse_graph.cc
浏览文件 @
8e20d4d8
...
@@ -250,12 +250,22 @@ std::shared_ptr<Graph> EliminateGraph(const std::shared_ptr<Graph> graph,
...
@@ -250,12 +250,22 @@ std::shared_ptr<Graph> EliminateGraph(const std::shared_ptr<Graph> graph,
new_graph
->
nodes
.
push_back
(
graph
->
nodes
[
i
]);
new_graph
->
nodes
.
push_back
(
graph
->
nodes
[
i
]);
auto
*
node_in
=
&
new_graph
->
nodes
[
index_list
->
at
(
i
)].
node_in
;
auto
*
node_in
=
&
new_graph
->
nodes
[
index_list
->
at
(
i
)].
node_in
;
for
(
size_t
j
=
0
;
j
<
node_in
->
size
();
j
++
)
{
for
(
size_t
j
=
node_in
->
size
();
j
>
0
;
j
--
)
{
node_in
->
at
(
j
)
=
index_list
->
at
(
node_in
->
at
(
j
));
bool
IsEliminated
=
(
index_list
->
at
(
node_in
->
at
(
j
-
1
))
==
SIZE_MAX
);
if
(
IsEliminated
)
{
node_in
->
erase
(
node_in
->
begin
()
+
j
-
1
);
}
else
{
node_in
->
at
(
j
-
1
)
=
index_list
->
at
(
node_in
->
at
(
j
-
1
));
}
}
}
auto
*
node_out
=
&
new_graph
->
nodes
[
index_list
->
at
(
i
)].
node_out
;
auto
*
node_out
=
&
new_graph
->
nodes
[
index_list
->
at
(
i
)].
node_out
;
for
(
size_t
j
=
0
;
j
<
node_out
->
size
();
j
++
)
{
for
(
size_t
j
=
node_out
->
size
();
j
>
0
;
j
--
)
{
node_out
->
at
(
j
)
=
index_list
->
at
(
node_out
->
at
(
j
));
bool
IsEliminated
=
(
index_list
->
at
(
node_out
->
at
(
j
-
1
))
==
SIZE_MAX
);
if
(
IsEliminated
)
{
node_out
->
erase
(
node_out
->
begin
()
+
j
-
1
);
}
else
{
node_out
->
at
(
j
-
1
)
=
index_list
->
at
(
node_out
->
at
(
j
-
1
));
}
}
}
}
}
return
new_graph
;
return
new_graph
;
...
...
mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_parse_graph.h
浏览文件 @
8e20d4d8
...
@@ -67,7 +67,7 @@ const std::map<std::string, OperatorType> DictOpType{
...
@@ -67,7 +67,7 @@ const std::map<std::string, OperatorType> DictOpType{
{
REAL_DIV
,
OperatorType
::
kRecElmWiseOp
},
{
REAL_DIV
,
OperatorType
::
kRecElmWiseOp
},
{
SOFTMAX
,
OperatorType
::
kRecSoftmax
},
{
SOFTMAX
,
OperatorType
::
kRecSoftmax
},
{
LOG_SOFTMAX
,
OperatorType
::
kRecSoftmax
},
{
LOG_SOFTMAX
,
OperatorType
::
kRecSoftmax
},
{
SOFTMAX_CROSS_ENTROPY_WITH_LOGITS
,
OperatorType
::
kRecSoftmax
},
{
SOFTMAX_CROSS_ENTROPY_WITH_LOGITS
,
OperatorType
::
kRecSoftmax
CrossEntropyWithLogits
},
{
SQRT
,
OperatorType
::
kRecElmWiseOp
},
{
SQRT
,
OperatorType
::
kRecElmWiseOp
},
{
NEG
,
OperatorType
::
kRecElmWiseOp
},
{
NEG
,
OperatorType
::
kRecElmWiseOp
},
{
POW
,
OperatorType
::
kRecElmWiseOp
},
{
POW
,
OperatorType
::
kRecElmWiseOp
},
...
...
mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_partition.cc
浏览文件 @
8e20d4d8
...
@@ -76,15 +76,16 @@ double GetWeights(const Graph::NodeType &node) {
...
@@ -76,15 +76,16 @@ double GetWeights(const Graph::NodeType &node) {
auto
cost_ptr
=
std
::
make_shared
<
CostCommon
>
();
auto
cost_ptr
=
std
::
make_shared
<
CostCommon
>
();
return
cost_ptr
->
GetMinCostIn
();
return
cost_ptr
->
GetMinCostIn
();
}
else
if
(
op
.
op_type
==
OperatorType
::
kRecBatchNorm
||
op
.
op_type
==
OperatorType
::
kRecOneHot
)
{
}
else
if
(
op
.
op_type
==
OperatorType
::
kRecBatchNorm
||
op
.
op_type
==
OperatorType
::
kRecOneHot
||
op
.
op_type
==
OperatorType
::
kRecPReLU
||
op
.
op_type
==
OperatorType
::
kRecSoftmax
||
op
.
op_type
==
OperatorType
::
kRecSparseSoftmaxCrossEntropyWithLogits
||
op
.
op_type
==
OperatorType
::
kRecSoftmaxCrossEntropyWithLogits
)
{
// For BatchParallel op
// For BatchParallel op
auto
cost_ptr
=
std
::
make_shared
<
CostBatchParallel
>
();
auto
cost_ptr
=
std
::
make_shared
<
CostBatchParallel
>
();
return
cost_ptr
->
GetMaxCostIn
();
return
cost_ptr
->
GetMaxCostIn
();
}
else
if
(
op
.
op_type
==
OperatorType
::
kRecUnkownType
||
op
.
op_type
==
OperatorType
::
kRecPReLU
||
}
else
if
(
op
.
op_type
==
OperatorType
::
kRecUnkownType
)
{
op
.
op_type
==
OperatorType
::
kRecSoftmax
||
// For Unkown type
op
.
op_type
==
OperatorType
::
kRecSparseSoftmaxCrossEntropyWithLogits
)
{
// For unprocessed type
return
0.0
;
return
0.0
;
}
else
{
}
else
{
MS_LOG
(
EXCEPTION
)
<<
"Failure: GetOperatorWeight failed."
;
MS_LOG
(
EXCEPTION
)
<<
"Failure: GetOperatorWeight failed."
;
...
@@ -170,14 +171,18 @@ StrategyRec PartitionNode(const Graph::NodeType &node,
...
@@ -170,14 +171,18 @@ StrategyRec PartitionNode(const Graph::NodeType &node,
auto
cost_ptr
=
std
::
make_shared
<
CostCommon
>
();
auto
cost_ptr
=
std
::
make_shared
<
CostCommon
>
();
return
cost_ptr
->
GetOptimalStr
(
node
,
node_name_to_strategy
,
*
graph
);
return
cost_ptr
->
GetOptimalStr
(
node
,
node_name_to_strategy
,
*
graph
);
}
else
if
(
node
.
apply
.
op_type
==
OperatorType
::
kRecBatchNorm
||
node
.
apply
.
op_type
==
OperatorType
::
kRecOneHot
)
{
}
else
if
(
node
.
apply
.
op_type
==
OperatorType
::
kRecBatchNorm
||
node
.
apply
.
op_type
==
OperatorType
::
kRecOneHot
||
node
.
apply
.
op_type
==
OperatorType
::
kRecPReLU
||
node
.
apply
.
op_type
==
kRecSoftmax
||
node
.
apply
.
op_type
==
OperatorType
::
kRecSparseSoftmaxCrossEntropyWithLogits
)
{
// For BatchParallel type
// For BatchParallel type
auto
cost_ptr
=
std
::
make_shared
<
CostBatchParallel
>
();
auto
cost_ptr
=
std
::
make_shared
<
CostBatchParallel
>
();
return
cost_ptr
->
GetOptimalStr
(
node
);
return
cost_ptr
->
GetOptimalStr
(
node
);
}
else
if
(
node
.
apply
.
op_type
==
OperatorType
::
kRecUnkownType
||
node
.
apply
.
op_type
==
OperatorType
::
kRecPReLU
||
}
else
if
(
node
.
apply
.
op_type
==
OperatorType
::
kRecSoftmaxCrossEntropyWithLogits
)
{
node
.
apply
.
op_type
==
OperatorType
::
kRecSoftmax
||
// For SoftmaxCrossEntropyWithLogits type
node
.
apply
.
op_type
==
OperatorType
::
kRecSparseSoftmaxCrossEntropyWithLogits
)
{
auto
cost_ptr
=
std
::
make_shared
<
CostSoftmaxCrossEntropyWithLogits
>
();
// For unprocessed type
return
cost_ptr
->
GetOptimalStr
(
node
);
}
else
if
(
node
.
apply
.
op_type
==
OperatorType
::
kRecUnkownType
)
{
// For Unkown type
StrategyRec
default_strategy
;
StrategyRec
default_strategy
;
return
default_strategy
;
return
default_strategy
;
}
else
{
}
else
{
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录