未验证 提交 d5a6a1e5 编写于 作者: C Cwndmiao 提交者: GitHub

[XPU] Add more XPU op kernels (#3457)

上级 807454fd
......@@ -208,6 +208,8 @@ class LITE_API CxxConfig : public ConfigBase {
// current thread.
void set_xpu_workspace_l3_size_per_thread(int l3_size = 0xfffc00);
// XPU only, specify the target device ID for the current thread.
// **DEPRECATED**, use xpu_set_device() at the very beginning of each worker
// thread
void set_xpu_dev_per_thread(int dev_no = 0);
};
......
......@@ -33,6 +33,7 @@ USE_MIR_PASS(lite_transpose_softmax_transpose_fuse_pass);
USE_MIR_PASS(lite_interpolate_fuse_pass);
USE_MIR_PASS(lite_sequence_pool_concat_fuse_pass);
USE_MIR_PASS(identity_scale_eliminate_pass);
USE_MIR_PASS(identity_dropout_eliminate_pass);
USE_MIR_PASS(lite_conv_elementwise_fuse_pass);
USE_MIR_PASS(lite_conv_activation_fuse_pass);
USE_MIR_PASS(lite_var_conv_2d_activation_fuse_pass);
......@@ -53,3 +54,5 @@ USE_MIR_PASS(apu_subgraph_pass);
USE_MIR_PASS(quantized_op_attributes_inference_pass);
USE_MIR_PASS(__xpu__resnet_fuse_pass);
USE_MIR_PASS(__xpu__multi_encoder_fuse_pass);
USE_MIR_PASS(__xpu__embedding_with_eltwise_add_fuse_pass);
USE_MIR_PASS(__xpu__fc_fuse_pass);
......@@ -19,6 +19,7 @@ namespace lite {
#ifdef LITE_WITH_XPU
thread_local xdnn::Context* Context<TargetType::kXPU>::_tls_raw_ctx{nullptr};
int Context<TargetType::kXPU>::_workspace_l3_size_per_thread{0};
#endif
} // namespace lite
......
......@@ -151,14 +151,23 @@ class Context<TargetType::kXPU> {
if (_tls_raw_ctx == nullptr) {
_tls_raw_ctx = xdnn::create_context();
CHECK(_tls_raw_ctx);
int r = xdnn::set_workspace_l3_size(_tls_raw_ctx,
_workspace_l3_size_per_thread);
if (r != 0) {
LOG(WARNING) << "xdnn::set_workspace_l3_size() failed, r = " << r
<< ", _workspace_l3_size_per_thread = "
<< _workspace_l3_size_per_thread;
}
}
return _tls_raw_ctx;
}
static void SetWorkspaceL3Size(int l3_size = 0xfffc00) {
xdnn::set_workspace_l3_size(GetRawContext(), l3_size);
_workspace_l3_size_per_thread = l3_size;
}
// **DEPRECATED**, use xpu_set_device() at the very beginning of each worker
// thread
static void SetDev(int dev_no = 0) {
const char* dev_env = getenv("LITE_XPU_DEV");
if (dev_env) {
......@@ -173,6 +182,7 @@ class Context<TargetType::kXPU> {
private:
static thread_local xdnn::Context* _tls_raw_ctx;
static int _workspace_l3_size_per_thread;
};
#endif
......
......@@ -23,7 +23,10 @@ lite_cc_library(mir_passes
fusion/sequence_pool_concat_fuse_pass.cc
fusion/__xpu__resnet_fuse_pass.cc
fusion/__xpu__multi_encoder_fuse_pass.cc
fusion/__xpu__embedding_with_eltwise_add_fuse_pass.cc
fusion/__xpu__fc_fuse_pass.cc
elimination/identity_scale_eliminate_pass.cc
elimination/identity_dropout_eliminate_pass.cc
elimination/elementwise_mul_constant_eliminate_pass.cc
static_kernel_pick_pass.cc
variable_place_inference_pass.cc
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/core/mir/pass.h"
#include "lite/core/mir/pass_registry.h"
#include "lite/core/mir/pattern_matcher_high_api.h"
namespace paddle {
namespace lite {
namespace mir {
namespace {
class Eliminator : public FuseBase {
public:
void BuildPattern() override {
// the previous op's output need updat
auto* pre_op = OpNode("preop")->assert_is_not_op_type("conditional_block");
// TODO(Superjomn) check has only one output
auto* x = VarNode("x")->assert_is_op_input("dropout", "X");
auto* dropout_op = OpNode("dropout", "dropout")
->assert_op_attr<int>("is_test", 1)
->assert_op_attr<std::string>(
"dropout_implementation", "upscale_in_train");
auto* out = VarNode("out")->assert_is_op_output("dropout", "Out");
auto* mask = VarNode("mask")->assert_is_op_output("dropout", "Mask");
*pre_op >> *x >> *dropout_op >> *out;
*dropout_op >> *mask;
// The pre_op will be eliminated, and a new output-updated op will insert.
x->AsIntermediate(); // x is pre_op's output, need to update
dropout_op->AsIntermediate();
mask->AsIntermediate();
}
private:
void InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) override {
auto& pre_op = matched.at("preop")->AsStmt();
auto op_info = *pre_op.op_info();
op_info.UpdateAllOutputs(matched.at("x")->AsArg().name,
matched.at("out")->AsArg().name);
pre_op.ResetOp(op_info, graph->valid_places());
IR_NODE_LINK_TO(matched.at("preop"), matched.at("out"));
}
};
} // namespace
class IdentityDropoutEliminatePass : public ProgramPass {
public:
void Apply(const std::unique_ptr<SSAGraph>& graph) override {
Eliminator eliminator;
eliminator(graph.get());
}
};
} // namespace mir
} // namespace lite
} // namespace paddle
REGISTER_MIR_PASS(identity_dropout_eliminate_pass,
paddle::lite::mir::IdentityDropoutEliminatePass)
.BindTargets({TARGET(kXPU)});
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <memory>
#include <vector>
#include "lite/core/mir/pass_registry.h"
#include "lite/core/mir/xpu_pattern_matcher_high_api.h"
#include "lite/utils/string.h"
namespace paddle {
namespace lite {
namespace mir {
namespace fusion {
class XPUEmbeddingWithEltwiseAddFuser : public FuseBase {
public:
explicit XPUEmbeddingWithEltwiseAddFuser(int n_embedding)
: n_embedding_(n_embedding) {}
void BuildPattern() override {
auto* ids0 =
VarNode("ids0")->assert_is_op_input("lookup_table", "Ids")->AsInput();
auto* table0 =
VarNode("table0")->assert_is_op_input("lookup_table", "W")->AsInput();
auto* embedding0 = OpNode("embedding0", "lookup_table");
auto* embedding_out0 = VarNode("embedding_out0")
->assert_is_op_output("lookup_table", "Out")
->assert_is_op_input("elementwise_add", "X")
->AsIntermediate();
auto* ids1 =
VarNode("ids1")->assert_is_op_input("lookup_table", "Ids")->AsInput();
auto* table1 =
VarNode("table1")->assert_is_op_input("lookup_table", "W")->AsInput();
auto* embedding1 = OpNode("embedding1", "lookup_table")->AsIntermediate();
auto* embedding_out1 = VarNode("embedding_out1")
->assert_is_op_output("lookup_table", "Out")
->assert_is_op_input("elementwise_add", "Y")
->AsIntermediate();
auto* ewadd01 = OpNode("ewadd01", "elementwise_add")->AsIntermediate();
auto* ewadd01_out = VarNode("ewadd01_out")
->assert_is_op_output("elementwise_add", "Out")
->AsIntermediate();
embedding0->LinksFrom({ids0, table0});
embedding0->LinksTo({embedding_out0});
embedding1->LinksFrom({ids1, table1});
embedding1->LinksTo({embedding_out1});
ewadd01->LinksFrom({embedding_out0, embedding_out1});
ewadd01->LinksTo({ewadd01_out});
auto* last_ewadd_out = ewadd01_out;
for (int i = 2; i < n_embedding_; ++i) {
auto ids_name = paddle::lite::string_format("ids%d", i);
auto table_name = paddle::lite::string_format("table%d", i);
auto embedding_name = paddle::lite::string_format("embedding%d", i);
auto embedding_out_name =
paddle::lite::string_format("embedding_out%d", i);
auto* new_ids = VarNode(ids_name)
->assert_is_op_input("lookup_table", "Ids")
->AsInput();
auto* new_table = VarNode(table_name)
->assert_is_op_input("lookup_table", "W")
->AsInput();
auto* new_embedding =
OpNode(embedding_name, "lookup_table")->AsIntermediate();
auto* new_embedding_out = VarNode(embedding_out_name)
->assert_is_op_output("lookup_table", "Out")
->assert_is_op_input("elementwise_add", "Y")
->AsIntermediate();
new_embedding->LinksFrom({new_ids, new_table});
new_embedding->LinksTo({new_embedding_out});
auto ewadd_name = paddle::lite::string_format("ewadd%d%d", i - 1, i);
auto ewadd_out_name = ewadd_name + "_out";
auto* new_ewadd = OpNode(ewadd_name, "elementwise_add")->AsIntermediate();
auto* new_ewadd_out = VarNode(ewadd_out_name)
->assert_is_op_output("elementwise_add", "Out")
->AsIntermediate();
new_ewadd->LinksFrom({last_ewadd_out, new_embedding_out});
new_ewadd->LinksTo({new_ewadd_out});
last_ewadd_out = new_ewadd_out;
}
last_ewadd_out->AsOutput();
}
void InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) override {
cpp::OpDesc op_desc;
op_desc.SetType("__xpu__embedding_with_eltwise_add");
std::vector<std::string> ids_names;
std::vector<std::string> table_names;
for (int i = 0; i < n_embedding_; ++i) {
auto ids_name = paddle::lite::string_format("ids%d", i);
ids_names.push_back(matched.at(ids_name)->arg()->name);
auto table_name = paddle::lite::string_format("table%d", i);
table_names.push_back(matched.at(table_name)->arg()->name);
}
op_desc.SetInput("Ids", ids_names);
op_desc.SetInput("Tables", table_names);
auto output_name = paddle::lite::string_format(
"ewadd%d%d_out", n_embedding_ - 2, n_embedding_ - 1);
op_desc.SetOutput("Output", {matched.at(output_name)->arg()->name});
op_desc.SetAttr<int>("n_embedding", n_embedding_);
auto* embedding0_op_info = matched.at("embedding0")->stmt()->op_info();
op_desc.SetAttr<int64_t>(
"padding_idx", embedding0_op_info->GetAttr<int64_t>("padding_idx"));
auto* new_stmt = matched.at("embedding0")->stmt();
auto new_op = LiteOpRegistry::Global().Create(op_desc.Type());
new_op->Attach(op_desc, new_stmt->op()->scope());
new_op->SetValidPlaces(new_stmt->op()->valid_places());
auto kernels = new_op->CreateKernels(new_op->valid_places());
new_stmt->SetOp(new_op);
new_stmt->SetKernels(std::move(kernels));
for (int i = 0; i < n_embedding_; ++i) {
auto ids_name = paddle::lite::string_format("ids%d", i);
auto table_name = paddle::lite::string_format("table%d", i);
DirectedLink(matched.at(ids_name), matched.at("embedding0"));
DirectedLink(matched.at(table_name), matched.at("embedding0"));
}
IR_OP_VAR_LINK(matched.at("embedding0"), matched.at(output_name));
}
private:
int n_embedding_;
};
} // namespace fusion
class XPUEmbeddingWithEltwiseAddFusePass : public ProgramPass {
public:
void Apply(const std::unique_ptr<SSAGraph>& graph) override {
if (GetBoolFromEnv("XPU_ENABLE_XTCL")) return;
for (int n_embedding : {4, 3}) {
fusion::XPUEmbeddingWithEltwiseAddFuser fuser(n_embedding);
fuser(graph.get());
}
}
};
} // namespace mir
} // namespace lite
} // namespace paddle
REGISTER_MIR_PASS(__xpu__embedding_with_eltwise_add_fuse_pass,
paddle::lite::mir::XPUEmbeddingWithEltwiseAddFusePass)
.BindTargets({TARGET(kXPU)})
.BindKernel("lookup_table");
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <memory>
#include <string>
#include "lite/backends/xpu/math.h"
#include "lite/core/mir/pass_registry.h"
#include "lite/core/mir/pattern_matcher_high_api.h"
namespace paddle {
namespace lite {
namespace mir {
namespace fusion {
class XPUFcFuser : public FuseBase {
public:
explicit XPUFcFuser(bool with_relu) : with_relu_(with_relu) {}
void BuildPattern() override {
// create nodes.
auto* x = VarNode("x")->assert_is_op_input("mul", "X");
auto* W = VarNode("W")->assert_is_op_input("mul", "Y");
auto* b = VarNode("b")->assert_is_persistable_var();
auto* mul = OpNode("mul", "mul");
auto* mul_out = VarNode("mul_out");
auto* add = OpNode("add", "elementwise_add");
auto* Out = VarNode("Out");
// create topology.
std::vector<PMNode*> mul_inputs{W, x};
std::vector<PMNode*> add_inputs{mul_out, b};
mul_inputs >> *mul >> *mul_out;
// Some op specialities.
mul_out->AsIntermediate();
mul->AsIntermediate();
add->AsIntermediate();
if (with_relu_) {
auto* add_out = VarNode("add_out");
auto* relu = OpNode("relu", "relu");
std::vector<PMNode*> relu_inputs{add_out};
add_inputs >> *add >> *add_out;
relu_inputs >> *relu >> *Out;
add_out->AsIntermediate();
relu->AsIntermediate();
} else {
add_inputs >> *add >> *Out;
}
}
void InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) override {
auto mul = matched.at("mul")->stmt()->op();
auto* scope = mul->scope();
// convert W from float to int16, and transpose W
auto weight_name = matched.at("W")->arg()->name;
auto* weight_t = scope->FindMutableTensor(weight_name);
auto weight_dims = weight_t->dims();
int weight_len = weight_t->numel();
float* weight_on_host = weight_t->mutable_data<float>();
float max_f =
paddle::lite::xpu::math::FindMaxAbs(weight_on_host, weight_len);
std::unique_ptr<int16_t[]> weight_int16(new int16_t[weight_len]);
std::unique_ptr<int16_t[]> weight_trans_int16(new int16_t[weight_len]);
paddle::lite::xpu::math::ConvertFP32ToInt16(
weight_on_host, weight_int16.get(), max_f, weight_len);
paddle::lite::xpu::math::Transpose(weight_int16.get(),
weight_trans_int16.get(),
weight_dims[0],
weight_dims[1]);
memcpy(
weight_on_host, weight_trans_int16.get(), weight_len * sizeof(int16_t));
auto op_desc = GenOpDesc(matched, max_f, true);
auto fc_op = LiteOpRegistry::Global().Create("__xpu__fc");
auto& valid_places = mul->valid_places();
fc_op->Attach(op_desc, scope);
auto* new_op_node = graph->GraphCreateInstructNode(fc_op, valid_places);
IR_NODE_LINK_TO(matched.at("W"), new_op_node);
IR_NODE_LINK_TO(matched.at("x"), new_op_node);
IR_NODE_LINK_TO(matched.at("b"), new_op_node);
IR_NODE_LINK_TO(new_op_node, matched.at("Out"));
}
private:
cpp::OpDesc GenOpDesc(const key2nodes_t& matched,
float w_max,
bool transpose_w) {
cpp::OpDesc op_desc = *matched.at("mul")->stmt()->op_info();
op_desc.mutable_inputs()->clear();
op_desc.mutable_outputs()->clear();
op_desc.SetType("__xpu__fc");
op_desc.SetInput("Input", {matched.at("x")->arg()->name});
op_desc.SetInput("W", {matched.at("W")->arg()->name});
op_desc.SetInput("Bias", {matched.at("b")->arg()->name});
op_desc.SetOutput("Out", {matched.at("Out")->arg()->name});
op_desc.SetAttr(
"in_num_col_dims",
matched.at("mul")->stmt()->op_info()->GetAttr<int>("x_num_col_dims"));
op_desc.SetAttr("w_max", w_max);
op_desc.SetAttr("transpose_w", transpose_w);
if (with_relu_) {
op_desc.SetAttr("activation_type", std::string{"relu"});
}
return op_desc;
}
bool with_relu_;
};
} // namespace fusion
class XPUFcFusePass : public ProgramPass {
public:
void Apply(const std::unique_ptr<SSAGraph>& graph) override {
if (GetBoolFromEnv("XPU_ENABLE_XTCL")) return;
fusion::XPUFcFuser fuser(true /* with_relu */);
fuser(graph.get());
fusion::XPUFcFuser fuser2(false /* with_relu */);
fuser2(graph.get());
}
};
} // namespace mir
} // namespace lite
} // namespace paddle
REGISTER_MIR_PASS(__xpu__fc_fuse_pass, paddle::lite::mir::XPUFcFusePass)
.BindTargets({TARGET(kXPU)})
.BindKernel("fc");
......@@ -16,6 +16,7 @@
#include <vector>
#include "lite/backends/xpu/math.h"
#include "lite/core/mir/pass_registry.h"
#include "lite/core/mir/type_precision_cast_pass.h" // For UpdateInputs()
#include "lite/core/mir/xpu_pattern_matcher_high_api.h"
#include "lite/operators/subgraph_op.h"
......@@ -588,8 +589,7 @@ class XPUMultiEncoderFuser {
multi_encoder_stmt->SetOp(multi_encoder_op);
multi_encoder_stmt->SetKernels(std::move(kernels));
// temp remove useless cast
std::unordered_set<const Node*> to_remove2;
// remove dangling/useless cast
Node* stack = nullptr;
for (auto* node : graph->StmtTopologicalOrder()) {
CHECK(node->IsStmt());
......@@ -597,16 +597,39 @@ class XPUMultiEncoderFuser {
stack = node;
}
}
Node* stack_out = stack->outlinks.front();
for (Node* cast : stack_out->outlinks) {
Node* cast_out = cast->outlinks.front();
if (cast_out->outlinks.size() == 0) {
// remove
to_remove2.insert(cast_out);
to_remove2.insert(cast);
if (stack) {
std::unordered_set<const Node*> to_remove2;
Node* stack_out = stack->outlinks.front();
// avoid modification while traversing
auto stack_out_outlinks = stack_out->outlinks;
for (Node* cast : stack_out_outlinks) {
if (cast->stmt()->op_info()->Type() != "cast") {
continue;
}
Node* cast_out = cast->outlinks.front();
if (cast_out->outlinks.size() == 0) {
// dangling cast
to_remove2.insert(cast);
to_remove2.insert(cast_out);
VLOG(3) << "Remove dangling cast [" << cast_out->arg()->name << "]";
} else if (cast_out->outlinks.size() == 1) {
// useless cast
to_remove2.insert(cast);
to_remove2.insert(cast_out);
VLOG(3) << "Remove useless cast [" << cast_out->arg()->name << "]";
auto* multi_encoder = cast_out->outlinks.front();
DirectedLink(stack_out, multi_encoder);
UpdateInputs(multi_encoder->stmt()->op().get(),
cast_out->arg()->name,
stack_out->arg()->name);
auto update_op_info = *multi_encoder->stmt()->op_info();
multi_encoder->stmt()->ResetOp(update_op_info, graph->valid_places());
}
}
GraphSafeRemoveNodes(graph, to_remove2);
}
GraphSafeRemoveNodes(graph, to_remove2);
}
};
......
......@@ -77,6 +77,10 @@ class Optimizer {
#endif
"__xpu__resnet_fuse_pass",
"__xpu__multi_encoder_fuse_pass",
"__xpu__embedding_with_eltwise_add_fuse_pass",
"__xpu__fc_fuse_pass",
"identity_dropout_eliminate_pass", // should be placed after
// xpu fusion
"quantized_op_attributes_inference_pass", // Only for fully
// quantized model, infer
// the output scale and
......
......@@ -24,4 +24,6 @@ else()
add_kernel(cast_compute_xpu XPU basic SRCS cast_compute.cc DEPS ${lite_kernel_deps})
add_kernel(__xpu__resnet50_compute_xpu XPU extra SRCS __xpu__resnet50_compute.cc DEPS ${lite_kernel_deps})
add_kernel(__xpu__multi_encoder_compute_xpu XPU extra SRCS __xpu__multi_encoder_compute.cc DEPS ${lite_kernel_deps})
add_kernel(__xpu__embedding_with_eltwise_add_compute_xpu XPU extra SRCS __xpu__embedding_with_eltwise_add_compute.cc DEPS ${lite_kernel_deps})
add_kernel(__xpu__fc_compute_xpu XPU extra SRCS __xpu__fc_compute.cc DEPS ${lite_kernel_deps})
endif()
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/xpu/__xpu__embedding_with_eltwise_add_compute.h"
#include "lite/backends/xpu/xpu_header_sitter.h"
#include "lite/core/op_registry.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace xpu {
void XPUEmbeddingWithEltwiseAddCompute::PrepareForRun() {
auto& param = this->Param<param_t>();
arg_ids_.reserve(param.Ids.size());
arg_tables_.reserve(param.Tables.size());
for (auto* table : param.Tables) {
auto& table_dims = table->dims();
CHECK_EQ(table_dims.size(), 2); /* shape like [table_len, embed_dim] */
table_lens_cpu_.push_back(table_dims[0]);
}
void* lens_ptr = nullptr;
size_t lens_size = table_lens_cpu_.size() * sizeof(int);
xpu_malloc(&lens_ptr, lens_size);
xpu_memcpy(lens_ptr, &table_lens_cpu_[0], lens_size, XPU_HOST_TO_DEVICE);
table_lens_guard_.reset(lens_ptr);
}
void XPUEmbeddingWithEltwiseAddCompute::Run() {
auto& param = this->Param<param_t>();
auto& ctx = this->ctx_->As<XPUContext>();
for (size_t i = 0; i < param.Ids.size(); ++i) {
arg_ids_[i] = param.Ids[i]->data<int64_t>();
}
for (size_t i = 0; i < param.Tables.size(); ++i) {
arg_tables_[i] = param.Tables[i]->data<float>();
}
auto& id_dims = param.Ids[0]->dims();
auto& table_dims = param.Tables[0]->dims();
int idx_len = id_dims[0] * id_dims[1];
int embed_dim = table_dims[1];
int emb_layer_num = param.Ids.size();
int r = xdnn::embedding_with_ewadd<float, int64_t, false, false>(
ctx.GetRawContext(), /* context */
embed_dim, /* embed_dim */
idx_len, /* idx_len */
emb_layer_num, /* emb_layer_num */
param.padding_idx, /* padding_idx */
&arg_tables_[0], /* tables */
&arg_ids_[0], /* indices */
static_cast<int*>(table_lens_guard_.get()), /* table_lens */
nullptr, /* scale_after_emb */
nullptr, /* scale_after_ewadd */
param.Out->mutable_data<float>(TARGET(kXPU)) /* top */);
CHECK_EQ(r, 0);
}
} // namespace xpu
} // namespace kernels
} // namespace lite
} // namespace paddle
REGISTER_LITE_KERNEL(
__xpu__embedding_with_eltwise_add,
kXPU,
kFloat,
kNCHW,
paddle::lite::kernels::xpu::XPUEmbeddingWithEltwiseAddCompute,
def)
.BindInput("Ids", {LiteType::GetTensorTy(TARGET(kXPU), PRECISION(kInt64))})
.BindInput("Tables", {LiteType::GetTensorTy(TARGET(kXPU))})
.BindOutput("Output", {LiteType::GetTensorTy(TARGET(kXPU))})
.Finalize();
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include <vector>
#include "lite/core/kernel.h"
#include "lite/kernels/xpu/utils.h" // XPUFreeDeleter
namespace paddle {
namespace lite {
namespace kernels {
namespace xpu {
class XPUEmbeddingWithEltwiseAddCompute
: public KernelLite<TARGET(kXPU), PRECISION(kFloat)> {
public:
using param_t = operators::XPUEmbeddingWithEltwiseAddParam;
void PrepareForRun() override;
void Run() override;
private:
std::vector<const int64_t*> arg_ids_;
std::vector<const float*> arg_tables_;
std::unique_ptr<void, XPUFreeDeleter> table_lens_guard_;
std::vector<int> table_lens_cpu_;
};
} // namespace xpu
} // namespace kernels
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/xpu/__xpu__fc_compute.h"
#include "lite/backends/xpu/xpu_header_sitter.h"
#include "lite/core/op_registry.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace xpu {
void XPUFcCompute::Run() {
auto& param = this->Param<param_t>();
auto& ctx = this->ctx_->As<XPUContext>();
auto input_dims = param.input->dims();
param.in_mat_dims = input_dims.Flatten2D(param.in_num_col_dims);
int m = param.in_mat_dims[0];
int k = param.in_mat_dims[1];
int n = param.w->dims()[1];
const float* bias = param.bias ? param.bias->data<float>() : nullptr;
xdnn::Activation_t act_type = (param.activation_type == "relu")
? xdnn::Activation_t::RELU
: xdnn::Activation_t::LINEAR;
int r = xdnn::fc_int16(
ctx.GetRawContext(), /* context */
false, /* TransA */
param.transpose_w, /* TransB */
m, /* m */
n, /* n */
k, /* k */
1.0f, /* alpha */
param.input->data<float>(), /* A */
reinterpret_cast<const int16_t*>(param.w->data<float>()), /* B */
param.w_max, /* max_b */
0.0f, /* beta */
param.output->mutable_data<float>(TARGET(kXPU)), /* C */
bias, /* bias */
act_type /* act_type */);
CHECK_EQ(r, 0);
}
} // namespace xpu
} // namespace kernels
} // namespace lite
} // namespace paddle
REGISTER_LITE_KERNEL(__xpu__fc,
kXPU,
kFloat,
kNCHW,
paddle::lite::kernels::xpu::XPUFcCompute,
def)
.BindInput("Input", {LiteType::GetTensorTy(TARGET(kXPU))})
.BindInput("Bias", {LiteType::GetTensorTy(TARGET(kXPU))})
.BindInput("W", {LiteType::GetTensorTy(TARGET(kXPU))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kXPU))})
.Finalize();
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "lite/core/kernel.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace xpu {
class XPUFcCompute : public KernelLite<TARGET(kXPU), PRECISION(kFloat)> {
public:
using param_t = operators::XPUFcParam;
virtual void Run();
virtual ~XPUFcCompute() = default;
};
} // namespace xpu
} // namespace kernels
} // namespace lite
} // namespace paddle
......@@ -13,6 +13,7 @@
// limitations under the License.
#include "lite/kernels/xpu/stack_compute.h"
#include "lite/backends/xpu/xpu_header_sitter.h"
#include "lite/core/op_registry.h"
namespace paddle {
......
......@@ -16,18 +16,14 @@
#include <memory>
#include <vector>
#include "lite/backends/xpu/xpu_header_sitter.h"
#include "lite/core/kernel.h"
#include "lite/kernels/xpu/utils.h" // XPUFreeDeleter
namespace paddle {
namespace lite {
namespace kernels {
namespace xpu {
struct XPUFreeDeleter {
void operator()(void* p) const { xpu_free(p); }
};
class StackCompute : public KernelLite<TARGET(kXPU), PRECISION(kFloat)> {
public:
using param_t = operators::StackParam;
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "lite/backends/xpu/xpu_header_sitter.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace xpu {
struct XPUFreeDeleter {
void operator()(void* p) const { xpu_free(p); }
};
} // namespace xpu
} // namespace kernels
} // namespace lite
} // namespace paddle
......@@ -154,6 +154,8 @@ add_operator(sgd_op train SRCS sgd_op.cc DEPS ${op_DEPS})
# Only for XPU
add_operator(__xpu__resnet50_op extra SRCS __xpu__resnet50_op.cc DEPS ${op_DEPS})
add_operator(__xpu__multi_encoder_op extra SRCS __xpu__multi_encoder_op.cc DEPS ${op_DEPS})
add_operator(__xpu__embedding_with_eltwise_add_op extra SRCS __xpu__embedding_with_eltwise_add_op.cc DEPS ${op_DEPS})
add_operator(__xpu__fc_op extra SRCS __xpu__fc_op.cc DEPS ${op_DEPS})
if (NOT LITE_WITH_X86)
lite_cc_test(test_fc_op SRCS fc_op_test.cc
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/operators/__xpu__embedding_with_eltwise_add_op.h"
#include "lite/core/op_registry.h"
namespace paddle {
namespace lite {
namespace operators {
bool XPUEmbeddingWithEltwiseAddOp::CheckShape() const {
CHECK_OR_FALSE(param_.Ids.size() == param_.Tables.size());
auto& id_dims = param_.Ids[0]->dims();
auto& table_dims = param_.Tables[0]->dims();
int id_rank = id_dims.size();
CHECK_EQ_OR_FALSE(table_dims.size(), 2);
CHECK_EQ_OR_FALSE(id_dims[id_rank - 1], 1);
return true;
}
bool XPUEmbeddingWithEltwiseAddOp::InferShapeImpl() const {
auto& id_dims = param_.Ids[0]->dims();
auto& table_dims = param_.Tables[0]->dims();
auto out_dims = id_dims;
int id_rank = id_dims.size();
out_dims[id_rank - 1] = table_dims[1];
param_.Out->Resize(out_dims);
param_.Out->set_lod(param_.Ids[0]->lod());
return true;
}
bool XPUEmbeddingWithEltwiseAddOp::AttachImpl(const cpp::OpDesc& op_desc,
lite::Scope* scope) {
param_.Out = scope->FindVar(op_desc.Output("Output").front())
->GetMutable<lite::Tensor>();
param_.Ids.clear();
for (auto& name : op_desc.Input("Ids")) {
auto t =
const_cast<lite::Tensor*>(&scope->FindVar(name)->Get<lite::Tensor>());
param_.Ids.push_back(t);
}
param_.Tables.clear();
for (auto& name : op_desc.Input("Tables")) {
auto t =
const_cast<lite::Tensor*>(&scope->FindVar(name)->Get<lite::Tensor>());
param_.Tables.push_back(t);
}
param_.padding_idx = op_desc.GetAttr<int64_t>("padding_idx");
return true;
}
} // namespace operators
} // namespace lite
} // namespace paddle
REGISTER_LITE_OP(__xpu__embedding_with_eltwise_add,
paddle::lite::operators::XPUEmbeddingWithEltwiseAddOp);
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include "lite/core/op_lite.h"
namespace paddle {
namespace lite {
namespace operators {
class XPUEmbeddingWithEltwiseAddOp : public OpLite {
public:
XPUEmbeddingWithEltwiseAddOp() {}
explicit XPUEmbeddingWithEltwiseAddOp(const std::string &op_type)
: OpLite(op_type) {}
bool CheckShape() const override;
bool InferShapeImpl() const override;
bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override;
void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); }
std::string DebugString() const override { return "EmbeddingWithEltwiseAdd"; }
private:
mutable XPUEmbeddingWithEltwiseAddParam param_;
};
} // namespace operators
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/operators/__xpu__fc_op.h"
#include <vector>
#include "lite/core/op_registry.h"
namespace paddle {
namespace lite {
namespace operators {
bool XPUFcOp::CheckShape() const {
CHECK_OR_FALSE(param_.input);
CHECK_OR_FALSE(param_.output);
CHECK_OR_FALSE(param_.w);
// bias is optional.
const auto input_dims = param_.input->dims();
const auto w_dims = param_.w->dims();
CHECK_EQ_OR_FALSE(w_dims.size(), 2UL);
int64_t w_dims_1 = w_dims[1];
if (param_.bias) {
const auto bias_dims = param_.bias->dims();
if (bias_dims.size() == 2) {
CHECK_EQ_OR_FALSE(bias_dims[0], 1);
CHECK_EQ_OR_FALSE(bias_dims[1], w_dims_1);
} else if (bias_dims.size() == 1) {
CHECK_EQ_OR_FALSE(bias_dims[0], w_dims_1);
}
}
CHECK_GT_OR_FALSE(input_dims.size(),
static_cast<size_t>(param_.in_num_col_dims));
param_.in_mat_dims = input_dims.Flatten2D(param_.in_num_col_dims);
CHECK_EQ_OR_FALSE(param_.in_mat_dims[1], w_dims[0]);
return true;
}
bool XPUFcOp::InferShapeImpl() const {
const auto& input_dims = param_.input->dims();
const auto& w_dims = param_.w->dims();
int in_num_col_dims = param_.in_num_col_dims;
int64_t w_dims_1 = w_dims[1];
// Set output dims
std::vector<DDim::value_type> output_dims(in_num_col_dims + 1);
for (int i = 0; i < in_num_col_dims; ++i) {
output_dims[i] = input_dims[i];
}
output_dims[in_num_col_dims] = w_dims_1;
param_.output->Resize(output_dims);
// share LoD
param_.output->set_lod(param_.input->lod());
return true;
}
bool XPUFcOp::AttachImpl(const cpp::OpDesc& op_desc, lite::Scope* scope) {
auto input = op_desc.Input("Input").front();
auto W = op_desc.Input("W").front();
auto out = op_desc.Output("Out").front();
param_.input = scope->FindVar(input)->GetMutable<lite::Tensor>();
param_.w = scope->FindVar(W)->GetMutable<lite::Tensor>();
std::vector<std::string> input_arg_names = op_desc.InputArgumentNames();
if (std::find(input_arg_names.begin(), input_arg_names.end(), "Bias") !=
input_arg_names.end()) {
auto bias_arguments = op_desc.Input("Bias");
if (bias_arguments.size() > 0) {
auto bias_var = scope->FindVar(bias_arguments.front());
if (bias_var != nullptr) {
param_.bias = bias_var->GetMutable<lite::Tensor>();
}
}
}
CHECK(scope->FindVar(out));
param_.output = scope->FindVar(out)->GetMutable<lite::Tensor>();
param_.in_num_col_dims = op_desc.GetAttr<int>("in_num_col_dims");
param_.w_max = op_desc.GetAttr<float>("w_max");
if (op_desc.HasAttr("activation_type")) {
param_.activation_type = op_desc.GetAttr<std::string>("activation_type");
}
if (op_desc.HasAttr("transpose_w")) {
param_.transpose_w = op_desc.GetAttr<bool>("transpose_w");
}
return true;
}
} // namespace operators
} // namespace lite
} // namespace paddle
REGISTER_LITE_OP(__xpu__fc, paddle::lite::operators::XPUFcOp);
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include "lite/core/op_lite.h"
namespace paddle {
namespace lite {
namespace operators {
class XPUFcOp : public OpLite {
public:
XPUFcOp() {}
explicit XPUFcOp(const std::string &op_type) : OpLite(op_type) {}
bool CheckShape() const override;
bool InferShapeImpl() const override;
bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override;
void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); }
std::string DebugString() const override { return "XPUFc"; }
private:
mutable XPUFcParam param_;
};
} // namespace operators
} // namespace lite
} // namespace paddle
......@@ -1491,6 +1491,26 @@ struct XPUMultiEncoderParam : ParamBase {
std::string act_type{};
};
struct XPUEmbeddingWithEltwiseAddParam : ParamBase {
std::vector<lite::Tensor*> Ids;
std::vector<lite::Tensor*> Tables;
lite::Tensor* Out{};
int64_t padding_idx{-1};
};
struct XPUFcParam : ParamBase {
lite::Tensor* input{nullptr};
lite::Tensor* w{nullptr};
lite::Tensor* bias{nullptr};
lite::Tensor* output{nullptr};
int in_num_col_dims{1};
lite::DDim in_mat_dims;
float w_max{0.0f};
bool transpose_w{true};
std::string activation_type{""};
};
} // namespace operators
} // namespace lite
} // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册