diff --git a/lite/api/paddle_api.h b/lite/api/paddle_api.h index 79ab98da799a99540217d55e3d40b46800f17626..31600bda3017861a9f43b1f5b844ab0157395627 100644 --- a/lite/api/paddle_api.h +++ b/lite/api/paddle_api.h @@ -208,6 +208,8 @@ class LITE_API CxxConfig : public ConfigBase { // current thread. void set_xpu_workspace_l3_size_per_thread(int l3_size = 0xfffc00); // XPU only, specify the target device ID for the current thread. + // **DEPRECATED**, use xpu_set_device() at the very beginning of each worker + // thread void set_xpu_dev_per_thread(int dev_no = 0); }; diff --git a/lite/api/paddle_use_passes.h b/lite/api/paddle_use_passes.h index 82cd7f3d8da5eb4f00c9069731960a81ef9fe87d..e81bebe1a31656409ed718b29b956a7a66560248 100644 --- a/lite/api/paddle_use_passes.h +++ b/lite/api/paddle_use_passes.h @@ -33,6 +33,7 @@ USE_MIR_PASS(lite_transpose_softmax_transpose_fuse_pass); USE_MIR_PASS(lite_interpolate_fuse_pass); USE_MIR_PASS(lite_sequence_pool_concat_fuse_pass); USE_MIR_PASS(identity_scale_eliminate_pass); +USE_MIR_PASS(identity_dropout_eliminate_pass); USE_MIR_PASS(lite_conv_elementwise_fuse_pass); USE_MIR_PASS(lite_conv_activation_fuse_pass); USE_MIR_PASS(lite_var_conv_2d_activation_fuse_pass); @@ -53,3 +54,5 @@ USE_MIR_PASS(apu_subgraph_pass); USE_MIR_PASS(quantized_op_attributes_inference_pass); USE_MIR_PASS(__xpu__resnet_fuse_pass); USE_MIR_PASS(__xpu__multi_encoder_fuse_pass); +USE_MIR_PASS(__xpu__embedding_with_eltwise_add_fuse_pass); +USE_MIR_PASS(__xpu__fc_fuse_pass); diff --git a/lite/core/context.cc b/lite/core/context.cc index be41aa6eb0cb986760f38eaa2bb5b7e017cc4edb..711c67f8b7f36edcd2d66569d964296d96e8d85c 100644 --- a/lite/core/context.cc +++ b/lite/core/context.cc @@ -19,6 +19,7 @@ namespace lite { #ifdef LITE_WITH_XPU thread_local xdnn::Context* Context::_tls_raw_ctx{nullptr}; +int Context::_workspace_l3_size_per_thread{0}; #endif } // namespace lite diff --git a/lite/core/context.h b/lite/core/context.h index fa415c7cc452d524b0b6f1b2ad17418e8cfdade1..d0c1bd93cc7b93628aedc5f549c84d19c44f4f71 100644 --- a/lite/core/context.h +++ b/lite/core/context.h @@ -151,14 +151,23 @@ class Context { if (_tls_raw_ctx == nullptr) { _tls_raw_ctx = xdnn::create_context(); CHECK(_tls_raw_ctx); + int r = xdnn::set_workspace_l3_size(_tls_raw_ctx, + _workspace_l3_size_per_thread); + if (r != 0) { + LOG(WARNING) << "xdnn::set_workspace_l3_size() failed, r = " << r + << ", _workspace_l3_size_per_thread = " + << _workspace_l3_size_per_thread; + } } return _tls_raw_ctx; } static void SetWorkspaceL3Size(int l3_size = 0xfffc00) { - xdnn::set_workspace_l3_size(GetRawContext(), l3_size); + _workspace_l3_size_per_thread = l3_size; } + // **DEPRECATED**, use xpu_set_device() at the very beginning of each worker + // thread static void SetDev(int dev_no = 0) { const char* dev_env = getenv("LITE_XPU_DEV"); if (dev_env) { @@ -173,6 +182,7 @@ class Context { private: static thread_local xdnn::Context* _tls_raw_ctx; + static int _workspace_l3_size_per_thread; }; #endif diff --git a/lite/core/mir/CMakeLists.txt b/lite/core/mir/CMakeLists.txt index d036bf7988b98e64586e42683d33b4696e9ff706..a365fe3f7b8f04b3568fbf2c8f85af4e2469706c 100644 --- a/lite/core/mir/CMakeLists.txt +++ b/lite/core/mir/CMakeLists.txt @@ -23,7 +23,10 @@ lite_cc_library(mir_passes fusion/sequence_pool_concat_fuse_pass.cc fusion/__xpu__resnet_fuse_pass.cc fusion/__xpu__multi_encoder_fuse_pass.cc + fusion/__xpu__embedding_with_eltwise_add_fuse_pass.cc + fusion/__xpu__fc_fuse_pass.cc elimination/identity_scale_eliminate_pass.cc + elimination/identity_dropout_eliminate_pass.cc elimination/elementwise_mul_constant_eliminate_pass.cc static_kernel_pick_pass.cc variable_place_inference_pass.cc diff --git a/lite/core/mir/elimination/identity_dropout_eliminate_pass.cc b/lite/core/mir/elimination/identity_dropout_eliminate_pass.cc new file mode 100644 index 0000000000000000000000000000000000000000..92401df875da1f500ec09b34b2786d15cea2991b --- /dev/null +++ b/lite/core/mir/elimination/identity_dropout_eliminate_pass.cc @@ -0,0 +1,77 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/core/mir/pass.h" +#include "lite/core/mir/pass_registry.h" +#include "lite/core/mir/pattern_matcher_high_api.h" + +namespace paddle { +namespace lite { +namespace mir { + +namespace { + +class Eliminator : public FuseBase { + public: + void BuildPattern() override { + // the previous op's output need updat + auto* pre_op = OpNode("preop")->assert_is_not_op_type("conditional_block"); + // TODO(Superjomn) check has only one output + auto* x = VarNode("x")->assert_is_op_input("dropout", "X"); + auto* dropout_op = OpNode("dropout", "dropout") + ->assert_op_attr("is_test", 1) + ->assert_op_attr( + "dropout_implementation", "upscale_in_train"); + auto* out = VarNode("out")->assert_is_op_output("dropout", "Out"); + auto* mask = VarNode("mask")->assert_is_op_output("dropout", "Mask"); + + *pre_op >> *x >> *dropout_op >> *out; + *dropout_op >> *mask; + + // The pre_op will be eliminated, and a new output-updated op will insert. + x->AsIntermediate(); // x is pre_op's output, need to update + dropout_op->AsIntermediate(); + mask->AsIntermediate(); + } + + private: + void InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) override { + auto& pre_op = matched.at("preop")->AsStmt(); + auto op_info = *pre_op.op_info(); + + op_info.UpdateAllOutputs(matched.at("x")->AsArg().name, + matched.at("out")->AsArg().name); + pre_op.ResetOp(op_info, graph->valid_places()); + + IR_NODE_LINK_TO(matched.at("preop"), matched.at("out")); + } +}; + +} // namespace + +class IdentityDropoutEliminatePass : public ProgramPass { + public: + void Apply(const std::unique_ptr& graph) override { + Eliminator eliminator; + eliminator(graph.get()); + } +}; + +} // namespace mir +} // namespace lite +} // namespace paddle + +REGISTER_MIR_PASS(identity_dropout_eliminate_pass, + paddle::lite::mir::IdentityDropoutEliminatePass) + .BindTargets({TARGET(kXPU)}); diff --git a/lite/core/mir/fusion/__xpu__embedding_with_eltwise_add_fuse_pass.cc b/lite/core/mir/fusion/__xpu__embedding_with_eltwise_add_fuse_pass.cc new file mode 100644 index 0000000000000000000000000000000000000000..1272ae4c63c2521bf738ca8623fcde2d40014dea --- /dev/null +++ b/lite/core/mir/fusion/__xpu__embedding_with_eltwise_add_fuse_pass.cc @@ -0,0 +1,166 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include "lite/core/mir/pass_registry.h" +#include "lite/core/mir/xpu_pattern_matcher_high_api.h" +#include "lite/utils/string.h" + +namespace paddle { +namespace lite { +namespace mir { + +namespace fusion { + +class XPUEmbeddingWithEltwiseAddFuser : public FuseBase { + public: + explicit XPUEmbeddingWithEltwiseAddFuser(int n_embedding) + : n_embedding_(n_embedding) {} + + void BuildPattern() override { + auto* ids0 = + VarNode("ids0")->assert_is_op_input("lookup_table", "Ids")->AsInput(); + auto* table0 = + VarNode("table0")->assert_is_op_input("lookup_table", "W")->AsInput(); + auto* embedding0 = OpNode("embedding0", "lookup_table"); + auto* embedding_out0 = VarNode("embedding_out0") + ->assert_is_op_output("lookup_table", "Out") + ->assert_is_op_input("elementwise_add", "X") + ->AsIntermediate(); + + auto* ids1 = + VarNode("ids1")->assert_is_op_input("lookup_table", "Ids")->AsInput(); + auto* table1 = + VarNode("table1")->assert_is_op_input("lookup_table", "W")->AsInput(); + auto* embedding1 = OpNode("embedding1", "lookup_table")->AsIntermediate(); + auto* embedding_out1 = VarNode("embedding_out1") + ->assert_is_op_output("lookup_table", "Out") + ->assert_is_op_input("elementwise_add", "Y") + ->AsIntermediate(); + + auto* ewadd01 = OpNode("ewadd01", "elementwise_add")->AsIntermediate(); + auto* ewadd01_out = VarNode("ewadd01_out") + ->assert_is_op_output("elementwise_add", "Out") + ->AsIntermediate(); + + embedding0->LinksFrom({ids0, table0}); + embedding0->LinksTo({embedding_out0}); + embedding1->LinksFrom({ids1, table1}); + embedding1->LinksTo({embedding_out1}); + ewadd01->LinksFrom({embedding_out0, embedding_out1}); + ewadd01->LinksTo({ewadd01_out}); + + auto* last_ewadd_out = ewadd01_out; + for (int i = 2; i < n_embedding_; ++i) { + auto ids_name = paddle::lite::string_format("ids%d", i); + auto table_name = paddle::lite::string_format("table%d", i); + auto embedding_name = paddle::lite::string_format("embedding%d", i); + auto embedding_out_name = + paddle::lite::string_format("embedding_out%d", i); + + auto* new_ids = VarNode(ids_name) + ->assert_is_op_input("lookup_table", "Ids") + ->AsInput(); + auto* new_table = VarNode(table_name) + ->assert_is_op_input("lookup_table", "W") + ->AsInput(); + auto* new_embedding = + OpNode(embedding_name, "lookup_table")->AsIntermediate(); + auto* new_embedding_out = VarNode(embedding_out_name) + ->assert_is_op_output("lookup_table", "Out") + ->assert_is_op_input("elementwise_add", "Y") + ->AsIntermediate(); + + new_embedding->LinksFrom({new_ids, new_table}); + new_embedding->LinksTo({new_embedding_out}); + + auto ewadd_name = paddle::lite::string_format("ewadd%d%d", i - 1, i); + auto ewadd_out_name = ewadd_name + "_out"; + + auto* new_ewadd = OpNode(ewadd_name, "elementwise_add")->AsIntermediate(); + auto* new_ewadd_out = VarNode(ewadd_out_name) + ->assert_is_op_output("elementwise_add", "Out") + ->AsIntermediate(); + + new_ewadd->LinksFrom({last_ewadd_out, new_embedding_out}); + new_ewadd->LinksTo({new_ewadd_out}); + last_ewadd_out = new_ewadd_out; + } + last_ewadd_out->AsOutput(); + } + + void InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) override { + cpp::OpDesc op_desc; + op_desc.SetType("__xpu__embedding_with_eltwise_add"); + std::vector ids_names; + std::vector table_names; + for (int i = 0; i < n_embedding_; ++i) { + auto ids_name = paddle::lite::string_format("ids%d", i); + ids_names.push_back(matched.at(ids_name)->arg()->name); + auto table_name = paddle::lite::string_format("table%d", i); + table_names.push_back(matched.at(table_name)->arg()->name); + } + op_desc.SetInput("Ids", ids_names); + op_desc.SetInput("Tables", table_names); + auto output_name = paddle::lite::string_format( + "ewadd%d%d_out", n_embedding_ - 2, n_embedding_ - 1); + op_desc.SetOutput("Output", {matched.at(output_name)->arg()->name}); + op_desc.SetAttr("n_embedding", n_embedding_); + auto* embedding0_op_info = matched.at("embedding0")->stmt()->op_info(); + op_desc.SetAttr( + "padding_idx", embedding0_op_info->GetAttr("padding_idx")); + + auto* new_stmt = matched.at("embedding0")->stmt(); + auto new_op = LiteOpRegistry::Global().Create(op_desc.Type()); + new_op->Attach(op_desc, new_stmt->op()->scope()); + new_op->SetValidPlaces(new_stmt->op()->valid_places()); + auto kernels = new_op->CreateKernels(new_op->valid_places()); + new_stmt->SetOp(new_op); + new_stmt->SetKernels(std::move(kernels)); + + for (int i = 0; i < n_embedding_; ++i) { + auto ids_name = paddle::lite::string_format("ids%d", i); + auto table_name = paddle::lite::string_format("table%d", i); + DirectedLink(matched.at(ids_name), matched.at("embedding0")); + DirectedLink(matched.at(table_name), matched.at("embedding0")); + } + IR_OP_VAR_LINK(matched.at("embedding0"), matched.at(output_name)); + } + + private: + int n_embedding_; +}; + +} // namespace fusion + +class XPUEmbeddingWithEltwiseAddFusePass : public ProgramPass { + public: + void Apply(const std::unique_ptr& graph) override { + if (GetBoolFromEnv("XPU_ENABLE_XTCL")) return; + for (int n_embedding : {4, 3}) { + fusion::XPUEmbeddingWithEltwiseAddFuser fuser(n_embedding); + fuser(graph.get()); + } + } +}; + +} // namespace mir +} // namespace lite +} // namespace paddle + +REGISTER_MIR_PASS(__xpu__embedding_with_eltwise_add_fuse_pass, + paddle::lite::mir::XPUEmbeddingWithEltwiseAddFusePass) + .BindTargets({TARGET(kXPU)}) + .BindKernel("lookup_table"); diff --git a/lite/core/mir/fusion/__xpu__fc_fuse_pass.cc b/lite/core/mir/fusion/__xpu__fc_fuse_pass.cc new file mode 100644 index 0000000000000000000000000000000000000000..1e6b28790e1c87f2e9e80acc99f3cf517621c477 --- /dev/null +++ b/lite/core/mir/fusion/__xpu__fc_fuse_pass.cc @@ -0,0 +1,147 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include "lite/backends/xpu/math.h" +#include "lite/core/mir/pass_registry.h" +#include "lite/core/mir/pattern_matcher_high_api.h" + +namespace paddle { +namespace lite { +namespace mir { +namespace fusion { + +class XPUFcFuser : public FuseBase { + public: + explicit XPUFcFuser(bool with_relu) : with_relu_(with_relu) {} + + void BuildPattern() override { + // create nodes. + auto* x = VarNode("x")->assert_is_op_input("mul", "X"); + auto* W = VarNode("W")->assert_is_op_input("mul", "Y"); + auto* b = VarNode("b")->assert_is_persistable_var(); + auto* mul = OpNode("mul", "mul"); + auto* mul_out = VarNode("mul_out"); + auto* add = OpNode("add", "elementwise_add"); + auto* Out = VarNode("Out"); + + // create topology. + std::vector mul_inputs{W, x}; + std::vector add_inputs{mul_out, b}; + mul_inputs >> *mul >> *mul_out; + + // Some op specialities. + mul_out->AsIntermediate(); + mul->AsIntermediate(); + add->AsIntermediate(); + + if (with_relu_) { + auto* add_out = VarNode("add_out"); + auto* relu = OpNode("relu", "relu"); + std::vector relu_inputs{add_out}; + add_inputs >> *add >> *add_out; + relu_inputs >> *relu >> *Out; + add_out->AsIntermediate(); + relu->AsIntermediate(); + } else { + add_inputs >> *add >> *Out; + } + } + + void InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) override { + auto mul = matched.at("mul")->stmt()->op(); + auto* scope = mul->scope(); + + // convert W from float to int16, and transpose W + auto weight_name = matched.at("W")->arg()->name; + auto* weight_t = scope->FindMutableTensor(weight_name); + auto weight_dims = weight_t->dims(); + int weight_len = weight_t->numel(); + float* weight_on_host = weight_t->mutable_data(); + float max_f = + paddle::lite::xpu::math::FindMaxAbs(weight_on_host, weight_len); + + std::unique_ptr weight_int16(new int16_t[weight_len]); + std::unique_ptr weight_trans_int16(new int16_t[weight_len]); + paddle::lite::xpu::math::ConvertFP32ToInt16( + weight_on_host, weight_int16.get(), max_f, weight_len); + paddle::lite::xpu::math::Transpose(weight_int16.get(), + weight_trans_int16.get(), + weight_dims[0], + weight_dims[1]); + memcpy( + weight_on_host, weight_trans_int16.get(), weight_len * sizeof(int16_t)); + + auto op_desc = GenOpDesc(matched, max_f, true); + auto fc_op = LiteOpRegistry::Global().Create("__xpu__fc"); + auto& valid_places = mul->valid_places(); + fc_op->Attach(op_desc, scope); + + auto* new_op_node = graph->GraphCreateInstructNode(fc_op, valid_places); + + IR_NODE_LINK_TO(matched.at("W"), new_op_node); + IR_NODE_LINK_TO(matched.at("x"), new_op_node); + IR_NODE_LINK_TO(matched.at("b"), new_op_node); + IR_NODE_LINK_TO(new_op_node, matched.at("Out")); + } + + private: + cpp::OpDesc GenOpDesc(const key2nodes_t& matched, + float w_max, + bool transpose_w) { + cpp::OpDesc op_desc = *matched.at("mul")->stmt()->op_info(); + op_desc.mutable_inputs()->clear(); + op_desc.mutable_outputs()->clear(); + op_desc.SetType("__xpu__fc"); + op_desc.SetInput("Input", {matched.at("x")->arg()->name}); + op_desc.SetInput("W", {matched.at("W")->arg()->name}); + op_desc.SetInput("Bias", {matched.at("b")->arg()->name}); + op_desc.SetOutput("Out", {matched.at("Out")->arg()->name}); + op_desc.SetAttr( + "in_num_col_dims", + matched.at("mul")->stmt()->op_info()->GetAttr("x_num_col_dims")); + op_desc.SetAttr("w_max", w_max); + op_desc.SetAttr("transpose_w", transpose_w); + if (with_relu_) { + op_desc.SetAttr("activation_type", std::string{"relu"}); + } + return op_desc; + } + + bool with_relu_; +}; + +} // namespace fusion + +class XPUFcFusePass : public ProgramPass { + public: + void Apply(const std::unique_ptr& graph) override { + if (GetBoolFromEnv("XPU_ENABLE_XTCL")) return; + + fusion::XPUFcFuser fuser(true /* with_relu */); + fuser(graph.get()); + + fusion::XPUFcFuser fuser2(false /* with_relu */); + fuser2(graph.get()); + } +}; + +} // namespace mir +} // namespace lite +} // namespace paddle + +REGISTER_MIR_PASS(__xpu__fc_fuse_pass, paddle::lite::mir::XPUFcFusePass) + .BindTargets({TARGET(kXPU)}) + .BindKernel("fc"); diff --git a/lite/core/mir/fusion/__xpu__multi_encoder_fuse_pass.cc b/lite/core/mir/fusion/__xpu__multi_encoder_fuse_pass.cc index 655274070f1ffcccf39b5f3ff6aaa705c5cbbfda..a6640f107f5dd46e6570a55cf59d2ad69a2bee1a 100644 --- a/lite/core/mir/fusion/__xpu__multi_encoder_fuse_pass.cc +++ b/lite/core/mir/fusion/__xpu__multi_encoder_fuse_pass.cc @@ -16,6 +16,7 @@ #include #include "lite/backends/xpu/math.h" #include "lite/core/mir/pass_registry.h" +#include "lite/core/mir/type_precision_cast_pass.h" // For UpdateInputs() #include "lite/core/mir/xpu_pattern_matcher_high_api.h" #include "lite/operators/subgraph_op.h" @@ -588,8 +589,7 @@ class XPUMultiEncoderFuser { multi_encoder_stmt->SetOp(multi_encoder_op); multi_encoder_stmt->SetKernels(std::move(kernels)); - // temp remove useless cast - std::unordered_set to_remove2; + // remove dangling/useless cast Node* stack = nullptr; for (auto* node : graph->StmtTopologicalOrder()) { CHECK(node->IsStmt()); @@ -597,16 +597,39 @@ class XPUMultiEncoderFuser { stack = node; } } - Node* stack_out = stack->outlinks.front(); - for (Node* cast : stack_out->outlinks) { - Node* cast_out = cast->outlinks.front(); - if (cast_out->outlinks.size() == 0) { - // remove - to_remove2.insert(cast_out); - to_remove2.insert(cast); + if (stack) { + std::unordered_set to_remove2; + Node* stack_out = stack->outlinks.front(); + // avoid modification while traversing + auto stack_out_outlinks = stack_out->outlinks; + for (Node* cast : stack_out_outlinks) { + if (cast->stmt()->op_info()->Type() != "cast") { + continue; + } + + Node* cast_out = cast->outlinks.front(); + if (cast_out->outlinks.size() == 0) { + // dangling cast + to_remove2.insert(cast); + to_remove2.insert(cast_out); + VLOG(3) << "Remove dangling cast [" << cast_out->arg()->name << "]"; + } else if (cast_out->outlinks.size() == 1) { + // useless cast + to_remove2.insert(cast); + to_remove2.insert(cast_out); + VLOG(3) << "Remove useless cast [" << cast_out->arg()->name << "]"; + + auto* multi_encoder = cast_out->outlinks.front(); + DirectedLink(stack_out, multi_encoder); + UpdateInputs(multi_encoder->stmt()->op().get(), + cast_out->arg()->name, + stack_out->arg()->name); + auto update_op_info = *multi_encoder->stmt()->op_info(); + multi_encoder->stmt()->ResetOp(update_op_info, graph->valid_places()); + } } + GraphSafeRemoveNodes(graph, to_remove2); } - GraphSafeRemoveNodes(graph, to_remove2); } }; diff --git a/lite/core/optimizer.h b/lite/core/optimizer.h index 83df76f0230f666ec3857834e234afd921daa927..bd33796f046a7605ff32fb8f2d08610f51976d6c 100644 --- a/lite/core/optimizer.h +++ b/lite/core/optimizer.h @@ -77,6 +77,10 @@ class Optimizer { #endif "__xpu__resnet_fuse_pass", "__xpu__multi_encoder_fuse_pass", + "__xpu__embedding_with_eltwise_add_fuse_pass", + "__xpu__fc_fuse_pass", + "identity_dropout_eliminate_pass", // should be placed after + // xpu fusion "quantized_op_attributes_inference_pass", // Only for fully // quantized model, infer // the output scale and diff --git a/lite/kernels/xpu/CMakeLists.txt b/lite/kernels/xpu/CMakeLists.txt index 07dc127695e3906719b45020a585966877bec868..7ded008387b7d7c92fb2ce6b18e73e1c1e51f29d 100644 --- a/lite/kernels/xpu/CMakeLists.txt +++ b/lite/kernels/xpu/CMakeLists.txt @@ -24,4 +24,6 @@ else() add_kernel(cast_compute_xpu XPU basic SRCS cast_compute.cc DEPS ${lite_kernel_deps}) add_kernel(__xpu__resnet50_compute_xpu XPU extra SRCS __xpu__resnet50_compute.cc DEPS ${lite_kernel_deps}) add_kernel(__xpu__multi_encoder_compute_xpu XPU extra SRCS __xpu__multi_encoder_compute.cc DEPS ${lite_kernel_deps}) + add_kernel(__xpu__embedding_with_eltwise_add_compute_xpu XPU extra SRCS __xpu__embedding_with_eltwise_add_compute.cc DEPS ${lite_kernel_deps}) + add_kernel(__xpu__fc_compute_xpu XPU extra SRCS __xpu__fc_compute.cc DEPS ${lite_kernel_deps}) endif() diff --git a/lite/kernels/xpu/__xpu__embedding_with_eltwise_add_compute.cc b/lite/kernels/xpu/__xpu__embedding_with_eltwise_add_compute.cc new file mode 100644 index 0000000000000000000000000000000000000000..376cdd0dc23426ede42ddac60e061727f73322e3 --- /dev/null +++ b/lite/kernels/xpu/__xpu__embedding_with_eltwise_add_compute.cc @@ -0,0 +1,87 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/kernels/xpu/__xpu__embedding_with_eltwise_add_compute.h" +#include "lite/backends/xpu/xpu_header_sitter.h" +#include "lite/core/op_registry.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace xpu { + +void XPUEmbeddingWithEltwiseAddCompute::PrepareForRun() { + auto& param = this->Param(); + + arg_ids_.reserve(param.Ids.size()); + arg_tables_.reserve(param.Tables.size()); + for (auto* table : param.Tables) { + auto& table_dims = table->dims(); + CHECK_EQ(table_dims.size(), 2); /* shape like [table_len, embed_dim] */ + table_lens_cpu_.push_back(table_dims[0]); + } + void* lens_ptr = nullptr; + size_t lens_size = table_lens_cpu_.size() * sizeof(int); + xpu_malloc(&lens_ptr, lens_size); + xpu_memcpy(lens_ptr, &table_lens_cpu_[0], lens_size, XPU_HOST_TO_DEVICE); + table_lens_guard_.reset(lens_ptr); +} + +void XPUEmbeddingWithEltwiseAddCompute::Run() { + auto& param = this->Param(); + auto& ctx = this->ctx_->As(); + + for (size_t i = 0; i < param.Ids.size(); ++i) { + arg_ids_[i] = param.Ids[i]->data(); + } + for (size_t i = 0; i < param.Tables.size(); ++i) { + arg_tables_[i] = param.Tables[i]->data(); + } + + auto& id_dims = param.Ids[0]->dims(); + auto& table_dims = param.Tables[0]->dims(); + int idx_len = id_dims[0] * id_dims[1]; + int embed_dim = table_dims[1]; + int emb_layer_num = param.Ids.size(); + int r = xdnn::embedding_with_ewadd( + ctx.GetRawContext(), /* context */ + embed_dim, /* embed_dim */ + idx_len, /* idx_len */ + emb_layer_num, /* emb_layer_num */ + param.padding_idx, /* padding_idx */ + &arg_tables_[0], /* tables */ + &arg_ids_[0], /* indices */ + static_cast(table_lens_guard_.get()), /* table_lens */ + nullptr, /* scale_after_emb */ + nullptr, /* scale_after_ewadd */ + param.Out->mutable_data(TARGET(kXPU)) /* top */); + CHECK_EQ(r, 0); +} + +} // namespace xpu +} // namespace kernels +} // namespace lite +} // namespace paddle + +REGISTER_LITE_KERNEL( + __xpu__embedding_with_eltwise_add, + kXPU, + kFloat, + kNCHW, + paddle::lite::kernels::xpu::XPUEmbeddingWithEltwiseAddCompute, + def) + .BindInput("Ids", {LiteType::GetTensorTy(TARGET(kXPU), PRECISION(kInt64))}) + .BindInput("Tables", {LiteType::GetTensorTy(TARGET(kXPU))}) + .BindOutput("Output", {LiteType::GetTensorTy(TARGET(kXPU))}) + .Finalize(); diff --git a/lite/kernels/xpu/__xpu__embedding_with_eltwise_add_compute.h b/lite/kernels/xpu/__xpu__embedding_with_eltwise_add_compute.h new file mode 100644 index 0000000000000000000000000000000000000000..10ba6e0b5b76a1dbebfd633732f7c36e6ac7c954 --- /dev/null +++ b/lite/kernels/xpu/__xpu__embedding_with_eltwise_add_compute.h @@ -0,0 +1,46 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include "lite/core/kernel.h" +#include "lite/kernels/xpu/utils.h" // XPUFreeDeleter + +namespace paddle { +namespace lite { +namespace kernels { +namespace xpu { + +class XPUEmbeddingWithEltwiseAddCompute + : public KernelLite { + public: + using param_t = operators::XPUEmbeddingWithEltwiseAddParam; + + void PrepareForRun() override; + + void Run() override; + + private: + std::vector arg_ids_; + std::vector arg_tables_; + std::unique_ptr table_lens_guard_; + std::vector table_lens_cpu_; +}; + +} // namespace xpu +} // namespace kernels +} // namespace lite +} // namespace paddle diff --git a/lite/kernels/xpu/__xpu__fc_compute.cc b/lite/kernels/xpu/__xpu__fc_compute.cc new file mode 100644 index 0000000000000000000000000000000000000000..7d7ec01d36aa58f45954ede6f745d50e6c06df41 --- /dev/null +++ b/lite/kernels/xpu/__xpu__fc_compute.cc @@ -0,0 +1,71 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/kernels/xpu/__xpu__fc_compute.h" +#include "lite/backends/xpu/xpu_header_sitter.h" +#include "lite/core/op_registry.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace xpu { + +void XPUFcCompute::Run() { + auto& param = this->Param(); + auto& ctx = this->ctx_->As(); + + auto input_dims = param.input->dims(); + param.in_mat_dims = input_dims.Flatten2D(param.in_num_col_dims); + int m = param.in_mat_dims[0]; + int k = param.in_mat_dims[1]; + int n = param.w->dims()[1]; + const float* bias = param.bias ? param.bias->data() : nullptr; + xdnn::Activation_t act_type = (param.activation_type == "relu") + ? xdnn::Activation_t::RELU + : xdnn::Activation_t::LINEAR; + + int r = xdnn::fc_int16( + ctx.GetRawContext(), /* context */ + false, /* TransA */ + param.transpose_w, /* TransB */ + m, /* m */ + n, /* n */ + k, /* k */ + 1.0f, /* alpha */ + param.input->data(), /* A */ + reinterpret_cast(param.w->data()), /* B */ + param.w_max, /* max_b */ + 0.0f, /* beta */ + param.output->mutable_data(TARGET(kXPU)), /* C */ + bias, /* bias */ + act_type /* act_type */); + CHECK_EQ(r, 0); +} + +} // namespace xpu +} // namespace kernels +} // namespace lite +} // namespace paddle + +REGISTER_LITE_KERNEL(__xpu__fc, + kXPU, + kFloat, + kNCHW, + paddle::lite::kernels::xpu::XPUFcCompute, + def) + .BindInput("Input", {LiteType::GetTensorTy(TARGET(kXPU))}) + .BindInput("Bias", {LiteType::GetTensorTy(TARGET(kXPU))}) + .BindInput("W", {LiteType::GetTensorTy(TARGET(kXPU))}) + .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kXPU))}) + .Finalize(); diff --git a/lite/kernels/xpu/__xpu__fc_compute.h b/lite/kernels/xpu/__xpu__fc_compute.h new file mode 100644 index 0000000000000000000000000000000000000000..73295645ab50dbc1d341479a330ffcfa94dad3f4 --- /dev/null +++ b/lite/kernels/xpu/__xpu__fc_compute.h @@ -0,0 +1,36 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "lite/core/kernel.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace xpu { + +class XPUFcCompute : public KernelLite { + public: + using param_t = operators::XPUFcParam; + + virtual void Run(); + + virtual ~XPUFcCompute() = default; +}; + +} // namespace xpu +} // namespace kernels +} // namespace lite +} // namespace paddle diff --git a/lite/kernels/xpu/stack_compute.cc b/lite/kernels/xpu/stack_compute.cc index e9e5c19d25135ac5877e38eaf65829fefc500e07..90a6c70b49f39ce744f2a03eec41d79ddc768a19 100644 --- a/lite/kernels/xpu/stack_compute.cc +++ b/lite/kernels/xpu/stack_compute.cc @@ -13,6 +13,7 @@ // limitations under the License. #include "lite/kernels/xpu/stack_compute.h" +#include "lite/backends/xpu/xpu_header_sitter.h" #include "lite/core/op_registry.h" namespace paddle { diff --git a/lite/kernels/xpu/stack_compute.h b/lite/kernels/xpu/stack_compute.h index 6f77cbb3a73bce2d5496f840b2a1f8e14313e776..1ba1d92dc9479cfd00c5e154df7b5476ffd9976c 100644 --- a/lite/kernels/xpu/stack_compute.h +++ b/lite/kernels/xpu/stack_compute.h @@ -16,18 +16,14 @@ #include #include -#include "lite/backends/xpu/xpu_header_sitter.h" #include "lite/core/kernel.h" +#include "lite/kernels/xpu/utils.h" // XPUFreeDeleter namespace paddle { namespace lite { namespace kernels { namespace xpu { -struct XPUFreeDeleter { - void operator()(void* p) const { xpu_free(p); } -}; - class StackCompute : public KernelLite { public: using param_t = operators::StackParam; diff --git a/lite/kernels/xpu/utils.h b/lite/kernels/xpu/utils.h new file mode 100644 index 0000000000000000000000000000000000000000..d410cb1567d5c60aeb52b798d9f17c7f5692e096 --- /dev/null +++ b/lite/kernels/xpu/utils.h @@ -0,0 +1,31 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "lite/backends/xpu/xpu_header_sitter.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace xpu { + +struct XPUFreeDeleter { + void operator()(void* p) const { xpu_free(p); } +}; + +} // namespace xpu +} // namespace kernels +} // namespace lite +} // namespace paddle diff --git a/lite/operators/CMakeLists.txt b/lite/operators/CMakeLists.txt index c7fa674bff745df29b271e10c8c4d99687a889ed..87f74f9fe7a9211b0dd9925ff588aa6a3595f8f9 100644 --- a/lite/operators/CMakeLists.txt +++ b/lite/operators/CMakeLists.txt @@ -154,6 +154,8 @@ add_operator(sgd_op train SRCS sgd_op.cc DEPS ${op_DEPS}) # Only for XPU add_operator(__xpu__resnet50_op extra SRCS __xpu__resnet50_op.cc DEPS ${op_DEPS}) add_operator(__xpu__multi_encoder_op extra SRCS __xpu__multi_encoder_op.cc DEPS ${op_DEPS}) +add_operator(__xpu__embedding_with_eltwise_add_op extra SRCS __xpu__embedding_with_eltwise_add_op.cc DEPS ${op_DEPS}) +add_operator(__xpu__fc_op extra SRCS __xpu__fc_op.cc DEPS ${op_DEPS}) if (NOT LITE_WITH_X86) lite_cc_test(test_fc_op SRCS fc_op_test.cc diff --git a/lite/operators/__xpu__embedding_with_eltwise_add_op.cc b/lite/operators/__xpu__embedding_with_eltwise_add_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..7c36e7b8157d5d781ad162515364290d8c9ef3be --- /dev/null +++ b/lite/operators/__xpu__embedding_with_eltwise_add_op.cc @@ -0,0 +1,76 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/operators/__xpu__embedding_with_eltwise_add_op.h" +#include "lite/core/op_registry.h" + +namespace paddle { +namespace lite { +namespace operators { + +bool XPUEmbeddingWithEltwiseAddOp::CheckShape() const { + CHECK_OR_FALSE(param_.Ids.size() == param_.Tables.size()); + + auto& id_dims = param_.Ids[0]->dims(); + auto& table_dims = param_.Tables[0]->dims(); + + int id_rank = id_dims.size(); + + CHECK_EQ_OR_FALSE(table_dims.size(), 2); + CHECK_EQ_OR_FALSE(id_dims[id_rank - 1], 1); + + return true; +} + +bool XPUEmbeddingWithEltwiseAddOp::InferShapeImpl() const { + auto& id_dims = param_.Ids[0]->dims(); + auto& table_dims = param_.Tables[0]->dims(); + + auto out_dims = id_dims; + int id_rank = id_dims.size(); + out_dims[id_rank - 1] = table_dims[1]; + + param_.Out->Resize(out_dims); + param_.Out->set_lod(param_.Ids[0]->lod()); + return true; +} + +bool XPUEmbeddingWithEltwiseAddOp::AttachImpl(const cpp::OpDesc& op_desc, + lite::Scope* scope) { + param_.Out = scope->FindVar(op_desc.Output("Output").front()) + ->GetMutable(); + + param_.Ids.clear(); + for (auto& name : op_desc.Input("Ids")) { + auto t = + const_cast(&scope->FindVar(name)->Get()); + param_.Ids.push_back(t); + } + param_.Tables.clear(); + for (auto& name : op_desc.Input("Tables")) { + auto t = + const_cast(&scope->FindVar(name)->Get()); + param_.Tables.push_back(t); + } + + param_.padding_idx = op_desc.GetAttr("padding_idx"); + return true; +} + +} // namespace operators +} // namespace lite +} // namespace paddle + +REGISTER_LITE_OP(__xpu__embedding_with_eltwise_add, + paddle::lite::operators::XPUEmbeddingWithEltwiseAddOp); diff --git a/lite/operators/__xpu__embedding_with_eltwise_add_op.h b/lite/operators/__xpu__embedding_with_eltwise_add_op.h new file mode 100644 index 0000000000000000000000000000000000000000..6cfea5d3f1f8c5085f0d276c0ba420e03d2c75cb --- /dev/null +++ b/lite/operators/__xpu__embedding_with_eltwise_add_op.h @@ -0,0 +1,46 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include "lite/core/op_lite.h" + +namespace paddle { +namespace lite { +namespace operators { + +class XPUEmbeddingWithEltwiseAddOp : public OpLite { + public: + XPUEmbeddingWithEltwiseAddOp() {} + + explicit XPUEmbeddingWithEltwiseAddOp(const std::string &op_type) + : OpLite(op_type) {} + + bool CheckShape() const override; + + bool InferShapeImpl() const override; + + bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; + + void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); } + + std::string DebugString() const override { return "EmbeddingWithEltwiseAdd"; } + + private: + mutable XPUEmbeddingWithEltwiseAddParam param_; +}; + +} // namespace operators +} // namespace lite +} // namespace paddle diff --git a/lite/operators/__xpu__fc_op.cc b/lite/operators/__xpu__fc_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..75a870065570afcdb0c0906458c5922499a33383 --- /dev/null +++ b/lite/operators/__xpu__fc_op.cc @@ -0,0 +1,109 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/operators/__xpu__fc_op.h" +#include +#include "lite/core/op_registry.h" + +namespace paddle { +namespace lite { +namespace operators { + +bool XPUFcOp::CheckShape() const { + CHECK_OR_FALSE(param_.input); + CHECK_OR_FALSE(param_.output); + CHECK_OR_FALSE(param_.w); + // bias is optional. + + const auto input_dims = param_.input->dims(); + const auto w_dims = param_.w->dims(); + CHECK_EQ_OR_FALSE(w_dims.size(), 2UL); + + int64_t w_dims_1 = w_dims[1]; + if (param_.bias) { + const auto bias_dims = param_.bias->dims(); + if (bias_dims.size() == 2) { + CHECK_EQ_OR_FALSE(bias_dims[0], 1); + CHECK_EQ_OR_FALSE(bias_dims[1], w_dims_1); + } else if (bias_dims.size() == 1) { + CHECK_EQ_OR_FALSE(bias_dims[0], w_dims_1); + } + } + + CHECK_GT_OR_FALSE(input_dims.size(), + static_cast(param_.in_num_col_dims)); + param_.in_mat_dims = input_dims.Flatten2D(param_.in_num_col_dims); + CHECK_EQ_OR_FALSE(param_.in_mat_dims[1], w_dims[0]); + + return true; +} + +bool XPUFcOp::InferShapeImpl() const { + const auto& input_dims = param_.input->dims(); + const auto& w_dims = param_.w->dims(); + int in_num_col_dims = param_.in_num_col_dims; + int64_t w_dims_1 = w_dims[1]; + + // Set output dims + std::vector output_dims(in_num_col_dims + 1); + for (int i = 0; i < in_num_col_dims; ++i) { + output_dims[i] = input_dims[i]; + } + output_dims[in_num_col_dims] = w_dims_1; + param_.output->Resize(output_dims); + + // share LoD + param_.output->set_lod(param_.input->lod()); + + return true; +} + +bool XPUFcOp::AttachImpl(const cpp::OpDesc& op_desc, lite::Scope* scope) { + auto input = op_desc.Input("Input").front(); + auto W = op_desc.Input("W").front(); + auto out = op_desc.Output("Out").front(); + + param_.input = scope->FindVar(input)->GetMutable(); + param_.w = scope->FindVar(W)->GetMutable(); + std::vector input_arg_names = op_desc.InputArgumentNames(); + if (std::find(input_arg_names.begin(), input_arg_names.end(), "Bias") != + input_arg_names.end()) { + auto bias_arguments = op_desc.Input("Bias"); + if (bias_arguments.size() > 0) { + auto bias_var = scope->FindVar(bias_arguments.front()); + if (bias_var != nullptr) { + param_.bias = bias_var->GetMutable(); + } + } + } + CHECK(scope->FindVar(out)); + param_.output = scope->FindVar(out)->GetMutable(); + param_.in_num_col_dims = op_desc.GetAttr("in_num_col_dims"); + param_.w_max = op_desc.GetAttr("w_max"); + + if (op_desc.HasAttr("activation_type")) { + param_.activation_type = op_desc.GetAttr("activation_type"); + } + if (op_desc.HasAttr("transpose_w")) { + param_.transpose_w = op_desc.GetAttr("transpose_w"); + } + + return true; +} + +} // namespace operators +} // namespace lite +} // namespace paddle + +REGISTER_LITE_OP(__xpu__fc, paddle::lite::operators::XPUFcOp); diff --git a/lite/operators/__xpu__fc_op.h b/lite/operators/__xpu__fc_op.h new file mode 100644 index 0000000000000000000000000000000000000000..ee8d857335bc469f2de93dd704331709945a98bc --- /dev/null +++ b/lite/operators/__xpu__fc_op.h @@ -0,0 +1,46 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include "lite/core/op_lite.h" + +namespace paddle { +namespace lite { +namespace operators { + +class XPUFcOp : public OpLite { + public: + XPUFcOp() {} + + explicit XPUFcOp(const std::string &op_type) : OpLite(op_type) {} + + bool CheckShape() const override; + + bool InferShapeImpl() const override; + + bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; + + void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); } + + std::string DebugString() const override { return "XPUFc"; } + + private: + mutable XPUFcParam param_; +}; + +} // namespace operators +} // namespace lite +} // namespace paddle diff --git a/lite/operators/op_params.h b/lite/operators/op_params.h index 05bcdd54cdc42b4cc874db2157579cc1cc9a65cb..cfdb0d5389cccda03d304216c4e0a6329e5dc86f 100644 --- a/lite/operators/op_params.h +++ b/lite/operators/op_params.h @@ -1491,6 +1491,26 @@ struct XPUMultiEncoderParam : ParamBase { std::string act_type{}; }; +struct XPUEmbeddingWithEltwiseAddParam : ParamBase { + std::vector Ids; + std::vector Tables; + lite::Tensor* Out{}; + int64_t padding_idx{-1}; +}; + +struct XPUFcParam : ParamBase { + lite::Tensor* input{nullptr}; + lite::Tensor* w{nullptr}; + lite::Tensor* bias{nullptr}; + lite::Tensor* output{nullptr}; + + int in_num_col_dims{1}; + lite::DDim in_mat_dims; + float w_max{0.0f}; + bool transpose_w{true}; + std::string activation_type{""}; +}; + } // namespace operators } // namespace lite } // namespace paddle