Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
c22c865c
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
c22c865c
编写于
6月 24, 2020
作者:
M
mindspore-ci-bot
提交者:
Gitee
6月 24, 2020
浏览文件
操作
浏览文件
下载
差异文件
!1950 [front]add `has_effect` direct to target construct method not cell object
Merge pull request !1950 from vlne-v1/move-flags-to-construct
上级
ecff9f4e
8f56528f
变更
15
隐藏空白更改
内联
并排
Showing
15 changed file
with
44 addition
and
34 deletion
+44
-34
.gitignore
.gitignore
+1
-0
mindspore/ccsrc/ir/primitive.cc
mindspore/ccsrc/ir/primitive.cc
+1
-1
mindspore/ccsrc/ir/tensor.cc
mindspore/ccsrc/ir/tensor.cc
+2
-2
mindspore/ccsrc/pipeline/parse/parse.cc
mindspore/ccsrc/pipeline/parse/parse.cc
+8
-5
mindspore/ccsrc/pipeline/parse/parse.h
mindspore/ccsrc/pipeline/parse/parse.h
+3
-3
mindspore/nn/layer/basic.py
mindspore/nn/layer/basic.py
+1
-2
mindspore/nn/wrap/loss_scale.py
mindspore/nn/wrap/loss_scale.py
+1
-1
mindspore/ops/composite/base.py
mindspore/ops/composite/base.py
+15
-9
model_zoo/Transformer/src/transformer_for_train.py
model_zoo/Transformer/src/transformer_for_train.py
+1
-1
model_zoo/bert/src/bert_for_pre_training.py
model_zoo/bert/src/bert_for_pre_training.py
+3
-4
model_zoo/deeplabv3/src/deeplabv3.py
model_zoo/deeplabv3/src/deeplabv3.py
+3
-1
tests/st/networks/models/bert/src/bert_for_pre_training.py
tests/st/networks/models/bert/src/bert_for_pre_training.py
+1
-1
tests/ut/python/keep_order/test_keep_order.py
tests/ut/python/keep_order/test_keep_order.py
+1
-1
tests/ut/python/ops/test_math_ops.py
tests/ut/python/ops/test_math_ops.py
+1
-1
tests/ut/python/pynative_mode/ge/model/test_lenet_model.py
tests/ut/python/pynative_mode/ge/model/test_lenet_model.py
+2
-2
未找到文件。
.gitignore
浏览文件 @
c22c865c
...
@@ -26,6 +26,7 @@ cmake-build-debug
...
@@ -26,6 +26,7 @@ cmake-build-debug
*_pb2.py
*_pb2.py
*.pb.h
*.pb.h
*.pb.cc
*.pb.cc
*.pb
# Object files
# Object files
*.o
*.o
...
...
mindspore/ccsrc/ir/primitive.cc
浏览文件 @
c22c865c
...
@@ -86,7 +86,7 @@ void PrimitivePy::AddPyAttr(const py::str &name, const py::object &obj) {
...
@@ -86,7 +86,7 @@ void PrimitivePy::AddPyAttr(const py::str &name, const py::object &obj) {
}
}
bool
converted
=
parse
::
ConvertData
(
obj
,
&
converted_ret
);
bool
converted
=
parse
::
ConvertData
(
obj
,
&
converted_ret
);
if
(
!
converted
)
{
if
(
!
converted
)
{
MS_LOG
(
EXCEPTION
)
<<
"Attribute convert error with type:"
<<
std
::
string
(
py
::
str
(
obj
));
MS_LOG
(
EXCEPTION
)
<<
"Attribute convert error with type:
"
<<
std
::
string
(
py
::
str
(
obj
));
}
}
(
void
)
this
->
AddAttr
(
attr_name
,
converted_ret
);
(
void
)
this
->
AddAttr
(
attr_name
,
converted_ret
);
}
}
...
...
mindspore/ccsrc/ir/tensor.cc
浏览文件 @
c22c865c
...
@@ -345,14 +345,14 @@ abstract::AbstractBasePtr Tensor::ToAbstract() {
...
@@ -345,14 +345,14 @@ abstract::AbstractBasePtr Tensor::ToAbstract() {
std
::
string
Tensor
::
GetShapeAndDataTypeInfo
()
const
{
std
::
string
Tensor
::
GetShapeAndDataTypeInfo
()
const
{
std
::
ostringstream
buf
;
std
::
ostringstream
buf
;
buf
<<
"Tensor
\n
shape:["
<<
shape
()
<<
"]"
<<
this
->
Dtype
()
->
ToString
();
buf
<<
"Tensor shape:["
<<
shape
()
<<
"]"
<<
this
->
Dtype
()
->
ToString
();
return
buf
.
str
();
return
buf
.
str
();
}
}
std
::
string
Tensor
::
ToString
()
const
{
std
::
string
Tensor
::
ToString
()
const
{
const
int
small_tensor_size
=
30
;
const
int
small_tensor_size
=
30
;
std
::
ostringstream
buf
;
std
::
ostringstream
buf
;
buf
<<
"Tensor
\n
shape:["
<<
shape
()
<<
"]"
<<
this
->
Dtype
()
->
ToString
();
buf
<<
"Tensor shape:["
<<
shape
()
<<
"]"
<<
this
->
Dtype
()
->
ToString
();
// only print small tensor
// only print small tensor
if
(
DataSize
()
<
small_tensor_size
)
{
if
(
DataSize
()
<
small_tensor_size
)
{
buf
<<
"val:"
<<
std
::
string
(
py
::
str
(
data
()));
buf
<<
"val:"
<<
std
::
string
(
py
::
str
(
data
()));
...
...
mindspore/ccsrc/pipeline/parse/parse.cc
浏览文件 @
c22c865c
...
@@ -234,7 +234,11 @@ FunctionBlockPtr Parser::ParseFunction(const py::object &node, const FunctionBlo
...
@@ -234,7 +234,11 @@ FunctionBlockPtr Parser::ParseFunction(const py::object &node, const FunctionBlo
current_fg
->
debug_info
()
->
set_deco_location
(
GetLocation
(
deco_list
));
current_fg
->
debug_info
()
->
set_deco_location
(
GetLocation
(
deco_list
));
}
}
bool
set_flag
=
ast_
->
UpdateFuncGraphFlags
(
current_fg
);
bool
set_flag
=
UpdateFuncGraphFlags
(
ast_
->
function
(),
current_fg
);
if
(
ast_
->
obj
()
!=
ast_
->
function
())
{
set_flag
=
set_flag
&&
UpdateFuncGraphFlags
(
ast_
->
obj
(),
current_fg
);
}
if
(
!
set_flag
)
{
if
(
!
set_flag
)
{
MS_LOG
(
ERROR
)
<<
"Set flags failed"
;
MS_LOG
(
ERROR
)
<<
"Set flags failed"
;
return
nullptr
;
return
nullptr
;
...
@@ -1436,17 +1440,17 @@ bool ParseAst::IsClassMember(const py::object &node) {
...
@@ -1436,17 +1440,17 @@ bool ParseAst::IsClassMember(const py::object &node) {
return
ret
.
cast
<
bool
>
();
return
ret
.
cast
<
bool
>
();
}
}
bool
ParseAst
::
UpdateFuncGraphFlags
(
const
FuncGraphPtr
&
func_graph
)
{
bool
UpdateFuncGraphFlags
(
py
::
object
obj
,
const
FuncGraphPtr
&
func_graph
)
{
if
(
func_graph
==
nullptr
)
{
if
(
func_graph
==
nullptr
)
{
MS_LOG
(
ERROR
)
<<
"FuncGraph is null"
;
MS_LOG
(
ERROR
)
<<
"FuncGraph is null"
;
return
false
;
return
false
;
}
}
if
(
!
py
::
hasattr
(
obj
_
,
PYTHON_EXTERN_MINDSPORE_FLAG
))
{
if
(
!
py
::
hasattr
(
obj
,
PYTHON_EXTERN_MINDSPORE_FLAG
))
{
MS_LOG
(
DEBUG
)
<<
"No flags"
;
MS_LOG
(
DEBUG
)
<<
"No flags"
;
return
true
;
return
true
;
}
}
py
::
dict
flags
=
python_adapter
::
GetPyObjAttr
(
obj
_
,
PYTHON_EXTERN_MINDSPORE_FLAG
);
py
::
dict
flags
=
python_adapter
::
GetPyObjAttr
(
obj
,
PYTHON_EXTERN_MINDSPORE_FLAG
);
for
(
auto
&
item
:
flags
)
{
for
(
auto
&
item
:
flags
)
{
if
(
!
py
::
isinstance
<
py
::
str
>
(
item
.
first
))
{
if
(
!
py
::
isinstance
<
py
::
str
>
(
item
.
first
))
{
MS_LOG
(
ERROR
)
<<
"Type error in flags dict convert"
;
MS_LOG
(
ERROR
)
<<
"Type error in flags dict convert"
;
...
@@ -1466,7 +1470,6 @@ bool ParseAst::UpdateFuncGraphFlags(const FuncGraphPtr &func_graph) {
...
@@ -1466,7 +1470,6 @@ bool ParseAst::UpdateFuncGraphFlags(const FuncGraphPtr &func_graph) {
return
false
;
return
false
;
}
}
}
}
return
true
;
return
true
;
}
}
...
...
mindspore/ccsrc/pipeline/parse/parse.h
浏览文件 @
c22c865c
...
@@ -327,9 +327,6 @@ class ParseAst {
...
@@ -327,9 +327,6 @@ class ParseAst {
bool
IsClassMember
(
const
py
::
object
&
node
);
bool
IsClassMember
(
const
py
::
object
&
node
);
// update the graph flags
bool
UpdateFuncGraphFlags
(
const
FuncGraphPtr
&
func_graph
);
private:
private:
// save obj,eg: class instance or function
// save obj,eg: class instance or function
py
::
object
obj_
;
py
::
object
obj_
;
...
@@ -350,6 +347,9 @@ class ParseAst {
...
@@ -350,6 +347,9 @@ class ParseAst {
int
function_line_offset_
;
int
function_line_offset_
;
};
};
// update the graph flags
bool
UpdateFuncGraphFlags
(
py
::
object
obj
,
const
FuncGraphPtr
&
func_graph
);
AnfNodePtr
GetMixedPrecisionCastHelp
(
const
FuncGraphPtr
&
func_graph
,
const
AnfNodePtr
&
param
);
AnfNodePtr
GetMixedPrecisionCastHelp
(
const
FuncGraphPtr
&
func_graph
,
const
AnfNodePtr
&
param
);
}
// namespace parse
}
// namespace parse
...
...
mindspore/nn/layer/basic.py
浏览文件 @
c22c865c
...
@@ -284,7 +284,6 @@ class ClipByNorm(Cell):
...
@@ -284,7 +284,6 @@ class ClipByNorm(Cell):
self
.
reduce_sum
=
P
.
ReduceSum
(
keep_dims
=
True
)
self
.
reduce_sum
=
P
.
ReduceSum
(
keep_dims
=
True
)
self
.
select_
=
P
.
Select
()
self
.
select_
=
P
.
Select
()
self
.
greater_
=
P
.
Greater
()
self
.
greater_
=
P
.
Greater
()
self
.
axis
=
()
self
.
cast
=
P
.
Cast
()
self
.
cast
=
P
.
Cast
()
self
.
zero
=
Tensor
(
np
.
array
([
0.0
]).
astype
(
np
.
float32
))
self
.
zero
=
Tensor
(
np
.
array
([
0.0
]).
astype
(
np
.
float32
))
self
.
sqrt
=
P
.
Sqrt
()
self
.
sqrt
=
P
.
Sqrt
()
...
@@ -299,7 +298,7 @@ class ClipByNorm(Cell):
...
@@ -299,7 +298,7 @@ class ClipByNorm(Cell):
def
construct
(
self
,
x
,
clip_norm
):
def
construct
(
self
,
x
,
clip_norm
):
"""add ms_function decorator for pynative mode"""
"""add ms_function decorator for pynative mode"""
mul_x
=
F
.
square
(
x
)
mul_x
=
F
.
square
(
x
)
l2sum
=
self
.
cast
(
self
.
reduce_sum
(
mul_x
,
self
.
axis
),
mstype
.
float32
)
l2sum
=
self
.
cast
(
self
.
reduce_sum
(
mul_x
),
mstype
.
float32
)
cond
=
self
.
greater_
(
l2sum
,
self
.
zero
)
cond
=
self
.
greater_
(
l2sum
,
self
.
zero
)
ones_
=
self
.
fill
(
self
.
dtype
(
cond
),
self
.
shape
(
cond
),
1.0
)
ones_
=
self
.
fill
(
self
.
dtype
(
cond
),
self
.
shape
(
cond
),
1.0
)
...
...
mindspore/nn/wrap/loss_scale.py
浏览文件 @
c22c865c
...
@@ -234,8 +234,8 @@ class TrainOneStepWithLossScaleCell(Cell):
...
@@ -234,8 +234,8 @@ class TrainOneStepWithLossScaleCell(Cell):
if
scale_update_cell
:
if
scale_update_cell
:
self
.
loss_scale
=
Parameter
(
Tensor
(
scale_update_cell
.
get_loss_scale
(),
dtype
=
mstype
.
float32
),
self
.
loss_scale
=
Parameter
(
Tensor
(
scale_update_cell
.
get_loss_scale
(),
dtype
=
mstype
.
float32
),
name
=
"loss_scale"
)
name
=
"loss_scale"
)
self
.
add_flags
(
has_effect
=
True
)
@
C
.
add_flags
(
has_effect
=
True
)
def
construct
(
self
,
data
,
label
,
sens
=
None
):
def
construct
(
self
,
data
,
label
,
sens
=
None
):
weights
=
self
.
weights
weights
=
self
.
weights
loss
=
self
.
network
(
data
,
label
)
loss
=
self
.
network
(
data
,
label
)
...
...
mindspore/ops/composite/base.py
浏览文件 @
c22c865c
...
@@ -30,16 +30,16 @@ from ...common.parameter import Parameter
...
@@ -30,16 +30,16 @@ from ...common.parameter import Parameter
__all__
=
[
EnvInstance_
,
TupleAdd_
,
TupleSlice_
,
UnpackCall_
,
TupleGetItemTensor_
]
__all__
=
[
EnvInstance_
,
TupleAdd_
,
TupleSlice_
,
UnpackCall_
,
TupleGetItemTensor_
]
def
add_flags
(
fn
,
**
flags
):
def
add_flags
(
fn
=
None
,
**
flags
):
"""
"""
An
interface
to add flag for a function.
An
decorator
to add flag for a function.
Note:
Note:
Only supports bool value.
Only supports bool value.
Args:
Args:
fn (Function): Function or cell to add flag.
fn (Function): Function or cell to add flag.
Default: None.
flags (
bool): Flags use kwargs
.
flags (
dict): Flags use kwargs. Default: None
.
Returns:
Returns:
Function, the fn added flags.
Function, the fn added flags.
...
@@ -47,11 +47,17 @@ def add_flags(fn, **flags):
...
@@ -47,11 +47,17 @@ def add_flags(fn, **flags):
Examples:
Examples:
>>> add_flags(net, predit=True)
>>> add_flags(net, predit=True)
"""
"""
# need set the attr and access on c++
def
deco
(
fn
):
if
not
hasattr
(
fn
,
"_mindspore_flags"
):
# need set the attr and access on c++
fn
.
_mindspore_flags
=
{}
if
not
hasattr
(
fn
,
"_mindspore_flags"
):
fn
.
_mindspore_flags
.
update
({
**
flags
})
fn
.
_mindspore_flags
=
{}
return
fn
fn
.
_mindspore_flags
.
update
({
**
flags
})
return
fn
ret
=
deco
if
fn
is
not
None
:
ret
=
deco
(
fn
)
return
ret
def
core
(
fn
=
None
,
**
flags
):
def
core
(
fn
=
None
,
**
flags
):
...
...
model_zoo/Transformer/src/transformer_for_train.py
浏览文件 @
c22c865c
...
@@ -277,8 +277,8 @@ class TransformerTrainOneStepWithLossScaleCell(nn.Cell):
...
@@ -277,8 +277,8 @@ class TransformerTrainOneStepWithLossScaleCell(nn.Cell):
if
scale_update_cell
:
if
scale_update_cell
:
self
.
loss_scale
=
Parameter
(
Tensor
(
scale_update_cell
.
get_loss_scale
(),
dtype
=
mstype
.
float32
),
self
.
loss_scale
=
Parameter
(
Tensor
(
scale_update_cell
.
get_loss_scale
(),
dtype
=
mstype
.
float32
),
name
=
"loss_scale"
)
name
=
"loss_scale"
)
self
.
add_flags
(
has_effect
=
True
)
@
C
.
add_flags
(
has_effect
=
True
)
def
construct
(
self
,
def
construct
(
self
,
source_eos_ids
,
source_eos_ids
,
source_eos_mask
,
source_eos_mask
,
...
...
model_zoo/bert/src/bert_for_pre_training.py
浏览文件 @
c22c865c
...
@@ -132,9 +132,9 @@ class GetNextSentenceOutput(nn.Cell):
...
@@ -132,9 +132,9 @@ class GetNextSentenceOutput(nn.Cell):
def
__init__
(
self
,
config
):
def
__init__
(
self
,
config
):
super
(
GetNextSentenceOutput
,
self
).
__init__
()
super
(
GetNextSentenceOutput
,
self
).
__init__
()
self
.
log_softmax
=
_selected_ops
.
LogSoftmax
()
self
.
log_softmax
=
_selected_ops
.
LogSoftmax
()
self
.
weight_init
=
TruncatedNormal
(
config
.
initializer_range
)
weight_init
=
TruncatedNormal
(
config
.
initializer_range
)
self
.
dense
=
nn
.
Dense
(
config
.
hidden_size
,
2
,
self
.
dense
=
nn
.
Dense
(
config
.
hidden_size
,
2
,
weight_init
=
self
.
weight_init
,
has_bias
=
True
).
to_float
(
config
.
compute_type
)
weight_init
=
weight_init
,
has_bias
=
True
).
to_float
(
config
.
compute_type
)
self
.
dtype
=
config
.
dtype
self
.
dtype
=
config
.
dtype
self
.
cast
=
P
.
Cast
()
self
.
cast
=
P
.
Cast
()
...
@@ -321,7 +321,6 @@ class BertTrainOneStepCell(nn.Cell):
...
@@ -321,7 +321,6 @@ class BertTrainOneStepCell(nn.Cell):
if
self
.
reducer_flag
:
if
self
.
reducer_flag
:
# apply grad reducer on grads
# apply grad reducer on grads
grads
=
self
.
grad_reducer
(
grads
)
grads
=
self
.
grad_reducer
(
grads
)
succ
=
self
.
optimizer
(
grads
)
succ
=
self
.
optimizer
(
grads
)
return
F
.
depend
(
loss
,
succ
)
return
F
.
depend
(
loss
,
succ
)
...
@@ -380,8 +379,8 @@ class BertTrainOneStepWithLossScaleCell(nn.Cell):
...
@@ -380,8 +379,8 @@ class BertTrainOneStepWithLossScaleCell(nn.Cell):
if
scale_update_cell
:
if
scale_update_cell
:
self
.
loss_scale
=
Parameter
(
Tensor
(
scale_update_cell
.
get_loss_scale
(),
dtype
=
mstype
.
float32
),
self
.
loss_scale
=
Parameter
(
Tensor
(
scale_update_cell
.
get_loss_scale
(),
dtype
=
mstype
.
float32
),
name
=
"loss_scale"
)
name
=
"loss_scale"
)
self
.
add_flags
(
has_effect
=
True
)
@
C
.
add_flags
(
has_effect
=
True
)
def
construct
(
self
,
def
construct
(
self
,
input_ids
,
input_ids
,
input_mask
,
input_mask
,
...
...
model_zoo/deeplabv3/src/deeplabv3.py
浏览文件 @
c22c865c
...
@@ -17,6 +17,7 @@
...
@@ -17,6 +17,7 @@
import
numpy
as
np
import
numpy
as
np
import
mindspore.nn
as
nn
import
mindspore.nn
as
nn
from
mindspore.ops
import
operations
as
P
from
mindspore.ops
import
operations
as
P
from
mindspore.ops.composite
import
add_flags
from
.backbone.resnet_deeplab
import
_conv_bn_relu
,
resnet50_dl
,
_deep_conv_bn_relu
,
\
from
.backbone.resnet_deeplab
import
_conv_bn_relu
,
resnet50_dl
,
_deep_conv_bn_relu
,
\
DepthwiseConv2dNative
,
SpaceToBatch
,
BatchToSpace
DepthwiseConv2dNative
,
SpaceToBatch
,
BatchToSpace
...
@@ -121,6 +122,7 @@ class ASPP(nn.Cell):
...
@@ -121,6 +122,7 @@ class ASPP(nn.Cell):
self
.
feature_shape
=
feature_shape
self
.
feature_shape
=
feature_shape
self
.
concat
=
P
.
Concat
(
axis
=
1
)
self
.
concat
=
P
.
Concat
(
axis
=
1
)
@
add_flags
(
loop_can_unroll
=
True
)
def
construct
(
self
,
x
,
scale_index
=
0
):
def
construct
(
self
,
x
,
scale_index
=
0
):
aspp0
=
self
.
aspp0
(
x
)
aspp0
=
self
.
aspp0
(
x
)
aspp1
=
self
.
global_poolings
[
scale_index
](
x
)
aspp1
=
self
.
global_poolings
[
scale_index
](
x
)
...
@@ -276,7 +278,7 @@ class SingleDeepLabV3(nn.Cell):
...
@@ -276,7 +278,7 @@ class SingleDeepLabV3(nn.Cell):
atrous_rates
=
atrous_rates
,
atrous_rates
=
atrous_rates
,
output_stride
=
output_stride
,
output_stride
=
output_stride
,
fine_tune_batch_norm
=
fine_tune_batch_norm
)
fine_tune_batch_norm
=
fine_tune_batch_norm
)
self
.
aspp
.
add_flags
(
loop_can_unroll
=
True
)
atrous_rates_len
=
0
atrous_rates_len
=
0
if
atrous_rates
is
not
None
:
if
atrous_rates
is
not
None
:
atrous_rates_len
=
len
(
atrous_rates
)
atrous_rates_len
=
len
(
atrous_rates
)
...
...
tests/st/networks/models/bert/src/bert_for_pre_training.py
浏览文件 @
c22c865c
...
@@ -379,8 +379,8 @@ class BertTrainOneStepWithLossScaleCell(nn.Cell):
...
@@ -379,8 +379,8 @@ class BertTrainOneStepWithLossScaleCell(nn.Cell):
if
scale_update_cell
:
if
scale_update_cell
:
self
.
loss_scale
=
Parameter
(
Tensor
(
scale_update_cell
.
get_loss_scale
(),
dtype
=
mstype
.
float32
),
self
.
loss_scale
=
Parameter
(
Tensor
(
scale_update_cell
.
get_loss_scale
(),
dtype
=
mstype
.
float32
),
name
=
"loss_scale"
)
name
=
"loss_scale"
)
self
.
add_flags
(
has_effect
=
True
)
@
C
.
add_flags
(
has_effect
=
True
)
def
construct
(
self
,
def
construct
(
self
,
input_ids
,
input_ids
,
input_mask
,
input_mask
,
...
...
tests/ut/python/keep_order/test_keep_order.py
浏览文件 @
c22c865c
...
@@ -133,8 +133,8 @@ def test_keep_order_io_effect_exception_return_dtype():
...
@@ -133,8 +133,8 @@ def test_keep_order_io_effect_exception_return_dtype():
self
.
dtype
=
P
.
DType
()
self
.
dtype
=
P
.
DType
()
self
.
sub
=
P
.
Sub
()
self
.
sub
=
P
.
Sub
()
self
.
neg
=
P
.
Neg
()
self
.
neg
=
P
.
Neg
()
self
.
add_flags
(
has_effect
=
True
)
@
C
.
add_flags
(
has_effect
=
True
)
def
construct
(
self
,
x
):
def
construct
(
self
,
x
):
init
=
self
.
alloc_status
()
init
=
self
.
alloc_status
()
self
.
clear_status
(
init
)
self
.
clear_status
(
init
)
...
...
tests/ut/python/ops/test_math_ops.py
浏览文件 @
c22c865c
...
@@ -268,8 +268,8 @@ class NpuFloatNet(nn.Cell):
...
@@ -268,8 +268,8 @@ class NpuFloatNet(nn.Cell):
self
.
reduce_sum
=
P
.
ReduceSum
(
keep_dims
=
True
)
self
.
reduce_sum
=
P
.
ReduceSum
(
keep_dims
=
True
)
self
.
sub
=
P
.
Sub
()
self
.
sub
=
P
.
Sub
()
self
.
neg
=
P
.
Neg
()
self
.
neg
=
P
.
Neg
()
self
.
add_flags
(
has_effect
=
True
)
@
C
.
add_flags
(
has_effect
=
True
)
def
construct
(
self
,
x
):
def
construct
(
self
,
x
):
init
=
self
.
alloc_status
()
init
=
self
.
alloc_status
()
self
.
clear_status
(
init
)
self
.
clear_status
(
init
)
...
...
tests/ut/python/pynative_mode/ge/model/test_lenet_model.py
浏览文件 @
c22c865c
...
@@ -14,13 +14,13 @@
...
@@ -14,13 +14,13 @@
# ============================================================================
# ============================================================================
""" test_lenet_model """
""" test_lenet_model """
import
numpy
as
np
import
numpy
as
np
import
pytest
import
mindspore.nn
as
nn
import
mindspore.nn
as
nn
from
mindspore.common.tensor
import
Tensor
from
mindspore.common.tensor
import
Tensor
from
mindspore.nn
import
WithGradCell
,
WithLossCell
from
mindspore.nn
import
WithGradCell
,
WithLossCell
from
mindspore.nn.optim
import
Momentum
from
mindspore.nn.optim
import
Momentum
from
mindspore.ops
import
operations
as
P
from
mindspore.ops
import
operations
as
P
from
....ut_filter
import
non_graph_engine
class
LeNet5
(
nn
.
Cell
):
class
LeNet5
(
nn
.
Cell
):
...
@@ -47,7 +47,7 @@ class LeNet5(nn.Cell):
...
@@ -47,7 +47,7 @@ class LeNet5(nn.Cell):
return
x
return
x
@
non_graph_engine
@
pytest
.
mark
.
skip
(
reason
=
"need ge backend"
)
def
test_lenet_pynative_train_net
():
def
test_lenet_pynative_train_net
():
""" test_lenet_pynative_train_net """
""" test_lenet_pynative_train_net """
data
=
Tensor
(
np
.
ones
([
1
,
1
,
32
,
32
]).
astype
(
np
.
float32
)
*
0.01
)
data
=
Tensor
(
np
.
ones
([
1
,
1
,
32
,
32
]).
astype
(
np
.
float32
)
*
0.01
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录