diff --git a/paddle/inference/inference.cc b/paddle/inference/inference.cc index 49001778808173b82865a4b6632a6b175ef96242..09268ffb3a1410b22f1b7d997a5cc0e4176b6d55 100644 --- a/paddle/inference/inference.cc +++ b/paddle/inference/inference.cc @@ -19,14 +19,10 @@ limitations under the License. */ #include "paddle/framework/init.h" #include "paddle/framework/scope.h" -#ifdef PADDLE_USE_PTOOLS -#include "chooseser.h" -#endif - namespace paddle { void InferenceEngine::LoadInferenceModel(const std::string& dirname) { - std::string model_filename = dirname + "/__model__.dat"; + std::string model_filename = dirname + "/__model__"; LOG(INFO) << "loading model from " << model_filename; std::ifstream inputfs(model_filename, std::ios::in | std::ios::binary); std::string program_desc_str; @@ -52,39 +48,15 @@ void InferenceEngine::LoadInferenceModel(const std::string& dirname) { } } -void InferenceEngine::LoadInferenceModel( - const std::string& dirname, - const std::vector& feed_var_names, - const std::vector& fetch_var_names) { - std::string model_filename = dirname + "/__model__.dat"; - LOG(INFO) << "loading model from " << model_filename; - std::ifstream inputfs(model_filename, std::ios::in | std::ios::binary); - std::string program_desc_str; - inputfs.seekg(0, std::ios::end); - program_desc_str.resize(inputfs.tellg()); - inputfs.seekg(0, std::ios::beg); - LOG(INFO) << "program_desc_str's size: " << program_desc_str.size(); - inputfs.read(&program_desc_str[0], program_desc_str.size()); - inputfs.close(); - - program_ = new framework::ProgramDesc(program_desc_str); - GenerateLoadProgram(dirname); - - if (feed_var_names.empty() || fetch_var_names.empty()) { - LOG(FATAL) << "Please specify the feed_var_names and fetch_var_names."; - } - feed_var_names_ = feed_var_names; - fetch_var_names_ = fetch_var_names; - PrependFeedOp(); - AppendFetchOp(); -} - bool InferenceEngine::IsParameter(const framework::VarDesc* var) { - if (var->Persistable() && var->Name() != "feed" && var->Name() != "fetch") { + if (var->Persistable()) { // There are many unreachable variables in the program for (size_t i = 0; i < program_->Size(); ++i) { const framework::BlockDesc& block = program_->Block(i); for (auto* op : block.AllOps()) { + if (op->Type() == "feed") { + continue; + } for (auto input_argument_name : op->InputArgumentNames()) { if (input_argument_name == var->Name()) { return true; diff --git a/paddle/inference/inference.h b/paddle/inference/inference.h index 7fc09cb9e539a65a8cd3cceb1543bc7d111c22b3..26f259824b945e260b370ced9d065842264075d5 100644 --- a/paddle/inference/inference.h +++ b/paddle/inference/inference.h @@ -29,9 +29,6 @@ public: } void LoadInferenceModel(const std::string& dirname); - void LoadInferenceModel(const std::string& dirname, - const std::vector& feed_var_names, - const std::vector& fetch_var_names); void Execute(const std::vector& feeds, std::vector& fetchs); diff --git a/python/paddle/v2/fluid/executor.py b/python/paddle/v2/fluid/executor.py index 9d5ed9571a2fa0a871a25e43b23b1a3c3a6102db..9f48815b8b84426c7d539af4e7d45ea47e69d4d9 100644 --- a/python/paddle/v2/fluid/executor.py +++ b/python/paddle/v2/fluid/executor.py @@ -68,6 +68,84 @@ def as_numpy(tensor): return ans +def has_feed_operators(block, feed_targets, feed_holder_name): + """ Check whether the block already has feed operators. + + Return false if the block does not have any feed operators. + If some feed operators have been prepended to the block, check that + the info contained in these feed operators matches the feed_targets + and feed_holder_name. Raise exception when any mismatch is found. + Return true when the block has feed operators with matching info. + + Args: + block: a block instance (typically global block of a program) + feed_targets: a dictionary of {feed_target_name: feed_target_data} + feed_holder_name: the name of the variable that holds the data of + all feed targets. The type of this feed_holder variable is + FEED_MINIBATCH, which is essentially vector. + + Returns: + A boolean value that indicates whether a block has feed operators + that match the info contained in feed_targets and feed_holder_name. + """ + + feed_count = 0 + for op in block.ops: + if op.desc.type() == 'feed': + feed_count += 1 + assert op.desc.input('X')[0] == feed_holder_name + feed_target_name = op.desc.output('Out')[0] + if feed_target_name not in feed_targets: + raise Exception("'feed_targets' does not have {} variable". + format(feed_target_name)) + else: + break + if feed_count > 0 and feed_count != len(feed_targets): + raise Exception( + "Feed operators in program desc do not match 'feed_targets'") + return feed_count > 0 + + +def has_fetch_operators(block, fetch_targets, fetch_holder_name): + """ Check whether the block already has fetch operators. + + Return false if the block does not have any fetch operators. + If some fetch operators have been appended to the block, check that + the info contained in these fetch operators matches the fetch_targets + and fetch_holder_name. Raise exception when any mismatch is found. + Return true when the block has fetch operators with matching info. + + Args: + block: a block instance (typically global block of a program) + fetch_targets: a dictionary of {fetch_target_name: fetch_target_data} + fetch_holder_name: the name of the variable that holds the data of + all fetch targets. The type of this fetch_holder variable is + FETCH_LIST, which is essentially vector. + + Return: + A boolean value that indicates whether a block has fetch operators + that match the info contained in fetch_targets and fetch_holder_name. + """ + + fetch_count = 0 + for op in block.ops: + if op.desc.type() == 'fetch': + fetch_count += 1 + assert op.desc.output('Out')[0] == fetch_holder_name + fetch_target_name = op.desc.input('X')[0] + if fetch_target_name not in [ + var.desc.name() for var in fetch_targets + ]: + raise Exception("'fetch_targets' does not have {} variable". + format(fetch_target_name)) + idx = op.desc.attr('col') + assert fetch_target_name == fetch_targets[idx].desc.name() + if fetch_count > 0 and fetch_count != len(fetch_targets): + raise Exception( + "Fetch operators in program desc do not match 'fetch_targets'") + return fetch_count > 0 + + class Executor(object): def __init__(self, places): if not isinstance(places, list) and not isinstance(places, tuple): @@ -147,33 +225,50 @@ class Executor(object): program = program.clone() global_block = program.global_block() - feed_var = global_block.create_var( - name=feed_var_name, - type=core.VarDesc.VarType.FEED_MINIBATCH, - persistable=True) - - for i, name in enumerate(feed): - out = global_block.var(name) - global_block.prepend_op( - 'feed', - inputs={'X': [feed_var]}, - outputs={'Out': [out]}, - attrs={'col': i}) - cur_feed = feed[name] - if not isinstance(cur_feed, core.LoDTensor): - cur_feed = self.aslodtensor(cur_feed) - core.set_feed_variable(scope, cur_feed, feed_var.name, i) - - fetch_var = global_block.create_var( - name=fetch_var_name, - type=core.VarDesc.VarType.FETCH_LIST, - persistable=True) - for i, var in enumerate(fetch_list): - global_block.append_op( - type='fetch', - inputs={'X': [var]}, - outputs={'Out': [fetch_var]}, - attrs={'col': i}) + + if feed_var_name in global_block.vars: + feed_var = global_block.var(feed_var_name) + else: + feed_var = global_block.create_var( + name=feed_var_name, + type=core.VarDesc.VarType.FEED_MINIBATCH, + persistable=True) + + if fetch_var_name in global_block.vars: + fetch_var = global_block.var(fetch_var_name) + else: + fetch_var = global_block.create_var( + name=fetch_var_name, + type=core.VarDesc.VarType.FETCH_LIST, + persistable=True) + + if not has_feed_operators(global_block, feed, feed_var_name): + for i, name in enumerate(feed): + out = global_block.var(name) + global_block.prepend_op( + type='feed', + inputs={'X': [feed_var]}, + outputs={'Out': [out]}, + attrs={'col': i}) + + for op in global_block.ops: + if op.desc.type() == 'feed': + feed_target_name = op.desc.output('Out')[0] + cur_feed = feed[feed_target_name] + if not isinstance(cur_feed, core.LoDTensor): + cur_feed = self.aslodtensor(cur_feed) + idx = op.desc.attr('col') + core.set_feed_variable(scope, cur_feed, feed_var_name, idx) + else: + break + + if not has_fetch_operators(global_block, fetch_list, fetch_var_name): + for i, var in enumerate(fetch_list): + global_block.append_op( + type='fetch', + inputs={'X': [var]}, + outputs={'Out': [fetch_var]}, + attrs={'col': i}) self.executor.run(program.desc, scope, 0, True, True) outs = [ diff --git a/python/paddle/v2/fluid/io.py b/python/paddle/v2/fluid/io.py index 606033af43a368293629d733f92b99d08a108187..d56ec45c538b580f5520bc060b4b339bb1be0539 100644 --- a/python/paddle/v2/fluid/io.py +++ b/python/paddle/v2/fluid/io.py @@ -13,7 +13,6 @@ # limitations under the License. import os -import cPickle as pickle from paddle.v2.fluid.evaluator import Evaluator from paddle.v2.fluid.framework import Program, Parameter, default_main_program, Variable @@ -200,12 +199,16 @@ def get_inference_program(target_vars, main_program=None): return inference_program -def prepend_feed_ops(inference_program, feeded_var_names): +def prepend_feed_ops(inference_program, + feed_target_names, + feed_holder_name='feed'): global_block = inference_program.global_block() feed_var = global_block.create_var( - name='feed', type=core.VarDesc.VarType.FEED_MINIBATCH, persistable=True) + name=feed_holder_name, + type=core.VarDesc.VarType.FEED_MINIBATCH, + persistable=True) - for i, name in enumerate(feeded_var_names): + for i, name in enumerate(feed_target_names): out = global_block.var(name) global_block.prepend_op( type='feed', @@ -214,12 +217,16 @@ def prepend_feed_ops(inference_program, feeded_var_names): attrs={'col': i}) -def append_fetch_ops(inference_program, fetch_var_names): +def append_fetch_ops(inference_program, + fetch_target_names, + fetch_holder_name='fetch'): global_block = inference_program.global_block() fetch_var = global_block.create_var( - name='fetch', type=core.VarDesc.VarType.FETCH_LIST, persistable=True) + name=fetch_holder_name, + type=core.VarDesc.VarType.FETCH_LIST, + persistable=True) - for i, name in enumerate(fetch_var_names): + for i, name in enumerate(fetch_target_names): global_block.append_op( type='fetch', inputs={'X': [name]}, @@ -269,21 +276,12 @@ def save_inference_model(dirname, inference_program = pruned_program.inference_optimize() fetch_var_names = [v.name for v in target_vars] - model_file_name = dirname + "/__model__" - with open(model_file_name, "w") as f: - pickle.dump({ - "program_desc_str": inference_program.desc.serialize_to_string(), - "feed_var_names": feeded_var_names, - "fetch_var_names": fetch_var_names - }, f, -1) - prepend_feed_ops(inference_program, feeded_var_names) append_fetch_ops(inference_program, fetch_var_names) - # Save only programDesc of inference_program in binary format - # in another file: __model__.dat - with open(model_file_name + ".dat", "wb") as fp: - fp.write(inference_program.desc.serialize_to_string()) + model_file_name = dirname + "/__model__" + with open(model_file_name, "wb") as f: + f.write(inference_program.desc.serialize_to_string()) save_params(executor, dirname, main_program) @@ -306,6 +304,24 @@ def load_persistables_if_exist(executor, dirname, main_program=None): predicate=_is_presistable_and_exist_) +def get_feed_targets_names(program): + feed_targets_names = [] + global_block = program.global_block() + for op in global_block.ops: + if op.desc.type() == 'feed': + feed_targets_names.insert(0, op.desc.output('Out')[0]) + return feed_targets_names + + +def get_fetch_targets_names(program): + fetch_targets_names = [] + global_block = program.global_block() + for op in global_block.ops: + if op.desc.type() == 'fetch': + fetch_targets_names.append(op.desc.input('X')[0]) + return fetch_targets_names + + def load_inference_model(dirname, executor): """ Load inference model from a directory @@ -313,24 +329,28 @@ def load_inference_model(dirname, executor): :param dirname: directory path :param executor: executor that load inference model - :return: [program, feed_var_names, fetch_var_names] + :return: [program, feed_target_names, fetch_targets] program: program especially for inference. - feeded_var_names: Names of variables that need to feed data - fetch_vars: Variables from which we can get inference results. + feed_target_names: Names of variables that need to feed data + fetch_targets: Variables from which we can get inference results. """ if not os.path.isdir(dirname): raise ValueError("There is no directory named '%s'", dirname) model_file_name = dirname + "/__model__" - model = pickle.load(open(model_file_name, "r")) - program_desc_str = model["program_desc_str"] - feed_var_names = model["feed_var_names"] - fetch_var_names = model["fetch_var_names"] + with open(model_file_name, "rb") as f: + program_desc_str = f.read() + program = Program.parse_from_string(program_desc_str) load_persistables_if_exist(executor, dirname, program) - fetch_vars = [program.global_block().var(name) for name in fetch_var_names] - return [program, feed_var_names, fetch_vars] + feed_target_names = get_feed_targets_names(program) + fetch_target_names = get_fetch_targets_names(program) + fetch_targets = [ + program.global_block().var(name) for name in fetch_target_names + ] + + return [program, feed_target_names, fetch_targets] def get_parameter_value(para, executor): diff --git a/python/paddle/v2/fluid/tests/book/test_recognize_digits_mlp.py b/python/paddle/v2/fluid/tests/book/test_recognize_digits_mlp.py index 8776a65bf804e93dfeb295ecca34fac0840b0a90..236ee4f3398538403228a00ee9c4b72c7c8231cf 100644 --- a/python/paddle/v2/fluid/tests/book/test_recognize_digits_mlp.py +++ b/python/paddle/v2/fluid/tests/book/test_recognize_digits_mlp.py @@ -91,6 +91,21 @@ for pass_id in range(PASS_NUM): fluid.io.save_inference_model( "./recognize_digits_mlp.inference.model/", ["x"], [predict], exe) - exit(0) - -exit(1) + break + +# Use load_inference_model to obtain the inference program desc, +# the feed_target_names (the names of variables that will be feeded +# data using feed operators), and the fetch_targets (variables that +# we want to obtain data from using fetch operators). +[infer_prog, feed_target_names, fetch_targets] = fluid.io.load_inference_model( + "./recognize_digits_mlp.inference.model/", exe) + +tensor_x = np.random.rand(1, 784).astype("float32") +# Construct feed as a dictionary of {feed_target_name: feed_target_data} +# and results will contain a list of data corresponding to fetch_targets. +results = exe.run(infer_prog, + feed={feed_target_names[0]: tensor_x}, + fetch_list=fetch_targets) +print(results[0]) + +exit(0)