未验证 提交 cf3ddd3b 编写于 作者: T TeslaZhao 提交者: GitHub

Pass compat of conv_transpose_bias_mkldnn_fuse_pass (#33708)

上级 18284261
......@@ -25,6 +25,102 @@ namespace paddle {
namespace framework {
namespace ir {
ConvBiasFusePass::ConvBiasFusePass() {
AddOpCompat(OpCompat("conv2d"))
.AddInput("Input")
.IsTensor()
.End()
.AddInput("Filter")
.IsTensor()
.End()
.AddInput("Bias")
.IsTensor()
.IsOptional()
.End()
.AddOutput("Output")
.IsTensor()
.End()
.AddAttr("strides")
.End()
.AddAttr("paddings")
.End()
.AddAttr("padding_algorithm")
.IsStringIn({"EXPLICIT", "SAME", "VALID"})
.End()
.AddAttr("groups")
.IsNumGE(1)
.End()
.AddAttr("dilations")
.End()
.AddAttr("data_format")
.IsStringIn({"NCHW", "NHWC"})
.End();
AddOpCompat(OpCompat("elementwise_add"))
.AddInput("X")
.IsTensor()
.End()
.AddInput("Y")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddAttr("axis")
.IsNumEQ(-1)
.End();
}
Conv2DTransposeBiasFusePass::Conv2DTransposeBiasFusePass() {
AddOpCompat(OpCompat("conv2d_transpose"))
.AddInput("Input")
.IsTensor()
.End()
.AddInput("Filter")
.IsTensor()
.End()
.AddInput("Bias")
.IsTensor()
.End()
.AddOutput("Output")
.IsTensor()
.End()
.AddAttr("output_padding")
.End()
.AddAttr("output_size")
.IsNumGE(1)
.End()
.AddAttr("groups")
.IsNumGE(1)
.End()
.AddAttr("dilations")
.End()
.AddAttr("strides")
.End()
.AddAttr("paddings")
.End()
.AddAttr("padding_algorithm")
.IsStringIn({"EXPLICIT", "SAME", "VALID"})
.End()
.AddAttr("data_format")
.IsStringIn({"NCHW", "NHWC"})
.End();
AddOpCompat(OpCompat("elementwise_add"))
.AddInput("X")
.IsTensor()
.End()
.AddInput("Y")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddAttr("axis")
.IsNumEQ(-1)
.End();
}
template <typename BinaryOperation>
LoDTensor tensor_apply_eltwise(const LoDTensor& vec_a, const LoDTensor& vec_b,
BinaryOperation f) {
......@@ -80,6 +176,12 @@ void ConvBiasFusePass::ApplyImpl(ir::Graph* graph) const {
subgraph.count(conv_input), 0,
platform::errors::NotFound("Detector did not find conv input."));
// check compat
if (!IsCompat(subgraph, g)) {
VLOG(3) << "Pass in op compat failed.";
return;
}
// check if fuse can be done and if MKL-DNN should be used
FuseOptions fuse_option = FindFuseOption(*conv, *eltwise);
if (fuse_option == DO_NOT_FUSE || fuse_option == FUSE_NATIVE) {
......
......@@ -29,6 +29,7 @@ class Graph;
class ConvBiasFusePass : public FusePassBase {
public:
ConvBiasFusePass();
virtual ~ConvBiasFusePass() {}
virtual std::string type() const { return "conv2d"; }
......@@ -41,6 +42,7 @@ class ConvBiasFusePass : public FusePassBase {
*/
class Conv2DTransposeBiasFusePass : public ConvBiasFusePass {
public:
Conv2DTransposeBiasFusePass();
std::string type() const override { return "conv2d_transpose"; }
};
......
......@@ -31,8 +31,19 @@ void SetOp(ProgramDesc* prog, const std::string& type, const std::string& name,
auto* op = prog->MutableBlock(0)->AppendOp();
op->SetType(type);
if (type == "conv2d") {
const std::vector<int> strides({1, 1});
const std::vector<int> paddings({0, 0});
const std::vector<int> dilations({1, 1});
op->SetAttr("use_mkldnn", true);
op->SetAttr("name", name);
op->SetAttr("strides", strides);
op->SetAttr("groups", 1);
op->SetAttr("paddings", paddings);
op->SetAttr("padding_algorithm", std::string("EXPLICIT"));
op->SetAttr("dilations", dilations);
op->SetAttr("data_format", std::string("NCHW"));
op->SetOutput("Output", outputs);
op->SetInput("Input", {inputs[0]});
op->SetInput("Filter", {inputs[1]});
if (inputs.size() > 2)
......@@ -41,10 +52,11 @@ void SetOp(ProgramDesc* prog, const std::string& type, const std::string& name,
op->SetInput("Bias", {});
} else if (type == "elementwise_add") {
op->SetAttr("use_mkldnn", true);
op->SetAttr("axis", -1);
op->SetInput("X", {inputs[0]});
op->SetInput("Y", {inputs[1]});
op->SetOutput("Out", outputs);
}
op->SetOutput("Out", outputs);
op->SetAttr(OpProtoAndCheckerMaker::OpRoleAttrName(),
static_cast<int>(OpRole::kForward));
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册