未验证 提交 89a0ecd1 编写于 作者: C Cwndmiao 提交者: GitHub

[LITE][XPU] Support mmdnn3.0-ras (a.k.a. crmm-0608) (#3950)

* fix typo

* [LITE][XPU] accomodate crmm(variant 20200608)

* refine lite/tests/api/test_mmdnn_lite_xpu.cc

* more comments, test=develop test=xpu

* bugfix in crmm pattern match

* pr comments, test=develop test=xpu

* add XPU_CALL and retval check, test=develop test=xpu
上级 85a12dab
...@@ -19,7 +19,7 @@ ...@@ -19,7 +19,7 @@
#include <memory> #include <memory>
#include <string> #include <string>
#include <type_traits> #include <type_traits>
#include "lite/backends/xpu/xpu_header_sitter.h" #include "lite/backends/xpu/target_wrapper.h"
namespace paddle { namespace paddle {
namespace lite { namespace lite {
...@@ -82,8 +82,8 @@ void DumpXPUMem(const T* ptr, ...@@ -82,8 +82,8 @@ void DumpXPUMem(const T* ptr,
size_t item_per_line = 30) { size_t item_per_line = 30) {
size_t after_stride_len = (len + stride - 1) / stride; size_t after_stride_len = (len + stride - 1) / stride;
std::unique_ptr<T[]> cpu_mem(new T[len]); std::unique_ptr<T[]> cpu_mem(new T[len]);
xpu_memcpy( XPU_CALL(xpu_memcpy(
cpu_mem.get(), ptr, len * sizeof(T), XPUMemcpyKind::XPU_DEVICE_TO_HOST); cpu_mem.get(), ptr, len * sizeof(T), XPUMemcpyKind::XPU_DEVICE_TO_HOST));
std::unique_ptr<T[]> after_stride(new T[after_stride_len]); std::unique_ptr<T[]> after_stride(new T[after_stride_len]);
for (size_t i = 0; i < after_stride_len; ++i) { for (size_t i = 0; i < after_stride_len; ++i) {
after_stride[i] = cpu_mem[i * stride]; after_stride[i] = cpu_mem[i * stride];
......
...@@ -19,11 +19,11 @@ namespace lite { ...@@ -19,11 +19,11 @@ namespace lite {
void* TargetWrapperXPU::Malloc(size_t size) { void* TargetWrapperXPU::Malloc(size_t size) {
void* ptr{nullptr}; void* ptr{nullptr};
xpu_malloc(&ptr, size); XPU_CALL(xpu_malloc(&ptr, size));
return ptr; return ptr;
} }
void TargetWrapperXPU::Free(void* ptr) { xpu_free(ptr); } void TargetWrapperXPU::Free(void* ptr) { XPU_CALL(xpu_free(ptr)); }
void TargetWrapperXPU::MemcpySync(void* dst, void TargetWrapperXPU::MemcpySync(void* dst,
const void* src, const void* src,
...@@ -31,10 +31,10 @@ void TargetWrapperXPU::MemcpySync(void* dst, ...@@ -31,10 +31,10 @@ void TargetWrapperXPU::MemcpySync(void* dst,
IoDirection dir) { IoDirection dir) {
switch (dir) { switch (dir) {
case IoDirection::HtoD: case IoDirection::HtoD:
xpu_memcpy(dst, src, size, XPU_HOST_TO_DEVICE); XPU_CALL(xpu_memcpy(dst, src, size, XPU_HOST_TO_DEVICE));
break; break;
case IoDirection::DtoH: case IoDirection::DtoH:
xpu_memcpy(dst, src, size, XPU_DEVICE_TO_HOST); XPU_CALL(xpu_memcpy(dst, src, size, XPU_DEVICE_TO_HOST));
break; break;
default: default:
LOG(FATAL) << "Unsupported IoDirection " << static_cast<int>(dir); LOG(FATAL) << "Unsupported IoDirection " << static_cast<int>(dir);
...@@ -49,7 +49,7 @@ XPUScratchPadGuard TargetWrapperXPU::MallocScratchPad(size_t size, ...@@ -49,7 +49,7 @@ XPUScratchPadGuard TargetWrapperXPU::MallocScratchPad(size_t size,
} else { } else {
ptr = TargetWrapperXPU::Malloc(size); ptr = TargetWrapperXPU::Malloc(size);
} }
CHECK(ptr != nullptr); CHECK(ptr != nullptr) << "size = " << size << ", use_l3 = " << use_l3;
return XPUScratchPadGuard(new XPUScratchPad(ptr, use_l3)); return XPUScratchPadGuard(new XPUScratchPad(ptr, use_l3));
} }
......
...@@ -16,11 +16,23 @@ ...@@ -16,11 +16,23 @@
#include <memory> // std::unique_ptr #include <memory> // std::unique_ptr
#include "lite/backends/xpu/xpu_header_sitter.h" // xpu_free #include "lite/backends/xpu/xpu_header_sitter.h" // xpu_free
#include "lite/core/target_wrapper.h" #include "lite/core/target_wrapper.h" // TargetWrapper
#include "lite/utils/cp_logging.h" // CHECK_EQ
#define XPU_CALL(func) \
{ \
auto e = (func); \
CHECK_EQ(e, 0) << "XPU: (" << #func << ") returns " << e; \
}
namespace paddle { namespace paddle {
namespace lite { namespace lite {
// MAX(lod.size()) = 64
const int XPU_MAX_LOD_SIZE = 64;
// MAX(lod[i + 1] - lod[i]) = 512
const int XPU_MAX_LOD_SEQ_LEN = 512;
using TargetWrapperXPU = TargetWrapper<TARGET(kXPU)>; using TargetWrapperXPU = TargetWrapper<TARGET(kXPU)>;
struct XPUScratchPad { struct XPUScratchPad {
...@@ -33,7 +45,7 @@ struct XPUScratchPad { ...@@ -33,7 +45,7 @@ struct XPUScratchPad {
struct XPUScratchPadDeleter { struct XPUScratchPadDeleter {
void operator()(XPUScratchPad* sp) const { void operator()(XPUScratchPad* sp) const {
if (!sp->is_l3_) { if (!sp->is_l3_) {
xpu_free(sp->addr_); XPU_CALL(xpu_free(sp->addr_));
} }
delete sp; delete sp;
} }
...@@ -55,7 +67,7 @@ class TargetWrapper<TARGET(kXPU)> { ...@@ -55,7 +67,7 @@ class TargetWrapper<TARGET(kXPU)> {
size_t size, size_t size,
IoDirection dir); IoDirection dir);
static XPUScratchPadGuard MallocScratchPad(size_t size, bool use_l3 = true); static XPUScratchPadGuard MallocScratchPad(size_t size, bool use_l3 = false);
static xdnn::Context* GetRawContext() { static xdnn::Context* GetRawContext() {
if (tls_raw_ctx_ == nullptr) { if (tls_raw_ctx_ == nullptr) {
...@@ -77,11 +89,10 @@ class TargetWrapper<TARGET(kXPU)> { ...@@ -77,11 +89,10 @@ class TargetWrapper<TARGET(kXPU)> {
static void SetDev(int dev_no = 0) { static void SetDev(int dev_no = 0) {
const char* dev_env = getenv("LITE_XPU_DEV"); const char* dev_env = getenv("LITE_XPU_DEV");
if (dev_env) { if (dev_env) {
xpu_set_device(atoi(dev_env)); dev_no = atoi(dev_env);
return;
} }
xpu_set_device(dev_no); XPU_CALL(xpu_set_device(dev_no));
} }
static std::string multi_encoder_precision; // NOLINT static std::string multi_encoder_precision; // NOLINT
......
...@@ -326,6 +326,28 @@ class XPUMmdnnSearchAttentionFuser : public FuseBase { ...@@ -326,6 +326,28 @@ class XPUMmdnnSearchAttentionFuser : public FuseBase {
} }
}; };
// 4 inputs
// ========
//
// input_x
// input_y
// topk_row
// topk_col
//
// input_x ------- match_matrix_tensor ------- input_y
// |
// relu
// ________/ \________
// | |
// var_conv_2d |
// | |
// relu |
// |_______ _______|
// \ /
// sequence_concat
// |
// topk_row ---- sequence_topk_avg_pooling ----- topk_col
//
class XPUMmdnnMatchConvTopkFuser : public FuseBase { class XPUMmdnnMatchConvTopkFuser : public FuseBase {
public: public:
void BuildPattern() override { void BuildPattern() override {
...@@ -418,10 +440,156 @@ class XPUMmdnnMatchConvTopkFuser : public FuseBase { ...@@ -418,10 +440,156 @@ class XPUMmdnnMatchConvTopkFuser : public FuseBase {
auto* match_op_info = matched.at("match_matrix_tensor")->stmt()->op_info(); auto* match_op_info = matched.at("match_matrix_tensor")->stmt()->op_info();
op_desc.SetAttr<float>("input_w_max", op_desc.SetAttr<float>("input_w_max",
match_op_info->GetAttr<float>("w_max")); match_op_info->GetAttr<float>("__xpu__w_max"));
op_desc.SetAttr<int>("dim_t", match_op_info->GetAttr<int>("dim_t"));
auto* conv_op_info = matched.at("conv")->stmt()->op_info();
op_desc.SetAttr<float>("conv_w_max",
conv_op_info->GetAttr<float>("__xpu__w_max"));
op_desc.SetAttr<int>("output_channel",
conv_op_info->GetAttr<int>("OutputChannel"));
auto* topk_op_info = matched.at("topk")->stmt()->op_info();
op_desc.SetAttr<std::vector<int>>(
"topks", topk_op_info->GetAttr<std::vector<int>>("topks"));
op_desc.SetAttr<int>("channel_num",
topk_op_info->GetAttr<int>("channel_num"));
auto* new_stmt = matched.at("match_matrix_tensor")->stmt();
auto new_op = LiteOpRegistry::Global().Create(op_desc.Type());
new_op->Attach(op_desc, new_stmt->op()->scope());
new_op->SetValidPlaces(new_stmt->op()->valid_places());
auto kernels = new_op->CreateKernels(new_op->valid_places());
new_stmt->SetOp(new_op);
new_stmt->SetKernels(std::move(kernels));
// XXX(miaotianxiang): redundant links around |topk| are automatically
// removed as |topk| is marked intermediate.
// RemoveDirectedLink(matched.at("topk_col"), matched.at("topk"));
// RemoveDirectedLink(matched.at("topk_row"), matched.at("topk"));
std::vector<std::string> arg_names{"conv_w"};
for (auto name : arg_names) {
DirectedLink(matched.at(name), matched.at("match_matrix_tensor"));
}
std::vector<std::string> out_names{"topk_out"};
for (auto name : out_names) {
IR_OP_VAR_LINK(matched.at("match_matrix_tensor"), matched.at(name));
}
}
};
// 2 inputs
// ========
//
// input_x
// input_y
//
// input_x ------- match_matrix_tensor ------- input_y
// | | |
// | relu |
// | ________/ \________ |
// | | | |
// | var_conv_2d | |
// | | | |
// | relu | |
// | |_______ _______| |
// | \ / |
// | sequence_concat |
// | | |
// |--------- sequence_topk_avg_pooling -------|
//
class XPUMmdnnMatchConvTopkFuser2 : public FuseBase {
public:
void BuildPattern() override {
auto* input_x = VarNode("input_x")
->assert_is_op_input("match_matrix_tensor", "X")
->assert_is_op_input("sequence_topk_avg_pooling", "ROW")
->AsInput();
auto* input_y =
VarNode("input_y")
->assert_is_op_input("match_matrix_tensor", "Y")
->assert_is_op_input("sequence_topk_avg_pooling", "COLUMN")
->AsInput();
auto* input_w = VarNode("input_w")
->assert_is_op_input("match_matrix_tensor", "W")
->AsInput();
auto* match_matrix_tensor =
OpNode("match_matrix_tensor", "match_matrix_tensor");
auto* match_out = VarNode("match_out")
->assert_is_op_output("match_matrix_tensor", "Out")
->AsIntermediate();
auto* match_tmp = VarNode("match_tmp")
->assert_is_op_output("match_matrix_tensor", "Tmp")
->AsIntermediate();
auto* relu0 = OpNode("relu0", "relu")->AsIntermediate();
auto* relu0_out = VarNode("relu0_out")
->assert_is_op_output("relu", "Out")
->AsIntermediate();
auto* conv_w =
VarNode("conv_w")->assert_is_op_input("var_conv_2d", "W")->AsInput();
auto* conv = OpNode("conv", "var_conv_2d")->AsIntermediate();
auto* conv_out = VarNode("conv_out")
->assert_is_op_output("var_conv_2d", "Out")
->AsIntermediate();
auto* conv_col = VarNode("conv_col")
->assert_is_op_output("var_conv_2d", "Col")
->AsIntermediate();
auto* relu1 = OpNode("relu1", "relu")->AsIntermediate();
auto* relu1_out = VarNode("relu1_out")
->assert_is_op_output("relu", "Out")
->AsIntermediate();
auto* seq_concat =
OpNode("seq_concat", "sequence_concat")->AsIntermediate();
auto* seq_concat_out =
VarNode("seq_concat_out")
->assert_is_op_output("sequence_concat", "Out")
->assert_is_op_input("sequence_topk_avg_pooling", "X")
->AsIntermediate();
auto* topk = OpNode("topk", "sequence_topk_avg_pooling")->AsIntermediate();
auto* topk_out =
VarNode("topk_out")
->assert_is_op_output("sequence_topk_avg_pooling", "Out")
->AsOutput();
auto* topk_pos =
VarNode("topk_pos")
->assert_is_op_output("sequence_topk_avg_pooling", "pos")
->AsIntermediate();
*input_x >> *match_matrix_tensor;
*input_y >> *match_matrix_tensor;
*input_w >> *match_matrix_tensor;
*match_matrix_tensor >> *match_out >> *relu0 >> *relu0_out;
*match_matrix_tensor >> *match_tmp;
*relu0_out >> *conv >> *conv_out >> *relu1 >> *relu1_out;
*conv_w >> *conv;
*conv >> *conv_col;
*relu0_out >> *seq_concat;
*relu1_out >> *seq_concat;
*seq_concat >> *seq_concat_out >> *topk >> *topk_out;
*input_x >> *topk;
*input_y >> *topk;
*topk >> *topk_pos;
}
void InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) override {
cpp::OpDesc op_desc;
op_desc.SetType("__xpu__mmdnn_match_conv_topk");
op_desc.SetInput("input_x", {matched.at("input_x")->arg()->name});
op_desc.SetInput("input_y", {matched.at("input_y")->arg()->name});
op_desc.SetInput("input_w", {matched.at("input_w")->arg()->name});
op_desc.SetInput("conv_w", {matched.at("conv_w")->arg()->name});
op_desc.SetOutput("topk_out", {matched.at("topk_out")->arg()->name});
auto* match_op_info = matched.at("match_matrix_tensor")->stmt()->op_info();
op_desc.SetAttr<float>("input_w_max",
match_op_info->GetAttr<float>("__xpu__w_max"));
op_desc.SetAttr<int>("dim_t", match_op_info->GetAttr<int>("dim_t")); op_desc.SetAttr<int>("dim_t", match_op_info->GetAttr<int>("dim_t"));
auto* conv_op_info = matched.at("conv")->stmt()->op_info(); auto* conv_op_info = matched.at("conv")->stmt()->op_info();
op_desc.SetAttr<float>("conv_w_max", conv_op_info->GetAttr<float>("w_max")); op_desc.SetAttr<float>("conv_w_max",
conv_op_info->GetAttr<float>("__xpu__w_max"));
op_desc.SetAttr<int>("output_channel",
conv_op_info->GetAttr<int>("OutputChannel"));
auto* topk_op_info = matched.at("topk")->stmt()->op_info(); auto* topk_op_info = matched.at("topk")->stmt()->op_info();
op_desc.SetAttr<std::vector<int>>( op_desc.SetAttr<std::vector<int>>(
"topks", topk_op_info->GetAttr<std::vector<int>>("topks")); "topks", topk_op_info->GetAttr<std::vector<int>>("topks"));
...@@ -437,8 +605,7 @@ class XPUMmdnnMatchConvTopkFuser : public FuseBase { ...@@ -437,8 +605,7 @@ class XPUMmdnnMatchConvTopkFuser : public FuseBase {
new_stmt->SetKernels(std::move(kernels)); new_stmt->SetKernels(std::move(kernels));
// XXX(miaotianxiang): redundant links around |topk| are automatically // XXX(miaotianxiang): redundant links around |topk| are automatically
// removed as |topk| is // removed as |topk| is marked intermediate.
// marked intermediate.
// RemoveDirectedLink(matched.at("topk_col"), matched.at("topk")); // RemoveDirectedLink(matched.at("topk_col"), matched.at("topk"));
// RemoveDirectedLink(matched.at("topk_row"), matched.at("topk")); // RemoveDirectedLink(matched.at("topk_row"), matched.at("topk"));
std::vector<std::string> arg_names{"conv_w"}; std::vector<std::string> arg_names{"conv_w"};
...@@ -624,6 +791,15 @@ class XPUMmdnnBidEmbAttFuser : public FuseBase { ...@@ -624,6 +791,15 @@ class XPUMmdnnBidEmbAttFuser : public FuseBase {
} }
}; };
// 5 outputs
// =========
//
// eltwise01_out
// seq_pool_right_out
// seq_pool_left_out
// seq_pool_2in1_out
// concat_3in1_out
//
class XPUMmdnnBidEmbGrnnAttFuser : public FuseBase { class XPUMmdnnBidEmbGrnnAttFuser : public FuseBase {
public: public:
void BuildPattern() override { void BuildPattern() override {
...@@ -818,17 +994,272 @@ class XPUMmdnnBidEmbGrnnAttFuser : public FuseBase { ...@@ -818,17 +994,272 @@ class XPUMmdnnBidEmbGrnnAttFuser : public FuseBase {
auto* grnn_fw_op_info = matched.at("grnn_left")->stmt()->op_info(); auto* grnn_fw_op_info = matched.at("grnn_left")->stmt()->op_info();
op_desc.SetAttr<std::vector<float>>( op_desc.SetAttr<std::vector<float>>(
"grnn_fw_wh_maxs", "grnn_fw_wh_maxs",
grnn_fw_op_info->GetAttr<std::vector<float>>("wh_max")); grnn_fw_op_info->GetAttr<std::vector<float>>("__xpu__wh_max"));
op_desc.SetAttr<std::vector<float>>( op_desc.SetAttr<std::vector<float>>(
"grnn_fw_wi_maxs", "grnn_fw_wi_maxs",
grnn_fw_op_info->GetAttr<std::vector<float>>("wi_max")); grnn_fw_op_info->GetAttr<std::vector<float>>("__xpu__wi_max"));
auto* grnn_rv_op_info = matched.at("grnn_right")->stmt()->op_info(); auto* grnn_rv_op_info = matched.at("grnn_right")->stmt()->op_info();
op_desc.SetAttr<std::vector<float>>( op_desc.SetAttr<std::vector<float>>(
"grnn_rv_wh_maxs", "grnn_rv_wh_maxs",
grnn_rv_op_info->GetAttr<std::vector<float>>("wh_max")); grnn_rv_op_info->GetAttr<std::vector<float>>("__xpu__wh_max"));
op_desc.SetAttr<std::vector<float>>( op_desc.SetAttr<std::vector<float>>(
"grnn_rv_wi_maxs", "grnn_rv_wi_maxs",
grnn_rv_op_info->GetAttr<std::vector<float>>("wi_max")); grnn_rv_op_info->GetAttr<std::vector<float>>("__xpu__wi_max"));
auto* att_fc_op_info = matched.at("att_2in1")->stmt()->op_info();
op_desc.SetAttr<float>("att_fc_w_max",
att_fc_op_info->GetAttr<float>("W_max"));
auto* new_stmt = matched.at("emb0")->stmt();
auto new_op = LiteOpRegistry::Global().Create(op_desc.Type());
new_op->Attach(op_desc, new_stmt->op()->scope());
new_op->SetValidPlaces(new_stmt->op()->valid_places());
auto kernels = new_op->CreateKernels(new_op->valid_places());
new_stmt->SetOp(new_op);
new_stmt->SetKernels(std::move(kernels));
std::vector<std::string> arg_names{
"input1",
"grnn_left_wh",
"grnn_left_wi",
"grnn_right_wh",
"grnn_right_wi",
"att_2in1_w",
"att_2in1_b",
};
for (auto name : arg_names) {
DirectedLink(matched.at(name), matched.at("emb0"));
}
std::vector<std::string> out_names{
"seq_pool_left_out",
"seq_pool_right_out",
"seq_pool_2in1_out",
"concat_3in1_out",
"eltwise01_out",
};
for (auto name : out_names) {
IR_OP_VAR_LINK(matched.at("emb0"), matched.at(name));
}
}
};
// 6 outputs
// =========
//
// emb0_out
// eltwise01_out
// seq_pool_right_out
// seq_pool_left_out
// seq_pool_2in1_out
// concat_3in1_out
//
class XPUMmdnnBidEmbGrnnAttFuser2 : public FuseBase {
public:
void BuildPattern() override {
auto* input0 = VarNode("input0")->AsInput();
auto* input1 = VarNode("input1")->AsInput();
auto* emb_tbl = VarNode("emb_tbl")->AsInput();
auto* emb0 = OpNode("emb0", "lookup_table");
auto* emb0_out = VarNode("emb0_out")
->assert_is_op_output("lookup_table", "Out")
->assert_is_op_input("search_seq_arithmetic", "X")
->AsOutput();
auto* emb1 = OpNode("emb1", "lookup_table")->AsIntermediate();
auto* emb1_out = VarNode("emb1_out")
->assert_is_op_output("lookup_table", "Out")
->assert_is_op_input("search_seq_arithmetic", "Y")
->AsIntermediate();
auto* eltwise01 =
OpNode("eltwise01", "search_seq_arithmetic")->AsIntermediate();
auto* eltwise01_out =
VarNode("eltwise01_out")
->assert_is_op_output("search_seq_arithmetic", "Out")
->AsOutput();
auto* seq_rev_right0 =
OpNode("seq_rev_right0", "sequence_reverse")->AsIntermediate();
auto* seq_rev_right0_out =
VarNode("seq_rev_right0_out")
->assert_is_op_output("sequence_reverse", "Y")
->AsIntermediate();
auto* grnn_right_wh = VarNode("grnn_right_wh")
->assert_is_op_input("search_grnn", "Wh")
->AsInput();
auto* grnn_right_wi = VarNode("grnn_right_wi")
->assert_is_op_input("search_grnn", "Wi")
->AsInput();
auto* grnn_right = OpNode("grnn_right", "search_grnn")->AsIntermediate();
auto* grnn_right_out = VarNode("grnn_right_out")
->assert_is_op_output("search_grnn", "Out")
->AsIntermediate();
auto* grnn_right_idx_sorted_by_width =
VarNode("grnn_right_idx_sorted_by_width")
->assert_is_op_output("search_grnn", "idx_sorted_by_width")
->AsIntermediate();
auto* grnn_right_layout_input =
VarNode("grnn_right_layout_input")
->assert_is_op_output("search_grnn", "layout_input")
->AsIntermediate();
auto* grnn_right_tmp_buffer =
VarNode("grnn_right_tmp_buffer")
->assert_is_op_output("search_grnn", "tmp_buffer")
->AsIntermediate();
auto* seq_rev_right1 =
OpNode("seq_rev_right1", "sequence_reverse")->AsIntermediate();
auto* seq_rev_right1_out =
VarNode("seq_rev_right1_out")
->assert_is_op_output("sequence_reverse", "Y")
->AsIntermediate();
auto* seq_pool_right =
OpNode("seq_pool_right", "sequence_pool")->AsIntermediate();
auto* seq_pool_right_out = VarNode("seq_pool_right_out")
->assert_is_op_output("sequence_pool", "Out")
->AsOutput();
auto* seq_pool_right_max_idx =
VarNode("seq_pool_right_max_idx")
->assert_is_op_output("sequence_pool", "MaxIndex")
->AsIntermediate();
auto* grnn_left_wh = VarNode("grnn_left_wh")
->assert_is_op_input("search_grnn", "Wh")
->AsInput();
auto* grnn_left_wi = VarNode("grnn_left_wi")
->assert_is_op_input("search_grnn", "Wi")
->AsInput();
auto* grnn_left = OpNode("grnn_left", "search_grnn")->AsIntermediate();
auto* grnn_left_out = VarNode("grnn_left_out")
->assert_is_op_output("search_grnn", "Out")
->AsIntermediate();
auto* grnn_left_idx_sorted_by_width =
VarNode("grnn_left_idx_sorted_by_width")
->assert_is_op_output("search_grnn", "idx_sorted_by_width")
->AsIntermediate();
auto* grnn_left_layout_input =
VarNode("grnn_left_layout_input")
->assert_is_op_output("search_grnn", "layout_input")
->AsIntermediate();
auto* grnn_left_tmp_buffer =
VarNode("grnn_left_tmp_buffer")
->assert_is_op_output("search_grnn", "tmp_buffer")
->AsIntermediate();
auto* seq_pool_left =
OpNode("seq_pool_left", "sequence_pool")->AsIntermediate();
auto* seq_pool_left_out = VarNode("seq_pool_left_out")
->assert_is_op_output("sequence_pool", "Out")
->AsOutput();
auto* seq_pool_left_max_idx =
VarNode("seq_pool_left_max_idx")
->assert_is_op_output("sequence_pool", "MaxIndex")
->AsIntermediate();
auto* concat_2in1 = OpNode("concat_2in1", "concat")->AsIntermediate();
auto* concat_2in1_out = VarNode("concat_2in1_out")
->assert_is_op_output("concat", "Out")
->AsIntermediate();
auto* att_2in1_w =
VarNode("att_2in1_w")
->assert_is_op_input("__xpu__mmdnn_search_attention", "W")
->AsInput();
auto* att_2in1_b =
VarNode("att_2in1_b")
->assert_is_op_input("__xpu__mmdnn_search_attention", "b")
->AsInput();
auto* att_2in1 =
OpNode("att_2in1", "__xpu__mmdnn_search_attention")->AsIntermediate();
auto* att_2in1_out =
VarNode("att_2in1_out")
->assert_is_op_output("__xpu__mmdnn_search_attention", "Out")
->AsIntermediate();
auto* seq_pool_2in1 =
OpNode("seq_pool_2in1", "sequence_pool")->AsIntermediate();
auto* seq_pool_2in1_out = VarNode("seq_pool_2in1_out")
->assert_is_op_output("sequence_pool", "Out")
->AsOutput();
auto* seq_pool_2in1_max_idx =
VarNode("seq_pool_2in1_max_idx")
->assert_is_op_output("sequence_pool", "MaxIndex")
->AsIntermediate();
auto* concat_3in1 = OpNode("concat_3in1", "concat")->AsIntermediate();
auto* concat_3in1_out = VarNode("concat_3in1_out")
->assert_is_op_output("concat", "Out")
->AsOutput();
*input0 >> *emb0 >> *emb0_out >> *eltwise01 >> *eltwise01_out;
*emb_tbl >> *emb0;
*input1 >> *emb1 >> *emb1_out >> *eltwise01;
*emb_tbl >> *emb1;
*eltwise01_out >> *seq_rev_right0 >> *seq_rev_right0_out >> *grnn_right >>
*grnn_right_out >> *seq_rev_right1 >> *seq_rev_right1_out;
*grnn_right_out >> *seq_pool_right >> *seq_pool_right_out;
*seq_pool_right >> *seq_pool_right_max_idx;
*grnn_right_wh >> *grnn_right;
*grnn_right_wi >> *grnn_right;
*grnn_right >> *grnn_right_idx_sorted_by_width;
*grnn_right >> *grnn_right_layout_input;
*grnn_right >> *grnn_right_tmp_buffer;
*eltwise01_out >> *grnn_left >> *grnn_left_out >> *seq_pool_left >>
*seq_pool_left_out;
*seq_pool_left >> *seq_pool_left_max_idx;
*grnn_left_wh >> *grnn_left;
*grnn_left_wi >> *grnn_left;
*grnn_left >> *grnn_left_idx_sorted_by_width;
*grnn_left >> *grnn_left_layout_input;
*grnn_left >> *grnn_left_tmp_buffer;
*seq_rev_right1_out >> *concat_2in1;
*grnn_left_out >> *concat_2in1;
*concat_2in1 >> *concat_2in1_out >> *att_2in1 >> *att_2in1_out >>
*seq_pool_2in1 >> *seq_pool_2in1_out;
*seq_pool_2in1 >> *seq_pool_2in1_max_idx;
*att_2in1_w >> *att_2in1;
*att_2in1_b >> *att_2in1;
*eltwise01_out >> *concat_3in1;
*seq_rev_right1_out >> *concat_3in1;
*grnn_left_out >> *concat_3in1;
*concat_3in1 >> *concat_3in1_out;
}
void InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) override {
cpp::OpDesc op_desc;
op_desc.SetType("__xpu__mmdnn_bid_emb_grnn_att2");
op_desc.SetInput("id0", {matched.at("input0")->arg()->name});
op_desc.SetInput("id1", {matched.at("input1")->arg()->name});
op_desc.SetInput("emb_tbl", {matched.at("emb_tbl")->arg()->name});
op_desc.SetInput("grnn_fw_wh", {matched.at("grnn_left_wh")->arg()->name});
op_desc.SetInput("grnn_fw_wi", {matched.at("grnn_left_wi")->arg()->name});
op_desc.SetInput("grnn_rv_wh", {matched.at("grnn_right_wh")->arg()->name});
op_desc.SetInput("grnn_rv_wi", {matched.at("grnn_right_wi")->arg()->name});
op_desc.SetInput("att_fc_w", {matched.at("att_2in1_w")->arg()->name});
op_desc.SetInput("att_fc_b", {matched.at("att_2in1_b")->arg()->name});
op_desc.SetOutput("emb0_out", {matched.at("emb0_out")->arg()->name});
op_desc.SetOutput("grnn_fw_pool_out",
{matched.at("seq_pool_left_out")->arg()->name});
op_desc.SetOutput("grnn_rv_pool_out",
{matched.at("seq_pool_right_out")->arg()->name});
op_desc.SetOutput("att_pool_out",
{matched.at("seq_pool_2in1_out")->arg()->name});
op_desc.SetOutput("concat_3in1_out",
{matched.at("concat_3in1_out")->arg()->name});
op_desc.SetOutput("emb_fw_out", {matched.at("eltwise01_out")->arg()->name});
auto* grnn_fw_op_info = matched.at("grnn_left")->stmt()->op_info();
op_desc.SetAttr<std::vector<float>>(
"grnn_fw_wh_maxs",
grnn_fw_op_info->GetAttr<std::vector<float>>("__xpu__wh_max"));
op_desc.SetAttr<std::vector<float>>(
"grnn_fw_wi_maxs",
grnn_fw_op_info->GetAttr<std::vector<float>>("__xpu__wi_max"));
auto* grnn_rv_op_info = matched.at("grnn_right")->stmt()->op_info();
op_desc.SetAttr<std::vector<float>>(
"grnn_rv_wh_maxs",
grnn_rv_op_info->GetAttr<std::vector<float>>("__xpu__wh_max"));
op_desc.SetAttr<std::vector<float>>(
"grnn_rv_wi_maxs",
grnn_rv_op_info->GetAttr<std::vector<float>>("__xpu__wi_max"));
auto* att_fc_op_info = matched.at("att_2in1")->stmt()->op_info(); auto* att_fc_op_info = matched.at("att_2in1")->stmt()->op_info();
op_desc.SetAttr<float>("att_fc_w_max", op_desc.SetAttr<float>("att_fc_w_max",
att_fc_op_info->GetAttr<float>("W_max")); att_fc_op_info->GetAttr<float>("W_max"));
...@@ -868,6 +1299,9 @@ class XPUMmdnnBidEmbGrnnAttFuser : public FuseBase { ...@@ -868,6 +1299,9 @@ class XPUMmdnnBidEmbGrnnAttFuser : public FuseBase {
class XPUMmdnnMergeAllFuser : public FuseBase { class XPUMmdnnMergeAllFuser : public FuseBase {
public: public:
explicit XPUMmdnnMergeAllFuser(int n_concat_topk)
: n_concat_topk_(n_concat_topk) {}
void BuildPattern() override { void BuildPattern() override {
auto* concat_7in1_input0 = VarNode("concat_7in1_input0") auto* concat_7in1_input0 = VarNode("concat_7in1_input0")
->assert_is_op_nth_input("concat", "X", 0) ->assert_is_op_nth_input("concat", "X", 0)
...@@ -909,16 +1343,25 @@ class XPUMmdnnMergeAllFuser : public FuseBase { ...@@ -909,16 +1343,25 @@ class XPUMmdnnMergeAllFuser : public FuseBase {
->assert_is_op_output("relu", "Out") ->assert_is_op_output("relu", "Out")
->AsIntermediate(); ->AsIntermediate();
auto* concat_2in1_input0 = VarNode("concat_2in1_input0") auto* concat_topk_input0 = VarNode("concat_topk_input0")
->assert_is_op_nth_input("concat", "X", 0) ->assert_is_op_nth_input("concat", "X", 0)
->AsInput(); ->AsInput();
auto* concat_2in1_input1 = VarNode("concat_2in1_input1") auto* concat_topk_input1 = VarNode("concat_topk_input1")
->assert_is_op_nth_input("concat", "X", 1) ->assert_is_op_nth_input("concat", "X", 1)
->AsInput(); ->AsInput();
auto* concat_2in1 = OpNode("concat_2in1", "concat")->AsIntermediate(); auto* concat_topk = OpNode("concat_topk", "concat")->AsIntermediate();
auto* concat_2in1_out = VarNode("concat_2in1_out") auto* concat_topk_out = VarNode("concat_topk_out")
->assert_is_op_output("concat", "Out") ->assert_is_op_output("concat", "Out")
->AsIntermediate(); ->AsIntermediate();
for (int i = 2; i < n_concat_topk_; ++i) {
auto concat_topk_input_name =
paddle::lite::string_format("concat_topk_input%d", i);
auto* concat_topk_inputx = VarNode(concat_topk_input_name)
->assert_is_op_nth_input("concat", "X", i)
->AsInput();
*concat_topk_inputx >> *concat_topk;
}
auto* seq_rev = OpNode("seq_rev", "sequence_reverse")->AsIntermediate(); auto* seq_rev = OpNode("seq_rev", "sequence_reverse")->AsIntermediate();
auto* seq_rev_out = VarNode("seq_rev_out") auto* seq_rev_out = VarNode("seq_rev_out")
->assert_is_op_output("sequence_reverse", "Y") ->assert_is_op_output("sequence_reverse", "Y")
...@@ -1034,9 +1477,9 @@ class XPUMmdnnMergeAllFuser : public FuseBase { ...@@ -1034,9 +1477,9 @@ class XPUMmdnnMergeAllFuser : public FuseBase {
*search_fc0_w >> *search_fc0; *search_fc0_w >> *search_fc0;
*search_fc0_b >> *search_fc0; *search_fc0_b >> *search_fc0;
*concat_2in1_input0 >> *concat_2in1; *concat_topk_input0 >> *concat_topk;
*concat_2in1_input1 >> *concat_2in1; *concat_topk_input1 >> *concat_topk;
*concat_2in1 >> *concat_2in1_out >> *seq_rev >> *seq_rev_out; *concat_topk >> *concat_topk_out >> *seq_rev >> *seq_rev_out;
*seq_rev_out >> *grnn_rv >> *grnn_rv_out >> *seq_pool_rv >> *seq_rev_out >> *grnn_rv >> *grnn_rv_out >> *seq_pool_rv >>
*seq_pool_rv_out; *seq_pool_rv_out;
...@@ -1047,7 +1490,7 @@ class XPUMmdnnMergeAllFuser : public FuseBase { ...@@ -1047,7 +1490,7 @@ class XPUMmdnnMergeAllFuser : public FuseBase {
*grnn_rv >> *grnn_rv_layout_input; *grnn_rv >> *grnn_rv_layout_input;
*grnn_rv >> *grnn_rv_tmp_buffer; *grnn_rv >> *grnn_rv_tmp_buffer;
*concat_2in1_out >> *grnn_fw >> *grnn_fw_out >> *seq_pool_fw >> *concat_topk_out >> *grnn_fw >> *grnn_fw_out >> *seq_pool_fw >>
*seq_pool_fw_out; *seq_pool_fw_out;
*seq_pool_fw >> *seq_pool_fw_max_idx; *seq_pool_fw >> *seq_pool_fw_max_idx;
*grnn_fw_wh >> *grnn_fw; *grnn_fw_wh >> *grnn_fw;
...@@ -1075,8 +1518,8 @@ class XPUMmdnnMergeAllFuser : public FuseBase { ...@@ -1075,8 +1518,8 @@ class XPUMmdnnMergeAllFuser : public FuseBase {
op_desc.SetType("__xpu__mmdnn_merge_all"); op_desc.SetType("__xpu__mmdnn_merge_all");
auto* concat_7in1_op_info = matched.at("concat_7in1")->stmt()->op_info(); auto* concat_7in1_op_info = matched.at("concat_7in1")->stmt()->op_info();
op_desc.SetInput("concat_7in1_x", concat_7in1_op_info->Input("X")); op_desc.SetInput("concat_7in1_x", concat_7in1_op_info->Input("X"));
auto* concat_2in1_op_info = matched.at("concat_2in1")->stmt()->op_info(); auto* concat_topk_op_info = matched.at("concat_topk")->stmt()->op_info();
op_desc.SetInput("concat_2in1_x", concat_2in1_op_info->Input("X")); op_desc.SetInput("concat_topk_x", concat_topk_op_info->Input("X"));
op_desc.SetInput("grnn_fw_wh", {matched.at("grnn_fw_wh")->arg()->name}); op_desc.SetInput("grnn_fw_wh", {matched.at("grnn_fw_wh")->arg()->name});
op_desc.SetInput("grnn_fw_wi", {matched.at("grnn_fw_wi")->arg()->name}); op_desc.SetInput("grnn_fw_wi", {matched.at("grnn_fw_wi")->arg()->name});
op_desc.SetInput("grnn_rv_wh", {matched.at("grnn_rv_wh")->arg()->name}); op_desc.SetInput("grnn_rv_wh", {matched.at("grnn_rv_wh")->arg()->name});
...@@ -1093,23 +1536,26 @@ class XPUMmdnnMergeAllFuser : public FuseBase { ...@@ -1093,23 +1536,26 @@ class XPUMmdnnMergeAllFuser : public FuseBase {
auto* grnn_fw_op_info = matched.at("grnn_fw")->stmt()->op_info(); auto* grnn_fw_op_info = matched.at("grnn_fw")->stmt()->op_info();
op_desc.SetAttr<std::vector<float>>( op_desc.SetAttr<std::vector<float>>(
"grnn_fw_wh_maxs", "grnn_fw_wh_maxs",
grnn_fw_op_info->GetAttr<std::vector<float>>("wh_max")); grnn_fw_op_info->GetAttr<std::vector<float>>("__xpu__wh_max"));
op_desc.SetAttr<std::vector<float>>( op_desc.SetAttr<std::vector<float>>(
"grnn_fw_wi_maxs", "grnn_fw_wi_maxs",
grnn_fw_op_info->GetAttr<std::vector<float>>("wi_max")); grnn_fw_op_info->GetAttr<std::vector<float>>("__xpu__wi_max"));
auto* grnn_rv_op_info = matched.at("grnn_rv")->stmt()->op_info(); auto* grnn_rv_op_info = matched.at("grnn_rv")->stmt()->op_info();
op_desc.SetAttr<std::vector<float>>( op_desc.SetAttr<std::vector<float>>(
"grnn_rv_wh_maxs", "grnn_rv_wh_maxs",
grnn_rv_op_info->GetAttr<std::vector<float>>("wh_max")); grnn_rv_op_info->GetAttr<std::vector<float>>("__xpu__wh_max"));
op_desc.SetAttr<std::vector<float>>( op_desc.SetAttr<std::vector<float>>(
"grnn_rv_wi_maxs", "grnn_rv_wi_maxs",
grnn_rv_op_info->GetAttr<std::vector<float>>("wi_max")); grnn_rv_op_info->GetAttr<std::vector<float>>("__xpu__wi_max"));
auto* fc0_op_info = matched.at("search_fc0")->stmt()->op_info(); auto* fc0_op_info = matched.at("search_fc0")->stmt()->op_info();
op_desc.SetAttr<float>("fc0_w_max", fc0_op_info->GetAttr<float>("w_max")); op_desc.SetAttr<float>("fc0_w_max",
fc0_op_info->GetAttr<float>("__xpu__w_max"));
auto* fc1_op_info = matched.at("search_fc1")->stmt()->op_info(); auto* fc1_op_info = matched.at("search_fc1")->stmt()->op_info();
op_desc.SetAttr<float>("fc1_w_max", fc1_op_info->GetAttr<float>("w_max")); op_desc.SetAttr<float>("fc1_w_max",
fc1_op_info->GetAttr<float>("__xpu__w_max"));
auto* fc2_op_info = matched.at("search_fc2")->stmt()->op_info(); auto* fc2_op_info = matched.at("search_fc2")->stmt()->op_info();
op_desc.SetAttr<float>("fc2_w_max", fc2_op_info->GetAttr<float>("w_max")); op_desc.SetAttr<float>("fc2_w_max",
fc2_op_info->GetAttr<float>("__xpu__w_max"));
auto* new_stmt = matched.at("concat_7in1")->stmt(); auto* new_stmt = matched.at("concat_7in1")->stmt();
auto new_op = LiteOpRegistry::Global().Create(op_desc.Type()); auto new_op = LiteOpRegistry::Global().Create(op_desc.Type());
...@@ -1120,8 +1566,8 @@ class XPUMmdnnMergeAllFuser : public FuseBase { ...@@ -1120,8 +1566,8 @@ class XPUMmdnnMergeAllFuser : public FuseBase {
new_stmt->SetKernels(std::move(kernels)); new_stmt->SetKernels(std::move(kernels));
std::vector<std::string> arg_names{ std::vector<std::string> arg_names{
"concat_2in1_input0", "concat_topk_input0",
"concat_2in1_input1", "concat_topk_input1",
"grnn_fw_wh", "grnn_fw_wh",
"grnn_fw_wi", "grnn_fw_wi",
"grnn_rv_wh", "grnn_rv_wh",
...@@ -1133,6 +1579,11 @@ class XPUMmdnnMergeAllFuser : public FuseBase { ...@@ -1133,6 +1579,11 @@ class XPUMmdnnMergeAllFuser : public FuseBase {
"search_fc2_w", "search_fc2_w",
"search_fc2_b", "search_fc2_b",
}; };
for (int i = 2; i < n_concat_topk_; ++i) {
auto concat_topk_input_name =
paddle::lite::string_format("concat_topk_input%d", i);
arg_names.push_back(concat_topk_input_name);
}
for (auto name : arg_names) { for (auto name : arg_names) {
DirectedLink(matched.at(name), matched.at("concat_7in1")); DirectedLink(matched.at(name), matched.at("concat_7in1"));
} }
...@@ -1143,6 +1594,9 @@ class XPUMmdnnMergeAllFuser : public FuseBase { ...@@ -1143,6 +1594,9 @@ class XPUMmdnnMergeAllFuser : public FuseBase {
IR_OP_VAR_LINK(matched.at("concat_7in1"), matched.at(name)); IR_OP_VAR_LINK(matched.at("concat_7in1"), matched.at(name));
} }
} }
private:
int n_concat_topk_;
}; };
} // namespace fusion } // namespace fusion
...@@ -1158,15 +1612,21 @@ class XPUMmdnnFusePass : public ProgramPass { ...@@ -1158,15 +1612,21 @@ class XPUMmdnnFusePass : public ProgramPass {
search_att_fuser(graph.get()); search_att_fuser(graph.get());
fusion::XPUMmdnnMatchConvTopkFuser match_conv_topk_fuser; fusion::XPUMmdnnMatchConvTopkFuser match_conv_topk_fuser;
match_conv_topk_fuser(graph.get()); match_conv_topk_fuser(graph.get());
fusion::XPUMmdnnMatchConvTopkFuser2 match_conv_topk_fuser2;
match_conv_topk_fuser2(graph.get());
fusion::XPUMmdnnBidSeqRevEmbEltwiseFuser bi_seq_rev_emb_eltwise_fuser; fusion::XPUMmdnnBidSeqRevEmbEltwiseFuser bi_seq_rev_emb_eltwise_fuser;
bi_seq_rev_emb_eltwise_fuser(graph.get()); bi_seq_rev_emb_eltwise_fuser(graph.get());
fusion::XPUMmdnnBidEmbGrnnAttFuser bid_emb_grnn_att_fuser; fusion::XPUMmdnnBidEmbGrnnAttFuser bid_emb_grnn_att_fuser;
bid_emb_grnn_att_fuser(graph.get()); bid_emb_grnn_att_fuser(graph.get());
fusion::XPUMmdnnBidEmbGrnnAttFuser2 bid_emb_grnn_att_fuser2;
bid_emb_grnn_att_fuser2(graph.get());
fusion::XPUMmdnnBidEmbAttFuser bid_emb_att_fuser; fusion::XPUMmdnnBidEmbAttFuser bid_emb_att_fuser;
bid_emb_att_fuser(graph.get()); bid_emb_att_fuser(graph.get());
fusion::XPUMmdnnMergeAllFuser merge_all_fuser; for (int n_concat_topk : {3, 2}) {
merge_all_fuser(graph.get()); fusion::XPUMmdnnMergeAllFuser merge_all_fuser(n_concat_topk);
merge_all_fuser(graph.get());
}
} }
}; };
...@@ -1178,6 +1638,7 @@ REGISTER_MIR_PASS(__xpu__mmdnn_fuse_pass, paddle::lite::mir::XPUMmdnnFusePass) ...@@ -1178,6 +1638,7 @@ REGISTER_MIR_PASS(__xpu__mmdnn_fuse_pass, paddle::lite::mir::XPUMmdnnFusePass)
.BindTargets({TARGET(kXPU)}) .BindTargets({TARGET(kXPU)})
.BindKernel("__xpu__mmdnn_search_attention") .BindKernel("__xpu__mmdnn_search_attention")
.BindKernel("__xpu__mmdnn_bid_emb_grnn_att") .BindKernel("__xpu__mmdnn_bid_emb_grnn_att")
.BindKernel("__xpu__mmdnn_bid_emb_grnn_att2")
.BindKernel("__xpu__mmdnn_bid_emb_att") .BindKernel("__xpu__mmdnn_bid_emb_att")
.BindKernel("__xpu__mmdnn_match_conv_topk") .BindKernel("__xpu__mmdnn_match_conv_topk")
.BindKernel("__xpu__mmdnn_merge_all"); .BindKernel("__xpu__mmdnn_merge_all");
...@@ -31,11 +31,14 @@ void XPUEmbeddingWithEltwiseAddCompute::PrepareForRun() { ...@@ -31,11 +31,14 @@ void XPUEmbeddingWithEltwiseAddCompute::PrepareForRun() {
CHECK_EQ(table_dims.size(), 2); /* shape like [table_len, embed_dim] */ CHECK_EQ(table_dims.size(), 2); /* shape like [table_len, embed_dim] */
table_lens_cpu_.push_back(table_dims[0]); table_lens_cpu_.push_back(table_dims[0]);
} }
void* lens_ptr = nullptr;
size_t lens_size = table_lens_cpu_.size() * sizeof(int); size_t lens_size = table_lens_cpu_.size() * sizeof(int);
xpu_malloc(&lens_ptr, lens_size); table_lens_guard_ =
xpu_memcpy(lens_ptr, &table_lens_cpu_[0], lens_size, XPU_HOST_TO_DEVICE); TargetWrapperXPU::MallocScratchPad(lens_size, false /* use_l3 */);
table_lens_guard_.reset(lens_ptr); XPU_CALL(xpu_memcpy(table_lens_guard_->addr_,
&table_lens_cpu_[0],
lens_size,
XPU_HOST_TO_DEVICE));
} }
void XPUEmbeddingWithEltwiseAddCompute::Run() { void XPUEmbeddingWithEltwiseAddCompute::Run() {
...@@ -55,16 +58,16 @@ void XPUEmbeddingWithEltwiseAddCompute::Run() { ...@@ -55,16 +58,16 @@ void XPUEmbeddingWithEltwiseAddCompute::Run() {
int embed_dim = table_dims[1]; int embed_dim = table_dims[1];
int emb_layer_num = param.Ids.size(); int emb_layer_num = param.Ids.size();
int r = xdnn::embedding_with_ewadd<float, int64_t, false, false>( int r = xdnn::embedding_with_ewadd<float, int64_t, false, false>(
ctx.GetRawContext(), /* context */ ctx.GetRawContext(), /* context */
embed_dim, /* embed_dim */ embed_dim, /* embed_dim */
idx_len, /* idx_len */ idx_len, /* idx_len */
emb_layer_num, /* emb_layer_num */ emb_layer_num, /* emb_layer_num */
param.padding_idx, /* padding_idx */ param.padding_idx, /* padding_idx */
&arg_tables_[0], /* tables */ &arg_tables_[0], /* tables */
&arg_ids_[0], /* indices */ &arg_ids_[0], /* indices */
static_cast<int*>(table_lens_guard_.get()), /* table_lens */ static_cast<int*>(table_lens_guard_->addr_), /* table_lens */
nullptr, /* scale_after_emb */ nullptr, /* scale_after_emb */
nullptr, /* scale_after_ewadd */ nullptr, /* scale_after_ewadd */
param.Out->mutable_data<float>(TARGET(kXPU)) /* top */); param.Out->mutable_data<float>(TARGET(kXPU)) /* top */);
CHECK_EQ(r, 0); CHECK_EQ(r, 0);
} }
......
...@@ -14,10 +14,9 @@ ...@@ -14,10 +14,9 @@
#pragma once #pragma once
#include <memory>
#include <vector> #include <vector>
#include "lite/backends/xpu/target_wrapper.h" // XPUScratchPadGuard
#include "lite/core/kernel.h" #include "lite/core/kernel.h"
#include "lite/kernels/xpu/utils.h" // XPUFreeDeleter
namespace paddle { namespace paddle {
namespace lite { namespace lite {
...@@ -36,7 +35,7 @@ class XPUEmbeddingWithEltwiseAddCompute ...@@ -36,7 +35,7 @@ class XPUEmbeddingWithEltwiseAddCompute
private: private:
std::vector<const int64_t*> arg_ids_; std::vector<const int64_t*> arg_ids_;
std::vector<const float*> arg_tables_; std::vector<const float*> arg_tables_;
std::unique_ptr<void, XPUFreeDeleter> table_lens_guard_; XPUScratchPadGuard table_lens_guard_;
std::vector<int> table_lens_cpu_; std::vector<int> table_lens_cpu_;
}; };
......
...@@ -27,8 +27,8 @@ namespace { ...@@ -27,8 +27,8 @@ namespace {
void FillMax(float max, float* xpu_ptr) { void FillMax(float max, float* xpu_ptr) {
float maxs[4] = {max, 0.0f, 0.0f, 0.0f}; float maxs[4] = {max, 0.0f, 0.0f, 0.0f};
xpu_memcpy( XPU_CALL(xpu_memcpy(
xpu_ptr, maxs, 4 * sizeof(float), XPUMemcpyKind::XPU_HOST_TO_DEVICE); xpu_ptr, maxs, 4 * sizeof(float), XPUMemcpyKind::XPU_HOST_TO_DEVICE));
} }
void GrnnLayout(int batch, void GrnnLayout(int batch,
...@@ -156,8 +156,8 @@ class MMDNNIdInfo { ...@@ -156,8 +156,8 @@ class MMDNNIdInfo {
idx_sorted.data(), idx_sorted.data(),
idx_sorted.size() * sizeof(int)); idx_sorted.size() * sizeof(int));
offset += idx_sorted.size() * sizeof(int); offset += idx_sorted.size() * sizeof(int);
xpu_memcpy( XPU_CALL(xpu_memcpy(
l3_buffer_, cpu_buffer_, offset, XPUMemcpyKind::XPU_HOST_TO_DEVICE); l3_buffer_, cpu_buffer_, offset, XPUMemcpyKind::XPU_HOST_TO_DEVICE));
} }
}; };
...@@ -221,29 +221,32 @@ class MMDNNFcOp { ...@@ -221,29 +221,32 @@ class MMDNNFcOp {
int m, int m,
float* out, float* out,
const float* in_max_by_caller = nullptr) { const float* in_max_by_caller = nullptr) {
int r = 0;
if (in_max_by_caller == nullptr) { if (in_max_by_caller == nullptr) {
xdnn::findmax<float>(ctx, in, m * k_, in_max_); r = xdnn::findmax<float>(ctx, in, m * k_, in_max_);
CHECK_EQ(r, 0);
in_max_by_caller = in_max_; in_max_by_caller = in_max_;
} }
xdnn::gemm_int16_maxptr<float, int16_t, float>(ctx, r = xdnn::gemm_int16_maxptr<float, int16_t, float>(ctx,
false, false,
true, true,
m, m,
n_, n_,
k_, k_,
1.0f, 1.0f,
in, in,
k_, k_,
weight_, weight_,
k_, k_,
0.0f, 0.0f,
out, out,
n_, n_,
bias_, bias_,
act_type_, act_type_,
in_max_by_caller, in_max_by_caller,
weight_max_, weight_max_,
out_max); out_max);
CHECK_EQ(r, 0);
} }
}; };
...@@ -331,44 +334,49 @@ class MMDNNGrnnOp { ...@@ -331,44 +334,49 @@ class MMDNNGrnnOp {
gru_out = l3_buffer + 4 * slot_size; gru_out = l3_buffer + 4 * slot_size;
} }
xdnn::search_seq2batch(ctx, int r = 0;
batch, r = xdnn::search_seq2batch(ctx,
max_width, batch,
cap_e_, max_width,
sentense.idx_sorted_32, cap_e_,
sentense.lod_32, sentense.idx_sorted_32,
sentense.new_offset_32, sentense.lod_32,
in, sentense.new_offset_32,
seq2batch_out); in,
seq2batch_out);
xdnn::findmax<float>(ctx, in, cap_l * cap_e_, input_max_); CHECK_EQ(r, 0);
r = xdnn::findmax<float>(ctx, in, cap_l * cap_e_, input_max_);
CHECK_EQ(r, 0);
fc_e2h0_.Infer(ctx, seq2batch_out, cap_l, fc_e2h_out, input_max_); fc_e2h0_.Infer(ctx, seq2batch_out, cap_l, fc_e2h_out, input_max_);
fc_e2h1_.Infer( fc_e2h1_.Infer(
ctx, seq2batch_out, cap_l, fc_e2h_out + cap_l * cap_h_, input_max_); ctx, seq2batch_out, cap_l, fc_e2h_out + cap_l * cap_h_, input_max_);
fc_e2h2_.Infer( fc_e2h2_.Infer(
ctx, seq2batch_out, cap_l, fc_e2h_out + cap_l * cap_h_ * 2, input_max_); ctx, seq2batch_out, cap_l, fc_e2h_out + cap_l * cap_h_ * 2, input_max_);
xdnn::search_grnn<float, int16_t>(ctx, r = xdnn::search_grnn<float, int16_t>(ctx,
cap_l, cap_l,
cap_h_, cap_h_,
cap_e_, cap_e_,
max_width, max_width,
sentense.new_offset_32, sentense.new_offset_32,
fc_e2h_out, fc_e2h_out,
dense_h2h_, dense_h2h_,
gru_out, gru_out,
dense_h2h_max_[0], dense_h2h_max_[0],
dense_h2h_max_[1], dense_h2h_max_[1],
dense_h2h_max_[2]); dense_h2h_max_[2]);
CHECK_EQ(r, 0);
xdnn::search_batch2seq(ctx,
batch, r = xdnn::search_batch2seq(ctx,
max_width, batch,
cap_h_, max_width,
sentense.idx_sorted_32, cap_h_,
sentense.lod_32, sentense.idx_sorted_32,
sentense.new_offset_32, sentense.lod_32,
gru_out, sentense.new_offset_32,
out); gru_out,
out);
CHECK_EQ(r, 0);
} }
}; };
...@@ -435,38 +443,43 @@ class MMDNNAttentionOp { ...@@ -435,38 +443,43 @@ class MMDNNAttentionOp {
} }
seqfc_.Infer(ctx, input, cap_l, seqfc_out); seqfc_.Infer(ctx, input, cap_l, seqfc_out);
xdnn::search_noaligned_mat_mul(ctx, int r = 0;
0, r = xdnn::search_noaligned_mat_mul(ctx,
1, 0,
batch, 1,
lod_32, batch,
max_width, lod_32,
dim_, max_width,
alpha0_, dim_,
input, alpha0_,
seqfc_out, input,
batchgemm0_out); seqfc_out,
xdnn::search_seq_softmax( batchgemm0_out);
CHECK_EQ(r, 0);
r = xdnn::search_seq_softmax(
ctx, batchgemm0_out, seq_softmax_out, lod_32, batch, max_width); ctx, batchgemm0_out, seq_softmax_out, lod_32, batch, max_width);
xdnn::search_noaligned_mat_mul(ctx, CHECK_EQ(r, 0);
0, r = xdnn::search_noaligned_mat_mul(ctx,
0, 0,
batch, 0,
lod_32, batch,
max_width, lod_32,
dim_, max_width,
alpha1_, dim_,
seq_softmax_out, alpha1_,
input, seq_softmax_out,
batchgemm1_out); input,
xdnn::sequence_pooling_forward(ctx, batchgemm1_out);
xdnn::Pooling_t::MAX_WITHOUT_INDEX, CHECK_EQ(r, 0);
batch, r = xdnn::sequence_pooling_forward(ctx,
lod_32, xdnn::Pooling_t::MAX_WITHOUT_INDEX,
dim_, batch,
batchgemm1_out, lod_32,
nullptr, dim_,
pool_out); batchgemm1_out,
nullptr,
pool_out);
CHECK_EQ(r, 0);
} }
}; };
...@@ -510,12 +523,13 @@ class MMDNNMatchConvTopk { ...@@ -510,12 +523,13 @@ class MMDNNMatchConvTopk {
float conv_w_max, float conv_w_max,
int dim_t, int dim_t,
int dim_in, int dim_in,
int out_channel,
int upper_bound_batch, int upper_bound_batch,
int upper_bound_seqlen, int upper_bound_seqlen,
const std::vector<int>& topks) { const std::vector<int>& topks) {
dim_t_ = dim_t; dim_t_ = dim_t;
dim_in_ = dim_in; dim_in_ = dim_in;
out_channel_ = 5; // TODO(miaotianxiang): out_channel_ = out_channel;
topks_ = topks; topks_ = topks;
xw_fc_.Init(input_w, xw_fc_.Init(input_w,
...@@ -553,10 +567,10 @@ class MMDNNMatchConvTopk { ...@@ -553,10 +567,10 @@ class MMDNNMatchConvTopk {
topks_xpu_guard_ = topks_xpu_guard_ =
TargetWrapperXPU::MallocScratchPad(topks_.size() * sizeof(int), false); TargetWrapperXPU::MallocScratchPad(topks_.size() * sizeof(int), false);
topks_xpu_ = reinterpret_cast<int*>(topks_xpu_guard_->addr_); topks_xpu_ = reinterpret_cast<int*>(topks_xpu_guard_->addr_);
xpu_memcpy(topks_xpu_, XPU_CALL(xpu_memcpy(topks_xpu_,
topks_.data(), topks_.data(),
topks_.size() * sizeof(int), topks_.size() * sizeof(int),
XPUMemcpyKind::XPU_HOST_TO_DEVICE); XPUMemcpyKind::XPU_HOST_TO_DEVICE));
useless_topk_pos_guard_ = useless_topk_pos_guard_ =
TargetWrapperXPU::MallocScratchPad(4 * sizeof(int), false); TargetWrapperXPU::MallocScratchPad(4 * sizeof(int), false);
useless_topk_pos_ = reinterpret_cast<int*>(useless_topk_pos_guard_->addr_); useless_topk_pos_ = reinterpret_cast<int*>(useless_topk_pos_guard_->addr_);
...@@ -576,18 +590,18 @@ class MMDNNMatchConvTopk { ...@@ -576,18 +590,18 @@ class MMDNNMatchConvTopk {
for (auto e : left_lod) { for (auto e : left_lod) {
left_lod_32_cpu.push_back(e); left_lod_32_cpu.push_back(e);
} }
xpu_memcpy(left_lod_32_, XPU_CALL(xpu_memcpy(left_lod_32_,
left_lod_32_cpu.data(), left_lod_32_cpu.data(),
left_lod_32_cpu.size() * sizeof(int), left_lod_32_cpu.size() * sizeof(int),
XPUMemcpyKind::XPU_HOST_TO_DEVICE); XPUMemcpyKind::XPU_HOST_TO_DEVICE));
std::vector<int> right_lod_32_cpu; std::vector<int> right_lod_32_cpu;
for (auto e : right_lod) { for (auto e : right_lod) {
right_lod_32_cpu.push_back(e); right_lod_32_cpu.push_back(e);
} }
xpu_memcpy(right_lod_32_, XPU_CALL(xpu_memcpy(right_lod_32_,
right_lod_32_cpu.data(), right_lod_32_cpu.data(),
right_lod_32_cpu.size() * sizeof(int), right_lod_32_cpu.size() * sizeof(int),
XPUMemcpyKind::XPU_HOST_TO_DEVICE); XPUMemcpyKind::XPU_HOST_TO_DEVICE));
std::vector<int> lod_match = {0}; std::vector<int> lod_match = {0};
std::vector<int> lod_conv = {0}; std::vector<int> lod_conv = {0};
...@@ -611,18 +625,18 @@ class MMDNNMatchConvTopk { ...@@ -611,18 +625,18 @@ class MMDNNMatchConvTopk {
left_seqlen_sum += len_x; left_seqlen_sum += len_x;
right_seqlen_sum += len_y; right_seqlen_sum += len_y;
} }
xpu_memcpy(match_lod_32_, XPU_CALL(xpu_memcpy(match_lod_32_,
lod_match.data(), lod_match.data(),
lod_match.size() * sizeof(int), lod_match.size() * sizeof(int),
XPUMemcpyKind::XPU_HOST_TO_DEVICE); XPUMemcpyKind::XPU_HOST_TO_DEVICE));
xpu_memcpy(conv_lod_32_, XPU_CALL(xpu_memcpy(conv_lod_32_,
lod_conv.data(), lod_conv.data(),
lod_conv.size() * sizeof(int), lod_conv.size() * sizeof(int),
XPUMemcpyKind::XPU_HOST_TO_DEVICE); XPUMemcpyKind::XPU_HOST_TO_DEVICE));
xpu_memcpy(topk_offset_32_, XPU_CALL(xpu_memcpy(topk_offset_32_,
lod_topk.data(), lod_topk.data(),
lod_topk.size() * sizeof(int), lod_topk.size() * sizeof(int),
XPUMemcpyKind::XPU_HOST_TO_DEVICE); XPUMemcpyKind::XPU_HOST_TO_DEVICE));
float* xwy_out = hbm_buffer_; float* xwy_out = hbm_buffer_;
float* conv_out = hbm_buffer_ + x_mul_y_sum * dim_t_; float* conv_out = hbm_buffer_ + x_mul_y_sum * dim_t_;
...@@ -640,19 +654,21 @@ class MMDNNMatchConvTopk { ...@@ -640,19 +654,21 @@ class MMDNNMatchConvTopk {
int max_width = std::max(left_seqlen_max, right_seqlen_max); int max_width = std::max(left_seqlen_max, right_seqlen_max);
xw_fc_.Infer(ctx, left->data<float>(), left_seqlen_sum, xw_out); xw_fc_.Infer(ctx, left->data<float>(), left_seqlen_sum, xw_out);
xdnn::match_matrix_tensor(ctx, int r = 0;
batch, r = xdnn::match_matrix_tensor(ctx,
xw_out, batch,
right->data<float>(), xw_out,
left_lod_32_, right->data<float>(),
right_lod_32_, left_lod_32_,
dim_t_, right_lod_32_,
dim_in_, dim_t_,
xwy_out, dim_in_,
xw_fc_.out_max, xwy_out,
xdnn::Activation_t::RELU, xw_fc_.out_max,
max_width); xdnn::Activation_t::RELU,
xdnn::search_varconv<float, int16_t>( max_width);
CHECK_EQ(r, 0);
r = xdnn::search_varconv<float, int16_t>(
ctx, ctx,
batch, batch,
dim_t_, dim_t_,
...@@ -668,24 +684,27 @@ class MMDNNMatchConvTopk { ...@@ -668,24 +684,27 @@ class MMDNNMatchConvTopk {
conv_out, conv_out,
conv_weight_max_, conv_weight_max_,
xdnn::Activation_t::RELU); // TODO(miaotianxiang): xdnn::Activation_t::RELU); // TODO(miaotianxiang):
xdnn::sequence_concat(ctx, CHECK_EQ(r, 0);
xwy_out, r = xdnn::sequence_concat(ctx,
match_lod_32_, xwy_out,
conv_out, match_lod_32_,
conv_lod_32_, conv_out,
seq_concat_out, conv_lod_32_,
batch); seq_concat_out,
xdnn::sequence_topk_avg_pooling(ctx, batch);
seq_concat_out, CHECK_EQ(r, 0);
seq_avg_topk_out, r = xdnn::sequence_topk_avg_pooling(ctx,
useless_topk_pos_, seq_concat_out,
batch, seq_avg_topk_out,
dim_t_ + out_channel_, useless_topk_pos_,
topk_offset_32_, batch,
left_lod_32_, dim_t_ + out_channel_,
right_lod_32_, topk_offset_32_,
topks_xpu_, left_lod_32_,
topks_.size()); right_lod_32_,
topks_xpu_,
topks_.size());
CHECK_EQ(r, 0);
} }
}; };
...@@ -802,34 +821,38 @@ class MMDNNBidEmbGrnnAtt { ...@@ -802,34 +821,38 @@ class MMDNNBidEmbGrnnAtt {
pool_rv = grnn_rv_pool_out->mutable_data<float>(TARGET(kXPU)); pool_rv = grnn_rv_pool_out->mutable_data<float>(TARGET(kXPU));
att_out = att_pool_out->mutable_data<float>(TARGET(kXPU)); att_out = att_pool_out->mutable_data<float>(TARGET(kXPU));
xdnn::search_bid_emb_ew(ctx, int r = 0;
batch, r = xdnn::search_bid_emb_ew(ctx,
sentense.lod_64, batch,
sentense.id0_64, sentense.lod_64,
sentense.id1_64, sentense.id0_64,
table_, sentense.id1_64,
table_len_, table_,
emb_dim_, table_len_,
emb_fw, emb_dim_,
emb_rv, emb_fw,
table_len_ - 2, emb_rv,
1); table_len_ - 2,
1);
CHECK_EQ(r, 0);
bi_rv_.Infer(ctx, bi_rv_.Infer(ctx,
sentense, sentense,
emb_rv, emb_rv,
grnn_rv, grnn_rv,
l3_buffer + 2 * slot_len, l3_buffer + 2 * slot_len,
l3_size - 2 * slot_len * sizeof(float)); l3_size - 2 * slot_len * sizeof(float));
xdnn::sequence_reverse( r = xdnn::sequence_reverse(
ctx, batch, sentense.lod_32, cap_h_, grnn_rv, grnn_rv_rv); ctx, batch, sentense.lod_32, cap_h_, grnn_rv, grnn_rv_rv);
xdnn::sequence_pooling_forward(ctx, CHECK_EQ(r, 0);
xdnn::Pooling_t::LAST, r = xdnn::sequence_pooling_forward(ctx,
batch, xdnn::Pooling_t::LAST,
sentense.lod_32, batch,
cap_h_, sentense.lod_32,
grnn_rv, cap_h_,
nullptr, grnn_rv,
pool_rv); nullptr,
pool_rv);
CHECK_EQ(r, 0);
bi_fw_.Infer(ctx, bi_fw_.Infer(ctx,
sentense, sentense,
...@@ -837,19 +860,23 @@ class MMDNNBidEmbGrnnAtt { ...@@ -837,19 +860,23 @@ class MMDNNBidEmbGrnnAtt {
grnn_fw, grnn_fw,
l3_buffer + 2 * slot_len, l3_buffer + 2 * slot_len,
l3_size - 2 * slot_len * sizeof(float)); l3_size - 2 * slot_len * sizeof(float));
xdnn::sequence_pooling_forward(ctx, r = xdnn::sequence_pooling_forward(ctx,
xdnn::Pooling_t::LAST, xdnn::Pooling_t::LAST,
batch, batch,
sentense.lod_32, sentense.lod_32,
cap_h_, cap_h_,
grnn_fw, grnn_fw,
nullptr, nullptr,
pool_fw); pool_fw);
CHECK_EQ(r, 0);
const int concat_widths[] = {cap_h_, cap_h_, cap_h_}; const int concat_widths[] = {cap_h_, cap_h_, cap_h_};
const float* concat_ptrs[] = {emb_fw, grnn_fw, grnn_rv_rv}; const float* concat_ptrs[] = {emb_fw, grnn_fw, grnn_rv_rv};
xdnn::concat<float>( r = xdnn::concat<float>(
ctx, cap_l, concat_widths + 1, 2, concat_ptrs + 1, concat_2in); ctx, cap_l, concat_widths + 1, 2, concat_ptrs + 1, concat_2in);
xdnn::concat<float>(ctx, cap_l, concat_widths, 3, concat_ptrs, concat_3in); CHECK_EQ(r, 0);
r = xdnn::concat<float>(
ctx, cap_l, concat_widths, 3, concat_ptrs, concat_3in);
CHECK_EQ(r, 0);
att_.Infer(ctx, att_.Infer(ctx,
sentense, sentense,
concat_2in, concat_2in,
...@@ -899,16 +926,18 @@ class MMDNNEmbAtt { ...@@ -899,16 +926,18 @@ class MMDNNEmbAtt {
int cap_l = sentense.lod.back(); int cap_l = sentense.lod.back();
const float* emb_tables[] = {table_, table_}; const float* emb_tables[] = {table_, table_};
const int64_t* emb_indices[] = {sentense.id0_64, sentense.id1_64}; const int64_t* emb_indices[] = {sentense.id0_64, sentense.id1_64};
xdnn::embedding_with_ewadd<float, int64_t, false, false>(ctx, int r =
emb_dim_, xdnn::embedding_with_ewadd<float, int64_t, false, false>(ctx,
cap_l, emb_dim_,
2, cap_l,
table_len_ - 2, 2,
emb_tables, table_len_ - 2,
emb_indices, emb_tables,
nullptr, emb_indices,
nullptr, nullptr,
emb_fw); nullptr,
emb_fw);
CHECK_EQ(r, 0);
att_.Infer(ctx, sentense, emb_fw, att_out, l3_buffer, l3_size); att_.Infer(ctx, sentense, emb_fw, att_out, l3_buffer, l3_size);
} }
}; };
...@@ -990,7 +1019,7 @@ class MMDNNMergeAll { ...@@ -990,7 +1019,7 @@ class MMDNNMergeAll {
fc2_.Init( fc2_.Init(
fc2_w, fc2_w_max, fc2_b, fc2_n_, fc2_k_, xdnn::Activation_t::LINEAR); fc2_w, fc2_w_max, fc2_b, fc2_n_, fc2_k_, xdnn::Activation_t::LINEAR);
int hbm_total_len = max_cap_l * cap_h_ * 4 + int hbm_total_len = max_cap_l * cap_e_ * 2 + max_cap_l * cap_h_ * 2 +
upper_bound_batch * (2 * cap_h_ + fc0_k_ + fc0_n_ + upper_bound_batch * (2 * cap_h_ + fc0_k_ + fc0_n_ +
fc1_k_ + fc1_n_ + fc2_n_); fc1_k_ + fc1_n_ + fc2_n_);
hbm_buffer_guard_ = TargetWrapperXPU::MallocScratchPad( hbm_buffer_guard_ = TargetWrapperXPU::MallocScratchPad(
...@@ -1000,7 +1029,7 @@ class MMDNNMergeAll { ...@@ -1000,7 +1029,7 @@ class MMDNNMergeAll {
void Infer(xdnn::Context* ctx, void Infer(xdnn::Context* ctx,
const MMDNNIdInfo& sentense, const MMDNNIdInfo& sentense,
const std::vector<lite::Tensor*> concat_2in1_x, const std::vector<lite::Tensor*> concat_topk_x,
const std::vector<lite::Tensor*> concat_7in1_x, const std::vector<lite::Tensor*> concat_7in1_x,
lite::Tensor* out, lite::Tensor* out,
float* l3_buffer = nullptr, float* l3_buffer = nullptr,
...@@ -1010,13 +1039,13 @@ class MMDNNMergeAll { ...@@ -1010,13 +1039,13 @@ class MMDNNMergeAll {
float* topk_concat_out_fw = hbm_buffer_; float* topk_concat_out_fw = hbm_buffer_;
int hbm_total_len = int hbm_total_len =
cap_l * cap_h_ * 4 + cap_l * cap_e_ * 2 + cap_l * cap_h_ * 2 +
batch * (2 * cap_h_ + fc0_k_ + fc0_n_ + fc1_k_ + fc1_n_ + fc2_n_); batch * (2 * cap_h_ + fc0_k_ + fc0_n_ + fc1_k_ + fc1_n_ + fc2_n_);
if (l3_size > 0 && l3_size >= hbm_total_len * sizeof(float)) { if (l3_size > 0 && l3_size >= hbm_total_len * sizeof(float)) {
topk_concat_out_fw = l3_buffer; topk_concat_out_fw = l3_buffer;
} }
float* topk_concat_out_rv = topk_concat_out_fw + cap_l * cap_h_; float* topk_concat_out_rv = topk_concat_out_fw + cap_l * cap_e_;
float* grnn_fw = topk_concat_out_rv + cap_l * cap_h_; float* grnn_fw = topk_concat_out_rv + cap_l * cap_e_;
float* grnn_rv = grnn_fw + cap_l * cap_h_; float* grnn_rv = grnn_fw + cap_l * cap_h_;
float* pool_fw = grnn_rv + cap_l * cap_h_; float* pool_fw = grnn_rv + cap_l * cap_h_;
float* pool_rv = pool_fw + batch * cap_h_; float* pool_rv = pool_fw + batch * cap_h_;
...@@ -1027,18 +1056,27 @@ class MMDNNMergeAll { ...@@ -1027,18 +1056,27 @@ class MMDNNMergeAll {
// float* fc2_out = fc1_out + batch * fc1_n_; // float* fc2_out = fc1_out + batch * fc1_n_;
float* fc2_out = out->mutable_data<float>(TARGET(kXPU)); float* fc2_out = out->mutable_data<float>(TARGET(kXPU));
const int concat_widths[] = {static_cast<int>(concat_2in1_x[0]->dims()[1]), std::vector<int> concat_widths;
static_cast<int>(concat_2in1_x[1]->dims()[1])}; std::vector<const float*> concat_ptrs;
const float* concat_ptrs[] = {concat_2in1_x[0]->data<float>(), for (const auto* t : concat_topk_x) {
concat_2in1_x[1]->data<float>()}; concat_widths.push_back(static_cast<int>(t->dims()[1]));
xdnn::concat<float>( concat_ptrs.push_back(t->data<float>());
ctx, cap_l, concat_widths, 2, concat_ptrs, topk_concat_out_fw); }
xdnn::sequence_reverse(ctx, int r = 0;
batch, r = xdnn::concat<float>(ctx,
sentense.lod_32, cap_l,
cap_e_, concat_widths.data(),
topk_concat_out_fw, concat_widths.size(),
topk_concat_out_rv); concat_ptrs.data(),
topk_concat_out_fw);
CHECK_EQ(r, 0);
r = xdnn::sequence_reverse(ctx,
batch,
sentense.lod_32,
cap_e_,
topk_concat_out_fw,
topk_concat_out_rv);
CHECK_EQ(r, 0);
coverage_fw_.Infer(ctx, coverage_fw_.Infer(ctx,
sentense, sentense,
topk_concat_out_fw, topk_concat_out_fw,
...@@ -1051,22 +1089,24 @@ class MMDNNMergeAll { ...@@ -1051,22 +1089,24 @@ class MMDNNMergeAll {
grnn_rv, grnn_rv,
l3_buffer + hbm_total_len, l3_buffer + hbm_total_len,
l3_size - hbm_total_len * sizeof(float)); l3_size - hbm_total_len * sizeof(float));
xdnn::sequence_pooling_forward(ctx, r = xdnn::sequence_pooling_forward(ctx,
xdnn::Pooling_t::LAST, xdnn::Pooling_t::LAST,
batch, batch,
sentense.lod_32, sentense.lod_32,
cap_h_, cap_h_,
grnn_fw, grnn_fw,
nullptr, nullptr,
pool_fw); pool_fw);
xdnn::sequence_pooling_forward(ctx, CHECK_EQ(r, 0);
xdnn::Pooling_t::LAST, r = xdnn::sequence_pooling_forward(ctx,
batch, xdnn::Pooling_t::LAST,
sentense.lod_32, batch,
cap_h_, sentense.lod_32,
grnn_rv, cap_h_,
nullptr, grnn_rv,
pool_rv); nullptr,
pool_rv);
CHECK_EQ(r, 0);
const int concat_widths_fc0[] = { const int concat_widths_fc0[] = {
static_cast<int>(concat_7in1_x[0]->dims()[1]), static_cast<int>(concat_7in1_x[0]->dims()[1]),
...@@ -1089,11 +1129,13 @@ class MMDNNMergeAll { ...@@ -1089,11 +1129,13 @@ class MMDNNMergeAll {
const int concat_widths_fc1[] = {cap_h_, cap_h_, fc0_n_}; const int concat_widths_fc1[] = {cap_h_, cap_h_, fc0_n_};
const float* concat_ptrs_fc1[] = {pool_fw, pool_rv, fc0_out}; const float* concat_ptrs_fc1[] = {pool_fw, pool_rv, fc0_out};
xdnn::concat<float>( r = xdnn::concat<float>(
ctx, batch, concat_widths_fc0, 7, concat_ptrs_fc0, fc0_in); ctx, batch, concat_widths_fc0, 7, concat_ptrs_fc0, fc0_in);
CHECK_EQ(r, 0);
fc0_.Infer(ctx, fc0_in, batch, fc0_out); fc0_.Infer(ctx, fc0_in, batch, fc0_out);
xdnn::concat<float>( r = xdnn::concat<float>(
ctx, batch, concat_widths_fc1, 3, concat_ptrs_fc1, fc1_in); ctx, batch, concat_widths_fc1, 3, concat_ptrs_fc1, fc1_in);
CHECK_EQ(r, 0);
fc1_.Infer(ctx, fc1_in, batch, fc1_out); fc1_.Infer(ctx, fc1_in, batch, fc1_out);
fc2_.Infer(ctx, fc1_out, batch, fc2_out); fc2_.Infer(ctx, fc1_out, batch, fc2_out);
} }
...@@ -1111,14 +1153,12 @@ class XPUMmdnnBidEmbGrnnAttCompute ...@@ -1111,14 +1153,12 @@ class XPUMmdnnBidEmbGrnnAttCompute
private: private:
MMDNNIdInfo id_; MMDNNIdInfo id_;
MMDNNBidEmbGrnnAtt compound_; MMDNNBidEmbGrnnAtt compound_;
int upper_bound_batch_ = 40;
int upper_bound_seqlen_ = 512;
}; };
void XPUMmdnnBidEmbGrnnAttCompute::PrepareForRun() { void XPUMmdnnBidEmbGrnnAttCompute::PrepareForRun() {
auto& param = this->Param<param_t>(); auto& param = this->Param<param_t>();
id_.Init(upper_bound_batch_, upper_bound_seqlen_); id_.Init(XPU_MAX_LOD_SIZE, XPU_MAX_LOD_SEQ_LEN);
compound_.Init(param.emb_tbl, compound_.Init(param.emb_tbl,
param.grnn_fw_wh, param.grnn_fw_wh,
param.grnn_fw_wh_maxs, param.grnn_fw_wh_maxs,
...@@ -1131,8 +1171,8 @@ void XPUMmdnnBidEmbGrnnAttCompute::PrepareForRun() { ...@@ -1131,8 +1171,8 @@ void XPUMmdnnBidEmbGrnnAttCompute::PrepareForRun() {
param.att_fc_w, param.att_fc_w,
param.att_fc_w_max, param.att_fc_w_max,
param.att_fc_b, param.att_fc_b,
upper_bound_batch_, XPU_MAX_LOD_SIZE,
upper_bound_seqlen_); XPU_MAX_LOD_SEQ_LEN);
} }
void XPUMmdnnBidEmbGrnnAttCompute::Run() { void XPUMmdnnBidEmbGrnnAttCompute::Run() {
...@@ -1157,6 +1197,76 @@ void XPUMmdnnBidEmbGrnnAttCompute::Run() { ...@@ -1157,6 +1197,76 @@ void XPUMmdnnBidEmbGrnnAttCompute::Run() {
xpu_ctx->workspace_l3_size - xpu_ctx->used_l3_size); xpu_ctx->workspace_l3_size - xpu_ctx->used_l3_size);
} }
class XPUMmdnnBidEmbGrnnAttCompute2
: public KernelLite<TARGET(kXPU), PRECISION(kFloat)> {
public:
using param_t = operators::XPUMmdnnBidEmbGrnnAttParam2;
void PrepareForRun() override;
void Run() override;
private:
MMDNNIdInfo id_;
MMDNNBidEmbGrnnAtt compound_;
};
void XPUMmdnnBidEmbGrnnAttCompute2::PrepareForRun() {
auto& param = this->Param<param_t>();
id_.Init(XPU_MAX_LOD_SIZE, XPU_MAX_LOD_SEQ_LEN);
compound_.Init(param.emb_tbl,
param.grnn_fw_wh,
param.grnn_fw_wh_maxs,
param.grnn_fw_wi,
param.grnn_fw_wi_maxs,
param.grnn_rv_wh,
param.grnn_rv_wh_maxs,
param.grnn_rv_wi,
param.grnn_rv_wi_maxs,
param.att_fc_w,
param.att_fc_w_max,
param.att_fc_b,
XPU_MAX_LOD_SIZE,
XPU_MAX_LOD_SEQ_LEN);
}
void XPUMmdnnBidEmbGrnnAttCompute2::Run() {
auto& param = this->Param<param_t>();
auto& ctx = this->ctx_->As<XPUContext>();
auto* xpu_ctx = ctx.GetRawContext();
int batch = param.id0->lod()[0].size() - 1;
id_.Update(param.id0, param.id1);
compound_.Infer(ctx.GetRawContext(),
batch,
id_,
param.grnn_fw_pool_out,
param.grnn_rv_pool_out,
param.att_pool_out,
param.concat_3in1_out,
param.emb_fw_out,
reinterpret_cast<float*>(
reinterpret_cast<char*>(xpu_ctx->workspace_l3_ptr) +
xpu_ctx->used_l3_size),
xpu_ctx->workspace_l3_size - xpu_ctx->used_l3_size);
int num = param.id0->numel();
int embed_dim = param.emb_tbl->dims()[1];
// TODO(miaotianxiang):
int r = xdnn::embedding<float, int64_t>(
ctx.GetRawContext(), /* context */
num, /* num */
param.id0->data<int64_t>(), /* indices */
embed_dim, /* embed_dim */
param.emb_tbl->data<float>(), /* table */
param.emb0_out->mutable_data<float>(TARGET(kXPU)), /* top */
128000 /* padding_idx */);
CHECK_EQ(r, 0);
}
class XPUMmdnnBidEmbAttCompute class XPUMmdnnBidEmbAttCompute
: public KernelLite<TARGET(kXPU), PRECISION(kFloat)> { : public KernelLite<TARGET(kXPU), PRECISION(kFloat)> {
public: public:
...@@ -1169,20 +1279,18 @@ class XPUMmdnnBidEmbAttCompute ...@@ -1169,20 +1279,18 @@ class XPUMmdnnBidEmbAttCompute
private: private:
MMDNNIdInfo id_; MMDNNIdInfo id_;
MMDNNEmbAtt compound_; MMDNNEmbAtt compound_;
int upper_bound_batch_ = 40;
int upper_bound_seqlen_ = 512;
}; };
void XPUMmdnnBidEmbAttCompute::PrepareForRun() { void XPUMmdnnBidEmbAttCompute::PrepareForRun() {
auto& param = this->Param<param_t>(); auto& param = this->Param<param_t>();
id_.Init(upper_bound_batch_, upper_bound_seqlen_); id_.Init(XPU_MAX_LOD_SIZE, XPU_MAX_LOD_SEQ_LEN);
compound_.Init(param.emb_tbl, compound_.Init(param.emb_tbl,
param.att_fc_w, param.att_fc_w,
param.att_fc_w_max, param.att_fc_w_max,
param.att_fc_b, param.att_fc_b,
upper_bound_batch_, XPU_MAX_LOD_SIZE,
upper_bound_seqlen_); XPU_MAX_LOD_SEQ_LEN);
} }
void XPUMmdnnBidEmbAttCompute::Run() { void XPUMmdnnBidEmbAttCompute::Run() {
...@@ -1215,8 +1323,6 @@ class XPUMmdnnMatchConvTopkCompute ...@@ -1215,8 +1323,6 @@ class XPUMmdnnMatchConvTopkCompute
private: private:
MMDNNMatchConvTopk compound_; MMDNNMatchConvTopk compound_;
int upper_bound_batch_ = 40;
int upper_bound_seqlen_ = 512;
}; };
void XPUMmdnnMatchConvTopkCompute::PrepareForRun() { void XPUMmdnnMatchConvTopkCompute::PrepareForRun() {
...@@ -1228,8 +1334,9 @@ void XPUMmdnnMatchConvTopkCompute::PrepareForRun() { ...@@ -1228,8 +1334,9 @@ void XPUMmdnnMatchConvTopkCompute::PrepareForRun() {
param.conv_w_max, param.conv_w_max,
param.dim_t, param.dim_t,
param.input_w->dims()[0], param.input_w->dims()[0],
upper_bound_batch_, param.output_channel,
upper_bound_seqlen_, XPU_MAX_LOD_SIZE,
XPU_MAX_LOD_SEQ_LEN,
param.topks); param.topks);
} }
...@@ -1261,14 +1368,12 @@ class XPUMmdnnMergeAllCompute ...@@ -1261,14 +1368,12 @@ class XPUMmdnnMergeAllCompute
private: private:
MMDNNIdInfo id_; MMDNNIdInfo id_;
MMDNNMergeAll compound_; MMDNNMergeAll compound_;
int upper_bound_batch_ = 40;
int upper_bound_seqlen_ = 512;
}; };
void XPUMmdnnMergeAllCompute::PrepareForRun() { void XPUMmdnnMergeAllCompute::PrepareForRun() {
auto& param = this->Param<param_t>(); auto& param = this->Param<param_t>();
id_.Init(upper_bound_batch_, upper_bound_seqlen_); id_.Init(XPU_MAX_LOD_SIZE, XPU_MAX_LOD_SEQ_LEN);
compound_.Init(param.grnn_fw_wh, compound_.Init(param.grnn_fw_wh,
param.grnn_fw_wh_maxs, param.grnn_fw_wh_maxs,
param.grnn_fw_wi, param.grnn_fw_wi,
...@@ -1286,8 +1391,8 @@ void XPUMmdnnMergeAllCompute::PrepareForRun() { ...@@ -1286,8 +1391,8 @@ void XPUMmdnnMergeAllCompute::PrepareForRun() {
param.fc2_w, param.fc2_w,
param.fc2_w_max, param.fc2_w_max,
param.fc2_b, param.fc2_b,
upper_bound_batch_, XPU_MAX_LOD_SIZE,
upper_bound_seqlen_); XPU_MAX_LOD_SEQ_LEN);
} }
void XPUMmdnnMergeAllCompute::Run() { void XPUMmdnnMergeAllCompute::Run() {
...@@ -1296,10 +1401,10 @@ void XPUMmdnnMergeAllCompute::Run() { ...@@ -1296,10 +1401,10 @@ void XPUMmdnnMergeAllCompute::Run() {
auto* xpu_ctx = ctx.GetRawContext(); auto* xpu_ctx = ctx.GetRawContext();
id_.Update(param.concat_2in1_x[0], param.concat_2in1_x[1]); id_.Update(param.concat_topk_x[0], param.concat_topk_x[1]);
compound_.Infer(ctx.GetRawContext(), compound_.Infer(ctx.GetRawContext(),
id_, id_,
param.concat_2in1_x, param.concat_topk_x,
param.concat_7in1_x, param.concat_7in1_x,
param.out, param.out,
reinterpret_cast<float*>( reinterpret_cast<float*>(
...@@ -1335,6 +1440,29 @@ REGISTER_LITE_KERNEL(__xpu__mmdnn_bid_emb_grnn_att, ...@@ -1335,6 +1440,29 @@ REGISTER_LITE_KERNEL(__xpu__mmdnn_bid_emb_grnn_att,
.BindOutput("emb_fw_out", {LiteType::GetTensorTy(TARGET(kXPU))}) .BindOutput("emb_fw_out", {LiteType::GetTensorTy(TARGET(kXPU))})
.Finalize(); .Finalize();
REGISTER_LITE_KERNEL(__xpu__mmdnn_bid_emb_grnn_att2,
kXPU,
kFloat,
kNCHW,
paddle::lite::kernels::xpu::XPUMmdnnBidEmbGrnnAttCompute2,
def)
.BindInput("id0", {LiteType::GetTensorTy(TARGET(kXPU), PRECISION(kInt64))})
.BindInput("id1", {LiteType::GetTensorTy(TARGET(kXPU), PRECISION(kInt64))})
.BindInput("emb_tbl", {LiteType::GetTensorTy(TARGET(kXPU))})
.BindInput("grnn_fw_wh", {LiteType::GetTensorTy(TARGET(kXPU))})
.BindInput("grnn_fw_wi", {LiteType::GetTensorTy(TARGET(kXPU))})
.BindInput("grnn_rv_wh", {LiteType::GetTensorTy(TARGET(kXPU))})
.BindInput("grnn_rv_wi", {LiteType::GetTensorTy(TARGET(kXPU))})
.BindInput("att_fc_w", {LiteType::GetTensorTy(TARGET(kXPU))})
.BindInput("att_fc_b", {LiteType::GetTensorTy(TARGET(kXPU))})
.BindOutput("emb0_out", {LiteType::GetTensorTy(TARGET(kXPU))})
.BindOutput("grnn_fw_pool_out", {LiteType::GetTensorTy(TARGET(kXPU))})
.BindOutput("grnn_rv_pool_out", {LiteType::GetTensorTy(TARGET(kXPU))})
.BindOutput("att_pool_out", {LiteType::GetTensorTy(TARGET(kXPU))})
.BindOutput("concat_3in1_out", {LiteType::GetTensorTy(TARGET(kXPU))})
.BindOutput("emb_fw_out", {LiteType::GetTensorTy(TARGET(kXPU))})
.Finalize();
REGISTER_LITE_KERNEL(__xpu__mmdnn_bid_emb_att, REGISTER_LITE_KERNEL(__xpu__mmdnn_bid_emb_att,
kXPU, kXPU,
kFloat, kFloat,
...@@ -1371,7 +1499,7 @@ REGISTER_LITE_KERNEL(__xpu__mmdnn_merge_all, ...@@ -1371,7 +1499,7 @@ REGISTER_LITE_KERNEL(__xpu__mmdnn_merge_all,
paddle::lite::kernels::xpu::XPUMmdnnMergeAllCompute, paddle::lite::kernels::xpu::XPUMmdnnMergeAllCompute,
def) def)
.BindInput("concat_7in1_x", {LiteType::GetTensorTy(TARGET(kXPU))}) .BindInput("concat_7in1_x", {LiteType::GetTensorTy(TARGET(kXPU))})
.BindInput("concat_2in1_x", {LiteType::GetTensorTy(TARGET(kXPU))}) .BindInput("concat_topk_x", {LiteType::GetTensorTy(TARGET(kXPU))})
.BindInput("grnn_fw_wh", {LiteType::GetTensorTy(TARGET(kXPU))}) .BindInput("grnn_fw_wh", {LiteType::GetTensorTy(TARGET(kXPU))})
.BindInput("grnn_fw_wi", {LiteType::GetTensorTy(TARGET(kXPU))}) .BindInput("grnn_fw_wi", {LiteType::GetTensorTy(TARGET(kXPU))})
.BindInput("grnn_rv_wh", {LiteType::GetTensorTy(TARGET(kXPU))}) .BindInput("grnn_rv_wh", {LiteType::GetTensorTy(TARGET(kXPU))})
......
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
// limitations under the License. // limitations under the License.
#pragma once #pragma once
#include <vector> #include <vector>
#include "lite/backends/xpu/xpu_header_sitter.h" #include "lite/backends/xpu/xpu_header_sitter.h"
#include "lite/core/kernel.h" #include "lite/core/kernel.h"
......
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
// limitations under the License. // limitations under the License.
#pragma once #pragma once
#include <vector> #include <vector>
#include "lite/backends/xpu/xpu_header_sitter.h" #include "lite/backends/xpu/xpu_header_sitter.h"
#include "lite/core/kernel.h" #include "lite/core/kernel.h"
......
...@@ -22,16 +22,19 @@ namespace kernels { ...@@ -22,16 +22,19 @@ namespace kernels {
namespace xpu { namespace xpu {
void XPUMmdnnSearchAttentionCompute::PrepareForRun() { void XPUMmdnnSearchAttentionCompute::PrepareForRun() {
offset_xpu_guard_ = TargetWrapperXPU::MallocScratchPad(64 * sizeof(int)); offset_xpu_guard_ = TargetWrapperXPU::MallocScratchPad(
pad_begin_xpu_guard_ = TargetWrapperXPU::MallocScratchPad(64 * sizeof(int)); XPU_MAX_LOD_SIZE * sizeof(int), false /* use_l3 */);
w_max_xpu_guard_ = TargetWrapperXPU::MallocScratchPad(8 * sizeof(float)); pad_begin_xpu_guard_ = TargetWrapperXPU::MallocScratchPad(
XPU_MAX_LOD_SIZE * sizeof(int), false /* use_l3 */);
w_max_xpu_guard_ =
TargetWrapperXPU::MallocScratchPad(8 * sizeof(float), false /* use_l3 */);
buffer_at_l3_guard_ = TargetWrapperXPU::MallocScratchPad( buffer_at_l3_guard_ = TargetWrapperXPU::MallocScratchPad(
5 * L3_SLOT_SIZE * sizeof(float), false /* use_l3 */); 5 * L3_SLOT_SIZE * sizeof(float), false /* use_l3 */);
buffer_at_gm_guard_ = TargetWrapperXPU::MallocScratchPad( buffer_at_gm_guard_ = TargetWrapperXPU::MallocScratchPad(
5 * GM_SLOT_SIZE * sizeof(float), false /* use_l3 */); 5 * GM_SLOT_SIZE * sizeof(float), false /* use_l3 */);
offset_cpu.reset(new int[64]); offset_cpu.reset(new int[XPU_MAX_LOD_SIZE]);
pad_begin_cpu.reset(new int[64]); pad_begin_cpu.reset(new int[XPU_MAX_LOD_SIZE]);
} }
void XPUMmdnnSearchAttentionCompute::Run() { void XPUMmdnnSearchAttentionCompute::Run() {
...@@ -72,18 +75,18 @@ void XPUMmdnnSearchAttentionCompute::Run() { ...@@ -72,18 +75,18 @@ void XPUMmdnnSearchAttentionCompute::Run() {
} }
offset_cpu[batch] = offset[batch]; offset_cpu[batch] = offset[batch];
xpu_memcpy(offset_xpu_guard_->addr_, XPU_CALL(xpu_memcpy(offset_xpu_guard_->addr_,
offset_cpu.get(), offset_cpu.get(),
offset.size() * sizeof(int), offset.size() * sizeof(int),
XPUMemcpyKind::XPU_HOST_TO_DEVICE); XPUMemcpyKind::XPU_HOST_TO_DEVICE));
xpu_memcpy(pad_begin_xpu_guard_->addr_, XPU_CALL(xpu_memcpy(pad_begin_xpu_guard_->addr_,
pad_begin_cpu.get(), pad_begin_cpu.get(),
batch * sizeof(int), batch * sizeof(int),
XPUMemcpyKind::XPU_HOST_TO_DEVICE); XPUMemcpyKind::XPU_HOST_TO_DEVICE));
xpu_memcpy(w_max_xpu_guard_->addr_, XPU_CALL(xpu_memcpy(w_max_xpu_guard_->addr_,
maxs_cpu, maxs_cpu,
8 * sizeof(float), 8 * sizeof(float),
XPUMemcpyKind::XPU_HOST_TO_DEVICE); XPUMemcpyKind::XPU_HOST_TO_DEVICE));
int* offset_xpu = reinterpret_cast<int*>(offset_xpu_guard_->addr_); int* offset_xpu = reinterpret_cast<int*>(offset_xpu_guard_->addr_);
int* pad_begin_xpu = reinterpret_cast<int*>(pad_begin_xpu_guard_->addr_); int* pad_begin_xpu = reinterpret_cast<int*>(pad_begin_xpu_guard_->addr_);
...@@ -115,90 +118,99 @@ void XPUMmdnnSearchAttentionCompute::Run() { ...@@ -115,90 +118,99 @@ void XPUMmdnnSearchAttentionCompute::Run() {
} }
const auto* bottom_data = X->data<float>(); const auto* bottom_data = X->data<float>();
xdnn::search_sequence_pad_depad(ctx.GetRawContext(), int r = 0;
const_cast<float*>(bottom_data), r = xdnn::search_sequence_pad_depad(ctx.GetRawContext(),
group_padding_output, const_cast<float*>(bottom_data),
offset_xpu, group_padding_output,
max_seq, offset_xpu,
batch, max_seq,
dim1, batch,
0); // is_depad = 0 dim1,
0); // is_depad = 0
CHECK_EQ(r, 0);
// do-findmax // do-findmax
xdnn::findmax<float>(ctx.GetRawContext(), r = xdnn::findmax<float>(ctx.GetRawContext(),
group_padding_output, group_padding_output,
batch * max_seq * dim1, batch * max_seq * dim1,
maxs_xpu); maxs_xpu);
xdnn::gemm_int16_maxptr<float, int16_t, float>( CHECK_EQ(r, 0);
ctx.GetRawContext(), r = xdnn::gemm_int16_maxptr<float, int16_t, float>(
false, ctx.GetRawContext(), /* ctx */
true, // trans_a, trans_b false, /* trans_a */
batch * max_seq, true, /* trans_b */
dim1, batch * max_seq, /* m */
dim1, // m, n, k dim1, /* n */
1.0f, dim1, /* k */
group_padding_output, 1.0f, /* alpha */
dim1, // alpha, data_a, lda group_padding_output, /* data_a */
w_data, dim1, /* lda */
dim1, w_data, /* data_b */
0.0f, // data_b, ldb, beta dim1, /* ldb */
seq_fc_output, 0.0f, /* beta */
dim1, seq_fc_output, /* data_c */
b_data, // data_c, ldc, bias dim1, /* ldc */
xdnn::Activation_t::LINEAR, b_data, /* bias */
maxs_xpu, xdnn::Activation_t::LINEAR, /* act */
maxs_xpu + 4, maxs_xpu, /* max_a */
nullptr); // max_a, max_b, max_c maxs_xpu + 4, /* max_b */
xdnn::search_aligned_mat_mul(ctx.GetRawContext(), nullptr /* max_c */);
0, CHECK_EQ(r, 0);
1, r = xdnn::search_aligned_mat_mul(ctx.GetRawContext(),
batch, 0,
max_seq, 1,
max_seq, batch,
dim1, max_seq,
alpha0, max_seq,
group_padding_output, dim1,
dim1, alpha0,
seq_fc_output, group_padding_output,
dim1, dim1,
batchgemm0_output, seq_fc_output,
max_seq); dim1,
xdnn::search_pad_mask(ctx.GetRawContext(), batchgemm0_output,
batchgemm0_output, max_seq);
attention_output, CHECK_EQ(r, 0);
pad_begin_xpu, r = xdnn::search_pad_mask(ctx.GetRawContext(),
batch, batchgemm0_output,
max_seq, attention_output,
max_seq, pad_begin_xpu,
batch, batch,
mask); max_seq,
xdnn::softmax2d_forward(ctx.GetRawContext(), max_seq,
attention_output, batch,
seq_softmax_output, mask);
batch * max_seq, CHECK_EQ(r, 0);
max_seq, r = xdnn::softmax2d_forward(ctx.GetRawContext(),
true); attention_output,
xdnn::search_aligned_mat_mul(ctx.GetRawContext(), seq_softmax_output,
0, batch * max_seq,
0, max_seq,
batch, true);
max_seq, CHECK_EQ(r, 0);
dim1, r = xdnn::search_aligned_mat_mul(ctx.GetRawContext(),
max_seq, 0,
alpha1, 0,
seq_softmax_output, batch,
max_seq, max_seq,
group_padding_output, dim1,
dim1, max_seq,
batchgemm1_output, alpha1,
dim1); seq_softmax_output,
xdnn::search_sequence_pad_depad(ctx.GetRawContext(), max_seq,
top_data, group_padding_output,
batchgemm1_output, dim1,
offset_xpu, batchgemm1_output,
max_seq, dim1);
batch, CHECK_EQ(r, 0);
dim1, r = xdnn::search_sequence_pad_depad(ctx.GetRawContext(),
1); // is_depad = 1 top_data,
batchgemm1_output,
offset_xpu,
max_seq,
batch,
dim1,
1); // is_depad = 1
CHECK_EQ(r, 0);
} }
} // namespace xpu } // namespace xpu
......
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
// limitations under the License. // limitations under the License.
#pragma once #pragma once
#include "lite/core/kernel.h" #include "lite/core/kernel.h"
namespace paddle { namespace paddle {
......
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
// limitations under the License. // limitations under the License.
#pragma once #pragma once
#include "lite/core/kernel.h" #include "lite/core/kernel.h"
namespace paddle { namespace paddle {
......
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
// limitations under the License. // limitations under the License.
#pragma once #pragma once
#include "lite/core/kernel.h" #include "lite/core/kernel.h"
namespace paddle { namespace paddle {
......
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
// limitations under the License. // limitations under the License.
#pragma once #pragma once
#include "lite/core/kernel.h" #include "lite/core/kernel.h"
namespace paddle { namespace paddle {
......
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
// limitations under the License. // limitations under the License.
#pragma once #pragma once
#include "lite/core/kernel.h" #include "lite/core/kernel.h"
namespace paddle { namespace paddle {
......
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
// limitations under the License. // limitations under the License.
#pragma once #pragma once
#include "lite/core/kernel.h" #include "lite/core/kernel.h"
namespace paddle { namespace paddle {
......
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
// limitations under the License. // limitations under the License.
#pragma once #pragma once
#include "lite/core/kernel.h" #include "lite/core/kernel.h"
namespace paddle { namespace paddle {
......
...@@ -29,12 +29,13 @@ void LookupTableCompute::Run() { ...@@ -29,12 +29,13 @@ void LookupTableCompute::Run() {
int embed_dim = param.W->dims()[1]; int embed_dim = param.W->dims()[1];
int r = xdnn::embedding<float, int64_t>( int r = xdnn::embedding<float, int64_t>(
ctx.GetRawContext(), /* context */ ctx.GetRawContext(), /* context */
num, /* num */ num, /* num */
param.Ids->data<int64_t>(), /* indices */ param.Ids->data<int64_t>(), /* indices */
embed_dim, /* embed_dim */ embed_dim, /* embed_dim */
param.W->data<float>(), /* table */ param.W->data<float>(), /* table */
param.Out->mutable_data<float>(TARGET(kXPU)) /* top */); param.Out->mutable_data<float>(TARGET(kXPU)), /* top */
param.padding_idx /* padding_idx */);
CHECK_EQ(r, 0); CHECK_EQ(r, 0);
} }
......
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
// limitations under the License. // limitations under the License.
#pragma once #pragma once
#include "lite/core/kernel.h" #include "lite/core/kernel.h"
namespace paddle { namespace paddle {
......
...@@ -23,12 +23,15 @@ namespace kernels { ...@@ -23,12 +23,15 @@ namespace kernels {
namespace xpu { namespace xpu {
void MatchMatrixTensorCompute::PrepareForRun() { void MatchMatrixTensorCompute::PrepareForRun() {
wx_max_xpu_guard_ = TargetWrapperXPU::MallocScratchPad(64 * sizeof(int)); wx_max_xpu_guard_ = TargetWrapperXPU::MallocScratchPad(
offset_l_xpu_guard_ = TargetWrapperXPU::MallocScratchPad(64 * sizeof(int)); XPU_MAX_LOD_SIZE * sizeof(int), false /* use_l3 */);
offset_r_xpu_guard_ = TargetWrapperXPU::MallocScratchPad(64 * sizeof(int)); offset_l_xpu_guard_ = TargetWrapperXPU::MallocScratchPad(
XPU_MAX_LOD_SIZE * sizeof(int), false /* use_l3 */);
offset_l_cpu.reset(new int[64]); offset_r_xpu_guard_ = TargetWrapperXPU::MallocScratchPad(
offset_r_cpu.reset(new int[64]); XPU_MAX_LOD_SIZE * sizeof(int), false /* use_l3 */);
offset_l_cpu.reset(new int[XPU_MAX_LOD_SIZE]);
offset_r_cpu.reset(new int[XPU_MAX_LOD_SIZE]);
} }
void MatchMatrixTensorCompute::Run() { void MatchMatrixTensorCompute::Run() {
...@@ -76,25 +79,25 @@ void MatchMatrixTensorCompute::Run() { ...@@ -76,25 +79,25 @@ void MatchMatrixTensorCompute::Run() {
int* offset_r_xpu = reinterpret_cast<int*>(offset_r_xpu_guard_->addr_); int* offset_r_xpu = reinterpret_cast<int*>(offset_r_xpu_guard_->addr_);
int r = xdnn::gemm_int16_tmp_api<float, int16_t, float>( int r = xdnn::gemm_int16_tmp_api<float, int16_t, float>(
ctx.GetRawContext(), /* ctx */ ctx.GetRawContext(), /* ctx */
false, false, /* trans_a */
false, /* trans_a, trans_b */ false, /* trans_b */
x->dims()[0], x->dims()[0], /* m */
dim_t * dim_in, dim_t * dim_in, /* n */
dim_in, /* m, n, k */ dim_in, /* k */
1.0f, 1.0f, /* alpha */
bottom_l_data, bottom_l_data, /* data_a */
dim_in, /* alpha, data_a, lda */ dim_in, /* lda */
w_data, w_data, /* data_b */
dim_t * dim_in, dim_t * dim_in, /* ldb */
0.0f, /* data_b, ldb, beta */ 0.0f, /* beta */
bottom_l_trans_data, bottom_l_trans_data, /* data_c */
dim_t * dim_in, /* data_c, ldc */ dim_t * dim_in, /* ldc */
nullptr, /* bias */ nullptr, /* bias */
xdnn::Activation_t::LINEAR, xdnn::Activation_t::LINEAR, /* act */
0.0f, 0.0f, /* max_a */
w_max, w_max, /* max_b */
wx_max /* max_a, max_b, max_c */); wx_max /* max_c */);
CHECK_EQ(r, 0); CHECK_EQ(r, 0);
int max_width = 0; int max_width = 0;
...@@ -110,14 +113,14 @@ void MatchMatrixTensorCompute::Run() { ...@@ -110,14 +113,14 @@ void MatchMatrixTensorCompute::Run() {
max_width = offset_r_cpu[i] - offset_r_cpu[i - 1]; max_width = offset_r_cpu[i] - offset_r_cpu[i - 1];
} }
} }
xpu_memcpy(offset_l_xpu, XPU_CALL(xpu_memcpy(offset_l_xpu,
offset_l_cpu.get(), offset_l_cpu.get(),
offset_l.size() * sizeof(int), offset_l.size() * sizeof(int),
XPUMemcpyKind::XPU_HOST_TO_DEVICE); XPUMemcpyKind::XPU_HOST_TO_DEVICE));
xpu_memcpy(offset_r_xpu, XPU_CALL(xpu_memcpy(offset_r_xpu,
offset_r_cpu.get(), offset_r_cpu.get(),
offset_r.size() * sizeof(int), offset_r.size() * sizeof(int),
XPUMemcpyKind::XPU_HOST_TO_DEVICE); XPUMemcpyKind::XPU_HOST_TO_DEVICE));
r = xdnn::match_matrix_tensor(ctx.GetRawContext(), r = xdnn::match_matrix_tensor(ctx.GetRawContext(),
batch_size, batch_size,
......
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
// limitations under the License. // limitations under the License.
#pragma once #pragma once
#include "lite/core/kernel.h" #include "lite/core/kernel.h"
namespace paddle { namespace paddle {
......
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
// limitations under the License. // limitations under the License.
#pragma once #pragma once
#include "lite/core/kernel.h" #include "lite/core/kernel.h"
namespace paddle { namespace paddle {
......
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
// limitations under the License. // limitations under the License.
#pragma once #pragma once
#include "lite/core/kernel.h" #include "lite/core/kernel.h"
namespace paddle { namespace paddle {
......
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
// limitations under the License. // limitations under the License.
#pragma once #pragma once
#include "lite/core/kernel.h" #include "lite/core/kernel.h"
namespace paddle { namespace paddle {
......
...@@ -23,7 +23,8 @@ namespace kernels { ...@@ -23,7 +23,8 @@ namespace kernels {
namespace xpu { namespace xpu {
void SearchFcCompute::PrepareForRun() { void SearchFcCompute::PrepareForRun() {
maxs_xpu_guard_ = TargetWrapperXPU::MallocScratchPad(64 * sizeof(float)); maxs_xpu_guard_ = TargetWrapperXPU::MallocScratchPad(
XPU_MAX_LOD_SIZE * sizeof(float), false /* use_l3 */);
} }
void SearchFcCompute::Run() { void SearchFcCompute::Run() {
...@@ -59,34 +60,34 @@ void SearchFcCompute::Run() { ...@@ -59,34 +60,34 @@ void SearchFcCompute::Run() {
float* maxs_xpu = reinterpret_cast<float*>(maxs_xpu_guard_->addr_); float* maxs_xpu = reinterpret_cast<float*>(maxs_xpu_guard_->addr_);
float maxs_cpu[8] = {0.0f, 0.0f, 0.0f, 0.0f, w_max, 0.0f, 0.0f, 0.0f}; float maxs_cpu[8] = {0.0f, 0.0f, 0.0f, 0.0f, w_max, 0.0f, 0.0f, 0.0f};
xpu_memcpy(maxs_xpu, XPU_CALL(xpu_memcpy(maxs_xpu,
&maxs_cpu[0], &maxs_cpu[0],
8 * sizeof(float), 8 * sizeof(float),
XPUMemcpyKind::XPU_HOST_TO_DEVICE); XPUMemcpyKind::XPU_HOST_TO_DEVICE));
int r = xdnn::findmax<float>( int r = xdnn::findmax<float>(
ctx.GetRawContext(), bottom_data, batch * _in, maxs_xpu); ctx.GetRawContext(), bottom_data, batch * _in, maxs_xpu);
CHECK_EQ(r, 0); CHECK_EQ(r, 0);
r = xdnn::gemm_int16_maxptr<float, int16_t, float>( r = xdnn::gemm_int16_maxptr<float, int16_t, float>(
ctx.GetRawContext(), /* ctx */ ctx.GetRawContext(), /* ctx */
false, false, /* trans_a */
true, /*trans_a, trans_b*/ true, /* trans_b */
batch, batch, /* m */
_out, _out, /* n */
_in, /*m, n, k*/ _in, /* k */
1.0f, 1.0f, /* alpha */
bottom_data, bottom_data, /* data_a */
_in, /*alpha, data_a, lda*/ _in, /* lda */
weights, weights, /* data_b */
_in, _in, /* ldb */
0.0f, /*data_b, ldb, beta*/ 0.0f, /* beta */
top_data, top_data, /* data_c */
_out, _out, /* ldc */
bias_data, /* data_c, ldc, bias*/ bias_data, /* bias */
act, act, /* act */
maxs_xpu, maxs_xpu, /* max_a */
maxs_xpu + 4, maxs_xpu + 4, /* max_b */
nullptr /*act, max_a, max_b, max_c*/); nullptr /* max_c */);
CHECK_EQ(r, 0); CHECK_EQ(r, 0);
} }
......
...@@ -24,13 +24,16 @@ namespace kernels { ...@@ -24,13 +24,16 @@ namespace kernels {
namespace xpu { namespace xpu {
void SearchGrnnCompute::PrepareForRun() { void SearchGrnnCompute::PrepareForRun() {
offset_xpu_guard_ = TargetWrapperXPU::MallocScratchPad(64 * sizeof(int)); offset_xpu_guard_ = TargetWrapperXPU::MallocScratchPad(
new_offset_xpu_guard_ = TargetWrapperXPU::MallocScratchPad(256 * sizeof(int)); XPU_MAX_LOD_SIZE * sizeof(int), false /* use_l3 */);
maxs_xpu_guard_ = TargetWrapperXPU::MallocScratchPad(16 * sizeof(float)); new_offset_xpu_guard_ = TargetWrapperXPU::MallocScratchPad(
XPU_MAX_LOD_SEQ_LEN * sizeof(int), false /* use_l3 */);
maxs_xpu_guard_ = TargetWrapperXPU::MallocScratchPad(16 * sizeof(float),
false /* use_l3 */);
idx_sorted_by_width_data_cpu.reset(new int[64]); idx_sorted_by_width_data_cpu.reset(new int[XPU_MAX_LOD_SIZE]);
offset_cpu.reset(new int[64]); offset_cpu.reset(new int[XPU_MAX_LOD_SIZE]);
new_offset_cpu.reset(new int[256]); new_offset_cpu.reset(new int[XPU_MAX_LOD_SEQ_LEN]);
} }
void SearchGrnnCompute::prepare_layout(const operators::SearchGrnnParam& param, void SearchGrnnCompute::prepare_layout(const operators::SearchGrnnParam& param,
...@@ -96,10 +99,10 @@ void SearchGrnnCompute::prepare_layout(const operators::SearchGrnnParam& param, ...@@ -96,10 +99,10 @@ void SearchGrnnCompute::prepare_layout(const operators::SearchGrnnParam& param,
layout_input->Resize({dim0, dim1}); layout_input->Resize({dim0, dim1});
} }
xpu_memcpy(idx_sorted_by_width->mutable_data<int>(TARGET(kXPU)), XPU_CALL(xpu_memcpy(idx_sorted_by_width->mutable_data<int>(TARGET(kXPU)),
idx_sorted_by_width_data_cpu.get(), idx_sorted_by_width_data_cpu.get(),
idx_sorted_by_width->numel() * sizeof(int), idx_sorted_by_width->numel() * sizeof(int),
XPUMemcpyKind::XPU_HOST_TO_DEVICE); XPUMemcpyKind::XPU_HOST_TO_DEVICE));
} }
void SearchGrnnCompute::Run() { void SearchGrnnCompute::Run() {
...@@ -156,14 +159,14 @@ void SearchGrnnCompute::Run() { ...@@ -156,14 +159,14 @@ void SearchGrnnCompute::Run() {
for (size_t i = 0; i < new_offset.size(); ++i) { for (size_t i = 0; i < new_offset.size(); ++i) {
new_offset_cpu[i] = new_offset[i]; new_offset_cpu[i] = new_offset[i];
} }
xpu_memcpy(offset_xpu, XPU_CALL(xpu_memcpy(offset_xpu,
offset_cpu.get(), offset_cpu.get(),
offset.size() * sizeof(int), offset.size() * sizeof(int),
XPUMemcpyKind::XPU_HOST_TO_DEVICE); XPUMemcpyKind::XPU_HOST_TO_DEVICE));
xpu_memcpy(new_offset_xpu, XPU_CALL(xpu_memcpy(new_offset_xpu,
new_offset_cpu.get(), new_offset_cpu.get(),
new_offset.size() * sizeof(int), new_offset.size() * sizeof(int),
XPUMemcpyKind::XPU_HOST_TO_DEVICE); XPUMemcpyKind::XPU_HOST_TO_DEVICE));
int r = xdnn::search_seq2batch(ctx.GetRawContext(), int r = xdnn::search_seq2batch(ctx.GetRawContext(),
batch, batch,
...@@ -200,10 +203,10 @@ void SearchGrnnCompute::Run() { ...@@ -200,10 +203,10 @@ void SearchGrnnCompute::Run() {
0.0f, 0.0f,
0.0f, 0.0f,
0.0f}; 0.0f};
xpu_memcpy(maxs_xpu, XPU_CALL(xpu_memcpy(maxs_xpu,
maxs_cpu, maxs_cpu,
16 * sizeof(float), 16 * sizeof(float),
XPUMemcpyKind::XPU_HOST_TO_DEVICE); XPUMemcpyKind::XPU_HOST_TO_DEVICE));
r = xdnn::findmax<float>( r = xdnn::findmax<float>(
ctx.GetRawContext(), new_emb, cap_l * cap_e, maxs_xpu); ctx.GetRawContext(), new_emb, cap_l * cap_e, maxs_xpu);
CHECK_EQ(r, 0); CHECK_EQ(r, 0);
......
...@@ -37,44 +37,54 @@ void SequenceArithmeticCompute::Run() { ...@@ -37,44 +37,54 @@ void SequenceArithmeticCompute::Run() {
const auto* bottom_data1 = bottom1->data<float>(); const auto* bottom_data1 = bottom1->data<float>();
auto* top_data = top->mutable_data<float>(TARGET(kXPU)); auto* top_data = top->mutable_data<float>(TARGET(kXPU));
int r = 0;
switch (op_type) { switch (op_type) {
case 1: // addition: top[0] = bottom[0] + bottom[1] case 1: // addition: top[0] = bottom[0] + bottom[1]
if (len1 > len2) { if (len1 > len2) {
xdnn::elementwise_add( r = xdnn::elementwise_add(
ctx.GetRawContext(), bottom_data0, bottom_data1, top_data, len2); ctx.GetRawContext(), bottom_data0, bottom_data1, top_data, len2);
xdnn::memcpy_device(ctx.GetRawContext(), CHECK_EQ(r, 0);
&top_data[len2], r = xdnn::memcpy_device(ctx.GetRawContext(),
&bottom_data0[len2], &top_data[len2],
(len1 - len2) * sizeof(float)); &bottom_data0[len2],
(len1 - len2) * sizeof(float));
CHECK_EQ(r, 0);
} else { } else {
xdnn::elementwise_add( r = xdnn::elementwise_add(
ctx.GetRawContext(), bottom_data0, bottom_data1, top_data, len1); ctx.GetRawContext(), bottom_data0, bottom_data1, top_data, len1);
CHECK_EQ(r, 0);
} }
break; break;
case 2: // substraction: top[0] = bottom[0] - bottom[1] case 2: // substraction: top[0] = bottom[0] - bottom[1]
if (len1 > len2) { if (len1 > len2) {
xdnn::elementwise_sub( r = xdnn::elementwise_sub(
ctx.GetRawContext(), bottom_data0, bottom_data1, top_data, len2); ctx.GetRawContext(), bottom_data0, bottom_data1, top_data, len2);
xdnn::memcpy_device(ctx.GetRawContext(), CHECK_EQ(r, 0);
&top_data[len2], r = xdnn::memcpy_device(ctx.GetRawContext(),
&bottom_data0[len2], &top_data[len2],
(len1 - len2) * sizeof(float)); &bottom_data0[len2],
(len1 - len2) * sizeof(float));
CHECK_EQ(r, 0);
} else { } else {
xdnn::elementwise_sub( r = xdnn::elementwise_sub(
ctx.GetRawContext(), bottom_data0, bottom_data1, top_data, len1); ctx.GetRawContext(), bottom_data0, bottom_data1, top_data, len1);
CHECK_EQ(r, 0);
} }
break; break;
case 3: // multiplication: top[0] = bottom[0] * bottom[1] case 3: // multiplication: top[0] = bottom[0] * bottom[1]
if (len1 > len2) { if (len1 > len2) {
xdnn::elementwise_mul( r = xdnn::elementwise_mul(
ctx.GetRawContext(), bottom_data0, bottom_data1, top_data, len2); ctx.GetRawContext(), bottom_data0, bottom_data1, top_data, len2);
xdnn::memcpy_device(ctx.GetRawContext(), CHECK_EQ(r, 0);
&top_data[len2], r = xdnn::memcpy_device(ctx.GetRawContext(),
&bottom_data0[len2], &top_data[len2],
(len1 - len2) * sizeof(float)); &bottom_data0[len2],
(len1 - len2) * sizeof(float));
CHECK_EQ(r, 0);
} else { } else {
xdnn::elementwise_mul( r = xdnn::elementwise_mul(
ctx.GetRawContext(), bottom_data0, bottom_data1, top_data, len1); ctx.GetRawContext(), bottom_data0, bottom_data1, top_data, len1);
CHECK_EQ(r, 0);
} }
break; break;
default: default:
......
...@@ -23,11 +23,13 @@ namespace kernels { ...@@ -23,11 +23,13 @@ namespace kernels {
namespace xpu { namespace xpu {
void SequenceConcatCompute::PrepareForRun() { void SequenceConcatCompute::PrepareForRun() {
lod0_xpu_guard_ = TargetWrapperXPU::MallocScratchPad(64 * sizeof(int)); lod0_xpu_guard_ = TargetWrapperXPU::MallocScratchPad(
lod1_xpu_guard_ = TargetWrapperXPU::MallocScratchPad(64 * sizeof(int)); XPU_MAX_LOD_SIZE * sizeof(int), false /* use_l3 */);
lod1_xpu_guard_ = TargetWrapperXPU::MallocScratchPad(
XPU_MAX_LOD_SIZE * sizeof(int), false /* use_l3 */);
lod0_cpu.reset(new int[64]); lod0_cpu.reset(new int[XPU_MAX_LOD_SIZE]);
lod1_cpu.reset(new int[64]); lod1_cpu.reset(new int[XPU_MAX_LOD_SIZE]);
} }
template <typename T> template <typename T>
...@@ -106,14 +108,14 @@ void SequenceConcatCompute::Run() { ...@@ -106,14 +108,14 @@ void SequenceConcatCompute::Run() {
for (int i = 0; i < lod1.size(); ++i) { for (int i = 0; i < lod1.size(); ++i) {
lod1_cpu[i] = lod1[i]; lod1_cpu[i] = lod1[i];
} }
xpu_memcpy(lod0_xpu, XPU_CALL(xpu_memcpy(lod0_xpu,
lod0_cpu.get(), lod0_cpu.get(),
lod0.size() * sizeof(int), lod0.size() * sizeof(int),
XPUMemcpyKind::XPU_HOST_TO_DEVICE); XPUMemcpyKind::XPU_HOST_TO_DEVICE));
xpu_memcpy(lod1_xpu, XPU_CALL(xpu_memcpy(lod1_xpu,
lod1_cpu.get(), lod1_cpu.get(),
lod1.size() * sizeof(int), lod1.size() * sizeof(int),
XPUMemcpyKind::XPU_HOST_TO_DEVICE); XPUMemcpyKind::XPU_HOST_TO_DEVICE));
int r = xdnn::sequence_concat(ctx.GetRawContext(), int r = xdnn::sequence_concat(ctx.GetRawContext(),
xs[0]->data<float>(), xs[0]->data<float>(),
......
...@@ -23,8 +23,9 @@ namespace kernels { ...@@ -23,8 +23,9 @@ namespace kernels {
namespace xpu { namespace xpu {
void XPUSequencePoolCompute::PrepareForRun() { void XPUSequencePoolCompute::PrepareForRun() {
lod_xpu_guard_ = TargetWrapperXPU::MallocScratchPad(64 * sizeof(int)); lod_xpu_guard_ = TargetWrapperXPU::MallocScratchPad(
lod_cpu.reset(new int[64]); XPU_MAX_LOD_SIZE * sizeof(int), false /* use_l3 */);
lod_cpu.reset(new int[XPU_MAX_LOD_SIZE]);
} }
void XPUSequencePoolCompute::Run() { void XPUSequencePoolCompute::Run() {
...@@ -55,10 +56,10 @@ void XPUSequencePoolCompute::Run() { ...@@ -55,10 +56,10 @@ void XPUSequencePoolCompute::Run() {
lod_cpu[i] = in_lod[i]; lod_cpu[i] = in_lod[i];
} }
int* lod_xpu = reinterpret_cast<int*>(lod_xpu_guard_->addr_); int* lod_xpu = reinterpret_cast<int*>(lod_xpu_guard_->addr_);
xpu_memcpy(lod_xpu, XPU_CALL(xpu_memcpy(lod_xpu,
lod_cpu.get(), lod_cpu.get(),
in_lod.size() * sizeof(int), in_lod.size() * sizeof(int),
XPUMemcpyKind::XPU_HOST_TO_DEVICE); XPUMemcpyKind::XPU_HOST_TO_DEVICE));
int r = int r =
xdnn::sequence_pooling_forward(ctx.GetRawContext(), xdnn::sequence_pooling_forward(ctx.GetRawContext(),
......
...@@ -23,8 +23,9 @@ namespace xpu { ...@@ -23,8 +23,9 @@ namespace xpu {
template <typename T, PrecisionType PType> template <typename T, PrecisionType PType>
void SequenceReverseCompute<T, PType>::PrepareForRun() { void SequenceReverseCompute<T, PType>::PrepareForRun() {
lod_xpu_guard_ = TargetWrapperXPU::MallocScratchPad(64 * sizeof(int)); lod_xpu_guard_ = TargetWrapperXPU::MallocScratchPad(
lod_cpu.reset(new int[64]); XPU_MAX_LOD_SIZE * sizeof(int), false /* use_l3 */);
lod_cpu.reset(new int[XPU_MAX_LOD_SIZE]);
} }
template <typename T, PrecisionType PType> template <typename T, PrecisionType PType>
...@@ -58,10 +59,10 @@ void SequenceReverseCompute<T, PType>::Run() { ...@@ -58,10 +59,10 @@ void SequenceReverseCompute<T, PType>::Run() {
lod_cpu[i] = lod[i]; lod_cpu[i] = lod[i];
} }
int* lod_xpu = reinterpret_cast<int*>(lod_xpu_guard_->addr_); int* lod_xpu = reinterpret_cast<int*>(lod_xpu_guard_->addr_);
xpu_memcpy(lod_xpu, XPU_CALL(xpu_memcpy(lod_xpu,
lod_cpu.get(), lod_cpu.get(),
lod.size() * sizeof(int), lod.size() * sizeof(int),
XPUMemcpyKind::XPU_HOST_TO_DEVICE); XPUMemcpyKind::XPU_HOST_TO_DEVICE));
int r = xdnn::sequence_reverse(ctx.GetRawContext(), int r = xdnn::sequence_reverse(ctx.GetRawContext(),
batch_size, batch_size,
......
...@@ -23,10 +23,11 @@ namespace kernels { ...@@ -23,10 +23,11 @@ namespace kernels {
namespace xpu { namespace xpu {
void SequenceTopkAvgPoolingCompute::PrepareForRun() { void SequenceTopkAvgPoolingCompute::PrepareForRun() {
lod_xpu_guard_ = TargetWrapperXPU::MallocScratchPad(256 * sizeof(int)); lod_xpu_guard_ = TargetWrapperXPU::MallocScratchPad(
in_lod_cpu.reset(new int[64]); 4 * XPU_MAX_LOD_SIZE * sizeof(int), false /* use_l3 */);
row_lod_cpu.reset(new int[64]); in_lod_cpu.reset(new int[XPU_MAX_LOD_SIZE]);
col_lod_cpu.reset(new int[64]); row_lod_cpu.reset(new int[XPU_MAX_LOD_SIZE]);
col_lod_cpu.reset(new int[XPU_MAX_LOD_SIZE]);
} }
void SequenceTopkAvgPoolingCompute::Run() { void SequenceTopkAvgPoolingCompute::Run() {
...@@ -81,22 +82,22 @@ void SequenceTopkAvgPoolingCompute::Run() { ...@@ -81,22 +82,22 @@ void SequenceTopkAvgPoolingCompute::Run() {
for (int i = 0; i < col_lod.size(); ++i) { for (int i = 0; i < col_lod.size(); ++i) {
col_lod_cpu[i] = col_lod[i]; col_lod_cpu[i] = col_lod[i];
} }
xpu_memcpy(in_lod_xpu, XPU_CALL(xpu_memcpy(in_lod_xpu,
in_lod_cpu.get(), in_lod_cpu.get(),
in_lod.size() * sizeof(int), in_lod.size() * sizeof(int),
XPUMemcpyKind::XPU_HOST_TO_DEVICE); XPUMemcpyKind::XPU_HOST_TO_DEVICE));
xpu_memcpy(row_lod_xpu, XPU_CALL(xpu_memcpy(row_lod_xpu,
row_lod_cpu.get(), row_lod_cpu.get(),
row_lod.size() * sizeof(int), row_lod.size() * sizeof(int),
XPUMemcpyKind::XPU_HOST_TO_DEVICE); XPUMemcpyKind::XPU_HOST_TO_DEVICE));
xpu_memcpy(col_lod_xpu, XPU_CALL(xpu_memcpy(col_lod_xpu,
col_lod_cpu.get(), col_lod_cpu.get(),
col_lod.size() * sizeof(int), col_lod.size() * sizeof(int),
XPUMemcpyKind::XPU_HOST_TO_DEVICE); XPUMemcpyKind::XPU_HOST_TO_DEVICE));
xpu_memcpy(topks_xpu, XPU_CALL(xpu_memcpy(topks_xpu,
topks.data(), topks.data(),
topks.size() * sizeof(int), topks.size() * sizeof(int),
XPUMemcpyKind::XPU_HOST_TO_DEVICE); XPUMemcpyKind::XPU_HOST_TO_DEVICE));
int r = xdnn::sequence_topk_avg_pooling(ctx.GetRawContext(), int r = xdnn::sequence_topk_avg_pooling(ctx.GetRawContext(),
in_data, in_data,
......
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
// limitations under the License. // limitations under the License.
#pragma once #pragma once
#include "lite/core/kernel.h" #include "lite/core/kernel.h"
namespace paddle { namespace paddle {
......
...@@ -25,9 +25,8 @@ void StackCompute::PrepareForRun() { ...@@ -25,9 +25,8 @@ void StackCompute::PrepareForRun() {
auto& param = this->Param<param_t>(); auto& param = this->Param<param_t>();
int n = param.X.size(); int n = param.X.size();
void* x_ptr = nullptr; x_ptr_guard_ = TargetWrapperXPU::MallocScratchPad(
xpu_malloc(&x_ptr, n * 8 /* sizeof(__global__ float*) */); n * 8 /* sizeof(__global__ float*) */, false /* use_l3 */);
x_ptr_guard_.reset(x_ptr);
x_ptr_cpu_.reserve(n); x_ptr_cpu_.reserve(n);
} }
...@@ -47,14 +46,15 @@ void StackCompute::Run() { ...@@ -47,14 +46,15 @@ void StackCompute::Run() {
for (int i = 0; i < n; ++i) { for (int i = 0; i < n; ++i) {
x_ptr_cpu_[i] = param.X[i]->data<float>(); x_ptr_cpu_[i] = param.X[i]->data<float>();
} }
xpu_memcpy(x_ptr_guard_.get(), &x_ptr_cpu_[0], n * 8, XPU_HOST_TO_DEVICE); XPU_CALL(xpu_memcpy(
x_ptr_guard_->addr_, &x_ptr_cpu_[0], n * 8, XPU_HOST_TO_DEVICE));
int r = xdnn::stack_forward( int r = xdnn::stack_forward(
ctx.GetRawContext(), /* context */ ctx.GetRawContext(), /* context */
height, /* height */ height, /* height */
width, /* width */ width, /* width */
n, /* n */ n, /* n */
x_ptr_guard_.get(), /* x_ptr */ x_ptr_guard_->addr_, /* x_ptr */
param.Out->mutable_data<float>(TARGET(kXPU)) /* out */); param.Out->mutable_data<float>(TARGET(kXPU)) /* out */);
CHECK_EQ(r, 0); CHECK_EQ(r, 0);
} }
......
...@@ -14,10 +14,9 @@ ...@@ -14,10 +14,9 @@
#pragma once #pragma once
#include <memory>
#include <vector> #include <vector>
#include "lite/backends/xpu/target_wrapper.h" // XPUScratchPadGuard
#include "lite/core/kernel.h" #include "lite/core/kernel.h"
#include "lite/kernels/xpu/utils.h" // XPUFreeDeleter
namespace paddle { namespace paddle {
namespace lite { namespace lite {
...@@ -35,7 +34,7 @@ class StackCompute : public KernelLite<TARGET(kXPU), PRECISION(kFloat)> { ...@@ -35,7 +34,7 @@ class StackCompute : public KernelLite<TARGET(kXPU), PRECISION(kFloat)> {
virtual ~StackCompute() = default; virtual ~StackCompute() = default;
private: private:
std::unique_ptr<void, XPUFreeDeleter> x_ptr_guard_; XPUScratchPadGuard x_ptr_guard_;
std::vector<const float*> x_ptr_cpu_; std::vector<const float*> x_ptr_cpu_;
}; };
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "lite/backends/xpu/xpu_header_sitter.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace xpu {
struct XPUFreeDeleter {
void operator()(void* p) const { xpu_free(p); }
};
} // namespace xpu
} // namespace kernels
} // namespace lite
} // namespace paddle
...@@ -23,10 +23,12 @@ namespace kernels { ...@@ -23,10 +23,12 @@ namespace kernels {
namespace xpu { namespace xpu {
void VarConv2DCompute::PrepareForRun() { void VarConv2DCompute::PrepareForRun() {
offset_x_xpu_guard_ = TargetWrapperXPU::MallocScratchPad(64 * sizeof(int)); offset_x_xpu_guard_ = TargetWrapperXPU::MallocScratchPad(
offset_y_xpu_guard_ = TargetWrapperXPU::MallocScratchPad(64 * sizeof(int)); XPU_MAX_LOD_SIZE * sizeof(int), false /* use_l3 */);
offset_x_cpu.reset(new int[64]); offset_y_xpu_guard_ = TargetWrapperXPU::MallocScratchPad(
offset_y_cpu.reset(new int[64]); XPU_MAX_LOD_SIZE * sizeof(int), false /* use_l3 */);
offset_x_cpu.reset(new int[XPU_MAX_LOD_SIZE]);
offset_y_cpu.reset(new int[XPU_MAX_LOD_SIZE]);
} }
void VarConv2DCompute::Run() { void VarConv2DCompute::Run() {
...@@ -94,14 +96,14 @@ void VarConv2DCompute::Run() { ...@@ -94,14 +96,14 @@ void VarConv2DCompute::Run() {
offset_x_cpu[i] = offset_x[i]; offset_x_cpu[i] = offset_x[i];
offset_y_cpu[i] = offset_y[i]; offset_y_cpu[i] = offset_y[i];
} }
xpu_memcpy(offset_x_xpu, XPU_CALL(xpu_memcpy(offset_x_xpu,
offset_x_cpu.get(), offset_x_cpu.get(),
(batch + 1) * sizeof(int), (batch + 1) * sizeof(int),
XPUMemcpyKind::XPU_HOST_TO_DEVICE); XPUMemcpyKind::XPU_HOST_TO_DEVICE));
xpu_memcpy(offset_y_xpu, XPU_CALL(xpu_memcpy(offset_y_xpu,
offset_y_cpu.get(), offset_y_cpu.get(),
(batch + 1) * sizeof(int), (batch + 1) * sizeof(int),
XPUMemcpyKind::XPU_HOST_TO_DEVICE); XPUMemcpyKind::XPU_HOST_TO_DEVICE));
int r = xdnn::search_varconv<float, int16_t>(ctx.GetRawContext(), int r = xdnn::search_varconv<float, int16_t>(ctx.GetRawContext(),
batch, batch,
......
...@@ -88,6 +88,78 @@ bool XPUMmdnnBidEmbGrnnAttOp::AttachImpl(const cpp::OpDesc& op_desc, ...@@ -88,6 +88,78 @@ bool XPUMmdnnBidEmbGrnnAttOp::AttachImpl(const cpp::OpDesc& op_desc,
return true; return true;
} }
bool XPUMmdnnBidEmbGrnnAttOp2::CheckShape() const { return true; }
bool XPUMmdnnBidEmbGrnnAttOp2::InferShapeImpl() const {
auto& id_dims = param_.id0->dims();
auto& id_lod = param_.id0->lod()[0];
auto& emb_tbl_dims = param_.emb_tbl->dims();
auto& grnn_wh_dims = param_.grnn_rv_wh->dims();
param_.emb0_out->Resize({id_dims[0], emb_tbl_dims[1]});
param_.emb0_out->set_lod({id_lod});
param_.grnn_fw_pool_out->Resize(
{(int64_t)id_lod.size() - 1, grnn_wh_dims[2]});
param_.grnn_rv_pool_out->Resize(
{(int64_t)id_lod.size() - 1, grnn_wh_dims[2]});
param_.att_pool_out->Resize(
{(int64_t)id_lod.size() - 1, 2 * grnn_wh_dims[2]});
param_.concat_3in1_out->Resize({id_dims[0], 3 * grnn_wh_dims[2]});
param_.concat_3in1_out->set_lod({id_lod});
param_.emb_fw_out->Resize({id_dims[0], emb_tbl_dims[1]});
param_.emb_fw_out->set_lod({id_lod});
return true;
}
bool XPUMmdnnBidEmbGrnnAttOp2::AttachImpl(const cpp::OpDesc& op_desc,
lite::Scope* scope) {
param_.id0 =
scope->FindVar(op_desc.Input("id0").front())->GetMutable<lite::Tensor>();
param_.id1 =
scope->FindVar(op_desc.Input("id1").front())->GetMutable<lite::Tensor>();
param_.emb_tbl = scope->FindVar(op_desc.Input("emb_tbl").front())
->GetMutable<lite::Tensor>();
param_.grnn_fw_wh = scope->FindVar(op_desc.Input("grnn_fw_wh").front())
->GetMutable<lite::Tensor>();
param_.grnn_fw_wi = scope->FindVar(op_desc.Input("grnn_fw_wi").front())
->GetMutable<lite::Tensor>();
param_.grnn_rv_wh = scope->FindVar(op_desc.Input("grnn_rv_wh").front())
->GetMutable<lite::Tensor>();
param_.grnn_rv_wi = scope->FindVar(op_desc.Input("grnn_rv_wi").front())
->GetMutable<lite::Tensor>();
param_.att_fc_w = scope->FindVar(op_desc.Input("att_fc_w").front())
->GetMutable<lite::Tensor>();
param_.att_fc_b = scope->FindVar(op_desc.Input("att_fc_b").front())
->GetMutable<lite::Tensor>();
param_.emb0_out = scope->FindVar(op_desc.Output("emb0_out").front())
->GetMutable<lite::Tensor>();
param_.grnn_fw_pool_out =
scope->FindVar(op_desc.Output("grnn_fw_pool_out").front())
->GetMutable<lite::Tensor>();
param_.grnn_rv_pool_out =
scope->FindVar(op_desc.Output("grnn_rv_pool_out").front())
->GetMutable<lite::Tensor>();
param_.att_pool_out = scope->FindVar(op_desc.Output("att_pool_out").front())
->GetMutable<lite::Tensor>();
param_.concat_3in1_out =
scope->FindVar(op_desc.Output("concat_3in1_out").front())
->GetMutable<lite::Tensor>();
param_.emb_fw_out = scope->FindVar(op_desc.Output("emb_fw_out").front())
->GetMutable<lite::Tensor>();
param_.grnn_fw_wh_maxs =
op_desc.GetAttr<std::vector<float>>("grnn_fw_wh_maxs");
param_.grnn_fw_wi_maxs =
op_desc.GetAttr<std::vector<float>>("grnn_fw_wi_maxs");
param_.grnn_rv_wh_maxs =
op_desc.GetAttr<std::vector<float>>("grnn_rv_wh_maxs");
param_.grnn_rv_wi_maxs =
op_desc.GetAttr<std::vector<float>>("grnn_rv_wi_maxs");
param_.att_fc_w_max = op_desc.GetAttr<float>("att_fc_w_max");
return true;
}
bool XPUMmdnnBidEmbAttOp::CheckShape() const { return true; } bool XPUMmdnnBidEmbAttOp::CheckShape() const { return true; }
bool XPUMmdnnBidEmbAttOp::InferShapeImpl() const { bool XPUMmdnnBidEmbAttOp::InferShapeImpl() const {
...@@ -157,6 +229,7 @@ bool XPUMmdnnMatchConvTopkOp::AttachImpl(const cpp::OpDesc& op_desc, ...@@ -157,6 +229,7 @@ bool XPUMmdnnMatchConvTopkOp::AttachImpl(const cpp::OpDesc& op_desc,
param_.input_w_max = op_desc.GetAttr<float>("input_w_max"); param_.input_w_max = op_desc.GetAttr<float>("input_w_max");
param_.conv_w_max = op_desc.GetAttr<float>("conv_w_max"); param_.conv_w_max = op_desc.GetAttr<float>("conv_w_max");
param_.topks = op_desc.GetAttr<std::vector<int>>("topks"); param_.topks = op_desc.GetAttr<std::vector<int>>("topks");
param_.output_channel = op_desc.GetAttr<int>("output_channel");
param_.channel_num = op_desc.GetAttr<int>("channel_num"); param_.channel_num = op_desc.GetAttr<int>("channel_num");
param_.dim_t = op_desc.GetAttr<int>("dim_t"); param_.dim_t = op_desc.GetAttr<int>("dim_t");
return true; return true;
...@@ -182,10 +255,10 @@ bool XPUMmdnnMergeAllOp::AttachImpl(const cpp::OpDesc& op_desc, ...@@ -182,10 +255,10 @@ bool XPUMmdnnMergeAllOp::AttachImpl(const cpp::OpDesc& op_desc,
auto t = scope->FindVar(name)->GetMutable<lite::Tensor>(); auto t = scope->FindVar(name)->GetMutable<lite::Tensor>();
param_.concat_7in1_x.push_back(t); param_.concat_7in1_x.push_back(t);
} }
param_.concat_2in1_x.clear(); param_.concat_topk_x.clear();
for (auto& name : op_desc.Input("concat_2in1_x")) { for (auto& name : op_desc.Input("concat_topk_x")) {
auto t = scope->FindVar(name)->GetMutable<lite::Tensor>(); auto t = scope->FindVar(name)->GetMutable<lite::Tensor>();
param_.concat_2in1_x.push_back(t); param_.concat_topk_x.push_back(t);
} }
param_.grnn_fw_wh = scope->FindVar(op_desc.Input("grnn_fw_wh").front()) param_.grnn_fw_wh = scope->FindVar(op_desc.Input("grnn_fw_wh").front())
->GetMutable<lite::Tensor>(); ->GetMutable<lite::Tensor>();
...@@ -231,6 +304,8 @@ bool XPUMmdnnMergeAllOp::AttachImpl(const cpp::OpDesc& op_desc, ...@@ -231,6 +304,8 @@ bool XPUMmdnnMergeAllOp::AttachImpl(const cpp::OpDesc& op_desc,
REGISTER_LITE_OP(__xpu__mmdnn_bid_emb_grnn_att, REGISTER_LITE_OP(__xpu__mmdnn_bid_emb_grnn_att,
paddle::lite::operators::XPUMmdnnBidEmbGrnnAttOp); paddle::lite::operators::XPUMmdnnBidEmbGrnnAttOp);
REGISTER_LITE_OP(__xpu__mmdnn_bid_emb_grnn_att2,
paddle::lite::operators::XPUMmdnnBidEmbGrnnAttOp2);
REGISTER_LITE_OP(__xpu__mmdnn_bid_emb_att, REGISTER_LITE_OP(__xpu__mmdnn_bid_emb_att,
paddle::lite::operators::XPUMmdnnBidEmbAttOp); paddle::lite::operators::XPUMmdnnBidEmbAttOp);
REGISTER_LITE_OP(__xpu__mmdnn_match_conv_topk, REGISTER_LITE_OP(__xpu__mmdnn_match_conv_topk,
......
...@@ -41,6 +41,29 @@ class XPUMmdnnBidEmbGrnnAttOp : public OpLite { ...@@ -41,6 +41,29 @@ class XPUMmdnnBidEmbGrnnAttOp : public OpLite {
mutable XPUMmdnnBidEmbGrnnAttParam param_; mutable XPUMmdnnBidEmbGrnnAttParam param_;
}; };
class XPUMmdnnBidEmbGrnnAttOp2 : public OpLite {
public:
XPUMmdnnBidEmbGrnnAttOp2() {}
explicit XPUMmdnnBidEmbGrnnAttOp2(const std::string &op_type)
: OpLite(op_type) {}
bool CheckShape() const override;
bool InferShapeImpl() const override;
bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override;
void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); }
std::string DebugString() const override {
return "XPUMmdnnBidEmbGrnnAttOp2";
}
private:
mutable XPUMmdnnBidEmbGrnnAttParam2 param_;
};
class XPUMmdnnBidEmbAttOp : public OpLite { class XPUMmdnnBidEmbAttOp : public OpLite {
public: public:
XPUMmdnnBidEmbAttOp() {} XPUMmdnnBidEmbAttOp() {}
......
...@@ -1627,11 +1627,36 @@ struct XPUMmdnnBidEmbGrnnAttParam : ParamBase { ...@@ -1627,11 +1627,36 @@ struct XPUMmdnnBidEmbGrnnAttParam : ParamBase {
std::vector<float> grnn_rv_wi_maxs; std::vector<float> grnn_rv_wi_maxs;
float att_fc_w_max{0.0f}; float att_fc_w_max{0.0f};
lite::Tensor* grnn_fw_pool_out{}; // 1 lite::Tensor* grnn_fw_pool_out{};
lite::Tensor* grnn_rv_pool_out{}; // 2 lite::Tensor* grnn_rv_pool_out{};
lite::Tensor* att_pool_out{}; // 3 lite::Tensor* att_pool_out{};
lite::Tensor* concat_3in1_out{}; // 4 lite::Tensor* concat_3in1_out{};
lite::Tensor* emb_fw_out{}; // 5 lite::Tensor* emb_fw_out{};
};
struct XPUMmdnnBidEmbGrnnAttParam2 : ParamBase {
lite::Tensor* id0{};
lite::Tensor* id1{};
lite::Tensor* emb_tbl{};
lite::Tensor* grnn_fw_wh{};
lite::Tensor* grnn_fw_wi{};
lite::Tensor* grnn_rv_wh{};
lite::Tensor* grnn_rv_wi{};
lite::Tensor* att_fc_w{};
lite::Tensor* att_fc_b{};
std::vector<float> grnn_fw_wh_maxs;
std::vector<float> grnn_fw_wi_maxs;
std::vector<float> grnn_rv_wh_maxs;
std::vector<float> grnn_rv_wi_maxs;
float att_fc_w_max{0.0f};
lite::Tensor* emb0_out{};
lite::Tensor* grnn_fw_pool_out{};
lite::Tensor* grnn_rv_pool_out{};
lite::Tensor* att_pool_out{};
lite::Tensor* concat_3in1_out{};
lite::Tensor* emb_fw_out{};
}; };
struct XPUMmdnnBidEmbAttParam : ParamBase { struct XPUMmdnnBidEmbAttParam : ParamBase {
...@@ -1643,8 +1668,8 @@ struct XPUMmdnnBidEmbAttParam : ParamBase { ...@@ -1643,8 +1668,8 @@ struct XPUMmdnnBidEmbAttParam : ParamBase {
float att_fc_w_max{0.0f}; float att_fc_w_max{0.0f};
lite::Tensor* att_pool_out{}; // 1 lite::Tensor* att_pool_out{};
lite::Tensor* emb_fw_out{}; // 2 lite::Tensor* emb_fw_out{};
}; };
struct XPUMmdnnMatchConvTopkParam : ParamBase { struct XPUMmdnnMatchConvTopkParam : ParamBase {
...@@ -1656,6 +1681,7 @@ struct XPUMmdnnMatchConvTopkParam : ParamBase { ...@@ -1656,6 +1681,7 @@ struct XPUMmdnnMatchConvTopkParam : ParamBase {
float input_w_max{0.0f}; float input_w_max{0.0f};
float conv_w_max{0.0f}; float conv_w_max{0.0f};
std::vector<int> topks; std::vector<int> topks;
int output_channel{0};
int channel_num{0}; int channel_num{0};
int dim_t{0}; int dim_t{0};
...@@ -1664,7 +1690,7 @@ struct XPUMmdnnMatchConvTopkParam : ParamBase { ...@@ -1664,7 +1690,7 @@ struct XPUMmdnnMatchConvTopkParam : ParamBase {
struct XPUMmdnnMergeAllParam : ParamBase { struct XPUMmdnnMergeAllParam : ParamBase {
std::vector<lite::Tensor*> concat_7in1_x; std::vector<lite::Tensor*> concat_7in1_x;
std::vector<lite::Tensor*> concat_2in1_x; std::vector<lite::Tensor*> concat_topk_x;
lite::Tensor* grnn_fw_wh{}; lite::Tensor* grnn_fw_wh{};
lite::Tensor* grnn_fw_wi{}; lite::Tensor* grnn_fw_wi{};
lite::Tensor* grnn_rv_wh{}; lite::Tensor* grnn_rv_wh{};
......
...@@ -26,156 +26,171 @@ ...@@ -26,156 +26,171 @@
DEFINE_bool(perf, false, "perf?"); DEFINE_bool(perf, false, "perf?");
DEFINE_string(perf_input, "perf_input", "perf_input"); DEFINE_string(perf_input, "perf_input", "perf_input");
DEFINE_int32(perf_batch_size, 40, "perf_batch_size");
DEFINE_bool(use_xpu, true, "use_xpu?");
DEFINE_int32(perf_dev, 0, "perf_dev");
namespace paddle { namespace paddle {
namespace lite { namespace lite {
std::vector<int64_t> input0; class SampleReader {
std::vector<uint64_t> input0_lod = {0}; public:
std::vector<int64_t> input1; std::vector<std::vector<int64_t>> data;
std::vector<uint64_t> input1_lod = {0}; std::vector<std::vector<uint64_t>> lod;
std::vector<int64_t> input2;
std::vector<uint64_t> input2_lod = {0};
std::vector<int64_t> input3;
std::vector<uint64_t> input3_lod = {0};
std::vector<int64_t> input4;
std::vector<uint64_t> input4_lod = {0};
std::vector<int64_t> input5;
std::vector<uint64_t> input5_lod = {0};
void ParseInput() { void Read() {
std::string raw_input = std::string raw_input =
"0 1;145 10251 839 3719 428 52;1050488 1050488 911898 3719 760166 " "0 1;125 584 142 2114 197;125 756226 756913 855693 760836;125 584 142 "
"760166;3719 428 52 18 1102 10327 252 20 153 2897 1146 70 156 6 145 " "2114 197 10 2899;125 756226 756913 855693 760836 10 750793;125 584 "
"10251 839 5 1779 1729 1779 1729 18 2707 6 2707 20 4742 4937 432 6 " "142 2114 197 10 2899 2 825 32 18499 125 584 295 2114 197 2114 2730 6 "
"3869;3719 760166 760166 18 1035176 1035176 764393 764393 1259006 767614 " "15 32 18499 125 584 142 295 2114 1423 21 2 334 863 5122 197 974 21 "
"767614 1020808 769579 793958 793958 1050488 911898 751332 751332 750336 " "295 619 25 2114 1755 2701 197 15 216 23 18499 125 584 142 599 3228 23 "
"750799 750336 751575 751575 751544 751735 751397 751365 751512 751512 " "2 5122 1917 804 5 2114 197 1236 3 2114 1403 15 3886 1080 23 1150 125 "
"753011 751562;3719 428 52 18 1102 10327 252 20 153 2897 1146 70 156 6 " "475 23 2998 23;125 756226 756913 855693 760836 10 750793 2 825 750355 "
"145 10251 839 2 1211 3 3719 720 1540 145 10251 839 9405 4315 5998 4 2 " "18499 881680 756226 295 765124 760836 2114 872813 754265 15 32 18499 "
"600 373 41 3719 428 52 44 10251 4302 1319 7 12 2 768 6 918 6 841 870 8 " "881680 756226 756913 761251 765124 752843 766823 2 334 759834 5122 "
"843 8 271;3719 760166 760166 18 1035176 1035176 764393 764393 1259006 " "774643 758458 21 295 755114 25 1148365 1755 2701 197 15 216 23 18499 "
"767614 767614 1020808 769579 793958 793958 1050488 911898 2 773899 " "881680 756226 756913 826848 3228 23 2 5122 831009 804 752371 2114 "
"773899 3719 1118420 1118420 1050488 1050488 911898 9405 4315 5998 4 2 " "760836 1236 3 2114 910393 15 3886 1080 23 877375 752137 761034 792123 "
"785435 785435 41 3719 760166 760166 44 10251 4302 1319 750118 750118 2 " "2998 23;1;1;\n"
"750465 750465 750274 750398 750233 751252 751252 753447 752830 753112;\n" "0 0;125 584 142 2114 197;125 756226 756913 855693 760836;121 28 1054 "
"0 0;145 10251 839 3719 428 52;1050488 1050488 911898 3719 760166 " "1459 125 72 32 2321 531 125 295 584 142 2114 197 14 477 30 121;121 28 "
"760166;2109 2467 1805 227 3719 428 52 18 1102 10327 252 20 6 242 78 6 " "764114 1459 753052 750694 750001 886192 750435 752179 295 584 756913 "
"532 78;2109 2467 1805 1245431 1245431 760166 760166 18 1035176 1035176 " "855693 760836 14 477 30 753504;121 28 1054 1459 125 72 32 2321 531 "
"764393 764393 752116 242 750370 750370 752081 751247;2109 2467 1805 227 " "125 295 584 142 2114 197 2 121 28 1054 1459 125 72 32 2321 531 125 "
"3719 428 52 18 1102 10327 252 20 2 145 242 1050 252 3582 2212;2109 2467 " "295 584 142 4 263 2114 197 43 95 863 2114 323 20 142 626 11 2 45 10 "
"1805 1245431 1245431 760166 760166 18 1035176 1035176 764393 764393 2 " "45 58 142 65 918 741 2114 197 764 3 5122 26 51 1266 2037 295 222 1121 "
"871717 871717 757921 757921 3582 2212;\n" "4491 3 545 4338 11 2 5122 26 495 3 142 3444 3249 2114 197 3 626 4 "
"0 0;145 10251 839 3719 428 52;1050488 1050488 911898 3719 760166 " "2794;121 28 764114 1459 753052 750694 750001 886192 750435 752179 295 "
"760166;145 10251 839 76 31 1337 823 7506 567 65 170 8 21293 3719 5 43 " "584 756913 855693 760836 2 121 28 764114 1459 753052 750694 750001 "
"394 743 42;1050488 1050488 911898 750016 750016 1337 823 7506 762617 " "886192 750435 752179 295 584 756913 4 750885 2114 760836 43 750030 "
"762617 866652 8 21293 3719 5 43 914758 914758 757202;145 10251 839 76 " "754302 2114 323 822131 142 626 769001 2 45 750128 750324 58 142 "
"31 1337 823 7506 567 65 170 8 21293 3719 2 17580 30 523324 3 10251 4104 " "1147454 918 910829 2114 760836 841946 767340 5122 779102 51 1266 2037 "
"281 3 8511 3719 2217 3 13 226 3083 4 11251 1606 357 9 2 145 10251 839 " "756461 222 752031 942669 1139389 780275 4338 830597 2 5122 779102 495 "
"76 31 1337 823 7506 567 65 170 2 7506 2445 8 145 10251 839 528 839 " "761418 142 3444 852932 2114 760836 3 760162 757966 751127;121 295 "
"19670 6538;1050488 1050488 911898 750016 750016 1337 823 7506 762617 " "5593 142 2114 197;121 295 5593 925208 2114 760836;\n"
"762617 866652 8 21293 3719 2 816626 816626 523324 3 1181698 1181698 " "0 0;125 584 142 2114 197;125 756226 756913 855693 760836;207 125 584 "
"751656 780821 1063148 3719 2217 3 752498 752498 831323 753602 11251 " "142 2114 1423 14 5283 1745 73;207 752276 756226 756913 855693 752843 "
"1606 357 9 2 1050488 1050488 911898 750016 750016 1337 823 7506 762617 " "14 5283 781651 786597;6109 18807 142 5 64 5283 1745 73 3690 1060 3626 "
"762617 866652 2 7506 753045 753045 756756 1050488 911898 528 839 19670 " "4 716 51 1030 2114 197 4 428 936 9066 10 10 10 2 207 125 584 142 2114 "
"6538;\n" "1423 2 15329 2114 197 5669 401 318 285 953 4 2114 197 2285 7 1783 11 "
"0 0;145 10251 839 3719 428 52;1050488 1050488 911898 3719 760166 " "2 5122 197 14017 584;6109 18807 142 5 755319 5283 781651 786597 3690 "
"760166;145 10251 839 99 4 1102 10327 2196 41 3719 428 52 44 99 4 2899 " "1060 3626 4 716 910478 1030 2114 760836 4 750323 936 9066 10 750002 "
"229 10 10 10;1050488 1050488 911898 807966 750273 1035176 1035176 " "750002 2 207 752276 756226 756913 855693 752843 2 15329 2114 760836 "
"1237875 41 3719 760166 760166 753645 753645 750273 2899 229 750001 " "5669 401 318 757541 750261 4 2114 760836 2285 7 757639 11 2 5122 "
"750001 750001;145 10251 839 99 4 1102 10327 2196 41 3719 428 52 44 99 4 " "774643 14017 584;125 584 142 1745 5122;125 756226 756913 1745 "
"2899 229 10 10 10 2 1177 8 145 10251 839 99 4 1102 10327 2196 41 3719 " "755836;\n"
"428 52 44 99 4 2 101 8 1922 17 2184 2 1154 1922 72 1198 1266 " "0 0;125 584 142 2114 197;125 756226 756913 855693 760836;149 396 778 "
"4516;1050488 1050488 911898 807966 750273 1035176 1035176 1237875 41 " "584 142 295 2114 1423 14 64 125 584 73 21 36670 5834 10 211 25;149 "
"3719 760166 760166 753645 753645 750273 2899 229 750001 750001 750001 2 " "751876 1048872 584 756913 761251 765124 752843 14 64 125 756226 73 "
"750257 750257 756756 1050488 911898 807966 750273 1035176 1035176 " "944567 36670 5834 10 750012 753240;101 10 2114 197 3 946 2 149 396 "
"1237875 41 3719 760166 760166 753645 753645 750273 2 764513 764513 " "778 584 142 295 2114 1423 2 2610 6 1444 111 2114 948 72 32 21 15 494 "
"851213 851213 854628 2 753018 753018 754317 753328 754085 754070;\n" "25 4 2114 197 5669 1145 2 148 295 149 396 778 584 142 295 21 22853 41 "
"0 0;145 10251 839 3719 428 52;1050488 1050488 911898 3719 760166 " "348 619 25 366 5305 2114 807 4 1115 381 1955 2114 11;101 751178 2114 "
"760166;73 5347 112 8 145 10251 839 262 169 22729 3719 6 743 6 339 1156 " "760836 3 946 2 149 751876 1048872 584 756913 761251 765124 752843 2 "
"78 136 399 693 128 571;776150 776150 112 756756 756756 1050488 911898 " "2610 753567 775165 750899 972788 948 750125 750001 751875 15 494 25 4 "
"791355 791355 22729 3719 6 758277 758277 750137 750234 750241 750178 " "2114 760836 5669 1145 2 148 808886 982157 751876 1048872 584 756913 "
"750055 750216 750212 750049;73 5347 112 8 145 10251 839 262 169 22729 " "761251 790772 22853 41 348 619 25 366 894206 2114 1008440 4 753953 "
"3719 2 588 415 549 415 115 23;776150 776150 112 756756 756756 1050488 " "381 851474 765868 11;149 396 778 584 142 295 2 149 396 354 778 584 "
"911898 791355 791355 22729 3719 2 750221 750221 750262 750277 750277 " "142 1333 2 584 778 295 5122 2 149 396 778 584 3609 2 149 396 64478 "
"750261;"; "816 14246 1423 2 149 396 584 32 127 19 3609 2 149 396 584 73 2 149 "
auto raw_lines = Split(raw_input, "\n"); "396 584 778 295 2285 142 4922 323 2 149 396 584 2114 2 149 396 253 "
for (auto& raw_line : raw_lines) { "584 2114 197;149 751876 1048872 584 756913 761251 2 149 751876 756286 "
auto inputx = Split(raw_line, ";"); "767182 584 756913 1333 2 584 778 897778 941364 2 149 751876 1048872 "
for (size_t i = 1; i < inputx.size(); ++i) { "584 1102835 2 149 751876 64478 816 14246 912094 2 149 751876 584 "
auto tokens = Split(inputx[i], " "); "773547 127 750771 791456 2 149 751876 584 73 2 149 751876 584 778 "
static std::vector<int64_t>* const input_array[] = { "897778 2285 751493 791984 323 2 149 751876 584 2114 2 149 751876 "
&input0, &input0, &input1, &input2, &input3, &input4, &input5}; "808443 835481 2114 760836;\n"
static std::vector<uint64_t>* const lod_array[] = {&input0_lod, "0 0;125 584 142 2114 197;125 756226 756913 855693 760836;125 584 545 "
&input0_lod, "149 14 125 584;125 756226 545 874302 14 125 756226;2204 25 30 1692 "
&input1_lod, "1770 6534 295 125 584 72 32 1346 4 2698 2114 197 11 2 4235 4301 240 "
&input2_lod, "295 125 584 72 32 21 6708 15 56974 494 25 1030 2114 197 110 804 495 "
&input3_lod, "611 2 221 759 341 6 5283 1745 73 71 2114 1423 71 125 584 545 149 149 "
&input4_lod, "2 505 345 58 125 584 65 3486 2114 295 4 45 786 196 6604 6086;2204 25 "
&input5_lod}; "30 797189 1770 1191824 295 752782 756226 751697 750001 1346 4 2698 "
for (auto token : tokens) { "2114 760836 765158 2 4235 4301 240 753859 752782 756226 751697 750001 "
input_array[i]->push_back((int64_t)atoi(token.c_str())); "751875 6708 15 56974 494 25 1030 2114 760836 777607 762850 966521 611 "
} "2 221 752565 750130 750084 910219 781651 786597 71 2114 752843 71 125 "
lod_array[i]->push_back((uint64_t)tokens.size() + "756226 545 874302 149 2 505 825657 782848 125 756226 65 3486 2114 "
(*lod_array[i])[lod_array[i]->size() - 1]); "760669 4 45 755747 758903 6604 6086;125 584 2114 2 125 584 2114 1423 "
} "2 125 584 2114 149 2 149 584 1745 5122 725 2 2114 125 584 2 125 584 "
} "2114 2 2621 584 2114 2 527 37 2754 130 170 1013 494 887 240 2 4521 "
return; "11111 586 2321 531 125 584 142 1360 816 2842 1423 2 125 584 2114;125 "
} "756226 2114 2 125 756226 2114 752843 2 125 756226 2114 783644 2 149 "
"760183 1745 755836 725 2 2114 125 756226 2 125 756226 2114 2 2621 "
"932600 2114 2 527 751304 869964 754462 170 1013 750719 778287 774620 "
"2 4521 11111 586 2321 750435 752179 756226 756913 1360 764399 2842 "
"1423 2 125 756226 2114;\n"
"0 0;125 584 142 2114 197;125 756226 756913 855693 760836;207 584 142 "
"2114 197 4 207 584 142 2114 197 674 14 240 4328 14 4328 767;207 "
"1237071 756913 855693 760836 4 207 1237071 756913 855693 760836 674 "
"14 240 755573 14 4328 795065;207 584 142 2114 197 2 325 71 71 207 584 "
"142 2114 197 2 876 125 140 2114 197 2 207 584 142 2114 197 674 1210 "
"239 4328 767 268 1349 485 28 4389 504 3 941 57 1419 1978 11;207 "
"1237071 756913 855693 760836 2 325 71 71 207 1237071 756913 855693 "
"760836 2 876 125 750977 1250790 760836 2 207 1237071 756913 855693 "
"760836 674 814792 755820 812174 795065 818859 817155 816597 761001 "
"774461 780904 820475 1109800 790141 790459 780324 770390;584 142 295 "
"2114 232 2 207 584 2114 197 2 584 142 295 2114 232 2 584 142 512 2114 "
"197;584 756913 761251 765124 1006359 2 207 1237071 2114 760836 2 584 "
"756913 761251 765124 1006359 2 584 756913 879930 2114 760836;";
class MmdnnReader { auto lines = Split(raw_input, "\n");
std::ifstream ifs; for (auto& line : lines) {
std::vector<std::string> StringSplit(const std::string& in, auto split1 = Split(line, ";");
const std::string& delim) { if (data.size() == 0) {
std::vector<std::string> ret; for (size_t i = 1; i < split1.size(); ++i) {
if (in == "") { data.push_back(std::vector<int64_t>());
return ret; lod.push_back({0});
} }
auto begpos = in.find_first_not_of(delim);
while (begpos != std::string::npos) {
auto endpos = in.find_first_of(delim, begpos);
if (endpos == std::string::npos) {
endpos = in.size();
} }
std::string ssubstr = in.substr(begpos, endpos - begpos);
ret.push_back(ssubstr); for (size_t i = 1; i < split1.size(); ++i) {
begpos = endpos + 1; auto split2 = Split(split1[i], " ");
if (endpos >= (in.size() - 1)) { if (split2.size() == 0) {
break; split2.push_back("1280000");
}
for (auto e : split2) {
data[i - 1].push_back(std::stoi(e.c_str(), nullptr, 0));
}
lod[i - 1].push_back(lod[i - 1].back() + split2.size());
} }
} }
return ret;
} }
};
class FileReader {
std::ifstream ifs;
public: public:
std::vector<int64_t> data[6]; std::vector<std::vector<int64_t>> data;
std::vector<uint64_t> lod[6]; std::vector<std::vector<uint64_t>> lod;
void Init(std::string file_name) { ifs.open(file_name); } void Init(std::string file_name) { ifs.open(file_name); }
int Read(int maxline) { int Read(int maxline) {
for (int i = 0; i < 6; i++) { data.clear();
data[i].clear(); lod.clear();
}
for (int i = 0; i < 6; i++) {
lod[i].clear();
lod[i].push_back(0);
}
std::string line; std::string line;
int cnt = 0; int cnt = 0;
while (cnt < maxline && getline(ifs, line)) { while (cnt < maxline && getline(ifs, line)) {
std::vector<std::string> split1 = StringSplit(line, ";"); std::vector<std::string> split1 = Split(line, ";");
for (int i = 1; i < 7; i++) { if (data.size() == 0) {
std::vector<std::string> split2 = StringSplit(split1[i], " "); for (size_t i = 1; i < split1.size(); ++i) {
data.push_back(std::vector<int64_t>());
lod.push_back({0});
}
}
for (size_t i = 1; i < split1.size(); i++) {
std::vector<std::string> split2 = Split(split1[i], " ");
if (split2.size() == 0) { if (split2.size() == 0) {
split2.push_back("1280000"); split2.push_back("1280000");
} }
for (size_t j = 0; j < split2.size(); j++) { for (size_t j = 0; j < split2.size(); j++) {
data[i - 1].push_back(std::stoi(split2[j].c_str(), nullptr, 0)); data[i - 1].push_back(std::stoi(split2[j].c_str(), nullptr, 0));
} }
// if (i % 2 == 1) {
// lod[i / 2].push_back(lod[i / 2].back() + split2.size());
//}
lod[i - 1].push_back(lod[i - 1].back() + split2.size()); lod[i - 1].push_back(lod[i - 1].back() + split2.size());
} }
cnt++; cnt++;
...@@ -186,36 +201,47 @@ class MmdnnReader { ...@@ -186,36 +201,47 @@ class MmdnnReader {
TEST(MMDNN, test_mmdnn_lite_xpu) { TEST(MMDNN, test_mmdnn_lite_xpu) {
lite_api::CxxConfig config; lite_api::CxxConfig config;
config.set_model_dir(FLAGS_model_dir); // config.set_model_dir(FLAGS_model_dir);
config.set_valid_places({lite_api::Place{TARGET(kXPU), PRECISION(kFloat)}, config.set_model_file(FLAGS_model_dir + "/__model__");
lite_api::Place{TARGET(kXPU), PRECISION(kInt64)}, config.set_param_file(FLAGS_model_dir + "/__param__");
lite_api::Place{TARGET(kX86), PRECISION(kFloat)}, config.set_xpu_dev_per_thread(FLAGS_perf_dev);
lite_api::Place{TARGET(kX86), PRECISION(kInt64)}, if (FLAGS_use_xpu) {
lite_api::Place{TARGET(kHost), PRECISION(kFloat)}}); config.set_valid_places(
{lite_api::Place{TARGET(kXPU), PRECISION(kFloat)},
lite_api::Place{TARGET(kXPU), PRECISION(kInt64)},
lite_api::Place{TARGET(kX86), PRECISION(kFloat)},
lite_api::Place{TARGET(kX86), PRECISION(kInt64)},
lite_api::Place{TARGET(kHost), PRECISION(kFloat)}});
} else {
config.set_valid_places(
{lite_api::Place{TARGET(kX86), PRECISION(kFloat)},
lite_api::Place{TARGET(kX86), PRECISION(kInt64)},
lite_api::Place{TARGET(kHost), PRECISION(kFloat)}});
}
config.set_xpu_workspace_l3_size_per_thread(); config.set_xpu_workspace_l3_size_per_thread();
auto predictor = lite_api::CreatePaddlePredictor(config); auto predictor = lite_api::CreatePaddlePredictor(config);
if (FLAGS_perf) { if (FLAGS_perf) {
MmdnnReader reader; FileReader file_reader;
reader.Init(FLAGS_perf_input); file_reader.Init(FLAGS_perf_input);
int UB_batch = 40; // upper bound of batch int UB_batch = FLAGS_perf_batch_size; // upper bound of batch
int iter = 0; int iter = 0;
double tsc_sum = 0; double tsc_sum = 0;
while (true) { while (true) {
int batch = reader.Read(UB_batch); int batch = file_reader.Read(UB_batch);
if (batch <= 0) { if (batch <= 0) {
break; break;
} }
++iter; ++iter;
for (int i = 0; i < 6; ++i) { for (size_t i = 0; i < file_reader.data.size(); ++i) {
auto input_x = predictor->GetInput(i); auto input_x = predictor->GetInput(i);
input_x->Resize({(int64_t)reader.data[i].size(), 1}); input_x->Resize({(int64_t)file_reader.data[i].size(), 1});
input_x->SetLoD({reader.lod[i]}); input_x->SetLoD({file_reader.lod[i]});
auto* data_x = input_x->mutable_data<int64_t>(); auto* data_x = input_x->mutable_data<int64_t>();
memcpy(data_x, memcpy(data_x,
reader.data[i].data(), file_reader.data[i].data(),
reader.data[i].size() * sizeof(int64_t)); file_reader.data[i].size() * sizeof(int64_t));
} }
auto start = GetCurrentUS(); auto start = GetCurrentUS();
...@@ -232,55 +258,17 @@ TEST(MMDNN, test_mmdnn_lite_xpu) { ...@@ -232,55 +258,17 @@ TEST(MMDNN, test_mmdnn_lite_xpu) {
return; return;
} }
ParseInput(); SampleReader sample_reader;
sample_reader.Read();
{ for (size_t i = 0; i < sample_reader.data.size(); ++i) {
std::vector<int64_t> input0_shape{(int64_t)input0.size(), 1}; auto input_x = predictor->GetInput(i);
auto input_tensor0 = predictor->GetInput(0); input_x->Resize({(int64_t)sample_reader.data[i].size(), 1});
input_tensor0->Resize(input0_shape); input_x->SetLoD({sample_reader.lod[i]});
input_tensor0->SetLoD({input0_lod}); auto* data_x = input_x->mutable_data<int64_t>();
auto* data0 = input_tensor0->mutable_data<int64_t>(); memcpy(data_x,
memcpy(data0, input0.data(), sizeof(int64_t) * input0.size()); sample_reader.data[i].data(),
} sample_reader.data[i].size() * sizeof(int64_t));
{
std::vector<int64_t> input1_shape{(int64_t)input1.size(), 1};
auto input_tensor1 = predictor->GetInput(1);
input_tensor1->Resize(input1_shape);
input_tensor1->SetLoD({input1_lod});
auto* data1 = input_tensor1->mutable_data<int64_t>();
memcpy(data1, input1.data(), sizeof(int64_t) * input1.size());
}
{
std::vector<int64_t> input2_shape{(int64_t)input2.size(), 1};
auto input_tensor2 = predictor->GetInput(2);
input_tensor2->Resize(input2_shape);
input_tensor2->SetLoD({input2_lod});
auto* data2 = input_tensor2->mutable_data<int64_t>();
memcpy(data2, input2.data(), sizeof(int64_t) * input2.size());
}
{
std::vector<int64_t> input3_shape{(int64_t)input3.size(), 1};
auto input_tensor3 = predictor->GetInput(3);
input_tensor3->Resize(input3_shape);
input_tensor3->SetLoD({input3_lod});
auto* data3 = input_tensor3->mutable_data<int64_t>();
memcpy(data3, input3.data(), sizeof(int64_t) * input3.size());
}
{
std::vector<int64_t> input4_shape{(int64_t)input4.size(), 1};
auto input_tensor4 = predictor->GetInput(4);
input_tensor4->Resize(input4_shape);
input_tensor4->SetLoD({input4_lod});
auto* data4 = input_tensor4->mutable_data<int64_t>();
memcpy(data4, input4.data(), sizeof(int64_t) * input4.size());
}
{
std::vector<int64_t> input5_shape{(int64_t)input5.size(), 1};
auto input_tensor5 = predictor->GetInput(5);
input_tensor5->Resize(input5_shape);
input_tensor5->SetLoD({input5_lod});
auto* data5 = input_tensor5->mutable_data<int64_t>();
memcpy(data5, input5.data(), sizeof(int64_t) * input5.size());
} }
for (int i = 0; i < FLAGS_warmup; ++i) { for (int i = 0; i < FLAGS_warmup; ++i) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册