未验证 提交 5d43f5e4 编写于 作者: A Aurelius84 提交者: GitHub

[NewIR]Add NOT_FOR_INFER to prune Inference Library Size and Split VJP CodeGen...

[NewIR]Add NOT_FOR_INFER to prune Inference Library Size and Split VJP CodeGen into pd_op_vjp.cc (#56352)

* [NewIR]Prune Inference Library Size and Remove IR Dialect

* remove options

* add NOT_FOR_INFER

* fix pd_vjp.cc

* polish deps

* fix code style

* fix unittest

* fix cmake

* fix inference CI
上级 4f652ac2
...@@ -118,6 +118,19 @@ function(find_fluid_modules TARGET_NAME) ...@@ -118,6 +118,19 @@ function(find_fluid_modules TARGET_NAME)
endif() endif()
endfunction() endfunction()
# NOTE(Aurelius84): NOT_INFER_MODULES is used to tag
# and not considered as DEPS for inference libs.
set_property(GLOBAL PROPERTY NOT_INFER_MODULES "")
function(ignore_infer_modules TARGET_NAME)
get_property(not_infer_modules GLOBAL PROPERTY NOT_INFER_MODULES)
list(FIND not_infer_modules TARGET_NAME is_found)
if(is_found EQUAL -1) # NOT FOUND
set(not_infer_modules ${not_infer_modules} ${TARGET_NAME})
set_property(GLOBAL PROPERTY NOT_INFER_MODULES "${not_infer_modules}")
endif()
endfunction()
set_property(GLOBAL PROPERTY PHI_MODULES "") set_property(GLOBAL PROPERTY PHI_MODULES "")
# find all phi modules is used for paddle static library # find all phi modules is used for paddle static library
# for building inference libs # for building inference libs
...@@ -335,7 +348,15 @@ function(check_coverage_opt TARGET_NAME SRCS) ...@@ -335,7 +348,15 @@ function(check_coverage_opt TARGET_NAME SRCS)
endfunction() endfunction()
function(cc_library TARGET_NAME) function(cc_library TARGET_NAME)
set(options STATIC static SHARED shared INTERFACE interface) set(options
STATIC
static
SHARED
shared
INTERFACE
interface
NOT_FOR_INFER
not_for_infer)
set(oneValueArgs "") set(oneValueArgs "")
set(multiValueArgs SRCS DEPS) set(multiValueArgs SRCS DEPS)
cmake_parse_arguments(cc_library "${options}" "${oneValueArgs}" cmake_parse_arguments(cc_library "${options}" "${oneValueArgs}"
...@@ -347,6 +368,9 @@ function(cc_library TARGET_NAME) ...@@ -347,6 +368,9 @@ function(cc_library TARGET_NAME)
CACHE STRING "output library name for target ${TARGET_NAME}") CACHE STRING "output library name for target ${TARGET_NAME}")
endif() endif()
if(cc_library_SRCS) if(cc_library_SRCS)
if(cc_library_NOT_FOR_INFER OR cc_library_not_for_infer)
ignore_infer_modules(${TARGET_NAME})
endif()
if(cc_library_SHARED OR cc_library_shared) # build *.so if(cc_library_SHARED OR cc_library_shared) # build *.so
add_library(${TARGET_NAME} SHARED ${cc_library_SRCS}) add_library(${TARGET_NAME} SHARED ${cc_library_SRCS})
elseif(cc_library_INTERFACE OR cc_library_interface) elseif(cc_library_INTERFACE OR cc_library_interface)
......
...@@ -276,7 +276,7 @@ set(SSA_GRAPH_EXECUTOR_DEPS ...@@ -276,7 +276,7 @@ set(SSA_GRAPH_EXECUTOR_DEPS
inplace_addto_op_pass inplace_addto_op_pass
set_reader_device_info_utils) set_reader_device_info_utils)
cc_library( cc_library(
ssa_graph_executor ssa_graph_executor NOT_FOR_INFER
SRCS ssa_graph_executor.cc SRCS ssa_graph_executor.cc
DEPS ${SSA_GRAPH_EXECUTOR_DEPS}) DEPS ${SSA_GRAPH_EXECUTOR_DEPS})
......
cc_library( cc_library(
interpretercore_garbage_collector interpretercore_garbage_collector NOT_FOR_INFER
SRCS garbage_collector.cc event_garbage_collector.cc fast_garbage_collector.cc SRCS garbage_collector.cc event_garbage_collector.cc fast_garbage_collector.cc
no_event_garbage_collector.cc no_event_garbage_collector.cc
DEPS garbage_collector executor_gc_helper) DEPS garbage_collector executor_gc_helper)
...@@ -6,7 +6,7 @@ cc_library( ...@@ -6,7 +6,7 @@ cc_library(
if(WITH_CINN AND NOT CINN_ONLY) if(WITH_CINN AND NOT CINN_ONLY)
cc_library( cc_library(
cinn_jit_instruction cinn_jit_instruction NOT_FOR_INFER
SRCS cinn_jit_instruction.cc SRCS cinn_jit_instruction.cc
DEPS phi cinnapi cinn_dialect) DEPS phi cinnapi cinn_dialect)
endif() endif()
...@@ -34,6 +34,7 @@ endif() ...@@ -34,6 +34,7 @@ endif()
get_property(fluid_modules GLOBAL PROPERTY FLUID_MODULES) get_property(fluid_modules GLOBAL PROPERTY FLUID_MODULES)
get_property(phi_modules GLOBAL PROPERTY PHI_MODULES) get_property(phi_modules GLOBAL PROPERTY PHI_MODULES)
get_property(ir_targets GLOBAL PROPERTY IR_TARGETS) get_property(ir_targets GLOBAL PROPERTY IR_TARGETS)
get_property(not_infer_modules GLOBAL PROPERTY NOT_INFER_MODULES)
set(utils_modules pretty_log string_helper benchmark utf8proc) set(utils_modules pretty_log string_helper benchmark utf8proc)
add_subdirectory(api) add_subdirectory(api)
...@@ -57,18 +58,16 @@ set(KERNEL_LIST ...@@ -57,18 +58,16 @@ set(KERNEL_LIST
"" ""
CACHE STRING "The list of phi kernels that will be compiled") CACHE STRING "The list of phi kernels that will be compiled")
# shared inference library deps
list(REMOVE_DUPLICATES fluid_modules)
#windows GPU static library over the limit, so not create_static_lib, and cc_library is dummy #windows GPU static library over the limit, so not create_static_lib, and cc_library is dummy
if(WIN32 AND WITH_GPU) if(WIN32 AND WITH_GPU)
cc_library(paddle_inference DEPS ${fluid_modules} ${ir_targets} cc_library(paddle_inference DEPS ${fluid_modules} ${STATIC_INFERENCE_API}
${STATIC_INFERENCE_API} ${utils_modules}) ${utils_modules})
else() else()
# message("${fluid_modules}")
# message("${STATIC_INFERENCE_API}")
# message("${utils_modules}")
# message("${phi_modules}")
if(WIN32) if(WIN32)
create_static_lib(paddle_inference ${phi_modules} ${fluid_modules} create_static_lib(paddle_inference ${phi_modules} ${fluid_modules}
${STATIC_INFERENCE_API} ${utils_modules}) ${ir_targets} ${STATIC_INFERENCE_API} ${utils_modules})
else() else()
create_static_lib(paddle_inference ${phi_modules} ${fluid_modules} create_static_lib(paddle_inference ${phi_modules} ${fluid_modules}
${ir_targets} ${STATIC_INFERENCE_API} ${utils_modules}) ${ir_targets} ${STATIC_INFERENCE_API} ${utils_modules})
...@@ -82,7 +81,6 @@ if(NOT APPLE) ...@@ -82,7 +81,6 @@ if(NOT APPLE)
) )
set_target_properties(paddle_inference PROPERTIES LINK_FLAGS "${LINK_FLAGS}") set_target_properties(paddle_inference PROPERTIES LINK_FLAGS "${LINK_FLAGS}")
endif() endif()
# C inference API # C inference API
add_subdirectory(capi_exp) add_subdirectory(capi_exp)
...@@ -99,15 +97,17 @@ set(SHARED_INFERENCE_SRCS ...@@ -99,15 +97,17 @@ set(SHARED_INFERENCE_SRCS
${CMAKE_CURRENT_SOURCE_DIR}/api/details/zero_copy_tensor.cc ${CMAKE_CURRENT_SOURCE_DIR}/api/details/zero_copy_tensor.cc
${CMAKE_CURRENT_SOURCE_DIR}/utils/io_utils.cc) ${CMAKE_CURRENT_SOURCE_DIR}/utils/io_utils.cc)
# shared inference library deps # NOTE(Aurelius84): For inference library, some DEPS is usless
list(REMOVE_ITEM fluid_modules standalone_executor # such as non-infer operator related targets et.al.
interpretercore_garbage_collector) list(REMOVE_ITEM fluid_modules cinn_dialect)
if(WIN32) # NOTE(Aurelisu84): Remove ir dialect related target DEPS for inference
set(SHARED_INFERENCE_DEPS phi ${fluid_modules} analysis_predictor # shared library to prune library size.
list(REMOVE_ITEM fluid_modules ${not_infer_modules})
set(SHARED_INFERENCE_DEPS phi ${fluid_modules} analysis_predictor
${utils_modules}) ${utils_modules})
else() if(NOT WIN32)
set(SHARED_INFERENCE_DEPS phi ${fluid_modules} ${ir_targets} list(APPEND SHARED_INFERENCE_DEPS ${ir_targets})
analysis_predictor ${utils_modules})
endif() endif()
if(WITH_CRYPTO) if(WITH_CRYPTO)
......
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
# limitations under the License. # limitations under the License.
import argparse import argparse
import logging
import os import os
import yaml import yaml
...@@ -115,13 +116,30 @@ CC_FILE_TEMPLATE = """// This file is generated by "paddle/fluid/ir/dialect/op_g ...@@ -115,13 +116,30 @@ CC_FILE_TEMPLATE = """// This file is generated by "paddle/fluid/ir/dialect/op_g
#include "paddle/phi/infermeta/fusion.h" #include "paddle/phi/infermeta/fusion.h"
#include "paddle/phi/api/lib/utils/allocator.h" #include "paddle/phi/api/lib/utils/allocator.h"
#include "paddle/fluid/primitive/rule/vjp/vjp.h" #include "paddle/fluid/primitive/rule/vjp/vjp.h"
{def_primitive}
#include "paddle/ir/core/op_base.h" #include "paddle/ir/core/op_base.h"
{input} {input}
{define_type_id} {define_type_id}
""" """
# =====================================
# String Template for pd_op_vjp.cc file code gen
# =====================================
VJP_CC_FILE_TEMPLATE = """// This file is generated by "paddle/fluid/ir/dialect/op_generator/op_gen.py"
#include "paddle/fluid/ir/dialect/paddle_dialect/ir/pd_attribute.h"
#include "paddle/fluid/ir/dialect/paddle_dialect/ir/pd_op.h"
#include "paddle/fluid/primitive/rule/vjp/vjp.h"
#include "paddle/fluid/primitive/type/lazy_tensor.h"
#include "paddle/ir/core/builtin_op.h"
#include "paddle/ir/core/op_base.h"
#include "paddle/phi/common/int_array.h"
namespace paddle {{
namespace dialect {{
{input}
}} // namespace dialect
}} // namespace paddle
"""
OP_N_ATTRIBUTE_DEFINED_TEMPLATE = """ OP_N_ATTRIBUTE_DEFINED_TEMPLATE = """
const char *{op_name}::attributes_name[{attribute_num}] = {{ {attribute_names} }}; const char *{op_name}::attributes_name[{attribute_num}] = {{ {attribute_names} }};
...@@ -681,6 +699,7 @@ def OpGenerator( ...@@ -681,6 +699,7 @@ def OpGenerator(
dialect_name, dialect_name,
op_def_h_file, op_def_h_file,
op_def_cc_file, op_def_cc_file,
op_vjp_cc_file,
): ):
# (1) Prepare: Delete existing old files: pd_op.h.tmp, pd_op.cc.tmp # (1) Prepare: Delete existing old files: pd_op.h.tmp, pd_op.cc.tmp
if os.path.exists(op_def_h_file): if os.path.exists(op_def_h_file):
...@@ -705,6 +724,7 @@ def OpGenerator( ...@@ -705,6 +724,7 @@ def OpGenerator(
ops_name_list = [] # all op class name store in this list ops_name_list = [] # all op class name store in this list
ops_declare_list = [] # all op class declare store in this list ops_declare_list = [] # all op class declare store in this list
ops_defined_list = [] # all op class defined store in this list ops_defined_list = [] # all op class defined store in this list
ops_vjp_defined_list = [] # all op vjp static interface defination
for key, op_info in op_info_items.items(): for key, op_info in op_info_items.items():
# get op inputs info # get op inputs info
op_input_name_list = op_info.input_name_list op_input_name_list = op_info.input_name_list
...@@ -1078,11 +1098,9 @@ def OpGenerator( ...@@ -1078,11 +1098,9 @@ def OpGenerator(
ops_defined_list.append(op_infer_meta_str) ops_defined_list.append(op_infer_meta_str)
# NOTE(chenxi67)skip if dialect_name==cinn # NOTE(chenxi67)skip if dialect_name==cinn
if dialect_name == "cinn": if dialect_name == "cinn":
import logging
logging.warning("cinn is currently not support Vjp function") logging.warning("cinn is currently not support Vjp function")
else: else:
ops_defined_list.append(op_vjp_str) ops_vjp_defined_list.append(op_vjp_str)
# (4) Generate head file str # (4) Generate head file str
op_namespaces_prev = "" op_namespaces_prev = ""
...@@ -1122,24 +1140,25 @@ def OpGenerator( ...@@ -1122,24 +1140,25 @@ def OpGenerator(
for op in ops_name_with_namespace_list: for op in ops_name_with_namespace_list:
define_type_id_str += DEFINE_OP_TYPE_ID.format(op_name=op) define_type_id_str += DEFINE_OP_TYPE_ID.format(op_name=op)
# NOTE(chenxi67) Skip include this header file if dialect_name == cinn
# otherwise we may get compile error when compile with "ncclDataType_t"
def_primitive_str = "#include \"paddle/fluid/primitive/type/lazy_tensor.h\""
if dialect_name == "cinn":
def_primitive_str = ""
source_file_str = CC_FILE_TEMPLATE.format( source_file_str = CC_FILE_TEMPLATE.format(
h_file=op_def_h_file[:-4], h_file=op_def_h_file[:-4],
def_primitive=def_primitive_str,
input=source_file_str, input=source_file_str,
define_type_id=define_type_id_str, define_type_id=define_type_id_str,
) # Add head ) # Add head
vjp_source_file_str = VJP_CC_FILE_TEMPLATE.format(
input="".join(ops_vjp_defined_list)
)
# (5) Generate pd_op.h.tmp, pd_op.cc.tmp # (5) Generate pd_op.h.tmp, pd_op.cc.tmp
with open(op_def_h_file, 'a') as f: with open(op_def_h_file, 'w') as f:
f.write(head_file_str) f.write(head_file_str)
with open(op_def_cc_file, 'a') as f: with open(op_def_cc_file, 'w') as f:
f.write(source_file_str) f.write(source_file_str)
# NOTE(Aurelius84): op_gen.py is called multiply times,
# and vjp is only avaible for pd dialect.
if dialect_name != 'cinn' and op_vjp_cc_file:
with open(op_vjp_cc_file, 'w') as f:
f.write(vjp_source_file_str)
# ===================================== # =====================================
...@@ -1155,6 +1174,7 @@ def ParseArguments(): ...@@ -1155,6 +1174,7 @@ def ParseArguments():
parser.add_argument('--dialect_name', type=str) parser.add_argument('--dialect_name', type=str)
parser.add_argument('--op_def_h_file', type=str) parser.add_argument('--op_def_h_file', type=str)
parser.add_argument('--op_def_cc_file', type=str) parser.add_argument('--op_def_cc_file', type=str)
parser.add_argument('--op_vjp_cc_file', type=str)
return parser.parse_args() return parser.parse_args()
...@@ -1172,6 +1192,7 @@ if __name__ == "__main__": ...@@ -1172,6 +1192,7 @@ if __name__ == "__main__":
dialect_name = args.dialect_name dialect_name = args.dialect_name
op_def_h_file = args.op_def_h_file op_def_h_file = args.op_def_h_file
op_def_cc_file = args.op_def_cc_file op_def_cc_file = args.op_def_cc_file
op_vjp_cc_file = args.op_vjp_cc_file
# auto code generate # auto code generate
OpGenerator( OpGenerator(
...@@ -1181,4 +1202,5 @@ if __name__ == "__main__": ...@@ -1181,4 +1202,5 @@ if __name__ == "__main__":
dialect_name, dialect_name,
op_def_h_file, op_def_h_file,
op_def_cc_file, op_def_cc_file,
op_vjp_cc_file,
) )
...@@ -57,17 +57,17 @@ OP_VJP_DEFINE_TEMPLATE = """ ...@@ -57,17 +57,17 @@ OP_VJP_DEFINE_TEMPLATE = """
std::vector<std::vector<ir::OpResult>> {op_class_name}::Vjp(ir::Operation* op, const std::vector<std::vector<ir::OpResult>>& out_grads, const std::vector<std::vector<bool>>& stop_gradients){{ std::vector<std::vector<ir::OpResult>> {op_class_name}::Vjp(ir::Operation* op, const std::vector<std::vector<ir::OpResult>>& out_grads, const std::vector<std::vector<bool>>& stop_gradients){{
{op_class_name} op_obj = op->dyn_cast<{op_class_name}>(); {op_class_name} op_obj = op->dyn_cast<{op_class_name}>();
VLOG(6) << "Prepare inputs of {op_grad_name}"; VLOG(6) << "Prepare inputs of {op_grad_name}";
{forward_input_code} {forward_input_code}
{forward_output_grad_code} {forward_output_grad_code}
VLOG(6) << "Vjp prepare Prepare attributes of {op_grad_name}"; VLOG(6) << "Vjp prepare Prepare attributes of {op_grad_name}";
{attribute_code} {attribute_code}
VLOG(4) << "Vjp prepare call {op_phi_name}'s vjp inteface"; VLOG(4) << "Vjp prepare call {op_phi_name}'s vjp inteface";
{call_vjp_code} {call_vjp_code}
VLOG(4) << "Vjp prepare stop gradient of {op_grad_name}"; VLOG(4) << "Vjp prepare stop gradient of {op_grad_name}";
{stop_gradient_input_grad_code} {stop_gradient_input_grad_code}
return res; return res;
}} }}
......
...@@ -36,18 +36,25 @@ set(op_source_file ${PD_DIALECT_BINARY_DIR}/pd_op.cc) ...@@ -36,18 +36,25 @@ set(op_source_file ${PD_DIALECT_BINARY_DIR}/pd_op.cc)
set(op_header_file_tmp ${op_header_file}.tmp) set(op_header_file_tmp ${op_header_file}.tmp)
set(op_source_file_tmp ${op_source_file}.tmp) set(op_source_file_tmp ${op_source_file}.tmp)
set(op_vjp_source_file ${PD_DIALECT_BINARY_DIR}/pd_op_vjp.cc)
set(op_vjp_source_file_tmp ${op_vjp_source_file}.tmp)
add_custom_command( add_custom_command(
OUTPUT ${op_header_file} ${op_source_file} OUTPUT ${op_header_file} ${op_source_file} ${op_vjp_source_file}
COMMAND COMMAND
${PYTHON_EXECUTABLE} ${op_gen_file} --op_yaml_files ${op_yaml_files} ${PYTHON_EXECUTABLE} ${op_gen_file} --op_yaml_files ${op_yaml_files}
--op_compat_yaml_file ${op_compat_yaml_file} --namespaces ${op_namespace} --op_compat_yaml_file ${op_compat_yaml_file} --namespaces ${op_namespace}
--dialect_name ${dialect_name} --op_def_h_file ${op_header_file_tmp} --dialect_name ${dialect_name} --op_def_h_file ${op_header_file_tmp}
--op_def_cc_file ${op_source_file_tmp} --op_def_cc_file ${op_source_file_tmp} --op_vjp_cc_file
${op_vjp_source_file_tmp}
COMMAND ${CMAKE_COMMAND} -E copy_if_different ${op_header_file_tmp} COMMAND ${CMAKE_COMMAND} -E copy_if_different ${op_header_file_tmp}
${op_header_file} ${op_header_file}
COMMAND ${CMAKE_COMMAND} -E copy_if_different ${op_source_file_tmp} COMMAND ${CMAKE_COMMAND} -E copy_if_different ${op_source_file_tmp}
${op_source_file} ${op_source_file}
COMMENT "copy_if_different ${op_header_file} ${op_source_file}" COMMAND ${CMAKE_COMMAND} -E copy_if_different ${op_vjp_source_file_tmp}
${op_vjp_source_file}
COMMENT
"copy_if_different ${op_header_file} ${op_source_file} ${op_vjp_source_file}"
DEPENDS ${op_gen_file} ${op_forward_yaml_file1} ${op_forward_yaml_file2} DEPENDS ${op_gen_file} ${op_forward_yaml_file1} ${op_forward_yaml_file2}
${op_backward_yaml_file1} ${op_backward_yaml_file2} ${op_backward_yaml_file1} ${op_backward_yaml_file2}
${op_compat_yaml_file} ${op_compat_yaml_file}
...@@ -98,6 +105,6 @@ target_include_directories(pd_dialect_api PRIVATE ${PD_DIALECT_BINARY_DIR}) ...@@ -98,6 +105,6 @@ target_include_directories(pd_dialect_api PRIVATE ${PD_DIALECT_BINARY_DIR})
cc_library( cc_library(
pd_dialect pd_dialect
SRCS pd_dialect.cc pd_op_vjp_manual.cc SRCS pd_dialect.cc pd_op_vjp_manual.cc ${op_vjp_source_file}
DEPS pd_dialect_api param_to_variable primitive_vjp_experimental DEPS pd_dialect_api param_to_variable primitive_vjp_experimental
pd_dialect_utils op_yaml_info_parser) pd_dialect_utils op_yaml_info_parser)
...@@ -25,6 +25,7 @@ ...@@ -25,6 +25,7 @@
#include "paddle/fluid/framework/op_desc.h" #include "paddle/fluid/framework/op_desc.h"
#include "paddle/fluid/ir/dialect/paddle_dialect/interface/op_yaml_info.h" #include "paddle/fluid/ir/dialect/paddle_dialect/interface/op_yaml_info.h"
#include "paddle/fluid/ir/dialect/paddle_dialect/ir/pd_attribute.h" #include "paddle/fluid/ir/dialect/paddle_dialect/ir/pd_attribute.h"
#include "paddle/fluid/ir/dialect/paddle_dialect/ir/pd_dialect.h"
#include "paddle/fluid/ir/dialect/paddle_dialect/ir/pd_type.h" #include "paddle/fluid/ir/dialect/paddle_dialect/ir/pd_type.h"
#include "paddle/fluid/ir/dialect/paddle_dialect/utils/utils.h" #include "paddle/fluid/ir/dialect/paddle_dialect/utils/utils.h"
#include "paddle/fluid/ir_adaptor/translator/attribute_translator.h" #include "paddle/fluid/ir_adaptor/translator/attribute_translator.h"
...@@ -1634,6 +1635,9 @@ struct ElementwiseGradTranscriber : public OpTranscriber { ...@@ -1634,6 +1635,9 @@ struct ElementwiseGradTranscriber : public OpTranscriber {
}; };
OpTranslator::OpTranslator() { OpTranslator::OpTranslator() {
ir::IrContext* ctx = ir::IrContext::Instance();
ctx->GetOrRegisterDialect<paddle::dialect::PaddleDialect>();
general_handler = OpTranscriber(); general_handler = OpTranscriber();
special_handlers["add_n"] = AddNOpTranscriber(); special_handlers["add_n"] = AddNOpTranscriber();
special_handlers["assign_value"] = AssignValueOpTranscriber(); special_handlers["assign_value"] = AssignValueOpTranscriber();
......
...@@ -24,7 +24,7 @@ cc_library( ...@@ -24,7 +24,7 @@ cc_library(
cc_library( cc_library(
jit_compilation_unit jit_compilation_unit
SRCS compilation_unit.cc SRCS compilation_unit.cc
DEPS proto_desc executor parallel_executor executor_cache) DEPS proto_desc executor parallel_executor)
cc_library( cc_library(
jit_function_schema jit_function_schema
......
...@@ -108,7 +108,7 @@ endif() ...@@ -108,7 +108,7 @@ endif()
cc_library( cc_library(
init init
SRCS init.cc SRCS init.cc
DEPS device_context phi memcpy pd_dialect) DEPS device_context phi memcpy)
# memcpy depends on device_context, here add deps individually for # memcpy depends on device_context, here add deps individually for
# avoiding cycle dependencies # avoiding cycle dependencies
......
...@@ -51,13 +51,9 @@ limitations under the License. */ ...@@ -51,13 +51,9 @@ limitations under the License. */
#include "paddle/fluid/platform/device/ipu/ipu_info.h" #include "paddle/fluid/platform/device/ipu/ipu_info.h"
#endif #endif
#include "paddle/fluid/ir/dialect/paddle_dialect/ir/pd_dialect.h"
#include "paddle/fluid/memory/allocation/allocator_facade.h" #include "paddle/fluid/memory/allocation/allocator_facade.h"
#include "paddle/fluid/memory/memory.h" #include "paddle/fluid/memory/memory.h"
#include "paddle/fluid/platform/flags.h" #include "paddle/fluid/platform/flags.h"
#include "paddle/ir/core/builtin_dialect.h"
#include "paddle/ir/core/ir_context.h"
#include "paddle/ir/core/program.h"
#include "paddle/phi/common/memory_utils.h" #include "paddle/phi/common/memory_utils.h"
#include "paddle/phi/core/custom_kernel.h" #include "paddle/phi/core/custom_kernel.h"
...@@ -202,9 +198,6 @@ void InitDevices() { ...@@ -202,9 +198,6 @@ void InitDevices() {
} }
void InitDevices(const std::vector<int> devices) { void InitDevices(const std::vector<int> devices) {
ir::IrContext *ctx = ir::IrContext::Instance();
ctx->GetOrRegisterDialect<paddle::dialect::PaddleDialect>();
std::vector<platform::Place> places; std::vector<platform::Place> places;
for (auto device : devices) { for (auto device : devices) {
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include <unordered_map> #include <unordered_map>
#include "paddle/fluid/ir/dialect/paddle_dialect/ir/pd_dialect.h"
#include "paddle/fluid/ir/dialect/paddle_dialect/ir/pd_type.h" #include "paddle/fluid/ir/dialect/paddle_dialect/ir/pd_type.h"
#include "paddle/ir/core/builtin_dialect.h" #include "paddle/ir/core/builtin_dialect.h"
#include "paddle/ir/core/builtin_type.h" #include "paddle/ir/core/builtin_type.h"
...@@ -209,6 +210,7 @@ IR_DEFINE_EXPLICIT_TYPE_ID(IntegerDialect) ...@@ -209,6 +210,7 @@ IR_DEFINE_EXPLICIT_TYPE_ID(IntegerDialect)
TEST(type_test, custom_type_dialect) { TEST(type_test, custom_type_dialect) {
ir::IrContext *ctx = ir::IrContext::Instance(); ir::IrContext *ctx = ir::IrContext::Instance();
ctx->GetOrRegisterDialect<paddle::dialect::PaddleDialect>();
// Test 1: Test the function of IrContext to register Dialect. // Test 1: Test the function of IrContext to register Dialect.
ctx->GetOrRegisterDialect<IntegerDialect>(); ctx->GetOrRegisterDialect<IntegerDialect>();
...@@ -240,6 +242,7 @@ TEST(type_test, custom_type_dialect) { ...@@ -240,6 +242,7 @@ TEST(type_test, custom_type_dialect) {
TEST(type_test, pd_dialect) { TEST(type_test, pd_dialect) {
ir::IrContext *ctx = ir::IrContext::Instance(); ir::IrContext *ctx = ir::IrContext::Instance();
ctx->GetOrRegisterDialect<paddle::dialect::PaddleDialect>();
ir::Type fp32_dtype = ir::Float32Type::get(ctx); ir::Type fp32_dtype = ir::Float32Type::get(ctx);
phi::DDim dims = {2, 2}; phi::DDim dims = {2, 2};
phi::DataLayout data_layout = phi::DataLayout::NCHW; phi::DataLayout data_layout = phi::DataLayout::NCHW;
......
...@@ -44,6 +44,7 @@ namespace framework { ...@@ -44,6 +44,7 @@ namespace framework {
TEST(VJP, TanhBackwardTest) { TEST(VJP, TanhBackwardTest) {
ir::IrContext* ctx = ir::IrContext::Instance(); ir::IrContext* ctx = ir::IrContext::Instance();
ctx->GetOrRegisterDialect<paddle::dialect::PaddleDialect>();
ir::Program program((ctx)); ir::Program program((ctx));
paddle::dialect::APIBuilder::Instance().SetProgram(&program); paddle::dialect::APIBuilder::Instance().SetProgram(&program);
...@@ -98,6 +99,7 @@ TEST(VJP, TanhBackwardTest) { ...@@ -98,6 +99,7 @@ TEST(VJP, TanhBackwardTest) {
TEST(VJP, Tanh_BackwardTest) { TEST(VJP, Tanh_BackwardTest) {
ir::IrContext* ctx = ir::IrContext::Instance(); ir::IrContext* ctx = ir::IrContext::Instance();
ctx->GetOrRegisterDialect<paddle::dialect::PaddleDialect>();
ir::Program program((ctx)); ir::Program program((ctx));
paddle::dialect::APIBuilder::Instance().SetProgram(&program); paddle::dialect::APIBuilder::Instance().SetProgram(&program);
...@@ -152,6 +154,7 @@ TEST(VJP, Tanh_BackwardTest) { ...@@ -152,6 +154,7 @@ TEST(VJP, Tanh_BackwardTest) {
TEST(VJP, MeanBackwardTest) { TEST(VJP, MeanBackwardTest) {
ir::IrContext* ctx = ir::IrContext::Instance(); ir::IrContext* ctx = ir::IrContext::Instance();
ctx->GetOrRegisterDialect<paddle::dialect::PaddleDialect>();
ir::Program program((ctx)); ir::Program program((ctx));
paddle::dialect::APIBuilder::Instance().SetProgram(&program); paddle::dialect::APIBuilder::Instance().SetProgram(&program);
...@@ -208,6 +211,7 @@ TEST(VJP, MeanBackwardTest) { ...@@ -208,6 +211,7 @@ TEST(VJP, MeanBackwardTest) {
TEST(VJP, ConcatBackwardTest) { TEST(VJP, ConcatBackwardTest) {
ir::IrContext* ctx = ir::IrContext::Instance(); ir::IrContext* ctx = ir::IrContext::Instance();
ctx->GetOrRegisterDialect<paddle::dialect::PaddleDialect>();
ir::Program program((ctx)); ir::Program program((ctx));
paddle::dialect::APIBuilder::Instance().SetProgram(&program); paddle::dialect::APIBuilder::Instance().SetProgram(&program);
...@@ -270,6 +274,7 @@ TEST(VJP, ConcatBackwardTest) { ...@@ -270,6 +274,7 @@ TEST(VJP, ConcatBackwardTest) {
TEST(VJP, AddBackwardTest) { TEST(VJP, AddBackwardTest) {
ir::IrContext* ctx = ir::IrContext::Instance(); ir::IrContext* ctx = ir::IrContext::Instance();
ctx->GetOrRegisterDialect<paddle::dialect::PaddleDialect>();
ir::Program program((ctx)); ir::Program program((ctx));
paddle::dialect::APIBuilder::Instance().SetProgram(&program); paddle::dialect::APIBuilder::Instance().SetProgram(&program);
...@@ -334,6 +339,7 @@ TEST(VJP, AddBackwardTest) { ...@@ -334,6 +339,7 @@ TEST(VJP, AddBackwardTest) {
TEST(VJP, Add_BackwardTest) { TEST(VJP, Add_BackwardTest) {
ir::IrContext* ctx = ir::IrContext::Instance(); ir::IrContext* ctx = ir::IrContext::Instance();
ctx->GetOrRegisterDialect<paddle::dialect::PaddleDialect>();
ir::Program program((ctx)); ir::Program program((ctx));
paddle::dialect::APIBuilder::Instance().SetProgram(&program); paddle::dialect::APIBuilder::Instance().SetProgram(&program);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册