未验证 提交 68e27f35 编写于 作者: C Chen Weihang 提交者: GitHub

fix gcc54 compile failed (#47172)

上级 dc64db15
...@@ -835,9 +835,11 @@ class ConvMKLDNNOpKernel : public framework::OpKernel<T> { ...@@ -835,9 +835,11 @@ class ConvMKLDNNOpKernel : public framework::OpKernel<T> {
ctx.template device_context<platform::MKLDNNDeviceContext>(); ctx.template device_context<platform::MKLDNNDeviceContext>();
const auto& mkldnn_engine = dev_ctx.GetEngine(); const auto& mkldnn_engine = dev_ctx.GetEngine();
const bool is_test = ctx.Attr<bool>("is_test"); bool is_test = ctx.Attr<bool>("is_test");
const bool is_conv3d = ctx.Attr<std::vector<int>>("strides").size() == 3U; const auto& strides = ctx.Attr<std::vector<int>>("strides");
const bool fuse_residual_conn = ctx.Attr<bool>("fuse_residual_connection"); bool is_conv3d = strides.size() == 3UL;
bool fuse_residual_conn = ctx.Attr<bool>("fuse_residual_connection");
int groups = ctx.Attr<int>("groups");
const auto* input = ctx.Input<phi::DenseTensor>("Input"); const auto* input = ctx.Input<phi::DenseTensor>("Input");
const auto* filter = ctx.Input<phi::DenseTensor>("Filter"); const auto* filter = ctx.Input<phi::DenseTensor>("Filter");
...@@ -861,7 +863,7 @@ class ConvMKLDNNOpKernel : public framework::OpKernel<T> { ...@@ -861,7 +863,7 @@ class ConvMKLDNNOpKernel : public framework::OpKernel<T> {
auto src_memory_p = handler.AcquireSrcMemoryWithReorder(input); auto src_memory_p = handler.AcquireSrcMemoryWithReorder(input);
auto weights_memory_p = handler.AcquireWeightsMemoryWithReorder( auto weights_memory_p = handler.AcquireWeightsMemoryWithReorder(
filter, ctx.Attr<int>("groups"), is_conv3d, is_test); filter, groups, is_conv3d, is_test);
std::shared_ptr<dnnl::memory> dst_memory_p; std::shared_ptr<dnnl::memory> dst_memory_p;
if (fuse_residual_conn) { if (fuse_residual_conn) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册