未验证 提交 79ed7177 编写于 作者: 王明冬 提交者: GitHub

add method for enhance pass,test=develop (#33004)

上级 7be6191b
......@@ -50,6 +50,7 @@ if (WITH_TESTING)
endif(WITH_TESTING)
cc_library(graph_pattern_detector SRCS graph_pattern_detector.cc DEPS ${GRAPH_PATTERN_DETECTOR_DEPS})
cc_library(op_compat_sensible_pass SRCS op_compat_sensible_pass.cc DEPS graph_pattern_detector)
cc_library(subgraph_detector SRCS subgraph_detector.cc DEPS graph_pattern_detector executor)
cc_library(fuse_pass_base SRCS fuse_pass_base.cc DEPS pass)
cc_library(placement_pass_base SRCS placement_pass_base.cc DEPS pass)
......@@ -139,6 +140,7 @@ cc_test(graph_test SRCS graph_test.cc DEPS graph graph_helper op_registry)
cc_test(graph_helper_test SRCS graph_helper_test.cc DEPS graph graph_helper op_registry)
cc_test(graph_to_program_pass_test SRCS graph_to_program_pass_test.cc DEPS graph_to_program_pass)
cc_test(test_graph_pattern_detector SRCS graph_pattern_detector_tester.cc DEPS graph_pattern_detector)
cc_test(test_op_compat_sensible_pass SRCS op_compat_sensible_pass_tester.cc DEPS op_compat_sensible_pass)
cc_test(test_fc_fuse_pass_cc SRCS fc_fuse_pass_tester.cc DEPS fc_fuse_pass framework_proto)
cc_test(test_fc_lstm_fuse_pass_cc SRCS fc_lstm_fuse_pass_tester.cc DEPS fc_lstm_fuse_pass framework_proto)
cc_test(test_fc_gru_fuse_pass_cc SRCS fc_gru_fuse_pass_tester.cc DEPS fc_gru_fuse_pass framework_proto)
......
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <memory>
#include "paddle/fluid/framework/ir/op_compat_sensible_pass.h"
namespace paddle {
namespace framework {
namespace ir {
AttrCompat& AttrCompat::IsStringIn(const std::set<std::string>& candidates) {
conditions_.emplace_back([candidates](const Attribute& attr) -> bool {
std::string value = BOOST_GET_CONST(std::string, attr);
for (auto& str : candidates) {
if (str == value) {
return true;
}
}
return false;
});
return *this;
}
AttrCompat& AttrCompat::IsStringMatch(
const std::function<bool(const std::string&)>& func) {
conditions_.emplace_back([func](const Attribute& attr) -> bool {
std::string value = BOOST_GET_CONST(std::string, attr);
return func(value);
});
return *this;
}
AttrCompat& AttrCompat::IsIntIn(const std::set<int>& candidates) {
conditions_.emplace_back([candidates](const Attribute& attr) -> bool {
int value = BOOST_GET_CONST(int, attr);
return candidates.find(value) != candidates.end();
});
return *this;
}
//! Todo: append the definition.
AttrCompat& AttrCompat::IsLeftDefault() { return *this; }
bool AttrCompat::operator()(const OpDesc& op_desc) {
if (!op_desc.HasAttr(attr_name_)) {
return false;
}
const Attribute attr = op_desc.GetAttr(attr_name_);
for (auto& func : conditions_) {
if (!func(attr)) {
return false;
}
}
return true;
}
AttrCompat& AttrCompat::IsBoolEQ(bool v) {
conditions_.emplace_back([v](const Attribute& attr) -> bool {
bool value = BOOST_GET_CONST(bool, attr);
return value == v;
});
return *this;
}
InputOrOutputCompat& InputOrOutputCompat::IsTensor() {
conditions_.emplace_back([](const std::vector<std::string>& input) -> bool {
return input.size() == 1u;
});
return *this;
}
InputOrOutputCompat& InputOrOutputCompat::IsOptional() {
optional_ = true;
return *this;
}
bool InputOrOutputCompat::operator()(
const std::vector<std::string>& input) const {
if (input.empty()) return false;
for (auto& func : conditions_) {
if (!func(input)) {
return false;
}
}
return true;
}
AttrCompat& OpCompat::AddAttr(const std::string& attr_name) {
attr_compats_.emplace_back(attr_name, this);
return attr_compats_.back();
}
InputOrOutputCompat& OpCompat::AddInput(const std::string& name) {
PADDLE_ENFORCE_EQ(input_compats_.find(name), input_compats_.end(),
platform::errors::InvalidArgument(
"The input with the same name has been added"));
input_compats_.emplace(name, InputOrOutputCompat(name, this));
return input_compats_.at(name);
}
InputOrOutputCompat& OpCompat::AddOutput(const std::string& name) {
PADDLE_ENFORCE_EQ(output_compats_.find(name), output_compats_.end(),
platform::errors::InvalidArgument(
"The output with the same name has been added"));
output_compats_.emplace(name, InputOrOutputCompat(name, this));
return output_compats_.at(name);
}
bool OpCompat::Judge(const OpDesc& op_desc) {
for (auto& attr_compat : attr_compats_) {
if (!attr_compat(op_desc)) {
return false;
}
}
const VariableNameMap& inputs_map = op_desc.Inputs();
for (auto& input_desc : inputs_map) {
if (input_compats_.find(input_desc.first) == input_compats_.end()) {
if (!input_desc.second.empty()) {
return false;
}
}
}
for (auto& input_val : input_compats_) {
if (inputs_map.find(input_val.first) == inputs_map.end()) {
if (!input_val.second.Optional()) {
return false;
}
} else {
if (!input_val.second(inputs_map.at(input_val.first))) {
return false;
}
}
}
const VariableNameMap& outputs_map = op_desc.Outputs();
for (auto& output_desc : outputs_map) {
if (output_compats_.find(output_desc.first) == output_compats_.end()) {
if (!output_desc.second.empty()) {
return false;
}
}
}
for (auto& output_val : output_compats_) {
if (outputs_map.find(output_val.first) == outputs_map.end()) {
if (!output_val.second.Optional()) {
return false;
}
} else {
if (!output_val.second(outputs_map.at(output_val.first))) {
return false;
}
}
}
return true;
}
OpCompat& OpCompatSensiblePass::AddOpCompat(OpCompat&& op_compat) {
std::string name = op_compat.Name();
op_compat_judgers_[name].reset(new OpCompat(std::move(op_compat)));
return *(op_compat_judgers_[name]);
}
} // namespace ir
} // namespace framework
} // namespace paddle
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <map>
#include <vector>
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/framework/ir/pass.h"
namespace paddle {
namespace framework {
namespace ir {
class OpCompat;
class AttrCompat {
public:
AttrCompat(const std::string& attr_name, OpCompat* op_compat)
: attr_name_(attr_name), op_compat_(op_compat) {}
// @{ String-related methods
//! Assert the attribute is an string in the `candidates` domain.
AttrCompat& IsStringIn(const std::set<std::string>& candidates);
//! Assert the attribute is a string and match a custom judging function.
AttrCompat& IsStringMatch(
const std::function<bool(const std::string&)>& func);
// @}
//! Assert the attribute is an integer in the `candidates` domain.
AttrCompat& IsIntIn(const std::set<int>& candidates);
// @{ Number-releated methods
//! Assert the attribute is a number and > `v`.
template <typename T>
AttrCompat& IsNumGT(T v);
//! Assert the attribute is a number and >= `v`.
template <typename T>
AttrCompat& IsNumGE(T v);
//! Assert the attribute is a number and < `v`.
template <typename T>
AttrCompat& IsNumLT(T v);
//! Assert the attribute is a number and <= `v`.
template <typename T>
AttrCompat& IsNumLE(T v);
//! Assert the attribute is a number and == `v`.
template <typename T>
AttrCompat& IsNumEQ(T v);
//! Assert the attribute is a number and matches a customized judging
//! function.
template <typename T>
AttrCompat& IsNumMatch(bool (*func)(T));
// @}
//! Assert the attribute is a boolean value equals `v`.
AttrCompat& IsBoolEQ(bool v);
//! Tell whether this attribute is left as default value.
AttrCompat& IsLeftDefault();
//! Jump back to retrieve OpCompat instance.
OpCompat& End() { return *op_compat_; }
bool operator()(const OpDesc& op_desc);
private:
std::string attr_name_;
OpCompat* op_compat_;
std::vector<std::function<bool(const Attribute&)>> conditions_;
};
class InputOrOutputCompat {
public:
InputOrOutputCompat(const std::string& name, OpCompat* op_compat)
: optional_(false), name_(name), op_compat_(op_compat) {}
InputOrOutputCompat& IsTensor();
InputOrOutputCompat& IsOptional();
bool Optional() const { return optional_; }
bool operator()(const std::vector<std::string>& input) const;
//! Jump back to retrieve OpCompat instance.
OpCompat& End() { return *op_compat_; }
private:
bool optional_;
std::string name_;
OpCompat* op_compat_;
std::vector<std::function<bool(const std::vector<std::string>&)>> conditions_;
};
/**
* OpCompat is a helper class to help define the compatible Op definition.
*
* Usage:
* OpCompat compat("FC");
* compat.AddAttr("in_num_col_dims").IsNumLE(1).End()
* .AddAttr("activation_type").IsStringIn({"tanh", "sigmoid"}).End()
* .AddInput("Input").IsTensor().End()
* .AddInput("W").IsTensor().End()
* .AddInput("Bias").IsTensor().IsOptional().End()
* .AddOutput("Out").IsTensor().End()
*
* All the inference-aware Op defition is as above, all the other attributes not
* contained in the definition should be set default value or it would be judged
* incompatible.
*/
class OpCompat {
public:
explicit OpCompat(const std::string& op_name) : op_name_(op_name) {}
explicit OpCompat(std::string&& op_name) : op_name_(std::move(op_name)) {}
explicit OpCompat(const OpCompat&) = default;
explicit OpCompat(OpCompat&&) = default;
AttrCompat& AddAttr(const std::string& attr_name);
InputOrOutputCompat& AddInput(const std::string& name);
InputOrOutputCompat& AddOutput(const std::string& name);
//! Judge whether an OpDesc match the defined Op compatibility.
bool Judge(const OpDesc& op_desc);
const std::string& Name() const { return op_name_; }
private:
std::string op_name_;
std::vector<AttrCompat> attr_compats_;
std::unordered_map<std::string, InputOrOutputCompat> input_compats_;
std::unordered_map<std::string, InputOrOutputCompat> output_compats_;
};
/**
* OpCompatSensiblePass is a base class for all the passes thouse is sensitive
* to Op update.
* There are two methods to help tell the compability of an Op
* bool IsCompat(const GraphPatternDetector::subgraph_t& subgraph, Graph* g);
* bool IsCompat(const OpDesc& op_desc);
*
* One can register the related Op compabilities using
* void AddOpCompat(OpCompat&& judger);
*
* Most of the Passes are used for fusing ops, so we define a method for such
* scenerios.
* void AccessSubgraph(const GraphPatternDetector::subgraph_t& subgraph,
Graph* g);
* It will check the Op compatibility automatically.
* For other scenirios, one should call `IsCompat` by himself.
*
* A FC fuse pass example:
* class FcFusePass : public OpCompatSensiblePass {
* public:
* FcFusePass() {
* // define Mul op compatiblity.
* AddOpCompat(OpCompat("Mul"))
* .AddInput("Input").IsTensor().End()
* .AddAttr("in_num_col_dims").IsNumGE(1);
* AddOpCompat(OpCompat("Add")). ...;
* // There are multiple activation implemention.
* AddOpCompat(OpCompat("Tanh")). ...;
* AddOpCompat(OpCompat("Sigmoid")). ...;
* }
*
* // override the subgraph access method
* virtual bool AccessSubgraphImpl(
* const GraphPatternDetector::subgraph_t& subgraph,
* Graph* g) override { ... }
*
* // Call the AccessSubgraph method in main procedure of this Pass.
* };
*/
class OpCompatSensiblePass : public Pass {
public:
//! Access the subgraph and pattern.
void AccessSubgraph(const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) {
if (IsCompat(subgraph, g)) {
AccessSubgraphImpl(subgraph, g);
}
}
protected:
/**
* Developer should push the compatibility `teller` for each kind of Op in the
* subgraph.
* NOTE One should add all the related op compatiblity in the construct so
* that all the following methods are valid.
*/
OpCompat& AddOpCompat(OpCompat&& op_compat);
//! Modify the subgraph.
virtual bool AccessSubgraphImpl(
const GraphPatternDetector::subgraph_t& subgraph, Graph* g) const {
return true;
}
//! Tell the Op compability of a subgraph.
bool IsCompat(const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) const {
CHECK(!op_compat_judgers_.empty())
<< "At least one OpCompat instance should be added in the "
"OpCompatSensiblePass.";
// Check the all the ops in the subgraph are contained in the
// op_compat.
for (auto& node_pair : subgraph) {
if (!node_pair.first->IsOp()) continue;
auto op_type = node_pair.second->Op()->Type();
if (!op_compat_judgers_.count(op_type)) {
return false;
}
auto& judger = *op_compat_judgers_.at(op_type);
if (!judger.Judge(*(node_pair.second->Op()))) {
return false;
}
}
return true;
}
//! Tell the op compatibility of a single Op.
bool IsCompat(const OpDesc& op_desc) const {
if (!op_compat_judgers_.count(op_desc.Type())) return false;
return op_compat_judgers_.at(op_desc.Type())->Judge(op_desc);
}
private:
std::map<std::string, std::unique_ptr<OpCompat>> op_compat_judgers_;
};
template <typename T>
AttrCompat& AttrCompat::IsNumGT(T v) {
conditions_.emplace_back([v](const Attribute& attr) -> bool {
T value = BOOST_GET_CONST(T, attr);
return value > v;
});
return *this;
}
template <typename T>
AttrCompat& AttrCompat::IsNumGE(T v) {
conditions_.emplace_back([v](const Attribute& attr) -> bool {
T value = BOOST_GET_CONST(T, attr);
return value >= v;
});
return *this;
}
template <typename T>
AttrCompat& AttrCompat::IsNumLT(T v) {
conditions_.emplace_back([v](const Attribute& attr) -> bool {
T value = BOOST_GET_CONST(T, attr);
return value < v;
});
return *this;
}
template <typename T>
AttrCompat& AttrCompat::IsNumLE(T v) {
conditions_.emplace_back([v](const Attribute& attr) -> bool {
T value = BOOST_GET_CONST(T, attr);
return value <= v;
});
return *this;
}
template <typename T>
AttrCompat& AttrCompat::IsNumEQ(T v) {
conditions_.emplace_back([v](const Attribute& attr) -> bool {
T value = BOOST_GET_CONST(T, attr);
return value == v;
});
return *this;
}
template <typename T>
AttrCompat& AttrCompat::IsNumMatch(bool (*func)(T)) {
conditions_.emplace_back([func](const Attribute& attr) -> bool {
T value = BOOST_GET_CONST(T, attr);
return func(value);
});
return *this;
}
} // namespace ir
} // namespace framework
} // namespace paddle
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/ir/op_compat_sensible_pass.h"
#include "gtest/gtest.h"
#include "paddle/fluid/framework/program_desc.h"
namespace paddle {
namespace framework {
namespace ir {
TEST(OpCompatSensiblePass, compatOp) {
auto lambda = [](const std::string& str) { return str == "tanh"; };
OpCompat compat("FC");
compat.AddAttr("in_num_col_dims")
.IsIntIn({1, 2})
.IsNumLE(1)
.IsLeftDefault()
.End()
.AddAttr("activation_type")
.IsStringIn({"tanh", "sigmoid"})
.IsStringMatch(lambda)
.End()
.AddAttr("test_attr")
.IsBoolEQ(true)
.End()
.AddInput("Input")
.IsTensor()
.End()
.AddInput("W")
.IsTensor()
.End()
.AddInput("Bias")
.IsTensor()
.IsOptional()
.End()
.AddInput("Test")
.IsOptional()
.End()
.AddOutput("Out")
.IsTensor()
.End();
OpDesc fc_op;
std::unordered_map<std::string, Attribute> attr_map;
attr_map["in_num_col_dims"] = 1;
attr_map["activation_type"] = std::string("tanh");
attr_map["test_attr"] = true;
fc_op.SetAttrMap(attr_map);
fc_op.SetInput("Input", std::vector<std::string>{"test_input"});
fc_op.SetInput("W", std::vector<std::string>{"test_input_0"});
fc_op.SetInput("Bias", std::vector<std::string>{"test_input_1"});
fc_op.SetOutput("Out", std::vector<std::string>{"test_output"});
EXPECT_STREQ(compat.Name().c_str(), "FC");
EXPECT_TRUE(compat.Judge(fc_op));
}
class OpCompatSensiblePassTest : public OpCompatSensiblePass {
public:
OpCompatSensiblePassTest();
bool TestIsCompat(const OpDesc& op_desc) { return IsCompat(op_desc); }
};
OpCompatSensiblePassTest::OpCompatSensiblePassTest() {
AddOpCompat(OpCompat("FC"))
.AddAttr("in_num_col_dims")
.IsNumLE(1)
.End()
.AddAttr("activation_type")
.IsStringIn({"tanh", "sigmoid"})
.End()
.AddInput("Input")
.IsTensor()
.End()
.AddInput("W")
.IsTensor()
.End()
.AddInput("Bias")
.IsTensor()
.IsOptional()
.End()
.AddOutput("Out")
.IsTensor();
}
TEST(OpCompatSensiblePass, IsCompat) {
OpCompatSensiblePassTest test;
OpDesc fc_op;
fc_op.SetType("FC");
std::unordered_map<std::string, Attribute> attr_map;
attr_map["in_num_col_dims"] = 1;
attr_map["activation_type"] = std::string("tanh");
fc_op.SetAttrMap(attr_map);
fc_op.SetInput("Input", std::vector<std::string>{"test_input"});
fc_op.SetInput("W", std::vector<std::string>{"test_input_0"});
fc_op.SetInput("Bias", std::vector<std::string>{"test_input_1"});
fc_op.SetOutput("Out", std::vector<std::string>{"test_output"});
EXPECT_TRUE(test.TestIsCompat(fc_op));
ProgramDesc prog;
std::unique_ptr<Graph> g(new Graph(prog));
Node* o1 = g->CreateOpNode(&fc_op);
GraphPatternDetector detector;
PDNode* op2 =
detector.mutable_pattern()->NewNode([](Node* x) { return true; });
GraphPatternDetector::subgraph_t subgraph;
subgraph[op2] = o1;
test.AccessSubgraph(subgraph, g.get());
}
} // namespace ir
} // namespace framework
} // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册