Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
f7e10ea8
MegEngine
项目概览
MegEngine 天元
/
MegEngine
大约 1 年 前同步成功
通知
396
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
f7e10ea8
编写于
3月 22, 2022
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
perf(imperative): improve matmul/batch_matmul
GitOrigin-RevId: 4ceb2eb60148113dd789416d604f0e4f76a4ec7c
上级
1c2a323e
变更
12
隐藏空白更改
内联
并排
Showing
12 changed file
with
657 addition
and
160 deletion
+657
-160
imperative/python/megengine/core/tensor/array_method.py
imperative/python/megengine/core/tensor/array_method.py
+63
-24
imperative/python/megengine/functional/math.py
imperative/python/megengine/functional/math.py
+1
-1
imperative/python/megengine/functional/nn.py
imperative/python/megengine/functional/nn.py
+0
-17
imperative/python/src/tensor.cpp
imperative/python/src/tensor.cpp
+4
-0
imperative/python/src/tensor_utils.cpp
imperative/python/src/tensor_utils.cpp
+94
-0
imperative/python/src/tensor_utils.h
imperative/python/src/tensor_utils.h
+4
-0
imperative/python/test/unit/autodiff/test_grad_manager.py
imperative/python/test/unit/autodiff/test_grad_manager.py
+0
-0
imperative/src/impl/ops/dot.cpp
imperative/src/impl/ops/dot.cpp
+0
-87
imperative/src/impl/ops/matmul.cpp
imperative/src/impl/ops/matmul.cpp
+435
-0
imperative/src/impl/ops/reduce.cpp
imperative/src/impl/ops/reduce.cpp
+2
-3
imperative/src/impl/ops/specializations.cpp
imperative/src/impl/ops/specializations.cpp
+0
-28
imperative/src/impl/transformations/dtype_promote.cpp
imperative/src/impl/transformations/dtype_promote.cpp
+54
-0
未找到文件。
imperative/python/megengine/core/tensor/array_method.py
浏览文件 @
f7e10ea8
...
...
@@ -20,9 +20,10 @@ from .._imperative_rt.core2 import (
Tensor
,
apply
,
astype_cpp
,
batched_matmul_cpp
,
broadcast_cpp
,
dtype_promotion
,
getitem_cpp
,
matmul_cpp
,
)
from
.._imperative_rt.core2
import
reduce_to_scalar
as
_reduce_to_scalar
from
.._imperative_rt.core2
import
reshape_cpp
,
setitem_cpp
,
squeeze_cpp
,
transpose_cpp
...
...
@@ -266,6 +267,42 @@ class _Hashable:
return
self
.
value
==
o
.
value
def
symbolicMatrixMul
(
inp1
,
inp2
,
dim1
,
dim2
,
transpose_a
,
transpose_b
,
compute_mode
,
format
,
strategy
):
extentedMatrixMulOp
=
_get_extentedMatrixMulOp
(
inp1
.
device
,
inp1
.
dtype
,
dim1
,
dim2
,
transpose_a
,
transpose_b
,
compute_mode
,
format
,
strategy
=
_Hashable
(
strategy
),
)
(
result
,)
=
apply
(
extentedMatrixMulOp
(),
inp1
,
inp2
)
return
result
def
symbolicBatchedMatrixMul
(
inp1
,
inp2
,
dim1
,
dim2
,
transpose_a
,
transpose_b
,
compute_mode
,
format
,
strategy
):
extentedBatchedMatrixMulOp
=
_get_extentedBatchedMatrixMulOp
(
inp1
.
device
,
inp1
.
dtype
,
dim1
,
dim2
,
transpose_a
,
transpose_b
,
compute_mode
,
format
,
strategy
=
_Hashable
(
strategy
),
)
(
result
,)
=
apply
(
extentedBatchedMatrixMulOp
(),
inp1
,
inp2
)
return
result
def
_matmul
(
inp1
,
inp2
,
...
...
@@ -274,16 +311,6 @@ def _matmul(
compute_mode
=
"default"
,
format
=
"default"
,
):
if
amp
.
_enabled
:
compute_mode
=
"float32"
inp1
,
inp2
=
cast_tensors
(
inp1
,
inp2
)
else
:
dtype
=
dtype_promotion
(
inp1
,
inp2
)
if
inp1
.
dtype
!=
dtype
:
inp1
=
inp1
.
astype
(
dtype
)
if
inp2
.
dtype
!=
dtype
:
inp2
=
inp2
.
astype
(
dtype
)
dim1
,
dim2
=
inp1
.
ndim
,
inp2
.
ndim
assert
dim1
>
0
and
dim2
>
0
maxdim
=
dim1
if
dim1
>
dim2
else
dim2
...
...
@@ -301,34 +328,46 @@ def _matmul(
if
dim1
==
1
and
dim2
==
1
:
# dispatch to Dot
(
result
,)
=
apply
(
builtin
.
Dot
(),
inp1
,
inp2
)
return
result
elif
maxdim
<=
2
or
dim2
<=
2
:
# dispath to MatrixMul
extentedMatrixMulOp
=
_get_extentedMatrixMulOp
(
inp1
.
device
,
inp1
.
dtype
,
elif
maxdim
<=
2
or
(
dim2
<=
2
and
not
transpose_a
):
# dispath to MatrixMul
# 2x1
# 1x2
# 2x2
# nx1(transpose_a=False), n>=3
# nx2(transpose_a=False), n>=3
return
matmul_cpp
(
inp1
,
inp2
,
dim1
,
dim2
,
transpose_a
,
transpose_b
,
compute_mode
,
format
,
strategy
=
_Hashable
(
strategy
),
_config
.
_benchmark_kernel
,
_config
.
_deterministic_kernel
,
strategy
,
symbolicMatrixMul
,
)
(
result
,)
=
apply
(
extentedMatrixMulOp
(),
inp1
,
inp2
)
return
result
else
:
# dispath to BatchedMatrixMul
extentedBatchedMatrixMulOp
=
_get_extentedBatchedMatrixMulOp
(
inp1
.
device
,
inp1
.
dtype
,
# nx1(transpose_a=True), n>=3
# nx2(transpose_a=True), n>=3
# nxm,n>=3,m>=3
# 1xm,m>=3
# 2xm,m>=3
return
batched_matmul_cpp
(
inp1
,
inp2
,
dim1
,
dim2
,
transpose_a
,
transpose_b
,
compute_mode
,
format
,
strategy
=
_Hashable
(
strategy
),
_config
.
_benchmark_kernel
,
_config
.
_deterministic_kernel
,
strategy
,
symbolicBatchedMatrixMul
,
)
(
result
,)
=
apply
(
extentedBatchedMatrixMulOp
(),
inp1
,
inp2
)
return
result
def
_unary_elwise
(
mode
):
...
...
imperative/python/megengine/functional/math.py
浏览文件 @
f7e10ea8
...
...
@@ -10,7 +10,7 @@ import collections
import
math
from
typing
import
Iterable
,
Optional
,
Sequence
,
Tuple
,
Union
from
..core._imperative_rt.core2
import
Const
,
apply
,
dtype_promotion
from
..core._imperative_rt.core2
import
Const
,
apply
from
..core._imperative_rt.ops
import
SubgraphBuilder
as
_SubgraphBuilder
from
..core.ops
import
builtin
from
..core.tensor.array_method
import
_matmul
...
...
imperative/python/megengine/functional/nn.py
浏览文件 @
f7e10ea8
...
...
@@ -17,7 +17,6 @@ from ..core._imperative_rt.core2 import (
apply
,
dtype_promotion
,
)
from
..core._imperative_rt.ops
import
SubgraphBuilder
as
_SubgraphBuilder
from
..core._imperative_rt.ops
import
get_global_rng_seed
as
_get_global_rng_seed
from
..core.ops
import
builtin
from
..core.ops.builtin
import
(
...
...
@@ -177,16 +176,6 @@ def conv1d(
assert
compute_mode
.
lower
()
==
"default"
or
compute_mode
.
name
==
"DEFAULT"
assert
inp
.
ndim
==
3
,
"the input dimension of conv1d should be 3"
assert
weight
.
ndim
==
3
,
"the weight dimension of conv1d should be 3"
if
amp
.
_enabled
:
compute_mode
=
"float32"
inp
,
weight
,
bias
=
cast_tensors
(
inp
,
weight
,
bias
)
else
:
dtype
=
dtype_promotion
(
inp
,
weight
)
if
inp
.
dtype
!=
dtype
:
inp
=
inp
.
astype
(
dtype
)
if
weight
.
dtype
!=
dtype
:
weight
=
weight
.
astype
(
dtype
)
if
bias
is
not
None
:
assert
bias
.
ndim
==
3
,
"the bias dimension of conv1d should be 3"
...
...
@@ -522,12 +511,6 @@ def local_conv2d(
pad_h
,
pad_w
=
expand_hw
(
padding
)
dilate_h
,
dilate_w
=
expand_hw
(
dilation
)
dtype
=
dtype_promotion
(
inp
,
weight
)
if
inp
.
dtype
!=
dtype
:
inp
=
inp
.
astype
(
dtype
)
if
weight
.
dtype
!=
dtype
:
weight
=
weight
.
astype
(
dtype
)
# local conv only support "dense" mode, but weight could contain group dimension.
op
=
builtin
.
GroupLocal
(
stride_h
=
stride_h
,
...
...
imperative/python/src/tensor.cpp
浏览文件 @
f7e10ea8
...
...
@@ -433,6 +433,8 @@ WRAP_FUNC_PY35(reshape_cpp);
WRAP_FUNC_PY35
(
adaptive_pool2d_cpp
);
WRAP_FUNC_PY35
(
Const
);
WRAP_FUNC_PY35
(
astype_cpp
);
WRAP_FUNC_PY35
(
matmul_cpp
);
WRAP_FUNC_PY35
(
batched_matmul_cpp
);
WRAP_FUNC_PY35
(
convert_single_value_cpp
);
WRAP_FUNC_PY35
(
convert_inputs_cpp
);
WRAP_FUNC_PY35
(
astensor1d_cpp
);
...
...
@@ -588,6 +590,8 @@ void init_tensor(py::module m) {
MGE_PY_INTERFACE
(
adaptive_pool2d_cpp
,
adaptive_pool2d_cpp
),
MGE_PY_INTERFACE
(
Const
,
Const
),
MGE_PY_INTERFACE
(
astype_cpp
,
astype_cpp
),
MGE_PY_INTERFACE
(
matmul_cpp
,
matmul_cpp
),
MGE_PY_INTERFACE
(
batched_matmul_cpp
,
batched_matmul_cpp
),
MGE_PY_INTERFACE
(
convert_single_value_cpp
,
convert_single_value_cpp
),
MGE_PY_INTERFACE
(
convert_inputs_cpp
,
convert_inputs_cpp
),
MGE_PY_INTERFACE
(
astensor1d_cpp
,
astensor1d_cpp
),
...
...
imperative/python/src/tensor_utils.cpp
浏览文件 @
f7e10ea8
...
...
@@ -1490,6 +1490,78 @@ py::object _transpose_cpp(py::handle inp_hdl, py::handle args) {
return
ret
[
0
];
}
py
::
object
_matmul_cpp
(
py
::
handle
inp1
,
py
::
handle
inp2
,
py
::
handle
dim1
,
py
::
handle
dim2
,
py
::
handle
transpose_a
,
py
::
handle
transpose_b
,
py
::
handle
compute_mode
,
py
::
handle
format
,
py
::
handle
profile
,
py
::
handle
determistic
,
py
::
handle
strategy
,
py
::
handle
func
)
{
if
(
enable_fastpath
(
inp1
))
{
::
megdnn
::
param
::
MatrixMul
::
ComputeMode
mode
=
::
megdnn
::
param
::
MatrixMul
::
ComputeMode
::
DEFAULT
;
if
(
compute_mode
.
cast
<
std
::
string
>
().
compare
(
std
::
string
(
"float32"
))
==
0
)
{
mode
=
::
megdnn
::
param
::
MatrixMul
::
ComputeMode
::
FLOAT32
;
}
::
megdnn
::
param
::
ExecutionPolicy
::
Strategy
cstrategy
;
if
(
profile
.
cast
<
bool
>
())
{
cstrategy
|=
::
megdnn
::
param
::
ExecutionPolicy
::
Strategy
::
PROFILE
;
}
else
{
cstrategy
|=
::
megdnn
::
param
::
ExecutionPolicy
::
Strategy
::
HEURISTIC
;
}
if
(
determistic
.
cast
<
bool
>
())
{
cstrategy
|=
::
megdnn
::
param
::
ExecutionPolicy
::
Strategy
::
REPRODUCIBLE
;
}
std
::
shared_ptr
<
OpDef
>
op
=
MatrixMul
::
make
(
transpose_a
.
cast
<
bool
>
(),
transpose_b
.
cast
<
bool
>
(),
mode
,
::
megdnn
::
param
::
MatrixMul
::
Format
::
DEFAULT
,
cstrategy
,
UINT64_MAX
);
py
::
object
Op
=
py
::
cast
(
op
);
PyObject
*
p
[
3
]
=
{
Op
.
ptr
(),
inp1
.
ptr
(),
inp2
.
ptr
()};
py
::
tuple
ret
=
py
::
reinterpret_steal
<
py
::
object
>
(
py_apply
(
NULL
,
p
,
3
));
return
ret
[
0
];
}
else
{
// fallback to traceable implementation
return
func
(
inp1
,
inp2
,
dim1
,
dim2
,
transpose_a
,
transpose_b
,
compute_mode
,
format
,
strategy
);
}
}
py
::
object
_batched_matmul_cpp
(
py
::
handle
inp1
,
py
::
handle
inp2
,
py
::
handle
dim1
,
py
::
handle
dim2
,
py
::
handle
transpose_a
,
py
::
handle
transpose_b
,
py
::
handle
compute_mode
,
py
::
handle
format
,
py
::
handle
profile
,
py
::
handle
determistic
,
py
::
handle
strategy
,
py
::
handle
func
)
{
if
(
enable_fastpath
(
inp1
))
{
::
megdnn
::
param
::
MatrixMul
::
ComputeMode
mode
=
::
megdnn
::
param
::
MatrixMul
::
ComputeMode
::
DEFAULT
;
if
(
compute_mode
.
cast
<
std
::
string
>
().
compare
(
std
::
string
(
"float32"
))
==
0
)
{
mode
=
::
megdnn
::
param
::
MatrixMul
::
ComputeMode
::
FLOAT32
;
}
::
megdnn
::
param
::
ExecutionPolicy
::
Strategy
cstrategy
;
if
(
profile
.
cast
<
bool
>
())
{
cstrategy
|=
::
megdnn
::
param
::
ExecutionPolicy
::
Strategy
::
PROFILE
;
}
else
{
cstrategy
|=
::
megdnn
::
param
::
ExecutionPolicy
::
Strategy
::
HEURISTIC
;
}
if
(
determistic
.
cast
<
bool
>
())
{
cstrategy
|=
::
megdnn
::
param
::
ExecutionPolicy
::
Strategy
::
REPRODUCIBLE
;
}
std
::
shared_ptr
<
OpDef
>
op
=
BatchedMatrixMul
::
make
(
transpose_a
.
cast
<
bool
>
(),
transpose_b
.
cast
<
bool
>
(),
mode
,
::
megdnn
::
param
::
MatrixMul
::
Format
::
DEFAULT
,
cstrategy
,
UINT64_MAX
);
py
::
object
Op
=
py
::
cast
(
op
);
PyObject
*
p
[
3
]
=
{
Op
.
ptr
(),
inp1
.
ptr
(),
inp2
.
ptr
()};
py
::
tuple
ret
=
py
::
reinterpret_steal
<
py
::
object
>
(
py_apply
(
NULL
,
p
,
3
));
return
ret
[
0
];
}
else
{
// fallback to traceable implementation
return
func
(
inp1
,
inp2
,
dim1
,
dim2
,
transpose_a
,
transpose_b
,
compute_mode
,
format
,
strategy
);
}
}
PyObject
*
make_shape_tuple
(
PyObject
*
self
,
PyObject
*
const
*
args
,
size_t
nargs
)
{
try
{
return
_make_shape_tuple
(
args
[
0
]).
release
().
ptr
();
...
...
@@ -1574,6 +1646,28 @@ PyObject* astype_cpp(PyObject* self, PyObject* const* args, size_t nargs) {
PYEXT17_TRANSLATE_EXC_RET
(
nullptr
)
}
PyObject
*
matmul_cpp
(
PyObject
*
self
,
PyObject
*
const
*
args
,
size_t
nargs
)
{
try
{
return
_matmul_cpp
(
args
[
0
],
args
[
1
],
args
[
2
],
args
[
3
],
args
[
4
],
args
[
5
],
args
[
6
],
args
[
7
],
args
[
8
],
args
[
9
],
args
[
10
],
args
[
11
])
.
release
()
.
ptr
();
}
PYEXT17_TRANSLATE_EXC_RET
(
nullptr
)
}
PyObject
*
batched_matmul_cpp
(
PyObject
*
self
,
PyObject
*
const
*
args
,
size_t
nargs
)
{
try
{
return
_batched_matmul_cpp
(
args
[
0
],
args
[
1
],
args
[
2
],
args
[
3
],
args
[
4
],
args
[
5
],
args
[
6
],
args
[
7
],
args
[
8
],
args
[
9
],
args
[
10
],
args
[
11
])
.
release
()
.
ptr
();
}
PYEXT17_TRANSLATE_EXC_RET
(
nullptr
)
}
PyObject
*
convert_single_value_cpp
(
PyObject
*
self
,
PyObject
*
const
*
args
,
size_t
nargs
)
{
try
{
...
...
imperative/python/src/tensor_utils.h
浏览文件 @
f7e10ea8
...
...
@@ -30,6 +30,10 @@ PyObject* Const(PyObject* self, PyObject* const* args, size_t nargs);
PyObject
*
astype_cpp
(
PyObject
*
self
,
PyObject
*
const
*
args
,
size_t
nargs
);
PyObject
*
matmul_cpp
(
PyObject
*
self
,
PyObject
*
const
*
args
,
size_t
nargs
);
PyObject
*
batched_matmul_cpp
(
PyObject
*
self
,
PyObject
*
const
*
args
,
size_t
nargs
);
PyObject
*
convert_single_value_cpp
(
PyObject
*
self
,
PyObject
*
const
*
args
,
size_t
nargs
);
PyObject
*
convert_inputs_cpp
(
PyObject
*
self
,
PyObject
*
const
*
args
,
size_t
nargs
);
...
...
imperative/python/test/unit/autodiff/test_grad_manger.py
→
imperative/python/test/unit/autodiff/test_grad_man
a
ger.py
浏览文件 @
f7e10ea8
文件已移动
imperative/src/impl/ops/dot.cpp
已删除
100644 → 0
浏览文件 @
1c2a323e
#include "megbrain/imperative/opr_utility.h"
#include "megbrain/imperative/ops/autogen.h"
#include "megbrain/imperative/utils/stats.h"
#include "megbrain/opr/basic_arith.h"
#include "megbrain/opr/blas.h"
#include "megbrain/opr/utility.h"
#include "../blob_manager_impl.h"
#include "../dnn_op_helper.h"
#include "../op_trait.h"
namespace
mgb
{
namespace
imperative
{
namespace
{
namespace
dot
{
auto
apply_on_var_node
(
const
OpDef
&
def
,
const
VarNodeArray
&
inputs
)
{
auto
&&
op
=
def
.
cast_final_safe
<
Dot
>
();
mgb_assert
(
inputs
.
size
()
==
2
);
OperatorNodeConfig
config
{
op
.
make_name
()};
return
opr
::
Dot
::
make
(
inputs
[
0
],
inputs
[
1
],
config
);
}
SmallVector
<
TensorPtr
>
apply_on_physical_tensor
(
const
OpDef
&
def
,
const
SmallVector
<
TensorPtr
>&
inputs
,
SmallVector
<
LogicalTensorDesc
>&
output_descs
,
const
bool
&
validated
)
{
auto
comp_node
=
inputs
[
0
]
->
comp_node
();
using
TensorND
=
megdnn
::
TensorND
;
SmallVector
<
TensorND
>
inp_tensornds
;
inp_tensornds
.
reserve
(
inputs
.
size
());
DnnOprCaller
<
megdnn
::
Dot
>
dnn_opr
(
comp_node
);
for
(
unsigned
i
=
0
;
i
<
inputs
.
size
();
++
i
)
{
auto
dnn_ten
=
inputs
[
i
]
->
dnn_tensor
();
inp_tensornds
.
push_back
(
dnn_ten
);
}
TensorLayout
oup_layout
{
inputs
[
0
]
->
dtype
()};
auto
inp1_tensor
=
inputs
[
0
]
->
dnn_tensor
();
auto
inp2_tensor
=
inputs
[
1
]
->
dnn_tensor
();
dnn_opr
.
op
->
deduce_layout
(
inp1_tensor
.
layout
,
inp2_tensor
.
layout
,
oup_layout
);
if
(
inputs
[
0
]
->
layout
().
is_empty
()
||
inputs
[
1
]
->
layout
().
is_empty
())
{
DnnOprCaller
<
megdnn
::
Fill
>
fill_opr
(
comp_node
);
DeviceTensorND
out
=
BlobManager
::
inst
()
->
alloc_workspace_with_defrag
(
comp_node
,
oup_layout
);
fill_opr
.
op
->
param
()
=
0
;
fill_opr
.
op
->
exec
(
out
.
as_megdnn
(),
{});
return
{
Tensor
::
make
(
out
)};
}
auto
sz
=
dnn_opr
.
op
->
get_workspace_in_bytes
(
inp_tensornds
[
0
].
layout
,
inp_tensornds
[
1
].
layout
,
output_descs
[
0
].
layout
);
DeviceTensorND
out_devtensor
=
BlobManager
::
inst
()
->
alloc_workspace_with_defrag
(
comp_node
,
oup_layout
);
TensorLayout
w_layout
({
sz
},
dtype
::
Byte
());
auto
dnn_wk
=
dnn_opr
.
create_workspace
(
w_layout
);
dnn_opr
.
op
->
exec
(
inp_tensornds
[
0
],
inp_tensornds
[
1
],
out_devtensor
.
as_megdnn
(),
dnn_wk
);
return
{
Tensor
::
make
(
out_devtensor
)};
}
std
::
tuple
<
SmallVector
<
LogicalTensorDesc
>
,
bool
>
infer_output_attrs_fallible
(
const
OpDef
&
def
,
const
SmallVector
<
LogicalTensorDesc
>&
inputs
)
{
mgb_assert
(
inputs
.
size
()
==
2
,
"Dot expects 2 inputs; got %lu actually"
,
inputs
.
size
());
SmallVector
<
LogicalTensorDesc
>
dests
(
1
);
dests
[
0
].
layout
=
TensorLayout
(
TensorShape
{
1
},
inputs
[
0
].
layout
.
dtype
);
dests
[
0
].
comp_node
=
inputs
[
0
].
comp_node
;
bool
validated
=
inputs
[
0
].
layout
.
ndim
!=
0
&&
inputs
[
1
].
layout
.
ndim
!=
0
;
return
{
dests
,
validated
};
}
OP_TRAIT_REG
(
Dot
,
Dot
,
mgb
::
opr
::
Dot
)
.
apply_on_var_node
(
apply_on_var_node
)
.
infer_output_attrs_fallible
(
infer_output_attrs_fallible
)
.
apply_on_physical_tensor
(
apply_on_physical_tensor
)
.
fallback
();
}
// namespace dot
}
// anonymous namespace
}
// namespace imperative
}
// namespace mgb
\ No newline at end of file
imperative/src/impl/ops/matmul.cpp
0 → 100644
浏览文件 @
f7e10ea8
#include <numeric>
#include "../blob_manager_impl.h"
#include "../dnn_op_helper.h"
#include "../op_trait.h"
#include "megbrain/imperative/ops/autogen.h"
#include "megbrain/opr/blas.h"
#include "../algo_chooser.h"
namespace
mgb
{
namespace
imperative
{
namespace
{
namespace
matrix_mul
{
auto
apply_on_var_node
(
const
OpDef
&
def
,
const
VarNodeArray
&
inputs
)
{
auto
&&
matmul
=
def
.
cast_final_safe
<
MatrixMul
>
();
mgb_assert
(
inputs
.
size
()
==
2
);
OperatorNodeConfig
config
{
matmul
.
make_name
()};
return
opr
::
MatrixMul
::
make
(
inputs
[
0
],
inputs
[
1
],
matmul
.
param
(),
matmul
.
policy
(),
config
);
}
std
::
tuple
<
SmallVector
<
LogicalTensorDesc
>
,
bool
>
infer_output_attrs_fallible
(
const
OpDef
&
def
,
const
SmallVector
<
LogicalTensorDesc
>&
inputs
)
{
auto
&&
matmul
=
def
.
cast_final_safe
<
MatrixMul
>
();
auto
layout1
=
inputs
[
0
].
layout
;
auto
layout2
=
inputs
[
1
].
layout
;
size_t
dim1
=
layout1
.
ndim
,
dim2
=
layout2
.
ndim
;
if
(
dim1
==
0
||
dim2
==
0
)
{
return
{{{
TensorLayout
(
layout1
.
dtype
),
inputs
[
0
].
comp_node
}},
false
};
}
if
(
matmul
.
transposeA
)
std
::
swap
(
layout1
[
0
],
layout1
[
1
]);
if
(
matmul
.
transposeB
)
std
::
swap
(
layout2
[
0
],
layout2
[
1
]);
mgb_assert
(
layout1
[
dim1
-
1
]
==
layout2
[
0
]);
TensorLayout
dst_layout
(
layout1
.
dtype
);
size_t
ci
=
0
;
for
(
size_t
i
=
0
;
i
<
dim1
-
1
;
i
++
)
dst_layout
[
ci
++
]
=
layout1
[
i
];
if
(
dim2
==
2
)
dst_layout
[
ci
++
]
=
layout2
[
1
];
dst_layout
.
ndim
=
ci
;
dst_layout
.
init_contiguous_stride
();
SmallVector
<
LogicalTensorDesc
>
out_descs
(
1u
);
out_descs
[
0
]
=
{
dst_layout
,
inputs
[
0
].
comp_node
};
return
{
out_descs
,
true
};
}
SmallVector
<
TensorPtr
>
apply_on_physical_tensor
(
const
OpDef
&
def
,
const
SmallVector
<
TensorPtr
>&
inputs
,
SmallVector
<
LogicalTensorDesc
>&
output_descs
,
const
bool
&
validated
)
{
auto
&&
matmul
=
def
.
cast_final_safe
<
MatrixMul
>
();
auto
&&
cn
=
inputs
[
0
]
->
comp_node
();
using
TensorND
=
megdnn
::
TensorND
;
SmallVector
<
TensorND
>
inp_tensornds
(
inputs
.
size
());
TensorLayout
layout1
=
inputs
[
0
]
->
layout
(),
layout2
=
inputs
[
1
]
->
layout
();
// only matters when layout1 has dim 2
if
(
matmul
.
transposeA
)
std
::
swap
(
layout1
.
shape
[
0
],
layout1
.
shape
[
1
]);
// only matters when layout2 has dim 2
if
(
matmul
.
transposeB
)
std
::
swap
(
layout2
.
shape
[
0
],
layout2
.
shape
[
1
]);
size_t
dim1
=
layout1
.
ndim
,
dim2
=
layout2
.
ndim
;
TensorLayout
real_dst_layout
(
layout1
.
dtype
);
if
(
validated
)
{
real_dst_layout
=
output_descs
[
0
].
layout
;
}
else
{
size_t
ri
=
0
;
for
(
size_t
i
=
0
;
i
<
dim1
-
2
;
i
++
)
real_dst_layout
[
ri
++
]
=
layout1
[
i
];
real_dst_layout
[
ri
++
]
=
layout1
[
dim1
-
2
];
if
(
dim2
==
2
)
real_dst_layout
[
ri
++
]
=
layout2
[
dim2
-
1
];
real_dst_layout
.
ndim
=
ri
;
real_dst_layout
.
init_contiguous_stride
();
}
if
(
dim1
==
0
||
dim2
==
0
||
layout1
[
layout1
.
ndim
-
1
]
==
0
)
{
DeviceTensorND
out
=
BlobManager
::
inst
()
->
alloc_workspace_with_defrag
(
cn
,
real_dst_layout
);
if
(
!
out
.
empty
())
{
dev_tensor_memset
(
out
,
0
);
}
return
{
Tensor
::
make
(
out
)};
}
TensorLayout
layout_a
=
layout1
,
layout_b
=
layout2
;
if
(
dim1
==
1
)
{
layout_a
.
add_axis_cont_inplace
(
0
);
inp_tensornds
[
0
]
=
inputs
[
0
]
->
dnn_tensor
();
inp_tensornds
[
0
].
layout
=
layout_a
;
}
else
if
(
dim1
>
2
)
{
size_t
batch
=
std
::
accumulate
(
layout1
.
shape
,
layout1
.
shape
+
dim1
-
1
,
(
size_t
)
1
,
std
::
multiplies
<
size_t
>
());
TensorShape
na
=
TensorShape
{
batch
,
layout1
[
dim1
-
1
]};
auto
inp1
=
inputs
[
0
];
if
(
!
layout1
.
try_reshape
(
layout_a
,
na
))
{
inp1
=
Tensor
::
make
(
inp1
->
blob
(),
inp1
->
offset
(),
layout1
);
inp1
->
to_contiguous_inplace
();
layout1
=
inp1
->
layout
();
layout_a
=
TensorLayout
{{
batch
,
layout1
[
dim1
-
1
]},
layout1
.
dtype
};
}
layout_a
.
init_contiguous_stride
();
inp_tensornds
[
0
]
=
inp1
->
dnn_tensor
();
inp_tensornds
[
0
].
layout
=
layout_a
;
}
else
{
inp_tensornds
[
0
]
=
inputs
[
0
]
->
dnn_tensor
();
}
if
(
dim2
==
1
)
{
layout_b
.
add_axis_inplace
(
1
,
1
,
1
);
inp_tensornds
[
1
]
=
inputs
[
1
]
->
dnn_tensor
();
inp_tensornds
[
1
].
layout
=
layout_b
;
}
else
{
inp_tensornds
[
1
]
=
inputs
[
1
]
->
dnn_tensor
();
}
TensorLayout
dst_layout
=
TensorLayout
({
layout_a
[
0
],
layout_b
[
1
]},
layout_a
.
dtype
);
dst_layout
.
init_contiguous_stride
();
DnnOprCaller
<
megdnn
::
MatrixMul
>
dnn_opr
(
cn
);
dnn_opr
.
op
->
param
()
=
matmul
.
param
();
DeviceTensorND
out
=
BlobManager
::
inst
()
->
alloc_workspace_with_defrag
(
cn
,
dst_layout
);
size_t
sz
=
setup_algo
<
megdnn
::
MatrixMul
>
(
{
layout_a
,
layout_b
,
dst_layout
},
dnn_opr
.
op
.
get
(),
0
,
false
,
false
,
cn
,
matmul
.
policy
(),
false
);
TensorLayout
w_layout
({
sz
},
dtype
::
Byte
());
auto
dnn_wk
=
dnn_opr
.
create_workspace
(
w_layout
);
dnn_opr
.
op
->
exec
(
inp_tensornds
[
0
],
inp_tensornds
[
1
],
out
.
as_megdnn
(),
dnn_wk
);
return
{
Tensor
::
make
(
out
.
sub
(
SubTensorSpec
::
make_from_layout
(
real_dst_layout
)))};
}
SmallVector
<
VarNode
::
LayoutConstraintCallback
>
get_input_layout_constraint
(
const
OpDef
&
def
,
const
SmallVector
<
TensorPtr
>&
inputs
)
{
SmallVector
<
VarNode
::
LayoutConstraintCallback
>
layout_checker
(
inputs
.
size
());
layout_checker
[
0
]
=
layout_checker
[
1
]
=
[](
const
TensorLayout
&
layout
)
{
return
layout
.
is_contiguous
();
};
return
layout_checker
;
}
OP_TRAIT_REG
(
MatrixMul
,
MatrixMul
)
.
apply_on_var_node
(
apply_on_var_node
)
.
infer_output_attrs_fallible
(
infer_output_attrs_fallible
)
.
apply_on_physical_tensor
(
apply_on_physical_tensor
)
.
get_input_layout_constraint
(
get_input_layout_constraint
)
.
fallback
();
}
// namespace matrix_mul
}
// namespace
namespace
{
namespace
batched_matrix_mul
{
auto
apply_on_var_node
(
const
OpDef
&
def
,
const
VarNodeArray
&
inputs
)
{
auto
&&
matmul
=
def
.
cast_final_safe
<
BatchedMatrixMul
>
();
mgb_assert
(
inputs
.
size
()
==
2
);
OperatorNodeConfig
config
{
matmul
.
make_name
()};
return
opr
::
BatchedMatrixMul
::
make
(
inputs
[
0
],
inputs
[
1
],
matmul
.
param
(),
matmul
.
policy
(),
config
);
}
std
::
tuple
<
SmallVector
<
LogicalTensorDesc
>
,
bool
>
infer_output_attrs_fallible
(
const
OpDef
&
def
,
const
SmallVector
<
LogicalTensorDesc
>&
inputs
)
{
auto
&&
matmul
=
def
.
cast_final_safe
<
BatchedMatrixMul
>
();
TensorLayout
layout1
=
inputs
[
0
].
layout
,
layout2
=
inputs
[
1
].
layout
;
size_t
dim1
=
layout1
.
ndim
,
dim2
=
layout2
.
ndim
;
if
(
dim1
==
0
||
dim2
==
0
)
{
return
{{{
TensorLayout
(
layout1
.
dtype
),
inputs
[
0
].
comp_node
}},
false
};
}
if
(
matmul
.
transposeA
)
std
::
swap
(
layout1
[
dim1
-
1
],
layout1
[
dim1
-
2
]);
if
(
matmul
.
transposeB
)
std
::
swap
(
layout2
[
dim2
-
1
],
layout2
[
dim2
-
2
]);
TensorLayout
dst_layout
(
layout1
.
dtype
);
size_t
di
=
0
;
if
(
dim1
>
dim2
)
{
for
(
size_t
i
=
0
;
i
<
dim1
-
2
;
i
++
)
dst_layout
[
di
++
]
=
layout1
[
i
];
}
else
{
for
(
size_t
i
=
0
;
i
<
dim2
-
2
;
i
++
)
dst_layout
[
di
++
]
=
layout2
[
i
];
}
if
(
dim1
>
1
)
dst_layout
[
di
++
]
=
layout1
[
dim1
-
2
];
if
(
dim2
>
1
)
dst_layout
[
di
++
]
=
layout2
[
dim2
-
1
];
dst_layout
.
ndim
=
di
;
dst_layout
.
init_contiguous_stride
();
SmallVector
<
LogicalTensorDesc
>
out_descs
(
1u
);
out_descs
[
0
]
=
{
dst_layout
,
inputs
[
0
].
comp_node
};
return
{
out_descs
,
true
};
}
SmallVector
<
TensorPtr
>
apply_on_physical_tensor
(
const
OpDef
&
def
,
const
SmallVector
<
TensorPtr
>&
inputs
,
SmallVector
<
LogicalTensorDesc
>&
output_descs
,
const
bool
&
validated
)
{
auto
&&
matmul
=
def
.
cast_final_safe
<
BatchedMatrixMul
>
();
auto
&&
cn
=
inputs
[
0
]
->
comp_node
();
TensorLayout
layout1
=
inputs
[
0
]
->
layout
(),
layout2
=
inputs
[
1
]
->
layout
();
size_t
dim1
=
layout1
.
ndim
,
dim2
=
layout2
.
ndim
;
bool
remove_row
=
false
,
remove_col
=
false
;
if
(
dim1
==
1
)
{
dim1
=
2
;
remove_row
=
true
;
}
if
(
dim2
==
1
)
{
dim2
=
2
;
remove_col
=
true
;
}
if
(
remove_row
)
layout1
.
add_axis_cont_inplace
(
0
);
if
(
remove_col
)
layout2
.
add_axis_inplace
(
1
,
1
,
1
);
TensorShape
tshp
,
batch_shp
;
size_t
j
=
0
;
if
(
dim1
>
dim2
)
{
for
(
size_t
i
=
0
;
i
<
dim1
-
2
;
i
++
)
tshp
[
j
++
]
=
layout1
.
shape
[
i
];
batch_shp
=
tshp
;
batch_shp
.
ndim
=
dim1
-
2
;
tshp
[
j
++
]
=
layout2
[
layout2
.
ndim
-
2
];
tshp
[
j
++
]
=
layout2
[
layout2
.
ndim
-
1
];
tshp
.
ndim
=
j
;
layout2
=
layout2
.
broadcast
(
tshp
);
}
if
(
dim2
>
dim1
)
{
for
(
size_t
i
=
0
;
i
<
dim2
-
2
;
i
++
)
tshp
[
j
++
]
=
layout2
.
shape
[
i
];
batch_shp
=
tshp
;
batch_shp
.
ndim
=
dim2
-
2
;
tshp
[
j
++
]
=
layout1
[
layout1
.
ndim
-
2
];
tshp
[
j
++
]
=
layout1
[
layout1
.
ndim
-
1
];
tshp
.
ndim
=
j
;
layout1
=
layout1
.
broadcast
(
tshp
);
}
if
(
dim1
==
dim2
)
{
for
(
size_t
i
=
0
;
i
<
dim1
-
2
;
i
++
)
tshp
[
j
++
]
=
layout1
.
shape
[
i
];
batch_shp
=
tshp
;
batch_shp
.
ndim
=
dim1
-
2
;
}
TensorShape
shp1
=
batch_shp
,
shp2
=
batch_shp
;
shp1
.
ndim
+=
2
;
shp2
.
ndim
+=
2
;
size_t
maxdim
=
dim1
>
dim2
?
dim1
:
dim2
;
size_t
nbatch
=
batch_shp
[
0
];
auto
inp1
=
inputs
[
0
],
inp2
=
inputs
[
1
];
if
(
maxdim
>
3
)
{
nbatch
=
std
::
accumulate
(
batch_shp
.
shape
,
batch_shp
.
shape
+
batch_shp
.
ndim
,
(
size_t
)
1
,
std
::
multiplies
<
size_t
>
());
TensorLayout
layout_a
;
TensorShape
nl1
=
TensorShape
(
{
nbatch
,
layout1
[
layout1
.
ndim
-
2
],
layout1
[
layout1
.
ndim
-
1
]});
if
(
!
layout1
.
try_reshape
(
layout_a
,
nl1
))
{
inp1
=
Tensor
::
make
(
inputs
[
0
]
->
blob
(),
inputs
[
0
]
->
offset
(),
layout1
);
inp1
->
to_contiguous_inplace
();
layout1
=
inp1
->
layout
();
}
layout1
=
layout_a
;
TensorShape
nl2
=
TensorShape
(
{
nbatch
,
layout2
[
layout2
.
ndim
-
2
],
layout2
[
layout2
.
ndim
-
1
]});
if
(
!
layout2
.
try_reshape
(
layout_a
,
nl2
))
{
inp2
=
Tensor
::
make
(
inputs
[
1
]
->
blob
(),
inputs
[
1
]
->
offset
(),
layout2
);
inp2
->
to_contiguous_inplace
();
layout2
=
inp2
->
layout
();
}
layout2
=
layout_a
;
}
TensorLayout
dst_layout
(
{
nbatch
,
matmul
.
transposeA
?
layout1
[
2
]
:
layout1
[
1
],
matmul
.
transposeB
?
layout2
[
1
]
:
layout2
[
2
]},
layout1
.
dtype
);
dst_layout
.
init_contiguous_stride
();
if
(
dim1
==
0
||
dim2
==
0
||
layout1
[
layout1
.
ndim
-
1
]
==
0
)
{
DeviceTensorND
out
=
BlobManager
::
inst
()
->
alloc_workspace_with_defrag
(
cn
,
dst_layout
);
if
(
!
out
.
empty
())
{
dev_tensor_memset
(
out
,
0
);
}
return
{
Tensor
::
make
(
out
)};
}
using
TensorND
=
megdnn
::
TensorND
;
TensorND
inp_nd1
=
inp1
->
dnn_tensor
();
inp_nd1
.
layout
=
layout1
;
TensorND
inp_nd2
=
inp2
->
dnn_tensor
();
inp_nd2
.
layout
=
layout2
;
DeviceTensorND
out
=
BlobManager
::
inst
()
->
alloc_workspace_with_defrag
(
cn
,
dst_layout
);
DnnOprCaller
<
megdnn
::
BatchedMatrixMul
>
dnn_opr
(
cn
);
dnn_opr
.
op
->
param
()
=
matmul
.
param
();
size_t
sz
=
setup_algo
<
megdnn
::
BatchedMatrixMul
>
(
{
layout1
,
layout2
,
dst_layout
},
dnn_opr
.
op
.
get
(),
0
,
false
,
false
,
cn
,
matmul
.
policy
(),
false
);
TensorLayout
w_layout
({
sz
},
dtype
::
Byte
());
auto
dnn_wk
=
dnn_opr
.
create_workspace
(
w_layout
);
dnn_opr
.
op
->
exec
(
inp_nd1
,
inp_nd2
,
out
.
as_megdnn
(),
dnn_wk
);
shp1
[
shp1
.
ndim
-
2
]
=
dst_layout
[
dst_layout
.
ndim
-
2
];
shp1
[
shp1
.
ndim
-
1
]
=
dst_layout
[
dst_layout
.
ndim
-
1
];
if
(
maxdim
>
3
)
{
dst_layout
=
dst_layout
.
reshape
(
shp1
);
}
if
(
remove_row
)
{
dst_layout
=
dst_layout
.
remove_axis
(
maxdim
-
2
);
}
if
(
remove_col
)
{
dst_layout
=
dst_layout
.
remove_axis
(
maxdim
-
1
);
}
return
{
Tensor
::
make
(
out
.
sub
(
SubTensorSpec
::
make_from_layout
(
dst_layout
)))};
}
SmallVector
<
VarNode
::
LayoutConstraintCallback
>
get_input_layout_constraint
(
const
OpDef
&
def
,
const
SmallVector
<
TensorPtr
>&
inputs
)
{
SmallVector
<
VarNode
::
LayoutConstraintCallback
>
layout_checker
(
inputs
.
size
());
layout_checker
[
0
]
=
layout_checker
[
1
]
=
[](
const
TensorLayout
&
layout
)
{
return
layout
.
is_contiguous
();
};
return
layout_checker
;
}
OP_TRAIT_REG
(
BatchedMatrixMul
,
BatchedMatrixMul
)
.
apply_on_var_node
(
apply_on_var_node
)
.
infer_output_attrs_fallible
(
infer_output_attrs_fallible
)
.
get_input_layout_constraint
(
get_input_layout_constraint
)
.
apply_on_physical_tensor
(
apply_on_physical_tensor
)
.
fallback
();
}
// namespace batched_matrix_mul
}
// namespace
namespace
{
namespace
dot
{
auto
apply_on_var_node
(
const
OpDef
&
def
,
const
VarNodeArray
&
inputs
)
{
auto
&&
op
=
def
.
cast_final_safe
<
Dot
>
();
mgb_assert
(
inputs
.
size
()
==
2
);
OperatorNodeConfig
config
{
op
.
make_name
()};
return
opr
::
Dot
::
make
(
inputs
[
0
],
inputs
[
1
],
config
);
}
SmallVector
<
TensorPtr
>
apply_on_physical_tensor
(
const
OpDef
&
def
,
const
SmallVector
<
TensorPtr
>&
inputs
,
SmallVector
<
LogicalTensorDesc
>&
output_descs
,
const
bool
&
validated
)
{
auto
comp_node
=
inputs
[
0
]
->
comp_node
();
using
TensorND
=
megdnn
::
TensorND
;
SmallVector
<
TensorND
>
inp_tensornds
;
inp_tensornds
.
reserve
(
inputs
.
size
());
DnnOprCaller
<
megdnn
::
Dot
>
dnn_opr
(
comp_node
);
for
(
unsigned
i
=
0
;
i
<
inputs
.
size
();
++
i
)
{
auto
dnn_ten
=
inputs
[
i
]
->
dnn_tensor
();
inp_tensornds
.
push_back
(
dnn_ten
);
}
TensorLayout
oup_layout
{
inputs
[
0
]
->
dtype
()};
auto
inp1_tensor
=
inputs
[
0
]
->
dnn_tensor
();
auto
inp2_tensor
=
inputs
[
1
]
->
dnn_tensor
();
dnn_opr
.
op
->
deduce_layout
(
inp1_tensor
.
layout
,
inp2_tensor
.
layout
,
oup_layout
);
if
(
inputs
[
0
]
->
layout
().
is_empty
()
||
inputs
[
1
]
->
layout
().
is_empty
())
{
DeviceTensorND
out
=
BlobManager
::
inst
()
->
alloc_workspace_with_defrag
(
comp_node
,
oup_layout
);
if
(
!
out
.
empty
())
{
dev_tensor_memset
(
out
,
0
);
}
return
{
Tensor
::
make
(
out
)};
}
auto
sz
=
dnn_opr
.
op
->
get_workspace_in_bytes
(
inp_tensornds
[
0
].
layout
,
inp_tensornds
[
1
].
layout
,
output_descs
[
0
].
layout
);
DeviceTensorND
out_devtensor
=
BlobManager
::
inst
()
->
alloc_workspace_with_defrag
(
comp_node
,
oup_layout
);
TensorLayout
w_layout
({
sz
},
dtype
::
Byte
());
auto
dnn_wk
=
dnn_opr
.
create_workspace
(
w_layout
);
dnn_opr
.
op
->
exec
(
inp_tensornds
[
0
],
inp_tensornds
[
1
],
out_devtensor
.
as_megdnn
(),
dnn_wk
);
return
{
Tensor
::
make
(
out_devtensor
)};
}
std
::
tuple
<
SmallVector
<
LogicalTensorDesc
>
,
bool
>
infer_output_attrs_fallible
(
const
OpDef
&
def
,
const
SmallVector
<
LogicalTensorDesc
>&
inputs
)
{
mgb_assert
(
inputs
.
size
()
==
2
,
"Dot expects 2 inputs; got %lu actually"
,
inputs
.
size
());
SmallVector
<
LogicalTensorDesc
>
dests
(
1
);
dests
[
0
].
layout
=
TensorLayout
(
TensorShape
{
1
},
inputs
[
0
].
layout
.
dtype
);
dests
[
0
].
comp_node
=
inputs
[
0
].
comp_node
;
bool
validated
=
inputs
[
0
].
layout
.
ndim
!=
0
&&
inputs
[
1
].
layout
.
ndim
!=
0
;
return
{
dests
,
validated
};
}
OP_TRAIT_REG
(
Dot
,
Dot
,
mgb
::
opr
::
Dot
)
.
apply_on_var_node
(
apply_on_var_node
)
.
infer_output_attrs_fallible
(
infer_output_attrs_fallible
)
.
apply_on_physical_tensor
(
apply_on_physical_tensor
)
.
fallback
();
}
// namespace dot
}
// anonymous namespace
}
// namespace imperative
}
// namespace mgb
imperative/src/impl/ops/reduce.cpp
浏览文件 @
f7e10ea8
...
...
@@ -123,7 +123,6 @@ SmallVector<TensorPtr> apply_on_physical_tensor(
inputs
[
0
]
->
dev_tensor
().
reset
(
inputs
[
0
]
->
dev_tensor
().
storage
(),
src
);
auto
mode
=
op_def
.
param
().
mode
;
DnnOprCaller
<
megdnn
::
Fill
>
fill_op
(
comp_node
);
if
(
!
keepdim
&&
src
.
ndim
>
1
)
{
layout
.
remove_axis_inplace
(
axis
);
...
...
@@ -135,12 +134,12 @@ SmallVector<TensorPtr> apply_on_physical_tensor(
switch
(
mode
)
{
case
Reduce
::
Mode
::
SUM
:
if
(
!
out
.
empty
())
{
fill_op
.
op
->
param
()
=
0
;
fill_op
.
op
->
exec
(
out
.
as_megdnn
(),
{});
dev_tensor_memset
(
out
,
0
);
}
break
;
case
Reduce
::
Mode
::
PRODUCT
:
if
(
!
out
.
empty
())
{
DnnOprCaller
<
megdnn
::
Fill
>
fill_op
(
comp_node
);
fill_op
.
op
->
param
()
=
1
;
fill_op
.
op
->
exec
(
out
.
as_megdnn
(),
{});
}
...
...
imperative/src/impl/ops/specializations.cpp
浏览文件 @
f7e10ea8
...
...
@@ -319,34 +319,6 @@ OP_TRAIT_REG(BatchConvBias, BatchConvBias)
}
// namespace batch_conv_bias
}
// namespace
namespace
{
namespace
matrix_mul
{
auto
apply_on_var_node
(
const
OpDef
&
def
,
const
VarNodeArray
&
inputs
)
{
auto
&&
matmul
=
static_cast
<
const
MatrixMul
&>
(
def
);
mgb_assert
(
inputs
.
size
()
==
2
);
OperatorNodeConfig
config
{
matmul
.
make_name
()};
return
opr
::
MatrixMul
::
make
(
inputs
[
0
],
inputs
[
1
],
matmul
.
param
(),
matmul
.
policy
(),
config
);
}
OP_TRAIT_REG
(
MatrixMul
,
MatrixMul
).
apply_on_var_node
(
apply_on_var_node
).
fallback
();
}
// namespace matrix_mul
}
// namespace
namespace
{
namespace
batched_matrix_mul
{
auto
apply_on_var_node
(
const
OpDef
&
def
,
const
VarNodeArray
&
inputs
)
{
auto
&&
matmul
=
static_cast
<
const
BatchedMatrixMul
&>
(
def
);
mgb_assert
(
inputs
.
size
()
==
2
);
OperatorNodeConfig
config
{
matmul
.
make_name
()};
return
opr
::
BatchedMatrixMul
::
make
(
inputs
[
0
],
inputs
[
1
],
matmul
.
param
(),
matmul
.
policy
(),
config
);
}
OP_TRAIT_REG
(
BatchedMatrixMul
,
BatchedMatrixMul
)
.
apply_on_var_node
(
apply_on_var_node
)
.
fallback
();
}
// namespace batched_matrix_mul
}
// namespace
namespace
{
namespace
argsort
{
auto
apply_on_var_node
(
const
OpDef
&
def
,
const
VarNodeArray
&
inputs
)
{
...
...
imperative/src/impl/transformations/dtype_promote.cpp
浏览文件 @
f7e10ea8
...
...
@@ -183,6 +183,57 @@ ValueRefList convolution_rule(const OpDef& op, Span<ValueRef> inputs) {
return
imperative
::
apply
(
op
,
converted
);
}
ValueRefList
matmul_rule
(
const
OpDef
&
op
,
Span
<
ValueRef
>
inputs
)
{
auto
&&
conv_op
=
const_cast
<
MatrixMul
&>
(
op
.
cast_final_safe
<
MatrixMul
>
());
SmallVector
<
DType
>
dtypes
=
get_value_dtypes
(
inputs
);
mgb
::
DType
target_dtype
;
if
(
DTypePromoteCfg
::
amp_dtype_autocast_enabled
)
{
conv_op
.
compute_mode
=
MatrixMul
::
ComputeMode
::
FLOAT32
;
target_dtype
=
DTypePromoteCfg
::
amp_low_prec_dtype
;
}
else
{
target_dtype
=
get_promoted_dtype
(
dtypes
);
}
ValueRefList
converted
(
inputs
.
size
());
for
(
size_t
i
=
0
;
i
<
inputs
.
size
();
++
i
)
{
if
(
dtypes
[
i
]
!=
target_dtype
)
{
converted
[
i
]
=
imperative
::
apply
(
ApplyOp
(
*
TypeCvt
::
make
(
target_dtype
)),
inputs
[
i
])[
0
];
}
else
{
converted
[
i
]
=
inputs
[
i
];
}
}
return
imperative
::
apply
(
op
,
converted
);
}
ValueRefList
batch_matmul_rule
(
const
OpDef
&
op
,
Span
<
ValueRef
>
inputs
)
{
auto
&&
conv_op
=
const_cast
<
BatchedMatrixMul
&>
(
op
.
cast_final_safe
<
BatchedMatrixMul
>
());
SmallVector
<
DType
>
dtypes
=
get_value_dtypes
(
inputs
);
mgb
::
DType
target_dtype
;
if
(
DTypePromoteCfg
::
amp_dtype_autocast_enabled
)
{
conv_op
.
compute_mode
=
BatchedMatrixMul
::
ComputeMode
::
FLOAT32
;
target_dtype
=
DTypePromoteCfg
::
amp_low_prec_dtype
;
}
else
{
target_dtype
=
get_promoted_dtype
(
dtypes
);
}
ValueRefList
converted
(
inputs
.
size
());
for
(
size_t
i
=
0
;
i
<
inputs
.
size
();
++
i
)
{
if
(
dtypes
[
i
]
!=
target_dtype
)
{
converted
[
i
]
=
imperative
::
apply
(
ApplyOp
(
*
TypeCvt
::
make
(
target_dtype
)),
inputs
[
i
])[
0
];
}
else
{
converted
[
i
]
=
inputs
[
i
];
}
}
return
imperative
::
apply
(
op
,
converted
);
}
// differ from Convolution, ConvolutionBackwardData is used in both
// functional.conv_transpose2d and quantize.conv_transpose2d
ValueRefList
convolution_backward_rule
(
const
OpDef
&
op
,
Span
<
ValueRef
>
inputs
)
{
...
...
@@ -259,8 +310,11 @@ struct DTypePromoteRuleRegistry {
DTypePromoteRuleRegistry
()
{
register_dtype_promote_rule
<
Elemwise
>
(
elemwise_rule
);
register_dtype_promote_rule
<
Concat
>
(
naive_promote_rule
);
register_dtype_promote_rule
<
GroupLocal
>
(
naive_promote_rule
);
register_dtype_promote_rule
<
Reduce
>
(
reduce_rule
);
register_dtype_promote_rule
<
Convolution
>
(
convolution_rule
);
register_dtype_promote_rule
<
MatrixMul
>
(
matmul_rule
);
register_dtype_promote_rule
<
BatchedMatrixMul
>
(
batch_matmul_rule
);
register_dtype_promote_rule
<
ConvolutionBackwardData
>
(
convolution_backward_rule
);
register_dtype_promote_rule
<
BatchNorm
>
(
batch_norm_rule
);
register_dtype_promote_rule
<
Convolution3D
>
(
naive_promote_rule
);
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录