提交 92dc2ec6 编写于 作者: W weihaoji

[XPU] Support OCR models

* add exp and reciprocal for activation
* add conv2d transpose op
* add fill constant op
* add im2sequence op
* add interpolate op (including nearest and bilinear)
* add lrn op
* add split op
* add sum op
* add topk op
* add gru(int31) op
* fix bug in elementwise arithmetic op
* fix bug in conv2d op
* fix bug in dropout op
* fix bug in cast op, test=develop test=xpu
上级 9a9d1cf2
......@@ -8,6 +8,7 @@ if(LITE_WITH_XTCL)
else()
# basic
add_kernel(conv_compute_xpu XPU basic SRCS conv_compute.cc DEPS ${lite_kernel_deps})
add_kernel(conv2d_transpose_compute_xpu XPU basic SRCS conv2d_transpose_compute.cc DEPS ${lite_kernel_deps})
add_kernel(io_copy_compute_xpu XPU basic SRCS io_copy_compute.cc DEPS ${lite_kernel_deps} target_wrapper_xpu)
add_kernel(batch_norm_compute_xpu XPU basic SRCS batch_norm_compute.cc DEPS ${lite_kernel_deps})
add_kernel(activation_compute_xpu XPU basic SRCS activation_compute.cc DEPS ${lite_kernel_deps})
......@@ -27,6 +28,10 @@ else()
add_kernel(reshape_compute_xpu XPU basic SRCS reshape_compute.cc DEPS ${lite_kernel_deps})
add_kernel(reduce_mean_compute_xpu XPU basic SRCS reduce_mean_compute.cc DEPS ${lite_kernel_deps})
add_kernel(reduce_sum_compute_xpu XPU basic SRCS reduce_sum_compute.cc DEPS ${lite_kernel_deps})
add_kernel(split_compute_xpu XPU basic SRCS split_compute.cc DEPS ${lite_kernel_deps})
add_kernel(sum_compute_xpu XPU basic SRCS sum_compute.cc DEPS ${lite_kernel_deps})
add_kernel(interpolate_compute_xpu XPU basic SRCS interpolate_compute.cc DEPS ${lite_kernel_deps})
add_kernel(fill_constant_compute_xpu XPU basic SRCS fill_constant_compute.cc DEPS ${lite_kernel_deps})
# extra
add_kernel(lookup_table_compute_xpu XPU extra SRCS lookup_table_compute.cc DEPS ${lite_kernel_deps})
......@@ -39,6 +44,10 @@ else()
add_kernel(var_conv_2d_compute_xpu XPU extra SRCS var_conv_2d_compute.cc DEPS ${lite_kernel_deps})
add_kernel(search_grnn_compute_xpu XPU extra SRCS search_grnn_compute.cc DEPS ${lite_kernel_deps})
add_kernel(sequence_unpad_compute_xpu XPU extra SRCS sequence_unpad_compute.cc DEPS ${lite_kernel_deps})
add_kernel(lrn_compute_xpu XPU extra SRCS lrn_compute.cc DEPS ${lite_kernel_deps})
add_kernel(topk_compute_xpu XPU extra SRCS topk_compute.cc DEPS ${lite_kernel_deps})
add_kernel(im2sequence_compute_xpu XPU extra SRCS im2sequence_compute.cc DEPS ${lite_kernel_deps})
add_kernel(gru_compute_xpu XPU extra SRCS gru_compute.cc DEPS ${lite_kernel_deps})
# extra(fused kernel)
add_kernel(__xpu__resnet50_compute_xpu XPU extra SRCS __xpu__resnet50_compute.cc DEPS ${lite_kernel_deps})
......
......@@ -13,6 +13,7 @@
// limitations under the License.
#include "lite/kernels/xpu/cast_compute.h"
#include <typeinfo>
#include "lite/backends/xpu/xpu_header_sitter.h"
#include "lite/core/op_registry.h"
......@@ -21,31 +22,61 @@ namespace lite {
namespace kernels {
namespace xpu {
template <typename InType>
void CastCompute<InType>::Run() {
void CastCompute::Run() {
auto& param = this->template Param<param_t>();
auto& ctx = this->ctx_->template As<XPUContext>();
auto* x = param.X;
auto* out = param.Out;
int out_dtype = param.out_dtype;
auto* in_data = x->template data<InType>();
int numel = x->numel();
int in_dtype = param.in_dtype;
int numel = param.X->numel();
int r = 0;
// BOOL = 0;INT16 = 1;INT32 = 2;INT64 = 3;FP16 = 4;FP32 = 5;FP64 = 6;
// SIZE_T = 19;UINT8 = 20;INT8 = 21;
if (out_dtype == 5) {
auto* out_data = out->template mutable_data<float>(TARGET(kXPU));
r = xdnn::cast<InType, float>(
if (in_dtype == 5 && out_dtype == 5) {
// float -> float
auto* in_data = param.X->data<float>();
auto* out_data = out->mutable_data<float>(TARGET(kXPU));
r = xdnn::cast<float, float>(ctx.GetRawContext(), in_data, out_data, numel);
} else if (in_dtype == 2 && out_dtype == 2) {
// int -> int
auto* in_data = param.X->data<int>();
auto* out_data = out->mutable_data<int>(TARGET(kXPU));
r = xdnn::cast<int, int>(ctx.GetRawContext(), in_data, out_data, numel);
} else if (in_dtype == 3 && out_dtype == 3) {
// int64 -> int64
auto* in_data = param.X->data<int64_t>();
auto* out_data = out->mutable_data<int64_t>(TARGET(kXPU));
r = xdnn::cast<int64_t, int64_t>(
ctx.GetRawContext(), in_data, out_data, numel);
} else if (in_dtype == 2 && out_dtype == 3) {
// int -> int64
auto* in_data = param.X->data<int>();
auto* out_data = out->mutable_data<int64_t>(TARGET(kXPU));
r = xdnn::cast<int, int64_t>(ctx.GetRawContext(), in_data, out_data, numel);
} else if (in_dtype == 2 && out_dtype == 5) {
// int -> float
auto* in_data = param.X->data<int>();
auto* out_data = out->mutable_data<float>(TARGET(kXPU));
r = xdnn::cast<int, float>(ctx.GetRawContext(), in_data, out_data, numel);
} else if (in_dtype == 3 && out_dtype == 5) {
// int64_t -> float
auto* in_data = param.X->data<int64_t>();
auto* out_data = out->mutable_data<float>(TARGET(kXPU));
r = xdnn::cast<int64_t, float>(
ctx.GetRawContext(), in_data, out_data, numel);
} else if (out_dtype == 2) {
auto* out_data = out->template mutable_data<int>(TARGET(kXPU));
r = xdnn::cast<InType, int>(ctx.GetRawContext(), in_data, out_data, numel);
} else if (out_dtype == 3) {
auto* out_data = out->template mutable_data<int64_t>(TARGET(kXPU));
r = xdnn::cast<InType, int64_t>(
} else if (in_dtype == 5 && out_dtype == 3) {
// float -> int64_t
auto* in_data = param.X->data<float>();
auto* out_data = out->mutable_data<int64_t>(TARGET(kXPU));
r = xdnn::cast<float, int64_t>(
ctx.GetRawContext(), in_data, out_data, numel);
} else if (in_dtype == 5 && out_dtype == 2) {
// float -> int
auto* in_data = param.X->data<float>();
auto* out_data = out->mutable_data<int>(TARGET(kXPU));
r = xdnn::cast<float, int>(ctx.GetRawContext(), in_data, out_data, numel);
} else {
CHECK(false);
}
......@@ -57,12 +88,8 @@ void CastCompute<InType>::Run() {
} // namespace lite
} // namespace paddle
REGISTER_LITE_KERNEL(cast,
kXPU,
kAny,
kNCHW,
paddle::lite::kernels::xpu::CastCompute<float>,
def)
REGISTER_LITE_KERNEL(
cast, kXPU, kAny, kNCHW, paddle::lite::kernels::xpu::CastCompute, def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kXPU), PRECISION(kAny))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kXPU), PRECISION(kAny))})
.Finalize();
......@@ -21,7 +21,6 @@ namespace lite {
namespace kernels {
namespace xpu {
template <typename InType>
class CastCompute : public KernelLite<TARGET(kXPU), PRECISION(kAny)> {
public:
using param_t = operators::CastParam;
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/xpu/conv2d_transpose_compute.h"
#include "lite/backends/xpu/xpu_header_sitter.h"
#include "lite/core/op_registry.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace xpu {
template <>
void Conv2dTransposeCompute<PRECISION(kFloat)>::PrepareForRun() {
maxs_xpu_guard_ =
TargetWrapperXPU::MallocScratchPad(8 * sizeof(float), false /* use_l3 */);
auto& ctx = this->ctx_->As<XPUContext>();
auto& param = this->Param<param_t>();
float* max_filter_ptr = reinterpret_cast<float*>(maxs_xpu_guard_->addr_);
int filter_size = param.filter->numel();
int r = xdnn::findmax<float>(ctx.GetRawContext(),
param.filter->data<float>(),
filter_size,
max_filter_ptr);
CHECK_EQ(r, 0);
}
template <>
void Conv2dTransposeCompute<PRECISION(kFloat)>::Run() {
auto& param = this->Param<param_t>();
auto& ctx = this->ctx_->As<XPUContext>();
auto& out_dims = param.output->dims();
auto& w_dims = param.filter->dims();
auto& in_dims = param.x->dims();
int groups = param.groups;
auto& strides = param.strides;
auto paddings = *param.paddings;
auto dilations = *param.dilations;
float* max_filter_ptr = reinterpret_cast<float*>(maxs_xpu_guard_->addr_);
float* max_image_ptr = max_filter_ptr + 4;
int image_size = param.x->numel();
// find image max
int r = xdnn::findmax<float>(
ctx.GetRawContext(), param.x->data<float>(), image_size, max_image_ptr);
CHECK_EQ(r, 0);
r = xdnn::conv2d_backward_int16(
ctx.GetRawContext(),
out_dims[0],
out_dims[1],
out_dims[2],
out_dims[3],
in_dims[1],
w_dims[2],
w_dims[3],
strides[0],
strides[1],
paddings[0],
paddings[1],
dilations[0],
dilations[1],
groups,
param.x->data<float>(),
param.filter->data<float>(),
param.output->mutable_data<float>(TARGET(kXPU)),
max_image_ptr,
max_filter_ptr);
CHECK_EQ(r, 0);
}
} // namespace xpu
} // namespace kernels
} // namespace lite
} // namespace paddle
namespace xpu = paddle::lite::kernels::xpu;
using Conv2dTransposeFp32 = xpu::Conv2dTransposeCompute<PRECISION(kFloat)>;
REGISTER_LITE_KERNEL(
conv2d_transpose, kXPU, kFloat, kNCHW, Conv2dTransposeFp32, def)
.BindInput("Input", {LiteType::GetTensorTy(TARGET(kXPU))})
.BindInput("Filter", {LiteType::GetTensorTy(TARGET(kXPU))})
.BindInput("Bias", {LiteType::GetTensorTy(TARGET(kXPU))})
.BindOutput("Output", {LiteType::GetTensorTy(TARGET(kXPU))})
.Finalize();
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include "lite/backends/xpu/target_wrapper.h" // XPUScratchPadGuard
#include "lite/core/kernel.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace xpu {
template <PrecisionType FilterPtype>
class Conv2dTransposeCompute : public KernelLite<TARGET(kXPU), FilterPtype> {
public:
using param_t = operators::ConvParam;
void PrepareForRun() override;
void Run() override;
virtual ~Conv2dTransposeCompute() = default;
private:
XPUScratchPadGuard maxs_xpu_guard_;
};
} // namespace xpu
} // namespace kernels
} // namespace lite
} // namespace paddle
......@@ -33,32 +33,55 @@ void Conv2dCompute<PRECISION(kFloat)>::Run() {
auto paddings = *param.paddings;
auto dilations = *param.dilations;
int r = xdnn::conv2d_forward_int16<float, float, float, float>(
ctx.GetRawContext(), /* context */
x_dims[0], /* num */
x_dims[1], /* input_c */
x_dims[2], /* input_h */
x_dims[3], /* input_w */
w_dims[0], /* num_filter */
w_dims[2], /* kernel_h */
w_dims[3], /* kernel_w */
strides[0], /* stride_h */
strides[1], /* stride_w */
paddings[0], /* pad_h */
paddings[1], /* pad_w */
dilations[0], /* dilation_h */
dilations[1], /* dilation_w */
groups, /* group */
param.x->data<float>(), /* bottom */
param.filter->data<float>(), /* weight */
param.output->mutable_data<float>(TARGET(kXPU)), /* top */
nullptr, /* bias */
nullptr, /* branch */
xdnn::Activation_t::LINEAR, /* type */
nullptr, /* max_image_ptr */
nullptr, /* max_filter_ptr */
nullptr /* max_result_ptr */);
CHECK_EQ(r, 0);
if (groups == 1) {
int r = xdnn::conv2d_forward_int16<float, float, float, float>(
ctx.GetRawContext(), /* context */
x_dims[0], /* num */
x_dims[1], /* input_c */
x_dims[2], /* input_h */
x_dims[3], /* input_w */
w_dims[0], /* num_filter */
w_dims[2], /* kernel_h */
w_dims[3], /* kernel_w */
strides[0], /* stride_h */
strides[1], /* stride_w */
paddings[0], /* pad_h */
paddings[1], /* pad_w */
dilations[0], /* dilation_h */
dilations[1], /* dilation_w */
groups, /* group */
param.x->data<float>(), /* bottom */
param.filter->data<float>(), /* weight */
param.output->mutable_data<float>(TARGET(kXPU)), /* top */
nullptr, /* bias */
nullptr, /* branch */
xdnn::Activation_t::LINEAR, /* type */
nullptr, /* max_image_ptr */
nullptr, /* max_filter_ptr */
nullptr /* max_result_ptr */);
CHECK_EQ(r, 0);
} else {
int r = xdnn::conv2d_int16_with_group<float, float, float>(
ctx.GetRawContext(), /* context */
param.x->data<float>(), /* bottom */
param.filter->data<float>(), /* weight */
param.output->mutable_data<float>(TARGET(kXPU)), /* top */
x_dims[0],
x_dims[1],
x_dims[2],
x_dims[3],
w_dims[0],
w_dims[2],
w_dims[3],
groups,
strides[0],
strides[1],
paddings[0],
paddings[1],
nullptr,
nullptr);
CHECK_EQ(r, 0);
}
}
} // namespace xpu
......@@ -75,3 +98,10 @@ REGISTER_LITE_KERNEL(conv2d, kXPU, kFloat, kNCHW, Conv2dFp32, def)
.BindInput("Filter", {LiteType::GetTensorTy(TARGET(kXPU))})
.BindOutput("Output", {LiteType::GetTensorTy(TARGET(kXPU))})
.Finalize();
REGISTER_LITE_KERNEL(depthwise_conv2d, kXPU, kFloat, kNCHW, Conv2dFp32, def)
.BindInput("Input", {LiteType::GetTensorTy(TARGET(kXPU))})
.BindInput("Bias", {LiteType::GetTensorTy(TARGET(kXPU))})
.BindInput("Filter", {LiteType::GetTensorTy(TARGET(kXPU))})
.BindOutput("Output", {LiteType::GetTensorTy(TARGET(kXPU))})
.Finalize();
......@@ -25,14 +25,25 @@ void DropoutCompute::Run() {
auto& param = this->Param<param_t>();
auto& ctx = this->ctx_->As<XPUContext>();
int size = param.x->numel() * sizeof(float);
int r = xdnn::memcpy_device(
ctx.GetRawContext(), /* context */
param.output->mutable_data<float>(TARGET(kXPU)), /* dst */
param.x->data<float>(), /* src */
size /* size */);
CHECK_EQ(r, 0);
if (param.is_test) {
float scale = 1.0f;
if (param.dropout_implementation == "upscale_in_train") {
scale = 1.0f;
} else {
scale = 1.0f - param.dropout_prob;
}
int r =
xdnn::scale(ctx.GetRawContext(), /* context */
param.x->numel(),
scale,
0.0f,
0,
param.x->data<float>(), /* src */
param.output->mutable_data<float>(TARGET(kXPU))); /* dst */
CHECK_EQ(r, 0);
} else {
CHECK(false);
}
}
} // namespace xpu
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/xpu/fill_constant_compute.h"
#include <iostream>
#include "lite/backends/xpu/xpu_header_sitter.h"
#include "lite/core/op_registry.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace xpu {
union TypeUnion {
float fp32;
int32_t int32;
};
void FillConstantCompute::Run() {
auto& param = this->Param<param_t>();
auto& ctx = this->ctx_->As<XPUContext>();
TypeUnion value;
int write_size = param.out->numel();
if (param.dtype == static_cast<int32_t>(lite::core::FluidType::FP32)) {
auto data = param.out->template mutable_data<float>(TARGET(kXPU));
value.fp32 = param.value;
write_size = write_size * sizeof(float);
int r = xdnn::memset(ctx.GetRawContext(), /* context */
reinterpret_cast<void*>(data),
value.int32,
write_size);
CHECK_EQ(r, 0);
} else if (param.dtype ==
static_cast<int32_t>(lite::core::FluidType::INT32)) {
auto data = param.out->template mutable_data<int32_t>(TARGET(kXPU));
value.int32 = param.value;
write_size = write_size * sizeof(int32_t);
int r = xdnn::memset(ctx.GetRawContext(), /* context */
reinterpret_cast<void*>(data),
value.int32,
write_size);
CHECK_EQ(r, 0);
} else if (param.dtype == static_cast<int32_t>(lite::core::FluidType::INT8)) {
auto data = param.out->template mutable_data<int8_t>(TARGET(kXPU));
value.int32 = 0;
for (int i = 0; i < 4; i++) {
value.int32 += static_cast<int32_t>(param.value);
value.int32 = value.int32 << 8;
}
int r = xdnn::memset(ctx.GetRawContext(), /* context */
reinterpret_cast<void*>(data),
value.int32,
write_size);
CHECK_EQ(r, 0);
} else {
LOG(FATAL) << "not supported dtype " << param.dtype;
}
}
} // namespace xpu
} // namespace kernels
} // namespace lite
} // namespace paddle
REGISTER_LITE_KERNEL(fill_constant,
kXPU,
kAny,
kNCHW,
paddle::lite::kernels::xpu::FillConstantCompute,
def)
.BindInput("ShapeTensor",
{LiteType::GetTensorTy(TARGET(kXPU), PRECISION(kInt32))})
.BindInput("ShapeTensorList",
{LiteType::GetTensorTy(TARGET(kXPU), PRECISION(kInt32))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kXPU), PRECISION(kAny))})
.Finalize();
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "lite/core/kernel.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace xpu {
class FillConstantCompute : public KernelLite<TARGET(kXPU), PRECISION(kAny)> {
public:
using param_t = operators::FillConstantParam;
virtual void Run();
virtual ~FillConstantCompute() = default;
};
} // namespace xpu
} // namespace kernels
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/xpu/gru_compute.h"
#include <math.h>
#include <algorithm>
#include <map>
#include <string>
#include <vector>
#include "lite/backends/xpu/xpu_header_sitter.h"
#include "lite/core/op_registry.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace xpu {
inline xdnn::Activation_t get_gru_act_type(const std::string& type) {
std::map<std::string, xdnn::Activation_t> act_type_map = {
{"sigmoid", xdnn::Activation_t::SIGMOID},
{"tanh", xdnn::Activation_t::TANH},
{"relu", xdnn::Activation_t::RELU}};
auto it = act_type_map.find(type);
if (it != act_type_map.end()) {
return it->second;
} else {
LOG(FATAL) << "unsupported activation type: " << type;
}
}
void GruCompute::PrepareForRun() {
offset_xpu_guard_ = TargetWrapperXPU::MallocScratchPad(
XPU_MAX_LOD_SIZE * sizeof(int), false /* use_l3 */);
new_offset_xpu_guard_ = TargetWrapperXPU::MallocScratchPad(
XPU_MAX_LOD_SEQ_LEN * sizeof(int), false /* use_l3 */);
idx_sorted_by_width_data_xpu_guard_ = TargetWrapperXPU::MallocScratchPad(
XPU_MAX_LOD_SIZE * sizeof(int), false /* use_l3 */);
idx_sorted_by_width_data_cpu.reset(new int[XPU_MAX_LOD_SIZE]);
offset_cpu.reset(new int[XPU_MAX_LOD_SIZE]);
new_offset_cpu.reset(new int[XPU_MAX_LOD_SEQ_LEN]);
// find max
maxs_xpu_guard_ =
TargetWrapperXPU::MallocScratchPad(8 * sizeof(float), false /* use_l3 */);
auto& ctx = this->ctx_->As<XPUContext>();
auto& param = this->Param<param_t>();
int frame_size = param.input->dims()[1] / 3;
float* weight_ur_max_ptr_xpu =
reinterpret_cast<float*>(maxs_xpu_guard_->addr_);
float* weight_c_max_ptr_xpu = weight_ur_max_ptr_xpu + 4;
// weight_ur_max
int ret = xdnn::findmax(ctx.GetRawContext(),
param.weight->data<float>(),
frame_size * frame_size * 2,
weight_ur_max_ptr_xpu);
CHECK_EQ(ret, 0);
// weight_c_max
ret = xdnn::findmax(ctx.GetRawContext(),
param.weight->data<float>() + frame_size * frame_size * 2,
frame_size * frame_size,
weight_c_max_ptr_xpu);
CHECK_EQ(ret, 0);
float weight_ur_max_cpu[4];
XPU_CALL(xpu_memcpy(weight_ur_max_cpu,
weight_ur_max_ptr_xpu,
sizeof(float) * 4,
XPUMemcpyKind::XPU_DEVICE_TO_HOST));
weight_u_r_max_value =
std::max(std::max(weight_ur_max_cpu[0], weight_ur_max_cpu[1]),
std::max(weight_ur_max_cpu[2], weight_ur_max_cpu[3]));
float weight_c_max_cpu[4];
XPU_CALL(xpu_memcpy(weight_c_max_cpu,
weight_c_max_ptr_xpu,
sizeof(float) * 4,
XPUMemcpyKind::XPU_DEVICE_TO_HOST));
weight_c_max_value =
std::max(std::max(weight_c_max_cpu[0], weight_c_max_cpu[1]),
std::max(weight_c_max_cpu[2], weight_c_max_cpu[3]));
}
void GruCompute::prepare_layout(const paddle::lite::LoD& lods,
int* offset_xpu,
int* new_offset_xpu,
int* idx_sorted_by_width_data_xpu) {
const auto& lod = lods[0];
for (auto i = 0; i < lod.size(); i++) {
offset_cpu[i] = lod[i];
}
for (size_t seq_id = 0; seq_id < lod.size() - 1; ++seq_id) {
int length = lod[seq_id + 1] - lod[seq_id];
seq_info.push_back(SeqInfo(lod[seq_id], length, seq_id));
}
std::cout << "seq len is " << seq_info.size() << std::endl;
std::stable_sort(seq_info.begin(), seq_info.end(), [](SeqInfo a, SeqInfo b) {
return a.length > b.length;
});
for (auto i = 0; i < seq_info.size(); i++) {
idx_sorted_by_width_data_cpu[i] = seq_info[i].seq_idx;
}
// max_width
int max_width = seq_info[0].length;
new_offset_cpu[0] = 0;
int cur_offset_idx = 1;
for (auto i = 0; i < seq_info.size(); i++) {
int cur_length = seq_info.size() - i;
int repeat_times = (i == 0) ? seq_info[i].length
: (seq_info[i].length - seq_info[i - 1].length);
for (int j = 0; j < repeat_times; j++) {
new_offset_cpu[cur_offset_idx] =
new_offset_cpu[cur_offset_idx - 1] + cur_length;
cur_offset_idx++;
}
}
XPU_CALL(xpu_memcpy(offset_xpu,
offset_cpu.get(),
sizeof(int) * lod.size(),
XPU_HOST_TO_DEVICE));
XPU_CALL(xpu_memcpy(idx_sorted_by_width_data_xpu,
idx_sorted_by_width_data_cpu.get(),
sizeof(int) * seq_info.size(),
XPU_HOST_TO_DEVICE));
XPU_CALL(xpu_memcpy(new_offset_xpu,
new_offset_cpu.get(),
sizeof(int) * (max_width + 1),
XPU_HOST_TO_DEVICE));
}
void GruCompute::Run() {
auto& param = this->Param<param_t>();
auto& ctx = this->ctx_->As<XPUContext>();
auto input = param.input;
float* batch_gate = param.batch_gate->mutable_data<float>(TARGET(kXPU));
float* batch_reset_hidden_prev =
param.batch_reset_hidden_prev->mutable_data<float>(TARGET(kXPU));
float* batch_hidden = param.hidden->mutable_data<float>(TARGET(kXPU));
bool origin_mode = param.origin_mode;
int frame_size = input->dims()[1] / 3;
int* offset_xpu = reinterpret_cast<int*>(offset_xpu_guard_->addr_);
int* new_offset_xpu = reinterpret_cast<int*>(new_offset_xpu_guard_->addr_);
int* idx_sorted_by_width_data_xpu =
reinterpret_cast<int*>(idx_sorted_by_width_data_xpu_guard_->addr_);
// prepare seq_info
auto lods = input->lod();
const auto& lod = lods[0];
prepare_layout(
lods, offset_xpu, new_offset_xpu, idx_sorted_by_width_data_xpu);
int max_width = seq_info[0].length;
// sequence to batch
XPUScratchPadGuard xpu_batch_data_guard_ = TargetWrapperXPU::MallocScratchPad(
lod[lod.size() - 1] * frame_size * 3 * sizeof(float), false /*use_l3 */);
float* batch_data = reinterpret_cast<float*>(xpu_batch_data_guard_->addr_);
bool is_reverse = param.is_reverse;
if (is_reverse) {
int ret = xdnn::sequence_reverse(ctx.GetRawContext(), /* context */
lod.size() - 1,
offset_xpu,
frame_size * 3,
param.input->data<float>(),
batch_data);
CHECK_EQ(ret, 0);
ret = xdnn::search_seq2batch(ctx.GetRawContext(), /* context */
lod.size() - 1,
max_width,
frame_size * 3,
idx_sorted_by_width_data_xpu,
offset_xpu,
new_offset_xpu,
batch_data,
batch_data);
CHECK_EQ(ret, 0);
} else {
int ret = xdnn::search_seq2batch(ctx.GetRawContext(), /* context */
lod.size() - 1,
max_width,
frame_size * 3,
idx_sorted_by_width_data_xpu,
offset_xpu,
new_offset_xpu,
param.input->data<float>(),
batch_data);
CHECK_EQ(ret, 0);
}
// perpare xpu_h_p
auto* h0 = param.h0;
XPUScratchPadGuard xpu_h0_guard_ = TargetWrapperXPU::MallocScratchPad(
(lod.size() - 1) * frame_size * sizeof(float), false /*use_l3 */);
float* xpu_h0_start = reinterpret_cast<float*>(xpu_h0_guard_->addr_);
float* xpu_h0 = xpu_h0_start;
if (h0) {
for (auto i = 0; i < seq_info.size(); i++) {
int ret = xdnn::memcpy_device(
ctx.GetRawContext(),
xpu_h0 + i * frame_size,
h0->data<float>() + seq_info[i].seq_idx * frame_size,
sizeof(float) * frame_size);
CHECK_EQ(ret, 0);
}
} else {
// initial with zero
int ret = xdnn::scale(ctx.GetRawContext(),
frame_size * seq_info.size(),
0.0,
0.0,
false,
xpu_h0,
xpu_h0);
CHECK_EQ(ret, 0);
}
// gru
for (int batch_idx = 0; batch_idx < max_width; batch_idx++) {
float* x = batch_data + new_offset_cpu[batch_idx] * frame_size * 3;
int ret = xdnn::gru_unit_int31(
ctx.GetRawContext(),
new_offset_cpu[batch_idx + 1] - new_offset_cpu[batch_idx],
frame_size,
origin_mode,
get_gru_act_type(param.gate_activation),
get_gru_act_type(param.activation),
x,
xpu_h0,
param.weight->data<float>(),
weight_u_r_max_value,
weight_c_max_value,
param.bias->data<float>(),
batch_gate + new_offset_cpu[batch_idx] * frame_size * 3,
batch_reset_hidden_prev + new_offset_cpu[batch_idx] * frame_size,
batch_hidden + new_offset_cpu[batch_idx] * frame_size);
CHECK_EQ(ret, 0);
xpu_h0 = batch_hidden + new_offset_cpu[batch_idx] * frame_size;
}
// batch to sequence
if (is_reverse) {
int ret = xdnn::search_batch2seq(ctx.GetRawContext(),
seq_info.size(),
max_width,
frame_size,
idx_sorted_by_width_data_xpu,
offset_xpu,
new_offset_xpu,
batch_hidden,
batch_data);
CHECK_EQ(ret, 0);
ret =
xdnn::sequence_reverse(ctx.GetRawContext(),
lod.size() - 1,
offset_xpu,
frame_size,
batch_data,
param.hidden->mutable_data<float>(TARGET(kXPU)));
CHECK_EQ(ret, 0);
} else {
int ret =
xdnn::search_batch2seq(ctx.GetRawContext(),
seq_info.size(),
max_width,
frame_size,
idx_sorted_by_width_data_xpu,
offset_xpu,
new_offset_xpu,
batch_hidden,
param.hidden->mutable_data<float>(TARGET(kXPU)));
CHECK_EQ(ret, 0);
}
seq_info.clear();
}
} // namespace xpu
} // namespace kernels
} // namespace lite
} // namespace paddle
REGISTER_LITE_KERNEL(
gru, kXPU, kFloat, kNCHW, paddle::lite::kernels::xpu::GruCompute, def)
.BindInput("Input", {LiteType::GetTensorTy(TARGET(kXPU))})
.BindInput("H0", {LiteType::GetTensorTy(TARGET(kXPU))})
.BindInput("Weight", {LiteType::GetTensorTy(TARGET(kXPU))})
.BindInput("Bias", {LiteType::GetTensorTy(TARGET(kXPU))})
.BindOutput("BatchGate", {LiteType::GetTensorTy(TARGET(kXPU))})
.BindOutput("BatchResetHiddenPrev", {LiteType::GetTensorTy(TARGET(kXPU))})
.BindOutput("BatchHidden", {LiteType::GetTensorTy(TARGET(kXPU))})
.BindOutput("Hidden", {LiteType::GetTensorTy(TARGET(kXPU))})
.Finalize();
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include <vector>
#include "lite/backends/xpu/target_wrapper.h" // XPUScratchPadGuard
#include "lite/core/kernel.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace xpu {
class GruCompute : public KernelLite<TARGET(kXPU), PRECISION(kFloat)> {
public:
using param_t = operators::GRUParam;
void PrepareForRun() override;
void prepare_layout(const paddle::lite::LoD& lods,
int* offset_xpu,
int* new_offset_xpu,
int* idx_sorted_by_width_data_xpu);
void Run() override;
virtual ~GruCompute() = default;
private:
XPUScratchPadGuard offset_xpu_guard_;
XPUScratchPadGuard new_offset_xpu_guard_;
XPUScratchPadGuard maxs_xpu_guard_;
XPUScratchPadGuard idx_sorted_by_width_data_xpu_guard_;
float weight_u_r_max_value;
float weight_c_max_value;
std::unique_ptr<int[]> idx_sorted_by_width_data_cpu;
std::unique_ptr<int[]> offset_cpu;
std::unique_ptr<int[]> new_offset_cpu;
struct SeqInfo {
SeqInfo() = default;
SeqInfo(int start, int length, int seq_idx)
: start(start), length(length), seq_idx(seq_idx) {}
int start;
int length;
int seq_idx;
};
std::vector<SeqInfo> seq_info;
};
} // namespace xpu
} // namespace kernels
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/xpu/im2sequence_compute.h"
#include <vector>
#include "lite/backends/xpu/xpu_header_sitter.h"
#include "lite/core/op_registry.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace xpu {
inline int Im2SeqOutputSize(
int input_size, int filter_size, int padding_0, int padding_1, int stride) {
const int output_size =
(input_size + padding_0 + padding_1 - filter_size) / stride + 1;
return output_size;
}
void Im2SequenceCompute::Run() {
auto& param = this->Param<param_t>();
auto& ctx = this->ctx_->As<XPUContext>();
auto x_dims = param.X->dims();
int batch = x_dims[0];
int channel = x_dims[1];
int height = x_dims[2];
int width = x_dims[3];
int kernel_h = param.kernels[0];
int kernel_w = param.kernels[1];
int stride_h = param.strides[0];
int stride_w = param.strides[1];
int dilation_h = 1;
int dilation_w = 1;
int pad_h = param.paddings[0];
int pad_w = param.paddings[1];
int output_height =
Im2SeqOutputSize(height, kernel_h, pad_h, pad_h, stride_h);
int output_width = Im2SeqOutputSize(width, kernel_w, pad_w, pad_w, stride_w);
std::vector<uint64_t> out_offset;
out_offset.push_back(0);
out_offset.push_back(output_height * output_width);
for (int batch_idx = 0; batch_idx < batch; batch_idx++) {
int r = xdnn::im2col_ocf(
ctx.GetRawContext(), /* context */
channel,
height,
width,
kernel_h,
kernel_w,
pad_h,
pad_w,
stride_h,
stride_w,
dilation_h,
dilation_w,
param.X->data<float>() + batch_idx * channel * height * width,
param.Out->mutable_data<float>(TARGET(kXPU)) +
batch_idx * output_height * output_width * channel * kernel_h *
kernel_w);
CHECK_EQ(r, 0);
}
auto lod = param.Out->mutable_lod();
lod->resize(1);
(*lod)[0] = out_offset;
}
} // namespace xpu
} // namespace kernels
} // namespace lite
} // namespace paddle
REGISTER_LITE_KERNEL(im2sequence,
kXPU,
kFloat,
kNCHW,
paddle::lite::kernels::xpu::Im2SequenceCompute,
def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kXPU))})
.BindInput("Y", {LiteType::GetTensorTy(TARGET(kXPU))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kXPU))})
.Finalize();
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "lite/core/kernel.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace xpu {
class Im2SequenceCompute : public KernelLite<TARGET(kXPU), PRECISION(kFloat)> {
public:
using param_t = operators::Im2SequenceParam;
virtual void Run();
virtual ~Im2SequenceCompute() = default;
};
} // namespace xpu
} // namespace kernels
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/xpu/interpolate_compute.h"
#include <iostream>
#include <memory>
#include "lite/backends/xpu/xpu_header_sitter.h"
#include "lite/core/op_registry.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace xpu {
void BilinearInterpCompute::Run() {
auto& param = this->Param<param_t>();
auto& ctx = this->ctx_->As<XPUContext>();
auto x_dims = param.X->dims();
CHECK_EQ(x_dims.size(), 4);
int n = x_dims[0];
int c = x_dims[1];
int in_h = x_dims[2];
int in_w = x_dims[3];
int out_w = param.out_w;
int out_h = param.out_h;
float scale = param.scale;
if (scale > 0) {
out_h = static_cast<int>(in_h * scale);
out_w = static_cast<int>(in_w * scale);
}
if (param.OutSize != nullptr) {
out_h = param.OutSize->data<int>()[0];
out_w = param.OutSize->data<int>()[1];
}
bool align_corners = param.align_corners;
CHECK_EQ(align_corners, 1) << "XPU only support align corners = 1";
int r = xdnn::bilinear_interp(ctx.GetRawContext(), /* context */
param.X->data<float>(),
param.Out->mutable_data<float>(TARGET(kXPU)),
n,
c,
in_h,
in_w,
out_h,
out_w,
align_corners,
1);
CHECK_EQ(r, 0);
}
void NearestInterpCompute::Run() {
auto& param = this->Param<param_t>();
auto& ctx = this->ctx_->As<XPUContext>();
auto x_dims = param.X->dims();
CHECK_EQ(x_dims.size(), 4);
int n = x_dims[0];
int c = x_dims[1];
int in_h = x_dims[2];
int in_w = x_dims[3];
int out_w = param.out_w;
int out_h = param.out_h;
float scale = param.scale;
if (scale > 0) {
out_h = static_cast<int>(in_h * scale);
out_w = static_cast<int>(in_w * scale);
}
if (param.OutSize != nullptr) {
out_h = param.OutSize->data<int>()[0];
out_w = param.OutSize->data<int>()[1];
}
bool align_corners = param.align_corners;
int r = xdnn::interpolate(ctx.GetRawContext(), /* context */
param.X->data<float>(),
param.Out->mutable_data<float>(TARGET(kXPU)),
n,
c,
in_h,
in_w,
out_h,
out_w,
align_corners);
CHECK_EQ(r, 0);
}
} // namespace xpu
} // namespace kernels
} // namespace lite
} // namespace paddle
REGISTER_LITE_KERNEL(bilinear_interp,
kXPU,
kFloat,
kNCHW,
paddle::lite::kernels::xpu::BilinearInterpCompute,
def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kXPU))})
.BindInput("OutSize",
{LiteType::GetTensorTy(TARGET(kX86), PRECISION(kInt32))})
.BindInput("SizeTensor",
{LiteType::GetTensorTy(TARGET(kXPU), PRECISION(kInt32))})
.BindInput("Scale", {LiteType::GetTensorTy(TARGET(kXPU))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kXPU))})
.Finalize();
REGISTER_LITE_KERNEL(nearest_interp,
kXPU,
kFloat,
kNCHW,
paddle::lite::kernels::xpu::NearestInterpCompute,
def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kXPU))})
.BindInput("OutSize",
{LiteType::GetTensorTy(TARGET(kX86), PRECISION(kInt32))})
.BindInput("SizeTensor",
{LiteType::GetTensorTy(TARGET(kXPU), PRECISION(kInt32))})
.BindInput("Scale", {LiteType::GetTensorTy(TARGET(kXPU))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kXPU))})
.Finalize();
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "lite/core/kernel.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace xpu {
class BilinearInterpCompute
: public KernelLite<TARGET(kXPU), PRECISION(kFloat)> {
public:
using param_t = operators::InterpolateParam;
void Run() override;
virtual ~BilinearInterpCompute() = default;
};
class NearestInterpCompute
: public KernelLite<TARGET(kXPU), PRECISION(kFloat)> {
public:
using param_t = operators::InterpolateParam;
void Run() override;
virtual ~NearestInterpCompute() = default;
};
} // namespace xpu
} // namespace kernels
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/xpu/lrn_compute.h"
#include "lite/backends/xpu/xpu_header_sitter.h"
#include "lite/core/op_registry.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace xpu {
void LrnCompute::Run() {
auto& param = this->Param<param_t>();
auto& ctx = this->ctx_->As<XPUContext>();
auto x_dims = param.X->dims();
int batch = x_dims[0];
int channel = x_dims[1];
int h = x_dims[2];
int w = x_dims[3];
int n = param.n;
float alpha = param.alpha;
float beta = param.beta;
float k = param.k;
if (param.norm_region == "AcrossChannels") {
int r = xdnn::lrn_fwd(ctx.GetRawContext(),
param.X->data<float>(),
param.Out->mutable_data<float>(TARGET(kXPU)),
batch,
channel,
h,
w,
n,
k,
alpha,
beta);
CHECK_EQ(r, 0);
} else {
LOG(FATAL) << "Unsupport Norm Region Type: " << param.norm_region;
}
}
} // namespace xpu
} // namespace kernels
} // namespace lite
} // namespace paddle
REGISTER_LITE_KERNEL(
lrn, kXPU, kFloat, kNCHW, paddle::lite::kernels::xpu::LrnCompute, def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kXPU))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kXPU))})
.BindOutput("MidOut", {LiteType::GetTensorTy(TARGET(kXPU))})
.Finalize();
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "lite/core/kernel.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace xpu {
class LrnCompute : public KernelLite<TARGET(kXPU), PRECISION(kFloat)> {
public:
using param_t = operators::LrnParam;
virtual void Run();
virtual ~LrnCompute() = default;
};
} // namespace xpu
} // namespace kernels
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/xpu/split_compute.h"
#include <vector>
#include "lite/backends/xpu/xpu_header_sitter.h"
#include "lite/core/op_registry.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace xpu {
void SplitCompute::Run() {
auto& param = this->Param<param_t>();
auto& ctx = this->ctx_->As<XPUContext>();
auto& dout = param.output;
auto in_dim = param.x->dims();
auto axis = param.axis;
int height = 1;
for (int i = 0; i < axis; i++) {
height = height * in_dim[i];
}
int n = 0;
std::vector<float*> out_ptrs;
std::vector<int> width_out;
for (auto out : dout) {
n++;
out->set_lod(param.x->lod());
out_ptrs.push_back(out->mutable_data<float>(TARGET(kXPU)));
int out_strides = out->numel();
width_out.push_back(out_strides / height);
}
int r = xdnn::concat_grad(ctx.GetRawContext(),
height,
width_out.data(),
n,
out_ptrs.data(),
param.x->data<float>());
CHECK_EQ(r, 0);
}
} // namespace xpu
} // namespace kernels
} // namespace lite
} // namespace paddle
REGISTER_LITE_KERNEL(
split, kXPU, kFloat, kNCHW, paddle::lite::kernels::xpu::SplitCompute, def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kXPU))})
.BindInput("AxisTensor",
{LiteType::GetTensorTy(TARGET(kXPU), PRECISION(kInt32))})
.BindOutput("SectionsTensorList",
{LiteType::GetTensorTy(TARGET(kXPU), PRECISION(kInt32))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kXPU))})
.Finalize();
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "lite/core/kernel.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace xpu {
class SplitCompute : public KernelLite<TARGET(kXPU), PRECISION(kFloat)> {
public:
using param_t = operators::SplitParam;
virtual void Run();
virtual ~SplitCompute() = default;
};
} // namespace xpu
} // namespace kernels
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/xpu/sum_compute.h"
#include <vector>
#include "lite/backends/xpu/xpu_header_sitter.h"
#include "lite/core/op_registry.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace xpu {
void SumCompute::Run() {
auto& param = this->Param<param_t>();
auto& ctx = this->ctx_->As<XPUContext>();
int N = param.x.size();
if (N == 1) {
param.output->ShareDataWith(*param.x[0]);
return;
}
std::vector<const float*> ptrs(N, nullptr);
for (int i = 0; i < N; i++) {
ptrs[i] = param.x[i]->data<float>();
}
int out_numel = param.output->numel();
int r = xdnn::sum_batch(ctx.GetRawContext(),
ptrs.data(),
param.output->mutable_data<float>(TARGET(kXPU)),
N,
out_numel);
CHECK_EQ(r, 0);
}
} // namespace xpu
} // namespace kernels
} // namespace lite
} // namespace paddle
REGISTER_LITE_KERNEL(
sum, kXPU, kFloat, kNCHW, paddle::lite::kernels::xpu::SumCompute, def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kXPU))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kXPU))})
.Finalize();
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "lite/core/kernel.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace xpu {
class SumCompute : public KernelLite<TARGET(kXPU), PRECISION(kFloat)> {
public:
using param_t = operators::SumParam;
virtual void Run();
virtual ~SumCompute() = default;
};
} // namespace xpu
} // namespace kernels
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/xpu/topk_compute.h"
#include "lite/backends/xpu/target_wrapper.h" // XPUScratchPadGuard
#include "lite/backends/xpu/xpu_header_sitter.h"
#include "lite/core/op_registry.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace xpu {
void TopkCompute::Run() {
auto& param = this->Param<param_t>();
auto& ctx = this->ctx_->As<XPUContext>();
DDim x_dims = param.X->dims();
int K = param.K;
int dim_size = x_dims.size();
int m = x_dims.production() / x_dims[dim_size - 1];
int n = x_dims[dim_size - 1];
XPUScratchPadGuard indices_xpu_guard_ = TargetWrapperXPU::MallocScratchPad(
m * K * sizeof(int), false /* use_l3 */);
int* indices_int32_device = reinterpret_cast<int*>(indices_xpu_guard_->addr_);
int64_t* indices_int64_device =
param.Indices->mutable_data<int64_t>(TARGET(kXPU));
int r = xdnn::topk(ctx.GetRawContext(),
param.X->data<float>(),
param.Out->mutable_data<float>(TARGET(kXPU)),
indices_int32_device,
m,
n,
K);
CHECK_EQ(r, 0);
r = xdnn::cast<int, int64_t>(
ctx.GetRawContext(), indices_int32_device, indices_int64_device, m * K);
CHECK_EQ(r, 0);
}
} // namespace xpu
} // namespace kernels
} // namespace lite
} // namespace paddle
REGISTER_LITE_KERNEL(
top_k, kXPU, kFloat, kNCHW, paddle::lite::kernels::xpu::TopkCompute, def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kXPU))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kXPU))})
.BindOutput("Indices",
{LiteType::GetTensorTy(TARGET(kXPU), PRECISION(kInt64))})
.Finalize();
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "lite/core/kernel.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace xpu {
class TopkCompute : public KernelLite<TARGET(kXPU), PRECISION(kFloat)> {
public:
using param_t = operators::TopkParam;
virtual void Run();
virtual ~TopkCompute() = default;
};
} // namespace xpu
} // namespace kernels
} // namespace lite
} // namespace paddle
......@@ -26,6 +26,7 @@ add_operator(interpolate_op basic SRCS interpolate_op.cc DEPS ${op_DEPS})
add_operator(argmax_op basic SRCS argmax_op.cc DEPS ${op_DEPS})
add_operator(prior_box_op basic SRCS prior_box_op.cc DEPS ${op_DEPS})
add_operator(concat_op basic SRCS concat_op.cc DEPS ${op_DEPS})
add_operator(sum_op basic SRCS sum_op.cc DEPS ${op_DEPS})
add_operator(pad2d_op basic SRCS pad2d_op.cc DEPS ${op_DEPS})
add_operator(calib_op basic SRCS calib_op.cc DEPS ${op_DEPS})
add_operator(split_op basic SRCS split_op.cc DEPS ${op_DEPS})
......
......@@ -354,6 +354,31 @@ struct ReshapeParam : ParamBase {
}
};
// For Sum op
struct SumParam : ParamBase {
std::vector<lite::Tensor*> x{};
lite::Tensor* output{};
bool use_mkldnn{false};
// get a vector of input tensors
const std::vector<const Tensor*>* input_tensor_ptrs() override {
if (!input_tensor_ptrs_cache_) {
std::vector<const Tensor*> vec;
for (auto in : x) {
vec.push_back(in);
}
input_tensor_ptrs_cache_.reset(new std::vector<const Tensor*>(vec));
}
return input_tensor_ptrs_cache_.get();
}
// get a vector of output tensors
std::vector<Tensor*>* output_tensor_ptrs() override {
if (!output_tensor_ptrs_cache_) {
output_tensor_ptrs_cache_.reset(new std::vector<lite::Tensor*>({output}));
}
return output_tensor_ptrs_cache_.get();
}
};
// For Concat op
struct ConcatParam : ParamBase {
std::vector<lite::Tensor*> x{};
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/operators/sum_op.h"
#include "lite/core/op_lite.h"
#include "lite/core/op_registry.h"
namespace paddle {
namespace lite {
namespace operators {
bool SumOpLite::CheckShape() const {
CHECK_OR_FALSE(param_.output);
CHECK_GE_OR_FALSE(param_.x.size(), 1UL);
return true;
}
bool SumOpLite::InferShapeImpl() const {
const std::vector<Tensor *> &inputs = param_.x;
const size_t n = inputs.size();
CHECK_GT_OR_FALSE(n, 0);
auto out_dims = inputs[0]->dims();
// Set output dims
param_.output->Resize(out_dims);
auto out_lod = param_.output->mutable_lod();
*out_lod = param_.x[0]->lod();
return true;
}
// TODO(Superjomn) replace framework::OpDesc with a lite one.
bool SumOpLite::AttachImpl(const cpp::OpDesc &op_desc, lite::Scope *scope) {
AttachParam(&param_);
auto inputs = op_desc.Input("X");
auto out = op_desc.Output("Out").front();
param_.x.clear();
for (auto var : inputs) {
param_.x.push_back(scope->FindVar(var)->GetMutable<lite::Tensor>());
}
CHECK(scope->FindVar(out));
param_.output = scope->FindVar(out)->GetMutable<lite::Tensor>();
param_.use_mkldnn = op_desc.GetAttr<bool>("use_mkldnn");
return true;
}
} // namespace operators
} // namespace lite
} // namespace paddle
REGISTER_LITE_OP(sum, paddle::lite::operators::SumOpLite);
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include <vector>
#include "lite/core/op_lite.h"
#include "lite/core/scope.h"
#include "lite/utils/all.h"
namespace paddle {
namespace lite {
namespace operators {
class SumOpLite : public OpLite {
public:
SumOpLite() {}
explicit SumOpLite(const std::string &op_type) : OpLite(op_type) {}
bool CheckShape() const override;
bool InferShapeImpl() const override;
bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override;
void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); }
std::string DebugString() const override { return "sum"; }
#ifdef LITE_WITH_PROFILE
void GetOpRuntimeInfo(paddle::lite::profile::OpCharacter *ch) {
auto output_dims = param_.output->dims();
std::string inputs_shape = "";
for (size_t i = 0; i < param_.x.size(); ++i) {
inputs_shape += ch->DimToStr(param_.x[i]->dims());
if (i != param_.x.size() - 1) inputs_shape += "/";
}
ch->input_shape = inputs_shape;
ch->output_shape = ch->DimToStr(output_dims);
ch->macs = 0.f; // no calc. only io operation
}
#endif
private:
mutable SumParam param_;
};
} // namespace operators
} // namespace lite
} // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册