Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
855d6b8f
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
855d6b8f
编写于
8月 18, 2020
作者:
B
buxue
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add check for user define bprop in Pynative mode.
上级
030af09f
变更
11
隐藏空白更改
内联
并排
Showing
11 changed file
with
316 addition
and
19 deletion
+316
-19
mindspore/ccsrc/frontend/operator/composite/do_signature.cc
mindspore/ccsrc/frontend/operator/composite/do_signature.cc
+1
-1
mindspore/ccsrc/pipeline/jit/pipeline.cc
mindspore/ccsrc/pipeline/jit/pipeline.cc
+4
-1
mindspore/ccsrc/pipeline/pynative/pynative_execute.cc
mindspore/ccsrc/pipeline/pynative/pynative_execute.cc
+7
-0
mindspore/ccsrc/pybind_api/ir/log_adapter_py.cc
mindspore/ccsrc/pybind_api/ir/log_adapter_py.cc
+3
-0
mindspore/ccsrc/pybind_api/ir/primitive_py.cc
mindspore/ccsrc/pybind_api/ir/primitive_py.cc
+44
-3
mindspore/ccsrc/utils/convert_utils.cc
mindspore/ccsrc/utils/convert_utils.cc
+0
-6
mindspore/core/abstract/prim_structures.cc
mindspore/core/abstract/prim_structures.cc
+5
-5
mindspore/core/utils/log_adapter.cc
mindspore/core/utils/log_adapter.cc
+2
-1
mindspore/core/utils/log_adapter.h
mindspore/core/utils/log_adapter.h
+1
-0
tests/ut/python/pynative_mode/test_implicit_conversion.py
tests/ut/python/pynative_mode/test_implicit_conversion.py
+38
-2
tests/ut/python/pynative_mode/test_user_define_bprop_check.py
...s/ut/python/pynative_mode/test_user_define_bprop_check.py
+211
-0
未找到文件。
mindspore/ccsrc/frontend/operator/composite/do_signature.cc
浏览文件 @
855d6b8f
...
...
@@ -143,7 +143,7 @@ TypeId GetMaxTypeId(const abstract::AbstractBasePtrList &args_spec_list, std::ve
}
}
if
(
max_type_id
==
kNumberTypeUInt8
&&
has_int8
==
true
)
{
if
(
max_type_id
==
kNumberTypeUInt8
&&
has_int8
)
{
max_type_id
=
kNumberTypeInt16
;
}
// if bool is the max type, see if there is scalar input
...
...
mindspore/ccsrc/pipeline/jit/pipeline.cc
浏览文件 @
855d6b8f
...
...
@@ -445,7 +445,7 @@ bool ExecutorPy::CompileInner(const py::object &obj, const py::tuple &args, cons
GetGeBackendPolicy
();
#endif
ExecutorInfoPtr
executor_info
=
std
::
make_shared
<
ExecutorInfo
>
();
std
::
string
phase_s
=
py
::
cast
<
std
::
string
>
(
phase
);
auto
phase_s
=
py
::
cast
<
std
::
string
>
(
phase
);
MS_LOG
(
INFO
)
<<
"ExecutorPy compile phase:"
<<
phase_s
<<
"!"
;
ResourcePtr
resource
=
std
::
make_shared
<
Resource
>
(
obj
);
...
...
@@ -540,6 +540,9 @@ bool ExecutorPy::Compile(const py::object &obj, const py::tuple &args, const py:
}
catch
(
const
py
::
index_error
&
ex
)
{
ReleaseResource
(
phase
);
throw
py
::
index_error
(
ex
);
}
catch
(
const
py
::
key_error
&
ex
)
{
ReleaseResource
(
phase
);
throw
py
::
key_error
(
ex
);
}
catch
(
const
py
::
attribute_error
&
ex
)
{
ReleaseResource
(
phase
);
throw
py
::
attribute_error
(
ex
);
...
...
mindspore/ccsrc/pipeline/pynative/pynative_execute.cc
浏览文件 @
855d6b8f
...
...
@@ -175,6 +175,7 @@ std::map<SignatureEnumDType, TypeId> GetDstType(const py::tuple &py_args,
TypeId
max_type
=
TypeId
::
kTypeUnknown
;
bool
has_float
=
false
;
bool
has_int
=
false
;
bool
has_int8
=
false
;
for
(
size_t
index
:
indexes
)
{
if
(
!
has_float
&&
py
::
isinstance
<
py
::
float_
>
(
py_args
[
index
]))
{
has_float
=
true
;
...
...
@@ -191,6 +192,9 @@ std::map<SignatureEnumDType, TypeId> GetDstType(const py::tuple &py_args,
if
(
type_priority
==
prim
::
type_map
.
end
())
{
continue
;
}
if
(
arg_type_id
==
kNumberTypeInt8
)
{
has_int8
=
true
;
}
if
(
type_priority
->
second
>
priority
)
{
max_type
=
type_priority
->
first
;
priority
=
type_priority
->
second
;
...
...
@@ -205,6 +209,9 @@ std::map<SignatureEnumDType, TypeId> GetDstType(const py::tuple &py_args,
max_type
=
TypeId
::
kNumberTypeFloat32
;
}
}
if
(
max_type
==
TypeId
::
kNumberTypeUInt8
&&
has_int8
)
{
max_type
=
TypeId
::
kNumberTypeInt16
;
}
(
void
)
dst_type
.
insert
(
std
::
make_pair
(
type
,
max_type
));
}
return
dst_type
;
...
...
mindspore/ccsrc/pybind_api/ir/log_adapter_py.cc
浏览文件 @
855d6b8f
...
...
@@ -39,6 +39,9 @@ class PyExceptionInitializer {
if
(
exception_type
==
TypeError
)
{
throw
py
::
type_error
(
str
);
}
if
(
exception_type
==
KeyError
)
{
throw
py
::
key_error
(
str
);
}
if
(
exception_type
==
AttributeError
)
{
throw
py
::
attribute_error
(
str
);
}
...
...
mindspore/ccsrc/pybind_api/ir/primitive_py.cc
浏览文件 @
855d6b8f
...
...
@@ -24,6 +24,7 @@
#include "utils/convert_utils_base.h"
#include "utils/primitive_utils.h"
#include "utils/base_ref_extends.h"
#include "utils/ms_context.h"
#include "pybind_api/api_register.h"
#include "pybind_api/export_flags.h"
#include "pybind_api/ir/base_ref_py.h"
...
...
@@ -77,9 +78,47 @@ py::function PrimitivePy::GetBpropFunction() {
}
}
py
::
tuple
check_bprop_out
(
const
py
::
object
&
grads_obj
,
const
py
::
tuple
&
py_args
)
{
py
::
tuple
grads
;
if
(
!
py
::
isinstance
<
py
::
tuple
>
(
grads_obj
))
{
grads
=
py
::
make_tuple
(
grads_obj
);
}
else
{
grads
=
py
::
cast
<
py
::
tuple
>
(
grads_obj
);
}
if
(
grads
.
size
()
!=
py_args
.
size
()
-
2
)
{
MS_EXCEPTION
(
ValueError
)
<<
"For user define net bprop, the gradients number: "
<<
grads
.
size
()
<<
" is not equal to the args number: "
<<
py_args
.
size
()
-
2
<<
"."
;
}
if
(
MsContext
::
GetInstance
()
->
check_bprop_flag
())
{
for
(
size_t
i
=
0
;
i
<
grads
.
size
();
i
++
)
{
if
(
py
::
isinstance
<
tensor
::
Tensor
>
(
py_args
[
i
]))
{
if
(
!
py
::
isinstance
<
tensor
::
Tensor
>
(
grads
[
i
]))
{
MS_EXCEPTION
(
ValueError
)
<<
"For user define net bprop, the gradient of the "
<<
i
<<
"th arg should be Tensor, but got "
<<
py
::
cast
<
std
::
string
>
(
grads
[
i
].
attr
(
"__class__"
).
attr
(
"__name__"
))
<<
", and the value is "
<<
py
::
cast
<
py
::
str
>
(
grads
[
i
])
<<
"."
;
}
py
::
tuple
grad_shape
=
grads
[
i
].
attr
(
"shape"
);
py
::
object
grad_dtype
=
grads
[
i
].
attr
(
"dtype"
);
py
::
tuple
arg_shape
=
py_args
[
i
].
attr
(
"shape"
);
py
::
object
arg_dtype
=
py_args
[
i
].
attr
(
"dtype"
);
if
(
!
grad_shape
.
equal
(
arg_shape
)
||
grad_dtype
!=
arg_dtype
)
{
MS_EXCEPTION
(
ValueError
)
<<
"For user define net bprop, the gradient of the "
<<
i
<<
"th arg should have the same shape and dtype as the "
<<
i
<<
"th arg, but the "
<<
i
<<
"th arg shape: "
<<
py
::
cast
<
py
::
str
>
(
arg_shape
)
<<
" and dtype: "
<<
py
::
cast
<
py
::
str
>
(
arg_dtype
)
<<
", the gradient shape: "
<<
py
::
cast
<
py
::
str
>
(
grad_shape
)
<<
" and dtype: "
<<
py
::
cast
<
py
::
str
>
(
grad_dtype
)
<<
"."
;
}
}
}
}
return
grads
;
}
BaseRef
PrimitivePy
::
RunHookFunction
(
const
VectorRef
&
args
)
const
{
py
::
tuple
py_args
=
ConvertDatatoPyTuple
(
args
);
py
::
object
obj
;
bool
is_bprop
=
this
->
HasAttr
(
kBpropAttrName
);
if
(
is_bprop
)
{
SyncData
(
py_args
);
...
...
@@ -90,11 +129,13 @@ BaseRef PrimitivePy::RunHookFunction(const VectorRef &args) const {
parse
::
PYTHON_MOD_CONVERT_TO_MS_TENSOR
,
py_args
[
i
])
:
py_args
[
i
];
}
obj
=
hook_
(
*
convert_args
);
return
std
::
make_shared
<
PyObjectRef
>
(
obj
);
py
::
object
grads_obj
=
hook_
(
*
convert_args
);
py
::
tuple
grads
=
check_bprop_out
(
grads_obj
,
py_args
);
return
std
::
make_shared
<
PyObjectRef
>
(
grads
);
}
SyncData
(
py_args
[
2
]);
bool
is_cell
=
this
->
HasAttr
(
kCellHookAttrName
);
py
::
object
obj
;
if
(
is_cell
)
{
auto
cell_id
=
GetValue
<
std
::
string
>
(
this
->
GetAttr
(
kCellIDAttrName
));
auto
iter
=
hook_grad_
.
find
(
cell_id
);
...
...
mindspore/ccsrc/utils/convert_utils.cc
浏览文件 @
855d6b8f
...
...
@@ -440,16 +440,10 @@ bool IsGraphOutputValueNodeOrParameter(const AnfNodePtr &output, const py::tuple
// inputs (a.k.a args in current function) size less than parameters'.
if
(
output
->
isa
<
Parameter
>
())
{
MS_LOG
(
INFO
)
<<
"Graph's output is a parameter. If all params are inputs, no need to execute."
;
if
(
args
.
empty
())
{
MS_LOG
(
EXCEPTION
)
<<
"Inputs size is 0, let graph to be executed."
;
}
// Find the right parameter as ret_val.
auto
func_graph
=
output
->
func_graph
();
MS_EXCEPTION_IF_NULL
(
func_graph
);
auto
params
=
func_graph
->
parameters
();
if
(
params
.
empty
())
{
MS_EXCEPTION
(
UnknownError
)
<<
"Graph's parameters size is 0"
;
}
if
((
args
.
size
()
+
func_graph
->
hyper_param_count
())
!=
params
.
size
())
{
MS_LOG
(
EXCEPTION
)
<<
"Input size "
<<
args
.
size
()
<<
" add Parameter count "
<<
func_graph
->
hyper_param_count
()
<<
" not equal to graph input size "
<<
params
.
size
()
<<
", let graph to be executed."
;
...
...
mindspore/core/abstract/prim_structures.cc
浏览文件 @
855d6b8f
...
...
@@ -55,7 +55,7 @@ AbstractBasePtr InferImplMakeDict(const AnalysisEnginePtr &, const PrimitivePtr
if
(
!
keyPtr
->
isa
<
StringImm
>
())
{
MS_LOG
(
EXCEPTION
)
<<
op_name
<<
" evaluator keys should be string, but got "
<<
keyPtr
->
ToString
();
}
std
::
string
key_string
=
GetValue
<
std
::
string
>
(
keyPtr
);
auto
key_string
=
GetValue
<
std
::
string
>
(
keyPtr
);
key_value
.
emplace_back
(
key_string
,
value_list
[
index
]);
}
return
std
::
make_shared
<
AbstractDictionary
>
(
key_value
);
...
...
@@ -72,7 +72,7 @@ AbstractBasePtr InferImplMakeKwarg(const AnalysisEnginePtr &, const PrimitivePtr
if
(
!
keyPtr
->
isa
<
StringImm
>
())
{
MS_LOG
(
EXCEPTION
)
<<
op_name
<<
" evaluator key should be string, but got "
<<
keyPtr
->
ToString
();
}
std
::
string
key_string
=
GetValue
<
std
::
string
>
(
keyPtr
);
auto
key_string
=
GetValue
<
std
::
string
>
(
keyPtr
);
return
std
::
make_shared
<
AbstractKeywordArg
>
(
key_string
,
args_spec_list
[
1
]);
}
...
...
@@ -88,7 +88,7 @@ AbstractBasePtr InferImplExtractKwarg(const AnalysisEnginePtr &, const Primitive
if
(
!
key_value
->
isa
<
StringImm
>
())
{
MS_LOG
(
EXCEPTION
)
<<
op_name
<<
" evaluator key should be string, but got "
<<
key_value
->
ToString
();
}
std
::
string
key_input
=
GetValue
<
std
::
string
>
(
key_value
);
auto
key_input
=
GetValue
<
std
::
string
>
(
key_value
);
std
::
string
key_actual
=
kwarg
->
get_key
();
if
(
key_actual
!=
key_input
)
{
MS_LOG
(
EXCEPTION
)
<<
op_name
<<
" evaluator input key should be same as AbstractKeywordArg' key, but input is "
...
...
@@ -216,7 +216,7 @@ AbstractBasePtr InferImplDictGetItem(const AnalysisEnginePtr &, const PrimitiveP
auto
it
=
std
::
find_if
(
dict_elems
.
begin
(),
dict_elems
.
end
(),
[
key_str
](
const
AbstractAttribute
&
item
)
{
return
item
.
first
==
key_str
;
});
if
(
it
==
dict_elems
.
end
())
{
MS_
LOG
(
EXCEPTION
)
<<
"The key "
<<
key_str
<<
" does not exist in the dict:"
<<
args_spec_list
[
0
]
->
ToString
();
MS_
EXCEPTION
(
KeyError
)
<<
"The key "
<<
key_str
<<
" does not exist in the dict:"
<<
args_spec_list
[
0
]
->
ToString
();
}
return
it
->
second
;
}
...
...
@@ -233,7 +233,7 @@ AbstractBasePtr InferImplDictSetItem(const AnalysisEnginePtr &, const PrimitiveP
if
(
!
key_value
->
isa
<
StringImm
>
())
{
MS_LOG
(
EXCEPTION
)
<<
op_name
<<
" evaluator key should be string, but got "
<<
key_value
->
ToString
();
}
std
::
string
key_str
=
GetValue
<
std
::
string
>
(
key_value
);
auto
key_str
=
GetValue
<
std
::
string
>
(
key_value
);
std
::
vector
<
AbstractAttribute
>
dict_elems
=
dict
->
elements
();
auto
it
=
std
::
find_if
(
dict_elems
.
begin
(),
dict_elems
.
end
(),
[
key_str
](
const
AbstractAttribute
&
item
)
{
return
item
.
first
==
key_str
;
});
...
...
mindspore/core/utils/log_adapter.cc
浏览文件 @
855d6b8f
...
...
@@ -147,6 +147,7 @@ static std::string ExceptionTypeToString(ExceptionType type) {
_TO_STRING
(
IndexError
),
_TO_STRING
(
ValueError
),
_TO_STRING
(
TypeError
),
_TO_STRING
(
KeyError
),
_TO_STRING
(
AttributeError
),
};
// clang-format on
...
...
@@ -236,7 +237,7 @@ void LogWriter::operator^(const LogStream &stream) const {
std
::
ostringstream
oss
;
oss
<<
location_
.
file_
<<
":"
<<
location_
.
line_
<<
" "
<<
location_
.
func_
<<
"] "
;
if
(
exception_type_
!=
NoExceptionType
&&
exception_type_
!=
IndexError
&&
exception_type_
!=
TypeError
&&
exception_type_
!=
ValueError
&&
exception_type_
!=
AttributeError
)
{
exception_type_
!=
ValueError
&&
exception_type_
!=
KeyError
&&
exception_type_
!=
AttributeError
)
{
oss
<<
ExceptionTypeToString
(
exception_type_
)
<<
" "
;
}
oss
<<
msg
.
str
();
...
...
mindspore/core/utils/log_adapter.h
浏览文件 @
855d6b8f
...
...
@@ -60,6 +60,7 @@ enum ExceptionType {
IndexError
,
ValueError
,
TypeError
,
KeyError
,
AttributeError
,
};
...
...
tests/ut/python/pynative_mode/test_implicit_conversion.py
浏览文件 @
855d6b8f
...
...
@@ -88,6 +88,16 @@ def test_float_tensor_and_bool_tensors_add():
y
=
Tensor
(
np
.
array
([[
True
,
True
,
True
],
[
False
,
False
,
False
]],
dtype
=
np
.
bool_
))
ret_actual
=
x
+
y
ret_expect
=
Tensor
(
np
.
array
([[
1.1
,
1.2
,
1.3
],
[
0.4
,
0.5
,
0.6
]],
dtype
=
np
.
float32
))
assert
ret_actual
.
dtype
==
ret_expect
.
dtype
assert
(
ret_actual
.
asnumpy
()
==
ret_expect
.
asnumpy
()).
all
()
def
test_int8_tensor_and_uint8_tensors_add
():
x
=
Tensor
(
np
.
array
([[
1
,
2
,
3
],
[
4
,
5
,
6
]],
dtype
=
np
.
int8
))
y
=
Tensor
(
np
.
array
([[
1
,
2
,
3
],
[
4
,
5
,
6
]],
dtype
=
np
.
uint8
))
ret_actual
=
x
+
y
ret_expect
=
Tensor
(
np
.
array
([[
2
,
4
,
6
],
[
8
,
10
,
12
]],
dtype
=
np
.
int16
))
assert
ret_actual
.
dtype
==
ret_expect
.
dtype
assert
(
ret_actual
.
asnumpy
()
==
ret_expect
.
asnumpy
()).
all
()
...
...
@@ -165,7 +175,6 @@ def test_float_tensor_and_int_tensors_sub_grad():
net
=
Net
()
grad_net
=
GradNet
(
net
)
ret
=
grad_net
(
x
,
y
,
sens
)
print
(
ret
)
assert
ret
[
0
].
dtype
==
x
.
dtype
assert
ret
[
1
].
dtype
==
y
.
dtype
assert
(
ret
[
0
].
asnumpy
()
==
sens
.
asnumpy
()).
all
()
...
...
@@ -194,7 +203,6 @@ def test_float16_tensor_and_float32_tensors_sub_grad():
net
=
Net
()
grad_net
=
GradNet
(
net
)
ret
=
grad_net
(
x
,
y
,
sens
)
print
(
ret
)
assert
ret
[
0
].
dtype
==
x
.
dtype
assert
ret
[
1
].
dtype
==
y
.
dtype
assert
(
ret
[
0
].
asnumpy
()
==
sens
.
asnumpy
()).
all
()
...
...
@@ -224,3 +232,31 @@ def test_float_tensor_and_int_add_grad():
ret
=
grad_net
(
x
,
sens
)
assert
ret
[
0
].
dtype
==
x
.
dtype
assert
(
ret
[
0
].
asnumpy
()
==
sens
.
asnumpy
()).
all
()
def
test_int8_tensor_and_uint8_tensors_add_grad
():
class
Net
(
nn
.
Cell
):
def
__init__
(
self
):
super
(
Net
,
self
).
__init__
()
def
construct
(
self
,
x
,
y
):
return
x
+
y
class
GradNet
(
nn
.
Cell
):
def
__init__
(
self
,
net
):
super
(
GradNet
,
self
).
__init__
()
self
.
net
=
net
def
construct
(
self
,
x
,
y
,
sens
):
return
C
.
grad_all_with_sens
(
self
.
net
)(
x
,
y
,
sens
)
x
=
Tensor
(
np
.
array
([[
1
,
2
,
3
],
[
4
,
5
,
6
]],
dtype
=
np
.
int8
))
y
=
Tensor
(
np
.
array
([[
1
,
2
,
3
],
[
4
,
5
,
6
]],
dtype
=
np
.
uint8
))
sens
=
Tensor
(
np
.
array
([[
1
,
2
,
3
],
[
4
,
5
,
6
]],
dtype
=
np
.
int16
))
net
=
Net
()
grad_net
=
GradNet
(
net
)
ret
=
grad_net
(
x
,
y
,
sens
)
assert
ret
[
0
].
dtype
==
x
.
dtype
assert
ret
[
1
].
dtype
==
y
.
dtype
assert
(
ret
[
0
].
asnumpy
()
==
sens
.
asnumpy
()).
all
()
assert
(
ret
[
1
].
asnumpy
()
==
sens
.
asnumpy
()).
all
()
tests/ut/python/pynative_mode/test_user_define_bprop_check.py
0 → 100644
浏览文件 @
855d6b8f
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
""" test implicit conversion """
import
numpy
as
np
import
pytest
from
mindspore
import
Tensor
,
nn
,
context
,
Parameter
from
mindspore
import
dtype
as
mstype
from
mindspore.ops
import
composite
as
C
def
test_user_define_bprop_check_ok
():
class
Net
(
nn
.
Cell
):
def
__init__
(
self
):
super
(
Net
,
self
).
__init__
()
self
.
grad
=
Tensor
(
np
.
array
([[
1.1
,
2.2
,
3.3
],
[
2.0
,
3.0
,
4.0
]],
dtype
=
np
.
float32
))
def
construct
(
self
,
x
):
ret
=
x
*
2
return
ret
def
bprop
(
self
,
x
,
out
,
dout
):
return
(
self
.
grad
*
3
,)
class
GradNet
(
nn
.
Cell
):
def
__init__
(
self
,
net
):
super
(
GradNet
,
self
).
__init__
()
self
.
net
=
net
def
construct
(
self
,
x
,
sens
):
return
C
.
grad_all_with_sens
(
self
.
net
)(
x
,
sens
)
x
=
Tensor
(
np
.
array
([[
0.1
,
0.2
,
0.3
],
[
0.4
,
0.5
,
0.6
]],
dtype
=
np
.
float32
))
sens
=
Tensor
(
np
.
array
([[
1.0
,
2.0
,
0.0
],
[
0.0
,
3.0
,
4.0
]],
dtype
=
np
.
float32
))
context
.
set_context
(
mode
=
context
.
PYNATIVE_MODE
,
check_bprop
=
True
)
net
=
Net
()
grad_net
=
GradNet
(
net
)
ret
=
grad_net
(
x
,
sens
)
assert
ret
[
0
].
shape
==
(
2
,
3
)
assert
ret
[
0
].
dtype
==
mstype
.
float32
assert
(
ret
[
0
].
asnumpy
()
==
np
.
array
([[
1.1
,
2.2
,
3.3
],
[
2.0
,
3.0
,
4.0
]],
np
.
float32
)
*
3
).
all
()
def
test_user_define_bprop_no_check_dtype
():
class
Net
(
nn
.
Cell
):
def
__init__
(
self
):
super
(
Net
,
self
).
__init__
()
self
.
grad
=
Tensor
(
np
.
array
([[
1.1
,
2.2
,
3.3
],
[
2.0
,
3.0
,
4.0
]],
dtype
=
np
.
float16
))
def
construct
(
self
,
x
):
ret
=
x
*
2
return
ret
def
bprop
(
self
,
x
,
out
,
dout
):
return
(
self
.
grad
*
3
,)
class
GradNet
(
nn
.
Cell
):
def
__init__
(
self
,
net
):
super
(
GradNet
,
self
).
__init__
()
self
.
net
=
net
def
construct
(
self
,
x
,
sens
):
return
C
.
grad_all_with_sens
(
self
.
net
)(
x
,
sens
)
x
=
Tensor
(
np
.
array
([[
0.1
,
0.2
,
0.3
],
[
0.4
,
0.5
,
0.6
]],
dtype
=
np
.
float32
))
sens
=
Tensor
(
np
.
array
([[
1.0
,
2.0
,
0.0
],
[
0.0
,
3.0
,
4.0
]],
dtype
=
np
.
float32
))
context
.
set_context
(
mode
=
context
.
PYNATIVE_MODE
,
check_bprop
=
False
)
net
=
Net
()
grad_net
=
GradNet
(
net
)
ret
=
grad_net
(
x
,
sens
)
assert
ret
[
0
].
shape
==
(
2
,
3
)
assert
ret
[
0
].
dtype
==
mstype
.
float16
assert
(
ret
[
0
].
asnumpy
()
==
np
.
array
([[
1.1
,
2.2
,
3.3
],
[
2.0
,
3.0
,
4.0
]],
np
.
float16
)
*
3
).
all
()
def
test_user_define_bprop_check_shape
():
class
Net
(
nn
.
Cell
):
def
__init__
(
self
):
super
(
Net
,
self
).
__init__
()
self
.
grad
=
Tensor
(
np
.
array
([[
1.1
,
2.2
],
[
2.0
,
3.0
]],
dtype
=
np
.
float32
))
def
construct
(
self
,
x
):
ret
=
x
*
2
return
ret
def
bprop
(
self
,
x
,
out
,
dout
):
return
(
self
.
grad
*
3
,)
class
GradNet
(
nn
.
Cell
):
def
__init__
(
self
,
net
):
super
(
GradNet
,
self
).
__init__
()
self
.
net
=
net
def
construct
(
self
,
x
,
sens
):
return
C
.
grad_all_with_sens
(
self
.
net
)(
x
,
sens
)
x
=
Tensor
(
np
.
array
([[
0.1
,
0.2
,
0.3
],
[
0.4
,
0.5
,
0.6
]],
dtype
=
np
.
float32
))
sens
=
Tensor
(
np
.
array
([[
1.0
,
2.0
,
0.0
],
[
0.0
,
3.0
,
4.0
]],
dtype
=
np
.
float32
))
context
.
set_context
(
mode
=
context
.
PYNATIVE_MODE
,
check_bprop
=
True
)
net
=
Net
()
grad_net
=
GradNet
(
net
)
with
pytest
.
raises
(
ValueError
)
as
ex
:
ret
=
grad_net
(
x
,
sens
)
assert
"the gradient of the 0th arg should have the same shape and dtype as the 0th arg"
in
str
(
ex
.
value
)
def
test_user_define_bprop_check_dtype
():
class
Net
(
nn
.
Cell
):
def
__init__
(
self
):
super
(
Net
,
self
).
__init__
()
self
.
grad
=
Tensor
(
np
.
array
([[
1.1
,
2.2
,
3.3
],
[
2.0
,
3.0
,
4.0
]],
dtype
=
np
.
float16
))
def
construct
(
self
,
x
):
ret
=
x
*
2
return
ret
def
bprop
(
self
,
x
,
out
,
dout
):
return
(
self
.
grad
*
3
,)
class
GradNet
(
nn
.
Cell
):
def
__init__
(
self
,
net
):
super
(
GradNet
,
self
).
__init__
()
self
.
net
=
net
def
construct
(
self
,
x
,
sens
):
return
C
.
grad_all_with_sens
(
self
.
net
)(
x
,
sens
)
x
=
Tensor
(
np
.
array
([[
0.1
,
0.2
,
0.3
],
[
0.4
,
0.5
,
0.6
]],
dtype
=
np
.
float32
))
sens
=
Tensor
(
np
.
array
([[
1.0
,
2.0
,
0.0
],
[
0.0
,
3.0
,
4.0
]],
dtype
=
np
.
float32
))
context
.
set_context
(
mode
=
context
.
PYNATIVE_MODE
,
check_bprop
=
True
)
net
=
Net
()
grad_net
=
GradNet
(
net
)
with
pytest
.
raises
(
ValueError
)
as
ex
:
ret
=
grad_net
(
x
,
sens
)
assert
"the gradient of the 0th arg should have the same shape and dtype as the 0th arg"
in
str
(
ex
.
value
)
def
test_user_define_bprop_check_parameter
():
class
Net
(
nn
.
Cell
):
def
__init__
(
self
):
super
(
Net
,
self
).
__init__
()
self
.
par
=
Parameter
(
Tensor
(
np
.
array
([[
1.1
,
2.2
,
3.3
],
[
2.0
,
3.0
,
4.0
]],
dtype
=
np
.
float32
)),
name
=
"par"
)
self
.
grad
=
Tensor
(
np
.
array
([[
1.1
,
2.2
,
3.3
],
[
2.0
,
3.0
,
4.0
]],
dtype
=
np
.
float16
))
def
construct
(
self
,
x
):
ret
=
x
*
2
+
self
.
par
return
ret
def
bprop
(
self
,
x
,
out
,
dout
):
return
dout
+
x
class
GradNet
(
nn
.
Cell
):
def
__init__
(
self
,
net
):
super
(
GradNet
,
self
).
__init__
()
self
.
net
=
net
def
construct
(
self
,
x
,
sens
):
return
C
.
grad_all_with_sens
(
self
.
net
)(
x
,
sens
)
x
=
Tensor
(
np
.
array
([[
0.1
,
0.2
,
0.3
],
[
0.4
,
0.5
,
0.6
]],
dtype
=
np
.
float32
))
sens
=
Tensor
(
np
.
array
([[
1.0
,
2.0
,
0.0
],
[
0.0
,
3.0
,
4.0
]],
dtype
=
np
.
float32
))
context
.
set_context
(
mode
=
context
.
PYNATIVE_MODE
,
check_bprop
=
True
)
net
=
Net
()
grad_net
=
GradNet
(
net
)
with
pytest
.
raises
(
RuntimeError
)
as
ex
:
ret
=
grad_net
(
x
,
sens
)
assert
"in scope Default does not support Parameter data type."
in
str
(
ex
.
value
)
def
test_user_define_bprop_check_number
():
class
Net
(
nn
.
Cell
):
def
__init__
(
self
):
super
(
Net
,
self
).
__init__
()
self
.
grad
=
Tensor
(
np
.
array
([[
1.1
,
2.2
,
3.3
],
[
2.0
,
3.0
,
4.0
]],
dtype
=
np
.
float32
))
def
construct
(
self
,
x
,
y
):
ret
=
x
*
2
+
y
return
ret
def
bprop
(
self
,
x
,
y
,
out
,
dout
):
return
(
dout
,)
class
GradNet
(
nn
.
Cell
):
def
__init__
(
self
,
net
):
super
(
GradNet
,
self
).
__init__
()
self
.
net
=
net
def
construct
(
self
,
x
,
y
,
sens
):
return
C
.
grad_all_with_sens
(
self
.
net
)(
x
,
y
,
sens
)
x
=
Tensor
(
np
.
array
([[
0.1
,
0.2
,
0.3
],
[
0.4
,
0.5
,
0.6
]],
dtype
=
np
.
float32
))
y
=
Tensor
(
np
.
array
([[
0.1
,
0.2
,
0.3
],
[
0.4
,
0.5
,
0.6
]],
dtype
=
np
.
float32
))
sens
=
Tensor
(
np
.
array
([[
1.0
,
2.0
,
0.0
],
[
0.0
,
3.0
,
4.0
]],
dtype
=
np
.
float32
))
context
.
set_context
(
mode
=
context
.
PYNATIVE_MODE
,
check_bprop
=
True
)
net
=
Net
()
grad_net
=
GradNet
(
net
)
with
pytest
.
raises
(
ValueError
)
as
ex
:
ret
=
grad_net
(
x
,
y
,
sens
)
assert
"For user define net bprop, the gradients number: 1 is not equal to the args number: 2."
in
str
(
ex
.
value
)
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录