未验证 提交 07104881 编写于 作者: Z zhupengyang 提交者: GitHub

support int64 of slice, split; support cast: int64->int64, int32->int64 (#4048)

上级 8e76e305
...@@ -79,6 +79,13 @@ void slice(const Dtype* input, ...@@ -79,6 +79,13 @@ void slice(const Dtype* input,
} }
} }
template void slice(const float* input,
std::vector<int64_t> dims,
std::vector<int> axes,
std::vector<int> starts,
std::vector<int> ends,
float* out,
Context<TARGET(kARM)>* ctx);
template void slice(const int* input, template void slice(const int* input,
std::vector<int64_t> dims, std::vector<int64_t> dims,
std::vector<int> axes, std::vector<int> axes,
...@@ -86,12 +93,12 @@ template void slice(const int* input, ...@@ -86,12 +93,12 @@ template void slice(const int* input,
std::vector<int> ends, std::vector<int> ends,
int* out, int* out,
Context<TARGET(kARM)>* ctx); Context<TARGET(kARM)>* ctx);
template void slice(const float* input, template void slice(const int64_t* input,
std::vector<int64_t> dims, std::vector<int64_t> dims,
std::vector<int> axes, std::vector<int> axes,
std::vector<int> starts, std::vector<int> starts,
std::vector<int> ends, std::vector<int> ends,
float* out, int64_t* out,
Context<TARGET(kARM)>* ctx); Context<TARGET(kARM)>* ctx);
} // namespace math } // namespace math
......
...@@ -51,11 +51,11 @@ void split_cpy<float>(const float* din, float* dout, int num) { ...@@ -51,11 +51,11 @@ void split_cpy<float>(const float* din, float* dout, int num) {
} }
} }
template <> template <typename T>
void split<float>(const float* din, void split(const T* din,
const std::vector<lite::Tensor*>& dout, const std::vector<lite::Tensor*>& dout,
const int axis, const int axis,
const std::vector<int>& in_strides) { const std::vector<int>& in_strides) {
int input_offset = 0; int input_offset = 0;
for (auto out : dout) { for (auto out : dout) {
auto out_dim = out->dims(); auto out_dim = out->dims();
...@@ -65,15 +65,15 @@ void split<float>(const float* din, ...@@ -65,15 +65,15 @@ void split<float>(const float* din,
out_strides[i] = out_strides[i + 1] * out_dim[i]; out_strides[i] = out_strides[i + 1] * out_dim[i];
} }
float* out_data = out->mutable_data<float>(); T* out_data = out->mutable_data<T>();
int before = out_strides[0] / out_strides[axis]; int before = out_strides[0] / out_strides[axis];
int in_after = in_strides[axis]; int in_after = in_strides[axis];
int out_after = out_strides[axis]; int out_after = out_strides[axis];
const float* din_ptr = din + input_offset; const T* din_ptr = din + input_offset;
for (int i = 0; i < before; ++i) { for (int i = 0; i < before; ++i) {
std::memcpy(out_data, din_ptr, sizeof(float) * out_after); std::memcpy(out_data, din_ptr, sizeof(T) * out_after);
din_ptr += in_after; din_ptr += in_after;
out_data += out_after; out_data += out_after;
} }
...@@ -81,6 +81,15 @@ void split<float>(const float* din, ...@@ -81,6 +81,15 @@ void split<float>(const float* din,
} }
} }
template void split(const float* din,
const std::vector<lite::Tensor*>& dout,
const int axis,
const std::vector<int>& in_strides);
template void split(const int64_t* din,
const std::vector<lite::Tensor*>& dout,
const int axis,
const std::vector<int>& in_strides);
} // namespace math } // namespace math
} // namespace arm } // namespace arm
} // namespace lite } // namespace lite
......
...@@ -40,6 +40,11 @@ void CastCompute::Run() { ...@@ -40,6 +40,11 @@ void CastCompute::Run() {
const auto* x_data = param.X->data<float>(); const auto* x_data = param.X->data<float>();
auto* o_data = param.Out->mutable_data<float>(); auto* o_data = param.Out->mutable_data<float>();
memcpy(o_data, x_data, sizeof(float) * param.X->numel()); memcpy(o_data, x_data, sizeof(float) * param.X->numel());
} else if (param.in_dtype == param.out_dtype &&
param.in_dtype == 3) { // int64->int64
const auto* x_data = param.X->data<int64_t>();
auto* o_data = param.Out->mutable_data<int64_t>();
memcpy(o_data, x_data, sizeof(int64_t) * param.X->numel());
} else if (param.in_dtype == 21 && param.out_dtype == 5) { // int8->float32 } else if (param.in_dtype == 21 && param.out_dtype == 5) { // int8->float32
const char* x_data_begin = param.X->data<char>(); const char* x_data_begin = param.X->data<char>();
const char* x_data_end = x_data_begin + param.X->numel(); const char* x_data_end = x_data_begin + param.X->numel();
...@@ -56,7 +61,7 @@ void CastCompute::Run() { ...@@ -56,7 +61,7 @@ void CastCompute::Run() {
float* out_data = param.Out->mutable_data<float>(); float* out_data = param.Out->mutable_data<float>();
std::transform( std::transform(
x_data_begin, x_data_end, out_data, TransOp<unsigned char, float>); x_data_begin, x_data_end, out_data, TransOp<unsigned char, float>);
} else if (param.in_dtype == 3 && param.out_dtype == 2) { } else if (param.in_dtype == 3 && param.out_dtype == 2) { // int64->int32
const int64_t* x_data_begin = param.X->data<int64_t>(); const int64_t* x_data_begin = param.X->data<int64_t>();
const int64_t* x_data_end = x_data_begin + param.X->numel(); const int64_t* x_data_end = x_data_begin + param.X->numel();
int32_t* out_data = param.Out->mutable_data<int32_t>(); int32_t* out_data = param.Out->mutable_data<int32_t>();
......
...@@ -169,21 +169,47 @@ void SliceCompute<T, PType>::Run() { ...@@ -169,21 +169,47 @@ void SliceCompute<T, PType>::Run() {
using slice_float = using slice_float =
paddle::lite::kernels::arm::SliceCompute<float, PRECISION(kFloat)>; paddle::lite::kernels::arm::SliceCompute<float, PRECISION(kFloat)>;
REGISTER_LITE_KERNEL(slice, kARM, kFloat, kNCHW, slice_float, def) REGISTER_LITE_KERNEL(slice, kARM, kFloat, kNCHW, slice_float, def)
.BindInput("Input", {LiteType::GetTensorTy(TARGET(kARM))}) .BindInput("Input",
.BindInput("StartsTensor", {LiteType::GetTensorTy(TARGET(kARM))}) {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kFloat))})
.BindInput("EndsTensor", {LiteType::GetTensorTy(TARGET(kARM))}) .BindInput("StartsTensor",
.BindInput("StartsTensorList", {LiteType::GetTensorTy(TARGET(kARM))}) {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt32))})
.BindInput("EndsTensorList", {LiteType::GetTensorTy(TARGET(kARM))}) .BindInput("EndsTensor",
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM))}) {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt32))})
.BindInput("StartsTensorList",
{LiteType::GetTensorListTy(TARGET(kARM), PRECISION(kInt32))})
.BindInput("EndsTensorList",
{LiteType::GetTensorListTy(TARGET(kARM), PRECISION(kInt32))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kFloat))})
.Finalize(); .Finalize();
using slice_int32 = using slice_int32 =
paddle::lite::kernels::arm::SliceCompute<int, PRECISION(kInt32)>; paddle::lite::kernels::arm::SliceCompute<int, PRECISION(kInt32)>;
REGISTER_LITE_KERNEL(slice, kARM, kInt32, kNCHW, slice_int32, def) REGISTER_LITE_KERNEL(slice, kARM, kInt32, kNCHW, slice_int32, def)
.BindInput("Input", .BindInput("Input",
{LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt32))}) {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt32))})
.BindInput("StartsTensor", {LiteType::GetTensorTy(TARGET(kARM))}) .BindInput("StartsTensor",
.BindInput("EndsTensor", {LiteType::GetTensorTy(TARGET(kARM))}) {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt32))})
.BindInput("StartsTensorList", {LiteType::GetTensorTy(TARGET(kARM))}) .BindInput("EndsTensor",
.BindInput("EndsTensorList", {LiteType::GetTensorTy(TARGET(kARM))}) {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt32))})
.BindInput("StartsTensorList",
{LiteType::GetTensorListTy(TARGET(kARM), PRECISION(kInt32))})
.BindInput("EndsTensorList",
{LiteType::GetTensorListTy(TARGET(kARM), PRECISION(kInt32))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt32))}) .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt32))})
.Finalize(); .Finalize();
using slice_int64 =
paddle::lite::kernels::arm::SliceCompute<int64_t, PRECISION(kInt64)>;
REGISTER_LITE_KERNEL(slice, kARM, kInt64, kNCHW, slice_int64, def)
.BindInput("Input",
{LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt64))})
.BindInput("StartsTensor",
{LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt32))})
.BindInput("EndsTensor",
{LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt32))})
.BindInput("StartsTensorList",
{LiteType::GetTensorListTy(TARGET(kARM), PRECISION(kInt32))})
.BindInput("EndsTensorList",
{LiteType::GetTensorListTy(TARGET(kARM), PRECISION(kInt32))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt64))})
.Finalize();
...@@ -21,9 +21,10 @@ namespace lite { ...@@ -21,9 +21,10 @@ namespace lite {
namespace kernels { namespace kernels {
namespace arm { namespace arm {
void SplitCompute::Run() { template <typename T, PrecisionType PType>
auto& param = Param<operators::SplitParam>(); void SplitCompute<T, PType>::Run() {
const float* din = param.x->data<float>(); auto& param = this->template Param<operators::SplitParam>();
const T* din = param.x->template data<T>();
auto& dout = param.output; auto& dout = param.output;
auto in_dim = param.x->dims(); auto in_dim = param.x->dims();
std::vector<int> in_strides(in_dim.size()); std::vector<int> in_strides(in_dim.size());
...@@ -42,12 +43,24 @@ void SplitCompute::Run() { ...@@ -42,12 +43,24 @@ void SplitCompute::Run() {
} // namespace lite } // namespace lite
} // namespace paddle } // namespace paddle
REGISTER_LITE_KERNEL( using split_float =
split, kARM, kFloat, kNCHW, paddle::lite::kernels::arm::SplitCompute, def) paddle::lite::kernels::arm::SplitCompute<float, PRECISION(kFloat)>;
.BindInput("X", {LiteType::GetTensorTy(TARGET(kARM))}) REGISTER_LITE_KERNEL(split, kARM, kFloat, kNCHW, split_float, def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kFloat))})
.BindInput("AxisTensor", .BindInput("AxisTensor",
{LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt32))}) {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt32))})
.BindInput("SectionsTensorList", .BindInput("SectionsTensorList",
{LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt32))}) {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt32))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM))}) .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kFloat))})
.Finalize();
using split_int64 =
paddle::lite::kernels::arm::SplitCompute<int64_t, PRECISION(kInt64)>;
REGISTER_LITE_KERNEL(split, kARM, kInt64, kNCHW, split_int64, def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt64))})
.BindInput("AxisTensor",
{LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt32))})
.BindInput("SectionsTensorList",
{LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt32))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt64))})
.Finalize(); .Finalize();
...@@ -22,7 +22,8 @@ namespace lite { ...@@ -22,7 +22,8 @@ namespace lite {
namespace kernels { namespace kernels {
namespace arm { namespace arm {
class SplitCompute : public KernelLite<TARGET(kARM), PRECISION(kFloat)> { template <typename T, PrecisionType PType>
class SplitCompute : public KernelLite<TARGET(kARM), PType> {
public: public:
void Run() override; void Run() override;
......
...@@ -93,13 +93,13 @@ void split_compute_ref(const operators::SplitParam& param) { ...@@ -93,13 +93,13 @@ void split_compute_ref(const operators::SplitParam& param) {
} }
TEST(split_arm, init) { TEST(split_arm, init) {
SplitCompute split; SplitCompute<float, PRECISION(kFloat)> split;
ASSERT_EQ(split.precision(), PRECISION(kFloat)); ASSERT_EQ(split.precision(), PRECISION(kFloat));
ASSERT_EQ(split.target(), TARGET(kARM)); ASSERT_EQ(split.target(), TARGET(kARM));
} }
TEST(split_arm, compute) { TEST(split_arm, compute) {
SplitCompute split; SplitCompute<float, PRECISION(kFloat)> split;
operators::SplitParam param; operators::SplitParam param;
lite::Tensor x; lite::Tensor x;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册