未验证 提交 4780849f 编写于 作者: C Cwndmiao 提交者: GitHub

[LITE][XPU] Support ResnetCbam and MMDNN (#3844)

* [LITE][XPU] accomodate resnet_cbam

* [LITE][XPU] accomodate content-dnn

* fix pr comments test=develop

* fix pr comments test=develop

* fix pr comments test=develop test=xpu

* fix compilation error, test=develop test=xpu

* [X86] Fix the unit test of slice op
test=develop test=xpu
Co-authored-by: Nhong19860320 <9973393+hong19860320@users.noreply.github.com>
上级 e6b3d883
......@@ -39,7 +39,7 @@ else()
endif()
find_library(XPU_SDK_XPU_RT_FILE NAMES xpurt
PATHS ${XPU_SDK_ROOT}/XTDK/runtime/shlib
PATHS ${XPU_SDK_ROOT}/XTDK/runtime/shlib ${XPU_SDK_ROOT}/XTDK/shlib # libxpurt.so may have been moved to XTDK/runtime/shlib
NO_DEFAULT_PATH)
if(NOT XPU_SDK_XPU_RT_FILE)
......
......@@ -24,6 +24,9 @@
#ifdef LITE_WITH_CUDA
#include "lite/backends/cuda/target_wrapper.h"
#endif
#ifdef LITE_WITH_XPU
#include "lite/backends/xpu/target_wrapper.h"
#endif
#ifdef LITE_WITH_MLU
#include "lite/backends/mlu/target_wrapper.h"
......@@ -272,7 +275,7 @@ CxxConfig::mlu_firstconv_param() const {
void CxxConfig::set_xpu_workspace_l3_size_per_thread(int l3_size) {
#ifdef LITE_WITH_XPU
lite::Context<TargetType::kXPU>::SetWorkspaceL3Size(l3_size);
lite::TargetWrapperXPU::workspace_l3_size_per_thread = l3_size;
#else
LOG(WARNING) << "The invoking of the function "
"'set_xpu_workspace_l3_size_per_thread' is ignored, please "
......@@ -282,7 +285,7 @@ void CxxConfig::set_xpu_workspace_l3_size_per_thread(int l3_size) {
void CxxConfig::set_xpu_dev_per_thread(int dev_no) {
#ifdef LITE_WITH_XPU
lite::Context<TargetType::kXPU>::SetDev(dev_no);
lite::TargetWrapperXPU::SetDev(dev_no);
#else
LOG(WARNING) << "The invoking of the function 'set_xpu_dev_per_thread' is "
"ignored, please rebuild it with LITE_WITH_XPU=ON.";
......@@ -291,7 +294,7 @@ void CxxConfig::set_xpu_dev_per_thread(int dev_no) {
void CxxConfig::set_xpu_multi_encoder_precision(const std::string &precision) {
#ifdef LITE_WITH_XPU
lite::Context<TargetType::kXPU>::_multi_encoder_precision = precision;
lite::TargetWrapperXPU::multi_encoder_precision = precision;
#else
LOG(WARNING) << "The invoking of the function "
"'set_xpu_multi_encoder_precision' is "
......
......@@ -55,6 +55,8 @@ USE_MIR_PASS(apu_subgraph_pass);
USE_MIR_PASS(quantized_op_attributes_inference_pass);
USE_MIR_PASS(lite_scale_activation_fuse_pass);
USE_MIR_PASS(__xpu__resnet_fuse_pass);
USE_MIR_PASS(__xpu__resnet_cbam_fuse_pass);
USE_MIR_PASS(__xpu__multi_encoder_fuse_pass);
USE_MIR_PASS(__xpu__embedding_with_eltwise_add_fuse_pass);
USE_MIR_PASS(__xpu__fc_fuse_pass);
USE_MIR_PASS(__xpu__mmdnn_fuse_pass);
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <cstddef>
#include <cstdio>
#include <memory>
#include <string>
#include <type_traits>
#include "lite/backends/xpu/xpu_header_sitter.h"
namespace paddle {
namespace lite {
namespace xpu {
template <typename T>
void DumpCPUMem(const T* ptr,
size_t len,
const std::string& comment = "",
size_t stride = 1,
size_t item_per_line = 30) {
size_t after_stride_len = (len + stride - 1) / stride;
std::unique_ptr<T[]> after_stride(new T[after_stride_len]);
for (size_t i = 0; i < after_stride_len; ++i) {
after_stride[i] = ptr[i * stride];
}
double sum = 0;
for (size_t i = 0; i < len; ++i) {
sum += ptr[i];
}
printf(
"------------------------------ [%s] len=%zd stride=%zd sum=%f BEGIN "
"------------------------------\n",
comment.c_str(),
len,
stride,
sum);
size_t nline = (after_stride_len + item_per_line - 1) / item_per_line;
for (size_t i = 0; i < nline; ++i) {
size_t line_begin = i * item_per_line;
size_t line_end = line_begin + item_per_line;
printf("line[%04zd] -- ", i);
for (size_t ii = line_begin; (ii < line_end) && (ii < after_stride_len);
++ii) {
if (std::is_same<T, float>::value) {
printf("%.6f, ", static_cast<float>(after_stride[ii]));
} else if (std::is_same<T, int16_t>::value) {
printf("%d ", static_cast<int>(after_stride[ii]));
} else {
// CHECK(false) << "unknown type";
}
}
printf("\n");
}
printf(
"------------------------------ [%s] len=%zd stride=%zd sum=%f END "
"------------------------------\n",
comment.c_str(),
len,
stride,
sum);
}
template <typename T>
void DumpXPUMem(const T* ptr,
size_t len,
const std::string& comment = "",
size_t stride = 1,
size_t item_per_line = 30) {
size_t after_stride_len = (len + stride - 1) / stride;
std::unique_ptr<T[]> cpu_mem(new T[len]);
xpu_memcpy(
cpu_mem.get(), ptr, len * sizeof(T), XPUMemcpyKind::XPU_DEVICE_TO_HOST);
std::unique_ptr<T[]> after_stride(new T[after_stride_len]);
for (size_t i = 0; i < after_stride_len; ++i) {
after_stride[i] = cpu_mem[i * stride];
}
double sum = 0;
for (size_t i = 0; i < len; ++i) {
sum += cpu_mem[i];
}
printf(
"------------------------------ [%s] len=%zd stride=%zd sum=%f BEGIN "
"------------------------------\n",
comment.c_str(),
len,
stride,
sum);
size_t nline = (after_stride_len + item_per_line - 1) / item_per_line;
for (size_t i = 0; i < nline; ++i) {
size_t line_begin = i * item_per_line;
size_t line_end = line_begin + item_per_line;
printf("line[%04zd] -- ", i);
for (size_t ii = line_begin; (ii < line_end) && (ii < after_stride_len);
++ii) {
if (std::is_same<T, float>::value) {
printf("%.6f, ", static_cast<float>(after_stride[ii]));
} else if (std::is_same<T, int16_t>::value) {
printf("%d ", static_cast<int>(after_stride[ii]));
} else {
// CHECK(false) << "unknown type";
}
}
printf("\n");
}
printf(
"------------------------------ [%s] len=%zd stride=%zd sum=%f END "
"------------------------------\n",
comment.c_str(),
len,
stride,
sum);
}
} // namespace xpu
} // namespace lite
} // namespace paddle
......@@ -13,7 +13,6 @@
// limitations under the License.
#include "lite/backends/xpu/target_wrapper.h"
#include "lite/backends/xpu/xpu_header_sitter.h"
namespace paddle {
namespace lite {
......@@ -42,5 +41,21 @@ void TargetWrapperXPU::MemcpySync(void* dst,
}
}
XPUScratchPadGuard TargetWrapperXPU::MallocScratchPad(size_t size,
bool use_l3) {
void* ptr{nullptr};
if (use_l3) {
ptr = xdnn::alloc_workspace(GetRawContext(), size);
} else {
ptr = TargetWrapperXPU::Malloc(size);
}
CHECK(ptr != nullptr);
return XPUScratchPadGuard(new XPUScratchPad(ptr, use_l3));
}
std::string TargetWrapperXPU::multi_encoder_precision; // NOLINT
int TargetWrapperXPU::workspace_l3_size_per_thread{0};
thread_local xdnn::Context* TargetWrapperXPU::tls_raw_ctx_{nullptr};
} // namespace lite
} // namespace paddle
......@@ -14,6 +14,8 @@
#pragma once
#include <memory> // std::unique_ptr
#include "lite/backends/xpu/xpu_header_sitter.h" // xpu_free
#include "lite/core/target_wrapper.h"
namespace paddle {
......@@ -21,6 +23,24 @@ namespace lite {
using TargetWrapperXPU = TargetWrapper<TARGET(kXPU)>;
struct XPUScratchPad {
XPUScratchPad(void* addr, bool is_l3) : addr_(addr), is_l3_(is_l3) {}
void* addr_{nullptr};
bool is_l3_{false};
};
struct XPUScratchPadDeleter {
void operator()(XPUScratchPad* sp) const {
if (!sp->is_l3_) {
xpu_free(sp->addr_);
}
delete sp;
}
};
using XPUScratchPadGuard = std::unique_ptr<XPUScratchPad, XPUScratchPadDeleter>;
template <>
class TargetWrapper<TARGET(kXPU)> {
public:
......@@ -34,6 +54,41 @@ class TargetWrapper<TARGET(kXPU)> {
const void* src,
size_t size,
IoDirection dir);
static XPUScratchPadGuard MallocScratchPad(size_t size, bool use_l3 = true);
static xdnn::Context* GetRawContext() {
if (tls_raw_ctx_ == nullptr) {
tls_raw_ctx_ = xdnn::create_context();
CHECK(tls_raw_ctx_);
int r = xdnn::set_workspace_l3_size(tls_raw_ctx_,
workspace_l3_size_per_thread);
if (r != 0) {
LOG(WARNING) << "xdnn::set_workspace_l3_size() failed, r = " << r
<< ", workspace_l3_size_per_thread = "
<< workspace_l3_size_per_thread;
}
}
return tls_raw_ctx_;
}
// **DEPRECATED**, use xpu_set_device() at the very beginning of each worker
// thread
static void SetDev(int dev_no = 0) {
const char* dev_env = getenv("LITE_XPU_DEV");
if (dev_env) {
xpu_set_device(atoi(dev_env));
return;
}
xpu_set_device(dev_no);
}
static std::string multi_encoder_precision; // NOLINT
static int workspace_l3_size_per_thread;
private:
static thread_local xdnn::Context* tls_raw_ctx_;
};
} // namespace lite
......
......@@ -21,12 +21,6 @@ namespace lite {
std::string Context<TargetType::kNPU>::subgraph_model_cache_dir_{""}; // NOLINT
#endif
#ifdef LITE_WITH_XPU
std::string Context<TargetType::kXPU>::_multi_encoder_precision; // NOLINT
thread_local xdnn::Context* Context<TargetType::kXPU>::_tls_raw_ctx{nullptr};
int Context<TargetType::kXPU>::_workspace_l3_size_per_thread{0};
#endif
#ifdef LITE_WITH_MLU
int Context<TargetType::kMLU>::next_queue_id_{0};
std::map<int, int> Context<TargetType::kMLU>::queue_id_map_;
......
......@@ -144,45 +144,12 @@ class Context<TargetType::kXPU> {
void CopySharedTo(XPUContext* ctx) {}
// TODO(miaotianxiang): remove this
static xdnn::Context* GetRawContext() {
if (_tls_raw_ctx == nullptr) {
_tls_raw_ctx = xdnn::create_context();
CHECK(_tls_raw_ctx);
int r = xdnn::set_workspace_l3_size(_tls_raw_ctx,
_workspace_l3_size_per_thread);
if (r != 0) {
LOG(WARNING) << "xdnn::set_workspace_l3_size() failed, r = " << r
<< ", _workspace_l3_size_per_thread = "
<< _workspace_l3_size_per_thread;
}
}
return _tls_raw_ctx;
}
static void SetWorkspaceL3Size(int l3_size = 0xfffc00) {
_workspace_l3_size_per_thread = l3_size;
}
// **DEPRECATED**, use xpu_set_device() at the very beginning of each worker
// thread
static void SetDev(int dev_no = 0) {
const char* dev_env = getenv("LITE_XPU_DEV");
if (dev_env) {
xpu_set_device(atoi(dev_env));
return;
}
xpu_set_device(dev_no);
return TargetWrapperXPU::GetRawContext();
}
std::string name() const { return "XPUContext"; }
public:
static std::string _multi_encoder_precision; // NOLINT
private:
static thread_local xdnn::Context* _tls_raw_ctx;
static int _workspace_l3_size_per_thread;
};
#endif
......
......@@ -23,9 +23,11 @@ lite_cc_library(mir_passes
fusion/sequence_pool_concat_fuse_pass.cc
fusion/scale_activation_fuse_pass.cc
fusion/__xpu__resnet_fuse_pass.cc
fusion/__xpu__resnet_cbam_fuse_pass.cc
fusion/__xpu__multi_encoder_fuse_pass.cc
fusion/__xpu__embedding_with_eltwise_add_fuse_pass.cc
fusion/__xpu__fc_fuse_pass.cc
fusion/__xpu__mmdnn_fuse_pass.cc
elimination/identity_scale_eliminate_pass.cc
elimination/identity_dropout_eliminate_pass.cc
elimination/elementwise_mul_constant_eliminate_pass.cc
......
此差异已折叠。
......@@ -639,20 +639,21 @@ class XPUMultiEncoderFusePass : public ProgramPass {
std::set<int> fc_int31_ids;
#ifdef LITE_WITH_XPU
// TODO(miaotianxiang): core/mir/*_pass.cc are compiled anyway and need to
// access Context<kXPU>::_multi_encoder_precision, but this static member
// variable in class specialization defined in lite/core/context.cc
// is only compiled iff LITE_WITH_XPU==ON. To suppress linkage error, we use
// access TargetWrapperXPU::multi_encoder_precision, but this static member
// variable in class specialization defined in
// lite/backends/xpu/target_wrapper.cc is only compiled iff
// LITE_WITH_XPU==ON. To suppress linkage error, we use
// #ifdef here. Any better idea?
if (GetStringFromEnv("XPU_ENCODER_PRECISION", "int16") == "int31" ||
lite::Context<TargetType::kXPU>::_multi_encoder_precision == "int31") {
lite::TargetWrapperXPU::multi_encoder_precision == "int31") {
fc_int31_ids = {0, 1, 2, 3, 4, 5};
VLOG(3) << "Use int31 in XPUMultiEncoderOp, "
<< "lite::Context<>::_multi_encoder_precision="
<< lite::Context<TargetType::kXPU>::_multi_encoder_precision;
<< "lite::TargetWrapperXPU::multi_encoder_precision="
<< lite::TargetWrapperXPU::multi_encoder_precision;
} else {
VLOG(3) << "Use int16 in XPUMultiEncoderOp, "
<< "lite::Context<>::_multi_encoder_precision="
<< lite::Context<TargetType::kXPU>::_multi_encoder_precision;
<< "lite::TargetWrapperXPU::multi_encoder_precision="
<< lite::TargetWrapperXPU::multi_encoder_precision;
}
#endif
......
此差异已折叠。
......@@ -94,6 +94,8 @@ class Optimizer {
#endif
"identity_dropout_eliminate_pass",
"__xpu__resnet_fuse_pass",
"__xpu__resnet_cbam_fuse_pass",
"__xpu__mmdnn_fuse_pass",
"__xpu__multi_encoder_fuse_pass",
"__xpu__embedding_with_eltwise_add_fuse_pass",
"__xpu__fc_fuse_pass",
......
......@@ -157,7 +157,7 @@ void slice_compute(const lite::Tensor* in,
}
}
out->mutable_data<float>(lite::TargetType::kX86);
out->mutable_data<float>();
auto new_out_dims = out->dims();
auto offsets = Eigen::array<int, D>();
......
......@@ -6,6 +6,7 @@ if(LITE_WITH_XTCL)
add_subdirectory(bridges)
add_kernel(subgraph_compute_xpu XPU basic SRCS subgraph_compute.cc DEPS ${lite_kernel_deps} device_xpu subgraph_bridge_engine ${xpu_subgraph_bridges})
else()
# basic
add_kernel(conv_compute_xpu XPU basic SRCS conv_compute.cc DEPS ${lite_kernel_deps})
add_kernel(io_copy_compute_xpu XPU basic SRCS io_copy_compute.cc DEPS ${lite_kernel_deps} target_wrapper_xpu)
add_kernel(batch_norm_compute_xpu XPU basic SRCS batch_norm_compute.cc DEPS ${lite_kernel_deps})
......@@ -15,15 +16,32 @@ else()
add_kernel(mul_compute_xpu XPU basic SRCS mul_compute.cc DEPS ${lite_kernel_deps})
add_kernel(softmax_compute_xpu XPU basic SRCS softmax_compute.cc DEPS ${lite_kernel_deps})
add_kernel(scale_compute_xpu XPU basic SRCS scale_compute.cc DEPS ${lite_kernel_deps})
add_kernel(lookup_table_compute_xpu XPU basic SRCS lookup_table_compute.cc DEPS ${lite_kernel_deps})
add_kernel(layer_norm_compute_xpu XPU basic SRCS layer_norm_compute.cc DEPS ${lite_kernel_deps})
add_kernel(dropout_compute_xpu XPU basic SRCS dropout_compute.cc DEPS ${lite_kernel_deps})
add_kernel(matmul_compute_xpu XPU basic SRCS matmul_compute.cc DEPS ${lite_kernel_deps})
add_kernel(stack_compute_xpu XPU basic SRCS stack_compute.cc DEPS ${lite_kernel_deps})
add_kernel(slice_compute_xpu XPU basic SRCS slice_compute.cc DEPS ${lite_kernel_deps})
add_kernel(cast_compute_xpu XPU basic SRCS cast_compute.cc DEPS ${lite_kernel_deps})
add_kernel(sequence_topk_avg_pooling_compute_xpu XPU basic SRCS sequence_topk_avg_pooling_compute.cc DEPS ${lite_kernel_deps})
add_kernel(concat_compute_xpu XPU basic SRCS concat_compute.cc DEPS ${lite_kernel_deps})
add_kernel(search_fc_compute_xpu XPU basic SRCS search_fc_compute.cc DEPS ${lite_kernel_deps})
# extra
add_kernel(lookup_table_compute_xpu XPU extra SRCS lookup_table_compute.cc DEPS ${lite_kernel_deps})
add_kernel(layer_norm_compute_xpu XPU extra SRCS layer_norm_compute.cc DEPS ${lite_kernel_deps})
add_kernel(sequence_reverse_compute_xpu XPU extra SRCS sequence_reverse_compute.cc DEPS ${lite_kernel_deps})
add_kernel(sequence_concat_compute_xpu XPU extra SRCS sequence_concat_compute.cc DEPS ${lite_kernel_deps})
add_kernel(sequence_arithmetic_compute_xpu XPU extra SRCS sequence_arithmetic_compute.cc DEPS ${lite_kernel_deps})
add_kernel(sequence_pool_compute_xpu XPU extra SRCS sequence_pool_compute.cc DEPS ${lite_kernel_deps})
add_kernel(match_matrix_tensor_compute_xpu XPU extra SRCS match_matrix_tensor_compute.cc DEPS ${lite_kernel_deps})
add_kernel(var_conv_2d_compute_xpu XPU extra SRCS var_conv_2d_compute.cc DEPS ${lite_kernel_deps})
add_kernel(search_grnn_compute_xpu XPU extra SRCS search_grnn_compute.cc DEPS ${lite_kernel_deps})
# extra(fused kernel)
add_kernel(__xpu__resnet50_compute_xpu XPU extra SRCS __xpu__resnet50_compute.cc DEPS ${lite_kernel_deps})
add_kernel(__xpu__resnet_cbam_compute_xpu XPU extra SRCS __xpu__resnet_cbam_compute.cc DEPS ${lite_kernel_deps})
add_kernel(__xpu__multi_encoder_compute_xpu XPU extra SRCS __xpu__multi_encoder_compute.cc DEPS ${lite_kernel_deps})
add_kernel(__xpu__embedding_with_eltwise_add_compute_xpu XPU extra SRCS __xpu__embedding_with_eltwise_add_compute.cc DEPS ${lite_kernel_deps})
add_kernel(__xpu__fc_compute_xpu XPU extra SRCS __xpu__fc_compute.cc DEPS ${lite_kernel_deps})
add_kernel(__xpu__search_attention_compute_xpu XPU extra SRCS __xpu__search_attention_compute.cc DEPS ${lite_kernel_deps})
add_kernel(__xpu__mmdnn_compute_xpu XPU extra SRCS __xpu__mmdnn_compute.cc DEPS ${lite_kernel_deps})
endif()
此差异已折叠。
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/xpu/__xpu__resnet_cbam_compute.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace xpu {
void XPUResNetCbamCompute::PrepareForRun() {
auto& param = this->Param<param_t>();
for (auto* filter : param.filter) {
arg_filter_.push_back(
reinterpret_cast<const int16_t*>(filter->data<float>()));
}
for (auto* bias : param.bias) {
if (bias == nullptr) {
arg_bias_.push_back(nullptr);
} else {
arg_bias_.push_back(bias->data<float>());
}
}
for (auto* max_filter : param.max_filter) {
arg_max_filter_.push_back(max_filter->data<float>());
}
}
void XPUResNetCbamCompute::Run() {
auto& param = this->Param<param_t>();
auto& ctx = this->ctx_->As<XPUContext>();
auto input_dims = param.input->dims();
int batch_size = input_dims[0];
int height = input_dims[2];
int width = input_dims[3];
int r = xdnn::conv2d_int16_resnet_cbam<float, int16_t>(
ctx.GetRawContext(), /* context */
batch_size, /* num */
height, /* height */
width, /* width */
param.input->data<float>(), /* bottom */
&arg_filter_[0], /* weight_list */
param.output->mutable_data<float>(TARGET(kXPU)), /* top */
&arg_bias_[0], /* bias_list */
&arg_max_filter_[0], /* max_filter_list */
param.pool_p, /* pool_p */
true, /* midtype_fp16 */
false /* dynamic_shape */);
CHECK_EQ(r, 0);
}
} // namespace xpu
} // namespace kernels
} // namespace lite
} // namespace paddle
REGISTER_LITE_KERNEL(__xpu__resnet_cbam,
kXPU,
kFloat,
kNCHW,
paddle::lite::kernels::xpu::XPUResNetCbamCompute,
def)
.BindInput("Input", {LiteType::GetTensorTy(TARGET(kXPU))})
.BindInput("Filter", {LiteType::GetTensorTy(TARGET(kXPU))})
.BindInput("Bias", {LiteType::GetTensorTy(TARGET(kXPU))})
.BindInput("MaxFilter", {LiteType::GetTensorTy(TARGET(kXPU))})
.BindOutput("Output", {LiteType::GetTensorTy(TARGET(kXPU))})
.Finalize();
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <vector>
#include "lite/backends/xpu/xpu_header_sitter.h"
#include "lite/core/kernel.h"
#include "lite/core/op_registry.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace xpu {
class XPUResNetCbamCompute
: public KernelLite<TARGET(kXPU), PRECISION(kFloat)> {
public:
using param_t = operators::XPUResNetCbamParam;
virtual void PrepareForRun();
virtual void Run();
private:
std::vector<const int16_t *> arg_filter_;
std::vector<const float *> arg_max_filter_;
std::vector<const float *> arg_bias_;
};
} // namespace xpu
} // namespace kernels
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/xpu/__xpu__search_attention_compute.h"
#include "lite/backends/xpu/xpu_header_sitter.h"
#include "lite/core/op_registry.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace xpu {
void XPUMmdnnSearchAttentionCompute::PrepareForRun() {
offset_xpu_guard_ = TargetWrapperXPU::MallocScratchPad(64 * sizeof(int));
pad_begin_xpu_guard_ = TargetWrapperXPU::MallocScratchPad(64 * sizeof(int));
w_max_xpu_guard_ = TargetWrapperXPU::MallocScratchPad(8 * sizeof(float));
buffer_at_l3_guard_ = TargetWrapperXPU::MallocScratchPad(
5 * L3_SLOT_SIZE * sizeof(float), false /* use_l3 */);
buffer_at_gm_guard_ = TargetWrapperXPU::MallocScratchPad(
5 * GM_SLOT_SIZE * sizeof(float), false /* use_l3 */);
offset_cpu.reset(new int[64]);
pad_begin_cpu.reset(new int[64]);
}
void XPUMmdnnSearchAttentionCompute::Run() {
auto& param = this->Param<param_t>();
auto& ctx = this->ctx_->As<XPUContext>();
auto* X = param.X;
auto* W = param.W;
auto* b = param.b;
float W_max = param.W_max;
float alpha0 = param.alpha0;
float alpha1 = param.alpha1;
float mask = param.mask;
const int16_t* w_data = W->data<int16_t>();
const float* b_data = b->data<float>();
int batch = X->lod()[0].size() - 1;
int dim0 = X->dims()[0];
int dim1 = X->dims()[1];
const auto offset = X->lod()[0];
int max_seq = 0;
auto* top = param.Out;
LoD top_lod;
top_lod.push_back(X->lod()[0]);
top->set_lod(top_lod);
top->Resize({dim0, dim1});
auto* top_data = top->mutable_data<float>(TARGET(kXPU));
float maxs_cpu[8] = {0.0f, 0.0f, 0.0f, 0.0f, W_max, 0.0f, 0.0f, 0.0f};
for (int i = 0; i < batch; ++i) {
offset_cpu[i] = offset[i]; // type of offset is int64, not supported by xpu
pad_begin_cpu[i] = offset[i + 1] - offset[i];
if (offset[i + 1] - offset[i] > max_seq) {
max_seq = offset[i + 1] - offset[i];
}
}
offset_cpu[batch] = offset[batch];
xpu_memcpy(offset_xpu_guard_->addr_,
offset_cpu.get(),
offset.size() * sizeof(int),
XPUMemcpyKind::XPU_HOST_TO_DEVICE);
xpu_memcpy(pad_begin_xpu_guard_->addr_,
pad_begin_cpu.get(),
batch * sizeof(int),
XPUMemcpyKind::XPU_HOST_TO_DEVICE);
xpu_memcpy(w_max_xpu_guard_->addr_,
maxs_cpu,
8 * sizeof(float),
XPUMemcpyKind::XPU_HOST_TO_DEVICE);
int* offset_xpu = reinterpret_cast<int*>(offset_xpu_guard_->addr_);
int* pad_begin_xpu = reinterpret_cast<int*>(pad_begin_xpu_guard_->addr_);
float* maxs_xpu = reinterpret_cast<float*>(w_max_xpu_guard_->addr_);
float* buffer_at_l3 = reinterpret_cast<float*>(buffer_at_l3_guard_->addr_);
float* buffer_at_gm = reinterpret_cast<float*>(buffer_at_gm_guard_->addr_);
// when use l3, max_seq <= 128:
// group_padding: batch * max_seq * dim1; at (slot0, slot1)
// seq_fc: batch * max_seq * dim1; at (slot2, slot3)
// batchgemm0: batch * max_seq * max_seq; at slot4
// attention_padding_mask: batch * max_seq * max_seq; at slot3
// seq_softmax: batch * max_seq * max_seq; at slot4
// batchgemm1: batch * max_seq * dim1; at (slot2, slot3)
float* group_padding_output = buffer_at_l3;
float* seq_fc_output = buffer_at_l3 + 2 * L3_SLOT_SIZE;
float* batchgemm0_output = buffer_at_l3 + 4 * L3_SLOT_SIZE;
float* attention_output = buffer_at_l3 + 3 * L3_SLOT_SIZE;
float* seq_softmax_output = buffer_at_l3 + 4 * L3_SLOT_SIZE;
float* batchgemm1_output = buffer_at_l3 + 2 * L3_SLOT_SIZE;
if (max_seq > 128) {
group_padding_output = buffer_at_gm;
seq_fc_output = buffer_at_gm + 1 * GM_SLOT_SIZE;
batchgemm0_output = buffer_at_gm + 2 * GM_SLOT_SIZE;
attention_output = buffer_at_gm + 1 * GM_SLOT_SIZE;
seq_softmax_output = buffer_at_gm + 3 * GM_SLOT_SIZE;
batchgemm1_output = buffer_at_gm + 4 * GM_SLOT_SIZE;
}
const auto* bottom_data = X->data<float>();
xdnn::search_sequence_pad_depad(ctx.GetRawContext(),
const_cast<float*>(bottom_data),
group_padding_output,
offset_xpu,
max_seq,
batch,
dim1,
0); // is_depad = 0
// do-findmax
xdnn::findmax<float>(ctx.GetRawContext(),
group_padding_output,
batch * max_seq * dim1,
maxs_xpu);
xdnn::gemm_int16_maxptr<float, int16_t, float>(
ctx.GetRawContext(),
false,
true, // trans_a, trans_b
batch * max_seq,
dim1,
dim1, // m, n, k
1.0f,
group_padding_output,
dim1, // alpha, data_a, lda
w_data,
dim1,
0.0f, // data_b, ldb, beta
seq_fc_output,
dim1,
b_data, // data_c, ldc, bias
xdnn::Activation_t::LINEAR,
maxs_xpu,
maxs_xpu + 4,
nullptr); // max_a, max_b, max_c
xdnn::search_aligned_mat_mul(ctx.GetRawContext(),
0,
1,
batch,
max_seq,
max_seq,
dim1,
alpha0,
group_padding_output,
dim1,
seq_fc_output,
dim1,
batchgemm0_output,
max_seq);
xdnn::search_pad_mask(ctx.GetRawContext(),
batchgemm0_output,
attention_output,
pad_begin_xpu,
batch,
max_seq,
max_seq,
batch,
mask);
xdnn::softmax2d_forward(ctx.GetRawContext(),
attention_output,
seq_softmax_output,
batch * max_seq,
max_seq,
true);
xdnn::search_aligned_mat_mul(ctx.GetRawContext(),
0,
0,
batch,
max_seq,
dim1,
max_seq,
alpha1,
seq_softmax_output,
max_seq,
group_padding_output,
dim1,
batchgemm1_output,
dim1);
xdnn::search_sequence_pad_depad(ctx.GetRawContext(),
top_data,
batchgemm1_output,
offset_xpu,
max_seq,
batch,
dim1,
1); // is_depad = 1
}
} // namespace xpu
} // namespace kernels
} // namespace lite
} // namespace paddle
REGISTER_LITE_KERNEL(__xpu__mmdnn_search_attention,
kXPU,
kFloat,
kNCHW,
paddle::lite::kernels::xpu::XPUMmdnnSearchAttentionCompute,
def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kXPU))})
.BindInput("W", {LiteType::GetTensorTy(TARGET(kXPU))})
.BindInput("b", {LiteType::GetTensorTy(TARGET(kXPU))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kXPU))})
.Finalize();
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include "lite/backends/xpu/target_wrapper.h" // XPUScratchPadGuard
#include "lite/core/kernel.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace xpu {
class XPUMmdnnSearchAttentionCompute
: public KernelLite<TARGET(kXPU), PRECISION(kFloat)> {
public:
using param_t = operators::XPUMmdnnSearchAttentionParam;
void PrepareForRun() override;
void Run() override;
private:
XPUScratchPadGuard offset_xpu_guard_;
XPUScratchPadGuard pad_begin_xpu_guard_;
XPUScratchPadGuard w_max_xpu_guard_;
XPUScratchPadGuard buffer_at_l3_guard_;
XPUScratchPadGuard buffer_at_gm_guard_;
std::unique_ptr<int[]> offset_cpu;
std::unique_ptr<int[]> pad_begin_cpu;
const int L3_SLOT_SIZE = 40 * 128 * 128;
const int GM_SLOT_SIZE = 40 * 512 * 512;
};
} // namespace xpu
} // namespace kernels
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/xpu/concat_compute.h"
#include "lite/backends/xpu/xpu_header_sitter.h"
#include "lite/core/op_registry.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace xpu {
void ConcatCompute::Run() {
auto& param = this->Param<param_t>();
auto& ctx = this->ctx_->As<XPUContext>();
auto ins = param.x;
auto out = param.output;
int64_t axis = param.axis;
int n = ins.size();
int h = 1;
int w_except_axis = 1;
CHECK(n <= 8) << "XPU only surpport at most 8 tensors for now";
for (int i = 0; i < axis; ++i) {
h *= (ins[0]->dims())[i];
}
for (int i = axis + 1; i < ins[0]->dims().size(); ++i) {
w_except_axis *= (ins[0]->dims())[i];
}
CHECK(axis >= 0) << "concat: axis shoud >= 0!";
CHECK(axis < ins[0]->dims().size()) << "concat: axis shoud < ins[0]->dims()!";
for (int i = 0; i < n; ++i) {
int hh = 1;
int ww = 1;
for (int j = 0; j < axis; ++j) {
hh *= (ins[i]->dims())[j];
}
for (int j = axis + 1; j < ins[i]->dims().size(); ++j) {
ww *= (ins[i]->dims())[j];
}
CHECK(hh == h) << "concat: h should be eual!";
CHECK(ww == w_except_axis) << "concat: w should be eual except for axis!";
}
int in_w_host[n]; // NOLINT
const float* ptrs[n]; // NOLINT
for (int i = 0; i < n; ++i) {
ptrs[i] = ins[i]->data<float>();
in_w_host[i] = w_except_axis * (ins[i]->dims())[axis];
}
int r = xdnn::concat<float>(ctx.GetRawContext(), /* ctx */
h, /* height */
in_w_host, /* width_x */
n, /* n */
ptrs, /* lm_ptrs */
out->mutable_data<float>(TARGET(kXPU)) /*y*/);
CHECK_EQ(r, 0);
}
} // namespace xpu
} // namespace kernels
} // namespace lite
} // namespace paddle
REGISTER_LITE_KERNEL(
concat, kXPU, kFloat, kNCHW, paddle::lite::kernels::xpu::ConcatCompute, def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kXPU))})
.BindInput("AxisTensor",
{LiteType::GetTensorTy(TARGET(kXPU), PRECISION(kInt32))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kXPU))})
.Finalize();
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "lite/core/kernel.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace xpu {
class ConcatCompute : public KernelLite<TARGET(kXPU), PRECISION(kFloat)> {
public:
using param_t = operators::ConcatParam;
virtual void Run();
virtual ~ConcatCompute() = default;
};
} // namespace xpu
} // namespace kernels
} // namespace lite
} // namespace paddle
此差异已折叠。
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include "lite/backends/xpu/target_wrapper.h" // XPUScratchPadGuard
#include "lite/core/kernel.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace xpu {
class MatchMatrixTensorCompute
: public KernelLite<TARGET(kXPU), PRECISION(kFloat)> {
public:
using param_t = operators::MatchMatrixTensorParam;
virtual void PrepareForRun();
virtual void Run();
private:
XPUScratchPadGuard wx_max_xpu_guard_;
XPUScratchPadGuard offset_l_xpu_guard_;
XPUScratchPadGuard offset_r_xpu_guard_;
std::unique_ptr<int[]> offset_l_cpu;
std::unique_ptr<int[]> offset_r_cpu;
};
} // namespace xpu
} // namespace kernels
} // namespace lite
} // namespace paddle
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include "lite/core/kernel.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace xpu {
class SequenceArithmeticCompute
: public KernelLite<TARGET(kXPU), PRECISION(kFloat)> {
public:
using param_t = operators::SequenceArithmeticParam;
void Run() override;
};
} // namespace xpu
} // namespace kernels
} // namespace lite
} // namespace paddle
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册