Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
6de7eb26
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
6de7eb26
编写于
5月 17, 2020
作者:
M
mindspore-ci-bot
提交者:
Gitee
5月 17, 2020
浏览文件
操作
浏览文件
下载
差异文件
!1203 delete ApplyAdamD for master
Merge pull request !1203 from changzherui/code516
上级
c3d9f180
1e939277
变更
10
隐藏空白更改
内联
并排
Showing
10 changed file
with
39 addition
and
67 deletion
+39
-67
graphengine
graphengine
+1
-1
mindspore/ccsrc/kernel/tbe/tbe_adapter.cc
mindspore/ccsrc/kernel/tbe/tbe_adapter.cc
+1
-3
mindspore/ccsrc/transform/convert.cc
mindspore/ccsrc/transform/convert.cc
+3
-3
mindspore/ccsrc/transform/op_declare.cc
mindspore/ccsrc/transform/op_declare.cc
+11
-21
mindspore/ccsrc/transform/op_declare.h
mindspore/ccsrc/transform/op_declare.h
+4
-6
mindspore/ops/_op_impl/tbe/apply_ftrl.py
mindspore/ops/_op_impl/tbe/apply_ftrl.py
+8
-10
mindspore/ops/_op_impl/tbe/apply_momentum.py
mindspore/ops/_op_impl/tbe/apply_momentum.py
+8
-9
mindspore/ops/operations/nn_ops.py
mindspore/ops/operations/nn_ops.py
+2
-12
tests/ut/python/ops/test_momentum.py
tests/ut/python/ops/test_momentum.py
+1
-1
tests/ut/python/ops/test_ops.py
tests/ut/python/ops/test_ops.py
+0
-1
未找到文件。
graphengine
@
995b6dad
Subproject commit
cf29b3d853b38c13d8d56181256613a11bf9eb95
Subproject commit
995b6dadc0fbbe4b80a08196886a53a18bffa60e
mindspore/ccsrc/kernel/tbe/tbe_adapter.cc
浏览文件 @
6de7eb26
...
...
@@ -32,8 +32,6 @@ namespace tbe {
static
std
::
map
<
string
,
string
>
tbe_func_adapter_map
=
{
{
"softmax"
,
"softmax_v2"
},
{
"log_softmax"
,
"log_softmax_v2"
},
{
"apply_momentum"
,
"apply_momentum_d"
},
{
"apply_ftrl"
,
"apply_ftrl_d"
},
{
"re_lu6"
,
"relu6"
},
{
"re_lu6_grad"
,
"relu6_grad"
},
{
"re_lu"
,
"relu"
},
...
...
@@ -82,7 +80,7 @@ static std::map<string, string> tbe_func_adapter_map = {
{
"batch_to_space"
,
"batch_to_space_d"
},
{
"resize_bilinear"
,
"resize_bilinear_v2_d"
},
{
"resize_bilinear_grad"
,
"resize_bilinear_v2_grad"
},
{
"adam"
,
"apply_adam
_d
"
},
{
"adam"
,
"apply_adam"
},
{
"r_oi_align"
,
"roi_align"
},
{
"r_oi_align_grad"
,
"roi_align_grad"
}};
...
...
mindspore/ccsrc/transform/convert.cc
浏览文件 @
6de7eb26
...
...
@@ -206,7 +206,7 @@ std::unordered_map<std::string, OpAdapterDescPtr> &DfGraphConvertor::get_adpt_ma
{
string
(
kNameIOU
),
ADPT_DESC
(
Iou
)},
{
string
(
kNameGreaterEqual
),
ADPT_DESC
(
GreaterEqual
)},
{
string
(
kNameSlice
),
ADPT_DESC
(
SliceD
)},
{
string
(
kNameApplyMomentum
),
ADPT_DESC
(
ApplyMomentum
D
)},
{
string
(
kNameApplyMomentum
),
ADPT_DESC
(
ApplyMomentum
)},
{
string
(
kNameMaxPool
),
ADPT_DESC
(
MaxPool
)},
{
string
(
kNameAvgPool
),
ADPT_DESC
(
AvgPool
)},
{
string
(
kNameMaxPoolWithArgmax
),
ADPT_DESC
(
MaxPoolWithArgmax
)},
...
...
@@ -386,7 +386,7 @@ std::unordered_map<std::string, OpAdapterDescPtr> &DfGraphConvertor::get_adpt_ma
{
string
(
kNameDepthToSpace
),
ADPT_DESC
(
DepthToSpace
)},
{
string
(
kNameSign
),
ADPT_DESC
(
Sign
)},
{
string
(
kNameRound
),
ADPT_DESC
(
Round
)},
{
string
(
kNameApplyFtrl
),
ADPT_DESC
(
ApplyFtrl
D
)},
{
string
(
kNameApplyFtrl
),
ADPT_DESC
(
ApplyFtrl
)},
{
string
(
kNameDiag
),
ADPT_DESC
(
Diag
)},
{
string
(
kNameDiagPart
),
ADPT_DESC
(
DiagPart
)},
{
string
(
kNameSpaceToBatch
),
ADPT_DESC
(
SpaceToBatchD
)},
...
...
@@ -398,7 +398,7 @@ std::unordered_map<std::string, OpAdapterDescPtr> &DfGraphConvertor::get_adpt_ma
{
string
(
kNameCTCLoss
),
ADPT_DESC
(
CTCLoss
)}};
#ifdef ENABLE_GE
adpt_map
[
string
(
kNamePrint
)]
=
ADPT_DESC
(
Print
);
adpt_map
[
string
(
kNameApplyAdam
)]
=
ADPT_DESC
(
ApplyAdam
D
);
adpt_map
[
string
(
kNameApplyAdam
)]
=
ADPT_DESC
(
ApplyAdam
);
#endif
return
adpt_map
;
}
...
...
mindspore/ccsrc/transform/op_declare.cc
浏览文件 @
6de7eb26
...
...
@@ -127,12 +127,11 @@ INPUT_MAP(Constant) = EMPTY_INPUT_MAP;
ATTR_MAP
(
Constant
)
=
{{
"value"
,
ATTR_DESC
(
value
,
AnyTraits
<
AnyValue
>
())}};
OUTPUT_MAP
(
Constant
)
=
{{
0
,
OUTPUT_DESC
(
y
)}};
// ApplyMomentum
D
INPUT_MAP
(
ApplyMomentum
D
)
=
{
// ApplyMomentum
INPUT_MAP
(
ApplyMomentum
)
=
{
{
1
,
INPUT_DESC
(
var
)},
{
2
,
INPUT_DESC
(
accum
)},
{
3
,
INPUT_DESC
(
lr
)},
{
4
,
INPUT_DESC
(
grad
)},
{
5
,
INPUT_DESC
(
momentum
)}};
ATTR_MAP
(
ApplyMomentumD
)
=
{{
"use_nesterov"
,
ATTR_DESC
(
use_nesterov
,
AnyTraits
<
bool
>
())},
{
"use_locking"
,
ATTR_DESC
(
use_locking
,
AnyTraits
<
bool
>
())}};
OUTPUT_MAP
(
ApplyMomentumD
)
=
{{
0
,
OUTPUT_DESC
(
var
)},
{
1
,
OUTPUT_DESC
(
accum
)}};
ATTR_MAP
(
ApplyMomentum
)
=
{{
"use_nesterov"
,
ATTR_DESC
(
use_nesterov
,
AnyTraits
<
bool
>
())}};
OUTPUT_MAP
(
ApplyMomentum
)
=
{{
0
,
OUTPUT_DESC
(
var
)}};
// ScalarSummary
INPUT_MAP
(
Summary
)
=
{{
2
,
INPUT_DESC
(
x
)}};
...
...
@@ -471,16 +470,7 @@ INPUT_MAP(ApplyAdam) = {{1, INPUT_DESC(var)}, {2, INPUT_DESC(m)},
{
10
,
INPUT_DESC
(
grad
)}};
ATTR_MAP
(
ApplyAdam
)
=
{{
"use_locking"
,
ATTR_DESC
(
use_locking
,
AnyTraits
<
bool
>
())},
{
"use_nesterov"
,
ATTR_DESC
(
use_nesterov
,
AnyTraits
<
bool
>
())}};
OUTPUT_MAP
(
ApplyAdam
)
=
{{
0
,
OUTPUT_DESC
(
var
)}};
// ApplyAdamD
INPUT_MAP
(
ApplyAdamD
)
=
{{
1
,
INPUT_DESC
(
var
)},
{
2
,
INPUT_DESC
(
m
)},
{
3
,
INPUT_DESC
(
v
)},
{
4
,
INPUT_DESC
(
beta1_power
)},
{
5
,
INPUT_DESC
(
beta2_power
)},
{
6
,
INPUT_DESC
(
lr
)},
{
7
,
INPUT_DESC
(
beta1
)},
{
8
,
INPUT_DESC
(
beta2
)},
{
9
,
INPUT_DESC
(
epsilon
)},
{
10
,
INPUT_DESC
(
grad
)}};
ATTR_MAP
(
ApplyAdamD
)
=
{{
"use_locking"
,
ATTR_DESC
(
use_locking
,
AnyTraits
<
bool
>
())},
{
"use_nesterov"
,
ATTR_DESC
(
use_nesterov
,
AnyTraits
<
bool
>
())}};
OUTPUT_MAP
(
ApplyAdamD
)
=
{{
0
,
OUTPUT_DESC
(
var
)},
{
1
,
OUTPUT_DESC
(
m
)},
{
2
,
OUTPUT_DESC
(
v
)}};
OUTPUT_MAP
(
ApplyAdam
)
=
{{
0
,
OUTPUT_DESC
(
var
)},
{
1
,
OUTPUT_DESC
(
m
)},
{
2
,
OUTPUT_DESC
(
v
)}};
// Relu6
INPUT_MAP
(
Relu6
)
=
{{
1
,
INPUT_DESC
(
x
)}};
...
...
@@ -1140,7 +1130,7 @@ INPUT_MAP(SparseApplyAdagradD) = {
{
1
,
INPUT_DESC
(
var
)},
{
2
,
INPUT_DESC
(
accum
)},
{
3
,
INPUT_DESC
(
grad
)},
{
4
,
INPUT_DESC
(
indices
)}};
ATTR_MAP
(
SparseApplyAdagradD
)
=
{{
"lr"
,
ATTR_DESC
(
lr
,
AnyTraits
<
float
>
())},
{
"use_locking"
,
ATTR_DESC
(
use_locking
,
AnyTraits
<
bool
>
())}};
OUTPUT_MAP
(
SparseApplyAdagradD
)
=
{{
0
,
OUTPUT_DESC
(
var
)}
,
{
1
,
OUTPUT_DESC
(
accum
)}
};
OUTPUT_MAP
(
SparseApplyAdagradD
)
=
{{
0
,
OUTPUT_DESC
(
var
)}};
// SparseApplyFtrlD
INPUT_MAP
(
SparseApplyFtrlD
)
=
{{
1
,
INPUT_DESC
(
var
)},
...
...
@@ -1176,11 +1166,11 @@ ATTR_MAP(Round) = EMPTY_ATTR_MAP;
OUTPUT_MAP
(
Round
)
=
{{
0
,
OUTPUT_DESC
(
y
)}};
// ApplyFtrl
INPUT_MAP
(
ApplyFtrl
D
)
=
{{
1
,
INPUT_DESC
(
var
)},
{
2
,
INPUT_DESC
(
accum
)},
{
3
,
INPUT_DESC
(
linear
)},
{
4
,
INPUT_DESC
(
grad
)},
{
5
,
INPUT_DESC
(
lr
)},
{
6
,
INPUT_DESC
(
l1
)},
{
7
,
INPUT_DESC
(
l2
)},
{
8
,
INPUT_DESC
(
lr_power
)}};
ATTR_MAP
(
ApplyFtrl
D
)
=
{{
"use_locking"
,
ATTR_DESC
(
use_locking
,
AnyTraits
<
bool
>
())}};
OUTPUT_MAP
(
ApplyFtrl
D
)
=
{{
0
,
OUTPUT_DESC
(
var
)},
{
1
,
OUTPUT_DESC
(
accum
)},
{
2
,
OUTPUT_DESC
(
line
ar
)}};
INPUT_MAP
(
ApplyFtrl
)
=
{{
1
,
INPUT_DESC
(
var
)},
{
2
,
INPUT_DESC
(
accum
)},
{
3
,
INPUT_DESC
(
linear
)},
{
4
,
INPUT_DESC
(
grad
)},
{
5
,
INPUT_DESC
(
lr
)},
{
6
,
INPUT_DESC
(
l1
)},
{
7
,
INPUT_DESC
(
l2
)},
{
8
,
INPUT_DESC
(
lr_power
)}};
ATTR_MAP
(
ApplyFtrl
)
=
{{
"use_locking"
,
ATTR_DESC
(
use_locking
,
AnyTraits
<
bool
>
())}};
OUTPUT_MAP
(
ApplyFtrl
)
=
{{
0
,
OUTPUT_DESC
(
v
ar
)}};
// Diag
INPUT_MAP
(
Diag
)
=
{{
1
,
INPUT_DESC
(
x
)}};
...
...
mindspore/ccsrc/transform/op_declare.h
浏览文件 @
6de7eb26
...
...
@@ -120,8 +120,6 @@ DECLARE_OP_ADAPTER(ResizeNearestNeighborV2Grad)
DECLARE_OP_USE_OUTPUT
(
ResizeNearestNeighborV2Grad
)
DECLARE_OP_ADAPTER
(
ApplyAdam
)
DECLARE_OP_USE_OUTPUT
(
ApplyAdam
)
DECLARE_OP_ADAPTER
(
ApplyAdamD
)
DECLARE_OP_USE_OUTPUT
(
ApplyAdamD
)
DECLARE_OP_ADAPTER
(
Relu6
)
DECLARE_OP_USE_OUTPUT
(
Relu6
)
DECLARE_OP_ADAPTER
(
Relu6Grad
)
...
...
@@ -315,8 +313,8 @@ DECLARE_OP_ADAPTER(Assign)
DECLARE_OP_USE_OUTPUT
(
Assign
)
DECLARE_OP_ADAPTER
(
Constant
)
DECLARE_OP_USE_OUTPUT
(
Constant
)
DECLARE_OP_ADAPTER
(
ApplyMomentum
D
)
DECLARE_OP_USE_OUTPUT
(
ApplyMomentum
D
)
DECLARE_OP_ADAPTER
(
ApplyMomentum
)
DECLARE_OP_USE_OUTPUT
(
ApplyMomentum
)
// ** Summary Operations **
DECLARE_OP_ADAPTER
(
Summary
)
...
...
@@ -446,8 +444,8 @@ DECLARE_OP_ADAPTER(LarsV2Update)
DECLARE_OP_USE_OUTPUT
(
LarsV2Update
)
DECLARE_OP_ADAPTER
(
Round
)
DECLARE_OP_USE_OUTPUT
(
Round
)
DECLARE_OP_ADAPTER
(
ApplyFtrl
D
)
DECLARE_OP_USE_OUTPUT
(
ApplyFtrl
D
)
DECLARE_OP_ADAPTER
(
ApplyFtrl
)
DECLARE_OP_USE_OUTPUT
(
ApplyFtrl
)
DECLARE_OP_ADAPTER
(
SparseApplyFtrlD
)
DECLARE_OP_USE_OUTPUT
(
SparseApplyFtrlD
)
DECLARE_OP_ADAPTER
(
Diag
)
...
...
mindspore/ops/_op_impl/tbe/apply_ftrl.py
浏览文件 @
6de7eb26
...
...
@@ -32,32 +32,30 @@ apply_ftrl_op_info = TBERegOp("ApplyFtrl") \
.
input
(
6
,
"l2"
,
False
,
"required"
,
"all"
)
\
.
input
(
7
,
"lr_power"
,
False
,
"required"
,
"all"
)
\
.
output
(
0
,
"var"
,
False
,
"required"
,
"all"
)
\
.
output
(
1
,
"accum"
,
False
,
"required"
,
"all"
)
\
.
output
(
2
,
"linear"
,
False
,
"required"
,
"all"
)
\
.
dtype_format
(
DataType
.
F16_5HD
,
DataType
.
F16_5HD
,
DataType
.
F16_5HD
,
DataType
.
F16_5HD
,
DataType
.
F16_5HD
,
DataType
.
F16_Default
,
DataType
.
F16_Default
,
DataType
.
F16_Default
,
DataType
.
F16_5HD
,
DataType
.
F16_5HD
,
DataType
.
F16_5HD
)
\
DataType
.
F16_5HD
)
\
.
dtype_format
(
DataType
.
F16_FracZ
,
DataType
.
F16_FracZ
,
DataType
.
F16_FracZ
,
DataType
.
F16_FracZ
,
DataType
.
F16_FracZ
,
DataType
.
F16_Default
,
DataType
.
F16_Default
,
DataType
.
F16_Default
,
DataType
.
F16_FracZ
,
DataType
.
F16_FracZ
,
DataType
.
F16_FracZ
)
\
DataType
.
F16_FracZ
)
\
.
dtype_format
(
DataType
.
F16_C1HWNCoC0
,
DataType
.
F16_C1HWNCoC0
,
DataType
.
F16_C1HWNCoC0
,
DataType
.
F16_C1HWNCoC0
,
DataType
.
F16_C1HWNCoC0
,
DataType
.
F16_Default
,
DataType
.
F16_Default
,
DataType
.
F16_Default
,
DataType
.
F16_C1HWNCoC0
,
DataType
.
F16_C1HWNCoC0
,
DataType
.
F16_C1HWNCoC0
)
\
DataType
.
F16_C1HWNCoC0
)
\
.
dtype_format
(
DataType
.
F16_Default
,
DataType
.
F16_Default
,
DataType
.
F16_Default
,
DataType
.
F16_Default
,
DataType
.
F16_Default
,
DataType
.
F16_Default
,
DataType
.
F16_Default
,
DataType
.
F16_Default
,
DataType
.
F16_Default
,
DataType
.
F16_Default
,
DataType
.
F16_Default
)
\
DataType
.
F16_Default
)
\
.
dtype_format
(
DataType
.
F32_5HD
,
DataType
.
F32_5HD
,
DataType
.
F32_5HD
,
DataType
.
F32_5HD
,
DataType
.
F32_5HD
,
DataType
.
F32_Default
,
DataType
.
F32_Default
,
DataType
.
F32_Default
,
DataType
.
F32_5HD
,
DataType
.
F32_5HD
,
DataType
.
F32_5HD
)
\
DataType
.
F32_5HD
)
\
.
dtype_format
(
DataType
.
F32_FracZ
,
DataType
.
F32_FracZ
,
DataType
.
F32_FracZ
,
DataType
.
F32_FracZ
,
DataType
.
F32_FracZ
,
DataType
.
F32_Default
,
DataType
.
F32_Default
,
DataType
.
F32_Default
,
DataType
.
F32_FracZ
,
DataType
.
F32_FracZ
,
DataType
.
F32_FracZ
)
\
DataType
.
F32_FracZ
)
\
.
dtype_format
(
DataType
.
F32_C1HWNCoC0
,
DataType
.
F32_C1HWNCoC0
,
DataType
.
F32_C1HWNCoC0
,
DataType
.
F32_C1HWNCoC0
,
DataType
.
F32_C1HWNCoC0
,
DataType
.
F32_Default
,
DataType
.
F32_Default
,
DataType
.
F32_Default
,
DataType
.
F32_C1HWNCoC0
,
DataType
.
F32_C1HWNCoC0
,
DataType
.
F32_C1HWNCoC0
)
\
DataType
.
F32_C1HWNCoC0
)
\
.
dtype_format
(
DataType
.
F32_Default
,
DataType
.
F32_Default
,
DataType
.
F32_Default
,
DataType
.
F32_Default
,
DataType
.
F32_Default
,
DataType
.
F32_Default
,
DataType
.
F32_Default
,
DataType
.
F32_Default
,
DataType
.
F32_Default
,
DataType
.
F32_Default
,
DataType
.
F32_Default
)
\
DataType
.
F32_Default
)
\
.
get_op_info
()
...
...
mindspore/ops/_op_impl/tbe/apply_momentum.py
浏览文件 @
6de7eb26
...
...
@@ -30,23 +30,22 @@ apply_momentum_op_info = TBERegOp("ApplyMomentum") \
.
input
(
3
,
"grad"
,
False
,
"required"
,
"all"
)
\
.
input
(
4
,
"momentum"
,
False
,
"required"
,
"all"
)
\
.
output
(
0
,
"var"
,
False
,
"required"
,
"all"
)
\
.
output
(
1
,
"accum"
,
False
,
"required"
,
"all"
)
\
.
dtype_format
(
DataType
.
F16_Default
,
DataType
.
F16_Default
,
DataType
.
F16_Default
,
DataType
.
F16_Default
,
DataType
.
F16_Default
,
DataType
.
F16_Default
,
DataType
.
F16_Default
)
\
DataType
.
F16_Default
,
DataType
.
F16_Default
)
\
.
dtype_format
(
DataType
.
F16_5HD
,
DataType
.
F16_5HD
,
DataType
.
F16_Default
,
DataType
.
F16_5HD
,
DataType
.
F16_Default
,
DataType
.
F16_5HD
,
DataType
.
F16_5HD
)
\
DataType
.
F16_Default
,
DataType
.
F16_5HD
)
\
.
dtype_format
(
DataType
.
F16_C1HWNCoC0
,
DataType
.
F16_C1HWNCoC0
,
DataType
.
F16_Default
,
DataType
.
F16_C1HWNCoC0
,
DataType
.
F16_Default
,
DataType
.
F16_C1HWNCoC0
,
DataType
.
F16_C1HWNCoC0
)
\
DataType
.
F16_Default
,
DataType
.
F16_C1HWNCoC0
)
\
.
dtype_format
(
DataType
.
F16_FracZ
,
DataType
.
F16_FracZ
,
DataType
.
F16_Default
,
DataType
.
F16_FracZ
,
DataType
.
F16_Default
,
DataType
.
F16_FracZ
,
DataType
.
F16_FracZ
)
\
DataType
.
F16_Default
,
DataType
.
F16_FracZ
)
\
.
dtype_format
(
DataType
.
F32_Default
,
DataType
.
F32_Default
,
DataType
.
F32_Default
,
DataType
.
F32_Default
,
DataType
.
F32_Default
,
DataType
.
F32_Default
,
DataType
.
F32_Default
)
\
DataType
.
F32_Default
,
DataType
.
F32_Default
)
\
.
dtype_format
(
DataType
.
F32_5HD
,
DataType
.
F32_5HD
,
DataType
.
F32_Default
,
DataType
.
F32_5HD
,
DataType
.
F32_Default
,
DataType
.
F32_5HD
,
DataType
.
F32_5HD
)
\
DataType
.
F32_Default
,
DataType
.
F32_5HD
)
\
.
dtype_format
(
DataType
.
F32_C1HWNCoC0
,
DataType
.
F32_C1HWNCoC0
,
DataType
.
F32_Default
,
DataType
.
F32_C1HWNCoC0
,
DataType
.
F32_Default
,
DataType
.
F32_C1HWNCoC0
,
DataType
.
F32_C1HWNCoC0
)
\
DataType
.
F32_Default
,
DataType
.
F32_C1HWNCoC0
)
\
.
dtype_format
(
DataType
.
F32_FracZ
,
DataType
.
F32_FracZ
,
DataType
.
F32_Default
,
DataType
.
F32_FracZ
,
DataType
.
F32_Default
,
DataType
.
F32_FracZ
,
DataType
.
F32_FracZ
)
\
DataType
.
F32_Default
,
DataType
.
F32_FracZ
)
\
.
get_op_info
()
...
...
mindspore/ops/operations/nn_ops.py
浏览文件 @
6de7eb26
...
...
@@ -1465,11 +1465,8 @@ class ApplyMomentum(PrimitiveWithInfer):
def
__init__
(
self
,
use_nesterov
=
False
,
use_locking
=
False
,
gradient_scale
=
1.0
):
self
.
init_prim_io_names
(
inputs
=
[
'variable'
,
'accumulation'
,
'learning_rate'
,
'gradient'
,
'momentum'
],
outputs
=
[
'output'
])
self
.
is_tbe
=
context
.
get_context
(
"device_target"
)
==
"Ascend"
def
infer_shape
(
self
,
v_shape
,
a_shape
,
l_shape
,
g_shape
,
m_shape
):
if
self
.
is_tbe
:
return
v_shape
,
v_shape
return
v_shape
def
infer_dtype
(
self
,
v_dtype
,
a_dtype
,
l_dtype
,
g_dtype
,
m_dtype
):
...
...
@@ -1480,8 +1477,6 @@ class ApplyMomentum(PrimitiveWithInfer):
validator
.
check_scalar_or_tensor_type_same
({
"l_dtype"
:
l_dtype
},
valid_types
,
self
.
name
)
validator
.
check_scalar_or_tensor_type_same
({
"g_dtype"
:
g_dtype
},
valid_types
,
self
.
name
)
validator
.
check_scalar_or_tensor_type_same
({
"m_dtype"
:
m_dtype
},
valid_types
,
self
.
name
)
if
self
.
is_tbe
:
return
g_dtype
,
g_dtype
return
g_dtype
...
...
@@ -2622,13 +2617,13 @@ class SparseApplyAdagrad(PrimitiveWithInfer):
validator
.
check
(
'var_shape[1:]'
,
var_shape
[
1
:],
'grad_shape[1:]'
,
grad_shape
[
1
:],
Rel
.
EQ
,
self
.
name
)
validator
.
check_integer
(
"indices rank"
,
len
(
indices_shape
),
1
,
Rel
.
EQ
,
self
.
name
)
validator
.
check
(
'grad_shape[0]'
,
grad_shape
[
0
],
'indices_shape[0]'
,
indices_shape
[
0
],
Rel
.
EQ
,
self
.
name
)
return
var_shape
,
accum_shape
return
var_shape
def
infer_dtype
(
self
,
var_type
,
accum_type
,
grad_type
,
indices_type
):
args
=
{
'var'
:
var_type
,
'accum'
:
accum_type
,
'grad'
:
grad_type
}
validator
.
check_tensor_type_same
(
args
,
(
mstype
.
float32
,),
self
.
name
)
validator
.
check_tensor_type_same
({
'indices'
:
indices_type
},
[
mstype
.
int32
],
self
.
name
)
return
var_type
,
accum_type
return
var_type
class
LARSUpdate
(
PrimitiveWithInfer
):
...
...
@@ -2737,14 +2732,11 @@ class ApplyFtrl(PrimitiveWithInfer):
self
.
init_prim_io_names
(
inputs
=
[
'var'
,
'accum'
,
'linear'
,
'grad'
,
'lr'
,
'l1'
,
'l2'
,
'lr_power'
],
outputs
=
[
'output'
])
self
.
use_locking
=
validator
.
check_value_type
(
"use_locking"
,
use_locking
,
[
bool
],
self
.
name
)
self
.
is_tbe
=
context
.
get_context
(
"device_target"
)
==
"Ascend"
def
infer_shape
(
self
,
var_shape
,
accum_shape
,
linear_shape
,
grad_shape
,
lr_shape
,
l1_shape
,
l2_shape
,
lr_power_shape
):
validator
.
check
(
'var shape'
,
var_shape
,
'accum shape'
,
accum_shape
,
Rel
.
EQ
,
self
.
name
)
validator
.
check
(
'var shape'
,
var_shape
,
'linear shape'
,
linear_shape
,
Rel
.
EQ
,
self
.
name
)
if
self
.
is_tbe
:
return
var_shape
,
var_shape
,
var_shape
return
var_shape
def
infer_dtype
(
self
,
var_type
,
accum_type
,
linear_type
,
grad_type
,
lr_type
,
l1_type
,
l2_type
,
lr_power_type
):
...
...
@@ -2756,8 +2748,6 @@ class ApplyFtrl(PrimitiveWithInfer):
validator
.
check_scalar_or_tensor_type_same
({
"l1"
:
l1_type
},
valid_types
,
self
.
name
)
validator
.
check_scalar_or_tensor_type_same
({
"l2"
:
l2_type
},
valid_types
,
self
.
name
)
validator
.
check_scalar_or_tensor_type_same
({
"lr_power"
:
lr_power_type
},
valid_types
,
self
.
name
)
if
self
.
is_tbe
:
return
var_type
,
var_type
,
var_type
return
var_type
...
...
tests/ut/python/ops/test_momentum.py
浏览文件 @
6de7eb26
...
...
@@ -38,7 +38,7 @@ def tensor_run_opt(opt, iters, learning_rate, momentum,
gradient
,
variable
,
moment
):
""" tensor_run_opt """
success
=
True
new_weight
=
opt
(
variable
,
moment
,
learning_rate
,
gradient
,
momentum
)
[
0
]
new_weight
=
opt
(
variable
,
moment
,
learning_rate
,
gradient
,
momentum
)
success
=
F
.
depend
(
success
,
F
.
assign
(
variable
,
new_weight
))
return
success
...
...
tests/ut/python/ops/test_ops.py
浏览文件 @
6de7eb26
...
...
@@ -853,7 +853,6 @@ test_case_nn_ops = [
(
'SparseApplyAdagrad'
,
{
'block'
:
P
.
SparseApplyAdagrad
(
0.5
),
'desc_inputs'
:
[[
3
,
3
],
[
3
,
3
],
[
3
,
3
],
Tensor
(
np
.
ones
((
3
,),
np
.
int32
))],
'desc_bprop'
:
[[
3
,
3
],
[
3
,
3
]],
'skip'
:
[
'backward'
]}),
(
'Flatten_1'
,
{
'block'
:
NetForFlatten
(),
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录