提交 add3778a 编写于 作者: K kingfo

add grad all in pynative mode

上级 f201bd65
...@@ -980,7 +980,7 @@ std::vector<AnfNodePtr> PynativeExecutor::GetWeightsArgs(const py::object &weigh ...@@ -980,7 +980,7 @@ std::vector<AnfNodePtr> PynativeExecutor::GetWeightsArgs(const py::object &weigh
} }
} }
} else { } else {
MS_LOG(EXCEPTION) << "training not paramter_tuple"; MS_LOG(DEBUG) << "training not paramter_tuple";
} }
return w_args; return w_args;
} }
......
...@@ -181,6 +181,9 @@ class Tensor(Tensor_): ...@@ -181,6 +181,9 @@ class Tensor(Tensor_):
def __imod__(self, other): def __imod__(self, other):
return self.__mod__(other) return self.__mod__(other)
def __pow__(self, other):
return tensor_operator_registry.get('__pow__')(self, other)
def __floordiv__(self, other): def __floordiv__(self, other):
return tensor_operator_registry.get('__floordiv__')(self, other) return tensor_operator_registry.get('__floordiv__')(self, other)
......
...@@ -176,7 +176,10 @@ class _Context: ...@@ -176,7 +176,10 @@ class _Context:
self._context_switches.push(True, None) self._context_switches.push(True, None)
else: else:
if self.enable_debug_runtime: if self.enable_debug_runtime:
self.set_backend_policy("ge") if self.device_target == "CPU":
self.set_backend_policy("vm")
else:
self.set_backend_policy("ge")
self._context_switches.push(False, None) self._context_switches.push(False, None)
def set_backend_policy(self, policy): def set_backend_policy(self, policy):
......
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
import time import time
import gc import gc
from collections import OrderedDict from collections import OrderedDict
import numpy
from mindspore import log as logger from mindspore import log as logger
from .. import context from .. import context
from ..common import dtype as mstype from ..common import dtype as mstype
...@@ -211,6 +212,9 @@ class Cell: ...@@ -211,6 +212,9 @@ class Cell:
if context.get_context("mode") == context.GRAPH_MODE: if context.get_context("mode") == context.GRAPH_MODE:
out = self.compile_and_run(*inputs) out = self.compile_and_run(*inputs)
return out return out
for item in inputs:
if isinstance(item, numpy.ndarray):
raise TypeError("cell inputs should not be numpy array.")
self.init_parameters_data() self.init_parameters_data()
orign_grad = [] orign_grad = []
if self.requires_grad is True: if self.requires_grad is True:
......
...@@ -17,6 +17,7 @@ ...@@ -17,6 +17,7 @@
"""Basic composite operations.""" """Basic composite operations."""
from functools import partial from functools import partial
from types import FunctionType
from mindspore import context from mindspore import context
from ..._c_expression import EnvInstance_, GradOperation_, HyperMap_, Map_, MultitypeFuncGraph_, Tail_, \ from ..._c_expression import EnvInstance_, GradOperation_, HyperMap_, Map_, MultitypeFuncGraph_, Tail_, \
...@@ -25,6 +26,7 @@ from ...common import dtype as mstype ...@@ -25,6 +26,7 @@ from ...common import dtype as mstype
from ...common.api import ms_function, _pynative_exec, _wrap_func from ...common.api import ms_function, _pynative_exec, _wrap_func
from .. import functional as F from .. import functional as F
from ...common.parameter import Parameter from ...common.parameter import Parameter
from ...common.tensor import Tensor
__all__ = [EnvInstance_, TupleAdd_, TupleSlice_, UnpackCall_, TupleGetItemTensor_] __all__ = [EnvInstance_, TupleAdd_, TupleSlice_, UnpackCall_, TupleGetItemTensor_]
...@@ -114,37 +116,48 @@ class GradOperation(GradOperation_): ...@@ -114,37 +116,48 @@ class GradOperation(GradOperation_):
self.fn = None self.fn = None
self.need_forward = False self.need_forward = False
def _pynative_forward_run(self, args, fn):
""" Pynative forward run to build grad graph. """
if self.sens_param:
args = args[:-1]
if isinstance(fn, FunctionType):
_pynative_exec.set_grad_flag(True)
_pynative_exec.new_graph(fn, *args)
output = fn(*args)
_pynative_exec.end_graph(fn, output, *args)
else:
if fn.is_run and not fn.requires_grad:
raise ValueError("obj must set_grad.")
if not fn.is_run:
self.need_forward = True
print("already has forward run before grad by user")
if self.need_forward:
fn.set_grad()
fn(*args)
def __call__(self, fn, weights=None): def __call__(self, fn, weights=None):
grad_ = GradOperation('grad', self.get_all, self.get_by_list, self.sens_param) grad_ = GradOperation('grad', self.get_all, self.get_by_list, self.sens_param)
if self.grad_fn is None or self.fn != fn: if self.grad_fn is None or self.fn != fn:
if self.get_by_list: if context.get_context("mode") == context.GRAPH_MODE:
if context.get_context("mode") == context.GRAPH_MODE: if self.get_by_list:
@ms_function(obj=fn) @ms_function(obj=fn)
def after_grad(*args): def after_grad(*args):
return grad_(fn, weights)(*args) return grad_(fn, weights)(*args)
else: else:
@_wrap_func @ms_function(obj=fn)
def after_grad(*args): def after_grad(*args):
if fn.is_run and not fn.requires_grad: return grad_(fn)(*args)
raise ValueError("obj must set_grad.")
if not fn.is_run:
self.need_forward = True
print("already has forward run before grad by user")
if self.need_forward:
fn.set_grad()
if self.sens_param:
f_args = args[:-1]
fn(*f_args)
else:
fn(*args)
_pynative_exec.grad(grad_, fn, weights, *args)
out = _pynative_exec(*args)
_pynative_exec.clear()
return out
else: else:
@ms_function(obj=fn) @_wrap_func
def after_grad(*args): def after_grad(*args):
return grad_(fn)(*args) for arg in args:
if not isinstance(arg, Tensor):
raise TypeError("grad inputs should be tensor in pynative mode")
self._pynative_forward_run(args, fn)
_pynative_exec.grad(grad_, fn, weights, *args)
out = _pynative_exec(*args)
_pynative_exec.clear()
return out
self.grad_fn = after_grad self.grad_fn = after_grad
self.fn = fn self.fn = fn
return self.grad_fn return self.grad_fn
......
...@@ -166,6 +166,7 @@ tensor_operator_registry.register('__sub__', tensor_sub) ...@@ -166,6 +166,7 @@ tensor_operator_registry.register('__sub__', tensor_sub)
tensor_operator_registry.register('__mul__', tensor_mul) tensor_operator_registry.register('__mul__', tensor_mul)
tensor_operator_registry.register('__truediv__', tensor_div) tensor_operator_registry.register('__truediv__', tensor_div)
tensor_operator_registry.register('__mod__', tensor_mod) tensor_operator_registry.register('__mod__', tensor_mod)
tensor_operator_registry.register('__pow__', tensor_pow)
tensor_operator_registry.register('__floordiv__', tensor_floordiv) tensor_operator_registry.register('__floordiv__', tensor_floordiv)
#ms cannot support Tensor(True) compare #ms cannot support Tensor(True) compare
tensor_operator_registry.register('__eq__', equal) tensor_operator_registry.register('__eq__', equal)
......
...@@ -228,6 +228,7 @@ def test_biasadd_3d(): ...@@ -228,6 +228,7 @@ def test_biasadd_3d():
error = np.ones(shape=[3, 4, 8]) * 1.0e-6 error = np.ones(shape=[3, 4, 8]) * 1.0e-6
context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU") context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
net = BiasAdd() net = BiasAdd()
net.set_grad()
result = net(x, b) result = net(x, b)
diff = result.asnumpy() - expect diff = result.asnumpy() - expect
assert np.all(diff < error) assert np.all(diff < error)
......
...@@ -45,6 +45,7 @@ def test_net_infer(): ...@@ -45,6 +45,7 @@ def test_net_infer():
def test_assign_in_while(): def test_assign_in_while():
context.set_context(device_target="Ascend")
context.set_context(mode=context.GRAPH_MODE) context.set_context(mode=context.GRAPH_MODE)
class Net(nn.Cell): class Net(nn.Cell):
def __init__(self, input_shape): def __init__(self, input_shape):
......
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
import numpy as np import numpy as np
import pytest import pytest
import mindspore as ms
import mindspore.common.dtype as mstype import mindspore.common.dtype as mstype
import mindspore.nn as nn import mindspore.nn as nn
from mindspore import Parameter from mindspore import Parameter
...@@ -24,12 +25,15 @@ from mindspore.common.initializer import initializer ...@@ -24,12 +25,15 @@ from mindspore.common.initializer import initializer
from mindspore.common.tensor import Tensor from mindspore.common.tensor import Tensor
from mindspore.ops import composite as C from mindspore.ops import composite as C
from mindspore.ops import operations as P from mindspore.ops import operations as P
from ....mindspore_test_framework.utils.bprop_util import bprop from .....mindspore_test_framework.utils.bprop_util import bprop
def setup_module(module): def setup_module(module):
context.set_context(mode=context.PYNATIVE_MODE) context.set_context(device_target="CPU")
context.set_context(mode=context.GRAPH_MODE)
def teardown_module(module):
context.set_context(device_target="Ascend")
class MulAdd(nn.Cell): class MulAdd(nn.Cell):
def __init__(self): def __init__(self):
...@@ -45,7 +49,9 @@ class MulAdd(nn.Cell): ...@@ -45,7 +49,9 @@ class MulAdd(nn.Cell):
def test_grad_mul_add(): def test_grad_mul_add():
mul_add = MulAdd() mul_add = MulAdd()
assert C.grad_all(mul_add)(1, 2) == (2, 4) x = Tensor(1, dtype=ms.int32)
y = Tensor(2, dtype=ms.int32)
assert C.grad_all(mul_add)(x, y) == (2, 4)
class InlineMulADD(nn.Cell): class InlineMulADD(nn.Cell):
...@@ -60,7 +66,9 @@ class InlineMulADD(nn.Cell): ...@@ -60,7 +66,9 @@ class InlineMulADD(nn.Cell):
def test_grad_inline_mul_add(): def test_grad_inline_mul_add():
inline_mul_add = InlineMulADD() inline_mul_add = InlineMulADD()
assert C.grad_all(inline_mul_add)(1, 2) == (3, 6) x = Tensor(1, dtype=ms.int32)
y = Tensor(2, dtype=ms.int32)
assert C.grad_all(inline_mul_add)(x, y) == (3, 6)
class WithParameter(nn.Cell): class WithParameter(nn.Cell):
...@@ -93,7 +101,9 @@ class WithNoBprop(nn.Cell): ...@@ -93,7 +101,9 @@ class WithNoBprop(nn.Cell):
def test_with_no_bprop(): def test_with_no_bprop():
with_no_bprop = WithNoBprop() with_no_bprop = WithNoBprop()
assert C.grad_all(with_no_bprop)(1, 2) == (2, 1) x = Tensor(1, dtype=ms.int32)
y = Tensor(2, dtype=ms.int32)
assert C.grad_all(with_no_bprop)(x, y) == (2, 1)
def test_grad_in_bprop_1(): def test_grad_in_bprop_1():
......
...@@ -19,21 +19,27 @@ ...@@ -19,21 +19,27 @@
@Desc : @Desc :
""" """
import logging import logging
import pytest
import numpy as np import numpy as np
import mindspore as ms import mindspore as ms
import mindspore.nn as nn import mindspore.nn as nn
from mindspore import Tensor from mindspore import Tensor
from mindspore import context
from mindspore.ops import composite as C
from mindspore.common.api import ms_function, _executor from mindspore.common.api import ms_function, _executor
from mindspore.ops._grad.grad_base import bprop_getters
from mindspore.ops.primitive import prim_attr_register, PrimitiveWithInfer
from mindspore.ops.functional import tensor_add from mindspore.ops.functional import tensor_add
from ...ut_filter import non_graph_engine from ...ut_filter import non_graph_engine
# pylint: disable=W0613 # pylint: disable=W0613,W0612
# W0613: unused-argument # W0613: unused-argument
log = logging.getLogger("test") log = logging.getLogger("test")
log.setLevel(level=logging.ERROR) log.setLevel(level=logging.ERROR)
context.set_context(mode=context.GRAPH_MODE)
# Test case: use the parse obj interface use default parameter # Test case: use the parse obj interface use default parameter
...@@ -135,3 +141,113 @@ def test_net_with_ndarray(): ...@@ -135,3 +141,113 @@ def test_net_with_ndarray():
input_data = np.array([[1.2, 2.1], [2.2, 3.2]]).astype('float32') input_data = np.array([[1.2, 2.1], [2.2, 3.2]]).astype('float32')
net(ms.Tensor(input_data)) net(ms.Tensor(input_data))
def test_bprop_with_wrong_output_num():
context.set_context(check_bprop=True)
class BpropWithWrongOutputNum(PrimitiveWithInfer):
@prim_attr_register
def __init__(self):
super(BpropWithWrongOutputNum, self).__init__('BpropWithWrongOutputNum')
def __call__(self, x, y):
return x
def infer_shape(self, x_shape, yshape):
return x_shape
def infer_dtype(self, x_type, y_type):
return x_type
@bprop_getters.register(BpropWithWrongOutputNum)
def get_bprop_with_wrong_output_num(self):
"""Generate bprop for BpropWithWrongOutputNum"""
def bprop(x, y, out, dout):
return (dout,)
return bprop
class BpropWithWrongOutputNumCell(nn.Cell):
def __init__(self):
super(BpropWithWrongOutputNumCell, self).__init__()
def construct(self, x, y):
return BpropWithWrongOutputNum()(x, y)
with pytest.raises(TypeError):
C.grad_all(BpropWithWrongOutputNumCell())(1, 2)
def test_bprop_with_wrong_output_type():
context.set_context(check_bprop=True)
class BpropWithWrongOutputType(PrimitiveWithInfer):
@prim_attr_register
def __init__(self):
super(BpropWithWrongOutputType, self).__init__('BpropWithWrongOutputType')
def __call__(self, x):
return x
def infer_shape(self, x_shape):
return x_shape
def infer_dtype(self, x_type):
return x_type
@bprop_getters.register(BpropWithWrongOutputType)
def get_bprop_with_wrong_output_type(self):
"""Generate bprop for BpropWithWrongOutputType"""
def bprop(x, out, dout):
return (1,)
return bprop
class BpropWithWrongOutputTypeCell(nn.Cell):
def __init__(self):
super(BpropWithWrongOutputTypeCell, self).__init__()
def construct(self, x):
return BpropWithWrongOutputType()(x)
with pytest.raises(TypeError):
C.grad_all(BpropWithWrongOutputTypeCell())(Tensor(np.ones([64, 10]).astype(np.int32)))
def test_bprop_with_wrong_output_shape():
context.set_context(check_bprop=True)
class BpropWithWrongOutputShape(PrimitiveWithInfer):
@prim_attr_register
def __init__(self):
super(BpropWithWrongOutputShape, self).__init__('BpropWithWrongOutputShape')
def __call__(self, x):
return x
def infer_shape(self, x_shape):
return x_shape
def infer_dtype(self, x_type):
return x_type
@bprop_getters.register(BpropWithWrongOutputShape)
def get_bprop_with_wrong_output_shape(self):
"""Generate bprop for BpropWithWrongOutputShape"""
ones = Tensor(np.ones([2,]).astype(np.int32))
def bprop(x, out, dout):
return (ones,)
return bprop
class BpropWithWrongOutputShapeCell(nn.Cell):
def __init__(self):
super(BpropWithWrongOutputShapeCell, self).__init__()
def construct(self, x):
return BpropWithWrongOutputShape()(x)
with pytest.raises(TypeError):
net = BpropWithWrongOutputShapeCell()
net.set_grad()
C.grad_all(net)(Tensor(np.ones([64, 10]).astype(np.int32)))
...@@ -78,3 +78,9 @@ def test_tensor_imul(): ...@@ -78,3 +78,9 @@ def test_tensor_imul():
y = Tensor(np.ones([3, 3, 3, 3]).astype(np.float32)) y = Tensor(np.ones([3, 3, 3, 3]).astype(np.float32))
x *= y x *= y
assert x.asnumpy()[0][0][0][0] == 1.0 assert x.asnumpy()[0][0][0][0] == 1.0
def test_tensor_pow():
x = Tensor(np.ones([3, 3, 3, 3]).astype(np.float32) * 2)
y = x ** 3
assert y.asnumpy()[0][0][0][0] == 8.0
...@@ -89,7 +89,11 @@ def test_scalar_cast_grad(): ...@@ -89,7 +89,11 @@ def test_scalar_cast_grad():
output = F.scalar_cast(x, input_t) output = F.scalar_cast(x, input_t)
return output return output
gfn = C.grad(fx_cast)(input_x) @ms_function
def grad_fx_cast(input_x):
return C.grad(fx_cast)(input_x)
gfn = grad_fx_cast(input_x)
expect_dx = 1 expect_dx = 1
assert gfn == expect_dx assert gfn == expect_dx
...@@ -133,25 +137,6 @@ def test_transpose_grad(): ...@@ -133,25 +137,6 @@ def test_transpose_grad():
assert np.all(gout[0].asnumpy() == expect) assert np.all(gout[0].asnumpy() == expect)
@non_graph_engine
def test_squeeze_grad():
""" test_squeeze_grad """
input_tensor = Tensor(np.ones(shape=[3, 2, 1]))
squeeze = P.Squeeze(2)
def fn(x):
output = squeeze(x)
return output
out = fn(input_tensor)
gfn = grad_all_with_sens(fn)
sens = Tensor(np.ones_like(out.asnumpy()))
args = [input_tensor, sens]
gout = gfn(*args)
expect = np.ones([3, 2, 1])
assert np.all(gout[0].asnumpy() == expect)
def test_select_grad(): def test_select_grad():
""" test_select_grad """ """ test_select_grad """
select = P.Select() select = P.Select()
...@@ -176,6 +161,25 @@ def test_select_grad(): ...@@ -176,6 +161,25 @@ def test_select_grad():
assert np.all(gout[2].asnumpy() == expect_y) assert np.all(gout[2].asnumpy() == expect_y)
@non_graph_engine
def test_squeeze_grad():
""" test_squeeze_grad """
input_tensor = Tensor(np.ones(shape=[3, 2, 1]))
squeeze = P.Squeeze(2)
def fn(x):
output = squeeze(x)
return output
out = fn(input_tensor)
gfn = grad_all_with_sens(fn)
sens = Tensor(np.ones_like(out.asnumpy()))
args = [input_tensor, sens]
gout = gfn(*args)
expect = np.ones([3, 2, 1])
assert np.all(gout[0].asnumpy() == expect)
def test_SubGrad(): def test_SubGrad():
""" test_SubGrad """ """ test_SubGrad """
input_x = Tensor(np.array([[2, 2]])) input_x = Tensor(np.array([[2, 2]]))
......
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
import numpy as np import numpy as np
import pytest import pytest
import mindspore as ms
import mindspore.nn as nn import mindspore.nn as nn
from mindspore import context from mindspore import context
from mindspore.common import dtype as mstype from mindspore.common import dtype as mstype
...@@ -23,8 +24,6 @@ from mindspore.common.parameter import Parameter, ParameterTuple ...@@ -23,8 +24,6 @@ from mindspore.common.parameter import Parameter, ParameterTuple
from mindspore.common.tensor import Tensor from mindspore.common.tensor import Tensor
from mindspore.ops import composite as C from mindspore.ops import composite as C
from mindspore.ops import operations as P from mindspore.ops import operations as P
from mindspore.ops._grad.grad_base import bprop_getters
from mindspore.ops.primitive import prim_attr_register, PrimitiveWithInfer
from ..ut_filter import non_graph_engine from ..ut_filter import non_graph_engine
from ....mindspore_test_framework.utils.check_gradient import ( from ....mindspore_test_framework.utils.check_gradient import (
ms_function, check_jacobian, Tensor, NNGradChecker, ms_function, check_jacobian, Tensor, NNGradChecker,
...@@ -156,14 +155,14 @@ def test_if_always_true(): ...@@ -156,14 +155,14 @@ def test_if_always_true():
@non_graph_engine @non_graph_engine
def test_f(): def test_f():
""" test_f """ """ test_f """
res = mainf(3, 2) res = mainf(Tensor(3, dtype=ms.int32), Tensor(2, dtype=ms.int32))
assert res == (2, 3) assert res == (2, 3)
@non_graph_engine @non_graph_engine
def test_grad_add_mul(): def test_grad_add_mul():
""" test_grad_add_mul """ """ test_grad_add_mul """
res = grad_add_mul(3, 2) res = grad_add_mul(Tensor(3, dtype=ms.int32), Tensor(2, dtype=ms.int32))
assert res == (2, 7) assert res == (2, 7)
...@@ -262,17 +261,19 @@ def test_if_tensor(): ...@@ -262,17 +261,19 @@ def test_if_tensor():
assert res == Tensor(np.ones([1]).astype(np.int32) * 4) assert res == Tensor(np.ones([1]).astype(np.int32) * 4)
@ms_function
def rec(x): def rec(x):
""" rec """ """ rec """
if x > 0: if x > 0:
return rec(x - 1) return rec(x - 1)
return x return x
@ms_function
def grad_rec(input_x):
return C.grad(rec)(input_x)
def test_grad_rec(): def test_grad_rec():
""" test_grad_rec """ """ test_grad_rec """
res = C.grad(rec)(10) res = grad_rec(3)
assert res == 1 assert res == 1
...@@ -282,7 +283,6 @@ def test_me_rec(): ...@@ -282,7 +283,6 @@ def test_me_rec():
assert res == 0 assert res == 0
@ms_function
def t2_while(x, y): def t2_while(x, y):
out = y - x out = y - x
i = 0 i = 0
...@@ -298,8 +298,10 @@ def test_while2(): ...@@ -298,8 +298,10 @@ def test_while2():
def test_grad_while2(): def test_grad_while2():
res = C.grad(t2_while)(2, 3) @ms_function
assert res == 3 def df_t2_while(input_x, input_y):
return C.grad(t2_while)(input_x, input_y)
assert df_t2_while(2, 3) == 3
def if_test(a, b): def if_test(a, b):
...@@ -316,7 +318,7 @@ def grad_if(x, y): ...@@ -316,7 +318,7 @@ def grad_if(x, y):
def test_grad_if(): def test_grad_if():
""" test_grad_if """ """ test_grad_if """
assert grad_if(5, 4) == (3, 0) assert grad_if(Tensor(5, dtype=ms.int32), Tensor(4, dtype=ms.int32)) == (3, 0)
# While loop is not unrolled in forward and backward graphs. # While loop is not unrolled in forward and backward graphs.
...@@ -421,7 +423,7 @@ def grad_while(x): ...@@ -421,7 +423,7 @@ def grad_while(x):
def test_grad_while(): def test_grad_while():
""" test_grad_while """ """ test_grad_while """
assert grad_while(5) == (60,) assert grad_while(Tensor(5, dtype=ms.int32)) == (60,)
@ms_function @ms_function
...@@ -438,8 +440,10 @@ def test_factorial(): ...@@ -438,8 +440,10 @@ def test_factorial():
def test_grad_factorial(): def test_grad_factorial():
res = C.grad(factorial)(3) @ms_function
assert res == 11 def df_factorial(x):
return C.grad(factorial)(x)
assert df_factorial(3) == 11
@ms_function @ms_function
...@@ -513,7 +517,7 @@ def _for(x): ...@@ -513,7 +517,7 @@ def _for(x):
ret = ret * i ret = ret * i
return ret return ret
@ms_function
def grad_for(x): def grad_for(x):
""" grad_for """ """ grad_for """
return C.grad_all(_for)(x) return C.grad_all(_for)(x)
...@@ -786,7 +790,10 @@ def multi_outputs(x, y): ...@@ -786,7 +790,10 @@ def multi_outputs(x, y):
def test_grad_multi_outputs(): def test_grad_multi_outputs():
assert C.grad_all_with_sens(multi_outputs)(2, 3, (1, 1)) == (4, 4) @ms_function
def df_multi_outputs(x, y):
return C.grad_all_with_sens(multi_outputs)(x, y, (1, 1))
assert df_multi_outputs(2, 3) == (4, 4)
@ms_function @ms_function
...@@ -813,7 +820,7 @@ def grad_refactor_simple_1(x, y): ...@@ -813,7 +820,7 @@ def grad_refactor_simple_1(x, y):
def test_grad_refactor_simple_1(): def test_grad_refactor_simple_1():
assert C.grad_all(grad_refactor_simple_1)(2, 1) == (4, 2) assert C.grad_all(grad_refactor_simple_1)(Tensor(2, dtype=ms.int32), Tensor(1, dtype=ms.int32)) == (4, 2)
def grad_refactor_simple_2(x, y, z): def grad_refactor_simple_2(x, y, z):
...@@ -822,7 +829,10 @@ def grad_refactor_simple_2(x, y, z): ...@@ -822,7 +829,10 @@ def grad_refactor_simple_2(x, y, z):
def test_grad_refactor_simple_2(): def test_grad_refactor_simple_2():
assert C.grad_all(grad_refactor_simple_2)(2, 3, 0) == (7, 4, 7) x = Tensor(2, dtype=ms.int32)
y = Tensor(3, dtype=ms.int32)
z = Tensor(0, dtype=ms.int32)
assert C.grad_all(grad_refactor_simple_2)(x, y, z) == (7, 4, 7)
def grad_refactor_1(a, b): def grad_refactor_1(a, b):
...@@ -835,7 +845,7 @@ def grad_refactor_1(a, b): ...@@ -835,7 +845,7 @@ def grad_refactor_1(a, b):
def test_grad_refactor_1(): def test_grad_refactor_1():
assert C.grad_all(grad_refactor_1)(2, 3) == (3, 2) assert C.grad_all(grad_refactor_1)(Tensor(2, dtype=ms.int32), Tensor(3, dtype=ms.int32)) == (3, 2)
def grad_refactor_2(a, b): def grad_refactor_2(a, b):
...@@ -848,7 +858,7 @@ def grad_refactor_2(a, b): ...@@ -848,7 +858,7 @@ def grad_refactor_2(a, b):
def test_grad_refactor_2(): def test_grad_refactor_2():
assert C.grad_all(grad_refactor_2)(2, 3) == (27, 54) assert C.grad_all(grad_refactor_2)(Tensor(2, dtype=ms.int32), Tensor(3, dtype=ms.int32)) == (27, 54)
def grad_refactor_3(a): def grad_refactor_3(a):
...@@ -859,7 +869,10 @@ def grad_refactor_3(a): ...@@ -859,7 +869,10 @@ def grad_refactor_3(a):
def test_grad_refactor_3(): def test_grad_refactor_3():
assert C.grad_all(grad_refactor_3)(3) == (3,) @ms_function
def df_refactor_3(x):
return C.grad_all(grad_refactor_3)(x)
assert df_refactor_3(3) == (3,)
def grad_refactor_4(a): def grad_refactor_4(a):
...@@ -870,7 +883,7 @@ def grad_refactor_4(a): ...@@ -870,7 +883,7 @@ def grad_refactor_4(a):
def test_grad_refactor_4(): def test_grad_refactor_4():
assert C.grad_all(grad_refactor_4)(4) == (3,) assert C.grad_all(grad_refactor_4)(Tensor(4, dtype=ms.int32)) == (3,)
def grad_refactor_5(a): def grad_refactor_5(a):
...@@ -881,7 +894,10 @@ def grad_refactor_5(a): ...@@ -881,7 +894,10 @@ def grad_refactor_5(a):
def test_grad_refactor_5(): def test_grad_refactor_5():
assert C.grad_all(grad_refactor_5)(1) == (1,) @ms_function
def df_refactor_5(x):
return C.grad_all(grad_refactor_5)(x)
assert df_refactor_5(1) == (1,)
def grad_refactor_6(a, b): def grad_refactor_6(a, b):
...@@ -892,7 +908,7 @@ def grad_refactor_6(a, b): ...@@ -892,7 +908,7 @@ def grad_refactor_6(a, b):
def test_grad_refactor_6(): def test_grad_refactor_6():
assert C.grad_all(grad_refactor_6)(3, 2) == (3, 1) assert C.grad_all(grad_refactor_6)(Tensor(3, dtype=ms.int32), Tensor(2, dtype=ms.int32)) == (3, 1)
def grad_refactor_while(x): def grad_refactor_while(x):
...@@ -904,7 +920,10 @@ def grad_refactor_while(x): ...@@ -904,7 +920,10 @@ def grad_refactor_while(x):
def test_grad_refactor_9(): def test_grad_refactor_9():
assert C.grad_all(grad_refactor_while)(3) == (6,) @ms_function
def df_refactor_while(input_x):
return C.grad_all(grad_refactor_while)(input_x)
assert df_refactor_while(3) == (6,)
def grad_refactor__while_1(x): def grad_refactor__while_1(x):
...@@ -919,7 +938,7 @@ def grad_refactor__while_1(x): ...@@ -919,7 +938,7 @@ def grad_refactor__while_1(x):
def test_grad_refactor_10(): def test_grad_refactor_10():
""" test_grad_while """ """ test_grad_while """
assert C.grad_all(grad_refactor__while_1)(5) == (60,) assert C.grad_all(grad_refactor__while_1)(Tensor(5, dtype=ms.int32)) == (60,)
def test_grad_refactor_11(): def test_grad_refactor_11():
...@@ -985,7 +1004,10 @@ def grad_refactor_14(a, b): ...@@ -985,7 +1004,10 @@ def grad_refactor_14(a, b):
def test_grad_refactor_14(): def test_grad_refactor_14():
assert C.grad_all(grad_refactor_14)(2, 3) == (3, 9) @ms_function
def df_refactor_14(x, y):
return C.grad_all(grad_refactor_14)(x, y)
assert df_refactor_14(2, 3) == (3, 9)
# pylint: disable=using-constant-test # pylint: disable=using-constant-test
...@@ -1009,111 +1031,3 @@ def test_grad_if_defer_inline(): ...@@ -1009,111 +1031,3 @@ def test_grad_if_defer_inline():
inp = Tensor(np.ones([128, 96]).astype(np.float32)) inp = Tensor(np.ones([128, 96]).astype(np.float32))
grads = C.grad_all(network)(inp) grads = C.grad_all(network)(inp)
assert grads == (Tensor(np.full([128, 96], 0.6, dtype=np.float32)),) assert grads == (Tensor(np.full([128, 96], 0.6, dtype=np.float32)),)
def test_bprop_with_wrong_output_num():
context.set_context(check_bprop=True)
class BpropWithWrongOutputNum(PrimitiveWithInfer):
@prim_attr_register
def __init__(self):
super(BpropWithWrongOutputNum, self).__init__('BpropWithWrongOutputNum')
def __call__(self, x, y):
return x
def infer_shape(self, x_shape, yshape):
return x_shape
def infer_dtype(self, x_type, y_type):
return x_type
@bprop_getters.register(BpropWithWrongOutputNum)
def get_bprop_with_wrong_output_num(self):
"""Generate bprop for BpropWithWrongOutputNum"""
def bprop(x, y, out, dout):
return (dout,)
return bprop
class BpropWithWrongOutputNumCell(nn.Cell):
def __init__(self):
super(BpropWithWrongOutputNumCell, self).__init__()
def construct(self, x, y):
return BpropWithWrongOutputNum()(x, y)
with pytest.raises(TypeError):
C.grad_all(BpropWithWrongOutputNumCell())(1, 2)
def test_bprop_with_wrong_output_type():
context.set_context(check_bprop=True)
class BpropWithWrongOutputType(PrimitiveWithInfer):
@prim_attr_register
def __init__(self):
super(BpropWithWrongOutputType, self).__init__('BpropWithWrongOutputType')
def __call__(self, x):
return x
def infer_shape(self, x_shape):
return x_shape
def infer_dtype(self, x_type):
return x_type
@bprop_getters.register(BpropWithWrongOutputType)
def get_bprop_with_wrong_output_type(self):
"""Generate bprop for BpropWithWrongOutputType"""
def bprop(x, out, dout):
return (1,)
return bprop
class BpropWithWrongOutputTypeCell(nn.Cell):
def __init__(self):
super(BpropWithWrongOutputTypeCell, self).__init__()
def construct(self, x):
return BpropWithWrongOutputType()(x)
with pytest.raises(TypeError):
C.grad_all(BpropWithWrongOutputTypeCell())(Tensor(np.ones([64, 10]).astype(np.int32)))
def test_bprop_with_wrong_output_shape():
context.set_context(check_bprop=True)
class BpropWithWrongOutputShape(PrimitiveWithInfer):
@prim_attr_register
def __init__(self):
super(BpropWithWrongOutputShape, self).__init__('BpropWithWrongOutputShape')
def __call__(self, x):
return x
def infer_shape(self, x_shape):
return x_shape
def infer_dtype(self, x_type):
return x_type
@bprop_getters.register(BpropWithWrongOutputShape)
def get_bprop_with_wrong_output_shape(self):
"""Generate bprop for BpropWithWrongOutputShape"""
ones = Tensor(np.ones([2,]).astype(np.int32))
def bprop(x, out, dout):
return (ones,)
return bprop
class BpropWithWrongOutputShapeCell(nn.Cell):
def __init__(self):
super(BpropWithWrongOutputShapeCell, self).__init__()
def construct(self, x):
return BpropWithWrongOutputShape()(x)
with pytest.raises(TypeError):
C.grad_all(BpropWithWrongOutputShapeCell())(Tensor(np.ones([64, 10]).astype(np.int32)))
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
# limitations under the License. # limitations under the License.
# ============================================================================ # ============================================================================
import numpy as np import numpy as np
import pytest
import mindspore.nn as nn import mindspore.nn as nn
import mindspore.ops.operations as P import mindspore.ops.operations as P
...@@ -154,22 +155,47 @@ def test_hook(): ...@@ -154,22 +155,47 @@ def test_hook():
print(loss_output.asnumpy().shape) print(loss_output.asnumpy().shape)
bprop_debug = False
class MulAdd(nn.Cell): class MulAdd(nn.Cell):
def __init__(self): def __init__(self):
super(MulAdd, self).__init__() super(MulAdd, self).__init__()
def construct(self, x, y): def construct(self, x, y):
return 2 * x + y return 2 * x * x + y * y
def bprop(self, x, y, out, dout): def bprop(self, x, y, out, dout):
assert (x == 1) global bprop_debug
assert (y == 2) bprop_debug = True
assert (out == 4) return dout, 2 * y
assert (dout == 1)
return 3 * dout, 2 * y
def test_custom_bprop(): def test_custom_bprop():
mul_add = MulAdd() mul_add = MulAdd()
mul_add.bprop_debug = True mul_add.bprop_debug = True
assert C.grad_all(mul_add)(1, 2) == (3, 4) x = Tensor(np.array([1, 2, 3]).astype(np.int32))
y = Tensor(np.array([2, 3, 4]).astype(np.int32))
C.grad_all(mul_add)(x, y)
assert bprop_debug
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
def construct(self, x, y):
return 2 * x * x + y * y
def test_grad_all():
net = Net()
x = Tensor(np.array([1, 2, 3]).astype(np.int32))
y = Tensor(np.array([2, 3, 4]).astype(np.int32))
res = C.grad_all(net)(x, y)
print(res)
def test_check_input():
net = Net()
x = np.array([1, 2, 3])
y = np.array([2, 3, 4])
with pytest.raises(TypeError):
net(x, y)
...@@ -46,6 +46,7 @@ def test_InsertGradientOf_1(): ...@@ -46,6 +46,7 @@ def test_InsertGradientOf_1():
c = x * y c = x * y
return c return c
@ms_function
def f(x, y): def f(x, y):
return C.grad_all(stop_test)(x, y) return C.grad_all(stop_test)(x, y)
...@@ -80,6 +81,7 @@ def test_InsertGradientOf_2(): ...@@ -80,6 +81,7 @@ def test_InsertGradientOf_2():
def f(x, y): def f(x, y):
return clip_test(x, y) return clip_test(x, y)
@ms_function
def fd(x, y): def fd(x, y):
return C.grad_all(clip_test)(x, y) return C.grad_all(clip_test)(x, y)
......
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
import numpy as np import numpy as np
import pytest import pytest
import mindspore as ms
import mindspore.common.dtype as mstype import mindspore.common.dtype as mstype
import mindspore.nn as nn import mindspore.nn as nn
from mindspore import Parameter, ParameterTuple from mindspore import Parameter, ParameterTuple
...@@ -81,16 +82,24 @@ def stop_test4(x, y): ...@@ -81,16 +82,24 @@ def stop_test4(x, y):
return e return e
@ms_function
def grad_stop_test(x, y): def grad_stop_test(x, y):
""" grad_stop_test """ """ grad_stop_test """
return C.grad_all(stop_test2)(x, y) return C.grad_all(stop_test2)(x, y)
@ms_function
def grad_stop_test1(x, y): def grad_stop_test1(x, y):
""" grad_stop_test1 """ """ grad_stop_test1 """
return C.grad_all(stop_test3)(x, y) return C.grad_all(stop_test3)(x, y)
@ms_function
def grad_stop_test5(x, y):
""" grad_stop_test5 """
return C.grad_all(stop_test5)(x, y)
def test_stop(): def test_stop():
""" test_stop """ """ test_stop """
print("test_stop:", grad_stop_test(1, 1)) print("test_stop:", grad_stop_test(1, 1))
...@@ -103,7 +112,7 @@ def test_stop1(): ...@@ -103,7 +112,7 @@ def test_stop1():
def test_stop5(): def test_stop5():
""" test_stop1 """ """ test_stop1 """
print("test_stop5:", C.grad_all(stop_test5)(2, 3)) print("test_stop5:", grad_stop_test5(2, 3))
class GradWrap(nn.Cell): class GradWrap(nn.Cell):
...@@ -247,7 +256,7 @@ def test_stop_gradient_4(): ...@@ -247,7 +256,7 @@ def test_stop_gradient_4():
def stop_test(x): def stop_test(x):
return stop_gradient(x) return stop_gradient(x)
assert C.grad_all(stop_test)(1) == (0,) assert C.grad_all(stop_test)(Tensor(1, dtype=ms.int32)) == (0,)
def test_stop_gradient_5(): def test_stop_gradient_5():
...@@ -257,7 +266,7 @@ def test_stop_gradient_5(): ...@@ -257,7 +266,7 @@ def test_stop_gradient_5():
ret = x + y ret = x + y
return ret return ret
assert C.grad_all(stop_test)(1) == (1,) assert C.grad_all(stop_test)(Tensor(1, dtype=ms.int32)) == (1,)
def test_stop_gradient_6(): def test_stop_gradient_6():
...@@ -266,7 +275,7 @@ def test_stop_gradient_6(): ...@@ -266,7 +275,7 @@ def test_stop_gradient_6():
ret = stop_gradient(ret) ret = stop_gradient(ret)
return ret return ret
assert C.grad_all(stop_test)(1, 3) == (0, 0) assert C.grad_all(stop_test)(Tensor(1, dtype=ms.int32), Tensor(3, dtype=ms.int32)) == (0, 0)
class PrimWithMultiOutputs(PrimitiveWithInfer): class PrimWithMultiOutputs(PrimitiveWithInfer):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册