提交 04b03532 编写于 作者: M Megvii Engine Team

feat(jit/opencl): add OpenCL jit test

GitOrigin-RevId: 9467bc7d51e3a1d343b3745d7ad0ce577ddffae8
上级 6ce3040f
...@@ -53,7 +53,7 @@ void gen_input_code_and_gen_input_data_update( ...@@ -53,7 +53,7 @@ void gen_input_code_and_gen_input_data_update(
(is_half ? "half4 x_after_read" : "float4 x_after_read") + (is_half ? "half4 x_after_read" : "float4 x_after_read") +
std::to_string(i)); std::to_string(i));
std::string coord = "coord"; std::string coord = "coord";
if (LayoutType::BROADCAST == b_info[i]) { if (LayoutType::CHANNEL_BROADCAST == b_info[i]) {
coord = "coord_b"; coord = "coord_b";
} }
std::string read_method = read_image_func + "(" + std::string read_method = read_image_func + "(" +
...@@ -102,6 +102,7 @@ __kernel void {{KERNEL_NAME}} ( ...@@ -102,6 +102,7 @@ __kernel void {{KERNEL_NAME}} (
__private const int global_size_dim1, __private const int global_size_dim1,
__private const int wc_size, __private const int wc_size,
__private const int hb_size, __private const int hb_size,
__private const int h,
__private const uint w_size __private const uint w_size
) { ) {
#if OPENCL_ENABLE_FP16 #if OPENCL_ENABLE_FP16
...@@ -121,7 +122,7 @@ __kernel void {{KERNEL_NAME}} ( ...@@ -121,7 +122,7 @@ __kernel void {{KERNEL_NAME}} (
for (; hb < hb_size; hb += global_size_dim1) { for (; hb < hb_size; hb += global_size_dim1) {
for (; wc < wc_size; wc += global_size_dim0) { for (; wc < wc_size; wc += global_size_dim0) {
int2 coord = (int2)(wc, hb); int2 coord = (int2)(wc, hb);
int2 coord_b = (int2)(wc / w_size, 0); int2 coord_b = (int2)(wc / w_size, hb/h);
{{INTERNAL_DECL_EXPRS}} {{INTERNAL_DECL_EXPRS}}
{{ASSIGN_EXPRS}} {{ASSIGN_EXPRS}}
{{INTERNAL_ASSIGN_EXPRS}} {{INTERNAL_ASSIGN_EXPRS}}
......
...@@ -109,6 +109,7 @@ void OpenCLExecutable::execute(JITExecutor* fusion_opr) { ...@@ -109,6 +109,7 @@ void OpenCLExecutable::execute(JITExecutor* fusion_opr) {
} }
mgb_assert( mgb_assert(
args.outputs.size() == 1, "OpenCL elemwise jit output size should be one"); args.outputs.size() == 1, "OpenCL elemwise jit output size should be one");
size_t h = args.outputs[0].layout[1];
//! create kernel //! create kernel
std::string compile_options; std::string compile_options;
...@@ -175,7 +176,8 @@ void OpenCLExecutable::execute(JITExecutor* fusion_opr) { ...@@ -175,7 +176,8 @@ void OpenCLExecutable::execute(JITExecutor* fusion_opr) {
{{&i_WGSX, sizeof(int)}, {{&i_WGSX, sizeof(int)},
{&i_WGSY, sizeof(int)}, {&i_WGSY, sizeof(int)},
{&wc_size, sizeof(int)}, {&wc_size, sizeof(int)},
{&hb_size, sizeof(int)}}); {&hb_size, sizeof(int)},
{&h, sizeof(int)}});
//! have broadcasted_channel_like_input case //! have broadcasted_channel_like_input case
int may_w_size = args.outputs[0].layout[3]; int may_w_size = args.outputs[0].layout[3];
kernel.add_arg({&may_w_size, sizeof(cl_uint)}); kernel.add_arg({&may_w_size, sizeof(cl_uint)});
......
...@@ -32,7 +32,7 @@ std::vector<LayoutType> get_channel_broadcast_info( ...@@ -32,7 +32,7 @@ std::vector<LayoutType> get_channel_broadcast_info(
if (ih == h && iw == w) { if (ih == h && iw == w) {
ret.push_back(LayoutType::VEC); ret.push_back(LayoutType::VEC);
} else { } else {
ret.push_back(LayoutType::BROADCAST); ret.push_back(LayoutType::CHANNEL_BROADCAST);
mgb_assert(ih == 1 && iw == 1, "invalid args for OpenCL jit"); mgb_assert(ih == 1 && iw == 1, "invalid args for OpenCL jit");
} }
} }
......
...@@ -21,7 +21,7 @@ int safe_int(S val) { ...@@ -21,7 +21,7 @@ int safe_int(S val) {
enum class LayoutType { enum class LayoutType {
SCALAR = 0, SCALAR = 0,
BROADCAST = 1, CHANNEL_BROADCAST = 1,
VEC = 2, VEC = 2,
}; };
......
#include <memory> #include <memory>
#include "./helper.h" #include "./helper.h"
#include "megbrain/comp_node_env.h"
#include "megbrain/gopt/inference.h"
#include "megbrain/jit/executor_opr.h" #include "megbrain/jit/executor_opr.h"
#include "megbrain/opr/basic_arith.h" #include "megbrain/opr/basic_arith.h"
#include "megbrain/opr/basic_arith_wrapper.h" #include "megbrain/opr/basic_arith_wrapper.h"
...@@ -279,7 +281,6 @@ void run_mlir_mode(CompNode cn) { ...@@ -279,7 +281,6 @@ void run_mlir_mode(CompNode cn) {
MGB_ASSERT_TENSOR_NEAR(host_y, host_y_jit, opt.maxerr); MGB_ASSERT_TENSOR_NEAR(host_y, host_y_jit, opt.maxerr);
} }
#endif #endif
} // anonymous namespace } // anonymous namespace
/* ===================== TestJITHalideCodeGenCude ===================== */ /* ===================== TestJITHalideCodeGenCude ===================== */
...@@ -323,32 +324,11 @@ TEST(TestJITMlirCodeGen, BasicGPU) { ...@@ -323,32 +324,11 @@ TEST(TestJITMlirCodeGen, BasicGPU) {
} }
/* ===================== TestJITMlirUnaryElemwise ===================== */ /* ===================== TestJITMlirUnaryElemwise ===================== */
#define FOREACH_UNARY_MODE(cb) \
cb(RELU) cb(ABS) cb(NEGATE) cb(ACOS) cb(ASIN) cb(CEIL) cb(EXP) cb(FLOOR) cb(LOG) \
cb(LOG1P) cb(SIN) cb(COS) cb(TANH) cb(FAST_TANH) cb(H_SWISH) cb(SIGMOID) \
cb(EXPM1) cb(ROUND) cb(ERF) cb(ERFINV) cb(ERFC) cb(ERFCINV)
// clang-format off
#define FOREACH_UNARY_MODE(cb) \
cb(RELU) \
cb(ABS) \
cb(NEGATE) \
cb(ACOS) \
cb(ASIN) \
cb(CEIL) \
cb(EXP) \
cb(FLOOR) \
cb(LOG) \
cb(LOG1P) \
cb(SIN) \
cb(COS) \
cb(TANH) \
cb(FAST_TANH) \
cb(H_SWISH) \
cb(SIGMOID) \
cb(EXPM1) \
cb(ROUND) \
cb(ERF) \
cb(ERFINV) \
cb(ERFC) \
cb(ERFCINV)
// clang-format on
template <typename tag> template <typename tag>
class TestJITMlirUnaryElemwise : public ::testing::Test {}; class TestJITMlirUnaryElemwise : public ::testing::Test {};
...@@ -388,34 +368,13 @@ TYPED_TEST(TestJITMlirUnaryElemwise, runGpu) { ...@@ -388,34 +368,13 @@ TYPED_TEST(TestJITMlirUnaryElemwise, runGpu) {
} }
/* ===================== TestJITMlirBinaryElemwise ===================== */ /* ===================== TestJITMlirBinaryElemwise ===================== */
#define FOREACH_BINARY_MODE(cb) \
cb(ADD) cb(FLOOR_DIV) cb(MUL) cb(MAX) cb(MIN) cb(MOD) cb(SUB) cb(TRUE_DIV) cb(POW) \
cb(ABS_GRAD) cb(SIGMOID_GRAD) cb(SWITCH_GT0) cb(TANH_GRAD) cb(LT) cb(LEQ) \
cb(EQ) cb(FUSE_ADD_RELU) cb(LOG_SUM_EXP) cb(FUSE_ADD_TANH) \
cb(FAST_TANH_GRAD) cb(FUSE_ADD_SIGMOID) cb(H_SWISH_GRAD) \
cb(FUSE_ADD_H_SWISH) cb(ATAN2)
// clang-format off
#define FOREACH_BINARY_MODE(cb) \
cb(ADD) \
cb(FLOOR_DIV) \
cb(MUL) \
cb(MAX) \
cb(MIN) \
cb(MOD) \
cb(SUB) \
cb(TRUE_DIV) \
cb(POW) \
cb(ABS_GRAD) \
cb(SIGMOID_GRAD) \
cb(SWITCH_GT0) \
cb(TANH_GRAD) \
cb(LT) \
cb(LEQ) \
cb(EQ) \
cb(FUSE_ADD_RELU) \
cb(LOG_SUM_EXP) \
cb(FUSE_ADD_TANH) \
cb(FAST_TANH_GRAD) \
cb(FUSE_ADD_SIGMOID) \
cb(H_SWISH_GRAD) \
cb(FUSE_ADD_H_SWISH) \
cb(ATAN2)
// clang-format on
template <typename tag> template <typename tag>
class TestJITMlirBinaryElemwise : public ::testing::Test {}; class TestJITMlirBinaryElemwise : public ::testing::Test {};
...@@ -445,13 +404,8 @@ TYPED_TEST(TestJITMlirBinaryElemwise, runGpu) { ...@@ -445,13 +404,8 @@ TYPED_TEST(TestJITMlirBinaryElemwise, runGpu) {
} }
/* ===================== TestJITMlirTenaryElemwise ===================== */ /* ===================== TestJITMlirTenaryElemwise ===================== */
#define FOREACH_TERNARY_MODE(cb) cb(COND_LEQ_MOV) cb(COND_LT_MOV) cb(FUSE_MUL_ADD3)
// clang-format off
#define FOREACH_TERNARY_MODE(cb) \
cb(COND_LEQ_MOV) \
cb(COND_LT_MOV) \
cb(FUSE_MUL_ADD3) \
// clang-format on
template <typename tag> template <typename tag>
class TestJITMlirTernaryElemwise : public ::testing::Test {}; class TestJITMlirTernaryElemwise : public ::testing::Test {};
...@@ -463,8 +417,8 @@ FOREACH_TERNARY_MODE(def_tag) ...@@ -463,8 +417,8 @@ FOREACH_TERNARY_MODE(def_tag)
#undef def_tag #undef def_tag
#define t(n) n, #define t(n) n,
using mlir_elemwise_ternary_types = using mlir_elemwise_ternary_types =
::testing::Types<FOREACH_TERNARY_MODE(t) COND_LEQ_MOV>; ::testing::Types<FOREACH_TERNARY_MODE(t) COND_LEQ_MOV>;
#undef t #undef t
TYPED_TEST_CASE(TestJITMlirTernaryElemwise, mlir_elemwise_ternary_types); TYPED_TEST_CASE(TestJITMlirTernaryElemwise, mlir_elemwise_ternary_types);
TYPED_TEST(TestJITMlirTernaryElemwise, run) { TYPED_TEST(TestJITMlirTernaryElemwise, run) {
...@@ -480,7 +434,6 @@ TYPED_TEST(TestJITMlirTernaryElemwise, runGpu) { ...@@ -480,7 +434,6 @@ TYPED_TEST(TestJITMlirTernaryElemwise, runGpu) {
#undef SKIP_MODE #undef SKIP_MODE
/* ===================== TestJITMlirTypeCvt ===================== */ /* ===================== TestJITMlirTypeCvt ===================== */
template <typename itype, typename otype> template <typename itype, typename otype>
...@@ -505,35 +458,35 @@ void run_typecvt(CompNode cn) { ...@@ -505,35 +458,35 @@ void run_typecvt(CompNode cn) {
auto y_jit = JITExecutor::make(igraph, ig_gen->orig_inps()); auto y_jit = JITExecutor::make(igraph, ig_gen->orig_inps());
HostTensorND host_y, host_y_jit; HostTensorND host_y, host_y_jit;
auto func = graph->compile({make_callback_copy(y, host_y), auto func = graph->compile(
make_callback_copy(y_jit, host_y_jit)}); {make_callback_copy(y, host_y), make_callback_copy(y_jit, host_y_jit)});
func->execute(); func->execute();
MGB_ASSERT_TENSOR_EQ(host_y, host_y_jit); MGB_ASSERT_TENSOR_EQ(host_y, host_y_jit);
}; };
#define add_typecvt_gtest(itype, otype) \ #define add_typecvt_gtest(itype, otype) \
TEST(TestJITMlirTypeCvt, itype##_to_##otype) { \ TEST(TestJITMlirTypeCvt, itype##_to_##otype) { \
run_typecvt<dtype::itype, dtype::otype>(CompNode::load("cpu0")); \ run_typecvt<dtype::itype, dtype::otype>(CompNode::load("cpu0")); \
} \ } \
TEST(TestJITMlirTypeCvt, itype##_to_##otype##_GPU) { \ TEST(TestJITMlirTypeCvt, itype##_to_##otype##_GPU) { \
REQUIRE_GPU(1); \ REQUIRE_GPU(1); \
run_typecvt<dtype::itype, dtype::otype>(CompNode::load("gpu0")); \ run_typecvt<dtype::itype, dtype::otype>(CompNode::load("gpu0")); \
} }
#if !MEGDNN_DISABLE_FLOAT16 #if !MEGDNN_DISABLE_FLOAT16
// TODO: the support for f16 and bf16 is currently not complete in mlir // TODO: the support for f16 and bf16 is currently not complete in mlir
// FPExtOp // FPExtOp
// add_typecvt_gtest(Float16, Float32); // add_typecvt_gtest(Float16, Float32);
// add_typecvt_gtest(BFloat16, Float32); // add_typecvt_gtest(BFloat16, Float32);
// add_typecvt_gtest(Float16, BFloat16); // add_typecvt_gtest(Float16, BFloat16);
// FPTruncOp // FPTruncOp
// add_typecvt_gtest(Float32, Float16); // add_typecvt_gtest(Float32, Float16);
// add_typecvt_gtest(Float32, BFloat16); // add_typecvt_gtest(Float32, BFloat16);
// add_typecvt_gtest(Float16, BFloat16); // add_typecvt_gtest(Float16, BFloat16);
#endif #endif
...@@ -557,8 +510,7 @@ add_typecvt_gtest(Uint8, Float32); ...@@ -557,8 +510,7 @@ add_typecvt_gtest(Uint8, Float32);
/* ===================== TestJITMlirDimshuffle ===================== */ /* ===================== TestJITMlirDimshuffle ===================== */
void run_dimshuffle(CompNode cn, TensorShape ishape, void run_dimshuffle(CompNode cn, TensorShape ishape, const std::vector<int>& pattern) {
const std::vector<int>& pattern) {
set_backend(Backend::MLIR); set_backend(Backend::MLIR);
auto graph = ComputingGraph::make(); auto graph = ComputingGraph::make();
HostTensorGenerator<> gen; HostTensorGenerator<> gen;
...@@ -579,8 +531,8 @@ void run_dimshuffle(CompNode cn, TensorShape ishape, ...@@ -579,8 +531,8 @@ void run_dimshuffle(CompNode cn, TensorShape ishape,
auto y_jit = JITExecutor::make(igraph, ig_gen->orig_inps()); auto y_jit = JITExecutor::make(igraph, ig_gen->orig_inps());
HostTensorND host_y, host_y_jit; HostTensorND host_y, host_y_jit;
auto func = graph->compile({make_callback_copy(y, host_y), auto func = graph->compile(
make_callback_copy(y_jit, host_y_jit)}); {make_callback_copy(y, host_y), make_callback_copy(y_jit, host_y_jit)});
func->execute(); func->execute();
MGB_ASSERT_TENSOR_EQ(host_y, host_y_jit); MGB_ASSERT_TENSOR_EQ(host_y, host_y_jit);
......
...@@ -462,7 +462,7 @@ void run<all_oprs>(Backend backend, CompNode cn) { ...@@ -462,7 +462,7 @@ void run<all_oprs>(Backend backend, CompNode cn) {
CHECK_ELEM2(FUSE_ADD_TANH, true, none); CHECK_ELEM2(FUSE_ADD_TANH, true, none);
CHECK_ELEM2(FUSE_ADD_H_SWISH, true, none); CHECK_ELEM2(FUSE_ADD_H_SWISH, true, none);
ASSERT_EQ(ast_c::elem_opr_generator().size(), tasks.size()); ASSERT_EQ(ast_c::elem_opr_generator(cn.device_type()).size(), tasks.size());
auto type_cvt_test = [&](const char* name, DType src_dtype, DType dst_dtype) { auto type_cvt_test = [&](const char* name, DType src_dtype, DType dst_dtype) {
tasks.emplace_back(name, [cn, src_dtype, dst_dtype]() { tasks.emplace_back(name, [cn, src_dtype, dst_dtype]() {
...@@ -496,7 +496,7 @@ void run<all_oprs>(Backend backend, CompNode cn) { ...@@ -496,7 +496,7 @@ void run<all_oprs>(Backend backend, CompNode cn) {
} }
if (!::testing::Test::HasFailure()) { if (!::testing::Test::HasFailure()) {
mgb_log("going to run %s on worker %d", tasks[id].first, wid); mgb_log("going to run %s on worker %d", tasks[id].first, wid);
ASSERT_NO_THROW(tasks[id].second()) << "failed for " << tasks[id].first; ASSERT_NO_THROW(tasks[id].second());
} }
} }
}; };
...@@ -1449,7 +1449,9 @@ TEST(TestJITNvrtc, JITConfig) { ...@@ -1449,7 +1449,9 @@ TEST(TestJITNvrtc, JITConfig) {
x = opr::Reduce::make(x + 2, {ReduceMode::SUM, 2}); // Reduce x = opr::Reduce::make(x + 2, {ReduceMode::SUM, 2}); // Reduce
auto func = cg->compile({make_callback_copy(x + 1, *host_x)}); auto func = cg->compile({make_callback_copy(x + 1, *host_x)});
auto comp_seq = dynamic_cast<CompSeq*>(func.get()); //! cg->compile always return CompSeq* cast to AsyncExecutable*, so it`s safe
//! use static_cast, as Android bazel -copt will disable rtti
auto comp_seq = static_cast<CompSeq*>(func.get());
ASSERT_TRUE(comp_seq != nullptr); ASSERT_TRUE(comp_seq != nullptr);
bool dimshuffle_found = false, reduce_found = false, jit_executor_found = false; bool dimshuffle_found = false, reduce_found = false, jit_executor_found = false;
......
...@@ -27,6 +27,9 @@ void jit::set_backend(Backend backend) { ...@@ -27,6 +27,9 @@ void jit::set_backend(Backend backend) {
case Backend::MLIR: case Backend::MLIR:
setenv("MGB_JIT_BACKEND", "MLIR", 1); setenv("MGB_JIT_BACKEND", "MLIR", 1);
return; return;
case Backend::TINYOPENCL:
setenv("MGB_JIT_BACKEND", "TINYOPENCL", 1);
return;
default: default:
mgb_assert(0); mgb_assert(0);
} }
......
...@@ -4,7 +4,7 @@ ...@@ -4,7 +4,7 @@
namespace mgb { namespace mgb {
namespace jit { namespace jit {
enum class Backend { NONE, HALIDE, NVRTC, MLIR }; enum class Backend { NONE, HALIDE, NVRTC, MLIR, TINYOPENCL };
void set_backend(Backend backend); void set_backend(Backend backend);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册