未验证 提交 07e788f1 编写于 作者: H hong19860320 提交者: GitHub

[XPU] Add fast_where fusion op and XPU micro kernel (#55628)

上级 744e1eaf
......@@ -280,6 +280,7 @@ if(WITH_XPU)
pass_library(matmul_weight_trans_pass inference DIR xpu DEPS ${XPU_PASS_DEPS})
pass_library(reshape2_matmul_xpu_fuse_pass inference DIR xpu DEPS
${XPU_PASS_DEPS})
pass_library(fast_where_xpu_fuse_pass inference DIR xpu DEPS ${XPU_PASS_DEPS})
endif()
cc_library(
......@@ -599,4 +600,8 @@ if(WITH_XPU)
test_reshape2_matmul_xpu_fuse_pass
SRCS xpu/reshape2_matmul_xpu_fuse_pass_test.cc
DEPS reshape2_matmul_xpu_fuse_pass)
cc_test(
test_fast_where_xpu_fuse_pass
SRCS xpu/fast_where_xpu_fuse_pass_test.cc
DEPS fast_where_xpu_fuse_pass)
endif()
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <string>
#include "glog/logging.h"
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/framework/ir/pass.h"
#include "paddle/fluid/framework/ir/xpu/pass_utils.h"
#include "paddle/fluid/framework/ir/xpu/quant_utils.h"
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/platform/enforce.h"
namespace phi {
class DenseTensor;
} // namespace phi
namespace paddle {
namespace framework {
class Scope;
} // namespace framework
} // namespace paddle
namespace paddle {
namespace framework {
namespace ir {
namespace patterns {
/*
Fuse cast+scale+mul+mul+add ops to fast_where_xpu op reduce memory access.
Case 0: when mode = 0,
condition
|
cast
|
/ \
/ \
scale \
x / y \
\ / \ /
mul0 mul1
\ /
\ /
\ /
add
After the pass is applied,
condition y x
\ | /
\ | /
\ | /
fast_where_xpu
Case 1: when mode = 1,
condition
|
cast
|
/ \
/ scale
/ \
x / y \
\ / \ /
mul0 mul1
\ /
\ /
\ /
add
After the pass is applied,
condition x y
\ | /
\ | /
\ | /
fast_where_xpu
Case 2: when mode = 0,
condition
|
cast
|
/ \
scale \
/ \
/ x \ y
\ / \ /
mul0 mul1
\ /
\ /
\ /
add
After the pass is applied,
condition y x
\ | /
\ | /
\ | /
fast_where_xpu
Case 3: when mode = 1,
condition
|
cast
|
/ \
/ scale
/ \
/ x \ y
\ / \ /
mul0 mul1
\ /
\ /
\ /
add
After the pass is applied,
condition x y
\ | /
\ | /
\ | /
fast_where_xpu
Case 4: when mode = 0,
condition
|
cast
|
/ \
scale \
/ \
/ x y \
\ / \ /
mul0 mul1
\ /
\ /
\ /
add
After the pass is applied,
condition y x
\ | /
\ | /
\ | /
fast_where_xpu
Case 5: when mode = 1,
condition
|
cast
|
/ \
/ scale
/ \
/ x y \
\ / \ /
mul0 mul1
\ /
\ /
\ /
add
After the pass is applied,
condition x y
\ | /
\ | /
\ | /
fast_where_xpu
*/
struct OneFastWhereXPUPattern : public PatternBase {
OneFastWhereXPUPattern(PDPattern* pattern,
const std::string& name_scope,
int mode);
// declare operator node's name
PATTERN_DECL_NODE(cast);
PATTERN_DECL_NODE(scale);
PATTERN_DECL_NODE(mul0);
PATTERN_DECL_NODE(mul1);
PATTERN_DECL_NODE(add);
// declare variable node's name
// cast
PATTERN_DECL_NODE(condition);
PATTERN_DECL_NODE(cast_out);
// scale
PATTERN_DECL_NODE(scale_out);
// mul0
PATTERN_DECL_NODE(x);
PATTERN_DECL_NODE(mul0_out);
// mul1
PATTERN_DECL_NODE(y);
PATTERN_DECL_NODE(mul1_out);
// add
PATTERN_DECL_NODE(add_out);
private:
int mode_{0};
};
OneFastWhereXPUPattern::OneFastWhereXPUPattern(PDPattern* pattern,
const std::string& name_scope,
int mode)
: PatternBase(pattern, name_scope, name_scope), mode_(mode) {
// cast
auto condition =
pattern->NewNode(condition_repr())->assert_is_op_input("cast", "X");
auto cast_out = pattern->NewNode(cast_out_repr())
->assert_is_op_output("cast", "Out")
->assert_is_op_input("scale", "X")
->assert_is_op_input("elementwise_mul");
auto cast = pattern->NewNode(cast_repr())
->assert_is_op("cast")
->assert_more([](Node* n) {
auto in_dtype_val =
PADDLE_GET_CONST(int, n->Op()->GetAttr("in_dtype"));
auto out_dtype_val =
PADDLE_GET_CONST(int, n->Op()->GetAttr("out_dtype"));
return in_dtype_val == 0 &&
(out_dtype_val == 4 || out_dtype_val == 5);
});
// scale
auto scale_out = pattern->NewNode(scale_out_repr())
->assert_is_op_output("scale", "Out")
->assert_is_op_input("elementwise_mul");
auto scale =
pattern->NewNode(scale_repr())
->assert_is_op("scale")
->assert_more([](Node* n) {
auto bias_val = PADDLE_GET_CONST(float, n->Op()->GetAttr("bias"));
auto scale_val = PADDLE_GET_CONST(float, n->Op()->GetAttr("scale"));
return fabs(bias_val - 1.0f) <= 1e-5f &&
fabs(scale_val + 1.0f) <= 1e-5f;
});
// mul0
auto x = pattern->NewNode(x_repr())->assert_is_op_input("elementwise_mul");
auto mul0_out = pattern->NewNode(mul0_out_repr())
->assert_is_op_output("elementwise_mul", "Out")
->assert_is_op_input("elementwise_add");
auto mul0 = pattern->NewNode(mul0_repr())
->assert_is_op("elementwise_mul")
->assert_more([](Node* node) {
auto node1 = node->inputs[0];
auto node2 = node->inputs[1];
auto node1_shape = node1->Var()->GetShape();
auto node2_shape = node2->Var()->GetShape();
if (node1_shape.size() != node2_shape.size()) return false;
for (size_t i = 0; i < node1_shape.size(); i++) {
if (node1_shape[i] != node2_shape[i] &&
(node1_shape[i] != 1 && node2_shape[i] != 1)) {
return false;
}
}
return true;
});
// mul1
auto y = pattern->NewNode(y_repr())->assert_is_op_input("elementwise_mul");
auto mul1_out = pattern->NewNode(mul1_out_repr())
->assert_is_op_output("elementwise_mul", "Out")
->assert_is_op_input("elementwise_add");
auto mul1 = pattern->NewNode(mul1_repr())
->assert_is_op("elementwise_mul")
->assert_more([](Node* node) {
auto node1 = node->inputs[0];
auto node2 = node->inputs[1];
auto node1_shape = node1->Var()->GetShape();
auto node2_shape = node2->Var()->GetShape();
if (node1_shape.size() != node2_shape.size()) return false;
for (size_t i = 0; i < node1_shape.size(); i++) {
if (node1_shape[i] != node2_shape[i] &&
(node1_shape[i] != 1 && node2_shape[i] != 1)) {
return false;
}
}
return true;
});
// add
auto add_out = pattern->NewNode(add_out_repr())
->assert_is_op_output("elementwise_add", "Out");
auto add = pattern->NewNode(add_repr())
->assert_is_op("elementwise_add")
->assert_more([](Node* node) {
auto node_in1 = node->inputs[0];
auto node_in2 = node->inputs[1];
if (node_in1->inputs.size() == 1 &&
node_in1->inputs[0]->Op()->Type() == "elementwise_mul" &&
node_in2->inputs.size() == 1 &&
node_in2->inputs[0]->Op()->Type() == "elementwise_mul") {
auto shape1 = node_in1->Var()->GetShape();
auto shape2 = node_in2->Var()->GetShape();
return shape1 == shape2;
}
return false;
});
cast->LinksFrom({condition}).LinksTo({cast_out});
scale->LinksFrom({cast_out}).LinksTo({scale_out});
PADDLE_ENFORCE_LE(
mode,
1,
platform::errors::InvalidArgument(
"one_fast_where_xpu_fuse_pass mode(%d) is not supported.", mode));
if (mode == 0) {
mul0->LinksFrom({x, scale_out}).LinksTo({mul0_out});
mul1->LinksFrom({y, cast_out}).LinksTo({mul1_out});
} else if (mode == 1) {
mul0->LinksFrom({x, cast_out}).LinksTo({mul0_out});
mul1->LinksFrom({y, scale_out}).LinksTo({mul1_out});
}
add->LinksFrom({mul0_out, mul1_out}).LinksTo({add_out});
}
/*
Fuse cascade fast_where_xpu ops to one fast_where_xpu op reduce memory access.
Case 0: when mode = 0,
x--------------
| |
| condition0 | y
| \ | /
| \ | /
| \ | /
condition1 | fast_where_xpu0
\ | /
\ | /
\ | /
fast_where_xpu1
After the pass is applied,
condition0 condition1
\ /
\ /
or
\ x y
\ | /
\ | /
fast_where_xpu
Case 1: when mode = 1,
condition0 x y
\ | / |
\ | / |
\ | / |
fast_where_xpu0 |
| |
condition1 | |
\ | /
\ | /
\ | /
fast_where_xpu1
After the pass is applied,
condition0 condition1
\ /
\ /
\ /
and
\ x y
\ | /
\ | /
fast_where_xpu
Other cases:
x ---------------------
| |
| condition0 y |
| \ | /
| \ | /
| \ | /
condition1 | fast_where_xpu0
\ | /
\ | /
\ | /
fast_where_xpu1
----------
| |
condition0 x y |
\ | / |
\ | / |
\ | / |
fast_where_xpu0 |
| |
condition1 | |
\ | /
\ | /
\ | /
fast_where_xpu1
*/
struct CascadeFastWhereXPUPattern : public PatternBase {
CascadeFastWhereXPUPattern(PDPattern* pattern,
const std::string& name_scope,
int mode);
// declare operator node's name
PATTERN_DECL_NODE(fast_where_xpu0);
PATTERN_DECL_NODE(fast_where_xpu1);
// declare variable node's name
PATTERN_DECL_NODE(condition0);
PATTERN_DECL_NODE(condition1);
PATTERN_DECL_NODE(x);
PATTERN_DECL_NODE(y);
PATTERN_DECL_NODE(fast_where_xpu0_out);
PATTERN_DECL_NODE(fast_where_xpu1_out);
private:
int mode_{0};
};
CascadeFastWhereXPUPattern::CascadeFastWhereXPUPattern(
PDPattern* pattern, const std::string& name_scope, int mode)
: PatternBase(pattern, name_scope, name_scope), mode_(mode) {
// declare operator nodes
auto fast_where_xpu0 =
pattern->NewNode(fast_where_xpu0_repr())->assert_is_op("fast_where_xpu");
auto fast_where_xpu1 =
pattern->NewNode(fast_where_xpu1_repr())->assert_is_op("fast_where_xpu");
// declare vairable nodes
auto condition0 = pattern->NewNode(condition0_repr())
->assert_is_op_input("fast_where_xpu", "condition");
auto condition1 = pattern->NewNode(condition1_repr())
->assert_is_op_input("fast_where_xpu", "condition");
auto fast_where_xpu0_out = pattern->NewNode(fast_where_xpu0_out_repr())
->assert_is_op_output("fast_where_xpu", "out");
auto fast_where_xpu1_out = pattern->NewNode(fast_where_xpu1_out_repr())
->assert_is_op_output("fast_where_xpu", "out");
auto x =
pattern->NewNode(x_repr())->assert_is_op_input("fast_where_xpu", "x");
auto y =
pattern->NewNode(y_repr())->assert_is_op_input("fast_where_xpu", "y");
fast_where_xpu0->LinksFrom({condition0, x, y}).LinksTo({fast_where_xpu0_out});
PADDLE_ENFORCE_LE(
mode,
1,
platform::errors::InvalidArgument(
"cascade_fast_where_xpu_fuse_pass mode(%d) is not supported.", mode));
if (mode == 0) {
fast_where_xpu0_out->assert_is_op_input("fast_where_xpu", "y");
fast_where_xpu1->LinksFrom({condition1, x, fast_where_xpu0_out})
.LinksTo({fast_where_xpu1_out});
} else if (mode == 1) {
fast_where_xpu0_out->assert_is_op_input("fast_where_xpu", "x");
fast_where_xpu1->LinksFrom({condition1, fast_where_xpu0_out, y})
.LinksTo({fast_where_xpu1_out});
}
}
} // namespace patterns
class OneFastWhereXPUFusePass : public FusePassBase {
public:
void ApplyImpl(ir::Graph* graph) const override;
private:
int ApplySubgraph(ir::Graph* graph, int mode) const;
const std::string name_scope_{"one_fast_where_xpu_fuse_pass"};
};
int OneFastWhereXPUFusePass::ApplySubgraph(ir::Graph* graph, int mode) const {
GraphPatternDetector gpd;
patterns::OneFastWhereXPUPattern pattern(
gpd.mutable_pattern(), name_scope_, mode);
int found_subgraph_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* graph) {
VLOG(4) << "handle FastWhereXPUFusePass fuse";
// declare operator node's name
GET_IR_NODE(cast);
GET_IR_NODE(scale);
GET_IR_NODE(mul0);
GET_IR_NODE(mul1);
GET_IR_NODE(add);
// declare variable node's name
// scale
GET_IR_NODE(condition);
GET_IR_NODE(cast_out);
GET_IR_NODE(scale_out);
// mul0
GET_IR_NODE(x);
GET_IR_NODE(mul0_out);
// mul1
GET_IR_NODE(y);
GET_IR_NODE(mul1_out);
// add
GET_IR_NODE(add_out);
auto* block = add->Op()->Block();
framework::OpDesc fast_where_xpu_op_desc(block);
fast_where_xpu_op_desc.SetType("fast_where_xpu");
fast_where_xpu_op_desc.SetInput("condition", {condition->Name()});
if (mode == 0) {
fast_where_xpu_op_desc.SetInput("x", {y->Name()});
fast_where_xpu_op_desc.SetInput("y", {x->Name()});
} else if (mode == 1) {
fast_where_xpu_op_desc.SetInput("x", {x->Name()});
fast_where_xpu_op_desc.SetInput("y", {y->Name()});
}
fast_where_xpu_op_desc.SetOutput("out", {add_out->Name()});
auto fast_where_xpu_op_node = graph->CreateOpNode(&fast_where_xpu_op_desc);
IR_NODE_LINK_TO(x, fast_where_xpu_op_node);
IR_NODE_LINK_TO(y, fast_where_xpu_op_node);
IR_NODE_LINK_TO(condition, fast_where_xpu_op_node);
IR_NODE_LINK_TO(fast_where_xpu_op_node, add_out);
std::unordered_set<const Node*> delete_nodes = {
cast, cast_out, scale, scale_out, mul0, mul0_out, mul1, mul1_out, add};
GraphSafeRemoveNodes(graph, delete_nodes);
found_subgraph_count++;
};
gpd(graph, handler);
return found_subgraph_count;
}
void OneFastWhereXPUFusePass::ApplyImpl(ir::Graph* graph) const {
PADDLE_ENFORCE_NOT_NULL(
graph, platform::errors::PreconditionNotMet("graph should not be null."));
Init(name_scope_, graph);
int found_subgraph_count = 0;
for (auto mode : {0, 1}) {
found_subgraph_count += ApplySubgraph(graph, mode);
}
AddStatis(found_subgraph_count);
}
class CascadeFastWhereXPUFusePass : public FusePassBase {
public:
void ApplyImpl(ir::Graph* graph) const override;
private:
int ApplySubgraph(ir::Graph* graph, int mode) const;
const std::string name_scope_{"cascade_fast_where_xpu_fuse_pass"};
};
int CascadeFastWhereXPUFusePass::ApplySubgraph(ir::Graph* graph,
int mode) const {
GraphPatternDetector gpd;
patterns::CascadeFastWhereXPUPattern pattern(
gpd.mutable_pattern(), name_scope_, mode);
int found_subgraph_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* graph) {
VLOG(4) << "handle FastWhereXPUFusePass fuse";
// declare operator node's name
GET_IR_NODE(fast_where_xpu0);
GET_IR_NODE(fast_where_xpu1);
// declare variable node's name
GET_IR_NODE(condition0);
GET_IR_NODE(condition1);
GET_IR_NODE(x);
GET_IR_NODE(y);
GET_IR_NODE(fast_where_xpu0_out);
GET_IR_NODE(fast_where_xpu1_out);
// Reuse variables
fast_where_xpu0_out->Var()->SetShape(condition0->Var()->GetShape());
fast_where_xpu0_out->Var()->SetDataType(condition0->Var()->GetDataType());
// Change the first fast_where_xpu op to logical op
fast_where_xpu0->Op()->RemoveInput("condition");
fast_where_xpu0->Op()->RemoveInput("x");
fast_where_xpu0->Op()->RemoveInput("y");
fast_where_xpu0->Op()->RemoveOutput("out");
fast_where_xpu0->Op()->SetInput(
"X", std::vector<std::string>({condition0->Name()}));
fast_where_xpu0->Op()->SetInput(
"Y", std::vector<std::string>({condition1->Name()}));
fast_where_xpu0->Op()->SetOutput(
"Out", std::vector<std::string>({fast_where_xpu0_out->Name()}));
// Reserve the second first_where_xpu but change its inputs
fast_where_xpu1->Op()->SetInput(
"condition", std::vector<std::string>({fast_where_xpu0_out->Name()}));
fast_where_xpu1->Op()->SetInput("x", std::vector<std::string>({x->Name()}));
fast_where_xpu1->Op()->SetInput("y", std::vector<std::string>({y->Name()}));
if (mode == 0) {
fast_where_xpu0->Op()->SetType("logical_or");
} else if (mode == 1) {
fast_where_xpu0->Op()->SetType("logical_and");
}
IR_NODE_UNLINK(x, fast_where_xpu0);
IR_NODE_UNLINK(y, fast_where_xpu0);
IR_NODE_LINK_TO(condition1, fast_where_xpu0);
IR_NODE_UNLINK(condition1, fast_where_xpu1);
IR_NODE_LINK_TO(x, fast_where_xpu1);
IR_NODE_LINK_TO(y, fast_where_xpu1);
found_subgraph_count++;
};
gpd(graph, handler);
return found_subgraph_count;
}
void CascadeFastWhereXPUFusePass::ApplyImpl(ir::Graph* graph) const {
PADDLE_ENFORCE_NOT_NULL(
graph, platform::errors::PreconditionNotMet("graph should not be null."));
Init(name_scope_, graph);
int total_found_subgraph_count = 0;
int cur_found_subgraph_count = 0;
do {
cur_found_subgraph_count = 0;
for (auto mode : {0, 1}) {
cur_found_subgraph_count += ApplySubgraph(graph, mode);
}
total_found_subgraph_count += cur_found_subgraph_count;
} while (cur_found_subgraph_count > 0);
AddStatis(total_found_subgraph_count);
}
class FastWhereXPUFusePass : public FusePassBase {
protected:
void ApplyImpl(ir::Graph* graph) const override;
private:
const std::string name_scope_{"fast_where_xpu_fuse_pass"};
};
void FastWhereXPUFusePass::ApplyImpl(ir::Graph* graph) const {
VLOG(4) << "handle fast_where_xpu op fusion.";
OneFastWhereXPUFusePass one_fast_where_xpu_fuse_pass;
one_fast_where_xpu_fuse_pass.ApplyImpl(graph);
CascadeFastWhereXPUFusePass cascade_fast_where_xpu_fuse_pass;
cascade_fast_where_xpu_fuse_pass.ApplyImpl(graph);
}
} // namespace ir
} // namespace framework
} // namespace paddle
REGISTER_PASS(fast_where_xpu_fuse_pass,
paddle::framework::ir::FastWhereXPUFusePass);
REGISTER_PASS_CAPABILITY(fast_where_xpu_fuse_pass)
.AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination().EQ(
"fast_where_xpu_fuse_pass", 0));
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <gtest/gtest.h>
#include "paddle/fluid/framework/ir/pass.h"
#include "paddle/fluid/framework/ir/pass_tester_helper.h"
namespace paddle {
namespace framework {
namespace ir {
#define APPLY_PASS \
std::unique_ptr<ir::Graph> graph(new ir::Graph(layers.main_program())); \
auto pass = PassRegistry::Instance().Get("fast_where_xpu_fuse_pass"); \
pass->Apply(graph.get());
#define VERIFY_GRAPH(x, y) \
auto num_op_nodes = GetNumOpNodes(graph); \
PADDLE_ENFORCE_EQ( \
num_op_nodes, \
1, \
platform::errors::PreconditionNotMet( \
"The graph contains only one op node, but %d op nodes found.", \
num_op_nodes)); \
auto fast_where_xpu_op_nodes = GetOpNodes(graph, "fast_where_xpu"); \
PADDLE_ENFORCE_EQ(fast_where_xpu_op_nodes.size(), \
1, \
platform::errors::PreconditionNotMet( \
"The graph contains only a fast_where_xpu op node, " \
"but %d op nodes found.", \
fast_where_xpu_op_nodes.size())); \
const auto& x_name = fast_where_xpu_op_nodes[0]->Op()->Input("x")[0]; \
PADDLE_ENFORCE_EQ(x_name, \
#x, \
platform::errors::PreconditionNotMet( \
"The input 'x' of fast_where_xpu op should be '%s', " \
"but receive '%s'.", \
#x, \
x_name)); \
const auto& y_name = fast_where_xpu_op_nodes[0]->Op()->Input("y")[0]; \
PADDLE_ENFORCE_EQ(y_name, \
#y, \
platform::errors::PreconditionNotMet( \
"The input 'y' of fast_where_xpu op should be '%s', " \
"but receive '%s'.", \
#y, \
y_name));
TEST(FastWhereXPUFusePass, one_case0) {
Layers layers;
auto* condition =
layers.data("condition", {20, 1}, false, proto::VarType::BOOL);
auto* x = layers.data("x", {20, 7});
auto* y = layers.data("y", {20, 7});
auto* cast_out = layers.cast(condition, 0, 5);
cast_out->SetShape({20, 1});
auto* scale_out = layers.scale(cast_out, -1.0f, 1.0f, true);
scale_out->SetShape({20, 1});
auto* mul0_out = layers.elementwise_mul(x, scale_out);
mul0_out->SetShape({20, 7});
auto* mul1_out = layers.elementwise_mul(y, cast_out);
mul1_out->SetShape({20, 7});
auto* add_out = layers.elementwise_add(mul0_out, mul1_out);
add_out->SetShape({20, 7});
APPLY_PASS
VERIFY_GRAPH(y, x)
}
TEST(FastWhereXPUFusePass, one_case1) {
Layers layers;
auto* condition =
layers.data("condition", {20, 1}, false, proto::VarType::BOOL);
auto* x = layers.data("x", {20, 7});
auto* y = layers.data("y", {20, 7});
auto* cast_out = layers.cast(condition, 0, 5);
cast_out->SetShape({20, 1});
auto* mul0_out = layers.elementwise_mul(x, cast_out);
mul0_out->SetShape({20, 7});
auto* scale_out = layers.scale(cast_out, -1.0f, 1.0f, true);
scale_out->SetShape({20, 1});
auto* mul1_out = layers.elementwise_mul(y, scale_out);
mul1_out->SetShape({20, 7});
auto* add_out = layers.elementwise_add(mul0_out, mul1_out);
add_out->SetShape({20, 7});
APPLY_PASS
VERIFY_GRAPH(x, y)
}
TEST(FastWhereXPUFusePass, one_case2) {
Layers layers;
auto* condition =
layers.data("condition", {20, 1}, false, proto::VarType::BOOL);
auto* x = layers.data("x", {20, 7});
auto* y = layers.data("y", {20, 7});
auto* cast_out = layers.cast(condition, 0, 5);
cast_out->SetShape({20, 1});
auto* scale_out = layers.scale(cast_out, -1.0f, 1.0f, true);
scale_out->SetShape({20, 1});
auto* mul0_out = layers.elementwise_mul(scale_out, x);
mul0_out->SetShape({20, 7});
auto* mul1_out = layers.elementwise_mul(cast_out, y);
mul1_out->SetShape({20, 7});
auto* add_out = layers.elementwise_add(mul0_out, mul1_out);
add_out->SetShape({20, 7});
APPLY_PASS
VERIFY_GRAPH(y, x)
}
TEST(FastWhereXPUFusePass, one_case3) {
Layers layers;
auto* condition =
layers.data("condition", {20, 1}, false, proto::VarType::BOOL);
auto* x = layers.data("x", {20, 7});
auto* y = layers.data("y", {20, 7});
auto* cast_out = layers.cast(condition, 0, 5);
cast_out->SetShape({20, 1});
auto* mul0_out = layers.elementwise_mul(cast_out, x);
mul0_out->SetShape({20, 7});
auto* scale_out = layers.scale(cast_out, -1.0f, 1.0f, true);
scale_out->SetShape({20, 1});
auto* mul1_out = layers.elementwise_mul(scale_out, y);
mul1_out->SetShape({20, 7});
auto* add_out = layers.elementwise_add(mul0_out, mul1_out);
add_out->SetShape({20, 7});
APPLY_PASS
VERIFY_GRAPH(x, y)
}
TEST(FastWhereXPUFusePass, one_case4) {
Layers layers;
auto* condition =
layers.data("condition", {20, 1}, false, proto::VarType::BOOL);
auto* x = layers.data("x", {20, 7});
auto* y = layers.data("y", {20, 7});
auto* cast_out = layers.cast(condition, 0, 5);
cast_out->SetShape({20, 1});
auto* scale_out = layers.scale(cast_out, -1.0f, 1.0f, true);
scale_out->SetShape({20, 1});
auto* mul0_out = layers.elementwise_mul(scale_out, x);
mul0_out->SetShape({20, 7});
auto* mul1_out = layers.elementwise_mul(y, cast_out);
mul1_out->SetShape({20, 7});
auto* add_out = layers.elementwise_add(mul0_out, mul1_out);
add_out->SetShape({20, 7});
APPLY_PASS
VERIFY_GRAPH(y, x)
}
TEST(FastWhereXPUFusePass, one_case5) {
Layers layers;
auto* condition =
layers.data("condition", {20, 1}, false, proto::VarType::BOOL);
auto* x = layers.data("x", {20, 7});
auto* y = layers.data("y", {20, 7});
auto* cast_out = layers.cast(condition, 0, 5);
cast_out->SetShape({20, 1});
auto* mul0_out = layers.elementwise_mul(cast_out, x);
mul0_out->SetShape({20, 7});
auto* scale_out = layers.scale(cast_out, -1.0f, 1.0f, true);
scale_out->SetShape({20, 1});
auto* mul1_out = layers.elementwise_mul(y, scale_out);
mul1_out->SetShape({20, 7});
auto* add_out = layers.elementwise_add(mul0_out, mul1_out);
add_out->SetShape({20, 7});
APPLY_PASS
VERIFY_GRAPH(x, y)
}
#undef VERIFY_GRAPH
#define VERIFY_GRAPH(logical_op, x, y) \
auto num_op_nodes = GetNumOpNodes(graph); \
PADDLE_ENFORCE_EQ( \
num_op_nodes, \
2, \
platform::errors::PreconditionNotMet( \
"The graph contains only two op nodes, but %d op nodes found.", \
num_op_nodes)); \
auto logical_op_nodes = GetOpNodes(graph, #logical_op); \
PADDLE_ENFORCE_EQ( \
logical_op_nodes.size(), \
1, \
platform::errors::PreconditionNotMet( \
"The graph contains only a '%s' op node, but %d op nodes found.", \
#logical_op, \
logical_op_nodes.size())); \
auto fast_where_xpu_op_nodes = GetOpNodes(graph, "fast_where_xpu"); \
PADDLE_ENFORCE_EQ(fast_where_xpu_op_nodes.size(), \
1, \
platform::errors::PreconditionNotMet( \
"The graph contains only a fast_where_xpu op node, " \
"but %d op nodes found.", \
fast_where_xpu_op_nodes.size())); \
const auto& x_name = fast_where_xpu_op_nodes[0]->Op()->Input("x")[0]; \
PADDLE_ENFORCE_EQ(x_name, \
#x, \
platform::errors::PreconditionNotMet( \
"The input 'x' of fast_where_xpu op should be '%s', " \
"but receive '%s'.", \
#x, \
x_name)); \
const auto& y_name = fast_where_xpu_op_nodes[0]->Op()->Input("y")[0]; \
PADDLE_ENFORCE_EQ(y_name, \
#y, \
platform::errors::PreconditionNotMet( \
"The input 'y' of fast_where_xpu op should be '%s', " \
"but receive '%s'.", \
#y, \
y_name));
TEST(FastWhereXPUFusePass, cascade_case0) {
Layers layers;
auto* condition0 =
layers.data("condition0", {20, 1}, false, proto::VarType::BOOL);
auto* condition1 =
layers.data("condition1", {20, 1}, false, proto::VarType::BOOL);
auto* x = layers.data("x", {20, 7});
auto* y = layers.data("y", {20, 7});
// fast_where_xpu0
auto* cast0_out = layers.cast(condition0, 0, 5);
cast0_out->SetShape({20, 1});
auto* mul0_out = layers.elementwise_mul(cast0_out, x);
mul0_out->SetShape({20, 7});
auto* scale0_out = layers.scale(cast0_out, -1.0f, 1.0f, true);
scale0_out->SetShape({20, 1});
auto* mul1_out = layers.elementwise_mul(scale0_out, y);
mul1_out->SetShape({20, 7});
auto* add0_out = layers.elementwise_add(mul0_out, mul1_out);
add0_out->SetShape({20, 7});
// fast_where_xpu1
auto* cast1_out = layers.cast(condition1, 0, 5);
cast1_out->SetShape({20, 1});
auto* mul2_out = layers.elementwise_mul(cast1_out, x);
mul2_out->SetShape({20, 7});
auto* scale1_out = layers.scale(cast1_out, -1.0f, 1.0f, true);
scale1_out->SetShape({20, 1});
auto* mul3_out = layers.elementwise_mul(scale1_out, add0_out);
mul3_out->SetShape({20, 7});
auto* add1_out = layers.elementwise_add(mul2_out, mul3_out);
add1_out->SetShape({20, 7});
APPLY_PASS
VERIFY_GRAPH(logical_or, x, y)
}
TEST(FastWhereXPUFusePass, cascade_case1) {
Layers layers;
auto* condition0 =
layers.data("condition0", {20, 1}, false, proto::VarType::BOOL);
auto* condition1 =
layers.data("condition1", {20, 1}, false, proto::VarType::BOOL);
auto* x = layers.data("x", {20, 7});
auto* y = layers.data("y", {20, 7});
// fast_where_xpu0
auto* cast0_out = layers.cast(condition0, 0, 5);
cast0_out->SetShape({20, 1});
auto* mul0_out = layers.elementwise_mul(cast0_out, x);
mul0_out->SetShape({20, 7});
auto* scale0_out = layers.scale(cast0_out, -1.0f, 1.0f, true);
scale0_out->SetShape({20, 1});
auto* mul1_out = layers.elementwise_mul(scale0_out, y);
mul1_out->SetShape({20, 7});
auto* add0_out = layers.elementwise_add(mul0_out, mul1_out);
add0_out->SetShape({20, 7});
// fast_where_xpu1
auto* cast1_out = layers.cast(condition1, 0, 5);
cast1_out->SetShape({20, 1});
auto* mul2_out = layers.elementwise_mul(cast1_out, add0_out);
mul2_out->SetShape({20, 7});
auto* scale1_out = layers.scale(cast1_out, -1.0f, 1.0f, true);
scale1_out->SetShape({20, 1});
auto* mul3_out = layers.elementwise_mul(scale1_out, y);
mul3_out->SetShape({20, 7});
auto* add1_out = layers.elementwise_add(mul2_out, mul3_out);
add1_out->SetShape({20, 7});
APPLY_PASS
VERIFY_GRAPH(logical_and, x, y)
}
#undef APPLY_PASS
#undef VERIFY_GRAPH
} // namespace ir
} // namespace framework
} // namespace paddle
USE_PASS(fast_where_xpu_fuse_pass);
......@@ -545,6 +545,7 @@ XpuPassStrategy::XpuPassStrategy() : PassStrategy({}) {
"add_activation_xpu_fuse_pass",
"add_layernorm_xpu_fuse_pass",
"yolo_box_xpu_fuse_pass",
"fast_where_xpu_fuse_pass",
"link_xpu_op_max_pass",
"delete_isolated_node_pass",
// "auto_mixed_precision_pass",
......
......@@ -53,6 +53,15 @@
data_type: tables
optional : mask, seq_lod, max_seq_len
- op : fast_where_xpu
args : (Tensor condition, Tensor x, Tensor y)
output : Tensor(out)
infer_meta :
func : FastWhereXPUInferMeta
kernel :
func : fast_where_xpu
data_type : x
- op : fc_xpu
args : (Tensor x, Tensor x_max, Tensor w, Tensor w_max, Tensor bias, int in_num_col_dims, bool transpose_x, float alpha, float beta, int act_type, float act_alpha, DataType out_dtype)
output : Tensor(out), Tensor(out_max)
......
......@@ -295,6 +295,10 @@ XPUOpMap& get_kl2_ops() {
phi::DataType::BOOL,
phi::DataType::FLOAT16,
phi::DataType::FLOAT32})},
{"fast_where_xpu",
XPUKernelSet({phi::DataType::INT32,
phi::DataType::FLOAT32,
phi::DataType::FLOAT16})},
{"fc_xpu",
XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})},
{"fill",
......
......@@ -721,4 +721,12 @@ void Conv2dTransposeXPUInferMeta(const MetaTensor& x,
out_max);
}
void FastWhereXPUInferMeta(const MetaTensor& condition,
const MetaTensor& x,
const MetaTensor& y,
MetaTensor* out) {
out->set_dims(x.dims());
out->set_dtype(x.dtype());
}
} // namespace phi
......@@ -175,4 +175,10 @@ void Conv2dTransposeXPUInferMeta(const MetaTensor& x,
const std::string& act_type,
MetaTensor* out,
MetaTensor* out_max);
void FastWhereXPUInferMeta(const MetaTensor& condition,
const MetaTensor& x,
const MetaTensor& y,
MetaTensor* out);
} // namespace phi
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "glog/logging.h"
#include "paddle/phi/backends/xpu/enforce_xpu.h"
#include "paddle/phi/core/kernel_registry.h"
namespace phi {
namespace fusion {
template <typename T, typename Context>
void FastWhereXPUKernel(const Context& ctx,
const DenseTensor& condition,
const DenseTensor& x,
const DenseTensor& y,
DenseTensor* out) {
using XPUType = typename XPUTypeTrait<T>::Type;
auto* condition_data = condition.data<bool>();
auto* x_data = reinterpret_cast<const XPUType*>(x.data<T>());
auto* y_data = reinterpret_cast<const XPUType*>(y.data<T>());
auto* out_data = reinterpret_cast<XPUType*>(ctx.template Alloc<T>(out));
auto condition_dims = phi::vectorize<int>(condition.dims());
auto x_dims = phi::vectorize<int>(x.dims());
auto y_dims = phi::vectorize<int>(y.dims());
PADDLE_ENFORCE_EQ(
x_dims,
y_dims,
errors::PreconditionNotMet(
"The dimensions of inputs should be equal, but x_dims=[",
x.dims(),
"] and y_dims=[",
y.dims(),
"]"));
#ifndef PADDLE_WITH_XPU_PLUGIN
LOG(WARNING)
<< "Add -DWITH_XPU_PLUGIN=ON to build xpu::plugin::fast_where(), or use "
"xpu::select() instead, which leads low performance.";
int r = xpu::select<XPUType>(ctx.x_context(),
condition_data,
x_data,
y_data,
out_data,
condition_dims,
x_dims);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "select");
#else
xpu::ctx_guard RAII_GUARD(ctx.x_context());
if (condition_dims != x_dims) {
bool* temp_data = RAII_GUARD.alloc_l3_or_gm<bool>(x.numel());
int r = xpu::broadcast<bool>(
ctx.x_context(), condition_data, temp_data, condition_dims, x_dims);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "broadcast");
condition_data = temp_data;
}
int r = xpu::plugin::fast_where<XPUType>(
ctx.x_context(), condition_data, x_data, y_data, out_data, x.numel());
PADDLE_ENFORCE_XDNN_SUCCESS(r, "fast_where");
#endif
}
} // namespace fusion
} // namespace phi
PD_REGISTER_KERNEL(fast_where_xpu,
XPU,
ALL_LAYOUT,
phi::fusion::FastWhereXPUKernel,
float,
phi::dtype::float16,
int) {}
......@@ -154,7 +154,7 @@ macro(
${kernel_path} -D ${xpu_n_macro} --target=${TARGET_ARCH} ${HOST_XPU_FLAGS}
--basename ${kernel_name} -fno-builtin --xpu-arch=${xpu_n} -fPIC
-Wno-int-to-void-pointer-cast -Wno-int-to-pointer-cast -Werror -mllvm
--xpu-inline-cost -mllvm --xpu-inline-hot-call
--xpu-inline-cost -mllvm --xpu-inline-hot-call -I${XDNN_INC_DIR}
-I${CMAKE_CURRENT_SOURCE_DIR}/include -I${CMAKE_CURRENT_SOURCE_DIR}/src
-I${CMAKE_CURRENT_SOURCE_DIR}/src/kernel
-I${CMAKE_CURRENT_SOURCE_DIR}/src/kernel/include ${arg_rule}
......
......@@ -24,6 +24,13 @@ namespace api {
namespace plugin {
DLL_EXPORT int add2(Context* ctx, const float* x, float* y, int len);
template <typename T>
DLL_EXPORT int fast_where(Context* ctx,
const bool* condition,
const T* x,
const T* y,
T* out,
int64_t len);
} // namespace plugin
} // namespace api
......
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
/*
* copyright (C) 2022 KUNLUNXIN, Inc
*/
#include "xpu/kernel/cluster.h"
#include "xpu/kernel/cluster_partition.h"
#include "xpu/kernel/cluster_primitive.h"
namespace xpu2 {
namespace plugin {
#define CALC_MASK(offset) \
mask |= static_cast<int>(condition[i + offset]) << offset;
static __device__ inline void do_select_16(const int8_t* condition,
const int16_t* x,
int16_t* y,
int len) {
int len_rounddown32 = rounddown32(len);
for (int i = 0; i < len_rounddown32; i += 32) {
int mask = condition[i];
CALC_MASK(1)
CALC_MASK(2)
CALC_MASK(3)
CALC_MASK(4)
CALC_MASK(5)
CALC_MASK(6)
CALC_MASK(7)
CALC_MASK(8)
CALC_MASK(9)
CALC_MASK(10)
CALC_MASK(11)
CALC_MASK(12)
CALC_MASK(13)
CALC_MASK(14)
CALC_MASK(15)
CALC_MASK(16)
CALC_MASK(17)
CALC_MASK(18)
CALC_MASK(19)
CALC_MASK(20)
CALC_MASK(21)
CALC_MASK(22)
CALC_MASK(23)
CALC_MASK(24)
CALC_MASK(25)
CALC_MASK(26)
CALC_MASK(27)
CALC_MASK(28)
CALC_MASK(29)
CALC_MASK(30)
CALC_MASK(31)
vstore_lm_int16x32_mh(y + i, vload_lm_int16x32(x + i), mask);
}
for (int i = len_rounddown32; i < len; i++) {
y[i] = condition[i] ? x[i] : y[i];
}
mfence_lm();
}
static __device__ inline void do_select_32(const int8_t* condition,
const int32_t* x,
int32_t* y,
int len) {
int len_rounddown16 = rounddown16(len);
for (int i = 0; i < len_rounddown16; i += 16) {
int mask = condition[i];
CALC_MASK(1)
CALC_MASK(2)
CALC_MASK(3)
CALC_MASK(4)
CALC_MASK(5)
CALC_MASK(6)
CALC_MASK(7)
CALC_MASK(8)
CALC_MASK(9)
CALC_MASK(10)
CALC_MASK(11)
CALC_MASK(12)
CALC_MASK(13)
CALC_MASK(14)
CALC_MASK(15)
vstore_lm_int32x16_mh(y + i, vload_lm_int32x16(x + i), mask);
}
for (int i = len_rounddown16; i < len; i++) {
y[i] = condition[i] ? x[i] : y[i];
}
mfence_lm();
}
template <typename T>
static __device__ void do_select(const int8_t* condition,
const T* x,
T* y,
int len) {}
template <>
__device__ void do_select<float16>(const int8_t* condition,
const float16* x,
float16* y,
int len) {
do_select_16(condition,
reinterpret_cast<const int16_t*>(x),
reinterpret_cast<int16_t*>(y),
len);
}
template <>
__device__ void do_select<float>(const int8_t* condition,
const float* x,
float* y,
int len) {
do_select_32(condition,
reinterpret_cast<const int32_t*>(x),
reinterpret_cast<int32_t*>(y),
len);
}
template <>
__device__ void do_select<int16_t>(const int8_t* condition,
const int16_t* x,
int16_t* y,
int len) {
do_select_16(condition, x, y, len);
}
template <>
__device__ void do_select<int32_t>(const int8_t* condition,
const int32_t* x,
int32_t* y,
int len) {
do_select_32(condition, x, y, len);
}
template <typename T>
__global__ void fast_where(
const int8_t* condition, const T* x, const T* y, T* z, int64_t len) {
int tid = core_id() * cluster_num() + cluster_id();
int nthreads = core_num() * cluster_num();
#ifdef __XPU3__
const int buf_len = 1536 / sizeof(T);
#else
const int buf_len = 512 / sizeof(T);
#endif
__simd__ int8_t local_condition[buf_len];
__simd__ T local_x[buf_len];
__simd__ T local_y[buf_len];
int loop = 0;
for (int64_t i = tid * buf_len; i < len; i += nthreads * buf_len) {
int read_len = min(static_cast<int64_t>(buf_len), len - i);
GM2LM_ASYNC(condition + i, local_condition, read_len * sizeof(int8_t));
GM2LM_ASYNC(x + i, local_x, read_len * sizeof(T));
GM2LM(y + i, local_y, read_len * sizeof(T));
do_select<T>(local_condition, local_x, local_y, read_len);
LM2GM_ASYNC(local_y, z + i, read_len * sizeof(T));
mfence();
#ifndef __XPU3__
loop++;
if ((loop & 0xF) == 0) {
sync_all();
}
#endif
}
}
#define _XPU_DEF__FAST_WHERE_(DTYPE) \
template __global__ void fast_where<DTYPE>(const int8_t* condition, \
const DTYPE* x, \
const DTYPE* y, \
DTYPE* z, \
int64_t len);
_XPU_DEF__FAST_WHERE_(float16);
_XPU_DEF__FAST_WHERE_(float);
_XPU_DEF__FAST_WHERE_(int16_t);
_XPU_DEF__FAST_WHERE_(int32_t);
} // namespace plugin
} // namespace xpu2
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
/*
* copyright (C) 2022 KUNLUNXIN, Inc
*/
#include "xpu/plugin.h"
#include "xpu/refactor/impl_public/wrapper_check.h"
namespace xpu2 {
namespace plugin {
template <typename T>
__attribute__((global)) void fast_where(
const int8_t* condition, const T* x, const T* y, T* z, int64_t len);
}
} // namespace xpu2
namespace baidu {
namespace xpu {
namespace api {
namespace plugin {
template <typename T>
static int cpu_wrapper(Context* ctx,
const bool* condition,
const T* x,
const T* y,
T* z,
int64_t len) {
for (int64_t i = 0; i < len; i++) {
z[i] = condition[i] ? x[i] : y[i];
}
return SUCCESS;
}
template <>
int cpu_wrapper<float16>(Context* ctx,
const bool* condition,
const float16* x,
const float16* y,
float16* z,
int64_t len) {
std::vector<float> x_fp32(len);
std::vector<float> y_fp32(len);
std::vector<float> z_fp32(len);
int ret = cast<float16, float>(ctx, x, x_fp32.data(), len);
ret = cast<float16, float>(ctx, y, y_fp32.data(), len);
ret = cpu_wrapper<float>(
ctx, condition, x_fp32.data(), y_fp32.data(), z_fp32.data(), len);
ret = cast<float, float16>(ctx, z_fp32.data(), z, len);
WRAPPER_ASSERT_SUCCESS(ctx, ret);
return ret;
}
template <typename T>
static int xpu2_wrapper(Context* ctx,
const bool* condition,
const T* x,
const T* y,
T* z,
int64_t len) {
xpu2::plugin::fast_where<T><<<ctx->ncluster(), 64, ctx->xpu_stream>>>(
reinterpret_cast<const int8_t*>(condition), x, y, z, len);
return SUCCESS;
}
template <typename T>
int fast_where(Context* ctx,
const bool* condition,
const T* x,
const T* y,
T* z,
int64_t len) {
WRAPPER_CHECK_CTX(ctx);
WRAPPER_DUMP_FUNCTION_T1(ctx, "fast_where", float);
WRAPPER_DUMP_PARAM5(ctx, condition, x, y, z, len);
WRAPPER_DUMP(ctx);
WRAPPER_ASSERT_GT(ctx, len, 0);
WRAPPER_CHECK_2PTRS(ctx, T, len, x, y);
if (ctx->dev().type() == api::kCPU) {
return cpu_wrapper<T>(ctx, condition, x, y, z, len);
}
if (ctx->dev().type() == api::kXPU2) {
return xpu2_wrapper<T>(ctx, condition, x, y, z, len);
}
WRAPPER_UNIMPLEMENTED(ctx);
}
template int fast_where(Context*,
const bool* condition,
const float*,
const float*,
float*,
int64_t);
template int fast_where(Context*,
const bool* condition,
const float16*,
const float16*,
float16*,
int64_t);
template int fast_where(Context*,
const bool* condition,
const int16_t*,
const int16_t*,
int16_t*,
int64_t);
template int fast_where(Context*,
const bool* condition,
const int32_t*,
const int32_t*,
int32_t*,
int64_t);
} // namespace plugin
} // namespace api
} // namespace xpu
} // namespace baidu
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
from functools import partial
import hypothesis.strategies as st
import numpy as np
from auto_scan_test import PassAutoScanTest
from program_config import OpConfig, ProgramConfig, TensorConfig
class TestFastWhereXPUFusePassOneCase0(PassAutoScanTest):
def sample_predictor_configs(self, program_config):
config = self.create_inference_config(use_xpu=True)
yield config, ["fast_where_xpu"], (1e-3, 1e-3)
def sample_program_config(self, draw):
value_shape = draw(
st.lists(
st.integers(min_value=1, max_value=4), min_size=2, max_size=4
)
)
condition_shape = value_shape
condition_shape[-1] = 1
def generate_condition():
return np.random.random(condition_shape).astype(bool)
def generate_value():
return np.random.random(value_shape).astype(np.float32)
cast_op = OpConfig(
"cast",
inputs={"X": ["condition"]},
outputs={"Out": ["cast_out"]},
in_dtype=0,
out_dtype=5,
)
scale_op = OpConfig(
"scale",
inputs={"X": ["cast_out"]},
outputs={"Out": ["scale_out"]},
scale=-1,
bias=1,
base_after_scale=True,
)
mul0_op = OpConfig(
"elementwise_mul",
inputs={"X": ["x"], "Y": ["scale_out"]},
outputs={"Out": ["mul0_out"]},
axis=-1,
)
mul1_op = OpConfig(
"elementwise_mul",
inputs={"X": ["y"], "Y": ["cast_out"]},
outputs={"Out": ["mul1_out"]},
axis=-1,
)
add_op = OpConfig(
"elementwise_add",
inputs={"X": ["mul0_out"], "Y": ["mul1_out"]},
outputs={"Out": ["add0_out"]},
axis=-1,
)
ops = [cast_op, scale_op, mul0_op, mul1_op, add_op]
program_config = ProgramConfig(
ops=ops,
weights={},
inputs={
"condition": TensorConfig(data_gen=partial(generate_condition)),
"x": TensorConfig(data_gen=partial(generate_value)),
"y": TensorConfig(data_gen=partial(generate_value)),
},
outputs=ops[-1].outputs["Out"],
)
return program_config
def test(self):
self.run_and_statis(
quant=False,
max_examples=25,
passes=["fast_where_xpu_fuse_pass"],
)
class TestFastWhereXPUFusePassOneCase1(PassAutoScanTest):
def sample_predictor_configs(self, program_config):
config = self.create_inference_config(use_xpu=True)
yield config, ["fast_where_xpu"], (1e-3, 1e-3)
def sample_program_config(self, draw):
value_shape = draw(
st.lists(
st.integers(min_value=1, max_value=4), min_size=2, max_size=4
)
)
condition_shape = value_shape
condition_shape[-1] = 1
def generate_condition():
return np.random.random(condition_shape).astype(bool)
def generate_value():
return np.random.random(value_shape).astype(np.float32)
cast_op = OpConfig(
"cast",
inputs={"X": ["condition"]},
outputs={"Out": ["cast_out"]},
in_dtype=0,
out_dtype=5,
)
mul0_op = OpConfig(
"elementwise_mul",
inputs={"X": ["x"], "Y": ["cast_out"]},
outputs={"Out": ["mul0_out"]},
axis=-1,
)
scale_op = OpConfig(
"scale",
inputs={"X": ["cast_out"]},
outputs={"Out": ["scale_out"]},
scale=-1,
bias=1,
base_after_scale=True,
)
mul1_op = OpConfig(
"elementwise_mul",
inputs={"X": ["y"], "Y": ["scale_out"]},
outputs={"Out": ["mul1_out"]},
axis=-1,
)
add_op = OpConfig(
"elementwise_add",
inputs={"X": ["mul0_out"], "Y": ["mul1_out"]},
outputs={"Out": ["add0_out"]},
axis=-1,
)
ops = [cast_op, mul0_op, scale_op, mul1_op, add_op]
program_config = ProgramConfig(
ops=ops,
weights={},
inputs={
"condition": TensorConfig(data_gen=partial(generate_condition)),
"x": TensorConfig(data_gen=partial(generate_value)),
"y": TensorConfig(data_gen=partial(generate_value)),
},
outputs=ops[-1].outputs["Out"],
)
return program_config
def test(self):
self.run_and_statis(
quant=False,
max_examples=25,
passes=["fast_where_xpu_fuse_pass"],
)
class TestFastWhereXPUFusePassOneCase2(PassAutoScanTest):
def sample_predictor_configs(self, program_config):
config = self.create_inference_config(use_xpu=True)
yield config, ["fast_where_xpu"], (1e-3, 1e-3)
def sample_program_config(self, draw):
value_shape = draw(
st.lists(
st.integers(min_value=1, max_value=4), min_size=2, max_size=4
)
)
condition_shape = value_shape
condition_shape[-1] = 1
def generate_condition():
return np.random.random(condition_shape).astype(bool)
def generate_value():
return np.random.random(value_shape).astype(np.float32)
cast_op = OpConfig(
"cast",
inputs={"X": ["condition"]},
outputs={"Out": ["cast_out"]},
in_dtype=0,
out_dtype=5,
)
scale_op = OpConfig(
"scale",
inputs={"X": ["cast_out"]},
outputs={"Out": ["scale_out"]},
scale=-1,
bias=1,
base_after_scale=True,
)
mul0_op = OpConfig(
"elementwise_mul",
inputs={"X": ["scale_out"], "Y": ["x"]},
outputs={"Out": ["mul0_out"]},
axis=-1,
)
mul1_op = OpConfig(
"elementwise_mul",
inputs={"X": ["cast_out"], "Y": ["y"]},
outputs={"Out": ["mul1_out"]},
axis=-1,
)
add_op = OpConfig(
"elementwise_add",
inputs={"X": ["mul0_out"], "Y": ["mul1_out"]},
outputs={"Out": ["add0_out"]},
axis=-1,
)
ops = [cast_op, scale_op, mul0_op, mul1_op, add_op]
program_config = ProgramConfig(
ops=ops,
weights={},
inputs={
"condition": TensorConfig(data_gen=partial(generate_condition)),
"x": TensorConfig(data_gen=partial(generate_value)),
"y": TensorConfig(data_gen=partial(generate_value)),
},
outputs=ops[-1].outputs["Out"],
)
return program_config
def test(self):
self.run_and_statis(
quant=False,
max_examples=25,
passes=["fast_where_xpu_fuse_pass"],
)
class TestFastWhereXPUFusePassOneCase3(PassAutoScanTest):
def sample_predictor_configs(self, program_config):
config = self.create_inference_config(use_xpu=True)
yield config, ["fast_where_xpu"], (1e-3, 1e-3)
def sample_program_config(self, draw):
value_shape = draw(
st.lists(
st.integers(min_value=1, max_value=4), min_size=2, max_size=4
)
)
condition_shape = value_shape
condition_shape[-1] = 1
def generate_condition():
return np.random.random(condition_shape).astype(bool)
def generate_value():
return np.random.random(value_shape).astype(np.float32)
cast_op = OpConfig(
"cast",
inputs={"X": ["condition"]},
outputs={"Out": ["cast_out"]},
in_dtype=0,
out_dtype=5,
)
mul0_op = OpConfig(
"elementwise_mul",
inputs={"X": ["cast_out"], "Y": ["x"]},
outputs={"Out": ["mul0_out"]},
axis=-1,
)
scale_op = OpConfig(
"scale",
inputs={"X": ["cast_out"]},
outputs={"Out": ["scale_out"]},
scale=-1,
bias=1,
base_after_scale=True,
)
mul1_op = OpConfig(
"elementwise_mul",
inputs={"X": ["scale_out"], "Y": ["y"]},
outputs={"Out": ["mul1_out"]},
axis=-1,
)
add_op = OpConfig(
"elementwise_add",
inputs={"X": ["mul0_out"], "Y": ["mul1_out"]},
outputs={"Out": ["add0_out"]},
axis=-1,
)
ops = [cast_op, mul0_op, scale_op, mul1_op, add_op]
program_config = ProgramConfig(
ops=ops,
weights={},
inputs={
"condition": TensorConfig(data_gen=partial(generate_condition)),
"x": TensorConfig(data_gen=partial(generate_value)),
"y": TensorConfig(data_gen=partial(generate_value)),
},
outputs=ops[-1].outputs["Out"],
)
return program_config
def test(self):
self.run_and_statis(
quant=False,
max_examples=25,
passes=["fast_where_xpu_fuse_pass"],
)
class TestFastWhereXPUFusePassOneCase4(PassAutoScanTest):
def sample_predictor_configs(self, program_config):
config = self.create_inference_config(use_xpu=True)
yield config, ["fast_where_xpu"], (1e-3, 1e-3)
def sample_program_config(self, draw):
value_shape = draw(
st.lists(
st.integers(min_value=1, max_value=4), min_size=2, max_size=4
)
)
condition_shape = value_shape
condition_shape[-1] = 1
def generate_condition():
return np.random.random(condition_shape).astype(bool)
def generate_value():
return np.random.random(value_shape).astype(np.float32)
cast_op = OpConfig(
"cast",
inputs={"X": ["condition"]},
outputs={"Out": ["cast_out"]},
in_dtype=0,
out_dtype=5,
)
scale_op = OpConfig(
"scale",
inputs={"X": ["cast_out"]},
outputs={"Out": ["scale_out"]},
scale=-1,
bias=1,
base_after_scale=True,
)
mul0_op = OpConfig(
"elementwise_mul",
inputs={"X": ["scale_out"], "Y": ["x"]},
outputs={"Out": ["mul0_out"]},
axis=-1,
)
mul1_op = OpConfig(
"elementwise_mul",
inputs={"X": ["y"], "Y": ["cast_out"]},
outputs={"Out": ["mul1_out"]},
axis=-1,
)
add_op = OpConfig(
"elementwise_add",
inputs={"X": ["mul0_out"], "Y": ["mul1_out"]},
outputs={"Out": ["add0_out"]},
axis=-1,
)
ops = [cast_op, scale_op, mul0_op, mul1_op, add_op]
program_config = ProgramConfig(
ops=ops,
weights={},
inputs={
"condition": TensorConfig(data_gen=partial(generate_condition)),
"x": TensorConfig(data_gen=partial(generate_value)),
"y": TensorConfig(data_gen=partial(generate_value)),
},
outputs=ops[-1].outputs["Out"],
)
return program_config
def test(self):
self.run_and_statis(
quant=False,
max_examples=25,
passes=["fast_where_xpu_fuse_pass"],
)
class TestFastWhereXPUFusePassOneCase5(PassAutoScanTest):
def sample_predictor_configs(self, program_config):
config = self.create_inference_config(use_xpu=True)
yield config, ["fast_where_xpu"], (1e-3, 1e-3)
def sample_program_config(self, draw):
value_shape = draw(
st.lists(
st.integers(min_value=1, max_value=4), min_size=2, max_size=4
)
)
condition_shape = value_shape
condition_shape[-1] = 1
def generate_condition():
return np.random.random(condition_shape).astype(bool)
def generate_value():
return np.random.random(value_shape).astype(np.float32)
cast_op = OpConfig(
"cast",
inputs={"X": ["condition"]},
outputs={"Out": ["cast_out"]},
in_dtype=0,
out_dtype=5,
)
mul0_op = OpConfig(
"elementwise_mul",
inputs={"X": ["cast_out"], "Y": ["x"]},
outputs={"Out": ["mul0_out"]},
axis=-1,
)
scale_op = OpConfig(
"scale",
inputs={"X": ["cast_out"]},
outputs={"Out": ["scale_out"]},
scale=-1,
bias=1,
base_after_scale=True,
)
mul1_op = OpConfig(
"elementwise_mul",
inputs={"X": ["y"], "Y": ["scale_out"]},
outputs={"Out": ["mul1_out"]},
axis=-1,
)
add_op = OpConfig(
"elementwise_add",
inputs={"X": ["mul0_out"], "Y": ["mul1_out"]},
outputs={"Out": ["add0_out"]},
axis=-1,
)
ops = [cast_op, mul0_op, scale_op, mul1_op, add_op]
program_config = ProgramConfig(
ops=ops,
weights={},
inputs={
"condition": TensorConfig(data_gen=partial(generate_condition)),
"x": TensorConfig(data_gen=partial(generate_value)),
"y": TensorConfig(data_gen=partial(generate_value)),
},
outputs=ops[-1].outputs["Out"],
)
return program_config
def test(self):
self.run_and_statis(
quant=False,
max_examples=25,
passes=["fast_where_xpu_fuse_pass"],
)
class TestFastWhereXPUFusePassCascadeCase0(PassAutoScanTest):
def sample_predictor_configs(self, program_config):
config = self.create_inference_config(use_xpu=True)
yield config, ["logical_or", "fast_where_xpu"], (1e-3, 1e-3)
def sample_program_config(self, draw):
value_shape = draw(
st.lists(
st.integers(min_value=1, max_value=4), min_size=2, max_size=4
)
)
condition_shape = value_shape
condition_shape[-1] = 1
def generate_condition():
return np.random.random(condition_shape).astype(bool)
def generate_value():
return np.random.random(value_shape).astype(np.float32)
# fast_where_xpu0
cast0_op = OpConfig(
"cast",
inputs={"X": ["condition0"]},
outputs={"Out": ["cast0_out"]},
in_dtype=0,
out_dtype=5,
)
mul0_op = OpConfig(
"elementwise_mul",
inputs={"X": ["cast0_out"], "Y": ["x"]},
outputs={"Out": ["mul0_out"]},
axis=-1,
)
scale0_op = OpConfig(
"scale",
inputs={"X": ["cast0_out"]},
outputs={"Out": ["scale0_out"]},
scale=-1,
bias=1,
base_after_scale=True,
)
mul1_op = OpConfig(
"elementwise_mul",
inputs={"X": ["scale0_out"], "Y": ["y"]},
outputs={"Out": ["mul1_out"]},
axis=-1,
)
add0_op = OpConfig(
"elementwise_add",
inputs={"X": ["mul0_out"], "Y": ["mul1_out"]},
outputs={"Out": ["add0_out"]},
axis=-1,
)
# fast_where_xpu1
cast1_op = OpConfig(
"cast",
inputs={"X": ["condition1"]},
outputs={"Out": ["cast1_out"]},
in_dtype=0,
out_dtype=5,
)
mul2_op = OpConfig(
"elementwise_mul",
inputs={"X": ["cast1_out"], "Y": ["x"]},
outputs={"Out": ["mul2_out"]},
axis=-1,
)
scale1_op = OpConfig(
"scale",
inputs={"X": ["cast1_out"]},
outputs={"Out": ["scale1_out"]},
scale=-1,
bias=1,
base_after_scale=True,
)
mul3_op = OpConfig(
"elementwise_mul",
inputs={"X": ["scale1_out"], "Y": ["add0_out"]},
outputs={"Out": ["mul3_out"]},
axis=-1,
)
add1_op = OpConfig(
"elementwise_add",
inputs={"X": ["mul2_out"], "Y": ["mul3_out"]},
outputs={"Out": ["add1_out"]},
axis=-1,
)
ops = [
cast0_op,
mul0_op,
scale0_op,
mul1_op,
add0_op,
cast1_op,
mul2_op,
scale1_op,
mul3_op,
add1_op,
]
program_config = ProgramConfig(
ops=ops,
weights={},
inputs={
"condition0": TensorConfig(
data_gen=partial(generate_condition)
),
"condition1": TensorConfig(
data_gen=partial(generate_condition)
),
"x": TensorConfig(data_gen=partial(generate_value)),
"y": TensorConfig(data_gen=partial(generate_value)),
},
outputs=ops[-1].outputs["Out"],
)
return program_config
def test(self):
self.run_and_statis(
quant=False,
max_examples=25,
passes=["fast_where_xpu_fuse_pass"],
)
class TestFastWhereXPUFusePassCascadeCase1(PassAutoScanTest):
def sample_predictor_configs(self, program_config):
config = self.create_inference_config(use_xpu=True)
yield config, ["logical_and", "fast_where_xpu"], (1e-3, 1e-3)
def sample_program_config(self, draw):
value_shape = draw(
st.lists(
st.integers(min_value=1, max_value=4), min_size=2, max_size=4
)
)
condition_shape = value_shape
condition_shape[-1] = 1
def generate_condition():
return np.random.random(condition_shape).astype(bool)
def generate_value():
return np.random.random(value_shape).astype(np.float32)
# fast_where_xpu0
cast0_op = OpConfig(
"cast",
inputs={"X": ["condition0"]},
outputs={"Out": ["cast0_out"]},
in_dtype=0,
out_dtype=5,
)
mul0_op = OpConfig(
"elementwise_mul",
inputs={"X": ["cast0_out"], "Y": ["x"]},
outputs={"Out": ["mul0_out"]},
axis=-1,
)
scale0_op = OpConfig(
"scale",
inputs={"X": ["cast0_out"]},
outputs={"Out": ["scale0_out"]},
scale=-1,
bias=1,
base_after_scale=True,
)
mul1_op = OpConfig(
"elementwise_mul",
inputs={"X": ["scale0_out"], "Y": ["y"]},
outputs={"Out": ["mul1_out"]},
axis=-1,
)
add0_op = OpConfig(
"elementwise_add",
inputs={"X": ["mul0_out"], "Y": ["mul1_out"]},
outputs={"Out": ["add0_out"]},
axis=-1,
)
# fast_where_xpu1
cast1_op = OpConfig(
"cast",
inputs={"X": ["condition1"]},
outputs={"Out": ["cast1_out"]},
in_dtype=0,
out_dtype=5,
)
mul2_op = OpConfig(
"elementwise_mul",
inputs={"X": ["cast1_out"], "Y": ["add0_out"]},
outputs={"Out": ["mul2_out"]},
axis=-1,
)
scale1_op = OpConfig(
"scale",
inputs={"X": ["cast1_out"]},
outputs={"Out": ["scale1_out"]},
scale=-1,
bias=1,
base_after_scale=True,
)
mul3_op = OpConfig(
"elementwise_mul",
inputs={"X": ["scale1_out"], "Y": ["y"]},
outputs={"Out": ["mul3_out"]},
axis=-1,
)
add1_op = OpConfig(
"elementwise_add",
inputs={"X": ["mul2_out"], "Y": ["mul3_out"]},
outputs={"Out": ["add1_out"]},
axis=-1,
)
ops = [
cast0_op,
mul0_op,
scale0_op,
mul1_op,
add0_op,
cast1_op,
mul2_op,
scale1_op,
mul3_op,
add1_op,
]
program_config = ProgramConfig(
ops=ops,
weights={},
inputs={
"condition0": TensorConfig(
data_gen=partial(generate_condition)
),
"condition1": TensorConfig(
data_gen=partial(generate_condition)
),
"x": TensorConfig(data_gen=partial(generate_value)),
"y": TensorConfig(data_gen=partial(generate_value)),
},
outputs=ops[-1].outputs["Out"],
)
return program_config
def test(self):
self.run_and_statis(
quant=False,
max_examples=25,
passes=["fast_where_xpu_fuse_pass"],
)
if __name__ == "__main__":
np.random.seed(200)
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册