提交 250206d1 编写于 作者: L Liu Yiqun

Change the example of inference to a unittest.

上级 c6482444
......@@ -24,19 +24,6 @@ if(NOT WITH_C_API AND WITH_FLUID)
install(TARGETS paddle_fluid_shared DESTINATION lib)
endif()
add_executable(example example.cc)
if(APPLE)
set(OPTIONAL_LINK_FLAGS)
if("${CMAKE_CXX_COMPILER_ID}" STREQUAL "Clang" OR "${CMAKE_CXX_COMPILER_ID}" STREQUAL "AppleClang")
set(OPTIONAL_LINK_FLAGS "-undefined dynamic_lookup")
endif()
target_link_libraries(example
-Wl,-force_load paddle_fluid
${OPTIONAL_LINK_FLAGS}
${PTOOLS_LIB})
else()
target_link_libraries(example
-Wl,--start-group -Wl,--whole-archive paddle_fluid
-Wl,--no-whole-archive -Wl,--end-group
${PTOOLS_LIB})
if(WITH_TESTING)
add_subdirectory(tests/book)
endif()
add_executable(test_inference_recognize_digits_mlp test_inference_recognize_digits_mlp.cc)
target_circle_link_libraries(
test_inference_recognize_digits_mlp
ARCHIVE_START
paddle_fluid
ARCHIVE_END
gtest
gflags)
add_test(
NAME test_inference_recognize_digits_mlp
COMMAND test_inference_recognize_digits_mlp
--dirname=${PADDLE_SOURCE_DIR}/python/paddle/v2/fluid/tests/book/recognize_digits_mlp.inference.model
WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR})
......@@ -12,20 +12,17 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <gtest/gtest.h>
#include <time.h>
#include <iostream>
#include <sstream>
#include "gflags/gflags.h"
#include "paddle/inference/inference.h"
DEFINE_string(dirname, "", "Directory of the inference model.");
int main(int argc, char** argv) {
google::ParseCommandLineFlags(&argc, &argv, true);
TEST(inference, recognize_digits_mlp) {
if (FLAGS_dirname.empty()) {
// Example:
// ./example --dirname=recognize_digits_mlp.inference.model
std::cout << "Usage: ./example --dirname=path/to/your/model" << std::endl;
exit(1);
LOG(FATAL) << "Usage: ./example --dirname=path/to/your/model";
}
std::cout << "FLAGS_dirname: " << FLAGS_dirname << std::endl;
......@@ -48,20 +45,21 @@ int main(int argc, char** argv) {
engine->Execute(feeds, fetchs);
for (size_t i = 0; i < fetchs.size(); ++i) {
auto dims_i = fetchs[i].dims();
std::cout << "dims_i:";
for (int j = 0; j < dims_i.size(); ++j) {
std::cout << " " << dims_i[j];
}
std::cout << std::endl;
std::cout << "result:";
LOG(INFO) << fetchs[i].dims();
std::stringstream ss;
ss << "result:";
float* output_ptr = fetchs[i].data<float>();
for (int j = 0; j < paddle::framework::product(dims_i); ++j) {
std::cout << " " << output_ptr[j];
for (int j = 0; j < fetchs[i].numel(); ++j) {
ss << " " << output_ptr[j];
}
std::cout << std::endl;
LOG(INFO) << ss.str();
}
delete engine;
return 0;
}
int main(int argc, char** argv) {
google::ParseCommandLineFlags(&argc, &argv, false);
testing::InitGoogleTest(&argc, argv);
return RUN_ALL_TESTS();
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册