Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
d3b2b519
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
404
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
d3b2b519
编写于
9月 09, 2022
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
feat(mge/opr): add meshgrid opr
GitOrigin-RevId: 6f703295be2c33974ba27edd3161744251eea570
上级
ec234135
变更
8
显示空白变更内容
内联
并排
Showing
8 changed file
with
320 addition
and
11 deletion
+320
-11
imperative/python/megengine/functional/tensor.py
imperative/python/megengine/functional/tensor.py
+48
-1
imperative/src/impl/ops/broadcast.cpp
imperative/src/impl/ops/broadcast.cpp
+119
-5
imperative/tablegen/generated/hash.txt
imperative/tablegen/generated/hash.txt
+5
-5
imperative/tablegen/generated/opdef.cpp.inl
imperative/tablegen/generated/opdef.cpp.inl
+37
-0
imperative/tablegen/generated/opdef.cpy.inl
imperative/tablegen/generated/opdef.cpy.inl
+90
-0
imperative/tablegen/generated/opdef.h.inl
imperative/tablegen/generated/opdef.h.inl
+9
-0
imperative/tablegen/generated/opdef.py.inl
imperative/tablegen/generated/opdef.py.inl
+7
-0
src/core/include/megbrain/ir/ops.td
src/core/include/megbrain/ir/ops.td
+5
-0
未找到文件。
imperative/python/megengine/functional/tensor.py
浏览文件 @
d3b2b519
# -*- coding: utf-8 -*-
# -*- coding: utf-8 -*-
from
functools
import
lru_cache
from
functools
import
lru_cache
from
typing
import
Iterable
,
Optional
,
Sequence
,
Tuple
,
Union
from
typing
import
Iterable
,
List
,
Optional
,
Sequence
,
Tuple
,
Union
import
numpy
as
np
import
numpy
as
np
...
@@ -36,6 +36,7 @@ __all__ = [
...
@@ -36,6 +36,7 @@ __all__ = [
"full_like"
,
"full_like"
,
"gather"
,
"gather"
,
"linspace"
,
"linspace"
,
"meshgrid"
,
"ones"
,
"ones"
,
"ones_like"
,
"ones_like"
,
"repeat"
,
"repeat"
,
...
@@ -1205,3 +1206,49 @@ def cumsum(inp: Tensor, axis: int):
...
@@ -1205,3 +1206,49 @@ def cumsum(inp: Tensor, axis: int):
assert
isinstance
(
inp
,
Tensor
),
"input of cumsum must be type of Tensor"
assert
isinstance
(
inp
,
Tensor
),
"input of cumsum must be type of Tensor"
op
=
builtin
.
Cumsum
(
axis
=
axis
,
exclusive
=
False
,
reverse
=
False
)
op
=
builtin
.
Cumsum
(
axis
=
axis
,
exclusive
=
False
,
reverse
=
False
)
return
apply
(
op
,
inp
)[
0
]
return
apply
(
op
,
inp
)[
0
]
def
meshgrid
(
*
inputs
:
Tensor
,
indexing
:
str
=
"xy"
)
->
List
[
Tensor
]:
r
"""Returns coordinate matrices from coordinate vectors.
Args:
inputs: an arbitrary number of one-dimensional tensors representing grid
coordinates. Each input should have the same numeric data type.
indexing: Cartesian ``'xy'`` or matrix ``'ij'`` indexing of output.
If provided zero or one one-dimensional vector(s) (i.e., the zero- and one-dimensional
cases, respectively), the indexing keyword has no effect and should be ignored.
Returns:
out: list of N tensors, where N is the number of provided one-dimensional input tensors.
Each returned tensor must have rank N. For N one-dimensional tensors having lengths ``Ni = len(xi)``,
* if matrix indexing ``ij``, then each returned tensor must have the shape ``(N1, N2, N3, ..., Nn)``.
* if Cartesian indexing ``xy``, then each returned tensor must have shape ``(N2, N1, N3, ..., Nn)``.
Accordingly, for the two-dimensional case with input one-dimensional tensors of length ``M`` and ``N``,
if matrix indexing ``ij``, then each returned tensor must have shape ``(M, N)``, and, if Cartesian indexing ``xy``,
then each returned tensor must have shape ``(N, M)``.
Similarly, for the three-dimensional case with input one-dimensional tensor of length ``M``, ``N``, and ``P``,
if matrix indexing ``ij``, then each returned tensor must have shape ``(M, N, P)``, and, if Cartesian indexing ``xy``,
then each returned tensor must have shape ``(N, M, P)``.
Each returned tensor should have the same data type as the input tensors.
Examples:
>>> nx, ny = (3, 2)
>>> x = F.linspace(0, 1, nx)
>>> y = F.linspace(0, 1, ny)
>>> xv, yv = F.meshgrid(x, y)
>>> xv
Tensor([[0. 0.5 1. ]
[0. 0.5 1. ]], device=xpux:0)
>>> yv
Tensor([[0. 0. 0.]
[1. 1. 1.]], device=xpux:0)
"""
op
=
builtin
.
MeshGrid
(
indexing
)
return
apply
(
op
,
*
inputs
)
imperative/src/impl/ops/broadcast.cpp
浏览文件 @
d3b2b519
#include <numeric>
#include "megbrain/graph/helper.h"
#include "megbrain/imperative/ops/autogen.h"
#include "megbrain/imperative/ops/autogen.h"
#include "megbrain/opr/io.h"
#include "megbrain/opr/tensor_manip.h"
#include "megbrain/opr/tensor_manip.h"
#include "megbrain/graph/helper.h"
#include "../op_trait.h"
#include "../op_trait.h"
namespace
mgb
{
namespace
mgb
{
namespace
imperative
{
namespace
imperative
{
namespace
meshgrid
{
SmallVector
<
VarNode
::
LayoutConstraintCallback
>
get_input_layout_constraint
(
const
OpDef
&
def
,
const
SmallVector
<
TensorPtr
>&
inputs
)
{
return
SmallVector
<
VarNode
::
LayoutConstraintCallback
>
(
inputs
.
size
());
}
std
::
tuple
<
SmallVector
<
LogicalTensorDesc
>
,
bool
>
infer_output_attrs_fallible
(
const
OpDef
&
def
,
const
SmallVector
<
LogicalTensorDesc
>&
inputs
)
{
for
(
size_t
i
=
0
;
i
<
inputs
.
size
()
-
1
;
i
++
)
{
mgb_assert
(
inputs
[
i
].
layout
.
dtype
==
inputs
[
i
+
1
].
layout
.
dtype
);
mgb_assert
(
inputs
[
i
].
comp_node
==
inputs
[
i
+
1
].
comp_node
);
}
auto
&&
op
=
def
.
cast_final_safe
<
MeshGrid
>
();
mgb_assert
(
op
.
indexing
==
"xy"
||
op
.
indexing
==
"ij"
);
bool
success
=
true
;
SmallVector
<
size_t
>
shp
;
for
(
size_t
i
=
0
;
i
<
inputs
.
size
();
i
++
)
{
mgb_assert
(
inputs
[
i
].
layout
.
ndim
<=
1
);
if
(
inputs
[
i
].
layout
.
ndim
==
0
)
{
success
=
false
;
}
shp
.
push_back
(
inputs
[
i
].
layout
.
total_nr_elems
());
}
if
(
op
.
indexing
==
"xy"
and
shp
.
size
()
>=
2
)
{
std
::
swap
(
shp
[
0
],
shp
[
1
]);
}
TensorShape
tshp
(
shp
);
SmallVector
<
LogicalTensorDesc
>
descs
;
for
(
size_t
i
=
0
;
i
<
inputs
.
size
();
i
++
)
{
if
(
success
)
{
descs
.
push_back
(
{
TensorLayout
(
tshp
,
inputs
[
0
].
layout
.
dtype
),
inputs
[
0
].
comp_node
});
}
else
{
descs
.
push_back
(
{
TensorLayout
(
inputs
[
0
].
layout
.
dtype
),
inputs
[
0
].
comp_node
});
}
}
return
{
descs
,
success
};
}
VarNodeArray
apply_on_var_node
(
const
OpDef
&
def
,
const
VarNodeArray
&
inputs
)
{
auto
&&
op
=
def
.
cast_final_safe
<
MeshGrid
>
();
std
::
vector
<
size_t
>
indexs
(
inputs
.
size
());
std
::
iota
(
indexs
.
begin
(),
indexs
.
end
(),
0
);
auto
cn
=
inputs
[
0
]
->
comp_node
();
auto
graph
=
inputs
[
0
]
->
owner_graph
();
if
(
op
.
indexing
==
"xy"
)
{
if
(
indexs
.
size
()
>=
2
)
{
std
::
swap
(
indexs
[
0
],
indexs
[
1
]);
}
}
else
{
mgb_assert
(
op
.
indexing
==
"ij"
,
"meshgrid only support
\"
ij
\"
or
\"
xy
\"
"
);
}
VarNodeArray
shps
;
for
(
size_t
ind
=
0
;
ind
<
inputs
.
size
();
ind
++
)
{
auto
&&
inp
=
inputs
[
indexs
[
ind
]];
shps
.
push_back
(
opr
::
GetVarShape
::
make
(
inp
).
node
());
}
VarNode
*
tshp
=
opr
::
Concat
::
make
(
shps
,
0
,
cn
).
node
();
VarNodeArray
results
;
auto
t_ndim
=
inputs
.
size
();
for
(
size_t
ind
=
0
;
ind
<
inputs
.
size
();
ind
++
)
{
auto
axis
=
indexs
[
ind
];
HostTensorND
hv
=
HostTensorND
(
cn
,
{
t_ndim
},
dtype
::
Int32
());
auto
*
ptr
=
hv
.
ptr
<
dt_int32
>
();
std
::
fill_n
(
ptr
,
t_ndim
,
1
);
ptr
[
axis
]
=
-
1
;
auto
shp
=
opr
::
ImmutableTensor
::
make
(
*
graph
,
hv
,
cn
).
node
();
auto
tmp
=
opr
::
Reshape
::
make
(
inputs
[
ind
],
shp
,
axis
).
node
();
results
.
push_back
(
opr
::
Broadcast
::
make
(
tmp
,
tshp
).
node
());
}
return
results
;
}
SmallVector
<
TensorPtr
>
apply_on_physical_tensor
(
const
OpDef
&
def
,
const
SmallVector
<
TensorPtr
>&
inputs
,
SmallVector
<
LogicalTensorDesc
>&
output_descs
,
const
bool
&
validated
)
{
auto
&&
op
=
def
.
cast_final_safe
<
MeshGrid
>
();
TensorShape
tshp
;
TensorShape
view_shp
;
tshp
.
ndim
=
inputs
.
size
();
view_shp
.
ndim
=
inputs
.
size
();
std
::
vector
<
size_t
>
indexs
(
inputs
.
size
());
std
::
iota
(
indexs
.
begin
(),
indexs
.
end
(),
0
);
if
(
op
.
indexing
==
"xy"
)
{
if
(
indexs
.
size
()
>=
2
)
{
std
::
swap
(
indexs
[
0
],
indexs
[
1
]);
}
}
else
{
mgb_assert
(
op
.
indexing
==
"ij"
,
"meshgrid only support
\"
ij
\"
or
\"
xy
\"
"
);
}
for
(
size_t
ind
=
0
;
ind
<
inputs
.
size
();
ind
++
)
{
auto
&&
inp
=
inputs
[
indexs
[
ind
]];
mgb_assert
(
inp
->
layout
().
ndim
<=
1
);
tshp
[
ind
]
=
inp
->
layout
().
total_nr_elems
();
view_shp
[
ind
]
=
1
;
}
SmallVector
<
TensorPtr
>
grids
;
for
(
size_t
i
=
0
;
i
<
inputs
.
size
();
i
++
)
{
auto
&&
src
=
inputs
[
i
];
TensorLayout
layout
;
view_shp
[
indexs
[
i
]]
=
src
->
layout
().
total_nr_elems
();
mgb_assert
(
src
->
layout
().
try_reshape
(
layout
,
view_shp
));
layout
=
layout
.
broadcast
(
tshp
);
view_shp
[
indexs
[
i
]]
=
1
;
grids
.
push_back
(
Tensor
::
make
(
src
->
blob
(),
src
->
offset
(),
layout
));
}
return
grids
;
}
OP_TRAIT_REG
(
MeshGrid
,
MeshGrid
)
.
apply_on_var_node
(
apply_on_var_node
)
.
infer_output_attrs_fallible
(
infer_output_attrs_fallible
)
.
apply_on_physical_tensor
(
apply_on_physical_tensor
)
.
get_input_layout_constraint
(
get_input_layout_constraint
)
.
fallback
();
}
// namespace meshgrid
namespace
broadcast
{
namespace
broadcast
{
std
::
shared_ptr
<
OpDef
>
make_from_op_node
(
cg
::
OperatorNodeBase
*
node_
)
{
std
::
shared_ptr
<
OpDef
>
make_from_op_node
(
cg
::
OperatorNodeBase
*
node_
)
{
...
@@ -211,7 +327,6 @@ SmallVector<TensorPtr> apply_on_physical_tensor(
...
@@ -211,7 +327,6 @@ SmallVector<TensorPtr> apply_on_physical_tensor(
tshp
,
tshp_nd
->
get_value
().
proxy_to_default_cpu
());
tshp
,
tshp_nd
->
get_value
().
proxy_to_default_cpu
());
}
}
if
(
op
.
axis
!=
opr
::
Reshape
::
Param
::
INVALID_AXIS
)
{
if
(
op
.
axis
!=
opr
::
Reshape
::
Param
::
INVALID_AXIS
)
{
mgb_assert
(
tshp
[
op
.
axis
]
==
-
1
);
tshp
[
op
.
axis
]
=
1
;
tshp
[
op
.
axis
]
=
1
;
tshp
[
op
.
axis
]
=
src
->
layout
().
total_nr_elems
()
/
tshp
.
total_nr_elems
();
tshp
[
op
.
axis
]
=
src
->
layout
().
total_nr_elems
()
/
tshp
.
total_nr_elems
();
}
}
...
@@ -237,7 +352,6 @@ SmallVector<VarNode::LayoutConstraintCallback> get_input_layout_constraint(
...
@@ -237,7 +352,6 @@ SmallVector<VarNode::LayoutConstraintCallback> get_input_layout_constraint(
tshp
,
inputs
[
1
]
->
get_value
().
proxy_to_default_cpu
());
tshp
,
inputs
[
1
]
->
get_value
().
proxy_to_default_cpu
());
}
}
if
(
op
.
axis
!=
opr
::
Reshape
::
Param
::
INVALID_AXIS
)
{
if
(
op
.
axis
!=
opr
::
Reshape
::
Param
::
INVALID_AXIS
)
{
mgb_assert
(
tshp
[
op
.
axis
]
==
-
1
);
tshp
[
op
.
axis
]
=
1
;
tshp
[
op
.
axis
]
=
1
;
tshp
[
op
.
axis
]
=
layout
.
total_nr_elems
()
/
tshp
.
total_nr_elems
();
tshp
[
op
.
axis
]
=
layout
.
total_nr_elems
()
/
tshp
.
total_nr_elems
();
}
}
...
@@ -250,7 +364,7 @@ SmallVector<VarNode::LayoutConstraintCallback> get_input_layout_constraint(
...
@@ -250,7 +364,7 @@ SmallVector<VarNode::LayoutConstraintCallback> get_input_layout_constraint(
return
layout_checker
;
return
layout_checker
;
}
}
OP_TRAIT_REG
(
Reshape
,
Reshape
)
OP_TRAIT_REG
(
Reshape
,
Reshape
,
opr
::
Reshape
)
.
apply_on_var_node
(
apply_on_var_node
)
.
apply_on_var_node
(
apply_on_var_node
)
.
infer_output_attrs_fallible
(
infer_output_attrs_fallible
)
.
infer_output_attrs_fallible
(
infer_output_attrs_fallible
)
.
apply_on_physical_tensor
(
apply_on_physical_tensor
)
.
apply_on_physical_tensor
(
apply_on_physical_tensor
)
...
...
imperative/tablegen/generated/hash.txt
浏览文件 @
d3b2b519
905bdf78e5413b06873be64b4ba55db9 ../../dnn/scripts/opr_param_defs.py
905bdf78e5413b06873be64b4ba55db9 ../../dnn/scripts/opr_param_defs.py
e35e13523f43b7bea4034a0bf75937b7
../../src/core/include/megbrain/ir/ops.td
40708c56b1f05fdb7d06cc097a300330
../../src/core/include/megbrain/ir/ops.td
240dccd6f8d42cadfd08c6ca90fe61b1
generated/opdef.h.inl
9f3af118c7fe8d0c9db433825d5ad77b
generated/opdef.h.inl
a79a4058ff18ffd9593ee5db3deef6c4
generated/opdef.cpp.inl
4041e44a8ba3cca3b3affa1ed9ed44a2
generated/opdef.cpp.inl
83c179ee7416824fbfab978a097cd4d3
generated/opdef.py.inl
319e1d170c989fe793a4e9c45decefc4
generated/opdef.py.inl
86f70b1052331130f5e4c0ca53e68423
generated/opdef.cpy.inl
26a18a7593566128ecce76e8f74dcc5d
generated/opdef.cpy.inl
71e1462bf4d882e2615c3c632cb671cc generated/enum_macro.h
71e1462bf4d882e2615c3c632cb671cc generated/enum_macro.h
imperative/tablegen/generated/opdef.cpp.inl
浏览文件 @
d3b2b519
...
@@ -4672,6 +4672,43 @@ OP_TRAIT_REG(MatrixMul, MatrixMul)
...
@@ -4672,6 +4672,43 @@ OP_TRAIT_REG(MatrixMul, MatrixMul)
.props(MatrixMul_props_impl)
.props(MatrixMul_props_impl)
.make_name(MatrixMul_make_name_impl);
.make_name(MatrixMul_make_name_impl);
MGB_DYN_TYPE_OBJ_FINAL_IMPL(MeshGrid);
namespace {
size_t MeshGrid_hash_impl(const OpDef& def_) {
auto&& op_ = def_.cast_final_safe<MeshGrid>();
static_cast<void>(op_);
size_t val = mgb::hash(op_.dyn_typeinfo());
val = mgb::hash_pair_combine(val, mgb::hash(op_.indexing));
return val;
}
bool MeshGrid_is_same_st_impl(const OpDef& lhs_, const OpDef& rhs_) {
auto &&a_ = lhs_.cast_final_safe<MeshGrid>(),
&&b_ = rhs_.cast_final_safe<MeshGrid>();
static_cast<void>(a_);
static_cast<void>(b_);
if (a_.indexing != b_.indexing) return false;
return true;
}
std::vector<std::pair<const char*, std::string>> MeshGrid_props_impl(const OpDef& def_) {
auto&& op_ = def_.cast_final_safe<MeshGrid>();
static_cast<void>(op_);
std::vector<std::pair<const char*, std::string>> props_;
props_.emplace_back("indexing", op_.indexing);
return props_;
}
std::string MeshGrid_make_name_impl(const OpDef& def_) {
auto&& op_ = def_.cast_final_safe<MeshGrid>();
static_cast<void>(op_);
return "MeshGrid";
}
} // anonymous namespace
OP_TRAIT_REG(MeshGrid, MeshGrid)
.hash(MeshGrid_hash_impl)
.is_same_st(MeshGrid_is_same_st_impl)
.props(MeshGrid_props_impl)
.make_name(MeshGrid_make_name_impl);
MGB_DYN_TYPE_OBJ_FINAL_IMPL(MeshIndexing);
MGB_DYN_TYPE_OBJ_FINAL_IMPL(MeshIndexing);
namespace {
namespace {
...
...
imperative/tablegen/generated/opdef.cpy.inl
浏览文件 @
d3b2b519
...
@@ -12467,6 +12467,95 @@ void _init_py_MatrixMul(py::module m) {
...
@@ -12467,6 +12467,95 @@ void _init_py_MatrixMul(py::module m) {
mgb_assert(PyOp(OpDef)::ctype2pytype.emplace(MatrixMul::typeinfo(), &py_type).second);
mgb_assert(PyOp(OpDef)::ctype2pytype.emplace(MatrixMul::typeinfo(), &py_type).second);
}
}
PyOpDefBegin(MeshGrid) // {
static PyGetSetDef py_getsetters[];
static PyMethodDef tp_methods[];
static PyObject* getstate(PyObject* self, PyObject*) {
auto& opdef = reinterpret_cast<PyOp(MeshGrid)*>(self)->inst();
static_cast<void>(opdef);
std::unordered_map<std::string, py::object> state {
{"indexing", serialization<decltype(opdef.indexing)>::dump(opdef.indexing)}
};
return py::cast(state).release().ptr();
}
static PyObject* setstate(PyObject* self, PyObject* args) {
PyObject* dict = PyTuple_GetItem(args, 0);
if (!dict) return NULL;
auto state = py::cast<std::unordered_map<std::string, py::object>>(dict);
auto& opdef = reinterpret_cast<PyOp(MeshGrid)*>(self)->inst();
static_cast<void>(opdef);
{
auto&& iter = state.find("indexing");
if (iter != state.end()) {
opdef.indexing = serialization<decltype(opdef.indexing)>::load(iter->second);
}
}
Py_RETURN_NONE;
}
static int py_init(PyObject *self, PyObject *args, PyObject *kwds);
// };
PyOpDefEnd(MeshGrid)
int PyOp(MeshGrid)::py_init(PyObject *self, PyObject *args, PyObject *kwds) {
static const char* kwlist[] = {"indexing", "scope", NULL};
PyObject *indexing = NULL, *scope = NULL;
if (!PyArg_ParseTupleAndKeywords(args, kwds, "|OO", const_cast<char**>(kwlist), &indexing, &scope))
return -1;
if (indexing) {
try {
// TODO: remove this guard which is used for pybind11 implicit conversion
py::detail::loader_life_support guard{};
reinterpret_cast<PyOp(MeshGrid)*>(self)->inst().indexing =
py::cast<decltype(MeshGrid::indexing)>(py::handle(indexing));
} CATCH_ALL(-1)
}
if (scope) {
try {
reinterpret_cast<PyOp(OpDef)*>(self)->op
->set_scope(py::cast<std::string>(py::handle(scope)));
} CATCH_ALL(-1)
}
return 0;
}
PyGetSetDef PyOp(MeshGrid)::py_getsetters[] = {
{const_cast<char*>("indexing"), py_get_generic(MeshGrid, indexing), py_set_generic(MeshGrid, indexing), const_cast<char*>("indexing"), NULL},
{NULL} /* Sentinel */
};
PyMethodDef PyOp(MeshGrid)::tp_methods[] = {
{const_cast<char*>("__getstate__"), PyOp(MeshGrid)::getstate, METH_NOARGS, "MeshGrid getstate"},
{const_cast<char*>("__setstate__"), PyOp(MeshGrid)::setstate, METH_VARARGS, "MeshGrid setstate"},
{NULL} /* Sentinel */
};
void _init_py_MeshGrid(py::module m) {
using py_op = PyOp(MeshGrid);
auto& py_type = PyOpType(MeshGrid);
py_type = {PyVarObject_HEAD_INIT(NULL, 0)};
py_type.tp_name = "megengine.core._imperative_rt.ops.MeshGrid";
py_type.tp_basicsize = sizeof(PyOp(MeshGrid));
py_type.tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE;
py_type.tp_doc = "MeshGrid";
py_type.tp_base = &PyOpType(OpDef);
py_type.tp_dealloc = py_dealloc_generic<py_op>;
py_type.tp_new = py_new_generic<py_op>;
py_type.tp_init = py_op::py_init;
py_type.tp_methods = py_op::tp_methods;
py_type.tp_getset = py_op::py_getsetters;
mgb_assert(PyType_Ready(&py_type) >= 0);
PyType_Modified(&py_type);
m.add_object("MeshGrid", reinterpret_cast<PyObject*>(&py_type));
mgb_assert(PyOp(OpDef)::ctype2pytype.emplace(MeshGrid::typeinfo(), &py_type).second);
}
PyOpDefBegin(MeshIndexing) // {
PyOpDefBegin(MeshIndexing) // {
static PyGetSetDef py_getsetters[];
static PyGetSetDef py_getsetters[];
static PyMethodDef tp_methods[];
static PyMethodDef tp_methods[];
...
@@ -18594,6 +18683,7 @@ void _init_py_WarpPerspectiveBackwardMat(py::module m) {
...
@@ -18594,6 +18683,7 @@ void _init_py_WarpPerspectiveBackwardMat(py::module m) {
_init_py_MagicMindRuntime(m); \
_init_py_MagicMindRuntime(m); \
_init_py_MatrixInverse(m); \
_init_py_MatrixInverse(m); \
_init_py_MatrixMul(m); \
_init_py_MatrixMul(m); \
_init_py_MeshGrid(m); \
_init_py_MeshIndexing(m); \
_init_py_MeshIndexing(m); \
_init_py_NMSKeep(m); \
_init_py_NMSKeep(m); \
_init_py_NvOf(m); \
_init_py_NvOf(m); \
...
...
imperative/tablegen/generated/opdef.h.inl
浏览文件 @
d3b2b519
...
@@ -1262,6 +1262,15 @@ public:
...
@@ -1262,6 +1262,15 @@ public:
}
}
};
};
class MeshGrid : public OpDefImplBase<MeshGrid> {
MGB_DYN_TYPE_OBJ_FINAL_DECL;
public:
std::string indexing;
MeshGrid() = default;
MeshGrid(std::string indexing_, std::string scope_ = {}): indexing(indexing_) { set_scope(scope_); }
};
class MeshIndexing : public OpDefImplBase<MeshIndexing> {
class MeshIndexing : public OpDefImplBase<MeshIndexing> {
MGB_DYN_TYPE_OBJ_FINAL_DECL;
MGB_DYN_TYPE_OBJ_FINAL_DECL;
...
...
imperative/tablegen/generated/opdef.py.inl
浏览文件 @
d3b2b519
...
@@ -1365,6 +1365,13 @@ MatrixMulInst
...
@@ -1365,6 +1365,13 @@ MatrixMulInst
.def_readwrite("dimA", &MatrixMul::dimA)
.def_readwrite("dimA", &MatrixMul::dimA)
.def_readwrite("dimB", &MatrixMul::dimB);
.def_readwrite("dimB", &MatrixMul::dimB);
py::class_<MeshGrid, std::shared_ptr<MeshGrid>, OpDef> MeshGridInst(m, "MeshGrid");
MeshGridInst
.def(py::init<std::string, std::string>(), py::arg("indexing"), py::arg("scope") = {})
.def(py::init<>())
.def_readwrite("indexing", &MeshGrid::indexing);
py::class_<MeshIndexing, std::shared_ptr<MeshIndexing>, OpDef> MeshIndexingInst(m, "MeshIndexing");
py::class_<MeshIndexing, std::shared_ptr<MeshIndexing>, OpDef> MeshIndexingInst(m, "MeshIndexing");
MeshIndexingInst
MeshIndexingInst
...
...
src/core/include/megbrain/ir/ops.td
浏览文件 @
d3b2b519
...
@@ -515,4 +515,9 @@ def Dropout: MgbHashableOp<"Dropout", [DropoutParam]> {
...
@@ -515,4 +515,9 @@ def Dropout: MgbHashableOp<"Dropout", [DropoutParam]> {
let cmpFunction = [{return $0.handle == $1.handle && $0.drop_prob == $1.drop_prob;}];
let cmpFunction = [{return $0.handle == $1.handle && $0.drop_prob == $1.drop_prob;}];
}
}
def MeshGrid: MgbHashableOp<"MeshGrid"> {
let extraArguments = (ins
MgbStringAttr:$indexing
);
}
#endif // MGB_OPS
#endif // MGB_OPS
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录