未验证 提交 e5e0b726 编写于 作者: S Sławomir Siwek 提交者: GitHub

conv + elementwise_add refactor (#41286)

* DRY

* change nodes names

* add const prefix

* change asX to as_x in all files
上级 75a17cdb
......@@ -2069,6 +2069,29 @@ PDNode *patterns::Elementwise::operator()(PDNode *x_var, PDNode *y_var,
return out_var;
}
PDNode *patterns::ResidualElementwise::operator()(
PDNode *op_var, PDNode *residual_var, const std::string elementwise_type,
bool as_x) {
auto elementwise_op =
pattern->NewNode(elementwise_op_repr())->assert_is_op(elementwise_type);
if (as_x) {
op_var->AsInput()->assert_is_op_input(elementwise_type, "X");
residual_var->AsInput()->assert_is_op_input(elementwise_type, "Y");
} else {
op_var->AsInput()->assert_is_op_input(elementwise_type, "Y");
residual_var->AsInput()->assert_is_op_input(elementwise_type, "X");
}
auto out_var = pattern->NewNode(elementwise_out_repr())
->AsOutput()
->assert_is_op_output(elementwise_type, "Out");
elementwise_op->LinksFrom({op_var, residual_var});
elementwise_op->LinksTo({out_var});
return out_var;
}
PDNode *patterns::Concat::operator()() {
auto concat_op = pattern->NewNode(concat_op_repr())->assert_is_op("concat");
......
......@@ -1032,6 +1032,22 @@ struct Elementwise : public PatternBase {
PATTERN_DECL_NODE(elementwise_out);
};
// Residual Elementwise ops
// This pattern allows operator output to be X or Y
// and residual data Y or X, based on as_x flag
struct ResidualElementwise : public PatternBase {
ResidualElementwise(PDPattern* pattern, const std::string& name_scope,
bool as_x)
: PatternBase(pattern, name_scope, "residual_elementwise") {}
PDNode* operator()(PDNode* op_var, PDNode* residual_var,
const std::string elementwise_type, bool as_x);
PATTERN_DECL_NODE(operator_output);
PATTERN_DECL_NODE(residual_data);
PATTERN_DECL_NODE(elementwise_op);
PATTERN_DECL_NODE(elementwise_out);
};
// Transpose op
// Forward pass for transpose.
// transpose_out is a result of the operator.
......
......@@ -12,6 +12,9 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include <list>
#include <map>
#include "paddle/fluid/framework/ir/graph_traits.h"
namespace paddle {
......@@ -23,6 +26,51 @@ namespace ir {
//
class Node;
bool IsReachable(ir::Graph *graph, Node *from, Node *to) {
if (from == to) {
return true;
}
std::map<Node *, bool> visited;
for (auto &node : GraphTraits::DFS(*graph)) {
visited[&node] = false;
}
visited[from] = true;
std::list<Node *> queue;
queue.push_back(from);
while (!queue.empty()) {
auto cur = FindNode(graph, queue.front());
queue.pop_front();
if (!cur) return false;
for (const auto &n : cur->outputs) {
if (n == to) {
return true;
}
if (!visited[n]) {
visited[n] = true;
queue.push_back(n);
}
}
}
return false;
}
Node *FindNode(ir::Graph *graph, const Node *node) {
for (const auto &n : graph->Nodes()) {
if (n == node) {
return n;
}
}
return nullptr;
}
NodesDFSIterator::NodesDFSIterator(const std::vector<Node *> &source) {
for (auto *x : source) stack_.push(x);
}
......
......@@ -29,6 +29,9 @@ namespace ir {
class Graph;
class Node;
bool IsReachable(ir::Graph *graph, Node *from, Node *to);
Node *FindNode(ir::Graph *graph, const Node *node);
template <typename IteratorT>
class iterator_range {
IteratorT begin_, end_;
......
......@@ -14,12 +14,6 @@
#include "paddle/fluid/framework/ir/mkldnn/conv_elementwise_add_mkldnn_fuse_pass.h"
#include <functional>
#include <list>
#include <map>
#include <memory>
#include <tuple>
#include "paddle/fluid/framework/ir/graph_traits.h"
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/string/pretty_log.h"
......@@ -28,60 +22,6 @@ namespace paddle {
namespace framework {
namespace ir {
bool IsReachable(ir::Graph* graph, Node* from, Node* to) {
auto find_node = [](ir::Graph* graph, const Node* node) -> Node* {
for (auto n : graph->Nodes()) {
if (n == node) {
return n;
}
}
return nullptr;
};
if (from == to) {
return true;
}
std::map<Node*, bool> visited;
for (auto& node : GraphTraits::DFS(*graph)) {
visited[&node] = false;
}
visited[from] = true;
std::list<Node*> queue;
queue.push_back(from);
while (!queue.empty()) {
auto cur = find_node(graph, queue.front());
queue.pop_front();
if (!cur) return false;
for (auto n : cur->outputs) {
if (n == to) {
return true;
}
if (!visited[n]) {
visited[n] = true;
queue.push_back(n);
}
}
}
return false;
}
template <typename T>
paddle::optional<T> HasAttribute(const Node& op, const std::string& attr) {
if (op.Op()->HasAttr(attr))
return BOOST_GET_CONST(T, op.Op()->GetAttr(attr));
else
return paddle::none;
}
ResidualConnectionMKLDNNFusePass::ResidualConnectionMKLDNNFusePass() {
AddOpCompat(OpCompat("conv2d"))
.AddInput("Input")
......@@ -136,89 +76,22 @@ ResidualConnectionMKLDNNFusePass::ResidualConnectionMKLDNNFusePass() {
.End();
}
GraphWithStats ResidualConnectionMKLDNNFusePass::FuseConvAsX(
const std::string& name_scope,
const GraphWithStats& graph_with_stats) const {
GraphPatternDetector gpd;
auto pattern = gpd.mutable_pattern();
patterns::Conv conv_pattern{pattern, name_scope};
auto conv_output = conv_pattern();
patterns::Elementwise elementwise_pattern{pattern, name_scope};
elementwise_pattern(
conv_output, pattern->NewNode(elementwise_pattern.elementwise_y_repr()),
"elementwise_add");
conv_output->AsIntermediate();
int found_conv_as_x_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) {
GET_IR_NODE_FROM_SUBGRAPH(conv_op, conv_op, conv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(conv_input, conv_input, conv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(conv_filter, conv_filter, conv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(conv_output, conv_output, conv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(elementwise_op, elementwise_op,
elementwise_pattern);
GET_IR_NODE_FROM_SUBGRAPH(elementwise_identity, elementwise_y,
elementwise_pattern);
GET_IR_NODE_FROM_SUBGRAPH(elementwise_out, elementwise_out,
elementwise_pattern);
if (FindFuseOption(*conv_op, *elementwise_op) != FUSE_MKLDNN) return;
if (!IsReachable(g, elementwise_identity, conv_output)) return;
if (HasFusedActivation(conv_op)) return;
if (!IsCompat(subgraph, g)) {
LOG(WARNING)
<< "conv_elementwise_add_mkldnn_fuse_pass in op compat failed.";
return;
}
conv_op->Op()->SetInput("ResidualData", {elementwise_identity->Name()});
conv_op->Op()->SetOutput("Output", {elementwise_out->Name()});
conv_op->Op()->SetAttr("fuse_residual_connection", true);
GraphSafeRemoveNodes(g, {conv_output, elementwise_op});
IR_NODE_LINK_TO(elementwise_identity, conv_op);
IR_NODE_LINK_TO(conv_op, elementwise_out);
found_conv_as_x_count++;
};
gpd(graph_with_stats.first, handler);
if (!Has("disable_logs") || !Get<bool>("disable_logs")) {
std::stringstream msg_ss;
msg_ss << "--- Fused " << found_conv_as_x_count
<< " conv (as x) + elementwise_add patterns";
paddle::string::PrettyLogDetail(msg_ss.str().c_str());
}
return std::make_pair(graph_with_stats.first,
found_conv_as_x_count + graph_with_stats.second);
}
GraphWithStats ResidualConnectionMKLDNNFusePass::FuseConvAsY(
const std::string& name_scope,
const GraphWithStats& graph_with_stats) const {
GraphWithStats ResidualConnectionMKLDNNFusePass::FuseConv(
const std::string& name_scope, const GraphWithStats& graph_with_stats,
bool as_x) const {
GraphPatternDetector gpd;
auto pattern = gpd.mutable_pattern();
patterns::Conv conv_pattern{pattern, name_scope};
auto conv_output = conv_pattern();
patterns::Elementwise elementwise_pattern{pattern, name_scope};
patterns::ResidualElementwise elementwise_pattern{pattern, name_scope, as_x};
elementwise_pattern(
pattern->NewNode(elementwise_pattern.elementwise_x_repr()), conv_output,
"elementwise_add");
conv_output, pattern->NewNode(elementwise_pattern.residual_data_repr()),
"elementwise_add", as_x);
conv_output->AsIntermediate();
int found_conv_as_y_count = 0;
int found_conv_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) {
......@@ -229,15 +102,13 @@ GraphWithStats ResidualConnectionMKLDNNFusePass::FuseConvAsY(
GET_IR_NODE_FROM_SUBGRAPH(elementwise_op, elementwise_op,
elementwise_pattern);
GET_IR_NODE_FROM_SUBGRAPH(elementwise_x, elementwise_x,
GET_IR_NODE_FROM_SUBGRAPH(residual_data, residual_data,
elementwise_pattern);
GET_IR_NODE_FROM_SUBGRAPH(elementwise_out, elementwise_out,
elementwise_pattern);
if (FindFuseOption(*conv_op, *elementwise_op) != FUSE_MKLDNN) return;
if (!IsReachable(g, elementwise_x, conv_output)) return;
if (!IsReachable(g, residual_data, conv_output)) return;
if (HasFusedActivation(conv_op)) return;
if (!IsCompat(subgraph, g)) {
......@@ -246,28 +117,29 @@ GraphWithStats ResidualConnectionMKLDNNFusePass::FuseConvAsY(
return;
}
conv_op->Op()->SetInput("ResidualData", {elementwise_x->Name()});
conv_op->Op()->SetInput("ResidualData", {residual_data->Name()});
conv_op->Op()->SetOutput("Output", {elementwise_out->Name()});
conv_op->Op()->SetAttr("fuse_residual_connection", true);
GraphSafeRemoveNodes(g, {conv_output, elementwise_op});
IR_NODE_LINK_TO(elementwise_x, conv_op);
IR_NODE_LINK_TO(residual_data, conv_op);
IR_NODE_LINK_TO(conv_op, elementwise_out);
found_conv_as_y_count++;
found_conv_count++;
};
gpd(graph_with_stats.first, handler);
if (!Has("disable_logs") || !Get<bool>("disable_logs")) {
std::stringstream msg_ss;
msg_ss << "--- Fused " << found_conv_as_y_count
<< " conv (as y) + elementwise_add patterns";
std::string fusionMode = as_x ? "x" : "y";
msg_ss << "--- Fused " << found_conv_count << " conv (as " << fusionMode
<< ") + elementwise_add patterns";
paddle::string::PrettyLogDetail(msg_ss.str().c_str());
}
return std::make_pair(graph_with_stats.first,
found_conv_as_y_count + graph_with_stats.second);
found_conv_count + graph_with_stats.second);
}
GraphWithStats ResidualConnectionMKLDNNFusePass::FuseProjectionConv(
......@@ -308,7 +180,7 @@ GraphWithStats ResidualConnectionMKLDNNFusePass::FuseProjectionConv(
if (!IsCompat(subgraph, g)) {
LOG(WARNING)
<< "conv_elementwise_add_mkldnn_fuse_pass in op compat failed.";
<< "op compat for conv_elementwise_add_mkldnn_fuse_pass failed.";
return;
}
......@@ -361,8 +233,8 @@ void ResidualConnectionMKLDNNFusePass::ApplyImpl(ir::Graph* graph) const {
FusePassBase::Init(name_scope_, graph);
auto graph_with_stats =
FuseProjectionConv(name_scope_, std::make_pair(graph, 0));
graph_with_stats = FuseConvAsX(name_scope_, graph_with_stats);
graph_with_stats = FuseConvAsY(name_scope_, graph_with_stats);
graph_with_stats = FuseConv(name_scope_, graph_with_stats, true);
graph_with_stats = FuseConv(name_scope_, graph_with_stats, false);
AddStatis(graph_with_stats.second);
}
......
......@@ -14,30 +14,20 @@
#pragma once
#include <memory>
#include <string>
#include <tuple>
#include <utility>
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include <boost/optional.hpp>
namespace paddle {
namespace framework {
namespace ir {
using GraphWithStats = std::pair<ir::Graph*, int>;
bool IsReachable(ir::Graph* graph, Node* from, Node* to);
class ResidualConnectionMKLDNNFusePass : public FusePassBase {
private:
GraphWithStats FuseConvAsX(const std::string& name_scope,
const GraphWithStats& graph_with_stats) const;
GraphWithStats FuseConvAsY(const std::string& name_scope,
const GraphWithStats& graph_with_stats) const;
GraphWithStats FuseConv(const std::string& name_scope,
const GraphWithStats& graph_with_stats,
bool as_x) const;
GraphWithStats FuseProjectionConv(
const std::string& name_scope,
const GraphWithStats& graph_with_stats) const;
......
......@@ -26,7 +26,7 @@ import hypothesis.strategies as st
# the two inputs of elementwise_add are tensor
class TestConvElementwiseAddMkldnnFusePass1(PassAutoScanTest):
class TestConvElementwiseAddMkldnnFusePass(PassAutoScanTest):
def is_program_valid(self, program_config: ProgramConfig) -> bool:
attrs = [
program_config.ops[i].attrs
......@@ -125,139 +125,5 @@ class TestConvElementwiseAddMkldnnFusePass1(PassAutoScanTest):
quant=False, passes=["conv_elementwise_add_mkldnn_fuse_pass"])
'''
class TestConvElementwiseAddMkldnnFusePass(PassAutoScanTest):
def is_program_valid(self, program_config: ProgramConfig) -> bool:
attrs = [
program_config.ops[i].attrs
for i in range(len(program_config.ops))
]
if "elementwise_weight" in program_config.weights:
if program_config.weights["elementwise_weight"].shape[0] == program_config.inputs["input_data1"].shape[1]:
if attrs[2]['axis'] != 1:
return False
if program_config.weights["elementwise_weight"].shape[0] == program_config.inputs["input_data1"].shape[3]:
if attrs[2]['axis'] != -1:
return False
return True
def sample_program_config(self, draw):
data_format = draw(st.sampled_from(["NCHW", "NHWC"]))
dilations = draw(st.sampled_from([[1, 1], [2, 2], [1, 2]]))
padding_algorithm = draw(st.sampled_from(["EXPLICIT", "SAME", "VALID"]))
groups = draw(st.sampled_from([1, 2, 4]))
paddings = draw(st.sampled_from([[0, 3], [1, 1], [1, 2, 3, 4]]))
strides = draw(st.sampled_from([[1, 1], [2, 2], [1, 2]]))
axis = draw(st.sampled_from([-1, 0, 1]))
batch_size = draw(st.integers(min_value=1, max_value=4))
def generate_input1():
if data_format == "NCHW":
return np.random.random(
[batch_size, 48, 64, 64]).astype(np.float32)
else:
return np.random.random(
[batch_size, 64, 64, 48]).astype(np.float32)
def generate_weight1():
return np.random.random(
[48, int(48 / groups), 3, 3]).astype(np.float32)
def compute_out_shape(padding_alg):
import paddle
import paddle.nn as nn
x_var = paddle.uniform(
(batch_size, 48, 64, 64), dtype='float32', min=-1., max=1.)
if padding_alg == "EXPLICIT":
conv = nn.Conv2D(48, 48, (3, 3), strides, paddings, dilations,
1)
else:
conv = nn.Conv2D(48, 48, (3, 3), strides, padding_alg,
dilations, 1)
y_var = conv(x_var)
return y_var.shape
def generate_weight2():
return np.random.random([48]).astype(np.float32)
if compute_out_shape(padding_algorithm) != (batch_size, 48, 64, 64):
axis = 1
relu_op = OpConfig(
type="relu",
inputs={"X": ["input_data1"]},
outputs={"Out": ["sigmoid_out"]},
attrs={})
conv2d_op = OpConfig(
type="conv2d",
inputs={"Input": ["sigmoid_out"],
"Filter": ["conv_weight"]},
outputs={"Output": ["conv_output"]},
attrs={
"data_format": data_format,
"dilations": dilations,
"padding_algorithm": padding_algorithm,
"groups": groups,
"paddings": paddings,
"strides": strides
})
if axis == 0:
elt_op = OpConfig(
type="elementwise_add",
inputs={"X": ["input_data1"],
"Y": ["conv_output"]},
outputs={"Out": ["elementwise_output"]},
attrs={'axis': axis})
else:
elt_op = OpConfig(
type="elementwise_add",
inputs={"X": ["conv_output"],
"Y": ["elementwise_weight"]},
outputs={"Out": ["elementwise_output"]},
attrs={'axis': axis})
model_net = [relu_op, conv2d_op, elt_op]
if axis == 0:
program_config = ProgramConfig(
ops=model_net,
weights={
"conv_weight":
TensorConfig(data_gen=partial(generate_weight1))
},
inputs={
"input_data1":
TensorConfig(data_gen=partial(generate_input1))
},
outputs=["elementwise_output"])
else:
program_config = ProgramConfig(
ops=model_net,
weights={
"conv_weight":
TensorConfig(data_gen=partial(generate_weight1)),
"elementwise_weight":
TensorConfig(data_gen=partial(generate_weight2))
},
inputs={
"input_data1":
TensorConfig(data_gen=partial(generate_input1))
},
outputs=["elementwise_output"])
return program_config
def sample_predictor_configs(self, program_config):
config = self.create_inference_config(use_mkldnn=True)
yield config, ["relu", "conv2d"], (1e-5, 1e-5)
def test(self):
self.run_and_statis(
quant=False, passes=["conv_elementwise_add_mkldnn_fuse_pass"])
'''
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册