From 6da6ff6a4dfba49cf55ae9622af99fb3901faef4 Mon Sep 17 00:00:00 2001 From: wenbin Date: Fri, 18 Jun 2021 18:59:29 +0800 Subject: [PATCH] SimplifyWithBasicOpsPass (#33637) * simplify_with_basic * fix * scale factor --- .../ir/simplify_with_basic_ops_pass.cc | 25 +++++++++++++++++++ .../ir/simplify_with_basic_ops_pass.h | 7 ++++-- 2 files changed, 30 insertions(+), 2 deletions(-) diff --git a/paddle/fluid/framework/ir/simplify_with_basic_ops_pass.cc b/paddle/fluid/framework/ir/simplify_with_basic_ops_pass.cc index dff2f2451da..282bac4e163 100644 --- a/paddle/fluid/framework/ir/simplify_with_basic_ops_pass.cc +++ b/paddle/fluid/framework/ir/simplify_with_basic_ops_pass.cc @@ -34,6 +34,26 @@ namespace ir { */ class Graph; +SimplifyWithBasicOpsPass::SimplifyWithBasicOpsPass() { + AddOpCompat(OpCompat("scale")) + .AddInput("X") + .IsTensor() + .End() + .AddOutput("Out") + .IsTensor() + .End() + .AddAttr("scale") + .IsNumGE(0.f) + .IsNumLE(1.f) + .End() + .AddAttr("bias") + .IsNumEQ(0.f) + .End() + .AddAttr("bias_after_scale") + .IsNumEQ(true) + .End(); +} + void SimplifyWithBasicOpsPass::ApplyImpl(Graph* graph) const { VLOG(3) << "Simplify the Graph with basic ops."; std::unordered_set del_node_set; @@ -145,6 +165,11 @@ bool SimplifyWithBasicOpsPass::SimplifyDropout( new_op_desc.SetAttr("bias", static_cast(0)); new_op_desc.SetAttr("bias_after_scale", true); + if (!IsCompat(new_op_desc)) { + LOG(WARNING) << "Basic ops pass in scale op compat failed."; + return false; + } + auto* scale_op_node = graph->CreateOpNode(&new_op_desc); IR_NODE_LINK_TO(dropout_x, scale_op_node); IR_NODE_LINK_TO(scale_op_node, dropout_out); diff --git a/paddle/fluid/framework/ir/simplify_with_basic_ops_pass.h b/paddle/fluid/framework/ir/simplify_with_basic_ops_pass.h index 6a245c444a7..e80de5e1cd9 100644 --- a/paddle/fluid/framework/ir/simplify_with_basic_ops_pass.h +++ b/paddle/fluid/framework/ir/simplify_with_basic_ops_pass.h @@ -17,7 +17,7 @@ limitations under the License. */ #include #include -#include "paddle/fluid/framework/ir/pass.h" +#include "paddle/fluid/framework/ir/op_compat_sensible_pass.h" namespace paddle { namespace framework { @@ -26,7 +26,10 @@ namespace ir { class Graph; class Node; -class SimplifyWithBasicOpsPass : public Pass { +class SimplifyWithBasicOpsPass : public OpCompatSensiblePass { + public: + SimplifyWithBasicOpsPass(); + protected: void ApplyImpl(Graph* graph) const override; -- GitLab