diff --git a/src/operators/kernel/arm/reshape_kernel.cpp b/src/operators/kernel/arm/reshape_kernel.cpp index 15b3c6979a44fd42d11c71de488b3eadeb097226..7f7e80ece9f30631c109d0d27f4025e2617cec95 100644 --- a/src/operators/kernel/arm/reshape_kernel.cpp +++ b/src/operators/kernel/arm/reshape_kernel.cpp @@ -17,35 +17,35 @@ limitations under the License. */ #include "operators/kernel/reshape_kernel.h" namespace paddle_mobile { - namespace operators { - - template <> - void ReshapeKernel::Compute(const ReshapeParam ¶m) const { - const auto *input_x = param.InputX(); - const auto &input_x_dims = input_x->dims(); - auto *out = param.Out(); - framework::DDim out_dims = out->dims(); - const auto *input_shape = param.InputShape(); - - if (input_shape) { - auto *shape_data = input_shape->data(); - framework::Tensor cpu_shape_tensor; - auto shape = - std::vector(shape_data, shape_data + input_shape->numel()); - out_dims = ValidateShape(shape, input_x->dims()); - } - - bool inplace = param.Inplace(); - out->Resize(out_dims); - if (!inplace) { - out->mutable_data(); - framework::TensorCopy(*input_x,out); - out->Resize(out_dims); - } else { - out->ShareDataWith(*input_x); - out->Resize(out_dims); - } - } - - } // namespace operators +namespace operators { + +template <> +void ReshapeKernel::Compute(const ReshapeParam ¶m) const { + const auto *input_x = param.InputX(); + const auto &input_x_dims = input_x->dims(); + auto *out = param.Out(); + framework::DDim out_dims = out->dims(); + const auto *input_shape = param.InputShape(); + + if (input_shape) { + auto *shape_data = input_shape->data(); + framework::Tensor cpu_shape_tensor; + auto shape = + std::vector(shape_data, shape_data + input_shape->numel()); + out_dims = ValidateShape(shape, input_x->dims()); + } + + bool inplace = param.Inplace(); + out->Resize(out_dims); + if (!inplace) { + out->mutable_data(); + framework::TensorCopy(*input_x, out); + out->Resize(out_dims); + } else { + out->ShareDataWith(*input_x); + out->Resize(out_dims); + } +} + +} // namespace operators } // namespace paddle_mobile diff --git a/src/operators/kernel/reshape_kernel.h b/src/operators/kernel/reshape_kernel.h index 8d0f7dadd14bcd79a18b866ee2019651f4eb9798..7d5dcdf71de232b1c72180231731fcf76483b9e4 100644 --- a/src/operators/kernel/reshape_kernel.h +++ b/src/operators/kernel/reshape_kernel.h @@ -20,58 +20,55 @@ limitations under the License. */ #pragma once; namespace paddle_mobile { - namespace operators { +namespace operators { +inline framework::DDim ValidateShape(const std::vector shape, + const framework::DDim& in_dims) { + const int64_t in_size = framework::product(in_dims); + // only one dimension can be set to -1, whose size will be automatically + // infered. + const int64_t unk_dim_val = -1; + const int64_t copy_dim_val = 0; - inline framework::DDim ValidateShape(const std::vector shape, - const framework::DDim &in_dims) { - const int64_t in_size = framework::product(in_dims); - // only one dimension can be set to -1, whose size will be automatically - // infered. - const int64_t unk_dim_val = -1; - const int64_t copy_dim_val = 0; + std::vector output_shape(shape.size(), 0); + int64_t capacity = 1; + int unk_dim_idx = -1; + for (size_t i = 0; i < shape.size(); ++i) { + if (shape[i] == unk_dim_val) { + PADDLE_MOBILE_ENFORCE( + unk_dim_idx == -1, + "Only one input dimension of Attr(shape) can be unknown."); + unk_dim_idx = i; + } else if (shape[i] == copy_dim_val) { + PADDLE_MOBILE_ENFORCE( + static_cast(i) < in_dims.size(), + "The index of dimension to copy from input shape must be less " + "than the size of input shape."); + } else { + PADDLE_MOBILE_ENFORCE( + shape[i] > 0, + "Each input dimension of Attr(shape) must not be negtive except " + "one unknown dimension."); + } - std::vector output_shape(shape.size(), 0); - int64_t capacity = 1; - int unk_dim_idx = -1; - for (size_t i = 0; i < shape.size(); ++i) { - if (shape[i] == unk_dim_val) { - PADDLE_MOBILE_ENFORCE( - unk_dim_idx == -1, - "Only one input dimension of Attr(shape) can be unknown."); - unk_dim_idx = i; - } else if (shape[i] == copy_dim_val) { - PADDLE_MOBILE_ENFORCE( - static_cast(i) < in_dims.size(), - "The index of dimension to copy from input shape must be less " - "than the size of input shape."); - } else { - PADDLE_MOBILE_ENFORCE( - shape[i] > 0, - "Each input dimension of Attr(shape) must not be negtive except " - "one unknown dimension."); - } + capacity *= (shape[i] ? shape[i] : in_dims[i]); + output_shape[i] = (shape[i] ? static_cast(shape[i]) : in_dims[i]); + } - capacity *= (shape[i] ? shape[i] : in_dims[i]); - output_shape[i] = - (shape[i] ? static_cast(shape[i]) : in_dims[i]); - } + if (unk_dim_idx != -1) { + output_shape[unk_dim_idx] = -in_size / capacity; + PADDLE_MOBILE_ENFORCE(output_shape[unk_dim_idx] * capacity == -in_size, + "Invalid shape is given."); + } else { + PADDLE_MOBILE_ENFORCE(capacity == in_size, "Invalid shape is given."); + } + return framework::make_ddim(output_shape); +} - if (unk_dim_idx != -1) { - output_shape[unk_dim_idx] = -in_size / capacity; - PADDLE_MOBILE_ENFORCE(output_shape[unk_dim_idx] * capacity == -in_size, - "Invalid shape is given."); - } else { - PADDLE_MOBILE_ENFORCE(capacity==in_size, "Invalid shape is given."); - } - return framework::make_ddim(output_shape); - } - - template - class ReshapeKernel - : public framework::OpKernelBase { - public: - void Compute(const ReshapeParam& param) const; - }; - } // namespace operators +template +class ReshapeKernel : public framework::OpKernelBase { + public: + void Compute(const ReshapeParam& param) const; +}; +} // namespace operators } // namespace paddle_mobile diff --git a/src/operators/reshape_op.cpp b/src/operators/reshape_op.cpp index 7cffa6c7c0fc27a124266ca9020c547abfe27bd1..6562b7a5eb491a7e69e9bd9481251b8aaf9f3f4b 100644 --- a/src/operators/reshape_op.cpp +++ b/src/operators/reshape_op.cpp @@ -15,18 +15,18 @@ limitations under the License. */ #include "operators/reshape_op.h" #include namespace paddle_mobile { - namespace operators { +namespace operators { - template - void ReshapeOp::InferShape() const { - /// todo: add InputShape() detection. - auto &shape = param_.Shape(); - auto input_x_dims = param_.InputX()->dims(); - auto out_dims = ValidateShape(shape, input_x_dims); - param_.Out()->Resize(out_dims); - } - template class ReshapeOp; - } // namespace operators +template +void ReshapeOp::InferShape() const { + /// todo: add InputShape() detection. + auto &shape = param_.Shape(); + auto input_x_dims = param_.InputX()->dims(); + auto out_dims = ValidateShape(shape, input_x_dims); + param_.Out()->Resize(out_dims); +} +template class ReshapeOp; +} // namespace operators } // namespace paddle_mobile namespace ops = paddle_mobile::operators; diff --git a/src/operators/reshape_op.h b/src/operators/reshape_op.h index ae34fb66951a2ae4e62ce81df7803dcebd7cf75c..62bcb3a67980b75f46487aba4dbf5c89d2b65c7d 100644 --- a/src/operators/reshape_op.h +++ b/src/operators/reshape_op.h @@ -21,32 +21,31 @@ limitations under the License. */ #include "operators/op_param.h" namespace paddle_mobile { - namespace operators { - - using paddle_mobile::framework::Tensor; - - template - class ReshapeOp : public framework::OperatorWithKernel { - public: - ReshapeOp(const std::string &type, const VariableNameMap &inputs, - const VariableNameMap &outputs, - const framework::AttributeMap attrs, - std::shared_ptr scope) - : framework::OperatorWithKernel(type, inputs, outputs, attrs, - scope), - param_(inputs, outputs, attrs, *scope) {} - - void Run() const { - operators::ReshapeKernel kernel; - kernel.Compute(param_); - } - - using framework::OperatorWithKernel::OperatorWithKernel; - void InferShape() const override; - - protected: - ReshapeParam param_; - }; - - } // namespace operators +namespace operators { + +using paddle_mobile::framework::Tensor; + +template +class ReshapeOp : public framework::OperatorWithKernel { + public: + ReshapeOp(const std::string &type, const VariableNameMap &inputs, + const VariableNameMap &outputs, const framework::AttributeMap attrs, + std::shared_ptr scope) + : framework::OperatorWithKernel(type, inputs, outputs, attrs, + scope), + param_(inputs, outputs, attrs, *scope) {} + + void Run() const { + operators::ReshapeKernel kernel; + kernel.Compute(param_); + } + + using framework::OperatorWithKernel::OperatorWithKernel; + void InferShape() const override; + + protected: + ReshapeParam param_; +}; + +} // namespace operators } // namespace paddle_mobile diff --git a/test/executor_for_test.h b/test/executor_for_test.h index 2d592fdf349b76965a9b662d8b2c4e4679d1ad9b..11a1747ef06228d5e80255e722cadf61a9a6700e 100644 --- a/test/executor_for_test.h +++ b/test/executor_for_test.h @@ -21,9 +21,9 @@ limitations under the License. */ #include "io.h" #include "operators/conv_op.h" #include "operators/pool_op.h" +#include "operators/reshape_op.h" #include "operators/softmax_op.h" #include "operators/transpose_op.h" -#include "operators/reshape_op.h" using paddle_mobile::Executor; using paddle_mobile::framework::BlockDesc; diff --git a/test/operators/test_reshape_op.cpp b/test/operators/test_reshape_op.cpp index 29d184b77408e247e774ac94e95f27bcd85983b1..7ba2faa47dfad443df3ff59e8db4a66f8d1d8bcf 100644 --- a/test/operators/test_reshape_op.cpp +++ b/test/operators/test_reshape_op.cpp @@ -17,32 +17,32 @@ limitations under the License. */ #include "./io.h" int main() { - paddle_mobile::Loader loader; - auto program = loader.Load(std::string("../../test/models/mobilenet+ssd")); - if (program.originProgram == nullptr) { - DLOG << "program read file"; - } - Executor4Test> - executor(program, "reshape"); - paddle_mobile::framework::Tensor input; - SetupTensor(&input, {2, 3, 3, 2}, static_cast(0), - static_cast(1)); - auto input_ptr = input.data(); - auto out_ddim = paddle_mobile::framework::make_ddim({2, 9, 2}); - auto output = - executor.predict(input, "transpose_0.tmp_0", "reshape_0.tmp_0", out_ddim); - auto *output_ptr = output->data(); - - DLOG << "input : "; - for (int j = 0; j < input.numel(); ++j) { - DLOG << " index " << j << " : " << input_ptr[j]; - } - - DLOG << "output : "; - for (int j = 0; j < output->numel(); ++j) { - DLOG << " index " << j << " : " << output_ptr[j]; - } - - return 0; + paddle_mobile::Loader loader; + auto program = loader.Load(std::string("../../test/models/mobilenet+ssd")); + if (program.originProgram == nullptr) { + DLOG << "program read file"; + } + Executor4Test> + executor(program, "reshape"); + paddle_mobile::framework::Tensor input; + SetupTensor(&input, {2, 3, 3, 2}, static_cast(0), + static_cast(1)); + auto input_ptr = input.data(); + auto out_ddim = paddle_mobile::framework::make_ddim({2, 9, 2}); + auto output = + executor.predict(input, "transpose_0.tmp_0", "reshape_0.tmp_0", out_ddim); + auto *output_ptr = output->data(); + + DLOG << "input : "; + for (int j = 0; j < input.numel(); ++j) { + DLOG << " index " << j << " : " << input_ptr[j]; + } + + DLOG << "output : "; + for (int j = 0; j < output->numel(); ++j) { + DLOG << " index " << j << " : " << output_ptr[j]; + } + + return 0; }