Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
51397bfc
MegEngine
项目概览
MegEngine 天元
/
MegEngine
大约 1 年 前同步成功
通知
399
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
51397bfc
编写于
6月 24, 2022
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
feat(mgb): supports value infer and empty input tensor in ElemwiseMultiType
GitOrigin-RevId: 05577a8bc8e214dcd7d7fc138ef952fc881c7a88
上级
247e2f59
变更
8
隐藏空白更改
内联
并排
Showing
8 changed file
with
227 addition
and
57 deletion
+227
-57
imperative/python/megengine/core/tensor/array_method.py
imperative/python/megengine/core/tensor/array_method.py
+6
-7
imperative/python/megengine/functional/math.py
imperative/python/megengine/functional/math.py
+2
-2
imperative/python/src/tensor.cpp
imperative/python/src/tensor.cpp
+1
-1
imperative/src/impl/transformations/dtype_promote.cpp
imperative/src/impl/transformations/dtype_promote.cpp
+36
-0
src/opr/impl/basic_arith.cpp
src/opr/impl/basic_arith.cpp
+0
-46
src/opr/impl/nn_int.cpp
src/opr/impl/nn_int.cpp
+123
-1
src/opr/include/megbrain/opr/nn_int.h
src/opr/include/megbrain/opr/nn_int.h
+16
-0
src/opr/include/megbrain/opr/utility.h
src/opr/include/megbrain/opr/utility.h
+43
-0
未找到文件。
imperative/python/megengine/core/tensor/array_method.py
浏览文件 @
51397bfc
...
...
@@ -37,7 +37,6 @@ _ElwMod = builtin.Elemwise.Mode
def
_elemwise_multi_type
(
*
args
,
mode
,
**
kwargs
):
op
=
builtin
.
ElemwiseMultiType
(
mode
=
mode
,
**
kwargs
)
args
=
convert_inputs
(
*
args
)
(
result
,)
=
apply
(
op
,
*
args
)
return
result
...
...
@@ -249,22 +248,22 @@ class ArrayMethodMixin(abc.ABC):
__hash__
=
None
# due to __eq__ diviates from python convention
__lt__
=
lambda
self
,
value
:
_elemwise_multi_type
(
self
,
value
,
mode
=
"lt"
,
dtype
=
"
B
ool"
self
,
value
,
mode
=
"lt"
,
dtype
=
"
b
ool"
)
__le__
=
lambda
self
,
value
:
_elemwise_multi_type
(
self
,
value
,
mode
=
"leq"
,
dtype
=
"
B
ool"
self
,
value
,
mode
=
"leq"
,
dtype
=
"
b
ool"
)
__gt__
=
lambda
self
,
value
:
_elemwise_multi_type
(
value
,
self
,
mode
=
"lt"
,
dtype
=
"
B
ool"
value
,
self
,
mode
=
"lt"
,
dtype
=
"
b
ool"
)
__ge__
=
lambda
self
,
value
:
_elemwise_multi_type
(
value
,
self
,
mode
=
"leq"
,
dtype
=
"
B
ool"
value
,
self
,
mode
=
"leq"
,
dtype
=
"
b
ool"
)
__eq__
=
lambda
self
,
value
:
_elemwise_multi_type
(
self
,
value
,
mode
=
"eq"
,
dtype
=
"
B
ool"
self
,
value
,
mode
=
"eq"
,
dtype
=
"
b
ool"
)
__ne__
=
lambda
self
,
value
:
_elemwise_multi_type
(
self
,
value
,
mode
=
"neq"
,
dtype
=
"
B
ool"
self
,
value
,
mode
=
"neq"
,
dtype
=
"
b
ool"
)
__neg__
=
_unary_elwise
(
_ElwMod
.
NEGATE
)
...
...
imperative/python/megengine/functional/math.py
浏览文件 @
51397bfc
...
...
@@ -52,7 +52,7 @@ def isnan(inp: Tensor) -> Tensor:
>>> F.isnan(x).numpy()
array([False, True, False])
"""
return
_elemwise_multi_type
(
inp
,
mode
=
"isnan"
,
dtype
=
"
B
ool"
)
return
_elemwise_multi_type
(
inp
,
mode
=
"isnan"
,
dtype
=
"
b
ool"
)
def
isinf
(
inp
:
Tensor
)
->
Tensor
:
...
...
@@ -69,7 +69,7 @@ def isinf(inp: Tensor) -> Tensor:
>>> F.isinf(x).numpy()
array([False, True, False])
"""
return
_elemwise_multi_type
(
inp
,
mode
=
"isinf"
,
dtype
=
"
B
ool"
)
return
_elemwise_multi_type
(
inp
,
mode
=
"isinf"
,
dtype
=
"
b
ool"
)
def
sign
(
inp
:
Tensor
):
...
...
imperative/python/src/tensor.cpp
浏览文件 @
51397bfc
...
...
@@ -118,7 +118,7 @@ PyObject* py_apply(
tensors
[
i
]
=
tw
->
m_tensor
->
data
();
}
else
if
(
DTypePromoteCfg
::
convert_input_enabled
&&
op
->
same_type
<
Elemwise
>
(
))
{
(
op
->
same_type
<
Elemwise
>
()
||
op
->
same_type
<
ElemwiseMultiType
>
()
))
{
tensors
[
i
]
=
convert_pyinput_to_tensor
(
i
);
}
else
{
PyErr_SetString
(
PyExc_TypeError
,
"py_apply expects tensor as inputs"
);
...
...
imperative/src/impl/transformations/dtype_promote.cpp
浏览文件 @
51397bfc
...
...
@@ -53,6 +53,41 @@ mgb::DType get_promoted_dtype(const SmallVector<DType>& dtypes) {
return
ret
;
}
ValueRefList
elemwise_multi_type_rule
(
const
OpDef
&
op
,
Span
<
ValueRef
>
inputs
)
{
auto
&&
elem_op
=
op
.
cast_final_safe
<
ElemwiseMultiType
>
();
static
std
::
unordered_set
<
ElemwiseMultiType
::
Mode
>
cast_case
=
{
ElemwiseMultiType
::
Mode
::
EQ
,
ElemwiseMultiType
::
Mode
::
NEQ
,
ElemwiseMultiType
::
Mode
::
LT
,
ElemwiseMultiType
::
Mode
::
LEQ
,
};
if
(
cast_case
.
find
(
elem_op
.
mode
)
==
cast_case
.
end
())
{
return
imperative
::
apply
(
op
,
inputs
);
}
SmallVector
<
DType
>
dtypes
(
inputs
.
size
());
for
(
size_t
i
=
0
;
i
<
inputs
.
size
();
++
i
)
{
dtypes
[
i
]
=
*
(
inputs
[
i
].
dtype
());
}
ValueRefList
converted
(
inputs
.
size
());
mgb
::
DType
target_dtype
=
get_promoted_dtype
(
dtypes
);
for
(
size_t
i
=
0
;
i
<
inputs
.
size
();
++
i
)
{
if
(
!
is_quantized_dtype
(
dtypes
[
i
])
&&
dtypes
[
i
]
!=
target_dtype
&&
DTypePromoteCfg
::
convert_input_enabled
)
{
converted
[
i
]
=
imperative
::
apply
(
ApplyOp
(
*
TypeCvt
::
make
(
target_dtype
)),
inputs
[
i
])[
0
];
dtypes
[
i
]
=
target_dtype
;
}
else
{
converted
[
i
]
=
inputs
[
i
];
}
}
return
imperative
::
apply
(
op
,
converted
);
}
ValueRefList
elemwise_rule
(
const
OpDef
&
op
,
Span
<
ValueRef
>
inputs
)
{
auto
&&
elem_op
=
op
.
cast_final_safe
<
Elemwise
>
();
...
...
@@ -349,6 +384,7 @@ ValueRefList naive_promote_rule(const OpDef& op, Span<ValueRef> inputs) {
struct
DTypePromoteRuleRegistry
{
DTypePromoteRuleRegistry
()
{
register_dtype_promote_rule
<
Elemwise
>
(
elemwise_rule
);
register_dtype_promote_rule
<
ElemwiseMultiType
>
(
elemwise_multi_type_rule
);
register_dtype_promote_rule
<
Concat
>
(
naive_promote_rule
);
register_dtype_promote_rule
<
GroupLocal
>
(
naive_promote_rule
);
register_dtype_promote_rule
<
Reduce
>
(
reduce_rule
);
...
...
src/opr/impl/basic_arith.cpp
浏览文件 @
51397bfc
...
...
@@ -16,52 +16,6 @@
using
namespace
mgb
;
using
namespace
opr
;
namespace
{
//! global operator instance for static inference
template
<
class
Opr
>
class
StaticInferOpr
{
intl
::
UniqPtrWithCN
<
Opr
>
m_opr
;
MGB_MUTEX
m_mtx
;
public:
class
Lock
{
friend
class
StaticInferOpr
;
StaticInferOpr
*
m_owner
;
explicit
Lock
(
StaticInferOpr
*
owner
)
:
m_owner
{
owner
}
{
#if !__DEPLOY_ON_XP_SP2__
m_owner
->
m_mtx
.
lock
();
#endif
}
public:
Lock
(
Lock
&&
rhs
)
:
m_owner
{
rhs
.
m_owner
}
{
rhs
.
m_owner
=
nullptr
;
}
~
Lock
()
{
#if !__DEPLOY_ON_XP_SP2__
if
(
m_owner
)
m_owner
->
m_mtx
.
unlock
();
#endif
}
Lock
&
operator
=
(
const
Lock
&
)
=
delete
;
Lock
&
operator
=
(
Lock
&&
)
=
delete
;
intl
::
UniqPtrWithCN
<
Opr
>&
operator
()()
{
return
m_owner
->
m_opr
;
}
};
//! lock and acquire the operator
Lock
lock
()
{
Lock
ret
{
this
};
if
(
!
m_opr
)
{
m_opr
=
intl
::
create_megdnn_opr
<
Opr
>
(
CompNode
::
default_cpu
());
}
return
ret
;
}
};
}
// anonymous namespace
/* ========================= BatchedDTypePromotion ========================= */
intl
::
BatchedDTypePromotion
::
BatchedDTypePromotion
(
const
VarNodeArrayView
&
vars
)
:
m_orig_vars
{
vars
}
{
...
...
src/opr/impl/nn_int.cpp
浏览文件 @
51397bfc
#include "megbrain/opr/nn_int.h"
#include "./internal/megdnn_opr_wrapper.inl"
#include "megbrain/opr/utility.h"
#include "megdnn/oprs/general.h"
using
namespace
mgb
;
...
...
@@ -18,6 +18,7 @@ ElemwiseMultiType::ElemwiseMultiType(
for
(
auto
i
:
inputs
)
{
add_input
({
i
});
}
output
(
0
)
->
add_flag
(
VarNode
::
Flag
::
ALLOW_EMPTY_SHAPE
);
}
SymbolVar
ElemwiseMultiType
::
make
(
...
...
@@ -52,8 +53,13 @@ void ElemwiseMultiType::init_output_dtype() {
void
ElemwiseMultiType
::
scn_do_execute
()
{
megdnn
::
TensorNDArray
inp_arr
(
input
().
size
());
for
(
size_t
i
=
0
;
i
<
input
().
size
();
++
i
)
{
if
(
input
()[
i
]
->
dev_tensor
().
empty
())
{
mgb_assert
(
output
(
0
)
->
dev_tensor
().
empty
());
return
;
}
inp_arr
[
i
]
=
input
()[
i
]
->
dev_tensor
().
as_megdnn
();
}
mgb_assert
(
!
output
(
0
)
->
dev_tensor
().
empty
());
megdnn_opr
()
->
exec
(
inp_arr
,
output
(
0
)
->
dev_tensor
().
as_megdnn
());
}
...
...
@@ -75,4 +81,120 @@ void ElemwiseMultiType::add_input_layout_constraint() {
#endif
}
ElemwiseMultiType
::
NodeProp
*
ElemwiseMultiType
::
do_make_node_prop
()
const
{
auto
ret
=
Super
::
do_make_node_prop
();
for
(
auto
&
inp
:
input
())
{
ret
->
add_dep_type_existing_var
(
inp
,
NodeProp
::
DepType
::
VALUE_ALLOW_EMPTY
);
}
return
ret
;
}
void
ElemwiseMultiType
::
init_output_static_infer_desc
()
{
Super
::
init_output_static_infer_desc
();
static
StaticInferOpr
<
megdnn
::
ElemwiseMultiType
>
static_infer_opr
;
using
namespace
cg
::
static_infer
;
auto
infer_value
=
[
this
](
DeviceTensorND
&
dest
,
const
InpVal
&
inp
)
{
SmallVector
<
DeviceTensorND
>
inp_vals
(
inp
.
val
.
size
());
for
(
size_t
i
=
0
;
i
<
inp_vals
.
size
();
++
i
)
inp_vals
[
i
]
=
inp
.
val
[
i
].
value
();
DType
out_dt
;
auto
trait
=
ModeTrait
::
from_mode
(
param
().
mode
);
if
(
trait
.
need_specify_out_dtype
)
{
auto
dtype
=
config
().
output_dtype
();
mgb_assert
(
dtype
.
valid
());
out_dt
=
dtype
;
}
else
{
DType
dtype
;
trait
.
check_out
(
dtype
,
false
);
out_dt
=
dtype
;
}
auto
sopr
=
static_infer_opr
.
lock
();
perform
(
param
().
mode
,
out_dt
,
dest
,
inp_vals
,
sopr
());
return
true
;
};
DepVal
deps
(
input
().
size
());
for
(
size_t
i
=
0
;
i
<
input
().
size
();
++
i
)
deps
[
i
]
=
{
input
(
i
),
DepType
::
VALUE
};
owner_graph
()
->
static_infer_manager
().
register_value_infer
(
output
(
0
),
{
SourceType
::
DEP
,
deps
,
infer_value
});
}
TensorShape
ElemwiseMultiType
::
get_output_var_shape
(
Mode
mode
,
const
TensorShapeArray
&
input_shapes
)
{
mgb_assert
(
input_shapes
.
size
()
==
ModeTrait
::
from_mode
(
mode
).
arity
);
TensorShape
ret
;
megdnn
::
Elemwise
::
deduce_shape
(
input_shapes
,
ret
);
return
ret
;
}
void
ElemwiseMultiType
::
call_megdnn_opr_exec
(
CompNode
comp_node
,
megdnn
::
TensorNDArray
&
inp
,
const
megdnn
::
TensorND
&
out
,
megdnn
::
ElemwiseMultiType
*
opr
,
ElemwiseMultiType
*
caller
)
{
// All Elemwise operations on QuantizedS32/QuantizedS8 are not related to
// scale. MegDNN does not support computing Elemwise for
// QuantizedS32/QuantizedS8, we translate the data type to Int32/Int8 before
// passing to MegDNN.
if
(
inp
.
size
()
&&
inp
[
0
].
layout
.
dtype
.
category
()
==
DTypeCategory
::
QUANTIZED
)
{
auto
inp_dtype
=
inp
[
0
].
layout
.
dtype
;
DType
compute_dtype
;
if
(
inp_dtype
.
enumv
()
==
DTypeEnum
::
QuantizedS32
)
{
compute_dtype
=
dtype
::
Int32
();
}
else
if
(
inp_dtype
.
enumv
()
==
DTypeEnum
::
QuantizedS8
)
{
compute_dtype
=
dtype
::
Int8
();
}
else
{
mgb_throw
(
MegBrainError
,
"Unsupported Quantized Elemwise Mode %s: %d on %s"
,
inp
[
0
].
layout
.
dtype
.
name
(),
int
(
opr
->
param
().
mode
),
comp_node
.
to_string
().
c_str
());
}
megdnn
::
TensorNDArray
run_inp
(
inp
);
for
(
size_t
i
=
0
;
i
<
inp
.
size
();
i
++
)
{
run_inp
[
i
].
layout
.
dtype
=
compute_dtype
;
}
megdnn
::
TensorND
run_out
=
out
;
run_out
.
layout
.
dtype
=
compute_dtype
;
opr
->
exec
(
run_inp
,
run_out
);
return
;
}
opr
->
exec
(
inp
,
out
);
}
void
ElemwiseMultiType
::
perform
(
Mode
mode
,
DType
out_dt
,
DeviceTensorND
&
dest
,
const
SmallVector
<
DeviceTensorND
>&
inputs
,
intl
::
UniqPtrWithCN
<
megdnn
::
ElemwiseMultiType
>&
opr
)
{
megdnn
::
TensorNDArray
dnn_inputs
(
inputs
.
size
());
TensorShapeArray
inp_shapes
(
inputs
.
size
());
CompNode
out_cn
;
for
(
size_t
i
=
0
;
i
<
inputs
.
size
();
++
i
)
{
auto
&&
t
=
inputs
[
i
];
if
(
!
i
)
{
out_cn
=
t
.
comp_node
();
}
else
{
mgb_assert
(
t
.
comp_node
()
==
out_cn
);
}
if
(
t
.
shape
().
is_empty
())
{
mgb_assert
(
dest
.
empty
());
return
;
}
inp_shapes
[
i
]
=
t
.
shape
();
}
if
(
!
opr
)
{
opr
=
intl
::
create_megdnn_opr
<
megdnn
::
ElemwiseMultiType
>
(
out_cn
);
}
else
{
mgb_assert
(
out_cn
==
opr
.
comp_node
());
}
out_cn
.
activate
();
for
(
size_t
i
=
0
;
i
<
inputs
.
size
();
++
i
)
dnn_inputs
[
i
]
=
inputs
[
i
].
as_megdnn
();
dest
.
comp_node
(
out_cn
).
dtype
(
out_dt
).
resize
(
get_output_var_shape
(
mode
,
inp_shapes
));
opr
->
param
()
=
{
mode
};
call_megdnn_opr_exec
(
out_cn
,
dnn_inputs
,
dest
.
as_megdnn
(),
opr
.
get
(),
nullptr
);
}
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
src/opr/include/megbrain/opr/nn_int.h
浏览文件 @
51397bfc
...
...
@@ -26,6 +26,14 @@ public:
const
VarNodeArrayView
&
inputs
,
Param
param
,
const
OperatorNodeConfig
&
config
=
{});
MGE_WIN_DECLSPEC_FUC
static
TensorShape
get_output_var_shape
(
Mode
mode
,
const
TensorShapeArray
&
input_shapes
);
MGE_WIN_DECLSPEC_FUC
static
void
perform
(
Mode
mode
,
DType
out_dt
,
DeviceTensorND
&
dest
,
const
SmallVector
<
DeviceTensorND
>&
inputs
,
intl
::
UniqPtrWithCN
<
megdnn
::
ElemwiseMultiType
>&
opr
);
private:
using
ModeTrait
=
megdnn
::
ElemwiseMultiType
::
ModeTrait
;
...
...
@@ -40,6 +48,14 @@ private:
void
record_execute_deps
(
ExecDependencyArray
&
deps
)
override
;
void
add_input_layout_constraint
()
override
;
NodeProp
*
do_make_node_prop
()
const
override
;
void
init_output_static_infer_desc
()
override
;
static
void
call_megdnn_opr_exec
(
CompNode
comp_node
,
megdnn
::
TensorNDArray
&
inp
,
const
megdnn
::
TensorND
&
out
,
megdnn
::
ElemwiseMultiType
*
opr
,
ElemwiseMultiType
*
caller
);
};
//! deprecated; TODO: remove in megbrain 8
...
...
src/opr/include/megbrain/opr/utility.h
浏览文件 @
51397bfc
...
...
@@ -509,6 +509,49 @@ public:
bool
is_const
()
const
{
return
m_is_const
;
}
}
;
//! global operator instance for static inference
template
<
class
Opr
>
class
StaticInferOpr
{
intl
::
UniqPtrWithCN
<
Opr
>
m_opr
;
MGB_MUTEX
m_mtx
;
public:
class
Lock
{
friend
class
StaticInferOpr
;
StaticInferOpr
*
m_owner
;
explicit
Lock
(
StaticInferOpr
*
owner
)
:
m_owner
{
owner
}
{
#if !__DEPLOY_ON_XP_SP2__
m_owner
->
m_mtx
.
lock
();
#endif
}
public:
Lock
(
Lock
&&
rhs
)
:
m_owner
{
rhs
.
m_owner
}
{
rhs
.
m_owner
=
nullptr
;
}
~
Lock
()
{
#if !__DEPLOY_ON_XP_SP2__
if
(
m_owner
)
m_owner
->
m_mtx
.
unlock
();
#endif
}
Lock
&
operator
=
(
const
Lock
&
)
=
delete
;
Lock
&
operator
=
(
Lock
&&
)
=
delete
;
intl
::
UniqPtrWithCN
<
Opr
>&
operator
()()
{
return
m_owner
->
m_opr
;
}
};
//! lock and acquire the operator
Lock
lock
()
{
Lock
ret
{
this
};
if
(
!
m_opr
)
{
m_opr
=
intl
::
create_megdnn_opr
<
Opr
>
(
CompNode
::
default_cpu
());
}
return
ret
;
}
};
}
// namespace opr
}
// namespace mgb
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录