提交 4a1d52c9 编写于 作者: M Megvii Engine Team

refactor(megdnn): refactor bfloat16 matmul to recursive inteface

GitOrigin-RevId: 641c508aecab700c5c7c8cb758900ab605a29620
上级 b8febaf9
...@@ -31,8 +31,7 @@ MatrixMulForwardImpl::AlgoPack::AlgoPack() { ...@@ -31,8 +31,7 @@ MatrixMulForwardImpl::AlgoPack::AlgoPack() {
#endif #endif
all_algos.push_back(&naive); all_algos.push_back(&naive);
#if !MEGDNN_DISABLE_FLOAT16 #if !MEGDNN_DISABLE_FLOAT16
cublas_bfloat16 = std::make_unique<AlgoBFloat16>(&cublas); all_algos.push_back(&bfloat16);
all_algos.push_back(cublas_bfloat16.get());
#endif #endif
for (auto&& algo : all_algos) { for (auto&& algo : all_algos) {
......
...@@ -148,25 +148,20 @@ public: ...@@ -148,25 +148,20 @@ public:
#if !MEGDNN_DISABLE_FLOAT16 #if !MEGDNN_DISABLE_FLOAT16
class MatrixMulForwardImpl::AlgoBFloat16 final : public AlgoBase { class MatrixMulForwardImpl::AlgoBFloat16 final : public AlgoBase {
public: public:
AlgoBFloat16(MatrixMulForwardImpl::AlgoBase*);
bool is_available(const SizeArgs& args) const override; bool is_available(const SizeArgs& args) const override;
size_t get_workspace_in_bytes(const SizeArgs& args) const override; size_t get_workspace_in_bytes(const SizeArgs& args) const override;
const char* name() const override { return m_name.c_str(); }
void exec(const ExecArgs& args) const override; void exec(const ExecArgs& args) const override;
bool is_reproducible() const override { return true; } MEGDNN_DECL_ALGO_TYPE(CUDA_BFLOAT16)
MEGDNN_DECL_ALGO_TYPE(CUDA_NAIVE)
std::string param() const override { std::vector<SearchItem> get_subopr_list(
std::string ret; const TensorLayoutArray& layouts,
serialize_write_pod(m_algorithm, ret); const OperatorBase* opr) const override;
return ret;
} const char* name() const override { return "MATMUL_BFLOAT16"; }
bool is_reproducible() const override { return true; }
private: private:
MatrixMulForwardImpl::AlgoBase* m_algorithm = nullptr;
std::string m_name;
WorkspaceBundle get_workspace_bundle(void* ptr, const SizeArgs& args) const; WorkspaceBundle get_workspace_bundle(void* ptr, const SizeArgs& args) const;
SizeArgs float_args(const SizeArgs& args) const;
}; };
#endif #endif
...@@ -185,7 +180,7 @@ public: ...@@ -185,7 +180,7 @@ public:
AlgoCuBlasLt cublas_lt; AlgoCuBlasLt cublas_lt;
#endif #endif
#if !MEGDNN_DISABLE_FLOAT16 #if !MEGDNN_DISABLE_FLOAT16
std::unique_ptr<AlgoBFloat16> cublas_bfloat16; AlgoBFloat16 bfloat16;
#endif #endif
std::vector<AlgoBase*> all_algos; std::vector<AlgoBase*> all_algos;
......
...@@ -6,59 +6,87 @@ ...@@ -6,59 +6,87 @@
* *
* Unless required by applicable law or agreed to in writing, * Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an * software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/ */
#include "src/cuda/handle.h" #include "src/cuda/handle.h"
#include "src/cuda/matrix_mul/algos.h" #include "src/cuda/matrix_mul/algos.h"
#include "src/cuda/utils.h" #include "src/cuda/utils.h"
#include "src/common/algo_chooser.h"
using namespace megdnn; using namespace megdnn;
using namespace cuda; using namespace cuda;
MatrixMulForwardImpl::AlgoBFloat16::AlgoBFloat16( namespace {
MatrixMulForwardImpl::AlgoBase* algorithm) std::pair<TensorLayoutArray, MatrixMulForwardImpl::Param> sub_opr_config(
: m_algorithm(algorithm) { const TensorLayoutArray& layouts, const MatrixMulForwardImpl* opr) {
megdnn_assert_internal(algorithm); megdnn_assert(layouts.size() == 3);
m_name = ssprintf("MATMUL_BFLOAT16:%s", m_algorithm->name()); std::pair<TensorLayoutArray, MatrixMulForwardImpl::Param> ret;
} ret.first = layouts;
MatrixMulForwardImpl::AlgoBase::SizeArgs
MatrixMulForwardImpl::AlgoBFloat16::float_args(const SizeArgs& args) const {
auto new_args = args;
auto change_dtype = [](TensorLayout& layout) { auto change_dtype = [](TensorLayout& layout) {
if (layout.dtype == dtype::BFloat16()) { if (layout.dtype == dtype::BFloat16()) {
layout.dtype = dtype::Float32(); layout.dtype = dtype::Float32();
} }
}; };
change_dtype(new_args.layout_a); change_dtype(ret.first[0]);
change_dtype(new_args.layout_b); change_dtype(ret.first[1]);
change_dtype(new_args.layout_c); change_dtype(ret.first[2]);
return new_args;
ret.second = opr->param();
ret.second.compute_mode = MatrixMulForwardImpl::Param::ComputeMode::DEFAULT;
return ret;
}
} // namespace
std::vector<Algorithm::SearchItem>
MatrixMulForwardImpl::AlgoBFloat16::get_subopr_list(
const TensorLayoutArray& layouts, const OperatorBase* opr) const {
auto&& config = sub_opr_config(
layouts, static_cast<const MatrixMulForwardImpl*>(opr));
std::string param_str;
Algorithm::serialize_write_pod(config.second, param_str);
return {{Algorithm::OprType::MATRIX_MUL_FORWARD, param_str, config.first}};
} }
bool MatrixMulForwardImpl::AlgoBFloat16::is_available( bool MatrixMulForwardImpl::AlgoBFloat16::is_available(
const SizeArgs& args) const { const SizeArgs& args) const {
auto fargs = float_args(args); auto&& config = sub_opr_config(
{args.layout_a, args.layout_b, args.layout_c}, args.opr);
auto matmul_opr = args.opr->handle()->create_operator<MatrixMulForward>();
matmul_opr->param() = config.second;
return args.layout_a.dtype == dtype::BFloat16() && return args.layout_a.dtype == dtype::BFloat16() &&
m_algorithm->is_available(fargs); get_algorithm(static_cast<MatrixMulForwardImpl*>(matmul_opr.get()),
config.first[0], config.first[1], config.first[2]);
} }
WorkspaceBundle MatrixMulForwardImpl::AlgoBFloat16::get_workspace_bundle( WorkspaceBundle MatrixMulForwardImpl::AlgoBFloat16::get_workspace_bundle(
void* ptr, const SizeArgs& args) const { void* ptr, const SizeArgs& args) const {
auto fargs = float_args(args); auto matmul_opr = args.opr->handle()->create_operator<MatrixMulForward>();
if (args.opr->execution_policy().algo.valid()) {
megdnn_assert(args.opr->execution_policy().sub_policy.size() == 1);
matmul_opr->execution_policy() =
args.opr->execution_policy().sub_policy[0];
}
auto&& config = sub_opr_config(
{args.layout_a, args.layout_b, args.layout_c}, args.opr);
matmul_opr->param() = config.second;
SmallVector<size_t> sizes; SmallVector<size_t> sizes;
auto get_workspace = [&sizes](const TensorLayout& src) { auto get_workspace = [&sizes](const TensorLayout& src,
TensorLayout dst = src; const TensorLayout& dst) {
if (dst.dtype == dtype::BFloat16()) { if (src.dtype != dst.dtype) {
dst.dtype = dtype::Float32();
sizes.push_back(dst.span().dist_byte()); sizes.push_back(dst.span().dist_byte());
} }
}; };
get_workspace(args.layout_a);
get_workspace(args.layout_b); get_workspace(args.layout_a, config.first[0]);
get_workspace(args.layout_c); get_workspace(args.layout_b, config.first[1]);
sizes.push_back(m_algorithm->get_workspace_in_bytes(fargs)); get_workspace(args.layout_c, config.first[2]);
sizes.push_back(matmul_opr->get_workspace_in_bytes(
config.first[0], config.first[1], config.first[2]));
return {ptr, std::move(sizes)}; return {ptr, std::move(sizes)};
} }
...@@ -82,7 +110,12 @@ void MatrixMulForwardImpl::AlgoBFloat16::exec(const ExecArgs& args) const { ...@@ -82,7 +110,12 @@ void MatrixMulForwardImpl::AlgoBFloat16::exec(const ExecArgs& args) const {
args.opr->handle()->create_operator<MatrixMulForward>(); args.opr->handle()->create_operator<MatrixMulForward>();
matmul_opr->param() = args.opr->param(); matmul_opr->param() = args.opr->param();
matmul_opr->param().compute_mode = Param::ComputeMode::DEFAULT; matmul_opr->param().compute_mode = Param::ComputeMode::DEFAULT;
matmul_opr->execution_policy() = {m_algorithm->desc(), {}}; if (args.opr->execution_policy().algo.valid()) {
megdnn_assert(args.opr->execution_policy().sub_policy.size() == 1);
matmul_opr->execution_policy() =
args.opr->execution_policy().sub_policy[0];
}
matmul_opr->exec(a, b, c, ctypecvt.workspace()); matmul_opr->exec(a, b, c, ctypecvt.workspace());
} }
ctypecvt.comp_to_dst_type(c, args.tensor_c); ctypecvt.comp_to_dst_type(c, args.tensor_c);
......
...@@ -218,6 +218,9 @@ TEST_F(CUDA, MATRIX_MUL) { ...@@ -218,6 +218,9 @@ TEST_F(CUDA, MATRIX_MUL) {
B = TensorShape{k, n}; B = TensorShape{k, n};
if (dtype == dtype::BFloat16()) { if (dtype == dtype::BFloat16()) {
param.compute_mode = param::MatrixMul::ComputeMode::FLOAT32; param.compute_mode = param::MatrixMul::ComputeMode::FLOAT32;
checker.set_before_exec_callback(
AlgoChecker<MatrixMulForward>(ExecutionPolicyAlgoName{
"MATMUL_BFLOAT16", {{"CUBLAS", {}}}}));
} }
checker.set_param(param) checker.set_param(param)
.set_dtype(0, stype) .set_dtype(0, stype)
...@@ -228,6 +231,10 @@ TEST_F(CUDA, MATRIX_MUL) { ...@@ -228,6 +231,10 @@ TEST_F(CUDA, MATRIX_MUL) {
? 5e-2 ? 5e-2
: 5e-3) : 5e-3)
.execs({A, B, {}}); .execs({A, B, {}});
if (dtype == dtype::BFloat16()) {
checker.reset_before_exec_callback();
checker.opr()->execution_policy() = {};
}
} }
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册