提交 c3f8cf04 编写于 作者: M Megvii Engine Team

feat(dnn): add conv_bwd_data and conv_bwd_filter accuracy shake check

GitOrigin-RevId: 4069e083d2218b8a5ce2ea77c3c7d5f81acc6149
上级 717b88e6
......@@ -236,7 +236,13 @@ public:
TensorLayout& grad_pg);
MEGDNN_DECL_ALGO_TYPE(CUDA_GROUP_CONV_GENERAL)
AlgoAttribute attribute() const override {
auto ret = static_cast<AlgoAttribute>(0);
auto ret = AlgoAttribute::DEFAULT;
#define cb(attr) \
if (m_impl->contain_attribute_all(attr)) { \
ret |= attr; \
}
MEGDNN_FOREACH_ALGO_ATTRIBUTE_INHERITABLE(cb)
#undef cb
if (m_impl->contain_attribute_all(AlgoAttribute::REPRODUCIBLE)) {
ret |= AlgoAttribute::REPRODUCIBLE;
}
......
......@@ -169,6 +169,7 @@ public:
opr, layouts)) {
if (!(algo_info.attribute &
AlgoAttribute::ACCURACY_DEPEND_ON_BATCH) &&
(algo_info.attribute & AlgoAttribute::REPRODUCIBLE) &&
std::regex_match(
algo_info.desc.name,
std::regex("(.*)(" + m_policy_name.name + ")(.*)"))) {
......
......@@ -241,6 +241,41 @@ TEST_F(CUDA, SHAKE_LOCAL_SHARE) {
checker.exec({{20, 16, 32, 32}, {3, 3, 16, 3, 3, 64}, {}});
}
TEST_F(CUDA, SHAKE_CONVOLUTION_BACKWARD_DATA) {
AccuracyShakeChecker<ConvolutionBackwardData> checker(handle_cuda());
NormalRNG default_rng;
checker.set_dtype(0, dtype::Float32())
.set_dtype(1, dtype::Float32())
.set_rng(0, &default_rng)
.set_rng(1, &default_rng);
// ConvolutionBackwardData
checker.exec({{8, 16, 3, 3}, {64, 8, 5, 5}, {64, 16, 7, 7}});
// group
ConvolutionBackwardData::Param param;
param.sparse = Convolution::Param::Sparse::GROUP;
checker.set_param(param);
checker.exec({{2, 16, 32, 3, 3}, {2, 32, 5, 5}, {2, 64, 7, 7}});
checker.exec({{2, 8, 32, 3, 3}, {64, 16, 19, 19}, {64, 64, 21, 21}});
}
TEST_F(CUDA, SHAKE_CONVOLUTION_BACKWARD_FILTER) {
AccuracyShakeChecker<ConvolutionBackwardFilter> checker(handle_cuda());
NormalRNG default_rng;
checker.set_dtype(0, dtype::Float32())
.set_dtype(1, dtype::Float32())
.set_rng(0, &default_rng)
.set_rng(1, &default_rng);
// ConvolutionBackwardFilter
checker.exec({{2, 64, 7, 7}, {2, 32, 5, 5}, {32, 64, 3, 3}});
// group
ConvolutionBackwardFilter::Param param;
param.sparse = Convolution::Param::Sparse::GROUP;
checker.set_param(param);
checker.exec({{2, 64, 7, 7}, {2, 32, 5, 5}, {2, 16, 32, 3, 3}});
}
} // namespace test
} // namespace megdnn
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册