From 189ac02b766cf72a3bbcbbb55d922a9787aad6fb Mon Sep 17 00:00:00 2001 From: 123malin Date: Mon, 17 Feb 2020 19:52:20 +0800 Subject: [PATCH] test=develop, add distributed tools (#22623) (#22637) --- .../fleet/utils/fleet_barrier_util.py | 1 + .../fluid/incubate/fleet/utils/fleet_util.py | 82 +++- .../fluid/incubate/fleet/utils/utils.py | 428 ++++++++++++++++++ .../fluid/tests/unittests/test_fleet_utils.py | 188 ++++++++ 4 files changed, 698 insertions(+), 1 deletion(-) create mode 100644 python/paddle/fluid/incubate/fleet/utils/utils.py diff --git a/python/paddle/fluid/incubate/fleet/utils/fleet_barrier_util.py b/python/paddle/fluid/incubate/fleet/utils/fleet_barrier_util.py index bce8da641c..a9fd8ac74f 100644 --- a/python/paddle/fluid/incubate/fleet/utils/fleet_barrier_util.py +++ b/python/paddle/fluid/incubate/fleet/utils/fleet_barrier_util.py @@ -15,6 +15,7 @@ from paddle.fluid.incubate.fleet.parameter_server.distribute_transpiler import fleet from paddle.fluid.contrib.utils import HDFSClient import os +import time def check_all_trainers_ready(ready_path, epoch): diff --git a/python/paddle/fluid/incubate/fleet/utils/fleet_util.py b/python/paddle/fluid/incubate/fleet/utils/fleet_util.py index 50fde2c47b..4a1fd20afc 100644 --- a/python/paddle/fluid/incubate/fleet/utils/fleet_util.py +++ b/python/paddle/fluid/incubate/fleet/utils/fleet_util.py @@ -23,15 +23,19 @@ import sys import time import paddle.fluid as fluid from paddle.fluid.log_helper import get_logger -from paddle.fluid.incubate.fleet.parameter_server.pslib import fleet +from paddle.fluid.incubate.fleet.parameter_server.pslib import fleet as fleet_pslib +from paddle.fluid.incubate.fleet.parameter_server.distribute_transpiler import fleet as fleet_transpiler from . import hdfs from .hdfs import * +from . import utils __all__ = ["FleetUtil"] _logger = get_logger( __name__, logging.INFO, fmt='%(asctime)s-%(levelname)s: %(message)s') +fleet = fleet_pslib + class FleetUtil(object): """ @@ -46,6 +50,16 @@ class FleetUtil(object): """ + def __init__(self, mode="pslib"): + global fleet + if mode == "pslib": + fleet = fleet_pslib + elif mode == "transpiler": + fleet = fleet_transpiler + else: + raise ValueError( + "Please choose one mode from [\"pslib\", \"transpiler\"]") + def rank0_print(self, s): """ Worker of rank 0 print some log. @@ -1535,3 +1549,69 @@ class FleetUtil(object): (print_prefix, auc, bucket_error, mae, rmse, actual_ctr, predicted_ctr, copc, mean_predict_qvalue, total_ins_num)) + + def program_type_trans(self, prog_dir, prog_fn, is_text): + return utils.program_type_trans(prog_dir, prog_fn, is_text) + + def draw_from_program_file(self, model_filename, is_text, output_dir, + output_filename): + """draw program from file""" + program = utils.load_program(model_filename, is_text) + utils.graphviz(program.global_block(), output_dir, output_filename) + + def draw_from_program(self, program, output_dir, output_name): + """draw Program""" + utils.graphviz(program.global_block(), output_dir, output_name) + + def check_two_programs(self, config): + train_prog = utils.load_program(config.train_prog_path, + config.is_text_train_program) + pruned_prog = utils.load_program(config.pruned_prog_path, + config.is_text_pruned_program) + if config.draw: + pruned_dir = os.path.dirname(config.pruned_prog_path) + self.draw_from_program(pruned_prog, pruned_dir, + config.draw_out_name) + res = utils.check_pruned_program_vars(train_prog, pruned_prog) + if res: + _logger.info("check_programs succeed.") + else: + _logger.info( + "check_programs failed. pruned program and train program not match!" + ) + return res + + def check_vars_and_dump(self, config): + _logger.info("start check_vars_and_dump.") + results = utils.check_saved_vars_try_dump( + config.dump_model_dir, config.dump_program_filename, + config.is_text_dump_program, config.feed_config, + config.fetch_config, config.batch_size, config.save_params_filename) + _logger.info("check_vars_and_dump succeed.") + return results + + def parse_program_proto(self, prog_path, is_text, output_dir): + """ + Parse program.proto into a more readable format. + This function will generate three files: + output_dir/vars_all.log, + output_dir/vars_persistable.log, + output_dir/ops.log. + + Args: + prog_path(str): proto file path to be parsed. + is_text(bool): proto file is human-readale format or not(binary). + output_dir(str): output dir. + + Examples: + .. code-block:: python + + from paddle.fluid.incubate.fleet.utils.fleet_util import FleetUtil + fleet_util = FleetUtil() + program_path = "./program.pbtxt" + is_text = True + output_dir = "/tmp/" + fleet_util.parse_program_proto(program_path, is_text, output_dir) + """ + program = utils.load_program(prog_path, is_text) + utils.parse_program(program, output_dir) diff --git a/python/paddle/fluid/incubate/fleet/utils/utils.py b/python/paddle/fluid/incubate/fleet/utils/utils.py new file mode 100644 index 0000000000..79f3fb9193 --- /dev/null +++ b/python/paddle/fluid/incubate/fleet/utils/utils.py @@ -0,0 +1,428 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function, absolute_import +import os +import sys +import logging +import subprocess +import numpy as np +from collections import OrderedDict +import paddle.fluid as fluid +from paddle.fluid import core +from paddle.fluid.log_helper import get_logger + +from google.protobuf import text_format +from paddle.fluid import debugger +from paddle.fluid.framework import Program +from paddle.fluid.proto import framework_pb2 + +__all__ = [ + "load_program", "save_program", "program_type_trans", + "check_saved_vars_try_dump", "parse_program", "check_pruned_program_vars", + "graphviz" +] + +logging.basicConfig( + format='%(asctime)s - %(levelname)s - %(message)s', level=logging.INFO) +logger = logging.getLogger(__name__) + +persistable_vars_out_fn = "vars_persistable.log" +all_vars_out_fn = "vars_all.log" +ops_out_fn = "ops.log" + +feed_fetch_type_list = [ + core.VarDesc.VarType.FEED_MINIBATCH, core.VarDesc.VarType.FETCH_LIST +] +not_expected_op_types = ["lookup_table"] + + +def load_program(model_filename, is_text=False): + if is_text: + return load_program_text(model_filename) + return load_program_binary(model_filename) + + +def load_program_binary(model_filename): + """load program from binary string file""" + with open(model_filename, "rb") as f: + program_desc_str = f.read() + return Program.parse_from_string(program_desc_str) + + +def load_program_text(model_filename): + """load program from human-readable text file""" + with open(model_filename, "r") as f: + program_desc_text = f.read() + + prog_desc = framework_pb2.ProgramDesc() + text_format.Merge(program_desc_text, prog_desc) + return Program.parse_from_string(prog_desc.SerializeToString()) + + +def save_program(program, model_filename='__model__', is_text=False): + if is_text: + with open(model_filename, "w") as f: + f.write(str(program)) + else: + with open(model_filename, "wb") as f: + f.write(program.desc.serialize_to_string()) + + +def check_pruned_program_vars(train_prog, pruned_prog): + is_match = True + + pruned_vars = [(v.name, v) for v in pruned_prog.list_vars() + if fluid.io.is_persistable(v)] + pruned_vars = OrderedDict(pruned_vars) + pruned_vars_name = [name for name in pruned_vars] + logger.info("persistable vars in pruned program: {}".format( + pruned_vars_name)) + + for var_name in pruned_vars: + var = pruned_vars[var_name] + # feed and fetch op is added in pruned program when pruning, not need to be found in train program + if var.type in feed_fetch_type_list: + break + try: + train_prog_var = train_prog.global_block().var(var_name) + except ValueError as e: + logger.error( + "not find variable '%s' in train program. please check pruning." + % var_name) + logger.error(e) + continue + if var.shape != train_prog_var.shape or var.dtype != train_prog_var.dtype: + logger.error( + "variable: {} not match. in pruned program shape: {} dtype:{}, in train program shape: {} dtype: {}". + format(var_name, var.shape, var.dtype, train_prog_var.shape, + train_prog_var.dtype)) + is_match = False + return is_match + + +def graphviz(block, output_dir="", filename='debug'): + dot_path = os.path.join(output_dir, filename + '.dot') + pdf_path = os.path.join(output_dir, filename + '.pdf') + debugger.draw_block_graphviz(block, path=dot_path) + cmd = ["dot", "-Tpdf", dot_path, "-o", pdf_path] + p = subprocess.Popen( + cmd, + stdin=subprocess.PIPE, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE) + p.wait() + + +def program_type_trans(prog_dir, prog_fn, is_text): + prog = load_program(os.path.join(prog_dir, prog_fn), is_text) + prog_out_fn = prog_fn + ".bin" if is_text else prog_fn + ".pbtxt" + save_program(prog, os.path.join(prog_dir, prog_out_fn), 1 - is_text) + return prog_out_fn + + +def append_save_op(block, var, path): + block.append_op( + type='save', inputs={'X': [var]}, outputs={}, + attrs={'file_path': path}) + + +def append_load_op(block, var, path): + block.append_op( + type='load', + inputs={}, + outputs={'Out': [var]}, + attrs={'file_path': path}) + + +def save_var(np_array, var_name, shape_list, dtype, save_path): + program = fluid.Program() + place = fluid.CPUPlace() + exe = fluid.Executor(place) + with fluid.program_guard(program): + d0_data = fluid.layers.data(var_name, shape=shape_list, dtype=dtype) + append_save_op(program.global_block(), d0_data, save_path) + exe.run(feed={var_name: np_array}, fetch_list=[]) + + +def load_var(var_name, shape_list, dtype, save_path): + program = fluid.Program() + place = fluid.CPUPlace() + exe = fluid.Executor(place) + with fluid.program_guard(program): + d0_data = fluid.layers.data(var_name, shape=shape_list, dtype=dtype) + append_load_op(program.global_block(), d0_data, save_path) + outs = exe.run(feed={}, fetch_list=[d0_data]) + return outs + + +def reader(batch_size, fn, dim): + data = [] + if isinstance(dim, list) or isinstance(dim, tuple): + shape = list(dim) + _temp = 1 + for x in dim: + _temp = _temp * x + dim = _temp + else: + shape = [dim] + + shape = [batch_size] + shape + dim = dim * batch_size + + for line in open(fn, 'r'): + fields = line.strip().split(' ') + fields = [float(d) for d in fields] + while len(fields) >= dim: + tmp = fields[:dim] + fields = fields[dim:] + data.append(np.array(tmp).reshape(shape)) + return data + + +def feed_gen(batch_size, feeded_vars_dims, feeded_vars_filelist): + batch_feed = [] + for i, fn in enumerate(feeded_vars_filelist): + batch_feed.append(reader(batch_size, fn, feeded_vars_dims[i])) + return batch_feed + + +def try_load_model_vars(dump_dir, dump_prog_fn, is_text_dump_program, + batch_size, feed_config, fetch_config, save_filename, + saved_params): + place = fluid.CPUPlace() + exe = fluid.Executor(place) + scope = fluid.core.Scope() + with fluid.scope_guard(scope): + if is_text_dump_program: + dump_prog_fn = program_type_trans(dump_dir, dump_prog_fn, + is_text_dump_program) + inference_program, feed_target_names, fetch_targets = \ + fluid.io.load_inference_model(dump_dir, exe, model_filename=dump_prog_fn, + params_filename=save_filename) + + # check program vars and saved vars shape + orig_para_shape = { + each_var.name: tuple(each_var.desc.shape()) + for each_var in saved_params + } + for each_var in saved_params: + var_temp = fluid.global_scope().find_var(each_var.name) + assert var_temp != None, "can't not find var: " + each_var.name + new_shape = (np.array(var_temp.get_tensor())).shape + assert each_var.name in orig_para_shape, each_var.name + "MUST in var list" + orig_shape = orig_para_shape.get(each_var.name) + if new_shape != orig_shape: + raise RuntimeError( + "Shape not matching: the Program requires a parameter with a shape of ({}), " + "while the loaded parameter (namely [ {} ]) has a shape of ({}).". + format(orig_shape, each_var.name, new_shape)) + + # check feed/fetch vars in program and config + fetch_targets_names = [v.name for v in fetch_targets] + if not feed_target_names: + logger.warning("no feed targets in program.") + if not fetch_targets_names: + logger.warning("no fetch targets in program.") + fetch_list = fetch_targets + feed_name_list = feed_target_names + if feed_config.feeded_vars_names is not None and feed_target_names != feed_config.feeded_vars_names: + logger.warning( + "feed vars in program and config are diff: feed in program: {}. feed in config {}.". + format(feed_target_names, feed_config.feeded_vars_names)) + feed_name_list = feed_config.feeded_vars_names + # remove feed op in inference_program. new feed op will be added in exe.run + global_block = inference_program.global_block() + need_to_remove_op_index = [] + for i, op in enumerate(global_block.ops): + op.desc.set_is_target(False) + if op.type == "feed": # only remove feed op here + need_to_remove_op_index.append(i) + for index in need_to_remove_op_index[::-1]: + global_block._remove_op(index) + if fetch_config.fetch_vars_names is not None and fetch_targets_names != fetch_config.fetch_vars_names: + logger.warning( + "fetch vars in program and config are diff: fetch in program: {}. fetch in config {}.". + format(fetch_targets_names, fetch_config.fetch_vars_names)) + fetch_list = [ + inference_program.global_block().var(i) + for i in fetch_config.fetch_vars_names + ] + # remove fetch op in inference_program. new fetch op will be added in exe.run + global_block = inference_program.global_block() + need_to_remove_op_index = [] + for i, op in enumerate(global_block.ops): + op.desc.set_is_target(False) + if op.type == "fetch": # only remove fetch op here + need_to_remove_op_index.append(i) + for index in need_to_remove_op_index[::-1]: + global_block._remove_op(index) + + # if fetch_list have lod tensor + return_numpy = all([v.lod_level == 0 for v in fetch_list]) + + # try dump fetch_targets + feed_tensors = [] + assert len(feed_config.feeded_vars_names) == len( + feed_config.feeded_vars_dims) == len(feed_config.feeded_vars_types) + # check program vars and feed tensor shape in config + for i in range(len(feed_config.feeded_vars_names)): + var = inference_program.global_block().var( + feed_config.feeded_vars_names[i]) + if not isinstance(feed_config.feeded_vars_dims[i], (list, tuple)): + tensor_shape = (feed_config.feeded_vars_dims[i], ) + else: + tensor_shape = tuple(feed_config.feeded_vars_dims[i]) + feed_config.feeded_vars_dims[i] = tensor_shape + var_shape = var.shape[1:] + if tensor_shape != var_shape: + raise RuntimeError( + "feed variable '{}' shape not match. infer program shape: {}. feed tensor shape: {}". + format(feed_config.feeded_vars_names[i], var_shape, + tensor_shape)) + + if not feed_config.feeded_vars_filelist: + logger.info("generate random feed vars.") + for i in range(len(feed_config.feeded_vars_names)): + var = inference_program.global_block().var( + feed_config.feeded_vars_names[i]) + # create fake feed tensor. if lod_level > 1, should create_lod_tensor() + if var.lod_level == 0: + feed_tensors.append( + np.array( + np.random.random( + tuple([batch_size] + list( + feed_config.feeded_vars_dims[i]))), + dtype=feed_config.feeded_vars_types[i])) + elif var.lod_level == 1: + t = np.array( + np.random.random( + tuple([batch_size] + list( + feed_config.feeded_vars_dims[i]))), + dtype=feed_config.feeded_vars_types[i]) + feed_tensors.append( + fluid.create_lod_tensor(t, [[1] * batch_size], place)) + else: + raise RuntimeError( + "vars with lod_level >= 2 is not supported now in this infer program check tool." + ) + results = exe.run(inference_program, + feed={ + name: feed_tensors[i] + for i, name in enumerate(feed_name_list) + }, + fetch_list=fetch_list, + return_numpy=return_numpy) + else: + logger.info("load feed vars from files: {}.".format( + feed_config.feeded_vars_filelist)) + feed_vars = [ + inference_program.global_block().var( + feed_config.feeded_vars_names[i]) + for i in range(len(feed_config.feeded_vars_names)) + ] + feeder = fluid.DataFeeder(feed_list=feed_vars, place=place) + batch_feed = feed_gen(batch_size, feed_config.feeded_vars_dims, + feed_config.feeded_vars_filelist) + slots = [batch_feed] + results = exe.run(inference_program, + feed=feeder.feed(slots), + fetch_list=fetch_list, + return_numpy=return_numpy) + for i, v in enumerate(fetch_list): + logger.info("fetch_targets name: %s" % v.name) + logger.info("fetch_targets: {}".format(results[i])) + return results + + +def check_not_expected_ops(prog): + op_types_set = set() + for op in prog.global_block().ops: + if op.type in not_expected_op_types and op.type not in op_types_set: + logger.warning( + "find op type '{}' in program, please check if your program is pruned correctly !". + format(op.type)) + op_types_set.add(op.type) + + +def check_saved_vars_try_dump(dump_dir, + dump_prog_fn, + is_text_dump_program, + feed_config, + fetch_config, + batch_size=1, + save_filename=None): + dump_prog = load_program( + os.path.join(dump_dir, dump_prog_fn), is_text_dump_program) + saved_params = [ + v for v in dump_prog.list_vars() if fluid.io.is_persistable(v) + ] + logger.info("persistable vars in dump program: {}".format( + [v.name for v in saved_params])) + + check_not_expected_ops(dump_prog) + + return try_load_model_vars(dump_dir, dump_prog_fn, is_text_dump_program, + batch_size, feed_config, fetch_config, + save_filename, saved_params) + + +def parse_program(program, output_dir): + # persistable vars + output = {} + persistable_vars = [ + v for v in program.list_vars() if fluid.io.is_persistable(v) + ] + output["persistable_vars"] = [{ + 'name': str(v.name), + 'shape': str(v.shape), + 'lod_level': int(v.lod_level), + 'dtype': str(v.dtype), + 'type': str(v.type) + } for v in persistable_vars] + with open(os.path.join(output_dir, persistable_vars_out_fn), 'w') as f: + f.write("persistable vars:\n") + for var in output["persistable_vars"]: + f.write(str(var)) + f.write("\n") + + # all vars + all_vars = [v for v in program.list_vars()] + output["all_vars"] = [{ + 'name': str(v.name), + 'shape': str(v.shape), + 'lod_level': int(v.lod_level), + 'dtype': str(v.dtype) + } if v.type not in feed_fetch_type_list else { + 'name': str(v.name), + 'type': str(v.type) + } for v in all_vars] + with open(os.path.join(output_dir, all_vars_out_fn), 'w') as f: + f.write("all vars:\n") + for var in output["all_vars"]: + f.write(str(var)) + f.write("\n") + + # ops + ops = program.global_block().ops + output["ops"] = [{ + 'type': op.type, + 'input_arg_names': str(op.input_arg_names), + 'output_arg_names': str(op.output_arg_names) + } for op in ops] + with open(os.path.join(output_dir, ops_out_fn), 'w') as f: + f.write("ops:\n") + for op in output["ops"]: + f.write(str(op)) + f.write("\n") diff --git a/python/paddle/fluid/tests/unittests/test_fleet_utils.py b/python/paddle/fluid/tests/unittests/test_fleet_utils.py index a26b27ff4f..51c1237594 100644 --- a/python/paddle/fluid/tests/unittests/test_fleet_utils.py +++ b/python/paddle/fluid/tests/unittests/test_fleet_utils.py @@ -13,14 +13,43 @@ # limitations under the License. from __future__ import print_function +import paddle import paddle.fluid as fluid import unittest +import numpy as np +import tarfile +import tempfile +import os +import sys +from paddle.dataset.common import download, DATA_HOME import paddle.fluid.incubate.fleet.base.role_maker as role_maker from paddle.fluid.incubate.fleet.parameter_server.distribute_transpiler import fleet from paddle.fluid.incubate.fleet.utils.fleet_barrier_util import check_all_trainers_ready +from paddle.fluid.incubate.fleet.utils.fleet_util import FleetUtil +import paddle.fluid.incubate.fleet.utils.utils as utils class TestFleetUtils(unittest.TestCase): + proto_data_url = "https://fleet.bj.bcebos.com/fleet_util_data.tgz" + proto_data_md5 = "59b7f12fd9dc24b64ae8e4629523a92a" + module_name = "fleet_util_data" + pruned_dir = os.path.join("fleet_util_data", "pruned_model") + train_dir = os.path.join("fleet_util_data", "train_program") + + def download_files(self): + path = download(self.proto_data_url, self.module_name, + self.proto_data_md5) + print('data is downloaded at ' + path) + tar = tarfile.open(path) + unzip_folder = tempfile.mkdtemp() + tar.extractall(unzip_folder) + return unzip_folder + + def test_fleet_util_init(self): + fleet_util_pslib = FleetUtil() + fleet_util_transpiler = FleetUtil(mode="transpiler") + self.assertRaises(Exception, FleetUtil, "other") + def test_fleet_barrier(self): role = role_maker.UserDefinedRoleMaker( current_id=0, @@ -30,6 +59,165 @@ class TestFleetUtils(unittest.TestCase): fleet.init(role) check_all_trainers_ready("/ready_path/", 0) + def test_program_type_trans(self): + data_dir = self.download_files() + program_dir = os.path.join(data_dir, self.pruned_dir) + text_program = "pruned_main_program.pbtxt" + binary_program = "pruned_main_program.bin" + fleet_util = FleetUtil() + text_to_binary = fleet_util.program_type_trans(program_dir, + text_program, True) + binary_to_text = fleet_util.program_type_trans(program_dir, + binary_program, False) + self.assertTrue( + os.path.exists(os.path.join(program_dir, text_to_binary))) + self.assertTrue( + os.path.exists(os.path.join(program_dir, binary_to_text))) + + def test_parse_program_proto(self): + data_dir = self.download_files() + parse_program_file_path = os.path.join( + data_dir, + os.path.join(self.pruned_dir, "pruned_main_program.pbtxt")) + is_text_parse_program = True + parse_output_dir = os.path.join(data_dir, self.pruned_dir) + fleet_util = FleetUtil() + fleet_util.parse_program_proto(parse_program_file_path, + is_text_parse_program, parse_output_dir) + ops_log = os.path.join(parse_output_dir, "ops.log") + vars_log = os.path.join(parse_output_dir, "vars_all.log") + vars_persistable = os.path.join(parse_output_dir, + "vars_persistable.log") + self.assertTrue(os.path.exists(ops_log)) + self.assertTrue(os.path.exists(vars_log)) + self.assertTrue(os.path.exists(vars_persistable)) + + def test_check_vars_and_dump(self): + data_dir = self.download_files() + + class config: + pass + + feed_config = config() + feed_config.feeded_vars_names = ['concat_1.tmp_0', 'concat_2.tmp_0'] + feed_config.feeded_vars_dims = [682, 1199] + feed_config.feeded_vars_types = [np.float32, np.float32] + feed_config.feeded_vars_filelist = [ + os.path.join(data_dir, os.path.join(self.pruned_dir, "concat_1")), + os.path.join(data_dir, os.path.join(self.pruned_dir, "concat_2")) + ] + + fetch_config = config() + fetch_config.fetch_vars_names = ['similarity_norm.tmp_0'] + + conf = config() + conf.batch_size = 1 + conf.feed_config = feed_config + conf.fetch_config = fetch_config + conf.dump_model_dir = os.path.join(data_dir, self.pruned_dir) + conf.dump_program_filename = "pruned_main_program.pbtxt" + conf.is_text_dump_program = True + conf.save_params_filename = None + + fleet_util = FleetUtil() + # test saved var's shape + conf.dump_program_filename = "pruned_main_program.save_var_shape_not_match" + self.assertRaises(Exception, fleet_util.check_vars_and_dump, conf) + + # test program.proto without feed_op and fetch_op + conf.dump_program_filename = "pruned_main_program.no_feed_fetch" + results = fleet_util.check_vars_and_dump(conf) + self.assertTrue(len(results) == 1) + np.testing.assert_array_almost_equal( + results[0], np.array( + [[3.0590223e-07]], dtype=np.float32)) + + # test feed_var's shape + conf.dump_program_filename = "pruned_main_program.feed_var_shape_not_match" + self.assertRaises(Exception, fleet_util.check_vars_and_dump, conf) + + # test correct case with feed_vars_filelist + conf.dump_program_filename = "pruned_main_program.pbtxt" + results = fleet_util.check_vars_and_dump(conf) + self.assertTrue(len(results) == 1) + np.testing.assert_array_almost_equal( + results[0], np.array( + [[3.0590223e-07]], dtype=np.float32)) + + # test correct case without feed_vars_filelist + conf.feed_config.feeded_vars_filelist = None + # test feed var with lod_level >= 2 + conf.dump_program_filename = "pruned_main_program.feed_lod2" + self.assertRaises(Exception, fleet_util.check_vars_and_dump, conf) + + conf.dump_program_filename = "pruned_main_program.pbtxt" + results = fleet_util.check_vars_and_dump(conf) + self.assertTrue(len(results) == 1) + + def test_check_two_programs(self): + data_dir = self.download_files() + + class config: + pass + + conf = config() + conf.train_prog_path = os.path.join( + data_dir, os.path.join(self.train_dir, "join_main_program.pbtxt")) + conf.is_text_train_program = True + + # test not match + conf.pruned_prog_path = os.path.join( + data_dir, + os.path.join(self.pruned_dir, + "pruned_main_program.save_var_shape_not_match")) + conf.is_text_pruned_program = True + conf.draw = False + fleet_util = FleetUtil() + res = fleet_util.check_two_programs(conf) + self.assertFalse(res) + + # test match + conf.pruned_prog_path = os.path.join( + data_dir, + os.path.join(self.pruned_dir, "pruned_main_program.pbtxt")) + if sys.platform == 'win32' or sys.platform == 'sys.platform': + conf.draw = False + else: + conf.draw = True + conf.draw_out_name = "pruned_check" + res = fleet_util.check_two_programs(conf) + self.assertTrue(res) + + def test_draw_program(self): + if sys.platform == 'win32' or sys.platform == 'sys.platform': + pass + else: + data_dir = self.download_files() + program_path = os.path.join( + data_dir, + os.path.join(self.train_dir, "join_main_program.pbtxt")) + is_text = True + program = utils.load_program(program_path, is_text) + output_dir = os.path.join(data_dir, self.train_dir) + output_filename_1 = "draw_prog_1" + output_filename_2 = "draw_prog_2" + fleet_util = FleetUtil() + fleet_util.draw_from_program_file(program_path, is_text, output_dir, + output_filename_1) + fleet_util.draw_from_program(program, output_dir, output_filename_2) + self.assertTrue( + os.path.exists( + os.path.join(output_dir, output_filename_1 + ".dot"))) + self.assertTrue( + os.path.exists( + os.path.join(output_dir, output_filename_1 + ".pdf"))) + self.assertTrue( + os.path.exists( + os.path.join(output_dir, output_filename_2 + ".dot"))) + self.assertTrue( + os.path.exists( + os.path.join(output_dir, output_filename_2 + ".pdf"))) + if __name__ == '__main__': unittest.main() -- GitLab