diff --git a/paddle/fluid/framework/ir/simplify_with_basic_ops_pass.cc b/paddle/fluid/framework/ir/simplify_with_basic_ops_pass.cc index dff2f2451dac4ca985c206b7913e42fc563be4c3..282bac4e1634de4a47e573b60a9040abbfc90258 100644 --- a/paddle/fluid/framework/ir/simplify_with_basic_ops_pass.cc +++ b/paddle/fluid/framework/ir/simplify_with_basic_ops_pass.cc @@ -34,6 +34,26 @@ namespace ir { */ class Graph; +SimplifyWithBasicOpsPass::SimplifyWithBasicOpsPass() { + AddOpCompat(OpCompat("scale")) + .AddInput("X") + .IsTensor() + .End() + .AddOutput("Out") + .IsTensor() + .End() + .AddAttr("scale") + .IsNumGE(0.f) + .IsNumLE(1.f) + .End() + .AddAttr("bias") + .IsNumEQ(0.f) + .End() + .AddAttr("bias_after_scale") + .IsNumEQ(true) + .End(); +} + void SimplifyWithBasicOpsPass::ApplyImpl(Graph* graph) const { VLOG(3) << "Simplify the Graph with basic ops."; std::unordered_set del_node_set; @@ -145,6 +165,11 @@ bool SimplifyWithBasicOpsPass::SimplifyDropout( new_op_desc.SetAttr("bias", static_cast(0)); new_op_desc.SetAttr("bias_after_scale", true); + if (!IsCompat(new_op_desc)) { + LOG(WARNING) << "Basic ops pass in scale op compat failed."; + return false; + } + auto* scale_op_node = graph->CreateOpNode(&new_op_desc); IR_NODE_LINK_TO(dropout_x, scale_op_node); IR_NODE_LINK_TO(scale_op_node, dropout_out); diff --git a/paddle/fluid/framework/ir/simplify_with_basic_ops_pass.h b/paddle/fluid/framework/ir/simplify_with_basic_ops_pass.h index 6a245c444a7ec8dd800d8432693d2fa247360634..e80de5e1cd9d1e51acebab613a1dc543eb354da6 100644 --- a/paddle/fluid/framework/ir/simplify_with_basic_ops_pass.h +++ b/paddle/fluid/framework/ir/simplify_with_basic_ops_pass.h @@ -17,7 +17,7 @@ limitations under the License. */ #include #include -#include "paddle/fluid/framework/ir/pass.h" +#include "paddle/fluid/framework/ir/op_compat_sensible_pass.h" namespace paddle { namespace framework { @@ -26,7 +26,10 @@ namespace ir { class Graph; class Node; -class SimplifyWithBasicOpsPass : public Pass { +class SimplifyWithBasicOpsPass : public OpCompatSensiblePass { + public: + SimplifyWithBasicOpsPass(); + protected: void ApplyImpl(Graph* graph) const override;