未验证 提交 07104881 编写于 作者: Z zhupengyang 提交者: GitHub

support int64 of slice, split; support cast: int64->int64, int32->int64 (#4048)

上级 8e76e305
......@@ -79,6 +79,13 @@ void slice(const Dtype* input,
}
}
template void slice(const float* input,
std::vector<int64_t> dims,
std::vector<int> axes,
std::vector<int> starts,
std::vector<int> ends,
float* out,
Context<TARGET(kARM)>* ctx);
template void slice(const int* input,
std::vector<int64_t> dims,
std::vector<int> axes,
......@@ -86,12 +93,12 @@ template void slice(const int* input,
std::vector<int> ends,
int* out,
Context<TARGET(kARM)>* ctx);
template void slice(const float* input,
template void slice(const int64_t* input,
std::vector<int64_t> dims,
std::vector<int> axes,
std::vector<int> starts,
std::vector<int> ends,
float* out,
int64_t* out,
Context<TARGET(kARM)>* ctx);
} // namespace math
......
......@@ -51,11 +51,11 @@ void split_cpy<float>(const float* din, float* dout, int num) {
}
}
template <>
void split<float>(const float* din,
const std::vector<lite::Tensor*>& dout,
const int axis,
const std::vector<int>& in_strides) {
template <typename T>
void split(const T* din,
const std::vector<lite::Tensor*>& dout,
const int axis,
const std::vector<int>& in_strides) {
int input_offset = 0;
for (auto out : dout) {
auto out_dim = out->dims();
......@@ -65,15 +65,15 @@ void split<float>(const float* din,
out_strides[i] = out_strides[i + 1] * out_dim[i];
}
float* out_data = out->mutable_data<float>();
T* out_data = out->mutable_data<T>();
int before = out_strides[0] / out_strides[axis];
int in_after = in_strides[axis];
int out_after = out_strides[axis];
const float* din_ptr = din + input_offset;
const T* din_ptr = din + input_offset;
for (int i = 0; i < before; ++i) {
std::memcpy(out_data, din_ptr, sizeof(float) * out_after);
std::memcpy(out_data, din_ptr, sizeof(T) * out_after);
din_ptr += in_after;
out_data += out_after;
}
......@@ -81,6 +81,15 @@ void split<float>(const float* din,
}
}
template void split(const float* din,
const std::vector<lite::Tensor*>& dout,
const int axis,
const std::vector<int>& in_strides);
template void split(const int64_t* din,
const std::vector<lite::Tensor*>& dout,
const int axis,
const std::vector<int>& in_strides);
} // namespace math
} // namespace arm
} // namespace lite
......
......@@ -40,6 +40,11 @@ void CastCompute::Run() {
const auto* x_data = param.X->data<float>();
auto* o_data = param.Out->mutable_data<float>();
memcpy(o_data, x_data, sizeof(float) * param.X->numel());
} else if (param.in_dtype == param.out_dtype &&
param.in_dtype == 3) { // int64->int64
const auto* x_data = param.X->data<int64_t>();
auto* o_data = param.Out->mutable_data<int64_t>();
memcpy(o_data, x_data, sizeof(int64_t) * param.X->numel());
} else if (param.in_dtype == 21 && param.out_dtype == 5) { // int8->float32
const char* x_data_begin = param.X->data<char>();
const char* x_data_end = x_data_begin + param.X->numel();
......@@ -56,7 +61,7 @@ void CastCompute::Run() {
float* out_data = param.Out->mutable_data<float>();
std::transform(
x_data_begin, x_data_end, out_data, TransOp<unsigned char, float>);
} else if (param.in_dtype == 3 && param.out_dtype == 2) {
} else if (param.in_dtype == 3 && param.out_dtype == 2) { // int64->int32
const int64_t* x_data_begin = param.X->data<int64_t>();
const int64_t* x_data_end = x_data_begin + param.X->numel();
int32_t* out_data = param.Out->mutable_data<int32_t>();
......
......@@ -169,21 +169,47 @@ void SliceCompute<T, PType>::Run() {
using slice_float =
paddle::lite::kernels::arm::SliceCompute<float, PRECISION(kFloat)>;
REGISTER_LITE_KERNEL(slice, kARM, kFloat, kNCHW, slice_float, def)
.BindInput("Input", {LiteType::GetTensorTy(TARGET(kARM))})
.BindInput("StartsTensor", {LiteType::GetTensorTy(TARGET(kARM))})
.BindInput("EndsTensor", {LiteType::GetTensorTy(TARGET(kARM))})
.BindInput("StartsTensorList", {LiteType::GetTensorTy(TARGET(kARM))})
.BindInput("EndsTensorList", {LiteType::GetTensorTy(TARGET(kARM))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM))})
.BindInput("Input",
{LiteType::GetTensorTy(TARGET(kARM), PRECISION(kFloat))})
.BindInput("StartsTensor",
{LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt32))})
.BindInput("EndsTensor",
{LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt32))})
.BindInput("StartsTensorList",
{LiteType::GetTensorListTy(TARGET(kARM), PRECISION(kInt32))})
.BindInput("EndsTensorList",
{LiteType::GetTensorListTy(TARGET(kARM), PRECISION(kInt32))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kFloat))})
.Finalize();
using slice_int32 =
paddle::lite::kernels::arm::SliceCompute<int, PRECISION(kInt32)>;
REGISTER_LITE_KERNEL(slice, kARM, kInt32, kNCHW, slice_int32, def)
.BindInput("Input",
{LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt32))})
.BindInput("StartsTensor", {LiteType::GetTensorTy(TARGET(kARM))})
.BindInput("EndsTensor", {LiteType::GetTensorTy(TARGET(kARM))})
.BindInput("StartsTensorList", {LiteType::GetTensorTy(TARGET(kARM))})
.BindInput("EndsTensorList", {LiteType::GetTensorTy(TARGET(kARM))})
.BindInput("StartsTensor",
{LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt32))})
.BindInput("EndsTensor",
{LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt32))})
.BindInput("StartsTensorList",
{LiteType::GetTensorListTy(TARGET(kARM), PRECISION(kInt32))})
.BindInput("EndsTensorList",
{LiteType::GetTensorListTy(TARGET(kARM), PRECISION(kInt32))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt32))})
.Finalize();
using slice_int64 =
paddle::lite::kernels::arm::SliceCompute<int64_t, PRECISION(kInt64)>;
REGISTER_LITE_KERNEL(slice, kARM, kInt64, kNCHW, slice_int64, def)
.BindInput("Input",
{LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt64))})
.BindInput("StartsTensor",
{LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt32))})
.BindInput("EndsTensor",
{LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt32))})
.BindInput("StartsTensorList",
{LiteType::GetTensorListTy(TARGET(kARM), PRECISION(kInt32))})
.BindInput("EndsTensorList",
{LiteType::GetTensorListTy(TARGET(kARM), PRECISION(kInt32))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt64))})
.Finalize();
......@@ -21,9 +21,10 @@ namespace lite {
namespace kernels {
namespace arm {
void SplitCompute::Run() {
auto& param = Param<operators::SplitParam>();
const float* din = param.x->data<float>();
template <typename T, PrecisionType PType>
void SplitCompute<T, PType>::Run() {
auto& param = this->template Param<operators::SplitParam>();
const T* din = param.x->template data<T>();
auto& dout = param.output;
auto in_dim = param.x->dims();
std::vector<int> in_strides(in_dim.size());
......@@ -42,12 +43,24 @@ void SplitCompute::Run() {
} // namespace lite
} // namespace paddle
REGISTER_LITE_KERNEL(
split, kARM, kFloat, kNCHW, paddle::lite::kernels::arm::SplitCompute, def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kARM))})
using split_float =
paddle::lite::kernels::arm::SplitCompute<float, PRECISION(kFloat)>;
REGISTER_LITE_KERNEL(split, kARM, kFloat, kNCHW, split_float, def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kFloat))})
.BindInput("AxisTensor",
{LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt32))})
.BindInput("SectionsTensorList",
{LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt32))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kFloat))})
.Finalize();
using split_int64 =
paddle::lite::kernels::arm::SplitCompute<int64_t, PRECISION(kInt64)>;
REGISTER_LITE_KERNEL(split, kARM, kInt64, kNCHW, split_int64, def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt64))})
.BindInput("AxisTensor",
{LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt32))})
.BindInput("SectionsTensorList",
{LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt32))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt64))})
.Finalize();
......@@ -22,7 +22,8 @@ namespace lite {
namespace kernels {
namespace arm {
class SplitCompute : public KernelLite<TARGET(kARM), PRECISION(kFloat)> {
template <typename T, PrecisionType PType>
class SplitCompute : public KernelLite<TARGET(kARM), PType> {
public:
void Run() override;
......
......@@ -93,13 +93,13 @@ void split_compute_ref(const operators::SplitParam& param) {
}
TEST(split_arm, init) {
SplitCompute split;
SplitCompute<float, PRECISION(kFloat)> split;
ASSERT_EQ(split.precision(), PRECISION(kFloat));
ASSERT_EQ(split.target(), TARGET(kARM));
}
TEST(split_arm, compute) {
SplitCompute split;
SplitCompute<float, PRECISION(kFloat)> split;
operators::SplitParam param;
lite::Tensor x;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册