未验证 提交 da7c73f8 编写于 作者: A Adam 提交者: GitHub

Delete is_test attribute from activation operators (#23318)

* Delete is_test from activation operators
test=develop

* Revent unneeded changes
test=develop
上级 9fd90674
......@@ -53,11 +53,6 @@ static constexpr bool CanInplaceAct() {
"(bool, default false) Only used in cudnn kernel, need " \
"install cudnn") \
.SetDefault(false); \
AddAttr<bool>( \
"is_test", \
"(bool, default false) Set to true for inference only, false " \
"for training. Some layers may run faster when this is true.") \
.SetDefault(false); \
AddComment(OP_COMMENT); \
} \
}
......@@ -361,10 +356,6 @@ class LeakyReluOpMaker : public framework::OpProtoAndCheckerMaker {
AddAttr<bool>("use_mkldnn",
"(bool, default false) Only used in mkldnn kernel")
.SetDefault(false);
AddAttr<bool>("is_test",
"(bool, default false) Set to true for inference only, false "
"for training. Some layers may run faster when this is true.")
.SetDefault(false);
AddComment(R"DOC(
LeakyRelu Activation Operator.
......@@ -592,10 +583,6 @@ class SwishOpMaker : public framework::OpProtoAndCheckerMaker {
AddAttr<bool>("use_mkldnn",
"(bool, default false) Only used in mkldnn kernel")
.SetDefault(false);
AddAttr<bool>("is_test",
"(bool, default false) Set to true for inference only, false "
"for training. Some layers may run faster when this is true.")
.SetDefault(false);
AddComment(R"DOC(
Swish Activation Operator.
......
......@@ -93,10 +93,6 @@ class GeluOpMaker : public framework::OpProtoAndCheckerMaker {
"(bool, default false) Only used in cudnn kernel, need "
"install cudnn")
.SetDefault(false);
AddAttr<bool>("is_test",
"(bool, default false) Set to true for inference only, false "
"for training. Some layers may run faster when this is true.")
.SetDefault(false);
AddComment(R"DOC(
Gelu Activation Operator.
......
......@@ -54,10 +54,6 @@ class MKLDNNActivationGradKernel
PADDLE_ENFORCE_NE(diff_y->format(), MKLDNNMemoryFormat::undef,
"Wrong format set for Input OutGrad tensor");
PADDLE_ENFORCE_EQ(
ctx.Attr<bool>("is_test"), false,
"is_test attribute should be set to False in training phase.");
Functor functor;
functor(ctx);
}
......@@ -89,11 +85,9 @@ void eltwise_forward(const framework::ExecutionContext &ctx,
auto src_format = src_tz.size() == 2 ? MKLDNNMemoryFormat::nc : x->format();
bool is_test = ctx.Attr<bool>("is_test");
platform::ActivationMKLDNNHandler<T> handler(
src_tz, algorithm, alpha, beta, src_format, is_test, dev_ctx,
ctx.GetPlace(), ctx.InputName("X"));
src_tz, algorithm, alpha, beta, src_format, dev_ctx, ctx.GetPlace(),
ctx.InputName("X"));
auto src_memory_p = handler.AcquireSrcMemory(x);
auto dst_memory_p = handler.AcquireDstMemory(y);
......
......@@ -411,7 +411,7 @@ class ActivationMKLDNNHandler
public:
ActivationMKLDNNHandler(const std::vector<int64_t>& dims,
mkldnn::algorithm algorithm, float alpha, float beta,
const MKLDNNMemoryFormat fmt, bool is_test,
const MKLDNNMemoryFormat fmt,
const platform::MKLDNNDeviceContext& dev_ctx,
platform::Place cpu_place,
const std::string& unique_name)
......@@ -422,9 +422,7 @@ class ActivationMKLDNNHandler
platform::CreateKey(dims, unique_name)) {
auto md = mkldnn::memory::desc(dims, platform::MKLDNNGetDataType<T>(), fmt);
this->AcquireForwardPrimitiveDescriptor(
is_test ? mkldnn::prop_kind::forward_inference
: mkldnn::prop_kind::forward_training,
this->AcquireForwardPrimitiveDescriptor(mkldnn::prop_kind::forward_training,
algorithm, md, alpha, beta);
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册