Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
55efc8e1
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
403
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
55efc8e1
编写于
7月 01, 2021
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
feat(mgb/gopt): add reformat emitter
GitOrigin-RevId: 937b20a57ce0f93add1e4de77365c3e73ec417b1
上级
c9d06030
变更
4
隐藏空白更改
内联
并排
Showing
4 changed file
with
323 addition
and
1 deletion
+323
-1
dnn/include/megdnn/named_tensor.h
dnn/include/megdnn/named_tensor.h
+0
-1
src/gopt/impl/reformat_emitter.cpp
src/gopt/impl/reformat_emitter.cpp
+172
-0
src/gopt/include/megbrain/gopt/reformat_emitter.h
src/gopt/include/megbrain/gopt/reformat_emitter.h
+66
-0
src/gopt/test/reformat_emitter.cpp
src/gopt/test/reformat_emitter.cpp
+85
-0
未找到文件。
dnn/include/megdnn/named_tensor.h
浏览文件 @
55efc8e1
...
...
@@ -14,7 +14,6 @@
#include "megdnn/internal/defs.h"
#include "megdnn/opr_param_defs.h"
#include "src/common/utils.h"
#include <array>
#include <string>
...
...
src/gopt/impl/reformat_emitter.cpp
0 → 100644
浏览文件 @
55efc8e1
/**
* \file src/gopt/impl/reformat_emitter.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include <numeric>
#include "megbrain/gopt/reformat_emitter.h"
#include "megbrain/opr/tensor_manip.h"
using
namespace
mgb
;
using
namespace
gopt
;
using
Dimension
=
megdnn
::
Dimension
;
using
NamedTensorShape
=
megdnn
::
NamedTensorShape
;
ReshapeEmitter
::
Operator
ReshapeEmitter
::
emit
()
const
{
auto
pattern
=
analyze
();
auto
op
=
[
pattern
](
VarNode
*
var
)
{
auto
sym_var
=
SymbolVar
(
var
);
auto
shp
=
opr
::
GetVarShape
::
make
(
sym_var
);
auto
cv
=
[
&
sym_var
](
int
c
)
{
return
sym_var
.
make_scalar
(
c
);
};
auto
sub
=
[
&
shp
,
&
cv
](
int
ax
)
{
return
opr
::
IndexAt
::
make
(
shp
,
{{
0
,
cv
(
ax
)}});
};
SymbolVarArray
axs
;
for
(
auto
i
:
pattern
)
{
if
(
std
::
get
<
0
>
(
i
)
>=
0
)
{
if
(
std
::
get
<
2
>
(
i
))
axs
.
emplace_back
(
sub
(
std
::
get
<
0
>
(
i
))
*
std
::
get
<
1
>
(
i
));
else
axs
.
emplace_back
(
sub
(
std
::
get
<
0
>
(
i
))
/
std
::
get
<
1
>
(
i
));
}
else
{
axs
.
emplace_back
(
cv
(
std
::
get
<
1
>
(
i
)));
}
}
auto
tshp
=
opr
::
Concat
::
make
(
axs
,
0
);
auto
ovar
=
opr
::
Reshape
::
make
(
sym_var
,
tshp
);
return
ovar
.
node
();
};
return
op
;
}
SmallVector
<
std
::
tuple
<
int
,
int
,
bool
>>
ReshapeEmitter
::
analyze
()
const
{
static
constexpr
uint32_t
UNDETERMINED_EXTENT
=
Dimension
::
UNDETERMINED_EXTENT
;
ThinHashMap
<
Dimension
::
Name
,
int
>
name2dominant
;
for
(
size_t
i
=
0
;
i
<
m_src
.
ndim
;
++
i
)
{
auto
name
=
m_src
[
i
].
name
();
if
(
m_src
[
i
].
extent
()
==
UNDETERMINED_EXTENT
)
{
auto
insert
=
name2dominant
.
insert
(
std
::
make_pair
(
name
,
i
));
mgb_assert
(
insert
.
second
);
}
}
SmallVector
<
std
::
tuple
<
int
,
int
,
bool
>>
pattern
(
m_dest
.
ndim
);
for
(
size_t
i
=
0
;
i
<
m_dest
.
ndim
;
++
i
)
{
auto
name
=
m_dest
[
i
].
name
();
if
(
m_dest
[
i
].
extent
()
==
UNDETERMINED_EXTENT
)
{
int
src_dim
=
name2dominant
.
at
(
name
);
bool
mul
=
m_src
[
src_dim
]
<
m_dest
[
i
];
int
factor
=
mul
?
(
m_dest
[
i
]
/
m_src
[
src_dim
]).
extent
()
:
(
m_src
[
src_dim
]
/
m_dest
[
i
]).
extent
();
pattern
[
i
]
=
std
::
make_tuple
(
src_dim
,
factor
,
mul
);
}
else
{
pattern
[
i
]
=
std
::
make_tuple
(
-
1
,
m_dest
[
i
].
extent
(),
false
);
}
}
return
pattern
;
}
DimshuffleEmitter
::
Operator
DimshuffleEmitter
::
emit
()
const
{
auto
pattern
=
m_pattern
;
auto
op
=
[
pattern
](
VarNode
*
var
)
{
auto
sym_var
=
SymbolVar
(
var
);
return
opr
::
Dimshuffle
::
make
(
sym_var
,
pattern
).
node
();
};
return
op
;
}
ReformatEmitter
::
Operator
ReformatEmitter
::
emit
()
const
{
auto
ops
=
analyze
();
auto
op
=
[
ops
](
VarNode
*
var
)
{
VarNode
*
ovar
=
var
;
for
(
const
auto
&
o
:
ops
)
{
ovar
=
o
(
ovar
);
}
return
ovar
;
};
return
op
;
}
SmallVector
<
ReformatEmitter
::
Operator
>
ReformatEmitter
::
analyze
()
const
{
struct
Dim
{
Dimension
dim
;
int
index
;
Dim
(
Dimension
dim_
,
int
index_
)
:
dim
{
dim_
},
index
{
index_
}
{}
};
SmallVector
<
Dim
>
src_dims
;
SmallVector
<
Dim
>
dest_dims
;
for
(
size_t
i
=
0
;
i
<
m_src
.
ndim
;
++
i
)
src_dims
.
emplace_back
(
Dim
(
m_src
[
i
],
i
));
for
(
size_t
i
=
0
;
i
<
m_dest
.
ndim
;
++
i
)
dest_dims
.
emplace_back
(
Dim
(
m_dest
[
i
],
i
));
auto
compare
=
[](
const
Dim
&
lhs
,
const
Dim
&
rhs
)
{
return
lhs
.
dim
<
rhs
.
dim
;
};
std
::
sort
(
src_dims
.
begin
(),
src_dims
.
end
(),
compare
);
std
::
sort
(
dest_dims
.
begin
(),
dest_dims
.
end
(),
compare
);
auto
src_iter
=
src_dims
.
begin
();
auto
dest_iter
=
dest_dims
.
begin
();
for
(;
src_iter
!=
src_dims
.
end
()
&&
dest_iter
!=
dest_dims
.
end
();)
{
if
(
src_iter
->
dim
==
dest_iter
->
dim
)
{
src_iter
++
;
dest_iter
++
;
}
else
if
(
src_iter
->
dim
<
dest_iter
->
dim
)
{
auto
split
=
dest_iter
->
dim
/
src_iter
->
dim
;
int
dim_idx
=
dest_iter
->
index
;
dest_iter
=
dest_dims
.
insert
(
dest_iter
,
Dim
(
src_iter
->
dim
,
dim_idx
));
dest_iter
++
;
dest_iter
->
dim
=
split
;
dest_iter
->
index
=
dim_idx
;
src_iter
++
;
}
else
{
auto
split
=
src_iter
->
dim
/
dest_iter
->
dim
;
int
dim_idx
=
src_iter
->
index
;
src_iter
=
src_dims
.
insert
(
src_iter
,
Dim
(
dest_iter
->
dim
,
dim_idx
));
src_iter
++
;
src_iter
->
dim
=
split
;
src_iter
->
index
=
dim_idx
;
dest_iter
++
;
}
}
mgb_assert
(
src_dims
.
size
()
==
dest_dims
.
size
());
std
::
vector
<
int
>
src_perm
(
src_dims
.
size
());
std
::
vector
<
int
>
permute
(
dest_dims
.
size
());
std
::
iota
(
src_perm
.
begin
(),
src_perm
.
end
(),
0
);
std
::
iota
(
permute
.
begin
(),
permute
.
end
(),
0
);
std
::
sort
(
src_perm
.
begin
(),
src_perm
.
end
(),
[
&
](
const
int
a
,
const
int
b
)
{
if
(
src_dims
[
a
].
index
!=
src_dims
[
b
].
index
)
return
src_dims
[
a
].
index
<
src_dims
[
b
].
index
;
return
src_dims
[
a
].
dim
<
src_dims
[
b
].
dim
;
});
std
::
sort
(
permute
.
begin
(),
permute
.
end
(),
[
&
](
const
int
a
,
const
int
b
)
{
int
perm_a
=
src_perm
[
a
];
int
perm_b
=
src_perm
[
b
];
if
(
dest_dims
[
perm_a
].
index
!=
dest_dims
[
perm_b
].
index
)
return
dest_dims
[
perm_a
].
index
<
dest_dims
[
perm_b
].
index
;
return
dest_dims
[
perm_a
].
dim
<
dest_dims
[
perm_b
].
dim
;
});
NamedTensorShape
i1
,
i2
;
i1
.
ndim
=
src_dims
.
size
(),
i2
.
ndim
=
dest_dims
.
size
();
for
(
size_t
i
=
0
;
i
<
src_dims
.
size
();
++
i
)
{
i1
[
i
]
=
src_dims
[
src_perm
[
i
]].
dim
;
i2
[
i
]
=
src_dims
[
src_perm
[
permute
[
i
]]].
dim
;
}
SmallVector
<
Operator
>
ops
;
if
(
!
m_src
.
eq_shape
(
i1
))
ops
.
emplace_back
(
ReshapeEmitter
(
m_src
,
i1
).
emit
());
ops
.
emplace_back
(
DimshuffleEmitter
(
permute
).
emit
());
if
(
!
m_dest
.
eq_shape
(
i2
))
ops
.
emplace_back
(
ReshapeEmitter
(
i2
,
m_dest
).
emit
());
return
ops
;
}
src/gopt/include/megbrain/gopt/reformat_emitter.h
0 → 100644
浏览文件 @
55efc8e1
/**
* \file src/gopt/include/megbrain/gopt/reformat_emitter.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#pragma once
#include <vector>
#include "megbrain/graph.h"
#include "megdnn/named_tensor.h"
namespace
mgb
{
namespace
gopt
{
class
Emitter
{
public:
using
Operator
=
thin_function
<
VarNode
*
(
VarNode
*
)
>
;
virtual
~
Emitter
()
=
default
;
virtual
Operator
emit
()
const
=
0
;
};
class
ReshapeEmitter
final
:
public
Emitter
{
public:
using
Operator
=
typename
Emitter
::
Operator
;
ReshapeEmitter
(
const
megdnn
::
NamedTensorShape
&
src
,
const
megdnn
::
NamedTensorShape
&
dest
)
:
m_src
{
src
},
m_dest
{
dest
}
{}
Operator
emit
()
const
override
;
private:
SmallVector
<
std
::
tuple
<
int
,
int
,
bool
>>
analyze
()
const
;
megdnn
::
NamedTensorShape
m_src
,
m_dest
;
};
class
DimshuffleEmitter
final
:
public
Emitter
{
public:
using
Operator
=
typename
Emitter
::
Operator
;
DimshuffleEmitter
(
const
std
::
vector
<
int
>&
pattern
)
:
m_pattern
{
pattern
}
{}
Operator
emit
()
const
override
;
private:
std
::
vector
<
int
>
m_pattern
;
};
class
ReformatEmitter
final
:
public
Emitter
{
public:
using
Operator
=
typename
Emitter
::
Operator
;
ReformatEmitter
(
const
megdnn
::
NamedTensorShape
&
src
,
const
megdnn
::
NamedTensorShape
&
dest
)
:
m_src
{
src
},
m_dest
{
dest
}
{}
Operator
emit
()
const
override
;
private:
SmallVector
<
Operator
>
analyze
()
const
;
megdnn
::
NamedTensorShape
m_src
,
m_dest
;
};
}
// namespace gopt
}
// namespace mgb
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
src/gopt/test/reformat_emitter.cpp
0 → 100644
浏览文件 @
55efc8e1
/**
* \file src/gopt/test/reformat_emitter.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "./helper.h"
#include "megbrain/gopt/reformat_emitter.h"
#include "megbrain/opr/tensor_manip.h"
using
namespace
mgb
;
TEST
(
TestReformatEmitter
,
Basic
)
{
constexpr
size_t
N
=
12
,
C
=
64
,
H
=
7
,
W
=
7
;
HostTensorGenerator
<>
gen
;
using
NamedTensorShape
=
megdnn
::
NamedTensorShape
;
auto
dest
=
NamedTensorShape
::
make_named_tensor_shape
(
NamedTensorShape
::
Format
::
NCHW4
);
auto
src
=
NamedTensorShape
::
make_named_tensor_shape
(
NamedTensorShape
::
Format
::
NCHW32
);
auto
reformat
=
gopt
::
ReformatEmitter
(
src
,
dest
).
emit
();
auto
graph
=
ComputingGraph
::
make
();
graph
->
options
().
graph_opt_level
=
0
;
auto
nchw32_to_nchw4
=
[](
VarNode
*
in
)
{
auto
x
=
SymbolVar
(
in
);
auto
xshp
=
opr
::
GetVarShape
::
make
(
x
);
auto
cv
=
[
&
x
](
int
v
)
{
return
x
.
make_scalar
(
v
);
};
auto
sub
=
[
&
xshp
,
&
cv
](
int
idx
)
{
return
opr
::
IndexAt
::
make
(
xshp
,
{{
0
,
cv
(
idx
)}});
};
auto
tshp0
=
opr
::
Concat
::
make
(
{
sub
(
0
),
sub
(
1
),
sub
(
2
),
sub
(
3
),
cv
(
8
),
sub
(
4
)
/
8
},
0
),
tshp1
=
opr
::
Concat
::
make
(
{
sub
(
0
),
sub
(
1
)
*
8
,
sub
(
2
),
sub
(
3
),
sub
(
4
)
/
8
},
0
);
auto
y0
=
opr
::
Reshape
::
make
(
x
,
tshp0
);
auto
y1
=
opr
::
Dimshuffle
::
make
(
y0
,
{
0
,
1
,
4
,
2
,
3
,
5
});
auto
y2
=
opr
::
Reshape
::
make
(
y1
,
tshp1
);
return
y2
.
node
();
};
auto
mkvar
=
[
&
](
const
char
*
name
,
const
TensorShape
&
shp
)
{
return
opr
::
Host2DeviceCopy
::
make
(
*
graph
,
gen
(
shp
)).
rename
(
name
);
};
auto
x
=
mkvar
(
"x"
,
{
N
,
C
/
32
,
H
,
W
,
32
});
auto
y1
=
SymbolVar
(
reformat
(
x
.
node
()));
auto
y2
=
SymbolVar
(
nchw32_to_nchw4
(
x
.
node
()));
HostTensorND
t1
,
t2
;
auto
func1
=
graph
->
compile
({
make_callback_copy
(
y1
,
t1
)});
func1
->
execute
();
auto
func2
=
graph
->
compile
({
make_callback_copy
(
y2
,
t2
)});
func2
->
execute
();
MGB_ASSERT_TENSOR_EQ
(
t1
,
t2
);
}
TEST
(
TestReformatEmitter
,
MoreComplicated
)
{
constexpr
size_t
N
=
16
,
C
=
64
,
H
=
7
,
W
=
7
;
HostTensorGenerator
<>
gen
;
using
NamedTensorShape
=
megdnn
::
NamedTensorShape
;
auto
src
=
NamedTensorShape
::
make_named_tensor_shape
(
NamedTensorShape
::
Format
::
NCHW64
);
auto
dest
=
NamedTensorShape
::
make_named_tensor_shape
(
NamedTensorShape
::
Format
::
NCHW88
);
auto
reformat
=
gopt
::
ReformatEmitter
(
src
,
dest
).
emit
();
auto
graph
=
ComputingGraph
::
make
();
graph
->
options
().
graph_opt_level
=
0
;
auto
mkvar
=
[
&
](
const
char
*
name
,
const
TensorShape
&
shp
)
{
return
opr
::
Host2DeviceCopy
::
make
(
*
graph
,
gen
(
shp
)).
rename
(
name
);
};
auto
x
=
mkvar
(
"x"
,
{
N
,
C
/
64
,
H
,
W
,
64
});
auto
y
=
SymbolVar
(
reformat
(
x
.
node
()));
HostTensorND
t
;
auto
func
=
graph
->
compile
({
make_callback_copy
(
y
,
t
)});
func
->
execute
();
}
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录