diff --git a/paddle/fluid/lite/arm/math/CMakeLists.txt b/paddle/fluid/lite/arm/math/CMakeLists.txt index 02a72ad8a737ae3c1bc1e7ec3ac2822c5c0b9388..3c0fa97afae0f83485004e14ddf60eac97cdef3e 100644 --- a/paddle/fluid/lite/arm/math/CMakeLists.txt +++ b/paddle/fluid/lite/arm/math/CMakeLists.txt @@ -6,4 +6,4 @@ if(NOT (LITE_WITH_LIGHT_WEIGHT_FRAMEWORK AND LITE_WITH_ARM)) return() endif() -cc_library(math_arm SRCS funcs.cc packed_sgemm.cc softmax.cc DEPS ${lite_kernel_deps} eigen3) +cc_library(math_arm SRCS funcs.cc packed_sgemm.cc softmax.cc scale.cc DEPS ${lite_kernel_deps} eigen3) diff --git a/paddle/fluid/lite/arm/math/scale.cc b/paddle/fluid/lite/arm/math/scale.cc new file mode 100644 index 0000000000000000000000000000000000000000..40b91e6979f6f330f96f4c086fe1856707d9b189 --- /dev/null +++ b/paddle/fluid/lite/arm/math/scale.cc @@ -0,0 +1,64 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/lite/arm/math/scale.h" +#include "paddle/fluid/lite/arm/math/funcs.h" + +namespace paddle { +namespace lite { +namespace arm { +namespace math { + +template <> +void scale(const float* din, float* dout, int num, float scale, + float bias) { + int cnt = num >> 4; + int remain = num % 16; + float32x4_t vscale = vdupq_n_f32(scale); + float32x4_t vbias = vdupq_n_f32(bias); +#pragma omp parallel for + for (int i = 0; i < cnt; i++) { + const float* din_ptr = din + (i << 4); + float* dout_ptr = dout + (i << 4); + + float32x4_t din0 = vld1q_f32(din_ptr); + float32x4_t din1 = vld1q_f32(din_ptr + 4); + float32x4_t din2 = vld1q_f32(din_ptr + 8); + float32x4_t din3 = vld1q_f32(din_ptr + 12); + + float32x4_t vsum1 = vmlaq_f32(vbias, din0, vscale); + float32x4_t vsum2 = vmlaq_f32(vbias, din1, vscale); + float32x4_t vsum3 = vmlaq_f32(vbias, din2, vscale); + float32x4_t vsum4 = vmlaq_f32(vbias, din3, vscale); + + vst1q_f32(dout_ptr, vsum1); + vst1q_f32(dout_ptr + 4, vsum2); + vst1q_f32(dout_ptr + 8, vsum3); + vst1q_f32(dout_ptr + 12, vsum4); + } + if (remain > 0) { + const float* din_ptr = din + (cnt << 4); + float* dout_ptr = dout + (cnt << 4); + for (int i = 0; i < remain; i++) { + *dout_ptr = *din_ptr * scale + bias; + dout_ptr++; + din_ptr++; + } + } +} + +} // namespace math +} // namespace arm +} // namespace lite +} // namespace paddle diff --git a/paddle/fluid/lite/arm/math/scale.h b/paddle/fluid/lite/arm/math/scale.h new file mode 100644 index 0000000000000000000000000000000000000000..97a5f79fc6bfabee5e38854e2ba89ce388648aac --- /dev/null +++ b/paddle/fluid/lite/arm/math/scale.h @@ -0,0 +1,28 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +namespace paddle { +namespace lite { +namespace arm { +namespace math { + +template +void scale(const T* din, T* dout, int num, float scale, float bias); + +} // namespace math +} // namespace arm +} // namespace lite +} // namespace paddle diff --git a/paddle/fluid/lite/core/hvy_tensor.h b/paddle/fluid/lite/core/hvy_tensor.h index 1fa8dbbee33a8776e87abe81493680ddd45a9420..7c60e1661ce1384ef5b49a57c7a5145697482909 100644 --- a/paddle/fluid/lite/core/hvy_tensor.h +++ b/paddle/fluid/lite/core/hvy_tensor.h @@ -102,7 +102,8 @@ class TensorHvy : public TensorBase { data_.ShareDataWith(other.data_); } void CopyDataFrom(const TensorHvy& other) { - data_.ShareDataWith(other.data_); + data_.mutable_data(other.data_.place(), other.data_.type()); + TensorCopySync(other.data_, data_.place(), &data_); } DDimT dims() const { return DDimT(framework::vectorize(data_.dims())); } diff --git a/paddle/fluid/lite/core/lite_tensor.h b/paddle/fluid/lite/core/lite_tensor.h index 433bc6911164f106bbf595b77a5665597bc1ce34..24f28300d05a2c56c9305a22624fca755ff3d3ef 100644 --- a/paddle/fluid/lite/core/lite_tensor.h +++ b/paddle/fluid/lite/core/lite_tensor.h @@ -37,7 +37,7 @@ class DDimLite : public DDimBase { void ConstructFrom(const std::vector &x) { data_ = x; } value_type operator[](int offset) const { return data_[offset]; } - std::vector Vectorize() { return data_; } + std::vector Vectorize() const { return data_; } size_t size() const { return data_.size(); } bool empty() const { return data_.empty(); } diff --git a/paddle/fluid/lite/core/tensor.h b/paddle/fluid/lite/core/tensor.h index 807fbfc6a623501ddebaf14a95ffe0f57d478f1c..11b682a617c8654b18c80e91b3dcdb7057f6d264 100644 --- a/paddle/fluid/lite/core/tensor.h +++ b/paddle/fluid/lite/core/tensor.h @@ -48,7 +48,7 @@ class DDimBase { explicit DDimBase(const std::vector &x) { self()->ConstructFrom(x); } value_type operator[](int offset) const { return (*self())[offset]; } - std::vector Vectorize() { return self()->Vectorize(); } + std::vector Vectorize() const { return self()->Vectorize(); } size_t size() const { return const_self()->size(); } bool empty() const { return const_self()->empty(); } diff --git a/paddle/fluid/lite/kernels/arm/CMakeLists.txt b/paddle/fluid/lite/kernels/arm/CMakeLists.txt index 82b1a07810ffc265510b2e6d7dac1d2cbadcbbb5..1b5f5b7fdc3399afbddf114e2096fc06810237ed 100644 --- a/paddle/fluid/lite/kernels/arm/CMakeLists.txt +++ b/paddle/fluid/lite/kernels/arm/CMakeLists.txt @@ -7,10 +7,11 @@ message(STATUS "compile with lite ARM kernels") cc_library(fc_compute_arm SRCS fc_compute.cc DEPS ${lite_kernel_deps} math_arm) cc_library(relu_compute_arm SRCS relu_compute.cc DEPS ${lite_kernel_deps}) cc_library(mul_compute_arm SRCS mul_compute.cc DEPS ${lite_kernel_deps} eigen3) -cc_library(scale_compute_arm SRCS scale_compute.cc DEPS ${lite_kernel_deps} eigen3) +cc_library(scale_compute_arm SRCS scale_compute.cc DEPS ${lite_kernel_deps} math_arm) cc_library(softmax_compute_arm SRCS softmax_compute.cc DEPS ${lite_kernel_deps} math_arm) lite_cc_test(test_fc_compute_arm SRCS fc_compute_test.cc DEPS fc_compute_arm math_arm) +lite_cc_test(test_scale_compute_arm SRCS scale_compute_test.cc DEPS scale_compute_arm) lite_cc_test(test_softmax_compute_arm SRCS softmax_compute_test.cc DEPS softmax_compute_arm) set(arm_kernels diff --git a/paddle/fluid/lite/kernels/arm/scale_compute.cc b/paddle/fluid/lite/kernels/arm/scale_compute.cc index f078318e42ac1e8eaeff752a17d35008d86b2d4a..a89e19fb05a412ade6cb630f4d800cdcc244f75c 100644 --- a/paddle/fluid/lite/kernels/arm/scale_compute.cc +++ b/paddle/fluid/lite/kernels/arm/scale_compute.cc @@ -12,39 +12,28 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include -#include "paddle/fluid/lite/core/kernel.h" -#include "paddle/fluid/lite/core/op_registry.h" -#include "paddle/fluid/lite/core/types.h" +#include "paddle/fluid/lite/kernels/arm/scale_compute.h" +#include "paddle/fluid/lite/arm/math/funcs.h" namespace paddle { namespace lite { namespace kernels { namespace arm { -template -void scale_compute(const T* x, T* out, int size, float scale, float bias, - bool bias_before) { - if (bias_before) bias *= scale; - for (int i = 0; i < size; i++) { - out[i] = x[i] * scale + bias; +void ScaleCompute::Run() { + auto& param = Param(); + const float* x_data = param.x->data(); + float* output_data = param.output->mutable_data(); + DDim x_dims = param.x->dims(); + bool bias_after_scale = param.bias_after_scale; + float scale = param.scale; + float bias = param.bias; + if (!bias_after_scale) { + bias *= scale; } + lite::arm::math::scale(x_data, output_data, x_dims.production(), scale, bias); } -class ScaleCompute : public KernelLite { - public: - using param_t = operators::MulParam; - - void Run() override { - auto& param = Param(); - scale_compute(param.x->data(), param.output->mutable_data(), - param.x->dims().production(), param.scale, param.bias, - param.bias_after_scale); - } - - virtual ~ScaleCompute() = default; -}; - } // namespace arm } // namespace kernels } // namespace lite diff --git a/paddle/fluid/lite/kernels/arm/scale_compute.h b/paddle/fluid/lite/kernels/arm/scale_compute.h new file mode 100644 index 0000000000000000000000000000000000000000..b0ee41c654d209f1a9eec7701f14e7ffb09fae76 --- /dev/null +++ b/paddle/fluid/lite/kernels/arm/scale_compute.h @@ -0,0 +1,34 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include "paddle/fluid/lite/core/kernel.h" +#include "paddle/fluid/lite/core/op_registry.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace arm { + +class ScaleCompute : public KernelLite { + public: + void Run() override; + + virtual ~ScaleCompute() = default; +}; + +} // namespace arm +} // namespace kernels +} // namespace lite +} // namespace paddle diff --git a/paddle/fluid/lite/kernels/arm/scale_compute_test.cc b/paddle/fluid/lite/kernels/arm/scale_compute_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..fee47d7eb7a6c093524bb0af617c60d069add01a --- /dev/null +++ b/paddle/fluid/lite/kernels/arm/scale_compute_test.cc @@ -0,0 +1,106 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/lite/kernels/arm/scale_compute.h" +#include +#include +#include "paddle/fluid/lite/core/op_registry.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace arm { + +template +void scale_compute_ref(const operators::ScaleParam& param) { + const dtype* x_data = param.x->mutable_data(); + dtype* output_data = param.output->mutable_data(); + DDim x_dims = param.x->dims(); + DDim output_dims = param.output->dims(); + ASSERT_EQ(x_dims.data(), output_dims.data()); + bool bias_after_scale = param.bias_after_scale; + float scale = param.scale; + float bias = param.bias; + if (!bias_after_scale) { + bias *= scale; + } + for (int i = 0; i < output_dims.production(); i++) { + output_data[i] = x_data[i] * scale + bias; + } +} + +TEST(scale_arm, init) { + ScaleCompute scale; + ASSERT_EQ(scale.precision(), PRECISION(kFloat)); + ASSERT_EQ(scale.target(), TARGET(kARM)); +} + +TEST(scale_arm, compute) { + ScaleCompute scale; + operators::ScaleParam param; + + lite::Tensor x; + lite::Tensor output; + lite::Tensor output_ref; + + for (auto n : {1, 3, 4, 11}) { + for (auto c : {1, 3, 11, 4}) { + for (auto h : {3, 1, 11, 4}) { + for (auto w : {1, 3, 4, 12}) { + for (auto bias_after_scale : {true, false}) { + for (auto s : {-100.25f, -1.0f, 0.13f, 3840.975f}) { + for (auto b : {-3075.495f, -15.f, 0.11234f, 128.15f}) { + x.Resize(DDim(std::vector({n, c, h, w}))); + output.Resize(DDim(std::vector({n, c, h, w}))); + output_ref.Resize(DDim(std::vector({n, c, h, w}))); + auto* x_data = x.mutable_data(); + auto* output_data = output.mutable_data(); + auto* output_ref_data = output_ref.mutable_data(); + for (int i = 0; i < x.dims().production(); i++) { + x_data[i] = i; + } + param.x = &x; + param.output = &output; + param.bias_after_scale = bias_after_scale; + param.scale = s; + param.bias = b; + scale.SetParam(param); + scale.Run(); + param.output = &output_ref; + scale_compute_ref(param); + for (int i = 0; i < output.dims().production(); i++) { + EXPECT_NEAR(output_data[i], output_ref_data[i], 1e-5); + } + } + } + } + } + } + } + } +} + +TEST(scale, retrive_op) { + auto scale = + KernelRegistry::Global().Create("scale"); + ASSERT_FALSE(scale.empty()); + ASSERT_TRUE(scale.front()); +} + +} // namespace arm +} // namespace kernels +} // namespace lite +} // namespace paddle + +USE_LITE_KERNEL(scale, kARM, kFloat, kNCHW, def); diff --git a/paddle/fluid/lite/kernels/arm/softmax_compute.cc b/paddle/fluid/lite/kernels/arm/softmax_compute.cc index ceb061c901ff6c13c3e2160cd6a648ff36f4429e..099385395e2e79e46394273c894be9cc25d5feb1 100644 --- a/paddle/fluid/lite/kernels/arm/softmax_compute.cc +++ b/paddle/fluid/lite/kernels/arm/softmax_compute.cc @@ -24,15 +24,15 @@ void SoftmaxCompute::Run() { auto& param = Param(); const float* din = param.x->data(); float* dout = param.output->mutable_data(); - auto dim_x = param.x->dims(); - auto rank_x = dim_x.size(); + auto x_dims = param.x->dims(); + auto x_rank = x_dims.size(); int axis = param.axis; if (axis < 0) { - axis += rank_x; + axis += x_rank; } - int outer_num = dim_x.Slice(0, axis).production(); - int inner_num = dim_x.Slice(axis + 1, rank_x).production(); - int axis_size = dim_x[axis]; + int outer_num = x_dims.Slice(0, axis).production(); + int inner_num = x_dims.Slice(axis + 1, x_rank).production(); + int axis_size = x_dims[axis]; if (inner_num == 1) { if (axis_size >= 4) { lite::arm::math::softmax_inner1_large_axis(din, dout, outer_num, @@ -64,10 +64,6 @@ void SoftmaxCompute::Run() { } } -TargetType SoftmaxCompute::target() const { return TARGET(kARM); } - -PrecisionType SoftmaxCompute::precision() const { return PRECISION(kFloat); } - } // namespace arm } // namespace kernels } // namespace lite diff --git a/paddle/fluid/lite/kernels/arm/softmax_compute.h b/paddle/fluid/lite/kernels/arm/softmax_compute.h index 2daec0f9ee4167772fa7eeb5c0059a810f5db9ca..4d538473ebd89e61384e02164fb10a4ae1997df1 100644 --- a/paddle/fluid/lite/kernels/arm/softmax_compute.h +++ b/paddle/fluid/lite/kernels/arm/softmax_compute.h @@ -26,9 +26,6 @@ class SoftmaxCompute : public KernelLite { public: void Run() override; - TargetType target() const override; - PrecisionType precision() const override; - virtual ~SoftmaxCompute() = default; }; diff --git a/paddle/fluid/lite/kernels/arm/softmax_compute_test.cc b/paddle/fluid/lite/kernels/arm/softmax_compute_test.cc index d24868f2c5679e4f9b4bf0b5ad1bfbf62f3cbad5..80a64f4eaf74288d0fff6431ad1707afcf1b9eb2 100644 --- a/paddle/fluid/lite/kernels/arm/softmax_compute_test.cc +++ b/paddle/fluid/lite/kernels/arm/softmax_compute_test.cc @@ -14,6 +14,7 @@ #include "paddle/fluid/lite/kernels/arm/softmax_compute.h" #include +#include #include #include "paddle/fluid/lite/core/op_registry.h" @@ -23,19 +24,19 @@ namespace kernels { namespace arm { template -void softmat_compute_ref(const operators::SoftmaxParam& param) { +void softmax_compute_ref(const operators::SoftmaxParam& param) { const dtype* x_data = param.x->mutable_data(); dtype* output_data = param.output->mutable_data(); - DDim dim = param.x->dims(); - ASSERT_EQ(dim.data(), param.output->dims().data()); - auto rank = dim.size(); + DDim x_dims = param.x->dims(); + ASSERT_EQ(x_dims.data(), param.output->dims().data()); + auto x_rank = x_dims.size(); int axis = param.axis; if (axis < 0) { - axis += rank; + axis += x_rank; } - int axis_size = dim[axis]; - int outer_num = dim.Slice(0, axis).production(); - int inner_num = dim.Slice(axis + 1, rank).production(); + int axis_size = x_dims[axis]; + int outer_num = x_dims.Slice(0, axis).production(); + int inner_num = x_dims.Slice(axis + 1, x_rank).production(); int compute_size = outer_num * inner_num; for (int i = 0; i < compute_size; i++) { int idx_inner = i % inner_num; @@ -100,7 +101,7 @@ TEST(softmax_arm, compute) { softmax.SetParam(param); softmax.Run(); param.output = &output_ref; - softmat_compute_ref(param); + softmax_compute_ref(param); for (int i = 0; i < output.dims().production(); i++) { EXPECT_NEAR(output_data[i], output_ref_data[i], 1e-5); } diff --git a/paddle/fluid/lite/kernels/host/CMakeLists.txt b/paddle/fluid/lite/kernels/host/CMakeLists.txt index 5642d4d9d07e5248fe329b3db95015a4d8efd74d..03c0023cb4132058499e105deb612d22aca34007 100644 --- a/paddle/fluid/lite/kernels/host/CMakeLists.txt +++ b/paddle/fluid/lite/kernels/host/CMakeLists.txt @@ -2,10 +2,14 @@ message(STATUS "compile with lite host kernels") cc_library(feed_compute_host SRCS feed_compute.cc DEPS ${lite_kernel_deps}) cc_library(fetch_compute_host SRCS fetch_compute.cc DEPS ${lite_kernel_deps}) +cc_library(reshape_compute_host SRCS reshape_compute.cc DEPS ${lite_kernel_deps} reshape_op_lite) + +lite_cc_test(test_reshape_compute_host SRCS reshape_compute_test.cc DEPS reshape_compute_host) set(host_kernels feed_compute_host fetch_compute_host + reshape_compute_host ) set(host_kernels "${host_kernels}" CACHE INTERNAL "host kernels") diff --git a/paddle/fluid/lite/kernels/host/reshape_compute.cc b/paddle/fluid/lite/kernels/host/reshape_compute.cc new file mode 100644 index 0000000000000000000000000000000000000000..c797ddf45b4faeb704ed7f1ab6b04a302f78ff5f --- /dev/null +++ b/paddle/fluid/lite/kernels/host/reshape_compute.cc @@ -0,0 +1,80 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/lite/kernels/host/reshape_compute.h" +#include +#include "paddle/fluid/lite/operators/reshape_op.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace host { + +void ReshapeCompute::Run() { + auto& param = Param(); + auto x = param.x; + auto actual_shape = param.actual_shape; + auto output = param.output; + bool inplace = param.inplace; + auto x_dims = x->dims(); + auto output_dims = output->dims(); + if (actual_shape) { + auto actual_shape_dims = actual_shape->dims(); + auto* actual_shape_data = actual_shape->data(); +#ifdef LITE_WITH_CUDA + lite::Tensor cpu_actual_shape; + if (actual_shape->target() == TARGET(kCUDA)) { + cpu_actual_shape.CopyDataFrom(*actual_shape); + actual_shape_data = cpu_actual_shape.data(); + } +#endif + auto shape = std::vector( + actual_shape_data, actual_shape_data + actual_shape_dims.production()); + output_dims = lite::operators::ValidateShape(shape, x_dims); + output->Resize(output_dims); + } + if (inplace) { + output->ShareDataWith(*x); + } else { + output->CopyDataFrom(*x); + } + output->Resize(output_dims); +} + +} // namespace host +} // namespace kernels +} // namespace lite +} // namespace paddle + +REGISTER_LITE_KERNEL(reshape, kHost, kAny, kAny, + paddle::lite::kernels::host::ReshapeCompute, def) + .BindInput("X", {LiteType::GetTensorTy(TARGET(kHost), PRECISION(kAny), + DATALAYOUT(kAny), -1)}) + .BindInput("Shape", {LiteType::GetTensorTy(TARGET(kHost), PRECISION(kAny), + DATALAYOUT(kAny), -1)}) + .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kHost), PRECISION(kAny), + DATALAYOUT(kAny), -1)}) + .Finalize(); + +REGISTER_LITE_KERNEL(reshape2, kHost, kAny, kAny, + paddle::lite::kernels::host::ReshapeCompute, def) + .BindInput("X", {LiteType::GetTensorTy(TARGET(kHost), PRECISION(kAny), + DATALAYOUT(kAny), -1)}) + .BindInput("Shape", {LiteType::GetTensorTy(TARGET(kHost), PRECISION(kAny), + DATALAYOUT(kAny), -1)}) + .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kHost), PRECISION(kAny), + DATALAYOUT(kAny), -1)}) + .BindOutput("XShape", {LiteType::GetTensorTy(TARGET(kHost), PRECISION(kAny), + DATALAYOUT(kAny), -1)}) + .Finalize(); diff --git a/paddle/fluid/lite/kernels/host/reshape_compute.h b/paddle/fluid/lite/kernels/host/reshape_compute.h new file mode 100644 index 0000000000000000000000000000000000000000..423b589d37d01586c897a68c3d7129849bc25baa --- /dev/null +++ b/paddle/fluid/lite/kernels/host/reshape_compute.h @@ -0,0 +1,36 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include "paddle/fluid/lite/core/kernel.h" +#include "paddle/fluid/lite/core/op_registry.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace host { + +class ReshapeCompute + : public KernelLite { + public: + void Run() override; + + virtual ~ReshapeCompute() = default; +}; + +} // namespace host +} // namespace kernels +} // namespace lite +} // namespace paddle diff --git a/paddle/fluid/lite/kernels/host/reshape_compute_test.cc b/paddle/fluid/lite/kernels/host/reshape_compute_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..07a8101fec631a42a35e5f1b9705a6ac60b12a84 --- /dev/null +++ b/paddle/fluid/lite/kernels/host/reshape_compute_test.cc @@ -0,0 +1,101 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/lite/kernels/host/reshape_compute.h" +#include +#include +#include "paddle/fluid/lite/core/op_registry.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace host { + +TEST(reshape_host, init) { + ReshapeCompute reshape; + ASSERT_EQ(reshape.precision(), PRECISION(kAny)); + ASSERT_EQ(reshape.target(), TARGET(kHost)); +} + +TEST(reshape_host, compute) { + ReshapeCompute reshape; + operators::ReshapeParam param; + + Tensor x; + Tensor actual_shape; + Tensor output; + + x.Resize(DDim(std::vector({1, 2, 4, 6}))); + actual_shape.Resize(DDim(std::vector({2}))); + + auto* x_data = x.mutable_data(); + auto* actual_shape_data = actual_shape.mutable_data(); + for (int i = 0; i < x.dims().production(); i++) { + x_data[i] = i; + } + actual_shape_data[0] = 6; + actual_shape_data[1] = 8; + + param.x = &x; + param.shape = {-1, 0, 3, 2, 1}; + param.output = &output; + param.actual_shape = &actual_shape; + param.inplace = false; + reshape.SetParam(param); + reshape.Run(); + + // check output dims + CHECK_EQ(actual_shape.dims().production(), output.dims().size()); + for (int i = 0; i < output.dims().size(); i++) { + CHECK_EQ(output.dims()[i], actual_shape_data[i]); + } + + // check output data + auto* output_data = output.mutable_data(); + CHECK_NE(output_data, x_data); + for (int i = 0; i < output.dims().production(); i++) { + EXPECT_NEAR(output_data[i], x_data[i], 1e-6); + } + + // check output data if inplace = true; + param.inplace = true; + reshape.SetParam(param); + reshape.Run(); + output_data = output.mutable_data(); + CHECK_EQ(output_data, x_data); +} + +TEST(reshape, retrive_op) { + auto reshape = + KernelRegistry::Global() + .Create("reshape"); + ASSERT_FALSE(reshape.empty()); + ASSERT_TRUE(reshape.front()); +} + +TEST(reshape2, retrive_op) { + auto reshape2 = + KernelRegistry::Global() + .Create("reshape2"); + ASSERT_FALSE(reshape2.empty()); + ASSERT_TRUE(reshape2.front()); +} + +} // namespace host +} // namespace kernels +} // namespace lite +} // namespace paddle + +USE_LITE_KERNEL(reshape, kHost, kAny, kAny, def); +USE_LITE_KERNEL(reshape2, kHost, kAny, kAny, def); diff --git a/paddle/fluid/lite/kernels/host/use_kernels.h b/paddle/fluid/lite/kernels/host/use_kernels.h index 52e087cdfa0b938841b93c0d8aabf16ff68419ef..b3b534283b350952b20e6b0dc8f8ff3dfa50d78f 100644 --- a/paddle/fluid/lite/kernels/host/use_kernels.h +++ b/paddle/fluid/lite/kernels/host/use_kernels.h @@ -17,3 +17,5 @@ USE_LITE_KERNEL(feed, kHost, kAny, kAny, def); USE_LITE_KERNEL(fetch, kHost, kAny, kAny, def); +USE_LITE_KERNEL(reshape, kHost, kAny, kAny, def); +USE_LITE_KERNEL(reshape2, kHost, kAny, kAny, def); diff --git a/paddle/fluid/lite/model_parser/pb/op_desc.cc b/paddle/fluid/lite/model_parser/pb/op_desc.cc index fb269cd067180b9df30aba27f9dd61b61b58279d..27ccc5c686aac0cd40a6f2e1dff14f93351715c5 100644 --- a/paddle/fluid/lite/model_parser/pb/op_desc.cc +++ b/paddle/fluid/lite/model_parser/pb/op_desc.cc @@ -38,6 +38,29 @@ void OpDesc::SetAttr(const std::string &name, it->set_s(v.c_str()); } +template <> +void OpDesc::SetAttr>(const std::string &name, + const std::vector &v) { + auto &xs = *desc_.mutable_attrs(); + auto it = std::find_if( + xs.begin(), xs.end(), + [&](const framework::proto::OpDesc_Attr &x) { return x.name() == name; }); + if (it == xs.end()) { + auto *attr = xs.Add(); + attr->set_name(name); + it = std::find_if(xs.begin(), xs.end(), + [&](const framework::proto::OpDesc_Attr &x) { + return x.name() == name; + }); + } + + it->set_type(framework::proto::INTS); + it->clear_ints(); + for (auto &i : v) { + it->add_ints(i); + } +} + } // namespace pb } // namespace lite } // namespace paddle diff --git a/paddle/fluid/lite/model_parser/pb/op_desc.h b/paddle/fluid/lite/model_parser/pb/op_desc.h index b1fbce54d865c9924fbdc686edd87662be4cec4f..0be809da8db87c0a45d72e4f16a5f6576f431013 100644 --- a/paddle/fluid/lite/model_parser/pb/op_desc.h +++ b/paddle/fluid/lite/model_parser/pb/op_desc.h @@ -33,7 +33,8 @@ namespace paddle { namespace lite { namespace pb { -using Attribute = variant>; +using Attribute = + variant, std::vector>; using VariableNameMap = std::map>; /* @@ -152,7 +153,6 @@ class OpDesc { Attribute res; CHECK(it != xs.end()); - switch (it->type()) { case framework::proto::INT: res.set(it->i()); @@ -166,6 +166,13 @@ class OpDesc { case framework::proto::BOOLEAN: res.set(it->b()); break; + case framework::proto::INTS: { + std::vector values; + const auto &ys = it->ints(); + std::transform(ys.begin(), ys.end(), std::back_inserter(values), + [](const int &x) { return x; }); + res.set>(values); + } break; default: LOG(FATAL) << "unsupported attr type"; @@ -231,6 +238,10 @@ template <> void OpDesc::SetAttr(const std::string &name, const std::string &v); +template <> +void OpDesc::SetAttr>(const std::string &name, + const std::vector &v); + } // namespace pb } // namespace lite } // namespace paddle diff --git a/paddle/fluid/lite/operators/CMakeLists.txt b/paddle/fluid/lite/operators/CMakeLists.txt index d17ff90ecf073b8a34556912e315d06214885caa..6ed8a410c37a57949bcf092a9973d09bdb0110ff 100644 --- a/paddle/fluid/lite/operators/CMakeLists.txt +++ b/paddle/fluid/lite/operators/CMakeLists.txt @@ -5,6 +5,7 @@ cc_library(relu_op_lite SRCS relu_op.cc DEPS ${op_DEPS}) cc_library(mul_op_lite SRCS mul_op.cc DEPS ${op_DEPS}) cc_library(scale_op_lite SRCS scale_op.cc DEPS ${op_DEPS}) cc_library(softmax_op_lite SRCS softmax_op.cc DEPS ${op_DEPS}) +cc_library(reshape_op_lite SRCS reshape_op.cc DEPS ${op_DEPS} ) cc_library(feed_op_lite SRCS feed_op.cc DEPS ${op_DEPS}) cc_library(fetch_op_lite SRCS fetch_op.cc DEPS ${op_DEPS}) cc_library(io_copy_op_lite SRCS io_copy_op.cc DEPS ${op_DEPS}) @@ -22,6 +23,7 @@ set(ops_lite mul_op_lite scale_op_lite softmax_op_lite + reshape_op_lite feed_op_lite fetch_op_lite io_copy_op_lite @@ -35,5 +37,7 @@ set(ops_lite lite_cc_test(test_fc_op_lite SRCS fc_op_test.cc DEPS fc_op_lite memory_lite X86_DEPS fc_compute_x86 - ARM_DEPS fc_compute_arm) + ARM_DEPS fc_compute_arm) +lite_cc_test(test_scale_op_lite SRCS scale_op_test.cc DEPS scale_op_lite memory_lite) lite_cc_test(test_softmax_op_lite SRCS softmax_op_test.cc DEPS softmax_op_lite memory_lite) +lite_cc_test(test_reshape_op_lite SRCS reshape_op_test.cc DEPS reshape_op_lite memory_lite) diff --git a/paddle/fluid/lite/operators/op_params.h b/paddle/fluid/lite/operators/op_params.h index d65a096d2d849892dde3907ac5e2e9e409aa633b..166a5ad868e232755ede45c939693abc474ab86e 100644 --- a/paddle/fluid/lite/operators/op_params.h +++ b/paddle/fluid/lite/operators/op_params.h @@ -102,6 +102,17 @@ struct SoftmaxParam { int axis{-1}; }; +// For Reshape and Reshape2 Op +struct ReshapeParam { + const lite::Tensor* x{}; + const lite::Tensor* actual_shape{nullptr}; + lite::Tensor* output{}; + lite::Tensor* xshape{}; + + std::vector shape{}; + bool inplace{false}; +}; + // For Convolution op struct ConvParam { lite::Tensor* x{}; diff --git a/paddle/fluid/lite/operators/reshape_op.cc b/paddle/fluid/lite/operators/reshape_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..bf43f52340cd75553a234cec6f96bd349a6dfdbc --- /dev/null +++ b/paddle/fluid/lite/operators/reshape_op.cc @@ -0,0 +1,145 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/lite/operators/reshape_op.h" +#include "paddle/fluid/lite/core/op_registry.h" + +namespace paddle { +namespace lite { +namespace operators { + +bool ReshapeOp::CheckShape() const { + CHECK_OR_FALSE(param_.x); + CHECK_OR_FALSE(param_.output); + CHECK_OR_FALSE(!param_.shape.empty()); + return true; +} + +bool ReshapeOp::InferShape() const { + auto x_dims = param_.x->dims(); + auto output_dims = ValidateShape(param_.shape, x_dims); + param_.output->Resize(output_dims); + return true; +} + +bool ReshapeOp::AttachImpl(const OpDesc &opdesc, lite::Scope *scope) { + auto x_var = scope->FindVar(opdesc.Input("X").front()); + auto output_var = scope->FindVar(opdesc.Output("Out").front()); + CHECK(x_var); + CHECK(output_var); + param_.x = const_cast(&(x_var->Get())); + param_.output = output_var->GetMutable(); + std::vector input_arg_names = opdesc.InputArgumentNames(); + if (std::find(input_arg_names.begin(), input_arg_names.end(), "Shape") != + input_arg_names.end()) { + auto actual_shape_var = scope->FindVar(opdesc.Input("Shape").front()); + if (actual_shape_var != nullptr) { + param_.actual_shape = + const_cast(&(actual_shape_var->Get())); + } + } + param_.shape = GetAttr>(opdesc.GetAttr("shape")); + if (opdesc.HasAttr("inplace")) { + param_.inplace = GetAttr(opdesc.GetAttr("inplace")); + } + CHECK(param_.x) << "Input(X) of ReshapeOp should not be null."; + CHECK(param_.output) << "Output(Out) of ReshapeOp should not be null."; + CHECK(!param_.shape.empty()) + << "The shape information must be set by Attr(shape)."; + return true; +} + +bool Reshape2Op::CheckShape() const { + ReshapeOp::CheckShape(); + CHECK_OR_FALSE(param_.xshape); + return true; +} + +bool Reshape2Op::InferShape() const { + ReshapeOp::InferShape(); + auto x_dims = param_.x->dims(); + std::vector xshape_dims(x_dims.size() + 1, 0); + for (int i = 0; i < x_dims.size(); i++) { + xshape_dims[i + 1] = x_dims[i]; + } + param_.xshape->Resize(DDim(xshape_dims)); + return true; +} + +bool Reshape2Op::AttachImpl(const OpDesc &opdesc, lite::Scope *scope) { + ReshapeOp::AttachImpl(opdesc, scope); + auto xshape_var = scope->FindVar(opdesc.Output("XShape").front()); + CHECK(xshape_var); + param_.xshape = xshape_var->GetMutable(); + CHECK(param_.xshape) << "Output(XShape) of ReshapeOp should not be null."; + return true; +} + +DDim ValidateShape(const std::vector &shape, const DDim &input_dims) { + const DDim::value_type input_size = input_dims.production(); + auto input_shape = input_dims.Vectorize(); + bool all_positive = std::all_of(input_shape.cbegin(), input_shape.cend(), + [](DDim::value_type i) { return i > 0; }); + // only one dimension can be set to -1, whose size will be automatically + // infered. + const int unk_dim_val = -1; + const int copy_dim_val = 0; + + std::vector output_shape(shape.size(), 0); + DDim::value_type capacity = 1; + int unk_dim_idx = -1; + for (size_t i = 0; i < shape.size(); ++i) { + if (shape[i] == unk_dim_val) { + CHECK_EQ(unk_dim_idx, -1) + << "Only one input dimension of Attr(shape) can be unknown."; + unk_dim_idx = i; + } else if (shape[i] == copy_dim_val) { + CHECK_LT(static_cast(i), input_shape.size()) + << "The index of dimension to copy from input shape must be less " + "than the size of input shape."; + } else { + CHECK_GT(shape[i], 0) << "Each input dimension of Attr(shape) must not " + "be negtive except one unknown dimension."; + } + + capacity *= + (shape[i] ? static_cast(shape[i]) : input_shape[i]); + output_shape[i] = + (shape[i] ? static_cast(shape[i]) : input_shape[i]); + } + + if (unk_dim_idx != -1) { + if (all_positive) { + // input_size < 0 and is un-determinate in compile time, skip the check, + // for example, input_dims = [-1, 8, 1, 1], shape = [-1, 3, 8], + // capacity = -24, input_size = -8, output_shape[0] = 0 + // the following check will fail. + output_shape[unk_dim_idx] = -input_size / capacity; + CHECK_EQ(output_shape[unk_dim_idx] * capacity, -input_size) + << "Invalid shape is given."; + } else { + output_shape[unk_dim_idx] = -1; + } + } else { + CHECK_EQ(capacity, input_size) << "Invalid shape is given."; + } + return DDim(output_shape); +} + +} // namespace operators +} // namespace lite +} // namespace paddle + +REGISTER_LITE_OP(reshape, paddle::lite::operators::ReshapeOp); +REGISTER_LITE_OP(reshape2, paddle::lite::operators::Reshape2Op); diff --git a/paddle/fluid/lite/operators/reshape_op.h b/paddle/fluid/lite/operators/reshape_op.h new file mode 100644 index 0000000000000000000000000000000000000000..d96da8d5d01d9b86ae41d06dc0bdb3ea53a4b87d --- /dev/null +++ b/paddle/fluid/lite/operators/reshape_op.h @@ -0,0 +1,63 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include +#include "paddle/fluid/lite/core/op_lite.h" +#include "paddle/fluid/lite/core/scope.h" +#include "paddle/fluid/lite/utils/all.h" + +namespace paddle { +namespace lite { +namespace operators { + +class ReshapeOp : public OpLite { + public: + ReshapeOp() {} + explicit ReshapeOp(const std::string &op_type) : OpLite(op_type) {} + + bool CheckShape() const override; + + bool InferShape() const override; + + bool AttachImpl(const OpDesc &opdesc, lite::Scope *scope) override; + + void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); } + std::string DebugString() const override { return "reshape"; } + + protected: + mutable ReshapeParam param_; +}; + +class Reshape2Op : public ReshapeOp { + public: + Reshape2Op() : ReshapeOp() {} + explicit Reshape2Op(const std::string &op_type) : ReshapeOp(op_type) {} + + bool CheckShape() const override; + + bool InferShape() const override; + + bool AttachImpl(const OpDesc &opdesc, lite::Scope *scope) override; + + void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); } + std::string DebugString() const override { return "reshape2"; } +}; + +DDim ValidateShape(const std::vector &shape, const DDim &input_dims); + +} // namespace operators +} // namespace lite +} // namespace paddle diff --git a/paddle/fluid/lite/operators/reshape_op_test.cc b/paddle/fluid/lite/operators/reshape_op_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..41f6999c1d88782a5f9e782537dc286e8c188661 --- /dev/null +++ b/paddle/fluid/lite/operators/reshape_op_test.cc @@ -0,0 +1,145 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/lite/operators/reshape_op.h" +#include +#include +#include +#include "paddle/fluid/lite/core/op_registry.h" + +namespace paddle { +namespace lite { +namespace operators { + +TEST(reshape_op_lite, test) { + // prepare variables + Scope scope; + auto* x = scope.Var("x")->GetMutable(); + auto* actual_shape = scope.Var("actual_shape")->GetMutable(); + auto* output = scope.Var("output")->GetMutable(); + std::map, std::vector> shapes = { + {{-1, 0, 3, 2, 1}, {2, 4, 3, 2, 1}}, + {{0, -1, 3, 2, 1}, {2, 4, 3, 2, 1}}, + {{-1, 48}, {1, 48}}, + {{48, -1}, {48, 1}}, + {{0, 24}, {2, 24}}, + {{12, 0}, {12, 4}}, + }; + x->Resize(DDim(std::vector({2, 4, 6}))); + actual_shape->Resize(DDim(std::vector({2}))); + + auto* actual_shape_data = actual_shape->mutable_data(); + actual_shape_data[0] = 6; + actual_shape_data[1] = 8; + + for (auto& shape : shapes) { + for (auto& has_actual_shape : {true, false}) { + for (auto& inplace : {true, false}) { + // prepare op desc + lite::OpDesc desc; + desc.SetType("reshape"); + desc.SetInput("X", {"x"}); + if (has_actual_shape) { + desc.SetInput("Shape", {"actual_shape"}); + } + desc.SetOutput("Out", {"output"}); + desc.SetAttr("shape", shape.first); + desc.SetAttr("inplace", inplace); + + ReshapeOp reshape("reshape"); + + reshape.SetValidPlaces( + {Place{TARGET(kHost), PRECISION(kAny), DATALAYOUT(kAny)}}); + reshape.Attach(desc, &scope); + reshape.CheckShape(); + reshape.InferShape(); + + // check output dims + auto output_dims = output->dims(); + CHECK_EQ(output_dims.size(), shape.second.size()); + for (int i = 0; i < output_dims.size(); i++) { + CHECK_EQ(output_dims[i], shape.second[i]); + } + } + } + } +} + +TEST(reshape2_op_lite, test) { + // prepare variables + Scope scope; + auto* x = scope.Var("x")->GetMutable(); + auto* actual_shape = scope.Var("actual_shape")->GetMutable(); + auto* output = scope.Var("output")->GetMutable(); + auto* xshape = scope.Var("xshape")->GetMutable(); + std::map, std::vector> shapes = { + {{-1, 0, 3, 2, 1}, {2, 4, 3, 2, 1}}, + {{0, -1, 3, 2, 1}, {2, 4, 3, 2, 1}}, + {{-1, 48}, {1, 48}}, + {{48, -1}, {48, 1}}, + {{0, 24}, {2, 24}}, + {{12, 0}, {12, 4}}, + }; + x->Resize(DDim(std::vector({2, 4, 6}))); + actual_shape->Resize(DDim(std::vector({2}))); + + auto* actual_shape_data = actual_shape->mutable_data(); + actual_shape_data[0] = 6; + actual_shape_data[1] = 8; + + for (auto& shape : shapes) { + for (auto& has_actual_shape : {true, false}) { + for (auto& inplace : {true, false}) { + // prepare op desc + lite::OpDesc desc; + desc.SetType("reshape"); + desc.SetInput("X", {"x"}); + if (has_actual_shape) { + desc.SetInput("Shape", {"actual_shape"}); + } + desc.SetOutput("Out", {"output"}); + desc.SetOutput("XShape", {"xshape"}); + desc.SetAttr("shape", shape.first); + desc.SetAttr("inplace", inplace); + + Reshape2Op reshape2("reshape2"); + + reshape2.SetValidPlaces( + {Place{TARGET(kHost), PRECISION(kAny), DATALAYOUT(kAny)}}); + reshape2.Attach(desc, &scope); + reshape2.CheckShape(); + reshape2.InferShape(); + + // check output dims + auto output_dims = output->dims(); + CHECK_EQ(output_dims.size(), shape.second.size()); + for (int i = 0; i < output_dims.size(); i++) { + CHECK_EQ(output_dims[i], shape.second[i]); + } + // check xshape dims + auto x_dims = x->dims(); + auto xshape_dims = xshape->dims(); + CHECK_EQ(xshape_dims.size(), x_dims.size() + 1); + CHECK_EQ(xshape_dims[0], 0); + for (int i = 0; i < x_dims.size(); i++) { + CHECK_EQ(xshape_dims[i + 1], x_dims[i]); + } + } + } + } +} + +} // namespace operators +} // namespace lite +} // namespace paddle diff --git a/paddle/fluid/lite/operators/scale_op.cc b/paddle/fluid/lite/operators/scale_op.cc index 87cbe2a2e035bd2f943c18c4f19bd40e5e9df0dd..0a6dec991a0ec1669792accd86960611a4212c24 100644 --- a/paddle/fluid/lite/operators/scale_op.cc +++ b/paddle/fluid/lite/operators/scale_op.cc @@ -12,58 +12,35 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include -#include -#include "paddle/fluid/lite/core/kernel.h" -#include "paddle/fluid/lite/core/op_lite.h" +#include "paddle/fluid/lite/operators/scale_op.h" #include "paddle/fluid/lite/core/op_registry.h" -#include "paddle/fluid/lite/core/scope.h" -#include "paddle/fluid/lite/operators/op_params.h" -#include "paddle/fluid/lite/utils/all.h" - namespace paddle { namespace lite { namespace operators { -class ScaleOp : public OpLite { - public: - ScaleOp() {} - - explicit ScaleOp(const std::string &type) : OpLite(type) {} - - bool CheckShape() const override { - CHECK_OR_FALSE(param_.x); - CHECK_OR_FALSE(param_.output); - return true; - } - - bool InferShape() const override { - param_.output->Resize(param_.x->dims()); - return true; - } - - void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); } - - // TODO(Superjomn) replace framework::OpDesc with a lite one. - bool AttachImpl(const OpDesc &op_desc, lite::Scope *scope) override { - auto x = op_desc.Input("X").front(); - auto out = op_desc.Output("Out").front(); - - param_.x = scope->FindVar(x)->GetMutable(); - CHECK(scope->FindVar(out)); - param_.output = scope->FindVar(out)->GetMutable(); - param_.scale = GetAttr(op_desc.GetAttr("scale")); - param_.bias = GetAttr(op_desc.GetAttr("bias")); - param_.bias_after_scale = - GetAttr(op_desc.GetAttr("bias_after_scale")); - return true; - } - - std::string DebugString() const override { return op_type_; } - - private: - mutable ScaleParam param_; -}; +bool ScaleOp::CheckShape() const { + CHECK_OR_FALSE(param_.x); + CHECK_OR_FALSE(param_.output); + return true; +} + +bool ScaleOp::InferShape() const { + param_.output->Resize(param_.x->dims()); + return true; +} + +bool ScaleOp::AttachImpl(const OpDesc &op_desc, lite::Scope *scope) { + auto x = op_desc.Input("X").front(); + auto output = op_desc.Output("Out").front(); + param_.x = scope->FindVar(x)->GetMutable(); + param_.output = scope->FindVar(output)->GetMutable(); + param_.scale = GetAttr(op_desc.GetAttr("scale")); + param_.bias = GetAttr(op_desc.GetAttr("bias")); + param_.bias_after_scale = GetAttr(op_desc.GetAttr("bias_after_scale")); + CHECK(param_.x); + CHECK(param_.output); + return true; +} } // namespace operators } // namespace lite diff --git a/paddle/fluid/lite/operators/scale_op.h b/paddle/fluid/lite/operators/scale_op.h new file mode 100644 index 0000000000000000000000000000000000000000..8866e6a29b7565dd80ec05be795d6ff8ac7120b1 --- /dev/null +++ b/paddle/fluid/lite/operators/scale_op.h @@ -0,0 +1,46 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include +#include "paddle/fluid/lite/core/op_lite.h" +#include "paddle/fluid/lite/core/scope.h" +#include "paddle/fluid/lite/utils/all.h" + +namespace paddle { +namespace lite { +namespace operators { + +class ScaleOp : public OpLite { + public: + ScaleOp() {} + explicit ScaleOp(const std::string &op_type) : OpLite(op_type) {} + + bool CheckShape() const override; + + bool InferShape() const override; + + bool AttachImpl(const OpDesc &opdesc, lite::Scope *scope) override; + + void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); } + std::string DebugString() const override { return "scale"; } + + private: + mutable ScaleParam param_; +}; + +} // namespace operators +} // namespace lite +} // namespace paddle diff --git a/paddle/fluid/lite/operators/scale_op_test.cc b/paddle/fluid/lite/operators/scale_op_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..ad61a27a1c3d8a212597a0e565a06144796782c9 --- /dev/null +++ b/paddle/fluid/lite/operators/scale_op_test.cc @@ -0,0 +1,58 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/lite/operators/scale_op.h" +#include +#include "paddle/fluid/lite/core/op_registry.h" + +namespace paddle { +namespace lite { +namespace operators { + +TEST(scale_op_lite, test) { + // prepare variables + Scope scope; + auto* x = scope.Var("x")->GetMutable(); + auto* output = scope.Var("output")->GetMutable(); + x->Resize(DDim(std::vector({10, 20}))); + output->Resize(DDim(std::vector{1, 1})); + + // prepare op desc + lite::OpDesc desc; + desc.SetType("scale"); + desc.SetInput("X", {"x"}); + desc.SetOutput("Out", {"output"}); + desc.SetAttr("bias_after_scale", false); + desc.SetAttr("scale", 0.5f); + desc.SetAttr("bias", 0.125f); + + ScaleOp scale("scale"); + + scale.SetValidPlaces({Place{TARGET(kHost), PRECISION(kFloat)}}); + scale.Attach(desc, &scope); + scale.CheckShape(); + scale.InferShape(); + + // check output dims + auto x_dims = x->dims(); + auto output_dims = output->dims(); + CHECK_EQ(output_dims.size(), x_dims.size()); + for (int i = 0; i < output_dims.size(); i++) { + CHECK_EQ(output_dims[i], x_dims[i]); + } +} + +} // namespace operators +} // namespace lite +} // namespace paddle diff --git a/paddle/fluid/lite/operators/softmax_op.cc b/paddle/fluid/lite/operators/softmax_op.cc index 3e41a5ffffe9ab1d045bdc5da3a819cb645408d5..518d6a3d36a9d919b4fa326db07d9a53f4dc454e 100644 --- a/paddle/fluid/lite/operators/softmax_op.cc +++ b/paddle/fluid/lite/operators/softmax_op.cc @@ -22,9 +22,9 @@ namespace operators { bool SoftmaxOp::CheckShape() const { CHECK_OR_FALSE(param_.x); CHECK_OR_FALSE(param_.output); - auto dim_x = param_.x->dims(); - auto rank_x = dim_x.size(); - CHECK_OR_FALSE(param_.axis >= -rank_x && param_.axis < rank_x); + auto x_dims = param_.x->dims(); + auto x_rank = x_dims.size(); + CHECK_OR_FALSE(param_.axis >= -x_rank && param_.axis < x_rank); return true; }