提交 d925c52b 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!2875 add grad all in pynative

Merge pull request !2875 from wangqiuliang/add-grad-all-in-pynative
......@@ -980,7 +980,7 @@ std::vector<AnfNodePtr> PynativeExecutor::GetWeightsArgs(const py::object &weigh
}
}
} else {
MS_LOG(EXCEPTION) << "training not paramter_tuple";
MS_LOG(DEBUG) << "training not paramter_tuple";
}
return w_args;
}
......
......@@ -181,6 +181,9 @@ class Tensor(Tensor_):
def __imod__(self, other):
return self.__mod__(other)
def __pow__(self, other):
return tensor_operator_registry.get('__pow__')(self, other)
def __floordiv__(self, other):
return tensor_operator_registry.get('__floordiv__')(self, other)
......
......@@ -176,7 +176,10 @@ class _Context:
self._context_switches.push(True, None)
else:
if self.enable_debug_runtime:
self.set_backend_policy("ge")
if self.device_target == "CPU":
self.set_backend_policy("vm")
else:
self.set_backend_policy("ge")
self._context_switches.push(False, None)
def set_backend_policy(self, policy):
......
......@@ -16,6 +16,7 @@
import time
import gc
from collections import OrderedDict
import numpy
from mindspore import log as logger
from .. import context
from ..common import dtype as mstype
......@@ -211,6 +212,9 @@ class Cell:
if context.get_context("mode") == context.GRAPH_MODE:
out = self.compile_and_run(*inputs)
return out
for item in inputs:
if isinstance(item, numpy.ndarray):
raise TypeError("cell inputs should not be numpy array.")
self.init_parameters_data()
orign_grad = []
if self.requires_grad is True:
......
......@@ -17,6 +17,7 @@
"""Basic composite operations."""
from functools import partial
from types import FunctionType
from mindspore import context
from ..._c_expression import EnvInstance_, GradOperation_, HyperMap_, Map_, MultitypeFuncGraph_, Tail_, \
......@@ -25,6 +26,7 @@ from ...common import dtype as mstype
from ...common.api import ms_function, _pynative_exec, _wrap_func
from .. import functional as F
from ...common.parameter import Parameter
from ...common.tensor import Tensor
__all__ = [EnvInstance_, TupleAdd_, TupleSlice_, UnpackCall_, TupleGetItemTensor_]
......@@ -114,37 +116,48 @@ class GradOperation(GradOperation_):
self.fn = None
self.need_forward = False
def _pynative_forward_run(self, args, fn):
""" Pynative forward run to build grad graph. """
if self.sens_param:
args = args[:-1]
if isinstance(fn, FunctionType):
_pynative_exec.set_grad_flag(True)
_pynative_exec.new_graph(fn, *args)
output = fn(*args)
_pynative_exec.end_graph(fn, output, *args)
else:
if fn.is_run and not fn.requires_grad:
raise ValueError("obj must set_grad.")
if not fn.is_run:
self.need_forward = True
print("already has forward run before grad by user")
if self.need_forward:
fn.set_grad()
fn(*args)
def __call__(self, fn, weights=None):
grad_ = GradOperation('grad', self.get_all, self.get_by_list, self.sens_param)
if self.grad_fn is None or self.fn != fn:
if self.get_by_list:
if context.get_context("mode") == context.GRAPH_MODE:
if context.get_context("mode") == context.GRAPH_MODE:
if self.get_by_list:
@ms_function(obj=fn)
def after_grad(*args):
return grad_(fn, weights)(*args)
else:
@_wrap_func
@ms_function(obj=fn)
def after_grad(*args):
if fn.is_run and not fn.requires_grad:
raise ValueError("obj must set_grad.")
if not fn.is_run:
self.need_forward = True
print("already has forward run before grad by user")
if self.need_forward:
fn.set_grad()
if self.sens_param:
f_args = args[:-1]
fn(*f_args)
else:
fn(*args)
_pynative_exec.grad(grad_, fn, weights, *args)
out = _pynative_exec(*args)
_pynative_exec.clear()
return out
return grad_(fn)(*args)
else:
@ms_function(obj=fn)
@_wrap_func
def after_grad(*args):
return grad_(fn)(*args)
for arg in args:
if not isinstance(arg, Tensor):
raise TypeError("grad inputs should be tensor in pynative mode")
self._pynative_forward_run(args, fn)
_pynative_exec.grad(grad_, fn, weights, *args)
out = _pynative_exec(*args)
_pynative_exec.clear()
return out
self.grad_fn = after_grad
self.fn = fn
return self.grad_fn
......
......@@ -166,6 +166,7 @@ tensor_operator_registry.register('__sub__', tensor_sub)
tensor_operator_registry.register('__mul__', tensor_mul)
tensor_operator_registry.register('__truediv__', tensor_div)
tensor_operator_registry.register('__mod__', tensor_mod)
tensor_operator_registry.register('__pow__', tensor_pow)
tensor_operator_registry.register('__floordiv__', tensor_floordiv)
#ms cannot support Tensor(True) compare
tensor_operator_registry.register('__eq__', equal)
......
......@@ -228,6 +228,7 @@ def test_biasadd_3d():
error = np.ones(shape=[3, 4, 8]) * 1.0e-6
context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
net = BiasAdd()
net.set_grad()
result = net(x, b)
diff = result.asnumpy() - expect
assert np.all(diff < error)
......
......@@ -45,6 +45,7 @@ def test_net_infer():
def test_assign_in_while():
context.set_context(device_target="Ascend")
context.set_context(mode=context.GRAPH_MODE)
class Net(nn.Cell):
def __init__(self, input_shape):
......
......@@ -16,6 +16,7 @@
import numpy as np
import pytest
import mindspore as ms
import mindspore.common.dtype as mstype
import mindspore.nn as nn
from mindspore import Parameter
......@@ -24,12 +25,15 @@ from mindspore.common.initializer import initializer
from mindspore.common.tensor import Tensor
from mindspore.ops import composite as C
from mindspore.ops import operations as P
from ....mindspore_test_framework.utils.bprop_util import bprop
from .....mindspore_test_framework.utils.bprop_util import bprop
def setup_module(module):
context.set_context(mode=context.PYNATIVE_MODE)
context.set_context(device_target="CPU")
context.set_context(mode=context.GRAPH_MODE)
def teardown_module(module):
context.set_context(device_target="Ascend")
class MulAdd(nn.Cell):
def __init__(self):
......@@ -45,7 +49,9 @@ class MulAdd(nn.Cell):
def test_grad_mul_add():
mul_add = MulAdd()
assert C.grad_all(mul_add)(1, 2) == (2, 4)
x = Tensor(1, dtype=ms.int32)
y = Tensor(2, dtype=ms.int32)
assert C.grad_all(mul_add)(x, y) == (2, 4)
class InlineMulADD(nn.Cell):
......@@ -60,7 +66,9 @@ class InlineMulADD(nn.Cell):
def test_grad_inline_mul_add():
inline_mul_add = InlineMulADD()
assert C.grad_all(inline_mul_add)(1, 2) == (3, 6)
x = Tensor(1, dtype=ms.int32)
y = Tensor(2, dtype=ms.int32)
assert C.grad_all(inline_mul_add)(x, y) == (3, 6)
class WithParameter(nn.Cell):
......@@ -93,7 +101,9 @@ class WithNoBprop(nn.Cell):
def test_with_no_bprop():
with_no_bprop = WithNoBprop()
assert C.grad_all(with_no_bprop)(1, 2) == (2, 1)
x = Tensor(1, dtype=ms.int32)
y = Tensor(2, dtype=ms.int32)
assert C.grad_all(with_no_bprop)(x, y) == (2, 1)
def test_grad_in_bprop_1():
......
......@@ -19,21 +19,27 @@
@Desc :
"""
import logging
import pytest
import numpy as np
import mindspore as ms
import mindspore.nn as nn
from mindspore import Tensor
from mindspore import context
from mindspore.ops import composite as C
from mindspore.common.api import ms_function, _executor
from mindspore.ops._grad.grad_base import bprop_getters
from mindspore.ops.primitive import prim_attr_register, PrimitiveWithInfer
from mindspore.ops.functional import tensor_add
from ...ut_filter import non_graph_engine
# pylint: disable=W0613
# pylint: disable=W0613,W0612
# W0613: unused-argument
log = logging.getLogger("test")
log.setLevel(level=logging.ERROR)
context.set_context(mode=context.GRAPH_MODE)
# Test case: use the parse obj interface use default parameter
......@@ -135,3 +141,113 @@ def test_net_with_ndarray():
input_data = np.array([[1.2, 2.1], [2.2, 3.2]]).astype('float32')
net(ms.Tensor(input_data))
def test_bprop_with_wrong_output_num():
context.set_context(check_bprop=True)
class BpropWithWrongOutputNum(PrimitiveWithInfer):
@prim_attr_register
def __init__(self):
super(BpropWithWrongOutputNum, self).__init__('BpropWithWrongOutputNum')
def __call__(self, x, y):
return x
def infer_shape(self, x_shape, yshape):
return x_shape
def infer_dtype(self, x_type, y_type):
return x_type
@bprop_getters.register(BpropWithWrongOutputNum)
def get_bprop_with_wrong_output_num(self):
"""Generate bprop for BpropWithWrongOutputNum"""
def bprop(x, y, out, dout):
return (dout,)
return bprop
class BpropWithWrongOutputNumCell(nn.Cell):
def __init__(self):
super(BpropWithWrongOutputNumCell, self).__init__()
def construct(self, x, y):
return BpropWithWrongOutputNum()(x, y)
with pytest.raises(TypeError):
C.grad_all(BpropWithWrongOutputNumCell())(1, 2)
def test_bprop_with_wrong_output_type():
context.set_context(check_bprop=True)
class BpropWithWrongOutputType(PrimitiveWithInfer):
@prim_attr_register
def __init__(self):
super(BpropWithWrongOutputType, self).__init__('BpropWithWrongOutputType')
def __call__(self, x):
return x
def infer_shape(self, x_shape):
return x_shape
def infer_dtype(self, x_type):
return x_type
@bprop_getters.register(BpropWithWrongOutputType)
def get_bprop_with_wrong_output_type(self):
"""Generate bprop for BpropWithWrongOutputType"""
def bprop(x, out, dout):
return (1,)
return bprop
class BpropWithWrongOutputTypeCell(nn.Cell):
def __init__(self):
super(BpropWithWrongOutputTypeCell, self).__init__()
def construct(self, x):
return BpropWithWrongOutputType()(x)
with pytest.raises(TypeError):
C.grad_all(BpropWithWrongOutputTypeCell())(Tensor(np.ones([64, 10]).astype(np.int32)))
def test_bprop_with_wrong_output_shape():
context.set_context(check_bprop=True)
class BpropWithWrongOutputShape(PrimitiveWithInfer):
@prim_attr_register
def __init__(self):
super(BpropWithWrongOutputShape, self).__init__('BpropWithWrongOutputShape')
def __call__(self, x):
return x
def infer_shape(self, x_shape):
return x_shape
def infer_dtype(self, x_type):
return x_type
@bprop_getters.register(BpropWithWrongOutputShape)
def get_bprop_with_wrong_output_shape(self):
"""Generate bprop for BpropWithWrongOutputShape"""
ones = Tensor(np.ones([2,]).astype(np.int32))
def bprop(x, out, dout):
return (ones,)
return bprop
class BpropWithWrongOutputShapeCell(nn.Cell):
def __init__(self):
super(BpropWithWrongOutputShapeCell, self).__init__()
def construct(self, x):
return BpropWithWrongOutputShape()(x)
with pytest.raises(TypeError):
net = BpropWithWrongOutputShapeCell()
net.set_grad()
C.grad_all(net)(Tensor(np.ones([64, 10]).astype(np.int32)))
......@@ -78,3 +78,9 @@ def test_tensor_imul():
y = Tensor(np.ones([3, 3, 3, 3]).astype(np.float32))
x *= y
assert x.asnumpy()[0][0][0][0] == 1.0
def test_tensor_pow():
x = Tensor(np.ones([3, 3, 3, 3]).astype(np.float32) * 2)
y = x ** 3
assert y.asnumpy()[0][0][0][0] == 8.0
......@@ -89,7 +89,11 @@ def test_scalar_cast_grad():
output = F.scalar_cast(x, input_t)
return output
gfn = C.grad(fx_cast)(input_x)
@ms_function
def grad_fx_cast(input_x):
return C.grad(fx_cast)(input_x)
gfn = grad_fx_cast(input_x)
expect_dx = 1
assert gfn == expect_dx
......@@ -133,25 +137,6 @@ def test_transpose_grad():
assert np.all(gout[0].asnumpy() == expect)
@non_graph_engine
def test_squeeze_grad():
""" test_squeeze_grad """
input_tensor = Tensor(np.ones(shape=[3, 2, 1]))
squeeze = P.Squeeze(2)
def fn(x):
output = squeeze(x)
return output
out = fn(input_tensor)
gfn = grad_all_with_sens(fn)
sens = Tensor(np.ones_like(out.asnumpy()))
args = [input_tensor, sens]
gout = gfn(*args)
expect = np.ones([3, 2, 1])
assert np.all(gout[0].asnumpy() == expect)
def test_select_grad():
""" test_select_grad """
select = P.Select()
......@@ -176,6 +161,25 @@ def test_select_grad():
assert np.all(gout[2].asnumpy() == expect_y)
@non_graph_engine
def test_squeeze_grad():
""" test_squeeze_grad """
input_tensor = Tensor(np.ones(shape=[3, 2, 1]))
squeeze = P.Squeeze(2)
def fn(x):
output = squeeze(x)
return output
out = fn(input_tensor)
gfn = grad_all_with_sens(fn)
sens = Tensor(np.ones_like(out.asnumpy()))
args = [input_tensor, sens]
gout = gfn(*args)
expect = np.ones([3, 2, 1])
assert np.all(gout[0].asnumpy() == expect)
def test_SubGrad():
""" test_SubGrad """
input_x = Tensor(np.array([[2, 2]]))
......
......@@ -16,6 +16,7 @@
import numpy as np
import pytest
import mindspore as ms
import mindspore.nn as nn
from mindspore import context
from mindspore.common import dtype as mstype
......@@ -23,8 +24,6 @@ from mindspore.common.parameter import Parameter, ParameterTuple
from mindspore.common.tensor import Tensor
from mindspore.ops import composite as C
from mindspore.ops import operations as P
from mindspore.ops._grad.grad_base import bprop_getters
from mindspore.ops.primitive import prim_attr_register, PrimitiveWithInfer
from ..ut_filter import non_graph_engine
from ....mindspore_test_framework.utils.check_gradient import (
ms_function, check_jacobian, Tensor, NNGradChecker,
......@@ -156,14 +155,14 @@ def test_if_always_true():
@non_graph_engine
def test_f():
""" test_f """
res = mainf(3, 2)
res = mainf(Tensor(3, dtype=ms.int32), Tensor(2, dtype=ms.int32))
assert res == (2, 3)
@non_graph_engine
def test_grad_add_mul():
""" test_grad_add_mul """
res = grad_add_mul(3, 2)
res = grad_add_mul(Tensor(3, dtype=ms.int32), Tensor(2, dtype=ms.int32))
assert res == (2, 7)
......@@ -262,17 +261,19 @@ def test_if_tensor():
assert res == Tensor(np.ones([1]).astype(np.int32) * 4)
@ms_function
def rec(x):
""" rec """
if x > 0:
return rec(x - 1)
return x
@ms_function
def grad_rec(input_x):
return C.grad(rec)(input_x)
def test_grad_rec():
""" test_grad_rec """
res = C.grad(rec)(10)
res = grad_rec(3)
assert res == 1
......@@ -282,7 +283,6 @@ def test_me_rec():
assert res == 0
@ms_function
def t2_while(x, y):
out = y - x
i = 0
......@@ -298,8 +298,10 @@ def test_while2():
def test_grad_while2():
res = C.grad(t2_while)(2, 3)
assert res == 3
@ms_function
def df_t2_while(input_x, input_y):
return C.grad(t2_while)(input_x, input_y)
assert df_t2_while(2, 3) == 3
def if_test(a, b):
......@@ -316,7 +318,7 @@ def grad_if(x, y):
def test_grad_if():
""" test_grad_if """
assert grad_if(5, 4) == (3, 0)
assert grad_if(Tensor(5, dtype=ms.int32), Tensor(4, dtype=ms.int32)) == (3, 0)
# While loop is not unrolled in forward and backward graphs.
......@@ -421,7 +423,7 @@ def grad_while(x):
def test_grad_while():
""" test_grad_while """
assert grad_while(5) == (60,)
assert grad_while(Tensor(5, dtype=ms.int32)) == (60,)
@ms_function
......@@ -438,8 +440,10 @@ def test_factorial():
def test_grad_factorial():
res = C.grad(factorial)(3)
assert res == 11
@ms_function
def df_factorial(x):
return C.grad(factorial)(x)
assert df_factorial(3) == 11
@ms_function
......@@ -513,7 +517,7 @@ def _for(x):
ret = ret * i
return ret
@ms_function
def grad_for(x):
""" grad_for """
return C.grad_all(_for)(x)
......@@ -786,7 +790,10 @@ def multi_outputs(x, y):
def test_grad_multi_outputs():
assert C.grad_all_with_sens(multi_outputs)(2, 3, (1, 1)) == (4, 4)
@ms_function
def df_multi_outputs(x, y):
return C.grad_all_with_sens(multi_outputs)(x, y, (1, 1))
assert df_multi_outputs(2, 3) == (4, 4)
@ms_function
......@@ -813,7 +820,7 @@ def grad_refactor_simple_1(x, y):
def test_grad_refactor_simple_1():
assert C.grad_all(grad_refactor_simple_1)(2, 1) == (4, 2)
assert C.grad_all(grad_refactor_simple_1)(Tensor(2, dtype=ms.int32), Tensor(1, dtype=ms.int32)) == (4, 2)
def grad_refactor_simple_2(x, y, z):
......@@ -822,7 +829,10 @@ def grad_refactor_simple_2(x, y, z):
def test_grad_refactor_simple_2():
assert C.grad_all(grad_refactor_simple_2)(2, 3, 0) == (7, 4, 7)
x = Tensor(2, dtype=ms.int32)
y = Tensor(3, dtype=ms.int32)
z = Tensor(0, dtype=ms.int32)
assert C.grad_all(grad_refactor_simple_2)(x, y, z) == (7, 4, 7)
def grad_refactor_1(a, b):
......@@ -835,7 +845,7 @@ def grad_refactor_1(a, b):
def test_grad_refactor_1():
assert C.grad_all(grad_refactor_1)(2, 3) == (3, 2)
assert C.grad_all(grad_refactor_1)(Tensor(2, dtype=ms.int32), Tensor(3, dtype=ms.int32)) == (3, 2)
def grad_refactor_2(a, b):
......@@ -848,7 +858,7 @@ def grad_refactor_2(a, b):
def test_grad_refactor_2():
assert C.grad_all(grad_refactor_2)(2, 3) == (27, 54)
assert C.grad_all(grad_refactor_2)(Tensor(2, dtype=ms.int32), Tensor(3, dtype=ms.int32)) == (27, 54)
def grad_refactor_3(a):
......@@ -859,7 +869,10 @@ def grad_refactor_3(a):
def test_grad_refactor_3():
assert C.grad_all(grad_refactor_3)(3) == (3,)
@ms_function
def df_refactor_3(x):
return C.grad_all(grad_refactor_3)(x)
assert df_refactor_3(3) == (3,)
def grad_refactor_4(a):
......@@ -870,7 +883,7 @@ def grad_refactor_4(a):
def test_grad_refactor_4():
assert C.grad_all(grad_refactor_4)(4) == (3,)
assert C.grad_all(grad_refactor_4)(Tensor(4, dtype=ms.int32)) == (3,)
def grad_refactor_5(a):
......@@ -881,7 +894,10 @@ def grad_refactor_5(a):
def test_grad_refactor_5():
assert C.grad_all(grad_refactor_5)(1) == (1,)
@ms_function
def df_refactor_5(x):
return C.grad_all(grad_refactor_5)(x)
assert df_refactor_5(1) == (1,)
def grad_refactor_6(a, b):
......@@ -892,7 +908,7 @@ def grad_refactor_6(a, b):
def test_grad_refactor_6():
assert C.grad_all(grad_refactor_6)(3, 2) == (3, 1)
assert C.grad_all(grad_refactor_6)(Tensor(3, dtype=ms.int32), Tensor(2, dtype=ms.int32)) == (3, 1)
def grad_refactor_while(x):
......@@ -904,7 +920,10 @@ def grad_refactor_while(x):
def test_grad_refactor_9():
assert C.grad_all(grad_refactor_while)(3) == (6,)
@ms_function
def df_refactor_while(input_x):
return C.grad_all(grad_refactor_while)(input_x)
assert df_refactor_while(3) == (6,)
def grad_refactor__while_1(x):
......@@ -919,7 +938,7 @@ def grad_refactor__while_1(x):
def test_grad_refactor_10():
""" test_grad_while """
assert C.grad_all(grad_refactor__while_1)(5) == (60,)
assert C.grad_all(grad_refactor__while_1)(Tensor(5, dtype=ms.int32)) == (60,)
def test_grad_refactor_11():
......@@ -985,7 +1004,10 @@ def grad_refactor_14(a, b):
def test_grad_refactor_14():
assert C.grad_all(grad_refactor_14)(2, 3) == (3, 9)
@ms_function
def df_refactor_14(x, y):
return C.grad_all(grad_refactor_14)(x, y)
assert df_refactor_14(2, 3) == (3, 9)
# pylint: disable=using-constant-test
......@@ -1009,111 +1031,3 @@ def test_grad_if_defer_inline():
inp = Tensor(np.ones([128, 96]).astype(np.float32))
grads = C.grad_all(network)(inp)
assert grads == (Tensor(np.full([128, 96], 0.6, dtype=np.float32)),)
def test_bprop_with_wrong_output_num():
context.set_context(check_bprop=True)
class BpropWithWrongOutputNum(PrimitiveWithInfer):
@prim_attr_register
def __init__(self):
super(BpropWithWrongOutputNum, self).__init__('BpropWithWrongOutputNum')
def __call__(self, x, y):
return x
def infer_shape(self, x_shape, yshape):
return x_shape
def infer_dtype(self, x_type, y_type):
return x_type
@bprop_getters.register(BpropWithWrongOutputNum)
def get_bprop_with_wrong_output_num(self):
"""Generate bprop for BpropWithWrongOutputNum"""
def bprop(x, y, out, dout):
return (dout,)
return bprop
class BpropWithWrongOutputNumCell(nn.Cell):
def __init__(self):
super(BpropWithWrongOutputNumCell, self).__init__()
def construct(self, x, y):
return BpropWithWrongOutputNum()(x, y)
with pytest.raises(TypeError):
C.grad_all(BpropWithWrongOutputNumCell())(1, 2)
def test_bprop_with_wrong_output_type():
context.set_context(check_bprop=True)
class BpropWithWrongOutputType(PrimitiveWithInfer):
@prim_attr_register
def __init__(self):
super(BpropWithWrongOutputType, self).__init__('BpropWithWrongOutputType')
def __call__(self, x):
return x
def infer_shape(self, x_shape):
return x_shape
def infer_dtype(self, x_type):
return x_type
@bprop_getters.register(BpropWithWrongOutputType)
def get_bprop_with_wrong_output_type(self):
"""Generate bprop for BpropWithWrongOutputType"""
def bprop(x, out, dout):
return (1,)
return bprop
class BpropWithWrongOutputTypeCell(nn.Cell):
def __init__(self):
super(BpropWithWrongOutputTypeCell, self).__init__()
def construct(self, x):
return BpropWithWrongOutputType()(x)
with pytest.raises(TypeError):
C.grad_all(BpropWithWrongOutputTypeCell())(Tensor(np.ones([64, 10]).astype(np.int32)))
def test_bprop_with_wrong_output_shape():
context.set_context(check_bprop=True)
class BpropWithWrongOutputShape(PrimitiveWithInfer):
@prim_attr_register
def __init__(self):
super(BpropWithWrongOutputShape, self).__init__('BpropWithWrongOutputShape')
def __call__(self, x):
return x
def infer_shape(self, x_shape):
return x_shape
def infer_dtype(self, x_type):
return x_type
@bprop_getters.register(BpropWithWrongOutputShape)
def get_bprop_with_wrong_output_shape(self):
"""Generate bprop for BpropWithWrongOutputShape"""
ones = Tensor(np.ones([2,]).astype(np.int32))
def bprop(x, out, dout):
return (ones,)
return bprop
class BpropWithWrongOutputShapeCell(nn.Cell):
def __init__(self):
super(BpropWithWrongOutputShapeCell, self).__init__()
def construct(self, x):
return BpropWithWrongOutputShape()(x)
with pytest.raises(TypeError):
C.grad_all(BpropWithWrongOutputShapeCell())(Tensor(np.ones([64, 10]).astype(np.int32)))
......@@ -13,6 +13,7 @@
# limitations under the License.
# ============================================================================
import numpy as np
import pytest
import mindspore.nn as nn
import mindspore.ops.operations as P
......@@ -154,22 +155,47 @@ def test_hook():
print(loss_output.asnumpy().shape)
bprop_debug = False
class MulAdd(nn.Cell):
def __init__(self):
super(MulAdd, self).__init__()
def construct(self, x, y):
return 2 * x + y
return 2 * x * x + y * y
def bprop(self, x, y, out, dout):
assert (x == 1)
assert (y == 2)
assert (out == 4)
assert (dout == 1)
return 3 * dout, 2 * y
global bprop_debug
bprop_debug = True
return dout, 2 * y
def test_custom_bprop():
mul_add = MulAdd()
mul_add.bprop_debug = True
assert C.grad_all(mul_add)(1, 2) == (3, 4)
x = Tensor(np.array([1, 2, 3]).astype(np.int32))
y = Tensor(np.array([2, 3, 4]).astype(np.int32))
C.grad_all(mul_add)(x, y)
assert bprop_debug
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
def construct(self, x, y):
return 2 * x * x + y * y
def test_grad_all():
net = Net()
x = Tensor(np.array([1, 2, 3]).astype(np.int32))
y = Tensor(np.array([2, 3, 4]).astype(np.int32))
res = C.grad_all(net)(x, y)
print(res)
def test_check_input():
net = Net()
x = np.array([1, 2, 3])
y = np.array([2, 3, 4])
with pytest.raises(TypeError):
net(x, y)
......@@ -46,6 +46,7 @@ def test_InsertGradientOf_1():
c = x * y
return c
@ms_function
def f(x, y):
return C.grad_all(stop_test)(x, y)
......@@ -80,6 +81,7 @@ def test_InsertGradientOf_2():
def f(x, y):
return clip_test(x, y)
@ms_function
def fd(x, y):
return C.grad_all(clip_test)(x, y)
......
......@@ -16,6 +16,7 @@
import numpy as np
import pytest
import mindspore as ms
import mindspore.common.dtype as mstype
import mindspore.nn as nn
from mindspore import Parameter, ParameterTuple
......@@ -81,16 +82,24 @@ def stop_test4(x, y):
return e
@ms_function
def grad_stop_test(x, y):
""" grad_stop_test """
return C.grad_all(stop_test2)(x, y)
@ms_function
def grad_stop_test1(x, y):
""" grad_stop_test1 """
return C.grad_all(stop_test3)(x, y)
@ms_function
def grad_stop_test5(x, y):
""" grad_stop_test5 """
return C.grad_all(stop_test5)(x, y)
def test_stop():
""" test_stop """
print("test_stop:", grad_stop_test(1, 1))
......@@ -103,7 +112,7 @@ def test_stop1():
def test_stop5():
""" test_stop1 """
print("test_stop5:", C.grad_all(stop_test5)(2, 3))
print("test_stop5:", grad_stop_test5(2, 3))
class GradWrap(nn.Cell):
......@@ -247,7 +256,7 @@ def test_stop_gradient_4():
def stop_test(x):
return stop_gradient(x)
assert C.grad_all(stop_test)(1) == (0,)
assert C.grad_all(stop_test)(Tensor(1, dtype=ms.int32)) == (0,)
def test_stop_gradient_5():
......@@ -257,7 +266,7 @@ def test_stop_gradient_5():
ret = x + y
return ret
assert C.grad_all(stop_test)(1) == (1,)
assert C.grad_all(stop_test)(Tensor(1, dtype=ms.int32)) == (1,)
def test_stop_gradient_6():
......@@ -266,7 +275,7 @@ def test_stop_gradient_6():
ret = stop_gradient(ret)
return ret
assert C.grad_all(stop_test)(1, 3) == (0, 0)
assert C.grad_all(stop_test)(Tensor(1, dtype=ms.int32), Tensor(3, dtype=ms.int32)) == (0, 0)
class PrimWithMultiOutputs(PrimitiveWithInfer):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册