未验证 提交 5d43f5e4 编写于 作者: A Aurelius84 提交者: GitHub

[NewIR]Add NOT_FOR_INFER to prune Inference Library Size and Split VJP CodeGen...

[NewIR]Add NOT_FOR_INFER to prune Inference Library Size and Split VJP CodeGen into pd_op_vjp.cc (#56352)

* [NewIR]Prune Inference Library Size and Remove IR Dialect

* remove options

* add NOT_FOR_INFER

* fix pd_vjp.cc

* polish deps

* fix code style

* fix unittest

* fix cmake

* fix inference CI
上级 4f652ac2
......@@ -118,6 +118,19 @@ function(find_fluid_modules TARGET_NAME)
endif()
endfunction()
# NOTE(Aurelius84): NOT_INFER_MODULES is used to tag
# and not considered as DEPS for inference libs.
set_property(GLOBAL PROPERTY NOT_INFER_MODULES "")
function(ignore_infer_modules TARGET_NAME)
get_property(not_infer_modules GLOBAL PROPERTY NOT_INFER_MODULES)
list(FIND not_infer_modules TARGET_NAME is_found)
if(is_found EQUAL -1) # NOT FOUND
set(not_infer_modules ${not_infer_modules} ${TARGET_NAME})
set_property(GLOBAL PROPERTY NOT_INFER_MODULES "${not_infer_modules}")
endif()
endfunction()
set_property(GLOBAL PROPERTY PHI_MODULES "")
# find all phi modules is used for paddle static library
# for building inference libs
......@@ -335,7 +348,15 @@ function(check_coverage_opt TARGET_NAME SRCS)
endfunction()
function(cc_library TARGET_NAME)
set(options STATIC static SHARED shared INTERFACE interface)
set(options
STATIC
static
SHARED
shared
INTERFACE
interface
NOT_FOR_INFER
not_for_infer)
set(oneValueArgs "")
set(multiValueArgs SRCS DEPS)
cmake_parse_arguments(cc_library "${options}" "${oneValueArgs}"
......@@ -347,6 +368,9 @@ function(cc_library TARGET_NAME)
CACHE STRING "output library name for target ${TARGET_NAME}")
endif()
if(cc_library_SRCS)
if(cc_library_NOT_FOR_INFER OR cc_library_not_for_infer)
ignore_infer_modules(${TARGET_NAME})
endif()
if(cc_library_SHARED OR cc_library_shared) # build *.so
add_library(${TARGET_NAME} SHARED ${cc_library_SRCS})
elseif(cc_library_INTERFACE OR cc_library_interface)
......
......@@ -276,7 +276,7 @@ set(SSA_GRAPH_EXECUTOR_DEPS
inplace_addto_op_pass
set_reader_device_info_utils)
cc_library(
ssa_graph_executor
ssa_graph_executor NOT_FOR_INFER
SRCS ssa_graph_executor.cc
DEPS ${SSA_GRAPH_EXECUTOR_DEPS})
......
cc_library(
interpretercore_garbage_collector
interpretercore_garbage_collector NOT_FOR_INFER
SRCS garbage_collector.cc event_garbage_collector.cc fast_garbage_collector.cc
no_event_garbage_collector.cc
DEPS garbage_collector executor_gc_helper)
......@@ -6,7 +6,7 @@ cc_library(
if(WITH_CINN AND NOT CINN_ONLY)
cc_library(
cinn_jit_instruction
cinn_jit_instruction NOT_FOR_INFER
SRCS cinn_jit_instruction.cc
DEPS phi cinnapi cinn_dialect)
endif()
......@@ -34,6 +34,7 @@ endif()
get_property(fluid_modules GLOBAL PROPERTY FLUID_MODULES)
get_property(phi_modules GLOBAL PROPERTY PHI_MODULES)
get_property(ir_targets GLOBAL PROPERTY IR_TARGETS)
get_property(not_infer_modules GLOBAL PROPERTY NOT_INFER_MODULES)
set(utils_modules pretty_log string_helper benchmark utf8proc)
add_subdirectory(api)
......@@ -57,18 +58,16 @@ set(KERNEL_LIST
""
CACHE STRING "The list of phi kernels that will be compiled")
# shared inference library deps
list(REMOVE_DUPLICATES fluid_modules)
#windows GPU static library over the limit, so not create_static_lib, and cc_library is dummy
if(WIN32 AND WITH_GPU)
cc_library(paddle_inference DEPS ${fluid_modules} ${ir_targets}
${STATIC_INFERENCE_API} ${utils_modules})
cc_library(paddle_inference DEPS ${fluid_modules} ${STATIC_INFERENCE_API}
${utils_modules})
else()
# message("${fluid_modules}")
# message("${STATIC_INFERENCE_API}")
# message("${utils_modules}")
# message("${phi_modules}")
if(WIN32)
create_static_lib(paddle_inference ${phi_modules} ${fluid_modules}
${STATIC_INFERENCE_API} ${utils_modules})
${ir_targets} ${STATIC_INFERENCE_API} ${utils_modules})
else()
create_static_lib(paddle_inference ${phi_modules} ${fluid_modules}
${ir_targets} ${STATIC_INFERENCE_API} ${utils_modules})
......@@ -82,7 +81,6 @@ if(NOT APPLE)
)
set_target_properties(paddle_inference PROPERTIES LINK_FLAGS "${LINK_FLAGS}")
endif()
# C inference API
add_subdirectory(capi_exp)
......@@ -99,15 +97,17 @@ set(SHARED_INFERENCE_SRCS
${CMAKE_CURRENT_SOURCE_DIR}/api/details/zero_copy_tensor.cc
${CMAKE_CURRENT_SOURCE_DIR}/utils/io_utils.cc)
# shared inference library deps
list(REMOVE_ITEM fluid_modules standalone_executor
interpretercore_garbage_collector)
if(WIN32)
set(SHARED_INFERENCE_DEPS phi ${fluid_modules} analysis_predictor
# NOTE(Aurelius84): For inference library, some DEPS is usless
# such as non-infer operator related targets et.al.
list(REMOVE_ITEM fluid_modules cinn_dialect)
# NOTE(Aurelisu84): Remove ir dialect related target DEPS for inference
# shared library to prune library size.
list(REMOVE_ITEM fluid_modules ${not_infer_modules})
set(SHARED_INFERENCE_DEPS phi ${fluid_modules} analysis_predictor
${utils_modules})
else()
set(SHARED_INFERENCE_DEPS phi ${fluid_modules} ${ir_targets}
analysis_predictor ${utils_modules})
if(NOT WIN32)
list(APPEND SHARED_INFERENCE_DEPS ${ir_targets})
endif()
if(WITH_CRYPTO)
......
......@@ -13,6 +13,7 @@
# limitations under the License.
import argparse
import logging
import os
import yaml
......@@ -115,13 +116,30 @@ CC_FILE_TEMPLATE = """// This file is generated by "paddle/fluid/ir/dialect/op_g
#include "paddle/phi/infermeta/fusion.h"
#include "paddle/phi/api/lib/utils/allocator.h"
#include "paddle/fluid/primitive/rule/vjp/vjp.h"
{def_primitive}
#include "paddle/ir/core/op_base.h"
{input}
{define_type_id}
"""
# =====================================
# String Template for pd_op_vjp.cc file code gen
# =====================================
VJP_CC_FILE_TEMPLATE = """// This file is generated by "paddle/fluid/ir/dialect/op_generator/op_gen.py"
#include "paddle/fluid/ir/dialect/paddle_dialect/ir/pd_attribute.h"
#include "paddle/fluid/ir/dialect/paddle_dialect/ir/pd_op.h"
#include "paddle/fluid/primitive/rule/vjp/vjp.h"
#include "paddle/fluid/primitive/type/lazy_tensor.h"
#include "paddle/ir/core/builtin_op.h"
#include "paddle/ir/core/op_base.h"
#include "paddle/phi/common/int_array.h"
namespace paddle {{
namespace dialect {{
{input}
}} // namespace dialect
}} // namespace paddle
"""
OP_N_ATTRIBUTE_DEFINED_TEMPLATE = """
const char *{op_name}::attributes_name[{attribute_num}] = {{ {attribute_names} }};
......@@ -681,6 +699,7 @@ def OpGenerator(
dialect_name,
op_def_h_file,
op_def_cc_file,
op_vjp_cc_file,
):
# (1) Prepare: Delete existing old files: pd_op.h.tmp, pd_op.cc.tmp
if os.path.exists(op_def_h_file):
......@@ -705,6 +724,7 @@ def OpGenerator(
ops_name_list = [] # all op class name store in this list
ops_declare_list = [] # all op class declare store in this list
ops_defined_list = [] # all op class defined store in this list
ops_vjp_defined_list = [] # all op vjp static interface defination
for key, op_info in op_info_items.items():
# get op inputs info
op_input_name_list = op_info.input_name_list
......@@ -1078,11 +1098,9 @@ def OpGenerator(
ops_defined_list.append(op_infer_meta_str)
# NOTE(chenxi67)skip if dialect_name==cinn
if dialect_name == "cinn":
import logging
logging.warning("cinn is currently not support Vjp function")
else:
ops_defined_list.append(op_vjp_str)
ops_vjp_defined_list.append(op_vjp_str)
# (4) Generate head file str
op_namespaces_prev = ""
......@@ -1122,24 +1140,25 @@ def OpGenerator(
for op in ops_name_with_namespace_list:
define_type_id_str += DEFINE_OP_TYPE_ID.format(op_name=op)
# NOTE(chenxi67) Skip include this header file if dialect_name == cinn
# otherwise we may get compile error when compile with "ncclDataType_t"
def_primitive_str = "#include \"paddle/fluid/primitive/type/lazy_tensor.h\""
if dialect_name == "cinn":
def_primitive_str = ""
source_file_str = CC_FILE_TEMPLATE.format(
h_file=op_def_h_file[:-4],
def_primitive=def_primitive_str,
input=source_file_str,
define_type_id=define_type_id_str,
) # Add head
vjp_source_file_str = VJP_CC_FILE_TEMPLATE.format(
input="".join(ops_vjp_defined_list)
)
# (5) Generate pd_op.h.tmp, pd_op.cc.tmp
with open(op_def_h_file, 'a') as f:
with open(op_def_h_file, 'w') as f:
f.write(head_file_str)
with open(op_def_cc_file, 'a') as f:
with open(op_def_cc_file, 'w') as f:
f.write(source_file_str)
# NOTE(Aurelius84): op_gen.py is called multiply times,
# and vjp is only avaible for pd dialect.
if dialect_name != 'cinn' and op_vjp_cc_file:
with open(op_vjp_cc_file, 'w') as f:
f.write(vjp_source_file_str)
# =====================================
......@@ -1155,6 +1174,7 @@ def ParseArguments():
parser.add_argument('--dialect_name', type=str)
parser.add_argument('--op_def_h_file', type=str)
parser.add_argument('--op_def_cc_file', type=str)
parser.add_argument('--op_vjp_cc_file', type=str)
return parser.parse_args()
......@@ -1172,6 +1192,7 @@ if __name__ == "__main__":
dialect_name = args.dialect_name
op_def_h_file = args.op_def_h_file
op_def_cc_file = args.op_def_cc_file
op_vjp_cc_file = args.op_vjp_cc_file
# auto code generate
OpGenerator(
......@@ -1181,4 +1202,5 @@ if __name__ == "__main__":
dialect_name,
op_def_h_file,
op_def_cc_file,
op_vjp_cc_file,
)
......@@ -57,17 +57,17 @@ OP_VJP_DEFINE_TEMPLATE = """
std::vector<std::vector<ir::OpResult>> {op_class_name}::Vjp(ir::Operation* op, const std::vector<std::vector<ir::OpResult>>& out_grads, const std::vector<std::vector<bool>>& stop_gradients){{
{op_class_name} op_obj = op->dyn_cast<{op_class_name}>();
VLOG(6) << "Prepare inputs of {op_grad_name}";
VLOG(6) << "Prepare inputs of {op_grad_name}";
{forward_input_code}
{forward_output_grad_code}
VLOG(6) << "Vjp prepare Prepare attributes of {op_grad_name}";
VLOG(6) << "Vjp prepare Prepare attributes of {op_grad_name}";
{attribute_code}
VLOG(4) << "Vjp prepare call {op_phi_name}'s vjp inteface";
VLOG(4) << "Vjp prepare call {op_phi_name}'s vjp inteface";
{call_vjp_code}
VLOG(4) << "Vjp prepare stop gradient of {op_grad_name}";
VLOG(4) << "Vjp prepare stop gradient of {op_grad_name}";
{stop_gradient_input_grad_code}
return res;
}}
......
......@@ -36,18 +36,25 @@ set(op_source_file ${PD_DIALECT_BINARY_DIR}/pd_op.cc)
set(op_header_file_tmp ${op_header_file}.tmp)
set(op_source_file_tmp ${op_source_file}.tmp)
set(op_vjp_source_file ${PD_DIALECT_BINARY_DIR}/pd_op_vjp.cc)
set(op_vjp_source_file_tmp ${op_vjp_source_file}.tmp)
add_custom_command(
OUTPUT ${op_header_file} ${op_source_file}
OUTPUT ${op_header_file} ${op_source_file} ${op_vjp_source_file}
COMMAND
${PYTHON_EXECUTABLE} ${op_gen_file} --op_yaml_files ${op_yaml_files}
--op_compat_yaml_file ${op_compat_yaml_file} --namespaces ${op_namespace}
--dialect_name ${dialect_name} --op_def_h_file ${op_header_file_tmp}
--op_def_cc_file ${op_source_file_tmp}
--op_def_cc_file ${op_source_file_tmp} --op_vjp_cc_file
${op_vjp_source_file_tmp}
COMMAND ${CMAKE_COMMAND} -E copy_if_different ${op_header_file_tmp}
${op_header_file}
COMMAND ${CMAKE_COMMAND} -E copy_if_different ${op_source_file_tmp}
${op_source_file}
COMMENT "copy_if_different ${op_header_file} ${op_source_file}"
COMMAND ${CMAKE_COMMAND} -E copy_if_different ${op_vjp_source_file_tmp}
${op_vjp_source_file}
COMMENT
"copy_if_different ${op_header_file} ${op_source_file} ${op_vjp_source_file}"
DEPENDS ${op_gen_file} ${op_forward_yaml_file1} ${op_forward_yaml_file2}
${op_backward_yaml_file1} ${op_backward_yaml_file2}
${op_compat_yaml_file}
......@@ -98,6 +105,6 @@ target_include_directories(pd_dialect_api PRIVATE ${PD_DIALECT_BINARY_DIR})
cc_library(
pd_dialect
SRCS pd_dialect.cc pd_op_vjp_manual.cc
SRCS pd_dialect.cc pd_op_vjp_manual.cc ${op_vjp_source_file}
DEPS pd_dialect_api param_to_variable primitive_vjp_experimental
pd_dialect_utils op_yaml_info_parser)
......@@ -25,6 +25,7 @@
#include "paddle/fluid/framework/op_desc.h"
#include "paddle/fluid/ir/dialect/paddle_dialect/interface/op_yaml_info.h"
#include "paddle/fluid/ir/dialect/paddle_dialect/ir/pd_attribute.h"
#include "paddle/fluid/ir/dialect/paddle_dialect/ir/pd_dialect.h"
#include "paddle/fluid/ir/dialect/paddle_dialect/ir/pd_type.h"
#include "paddle/fluid/ir/dialect/paddle_dialect/utils/utils.h"
#include "paddle/fluid/ir_adaptor/translator/attribute_translator.h"
......@@ -1634,6 +1635,9 @@ struct ElementwiseGradTranscriber : public OpTranscriber {
};
OpTranslator::OpTranslator() {
ir::IrContext* ctx = ir::IrContext::Instance();
ctx->GetOrRegisterDialect<paddle::dialect::PaddleDialect>();
general_handler = OpTranscriber();
special_handlers["add_n"] = AddNOpTranscriber();
special_handlers["assign_value"] = AssignValueOpTranscriber();
......
......@@ -24,7 +24,7 @@ cc_library(
cc_library(
jit_compilation_unit
SRCS compilation_unit.cc
DEPS proto_desc executor parallel_executor executor_cache)
DEPS proto_desc executor parallel_executor)
cc_library(
jit_function_schema
......
......@@ -108,7 +108,7 @@ endif()
cc_library(
init
SRCS init.cc
DEPS device_context phi memcpy pd_dialect)
DEPS device_context phi memcpy)
# memcpy depends on device_context, here add deps individually for
# avoiding cycle dependencies
......
......@@ -51,13 +51,9 @@ limitations under the License. */
#include "paddle/fluid/platform/device/ipu/ipu_info.h"
#endif
#include "paddle/fluid/ir/dialect/paddle_dialect/ir/pd_dialect.h"
#include "paddle/fluid/memory/allocation/allocator_facade.h"
#include "paddle/fluid/memory/memory.h"
#include "paddle/fluid/platform/flags.h"
#include "paddle/ir/core/builtin_dialect.h"
#include "paddle/ir/core/ir_context.h"
#include "paddle/ir/core/program.h"
#include "paddle/phi/common/memory_utils.h"
#include "paddle/phi/core/custom_kernel.h"
......@@ -202,9 +198,6 @@ void InitDevices() {
}
void InitDevices(const std::vector<int> devices) {
ir::IrContext *ctx = ir::IrContext::Instance();
ctx->GetOrRegisterDialect<paddle::dialect::PaddleDialect>();
std::vector<platform::Place> places;
for (auto device : devices) {
......
......@@ -15,6 +15,7 @@
#include <gtest/gtest.h>
#include <unordered_map>
#include "paddle/fluid/ir/dialect/paddle_dialect/ir/pd_dialect.h"
#include "paddle/fluid/ir/dialect/paddle_dialect/ir/pd_type.h"
#include "paddle/ir/core/builtin_dialect.h"
#include "paddle/ir/core/builtin_type.h"
......@@ -209,6 +210,7 @@ IR_DEFINE_EXPLICIT_TYPE_ID(IntegerDialect)
TEST(type_test, custom_type_dialect) {
ir::IrContext *ctx = ir::IrContext::Instance();
ctx->GetOrRegisterDialect<paddle::dialect::PaddleDialect>();
// Test 1: Test the function of IrContext to register Dialect.
ctx->GetOrRegisterDialect<IntegerDialect>();
......@@ -240,6 +242,7 @@ TEST(type_test, custom_type_dialect) {
TEST(type_test, pd_dialect) {
ir::IrContext *ctx = ir::IrContext::Instance();
ctx->GetOrRegisterDialect<paddle::dialect::PaddleDialect>();
ir::Type fp32_dtype = ir::Float32Type::get(ctx);
phi::DDim dims = {2, 2};
phi::DataLayout data_layout = phi::DataLayout::NCHW;
......
......@@ -44,6 +44,7 @@ namespace framework {
TEST(VJP, TanhBackwardTest) {
ir::IrContext* ctx = ir::IrContext::Instance();
ctx->GetOrRegisterDialect<paddle::dialect::PaddleDialect>();
ir::Program program((ctx));
paddle::dialect::APIBuilder::Instance().SetProgram(&program);
......@@ -98,6 +99,7 @@ TEST(VJP, TanhBackwardTest) {
TEST(VJP, Tanh_BackwardTest) {
ir::IrContext* ctx = ir::IrContext::Instance();
ctx->GetOrRegisterDialect<paddle::dialect::PaddleDialect>();
ir::Program program((ctx));
paddle::dialect::APIBuilder::Instance().SetProgram(&program);
......@@ -152,6 +154,7 @@ TEST(VJP, Tanh_BackwardTest) {
TEST(VJP, MeanBackwardTest) {
ir::IrContext* ctx = ir::IrContext::Instance();
ctx->GetOrRegisterDialect<paddle::dialect::PaddleDialect>();
ir::Program program((ctx));
paddle::dialect::APIBuilder::Instance().SetProgram(&program);
......@@ -208,6 +211,7 @@ TEST(VJP, MeanBackwardTest) {
TEST(VJP, ConcatBackwardTest) {
ir::IrContext* ctx = ir::IrContext::Instance();
ctx->GetOrRegisterDialect<paddle::dialect::PaddleDialect>();
ir::Program program((ctx));
paddle::dialect::APIBuilder::Instance().SetProgram(&program);
......@@ -270,6 +274,7 @@ TEST(VJP, ConcatBackwardTest) {
TEST(VJP, AddBackwardTest) {
ir::IrContext* ctx = ir::IrContext::Instance();
ctx->GetOrRegisterDialect<paddle::dialect::PaddleDialect>();
ir::Program program((ctx));
paddle::dialect::APIBuilder::Instance().SetProgram(&program);
......@@ -334,6 +339,7 @@ TEST(VJP, AddBackwardTest) {
TEST(VJP, Add_BackwardTest) {
ir::IrContext* ctx = ir::IrContext::Instance();
ctx->GetOrRegisterDialect<paddle::dialect::PaddleDialect>();
ir::Program program((ctx));
paddle::dialect::APIBuilder::Instance().SetProgram(&program);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册