diff --git a/imperative/python/megengine/functional/__init__.py b/imperative/python/megengine/functional/__init__.py index fcd76b1081483ce3d670ef39fe1bff7ebba07e27..7d5c23650849dfc37cff5d8fea42f0d5fd5932ed 100644 --- a/imperative/python/megengine/functional/__init__.py +++ b/imperative/python/megengine/functional/__init__.py @@ -7,7 +7,7 @@ # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # pylint: disable=redefined-builtin -from . import metric, vision +from . import metric, utils, vision from .elemwise import * from .math import * from .nn import * diff --git a/imperative/python/megengine/functional/metric.py b/imperative/python/megengine/functional/metric.py index 91a03e1d25133b1fa38e2883b0b8a0566793a44d..6eb2ccdcfee9c2e8e5b7446c950ed60ee1a52a5e 100644 --- a/imperative/python/megengine/functional/metric.py +++ b/imperative/python/megengine/functional/metric.py @@ -11,6 +11,7 @@ from typing import Iterable, Union import numpy as np from ..tensor import Tensor +from .elemwise import abs, maximum, minimum from .math import topk as _topk from .tensor import broadcast_to, transpose diff --git a/imperative/python/megengine/functional/utils.py b/imperative/python/megengine/functional/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..15616f3c62d6dfaef87ea4412d56af2384c68bb6 --- /dev/null +++ b/imperative/python/megengine/functional/utils.py @@ -0,0 +1,57 @@ +# -*- coding: utf-8 -*- +# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") +# +# Copyright (c) 2014-2021 Megvii Inc. All rights reserved. +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +from ..core._imperative_rt.core2 import apply +from ..core._imperative_rt.core2 import sync as _sync +from ..core.ops.builtin import AssertEqual +from ..tensor import Tensor +from .elemwise import abs, maximum, minimum + + +def _assert_equal( + expect: Tensor, actual: Tensor, *, maxerr: float = 0.0001, verbose: bool = False +): + r""" + Asserts two tensors equal and returns expected value (first input). + It is a variant of python assert which is symbolically traceable (similar to ``numpy.testing.assert_equal``). + If we want to verify the correctness of model, just ``assert`` its states and outputs. + While sometimes we need to verify the correctness at different backends for *dumped* model + (or in :class:`~jit.trace` context), and no python code could be executed in that case. + Thus we have to use :func:`~functional.utils._assert_equal` instead. + + :param expect: expected tensor value + :param actual: tensor to check value + :param maxerr: max allowed error; error is defined as the minimal of absolute and relative error + :param verbose: whether to print maxerr to stdout during opr exec + :return: expected tensor + + Examples: + + .. testcode:: + + import numpy as np + from megengine import tensor + import megengine.functional as F + + x = tensor([1, 2, 3], np.float32) + y = tensor([1, 2, 3], np.float32) + print(F.utils._assert_equal(x, y, maxerr=0).numpy()) + + Outputs: + + .. testoutput:: + + [1. 2. 3.] + """ + err = ( + abs(expect - actual) + / maximum(minimum(abs(expect), abs(actual)), Tensor(1.0, dtype="float32")) + ).max() + result = apply(AssertEqual(maxerr=maxerr, verbose=verbose), expect, actual, err)[0] + _sync() # sync interpreter to get exception + return result diff --git a/imperative/python/megengine/jit/tracing.py b/imperative/python/megengine/jit/tracing.py index 91c6e7d347aeaad2f8ca96755a43764a4ca1bb6a..fc8f12e8687b71532d91b2219f22460814389377 100644 --- a/imperative/python/megengine/jit/tracing.py +++ b/imperative/python/megengine/jit/tracing.py @@ -28,7 +28,12 @@ from ..core._imperative_rt.core2 import ( unset_compiled, unset_tracing, ) -from ..core._imperative_rt.ops import CollectiveComm, RemoteRecv, RemoteSend +from ..core._imperative_rt.ops import ( + AssertEqual, + CollectiveComm, + RemoteRecv, + RemoteSend, +) from ..core._trace_option import set_symbolic_shape from ..core._wrap import device as as_device from ..core.ops.builtin import BackwardGraph, OpDef @@ -110,7 +115,7 @@ class TensorInfo: self.data_reader = None -_io_op_types = {CollectiveComm, RemoteSend, RemoteRecv} +_io_op_types = {AssertEqual, CollectiveComm, RemoteSend, RemoteRecv} class trace: diff --git a/imperative/python/test/unit/functional/test_functional.py b/imperative/python/test/unit/functional/test_functional.py index 84ce29bc7c42d7685d22f05efd5bc8747c2b675a..54437072b3a6a52137448053b8d159dad4ccc5fd 100644 --- a/imperative/python/test/unit/functional/test_functional.py +++ b/imperative/python/test/unit/functional/test_functional.py @@ -21,6 +21,7 @@ from megengine.core._trace_option import use_symbolic_shape from megengine.core.autodiff.grad import Grad from megengine.core.tensor.utils import make_shape_tuple from megengine.distributed.helper import get_device_count_by_fork +from megengine.jit import trace def test_where(): @@ -746,3 +747,18 @@ def test_ones(val): shp = tensor(val) np_shp = np.array(val) np.testing.assert_equal(F.ones(shp), np.ones(np_shp)) + + +def test_assert_equal(): + shape = (2, 3, 4, 5) + x = F.ones(shape, dtype=np.float32) + y = F.zeros(shape, dtype=np.float32) + 1.00001 + z = F.utils._assert_equal(x, y) + + +def test_assert_not_equal(): + shape = (2, 3, 4, 5) + x = F.ones(shape, dtype=np.float32) + y = F.zeros(shape, dtype=np.float32) + 1.1 + with pytest.raises(RuntimeError): + z = F.utils._assert_equal(x, y) diff --git a/imperative/src/impl/ops/specializations.cpp b/imperative/src/impl/ops/specializations.cpp index 36a27e871a3516cc6ccab3d105f7990e81d7b51d..7e55d44d02adbeaa3e34a6d0c281e40f5e9ed34f 100644 --- a/imperative/src/impl/ops/specializations.cpp +++ b/imperative/src/impl/ops/specializations.cpp @@ -451,20 +451,22 @@ OP_TRAIT_REG(Identity, Identity) namespace { namespace assert_equal { auto apply_on_var_node( - const OpDef& def, - const VarNodeArray& inputs) { - auto&& op = static_cast(def); - mgb_assert(inputs.size() == 2); - OperatorNodeConfig config{op.make_name()}; - return opr::AssertEqual::make(inputs[0], inputs[1], op.param(), config); - + const OpDef& def, + const VarNodeArray& inputs) { + auto&& op = def.cast_final(); + if (inputs.size() == 2) { + return opr::AssertEqual::make(inputs[0], inputs[1], op.param()); + } else { + // workaround for MiniGraph, which only allow one opr in the graph + mgb_assert(inputs.size() == 3); + return opr::AssertEqual::make(inputs[0], inputs[1], inputs[2], op.param(), {}); } +} OP_TRAIT_REG(AssertEqual, AssertEqual) .apply_on_var_node(apply_on_var_node) .fallback(); - -}} +}} // assert_equal namespace { namespace uniform_rng { auto apply_on_var_node( diff --git a/imperative/src/impl/proxy_graph.cpp b/imperative/src/impl/proxy_graph.cpp index d96d6c94598e2a3b20bfc72c9581034459f42c2b..c23ecf48593fdae67b2aa86e1eb882488e37f0d1 100644 --- a/imperative/src/impl/proxy_graph.cpp +++ b/imperative/src/impl/proxy_graph.cpp @@ -445,6 +445,12 @@ public: size_t nr_oprs_in_graph() const override {return m_opr_refkeeper.size();} + void record_async_error(std::unique_ptr async_exc) override { + if (!ProxyGraph::tm_async_error) { + std::swap(async_exc, tm_async_error); + } + } + std::unique_ptr compile(const OutputSpec &out_spec) override {mgb_assert(0);} SmallVector> compile_multi_part( const SmallVector& out_specs) override {mgb_assert(0);} @@ -457,7 +463,6 @@ public: size_t get_device_memory_size(CompNode cn) override {mgb_assert(0);} size_t clear_device_memory() override {mgb_assert(0);} void set_as_subgraph(ComputingGraph &par_graph) override {mgb_assert(0);} - void record_async_error(std::unique_ptr async_exc) override {mgb_assert(0);} }; std::atomic ProxyGraph::ProxyGraphImpl::m_node_id = 0; @@ -861,6 +866,8 @@ TensorPtr ProxyGraph::as_tensor(cg::OperatorNodeBase* opr, bool share) { } } +thread_local std::unique_ptr ProxyGraph::tm_async_error; + } // namespace imperative } // namespace mgb diff --git a/imperative/src/impl/proxy_graph.h b/imperative/src/impl/proxy_graph.h index 4fc7dcfc4a1fafdcf5e8f57a820142578f2c6e98..e0927cbc7b134c1065f2594d04ae525e90137599 100644 --- a/imperative/src/impl/proxy_graph.h +++ b/imperative/src/impl/proxy_graph.h @@ -24,6 +24,9 @@ namespace imperative { class ProxyGraph : public NonCopyableObj { public: static ProxyGraph* get_default_graph(); + static std::unique_ptr get_async_error() { + return std::move(tm_async_error); + } /********************** Physical Tensor API **********************/ @@ -98,6 +101,8 @@ private: std::unique_ptr m_env; std::unique_ptr m_static_infer_manager; std::unique_ptr m_seq_comp_node_optimizer; + + static thread_local std::unique_ptr tm_async_error; }; } // namespace imperative diff --git a/imperative/src/impl/proxy_graph_detail.cpp b/imperative/src/impl/proxy_graph_detail.cpp index 46f5e542d90b8561b4f3810e964056dd7efd2ba7..d68ce881c771d8bcc834f38db8c0a9dcd63d09dd 100644 --- a/imperative/src/impl/proxy_graph_detail.cpp +++ b/imperative/src/impl/proxy_graph_detail.cpp @@ -101,6 +101,10 @@ apply_on_physical_tensor(const OpDef& def, } } exec(def, inputs, outputs); + auto async_error = ProxyGraph::get_async_error(); + if (async_error) { + throw *async_error; + } return outputs; }