未验证 提交 521bba9c 编写于 作者: J JingZhuangzhuang 提交者: GitHub

modify cmake rules temporarily (#51644)

上级 fc2ccbb8
...@@ -109,6 +109,8 @@ set(SHARED_INFERENCE_SRCS ...@@ -109,6 +109,8 @@ set(SHARED_INFERENCE_SRCS
${PADDLE_CUSTOM_OP_SRCS}) ${PADDLE_CUSTOM_OP_SRCS})
# shared inference library deps # shared inference library deps
list(REMOVE_ITEM fluid_modules standalone_executor
interpretercore_garbage_collector)
set(SHARED_INFERENCE_DEPS ${fluid_modules} phi analysis_predictor set(SHARED_INFERENCE_DEPS ${fluid_modules} phi analysis_predictor
${utils_modules}) ${utils_modules})
......
...@@ -3,7 +3,13 @@ if(WITH_GPU) ...@@ -3,7 +3,13 @@ if(WITH_GPU)
nv_library( nv_library(
gpu_info gpu_info
SRCS gpu_info.cc SRCS gpu_info.cc
DEPS phi_backends gflags glog enforce monitor dynload_cuda) DEPS phi_backends
gflags
glog
enforce
monitor
dynload_cuda
malloc)
nv_test(cuda_helper_test SRCS cuda_helper_test.cu) nv_test(cuda_helper_test SRCS cuda_helper_test.cu)
nv_test( nv_test(
......
...@@ -30,15 +30,20 @@ def parse_args(): ...@@ -30,15 +30,20 @@ def parse_args():
'--model_dir', '--model_dir',
type=str, type=str,
default="", default="",
help='Directory of the inference models.', help='Directory of the inference models that named with pdmodel.',
)
parser.add_argument(
'--op_list',
type=str,
default="",
help='List of ops like "conv2d;pool2d;relu".',
) )
return parser.parse_args() return parser.parse_args()
def get_model_ops(model_file): def get_model_ops(model_file, ops_set):
model_bytes = paddle.static.load_from_file(model_file) model_bytes = paddle.static.load_from_file(model_file)
pg = paddle.static.deserialize_program(model_bytes) pg = paddle.static.deserialize_program(model_bytes)
ops_set = set()
for i in range(0, pg.desc.num_blocks()): for i in range(0, pg.desc.num_blocks()):
block = pg.desc.block(i) block = pg.desc.block(i)
...@@ -47,12 +52,12 @@ def get_model_ops(model_file): ...@@ -47,12 +52,12 @@ def get_model_ops(model_file):
for j in range(0, size): for j in range(0, size):
ops_set.add(block.op(j).type()) ops_set.add(block.op(j).type())
return ops_set
def get_model_phi_kernels(ops_set): def get_model_phi_kernels(ops_set):
phi_set = set() phi_set = set()
for op in ops_set: for op in ops_set:
print(op)
print(_get_phi_kernel_name(op))
phi_set.add(_get_phi_kernel_name(op)) phi_set.add(_get_phi_kernel_name(op))
return phi_set return phi_set
...@@ -60,10 +65,17 @@ def get_model_phi_kernels(ops_set): ...@@ -60,10 +65,17 @@ def get_model_phi_kernels(ops_set):
if __name__ == '__main__': if __name__ == '__main__':
args = parse_args() args = parse_args()
for root, dirs, files in os.walk(args.model_dir, topdown=True): ops_set = set()
for name in files: if args.op_list != "":
if re.match(r'.*pdmodel', name): op_list = args.op_list.split(";")
ops_set = get_model_ops(os.path.join(root, name)) for op in op_list:
ops_set.add(op)
if args.model_dir != "":
for root, dirs, files in os.walk(args.model_dir, topdown=True):
for name in files:
if re.match(r'.*pdmodel', name):
get_model_ops(os.path.join(root, name), ops_set)
phi_set = get_model_phi_kernels(ops_set) phi_set = get_model_phi_kernels(ops_set)
ops = ";".join(ops_set) ops = ";".join(ops_set)
kernels = ";".join(phi_set) kernels = ";".join(phi_set)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册