提交 2c4ff543 编写于 作者: M Megvii Engine Team

fix(mgb): fix cudnn ConvolutionBackwardData

GitOrigin-RevId: 1fffc06eaa6fe66435715ec5a93c86dd37de985e
上级 7138e4fd
...@@ -142,14 +142,16 @@ ConvolutionBackwardDataImpl::get_algorithm_heuristic( ...@@ -142,14 +142,16 @@ ConvolutionBackwardDataImpl::get_algorithm_heuristic(
for (int i = 0; i < ret_count; ++i) { for (int i = 0; i < ret_count; ++i) {
if (algo_perf[i].memory > workspace_limit_in_bytes) if (algo_perf[i].memory > workspace_limit_in_bytes)
continue; continue;
if ((positive_attr & AlgoAttribute::REPRODUCIBLE)) { if ((positive_attr & AlgoAttribute::REPRODUCIBLE) &&
if (algo_perf[i].determinism == CUDNN_DETERMINISTIC) { (algo_perf[i].determinism != CUDNN_DETERMINISTIC)) {
return reinterpret_cast<AlgoBase*>( continue;
sm_algo_pack.cudnn_from_enum(algo_perf[i].algo));
} }
} else { AlgoBase* conv_bd_data_algo = reinterpret_cast<AlgoBase*>(
return reinterpret_cast<AlgoBase*>(
sm_algo_pack.cudnn_from_enum(algo_perf[i].algo)); sm_algo_pack.cudnn_from_enum(algo_perf[i].algo));
if (conv_bd_data_algo->is_available_attribute(
args, positive_attr, negative_attr,
workspace_limit_in_bytes)) {
return conv_bd_data_algo;
} }
} }
return nullptr; return nullptr;
...@@ -269,14 +271,16 @@ ConvolutionBackwardFilterImpl::get_algorithm_heuristic( ...@@ -269,14 +271,16 @@ ConvolutionBackwardFilterImpl::get_algorithm_heuristic(
for (int i = 0; i < ret_count; ++i) { for (int i = 0; i < ret_count; ++i) {
if (algo_perf[i].memory > workspace_limit_in_bytes) if (algo_perf[i].memory > workspace_limit_in_bytes)
continue; continue;
if ((positive_attr & AlgoAttribute::REPRODUCIBLE)) { if ((positive_attr & AlgoAttribute::REPRODUCIBLE) &&
if (algo_perf[i].determinism == CUDNN_DETERMINISTIC) { (algo_perf[i].determinism != CUDNN_DETERMINISTIC)) {
return reinterpret_cast<AlgoBase*>( continue;
sm_algo_pack.cudnn_from_enum(algo_perf[i].algo));
} }
} else { AlgoBase* conv_bd_filter_algo = reinterpret_cast<AlgoBase*>(
return reinterpret_cast<AlgoBase*>(
sm_algo_pack.cudnn_from_enum(algo_perf[i].algo)); sm_algo_pack.cudnn_from_enum(algo_perf[i].algo));
if (conv_bd_filter_algo->is_available_attribute(
args, positive_attr, negative_attr,
workspace_limit_in_bytes)) {
return conv_bd_filter_algo;
} }
} }
return nullptr; return nullptr;
......
...@@ -582,6 +582,90 @@ TEST(TestOprDNN, ConvolutionBackwardDataBfloat16ExePolicy) { ...@@ -582,6 +582,90 @@ TEST(TestOprDNN, ConvolutionBackwardDataBfloat16ExePolicy) {
} }
} }
#if MGB_ENABLE_FASTRUN
TEST(TestOprDNN, ConvolutionBackwardDataFloat16ExePolicy) {
REQUIRE_GPU(1);
Param param{Mode::CROSS_CORRELATION, 1, 1, 1, 1};
param.compute_mode = Param::ComputeMode::FLOAT32;
using Policy = opr::Convolution::ExecutionPolicy;
using S = Policy::Strategy;
auto gen_fp16 = [](HostTensorND& dest) {
RNGxorshf rng{next_rand_seed()};
auto rand_real = [&rng]() {
std::uniform_real_distribution<float> dist(-1, 1);
return dist(rng);
};
auto ptr = dest.ptr<dt_float16>();
size_t elems = dest.shape().total_nr_elems();
for (size_t i = 0; i < elems; i++) {
ptr[i] = dt_float16(rand_real());
}
};
auto f32_to_f16 = [](const std::shared_ptr<HostTensorND>& src)
-> std::shared_ptr<HostTensorND> {
auto ret = std::make_shared<HostTensorND>(
src->comp_node(), src->shape(), dtype::Float16{});
for (size_t i = 0; i < src->layout().total_nr_elems(); i++) {
ret->ptr<dt_float16>()[i] = src->ptr<dt_float32>()[i];
}
return ret;
};
auto f16_to_f32 = [](const std::shared_ptr<HostTensorND>& src)
-> std::shared_ptr<HostTensorND> {
auto ret = std::make_shared<HostTensorND>(
src->comp_node(), src->shape(), dtype::Float32{});
for (size_t i = 0; i < src->layout().total_nr_elems(); i++) {
ret->ptr<dt_float32>()[i] = src->ptr<dt_float16>()[i];
}
return ret;
};
int nr_get = 0;
auto on_get = [&nr_get](const std::string&, const void*, size_t,
const void*, size_t) { ++nr_get; };
PersistentCacheHook cache_hook{on_get};
auto strategy = S(S::PROFILE | S::REPRODUCIBLE);
using Checker = AutoOprChecker<2, 1>;
auto make_graph =
[&](const Checker::SymInpArray& inputs) -> Checker::SymOutArray {
Policy policy;
policy.strategy = strategy;
return {opr::ConvolutionBackwardData::make_deconv(inputs[0], inputs[1],
param, policy)};
};
auto fwd = [&](Checker::NumOutArray& dest, Checker::NumInpArray inp) {
std::shared_ptr<HostTensorND> out;
conv_bwd_data_brute({f16_to_f32(inp[0]), f16_to_f32(inp[1])}, out,
param);
dest[0] = *f32_to_f16(out);
};
Checker::RunOptions opt;
opt.outputs_max_err = 1e-2;
nr_get = 0;
Checker(make_graph, fwd)
.disable_grad_check()
.set_input_dtype(0, dtype::Float16{})
.set_input_dtype(1, dtype::Float16{})
.set_input_generator(0, gen_fp16)
.set_input_generator(1, gen_fp16)
.run({TensorShape{3, 4, 10, 6}, {4, 2, 3, 3}}, opt)
.run({TensorShape{2, 2, 4, 3}, {2, 2, 3, 3}}, opt)
.run({TensorShape{1, 3, 10, 6}, {3, 2, 3, 3}}, opt);
if (strategy == S::HEURISTIC) {
ASSERT_EQ(0, nr_get);
} else {
ASSERT_LT(0, nr_get);
}
}
#endif
TEST(TestOprDNN, Deconvolution) { TEST(TestOprDNN, Deconvolution) {
// dilated grouped deconv // dilated grouped deconv
using Checker = AutoOprChecker<2, 1>; using Checker = AutoOprChecker<2, 1>;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册