提交 250206d1 编写于 作者: L Liu Yiqun

Change the example of inference to a unittest.

上级 c6482444
...@@ -24,19 +24,6 @@ if(NOT WITH_C_API AND WITH_FLUID) ...@@ -24,19 +24,6 @@ if(NOT WITH_C_API AND WITH_FLUID)
install(TARGETS paddle_fluid_shared DESTINATION lib) install(TARGETS paddle_fluid_shared DESTINATION lib)
endif() endif()
add_executable(example example.cc) if(WITH_TESTING)
if(APPLE) add_subdirectory(tests/book)
set(OPTIONAL_LINK_FLAGS)
if("${CMAKE_CXX_COMPILER_ID}" STREQUAL "Clang" OR "${CMAKE_CXX_COMPILER_ID}" STREQUAL "AppleClang")
set(OPTIONAL_LINK_FLAGS "-undefined dynamic_lookup")
endif()
target_link_libraries(example
-Wl,-force_load paddle_fluid
${OPTIONAL_LINK_FLAGS}
${PTOOLS_LIB})
else()
target_link_libraries(example
-Wl,--start-group -Wl,--whole-archive paddle_fluid
-Wl,--no-whole-archive -Wl,--end-group
${PTOOLS_LIB})
endif() endif()
add_executable(test_inference_recognize_digits_mlp test_inference_recognize_digits_mlp.cc)
target_circle_link_libraries(
test_inference_recognize_digits_mlp
ARCHIVE_START
paddle_fluid
ARCHIVE_END
gtest
gflags)
add_test(
NAME test_inference_recognize_digits_mlp
COMMAND test_inference_recognize_digits_mlp
--dirname=${PADDLE_SOURCE_DIR}/python/paddle/v2/fluid/tests/book/recognize_digits_mlp.inference.model
WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR})
...@@ -12,20 +12,17 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,20 +12,17 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include <gtest/gtest.h>
#include <time.h> #include <time.h>
#include <iostream> #include <sstream>
#include "gflags/gflags.h" #include "gflags/gflags.h"
#include "paddle/inference/inference.h" #include "paddle/inference/inference.h"
DEFINE_string(dirname, "", "Directory of the inference model."); DEFINE_string(dirname, "", "Directory of the inference model.");
int main(int argc, char** argv) { TEST(inference, recognize_digits_mlp) {
google::ParseCommandLineFlags(&argc, &argv, true);
if (FLAGS_dirname.empty()) { if (FLAGS_dirname.empty()) {
// Example: LOG(FATAL) << "Usage: ./example --dirname=path/to/your/model";
// ./example --dirname=recognize_digits_mlp.inference.model
std::cout << "Usage: ./example --dirname=path/to/your/model" << std::endl;
exit(1);
} }
std::cout << "FLAGS_dirname: " << FLAGS_dirname << std::endl; std::cout << "FLAGS_dirname: " << FLAGS_dirname << std::endl;
...@@ -48,20 +45,21 @@ int main(int argc, char** argv) { ...@@ -48,20 +45,21 @@ int main(int argc, char** argv) {
engine->Execute(feeds, fetchs); engine->Execute(feeds, fetchs);
for (size_t i = 0; i < fetchs.size(); ++i) { for (size_t i = 0; i < fetchs.size(); ++i) {
auto dims_i = fetchs[i].dims(); LOG(INFO) << fetchs[i].dims();
std::cout << "dims_i:"; std::stringstream ss;
for (int j = 0; j < dims_i.size(); ++j) { ss << "result:";
std::cout << " " << dims_i[j];
}
std::cout << std::endl;
std::cout << "result:";
float* output_ptr = fetchs[i].data<float>(); float* output_ptr = fetchs[i].data<float>();
for (int j = 0; j < paddle::framework::product(dims_i); ++j) { for (int j = 0; j < fetchs[i].numel(); ++j) {
std::cout << " " << output_ptr[j]; ss << " " << output_ptr[j];
} }
std::cout << std::endl; LOG(INFO) << ss.str();
} }
delete engine; delete engine;
return 0; }
int main(int argc, char** argv) {
google::ParseCommandLineFlags(&argc, &argv, false);
testing::InitGoogleTest(&argc, argv);
return RUN_ALL_TESTS();
} }
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册