未验证 提交 8d325d82 编写于 作者: C csy0225 提交者: GitHub

[XPU] Migrate xpu_embedding_with_eltwise_add_fuse_pass (#50590)

上级 d7673e2f
......@@ -221,6 +221,7 @@ if(WITH_XPU)
SRCS xpu/pass_utils.cc
DEPS pass)
set(XPU_PASS_DEPS xpu_quant_utils xpu_pass_utils)
pass_library(embedding_with_eltwise_add_xpu_fuse_pass inference DIR xpu)
pass_library(fc_xpu_fuse_pass inference DIR xpu DEPS ${XPU_PASS_DEPS})
pass_library(multi_encoder_xpu_fuse_pass inference DIR xpu DEPS
${XPU_PASS_DEPS})
......
......@@ -30,46 +30,50 @@ namespace ir {
void DeleteDropoutOpPass::ApplyImpl(ir::Graph* graph) const {
const std::string pattern_name = "delete_dropout_op_pattern";
FusePassBase::Init(pattern_name, graph);
int found_subgraph_count = 0;
GraphPatternDetector gpd;
patterns::DeleteDropoutOpPattern pattern(gpd.mutable_pattern(), pattern_name);
pattern();
for (auto with_mask : {true, false}) {
GraphPatternDetector gpd;
patterns::DeleteDropoutOpPattern pattern(gpd.mutable_pattern(),
pattern_name);
pattern(with_mask);
int found_subgraph_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) {
GET_IR_NODE(dropout_op_x);
GET_IR_NODE(dropout_op);
GET_IR_NODE(dropout_op_out);
GET_IR_NODE(dropout_op_mask);
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) {
GET_IR_NODE(dropout_op_x);
GET_IR_NODE(dropout_op);
GET_IR_NODE(dropout_op_out);
// link dropout_op_out to pre_op
auto dropout_op_x_name = dropout_op_x->Var()->Name();
auto dropout_op_out_name = dropout_op_out->Var()->Name();
auto pre_ops = dropout_op_x->inputs;
if (pre_ops.empty()) return;
auto pre_op_desc = pre_ops[0]->Op();
auto pre_op_outs = pre_op_desc->Outputs();
for (auto& out_var : pre_op_outs) {
auto names = out_var.second;
for (size_t i = 0; i < names.size(); i++) {
if (names[i] == dropout_op_x_name) {
names[i] = dropout_op_out_name;
pre_op_desc->SetOutput(out_var.first, names);
break;
// link dropout_op_x to next_op
auto dropout_op_x_name = dropout_op_x->Var()->Name();
auto dropout_op_out_name = dropout_op_out->Var()->Name();
auto next_op_nodes = dropout_op_out->outputs;
for (auto next_op_node : next_op_nodes) {
auto next_op_desc = next_op_node->Op();
auto next_op_inputs = next_op_desc->Inputs();
for (auto& input_var : next_op_inputs) {
auto names = input_var.second;
for (size_t i = 0; i < names.size(); i++) {
if (names[i] == dropout_op_out_name) {
names[i] = dropout_op_x_name;
next_op_desc->SetInput(input_var.first, names);
break;
}
}
}
IR_NODE_LINK_TO(dropout_op_x, next_op_node);
}
}
IR_NODE_LINK_TO(pre_ops[0], dropout_op_out);
// delete useless node
std::unordered_set<const Node*> delete_nodes{
dropout_op_x, dropout_op, dropout_op_mask};
GraphSafeRemoveNodes(graph, delete_nodes);
found_subgraph_count++;
};
gpd(graph, handler);
// delete useless node
std::unordered_set<const Node*> delete_nodes{dropout_op, dropout_op_out};
if (with_mask) {
GET_IR_NODE(dropout_op_mask);
delete_nodes.insert(dropout_op_mask);
}
GraphSafeRemoveNodes(graph, delete_nodes);
found_subgraph_count++;
};
gpd(graph, handler);
}
AddStatis(found_subgraph_count);
}
......
......@@ -3032,7 +3032,7 @@ PDNode *patterns::TransposeFlattenConcat::operator()(
return concat_out;
}
void patterns::DeleteDropoutOpPattern::operator()() {
void patterns::DeleteDropoutOpPattern::operator()(bool with_mask) {
auto dropout_op_x = pattern->NewNode(dropout_op_x_repr())
->assert_is_op_input("dropout", "X")
->AsInput();
......@@ -3042,10 +3042,14 @@ void patterns::DeleteDropoutOpPattern::operator()() {
std::string("upscale_in_train"));
auto dropout_op_out = pattern->NewNode(dropout_op_out_repr())
->assert_is_op_output("dropout", "Out");
auto dropout_op_mask = pattern->NewNode(dropout_op_mask_repr())
->assert_is_op_output("dropout", "Mask");
dropout_op->LinksFrom({dropout_op_x})
.LinksTo({dropout_op_out, dropout_op_mask});
if (with_mask) {
auto dropout_op_mask = pattern->NewNode(dropout_op_mask_repr())
->assert_is_op_output("dropout", "Mask");
dropout_op->LinksFrom({dropout_op_x})
.LinksTo({dropout_op_out, dropout_op_mask});
} else {
dropout_op->LinksFrom({dropout_op_x}).LinksTo({dropout_op_out});
}
}
void patterns::DeleteQuantOpFuse::operator()(PDNode *input_act_node,
......
......@@ -1759,7 +1759,7 @@ struct DeleteDropoutOpPattern : public PatternBase {
DeleteDropoutOpPattern(PDPattern* pattern, const std::string& name_scope)
: PatternBase(pattern, name_scope, "delete_dropout_op_pattern") {}
void operator()();
void operator()(bool with_mask);
PATTERN_DECL_NODE(dropout_op_x);
PATTERN_DECL_NODE(dropout_op);
......
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <string>
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/framework/ir/pass.h"
#include "paddle/fluid/framework/ir/xpu/pass_utils.h"
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/platform/enforce.h"
namespace phi {
class DenseTensor;
} // namespace phi
namespace paddle {
namespace framework {
class Scope;
} // namespace framework
} // namespace paddle
namespace paddle {
namespace framework {
namespace ir {
static bool GetBoolFromEnv(const std::string& str, bool def = false) {
char* variable = std::getenv(str.c_str());
if (!variable) {
return def;
}
if (strcmp(variable, "false") == 0 || strcmp(variable, "0") == 0) {
return false;
} else {
return true;
}
}
namespace patterns {
struct EmbeddingWithEltwiseAddXPUPattern : public PatternBase {
EmbeddingWithEltwiseAddXPUPattern(PDPattern* pattern,
const std::string& name_scope,
int n_embedding_,
const std::string& op_type,
const std::string& pre_op_type);
// declare operator node's name
PATTERN_DECL_NODE(embedding0);
PATTERN_DECL_NODE(embedding1);
PATTERN_DECL_NODE(ewadd01);
// declare variable node's name
PATTERN_DECL_NODE(x0);
PATTERN_DECL_NODE(x1);
PATTERN_DECL_NODE(table0);
PATTERN_DECL_NODE(table1);
PATTERN_DECL_NODE(embedding_out0);
PATTERN_DECL_NODE(embedding_out1);
PATTERN_DECL_NODE(ewadd01_out);
std::unordered_map<std::string, std::string> node_reprs;
private:
int n_embedding_;
std::string op_type_;
std::string pre_op_type_;
};
EmbeddingWithEltwiseAddXPUPattern::EmbeddingWithEltwiseAddXPUPattern(
PDPattern* pattern,
const std::string& name_scope,
int n_embedding,
const std::string& op_type,
const std::string& pre_op_type)
: PatternBase(pattern, name_scope, name_scope),
n_embedding_(n_embedding),
op_type_(op_type),
pre_op_type_(pre_op_type) {
for (int i = 0; i < n_embedding; i++) {
node_reprs["x" + std::to_string(i)] =
PDNodeName(name_scope_, repr_, id_, "x" + std::to_string(i));
node_reprs["table" + std::to_string(i)] =
PDNodeName(name_scope_, repr_, id_, "table" + std::to_string(i));
node_reprs["embedding" + std::to_string(i)] =
PDNodeName(name_scope_, repr_, id_, "embedding" + std::to_string(i));
node_reprs["embedding_out" + std::to_string(i)] = PDNodeName(
name_scope_, repr_, id_, "embedding_out" + std::to_string(i));
if (i - 1 >= 0) {
auto ewadd_name = string::Sprintf("ewadd%d%d", i - 1, i);
node_reprs[ewadd_name] = PDNodeName(name_scope_, repr_, id_, ewadd_name);
auto ewadd_out_name = string::Sprintf("ewadd%d%d_out", i - 1, i);
node_reprs[ewadd_out_name] =
PDNodeName(name_scope_, repr_, id_, ewadd_out_name);
}
}
PDNode* x0 = pattern->NewNode(x0_repr())
->assert_is_op_input(op_type_, "Ids")
->assert_var_not_persistable()
->AsInput();
PDNode* x1 = pattern->NewNode(x1_repr())
->assert_is_op_input(op_type_, "Ids")
->assert_var_not_persistable()
->AsInput();
PDNode* embedding0 =
pattern->NewNode(embedding0_repr())->assert_is_op(op_type_);
auto* table0 = pattern->NewNode(table0_repr())
->assert_is_op_input(op_type_, "W")
->AsInput();
auto* embedding_out0 = pattern->NewNode(embedding_out0_repr())
->assert_is_op_output(op_type_, "Out")
->assert_is_op_input("elementwise_add", "X");
auto* table1 = pattern->NewNode(table1_repr())
->assert_is_op_input(op_type_, "W")
->AsInput();
auto* embedding1 =
pattern->NewNode(embedding1_repr())->assert_is_op(op_type_);
auto* embedding_out1 = pattern->NewNode(embedding_out1_repr())
->assert_is_op_output(op_type_, "Out")
->assert_is_op_input("elementwise_add", "Y");
auto* ewadd01 =
pattern->NewNode(ewadd01_repr())->assert_is_op("elementwise_add");
auto* ewadd01_out = pattern->NewNode(ewadd01_out_repr())
->assert_is_op_output("elementwise_add", "Out");
embedding0->LinksFrom({x0, table0});
embedding1->LinksFrom({x1, table1});
embedding0->LinksTo({embedding_out0});
embedding1->LinksTo({embedding_out1});
ewadd01->LinksFrom({embedding_out0, embedding_out1});
ewadd01->LinksTo({ewadd01_out});
auto* last_ewadd_out = ewadd01_out;
for (int i = 2; i < n_embedding; ++i) {
auto x_name = node_reprs["x" + std::to_string(i)];
auto table_name = node_reprs["table" + std::to_string(i)];
auto embedding_name = node_reprs["embedding" + std::to_string(i)];
auto embedding_out_name = node_reprs["embedding_out" + std::to_string(i)];
auto* new_table = pattern->NewNode(table_name)
->assert_is_op_input(op_type_, "W")
->AsInput();
auto* new_embedding =
pattern->NewNode(embedding_name)->assert_is_op(op_type_);
auto* new_embedding_out = pattern->NewNode(embedding_out_name)
->assert_is_op_output(op_type_, "Out")
->assert_is_op_input("elementwise_add", "Y");
auto* new_x = pattern->NewNode(x_name)
->assert_is_op_input(op_type_, "Ids")
->AsInput();
new_embedding->LinksFrom({new_x, new_table});
new_embedding->LinksTo({new_embedding_out});
auto ewadd_name =
node_reprs["ewadd" + std::to_string(i - 1) + std::to_string(i)];
auto ewadd_out_name = node_reprs["ewadd" + std::to_string(i - 1) +
std::to_string(i) + "_out"];
auto* new_ewadd =
pattern->NewNode(ewadd_name)->assert_is_op("elementwise_add");
auto* new_ewadd_out = pattern->NewNode(ewadd_out_name)
->assert_is_op_output("elementwise_add", "Out");
new_ewadd->LinksFrom({last_ewadd_out, new_embedding_out});
new_ewadd->LinksTo({new_ewadd_out});
last_ewadd_out = new_ewadd_out;
}
last_ewadd_out->AsOutput();
}
} // namespace patterns
class EmbeddingWithEltwiseAddXPUFusePass : public FusePassBase {
protected:
void ApplyImpl(ir::Graph* graph) const override;
private:
void ApplyImpl(ir::Graph* graph,
int n_embedding,
const std::string op_type,
const std::string pre_op_type) const;
const std::string name_scope_{"embedding_with_eltwise_add_xpu_fuse_pass"};
};
void EmbeddingWithEltwiseAddXPUFusePass::ApplyImpl(ir::Graph* graph) const {
PADDLE_ENFORCE_NOT_NULL(
graph, platform::errors::PreconditionNotMet("graph should not be null."));
FusePassBase::Init(name_scope_, graph);
std::vector<std::string> pre_op_types{"reshape2", "squeeze2", ""};
std::vector<std::string> op_types{"lookup_table", "lookup_table_v2"};
for (auto& pre_op_type : pre_op_types) {
for (int n_embedding : {4, 3, 2}) {
for (auto& op_type : op_types) {
ApplyImpl(graph, n_embedding, op_type, pre_op_type);
}
}
}
}
void EmbeddingWithEltwiseAddXPUFusePass::ApplyImpl(
ir::Graph* graph,
int n_embedding,
const std::string op_type,
const std::string pre_op_type) const {
GraphPatternDetector gpd;
patterns::EmbeddingWithEltwiseAddXPUPattern pattern(
gpd.mutable_pattern(), name_scope_, n_embedding, op_type, pre_op_type);
int found_subgraph_count = 0;
#define GET_IR_NODE_FROM_SUBGRAPH_BY_NAME(name, rt_node, pat) \
PADDLE_ENFORCE_NE( \
subgraph.count(pat.PatternBase::pattern->RetrieveNode(name)), \
0UL, \
platform::errors::NotFound("Node not found for PDNode %s", name)); \
Node* rt_node = subgraph.at(pat.PatternBase::pattern->RetrieveNode(name)); \
PADDLE_ENFORCE_NOT_NULL( \
rt_node, \
platform::errors::NotFound("node %s not exists in the sub-graph", \
#rt_node));
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* graph) {
std::vector<std::string> x_names;
std::vector<std::string> table_names;
std::vector<Node*> x_nodes;
std::vector<Node*> table_nodes;
std::vector<Node*> embedding_nodes;
auto output_name = pattern.node_reprs[string::Sprintf(
"ewadd%d%d_out", n_embedding - 2, n_embedding - 1)];
GET_IR_NODE_FROM_SUBGRAPH_BY_NAME(output_name, output_node, pattern)
std::unordered_set<const Node*> delete_nodes;
for (int i = 0; i < n_embedding; ++i) {
// Ids
auto x_name = pattern.node_reprs["x" + std::to_string(i)];
GET_IR_NODE_FROM_SUBGRAPH_BY_NAME(x_name, x_node, pattern)
x_nodes.push_back(x_node);
x_names.push_back(x_node->Name());
// Tables
auto table_name = pattern.node_reprs["table" + std::to_string(i)];
GET_IR_NODE_FROM_SUBGRAPH_BY_NAME(table_name, table_node, pattern)
table_nodes.push_back(table_node);
table_names.push_back(table_node->Name());
// Embedding
auto embedding_name = pattern.node_reprs["embedding" + std::to_string(i)];
GET_IR_NODE_FROM_SUBGRAPH_BY_NAME(embedding_name, embedding_node, pattern)
embedding_nodes.push_back(embedding_node);
delete_nodes.insert(embedding_node);
auto embedding_out_name =
pattern.node_reprs["embedding_out" + std::to_string(i)];
GET_IR_NODE_FROM_SUBGRAPH_BY_NAME(
embedding_out_name, embedding_out_node, pattern)
delete_nodes.insert(embedding_out_node);
if (i - 1 >= 0) {
auto ewadd_name =
pattern.node_reprs[string::Sprintf("ewadd%d%d", i - 1, i)];
GET_IR_NODE_FROM_SUBGRAPH_BY_NAME(ewadd_name, ewadd_node, pattern)
delete_nodes.insert(ewadd_node);
auto ewadd_out_name =
pattern.node_reprs[string::Sprintf("ewadd%d%d_out", i - 1, i)];
GET_IR_NODE_FROM_SUBGRAPH_BY_NAME(
ewadd_out_name, ewadd_out_node, pattern)
if (i != n_embedding - 1) {
delete_nodes.insert(ewadd_out_node);
}
}
}
// Generate embedding_with_eltwise_add_xpu op
framework::OpDesc embedding_with_eltwise_add_xpu_op_desc;
embedding_with_eltwise_add_xpu_op_desc.SetType(
"embedding_with_eltwise_add_xpu");
embedding_with_eltwise_add_xpu_op_desc.SetInput("ids", x_names);
embedding_with_eltwise_add_xpu_op_desc.SetInput("tables", table_names);
embedding_with_eltwise_add_xpu_op_desc.SetOutput("out",
{output_node->Name()});
embedding_with_eltwise_add_xpu_op_desc.SetAttr("n_embedding", n_embedding);
int64_t padding_idx = PADDLE_GET_CONST(
int64_t, embedding_nodes[0]->Op()->GetAttr("padding_idx"));
if (GetBoolFromEnv("XPU_PADDING_IDX", true)) {
padding_idx = -1;
}
embedding_with_eltwise_add_xpu_op_desc.SetAttr(
"padding_idx", static_cast<int64_t>(padding_idx));
auto* embedding_with_eltwise_add_xpu_op =
graph->CreateOpNode(&embedding_with_eltwise_add_xpu_op_desc);
for (size_t i = 0; i < x_nodes.size(); i++) {
SAFE_IR_NODE_LINK_TO(x_nodes[i], embedding_with_eltwise_add_xpu_op);
SAFE_IR_NODE_LINK_TO(table_nodes[i], embedding_with_eltwise_add_xpu_op);
}
SAFE_IR_NODE_LINK_TO(embedding_with_eltwise_add_xpu_op, output_node);
// delete useless node
GraphSafeRemoveNodes(graph, delete_nodes);
found_subgraph_count++;
};
gpd(graph, handler);
AddStatis(found_subgraph_count);
}
} // namespace ir
} // namespace framework
} // namespace paddle
REGISTER_PASS(embedding_with_eltwise_add_xpu_fuse_pass,
paddle::framework::ir::EmbeddingWithEltwiseAddXPUFusePass);
REGISTER_PASS_CAPABILITY(embedding_with_eltwise_add_xpu_fuse_pass)
.AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination().EQ(
"embedding_with_eltwise_add_xpu", 0));
......@@ -521,7 +521,7 @@ XpuPassStrategy::XpuPassStrategy() : PassStrategy({}) {
"generate_sequence_xpu_fuse_pass",
"multi_encoder_xpu_fuse_pass",
"multi_encoder_xpu_slice_fuse_pass",
// "embedding_with_eltwise_add_xpu_fuse_pass",
"embedding_with_eltwise_add_xpu_fuse_pass",
"fc_xpu_fuse_pass",
"link_xpu_op_max_pass",
});
......
- op : embedding_with_eltwise_add_xpu
args : (Tensor[] ids, Tensor[] tables, int64_t padding_idx)
output: Tensor
infer_meta :
func: EmbeddingWithEltwiseAddXPUInferMeta
kernel:
func: embedding_with_eltwise_add_xpu
data_type: tables
- op : fc_xpu
args : (Tensor x, Tensor x_max, Tensor w, Tensor w_max, Tensor bias, int in_num_col_dims, bool transpose_x, float alpha, float beta, int act_type, float act_alpha)
output : Tensor(out), Tensor(out_max)
......
......@@ -80,6 +80,8 @@ XPUOpMap& get_kl1_ops() {
{"elementwise_pow", XPUKernelSet({phi::DataType::FLOAT32})},
{"elementwise_sub_grad", XPUKernelSet({phi::DataType::FLOAT32})},
{"elementwise_sub", XPUKernelSet({phi::DataType::FLOAT32})},
{"embedding_with_eltwise_add_xpu",
XPUKernelSet({phi::DataType::FLOAT32})},
{"equal", XPUKernelSet({phi::DataType::INT64})},
{"expand_as_v2",
XPUKernelSet({phi::DataType::INT32,
......
......@@ -212,6 +212,8 @@ XPUOpMap& get_kl2_ops() {
phi::DataType::FLOAT16,
phi::DataType::INT64,
phi::DataType::INT32})},
{"embedding_with_eltwise_add_xpu",
XPUKernelSet({phi::DataType::FLOAT32})},
{"empty",
XPUKernelSet({phi::DataType::INT64,
phi::DataType::INT32,
......
......@@ -21,6 +21,28 @@ limitations under the License. */
namespace phi {
void EmbeddingWithEltwiseAddXPUInferMeta(
const std::vector<const MetaTensor*>& ids,
const std::vector<const MetaTensor*>& tables,
MetaTensor* out) {
PADDLE_ENFORCE_GT(ids.size(),
0UL,
phi::errors::InvalidArgument(
"The input ids in EmbeddingWithEltwiseAddXPUInferMeta "
"can't be empty."));
PADDLE_ENFORCE_GT(tables.size(),
0UL,
phi::errors::InvalidArgument(
"The input tables in "
"EmbeddingWithEltwiseAddXPUInferMeta can't be empty."));
auto id_dims = ids[0]->dims();
auto table_dims = tables[0]->dims();
out->set_dims(phi::make_ddim({id_dims[0], id_dims[1], table_dims[1]}));
out->set_dtype(tables[0]->dtype());
out->set_layout(ids[0]->layout());
}
void FcXPUInferMeta(const MetaTensor& x,
const MetaTensor& x_max,
const MetaTensor& w,
......
......@@ -22,6 +22,11 @@ namespace phi {
// Common InferMeta Functions for fusion operators.
// NOTE: The InferMeta Functions in this file are arranged in alphabetic order.
void EmbeddingWithEltwiseAddXPUInferMeta(
const std::vector<const MetaTensor*>& ids,
const std::vector<const MetaTensor*>& tables,
MetaTensor* out);
void FcXPUInferMeta(const MetaTensor& x,
const MetaTensor& x_max,
const MetaTensor& w,
......
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/backends/xpu/enforce_xpu.h"
#include "paddle/phi/core/kernel_registry.h"
namespace phi {
namespace fusion {
template <typename T, typename Context>
void EmbeddingWithEltwiseAddXpuKernel(
const Context& ctx,
const std::vector<const DenseTensor*>& ids,
const std::vector<const DenseTensor*>& tables,
int64_t padding_idx,
DenseTensor* out) {
auto& id_dims = ids[0]->dims();
int idx_len = id_dims[0] * id_dims[1];
int emb_layer_num = ids.size();
int embed_dim = tables[0]->dims()[1];
std::vector<int> table_lens_cpu;
std::vector<const float*> arg_tables;
for (auto* table : tables) {
auto& table_dims = table->dims();
PADDLE_ENFORCE_EQ(
table_dims.size(),
2,
errors::InvalidArgument(
"The table_dims size [%d] should be equal 2.",
table_dims.size())); /* shape like [table_len, embed_dim] */
PADDLE_ENFORCE_EQ(
table_dims[1],
embed_dim,
errors::InvalidArgument(
"Every embed_dim [%d] should be equal the first one [%d].",
table_dims[1],
embed_dim));
table_lens_cpu.push_back(table_dims[0]);
arg_tables.push_back(table->data<float>());
}
std::vector<std::vector<int>> int_idx(emb_layer_num,
std::vector<int>(idx_len, 0));
std::vector<xpu::VectorParam<int>> arg_ids;
for (int i = 0; i < emb_layer_num; i++) {
for (int j = 0; j < idx_len; j++) {
int_idx[i][j] = static_cast<int>(ids[i]->data<int64_t>()[j]);
}
arg_ids.push_back(
xpu::VectorParam<int>{int_idx[i].data(), idx_len, nullptr});
}
ctx.template Alloc<T>(out);
int r = xpu::multi_embedding_fusion<float, float, int>(
ctx.x_context(),
arg_tables, /* tables */
out->data<T>(),
arg_ids,
table_lens_cpu,
embed_dim,
std::vector<float>(table_lens_cpu.size(), 1.0f),
std::vector<int>(table_lens_cpu.size(), padding_idx));
PADDLE_ENFORCE_XDNN_SUCCESS(r, "embedding_with_eltwise_add_xpu");
}
} // namespace fusion
} // namespace phi
PD_REGISTER_KERNEL(embedding_with_eltwise_add_xpu,
XPU,
ALL_LAYOUT,
phi::fusion::EmbeddingWithEltwiseAddXpuKernel,
float) {
kernel->InputAt(0).SetBackend(phi::Backend::CPU);
}
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
from functools import partial
import hypothesis.strategies as st
import numpy as np
from auto_scan_test import PassAutoScanTest
from program_config import OpConfig, ProgramConfig, TensorConfig
class TestEmbeddingWithEltwiseAddXPUFusePass(PassAutoScanTest):
def sample_predictor_configs(self, program_config):
config = self.create_inference_config(use_xpu=True)
yield config, ["embedding_with_eltwise_add_xpu"], (1e-3, 1e-3)
def sample_program_config(self, draw):
# lookup_table_v2
lookup_table_num = draw(st.sampled_from([2, 3, 4]))
print("lookup_table_num: ", lookup_table_num)
ids_shape = draw(st.sampled_from([[1, 32]]))
w_shape = draw(st.sampled_from([[1000, 32]]))
padding_idx = draw(st.sampled_from([-1]))
axis = draw(st.sampled_from([-1]))
def gen_lookup_table_ops():
lookup_table_op_config_list = []
lookup_table_op_0 = OpConfig(
"lookup_table_v2",
inputs={
"Ids": ["lookup_table_ids_0"],
"W": ["lookup_table_w_0"],
},
outputs={"Out": ["lookup_table_out_0"]},
padding_idx=padding_idx,
)
lookup_table_op_1 = OpConfig(
"lookup_table_v2",
inputs={
"Ids": ["lookup_table_ids_1"],
"W": ["lookup_table_w_1"],
},
outputs={"Out": ["lookup_table_out_1"]},
padding_idx=padding_idx,
)
lookup_table_ops_list = [lookup_table_op_0, lookup_table_op_1]
if lookup_table_num >= 3:
lookup_table_op_2 = OpConfig(
"lookup_table_v2",
inputs={
"Ids": ["lookup_table_ids_2"],
"W": ["lookup_table_w_2"],
},
outputs={"Out": ["lookup_table_out_2"]},
padding_idx=padding_idx,
)
lookup_table_ops_list.append(lookup_table_op_2)
if lookup_table_num >= 4:
lookup_table_op_3 = OpConfig(
"lookup_table_v2",
inputs={
"Ids": ["lookup_table_ids_3"],
"W": ["lookup_table_w_3"],
},
outputs={"Out": ["lookup_table_out_3"]},
padding_idx=padding_idx,
)
lookup_table_ops_list.append(lookup_table_op_3)
return lookup_table_ops_list
add_op_num = lookup_table_num - 1
def gen_eltwise_add_ops():
add_op_0 = OpConfig(
"elementwise_add",
inputs={
"X": ["lookup_table_out_0"],
"Y": ["lookup_table_out_1"],
},
outputs={"Out": ["add_op_0_out"]},
axis=axis,
)
add_op_list = [add_op_0]
if add_op_num >= 2:
add_op_1 = OpConfig(
"elementwise_add",
inputs={"X": ["add_op_0_out"], "Y": ["lookup_table_out_2"]},
outputs={"Out": ["add_op_1_out"]},
axis=axis,
)
add_op_list.append(add_op_1)
if add_op_num >= 3:
add_op_2 = OpConfig(
"elementwise_add",
inputs={"X": ["add_op_1_out"], "Y": ["lookup_table_out_3"]},
outputs={"Out": ["add_op_2_out"]},
axis=axis,
)
add_op_list.append(add_op_2)
return add_op_list
lookup_table_op_list = gen_lookup_table_ops()
add_op_list = gen_eltwise_add_ops()
# ops
ops = []
ops.extend(lookup_table_op_list)
ops.extend(add_op_list)
# inputs
def generate_input(*args, **kwargs):
return np.random.randint(0, w_shape[0], ids_shape).astype(np.int64)
def gen_lookup_table_inputs_data(*args, **kwargs):
inputs = {}
for i in range(lookup_table_num):
input_name = "lookup_table_ids_{}".format(i)
inputs[input_name] = TensorConfig(
data_gen=partial(generate_input)
)
return inputs
inputs = gen_lookup_table_inputs_data()
# weights
def gen_lookup_table_weights_data():
weights = {}
for i in range(lookup_table_num):
w_name = "lookup_table_w_{}".format(i)
weights[w_name] = TensorConfig(shape=w_shape)
return weights
weights = gen_lookup_table_weights_data()
program_config = ProgramConfig(
ops=ops,
weights=weights,
inputs=inputs,
outputs=add_op_list[-1].outputs["Out"],
)
return program_config
def test(self):
self.run_and_statis(
quant=False,
max_examples=3,
min_success_num=3,
passes=["embedding_with_eltwise_add_xpu_fuse_pass"],
)
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册