From add3778a6152bc605bcbed450e44543cedaeac70 Mon Sep 17 00:00:00 2001 From: kingfo Date: Mon, 6 Jul 2020 17:26:28 +0800 Subject: [PATCH] add grad all in pynative mode --- mindspore/ccsrc/pynative/pynative_execute.cc | 2 +- mindspore/common/tensor.py | 3 + mindspore/context.py | 5 +- mindspore/nn/cell.py | 4 + mindspore/ops/composite/base.py | 55 ++++-- mindspore/ops/functional.py | 1 + tests/st/ops/gpu/test_dense_op.py | 1 + .../python/pipeline/infer/test_net_infer.py | 1 + .../parse}/test_cell_bprop.py | 20 +- tests/ut/python/pipeline/parse/test_parse.py | 118 +++++++++++- .../pynative_mode/nn/test_tensor_operation.py | 6 + .../ut/python/pynative_mode/ops/test_grad.py | 44 +++-- .../python/pynative_mode/test_framstruct.py | 182 +++++------------- tests/ut/python/pynative_mode/test_hook.py | 40 +++- .../pynative_mode/test_insert_grad_of.py | 2 + .../pynative_mode/test_stop_gradient.py | 17 +- 16 files changed, 307 insertions(+), 194 deletions(-) rename tests/ut/python/{pynative_mode => pipeline/parse}/test_cell_bprop.py (94%) diff --git a/mindspore/ccsrc/pynative/pynative_execute.cc b/mindspore/ccsrc/pynative/pynative_execute.cc index b353ab4f9..d72f89399 100644 --- a/mindspore/ccsrc/pynative/pynative_execute.cc +++ b/mindspore/ccsrc/pynative/pynative_execute.cc @@ -980,7 +980,7 @@ std::vector PynativeExecutor::GetWeightsArgs(const py::object &weigh } } } else { - MS_LOG(EXCEPTION) << "training not paramter_tuple"; + MS_LOG(DEBUG) << "training not paramter_tuple"; } return w_args; } diff --git a/mindspore/common/tensor.py b/mindspore/common/tensor.py index 5dc394755..64a8eb463 100644 --- a/mindspore/common/tensor.py +++ b/mindspore/common/tensor.py @@ -181,6 +181,9 @@ class Tensor(Tensor_): def __imod__(self, other): return self.__mod__(other) + def __pow__(self, other): + return tensor_operator_registry.get('__pow__')(self, other) + def __floordiv__(self, other): return tensor_operator_registry.get('__floordiv__')(self, other) diff --git a/mindspore/context.py b/mindspore/context.py index 98dbfb327..fe3d95b19 100644 --- a/mindspore/context.py +++ b/mindspore/context.py @@ -176,7 +176,10 @@ class _Context: self._context_switches.push(True, None) else: if self.enable_debug_runtime: - self.set_backend_policy("ge") + if self.device_target == "CPU": + self.set_backend_policy("vm") + else: + self.set_backend_policy("ge") self._context_switches.push(False, None) def set_backend_policy(self, policy): diff --git a/mindspore/nn/cell.py b/mindspore/nn/cell.py index cffe00a92..4f1bb67a8 100755 --- a/mindspore/nn/cell.py +++ b/mindspore/nn/cell.py @@ -16,6 +16,7 @@ import time import gc from collections import OrderedDict +import numpy from mindspore import log as logger from .. import context from ..common import dtype as mstype @@ -211,6 +212,9 @@ class Cell: if context.get_context("mode") == context.GRAPH_MODE: out = self.compile_and_run(*inputs) return out + for item in inputs: + if isinstance(item, numpy.ndarray): + raise TypeError("cell inputs should not be numpy array.") self.init_parameters_data() orign_grad = [] if self.requires_grad is True: diff --git a/mindspore/ops/composite/base.py b/mindspore/ops/composite/base.py index b0f16d82b..632efa0cc 100644 --- a/mindspore/ops/composite/base.py +++ b/mindspore/ops/composite/base.py @@ -17,6 +17,7 @@ """Basic composite operations.""" from functools import partial +from types import FunctionType from mindspore import context from ..._c_expression import EnvInstance_, GradOperation_, HyperMap_, Map_, MultitypeFuncGraph_, Tail_, \ @@ -25,6 +26,7 @@ from ...common import dtype as mstype from ...common.api import ms_function, _pynative_exec, _wrap_func from .. import functional as F from ...common.parameter import Parameter +from ...common.tensor import Tensor __all__ = [EnvInstance_, TupleAdd_, TupleSlice_, UnpackCall_, TupleGetItemTensor_] @@ -114,37 +116,48 @@ class GradOperation(GradOperation_): self.fn = None self.need_forward = False + def _pynative_forward_run(self, args, fn): + """ Pynative forward run to build grad graph. """ + if self.sens_param: + args = args[:-1] + if isinstance(fn, FunctionType): + _pynative_exec.set_grad_flag(True) + _pynative_exec.new_graph(fn, *args) + output = fn(*args) + _pynative_exec.end_graph(fn, output, *args) + else: + if fn.is_run and not fn.requires_grad: + raise ValueError("obj must set_grad.") + if not fn.is_run: + self.need_forward = True + print("already has forward run before grad by user") + if self.need_forward: + fn.set_grad() + fn(*args) + def __call__(self, fn, weights=None): grad_ = GradOperation('grad', self.get_all, self.get_by_list, self.sens_param) if self.grad_fn is None or self.fn != fn: - if self.get_by_list: - if context.get_context("mode") == context.GRAPH_MODE: + if context.get_context("mode") == context.GRAPH_MODE: + if self.get_by_list: @ms_function(obj=fn) def after_grad(*args): return grad_(fn, weights)(*args) else: - @_wrap_func + @ms_function(obj=fn) def after_grad(*args): - if fn.is_run and not fn.requires_grad: - raise ValueError("obj must set_grad.") - if not fn.is_run: - self.need_forward = True - print("already has forward run before grad by user") - if self.need_forward: - fn.set_grad() - if self.sens_param: - f_args = args[:-1] - fn(*f_args) - else: - fn(*args) - _pynative_exec.grad(grad_, fn, weights, *args) - out = _pynative_exec(*args) - _pynative_exec.clear() - return out + return grad_(fn)(*args) else: - @ms_function(obj=fn) + @_wrap_func def after_grad(*args): - return grad_(fn)(*args) + for arg in args: + if not isinstance(arg, Tensor): + raise TypeError("grad inputs should be tensor in pynative mode") + self._pynative_forward_run(args, fn) + _pynative_exec.grad(grad_, fn, weights, *args) + out = _pynative_exec(*args) + _pynative_exec.clear() + return out self.grad_fn = after_grad self.fn = fn return self.grad_fn diff --git a/mindspore/ops/functional.py b/mindspore/ops/functional.py index a5c3165ab..d23fcd309 100644 --- a/mindspore/ops/functional.py +++ b/mindspore/ops/functional.py @@ -166,6 +166,7 @@ tensor_operator_registry.register('__sub__', tensor_sub) tensor_operator_registry.register('__mul__', tensor_mul) tensor_operator_registry.register('__truediv__', tensor_div) tensor_operator_registry.register('__mod__', tensor_mod) +tensor_operator_registry.register('__pow__', tensor_pow) tensor_operator_registry.register('__floordiv__', tensor_floordiv) #ms cannot support Tensor(True) compare tensor_operator_registry.register('__eq__', equal) diff --git a/tests/st/ops/gpu/test_dense_op.py b/tests/st/ops/gpu/test_dense_op.py index 220f7ae05..e9c010ea7 100644 --- a/tests/st/ops/gpu/test_dense_op.py +++ b/tests/st/ops/gpu/test_dense_op.py @@ -228,6 +228,7 @@ def test_biasadd_3d(): error = np.ones(shape=[3, 4, 8]) * 1.0e-6 context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU") net = BiasAdd() + net.set_grad() result = net(x, b) diff = result.asnumpy() - expect assert np.all(diff < error) diff --git a/tests/ut/python/pipeline/infer/test_net_infer.py b/tests/ut/python/pipeline/infer/test_net_infer.py index 6b32a7617..9c19f213f 100644 --- a/tests/ut/python/pipeline/infer/test_net_infer.py +++ b/tests/ut/python/pipeline/infer/test_net_infer.py @@ -45,6 +45,7 @@ def test_net_infer(): def test_assign_in_while(): + context.set_context(device_target="Ascend") context.set_context(mode=context.GRAPH_MODE) class Net(nn.Cell): def __init__(self, input_shape): diff --git a/tests/ut/python/pynative_mode/test_cell_bprop.py b/tests/ut/python/pipeline/parse/test_cell_bprop.py similarity index 94% rename from tests/ut/python/pynative_mode/test_cell_bprop.py rename to tests/ut/python/pipeline/parse/test_cell_bprop.py index 09a096a09..7207160ca 100644 --- a/tests/ut/python/pynative_mode/test_cell_bprop.py +++ b/tests/ut/python/pipeline/parse/test_cell_bprop.py @@ -16,6 +16,7 @@ import numpy as np import pytest +import mindspore as ms import mindspore.common.dtype as mstype import mindspore.nn as nn from mindspore import Parameter @@ -24,12 +25,15 @@ from mindspore.common.initializer import initializer from mindspore.common.tensor import Tensor from mindspore.ops import composite as C from mindspore.ops import operations as P -from ....mindspore_test_framework.utils.bprop_util import bprop +from .....mindspore_test_framework.utils.bprop_util import bprop def setup_module(module): - context.set_context(mode=context.PYNATIVE_MODE) + context.set_context(device_target="CPU") + context.set_context(mode=context.GRAPH_MODE) +def teardown_module(module): + context.set_context(device_target="Ascend") class MulAdd(nn.Cell): def __init__(self): @@ -45,7 +49,9 @@ class MulAdd(nn.Cell): def test_grad_mul_add(): mul_add = MulAdd() - assert C.grad_all(mul_add)(1, 2) == (2, 4) + x = Tensor(1, dtype=ms.int32) + y = Tensor(2, dtype=ms.int32) + assert C.grad_all(mul_add)(x, y) == (2, 4) class InlineMulADD(nn.Cell): @@ -60,7 +66,9 @@ class InlineMulADD(nn.Cell): def test_grad_inline_mul_add(): inline_mul_add = InlineMulADD() - assert C.grad_all(inline_mul_add)(1, 2) == (3, 6) + x = Tensor(1, dtype=ms.int32) + y = Tensor(2, dtype=ms.int32) + assert C.grad_all(inline_mul_add)(x, y) == (3, 6) class WithParameter(nn.Cell): @@ -93,7 +101,9 @@ class WithNoBprop(nn.Cell): def test_with_no_bprop(): with_no_bprop = WithNoBprop() - assert C.grad_all(with_no_bprop)(1, 2) == (2, 1) + x = Tensor(1, dtype=ms.int32) + y = Tensor(2, dtype=ms.int32) + assert C.grad_all(with_no_bprop)(x, y) == (2, 1) def test_grad_in_bprop_1(): diff --git a/tests/ut/python/pipeline/parse/test_parse.py b/tests/ut/python/pipeline/parse/test_parse.py index bbc32d072..b295adcbe 100644 --- a/tests/ut/python/pipeline/parse/test_parse.py +++ b/tests/ut/python/pipeline/parse/test_parse.py @@ -19,21 +19,27 @@ @Desc : """ import logging +import pytest import numpy as np import mindspore as ms import mindspore.nn as nn from mindspore import Tensor +from mindspore import context +from mindspore.ops import composite as C from mindspore.common.api import ms_function, _executor +from mindspore.ops._grad.grad_base import bprop_getters +from mindspore.ops.primitive import prim_attr_register, PrimitiveWithInfer from mindspore.ops.functional import tensor_add from ...ut_filter import non_graph_engine -# pylint: disable=W0613 +# pylint: disable=W0613,W0612 # W0613: unused-argument log = logging.getLogger("test") log.setLevel(level=logging.ERROR) +context.set_context(mode=context.GRAPH_MODE) # Test case: use the parse obj interface use default parameter @@ -135,3 +141,113 @@ def test_net_with_ndarray(): input_data = np.array([[1.2, 2.1], [2.2, 3.2]]).astype('float32') net(ms.Tensor(input_data)) + + +def test_bprop_with_wrong_output_num(): + context.set_context(check_bprop=True) + class BpropWithWrongOutputNum(PrimitiveWithInfer): + @prim_attr_register + def __init__(self): + super(BpropWithWrongOutputNum, self).__init__('BpropWithWrongOutputNum') + + def __call__(self, x, y): + return x + + def infer_shape(self, x_shape, yshape): + return x_shape + + def infer_dtype(self, x_type, y_type): + return x_type + + @bprop_getters.register(BpropWithWrongOutputNum) + def get_bprop_with_wrong_output_num(self): + """Generate bprop for BpropWithWrongOutputNum""" + + def bprop(x, y, out, dout): + return (dout,) + + return bprop + + class BpropWithWrongOutputNumCell(nn.Cell): + def __init__(self): + super(BpropWithWrongOutputNumCell, self).__init__() + + def construct(self, x, y): + return BpropWithWrongOutputNum()(x, y) + + with pytest.raises(TypeError): + C.grad_all(BpropWithWrongOutputNumCell())(1, 2) + +def test_bprop_with_wrong_output_type(): + context.set_context(check_bprop=True) + class BpropWithWrongOutputType(PrimitiveWithInfer): + @prim_attr_register + def __init__(self): + super(BpropWithWrongOutputType, self).__init__('BpropWithWrongOutputType') + + def __call__(self, x): + return x + + def infer_shape(self, x_shape): + return x_shape + + def infer_dtype(self, x_type): + return x_type + + @bprop_getters.register(BpropWithWrongOutputType) + def get_bprop_with_wrong_output_type(self): + """Generate bprop for BpropWithWrongOutputType""" + + def bprop(x, out, dout): + return (1,) + + return bprop + + class BpropWithWrongOutputTypeCell(nn.Cell): + def __init__(self): + super(BpropWithWrongOutputTypeCell, self).__init__() + + def construct(self, x): + return BpropWithWrongOutputType()(x) + + with pytest.raises(TypeError): + C.grad_all(BpropWithWrongOutputTypeCell())(Tensor(np.ones([64, 10]).astype(np.int32))) + + +def test_bprop_with_wrong_output_shape(): + context.set_context(check_bprop=True) + class BpropWithWrongOutputShape(PrimitiveWithInfer): + @prim_attr_register + def __init__(self): + super(BpropWithWrongOutputShape, self).__init__('BpropWithWrongOutputShape') + + def __call__(self, x): + return x + + def infer_shape(self, x_shape): + return x_shape + + def infer_dtype(self, x_type): + return x_type + + @bprop_getters.register(BpropWithWrongOutputShape) + def get_bprop_with_wrong_output_shape(self): + """Generate bprop for BpropWithWrongOutputShape""" + ones = Tensor(np.ones([2,]).astype(np.int32)) + + def bprop(x, out, dout): + return (ones,) + + return bprop + + class BpropWithWrongOutputShapeCell(nn.Cell): + def __init__(self): + super(BpropWithWrongOutputShapeCell, self).__init__() + + def construct(self, x): + return BpropWithWrongOutputShape()(x) + + with pytest.raises(TypeError): + net = BpropWithWrongOutputShapeCell() + net.set_grad() + C.grad_all(net)(Tensor(np.ones([64, 10]).astype(np.int32))) diff --git a/tests/ut/python/pynative_mode/nn/test_tensor_operation.py b/tests/ut/python/pynative_mode/nn/test_tensor_operation.py index 306ba63c9..eb8610bdf 100644 --- a/tests/ut/python/pynative_mode/nn/test_tensor_operation.py +++ b/tests/ut/python/pynative_mode/nn/test_tensor_operation.py @@ -78,3 +78,9 @@ def test_tensor_imul(): y = Tensor(np.ones([3, 3, 3, 3]).astype(np.float32)) x *= y assert x.asnumpy()[0][0][0][0] == 1.0 + + +def test_tensor_pow(): + x = Tensor(np.ones([3, 3, 3, 3]).astype(np.float32) * 2) + y = x ** 3 + assert y.asnumpy()[0][0][0][0] == 8.0 diff --git a/tests/ut/python/pynative_mode/ops/test_grad.py b/tests/ut/python/pynative_mode/ops/test_grad.py index 8d880a86d..f028e91be 100644 --- a/tests/ut/python/pynative_mode/ops/test_grad.py +++ b/tests/ut/python/pynative_mode/ops/test_grad.py @@ -89,7 +89,11 @@ def test_scalar_cast_grad(): output = F.scalar_cast(x, input_t) return output - gfn = C.grad(fx_cast)(input_x) + @ms_function + def grad_fx_cast(input_x): + return C.grad(fx_cast)(input_x) + + gfn = grad_fx_cast(input_x) expect_dx = 1 assert gfn == expect_dx @@ -133,25 +137,6 @@ def test_transpose_grad(): assert np.all(gout[0].asnumpy() == expect) -@non_graph_engine -def test_squeeze_grad(): - """ test_squeeze_grad """ - input_tensor = Tensor(np.ones(shape=[3, 2, 1])) - squeeze = P.Squeeze(2) - - def fn(x): - output = squeeze(x) - return output - - out = fn(input_tensor) - gfn = grad_all_with_sens(fn) - sens = Tensor(np.ones_like(out.asnumpy())) - args = [input_tensor, sens] - gout = gfn(*args) - expect = np.ones([3, 2, 1]) - assert np.all(gout[0].asnumpy() == expect) - - def test_select_grad(): """ test_select_grad """ select = P.Select() @@ -176,6 +161,25 @@ def test_select_grad(): assert np.all(gout[2].asnumpy() == expect_y) +@non_graph_engine +def test_squeeze_grad(): + """ test_squeeze_grad """ + input_tensor = Tensor(np.ones(shape=[3, 2, 1])) + squeeze = P.Squeeze(2) + + def fn(x): + output = squeeze(x) + return output + + out = fn(input_tensor) + gfn = grad_all_with_sens(fn) + sens = Tensor(np.ones_like(out.asnumpy())) + args = [input_tensor, sens] + gout = gfn(*args) + expect = np.ones([3, 2, 1]) + assert np.all(gout[0].asnumpy() == expect) + + def test_SubGrad(): """ test_SubGrad """ input_x = Tensor(np.array([[2, 2]])) diff --git a/tests/ut/python/pynative_mode/test_framstruct.py b/tests/ut/python/pynative_mode/test_framstruct.py index 39a4c97ab..cdae50dc8 100644 --- a/tests/ut/python/pynative_mode/test_framstruct.py +++ b/tests/ut/python/pynative_mode/test_framstruct.py @@ -16,6 +16,7 @@ import numpy as np import pytest +import mindspore as ms import mindspore.nn as nn from mindspore import context from mindspore.common import dtype as mstype @@ -23,8 +24,6 @@ from mindspore.common.parameter import Parameter, ParameterTuple from mindspore.common.tensor import Tensor from mindspore.ops import composite as C from mindspore.ops import operations as P -from mindspore.ops._grad.grad_base import bprop_getters -from mindspore.ops.primitive import prim_attr_register, PrimitiveWithInfer from ..ut_filter import non_graph_engine from ....mindspore_test_framework.utils.check_gradient import ( ms_function, check_jacobian, Tensor, NNGradChecker, @@ -156,14 +155,14 @@ def test_if_always_true(): @non_graph_engine def test_f(): """ test_f """ - res = mainf(3, 2) + res = mainf(Tensor(3, dtype=ms.int32), Tensor(2, dtype=ms.int32)) assert res == (2, 3) @non_graph_engine def test_grad_add_mul(): """ test_grad_add_mul """ - res = grad_add_mul(3, 2) + res = grad_add_mul(Tensor(3, dtype=ms.int32), Tensor(2, dtype=ms.int32)) assert res == (2, 7) @@ -262,17 +261,19 @@ def test_if_tensor(): assert res == Tensor(np.ones([1]).astype(np.int32) * 4) -@ms_function def rec(x): """ rec """ if x > 0: return rec(x - 1) return x +@ms_function +def grad_rec(input_x): + return C.grad(rec)(input_x) def test_grad_rec(): """ test_grad_rec """ - res = C.grad(rec)(10) + res = grad_rec(3) assert res == 1 @@ -282,7 +283,6 @@ def test_me_rec(): assert res == 0 -@ms_function def t2_while(x, y): out = y - x i = 0 @@ -298,8 +298,10 @@ def test_while2(): def test_grad_while2(): - res = C.grad(t2_while)(2, 3) - assert res == 3 + @ms_function + def df_t2_while(input_x, input_y): + return C.grad(t2_while)(input_x, input_y) + assert df_t2_while(2, 3) == 3 def if_test(a, b): @@ -316,7 +318,7 @@ def grad_if(x, y): def test_grad_if(): """ test_grad_if """ - assert grad_if(5, 4) == (3, 0) + assert grad_if(Tensor(5, dtype=ms.int32), Tensor(4, dtype=ms.int32)) == (3, 0) # While loop is not unrolled in forward and backward graphs. @@ -421,7 +423,7 @@ def grad_while(x): def test_grad_while(): """ test_grad_while """ - assert grad_while(5) == (60,) + assert grad_while(Tensor(5, dtype=ms.int32)) == (60,) @ms_function @@ -438,8 +440,10 @@ def test_factorial(): def test_grad_factorial(): - res = C.grad(factorial)(3) - assert res == 11 + @ms_function + def df_factorial(x): + return C.grad(factorial)(x) + assert df_factorial(3) == 11 @ms_function @@ -513,7 +517,7 @@ def _for(x): ret = ret * i return ret - +@ms_function def grad_for(x): """ grad_for """ return C.grad_all(_for)(x) @@ -786,7 +790,10 @@ def multi_outputs(x, y): def test_grad_multi_outputs(): - assert C.grad_all_with_sens(multi_outputs)(2, 3, (1, 1)) == (4, 4) + @ms_function + def df_multi_outputs(x, y): + return C.grad_all_with_sens(multi_outputs)(x, y, (1, 1)) + assert df_multi_outputs(2, 3) == (4, 4) @ms_function @@ -813,7 +820,7 @@ def grad_refactor_simple_1(x, y): def test_grad_refactor_simple_1(): - assert C.grad_all(grad_refactor_simple_1)(2, 1) == (4, 2) + assert C.grad_all(grad_refactor_simple_1)(Tensor(2, dtype=ms.int32), Tensor(1, dtype=ms.int32)) == (4, 2) def grad_refactor_simple_2(x, y, z): @@ -822,7 +829,10 @@ def grad_refactor_simple_2(x, y, z): def test_grad_refactor_simple_2(): - assert C.grad_all(grad_refactor_simple_2)(2, 3, 0) == (7, 4, 7) + x = Tensor(2, dtype=ms.int32) + y = Tensor(3, dtype=ms.int32) + z = Tensor(0, dtype=ms.int32) + assert C.grad_all(grad_refactor_simple_2)(x, y, z) == (7, 4, 7) def grad_refactor_1(a, b): @@ -835,7 +845,7 @@ def grad_refactor_1(a, b): def test_grad_refactor_1(): - assert C.grad_all(grad_refactor_1)(2, 3) == (3, 2) + assert C.grad_all(grad_refactor_1)(Tensor(2, dtype=ms.int32), Tensor(3, dtype=ms.int32)) == (3, 2) def grad_refactor_2(a, b): @@ -848,7 +858,7 @@ def grad_refactor_2(a, b): def test_grad_refactor_2(): - assert C.grad_all(grad_refactor_2)(2, 3) == (27, 54) + assert C.grad_all(grad_refactor_2)(Tensor(2, dtype=ms.int32), Tensor(3, dtype=ms.int32)) == (27, 54) def grad_refactor_3(a): @@ -859,7 +869,10 @@ def grad_refactor_3(a): def test_grad_refactor_3(): - assert C.grad_all(grad_refactor_3)(3) == (3,) + @ms_function + def df_refactor_3(x): + return C.grad_all(grad_refactor_3)(x) + assert df_refactor_3(3) == (3,) def grad_refactor_4(a): @@ -870,7 +883,7 @@ def grad_refactor_4(a): def test_grad_refactor_4(): - assert C.grad_all(grad_refactor_4)(4) == (3,) + assert C.grad_all(grad_refactor_4)(Tensor(4, dtype=ms.int32)) == (3,) def grad_refactor_5(a): @@ -881,7 +894,10 @@ def grad_refactor_5(a): def test_grad_refactor_5(): - assert C.grad_all(grad_refactor_5)(1) == (1,) + @ms_function + def df_refactor_5(x): + return C.grad_all(grad_refactor_5)(x) + assert df_refactor_5(1) == (1,) def grad_refactor_6(a, b): @@ -892,7 +908,7 @@ def grad_refactor_6(a, b): def test_grad_refactor_6(): - assert C.grad_all(grad_refactor_6)(3, 2) == (3, 1) + assert C.grad_all(grad_refactor_6)(Tensor(3, dtype=ms.int32), Tensor(2, dtype=ms.int32)) == (3, 1) def grad_refactor_while(x): @@ -904,7 +920,10 @@ def grad_refactor_while(x): def test_grad_refactor_9(): - assert C.grad_all(grad_refactor_while)(3) == (6,) + @ms_function + def df_refactor_while(input_x): + return C.grad_all(grad_refactor_while)(input_x) + assert df_refactor_while(3) == (6,) def grad_refactor__while_1(x): @@ -919,7 +938,7 @@ def grad_refactor__while_1(x): def test_grad_refactor_10(): """ test_grad_while """ - assert C.grad_all(grad_refactor__while_1)(5) == (60,) + assert C.grad_all(grad_refactor__while_1)(Tensor(5, dtype=ms.int32)) == (60,) def test_grad_refactor_11(): @@ -985,7 +1004,10 @@ def grad_refactor_14(a, b): def test_grad_refactor_14(): - assert C.grad_all(grad_refactor_14)(2, 3) == (3, 9) + @ms_function + def df_refactor_14(x, y): + return C.grad_all(grad_refactor_14)(x, y) + assert df_refactor_14(2, 3) == (3, 9) # pylint: disable=using-constant-test @@ -1009,111 +1031,3 @@ def test_grad_if_defer_inline(): inp = Tensor(np.ones([128, 96]).astype(np.float32)) grads = C.grad_all(network)(inp) assert grads == (Tensor(np.full([128, 96], 0.6, dtype=np.float32)),) - - -def test_bprop_with_wrong_output_num(): - context.set_context(check_bprop=True) - class BpropWithWrongOutputNum(PrimitiveWithInfer): - @prim_attr_register - def __init__(self): - super(BpropWithWrongOutputNum, self).__init__('BpropWithWrongOutputNum') - - def __call__(self, x, y): - return x - - def infer_shape(self, x_shape, yshape): - return x_shape - - def infer_dtype(self, x_type, y_type): - return x_type - - @bprop_getters.register(BpropWithWrongOutputNum) - def get_bprop_with_wrong_output_num(self): - """Generate bprop for BpropWithWrongOutputNum""" - - def bprop(x, y, out, dout): - return (dout,) - - return bprop - - class BpropWithWrongOutputNumCell(nn.Cell): - def __init__(self): - super(BpropWithWrongOutputNumCell, self).__init__() - - def construct(self, x, y): - return BpropWithWrongOutputNum()(x, y) - - with pytest.raises(TypeError): - C.grad_all(BpropWithWrongOutputNumCell())(1, 2) - -def test_bprop_with_wrong_output_type(): - context.set_context(check_bprop=True) - class BpropWithWrongOutputType(PrimitiveWithInfer): - @prim_attr_register - def __init__(self): - super(BpropWithWrongOutputType, self).__init__('BpropWithWrongOutputType') - - def __call__(self, x): - return x - - def infer_shape(self, x_shape): - return x_shape - - def infer_dtype(self, x_type): - return x_type - - @bprop_getters.register(BpropWithWrongOutputType) - def get_bprop_with_wrong_output_type(self): - """Generate bprop for BpropWithWrongOutputType""" - - def bprop(x, out, dout): - return (1,) - - return bprop - - class BpropWithWrongOutputTypeCell(nn.Cell): - def __init__(self): - super(BpropWithWrongOutputTypeCell, self).__init__() - - def construct(self, x): - return BpropWithWrongOutputType()(x) - - with pytest.raises(TypeError): - C.grad_all(BpropWithWrongOutputTypeCell())(Tensor(np.ones([64, 10]).astype(np.int32))) - - -def test_bprop_with_wrong_output_shape(): - context.set_context(check_bprop=True) - class BpropWithWrongOutputShape(PrimitiveWithInfer): - @prim_attr_register - def __init__(self): - super(BpropWithWrongOutputShape, self).__init__('BpropWithWrongOutputShape') - - def __call__(self, x): - return x - - def infer_shape(self, x_shape): - return x_shape - - def infer_dtype(self, x_type): - return x_type - - @bprop_getters.register(BpropWithWrongOutputShape) - def get_bprop_with_wrong_output_shape(self): - """Generate bprop for BpropWithWrongOutputShape""" - ones = Tensor(np.ones([2,]).astype(np.int32)) - - def bprop(x, out, dout): - return (ones,) - - return bprop - - class BpropWithWrongOutputShapeCell(nn.Cell): - def __init__(self): - super(BpropWithWrongOutputShapeCell, self).__init__() - - def construct(self, x): - return BpropWithWrongOutputShape()(x) - - with pytest.raises(TypeError): - C.grad_all(BpropWithWrongOutputShapeCell())(Tensor(np.ones([64, 10]).astype(np.int32))) diff --git a/tests/ut/python/pynative_mode/test_hook.py b/tests/ut/python/pynative_mode/test_hook.py index 07a7a7ad8..f34a81ab5 100644 --- a/tests/ut/python/pynative_mode/test_hook.py +++ b/tests/ut/python/pynative_mode/test_hook.py @@ -13,6 +13,7 @@ # limitations under the License. # ============================================================================ import numpy as np +import pytest import mindspore.nn as nn import mindspore.ops.operations as P @@ -154,22 +155,47 @@ def test_hook(): print(loss_output.asnumpy().shape) +bprop_debug = False + class MulAdd(nn.Cell): def __init__(self): super(MulAdd, self).__init__() def construct(self, x, y): - return 2 * x + y + return 2 * x * x + y * y def bprop(self, x, y, out, dout): - assert (x == 1) - assert (y == 2) - assert (out == 4) - assert (dout == 1) - return 3 * dout, 2 * y + global bprop_debug + bprop_debug = True + return dout, 2 * y def test_custom_bprop(): mul_add = MulAdd() mul_add.bprop_debug = True - assert C.grad_all(mul_add)(1, 2) == (3, 4) + x = Tensor(np.array([1, 2, 3]).astype(np.int32)) + y = Tensor(np.array([2, 3, 4]).astype(np.int32)) + C.grad_all(mul_add)(x, y) + assert bprop_debug + + +class Net(nn.Cell): + def __init__(self): + super(Net, self).__init__() + + def construct(self, x, y): + return 2 * x * x + y * y + +def test_grad_all(): + net = Net() + x = Tensor(np.array([1, 2, 3]).astype(np.int32)) + y = Tensor(np.array([2, 3, 4]).astype(np.int32)) + res = C.grad_all(net)(x, y) + print(res) + +def test_check_input(): + net = Net() + x = np.array([1, 2, 3]) + y = np.array([2, 3, 4]) + with pytest.raises(TypeError): + net(x, y) diff --git a/tests/ut/python/pynative_mode/test_insert_grad_of.py b/tests/ut/python/pynative_mode/test_insert_grad_of.py index 0a28bbbb6..218a4ee25 100644 --- a/tests/ut/python/pynative_mode/test_insert_grad_of.py +++ b/tests/ut/python/pynative_mode/test_insert_grad_of.py @@ -46,6 +46,7 @@ def test_InsertGradientOf_1(): c = x * y return c + @ms_function def f(x, y): return C.grad_all(stop_test)(x, y) @@ -80,6 +81,7 @@ def test_InsertGradientOf_2(): def f(x, y): return clip_test(x, y) + @ms_function def fd(x, y): return C.grad_all(clip_test)(x, y) diff --git a/tests/ut/python/pynative_mode/test_stop_gradient.py b/tests/ut/python/pynative_mode/test_stop_gradient.py index a94f80adf..09e4f25c5 100644 --- a/tests/ut/python/pynative_mode/test_stop_gradient.py +++ b/tests/ut/python/pynative_mode/test_stop_gradient.py @@ -16,6 +16,7 @@ import numpy as np import pytest +import mindspore as ms import mindspore.common.dtype as mstype import mindspore.nn as nn from mindspore import Parameter, ParameterTuple @@ -81,16 +82,24 @@ def stop_test4(x, y): return e +@ms_function def grad_stop_test(x, y): """ grad_stop_test """ return C.grad_all(stop_test2)(x, y) +@ms_function def grad_stop_test1(x, y): """ grad_stop_test1 """ return C.grad_all(stop_test3)(x, y) +@ms_function +def grad_stop_test5(x, y): + """ grad_stop_test5 """ + return C.grad_all(stop_test5)(x, y) + + def test_stop(): """ test_stop """ print("test_stop:", grad_stop_test(1, 1)) @@ -103,7 +112,7 @@ def test_stop1(): def test_stop5(): """ test_stop1 """ - print("test_stop5:", C.grad_all(stop_test5)(2, 3)) + print("test_stop5:", grad_stop_test5(2, 3)) class GradWrap(nn.Cell): @@ -247,7 +256,7 @@ def test_stop_gradient_4(): def stop_test(x): return stop_gradient(x) - assert C.grad_all(stop_test)(1) == (0,) + assert C.grad_all(stop_test)(Tensor(1, dtype=ms.int32)) == (0,) def test_stop_gradient_5(): @@ -257,7 +266,7 @@ def test_stop_gradient_5(): ret = x + y return ret - assert C.grad_all(stop_test)(1) == (1,) + assert C.grad_all(stop_test)(Tensor(1, dtype=ms.int32)) == (1,) def test_stop_gradient_6(): @@ -266,7 +275,7 @@ def test_stop_gradient_6(): ret = stop_gradient(ret) return ret - assert C.grad_all(stop_test)(1, 3) == (0, 0) + assert C.grad_all(stop_test)(Tensor(1, dtype=ms.int32), Tensor(3, dtype=ms.int32)) == (0, 0) class PrimWithMultiOutputs(PrimitiveWithInfer): -- GitLab