From 7c233a57fa285842054612eb83d612adb7c05e96 Mon Sep 17 00:00:00 2001 From: buxue Date: Mon, 20 Apr 2020 17:40:46 +0800 Subject: [PATCH] support python func print and != for list with none --- mindspore/_extends/parse/resources.py | 1 + mindspore/_extends/parse/trope.py | 4 +- .../composite/multitype_ops/not_equal_impl.py | 37 +++++++++++++++++-- mindspore/ops/functional.py | 2 +- mindspore/ops/operations/_grad_ops.py | 1 + .../ut/python/pipeline/parse/test_operator.py | 6 ++- tests/vm_impl/nn_ops_vm_impl.py | 2 - 7 files changed, 43 insertions(+), 10 deletions(-) diff --git a/mindspore/_extends/parse/resources.py b/mindspore/_extends/parse/resources.py index c2c271669..7178cd263 100644 --- a/mindspore/_extends/parse/resources.py +++ b/mindspore/_extends/parse/resources.py @@ -114,6 +114,7 @@ convert_object_map = { T.map: C.HyperMap(), T.partial: F.partial, T.zip: C.zip_operation, + T.print: F.print_, # custom define operation T.iter: M.ms_iter, diff --git a/mindspore/_extends/parse/trope.py b/mindspore/_extends/parse/trope.py index 9f8f67fba..7b40adcd1 100644 --- a/mindspore/_extends/parse/trope.py +++ b/mindspore/_extends/parse/trope.py @@ -27,7 +27,7 @@ from operator import ( # noqa # support system function call from builtins import ( # noqa - bool, getattr, setattr, len, iter, next, pow, range, map, zip + bool, getattr, setattr, len, iter, next, pow, range, map, zip, print ) # support functools @@ -44,7 +44,7 @@ __all__ = ['add', 'sub', 'mul', 'truediv', 'floordiv', 'mod', 'eq', 'ne', 'lt', 'not_', 'and_', 'or_', 'xor', 'lshift', 'rshift', 'invert', 'is_', 'is_not', 'contains', 'matmul', 'getitem', 'setitem', 'bool', 'getattr', 'setattr', 'len', 'iter', 'next', 'pow', 'range', 'map', 'zip', - 'partial', + 'partial', 'print', 'exp', 'log', 'sin', 'cos', 'tan'] diff --git a/mindspore/ops/composite/multitype_ops/not_equal_impl.py b/mindspore/ops/composite/multitype_ops/not_equal_impl.py index de099a2b8..7196f370c 100644 --- a/mindspore/ops/composite/multitype_ops/not_equal_impl.py +++ b/mindspore/ops/composite/multitype_ops/not_equal_impl.py @@ -132,7 +132,7 @@ def _none_not_equal_scalar(x, y): @not_equal.register("Tuple", "Tuple") -def _euqal_tuple(x, y): +def _not_euqal_tuple(x, y): """ Determine if two tuples are not equal by element. @@ -147,7 +147,7 @@ def _euqal_tuple(x, y): @not_equal.register("List", "List") -def _euqal_list(x, y): +def _not_euqal_list(x, y): """ Determine if two lists are not equal by element. @@ -162,7 +162,7 @@ def _euqal_list(x, y): @not_equal.register("Tuple", "None") -def _tuple_euqal_none(x, y): +def _tuple_not_euqal_none(x, y): """ Determine if tuple element not equals none element. @@ -190,6 +190,7 @@ def _none_not_equal_tuple(x, y): """ return True + @not_equal.register("Tensor", "Number") @not_equal.register("Number", "Tensor") @not_equal.register("Tensor", "Tensor") @@ -235,3 +236,33 @@ def _none_not_equal_tensor(x, y): bool, return True. """ return True + + +@not_equal.register("List", "None") +def _list_not_equal_none(x, y): + """ + Determine if list not equal none. + + Args: + x (list): The first input which is a list. + y (none): The second input which is none. + + Returns: + bool, return true. + """ + return True + + +@not_equal.register("None", "List") +def _none_not_equal_list(x, y): + """ + Determine if none not equal list. + + Args: + x (none): The first input which is none. + y (list): The second input which is a list. + + Returns: + bool, return true. + """ + return True diff --git a/mindspore/ops/functional.py b/mindspore/ops/functional.py index 0ed750beb..d94ef3a11 100644 --- a/mindspore/ops/functional.py +++ b/mindspore/ops/functional.py @@ -66,7 +66,7 @@ scalar_to_array = P.ScalarToArray() scalar_to_tensor = P.ScalarToTensor() tuple_to_array = P.TupleToArray() scalar_cast = P.ScalarCast() - +print_ = P.Print() tuple_setitem = Primitive('tuple_setitem') tuple_getitem = Primitive('tuple_getitem') diff --git a/mindspore/ops/operations/_grad_ops.py b/mindspore/ops/operations/_grad_ops.py index 48d1a2a89..9670ddd86 100644 --- a/mindspore/ops/operations/_grad_ops.py +++ b/mindspore/ops/operations/_grad_ops.py @@ -108,6 +108,7 @@ class BinaryCrossEntropyGrad(PrimitiveWithInfer): validator.check_two_types_same('x_type', x_type, 'weight_type', weight_type) return x_type + class ConcatOffset(PrimitiveWithInfer): """primitive for computing Concat's gradient.""" diff --git a/tests/ut/python/pipeline/parse/test_operator.py b/tests/ut/python/pipeline/parse/test_operator.py index a3c5f7e42..6ae02fa96 100644 --- a/tests/ut/python/pipeline/parse/test_operator.py +++ b/tests/ut/python/pipeline/parse/test_operator.py @@ -160,8 +160,10 @@ def test_ops(): ret_floor = p // q + q // p ret = ret_pow + ret_mod + ret_floor if self.int > self.float: - if self.str_a + self.str_b == "helloworld": - return ret + if [1, 2, 3] != None: + if self.str_a + self.str_b == "helloworld": + print("hello world") + return ret return x net = OpsNet(9, 2) diff --git a/tests/vm_impl/nn_ops_vm_impl.py b/tests/vm_impl/nn_ops_vm_impl.py index fc1fa9502..8794acbbd 100644 --- a/tests/vm_impl/nn_ops_vm_impl.py +++ b/tests/vm_impl/nn_ops_vm_impl.py @@ -151,8 +151,6 @@ def vm_impl_max_pool_grad_with_argmax(self): """Generate vm_impl function for MaxPoolGradWithArgmax""" def vm_impl(x, dout, argmax): - print("buxue") - print(argmax) x = x.asnumpy() dout = dout.asnumpy() arg_max = argmax.asnumpy() -- GitLab