Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
12178011
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
403
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
12178011
编写于
10月 29, 2020
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
perf(mge): add opdef for broadcast
GitOrigin-RevId: 92f0af29eb000b3e37f059e83fda52d26f21b383
上级
fccb2510
变更
8
隐藏空白更改
内联
并排
Showing
8 changed file
with
142 addition
and
30 deletion
+142
-30
dnn/src/common/basic_types.cpp
dnn/src/common/basic_types.cpp
+1
-1
imperative/python/megengine/core/autodiff/builtin_op_utils.py
...rative/python/megengine/core/autodiff/builtin_op_utils.py
+3
-2
imperative/python/megengine/core/tensor/tensor_wrapper.py
imperative/python/megengine/core/tensor/tensor_wrapper.py
+0
-22
imperative/python/src/ops.cpp
imperative/python/src/ops.cpp
+4
-0
imperative/python/test/unit/functional/test_tensor.py
imperative/python/test/unit/functional/test_tensor.py
+3
-3
imperative/src/impl/ops/broadcast.cpp
imperative/src/impl/ops/broadcast.cpp
+95
-0
imperative/src/include/megbrain/imperative/ops/broadcast.h
imperative/src/include/megbrain/imperative/ops/broadcast.h
+35
-0
imperative/src/include/megbrain/imperative/ops/nms.h
imperative/src/include/megbrain/imperative/ops/nms.h
+1
-2
未找到文件。
dnn/src/common/basic_types.cpp
浏览文件 @
12178011
...
@@ -413,7 +413,7 @@ TensorLayout TensorLayout::broadcast(const TensorShape& tshape) const {
...
@@ -413,7 +413,7 @@ TensorLayout TensorLayout::broadcast(const TensorShape& tshape) const {
megdnn_throw_if
(
megdnn_throw_if
(
cur_shape
!=
1
&&
cur_stride
!=
0
,
tensor_reshape_error
,
cur_shape
!=
1
&&
cur_stride
!=
0
,
tensor_reshape_error
,
megdnn_mangle
(
ssprintf
(
megdnn_mangle
(
ssprintf
(
"brodcast on dim with shape not equal to 1: "
"bro
a
dcast on dim with shape not equal to 1: "
"src_shape=%s dst_shape=%s"
,
"src_shape=%s dst_shape=%s"
,
to_string
().
c_str
(),
tshape
.
to_string
().
c_str
())));
to_string
().
c_str
(),
tshape
.
to_string
().
c_str
())));
result
.
shape
[
target_idx
]
=
tshape
.
shape
[
target_idx
];
result
.
shape
[
target_idx
]
=
tshape
.
shape
[
target_idx
];
...
...
imperative/python/megengine/core/autodiff/builtin_op_utils.py
浏览文件 @
12178011
...
@@ -47,7 +47,9 @@ def _(op: OpDef, inputs, outputs, input_requires_grad):
...
@@ -47,7 +47,9 @@ def _(op: OpDef, inputs, outputs, input_requires_grad):
grad_fn
=
reduce_sum_grad_fn
grad_fn
=
reduce_sum_grad_fn
else
:
else
:
grad_fn
=
default_grad_fn
grad_fn
=
default_grad_fn
elif
isinstance
(
op
,
Elemwise
)
and
op
.
mode
==
Elemwise
.
Mode
.
ADD
:
elif
isinstance
(
op
,
Broadcast
)
or
(
isinstance
(
op
,
Elemwise
)
and
op
.
mode
==
Elemwise
.
Mode
.
ADD
):
grad_fn
=
elemwise_add_grad_fn
grad_fn
=
elemwise_add_grad_fn
else
:
else
:
grad_fn
=
default_grad_fn
grad_fn
=
default_grad_fn
...
@@ -212,5 +214,4 @@ _oprAttr_grad_fn = {
...
@@ -212,5 +214,4 @@ _oprAttr_grad_fn = {
Reshape
.
name
:
reshape_grad_fn
,
Reshape
.
name
:
reshape_grad_fn
,
Subtensor
.
name
:
subtensor_grad_fn
,
Subtensor
.
name
:
subtensor_grad_fn
,
IndexingMultiAxisVec
.
name
:
indexingMultiAxisVec_grad_fn
,
IndexingMultiAxisVec
.
name
:
indexingMultiAxisVec_grad_fn
,
Broadcast
.
name
:
elemwise_add_grad_fn
,
}
}
imperative/python/megengine/core/tensor/tensor_wrapper.py
浏览文件 @
12178011
...
@@ -59,29 +59,7 @@ def _transpose(data, axes):
...
@@ -59,29 +59,7 @@ def _transpose(data, axes):
def
_broadcast
(
inp
,
shape
):
def
_broadcast
(
inp
,
shape
):
def
valid_broadcast
(
src
,
tar
):
def
failed
():
raise
ValueError
(
"the input shape {} can not be broadcasted to target shape {}"
.
format
(
src
,
tar
)
)
if
isinstance
(
src
,
(
TensorBase
,
TensorWrapperBase
)):
src
=
src
.
numpy
()
if
isinstance
(
tar
,
(
TensorBase
,
TensorWrapperBase
)):
tar
=
tar
.
numpy
()
if
len
(
src
)
>
len
(
tar
):
failed
()
for
i
in
range
(
min
(
len
(
src
),
len
(
tar
))):
if
src
[
-
i
-
1
]
!=
1
and
src
[
-
i
-
1
]
!=
tar
[
-
i
-
1
]:
failed
()
shape
=
utils
.
astensor1d
(
shape
,
inp
,
dtype
=
"int32"
,
device
=
inp
.
device
)
shape
=
utils
.
astensor1d
(
shape
,
inp
,
dtype
=
"int32"
,
device
=
inp
.
device
)
valid_broadcast
(
inp
.
shape
,
shape
)
(
result
,)
=
apply
(
builtin
.
Broadcast
(),
inp
,
shape
)
(
result
,)
=
apply
(
builtin
.
Broadcast
(),
inp
,
shape
)
return
result
return
result
...
...
imperative/python/src/ops.cpp
浏览文件 @
12178011
...
@@ -21,6 +21,7 @@
...
@@ -21,6 +21,7 @@
#include "megbrain/imperative/ops/nms.h"
#include "megbrain/imperative/ops/nms.h"
#include "megbrain/imperative/ops/elemwise.h"
#include "megbrain/imperative/ops/elemwise.h"
#include "megbrain/imperative/ops/batch_norm.h"
#include "megbrain/imperative/ops/batch_norm.h"
#include "megbrain/imperative/ops/broadcast.h"
namespace
py
=
pybind11
;
namespace
py
=
pybind11
;
...
@@ -206,4 +207,7 @@ void init_ops(py::module m) {
...
@@ -206,4 +207,7 @@ void init_ops(py::module m) {
V
(
INFERENCE
);
V
(
INFERENCE
);
#undef V
#undef V
py
::
class_
<
Broadcast
,
std
::
shared_ptr
<
Broadcast
>
,
OpDef
>
(
m
,
"Broadcast"
)
.
def
(
py
::
init
<>
());
}
}
imperative/python/test/unit/functional/test_tensor.py
浏览文件 @
12178011
...
@@ -262,13 +262,13 @@ def test_broadcast():
...
@@ -262,13 +262,13 @@ def test_broadcast():
opr_test
(
cases
,
F
.
broadcast_to
,
compare_fn
=
compare_fn
)
opr_test
(
cases
,
F
.
broadcast_to
,
compare_fn
=
compare_fn
)
x
=
F
.
ones
((
2
,
1
,
3
))
x
=
F
.
ones
((
2
,
1
,
3
))
with
pytest
.
raises
(
Valu
eError
):
with
pytest
.
raises
(
Runtim
eError
):
F
.
broadcast_to
(
x
,
(
2
,
3
,
4
))
F
.
broadcast_to
(
x
,
(
2
,
3
,
4
))
with
pytest
.
raises
(
Valu
eError
):
with
pytest
.
raises
(
Runtim
eError
):
F
.
broadcast_to
(
x
,
(
4
,
1
,
3
))
F
.
broadcast_to
(
x
,
(
4
,
1
,
3
))
with
pytest
.
raises
(
Valu
eError
):
with
pytest
.
raises
(
Runtim
eError
):
F
.
broadcast_to
(
x
,
(
1
,
3
))
F
.
broadcast_to
(
x
,
(
1
,
3
))
...
...
imperative/src/impl/ops/broadcast.cpp
0 → 100644
浏览文件 @
12178011
/**
* \file imperative/src/impl/ops/broadcast.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#include "megbrain/imperative/ops/broadcast.h"
#include "../op_trait.h"
namespace
mgb
{
namespace
imperative
{
namespace
{
std
::
shared_ptr
<
OpDef
>
make_from_op_node
(
cg
::
OperatorNodeBase
*
node_
)
{
node_
->
cast_final_safe
<
opr
::
Broadcast
>
();
return
Broadcast
::
make
();
}
cg
::
OperatorNodeBase
*
apply_on_var_node
(
const
OpDef
&
def
,
const
VarNodeArray
&
inputs
)
{
def
.
cast_final_safe
<
Broadcast
>
();
size_t
nr_inp
=
inputs
.
size
();
mgb_assert
(
nr_inp
==
2
,
"Broadcast expects 2 inputs; got %lu actually"
,
nr_inp
);
return
opr
::
Broadcast
::
make
(
inputs
[
0
],
inputs
[
1
]).
node
()
->
owner_opr
();
}
bool
valid_broadcast
(
const
TensorShape
&
src_shape
,
const
TensorShape
&
tar_shape
)
{
size_t
src_ndim
=
src_shape
.
ndim
,
tar_ndim
=
tar_shape
.
ndim
;
if
(
src_ndim
>
tar_ndim
)
{
return
false
;
}
size_t
min_ndim
=
src_ndim
<
tar_ndim
?
src_ndim
:
tar_ndim
;
for
(
size_t
i
=
0
;
i
<
min_ndim
;
++
i
)
{
if
(
src_shape
[
src_ndim
-
i
-
1
]
!=
1
&&
src_shape
[
src_ndim
-
i
-
1
]
!=
tar_shape
[
tar_ndim
-
i
-
1
])
{
return
false
;
}
}
return
true
;
}
SmallVector
<
LogicalTensorDesc
>
infer_output_attrs_fallible
(
const
OpDef
&
def
,
const
SmallVector
<
LogicalTensorDesc
>&
inputs
)
{
def
.
cast_final_safe
<
Broadcast
>
();
size_t
nr_inp
=
inputs
.
size
();
mgb_assert
(
nr_inp
==
2
,
"Broadcast expects 2 inputs; got %lu actually"
,
nr_inp
);
auto
&&
src
=
inputs
[
0
];
auto
&&
tshp
=
inputs
[
1
];
TensorLayout
out_layout
=
src
.
layout
;
if
(
tshp
.
layout
.
ndim
==
0
||
tshp
.
value
.
empty
())
{
out_layout
.
ndim
=
0
;
return
{{
out_layout
,
src
.
comp_node
}};
}
mgb_assert
(
tshp
.
layout
.
ndim
==
1
,
"target shape of Broadcast expects ndim=1; got ndim=%lu actually"
,
tshp
.
layout
.
ndim
);
size_t
target_ndim
=
tshp
.
layout
.
shape
[
0
];
out_layout
.
ndim
=
target_ndim
;
auto
*
ptr
=
tshp
.
value
.
ptr
<
dt_int32
>
();
for
(
size_t
i
=
0
;
i
<
target_ndim
;
++
i
)
{
out_layout
.
shape
[
i
]
=
ptr
[
i
];
}
mgb_assert
(
valid_broadcast
(
src
.
layout
,
out_layout
),
"the input shape %s can not be broadcasted to target shape %s"
,
src
.
layout
.
TensorShape
::
to_string
().
c_str
(),
out_layout
.
TensorShape
::
to_string
().
c_str
());
return
{{
out_layout
,
src
.
comp_node
}};
}
OP_TRAIT_REG
(
Broadcast
,
Broadcast
,
opr
::
Broadcast
)
.
make_from_op_node
(
make_from_op_node
)
.
apply_on_var_node
(
apply_on_var_node
)
.
infer_output_attrs_fallible
(
infer_output_attrs_fallible
)
.
fallback
();
}
// anonymous namespace
MGB_DYN_TYPE_OBJ_FINAL_IMPL
(
Broadcast
);
}
// namespace imperative
}
// namespace mgb
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
imperative/src/include/megbrain/imperative/ops/broadcast.h
0 → 100644
浏览文件 @
12178011
/**
* \file imperative/src/include/megbrain/imperative/ops/broadcast.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#pragma once
#include "megbrain/opr/tensor_manip.h"
#include "megbrain/imperative/ops/opr_attr.h"
#include "megbrain/imperative/op_def.h"
namespace
mgb
::
imperative
{
class
Broadcast
:
public
OpDefImplBase
<
Broadcast
>
{
MGB_DYN_TYPE_OBJ_FINAL_DECL
;
public:
Broadcast
()
=
default
;
size_t
hash
()
const
override
{
return
reinterpret_cast
<
std
::
uintptr_t
>
(
dyn_typeinfo
());
}
bool
is_same_st
(
const
Hashable
&
rhs
)
const
override
{
return
true
;
}
};
}
// namespace mgb::imperative
imperative/src/include/megbrain/imperative/ops/nms.h
浏览文件 @
12178011
...
@@ -32,8 +32,7 @@ public:
...
@@ -32,8 +32,7 @@ public:
bool
is_same_st
(
const
Hashable
&
rhs_
)
const
override
{
bool
is_same_st
(
const
Hashable
&
rhs_
)
const
override
{
auto
&&
rhs
=
static_cast
<
const
NMSKeep
&>
(
rhs_
);
auto
&&
rhs
=
static_cast
<
const
NMSKeep
&>
(
rhs_
);
return
rhs
.
dyn_typeinfo
()
==
dyn_typeinfo
()
return
rhs
.
iou_thresh
==
iou_thresh
&&
rhs
.
iou_thresh
==
iou_thresh
&&
rhs
.
max_output
==
max_output
;
&&
rhs
.
max_output
==
max_output
;
}
}
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录