未验证 提交 fc5b3a99 编写于 作者: 王明冬 提交者: GitHub

add the fc fuse example for pass enhance, test=develop (#33250)

上级 200d57c7
......@@ -27,7 +27,12 @@ add_subdirectory(fleet)
add_subdirectory(io)
#ddim lib
proto_library(framework_proto SRCS framework.proto)
proto_library(op_def_proto SRCS op_def.proto)
set(OP_DEF_FOLDER "${PADDLE_SOURCE_DIR}/paddle/fluid/operators/compat/")
configure_file("op_def_api.h.in" "op_def_api.h")
cc_library(op_def_api SRCS op_def_api.cc DEPS op_def_proto)
proto_library(heter_service_proto SRCS heter_service.proto)
proto_library(data_feed_proto SRCS data_feed.proto)
proto_library(trainer_desc_proto SRCS trainer_desc.proto DEPS framework_proto
......
......@@ -50,7 +50,7 @@ if (WITH_TESTING)
endif(WITH_TESTING)
cc_library(graph_pattern_detector SRCS graph_pattern_detector.cc DEPS ${GRAPH_PATTERN_DETECTOR_DEPS})
cc_library(op_compat_sensible_pass SRCS op_compat_sensible_pass.cc DEPS graph_pattern_detector)
cc_library(op_compat_sensible_pass SRCS op_compat_sensible_pass.cc DEPS graph_pattern_detector op_def_api)
cc_library(subgraph_detector SRCS subgraph_detector.cc DEPS graph_pattern_detector executor)
cc_library(fuse_pass_base SRCS fuse_pass_base.cc DEPS op_compat_sensible_pass)
cc_library(placement_pass_base SRCS placement_pass_base.cc DEPS pass)
......
......@@ -13,8 +13,8 @@
// limitations under the License.
#include "paddle/fluid/framework/ir/fc_fuse_pass.h"
#include <string>
#include "paddle/fluid/framework/op_proto_maker.h"
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/platform/enforce.h"
......@@ -23,6 +23,65 @@ namespace paddle {
namespace framework {
namespace ir {
FCFusePass::FCFusePass() {
AddOpCompat(OpCompat("mul"))
.AddInput("X")
.IsTensor()
.End()
.AddInput("Y")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddAttr("x_num_col_dims")
.IsNumGE(1)
.End()
.AddAttr("y_num_col_dims")
.End();
AddOpCompat(OpCompat("elementwise_add"))
.AddInput("X")
.IsTensor()
.End()
.AddInput("Y")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddAttr("axis")
.End();
AddOpCompat(OpCompat("relu"))
.AddInput("X")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End();
AddOpCompat(OpCompat("fc"))
.AddInput("Input")
.IsTensor()
.End()
.AddInput("W")
.IsTensor()
.End()
.AddInput("Bias")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddAttr("in_num_col_dims")
.IsNumGE(1)
.End()
.AddAttr("activation_type")
.IsStringIn({"relu", ""})
.End();
}
void FCFusePass::ApplyImpl(ir::Graph* graph) const {
PADDLE_ENFORCE_NOT_NULL(
graph, platform::errors::InvalidArgument("Graph cannot be nullptr."));
......@@ -52,6 +111,10 @@ int FCFusePass::ApplyFCPattern(Graph* graph, bool with_relu) const {
LOG(WARNING) << "The subgraph is empty.";
return;
}
if (!IsCompat(subgraph, g)) {
LOG(WARNING) << "Pass in op compat failed.";
return;
}
VLOG(4) << "handle FC fuse";
GET_IR_NODE_FROM_SUBGRAPH(w, w, fc_pattern);
......@@ -159,6 +222,11 @@ int FCFusePass::ApplyFCPattern(Graph* graph, bool with_relu) const {
}
desc.Flush();
if (!IsCompat(desc)) {
LOG(WARNING) << "Fc fuse pass in out fc op compat failed.";
return;
}
auto fc_node = g->CreateOpNode(&desc); // OpDesc will be copied.
if (with_relu) {
GraphSafeRemoveNodes(
......
......@@ -30,6 +30,7 @@ class Graph;
class FCFusePass : public FusePassBase {
public:
FCFusePass();
virtual ~FCFusePass() {}
protected:
......
......@@ -12,10 +12,13 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <memory>
#include "paddle/fluid/framework/ir/op_compat_sensible_pass.h"
#include <memory>
#include <mutex>
#include <unordered_map>
#include "paddle/fluid/framework/op_def_api.h"
#include "paddle/fluid/framework/op_info.h"
namespace paddle {
namespace framework {
namespace ir {
......@@ -50,18 +53,17 @@ AttrCompat& AttrCompat::IsIntIn(const std::set<int>& candidates) {
return *this;
}
//! Todo: append the definition.
AttrCompat& AttrCompat::IsLeftDefault() {
const std::string& op_name = op_compat_->Name();
if (!OpInfoMap::Instance().Has(op_name)) {
VLOG(3) << "Op (" << op_name << ") is not registered!";
LOG(WARNING) << "Op (" << op_name << ") is not registered!";
conditions_.emplace_back([](const Attribute& attr) { return false; });
return *this;
}
const OpInfo& op_info = OpInfoMap::Instance().Get(op_name);
const AttributeMap attrs = op_info.Checker()->GetAttrsDefaultValuesMap();
if (attrs.find(attr_name_) == attrs.end()) {
VLOG(3) << "Op (" << op_name << ") has no default attr:" << attr_name_;
LOG(WARNING) << "Op (" << op_name << ") has no default attr:" << attr_name_;
conditions_.emplace_back([](const Attribute& attr) { return false; });
} else {
Attribute default_attr = attrs.at(attr_name_);
......@@ -77,6 +79,10 @@ bool AttrCompat::operator()(const OpDesc& op_desc) {
return true;
}
if (!op_desc.HasAttr(attr_name_)) {
if (!optional_) {
LOG(WARNING) << "The non-optional Attr(" << attr_name_ << ") of Op ("
<< op_compat_->Name() << ") not find ! ";
}
return optional_;
}
const Attribute attr = op_desc.GetAttr(attr_name_);
......@@ -149,19 +155,35 @@ InputOrOutputCompat& OpCompat::AddOutput(const std::string& name) {
}
bool OpCompat::Judge(const OpDesc& op_desc) {
if (is_first_judge_) {
is_first_judge_ = false;
const proto::OpDef& op_def = GetOpDef(op_name_);
if (op_def.has_extra()) {
for (const proto::OpDef_AttrDef& attr : op_def.extra().attrs()) {
extra_attrs_.emplace(attr.name());
}
}
}
for (auto& attr_map : op_desc.GetAttrMap()) {
if (attr_compats_.find(attr_map.first) == attr_compats_.end()) {
if (extra_attrs_.find(attr_map.first) != extra_attrs_.end()) {
continue;
}
if (!AttrCompat(attr_map.first, this).IsLeftDefault()(op_desc)) {
VLOG(3) << "The Attr(" << attr_map.first << ") of Op (" << op_name_
<< ") not reigistered in OpCompat, not equal to default value!";
LOG(WARNING)
<< "The Attr(" << attr_map.first << ") of Op (" << op_name_
<< ") not reigistered in OpCompat, not in extra attribute, not "
"equal to default value!";
return false;
}
}
}
for (auto& attr_compat : attr_compats_) {
if (!attr_compat.second(op_desc)) {
VLOG(3) << " Check the Attr(" << attr_compat.first << ") of Op("
<< op_name_ << ") failed!";
LOG(WARNING) << " Check the Attr(" << attr_compat.first << ") of Op("
<< op_name_ << ") failed!";
return false;
}
}
......@@ -170,8 +192,8 @@ bool OpCompat::Judge(const OpDesc& op_desc) {
for (auto& input_desc : inputs_map) {
if (input_compats_.find(input_desc.first) == input_compats_.end()) {
if (!input_desc.second.empty()) {
VLOG(3) << "The Input (" << input_desc.first << ") of Operator ("
<< op_name_ << ") not reigistered in OpCompat!";
LOG(WARNING) << "The Input (" << input_desc.first << ") of Operator ("
<< op_name_ << ") not reigistered in OpCompat!";
return false;
}
}
......@@ -179,14 +201,15 @@ bool OpCompat::Judge(const OpDesc& op_desc) {
for (auto& input_val : input_compats_) {
if (inputs_map.find(input_val.first) == inputs_map.end()) {
if (!input_val.second.Optional()) {
VLOG(3) << "The No optional Input (" << input_val.first
<< ") of Operator (" << op_name_ << ") not find in op_desc!";
LOG(WARNING) << "The No optional Input (" << input_val.first
<< ") of Operator (" << op_name_
<< ") not find in op_desc!";
return false;
}
} else {
if (!input_val.second(inputs_map.at(input_val.first))) {
VLOG(3) << "The Input (" << input_val.first << ") of Operator ("
<< op_name_ << ") compat check failed!";
LOG(WARNING) << "The Input (" << input_val.first << ") of Operator ("
<< op_name_ << ") compat check failed!";
return false;
}
}
......@@ -196,8 +219,8 @@ bool OpCompat::Judge(const OpDesc& op_desc) {
for (auto& output_desc : outputs_map) {
if (output_compats_.find(output_desc.first) == output_compats_.end()) {
if (!output_desc.second.empty()) {
VLOG(3) << "The Output (" << output_desc.first << ") of Operator ("
<< op_name_ << ") not reigistered in OpCompat!";
LOG(WARNING) << "The Output (" << output_desc.first << ") of Operator ("
<< op_name_ << ") not reigistered in OpCompat!";
return false;
}
}
......@@ -205,14 +228,15 @@ bool OpCompat::Judge(const OpDesc& op_desc) {
for (auto& output_val : output_compats_) {
if (outputs_map.find(output_val.first) == outputs_map.end()) {
if (!output_val.second.Optional()) {
VLOG(3) << "The No optional Output (" << output_val.first
<< ") of Operator (" << op_name_ << ") not find in op_desc!";
LOG(WARNING) << "The No optional Output (" << output_val.first
<< ") of Operator (" << op_name_
<< ") not find in op_desc!";
return false;
}
} else {
if (!output_val.second(outputs_map.at(output_val.first))) {
VLOG(3) << "The Output (" << output_val.first << ") of Operator ("
<< op_name_ << ") compat check failed!";
LOG(WARNING) << "The Output (" << output_val.first << ") of Operator ("
<< op_name_ << ") compat check failed!";
return false;
}
}
......
......@@ -140,6 +140,8 @@ class OpCompat {
std::unordered_map<std::string, AttrCompat> attr_compats_;
std::unordered_map<std::string, InputOrOutputCompat> input_compats_;
std::unordered_map<std::string, InputOrOutputCompat> output_compats_;
std::unordered_set<std::string> extra_attrs_;
bool is_first_judge_ = true;
};
/**
......@@ -203,6 +205,7 @@ class OpCompatSensiblePass : public Pass {
if (!node_pair.second->IsOp()) continue;
auto op_type = node_pair.second->Op()->Type();
if (!op_compat_judgers_.count(op_type)) {
LOG(WARNING) << op_type << "compat not registered!";
return false;
}
auto& judger = *op_compat_judgers_.at(op_type);
......
......@@ -27,7 +27,6 @@ TEST(OpCompatSensiblePass, compatOp) {
compat.AddAttr("in_num_col_dims")
.IsIntIn({1, 2})
.IsNumLE(1)
.IsLeftDefault()
.End()
.AddAttr("activation_type")
.IsStringIn({"tanh", "sigmoid"})
......@@ -68,7 +67,7 @@ TEST(OpCompatSensiblePass, compatOp) {
fc_op.SetOutput("Out", std::vector<std::string>{"test_output"});
EXPECT_STREQ(compat.Name().c_str(), "fc");
EXPECT_FALSE(compat.Judge(fc_op));
EXPECT_TRUE(compat.Judge(fc_op));
}
TEST(OpCompatSensiblePass, compatOpAttribute) {
......
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#if defined _WIN32 || defined __APPLE__
#else
#define _LINUX
#endif
#include "paddle/fluid/framework/op_def_api.h"
#include <fstream>
#include <mutex>
#include <string>
#include <unordered_map>
#ifdef _LINUX
#include <stdio_ext.h>
#include <sys/mman.h>
#include <sys/stat.h>
#endif
#include <google/protobuf/io/zero_copy_stream_impl.h>
#include <google/protobuf/text_format.h>
#include "glog/logging.h"
#include "io/fs.h"
#include "paddle/fluid/framework/op_def.pb.h"
namespace paddle {
namespace framework {
const proto::OpDef& GetOpDef(const std::string& op_name) {
static std::unordered_map<std::string, proto::OpDef> ops_definition;
static std::mutex mtx;
if (ops_definition.find(op_name) == ops_definition.end()) {
std::lock_guard<std::mutex> lk(mtx);
if (ops_definition.find(op_name) == ops_definition.end()) {
proto::OpDef op_def;
std::string op_path = OP_DEF_FOLDER + op_name + ".pbtxt";
int fd = open(op_path.c_str(), O_RDONLY);
if (fd == -1) {
LOG(WARNING) << op_path << " open failed!";
} else {
::google::protobuf::io::FileInputStream* input =
new ::google::protobuf::io::FileInputStream(fd);
if (!::google::protobuf::TextFormat::Parse(input, &op_def)) {
LOG(WARNING) << "Failed to parse " << op_path;
}
delete input;
close(fd);
}
ops_definition.emplace(std::make_pair(op_name, std::move(op_def)));
}
}
return ops_definition.at(op_name);
}
} // namespace framework
} // namespace paddle
// the folder of pbtxt with op attribute definition
#pragma once
#include "paddle/fluid/framework/op_def.pb.h"
#define OP_DEF_FOLDER "@OP_DEF_FOLDER@"
namespace paddle {
namespace framework {
const proto::OpDef& GetOpDef(const std::string& op_name);
}
}
type: "elementwise_add"
def {
inputs {
name: "X"
}
inputs {
name: "Y"
}
outputs {
name: "Out"
}
attrs {
name: "axis"
type: INT
}
}
extra {
attrs {
name: "x_data_format"
type: STRING
# no longer to use
}
attrs {
name: "y_data_format"
type: STRING
# no longer to use
}
attrs {
name: "use_quantizer"
type: BOOLEAN
# no longer to use, Use 'mkldnn_data_type' instead.
}
attrs {
name: "use_mkldnn"
type: BOOLEAN
}
attrs {
name: "mkldnn_data_type"
type: STRING
}
attrs {
name: "Scale_x"
type: FLOAT
}
attrs {
name: "Scale_y"
type: FLOAT
}
attrs {
name: "Scale_out"
type: FLOAT
}
attrs {
name: "op_role"
type: INT
}
attrs {
name: "op_role_var"
type: STRINGS
}
attrs {
name: "op_namescope"
type: STRING
}
attrs {
name: "op_callstack"
type: STRINGS
}
attrs {
name: "op_device"
type: STRING
}
}
type: "fc"
def {
inputs {
name: "Input"
}
inputs {
name: "W"
}
inputs {
name: "Bias"
}
outputs {
name: "Out"
}
attrs {
name: "in_num_col_dims"
type: INT
}
attrs {
name: "activation_type"
type: STRING
}
}
extra {
attrs {
name: "use_mkldnn"
type: BOOLEAN
}
attrs {
name: "padding_weights"
type: BOOLEAN
}
attrs {
name: "@ALL_KERNELS_MUST_COMPUTE_RUNTIME_SHAPE@"
type: BOOLEAN
}
attrs {
name: "use_quantizer"
type: BOOLEAN
}
attrs {
name: "mkldnn_data_type"
type: STRING
}
attrs {
name: "weight_scale"
type: FLOATS
}
attrs {
name: "Input_scale"
type: FLOAT
}
attrs {
name: "out_scale"
type: FLOAT
}
attrs {
name: "out_threshold"
type: FLOAT
}
attrs {
name: "force_fp32_output"
type: BOOLEAN
}
attrs {
name: "enable_int8"
type: BOOLEAN
}
attrs {
name: "use_fc_padding"
type: BOOLEAN
}
attrs {
name: "use_gpu"
type: BOOLEAN
}
attrs {
name: "op_role"
type: INT
}
attrs {
name: "op_role_var"
type: STRINGS
}
attrs {
name: "op_namescope"
type: STRING
}
attrs {
name: "op_callstack"
type: STRINGS
}
attrs {
name: "op_device"
type: STRING
}
}
type: "mul"
def {
inputs {
name: "X"
}
inputs {
name: "Y"
}
outputs {
name: "Out"
}
attrs {
name: "x_num_col_dims"
type: INT
}
attrs {
name: "y_num_col_dims"
type: INT
}
}
extra {
attrs {
name: "skip_quant"
type: BOOLEAN
}
attrs {
name: "use_mkldnn"
type: BOOLEAN
}
attrs {
name: "scale_x"
type: FLOAT
}
attrs {
name: "scale_y"
type: FLOATS
}
attrs {
name: "scale_out"
type: FLOAT
}
attrs {
name: "force_fp32_output"
type: BOOLEAN
}
attrs {
name: "enable_int8"
type: BOOLEAN
}
attrs {
name: "X_scale"
type: FLOAT
}
attrs {
name: "weight_scale"
type: FLOAT
}
attrs {
name: "out_scale"
type: FLOAT
}
attrs {
name: "out_threshold"
type: FLOAT
}
attrs {
name: "op_role"
type: INT
}
attrs {
name: "op_role_var"
type: STRINGS
}
attrs {
name: "op_namescope"
type: STRING
}
attrs {
name: "op_callstack"
type: STRINGS
}
attrs {
name: "op_device"
type: STRING
}
}
type: "relu"
def {
inputs {
name: "X"
}
outputs {
name: "Out"
}
}
extra {
attrs {
name: "use_mkldnn"
type: BOOLEAN
}
attrs {
name: "use_cudnn"
type: BOOLEAN
}
attrs {
name: "op_role"
type: INT
}
attrs {
name: "op_role_var"
type: STRINGS
}
attrs {
name: "op_namescope"
type: STRING
}
attrs {
name: "op_callstack"
type: STRINGS
}
attrs {
name: "op_device"
type: STRING
}
attrs {
name: "is_test"
type: BOOLEAN
}
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册