提交 46cad4d3 编写于 作者: M Megvii Engine Team

feat(functional/ops): add _assert_equal

GitOrigin-RevId: b7ce4158b7087886e7a9aef5c89b682cae26c646
上级 585aa561
...@@ -7,7 +7,7 @@ ...@@ -7,7 +7,7 @@
# software distributed under the License is distributed on an # software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# pylint: disable=redefined-builtin # pylint: disable=redefined-builtin
from . import metric, vision from . import metric, utils, vision
from .elemwise import * from .elemwise import *
from .math import * from .math import *
from .nn import * from .nn import *
......
...@@ -11,6 +11,7 @@ from typing import Iterable, Union ...@@ -11,6 +11,7 @@ from typing import Iterable, Union
import numpy as np import numpy as np
from ..tensor import Tensor from ..tensor import Tensor
from .elemwise import abs, maximum, minimum
from .math import topk as _topk from .math import topk as _topk
from .tensor import broadcast_to, transpose from .tensor import broadcast_to, transpose
......
# -*- coding: utf-8 -*-
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from ..core._imperative_rt.core2 import apply
from ..core._imperative_rt.core2 import sync as _sync
from ..core.ops.builtin import AssertEqual
from ..tensor import Tensor
from .elemwise import abs, maximum, minimum
def _assert_equal(
expect: Tensor, actual: Tensor, *, maxerr: float = 0.0001, verbose: bool = False
):
r"""
Asserts two tensors equal and returns expected value (first input).
It is a variant of python assert which is symbolically traceable (similar to ``numpy.testing.assert_equal``).
If we want to verify the correctness of model, just ``assert`` its states and outputs.
While sometimes we need to verify the correctness at different backends for *dumped* model
(or in :class:`~jit.trace` context), and no python code could be executed in that case.
Thus we have to use :func:`~functional.utils._assert_equal` instead.
:param expect: expected tensor value
:param actual: tensor to check value
:param maxerr: max allowed error; error is defined as the minimal of absolute and relative error
:param verbose: whether to print maxerr to stdout during opr exec
:return: expected tensor
Examples:
.. testcode::
import numpy as np
from megengine import tensor
import megengine.functional as F
x = tensor([1, 2, 3], np.float32)
y = tensor([1, 2, 3], np.float32)
print(F.utils._assert_equal(x, y, maxerr=0).numpy())
Outputs:
.. testoutput::
[1. 2. 3.]
"""
err = (
abs(expect - actual)
/ maximum(minimum(abs(expect), abs(actual)), Tensor(1.0, dtype="float32"))
).max()
result = apply(AssertEqual(maxerr=maxerr, verbose=verbose), expect, actual, err)[0]
_sync() # sync interpreter to get exception
return result
...@@ -28,7 +28,12 @@ from ..core._imperative_rt.core2 import ( ...@@ -28,7 +28,12 @@ from ..core._imperative_rt.core2 import (
unset_compiled, unset_compiled,
unset_tracing, unset_tracing,
) )
from ..core._imperative_rt.ops import CollectiveComm, RemoteRecv, RemoteSend from ..core._imperative_rt.ops import (
AssertEqual,
CollectiveComm,
RemoteRecv,
RemoteSend,
)
from ..core._trace_option import set_symbolic_shape from ..core._trace_option import set_symbolic_shape
from ..core._wrap import device as as_device from ..core._wrap import device as as_device
from ..core.ops.builtin import BackwardGraph, OpDef from ..core.ops.builtin import BackwardGraph, OpDef
...@@ -110,7 +115,7 @@ class TensorInfo: ...@@ -110,7 +115,7 @@ class TensorInfo:
self.data_reader = None self.data_reader = None
_io_op_types = {CollectiveComm, RemoteSend, RemoteRecv} _io_op_types = {AssertEqual, CollectiveComm, RemoteSend, RemoteRecv}
class trace: class trace:
......
...@@ -21,6 +21,7 @@ from megengine.core._trace_option import use_symbolic_shape ...@@ -21,6 +21,7 @@ from megengine.core._trace_option import use_symbolic_shape
from megengine.core.autodiff.grad import Grad from megengine.core.autodiff.grad import Grad
from megengine.core.tensor.utils import make_shape_tuple from megengine.core.tensor.utils import make_shape_tuple
from megengine.distributed.helper import get_device_count_by_fork from megengine.distributed.helper import get_device_count_by_fork
from megengine.jit import trace
def test_where(): def test_where():
...@@ -746,3 +747,18 @@ def test_ones(val): ...@@ -746,3 +747,18 @@ def test_ones(val):
shp = tensor(val) shp = tensor(val)
np_shp = np.array(val) np_shp = np.array(val)
np.testing.assert_equal(F.ones(shp), np.ones(np_shp)) np.testing.assert_equal(F.ones(shp), np.ones(np_shp))
def test_assert_equal():
shape = (2, 3, 4, 5)
x = F.ones(shape, dtype=np.float32)
y = F.zeros(shape, dtype=np.float32) + 1.00001
z = F.utils._assert_equal(x, y)
def test_assert_not_equal():
shape = (2, 3, 4, 5)
x = F.ones(shape, dtype=np.float32)
y = F.zeros(shape, dtype=np.float32) + 1.1
with pytest.raises(RuntimeError):
z = F.utils._assert_equal(x, y)
...@@ -453,18 +453,20 @@ namespace { namespace assert_equal { ...@@ -453,18 +453,20 @@ namespace { namespace assert_equal {
auto apply_on_var_node( auto apply_on_var_node(
const OpDef& def, const OpDef& def,
const VarNodeArray& inputs) { const VarNodeArray& inputs) {
auto&& op = static_cast<const AssertEqual&>(def); auto&& op = def.cast_final<AssertEqual>();
mgb_assert(inputs.size() == 2); if (inputs.size() == 2) {
OperatorNodeConfig config{op.make_name()}; return opr::AssertEqual::make(inputs[0], inputs[1], op.param());
return opr::AssertEqual::make(inputs[0], inputs[1], op.param(), config); } else {
// workaround for MiniGraph, which only allow one opr in the graph
mgb_assert(inputs.size() == 3);
return opr::AssertEqual::make(inputs[0], inputs[1], inputs[2], op.param(), {});
} }
}
OP_TRAIT_REG(AssertEqual, AssertEqual) OP_TRAIT_REG(AssertEqual, AssertEqual)
.apply_on_var_node(apply_on_var_node) .apply_on_var_node(apply_on_var_node)
.fallback(); .fallback();
}} // assert_equal
}}
namespace { namespace uniform_rng { namespace { namespace uniform_rng {
auto apply_on_var_node( auto apply_on_var_node(
......
...@@ -445,6 +445,12 @@ public: ...@@ -445,6 +445,12 @@ public:
size_t nr_oprs_in_graph() const override {return m_opr_refkeeper.size();} size_t nr_oprs_in_graph() const override {return m_opr_refkeeper.size();}
void record_async_error(std::unique_ptr<MegBrainError> async_exc) override {
if (!ProxyGraph::tm_async_error) {
std::swap(async_exc, tm_async_error);
}
}
std::unique_ptr<cg::AsyncExecutable> compile(const OutputSpec &out_spec) override {mgb_assert(0);} std::unique_ptr<cg::AsyncExecutable> compile(const OutputSpec &out_spec) override {mgb_assert(0);}
SmallVector<std::unique_ptr<cg::AsyncExecutable>> compile_multi_part( SmallVector<std::unique_ptr<cg::AsyncExecutable>> compile_multi_part(
const SmallVector<OutputSpec>& out_specs) override {mgb_assert(0);} const SmallVector<OutputSpec>& out_specs) override {mgb_assert(0);}
...@@ -457,7 +463,6 @@ public: ...@@ -457,7 +463,6 @@ public:
size_t get_device_memory_size(CompNode cn) override {mgb_assert(0);} size_t get_device_memory_size(CompNode cn) override {mgb_assert(0);}
size_t clear_device_memory() override {mgb_assert(0);} size_t clear_device_memory() override {mgb_assert(0);}
void set_as_subgraph(ComputingGraph &par_graph) override {mgb_assert(0);} void set_as_subgraph(ComputingGraph &par_graph) override {mgb_assert(0);}
void record_async_error(std::unique_ptr<MegBrainError> async_exc) override {mgb_assert(0);}
}; };
std::atomic<size_t> ProxyGraph::ProxyGraphImpl::m_node_id = 0; std::atomic<size_t> ProxyGraph::ProxyGraphImpl::m_node_id = 0;
...@@ -861,6 +866,8 @@ TensorPtr ProxyGraph::as_tensor(cg::OperatorNodeBase* opr, bool share) { ...@@ -861,6 +866,8 @@ TensorPtr ProxyGraph::as_tensor(cg::OperatorNodeBase* opr, bool share) {
} }
} }
thread_local std::unique_ptr<MegBrainError> ProxyGraph::tm_async_error;
} // namespace imperative } // namespace imperative
} // namespace mgb } // namespace mgb
......
...@@ -24,6 +24,9 @@ namespace imperative { ...@@ -24,6 +24,9 @@ namespace imperative {
class ProxyGraph : public NonCopyableObj { class ProxyGraph : public NonCopyableObj {
public: public:
static ProxyGraph* get_default_graph(); static ProxyGraph* get_default_graph();
static std::unique_ptr<MegBrainError> get_async_error() {
return std::move(tm_async_error);
}
/********************** Physical Tensor API **********************/ /********************** Physical Tensor API **********************/
...@@ -98,6 +101,8 @@ private: ...@@ -98,6 +101,8 @@ private:
std::unique_ptr<ExecEnv> m_env; std::unique_ptr<ExecEnv> m_env;
std::unique_ptr<StaticInferManager> m_static_infer_manager; std::unique_ptr<StaticInferManager> m_static_infer_manager;
std::unique_ptr<SeqCompNodeOptimizer> m_seq_comp_node_optimizer; std::unique_ptr<SeqCompNodeOptimizer> m_seq_comp_node_optimizer;
static thread_local std::unique_ptr<MegBrainError> tm_async_error;
}; };
} // namespace imperative } // namespace imperative
......
...@@ -101,6 +101,10 @@ apply_on_physical_tensor(const OpDef& def, ...@@ -101,6 +101,10 @@ apply_on_physical_tensor(const OpDef& def,
} }
} }
exec(def, inputs, outputs); exec(def, inputs, outputs);
auto async_error = ProxyGraph::get_async_error();
if (async_error) {
throw *async_error;
}
return outputs; return outputs;
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册