提交 44c8d2d1 编写于 作者: M Megvii Engine Team

refactor(megdnn): refactor matmul algo in deformable conv

GitOrigin-RevId: 05291baf98f36141ccb6d686e2f92a58766d848c
上级 b04ad06f
...@@ -102,24 +102,24 @@ class DeformableConvBackwardDataImpl::AlgoMatmul final : public AlgoBase { ...@@ -102,24 +102,24 @@ class DeformableConvBackwardDataImpl::AlgoMatmul final : public AlgoBase {
private: private:
static WorkspaceBundle get_bundle(const SizeArgs& args); static WorkspaceBundle get_bundle(const SizeArgs& args);
static void get_matmul_layout(const SizeArgs& args, TensorLayout& al,
TensorLayout& bl, TensorLayout& cl);
public: public:
AlgoMatmul() {}
bool is_available(const SizeArgs& args) const override; bool is_available(const SizeArgs& args) const override;
size_t get_workspace_in_bytes(const SizeArgs& args) const override; size_t get_workspace_in_bytes(const SizeArgs& args) const override;
void exec(const ExecArgs& args) const override; void exec(const ExecArgs& args) const override;
bool is_reproducible() const override { return true; } bool is_reproducible() const override { return true; }
const char* name() const override { return "AlgoMatmul"; } std::vector<SearchItem> get_subopr_list(
const TensorLayoutArray& layouts,
const OperatorBase* opr) const override;
const char* name() const override { return "MATMUL"; }
MEGDNN_DECL_ALGO_TYPE(CUDA_MATMUL) MEGDNN_DECL_ALGO_TYPE(CUDA_MATMUL)
}; };
class DeformableConvBackwardDataImpl::AlgoPack : NonCopyableObj { class DeformableConvBackwardDataImpl::AlgoPack : NonCopyableObj {
AlgoBase::Mapper m_all_algos_map; AlgoBase::Mapper m_all_algos_map;
public: public:
AlgoPack(); AlgoPack();
AlgoMatmul algo_matmul; AlgoMatmul algo_matmul;
......
...@@ -57,24 +57,47 @@ deformable_conv::Param create_param(const Algo::SizeArgs& args, ...@@ -57,24 +57,47 @@ deformable_conv::Param create_param(const Algo::SizeArgs& args,
return p; return p;
} }
}; // anonymous namespace
bool Algo::is_available(const SizeArgs&) const { std::pair<TensorLayoutArray, BatchedMatrixMulForward::Param> sub_opr_config(
return true; const DeformableConvForwardImpl::CanonizedFilterMeta& fm,
const TensorLayout& im,
const TensorLayout& out_grad) {
auto&& dt = im.dtype;
size_t batch_sz = im[0], OH = out_grad[2],
OW = out_grad[3], FH = fm.spatial[0], FW = fm.spatial[1];
size_t M = fm.icpg * FH * FW, K = fm.ocpg, N = batch_sz * OH * OW,
batch = fm.group;
TensorLayout al = {{batch, K, M}, dt};
TensorLayout bl = {{batch, K, N}, dt};
TensorLayout cl = {{batch, M, N}, dt};
BatchedMatrixMulForward::Param param;
param.compute_mode = param::MatrixMul::ComputeMode::DEFAULT;
param.transposeA = true;
return {{al, bl, cl}, param};
} }
void Algo::get_matmul_layout(const SizeArgs& args, TensorLayout& al, }; // anonymous namespace
TensorLayout& bl, TensorLayout& cl) {
auto&& dt = args.im_layout.dtype;
auto&& fm = args.filter_meta;
size_t batch_sz = args.im_layout[0], OH = args.out_grad_layout[2],
OW = args.out_grad_layout[3], FH = fm.spatial[0], FW = fm.spatial[1];
size_t M = fm.icpg * FH * FW, K = fm.ocpg, N = batch_sz * OH * OW, std::vector<Algorithm::SearchItem>
batch = fm.group; Algo::get_subopr_list(
al = {{batch, K, M}, dt}; const TensorLayoutArray& layouts, const OperatorBase* opr) const {
bl = {{batch, K, N}, dt}; const DeformableConvBackwardDataImpl* deformable_conv =
cl = {{batch, M, N}, dt}; static_cast<const DeformableConvBackwardDataImpl*>(opr);
CanonizedFilterMeta fm = deformable_conv->make_canonized_filter_meta(
layouts[0].ndim, layouts[1], layouts[2]);
auto&& config = sub_opr_config(fm, layouts[0], layouts[4]);
std::string param_str;
Algorithm::serialize_write_pod(config.second, param_str);
return {{Algorithm::OprType::BATCHED_MATRIX_MUL_FORWARD, param_str,
config.first}};
}
bool Algo::is_available(const SizeArgs&) const {
return true;
} }
WorkspaceBundle Algo::get_bundle(const SizeArgs& args) { WorkspaceBundle Algo::get_bundle(const SizeArgs& args) {
...@@ -83,14 +106,20 @@ WorkspaceBundle Algo::get_bundle(const SizeArgs& args) { ...@@ -83,14 +106,20 @@ WorkspaceBundle Algo::get_bundle(const SizeArgs& args) {
OC = args.out_grad_layout[1], OH = args.out_grad_layout[2], OC = args.out_grad_layout[1], OH = args.out_grad_layout[2],
OW = args.out_grad_layout[3], FH = fm.spatial[0], FW = fm.spatial[1]; OW = args.out_grad_layout[3], FH = fm.spatial[0], FW = fm.spatial[1];
auto&& bmm_opr = args.handle->create_operator<BatchedMatrixMulForward>(); auto bmatmul_opr = args.handle->create_operator<BatchedMatrixMulForward>();
TensorLayout al, bl, cl; if (args.opr->execution_policy().algo.valid() &&
!args.opr->execution_policy().sub_policy.empty()) {
megdnn_assert(args.opr->execution_policy().sub_policy.size() == 1);
bmatmul_opr->execution_policy() =
args.opr->execution_policy().sub_policy[0];
}
get_matmul_layout(args, al, bl, cl); auto&& config = sub_opr_config(args.filter_meta, args.im_layout,
bmm_opr->param().compute_mode = param::MatrixMul::ComputeMode::DEFAULT; args.out_grad_layout);
bmm_opr->param().transposeA = true; bmatmul_opr->param() = config.second;
size_t bmm_ws = bmm_opr->get_workspace_in_bytes(al, bl, cl); size_t bmm_ws = bmatmul_opr->get_workspace_in_bytes(
config.first[0], config.first[1], config.first[2]);
size_t result_ws = batch_sz * IC * FH * FW * OH * OW * sizeof(float); size_t result_ws = batch_sz * IC * FH * FW * OH * OW * sizeof(float);
size_t relayout_ws1 = batch_sz * OC * OH * OW * sizeof(float); size_t relayout_ws1 = batch_sz * OC * OH * OW * sizeof(float);
size_t relayout_ws2 = batch_sz * IC * FH * FW * OH * OW * sizeof(float); size_t relayout_ws2 = batch_sz * IC * FH * FW * OH * OW * sizeof(float);
...@@ -154,21 +183,24 @@ void Algo::exec(const ExecArgs& args) const { ...@@ -154,21 +183,24 @@ void Algo::exec(const ExecArgs& args) const {
// matmul [g, icpg, FH, FW, ocpg] * [g, ocpg, N, OH, OW] => // matmul [g, icpg, FH, FW, ocpg] * [g, ocpg, N, OH, OW] =>
// => [g, icpg, FH, FW, N, OH, OW] // => [g, icpg, FH, FW, N, OH, OW]
{ {
TensorLayout al, bl, cl; auto bmatmul_opr =
get_matmul_layout(args, al, bl, cl);
TensorND A(static_cast<void*>(dev_filter), al),
B(static_cast<void*>(relayout_ws1), bl),
C(static_cast<void*>(result_ws), cl);
size_t bmm_ws_size = bundle.get_size(0);
auto&& bmm_opr =
args.handle->create_operator<BatchedMatrixMulForward>(); args.handle->create_operator<BatchedMatrixMulForward>();
if (args.opr->execution_policy().algo.valid()) {
megdnn_assert(args.opr->execution_policy().sub_policy.size() == 1);
bmatmul_opr->execution_policy() =
args.opr->execution_policy().sub_policy[0];
}
bmm_opr->param().compute_mode = param::MatrixMul::ComputeMode::DEFAULT; auto&& config = sub_opr_config(args.filter_meta, args.im_layout,
bmm_opr->param().transposeA = true; args.out_grad_layout);
bmatmul_opr->param() = config.second;
bmm_opr->exec( TensorND A(static_cast<void*>(dev_filter), config.first[0]),
B(static_cast<void*>(relayout_ws1), config.first[1]),
C(static_cast<void*>(result_ws), config.first[2]);
size_t bmm_ws_size = bundle.get_size(0);
bmatmul_opr->exec(
A, B, C, A, B, C,
Workspace(static_cast<megdnn::dt_byte*>(bmm_ws), bmm_ws_size)); Workspace(static_cast<megdnn::dt_byte*>(bmm_ws), bmm_ws_size));
} }
......
...@@ -92,20 +92,20 @@ public: ...@@ -92,20 +92,20 @@ public:
class DeformableConvBackwardFilterImpl::AlgoMatmul final : public AlgoBase { class DeformableConvBackwardFilterImpl::AlgoMatmul final : public AlgoBase {
private: private:
static void get_matmul_layout(const SizeArgs& args, TensorLayout& al,
TensorLayout& bl, TensorLayout& cl);
static WorkspaceBundle get_bundle(const SizeArgs& args); static WorkspaceBundle get_bundle(const SizeArgs& args);
public: public:
AlgoMatmul() {}
bool is_available(const SizeArgs& args) const override; bool is_available(const SizeArgs& args) const override;
size_t get_workspace_in_bytes(const SizeArgs& args) const override; size_t get_workspace_in_bytes(const SizeArgs& args) const override;
void exec(const ExecArgs& args) const override; void exec(const ExecArgs& args) const override;
bool is_reproducible() const override { return true; } bool is_reproducible() const override { return true; }
const char* name() const override { return "AlgoMatmul"; } std::vector<SearchItem> get_subopr_list(
const TensorLayoutArray& layouts,
const OperatorBase* opr) const override;
const char* name() const override { return "MATMUL"; }
MEGDNN_DECL_ALGO_TYPE(CUDA_MATMUL) MEGDNN_DECL_ALGO_TYPE(CUDA_MATMUL)
}; };
......
...@@ -6,7 +6,8 @@ ...@@ -6,7 +6,8 @@
* *
* Unless required by applicable law or agreed to in writing, * Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an * software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/ */
#include "src/cuda/utils.h" #include "src/cuda/utils.h"
...@@ -57,25 +58,46 @@ deformable_conv::Param create_param(const Algo::SizeArgs& args, ...@@ -57,25 +58,46 @@ deformable_conv::Param create_param(const Algo::SizeArgs& args,
return p; return p;
} }
}; // anonymous namespace
bool Algo::is_available(const SizeArgs&) const {
return true;
}
void Algo::get_matmul_layout(const SizeArgs& args, TensorLayout& al, std::pair<TensorLayoutArray, BatchedMatrixMulForward::Param> sub_opr_config(
TensorLayout& bl, TensorLayout& cl) { const DeformableConvBackwardFilterImpl::CanonizedFilterMeta& fm,
auto&& dt = args.im_layout.dtype; const TensorLayout& im, const TensorLayout& out_grad) {
auto&& fm = args.filter_grad_meta; auto&& dt = im.dtype;
size_t batch_sz = args.im_layout[0], OH = args.out_grad_layout[2], size_t batch_sz = im[0], OH = out_grad[2], OW = out_grad[3],
OW = args.out_grad_layout[3], FH = fm.spatial[0], FW = fm.spatial[1]; FH = fm.spatial[0], FW = fm.spatial[1];
size_t M = fm.ocpg, K = OH * OW * batch_sz, N = fm.icpg * FH * FW, size_t M = fm.ocpg, K = OH * OW * batch_sz, N = fm.icpg * FH * FW,
batch = fm.group; batch = fm.group;
TensorLayout al = {{batch, M, K}, dt};
TensorLayout bl = {{batch, N, K}, dt};
TensorLayout cl = {{batch, M, N}, dt};
BatchedMatrixMulForward::Param param;
param.compute_mode = param::MatrixMul::ComputeMode::DEFAULT;
param.transposeB = true;
al = {{batch, M, K}, dt}; return {{al, bl, cl}, param};
bl = {{batch, N, K}, dt}; }
cl = {{batch, M, N}, dt};
}; // anonymous namespace
std::vector<Algorithm::SearchItem>
Algo::get_subopr_list(
const TensorLayoutArray& layouts, const OperatorBase* opr) const {
const DeformableConvBackwardFilterImpl* deformable_conv =
static_cast<const DeformableConvBackwardFilterImpl*>(opr);
CanonizedFilterMeta fm = deformable_conv->make_canonized_filter_meta(
layouts[0].ndim, layouts[4], layouts[1]);
auto&& config = sub_opr_config(fm, layouts[0], layouts[3]);
std::string param_str;
Algorithm::serialize_write_pod(config.second, param_str);
return {{Algorithm::OprType::BATCHED_MATRIX_MUL_FORWARD, param_str,
config.first}};
}
bool Algo::is_available(const SizeArgs&) const {
return true;
} }
WorkspaceBundle Algo::get_bundle(const SizeArgs& args) { WorkspaceBundle Algo::get_bundle(const SizeArgs& args) {
...@@ -85,16 +107,22 @@ WorkspaceBundle Algo::get_bundle(const SizeArgs& args) { ...@@ -85,16 +107,22 @@ WorkspaceBundle Algo::get_bundle(const SizeArgs& args) {
size_t IC = fm.group * fm.icpg, OC = args.out_grad_layout[1]; size_t IC = fm.group * fm.icpg, OC = args.out_grad_layout[1];
auto batch_sz = args.im_layout[0]; auto batch_sz = args.im_layout[0];
auto&& bmm_opr = args.handle->create_operator<BatchedMatrixMulForward>(); auto bmatmul_opr = args.handle->create_operator<BatchedMatrixMulForward>();
TensorLayout al, bl, cl; if (args.opr->execution_policy().algo.valid() &&
!args.opr->execution_policy().sub_policy.empty()) {
megdnn_assert(args.opr->execution_policy().sub_policy.size() == 1);
bmatmul_opr->execution_policy() =
args.opr->execution_policy().sub_policy[0];
}
get_matmul_layout(args, al, bl, cl); auto&& config = sub_opr_config(args.filter_grad_meta, args.im_layout,
bmm_opr->param().compute_mode = param::MatrixMul::ComputeMode::DEFAULT; args.out_grad_layout);
bmm_opr->param().transposeB = true; bmatmul_opr->param() = config.second;
size_t col_ws = batch_sz * IC * FH * FW * OH * OW * sizeof(float); size_t col_ws = batch_sz * IC * FH * FW * OH * OW * sizeof(float);
size_t out_grad_ws = batch_sz * OC * OH * OW * sizeof(float); size_t out_grad_ws = batch_sz * OC * OH * OW * sizeof(float);
size_t bmm_ws = bmm_opr->get_workspace_in_bytes(al, bl, cl); size_t bmm_ws = bmatmul_opr->get_workspace_in_bytes(
config.first[0], config.first[1], config.first[2]);
return {nullptr, {col_ws, out_grad_ws, bmm_ws}}; return {nullptr, {col_ws, out_grad_ws, bmm_ws}};
} }
...@@ -138,20 +166,23 @@ void Algo::exec(const ExecArgs& args) const { ...@@ -138,20 +166,23 @@ void Algo::exec(const ExecArgs& args) const {
args.handle->relayout_opr()->exec(C2, C3); args.handle->relayout_opr()->exec(C2, C3);
// matmul // matmul
TensorLayout al, bl, cl; auto bmatmul_opr = args.handle->create_operator<BatchedMatrixMulForward>();
get_matmul_layout(args, al, bl, cl); if (args.opr->execution_policy().algo.valid()) {
megdnn_assert(args.opr->execution_policy().sub_policy.size() == 1);
TensorND A(static_cast<void*>(out_grad_ws), al), bmatmul_opr->execution_policy() =
B(static_cast<void*>(col_ws), bl), args.opr->execution_policy().sub_policy[0];
C(static_cast<void*>(dev_filter_grad), cl); }
size_t bmm_ws_size = bundle.get_size(2); auto&& config = sub_opr_config(args.filter_grad_meta, args.im_layout,
auto&& bmm_opr = args.handle->create_operator<BatchedMatrixMulForward>(); args.out_grad_layout);
bmatmul_opr->param() = config.second;
bmm_opr->param().compute_mode = param::MatrixMul::ComputeMode::DEFAULT; TensorND A(static_cast<void*>(out_grad_ws), config.first[0]),
bmm_opr->param().transposeB = true; B(static_cast<void*>(col_ws), config.first[1]),
C(static_cast<void*>(dev_filter_grad), config.first[2]);
bmm_opr->exec( size_t bmm_ws_size = bundle.get_size(2);
bmatmul_opr->exec(
A, B, C, A, B, C,
Workspace(static_cast<megdnn::dt_byte*>(bmm_ws), bmm_ws_size)); Workspace(static_cast<megdnn::dt_byte*>(bmm_ws), bmm_ws_size));
} }
......
...@@ -87,20 +87,20 @@ public: ...@@ -87,20 +87,20 @@ public:
class DeformableConvForwardImpl::AlgoMatmul final : public AlgoBase { class DeformableConvForwardImpl::AlgoMatmul final : public AlgoBase {
private: private:
static void get_matmul_layout(const SizeArgs& args, TensorLayout& al,
TensorLayout& bl, TensorLayout& cl);
static WorkspaceBundle get_bundle(const SizeArgs& args); static WorkspaceBundle get_bundle(const SizeArgs& args);
public: public:
AlgoMatmul(){};
bool is_available(const SizeArgs& args) const override; bool is_available(const SizeArgs& args) const override;
size_t get_workspace_in_bytes(const SizeArgs& args) const override; size_t get_workspace_in_bytes(const SizeArgs& args) const override;
void exec(const ExecArgs& args) const override; void exec(const ExecArgs& args) const override;
bool is_reproducible() const override { return true; } bool is_reproducible() const override { return true; }
const char* name() const override { return "AlgoMatmul"; } std::vector<SearchItem> get_subopr_list(
const TensorLayoutArray& layouts,
const OperatorBase* opr) const override;
const char* name() const override { return "MATMUL"; }
MEGDNN_DECL_ALGO_TYPE(CUDA_MATMUL) MEGDNN_DECL_ALGO_TYPE(CUDA_MATMUL)
}; };
......
...@@ -57,24 +57,47 @@ deformable_conv::Param create_param(const Algo::SizeArgs& args, ...@@ -57,24 +57,47 @@ deformable_conv::Param create_param(const Algo::SizeArgs& args,
return p; return p;
} }
std::pair<TensorLayoutArray, BatchedMatrixMulForward::Param> sub_opr_config(
const DeformableConvForwardImpl::CanonizedFilterMeta& fm,
const TensorLayout& im,
const TensorLayout& dst) {
auto&& dt = im.dtype;
size_t batch_sz = im[0], OH = dst[2],
OW = dst[3], FH = fm.spatial[0], FW = fm.spatial[1];
size_t M = fm.ocpg, N = OH * OW * batch_sz, K = fm.icpg * FH * FW,
batch = fm.group;
TensorLayout al = {{batch, M, K}, dt};
TensorLayout bl = {{batch, K, N}, dt};
TensorLayout cl = {{batch, M, N}, dt};
BatchedMatrixMulForward::Param param;
param.compute_mode = param::MatrixMul::ComputeMode::DEFAULT;
return {{al, bl, cl}, param};
}
}; // anonymous namespace }; // anonymous namespace
bool Algo::is_available(const SizeArgs&) const { std::vector<Algorithm::SearchItem>
return true; Algo::get_subopr_list(
const TensorLayoutArray& layouts, const OperatorBase* opr) const {
const DeformableConvForwardImpl* deformable_conv =
static_cast<const DeformableConvForwardImpl*>(opr);
CanonizedFilterMeta fm = deformable_conv->make_canonized_filter_meta(
layouts[0].ndim, layouts[1], layouts[2]);
auto&& config = sub_opr_config(fm, layouts[0], layouts[4]);
std::string param_str;
Algorithm::serialize_write_pod(config.second, param_str);
return {{Algorithm::OprType::BATCHED_MATRIX_MUL_FORWARD, param_str,
config.first}};
} }
void Algo::get_matmul_layout(const SizeArgs& args, TensorLayout& al,
TensorLayout& bl, TensorLayout& cl) {
auto&& dt = args.im_layout.dtype;
auto&& fm = args.filter_meta;
size_t batch_sz = args.im_layout[0], OH = args.dst_layout[2],
OW = args.dst_layout[3], FH = fm.spatial[0], FW = fm.spatial[1];
size_t M = fm.ocpg, N = OH * OW * batch_sz, K = fm.icpg * FH * FW, bool Algo::is_available(const SizeArgs&) const {
batch = fm.group; return true;
al = {{batch, M, K}, dt};
bl = {{batch, K, N}, dt};
cl = {{batch, M, N}, dt};
} }
WorkspaceBundle Algo::get_bundle(const SizeArgs& args) { WorkspaceBundle Algo::get_bundle(const SizeArgs& args) {
...@@ -83,17 +106,24 @@ WorkspaceBundle Algo::get_bundle(const SizeArgs& args) { ...@@ -83,17 +106,24 @@ WorkspaceBundle Algo::get_bundle(const SizeArgs& args) {
OC = args.dst_layout[1], OH = args.dst_layout[2], OC = args.dst_layout[1], OH = args.dst_layout[2],
OW = args.dst_layout[3], FH = fm.spatial[0], FW = fm.spatial[1]; OW = args.dst_layout[3], FH = fm.spatial[0], FW = fm.spatial[1];
auto&& bmm_opr = args.handle->create_operator<BatchedMatrixMulForward>(); auto bmatmul_opr = args.handle->create_operator<BatchedMatrixMulForward>();
TensorLayout al, bl, cl; if (args.opr->execution_policy().algo.valid() &&
!args.opr->execution_policy().sub_policy.empty()) {
megdnn_assert(args.opr->execution_policy().sub_policy.size() == 1);
bmatmul_opr->execution_policy() =
args.opr->execution_policy().sub_policy[0];
}
get_matmul_layout(args, al, bl, cl); auto&& config =
bmm_opr->param().compute_mode = param::MatrixMul::ComputeMode::DEFAULT; sub_opr_config(args.filter_meta, args.im_layout, args.dst_layout);
bmatmul_opr->param() = config.second;
size_t col_ws = batch_sz * IC * FH * FW * OH * OW * sizeof(float); size_t col_ws = batch_sz * IC * FH * FW * OH * OW * sizeof(float);
size_t bmm_ws = bmm_opr->get_workspace_in_bytes(al, bl, cl); size_t bmm_ws = bmatmul_opr->get_workspace_in_bytes(
config.first[0], config.first[1], config.first[2]);
size_t result_ws = batch_sz * OC * OH * OW * sizeof(float); size_t result_ws = batch_sz * OC * OH * OW * sizeof(float);
return {nullptr, {col_ws, bmm_ws, result_ws}}; return WorkspaceBundle{nullptr, {col_ws, bmm_ws, result_ws}};
} }
size_t Algo::get_workspace_in_bytes(const SizeArgs& args) const { size_t Algo::get_workspace_in_bytes(const SizeArgs& args) const {
...@@ -123,18 +153,25 @@ void Algo::exec(const ExecArgs& args) const { ...@@ -123,18 +153,25 @@ void Algo::exec(const ExecArgs& args) const {
// im2col // im2col
deformable_conv::im2col(dev_im, dev_offset, dev_mask, deformable_conv::im2col(dev_im, dev_offset, dev_mask,
static_cast<float*>(col_ws), p); static_cast<float*>(col_ws), p);
// matmul
TensorLayout al, bl, cl;
get_matmul_layout(args, al, bl, cl);
TensorND A(static_cast<void*>(dev_filter), al), auto bmatmul_opr = args.handle->create_operator<BatchedMatrixMulForward>();
B(static_cast<void*>(col_ws), bl), if (args.opr->execution_policy().algo.valid()) {
C(static_cast<void*>(result_ws), cl); megdnn_assert(args.opr->execution_policy().sub_policy.size() == 1);
bmatmul_opr->execution_policy() =
args.opr->execution_policy().sub_policy[0];
}
auto&& config =
sub_opr_config(args.filter_meta, args.im_layout, args.dst_layout);
bmatmul_opr->param() = config.second;
// matmul
TensorND A(static_cast<void*>(dev_filter), config.first[0]),
B(static_cast<void*>(col_ws), config.first[1]),
C(static_cast<void*>(result_ws), config.first[2]);
size_t bmm_ws_size = bundle.get_size(1); size_t bmm_ws_size = bundle.get_size(1);
auto&& bmm_opr = args.handle->create_operator<BatchedMatrixMulForward>(); bmatmul_opr->exec(
bmm_opr->param().compute_mode = param::MatrixMul::ComputeMode::DEFAULT;
bmm_opr->exec(
A, B, C, A, B, C,
Workspace(static_cast<megdnn::dt_byte*>(bmm_ws), bmm_ws_size)); Workspace(static_cast<megdnn::dt_byte*>(bmm_ws), bmm_ws_size));
// relayout // relayout
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册