未验证 提交 4780849f 编写于 作者: C Cwndmiao 提交者: GitHub

[LITE][XPU] Support ResnetCbam and MMDNN (#3844)

* [LITE][XPU] accomodate resnet_cbam

* [LITE][XPU] accomodate content-dnn

* fix pr comments test=develop

* fix pr comments test=develop

* fix pr comments test=develop test=xpu

* fix compilation error, test=develop test=xpu

* [X86] Fix the unit test of slice op
test=develop test=xpu
Co-authored-by: Nhong19860320 <9973393+hong19860320@users.noreply.github.com>
上级 e6b3d883
...@@ -39,7 +39,7 @@ else() ...@@ -39,7 +39,7 @@ else()
endif() endif()
find_library(XPU_SDK_XPU_RT_FILE NAMES xpurt find_library(XPU_SDK_XPU_RT_FILE NAMES xpurt
PATHS ${XPU_SDK_ROOT}/XTDK/runtime/shlib PATHS ${XPU_SDK_ROOT}/XTDK/runtime/shlib ${XPU_SDK_ROOT}/XTDK/shlib # libxpurt.so may have been moved to XTDK/runtime/shlib
NO_DEFAULT_PATH) NO_DEFAULT_PATH)
if(NOT XPU_SDK_XPU_RT_FILE) if(NOT XPU_SDK_XPU_RT_FILE)
......
...@@ -24,6 +24,9 @@ ...@@ -24,6 +24,9 @@
#ifdef LITE_WITH_CUDA #ifdef LITE_WITH_CUDA
#include "lite/backends/cuda/target_wrapper.h" #include "lite/backends/cuda/target_wrapper.h"
#endif #endif
#ifdef LITE_WITH_XPU
#include "lite/backends/xpu/target_wrapper.h"
#endif
#ifdef LITE_WITH_MLU #ifdef LITE_WITH_MLU
#include "lite/backends/mlu/target_wrapper.h" #include "lite/backends/mlu/target_wrapper.h"
...@@ -272,7 +275,7 @@ CxxConfig::mlu_firstconv_param() const { ...@@ -272,7 +275,7 @@ CxxConfig::mlu_firstconv_param() const {
void CxxConfig::set_xpu_workspace_l3_size_per_thread(int l3_size) { void CxxConfig::set_xpu_workspace_l3_size_per_thread(int l3_size) {
#ifdef LITE_WITH_XPU #ifdef LITE_WITH_XPU
lite::Context<TargetType::kXPU>::SetWorkspaceL3Size(l3_size); lite::TargetWrapperXPU::workspace_l3_size_per_thread = l3_size;
#else #else
LOG(WARNING) << "The invoking of the function " LOG(WARNING) << "The invoking of the function "
"'set_xpu_workspace_l3_size_per_thread' is ignored, please " "'set_xpu_workspace_l3_size_per_thread' is ignored, please "
...@@ -282,7 +285,7 @@ void CxxConfig::set_xpu_workspace_l3_size_per_thread(int l3_size) { ...@@ -282,7 +285,7 @@ void CxxConfig::set_xpu_workspace_l3_size_per_thread(int l3_size) {
void CxxConfig::set_xpu_dev_per_thread(int dev_no) { void CxxConfig::set_xpu_dev_per_thread(int dev_no) {
#ifdef LITE_WITH_XPU #ifdef LITE_WITH_XPU
lite::Context<TargetType::kXPU>::SetDev(dev_no); lite::TargetWrapperXPU::SetDev(dev_no);
#else #else
LOG(WARNING) << "The invoking of the function 'set_xpu_dev_per_thread' is " LOG(WARNING) << "The invoking of the function 'set_xpu_dev_per_thread' is "
"ignored, please rebuild it with LITE_WITH_XPU=ON."; "ignored, please rebuild it with LITE_WITH_XPU=ON.";
...@@ -291,7 +294,7 @@ void CxxConfig::set_xpu_dev_per_thread(int dev_no) { ...@@ -291,7 +294,7 @@ void CxxConfig::set_xpu_dev_per_thread(int dev_no) {
void CxxConfig::set_xpu_multi_encoder_precision(const std::string &precision) { void CxxConfig::set_xpu_multi_encoder_precision(const std::string &precision) {
#ifdef LITE_WITH_XPU #ifdef LITE_WITH_XPU
lite::Context<TargetType::kXPU>::_multi_encoder_precision = precision; lite::TargetWrapperXPU::multi_encoder_precision = precision;
#else #else
LOG(WARNING) << "The invoking of the function " LOG(WARNING) << "The invoking of the function "
"'set_xpu_multi_encoder_precision' is " "'set_xpu_multi_encoder_precision' is "
......
...@@ -55,6 +55,8 @@ USE_MIR_PASS(apu_subgraph_pass); ...@@ -55,6 +55,8 @@ USE_MIR_PASS(apu_subgraph_pass);
USE_MIR_PASS(quantized_op_attributes_inference_pass); USE_MIR_PASS(quantized_op_attributes_inference_pass);
USE_MIR_PASS(lite_scale_activation_fuse_pass); USE_MIR_PASS(lite_scale_activation_fuse_pass);
USE_MIR_PASS(__xpu__resnet_fuse_pass); USE_MIR_PASS(__xpu__resnet_fuse_pass);
USE_MIR_PASS(__xpu__resnet_cbam_fuse_pass);
USE_MIR_PASS(__xpu__multi_encoder_fuse_pass); USE_MIR_PASS(__xpu__multi_encoder_fuse_pass);
USE_MIR_PASS(__xpu__embedding_with_eltwise_add_fuse_pass); USE_MIR_PASS(__xpu__embedding_with_eltwise_add_fuse_pass);
USE_MIR_PASS(__xpu__fc_fuse_pass); USE_MIR_PASS(__xpu__fc_fuse_pass);
USE_MIR_PASS(__xpu__mmdnn_fuse_pass);
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <cstddef>
#include <cstdio>
#include <memory>
#include <string>
#include <type_traits>
#include "lite/backends/xpu/xpu_header_sitter.h"
namespace paddle {
namespace lite {
namespace xpu {
template <typename T>
void DumpCPUMem(const T* ptr,
size_t len,
const std::string& comment = "",
size_t stride = 1,
size_t item_per_line = 30) {
size_t after_stride_len = (len + stride - 1) / stride;
std::unique_ptr<T[]> after_stride(new T[after_stride_len]);
for (size_t i = 0; i < after_stride_len; ++i) {
after_stride[i] = ptr[i * stride];
}
double sum = 0;
for (size_t i = 0; i < len; ++i) {
sum += ptr[i];
}
printf(
"------------------------------ [%s] len=%zd stride=%zd sum=%f BEGIN "
"------------------------------\n",
comment.c_str(),
len,
stride,
sum);
size_t nline = (after_stride_len + item_per_line - 1) / item_per_line;
for (size_t i = 0; i < nline; ++i) {
size_t line_begin = i * item_per_line;
size_t line_end = line_begin + item_per_line;
printf("line[%04zd] -- ", i);
for (size_t ii = line_begin; (ii < line_end) && (ii < after_stride_len);
++ii) {
if (std::is_same<T, float>::value) {
printf("%.6f, ", static_cast<float>(after_stride[ii]));
} else if (std::is_same<T, int16_t>::value) {
printf("%d ", static_cast<int>(after_stride[ii]));
} else {
// CHECK(false) << "unknown type";
}
}
printf("\n");
}
printf(
"------------------------------ [%s] len=%zd stride=%zd sum=%f END "
"------------------------------\n",
comment.c_str(),
len,
stride,
sum);
}
template <typename T>
void DumpXPUMem(const T* ptr,
size_t len,
const std::string& comment = "",
size_t stride = 1,
size_t item_per_line = 30) {
size_t after_stride_len = (len + stride - 1) / stride;
std::unique_ptr<T[]> cpu_mem(new T[len]);
xpu_memcpy(
cpu_mem.get(), ptr, len * sizeof(T), XPUMemcpyKind::XPU_DEVICE_TO_HOST);
std::unique_ptr<T[]> after_stride(new T[after_stride_len]);
for (size_t i = 0; i < after_stride_len; ++i) {
after_stride[i] = cpu_mem[i * stride];
}
double sum = 0;
for (size_t i = 0; i < len; ++i) {
sum += cpu_mem[i];
}
printf(
"------------------------------ [%s] len=%zd stride=%zd sum=%f BEGIN "
"------------------------------\n",
comment.c_str(),
len,
stride,
sum);
size_t nline = (after_stride_len + item_per_line - 1) / item_per_line;
for (size_t i = 0; i < nline; ++i) {
size_t line_begin = i * item_per_line;
size_t line_end = line_begin + item_per_line;
printf("line[%04zd] -- ", i);
for (size_t ii = line_begin; (ii < line_end) && (ii < after_stride_len);
++ii) {
if (std::is_same<T, float>::value) {
printf("%.6f, ", static_cast<float>(after_stride[ii]));
} else if (std::is_same<T, int16_t>::value) {
printf("%d ", static_cast<int>(after_stride[ii]));
} else {
// CHECK(false) << "unknown type";
}
}
printf("\n");
}
printf(
"------------------------------ [%s] len=%zd stride=%zd sum=%f END "
"------------------------------\n",
comment.c_str(),
len,
stride,
sum);
}
} // namespace xpu
} // namespace lite
} // namespace paddle
...@@ -13,7 +13,6 @@ ...@@ -13,7 +13,6 @@
// limitations under the License. // limitations under the License.
#include "lite/backends/xpu/target_wrapper.h" #include "lite/backends/xpu/target_wrapper.h"
#include "lite/backends/xpu/xpu_header_sitter.h"
namespace paddle { namespace paddle {
namespace lite { namespace lite {
...@@ -42,5 +41,21 @@ void TargetWrapperXPU::MemcpySync(void* dst, ...@@ -42,5 +41,21 @@ void TargetWrapperXPU::MemcpySync(void* dst,
} }
} }
XPUScratchPadGuard TargetWrapperXPU::MallocScratchPad(size_t size,
bool use_l3) {
void* ptr{nullptr};
if (use_l3) {
ptr = xdnn::alloc_workspace(GetRawContext(), size);
} else {
ptr = TargetWrapperXPU::Malloc(size);
}
CHECK(ptr != nullptr);
return XPUScratchPadGuard(new XPUScratchPad(ptr, use_l3));
}
std::string TargetWrapperXPU::multi_encoder_precision; // NOLINT
int TargetWrapperXPU::workspace_l3_size_per_thread{0};
thread_local xdnn::Context* TargetWrapperXPU::tls_raw_ctx_{nullptr};
} // namespace lite } // namespace lite
} // namespace paddle } // namespace paddle
...@@ -14,6 +14,8 @@ ...@@ -14,6 +14,8 @@
#pragma once #pragma once
#include <memory> // std::unique_ptr
#include "lite/backends/xpu/xpu_header_sitter.h" // xpu_free
#include "lite/core/target_wrapper.h" #include "lite/core/target_wrapper.h"
namespace paddle { namespace paddle {
...@@ -21,6 +23,24 @@ namespace lite { ...@@ -21,6 +23,24 @@ namespace lite {
using TargetWrapperXPU = TargetWrapper<TARGET(kXPU)>; using TargetWrapperXPU = TargetWrapper<TARGET(kXPU)>;
struct XPUScratchPad {
XPUScratchPad(void* addr, bool is_l3) : addr_(addr), is_l3_(is_l3) {}
void* addr_{nullptr};
bool is_l3_{false};
};
struct XPUScratchPadDeleter {
void operator()(XPUScratchPad* sp) const {
if (!sp->is_l3_) {
xpu_free(sp->addr_);
}
delete sp;
}
};
using XPUScratchPadGuard = std::unique_ptr<XPUScratchPad, XPUScratchPadDeleter>;
template <> template <>
class TargetWrapper<TARGET(kXPU)> { class TargetWrapper<TARGET(kXPU)> {
public: public:
...@@ -34,6 +54,41 @@ class TargetWrapper<TARGET(kXPU)> { ...@@ -34,6 +54,41 @@ class TargetWrapper<TARGET(kXPU)> {
const void* src, const void* src,
size_t size, size_t size,
IoDirection dir); IoDirection dir);
static XPUScratchPadGuard MallocScratchPad(size_t size, bool use_l3 = true);
static xdnn::Context* GetRawContext() {
if (tls_raw_ctx_ == nullptr) {
tls_raw_ctx_ = xdnn::create_context();
CHECK(tls_raw_ctx_);
int r = xdnn::set_workspace_l3_size(tls_raw_ctx_,
workspace_l3_size_per_thread);
if (r != 0) {
LOG(WARNING) << "xdnn::set_workspace_l3_size() failed, r = " << r
<< ", workspace_l3_size_per_thread = "
<< workspace_l3_size_per_thread;
}
}
return tls_raw_ctx_;
}
// **DEPRECATED**, use xpu_set_device() at the very beginning of each worker
// thread
static void SetDev(int dev_no = 0) {
const char* dev_env = getenv("LITE_XPU_DEV");
if (dev_env) {
xpu_set_device(atoi(dev_env));
return;
}
xpu_set_device(dev_no);
}
static std::string multi_encoder_precision; // NOLINT
static int workspace_l3_size_per_thread;
private:
static thread_local xdnn::Context* tls_raw_ctx_;
}; };
} // namespace lite } // namespace lite
......
...@@ -21,12 +21,6 @@ namespace lite { ...@@ -21,12 +21,6 @@ namespace lite {
std::string Context<TargetType::kNPU>::subgraph_model_cache_dir_{""}; // NOLINT std::string Context<TargetType::kNPU>::subgraph_model_cache_dir_{""}; // NOLINT
#endif #endif
#ifdef LITE_WITH_XPU
std::string Context<TargetType::kXPU>::_multi_encoder_precision; // NOLINT
thread_local xdnn::Context* Context<TargetType::kXPU>::_tls_raw_ctx{nullptr};
int Context<TargetType::kXPU>::_workspace_l3_size_per_thread{0};
#endif
#ifdef LITE_WITH_MLU #ifdef LITE_WITH_MLU
int Context<TargetType::kMLU>::next_queue_id_{0}; int Context<TargetType::kMLU>::next_queue_id_{0};
std::map<int, int> Context<TargetType::kMLU>::queue_id_map_; std::map<int, int> Context<TargetType::kMLU>::queue_id_map_;
......
...@@ -144,45 +144,12 @@ class Context<TargetType::kXPU> { ...@@ -144,45 +144,12 @@ class Context<TargetType::kXPU> {
void CopySharedTo(XPUContext* ctx) {} void CopySharedTo(XPUContext* ctx) {}
// TODO(miaotianxiang): remove this
static xdnn::Context* GetRawContext() { static xdnn::Context* GetRawContext() {
if (_tls_raw_ctx == nullptr) { return TargetWrapperXPU::GetRawContext();
_tls_raw_ctx = xdnn::create_context();
CHECK(_tls_raw_ctx);
int r = xdnn::set_workspace_l3_size(_tls_raw_ctx,
_workspace_l3_size_per_thread);
if (r != 0) {
LOG(WARNING) << "xdnn::set_workspace_l3_size() failed, r = " << r
<< ", _workspace_l3_size_per_thread = "
<< _workspace_l3_size_per_thread;
}
}
return _tls_raw_ctx;
}
static void SetWorkspaceL3Size(int l3_size = 0xfffc00) {
_workspace_l3_size_per_thread = l3_size;
}
// **DEPRECATED**, use xpu_set_device() at the very beginning of each worker
// thread
static void SetDev(int dev_no = 0) {
const char* dev_env = getenv("LITE_XPU_DEV");
if (dev_env) {
xpu_set_device(atoi(dev_env));
return;
}
xpu_set_device(dev_no);
} }
std::string name() const { return "XPUContext"; } std::string name() const { return "XPUContext"; }
public:
static std::string _multi_encoder_precision; // NOLINT
private:
static thread_local xdnn::Context* _tls_raw_ctx;
static int _workspace_l3_size_per_thread;
}; };
#endif #endif
......
...@@ -23,9 +23,11 @@ lite_cc_library(mir_passes ...@@ -23,9 +23,11 @@ lite_cc_library(mir_passes
fusion/sequence_pool_concat_fuse_pass.cc fusion/sequence_pool_concat_fuse_pass.cc
fusion/scale_activation_fuse_pass.cc fusion/scale_activation_fuse_pass.cc
fusion/__xpu__resnet_fuse_pass.cc fusion/__xpu__resnet_fuse_pass.cc
fusion/__xpu__resnet_cbam_fuse_pass.cc
fusion/__xpu__multi_encoder_fuse_pass.cc fusion/__xpu__multi_encoder_fuse_pass.cc
fusion/__xpu__embedding_with_eltwise_add_fuse_pass.cc fusion/__xpu__embedding_with_eltwise_add_fuse_pass.cc
fusion/__xpu__fc_fuse_pass.cc fusion/__xpu__fc_fuse_pass.cc
fusion/__xpu__mmdnn_fuse_pass.cc
elimination/identity_scale_eliminate_pass.cc elimination/identity_scale_eliminate_pass.cc
elimination/identity_dropout_eliminate_pass.cc elimination/identity_dropout_eliminate_pass.cc
elimination/elementwise_mul_constant_eliminate_pass.cc elimination/elementwise_mul_constant_eliminate_pass.cc
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <memory>
#include <vector>
#include "lite/backends/xpu/math.h"
#include "lite/core/mir/pass_registry.h"
#include "lite/core/mir/xpu_pattern_matcher_high_api.h"
namespace paddle {
namespace lite {
namespace mir {
namespace fusion {
class XPUMmdnnFloat2Fix {
public:
void operator()(SSAGraph* graph) {
for (auto* node : graph->StmtTopologicalOrder()) {
CHECK(node->IsStmt());
auto* op_info = node->stmt()->op_info();
std::string op_type = op_info->Type();
static const std::vector<std::string> target_ops{"var_conv_2d",
"search_fc"};
if (std::find(target_ops.begin(), target_ops.end(), op_type) !=
target_ops.end()) {
std::string weight_name = op_info->Input("W").front();
auto* scope = node->stmt()->op()->scope();
auto* weight_t = scope->FindMutableTensor(weight_name);
auto weight_dims = weight_t->dims();
auto weight_len = weight_t->numel();
float* weight_on_host = weight_t->mutable_data<float>();
float max_f =
paddle::lite::xpu::math::FindMaxAbs(weight_on_host, weight_len);
std::unique_ptr<int16_t[]> weight_int16(new int16_t[weight_len]);
paddle::lite::xpu::math::ConvertFP32ToInt16(
weight_on_host, weight_int16.get(), max_f, weight_len);
memcpy(
weight_on_host, weight_int16.get(), weight_len * sizeof(int16_t));
auto update_op_info = *op_info;
update_op_info.SetAttr<bool>("__xpu__float_to_fix", true);
update_op_info.SetAttr<float>("__xpu__w_max", max_f);
node->stmt()->ResetOp(update_op_info, graph->valid_places());
VLOG(3) << "Float2Fix, op_type=" << op_type
<< ", weight_name=" << weight_name;
} else if (op_type == "match_matrix_tensor") {
std::string weight_name = op_info->Input("W").front();
auto* scope = node->stmt()->op()->scope();
auto* weight_t = scope->FindMutableTensor(weight_name);
auto weight_dims = weight_t->dims();
auto weight_len = weight_t->numel();
float* weight_on_host = weight_t->mutable_data<float>();
float max_f =
paddle::lite::xpu::math::FindMaxAbs(weight_on_host, weight_len);
std::unique_ptr<int16_t[]> weight_int16(new int16_t[weight_len]);
std::unique_ptr<int16_t[]> weight_trans_int16(new int16_t[weight_len]);
paddle::lite::xpu::math::ConvertFP32ToInt16(
weight_on_host, weight_int16.get(), max_f, weight_len);
paddle::lite::xpu::math::Transpose(weight_int16.get(),
weight_trans_int16.get(),
weight_dims[0],
weight_dims[1] * weight_dims[2]);
memcpy(weight_on_host,
weight_trans_int16.get(),
weight_len * sizeof(int16_t));
auto update_op_info = *op_info;
update_op_info.SetAttr<bool>("__xpu__float_to_fix", true);
update_op_info.SetAttr<float>("__xpu__w_max", max_f);
node->stmt()->ResetOp(update_op_info, graph->valid_places());
VLOG(3) << "Float2Fix && Transposed, op_type=" << op_type
<< ", weight_name=" << weight_name;
} else if (op_type == "search_grnn") {
auto* scope = node->stmt()->op()->scope();
std::string wi_name = op_info->Input("Wi").front();
auto* wi_t = scope->FindMutableTensor(wi_name);
auto wi_dims = wi_t->dims();
auto wi_len = wi_t->numel();
auto wi_stride_len = wi_len / 3;
float* wi_on_host = wi_t->mutable_data<float>();
std::unique_ptr<int16_t[]> wi_int16(new int16_t[wi_len]);
std::vector<float> wi_max(3);
for (int i = 0; i < 3; ++i) {
float max_f = paddle::lite::xpu::math::FindMaxAbs(
wi_on_host + i * wi_stride_len, wi_stride_len);
paddle::lite::xpu::math::ConvertFP32ToInt16(
wi_on_host + i * wi_stride_len,
wi_int16.get() + i * wi_stride_len,
max_f,
wi_stride_len);
wi_max[i] = max_f;
}
memcpy(wi_on_host, wi_int16.get(), wi_len * sizeof(int16_t));
std::string wh_name = op_info->Input("Wh").front();
auto* wh_t = scope->FindMutableTensor(wh_name);
auto wh_dims = wh_t->dims();
auto wh_len = wh_t->numel();
auto wh_stride_len = wh_len / 3;
float* wh_on_host = wh_t->mutable_data<float>();
std::unique_ptr<int16_t[]> wh_int16(new int16_t[wh_len]);
std::vector<float> wh_max(3);
for (int i = 0; i < 3; ++i) {
float max_f = paddle::lite::xpu::math::FindMaxAbs(
wh_on_host + i * wh_stride_len, wh_stride_len);
paddle::lite::xpu::math::ConvertFP32ToInt16(
wh_on_host + i * wh_stride_len,
wh_int16.get() + i * wh_stride_len,
max_f,
wh_stride_len);
wh_max[i] = max_f;
}
memcpy(wh_on_host, wh_int16.get(), wh_len * sizeof(int16_t));
auto update_op_info = *op_info;
update_op_info.SetAttr<bool>("__xpu__float_to_fix", true);
update_op_info.SetAttr<std::vector<float>>("__xpu__wi_max", wi_max);
update_op_info.SetAttr<std::vector<float>>("__xpu__wh_max", wh_max);
node->stmt()->ResetOp(update_op_info, graph->valid_places());
VLOG(3) << "Float2Fix, op_type=" << op_type << ", wi_name=" << wi_name
<< ", wh_name=" << wh_name;
}
}
}
};
class XPUMmdnnSearchAttentionFuser : public FuseBase {
public:
void BuildPattern() override {
auto* input = VarNode("input")->AsInput();
auto* search_group_padding =
OpNode("search_group_padding", "search_group_padding");
auto* out_emb_padding =
VarNode("out_emb_padding")
->assert_is_op_output("search_group_padding", "Out_emb_padding")
->AsIntermediate();
auto* out_new = VarNode("out_new")
->assert_is_op_output("search_group_padding", "Out_new")
->AsIntermediate();
auto* out_padding =
VarNode("out_padding")
->assert_is_op_output("search_group_padding", "Out_padding")
->AsIntermediate();
auto* search_seq_fc_w = VarNode("search_seq_fc_w")
->assert_is_op_input("search_seq_fc", "W")
->AsInput();
auto* search_seq_fc_b = VarNode("search_seq_fc_b")
->assert_is_op_input("search_seq_fc", "b")
->AsInput();
auto* search_seq_fc =
OpNode("search_seq_fc", "search_seq_fc")->AsIntermediate();
auto* search_seq_fc_out = VarNode("search_seq_fc_out")
->assert_is_op_output("search_seq_fc", "Out")
->AsIntermediate();
auto* search_aligned_mat_mul =
OpNode("search_aligned_mat_mul", "search_aligned_mat_mul")
->AsIntermediate();
auto* search_aligned_mat_mul_out =
VarNode("search_aligned_mat_mul_out")
->assert_is_op_output("search_aligned_mat_mul", "Out")
->AsIntermediate();
auto* search_aligned_mat_mul_a =
VarNode("search_aligned_mat_mul_a")
->assert_is_op_output("search_aligned_mat_mul", "_a_addr")
->AsIntermediate();
auto* search_aligned_mat_mul_b =
VarNode("search_aligned_mat_mul_b")
->assert_is_op_output("search_aligned_mat_mul", "_b_addr")
->AsIntermediate();
auto* search_aligned_mat_mul_c =
VarNode("search_aligned_mat_mul_c")
->assert_is_op_output("search_aligned_mat_mul", "_c_addr")
->AsIntermediate();
auto* search_attention_padding_mask =
OpNode("search_attention_padding_mask", "search_attention_padding_mask")
->AsIntermediate();
auto* search_attention_padding_mask_out =
VarNode("search_attention_padding_mask_out")
->assert_is_op_output("search_attention_padding_mask", "Out")
->AsIntermediate();
auto* search_attention_padding_mask_pad_begin =
VarNode("search_attention_padding_mask_pad_begin")
->assert_is_op_output("search_attention_padding_mask", "pad_begin")
->AsIntermediate();
auto* search_seq_softmax =
OpNode("search_seq_softmax", "search_seq_softmax")->AsIntermediate();
auto* search_seq_softmax_out =
VarNode("search_seq_softmax_out")
->assert_is_op_output("search_seq_softmax", "Out")
->AsIntermediate();
auto* search_seq_softmax_out_log =
VarNode("search_seq_softmax_out_log")
->assert_is_op_output("search_seq_softmax", "Out_log")
->AsIntermediate();
auto* search_aligned_mat_mul_2 =
OpNode("search_aligned_mat_mul_2", "search_aligned_mat_mul")
->AsIntermediate();
auto* search_aligned_mat_mul_2_out =
VarNode("search_aligned_mat_mul_2_out")
->assert_is_op_output("search_aligned_mat_mul", "Out")
->AsIntermediate();
auto* search_aligned_mat_mul_2_a =
VarNode("search_aligned_mat_mul_2_a")
->assert_is_op_output("search_aligned_mat_mul", "_a_addr")
->AsIntermediate();
auto* search_aligned_mat_mul_2_b =
VarNode("search_aligned_mat_mul_2_b")
->assert_is_op_output("search_aligned_mat_mul", "_b_addr")
->AsIntermediate();
auto* search_aligned_mat_mul_2_c =
VarNode("search_aligned_mat_mul_2_c")
->assert_is_op_output("search_aligned_mat_mul", "_c_addr")
->AsIntermediate();
auto* search_seq_depadding =
OpNode("search_seq_depadding")->AsIntermediate();
auto* search_seq_depadding_out =
VarNode("search_seq_depadding_out")->AsOutput();
*input >> *search_group_padding >> *out_emb_padding;
*search_group_padding >> *out_new;
*search_group_padding >> *out_padding;
*search_seq_fc_w >> *search_seq_fc;
*search_seq_fc_b >> *search_seq_fc;
*out_emb_padding >> *search_seq_fc;
*search_seq_fc >> *search_seq_fc_out;
*search_seq_fc_out >> *search_aligned_mat_mul;
*out_emb_padding >> *search_aligned_mat_mul;
*search_aligned_mat_mul >> *search_aligned_mat_mul_out;
*search_aligned_mat_mul >> *search_aligned_mat_mul_a;
*search_aligned_mat_mul >> *search_aligned_mat_mul_b;
*search_aligned_mat_mul >> *search_aligned_mat_mul_c;
*search_aligned_mat_mul_out >> *search_attention_padding_mask;
*out_padding >> *search_attention_padding_mask;
*search_attention_padding_mask >> *search_attention_padding_mask_out;
*search_attention_padding_mask >> *search_attention_padding_mask_pad_begin;
*search_attention_padding_mask_out >> *search_seq_softmax;
*search_seq_softmax >> *search_seq_softmax_out;
*search_seq_softmax >> *search_seq_softmax_out_log;
*search_seq_softmax_out >> *search_aligned_mat_mul_2;
*out_emb_padding >> *search_aligned_mat_mul_2;
*search_aligned_mat_mul_2 >> *search_aligned_mat_mul_2_out;
*search_aligned_mat_mul_2 >> *search_aligned_mat_mul_2_a;
*search_aligned_mat_mul_2 >> *search_aligned_mat_mul_2_b;
*search_aligned_mat_mul_2 >> *search_aligned_mat_mul_2_c;
*search_aligned_mat_mul_2_out >> *search_seq_depadding;
*out_new >> *search_seq_depadding;
*search_seq_depadding >> *search_seq_depadding_out;
}
void InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) override {
cpp::OpDesc op_desc;
op_desc.SetType("__xpu__mmdnn_search_attention");
op_desc.SetInput("X", {matched.at("input")->arg()->name});
op_desc.SetInput("W", {matched.at("search_seq_fc_w")->arg()->name});
op_desc.SetInput("b", {matched.at("search_seq_fc_b")->arg()->name});
op_desc.SetOutput("Out",
{matched.at("search_seq_depadding_out")->arg()->name});
auto* padding_op_info =
matched.at("search_group_padding")->stmt()->op_info();
op_desc.SetAttr<int>("pad_id", padding_op_info->GetAttr<int>("pad_id"));
auto* matmul_0_op_info =
matched.at("search_aligned_mat_mul")->stmt()->op_info();
op_desc.SetAttr<float>("alpha0", matmul_0_op_info->GetAttr<float>("alpha"));
auto* matmul_1_op_info =
matched.at("search_aligned_mat_mul_2")->stmt()->op_info();
op_desc.SetAttr<float>("alpha1", matmul_1_op_info->GetAttr<float>("alpha"));
auto* mask_op_info =
matched.at("search_attention_padding_mask")->stmt()->op_info();
op_desc.SetAttr<float>("mask", mask_op_info->GetAttr<float>("mask"));
auto* new_stmt = matched.at("search_group_padding")->stmt();
auto* scope = new_stmt->op()->scope();
auto w_name = matched.at("search_seq_fc_w")->arg()->name;
auto* w_t = scope->FindMutableTensor(w_name);
auto w_dims = w_t->dims();
int w_len = w_t->numel();
float* w_on_host = w_t->mutable_data<float>();
float max_f = paddle::lite::xpu::math::FindMaxAbs(w_on_host, w_len);
std::unique_ptr<int16_t[]> w_int16(new int16_t[w_len]);
paddle::lite::xpu::math::ConvertFP32ToInt16(
w_on_host, w_int16.get(), max_f, w_len);
memcpy(w_on_host, w_int16.get(), w_len * sizeof(int16_t));
op_desc.SetAttr<float>("W_max", max_f);
auto new_op = LiteOpRegistry::Global().Create(op_desc.Type());
new_op->Attach(op_desc, scope);
new_op->SetValidPlaces(new_stmt->op()->valid_places());
auto kernels = new_op->CreateKernels(new_op->valid_places());
new_stmt->SetOp(new_op);
new_stmt->SetKernels(std::move(kernels));
DirectedLink(matched.at("search_seq_fc_w"),
matched.at("search_group_padding"));
DirectedLink(matched.at("search_seq_fc_b"),
matched.at("search_group_padding"));
IR_OP_VAR_LINK(matched.at("search_group_padding"),
matched.at("search_seq_depadding_out"));
}
};
class XPUMmdnnMatchConvTopkFuser : public FuseBase {
public:
void BuildPattern() override {
auto* input_x = VarNode("input_x")
->assert_is_op_input("match_matrix_tensor", "X")
->AsInput();
auto* input_y = VarNode("input_y")
->assert_is_op_input("match_matrix_tensor", "Y")
->AsInput();
auto* input_w = VarNode("input_w")
->assert_is_op_input("match_matrix_tensor", "W")
->AsInput();
auto* match_matrix_tensor =
OpNode("match_matrix_tensor", "match_matrix_tensor");
auto* match_out = VarNode("match_out")
->assert_is_op_output("match_matrix_tensor", "Out")
->AsIntermediate();
auto* match_tmp = VarNode("match_tmp")
->assert_is_op_output("match_matrix_tensor", "Tmp")
->AsIntermediate();
auto* relu0 = OpNode("relu0", "relu")->AsIntermediate();
auto* relu0_out = VarNode("relu0_out")
->assert_is_op_output("relu", "Out")
->AsIntermediate();
auto* conv_w =
VarNode("conv_w")->assert_is_op_input("var_conv_2d", "W")->AsInput();
auto* conv = OpNode("conv", "var_conv_2d")->AsIntermediate();
auto* conv_out = VarNode("conv_out")
->assert_is_op_output("var_conv_2d", "Out")
->AsIntermediate();
auto* conv_col = VarNode("conv_col")
->assert_is_op_output("var_conv_2d", "Col")
->AsIntermediate();
auto* relu1 = OpNode("relu1", "relu")->AsIntermediate();
auto* relu1_out = VarNode("relu1_out")
->assert_is_op_output("relu", "Out")
->AsIntermediate();
auto* seq_concat =
OpNode("seq_concat", "sequence_concat")->AsIntermediate();
auto* seq_concat_out =
VarNode("seq_concat_out")
->assert_is_op_output("sequence_concat", "Out")
->assert_is_op_input("sequence_topk_avg_pooling", "X")
->AsIntermediate();
auto* topk_col =
VarNode("topk_col")
->assert_is_op_input("sequence_topk_avg_pooling", "COLUMN")
->AsInput();
auto* topk_row =
VarNode("topk_row")
->assert_is_op_input("sequence_topk_avg_pooling", "ROW")
->AsInput();
auto* topk = OpNode("topk", "sequence_topk_avg_pooling")->AsIntermediate();
auto* topk_out =
VarNode("topk_out")
->assert_is_op_output("sequence_topk_avg_pooling", "Out")
->AsOutput();
auto* topk_pos =
VarNode("topk_pos")
->assert_is_op_output("sequence_topk_avg_pooling", "pos")
->AsIntermediate();
*input_x >> *match_matrix_tensor;
*input_y >> *match_matrix_tensor;
*input_w >> *match_matrix_tensor;
*match_matrix_tensor >> *match_out >> *relu0 >> *relu0_out;
*match_matrix_tensor >> *match_tmp;
*relu0_out >> *conv >> *conv_out >> *relu1 >> *relu1_out;
*conv_w >> *conv;
*conv >> *conv_col;
*relu0_out >> *seq_concat;
*relu1_out >> *seq_concat;
*seq_concat >> *seq_concat_out >> *topk >> *topk_out;
*topk_col >> *topk;
*topk_row >> *topk;
*topk >> *topk_pos;
}
void InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) override {
cpp::OpDesc op_desc;
op_desc.SetType("__xpu__mmdnn_match_conv_topk");
op_desc.SetInput("input_x", {matched.at("input_x")->arg()->name});
op_desc.SetInput("input_y", {matched.at("input_y")->arg()->name});
op_desc.SetInput("input_w", {matched.at("input_w")->arg()->name});
op_desc.SetInput("conv_w", {matched.at("conv_w")->arg()->name});
op_desc.SetOutput("topk_out", {matched.at("topk_out")->arg()->name});
auto* match_op_info = matched.at("match_matrix_tensor")->stmt()->op_info();
op_desc.SetAttr<float>("input_w_max",
match_op_info->GetAttr<float>("w_max"));
op_desc.SetAttr<int>("dim_t", match_op_info->GetAttr<int>("dim_t"));
auto* conv_op_info = matched.at("conv")->stmt()->op_info();
op_desc.SetAttr<float>("conv_w_max", conv_op_info->GetAttr<float>("w_max"));
auto* topk_op_info = matched.at("topk")->stmt()->op_info();
op_desc.SetAttr<std::vector<int>>(
"topks", topk_op_info->GetAttr<std::vector<int>>("topks"));
op_desc.SetAttr<int>("channel_num",
topk_op_info->GetAttr<int>("channel_num"));
auto* new_stmt = matched.at("match_matrix_tensor")->stmt();
auto new_op = LiteOpRegistry::Global().Create(op_desc.Type());
new_op->Attach(op_desc, new_stmt->op()->scope());
new_op->SetValidPlaces(new_stmt->op()->valid_places());
auto kernels = new_op->CreateKernels(new_op->valid_places());
new_stmt->SetOp(new_op);
new_stmt->SetKernels(std::move(kernels));
// XXX(miaotianxiang): redundant links around |topk| are automatically
// removed as |topk| is
// marked intermediate.
// RemoveDirectedLink(matched.at("topk_col"), matched.at("topk"));
// RemoveDirectedLink(matched.at("topk_row"), matched.at("topk"));
std::vector<std::string> arg_names{"conv_w"};
for (auto name : arg_names) {
DirectedLink(matched.at(name), matched.at("match_matrix_tensor"));
}
std::vector<std::string> out_names{"topk_out"};
for (auto name : out_names) {
IR_OP_VAR_LINK(matched.at("match_matrix_tensor"), matched.at(name));
}
}
};
class XPUMmdnnBidSeqRevEmbEltwiseFuser : public FuseBase {
public:
void BuildPattern() override {
auto* input0 = VarNode("input0")->AsInput();
auto* input1 = VarNode("input1")->AsInput();
auto* emb_tbl = VarNode("emb_tbl")->AsInput();
// fwd emb
auto* emb0 = OpNode("emb0", "lookup_table");
auto* emb0_out =
VarNode("emb0_out")->assert_is_op_output("lookup_table", "Out");
auto* emb1 = OpNode("emb1", "lookup_table");
auto* emb1_out =
VarNode("emb1_out")->assert_is_op_output("lookup_table", "Out");
auto* eltwise01 = OpNode("eltwise01", "search_seq_arithmetic");
auto* eltwise01_out =
VarNode("eltwise01_out")
->assert_is_op_output("search_seq_arithmetic", "Out")
->AsOutput();
// rev emb
auto* seq_rev2 = OpNode("seq_rev2", "sequence_reverse")->AsIntermediate();
auto* seq_rev2_out = VarNode("seq_rev2_out")
->assert_is_op_output("sequence_reverse", "Y")
->AsIntermediate();
auto* seq_rev3 = OpNode("seq_rev3", "sequence_reverse")->AsIntermediate();
auto* seq_rev3_out = VarNode("seq_rev3_out")
->assert_is_op_output("sequence_reverse", "Y")
->AsIntermediate();
auto* emb2 = OpNode("emb2", "lookup_table")->AsIntermediate();
auto* emb2_out = VarNode("emb2_out")
->assert_is_op_output("lookup_table", "Out")
->AsIntermediate();
auto* emb3 = OpNode("emb3", "lookup_table")->AsIntermediate();
auto* emb3_out = VarNode("emb3_out")
->assert_is_op_output("lookup_table", "Out")
->AsIntermediate();
auto* eltwise23 =
OpNode("eltwise23", "search_seq_arithmetic")->AsIntermediate();
auto* eltwise23_out =
VarNode("eltwise23_out")
->assert_is_op_output("search_seq_arithmetic", "Out")
->AsOutput();
*input0 >> *emb0 >> *emb0_out >> *eltwise01 >> *eltwise01_out;
*emb_tbl >> *emb0;
*input1 >> *emb1 >> *emb1_out >> *eltwise01;
*emb_tbl >> *emb1;
*input0 >> *seq_rev2 >> *seq_rev2_out >> *emb2 >> *emb2_out >> *eltwise23 >>
*eltwise23_out;
*emb_tbl >> *emb2;
*input1 >> *seq_rev3 >> *seq_rev3_out >> *emb3 >> *emb3_out >> *eltwise23;
*emb_tbl >> *emb3;
}
void InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) override {
cpp::OpDesc op_desc;
op_desc.SetType("sequence_reverse");
op_desc.SetInput("X", {matched.at("eltwise01_out")->arg()->name});
op_desc.SetOutput("Y", {matched.at("eltwise23_out")->arg()->name});
auto emb0_op = matched.at("emb0")->stmt()->op();
auto new_seq_rev_op = LiteOpRegistry::Global().Create("sequence_reverse");
new_seq_rev_op->Attach(op_desc, emb0_op->scope());
auto* new_seq_rev_node =
graph->GraphCreateInstructNode(new_seq_rev_op, emb0_op->valid_places());
DirectedLink(matched.at("eltwise01_out"), new_seq_rev_node);
DirectedLink(new_seq_rev_node, matched.at("eltwise23_out"));
}
};
class XPUMmdnnBidEmbAttFuser : public FuseBase {
public:
void BuildPattern() override {
auto* input0 = VarNode("input0")->AsInput();
auto* input1 = VarNode("input1")->AsInput();
auto* emb_tbl = VarNode("emb_tbl")->AsInput();
auto* emb0 = OpNode("emb0", "lookup_table");
auto* emb0_out = VarNode("emb0_out")
->assert_is_op_output("lookup_table", "Out")
->AsIntermediate();
auto* emb1 = OpNode("emb1", "lookup_table")->AsIntermediate();
auto* emb1_out = VarNode("emb1_out")
->assert_is_op_output("lookup_table", "Out")
->AsIntermediate();
auto* eltwise01 =
OpNode("eltwise01", "search_seq_arithmetic")->AsIntermediate();
auto* eltwise01_out =
VarNode("eltwise01_out")
->assert_is_op_output("search_seq_arithmetic", "Out")
->AsOutput();
auto* att_2in1_w =
VarNode("att_2in1_w")
->assert_is_op_input("__xpu__mmdnn_search_attention", "W")
->AsInput();
auto* att_2in1_b =
VarNode("att_2in1_b")
->assert_is_op_input("__xpu__mmdnn_search_attention", "b")
->AsInput();
auto* att_2in1 =
OpNode("att_2in1", "__xpu__mmdnn_search_attention")->AsIntermediate();
auto* att_2in1_out =
VarNode("att_2in1_out")
->assert_is_op_output("__xpu__mmdnn_search_attention", "Out")
->AsIntermediate();
auto* seq_pool_2in1 =
OpNode("seq_pool_2in1", "sequence_pool")->AsIntermediate();
auto* seq_pool_2in1_out = VarNode("seq_pool_2in1_out")
->assert_is_op_output("sequence_pool", "Out")
->AsOutput();
auto* seq_pool_2in1_max_idx =
VarNode("seq_pool_2in1_max_idx")
->assert_is_op_output("sequence_pool", "MaxIndex")
->AsIntermediate();
*input0 >> *emb0 >> *emb0_out >> *eltwise01 >> *eltwise01_out;
*emb_tbl >> *emb0;
*input1 >> *emb1 >> *emb1_out >> *eltwise01;
*emb_tbl >> *emb1;
*eltwise01_out >> *att_2in1 >> *att_2in1_out >> *seq_pool_2in1 >>
*seq_pool_2in1_out;
*seq_pool_2in1 >> *seq_pool_2in1_max_idx;
*att_2in1_w >> *att_2in1;
*att_2in1_b >> *att_2in1;
}
void InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) override {
cpp::OpDesc op_desc;
op_desc.SetType("__xpu__mmdnn_bid_emb_att");
op_desc.SetInput("id0", {matched.at("input0")->arg()->name});
op_desc.SetInput("id1", {matched.at("input1")->arg()->name});
op_desc.SetInput("emb_tbl", {matched.at("emb_tbl")->arg()->name});
op_desc.SetInput("att_fc_w", {matched.at("att_2in1_w")->arg()->name});
op_desc.SetInput("att_fc_b", {matched.at("att_2in1_b")->arg()->name});
op_desc.SetOutput("att_pool_out",
{matched.at("seq_pool_2in1_out")->arg()->name});
op_desc.SetOutput("emb_fw_out", {matched.at("eltwise01_out")->arg()->name});
auto* att_fc_op_info = matched.at("att_2in1")->stmt()->op_info();
op_desc.SetAttr<float>("att_fc_w_max",
att_fc_op_info->GetAttr<float>("W_max"));
auto* new_stmt = matched.at("emb0")->stmt();
auto new_op = LiteOpRegistry::Global().Create(op_desc.Type());
new_op->Attach(op_desc, new_stmt->op()->scope());
new_op->SetValidPlaces(new_stmt->op()->valid_places());
auto kernels = new_op->CreateKernels(new_op->valid_places());
new_stmt->SetOp(new_op);
new_stmt->SetKernels(std::move(kernels));
std::vector<std::string> arg_names{
"input1", "att_2in1_w", "att_2in1_b",
};
for (auto name : arg_names) {
DirectedLink(matched.at(name), matched.at("emb0"));
}
std::vector<std::string> out_names{
"seq_pool_2in1_out", "eltwise01_out",
};
for (auto name : out_names) {
IR_OP_VAR_LINK(matched.at("emb0"), matched.at(name));
}
}
};
class XPUMmdnnBidEmbGrnnAttFuser : public FuseBase {
public:
void BuildPattern() override {
auto* input0 = VarNode("input0")->AsInput();
auto* input1 = VarNode("input1")->AsInput();
auto* emb_tbl = VarNode("emb_tbl")->AsInput();
auto* emb0 = OpNode("emb0", "lookup_table");
auto* emb0_out = VarNode("emb0_out")
->assert_is_op_output("lookup_table", "Out")
->AsIntermediate();
auto* emb1 = OpNode("emb1", "lookup_table")->AsIntermediate();
auto* emb1_out = VarNode("emb1_out")
->assert_is_op_output("lookup_table", "Out")
->AsIntermediate();
auto* eltwise01 =
OpNode("eltwise01", "search_seq_arithmetic")->AsIntermediate();
auto* eltwise01_out =
VarNode("eltwise01_out")
->assert_is_op_output("search_seq_arithmetic", "Out")
->AsOutput();
auto* seq_rev_right0 =
OpNode("seq_rev_right0", "sequence_reverse")->AsIntermediate();
auto* seq_rev_right0_out =
VarNode("seq_rev_right0_out")
->assert_is_op_output("sequence_reverse", "Y")
->AsIntermediate();
auto* grnn_right_wh = VarNode("grnn_right_wh")
->assert_is_op_input("search_grnn", "Wh")
->AsInput();
auto* grnn_right_wi = VarNode("grnn_right_wi")
->assert_is_op_input("search_grnn", "Wi")
->AsInput();
auto* grnn_right = OpNode("grnn_right", "search_grnn")->AsIntermediate();
auto* grnn_right_out = VarNode("grnn_right_out")
->assert_is_op_output("search_grnn", "Out")
->AsIntermediate();
auto* grnn_right_idx_sorted_by_width =
VarNode("grnn_right_idx_sorted_by_width")
->assert_is_op_output("search_grnn", "idx_sorted_by_width")
->AsIntermediate();
auto* grnn_right_layout_input =
VarNode("grnn_right_layout_input")
->assert_is_op_output("search_grnn", "layout_input")
->AsIntermediate();
auto* grnn_right_tmp_buffer =
VarNode("grnn_right_tmp_buffer")
->assert_is_op_output("search_grnn", "tmp_buffer")
->AsIntermediate();
auto* seq_rev_right1 =
OpNode("seq_rev_right1", "sequence_reverse")->AsIntermediate();
auto* seq_rev_right1_out =
VarNode("seq_rev_right1_out")
->assert_is_op_output("sequence_reverse", "Y")
->AsIntermediate();
auto* seq_pool_right =
OpNode("seq_pool_right", "sequence_pool")->AsIntermediate();
auto* seq_pool_right_out = VarNode("seq_pool_right_out")
->assert_is_op_output("sequence_pool", "Out")
->AsOutput();
auto* seq_pool_right_max_idx =
VarNode("seq_pool_right_max_idx")
->assert_is_op_output("sequence_pool", "MaxIndex")
->AsIntermediate();
auto* grnn_left_wh = VarNode("grnn_left_wh")
->assert_is_op_input("search_grnn", "Wh")
->AsInput();
auto* grnn_left_wi = VarNode("grnn_left_wi")
->assert_is_op_input("search_grnn", "Wi")
->AsInput();
auto* grnn_left = OpNode("grnn_left", "search_grnn")->AsIntermediate();
auto* grnn_left_out = VarNode("grnn_left_out")
->assert_is_op_output("search_grnn", "Out")
->AsIntermediate();
auto* grnn_left_idx_sorted_by_width =
VarNode("grnn_left_idx_sorted_by_width")
->assert_is_op_output("search_grnn", "idx_sorted_by_width")
->AsIntermediate();
auto* grnn_left_layout_input =
VarNode("grnn_left_layout_input")
->assert_is_op_output("search_grnn", "layout_input")
->AsIntermediate();
auto* grnn_left_tmp_buffer =
VarNode("grnn_left_tmp_buffer")
->assert_is_op_output("search_grnn", "tmp_buffer")
->AsIntermediate();
auto* seq_pool_left =
OpNode("seq_pool_left", "sequence_pool")->AsIntermediate();
auto* seq_pool_left_out = VarNode("seq_pool_left_out")
->assert_is_op_output("sequence_pool", "Out")
->AsOutput();
auto* seq_pool_left_max_idx =
VarNode("seq_pool_left_max_idx")
->assert_is_op_output("sequence_pool", "MaxIndex")
->AsIntermediate();
auto* concat_2in1 = OpNode("concat_2in1", "concat")->AsIntermediate();
auto* concat_2in1_out = VarNode("concat_2in1_out")
->assert_is_op_output("concat", "Out")
->AsIntermediate();
auto* att_2in1_w =
VarNode("att_2in1_w")
->assert_is_op_input("__xpu__mmdnn_search_attention", "W")
->AsInput();
auto* att_2in1_b =
VarNode("att_2in1_b")
->assert_is_op_input("__xpu__mmdnn_search_attention", "b")
->AsInput();
auto* att_2in1 =
OpNode("att_2in1", "__xpu__mmdnn_search_attention")->AsIntermediate();
auto* att_2in1_out =
VarNode("att_2in1_out")
->assert_is_op_output("__xpu__mmdnn_search_attention", "Out")
->AsIntermediate();
auto* seq_pool_2in1 =
OpNode("seq_pool_2in1", "sequence_pool")->AsIntermediate();
auto* seq_pool_2in1_out = VarNode("seq_pool_2in1_out")
->assert_is_op_output("sequence_pool", "Out")
->AsOutput();
auto* seq_pool_2in1_max_idx =
VarNode("seq_pool_2in1_max_idx")
->assert_is_op_output("sequence_pool", "MaxIndex")
->AsIntermediate();
auto* concat_3in1 = OpNode("concat_3in1", "concat")->AsIntermediate();
auto* concat_3in1_out = VarNode("concat_3in1_out")
->assert_is_op_output("concat", "Out")
->AsOutput();
*input0 >> *emb0 >> *emb0_out >> *eltwise01 >> *eltwise01_out;
*emb_tbl >> *emb0;
*input1 >> *emb1 >> *emb1_out >> *eltwise01;
*emb_tbl >> *emb1;
*eltwise01_out >> *seq_rev_right0 >> *seq_rev_right0_out >> *grnn_right >>
*grnn_right_out >> *seq_rev_right1 >> *seq_rev_right1_out;
*grnn_right_out >> *seq_pool_right >> *seq_pool_right_out;
*seq_pool_right >> *seq_pool_right_max_idx;
*grnn_right_wh >> *grnn_right;
*grnn_right_wi >> *grnn_right;
*grnn_right >> *grnn_right_idx_sorted_by_width;
*grnn_right >> *grnn_right_layout_input;
*grnn_right >> *grnn_right_tmp_buffer;
*eltwise01_out >> *grnn_left >> *grnn_left_out >> *seq_pool_left >>
*seq_pool_left_out;
*seq_pool_left >> *seq_pool_left_max_idx;
*grnn_left_wh >> *grnn_left;
*grnn_left_wi >> *grnn_left;
*grnn_left >> *grnn_left_idx_sorted_by_width;
*grnn_left >> *grnn_left_layout_input;
*grnn_left >> *grnn_left_tmp_buffer;
*seq_rev_right1_out >> *concat_2in1;
*grnn_left_out >> *concat_2in1;
*concat_2in1 >> *concat_2in1_out >> *att_2in1 >> *att_2in1_out >>
*seq_pool_2in1 >> *seq_pool_2in1_out;
*seq_pool_2in1 >> *seq_pool_2in1_max_idx;
*att_2in1_w >> *att_2in1;
*att_2in1_b >> *att_2in1;
*eltwise01_out >> *concat_3in1;
*seq_rev_right1_out >> *concat_3in1;
*grnn_left_out >> *concat_3in1;
*concat_3in1 >> *concat_3in1_out;
}
void InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) override {
cpp::OpDesc op_desc;
op_desc.SetType("__xpu__mmdnn_bid_emb_grnn_att");
op_desc.SetInput("id0", {matched.at("input0")->arg()->name});
op_desc.SetInput("id1", {matched.at("input1")->arg()->name});
op_desc.SetInput("emb_tbl", {matched.at("emb_tbl")->arg()->name});
op_desc.SetInput("grnn_fw_wh", {matched.at("grnn_left_wh")->arg()->name});
op_desc.SetInput("grnn_fw_wi", {matched.at("grnn_left_wi")->arg()->name});
op_desc.SetInput("grnn_rv_wh", {matched.at("grnn_right_wh")->arg()->name});
op_desc.SetInput("grnn_rv_wi", {matched.at("grnn_right_wi")->arg()->name});
op_desc.SetInput("att_fc_w", {matched.at("att_2in1_w")->arg()->name});
op_desc.SetInput("att_fc_b", {matched.at("att_2in1_b")->arg()->name});
op_desc.SetOutput("grnn_fw_pool_out",
{matched.at("seq_pool_left_out")->arg()->name});
op_desc.SetOutput("grnn_rv_pool_out",
{matched.at("seq_pool_right_out")->arg()->name});
op_desc.SetOutput("att_pool_out",
{matched.at("seq_pool_2in1_out")->arg()->name});
op_desc.SetOutput("concat_3in1_out",
{matched.at("concat_3in1_out")->arg()->name});
op_desc.SetOutput("emb_fw_out", {matched.at("eltwise01_out")->arg()->name});
auto* grnn_fw_op_info = matched.at("grnn_left")->stmt()->op_info();
op_desc.SetAttr<std::vector<float>>(
"grnn_fw_wh_maxs",
grnn_fw_op_info->GetAttr<std::vector<float>>("wh_max"));
op_desc.SetAttr<std::vector<float>>(
"grnn_fw_wi_maxs",
grnn_fw_op_info->GetAttr<std::vector<float>>("wi_max"));
auto* grnn_rv_op_info = matched.at("grnn_right")->stmt()->op_info();
op_desc.SetAttr<std::vector<float>>(
"grnn_rv_wh_maxs",
grnn_rv_op_info->GetAttr<std::vector<float>>("wh_max"));
op_desc.SetAttr<std::vector<float>>(
"grnn_rv_wi_maxs",
grnn_rv_op_info->GetAttr<std::vector<float>>("wi_max"));
auto* att_fc_op_info = matched.at("att_2in1")->stmt()->op_info();
op_desc.SetAttr<float>("att_fc_w_max",
att_fc_op_info->GetAttr<float>("W_max"));
auto* new_stmt = matched.at("emb0")->stmt();
auto new_op = LiteOpRegistry::Global().Create(op_desc.Type());
new_op->Attach(op_desc, new_stmt->op()->scope());
new_op->SetValidPlaces(new_stmt->op()->valid_places());
auto kernels = new_op->CreateKernels(new_op->valid_places());
new_stmt->SetOp(new_op);
new_stmt->SetKernels(std::move(kernels));
std::vector<std::string> arg_names{
"input1",
"grnn_left_wh",
"grnn_left_wi",
"grnn_right_wh",
"grnn_right_wi",
"att_2in1_w",
"att_2in1_b",
};
for (auto name : arg_names) {
DirectedLink(matched.at(name), matched.at("emb0"));
}
std::vector<std::string> out_names{
"seq_pool_left_out",
"seq_pool_right_out",
"seq_pool_2in1_out",
"concat_3in1_out",
"eltwise01_out",
};
for (auto name : out_names) {
IR_OP_VAR_LINK(matched.at("emb0"), matched.at(name));
}
}
};
class XPUMmdnnMergeAllFuser : public FuseBase {
public:
void BuildPattern() override {
auto* concat_7in1_input0 = VarNode("concat_7in1_input0")
->assert_is_op_nth_input("concat", "X", 0)
->AsInput();
auto* concat_7in1_input1 = VarNode("concat_7in1_input1")
->assert_is_op_nth_input("concat", "X", 1)
->AsInput();
auto* concat_7in1_input2 = VarNode("concat_7in1_input2")
->assert_is_op_nth_input("concat", "X", 2)
->AsInput();
auto* concat_7in1_input3 = VarNode("concat_7in1_input3")
->assert_is_op_nth_input("concat", "X", 3)
->AsInput();
auto* concat_7in1_input4 = VarNode("concat_7in1_input4")
->assert_is_op_nth_input("concat", "X", 4)
->AsInput();
auto* concat_7in1_input5 = VarNode("concat_7in1_input5")
->assert_is_op_nth_input("concat", "X", 5)
->AsInput();
auto* concat_7in1_input6 = VarNode("concat_7in1_input6")
->assert_is_op_nth_input("concat", "X", 6)
->AsInput();
auto* concat_7in1 = OpNode("concat_7in1", "concat");
auto* concat_7in1_out = VarNode("concat_7in1_out")
->assert_is_op_output("concat", "Out")
->AsIntermediate();
auto* search_fc0_w = VarNode("search_fc0_w")
->assert_is_op_input("search_fc", "W")
->AsInput();
auto* search_fc0_b = VarNode("search_fc0_b")
->assert_is_op_input("search_fc", "b")
->AsInput();
auto* search_fc0 = OpNode("search_fc0", "search_fc")->AsIntermediate();
auto* search_fc0_out = VarNode("search_fc0_out")
->assert_is_op_output("search_fc", "Out")
->AsIntermediate();
auto* relu0 = OpNode("relu0", "relu")->AsIntermediate();
auto* relu0_out = VarNode("relu0_out")
->assert_is_op_output("relu", "Out")
->AsIntermediate();
auto* concat_2in1_input0 = VarNode("concat_2in1_input0")
->assert_is_op_nth_input("concat", "X", 0)
->AsInput();
auto* concat_2in1_input1 = VarNode("concat_2in1_input1")
->assert_is_op_nth_input("concat", "X", 1)
->AsInput();
auto* concat_2in1 = OpNode("concat_2in1", "concat")->AsIntermediate();
auto* concat_2in1_out = VarNode("concat_2in1_out")
->assert_is_op_output("concat", "Out")
->AsIntermediate();
auto* seq_rev = OpNode("seq_rev", "sequence_reverse")->AsIntermediate();
auto* seq_rev_out = VarNode("seq_rev_out")
->assert_is_op_output("sequence_reverse", "Y")
->AsIntermediate();
auto* grnn_rv_wh = VarNode("grnn_rv_wh")
->assert_is_op_input("search_grnn", "Wh")
->AsInput();
auto* grnn_rv_wi = VarNode("grnn_rv_wi")
->assert_is_op_input("search_grnn", "Wi")
->AsInput();
auto* grnn_rv = OpNode("grnn_rv", "search_grnn")->AsIntermediate();
auto* grnn_rv_out = VarNode("grnn_rv_out")
->assert_is_op_output("search_grnn", "Out")
->AsIntermediate();
auto* grnn_rv_idx_sorted_by_width =
VarNode("grnn_rv_idx_sorted_by_width")
->assert_is_op_output("search_grnn", "idx_sorted_by_width")
->AsIntermediate();
auto* grnn_rv_layout_input =
VarNode("grnn_rv_layout_input")
->assert_is_op_output("search_grnn", "layout_input")
->AsIntermediate();
auto* grnn_rv_tmp_buffer =
VarNode("grnn_rv_tmp_buffer")
->assert_is_op_output("search_grnn", "tmp_buffer")
->AsIntermediate();
auto* seq_pool_rv =
OpNode("seq_pool_rv", "sequence_pool")->AsIntermediate();
auto* seq_pool_rv_out = VarNode("seq_pool_rv_out")
->assert_is_op_output("sequence_pool", "Out")
->AsIntermediate();
auto* seq_pool_rv_max_idx =
VarNode("seq_pool_rv_max_idx")
->assert_is_op_output("sequence_pool", "MaxIndex")
->AsIntermediate();
auto* grnn_fw_wh = VarNode("grnn_fw_wh")
->assert_is_op_input("search_grnn", "Wh")
->AsInput();
auto* grnn_fw_wi = VarNode("grnn_fw_wi")
->assert_is_op_input("search_grnn", "Wi")
->AsInput();
auto* grnn_fw = OpNode("grnn_fw", "search_grnn")->AsIntermediate();
auto* grnn_fw_out = VarNode("grnn_fw_out")
->assert_is_op_output("search_grnn", "Out")
->AsIntermediate();
auto* grnn_fw_idx_sorted_by_width =
VarNode("grnn_fw_idx_sorted_by_width")
->assert_is_op_output("search_grnn", "idx_sorted_by_width")
->AsIntermediate();
auto* grnn_fw_layout_input =
VarNode("grnn_fw_layout_input")
->assert_is_op_output("search_grnn", "layout_input")
->AsIntermediate();
auto* grnn_fw_tmp_buffer =
VarNode("grnn_fw_tmp_buffer")
->assert_is_op_output("search_grnn", "tmp_buffer")
->AsIntermediate();
auto* seq_pool_fw =
OpNode("seq_pool_fw", "sequence_pool")->AsIntermediate();
auto* seq_pool_fw_out = VarNode("seq_pool_fw_out")
->assert_is_op_output("sequence_pool", "Out")
->AsIntermediate();
auto* seq_pool_fw_max_idx =
VarNode("seq_pool_fw_max_idx")
->assert_is_op_output("sequence_pool", "MaxIndex")
->AsIntermediate();
auto* rv_fw_concat = OpNode("rv_fw_concat", "concat")->AsIntermediate();
auto* rv_fw_concat_out = VarNode("rv_fw_concat_out")
->assert_is_op_output("concat", "Out")
->AsIntermediate();
auto* last_concat = OpNode("last_concat", "concat")->AsIntermediate();
auto* last_concat_out = VarNode("last_concat_out")
->assert_is_op_output("concat", "Out")
->AsIntermediate();
auto* search_fc1_w = VarNode("search_fc1_w")
->assert_is_op_input("search_fc", "W")
->AsInput();
auto* search_fc1_b = VarNode("search_fc1_b")
->assert_is_op_input("search_fc", "b")
->AsInput();
auto* search_fc1 = OpNode("search_fc1", "search_fc")->AsIntermediate();
auto* search_fc1_out = VarNode("search_fc1_out")
->assert_is_op_output("search_fc", "Out")
->AsIntermediate();
auto* relu1 = OpNode("relu1", "relu")->AsIntermediate();
auto* relu1_out = VarNode("relu1_out")
->assert_is_op_output("relu", "Out")
->AsIntermediate();
auto* search_fc2_w = VarNode("search_fc2_w")
->assert_is_op_input("search_fc", "W")
->AsInput();
auto* search_fc2_b = VarNode("search_fc2_b")
->assert_is_op_input("search_fc", "b")
->AsInput();
auto* search_fc2 = OpNode("search_fc2", "search_fc")->AsIntermediate();
auto* search_fc2_out = VarNode("search_fc2_out")
->assert_is_op_output("search_fc", "Out")
->AsOutput();
*concat_7in1_input0 >> *concat_7in1;
*concat_7in1_input1 >> *concat_7in1;
*concat_7in1_input2 >> *concat_7in1;
*concat_7in1_input3 >> *concat_7in1;
*concat_7in1_input4 >> *concat_7in1;
*concat_7in1_input5 >> *concat_7in1;
*concat_7in1_input6 >> *concat_7in1;
*concat_7in1 >> *concat_7in1_out >> *search_fc0 >> *search_fc0_out >>
*relu0 >> *relu0_out;
*search_fc0_w >> *search_fc0;
*search_fc0_b >> *search_fc0;
*concat_2in1_input0 >> *concat_2in1;
*concat_2in1_input1 >> *concat_2in1;
*concat_2in1 >> *concat_2in1_out >> *seq_rev >> *seq_rev_out;
*seq_rev_out >> *grnn_rv >> *grnn_rv_out >> *seq_pool_rv >>
*seq_pool_rv_out;
*seq_pool_rv >> *seq_pool_rv_max_idx;
*grnn_rv_wh >> *grnn_rv;
*grnn_rv_wi >> *grnn_rv;
*grnn_rv >> *grnn_rv_idx_sorted_by_width;
*grnn_rv >> *grnn_rv_layout_input;
*grnn_rv >> *grnn_rv_tmp_buffer;
*concat_2in1_out >> *grnn_fw >> *grnn_fw_out >> *seq_pool_fw >>
*seq_pool_fw_out;
*seq_pool_fw >> *seq_pool_fw_max_idx;
*grnn_fw_wh >> *grnn_fw;
*grnn_fw_wi >> *grnn_fw;
*grnn_fw >> *grnn_fw_idx_sorted_by_width;
*grnn_fw >> *grnn_fw_layout_input;
*grnn_fw >> *grnn_fw_tmp_buffer;
*seq_pool_rv_out >> *rv_fw_concat;
*seq_pool_fw_out >> *rv_fw_concat;
*rv_fw_concat >> *rv_fw_concat_out;
*rv_fw_concat_out >> *last_concat;
*relu0_out >> *last_concat;
*last_concat >> *last_concat_out >> *search_fc1 >> *search_fc1_out >>
*relu1 >> *relu1_out >> *search_fc2 >> *search_fc2_out;
*search_fc1_w >> *search_fc1;
*search_fc1_b >> *search_fc1;
*search_fc2_w >> *search_fc2;
*search_fc2_b >> *search_fc2;
}
void InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) override {
cpp::OpDesc op_desc;
op_desc.SetType("__xpu__mmdnn_merge_all");
auto* concat_7in1_op_info = matched.at("concat_7in1")->stmt()->op_info();
op_desc.SetInput("concat_7in1_x", concat_7in1_op_info->Input("X"));
auto* concat_2in1_op_info = matched.at("concat_2in1")->stmt()->op_info();
op_desc.SetInput("concat_2in1_x", concat_2in1_op_info->Input("X"));
op_desc.SetInput("grnn_fw_wh", {matched.at("grnn_fw_wh")->arg()->name});
op_desc.SetInput("grnn_fw_wi", {matched.at("grnn_fw_wi")->arg()->name});
op_desc.SetInput("grnn_rv_wh", {matched.at("grnn_rv_wh")->arg()->name});
op_desc.SetInput("grnn_rv_wi", {matched.at("grnn_rv_wi")->arg()->name});
op_desc.SetInput("fc0_w", {matched.at("search_fc0_w")->arg()->name});
op_desc.SetInput("fc0_b", {matched.at("search_fc0_b")->arg()->name});
op_desc.SetInput("fc1_w", {matched.at("search_fc1_w")->arg()->name});
op_desc.SetInput("fc1_b", {matched.at("search_fc1_b")->arg()->name});
op_desc.SetInput("fc2_w", {matched.at("search_fc2_w")->arg()->name});
op_desc.SetInput("fc2_b", {matched.at("search_fc2_b")->arg()->name});
op_desc.SetOutput("out", {matched.at("search_fc2_out")->arg()->name});
auto* grnn_fw_op_info = matched.at("grnn_fw")->stmt()->op_info();
op_desc.SetAttr<std::vector<float>>(
"grnn_fw_wh_maxs",
grnn_fw_op_info->GetAttr<std::vector<float>>("wh_max"));
op_desc.SetAttr<std::vector<float>>(
"grnn_fw_wi_maxs",
grnn_fw_op_info->GetAttr<std::vector<float>>("wi_max"));
auto* grnn_rv_op_info = matched.at("grnn_rv")->stmt()->op_info();
op_desc.SetAttr<std::vector<float>>(
"grnn_rv_wh_maxs",
grnn_rv_op_info->GetAttr<std::vector<float>>("wh_max"));
op_desc.SetAttr<std::vector<float>>(
"grnn_rv_wi_maxs",
grnn_rv_op_info->GetAttr<std::vector<float>>("wi_max"));
auto* fc0_op_info = matched.at("search_fc0")->stmt()->op_info();
op_desc.SetAttr<float>("fc0_w_max", fc0_op_info->GetAttr<float>("w_max"));
auto* fc1_op_info = matched.at("search_fc1")->stmt()->op_info();
op_desc.SetAttr<float>("fc1_w_max", fc1_op_info->GetAttr<float>("w_max"));
auto* fc2_op_info = matched.at("search_fc2")->stmt()->op_info();
op_desc.SetAttr<float>("fc2_w_max", fc2_op_info->GetAttr<float>("w_max"));
auto* new_stmt = matched.at("concat_7in1")->stmt();
auto new_op = LiteOpRegistry::Global().Create(op_desc.Type());
new_op->Attach(op_desc, new_stmt->op()->scope());
new_op->SetValidPlaces(new_stmt->op()->valid_places());
auto kernels = new_op->CreateKernels(new_op->valid_places());
new_stmt->SetOp(new_op);
new_stmt->SetKernels(std::move(kernels));
std::vector<std::string> arg_names{
"concat_2in1_input0",
"concat_2in1_input1",
"grnn_fw_wh",
"grnn_fw_wi",
"grnn_rv_wh",
"grnn_rv_wi",
"search_fc0_w",
"search_fc0_b",
"search_fc1_w",
"search_fc1_b",
"search_fc2_w",
"search_fc2_b",
};
for (auto name : arg_names) {
DirectedLink(matched.at(name), matched.at("concat_7in1"));
}
std::vector<std::string> out_names{
"search_fc2_out",
};
for (auto name : out_names) {
IR_OP_VAR_LINK(matched.at("concat_7in1"), matched.at(name));
}
}
};
} // namespace fusion
class XPUMmdnnFusePass : public ProgramPass {
public:
void Apply(const std::unique_ptr<SSAGraph>& graph) override {
if (GetBoolFromEnv("XPU_ENABLE_XTCL")) return;
fusion::XPUMmdnnFloat2Fix float_2_fix;
float_2_fix(graph.get());
fusion::XPUMmdnnSearchAttentionFuser search_att_fuser;
search_att_fuser(graph.get());
fusion::XPUMmdnnMatchConvTopkFuser match_conv_topk_fuser;
match_conv_topk_fuser(graph.get());
fusion::XPUMmdnnBidSeqRevEmbEltwiseFuser bi_seq_rev_emb_eltwise_fuser;
bi_seq_rev_emb_eltwise_fuser(graph.get());
fusion::XPUMmdnnBidEmbGrnnAttFuser bid_emb_grnn_att_fuser;
bid_emb_grnn_att_fuser(graph.get());
fusion::XPUMmdnnBidEmbAttFuser bid_emb_att_fuser;
bid_emb_att_fuser(graph.get());
fusion::XPUMmdnnMergeAllFuser merge_all_fuser;
merge_all_fuser(graph.get());
}
};
} // namespace mir
} // namespace lite
} // namespace paddle
REGISTER_MIR_PASS(__xpu__mmdnn_fuse_pass, paddle::lite::mir::XPUMmdnnFusePass)
.BindTargets({TARGET(kXPU)})
.BindKernel("__xpu__mmdnn_search_attention")
.BindKernel("__xpu__mmdnn_bid_emb_grnn_att")
.BindKernel("__xpu__mmdnn_bid_emb_att")
.BindKernel("__xpu__mmdnn_match_conv_topk")
.BindKernel("__xpu__mmdnn_merge_all");
...@@ -639,20 +639,21 @@ class XPUMultiEncoderFusePass : public ProgramPass { ...@@ -639,20 +639,21 @@ class XPUMultiEncoderFusePass : public ProgramPass {
std::set<int> fc_int31_ids; std::set<int> fc_int31_ids;
#ifdef LITE_WITH_XPU #ifdef LITE_WITH_XPU
// TODO(miaotianxiang): core/mir/*_pass.cc are compiled anyway and need to // TODO(miaotianxiang): core/mir/*_pass.cc are compiled anyway and need to
// access Context<kXPU>::_multi_encoder_precision, but this static member // access TargetWrapperXPU::multi_encoder_precision, but this static member
// variable in class specialization defined in lite/core/context.cc // variable in class specialization defined in
// is only compiled iff LITE_WITH_XPU==ON. To suppress linkage error, we use // lite/backends/xpu/target_wrapper.cc is only compiled iff
// LITE_WITH_XPU==ON. To suppress linkage error, we use
// #ifdef here. Any better idea? // #ifdef here. Any better idea?
if (GetStringFromEnv("XPU_ENCODER_PRECISION", "int16") == "int31" || if (GetStringFromEnv("XPU_ENCODER_PRECISION", "int16") == "int31" ||
lite::Context<TargetType::kXPU>::_multi_encoder_precision == "int31") { lite::TargetWrapperXPU::multi_encoder_precision == "int31") {
fc_int31_ids = {0, 1, 2, 3, 4, 5}; fc_int31_ids = {0, 1, 2, 3, 4, 5};
VLOG(3) << "Use int31 in XPUMultiEncoderOp, " VLOG(3) << "Use int31 in XPUMultiEncoderOp, "
<< "lite::Context<>::_multi_encoder_precision=" << "lite::TargetWrapperXPU::multi_encoder_precision="
<< lite::Context<TargetType::kXPU>::_multi_encoder_precision; << lite::TargetWrapperXPU::multi_encoder_precision;
} else { } else {
VLOG(3) << "Use int16 in XPUMultiEncoderOp, " VLOG(3) << "Use int16 in XPUMultiEncoderOp, "
<< "lite::Context<>::_multi_encoder_precision=" << "lite::TargetWrapperXPU::multi_encoder_precision="
<< lite::Context<TargetType::kXPU>::_multi_encoder_precision; << lite::TargetWrapperXPU::multi_encoder_precision;
} }
#endif #endif
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <memory>
#include <vector>
#include "lite/backends/xpu/math.h"
#include "lite/core/mir/pass_registry.h"
#include "lite/core/mir/xpu_pattern_matcher_high_api.h"
#include "lite/operators/subgraph_op.h"
namespace paddle {
namespace lite {
namespace mir {
namespace fusion {
class XPUResNetCbamBlock0Fuser : public FuseBase {
public:
XPUResNetCbamBlock0Fuser() {}
void BuildPattern() override {
auto* input =
VarNode("input")->assert_is_op_input("conv2d", "Input")->AsInput();
auto* left_conv1_weight = VarNode("left_conv1_weight")
->assert_is_op_input("conv2d", "Filter")
->AsInput();
auto* left_conv1 = OpNode("left_conv1", "conv2d");
auto* left_conv1_out = VarNode("left_conv1_out")
->assert_is_op_output("conv2d", "Output")
->assert_is_op_input("batch_norm", "X")
->AsIntermediate();
auto* left_bn1_scale = VarNode("left_bn1_scale")
->assert_is_op_input("batch_norm", "Scale")
->AsIntermediate();
auto* left_bn1_bias = VarNode("left_bn1_bias")
->assert_is_op_input("batch_norm", "Bias")
->AsInput();
auto* left_bn1_mean = VarNode("left_bn1_mean")
->assert_is_op_input("batch_norm", "Mean")
->AsIntermediate();
auto* left_bn1_var = VarNode("left_bn1_variance")
->assert_is_op_input("batch_norm", "Variance")
->AsIntermediate();
auto* left_bn1 = OpNode("left_bn1", "batch_norm")->AsIntermediate();
auto* left_bn1_out = VarNode("left_bn1_out")
->assert_is_op_output("batch_norm", "Y")
->assert_is_op_input("relu", "X")
->AsIntermediate();
auto* left_bn1_mean_out = VarNode("left_bn1_mean_out")
->assert_is_op_output("batch_norm", "MeanOut")
->AsIntermediate();
auto* left_bn1_var_out =
VarNode("left_bn1_var_out")
->assert_is_op_output("batch_norm", "VarianceOut")
->AsIntermediate();
auto* left_bn1_saved_mean =
VarNode("left_bn1_saved_mean")
->assert_is_op_output("batch_norm", "SavedMean")
->AsIntermediate();
auto* left_bn1_saved_var =
VarNode("left_bn1_saved_var")
->assert_is_op_output("batch_norm", "SavedVariance")
->AsIntermediate();
auto* left_relu1 = OpNode("left_relu1", "relu")->AsIntermediate();
auto* left_relu1_out = VarNode("left_relu1_out")
->assert_is_op_output("relu", "Out")
->assert_is_op_input("conv2d", "Input")
->AsIntermediate();
auto* left_conv2_weight = VarNode("left_conv2_weight")
->assert_is_op_input("conv2d", "Filter")
->AsInput();
auto* left_conv2 = OpNode("left_conv2", "conv2d")->AsIntermediate();
auto* left_conv2_out = VarNode("left_conv2_out")
->assert_is_op_output("conv2d", "Output")
->assert_is_op_input("batch_norm", "X")
->AsIntermediate();
auto* left_bn2_scale = VarNode("left_bn2_scale")
->assert_is_op_input("batch_norm", "Scale")
->AsIntermediate();
auto* left_bn2_bias = VarNode("left_bn2_bias")
->assert_is_op_input("batch_norm", "Bias")
->AsInput();
auto* left_bn2_mean = VarNode("left_bn2_mean")
->assert_is_op_input("batch_norm", "Mean")
->AsIntermediate();
auto* left_bn2_var = VarNode("left_bn2_variance")
->assert_is_op_input("batch_norm", "Variance")
->AsIntermediate();
auto* left_bn2 = OpNode("left_bn2", "batch_norm")->AsIntermediate();
auto* left_bn2_out = VarNode("left_bn2_out")
->assert_is_op_output("batch_norm", "Y")
->assert_is_op_input("relu", "X")
->AsIntermediate();
auto* left_bn2_mean_out = VarNode("left_bn2_mean_out")
->assert_is_op_output("batch_norm", "MeanOut")
->AsIntermediate();
auto* left_bn2_var_out =
VarNode("left_bn2_var_out")
->assert_is_op_output("batch_norm", "VarianceOut")
->AsIntermediate();
auto* left_bn2_saved_mean =
VarNode("left_bn2_saved_mean")
->assert_is_op_output("batch_norm", "SavedMean")
->AsIntermediate();
auto* left_bn2_saved_var =
VarNode("left_bn2_saved_var")
->assert_is_op_output("batch_norm", "SavedVariance")
->AsIntermediate();
auto* left_relu2 = OpNode("left_relu2", "relu")->AsIntermediate();
auto* left_relu2_out = VarNode("left_relu2_out")
->assert_is_op_output("relu", "Out")
->assert_is_op_input("conv2d", "Input")
->AsIntermediate();
auto* left_conv3_weight = VarNode("left_conv3_weight")
->assert_is_op_input("conv2d", "Filter")
->AsInput();
auto* left_conv3 = OpNode("left_conv3", "conv2d")->AsIntermediate();
auto* left_conv3_out = VarNode("left_conv3_out")
->assert_is_op_output("conv2d", "Output")
->assert_is_op_input("batch_norm", "X")
->AsIntermediate();
auto* left_bn3_scale = VarNode("left_bn3_scale")
->assert_is_op_input("batch_norm", "Scale")
->AsIntermediate();
auto* left_bn3_bias = VarNode("left_bn3_bias")
->assert_is_op_input("batch_norm", "Bias")
->AsInput();
auto* left_bn3_mean = VarNode("left_bn3_mean")
->assert_is_op_input("batch_norm", "Mean")
->AsIntermediate();
auto* left_bn3_var = VarNode("left_bn3_variance")
->assert_is_op_input("batch_norm", "Variance")
->AsIntermediate();
auto* left_bn3 = OpNode("left_bn3", "batch_norm")->AsIntermediate();
auto* left_bn3_out = VarNode("left_bn3_out")
->assert_is_op_output("batch_norm", "Y")
->AsIntermediate();
auto* left_bn3_mean_out = VarNode("left_bn3_mean_out")
->assert_is_op_output("batch_norm", "MeanOut")
->AsIntermediate();
auto* left_bn3_var_out =
VarNode("left_bn3_var_out")
->assert_is_op_output("batch_norm", "VarianceOut")
->AsIntermediate();
auto* left_bn3_saved_mean =
VarNode("left_bn3_saved_mean")
->assert_is_op_output("batch_norm", "SavedMean")
->AsIntermediate();
auto* left_bn3_saved_var =
VarNode("left_bn3_saved_var")
->assert_is_op_output("batch_norm", "SavedVariance")
->AsIntermediate();
// cbam specific
auto* reduce_mean = OpNode("reduce_mean", "reduce_mean")->AsIntermediate();
auto* reduce_mean_out = VarNode("reduce_mean_out")
->assert_is_op_output("reduce_mean", "Out")
->assert_is_op_input("concat")
->AsIntermediate();
auto* reduce_max = OpNode("reduce_max", "reduce_max")->AsIntermediate();
auto* reduce_max_out = VarNode("reduce_max_out")
->assert_is_op_output("reduce_max", "Out")
->assert_is_op_input("concat")
->AsIntermediate();
auto* concat = OpNode("concat", "concat")->AsIntermediate();
auto* concat_out = VarNode("concat_out")
->assert_is_op_output("concat", "Out")
->assert_is_op_input("conv2d", "Input")
->AsIntermediate();
auto* left_conv4_weight = VarNode("left_conv4_weight")
->assert_is_op_input("conv2d", "Filter")
->AsInput();
auto* left_conv4 = OpNode("left_conv4", "conv2d")->AsIntermediate();
auto* left_conv4_out = VarNode("left_conv4_out")
->assert_is_op_output("conv2d", "Output")
->assert_is_op_input("sigmoid", "X")
->AsIntermediate();
auto* sigmoid = OpNode("sigmoid", "sigmoid")->AsIntermediate();
auto* sigmoid_out = VarNode("sigmoid_out")
->assert_is_op_output("sigmoid", "Out")
->assert_is_op_input("elementwise_mul")
->AsIntermediate();
auto* reshape = OpNode("reshape", "reshape2")->AsIntermediate();
auto* reshape_out = VarNode("reshape_out")
->assert_is_op_output("reshape2", "Out")
->assert_is_op_input("elementwise_mul")
->AsIntermediate();
auto* reshape_xshape = VarNode("reshape_xshape")
->assert_is_op_output("reshape2", "XShape")
->AsIntermediate();
auto* eltwise_mul =
OpNode("eltwise_mul", "elementwise_mul")->AsIntermediate();
auto* eltwise_mul_out = VarNode("eltwise_mul_out")
->assert_is_op_output("elementwise_mul", "Out")
->assert_is_op_input("elementwise_add")
->AsIntermediate();
auto* right_conv1_weight = VarNode("right_conv1_weight")
->assert_is_op_input("conv2d", "Filter")
->AsInput();
auto* right_conv1 = OpNode("right_conv1", "conv2d")->AsIntermediate();
auto* right_conv1_out = VarNode("right_conv1_out")
->assert_is_op_output("conv2d", "Output")
->assert_is_op_input("batch_norm", "X")
->AsIntermediate();
auto* right_bn1_scale = VarNode("right_bn1_scale")
->assert_is_op_input("batch_norm", "Scale")
->AsIntermediate();
auto* right_bn1_bias = VarNode("right_bn1_bias")
->assert_is_op_input("batch_norm", "Bias")
->AsInput();
auto* right_bn1_mean = VarNode("right_bn1_mean")
->assert_is_op_input("batch_norm", "Mean")
->AsIntermediate();
auto* right_bn1_var = VarNode("right_bn1_variance")
->assert_is_op_input("batch_norm", "Variance")
->AsIntermediate();
auto* right_bn1 = OpNode("right_bn1", "batch_norm")->AsIntermediate();
auto* right_bn1_out = VarNode("right_bn1_out")
->assert_is_op_output("batch_norm", "Y")
->assert_is_op_input("elementwise_add")
->AsIntermediate();
auto* right_bn1_mean_out =
VarNode("right_bn1_mean_out")
->assert_is_op_output("batch_norm", "MeanOut")
->AsIntermediate();
auto* right_bn1_var_out =
VarNode("right_bn1_var_out")
->assert_is_op_output("batch_norm", "VarianceOut")
->AsIntermediate();
auto* right_bn1_saved_mean =
VarNode("right_bn1_saved_mean")
->assert_is_op_output("batch_norm", "SavedMean")
->AsIntermediate();
auto* right_bn1_saved_var =
VarNode("right_bn1_saved_var")
->assert_is_op_output("batch_norm", "SavedVariance")
->AsIntermediate();
auto* add = OpNode("add", "elementwise_add")->AsIntermediate();
auto* add_out = VarNode("add_out")
->assert_is_op_output("elementwise_add", "Out")
->assert_is_op_input("relu", "X")
->AsIntermediate();
auto* relu = OpNode("relu", "relu")->AsIntermediate();
auto* relu_out =
VarNode("relu_out")->assert_is_op_output("relu", "Out")->AsOutput();
*input >> *left_conv1 >> *left_conv1_out >> *left_bn1 >> *left_bn1_out >>
*left_relu1 >> *left_relu1_out >> *left_conv2 >> *left_conv2_out >>
*left_bn2 >> *left_bn2_out >> *left_relu2 >> *left_relu2_out >>
*left_conv3 >> *left_conv3_out >> *left_bn3 >>
*left_bn3_out /* >> *add*/;
*left_bn3_out >> *reduce_mean >> *reduce_mean_out >> *concat;
*left_bn3_out >> *reduce_max >> *reduce_max_out >> *concat;
*concat >> *concat_out >> *left_conv4 >> *left_conv4_out >> *sigmoid >>
*sigmoid_out >> *eltwise_mul;
*left_conv4_weight >> *left_conv4;
*left_bn3_out >> *reshape >> *reshape_out >> *eltwise_mul;
*reshape >> *reshape_xshape;
*eltwise_mul >> *eltwise_mul_out >> *add;
*left_conv1_weight >> *left_conv1;
*left_bn1_scale >> *left_bn1;
*left_bn1_bias >> *left_bn1;
*left_bn1_mean >> *left_bn1;
*left_bn1_var >> *left_bn1;
*left_bn1 >> *left_bn1_mean_out;
*left_bn1 >> *left_bn1_var_out;
*left_bn1 >> *left_bn1_saved_mean;
*left_bn1 >> *left_bn1_saved_var;
*left_conv2_weight >> *left_conv2;
*left_bn2_scale >> *left_bn2;
*left_bn2_bias >> *left_bn2;
*left_bn2_mean >> *left_bn2;
*left_bn2_var >> *left_bn2;
*left_bn2 >> *left_bn2_mean_out;
*left_bn2 >> *left_bn2_var_out;
*left_bn2 >> *left_bn2_saved_mean;
*left_bn2 >> *left_bn2_saved_var;
*left_conv3_weight >> *left_conv3;
*left_bn3_scale >> *left_bn3;
*left_bn3_bias >> *left_bn3;
*left_bn3_mean >> *left_bn3;
*left_bn3_var >> *left_bn3;
*left_bn3 >> *left_bn3_mean_out;
*left_bn3 >> *left_bn3_var_out;
*left_bn3 >> *left_bn3_saved_mean;
*left_bn3 >> *left_bn3_saved_var;
*input >> *right_conv1 >> *right_conv1_out >> *right_bn1 >>
*right_bn1_out >> *add;
*right_conv1_weight >> *right_conv1;
*right_bn1_scale >> *right_bn1;
*right_bn1_bias >> *right_bn1;
*right_bn1_mean >> *right_bn1;
*right_bn1_var >> *right_bn1;
*right_bn1 >> *right_bn1_mean_out;
*right_bn1 >> *right_bn1_var_out;
*right_bn1 >> *right_bn1_saved_mean;
*right_bn1 >> *right_bn1_saved_var;
*add >> *add_out >> *relu >> *relu_out;
}
void InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) override {
cpp::OpDesc op_desc;
op_desc.SetType("resnet_cbam_block0");
op_desc.SetInput("Inputs", {matched.at("input")->arg()->name});
op_desc.SetInput("Filter",
{
matched.at("left_conv1_weight")->arg()->name,
matched.at("left_conv2_weight")->arg()->name,
matched.at("left_conv3_weight")->arg()->name,
matched.at("left_conv4_weight")->arg()->name,
matched.at("right_conv1_weight")->arg()->name,
});
op_desc.SetInput("Scale",
{
matched.at("left_bn1_scale")->arg()->name,
matched.at("left_bn2_scale")->arg()->name,
matched.at("left_bn3_scale")->arg()->name,
"placeholder_sa_conv",
matched.at("right_bn1_scale")->arg()->name,
});
op_desc.SetInput("Bias",
{
matched.at("left_bn1_bias")->arg()->name,
matched.at("left_bn2_bias")->arg()->name,
matched.at("left_bn3_bias")->arg()->name,
"placeholder_sa_conv",
matched.at("right_bn1_bias")->arg()->name,
});
op_desc.SetInput("Mean",
{
matched.at("left_bn1_mean")->arg()->name,
matched.at("left_bn2_mean")->arg()->name,
matched.at("left_bn3_mean")->arg()->name,
"placeholder_sa_conv",
matched.at("right_bn1_mean")->arg()->name,
});
op_desc.SetInput("Var",
{
matched.at("left_bn1_variance")->arg()->name,
matched.at("left_bn2_variance")->arg()->name,
matched.at("left_bn3_variance")->arg()->name,
"placeholder_sa_conv",
matched.at("right_bn1_variance")->arg()->name,
});
op_desc.SetOutput("Outputs", {matched.at("relu_out")->arg()->name});
// XXX: keep these to fool SubgraphOp::AttachImpl()
op_desc.SetAttr<int>("sub_block", 0);
op_desc.SetAttr<std::vector<std::string>>("input_data_names", {});
op_desc.SetAttr<std::vector<std::string>>("output_data_names", {});
auto block0_stmt = matched.at("left_conv1")->stmt();
// block0_stmt->ResetOp(op_desc, graph->valid_places());
auto fake_subgraph_op = LiteOpRegistry::Global().Create("subgraph");
// XXX: memleak?
auto sub_block_desc = new cpp::BlockDesc();
static_cast<operators::SubgraphOp*>(fake_subgraph_op.get())
->SetSubBlock(sub_block_desc);
fake_subgraph_op->Attach(op_desc, block0_stmt->op()->scope());
fake_subgraph_op->SetValidPlaces(block0_stmt->op()->valid_places());
block0_stmt->SetOp(fake_subgraph_op);
std::vector<std::string> froms = {
"left_conv2_weight",
"left_conv3_weight",
"left_conv4_weight",
"right_conv1_weight",
"left_bn1_bias",
"left_bn2_bias",
"left_bn3_bias",
"right_bn1_bias",
};
for (auto& from : froms) {
IR_NODE_LINK_TO(matched.at(from), matched.at("left_conv1"));
}
IR_OP_VAR_LINK(matched.at("left_conv1"), matched.at("relu_out"));
}
};
class XPUResNetCbamBlock1Fuser : public FuseBase {
public:
XPUResNetCbamBlock1Fuser() {}
void BuildPattern() override {
auto* input = VarNode("input")
->assert_is_op_input("conv2d", "Input")
->assert_is_op_input("elementwise_add")
->AsInput();
auto* right_conv1_weight = VarNode("right_conv1_weight")
->assert_is_op_input("conv2d", "Filter")
->AsInput();
auto* right_conv1 = OpNode("right_conv1", "conv2d");
auto* right_conv1_out = VarNode("right_conv1_out")
->assert_is_op_output("conv2d", "Output")
->assert_is_op_input("batch_norm", "X")
->AsIntermediate();
auto* right_bn1_scale = VarNode("right_bn1_scale")
->assert_is_op_input("batch_norm", "Scale")
->AsIntermediate();
auto* right_bn1_bias = VarNode("right_bn1_bias")
->assert_is_op_input("batch_norm", "Bias")
->AsInput();
auto* right_bn1_mean = VarNode("right_bn1_mean")
->assert_is_op_input("batch_norm", "Mean")
->AsIntermediate();
auto* right_bn1_var = VarNode("right_bn1_variance")
->assert_is_op_input("batch_norm", "Variance")
->AsIntermediate();
auto* right_bn1 = OpNode("right_bn1", "batch_norm")->AsIntermediate();
auto* right_bn1_out = VarNode("right_bn1_out")
->assert_is_op_output("batch_norm", "Y")
->assert_is_op_input("relu", "X")
->AsIntermediate();
auto* right_bn1_mean_out =
VarNode("right_bn1_mean_out")
->assert_is_op_output("batch_norm", "MeanOut")
->AsIntermediate();
auto* right_bn1_var_out =
VarNode("right_bn1_var_out")
->assert_is_op_output("batch_norm", "VarianceOut")
->AsIntermediate();
auto* right_bn1_saved_mean =
VarNode("right_bn1_saved_mean")
->assert_is_op_output("batch_norm", "SavedMean")
->AsIntermediate();
auto* right_bn1_saved_var =
VarNode("right_bn1_saved_var")
->assert_is_op_output("batch_norm", "SavedVariance")
->AsIntermediate();
auto* right_relu1 = OpNode("right_relu1", "relu")->AsIntermediate();
auto* right_relu1_out = VarNode("right_relu1_out")
->assert_is_op_output("relu", "Out")
->assert_is_op_input("conv2d", "Input")
->AsIntermediate();
auto* right_conv2_weight = VarNode("right_conv2_weight")
->assert_is_op_input("conv2d", "Filter")
->AsInput();
auto* right_conv2 = OpNode("right_conv2", "conv2d")->AsIntermediate();
auto* right_conv2_out = VarNode("right_conv2_out")
->assert_is_op_output("conv2d", "Output")
->assert_is_op_input("batch_norm", "X")
->AsIntermediate();
auto* right_bn2_scale = VarNode("right_bn2_scale")
->assert_is_op_input("batch_norm", "Scale")
->AsIntermediate();
auto* right_bn2_bias = VarNode("right_bn2_bias")
->assert_is_op_input("batch_norm", "Bias")
->AsInput();
auto* right_bn2_mean = VarNode("right_bn2_mean")
->assert_is_op_input("batch_norm", "Mean")
->AsIntermediate();
auto* right_bn2_var = VarNode("right_bn2_variance")
->assert_is_op_input("batch_norm", "Variance")
->AsIntermediate();
auto* right_bn2 = OpNode("right_bn2", "batch_norm")->AsIntermediate();
auto* right_bn2_out = VarNode("right_bn2_out")
->assert_is_op_output("batch_norm", "Y")
->assert_is_op_input("relu", "X")
->AsIntermediate();
auto* right_bn2_mean_out =
VarNode("right_bn2_mean_out")
->assert_is_op_output("batch_norm", "MeanOut")
->AsIntermediate();
auto* right_bn2_var_out =
VarNode("right_bn2_var_out")
->assert_is_op_output("batch_norm", "VarianceOut")
->AsIntermediate();
auto* right_bn2_saved_mean =
VarNode("right_bn2_saved_mean")
->assert_is_op_output("batch_norm", "SavedMean")
->AsIntermediate();
auto* right_bn2_saved_var =
VarNode("right_bn2_saved_var")
->assert_is_op_output("batch_norm", "SavedVariance")
->AsIntermediate();
auto* right_relu2 = OpNode("right_relu2", "relu")->AsIntermediate();
auto* right_relu2_out = VarNode("right_relu2_out")
->assert_is_op_output("relu", "Out")
->assert_is_op_input("conv2d", "Input")
->AsIntermediate();
auto* right_conv3_weight = VarNode("right_conv3_weight")
->assert_is_op_input("conv2d", "Filter")
->AsInput();
auto* right_conv3 = OpNode("right_conv3", "conv2d")->AsIntermediate();
auto* right_conv3_out = VarNode("right_conv3_out")
->assert_is_op_output("conv2d", "Output")
->assert_is_op_input("batch_norm", "X")
->AsIntermediate();
auto* right_bn3_scale = VarNode("right_bn3_scale")
->assert_is_op_input("batch_norm", "Scale")
->AsIntermediate();
auto* right_bn3_bias = VarNode("right_bn3_bias")
->assert_is_op_input("batch_norm", "Bias")
->AsInput();
auto* right_bn3_mean = VarNode("right_bn3_mean")
->assert_is_op_input("batch_norm", "Mean")
->AsIntermediate();
auto* right_bn3_var = VarNode("right_bn3_variance")
->assert_is_op_input("batch_norm", "Variance")
->AsIntermediate();
auto* right_bn3 = OpNode("right_bn3", "batch_norm")->AsIntermediate();
auto* right_bn3_out = VarNode("right_bn3_out")
->assert_is_op_output("batch_norm", "Y")
->AsIntermediate();
auto* right_bn3_mean_out =
VarNode("right_bn3_mean_out")
->assert_is_op_output("batch_norm", "MeanOut")
->AsIntermediate();
auto* right_bn3_var_out =
VarNode("right_bn3_var_out")
->assert_is_op_output("batch_norm", "VarianceOut")
->AsIntermediate();
auto* right_bn3_saved_mean =
VarNode("right_bn3_saved_mean")
->assert_is_op_output("batch_norm", "SavedMean")
->AsIntermediate();
auto* right_bn3_saved_var =
VarNode("right_bn3_saved_var")
->assert_is_op_output("batch_norm", "SavedVariance")
->AsIntermediate();
// cbam specific
auto* reduce_mean = OpNode("reduce_mean", "reduce_mean")->AsIntermediate();
auto* reduce_mean_out = VarNode("reduce_mean_out")
->assert_is_op_output("reduce_mean", "Out")
->assert_is_op_input("concat")
->AsIntermediate();
auto* reduce_max = OpNode("reduce_max", "reduce_max")->AsIntermediate();
auto* reduce_max_out = VarNode("reduce_max_out")
->assert_is_op_output("reduce_max", "Out")
->assert_is_op_input("concat")
->AsIntermediate();
auto* concat = OpNode("concat", "concat")->AsIntermediate();
auto* concat_out = VarNode("concat_out")
->assert_is_op_output("concat", "Out")
->assert_is_op_input("conv2d", "Input")
->AsIntermediate();
auto* right_conv4_weight = VarNode("right_conv4_weight")
->assert_is_op_input("conv2d", "Filter")
->AsInput();
auto* right_conv4 = OpNode("right_conv4", "conv2d")->AsIntermediate();
auto* right_conv4_out = VarNode("right_conv4_out")
->assert_is_op_output("conv2d", "Output")
->assert_is_op_input("sigmoid", "X")
->AsIntermediate();
auto* sigmoid = OpNode("sigmoid", "sigmoid")->AsIntermediate();
auto* sigmoid_out = VarNode("sigmoid_out")
->assert_is_op_output("sigmoid", "Out")
->assert_is_op_input("elementwise_mul")
->AsIntermediate();
auto* reshape = OpNode("reshape", "reshape2")->AsIntermediate();
auto* reshape_out = VarNode("reshape_out")
->assert_is_op_output("reshape2", "Out")
->assert_is_op_input("elementwise_mul")
->AsIntermediate();
auto* reshape_xshape = VarNode("reshape_xshape")
->assert_is_op_output("reshape2", "XShape")
->AsIntermediate();
auto* eltwise_mul =
OpNode("eltwise_mul", "elementwise_mul")->AsIntermediate();
auto* eltwise_mul_out = VarNode("eltwise_mul_out")
->assert_is_op_output("elementwise_mul", "Out")
->assert_is_op_input("elementwise_add")
->AsIntermediate();
auto* add = OpNode("add", "elementwise_add")->AsIntermediate();
auto* add_out = VarNode("add_out")
->assert_is_op_output("elementwise_add", "Out")
->assert_is_op_input("relu", "X")
->AsIntermediate();
auto* relu = OpNode("relu", "relu")->AsIntermediate();
auto* relu_out =
VarNode("relu_out")->assert_is_op_output("relu", "Out")->AsOutput();
*input >> *right_conv1 >> *right_conv1_out >> *right_bn1 >>
*right_bn1_out >> *right_relu1 >> *right_relu1_out >> *right_conv2 >>
*right_conv2_out >> *right_bn2 >> *right_bn2_out >> *right_relu2 >>
*right_relu2_out >> *right_conv3 >> *right_conv3_out >> *right_bn3 >>
*right_bn3_out /* >> *add*/;
*right_bn3_out >> *reduce_mean >> *reduce_mean_out >> *concat;
*right_bn3_out >> *reduce_max >> *reduce_max_out >> *concat;
*concat >> *concat_out >> *right_conv4 >> *right_conv4_out >> *sigmoid >>
*sigmoid_out >> *eltwise_mul;
*right_conv4_weight >> *right_conv4;
*right_bn3_out >> *reshape >> *reshape_out >> *eltwise_mul;
*reshape >> *reshape_xshape;
*eltwise_mul >> *eltwise_mul_out >> *add;
*right_conv1_weight >> *right_conv1;
*right_bn1_scale >> *right_bn1;
*right_bn1_bias >> *right_bn1;
*right_bn1_mean >> *right_bn1;
*right_bn1_var >> *right_bn1;
*right_bn1 >> *right_bn1_mean_out;
*right_bn1 >> *right_bn1_var_out;
*right_bn1 >> *right_bn1_saved_mean;
*right_bn1 >> *right_bn1_saved_var;
*right_conv2_weight >> *right_conv2;
*right_bn2_scale >> *right_bn2;
*right_bn2_bias >> *right_bn2;
*right_bn2_mean >> *right_bn2;
*right_bn2_var >> *right_bn2;
*right_bn2 >> *right_bn2_mean_out;
*right_bn2 >> *right_bn2_var_out;
*right_bn2 >> *right_bn2_saved_mean;
*right_bn2 >> *right_bn2_saved_var;
*right_conv3_weight >> *right_conv3;
*right_bn3_scale >> *right_bn3;
*right_bn3_bias >> *right_bn3;
*right_bn3_mean >> *right_bn3;
*right_bn3_var >> *right_bn3;
*right_bn3 >> *right_bn3_mean_out;
*right_bn3 >> *right_bn3_var_out;
*right_bn3 >> *right_bn3_saved_mean;
*right_bn3 >> *right_bn3_saved_var;
*input >> *add;
*add >> *add_out >> *relu >> *relu_out;
}
void InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) override {
cpp::OpDesc op_desc;
op_desc.SetType("resnet_cbam_block1");
op_desc.SetInput("Inputs", {matched.at("input")->arg()->name});
op_desc.SetInput("Filter",
{
matched.at("right_conv1_weight")->arg()->name,
matched.at("right_conv2_weight")->arg()->name,
matched.at("right_conv3_weight")->arg()->name,
matched.at("right_conv4_weight")->arg()->name,
});
op_desc.SetInput("Scale",
{
matched.at("right_bn1_scale")->arg()->name,
matched.at("right_bn2_scale")->arg()->name,
matched.at("right_bn3_scale")->arg()->name,
"placeholder_sa_conv",
});
op_desc.SetInput("Bias",
{
matched.at("right_bn1_bias")->arg()->name,
matched.at("right_bn2_bias")->arg()->name,
matched.at("right_bn3_bias")->arg()->name,
"placeholder_sa_conv",
});
op_desc.SetInput("Mean",
{
matched.at("right_bn1_mean")->arg()->name,
matched.at("right_bn2_mean")->arg()->name,
matched.at("right_bn3_mean")->arg()->name,
"placeholder_sa_conv",
});
op_desc.SetInput("Var",
{
matched.at("right_bn1_variance")->arg()->name,
matched.at("right_bn2_variance")->arg()->name,
matched.at("right_bn3_variance")->arg()->name,
"placeholder_sa_conv",
});
op_desc.SetOutput("Outputs", {matched.at("relu_out")->arg()->name});
// XXX: keep these to fool SubgraphOp::AttachImpl()
op_desc.SetAttr<int>("sub_block", 0);
op_desc.SetAttr<std::vector<std::string>>("input_data_names", {});
op_desc.SetAttr<std::vector<std::string>>("output_data_names", {});
auto block1_stmt = matched.at("right_conv1")->stmt();
auto fake_subgraph_op = LiteOpRegistry::Global().Create("subgraph");
// XXX: memleak?
auto sub_block_desc = new cpp::BlockDesc();
static_cast<operators::SubgraphOp*>(fake_subgraph_op.get())
->SetSubBlock(sub_block_desc);
fake_subgraph_op->Attach(op_desc, block1_stmt->op()->scope());
fake_subgraph_op->SetValidPlaces(block1_stmt->op()->valid_places());
block1_stmt->SetOp(fake_subgraph_op);
std::vector<std::string> froms = {
"right_conv2_weight",
"right_conv3_weight",
"right_conv4_weight",
"right_bn1_bias",
"right_bn2_bias",
"right_bn3_bias",
};
for (auto& from : froms) {
IR_NODE_LINK_TO(matched.at(from), matched.at("right_conv1"));
}
IR_OP_VAR_LINK(matched.at("right_conv1"), matched.at("relu_out"));
}
};
class XPUResNetCbamBlock2Fuser : public FuseBase {
public:
XPUResNetCbamBlock2Fuser() {}
void BuildPattern() override {
auto* input = VarNode("input")->assert_is_op_input("clip", "X")->AsInput();
auto* clip = OpNode("clip", "clip");
auto* clip_out = VarNode("clip_out")
->assert_is_op_output("clip", "Out")
->assert_is_op_input("elementwise_pow")
->AsIntermediate();
auto* eltwise_y = VarNode("eltwise_y")
->assert_is_op_input("elementwise_pow")
->assert_is_op_input("elementwise_div")
->AsIntermediate();
auto* eltwise_pow =
OpNode("eltwise_pow", "elementwise_pow")->AsIntermediate();
auto* eltwise_pow_out = VarNode("eltwise_pow_out")
->assert_is_op_output("elementwise_pow", "Out")
->assert_is_op_input("pad2d", "X")
->AsIntermediate();
auto* pad2d = OpNode("pad2d", "pad2d")->AsIntermediate();
auto* pad2d_out = VarNode("pad2d_out")
->assert_is_op_output("pad2d", "Out")
->assert_is_op_input("pool2d", "X")
->AsIntermediate();
auto* pool2d = OpNode("pool2d", "pool2d")->AsIntermediate();
auto* pool2d_out = VarNode("pool2d_out")
->assert_is_op_output("pool2d", "Out")
->assert_is_op_input("elementwise_pow")
->AsIntermediate();
auto* fill_const = OpNode("fill_const", "fill_constant")->AsIntermediate();
auto* fill_const_out = VarNode("fill_const_out")
->assert_is_op_output("fill_constant", "Out")
->assert_is_op_input("elementwise_div")
->AsIntermediate();
auto* eltwise_div =
OpNode("eltwise_div", "elementwise_div")->AsIntermediate();
auto* eltwise_div_out = VarNode("eltwise_div_out")
->assert_is_op_output("elementwise_div", "Out")
->assert_is_op_input("elementwise_pow")
->AsIntermediate();
auto* eltwise_pow2 =
OpNode("eltwise_pow2", "elementwise_pow")->AsIntermediate();
auto* eltwise_pow2_out = VarNode("eltwise_pow2_out")
->assert_is_op_output("elementwise_pow", "Out")
->AsIntermediate();
auto* shape = OpNode("shape", "shape")->AsIntermediate();
auto* shape_out = VarNode("shape_out")
->assert_is_op_output("shape", "Out")
->assert_is_op_input("gather")
->AsIntermediate();
auto* fill_const2 =
OpNode("fill_const2", "fill_constant")->AsIntermediate();
auto* fill_const2_out = VarNode("fill_const2_out")
->assert_is_op_output("fill_constant", "Out")
->assert_is_op_input("gather")
->AsIntermediate();
auto* gather = OpNode("gather", "gather")->AsIntermediate();
auto* gather_out = VarNode("gather_out")
->assert_is_op_output("gather", "Out")
->assert_is_op_input("assign", "X")
->AsIntermediate();
auto* assign = OpNode("assign", "assign")->AsIntermediate();
auto* assign_out = VarNode("assign_out")
->assert_is_op_output("assign", "Out")
->assert_is_op_input("concat")
->AsIntermediate();
auto* fill_const3 =
OpNode("fill_const3", "fill_constant")->AsIntermediate();
auto* fill_const3_out = VarNode("fill_const3_out")
->assert_is_op_output("fill_constant", "Out")
->assert_is_op_input("assign")
->AsIntermediate();
auto* assign2 = OpNode("assign2", "assign")->AsIntermediate();
auto* assign2_out = VarNode("assign2_out")
->assert_is_op_output("assign", "Out")
->assert_is_op_input("concat")
->AsIntermediate();
auto* concat = OpNode("concat", "concat")->AsIntermediate();
auto* concat_out = VarNode("concat_out")
->assert_is_op_output("concat", "Out")
->assert_is_op_input("cast", "X")
->AsIntermediate();
auto* cast = OpNode("cast", "cast")->AsIntermediate();
auto* cast_out = VarNode("cast_out")
->assert_is_op_output("cast", "Out")
->assert_is_op_input("reshape2", "Shape")
->AsIntermediate();
auto* reshape2 = OpNode("reshape2", "reshape2")->AsIntermediate();
auto* reshape2_out = VarNode("reshape2_out")
->assert_is_op_output("reshape2", "Out")
->assert_is_op_input("matmul", "X")
->AsIntermediate();
auto* reshape2_xshape = VarNode("reshape2_xshape")
->assert_is_op_output("reshape2", "XShape")
->AsIntermediate();
auto* matmul_y =
VarNode("matmul_y")->assert_is_op_input("matmul", "Y")->AsInput();
auto* matmul = OpNode("matmul", "matmul")->AsIntermediate();
auto* matmul_out = VarNode("matmul_out")
->assert_is_op_output("matmul", "Out")
->assert_is_op_input("elementwise_add")
->AsIntermediate();
auto* eltwise_add_y = VarNode("eltwise_add_y")
->assert_is_op_input("elementwise_add")
->AsInput();
auto* eltwise_add =
OpNode("eltwise_add", "elementwise_add")->AsIntermediate();
auto* eltwise_add_out = VarNode("eltwise_add_out")
->assert_is_op_output("elementwise_add", "Out")
->AsIntermediate();
auto* norm = OpNode("norm", "norm")->AsIntermediate();
auto* norm_out = VarNode("norm_out")
->assert_is_op_output("norm", "Out")
->assert_is_op_input("elementwise_add")
->AsIntermediate();
auto* norm_norm = VarNode("norm_norm")
->assert_is_op_output("norm", "Norm")
->AsIntermediate();
auto* fill_const4 =
OpNode("fill_const4", "fill_constant")->AsIntermediate();
auto* fill_const4_out = VarNode("fill_const4_out")
->assert_is_op_output("fill_constant", "Out")
->assert_is_op_input("elementwise_add")
->AsIntermediate();
auto* eltwise_add2 =
OpNode("eltwise_add2", "elementwise_add")->AsIntermediate();
auto* eltwise_add2_out = VarNode("eltwise_add2_out")
->assert_is_op_output("elementwise_add", "Out")
->assert_is_op_input("elementwise_mul")
->AsIntermediate();
auto* fill_const5 =
OpNode("fill_const5", "fill_constant")->AsIntermediate();
auto* fill_const5_out = VarNode("fill_const5_out")
->assert_is_op_output("fill_constant", "Out")
->assert_is_op_input("elementwise_mul")
->AsIntermediate();
auto* eltwise_mul =
OpNode("eltwise_mul", "elementwise_mul")->AsIntermediate();
auto* eltwise_mul_out = VarNode("eltwise_mul_out")
->assert_is_op_output("elementwise_mul", "Out")
->assert_is_op_input("elementwise_div")
->AsIntermediate();
auto* eltwise_div2 =
OpNode("eltwise_div2", "elementwise_div")->AsIntermediate();
auto* eltwise_div2_out = VarNode("eltwise_div2_out")
->assert_is_op_output("elementwise_div", "Out")
->AsOutput();
*input >> *clip >> *clip_out >> *eltwise_pow >> *eltwise_pow_out >>
*pad2d >> *pad2d_out >> *pool2d >> *pool2d_out >> *eltwise_pow2;
*eltwise_y >> *eltwise_pow;
*fill_const >> *fill_const_out >> *eltwise_div >> *eltwise_div_out >>
*eltwise_pow2;
*eltwise_y >> *eltwise_div;
*eltwise_pow2 >> *eltwise_pow2_out >> *shape >> *shape_out >> *gather >>
*gather_out >> *assign >> *assign_out >> *concat >> *concat_out >>
*cast >> *cast_out >> *reshape2;
*fill_const2 >> *fill_const2_out >> *gather;
*fill_const3 >> *fill_const3_out >> *assign2 >> *assign2_out >> *concat;
*eltwise_pow2_out >> *reshape2;
*reshape2 >> *reshape2_out >> *matmul >> *matmul_out >> *eltwise_add >>
*eltwise_add_out;
*reshape2 >> *reshape2_xshape;
*matmul_y >> *matmul;
*eltwise_add_y >> *eltwise_add;
*eltwise_add_out >> *norm >> *norm_out >> *eltwise_add2 >>
*eltwise_add2_out >> *eltwise_mul >> *eltwise_mul_out >>
*eltwise_div2 >> *eltwise_div2_out;
*norm >> *norm_norm;
*fill_const4 >> *fill_const4_out >> *eltwise_add2;
*fill_const5 >> *fill_const5_out >> *eltwise_mul;
*eltwise_add_out >> *eltwise_div2;
}
void InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) override {
cpp::OpDesc op_desc;
op_desc.SetType("resnet_cbam_block2");
op_desc.SetInput("Inputs", {matched.at("input")->arg()->name});
op_desc.SetInput("Filter", {matched.at("matmul_y")->arg()->name});
op_desc.SetInput("Scale", {"placeholder_last_fc"});
op_desc.SetInput("Bias", {matched.at("eltwise_add_y")->arg()->name});
op_desc.SetInput("Mean", {"placeholder_last_fc"});
op_desc.SetInput("Var", {"placeholder_last_fc"});
op_desc.SetOutput("Outputs", {matched.at("eltwise_div2_out")->arg()->name});
// XXX: keep these to fool SubgraphOp::AttachImpl()
op_desc.SetAttr<int>("sub_block", 0);
op_desc.SetAttr<std::vector<std::string>>("input_data_names", {});
op_desc.SetAttr<std::vector<std::string>>("output_data_names", {});
// extra traits to distill
auto block2_stmt = matched.at("clip")->stmt();
auto* scope = block2_stmt->op()->scope();
auto pow_tensor_name = matched.at("eltwise_y")->arg()->name;
auto* pow_tensor = scope->FindTensor(pow_tensor_name);
float pool_p = pow_tensor->data<float>()[0];
op_desc.SetAttr<float>("pool_p", pool_p);
auto* matmul_op_info = matched.at("matmul")->stmt()->op_info();
CHECK(matmul_op_info->GetAttr<bool>("transpose_Y") == true)
<< "Y of last fc must have been transposed";
auto fake_subgraph_op = LiteOpRegistry::Global().Create("subgraph");
// XXX: memleak?
auto sub_block_desc = new cpp::BlockDesc();
static_cast<operators::SubgraphOp*>(fake_subgraph_op.get())
->SetSubBlock(sub_block_desc);
fake_subgraph_op->Attach(op_desc, scope);
fake_subgraph_op->SetValidPlaces(block2_stmt->op()->valid_places());
block2_stmt->SetOp(fake_subgraph_op);
std::vector<std::string> froms = {
"matmul_y", "eltwise_add_y",
};
for (auto& from : froms) {
IR_NODE_LINK_TO(matched.at(from), matched.at("clip"));
}
IR_OP_VAR_LINK(matched.at("clip"), matched.at("eltwise_div2_out"));
}
};
class XPUResNetCbamFuser : public xpu::XPUFuseBase {
public:
XPUResNetCbamFuser() {}
void BuildPattern() override {
auto* input =
VarNode("input")->assert_is_op_input("conv2d", "Input")->AsInput();
auto* top_conv_weight = VarNode("top_conv_weight")
->assert_is_op_input("conv2d", "Filter")
->AsInput();
auto* top_conv = OpNode("top_conv", "conv2d");
auto* top_conv_out = VarNode("top_conv_out")
->assert_is_op_output("conv2d", "Output")
->assert_is_op_input("batch_norm", "X")
->AsIntermediate();
auto* top_bn_scale = VarNode("top_bn_scale")
->assert_is_op_input("batch_norm", "Scale")
->AsIntermediate();
auto* top_bn_bias = VarNode("top_bn_bias")
->assert_is_op_input("batch_norm", "Bias")
->AsInput();
auto* top_bn_mean = VarNode("top_bn_mean")
->assert_is_op_input("batch_norm", "Mean")
->AsIntermediate();
auto* top_bn_var = VarNode("top_bn_variance")
->assert_is_op_input("batch_norm", "Variance")
->AsIntermediate();
auto* top_bn = OpNode("top_bn", "batch_norm")->AsIntermediate();
auto* top_bn_out = VarNode("top_bn_out")
->assert_is_op_output("batch_norm", "Y")
->assert_is_op_input("relu", "X")
->AsIntermediate();
auto* top_bn_mean_out = VarNode("top_bn_mean_out")
->assert_is_op_output("batch_norm", "MeanOut")
->AsIntermediate();
auto* top_bn_var_out =
VarNode("top_bn_var_out")
->assert_is_op_output("batch_norm", "VarianceOut")
->AsIntermediate();
auto* top_bn_saved_mean =
VarNode("top_bn_saved_mean")
->assert_is_op_output("batch_norm", "SavedMean")
->AsIntermediate();
auto* top_bn_saved_var =
VarNode("top_bn_saved_var")
->assert_is_op_output("batch_norm", "SavedVariance")
->AsIntermediate();
auto* top_relu = OpNode("top_relu", "relu")->AsIntermediate();
auto* top_relu_out = VarNode("top_relu_out")
->assert_is_op_output("relu", "Out")
->assert_is_op_input("pool2d", "X")
->AsIntermediate();
auto* top_pool = OpNode("top_pool", "pool2d")->AsIntermediate();
auto* top_pool_out =
VarNode("top_pool_out")
->assert_is_op_output("pool2d", "Out")
->assert_is_op_input("resnet_cbam_block0", "Inputs")
->AsIntermediate();
// args are left out
auto* resnet_block0_1 =
OpNode("resnet_block0_1", "resnet_cbam_block0")->AsIntermediate();
auto* resnet_block0_1_out =
VarNode("resnet_block0_1_out")
->assert_is_op_output("resnet_cbam_block0", "Outputs")
->AsIntermediate();
auto* resnet_block1_1_1 =
OpNode("resnet_block1_1_1", "resnet_cbam_block1")->AsIntermediate();
auto* resnet_block1_1_1_out =
VarNode("resnet_block1_1_1_out")
->assert_is_op_output("resnet_cbam_block1", "Outputs")
->AsIntermediate();
auto* resnet_block1_1_2 =
OpNode("resnet_block1_1_2", "resnet_cbam_block1")->AsIntermediate();
auto* resnet_block1_1_2_out =
VarNode("resnet_block1_1_2_out")
->assert_is_op_output("resnet_cbam_block1", "Outputs")
->AsIntermediate();
auto* resnet_block0_2 =
OpNode("resnet_block0_2", "resnet_cbam_block0")->AsIntermediate();
auto* resnet_block0_2_out =
VarNode("resnet_block0_2_out")
->assert_is_op_output("resnet_cbam_block0", "Outputs")
->AsIntermediate();
auto* resnet_block1_2_1 =
OpNode("resnet_block1_2_1", "resnet_cbam_block1")->AsIntermediate();
auto* resnet_block1_2_1_out =
VarNode("resnet_block1_2_1_out")
->assert_is_op_output("resnet_cbam_block1", "Outputs")
->AsIntermediate();
auto* resnet_block1_2_2 =
OpNode("resnet_block1_2_2", "resnet_cbam_block1")->AsIntermediate();
auto* resnet_block1_2_2_out =
VarNode("resnet_block1_2_2_out")
->assert_is_op_output("resnet_cbam_block1", "Outputs")
->AsIntermediate();
auto* resnet_block1_2_3 =
OpNode("resnet_block1_2_3", "resnet_cbam_block1")->AsIntermediate();
auto* resnet_block1_2_3_out =
VarNode("resnet_block1_2_3_out")
->assert_is_op_output("resnet_cbam_block1", "Outputs")
->AsIntermediate();
auto* resnet_block0_3 =
OpNode("resnet_block0_3", "resnet_cbam_block0")->AsIntermediate();
auto* resnet_block0_3_out =
VarNode("resnet_block0_3_out")
->assert_is_op_output("resnet_cbam_block0", "Outputs")
->AsIntermediate();
auto* resnet_block1_3_1 =
OpNode("resnet_block1_3_1", "resnet_cbam_block1")->AsIntermediate();
auto* resnet_block1_3_1_out =
VarNode("resnet_block1_3_1_out")
->assert_is_op_output("resnet_cbam_block1", "Outputs")
->AsIntermediate();
auto* resnet_block1_3_2 =
OpNode("resnet_block1_3_2", "resnet_cbam_block1")->AsIntermediate();
auto* resnet_block1_3_2_out =
VarNode("resnet_block1_3_2_out")
->assert_is_op_output("resnet_cbam_block1", "Outputs")
->AsIntermediate();
auto* resnet_block1_3_3 =
OpNode("resnet_block1_3_3", "resnet_cbam_block1")->AsIntermediate();
auto* resnet_block1_3_3_out =
VarNode("resnet_block1_3_3_out")
->assert_is_op_output("resnet_cbam_block1", "Outputs")
->AsIntermediate();
auto* resnet_block1_3_4 =
OpNode("resnet_block1_3_4", "resnet_cbam_block1")->AsIntermediate();
auto* resnet_block1_3_4_out =
VarNode("resnet_block1_3_4_out")
->assert_is_op_output("resnet_cbam_block1", "Outputs")
->AsIntermediate();
auto* resnet_block1_3_5 =
OpNode("resnet_block1_3_5", "resnet_cbam_block1")->AsIntermediate();
auto* resnet_block1_3_5_out =
VarNode("resnet_block1_3_5_out")
->assert_is_op_output("resnet_cbam_block1", "Outputs")
->AsIntermediate();
auto* resnet_block0_4 =
OpNode("resnet_block0_4", "resnet_cbam_block0")->AsIntermediate();
auto* resnet_block0_4_out =
VarNode("resnet_block0_4_out")
->assert_is_op_output("resnet_cbam_block0", "Outputs")
->AsIntermediate();
auto* resnet_block1_4_1 =
OpNode("resnet_block1_4_1", "resnet_cbam_block1")->AsIntermediate();
auto* resnet_block1_4_1_out =
VarNode("resnet_block1_4_1_out")
->assert_is_op_output("resnet_cbam_block1", "Outputs")
->AsIntermediate();
auto* resnet_block1_4_2 =
OpNode("resnet_block1_4_2", "resnet_cbam_block1")->AsIntermediate();
auto* resnet_block1_4_2_out =
VarNode("resnet_block1_4_2_out")
->assert_is_op_output("resnet_cbam_block1", "Outputs")
->AsIntermediate();
auto* resnet_block2 =
OpNode("resnet_block2", "resnet_cbam_block2")->AsIntermediate();
auto* resnet_block2_out =
VarNode("resnet_block2_out")
->assert_is_op_output("resnet_cbam_block2", "Outputs")
->AsOutput();
*input >> *top_conv >> *top_conv_out >> *top_bn >> *top_bn_out >>
*top_relu >> *top_relu_out >> *top_pool >> *top_pool_out >>
*resnet_block0_1 >> *resnet_block0_1_out >> *resnet_block1_1_1 >>
*resnet_block1_1_1_out >> *resnet_block1_1_2 >>
*resnet_block1_1_2_out >> *resnet_block0_2 >> *resnet_block0_2_out >>
*resnet_block1_2_1 >> *resnet_block1_2_1_out >> *resnet_block1_2_2 >>
*resnet_block1_2_2_out >> *resnet_block1_2_3 >>
*resnet_block1_2_3_out >> *resnet_block0_3 >> *resnet_block0_3_out >>
*resnet_block1_3_1 >> *resnet_block1_3_1_out >> *resnet_block1_3_2 >>
*resnet_block1_3_2_out >> *resnet_block1_3_3 >>
*resnet_block1_3_3_out >> *resnet_block1_3_4 >>
*resnet_block1_3_4_out >> *resnet_block1_3_5 >>
*resnet_block1_3_5_out >> *resnet_block0_4 >> *resnet_block0_4_out >>
*resnet_block1_4_1 >> *resnet_block1_4_1_out >> *resnet_block1_4_2 >>
*resnet_block1_4_2_out >> *resnet_block2 >> *resnet_block2_out;
*top_conv_weight >> *top_conv;
*top_bn_scale >> *top_bn;
*top_bn_bias >> *top_bn;
*top_bn_mean >> *top_bn;
*top_bn_var >> *top_bn;
*top_bn >> *top_bn_mean_out;
*top_bn >> *top_bn_var_out;
*top_bn >> *top_bn_saved_mean;
*top_bn >> *top_bn_saved_var;
}
void handle_placeholder_sa_conv(SSAGraph* graph,
const key2nodes_t& matched,
paddle::lite::Scope* scope,
const std::string& filter_name,
std::vector<std::string>* max_filter_name) {
auto* filter_t = scope->FindMutableTensor(filter_name);
int filter_len = filter_t->numel();
float* filter_on_host = filter_t->mutable_data<float>();
float max_f =
paddle::lite::xpu::math::FindMaxAbs(filter_on_host, filter_len);
std::unique_ptr<int16_t[]> filter_int16(new int16_t[filter_len]);
paddle::lite::xpu::math::ConvertFP32ToInt16(
filter_on_host, filter_int16.get(), max_f, filter_len);
memcpy(filter_on_host, filter_int16.get(), filter_len * sizeof(int16_t));
// create new arg in graph and scope
std::string max_name = filter_name + "_max";
max_filter_name->push_back(max_name);
auto* max_filter_node = graph->NewArgumentNode(max_name);
max_filter_node->arg()->is_weight = true;
max_filter_node->arg()->type = LiteType::GetTensorTy(
TARGET(kHost), PRECISION(kFloat), DATALAYOUT(kNCHW));
DirectedLink(max_filter_node, matched.at("top_conv"));
auto* max_filter_t = scope->NewTensor(max_name);
max_filter_t->Resize({4});
float* max_ptr = max_filter_t->mutable_data<float>();
max_ptr[0] = max_f;
max_ptr[1] = max_f;
max_ptr[2] = max_f;
max_ptr[3] = max_f;
}
void handle_placeholder_last_fc(SSAGraph* graph,
const key2nodes_t& matched,
paddle::lite::Scope* scope,
const std::string& filter_name,
std::vector<std::string>* max_filter_name) {
auto* filter_t = scope->FindMutableTensor(filter_name);
auto filter_dims = filter_t->dims();
int filter_len = filter_t->numel();
float* filter_on_host = filter_t->mutable_data<float>();
// XXX(miaotianxiang): Y has already been transposed in model...
float max_f =
paddle::lite::xpu::math::FindMaxAbs(filter_on_host, filter_len);
std::unique_ptr<int16_t[]> filter_int16(new int16_t[filter_len]);
paddle::lite::xpu::math::ConvertFP32ToInt16(
filter_on_host, filter_int16.get(), max_f, filter_len);
memcpy(filter_on_host, filter_int16.get(), filter_len * sizeof(int16_t));
// create new arg in graph and scope
std::string max_name = filter_name + "_max";
max_filter_name->push_back(max_name);
auto* max_filter_node = graph->NewArgumentNode(max_name);
max_filter_node->arg()->is_weight = true;
max_filter_node->arg()->type = LiteType::GetTensorTy(
TARGET(kHost), PRECISION(kFloat), DATALAYOUT(kNCHW));
DirectedLink(max_filter_node, matched.at("top_conv"));
auto* max_filter_t = scope->NewTensor(max_name);
max_filter_t->Resize({4});
float* max_ptr = max_filter_t->mutable_data<float>();
max_ptr[0] = max_f;
max_ptr[1] = max_f;
max_ptr[2] = max_f;
max_ptr[3] = max_f;
}
void InsertNewNode(SSAGraph* graph,
const key2nodes_t& matched,
const std::vector<Node*>& extra_input_vars) override {
cpp::OpDesc op_desc;
op_desc.SetType("__xpu__resnet_cbam");
op_desc.SetInput("Input", {matched.at("input")->arg()->name});
std::vector<std::string> filter_name = {
matched.at("top_conv_weight")->arg()->name};
std::vector<std::string> scale_name = {
matched.at("top_bn_scale")->arg()->name};
std::vector<std::string> bias_name = {
matched.at("top_bn_bias")->arg()->name};
std::vector<std::string> mean_name = {
matched.at("top_bn_mean")->arg()->name};
std::vector<std::string> var_name = {
matched.at("top_bn_variance")->arg()->name};
std::vector<std::string> max_filter_name;
std::vector<std::string> resnet_block_vec = {
"resnet_block0_1",
"resnet_block1_1_1",
"resnet_block1_1_2",
"resnet_block0_2",
"resnet_block1_2_1",
"resnet_block1_2_2",
"resnet_block1_2_3",
"resnet_block0_3",
"resnet_block1_3_1",
"resnet_block1_3_2",
"resnet_block1_3_3",
"resnet_block1_3_4",
"resnet_block1_3_5",
"resnet_block0_4",
"resnet_block1_4_1",
"resnet_block1_4_2",
"resnet_block2",
};
for (auto& block : resnet_block_vec) {
auto* block_op_info = matched.at(block)->stmt()->op_info();
auto block_filter_name = block_op_info->Input("Filter");
std::copy(block_filter_name.begin(),
block_filter_name.end(),
std::back_inserter(filter_name));
auto block_scale_name = block_op_info->Input("Scale");
std::copy(block_scale_name.begin(),
block_scale_name.end(),
std::back_inserter(scale_name));
auto block_bias_name = block_op_info->Input("Bias");
std::copy(block_bias_name.begin(),
block_bias_name.end(),
std::back_inserter(bias_name));
auto block_mean_name = block_op_info->Input("Mean");
std::copy(block_mean_name.begin(),
block_mean_name.end(),
std::back_inserter(mean_name));
auto block_var_name = block_op_info->Input("Var");
std::copy(block_var_name.begin(),
block_var_name.end(),
std::back_inserter(var_name));
}
auto* resnet_cbam_stmt = matched.at("top_conv")->stmt();
auto* scope = resnet_cbam_stmt->op()->scope();
for (size_t i = 0; i < filter_name.size(); ++i) {
if (scale_name[i] == "placeholder_sa_conv") {
handle_placeholder_sa_conv(
graph, matched, scope, filter_name[i], &max_filter_name);
continue;
} else if (scale_name[i] == "placeholder_last_fc") {
handle_placeholder_last_fc(
graph, matched, scope, filter_name[i], &max_filter_name);
continue;
}
auto* filter_t = scope->FindMutableTensor(filter_name[i]);
auto* scale_t = scope->FindMutableTensor(scale_name[i]);
auto* bias_t = scope->FindMutableTensor(bias_name[i]);
auto* mean_t = scope->FindMutableTensor(mean_name[i]);
auto* var_t = scope->FindMutableTensor(var_name[i]);
int mean_len = mean_t->numel();
int filter_len = filter_t->numel();
int filter_stride = filter_len / mean_len;
float* filter_on_host = filter_t->mutable_data<float>();
float* scale_on_host = scale_t->mutable_data<float>();
float* bias_on_host = bias_t->mutable_data<float>();
float* mean_on_host = mean_t->mutable_data<float>();
float* var_on_host = var_t->mutable_data<float>();
// Perform preprocess
for (int i = 0; i < mean_len; ++i) {
scale_on_host[i] = scale_on_host[i] / sqrtf(var_on_host[i] + 0.00001f);
}
for (int i = 0; i < mean_len; ++i) {
for (int j = 0; j < filter_stride; ++j) {
filter_on_host[i * filter_stride + j] *= scale_on_host[i];
}
}
for (int i = 0; i < mean_len; ++i) {
bias_on_host[i] += -mean_on_host[i] * scale_on_host[i];
}
float max_f =
paddle::lite::xpu::math::FindMaxAbs(filter_on_host, filter_len);
std::unique_ptr<int16_t[]> filter_int16(new int16_t[filter_len]);
paddle::lite::xpu::math::ConvertFP32ToInt16(
filter_on_host, filter_int16.get(), max_f, filter_len);
memcpy(filter_on_host, filter_int16.get(), filter_len * sizeof(int16_t));
// create new arg in graph and scope
std::string max_name = filter_name[i] + "_max";
max_filter_name.push_back(max_name);
auto* max_filter_node = graph->NewArgumentNode(max_name);
max_filter_node->arg()->is_weight = true;
max_filter_node->arg()->type = LiteType::GetTensorTy(
TARGET(kHost), PRECISION(kFloat), DATALAYOUT(kNCHW));
DirectedLink(max_filter_node, matched.at("top_conv"));
auto* max_filter_t = scope->NewTensor(max_name);
max_filter_t->Resize({4});
float* max_ptr = max_filter_t->mutable_data<float>();
max_ptr[0] = max_f;
max_ptr[1] = max_f;
max_ptr[2] = max_f;
max_ptr[3] = max_f;
}
op_desc.SetInput("Filter", filter_name);
op_desc.SetInput("Bias", bias_name);
op_desc.SetInput("MaxFilter", max_filter_name);
op_desc.SetOutput("Output", {matched.at("resnet_block2_out")->arg()->name});
op_desc.SetAttr<int>("xpu", 1);
auto* block2_op_info = matched.at("resnet_block2")->stmt()->op_info();
op_desc.SetAttr<float>("pool_p", block2_op_info->GetAttr<float>("pool_p"));
auto resnet_cbam_op = LiteOpRegistry::Global().Create(op_desc.Type());
resnet_cbam_op->Attach(op_desc, scope);
resnet_cbam_op->SetValidPlaces(resnet_cbam_stmt->op()->valid_places());
auto kernels =
resnet_cbam_op->CreateKernels(resnet_cbam_op->valid_places());
resnet_cbam_stmt->SetOp(resnet_cbam_op);
resnet_cbam_stmt->SetKernels(std::move(kernels));
IR_NODE_LINK_TO(matched.at("top_bn_bias"), matched.at("top_conv"));
for (auto* node : extra_input_vars) {
IR_NODE_LINK_TO(node, matched.at("top_conv"));
}
IR_OP_VAR_LINK(matched.at("top_conv"), matched.at("resnet_block2_out"));
}
};
} // namespace fusion
class XPUResNetCbamFusePass : public ProgramPass {
public:
void Apply(const std::unique_ptr<SSAGraph>& graph) override {
if (GetBoolFromEnv("XPU_ENABLE_XTCL")) return;
fusion::XPUResNetCbamBlock0Fuser block0_fuser;
block0_fuser(graph.get());
fusion::XPUResNetCbamBlock1Fuser block1_fuser;
block1_fuser(graph.get());
fusion::XPUResNetCbamBlock2Fuser block2_fuser;
block2_fuser(graph.get());
fusion::XPUResNetCbamFuser resnet_fuser;
resnet_fuser(graph.get());
}
};
} // namespace mir
} // namespace lite
} // namespace paddle
REGISTER_MIR_PASS(__xpu__resnet_cbam_fuse_pass,
paddle::lite::mir::XPUResNetCbamFusePass)
.BindTargets({TARGET(kXPU)})
.BindKernel("__xpu__resnet_cbam");
...@@ -94,6 +94,8 @@ class Optimizer { ...@@ -94,6 +94,8 @@ class Optimizer {
#endif #endif
"identity_dropout_eliminate_pass", "identity_dropout_eliminate_pass",
"__xpu__resnet_fuse_pass", "__xpu__resnet_fuse_pass",
"__xpu__resnet_cbam_fuse_pass",
"__xpu__mmdnn_fuse_pass",
"__xpu__multi_encoder_fuse_pass", "__xpu__multi_encoder_fuse_pass",
"__xpu__embedding_with_eltwise_add_fuse_pass", "__xpu__embedding_with_eltwise_add_fuse_pass",
"__xpu__fc_fuse_pass", "__xpu__fc_fuse_pass",
......
...@@ -157,7 +157,7 @@ void slice_compute(const lite::Tensor* in, ...@@ -157,7 +157,7 @@ void slice_compute(const lite::Tensor* in,
} }
} }
out->mutable_data<float>(lite::TargetType::kX86); out->mutable_data<float>();
auto new_out_dims = out->dims(); auto new_out_dims = out->dims();
auto offsets = Eigen::array<int, D>(); auto offsets = Eigen::array<int, D>();
......
...@@ -6,6 +6,7 @@ if(LITE_WITH_XTCL) ...@@ -6,6 +6,7 @@ if(LITE_WITH_XTCL)
add_subdirectory(bridges) add_subdirectory(bridges)
add_kernel(subgraph_compute_xpu XPU basic SRCS subgraph_compute.cc DEPS ${lite_kernel_deps} device_xpu subgraph_bridge_engine ${xpu_subgraph_bridges}) add_kernel(subgraph_compute_xpu XPU basic SRCS subgraph_compute.cc DEPS ${lite_kernel_deps} device_xpu subgraph_bridge_engine ${xpu_subgraph_bridges})
else() else()
# basic
add_kernel(conv_compute_xpu XPU basic SRCS conv_compute.cc DEPS ${lite_kernel_deps}) add_kernel(conv_compute_xpu XPU basic SRCS conv_compute.cc DEPS ${lite_kernel_deps})
add_kernel(io_copy_compute_xpu XPU basic SRCS io_copy_compute.cc DEPS ${lite_kernel_deps} target_wrapper_xpu) add_kernel(io_copy_compute_xpu XPU basic SRCS io_copy_compute.cc DEPS ${lite_kernel_deps} target_wrapper_xpu)
add_kernel(batch_norm_compute_xpu XPU basic SRCS batch_norm_compute.cc DEPS ${lite_kernel_deps}) add_kernel(batch_norm_compute_xpu XPU basic SRCS batch_norm_compute.cc DEPS ${lite_kernel_deps})
...@@ -15,15 +16,32 @@ else() ...@@ -15,15 +16,32 @@ else()
add_kernel(mul_compute_xpu XPU basic SRCS mul_compute.cc DEPS ${lite_kernel_deps}) add_kernel(mul_compute_xpu XPU basic SRCS mul_compute.cc DEPS ${lite_kernel_deps})
add_kernel(softmax_compute_xpu XPU basic SRCS softmax_compute.cc DEPS ${lite_kernel_deps}) add_kernel(softmax_compute_xpu XPU basic SRCS softmax_compute.cc DEPS ${lite_kernel_deps})
add_kernel(scale_compute_xpu XPU basic SRCS scale_compute.cc DEPS ${lite_kernel_deps}) add_kernel(scale_compute_xpu XPU basic SRCS scale_compute.cc DEPS ${lite_kernel_deps})
add_kernel(lookup_table_compute_xpu XPU basic SRCS lookup_table_compute.cc DEPS ${lite_kernel_deps})
add_kernel(layer_norm_compute_xpu XPU basic SRCS layer_norm_compute.cc DEPS ${lite_kernel_deps})
add_kernel(dropout_compute_xpu XPU basic SRCS dropout_compute.cc DEPS ${lite_kernel_deps}) add_kernel(dropout_compute_xpu XPU basic SRCS dropout_compute.cc DEPS ${lite_kernel_deps})
add_kernel(matmul_compute_xpu XPU basic SRCS matmul_compute.cc DEPS ${lite_kernel_deps}) add_kernel(matmul_compute_xpu XPU basic SRCS matmul_compute.cc DEPS ${lite_kernel_deps})
add_kernel(stack_compute_xpu XPU basic SRCS stack_compute.cc DEPS ${lite_kernel_deps}) add_kernel(stack_compute_xpu XPU basic SRCS stack_compute.cc DEPS ${lite_kernel_deps})
add_kernel(slice_compute_xpu XPU basic SRCS slice_compute.cc DEPS ${lite_kernel_deps}) add_kernel(slice_compute_xpu XPU basic SRCS slice_compute.cc DEPS ${lite_kernel_deps})
add_kernel(cast_compute_xpu XPU basic SRCS cast_compute.cc DEPS ${lite_kernel_deps}) add_kernel(cast_compute_xpu XPU basic SRCS cast_compute.cc DEPS ${lite_kernel_deps})
add_kernel(sequence_topk_avg_pooling_compute_xpu XPU basic SRCS sequence_topk_avg_pooling_compute.cc DEPS ${lite_kernel_deps})
add_kernel(concat_compute_xpu XPU basic SRCS concat_compute.cc DEPS ${lite_kernel_deps})
add_kernel(search_fc_compute_xpu XPU basic SRCS search_fc_compute.cc DEPS ${lite_kernel_deps})
# extra
add_kernel(lookup_table_compute_xpu XPU extra SRCS lookup_table_compute.cc DEPS ${lite_kernel_deps})
add_kernel(layer_norm_compute_xpu XPU extra SRCS layer_norm_compute.cc DEPS ${lite_kernel_deps})
add_kernel(sequence_reverse_compute_xpu XPU extra SRCS sequence_reverse_compute.cc DEPS ${lite_kernel_deps})
add_kernel(sequence_concat_compute_xpu XPU extra SRCS sequence_concat_compute.cc DEPS ${lite_kernel_deps})
add_kernel(sequence_arithmetic_compute_xpu XPU extra SRCS sequence_arithmetic_compute.cc DEPS ${lite_kernel_deps})
add_kernel(sequence_pool_compute_xpu XPU extra SRCS sequence_pool_compute.cc DEPS ${lite_kernel_deps})
add_kernel(match_matrix_tensor_compute_xpu XPU extra SRCS match_matrix_tensor_compute.cc DEPS ${lite_kernel_deps})
add_kernel(var_conv_2d_compute_xpu XPU extra SRCS var_conv_2d_compute.cc DEPS ${lite_kernel_deps})
add_kernel(search_grnn_compute_xpu XPU extra SRCS search_grnn_compute.cc DEPS ${lite_kernel_deps})
# extra(fused kernel)
add_kernel(__xpu__resnet50_compute_xpu XPU extra SRCS __xpu__resnet50_compute.cc DEPS ${lite_kernel_deps}) add_kernel(__xpu__resnet50_compute_xpu XPU extra SRCS __xpu__resnet50_compute.cc DEPS ${lite_kernel_deps})
add_kernel(__xpu__resnet_cbam_compute_xpu XPU extra SRCS __xpu__resnet_cbam_compute.cc DEPS ${lite_kernel_deps})
add_kernel(__xpu__multi_encoder_compute_xpu XPU extra SRCS __xpu__multi_encoder_compute.cc DEPS ${lite_kernel_deps}) add_kernel(__xpu__multi_encoder_compute_xpu XPU extra SRCS __xpu__multi_encoder_compute.cc DEPS ${lite_kernel_deps})
add_kernel(__xpu__embedding_with_eltwise_add_compute_xpu XPU extra SRCS __xpu__embedding_with_eltwise_add_compute.cc DEPS ${lite_kernel_deps}) add_kernel(__xpu__embedding_with_eltwise_add_compute_xpu XPU extra SRCS __xpu__embedding_with_eltwise_add_compute.cc DEPS ${lite_kernel_deps})
add_kernel(__xpu__fc_compute_xpu XPU extra SRCS __xpu__fc_compute.cc DEPS ${lite_kernel_deps}) add_kernel(__xpu__fc_compute_xpu XPU extra SRCS __xpu__fc_compute.cc DEPS ${lite_kernel_deps})
add_kernel(__xpu__search_attention_compute_xpu XPU extra SRCS __xpu__search_attention_compute.cc DEPS ${lite_kernel_deps})
add_kernel(__xpu__mmdnn_compute_xpu XPU extra SRCS __xpu__mmdnn_compute.cc DEPS ${lite_kernel_deps})
endif() endif()
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <memory>
#include <vector>
#include "lite/backends/xpu/xpu_header_sitter.h"
#include "lite/core/kernel.h"
#include "lite/core/op_registry.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace xpu {
namespace {
void FillMax(float max, float* xpu_ptr) {
float maxs[4] = {max, 0.0f, 0.0f, 0.0f};
xpu_memcpy(
xpu_ptr, maxs, 4 * sizeof(float), XPUMemcpyKind::XPU_HOST_TO_DEVICE);
}
void GrnnLayout(int batch,
const std::vector<int>& offset,
std::vector<int>* new_offset_ptr,
std::vector<int>* idx_sorted_ptr) {
auto& new_offset = *new_offset_ptr;
auto& idx_sorted = *idx_sorted_ptr;
std::vector<int> width;
width.resize(batch);
new_offset.clear();
idx_sorted.clear();
idx_sorted.resize(batch);
for (int i = 0; i < batch; i++) {
width[i] = offset[i + 1] - offset[i];
idx_sorted[i] = i;
}
std::sort(idx_sorted.data(),
idx_sorted.data() + batch,
[&width](int a, int b) { return width[a] > width[b]; });
int max_width = width[idx_sorted[0]];
new_offset.resize(max_width + 1);
new_offset[0] = 0;
int j = batch - 1;
int last_width = 0;
int sub_row = 0;
int sub_col = 0;
for (int i = 1; i <= max_width;) {
for (int k = j; k >= 0; --k) {
if (width[idx_sorted[k]] > last_width) {
sub_row = width[idx_sorted[k]] - last_width;
sub_col = k + 1;
for (int s = 0; s < sub_row; s++) {
new_offset[i] = new_offset[i - 1] + sub_col;
i++;
}
// move on
last_width = width[idx_sorted[k]];
j = k - 1;
break;
}
}
}
}
} // anonymous namespace
class MMDNNIdInfo {
XPUScratchPadGuard l3_buffer_guard_;
char* l3_buffer_{nullptr};
std::unique_ptr<char[]> cpu_buffer_guard_;
char* cpu_buffer_{nullptr};
public:
const int64_t* id0_64{nullptr};
const int64_t* id1_64{nullptr};
int64_t* lod_64{nullptr};
int* lod_32{nullptr};
int* new_offset_32{nullptr};
int* idx_sorted_32{nullptr};
std::vector<int> lod;
std::vector<int> new_offset;
std::vector<int> idx_sorted;
int batch;
int seqlen_max;
int seqlen_sum;
int seqlen_square_sum;
void Init(int upper_bound_batch, int upper_bound_seqlen) {
int ub_lod_64_size = (upper_bound_batch + 1) * sizeof(int64_t);
int ub_lod_32_size = (upper_bound_batch + 1) * sizeof(int);
int ub_new_offset_32_size = (upper_bound_seqlen + 1) * sizeof(int);
int ub_idx_sorted_32_size = (upper_bound_batch + 1) * sizeof(int);
int total_size = ub_lod_64_size + ub_lod_32_size + ub_new_offset_32_size +
ub_idx_sorted_32_size;
// TODO(miaotianxiang): use l3?
l3_buffer_guard_ = TargetWrapperXPU::MallocScratchPad(total_size, false);
l3_buffer_ = reinterpret_cast<char*>(l3_buffer_guard_->addr_);
cpu_buffer_guard_.reset(new char[total_size]);
cpu_buffer_ = cpu_buffer_guard_.get();
}
void Update(lite::Tensor* id0, lite::Tensor* id1) {
auto& id0_lod = id0->lod()[0];
lod.clear();
for (auto e : id0_lod) {
lod.push_back(e);
}
seqlen_max = 0;
seqlen_sum = 0;
seqlen_square_sum = 0;
batch = lod.size() - 1;
for (int i = 0; i < batch; i++) {
int seqlen = lod[i + 1] - lod[i];
seqlen_max = std::max(seqlen_max, seqlen);
seqlen_sum = seqlen_sum + seqlen;
seqlen_square_sum = seqlen_square_sum + seqlen * seqlen;
}
GrnnLayout(batch, lod, &new_offset, &idx_sorted);
id0_64 = id0->data<int64_t>();
id1_64 = id1->data<int64_t>();
int offset = 0;
lod_64 = reinterpret_cast<int64_t*>(l3_buffer_ + offset);
memcpy(
cpu_buffer_ + offset, id0_lod.data(), id0_lod.size() * sizeof(int64_t));
offset += id0_lod.size() * sizeof(int64_t);
lod_32 = reinterpret_cast<int*>(l3_buffer_ + offset);
memcpy(cpu_buffer_ + offset, lod.data(), lod.size() * sizeof(int));
offset += lod.size() * sizeof(int);
new_offset_32 = reinterpret_cast<int*>(l3_buffer_ + offset);
memcpy(cpu_buffer_ + offset,
new_offset.data(),
new_offset.size() * sizeof(int));
offset += new_offset.size() * sizeof(int);
idx_sorted_32 = reinterpret_cast<int*>(l3_buffer_ + offset);
memcpy(cpu_buffer_ + offset,
idx_sorted.data(),
idx_sorted.size() * sizeof(int));
offset += idx_sorted.size() * sizeof(int);
xpu_memcpy(
l3_buffer_, cpu_buffer_, offset, XPUMemcpyKind::XPU_HOST_TO_DEVICE);
}
};
class MMDNNFcOp {
const int16_t* weight_{nullptr};
XPUScratchPadGuard weight_max_guard_;
float* weight_max_{nullptr};
const float* bias_{nullptr};
XPUScratchPadGuard in_max_guard_;
float* in_max_{nullptr};
int n_;
int k_;
xdnn::Activation_t::act_enum act_type_;
XPUScratchPadGuard out_max_guard_;
public:
float* out_max{nullptr};
void Init(const int16_t* weight,
float weight_max,
const float* bias,
int n,
int k,
xdnn::Activation_t::act_enum act_type) {
n_ = n;
k_ = k;
act_type_ = act_type;
weight_ = weight;
weight_max_guard_ =
TargetWrapperXPU::MallocScratchPad(4 * sizeof(float), false);
weight_max_ = reinterpret_cast<float*>(weight_max_guard_->addr_);
FillMax(weight_max, weight_max_);
bias_ = bias;
in_max_guard_ =
TargetWrapperXPU::MallocScratchPad(4 * sizeof(float), false);
out_max_guard_ =
TargetWrapperXPU::MallocScratchPad(4 * sizeof(float), false);
in_max_ = reinterpret_cast<float*>(in_max_guard_->addr_);
out_max = reinterpret_cast<float*>(in_max_guard_->addr_);
}
void Init(lite::Tensor* weight,
float weight_max,
lite::Tensor* bias,
int n,
int k,
xdnn::Activation_t::act_enum act_type) {
Init(weight->data<int16_t>(),
weight_max,
bias ? bias->data<float>() : nullptr,
n,
k,
act_type);
}
void Infer(xdnn::Context* ctx,
const float* in,
int m,
float* out,
const float* in_max_by_caller = nullptr) {
if (in_max_by_caller == nullptr) {
xdnn::findmax<float>(ctx, in, m * k_, in_max_);
in_max_by_caller = in_max_;
}
xdnn::gemm_int16_maxptr<float, int16_t, float>(ctx,
false,
true,
m,
n_,
k_,
1.0f,
in,
k_,
weight_,
k_,
0.0f,
out,
n_,
bias_,
act_type_,
in_max_by_caller,
weight_max_,
out_max);
}
};
class MMDNNGrnnOp {
MMDNNFcOp fc_e2h0_;
MMDNNFcOp fc_e2h1_;
MMDNNFcOp fc_e2h2_;
const int16_t* dense_h2h_{nullptr};
float dense_h2h_max_[3];
XPUScratchPadGuard input_max_guard_;
float* input_max_{nullptr};
XPUScratchPadGuard hbm_buffer_guard_;
float* hbm_buffer_{nullptr};
// require: cap_l * max(cap_e_, cap_h_) * 5
// seq2batch_out: [cap_l, cap_e_]
// fc_e2h_out: [3, cap_l, cap_h_]
// gru_out: [cap_l, cap_h_]
int cap_e_;
int cap_h_;
int max_cap_l_;
public:
void Init(lite::Tensor* wh,
const std::vector<float>& wh_maxs,
lite::Tensor* wi,
const std::vector<float>& wi_maxs,
int cap_e,
int cap_h,
int max_cap_l) {
cap_e_ = cap_e;
cap_h_ = cap_h;
max_cap_l_ = max_cap_l;
// weight
auto* dense_e2h = wi->data<int16_t>();
fc_e2h0_.Init(dense_e2h,
wi_maxs[0],
nullptr,
cap_h_,
cap_e_,
xdnn::Activation_t::LINEAR);
fc_e2h1_.Init(dense_e2h + cap_e_ * cap_h_,
wi_maxs[1],
nullptr,
cap_h_,
cap_e_,
xdnn::Activation_t::LINEAR);
fc_e2h2_.Init(dense_e2h + cap_e_ * cap_h_ * 2,
wi_maxs[2],
nullptr,
cap_h_,
cap_e_,
xdnn::Activation_t::LINEAR);
dense_h2h_ = wh->data<int16_t>();
dense_h2h_max_[0] = wh_maxs[0];
dense_h2h_max_[1] = wh_maxs[1];
dense_h2h_max_[2] = wh_maxs[2];
input_max_guard_ =
TargetWrapperXPU::MallocScratchPad(4 * sizeof(float), false);
input_max_ = reinterpret_cast<float*>(input_max_guard_->addr_);
hbm_buffer_guard_ = TargetWrapperXPU::MallocScratchPad(
5 * std::max(cap_e_, cap_h_) * max_cap_l_ * sizeof(float), false);
hbm_buffer_ = reinterpret_cast<float*>(hbm_buffer_guard_->addr_);
}
void Infer(xdnn::Context* ctx,
const MMDNNIdInfo& sentense,
const float* in,
float* out,
float* l3_buffer = nullptr,
int l3_size = 0) {
int batch = sentense.batch;
int cap_l = sentense.seqlen_sum;
int max_width = sentense.seqlen_max;
int slot_size = cap_l * std::max(cap_e_, cap_h_);
float* seq2batch_out = hbm_buffer_;
float* fc_e2h_out = hbm_buffer_ + 1 * slot_size;
float* gru_out = hbm_buffer_ + 4 * slot_size;
if (l3_size > 0 && l3_size >= 5 * slot_size * sizeof(float)) {
seq2batch_out = l3_buffer;
fc_e2h_out = l3_buffer + 1 * slot_size;
gru_out = l3_buffer + 4 * slot_size;
}
xdnn::search_seq2batch(ctx,
batch,
max_width,
cap_e_,
sentense.idx_sorted_32,
sentense.lod_32,
sentense.new_offset_32,
in,
seq2batch_out);
xdnn::findmax<float>(ctx, in, cap_l * cap_e_, input_max_);
fc_e2h0_.Infer(ctx, seq2batch_out, cap_l, fc_e2h_out, input_max_);
fc_e2h1_.Infer(
ctx, seq2batch_out, cap_l, fc_e2h_out + cap_l * cap_h_, input_max_);
fc_e2h2_.Infer(
ctx, seq2batch_out, cap_l, fc_e2h_out + cap_l * cap_h_ * 2, input_max_);
xdnn::search_grnn<float, int16_t>(ctx,
cap_l,
cap_h_,
cap_e_,
max_width,
sentense.new_offset_32,
fc_e2h_out,
dense_h2h_,
gru_out,
dense_h2h_max_[0],
dense_h2h_max_[1],
dense_h2h_max_[2]);
xdnn::search_batch2seq(ctx,
batch,
max_width,
cap_h_,
sentense.idx_sorted_32,
sentense.lod_32,
sentense.new_offset_32,
gru_out,
out);
}
};
class MMDNNAttentionOp {
int dim_;
float alpha0_;
float alpha1_;
MMDNNFcOp seqfc_;
XPUScratchPadGuard hbm_buffer_guard_;
float* hbm_buffer_{nullptr};
// require: cap_l * dim_ + seqlen_square_sum
// seqfc_out: [cap_l, dim_]
// batchgemm0_out: [seqlen_square_sum]
// seq_softmax_out: [seqlen_square_sum], reuse of batchgemm0_out
// batchgemm1_out: [cap_l, dim_], reuse of seqfc_out
public:
void Init(lite::Tensor* att_fc_w,
float att_fc_w_max,
lite::Tensor* att_fc_b,
int dim,
int upper_bound_batch,
int upper_bound_seqlen) {
dim_ = dim;
alpha0_ = 0.0883883461356163f; // TODO(miaotianxiang):
alpha1_ = 1.0f;
seqfc_.Init(att_fc_w,
att_fc_w_max,
att_fc_b,
dim_,
dim_,
xdnn::Activation_t::LINEAR);
hbm_buffer_guard_ = TargetWrapperXPU::MallocScratchPad(
(upper_bound_batch * (upper_bound_seqlen * dim_ +
upper_bound_seqlen * upper_bound_seqlen)) *
sizeof(float),
false);
hbm_buffer_ = reinterpret_cast<float*>(hbm_buffer_guard_->addr_);
}
void Infer(xdnn::Context* ctx,
const MMDNNIdInfo& sentense,
const float* input,
float* pool_out,
float* l3_buffer = nullptr,
int l3_size = 0) {
int batch = sentense.batch;
int cap_l = sentense.seqlen_sum;
int max_width = sentense.seqlen_max;
int* lod_32 = sentense.lod_32;
float* seqfc_out = hbm_buffer_;
float* batchgemm0_out = hbm_buffer_ + cap_l * dim_;
float* seq_softmax_out = batchgemm0_out;
float* batchgemm1_out = seqfc_out;
if (l3_size > 0 &&
l3_size >=
(cap_l * dim_ + sentense.seqlen_square_sum) * sizeof(float)) {
seqfc_out = l3_buffer;
batchgemm0_out = l3_buffer + cap_l * dim_;
seq_softmax_out = batchgemm0_out;
batchgemm1_out = seqfc_out;
}
seqfc_.Infer(ctx, input, cap_l, seqfc_out);
xdnn::search_noaligned_mat_mul(ctx,
0,
1,
batch,
lod_32,
max_width,
dim_,
alpha0_,
input,
seqfc_out,
batchgemm0_out);
xdnn::search_seq_softmax(
ctx, batchgemm0_out, seq_softmax_out, lod_32, batch, max_width);
xdnn::search_noaligned_mat_mul(ctx,
0,
0,
batch,
lod_32,
max_width,
dim_,
alpha1_,
seq_softmax_out,
input,
batchgemm1_out);
xdnn::sequence_pooling_forward(ctx,
xdnn::Pooling_t::MAX_WITHOUT_INDEX,
batch,
lod_32,
dim_,
batchgemm1_out,
nullptr,
pool_out);
}
};
class MMDNNMatchConvTopk {
std::vector<int> topks_;
int dim_t_;
int dim_in_;
int out_channel_;
MMDNNFcOp xw_fc_;
const int16_t* conv_weight_{nullptr};
float conv_weight_max_;
XPUScratchPadGuard hbm_buffer_guard_;
float* hbm_buffer_{nullptr};
// xw_out: [sum(left_len), dim_t_ * dim_in_]
// xwy_out: [sum(left_len * right_len) * dim_t_]
// conv_out: [sum(left_len * right_len) * out_channel_]
// seq_concat_out: [sum(left_len * right_len) * (dim_t_ + out_channel_)]
XPUScratchPadGuard left_lod_32_guard_;
int* left_lod_32_{nullptr};
XPUScratchPadGuard right_lod_32_guard_;
int* right_lod_32_{nullptr};
XPUScratchPadGuard match_lod_32_guard_;
int* match_lod_32_{nullptr};
XPUScratchPadGuard conv_lod_32_guard_;
int* conv_lod_32_{nullptr};
XPUScratchPadGuard topk_offset_32_guard_;
int* topk_offset_32_{nullptr};
XPUScratchPadGuard topks_xpu_guard_;
int* topks_xpu_{nullptr};
XPUScratchPadGuard useless_topk_pos_guard_;
int* useless_topk_pos_{nullptr};
public:
float* seq_avg_topk_out{nullptr};
void Init(lite::Tensor* input_w,
float input_w_max,
lite::Tensor* conv_w,
float conv_w_max,
int dim_t,
int dim_in,
int upper_bound_batch,
int upper_bound_seqlen,
const std::vector<int>& topks) {
dim_t_ = dim_t;
dim_in_ = dim_in;
out_channel_ = 5; // TODO(miaotianxiang):
topks_ = topks;
xw_fc_.Init(input_w,
input_w_max,
nullptr,
dim_t_ * dim_in_,
dim_in_,
xdnn::Activation_t::LINEAR);
conv_weight_ = conv_w->data<int16_t>();
conv_weight_max_ = conv_w_max;
hbm_buffer_guard_ = TargetWrapperXPU::MallocScratchPad(
(upper_bound_batch * upper_bound_seqlen * dim_t_ * dim_in_ +
upper_bound_batch * upper_bound_seqlen * upper_bound_seqlen *
(dim_t_ + out_channel_) * 2) *
sizeof(float),
false);
hbm_buffer_ = reinterpret_cast<float*>(hbm_buffer_guard_->addr_);
left_lod_32_guard_ = TargetWrapperXPU::MallocScratchPad(
(upper_bound_batch + 1) * sizeof(int), false);
left_lod_32_ = reinterpret_cast<int*>(left_lod_32_guard_->addr_);
right_lod_32_guard_ = TargetWrapperXPU::MallocScratchPad(
(upper_bound_batch + 1) * sizeof(int), false);
right_lod_32_ = reinterpret_cast<int*>(right_lod_32_guard_->addr_);
match_lod_32_guard_ = TargetWrapperXPU::MallocScratchPad(
(upper_bound_batch + 1) * sizeof(int), false);
match_lod_32_ = reinterpret_cast<int*>(match_lod_32_guard_->addr_);
conv_lod_32_guard_ = TargetWrapperXPU::MallocScratchPad(
(upper_bound_batch + 1) * sizeof(int), false);
conv_lod_32_ = reinterpret_cast<int*>(conv_lod_32_guard_->addr_);
topk_offset_32_guard_ = TargetWrapperXPU::MallocScratchPad(
(upper_bound_batch + 1) * sizeof(int), false);
topk_offset_32_ = reinterpret_cast<int*>(topk_offset_32_guard_->addr_);
topks_xpu_guard_ =
TargetWrapperXPU::MallocScratchPad(topks_.size() * sizeof(int), false);
topks_xpu_ = reinterpret_cast<int*>(topks_xpu_guard_->addr_);
xpu_memcpy(topks_xpu_,
topks_.data(),
topks_.size() * sizeof(int),
XPUMemcpyKind::XPU_HOST_TO_DEVICE);
useless_topk_pos_guard_ =
TargetWrapperXPU::MallocScratchPad(4 * sizeof(int), false);
useless_topk_pos_ = reinterpret_cast<int*>(useless_topk_pos_guard_->addr_);
}
void Infer(xdnn::Context* ctx,
lite::Tensor* left,
lite::Tensor* right,
lite::Tensor* out,
float* l3_buffer = nullptr,
int l3_size = 0) {
auto left_lod = left->lod()[0];
auto right_lod = right->lod()[0];
int batch = left_lod.size() - 1;
std::vector<int> left_lod_32_cpu;
for (auto e : left_lod) {
left_lod_32_cpu.push_back(e);
}
xpu_memcpy(left_lod_32_,
left_lod_32_cpu.data(),
left_lod_32_cpu.size() * sizeof(int),
XPUMemcpyKind::XPU_HOST_TO_DEVICE);
std::vector<int> right_lod_32_cpu;
for (auto e : right_lod) {
right_lod_32_cpu.push_back(e);
}
xpu_memcpy(right_lod_32_,
right_lod_32_cpu.data(),
right_lod_32_cpu.size() * sizeof(int),
XPUMemcpyKind::XPU_HOST_TO_DEVICE);
std::vector<int> lod_match = {0};
std::vector<int> lod_conv = {0};
std::vector<int> lod_topk = {0};
int x_mul_y_sum = 0;
int left_seqlen_sum = 0;
int left_seqlen_max = 0;
int right_seqlen_sum = 0;
int right_seqlen_max = 0;
for (int i = 0; i < batch; i++) {
int len_x = left_lod[i + 1] - left_lod[i];
int len_y = right_lod[i + 1] - right_lod[i];
int imgsize = len_x * len_y;
x_mul_y_sum = x_mul_y_sum + imgsize;
lod_match.push_back(lod_match.back() + imgsize * dim_t_);
lod_conv.push_back(lod_conv.back() + imgsize * out_channel_);
lod_topk.push_back(lod_topk.back() + imgsize * (dim_t_ + out_channel_));
left_seqlen_max = std::max(left_seqlen_max, len_x);
right_seqlen_max = std::max(right_seqlen_max, len_y);
left_seqlen_sum += len_x;
right_seqlen_sum += len_y;
}
xpu_memcpy(match_lod_32_,
lod_match.data(),
lod_match.size() * sizeof(int),
XPUMemcpyKind::XPU_HOST_TO_DEVICE);
xpu_memcpy(conv_lod_32_,
lod_conv.data(),
lod_conv.size() * sizeof(int),
XPUMemcpyKind::XPU_HOST_TO_DEVICE);
xpu_memcpy(topk_offset_32_,
lod_topk.data(),
lod_topk.size() * sizeof(int),
XPUMemcpyKind::XPU_HOST_TO_DEVICE);
float* xwy_out = hbm_buffer_;
float* conv_out = hbm_buffer_ + x_mul_y_sum * dim_t_;
float* seq_concat_out = hbm_buffer_ + x_mul_y_sum * (dim_t_ + out_channel_);
float* xw_out = hbm_buffer_ + x_mul_y_sum * (dim_t_ + out_channel_) * 2;
int total_len = x_mul_y_sum * (dim_t_ + out_channel_) * 2 +
left_seqlen_sum * dim_t_ * dim_in_;
if (l3_size > 0 && l3_size >= total_len * sizeof(float)) {
xwy_out = l3_buffer;
conv_out = l3_buffer + x_mul_y_sum * dim_t_;
seq_concat_out = l3_buffer + x_mul_y_sum * (dim_t_ + out_channel_);
xw_out = l3_buffer + x_mul_y_sum * (dim_t_ + out_channel_) * 2;
}
seq_avg_topk_out = out->mutable_data<float>(TARGET(kXPU));
int max_width = std::max(left_seqlen_max, right_seqlen_max);
xw_fc_.Infer(ctx, left->data<float>(), left_seqlen_sum, xw_out);
xdnn::match_matrix_tensor(ctx,
batch,
xw_out,
right->data<float>(),
left_lod_32_,
right_lod_32_,
dim_t_,
dim_in_,
xwy_out,
xw_fc_.out_max,
xdnn::Activation_t::RELU,
max_width);
xdnn::search_varconv<float, int16_t>(
ctx,
batch,
dim_t_,
out_channel_,
5,
5,
1,
1,
xwy_out,
conv_weight_,
right_lod_32_,
left_lod_32_,
conv_out,
conv_weight_max_,
xdnn::Activation_t::RELU); // TODO(miaotianxiang):
xdnn::sequence_concat(ctx,
xwy_out,
match_lod_32_,
conv_out,
conv_lod_32_,
seq_concat_out,
batch);
xdnn::sequence_topk_avg_pooling(ctx,
seq_concat_out,
seq_avg_topk_out,
useless_topk_pos_,
batch,
dim_t_ + out_channel_,
topk_offset_32_,
left_lod_32_,
right_lod_32_,
topks_xpu_,
topks_.size());
}
};
class MMDNNBidEmbGrnnAtt {
const float* table_{nullptr};
int table_len_;
int emb_dim_;
int cap_h_;
MMDNNGrnnOp bi_fw_;
MMDNNGrnnOp bi_rv_;
MMDNNAttentionOp att_;
XPUScratchPadGuard hbm_buffer_guard_;
float* hbm_buffer_{nullptr};
// require at least: 4 * cap_l * emb_dim_
// emb_rv: [cap_l, emb_dim_]
// grnn_fw: [cap_l, emb_dim_]
// grnn_rv: [cap_l, emb_dim_]
// grnn_rv_rv: [cap_l, emb_dim_]
// concat_2in: [cap_l, 2 * emb_dim_]
// L3.bi_fw: 5 * cap_l * emb_dim_
// L3.bi_rv: 5 * cap_l * emb_dim_
// L3.att: cap_l * 2 * emb_dim_ + seqlen_square_sum
// execution-plan:
// 1. bid_emb_ew, alloc(emb_rv)
// 2. bi_rv, alloc(grnn_rv)
// 3. free(emb_rv)
// 4. sequence_reverse, alloc(grnn_rv_rv)
// 5. sequence_pooling(grnn_rv)
// 6. free(grnn_rv)
// 7. bi_fw alloc(grnn_fw)
// 8. sequence_pooling(grnn_fw)
// 9. concat_2 alloc(concat_2in)
// 10. concat_3
// 11. att
// alloc-plan:
// [0]: emb_rv, grnn_rv_rv
// [1]: grnn_rv, grnn_fw
// [2, 3]: concat_2in
// [2, 3, 4, 5, 6]: L3.bi_fw, L3.bi_rv
// [4, 5, ..., ?]: L3.att
public:
float* emb_fw{nullptr};
float* concat_3in{nullptr};
float* pool_fw{nullptr};
float* pool_rv{nullptr};
float* att_out{nullptr};
void Init(lite::Tensor* table,
lite::Tensor* fw_wh,
const std::vector<float>& fw_wh_maxs,
lite::Tensor* fw_wi,
const std::vector<float>& fw_wi_maxs,
lite::Tensor* rv_wh,
const std::vector<float>& rv_wh_maxs,
lite::Tensor* rv_wi,
const std::vector<float>& rv_wi_maxs,
lite::Tensor* att_fc_w,
float att_fc_w_max,
lite::Tensor* att_fc_b,
int upper_bound_batch,
int upper_bound_seqlen) {
table_ = table->data<float>();
table_len_ = table->dims()[0];
emb_dim_ = table->dims()[1];
cap_h_ = emb_dim_;
int max_cap_l = upper_bound_batch * upper_bound_seqlen;
bi_fw_.Init(
fw_wh, fw_wh_maxs, fw_wi, fw_wi_maxs, emb_dim_, cap_h_, max_cap_l);
bi_rv_.Init(
rv_wh, rv_wh_maxs, rv_wi, rv_wi_maxs, emb_dim_, cap_h_, max_cap_l);
att_.Init(att_fc_w,
att_fc_w_max,
att_fc_b,
2 * cap_h_,
upper_bound_batch,
upper_bound_seqlen);
hbm_buffer_guard_ = TargetWrapperXPU::MallocScratchPad(
4 * max_cap_l * cap_h_ * sizeof(float), false);
hbm_buffer_ = reinterpret_cast<float*>(hbm_buffer_guard_->addr_);
}
void Infer(xdnn::Context* ctx,
int batch,
const MMDNNIdInfo& sentense,
lite::Tensor* grnn_fw_pool_out,
lite::Tensor* grnn_rv_pool_out,
lite::Tensor* att_pool_out,
lite::Tensor* concat_3in1_out,
lite::Tensor* emb_fw_out,
float* l3_buffer = nullptr,
int l3_size = 0) {
int cap_l = sentense.seqlen_sum;
int slot_len = cap_l * cap_h_;
float* emb_rv = hbm_buffer_;
float* grnn_fw = hbm_buffer_ + slot_len;
float* grnn_rv = hbm_buffer_ + slot_len;
float* grnn_rv_rv = hbm_buffer_;
float* concat_2in = hbm_buffer_ + 2 * slot_len;
if (l3_size > 0 && l3_size >= 4 * slot_len * sizeof(float)) {
emb_rv = l3_buffer;
grnn_fw = l3_buffer + slot_len;
grnn_rv = l3_buffer + slot_len;
grnn_rv_rv = l3_buffer;
}
emb_fw = emb_fw_out->mutable_data<float>(TARGET(kXPU));
concat_3in = concat_3in1_out->mutable_data<float>(TARGET(kXPU));
pool_fw = grnn_fw_pool_out->mutable_data<float>(TARGET(kXPU));
pool_rv = grnn_rv_pool_out->mutable_data<float>(TARGET(kXPU));
att_out = att_pool_out->mutable_data<float>(TARGET(kXPU));
xdnn::search_bid_emb_ew(ctx,
batch,
sentense.lod_64,
sentense.id0_64,
sentense.id1_64,
table_,
table_len_,
emb_dim_,
emb_fw,
emb_rv,
table_len_ - 2,
1);
bi_rv_.Infer(ctx,
sentense,
emb_rv,
grnn_rv,
l3_buffer + 2 * slot_len,
l3_size - 2 * slot_len * sizeof(float));
xdnn::sequence_reverse(
ctx, batch, sentense.lod_32, cap_h_, grnn_rv, grnn_rv_rv);
xdnn::sequence_pooling_forward(ctx,
xdnn::Pooling_t::LAST,
batch,
sentense.lod_32,
cap_h_,
grnn_rv,
nullptr,
pool_rv);
bi_fw_.Infer(ctx,
sentense,
emb_fw,
grnn_fw,
l3_buffer + 2 * slot_len,
l3_size - 2 * slot_len * sizeof(float));
xdnn::sequence_pooling_forward(ctx,
xdnn::Pooling_t::LAST,
batch,
sentense.lod_32,
cap_h_,
grnn_fw,
nullptr,
pool_fw);
const int concat_widths[] = {cap_h_, cap_h_, cap_h_};
const float* concat_ptrs[] = {emb_fw, grnn_fw, grnn_rv_rv};
xdnn::concat<float>(
ctx, cap_l, concat_widths + 1, 2, concat_ptrs + 1, concat_2in);
xdnn::concat<float>(ctx, cap_l, concat_widths, 3, concat_ptrs, concat_3in);
att_.Infer(ctx,
sentense,
concat_2in,
att_out,
l3_buffer + 4 * slot_len,
l3_size - 4 * slot_len * sizeof(float));
}
};
class MMDNNEmbAtt {
const float* table_{nullptr};
int table_len_;
int emb_dim_;
MMDNNAttentionOp att_;
public:
float* emb_fw{nullptr};
float* att_out{nullptr};
void Init(lite::Tensor* table,
lite::Tensor* att_fc_w,
float att_fc_w_max,
lite::Tensor* att_fc_b,
int upper_bound_batch,
int upper_bound_seqlen) {
table_ = table->data<float>();
table_len_ = table->dims()[0];
emb_dim_ = table->dims()[1];
att_.Init(att_fc_w,
att_fc_w_max,
att_fc_b,
emb_dim_,
upper_bound_batch,
upper_bound_seqlen);
}
void Infer(xdnn::Context* ctx,
int batch,
const MMDNNIdInfo& sentense,
lite::Tensor* att_pool_out,
lite::Tensor* emb_fw_out,
float* l3_buffer = nullptr,
int l3_size = 0) {
emb_fw = emb_fw_out->mutable_data<float>(TARGET(kXPU));
att_out = att_pool_out->mutable_data<float>(TARGET(kXPU));
int cap_l = sentense.lod.back();
const float* emb_tables[] = {table_, table_};
const int64_t* emb_indices[] = {sentense.id0_64, sentense.id1_64};
xdnn::embedding_with_ewadd<float, int64_t, false, false>(ctx,
emb_dim_,
cap_l,
2,
table_len_ - 2,
emb_tables,
emb_indices,
nullptr,
nullptr,
emb_fw);
att_.Infer(ctx, sentense, emb_fw, att_out, l3_buffer, l3_size);
}
};
class MMDNNMergeAll {
MMDNNGrnnOp coverage_fw_;
MMDNNGrnnOp coverage_rv_;
int cap_e_;
int cap_h_;
// TODO(miaotianxiang):
const int fc0_k_ = 1152;
const int fc0_n_ = 512;
const int fc1_k_ = 640;
const int fc1_n_ = 320;
const int fc2_k_ = 320;
const int fc2_n_ = 1;
MMDNNFcOp fc0_;
MMDNNFcOp fc1_;
MMDNNFcOp fc2_;
XPUScratchPadGuard hbm_buffer_guard_;
float* hbm_buffer_{nullptr};
// topk_concat_out_fw: [cap_l, cap_e_] <= [cap_l, cap_h_]
// topk_concat_out_rv: [cap_l, cap_e_] <= [cap_l, cap_h_]
// grnn_fw: [cap_l, cap_h_]
// grnn_rv: [cap_l, cap_h_]
// pool_fw: [batch, cap_h_]
// pool_rv: [batch, cap_h_]
// fc0_in: [batch, fc0_k_]
// fc0_out: [batch, fc0_n_]
// fc1_in: [batch, fc1_k_]
// fc1_out: [batch, fc1_n_]
// fc2_out: [batch, fc2_n_]
public:
void Init(lite::Tensor* grnn_fw_wh,
std::vector<float> grnn_fw_wh_maxs,
lite::Tensor* grnn_fw_wi,
std::vector<float> grnn_fw_wi_maxs,
lite::Tensor* grnn_rv_wh,
std::vector<float> grnn_rv_wh_maxs,
lite::Tensor* grnn_rv_wi,
std::vector<float> grnn_rv_wi_maxs,
lite::Tensor* fc0_w,
float fc0_w_max,
lite::Tensor* fc0_b,
lite::Tensor* fc1_w,
float fc1_w_max,
lite::Tensor* fc1_b,
lite::Tensor* fc2_w,
float fc2_w_max,
lite::Tensor* fc2_b,
int upper_bound_batch,
int upper_bound_seqlen) {
int max_cap_l = upper_bound_batch * upper_bound_seqlen;
cap_e_ = grnn_fw_wi->dims()[2];
cap_h_ = grnn_fw_wi->dims()[1];
coverage_fw_.Init(grnn_fw_wh,
grnn_fw_wh_maxs,
grnn_fw_wi,
grnn_fw_wi_maxs,
cap_e_,
cap_h_,
max_cap_l);
coverage_rv_.Init(grnn_rv_wh,
grnn_rv_wh_maxs,
grnn_rv_wi,
grnn_rv_wi_maxs,
cap_e_,
cap_h_,
max_cap_l);
fc0_.Init(
fc0_w, fc0_w_max, fc0_b, fc0_n_, fc0_k_, xdnn::Activation_t::RELU);
fc1_.Init(
fc1_w, fc1_w_max, fc1_b, fc1_n_, fc1_k_, xdnn::Activation_t::RELU);
fc2_.Init(
fc2_w, fc2_w_max, fc2_b, fc2_n_, fc2_k_, xdnn::Activation_t::LINEAR);
int hbm_total_len = max_cap_l * cap_h_ * 4 +
upper_bound_batch * (2 * cap_h_ + fc0_k_ + fc0_n_ +
fc1_k_ + fc1_n_ + fc2_n_);
hbm_buffer_guard_ = TargetWrapperXPU::MallocScratchPad(
hbm_total_len * sizeof(float), false);
hbm_buffer_ = reinterpret_cast<float*>(hbm_buffer_guard_->addr_);
}
void Infer(xdnn::Context* ctx,
const MMDNNIdInfo& sentense,
const std::vector<lite::Tensor*> concat_2in1_x,
const std::vector<lite::Tensor*> concat_7in1_x,
lite::Tensor* out,
float* l3_buffer = nullptr,
int l3_size = 0) {
int batch = sentense.batch;
int cap_l = sentense.seqlen_sum;
float* topk_concat_out_fw = hbm_buffer_;
int hbm_total_len =
cap_l * cap_h_ * 4 +
batch * (2 * cap_h_ + fc0_k_ + fc0_n_ + fc1_k_ + fc1_n_ + fc2_n_);
if (l3_size > 0 && l3_size >= hbm_total_len * sizeof(float)) {
topk_concat_out_fw = l3_buffer;
}
float* topk_concat_out_rv = topk_concat_out_fw + cap_l * cap_h_;
float* grnn_fw = topk_concat_out_rv + cap_l * cap_h_;
float* grnn_rv = grnn_fw + cap_l * cap_h_;
float* pool_fw = grnn_rv + cap_l * cap_h_;
float* pool_rv = pool_fw + batch * cap_h_;
float* fc0_in = pool_fw + batch * cap_h_ * 2;
float* fc0_out = fc0_in + batch * fc0_k_;
float* fc1_in = fc0_out + batch * fc0_n_;
float* fc1_out = fc1_in + batch * fc1_k_;
// float* fc2_out = fc1_out + batch * fc1_n_;
float* fc2_out = out->mutable_data<float>(TARGET(kXPU));
const int concat_widths[] = {static_cast<int>(concat_2in1_x[0]->dims()[1]),
static_cast<int>(concat_2in1_x[1]->dims()[1])};
const float* concat_ptrs[] = {concat_2in1_x[0]->data<float>(),
concat_2in1_x[1]->data<float>()};
xdnn::concat<float>(
ctx, cap_l, concat_widths, 2, concat_ptrs, topk_concat_out_fw);
xdnn::sequence_reverse(ctx,
batch,
sentense.lod_32,
cap_e_,
topk_concat_out_fw,
topk_concat_out_rv);
coverage_fw_.Infer(ctx,
sentense,
topk_concat_out_fw,
grnn_fw,
l3_buffer + hbm_total_len,
l3_size - hbm_total_len * sizeof(float));
coverage_rv_.Infer(ctx,
sentense,
topk_concat_out_rv,
grnn_rv,
l3_buffer + hbm_total_len,
l3_size - hbm_total_len * sizeof(float));
xdnn::sequence_pooling_forward(ctx,
xdnn::Pooling_t::LAST,
batch,
sentense.lod_32,
cap_h_,
grnn_fw,
nullptr,
pool_fw);
xdnn::sequence_pooling_forward(ctx,
xdnn::Pooling_t::LAST,
batch,
sentense.lod_32,
cap_h_,
grnn_rv,
nullptr,
pool_rv);
const int concat_widths_fc0[] = {
static_cast<int>(concat_7in1_x[0]->dims()[1]),
static_cast<int>(concat_7in1_x[1]->dims()[1]),
static_cast<int>(concat_7in1_x[2]->dims()[1]),
static_cast<int>(concat_7in1_x[3]->dims()[1]),
static_cast<int>(concat_7in1_x[4]->dims()[1]),
static_cast<int>(concat_7in1_x[5]->dims()[1]),
static_cast<int>(concat_7in1_x[6]->dims()[1]),
};
const float* concat_ptrs_fc0[] = {
concat_7in1_x[0]->data<float>(),
concat_7in1_x[1]->data<float>(),
concat_7in1_x[2]->data<float>(),
concat_7in1_x[3]->data<float>(),
concat_7in1_x[4]->data<float>(),
concat_7in1_x[5]->data<float>(),
concat_7in1_x[6]->data<float>(),
};
const int concat_widths_fc1[] = {cap_h_, cap_h_, fc0_n_};
const float* concat_ptrs_fc1[] = {pool_fw, pool_rv, fc0_out};
xdnn::concat<float>(
ctx, batch, concat_widths_fc0, 7, concat_ptrs_fc0, fc0_in);
fc0_.Infer(ctx, fc0_in, batch, fc0_out);
xdnn::concat<float>(
ctx, batch, concat_widths_fc1, 3, concat_ptrs_fc1, fc1_in);
fc1_.Infer(ctx, fc1_in, batch, fc1_out);
fc2_.Infer(ctx, fc1_out, batch, fc2_out);
}
};
class XPUMmdnnBidEmbGrnnAttCompute
: public KernelLite<TARGET(kXPU), PRECISION(kFloat)> {
public:
using param_t = operators::XPUMmdnnBidEmbGrnnAttParam;
void PrepareForRun() override;
void Run() override;
private:
MMDNNIdInfo id_;
MMDNNBidEmbGrnnAtt compound_;
int upper_bound_batch_ = 40;
int upper_bound_seqlen_ = 512;
};
void XPUMmdnnBidEmbGrnnAttCompute::PrepareForRun() {
auto& param = this->Param<param_t>();
id_.Init(upper_bound_batch_, upper_bound_seqlen_);
compound_.Init(param.emb_tbl,
param.grnn_fw_wh,
param.grnn_fw_wh_maxs,
param.grnn_fw_wi,
param.grnn_fw_wi_maxs,
param.grnn_rv_wh,
param.grnn_rv_wh_maxs,
param.grnn_rv_wi,
param.grnn_rv_wi_maxs,
param.att_fc_w,
param.att_fc_w_max,
param.att_fc_b,
upper_bound_batch_,
upper_bound_seqlen_);
}
void XPUMmdnnBidEmbGrnnAttCompute::Run() {
auto& param = this->Param<param_t>();
auto& ctx = this->ctx_->As<XPUContext>();
auto* xpu_ctx = ctx.GetRawContext();
int batch = param.id0->lod()[0].size() - 1;
id_.Update(param.id0, param.id1);
compound_.Infer(ctx.GetRawContext(),
batch,
id_,
param.grnn_fw_pool_out,
param.grnn_rv_pool_out,
param.att_pool_out,
param.concat_3in1_out,
param.emb_fw_out,
reinterpret_cast<float*>(
reinterpret_cast<char*>(xpu_ctx->workspace_l3_ptr) +
xpu_ctx->used_l3_size),
xpu_ctx->workspace_l3_size - xpu_ctx->used_l3_size);
}
class XPUMmdnnBidEmbAttCompute
: public KernelLite<TARGET(kXPU), PRECISION(kFloat)> {
public:
using param_t = operators::XPUMmdnnBidEmbAttParam;
void PrepareForRun() override;
void Run() override;
private:
MMDNNIdInfo id_;
MMDNNEmbAtt compound_;
int upper_bound_batch_ = 40;
int upper_bound_seqlen_ = 512;
};
void XPUMmdnnBidEmbAttCompute::PrepareForRun() {
auto& param = this->Param<param_t>();
id_.Init(upper_bound_batch_, upper_bound_seqlen_);
compound_.Init(param.emb_tbl,
param.att_fc_w,
param.att_fc_w_max,
param.att_fc_b,
upper_bound_batch_,
upper_bound_seqlen_);
}
void XPUMmdnnBidEmbAttCompute::Run() {
auto& param = this->Param<param_t>();
auto& ctx = this->ctx_->As<XPUContext>();
auto* xpu_ctx = ctx.GetRawContext();
int batch = param.id0->lod()[0].size() - 1;
id_.Update(param.id0, param.id1);
compound_.Infer(ctx.GetRawContext(),
batch,
id_,
param.att_pool_out,
param.emb_fw_out,
reinterpret_cast<float*>(
reinterpret_cast<char*>(xpu_ctx->workspace_l3_ptr) +
xpu_ctx->used_l3_size),
xpu_ctx->workspace_l3_size - xpu_ctx->used_l3_size);
}
class XPUMmdnnMatchConvTopkCompute
: public KernelLite<TARGET(kXPU), PRECISION(kFloat)> {
public:
using param_t = operators::XPUMmdnnMatchConvTopkParam;
void PrepareForRun() override;
void Run() override;
private:
MMDNNMatchConvTopk compound_;
int upper_bound_batch_ = 40;
int upper_bound_seqlen_ = 512;
};
void XPUMmdnnMatchConvTopkCompute::PrepareForRun() {
auto& param = this->Param<param_t>();
compound_.Init(param.input_w,
param.input_w_max,
param.conv_w,
param.conv_w_max,
param.dim_t,
param.input_w->dims()[0],
upper_bound_batch_,
upper_bound_seqlen_,
param.topks);
}
void XPUMmdnnMatchConvTopkCompute::Run() {
auto& param = this->Param<param_t>();
auto& ctx = this->ctx_->As<XPUContext>();
auto* xpu_ctx = ctx.GetRawContext();
compound_.Infer(ctx.GetRawContext(),
param.input_x,
param.input_y,
param.topk_out,
reinterpret_cast<float*>(
reinterpret_cast<char*>(xpu_ctx->workspace_l3_ptr) +
xpu_ctx->used_l3_size),
xpu_ctx->workspace_l3_size - xpu_ctx->used_l3_size);
}
class XPUMmdnnMergeAllCompute
: public KernelLite<TARGET(kXPU), PRECISION(kFloat)> {
public:
using param_t = operators::XPUMmdnnMergeAllParam;
void PrepareForRun() override;
void Run() override;
private:
MMDNNIdInfo id_;
MMDNNMergeAll compound_;
int upper_bound_batch_ = 40;
int upper_bound_seqlen_ = 512;
};
void XPUMmdnnMergeAllCompute::PrepareForRun() {
auto& param = this->Param<param_t>();
id_.Init(upper_bound_batch_, upper_bound_seqlen_);
compound_.Init(param.grnn_fw_wh,
param.grnn_fw_wh_maxs,
param.grnn_fw_wi,
param.grnn_fw_wi_maxs,
param.grnn_rv_wh,
param.grnn_rv_wh_maxs,
param.grnn_rv_wi,
param.grnn_rv_wi_maxs,
param.fc0_w,
param.fc0_w_max,
param.fc0_b,
param.fc1_w,
param.fc1_w_max,
param.fc1_b,
param.fc2_w,
param.fc2_w_max,
param.fc2_b,
upper_bound_batch_,
upper_bound_seqlen_);
}
void XPUMmdnnMergeAllCompute::Run() {
auto& param = this->Param<param_t>();
auto& ctx = this->ctx_->As<XPUContext>();
auto* xpu_ctx = ctx.GetRawContext();
id_.Update(param.concat_2in1_x[0], param.concat_2in1_x[1]);
compound_.Infer(ctx.GetRawContext(),
id_,
param.concat_2in1_x,
param.concat_7in1_x,
param.out,
reinterpret_cast<float*>(
reinterpret_cast<char*>(xpu_ctx->workspace_l3_ptr) +
xpu_ctx->used_l3_size),
xpu_ctx->workspace_l3_size - xpu_ctx->used_l3_size);
}
} // namespace xpu
} // namespace kernels
} // namespace lite
} // namespace paddle
REGISTER_LITE_KERNEL(__xpu__mmdnn_bid_emb_grnn_att,
kXPU,
kFloat,
kNCHW,
paddle::lite::kernels::xpu::XPUMmdnnBidEmbGrnnAttCompute,
def)
.BindInput("id0", {LiteType::GetTensorTy(TARGET(kXPU), PRECISION(kInt64))})
.BindInput("id1", {LiteType::GetTensorTy(TARGET(kXPU), PRECISION(kInt64))})
.BindInput("emb_tbl", {LiteType::GetTensorTy(TARGET(kXPU))})
.BindInput("grnn_fw_wh", {LiteType::GetTensorTy(TARGET(kXPU))})
.BindInput("grnn_fw_wi", {LiteType::GetTensorTy(TARGET(kXPU))})
.BindInput("grnn_rv_wh", {LiteType::GetTensorTy(TARGET(kXPU))})
.BindInput("grnn_rv_wi", {LiteType::GetTensorTy(TARGET(kXPU))})
.BindInput("att_fc_w", {LiteType::GetTensorTy(TARGET(kXPU))})
.BindInput("att_fc_b", {LiteType::GetTensorTy(TARGET(kXPU))})
.BindOutput("grnn_fw_pool_out", {LiteType::GetTensorTy(TARGET(kXPU))})
.BindOutput("grnn_rv_pool_out", {LiteType::GetTensorTy(TARGET(kXPU))})
.BindOutput("att_pool_out", {LiteType::GetTensorTy(TARGET(kXPU))})
.BindOutput("concat_3in1_out", {LiteType::GetTensorTy(TARGET(kXPU))})
.BindOutput("emb_fw_out", {LiteType::GetTensorTy(TARGET(kXPU))})
.Finalize();
REGISTER_LITE_KERNEL(__xpu__mmdnn_bid_emb_att,
kXPU,
kFloat,
kNCHW,
paddle::lite::kernels::xpu::XPUMmdnnBidEmbAttCompute,
def)
.BindInput("id0", {LiteType::GetTensorTy(TARGET(kXPU), PRECISION(kInt64))})
.BindInput("id1", {LiteType::GetTensorTy(TARGET(kXPU), PRECISION(kInt64))})
.BindInput("emb_tbl", {LiteType::GetTensorTy(TARGET(kXPU))})
.BindInput("att_fc_w", {LiteType::GetTensorTy(TARGET(kXPU))})
.BindInput("att_fc_b", {LiteType::GetTensorTy(TARGET(kXPU))})
.BindOutput("att_pool_out", {LiteType::GetTensorTy(TARGET(kXPU))})
.BindOutput("concat_3in1_out", {LiteType::GetTensorTy(TARGET(kXPU))})
.BindOutput("emb_fw_out", {LiteType::GetTensorTy(TARGET(kXPU))})
.Finalize();
REGISTER_LITE_KERNEL(__xpu__mmdnn_match_conv_topk,
kXPU,
kFloat,
kNCHW,
paddle::lite::kernels::xpu::XPUMmdnnMatchConvTopkCompute,
def)
.BindInput("input_x", {LiteType::GetTensorTy(TARGET(kXPU))})
.BindInput("input_y", {LiteType::GetTensorTy(TARGET(kXPU))})
.BindInput("input_w", {LiteType::GetTensorTy(TARGET(kXPU))})
.BindInput("conv_w", {LiteType::GetTensorTy(TARGET(kXPU))})
.BindOutput("topk_out", {LiteType::GetTensorTy(TARGET(kXPU))})
.Finalize();
REGISTER_LITE_KERNEL(__xpu__mmdnn_merge_all,
kXPU,
kFloat,
kNCHW,
paddle::lite::kernels::xpu::XPUMmdnnMergeAllCompute,
def)
.BindInput("concat_7in1_x", {LiteType::GetTensorTy(TARGET(kXPU))})
.BindInput("concat_2in1_x", {LiteType::GetTensorTy(TARGET(kXPU))})
.BindInput("grnn_fw_wh", {LiteType::GetTensorTy(TARGET(kXPU))})
.BindInput("grnn_fw_wi", {LiteType::GetTensorTy(TARGET(kXPU))})
.BindInput("grnn_rv_wh", {LiteType::GetTensorTy(TARGET(kXPU))})
.BindInput("grnn_rv_wi", {LiteType::GetTensorTy(TARGET(kXPU))})
.BindInput("fc0_w", {LiteType::GetTensorTy(TARGET(kXPU))})
.BindInput("fc0_b", {LiteType::GetTensorTy(TARGET(kXPU))})
.BindInput("fc1_w", {LiteType::GetTensorTy(TARGET(kXPU))})
.BindInput("fc1_b", {LiteType::GetTensorTy(TARGET(kXPU))})
.BindInput("fc2_w", {LiteType::GetTensorTy(TARGET(kXPU))})
.BindInput("fc2_b", {LiteType::GetTensorTy(TARGET(kXPU))})
.BindOutput("out", {LiteType::GetTensorTy(TARGET(kXPU))})
.Finalize();
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/xpu/__xpu__resnet_cbam_compute.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace xpu {
void XPUResNetCbamCompute::PrepareForRun() {
auto& param = this->Param<param_t>();
for (auto* filter : param.filter) {
arg_filter_.push_back(
reinterpret_cast<const int16_t*>(filter->data<float>()));
}
for (auto* bias : param.bias) {
if (bias == nullptr) {
arg_bias_.push_back(nullptr);
} else {
arg_bias_.push_back(bias->data<float>());
}
}
for (auto* max_filter : param.max_filter) {
arg_max_filter_.push_back(max_filter->data<float>());
}
}
void XPUResNetCbamCompute::Run() {
auto& param = this->Param<param_t>();
auto& ctx = this->ctx_->As<XPUContext>();
auto input_dims = param.input->dims();
int batch_size = input_dims[0];
int height = input_dims[2];
int width = input_dims[3];
int r = xdnn::conv2d_int16_resnet_cbam<float, int16_t>(
ctx.GetRawContext(), /* context */
batch_size, /* num */
height, /* height */
width, /* width */
param.input->data<float>(), /* bottom */
&arg_filter_[0], /* weight_list */
param.output->mutable_data<float>(TARGET(kXPU)), /* top */
&arg_bias_[0], /* bias_list */
&arg_max_filter_[0], /* max_filter_list */
param.pool_p, /* pool_p */
true, /* midtype_fp16 */
false /* dynamic_shape */);
CHECK_EQ(r, 0);
}
} // namespace xpu
} // namespace kernels
} // namespace lite
} // namespace paddle
REGISTER_LITE_KERNEL(__xpu__resnet_cbam,
kXPU,
kFloat,
kNCHW,
paddle::lite::kernels::xpu::XPUResNetCbamCompute,
def)
.BindInput("Input", {LiteType::GetTensorTy(TARGET(kXPU))})
.BindInput("Filter", {LiteType::GetTensorTy(TARGET(kXPU))})
.BindInput("Bias", {LiteType::GetTensorTy(TARGET(kXPU))})
.BindInput("MaxFilter", {LiteType::GetTensorTy(TARGET(kXPU))})
.BindOutput("Output", {LiteType::GetTensorTy(TARGET(kXPU))})
.Finalize();
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <vector>
#include "lite/backends/xpu/xpu_header_sitter.h"
#include "lite/core/kernel.h"
#include "lite/core/op_registry.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace xpu {
class XPUResNetCbamCompute
: public KernelLite<TARGET(kXPU), PRECISION(kFloat)> {
public:
using param_t = operators::XPUResNetCbamParam;
virtual void PrepareForRun();
virtual void Run();
private:
std::vector<const int16_t *> arg_filter_;
std::vector<const float *> arg_max_filter_;
std::vector<const float *> arg_bias_;
};
} // namespace xpu
} // namespace kernels
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/xpu/__xpu__search_attention_compute.h"
#include "lite/backends/xpu/xpu_header_sitter.h"
#include "lite/core/op_registry.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace xpu {
void XPUMmdnnSearchAttentionCompute::PrepareForRun() {
offset_xpu_guard_ = TargetWrapperXPU::MallocScratchPad(64 * sizeof(int));
pad_begin_xpu_guard_ = TargetWrapperXPU::MallocScratchPad(64 * sizeof(int));
w_max_xpu_guard_ = TargetWrapperXPU::MallocScratchPad(8 * sizeof(float));
buffer_at_l3_guard_ = TargetWrapperXPU::MallocScratchPad(
5 * L3_SLOT_SIZE * sizeof(float), false /* use_l3 */);
buffer_at_gm_guard_ = TargetWrapperXPU::MallocScratchPad(
5 * GM_SLOT_SIZE * sizeof(float), false /* use_l3 */);
offset_cpu.reset(new int[64]);
pad_begin_cpu.reset(new int[64]);
}
void XPUMmdnnSearchAttentionCompute::Run() {
auto& param = this->Param<param_t>();
auto& ctx = this->ctx_->As<XPUContext>();
auto* X = param.X;
auto* W = param.W;
auto* b = param.b;
float W_max = param.W_max;
float alpha0 = param.alpha0;
float alpha1 = param.alpha1;
float mask = param.mask;
const int16_t* w_data = W->data<int16_t>();
const float* b_data = b->data<float>();
int batch = X->lod()[0].size() - 1;
int dim0 = X->dims()[0];
int dim1 = X->dims()[1];
const auto offset = X->lod()[0];
int max_seq = 0;
auto* top = param.Out;
LoD top_lod;
top_lod.push_back(X->lod()[0]);
top->set_lod(top_lod);
top->Resize({dim0, dim1});
auto* top_data = top->mutable_data<float>(TARGET(kXPU));
float maxs_cpu[8] = {0.0f, 0.0f, 0.0f, 0.0f, W_max, 0.0f, 0.0f, 0.0f};
for (int i = 0; i < batch; ++i) {
offset_cpu[i] = offset[i]; // type of offset is int64, not supported by xpu
pad_begin_cpu[i] = offset[i + 1] - offset[i];
if (offset[i + 1] - offset[i] > max_seq) {
max_seq = offset[i + 1] - offset[i];
}
}
offset_cpu[batch] = offset[batch];
xpu_memcpy(offset_xpu_guard_->addr_,
offset_cpu.get(),
offset.size() * sizeof(int),
XPUMemcpyKind::XPU_HOST_TO_DEVICE);
xpu_memcpy(pad_begin_xpu_guard_->addr_,
pad_begin_cpu.get(),
batch * sizeof(int),
XPUMemcpyKind::XPU_HOST_TO_DEVICE);
xpu_memcpy(w_max_xpu_guard_->addr_,
maxs_cpu,
8 * sizeof(float),
XPUMemcpyKind::XPU_HOST_TO_DEVICE);
int* offset_xpu = reinterpret_cast<int*>(offset_xpu_guard_->addr_);
int* pad_begin_xpu = reinterpret_cast<int*>(pad_begin_xpu_guard_->addr_);
float* maxs_xpu = reinterpret_cast<float*>(w_max_xpu_guard_->addr_);
float* buffer_at_l3 = reinterpret_cast<float*>(buffer_at_l3_guard_->addr_);
float* buffer_at_gm = reinterpret_cast<float*>(buffer_at_gm_guard_->addr_);
// when use l3, max_seq <= 128:
// group_padding: batch * max_seq * dim1; at (slot0, slot1)
// seq_fc: batch * max_seq * dim1; at (slot2, slot3)
// batchgemm0: batch * max_seq * max_seq; at slot4
// attention_padding_mask: batch * max_seq * max_seq; at slot3
// seq_softmax: batch * max_seq * max_seq; at slot4
// batchgemm1: batch * max_seq * dim1; at (slot2, slot3)
float* group_padding_output = buffer_at_l3;
float* seq_fc_output = buffer_at_l3 + 2 * L3_SLOT_SIZE;
float* batchgemm0_output = buffer_at_l3 + 4 * L3_SLOT_SIZE;
float* attention_output = buffer_at_l3 + 3 * L3_SLOT_SIZE;
float* seq_softmax_output = buffer_at_l3 + 4 * L3_SLOT_SIZE;
float* batchgemm1_output = buffer_at_l3 + 2 * L3_SLOT_SIZE;
if (max_seq > 128) {
group_padding_output = buffer_at_gm;
seq_fc_output = buffer_at_gm + 1 * GM_SLOT_SIZE;
batchgemm0_output = buffer_at_gm + 2 * GM_SLOT_SIZE;
attention_output = buffer_at_gm + 1 * GM_SLOT_SIZE;
seq_softmax_output = buffer_at_gm + 3 * GM_SLOT_SIZE;
batchgemm1_output = buffer_at_gm + 4 * GM_SLOT_SIZE;
}
const auto* bottom_data = X->data<float>();
xdnn::search_sequence_pad_depad(ctx.GetRawContext(),
const_cast<float*>(bottom_data),
group_padding_output,
offset_xpu,
max_seq,
batch,
dim1,
0); // is_depad = 0
// do-findmax
xdnn::findmax<float>(ctx.GetRawContext(),
group_padding_output,
batch * max_seq * dim1,
maxs_xpu);
xdnn::gemm_int16_maxptr<float, int16_t, float>(
ctx.GetRawContext(),
false,
true, // trans_a, trans_b
batch * max_seq,
dim1,
dim1, // m, n, k
1.0f,
group_padding_output,
dim1, // alpha, data_a, lda
w_data,
dim1,
0.0f, // data_b, ldb, beta
seq_fc_output,
dim1,
b_data, // data_c, ldc, bias
xdnn::Activation_t::LINEAR,
maxs_xpu,
maxs_xpu + 4,
nullptr); // max_a, max_b, max_c
xdnn::search_aligned_mat_mul(ctx.GetRawContext(),
0,
1,
batch,
max_seq,
max_seq,
dim1,
alpha0,
group_padding_output,
dim1,
seq_fc_output,
dim1,
batchgemm0_output,
max_seq);
xdnn::search_pad_mask(ctx.GetRawContext(),
batchgemm0_output,
attention_output,
pad_begin_xpu,
batch,
max_seq,
max_seq,
batch,
mask);
xdnn::softmax2d_forward(ctx.GetRawContext(),
attention_output,
seq_softmax_output,
batch * max_seq,
max_seq,
true);
xdnn::search_aligned_mat_mul(ctx.GetRawContext(),
0,
0,
batch,
max_seq,
dim1,
max_seq,
alpha1,
seq_softmax_output,
max_seq,
group_padding_output,
dim1,
batchgemm1_output,
dim1);
xdnn::search_sequence_pad_depad(ctx.GetRawContext(),
top_data,
batchgemm1_output,
offset_xpu,
max_seq,
batch,
dim1,
1); // is_depad = 1
}
} // namespace xpu
} // namespace kernels
} // namespace lite
} // namespace paddle
REGISTER_LITE_KERNEL(__xpu__mmdnn_search_attention,
kXPU,
kFloat,
kNCHW,
paddle::lite::kernels::xpu::XPUMmdnnSearchAttentionCompute,
def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kXPU))})
.BindInput("W", {LiteType::GetTensorTy(TARGET(kXPU))})
.BindInput("b", {LiteType::GetTensorTy(TARGET(kXPU))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kXPU))})
.Finalize();
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include "lite/backends/xpu/target_wrapper.h" // XPUScratchPadGuard
#include "lite/core/kernel.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace xpu {
class XPUMmdnnSearchAttentionCompute
: public KernelLite<TARGET(kXPU), PRECISION(kFloat)> {
public:
using param_t = operators::XPUMmdnnSearchAttentionParam;
void PrepareForRun() override;
void Run() override;
private:
XPUScratchPadGuard offset_xpu_guard_;
XPUScratchPadGuard pad_begin_xpu_guard_;
XPUScratchPadGuard w_max_xpu_guard_;
XPUScratchPadGuard buffer_at_l3_guard_;
XPUScratchPadGuard buffer_at_gm_guard_;
std::unique_ptr<int[]> offset_cpu;
std::unique_ptr<int[]> pad_begin_cpu;
const int L3_SLOT_SIZE = 40 * 128 * 128;
const int GM_SLOT_SIZE = 40 * 512 * 512;
};
} // namespace xpu
} // namespace kernels
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/xpu/concat_compute.h"
#include "lite/backends/xpu/xpu_header_sitter.h"
#include "lite/core/op_registry.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace xpu {
void ConcatCompute::Run() {
auto& param = this->Param<param_t>();
auto& ctx = this->ctx_->As<XPUContext>();
auto ins = param.x;
auto out = param.output;
int64_t axis = param.axis;
int n = ins.size();
int h = 1;
int w_except_axis = 1;
CHECK(n <= 8) << "XPU only surpport at most 8 tensors for now";
for (int i = 0; i < axis; ++i) {
h *= (ins[0]->dims())[i];
}
for (int i = axis + 1; i < ins[0]->dims().size(); ++i) {
w_except_axis *= (ins[0]->dims())[i];
}
CHECK(axis >= 0) << "concat: axis shoud >= 0!";
CHECK(axis < ins[0]->dims().size()) << "concat: axis shoud < ins[0]->dims()!";
for (int i = 0; i < n; ++i) {
int hh = 1;
int ww = 1;
for (int j = 0; j < axis; ++j) {
hh *= (ins[i]->dims())[j];
}
for (int j = axis + 1; j < ins[i]->dims().size(); ++j) {
ww *= (ins[i]->dims())[j];
}
CHECK(hh == h) << "concat: h should be eual!";
CHECK(ww == w_except_axis) << "concat: w should be eual except for axis!";
}
int in_w_host[n]; // NOLINT
const float* ptrs[n]; // NOLINT
for (int i = 0; i < n; ++i) {
ptrs[i] = ins[i]->data<float>();
in_w_host[i] = w_except_axis * (ins[i]->dims())[axis];
}
int r = xdnn::concat<float>(ctx.GetRawContext(), /* ctx */
h, /* height */
in_w_host, /* width_x */
n, /* n */
ptrs, /* lm_ptrs */
out->mutable_data<float>(TARGET(kXPU)) /*y*/);
CHECK_EQ(r, 0);
}
} // namespace xpu
} // namespace kernels
} // namespace lite
} // namespace paddle
REGISTER_LITE_KERNEL(
concat, kXPU, kFloat, kNCHW, paddle::lite::kernels::xpu::ConcatCompute, def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kXPU))})
.BindInput("AxisTensor",
{LiteType::GetTensorTy(TARGET(kXPU), PRECISION(kInt32))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kXPU))})
.Finalize();
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "lite/core/kernel.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace xpu {
class ConcatCompute : public KernelLite<TARGET(kXPU), PRECISION(kFloat)> {
public:
using param_t = operators::ConcatParam;
virtual void Run();
virtual ~ConcatCompute() = default;
};
} // namespace xpu
} // namespace kernels
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/xpu/match_matrix_tensor_compute.h"
#include <vector>
#include "lite/backends/xpu/xpu_header_sitter.h"
#include "lite/core/op_registry.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace xpu {
void MatchMatrixTensorCompute::PrepareForRun() {
wx_max_xpu_guard_ = TargetWrapperXPU::MallocScratchPad(64 * sizeof(int));
offset_l_xpu_guard_ = TargetWrapperXPU::MallocScratchPad(64 * sizeof(int));
offset_r_xpu_guard_ = TargetWrapperXPU::MallocScratchPad(64 * sizeof(int));
offset_l_cpu.reset(new int[64]);
offset_r_cpu.reset(new int[64]);
}
void MatchMatrixTensorCompute::Run() {
auto& param = this->Param<param_t>();
auto& ctx = this->ctx_->As<XPUContext>();
auto* x = param.x;
auto* y = param.y;
auto* w = param.w;
auto* out = param.out;
auto* tmp = param.tmp;
int dim_t = param.dim_t;
float w_max = param.__xpu__w_max;
bool fuse_relu = param.fuse_relu;
bool float_to_fix = param.__xpu__float_to_fix;
CHECK(float_to_fix) << "W should be fixed point";
xdnn::Activation_t act = xdnn::Activation_t::LINEAR;
if (fuse_relu) {
act = xdnn::Activation_t::RELU;
}
int dim_in = x->dims()[1];
const auto& offset_l = x->lod()[0];
const auto& offset_r = y->lod()[0];
std::vector<size_t> top_offset;
int top_size = 0;
top_offset.push_back(top_size);
for (size_t b = 0; b < x->lod()[0].size() - 1; b++) {
int len_l = offset_l[b + 1] - offset_l[b];
int len_r = offset_r[b + 1] - offset_r[b];
top_size += dim_t * len_l * len_r;
top_offset.push_back(top_size);
}
auto* bottom_l_data = x->data<float>();
auto* bottom_r_data = y->data<float>();
auto* w_data = w->data<int16_t>();
auto* out_data = out->mutable_data<float>(TARGET(kXPU));
auto* bottom_l_trans_data = tmp->mutable_data<float>(TARGET(kXPU));
int batch_size = x->lod()[0].size() - 1;
float* wx_max = reinterpret_cast<float*>(wx_max_xpu_guard_->addr_);
int* offset_l_xpu = reinterpret_cast<int*>(offset_l_xpu_guard_->addr_);
int* offset_r_xpu = reinterpret_cast<int*>(offset_r_xpu_guard_->addr_);
int r = xdnn::gemm_int16_tmp_api<float, int16_t, float>(
ctx.GetRawContext(), /* ctx */
false,
false, /* trans_a, trans_b */
x->dims()[0],
dim_t * dim_in,
dim_in, /* m, n, k */
1.0f,
bottom_l_data,
dim_in, /* alpha, data_a, lda */
w_data,
dim_t * dim_in,
0.0f, /* data_b, ldb, beta */
bottom_l_trans_data,
dim_t * dim_in, /* data_c, ldc */
nullptr, /* bias */
xdnn::Activation_t::LINEAR,
0.0f,
w_max,
wx_max /* max_a, max_b, max_c */);
CHECK_EQ(r, 0);
int max_width = 0;
for (int i = 0; i < offset_l.size(); ++i) {
offset_l_cpu[i] = offset_l[i];
if (i != 0 && (offset_l_cpu[i] - offset_l_cpu[i - 1] > max_width)) {
max_width = offset_l_cpu[i] - offset_l_cpu[i - 1];
}
}
for (int i = 0; i < offset_r.size(); ++i) {
offset_r_cpu[i] = offset_r[i];
if (i != 0 && (offset_r_cpu[i] - offset_r_cpu[i - 1] > max_width)) {
max_width = offset_r_cpu[i] - offset_r_cpu[i - 1];
}
}
xpu_memcpy(offset_l_xpu,
offset_l_cpu.get(),
offset_l.size() * sizeof(int),
XPUMemcpyKind::XPU_HOST_TO_DEVICE);
xpu_memcpy(offset_r_xpu,
offset_r_cpu.get(),
offset_r.size() * sizeof(int),
XPUMemcpyKind::XPU_HOST_TO_DEVICE);
r = xdnn::match_matrix_tensor(ctx.GetRawContext(),
batch_size,
bottom_l_trans_data,
bottom_r_data,
offset_l_xpu,
offset_r_xpu,
dim_t,
dim_in,
out_data,
wx_max,
act,
max_width);
CHECK_EQ(r, 0);
int lod_lv1_size = batch_size * dim_t;
int lod_lv2_size = x->lod()[0].back() * dim_t;
std::vector<size_t> out_lod0(batch_size + 1, 0);
std::vector<size_t> out_lod1(lod_lv1_size + 1, 0);
std::vector<size_t> out_lod2(lod_lv2_size + 1, 0);
for (int i = 0; i < batch_size; i++) {
out_lod0[i + 1] = out_lod0[i] + dim_t;
int len_l = offset_l[i + 1] - offset_l[i];
for (int j = 0; j < dim_t; j++) {
out_lod1[i * dim_t + j + 1] = out_lod1[i * dim_t + j] + len_l;
int len_r = offset_r[i + 1] - offset_r[i];
for (int k = 0; k < len_l; k++) {
out_lod2[offset_l[i] * dim_t + j * len_l + k + 1] =
out_lod2[offset_l[i] * dim_t + j * len_l + k] + len_r;
}
}
}
paddle::lite::LoD out_lod;
out_lod.push_back(top_offset);
out_lod.push_back(offset_l);
out_lod.push_back(offset_r);
out->set_lod(out_lod);
}
} // namespace xpu
} // namespace kernels
} // namespace lite
} // namespace paddle
REGISTER_LITE_KERNEL(match_matrix_tensor,
kXPU,
kFloat,
kNCHW,
paddle::lite::kernels::xpu::MatchMatrixTensorCompute,
def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kXPU))})
.BindInput("W", {LiteType::GetTensorTy(TARGET(kXPU))})
.BindInput("Y", {LiteType::GetTensorTy(TARGET(kXPU))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kXPU))})
.BindOutput("Tmp", {LiteType::GetTensorTy(TARGET(kXPU))})
.Finalize();
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include "lite/backends/xpu/target_wrapper.h" // XPUScratchPadGuard
#include "lite/core/kernel.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace xpu {
class MatchMatrixTensorCompute
: public KernelLite<TARGET(kXPU), PRECISION(kFloat)> {
public:
using param_t = operators::MatchMatrixTensorParam;
virtual void PrepareForRun();
virtual void Run();
private:
XPUScratchPadGuard wx_max_xpu_guard_;
XPUScratchPadGuard offset_l_xpu_guard_;
XPUScratchPadGuard offset_r_xpu_guard_;
std::unique_ptr<int[]> offset_l_cpu;
std::unique_ptr<int[]> offset_r_cpu;
};
} // namespace xpu
} // namespace kernels
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/xpu/search_fc_compute.h"
#include <vector>
#include "lite/backends/xpu/xpu_header_sitter.h"
#include "lite/core/op_registry.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace xpu {
void SearchFcCompute::PrepareForRun() {
maxs_xpu_guard_ = TargetWrapperXPU::MallocScratchPad(64 * sizeof(float));
}
void SearchFcCompute::Run() {
auto& param = this->template Param<param_t>();
auto& ctx = this->ctx_->template As<XPUContext>();
auto* bottom = param.X;
auto* w = param.W;
auto* b = param.b;
auto* top = param.Out;
float w_max = param.__xpu__w_max;
int out_size = param.out_size;
bool fuse_relu = param.fuse_relu;
bool float_to_fix = param.__xpu__float_to_fix;
CHECK(float_to_fix) << "W should be fixed point";
int batch = bottom->dims()[0];
int _out = w->dims()[0];
int _in = w->dims()[1];
xdnn::Activation_t act = xdnn::Activation_t::LINEAR;
if (fuse_relu) {
act = xdnn::Activation_t::RELU;
}
std::vector<int64_t> top_dims{bottom->dims()[0], out_size};
top->Resize(top_dims);
const auto* bottom_data = bottom->data<float>();
const auto* weights = w->data<int16_t>();
const auto* bias_data = b->data<float>();
auto* top_data = top->mutable_data<float>(TARGET(kXPU));
float* maxs_xpu = reinterpret_cast<float*>(maxs_xpu_guard_->addr_);
float maxs_cpu[8] = {0.0f, 0.0f, 0.0f, 0.0f, w_max, 0.0f, 0.0f, 0.0f};
xpu_memcpy(maxs_xpu,
&maxs_cpu[0],
8 * sizeof(float),
XPUMemcpyKind::XPU_HOST_TO_DEVICE);
int r = xdnn::findmax<float>(
ctx.GetRawContext(), bottom_data, batch * _in, maxs_xpu);
CHECK_EQ(r, 0);
r = xdnn::gemm_int16_maxptr<float, int16_t, float>(
ctx.GetRawContext(), /* ctx */
false,
true, /*trans_a, trans_b*/
batch,
_out,
_in, /*m, n, k*/
1.0f,
bottom_data,
_in, /*alpha, data_a, lda*/
weights,
_in,
0.0f, /*data_b, ldb, beta*/
top_data,
_out,
bias_data, /* data_c, ldc, bias*/
act,
maxs_xpu,
maxs_xpu + 4,
nullptr /*act, max_a, max_b, max_c*/);
CHECK_EQ(r, 0);
}
} // namespace xpu
} // namespace kernels
} // namespace lite
} // namespace paddle
REGISTER_LITE_KERNEL(search_fc,
kXPU,
kFloat,
kNCHW,
paddle::lite::kernels::xpu::SearchFcCompute,
def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kXPU))})
.BindInput("W", {LiteType::GetTensorTy(TARGET(kXPU))})
.BindInput("b", {LiteType::GetTensorTy(TARGET(kXPU))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kXPU))})
.Finalize();
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "lite/backends/xpu/target_wrapper.h" // XPUScratchPadGuard
#include "lite/core/kernel.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace xpu {
class SearchFcCompute : public KernelLite<TARGET(kXPU), PRECISION(kFloat)> {
public:
using param_t = operators::SearchFcParam;
void PrepareForRun() override;
void Run() override;
private:
XPUScratchPadGuard maxs_xpu_guard_;
};
} // namespace xpu
} // namespace kernels
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/xpu/search_grnn_compute.h"
#include <algorithm>
#include <vector>
#include "lite/backends/xpu/xpu_header_sitter.h"
#include "lite/core/op_registry.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace xpu {
void SearchGrnnCompute::PrepareForRun() {
offset_xpu_guard_ = TargetWrapperXPU::MallocScratchPad(64 * sizeof(int));
new_offset_xpu_guard_ = TargetWrapperXPU::MallocScratchPad(256 * sizeof(int));
maxs_xpu_guard_ = TargetWrapperXPU::MallocScratchPad(16 * sizeof(float));
idx_sorted_by_width_data_cpu.reset(new int[64]);
offset_cpu.reset(new int[64]);
new_offset_cpu.reset(new int[256]);
}
void SearchGrnnCompute::prepare_layout(const operators::SearchGrnnParam& param,
const paddle::lite::Tensor* bottom) {
auto* idx_sorted_by_width = param.idx_sorted_by_width;
auto* layout_input = param.layout_input;
int dim0 = bottom->dims()[0];
int dim1 = 1;
if (bottom->dims().size() > 1) {
dim1 = bottom->dims()[1];
}
int batch = bottom->lod()[0].size() - 1;
auto& offset = bottom->lod()[0];
idx_sorted_by_width->Resize({batch});
std::vector<int> width;
width.resize(batch);
// sort sequences by width (descending) and find the largest width in the
// batch
for (int i = 0; i < batch; i++) {
width[i] = offset[i + 1] - offset[i];
idx_sorted_by_width_data_cpu[i] = i;
}
std::sort(idx_sorted_by_width_data_cpu.get(),
idx_sorted_by_width_data_cpu.get() + batch,
[&width](int a, int b) { return width[a] > width[b]; });
int max_width = width[idx_sorted_by_width_data_cpu[0]];
// start of reorganizing the input
std::vector<size_t> new_offset;
new_offset.resize(max_width + 1);
new_offset[0] = 0;
int j = batch - 1;
int last_width = 0;
int sub_row = 0;
int sub_col = 0;
for (int i = 1; i <= max_width;) {
for (int k = j; k >= 0; --k) {
if (width[idx_sorted_by_width_data_cpu[k]] > last_width) {
sub_row = width[idx_sorted_by_width_data_cpu[k]] - last_width;
sub_col = k + 1;
for (int s = 0; s < sub_row; s++) {
new_offset[i] = new_offset[i - 1] + sub_col;
i++;
}
// move on
last_width = width[idx_sorted_by_width_data_cpu[k]];
j = k - 1;
break;
}
}
}
// copying to the reorganized buffer
if (bottom->dims().size() == 1) {
} else {
LoD new_lod;
new_lod.push_back(new_offset);
layout_input->set_lod(new_lod);
layout_input->Resize({dim0, dim1});
}
xpu_memcpy(idx_sorted_by_width->mutable_data<int>(TARGET(kXPU)),
idx_sorted_by_width_data_cpu.get(),
idx_sorted_by_width->numel() * sizeof(int),
XPUMemcpyKind::XPU_HOST_TO_DEVICE);
}
void SearchGrnnCompute::Run() {
auto& param = this->template Param<param_t>();
auto& ctx = this->ctx_->template As<XPUContext>();
auto* bottom = param.x;
auto* wi = param.wi;
auto* wh = param.wh;
auto* top = param.out;
auto* tmp_buffer = param.tmp_buffer;
auto* idx_sorted_by_width = param.idx_sorted_by_width;
auto* layout_input = param.layout_input;
int cap_h = param.num_hidden;
int cap_e = param.num_input;
int cap_l = bottom->dims()[0];
auto wi_max = param.__xpu__wi_max;
auto wh_max = param.__xpu__wh_max;
bool float_to_fix = param.__xpu__float_to_fix;
CHECK(float_to_fix) << "W should be fixed point";
int dim = 1;
if (bottom->dims().size() > 1) {
dim = bottom->dims()[1];
}
const auto& offset = bottom->lod()[0];
LoD top_lod;
top_lod.push_back(offset);
top->set_lod(top_lod);
std::vector<int64_t> top_dims_vec{cap_l, cap_h};
top->Resize(top_dims_vec);
auto* top_hidden = top->mutable_data<float>(TARGET(kXPU));
const auto* dense_e2h = wi->data<int16_t>();
const auto* dense_h2h = wh->data<int16_t>();
// Prepare idx_sorted_by_width
prepare_layout(param, bottom);
int batch = bottom->lod()[0].size() - 1;
int max_width = layout_input->lod()[0].size() - 1;
const auto& new_offset = layout_input->lod()[0];
auto* new_emb = layout_input->mutable_data<float>(TARGET(kXPU));
// Prepare offset and new_offset
int* offset_xpu = reinterpret_cast<int*>(offset_xpu_guard_->addr_);
int* new_offset_xpu = reinterpret_cast<int*>(new_offset_xpu_guard_->addr_);
float* maxs_xpu = reinterpret_cast<float*>(maxs_xpu_guard_->addr_);
CHECK_LE(offset.size(), 64);
CHECK_LE(new_offset.size(), 256);
for (size_t i = 0; i < offset.size(); ++i) {
offset_cpu[i] = offset[i];
}
for (size_t i = 0; i < new_offset.size(); ++i) {
new_offset_cpu[i] = new_offset[i];
}
xpu_memcpy(offset_xpu,
offset_cpu.get(),
offset.size() * sizeof(int),
XPUMemcpyKind::XPU_HOST_TO_DEVICE);
xpu_memcpy(new_offset_xpu,
new_offset_cpu.get(),
new_offset.size() * sizeof(int),
XPUMemcpyKind::XPU_HOST_TO_DEVICE);
int r = xdnn::search_seq2batch(ctx.GetRawContext(),
batch,
max_width,
dim,
idx_sorted_by_width->data<int>(),
offset_xpu,
new_offset_xpu,
bottom->data<float>(),
new_emb);
CHECK_EQ(r, 0);
// this buffer is used for book keeping info which will be used in bp
// buffer also needed in bp, so make it larger
tmp_buffer->Resize({20, cap_l, cap_h});
auto* buffer_data = tmp_buffer->mutable_data<float>(TARGET(kXPU));
// the internal hidden
auto* hidden = buffer_data + 19 * cap_l * cap_h;
// do-findmax
float maxs_cpu[16] = {0.0f,
0.0f,
0.0f,
0.0f,
wi_max[0],
0.0f,
0.0f,
0.0f,
wi_max[1],
0.0f,
0.0f,
0.0f,
wi_max[2],
0.0f,
0.0f,
0.0f};
xpu_memcpy(maxs_xpu,
maxs_cpu,
16 * sizeof(float),
XPUMemcpyKind::XPU_HOST_TO_DEVICE);
r = xdnn::findmax<float>(
ctx.GetRawContext(), new_emb, cap_l * cap_e, maxs_xpu);
CHECK_EQ(r, 0);
// precompute embedding to hidden
for (int i = 0; i < 3; ++i) {
const int16_t* data_b = dense_e2h + i * cap_e * cap_h; // e2h, e2hr, e2hz
float* data_c = buffer_data + i * cap_l * cap_h; // w_x_e, wr_x_e, wz_x_e
int r = xdnn::gemm_int16_maxptr<float, int16_t, float>(
ctx.GetRawContext(),
false,
true, // trans_a, trans_b
cap_l,
cap_h,
cap_e, // m, n, k
1.0f,
new_emb,
cap_e, // alpha, data_a, lda
data_b,
cap_e,
0.0f, // data_b, ldb, beta
data_c,
cap_h, // data_c, ldc
nullptr,
xdnn::Activation_t::LINEAR, // bias, act
maxs_xpu,
maxs_xpu + 4 * (i + 1)); // max_a, max_b
CHECK_EQ(r, 0);
}
r = xdnn::search_grnn<float, int16_t>(ctx.GetRawContext(),
cap_l,
cap_h,
cap_e,
max_width,
new_offset_xpu,
buffer_data,
dense_h2h,
hidden,
wh_max[0],
wh_max[1],
wh_max[2]);
CHECK_EQ(r, 0);
r = xdnn::search_batch2seq(ctx.GetRawContext(),
batch,
max_width,
cap_h,
idx_sorted_by_width->data<int>(),
offset_xpu,
new_offset_xpu,
hidden,
top_hidden);
CHECK_EQ(r, 0);
}
} // namespace xpu
} // namespace kernels
} // namespace lite
} // namespace paddle
REGISTER_LITE_KERNEL(search_grnn,
kXPU,
kFloat,
kNCHW,
paddle::lite::kernels::xpu::SearchGrnnCompute,
def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kXPU))})
.BindInput("Wi", {LiteType::GetTensorTy(TARGET(kXPU))})
.BindInput("Wh", {LiteType::GetTensorTy(TARGET(kXPU))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kXPU))})
.BindOutput("tmp_buffer", {LiteType::GetTensorTy(TARGET(kXPU))})
.BindOutput("idx_sorted_by_width",
{LiteType::GetTensorTy(TARGET(kXPU), PRECISION(kInt32))})
.BindOutput("layout_input", {LiteType::GetTensorTy(TARGET(kXPU))})
.Finalize();
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include "lite/backends/xpu/target_wrapper.h" // XPUScratchPadGuard
#include "lite/core/kernel.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace xpu {
class SearchGrnnCompute : public KernelLite<TARGET(kXPU), PRECISION(kFloat)> {
public:
using param_t = operators::SearchGrnnParam;
void PrepareForRun() override;
void prepare_layout(const operators::SearchGrnnParam& param,
const paddle::lite::Tensor* bottom);
void Run() override;
private:
XPUScratchPadGuard offset_xpu_guard_;
XPUScratchPadGuard new_offset_xpu_guard_;
XPUScratchPadGuard maxs_xpu_guard_;
std::unique_ptr<int[]> idx_sorted_by_width_data_cpu;
std::unique_ptr<int[]> offset_cpu;
std::unique_ptr<int[]> new_offset_cpu;
};
} // namespace xpu
} // namespace kernels
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/xpu/sequence_arithmetic_compute.h"
#include "lite/backends/xpu/xpu_header_sitter.h"
#include "lite/core/op_registry.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace xpu {
void SequenceArithmeticCompute::Run() {
auto& param = this->template Param<param_t>();
auto& ctx = this->ctx_->template As<XPUContext>();
auto* bottom0 = param.X;
auto* bottom1 = param.Y;
auto* top = param.Out;
int op_type = param.op_type;
auto len1 = bottom0->numel();
auto len2 = bottom1->numel();
const auto* bottom_data0 = bottom0->data<float>();
const auto* bottom_data1 = bottom1->data<float>();
auto* top_data = top->mutable_data<float>(TARGET(kXPU));
switch (op_type) {
case 1: // addition: top[0] = bottom[0] + bottom[1]
if (len1 > len2) {
xdnn::elementwise_add(
ctx.GetRawContext(), bottom_data0, bottom_data1, top_data, len2);
xdnn::memcpy_device(ctx.GetRawContext(),
&top_data[len2],
&bottom_data0[len2],
(len1 - len2) * sizeof(float));
} else {
xdnn::elementwise_add(
ctx.GetRawContext(), bottom_data0, bottom_data1, top_data, len1);
}
break;
case 2: // substraction: top[0] = bottom[0] - bottom[1]
if (len1 > len2) {
xdnn::elementwise_sub(
ctx.GetRawContext(), bottom_data0, bottom_data1, top_data, len2);
xdnn::memcpy_device(ctx.GetRawContext(),
&top_data[len2],
&bottom_data0[len2],
(len1 - len2) * sizeof(float));
} else {
xdnn::elementwise_sub(
ctx.GetRawContext(), bottom_data0, bottom_data1, top_data, len1);
}
break;
case 3: // multiplication: top[0] = bottom[0] * bottom[1]
if (len1 > len2) {
xdnn::elementwise_mul(
ctx.GetRawContext(), bottom_data0, bottom_data1, top_data, len2);
xdnn::memcpy_device(ctx.GetRawContext(),
&top_data[len2],
&bottom_data0[len2],
(len1 - len2) * sizeof(float));
} else {
xdnn::elementwise_mul(
ctx.GetRawContext(), bottom_data0, bottom_data1, top_data, len1);
}
break;
default:
break;
}
}
} // namespace xpu
} // namespace kernels
} // namespace lite
} // namespace paddle
REGISTER_LITE_KERNEL(sequence_arithmetic,
kXPU,
kFloat,
kNCHW,
paddle::lite::kernels::xpu::SequenceArithmeticCompute,
def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kXPU))})
.BindInput("Y", {LiteType::GetTensorTy(TARGET(kXPU))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kXPU))})
.Finalize();
REGISTER_LITE_KERNEL(search_seq_arithmetic,
kXPU,
kFloat,
kNCHW,
paddle::lite::kernels::xpu::SequenceArithmeticCompute,
def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kXPU))})
.BindInput("Y", {LiteType::GetTensorTy(TARGET(kXPU))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kXPU))})
.Finalize();
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include "lite/core/kernel.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace xpu {
class SequenceArithmeticCompute
: public KernelLite<TARGET(kXPU), PRECISION(kFloat)> {
public:
using param_t = operators::SequenceArithmeticParam;
void Run() override;
};
} // namespace xpu
} // namespace kernels
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/xpu/sequence_concat_compute.h"
#include <vector>
#include "lite/backends/xpu/xpu_header_sitter.h"
#include "lite/core/op_registry.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace xpu {
void SequenceConcatCompute::PrepareForRun() {
lod0_xpu_guard_ = TargetWrapperXPU::MallocScratchPad(64 * sizeof(int));
lod1_xpu_guard_ = TargetWrapperXPU::MallocScratchPad(64 * sizeof(int));
lod0_cpu.reset(new int[64]);
lod1_cpu.reset(new int[64]);
}
template <typename T>
inline LoD ConcatLoD(const std::vector<lite::Tensor*>& xs,
std::vector<lite::Tensor>* xs_in_order) {
std::vector<uint64_t> result;
result.resize(xs[0]->lod()[0].size());
for (size_t i = 1; i < result.size(); ++i) {
size_t sum = 0;
for (size_t j = 0; j < xs.size(); ++j) {
auto& x_lod = xs[j]->lod()[0];
if (x_lod[i - 1] < x_lod[i]) {
xs_in_order->emplace_back(xs[j]->Slice<T>(x_lod[i - 1], x_lod[i]));
}
sum += x_lod[i];
}
result[i] = sum;
}
LoD lod;
lod.emplace_back(result);
return lod;
}
void SequenceConcatCompute::Run() {
auto& param = this->template Param<param_t>();
auto& ctx = this->ctx_->template As<XPUContext>();
auto xs = param.X;
auto out = param.Out;
size_t lod_size = 0;
for (auto& x : xs) {
if (lod_size == 0) {
lod_size = x->lod()[0].size();
} else {
CHECK_EQ(lod_size, x->lod()[0].size())
<< "The number of sequence must be same between each input";
}
}
CHECK_NE(lod_size, 0) << "Each input must have sequence information";
// TODO(miaotianxiang):
int64_t dim0 = 0;
int64_t feature_size = 0;
std::vector<int64_t> out_dims;
for (const auto& tensor : param.X) {
const auto x_dims = tensor->dims();
if (out_dims.empty()) {
out_dims = x_dims.data();
}
dim0 += x_dims[0];
if (feature_size == 0) {
feature_size = x_dims.production() / x_dims[0];
} else {
CHECK_EQ(feature_size, x_dims.production() / x_dims[0])
<< "Inputs of sequence concat must have same feature size";
}
}
out_dims[0] = dim0;
out->Resize(out_dims);
std::vector<lite::Tensor> x_in_order;
out->set_lod(ConcatLoD<float>(xs, &x_in_order));
CHECK(xs.size() == 2) << "XPU only support sequence_pool for 2 tensors";
auto lod0 = xs[0]->lod()[0];
auto lod1 = xs[1]->lod()[0];
int batch_size = lod0.size() - 1;
int* lod0_xpu = reinterpret_cast<int*>(lod0_xpu_guard_->addr_);
int* lod1_xpu = reinterpret_cast<int*>(lod1_xpu_guard_->addr_);
for (int i = 0; i < lod0.size(); ++i) {
lod0_cpu[i] = lod0[i];
}
for (int i = 0; i < lod1.size(); ++i) {
lod1_cpu[i] = lod1[i];
}
xpu_memcpy(lod0_xpu,
lod0_cpu.get(),
lod0.size() * sizeof(int),
XPUMemcpyKind::XPU_HOST_TO_DEVICE);
xpu_memcpy(lod1_xpu,
lod1_cpu.get(),
lod1.size() * sizeof(int),
XPUMemcpyKind::XPU_HOST_TO_DEVICE);
int r = xdnn::sequence_concat(ctx.GetRawContext(),
xs[0]->data<float>(),
lod0_xpu,
xs[1]->data<float>(),
lod1_xpu,
out->mutable_data<float>(TARGET(kXPU)),
batch_size);
CHECK_EQ(r, 0);
}
} // namespace xpu
} // namespace kernels
} // namespace lite
} // namespace paddle
REGISTER_LITE_KERNEL(sequence_concat,
kXPU,
kFloat,
kNCHW,
paddle::lite::kernels::xpu::SequenceConcatCompute,
def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kXPU))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kXPU))})
.Finalize();
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include "lite/backends/xpu/target_wrapper.h" // XPUScratchPadGuard
#include "lite/core/kernel.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace xpu {
class SequenceConcatCompute
: public KernelLite<TARGET(kXPU), PRECISION(kFloat)> {
public:
using param_t = operators::SequenceConcatParam;
void PrepareForRun() override;
void Run() override;
private:
XPUScratchPadGuard lod0_xpu_guard_;
XPUScratchPadGuard lod1_xpu_guard_;
std::unique_ptr<int[]> lod0_cpu;
std::unique_ptr<int[]> lod1_cpu;
};
} // namespace xpu
} // namespace kernels
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/xpu/sequence_pool_compute.h"
#include <string>
#include "lite/backends/xpu/xpu_header_sitter.h"
#include "lite/core/op_registry.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace xpu {
void XPUSequencePoolCompute::PrepareForRun() {
lod_xpu_guard_ = TargetWrapperXPU::MallocScratchPad(64 * sizeof(int));
lod_cpu.reset(new int[64]);
}
void XPUSequencePoolCompute::Run() {
auto& param = this->template Param<param_t>();
auto& ctx = this->ctx_->template As<XPUContext>();
auto* in = param.X;
auto* out = param.Out;
std::string pool_type_str = param.pool_type;
auto dims = in->dims();
auto lod = in->lod();
dims[0] = lod[0].size() - 1;
xdnn::Pooling_t pool_type = xdnn::Pooling_t::MAX_WITHOUT_INDEX;
if (pool_type_str == "MAX") {
} else if (pool_type_str == "LAST") {
pool_type = xdnn::Pooling_t::LAST;
} else {
CHECK(false);
}
int num_seq = out->dims()[0];
int dim = out->numel() / num_seq;
auto in_lod = in->lod()[0];
for (size_t i = 0; i < in_lod.size(); ++i) {
lod_cpu[i] = in_lod[i];
}
int* lod_xpu = reinterpret_cast<int*>(lod_xpu_guard_->addr_);
xpu_memcpy(lod_xpu,
lod_cpu.get(),
in_lod.size() * sizeof(int),
XPUMemcpyKind::XPU_HOST_TO_DEVICE);
int r =
xdnn::sequence_pooling_forward(ctx.GetRawContext(),
pool_type,
num_seq,
lod_xpu,
dim,
in->data<float>(),
nullptr /* index */,
out->mutable_data<float>(TARGET(kXPU)));
CHECK_EQ(r, 0);
}
} // namespace xpu
} // namespace kernels
} // namespace lite
} // namespace paddle
REGISTER_LITE_KERNEL(sequence_pool,
kXPU,
kFloat,
kNCHW,
paddle::lite::kernels::xpu::XPUSequencePoolCompute,
def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kXPU))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kXPU))})
.BindOutput("MaxIndex", {LiteType::GetTensorTy(TARGET(kXPU))})
.Finalize();
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include "lite/backends/xpu/target_wrapper.h" // XPUScratchPadGuard
#include "lite/core/kernel.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace xpu {
class XPUSequencePoolCompute
: public KernelLite<TARGET(kXPU), PRECISION(kFloat)> {
public:
using param_t = operators::SequencePoolParam;
void PrepareForRun() override;
void Run() override;
private:
XPUScratchPadGuard lod_xpu_guard_;
std::unique_ptr<int[]> lod_cpu;
};
} // namespace xpu
} // namespace kernels
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/xpu/sequence_reverse_compute.h"
#include "lite/backends/xpu/xpu_header_sitter.h"
#include "lite/core/op_registry.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace xpu {
template <typename T, PrecisionType PType>
void SequenceReverseCompute<T, PType>::PrepareForRun() {
lod_xpu_guard_ = TargetWrapperXPU::MallocScratchPad(64 * sizeof(int));
lod_cpu.reset(new int[64]);
}
template <typename T, PrecisionType PType>
void SequenceReverseCompute<T, PType>::Run() {
auto& param = this->template Param<param_t>();
auto& ctx = this->ctx_->template As<XPUContext>();
auto* x = param.X;
auto* y = param.Out;
auto lod = x->lod()[0];
size_t limit = x->numel();
size_t ele_cnt_in_4_byte = limit / x->dims()[0];
auto* x_data = x->template data<T>();
auto* y_data = y->template mutable_data<T>(TARGET(kXPU));
int batch_size = lod.size() - 1;
if (std::is_same<T, uint8_t>::value) {
ele_cnt_in_4_byte /= 4;
} else if (std::is_same<T, int>::value) {
// remain the same
} else if (std::is_same<T, int64_t>::value) {
ele_cnt_in_4_byte *= 2;
} else if (std::is_same<T, float>::value) {
// remain the same
} else if (std::is_same<T, double>::value) {
ele_cnt_in_4_byte *= 2;
}
for (size_t i = 0; i < lod.size(); ++i) {
lod_cpu[i] = lod[i];
}
int* lod_xpu = reinterpret_cast<int*>(lod_xpu_guard_->addr_);
xpu_memcpy(lod_xpu,
lod_cpu.get(),
lod.size() * sizeof(int),
XPUMemcpyKind::XPU_HOST_TO_DEVICE);
int r = xdnn::sequence_reverse(ctx.GetRawContext(),
batch_size,
lod_xpu,
ele_cnt_in_4_byte,
reinterpret_cast<const float*>(x_data),
reinterpret_cast<float*>(y_data));
CHECK_EQ(r, 0);
}
} // namespace xpu
} // namespace kernels
} // namespace lite
} // namespace paddle
namespace xpu = paddle::lite::kernels::xpu;
using SequenceReverseFp32 =
xpu::SequenceReverseCompute<float, PRECISION(kFloat)>;
using SequenceReverseInt64 =
xpu::SequenceReverseCompute<int64_t, PRECISION(kInt64)>;
REGISTER_LITE_KERNEL(
sequence_reverse, kXPU, kFloat, kNCHW, SequenceReverseFp32, def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kXPU))})
.BindOutput("Y", {LiteType::GetTensorTy(TARGET(kXPU))})
.Finalize();
REGISTER_LITE_KERNEL(
sequence_reverse, kXPU, kInt64, kNCHW, SequenceReverseInt64, def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kXPU), PRECISION(kInt64))})
.BindOutput("Y", {LiteType::GetTensorTy(TARGET(kXPU), PRECISION(kInt64))})
.Finalize();
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include "lite/backends/xpu/target_wrapper.h" // XPUScratchPadGuard
#include "lite/core/kernel.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace xpu {
template <typename T, PrecisionType PType>
class SequenceReverseCompute : public KernelLite<TARGET(kXPU), PType> {
public:
using param_t = operators::SequenceReverseParam;
void PrepareForRun() override;
void Run() override;
private:
XPUScratchPadGuard lod_xpu_guard_;
std::unique_ptr<int[]> lod_cpu;
};
} // namespace xpu
} // namespace kernels
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/xpu/sequence_topk_avg_pooling_compute.h"
#include <vector>
#include "lite/backends/xpu/xpu_header_sitter.h"
#include "lite/core/op_registry.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace xpu {
void SequenceTopkAvgPoolingCompute::PrepareForRun() {
lod_xpu_guard_ = TargetWrapperXPU::MallocScratchPad(256 * sizeof(int));
in_lod_cpu.reset(new int[64]);
row_lod_cpu.reset(new int[64]);
col_lod_cpu.reset(new int[64]);
}
void SequenceTopkAvgPoolingCompute::Run() {
auto& param = this->template Param<param_t>();
auto& ctx = this->ctx_->template As<XPUContext>();
auto* in = param.X;
auto* row = param.ROW;
auto* col = param.COLUMN;
auto* out = param.Out;
auto* pos = param.pos;
auto channel_num = param.channel_num;
auto topks = param.topks;
auto k_num = topks.size();
auto max_k = topks[topks.size() - 1];
auto in_lod = in->lod()[0];
auto row_lod = row->lod()[0];
auto col_lod = col->lod()[0];
int batch_size = row_lod.size() - 1;
int pos_total_size = row_lod[batch_size] * channel_num * max_k;
std::vector<int64_t> vec_pos_shape;
vec_pos_shape.push_back(pos_total_size);
pos->Resize(vec_pos_shape);
auto pos_data = pos->mutable_data<int>(TARGET(kXPU));
int offset = 0;
std::vector<uint64_t> vec_out_lod;
vec_out_lod.reserve(batch_size + 1);
for (int i = 0; i <= batch_size; ++i) {
offset = row_lod[i];
vec_out_lod.push_back(offset);
}
LoD lod_temp;
lod_temp.push_back(vec_out_lod);
out->set_lod(lod_temp);
auto in_data = in->data<float>();
auto out_data = out->mutable_data<float>(TARGET(kXPU));
int* in_lod_xpu = reinterpret_cast<int*>(lod_xpu_guard_->addr_);
int* row_lod_xpu = in_lod_xpu + in_lod.size();
int* col_lod_xpu = row_lod_xpu + row_lod.size();
int* topks_xpu = col_lod_xpu + col_lod.size();
for (int i = 0; i < in_lod.size(); ++i) {
in_lod_cpu[i] = in_lod[i];
}
for (int i = 0; i < row_lod.size(); ++i) {
row_lod_cpu[i] = row_lod[i];
}
for (int i = 0; i < col_lod.size(); ++i) {
col_lod_cpu[i] = col_lod[i];
}
xpu_memcpy(in_lod_xpu,
in_lod_cpu.get(),
in_lod.size() * sizeof(int),
XPUMemcpyKind::XPU_HOST_TO_DEVICE);
xpu_memcpy(row_lod_xpu,
row_lod_cpu.get(),
row_lod.size() * sizeof(int),
XPUMemcpyKind::XPU_HOST_TO_DEVICE);
xpu_memcpy(col_lod_xpu,
col_lod_cpu.get(),
col_lod.size() * sizeof(int),
XPUMemcpyKind::XPU_HOST_TO_DEVICE);
xpu_memcpy(topks_xpu,
topks.data(),
topks.size() * sizeof(int),
XPUMemcpyKind::XPU_HOST_TO_DEVICE);
int r = xdnn::sequence_topk_avg_pooling(ctx.GetRawContext(),
in_data,
out_data,
pos_data,
batch_size,
channel_num,
in_lod_xpu,
row_lod_xpu,
col_lod_xpu,
topks_xpu,
k_num);
CHECK_EQ(r, 0);
}
} // namespace xpu
} // namespace kernels
} // namespace lite
} // namespace paddle
REGISTER_LITE_KERNEL(sequence_topk_avg_pooling,
kXPU,
kFloat,
kNCHW,
paddle::lite::kernels::xpu::SequenceTopkAvgPoolingCompute,
def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kXPU))})
.BindInput("ROW", {LiteType::GetTensorTy(TARGET(kXPU))})
.BindInput("COLUMN", {LiteType::GetTensorTy(TARGET(kXPU))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kXPU))})
.BindOutput("pos", {LiteType::GetTensorTy(TARGET(kXPU))})
.Finalize();
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include "lite/backends/xpu/target_wrapper.h" // XPUScratchPadGuard
#include "lite/core/kernel.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace xpu {
class SequenceTopkAvgPoolingCompute
: public KernelLite<TARGET(kXPU), PRECISION(kFloat)> {
public:
using param_t = operators::SequenceTopkAvgPoolingParam;
void PrepareForRun() override;
void Run() override;
private:
XPUScratchPadGuard lod_xpu_guard_;
std::unique_ptr<int[]> in_lod_cpu;
std::unique_ptr<int[]> row_lod_cpu;
std::unique_ptr<int[]> col_lod_cpu;
};
} // namespace xpu
} // namespace kernels
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/xpu/var_conv_2d_compute.h"
#include <vector>
#include "lite/backends/xpu/xpu_header_sitter.h"
#include "lite/core/op_registry.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace xpu {
void VarConv2DCompute::PrepareForRun() {
offset_x_xpu_guard_ = TargetWrapperXPU::MallocScratchPad(64 * sizeof(int));
offset_y_xpu_guard_ = TargetWrapperXPU::MallocScratchPad(64 * sizeof(int));
offset_x_cpu.reset(new int[64]);
offset_y_cpu.reset(new int[64]);
}
void VarConv2DCompute::Run() {
auto& param = this->template Param<param_t>();
auto& ctx = this->ctx_->template As<XPUContext>();
auto* bottom = param.X;
auto* w = param.W;
auto* top = param.Out;
int output_channel = param.output_channel;
int input_channel = param.input_channel;
int kernel_h = param.kernel_h;
int kernel_w = param.kernel_w;
int stride_h = param.stride_h;
int stride_w = param.stride_w;
float w_max = param.__xpu__w_max;
bool fuse_relu = param.fuse_relu;
bool float_to_fix = param.__xpu__float_to_fix;
CHECK(float_to_fix) << "W should be fixed point";
xdnn::Activation_t act = xdnn::Activation_t::LINEAR;
if (fuse_relu) {
act = xdnn::Activation_t::RELU;
}
int batch = bottom->lod()[0].size() - 1;
const auto& offset_x = bottom->lod()[2];
const auto& offset_y = bottom->lod()[1];
std::vector<size_t> top_offset;
int top_size = 0;
top_offset.push_back(top_size);
for (int b = 0; b < batch; ++b) {
int width = offset_x[b + 1] - offset_x[b];
int height = offset_y[b + 1] - offset_y[b];
int top_im_x = 0;
int top_im_y = 0;
if (width != 0) {
top_im_x = (width - 1) / stride_w + 1;
}
if (height != 0) {
top_im_y = (height - 1) / stride_h + 1;
}
int top_im_size = top_im_y * top_im_x;
top_size += output_channel * top_im_size;
top_offset.push_back(top_size);
}
LoD top_lod;
top_lod.push_back(top_offset);
top_lod.push_back(bottom->lod()[1]);
top_lod.push_back(bottom->lod()[2]);
top->set_lod(top_lod);
std::vector<int64_t> top_dims_vec{top_size};
top_dims_vec.push_back(1);
top->Resize(top_dims_vec);
auto* top_data = top->mutable_data<float>(TARGET(kXPU));
auto* bottom_data = bottom->data<float>();
auto* w_data = w->data<int16_t>();
int* offset_x_xpu = reinterpret_cast<int*>(offset_x_xpu_guard_->addr_);
int* offset_y_xpu = reinterpret_cast<int*>(offset_y_xpu_guard_->addr_);
for (int i = 0; i < (batch + 1); ++i) {
offset_x_cpu[i] = offset_x[i];
offset_y_cpu[i] = offset_y[i];
}
xpu_memcpy(offset_x_xpu,
offset_x_cpu.get(),
(batch + 1) * sizeof(int),
XPUMemcpyKind::XPU_HOST_TO_DEVICE);
xpu_memcpy(offset_y_xpu,
offset_y_cpu.get(),
(batch + 1) * sizeof(int),
XPUMemcpyKind::XPU_HOST_TO_DEVICE);
int r = xdnn::search_varconv<float, int16_t>(ctx.GetRawContext(),
batch,
input_channel,
output_channel,
kernel_h,
kernel_w,
stride_h,
stride_w,
bottom_data,
w_data,
offset_x_xpu,
offset_y_xpu,
top_data,
w_max,
act);
CHECK_EQ(r, 0);
}
} // namespace xpu
} // namespace kernels
} // namespace lite
} // namespace paddle
REGISTER_LITE_KERNEL(var_conv_2d,
kXPU,
kFloat,
kNCHW,
paddle::lite::kernels::xpu::VarConv2DCompute,
def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kXPU))})
.BindInput("W", {LiteType::GetTensorTy(TARGET(kXPU))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kXPU))})
.BindOutput("Col", {LiteType::GetTensorTy(TARGET(kXPU))})
.Finalize();
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include "lite/backends/xpu/target_wrapper.h" // XPUScratchPadGuard
#include "lite/core/kernel.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace xpu {
class VarConv2DCompute : public KernelLite<TARGET(kXPU), PRECISION(kFloat)> {
public:
using param_t = operators::VarConv2DParam;
void PrepareForRun() override;
void Run() override;
private:
XPUScratchPadGuard offset_x_xpu_guard_;
XPUScratchPadGuard offset_y_xpu_guard_;
std::unique_ptr<int[]> offset_x_cpu;
std::unique_ptr<int[]> offset_y_cpu;
};
} // namespace xpu
} // namespace kernels
} // namespace lite
} // namespace paddle
...@@ -168,6 +168,9 @@ add_operator(__xpu__resnet50_op extra SRCS __xpu__resnet50_op.cc DEPS ${op_DEPS} ...@@ -168,6 +168,9 @@ add_operator(__xpu__resnet50_op extra SRCS __xpu__resnet50_op.cc DEPS ${op_DEPS}
add_operator(__xpu__multi_encoder_op extra SRCS __xpu__multi_encoder_op.cc DEPS ${op_DEPS}) add_operator(__xpu__multi_encoder_op extra SRCS __xpu__multi_encoder_op.cc DEPS ${op_DEPS})
add_operator(__xpu__embedding_with_eltwise_add_op extra SRCS __xpu__embedding_with_eltwise_add_op.cc DEPS ${op_DEPS}) add_operator(__xpu__embedding_with_eltwise_add_op extra SRCS __xpu__embedding_with_eltwise_add_op.cc DEPS ${op_DEPS})
add_operator(__xpu__fc_op extra SRCS __xpu__fc_op.cc DEPS ${op_DEPS}) add_operator(__xpu__fc_op extra SRCS __xpu__fc_op.cc DEPS ${op_DEPS})
add_operator(__xpu__resnet_cbam_op extra SRCS __xpu__resnet_cbam_op.cc DEPS ${op_DEPS})
add_operator(__xpu__search_attention_op extra SRCS __xpu__search_attention_op.cc DEPS ${op_DEPS})
add_operator(__xpu__mmdnn_op extra SRCS __xpu__mmdnn_op.cc DEPS ${op_DEPS})
if (NOT LITE_WITH_X86) if (NOT LITE_WITH_X86)
lite_cc_test(test_fc_op SRCS fc_op_test.cc lite_cc_test(test_fc_op SRCS fc_op_test.cc
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/operators/__xpu__mmdnn_op.h"
#include <vector>
#include "lite/core/op_registry.h"
namespace paddle {
namespace lite {
namespace operators {
bool XPUMmdnnBidEmbGrnnAttOp::CheckShape() const { return true; }
bool XPUMmdnnBidEmbGrnnAttOp::InferShapeImpl() const {
auto& id_dims = param_.id0->dims();
auto& id_lod = param_.id0->lod()[0];
auto& emb_tbl_dims = param_.emb_tbl->dims();
auto& grnn_wh_dims = param_.grnn_rv_wh->dims();
param_.grnn_fw_pool_out->Resize(
{(int64_t)id_lod.size() - 1, grnn_wh_dims[2]});
param_.grnn_rv_pool_out->Resize(
{(int64_t)id_lod.size() - 1, grnn_wh_dims[2]});
param_.att_pool_out->Resize(
{(int64_t)id_lod.size() - 1, 2 * grnn_wh_dims[2]});
param_.concat_3in1_out->Resize({id_dims[0], 3 * grnn_wh_dims[2]});
param_.concat_3in1_out->set_lod({id_lod});
param_.emb_fw_out->Resize({id_dims[0], emb_tbl_dims[1]});
param_.emb_fw_out->set_lod({id_lod});
return true;
}
bool XPUMmdnnBidEmbGrnnAttOp::AttachImpl(const cpp::OpDesc& op_desc,
lite::Scope* scope) {
param_.id0 =
scope->FindVar(op_desc.Input("id0").front())->GetMutable<lite::Tensor>();
param_.id1 =
scope->FindVar(op_desc.Input("id1").front())->GetMutable<lite::Tensor>();
param_.emb_tbl = scope->FindVar(op_desc.Input("emb_tbl").front())
->GetMutable<lite::Tensor>();
param_.grnn_fw_wh = scope->FindVar(op_desc.Input("grnn_fw_wh").front())
->GetMutable<lite::Tensor>();
param_.grnn_fw_wi = scope->FindVar(op_desc.Input("grnn_fw_wi").front())
->GetMutable<lite::Tensor>();
param_.grnn_rv_wh = scope->FindVar(op_desc.Input("grnn_rv_wh").front())
->GetMutable<lite::Tensor>();
param_.grnn_rv_wi = scope->FindVar(op_desc.Input("grnn_rv_wi").front())
->GetMutable<lite::Tensor>();
param_.att_fc_w = scope->FindVar(op_desc.Input("att_fc_w").front())
->GetMutable<lite::Tensor>();
param_.att_fc_b = scope->FindVar(op_desc.Input("att_fc_b").front())
->GetMutable<lite::Tensor>();
param_.grnn_fw_pool_out =
scope->FindVar(op_desc.Output("grnn_fw_pool_out").front())
->GetMutable<lite::Tensor>();
param_.grnn_rv_pool_out =
scope->FindVar(op_desc.Output("grnn_rv_pool_out").front())
->GetMutable<lite::Tensor>();
param_.att_pool_out = scope->FindVar(op_desc.Output("att_pool_out").front())
->GetMutable<lite::Tensor>();
param_.concat_3in1_out =
scope->FindVar(op_desc.Output("concat_3in1_out").front())
->GetMutable<lite::Tensor>();
param_.emb_fw_out = scope->FindVar(op_desc.Output("emb_fw_out").front())
->GetMutable<lite::Tensor>();
param_.grnn_fw_wh_maxs =
op_desc.GetAttr<std::vector<float>>("grnn_fw_wh_maxs");
param_.grnn_fw_wi_maxs =
op_desc.GetAttr<std::vector<float>>("grnn_fw_wi_maxs");
param_.grnn_rv_wh_maxs =
op_desc.GetAttr<std::vector<float>>("grnn_rv_wh_maxs");
param_.grnn_rv_wi_maxs =
op_desc.GetAttr<std::vector<float>>("grnn_rv_wi_maxs");
param_.att_fc_w_max = op_desc.GetAttr<float>("att_fc_w_max");
return true;
}
bool XPUMmdnnBidEmbAttOp::CheckShape() const { return true; }
bool XPUMmdnnBidEmbAttOp::InferShapeImpl() const {
auto& id_dims = param_.id0->dims();
auto& id_lod = param_.id0->lod()[0];
auto& emb_tbl_dims = param_.emb_tbl->dims();
param_.att_pool_out->Resize({(int64_t)id_lod.size() - 1, emb_tbl_dims[1]});
param_.emb_fw_out->Resize({id_dims[0], emb_tbl_dims[1]});
param_.emb_fw_out->set_lod({id_lod});
return true;
}
bool XPUMmdnnBidEmbAttOp::AttachImpl(const cpp::OpDesc& op_desc,
lite::Scope* scope) {
param_.id0 =
scope->FindVar(op_desc.Input("id0").front())->GetMutable<lite::Tensor>();
param_.id1 =
scope->FindVar(op_desc.Input("id1").front())->GetMutable<lite::Tensor>();
param_.emb_tbl = scope->FindVar(op_desc.Input("emb_tbl").front())
->GetMutable<lite::Tensor>();
param_.att_fc_w = scope->FindVar(op_desc.Input("att_fc_w").front())
->GetMutable<lite::Tensor>();
param_.att_fc_b = scope->FindVar(op_desc.Input("att_fc_b").front())
->GetMutable<lite::Tensor>();
param_.att_pool_out = scope->FindVar(op_desc.Output("att_pool_out").front())
->GetMutable<lite::Tensor>();
param_.emb_fw_out = scope->FindVar(op_desc.Output("emb_fw_out").front())
->GetMutable<lite::Tensor>();
param_.att_fc_w_max = op_desc.GetAttr<float>("att_fc_w_max");
return true;
}
bool XPUMmdnnMatchConvTopkOp::CheckShape() const { return true; }
bool XPUMmdnnMatchConvTopkOp::InferShapeImpl() const {
int channel_num = param_.channel_num;
std::vector<int> topks = param_.topks;
auto row_dim = param_.input_x->dims();
auto num_k = topks.size();
auto row_shape_0 = row_dim[0];
std::vector<int64_t> vec_out_shape;
vec_out_shape.push_back(row_shape_0);
vec_out_shape.push_back(channel_num * num_k);
param_.topk_out->Resize(lite::DDim(vec_out_shape));
param_.topk_out->set_lod(param_.input_x->lod());
return true;
}
bool XPUMmdnnMatchConvTopkOp::AttachImpl(const cpp::OpDesc& op_desc,
lite::Scope* scope) {
param_.input_x = scope->FindVar(op_desc.Input("input_x").front())
->GetMutable<lite::Tensor>();
param_.input_y = scope->FindVar(op_desc.Input("input_y").front())
->GetMutable<lite::Tensor>();
param_.input_w = scope->FindVar(op_desc.Input("input_w").front())
->GetMutable<lite::Tensor>();
param_.conv_w = scope->FindVar(op_desc.Input("conv_w").front())
->GetMutable<lite::Tensor>();
param_.topk_out = scope->FindVar(op_desc.Output("topk_out").front())
->GetMutable<lite::Tensor>();
param_.input_w_max = op_desc.GetAttr<float>("input_w_max");
param_.conv_w_max = op_desc.GetAttr<float>("conv_w_max");
param_.topks = op_desc.GetAttr<std::vector<int>>("topks");
param_.channel_num = op_desc.GetAttr<int>("channel_num");
param_.dim_t = op_desc.GetAttr<int>("dim_t");
return true;
}
bool XPUMmdnnMergeAllOp::CheckShape() const { return true; }
bool XPUMmdnnMergeAllOp::InferShapeImpl() const {
int64_t dim0 = param_.concat_7in1_x[0]->dims()[0];
int64_t dim1 = param_.fc2_w->dims()[0];
std::vector<int64_t> vec_out_shape;
vec_out_shape.push_back(dim0);
vec_out_shape.push_back(dim1);
param_.out->Resize(lite::DDim(vec_out_shape));
return true;
}
bool XPUMmdnnMergeAllOp::AttachImpl(const cpp::OpDesc& op_desc,
lite::Scope* scope) {
param_.concat_7in1_x.clear();
for (auto& name : op_desc.Input("concat_7in1_x")) {
auto t = scope->FindVar(name)->GetMutable<lite::Tensor>();
param_.concat_7in1_x.push_back(t);
}
param_.concat_2in1_x.clear();
for (auto& name : op_desc.Input("concat_2in1_x")) {
auto t = scope->FindVar(name)->GetMutable<lite::Tensor>();
param_.concat_2in1_x.push_back(t);
}
param_.grnn_fw_wh = scope->FindVar(op_desc.Input("grnn_fw_wh").front())
->GetMutable<lite::Tensor>();
param_.grnn_fw_wi = scope->FindVar(op_desc.Input("grnn_fw_wi").front())
->GetMutable<lite::Tensor>();
param_.grnn_rv_wh = scope->FindVar(op_desc.Input("grnn_rv_wh").front())
->GetMutable<lite::Tensor>();
param_.grnn_rv_wi = scope->FindVar(op_desc.Input("grnn_rv_wi").front())
->GetMutable<lite::Tensor>();
param_.fc0_w = scope->FindVar(op_desc.Input("fc0_w").front())
->GetMutable<lite::Tensor>();
param_.fc0_b = scope->FindVar(op_desc.Input("fc0_b").front())
->GetMutable<lite::Tensor>();
param_.fc1_w = scope->FindVar(op_desc.Input("fc1_w").front())
->GetMutable<lite::Tensor>();
param_.fc1_b = scope->FindVar(op_desc.Input("fc1_b").front())
->GetMutable<lite::Tensor>();
param_.fc2_w = scope->FindVar(op_desc.Input("fc2_w").front())
->GetMutable<lite::Tensor>();
param_.fc2_b = scope->FindVar(op_desc.Input("fc2_b").front())
->GetMutable<lite::Tensor>();
param_.out =
scope->FindVar(op_desc.Output("out").front())->GetMutable<lite::Tensor>();
param_.grnn_fw_wh_maxs =
op_desc.GetAttr<std::vector<float>>("grnn_fw_wh_maxs");
param_.grnn_fw_wi_maxs =
op_desc.GetAttr<std::vector<float>>("grnn_fw_wi_maxs");
param_.grnn_rv_wh_maxs =
op_desc.GetAttr<std::vector<float>>("grnn_rv_wh_maxs");
param_.grnn_rv_wi_maxs =
op_desc.GetAttr<std::vector<float>>("grnn_rv_wi_maxs");
param_.fc0_w_max = op_desc.GetAttr<float>("fc0_w_max");
param_.fc1_w_max = op_desc.GetAttr<float>("fc1_w_max");
param_.fc2_w_max = op_desc.GetAttr<float>("fc2_w_max");
return true;
}
} // namespace operators
} // namespace lite
} // namespace paddle
REGISTER_LITE_OP(__xpu__mmdnn_bid_emb_grnn_att,
paddle::lite::operators::XPUMmdnnBidEmbGrnnAttOp);
REGISTER_LITE_OP(__xpu__mmdnn_bid_emb_att,
paddle::lite::operators::XPUMmdnnBidEmbAttOp);
REGISTER_LITE_OP(__xpu__mmdnn_match_conv_topk,
paddle::lite::operators::XPUMmdnnMatchConvTopkOp);
REGISTER_LITE_OP(__xpu__mmdnn_merge_all,
paddle::lite::operators::XPUMmdnnMergeAllOp);
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include "lite/core/op_lite.h"
namespace paddle {
namespace lite {
namespace operators {
class XPUMmdnnBidEmbGrnnAttOp : public OpLite {
public:
XPUMmdnnBidEmbGrnnAttOp() {}
explicit XPUMmdnnBidEmbGrnnAttOp(const std::string &op_type)
: OpLite(op_type) {}
bool CheckShape() const override;
bool InferShapeImpl() const override;
bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override;
void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); }
std::string DebugString() const override { return "XPUMmdnnBidEmbGrnnAttOp"; }
private:
mutable XPUMmdnnBidEmbGrnnAttParam param_;
};
class XPUMmdnnBidEmbAttOp : public OpLite {
public:
XPUMmdnnBidEmbAttOp() {}
explicit XPUMmdnnBidEmbAttOp(const std::string &op_type) : OpLite(op_type) {}
bool CheckShape() const override;
bool InferShapeImpl() const override;
bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override;
void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); }
std::string DebugString() const override { return "XPUMmdnnBidEmbAttOp"; }
private:
mutable XPUMmdnnBidEmbAttParam param_;
};
class XPUMmdnnMatchConvTopkOp : public OpLite {
public:
XPUMmdnnMatchConvTopkOp() {}
explicit XPUMmdnnMatchConvTopkOp(const std::string &op_type)
: OpLite(op_type) {}
bool CheckShape() const override;
bool InferShapeImpl() const override;
bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override;
void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); }
std::string DebugString() const override { return "XPUMmdnnMatchConvTopkOp"; }
private:
mutable XPUMmdnnMatchConvTopkParam param_;
};
class XPUMmdnnMergeAllOp : public OpLite {
public:
XPUMmdnnMergeAllOp() {}
explicit XPUMmdnnMergeAllOp(const std::string &op_type) : OpLite(op_type) {}
bool CheckShape() const override;
bool InferShapeImpl() const override;
bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override;
void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); }
std::string DebugString() const override { return "XPUMmdnnMergeAllOp"; }
private:
mutable XPUMmdnnMergeAllParam param_;
};
} // namespace operators
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/operators/__xpu__resnet_cbam_op.h"
#include <vector>
#include "lite/core/op_registry.h"
namespace paddle {
namespace lite {
namespace operators {
bool XPUResNetCbamOp::CheckShape() const { return true; }
bool XPUResNetCbamOp::InferShapeImpl() const {
auto input_shape = param_.input->dims();
std::vector<int64_t> output_shape_vec{1, 64};
paddle::lite::DDim output_shape(output_shape_vec);
output_shape[0] = input_shape[0];
param_.output->Resize(output_shape);
return true;
}
bool XPUResNetCbamOp::AttachImpl(const cpp::OpDesc& op_desc,
lite::Scope* scope) {
param_.input = const_cast<lite::Tensor*>(
&scope->FindVar(op_desc.Input("Input").front())->Get<lite::Tensor>());
param_.output = scope->FindVar(op_desc.Output("Output").front())
->GetMutable<lite::Tensor>();
param_.filter.clear();
for (auto& name : op_desc.Input("Filter")) {
auto t =
const_cast<lite::Tensor*>(&scope->FindVar(name)->Get<lite::Tensor>());
param_.filter.push_back(t);
}
param_.bias.clear();
for (auto& name : op_desc.Input("Bias")) {
if (name.substr(0, 11) == "placeholder") {
param_.bias.push_back(nullptr);
} else {
auto t =
const_cast<lite::Tensor*>(&scope->FindVar(name)->Get<lite::Tensor>());
param_.bias.push_back(t);
}
}
param_.max_filter.clear();
for (auto& name : op_desc.Input("MaxFilter")) {
auto t =
const_cast<lite::Tensor*>(&scope->FindVar(name)->Get<lite::Tensor>());
param_.max_filter.push_back(t);
}
param_.pool_p = op_desc.GetAttr<float>("pool_p");
return true;
}
} // namespace operators
} // namespace lite
} // namespace paddle
REGISTER_LITE_OP(__xpu__resnet_cbam, paddle::lite::operators::XPUResNetCbamOp);
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include "lite/core/op_lite.h"
namespace paddle {
namespace lite {
namespace operators {
class XPUResNetCbamOp : public OpLite {
public:
XPUResNetCbamOp() {}
explicit XPUResNetCbamOp(const std::string &op_type) : OpLite(op_type) {}
bool CheckShape() const override;
bool InferShapeImpl() const override;
bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override;
void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); }
std::string DebugString() const override { return "ResNetCbam"; }
private:
mutable XPUResNetCbamParam param_;
};
} // namespace operators
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/operators/__xpu__search_attention_op.h"
#include "lite/core/op_registry.h"
namespace paddle {
namespace lite {
namespace operators {
bool XPUMmdnnSearchAttentionOp::CheckShape() const { return true; }
bool XPUMmdnnSearchAttentionOp::InferShapeImpl() const {
auto& x_dims = param_.X->dims();
param_.Out->Resize(x_dims);
param_.Out->set_lod(param_.X->lod());
return true;
}
bool XPUMmdnnSearchAttentionOp::AttachImpl(const cpp::OpDesc& op_desc,
lite::Scope* scope) {
auto x = op_desc.Input("X").front();
auto w = op_desc.Input("W").front();
auto b = op_desc.Input("b").front();
auto out = op_desc.Output("Out").front();
param_.X = scope->FindVar(x)->GetMutable<lite::Tensor>();
param_.W = scope->FindVar(w)->GetMutable<lite::Tensor>();
param_.b = scope->FindVar(b)->GetMutable<lite::Tensor>();
param_.Out = scope->FindVar(out)->GetMutable<lite::Tensor>();
param_.W_max = op_desc.GetAttr<float>("W_max");
param_.pad_id = op_desc.GetAttr<int>("pad_id");
param_.alpha0 = op_desc.GetAttr<float>("alpha0");
param_.alpha1 = op_desc.GetAttr<float>("alpha1");
param_.mask = op_desc.GetAttr<float>("mask");
return true;
}
} // namespace operators
} // namespace lite
} // namespace paddle
REGISTER_LITE_OP(__xpu__mmdnn_search_attention,
paddle::lite::operators::XPUMmdnnSearchAttentionOp);
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include "lite/core/op_lite.h"
namespace paddle {
namespace lite {
namespace operators {
class XPUMmdnnSearchAttentionOp : public OpLite {
public:
XPUMmdnnSearchAttentionOp() {}
explicit XPUMmdnnSearchAttentionOp(const std::string &op_type)
: OpLite(op_type) {}
bool CheckShape() const override;
bool InferShapeImpl() const override;
bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override;
void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); }
std::string DebugString() const override {
return "XPUMmdnnSearchAttentionOp";
}
private:
mutable XPUMmdnnSearchAttentionParam param_;
};
} // namespace operators
} // namespace lite
} // namespace paddle
...@@ -145,6 +145,7 @@ REGISTER_LITE_OP(elementwise_mul, paddle::lite::operators::ElementwiseOp); ...@@ -145,6 +145,7 @@ REGISTER_LITE_OP(elementwise_mul, paddle::lite::operators::ElementwiseOp);
REGISTER_LITE_OP(elementwise_max, paddle::lite::operators::ElementwiseOp); REGISTER_LITE_OP(elementwise_max, paddle::lite::operators::ElementwiseOp);
REGISTER_LITE_OP(elementwise_div, paddle::lite::operators::ElementwiseOp); REGISTER_LITE_OP(elementwise_div, paddle::lite::operators::ElementwiseOp);
REGISTER_LITE_OP(elementwise_mod, paddle::lite::operators::ElementwiseOp); REGISTER_LITE_OP(elementwise_mod, paddle::lite::operators::ElementwiseOp);
REGISTER_LITE_OP(elementwise_pow, paddle::lite::operators::ElementwiseOp);
// #ifdef LITE_WITH_TRAIN // #ifdef LITE_WITH_TRAIN
// REGISTER_LITE_OP(elementwise_sub_grad, // REGISTER_LITE_OP(elementwise_sub_grad,
......
...@@ -94,6 +94,18 @@ bool MatchMatrixTensorOpLite::AttachImpl(const cpp::OpDesc& op_desc, ...@@ -94,6 +94,18 @@ bool MatchMatrixTensorOpLite::AttachImpl(const cpp::OpDesc& op_desc,
param_.dim_t = op_desc.GetAttr<int32_t>("dim_t"); param_.dim_t = op_desc.GetAttr<int32_t>("dim_t");
if (op_desc.HasAttr("fuse_relu")) {
param_.fuse_relu = op_desc.GetAttr<bool>("fuse_relu");
}
#ifdef LITE_WITH_XPU
if (op_desc.HasAttr("__xpu__float_to_fix")) {
param_.__xpu__float_to_fix = op_desc.GetAttr<bool>("__xpu__float_to_fix");
}
if (op_desc.HasAttr("__xpu__w_max")) {
param_.__xpu__w_max = op_desc.GetAttr<float>("__xpu__w_max");
}
#endif
return true; return true;
} }
......
...@@ -1129,6 +1129,11 @@ struct VarConv2DParam : ParamBase { ...@@ -1129,6 +1129,11 @@ struct VarConv2DParam : ParamBase {
int kernel_w; int kernel_w;
bool fuse_relu{false}; bool fuse_relu{false};
#ifdef LITE_WITH_XPU
bool __xpu__float_to_fix{false}; // Is W already converted to int16/int8
float __xpu__w_max{0.0f}; // Abs max in W
#endif
}; };
/// ----------------------- shape operators ---------------------- /// ----------------------- shape operators ----------------------
...@@ -1378,6 +1383,13 @@ struct SearchFcParam : ParamBase { ...@@ -1378,6 +1383,13 @@ struct SearchFcParam : ParamBase {
const lite::Tensor* b{}; const lite::Tensor* b{};
lite::Tensor* Out{}; lite::Tensor* Out{};
int out_size{}; int out_size{};
bool fuse_relu{false};
#ifdef LITE_WITH_XPU
bool __xpu__float_to_fix{false}; // Is W already converted to int16/int8
float __xpu__w_max{0.0f}; // Abs max in W
#endif
}; };
/// --------------------- match_matrix_tensor operators -------------------- /// --------------------- match_matrix_tensor operators --------------------
struct MatchMatrixTensorParam : ParamBase { struct MatchMatrixTensorParam : ParamBase {
...@@ -1388,6 +1400,12 @@ struct MatchMatrixTensorParam : ParamBase { ...@@ -1388,6 +1400,12 @@ struct MatchMatrixTensorParam : ParamBase {
lite::Tensor* tmp{}; lite::Tensor* tmp{};
int dim_t; int dim_t;
bool fuse_relu{false};
#ifdef LITE_WITH_XPU
bool __xpu__float_to_fix{false}; // Is w already converted to int16/int8
float __xpu__w_max{0.0f}; // Abs max in w
#endif
}; };
/// --------------------- search_seq_depadding operators -------------------- /// --------------------- search_seq_depadding operators --------------------
...@@ -1409,6 +1427,12 @@ struct SearchGrnnParam : ParamBase { ...@@ -1409,6 +1427,12 @@ struct SearchGrnnParam : ParamBase {
lite::Tensor* tmp_buffer{}; lite::Tensor* tmp_buffer{};
lite::Tensor* idx_sorted_by_width{}; lite::Tensor* idx_sorted_by_width{};
lite::Tensor* layout_input{}; lite::Tensor* layout_input{};
#ifdef LITE_WITH_XPU
bool __xpu__float_to_fix{false}; // Is wi/wh already converted to int16/int8
std::vector<float> __xpu__wi_max; // Abs max in wi
std::vector<float> __xpu__wh_max; // Abs max in wh
#endif
}; };
struct SplitLodTensorParam : ParamBase { struct SplitLodTensorParam : ParamBase {
...@@ -1563,6 +1587,106 @@ struct XPUFcParam : ParamBase { ...@@ -1563,6 +1587,106 @@ struct XPUFcParam : ParamBase {
std::string activation_type{""}; std::string activation_type{""};
}; };
struct XPUResNetCbamParam : ParamBase {
lite::Tensor* input{};
std::vector<lite::Tensor*> filter;
std::vector<lite::Tensor*> bias;
std::vector<lite::Tensor*> max_filter;
lite::Tensor* output{};
float pool_p{1.0f};
};
struct XPUMmdnnSearchAttentionParam : ParamBase {
lite::Tensor* X{};
lite::Tensor* W{};
lite::Tensor* b{};
lite::Tensor* Out{};
float W_max{0.0f};
int pad_id{0};
float alpha0{1.0f};
float alpha1{1.0f};
float mask{1.0f};
};
struct XPUMmdnnBidEmbGrnnAttParam : ParamBase {
lite::Tensor* id0{};
lite::Tensor* id1{};
lite::Tensor* emb_tbl{};
lite::Tensor* grnn_fw_wh{};
lite::Tensor* grnn_fw_wi{};
lite::Tensor* grnn_rv_wh{};
lite::Tensor* grnn_rv_wi{};
lite::Tensor* att_fc_w{};
lite::Tensor* att_fc_b{};
std::vector<float> grnn_fw_wh_maxs;
std::vector<float> grnn_fw_wi_maxs;
std::vector<float> grnn_rv_wh_maxs;
std::vector<float> grnn_rv_wi_maxs;
float att_fc_w_max{0.0f};
lite::Tensor* grnn_fw_pool_out{}; // 1
lite::Tensor* grnn_rv_pool_out{}; // 2
lite::Tensor* att_pool_out{}; // 3
lite::Tensor* concat_3in1_out{}; // 4
lite::Tensor* emb_fw_out{}; // 5
};
struct XPUMmdnnBidEmbAttParam : ParamBase {
lite::Tensor* id0{};
lite::Tensor* id1{};
lite::Tensor* emb_tbl{};
lite::Tensor* att_fc_w{};
lite::Tensor* att_fc_b{};
float att_fc_w_max{0.0f};
lite::Tensor* att_pool_out{}; // 1
lite::Tensor* emb_fw_out{}; // 2
};
struct XPUMmdnnMatchConvTopkParam : ParamBase {
lite::Tensor* input_x{};
lite::Tensor* input_y{};
lite::Tensor* input_w{};
lite::Tensor* conv_w{};
float input_w_max{0.0f};
float conv_w_max{0.0f};
std::vector<int> topks;
int channel_num{0};
int dim_t{0};
lite::Tensor* topk_out{};
};
struct XPUMmdnnMergeAllParam : ParamBase {
std::vector<lite::Tensor*> concat_7in1_x;
std::vector<lite::Tensor*> concat_2in1_x;
lite::Tensor* grnn_fw_wh{};
lite::Tensor* grnn_fw_wi{};
lite::Tensor* grnn_rv_wh{};
lite::Tensor* grnn_rv_wi{};
lite::Tensor* fc0_w{};
lite::Tensor* fc0_b{};
lite::Tensor* fc1_w{};
lite::Tensor* fc1_b{};
lite::Tensor* fc2_w{};
lite::Tensor* fc2_b{};
std::vector<float> grnn_fw_wh_maxs;
std::vector<float> grnn_fw_wi_maxs;
std::vector<float> grnn_rv_wh_maxs;
std::vector<float> grnn_rv_wi_maxs;
float fc0_w_max{0.0f};
float fc1_w_max{0.0f};
float fc2_w_max{0.0f};
lite::Tensor* out{};
};
// For DeformableConvolution op // For DeformableConvolution op
struct DeformableConvParam : ParamBase { struct DeformableConvParam : ParamBase {
lite::Tensor* x{}; lite::Tensor* x{};
......
...@@ -70,6 +70,18 @@ bool SearchFcOpLite::AttachImpl(const cpp::OpDesc &op_desc, ...@@ -70,6 +70,18 @@ bool SearchFcOpLite::AttachImpl(const cpp::OpDesc &op_desc,
param_.Out = scope->FindVar(Out)->GetMutable<lite::Tensor>(); param_.Out = scope->FindVar(Out)->GetMutable<lite::Tensor>();
param_.out_size = op_desc.GetAttr<int>("out_size"); param_.out_size = op_desc.GetAttr<int>("out_size");
if (op_desc.HasAttr("fuse_relu")) {
param_.fuse_relu = op_desc.GetAttr<bool>("fuse_relu");
}
#ifdef LITE_WITH_XPU
if (op_desc.HasAttr("__xpu__float_to_fix")) {
param_.__xpu__float_to_fix = op_desc.GetAttr<bool>("__xpu__float_to_fix");
}
if (op_desc.HasAttr("__xpu__w_max")) {
param_.__xpu__w_max = op_desc.GetAttr<float>("__xpu__w_max");
}
#endif
return true; return true;
} }
......
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
// limitations under the License. // limitations under the License.
#include "lite/operators/search_grnn_op.h" #include "lite/operators/search_grnn_op.h"
#include <vector>
#include "lite/core/op_lite.h" #include "lite/core/op_lite.h"
#include "lite/core/op_registry.h" #include "lite/core/op_registry.h"
...@@ -84,6 +85,18 @@ bool SearchGrnnOpLite::AttachImpl(const cpp::OpDesc& op_desc, ...@@ -84,6 +85,18 @@ bool SearchGrnnOpLite::AttachImpl(const cpp::OpDesc& op_desc,
param_.layout_input = param_.layout_input =
scope->FindVar(layout_input)->GetMutable<lite::Tensor>(); scope->FindVar(layout_input)->GetMutable<lite::Tensor>();
#ifdef LITE_WITH_XPU
if (op_desc.HasAttr("__xpu__float_to_fix")) {
param_.__xpu__float_to_fix = op_desc.GetAttr<bool>("__xpu__float_to_fix");
}
if (op_desc.HasAttr("__xpu__wi_max")) {
param_.__xpu__wi_max = op_desc.GetAttr<std::vector<float>>("__xpu__wi_max");
}
if (op_desc.HasAttr("__xpu__wh_max")) {
param_.__xpu__wh_max = op_desc.GetAttr<std::vector<float>>("__xpu__wh_max");
}
#endif
return true; return true;
} }
......
...@@ -34,6 +34,7 @@ bool SequenceReverseOp::InferShapeImpl() const { ...@@ -34,6 +34,7 @@ bool SequenceReverseOp::InferShapeImpl() const {
const auto *input = param_.X; const auto *input = param_.X;
auto out_dims = input->dims(); auto out_dims = input->dims();
param_.Out->Resize(out_dims); param_.Out->Resize(out_dims);
param_.Out->set_lod(param_.X->lod());
return true; return true;
} }
...@@ -45,6 +46,7 @@ bool SequenceReverseOp::AttachImpl(const cpp::OpDesc &opdesc, ...@@ -45,6 +46,7 @@ bool SequenceReverseOp::AttachImpl(const cpp::OpDesc &opdesc,
scope->FindVar(opdesc.Output("Y").front())->GetMutable<lite::Tensor>(); scope->FindVar(opdesc.Output("Y").front())->GetMutable<lite::Tensor>();
CHECK(param_.X); CHECK(param_.X);
CHECK(param_.Out); CHECK(param_.Out);
return true; return true;
} }
......
...@@ -52,6 +52,15 @@ bool VarConv2dOp::AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) { ...@@ -52,6 +52,15 @@ bool VarConv2dOp::AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) {
if (opdesc.HasAttr("fuse_relu")) { if (opdesc.HasAttr("fuse_relu")) {
param_.fuse_relu = opdesc.GetAttr<bool>("fuse_relu"); param_.fuse_relu = opdesc.GetAttr<bool>("fuse_relu");
} }
#ifdef LITE_WITH_XPU
if (opdesc.HasAttr("__xpu__float_to_fix")) {
param_.__xpu__float_to_fix = opdesc.GetAttr<bool>("__xpu__float_to_fix");
}
if (opdesc.HasAttr("__xpu__w_max")) {
param_.__xpu__w_max = opdesc.GetAttr<float>("__xpu__w_max");
}
#endif
return true; return true;
} }
......
...@@ -16,6 +16,15 @@ if(LITE_WITH_XPU) ...@@ -16,6 +16,15 @@ if(LITE_WITH_XPU)
add_dependencies(test_ernie_lite_xpu extern_lite_download_ernie_tar_gz) add_dependencies(test_ernie_lite_xpu extern_lite_download_ernie_tar_gz)
add_dependencies(test_bert_lite_xpu extern_lite_download_bert_tar_gz) add_dependencies(test_bert_lite_xpu extern_lite_download_bert_tar_gz)
endif() endif()
# TODO(miaotianxiang): enable later
#lite_cc_test(test_fpr_lite_xpu SRCS test_fpr_lite_xpu.cc
#DEPS mir_passes lite_api_test_helper paddle_api_full paddle_api_light gflags utils
#${ops} ${host_kernels} ${x86_kernels} ${xpu_kernels}
#ARGS --model_dir=${LITE_MODEL_DIR}/resnet50)
#lite_cc_test(test_mmdnn_lite_xpu SRCS test_mmdnn_lite_xpu.cc
#DEPS mir_passes lite_api_test_helper paddle_api_full paddle_api_light gflags utils
#${ops} ${host_kernels} ${x86_kernels} ${xpu_kernels}
#ARGS --model_dir=${LITE_MODEL_DIR}/resnet50)
endif() endif()
if(LITE_WITH_RKNPU) if(LITE_WITH_RKNPU)
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <gflags/gflags.h>
#include <gtest/gtest.h>
#include <vector>
#include "lite/api/lite_api_test_helper.h"
#include "lite/api/paddle_api.h"
#include "lite/api/paddle_use_kernels.h"
#include "lite/api/paddle_use_ops.h"
#include "lite/api/paddle_use_passes.h"
#include "lite/api/test_helper.h"
#include "lite/utils/cp_logging.h"
namespace paddle {
namespace lite {
TEST(ResnetCbam, test_resnet_cbam_lite_xpu) {
lite_api::CxxConfig config;
// config.set_model_dir(FLAGS_model_dir);
config.set_model_file(FLAGS_model_dir + "/__model__");
config.set_param_file(FLAGS_model_dir + "/__params__");
config.set_valid_places({lite_api::Place{TARGET(kXPU), PRECISION(kFloat)},
lite_api::Place{TARGET(kX86), PRECISION(kFloat)},
lite_api::Place{TARGET(kHost), PRECISION(kFloat)}});
config.set_xpu_workspace_l3_size_per_thread();
auto predictor = lite_api::CreatePaddlePredictor(config);
auto input_tensor = predictor->GetInput(0);
std::vector<int64_t> input_shape{1, 3, 224, 224};
input_tensor->Resize(input_shape);
auto* data = input_tensor->mutable_data<float>();
int input_num = 1;
for (size_t i = 0; i < input_shape.size(); ++i) {
input_num *= input_shape[i];
}
for (int i = 0; i < input_num; i++) {
data[i] = 1;
}
for (int i = 0; i < FLAGS_warmup; ++i) {
predictor->Run();
}
auto start = GetCurrentUS();
for (int i = 0; i < FLAGS_repeats; ++i) {
predictor->Run();
}
LOG(INFO) << "================== Speed Report ===================";
LOG(INFO) << "Model: " << FLAGS_model_dir << ", threads num " << FLAGS_threads
<< ", warmup: " << FLAGS_warmup << ", repeats: " << FLAGS_repeats
<< ", spend " << (GetCurrentUS() - start) / FLAGS_repeats / 1000.0
<< " ms in average.";
}
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <gflags/gflags.h>
#include <gtest/gtest.h>
#include <vector>
#include "lite/api/lite_api_test_helper.h"
#include "lite/api/paddle_api.h"
#include "lite/api/paddle_use_kernels.h"
#include "lite/api/paddle_use_ops.h"
#include "lite/api/paddle_use_passes.h"
#include "lite/api/test_helper.h"
#include "lite/utils/cp_logging.h"
#include "lite/utils/string.h"
DEFINE_bool(perf, false, "perf?");
DEFINE_string(perf_input, "perf_input", "perf_input");
namespace paddle {
namespace lite {
std::vector<int64_t> input0;
std::vector<uint64_t> input0_lod = {0};
std::vector<int64_t> input1;
std::vector<uint64_t> input1_lod = {0};
std::vector<int64_t> input2;
std::vector<uint64_t> input2_lod = {0};
std::vector<int64_t> input3;
std::vector<uint64_t> input3_lod = {0};
std::vector<int64_t> input4;
std::vector<uint64_t> input4_lod = {0};
std::vector<int64_t> input5;
std::vector<uint64_t> input5_lod = {0};
void ParseInput() {
std::string raw_input =
"0 1;145 10251 839 3719 428 52;1050488 1050488 911898 3719 760166 "
"760166;3719 428 52 18 1102 10327 252 20 153 2897 1146 70 156 6 145 "
"10251 839 5 1779 1729 1779 1729 18 2707 6 2707 20 4742 4937 432 6 "
"3869;3719 760166 760166 18 1035176 1035176 764393 764393 1259006 767614 "
"767614 1020808 769579 793958 793958 1050488 911898 751332 751332 750336 "
"750799 750336 751575 751575 751544 751735 751397 751365 751512 751512 "
"753011 751562;3719 428 52 18 1102 10327 252 20 153 2897 1146 70 156 6 "
"145 10251 839 2 1211 3 3719 720 1540 145 10251 839 9405 4315 5998 4 2 "
"600 373 41 3719 428 52 44 10251 4302 1319 7 12 2 768 6 918 6 841 870 8 "
"843 8 271;3719 760166 760166 18 1035176 1035176 764393 764393 1259006 "
"767614 767614 1020808 769579 793958 793958 1050488 911898 2 773899 "
"773899 3719 1118420 1118420 1050488 1050488 911898 9405 4315 5998 4 2 "
"785435 785435 41 3719 760166 760166 44 10251 4302 1319 750118 750118 2 "
"750465 750465 750274 750398 750233 751252 751252 753447 752830 753112;\n"
"0 0;145 10251 839 3719 428 52;1050488 1050488 911898 3719 760166 "
"760166;2109 2467 1805 227 3719 428 52 18 1102 10327 252 20 6 242 78 6 "
"532 78;2109 2467 1805 1245431 1245431 760166 760166 18 1035176 1035176 "
"764393 764393 752116 242 750370 750370 752081 751247;2109 2467 1805 227 "
"3719 428 52 18 1102 10327 252 20 2 145 242 1050 252 3582 2212;2109 2467 "
"1805 1245431 1245431 760166 760166 18 1035176 1035176 764393 764393 2 "
"871717 871717 757921 757921 3582 2212;\n"
"0 0;145 10251 839 3719 428 52;1050488 1050488 911898 3719 760166 "
"760166;145 10251 839 76 31 1337 823 7506 567 65 170 8 21293 3719 5 43 "
"394 743 42;1050488 1050488 911898 750016 750016 1337 823 7506 762617 "
"762617 866652 8 21293 3719 5 43 914758 914758 757202;145 10251 839 76 "
"31 1337 823 7506 567 65 170 8 21293 3719 2 17580 30 523324 3 10251 4104 "
"281 3 8511 3719 2217 3 13 226 3083 4 11251 1606 357 9 2 145 10251 839 "
"76 31 1337 823 7506 567 65 170 2 7506 2445 8 145 10251 839 528 839 "
"19670 6538;1050488 1050488 911898 750016 750016 1337 823 7506 762617 "
"762617 866652 8 21293 3719 2 816626 816626 523324 3 1181698 1181698 "
"751656 780821 1063148 3719 2217 3 752498 752498 831323 753602 11251 "
"1606 357 9 2 1050488 1050488 911898 750016 750016 1337 823 7506 762617 "
"762617 866652 2 7506 753045 753045 756756 1050488 911898 528 839 19670 "
"6538;\n"
"0 0;145 10251 839 3719 428 52;1050488 1050488 911898 3719 760166 "
"760166;145 10251 839 99 4 1102 10327 2196 41 3719 428 52 44 99 4 2899 "
"229 10 10 10;1050488 1050488 911898 807966 750273 1035176 1035176 "
"1237875 41 3719 760166 760166 753645 753645 750273 2899 229 750001 "
"750001 750001;145 10251 839 99 4 1102 10327 2196 41 3719 428 52 44 99 4 "
"2899 229 10 10 10 2 1177 8 145 10251 839 99 4 1102 10327 2196 41 3719 "
"428 52 44 99 4 2 101 8 1922 17 2184 2 1154 1922 72 1198 1266 "
"4516;1050488 1050488 911898 807966 750273 1035176 1035176 1237875 41 "
"3719 760166 760166 753645 753645 750273 2899 229 750001 750001 750001 2 "
"750257 750257 756756 1050488 911898 807966 750273 1035176 1035176 "
"1237875 41 3719 760166 760166 753645 753645 750273 2 764513 764513 "
"851213 851213 854628 2 753018 753018 754317 753328 754085 754070;\n"
"0 0;145 10251 839 3719 428 52;1050488 1050488 911898 3719 760166 "
"760166;73 5347 112 8 145 10251 839 262 169 22729 3719 6 743 6 339 1156 "
"78 136 399 693 128 571;776150 776150 112 756756 756756 1050488 911898 "
"791355 791355 22729 3719 6 758277 758277 750137 750234 750241 750178 "
"750055 750216 750212 750049;73 5347 112 8 145 10251 839 262 169 22729 "
"3719 2 588 415 549 415 115 23;776150 776150 112 756756 756756 1050488 "
"911898 791355 791355 22729 3719 2 750221 750221 750262 750277 750277 "
"750261;";
auto raw_lines = Split(raw_input, "\n");
for (auto& raw_line : raw_lines) {
auto inputx = Split(raw_line, ";");
for (size_t i = 1; i < inputx.size(); ++i) {
auto tokens = Split(inputx[i], " ");
static std::vector<int64_t>* const input_array[] = {
&input0, &input0, &input1, &input2, &input3, &input4, &input5};
static std::vector<uint64_t>* const lod_array[] = {&input0_lod,
&input0_lod,
&input1_lod,
&input2_lod,
&input3_lod,
&input4_lod,
&input5_lod};
for (auto token : tokens) {
input_array[i]->push_back((int64_t)atoi(token.c_str()));
}
lod_array[i]->push_back((uint64_t)tokens.size() +
(*lod_array[i])[lod_array[i]->size() - 1]);
}
}
return;
}
class MmdnnReader {
std::ifstream ifs;
std::vector<std::string> StringSplit(const std::string& in,
const std::string& delim) {
std::vector<std::string> ret;
if (in == "") {
return ret;
}
auto begpos = in.find_first_not_of(delim);
while (begpos != std::string::npos) {
auto endpos = in.find_first_of(delim, begpos);
if (endpos == std::string::npos) {
endpos = in.size();
}
std::string ssubstr = in.substr(begpos, endpos - begpos);
ret.push_back(ssubstr);
begpos = endpos + 1;
if (endpos >= (in.size() - 1)) {
break;
}
}
return ret;
}
public:
std::vector<int64_t> data[6];
std::vector<uint64_t> lod[6];
void Init(std::string file_name) { ifs.open(file_name); }
int Read(int maxline) {
for (int i = 0; i < 6; i++) {
data[i].clear();
}
for (int i = 0; i < 6; i++) {
lod[i].clear();
lod[i].push_back(0);
}
std::string line;
int cnt = 0;
while (cnt < maxline && getline(ifs, line)) {
std::vector<std::string> split1 = StringSplit(line, ";");
for (int i = 1; i < 7; i++) {
std::vector<std::string> split2 = StringSplit(split1[i], " ");
if (split2.size() == 0) {
split2.push_back("1280000");
}
for (size_t j = 0; j < split2.size(); j++) {
data[i - 1].push_back(std::stoi(split2[j].c_str(), nullptr, 0));
}
// if (i % 2 == 1) {
// lod[i / 2].push_back(lod[i / 2].back() + split2.size());
//}
lod[i - 1].push_back(lod[i - 1].back() + split2.size());
}
cnt++;
}
return cnt;
}
};
TEST(MMDNN, test_mmdnn_lite_xpu) {
lite_api::CxxConfig config;
config.set_model_dir(FLAGS_model_dir);
config.set_valid_places({lite_api::Place{TARGET(kXPU), PRECISION(kFloat)},
lite_api::Place{TARGET(kXPU), PRECISION(kInt64)},
lite_api::Place{TARGET(kX86), PRECISION(kFloat)},
lite_api::Place{TARGET(kX86), PRECISION(kInt64)},
lite_api::Place{TARGET(kHost), PRECISION(kFloat)}});
config.set_xpu_workspace_l3_size_per_thread();
auto predictor = lite_api::CreatePaddlePredictor(config);
if (FLAGS_perf) {
MmdnnReader reader;
reader.Init(FLAGS_perf_input);
int UB_batch = 40; // upper bound of batch
int iter = 0;
double tsc_sum = 0;
while (true) {
int batch = reader.Read(UB_batch);
if (batch <= 0) {
break;
}
++iter;
for (int i = 0; i < 6; ++i) {
auto input_x = predictor->GetInput(i);
input_x->Resize({(int64_t)reader.data[i].size(), 1});
input_x->SetLoD({reader.lod[i]});
auto* data_x = input_x->mutable_data<int64_t>();
memcpy(data_x,
reader.data[i].data(),
reader.data[i].size() * sizeof(int64_t));
}
auto start = GetCurrentUS();
predictor->Run();
auto end = GetCurrentUS();
tsc_sum += end - start;
}
LOG(INFO) << "================== Speed Report ===================";
LOG(INFO) << "Model: " << FLAGS_model_dir << ", threads num "
<< FLAGS_threads << ", warmup: " << FLAGS_warmup
<< ", repeats: " << iter << ", spend " << tsc_sum / iter / 1000.0
<< " ms in average.";
return;
}
ParseInput();
{
std::vector<int64_t> input0_shape{(int64_t)input0.size(), 1};
auto input_tensor0 = predictor->GetInput(0);
input_tensor0->Resize(input0_shape);
input_tensor0->SetLoD({input0_lod});
auto* data0 = input_tensor0->mutable_data<int64_t>();
memcpy(data0, input0.data(), sizeof(int64_t) * input0.size());
}
{
std::vector<int64_t> input1_shape{(int64_t)input1.size(), 1};
auto input_tensor1 = predictor->GetInput(1);
input_tensor1->Resize(input1_shape);
input_tensor1->SetLoD({input1_lod});
auto* data1 = input_tensor1->mutable_data<int64_t>();
memcpy(data1, input1.data(), sizeof(int64_t) * input1.size());
}
{
std::vector<int64_t> input2_shape{(int64_t)input2.size(), 1};
auto input_tensor2 = predictor->GetInput(2);
input_tensor2->Resize(input2_shape);
input_tensor2->SetLoD({input2_lod});
auto* data2 = input_tensor2->mutable_data<int64_t>();
memcpy(data2, input2.data(), sizeof(int64_t) * input2.size());
}
{
std::vector<int64_t> input3_shape{(int64_t)input3.size(), 1};
auto input_tensor3 = predictor->GetInput(3);
input_tensor3->Resize(input3_shape);
input_tensor3->SetLoD({input3_lod});
auto* data3 = input_tensor3->mutable_data<int64_t>();
memcpy(data3, input3.data(), sizeof(int64_t) * input3.size());
}
{
std::vector<int64_t> input4_shape{(int64_t)input4.size(), 1};
auto input_tensor4 = predictor->GetInput(4);
input_tensor4->Resize(input4_shape);
input_tensor4->SetLoD({input4_lod});
auto* data4 = input_tensor4->mutable_data<int64_t>();
memcpy(data4, input4.data(), sizeof(int64_t) * input4.size());
}
{
std::vector<int64_t> input5_shape{(int64_t)input5.size(), 1};
auto input_tensor5 = predictor->GetInput(5);
input_tensor5->Resize(input5_shape);
input_tensor5->SetLoD({input5_lod});
auto* data5 = input_tensor5->mutable_data<int64_t>();
memcpy(data5, input5.data(), sizeof(int64_t) * input5.size());
}
for (int i = 0; i < FLAGS_warmup; ++i) {
predictor->Run();
}
auto start = GetCurrentUS();
for (int i = 0; i < FLAGS_repeats; ++i) {
predictor->Run();
}
auto out = predictor->GetOutput(0);
auto out_shape = out->shape();
auto out_size = std::accumulate(
out_shape.begin(), out_shape.end(), 1, std::multiplies<int64_t>());
for (int i = 0; i < out_size; ++i) {
LOG(INFO) << "out[" << i << "] = " << out->data<float>()[i];
}
LOG(INFO) << "================== Speed Report ===================";
LOG(INFO) << "Model: " << FLAGS_model_dir << ", threads num " << FLAGS_threads
<< ", warmup: " << FLAGS_warmup << ", repeats: " << FLAGS_repeats
<< ", spend " << (GetCurrentUS() - start) / FLAGS_repeats / 1000.0
<< " ms in average.";
}
} // namespace lite
} // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册