未验证 提交 da7c73f8 编写于 作者: A Adam 提交者: GitHub

Delete is_test attribute from activation operators (#23318)

* Delete is_test from activation operators
test=develop

* Revent unneeded changes
test=develop
上级 9fd90674
...@@ -53,11 +53,6 @@ static constexpr bool CanInplaceAct() { ...@@ -53,11 +53,6 @@ static constexpr bool CanInplaceAct() {
"(bool, default false) Only used in cudnn kernel, need " \ "(bool, default false) Only used in cudnn kernel, need " \
"install cudnn") \ "install cudnn") \
.SetDefault(false); \ .SetDefault(false); \
AddAttr<bool>( \
"is_test", \
"(bool, default false) Set to true for inference only, false " \
"for training. Some layers may run faster when this is true.") \
.SetDefault(false); \
AddComment(OP_COMMENT); \ AddComment(OP_COMMENT); \
} \ } \
} }
...@@ -361,10 +356,6 @@ class LeakyReluOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -361,10 +356,6 @@ class LeakyReluOpMaker : public framework::OpProtoAndCheckerMaker {
AddAttr<bool>("use_mkldnn", AddAttr<bool>("use_mkldnn",
"(bool, default false) Only used in mkldnn kernel") "(bool, default false) Only used in mkldnn kernel")
.SetDefault(false); .SetDefault(false);
AddAttr<bool>("is_test",
"(bool, default false) Set to true for inference only, false "
"for training. Some layers may run faster when this is true.")
.SetDefault(false);
AddComment(R"DOC( AddComment(R"DOC(
LeakyRelu Activation Operator. LeakyRelu Activation Operator.
...@@ -592,10 +583,6 @@ class SwishOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -592,10 +583,6 @@ class SwishOpMaker : public framework::OpProtoAndCheckerMaker {
AddAttr<bool>("use_mkldnn", AddAttr<bool>("use_mkldnn",
"(bool, default false) Only used in mkldnn kernel") "(bool, default false) Only used in mkldnn kernel")
.SetDefault(false); .SetDefault(false);
AddAttr<bool>("is_test",
"(bool, default false) Set to true for inference only, false "
"for training. Some layers may run faster when this is true.")
.SetDefault(false);
AddComment(R"DOC( AddComment(R"DOC(
Swish Activation Operator. Swish Activation Operator.
......
...@@ -93,10 +93,6 @@ class GeluOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -93,10 +93,6 @@ class GeluOpMaker : public framework::OpProtoAndCheckerMaker {
"(bool, default false) Only used in cudnn kernel, need " "(bool, default false) Only used in cudnn kernel, need "
"install cudnn") "install cudnn")
.SetDefault(false); .SetDefault(false);
AddAttr<bool>("is_test",
"(bool, default false) Set to true for inference only, false "
"for training. Some layers may run faster when this is true.")
.SetDefault(false);
AddComment(R"DOC( AddComment(R"DOC(
Gelu Activation Operator. Gelu Activation Operator.
......
...@@ -54,10 +54,6 @@ class MKLDNNActivationGradKernel ...@@ -54,10 +54,6 @@ class MKLDNNActivationGradKernel
PADDLE_ENFORCE_NE(diff_y->format(), MKLDNNMemoryFormat::undef, PADDLE_ENFORCE_NE(diff_y->format(), MKLDNNMemoryFormat::undef,
"Wrong format set for Input OutGrad tensor"); "Wrong format set for Input OutGrad tensor");
PADDLE_ENFORCE_EQ(
ctx.Attr<bool>("is_test"), false,
"is_test attribute should be set to False in training phase.");
Functor functor; Functor functor;
functor(ctx); functor(ctx);
} }
...@@ -89,11 +85,9 @@ void eltwise_forward(const framework::ExecutionContext &ctx, ...@@ -89,11 +85,9 @@ void eltwise_forward(const framework::ExecutionContext &ctx,
auto src_format = src_tz.size() == 2 ? MKLDNNMemoryFormat::nc : x->format(); auto src_format = src_tz.size() == 2 ? MKLDNNMemoryFormat::nc : x->format();
bool is_test = ctx.Attr<bool>("is_test");
platform::ActivationMKLDNNHandler<T> handler( platform::ActivationMKLDNNHandler<T> handler(
src_tz, algorithm, alpha, beta, src_format, is_test, dev_ctx, src_tz, algorithm, alpha, beta, src_format, dev_ctx, ctx.GetPlace(),
ctx.GetPlace(), ctx.InputName("X")); ctx.InputName("X"));
auto src_memory_p = handler.AcquireSrcMemory(x); auto src_memory_p = handler.AcquireSrcMemory(x);
auto dst_memory_p = handler.AcquireDstMemory(y); auto dst_memory_p = handler.AcquireDstMemory(y);
......
...@@ -411,7 +411,7 @@ class ActivationMKLDNNHandler ...@@ -411,7 +411,7 @@ class ActivationMKLDNNHandler
public: public:
ActivationMKLDNNHandler(const std::vector<int64_t>& dims, ActivationMKLDNNHandler(const std::vector<int64_t>& dims,
mkldnn::algorithm algorithm, float alpha, float beta, mkldnn::algorithm algorithm, float alpha, float beta,
const MKLDNNMemoryFormat fmt, bool is_test, const MKLDNNMemoryFormat fmt,
const platform::MKLDNNDeviceContext& dev_ctx, const platform::MKLDNNDeviceContext& dev_ctx,
platform::Place cpu_place, platform::Place cpu_place,
const std::string& unique_name) const std::string& unique_name)
...@@ -422,9 +422,7 @@ class ActivationMKLDNNHandler ...@@ -422,9 +422,7 @@ class ActivationMKLDNNHandler
platform::CreateKey(dims, unique_name)) { platform::CreateKey(dims, unique_name)) {
auto md = mkldnn::memory::desc(dims, platform::MKLDNNGetDataType<T>(), fmt); auto md = mkldnn::memory::desc(dims, platform::MKLDNNGetDataType<T>(), fmt);
this->AcquireForwardPrimitiveDescriptor( this->AcquireForwardPrimitiveDescriptor(mkldnn::prop_kind::forward_training,
is_test ? mkldnn::prop_kind::forward_inference
: mkldnn::prop_kind::forward_training,
algorithm, md, alpha, beta); algorithm, md, alpha, beta);
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册