diff --git a/src/common/types.cpp b/src/common/types.cpp index 7cdf967055939c2e6c76104f656fae83a3e70708..b90fb70f2a81b365f049632cc7281a69ec58e18d 100644 --- a/src/common/types.cpp +++ b/src/common/types.cpp @@ -40,6 +40,7 @@ const char *G_OP_TYPE_POOL2D = "pool2d"; const char *G_OP_TYPE_PRIOR_BOX = "prior_box"; const char *G_OP_TYPE_RELU = "relu"; const char *G_OP_TYPE_RESHAPE = "reshape"; +const char *G_OP_TYPE_RESHAPE2 = "reshape2"; const char *G_OP_TYPE_SIGMOID = "sigmoid"; const char *G_OP_TYPE_SOFTMAX = "softmax"; const char *G_OP_TYPE_TRANSPOSE = "transpose"; @@ -101,6 +102,7 @@ std::unordered_map< {G_OP_TYPE_POLYGON_BOX_TRANSFORM, {{"Input"}, {"Output"}}}, {G_OP_TYPE_FC, {{"X", "Y", "Z"}, {"Out"}}}, {G_OP_TYPE_RESHAPE, {{"X"}, {"Out"}}}, + {G_OP_TYPE_RESHAPE2, {{"X"}, {"Out", "XShape"}}}, {G_OP_TYPE_DEPTHWISE_CONV, {{"Input"}, {"Output"}}}, {G_OP_TYPE_FILL_CONSTANT, {{}, {"Out"}}}, {G_OP_TYPE_FUSION_CONV_ADD_RELU, {{"Input"}, {"Out"}}}, diff --git a/src/framework/load_ops.h b/src/framework/load_ops.h index d85f1bf548d5a1263ee2ce0e6062101322b5b15d..982f1c0f3525afde8475866c0121343fafc9d5a0 100644 --- a/src/framework/load_ops.h +++ b/src/framework/load_ops.h @@ -109,6 +109,9 @@ LOAD_FUSION_MATCHER(fusion_conv_add_bn_relu); #ifdef RESHAPE_OP LOAD_OP2(reshape, CPU, MALI_GPU); #endif +#ifdef RESHAPE2_OP +LOAD_OP2(reshape2, CPU, MALI_GPU); +#endif #ifdef TRANSPOSE_OP LOAD_OP1(transpose, CPU); #endif @@ -224,5 +227,9 @@ LOAD_FUSION_MATCHER(fusion_conv_bn); #ifdef ELEMENTWISESUB_OP LOAD_OP1(elementwise_sub, CPU) #endif +#ifdef QUANT_OP LOAD_OP1(quantize, CPU); +#endif +#ifdef DEQUANT_OP LOAD_OP1(dequantize, CPU); +#endif diff --git a/src/operators/kernel/arm/quantize_kernel.cpp b/src/operators/kernel/arm/quantize_kernel.cpp index e7552d2602b31f9a5c10e3d81122babae8fcf1a8..11a1f0a53d4886e1a07d258b76b3827671471dca 100644 --- a/src/operators/kernel/arm/quantize_kernel.cpp +++ b/src/operators/kernel/arm/quantize_kernel.cpp @@ -135,11 +135,15 @@ static void quantize_round_to_even(const Tensor *input, const float scale, #if defined(__ARM_NEON__) || defined(__ARM_NEON) size_t loop = size >> 4; size_t remain = size & 0xF; + + #pragma omp parallel for for (size_t i = 0; i < loop; ++i) { - float32x4_t r0 = vld1q_f32(x); - float32x4_t r1 = vld1q_f32(x + 4); - float32x4_t r2 = vld1q_f32(x + 8); - float32x4_t r3 = vld1q_f32(x + 12); + const float *local_x = x + (i << 4); + int8_t *local_y = y + (i << 4); + float32x4_t r0 = vld1q_f32(local_x); + float32x4_t r1 = vld1q_f32(local_x + 4); + float32x4_t r2 = vld1q_f32(local_x + 8); + float32x4_t r3 = vld1q_f32(local_x + 12); r0 = vmulq_n_f32(r0, scale); r1 = vmulq_n_f32(r1, scale); r2 = vmulq_n_f32(r2, scale); @@ -156,12 +160,12 @@ static void quantize_round_to_even(const Tensor *input, const float scale, int16x8_t q6 = vcombine_s16(d2, d3); int8x8_t d5 = vmovn_s16(q5); int8x8_t d6 = vmovn_s16(q6); - vst1_s8(y, d5); - vst1_s8(y + 8, d6); - x += 16; - y += 16; + vst1_s8(local_y, d5); + vst1_s8(local_y + 8, d6); } size = remain; + x += (loop << 4); + y += (loop << 4); #endif for (size_t i = 0; i < size; ++i) { float value = x[i] * scale; @@ -187,11 +191,15 @@ static void quantize_round_to_zero(const Tensor *input, const float scale, #ifdef defined(__ARM_NEON__) || defined(__ARM_NEON) size_t loop = size >> 4; size_t remain = size & 0xF; + + #pragma omp parallel for for (size_t i = 0; i < loop; ++i) { - float32x4_t r0 = vld1q_f32(x); - float32x4_t r1 = vld1q_f32(x + 4); - float32x4_t r2 = vld1q_f32(x + 8); - float32x4_t r3 = vld1q_f32(x + 12); + const float *local_x = x + (i << 4); + int8_t *local_y = y + (i << 4); + float32x4_t r0 = vld1q_f32(local_x); + float32x4_t r1 = vld1q_f32(local_x + 4); + float32x4_t r2 = vld1q_f32(local_x + 8); + float32x4_t r3 = vld1q_f32(local_x + 12); r0 = vmulq_n_f32(r0, scale); r1 = vmulq_n_f32(r1, scale); r2 = vmulq_n_f32(r2, scale); @@ -208,12 +216,12 @@ static void quantize_round_to_zero(const Tensor *input, const float scale, int16x8_t q6 = vcombine_s16(d2, d3); int8x8_t d5 = vmovn_s16(q5); int8x8_t d6 = vmovn_s16(q6); - vst1_s8(y, d5); - vst1_s8(y + 8, d6); - x += 16; - y += 16; + vst1_s8(local_y, d5); + vst1_s8(local_y + 8, d6); } size = remain; + x += (loop << 4); + y += (loop << 4); #endif for (size_t i = 0; i < size; ++i) { y[i] = trunc(x[i] * scale); @@ -228,11 +236,15 @@ static void quantize_round_to_nearest(const Tensor *input, const float scale, #if defined(__ARM_NEON__) || defined(__ARM_NEON) size_t loop = size >> 4; size_t remain = size & 0xF; + + #pragma omp parallel for for (size_t i = 0; i < loop; ++i) { - float32x4_t r0 = vld1q_f32(x); - float32x4_t r1 = vld1q_f32(x + 4); - float32x4_t r2 = vld1q_f32(x + 8); - float32x4_t r3 = vld1q_f32(x + 12); + const float *local_x = x + (i << 4); + int8_t *local_y = y + (i << 4); + float32x4_t r0 = vld1q_f32(local_x); + float32x4_t r1 = vld1q_f32(local_x + 4); + float32x4_t r2 = vld1q_f32(local_x + 8); + float32x4_t r3 = vld1q_f32(local_x + 12); r0 = vmulq_n_f32(r0, scale); r1 = vmulq_n_f32(r1, scale); r2 = vmulq_n_f32(r2, scale); @@ -249,12 +261,12 @@ static void quantize_round_to_nearest(const Tensor *input, const float scale, int16x8_t q6 = vcombine_s16(d2, d3); int8x8_t d5 = vmovn_s16(q5); int8x8_t d6 = vmovn_s16(q6); - vst1_s8(y, d5); - vst1_s8(y + 8, d6); - x += 16; - y += 16; + vst1_s8(local_y, d5); + vst1_s8(local_y + 8, d6); } size = remain; + x += (loop << 4); + y += (loop << 4); #endif for (size_t i = 0; i < size; ++i) { y[i] = round(x[i] * scale); diff --git a/src/operators/kernel/arm/reshape2_kernel.cpp b/src/operators/kernel/arm/reshape2_kernel.cpp new file mode 100644 index 0000000000000000000000000000000000000000..83bbf112abb8b5e290126d6909a0fe77291f8fac --- /dev/null +++ b/src/operators/kernel/arm/reshape2_kernel.cpp @@ -0,0 +1,37 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#ifdef RESHAPE2_OP + +#include "operators/kernel/reshape2_kernel.h" +#include "operators/kernel/central-arm-func/reshape2_arm_func.h" + +namespace paddle_mobile { +namespace operators { + +template <> +bool Reshape2Kernel::Init(Reshape2Param *param) { + return true; +} + +template <> +void Reshape2Kernel::Compute( + const Reshape2Param ¶m) const { + Reshape2Compute(param); +} + +} // namespace operators +} // namespace paddle_mobile + +#endif diff --git a/src/operators/kernel/central-arm-func/elementwise_add_arm_func.h b/src/operators/kernel/central-arm-func/elementwise_add_arm_func.h index 0c01ef0072444479d2d2e2f7676b842d89e432ec..b6288380a04c71b3d6467f7f6648db046ae9acc9 100644 --- a/src/operators/kernel/central-arm-func/elementwise_add_arm_func.h +++ b/src/operators/kernel/central-arm-func/elementwise_add_arm_func.h @@ -58,6 +58,7 @@ void ElementwiseAddCompute(const ElementwiseAddParam ¶m) { const float *input_data = input_x->data(); float *output_data = Out->mutable_data(); for (int i = 0; i < batch; ++i) { + #pragma omp parallel for for (int j = 0; j < channels; ++j) { size_t offset = (i * channels + j) * elementwise_num; const float *input = input_data + offset; diff --git a/src/operators/kernel/central-arm-func/reshape2_arm_func.h b/src/operators/kernel/central-arm-func/reshape2_arm_func.h new file mode 100644 index 0000000000000000000000000000000000000000..c22cf120313b039944932fb4e6cc52aa59a68fd4 --- /dev/null +++ b/src/operators/kernel/central-arm-func/reshape2_arm_func.h @@ -0,0 +1,59 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#ifdef RESHAPE2_OP +#pragma once + +#include +#include "operators/kernel/reshape_kernel.h" +#include "operators/op_param.h" + +namespace paddle_mobile { +namespace operators { + +template +void Reshape2Compute(const Reshape2Param ¶m) { + const auto *input_x = param.InputX(); + const auto &input_x_dims = input_x->dims(); + auto *out = param.Out(); + framework::DDim out_dims = out->dims(); + const auto *input_shape = param.InputShape(); + + if (input_shape) { + auto *shape_data = input_shape->data(); + framework::Tensor cpu_shape_tensor; + auto shape = + std::vector(shape_data, shape_data + input_shape->numel()); + out_dims = ValidateShape(shape, input_x->dims()); + } else { + auto &shape = param.Shape(); + out_dims = ValidateShape(shape, input_x_dims); + } + + bool inplace = param.Inplace(); + out->Resize(out_dims); + if (!inplace) { + out->mutable_data(); + framework::TensorCopy(*input_x, out); + out->Resize(out_dims); + } else { + out->ShareDataWith(*input_x); + out->Resize(out_dims); + } +} + +} // namespace operators +} // namespace paddle_mobile + +#endif diff --git a/src/operators/kernel/reshape2_kernel.h b/src/operators/kernel/reshape2_kernel.h new file mode 100644 index 0000000000000000000000000000000000000000..8d15a619d314e3f5d3085a34cff503e286b5ee37 --- /dev/null +++ b/src/operators/kernel/reshape2_kernel.h @@ -0,0 +1,36 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#ifdef RESHAPE2_OP + +#pragma once + +#include +#include "framework/operator.h" +#include "operators/op_param.h" + +namespace paddle_mobile { +namespace operators { + +template +class Reshape2Kernel + : public framework::OpKernelBase> { + public: + void Compute(const Reshape2Param& param) const; + bool Init(Reshape2Param* param); +}; +} // namespace operators +} // namespace paddle_mobile + +#endif diff --git a/src/operators/op_param.h b/src/operators/op_param.h index 8c3478a5f04e7c18bef7a50686a8fa1eead4d7eb..c60014094b582036ef2038b04edf7be3313e571e 100644 --- a/src/operators/op_param.h +++ b/src/operators/op_param.h @@ -1270,6 +1270,49 @@ class ReshapeParam : public OpParam { }; #endif +#ifdef RESHAPE2_OP +template +class Reshape2Param : public OpParam { + typedef typename DtypeTensorTrait::gtype GType; + typedef typename DtypeTensorTrait::rtype RType; + + public: + Reshape2Param(const VariableNameMap &inputs, const VariableNameMap &outputs, + const AttributeMap &attrs, const Scope &scope) { + input_x_ = InputXFrom(inputs, scope); + input_shape_ = InputShapeFrom(inputs, scope); + out_ = OutFrom(outputs, scope); + output_xshape_ = OutputXShapeFrom(outputs, scope); + shape_ = GetAttr>("shape", attrs); + if (HasAttr("inplace", attrs)) { + inplace_ = GetAttr("inplace", attrs); + } else { + inplace_ = false; + } + } + + const RType *InputX() const { return input_x_; } + + const RType *InputShape() const { return input_shape_; } + + RType *Out() const { return out_; } + + RType *OutputXShape() const { return output_xshape_; } + + const vector &Shape() const { return shape_; } + + const bool &Inplace() const { return inplace_; } + + private: + RType *input_x_; + RType *input_shape_; + RType *out_; + RType *output_xshape_; + vector shape_; + bool inplace_; +}; +#endif + #ifdef SCALE_OP template class ScaleParam : public OpParam { diff --git a/src/operators/reshape2_op.cpp b/src/operators/reshape2_op.cpp new file mode 100644 index 0000000000000000000000000000000000000000..d1623076570d466fc53f885374060c5e744365ed --- /dev/null +++ b/src/operators/reshape2_op.cpp @@ -0,0 +1,47 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#ifdef RESHAPE2_OP + +#include "operators/reshape2_op.h" +#include +#include "operators/kernel/reshape_kernel.h" +namespace paddle_mobile { +namespace operators { + +template +void Reshape2Op::InferShape() const { + auto &shape = this->param_.Shape(); + auto input_x_dims = this->param_.InputX()->dims(); + auto out_dims = ValidateShape(shape, input_x_dims); + this->param_.Out()->Resize(out_dims); + std::vector xshape_dims(input_x_dims.size() + 1, 0); + for (int i = 0; i < input_x_dims.size(); ++i) { + xshape_dims[i + 1] = input_x_dims[i]; + } + this->param_.OutputXShape()->Resize(framework::make_ddim(xshape_dims)); +} + +} // namespace operators +} // namespace paddle_mobile + +namespace ops = paddle_mobile::operators; +#ifdef PADDLE_MOBILE_CPU +REGISTER_OPERATOR_CPU(reshape2, ops::Reshape2Op); +#endif +#ifdef PADDLE_MOBILE_MALI_GPU +REGISTER_OPERATOR_MALI_GPU(reshape2, ops::Reshape2Op); +#endif + +#endif diff --git a/src/operators/reshape2_op.h b/src/operators/reshape2_op.h new file mode 100644 index 0000000000000000000000000000000000000000..3a06c2b9b90233b6ad0bacb6176f4cc274ff1cc0 --- /dev/null +++ b/src/operators/reshape2_op.h @@ -0,0 +1,54 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#ifdef RESHAPE2_OP + +#pragma once + +#include + +#include "framework/operator.h" +#include "operators/kernel/reshape2_kernel.h" +#include "operators/op_param.h" + +namespace paddle_mobile { +namespace operators { + +using paddle_mobile::framework::Tensor; + +template +class Reshape2Op : public framework::OperatorWithKernel< + DeviceType, Reshape2Param, + operators::Reshape2Kernel> { + public: + Reshape2Op(const std::string &type, const VariableNameMap &inputs, + const VariableNameMap &outputs, + const framework::AttributeMap &attrs, + std::shared_ptr scope) + : framework::OperatorWithKernel, + operators::Reshape2Kernel>( + type, inputs, outputs, attrs, scope) {} + + using framework::OperatorWithKernel< + DeviceType, Reshape2Param, + operators::Reshape2Kernel>::OperatorWithKernel; + void InferShape() const override; + + protected: +}; + +} // namespace operators +} // namespace paddle_mobile + +#endif diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index 2050b34d21f1fc9d22b2144f2fa5126ecc44c4b4..c534123952eb5c33173abddb4ca1700c57fd103a 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -200,6 +200,10 @@ if (NOT FOUND_MATCH) ADD_EXECUTABLE(test-reshape-op operators/test_reshape_op.cpp test_helper.h test_include.h) target_link_libraries(test-reshape-op paddle-mobile) + # gen test + ADD_EXECUTABLE(test-reshape2-op operators/test_reshape2_op.cpp test_helper.h test_include.h) + target_link_libraries(test-reshape2-op paddle-mobile) + # gen test ADD_EXECUTABLE(test-relu-op operators/test_relu_op.cpp test_helper.h test_include.h) target_link_libraries(test-relu-op paddle-mobile) diff --git a/test/net/test_googlenet.cpp b/test/net/test_googlenet.cpp index c88a78974c330ec270fbcb3f5c28e368ef16440e..f7d29942224b51734cf62988ba8f271f1fa05bc3 100644 --- a/test/net/test_googlenet.cpp +++ b/test/net/test_googlenet.cpp @@ -25,8 +25,8 @@ int main() { paddle_mobile::PaddleMobile paddle_mobile; #endif - paddle_mobile.SetThreadNum(1); - bool optimize = false; + paddle_mobile.SetThreadNum(4); + bool optimize = true; auto time1 = time(); if (paddle_mobile.Load(g_googlenet, optimize)) { auto time2 = time(); @@ -35,10 +35,10 @@ int main() { std::vector output; std::vector dims{1, 3, 224, 224}; GetInput(g_test_image_1x3x224x224, &input, dims); - // // 预热十次 - // for (int i = 0; i < 10; ++i) { - // output = paddle_mobile.Predict(input, dims); - // } + // 预热十次 + for (int i = 0; i < 10; ++i) { + output = paddle_mobile.Predict(input, dims); + } auto time3 = time(); for (int i = 0; i < 10; ++i) { output = paddle_mobile.Predict(input, dims); @@ -47,9 +47,6 @@ int main() { std::cout << "predict cost :" << time_diff(time3, time4) / 10 << "ms" << std::endl; - for (int i = 0; i < output.size(); ++i) { - DLOG << "result[" << i << "] = " << output[i]; - } } return 0; } diff --git a/test/operators/test_reshape2_op.cpp b/test/operators/test_reshape2_op.cpp new file mode 100644 index 0000000000000000000000000000000000000000..564b8bcb4db8bdc2c97d4bbc9635262a8a28a6e4 --- /dev/null +++ b/test/operators/test_reshape2_op.cpp @@ -0,0 +1,143 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include "../test_include.h" +#include "operators/reshape2_op.h" + +namespace paddle_mobile { +namespace framework { + +template +class TestReshape2Op { + public: + explicit TestReshape2Op(const Program p) : program_(p) { + if (use_optimize_) { + to_predict_program_ = program_.optimizeProgram; + } else { + to_predict_program_ = program_.originProgram; + } + const std::vector> blocks = + to_predict_program_->Blocks(); + for (auto block_desc : blocks) { + std::vector> ops = block_desc->Ops(); + for (auto op : ops) { + if (op->Type() == "reshape2") { + DLOG << " attr size: " << op->GetAttrMap().size(); + std::unordered_map attrs = op->GetAttrMap(); + for (std::unordered_map::iterator it = + attrs.begin(); + it != attrs.end(); ++it) { + DLOG << " " << it->first << " " << it->second; + } + + DLOG << " inputs size: " << op->GetInputs().size(); + VariableNameMap inputs = op->GetInputs(); + for (VariableNameMap::iterator it = inputs.begin(); + it != inputs.end(); ++it) { + DLOG << " " << it->first << " " << it->second; + } + + DLOG << " outputs size: " << op->GetOutputs().size(); + VariableNameMap outputs = op->GetOutputs(); + for (VariableNameMap::iterator it = outputs.begin(); + it != outputs.end(); ++it) { + DLOG << " " << it->first << " " << it->second; + } + + input_var_name = op->Input("X")[0]; + output_var_name = op->Output("Out")[0]; + std::shared_ptr> op_ptr = + std::make_shared>( + op->Type(), op->GetInputs(), op->GetOutputs(), + op->GetAttrMap(), program_.scope); + ops_of_block_[*block_desc.get()].push_back(op_ptr); + return; + } + } + } + } + + std::shared_ptr predict(const Tensor &t) { + auto scope = program_.scope; + Variable *input_feed_value = scope->Var(input_var_name); + auto tensor_input = input_feed_value->GetMutable(); + tensor_input->ShareDataWith(t); + + Variable *output = scope->Var(output_var_name); + auto *output_tensor = output->GetMutable(); + + std::shared_ptr out_tensor = std::make_shared(); + out_tensor.reset(output_tensor); + + predict(t, 0); + + return out_tensor; + } + + private: + const framework::Program program_; + std::shared_ptr to_predict_program_; + std::map>>> + ops_of_block_; + bool use_optimize_ = false; + string input_var_name; + string output_var_name; + + void predict(const Tensor &t, int block_id) { + std::shared_ptr to_predict_block = + to_predict_program_->Block(block_id); + for (int j = 0; j < ops_of_block_[*to_predict_block.get()].size(); ++j) { + auto op = ops_of_block_[*to_predict_block.get()][j]; + op->Run(); + } + } +}; + +template class TestReshape2Op; +} // namespace framework +} // namespace paddle_mobile + +int main() { + DLOG << "----------**********----------"; + DLOG << "begin to run Reshape2 Test"; + paddle_mobile::Loader loader; + auto program = loader.Load(std::string(g_ocr) + "/model", + std::string(g_ocr) + "/params"); + + paddle_mobile::framework::Tensor input; + SetupTensor(&input, {1, 4, 4}, static_cast(0), + static_cast(1)); + auto *input_ptr = input.data(); + for (int i = 0; i < 16; ++i) { + *(input_ptr + i) = i; + } + DLOG << "input : "; + for (int i = 0; i < input.numel(); ++i) { + DLOG << " index " << i << " : " << input_ptr[i]; + } + + paddle_mobile::framework::TestReshape2Op testReshape2Op( + program); + + auto output = testReshape2Op.predict(input); + auto *output_ptr = output->data(); + + DLOG << "output : "; + for (int i = 0; i < output->numel(); ++i) { + DLOG << " index " << i << " : " << output_ptr[i]; + } + return 0; +} diff --git a/tools/op.cmake b/tools/op.cmake index a1b6c889aa01ab7747ecc8844f25830be431d5f5..2e1e311a2c96bac02257cfdce2d2fbebcd962dfb 100644 --- a/tools/op.cmake +++ b/tools/op.cmake @@ -201,6 +201,7 @@ if(NOT FOUND_MATCH) set(PRIORBOX_OP ON) set(RELU_OP ON) set(RESHAPE_OP ON) + set(RESHAPE2_OP ON) set(SIGMOID_OP ON) set(SOFTMAX_OP ON) set(TRANSPOSE_OP ON) @@ -247,6 +248,7 @@ endif() # option(PRIORBOX_OP "" ON) # option(RELU_OP "" ON) # option(RESHAPE_OP "" ON) + # option(RESHAPE2_OP "" ON) # option(SIGMOID_OP "" ON) # option(SOFTMAX_OP "" ON) # option(TRANSPOSE_OP "" ON) @@ -316,6 +318,9 @@ endif() if (RESHAPE_OP) add_definitions(-DRESHAPE_OP) endif() +if (RESHAPE2_OP) + add_definitions(-DRESHAPE2_OP) +endif() if (SIGMOID_OP) add_definitions(-DSIGMOID_OP) endif()