提交 7afa422d 编写于 作者: M Megvii Engine Team

refactor(megdnn): refactor sub opr setter

GitOrigin-RevId: 475afb9c10f66f7aba41e164866364aa158dd13d
上级 821656aa
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
#include <functional> #include <functional>
#include <string> #include <string>
#include <tuple>
#include "megdnn/oprs/base.h" #include "megdnn/oprs/base.h"
#include "src/common/utils.h" #include "src/common/utils.h"
...@@ -83,6 +84,29 @@ public: ...@@ -83,6 +84,29 @@ public:
} }
}; };
template <std::size_t I = 0, typename Opr, typename... Tp>
inline typename std::enable_if<I == sizeof...(Tp), void>::type
set_sub_execution_policy(const Opr*, std::tuple<Tp...>&) {}
template <std::size_t I = 0, typename Opr, typename... Tp>
inline typename std::enable_if <
I<sizeof...(Tp), void>::type set_sub_execution_policy(
const Opr* opr, std::tuple<Tp...>& t) {
std::get<I>(t)->execution_policy() = opr->execution_policy().sub_policy[I];
set_sub_execution_policy<I + 1, Tp...>(opr, t);
}
template <typename Opr, typename... SubOpr>
void set_execution_policy(const Opr* opr, SubOpr... sub_oprs) {
if (opr->execution_policy().algo.valid() &&
!opr->execution_policy().sub_policy.empty()) {
megdnn_assert(opr->execution_policy().sub_policy.size() ==
sizeof...(sub_oprs));
auto&& sub = std::make_tuple(sub_oprs...);
set_sub_execution_policy<sizeof...(sub_oprs), Opr, SubOpr...>(opr, sub);
}
}
} // namespace megdnn } // namespace megdnn
namespace std { namespace std {
......
...@@ -8,9 +8,12 @@ ...@@ -8,9 +8,12 @@
* software distributed under the License is distributed on an * software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/ */
#include <algorithm>
#include <memory>
#include "./algo.h" #include "./algo.h"
#include "megdnn/opr_param_defs.h" #include "megdnn/opr_param_defs.h"
#include "src/common/algo_chooser.h" #include "src/common/algo_chooser.h"
#include "src/common/algo_base.h"
#include "src/cuda/handle.h" #include "src/cuda/handle.h"
#include "src/cuda/utils.h" #include "src/cuda/utils.h"
...@@ -27,6 +30,20 @@ std::pair<TensorLayoutArray, MatrixMulForward::Param> sub_opr_config( ...@@ -27,6 +30,20 @@ std::pair<TensorLayoutArray, MatrixMulForward::Param> sub_opr_config(
return {{mm_layout_a, mm_layout_b, mm_layout_c}, opr->param()}; return {{mm_layout_a, mm_layout_b, mm_layout_c}, opr->param()};
} }
std::pair<TensorLayoutArray, std::unique_ptr<MatrixMulForward>> prepare_sub_opr(
const BatchedMatrixMulForwardImpl::AlgoBase::SizeArgs& args) {
auto matmul_opr = args.opr->handle()->create_operator<MatrixMulForward>();
set_execution_policy<BatchedMatrixMulForward, MatrixMulForward*>(
args.opr, matmul_opr.get());
auto&& config = sub_opr_config(args.layout_a, args.layout_b, args.layout_c,
args.opr);
matmul_opr->param() = config.second;
return {config.first, std::move(matmul_opr)};
}
} // namespace } // namespace
std::vector<Algorithm::SearchItem> std::vector<Algorithm::SearchItem>
...@@ -43,51 +60,23 @@ BatchedMatrixMulForwardImpl::AlgoBruteForce::get_subopr_list( ...@@ -43,51 +60,23 @@ BatchedMatrixMulForwardImpl::AlgoBruteForce::get_subopr_list(
bool BatchedMatrixMulForwardImpl::AlgoBruteForce::is_available( bool BatchedMatrixMulForwardImpl::AlgoBruteForce::is_available(
const SizeArgs& args) const { const SizeArgs& args) const {
auto matmul_opr = args.opr->handle()->create_operator<MatrixMulForward>(); auto config = prepare_sub_opr(args);
if (args.opr->execution_policy().algo.valid() &&
!args.opr->execution_policy().sub_policy.empty()) {
megdnn_assert(args.opr->execution_policy().sub_policy.size() == 1);
matmul_opr->execution_policy() =
args.opr->execution_policy().sub_policy[0];
}
auto&& config = sub_opr_config(args.layout_a, args.layout_b, args.layout_c, return get_algorithm(
args.opr); static_cast<MatrixMulForwardImpl*>(config.second.get()),
matmul_opr->param() = config.second; config.first[0], config.first[1], config.first[2]);
return get_algorithm(static_cast<MatrixMulForwardImpl*>(matmul_opr.get()),
config.first[0], config.first[1], config.first[2]);
} }
size_t BatchedMatrixMulForwardImpl::AlgoBruteForce::get_workspace_in_bytes( size_t BatchedMatrixMulForwardImpl::AlgoBruteForce::get_workspace_in_bytes(
const SizeArgs& args) const { const SizeArgs& args) const {
auto matmul_opr = args.opr->handle()->create_operator<MatrixMulForward>(); auto config = prepare_sub_opr(args);
if (args.opr->execution_policy().algo.valid() &&
!args.opr->execution_policy().sub_policy.empty()) {
megdnn_assert(args.opr->execution_policy().sub_policy.size() == 1);
matmul_opr->execution_policy() =
args.opr->execution_policy().sub_policy[0];
}
auto&& config = sub_opr_config(args.layout_a, args.layout_b, args.layout_c,
args.opr);
matmul_opr->param() = config.second;
return matmul_opr->get_workspace_in_bytes(config.first[0], config.first[1], return config.second->get_workspace_in_bytes(
config.first[2]); config.first[0], config.first[1], config.first[2]);
} }
void BatchedMatrixMulForwardImpl::AlgoBruteForce::exec( void BatchedMatrixMulForwardImpl::AlgoBruteForce::exec(
const ExecArgs& args) const { const ExecArgs& args) const {
auto N = args.layout_a.shape[0]; auto N = args.layout_a.shape[0];
auto matmul_opr = args.opr->handle()->create_operator<MatrixMulForward>(); auto config = prepare_sub_opr(args);
if (args.opr->execution_policy().algo.valid()) {
megdnn_assert(args.opr->execution_policy().sub_policy.size() == 1);
matmul_opr->execution_policy() =
args.opr->execution_policy().sub_policy[0];
}
auto&& config = sub_opr_config(args.layout_a, args.layout_b, args.layout_c,
args.opr);
matmul_opr->param() = config.second;
rep(n, N) { rep(n, N) {
TensorND A_, B_, C_; TensorND A_, B_, C_;
...@@ -100,6 +89,6 @@ void BatchedMatrixMulForwardImpl::AlgoBruteForce::exec( ...@@ -100,6 +89,6 @@ void BatchedMatrixMulForwardImpl::AlgoBruteForce::exec(
tensor_n_from_batch(args.tensor_a, A_); tensor_n_from_batch(args.tensor_a, A_);
tensor_n_from_batch(args.tensor_b, B_); tensor_n_from_batch(args.tensor_b, B_);
tensor_n_from_batch(args.tensor_c, C_); tensor_n_from_batch(args.tensor_c, C_);
matmul_opr->exec(A_, B_, C_, args.workspace); config.second->exec(A_, B_, C_, args.workspace);
} }
} }
...@@ -11,6 +11,7 @@ ...@@ -11,6 +11,7 @@
*/ */
#include "src/common/algo_chooser.h" #include "src/common/algo_chooser.h"
#include "src/common/algo_base.h"
#include "src/common/conv_bias.h" #include "src/common/conv_bias.h"
#include "src/cuda/batched_matrix_mul/algo.h" #include "src/cuda/batched_matrix_mul/algo.h"
#include "src/cuda/conv_bias/algo.h" #include "src/cuda/conv_bias/algo.h"
...@@ -51,6 +52,19 @@ std::pair<TensorLayoutArray, MatrixMulForward::Param> sub_opr_config( ...@@ -51,6 +52,19 @@ std::pair<TensorLayoutArray, MatrixMulForward::Param> sub_opr_config(
return {{A, B, C}, param}; return {{A, B, C}, param};
} }
std::pair<TensorLayoutArray, std::unique_ptr<BatchedMatrixMulForward>>
prepare_sub_opr(const ConvBiasForwardImpl::AlgoBase::SizeArgs& args) {
auto bmatmul_opr = args.handle->create_operator<BatchedMatrixMulForward>();
set_execution_policy<ConvBiasForward, BatchedMatrixMulForward*>(
args.opr, bmatmul_opr.get());
auto&& config =
sub_opr_config(args.filter_meta, *args.src_layout,
*args.filter_layout, *args.dst_layout, args.opr);
bmatmul_opr->param() = config.second;
return {config.first, std::move(bmatmul_opr)};
}
} // namespace } // namespace
std::vector<Algorithm::SearchItem> std::vector<Algorithm::SearchItem>
...@@ -74,18 +88,7 @@ bool ConvBiasForwardImpl::AlgoBatchedMatmul::is_available( ...@@ -74,18 +88,7 @@ bool ConvBiasForwardImpl::AlgoBatchedMatmul::is_available(
if (args.z_layout->ndim > 0) if (args.z_layout->ndim > 0)
return false; return false;
auto bmatmul_opr = args.handle->create_operator<BatchedMatrixMulForward>(); auto config = prepare_sub_opr(args);
if (args.opr->execution_policy().algo.valid() &&
!args.opr->execution_policy().sub_policy.empty()) {
megdnn_assert(args.opr->execution_policy().sub_policy.size() == 1);
bmatmul_opr->execution_policy() =
args.opr->execution_policy().sub_policy[0];
}
auto&& config =
sub_opr_config(args.filter_meta, *args.src_layout,
*args.filter_layout, *args.dst_layout, args.opr);
bmatmul_opr->param() = config.second;
auto&& fm = args.filter_meta; auto&& fm = args.filter_meta;
return fm.format == Param::Format::NCHW && return fm.format == Param::Format::NCHW &&
...@@ -95,9 +98,9 @@ bool ConvBiasForwardImpl::AlgoBatchedMatmul::is_available( ...@@ -95,9 +98,9 @@ bool ConvBiasForwardImpl::AlgoBatchedMatmul::is_available(
fm.dilation[1] == 1 && fm.spatial[0] == 1 && fm.spatial[1] == 1 && fm.dilation[1] == 1 && fm.spatial[0] == 1 && fm.spatial[1] == 1 &&
fm.padding[0] == 0 && fm.padding[1] == 0 && fm.stride[0] == 1 && fm.padding[0] == 0 && fm.padding[1] == 0 && fm.stride[0] == 1 &&
fm.stride[1] == 1 && fm.stride[1] == 1 &&
get_algorithm( get_algorithm(static_cast<BatchedMatrixMulForwardImpl*>(
static_cast<BatchedMatrixMulForwardImpl*>(bmatmul_opr.get()), config.second.get()),
config.first[0], config.first[1], config.first[2]); config.first[0], config.first[1], config.first[2]);
} }
WorkspaceBundle ConvBiasForwardImpl::AlgoBatchedMatmul::get_workspace_bundle( WorkspaceBundle ConvBiasForwardImpl::AlgoBatchedMatmul::get_workspace_bundle(
...@@ -115,21 +118,10 @@ WorkspaceBundle ConvBiasForwardImpl::AlgoBatchedMatmul::get_workspace_bundle( ...@@ -115,21 +118,10 @@ WorkspaceBundle ConvBiasForwardImpl::AlgoBatchedMatmul::get_workspace_bundle(
SizeArgs conv_args = args; SizeArgs conv_args = args;
conv_args.dst_layout = &dst_layout; conv_args.dst_layout = &dst_layout;
auto bmatmul_opr = args.handle->create_operator<BatchedMatrixMulForward>(); auto config = prepare_sub_opr(args);
if (args.opr->execution_policy().algo.valid() &&
!args.opr->execution_policy().sub_policy.empty()) {
megdnn_assert(args.opr->execution_policy().sub_policy.size() == 1);
bmatmul_opr->execution_policy() =
args.opr->execution_policy().sub_policy[0];
}
auto&& config =
sub_opr_config(args.filter_meta, *args.src_layout,
*args.filter_layout, *args.dst_layout, args.opr);
bmatmul_opr->param() = config.second;
sizes.insert(sizes.begin(), sizes.insert(sizes.begin(),
args.handle->batched_matrix_mul()->get_workspace_in_bytes( config.second->get_workspace_in_bytes(
config.first[0], config.first[1], config.first[2])); config.first[0], config.first[1], config.first[2]));
return {ptr, std::move(sizes)}; return {ptr, std::move(sizes)};
} }
...@@ -154,23 +146,12 @@ void ConvBiasForwardImpl::AlgoBatchedMatmul::exec(const ExecArgs& args) const { ...@@ -154,23 +146,12 @@ void ConvBiasForwardImpl::AlgoBatchedMatmul::exec(const ExecArgs& args) const {
conv_args.dst_tensor = &conv_dst_tensor; conv_args.dst_tensor = &conv_dst_tensor;
conv_args.dst_layout = &conv_dst_tensor.layout; conv_args.dst_layout = &conv_dst_tensor.layout;
{ {
auto bmatmul_opr = auto config = prepare_sub_opr(args);
args.handle->create_operator<BatchedMatrixMulForward>();
if (args.opr->execution_policy().algo.valid()) {
megdnn_assert(args.opr->execution_policy().sub_policy.size() == 1);
bmatmul_opr->execution_policy() =
args.opr->execution_policy().sub_policy[0];
}
auto&& config =
sub_opr_config(args.filter_meta, *args.src_layout,
*args.filter_layout, *args.dst_layout, args.opr);
bmatmul_opr->param() = config.second;
TensorND A{args.filter_tensor->raw_ptr, config.first[0]}, TensorND A{args.filter_tensor->raw_ptr, config.first[0]},
B{args.src_tensor->raw_ptr, config.first[1]}, B{args.src_tensor->raw_ptr, config.first[1]},
C{args.dst_tensor->raw_ptr, config.first[2]}; C{args.dst_tensor->raw_ptr, config.first[2]};
bmatmul_opr->exec(A, B, C, bundle.get_workspace(0)); config.second->exec(A, B, C, bundle.get_workspace(0));
} }
handle_bias_and_nonlinear(args.handle, args.nonlinear_mode, handle_bias_and_nonlinear(args.handle, args.nonlinear_mode,
&conv_dst_tensor, args.dst_tensor, &conv_dst_tensor, args.dst_tensor,
......
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
#include "src/cuda/handle.h" #include "src/cuda/handle.h"
#include "src/cuda/utils.cuh" #include "src/cuda/utils.cuh"
#include "src/cuda/utils.h" #include "src/cuda/utils.h"
#include "src/common/algo_base.h"
using namespace megdnn; using namespace megdnn;
using namespace cuda; using namespace cuda;
...@@ -40,6 +41,18 @@ std::pair<TensorLayoutArray, ConvBiasForwardImpl::Param> sub_opr_config( ...@@ -40,6 +41,18 @@ std::pair<TensorLayoutArray, ConvBiasForwardImpl::Param> sub_opr_config(
ret.second.compute_mode = ConvBiasForwardImpl::Param::ComputeMode::DEFAULT; ret.second.compute_mode = ConvBiasForwardImpl::Param::ComputeMode::DEFAULT;
return ret; return ret;
} }
std::pair<TensorLayoutArray, std::unique_ptr<ConvBiasForward>> prepare_sub_opr(
const ConvBiasForwardImpl::AlgoBase::SizeArgs& args) {
auto convbias_opr = args.handle->create_operator<ConvBias>();
auto&& config = sub_opr_config(
{*args.src_layout, *args.filter_layout, *args.bias_layout,
*args.z_layout, *args.dst_layout},
args.opr);
convbias_opr->param() = config.second;
return {config.first, std::move(convbias_opr)};
}
} // namespace } // namespace
std::vector<Algorithm::SearchItem> std::vector<Algorithm::SearchItem>
...@@ -55,33 +68,18 @@ ConvBiasForwardImpl::AlgoBFloat16::get_subopr_list( ...@@ -55,33 +68,18 @@ ConvBiasForwardImpl::AlgoBFloat16::get_subopr_list(
bool ConvBiasForwardImpl::AlgoBFloat16::is_available( bool ConvBiasForwardImpl::AlgoBFloat16::is_available(
const SizeArgs& args) const { const SizeArgs& args) const {
auto convbias_opr = args.handle->create_operator<ConvBias>(); auto config = prepare_sub_opr(args);
auto&& config = sub_opr_config(
{*args.src_layout, *args.filter_layout, *args.bias_layout,
*args.z_layout, *args.dst_layout},
args.opr);
convbias_opr->param() = config.second;
return args.src_layout->dtype == args.filter_layout->dtype && return args.src_layout->dtype == args.filter_layout->dtype &&
args.src_layout->dtype == dtype::BFloat16() && args.src_layout->dtype == dtype::BFloat16() &&
get_algorithm(static_cast<ConvBiasForwardImpl*>(convbias_opr.get()), get_algorithm(static_cast<ConvBiasForwardImpl*>(config.second.get()),
config.first[0], config.first[1], config.first[2], config.first[0], config.first[1], config.first[2],
config.first[3], config.first[4]); config.first[3], config.first[4]);
} }
WorkspaceBundle ConvBiasForwardImpl::AlgoBFloat16::get_workspace_bundle( WorkspaceBundle ConvBiasForwardImpl::AlgoBFloat16::get_workspace_bundle(
void* ptr, const SizeArgs& args) const { void* ptr, const SizeArgs& args) const {
auto convbias_opr = args.handle->create_operator<ConvBias>(); auto config = prepare_sub_opr(args);
if (args.opr->execution_policy().algo.valid()) {
megdnn_assert(args.opr->execution_policy().sub_policy.size() == 1);
convbias_opr->execution_policy() =
args.opr->execution_policy().sub_policy[0];
}
auto&& config = sub_opr_config(
{*args.src_layout, *args.filter_layout, *args.bias_layout,
*args.z_layout, *args.dst_layout},
args.opr);
convbias_opr->param() = config.second;
SmallVector<size_t> sizes; SmallVector<size_t> sizes;
auto get_workspace = [&sizes](const TensorLayout& src, auto get_workspace = [&sizes](const TensorLayout& src,
...@@ -95,7 +93,7 @@ WorkspaceBundle ConvBiasForwardImpl::AlgoBFloat16::get_workspace_bundle( ...@@ -95,7 +93,7 @@ WorkspaceBundle ConvBiasForwardImpl::AlgoBFloat16::get_workspace_bundle(
get_workspace(*args.bias_layout, config.first[2]); get_workspace(*args.bias_layout, config.first[2]);
get_workspace(*args.z_layout, config.first[3]); get_workspace(*args.z_layout, config.first[3]);
get_workspace(*args.dst_layout, config.first[4]); get_workspace(*args.dst_layout, config.first[4]);
sizes.push_back(convbias_opr->get_workspace_in_bytes( sizes.push_back(config.second->get_workspace_in_bytes(
config.first[0], config.first[1], config.first[2], config.first[3], config.first[0], config.first[1], config.first[2], config.first[3],
config.first[4], nullptr)); config.first[4], nullptr));
...@@ -123,17 +121,10 @@ void ConvBiasForwardImpl::AlgoBFloat16::exec(const ExecArgs& args) const { ...@@ -123,17 +121,10 @@ void ConvBiasForwardImpl::AlgoBFloat16::exec(const ExecArgs& args) const {
.src_to_comp_type(*args.dst_tensor, fdst_tensor); .src_to_comp_type(*args.dst_tensor, fdst_tensor);
} }
{ {
auto convbias_opr = args.handle->create_operator<ConvBias>(); auto config = prepare_sub_opr(args);
convbias_opr->param() = args.opr->param();
convbias_opr->param().compute_mode = Param::ComputeMode::DEFAULT;
if (args.opr->execution_policy().algo.valid()) {
megdnn_assert(args.opr->execution_policy().sub_policy.size() == 1);
convbias_opr->execution_policy() =
args.opr->execution_policy().sub_policy[0];
}
convbias_opr->exec(fsrc_tensor, ffilter_tensor, fbias_tensor, fz_tensor, config.second->exec(fsrc_tensor, ffilter_tensor, fbias_tensor,
fdst_tensor, nullptr, cvter.workspace()); fz_tensor, fdst_tensor, nullptr, cvter.workspace());
} }
{ cvter.comp_to_dst_type(fdst_tensor, *args.dst_tensor); } { cvter.comp_to_dst_type(fdst_tensor, *args.dst_tensor); }
} }
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
#include "src/cuda/conv_bias/helper.h" #include "src/cuda/conv_bias/helper.h"
#include "src/cuda/conv_bias/matmul/im2col.cuh" #include "src/cuda/conv_bias/matmul/im2col.cuh"
#include "src/cuda/utils.h" #include "src/cuda/utils.h"
#include "src/common/algo_base.h"
using namespace megdnn; using namespace megdnn;
using namespace cuda; using namespace cuda;
...@@ -40,6 +41,19 @@ std::pair<TensorLayoutArray, MatrixMulForward::Param> sub_opr_config( ...@@ -40,6 +41,19 @@ std::pair<TensorLayoutArray, MatrixMulForward::Param> sub_opr_config(
return {{Al, Bl, Cl}, param}; return {{Al, Bl, Cl}, param};
} }
std::pair<TensorLayoutArray, std::unique_ptr<MatrixMulForward>> prepare_sub_opr(
const ConvBiasForwardImpl::AlgoBase::SizeArgs& args) {
auto matmul_opr = args.handle->create_operator<MatrixMulForward>();
set_execution_policy<ConvBiasForward, MatrixMulForward*>(args.opr,
matmul_opr.get());
auto&& config =
sub_opr_config(args.filter_meta, *args.src_layout,
*args.filter_layout, *args.dst_layout, args.opr);
matmul_opr->param() = config.second;
return {config.first, std::move(matmul_opr)};
}
} // namespace } // namespace
std::vector<Algorithm::SearchItem> std::vector<Algorithm::SearchItem>
...@@ -87,19 +101,8 @@ WorkspaceBundle ConvBiasForwardImpl::AlgoMatmul::get_workspace_bundle( ...@@ -87,19 +101,8 @@ WorkspaceBundle ConvBiasForwardImpl::AlgoMatmul::get_workspace_bundle(
conv_args.dst_layout = &dst_layout; conv_args.dst_layout = &dst_layout;
SmallVector<size_t> matmul_sizes = matmul_get_workspace_bundle(conv_args); SmallVector<size_t> matmul_sizes = matmul_get_workspace_bundle(conv_args);
auto matmul_opr = args.handle->create_operator<MatrixMulForward>(); auto config = prepare_sub_opr(args);
if (args.opr->execution_policy().algo.valid() && size_t mm_ws = config.second->get_workspace_in_bytes(
!args.opr->execution_policy().sub_policy.empty()) {
megdnn_assert(args.opr->execution_policy().sub_policy.size() == 1);
matmul_opr->execution_policy() =
args.opr->execution_policy().sub_policy[0];
}
auto&& config =
sub_opr_config(args.filter_meta, *args.src_layout,
*args.filter_layout, *args.dst_layout, args.opr);
matmul_opr->param() = config.second;
size_t mm_ws = matmul_opr->get_workspace_in_bytes(
config.first[0], config.first[1], config.first[2]); config.first[0], config.first[1], config.first[2]);
matmul_sizes.push_back(mm_ws); matmul_sizes.push_back(mm_ws);
...@@ -162,17 +165,7 @@ void ConvBiasForwardImpl::AlgoMatmul::exec_internal( ...@@ -162,17 +165,7 @@ void ConvBiasForwardImpl::AlgoMatmul::exec_internal(
args.src_layout->stride[0], IC, IH, IW, FH, FW, OH, OW, args.src_layout->stride[0], IC, IH, IW, FH, FW, OH, OW,
PH, PW, SH, SW, DH, DW, stream); PH, PW, SH, SW, DH, DW, stream);
auto matmul_opr = args.handle->create_operator<MatrixMulForward>(); auto config = prepare_sub_opr(args);
if (args.opr->execution_policy().algo.valid()) {
megdnn_assert(args.opr->execution_policy().sub_policy.size() == 1);
matmul_opr->execution_policy() =
args.opr->execution_policy().sub_policy[0];
}
auto&& config =
sub_opr_config(args.filter_meta, *args.src_layout,
*args.filter_layout, *args.dst_layout, args.opr);
matmul_opr->param() = config.second;
TensorND A(args.filter_tensor->ptr<T>(), config.first[0]), TensorND A(args.filter_tensor->ptr<T>(), config.first[0]),
B(col, config.first[1]), C(dst_t, config.first[2]); B(col, config.first[1]), C(dst_t, config.first[2]);
...@@ -182,7 +175,7 @@ void ConvBiasForwardImpl::AlgoMatmul::exec_internal( ...@@ -182,7 +175,7 @@ void ConvBiasForwardImpl::AlgoMatmul::exec_internal(
matmul_ws_idx = 3; matmul_ws_idx = 3;
} }
matmul_opr->exec(A, B, C, bundle.get_workspace(matmul_ws_idx)); config.second->exec(A, B, C, bundle.get_workspace(matmul_ws_idx));
TensorLayout C2l({OC * OH * OW, N}, typename DTypeTrait<T>::dtype()), TensorLayout C2l({OC * OH * OW, N}, typename DTypeTrait<T>::dtype()),
C3l = C2l; C3l = C2l;
......
...@@ -10,6 +10,7 @@ ...@@ -10,6 +10,7 @@
*/ */
#include "./algo.h" #include "./algo.h"
#include "src/common/algo_base.h"
#include "src/cuda/convolution/chanwise/kern.cuh" #include "src/cuda/convolution/chanwise/kern.cuh"
#include "src/cuda/utils.h" #include "src/cuda/utils.h"
...@@ -38,7 +39,19 @@ std::pair<TensorLayoutArray, ConvolutionBackwardDataImpl::Param> sub_opr_config( ...@@ -38,7 +39,19 @@ std::pair<TensorLayoutArray, ConvolutionBackwardDataImpl::Param> sub_opr_config(
ConvolutionBackwardData::Param::ComputeMode::DEFAULT; ConvolutionBackwardData::Param::ComputeMode::DEFAULT;
return ret; return ret;
} }
std::pair<TensorLayoutArray, std::unique_ptr<ConvolutionBackwardData>>
prepare_sub_opr(const ConvolutionBackwardDataImpl::AlgoBase::SizeArgs& args) {
auto conv_back_data_opr =
args.handle->create_operator<ConvolutionBackwardData>();
auto&& config = sub_opr_config(
{*args.filter_layout, *args.diff_layout, *args.grad_layout},
args.opr);
conv_back_data_opr->param() = config.second;
return {config.first, std::move(conv_back_data_opr)};
} }
} // namespace
std::vector<Algorithm::SearchItem> std::vector<Algorithm::SearchItem>
ConvolutionBackwardDataImpl::AlgoBFloat16::get_subopr_list( ConvolutionBackwardDataImpl::AlgoBFloat16::get_subopr_list(
...@@ -54,33 +67,17 @@ ConvolutionBackwardDataImpl::AlgoBFloat16::get_subopr_list( ...@@ -54,33 +67,17 @@ ConvolutionBackwardDataImpl::AlgoBFloat16::get_subopr_list(
bool ConvolutionBackwardDataImpl::AlgoBFloat16::is_available( bool ConvolutionBackwardDataImpl::AlgoBFloat16::is_available(
const SizeArgs& args) const { const SizeArgs& args) const {
TensorLayout ffilter, fdiff, fgrad; auto config = prepare_sub_opr(args);
auto conv_back_data_opr =
args.handle->create_operator<ConvolutionBackwardData>();
auto&& config = sub_opr_config(
{*args.filter_layout, *args.diff_layout, *args.grad_layout},
args.opr);
conv_back_data_opr->param() = config.second;
return args.diff_layout->dtype == args.filter_layout->dtype && return args.diff_layout->dtype == args.filter_layout->dtype &&
args.diff_layout->dtype == dtype::BFloat16() && args.diff_layout->dtype == dtype::BFloat16() &&
get_algorithm(static_cast<ConvolutionBackwardDataImpl*>( get_algorithm(static_cast<ConvolutionBackwardDataImpl*>(
conv_back_data_opr.get()), config.second.get()),
config.first[0], config.first[1], config.first[2]); config.first[0], config.first[1], config.first[2]);
} }
WorkspaceBundle ConvolutionBackwardDataImpl::AlgoBFloat16::get_workspace_bundle( WorkspaceBundle ConvolutionBackwardDataImpl::AlgoBFloat16::get_workspace_bundle(
void* ptr, const SizeArgs& args) const { void* ptr, const SizeArgs& args) const {
auto conv_back_data_opr = auto config = prepare_sub_opr(args);
args.handle->create_operator<ConvolutionBackwardData>();
if (args.opr->execution_policy().algo.valid()) {
megdnn_assert(args.opr->execution_policy().sub_policy.size() == 1);
conv_back_data_opr->execution_policy() =
args.opr->execution_policy().sub_policy[0];
}
auto&& config = sub_opr_config(
{*args.filter_layout, *args.diff_layout, *args.grad_layout},
args.opr);
conv_back_data_opr->param() = config.second;
SmallVector<size_t> sizes; SmallVector<size_t> sizes;
auto get_workspace = [&sizes](const TensorLayout& src, auto get_workspace = [&sizes](const TensorLayout& src,
const TensorLayout& dst) { const TensorLayout& dst) {
...@@ -92,7 +89,7 @@ WorkspaceBundle ConvolutionBackwardDataImpl::AlgoBFloat16::get_workspace_bundle( ...@@ -92,7 +89,7 @@ WorkspaceBundle ConvolutionBackwardDataImpl::AlgoBFloat16::get_workspace_bundle(
get_workspace(*args.diff_layout, config.first[1]); get_workspace(*args.diff_layout, config.first[1]);
get_workspace(*args.grad_layout, config.first[2]); get_workspace(*args.grad_layout, config.first[2]);
sizes.push_back(conv_back_data_opr->get_workspace_in_bytes( sizes.push_back(config.second->get_workspace_in_bytes(
config.first[0], config.first[1], config.first[2])); config.first[0], config.first[1], config.first[2]));
return {ptr, std::move(sizes)}; return {ptr, std::move(sizes)};
} }
...@@ -115,17 +112,9 @@ void ConvolutionBackwardDataImpl::AlgoBFloat16::exec( ...@@ -115,17 +112,9 @@ void ConvolutionBackwardDataImpl::AlgoBFloat16::exec(
.src_to_comp_type(*args.grad_tensor, fgrad_tensor); .src_to_comp_type(*args.grad_tensor, fgrad_tensor);
} }
{ {
auto conv_back_data_opr = auto config = prepare_sub_opr(args);
args.handle->create_operator<ConvolutionBackwardData>(); config.second->exec(ffilter_tensor, fdiff_tensor, fgrad_tensor,
if (args.opr->execution_policy().algo.valid()) { cvter.workspace());
megdnn_assert(args.opr->execution_policy().sub_policy.size() == 1);
conv_back_data_opr->execution_policy() =
args.opr->execution_policy().sub_policy[0];
}
conv_back_data_opr->param() = args.opr->param();
conv_back_data_opr->param().compute_mode = Param::ComputeMode::DEFAULT;
conv_back_data_opr->exec(ffilter_tensor, fdiff_tensor, fgrad_tensor,
cvter.workspace());
} }
{ cvter.comp_to_dst_type(fgrad_tensor, *args.grad_tensor); } { cvter.comp_to_dst_type(fgrad_tensor, *args.grad_tensor); }
} }
......
...@@ -11,6 +11,7 @@ ...@@ -11,6 +11,7 @@
*/ */
#include "./algo.h" #include "./algo.h"
#include "src/common/algo_base.h"
#include "src/cuda/convolution/helper.h" #include "src/cuda/convolution/helper.h"
#include "src/cuda/convolution/im2col.cuh" #include "src/cuda/convolution/im2col.cuh"
#include "src/cuda/matrix_mul/opr_impl.h" #include "src/cuda/matrix_mul/opr_impl.h"
...@@ -43,6 +44,19 @@ std::pair<TensorLayoutArray, MatrixMulForward::Param> sub_opr_config( ...@@ -43,6 +44,19 @@ std::pair<TensorLayoutArray, MatrixMulForward::Param> sub_opr_config(
param.transposeA = true; param.transposeA = true;
return {{Al, Cl, Bl}, param}; return {{Al, Cl, Bl}, param};
} }
std::pair<TensorLayoutArray, std::unique_ptr<MatrixMulForward>> prepare_sub_opr(
const ConvolutionBackwardDataImpl::AlgoBase::SizeArgs& args) {
auto matmul_opr = args.handle->create_operator<MatrixMulForward>();
set_execution_policy<ConvolutionBackwardData, MatrixMulForward*>(
args.opr, matmul_opr.get());
auto&& config =
sub_opr_config(args.filter_meta, *args.filter_layout,
*args.diff_layout, *args.grad_layout, args.opr);
matmul_opr->param() = config.second;
return {config.first, std::move(matmul_opr)};
}
} // namespace } // namespace
std::vector<Algorithm::SearchItem> std::vector<Algorithm::SearchItem>
...@@ -57,8 +71,7 @@ ConvolutionBackwardDataImpl::AlgoMatmul::get_subopr_list( ...@@ -57,8 +71,7 @@ ConvolutionBackwardDataImpl::AlgoMatmul::get_subopr_list(
std::string param_str; std::string param_str;
Algorithm::serialize_write_pod(config.second, param_str); Algorithm::serialize_write_pod(config.second, param_str);
return {{Algorithm::OprType::MATRIX_MUL_FORWARD, param_str, return {{Algorithm::OprType::MATRIX_MUL_FORWARD, param_str, config.first}};
config.first}};
} }
bool ConvolutionBackwardDataImpl::AlgoMatmul::is_available( bool ConvolutionBackwardDataImpl::AlgoMatmul::is_available(
...@@ -75,22 +88,10 @@ bool ConvolutionBackwardDataImpl::AlgoMatmul::is_available( ...@@ -75,22 +88,10 @@ bool ConvolutionBackwardDataImpl::AlgoMatmul::is_available(
size_t ConvolutionBackwardDataImpl::AlgoMatmul::get_workspace_in_bytes( size_t ConvolutionBackwardDataImpl::AlgoMatmul::get_workspace_in_bytes(
const SizeArgs& args) const { const SizeArgs& args) const {
auto matmul_opr = auto config = prepare_sub_opr(args);
args.handle->create_operator<MatrixMulForward>();
if (args.opr->execution_policy().algo.valid() &&
!args.opr->execution_policy().sub_policy.empty()) {
megdnn_assert(args.opr->execution_policy().sub_policy.size() == 1);
matmul_opr->execution_policy() =
args.opr->execution_policy().sub_policy[0];
}
auto&& config =
sub_opr_config(args.filter_meta, *args.filter_layout,
*args.diff_layout, *args.grad_layout, args.opr);
matmul_opr->param() = config.second;
auto&& sizes = matmul_get_workspace_bundle(args.as_fwd_args()); auto&& sizes = matmul_get_workspace_bundle(args.as_fwd_args());
sizes.push_back(matmul_opr->get_workspace_in_bytes( sizes.push_back(config.second->get_workspace_in_bytes(
config.first[0], config.first[1], config.first[2])); config.first[0], config.first[1], config.first[2]));
return WorkspaceBundle(nullptr, sizes).total_size_in_bytes(); return WorkspaceBundle(nullptr, sizes).total_size_in_bytes();
} }
...@@ -121,19 +122,10 @@ void ConvolutionBackwardDataImpl::AlgoMatmul::exec_internal( ...@@ -121,19 +122,10 @@ void ConvolutionBackwardDataImpl::AlgoMatmul::exec_internal(
DW = fm.dilation[1]; DW = fm.dilation[1];
auto stream = cuda_stream(args.handle); auto stream = cuda_stream(args.handle);
auto matmul_opr = args.handle->create_operator<MatrixMulForward>(); auto config = prepare_sub_opr(args);
if (args.opr->execution_policy().algo.valid()) {
megdnn_assert(args.opr->execution_policy().sub_policy.size() == 1);
matmul_opr->execution_policy() =
args.opr->execution_policy().sub_policy[0];
}
auto&& config =
sub_opr_config(args.filter_meta, *args.filter_layout,
*args.diff_layout, *args.grad_layout, args.opr);
matmul_opr->param() = config.second;
auto&& sizes = matmul_get_workspace_bundle(args.as_fwd_args()); auto&& sizes = matmul_get_workspace_bundle(args.as_fwd_args());
sizes.push_back(matmul_opr->get_workspace_in_bytes( sizes.push_back(config.second->get_workspace_in_bytes(
config.first[0], config.first[1], config.first[2])); config.first[0], config.first[1], config.first[2]));
auto wbundle = WorkspaceBundle(args.workspace.raw_ptr, sizes); auto wbundle = WorkspaceBundle(args.workspace.raw_ptr, sizes);
...@@ -159,9 +151,9 @@ void ConvolutionBackwardDataImpl::AlgoMatmul::exec_internal( ...@@ -159,9 +151,9 @@ void ConvolutionBackwardDataImpl::AlgoMatmul::exec_internal(
if (fm.should_flip) { if (fm.should_flip) {
convolution::flip_filter(args.as_fwd_args(), convolution::flip_filter(args.as_fwd_args(),
wbundle.get_workspace(2), A.raw_ptr); wbundle.get_workspace(2), A.raw_ptr);
matmul_opr->exec(A, C, B, wbundle.get_workspace(3)); config.second->exec(A, C, B, wbundle.get_workspace(3));
} else { } else {
matmul_opr->exec(A, C, B, wbundle.get_workspace(2)); config.second->exec(A, C, B, wbundle.get_workspace(2));
} }
} }
{ {
......
...@@ -11,6 +11,7 @@ ...@@ -11,6 +11,7 @@
*/ */
#include "./algo.h" #include "./algo.h"
#include "src/common/algo_base.h"
#include "src/cuda/convolution/chanwise/kern.cuh" #include "src/cuda/convolution/chanwise/kern.cuh"
#include "src/cuda/utils.h" #include "src/cuda/utils.h"
...@@ -39,6 +40,18 @@ sub_opr_config(const TensorLayoutArray& layouts, ...@@ -39,6 +40,18 @@ sub_opr_config(const TensorLayoutArray& layouts,
ConvolutionBackwardFilter::Param::ComputeMode::DEFAULT; ConvolutionBackwardFilter::Param::ComputeMode::DEFAULT;
return ret; return ret;
} }
std::pair<TensorLayoutArray, std::unique_ptr<ConvolutionBackwardFilter>>
prepare_sub_opr(const ConvolutionBackwardFilterImpl::AlgoBase::SizeArgs& args) {
auto conv_back_filter_opr =
args.handle->create_operator<ConvolutionBackwardFilter>();
auto&& config = sub_opr_config(
{*args.src_layout, *args.diff_layout, *args.grad_layout}, args.opr);
conv_back_filter_opr->param() = config.second;
return {config.first, std::move(conv_back_filter_opr)};
}
} // namespace } // namespace
std::vector<Algorithm::SearchItem> std::vector<Algorithm::SearchItem>
...@@ -55,36 +68,18 @@ ConvolutionBackwardFilterImpl::AlgoBFloat16::get_subopr_list( ...@@ -55,36 +68,18 @@ ConvolutionBackwardFilterImpl::AlgoBFloat16::get_subopr_list(
bool ConvolutionBackwardFilterImpl::AlgoBFloat16::is_available( bool ConvolutionBackwardFilterImpl::AlgoBFloat16::is_available(
const SizeArgs& args) const { const SizeArgs& args) const {
TensorLayout fsrc, fdiff, fgrad; auto config = prepare_sub_opr(args);
auto conv_back_filter_opr =
args.handle->create_operator<ConvolutionBackwardFilter>();
auto&& config = sub_opr_config(
{*args.src_layout, *args.diff_layout, *args.grad_layout},
args.opr);
conv_back_filter_opr->param() = config.second;
return args.src_layout->dtype == args.diff_layout->dtype && return args.src_layout->dtype == args.diff_layout->dtype &&
args.src_layout->dtype == dtype::BFloat16() && args.src_layout->dtype == dtype::BFloat16() &&
get_algorithm(static_cast<ConvolutionBackwardFilterImpl*>( get_algorithm(static_cast<ConvolutionBackwardFilterImpl*>(
conv_back_filter_opr.get()), config.second.get()),
config.first[0], config.first[1], config.first[2]); config.first[0], config.first[1], config.first[2]);
} }
WorkspaceBundle WorkspaceBundle
ConvolutionBackwardFilterImpl::AlgoBFloat16::get_workspace_bundle( ConvolutionBackwardFilterImpl::AlgoBFloat16::get_workspace_bundle(
void* ptr, const SizeArgs& args) const { void* ptr, const SizeArgs& args) const {
auto conv_back_filter_opr = auto config = prepare_sub_opr(args);
args.handle->create_operator<ConvolutionBackwardFilter>();
if (args.opr->execution_policy().algo.valid()) {
megdnn_assert(args.opr->execution_policy().sub_policy.size() == 1);
conv_back_filter_opr->execution_policy() =
args.opr->execution_policy().sub_policy[0];
}
auto&& config = sub_opr_config(
{*args.src_layout, *args.diff_layout, *args.grad_layout},
args.opr);
conv_back_filter_opr->param() = config.second;
SmallVector<size_t> sizes; SmallVector<size_t> sizes;
auto get_workspace = [&sizes](const TensorLayout& src, auto get_workspace = [&sizes](const TensorLayout& src,
const TensorLayout& dst) { const TensorLayout& dst) {
...@@ -96,7 +91,7 @@ ConvolutionBackwardFilterImpl::AlgoBFloat16::get_workspace_bundle( ...@@ -96,7 +91,7 @@ ConvolutionBackwardFilterImpl::AlgoBFloat16::get_workspace_bundle(
get_workspace(*args.src_layout, config.first[0]); get_workspace(*args.src_layout, config.first[0]);
get_workspace(*args.diff_layout, config.first[1]); get_workspace(*args.diff_layout, config.first[1]);
get_workspace(*args.grad_layout, config.first[2]); get_workspace(*args.grad_layout, config.first[2]);
sizes.push_back(conv_back_filter_opr->get_workspace_in_bytes( sizes.push_back(config.second->get_workspace_in_bytes(
config.first[0], config.first[1], config.first[2])); config.first[0], config.first[1], config.first[2]));
auto ret = WorkspaceBundle{ptr, std::move(sizes)}; auto ret = WorkspaceBundle{ptr, std::move(sizes)};
return ret; return ret;
...@@ -120,19 +115,9 @@ void ConvolutionBackwardFilterImpl::AlgoBFloat16::exec( ...@@ -120,19 +115,9 @@ void ConvolutionBackwardFilterImpl::AlgoBFloat16::exec(
.src_to_comp_type(*args.grad_tensor, fgrad_tensor); .src_to_comp_type(*args.grad_tensor, fgrad_tensor);
} }
{ {
auto conv_back_filter_opr = auto config = prepare_sub_opr(args);
args.handle->create_operator<ConvolutionBackwardFilter>(); config.second->exec(fsrc_tensor, fdiff_tensor, fgrad_tensor,
conv_back_filter_opr->param() = args.opr->param(); cvter.workspace());
conv_back_filter_opr->param().compute_mode =
Param::ComputeMode::DEFAULT;
if (args.opr->execution_policy().algo.valid()) {
megdnn_assert(args.opr->execution_policy().sub_policy.size() == 1);
conv_back_filter_opr->execution_policy() =
args.opr->execution_policy().sub_policy[0];
}
conv_back_filter_opr->exec(fsrc_tensor, fdiff_tensor, fgrad_tensor,
cvter.workspace());
} }
{ cvter.comp_to_dst_type(fgrad_tensor, *args.grad_tensor); } { cvter.comp_to_dst_type(fgrad_tensor, *args.grad_tensor); }
} }
......
...@@ -11,6 +11,7 @@ ...@@ -11,6 +11,7 @@
*/ */
#include "./algo.h" #include "./algo.h"
#include "src/common/algo_base.h"
#include "src/cuda/convolution/helper.h" #include "src/cuda/convolution/helper.h"
#include "src/cuda/convolution/im2col.cuh" #include "src/cuda/convolution/im2col.cuh"
#include "src/cuda/utils.h" #include "src/cuda/utils.h"
...@@ -42,6 +43,20 @@ std::pair<TensorLayoutArray, MatrixMulForward::Param> sub_opr_config( ...@@ -42,6 +43,20 @@ std::pair<TensorLayoutArray, MatrixMulForward::Param> sub_opr_config(
param.transposeB = true; param.transposeB = true;
return {{Cl, Bl, Al}, param}; return {{Cl, Bl, Al}, param};
} }
std::pair<TensorLayoutArray, std::unique_ptr<MatrixMulForward>> prepare_sub_opr(
const ConvolutionBackwardFilterImpl::AlgoBase::SizeArgs& args) {
auto matmul_opr = args.handle->create_operator<MatrixMulForward>();
set_execution_policy<ConvolutionBackwardFilter, MatrixMulForward*>(
args.opr, matmul_opr.get());
auto&& config =
sub_opr_config(args.grad_filter_meta, *args.src_layout,
*args.diff_layout, *args.grad_layout, args.opr);
matmul_opr->param() = config.second;
return {config.first, std::move(matmul_opr)};
}
} // namespace } // namespace
std::vector<Algorithm::SearchItem> std::vector<Algorithm::SearchItem>
...@@ -56,11 +71,9 @@ ConvolutionBackwardFilterImpl::AlgoMatmul::get_subopr_list( ...@@ -56,11 +71,9 @@ ConvolutionBackwardFilterImpl::AlgoMatmul::get_subopr_list(
std::string param_str; std::string param_str;
Algorithm::serialize_write_pod(config.second, param_str); Algorithm::serialize_write_pod(config.second, param_str);
return {{Algorithm::OprType::MATRIX_MUL_FORWARD, param_str, return {{Algorithm::OprType::MATRIX_MUL_FORWARD, param_str, config.first}};
config.first}};
} }
bool ConvolutionBackwardFilterImpl::AlgoMatmul::is_available( bool ConvolutionBackwardFilterImpl::AlgoMatmul::is_available(
const SizeArgs& args) const { const SizeArgs& args) const {
if (args.src_layout->dtype == args.diff_layout->dtype && if (args.src_layout->dtype == args.diff_layout->dtype &&
...@@ -75,21 +88,10 @@ bool ConvolutionBackwardFilterImpl::AlgoMatmul::is_available( ...@@ -75,21 +88,10 @@ bool ConvolutionBackwardFilterImpl::AlgoMatmul::is_available(
size_t ConvolutionBackwardFilterImpl::AlgoMatmul::get_workspace_in_bytes( size_t ConvolutionBackwardFilterImpl::AlgoMatmul::get_workspace_in_bytes(
const SizeArgs& args) const { const SizeArgs& args) const {
auto matmul_opr = args.handle->create_operator<MatrixMulForward>(); auto config = prepare_sub_opr(args);
if (args.opr->execution_policy().algo.valid() &&
!args.opr->execution_policy().sub_policy.empty()) {
megdnn_assert(args.opr->execution_policy().sub_policy.size() == 1);
matmul_opr->execution_policy() =
args.opr->execution_policy().sub_policy[0];
}
auto&& config =
sub_opr_config(args.grad_filter_meta, *args.src_layout,
*args.diff_layout, *args.grad_layout, args.opr);
matmul_opr->param() = config.second;
auto&& sizes = matmul_get_workspace_bundle(args.as_fwd_args()); auto&& sizes = matmul_get_workspace_bundle(args.as_fwd_args());
sizes.push_back(matmul_opr->get_workspace_in_bytes( sizes.push_back(config.second->get_workspace_in_bytes(
config.first[0], config.first[1], config.first[2])); config.first[0], config.first[1], config.first[2]));
return WorkspaceBundle(nullptr, sizes).total_size_in_bytes(); return WorkspaceBundle(nullptr, sizes).total_size_in_bytes();
} }
...@@ -121,19 +123,10 @@ void ConvolutionBackwardFilterImpl::AlgoMatmul::exec_internal( ...@@ -121,19 +123,10 @@ void ConvolutionBackwardFilterImpl::AlgoMatmul::exec_internal(
DW = fm.dilation[1]; DW = fm.dilation[1];
auto stream = cuda_stream(args.handle); auto stream = cuda_stream(args.handle);
auto matmul_opr = args.handle->create_operator<MatrixMulForward>(); auto config = prepare_sub_opr(args);
if (args.opr->execution_policy().algo.valid()) {
megdnn_assert(args.opr->execution_policy().sub_policy.size() == 1);
matmul_opr->execution_policy() =
args.opr->execution_policy().sub_policy[0];
}
auto&& config =
sub_opr_config(args.grad_filter_meta, *args.src_layout,
*args.diff_layout, *args.grad_layout, args.opr);
matmul_opr->param() = config.second;
auto&& sizes = matmul_get_workspace_bundle(args.as_fwd_args()); auto&& sizes = matmul_get_workspace_bundle(args.as_fwd_args());
sizes.push_back(matmul_opr->get_workspace_in_bytes( sizes.push_back(config.second->get_workspace_in_bytes(
config.first[0], config.first[1], config.first[2])); config.first[0], config.first[1], config.first[2]));
auto wbundle = WorkspaceBundle(args.workspace.raw_ptr, sizes); auto wbundle = WorkspaceBundle(args.workspace.raw_ptr, sizes);
...@@ -164,14 +157,14 @@ void ConvolutionBackwardFilterImpl::AlgoMatmul::exec_internal( ...@@ -164,14 +157,14 @@ void ConvolutionBackwardFilterImpl::AlgoMatmul::exec_internal(
TensorND A(args.grad_tensor->ptr<T>(), Al), B(col, Bl), C(diff_t, Cl); TensorND A(args.grad_tensor->ptr<T>(), Al), B(col, Bl), C(diff_t, Cl);
if (fm.should_flip) { if (fm.should_flip) {
A.raw_ptr = wbundle.get(2); A.raw_ptr = wbundle.get(2);
matmul_opr->exec(C, B, A, wbundle.get_workspace(3)); config.second->exec(C, B, A, wbundle.get_workspace(3));
convolution::flip_filter( convolution::flip_filter(
args.as_fwd_args(), args.as_fwd_args(),
{static_cast<dt_byte*>(args.grad_tensor->raw_ptr), {static_cast<dt_byte*>(args.grad_tensor->raw_ptr),
wbundle.get_size(2)}, wbundle.get_size(2)},
A.raw_ptr); A.raw_ptr);
} else { } else {
matmul_opr->exec(C, B, A, wbundle.get_workspace(2)); config.second->exec(C, B, A, wbundle.get_workspace(2));
} }
} }
} }
......
...@@ -65,6 +65,20 @@ std::pair<TensorLayoutArray, ConvBiasForward::Param> sub_opr_config( ...@@ -65,6 +65,20 @@ std::pair<TensorLayoutArray, ConvBiasForward::Param> sub_opr_config(
return ret; return ret;
} }
std::pair<TensorLayoutArray, std::unique_ptr<ConvBiasForward>> prepare_sub_opr(
const ConvolutionForwardImpl::AlgoBase::SizeArgs& args) {
auto conv_bias_opr = args.opr->handle()->create_operator<ConvBiasForward>();
set_execution_policy<ConvolutionForward, ConvBiasForward*>(
args.opr, conv_bias_opr.get());
auto&& config = sub_opr_config(
*args.layout_src, *args.layout_filter, *args.layout_dst,
args.opr);
conv_bias_opr->param() = config.second;
return {config.first, std::move(conv_bias_opr)};
}
} // namespace } // namespace
ConvolutionForwardImpl::AlgoPack::AlgoPack() { ConvolutionForwardImpl::AlgoPack::AlgoPack() {
...@@ -121,13 +135,8 @@ ConvolutionForwardImpl::AlgoDefault::get_subopr_list( ...@@ -121,13 +135,8 @@ ConvolutionForwardImpl::AlgoDefault::get_subopr_list(
bool ConvolutionForwardImpl::AlgoDefault::is_available( bool ConvolutionForwardImpl::AlgoDefault::is_available(
const SizeArgs& args) const { const SizeArgs& args) const {
auto conv_bias_opr = auto config = prepare_sub_opr(args);
args.opr->handle()->create_operator<ConvBiasForward>(); return get_algorithm(static_cast<ConvBiasForwardImpl*>(config.second.get()),
auto&& config = sub_opr_config(
*args.layout_src, *args.layout_filter, *args.layout_dst,
args.opr);
conv_bias_opr->param() = config.second;
return get_algorithm(static_cast<ConvBiasForwardImpl*>(conv_bias_opr.get()),
*args.layout_src, *args.layout_filter, config.first[0], *args.layout_src, *args.layout_filter, config.first[0],
config.first[1], *args.layout_dst); config.first[1], *args.layout_dst);
} }
...@@ -135,36 +144,15 @@ bool ConvolutionForwardImpl::AlgoDefault::is_available( ...@@ -135,36 +144,15 @@ bool ConvolutionForwardImpl::AlgoDefault::is_available(
size_t ConvolutionForwardImpl::AlgoDefault::get_workspace_in_bytes( size_t ConvolutionForwardImpl::AlgoDefault::get_workspace_in_bytes(
const SizeArgs& args) const { const SizeArgs& args) const {
auto conv_bias_opr = args.opr->handle()->create_operator<ConvBiasForward>(); auto config = prepare_sub_opr(args);
if (args.opr->execution_policy().algo.valid() && return config.second->get_workspace_in_bytes(
!args.opr->execution_policy().sub_policy.empty()) {
megdnn_assert(args.opr->execution_policy().sub_policy.size() == 1);
conv_bias_opr->execution_policy() =
args.opr->execution_policy().sub_policy[0];
}
auto&& config = sub_opr_config(
*args.layout_src, *args.layout_filter, *args.layout_dst,
args.opr);
conv_bias_opr->param() = config.second;
return conv_bias_opr->get_workspace_in_bytes(
*args.layout_src, *args.layout_filter, config.first[0], *args.layout_src, *args.layout_filter, config.first[0],
config.first[1], *args.layout_dst, nullptr); config.first[1], *args.layout_dst, nullptr);
} }
void ConvolutionForwardImpl::AlgoDefault::exec(const ExecArgs& args) const { void ConvolutionForwardImpl::AlgoDefault::exec(const ExecArgs& args) const {
auto conv_bias_opr = args.opr->handle()->create_operator<ConvBiasForward>(); auto config = prepare_sub_opr(args);
if (args.opr->execution_policy().algo.valid()) { config.second->exec(args.tensor_src, args.tensor_filter,
megdnn_assert(args.opr->execution_policy().sub_policy.size() == 1);
conv_bias_opr->execution_policy() =
args.opr->execution_policy().sub_policy[0];
}
auto&& config = sub_opr_config(
*args.layout_src, *args.layout_filter, *args.layout_dst,
args.opr);
conv_bias_opr->param() = config.second;
conv_bias_opr->exec(args.tensor_src, args.tensor_filter,
{nullptr, config.first[0]}, {nullptr, config.first[1]}, {nullptr, config.first[0]}, {nullptr, config.first[1]},
args.tensor_dst, nullptr, args.workspace); args.tensor_dst, nullptr, args.workspace);
} }
......
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
#include "src/cuda/deformable_conv/bwd_data/algo.h" #include "src/cuda/deformable_conv/bwd_data/algo.h"
#include "src/cuda/deformable_conv/kimpl/deformable_conv.cuh" #include "src/cuda/deformable_conv/kimpl/deformable_conv.cuh"
#include "src/cuda/deformable_conv/opr_impl.h" #include "src/cuda/deformable_conv/opr_impl.h"
#include "src/common/algo_base.h"
using namespace megdnn; using namespace megdnn;
using namespace cuda; using namespace cuda;
...@@ -79,15 +80,28 @@ std::pair<TensorLayoutArray, BatchedMatrixMulForward::Param> sub_opr_config( ...@@ -79,15 +80,28 @@ std::pair<TensorLayoutArray, BatchedMatrixMulForward::Param> sub_opr_config(
return {{al, bl, cl}, param}; return {{al, bl, cl}, param};
} }
std::pair<TensorLayoutArray, std::unique_ptr<BatchedMatrixMulForward>>
prepare_sub_opr(
const DeformableConvBackwardDataImpl::AlgoBase::SizeArgs& args) {
auto bmatmul_opr = args.handle->create_operator<BatchedMatrixMulForward>();
set_execution_policy<DeformableConvBackwardData, BatchedMatrixMulForward*>(
args.opr, bmatmul_opr.get());
auto&& config = sub_opr_config(args.filter_meta, args.im_layout,
args.out_grad_layout);
bmatmul_opr->param() = config.second;
return {config.first, std::move(bmatmul_opr)};
}
}; // anonymous namespace }; // anonymous namespace
std::vector<Algorithm::SearchItem> std::vector<Algorithm::SearchItem> Algo::get_subopr_list(
Algo::get_subopr_list(
const TensorLayoutArray& layouts, const OperatorBase* opr) const { const TensorLayoutArray& layouts, const OperatorBase* opr) const {
const DeformableConvBackwardDataImpl* deformable_conv = const DeformableConvBackwardDataImpl* deformable_conv =
static_cast<const DeformableConvBackwardDataImpl*>(opr); static_cast<const DeformableConvBackwardDataImpl*>(opr);
CanonizedFilterMeta fm = deformable_conv->make_canonized_filter_meta( CanonizedFilterMeta fm = deformable_conv->make_canonized_filter_meta(
layouts[0].ndim, layouts[1], layouts[2]); layouts[0].ndim, layouts[1], layouts[2]);
auto&& config = sub_opr_config(fm, layouts[0], layouts[4]); auto&& config = sub_opr_config(fm, layouts[0], layouts[4]);
std::string param_str; std::string param_str;
...@@ -106,19 +120,9 @@ WorkspaceBundle Algo::get_bundle(const SizeArgs& args) { ...@@ -106,19 +120,9 @@ WorkspaceBundle Algo::get_bundle(const SizeArgs& args) {
OC = args.out_grad_layout[1], OH = args.out_grad_layout[2], OC = args.out_grad_layout[1], OH = args.out_grad_layout[2],
OW = args.out_grad_layout[3], FH = fm.spatial[0], FW = fm.spatial[1]; OW = args.out_grad_layout[3], FH = fm.spatial[0], FW = fm.spatial[1];
auto bmatmul_opr = args.handle->create_operator<BatchedMatrixMulForward>(); auto config = prepare_sub_opr(args);
if (args.opr->execution_policy().algo.valid() &&
!args.opr->execution_policy().sub_policy.empty()) {
megdnn_assert(args.opr->execution_policy().sub_policy.size() == 1);
bmatmul_opr->execution_policy() =
args.opr->execution_policy().sub_policy[0];
}
auto&& config = sub_opr_config(args.filter_meta, args.im_layout,
args.out_grad_layout);
bmatmul_opr->param() = config.second;
size_t bmm_ws = bmatmul_opr->get_workspace_in_bytes( size_t bmm_ws = config.second->get_workspace_in_bytes(
config.first[0], config.first[1], config.first[2]); config.first[0], config.first[1], config.first[2]);
size_t result_ws = batch_sz * IC * FH * FW * OH * OW * sizeof(float); size_t result_ws = batch_sz * IC * FH * FW * OH * OW * sizeof(float);
size_t relayout_ws1 = batch_sz * OC * OH * OW * sizeof(float); size_t relayout_ws1 = batch_sz * OC * OH * OW * sizeof(float);
...@@ -183,24 +187,14 @@ void Algo::exec(const ExecArgs& args) const { ...@@ -183,24 +187,14 @@ void Algo::exec(const ExecArgs& args) const {
// matmul [g, icpg, FH, FW, ocpg] * [g, ocpg, N, OH, OW] => // matmul [g, icpg, FH, FW, ocpg] * [g, ocpg, N, OH, OW] =>
// => [g, icpg, FH, FW, N, OH, OW] // => [g, icpg, FH, FW, N, OH, OW]
{ {
auto bmatmul_opr = auto config = prepare_sub_opr(args);
args.handle->create_operator<BatchedMatrixMulForward>();
if (args.opr->execution_policy().algo.valid()) {
megdnn_assert(args.opr->execution_policy().sub_policy.size() == 1);
bmatmul_opr->execution_policy() =
args.opr->execution_policy().sub_policy[0];
}
auto&& config = sub_opr_config(args.filter_meta, args.im_layout,
args.out_grad_layout);
bmatmul_opr->param() = config.second;
TensorND A(static_cast<void*>(dev_filter), config.first[0]), TensorND A(static_cast<void*>(dev_filter), config.first[0]),
B(static_cast<void*>(relayout_ws1), config.first[1]), B(static_cast<void*>(relayout_ws1), config.first[1]),
C(static_cast<void*>(result_ws), config.first[2]); C(static_cast<void*>(result_ws), config.first[2]);
size_t bmm_ws_size = bundle.get_size(0); size_t bmm_ws_size = bundle.get_size(0);
bmatmul_opr->exec( config.second->exec(
A, B, C, A, B, C,
Workspace(static_cast<megdnn::dt_byte*>(bmm_ws), bmm_ws_size)); Workspace(static_cast<megdnn::dt_byte*>(bmm_ws), bmm_ws_size));
} }
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
#include "src/cuda/deformable_conv/bwd_flt/algo.h" #include "src/cuda/deformable_conv/bwd_flt/algo.h"
#include "src/cuda/deformable_conv/kimpl/deformable_conv.cuh" #include "src/cuda/deformable_conv/kimpl/deformable_conv.cuh"
#include "src/cuda/deformable_conv/opr_impl.h" #include "src/cuda/deformable_conv/opr_impl.h"
#include "src/common/algo_base.h"
using namespace megdnn; using namespace megdnn;
using namespace cuda; using namespace cuda;
...@@ -79,10 +80,23 @@ std::pair<TensorLayoutArray, BatchedMatrixMulForward::Param> sub_opr_config( ...@@ -79,10 +80,23 @@ std::pair<TensorLayoutArray, BatchedMatrixMulForward::Param> sub_opr_config(
return {{al, bl, cl}, param}; return {{al, bl, cl}, param};
} }
std::pair<TensorLayoutArray, std::unique_ptr<BatchedMatrixMulForward>>
prepare_sub_opr(
const DeformableConvBackwardFilterImpl::AlgoBase::SizeArgs& args) {
auto bmatmul_opr = args.handle->create_operator<BatchedMatrixMulForward>();
set_execution_policy<DeformableConvBackwardFilter,
BatchedMatrixMulForward*>(args.opr, bmatmul_opr.get());
auto&& config = sub_opr_config(args.filter_grad_meta, args.im_layout,
args.out_grad_layout);
bmatmul_opr->param() = config.second;
return {config.first, std::move(bmatmul_opr)};
}
}; // anonymous namespace }; // anonymous namespace
std::vector<Algorithm::SearchItem> std::vector<Algorithm::SearchItem> Algo::get_subopr_list(
Algo::get_subopr_list(
const TensorLayoutArray& layouts, const OperatorBase* opr) const { const TensorLayoutArray& layouts, const OperatorBase* opr) const {
const DeformableConvBackwardFilterImpl* deformable_conv = const DeformableConvBackwardFilterImpl* deformable_conv =
static_cast<const DeformableConvBackwardFilterImpl*>(opr); static_cast<const DeformableConvBackwardFilterImpl*>(opr);
...@@ -107,21 +121,11 @@ WorkspaceBundle Algo::get_bundle(const SizeArgs& args) { ...@@ -107,21 +121,11 @@ WorkspaceBundle Algo::get_bundle(const SizeArgs& args) {
size_t IC = fm.group * fm.icpg, OC = args.out_grad_layout[1]; size_t IC = fm.group * fm.icpg, OC = args.out_grad_layout[1];
auto batch_sz = args.im_layout[0]; auto batch_sz = args.im_layout[0];
auto bmatmul_opr = args.handle->create_operator<BatchedMatrixMulForward>(); auto config = prepare_sub_opr(args);
if (args.opr->execution_policy().algo.valid() &&
!args.opr->execution_policy().sub_policy.empty()) {
megdnn_assert(args.opr->execution_policy().sub_policy.size() == 1);
bmatmul_opr->execution_policy() =
args.opr->execution_policy().sub_policy[0];
}
auto&& config = sub_opr_config(args.filter_grad_meta, args.im_layout,
args.out_grad_layout);
bmatmul_opr->param() = config.second;
size_t col_ws = batch_sz * IC * FH * FW * OH * OW * sizeof(float); size_t col_ws = batch_sz * IC * FH * FW * OH * OW * sizeof(float);
size_t out_grad_ws = batch_sz * OC * OH * OW * sizeof(float); size_t out_grad_ws = batch_sz * OC * OH * OW * sizeof(float);
size_t bmm_ws = bmatmul_opr->get_workspace_in_bytes( size_t bmm_ws = config.second->get_workspace_in_bytes(
config.first[0], config.first[1], config.first[2]); config.first[0], config.first[1], config.first[2]);
return {nullptr, {col_ws, out_grad_ws, bmm_ws}}; return {nullptr, {col_ws, out_grad_ws, bmm_ws}};
...@@ -166,23 +170,14 @@ void Algo::exec(const ExecArgs& args) const { ...@@ -166,23 +170,14 @@ void Algo::exec(const ExecArgs& args) const {
args.handle->relayout_opr()->exec(C2, C3); args.handle->relayout_opr()->exec(C2, C3);
// matmul // matmul
auto bmatmul_opr = args.handle->create_operator<BatchedMatrixMulForward>(); auto config = prepare_sub_opr(args);
if (args.opr->execution_policy().algo.valid()) {
megdnn_assert(args.opr->execution_policy().sub_policy.size() == 1);
bmatmul_opr->execution_policy() =
args.opr->execution_policy().sub_policy[0];
}
auto&& config = sub_opr_config(args.filter_grad_meta, args.im_layout,
args.out_grad_layout);
bmatmul_opr->param() = config.second;
TensorND A(static_cast<void*>(out_grad_ws), config.first[0]), TensorND A(static_cast<void*>(out_grad_ws), config.first[0]),
B(static_cast<void*>(col_ws), config.first[1]), B(static_cast<void*>(col_ws), config.first[1]),
C(static_cast<void*>(dev_filter_grad), config.first[2]); C(static_cast<void*>(dev_filter_grad), config.first[2]);
size_t bmm_ws_size = bundle.get_size(2); size_t bmm_ws_size = bundle.get_size(2);
bmatmul_opr->exec( config.second->exec(
A, B, C, A, B, C,
Workspace(static_cast<megdnn::dt_byte*>(bmm_ws), bmm_ws_size)); Workspace(static_cast<megdnn::dt_byte*>(bmm_ws), bmm_ws_size));
} }
......
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
#include "src/cuda/batched_matrix_mul/algo.h" #include "src/cuda/batched_matrix_mul/algo.h"
#include "src/cuda/deformable_conv/fwd/algo.h" #include "src/cuda/deformable_conv/fwd/algo.h"
#include "src/cuda/deformable_conv/kimpl/deformable_conv.cuh" #include "src/cuda/deformable_conv/kimpl/deformable_conv.cuh"
#include "src/common/algo_base.h"
using namespace megdnn; using namespace megdnn;
using namespace cuda; using namespace cuda;
...@@ -78,15 +79,27 @@ std::pair<TensorLayoutArray, BatchedMatrixMulForward::Param> sub_opr_config( ...@@ -78,15 +79,27 @@ std::pair<TensorLayoutArray, BatchedMatrixMulForward::Param> sub_opr_config(
return {{al, bl, cl}, param}; return {{al, bl, cl}, param};
} }
std::pair<TensorLayoutArray, std::unique_ptr<BatchedMatrixMulForward>>
prepare_sub_opr(const DeformableConvForwardImpl::AlgoBase::SizeArgs& args) {
auto bmatmul_opr = args.handle->create_operator<BatchedMatrixMulForward>();
set_execution_policy<DeformableConvForward, BatchedMatrixMulForward*>(
args.opr, bmatmul_opr.get());
auto&& config =
sub_opr_config(args.filter_meta, args.im_layout, args.dst_layout);
bmatmul_opr->param() = config.second;
return {config.first, std::move(bmatmul_opr)};
}
}; // anonymous namespace }; // anonymous namespace
std::vector<Algorithm::SearchItem> std::vector<Algorithm::SearchItem> Algo::get_subopr_list(
Algo::get_subopr_list(
const TensorLayoutArray& layouts, const OperatorBase* opr) const { const TensorLayoutArray& layouts, const OperatorBase* opr) const {
const DeformableConvForwardImpl* deformable_conv = const DeformableConvForwardImpl* deformable_conv =
static_cast<const DeformableConvForwardImpl*>(opr); static_cast<const DeformableConvForwardImpl*>(opr);
CanonizedFilterMeta fm = deformable_conv->make_canonized_filter_meta( CanonizedFilterMeta fm = deformable_conv->make_canonized_filter_meta(
layouts[0].ndim, layouts[1], layouts[2]); layouts[0].ndim, layouts[1], layouts[2]);
auto&& config = sub_opr_config(fm, layouts[0], layouts[4]); auto&& config = sub_opr_config(fm, layouts[0], layouts[4]);
std::string param_str; std::string param_str;
...@@ -95,7 +108,6 @@ Algo::get_subopr_list( ...@@ -95,7 +108,6 @@ Algo::get_subopr_list(
config.first}}; config.first}};
} }
bool Algo::is_available(const SizeArgs&) const { bool Algo::is_available(const SizeArgs&) const {
return true; return true;
} }
...@@ -106,20 +118,10 @@ WorkspaceBundle Algo::get_bundle(const SizeArgs& args) { ...@@ -106,20 +118,10 @@ WorkspaceBundle Algo::get_bundle(const SizeArgs& args) {
OC = args.dst_layout[1], OH = args.dst_layout[2], OC = args.dst_layout[1], OH = args.dst_layout[2],
OW = args.dst_layout[3], FH = fm.spatial[0], FW = fm.spatial[1]; OW = args.dst_layout[3], FH = fm.spatial[0], FW = fm.spatial[1];
auto bmatmul_opr = args.handle->create_operator<BatchedMatrixMulForward>(); auto config = prepare_sub_opr(args);
if (args.opr->execution_policy().algo.valid() &&
!args.opr->execution_policy().sub_policy.empty()) {
megdnn_assert(args.opr->execution_policy().sub_policy.size() == 1);
bmatmul_opr->execution_policy() =
args.opr->execution_policy().sub_policy[0];
}
auto&& config =
sub_opr_config(args.filter_meta, args.im_layout, args.dst_layout);
bmatmul_opr->param() = config.second;
size_t col_ws = batch_sz * IC * FH * FW * OH * OW * sizeof(float); size_t col_ws = batch_sz * IC * FH * FW * OH * OW * sizeof(float);
size_t bmm_ws = bmatmul_opr->get_workspace_in_bytes( size_t bmm_ws = config.second->get_workspace_in_bytes(
config.first[0], config.first[1], config.first[2]); config.first[0], config.first[1], config.first[2]);
size_t result_ws = batch_sz * OC * OH * OW * sizeof(float); size_t result_ws = batch_sz * OC * OH * OW * sizeof(float);
...@@ -154,16 +156,7 @@ void Algo::exec(const ExecArgs& args) const { ...@@ -154,16 +156,7 @@ void Algo::exec(const ExecArgs& args) const {
deformable_conv::im2col(dev_im, dev_offset, dev_mask, deformable_conv::im2col(dev_im, dev_offset, dev_mask,
static_cast<float*>(col_ws), p); static_cast<float*>(col_ws), p);
auto bmatmul_opr = args.handle->create_operator<BatchedMatrixMulForward>(); auto config = prepare_sub_opr(args);
if (args.opr->execution_policy().algo.valid()) {
megdnn_assert(args.opr->execution_policy().sub_policy.size() == 1);
bmatmul_opr->execution_policy() =
args.opr->execution_policy().sub_policy[0];
}
auto&& config =
sub_opr_config(args.filter_meta, args.im_layout, args.dst_layout);
bmatmul_opr->param() = config.second;
// matmul // matmul
TensorND A(static_cast<void*>(dev_filter), config.first[0]), TensorND A(static_cast<void*>(dev_filter), config.first[0]),
...@@ -171,7 +164,7 @@ void Algo::exec(const ExecArgs& args) const { ...@@ -171,7 +164,7 @@ void Algo::exec(const ExecArgs& args) const {
C(static_cast<void*>(result_ws), config.first[2]); C(static_cast<void*>(result_ws), config.first[2]);
size_t bmm_ws_size = bundle.get_size(1); size_t bmm_ws_size = bundle.get_size(1);
bmatmul_opr->exec( config.second->exec(
A, B, C, A, B, C,
Workspace(static_cast<megdnn::dt_byte*>(bmm_ws), bmm_ws_size)); Workspace(static_cast<megdnn::dt_byte*>(bmm_ws), bmm_ws_size));
// relayout // relayout
......
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
#include "src/cuda/matrix_mul/algos.h" #include "src/cuda/matrix_mul/algos.h"
#include "src/cuda/utils.h" #include "src/cuda/utils.h"
#include "src/common/algo_chooser.h" #include "src/common/algo_chooser.h"
#include "src/common/algo_base.h"
using namespace megdnn; using namespace megdnn;
using namespace cuda; using namespace cuda;
...@@ -37,6 +38,15 @@ std::pair<TensorLayoutArray, MatrixMulForwardImpl::Param> sub_opr_config( ...@@ -37,6 +38,15 @@ std::pair<TensorLayoutArray, MatrixMulForwardImpl::Param> sub_opr_config(
ret.second.compute_mode = MatrixMulForwardImpl::Param::ComputeMode::DEFAULT; ret.second.compute_mode = MatrixMulForwardImpl::Param::ComputeMode::DEFAULT;
return ret; return ret;
} }
std::pair<TensorLayoutArray, std::unique_ptr<MatrixMulForward>> prepare_sub_opr(
const MatrixMulForwardImpl::AlgoBase::SizeArgs& args) {
auto&& config = sub_opr_config(
{args.layout_a, args.layout_b, args.layout_c}, args.opr);
auto matmul_opr = args.opr->handle()->create_operator<MatrixMulForward>();
matmul_opr->param() = config.second;
return {config.first, std::move(matmul_opr)};
}
} // namespace } // namespace
std::vector<Algorithm::SearchItem> std::vector<Algorithm::SearchItem>
...@@ -52,27 +62,16 @@ MatrixMulForwardImpl::AlgoBFloat16::get_subopr_list( ...@@ -52,27 +62,16 @@ MatrixMulForwardImpl::AlgoBFloat16::get_subopr_list(
bool MatrixMulForwardImpl::AlgoBFloat16::is_available( bool MatrixMulForwardImpl::AlgoBFloat16::is_available(
const SizeArgs& args) const { const SizeArgs& args) const {
auto&& config = sub_opr_config( auto config = prepare_sub_opr(args);
{args.layout_a, args.layout_b, args.layout_c}, args.opr);
auto matmul_opr = args.opr->handle()->create_operator<MatrixMulForward>();
matmul_opr->param() = config.second;
return args.layout_a.dtype == dtype::BFloat16() && return args.layout_a.dtype == dtype::BFloat16() &&
get_algorithm(static_cast<MatrixMulForwardImpl*>(matmul_opr.get()), get_algorithm(
config.first[0], config.first[1], config.first[2]); static_cast<MatrixMulForwardImpl*>(config.second.get()),
config.first[0], config.first[1], config.first[2]);
} }
WorkspaceBundle MatrixMulForwardImpl::AlgoBFloat16::get_workspace_bundle( WorkspaceBundle MatrixMulForwardImpl::AlgoBFloat16::get_workspace_bundle(
void* ptr, const SizeArgs& args) const { void* ptr, const SizeArgs& args) const {
auto matmul_opr = args.opr->handle()->create_operator<MatrixMulForward>(); auto config = prepare_sub_opr(args);
if (args.opr->execution_policy().algo.valid()) {
megdnn_assert(args.opr->execution_policy().sub_policy.size() == 1);
matmul_opr->execution_policy() =
args.opr->execution_policy().sub_policy[0];
}
auto&& config = sub_opr_config(
{args.layout_a, args.layout_b, args.layout_c}, args.opr);
matmul_opr->param() = config.second;
SmallVector<size_t> sizes; SmallVector<size_t> sizes;
auto get_workspace = [&sizes](const TensorLayout& src, auto get_workspace = [&sizes](const TensorLayout& src,
...@@ -85,7 +84,7 @@ WorkspaceBundle MatrixMulForwardImpl::AlgoBFloat16::get_workspace_bundle( ...@@ -85,7 +84,7 @@ WorkspaceBundle MatrixMulForwardImpl::AlgoBFloat16::get_workspace_bundle(
get_workspace(args.layout_a, config.first[0]); get_workspace(args.layout_a, config.first[0]);
get_workspace(args.layout_b, config.first[1]); get_workspace(args.layout_b, config.first[1]);
get_workspace(args.layout_c, config.first[2]); get_workspace(args.layout_c, config.first[2]);
sizes.push_back(matmul_opr->get_workspace_in_bytes( sizes.push_back(config.second->get_workspace_in_bytes(
config.first[0], config.first[1], config.first[2])); config.first[0], config.first[1], config.first[2]));
return {ptr, std::move(sizes)}; return {ptr, std::move(sizes)};
} }
...@@ -106,17 +105,8 @@ void MatrixMulForwardImpl::AlgoBFloat16::exec(const ExecArgs& args) const { ...@@ -106,17 +105,8 @@ void MatrixMulForwardImpl::AlgoBFloat16::exec(const ExecArgs& args) const {
.src_to_comp_type(args.tensor_b, b) .src_to_comp_type(args.tensor_b, b)
.src_to_comp_type(args.tensor_c, c); .src_to_comp_type(args.tensor_c, c);
{ {
auto matmul_opr = auto config = prepare_sub_opr(args);
args.opr->handle()->create_operator<MatrixMulForward>(); config.second->exec(a, b, c, ctypecvt.workspace());
matmul_opr->param() = args.opr->param();
matmul_opr->param().compute_mode = Param::ComputeMode::DEFAULT;
if (args.opr->execution_policy().algo.valid()) {
megdnn_assert(args.opr->execution_policy().sub_policy.size() == 1);
matmul_opr->execution_policy() =
args.opr->execution_policy().sub_policy[0];
}
matmul_opr->exec(a, b, c, ctypecvt.workspace());
} }
ctypecvt.comp_to_dst_type(c, args.tensor_c); ctypecvt.comp_to_dst_type(c, args.tensor_c);
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册