提交 078738ad 编写于 作者: K kingfo

add tensor mod & floordiv operation

上级 c22c865c
...@@ -158,7 +158,9 @@ class _MindSporeFunction: ...@@ -158,7 +158,9 @@ class _MindSporeFunction:
# replace key with obj info and object ext info when fn is a method # replace key with obj info and object ext info when fn is a method
if self.obj is not None: if self.obj is not None:
self.obj.__parse_method__ = method_name self.obj.__parse_method__ = method_name
generate_name = self.obj.__module__ + "." + str(self.obj.create_time) generate_name = self.obj.__module__ + "."
if self.obj.__class__.__name__ != "ClipByNorm":
generate_name = generate_name + str(self.obj.create_time)
if self.identify_obj is not None: if self.identify_obj is not None:
generate_name = generate_name + str(id(self.identify_obj)) generate_name = generate_name + str(id(self.identify_obj))
......
...@@ -102,16 +102,14 @@ class Tensor(Tensor_): ...@@ -102,16 +102,14 @@ class Tensor(Tensor_):
return out return out
def __iadd__(self, other): def __iadd__(self, other):
out = self.__add__(other) return self.__add__(other)
return out
def __radd__(self, other): def __radd__(self, other):
out = tensor_operator_registry.get('__add__')(self, other) out = tensor_operator_registry.get('__add__')(self, other)
return out return out
def __imul__(self, other): def __imul__(self, other):
out = self.__mul__(other) return self.__mul__(other)
return out
def __rmul__(self, other): def __rmul__(self, other):
out = tensor_operator_registry.get('__mul__')(self, other) out = tensor_operator_registry.get('__mul__')(self, other)
...@@ -130,8 +128,7 @@ class Tensor(Tensor_): ...@@ -130,8 +128,7 @@ class Tensor(Tensor_):
return out return out
def __isub__(self, other): def __isub__(self, other):
out = self.__sub__(other) return self.__sub__(other)
return out
def __rsub__(self, other): def __rsub__(self, other):
out = tensor_operator_registry.get('__sub__')(other, self) out = tensor_operator_registry.get('__sub__')(other, self)
...@@ -168,6 +165,18 @@ class Tensor(Tensor_): ...@@ -168,6 +165,18 @@ class Tensor(Tensor_):
return 1 return 1
return out[0] return out[0]
def __mod__(self, other):
return tensor_operator_registry.get('__mod__')(self, other)
def __imod__(self, other):
return self.__mod__(other)
def __floordiv__(self, other):
return tensor_operator_registry.get('__floordiv__')(self, other)
def __ifloordiv__(self, other):
return self.__floordiv__(other)
def __str__(self): def __str__(self):
if self.dtype == mstype.type_none: if self.dtype == mstype.type_none:
return "Unknown Tensor type!" return "Unknown Tensor type!"
......
...@@ -157,6 +157,8 @@ tensor_operator_registry.register('__add__', tensor_add) ...@@ -157,6 +157,8 @@ tensor_operator_registry.register('__add__', tensor_add)
tensor_operator_registry.register('__sub__', tensor_sub) tensor_operator_registry.register('__sub__', tensor_sub)
tensor_operator_registry.register('__mul__', tensor_mul) tensor_operator_registry.register('__mul__', tensor_mul)
tensor_operator_registry.register('__truediv__', tensor_div) tensor_operator_registry.register('__truediv__', tensor_div)
tensor_operator_registry.register('__mod__', tensor_mod)
tensor_operator_registry.register('__floordiv__', tensor_floordiv)
#ms cannot support Tensor(True) compare #ms cannot support Tensor(True) compare
tensor_operator_registry.register('__eq__', equal) tensor_operator_registry.register('__eq__', equal)
tensor_operator_registry.register('__ne__', not_equal) tensor_operator_registry.register('__ne__', not_equal)
......
...@@ -24,13 +24,15 @@ import pytest ...@@ -24,13 +24,15 @@ import pytest
import mindspore as ms import mindspore as ms
import mindspore.common.api as me import mindspore.common.api as me
import mindspore.nn as nn import mindspore.nn as nn
from mindspore import Tensor from mindspore import Tensor, context
from mindspore.common.initializer import initializer from mindspore.common.initializer import initializer
from mindspore.common.parameter import Parameter from mindspore.common.parameter import Parameter
from ..ut_filter import non_graph_engine from ..ut_filter import non_graph_engine
ndarr = np.ones((2, 3)) ndarr = np.ones((2, 3))
context.set_context(mode=context.GRAPH_MODE)
def test_tensor_flatten(): def test_tensor_flatten():
with pytest.raises(AttributeError): with pytest.raises(AttributeError):
...@@ -452,5 +454,11 @@ def test_tensor_operation(): ...@@ -452,5 +454,11 @@ def test_tensor_operation():
assert np.all(res.asnumpy() == np.ones((3, 3)) * 2) assert np.all(res.asnumpy() == np.ones((3, 3)) * 2)
res = 8 / x res = 8 / x
assert np.all(res.asnumpy() == np.ones((3, 3)) * 2) assert np.all(res.asnumpy() == np.ones((3, 3)) * 2)
res = x % 3
assert np.all(res.asnumpy() == np.ones((3, 3)))
res = x // 3
assert np.all(res.asnumpy() == np.ones((3, 3)))
x %= 3
assert np.all(x.asnumpy() == np.ones((3, 3)))
with pytest.raises(ValueError): with pytest.raises(ValueError):
res = x * (2, 3) res = x * (2, 3)
...@@ -18,8 +18,7 @@ import pytest ...@@ -18,8 +18,7 @@ import pytest
import mindspore.nn as nn import mindspore.nn as nn
from mindspore.common.tensor import Tensor from mindspore.common.tensor import Tensor
from mindspore.nn import WithGradCell, WithLossCell from mindspore.nn import WithGradCell
from mindspore.nn.optim import Momentum
from mindspore.ops import operations as P from mindspore.ops import operations as P
...@@ -63,17 +62,6 @@ def test_lenet_pynative_train_net(): ...@@ -63,17 +62,6 @@ def test_lenet_pynative_train_net():
loss_fn = nn.SoftmaxCrossEntropyWithLogits(is_grad=False) loss_fn = nn.SoftmaxCrossEntropyWithLogits(is_grad=False)
grad_fn = nn.SoftmaxCrossEntropyWithLogits() grad_fn = nn.SoftmaxCrossEntropyWithLogits()
grad_net = WithGradCell(net, grad_fn, sens=dout) grad_net = WithGradCell(net, grad_fn, sens=dout)
gradients = grad_net(data, label)
# update parameters
opt = Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9)
opt(gradients)
# verification
if i == verification_step:
loss_net = WithLossCell(net, loss_fn)
loss_output = loss_net(data, label)
print("The loss of %s-th iteration is %s" % (i, loss_output.asnumpy()))
def test_lenet_pynative_train_model(): def test_lenet_pynative_train_model():
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册