Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
46cad4d3
MegEngine
项目概览
MegEngine 天元
/
MegEngine
接近 2 年 前同步成功
通知
414
Star
4708
Fork
583
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
46cad4d3
编写于
3月 11, 2021
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
feat(functional/ops): add _assert_equal
GitOrigin-RevId: b7ce4158b7087886e7a9aef5c89b682cae26c646
上级
585aa561
变更
9
隐藏空白更改
内联
并排
Showing
9 changed file
with
110 addition
and
13 deletion
+110
-13
imperative/python/megengine/functional/__init__.py
imperative/python/megengine/functional/__init__.py
+1
-1
imperative/python/megengine/functional/metric.py
imperative/python/megengine/functional/metric.py
+1
-0
imperative/python/megengine/functional/utils.py
imperative/python/megengine/functional/utils.py
+57
-0
imperative/python/megengine/jit/tracing.py
imperative/python/megengine/jit/tracing.py
+7
-2
imperative/python/test/unit/functional/test_functional.py
imperative/python/test/unit/functional/test_functional.py
+16
-0
imperative/src/impl/ops/specializations.cpp
imperative/src/impl/ops/specializations.cpp
+11
-9
imperative/src/impl/proxy_graph.cpp
imperative/src/impl/proxy_graph.cpp
+8
-1
imperative/src/impl/proxy_graph.h
imperative/src/impl/proxy_graph.h
+5
-0
imperative/src/impl/proxy_graph_detail.cpp
imperative/src/impl/proxy_graph_detail.cpp
+4
-0
未找到文件。
imperative/python/megengine/functional/__init__.py
浏览文件 @
46cad4d3
...
...
@@ -7,7 +7,7 @@
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# pylint: disable=redefined-builtin
from
.
import
metric
,
vision
from
.
import
metric
,
utils
,
vision
from
.elemwise
import
*
from
.math
import
*
from
.nn
import
*
...
...
imperative/python/megengine/functional/metric.py
浏览文件 @
46cad4d3
...
...
@@ -11,6 +11,7 @@ from typing import Iterable, Union
import
numpy
as
np
from
..tensor
import
Tensor
from
.elemwise
import
abs
,
maximum
,
minimum
from
.math
import
topk
as
_topk
from
.tensor
import
broadcast_to
,
transpose
...
...
imperative/python/megengine/functional/utils.py
0 → 100644
浏览文件 @
46cad4d3
# -*- coding: utf-8 -*-
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from
..core._imperative_rt.core2
import
apply
from
..core._imperative_rt.core2
import
sync
as
_sync
from
..core.ops.builtin
import
AssertEqual
from
..tensor
import
Tensor
from
.elemwise
import
abs
,
maximum
,
minimum
def
_assert_equal
(
expect
:
Tensor
,
actual
:
Tensor
,
*
,
maxerr
:
float
=
0.0001
,
verbose
:
bool
=
False
):
r
"""
Asserts two tensors equal and returns expected value (first input).
It is a variant of python assert which is symbolically traceable (similar to ``numpy.testing.assert_equal``).
If we want to verify the correctness of model, just ``assert`` its states and outputs.
While sometimes we need to verify the correctness at different backends for *dumped* model
(or in :class:`~jit.trace` context), and no python code could be executed in that case.
Thus we have to use :func:`~functional.utils._assert_equal` instead.
:param expect: expected tensor value
:param actual: tensor to check value
:param maxerr: max allowed error; error is defined as the minimal of absolute and relative error
:param verbose: whether to print maxerr to stdout during opr exec
:return: expected tensor
Examples:
.. testcode::
import numpy as np
from megengine import tensor
import megengine.functional as F
x = tensor([1, 2, 3], np.float32)
y = tensor([1, 2, 3], np.float32)
print(F.utils._assert_equal(x, y, maxerr=0).numpy())
Outputs:
.. testoutput::
[1. 2. 3.]
"""
err
=
(
abs
(
expect
-
actual
)
/
maximum
(
minimum
(
abs
(
expect
),
abs
(
actual
)),
Tensor
(
1.0
,
dtype
=
"float32"
))
).
max
()
result
=
apply
(
AssertEqual
(
maxerr
=
maxerr
,
verbose
=
verbose
),
expect
,
actual
,
err
)[
0
]
_sync
()
# sync interpreter to get exception
return
result
imperative/python/megengine/jit/tracing.py
浏览文件 @
46cad4d3
...
...
@@ -28,7 +28,12 @@ from ..core._imperative_rt.core2 import (
unset_compiled
,
unset_tracing
,
)
from
..core._imperative_rt.ops
import
CollectiveComm
,
RemoteRecv
,
RemoteSend
from
..core._imperative_rt.ops
import
(
AssertEqual
,
CollectiveComm
,
RemoteRecv
,
RemoteSend
,
)
from
..core._trace_option
import
set_symbolic_shape
from
..core._wrap
import
device
as
as_device
from
..core.ops.builtin
import
BackwardGraph
,
OpDef
...
...
@@ -110,7 +115,7 @@ class TensorInfo:
self
.
data_reader
=
None
_io_op_types
=
{
CollectiveComm
,
RemoteSend
,
RemoteRecv
}
_io_op_types
=
{
AssertEqual
,
CollectiveComm
,
RemoteSend
,
RemoteRecv
}
class
trace
:
...
...
imperative/python/test/unit/functional/test_functional.py
浏览文件 @
46cad4d3
...
...
@@ -21,6 +21,7 @@ from megengine.core._trace_option import use_symbolic_shape
from
megengine.core.autodiff.grad
import
Grad
from
megengine.core.tensor.utils
import
make_shape_tuple
from
megengine.distributed.helper
import
get_device_count_by_fork
from
megengine.jit
import
trace
def
test_where
():
...
...
@@ -746,3 +747,18 @@ def test_ones(val):
shp
=
tensor
(
val
)
np_shp
=
np
.
array
(
val
)
np
.
testing
.
assert_equal
(
F
.
ones
(
shp
),
np
.
ones
(
np_shp
))
def
test_assert_equal
():
shape
=
(
2
,
3
,
4
,
5
)
x
=
F
.
ones
(
shape
,
dtype
=
np
.
float32
)
y
=
F
.
zeros
(
shape
,
dtype
=
np
.
float32
)
+
1.00001
z
=
F
.
utils
.
_assert_equal
(
x
,
y
)
def
test_assert_not_equal
():
shape
=
(
2
,
3
,
4
,
5
)
x
=
F
.
ones
(
shape
,
dtype
=
np
.
float32
)
y
=
F
.
zeros
(
shape
,
dtype
=
np
.
float32
)
+
1.1
with
pytest
.
raises
(
RuntimeError
):
z
=
F
.
utils
.
_assert_equal
(
x
,
y
)
imperative/src/impl/ops/specializations.cpp
浏览文件 @
46cad4d3
...
...
@@ -451,20 +451,22 @@ OP_TRAIT_REG(Identity, Identity)
namespace
{
namespace
assert_equal
{
auto
apply_on_var_node
(
const
OpDef
&
def
,
const
VarNodeArray
&
inputs
)
{
auto
&&
op
=
static_cast
<
const
AssertEqual
&>
(
def
);
mgb_assert
(
inputs
.
size
()
==
2
);
OperatorNodeConfig
config
{
op
.
make_name
()};
return
opr
::
AssertEqual
::
make
(
inputs
[
0
],
inputs
[
1
],
op
.
param
(),
config
);
const
OpDef
&
def
,
const
VarNodeArray
&
inputs
)
{
auto
&&
op
=
def
.
cast_final
<
AssertEqual
>
();
if
(
inputs
.
size
()
==
2
)
{
return
opr
::
AssertEqual
::
make
(
inputs
[
0
],
inputs
[
1
],
op
.
param
());
}
else
{
// workaround for MiniGraph, which only allow one opr in the graph
mgb_assert
(
inputs
.
size
()
==
3
);
return
opr
::
AssertEqual
::
make
(
inputs
[
0
],
inputs
[
1
],
inputs
[
2
],
op
.
param
(),
{});
}
}
OP_TRAIT_REG
(
AssertEqual
,
AssertEqual
)
.
apply_on_var_node
(
apply_on_var_node
)
.
fallback
();
}}
}}
// assert_equal
namespace
{
namespace
uniform_rng
{
auto
apply_on_var_node
(
...
...
imperative/src/impl/proxy_graph.cpp
浏览文件 @
46cad4d3
...
...
@@ -445,6 +445,12 @@ public:
size_t
nr_oprs_in_graph
()
const
override
{
return
m_opr_refkeeper
.
size
();}
void
record_async_error
(
std
::
unique_ptr
<
MegBrainError
>
async_exc
)
override
{
if
(
!
ProxyGraph
::
tm_async_error
)
{
std
::
swap
(
async_exc
,
tm_async_error
);
}
}
std
::
unique_ptr
<
cg
::
AsyncExecutable
>
compile
(
const
OutputSpec
&
out_spec
)
override
{
mgb_assert
(
0
);}
SmallVector
<
std
::
unique_ptr
<
cg
::
AsyncExecutable
>>
compile_multi_part
(
const
SmallVector
<
OutputSpec
>&
out_specs
)
override
{
mgb_assert
(
0
);}
...
...
@@ -457,7 +463,6 @@ public:
size_t
get_device_memory_size
(
CompNode
cn
)
override
{
mgb_assert
(
0
);}
size_t
clear_device_memory
()
override
{
mgb_assert
(
0
);}
void
set_as_subgraph
(
ComputingGraph
&
par_graph
)
override
{
mgb_assert
(
0
);}
void
record_async_error
(
std
::
unique_ptr
<
MegBrainError
>
async_exc
)
override
{
mgb_assert
(
0
);}
};
std
::
atomic
<
size_t
>
ProxyGraph
::
ProxyGraphImpl
::
m_node_id
=
0
;
...
...
@@ -861,6 +866,8 @@ TensorPtr ProxyGraph::as_tensor(cg::OperatorNodeBase* opr, bool share) {
}
}
thread_local
std
::
unique_ptr
<
MegBrainError
>
ProxyGraph
::
tm_async_error
;
}
// namespace imperative
}
// namespace mgb
...
...
imperative/src/impl/proxy_graph.h
浏览文件 @
46cad4d3
...
...
@@ -24,6 +24,9 @@ namespace imperative {
class
ProxyGraph
:
public
NonCopyableObj
{
public:
static
ProxyGraph
*
get_default_graph
();
static
std
::
unique_ptr
<
MegBrainError
>
get_async_error
()
{
return
std
::
move
(
tm_async_error
);
}
/********************** Physical Tensor API **********************/
...
...
@@ -98,6 +101,8 @@ private:
std
::
unique_ptr
<
ExecEnv
>
m_env
;
std
::
unique_ptr
<
StaticInferManager
>
m_static_infer_manager
;
std
::
unique_ptr
<
SeqCompNodeOptimizer
>
m_seq_comp_node_optimizer
;
static
thread_local
std
::
unique_ptr
<
MegBrainError
>
tm_async_error
;
};
}
// namespace imperative
...
...
imperative/src/impl/proxy_graph_detail.cpp
浏览文件 @
46cad4d3
...
...
@@ -101,6 +101,10 @@ apply_on_physical_tensor(const OpDef& def,
}
}
exec
(
def
,
inputs
,
outputs
);
auto
async_error
=
ProxyGraph
::
get_async_error
();
if
(
async_error
)
{
throw
*
async_error
;
}
return
outputs
;
}
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录