From 96bbcbe933b390d6bfb2bcedbfc6ab058fff6019 Mon Sep 17 00:00:00 2001 From: lijiancheng0614 Date: Mon, 22 Oct 2018 20:10:14 +0800 Subject: [PATCH] add test_fill_constant_op --- test/CMakeLists.txt | 4 + test/operators/test_fill_constant_op.cpp | 113 +++++++++++++++++++++++ 2 files changed, 117 insertions(+) create mode 100644 test/operators/test_fill_constant_op.cpp diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index 95d7dfc3a9..e1ccee3c1d 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -185,6 +185,10 @@ if (NOT FOUND_MATCH) ADD_EXECUTABLE(test-polygon-box-transform-op operators/test_polygon_box_transform_op.cpp test_helper.h test_include.h) target_link_libraries(test-polygon-box-transform-op paddle-mobile) + # gen test + ADD_EXECUTABLE(test-fill-constant-op operators/test_fill_constant_op.cpp test_helper.h test_include.h) + target_link_libraries(test-fill-constant-op paddle-mobile) + # gen test ADD_EXECUTABLE(test-reshape-op operators/test_reshape_op.cpp test_helper.h test_include.h) target_link_libraries(test-reshape-op paddle-mobile) diff --git a/test/operators/test_fill_constant_op.cpp b/test/operators/test_fill_constant_op.cpp new file mode 100644 index 0000000000..b099217d16 --- /dev/null +++ b/test/operators/test_fill_constant_op.cpp @@ -0,0 +1,113 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include "../test_include.h" +#include "operators/fill_constant_op.h" + +namespace paddle_mobile { +namespace framework { + +template +class TestFillConstantOp { + public: + explicit TestFillConstantOp(const Program p) : program_(p) { + if (use_optimize_) { + to_predict_program_ = program_.optimizeProgram; + } else { + to_predict_program_ = program_.originProgram; + } + const std::vector> blocks = + to_predict_program_->Blocks(); + for (auto block_desc : blocks) { + std::vector> ops = block_desc->Ops(); + for (auto op : ops) { + if (op->Type() == "fill_constant") { + DLOG << " attr size: " << op->GetAttrMap().size(); + std::unordered_map attrs = op->GetAttrMap(); + for (std::unordered_map::iterator it = + attrs.begin(); + it != attrs.end(); ++it) { + DLOG << " " << it->first << " " << it->second; + } + DLOG << " inputs size: " << op->GetInputs().size(); + DLOG << " outputs size: " << op->GetOutputs().size(); + DLOG << " output is : " << op->Output("Out")[0]; + output_var_name = op->Output("Out")[0]; + std::shared_ptr> op_ptr = + std::make_shared>( + op->Type(), op->GetInputs(), op->GetOutputs(), + op->GetAttrMap(), program_.scope); + ops_of_block_[*block_desc.get()].push_back(op_ptr); + } + } + } + } + + std::shared_ptr predict() { + auto scope = program_.scope; + + Variable *output = scope->Var(output_var_name); + auto *output_tensor = output->GetMutable(); + + std::shared_ptr out_tensor = std::make_shared(); + out_tensor.reset(output_tensor); + + predict(0); + + return out_tensor; + } + + private: + const framework::Program program_; + std::shared_ptr to_predict_program_; + std::map>>> + ops_of_block_; + bool use_optimize_ = false; + string output_var_name; + + void predict(int block_id) { + std::shared_ptr to_predict_block = + to_predict_program_->Block(block_id); + for (int j = 0; j < ops_of_block_[*to_predict_block.get()].size(); ++j) { + auto op = ops_of_block_[*to_predict_block.get()][j]; + op->Run(); + } + } +}; + +template class TestFillConstantOp; +} // namespace framework +} // namespace paddle_mobile + +int main() { + DLOG << "----------**********----------"; + DLOG << "begin to run FillConstant Test"; + paddle_mobile::Loader loader; + auto program = loader.Load(std::string(g_ocr) + "/model", + std::string(g_ocr) + "/params"); + + paddle_mobile::framework::TestFillConstantOp + testFillConstantOp(program); + + auto output = testFillConstantOp.predict(); + auto *output_ptr = output->data(); + + DLOG << "output : "; + for (int i = 0; i < output->numel(); ++i) { + DLOG << " index " << i << " : " << output_ptr[i]; + } + return 0; +} -- GitLab