提交 3abe0b24 编写于 作者: M Megvii Engine Team

fix(mgb): fix rocm pooling

GitOrigin-RevId: 44876d398ed56214d71e117ada0b21085a2c05de
上级 f9722af3
......@@ -60,10 +60,10 @@ void PoolingForwardImpl::AlgoMIOpen::init_mode(
case param::Pooling::Mode::MAX:
mode = miopenPoolingMax;
break;
case param::Pooling::Mode::AVERAGE:
case param::Pooling::Mode::AVERAGE_COUNT_EXCLUDE_PADDING:
mode = miopenPoolingAverage;
break;
case param::Pooling::Mode::AVERAGE_COUNT_EXCLUDE_PADDING:
case param::Pooling::Mode::AVERAGE:
mode = miopenPoolingAverageInclusive;
break;
default:
......@@ -96,7 +96,7 @@ void PoolingForwardImpl::AlgoMIOpen::exec(const ExecArgs& args) const {
miopen_check(miopenPoolingForward(
handle, miopen_desc, &alpha, src_desc.desc,
args.src_tensor->raw_ptr, &beta, dst_desc.desc,
args.src_tensor->raw_ptr, false, nullptr, 0_z));
args.dst_tensor->raw_ptr, false, nullptr, 0_z));
miopen_check(miopenDestroyPoolingDescriptor(miopen_desc));
}
......@@ -163,10 +163,10 @@ void PoolingBackwardImpl::AlgoMIOpen::init_mode(const ExecArgs& args,
case param::Pooling::Mode::MAX:
mode = miopenPoolingMax;
break;
case param::Pooling::Mode::AVERAGE:
case param::Pooling::Mode::AVERAGE_COUNT_EXCLUDE_PADDING:
mode = miopenPoolingAverage;
break;
case param::Pooling::Mode::AVERAGE_COUNT_EXCLUDE_PADDING:
case param::Pooling::Mode::AVERAGE:
mode = miopenPoolingAverageInclusive;
break;
default:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册