提交 4a1d52c9 编写于 作者: M Megvii Engine Team

refactor(megdnn): refactor bfloat16 matmul to recursive inteface

GitOrigin-RevId: 641c508aecab700c5c7c8cb758900ab605a29620
上级 b8febaf9
......@@ -31,8 +31,7 @@ MatrixMulForwardImpl::AlgoPack::AlgoPack() {
#endif
all_algos.push_back(&naive);
#if !MEGDNN_DISABLE_FLOAT16
cublas_bfloat16 = std::make_unique<AlgoBFloat16>(&cublas);
all_algos.push_back(cublas_bfloat16.get());
all_algos.push_back(&bfloat16);
#endif
for (auto&& algo : all_algos) {
......
......@@ -148,25 +148,20 @@ public:
#if !MEGDNN_DISABLE_FLOAT16
class MatrixMulForwardImpl::AlgoBFloat16 final : public AlgoBase {
public:
AlgoBFloat16(MatrixMulForwardImpl::AlgoBase*);
bool is_available(const SizeArgs& args) const override;
size_t get_workspace_in_bytes(const SizeArgs& args) const override;
const char* name() const override { return m_name.c_str(); }
void exec(const ExecArgs& args) const override;
bool is_reproducible() const override { return true; }
MEGDNN_DECL_ALGO_TYPE(CUDA_NAIVE)
MEGDNN_DECL_ALGO_TYPE(CUDA_BFLOAT16)
std::string param() const override {
std::string ret;
serialize_write_pod(m_algorithm, ret);
return ret;
}
std::vector<SearchItem> get_subopr_list(
const TensorLayoutArray& layouts,
const OperatorBase* opr) const override;
const char* name() const override { return "MATMUL_BFLOAT16"; }
bool is_reproducible() const override { return true; }
private:
MatrixMulForwardImpl::AlgoBase* m_algorithm = nullptr;
std::string m_name;
WorkspaceBundle get_workspace_bundle(void* ptr, const SizeArgs& args) const;
SizeArgs float_args(const SizeArgs& args) const;
};
#endif
......@@ -185,7 +180,7 @@ public:
AlgoCuBlasLt cublas_lt;
#endif
#if !MEGDNN_DISABLE_FLOAT16
std::unique_ptr<AlgoBFloat16> cublas_bfloat16;
AlgoBFloat16 bfloat16;
#endif
std::vector<AlgoBase*> all_algos;
......
......@@ -6,59 +6,87 @@
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "src/cuda/handle.h"
#include "src/cuda/matrix_mul/algos.h"
#include "src/cuda/utils.h"
#include "src/common/algo_chooser.h"
using namespace megdnn;
using namespace cuda;
MatrixMulForwardImpl::AlgoBFloat16::AlgoBFloat16(
MatrixMulForwardImpl::AlgoBase* algorithm)
: m_algorithm(algorithm) {
megdnn_assert_internal(algorithm);
m_name = ssprintf("MATMUL_BFLOAT16:%s", m_algorithm->name());
}
MatrixMulForwardImpl::AlgoBase::SizeArgs
MatrixMulForwardImpl::AlgoBFloat16::float_args(const SizeArgs& args) const {
auto new_args = args;
namespace {
std::pair<TensorLayoutArray, MatrixMulForwardImpl::Param> sub_opr_config(
const TensorLayoutArray& layouts, const MatrixMulForwardImpl* opr) {
megdnn_assert(layouts.size() == 3);
std::pair<TensorLayoutArray, MatrixMulForwardImpl::Param> ret;
ret.first = layouts;
auto change_dtype = [](TensorLayout& layout) {
if (layout.dtype == dtype::BFloat16()) {
layout.dtype = dtype::Float32();
}
};
change_dtype(new_args.layout_a);
change_dtype(new_args.layout_b);
change_dtype(new_args.layout_c);
return new_args;
change_dtype(ret.first[0]);
change_dtype(ret.first[1]);
change_dtype(ret.first[2]);
ret.second = opr->param();
ret.second.compute_mode = MatrixMulForwardImpl::Param::ComputeMode::DEFAULT;
return ret;
}
} // namespace
std::vector<Algorithm::SearchItem>
MatrixMulForwardImpl::AlgoBFloat16::get_subopr_list(
const TensorLayoutArray& layouts, const OperatorBase* opr) const {
auto&& config = sub_opr_config(
layouts, static_cast<const MatrixMulForwardImpl*>(opr));
std::string param_str;
Algorithm::serialize_write_pod(config.second, param_str);
return {{Algorithm::OprType::MATRIX_MUL_FORWARD, param_str, config.first}};
}
bool MatrixMulForwardImpl::AlgoBFloat16::is_available(
const SizeArgs& args) const {
auto fargs = float_args(args);
auto&& config = sub_opr_config(
{args.layout_a, args.layout_b, args.layout_c}, args.opr);
auto matmul_opr = args.opr->handle()->create_operator<MatrixMulForward>();
matmul_opr->param() = config.second;
return args.layout_a.dtype == dtype::BFloat16() &&
m_algorithm->is_available(fargs);
get_algorithm(static_cast<MatrixMulForwardImpl*>(matmul_opr.get()),
config.first[0], config.first[1], config.first[2]);
}
WorkspaceBundle MatrixMulForwardImpl::AlgoBFloat16::get_workspace_bundle(
void* ptr, const SizeArgs& args) const {
auto fargs = float_args(args);
auto matmul_opr = args.opr->handle()->create_operator<MatrixMulForward>();
if (args.opr->execution_policy().algo.valid()) {
megdnn_assert(args.opr->execution_policy().sub_policy.size() == 1);
matmul_opr->execution_policy() =
args.opr->execution_policy().sub_policy[0];
}
auto&& config = sub_opr_config(
{args.layout_a, args.layout_b, args.layout_c}, args.opr);
matmul_opr->param() = config.second;
SmallVector<size_t> sizes;
auto get_workspace = [&sizes](const TensorLayout& src) {
TensorLayout dst = src;
if (dst.dtype == dtype::BFloat16()) {
dst.dtype = dtype::Float32();
auto get_workspace = [&sizes](const TensorLayout& src,
const TensorLayout& dst) {
if (src.dtype != dst.dtype) {
sizes.push_back(dst.span().dist_byte());
}
};
get_workspace(args.layout_a);
get_workspace(args.layout_b);
get_workspace(args.layout_c);
sizes.push_back(m_algorithm->get_workspace_in_bytes(fargs));
get_workspace(args.layout_a, config.first[0]);
get_workspace(args.layout_b, config.first[1]);
get_workspace(args.layout_c, config.first[2]);
sizes.push_back(matmul_opr->get_workspace_in_bytes(
config.first[0], config.first[1], config.first[2]));
return {ptr, std::move(sizes)};
}
......@@ -82,7 +110,12 @@ void MatrixMulForwardImpl::AlgoBFloat16::exec(const ExecArgs& args) const {
args.opr->handle()->create_operator<MatrixMulForward>();
matmul_opr->param() = args.opr->param();
matmul_opr->param().compute_mode = Param::ComputeMode::DEFAULT;
matmul_opr->execution_policy() = {m_algorithm->desc(), {}};
if (args.opr->execution_policy().algo.valid()) {
megdnn_assert(args.opr->execution_policy().sub_policy.size() == 1);
matmul_opr->execution_policy() =
args.opr->execution_policy().sub_policy[0];
}
matmul_opr->exec(a, b, c, ctypecvt.workspace());
}
ctypecvt.comp_to_dst_type(c, args.tensor_c);
......
......@@ -218,6 +218,9 @@ TEST_F(CUDA, MATRIX_MUL) {
B = TensorShape{k, n};
if (dtype == dtype::BFloat16()) {
param.compute_mode = param::MatrixMul::ComputeMode::FLOAT32;
checker.set_before_exec_callback(
AlgoChecker<MatrixMulForward>(ExecutionPolicyAlgoName{
"MATMUL_BFLOAT16", {{"CUBLAS", {}}}}));
}
checker.set_param(param)
.set_dtype(0, stype)
......@@ -228,6 +231,10 @@ TEST_F(CUDA, MATRIX_MUL) {
? 5e-2
: 5e-3)
.execs({A, B, {}});
if (dtype == dtype::BFloat16()) {
checker.reset_before_exec_callback();
checker.opr()->execution_policy() = {};
}
}
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册