From 5d43f5e4f45c4eaf769d9682d997e7c1954e3040 Mon Sep 17 00:00:00 2001 From: Aurelius84 Date: Thu, 24 Aug 2023 10:55:57 +0800 Subject: [PATCH] [NewIR]Add NOT_FOR_INFER to prune Inference Library Size and Split VJP CodeGen into pd_op_vjp.cc (#56352) * [NewIR]Prune Inference Library Size and Remove IR Dialect * remove options * add NOT_FOR_INFER * fix pd_vjp.cc * polish deps * fix code style * fix unittest * fix cmake * fix inference CI --- cmake/generic.cmake | 26 +++++++++- paddle/fluid/framework/details/CMakeLists.txt | 2 +- .../garbage_collector/CMakeLists.txt | 2 +- .../new_executor/instruction/CMakeLists.txt | 2 +- paddle/fluid/inference/CMakeLists.txt | 34 ++++++------- .../fluid/ir/dialect/op_generator/op_gen.py | 48 ++++++++++++++----- .../dialect/op_generator/op_interface_gen.py | 8 ++-- .../dialect/paddle_dialect/ir/CMakeLists.txt | 15 ++++-- .../ir_adaptor/translator/op_translator.cc | 4 ++ paddle/fluid/jit/CMakeLists.txt | 2 +- paddle/fluid/platform/CMakeLists.txt | 2 +- paddle/fluid/platform/init.cc | 7 --- test/cpp/ir/core/type_test.cc | 3 ++ test/cpp/prim/test_vjp.cc | 6 +++ 14 files changed, 110 insertions(+), 51 deletions(-) diff --git a/cmake/generic.cmake b/cmake/generic.cmake index 947d44950d5..077db75fde2 100644 --- a/cmake/generic.cmake +++ b/cmake/generic.cmake @@ -118,6 +118,19 @@ function(find_fluid_modules TARGET_NAME) endif() endfunction() +# NOTE(Aurelius84): NOT_INFER_MODULES is used to tag +# and not considered as DEPS for inference libs. +set_property(GLOBAL PROPERTY NOT_INFER_MODULES "") + +function(ignore_infer_modules TARGET_NAME) + get_property(not_infer_modules GLOBAL PROPERTY NOT_INFER_MODULES) + list(FIND not_infer_modules TARGET_NAME is_found) + if(is_found EQUAL -1) # NOT FOUND + set(not_infer_modules ${not_infer_modules} ${TARGET_NAME}) + set_property(GLOBAL PROPERTY NOT_INFER_MODULES "${not_infer_modules}") + endif() +endfunction() + set_property(GLOBAL PROPERTY PHI_MODULES "") # find all phi modules is used for paddle static library # for building inference libs @@ -335,7 +348,15 @@ function(check_coverage_opt TARGET_NAME SRCS) endfunction() function(cc_library TARGET_NAME) - set(options STATIC static SHARED shared INTERFACE interface) + set(options + STATIC + static + SHARED + shared + INTERFACE + interface + NOT_FOR_INFER + not_for_infer) set(oneValueArgs "") set(multiValueArgs SRCS DEPS) cmake_parse_arguments(cc_library "${options}" "${oneValueArgs}" @@ -347,6 +368,9 @@ function(cc_library TARGET_NAME) CACHE STRING "output library name for target ${TARGET_NAME}") endif() if(cc_library_SRCS) + if(cc_library_NOT_FOR_INFER OR cc_library_not_for_infer) + ignore_infer_modules(${TARGET_NAME}) + endif() if(cc_library_SHARED OR cc_library_shared) # build *.so add_library(${TARGET_NAME} SHARED ${cc_library_SRCS}) elseif(cc_library_INTERFACE OR cc_library_interface) diff --git a/paddle/fluid/framework/details/CMakeLists.txt b/paddle/fluid/framework/details/CMakeLists.txt index 4d9a88cf223..ded28eaf5cc 100644 --- a/paddle/fluid/framework/details/CMakeLists.txt +++ b/paddle/fluid/framework/details/CMakeLists.txt @@ -276,7 +276,7 @@ set(SSA_GRAPH_EXECUTOR_DEPS inplace_addto_op_pass set_reader_device_info_utils) cc_library( - ssa_graph_executor + ssa_graph_executor NOT_FOR_INFER SRCS ssa_graph_executor.cc DEPS ${SSA_GRAPH_EXECUTOR_DEPS}) diff --git a/paddle/fluid/framework/new_executor/garbage_collector/CMakeLists.txt b/paddle/fluid/framework/new_executor/garbage_collector/CMakeLists.txt index 340d0483fe1..f6dcd385170 100644 --- a/paddle/fluid/framework/new_executor/garbage_collector/CMakeLists.txt +++ b/paddle/fluid/framework/new_executor/garbage_collector/CMakeLists.txt @@ -1,5 +1,5 @@ cc_library( - interpretercore_garbage_collector + interpretercore_garbage_collector NOT_FOR_INFER SRCS garbage_collector.cc event_garbage_collector.cc fast_garbage_collector.cc no_event_garbage_collector.cc DEPS garbage_collector executor_gc_helper) diff --git a/paddle/fluid/framework/new_executor/instruction/CMakeLists.txt b/paddle/fluid/framework/new_executor/instruction/CMakeLists.txt index fb52d0ac4a4..17c8f8cd762 100644 --- a/paddle/fluid/framework/new_executor/instruction/CMakeLists.txt +++ b/paddle/fluid/framework/new_executor/instruction/CMakeLists.txt @@ -6,7 +6,7 @@ cc_library( if(WITH_CINN AND NOT CINN_ONLY) cc_library( - cinn_jit_instruction + cinn_jit_instruction NOT_FOR_INFER SRCS cinn_jit_instruction.cc DEPS phi cinnapi cinn_dialect) endif() diff --git a/paddle/fluid/inference/CMakeLists.txt b/paddle/fluid/inference/CMakeLists.txt index 6287a547aec..7b6175a9756 100644 --- a/paddle/fluid/inference/CMakeLists.txt +++ b/paddle/fluid/inference/CMakeLists.txt @@ -34,6 +34,7 @@ endif() get_property(fluid_modules GLOBAL PROPERTY FLUID_MODULES) get_property(phi_modules GLOBAL PROPERTY PHI_MODULES) get_property(ir_targets GLOBAL PROPERTY IR_TARGETS) +get_property(not_infer_modules GLOBAL PROPERTY NOT_INFER_MODULES) set(utils_modules pretty_log string_helper benchmark utf8proc) add_subdirectory(api) @@ -57,18 +58,16 @@ set(KERNEL_LIST "" CACHE STRING "The list of phi kernels that will be compiled") +# shared inference library deps +list(REMOVE_DUPLICATES fluid_modules) #windows GPU static library over the limit, so not create_static_lib, and cc_library is dummy if(WIN32 AND WITH_GPU) - cc_library(paddle_inference DEPS ${fluid_modules} ${ir_targets} - ${STATIC_INFERENCE_API} ${utils_modules}) + cc_library(paddle_inference DEPS ${fluid_modules} ${STATIC_INFERENCE_API} + ${utils_modules}) else() - # message("${fluid_modules}") - # message("${STATIC_INFERENCE_API}") - # message("${utils_modules}") - # message("${phi_modules}") if(WIN32) create_static_lib(paddle_inference ${phi_modules} ${fluid_modules} - ${STATIC_INFERENCE_API} ${utils_modules}) + ${ir_targets} ${STATIC_INFERENCE_API} ${utils_modules}) else() create_static_lib(paddle_inference ${phi_modules} ${fluid_modules} ${ir_targets} ${STATIC_INFERENCE_API} ${utils_modules}) @@ -82,7 +81,6 @@ if(NOT APPLE) ) set_target_properties(paddle_inference PROPERTIES LINK_FLAGS "${LINK_FLAGS}") endif() - # C inference API add_subdirectory(capi_exp) @@ -99,15 +97,17 @@ set(SHARED_INFERENCE_SRCS ${CMAKE_CURRENT_SOURCE_DIR}/api/details/zero_copy_tensor.cc ${CMAKE_CURRENT_SOURCE_DIR}/utils/io_utils.cc) -# shared inference library deps -list(REMOVE_ITEM fluid_modules standalone_executor - interpretercore_garbage_collector) -if(WIN32) - set(SHARED_INFERENCE_DEPS phi ${fluid_modules} analysis_predictor - ${utils_modules}) -else() - set(SHARED_INFERENCE_DEPS phi ${fluid_modules} ${ir_targets} - analysis_predictor ${utils_modules}) +# NOTE(Aurelius84): For inference library, some DEPS is usless +# such as non-infer operator related targets et.al. +list(REMOVE_ITEM fluid_modules cinn_dialect) +# NOTE(Aurelisu84): Remove ir dialect related target DEPS for inference +# shared library to prune library size. +list(REMOVE_ITEM fluid_modules ${not_infer_modules}) + +set(SHARED_INFERENCE_DEPS phi ${fluid_modules} analysis_predictor + ${utils_modules}) +if(NOT WIN32) + list(APPEND SHARED_INFERENCE_DEPS ${ir_targets}) endif() if(WITH_CRYPTO) diff --git a/paddle/fluid/ir/dialect/op_generator/op_gen.py b/paddle/fluid/ir/dialect/op_generator/op_gen.py index 29d4a1b1fab..2042c626f06 100644 --- a/paddle/fluid/ir/dialect/op_generator/op_gen.py +++ b/paddle/fluid/ir/dialect/op_generator/op_gen.py @@ -13,6 +13,7 @@ # limitations under the License. import argparse +import logging import os import yaml @@ -115,13 +116,30 @@ CC_FILE_TEMPLATE = """// This file is generated by "paddle/fluid/ir/dialect/op_g #include "paddle/phi/infermeta/fusion.h" #include "paddle/phi/api/lib/utils/allocator.h" #include "paddle/fluid/primitive/rule/vjp/vjp.h" -{def_primitive} #include "paddle/ir/core/op_base.h" {input} {define_type_id} """ +# ===================================== +# String Template for pd_op_vjp.cc file code gen +# ===================================== +VJP_CC_FILE_TEMPLATE = """// This file is generated by "paddle/fluid/ir/dialect/op_generator/op_gen.py" +#include "paddle/fluid/ir/dialect/paddle_dialect/ir/pd_attribute.h" +#include "paddle/fluid/ir/dialect/paddle_dialect/ir/pd_op.h" +#include "paddle/fluid/primitive/rule/vjp/vjp.h" +#include "paddle/fluid/primitive/type/lazy_tensor.h" +#include "paddle/ir/core/builtin_op.h" +#include "paddle/ir/core/op_base.h" +#include "paddle/phi/common/int_array.h" + +namespace paddle {{ +namespace dialect {{ +{input} +}} // namespace dialect +}} // namespace paddle +""" OP_N_ATTRIBUTE_DEFINED_TEMPLATE = """ const char *{op_name}::attributes_name[{attribute_num}] = {{ {attribute_names} }}; @@ -681,6 +699,7 @@ def OpGenerator( dialect_name, op_def_h_file, op_def_cc_file, + op_vjp_cc_file, ): # (1) Prepare: Delete existing old files: pd_op.h.tmp, pd_op.cc.tmp if os.path.exists(op_def_h_file): @@ -705,6 +724,7 @@ def OpGenerator( ops_name_list = [] # all op class name store in this list ops_declare_list = [] # all op class declare store in this list ops_defined_list = [] # all op class defined store in this list + ops_vjp_defined_list = [] # all op vjp static interface defination for key, op_info in op_info_items.items(): # get op inputs info op_input_name_list = op_info.input_name_list @@ -1078,11 +1098,9 @@ def OpGenerator( ops_defined_list.append(op_infer_meta_str) # NOTE(chenxi67)skip if dialect_name==cinn if dialect_name == "cinn": - import logging - logging.warning("cinn is currently not support Vjp function") else: - ops_defined_list.append(op_vjp_str) + ops_vjp_defined_list.append(op_vjp_str) # (4) Generate head file str op_namespaces_prev = "" @@ -1122,24 +1140,25 @@ def OpGenerator( for op in ops_name_with_namespace_list: define_type_id_str += DEFINE_OP_TYPE_ID.format(op_name=op) - # NOTE(chenxi67) Skip include this header file if dialect_name == cinn - # otherwise we may get compile error when compile with "ncclDataType_t" - def_primitive_str = "#include \"paddle/fluid/primitive/type/lazy_tensor.h\"" - if dialect_name == "cinn": - def_primitive_str = "" - source_file_str = CC_FILE_TEMPLATE.format( h_file=op_def_h_file[:-4], - def_primitive=def_primitive_str, input=source_file_str, define_type_id=define_type_id_str, ) # Add head + vjp_source_file_str = VJP_CC_FILE_TEMPLATE.format( + input="".join(ops_vjp_defined_list) + ) # (5) Generate pd_op.h.tmp, pd_op.cc.tmp - with open(op_def_h_file, 'a') as f: + with open(op_def_h_file, 'w') as f: f.write(head_file_str) - with open(op_def_cc_file, 'a') as f: + with open(op_def_cc_file, 'w') as f: f.write(source_file_str) + # NOTE(Aurelius84): op_gen.py is called multiply times, + # and vjp is only avaible for pd dialect. + if dialect_name != 'cinn' and op_vjp_cc_file: + with open(op_vjp_cc_file, 'w') as f: + f.write(vjp_source_file_str) # ===================================== @@ -1155,6 +1174,7 @@ def ParseArguments(): parser.add_argument('--dialect_name', type=str) parser.add_argument('--op_def_h_file', type=str) parser.add_argument('--op_def_cc_file', type=str) + parser.add_argument('--op_vjp_cc_file', type=str) return parser.parse_args() @@ -1172,6 +1192,7 @@ if __name__ == "__main__": dialect_name = args.dialect_name op_def_h_file = args.op_def_h_file op_def_cc_file = args.op_def_cc_file + op_vjp_cc_file = args.op_vjp_cc_file # auto code generate OpGenerator( @@ -1181,4 +1202,5 @@ if __name__ == "__main__": dialect_name, op_def_h_file, op_def_cc_file, + op_vjp_cc_file, ) diff --git a/paddle/fluid/ir/dialect/op_generator/op_interface_gen.py b/paddle/fluid/ir/dialect/op_generator/op_interface_gen.py index 8762c6328e1..48078e8c432 100644 --- a/paddle/fluid/ir/dialect/op_generator/op_interface_gen.py +++ b/paddle/fluid/ir/dialect/op_generator/op_interface_gen.py @@ -57,17 +57,17 @@ OP_VJP_DEFINE_TEMPLATE = """ std::vector> {op_class_name}::Vjp(ir::Operation* op, const std::vector>& out_grads, const std::vector>& stop_gradients){{ {op_class_name} op_obj = op->dyn_cast<{op_class_name}>(); -VLOG(6) << "Prepare inputs of {op_grad_name}"; + VLOG(6) << "Prepare inputs of {op_grad_name}"; {forward_input_code} {forward_output_grad_code} -VLOG(6) << "Vjp prepare Prepare attributes of {op_grad_name}"; + VLOG(6) << "Vjp prepare Prepare attributes of {op_grad_name}"; {attribute_code} -VLOG(4) << "Vjp prepare call {op_phi_name}'s vjp inteface"; + VLOG(4) << "Vjp prepare call {op_phi_name}'s vjp inteface"; {call_vjp_code} -VLOG(4) << "Vjp prepare stop gradient of {op_grad_name}"; + VLOG(4) << "Vjp prepare stop gradient of {op_grad_name}"; {stop_gradient_input_grad_code} return res; }} diff --git a/paddle/fluid/ir/dialect/paddle_dialect/ir/CMakeLists.txt b/paddle/fluid/ir/dialect/paddle_dialect/ir/CMakeLists.txt index 69ffb2fcebb..33613157397 100644 --- a/paddle/fluid/ir/dialect/paddle_dialect/ir/CMakeLists.txt +++ b/paddle/fluid/ir/dialect/paddle_dialect/ir/CMakeLists.txt @@ -36,18 +36,25 @@ set(op_source_file ${PD_DIALECT_BINARY_DIR}/pd_op.cc) set(op_header_file_tmp ${op_header_file}.tmp) set(op_source_file_tmp ${op_source_file}.tmp) +set(op_vjp_source_file ${PD_DIALECT_BINARY_DIR}/pd_op_vjp.cc) +set(op_vjp_source_file_tmp ${op_vjp_source_file}.tmp) + add_custom_command( - OUTPUT ${op_header_file} ${op_source_file} + OUTPUT ${op_header_file} ${op_source_file} ${op_vjp_source_file} COMMAND ${PYTHON_EXECUTABLE} ${op_gen_file} --op_yaml_files ${op_yaml_files} --op_compat_yaml_file ${op_compat_yaml_file} --namespaces ${op_namespace} --dialect_name ${dialect_name} --op_def_h_file ${op_header_file_tmp} - --op_def_cc_file ${op_source_file_tmp} + --op_def_cc_file ${op_source_file_tmp} --op_vjp_cc_file + ${op_vjp_source_file_tmp} COMMAND ${CMAKE_COMMAND} -E copy_if_different ${op_header_file_tmp} ${op_header_file} COMMAND ${CMAKE_COMMAND} -E copy_if_different ${op_source_file_tmp} ${op_source_file} - COMMENT "copy_if_different ${op_header_file} ${op_source_file}" + COMMAND ${CMAKE_COMMAND} -E copy_if_different ${op_vjp_source_file_tmp} + ${op_vjp_source_file} + COMMENT + "copy_if_different ${op_header_file} ${op_source_file} ${op_vjp_source_file}" DEPENDS ${op_gen_file} ${op_forward_yaml_file1} ${op_forward_yaml_file2} ${op_backward_yaml_file1} ${op_backward_yaml_file2} ${op_compat_yaml_file} @@ -98,6 +105,6 @@ target_include_directories(pd_dialect_api PRIVATE ${PD_DIALECT_BINARY_DIR}) cc_library( pd_dialect - SRCS pd_dialect.cc pd_op_vjp_manual.cc + SRCS pd_dialect.cc pd_op_vjp_manual.cc ${op_vjp_source_file} DEPS pd_dialect_api param_to_variable primitive_vjp_experimental pd_dialect_utils op_yaml_info_parser) diff --git a/paddle/fluid/ir_adaptor/translator/op_translator.cc b/paddle/fluid/ir_adaptor/translator/op_translator.cc index cd7f5029a44..1805ae21bf1 100644 --- a/paddle/fluid/ir_adaptor/translator/op_translator.cc +++ b/paddle/fluid/ir_adaptor/translator/op_translator.cc @@ -25,6 +25,7 @@ #include "paddle/fluid/framework/op_desc.h" #include "paddle/fluid/ir/dialect/paddle_dialect/interface/op_yaml_info.h" #include "paddle/fluid/ir/dialect/paddle_dialect/ir/pd_attribute.h" +#include "paddle/fluid/ir/dialect/paddle_dialect/ir/pd_dialect.h" #include "paddle/fluid/ir/dialect/paddle_dialect/ir/pd_type.h" #include "paddle/fluid/ir/dialect/paddle_dialect/utils/utils.h" #include "paddle/fluid/ir_adaptor/translator/attribute_translator.h" @@ -1634,6 +1635,9 @@ struct ElementwiseGradTranscriber : public OpTranscriber { }; OpTranslator::OpTranslator() { + ir::IrContext* ctx = ir::IrContext::Instance(); + ctx->GetOrRegisterDialect(); + general_handler = OpTranscriber(); special_handlers["add_n"] = AddNOpTranscriber(); special_handlers["assign_value"] = AssignValueOpTranscriber(); diff --git a/paddle/fluid/jit/CMakeLists.txt b/paddle/fluid/jit/CMakeLists.txt index 7730b6d8fe0..16d9b080b3e 100644 --- a/paddle/fluid/jit/CMakeLists.txt +++ b/paddle/fluid/jit/CMakeLists.txt @@ -24,7 +24,7 @@ cc_library( cc_library( jit_compilation_unit SRCS compilation_unit.cc - DEPS proto_desc executor parallel_executor executor_cache) + DEPS proto_desc executor parallel_executor) cc_library( jit_function_schema diff --git a/paddle/fluid/platform/CMakeLists.txt b/paddle/fluid/platform/CMakeLists.txt index c215611f9e5..0cca954a627 100644 --- a/paddle/fluid/platform/CMakeLists.txt +++ b/paddle/fluid/platform/CMakeLists.txt @@ -108,7 +108,7 @@ endif() cc_library( init SRCS init.cc - DEPS device_context phi memcpy pd_dialect) + DEPS device_context phi memcpy) # memcpy depends on device_context, here add deps individually for # avoiding cycle dependencies diff --git a/paddle/fluid/platform/init.cc b/paddle/fluid/platform/init.cc index 34f261ac213..2ae413db5e6 100644 --- a/paddle/fluid/platform/init.cc +++ b/paddle/fluid/platform/init.cc @@ -51,13 +51,9 @@ limitations under the License. */ #include "paddle/fluid/platform/device/ipu/ipu_info.h" #endif -#include "paddle/fluid/ir/dialect/paddle_dialect/ir/pd_dialect.h" #include "paddle/fluid/memory/allocation/allocator_facade.h" #include "paddle/fluid/memory/memory.h" #include "paddle/fluid/platform/flags.h" -#include "paddle/ir/core/builtin_dialect.h" -#include "paddle/ir/core/ir_context.h" -#include "paddle/ir/core/program.h" #include "paddle/phi/common/memory_utils.h" #include "paddle/phi/core/custom_kernel.h" @@ -202,9 +198,6 @@ void InitDevices() { } void InitDevices(const std::vector devices) { - ir::IrContext *ctx = ir::IrContext::Instance(); - ctx->GetOrRegisterDialect(); - std::vector places; for (auto device : devices) { diff --git a/test/cpp/ir/core/type_test.cc b/test/cpp/ir/core/type_test.cc index 62d6d48941a..2def5aa3d17 100644 --- a/test/cpp/ir/core/type_test.cc +++ b/test/cpp/ir/core/type_test.cc @@ -15,6 +15,7 @@ #include #include +#include "paddle/fluid/ir/dialect/paddle_dialect/ir/pd_dialect.h" #include "paddle/fluid/ir/dialect/paddle_dialect/ir/pd_type.h" #include "paddle/ir/core/builtin_dialect.h" #include "paddle/ir/core/builtin_type.h" @@ -209,6 +210,7 @@ IR_DEFINE_EXPLICIT_TYPE_ID(IntegerDialect) TEST(type_test, custom_type_dialect) { ir::IrContext *ctx = ir::IrContext::Instance(); + ctx->GetOrRegisterDialect(); // Test 1: Test the function of IrContext to register Dialect. ctx->GetOrRegisterDialect(); @@ -240,6 +242,7 @@ TEST(type_test, custom_type_dialect) { TEST(type_test, pd_dialect) { ir::IrContext *ctx = ir::IrContext::Instance(); + ctx->GetOrRegisterDialect(); ir::Type fp32_dtype = ir::Float32Type::get(ctx); phi::DDim dims = {2, 2}; phi::DataLayout data_layout = phi::DataLayout::NCHW; diff --git a/test/cpp/prim/test_vjp.cc b/test/cpp/prim/test_vjp.cc index 783de29b9ef..667690472c2 100644 --- a/test/cpp/prim/test_vjp.cc +++ b/test/cpp/prim/test_vjp.cc @@ -44,6 +44,7 @@ namespace framework { TEST(VJP, TanhBackwardTest) { ir::IrContext* ctx = ir::IrContext::Instance(); + ctx->GetOrRegisterDialect(); ir::Program program((ctx)); paddle::dialect::APIBuilder::Instance().SetProgram(&program); @@ -98,6 +99,7 @@ TEST(VJP, TanhBackwardTest) { TEST(VJP, Tanh_BackwardTest) { ir::IrContext* ctx = ir::IrContext::Instance(); + ctx->GetOrRegisterDialect(); ir::Program program((ctx)); paddle::dialect::APIBuilder::Instance().SetProgram(&program); @@ -152,6 +154,7 @@ TEST(VJP, Tanh_BackwardTest) { TEST(VJP, MeanBackwardTest) { ir::IrContext* ctx = ir::IrContext::Instance(); + ctx->GetOrRegisterDialect(); ir::Program program((ctx)); paddle::dialect::APIBuilder::Instance().SetProgram(&program); @@ -208,6 +211,7 @@ TEST(VJP, MeanBackwardTest) { TEST(VJP, ConcatBackwardTest) { ir::IrContext* ctx = ir::IrContext::Instance(); + ctx->GetOrRegisterDialect(); ir::Program program((ctx)); paddle::dialect::APIBuilder::Instance().SetProgram(&program); @@ -270,6 +274,7 @@ TEST(VJP, ConcatBackwardTest) { TEST(VJP, AddBackwardTest) { ir::IrContext* ctx = ir::IrContext::Instance(); + ctx->GetOrRegisterDialect(); ir::Program program((ctx)); paddle::dialect::APIBuilder::Instance().SetProgram(&program); @@ -334,6 +339,7 @@ TEST(VJP, AddBackwardTest) { TEST(VJP, Add_BackwardTest) { ir::IrContext* ctx = ir::IrContext::Instance(); + ctx->GetOrRegisterDialect(); ir::Program program((ctx)); paddle::dialect::APIBuilder::Instance().SetProgram(&program); -- GitLab