未验证 提交 050a9bf7 编写于 作者: J Jacek Czaja 提交者: GitHub

[oneDNN] LRN cleanup (#25416)

上级 1974aadc
......@@ -59,15 +59,6 @@ class LRNMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
{MKLDNN_ARG_DST, *dst_memory},
{MKLDNN_ARG_WORKSPACE, *workspace_memory}});
} else {
// mid has to be allocated and filled
// k to pass LRN unit tests
// TODO(jczaja): Disable checking mid in unit tests (Require API change)
mid->mutable_data<T>(ctx.GetPlace());
auto e_mid = framework::EigenTensor<T, 4>::From(*mid);
const float k = ctx.Attr<float>("k");
e_mid = e_mid.constant(k);
mid->set_format(platform::GetMKLDNNFormat(*dst_memory));
lrn_p->execute(astream, {{MKLDNN_ARG_SRC, *src_memory},
{MKLDNN_ARG_DST, *dst_memory}});
}
......@@ -85,7 +76,7 @@ class LRNMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
const bool is_float_type = std::is_same<T, float>::value;
PADDLE_ENFORCE_EQ(is_float_type, true,
platform::errors::PreconditionNotMet(
"DNNL LRN GradOpKernl must use float data."));
"DNNL LRN GradOpKernel must use float data."));
PADDLE_ENFORCE_EQ(platform::is_cpu_place(ctx.GetPlace()), true,
paddle::platform::errors::PreconditionNotMet(
"Operator DNNL LRNGrad must use CPUPlace"));
......
......@@ -26,8 +26,10 @@ class TestLRNMKLDNNOp(TestLRNOp):
return attrs
def test_check_output(self):
# We cannot validate MidOut as LRN REF has diffrent meaning in it
# TODO(wangzhongpu): support mkldnn op in dygraph mode
self.check_output(atol=0.002, check_dygraph=False)
self.check_output(
atol=0.002, no_check_set=['MidOut'], check_dygraph=False)
def test_check_grad_normal(self):
# TODO(wangzhongpu): support mkldnn op in dygraph mode
......
......@@ -17,6 +17,7 @@ no_check_set_white_list = [
'fake_quantize_range_abs_max',
'coalesce_tensor',
'flatten2',
'lrn',
'squeeze2',
'reshape2',
'transpose2',
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册