diff --git a/lite/backends/arm/math/interpolate.cc b/lite/backends/arm/math/interpolate.cc index f89410ad11590c60bf5542702b60fa883298d3e6..e9e18043dfc09001ebba23f952a59474630e54aa 100644 --- a/lite/backends/arm/math/interpolate.cc +++ b/lite/backends/arm/math/interpolate.cc @@ -22,6 +22,28 @@ namespace lite { namespace arm { namespace math { +inline std::vector get_new_shape( + std::vector list_new_shape_tensor) { + // get tensor from + std::vector vec_new_shape; + for (size_t i = 0; i < list_new_shape_tensor.size(); ++i) { + auto tensor = list_new_shape_tensor[i]; + vec_new_shape.push_back(static_cast(*tensor->data())); + } + + return vec_new_shape; +} + +template +inline std::vector get_new_data_from_tensor(const Tensor* new_data_tensor) { + std::vector vec_new_data; + auto* new_data = new_data_tensor->data(); + lite::Tensor cpu_starts_tensor; + vec_new_data = + std::vector(new_data, new_data + new_data_tensor->dims().production()); + return vec_new_data; +} + // The following function bilinear_interp is partially base on // https://github.com/Tencent/ncnn/blob/master/src/layer/arm/interp_arm.cpp // Tencent is pleased to support the open source community by making ncnn @@ -472,33 +494,52 @@ void nearest_interp(const float* src, void interpolate(lite::Tensor* X, lite::Tensor* OutSize, + std::vector SizeTensor, + lite::Tensor* Scale, lite::Tensor* Out, int out_height, int out_width, - float height_scale, - float width_scale, + float scale, bool with_align, std::string interpolate_type) { + int in_h = X->dims()[2]; + int in_w = X->dims()[3]; + if (SizeTensor.size() > 0) { + auto new_size = get_new_shape(SizeTensor); + out_height = new_size[0]; + out_width = new_size[1]; + } else { + auto scale_tensor = Scale; + if (scale_tensor != nullptr) { + auto scale_data = get_new_data_from_tensor(scale_tensor); + scale = scale_data[0]; + } + if (scale > 0) { + out_height = static_cast(in_h * scale); + out_width = static_cast(in_w * scale); + } + auto out_size = OutSize; + if (out_size != nullptr) { + auto out_size_data = get_new_data_from_tensor(out_size); + out_height = static_cast(out_size_data[0]); + out_width = static_cast(out_size_data[1]); + } + } + float height_scale = scale; + float width_scale = scale; if (out_width > 0 && out_height > 0) { height_scale = static_cast(out_height / X->dims()[2]); width_scale = static_cast(out_width / X->dims()[3]); } - if (OutSize != nullptr) { - auto OutSize_data = OutSize->data(); - int h_out = OutSize_data[0]; // HW - int w_out = OutSize_data[1]; // HW - int num_cout = Out->dims()[0]; - int c_cout = Out->dims()[1]; - Out->Resize({num_cout, c_cout, h_out, w_out}); - } + int num_cout = X->dims()[0]; + int c_cout = X->dims()[1]; + Out->Resize({num_cout, c_cout, out_height, out_width}); float* dout = Out->mutable_data(); const float* din = X->data(); int out_num = Out->dims()[0]; int out_c = Out->dims()[1]; int count = out_num * out_c; - int in_h = X->dims()[2]; - int in_w = X->dims()[3]; int out_h = Out->dims()[2]; int out_w = Out->dims()[3]; int spatial_in = in_h * in_w; diff --git a/lite/backends/arm/math/interpolate.h b/lite/backends/arm/math/interpolate.h index be250f6a5e7581ba70809362d169167fea1d1c11..e9c41c5bc86c8f00d57e096e3cd2b5f37df3a474 100644 --- a/lite/backends/arm/math/interpolate.h +++ b/lite/backends/arm/math/interpolate.h @@ -44,11 +44,12 @@ void nearest_interp(const float* src, void interpolate(lite::Tensor* X, lite::Tensor* OutSize, + std::vector SizeTensor, + lite::Tensor* Scale, lite::Tensor* Out, int out_height, int out_width, - float height_scale, - float width_scale, + float scale, bool with_align, std::string interpolate_type); diff --git a/lite/kernels/arm/interpolate_compute.cc b/lite/kernels/arm/interpolate_compute.cc index a26777826db6976c755fac7798880871f407c12d..0398dabeaee4c042b33ac5572b783b126bc8ddb4 100644 --- a/lite/kernels/arm/interpolate_compute.cc +++ b/lite/kernels/arm/interpolate_compute.cc @@ -28,6 +28,8 @@ void BilinearInterpCompute::Run() { auto& param = Param(); lite::Tensor* X = param.X; lite::Tensor* OutSize = param.OutSize; + auto SizeTensor = param.SizeTensor; + auto Scale = param.Scale; lite::Tensor* Out = param.Out; float scale = param.scale; int out_w = param.out_w; @@ -36,11 +38,12 @@ void BilinearInterpCompute::Run() { std::string interp_method = "Bilinear"; lite::arm::math::interpolate(X, OutSize, + SizeTensor, + Scale, Out, out_h, out_w, scale, - scale, align_corners, interp_method); } @@ -49,6 +52,8 @@ void NearestInterpCompute::Run() { auto& param = Param(); lite::Tensor* X = param.X; lite::Tensor* OutSize = param.OutSize; + auto SizeTensor = param.SizeTensor; + auto Scale = param.Scale; lite::Tensor* Out = param.Out; float scale = param.scale; int out_w = param.out_w; @@ -57,11 +62,12 @@ void NearestInterpCompute::Run() { std::string interp_method = "Nearest"; lite::arm::math::interpolate(X, OutSize, + SizeTensor, + Scale, Out, out_h, out_w, scale, - scale, align_corners, interp_method); } @@ -79,6 +85,8 @@ REGISTER_LITE_KERNEL(bilinear_interp, def) .BindInput("X", {LiteType::GetTensorTy(TARGET(kARM))}) .BindInput("OutSize", {LiteType::GetTensorTy(TARGET(kARM))}) + .BindInput("SizeTensor", {LiteType::GetTensorTy(TARGET(kARM))}) + .BindInput("Scale", {LiteType::GetTensorTy(TARGET(kARM))}) .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM))}) .Finalize(); @@ -90,5 +98,7 @@ REGISTER_LITE_KERNEL(nearest_interp, def) .BindInput("X", {LiteType::GetTensorTy(TARGET(kARM))}) .BindInput("OutSize", {LiteType::GetTensorTy(TARGET(kARM))}) + .BindInput("SizeTensor", {LiteType::GetTensorTy(TARGET(kARM))}) + .BindInput("Scale", {LiteType::GetTensorTy(TARGET(kARM))}) .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM))}) .Finalize(); diff --git a/lite/kernels/cuda/bilinear_interp_compute.cu b/lite/kernels/cuda/bilinear_interp_compute.cu index 7e1dbaf228c31d8123e48832e93e0180c4920359..a3bd89f642127da6fa7d7c475d87ab389313f975 100644 --- a/lite/kernels/cuda/bilinear_interp_compute.cu +++ b/lite/kernels/cuda/bilinear_interp_compute.cu @@ -11,6 +11,7 @@ limitations under the License. */ #pragma once #include +#include "lite/backends/cuda/target_wrapper.h" #include "lite/core/op_registry.h" #include "lite/kernels/cuda/bilinear_interp_compute.h" @@ -20,6 +21,43 @@ namespace kernels { namespace cuda { using Tensor = lite::Tensor; +inline std::vector get_new_shape( + std::vector list_new_shape_tensor) { + // get tensor from + std::vector vec_new_shape; + for (size_t i = 0; i < list_new_shape_tensor.size(); ++i) { + auto tensor = list_new_shape_tensor[i]; + lite::Tensor temp; + auto temp_data = temp.mutable_data(); + auto tensor_data = tensor->data(TARGET(kCUDA)); + cudaMemcpy(temp_data, + tensor_data, + tensor->dims().production() * sizeof(float), + cudaMemcpyDeviceToHost); + + vec_new_shape.push_back(static_cast(*temp_data)); + } + + return vec_new_shape; +} + +template +inline std::vector get_new_data_from_tensor(const Tensor* new_data_tensor) { + std::vector vec_new_data; + auto* new_data = new_data_tensor->data(kCUDA); + lite::Tensor cpu_starts_tensor; + auto cpu_starts_tensor_data = cpu_starts_tensor.mutable_data(); + cudaMemcpy(cpu_starts_tensor_data, + new_data, + new_data_tensor->dims().production() * sizeof(T), + cudaMemcpyDeviceToHost); + + auto new_data_ = cpu_starts_tensor.data(); + vec_new_data = std::vector( + new_data_, new_data_ + new_data_tensor->dims().production()); + return vec_new_data; +} + template __global__ void BilinearInterp(const T* in, const size_t in_img_h, @@ -103,19 +141,34 @@ void BilinearInterpCompute::Run() { int out_w = param.out_w; float scale = param.scale; bool align_corners = param.align_corners; - if (scale > 0) { - out_h = static_cast(in_h * scale); - out_w = static_cast(in_w * scale); - } + auto align_mode = param.align_mode; + + auto list_new_shape_tensor = param.SizeTensor; + if (list_new_shape_tensor.size() > 0) { + // have size tensor + auto new_size = get_new_shape(list_new_shape_tensor); + out_h = new_size[0]; + out_w = new_size[1]; + } else { + auto scale_tensor = param.Scale; + if (scale_tensor != nullptr) { + auto scale_data = get_new_data_from_tensor(scale_tensor); + scale = scale_data[0]; + } + if (scale > 0) { + out_h = static_cast(in_h * scale); + out_w = static_cast(in_w * scale); + } - if (out_size != nullptr) { - Tensor sizes; - float* size_data = sizes.mutable_data(); - float* outsize_data = out_size->mutable_data(TARGET(kCUDA)); - cudaMemcpy( - size_data, outsize_data, sizeof(float) * 2, cudaMemcpyDeviceToHost); - out_h = static_cast(size_data[0]); - out_w = static_cast(size_data[1]); + if (out_size != nullptr) { + lite::Tensor sizes; + float* size_data = sizes.mutable_data(); + float* outsize_data = out_size->mutable_data(TARGET(kCUDA)); + cudaMemcpy( + size_data, outsize_data, sizeof(float) * 2, cudaMemcpyDeviceToHost); + out_h = static_cast(size_data[0]); + out_w = static_cast(size_data[1]); + } } auto output_data = output->mutable_data(TARGET(kCUDA)); @@ -188,6 +241,14 @@ REGISTER_LITE_KERNEL(bilinear_interp, {LiteType::GetTensorTy(TARGET(kCUDA), PRECISION(kFloat), DATALAYOUT(kNCHW))}) + .BindInput("SizeTensor", + {LiteType::GetTensorTy(TARGET(kCUDA), + PRECISION(kFloat), + DATALAYOUT(kNCHW))}) + .BindInput("Scale", + {LiteType::GetTensorTy(TARGET(kCUDA), + PRECISION(kFloat), + DATALAYOUT(kNCHW))}) .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kCUDA), PRECISION(kFloat), diff --git a/lite/kernels/cuda/bilinear_interp_compute_test.cc b/lite/kernels/cuda/bilinear_interp_compute_test.cc index e7e8143150d2963fb4cb74c3530cfd6e125a454c..d82823680ca2d12178f7438772b16dc48cf8cb2c 100644 --- a/lite/kernels/cuda/bilinear_interp_compute_test.cc +++ b/lite/kernels/cuda/bilinear_interp_compute_test.cc @@ -16,6 +16,7 @@ #include #include #include +#include namespace paddle { namespace lite { @@ -98,6 +99,110 @@ TEST(bilinear_interp, normal) { } } +TEST(bilinear_interp, update) { + BilinearInterpCompute bilinear_interp_kernel; + std::unique_ptr ctx(new KernelContext); + auto& context = ctx->As(); + + operators::InterpolateParam param; + + std::vector size_tensor(2), size_tensor_cpu(2), size_tensor_ref(2); + Tensor x, input_scale, osz, out; + Tensor x_cpu, input_scale_cpu, osz_cpu, out_cpu; + Tensor x_ref, size_tensor_ref, input_scale_ref, osz_ref, out_ref; + + int n = 1, c = 1, in_h = 3, in_w = 3; + int out_h = 6, out_w = 6; + float scale = 2.0; + + param.out_h = out_h; + param.out_w = out_w; + param.scale = scale; + param.align_corners = false; + param.align_mode = 0; + + x.Resize({n, c, in_h, in_w}); + size_tensor[0]->Resize({1}); + size_tensor[1]->Resize({1}); + input_scale.Resize({1}); + osz.Resize({2}); + out.Resize({n, c, out_h, out_w}); + + x_cpu.Resize({n, c, in_h, in_w}); + size_tensor_cpu[0]->Resize({1}); + size_tensor_cpu[1]->Resize({1}); + input_scale_cpu.Resize({1}); + osz_cpu.Resize({2}); + out_cpu.Resize({n, c, out_h, out_w}); + + x_ref.Resize({n, c, in_h, in_w}); + size_tensor_ref[0]->Resize({1}); + size_tensor_ref[1]->Resize({1}); + input_scale_ref.Resize({1}); + osz_ref.Resize({2}); + out_ref.Resize({n, c, out_h, out_w}); + + auto* out_data = out.mutable_data(TARGET(kCUDA)); + + float* x_cpu_data = x_cpu.mutable_data(); + float* size_tensor0_cpu_data = size_tensor_cpu[0]->mutable_data(); + float* size_tensor1_cpu_data = size_tensor_cpu[1]->mutable_data(); + float* input_scale_cpu_data = input_scale_cpu.mutable_data(); + float* osz_cpu_data = osz_cpu.mutable_data(); + float* out_cpu_data = out_cpu.mutable_data(); + + float* x_ref_data = x_ref.mutable_data(); + float* size_tensor0_ref_data = size_tensor_ref[0]->mutable_data(); + float* size_tensor1_ref_data = size_tensor_ref[1]->mutable_data(); + float* input_scale_ref_data = input_scale_ref.mutable_data(); + float* osz_ref_data = osz_ref.mutable_data(); + + for (int i = 0; i < x_cpu.numel(); ++i) { + x_cpu_data[i] = i + 5.0; + x_ref_data[i] = i + 5.0; + } + osz_cpu_data[0] = out_h; + osz_cpu_data[1] = out_w; + size_tensor0_cpu_data[0] = out_h; + size_tensor1_cpu_data[0] = out_w; + input_scale_cpu_data[0] = scale; + osz_ref_data[0] = out_h; + osz_ref_data[1] = out_w; + size_tensor0_ref_data[0] = out_h; + size_tensor1_ref_data[0] = out_w; + input_scale_ref_data[0] = scale; + + x.Assign(x_cpu_data, x_cpu.dims()); + size_tensor[0]->Assign( + size_tensor0_cpu_data, {1}); + size_tensor[1]->Assign( + size_tensor1_cpu_data, {1}); + input_scale.Assign(input_scale_cpu_data, + {1}); + osz.Assign(osz_cpu_data, osz_cpu.dims()); + + param.X = &x; + param.SizeTensor = size_tensor; + param.Scale = &input_scale; + param.OutSize = &osz; + param.Out = &out; + bilinear_interp_kernel.SetParam(param); + + cudaStream_t stream; + cudaStreamCreate(&stream); + context.SetExecStream(stream); + + bilinear_interp_kernel.SetContext(std::move(ctx)); + bilinear_interp_kernel.Launch(); + cudaDeviceSynchronize(); + + CopySync( + out_cpu_data, out_data, sizeof(float) * out.numel(), IoDirection::DtoH); + for (int i = 0; i < out.numel(); i++) { + LOG(INFO) << out_cpu_data[i]; + } +} + } // namespace cuda } // namespace kernels } // namespace lite diff --git a/lite/kernels/cuda/nearest_interp_compute.cu b/lite/kernels/cuda/nearest_interp_compute.cu index 1a614e0656b417786deff8df6b7a827433b33f7b..6653fa62df737606d18bdc975d8ffa411701a9af 100644 --- a/lite/kernels/cuda/nearest_interp_compute.cu +++ b/lite/kernels/cuda/nearest_interp_compute.cu @@ -11,6 +11,7 @@ limitations under the License. */ #pragma once #include +#include "lite/backends/cuda/target_wrapper.h" #include "lite/core/op_registry.h" #include "lite/kernels/cuda/nearest_interp_compute.h" @@ -20,6 +21,43 @@ namespace kernels { namespace cuda { using Tensor = lite::Tensor; +inline std::vector get_new_shape( + std::vector list_new_shape_tensor) { + // get tensor from + std::vector vec_new_shape; + for (size_t i = 0; i < list_new_shape_tensor.size(); ++i) { + auto tensor = list_new_shape_tensor[i]; + lite::Tensor temp; + auto temp_data = temp.mutable_data(); + auto tensor_data = tensor->data(TARGET(kCUDA)); + cudaMemcpy(temp_data, + tensor_data, + tensor->dims().production() * sizeof(float), + cudaMemcpyDeviceToHost); + + vec_new_shape.push_back(static_cast(*temp_data)); + } + + return vec_new_shape; +} + +template +inline std::vector get_new_data_from_tensor(const Tensor* new_data_tensor) { + std::vector vec_new_data; + auto* new_data = new_data_tensor->data(kCUDA); + lite::Tensor cpu_starts_tensor; + auto cpu_starts_tensor_data = cpu_starts_tensor.mutable_data(); + cudaMemcpy(cpu_starts_tensor_data, + new_data, + new_data_tensor->dims().production() * sizeof(T), + cudaMemcpyDeviceToHost); + + auto new_data_ = cpu_starts_tensor.data(); + vec_new_data = std::vector( + new_data_, new_data_ + new_data_tensor->dims().production()); + return vec_new_data; +} + __global__ void KeNearestNeighborInterp(const float* in, const size_t in_img_h, const size_t in_img_w, @@ -79,19 +117,34 @@ void NearestInterpCompute::Run() { int out_w = param.out_w; float scale = param.scale; bool align_corners = param.align_corners; - if (scale > 0) { - out_h = static_cast(in_h * scale); - out_w = static_cast(in_w * scale); - } - - if (out_size != nullptr) { - Tensor sizes; - float* size_data = sizes.mutable_data(); - float* outsize_data = out_size->mutable_data(TARGET(kCUDA)); - cudaMemcpy( - size_data, outsize_data, sizeof(float) * 2, cudaMemcpyDeviceToHost); - out_h = static_cast(size_data[0]); - out_w = static_cast(size_data[1]); + auto align_mode = param.align_mode; + + auto list_new_shape_tensor = param.SizeTensor; + if (list_new_shape_tensor.size() > 0) { + // have size tensor + auto new_size = get_new_shape(list_new_shape_tensor); + out_h = new_size[0]; + out_w = new_size[1]; + } else { + auto scale_tensor = param.Scale; + if (scale_tensor != nullptr) { + auto scale_data = get_new_data_from_tensor(scale_tensor); + scale = scale_data[0]; + } + if (scale > 0) { + out_h = static_cast(in_h * scale); + out_w = static_cast(in_w * scale); + } + + if (out_size != nullptr) { + lite::Tensor sizes; + float* size_data = sizes.mutable_data(); + float* outsize_data = out_size->mutable_data(TARGET(kCUDA)); + cudaMemcpy( + size_data, outsize_data, sizeof(float) * 2, cudaMemcpyDeviceToHost); + out_h = static_cast(size_data[0]); + out_w = static_cast(size_data[1]); + } } auto output_data = output->mutable_data(TARGET(kCUDA)); @@ -162,6 +215,14 @@ REGISTER_LITE_KERNEL(nearest_interp, {LiteType::GetTensorTy(TARGET(kCUDA), PRECISION(kFloat), DATALAYOUT(kNCHW))}) + .BindInput("SizeTensor", + {LiteType::GetTensorTy(TARGET(kCUDA), + PRECISION(kFloat), + DATALAYOUT(kNCHW))}) + .BindInput("Scale", + {LiteType::GetTensorTy(TARGET(kCUDA), + PRECISION(kFloat), + DATALAYOUT(kNCHW))}) .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kCUDA), PRECISION(kFloat), diff --git a/lite/kernels/cuda/nearest_interp_compute_test.cc b/lite/kernels/cuda/nearest_interp_compute_test.cc index 85032016d630f11bbfe150f750470e89e241c61b..e8cb01b5e001c10015a8126f05c86d0f445865ac 100644 --- a/lite/kernels/cuda/nearest_interp_compute_test.cc +++ b/lite/kernels/cuda/nearest_interp_compute_test.cc @@ -16,6 +16,7 @@ #include #include #include +#include namespace paddle { namespace lite { @@ -143,6 +144,110 @@ TEST(nearest_interp, normal) { } } +TEST(nearest_interp, update) { + NearestInterpCompute nearest_interp_kernel; + std::unique_ptr ctx(new KernelContext); + auto& context = ctx->As(); + + operators::InterpolateParam param; + + std::vector size_tensor(2), size_tensor_cpu(2), size_tensor_ref(2); + Tensor x, input_scale, osz, out; + Tensor x_cpu, input_scale_cpu, osz_cpu, out_cpu; + Tensor x_ref, size_tensor_ref, input_scale_ref, osz_ref, out_ref; + + int n = 1, c = 3, in_h = 40, in_w = 40; + int out_h = 80, out_w = 80; + float scale = 2.0; + + param.out_h = out_h; + param.out_w = out_w; + param.scale = scale; + param.align_corners = false; + param.align_mode = 0; + + x.Resize({n, c, in_h, in_w}); + size_tensor[0]->Resize({1}); + size_tensor[1]->Resize({1}); + input_scale.Resize({1}); + osz.Resize({2}); + out.Resize({n, c, out_h, out_w}); + + x_cpu.Resize({n, c, in_h, in_w}); + size_tensor_cpu[0]->Resize({1}); + size_tensor_cpu[1]->Resize({1}); + input_scale_cpu.Resize({1}); + osz_cpu.Resize({2}); + out_cpu.Resize({n, c, out_h, out_w}); + + x_ref.Resize({n, c, in_h, in_w}); + size_tensor_ref[0]->Resize({1}); + size_tensor_ref[1]->Resize({1}); + input_scale_ref.Resize({1}); + osz_ref.Resize({2}); + out_ref.Resize({n, c, out_h, out_w}); + + auto* out_data = out.mutable_data(TARGET(kCUDA)); + + float* x_cpu_data = x_cpu.mutable_data(); + float* size_tensor0_cpu_data = size_tensor_cpu[0]->mutable_data(); + float* size_tensor1_cpu_data = size_tensor_cpu[1]->mutable_data(); + float* input_scale_cpu_data = input_scale_cpu.mutable_data(); + float* osz_cpu_data = osz_cpu.mutable_data(); + float* out_cpu_data = out_cpu.mutable_data(); + + float* x_ref_data = x_ref.mutable_data(); + float* size_tensor0_ref_data = size_tensor_ref[0]->mutable_data(); + float* size_tensor1_ref_data = size_tensor_ref[1]->mutable_data(); + float* input_scale_ref_data = input_scale_ref.mutable_data(); + float* osz_ref_data = osz_ref.mutable_data(); + + for (int i = 0; i < x_cpu.numel(); ++i) { + x_cpu_data[i] = i + 5.0; + x_ref_data[i] = i + 5.0; + } + osz_cpu_data[0] = out_h; + osz_cpu_data[1] = out_w; + size_tensor0_cpu_data[0] = out_h; + size_tensor1_cpu_data[0] = out_w; + input_scale_cpu_data[0] = scale; + osz_ref_data[0] = out_h; + osz_ref_data[1] = out_w; + size_tensor0_ref_data[0] = out_h; + size_tensor1_ref_data[0] = out_w; + input_scale_ref_data[0] = scale; + + x.Assign(x_cpu_data, x_cpu.dims()); + size_tensor[0]->Assign( + size_tensor0_cpu_data, {1}); + size_tensor[1]->Assign( + size_tensor1_cpu_data, {1}); + input_scale.Assign(input_scale_cpu_data, + {1}); + osz.Assign(osz_cpu_data, osz_cpu.dims()); + + param.X = &x; + param.SizeTensor = size_tensor; + param.Scale = &input_scale; + param.OutSize = &osz; + param.Out = &out; + nearest_interp_kernel.SetParam(param); + + cudaStream_t stream; + cudaStreamCreate(&stream); + context.SetExecStream(stream); + + nearest_interp_kernel.SetContext(std::move(ctx)); + nearest_interp_kernel.Launch(); + cudaDeviceSynchronize(); + + CopySync( + out_cpu_data, out_data, sizeof(float) * out.numel(), IoDirection::DtoH); + for (int i = 0; i < out.numel(); i++) { + LOG(INFO) << out_cpu_data[i]; + } +} + } // namespace cuda } // namespace kernels } // namespace lite diff --git a/lite/operators/interpolate_op.cc b/lite/operators/interpolate_op.cc index b98240ba4f255377c0ac661950a45bef0a7d0516..936da73d89007f4f6dd36fa770df537996c40a51 100644 --- a/lite/operators/interpolate_op.cc +++ b/lite/operators/interpolate_op.cc @@ -45,23 +45,42 @@ bool InterpolateOp::InferShape() const { int out_h; int out_w; - if (OutSize != nullptr) { - auto outsize_data = OutSize->data(); - int h_out = outsize_data[0]; // HW - int w_out = outsize_data[1]; // HW - param_.Out->Resize({n, c, h_out, w_out}); + auto SizeTensor = param_.SizeTensor; + if (!SizeTensor.empty()) { + CHECK(SizeTensor.size() == 2) + << "Input(SizeTensor)'size of Op(interpolate) must be 2. " + "Attr(out_shape)'s length must be 2 for 4-D input tensor."; + out_h = param_.out_h; + out_w = param_.out_w; + param_.Out->Resize({n, c, out_h, out_w}); + return true; + } + + auto Scale = param_.Scale; + if (Scale) { + auto scale_dims = Scale->dims(); + CHECK(scale_dims.size() == 1) << "Scale's dimension size must be 1."; + out_h = -1; + out_w = -1; } else { - if (0 >= param_.out_h && 0 >= param_.out_w) { - out_h = h * param_.scale; - out_w = w * param_.scale; + auto scale = param_.scale; + if (scale > 0) { + out_h = static_cast(h * scale); + out_w = static_cast(w * scale); out_h = out_h > 0 ? out_h : -1; out_w = out_w > 0 ? out_w : -1; } else { out_h = param_.out_h; out_w = param_.out_w; } - param_.Out->Resize({n, c, out_h, out_w}); } + + if (OutSize != nullptr) { + auto out_lod = param_.Out->mutable_lod(); + *out_lod = param_.X->lod(); + } + param_.Out->Resize({n, c, out_h, out_w}); + return true; } @@ -76,6 +95,24 @@ bool InterpolateOp::AttachImpl(const cpp::OpDesc& op_desc, lite::Scope* scope) { } else { param_.OutSize = nullptr; } + + if (op_desc.HasInput("SizeTensor")) { + auto size_tensor = op_desc.Input("SizeTensor"); + for (auto var : size_tensor) { + param_.SizeTensor.push_back( + scope->FindVar(var)->GetMutable()); + } + } + + if (op_desc.HasInput("Scale")) { + auto scale_var_names = op_desc.Input("Scale"); + if (scale_var_names.size() > 0) { + param_.Scale = + scope->FindVar(scale_var_names.front())->GetMutable(); + } + } else { + param_.Scale = nullptr; + } auto Out = op_desc.Output("Out").front(); param_.X = scope->FindVar(X)->GetMutable(); param_.Out = scope->FindVar(Out)->GetMutable(); diff --git a/lite/operators/op_params.h b/lite/operators/op_params.h index 32b80518e505c9dbc46d392308cf572a4e7f1278..474c97559041d069ccdaa2e149c83cea4ea9ae2c 100644 --- a/lite/operators/op_params.h +++ b/lite/operators/op_params.h @@ -94,6 +94,8 @@ struct InterpolateParam { lite::Tensor* X{}; lite::Tensor* OutSize{}; lite::Tensor* Out{}; + std::vector SizeTensor; + lite::Tensor* Scale; float scale{0.f}; int out_h{-1}; @@ -101,6 +103,7 @@ struct InterpolateParam { bool align_corners{true}; int align_mode{1}; std::string interp_method{"Nearest"}; + DataLayoutType data_layout{DATALAYOUT(kNCHW)}; }; // For Mul Op diff --git a/lite/tests/kernels/bilinear_interp_compute_test.cc b/lite/tests/kernels/bilinear_interp_compute_test.cc index 0779caf67aac907e6f8ccde8b3e65d413cf65db9..7ea4293f080df31d9bb05b4998b5b2d9ae7d5a47 100644 --- a/lite/tests/kernels/bilinear_interp_compute_test.cc +++ b/lite/tests/kernels/bilinear_interp_compute_test.cc @@ -22,6 +22,27 @@ namespace paddle { namespace lite { +inline std::vector get_new_shape( + std::vector list_new_shape_tensor) { + // get tensor from + std::vector vec_new_shape; + for (size_t i = 0; i < list_new_shape_tensor.size(); ++i) { + auto tensor = list_new_shape_tensor[i]; + vec_new_shape.push_back(static_cast(*(tensor->data()))); + } + return vec_new_shape; +} + +template +inline std::vector get_new_data_from_tensor(const Tensor* new_data_tensor) { + std::vector vec_new_data; + auto* new_data = new_data_tensor->data(); + lite::Tensor cpu_starts_tensor; + vec_new_data = + std::vector(new_data, new_data + new_data_tensor->dims().production()); + return vec_new_data; +} + template void resize_bilinear_align(std::vector inputs, lite::Tensor* output) { @@ -149,6 +170,9 @@ class BilinearInterpComputeTester : public arena::TestCase { protected: // common attributes for this op. std::string input0_ = "X"; + std::string sizetensor0_ = "SizeTensor0"; + std::string sizetensor1_ = "SizeTensor1"; + std::string input_scale_ = "Scale"; std::string input1_ = "OutSize"; std::string output_ = "Out"; @@ -162,6 +186,8 @@ class BilinearInterpComputeTester : public arena::TestCase { std::string interp_method_ = "Bilinear"; DDim _dims0_{{1, 1, 16, 16}}; DDim _dims1_{{2}}; + DDim sizetensor_dims_{{1}}; + DDim scale_dims_{{1}}; public: BilinearInterpComputeTester(const Place& place, @@ -190,33 +216,48 @@ class BilinearInterpComputeTester : public arena::TestCase { if (outsize_height_ > 0 && outsize_width_ > 0) { inputs.emplace_back(scope->FindTensor(input1_)); } + std::vector SizeTensor; + if (outsize_height_ > 0 && outsize_width_ > 0) { + SizeTensor.emplace_back(scope->FindTensor(sizetensor0_)); + SizeTensor.emplace_back(scope->FindTensor(sizetensor1_)); + } + const lite::Tensor* input_scale = scope->FindTensor(input_scale_); + float scale = height_scale_; + int in_h = inputs[0]->dims()[2]; + int in_w = inputs[0]->dims()[3]; + if (SizeTensor.size() > 0) { + auto new_size = get_new_shape(SizeTensor); + out_height_ = new_size[0]; + out_width_ = new_size[1]; + } else { + auto scale_tensor = input_scale; + if (scale_tensor != nullptr) { + auto scale_data = get_new_data_from_tensor(scale_tensor); + scale = scale_data[0]; + } + if (scale > 0) { + out_height_ = static_cast(in_h * scale); + out_width_ = static_cast(in_w * scale); + } + if (inputs.size() > 1) { + auto out_size = inputs[1]; + auto out_size_data = get_new_data_from_tensor(out_size); + out_height_ = out_size_data[0]; + out_width_ = out_size_data[1]; + } + } + height_scale_ = scale; + width_scale_ = scale; + if (out_width_ != -1 && out_height_ != -1) { height_scale_ = static_cast(out_height_ / inputs[0]->dims()[2]); width_scale_ = static_cast(out_width_ / inputs[0]->dims()[3]); } auto* outputs = scope->NewTensor(output_); CHECK(outputs); - if (inputs.size() > 1) { - auto outsize_data = inputs[1]->data(); - int h_out = outsize_data[0]; // HW - int w_out = outsize_data[1]; // HW - int num_cout = inputs[0]->dims()[0]; - int c_cout = inputs[0]->dims()[1]; - outputs->Resize({num_cout, c_cout, h_out, w_out}); - } else { - int out_h; - int out_w; - if (-1 == out_height_ && -1 == out_width_) { - out_h = inputs[0]->dims()[2] * height_scale_; - out_w = inputs[0]->dims()[3] * width_scale_; - } else { - out_h = out_height_; - out_w = out_width_; - } - outputs->Resize( - {inputs[0]->dims()[0], inputs[0]->dims()[1], out_h, out_w}); - } - + int num_cout = inputs[0]->dims()[0]; + int c_cout = inputs[0]->dims()[1]; + outputs->Resize({num_cout, c_cout, out_height_, out_width_}); if (align_corners_) { resize_bilinear_align(inputs, outputs); } else { @@ -229,6 +270,10 @@ class BilinearInterpComputeTester : public arena::TestCase { op_desc->SetInput("X", {input0_}); if (outsize_height_ > 0 && outsize_width_ > 0) { op_desc->SetInput("OutSize", {input1_}); + op_desc->SetInput("SizeTensor", {sizetensor0_, sizetensor1_}); + } + if (height_scale_ > 0) { + op_desc->SetInput("Scale", {input_scale_}); } op_desc->SetOutput("Out", {output_}); op_desc->SetAttr("scale", height_scale_); @@ -250,6 +295,19 @@ class BilinearInterpComputeTester : public arena::TestCase { data1[0] = outsize_height_; data1[1] = outsize_width_; SetCommonTensor(input1_, _dims1_, data1.data()); + + std::vector sizetensor_data(1); + sizetensor_data[0] = outsize_height_; + SetCommonTensor(sizetensor0_, sizetensor_dims_, sizetensor_data.data()); + + sizetensor_data[0] = outsize_width_; + SetCommonTensor(sizetensor1_, sizetensor_dims_, sizetensor_data.data()); + } + + if (height_scale_ > 0) { + std::vector scale_data(1); + scale_data[0] = height_scale_; + SetCommonTensor(input_scale_, scale_dims_, scale_data.data()); } } }; diff --git a/lite/tests/kernels/nearest_interp_compute_test.cc b/lite/tests/kernels/nearest_interp_compute_test.cc index 3256ababcab639cd31ef51294a890b7fbdb54d5d..894959f9090cce8a391c146815f550d5f42adcb6 100644 --- a/lite/tests/kernels/nearest_interp_compute_test.cc +++ b/lite/tests/kernels/nearest_interp_compute_test.cc @@ -22,6 +22,28 @@ namespace paddle { namespace lite { +inline std::vector get_new_shape( + const std::vector& list_new_shape_tensor) { + // get tensor from + std::vector vec_new_shape; + for (size_t i = 0; i < list_new_shape_tensor.size(); ++i) { + auto tensor = list_new_shape_tensor[i]; + vec_new_shape.push_back(static_cast(*tensor->data())); + } + + return vec_new_shape; +} + +template +inline std::vector get_new_data_from_tensor(const Tensor* new_data_tensor) { + std::vector vec_new_data; + auto* new_data = new_data_tensor->data(); + lite::Tensor cpu_starts_tensor; + vec_new_data = + std::vector(new_data, new_data + new_data_tensor->dims().production()); + return vec_new_data; +} + template void resize_nearest_align(std::vector inputs, lite::Tensor* output, @@ -73,6 +95,9 @@ class NearestInterpComputeTester : public arena::TestCase { protected: // common attributes for this op. std::string input0_ = "X"; + std::string sizetensor0_ = "SizeTensor0"; + std::string sizetensor1_ = "SizeTensor1"; + std::string input_scale_ = "Scale"; std::string input1_ = "OutSize"; std::string output_ = "Out"; @@ -85,6 +110,8 @@ class NearestInterpComputeTester : public arena::TestCase { DDim dims_{{2, 3}}; DDim _dims0_{{2, 3, 3, 2}}; DDim _dims1_{{2}}; + DDim sizetensor_dims_{{1}}; + DDim scale_dims_{{1}}; public: NearestInterpComputeTester(const Place& place, @@ -112,24 +139,54 @@ class NearestInterpComputeTester : public arena::TestCase { inputs.emplace_back(scope->FindTensor(input0_)); inputs.emplace_back(scope->FindTensor(input1_)); - auto outsize_data = inputs[1]->data(); + std::vector SizeTensor(2); + SizeTensor[0] = scope->FindTensor(sizetensor0_); + SizeTensor[1] = scope->FindTensor(sizetensor1_); + const lite::Tensor* input_scale = scope->FindTensor(input_scale_); + + float scale = height_scale_; + int in_h = inputs[0]->dims()[2]; + int in_w = inputs[0]->dims()[3]; + if (SizeTensor.size() > 0) { + auto new_size = get_new_shape(SizeTensor); + out_height_ = new_size[0]; + out_width_ = new_size[1]; + } else { + auto scale_tensor = input_scale; + if (scale_tensor != nullptr) { + auto scale_data = get_new_data_from_tensor(scale_tensor); + scale = scale_data[0]; + } + if (scale > 0) { + out_height_ = static_cast(in_h * scale); + out_width_ = static_cast(in_w * scale); + } + auto out_size = inputs[1]; + if (out_size != nullptr) { + auto out_size_data = get_new_data_from_tensor(out_size); + out_height_ = out_size_data[0]; + out_width_ = out_size_data[1]; + } + } + height_scale_ = scale; + width_scale_ = scale; + if (out_width_ != -1 && out_height_ != -1) { height_scale_ = static_cast(out_height_ / inputs[0]->dims()[2]); width_scale_ = static_cast(out_width_ / inputs[0]->dims()[3]); } - if (inputs.size() > 1) { - int h_out = outsize_data[0]; // HW - int w_out = outsize_data[1]; // HW - int num_cout = outputs->dims()[0]; - int c_cout = outputs->dims()[1]; - outputs->Resize({num_cout, c_cout, h_out, w_out}); - } + int num_cout = inputs[0]->dims()[0]; + int c_cout = inputs[0]->dims()[1]; + outputs->Resize({num_cout, c_cout, out_height_, out_width_}); + resize_nearest_align(inputs, outputs, align_corners_); } void PrepareOpDesc(cpp::OpDesc* op_desc) { op_desc->SetType("nearest_interp"); op_desc->SetInput("X", {input0_}); + op_desc->SetInput("SizeTensor", {sizetensor0_, sizetensor1_}); + op_desc->SetInput("Scale", {input_scale_}); op_desc->SetInput("OutSize", {input1_}); op_desc->SetOutput("Out", {output_}); op_desc->SetAttr("scale", height_scale_); @@ -152,6 +209,17 @@ class NearestInterpComputeTester : public arena::TestCase { SetCommonTensor(input0_, _dims0_, data0.data()); SetCommonTensor(input1_, _dims1_, data1.data()); + + std::vector sizetensor_data(1); + sizetensor_data[0] = out_height_; + SetCommonTensor(sizetensor0_, sizetensor_dims_, sizetensor_data.data()); + + sizetensor_data[0] = out_width_; + SetCommonTensor(sizetensor1_, sizetensor_dims_, sizetensor_data.data()); + + std::vector scale_data(1); + scale_data[0] = height_scale_; + SetCommonTensor(input_scale_, scale_dims_, scale_data.data()); } }; diff --git a/lite/tests/kernels/shuffle_channel_compute_test.cc b/lite/tests/kernels/shuffle_channel_compute_test.cc index d0e9912e65de7a0aae10f83c31ba4ab5bbd50890..66123625fae606a9022537698cdc1032abb13451 100644 --- a/lite/tests/kernels/shuffle_channel_compute_test.cc +++ b/lite/tests/kernels/shuffle_channel_compute_test.cc @@ -12,12 +12,9 @@ // See the License for the specific language governing permissions and // limitations under the License. -// TODO(zhengxi) -// shuffle_channel_test can pass on local compilation -// while on ci compilation, the test will be killed immediately. - -/* -#include +// TODO(FrostML): shaffle_channel cannot pass on CI, but ok in local machine. +// Open this. +/*#include #include "lite/api/paddle_use_kernels.h" #include "lite/api/paddle_use_ops.h" #include "lite/core/arena/framework.h" @@ -30,8 +27,8 @@ class ShuffleChannelComputeTester : public arena::TestCase { // common attributes for this op. std::string input_ = "X"; std::string output_ = "Out"; - int group_ = 1; - DDim dims_{{1, 2}}; + int group_ = 4; + DDim dims_{{10, 16, 4, 4}}; public: ShuffleChannelComputeTester(const Place& place, @@ -87,7 +84,7 @@ class ShuffleChannelComputeTester : public arena::TestCase { }; void test_shuffle_channel(Place place) { - for (int group : {1, 2, 3}) { + for (int group : {4}) { std::unique_ptr tester( new ShuffleChannelComputeTester(place, "def", group)); arena::Arena arena(std::move(tester), place, 2e-5);