提交 81c84737 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!5936 fix CPU BatchNorm infer error

Merge pull request !5936 from zhaoting/mobilenet
...@@ -50,11 +50,13 @@ void FusedBatchNormCPUKernel::InitKernel(const CNodePtr &kernel_node) { ...@@ -50,11 +50,13 @@ void FusedBatchNormCPUKernel::InitKernel(const CNodePtr &kernel_node) {
dnnl::memory::desc scale_bias_desc = GetDefaultMemDesc({2, channel}); dnnl::memory::desc scale_bias_desc = GetDefaultMemDesc({2, channel});
auto epsilon = AnfAlgo::GetNodeAttr<float>(kernel_node, "epsilon"); auto epsilon = AnfAlgo::GetNodeAttr<float>(kernel_node, "epsilon");
auto prop_kind = dnnl::prop_kind::forward_inference; auto prop_kind = dnnl::prop_kind::forward_inference;
auto normalization_flags = dnnl::normalization_flags::use_scale_shift | dnnl::normalization_flags::use_global_stats;
if (is_train) { if (is_train) {
prop_kind = dnnl::prop_kind::forward_training; prop_kind = dnnl::prop_kind::forward_training;
normalization_flags = dnnl::normalization_flags::use_scale_shift;
} }
dnnl::batch_normalization_forward::desc desc = dnnl::batch_normalization_forward::desc desc =
dnnl::batch_normalization_forward::desc(prop_kind, x_desc, epsilon, dnnl::normalization_flags::use_scale_shift); dnnl::batch_normalization_forward::desc(prop_kind, x_desc, epsilon, normalization_flags);
auto prim_desc = dnnl::batch_normalization_forward::primitive_desc(desc, MKLKernelEngine::Get().engine()); auto prim_desc = dnnl::batch_normalization_forward::primitive_desc(desc, MKLKernelEngine::Get().engine());
primitive_ = std::make_shared<dnnl::batch_normalization_forward>(prim_desc); primitive_ = std::make_shared<dnnl::batch_normalization_forward>(prim_desc);
AddArgument(DNNL_ARG_SRC, x_desc); AddArgument(DNNL_ARG_SRC, x_desc);
...@@ -74,14 +76,14 @@ bool FusedBatchNormCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inpu ...@@ -74,14 +76,14 @@ bool FusedBatchNormCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inpu
auto wksp = reinterpret_cast<float *>(workspace[0]->addr); auto wksp = reinterpret_cast<float *>(workspace[0]->addr);
memcpy_s(wksp, workspace[0]->size, inputs[1]->addr, inputs[1]->size); memcpy_s(wksp, workspace[0]->size, inputs[1]->addr, inputs[1]->size);
memcpy_s(wksp + (inputs[1]->size / sizeof(float)), inputs[2]->size, inputs[2]->addr, inputs[2]->size); memcpy_s(wksp + (inputs[1]->size / sizeof(float)), inputs[2]->size, inputs[2]->addr, inputs[2]->size);
SetArgumentHandle(DNNL_ARG_SRC, inputs[0]->addr);
SetArgumentHandle(DNNL_ARG_MEAN, outputs[3]->addr);
SetArgumentHandle(DNNL_ARG_VARIANCE, outputs[4]->addr);
SetArgumentHandle(DNNL_ARG_SCALE_SHIFT, workspace[0]->addr);
SetArgumentHandle(DNNL_ARG_DST, outputs[0]->addr);
ExecutePrimitive();
if (is_train) { if (is_train) {
SetArgumentHandle(DNNL_ARG_SRC, inputs[0]->addr);
SetArgumentHandle(DNNL_ARG_MEAN, outputs[3]->addr);
SetArgumentHandle(DNNL_ARG_VARIANCE, outputs[4]->addr);
SetArgumentHandle(DNNL_ARG_SCALE_SHIFT, workspace[0]->addr);
SetArgumentHandle(DNNL_ARG_DST, outputs[0]->addr);
ExecutePrimitive();
auto moving_mean = reinterpret_cast<float *>(inputs[3]->addr); auto moving_mean = reinterpret_cast<float *>(inputs[3]->addr);
auto moving_variance = reinterpret_cast<float *>(inputs[4]->addr); auto moving_variance = reinterpret_cast<float *>(inputs[4]->addr);
auto mean = reinterpret_cast<float *>(outputs[3]->addr); auto mean = reinterpret_cast<float *>(outputs[3]->addr);
...@@ -90,6 +92,13 @@ bool FusedBatchNormCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inpu ...@@ -90,6 +92,13 @@ bool FusedBatchNormCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inpu
moving_mean[i] = moving_mean[i] * (1 - momentum) + mean[i] * momentum; moving_mean[i] = moving_mean[i] * (1 - momentum) + mean[i] * momentum;
moving_variance[i] = moving_variance[i] * (1 - momentum) + variance[i] * momentum; moving_variance[i] = moving_variance[i] * (1 - momentum) + variance[i] * momentum;
} }
} else {
SetArgumentHandle(DNNL_ARG_SRC, inputs[0]->addr);
SetArgumentHandle(DNNL_ARG_MEAN, inputs[3]->addr);
SetArgumentHandle(DNNL_ARG_VARIANCE, inputs[4]->addr);
SetArgumentHandle(DNNL_ARG_SCALE_SHIFT, workspace[0]->addr);
SetArgumentHandle(DNNL_ARG_DST, outputs[0]->addr);
ExecutePrimitive();
} }
return true; return true;
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册