diff --git a/dnn/src/cuda/convolution/opr_impl.cpp b/dnn/src/cuda/convolution/opr_impl.cpp index b3dd29b49bdb5dcb604505e3c2c03839f692f87d..c3fbf76b8e00d26d17962593fc896e6741ad64d0 100644 --- a/dnn/src/cuda/convolution/opr_impl.cpp +++ b/dnn/src/cuda/convolution/opr_impl.cpp @@ -142,14 +142,16 @@ ConvolutionBackwardDataImpl::get_algorithm_heuristic( for (int i = 0; i < ret_count; ++i) { if (algo_perf[i].memory > workspace_limit_in_bytes) continue; - if ((positive_attr & AlgoAttribute::REPRODUCIBLE)) { - if (algo_perf[i].determinism == CUDNN_DETERMINISTIC) { - return reinterpret_cast( + if ((positive_attr & AlgoAttribute::REPRODUCIBLE) && + (algo_perf[i].determinism != CUDNN_DETERMINISTIC)) { + continue; + } + AlgoBase* conv_bd_data_algo = reinterpret_cast( sm_algo_pack.cudnn_from_enum(algo_perf[i].algo)); - } - } else { - return reinterpret_cast( - sm_algo_pack.cudnn_from_enum(algo_perf[i].algo)); + if (conv_bd_data_algo->is_available_attribute( + args, positive_attr, negative_attr, + workspace_limit_in_bytes)) { + return conv_bd_data_algo; } } return nullptr; @@ -269,14 +271,16 @@ ConvolutionBackwardFilterImpl::get_algorithm_heuristic( for (int i = 0; i < ret_count; ++i) { if (algo_perf[i].memory > workspace_limit_in_bytes) continue; - if ((positive_attr & AlgoAttribute::REPRODUCIBLE)) { - if (algo_perf[i].determinism == CUDNN_DETERMINISTIC) { - return reinterpret_cast( - sm_algo_pack.cudnn_from_enum(algo_perf[i].algo)); - } - } else { - return reinterpret_cast( - sm_algo_pack.cudnn_from_enum(algo_perf[i].algo)); + if ((positive_attr & AlgoAttribute::REPRODUCIBLE) && + (algo_perf[i].determinism != CUDNN_DETERMINISTIC)) { + continue; + } + AlgoBase* conv_bd_filter_algo = reinterpret_cast( + sm_algo_pack.cudnn_from_enum(algo_perf[i].algo)); + if (conv_bd_filter_algo->is_available_attribute( + args, positive_attr, negative_attr, + workspace_limit_in_bytes)) { + return conv_bd_filter_algo; } } return nullptr; diff --git a/src/opr/test/dnn/convolution.cpp b/src/opr/test/dnn/convolution.cpp index 734528e3558f1dbb5b32f066d58eb6140e6f995f..3fd50d80fe403537e26038bdd8aec2c246b9c820 100644 --- a/src/opr/test/dnn/convolution.cpp +++ b/src/opr/test/dnn/convolution.cpp @@ -582,6 +582,90 @@ TEST(TestOprDNN, ConvolutionBackwardDataBfloat16ExePolicy) { } } +#if MGB_ENABLE_FASTRUN +TEST(TestOprDNN, ConvolutionBackwardDataFloat16ExePolicy) { + REQUIRE_GPU(1); + Param param{Mode::CROSS_CORRELATION, 1, 1, 1, 1}; + param.compute_mode = Param::ComputeMode::FLOAT32; + using Policy = opr::Convolution::ExecutionPolicy; + using S = Policy::Strategy; + + auto gen_fp16 = [](HostTensorND& dest) { + RNGxorshf rng{next_rand_seed()}; + auto rand_real = [&rng]() { + std::uniform_real_distribution dist(-1, 1); + return dist(rng); + }; + auto ptr = dest.ptr(); + size_t elems = dest.shape().total_nr_elems(); + for (size_t i = 0; i < elems; i++) { + ptr[i] = dt_float16(rand_real()); + } + }; + + auto f32_to_f16 = [](const std::shared_ptr& src) + -> std::shared_ptr { + auto ret = std::make_shared( + src->comp_node(), src->shape(), dtype::Float16{}); + for (size_t i = 0; i < src->layout().total_nr_elems(); i++) { + ret->ptr()[i] = src->ptr()[i]; + } + return ret; + }; + + auto f16_to_f32 = [](const std::shared_ptr& src) + -> std::shared_ptr { + auto ret = std::make_shared( + src->comp_node(), src->shape(), dtype::Float32{}); + for (size_t i = 0; i < src->layout().total_nr_elems(); i++) { + ret->ptr()[i] = src->ptr()[i]; + } + return ret; + }; + + int nr_get = 0; + auto on_get = [&nr_get](const std::string&, const void*, size_t, + const void*, size_t) { ++nr_get; }; + PersistentCacheHook cache_hook{on_get}; + + auto strategy = S(S::PROFILE | S::REPRODUCIBLE); + using Checker = AutoOprChecker<2, 1>; + + auto make_graph = + [&](const Checker::SymInpArray& inputs) -> Checker::SymOutArray { + Policy policy; + policy.strategy = strategy; + return {opr::ConvolutionBackwardData::make_deconv(inputs[0], inputs[1], + param, policy)}; + }; + + auto fwd = [&](Checker::NumOutArray& dest, Checker::NumInpArray inp) { + std::shared_ptr out; + conv_bwd_data_brute({f16_to_f32(inp[0]), f16_to_f32(inp[1])}, out, + param); + dest[0] = *f32_to_f16(out); + }; + + Checker::RunOptions opt; + opt.outputs_max_err = 1e-2; + nr_get = 0; + Checker(make_graph, fwd) + .disable_grad_check() + .set_input_dtype(0, dtype::Float16{}) + .set_input_dtype(1, dtype::Float16{}) + .set_input_generator(0, gen_fp16) + .set_input_generator(1, gen_fp16) + .run({TensorShape{3, 4, 10, 6}, {4, 2, 3, 3}}, opt) + .run({TensorShape{2, 2, 4, 3}, {2, 2, 3, 3}}, opt) + .run({TensorShape{1, 3, 10, 6}, {3, 2, 3, 3}}, opt); + if (strategy == S::HEURISTIC) { + ASSERT_EQ(0, nr_get); + } else { + ASSERT_LT(0, nr_get); + } +} +#endif + TEST(TestOprDNN, Deconvolution) { // dilated grouped deconv using Checker = AutoOprChecker<2, 1>;