From b9874251c623a17c7db8c5c3c7214ae8b451a52f Mon Sep 17 00:00:00 2001 From: Tomasz Patejko Date: Fri, 30 Mar 2018 03:12:33 -0400 Subject: [PATCH] Plain LRN op throws an exception when is_test is set in backward pass --- paddle/fluid/operators/lrn_op.cc | 5 ++++- paddle/fluid/operators/lrn_op.h | 4 ++++ 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/paddle/fluid/operators/lrn_op.cc b/paddle/fluid/operators/lrn_op.cc index b36b5c3a33..cb15683981 100644 --- a/paddle/fluid/operators/lrn_op.cc +++ b/paddle/fluid/operators/lrn_op.cc @@ -214,7 +214,10 @@ class LRNOpMaker : public framework::OpProtoAndCheckerMaker { "Defaults to \"NHWC\". Specify the data format of the output data, " "the input will be transformed automatically. ") .SetDefault("AnyLayout"); - AddAttr("is_test", "").SetDefault(false); + AddAttr("is_test", + "Turns on memory optimization that optimizes away " + "unnecessary memory allocations. Used by MKLDNN.") + .SetDefault(false); AddComment(R"DOC( Local Response Normalization Operator. diff --git a/paddle/fluid/operators/lrn_op.h b/paddle/fluid/operators/lrn_op.h index 95796f7eec..0fd3175e85 100644 --- a/paddle/fluid/operators/lrn_op.h +++ b/paddle/fluid/operators/lrn_op.h @@ -121,6 +121,10 @@ class LRNGradKernel : public framework::OpKernel { T alpha = ctx.Attr("alpha"); T beta = ctx.Attr("beta"); + PADDLE_ENFORCE( + !ctx.Attr("is_test"), + "is_test attribute should be set to False in training phase."); + LRNGradFunctor f; f(ctx, x, out, mid, x_g, out_g, N, C, H, W, n, alpha, beta); } -- GitLab