未验证 提交 050a9bf7 编写于 作者: J Jacek Czaja 提交者: GitHub

[oneDNN] LRN cleanup (#25416)

上级 1974aadc
...@@ -59,15 +59,6 @@ class LRNMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -59,15 +59,6 @@ class LRNMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
{MKLDNN_ARG_DST, *dst_memory}, {MKLDNN_ARG_DST, *dst_memory},
{MKLDNN_ARG_WORKSPACE, *workspace_memory}}); {MKLDNN_ARG_WORKSPACE, *workspace_memory}});
} else { } else {
// mid has to be allocated and filled
// k to pass LRN unit tests
// TODO(jczaja): Disable checking mid in unit tests (Require API change)
mid->mutable_data<T>(ctx.GetPlace());
auto e_mid = framework::EigenTensor<T, 4>::From(*mid);
const float k = ctx.Attr<float>("k");
e_mid = e_mid.constant(k);
mid->set_format(platform::GetMKLDNNFormat(*dst_memory));
lrn_p->execute(astream, {{MKLDNN_ARG_SRC, *src_memory}, lrn_p->execute(astream, {{MKLDNN_ARG_SRC, *src_memory},
{MKLDNN_ARG_DST, *dst_memory}}); {MKLDNN_ARG_DST, *dst_memory}});
} }
...@@ -85,7 +76,7 @@ class LRNMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> { ...@@ -85,7 +76,7 @@ class LRNMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
const bool is_float_type = std::is_same<T, float>::value; const bool is_float_type = std::is_same<T, float>::value;
PADDLE_ENFORCE_EQ(is_float_type, true, PADDLE_ENFORCE_EQ(is_float_type, true,
platform::errors::PreconditionNotMet( platform::errors::PreconditionNotMet(
"DNNL LRN GradOpKernl must use float data.")); "DNNL LRN GradOpKernel must use float data."));
PADDLE_ENFORCE_EQ(platform::is_cpu_place(ctx.GetPlace()), true, PADDLE_ENFORCE_EQ(platform::is_cpu_place(ctx.GetPlace()), true,
paddle::platform::errors::PreconditionNotMet( paddle::platform::errors::PreconditionNotMet(
"Operator DNNL LRNGrad must use CPUPlace")); "Operator DNNL LRNGrad must use CPUPlace"));
......
...@@ -26,8 +26,10 @@ class TestLRNMKLDNNOp(TestLRNOp): ...@@ -26,8 +26,10 @@ class TestLRNMKLDNNOp(TestLRNOp):
return attrs return attrs
def test_check_output(self): def test_check_output(self):
# We cannot validate MidOut as LRN REF has diffrent meaning in it
# TODO(wangzhongpu): support mkldnn op in dygraph mode # TODO(wangzhongpu): support mkldnn op in dygraph mode
self.check_output(atol=0.002, check_dygraph=False) self.check_output(
atol=0.002, no_check_set=['MidOut'], check_dygraph=False)
def test_check_grad_normal(self): def test_check_grad_normal(self):
# TODO(wangzhongpu): support mkldnn op in dygraph mode # TODO(wangzhongpu): support mkldnn op in dygraph mode
......
...@@ -17,6 +17,7 @@ no_check_set_white_list = [ ...@@ -17,6 +17,7 @@ no_check_set_white_list = [
'fake_quantize_range_abs_max', 'fake_quantize_range_abs_max',
'coalesce_tensor', 'coalesce_tensor',
'flatten2', 'flatten2',
'lrn',
'squeeze2', 'squeeze2',
'reshape2', 'reshape2',
'transpose2', 'transpose2',
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册