Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
a72e0cb5
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
410
Star
4707
Fork
583
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
a72e0cb5
编写于
11月 09, 2021
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
feat(imperative,src): add jit builder for custom op
GitOrigin-RevId: 3bb0b46311ce94e2255fddf5a45c8c8007b99c09
上级
93310c0e
变更
21
展开全部
隐藏空白更改
内联
并排
Showing
21 changed file
with
1611 addition
and
183 deletion
+1611
-183
CMakeLists.txt
CMakeLists.txt
+8
-0
imperative/CMakeLists.txt
imperative/CMakeLists.txt
+3
-0
imperative/python/megengine/core/ops/custom.py
imperative/python/megengine/core/ops/custom.py
+12
-1
imperative/python/megengine/utils/custom_op_tools.py
imperative/python/megengine/utils/custom_op_tools.py
+909
-0
imperative/python/src/ops.cpp
imperative/python/src/ops.cpp
+7
-0
imperative/python/test/unit/core/custom_opsrc/elem_add.cpp
imperative/python/test/unit/core/custom_opsrc/elem_add.cpp
+140
-0
imperative/python/test/unit/core/custom_opsrc/matmul_scale.cpp
...ative/python/test/unit/core/custom_opsrc/matmul_scale.cpp
+65
-0
imperative/python/test/unit/core/custom_opsrc/matmul_scale.cu
...rative/python/test/unit/core/custom_opsrc/matmul_scale.cu
+97
-0
imperative/python/test/unit/core/custom_opsrc/matmul_scale.h
imperative/python/test/unit/core/custom_opsrc/matmul_scale.h
+24
-0
imperative/python/test/unit/core/test_custom_op.py
imperative/python/test/unit/core/test_custom_op.py
+111
-0
scripts/whl/macos/macos_build_whl.sh
scripts/whl/macos/macos_build_whl.sh
+1
-0
scripts/whl/manylinux2014/do_build_common.sh
scripts/whl/manylinux2014/do_build_common.sh
+1
-0
scripts/whl/windows/windows_build_whl.sh
scripts/whl/windows/windows_build_whl.sh
+4
-1
src/CMakeLists.txt
src/CMakeLists.txt
+13
-10
src/custom/impl/manager.cpp
src/custom/impl/manager.cpp
+19
-7
src/custom/include/megbrain/custom/custom.h
src/custom/include/megbrain/custom/custom.h
+2
-1
src/custom/include/megbrain/custom/op.h
src/custom/include/megbrain/custom/op.h
+47
-53
src/custom/include/megbrain/custom/param.h
src/custom/include/megbrain/custom/param.h
+10
-10
src/custom/include/megbrain/custom/param_val.h
src/custom/include/megbrain/custom/param_val.h
+40
-31
src/custom/include/megbrain/custom/tensor.h
src/custom/include/megbrain/custom/tensor.h
+70
-64
src/custom/include/megbrain/custom/utils.h
src/custom/include/megbrain/custom/utils.h
+28
-5
未找到文件。
CMakeLists.txt
浏览文件 @
a72e0cb5
...
...
@@ -1145,6 +1145,14 @@ if(TARGET _imperative_rt)
COMMAND
${
CMAKE_COMMAND
}
-E create_symlink
${
CMAKE_CURRENT_BINARY_DIR
}
/imperative/python/
${
PACKAGE_NAME
}
/version.py
${
CMAKE_CURRENT_SOURCE_DIR
}
/imperative/python/
${
PACKAGE_NAME
}
/version.py
COMMAND
${
CMAKE_COMMAND
}
-E create_symlink
${
CMAKE_CURRENT_SOURCE_DIR
}
/src/custom/include
${
CMAKE_CURRENT_SOURCE_DIR
}
/imperative/python/
${
PACKAGE_NAME
}
/core/include
COMMAND
${
CMAKE_COMMAND
}
-E make_directory
${
CMAKE_CURRENT_SOURCE_DIR
}
/imperative/python/
${
PACKAGE_NAME
}
/core/lib
COMMAND
${
CMAKE_COMMAND
}
-E create_symlink
${
CMAKE_CURRENT_BINARY_DIR
}
/src/$<TARGET_FILE_NAME:
${
MGE_SHARED_LIB
}
>
${
CMAKE_CURRENT_SOURCE_DIR
}
/imperative/python/
${
PACKAGE_NAME
}
/core/lib/$<TARGET_FILE_NAME:
${
MGE_SHARED_LIB
}
>
DEPENDS _imperative_rt
VERBATIM
)
...
...
imperative/CMakeLists.txt
浏览文件 @
a72e0cb5
...
...
@@ -67,8 +67,11 @@ add_custom_command(
COMMAND
${
CMAKE_COMMAND
}
-E copy
${
PROJECT_SOURCE_DIR
}
/LICENSE
${
PROJECT_SOURCE_DIR
}
/ACKNOWLEDGMENTS
${
PROJECT_BINARY_DIR
}
COMMAND
${
CMAKE_COMMAND
}
-E remove -f
${
CMAKE_CURRENT_SOURCE_DIR
}
/python/megengine/core/$<TARGET_FILE_NAME:
${
MODULE_NAME
}
>
# clean develop
COMMAND
${
CMAKE_COMMAND
}
-E remove -f
${
CMAKE_CURRENT_SOURCE_DIR
}
/python/megengine/version.py
# clean develop
COMMAND
${
CMAKE_COMMAND
}
-E remove -f
${
CMAKE_CURRENT_SOURCE_DIR
}
/python/megengine/core/include
# clean develop
COMMAND
${
CMAKE_COMMAND
}
-E remove -f
${
CMAKE_CURRENT_SOURCE_DIR
}
/python/megengine/core/lib
# clean develop
COMMAND
${
CMAKE_COMMAND
}
-E copy_directory
${
CMAKE_CURRENT_SOURCE_DIR
}
/python/megengine
${
CMAKE_CURRENT_BINARY_DIR
}
/python/megengine
COMMAND
${
CMAKE_COMMAND
}
-E copy_directory
${
CMAKE_CURRENT_SOURCE_DIR
}
/python/test
${
CMAKE_CURRENT_BINARY_DIR
}
/python/test
COMMAND
${
CMAKE_COMMAND
}
-E copy_directory
${
PROJECT_SOURCE_DIR
}
/src/custom/include
${
CMAKE_CURRENT_BINARY_DIR
}
/python/megengine/core/include
COMMAND
${
CMAKE_COMMAND
}
-E copy
${
CMAKE_CURRENT_SOURCE_DIR
}
/python/setup.py
${
CMAKE_CURRENT_BINARY_DIR
}
/python/setup.py
COMMAND
${
CMAKE_COMMAND
}
-E copy
${
CMAKE_CURRENT_SOURCE_DIR
}
/python/requires.txt
${
CMAKE_CURRENT_BINARY_DIR
}
/python/requires.txt
COMMAND
${
CMAKE_COMMAND
}
-E copy
${
CMAKE_CURRENT_SOURCE_DIR
}
/python/requires-style.txt
${
CMAKE_CURRENT_BINARY_DIR
}
/python/requires-style.txt
...
...
imperative/python/megengine/core/ops/custom.py
浏览文件 @
a72e0cb5
...
...
@@ -7,11 +7,14 @@
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import
os
from
.._imperative_rt.ops._custom
import
(
_get_custom_op_list
,
_install
,
_make_custom_op
,
_uninstall
,
get_custom_op_abi_tag
,
)
__all__
=
[
"load"
]
...
...
@@ -25,8 +28,16 @@ def _gen_custom_op_maker(custom_op_name):
def
load
(
lib_path
):
op_in_this_lib
=
_install
(
lib_path
[
0
:
-
3
],
lib_path
)
lib_path
=
os
.
path
.
abspath
(
lib_path
)
lib_name
=
os
.
path
.
splitext
(
lib_path
)[
0
]
op_in_this_lib
=
_install
(
lib_name
,
lib_path
)
for
op
in
op_in_this_lib
:
op_maker
=
_gen_custom_op_maker
(
op
)
globals
()[
op
]
=
op_maker
__all__
.
append
(
op
)
def
unload
(
lib_path
):
lib_path
=
os
.
path
.
abspath
(
lib_path
)
lib_name
=
os
.
path
.
splitext
(
lib_path
)[
0
]
_uninstall
(
lib_name
)
imperative/python/megengine/utils/custom_op_tools.py
0 → 100644
浏览文件 @
a72e0cb5
此差异已折叠。
点击以展开。
imperative/python/src/ops.cpp
浏览文件 @
a72e0cb5
...
...
@@ -766,6 +766,13 @@ void init_custom(pybind11::module m) {
m
.
def
(
"_install"
,
&
install_custom
);
m
.
def
(
"_uninstall"
,
&
uninstall_custom
);
m
.
def
(
"_get_custom_op_list"
,
&
get_custom_op_list
);
m
.
def
(
"get_custom_op_abi_tag"
,
[](
void
)
->
int
{
int
ret
=
0
;
#ifdef _GLIBCXX_USE_CXX11_ABI
ret
=
_GLIBCXX_USE_CXX11_ABI
;
#endif
return
ret
;
});
static
PyMethodDef
method_def
=
{
#ifdef METH_FASTCALL
...
...
imperative/python/test/unit/core/custom_opsrc/elem_add.cpp
0 → 100644
浏览文件 @
a72e0cb5
/**
* \file imperative/python/test/unit/core/custom_opsrc/elem_add.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#include "megbrain/custom/custom.h"
CUSTOM_OP_REG_BEGIN
(
ElemAddSmooth
)
void
forward_device_infer
(
const
std
::
vector
<
Device
>&
inputs
,
const
Param
&
params
,
std
::
vector
<
Device
>&
outputs
)
{
outputs
[
0
]
=
inputs
[
0
];
}
void
forward_shape_infer
(
const
std
::
vector
<
Shape
>&
inputs
,
const
Param
&
params
,
std
::
vector
<
Shape
>&
outputs
)
{
outputs
[
0
]
=
inputs
[
0
];
}
void
forward_dtype_infer
(
const
std
::
vector
<
DType
>&
inputs
,
const
Param
&
params
,
std
::
vector
<
DType
>&
outputs
)
{
outputs
[
0
]
=
inputs
[
0
];
}
void
forward_format_infer
(
const
std
::
vector
<
Format
>&
inputs
,
const
Param
&
params
,
std
::
vector
<
Format
>&
outputs
)
{
outputs
[
0
]
=
inputs
[
0
];
}
template
<
typename
scalar_t
>
void
forward_kernel
(
const
scalar_t
*
input0
,
const
scalar_t
*
input1
,
scalar_t
*
output
,
size_t
len
,
float
smooth
)
{
for
(
size_t
i
=
0
;
i
<
len
;
++
i
)
{
output
[
i
]
=
input0
[
i
]
+
input1
[
i
];
if
(
output
[
i
]
<
0
)
output
[
i
]
+=
smooth
;
else
output
[
i
]
-=
smooth
;
}
}
void
forward_compute
(
const
std
::
vector
<
Tensor
>&
inputs
,
const
Param
&
params
,
std
::
vector
<
Tensor
>&
outputs
)
{
DISPATCH_SIGN_INT_AND_FLOAT_TYPES
(
outputs
[
0
].
dtype
(),
"forward_compute"
,
([
&
]()
{
forward_kernel
<
scalar_t
>
(
inputs
[
0
].
data
<
scalar_t
>
(),
inputs
[
1
].
data
<
scalar_t
>
(),
outputs
[
0
].
data
<
scalar_t
>
(),
outputs
[
0
].
size
(),
params
[
"smooth"
].
as
<
float
>
());
}));
}
CUSTOM_OP_REG
(
ElemAddSmoothForward
)
.
set_description
(
"Custom ElemAdd Operator With a Smooth Parameter, "
"which is used to verify the CPU kernel"
)
.
add_input
(
"lhs"
)
.
add_input
(
"rhs"
)
.
add_output
(
"output"
)
.
add_param
(
"smooth"
,
0.
f
)
.
set_device_infer
(
forward_device_infer
)
.
set_shape_infer
(
forward_shape_infer
)
.
set_dtype_infer
(
forward_dtype_infer
)
.
set_format_infer
(
forward_format_infer
)
.
set_compute
(
forward_compute
);
void
backward_device_infer
(
const
std
::
vector
<
Device
>&
ograds
,
const
Param
&
params
,
std
::
vector
<
Device
>&
igrads
)
{
igrads
[
0
]
=
ograds
[
0
];
igrads
[
1
]
=
ograds
[
0
];
}
void
backward_shape_infer
(
const
std
::
vector
<
Shape
>&
ograds
,
const
Param
&
params
,
std
::
vector
<
Shape
>&
igrads
)
{
igrads
[
0
]
=
ograds
[
0
];
igrads
[
1
]
=
ograds
[
0
];
}
void
backward_dtype_infer
(
const
std
::
vector
<
DType
>&
ograds
,
const
Param
&
params
,
std
::
vector
<
DType
>&
igrads
)
{
igrads
[
0
]
=
ograds
[
0
];
igrads
[
1
]
=
ograds
[
0
];
}
void
backward_format_infer
(
const
std
::
vector
<
Format
>&
ograds
,
const
Param
&
params
,
std
::
vector
<
Format
>&
igrads
)
{
igrads
[
0
]
=
ograds
[
0
];
igrads
[
1
]
=
ograds
[
0
];
}
template
<
typename
scalar_t
>
void
backward_kernel
(
const
scalar_t
*
ograd
,
scalar_t
*
igrad0
,
scalar_t
*
igrad1
,
size_t
len
)
{
for
(
size_t
i
=
0
;
i
<
len
;
++
i
)
{
igrad0
[
i
]
=
ograd
[
i
];
igrad1
[
i
]
=
ograd
[
i
];
}
}
void
backward_compute
(
const
std
::
vector
<
Tensor
>&
ograds
,
const
Param
&
params
,
std
::
vector
<
Tensor
>&
igrads
)
{
DISPATCH_SIGN_INT_AND_FLOAT_TYPES
(
igrads
[
0
].
dtype
(),
"backward_compute"
,
([
&
]()
{
backward_kernel
<
scalar_t
>
(
ograds
[
0
].
data
<
scalar_t
>
(),
igrads
[
0
].
data
<
scalar_t
>
(),
igrads
[
1
].
data
<
scalar_t
>
(),
igrads
[
0
].
size
());
}));
}
CUSTOM_OP_REG
(
ElemAddSmoothBackward
)
.
set_description
(
"Custom ElemAdd Operator With a Smooth Parameter, "
"which is used to verify the CPU kernel"
)
.
add_input
(
"ograd"
)
.
add_output
(
"igrad_lhs"
)
.
add_output
(
"igrad_rhs"
)
.
set_device_infer
(
backward_device_infer
)
.
set_shape_infer
(
backward_shape_infer
)
.
set_dtype_infer
(
backward_dtype_infer
)
.
set_format_infer
(
backward_format_infer
)
.
set_compute
(
backward_compute
);
CUSTOM_OP_REG_END
(
ElemAddSmooth
)
imperative/python/test/unit/core/custom_opsrc/matmul_scale.cpp
0 → 100644
浏览文件 @
a72e0cb5
/**
* \file imperative/python/test/unit/core/custom_opsrc/matmul_scale.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#include "./matmul_scale.h"
#include "megbrain/custom/custom.h"
CUSTOM_OP_REG_BEGIN
(
MatMulScale
)
void
forward_shape_infer
(
const
std
::
vector
<
Shape
>&
inputs
,
const
Param
&
params
,
std
::
vector
<
Shape
>&
outputs
)
{
outputs
[
0
]
=
{
inputs
[
0
][
0
],
inputs
[
1
][
1
]};
}
void
forward_compute
(
const
std
::
vector
<
Tensor
>&
inputs
,
const
Param
&
params
,
std
::
vector
<
Tensor
>&
outputs
)
{
matmul_forward_helper
(
inputs
[
0
],
inputs
[
1
],
outputs
[
0
],
inputs
[
0
].
shape
()[
0
],
inputs
[
0
].
shape
()[
1
],
inputs
[
1
].
shape
()[
1
],
params
[
"scale"
].
as
<
float
>
());
}
CUSTOM_OP_REG
(
MatMulScaleForward
)
.
add_inputs
(
2
)
.
add_outputs
(
1
)
.
add_param
(
"scale"
,
1.0
f
)
.
set_shape_infer
(
forward_shape_infer
)
.
set_compute
(
"cuda"
,
forward_compute
);
void
backward_shape_infer
(
const
std
::
vector
<
Shape
>&
ograd_and_inputs
,
const
Param
&
params
,
std
::
vector
<
Shape
>&
outputs
)
{
outputs
[
0
]
=
ograd_and_inputs
[
1
];
outputs
[
1
]
=
ograd_and_inputs
[
2
];
}
void
backward_compute
(
const
std
::
vector
<
Tensor
>&
ograd_and_inputs
,
const
Param
&
params
,
std
::
vector
<
Tensor
>&
igrads
)
{
matmul_backward_lhs_helper
(
ograd_and_inputs
[
2
],
ograd_and_inputs
[
0
],
igrads
[
0
],
ograd_and_inputs
[
1
].
shape
()[
0
],
ograd_and_inputs
[
1
].
shape
()[
1
],
ograd_and_inputs
[
2
].
shape
()[
1
],
params
[
"scale"
].
as
<
float
>
());
matmul_backward_rhs_helper
(
ograd_and_inputs
[
1
],
ograd_and_inputs
[
0
],
igrads
[
1
],
ograd_and_inputs
[
1
].
shape
()[
0
],
ograd_and_inputs
[
1
].
shape
()[
1
],
ograd_and_inputs
[
2
].
shape
()[
1
],
params
[
"scale"
].
as
<
float
>
());
}
CUSTOM_OP_REG
(
MatMulScaleBackward
)
.
add_inputs
(
3
)
.
add_outputs
(
2
)
.
add_param
(
"scale"
,
1.0
f
)
.
set_shape_infer
(
backward_shape_infer
)
.
set_compute
(
"cuda"
,
backward_compute
);
CUSTOM_OP_REG_END
(
MatMulScale
)
imperative/python/test/unit/core/custom_opsrc/matmul_scale.cu
0 → 100644
浏览文件 @
a72e0cb5
/**
* \file imperative/python/test/unit/core/custom_opsrc/matmul_scale.cu
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#include <cuda.h>
#include <cuda_runtime.h>
#include <stdio.h>
#include "./matmul_scale.h"
using
namespace
custom
;
// matmul_forward for Mat_mxk * Mat_k*n
template
<
typename
T
>
__global__
void
matmul_forward_naive
(
const
T
*
lhs
,
const
T
*
rhs
,
T
*
res
,
size_t
M
,
size_t
K
,
size_t
N
,
float
scale
)
{
int
row
=
blockIdx
.
y
*
blockDim
.
y
+
threadIdx
.
y
;
int
col
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
T
acc
=
0
;
for
(
int
i
=
0
;
i
<
K
;
++
i
)
acc
+=
lhs
[
row
*
K
+
i
]
*
rhs
[
i
*
N
+
col
];
res
[
row
*
N
+
col
]
=
acc
*
scale
;
}
// matmul_backward_lhs for Mat_mxk * Mat_k*n = Mat_mxn
// that is Mat_mxn * Mat_nxk
template
<
typename
T
>
__global__
void
matmul_backward_lhs_naive
(
const
T
*
rhs
,
const
T
*
ograd
,
T
*
lhs_grad
,
size_t
M
,
size_t
K
,
size_t
N
,
float
scale
)
{
int
row
=
blockIdx
.
y
*
blockDim
.
y
+
threadIdx
.
y
;
int
col
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
T
acc
=
0
;
for
(
int
i
=
0
;
i
<
N
;
++
i
)
acc
+=
ograd
[
row
*
N
+
i
]
*
rhs
[
col
*
N
+
i
];
lhs_grad
[
row
*
K
+
col
]
=
acc
/
scale
;
}
// matmul_backward_rhs for Mat_mxk * Mat_k*n = Mat_mxn
// that is Mat_kxm * Mat_mxn
template
<
typename
T
>
__global__
void
matmul_backward_rhs_naive
(
const
T
*
lhs
,
const
T
*
ograd
,
T
*
rhs_grad
,
size_t
M
,
size_t
K
,
size_t
N
,
float
scale
)
{
int
row
=
blockIdx
.
y
*
blockDim
.
y
+
threadIdx
.
y
;
int
col
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
T
acc
=
0
;
for
(
int
i
=
0
;
i
<
M
;
++
i
)
acc
+=
lhs
[
i
*
K
+
row
]
*
ograd
[
i
*
N
+
col
];
rhs_grad
[
row
*
N
+
col
]
=
acc
/
scale
;
}
void
matmul_forward_helper
(
const
Tensor
&
lhs
,
const
Tensor
&
rhs
,
Tensor
&
res
,
size_t
M
,
size_t
K
,
size_t
N
,
float
scale
)
{
dim3
block
(
1
,
1
);
dim3
grid
(
N
/
block
.
x
,
M
/
block
.
y
);
DISPATCH_INT_AND_FLOAT_TYPES
(
res
.
dtype
(),
"matmul_forward"
,
([
&
]()
{
matmul_forward_naive
<
scalar_t
><<<
grid
,
block
>>>
(
lhs
.
data
<
scalar_t
>
(),
rhs
.
data
<
scalar_t
>
(),
res
.
data
<
scalar_t
>
(),
M
,
K
,
N
,
scale
);
}));
}
void
matmul_backward_lhs_helper
(
const
Tensor
&
rhs
,
const
Tensor
&
ograd
,
Tensor
&
lhs_grad
,
size_t
M
,
size_t
K
,
size_t
N
,
float
scale
)
{
dim3
block
(
1
,
1
);
dim3
grid
(
K
/
block
.
x
,
M
/
block
.
y
);
DISPATCH_INT_AND_FLOAT_TYPES
(
lhs_grad
.
dtype
(),
"matmul_backward_lhs"
,
([
&
]()
{
matmul_backward_lhs_naive
<
scalar_t
><<<
grid
,
block
>>>
(
rhs
.
data
<
scalar_t
>
(),
ograd
.
data
<
scalar_t
>
(),
lhs_grad
.
data
<
scalar_t
>
(),
M
,
K
,
N
,
scale
);
}));
}
void
matmul_backward_rhs_helper
(
const
Tensor
&
lhs
,
const
Tensor
&
ograd
,
Tensor
&
rhs_grad
,
size_t
M
,
size_t
K
,
size_t
N
,
float
scale
)
{
dim3
block
(
1
,
1
);
dim3
grid
(
N
/
block
.
x
,
K
/
block
.
y
);
DISPATCH_INT_AND_FLOAT_TYPES
(
rhs_grad
.
dtype
(),
"matmul_backward_rhs"
,
([
&
]()
{
matmul_backward_rhs_naive
<
scalar_t
><<<
grid
,
block
>>>
(
lhs
.
data
<
scalar_t
>
(),
ograd
.
data
<
scalar_t
>
(),
rhs_grad
.
data
<
scalar_t
>
(),
M
,
K
,
N
,
scale
);
}));
}
imperative/python/test/unit/core/custom_opsrc/matmul_scale.h
0 → 100644
浏览文件 @
a72e0cb5
/**
* \file imperative/python/test/unit/core/custom_opsrc/matmul_scale.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#include "megbrain/custom/custom.h"
using
Tensor
=
custom
::
Tensor
;
void
matmul_forward_helper
(
const
Tensor
&
lhs
,
const
Tensor
&
rhs
,
Tensor
&
res
,
size_t
M
,
size_t
K
,
size_t
N
,
float
scale
);
void
matmul_backward_lhs_helper
(
const
Tensor
&
rhs
,
const
Tensor
&
ograd
,
Tensor
&
lhs_grad
,
size_t
M
,
size_t
K
,
size_t
N
,
float
scale
);
void
matmul_backward_rhs_helper
(
const
Tensor
&
lhs
,
const
Tensor
&
ograd
,
Tensor
&
rhs_grad
,
size_t
M
,
size_t
K
,
size_t
N
,
float
scale
);
imperative/python/test/unit/core/test_custom_op.py
0 → 100644
浏览文件 @
a72e0cb5
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import
os
import
platform
import
shutil
import
sys
import
numpy
as
np
import
pytest
import
megengine
import
megengine.functional
as
F
import
megengine.optimizer
as
optim
from
megengine
import
jit
from
megengine.autodiff
import
Function
,
GradManager
from
megengine.core._imperative_rt.core2
import
apply
from
megengine.core.ops
import
custom
from
megengine.device
import
get_device_count
from
megengine.module
import
Conv2d
,
Linear
,
Module
from
megengine.random
import
normal
from
megengine.tensor
import
Parameter
,
Tensor
from
megengine.utils
import
custom_op_tools
def
compare
(
ref
,
real
):
if
ref
.
shape
!=
real
.
shape
:
real
=
real
.
T
np
.
testing
.
assert_allclose
(
ref
,
real
,
rtol
=
1e-3
,
atol
=
1e-5
)
def
build_and_clean
(
test_func
):
def
wrapper
():
cur_dir_path
=
os
.
path
.
dirname
(
os
.
path
.
abspath
(
__file__
))
build_path
=
os
.
path
.
join
(
cur_dir_path
,
"custom_opsrc"
,
"build"
)
mgb_root_path
=
os
.
path
.
dirname
(
os
.
path
.
dirname
(
os
.
path
.
dirname
(
os
.
path
.
dirname
(
os
.
path
.
dirname
(
cur_dir_path
)))
)
)
extra_include_paths
=
[
os
.
path
.
join
(
mgb_root_path
,
"src"
,
"custom"
,
"include"
)]
extra_ld_flags
=
[]
if
sys
.
platform
!=
"win32"
:
ld_path
=
os
.
environ
.
get
(
"LD_LIBRARY_PATH"
)
if
ld_path
!=
None
:
ld_dirs
=
ld_path
.
split
(
":"
)
for
ld_dir
in
ld_dirs
:
if
os
.
path
.
exists
(
ld_dir
)
and
os
.
path
.
isdir
(
ld_dir
):
for
lib
in
os
.
listdir
(
ld_dir
):
if
"megengine_shared"
in
lib
:
extra_ld_flags
+=
[
"-L{} -Wl,-rpath,{}"
.
format
(
ld_dir
,
ld_dir
)
]
break
if
get_device_count
(
"gpu"
)
>
0
:
custom_opsrc
=
[
os
.
path
.
join
(
cur_dir_path
,
"custom_opsrc"
,
"matmul_scale.cpp"
),
os
.
path
.
join
(
cur_dir_path
,
"custom_opsrc"
,
"matmul_scale.cu"
),
]
else
:
custom_opsrc
=
[
os
.
path
.
join
(
cur_dir_path
,
"custom_opsrc"
,
"elem_add.cpp"
)]
lib_path
=
custom_op_tools
.
build_and_load
(
"test_op"
,
custom_opsrc
,
extra_include_paths
=
extra_include_paths
,
extra_ldflags
=
extra_ld_flags
,
build_dir
=
build_path
,
verbose
=
False
,
abi_tag
=
custom
.
get_custom_op_abi_tag
(),
)
test_func
()
custom
.
unload
(
lib_path
)
if
os
.
path
.
exists
(
build_path
):
shutil
.
rmtree
(
build_path
)
return
wrapper
@
pytest
.
mark
.
skipif
(
get_device_count
(
"gpu"
)
>
0
,
reason
=
"elem_add operator is only supported on CPU"
)
@
build_and_clean
def
test_custom_op_cpu_build
():
assert
"ElemAddSmoothForward"
in
custom
.
_get_custom_op_list
()
assert
"ElemAddSmoothBackward"
in
custom
.
_get_custom_op_list
()
assert
hasattr
(
custom
,
"ElemAddSmoothForward"
)
assert
hasattr
(
custom
,
"ElemAddSmoothBackward"
)
@
pytest
.
mark
.
skipif
(
platform
.
system
()
==
"Darwin"
,
reason
=
"GPU kernel is only support on Linux and Windows"
,
)
@
pytest
.
mark
.
skipif
(
get_device_count
(
"gpu"
)
<
1
,
reason
=
"matmul scale operator is only supported on GPU"
)
@
build_and_clean
def
test_custom_op_gpu_build
():
assert
"MatMulScaleForward"
in
custom
.
_get_custom_op_list
()
assert
"MatMulScaleBackward"
in
custom
.
_get_custom_op_list
()
assert
hasattr
(
custom
,
"MatMulScaleForward"
)
assert
hasattr
(
custom
,
"MatMulScaleBackward"
)
scripts/whl/macos/macos_build_whl.sh
浏览文件 @
a72e0cb5
...
...
@@ -171,6 +171,7 @@ function do_build() {
mkdir
-p
staging
cp
-a
imperative/python/
{
megengine,setup.py,requires.txt,requires-style.txt,requires-test.txt
}
staging/
cp
-a
${
SRC_DIR
}
/src/custom/include staging/megengine/core/include/
cd
${
BUILD_DIR
}
/staging/megengine/core
rt_file
=
`
ls
_imperative_rt.
*
.so
`
echo
"rt file is:
${
rt_file
}
"
...
...
scripts/whl/manylinux2014/do_build_common.sh
浏览文件 @
a72e0cb5
...
...
@@ -151,6 +151,7 @@ do
rm
-rf
staging
mkdir
-p
staging
cp
-a
imperative/python/
{
megengine,setup.py,requires.txt,requires-style.txt,requires-test.txt
}
staging/
cp
-a
${
SRC_DIR
}
/src/custom/include/megbrain staging/megengine/core/include
cd
${
BUILD_DIR
}
/staging/megengine/core
mkdir
-p
lib/ucx
...
...
scripts/whl/windows/windows_build_whl.sh
浏览文件 @
a72e0cb5
...
...
@@ -77,11 +77,13 @@ CUBLAS_LIB="/c/Program Files/NVIDIA GPU Computing Toolkit/CUDA/v10.1/bin/cublas6
CURAND_LIB
=
"/c/Program Files/NVIDIA GPU Computing Toolkit/CUDA/v10.1/bin/curand64_10.dll"
CUBLASLT_LIB
=
"/c/Program Files/NVIDIA GPU Computing Toolkit/CUDA/v10.1/bin/cublasLt64_10.dll"
CUDART_LIB
=
"/c/Program Files/NVIDIA GPU Computing Toolkit/CUDA/v10.1/bin/cudart64_101.dll"
MGE_EXPORT_LIB
=
"
${
SRC_DIR
}
/build_dir/host/build/src/megengine_shared.dll"
MGE_EXPORT_DLL
=
"
${
SRC_DIR
}
/build_dir/host/build/src/megengine_shared.dll"
MGE_EXPORT_LIB
=
"
${
SRC_DIR
}
/build_dir/host/build/src/megengine_shared.lib"
function
depend_real_copy
()
{
REAL_DST
=
$1
echo
"real copy lib to
$1
"
cp
"
${
MGE_EXPORT_DLL
}
"
${
REAL_DST
}
cp
"
${
MGE_EXPORT_LIB
}
"
${
REAL_DST
}
if
[
${
BUILD_WHL_CPU_ONLY
}
=
"OFF"
]
;
then
...
...
@@ -190,6 +192,7 @@ function do_build() {
rm
-rf
staging
mkdir
-p
staging
cp
-a
imperative/python/
{
megengine,setup.py,requires.txt,requires-style.txt,requires-test.txt
}
staging/
cp
-a
${
SRC_DIR
}
/src/custom/include/megbrain staging/megengine/core/include/
cd
${
BUILD_DIR
}
/staging/megengine/core
rt_file
=
`
ls
_imperative_rt.
*
.pyd
`
echo
"rt file is:
${
rt_file
}
"
...
...
src/CMakeLists.txt
浏览文件 @
a72e0cb5
# force define a SHARED target for whl, caused by when build for APPLE
# we will force set BUILD_SHARED_LIBS=OFF for xcode needed
set
(
MGE_SHARED_LIB megengine_shared
)
set
(
MGE_SHARED_LIB
${
MGE_SHARED_LIB
}
PARENT_SCOPE
)
if
(
MGE_WITH_JIT_MLIR
)
add_subdirectory
(
jit/include/megbrain/jit/mlir/ir
)
endif
()
...
...
@@ -206,32 +211,30 @@ set (_VER_FILE ${PROJECT_SOURCE_DIR}/src/version.ld)
# Build as SHARED or STATIC depending on BUILD_SHARED_LIBS=ON/OFF
add_library
(
megengine
)
# force define a SHARED target for whl, caused by when build for APPLE
# we will force set BUILD_SHARED_LIBS=OFF for xcode needed
add_library
(
megengine_shared SHARED
)
add_library
(
${
MGE_SHARED_LIB
}
SHARED
)
target_link_libraries
(
megengine PRIVATE
${
MGE_CUDA_LIBS
}
)
target_link_libraries
(
megengine PUBLIC megbrain megdnn
)
target_link_libraries
(
megengine_shared
PUBLIC megbrain megdnn
)
target_link_libraries
(
megengine_shared
PRIVATE
${
MGE_CUDA_LIBS
}
)
target_link_libraries
(
${
MGE_SHARED_LIB
}
PUBLIC megbrain megdnn
)
target_link_libraries
(
${
MGE_SHARED_LIB
}
PRIVATE
${
MGE_CUDA_LIBS
}
)
if
(
UNIX AND NOT APPLE
)
target_link_options
(
megengine PRIVATE -Wl,--no-undefined -Wl,--version-script=
${
_VER_FILE
}
)
set_target_properties
(
megengine PROPERTIES LINK_DEPENDS
${
_VER_FILE
}
)
target_link_options
(
megengine_shared
PRIVATE -Wl,--no-undefined -Wl,--version-script=
${
_VER_FILE
}
)
set_target_properties
(
megengine_shared
PROPERTIES LINK_DEPENDS
${
_VER_FILE
}
)
target_link_options
(
${
MGE_SHARED_LIB
}
PRIVATE -Wl,--no-undefined -Wl,--version-script=
${
_VER_FILE
}
)
set_target_properties
(
${
MGE_SHARED_LIB
}
PROPERTIES LINK_DEPENDS
${
_VER_FILE
}
)
endif
()
if
(
WIN32 OR MSVC
)
target_compile_definitions
(
megbrain PRIVATE MGE_DLL_EXPORT
)
target_compile_definitions
(
megdnn PRIVATE MGE_DLL_EXPORT
)
target_compile_definitions
(
megengine PRIVATE MGE_DLL_EXPORT
)
target_compile_definitions
(
megengine_shared
PRIVATE MGE_DLL_EXPORT
)
target_compile_definitions
(
${
MGE_SHARED_LIB
}
PRIVATE MGE_DLL_EXPORT
)
# please do not use WINDOWS_EXPORT_ALL_SYMBOLS, as symbols max than 65535 when build with CUDA
#set_target_properties(megengine PROPERTIES WINDOWS_EXPORT_ALL_SYMBOLS TRUE)
#set_target_properties(
megengine_shared
PROPERTIES WINDOWS_EXPORT_ALL_SYMBOLS TRUE)
#set_target_properties(
${MGE_SHARED_LIB}
PROPERTIES WINDOWS_EXPORT_ALL_SYMBOLS TRUE)
endif
()
if
(
MGE_WITH_DISTRIBUTED
)
message
(
VERBOSE
"megengine configured to link megray"
)
target_link_libraries
(
megengine PUBLIC megray
)
target_link_libraries
(
megengine_shared
PUBLIC megray
)
target_link_libraries
(
${
MGE_SHARED_LIB
}
PUBLIC megray
)
endif
()
# Do not export targets if MGE_WITH_DISTRIBUTED is on. MegRay is not ready
# for this.
...
...
src/custom/impl/manager.cpp
浏览文件 @
a72e0cb5
...
...
@@ -18,12 +18,31 @@
#ifndef _WIN32
#include <dlfcn.h>
#else
#include <windows.h>
#endif
using
namespace
mgb
;
namespace
custom
{
#ifdef _WIN32
#define RTLD_LAZY 0
void
*
dlopen
(
const
char
*
file
,
int
)
{
return
static_cast
<
void
*>
(
LoadLibrary
(
file
));
}
int
dlclose
(
void
*
handle
)
{
return
static_cast
<
int
>
(
FreeLibrary
(
static_cast
<
HMODULE
>
(
handle
)));
}
const
char
*
dlerror
(
void
)
{
static
char
win_err_info
[]
=
"no dlerror info in windows"
;
return
win_err_info
;
}
#endif
CustomOpManager
*
CustomOpManager
::
inst
(
void
)
{
static
CustomOpManager
op_manager
;
return
&
op_manager
;
...
...
@@ -127,7 +146,6 @@ std::vector<RunTimeId> CustomOpManager::op_id_list(void) {
return
ret
;
}
#ifndef _WIN32
CustomLib
::
CustomLib
(
const
std
::
string
&
path
,
int
mode
=
RTLD_LAZY
)
:
m_handle
(
nullptr
,
[](
void
*
handle
)
{
dlclose
(
handle
);
})
{
auto
op_list_before_load
=
CustomOpManager
::
inst
()
->
op_name_list
();
...
...
@@ -146,12 +164,6 @@ CustomLib::CustomLib(const std::string& path, int mode = RTLD_LAZY)
}
}
}
#else
CustomLib
::
CustomLib
(
const
std
::
string
&
path
,
int
mode
=
0
)
:
m_handle
(
nullptr
,
[](
void
*
handle
)
{})
{
mgb_assert
(
false
,
"custom op is only supported on Linux now"
);
}
#endif
const
std
::
vector
<
std
::
string
>&
CustomLib
::
ops_in_lib
(
void
)
const
{
return
m_ops
;
...
...
src/custom/include/megbrain/custom/custom.h
浏览文件 @
a72e0cb5
...
...
@@ -16,7 +16,8 @@
#include "tensor.h"
namespace
custom
{
std
::
shared_ptr
<
CustomOp
>
op_insert
(
std
::
string
opname
,
uint32_t
version
);
MGE_WIN_DECLSPEC_FUC
std
::
shared_ptr
<
CustomOp
>
op_insert
(
std
::
string
opname
,
uint32_t
version
);
}
#define CUSTOM_OP_REG(OpName) \
...
...
src/custom/include/megbrain/custom/op.h
浏览文件 @
a72e0cb5
...
...
@@ -32,27 +32,26 @@ namespace custom {
using
RunTimeId
=
uint64_t
;
class
ArgInfo
{
class
MGE_WIN_DECLSPEC_FUC
ArgInfo
{
CUSTOM_PIMPL_CLS_DECL
(
ArgInfo
);
MGE_WIN_DECLSPEC_FUC
ArgInfo
(
const
std
::
string
&
name
,
const
std
::
string
&
desc
,
ArgInfo
(
const
std
::
string
&
name
,
const
std
::
string
&
desc
,
const
std
::
unordered_set
<
std
::
string
>&
dtypes
,
const
int
&
ndim
,
const
std
::
string
&
mem_stgy
);
MGE_WIN_DECLSPEC_FUC
const
std
::
string
&
name
(
void
)
const
;
MGE_WIN_DECLSPEC_FUC
const
std
::
string
&
desc
(
void
)
const
;
MGE_WIN_DECLSPEC_FUC
const
std
::
unordered_set
<
std
::
string
>&
dtypes
(
void
)
const
;
MGE_WIN_DECLSPEC_FUC
int
ndim
(
void
)
const
;
MGE_WIN_DECLSPEC_FUC
const
std
::
string
&
mem_strategy
(
void
)
const
;
const
std
::
string
&
name
(
void
)
const
;
const
std
::
string
&
desc
(
void
)
const
;
const
std
::
unordered_set
<
std
::
string
>&
dtypes
(
void
)
const
;
int
ndim
(
void
)
const
;
const
std
::
string
&
mem_strategy
(
void
)
const
;
MGE_WIN_DECLSPEC_FUC
std
::
string
str
()
const
;
std
::
string
str
()
const
;
};
class
CustomOp
{
class
MGE_WIN_DECLSPEC_FUC
CustomOp
{
std
::
unique_ptr
<
void
,
void_deleter
>
m_impl
;
public:
MGE_WIN_DECLSPEC_FUC
CustomOp
(
const
std
::
string
&
op_type
,
uint32_t
version
);
CustomOp
(
const
std
::
string
&
op_type
,
uint32_t
version
);
PREVENT_COPY_AND_ASSIGN
(
CustomOp
);
using
DeviceInferFuncPtr
=
...
...
@@ -71,70 +70,65 @@ public:
void
(
*
)(
const
std
::
vector
<
Tensor
>&
,
const
Param
&
,
std
::
vector
<
Tensor
>&
);
// write for forward
MGE_WIN_DECLSPEC_FUC
CustomOp
&
set_device_infer
(
DeviceInferFuncPtr
func
);
MGE_WIN_DECLSPEC_FUC
CustomOp
&
set_shape_infer
(
ShapeInferFuncPtr
func
);
MGE_WIN_DECLSPEC_FUC
CustomOp
&
set_dtype_infer
(
DTypeInferFuncPtr
func
);
MGE_WIN_DECLSPEC_FUC
CustomOp
&
set_format_infer
(
FormatInferFuncPtr
func
);
MGE_WIN_DECLSPEC_FUC
CustomOp
&
set_preprocess
(
PreprocessFuncPtr
func
);
MGE_WIN_DECLSPEC_FUC
CustomOp
&
set_preprocess
(
const
std
::
string
&
device
,
PreprocessFuncPtr
func
);
MGE_WIN_DECLSPEC_FUC
CustomOp
&
set_postprocess
(
PostprocessFuncPtr
func
);
MGE_WIN_DECLSPEC_FUC
CustomOp
&
set_postprocess
(
const
std
::
string
&
device
,
PostprocessFuncPtr
func
);
MGE_WIN_DECLSPEC_FUC
CustomOp
&
set_compute
(
ComputeFuncPtr
func
);
MGE_WIN_DECLSPEC_FUC
CustomOp
&
set_compute
(
const
std
::
string
&
device
,
ComputeFuncPtr
func
);
MGE_WIN_DECLSPEC_FUC
CustomOp
&
set_description
(
const
std
::
string
&
op_desc
);
MGE_WIN_DECLSPEC_FUC
CustomOp
&
add_input
(
CustomOp
&
set_device_infer
(
DeviceInferFuncPtr
func
);
CustomOp
&
set_shape_infer
(
ShapeInferFuncPtr
func
);
CustomOp
&
set_dtype_infer
(
DTypeInferFuncPtr
func
);
CustomOp
&
set_format_infer
(
FormatInferFuncPtr
func
);
CustomOp
&
set_preprocess
(
PreprocessFuncPtr
func
);
CustomOp
&
set_preprocess
(
const
std
::
string
&
device
,
PreprocessFuncPtr
func
);
CustomOp
&
set_postprocess
(
PostprocessFuncPtr
func
);
CustomOp
&
set_postprocess
(
const
std
::
string
&
device
,
PostprocessFuncPtr
func
);
CustomOp
&
set_compute
(
ComputeFuncPtr
func
);
CustomOp
&
set_compute
(
const
std
::
string
&
device
,
ComputeFuncPtr
func
);
CustomOp
&
set_description
(
const
std
::
string
&
op_desc
);
CustomOp
&
add_input
(
const
std
::
string
&
name
,
const
std
::
string
&
desc
,
const
std
::
initializer_list
<
std
::
string
>&
legal_dtypes
=
{
"float32"
},
int
dims
=
-
1
,
const
std
::
string
&
mem_stgy
=
"default"
);
MGE_WIN_DECLSPEC_FUC
CustomOp
&
add_output
(
CustomOp
&
add_output
(
const
std
::
string
&
name
,
const
std
::
string
&
desc
,
const
std
::
initializer_list
<
std
::
string
>&
legal_dtypes
=
{
"float32"
},
int
dims
=
-
1
,
const
std
::
string
&
mem_stgy
=
"default"
);
MGE_WIN_DECLSPEC_FUC
CustomOp
&
add_input
(
CustomOp
&
add_input
(
const
std
::
string
&
name
,
const
std
::
initializer_list
<
std
::
string
>&
legal_dtypes
=
{
"float32"
},
int
dims
=
-
1
,
const
std
::
string
&
mem_stgy
=
"default"
);
MGE_WIN_DECLSPEC_FUC
CustomOp
&
add_output
(
CustomOp
&
add_output
(
const
std
::
string
&
name
,
const
std
::
initializer_list
<
std
::
string
>&
legal_dtypes
=
{
"float32"
},
int
dims
=
-
1
,
const
std
::
string
&
mem_stgy
=
"default"
);
MGE_WIN_DECLSPEC_FUC
CustomOp
&
add_inputs
(
const
size_t
&
input_num
);
MGE_WIN_DECLSPEC_FUC
CustomOp
&
add_outputs
(
const
size_t
&
output_num
);
MGE_WIN_DECLSPEC_FUC
CustomOp
&
add_param
(
const
std
::
string
&
name
,
const
ParamVal
&
default_val
);
MGE_WIN_DECLSPEC_FUC
CustomOp
&
add_param
(
CustomOp
&
add_inputs
(
const
size_t
&
input_num
);
CustomOp
&
add_outputs
(
const
size_t
&
output_num
);
CustomOp
&
add_param
(
const
std
::
string
&
name
,
const
ParamVal
&
default_val
);
CustomOp
&
add_param
(
const
std
::
string
&
name
,
const
std
::
string
&
desc
,
const
ParamVal
&
default_val
);
// read
MGE_WIN_DECLSPEC_FUC
std
::
string
op_type
(
void
)
const
;
MGE_WIN_DECLSPEC_FUC
std
::
string
op_desc
(
void
)
const
;
MGE_WIN_DECLSPEC_FUC
RunTimeId
runtime_id
(
void
)
const
;
MGE_WIN_DECLSPEC_FUC
size_t
input_num
(
void
)
const
;
MGE_WIN_DECLSPEC_FUC
size_t
output_num
(
void
)
const
;
MGE_WIN_DECLSPEC_FUC
std
::
string
str
(
void
)
const
;
MGE_WIN_DECLSPEC_FUC
const
ParamInfo
&
param_info
(
void
)
const
;
MGE_WIN_DECLSPEC_FUC
ArgInfo
input_info
(
size_t
idx
)
const
;
MGE_WIN_DECLSPEC_FUC
ArgInfo
output_info
(
size_t
idx
)
const
;
MGE_WIN_DECLSPEC_FUC
const
std
::
vector
<
ArgInfo
>&
inputs_info
(
void
)
const
;
MGE_WIN_DECLSPEC_FUC
const
std
::
vector
<
ArgInfo
>&
outputs_info
(
void
)
const
;
std
::
string
op_type
(
void
)
const
;
std
::
string
op_desc
(
void
)
const
;
RunTimeId
runtime_id
(
void
)
const
;
size_t
input_num
(
void
)
const
;
size_t
output_num
(
void
)
const
;
std
::
string
str
(
void
)
const
;
const
ParamInfo
&
param_info
(
void
)
const
;
ArgInfo
input_info
(
size_t
idx
)
const
;
ArgInfo
output_info
(
size_t
idx
)
const
;
const
std
::
vector
<
ArgInfo
>&
inputs_info
(
void
)
const
;
const
std
::
vector
<
ArgInfo
>&
outputs_info
(
void
)
const
;
// use
MGE_WIN_DECLSPEC_FUC
std
::
vector
<
Device
>
infer_output_device
(
std
::
vector
<
Device
>
infer_output_device
(
const
std
::
vector
<
Device
>&
,
const
Param
&
)
const
;
MGE_WIN_DECLSPEC_FUC
std
::
vector
<
Shape
>
infer_output_shape
(
std
::
vector
<
Shape
>
infer_output_shape
(
const
std
::
vector
<
Shape
>&
,
const
Param
&
)
const
;
MGE_WIN_DECLSPEC_FUC
std
::
vector
<
DType
>
infer_output_dtype
(
std
::
vector
<
DType
>
infer_output_dtype
(
const
std
::
vector
<
DType
>&
,
const
Param
&
)
const
;
MGE_WIN_DECLSPEC_FUC
std
::
vector
<
Format
>
infer_output_format
(
std
::
vector
<
Format
>
infer_output_format
(
const
std
::
vector
<
Format
>&
,
const
Param
&
)
const
;
MGE_WIN_DECLSPEC_FUC
void
compute
(
const
std
::
vector
<
Tensor
>&
,
const
Param
&
,
std
::
vector
<
Tensor
>&
)
const
;
void
compute
(
const
std
::
vector
<
Tensor
>&
,
const
Param
&
,
std
::
vector
<
Tensor
>&
)
const
;
};
}
// namespace custom
src/custom/include/megbrain/custom/param.h
浏览文件 @
a72e0cb5
...
...
@@ -23,7 +23,7 @@ class ParamInfoImpl;
class
ParamImpl
;
// Schema of a param element
class
ParamSchema
{
class
MGE_WIN_DECLSPEC_FUC
ParamSchema
{
CUSTOM_PIMPL_CLS_DECL
(
ParamSchema
);
ParamSchema
(
const
std
::
string
&
name
,
const
ParamVal
&
value
,
...
...
@@ -36,7 +36,7 @@ class ParamSchema {
std
::
string
str
(
void
)
const
;
};
class
ParamInfo
{
class
MGE_WIN_DECLSPEC_FUC
ParamInfo
{
CUSTOM_PIMPL_CLS_DECL
(
ParamInfo
);
void
set_tag
(
const
std
::
string
&
);
...
...
@@ -46,16 +46,16 @@ class ParamInfo {
const
std
::
vector
<
ParamSchema
>&
meta
(
void
)
const
;
};
class
Param
{
class
MGE_WIN_DECLSPEC_FUC
Param
{
CUSTOM_PIMPL_CLS_DECL
(
Param
);
MGE_WIN_DECLSPEC_FUC
Param
(
const
ParamInfo
&
);
MGE_WIN_DECLSPEC_FUC
ParamVal
&
operator
[](
const
std
::
string
&
);
MGE_WIN_DECLSPEC_FUC
const
ParamVal
&
operator
[](
const
std
::
string
&
)
const
;
MGE_WIN_DECLSPEC_FUC
const
std
::
unordered_map
<
std
::
string
,
ParamVal
>&
raw
()
const
;
MGE_WIN_DECLSPEC_FUC
bool
exist
(
const
std
::
string
&
name
)
const
;
MGE_WIN_DECLSPEC_FUC
std
::
string
to_bytes
(
void
)
const
;
MGE_WIN_DECLSPEC_FUC
void
from_bytes
(
const
std
::
string
&
);
Param
(
const
ParamInfo
&
);
ParamVal
&
operator
[](
const
std
::
string
&
);
const
ParamVal
&
operator
[](
const
std
::
string
&
)
const
;
const
std
::
unordered_map
<
std
::
string
,
ParamVal
>&
raw
()
const
;
bool
exist
(
const
std
::
string
&
name
)
const
;
std
::
string
to_bytes
(
void
)
const
;
void
from_bytes
(
const
std
::
string
&
);
};
MGE_WIN_DECLSPEC_FUC
bool
operator
==
(
const
Param
&
,
const
Param
&
);
...
...
src/custom/include/megbrain/custom/param_val.h
浏览文件 @
a72e0cb5
...
...
@@ -169,21 +169,21 @@ std::string vec2str(const std::vector<T>& vec) {
* Con1: user need to set the type explicitly when class template instantiation
* Con2: ParamVal<int> can not be assigned to ParamVal<double>
*/
class
ParamVal
{
class
MGE_WIN_DECLSPEC_FUC
ParamVal
{
std
::
unique_ptr
<
void
,
void_deleter
>
m_ptr
;
ParamDynType
m_type
;
public:
template
<
typename
T
>
MGE_WIN_DECLSPEC_FUC
ParamVal
(
const
T
&
val
);
ParamVal
(
const
T
&
val
);
template
<
typename
T
>
MGE_WIN_DECLSPEC_FUC
ParamVal
(
const
std
::
initializer_list
<
T
>&
val
);
ParamVal
(
const
std
::
initializer_list
<
T
>&
val
);
MGE_WIN_DECLSPEC_FUC
ParamVal
();
MGE_WIN_DECLSPEC_FUC
ParamVal
(
const
char
*
str
);
MGE_WIN_DECLSPEC_FUC
ParamVal
(
const
std
::
initializer_list
<
const
char
*>&
strs
);
MGE_WIN_DECLSPEC_FUC
ParamVal
(
const
std
::
vector
<
const
char
*>&
strs
);
MGE_WIN_DECLSPEC_FUC
ParamVal
(
const
ParamVal
&
rhs
);
ParamVal
();
ParamVal
(
const
char
*
str
);
ParamVal
(
const
std
::
initializer_list
<
const
char
*>&
strs
);
ParamVal
(
const
std
::
vector
<
const
char
*>&
strs
);
ParamVal
(
const
ParamVal
&
rhs
);
template
<
typename
T
>
ParamVal
&
operator
=
(
const
T
&
rhs
);
...
...
@@ -196,30 +196,39 @@ public:
ParamVal
&
operator
=
(
const
ParamVal
&
rhs
);
template
<
typename
T
>
MGE_WIN_DECLSPEC_FUC
const
T
&
as
(
void
)
const
;
const
T
&
as
(
void
)
const
;
template
<
typename
T
>
MGE_WIN_DECLSPEC_FUC
T
&
as
(
void
);
MGE_WIN_DECLSPEC_FUC
const
void
*
raw_ptr
(
void
)
const
;
MGE_WIN_DECLSPEC_FUC
void
*
raw_ptr
(
void
);
MGE_WIN_DECLSPEC_FUC
ParamDynType
type
(
void
)
const
;
MGE_WIN_DECLSPEC_FUC
std
::
string
str
(
void
)
const
;
MGE_WIN_DECLSPEC_FUC
size_t
size
(
void
)
const
;
MGE_WIN_DECLSPEC_FUC
static
std
::
string
to_bytes
(
const
ParamVal
&
value
);
MGE_WIN_DECLSPEC_FUC
static
ParamVal
from_bytes
(
const
std
::
string
&
bytes
,
size_t
&
offset
);
friend
ParamVal
operator
+
(
const
ParamVal
&
lhs
,
const
ParamVal
&
rhs
);
friend
ParamVal
operator
-
(
const
ParamVal
&
lhs
,
const
ParamVal
&
rhs
);
friend
ParamVal
operator
*
(
const
ParamVal
&
lhs
,
const
ParamVal
&
rhs
);
friend
ParamVal
operator
/
(
const
ParamVal
&
lhs
,
const
ParamVal
&
rhs
);
friend
bool
operator
==
(
const
ParamVal
&
lhs
,
const
ParamVal
&
rhs
);
friend
bool
operator
!=
(
const
ParamVal
&
lhs
,
const
ParamVal
&
rhs
);
friend
bool
operator
>
(
const
ParamVal
&
lhs
,
const
ParamVal
&
rhs
);
friend
bool
operator
<
(
const
ParamVal
&
lhs
,
const
ParamVal
&
rhs
);
friend
bool
operator
>=
(
const
ParamVal
&
lhs
,
const
ParamVal
&
rhs
);
friend
bool
operator
<=
(
const
ParamVal
&
lhs
,
const
ParamVal
&
rhs
);
T
&
as
(
void
);
const
void
*
raw_ptr
(
void
)
const
;
void
*
raw_ptr
(
void
);
ParamDynType
type
(
void
)
const
;
std
::
string
str
(
void
)
const
;
size_t
size
(
void
)
const
;
static
std
::
string
to_bytes
(
const
ParamVal
&
value
);
static
ParamVal
from_bytes
(
const
std
::
string
&
bytes
,
size_t
&
offset
);
MGE_WIN_DECLSPEC_FUC
friend
ParamVal
operator
+
(
const
ParamVal
&
lhs
,
const
ParamVal
&
rhs
);
MGE_WIN_DECLSPEC_FUC
friend
ParamVal
operator
-
(
const
ParamVal
&
lhs
,
const
ParamVal
&
rhs
);
MGE_WIN_DECLSPEC_FUC
friend
ParamVal
operator
*
(
const
ParamVal
&
lhs
,
const
ParamVal
&
rhs
);
MGE_WIN_DECLSPEC_FUC
friend
ParamVal
operator
/
(
const
ParamVal
&
lhs
,
const
ParamVal
&
rhs
);
MGE_WIN_DECLSPEC_FUC
friend
bool
operator
==
(
const
ParamVal
&
lhs
,
const
ParamVal
&
rhs
);
MGE_WIN_DECLSPEC_FUC
friend
bool
operator
!=
(
const
ParamVal
&
lhs
,
const
ParamVal
&
rhs
);
MGE_WIN_DECLSPEC_FUC
friend
bool
operator
>
(
const
ParamVal
&
lhs
,
const
ParamVal
&
rhs
);
MGE_WIN_DECLSPEC_FUC
friend
bool
operator
<
(
const
ParamVal
&
lhs
,
const
ParamVal
&
rhs
);
MGE_WIN_DECLSPEC_FUC
friend
bool
operator
>=
(
const
ParamVal
&
lhs
,
const
ParamVal
&
rhs
);
MGE_WIN_DECLSPEC_FUC
friend
bool
operator
<=
(
const
ParamVal
&
lhs
,
const
ParamVal
&
rhs
);
};
ParamVal
operator
+
(
const
ParamVal
&
lhs
,
const
ParamVal
&
rhs
);
...
...
src/custom/include/megbrain/custom/tensor.h
浏览文件 @
a72e0cb5
...
...
@@ -30,9 +30,9 @@ namespace custom {
#define CUSTOM_DEVICE_TYPE_ENUM_DECL(custom_type, builtin_type, builtin_str) \
custom_type,
class
Device
{
MGE_WIN_DECLSPEC_FUC
const
void
*
impl
()
const
;
MGE_WIN_DECLSPEC_FUC
Device
(
const
void
*
impl
);
class
MGE_WIN_DECLSPEC_FUC
Device
{
const
void
*
impl
()
const
;
Device
(
const
void
*
impl
);
CUSTOM_PIMPL_CLS_DECL
(
Device
);
public:
...
...
@@ -40,19 +40,19 @@ public:
CUSTOM_FOR_EACH_DEVICE_TYPE
(
CUSTOM_DEVICE_TYPE_ENUM_DECL
)
};
MGE_WIN_DECLSPEC_FUC
Device
(
const
std
::
string
&
device
);
MGE_WIN_DECLSPEC_FUC
Device
(
const
char
*
device
);
MGE_WIN_DECLSPEC_FUC
Device
(
DeviceEnum
device
);
Device
(
const
std
::
string
&
device
);
Device
(
const
char
*
device
);
Device
(
DeviceEnum
device
);
MGE_WIN_DECLSPEC_FUC
std
::
string
str
(
void
)
const
;
MGE_WIN_DECLSPEC_FUC
DeviceEnum
enumv
(
void
)
const
;
std
::
string
str
(
void
)
const
;
DeviceEnum
enumv
(
void
)
const
;
MGE_WIN_DECLSPEC_FUC
static
bool
is_legal
(
const
std
::
string
&
device
);
MGE_WIN_DECLSPEC_FUC
static
bool
is_legal
(
DeviceEnum
device
);
MGE_WIN_DECLSPEC_FUC
static
std
::
vector
<
std
::
string
>
legal_devices
(
void
);
static
bool
is_legal
(
const
std
::
string
&
device
);
static
bool
is_legal
(
DeviceEnum
device
);
static
std
::
vector
<
std
::
string
>
legal_devices
(
void
);
friend
class
Tensor
;
friend
bool
operator
==
(
const
Device
&
lhs
,
const
Device
&
rhs
);
MGE_WIN_DECLSPEC_FUC
friend
bool
operator
==
(
const
Device
&
lhs
,
const
Device
&
rhs
);
CUSTOM_DATA_ADAPTOR_FRIEND_DECL
;
};
...
...
@@ -60,23 +60,23 @@ using DeviceEnum = Device::DeviceEnum;
bool
operator
==
(
const
Device
&
lhs
,
const
Device
&
rhs
);
class
Shape
{
MGE_WIN_DECLSPEC_FUC
const
void
*
impl
()
const
;
MGE_WIN_DECLSPEC_FUC
Shape
(
const
void
*
impl
);
class
MGE_WIN_DECLSPEC_FUC
Shape
{
const
void
*
impl
()
const
;
Shape
(
const
void
*
impl
);
CUSTOM_PIMPL_CLS_DECL
(
Shape
);
public:
MGE_WIN_DECLSPEC_FUC
Shape
(
const
std
::
vector
<
size_t
>&
rhs
);
MGE_WIN_DECLSPEC_FUC
Shape
(
const
std
::
initializer_list
<
size_t
>&
rhs
);
Shape
(
const
std
::
vector
<
size_t
>&
rhs
);
Shape
(
const
std
::
initializer_list
<
size_t
>&
rhs
);
size_t
&
operator
[](
size_t
idx
);
size_t
operator
[](
size_t
idx
)
const
;
MGE_WIN_DECLSPEC_FUC
void
ndim
(
size_t
dim
);
MGE_WIN_DECLSPEC_FUC
size_t
ndim
(
void
)
const
;
void
ndim
(
size_t
dim
);
size_t
ndim
(
void
)
const
;
friend
class
Tensor
;
friend
bool
operator
==
(
const
Shape
&
lhs
,
const
Shape
&
rhs
);
MGE_WIN_DECLSPEC_FUC
friend
bool
operator
==
(
const
Shape
&
lhs
,
const
Shape
&
rhs
);
CUSTOM_DATA_ADAPTOR_FRIEND_DECL
;
};
...
...
@@ -104,9 +104,9 @@ using bfloat16_t = uint16_t;
#define CUSTOM_DTYPE_ENUM_DECL(custom_type, builtin_type, ctype) custom_type,
class
DType
{
MGE_WIN_DECLSPEC_FUC
const
void
*
impl
()
const
;
MGE_WIN_DECLSPEC_FUC
DType
(
const
void
*
impl
);
class
MGE_WIN_DECLSPEC_FUC
DType
{
const
void
*
impl
()
const
;
DType
(
const
void
*
impl
);
CUSTOM_PIMPL_CLS_DECL
(
DType
);
public:
...
...
@@ -114,27 +114,33 @@ public:
CUSTOM_FOR_EACH_TENSOR_DATA_TYPE
(
CUSTOM_DTYPE_ENUM_DECL
)
};
MGE_WIN_DECLSPEC_FUC
DType
(
const
std
::
string
&
dtype
);
MGE_WIN_DECLSPEC_FUC
DType
(
const
char
*
dtype
);
MGE_WIN_DECLSPEC_FUC
DType
(
const
std
::
string
&
dtype
,
float
scale
,
uint8_t
zero_point
=
0
);
MGE_WIN_DECLSPEC_FUC
DType
(
const
char
*
dtype
,
float
scale
,
uint8_t
zero_point
=
0
);
MGE_WIN_DECLSPEC_FUC
DType
(
DTypeEnum
dtype
);
MGE_WIN_DECLSPEC_FUC
DType
(
DTypeEnum
dtype
,
float
scale
,
uint8_t
zero_point
=
0
);
MGE_WIN_DECLSPEC_FUC
std
::
string
str
(
void
)
const
;
MGE_WIN_DECLSPEC_FUC
DTypeEnum
enumv
()
const
;
MGE_WIN_DECLSPEC_FUC
float
scale
(
void
)
const
;
MGE_WIN_DECLSPEC_FUC
uint8_t
zero_point
(
void
)
const
;
DType
(
const
std
::
string
&
dtype
);
DType
(
const
char
*
dtype
);
DType
(
const
std
::
string
&
dtype
,
float
scale
,
uint8_t
zero_point
=
0
);
DType
(
const
char
*
dtype
,
float
scale
,
uint8_t
zero_point
=
0
);
DType
(
DTypeEnum
dtype
);
DType
(
DTypeEnum
dtype
,
float
scale
,
uint8_t
zero_point
=
0
);
std
::
string
str
(
void
)
const
;
DTypeEnum
enumv
()
const
;
float
scale
(
void
)
const
;
uint8_t
zero_point
(
void
)
const
;
template
<
typename
T
>
MGE_WIN_DECLSPEC_FUC
bool
is_compatible
(
void
)
const
;
bool
is_compatible
(
void
)
const
;
MGE_WIN_DECLSPEC_FUC
static
bool
is_legal
(
const
std
::
string
&
dtype
);
MGE_WIN_DECLSPEC_FUC
static
bool
is_legal
(
const
DTypeEnum
&
dtype
);
MGE_WIN_DECLSPEC_FUC
static
std
::
vector
<
std
::
string
>
legal_dtypes
(
void
);
static
bool
is_legal
(
const
std
::
string
&
dtype
);
static
bool
is_legal
(
const
DTypeEnum
&
dtype
);
static
std
::
vector
<
std
::
string
>
legal_dtypes
(
void
);
friend
class
Tensor
;
friend
bool
operator
==
(
const
DType
&
lhs
,
const
DType
&
rhs
);
MGE_WIN_DECLSPEC_FUC
friend
bool
operator
==
(
const
DType
&
lhs
,
const
DType
&
rhs
);
MGE_WIN_DECLSPEC_FUC
friend
bool
operator
==
(
const
DType
&
lhs
,
const
std
::
string
&
rhs
);
MGE_WIN_DECLSPEC_FUC
friend
bool
operator
==
(
const
DType
&
lhs
,
const
char
*
rhs
);
MGE_WIN_DECLSPEC_FUC
friend
bool
operator
==
(
const
std
::
string
&
lhs
,
const
DType
&
rhs
);
MGE_WIN_DECLSPEC_FUC
friend
bool
operator
==
(
const
char
*
lhs
,
const
DType
&
rhs
);
CUSTOM_DATA_ADAPTOR_FRIEND_DECL
;
};
...
...
@@ -180,45 +186,45 @@ bool operator==(const DType& lhs, const char* rhs);
bool
operator
==
(
const
std
::
string
&
lhs
,
const
DType
&
rhs
);
bool
operator
==
(
const
char
*
lhs
,
const
DType
&
rhs
);
class
Format
{
MGE_WIN_DECLSPEC_FUC
const
void
*
impl
()
const
;
MGE_WIN_DECLSPEC_FUC
Format
(
const
void
*
impl
);
class
MGE_WIN_DECLSPEC_FUC
Format
{
const
void
*
impl
()
const
;
Format
(
const
void
*
impl
);
CUSTOM_PIMPL_CLS_DECL
(
Format
);
public:
MGE_WIN_DECLSPEC_FUC
Format
(
const
std
::
string
&
format
);
MGE_WIN_DECLSPEC_FUC
Format
(
const
char
*
format
);
Format
(
const
std
::
string
&
format
);
Format
(
const
char
*
format
);
MGE_WIN_DECLSPEC_FUC
std
::
string
str
(
void
)
const
;
MGE_WIN_DECLSPEC_FUC
bool
is_default
(
void
)
const
;
std
::
string
str
(
void
)
const
;
bool
is_default
(
void
)
const
;
friend
class
Tensor
;
CUSTOM_DATA_ADAPTOR_FRIEND_DECL
;
};
class
Tensor
{
class
MGE_WIN_DECLSPEC_FUC
Tensor
{
void
*
m_tensor
;
MGE_WIN_DECLSPEC_FUC
const
void
*
impl
(
void
)
const
;
MGE_WIN_DECLSPEC_FUC
Tensor
(
const
void
*
impl
);
const
void
*
impl
(
void
)
const
;
Tensor
(
const
void
*
impl
);
MGE_WIN_DECLSPEC_FUC
const
size_t
*
shapes_raw
(
void
)
const
;
MGE_WIN_DECLSPEC_FUC
const
ptrdiff_t
*
strides_raw
(
void
)
const
;
const
size_t
*
shapes_raw
(
void
)
const
;
const
ptrdiff_t
*
strides_raw
(
void
)
const
;
public:
Tensor
()
=
delete
;
MGE_WIN_DECLSPEC_FUC
Tensor
(
const
Tensor
&
rhs
);
MGE_WIN_DECLSPEC_FUC
Tensor
&
operator
=
(
const
Tensor
&
rhs
);
MGE_WIN_DECLSPEC_FUC
Shape
shape
(
void
)
const
;
MGE_WIN_DECLSPEC_FUC
DType
dtype
(
void
)
const
;
MGE_WIN_DECLSPEC_FUC
Format
format
(
void
)
const
;
MGE_WIN_DECLSPEC_FUC
Device
device
(
void
)
const
;
MGE_WIN_DECLSPEC_FUC
size_t
size
(
void
)
const
;
MGE_WIN_DECLSPEC_FUC
std
::
vector
<
ptrdiff_t
>
stride
(
void
)
const
;
MGE_WIN_DECLSPEC_FUC
float
scale
(
void
)
const
;
MGE_WIN_DECLSPEC_FUC
uint8_t
zero_point
(
void
)
const
;
Tensor
(
const
Tensor
&
rhs
);
Tensor
&
operator
=
(
const
Tensor
&
rhs
);
Shape
shape
(
void
)
const
;
DType
dtype
(
void
)
const
;
Format
format
(
void
)
const
;
Device
device
(
void
)
const
;
size_t
size
(
void
)
const
;
std
::
vector
<
ptrdiff_t
>
stride
(
void
)
const
;
float
scale
(
void
)
const
;
uint8_t
zero_point
(
void
)
const
;
void
*
data
(
void
);
const
void
*
data
(
void
)
const
;
...
...
src/custom/include/megbrain/custom/utils.h
浏览文件 @
a72e0cb5
...
...
@@ -19,10 +19,19 @@
namespace
custom
{
void
assert_failed_log
(
#ifndef MGE_WIN_DECLSPEC_FUC
#ifdef _WIN32
#define MGE_WIN_DECLSPEC_FUC __declspec(dllexport)
#else
#define MGE_WIN_DECLSPEC_FUC
#endif
#endif
MGE_WIN_DECLSPEC_FUC
void
assert_failed_log
(
const
char
*
file
,
int
line
,
const
char
*
func
,
const
char
*
expr
,
const
char
*
msg_fmt
,
...);
#ifndef _WIN32
#define custom_expect(expr, msg...) \
if (!(expr)) { \
assert_failed_log(__FILE__, __LINE__, __PRETTY_FUNCTION__, #expr, ##msg); \
...
...
@@ -33,8 +42,22 @@ void assert_failed_log(
assert_failed_log(__FILE__, __LINE__, __PRETTY_FUNCTION__, #expr, ##msg); \
} \
assert((expr))
#else
#define custom_expect(expr, ...) \
if (!(expr)) { \
assert_failed_log( \
__FILE__, __LINE__, __PRETTY_FUNCTION__, #expr, __VA_ARGS__); \
}
#define custom_assert(expr, ...) \
if (!(expr)) { \
assert_failed_log( \
__FILE__, __LINE__, __PRETTY_FUNCTION__, #expr, __VA_ARGS__); \
} \
assert((expr))
#endif
class
UnImpleWarnLog
{
class
MGE_WIN_DECLSPEC_FUC
UnImpleWarnLog
{
public:
UnImpleWarnLog
(
const
std
::
string
&
func
,
const
std
::
string
&
attr
,
const
std
::
string
&
val
);
...
...
@@ -54,9 +77,9 @@ void impl_deleter(void* ptr) {
std::unique_ptr<void, void_deleter> m_impl; \
\
public: \
MGE_WIN_DECLSPEC_FUC Cls();
\
MGE_WIN_DECLSPEC_FUC Cls(const Cls& rhs);
\
MGE_WIN_DECLSPEC_FUC
Cls& operator=(const Cls& rhs)
Cls();
\
Cls(const Cls& rhs);
\
Cls& operator=(const Cls& rhs)
#define CUSTOM_PIMPL_CLS_DEFINE(Cls) \
Cls::Cls() : m_impl(new Cls##Impl(), impl_deleter<Cls##Impl>) {} \
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录