未验证 提交 f55f9d79 编写于 作者: W wz1qqx 提交者: GitHub

[XPU]Add act add fuse (#53965)

上级 75fc4bf0
......@@ -248,6 +248,8 @@ if(WITH_XPU)
pass_library(stack_fuse_pass inference DIR xpu DEPS ${XPU_PASS_DEPS})
pass_library(fused_multi_transformer_cachekv_layout_trans_pass inference DIR
xpu DEPS ${XPU_PASS_DEPS})
pass_library(add_activation_xpu_fuse_pass inference DIR xpu DEPS
${XPU_PASS_DEPS})
endif()
cc_library(
......
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <string>
#include "glog/logging.h"
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/framework/ir/pass.h"
#include "paddle/fluid/framework/ir/xpu/pass_utils.h"
#include "paddle/fluid/framework/ir/xpu/quant_utils.h"
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/platform/enforce.h"
namespace phi {
class DenseTensor;
} // namespace phi
namespace paddle {
namespace framework {
class Scope;
} // namespace framework
} // namespace paddle
namespace paddle {
namespace framework {
namespace ir {
namespace patterns {
/*
fuse ele_add + activation block in to xpu_ele_fusion op
For example:
graph:
ele_x
|
|
elementwise_add -----ele_y
|
|
act
|
|
out_Out
------------------------------------------------------
After the pass is applied:
Input
| ele_y
| /
| /
Input_max ---- add_act_fusion ---- ele_y_max
| \
| \
| OutputMax
Output
*/
struct AddActXPUPattern : public PatternBase {
AddActXPUPattern(PDPattern* pattern,
const std::string& name_scope,
const std::string& act_type);
// declare operator node's name
PATTERN_DECL_NODE(ele_add);
PATTERN_DECL_NODE(act);
// declare variable node's name
PATTERN_DECL_NODE(ele_x);
PATTERN_DECL_NODE(ele_y);
PATTERN_DECL_NODE(ele_out);
PATTERN_DECL_NODE(act_out);
private:
std::string act_type_;
};
AddActXPUPattern::AddActXPUPattern(PDPattern* pattern,
const std::string& name_scope,
const std::string& act_type)
: PatternBase(pattern, name_scope, name_scope), act_type_(act_type) {
auto ele_add =
pattern->NewNode(ele_add_repr())->assert_is_op("elementwise_add");
auto ele_x = pattern->NewNode(ele_x_repr())
->assert_is_op_input("elementwise_add", "X")
->assert_var_not_persistable()
->AsInput();
auto ele_y = pattern->NewNode(ele_y_repr())
->assert_is_op_input("elementwise_add", "Y")
->assert_var_not_persistable()
->AsInput();
auto ele_out = pattern->NewNode(ele_out_repr())
->assert_is_op_output("elementwise_add", "Out")
->assert_has_n_outputs(1);
ele_add->LinksFrom({ele_x, ele_y}).LinksTo({ele_out});
ele_out->assert_is_op_input(act_type_, "X");
auto act = pattern->NewNode(act_repr())->assert_is_op(act_type_);
auto act_out =
pattern->NewNode(act_out_repr())->assert_is_op_output(act_type_, "Out");
act->LinksFrom({ele_out}).LinksTo({act_out});
}
} // namespace patterns
class AddActXPUFusePass : public FusePassBase {
protected:
void ApplyImpl(ir::Graph* graph) const override;
private:
int ApplyImpl(ir::Graph* graph, const std::string& act_type) const;
const std::string name_scope_{"add_activation_xpu_fuse_pass"};
};
void AddActXPUFusePass::ApplyImpl(ir::Graph* graph) const {
PADDLE_ENFORCE_NOT_NULL(
graph, platform::errors::PreconditionNotMet("graph should not be null."));
Init(name_scope_, graph);
int found_subgraph_count = 0;
for (auto act_type : {"relu", "gelu"}) {
found_subgraph_count += ApplyImpl(graph, act_type);
}
AddStatis(found_subgraph_count);
}
int AddActXPUFusePass::ApplyImpl(ir::Graph* graph,
const std::string& act_type) const {
GraphPatternDetector gpd;
patterns::AddActXPUPattern pattern(
gpd.mutable_pattern(), name_scope_, act_type);
int found_subgraph_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* graph) {
VLOG(4) << "handle AddActXPUFusePass fuse";
/* declare operator node's name */
GET_IR_NODE(ele_add);
GET_IR_NODE(act);
/* declare variable node's name*/
GET_IR_NODE(ele_x);
GET_IR_NODE(ele_y);
GET_IR_NODE(ele_out);
GET_IR_NODE(act_out);
auto* block = ele_add->Op()->Block();
auto* scope = param_scope();
PADDLE_ENFORCE_NOT_NULL(
scope, platform::errors::InvalidArgument("Scope cannot be nullptr."));
std::string fused_op_out_name;
fused_op_out_name = act_out->Name();
std::string fused_op_out_max_name = fused_op_out_name + "_max";
VarDesc fused_op_out_max_desc(fused_op_out_max_name);
Node* fused_op_out_max = graph->CreateVarNode(&fused_op_out_max_desc);
// Generate add_act fused op
framework::OpDesc fused_op_desc(block);
fused_op_desc.SetType("add_act_xpu");
// set attrs for fused op
fused_op_desc.SetAttr("act_type", ConvertActivationType(act_type));
fused_op_desc.SetInput("x", {ele_x->Name()});
fused_op_desc.SetInput("y", {ele_y->Name()});
fused_op_desc.SetOutput("out", {fused_op_out_name});
fused_op_desc.SetOutput("out_max", {fused_op_out_max_name});
// relink fused op
auto* fused_op = graph->CreateOpNode(&fused_op_desc);
IR_NODE_LINK_TO(ele_x, fused_op);
IR_NODE_LINK_TO(ele_y, fused_op);
IR_NODE_LINK_TO(fused_op, act_out);
IR_NODE_LINK_TO(fused_op, fused_op_out_max);
// delete useless node
std::unordered_set<const Node*> delete_nodes = {ele_add, act, ele_out};
GraphSafeRemoveNodes(graph, delete_nodes);
found_subgraph_count++;
};
gpd(graph, handler);
return found_subgraph_count;
}
} // namespace ir
} // namespace framework
} // namespace paddle
REGISTER_PASS(add_activation_xpu_fuse_pass,
paddle::framework::ir::AddActXPUFusePass);
REGISTER_PASS_CAPABILITY(add_activation_xpu_fuse_pass)
.AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination().EQ(
"add_act_xpu", 0));
......@@ -12,13 +12,12 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/ir/xpu/link_xpu_op_max_pass.h"
#include <string>
#include "glog/logging.h"
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/framework/ir/pass.h"
#include "paddle/fluid/framework/ir/xpu/pass_utils.h"
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/platform/enforce.h"
......@@ -36,165 +35,211 @@ class Scope;
namespace paddle {
namespace framework {
namespace ir {
namespace patterns {
struct LinkAddActPattern : public PatternBase {
LinkAddActPattern(PDPattern* pattern, const std::string& name_scope);
// declare operator node's name
PATTERN_DECL_NODE(fusion_op);
// declare variable node's name
PATTERN_DECL_NODE(x);
PATTERN_DECL_NODE(ele_y);
};
LinkAddActPattern::LinkAddActPattern(PDPattern* pattern,
const std::string& name_scope)
: PatternBase(pattern, name_scope, name_scope) {
auto* fusion_op =
pattern->NewNode(fusion_op_repr())->assert_is_op("add_act_xpu");
auto* x = pattern->NewNode(x_repr())->assert_is_op_input("add_act_xpu", "x");
auto* ele_y =
pattern->NewNode(ele_y_repr())->assert_is_op_input("add_act_xpu", "y");
fusion_op->LinksFrom({x, ele_y});
}
struct FusionXPUOpPattern : public PatternBase {
FusionXPUOpPattern(PDPattern* pattern,
const std::string& name_scope,
const std::string& op_type,
bool with_branch);
struct LinkConv2dPattern : public PatternBase {
LinkConv2dPattern(PDPattern* pattern,
const std::string& name_scope,
bool with_branch);
// declare operator node's name
PATTERN_DECL_NODE(fusion_op);
// declare variable node's name
PATTERN_DECL_NODE(input);
PATTERN_DECL_NODE(x);
PATTERN_DECL_NODE(branch);
private:
std::string op_type_;
bool with_branch_{false};
};
FusionXPUOpPattern::FusionXPUOpPattern(PDPattern* pattern,
const std::string& name_scope,
const std::string& op_type,
bool with_branch)
: PatternBase(pattern, name_scope, name_scope),
op_type_(op_type),
with_branch_(with_branch) {
auto* fusion_op = pattern->NewNode(fusion_op_repr())->assert_is_op(op_type_);
auto* input =
pattern->NewNode(input_repr())->assert_is_op_input(op_type_, "x");
LinkConv2dPattern::LinkConv2dPattern(PDPattern* pattern,
const std::string& name_scope,
bool with_branch)
: PatternBase(pattern, name_scope, name_scope), with_branch_(with_branch) {
auto* fusion_op =
pattern->NewNode(fusion_op_repr())->assert_is_op("conv2d_xpu");
auto* x = pattern->NewNode(x_repr())->assert_is_op_input("conv2d_xpu", "x");
PDNode* branch = nullptr;
if (with_branch_) {
branch =
pattern->NewNode(branch_repr())->assert_is_op_input(op_type_, "branch");
fusion_op->LinksFrom({input, branch});
} else {
fusion_op->LinksFrom({input});
branch = pattern->NewNode(branch_repr())
->assert_is_op_input("conv2d_xpu", "branch");
fusion_op->LinksFrom({branch});
}
fusion_op->LinksFrom({x});
}
} // namespace patterns
struct LinkFcPattern : public PatternBase {
LinkFcPattern(PDPattern* pattern, const std::string& name_scope);
class LinkXPUOpMaxPass : public FusePassBase {
protected:
void ApplyImpl(ir::Graph* graph) const override;
// declare operator node's name
PATTERN_DECL_NODE(fusion_op);
// declare variable node's name
PATTERN_DECL_NODE(x);
};
private:
void ApplyImpl(ir::Graph* graph,
const std::string& op_type,
bool with_branch) const;
LinkFcPattern::LinkFcPattern(PDPattern* pattern, const std::string& name_scope)
: PatternBase(pattern, name_scope, name_scope) {
auto* fusion_op = pattern->NewNode(fusion_op_repr())->assert_is_op("fc_xpu");
auto* x = pattern->NewNode(x_repr())->assert_is_op_input("fc_xpu", "x");
const std::string name_scope_{"link_xpu_op_max_pass"};
// ops with x_max/out_max
std::set<std::string> op_types_{"fc_xpu", "conv2d_xpu"};
};
fusion_op->LinksFrom({x});
}
/*
Origin subgraph:
fusion_xpu_op0
/ \
| |
out0 out0_max
|
\
fusion_op
Fused subgraph:
fusion_xpu_op0
/ \
| |
out0 out0_max
| |
\ /
fusion_op
Origin subgraph1:
fusion_xpu_op0 fusion_xpu_op1
/ \ / \
| | | |
out0 out0_max out1 out1_max
| |
(x) \ / (branch)
fusion_xpu_op2
Fused subgraph1:
fusion_xpu_op0 fusion_xpu_op1
/ \ / \
| | | |
out0 out0_max out1 out1_max
| | | |
(x) \ |(x_max) |(branch) /(branch_max)
\ | | /
\ | | /
\ | | /
fusion_xpu_op2
*/
void LinkXPUOpMaxPass::ApplyImpl(ir::Graph* graph) const {
Init(name_scope_, graph);
for (auto op_type : op_types_) {
for (auto with_branch : {true, false}) {
ApplyImpl(graph, op_type, with_branch);
} // namespace patterns
void LinkXPUOpMaxPass::LinkAddActMax(ir::Graph* graph) const {
GraphPatternDetector gpd;
patterns::LinkAddActPattern pattern(gpd.mutable_pattern(), name_scope_);
int found_subgraph_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* graph) {
VLOG(4) << "handle LinkAddActMax";
/* declare operator node's name */
GET_IR_NODE(fusion_op);
/* declare variable node's name*/
GET_IR_NODE(x);
GET_IR_NODE(ele_y);
auto* fusion_op_desc = fusion_op->Op();
auto* x_pre_op = x->inputs[0]->Op();
if (x->inputs.size() > 0 && x->inputs[0]->IsOp() &&
x_pre_op->HasOutput("out_max")) {
auto preop_max_var_name = x_pre_op->Output("out_max");
for (auto max_node : x->inputs[0]->outputs) {
if (preop_max_var_name[0] == max_node->Name()) {
fusion_op_desc->SetInput("x_max", {max_node->Name()});
IR_NODE_LINK_TO(max_node, fusion_op);
}
}
}
}
auto* ele_y_pre_op = ele_y->inputs[0]->Op();
if (ele_y->inputs.size() > 0 && ele_y->inputs[0]->IsOp() &&
ele_y_pre_op->HasOutput("out_max")) {
auto preop_max_var_name = ele_y_pre_op->Output("out_max");
for (auto max_node : ele_y->inputs[0]->outputs) {
if (preop_max_var_name[0] == max_node->Name()) {
fusion_op_desc->SetInput("y_max", {max_node->Name()});
IR_NODE_LINK_TO(max_node, fusion_op);
}
}
}
found_subgraph_count++;
};
gpd(graph, handler);
AddStatis(found_subgraph_count);
}
void LinkXPUOpMaxPass::ApplyImpl(ir::Graph* graph,
const std::string& op_type,
bool with_branch) const {
PADDLE_ENFORCE_NOT_NULL(
graph, platform::errors::PreconditionNotMet("graph should not be null."));
void LinkXPUOpMaxPass::LinkConv2dMax(ir::Graph* graph, bool with_branch) const {
GraphPatternDetector gpd;
patterns::FusionXPUOpPattern pattern(
gpd.mutable_pattern(), name_scope_, op_type, with_branch);
patterns::LinkConv2dPattern pattern(
gpd.mutable_pattern(), name_scope_, with_branch);
int found_subgraph_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* graph) {
VLOG(4) << "handle LinkXPUOpMaxPass fuse";
VLOG(4) << "handle LinkConv2dMax";
/* declare operator node's name */
GET_IR_NODE(fusion_op);
GET_IR_NODE(input);
/* declare variable node's name*/
GET_IR_NODE(x);
GET_IR_NODE(branch);
auto* fusion_op_desc = fusion_op->Op();
if (fusion_op_desc->HasAttr("has_branch")) {
bool fusion_op_branch =
PADDLE_GET_CONST(bool, fusion_op_desc->GetAttr("has_branch"));
if (fusion_op_branch != with_branch) {
return;
}
}
if (input->inputs.size() > 0 && input->inputs[0]->IsOp() &&
input->inputs[0]->Op()->HasOutput("out_max")) {
auto input_max_name = input->inputs[0]->Op()->Output("out_max");
for (auto max_node : input->inputs[0]->outputs) {
if (input_max_name[0] == max_node->Name()) {
auto* x_pre_op = x->inputs[0]->Op();
if (x->inputs.size() > 0 && x->inputs[0]->IsOp() &&
x_pre_op->HasOutput("out_max")) {
auto preop_max_var_name = x_pre_op->Output("out_max");
for (auto max_node : x->inputs[0]->outputs) {
if (preop_max_var_name[0] == max_node->Name()) {
fusion_op_desc->SetInput("x_max", {max_node->Name()});
IR_NODE_LINK_TO(max_node, fusion_op);
found_subgraph_count++;
}
}
}
if (with_branch) {
auto* branch_pre_op = branch->inputs[0]->Op();
if (branch->inputs.size() > 0 && branch->inputs[0]->IsOp() &&
branch->inputs[0]->Op()->HasOutput("out_max")) {
auto branch_max_name = branch->inputs[0]->Op()->Output("out_max");
branch_pre_op->HasOutput("out_max")) {
auto preop_max_var_name = branch_pre_op->Output("out_max");
for (auto max_node : branch->inputs[0]->outputs) {
if (branch_max_name[0] == max_node->Name()) {
if (preop_max_var_name[0] == max_node->Name()) {
fusion_op_desc->SetInput("branch_max", {max_node->Name()});
IR_NODE_LINK_TO(max_node, fusion_op);
found_subgraph_count++;
}
}
}
}
found_subgraph_count++;
};
gpd(graph, handler);
AddStatis(found_subgraph_count);
}
void LinkXPUOpMaxPass::LinkFcMax(ir::Graph* graph) const {
GraphPatternDetector gpd;
patterns::LinkFcPattern pattern(gpd.mutable_pattern(), name_scope_);
int found_subgraph_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* graph) {
VLOG(4) << "handle LinkFcMax";
/* declare operator node's name */
GET_IR_NODE(fusion_op);
/* declare variable node's name*/
GET_IR_NODE(x);
auto* fusion_op_desc = fusion_op->Op();
auto* x_pre_op = x->inputs[0]->Op();
if (x->inputs.size() > 0 && x->inputs[0]->IsOp() &&
x_pre_op->HasOutput("out_max")) {
auto preop_max_var_name = x_pre_op->Output("out_max");
for (auto max_node : x->inputs[0]->outputs) {
if (preop_max_var_name[0] == max_node->Name()) {
fusion_op_desc->SetInput("x_max", {max_node->Name()});
IR_NODE_LINK_TO(max_node, fusion_op);
}
}
}
found_subgraph_count++;
};
gpd(graph, handler);
AddStatis(found_subgraph_count);
}
void LinkXPUOpMaxPass::ApplyImpl(ir::Graph* graph) const {
PADDLE_ENFORCE_NOT_NULL(
graph, platform::errors::PreconditionNotMet("graph should not be null."));
Init(name_scope_, graph);
LinkFcMax(graph);
for (auto with_branch : {true, false}) {
LinkConv2dMax(graph, with_branch);
}
LinkAddActMax(graph);
}
} // namespace ir
} // namespace framework
} // namespace paddle
......@@ -203,5 +248,7 @@ REGISTER_PASS(link_xpu_op_max_pass, paddle::framework::ir::LinkXPUOpMaxPass);
REGISTER_PASS_CAPABILITY(link_xpu_op_max_pass)
.AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination().EQ(
"fc_xpu", 0));
paddle::framework::compatible::OpVersionComparatorCombination()
.EQ("fc_xpu", 0)
.EQ("conv2d_xpu", 0)
.EQ("add_act_xpu", 0));
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/pass.h"
namespace phi {
class DenseTensor;
} // namespace phi
namespace paddle {
namespace framework {
class Scope;
} // namespace framework
} // namespace paddle
namespace paddle {
namespace framework {
namespace ir {
class LinkXPUOpMaxPass : public FusePassBase {
protected:
void ApplyImpl(ir::Graph* graph) const override;
private:
/*
Origin subgraph:
fusion_xpu_op0
/ \
| |
out0 out0_max
|
\
fc_xpu
Fused subgraph:
fusion_xpu_op0
/ \
| |
out0 out0_max
| |
\ /
fc_xpu
*/
void LinkFcMax(ir::Graph* graph) const;
/*
Origin subgraph:
fusion_xpu_op0 fusion_xpu_op1
/ \ / \
| | | |
out0 out0_max out1 out1_max
| |
(x) \ / (branch)
conv2d_xpu
Fused subgraph:
fusion_xpu_op0 fusion_xpu_op1
/ \ / \
| | | |
out0 out0_max out1 out1_max
| | | |
(x) \ |(x_max) |(branch) /(branch_max)
\ | | /
\ | | /
\ | | /
conv2d_xpu
*/
void LinkConv2dMax(ir::Graph* graph, bool with_branch) const;
/*
Origin subgraph:
fusion_xpu_op0 fusion_xpu_op1
/ \ / \
| | | |
out0 out0_max out1 out1_max
| |
(x) \ / (y)
add_act_xpu
Fused subgraph:
fusion_xpu_op0 fusion_xpu_op1
/ \ / \
| | | |
out0 out0_max out1 out1_max
| | | |
(x) \ |(x_max) |(y) /(y_max)
\ | | /
\ | | /
\ | | /
add_act_xpu
*/
void LinkAddActMax(ir::Graph* graph) const;
const std::string name_scope_{"link_xpu_op_max_pass"};
};
} // namespace ir
} // namespace framework
} // namespace paddle
......@@ -527,6 +527,7 @@ XpuPassStrategy::XpuPassStrategy() : PassStrategy({}) {
"sigmoid_elementmul_fuse_pass",
"fc_xpu_fuse_pass",
"conv2d_xpu_fuse_pass",
"add_activation_xpu_fuse_pass",
"link_xpu_op_max_pass",
"inplace_op_var_pass",
"delete_isolated_node_pass",
......
......@@ -4,6 +4,16 @@
# if one operator have "support_dygraph_mode : true", it supports dygraph mode,
# otherwise the operator only could be used in static mode.
- op : add_act_xpu
args : (Tensor x, Tensor x_max, Tensor y, Tensor y_max, int act_type)
output : Tensor(out), Tensor(out_max)
infer_meta :
func : AddActXPUInferMeta
kernel :
func : add_act_xpu
data_type : x
optional : x_max, y_max
- op : conv2d_xpu
args : (Tensor x, Tensor x_max, Tensor filter, Tensor filter_max, Tensor bias, Tensor branch, Tensor branch_max, int[] paddings, int[] dilations, int[] strides, str padding_algorithm, int groups, bool has_bias, bool has_branch, int act_type, float act_param)
output : Tensor(out), Tensor(out_max)
......
......@@ -22,6 +22,8 @@ namespace xpu {
XPUOpMap& get_kl2_ops() {
// KL2支持的op,通过op_name, data_type, place来索引
static XPUOpMap s_xpu2_kernels{
{"add_act_xpu",
XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})},
{"abs", XPUKernelSet({phi::DataType::FLOAT32})},
{"abs_grad",
XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})},
......
......@@ -19,9 +19,66 @@ limitations under the License. */
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/core/meta_tensor.h"
#include "paddle/phi/kernels/cpu/conv_util.h"
#include "paddle/phi/kernels/funcs/common_shape.h"
namespace phi {
void AddActXPUInferMeta(const MetaTensor& x,
const MetaTensor& x_max,
const MetaTensor& y,
const MetaTensor& y_max,
int act_type,
MetaTensor* out,
MetaTensor* out_max) {
int axis = -1;
if (x.dims() != y.dims()) {
auto x_dims = x.dims();
auto y_dims = y.dims();
int max_dim = std::max(x_dims.size(), y_dims.size());
if (x_dims.size() == y_dims.size()) {
PADDLE_ENFORCE_EQ((axis == -1) || (axis == 0),
true,
phi::errors::InvalidArgument(
"axis should be -1 or 0 while the dimension of "
"tensor X (%s) is equal to the dimension of "
"tensor Y (%s), but received axis: %s",
x_dims.size(),
y_dims.size(),
axis));
}
PADDLE_ENFORCE_EQ((axis >= (-1 * max_dim)) && (axis < max_dim),
true,
phi::errors::InvalidArgument(
"The axis range must be [%s, %s), but axis is %s. "
"Please set the axis again.",
-1 * max_dim,
max_dim,
axis));
axis = (axis < 0 ? (std::abs(x_dims.size() - y_dims.size()) + axis + 1)
: axis);
std::vector<int> x_dims_array(max_dim);
std::vector<int> y_dims_array(max_dim);
std::vector<int> out_dims_array(max_dim);
funcs::GetBroadcastDimsArrays(x_dims,
y_dims,
x_dims_array.data(),
y_dims_array.data(),
out_dims_array.data(),
max_dim,
axis);
auto out_dims = phi::make_ddim(out_dims_array);
out->set_dims(out_dims);
} else {
out->set_dims(x.dims());
}
out->set_dtype(x.dtype());
out->set_layout(x.layout());
out->share_lod(x);
out_max->set_dims(phi::make_ddim({6}));
out_max->set_dtype(x.dtype());
out_max->set_layout(x.layout());
}
inline int ConvOutSize(int input_size,
int filter_size,
int dilation,
......
......@@ -22,6 +22,14 @@ namespace phi {
// Common InferMeta Functions for fusion operators.
// NOTE: The InferMeta Functions in this file are arranged in alphabetic order.
void AddActXPUInferMeta(const MetaTensor& x,
const MetaTensor& x_max,
const MetaTensor& y,
const MetaTensor& y_max,
int act_type,
MetaTensor* out,
MetaTensor* out_max);
void Conv2dXPUInferMeta(const MetaTensor& x,
const MetaTensor& x_max,
const MetaTensor& filter,
......
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/backends/xpu/enforce_xpu.h"
#include "paddle/phi/core/kernel_registry.h"
namespace phi {
namespace fusion {
template <typename T, typename Context>
void AddActXPUKernel(const Context& ctx,
const DenseTensor& x,
const paddle::optional<DenseTensor>& x_max,
const DenseTensor& y,
const paddle::optional<DenseTensor>& y_max,
int act_type,
DenseTensor* out,
DenseTensor* out_max) {
using XPUType = typename XPUTypeTrait<T>::Type;
auto* x_data = reinterpret_cast<const XPUType*>(x.data<T>());
const float* x_max_data =
x_max.get_ptr() == nullptr ? nullptr : x_max.get_ptr()->data<float>();
auto* y_data = reinterpret_cast<const XPUType*>(y.data<T>());
const float* y_max_data =
y_max.get_ptr() == nullptr ? nullptr : y_max.get_ptr()->data<float>();
auto* out_data = reinterpret_cast<XPUType*>(ctx.template Alloc<T>(out));
std::vector<int64_t> x_shape = phi::vectorize(x.dims());
std::vector<int64_t> y_shape = phi::vectorize(y.dims());
xpu::Activation_t act(static_cast<xpu::Activation_t::act_enum>(act_type));
int r =
xpu::add_activation_fusion<XPUType, XPUType, XPUType>( // TX/TY/TZ/TID
/* baidu::xpu::api::Context* ctx */ ctx.x_context(),
/* const TX* x */ x_data,
/* const TY* y */ y_data,
/* TZ* z */ out_data,
/* const std::vector<int64_t>& x_shape */ x_shape,
/* const std::vector<int64_t>& y_shape */ y_shape,
/* const float* max_x */ x_max_data,
/* const float* max_y */ y_max_data,
/* float* max_z */ ctx.template Alloc<float>(out_max),
/* const baidu::xpu::api::Activation_t& act */ act);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "add_act_xpu");
}
} // namespace fusion
} // namespace phi
PD_REGISTER_KERNEL(add_act_xpu,
XPU,
ALL_LAYOUT,
phi::fusion::AddActXPUKernel,
float,
phi::dtype::float16) {}
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
from functools import partial
import hypothesis.strategies as st
import numpy as np
from auto_scan_test import PassAutoScanTest
from program_config import OpConfig, ProgramConfig, TensorConfig
class TestAddActXPUFusePass(PassAutoScanTest):
def sample_predictor_configs(self, program_config):
config = self.create_inference_config(use_xpu=True)
yield config, ["add_act_xpu"], (1e-3, 1e-3)
def sample_program_config(self, draw):
batch_size = draw(st.integers(min_value=1, max_value=50))
# Generate shape of input:X Y of ele_add
def generate_input():
return np.random.random([batch_size, 3, 100, 100]).astype(
np.float32
)
axis = -1
# Here we will compose a program
# Still has some risks that the program is invalid or cause bug while running
# Use function `is_program_valid` to filter the invalid programs before running
# Use function `add_skip_pass_case` to ignore the programs even if they cause bug while runing
elementwise_op = OpConfig(
type='elementwise_add',
inputs={'X': ['eltwise_X'], 'Y': ['eltwise_Y']},
outputs={'Out': ['eltwise_output']},
axis=axis,
)
relu_op = OpConfig(
"relu",
inputs={"X": ["eltwise_output"]},
outputs={"Out": ["relu_out"]},
)
mini_graph = [elementwise_op, relu_op]
program_config = ProgramConfig(
ops=mini_graph,
weights={},
inputs={
"eltwise_X": TensorConfig(data_gen=partial(generate_input)),
"eltwise_Y": TensorConfig(data_gen=partial(generate_input)),
},
outputs=mini_graph[-1].outputs["Out"],
)
return program_config
def test(self):
self.run_and_statis(
quant=False,
max_examples=25,
passes=["add_activation_xpu_fuse_pass"],
)
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册