提交 b9874251 编写于 作者: T Tomasz Patejko

Plain LRN op throws an exception when is_test is set in backward pass

上级 95658767
...@@ -214,7 +214,10 @@ class LRNOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -214,7 +214,10 @@ class LRNOpMaker : public framework::OpProtoAndCheckerMaker {
"Defaults to \"NHWC\". Specify the data format of the output data, " "Defaults to \"NHWC\". Specify the data format of the output data, "
"the input will be transformed automatically. ") "the input will be transformed automatically. ")
.SetDefault("AnyLayout"); .SetDefault("AnyLayout");
AddAttr<bool>("is_test", "").SetDefault(false); AddAttr<bool>("is_test",
"Turns on memory optimization that optimizes away "
"unnecessary memory allocations. Used by MKLDNN.")
.SetDefault(false);
AddComment(R"DOC( AddComment(R"DOC(
Local Response Normalization Operator. Local Response Normalization Operator.
......
...@@ -121,6 +121,10 @@ class LRNGradKernel : public framework::OpKernel<T> { ...@@ -121,6 +121,10 @@ class LRNGradKernel : public framework::OpKernel<T> {
T alpha = ctx.Attr<T>("alpha"); T alpha = ctx.Attr<T>("alpha");
T beta = ctx.Attr<T>("beta"); T beta = ctx.Attr<T>("beta");
PADDLE_ENFORCE(
!ctx.Attr<bool>("is_test"),
"is_test attribute should be set to False in training phase.");
LRNGradFunctor<DeviceContext, T> f; LRNGradFunctor<DeviceContext, T> f;
f(ctx, x, out, mid, x_g, out_g, N, C, H, W, n, alpha, beta); f(ctx, x, out, mid, x_g, out_g, N, C, H, W, n, alpha, beta);
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册