提交 56e863b7 编写于 作者: M Megvii Engine Team

fix(dnn/cuda): fix int4 epilogue stg bug

GitOrigin-RevId: e86da9a8a8d3d55931d55b12673c99f940a390d0
上级 cff61a53
......@@ -182,10 +182,10 @@ CheckerHelper::CheckerHelper(Handle *handle, bool check_dispatch):
const char* env_p = std::getenv("MGB_NO_NAIVE_CHECK");
if (env_p) {
int no_naive_flag = atoi(env_p);
no_naive_and_check = no_naive_flag > 0 ? true : false;
m_no_naive_and_check = no_naive_flag > 0 ? true : false;
check_dispatch = false;
} else {
no_naive_and_check = false;
m_no_naive_and_check = false;
}
auto tmp_handle = create_cpu_handle(2, check_dispatch);
m_handle_naive = std::move(tmp_handle);
......@@ -282,7 +282,18 @@ void CheckerHelper::do_exec(const TensorLayoutArray &user_layouts,
m_expect_exec_fail = {};
return;
}
if (no_naive_and_check){
if (m_stable_check) {
auto tensors_bak_host_storage =
alloc_tensors(m_handle_naive.get(), layouts, m_offset);
auto&& tensors_bak_host = *tensors_bak_host_storage;
copy_tensors_from_device(tensors_bak_host, tensors_cur);
for (int i = 0; i < 10; i++) {
exec_opr(tensors_cur);
copy_tensors_from_device(tensors_cur_host, tensors_cur);
check_tensors(tensors_bak_host, tensors_cur_host);
}
}
if (m_no_naive_and_check) {
m_prev_succ = !::testing::Test::HasFailure();
return;
}
......
......@@ -76,7 +76,8 @@ protected:
ExtraOprImpl m_extra_opr_impl;
OutputCanonizer m_output_canonizer;
TensorsConstriant m_tensor_constraint;
bool no_naive_and_check = false;
bool m_no_naive_and_check = false;
bool m_stable_check = false;
/**
* the offset from the start of malloc memory
*
......@@ -230,6 +231,17 @@ public:
return *this;
}
//! stable check will run many iter and compare result with first iter
Checker& set_stable_check(bool stable_check) {
m_stable_check = stable_check;
return *this;
}
Checker& set_no_naive_check(bool no_naive_and_check) {
m_no_naive_and_check = no_naive_and_check;
return *this;
}
//! load input tensors from file for next run
Checker& load_input_tensors(const char* fpath) {
m_input_tensors_fpath = fpath;
......
......@@ -731,9 +731,10 @@ std::vector<TestArg> get_int8_chwn4_tensorcore_args(size_t kernel_size) {
void check_conv_bias(DType src_dtype, DType filter_dtype, DType bias_dtype,
DType dst_dtype, Handle* handle, const char* algo,
param::ConvBias::Format format,
const std::vector<TestArg>& args, bool fuse_z) {
const std::vector<TestArg>& args, bool fuse_z,
bool stable_test) {
megdnn_assert(src_dtype.enumv() == filter_dtype.enumv());
Checker<ConvBiasForward> checker(handle);
Checker<ConvBiasForward> checker(handle, !stable_test);
if (algo) {
checker.set_before_exec_callback(
ConvBiasAlgoChecker<ConvBiasForward>(algo));
......@@ -823,6 +824,10 @@ void check_conv_bias(DType src_dtype, DType filter_dtype, DType bias_dtype,
.set_rng(1, rng.get())
.set_rng(2, bias_rng.get())
.set_rng(3, rng.get());
if (stable_test) {
checker.set_stable_check(true);
checker.set_no_naive_check(true);
}
if (args.empty()) {
std::vector<TestArg> default_args;
if (format == Format::NCHW4) {
......
......@@ -69,7 +69,7 @@ void check_conv_bias(
DType src_dtype, DType filter_dtype, DType bias_dtype, DType dst_dtype,
Handle* handle, const char* algo = nullptr,
param::ConvBias::Format format = param::ConvBias::Format::NCHW4,
const std::vector<TestArg>& args = {}, bool fuse_z = false);
const std::vector<TestArg>& args = {}, bool fuse_z = false, bool stable_test = false);
#if MEGDNN_WITH_BENCHMARK
std::vector<conv_bias::TestArg> get_winograd_benchmark_args(
......
......@@ -71,7 +71,7 @@ public:
auto rand_real2 = [&](double range) {
return rand_real(-range, range);
};
dt_float32 res;
dt_float32 res = 0;
switch (idx) {
case 0:
rot = rand_real(0, M_PI * 2);
......
......@@ -28,6 +28,13 @@ bool check_compute_capability_eq(int major, int minor) {
cuda_check(cudaGetDeviceProperties(&prop, dev));
return (prop.major == major && prop.minor == minor);
}
const cudaDeviceProp current_cuda_device_prop() {
int dev;
cuda_check(cudaGetDevice(&dev));
cudaDeviceProp prop;
cuda_check(cudaGetDeviceProperties(&prop, dev));
return prop;
}
} // namespace test
} // namespace megdnn
......
......@@ -27,6 +27,7 @@ namespace megdnn {
namespace test {
bool check_compute_capability(int major, int minor);
bool check_compute_capability_eq(int major, int minor);
const cudaDeviceProp current_cuda_device_prop();
} // namespace test
} // namespace megdnn
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册