提交 8abc3ab8 编写于 作者: M Megvii Engine Team

fix(imperative): fix convolution in rocm

GitOrigin-RevId: 9e97099fd5ccccf13dbdda393efd5cd004dd1be4
上级 3b1101b5
...@@ -214,6 +214,12 @@ public: ...@@ -214,6 +214,12 @@ public:
_megdnn_tensor_in src, _megdnn_tensor_in filter, _megdnn_tensor_out dst, _megdnn_tensor_in src, _megdnn_tensor_in filter, _megdnn_tensor_out dst,
const PreprocessedFilter* preprocessed_filter, const PreprocessedFilter* preprocessed_filter,
_megdnn_workspace workspace) = 0; _megdnn_workspace workspace) = 0;
MGE_WIN_DECLSPEC_FUC void exec(
_megdnn_tensor_in src, _megdnn_tensor_in filter, _megdnn_tensor_out dst,
_megdnn_workspace workspace) {
exec(src, filter, dst, nullptr, workspace);
}
/** /**
* \brief execute weight preprocessing, read weights form filter and write * \brief execute weight preprocessing, read weights form filter and write
* to preprocessed_filter after preprocessed. * to preprocessed_filter after preprocessed.
......
...@@ -57,6 +57,28 @@ SmallVector<TensorPtr> apply_on_physical_tensor( ...@@ -57,6 +57,28 @@ SmallVector<TensorPtr> apply_on_physical_tensor(
// create megdnn opr // create megdnn opr
auto&& conv = def.cast_final_safe<Convolution>(); auto&& conv = def.cast_final_safe<Convolution>();
CompNode cn = inputs[0]->comp_node(); CompNode cn = inputs[0]->comp_node();
// calling dnn ConvolutionForward when device is rocm
// because there is no dnn ConvBiasForward on rocm
if (cn.device_type() == CompNode::DeviceType::ROCM) {
DnnOprCaller<megdnn::ConvolutionForward> dnn_opr(
cn, conv.param(), conv.policy());
auto out_layout = [&] {
if (validated) {
return output_descs[0].layout;
} else {
return dnn_opr.deduce_layout(inputs[0]->layout(), inputs[1]->layout());
}
}();
// alloc memory
auto out = Tensor::make(out_layout, cn);
dnn_opr.exec_fastrun(inputs[0], inputs[1], out);
return {out};
}
// calling dnn ConvBiasForward on cuda because it's faster then ConvolutionForward
// ConvolutionForward internally uses ConvBiasForward to calculate the result
auto&& param = conv_bias_param_from_convolution(conv); auto&& param = conv_bias_param_from_convolution(conv);
DnnOprCaller<megdnn::ConvBiasForward> dnn_opr(cn, param, conv.policy()); DnnOprCaller<megdnn::ConvBiasForward> dnn_opr(cn, param, conv.policy());
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册