Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
3018ca51
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
404
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
3018ca51
编写于
6月 28, 2022
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix(mge): add param RuntimeArgs to customop kernel on cuda
GitOrigin-RevId: 7ed44c42ded50a2a07c258c13306dc5dc90bbd93
上级
086ee045
变更
6
隐藏空白更改
内联
并排
Showing
6 changed file
with
205 addition
and
80 deletion
+205
-80
src/CMakeLists.txt
src/CMakeLists.txt
+9
-0
src/custom/impl/op.cpp
src/custom/impl/op.cpp
+86
-69
src/custom/impl/platform/custom_cuda.cpp
src/custom/impl/platform/custom_cuda.cpp
+21
-0
src/custom/include/megbrain/custom/op.h
src/custom/include/megbrain/custom/op.h
+34
-11
src/custom/include/megbrain/custom/platform/custom_cuda.h
src/custom/include/megbrain/custom/platform/custom_cuda.h
+25
-0
src/custom/test/op.cpp
src/custom/test/op.cpp
+30
-0
未找到文件。
src/CMakeLists.txt
浏览文件 @
3018ca51
...
@@ -90,6 +90,15 @@ endif()
...
@@ -90,6 +90,15 @@ endif()
if
(
MGE_WITH_CUSTOM_OP
)
if
(
MGE_WITH_CUSTOM_OP
)
list
(
APPEND MGB_INC
${
CMAKE_CURRENT_LIST_DIR
}
/custom/include
)
list
(
APPEND MGB_INC
${
CMAKE_CURRENT_LIST_DIR
}
/custom/include
)
file
(
GLOB_RECURSE SOURCES_ custom/impl/*.cpp
)
file
(
GLOB_RECURSE SOURCES_ custom/impl/*.cpp
)
set
(
EXCLUDE_PLATFORM_DIR
"custom/impl/platform"
)
foreach
(
CUSOURCE
${
SOURCES_
}
)
string
(
FIND
${
CUSOURCE
}
${
EXCLUDE_PLATFORM_DIR
}
EXCLUDE_DIR_FOUND
)
if
(
NOT
${
EXCLUDE_DIR_FOUND
}
EQUAL -1
)
list
(
REMOVE_ITEM SOURCES_
${
CUSOURCE
}
)
endif
()
endforeach
(
CUSOURCE
)
list
(
APPEND SOURCES
${
SOURCES_
}
)
list
(
APPEND SOURCES
${
SOURCES_
}
)
endif
()
endif
()
...
...
src/custom/impl/op.cpp
浏览文件 @
3018ca51
...
@@ -6,6 +6,7 @@
...
@@ -6,6 +6,7 @@
#include <unordered_set>
#include <unordered_set>
#include "megbrain/custom/op.h"
#include "megbrain/custom/op.h"
#include "megbrain/custom/utils.h"
#include "megbrain/custom/utils.h"
#include "megbrain/utils/thin/function.h"
using
namespace
mgb
;
using
namespace
mgb
;
...
@@ -99,40 +100,6 @@ std::string ArgInfo::str() const {
...
@@ -99,40 +100,6 @@ std::string ArgInfo::str() const {
(arg_info).name().c_str(), static_cast<int>((arg_info).ndim()), \
(arg_info).name().c_str(), static_cast<int>((arg_info).ndim()), \
static_cast<int>((real_shape).ndim()))
static_cast<int>((real_shape).ndim()))
template
<
typename
T
>
class
Function
;
template
<
typename
RType
,
typename
...
Args
>
class
Function
<
RType
(
Args
...)
>
{
public:
using
Functor
=
RType
(
*
)(
Args
...);
Function
()
=
default
;
Function
(
Functor
f
)
:
m_f
(
f
)
{}
Function
(
const
Function
&
rhs
)
{
m_f
=
rhs
.
m_f
;
}
RType
operator
()(
Args
...
args
)
{
custom_assert
(
m_f
!=
nullptr
,
"invalid function ptr
\n
"
);
return
m_f
(
std
::
forward
<
Args
>
(
args
)...);
}
void
operator
=
(
const
Function
&
rhs
)
{
// not allowed continuous assignment
m_f
=
rhs
.
m_f
;
}
void
operator
=
(
const
Functor
f
)
{
m_f
=
f
;
}
private:
Functor
m_f
=
nullptr
;
};
template
<
typename
Functions
>
class
FuncWithSig
:
public
Functions
{
public:
using
Functions
::
operator
();
using
Functions
::
operator
=
;
};
class
CustomOpImpl
{
class
CustomOpImpl
{
static
constexpr
uint32_t
CURRENT_VERSION
=
CUSTOM_OP_VERSION
;
static
constexpr
uint32_t
CURRENT_VERSION
=
CUSTOM_OP_VERSION
;
const
uint32_t
m_version
;
const
uint32_t
m_version
;
...
@@ -143,29 +110,26 @@ class CustomOpImpl {
...
@@ -143,29 +110,26 @@ class CustomOpImpl {
std
::
vector
<
ArgInfo
>
m_output_infos
;
std
::
vector
<
ArgInfo
>
m_output_infos
;
ParamInfo
m_param_infos
;
ParamInfo
m_param_infos
;
using
DeviceInfer
=
FuncWithSig
<
Function
<
void
(
using
DeviceInfer
=
thin_function
<
void
(
const
std
::
vector
<
Device
>&
,
const
Param
&
,
std
::
vector
<
Device
>&
)
>>
;
const
std
::
vector
<
Device
>&
,
const
Param
&
,
std
::
vector
<
Device
>&
)
>
;
using
ShapeInfer
=
FuncWithSig
<
Function
<
void
(
using
ShapeInfer
=
thin_function
<
void
(
const
std
::
vector
<
Shape
>&
,
const
Param
&
,
std
::
vector
<
Shape
>&
)
>>
;
const
std
::
vector
<
Shape
>&
,
const
Param
&
,
std
::
vector
<
Shape
>&
)
>
;
using
DTypeInfer
=
FuncWithSig
<
Function
<
void
(
using
DTypeInfer
=
thin_function
<
void
(
const
std
::
vector
<
DType
>&
,
const
Param
&
,
std
::
vector
<
DType
>&
)
>>
;
const
std
::
vector
<
DType
>&
,
const
Param
&
,
std
::
vector
<
DType
>&
)
>
;
using
FormatInfer
=
FuncWithSig
<
Function
<
void
(
using
FormatInfer
=
thin_function
<
void
(
const
std
::
vector
<
Format
>&
,
const
Param
&
,
std
::
vector
<
Format
>&
)
>>
;
const
std
::
vector
<
Format
>&
,
const
Param
&
,
std
::
vector
<
Format
>&
)
>
;
using
Preprocess
=
FuncWithSig
<
Function
<
void
(
using
Process
=
thin_function
<
void
(
const
std
::
vector
<
Tensor
>&
,
const
Param
&
,
std
::
vector
<
Tensor
>&
)
>>
;
const
std
::
vector
<
Tensor
>&
,
const
Param
&
,
std
::
vector
<
Tensor
>&
,
using
Postprocess
=
FuncWithSig
<
Function
<
void
(
const
RuntimeArgs
&
)
>
;
const
std
::
vector
<
Tensor
>&
,
const
Param
&
,
std
::
vector
<
Tensor
>&
)
>>
;
using
Compute
=
FuncWithSig
<
Function
<
void
(
const
std
::
vector
<
Tensor
>&
,
const
Param
&
,
std
::
vector
<
Tensor
>&
)
>>
;
DeviceInfer
infer_output_device_func
;
DeviceInfer
infer_output_device_func
;
ShapeInfer
infer_output_shape_func
;
ShapeInfer
infer_output_shape_func
;
DTypeInfer
infer_output_dtype_func
;
DTypeInfer
infer_output_dtype_func
;
FormatInfer
infer_output_format_func
;
FormatInfer
infer_output_format_func
;
std
::
unordered_map
<
std
::
string
,
Compute
>
compute_funcs
;
std
::
unordered_map
<
std
::
string
,
Process
>
compute_funcs
;
std
::
unordered_map
<
std
::
string
,
Pr
epr
ocess
>
preprocess_funcs
;
std
::
unordered_map
<
std
::
string
,
Process
>
preprocess_funcs
;
std
::
unordered_map
<
std
::
string
,
P
ostp
rocess
>
postprocess_funcs
;
std
::
unordered_map
<
std
::
string
,
Process
>
postprocess_funcs
;
public:
public:
CustomOpImpl
(
const
std
::
string
&
,
uint32_t
version
);
CustomOpImpl
(
const
std
::
string
&
,
uint32_t
version
);
...
@@ -215,7 +179,8 @@ CustomOpImpl::CustomOpImpl(const std::string& op_type, uint32_t version)
...
@@ -215,7 +179,8 @@ CustomOpImpl::CustomOpImpl(const std::string& op_type, uint32_t version)
for
(
const
auto
&
device
:
Device
::
legal_devices
())
{
for
(
const
auto
&
device
:
Device
::
legal_devices
())
{
compute_funcs
[
device
]
=
[](
const
std
::
vector
<
Tensor
>&
,
const
Param
&
,
compute_funcs
[
device
]
=
[](
const
std
::
vector
<
Tensor
>&
,
const
Param
&
,
std
::
vector
<
Tensor
>&
outputs
)
->
void
{
std
::
vector
<
Tensor
>&
outputs
,
const
RuntimeArgs
&
)
->
void
{
auto
device
=
outputs
[
0
].
device
();
auto
device
=
outputs
[
0
].
device
();
mgb_assert
(
mgb_assert
(
false
,
false
,
...
@@ -224,9 +189,11 @@ CustomOpImpl::CustomOpImpl(const std::string& op_type, uint32_t version)
...
@@ -224,9 +189,11 @@ CustomOpImpl::CustomOpImpl(const std::string& op_type, uint32_t version)
device
.
str
().
c_str
());
device
.
str
().
c_str
());
};
};
preprocess_funcs
[
device
]
=
[](
const
std
::
vector
<
Tensor
>&
,
const
Param
&
,
preprocess_funcs
[
device
]
=
[](
const
std
::
vector
<
Tensor
>&
,
const
Param
&
,
std
::
vector
<
Tensor
>&
)
->
void
{
return
;
};
std
::
vector
<
Tensor
>&
,
const
RuntimeArgs
&
)
->
void
{
return
;
};
postprocess_funcs
[
device
]
=
[](
const
std
::
vector
<
Tensor
>&
,
const
Param
&
,
postprocess_funcs
[
device
]
=
[](
const
std
::
vector
<
Tensor
>&
,
const
Param
&
,
std
::
vector
<
Tensor
>&
)
->
void
{
return
;
};
std
::
vector
<
Tensor
>&
,
const
RuntimeArgs
&
)
->
void
{
return
;
};
}
}
m_param_infos
.
set_tag
(
op_type
);
m_param_infos
.
set_tag
(
op_type
);
}
}
...
@@ -256,33 +223,78 @@ CustomOp& CustomOp::set_format_infer(FormatInferFuncPtr func) {
...
@@ -256,33 +223,78 @@ CustomOp& CustomOp::set_format_infer(FormatInferFuncPtr func) {
return
*
this
;
return
*
this
;
}
}
CustomOp
&
CustomOp
::
set_preprocess
(
PreprocessFuncPtr
func
)
{
CustomOp
&
CustomOp
::
set_preprocess
(
ProcessFuncPtrWithoutRuntimeArgs
func
)
{
set_preprocess
(
"x86"
,
func
);
return
*
this
;
}
CustomOp
&
CustomOp
::
set_preprocess
(
const
std
::
string
&
device
,
ProcessFuncPtrWithoutRuntimeArgs
func
)
{
auto
wrap_func
=
[
func
](
const
std
::
vector
<
Tensor
>&
input
,
const
Param
&
param
,
std
::
vector
<
Tensor
>&
output
,
const
RuntimeArgs
&
)
->
void
{
return
func
(
input
,
param
,
output
);
};
OpImplRef
(
m_impl
.
get
())
->
preprocess_funcs
[
device
]
=
wrap_func
;
return
*
this
;
}
CustomOp
&
CustomOp
::
set_preprocess
(
ProcessFuncPtr
func
)
{
set_preprocess
(
"x86"
,
func
);
set_preprocess
(
"x86"
,
func
);
return
*
this
;
return
*
this
;
}
}
CustomOp
&
CustomOp
::
set_preprocess
(
const
std
::
string
&
device
,
Pr
epr
ocessFuncPtr
func
)
{
CustomOp
&
CustomOp
::
set_preprocess
(
const
std
::
string
&
device
,
ProcessFuncPtr
func
)
{
OpImplRef
(
m_impl
.
get
())
->
preprocess_funcs
[
device
]
=
func
;
OpImplRef
(
m_impl
.
get
())
->
preprocess_funcs
[
device
]
=
func
;
return
*
this
;
return
*
this
;
}
}
CustomOp
&
CustomOp
::
set_postprocess
(
P
ostprocessFuncPtr
func
)
{
CustomOp
&
CustomOp
::
set_postprocess
(
P
rocessFuncPtrWithoutRuntimeArgs
func
)
{
set_postprocess
(
"x86"
,
func
);
set_postprocess
(
"x86"
,
func
);
return
*
this
;
return
*
this
;
}
}
CustomOp
&
CustomOp
::
set_postprocess
(
CustomOp
&
CustomOp
::
set_postprocess
(
const
std
::
string
&
device
,
PostprocessFuncPtr
func
)
{
const
std
::
string
&
device
,
ProcessFuncPtrWithoutRuntimeArgs
func
)
{
auto
wrap_func
=
[
func
](
const
std
::
vector
<
Tensor
>&
input
,
const
Param
&
param
,
std
::
vector
<
Tensor
>&
output
,
const
RuntimeArgs
&
)
->
void
{
func
(
input
,
param
,
output
);
};
OpImplRef
(
m_impl
.
get
())
->
postprocess_funcs
[
device
]
=
wrap_func
;
return
*
this
;
}
CustomOp
&
CustomOp
::
set_postprocess
(
ProcessFuncPtr
func
)
{
set_postprocess
(
"x86"
,
func
);
return
*
this
;
}
CustomOp
&
CustomOp
::
set_postprocess
(
const
std
::
string
&
device
,
ProcessFuncPtr
func
)
{
OpImplRef
(
m_impl
.
get
())
->
postprocess_funcs
[
device
]
=
func
;
OpImplRef
(
m_impl
.
get
())
->
postprocess_funcs
[
device
]
=
func
;
return
*
this
;
return
*
this
;
}
}
CustomOp
&
CustomOp
::
set_compute
(
ComputeFuncPtr
func
)
{
CustomOp
&
CustomOp
::
set_compute
(
ProcessFuncPtrWithoutRuntimeArgs
func
)
{
set_compute
(
"x86"
,
func
);
return
*
this
;
}
CustomOp
&
CustomOp
::
set_compute
(
const
std
::
string
&
device
,
ProcessFuncPtrWithoutRuntimeArgs
func
)
{
auto
wrap_func
=
[
func
](
const
std
::
vector
<
Tensor
>&
input
,
const
Param
&
param
,
std
::
vector
<
Tensor
>&
output
,
const
RuntimeArgs
&
)
->
void
{
func
(
input
,
param
,
output
);
};
OpImplRef
(
m_impl
.
get
())
->
compute_funcs
[
device
]
=
wrap_func
;
return
*
this
;
}
CustomOp
&
CustomOp
::
set_compute
(
ProcessFuncPtr
func
)
{
set_compute
(
"x86"
,
func
);
set_compute
(
"x86"
,
func
);
return
*
this
;
return
*
this
;
}
}
CustomOp
&
CustomOp
::
set_compute
(
const
std
::
string
&
device
,
Compute
FuncPtr
func
)
{
CustomOp
&
CustomOp
::
set_compute
(
const
std
::
string
&
device
,
Process
FuncPtr
func
)
{
OpImplRef
(
m_impl
.
get
())
->
compute_funcs
[
device
]
=
func
;
OpImplRef
(
m_impl
.
get
())
->
compute_funcs
[
device
]
=
func
;
return
*
this
;
return
*
this
;
}
}
...
@@ -513,23 +525,28 @@ void CustomOp::compute(
...
@@ -513,23 +525,28 @@ void CustomOp::compute(
return
;
return
;
}
}
std
::
string
device
=
outputs
[
0
].
device
().
str
();
Device
device
=
outputs
[
0
].
device
();
std
::
string
device_str
=
device
.
str
();
for
(
size_t
i
=
1
;
i
<
outputs
.
size
();
++
i
)
{
for
(
size_t
i
=
1
;
i
<
outputs
.
size
();
++
i
)
{
mgb_assert
(
mgb_assert
(
outputs
[
i
].
device
().
str
()
==
device
,
outputs
[
i
].
device
().
str
()
==
device
_str
,
"all output tensors should have the same device attribute"
);
"all output tensors should have the same device attribute"
);
}
}
// need to add other input/output check
// need to add other input/output check
mgb_assert
(
Device
::
is_legal
(
device
),
"unsupported device type: %s"
,
device
.
c_str
());
mgb_assert
(
Device
::
is_legal
(
device_str
),
"unsupported device type: %s"
,
device_str
.
c_str
());
auto
preprocess_func
=
OpImplRef
(
m_impl
.
get
())
->
preprocess_funcs
[
device_str
];
auto
forward_func
=
OpImplRef
(
m_impl
.
get
())
->
compute_funcs
[
device_str
];
auto
postprocess_func
=
OpImplRef
(
m_impl
.
get
())
->
postprocess_funcs
[
device_str
];
auto
preprocess_func
=
OpImplRef
(
m_impl
.
get
())
->
preprocess_funcs
[
device
];
RuntimeArgs
rt_args
(
device
);
auto
forward_func
=
OpImplRef
(
m_impl
.
get
())
->
compute_funcs
[
device
];
auto
postprocess_func
=
OpImplRef
(
m_impl
.
get
())
->
postprocess_funcs
[
device
];
preprocess_func
(
inputs
,
param
,
outputs
);
preprocess_func
(
inputs
,
param
,
outputs
,
rt_args
);
forward_func
(
inputs
,
param
,
outputs
);
forward_func
(
inputs
,
param
,
outputs
,
rt_args
);
postprocess_func
(
outputs
,
param
,
outputs
);
postprocess_func
(
outputs
,
param
,
outputs
,
rt_args
);
assert_outputs_size_right
(
outputs
);
assert_outputs_size_right
(
outputs
);
}
}
...
...
src/custom/impl/platform/custom_cuda.cpp
0 → 100644
浏览文件 @
3018ca51
#include "megbrain/common.h"
#include "megbrain/comp_node_env.h"
#include "megbrain/custom/data_adaptor.h"
#include "megbrain/custom/platform/custom_cuda.h"
using
namespace
mgb
;
namespace
custom
{
const
CudaRuntimeArgs
get_cuda_runtime_args
(
const
RuntimeArgs
&
rt_args
)
{
mgb_assert
(
rt_args
.
device
().
enumv
()
==
DeviceEnum
::
cuda
,
"devive type should be cuda."
);
const
CompNodeEnv
&
env
=
CompNodeEnv
::
from_comp_node
(
to_builtin
<
CompNode
,
Device
>
(
rt_args
.
device
()));
const
CompNodeEnv
::
CudaEnv
&
cuda_env
=
env
.
cuda_env
();
return
{
cuda_env
.
device
,
cuda_env
.
stream
};
}
}
// namespace custom
src/custom/include/megbrain/custom/op.h
浏览文件 @
3018ca51
...
@@ -36,6 +36,18 @@ class MGE_WIN_DECLSPEC_FUC ArgInfo {
...
@@ -36,6 +36,18 @@ class MGE_WIN_DECLSPEC_FUC ArgInfo {
std
::
string
str
()
const
;
std
::
string
str
()
const
;
};
};
class
CudaRuntimeArgs
;
class
MGE_WIN_DECLSPEC_FUC
RuntimeArgs
{
Device
m_device
;
public:
RuntimeArgs
()
=
default
;
RuntimeArgs
(
Device
device
)
:
m_device
(
device
){};
const
Device
&
device
()
const
{
return
m_device
;
}
};
class
MGE_WIN_DECLSPEC_FUC
CustomOp
{
class
MGE_WIN_DECLSPEC_FUC
CustomOp
{
std
::
unique_ptr
<
void
,
void_deleter
>
m_impl
;
std
::
unique_ptr
<
void
,
void_deleter
>
m_impl
;
...
@@ -51,11 +63,10 @@ public:
...
@@ -51,11 +63,10 @@ public:
void
(
*
)(
const
std
::
vector
<
DType
>&
,
const
Param
&
,
std
::
vector
<
DType
>&
);
void
(
*
)(
const
std
::
vector
<
DType
>&
,
const
Param
&
,
std
::
vector
<
DType
>&
);
using
FormatInferFuncPtr
=
using
FormatInferFuncPtr
=
void
(
*
)(
const
std
::
vector
<
Format
>&
,
const
Param
&
,
std
::
vector
<
Format
>&
);
void
(
*
)(
const
std
::
vector
<
Format
>&
,
const
Param
&
,
std
::
vector
<
Format
>&
);
using
PreprocessFuncPtr
=
using
ProcessFuncPtr
=
void
(
*
)(
void
(
*
)(
const
std
::
vector
<
Tensor
>&
,
const
Param
&
,
std
::
vector
<
Tensor
>&
);
const
std
::
vector
<
Tensor
>&
,
const
Param
&
,
std
::
vector
<
Tensor
>&
,
using
PostprocessFuncPtr
=
const
RuntimeArgs
&
);
void
(
*
)(
const
std
::
vector
<
Tensor
>&
,
const
Param
&
,
std
::
vector
<
Tensor
>&
);
using
ProcessFuncPtrWithoutRuntimeArgs
=
using
ComputeFuncPtr
=
void
(
*
)(
const
std
::
vector
<
Tensor
>&
,
const
Param
&
,
std
::
vector
<
Tensor
>&
);
void
(
*
)(
const
std
::
vector
<
Tensor
>&
,
const
Param
&
,
std
::
vector
<
Tensor
>&
);
// write for forward
// write for forward
...
@@ -63,12 +74,24 @@ public:
...
@@ -63,12 +74,24 @@ public:
CustomOp
&
set_shape_infer
(
ShapeInferFuncPtr
func
);
CustomOp
&
set_shape_infer
(
ShapeInferFuncPtr
func
);
CustomOp
&
set_dtype_infer
(
DTypeInferFuncPtr
func
);
CustomOp
&
set_dtype_infer
(
DTypeInferFuncPtr
func
);
CustomOp
&
set_format_infer
(
FormatInferFuncPtr
func
);
CustomOp
&
set_format_infer
(
FormatInferFuncPtr
func
);
CustomOp
&
set_preprocess
(
PreprocessFuncPtr
func
);
//! set process function with RuntimeArgs e.g. cuda
CustomOp
&
set_preprocess
(
const
std
::
string
&
device
,
PreprocessFuncPtr
func
);
CustomOp
&
set_preprocess
(
ProcessFuncPtr
func
);
CustomOp
&
set_postprocess
(
PostprocessFuncPtr
func
);
CustomOp
&
set_preprocess
(
const
std
::
string
&
device
,
ProcessFuncPtr
func
);
CustomOp
&
set_postprocess
(
const
std
::
string
&
device
,
PostprocessFuncPtr
func
);
CustomOp
&
set_postprocess
(
ProcessFuncPtr
func
);
CustomOp
&
set_compute
(
ComputeFuncPtr
func
);
CustomOp
&
set_postprocess
(
const
std
::
string
&
device
,
ProcessFuncPtr
func
);
CustomOp
&
set_compute
(
const
std
::
string
&
device
,
ComputeFuncPtr
func
);
CustomOp
&
set_compute
(
ProcessFuncPtr
func
);
CustomOp
&
set_compute
(
const
std
::
string
&
device
,
ProcessFuncPtr
func
);
//! set process function without RuntimeArgs e.g. cpu
CustomOp
&
set_preprocess
(
ProcessFuncPtrWithoutRuntimeArgs
func
);
CustomOp
&
set_preprocess
(
const
std
::
string
&
device
,
ProcessFuncPtrWithoutRuntimeArgs
func
);
CustomOp
&
set_postprocess
(
ProcessFuncPtrWithoutRuntimeArgs
func
);
CustomOp
&
set_postprocess
(
const
std
::
string
&
device
,
ProcessFuncPtrWithoutRuntimeArgs
func
);
CustomOp
&
set_compute
(
ProcessFuncPtrWithoutRuntimeArgs
func
);
CustomOp
&
set_compute
(
const
std
::
string
&
device
,
ProcessFuncPtrWithoutRuntimeArgs
func
);
CustomOp
&
set_description
(
const
std
::
string
&
op_desc
);
CustomOp
&
set_description
(
const
std
::
string
&
op_desc
);
CustomOp
&
add_input
(
CustomOp
&
add_input
(
...
...
src/custom/include/megbrain/custom/platform/custom_cuda.h
0 → 100644
浏览文件 @
3018ca51
#pragma once
#include "megbrain/custom/op.h"
#include <cuda_runtime_api.h>
namespace
custom
{
class
CudaRuntimeArgs
{
private:
int
m_device
;
cudaStream_t
m_stream
;
public:
CudaRuntimeArgs
(
int
device
,
cudaStream_t
stream
)
:
m_device
(
device
),
m_stream
(
stream
)
{}
int
device
()
const
{
return
m_device
;
}
cudaStream_t
stream
()
const
{
return
m_stream
;
}
};
const
CudaRuntimeArgs
get_cuda_runtime_args
(
const
RuntimeArgs
&
rt_args
);
}
// namespace custom
src/custom/test/op.cpp
浏览文件 @
3018ca51
...
@@ -119,6 +119,34 @@ void gpu_kernel(
...
@@ -119,6 +119,34 @@ void gpu_kernel(
ASSERT_TRUE
(
params
[
"device"
]
==
"cuda"
);
ASSERT_TRUE
(
params
[
"device"
]
==
"cuda"
);
}
}
void
cpu_kernel_with_runtime_args
(
const
std
::
vector
<
Tensor
>&
inputs
,
const
Param
&
params
,
std
::
vector
<
Tensor
>&
outputs
,
const
RuntimeArgs
&
args
)
{
(
void
)
inputs
;
(
void
)
params
;
(
void
)
outputs
;
(
void
)
args
;
#if OP_TEST_LOG
std
::
cout
<<
"Checking CPU Forward - "
<<
params
[
"device"
].
as
<
std
::
string
>
()
<<
std
::
endl
;
#endif
ASSERT_TRUE
(
params
[
"device"
]
==
"x86"
);
}
void
gpu_kernel_with_runtime_args
(
const
std
::
vector
<
Tensor
>&
inputs
,
const
Param
&
params
,
std
::
vector
<
Tensor
>&
outputs
,
const
RuntimeArgs
&
args
)
{
(
void
)
inputs
;
(
void
)
params
;
(
void
)
outputs
;
(
void
)
args
;
#if OP_TEST_LOG
std
::
cout
<<
"Checking GPU Forward - "
<<
params
[
"device"
].
as
<
std
::
string
>
()
<<
std
::
endl
;
#endif
ASSERT_TRUE
(
params
[
"device"
]
==
"cuda"
);
}
TEST
(
TestCustomOp
,
TestCustomOpFuncSetter
)
{
TEST
(
TestCustomOp
,
TestCustomOpFuncSetter
)
{
#if MGB_CUDA
#if MGB_CUDA
CustomOp
test
(
"TestOp"
,
CUSTOM_OP_VERSION
);
CustomOp
test
(
"TestOp"
,
CUSTOM_OP_VERSION
);
...
@@ -179,6 +207,7 @@ TEST(TestCustomOp, TestCustomOpFuncSetter) {
...
@@ -179,6 +207,7 @@ TEST(TestCustomOp, TestCustomOpFuncSetter) {
ASSERT_TRUE
(
iformats
[
0
].
is_default
());
ASSERT_TRUE
(
iformats
[
0
].
is_default
());
ASSERT_TRUE
(
iformats
[
1
].
is_default
());
ASSERT_TRUE
(
iformats
[
1
].
is_default
());
test
.
set_compute
(
cpu_kernel_with_runtime_args
);
test
.
set_compute
(
cpu_kernel
);
test
.
set_compute
(
cpu_kernel
);
DeviceTensorND
cdev_itensor0
(
CompNode
::
load
(
"cpux"
),
{
3
,
2
},
dtype
::
Int32
{});
DeviceTensorND
cdev_itensor0
(
CompNode
::
load
(
"cpux"
),
{
3
,
2
},
dtype
::
Int32
{});
DeviceTensorND
cdev_itensor1
(
CompNode
::
load
(
"cpux"
),
{
3
,
2
},
dtype
::
Float32
{});
DeviceTensorND
cdev_itensor1
(
CompNode
::
load
(
"cpux"
),
{
3
,
2
},
dtype
::
Float32
{});
...
@@ -192,6 +221,7 @@ TEST(TestCustomOp, TestCustomOpFuncSetter) {
...
@@ -192,6 +221,7 @@ TEST(TestCustomOp, TestCustomOpFuncSetter) {
param
[
"device"
]
=
"x86"
;
param
[
"device"
]
=
"x86"
;
test
.
compute
(
cinputs
,
param
,
coutputs
);
test
.
compute
(
cinputs
,
param
,
coutputs
);
test
.
set_compute
(
"cuda"
,
gpu_kernel_with_runtime_args
);
test
.
set_compute
(
"cuda"
,
gpu_kernel
);
test
.
set_compute
(
"cuda"
,
gpu_kernel
);
DeviceTensorND
gdev_itensor0
(
CompNode
::
load
(
"gpux"
),
{
3
,
2
},
dtype
::
Int32
{});
DeviceTensorND
gdev_itensor0
(
CompNode
::
load
(
"gpux"
),
{
3
,
2
},
dtype
::
Int32
{});
DeviceTensorND
gdev_itensor1
(
CompNode
::
load
(
"gpux"
),
{
3
,
2
},
dtype
::
Float32
{});
DeviceTensorND
gdev_itensor1
(
CompNode
::
load
(
"gpux"
),
{
3
,
2
},
dtype
::
Float32
{});
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录