diff --git a/mindspore/ccsrc/pipeline/jit/parse/parse.cc b/mindspore/ccsrc/pipeline/jit/parse/parse.cc index eee9d00bbe349ed885e350fec24fe1f545119692..f525c9fccf997645f2ad1396d7c3bc1824677c47 100644 --- a/mindspore/ccsrc/pipeline/jit/parse/parse.cc +++ b/mindspore/ccsrc/pipeline/jit/parse/parse.cc @@ -737,8 +737,7 @@ AnfNodePtr Parser::ParseCompare(const FunctionBlockPtr &block, const py::object return block->func_graph()->NewCNode({op_node, left_node, right_node}); } -AnfNodePtr Parser::ProcessBoolOpValueList(const FunctionBlockPtr &block, const py::list &value_list, - const py::object &op) { +AnfNodePtr Parser::ProcessBoolOpValueList(const FunctionBlockPtr &block, const py::list &value_list, AstSubType mode) { // if there is only one bool op now if (value_list.size() == 1) { AnfNodePtr first_node = ParseExprNode(block, value_list[0]); @@ -749,11 +748,41 @@ AnfNodePtr Parser::ProcessBoolOpValueList(const FunctionBlockPtr &block, const p for (size_t i = 1; i < value_list.size(); i++) { rest.append(value_list[i]); } + MS_EXCEPTION_IF_NULL(block); + TraceManager::DebugTrace(std::make_shared(block->func_graph()->debug_info())); + FunctionBlockPtr true_block = MakeFunctionBlock(*this); + TraceManager::EndTrace(); + TraceManager::DebugTrace(std::make_shared(block->func_graph()->debug_info())); + FunctionBlockPtr false_block = MakeFunctionBlock(*this); + TraceManager::EndTrace(); + MakeConditionBlocks(block, true_block, false_block); + FunctionBlockPtr b1, b2; + + // if it is and, we need to process the rest nodes; + // if it is or, we continue to next + if (mode == AST_SUB_TYPE_AND) { + b1 = true_block; + b2 = false_block; + } else if (mode == AST_SUB_TYPE_OR) { + b2 = true_block; + b1 = false_block; + } else { + MS_LOG(ERROR) << "Not supported mode: " << mode; + return nullptr; + } + AnfNodePtr test_node = ParseExprNode(block, first); + AnfNodePtr rest_node = ProcessBoolOpValueList(b1, rest, mode); + b1->func_graph()->set_output(rest_node); + b2->func_graph()->set_output(test_node); + + auto cond_node = block->ForceToBoolNode(test_node); + auto switch_app = + block->func_graph()->NewCNode({NewValueNode(prim::kPrimSwitch), cond_node, NewValueNode(true_block->func_graph()), + NewValueNode(false_block->func_graph())}); - AnfNodePtr first_node = ParseExprNode(block, first); - AnfNodePtr rest_node = ProcessBoolOpValueList(block, rest, op); - auto op_node = block->MakeResolveAstOp(op); - return block->func_graph()->NewCNode({op_node, first_node, rest_node}); + std::vector call_graph_nodes{switch_app}; + auto switch_app_call = block->func_graph()->NewCNode(call_graph_nodes); + return switch_app_call; } } @@ -761,8 +790,13 @@ AnfNodePtr Parser::ProcessBoolOpValueList(const FunctionBlockPtr &block, const p AnfNodePtr Parser::ParseBoolOp(const FunctionBlockPtr &block, const py::object &node) { MS_LOG(DEBUG) << "Process ast BoolOp"; py::object op_node = python_adapter::GetPyObjAttr(node, "op"); + AstSubType op_type = ast_->GetOpType(op_node); + if (op_type == AST_SUB_TYPE_UNKNOWN) { + MS_LOG(WARNING) << "ProcessBoolOp, got unkown op type"; + return nullptr; + } py::list op_values = python_adapter::GetPyObjAttr(node, "values"); - return ProcessBoolOpValueList(block, op_values, op_node); + return ProcessBoolOpValueList(block, op_values, op_type); } // Process a function def diff --git a/mindspore/ccsrc/pipeline/jit/parse/parse.h b/mindspore/ccsrc/pipeline/jit/parse/parse.h index 47366b664eddcd82068b21d80703ae35fb3bc29c..b922248e5e6b9300c0d8102278c5d847f3878c17 100644 --- a/mindspore/ccsrc/pipeline/jit/parse/parse.h +++ b/mindspore/ccsrc/pipeline/jit/parse/parse.h @@ -206,7 +206,7 @@ class Parser { void HandleAssignSubscript(const FunctionBlockPtr &block, const py::object &targ, const AnfNodePtr &assigned_node); // process a bool operation value list - AnfNodePtr ProcessBoolOpValueList(const FunctionBlockPtr &block, const py::list &value_list, const py::object &op); + AnfNodePtr ProcessBoolOpValueList(const FunctionBlockPtr &block, const py::list &value_list, AstSubType mode); CNodePtr GenerateIteratorInFor(const FunctionBlockPtr &block, const pybind11::object &node, const AnfNodePtr &op_iter); diff --git a/mindspore/ops/composite/multitype_ops/logic_not_impl.py b/mindspore/ops/composite/multitype_ops/logic_not_impl.py index 6705145a64dc6ad47691c76f60af1ccc4c08da73..73219afec16e44bac46a5b8244a78286f15f8af7 100644 --- a/mindspore/ops/composite/multitype_ops/logic_not_impl.py +++ b/mindspore/ops/composite/multitype_ops/logic_not_impl.py @@ -45,7 +45,9 @@ def _logical_not_tensor(x): Returns: Tensor, Return logical not operation result of x. """ - return F.logical_not(x) + if F.isconstant(x): + return F.bool_not(x.__bool__()) + return F.logical_not(x.__bool__()) @logical_not.register("Tuple") diff --git a/tests/st/control/test_ascend_control_sink.py b/tests/st/control/test_ascend_control_sink.py index 39af571c14f6427b577b681a3951108d5f2951f0..0e416e205ed2e08042fb85dbd4607f2ba8977b88 100644 --- a/tests/st/control/test_ascend_control_sink.py +++ b/tests/st/control/test_ascend_control_sink.py @@ -61,8 +61,7 @@ class ControlSimpleIfWithAssign(nn.Cell): class ControlIfinIf(nn.Cell): - def __init__(self): - super().__init__() + """pass""" def construct(self, x, y): if x > y: @@ -151,6 +150,40 @@ class ControlMixedWhileIf(nn.Cell): return out +class AndOperation(nn.Cell): + def __init__(self): + super().__init__() + self.reduce_sum = op.ReduceSum() + + def construct(self, x, y): + x_sum = self.reduce_sum(x) + y_sum = self.reduce_sum(y) + out = x_sum and y_sum + return out + + +class OrOperation(nn.Cell): + def __init__(self): + super().__init__() + self.reduce_sum = op.ReduceSum() + + def construct(self, x, y): + x_sum = self.reduce_sum(x) + y_sum = self.reduce_sum(y) + out = x_sum or y_sum + return out + + +class NotOperation(nn.Cell): + def __init__(self): + super().__init__() + self.reduce_sum = op.ReduceSum() + + def construct(self, x): + x_sum = self.reduce_sum(x) + return not x_sum + + @pytest.mark.level0 @pytest.mark.platform_arm_ascend_training @pytest.mark.platform_x86_ascend_training @@ -248,3 +281,27 @@ def test_mixed_while_if(): output = net(Tensor(x), Tensor(y), Tensor(z), c2, c4) expect = np.array(3318).astype(np.int32) assert np.allclose(expect, output.asnumpy(), 0.0001, 0.0001) + + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard +def test_and_or_operation(): + context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") + x = np.array([0, 1]).astype(np.float32) + y = np.array([0, 0]).astype(np.float32) + net = AndOperation() + output = net(Tensor(x), Tensor(y)) + expect = np.sum(x) and np.sum(y) + assert np.allclose(expect, output.asnumpy(), 0.0001, 0.0001) + + net = OrOperation() + output = net(Tensor(x), Tensor(y)) + expect = np.sum(x) or np.sum(y) + assert np.allclose(expect, output.asnumpy(), 0.0001, 0.0001) + + net = NotOperation() + output = net(Tensor(x)) + expect = not np.sum(x) + assert np.allclose(expect, output.asnumpy(), 0.0001, 0.0001) diff --git a/tests/ut/python/ops/test_python_operators.py b/tests/ut/python/ops/test_python_operators.py index dd85c3310c55e32097e13f7bec33cc32730ffcdf..34881b5089f74d57b1ae10a9e6d8f00a095ab727 100644 --- a/tests/ut/python/ops/test_python_operators.py +++ b/tests/ut/python/ops/test_python_operators.py @@ -103,15 +103,15 @@ class LogicalTensorOpsNet(nn.Cell): self.const_true = Tensor(True, dtype=mstype.bool_) def construct(self, x, y): - ret = x and y and (y or self.const_true) and (not self.const_true) + ret = x and y and (y or self.const_true) and (not y) return ret test_case_ops = [ ('CompareOpsNet', { 'block': ComparisonOpsNet(), - 'desc_inputs': [Tensor(np.ones([6, 9, 10]), dtype=mstype.float32), - Tensor(np.zeros([6, 9, 10]), dtype=mstype.float32)]}), + 'desc_inputs': [Tensor(1.0, dtype=mstype.float32), + Tensor(1.0, dtype=mstype.float32)]}), ('MathOpsNet', { 'block': MathOpsNet(), 'desc_inputs': [Tensor(np.ones([6, 9, 10]), dtype=mstype.float32), @@ -126,8 +126,8 @@ test_case_ops = [ Tensor(np.zeros([6, 9, 10]), dtype=mstype.float32)]}), ('LogicalTensorOps', { 'block': LogicalTensorOpsNet(), - 'desc_inputs': [Tensor(np.ones([6, 9, 10]).astype(np.bool_), dtype=mstype.bool_), - Tensor(np.zeros([6, 9, 10]).astype(np.bool_), dtype=mstype.bool_)]}), + 'desc_inputs': [Tensor(True, dtype=mstype.bool_), + Tensor(False, dtype=mstype.bool_)]}), ] test_case_lists = [test_case_ops] diff --git a/tests/vm_impl/math_ops_vm_impl.py b/tests/vm_impl/math_ops_vm_impl.py index 9a614c9c9240c979f8c0bb5dbc18b9a095a063f5..76ccebbb8e4e141fa4103ab4926ea277886890ba 100644 --- a/tests/vm_impl/math_ops_vm_impl.py +++ b/tests/vm_impl/math_ops_vm_impl.py @@ -41,10 +41,12 @@ def vm_impl_tensor_add(self): # pylint: disable=used-before-assignment @vm_impl_getters.register(P.LogicalNot) def vm_impl_logical_not(self): - x = x.asnumpy() - out = vm.logical_not(x) - return Tensor(out) + def vm_impl(x): + x = x.asnumpy() + out = vm.logical_not(x) + return Tensor(out) + return vm_impl @vm_impl_getters.register(P.MatMul) def vm_impl_mat_mul(self):