diff --git a/paddle/framework/block_desc.cc b/paddle/framework/block_desc.cc index dd2ed87252102aee6d384f37365d19305f19b281..3e344ea3790f57b0f53f36a40263dcdd326e67a9 100644 --- a/paddle/framework/block_desc.cc +++ b/paddle/framework/block_desc.cc @@ -162,9 +162,8 @@ BlockDesc::BlockDesc(const BlockDesc &other, proto::BlockDesc *desc, : prog_(prog), desc_(desc) { need_update_ = true; for (auto &op : other.ops_) { - ops_.emplace_back(new OpDesc(*op, this)); + ops_.emplace_back(new OpDesc(*op->Proto(), prog, this)); } - for (auto &it : other.vars_) { auto *var = new VarDesc(*it.second); vars_[it.first].reset(var); diff --git a/paddle/framework/op_desc.cc b/paddle/framework/op_desc.cc index ea4028750248ec47f5094a67f736fb217216af6d..b51afe499bbc0e6b727aeeb4334f56e400ea81a5 100644 --- a/paddle/framework/op_desc.cc +++ b/paddle/framework/op_desc.cc @@ -125,11 +125,10 @@ OpDesc::OpDesc(const proto::OpDesc &desc, ProgramDesc *prog, BlockDesc *block) // restore attrs_ for (const proto::OpDesc::Attr &attr : desc_.attrs()) { std::string attr_name = attr.name(); + // The sub_block referred to by the BLOCK attr hasn't been added + // to ProgramDesc class yet, we skip setting BLOCK attr here. if (attr.type() != proto::AttrType::BLOCK) { attrs_[attr_name] = GetAttrValue(attr); - } else { - auto bid = attr.block_idx(); - attrs_[attr_name] = prog->MutableBlock(bid); } } this->block_ = block; diff --git a/paddle/framework/program_desc.cc b/paddle/framework/program_desc.cc index 15ea4035c6e6193105b621210a900e74d1466941..0e937dda4e185590648962a6d4f827eea21eb620 100644 --- a/paddle/framework/program_desc.cc +++ b/paddle/framework/program_desc.cc @@ -43,11 +43,20 @@ ProgramDesc::ProgramDesc() { ProgramDesc::ProgramDesc(const ProgramDesc &o) { desc_ = o.desc_; - for (int i = 0; i < desc_.blocks_size(); ++i) { auto *block = desc_.mutable_blocks(i); blocks_.emplace_back(new BlockDesc(*o.blocks_[i], block, this)); } + for (auto &block : blocks_) { + for (auto *op : block->AllOps()) { + for (const auto &attr : op->Proto()->attrs()) { + if (attr.type() == proto::AttrType::BLOCK) { + size_t blk_idx = attr.block_idx(); + op->SetBlockAttr(attr.name(), *this->MutableBlock(blk_idx)); + } + } + } + } } ProgramDesc::ProgramDesc(const proto::ProgramDesc &desc) { @@ -55,6 +64,16 @@ ProgramDesc::ProgramDesc(const proto::ProgramDesc &desc) { for (auto &block_desc : *desc_.mutable_blocks()) { blocks_.emplace_back(new BlockDesc(this, &block_desc)); } + for (auto &block : blocks_) { + for (auto *op : block->AllOps()) { + for (const auto &attr : op->Proto()->attrs()) { + if (attr.type() == proto::AttrType::BLOCK) { + size_t blk_idx = attr.block_idx(); + op->SetBlockAttr(attr.name(), *this->MutableBlock(blk_idx)); + } + } + } + } } ProgramDesc::ProgramDesc(const std::string &binary_str) { diff --git a/paddle/framework/prune.cc b/paddle/framework/prune.cc index bff8e0bceaca9749101b2c45edddba526d565624..ddd6b993d40f72cba919fad95318f70409c98bca 100644 --- a/paddle/framework/prune.cc +++ b/paddle/framework/prune.cc @@ -49,11 +49,28 @@ bool IsTarget(const proto::OpDesc& op_desc) { return false; } -void prune_impl(const proto::ProgramDesc& input, proto::ProgramDesc* output, - int block_id) { - // TODO(tonyyang-svail): - // - will change to use multiple blocks for RNN op and Cond Op +int GetSubBlockIndex(const proto::OpDesc& op_desc) { + for (auto& attr : op_desc.attrs()) { + if (attr.type() == proto::AttrType::BLOCK) { + PADDLE_ENFORCE(attr.has_block_idx()); + return attr.block_idx(); + } + } + return -1; +} + +bool HasSubBlock(const proto::OpDesc& op_desc) { + return GetSubBlockIndex(op_desc) > 0; +} +// block_id is the idx of the current block in the input desc +// parent_block_id is the idx of the parent of the current block +// in the output desc, -1 means the current block is global block +// dependent_vars is passed recursively from the parent block to +// the child block to help pruning +void prune_impl(const proto::ProgramDesc& input, proto::ProgramDesc* output, + int block_id, int parent_block_id, + std::set& dependent_vars) { auto& block = input.blocks(block_id); auto& ops = block.ops(); @@ -72,11 +89,9 @@ void prune_impl(const proto::ProgramDesc& input, proto::ProgramDesc* output, expect_fetch = (op_desc.type() == kFetchOpType); } - std::set dependent_vars; std::vector should_run; for (auto op_iter = ops.rbegin(); op_iter != ops.rend(); ++op_iter) { auto& op_desc = *op_iter; - if (IsTarget(op_desc) || HasDependentVar(op_desc, dependent_vars)) { // insert its input to the dependency graph for (auto& var : op_desc.inputs()) { @@ -84,7 +99,6 @@ void prune_impl(const proto::ProgramDesc& input, proto::ProgramDesc* output, dependent_vars.insert(argu); } } - should_run.push_back(true); } else { should_run.push_back(false); @@ -95,45 +109,81 @@ void prune_impl(const proto::ProgramDesc& input, proto::ProgramDesc* output, // we reverse the should_run vector std::reverse(should_run.begin(), should_run.end()); - *output = input; - auto* op_field = output->mutable_blocks(block_id)->mutable_ops(); + // copy the current block from input to output + auto* block_field = output->mutable_blocks(); + *block_field->Add() = input.blocks(block_id); + + int output_block_id = output->blocks_size() - 1; + auto* output_block = output->mutable_blocks(output_block_id); + output_block->set_idx(output_block_id); + output_block->set_parent_idx(parent_block_id); + + auto* op_field = output_block->mutable_ops(); op_field->Clear(); for (size_t i = 0; i < should_run.size(); ++i) { if (should_run[i]) { - *op_field->Add() = input.blocks(block_id).ops(i); + auto* op = op_field->Add(); + *op = input.blocks(block_id).ops(i); + if (HasSubBlock(*op)) { + // create sub_block_dependent_vars here to help prune the sub block + std::set sub_block_dependent_vars; + for (auto& var : op->inputs()) { + for (auto& argu : var.arguments()) { + sub_block_dependent_vars.insert(argu); + } + } + for (auto& var : op->outputs()) { + for (auto& argu : var.arguments()) { + sub_block_dependent_vars.insert(argu); + } + } + // GetSubBlockIndex(*op) is the idx of the sub_block in the input desc + // output_block_id is the idx of the current block in the output desc + prune_impl(input, output, GetSubBlockIndex(*op), output_block_id, + sub_block_dependent_vars); + } } } // remove the VarDescs in BlockDesc that are not referenced in // the pruned OpDescs std::unordered_map var_map; - auto* var_field = output->mutable_blocks(block_id)->mutable_vars(); + auto* var_field = output->mutable_blocks(output_block_id)->mutable_vars(); for (const auto& var : *var_field) { var_map[var.name()] = var; } - var_field->Clear(); + std::set var_names; for (const auto& op : *op_field) { - // add VarDescs of all input arguments for each OpDesc auto& input_field = op.inputs(); for (auto& input_var : input_field) { for (auto& arg : input_var.arguments()) { - *var_field->Add() = var_map[arg]; + if (var_map.count(arg) != 0) { + var_names.insert(arg); + } } } - // add VarDescs of all output arguments for each OpDesc auto& output_field = op.outputs(); for (auto& output_var : output_field) { for (auto& arg : output_var.arguments()) { - *var_field->Add() = var_map[arg]; + if (var_map.count(arg) != 0) { + var_names.insert(arg); + } } } } + + var_field->Clear(); + for (const auto& name : var_names) { + *var_field->Add() = var_map[name]; + } } // TODO(fengjiayi): Prune() could be inplaced to avoid unnecessary copies void Prune(const proto::ProgramDesc& input, proto::ProgramDesc* output) { - prune_impl(input, output, 0); + std::set dependent_vars; + output->clear_blocks(); + prune_impl(input, output, 0, -1, dependent_vars); } void inference_optimize_impl(const proto::ProgramDesc& input, diff --git a/paddle/inference/tests/book/CMakeLists.txt b/paddle/inference/tests/book/CMakeLists.txt index 63afeb18aebdf446c01cd4fdac13d238467801e4..0a96829bdd20f5dcb0c3fed501d27c27f2f73b17 100644 --- a/paddle/inference/tests/book/CMakeLists.txt +++ b/paddle/inference/tests/book/CMakeLists.txt @@ -27,3 +27,4 @@ endfunction(inference_test) inference_test(recognize_digits ARGS mlp) inference_test(image_classification ARGS vgg resnet) inference_test(label_semantic_roles) +inference_test(rnn_encoder_decoder) diff --git a/paddle/inference/tests/book/test_inference_rnn_encoder_decoder.cc b/paddle/inference/tests/book/test_inference_rnn_encoder_decoder.cc new file mode 100644 index 0000000000000000000000000000000000000000..9bfc0407b7f2732a14e7ac0f319a3d39b9e641bc --- /dev/null +++ b/paddle/inference/tests/book/test_inference_rnn_encoder_decoder.cc @@ -0,0 +1,67 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include "gflags/gflags.h" +#include "test_helper.h" + +DEFINE_string(dirname, "", "Directory of the inference model."); + +TEST(inference, rnn_encoder_decoder) { + if (FLAGS_dirname.empty()) { + LOG(FATAL) << "Usage: ./example --dirname=path/to/your/model"; + } + + LOG(INFO) << "FLAGS_dirname: " << FLAGS_dirname << std::endl; + std::string dirname = FLAGS_dirname; + + // 0. Call `paddle::framework::InitDevices()` initialize all the devices + // In unittests, this is done in paddle/testing/paddle_gtest_main.cc + + paddle::framework::LoDTensor word_data, trg_word; + paddle::framework::LoD lod{{0, 4, 10}}; + + SetupLoDTensor( + word_data, lod, static_cast(0), static_cast(1)); + SetupLoDTensor( + trg_word, lod, static_cast(0), static_cast(1)); + + std::vector cpu_feeds; + cpu_feeds.push_back(&word_data); + cpu_feeds.push_back(&trg_word); + + paddle::framework::LoDTensor output1; + std::vector cpu_fetchs1; + cpu_fetchs1.push_back(&output1); + + // Run inference on CPU + TestInference( + dirname, cpu_feeds, cpu_fetchs1); + LOG(INFO) << output1.lod(); + LOG(INFO) << output1.dims(); + +#ifdef PADDLE_WITH_CUDA + paddle::framework::LoDTensor output2; + std::vector cpu_fetchs2; + cpu_fetchs2.push_back(&output2); + + // Run inference on CUDA GPU + TestInference( + dirname, cpu_feeds, cpu_fetchs2); + LOG(INFO) << output2.lod(); + LOG(INFO) << output2.dims(); + + CheckError(output1, output2); +#endif +} diff --git a/python/paddle/v2/fluid/tests/book/test_rnn_encoder_decoder.py b/python/paddle/v2/fluid/tests/book/test_rnn_encoder_decoder.py index fdc60861760163d2ebad3b050e551929321baafd..7fe43c680ca9319682c42836986308856185a464 100644 --- a/python/paddle/v2/fluid/tests/book/test_rnn_encoder_decoder.py +++ b/python/paddle/v2/fluid/tests/book/test_rnn_encoder_decoder.py @@ -18,6 +18,10 @@ import paddle.v2.fluid as fluid import paddle.v2.fluid.core as core import paddle.v2.fluid.framework as framework import paddle.v2.fluid.layers as layers +import contextlib +import math +import sys +import unittest from paddle.v2.fluid.executor import Executor dict_size = 30000 @@ -145,7 +149,7 @@ def seq_to_seq_net(): cost = fluid.layers.cross_entropy(input=prediction, label=label) avg_cost = fluid.layers.mean(x=cost) - return avg_cost + return avg_cost, prediction def to_lodtensor(data, place): @@ -163,8 +167,16 @@ def to_lodtensor(data, place): return res -def main(): - avg_cost = seq_to_seq_net() +def create_random_lodtensor(lod, place, low, high): + data = np.random.random_integers(low, high, [lod[-1], 1]).astype("int64") + res = fluid.LoDTensor() + res.set(data, place) + res.set_lod([lod]) + return res + + +def train(use_cuda, save_dirname=None): + [avg_cost, prediction] = seq_to_seq_net() optimizer = fluid.optimizer.Adagrad(learning_rate=1e-4) optimizer.minimize(avg_cost) @@ -174,7 +186,7 @@ def main(): paddle.dataset.wmt14.train(dict_size), buf_size=1000), batch_size=batch_size) - place = core.CPUPlace() + place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace() exe = Executor(place) exe.run(framework.default_startup_program()) @@ -185,6 +197,7 @@ def main(): word_data = to_lodtensor(map(lambda x: x[0], data), place) trg_word = to_lodtensor(map(lambda x: x[1], data), place) trg_word_next = to_lodtensor(map(lambda x: x[2], data), place) + outs = exe.run(framework.default_main_program(), feed={ 'source_sequence': word_data, @@ -192,13 +205,86 @@ def main(): 'label_sequence': trg_word_next }, fetch_list=[avg_cost]) + avg_cost_val = np.array(outs[0]) print('pass_id=' + str(pass_id) + ' batch=' + str(batch_id) + " avg_cost=" + str(avg_cost_val)) + if math.isnan(float(avg_cost_val[0])): + sys.exit("got NaN loss, training failed.") if batch_id > 3: - exit(0) + if save_dirname is not None: + fluid.io.save_inference_model( + save_dirname, ['source_sequence', + 'target_sequence'], [prediction], exe) + return + batch_id += 1 +def infer(use_cuda, save_dirname=None): + if save_dirname is None: + return + + place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace() + exe = fluid.Executor(place) + + # Use fluid.io.load_inference_model to obtain the inference program desc, + # the feed_target_names (the names of variables that will be feeded + # data using feed operators), and the fetch_targets (variables that + # we want to obtain data from using fetch operators). + [inference_program, feed_target_names, + fetch_targets] = fluid.io.load_inference_model(save_dirname, exe) + + lod = [0, 4, 10] + word_data = create_random_lodtensor(lod, place, low=0, high=1) + trg_word = create_random_lodtensor(lod, place, low=0, high=1) + + # Construct feed as a dictionary of {feed_target_name: feed_target_data} + # and results will contain a list of data corresponding to fetch_targets. + assert feed_target_names[0] == 'source_sequence' + assert feed_target_names[1] == 'target_sequence' + results = exe.run(inference_program, + feed={ + feed_target_names[0]: word_data, + feed_target_names[1]: trg_word, + }, + fetch_list=fetch_targets, + return_numpy=False) + print(results[0].lod()) + np_data = np.array(results[0]) + print("Inference shape: ", np_data.shape) + print("Inference results: ", np_data) + + +def main(use_cuda): + if use_cuda and not fluid.core.is_compiled_with_cuda(): + return + + # Directory for saving the trained model + save_dirname = "rnn_encoder_decoder.inference.model" + + train(use_cuda, save_dirname) + infer(use_cuda, save_dirname) + + +class TestRnnEncoderDecoder(unittest.TestCase): + def test_cuda(self): + with self.scope_prog_guard(): + main(use_cuda=True) + + def test_cpu(self): + with self.scope_prog_guard(): + main(use_cuda=False) + + @contextlib.contextmanager + def scope_prog_guard(self): + prog = fluid.Program() + startup_prog = fluid.Program() + scope = fluid.core.Scope() + with fluid.scope_guard(scope): + with fluid.program_guard(prog, startup_prog): + yield + + if __name__ == '__main__': - main() + unittest.main()