Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
edb32495
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
404
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
edb32495
编写于
9月 18, 2020
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
feat(dnn/opr): add megdnn adaptive pooling opr
GitOrigin-RevId: 563ce65479c90cb5686761889d9b8459c9afcefa
上级
5a85c907
变更
12
隐藏空白更改
内联
并排
Showing
12 changed file
with
434 addition
and
0 deletion
+434
-0
dnn/include/megdnn/oprs/nn.h
dnn/include/megdnn/oprs/nn.h
+47
-0
dnn/src/common/adaptive_pooling.cpp
dnn/src/common/adaptive_pooling.cpp
+37
-0
dnn/src/common/handle_impl.h
dnn/src/common/handle_impl.h
+2
-0
dnn/src/cuda/adaptive_pooling/opr_impl.cpp
dnn/src/cuda/adaptive_pooling/opr_impl.cpp
+53
-0
dnn/src/cuda/adaptive_pooling/opr_impl.h
dnn/src/cuda/adaptive_pooling/opr_impl.h
+44
-0
dnn/src/cuda/handle_create.cpp
dnn/src/cuda/handle_create.cpp
+1
-0
dnn/src/naive/adaptive_pooling/opr_impl.cpp
dnn/src/naive/adaptive_pooling/opr_impl.cpp
+52
-0
dnn/src/naive/adaptive_pooling/opr_impl.h
dnn/src/naive/adaptive_pooling/opr_impl.h
+43
-0
dnn/src/naive/handle.cpp
dnn/src/naive/handle.cpp
+1
-0
dnn/test/common/adaptive_pooling.h
dnn/test/common/adaptive_pooling.h
+55
-0
dnn/test/common/opr_trait.h
dnn/test/common/opr_trait.h
+2
-0
dnn/test/cuda/adaptive_pooling.cpp
dnn/test/cuda/adaptive_pooling.cpp
+97
-0
未找到文件。
dnn/include/megdnn/oprs/nn.h
浏览文件 @
edb32495
...
@@ -682,6 +682,53 @@ protected:
...
@@ -682,6 +682,53 @@ protected:
size_t
workspace_in_bytes
);
size_t
workspace_in_bytes
);
};
};
/**
* \brief base class for AdaptivePooling
*/
class
AdaptivePoolingBase
:
public
OperatorBase
{
DEF_OPR_IMPL_CTOR
(
AdaptivePoolingBase
,
OperatorBase
);
DEF_OPR_PARAM
(
AdaptivePooling
);
protected:
param
::
Pooling
deduce_pooling_param
(
const
TensorLayout
&
src
,
const
TensorLayout
&
dst
);
};
class
AdaptivePoolingForward
:
public
AdaptivePoolingBase
{
DEF_OPR_IMPL
(
AdaptivePoolingForward
,
AdaptivePoolingBase
,
1
,
1
);
public:
/**
* \param[in] src input tensor
* \param[out] dst output tensor
*/
virtual
void
exec
(
_megdnn_tensor_in
src
,
_megdnn_tensor_out
dst
,
_megdnn_workspace
workspace
)
=
0
;
virtual
size_t
get_workspace_in_bytes
(
const
TensorLayout
&
src
,
const
TensorLayout
&
dst
)
=
0
;
};
using
AdaptivePooling
=
AdaptivePoolingForward
;
class
AdaptivePoolingBackward
:
public
AdaptivePoolingBase
{
DEF_OPR_IMPL
(
AdaptivePoolingBackward
,
AdaptivePoolingBase
,
3
,
1
);
public:
/**
* \param[in] src the `src' parameter in AdaptivePoolingForward::exec
* \param[in] dst the `dst' parameter in AdaptivePoolingForward::exec
* \param[in] diff the backpropagated gradient wrt. dst
* \param[out] grad the backpropagated gradient wrt. src
*/
virtual
void
exec
(
_megdnn_tensor_in
src
,
_megdnn_tensor_in
dst
,
_megdnn_tensor_in
diff
,
_megdnn_tensor_out
grad
,
_megdnn_workspace
workspace
)
=
0
;
virtual
size_t
get_workspace_in_bytes
(
const
TensorLayout
&
src
,
const
TensorLayout
&
dst
,
const
TensorLayout
&
diff
,
const
TensorLayout
&
grad
)
=
0
;
};
/**
/**
* \brief base class for Local
* \brief base class for Local
*/
*/
...
...
dnn/src/common/adaptive_pooling.cpp
0 → 100644
浏览文件 @
edb32495
/**
* \file dnn/src/common/adaptive_pooling.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "megdnn/opr_param_defs.h"
#include "megdnn/oprs.h"
#include "src/common/utils.h"
namespace
megdnn
{
param
::
Pooling
AdaptivePoolingBase
::
deduce_pooling_param
(
const
TensorLayout
&
src
,
const
TensorLayout
&
dst
)
{
megdnn_assert
(
param
().
format
==
param
::
AdaptivePooling
::
Format
::
NCHW
);
size_t
IH
=
src
.
shape
[
2
],
IW
=
src
.
shape
[
3
],
OH
=
dst
.
shape
[
2
],
OW
=
dst
.
shape
[
3
];
param
::
Pooling
ret
;
ret
.
mode
=
param
().
mode
;
ret
.
format
=
param
().
format
;
ret
.
pad_h
=
ret
.
pad_w
=
0
;
ret
.
stride_h
=
floor
(
IH
/
OH
);
ret
.
stride_w
=
floor
(
IW
/
OW
);
ret
.
window_h
=
IH
-
(
OH
-
1
)
*
ret
.
stride_h
;
ret
.
window_w
=
IW
-
(
OW
-
1
)
*
ret
.
stride_w
;
return
ret
;
}
}
// namespace megdnn
// vim: syntax=cpp.doxygen
dnn/src/common/handle_impl.h
浏览文件 @
edb32495
...
@@ -199,6 +199,8 @@ private:
...
@@ -199,6 +199,8 @@ private:
cb(Remap) \
cb(Remap) \
cb(RemapBackwardData) \
cb(RemapBackwardData) \
cb(RemapBackwardMat) \
cb(RemapBackwardMat) \
cb(AdaptivePoolingForward) \
cb(AdaptivePoolingBackward) \
/*!
/*!
* \brief specialize HandleImpl::create_operator for a single opr type;
* \brief specialize HandleImpl::create_operator for a single opr type;
...
...
dnn/src/cuda/adaptive_pooling/opr_impl.cpp
0 → 100644
浏览文件 @
edb32495
/**
* \file dnn/src/cuda/adaptive_pooling/opr_impl.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "src/cuda/adaptive_pooling/opr_impl.h"
#include "src/cuda/utils.h"
namespace
megdnn
{
namespace
cuda
{
void
AdaptivePoolingForwardImpl
::
exec
(
_megdnn_tensor_in
src
,
_megdnn_tensor_out
dst
,
_megdnn_workspace
workspace
)
{
auto
opr
=
handle
()
->
create_operator
<
PoolingForward
>
();
opr
->
param
()
=
deduce_pooling_param
(
src
.
layout
,
dst
.
layout
);
opr
->
exec
(
src
,
dst
,
workspace
);
}
size_t
AdaptivePoolingForwardImpl
::
get_workspace_in_bytes
(
const
TensorLayout
&
src
,
const
TensorLayout
&
dst
)
{
auto
opr
=
handle
()
->
create_operator
<
PoolingForward
>
();
opr
->
param
()
=
deduce_pooling_param
(
src
,
dst
);
return
opr
->
get_workspace_in_bytes
(
src
,
dst
);
}
void
AdaptivePoolingBackwardImpl
::
exec
(
_megdnn_tensor_in
src
,
_megdnn_tensor_in
dst
,
_megdnn_tensor_in
diff
,
_megdnn_tensor_out
grad
,
_megdnn_workspace
workspace
)
{
auto
opr
=
handle
()
->
create_operator
<
PoolingBackward
>
();
opr
->
param
()
=
deduce_pooling_param
(
src
.
layout
,
dst
.
layout
);
opr
->
exec
(
src
,
dst
,
diff
,
grad
,
workspace
);
}
size_t
AdaptivePoolingBackwardImpl
::
get_workspace_in_bytes
(
const
TensorLayout
&
src
,
const
TensorLayout
&
dst
,
const
TensorLayout
&
diff
,
const
TensorLayout
&
grad
)
{
auto
opr
=
handle
()
->
create_operator
<
PoolingBackward
>
();
opr
->
param
()
=
deduce_pooling_param
(
src
,
dst
);
return
opr
->
get_workspace_in_bytes
(
src
,
dst
,
diff
,
grad
);
}
}
// namespace cuda
}
// namespace megdnn
// vim: syntax=cpp.doxygen
dnn/src/cuda/adaptive_pooling/opr_impl.h
0 → 100644
浏览文件 @
edb32495
/**
* \file dnn/src/cuda/adaptive_pooling/opr_impl.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#pragma once
#include "megdnn/oprs.h"
#include "src/cuda/cudnn_wrapper.h"
#include "src/cuda/utils.h"
namespace
megdnn
{
namespace
cuda
{
class
AdaptivePoolingForwardImpl
final
:
public
AdaptivePoolingForward
{
public:
using
AdaptivePoolingForward
::
AdaptivePoolingForward
;
void
exec
(
_megdnn_tensor_in
src
,
_megdnn_tensor_out
dst
,
_megdnn_workspace
workspace
)
override
;
size_t
get_workspace_in_bytes
(
const
TensorLayout
&
src
,
const
TensorLayout
&
dst
)
override
;
};
class
AdaptivePoolingBackwardImpl
final
:
public
AdaptivePoolingBackward
{
public:
using
AdaptivePoolingBackward
::
AdaptivePoolingBackward
;
void
exec
(
_megdnn_tensor_in
src
,
_megdnn_tensor_in
dst
,
_megdnn_tensor_in
diff
,
_megdnn_tensor_out
grad
,
_megdnn_workspace
workspace
)
override
;
size_t
get_workspace_in_bytes
(
const
TensorLayout
&
src
,
const
TensorLayout
&
dst
,
const
TensorLayout
&
diff
,
const
TensorLayout
&
grad
)
override
;
};
}
// namespace cuda
}
// namespace megdnn
// vim: syntax=cpp.doxygen
dnn/src/cuda/handle_create.cpp
浏览文件 @
edb32495
...
@@ -11,6 +11,7 @@
...
@@ -11,6 +11,7 @@
#include "src/common/handle_impl.h"
#include "src/common/handle_impl.h"
#include "src/cuda/adaptive_pooling/opr_impl.h"
#include "src/cuda/add_update/opr_impl.h"
#include "src/cuda/add_update/opr_impl.h"
#include "src/cuda/argmxx/opr_impl.h"
#include "src/cuda/argmxx/opr_impl.h"
#include "src/cuda/argsort/opr_impl.h"
#include "src/cuda/argsort/opr_impl.h"
...
...
dnn/src/naive/adaptive_pooling/opr_impl.cpp
0 → 100644
浏览文件 @
edb32495
/**
* \file dnn/src/naive/adaptive_pooling/opr_impl.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "src/naive/adaptive_pooling/opr_impl.h"
#include "src/common/opr_delegate.h"
#include "src/common/utils.h"
#include "src/naive/handle.h"
namespace
megdnn
{
namespace
naive
{
void
AdaptivePoolingForwardImpl
::
exec
(
_megdnn_tensor_in
src
,
_megdnn_tensor_out
dst
,
_megdnn_workspace
workspace
)
{
MEGDNN_DISPATCH_CPU_KERN
(
static_cast
<
naive
::
HandleImpl
*>
(
handle
()),
{
auto
opr
=
inplace_cpu_handle
()
->
create_operator
<
PoolingForward
>
();
opr
->
param
()
=
deduce_pooling_param
(
src
.
layout
,
dst
.
layout
);
opr
->
exec
(
src
,
dst
,
workspace
);
});
}
void
AdaptivePoolingBackwardImpl
::
exec
(
_megdnn_tensor_in
src
,
_megdnn_tensor_in
dst
,
_megdnn_tensor_in
diff
,
_megdnn_tensor_out
grad
,
_megdnn_workspace
workspace
)
{
MEGDNN_DISPATCH_CPU_KERN
(
static_cast
<
naive
::
HandleImpl
*>
(
handle
()),
{
auto
opr
=
inplace_cpu_handle
()
->
create_operator
<
PoolingBackward
>
();
opr
->
param
()
=
deduce_pooling_param
(
src
.
layout
,
dst
.
layout
);
opr
->
exec
(
src
,
dst
,
diff
,
grad
,
workspace
);
});
}
size_t
AdaptivePoolingBackwardImpl
::
get_workspace_in_bytes
(
const
TensorLayout
&
src
,
const
TensorLayout
&
dst
,
const
TensorLayout
&
diff
,
const
TensorLayout
&
grad
)
{
auto
opr
=
inplace_cpu_handle
()
->
create_operator
<
PoolingBackward
>
();
opr
->
param
()
=
deduce_pooling_param
(
src
,
dst
);
return
opr
->
get_workspace_in_bytes
(
src
,
dst
,
diff
,
grad
);
}
}
// namespace naive
}
// namespace megdnn
// vim: syntax=cpp.doxygen
dnn/src/naive/adaptive_pooling/opr_impl.h
0 → 100644
浏览文件 @
edb32495
/**
* \file dnn/src/naive/adaptive_pooling/opr_impl.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#pragma once
#include "megdnn/oprs.h"
#include "src/common/utils.h"
namespace
megdnn
{
namespace
naive
{
class
AdaptivePoolingForwardImpl
:
public
AdaptivePoolingForward
{
public:
using
AdaptivePoolingForward
::
AdaptivePoolingForward
;
void
exec
(
_megdnn_tensor_in
src
,
_megdnn_tensor_out
dst
,
_megdnn_workspace
workspace
)
override
;
size_t
get_workspace_in_bytes
(
const
TensorLayout
&
,
const
TensorLayout
&
)
override
{
return
0
;
}
};
class
AdaptivePoolingBackwardImpl
:
public
AdaptivePoolingBackward
{
public:
using
AdaptivePoolingBackward
::
AdaptivePoolingBackward
;
void
exec
(
_megdnn_tensor_in
src
,
_megdnn_tensor_in
dst
,
_megdnn_tensor_in
diff
,
_megdnn_tensor_out
grad
,
_megdnn_workspace
workspace
)
override
;
size_t
get_workspace_in_bytes
(
const
TensorLayout
&
src
,
const
TensorLayout
&
dst
,
const
TensorLayout
&
diff
,
const
TensorLayout
&
grad
)
override
;
};
}
// namespace naive
}
// namespace megdnn
// vim: syntax=cpp.doxygen
dnn/src/naive/handle.cpp
浏览文件 @
edb32495
...
@@ -13,6 +13,7 @@
...
@@ -13,6 +13,7 @@
#include "src/common/handle_impl.h"
#include "src/common/handle_impl.h"
#include "src/naive/adaptive_pooling/opr_impl.h"
#include "src/naive/add_update/opr_impl.h"
#include "src/naive/add_update/opr_impl.h"
#include "src/naive/argmxx/opr_impl.h"
#include "src/naive/argmxx/opr_impl.h"
#include "src/naive/argsort/opr_impl.h"
#include "src/naive/argsort/opr_impl.h"
...
...
dnn/test/common/adaptive_pooling.h
0 → 100644
浏览文件 @
edb32495
/**
* \file dnn/test/common/adaptive_pooling.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#pragma once
#include <cstddef>
#include "megdnn/basic_types.h"
#include "megdnn/opr_param_defs.h"
namespace
megdnn
{
namespace
test
{
namespace
adaptive_pooling
{
struct
TestArg
{
param
::
AdaptivePooling
param
;
TensorShape
ishape
;
TensorShape
oshape
;
TestArg
(
param
::
AdaptivePooling
param
,
TensorShape
ishape
,
TensorShape
oshape
)
:
param
(
param
),
ishape
(
ishape
),
oshape
(
oshape
)
{}
};
inline
std
::
vector
<
TestArg
>
get_args
()
{
std
::
vector
<
TestArg
>
args
;
using
Param
=
param
::
AdaptivePooling
;
using
Mode
=
param
::
AdaptivePooling
::
Mode
;
for
(
size_t
i
=
36
;
i
<
40
;
++
i
)
{
args
.
emplace_back
(
Param
{
Mode
::
AVERAGE
},
TensorShape
{
2
,
3
,
i
,
i
+
1
},
TensorShape
{
2
,
3
,
i
-
4
,
i
-
2
});
args
.
emplace_back
(
Param
{
Mode
::
MAX
},
TensorShape
{
2
,
3
,
i
,
i
+
1
},
TensorShape
{
2
,
3
,
i
-
4
,
i
-
2
});
}
for
(
size_t
i
=
5
;
i
<
10
;
++
i
)
{
args
.
emplace_back
(
Param
{
Mode
::
AVERAGE
},
TensorShape
{
2
,
3
,
i
,
i
+
1
},
TensorShape
{
2
,
3
,
i
-
3
,
i
-
2
});
args
.
emplace_back
(
Param
{
Mode
::
MAX
},
TensorShape
{
2
,
3
,
i
,
i
+
1
},
TensorShape
{
2
,
3
,
i
-
3
,
i
-
2
});
}
return
args
;
}
}
// namespace adaptive_pooling
}
// namespace test
}
// namespace megdnn
// vim: syntax=cpp.doxygen
dnn/test/common/opr_trait.h
浏览文件 @
edb32495
...
@@ -41,6 +41,8 @@ DEF(Images2NeibsForward, 2, true, true);
...
@@ -41,6 +41,8 @@ DEF(Images2NeibsForward, 2, true, true);
DEF
(
Images2NeibsBackward
,
2
,
true
,
false
);
DEF
(
Images2NeibsBackward
,
2
,
true
,
false
);
DEF
(
PoolingForward
,
2
,
true
,
true
);
DEF
(
PoolingForward
,
2
,
true
,
true
);
DEF
(
PoolingBackward
,
4
,
true
,
false
);
DEF
(
PoolingBackward
,
4
,
true
,
false
);
DEF
(
AdaptivePoolingForward
,
2
,
true
,
false
);
DEF
(
AdaptivePoolingBackward
,
4
,
true
,
false
);
DEF
(
LocalForward
,
3
,
true
,
true
);
DEF
(
LocalForward
,
3
,
true
,
true
);
DEF
(
LocalBackwardData
,
3
,
true
,
false
);
DEF
(
LocalBackwardData
,
3
,
true
,
false
);
DEF
(
LocalBackwardFilter
,
3
,
true
,
false
);
DEF
(
LocalBackwardFilter
,
3
,
true
,
false
);
...
...
dnn/test/cuda/adaptive_pooling.cpp
0 → 100644
浏览文件 @
edb32495
/**
* \file dnn/test/cuda/adaptive_pooling.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "test/cuda/fixture.h"
#include "megdnn/tensor_iter.h"
#include "test/common/adaptive_pooling.h"
#include "test/common/checker.h"
#include "src/common/utils.h"
#include "test/cuda/utils.h"
#include <cudnn.h>
#include "test/cuda/benchmark.h"
namespace
megdnn
{
namespace
test
{
TEST_F
(
CUDA
,
ADAPTIVE_POOLING_FORWARD
)
{
auto
args
=
adaptive_pooling
::
get_args
();
using
Format
=
param
::
AdaptivePooling
::
Format
;
DType
dtype
=
dtype
::
Float32
();
for
(
auto
&&
arg
:
args
)
{
auto
param
=
arg
.
param
;
auto
src
=
arg
.
ishape
;
auto
dst
=
arg
.
oshape
;
param
.
format
=
Format
::
NCHW
;
Checker
<
AdaptivePooling
>
checker
(
handle_cuda
());
checker
.
set_epsilon
(
1e-2
);
checker
.
set_param
(
param
).
set_dtype
(
0
,
dtype
).
set_dtype
(
1
,
dtype
).
exec
(
TensorShapeArray
{
src
,
dst
,
{}});
}
}
TEST_F
(
CUDA
,
ADAPTIVE_POOLING_BACKWARD
)
{
auto
args
=
adaptive_pooling
::
get_args
();
for
(
auto
&&
arg
:
args
)
{
Checker
<
AdaptivePoolingBackward
>
checker
(
handle_cuda
());
TensorLayout
ilayout
=
TensorLayout
(
arg
.
ishape
,
dtype
::
Float32
());
TensorLayout
olayout
=
TensorLayout
(
arg
.
oshape
,
dtype
::
Float32
());
auto
constraint
=
[
this
,
arg
](
CheckerHelper
::
TensorValueArray
&
tensors_orig
)
{
megdnn_assert
(
tensors_orig
.
size
()
==
4
);
auto
opr
=
handle_cuda
()
->
create_operator
<
AdaptivePoolingForward
>
();
opr
->
param
()
=
arg
.
param
;
auto
tensors_cuda_storage
=
CheckerHelper
::
alloc_tensors
(
handle_cuda
(),
{
tensors_orig
[
0
].
layout
,
tensors_orig
[
1
].
layout
},
0
);
auto
&&
tensors_cuda
=
*
tensors_cuda_storage
;
auto
span
=
tensors_cuda
[
0
].
layout
.
span
();
auto
dst
=
static_cast
<
dt_byte
*>
(
tensors_cuda
[
0
].
raw_ptr
)
+
span
.
low_byte
;
auto
src
=
static_cast
<
const
dt_byte
*>
(
tensors_orig
[
0
].
raw_ptr
)
+
span
.
low_byte
;
megdnn_memcpy_H2D
(
handle_cuda
(),
dst
,
src
,
span
.
dist_byte
());
auto
workspace_size
=
opr
->
get_workspace_in_bytes
(
tensors_cuda
[
0
].
layout
,
tensors_cuda
[
1
].
layout
);
auto
workspace_cuda
=
megdnn_malloc
(
handle_cuda
(),
workspace_size
);
Workspace
workspace
{
static_cast
<
dt_byte
*>
(
workspace_cuda
),
workspace_size
};
opr
->
exec
(
tensors_cuda
[
0
],
tensors_cuda
[
1
],
workspace
);
megdnn_free
(
handle_cuda
(),
workspace_cuda
);
span
=
tensors_cuda
[
1
].
layout
.
span
();
dst
=
static_cast
<
dt_byte
*>
(
tensors_orig
[
1
].
raw_ptr
)
+
span
.
low_byte
;
src
=
static_cast
<
const
dt_byte
*>
(
tensors_cuda
[
1
].
raw_ptr
)
+
span
.
low_byte
;
megdnn_memcpy_D2H
(
handle_cuda
(),
dst
,
src
,
span
.
dist_byte
());
};
DType
dtype
=
dtype
::
Float32
();
checker
.
set_tensors_constraint
(
constraint
)
.
set_dtype
(
0
,
dtype
)
.
set_dtype
(
1
,
dtype
)
.
set_dtype
(
2
,
dtype
)
.
set_dtype
(
3
,
dtype
)
.
set_param
(
arg
.
param
)
.
exec
(
TensorShapeArray
{
ilayout
,
olayout
,
olayout
,
ilayout
});
}
}
}
// namespace test
}
// namespace megdnn
// vim: syntax=cpp.doxygen
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录