提交 58b682ca 编写于 作者: M Megvii Engine Team

feat(dnn/cuda): add naive bmm

GitOrigin-RevId: 4ba4b22e40368dd0918d02f9d01c24dd0a815640
上级 eddb0aba
......@@ -45,6 +45,7 @@ BatchedMatrixMulForwardImpl::AlgoPack::AlgoPack() {
#endif
all_algos.push_back(&int8x8x32);
all_algos.push_back(&brute_force);
all_algos.push_back(&naive_bmm);
for (auto&& algo : all_algos) {
m_all_algos_map.emplace(algo->info().desc, algo);
......
......@@ -24,6 +24,7 @@ public:
CUDA_CUBLAS,
CUDA_CUBLASLT,
CUDA_INT8X8X32,
CUDA_NAIVE_BMM,
};
using Mapper = std::unordered_map<AlgorithmDesc, AlgoBase*>;
......@@ -94,6 +95,27 @@ public:
std::vector<SearchItem> get_subopr_list(
const TensorLayoutArray& layouts, const OperatorBase* opr) const override;
};
class BatchedMatrixMulForwardImpl::AlgoNaive final
: public BatchedMatrixMulForwardImpl::AlgoBase {
using Param = MatrixMulForward::Param;
private:
WorkspaceBundle get_workspace_bundle();
public:
bool is_available(const SizeArgs& args) const override;
size_t get_workspace_in_bytes(const SizeArgs& /*args*/) const override {
return 0;
};
void exec(const ExecArgs& args) const final;
AlgoAttribute attribute() const override {
return AlgoAttribute::REPRODUCIBLE | AlgoAttribute::NAIVE;
}
const char* name() const override { return "NAIVE_BMM"; }
MEGDNN_DECL_ALGO_TYPE(CUDA_NAIVE_BMM)
};
class BatchedMatrixMulForwardImpl::AlgoCublas final
: public BatchedMatrixMulForwardImpl::AlgoBase {
public:
......@@ -148,6 +170,7 @@ public:
AlgoInt8x8x32 int8x8x32;
std::vector<AlgoBase*> all_algos;
AlgoBruteForce brute_force;
AlgoNaive naive_bmm;
const AlgoBase::Mapper& all_algos_map() const { return m_all_algos_map; }
};
......
/**
* \file dnn/src/cuda/batched_matrix_mul/naive.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#include "src/cuda/batched_matrix_mul/naive.cuh"
#include <cuda.h>
#include "src/cuda/batched_matrix_mul/algo.h"
#include "src/cuda/utils.h"
using namespace megdnn;
using namespace cuda;
#include "midout.h"
MIDOUT_DECL(megdnn_naive_matmul)
bool BatchedMatrixMulForwardImpl::AlgoNaive::is_available(const SizeArgs& args) const {
auto&& layout_a = args.layout_a;
auto&& layout_b = args.layout_b;
auto&& layout_c = args.layout_c;
return layout_a.dtype.enumv() == layout_b.dtype.enumv() &&
(layout_a.dtype.enumv() == DTypeEnum::Float32 ||
layout_a.dtype.enumv() == DTypeEnum::Float16) &&
(layout_c.dtype.enumv() == DTypeEnum::Float32 ||
layout_c.dtype.enumv() == DTypeEnum::Float16) &&
args.opr->param().format == param::MatrixMul::Format::DEFAULT;
}
void BatchedMatrixMulForwardImpl::AlgoNaive::exec(const ExecArgs& args) const {
auto&& param = args.opr->param();
auto Batch = args.tensor_c.layout.shape[0];
auto m = args.tensor_c.layout.shape[1], n = args.tensor_c.layout.shape[2],
k = args.tensor_a.layout.shape[param.transposeA ? 1 : 2];
auto LDA = args.tensor_a.layout.stride[1], LDB = args.tensor_b.layout.stride[1],
LDC = args.tensor_c.layout.stride[1];
auto&& handle = concrete_handle(args.opr->handle());
using ComputeMode = Param::ComputeMode;
#define DISPATCH_CMODE(in_dt, out_dt, in_ct, out_ct, comp_ct, cmode) \
MIDOUT_BEGIN( \
megdnn_naive_matmul, \
midout_iv(#in_dt #out_dt #in_ct, #out_ct, #comp_ct, #cmode)) { \
do { \
using namespace dtype; \
if (args.tensor_a.layout.dtype.enumv() == DTypeTrait<in_dt>::enumv && \
args.tensor_c.layout.dtype.enumv() == DTypeTrait<out_dt>::enumv && \
param.compute_mode == cmode) { \
in_ct* A = args.tensor_a.compatible_ptr<in_ct>(); \
in_ct* B = args.tensor_b.compatible_ptr<in_ct>(); \
out_ct* C = args.tensor_c.compatible_ptr<out_ct>(); \
exec_bgemm_naive<in_ct, in_ct, out_ct, comp_ct>( \
A, B, C, Batch, m, n, k, LDA, LDB, LDC, param.transposeA, \
param.transposeB, cuda_stream(handle)); \
return; \
} \
} while (0); \
} \
MIDOUT_END();
#define DISPATCH(in_dt, out_dt, in_ct, out_ct, comp_ct) \
DISPATCH_CMODE(in_dt, out_dt, in_ct, out_ct, comp_ct, ComputeMode::DEFAULT)
DISPATCH(Float32, Float32, dt_float32, dt_float32, dt_float32);
DISPATCH(Float16, Float16, dt_float16, dt_float16, dt_float16);
DNN_INC_FLOAT16(DISPATCH_CMODE(
Float16, Float16, dt_float16, dt_float16, dt_float32,
ComputeMode::FLOAT32));
#undef DISPATCH_CMODE
#undef DISPATCH
megdnn_throw(ssprintf(
"unsupported Matmul(%s, %s) -> %s with cmode = %d",
args.layout_a.dtype.name(), args.layout_b.dtype.name(),
args.layout_c.dtype.name(), static_cast<int>(param.compute_mode)));
}
// vim: syntax=cpp.doxygen
/**
* \file dnn/src/cuda/batched_matrix_mul/naive.cu
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#include <cuda.h>
#include "src/cuda/matrix_mul/naive.cuh"
#include "src/cuda/utils.cuh"
namespace {
template <typename AType, typename BType, typename CType, typename CompType>
__global__ void do_exec(
const AType* A, const BType* B, CType* C, size_t Batch, size_t M, size_t N,
size_t K, size_t LDA, size_t LDB, size_t LDC, bool transA, bool transB) {
for (int bid = blockIdx.x; bid < Batch; bid += gridDim.x) {
const AType* A_r = A + (transA ? bid * K * LDA : bid * M * LDA);
const BType* B_r = B + (transB ? bid * N * LDB : bid * K * LDB);
CType* C_r = C + bid * M * LDC;
for (size_t m = 0; m < M; ++m) {
size_t n = threadIdx.x;
for (; n < N; n += blockDim.x) {
CompType res = static_cast<CompType>(0);
for (size_t k = 0; k < K; ++k) {
AType av = transA ? A_r[k * LDA + m] : A_r[m * LDA + k];
BType bv = transB ? B_r[n * LDB + k] : B_r[k * LDB + n];
res += av * bv;
}
C_r[m * LDC + n] = res;
}
}
}
}
} // namespace
namespace megdnn {
namespace cuda {
template <typename AType, typename BType, typename CType, typename CompType>
void exec_bgemm_naive(
const AType* A, const BType* B, CType* C, size_t Batch, size_t M, size_t N,
size_t K, size_t LDA, size_t LDB, size_t LDC, bool transA, bool transB,
cudaStream_t stream) {
do_exec<AType, BType, CType, CompType><<<Batch, 128, 0, stream>>>(
A, B, C, Batch, M, N, K, LDA, LDB, LDC, transA, transB);
}
#define INST(in_ct, out_ct, comp_ct) \
template void exec_bgemm_naive< \
typename in_ct, typename in_ct, typename out_ct, typename comp_ct>( \
const in_ct* A, const in_ct* B, out_ct* C, size_t Batch, size_t M, \
size_t N, size_t K, size_t LDA, size_t LDB, size_t LDC, bool transA, \
bool transB, cudaStream_t stream);
INST(megdnn::dt_float32, megdnn::dt_float32, megdnn::dt_float32)
INST(megdnn::dt_float16, megdnn::dt_float16, megdnn::dt_float16)
INST(megdnn::dt_float16, megdnn::dt_float16, megdnn::dt_float32)
#undef cb
#undef INST
} // namespace cuda
} // namespace megdnn
// vim: syntax=cpp.doxygen
/**
* \file dnn/src/cuda/batched_matrix_mul/naive.cuh
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#pragma once
#include "src/cuda/utils.cuh"
namespace megdnn {
namespace cuda {
template <typename AType, typename BType, typename CType, typename CompType>
void exec_bgemm_naive(
const AType* A, const BType* B, CType* C, size_t Batch, size_t m, size_t n,
size_t k, size_t ldA, size_t ldB, size_t ldC, bool transA, bool transB,
cudaStream_t stream);
} // namespace cuda
} // namespace megdnn
// vim: syntax=cpp.doxygen
......@@ -10,6 +10,7 @@ public:
BatchedMatrixMulForwardImpl(Handle* handle) : BatchedMatrixMul(handle) {}
class AlgoBase;
class AlgoNaive;
class AlgoBruteForce;
class AlgoCublas;
#if CUDA_VERSION >= 10010
......
......@@ -266,19 +266,14 @@ void matrix_mul::check_matrix_mul(
checker.set_param(param);
if (format == param::MatrixMul::Format::DEFAULT) {
if (batched) {
checker.execl(
{TensorLayout{
{arg.b, A0, A1},
{A_batch_stride, A_stride, 1},
A_dtype},
TensorLayout{
{arg.b, B0, B1},
{B_batch_stride, B_stride, 1},
B_dtype},
TensorLayout{
{arg.b, m, n},
{C_batch_stride, C_stride, 1},
C_dtype}});
auto a_layout = TensorLayout{
{arg.b, A0, A1}, {A_batch_stride, A_stride, 1}, A_dtype};
auto b_layout = TensorLayout{
{arg.b, B0, B1}, {B_batch_stride, B_stride, 1}, B_dtype};
auto c_layout = TensorLayout{
{arg.b, m, n}, {C_batch_stride, C_stride, 1}, C_dtype};
checker.execl({a_layout, b_layout, c_layout});
} else {
checker.execl(
{TensorLayout{{A0, A1}, {A_stride, 1}, A_dtype},
......
......@@ -105,6 +105,62 @@ TEST_F(CUDA, BATCHED_MATRIX_MUL_F32_BRUTE_FORCE_PART4) {
matrix_mul::get_batched_matmul_args_mask(3));
}
TEST_F(CUDA, BATCHED_MATRIX_MUL_F32_NAIVE_PART0) {
matrix_mul::check_batched_matrix_mul(
dtype::Float32{}, dtype::Float32{}, {}, handle_cuda(),
ExecutionPolicyAlgoName{"NAIVE_BMM"}, 1e-5,
matrix_mul::get_batched_matmul_args_mask(0));
}
TEST_F(CUDA, BATCHED_MATRIX_MUL_F32_NAIVE_PART1) {
matrix_mul::check_batched_matrix_mul(
dtype::Float32{}, dtype::Float32{}, {}, handle_cuda(),
ExecutionPolicyAlgoName{"NAIVE_BMM"}, 1e-5,
matrix_mul::get_batched_matmul_args_mask(1));
}
TEST_F(CUDA, BATCHED_MATRIX_MUL_F32_NAIVE_PART2) {
matrix_mul::check_batched_matrix_mul(
dtype::Float32{}, dtype::Float32{}, {}, handle_cuda(),
ExecutionPolicyAlgoName{"NAIVE_BMM"}, 1e-5,
matrix_mul::get_batched_matmul_args_mask(2));
}
TEST_F(CUDA, BATCHED_MATRIX_MUL_F32_NAIVE_PART3) {
matrix_mul::check_batched_matrix_mul(
dtype::Float32{}, dtype::Float32{}, {}, handle_cuda(),
ExecutionPolicyAlgoName{"NAIVE_BMM"}, 1e-5,
matrix_mul::get_batched_matmul_args_mask(3));
}
TEST_F(CUDA, BATCHED_MATRIX_MUL_F16_NAIVE_PART0) {
matrix_mul::check_batched_matrix_mul(
dtype::Float16{}, dtype::Float16{}, {}, handle_cuda(),
ExecutionPolicyAlgoName{"NAIVE_BMM"}, 1e-5,
matrix_mul::get_batched_matmul_args_mask(0));
}
TEST_F(CUDA, BATCHED_MATRIX_MUL_F16_NAIVE_PART1) {
matrix_mul::check_batched_matrix_mul(
dtype::Float16{}, dtype::Float16{}, {}, handle_cuda(),
ExecutionPolicyAlgoName{"NAIVE_BMM"}, 1e-5,
matrix_mul::get_batched_matmul_args_mask(1));
}
TEST_F(CUDA, BATCHED_MATRIX_MUL_F16_NAIVE_PART2) {
matrix_mul::check_batched_matrix_mul(
dtype::Float16{}, dtype::Float16{}, {}, handle_cuda(),
ExecutionPolicyAlgoName{"NAIVE_BMM"}, 1e-5,
matrix_mul::get_batched_matmul_args_mask(2));
}
TEST_F(CUDA, BATCHED_MATRIX_MUL_F16_NAIVE_PART3) {
matrix_mul::check_batched_matrix_mul(
dtype::Float16{}, dtype::Float16{}, {}, handle_cuda(),
ExecutionPolicyAlgoName{"NAIVE_BMM"}, 1e-5,
matrix_mul::get_batched_matmul_args_mask(3));
}
TEST_F(CUDA, BATCHED_MATRIX_MUL_F16_PART1) {
require_compute_capability(6, 0);
matrix_mul::check_batched_matrix_mul(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册