未验证 提交 2096448b 编写于 作者: L Leo Chen 提交者: GitHub

make all cpp tests dynamic linked to libpaddle.so [except windows] (#47088)

* make all cpp tests dynamic linked to libpaddle.so

* add comments

* keep old cc_test for some tests

* fix some ut

* make some ut use cc_test_old

* fix typos and fit for win32

* fix lib path

* fix some tests

* skip lite test

* fit for rocm

* fit for cinn

* fit for mac

* fit for win32

* skip inference ut

* skip  windows

* fix coverage
上级 539f3006
...@@ -23,6 +23,7 @@ endif() ...@@ -23,6 +23,7 @@ endif()
# use to get_property location of static lib # use to get_property location of static lib
# https://cmake.org/cmake/help/v3.0/policy/CMP0026.html?highlight=cmp0026 # https://cmake.org/cmake/help/v3.0/policy/CMP0026.html?highlight=cmp0026
cmake_policy(SET CMP0026 OLD) cmake_policy(SET CMP0026 OLD)
cmake_policy(SET CMP0079 NEW)
set(CMAKE_MODULE_PATH ${CMAKE_MODULE_PATH} "${CMAKE_CURRENT_SOURCE_DIR}/cmake") set(CMAKE_MODULE_PATH ${CMAKE_MODULE_PATH} "${CMAKE_CURRENT_SOURCE_DIR}/cmake")
set(PADDLE_SOURCE_DIR ${CMAKE_CURRENT_SOURCE_DIR}) set(PADDLE_SOURCE_DIR ${CMAKE_CURRENT_SOURCE_DIR})
set(PADDLE_BINARY_DIR ${CMAKE_CURRENT_BINARY_DIR}) set(PADDLE_BINARY_DIR ${CMAKE_CURRENT_BINARY_DIR})
......
...@@ -486,14 +486,17 @@ endfunction() ...@@ -486,14 +486,17 @@ endfunction()
function(cc_test_run TARGET_NAME) function(cc_test_run TARGET_NAME)
if(WITH_TESTING) if(WITH_TESTING)
set(oneValueArgs "") set(oneValueArgs DIR)
set(multiValueArgs COMMAND ARGS) set(multiValueArgs COMMAND ARGS)
cmake_parse_arguments(cc_test "${options}" "${oneValueArgs}" cmake_parse_arguments(cc_test "${options}" "${oneValueArgs}"
"${multiValueArgs}" ${ARGN}) "${multiValueArgs}" ${ARGN})
if(cc_test_DIR STREQUAL "")
set(cc_test_DIR ${CMAKE_CURRENT_BINARY_DIR})
endif()
add_test( add_test(
NAME ${TARGET_NAME} NAME ${TARGET_NAME}
COMMAND ${cc_test_COMMAND} ${cc_test_ARGS} COMMAND ${cc_test_COMMAND} ${cc_test_ARGS}
WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}) WORKING_DIRECTORY ${cc_test_DIR})
set_property(TEST ${TARGET_NAME} PROPERTY ENVIRONMENT set_property(TEST ${TARGET_NAME} PROPERTY ENVIRONMENT
FLAGS_cpu_deterministic=true) FLAGS_cpu_deterministic=true)
set_property(TEST ${TARGET_NAME} PROPERTY ENVIRONMENT set_property(TEST ${TARGET_NAME} PROPERTY ENVIRONMENT
...@@ -513,7 +516,57 @@ function(cc_test_run TARGET_NAME) ...@@ -513,7 +516,57 @@ function(cc_test_run TARGET_NAME)
endif() endif()
endfunction() endfunction()
set_property(GLOBAL PROPERTY TEST_SRCS "")
set_property(GLOBAL PROPERTY TEST_NAMES "")
function(cc_test TARGET_NAME) function(cc_test TARGET_NAME)
if(WITH_TESTING)
set(oneValueArgs "")
set(multiValueArgs SRCS DEPS ARGS)
cmake_parse_arguments(cc_test "${options}" "${oneValueArgs}"
"${multiValueArgs}" ${ARGN})
if(WIN32)
# NOTE(zhiqiu): on windows platform, the symbols should be exported
# explicitly by __declspec(dllexport), however, there are serveral
# symbols not exported, and link error occurs.
# so, the tests are not built against dynamic libraries now.
cc_test_old(
${TARGET_NAME}
SRCS
${cc_test_SRCS}
DEPS
${cc_test_DEPS}
ARGS
${cc_test_ARGS})
else()
list(LENGTH cc_test_SRCS len)
# message("cc_test_SRCS ${cc_test_SRCS}")
# message("cc_test_ARGS ${cc_test_ARGS}")
if(${len} GREATER 1)
message(
SEND_ERROR
"The number source file of cc_test should be 1, but got ${len}, the source files are: ${cc_test_SRCS}"
)
endif()
list(LENGTH cc_test_ARGS len_arg)
if(len_arg GREATER_EQUAL 1)
set_property(GLOBAL PROPERTY "${TARGET_NAME}_ARGS" "${cc_test_ARGS}")
#message("${TARGET_NAME}_ARGS arg ${arg}")
endif()
get_property(test_srcs GLOBAL PROPERTY TEST_SRCS)
set(test_srcs ${test_srcs} "${CMAKE_CURRENT_SOURCE_DIR}/${cc_test_SRCS}")
set_property(GLOBAL PROPERTY TEST_SRCS "${test_srcs}")
get_property(test_names GLOBAL PROPERTY TEST_NAMES)
set(test_names ${test_names} ${TARGET_NAME})
set_property(GLOBAL PROPERTY TEST_NAMES "${test_names}")
endif()
endif()
endfunction()
function(cc_test_old TARGET_NAME)
if(WITH_TESTING) if(WITH_TESTING)
set(oneValueArgs "") set(oneValueArgs "")
set(multiValueArgs SRCS DEPS ARGS) set(multiValueArgs SRCS DEPS ARGS)
...@@ -626,25 +679,9 @@ function(nv_test TARGET_NAME) ...@@ -626,25 +679,9 @@ function(nv_test TARGET_NAME)
# Reference: https://cmake.org/cmake/help/v3.10/module/FindCUDA.html # Reference: https://cmake.org/cmake/help/v3.10/module/FindCUDA.html
add_executable(${TARGET_NAME} ${nv_test_SRCS}) add_executable(${TARGET_NAME} ${nv_test_SRCS})
get_property(os_dependency_modules GLOBAL PROPERTY OS_DEPENDENCY_MODULES) get_property(os_dependency_modules GLOBAL PROPERTY OS_DEPENDENCY_MODULES)
target_link_libraries( target_link_libraries(${TARGET_NAME} ${nv_test_DEPS}
${TARGET_NAME} ${os_dependency_modules} paddle_gtest_main)
${nv_test_DEPS} add_dependencies(${TARGET_NAME} ${nv_test_DEPS} paddle_gtest_main)
paddle_gtest_main
lod_tensor
memory
gtest
gflags
glog
${os_dependency_modules})
add_dependencies(
${TARGET_NAME}
${nv_test_DEPS}
paddle_gtest_main
lod_tensor
memory
gtest
gflags
glog)
common_link(${TARGET_NAME}) common_link(${TARGET_NAME})
add_test(${TARGET_NAME} ${TARGET_NAME}) add_test(${TARGET_NAME} ${TARGET_NAME})
set_property(TEST ${TARGET_NAME} PROPERTY ENVIRONMENT set_property(TEST ${TARGET_NAME} PROPERTY ENVIRONMENT
......
add_subdirectory(utils) set(CC_TESTS_DIR
add_subdirectory(scripts) ${PADDLE_BINARY_DIR}/paddle/tests
add_subdirectory(testing) CACHE INTERNAL "c++ tests directory")
set(PYTHON_TESTS_DIR set(PYTHON_TESTS_DIR
${PADDLE_BINARY_DIR}/python/paddle/fluid/tests ${PADDLE_BINARY_DIR}/python/paddle/fluid/tests
CACHE INTERNAL "python tests directory") CACHE INTERNAL "python tests directory")
add_subdirectory(utils)
add_subdirectory(scripts)
add_subdirectory(testing)
add_subdirectory(phi) add_subdirectory(phi)
add_subdirectory(infrt) add_subdirectory(infrt)
add_subdirectory(fluid) add_subdirectory(fluid)
# NOTE(zhiqiu): The changes of cc tests
# Before, (1) the source file of cc tests are distributed in different sub-directories,
# (2) the tests are added and configured by calling `cc_test()` in each `CMakeLists.txt`,
# (3) the tests links static libraries of paddle modules,
# (4) the tests binaries are generated in different directories, as the same as the
# folder of source file.
# Now, we want to make all cc tests dynamically linked to the main paddle labrary,
# i.e., `libpaddle.so`, so we changes the logic of (2), (3), (4):
# (2) calling `cc_test()` in each `CMakeLists.txt` will not `exactly` add test, but
# record all tests and its source files, the action of add tests is defered to HERE.
# Why doing so? since the target of `libpaddle.so` is mostly the last target, and
# the tests should be added after that accroding to dependency.
# (3) the tests links dynamic libraries, `libpaddle.so`
# (4) the tests are generated to the same directory, i.e., `CC_TESTS_DIR` defined above.
# Next, (to be discusssed)
# (1) move all source files to same folder,
# (2) naturally, and and configure tests in only one `CMakeLists.txt`,
# (3) cc tests support linking pre-built dynamic libraries. For example, use the dynamic
# library in the installed paddle by `pip`.
# add all tests here
get_property(test_srcs GLOBAL PROPERTY TEST_SRCS)
get_property(test_names GLOBAL PROPERTY TEST_NAMES)
# message("test_srcs ${test_srcs}")
get_property(paddle_lib GLOBAL PROPERTY PADDLE_LIB_NAME)
set(POSTFIX ".so")
if(WIN32)
set(POSTFIX ".dll")
endif()
list(LENGTH test_names len)
if(${len} GREATER_EQUAL 1)
message("Total tests: ${len}")
math(EXPR stop "${len} - 1")
foreach(idx RANGE ${stop})
if(WITH_TESTING)
list(GET test_srcs ${idx} test_src)
list(GET test_names ${idx} test_name)
get_property(test_arg GLOBAL PROPERTY "${test_name}_ARGS")
message("add test ${test_name}")
add_executable(${test_name} ${test_src})
# target_link_libraries(
# ${test_name}
# ${CMAKE_BINARY_DIR}/paddle/fluid/pybind/libpaddle${POSTFIX})
target_link_libraries(${test_name} $<TARGET_LINKER_FILE:${paddle_lib}>)
target_link_libraries(${test_name} paddle_gtest_main_new)
add_dependencies(${test_name} ${paddle_lib} paddle_gtest_main_new)
if(WITH_GPU)
target_link_libraries(${test_name} ${CUDA_CUDART_LIBRARY}
"-Wl,--as-needed")
endif()
if(WITH_ROCM)
target_link_libraries(${test_name} ${ROCM_HIPRTC_LIB})
endif()
if(APPLE)
target_link_libraries(${test_name}
"-Wl,-rpath,$<TARGET_FILE_DIR:${paddle_lib}>")
endif()
if(NOT
("${test_name}" STREQUAL "c_broadcast_op_npu_test"
OR "${test_name}" STREQUAL "c_allreduce_sum_op_npu_test"
OR "${test_name}" STREQUAL "c_allreduce_max_op_npu_test"
OR "${test_name}" STREQUAL "c_reducescatter_op_npu_test"
OR "${test_name}" STREQUAL "c_allgather_op_npu_test"
OR "${test_name}" STREQUAL "send_v2_op_npu_test"
OR "${test_name}" STREQUAL "c_reduce_sum_op_npu_test"
OR "${test_name}" STREQUAL "recv_v2_op_npu_test"))
cc_test_run(
${test_name}
COMMAND
${test_name}
ARGS
${test_arg}
DIR
${CC_TESTS_DIR})
endif()
elseif(WITH_TESTING AND NOT TEST ${test_name})
add_test(NAME ${test_name} COMMAND ${CMAKE_COMMAND} -E echo CI skip
${test_name}.)
endif()
set_target_properties(${test_name} PROPERTIES RUNTIME_OUTPUT_DIRECTORY
"${CC_TESTS_DIR}")
endforeach()
endif()
# set properties for some tests, it should be set after the tests defined.
if(TARGET standalone_executor_test)
set_tests_properties(standalone_executor_test PROPERTIES TIMEOUT 100)
if(NOT WIN32)
add_dependencies(standalone_executor_test download_program)
endif()
endif()
if(TARGET layer_test)
add_dependencies(layer_test jit_download_program)
add_dependencies(layer_test_new jit_download_program)
set_tests_properties(layer_test_new PROPERTIES ENVIRONMENT
"FLAGS_jit_engine_type=New")
endif()
if(TEST buddy_allocator_test)
if(NOT WIN32)
add_dependencies(buddy_allocator_test download_data)
endif()
set_tests_properties(buddy_allocator_test PROPERTIES LABELS
"RUN_TYPE=EXCLUSIVE")
endif()
add_custom_target(build_tests)
# add target to build all cpp tests
if(${len} GREATER_EQUAL 1)
add_dependencies(build_tests ${test_names})
endif()
set_source_files_properties( set_source_files_properties(
interceptor_ping_pong_test.cc PROPERTIES COMPILE_FLAGS interceptor_ping_pong_test.cc PROPERTIES COMPILE_FLAGS
${DISTRIBUTE_COMPILE_FLAGS}) ${DISTRIBUTE_COMPILE_FLAGS})
cc_test( cc_test_old(interceptor_ping_pong_test SRCS interceptor_ping_pong_test.cc DEPS
interceptor_ping_pong_test fleet_executor ${BRPC_DEPS})
SRCS interceptor_ping_pong_test.cc
DEPS fleet_executor ${BRPC_DEPS})
set_source_files_properties( set_source_files_properties(
compute_interceptor_test.cc PROPERTIES COMPILE_FLAGS compute_interceptor_test.cc PROPERTIES COMPILE_FLAGS
${DISTRIBUTE_COMPILE_FLAGS}) ${DISTRIBUTE_COMPILE_FLAGS})
cc_test( cc_test_old(compute_interceptor_test SRCS compute_interceptor_test.cc DEPS
compute_interceptor_test fleet_executor ${BRPC_DEPS})
SRCS compute_interceptor_test.cc
DEPS fleet_executor ${BRPC_DEPS})
set_source_files_properties( set_source_files_properties(
source_interceptor_test.cc PROPERTIES COMPILE_FLAGS source_interceptor_test.cc PROPERTIES COMPILE_FLAGS
${DISTRIBUTE_COMPILE_FLAGS}) ${DISTRIBUTE_COMPILE_FLAGS})
cc_test( cc_test_old(source_interceptor_test SRCS source_interceptor_test.cc DEPS
source_interceptor_test fleet_executor ${BRPC_DEPS})
SRCS source_interceptor_test.cc
DEPS fleet_executor ${BRPC_DEPS})
set_source_files_properties( set_source_files_properties(
sink_interceptor_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) sink_interceptor_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
cc_test( cc_test_old(sink_interceptor_test SRCS sink_interceptor_test.cc DEPS
sink_interceptor_test fleet_executor ${BRPC_DEPS})
SRCS sink_interceptor_test.cc
DEPS fleet_executor ${BRPC_DEPS})
set_source_files_properties( set_source_files_properties(
interceptor_pipeline_short_path_test.cc interceptor_pipeline_short_path_test.cc
PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
cc_test( cc_test_old(
interceptor_pipeline_short_path_test interceptor_pipeline_short_path_test SRCS
SRCS interceptor_pipeline_short_path_test.cc interceptor_pipeline_short_path_test.cc DEPS fleet_executor ${BRPC_DEPS})
DEPS fleet_executor ${BRPC_DEPS})
set_source_files_properties( set_source_files_properties(
interceptor_pipeline_long_path_test.cc PROPERTIES COMPILE_FLAGS interceptor_pipeline_long_path_test.cc PROPERTIES COMPILE_FLAGS
${DISTRIBUTE_COMPILE_FLAGS}) ${DISTRIBUTE_COMPILE_FLAGS})
cc_test( cc_test_old(
interceptor_pipeline_long_path_test interceptor_pipeline_long_path_test SRCS
SRCS interceptor_pipeline_long_path_test.cc interceptor_pipeline_long_path_test.cc DEPS fleet_executor ${BRPC_DEPS})
DEPS fleet_executor ${BRPC_DEPS})
set_source_files_properties( set_source_files_properties(
compute_interceptor_run_op_test.cc PROPERTIES COMPILE_FLAGS compute_interceptor_run_op_test.cc PROPERTIES COMPILE_FLAGS
${DISTRIBUTE_COMPILE_FLAGS}) ${DISTRIBUTE_COMPILE_FLAGS})
cc_test( cc_test_old(
compute_interceptor_run_op_test compute_interceptor_run_op_test
SRCS compute_interceptor_run_op_test.cc SRCS
DEPS fleet_executor compute_interceptor_run_op_test.cc
${BRPC_DEPS} DEPS
op_registry fleet_executor
fill_constant_op ${BRPC_DEPS}
elementwise_add_op op_registry
scope fill_constant_op
device_context) elementwise_add_op
scope
device_context)
if(WITH_DISTRIBUTE if(WITH_DISTRIBUTE
AND WITH_PSCORE AND WITH_PSCORE
...@@ -65,8 +57,7 @@ if(WITH_DISTRIBUTE ...@@ -65,8 +57,7 @@ if(WITH_DISTRIBUTE
set_source_files_properties( set_source_files_properties(
interceptor_ping_pong_with_brpc_test.cc interceptor_ping_pong_with_brpc_test.cc
PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
cc_test( cc_test_old(
interceptor_ping_pong_with_brpc_test interceptor_ping_pong_with_brpc_test SRCS
SRCS interceptor_ping_pong_with_brpc_test.cc interceptor_ping_pong_with_brpc_test.cc DEPS fleet_executor ${BRPC_DEPS})
DEPS fleet_executor ${BRPC_DEPS})
endif() endif()
set_source_files_properties( set_source_files_properties(
table_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) table_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
cc_test( cc_test_old(
table_test table_test
SRCS table_test.cc SRCS
DEPS common_table table ps_framework_proto ${COMMON_DEPS} ${RPC_DEPS}) table_test.cc
DEPS
common_table
table
ps_framework_proto
${COMMON_DEPS}
${RPC_DEPS})
set_source_files_properties( set_source_files_properties(
dense_table_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) dense_table_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
cc_test( cc_test_old(
dense_table_test dense_table_test
SRCS dense_table_test.cc SRCS
DEPS common_table table ps_framework_proto ${COMMON_DEPS} ${RPC_DEPS}) dense_table_test.cc
DEPS
common_table
table
ps_framework_proto
${COMMON_DEPS}
${RPC_DEPS})
set_source_files_properties( set_source_files_properties(
barrier_table_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) barrier_table_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
cc_test( cc_test_old(
barrier_table_test barrier_table_test
SRCS barrier_table_test.cc SRCS
DEPS common_table table ps_framework_proto ${COMMON_DEPS}) barrier_table_test.cc
DEPS
common_table
table
ps_framework_proto
${COMMON_DEPS})
set_source_files_properties( set_source_files_properties(
brpc_service_dense_sgd_test.cc PROPERTIES COMPILE_FLAGS brpc_service_dense_sgd_test.cc PROPERTIES COMPILE_FLAGS
${DISTRIBUTE_COMPILE_FLAGS}) ${DISTRIBUTE_COMPILE_FLAGS})
cc_test( cc_test_old(
brpc_service_dense_sgd_test brpc_service_dense_sgd_test
SRCS brpc_service_dense_sgd_test.cc SRCS
DEPS scope ps_service table ps_framework_proto ${COMMON_DEPS}) brpc_service_dense_sgd_test.cc
DEPS
scope
ps_service
table
ps_framework_proto
${COMMON_DEPS})
set_source_files_properties( set_source_files_properties(
brpc_service_sparse_sgd_test.cc PROPERTIES COMPILE_FLAGS brpc_service_sparse_sgd_test.cc PROPERTIES COMPILE_FLAGS
${DISTRIBUTE_COMPILE_FLAGS}) ${DISTRIBUTE_COMPILE_FLAGS})
cc_test( cc_test_old(
brpc_service_sparse_sgd_test brpc_service_sparse_sgd_test
SRCS brpc_service_sparse_sgd_test.cc SRCS
DEPS scope ps_service table ps_framework_proto ${COMMON_DEPS}) brpc_service_sparse_sgd_test.cc
DEPS
scope
ps_service
table
ps_framework_proto
${COMMON_DEPS})
set_source_files_properties( set_source_files_properties(
brpc_utils_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) brpc_utils_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
cc_test( cc_test_old(
brpc_utils_test brpc_utils_test
SRCS brpc_utils_test.cc SRCS
DEPS brpc_utils scope math_function ${COMMON_DEPS} ${RPC_DEPS}) brpc_utils_test.cc
DEPS
brpc_utils
scope
math_function
${COMMON_DEPS}
${RPC_DEPS})
set_source_files_properties( set_source_files_properties(
graph_node_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) graph_node_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
cc_test( cc_test_old(
graph_node_test graph_node_test
SRCS graph_node_test.cc SRCS
DEPS scope ps_service table ps_framework_proto ${COMMON_DEPS}) graph_node_test.cc
DEPS
scope
ps_service
table
ps_framework_proto
${COMMON_DEPS})
set_source_files_properties( set_source_files_properties(
graph_node_split_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) graph_node_split_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
cc_test( cc_test_old(
graph_node_split_test graph_node_split_test
SRCS graph_node_split_test.cc SRCS
DEPS scope ps_service table ps_framework_proto ${COMMON_DEPS}) graph_node_split_test.cc
DEPS
scope
ps_service
table
ps_framework_proto
${COMMON_DEPS})
set_source_files_properties( set_source_files_properties(
graph_table_sample_test.cc PROPERTIES COMPILE_FLAGS graph_table_sample_test.cc PROPERTIES COMPILE_FLAGS
${DISTRIBUTE_COMPILE_FLAGS}) ${DISTRIBUTE_COMPILE_FLAGS})
cc_test( cc_test_old(
graph_table_sample_test graph_table_sample_test
SRCS graph_table_sample_test.cc SRCS
DEPS table ps_framework_proto ${COMMON_DEPS}) graph_table_sample_test.cc
DEPS
table
ps_framework_proto
${COMMON_DEPS})
set_source_files_properties( set_source_files_properties(
feature_value_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) feature_value_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
cc_test( cc_test_old(feature_value_test SRCS feature_value_test.cc DEPS ${COMMON_DEPS}
feature_value_test table)
SRCS feature_value_test.cc
DEPS ${COMMON_DEPS} table)
set_source_files_properties( set_source_files_properties(
sparse_sgd_rule_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) sparse_sgd_rule_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
cc_test( cc_test_old(sparse_sgd_rule_test SRCS sparse_sgd_rule_test.cc DEPS
sparse_sgd_rule_test ${COMMON_DEPS} table)
SRCS sparse_sgd_rule_test.cc
DEPS ${COMMON_DEPS} table)
set_source_files_properties( set_source_files_properties(
ctr_accessor_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) ctr_accessor_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
cc_test( cc_test_old(ctr_accessor_test SRCS ctr_accessor_test.cc DEPS ${COMMON_DEPS}
ctr_accessor_test table)
SRCS ctr_accessor_test.cc
DEPS ${COMMON_DEPS} table)
set_source_files_properties( set_source_files_properties(
ctr_dymf_accessor_test.cc PROPERTIES COMPILE_FLAGS ctr_dymf_accessor_test.cc PROPERTIES COMPILE_FLAGS
${DISTRIBUTE_COMPILE_FLAGS}) ${DISTRIBUTE_COMPILE_FLAGS})
cc_test( cc_test_old(ctr_dymf_accessor_test SRCS ctr_dymf_accessor_test.cc DEPS
ctr_dymf_accessor_test ${COMMON_DEPS} table)
SRCS ctr_dymf_accessor_test.cc
DEPS ${COMMON_DEPS} table)
set_source_files_properties( set_source_files_properties(
memory_sparse_table_test.cc PROPERTIES COMPILE_FLAGS memory_sparse_table_test.cc PROPERTIES COMPILE_FLAGS
${DISTRIBUTE_COMPILE_FLAGS}) ${DISTRIBUTE_COMPILE_FLAGS})
cc_test( cc_test_old(memory_sparse_table_test SRCS memory_sparse_table_test.cc DEPS
memory_sparse_table_test ${COMMON_DEPS} table)
SRCS memory_sparse_table_test.cc
DEPS ${COMMON_DEPS} table)
set_source_files_properties( set_source_files_properties(
memory_geo_table_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) memory_geo_table_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
cc_test( cc_test_old(memory_sparse_geo_table_test SRCS memory_geo_table_test.cc DEPS
memory_sparse_geo_table_test ${COMMON_DEPS} table)
SRCS memory_geo_table_test.cc
DEPS ${COMMON_DEPS} table)
cc_test( cc_test_old(test_egr_ds_eager_tensor SRCS eager_tensor_test.cc DEPS
test_egr_ds_eager_tensor ${eager_deps})
SRCS eager_tensor_test.cc cc_test_old(test_egr_ds_auotgrad_meta SRCS autograd_meta_test.cc DEPS
DEPS ${eager_deps}) ${eager_deps})
cc_test( cc_test_old(test_egr_ds_grad_node_info SRCS grad_node_info_test.cc DEPS
test_egr_ds_auotgrad_meta ${eager_deps})
SRCS autograd_meta_test.cc cc_test_old(test_egr_ds_accumulation_node SRCS accumulation_node_test.cc DEPS
DEPS ${eager_deps}) ${eager_deps})
cc_test( cc_test_old(test_egr_ds_tensor_wrapper SRCS tensor_wrapper_test.cc DEPS
test_egr_ds_grad_node_info ${eager_deps})
SRCS grad_node_info_test.cc
DEPS ${eager_deps})
cc_test(
test_egr_ds_accumulation_node
SRCS accumulation_node_test.cc
DEPS ${eager_deps})
cc_test(
test_egr_ds_tensor_wrapper
SRCS tensor_wrapper_test.cc
DEPS ${eager_deps})
if(NOT ((NOT WITH_PYTHON) AND ON_INFER)) if(NOT ((NOT WITH_PYTHON) AND ON_INFER))
cc_test( cc_test_old(test_egr_ds_grad_tensor_holder SRCS grad_tensor_holder_test.cc
test_egr_ds_grad_tensor_holder DEPS ${eager_deps} ${generated_deps})
SRCS grad_tensor_holder_test.cc
DEPS ${eager_deps} ${generated_deps})
endif() endif()
...@@ -10,20 +10,36 @@ cc_library( ...@@ -10,20 +10,36 @@ cc_library(
matmul_v2_op matmul_v2_op
dygraph_function) dygraph_function)
cc_test( cc_test_old(
test_egr_performance_benchmark_eager_cpu test_egr_performance_benchmark_eager_cpu
SRCS benchmark_eager_cpu.cc SRCS
DEPS performance_benchmark_utils ${eager_deps} ${fluid_deps}) benchmark_eager_cpu.cc
cc_test( DEPS
performance_benchmark_utils
${eager_deps}
${fluid_deps})
cc_test_old(
test_egr_performance_benchmark_fluid_cpu test_egr_performance_benchmark_fluid_cpu
SRCS benchmark_fluid_cpu.cc SRCS
DEPS performance_benchmark_utils ${eager_deps} ${fluid_deps}) benchmark_fluid_cpu.cc
DEPS
performance_benchmark_utils
${eager_deps}
${fluid_deps})
cc_test( cc_test_old(
test_egr_performance_benchmark_eager_cuda test_egr_performance_benchmark_eager_cuda
SRCS benchmark_eager_cuda.cc SRCS
DEPS performance_benchmark_utils ${eager_deps} ${fluid_deps}) benchmark_eager_cuda.cc
cc_test( DEPS
performance_benchmark_utils
${eager_deps}
${fluid_deps})
cc_test_old(
test_egr_performance_benchmark_fluid_cuda test_egr_performance_benchmark_fluid_cuda
SRCS benchmark_fluid_cuda.cc SRCS
DEPS performance_benchmark_utils ${eager_deps} ${fluid_deps}) benchmark_fluid_cuda.cc
DEPS
performance_benchmark_utils
${eager_deps}
${fluid_deps})
...@@ -1156,19 +1156,29 @@ cc_library( ...@@ -1156,19 +1156,29 @@ cc_library(
op_compatible_info op_compatible_info
SRCS op_compatible_info.cc SRCS op_compatible_info.cc
DEPS string_helper proto_desc) DEPS string_helper proto_desc)
cc_test( cc_test_old(
op_compatible_info_test op_compatible_info_test
SRCS op_compatible_info_test.cc SRCS
DEPS op_compatible_info proto_desc string_helper glog) op_compatible_info_test.cc
DEPS
op_compatible_info
proto_desc
string_helper
glog)
cc_library( cc_library(
save_load_util save_load_util
SRCS save_load_util.cc SRCS save_load_util.cc
DEPS tensor scope layer) DEPS tensor scope layer)
cc_test( cc_test_old(
save_load_util_test save_load_util_test
SRCS save_load_util_test.cc SRCS
DEPS save_load_util tensor scope layer) save_load_util_test.cc
DEPS
save_load_util
tensor
scope
layer)
cc_library( cc_library(
generator generator
SRCS generator.cc SRCS generator.cc
......
...@@ -322,16 +322,18 @@ cc_test( ...@@ -322,16 +322,18 @@ cc_test(
memory memory
device_context device_context
broadcast_op_handle) broadcast_op_handle)
cc_test( cc_test_old(
gather_op_test gather_op_test
SRCS gather_op_handle_test.cc SRCS
DEPS var_handle gather_op_handle_test.cc
op_handle_base DEPS
scope var_handle
ddim op_handle_base
memory scope
device_context ddim
gather_op_handle) memory
device_context
gather_op_handle)
cc_library( cc_library(
scope_buffered_monitor scope_buffered_monitor
......
...@@ -60,7 +60,7 @@ class SumOpWithKernel : public OperatorWithKernel { ...@@ -60,7 +60,7 @@ class SumOpWithKernel : public OperatorWithKernel {
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
REGISTER_OP_WITHOUT_GRADIENT(sum, REGISTER_OP_WITHOUT_GRADIENT(fake_sum,
paddle::framework::SumOpWithKernel, paddle::framework::SumOpWithKernel,
paddle::framework::SumOpMaker); paddle::framework::SumOpMaker);
...@@ -114,7 +114,7 @@ void BuildStrategyApply(BuildStrategy *build_strategy, ir::Graph *graph) { ...@@ -114,7 +114,7 @@ void BuildStrategyApply(BuildStrategy *build_strategy, ir::Graph *graph) {
std::unique_ptr<ir::Graph> CreateGraph() { std::unique_ptr<ir::Graph> CreateGraph() {
ProgramDesc prog; ProgramDesc prog;
auto *op = prog.MutableBlock(0)->AppendOp(); auto *op = prog.MutableBlock(0)->AppendOp();
op->SetType("sum"); op->SetType("fake_sum");
op->SetInput("X", {"a1"}); op->SetInput("X", {"a1"});
op->SetOutput("Out", {"b1"}); op->SetOutput("Out", {"b1"});
op->SetAttr("op_role", 1); op->SetAttr("op_role", 1);
...@@ -133,7 +133,7 @@ std::unique_ptr<ir::Graph> CreateMultiGraph() { ...@@ -133,7 +133,7 @@ std::unique_ptr<ir::Graph> CreateMultiGraph() {
// Set contents in block_0. // Set contents in block_0.
auto *op = prog.MutableBlock(0)->AppendOp(); auto *op = prog.MutableBlock(0)->AppendOp();
op->SetType("sum"); op->SetType("fake_sum");
op->SetInput("X", {"test_a", "test_b", "test_c"}); op->SetInput("X", {"test_a", "test_b", "test_c"});
op->SetOutput("Out", {"test_out"}); op->SetOutput("Out", {"test_out"});
op->SetAttr("op_role", 1); op->SetAttr("op_role", 1);
...@@ -149,7 +149,7 @@ std::unique_ptr<ir::Graph> CreateMultiGraph() { ...@@ -149,7 +149,7 @@ std::unique_ptr<ir::Graph> CreateMultiGraph() {
// Set contents in block_1. // Set contents in block_1.
op = prog.MutableBlock(1)->AppendOp(); op = prog.MutableBlock(1)->AppendOp();
op->SetType("sum"); op->SetType("fake_sum");
op->SetInput("X", {"a1"}); op->SetInput("X", {"a1"});
op->SetOutput("Out", {"b1"}); op->SetOutput("Out", {"b1"});
op->SetAttr("op_role", 1); op->SetAttr("op_role", 1);
...@@ -159,7 +159,7 @@ std::unique_ptr<ir::Graph> CreateMultiGraph() { ...@@ -159,7 +159,7 @@ std::unique_ptr<ir::Graph> CreateMultiGraph() {
// Set contents in block_2. // Set contents in block_2.
op = prog.MutableBlock(2)->AppendOp(); op = prog.MutableBlock(2)->AppendOp();
op->SetType("sum"); op->SetType("fake_sum");
op->SetInput("X", {"a2"}); op->SetInput("X", {"a2"});
op->SetOutput("Out", {"b2"}); op->SetOutput("Out", {"b2"});
op->SetAttr("op_role", 1); op->SetAttr("op_role", 1);
......
...@@ -384,30 +384,29 @@ if(WITH_MKLDNN) ...@@ -384,30 +384,29 @@ if(WITH_MKLDNN)
test_conv_concat_relu_mkldnn_fuse_pass test_conv_concat_relu_mkldnn_fuse_pass
SRCS mkldnn/conv_concat_relu_mkldnn_fuse_pass_tester.cc SRCS mkldnn/conv_concat_relu_mkldnn_fuse_pass_tester.cc
DEPS conv_activation_mkldnn_fuse_pass) DEPS conv_activation_mkldnn_fuse_pass)
cc_test( cc_test_old(
test_conv_elementwise_add_mkldnn_fuse_pass test_conv_elementwise_add_mkldnn_fuse_pass SRCS
SRCS mkldnn/conv_elementwise_add_mkldnn_fuse_pass_tester.cc mkldnn/conv_elementwise_add_mkldnn_fuse_pass_tester.cc DEPS
DEPS conv_elementwise_add_mkldnn_fuse_pass pass_test_util) conv_elementwise_add_mkldnn_fuse_pass pass_test_util)
cc_test( cc_test_old(
test_int8_scale_calculation_mkldnn_pass test_int8_scale_calculation_mkldnn_pass SRCS
SRCS mkldnn/int8_scale_calculation_mkldnn_pass_tester.cc mkldnn/int8_scale_calculation_mkldnn_pass_tester.cc DEPS
DEPS int8_scale_calculation_mkldnn_pass pass_test_util) int8_scale_calculation_mkldnn_pass pass_test_util)
cc_test( cc_test_old(
test_params_quantization_mkldnn_pass test_params_quantization_mkldnn_pass SRCS
SRCS mkldnn/params_quantization_mkldnn_pass_tester.cc mkldnn/params_quantization_mkldnn_pass_tester.cc DEPS
DEPS params_quantization_mkldnn_pass) params_quantization_mkldnn_pass)
cc_test( cc_test_old(
test_fc_elementwise_add_mkldnn_fuse_pass test_fc_elementwise_add_mkldnn_fuse_pass SRCS
SRCS mkldnn/fc_elementwise_add_mkldnn_fuse_pass_tester.cc mkldnn/fc_elementwise_add_mkldnn_fuse_pass_tester.cc DEPS
DEPS fc_elementwise_add_mkldnn_fuse_pass pass_test_util) fc_elementwise_add_mkldnn_fuse_pass pass_test_util)
cc_test( cc_test_old(
test_fc_act_mkldnn_fuse_pass test_fc_act_mkldnn_fuse_pass SRCS mkldnn/fc_act_mkldnn_fuse_pass_tester.cc
SRCS mkldnn/fc_act_mkldnn_fuse_pass_tester.cc
DEPS fc_act_mkldnn_fuse_pass pass_test_util) DEPS fc_act_mkldnn_fuse_pass pass_test_util)
cc_test( cc_test_old(
test_batch_norm_act_fuse_pass test_batch_norm_act_fuse_pass SRCS
SRCS mkldnn/batch_norm_act_fuse_pass_tester.cc mkldnn/batch_norm_act_fuse_pass_tester.cc DEPS batch_norm_act_fuse_pass
DEPS batch_norm_act_fuse_pass pass_test_util) pass_test_util)
set(TEST_CONV_BN_PASS_DEPS set(TEST_CONV_BN_PASS_DEPS
conv_bn_fuse_pass conv_bn_fuse_pass
graph_to_program_pass graph_to_program_pass
......
...@@ -74,7 +74,7 @@ class DummyOpVarTypeInference : public VarTypeInference { ...@@ -74,7 +74,7 @@ class DummyOpVarTypeInference : public VarTypeInference {
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
REGISTER_OPERATOR(sum, REGISTER_OPERATOR(fake_sum,
paddle::framework::NOP, paddle::framework::NOP,
paddle::framework::SumOpMaker, paddle::framework::SumOpMaker,
paddle::framework::SumOpVarTypeInference); paddle::framework::SumOpVarTypeInference);
...@@ -92,7 +92,7 @@ namespace framework { ...@@ -92,7 +92,7 @@ namespace framework {
TEST(GraphTest, Basic) { TEST(GraphTest, Basic) {
ProgramDesc prog; ProgramDesc prog;
auto *op = prog.MutableBlock(0)->AppendOp(); auto *op = prog.MutableBlock(0)->AppendOp();
op->SetType("sum"); op->SetType("fake_sum");
op->SetInput("X", {"test_a", "test_b", "test_c"}); op->SetInput("X", {"test_a", "test_b", "test_c"});
op->SetOutput("Out", {"test_out"}); op->SetOutput("Out", {"test_out"});
op->SetAttr("op_role", 1); op->SetAttr("op_role", 1);
...@@ -115,7 +115,7 @@ TEST(GraphTest, Basic) { ...@@ -115,7 +115,7 @@ TEST(GraphTest, Basic) {
std::unique_ptr<ir::Graph> g(new ir::Graph(prog)); std::unique_ptr<ir::Graph> g(new ir::Graph(prog));
std::vector<ir::Node *> nodes(g->Nodes().begin(), g->Nodes().end()); std::vector<ir::Node *> nodes(g->Nodes().begin(), g->Nodes().end());
for (ir::Node *n : nodes) { for (ir::Node *n : nodes) {
if (n->Name() == "sum") { if (n->Name() == "fake_sum") {
ASSERT_EQ(n->inputs.size(), 3UL); ASSERT_EQ(n->inputs.size(), 3UL);
ASSERT_EQ(n->outputs.size(), 1UL); ASSERT_EQ(n->outputs.size(), 1UL);
} else if (n->Name() == "test_a" || n->Name() == "test_b" || } else if (n->Name() == "test_a" || n->Name() == "test_b" ||
...@@ -242,7 +242,7 @@ TEST(GraphTest, TestMultiBlock) { ...@@ -242,7 +242,7 @@ TEST(GraphTest, TestMultiBlock) {
// Set contents in block_0. // Set contents in block_0.
auto *op = prog.MutableBlock(0)->AppendOp(); auto *op = prog.MutableBlock(0)->AppendOp();
op->SetType("sum"); op->SetType("fake_sum");
op->SetInput("X", {"test_a", "test_b", "test_c"}); op->SetInput("X", {"test_a", "test_b", "test_c"});
op->SetOutput("Out", {"test_out"}); op->SetOutput("Out", {"test_out"});
op->SetAttr("op_role", 1); op->SetAttr("op_role", 1);
...@@ -262,7 +262,7 @@ TEST(GraphTest, TestMultiBlock) { ...@@ -262,7 +262,7 @@ TEST(GraphTest, TestMultiBlock) {
// Set contents in block_1. // Set contents in block_1.
op = prog.MutableBlock(1)->AppendOp(); op = prog.MutableBlock(1)->AppendOp();
op->SetType("sum"); op->SetType("fake_sum");
op->SetInput("X", {"a"}); op->SetInput("X", {"a"});
op->SetOutput("Out", {"b"}); op->SetOutput("Out", {"b"});
op->SetAttr("op_role", 1); op->SetAttr("op_role", 1);
...@@ -280,7 +280,7 @@ TEST(GraphTest, TestMultiBlock) { ...@@ -280,7 +280,7 @@ TEST(GraphTest, TestMultiBlock) {
// Set contents in block_2. // Set contents in block_2.
op = prog.MutableBlock(2)->AppendOp(); op = prog.MutableBlock(2)->AppendOp();
op->SetType("sum"); op->SetType("fake_sum");
op->SetInput("X", {"a"}); op->SetInput("X", {"a"});
op->SetOutput("Out", {"b"}); op->SetOutput("Out", {"b"});
op->SetAttr("op_role", 1); op->SetAttr("op_role", 1);
...@@ -305,7 +305,7 @@ TEST(GraphTest, TestMultiBlock) { ...@@ -305,7 +305,7 @@ TEST(GraphTest, TestMultiBlock) {
const ir::Graph *g0 = g->GetSubGraph(0); const ir::Graph *g0 = g->GetSubGraph(0);
std::vector<ir::Node *> nodes(g0->Nodes().begin(), g0->Nodes().end()); std::vector<ir::Node *> nodes(g0->Nodes().begin(), g0->Nodes().end());
for (ir::Node *n : nodes) { for (ir::Node *n : nodes) {
if (n->Name() == "sum") { if (n->Name() == "fake_sum") {
ASSERT_EQ(n->inputs.size(), 3UL); ASSERT_EQ(n->inputs.size(), 3UL);
ASSERT_EQ(n->outputs.size(), 1UL); ASSERT_EQ(n->outputs.size(), 1UL);
} else if (n->Name() == "test_a" || n->Name() == "test_b" || } else if (n->Name() == "test_a" || n->Name() == "test_b" ||
...@@ -322,7 +322,7 @@ TEST(GraphTest, TestMultiBlock) { ...@@ -322,7 +322,7 @@ TEST(GraphTest, TestMultiBlock) {
// Check contents in sub_graph_1. // Check contents in sub_graph_1.
const ir::Graph *g1 = g->GetSubGraph(1); const ir::Graph *g1 = g->GetSubGraph(1);
for (ir::Node *n : g1->Nodes()) { for (ir::Node *n : g1->Nodes()) {
if (n->Name() == "sum") { if (n->Name() == "fake_sum") {
ASSERT_EQ(n->outputs[0]->Name(), "b"); ASSERT_EQ(n->outputs[0]->Name(), "b");
ASSERT_EQ(n->outputs.size(), 1UL); ASSERT_EQ(n->outputs.size(), 1UL);
} }
...@@ -335,7 +335,7 @@ TEST(GraphTest, TestMultiBlock) { ...@@ -335,7 +335,7 @@ TEST(GraphTest, TestMultiBlock) {
// Check contents in sub_graph_2. // Check contents in sub_graph_2.
const ir::Graph *g2 = g->GetSubGraph(2); const ir::Graph *g2 = g->GetSubGraph(2);
for (ir::Node *n : g2->Nodes()) { for (ir::Node *n : g2->Nodes()) {
if (n->Name() == "sum") { if (n->Name() == "fake_sum") {
ASSERT_EQ(n->outputs[0]->Name(), "b"); ASSERT_EQ(n->outputs[0]->Name(), "b");
ASSERT_EQ(n->outputs.size(), 1UL); ASSERT_EQ(n->outputs.size(), 1UL);
} }
......
...@@ -52,6 +52,8 @@ void SetOp(ProgramDesc* prog, ...@@ -52,6 +52,8 @@ void SetOp(ProgramDesc* prog,
op->SetAttr("alpha", 0.02f); op->SetAttr("alpha", 0.02f);
} else if (type == "relu6") { } else if (type == "relu6") {
op->SetAttr("threshold", 6.0f); op->SetAttr("threshold", 6.0f);
} else if (type == "mish") {
op->SetAttr("threshold", 20.0f);
} else if (type == "swish") { } else if (type == "swish") {
op->SetAttr("beta", 1.0f); op->SetAttr("beta", 1.0f);
} }
......
...@@ -27,7 +27,8 @@ if(WITH_GPU ...@@ -27,7 +27,8 @@ if(WITH_GPU
COMMAND wget -nc --no-check-certificate COMMAND wget -nc --no-check-certificate
https://paddle-ci.gz.bcebos.com/new_exec/lm_main_program https://paddle-ci.gz.bcebos.com/new_exec/lm_main_program
COMMAND wget -nc --no-check-certificate COMMAND wget -nc --no-check-certificate
https://paddle-ci.gz.bcebos.com/new_exec/lm_startup_program) https://paddle-ci.gz.bcebos.com/new_exec/lm_startup_program
WORKING_DIRECTORY "${CC_TESTS_DIR}")
# all operators used in the program # all operators used in the program
set(OPS set(OPS
...@@ -58,16 +59,11 @@ if(WITH_GPU ...@@ -58,16 +59,11 @@ if(WITH_GPU
# All deps of the operators above, part of GLOB_OPERATOR_DEPS. # All deps of the operators above, part of GLOB_OPERATOR_DEPS.
set(OP_DEPS generator softmax selected_rows_functor jit_kernel_helper set(OP_DEPS generator softmax selected_rows_functor jit_kernel_helper
concat_and_split cross_entropy) concat_and_split cross_entropy)
cc_test(standalone_executor_test SRCS standalone_executor_test.cc)
cc_test( # add_dependencies(standalone_executor_test download_program)
standalone_executor_test # if(WITH_PROFILER)
SRCS standalone_executor_test.cc # target_link_libraries(standalone_executor_test profiler)
DEPS standalone_executor operator op_registry executor ${OPS} ${OP_DEPS}) # add_dependencies(standalone_executor_test profiler)
set_tests_properties(standalone_executor_test PROPERTIES TIMEOUT 100) # endif()
add_dependencies(standalone_executor_test download_program)
if(WITH_PROFILER)
target_link_libraries(standalone_executor_test profiler)
add_dependencies(standalone_executor_test profiler)
endif()
endif() endif()
...@@ -40,55 +40,52 @@ cc_library( ...@@ -40,55 +40,52 @@ cc_library(
cinn_launch_context) cinn_launch_context)
if(WITH_TESTING) if(WITH_TESTING)
cc_test( cc_test_old(cinn_lib_test SRCS cinn_lib_test.cc DEPS cinn)
cinn_lib_test
SRCS cinn_lib_test.cc
DEPS cinn)
set_tests_properties(cinn_lib_test PROPERTIES LABELS "RUN_TYPE=CINN") set_tests_properties(cinn_lib_test PROPERTIES LABELS "RUN_TYPE=CINN")
cc_test( cc_test_old(cinn_cache_key_test SRCS cinn_cache_key_test.cc DEPS
cinn_cache_key_test cinn_cache_key)
SRCS cinn_cache_key_test.cc
DEPS cinn_cache_key)
set_tests_properties(cinn_cache_key_test PROPERTIES LABELS "RUN_TYPE=CINN") set_tests_properties(cinn_cache_key_test PROPERTIES LABELS "RUN_TYPE=CINN")
cc_test( cc_test_old(
build_cinn_pass_test build_cinn_pass_test
SRCS build_cinn_pass_test.cc SRCS
DEPS build_cinn_pass cinn_compiler op_registry mul_op activation_op build_cinn_pass_test.cc
elementwise_add_op) DEPS
build_cinn_pass
cinn_compiler
op_registry
mul_op
activation_op
elementwise_add_op)
set_tests_properties(build_cinn_pass_test PROPERTIES LABELS "RUN_TYPE=CINN") set_tests_properties(build_cinn_pass_test PROPERTIES LABELS "RUN_TYPE=CINN")
cc_test( cc_test_old(transform_desc_test SRCS transform_desc_test.cc DEPS
transform_desc_test transform_desc)
SRCS transform_desc_test.cc
DEPS transform_desc)
set_tests_properties(transform_desc_test PROPERTIES LABELS "RUN_TYPE=CINN") set_tests_properties(transform_desc_test PROPERTIES LABELS "RUN_TYPE=CINN")
cc_test( cc_test_old(transform_type_test SRCS transform_type_test.cc DEPS
transform_type_test transform_type)
SRCS transform_type_test.cc
DEPS transform_type)
set_tests_properties(transform_type_test PROPERTIES LABELS "RUN_TYPE=CINN") set_tests_properties(transform_type_test PROPERTIES LABELS "RUN_TYPE=CINN")
cc_test( cc_test_old(cinn_graph_symbolization_test SRCS
cinn_graph_symbolization_test cinn_graph_symbolization_test.cc DEPS cinn_graph_symbolization)
SRCS cinn_graph_symbolization_test.cc
DEPS cinn_graph_symbolization)
set_tests_properties(cinn_graph_symbolization_test PROPERTIES LABELS set_tests_properties(cinn_graph_symbolization_test PROPERTIES LABELS
"RUN_TYPE=CINN") "RUN_TYPE=CINN")
cc_test( cc_test_old(
cinn_compiler_test cinn_compiler_test
SRCS cinn_compiler_test.cc SRCS
DEPS cinn_compiler cinn_compiler_test.cc
place DEPS
proto_desc cinn_compiler
graph_viz_pass place
build_cinn_pass proto_desc
cinn graph_viz_pass
mul_op build_cinn_pass
activation_op cinn
elementwise_add_op) mul_op
activation_op
elementwise_add_op)
set_tests_properties(cinn_compiler_test PROPERTIES LABELS "RUN_TYPE=CINN") set_tests_properties(cinn_compiler_test PROPERTIES LABELS "RUN_TYPE=CINN")
endif() endif()
...@@ -63,7 +63,7 @@ class SumOpVarTypeInference : public VarTypeInference { ...@@ -63,7 +63,7 @@ class SumOpVarTypeInference : public VarTypeInference {
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
REGISTER_OPERATOR(sum, REGISTER_OPERATOR(fake_sum,
paddle::framework::NOP, paddle::framework::NOP,
paddle::framework::SumOpMaker, paddle::framework::SumOpMaker,
paddle::framework::SumOpVarTypeInference); paddle::framework::SumOpVarTypeInference);
...@@ -152,7 +152,7 @@ class TestStaticGraphVarTypeInference : public StaticGraphVarTypeInference { ...@@ -152,7 +152,7 @@ class TestStaticGraphVarTypeInference : public StaticGraphVarTypeInference {
TEST(InferVarType, sum_op) { TEST(InferVarType, sum_op) {
ProgramDesc prog; ProgramDesc prog;
auto* op = prog.MutableBlock(0)->AppendOp(); auto* op = prog.MutableBlock(0)->AppendOp();
op->SetType("sum"); op->SetType("fake_sum");
op->SetInput("X", {"test_a", "test_b", "test_c"}); op->SetInput("X", {"test_a", "test_b", "test_c"});
op->SetOutput("Out", {"test_out"}); op->SetOutput("Out", {"test_out"});
......
...@@ -71,6 +71,11 @@ if(WIN32 AND WITH_GPU) ...@@ -71,6 +71,11 @@ if(WIN32 AND WITH_GPU)
cc_library(paddle_inference DEPS ${fluid_modules} phi ${STATIC_INFERENCE_API} cc_library(paddle_inference DEPS ${fluid_modules} phi ${STATIC_INFERENCE_API}
${utils_modules}) ${utils_modules})
else() else()
# message("${fluid_modules}")
# message("PHI_MODULES ${phi_modules}")
# message("${phi_kernels}")
# message("${STATIC_INFERENCE_API}")
# message("${utils_modules}")
create_static_lib(paddle_inference ${fluid_modules} ${phi_modules} create_static_lib(paddle_inference ${fluid_modules} ${phi_modules}
${phi_kernels} ${STATIC_INFERENCE_API} ${utils_modules}) ${phi_kernels} ${STATIC_INFERENCE_API} ${utils_modules})
endif() endif()
......
...@@ -136,16 +136,25 @@ if(WITH_TESTING) ...@@ -136,16 +136,25 @@ if(WITH_TESTING)
endif() endif()
if(NOT APPLE AND NOT WIN32) if(NOT APPLE AND NOT WIN32)
cc_test( cc_test_old(
test_analysis_predictor test_analysis_predictor
SRCS analysis_predictor_tester.cc SRCS
DEPS paddle_inference_shared ARGS --dirname=${WORD2VEC_MODEL_DIR}) analysis_predictor_tester.cc
DEPS
paddle_inference_shared
ARGS
--dirname=${WORD2VEC_MODEL_DIR})
elseif(WIN32) elseif(WIN32)
cc_test( cc_test_old(
test_analysis_predictor test_analysis_predictor
SRCS analysis_predictor_tester.cc SRCS
DEPS analysis_predictor benchmark ${inference_deps} ARGS analysis_predictor_tester.cc
--dirname=${WORD2VEC_MODEL_DIR}) DEPS
analysis_predictor
benchmark
${inference_deps}
ARGS
--dirname=${WORD2VEC_MODEL_DIR})
endif() endif()
if(WITH_TESTING AND WITH_MKLDNN) if(WITH_TESTING AND WITH_MKLDNN)
......
...@@ -14,11 +14,16 @@ cc_library( ...@@ -14,11 +14,16 @@ cc_library(
lite_tensor_utils lite_tensor_utils
SRCS tensor_utils.cc SRCS tensor_utils.cc
DEPS memcpy ${LITE_DEPS} framework_proto device_context ${XPU_DEPS}) DEPS memcpy ${LITE_DEPS} framework_proto device_context ${XPU_DEPS})
cc_test( cc_test_old(
test_lite_engine test_lite_engine
SRCS test_engine_lite.cc SRCS
DEPS lite_engine protobuf framework_proto glog gtest analysis) test_engine_lite.cc
cc_test( DEPS
test_lite_tensor_utils lite_engine
SRCS test_tensor_utils.cc protobuf
DEPS lite_engine lite_tensor_utils) framework_proto
glog
gtest
analysis)
cc_test_old(test_lite_tensor_utils SRCS test_tensor_utils.cc DEPS lite_engine
lite_tensor_utils)
...@@ -2,10 +2,7 @@ cc_library( ...@@ -2,10 +2,7 @@ cc_library(
benchmark benchmark
SRCS benchmark.cc SRCS benchmark.cc
DEPS enforce) DEPS enforce)
cc_test( cc_test_old(test_benchmark SRCS benchmark_tester.cc DEPS benchmark)
test_benchmark
SRCS benchmark_tester.cc
DEPS benchmark)
cc_library( cc_library(
infer_io_utils infer_io_utils
SRCS io_utils.cc SRCS io_utils.cc
...@@ -14,10 +11,7 @@ cc_library( ...@@ -14,10 +11,7 @@ cc_library(
model_utils model_utils
SRCS model_utils.cc SRCS model_utils.cc
DEPS proto_desc enforce) DEPS proto_desc enforce)
cc_test( cc_test_old(infer_io_utils_tester SRCS io_utils_tester.cc DEPS infer_io_utils)
infer_io_utils_tester
SRCS io_utils_tester.cc
DEPS infer_io_utils)
if(WITH_ONNXRUNTIME AND WIN32) if(WITH_ONNXRUNTIME AND WIN32)
# Copy onnxruntime for some c++ test in Windows, since the test will # Copy onnxruntime for some c++ test in Windows, since the test will
...@@ -26,9 +20,6 @@ if(WITH_ONNXRUNTIME AND WIN32) ...@@ -26,9 +20,6 @@ if(WITH_ONNXRUNTIME AND WIN32)
endif() endif()
cc_library(table_printer SRCS table_printer.cc) cc_library(table_printer SRCS table_printer.cc)
cc_test( cc_test_old(test_table_printer SRCS table_printer_tester.cc DEPS table_printer)
test_table_printer
SRCS table_printer_tester.cc
DEPS table_printer)
proto_library(shape_range_info_proto SRCS shape_range_info.proto) proto_library(shape_range_info_proto SRCS shape_range_info.proto)
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
#include "paddle/fluid/inference/utils/table_printer.h" #include "paddle/fluid/inference/utils/table_printer.h"
#ifdef WIN32 #ifdef _WIN32
// suppress the min and max definitions in Windef.h. // suppress the min and max definitions in Windef.h.
#define NOMINMAX #define NOMINMAX
#include <Windows.h> #include <Windows.h>
...@@ -58,7 +58,7 @@ std::string TablePrinter::PrintTable() { ...@@ -58,7 +58,7 @@ std::string TablePrinter::PrintTable() {
TablePrinter::TablePrinter(const std::vector<std::string>& header) { TablePrinter::TablePrinter(const std::vector<std::string>& header) {
size_t terminal_witdh = 500; size_t terminal_witdh = 500;
#ifdef WIN32 #ifdef _WIN32
CONSOLE_SCREEN_BUFFER_INFO csbi; CONSOLE_SCREEN_BUFFER_INFO csbi;
int ret = GetConsoleScreenBufferInfo(GetStdHandle(STD_OUTPUT_HANDLE), &csbi); int ret = GetConsoleScreenBufferInfo(GetStdHandle(STD_OUTPUT_HANDLE), &csbi);
if (ret && (csbi.dwSize.X != 0)) { if (ret && (csbi.dwSize.X != 0)) {
......
...@@ -56,7 +56,8 @@ if(WITH_TESTING AND NOT WIN32) ...@@ -56,7 +56,8 @@ if(WITH_TESTING AND NOT WIN32)
COMMAND COMMAND
wget -nc -q --no-check-certificate wget -nc -q --no-check-certificate
https://paddle-ci.gz.bcebos.com/dy2st/multi_program_load_with_property.tar.gz https://paddle-ci.gz.bcebos.com/dy2st/multi_program_load_with_property.tar.gz
COMMAND tar zxf multi_program_load_with_property.tar.gz) COMMAND tar zxf multi_program_load_with_property.tar.gz
WORKING_DIRECTORY "${CC_TESTS_DIR}")
set(JIT_DEPS set(JIT_DEPS
phi phi
phi_api phi_api
...@@ -73,13 +74,13 @@ if(WITH_TESTING AND NOT WIN32) ...@@ -73,13 +74,13 @@ if(WITH_TESTING AND NOT WIN32)
layer_test layer_test
SRCS layer_test.cc SRCS layer_test.cc
DEPS ${JIT_DEPS}) DEPS ${JIT_DEPS})
add_dependencies(layer_test jit_download_program) # add_dependencies(layer_test jit_download_program)
cc_test( cc_test(
layer_test_new layer_test_new
SRCS layer_test.cc SRCS layer_test.cc
DEPS ${JIT_DEPS}) DEPS ${JIT_DEPS})
add_dependencies(layer_test_new jit_download_program) # add_dependencies(layer_test_new jit_download_program)
set_tests_properties(layer_test_new PROPERTIES ENVIRONMENT # set_tests_properties(layer_test_new PROPERTIES ENVIRONMENT
"FLAGS_jit_engine_type=New") # "FLAGS_jit_engine_type=New")
endif() endif()
...@@ -75,10 +75,8 @@ cc_test( ...@@ -75,10 +75,8 @@ cc_test(
naive_best_fit_allocator_test naive_best_fit_allocator_test
SRCS naive_best_fit_allocator_test.cc SRCS naive_best_fit_allocator_test.cc
DEPS allocator) DEPS allocator)
cc_test( cc_test_old(buffered_allocator_test SRCS buffered_allocator_test.cc DEPS
buffered_allocator_test allocator)
SRCS buffered_allocator_test.cc
DEPS allocator)
if(WITH_GPU) if(WITH_GPU)
nv_test( nv_test(
...@@ -104,21 +102,14 @@ elseif(WITH_ROCM) ...@@ -104,21 +102,14 @@ elseif(WITH_ROCM)
SRCS best_fit_allocator_test.cc best_fit_allocator_test.cu SRCS best_fit_allocator_test.cc best_fit_allocator_test.cu
DEPS allocator memcpy) DEPS allocator memcpy)
else() else()
cc_test( cc_test_old(best_fit_allocator_test SRCS best_fit_allocator_test.cc DEPS
best_fit_allocator_test allocator)
SRCS best_fit_allocator_test.cc
DEPS allocator)
endif() endif()
cc_test( cc_test_old(test_aligned_allocator SRCS test_aligned_allocator.cc DEPS
test_aligned_allocator allocator)
SRCS test_aligned_allocator.cc
DEPS allocator)
cc_test( cc_test_old(retry_allocator_test SRCS retry_allocator_test.cc DEPS allocator)
retry_allocator_test
SRCS retry_allocator_test.cc
DEPS allocator)
if(TEST retry_allocator_test) if(TEST retry_allocator_test)
set_tests_properties(retry_allocator_test PROPERTIES LABELS set_tests_properties(retry_allocator_test PROPERTIES LABELS
"RUN_TYPE=EXCLUSIVE") "RUN_TYPE=EXCLUSIVE")
...@@ -138,10 +129,8 @@ cc_test( ...@@ -138,10 +129,8 @@ cc_test(
auto_growth_best_fit_allocator_facade_test auto_growth_best_fit_allocator_facade_test
SRCS auto_growth_best_fit_allocator_facade_test.cc SRCS auto_growth_best_fit_allocator_facade_test.cc
DEPS allocator) DEPS allocator)
cc_test( cc_test_old(auto_growth_best_fit_allocator_test SRCS
auto_growth_best_fit_allocator_test auto_growth_best_fit_allocator_test.cc DEPS allocator)
SRCS auto_growth_best_fit_allocator_test.cc
DEPS allocator)
if(NOT WIN32) if(NOT WIN32)
cc_test( cc_test(
...@@ -161,11 +150,6 @@ cc_test( ...@@ -161,11 +150,6 @@ cc_test(
DEPS allocator) DEPS allocator)
if(WITH_TESTING) if(WITH_TESTING)
if(TEST buddy_allocator_test)
set_tests_properties(buddy_allocator_test PROPERTIES LABELS
"RUN_TYPE=EXCLUSIVE")
endif()
# TODO(zhiqiu): why not win32? because wget is not found on windows # TODO(zhiqiu): why not win32? because wget is not found on windows
if(NOT WIN32) if(NOT WIN32)
add_custom_target( add_custom_target(
...@@ -173,6 +157,5 @@ if(WITH_TESTING) ...@@ -173,6 +157,5 @@ if(WITH_TESTING)
COMMAND wget -nc --no-check-certificate COMMAND wget -nc --no-check-certificate
https://paddle-ci.cdn.bcebos.com/buddy_allocator_test_data.tar https://paddle-ci.cdn.bcebos.com/buddy_allocator_test_data.tar
COMMAND tar -xf buddy_allocator_test_data.tar) COMMAND tar -xf buddy_allocator_test_data.tar)
add_dependencies(buddy_allocator_test download_data)
endif() endif()
endif() endif()
cc_test( cc_test(
op_tester op_tester
SRCS op_tester.cc op_tester_config.cc SRCS op_tester.cc
DEPS memory DEPS memory
timer timer
framework_proto framework_proto
......
...@@ -20,224 +20,6 @@ limitations under the License. */ ...@@ -20,224 +20,6 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace operators { namespace operators {
namespace benchmark { namespace benchmark {} // namespace benchmark
static const char kStartSeparator[] = "{";
static const char kEndSeparator[] = "}";
static const char kSepBetweenItems[] = ";";
static bool StartWith(const std::string& str, const std::string& substr) {
return str.find(substr) == 0;
}
static bool EndWith(const std::string& str, const std::string& substr) {
return str.rfind(substr) == (str.length() - substr.length());
}
static void EraseEndSep(std::string* str,
std::string substr = kSepBetweenItems) {
if (EndWith(*str, substr)) {
str->erase(str->length() - substr.length(), str->length());
}
}
OpInputConfig::OpInputConfig(std::istream& is) {
std::string sep;
is >> sep;
if (sep == kStartSeparator) {
while (sep != kEndSeparator) {
is >> sep;
if (sep == "name" || sep == "name:") {
is >> name;
EraseEndSep(&name);
} else if (sep == "dtype" || sep == "dtype:") {
ParseDType(is);
} else if (sep == "initializer" || sep == "initializer:") {
ParseInitializer(is);
} else if (sep == "dims" || sep == "dims:") {
ParseDims(is);
} else if (sep == "lod" || sep == "lod:") {
ParseLoD(is);
} else if (sep == "filename") {
is >> filename;
EraseEndSep(&filename);
}
}
}
}
void OpInputConfig::ParseDType(std::istream& is) {
std::string dtype_str;
is >> dtype_str;
EraseEndSep(&dtype_str);
if (dtype_str == "int32" || dtype_str == "int") {
dtype = "int32";
} else if (dtype_str == "int64" || dtype_str == "long") {
dtype = "int64";
} else if (dtype_str == "fp32" || dtype_str == "float") {
dtype = "fp32";
} else if (dtype_str == "fp64" || dtype_str == "double") {
dtype = "fp64";
} else {
PADDLE_THROW(platform::errors::Unimplemented(
"Unsupported dtype %s in OpInputConfig.", dtype_str.c_str()));
}
VLOG(4) << "dtype of input " << name << " is: " << dtype;
}
void OpInputConfig::ParseInitializer(std::istream& is) {
std::string initializer_str;
is >> initializer_str;
EraseEndSep(&initializer_str);
const std::vector<std::string> supported_initializers = {
"random", "natural", "zeros", "file"};
if (!Has(supported_initializers, initializer_str)) {
PADDLE_THROW(platform::errors::Unimplemented(
"Unsupported initializer %s in OpInputConfig.",
initializer_str.c_str()));
}
initializer = initializer_str;
VLOG(4) << "initializer of input " << name << " is: " << initializer;
}
void OpInputConfig::ParseDims(std::istream& is) {
std::string dims_str;
is >> dims_str;
dims.clear();
std::string token;
std::istringstream token_stream(dims_str);
while (std::getline(token_stream, token, 'x')) {
dims.push_back(std::stoi(token));
}
}
void OpInputConfig::ParseLoD(std::istream& is) {
std::string lod_str;
std::string start_sep =
std::string(kStartSeparator) + std::string(kStartSeparator);
std::string end_sep = std::string(kEndSeparator) + std::string(kEndSeparator);
std::string sep;
is >> sep;
if (StartWith(sep, start_sep)) {
lod_str += sep;
while (!EndWith(sep, end_sep)) {
is >> sep;
lod_str += sep;
}
}
EraseEndSep(&lod_str);
PADDLE_ENFORCE_GE(
lod_str.length(),
4U,
platform::errors::InvalidArgument(
"The length of lod string should be "
"equal to or larger than 4. But length of lod string is %zu.",
lod_str.length()));
VLOG(4) << "lod: " << lod_str << ", length: " << lod_str.length();
// Parse the lod_str
lod.clear();
for (size_t i = 1; i < lod_str.length() - 1;) {
if (lod_str[i] == '{') {
std::vector<size_t> level;
while (lod_str[i] != '}') {
++i;
std::string number;
while (lod_str[i] >= '0' && lod_str[i] <= '9') {
number += lod_str[i];
++i;
}
level.push_back(StringTo<size_t>(number));
}
lod.push_back(level);
} else if (lod_str[i] == '}') {
++i;
}
}
}
OpTesterConfig::OpTesterConfig(const std::string& filename) {
std::ifstream fin(filename, std::ios::in | std::ios::binary);
PADDLE_ENFORCE_EQ(
static_cast<bool>(fin),
true,
platform::errors::InvalidArgument("OpTesterConfig cannot open file %s.",
filename.c_str()));
Init(fin);
}
bool OpTesterConfig::Init(std::istream& is) {
std::string sep;
is >> sep;
if (sep == kStartSeparator) {
while (sep != kEndSeparator) {
is >> sep;
if (sep == "op_type" || sep == "op_type:") {
is >> op_type;
} else if (sep == "device_id" || sep == "device_id:") {
is >> device_id;
} else if (sep == "repeat" || sep == "repeat:") {
is >> repeat;
} else if (sep == "profile" || sep == "profile:") {
is >> profile;
} else if (sep == "print_debug_string" || sep == "print_debug_string:") {
is >> print_debug_string;
} else if (sep == "input" || sep == "input:") {
OpInputConfig input_config(is);
inputs.push_back(input_config);
} else if (sep == "attrs" || sep == "attrs:") {
ParseAttrs(is);
} else {
if (sep != kEndSeparator) {
return false;
}
}
}
} else {
return false;
}
return true;
}
bool OpTesterConfig::ParseAttrs(std::istream& is) {
std::string sep;
is >> sep;
if (sep == kStartSeparator) {
while (true) {
std::string key;
is >> key;
if (key == kEndSeparator) {
break;
}
std::string value;
is >> value;
EraseEndSep(&key, ":");
EraseEndSep(&value);
VLOG(4) << "attrs: " << key << ", " << value;
attrs[key] = value;
}
}
return true;
}
const OpInputConfig* OpTesterConfig::GetInput(const std::string& name) {
for (size_t i = 0; i < inputs.size(); ++i) {
if (inputs[i].name == name) {
return &inputs[i];
}
}
return nullptr;
}
} // namespace benchmark
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
...@@ -78,6 +78,222 @@ T StringTo(const std::string& str) { ...@@ -78,6 +78,222 @@ T StringTo(const std::string& str) {
return value; return value;
} }
static const char kStartSeparator[] = "{";
static const char kEndSeparator[] = "}";
static const char kSepBetweenItems[] = ";";
static bool StartWith(const std::string& str, const std::string& substr) {
return str.find(substr) == 0;
}
static bool EndWith(const std::string& str, const std::string& substr) {
return str.rfind(substr) == (str.length() - substr.length());
}
static void EraseEndSep(std::string* str,
std::string substr = kSepBetweenItems) {
if (EndWith(*str, substr)) {
str->erase(str->length() - substr.length(), str->length());
}
}
OpInputConfig::OpInputConfig(std::istream& is) {
std::string sep;
is >> sep;
if (sep == kStartSeparator) {
while (sep != kEndSeparator) {
is >> sep;
if (sep == "name" || sep == "name:") {
is >> name;
EraseEndSep(&name);
} else if (sep == "dtype" || sep == "dtype:") {
ParseDType(is);
} else if (sep == "initializer" || sep == "initializer:") {
ParseInitializer(is);
} else if (sep == "dims" || sep == "dims:") {
ParseDims(is);
} else if (sep == "lod" || sep == "lod:") {
ParseLoD(is);
} else if (sep == "filename") {
is >> filename;
EraseEndSep(&filename);
}
}
}
}
void OpInputConfig::ParseDType(std::istream& is) {
std::string dtype_str;
is >> dtype_str;
EraseEndSep(&dtype_str);
if (dtype_str == "int32" || dtype_str == "int") {
dtype = "int32";
} else if (dtype_str == "int64" || dtype_str == "long") {
dtype = "int64";
} else if (dtype_str == "fp32" || dtype_str == "float") {
dtype = "fp32";
} else if (dtype_str == "fp64" || dtype_str == "double") {
dtype = "fp64";
} else {
PADDLE_THROW(platform::errors::Unimplemented(
"Unsupported dtype %s in OpInputConfig.", dtype_str.c_str()));
}
VLOG(4) << "dtype of input " << name << " is: " << dtype;
}
void OpInputConfig::ParseInitializer(std::istream& is) {
std::string initializer_str;
is >> initializer_str;
EraseEndSep(&initializer_str);
const std::vector<std::string> supported_initializers = {
"random", "natural", "zeros", "file"};
if (!Has(supported_initializers, initializer_str)) {
PADDLE_THROW(platform::errors::Unimplemented(
"Unsupported initializer %s in OpInputConfig.",
initializer_str.c_str()));
}
initializer = initializer_str;
VLOG(4) << "initializer of input " << name << " is: " << initializer;
}
void OpInputConfig::ParseDims(std::istream& is) {
std::string dims_str;
is >> dims_str;
dims.clear();
std::string token;
std::istringstream token_stream(dims_str);
while (std::getline(token_stream, token, 'x')) {
dims.push_back(std::stoi(token));
}
}
void OpInputConfig::ParseLoD(std::istream& is) {
std::string lod_str;
std::string start_sep =
std::string(kStartSeparator) + std::string(kStartSeparator);
std::string end_sep = std::string(kEndSeparator) + std::string(kEndSeparator);
std::string sep;
is >> sep;
if (StartWith(sep, start_sep)) {
lod_str += sep;
while (!EndWith(sep, end_sep)) {
is >> sep;
lod_str += sep;
}
}
EraseEndSep(&lod_str);
PADDLE_ENFORCE_GE(
lod_str.length(),
4U,
platform::errors::InvalidArgument(
"The length of lod string should be "
"equal to or larger than 4. But length of lod string is %zu.",
lod_str.length()));
VLOG(4) << "lod: " << lod_str << ", length: " << lod_str.length();
// Parse the lod_str
lod.clear();
for (size_t i = 1; i < lod_str.length() - 1;) {
if (lod_str[i] == '{') {
std::vector<size_t> level;
while (lod_str[i] != '}') {
++i;
std::string number;
while (lod_str[i] >= '0' && lod_str[i] <= '9') {
number += lod_str[i];
++i;
}
level.push_back(StringTo<size_t>(number));
}
lod.push_back(level);
} else if (lod_str[i] == '}') {
++i;
}
}
}
OpTesterConfig::OpTesterConfig(const std::string& filename) {
std::ifstream fin(filename, std::ios::in | std::ios::binary);
PADDLE_ENFORCE_EQ(
static_cast<bool>(fin),
true,
platform::errors::InvalidArgument("OpTesterConfig cannot open file %s.",
filename.c_str()));
Init(fin);
}
bool OpTesterConfig::Init(std::istream& is) {
std::string sep;
is >> sep;
if (sep == kStartSeparator) {
while (sep != kEndSeparator) {
is >> sep;
if (sep == "op_type" || sep == "op_type:") {
is >> op_type;
} else if (sep == "device_id" || sep == "device_id:") {
is >> device_id;
} else if (sep == "repeat" || sep == "repeat:") {
is >> repeat;
} else if (sep == "profile" || sep == "profile:") {
is >> profile;
} else if (sep == "print_debug_string" || sep == "print_debug_string:") {
is >> print_debug_string;
} else if (sep == "input" || sep == "input:") {
OpInputConfig input_config(is);
inputs.push_back(input_config);
} else if (sep == "attrs" || sep == "attrs:") {
ParseAttrs(is);
} else {
if (sep != kEndSeparator) {
return false;
}
}
}
} else {
return false;
}
return true;
}
bool OpTesterConfig::ParseAttrs(std::istream& is) {
std::string sep;
is >> sep;
if (sep == kStartSeparator) {
while (true) {
std::string key;
is >> key;
if (key == kEndSeparator) {
break;
}
std::string value;
is >> value;
EraseEndSep(&key, ":");
EraseEndSep(&value);
VLOG(4) << "attrs: " << key << ", " << value;
attrs[key] = value;
}
}
return true;
}
const OpInputConfig* OpTesterConfig::GetInput(const std::string& name) {
for (size_t i = 0; i < inputs.size(); ++i) {
if (inputs[i].name == name) {
return &inputs[i];
}
}
return nullptr;
}
} // namespace benchmark } // namespace benchmark
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
...@@ -30,37 +30,48 @@ set(CINN_OP_DEPS ...@@ -30,37 +30,48 @@ set(CINN_OP_DEPS
register_operators(DEPS ${CINN_OP_DEPS}) register_operators(DEPS ${CINN_OP_DEPS})
if(WITH_TESTING) if(WITH_TESTING)
cc_test( cc_test_old(
cinn_launch_context_test cinn_launch_context_test
SRCS cinn_launch_context_test.cc SRCS
DEPS ddim cinn_launch_context_test.cc
lod_tensor DEPS
scope ddim
proto_desc lod_tensor
graph scope
cinn_launch_context proto_desc
cinn_instruction_run_op graph
cinn) cinn_launch_context
cinn_instruction_run_op
cinn)
set_tests_properties(cinn_launch_context_test PROPERTIES LABELS set_tests_properties(cinn_launch_context_test PROPERTIES LABELS
"RUN_TYPE=CINN") "RUN_TYPE=CINN")
set(CINN_RUN_ENVIRONMENT set(CINN_RUN_ENVIRONMENT
"OMP_NUM_THREADS=1;runtime_include_dir=${PADDLE_BINARY_DIR}/third_party/CINN/src/external_cinn/cinn/runtime/cuda" "OMP_NUM_THREADS=1;runtime_include_dir=${PADDLE_BINARY_DIR}/third_party/CINN/src/external_cinn/cinn/runtime/cuda"
) )
cc_test( cc_test_old(
cinn_launch_op_test cinn_launch_op_test
SRCS cinn_launch_op_test.cc SRCS
DEPS cinn_compiler cinn_launch_op cinn_instruction_run_op cinn_launch_op_test.cc
elementwise_add_op gflags) DEPS
cinn_compiler
cinn_launch_op
cinn_instruction_run_op
elementwise_add_op
gflags)
set_tests_properties( set_tests_properties(
cinn_launch_op_test PROPERTIES LABELS "RUN_TYPE=CINN" ENVIRONMENT cinn_launch_op_test PROPERTIES LABELS "RUN_TYPE=CINN" ENVIRONMENT
"${CINN_RUN_ENVIRONMENT}") "${CINN_RUN_ENVIRONMENT}")
cc_test( cc_test_old(
cinn_instruction_run_op_test cinn_instruction_run_op_test
SRCS cinn_instruction_run_op_test.cc SRCS
DEPS cinn_compiler cinn_launch_op cinn_instruction_run_op cinn_instruction_run_op_test.cc
elementwise_add_op) DEPS
cinn_compiler
cinn_launch_op
cinn_instruction_run_op
elementwise_add_op)
set_tests_properties( set_tests_properties(
cinn_instruction_run_op_test PROPERTIES LABELS "RUN_TYPE=CINN" ENVIRONMENT cinn_instruction_run_op_test PROPERTIES LABELS "RUN_TYPE=CINN" ENVIRONMENT
"${CINN_RUN_ENVIRONMENT}") "${CINN_RUN_ENVIRONMENT}")
......
...@@ -26,7 +26,6 @@ limitations under the License. */ ...@@ -26,7 +26,6 @@ limitations under the License. */
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h" #include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/program_desc.h" #include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/operators/copy_cross_scope_op.cc"
#include "paddle/fluid/string/printf.h" #include "paddle/fluid/string/printf.h"
#define Conn(x, y) x##y #define Conn(x, y) x##y
......
op_library(lite_engine_op DEPS lite_engine lite_tensor_utils) op_library(lite_engine_op DEPS lite_engine lite_tensor_utils)
cc_test( cc_test_old(test_lite_engine_op SRCS lite_engine_op_test.cc DEPS lite_engine_op
test_lite_engine_op analysis)
SRCS lite_engine_op_test.cc
DEPS lite_engine_op analysis)
cc_test( cc_test_old(
test_mkldnn_op_nhwc test_mkldnn_op_nhwc
SRCS mkldnn/test_mkldnn_op_nhwc.cc SRCS
DEPS op_registry mkldnn/test_mkldnn_op_nhwc.cc
pool_op DEPS
shape_op op_registry
crop_op pool_op
activation_op shape_op
pooling crop_op
transpose_op activation_op
scope pooling
device_context transpose_op
enforce scope
executor) device_context
enforce
executor)
...@@ -42,7 +42,4 @@ set(PRIM_OP_SRCS ...@@ -42,7 +42,4 @@ set(PRIM_OP_SRCS
rsqrt_p_op.cc rsqrt_p_op.cc
uniform_random_p_op.cc) uniform_random_p_op.cc)
cc_test( cc_test_old(prim_op_test SRCS prim_op_test.cc ${PRIM_OP_SRCS} DEPS op_registry)
prim_op_test
SRCS prim_op_test.cc ${PRIM_OP_SRCS}
DEPS op_registry)
...@@ -76,61 +76,69 @@ set(OPERATOR_DEPS ...@@ -76,61 +76,69 @@ set(OPERATOR_DEPS
set_source_files_properties( set_source_files_properties(
heter_server_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) heter_server_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
cc_test( cc_test_old(
heter_server_test heter_server_test
SRCS heter_server_test.cc SRCS
DEPS ${RPC_DEPS} heter_server_test.cc
${DISTRIBUTE_DEPS} DEPS
executor ${RPC_DEPS}
scope ${DISTRIBUTE_DEPS}
proto_desc executor
scale_op scope
eigen_function) proto_desc
scale_op
eigen_function)
set_source_files_properties( set_source_files_properties(
send_and_recv_op_cpu_test.cc PROPERTIES COMPILE_FLAGS send_and_recv_op_cpu_test.cc PROPERTIES COMPILE_FLAGS
${DISTRIBUTE_COMPILE_FLAGS}) ${DISTRIBUTE_COMPILE_FLAGS})
cc_test( cc_test_old(
send_and_recv_cpu_test send_and_recv_cpu_test
SRCS send_and_recv_op_cpu_test.cc SRCS
DEPS executor send_and_recv_op_cpu_test.cc
scope DEPS
proto_desc executor
scale_op scope
send_and_recv_op proto_desc
${RPC_DEPS} scale_op
${DISTRIBUTE_DEPS} send_and_recv_op
eigen_function) ${RPC_DEPS}
${DISTRIBUTE_DEPS}
eigen_function)
set_source_files_properties( set_source_files_properties(
send_and_recv_op_gpu_test.cc PROPERTIES COMPILE_FLAGS send_and_recv_op_gpu_test.cc PROPERTIES COMPILE_FLAGS
${DISTRIBUTE_COMPILE_FLAGS}) ${DISTRIBUTE_COMPILE_FLAGS})
cc_test( cc_test_old(
send_and_recv_gpu_test send_and_recv_gpu_test
SRCS send_and_recv_op_gpu_test.cc SRCS
DEPS executor send_and_recv_op_gpu_test.cc
scope DEPS
proto_desc executor
scale_op scope
send_and_recv_op proto_desc
${RPC_DEPS} scale_op
${DISTRIBUTE_DEPS} send_and_recv_op
eigen_function) ${RPC_DEPS}
${DISTRIBUTE_DEPS}
eigen_function)
set_source_files_properties( set_source_files_properties(
heter_listen_and_server_test.cc PROPERTIES COMPILE_FLAGS heter_listen_and_server_test.cc PROPERTIES COMPILE_FLAGS
${DISTRIBUTE_COMPILE_FLAGS}) ${DISTRIBUTE_COMPILE_FLAGS})
cc_test( cc_test_old(
heter_listen_and_server_test heter_listen_and_server_test
SRCS heter_listen_and_server_test.cc SRCS
DEPS executor heter_listen_and_server_test.cc
scope DEPS
proto_desc executor
scale_op scope
heter_listen_and_serv_op proto_desc
${RPC_DEPS} scale_op
${DISTRIBUTE_DEPS} heter_listen_and_serv_op
eigen_function) ${RPC_DEPS}
${DISTRIBUTE_DEPS}
eigen_function)
#set_source_files_properties(heter_cloud_comm_cpu_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) #set_source_files_properties(heter_cloud_comm_cpu_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
#cc_test(heter_cloud_comm_cpu_test SRCS heter_cloud_comm_cpu_test.cc DEPS executor scope proto_desc scale_op heter_listen_and_serv_op ${RPC_DEPS} ${DISTRIBUTE_DEPS} eigen_function) #cc_test(heter_cloud_comm_cpu_test SRCS heter_cloud_comm_cpu_test.cc DEPS executor scope proto_desc scale_op heter_listen_and_serv_op ${RPC_DEPS} ${DISTRIBUTE_DEPS} eigen_function)
......
...@@ -617,11 +617,20 @@ if(WITH_PYTHON) ...@@ -617,11 +617,20 @@ if(WITH_PYTHON)
if(WIN32) if(WIN32)
set(SHARD_LIB_NAME libpaddle) set(SHARD_LIB_NAME libpaddle)
endif() endif()
set_property(GLOBAL PROPERTY PADDLE_LIB_NAME ${SHARD_LIB_NAME})
cc_library( cc_library(
${SHARD_LIB_NAME} SHARED ${SHARD_LIB_NAME} SHARED
SRCS ${PYBIND_SRCS} SRCS ${PYBIND_SRCS}
DEPS ${PYBIND_DEPS} ${GLOB_OP_LIB} ${GLOB_OPERATOR_DEPS}) DEPS ${PYBIND_DEPS} ${GLOB_OP_LIB} ${GLOB_OPERATOR_DEPS})
# TODO(zhiqiu): some symbols not exported even setting the following
# property. Need to find a better way.
# if(WIN32)
# set_property(TARGET ${SHARD_LIB_NAME}
# PROPERTY WINDOWS_EXPORT_ALL_SYMBOLS ON)
# endif()
if(NOT ((NOT WITH_PYTHON) AND ON_INFER)) if(NOT ((NOT WITH_PYTHON) AND ON_INFER))
add_dependencies(${SHARD_LIB_NAME} legacy_eager_codegen) add_dependencies(${SHARD_LIB_NAME} legacy_eager_codegen)
add_dependencies(${SHARD_LIB_NAME} eager_legacy_op_function_generator_cmd) add_dependencies(${SHARD_LIB_NAME} eager_legacy_op_function_generator_cmd)
......
...@@ -24,10 +24,15 @@ cc_test( ...@@ -24,10 +24,15 @@ cc_test(
test_op_utils test_op_utils
SRCS test_op_utils.cc SRCS test_op_utils.cc
DEPS op_compat_infos) DEPS op_compat_infos)
cc_test( cc_test_old(
test_meta_fn_utils test_meta_fn_utils
SRCS test_meta_fn_utils.cc SRCS
DEPS dense_tensor wrapped_infermeta infermeta infermeta_utils) test_meta_fn_utils.cc
DEPS
dense_tensor
wrapped_infermeta
infermeta
infermeta_utils)
cc_test( cc_test(
test_ddim test_ddim
......
...@@ -33,7 +33,7 @@ TEST(ARG_MAP, fill_constant) { ...@@ -33,7 +33,7 @@ TEST(ARG_MAP, fill_constant) {
{"ShapeTensor", "ValueTensor"}, {}, {}, {}, {"Out"}); {"ShapeTensor", "ValueTensor"}, {}, {}, {}, {"Out"});
auto signature1 = (*OpUtilsMap::Instance().GetArgumentMappingFn( auto signature1 = (*OpUtilsMap::Instance().GetArgumentMappingFn(
"fill_constant"))(arg_case1); "fill_constant"))(arg_case1);
ASSERT_EQ(signature1.name, "full_sr"); EXPECT_STREQ(signature1.name, "full_sr");
TestArgumentMappingContext arg_case2( TestArgumentMappingContext arg_case2(
{"ShapeTensor"}, {"ShapeTensor"},
...@@ -43,7 +43,7 @@ TEST(ARG_MAP, fill_constant) { ...@@ -43,7 +43,7 @@ TEST(ARG_MAP, fill_constant) {
{"Out"}); {"Out"});
auto signature2 = (*OpUtilsMap::Instance().GetArgumentMappingFn( auto signature2 = (*OpUtilsMap::Instance().GetArgumentMappingFn(
"fill_constant"))(arg_case2); "fill_constant"))(arg_case2);
ASSERT_EQ(signature2.name, "full_sr"); EXPECT_STREQ(signature2.name, "full_sr");
TestArgumentMappingContext arg_case3( TestArgumentMappingContext arg_case3(
{"ShapeTensor"}, {"ShapeTensor"},
...@@ -53,13 +53,13 @@ TEST(ARG_MAP, fill_constant) { ...@@ -53,13 +53,13 @@ TEST(ARG_MAP, fill_constant) {
{"Out"}); {"Out"});
auto signature3 = (*OpUtilsMap::Instance().GetArgumentMappingFn( auto signature3 = (*OpUtilsMap::Instance().GetArgumentMappingFn(
"fill_constant"))(arg_case3); "fill_constant"))(arg_case3);
ASSERT_EQ(signature3.name, "full_sr"); EXPECT_STREQ(signature3.name, "full_sr");
TestArgumentMappingContext arg_case4( TestArgumentMappingContext arg_case4(
{"ShapeTensorList", "ValueTensor"}, {}, {}, {}, {"Out"}); {"ShapeTensorList", "ValueTensor"}, {}, {}, {}, {"Out"});
auto signature4 = (*OpUtilsMap::Instance().GetArgumentMappingFn( auto signature4 = (*OpUtilsMap::Instance().GetArgumentMappingFn(
"fill_constant"))(arg_case4); "fill_constant"))(arg_case4);
ASSERT_EQ(signature4.name, "full_sr"); EXPECT_STREQ(signature4.name, "full_sr");
TestArgumentMappingContext arg_case5( TestArgumentMappingContext arg_case5(
{"ShapeTensorList"}, {"ShapeTensorList"},
...@@ -69,7 +69,7 @@ TEST(ARG_MAP, fill_constant) { ...@@ -69,7 +69,7 @@ TEST(ARG_MAP, fill_constant) {
{"Out"}); {"Out"});
auto signature5 = (*OpUtilsMap::Instance().GetArgumentMappingFn( auto signature5 = (*OpUtilsMap::Instance().GetArgumentMappingFn(
"fill_constant"))(arg_case5); "fill_constant"))(arg_case5);
ASSERT_EQ(signature5.name, "full_sr"); EXPECT_STREQ(signature5.name, "full_sr");
TestArgumentMappingContext arg_case6( TestArgumentMappingContext arg_case6(
{"ShapeTensorList"}, {"ShapeTensorList"},
...@@ -79,7 +79,7 @@ TEST(ARG_MAP, fill_constant) { ...@@ -79,7 +79,7 @@ TEST(ARG_MAP, fill_constant) {
{"Out"}); {"Out"});
auto signature6 = (*OpUtilsMap::Instance().GetArgumentMappingFn( auto signature6 = (*OpUtilsMap::Instance().GetArgumentMappingFn(
"fill_constant"))(arg_case6); "fill_constant"))(arg_case6);
ASSERT_EQ(signature6.name, "full_sr"); EXPECT_STREQ(signature6.name, "full_sr");
TestArgumentMappingContext arg_case7( TestArgumentMappingContext arg_case7(
{"ValueTensor"}, {"ValueTensor"},
...@@ -89,7 +89,7 @@ TEST(ARG_MAP, fill_constant) { ...@@ -89,7 +89,7 @@ TEST(ARG_MAP, fill_constant) {
{"Out"}); {"Out"});
auto signature7 = (*OpUtilsMap::Instance().GetArgumentMappingFn( auto signature7 = (*OpUtilsMap::Instance().GetArgumentMappingFn(
"fill_constant"))(arg_case7); "fill_constant"))(arg_case7);
ASSERT_EQ(signature7.name, "full_sr"); EXPECT_STREQ(signature7.name, "full_sr");
TestArgumentMappingContext arg_case8( TestArgumentMappingContext arg_case8(
{}, {},
...@@ -101,7 +101,7 @@ TEST(ARG_MAP, fill_constant) { ...@@ -101,7 +101,7 @@ TEST(ARG_MAP, fill_constant) {
{"Out"}); {"Out"});
auto signature8 = (*OpUtilsMap::Instance().GetArgumentMappingFn( auto signature8 = (*OpUtilsMap::Instance().GetArgumentMappingFn(
"fill_constant"))(arg_case8); "fill_constant"))(arg_case8);
ASSERT_EQ(signature8.name, "full_sr"); EXPECT_STREQ(signature8.name, "full_sr");
TestArgumentMappingContext arg_case9( TestArgumentMappingContext arg_case9(
{}, {},
...@@ -112,7 +112,7 @@ TEST(ARG_MAP, fill_constant) { ...@@ -112,7 +112,7 @@ TEST(ARG_MAP, fill_constant) {
{"Out"}); {"Out"});
auto signature9 = (*OpUtilsMap::Instance().GetArgumentMappingFn( auto signature9 = (*OpUtilsMap::Instance().GetArgumentMappingFn(
"fill_constant"))(arg_case9); "fill_constant"))(arg_case9);
ASSERT_EQ(signature9.name, "full_sr"); EXPECT_STREQ(signature9.name, "full_sr");
} }
TEST(ARG_MAP, set_value) { TEST(ARG_MAP, set_value) {
...@@ -122,7 +122,7 @@ TEST(ARG_MAP, set_value) { ...@@ -122,7 +122,7 @@ TEST(ARG_MAP, set_value) {
{{"fp32_values", paddle::any{std::vector<float>{1}}}}, {{"fp32_values", paddle::any{std::vector<float>{1}}}},
{"Out"}, {"Out"},
{}); {});
ASSERT_EQ( EXPECT_STREQ(
(*OpUtilsMap::Instance().GetArgumentMappingFn("set_value"))(arg_case) (*OpUtilsMap::Instance().GetArgumentMappingFn("set_value"))(arg_case)
.name, .name,
"set_value"); "set_value");
...@@ -133,7 +133,7 @@ TEST(ARG_MAP, set_value) { ...@@ -133,7 +133,7 @@ TEST(ARG_MAP, set_value) {
{{"fp64_values", paddle::any{std::vector<double>{1}}}}, {{"fp64_values", paddle::any{std::vector<double>{1}}}},
{"Out"}, {"Out"},
{}); {});
ASSERT_EQ( EXPECT_STREQ(
(*OpUtilsMap::Instance().GetArgumentMappingFn("set_value"))(arg_case1) (*OpUtilsMap::Instance().GetArgumentMappingFn("set_value"))(arg_case1)
.name, .name,
"set_value"); "set_value");
...@@ -144,7 +144,7 @@ TEST(ARG_MAP, set_value) { ...@@ -144,7 +144,7 @@ TEST(ARG_MAP, set_value) {
{{"int32_values", paddle::any{std::vector<int>{1}}}}, {{"int32_values", paddle::any{std::vector<int>{1}}}},
{"Out"}, {"Out"},
{}); {});
ASSERT_EQ( EXPECT_STREQ(
(*OpUtilsMap::Instance().GetArgumentMappingFn("set_value"))(arg_case2) (*OpUtilsMap::Instance().GetArgumentMappingFn("set_value"))(arg_case2)
.name, .name,
"set_value"); "set_value");
...@@ -155,7 +155,7 @@ TEST(ARG_MAP, set_value) { ...@@ -155,7 +155,7 @@ TEST(ARG_MAP, set_value) {
{{"int64_values", paddle::any{std::vector<int64_t>{1}}}}, {{"int64_values", paddle::any{std::vector<int64_t>{1}}}},
{"Out"}, {"Out"},
{}); {});
ASSERT_EQ( EXPECT_STREQ(
(*OpUtilsMap::Instance().GetArgumentMappingFn("set_value"))(arg_case3) (*OpUtilsMap::Instance().GetArgumentMappingFn("set_value"))(arg_case3)
.name, .name,
"set_value"); "set_value");
...@@ -166,7 +166,7 @@ TEST(ARG_MAP, set_value) { ...@@ -166,7 +166,7 @@ TEST(ARG_MAP, set_value) {
{{"bool_values", paddle::any{std::vector<int>{1}}}}, {{"bool_values", paddle::any{std::vector<int>{1}}}},
{"Out"}, {"Out"},
{}); {});
ASSERT_EQ( EXPECT_STREQ(
(*OpUtilsMap::Instance().GetArgumentMappingFn("set_value"))(arg_case4) (*OpUtilsMap::Instance().GetArgumentMappingFn("set_value"))(arg_case4)
.name, .name,
"set_value"); "set_value");
...@@ -177,7 +177,7 @@ TEST(ARG_MAP, set_value) { ...@@ -177,7 +177,7 @@ TEST(ARG_MAP, set_value) {
{}, {},
{"Out"}, {"Out"},
{}); {});
ASSERT_EQ( EXPECT_STREQ(
(*OpUtilsMap::Instance().GetArgumentMappingFn("set_value"))(arg_case5) (*OpUtilsMap::Instance().GetArgumentMappingFn("set_value"))(arg_case5)
.name, .name,
"set_value_with_tensor"); "set_value_with_tensor");
...@@ -188,7 +188,7 @@ TEST(ARG_MAP, set_value) { ...@@ -188,7 +188,7 @@ TEST(ARG_MAP, set_value) {
{{"fp64_values", paddle::any{std::vector<double>{1}}}}, {{"fp64_values", paddle::any{std::vector<double>{1}}}},
{"Out"}, {"Out"},
{}); {});
ASSERT_EQ( EXPECT_STREQ(
(*OpUtilsMap::Instance().GetArgumentMappingFn("set_value"))(arg_case6) (*OpUtilsMap::Instance().GetArgumentMappingFn("set_value"))(arg_case6)
.name, .name,
"set_value"); "set_value");
...@@ -199,7 +199,7 @@ TEST(ARG_MAP, set_value) { ...@@ -199,7 +199,7 @@ TEST(ARG_MAP, set_value) {
{{"int32_values", paddle::any{std::vector<int>{1}}}}, {{"int32_values", paddle::any{std::vector<int>{1}}}},
{"Out"}, {"Out"},
{}); {});
ASSERT_EQ( EXPECT_STREQ(
(*OpUtilsMap::Instance().GetArgumentMappingFn("set_value"))(arg_case7) (*OpUtilsMap::Instance().GetArgumentMappingFn("set_value"))(arg_case7)
.name, .name,
"set_value"); "set_value");
...@@ -210,7 +210,7 @@ TEST(ARG_MAP, set_value) { ...@@ -210,7 +210,7 @@ TEST(ARG_MAP, set_value) {
{{"int64_values", paddle::any{std::vector<int64_t>{1}}}}, {{"int64_values", paddle::any{std::vector<int64_t>{1}}}},
{"Out"}, {"Out"},
{}); {});
ASSERT_EQ( EXPECT_STREQ(
(*OpUtilsMap::Instance().GetArgumentMappingFn("set_value"))(arg_case8) (*OpUtilsMap::Instance().GetArgumentMappingFn("set_value"))(arg_case8)
.name, .name,
"set_value"); "set_value");
...@@ -221,7 +221,7 @@ TEST(ARG_MAP, set_value) { ...@@ -221,7 +221,7 @@ TEST(ARG_MAP, set_value) {
{{"bool_values", paddle::any{std::vector<int>{1}}}}, {{"bool_values", paddle::any{std::vector<int>{1}}}},
{"Out"}, {"Out"},
{}); {});
ASSERT_EQ( EXPECT_STREQ(
(*OpUtilsMap::Instance().GetArgumentMappingFn("set_value"))(arg_case9) (*OpUtilsMap::Instance().GetArgumentMappingFn("set_value"))(arg_case9)
.name, .name,
"set_value"); "set_value");
...@@ -232,7 +232,7 @@ TEST(ARG_MAP, set_value) { ...@@ -232,7 +232,7 @@ TEST(ARG_MAP, set_value) {
{}, {},
{"Out"}, {"Out"},
{}); {});
ASSERT_EQ( EXPECT_STREQ(
(*OpUtilsMap::Instance().GetArgumentMappingFn("set_value"))(arg_case10) (*OpUtilsMap::Instance().GetArgumentMappingFn("set_value"))(arg_case10)
.name, .name,
"set_value_with_tensor"); "set_value_with_tensor");
...@@ -243,7 +243,7 @@ TEST(ARG_MAP, set_value) { ...@@ -243,7 +243,7 @@ TEST(ARG_MAP, set_value) {
{{"fp64_values", paddle::any{std::vector<double>{1}}}}, {{"fp64_values", paddle::any{std::vector<double>{1}}}},
{"Out"}, {"Out"},
{}); {});
ASSERT_EQ( EXPECT_STREQ(
(*OpUtilsMap::Instance().GetArgumentMappingFn("set_value"))(arg_case11) (*OpUtilsMap::Instance().GetArgumentMappingFn("set_value"))(arg_case11)
.name, .name,
"set_value"); "set_value");
...@@ -254,7 +254,7 @@ TEST(ARG_MAP, set_value) { ...@@ -254,7 +254,7 @@ TEST(ARG_MAP, set_value) {
{{"int32_values", paddle::any{std::vector<int>{1}}}}, {{"int32_values", paddle::any{std::vector<int>{1}}}},
{"Out"}, {"Out"},
{}); {});
ASSERT_EQ( EXPECT_STREQ(
(*OpUtilsMap::Instance().GetArgumentMappingFn("set_value"))(arg_case12) (*OpUtilsMap::Instance().GetArgumentMappingFn("set_value"))(arg_case12)
.name, .name,
"set_value"); "set_value");
...@@ -265,7 +265,7 @@ TEST(ARG_MAP, set_value) { ...@@ -265,7 +265,7 @@ TEST(ARG_MAP, set_value) {
{{"int64_values", paddle::any{std::vector<int64_t>{1}}}}, {{"int64_values", paddle::any{std::vector<int64_t>{1}}}},
{"Out"}, {"Out"},
{}); {});
ASSERT_EQ( EXPECT_STREQ(
(*OpUtilsMap::Instance().GetArgumentMappingFn("set_value"))(arg_case13) (*OpUtilsMap::Instance().GetArgumentMappingFn("set_value"))(arg_case13)
.name, .name,
"set_value"); "set_value");
...@@ -276,14 +276,14 @@ TEST(ARG_MAP, set_value) { ...@@ -276,14 +276,14 @@ TEST(ARG_MAP, set_value) {
{{"bool_values", paddle::any{std::vector<int>{1}}}}, {{"bool_values", paddle::any{std::vector<int>{1}}}},
{"Out"}, {"Out"},
{}); {});
ASSERT_EQ( EXPECT_STREQ(
(*OpUtilsMap::Instance().GetArgumentMappingFn("set_value"))(arg_case14) (*OpUtilsMap::Instance().GetArgumentMappingFn("set_value"))(arg_case14)
.name, .name,
"set_value"); "set_value");
TestArgumentMappingContext arg_case15( TestArgumentMappingContext arg_case15(
{"Input", "StartsTensorList", "ValueTensor"}, {}, {}, {"Out"}, {}); {"Input", "StartsTensorList", "ValueTensor"}, {}, {}, {"Out"}, {});
ASSERT_EQ( EXPECT_STREQ(
(*OpUtilsMap::Instance().GetArgumentMappingFn("set_value"))(arg_case15) (*OpUtilsMap::Instance().GetArgumentMappingFn("set_value"))(arg_case15)
.name, .name,
"set_value_with_tensor"); "set_value_with_tensor");
...@@ -294,7 +294,7 @@ TEST(ARG_MAP, set_value) { ...@@ -294,7 +294,7 @@ TEST(ARG_MAP, set_value) {
{{"fp32_values", paddle::any{std::vector<float>{1}}}}, {{"fp32_values", paddle::any{std::vector<float>{1}}}},
{"Out"}, {"Out"},
{}); {});
ASSERT_EQ( EXPECT_STREQ(
(*OpUtilsMap::Instance().GetArgumentMappingFn("set_value"))(arg_case16) (*OpUtilsMap::Instance().GetArgumentMappingFn("set_value"))(arg_case16)
.name, .name,
"set_value"); "set_value");
...@@ -305,7 +305,7 @@ TEST(ARG_MAP, set_value) { ...@@ -305,7 +305,7 @@ TEST(ARG_MAP, set_value) {
{{"fp64_values", paddle::any{std::vector<double>{1}}}}, {{"fp64_values", paddle::any{std::vector<double>{1}}}},
{"Out"}, {"Out"},
{}); {});
ASSERT_EQ( EXPECT_STREQ(
(*OpUtilsMap::Instance().GetArgumentMappingFn("set_value"))(arg_case17) (*OpUtilsMap::Instance().GetArgumentMappingFn("set_value"))(arg_case17)
.name, .name,
"set_value"); "set_value");
...@@ -316,7 +316,7 @@ TEST(ARG_MAP, set_value) { ...@@ -316,7 +316,7 @@ TEST(ARG_MAP, set_value) {
{{"int32_values", paddle::any{std::vector<int>{1}}}}, {{"int32_values", paddle::any{std::vector<int>{1}}}},
{"Out"}, {"Out"},
{}); {});
ASSERT_EQ( EXPECT_STREQ(
(*OpUtilsMap::Instance().GetArgumentMappingFn("set_value"))(arg_case18) (*OpUtilsMap::Instance().GetArgumentMappingFn("set_value"))(arg_case18)
.name, .name,
"set_value"); "set_value");
...@@ -327,7 +327,7 @@ TEST(ARG_MAP, set_value) { ...@@ -327,7 +327,7 @@ TEST(ARG_MAP, set_value) {
{{"int64_values", paddle::any{std::vector<int64_t>{1}}}}, {{"int64_values", paddle::any{std::vector<int64_t>{1}}}},
{"Out"}, {"Out"},
{}); {});
ASSERT_EQ( EXPECT_STREQ(
(*OpUtilsMap::Instance().GetArgumentMappingFn("set_value"))(arg_case19) (*OpUtilsMap::Instance().GetArgumentMappingFn("set_value"))(arg_case19)
.name, .name,
"set_value"); "set_value");
...@@ -338,7 +338,7 @@ TEST(ARG_MAP, set_value) { ...@@ -338,7 +338,7 @@ TEST(ARG_MAP, set_value) {
{{"bool_values", paddle::any{std::vector<int>{1}}}}, {{"bool_values", paddle::any{std::vector<int>{1}}}},
{"Out"}, {"Out"},
{}); {});
ASSERT_EQ( EXPECT_STREQ(
(*OpUtilsMap::Instance().GetArgumentMappingFn("set_value"))(arg_case20) (*OpUtilsMap::Instance().GetArgumentMappingFn("set_value"))(arg_case20)
.name, .name,
"set_value"); "set_value");
...@@ -349,7 +349,7 @@ TEST(ARG_MAP, set_value) { ...@@ -349,7 +349,7 @@ TEST(ARG_MAP, set_value) {
{}, {},
{"Out"}, {"Out"},
{}); {});
ASSERT_EQ( EXPECT_STREQ(
(*OpUtilsMap::Instance().GetArgumentMappingFn("set_value"))(arg_case21) (*OpUtilsMap::Instance().GetArgumentMappingFn("set_value"))(arg_case21)
.name, .name,
"set_value_with_tensor"); "set_value_with_tensor");
...@@ -360,7 +360,7 @@ TEST(ARG_MAP, set_value) { ...@@ -360,7 +360,7 @@ TEST(ARG_MAP, set_value) {
{{"fp64_values", paddle::any{std::vector<double>{1}}}}, {{"fp64_values", paddle::any{std::vector<double>{1}}}},
{"Out"}, {"Out"},
{}); {});
ASSERT_EQ( EXPECT_STREQ(
(*OpUtilsMap::Instance().GetArgumentMappingFn("set_value"))(arg_case22) (*OpUtilsMap::Instance().GetArgumentMappingFn("set_value"))(arg_case22)
.name, .name,
"set_value"); "set_value");
...@@ -371,7 +371,7 @@ TEST(ARG_MAP, set_value) { ...@@ -371,7 +371,7 @@ TEST(ARG_MAP, set_value) {
{{"int32_values", paddle::any{std::vector<int>{1}}}}, {{"int32_values", paddle::any{std::vector<int>{1}}}},
{"Out"}, {"Out"},
{}); {});
ASSERT_EQ( EXPECT_STREQ(
(*OpUtilsMap::Instance().GetArgumentMappingFn("set_value"))(arg_case23) (*OpUtilsMap::Instance().GetArgumentMappingFn("set_value"))(arg_case23)
.name, .name,
"set_value"); "set_value");
...@@ -382,7 +382,7 @@ TEST(ARG_MAP, set_value) { ...@@ -382,7 +382,7 @@ TEST(ARG_MAP, set_value) {
{{"int64_values", paddle::any{std::vector<int64_t>{1}}}}, {{"int64_values", paddle::any{std::vector<int64_t>{1}}}},
{"Out"}, {"Out"},
{}); {});
ASSERT_EQ( EXPECT_STREQ(
(*OpUtilsMap::Instance().GetArgumentMappingFn("set_value"))(arg_case24) (*OpUtilsMap::Instance().GetArgumentMappingFn("set_value"))(arg_case24)
.name, .name,
"set_value"); "set_value");
...@@ -393,14 +393,14 @@ TEST(ARG_MAP, set_value) { ...@@ -393,14 +393,14 @@ TEST(ARG_MAP, set_value) {
{{"bool_values", paddle::any{std::vector<int>{1}}}}, {{"bool_values", paddle::any{std::vector<int>{1}}}},
{"Out"}, {"Out"},
{}); {});
ASSERT_EQ( EXPECT_STREQ(
(*OpUtilsMap::Instance().GetArgumentMappingFn("set_value"))(arg_case25) (*OpUtilsMap::Instance().GetArgumentMappingFn("set_value"))(arg_case25)
.name, .name,
"set_value"); "set_value");
TestArgumentMappingContext arg_case26( TestArgumentMappingContext arg_case26(
{"Input", "EndsTensorList", "ValueTensor"}, {}, {}, {"Out"}, {}); {"Input", "EndsTensorList", "ValueTensor"}, {}, {}, {"Out"}, {});
ASSERT_EQ( EXPECT_STREQ(
(*OpUtilsMap::Instance().GetArgumentMappingFn("set_value"))(arg_case26) (*OpUtilsMap::Instance().GetArgumentMappingFn("set_value"))(arg_case26)
.name, .name,
"set_value_with_tensor"); "set_value_with_tensor");
...@@ -411,7 +411,7 @@ TEST(ARG_MAP, set_value) { ...@@ -411,7 +411,7 @@ TEST(ARG_MAP, set_value) {
{{"fp32_values", paddle::any{std::vector<float>{1}}}}, {{"fp32_values", paddle::any{std::vector<float>{1}}}},
{"Out"}, {"Out"},
{}); {});
ASSERT_EQ( EXPECT_STREQ(
(*OpUtilsMap::Instance().GetArgumentMappingFn("set_value"))(arg_case27) (*OpUtilsMap::Instance().GetArgumentMappingFn("set_value"))(arg_case27)
.name, .name,
"set_value"); "set_value");
...@@ -422,7 +422,7 @@ TEST(ARG_MAP, set_value) { ...@@ -422,7 +422,7 @@ TEST(ARG_MAP, set_value) {
{{"fp64_values", paddle::any{std::vector<double>{1}}}}, {{"fp64_values", paddle::any{std::vector<double>{1}}}},
{"Out"}, {"Out"},
{}); {});
ASSERT_EQ( EXPECT_STREQ(
(*OpUtilsMap::Instance().GetArgumentMappingFn("set_value"))(arg_case28) (*OpUtilsMap::Instance().GetArgumentMappingFn("set_value"))(arg_case28)
.name, .name,
"set_value"); "set_value");
...@@ -433,7 +433,7 @@ TEST(ARG_MAP, set_value) { ...@@ -433,7 +433,7 @@ TEST(ARG_MAP, set_value) {
{{"int32_values", paddle::any{std::vector<int>{1}}}}, {{"int32_values", paddle::any{std::vector<int>{1}}}},
{"Out"}, {"Out"},
{}); {});
ASSERT_EQ( EXPECT_STREQ(
(*OpUtilsMap::Instance().GetArgumentMappingFn("set_value"))(arg_case29) (*OpUtilsMap::Instance().GetArgumentMappingFn("set_value"))(arg_case29)
.name, .name,
"set_value"); "set_value");
...@@ -444,7 +444,7 @@ TEST(ARG_MAP, set_value) { ...@@ -444,7 +444,7 @@ TEST(ARG_MAP, set_value) {
{{"int64_values", paddle::any{std::vector<int64_t>{1}}}}, {{"int64_values", paddle::any{std::vector<int64_t>{1}}}},
{"Out"}, {"Out"},
{}); {});
ASSERT_EQ( EXPECT_STREQ(
(*OpUtilsMap::Instance().GetArgumentMappingFn("set_value"))(arg_case30) (*OpUtilsMap::Instance().GetArgumentMappingFn("set_value"))(arg_case30)
.name, .name,
"set_value"); "set_value");
...@@ -455,14 +455,14 @@ TEST(ARG_MAP, set_value) { ...@@ -455,14 +455,14 @@ TEST(ARG_MAP, set_value) {
{{"bool_values", paddle::any{std::vector<int>{1}}}}, {{"bool_values", paddle::any{std::vector<int>{1}}}},
{"Out"}, {"Out"},
{}); {});
ASSERT_EQ( EXPECT_STREQ(
(*OpUtilsMap::Instance().GetArgumentMappingFn("set_value"))(arg_case31) (*OpUtilsMap::Instance().GetArgumentMappingFn("set_value"))(arg_case31)
.name, .name,
"set_value"); "set_value");
TestArgumentMappingContext arg_case32( TestArgumentMappingContext arg_case32(
{"Input", "StepsTensorList", "ValueTensor"}, {}, {}, {"Out"}, {}); {"Input", "StepsTensorList", "ValueTensor"}, {}, {}, {"Out"}, {});
ASSERT_EQ( EXPECT_STREQ(
(*OpUtilsMap::Instance().GetArgumentMappingFn("set_value"))(arg_case32) (*OpUtilsMap::Instance().GetArgumentMappingFn("set_value"))(arg_case32)
.name, .name,
"set_value_with_tensor"); "set_value_with_tensor");
...@@ -473,7 +473,7 @@ TEST(ARG_MAP, set_value) { ...@@ -473,7 +473,7 @@ TEST(ARG_MAP, set_value) {
{{"fp32_values", paddle::any{std::vector<float>{1}}}}, {{"fp32_values", paddle::any{std::vector<float>{1}}}},
{"Out"}, {"Out"},
{}); {});
ASSERT_EQ( EXPECT_STREQ(
(*OpUtilsMap::Instance().GetArgumentMappingFn("set_value"))(arg_case33) (*OpUtilsMap::Instance().GetArgumentMappingFn("set_value"))(arg_case33)
.name, .name,
"set_value"); "set_value");
...@@ -484,7 +484,7 @@ TEST(ARG_MAP, set_value) { ...@@ -484,7 +484,7 @@ TEST(ARG_MAP, set_value) {
{{"fp64_values", paddle::any{std::vector<double>{1}}}}, {{"fp64_values", paddle::any{std::vector<double>{1}}}},
{"Out"}, {"Out"},
{}); {});
ASSERT_EQ( EXPECT_STREQ(
(*OpUtilsMap::Instance().GetArgumentMappingFn("set_value"))(arg_case34) (*OpUtilsMap::Instance().GetArgumentMappingFn("set_value"))(arg_case34)
.name, .name,
"set_value"); "set_value");
...@@ -495,7 +495,7 @@ TEST(ARG_MAP, set_value) { ...@@ -495,7 +495,7 @@ TEST(ARG_MAP, set_value) {
{{"int32_values", paddle::any{std::vector<int>{1}}}}, {{"int32_values", paddle::any{std::vector<int>{1}}}},
{"Out"}, {"Out"},
{}); {});
ASSERT_EQ( EXPECT_STREQ(
(*OpUtilsMap::Instance().GetArgumentMappingFn("set_value"))(arg_case35) (*OpUtilsMap::Instance().GetArgumentMappingFn("set_value"))(arg_case35)
.name, .name,
"set_value"); "set_value");
...@@ -506,7 +506,7 @@ TEST(ARG_MAP, set_value) { ...@@ -506,7 +506,7 @@ TEST(ARG_MAP, set_value) {
{{"int64_values", paddle::any{std::vector<int64_t>{1}}}}, {{"int64_values", paddle::any{std::vector<int64_t>{1}}}},
{"Out"}, {"Out"},
{}); {});
ASSERT_EQ( EXPECT_STREQ(
(*OpUtilsMap::Instance().GetArgumentMappingFn("set_value"))(arg_case36) (*OpUtilsMap::Instance().GetArgumentMappingFn("set_value"))(arg_case36)
.name, .name,
"set_value"); "set_value");
...@@ -517,7 +517,7 @@ TEST(ARG_MAP, set_value) { ...@@ -517,7 +517,7 @@ TEST(ARG_MAP, set_value) {
{{"bool_values", paddle::any{std::vector<int>{1}}}}, {{"bool_values", paddle::any{std::vector<int>{1}}}},
{"Out"}, {"Out"},
{}); {});
ASSERT_EQ( EXPECT_STREQ(
(*OpUtilsMap::Instance().GetArgumentMappingFn("set_value"))(arg_case37) (*OpUtilsMap::Instance().GetArgumentMappingFn("set_value"))(arg_case37)
.name, .name,
"set_value"); "set_value");
...@@ -530,7 +530,7 @@ TEST(ARG_MAP, set_value_grad) { ...@@ -530,7 +530,7 @@ TEST(ARG_MAP, set_value_grad) {
{}, {},
{"Input@GRAD", "ValueTensor@GRAD"}, {"Input@GRAD", "ValueTensor@GRAD"},
{}); {});
ASSERT_EQ( EXPECT_STREQ(
(*OpUtilsMap::Instance().GetArgumentMappingFn("set_value_grad"))(arg_case) (*OpUtilsMap::Instance().GetArgumentMappingFn("set_value_grad"))(arg_case)
.name, .name,
"set_value_grad"); "set_value_grad");
...@@ -541,20 +541,20 @@ TEST(ARG_MAP, set_value_grad) { ...@@ -541,20 +541,20 @@ TEST(ARG_MAP, set_value_grad) {
{}, {},
{"Input@GRAD", "ValueTensor@GRAD"}, {"Input@GRAD", "ValueTensor@GRAD"},
{}); {});
ASSERT_EQ((*OpUtilsMap::Instance().GetArgumentMappingFn("set_value_grad"))( EXPECT_STREQ((*OpUtilsMap::Instance().GetArgumentMappingFn("set_value_grad"))(
arg_case1) arg_case1)
.name, .name,
"set_value_grad"); "set_value_grad");
TestArgumentMappingContext arg_case2({"Out@GRAD", "StartsTensorList"}, TestArgumentMappingContext arg_case2({"Out@GRAD", "StartsTensorList"},
{}, {},
{}, {},
{"Input@GRAD", "ValueTensor@GRAD"}, {"Input@GRAD", "ValueTensor@GRAD"},
{}); {});
ASSERT_EQ((*OpUtilsMap::Instance().GetArgumentMappingFn("set_value_grad"))( EXPECT_STREQ((*OpUtilsMap::Instance().GetArgumentMappingFn("set_value_grad"))(
arg_case2) arg_case2)
.name, .name,
"set_value_grad"); "set_value_grad");
TestArgumentMappingContext arg_case3( TestArgumentMappingContext arg_case3(
{"Out@GRAD", "EndsTensorList", "StepsTensorList"}, {"Out@GRAD", "EndsTensorList", "StepsTensorList"},
...@@ -562,30 +562,30 @@ TEST(ARG_MAP, set_value_grad) { ...@@ -562,30 +562,30 @@ TEST(ARG_MAP, set_value_grad) {
{}, {},
{"Input@GRAD", "ValueTensor@GRAD"}, {"Input@GRAD", "ValueTensor@GRAD"},
{}); {});
ASSERT_EQ((*OpUtilsMap::Instance().GetArgumentMappingFn("set_value_grad"))( EXPECT_STREQ((*OpUtilsMap::Instance().GetArgumentMappingFn("set_value_grad"))(
arg_case3) arg_case3)
.name, .name,
"set_value_grad"); "set_value_grad");
TestArgumentMappingContext arg_case4({"Out@GRAD", "EndsTensorList"}, TestArgumentMappingContext arg_case4({"Out@GRAD", "EndsTensorList"},
{}, {},
{}, {},
{"Input@GRAD", "ValueTensor@GRAD"}, {"Input@GRAD", "ValueTensor@GRAD"},
{}); {});
ASSERT_EQ((*OpUtilsMap::Instance().GetArgumentMappingFn("set_value_grad"))( EXPECT_STREQ((*OpUtilsMap::Instance().GetArgumentMappingFn("set_value_grad"))(
arg_case4) arg_case4)
.name, .name,
"set_value_grad"); "set_value_grad");
TestArgumentMappingContext arg_case5({"Out@GRAD", "StepsTensorList"}, TestArgumentMappingContext arg_case5({"Out@GRAD", "StepsTensorList"},
{}, {},
{}, {},
{"Input@GRAD", "ValueTensor@GRAD"}, {"Input@GRAD", "ValueTensor@GRAD"},
{}); {});
ASSERT_EQ((*OpUtilsMap::Instance().GetArgumentMappingFn("set_value_grad"))( EXPECT_STREQ((*OpUtilsMap::Instance().GetArgumentMappingFn("set_value_grad"))(
arg_case5) arg_case5)
.name, .name,
"set_value_grad"); "set_value_grad");
} }
TEST(ARG_MAP, allclose) { TEST(ARG_MAP, allclose) {
...@@ -598,8 +598,8 @@ TEST(ARG_MAP, allclose) { ...@@ -598,8 +598,8 @@ TEST(ARG_MAP, allclose) {
{}); {});
auto signature1 = auto signature1 =
(*OpUtilsMap::Instance().GetArgumentMappingFn("allclose"))(arg_case1); (*OpUtilsMap::Instance().GetArgumentMappingFn("allclose"))(arg_case1);
ASSERT_EQ(signature1.name, "allclose"); EXPECT_STREQ(signature1.name, "allclose");
ASSERT_EQ(signature1.attr_names[0], "Rtol"); EXPECT_STREQ(signature1.attr_names[0], "Rtol");
TestArgumentMappingContext arg_case2( TestArgumentMappingContext arg_case2(
{"Input", "Other", "Atol"}, {"Input", "Other", "Atol"},
...@@ -610,26 +610,26 @@ TEST(ARG_MAP, allclose) { ...@@ -610,26 +610,26 @@ TEST(ARG_MAP, allclose) {
{}); {});
auto signature2 = auto signature2 =
(*OpUtilsMap::Instance().GetArgumentMappingFn("allclose"))(arg_case2); (*OpUtilsMap::Instance().GetArgumentMappingFn("allclose"))(arg_case2);
ASSERT_EQ(signature2.name, "allclose"); EXPECT_STREQ(signature2.name, "allclose");
ASSERT_EQ(signature2.attr_names[1], "Atol"); EXPECT_STREQ(signature2.attr_names[1], "Atol");
} }
TEST(ARG_MAP, reshape) { TEST(ARG_MAP, reshape) {
TestArgumentMappingContext arg_case1({"X", "ShapeTensor"}, {}, {}, {"Out"}); TestArgumentMappingContext arg_case1({"X", "ShapeTensor"}, {}, {}, {"Out"});
auto signature1 = auto signature1 =
(*OpUtilsMap::Instance().GetArgumentMappingFn("reshape2"))(arg_case1); (*OpUtilsMap::Instance().GetArgumentMappingFn("reshape2"))(arg_case1);
ASSERT_EQ(signature1.name, "reshape"); EXPECT_STREQ(signature1.name, "reshape");
TestArgumentMappingContext arg_case2({"X", "Shape"}, {}, {}, {"Out"}); TestArgumentMappingContext arg_case2({"X", "Shape"}, {}, {}, {"Out"});
auto signature2 = auto signature2 =
(*OpUtilsMap::Instance().GetArgumentMappingFn("reshape2"))(arg_case2); (*OpUtilsMap::Instance().GetArgumentMappingFn("reshape2"))(arg_case2);
ASSERT_EQ(signature2.name, "reshape"); EXPECT_STREQ(signature2.name, "reshape");
TestArgumentMappingContext arg_case3( TestArgumentMappingContext arg_case3(
{"X"}, {}, {{"shape", paddle::any(std::vector<int>({1, 2}))}}, {"Out"}); {"X"}, {}, {{"shape", paddle::any(std::vector<int>({1, 2}))}}, {"Out"});
auto signature3 = auto signature3 =
(*OpUtilsMap::Instance().GetArgumentMappingFn("reshape2"))(arg_case3); (*OpUtilsMap::Instance().GetArgumentMappingFn("reshape2"))(arg_case3);
ASSERT_EQ(signature3.name, "reshape"); EXPECT_STREQ(signature3.name, "reshape");
} }
} // namespace tests } // namespace tests
......
...@@ -18,4 +18,12 @@ if(WITH_TESTING) ...@@ -18,4 +18,12 @@ if(WITH_TESTING)
paddle_gtest_main paddle_gtest_main
SRCS paddle_gtest_main.cc SRCS paddle_gtest_main.cc
DEPS ${paddle_gtest_main_deps}) DEPS ${paddle_gtest_main_deps})
cc_library(
paddle_gtest_main_new
SRCS paddle_gtest_main.cc
DEPS gtest xxhash framework_proto eigen3 dlpack)
if(WITH_MKLDNN)
add_dependencies(paddle_gtest_main_new mkldnn)
endif()
endif() endif()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册