提交 7c233a57 编写于 作者: B buxue

support python func print and != for list with none

上级 679dbd27
......@@ -114,6 +114,7 @@ convert_object_map = {
T.map: C.HyperMap(),
T.partial: F.partial,
T.zip: C.zip_operation,
T.print: F.print_,
# custom define operation
T.iter: M.ms_iter,
......
......@@ -27,7 +27,7 @@ from operator import ( # noqa
# support system function call
from builtins import ( # noqa
bool, getattr, setattr, len, iter, next, pow, range, map, zip
bool, getattr, setattr, len, iter, next, pow, range, map, zip, print
)
# support functools
......@@ -44,7 +44,7 @@ __all__ = ['add', 'sub', 'mul', 'truediv', 'floordiv', 'mod', 'eq', 'ne', 'lt',
'not_', 'and_', 'or_', 'xor', 'lshift', 'rshift', 'invert', 'is_', 'is_not', 'contains',
'matmul', 'getitem', 'setitem',
'bool', 'getattr', 'setattr', 'len', 'iter', 'next', 'pow', 'range', 'map', 'zip',
'partial',
'partial', 'print',
'exp', 'log', 'sin', 'cos', 'tan']
......
......@@ -132,7 +132,7 @@ def _none_not_equal_scalar(x, y):
@not_equal.register("Tuple", "Tuple")
def _euqal_tuple(x, y):
def _not_euqal_tuple(x, y):
"""
Determine if two tuples are not equal by element.
......@@ -147,7 +147,7 @@ def _euqal_tuple(x, y):
@not_equal.register("List", "List")
def _euqal_list(x, y):
def _not_euqal_list(x, y):
"""
Determine if two lists are not equal by element.
......@@ -162,7 +162,7 @@ def _euqal_list(x, y):
@not_equal.register("Tuple", "None")
def _tuple_euqal_none(x, y):
def _tuple_not_euqal_none(x, y):
"""
Determine if tuple element not equals none element.
......@@ -190,6 +190,7 @@ def _none_not_equal_tuple(x, y):
"""
return True
@not_equal.register("Tensor", "Number")
@not_equal.register("Number", "Tensor")
@not_equal.register("Tensor", "Tensor")
......@@ -235,3 +236,33 @@ def _none_not_equal_tensor(x, y):
bool, return True.
"""
return True
@not_equal.register("List", "None")
def _list_not_equal_none(x, y):
"""
Determine if list not equal none.
Args:
x (list): The first input which is a list.
y (none): The second input which is none.
Returns:
bool, return true.
"""
return True
@not_equal.register("None", "List")
def _none_not_equal_list(x, y):
"""
Determine if none not equal list.
Args:
x (none): The first input which is none.
y (list): The second input which is a list.
Returns:
bool, return true.
"""
return True
......@@ -66,7 +66,7 @@ scalar_to_array = P.ScalarToArray()
scalar_to_tensor = P.ScalarToTensor()
tuple_to_array = P.TupleToArray()
scalar_cast = P.ScalarCast()
print_ = P.Print()
tuple_setitem = Primitive('tuple_setitem')
tuple_getitem = Primitive('tuple_getitem')
......
......@@ -108,6 +108,7 @@ class BinaryCrossEntropyGrad(PrimitiveWithInfer):
validator.check_two_types_same('x_type', x_type, 'weight_type', weight_type)
return x_type
class ConcatOffset(PrimitiveWithInfer):
"""primitive for computing Concat's gradient."""
......
......@@ -160,8 +160,10 @@ def test_ops():
ret_floor = p // q + q // p
ret = ret_pow + ret_mod + ret_floor
if self.int > self.float:
if self.str_a + self.str_b == "helloworld":
return ret
if [1, 2, 3] != None:
if self.str_a + self.str_b == "helloworld":
print("hello world")
return ret
return x
net = OpsNet(9, 2)
......
......@@ -151,8 +151,6 @@ def vm_impl_max_pool_grad_with_argmax(self):
"""Generate vm_impl function for MaxPoolGradWithArgmax"""
def vm_impl(x, dout, argmax):
print("buxue")
print(argmax)
x = x.asnumpy()
dout = dout.asnumpy()
arg_max = argmax.asnumpy()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册