diff --git a/README.md b/README.md index 60ffbe728178705b1734e682868614025214c2a4..45186ec4ef48dc305b2616dbf4966f01c3609962 100644 --- a/README.md +++ b/README.md @@ -19,7 +19,7 @@ Our vision is to enable deep learning for everyone via PaddlePaddle. Please refer to our [release announcement](https://github.com/PaddlePaddle/Paddle/releases) to track the latest feature of PaddlePaddle. -### Latest PaddlePaddle Release: [Fluid 0.14.0](https://github.com/PaddlePaddle/Paddle/tree/v0.14.0) +### Latest PaddlePaddle Release: [Fluid 0.15.0](https://github.com/PaddlePaddle/Paddle/tree/v0.15.0) ### Install Latest Stable Release: ``` # Linux CPU @@ -76,26 +76,26 @@ pip install paddlepaddle-gpu==0.14.0.post85 ## Installation -It is recommended to read [this doc](http://paddlepaddle.org/documentation/docs/zh/0.14.0/new_docs/beginners_guide/install/install_doc.html) on our website. +It is recommended to read [this doc](http://paddlepaddle.org/documentation/docs/zh/0.15.0/new_docs/beginners_guide/install/install_doc.html) on our website. ## Documentation -We provide [English](http://paddlepaddle.org/documentation/docs/en/0.14.0/getstarted/index_en.html) and -[Chinese](http://paddlepaddle.org/documentation/docs/zh/0.14.0/new_docs/beginners_guide/index.html) documentation. +We provide [English](http://paddlepaddle.org/documentation/docs/en/0.15.0/getstarted/index_en.html) and +[Chinese](http://paddlepaddle.org/documentation/docs/zh/0.15.0/new_docs/beginners_guide/index.html) documentation. - [Deep Learning 101](https://github.com/PaddlePaddle/book) You might want to start from this online interactive book that can run in a Jupyter Notebook. -- [Distributed Training](http://paddlepaddle.org/documentation/docs/zh/0.14.0/new_docs/user_guides/howto/training/cluster_howto.html) +- [Distributed Training](http://paddlepaddle.org/documentation/docs/zh/0.15.0/new_docs/user_guides/howto/training/cluster_howto.html) You can run distributed training jobs on MPI clusters. -- [Python API](http://paddlepaddle.org/documentation/api/zh/0.14.0/fluid.html) +- [Python API](http://paddlepaddle.org/documentation/api/zh/0.15.0/fluid.html) Our new API enables much shorter programs. -- [How to Contribute](http://paddlepaddle.org/documentation/docs/zh/0.14.0/new_docs/advanced_usage/development/contribute_to_paddle.html) +- [How to Contribute](http://paddlepaddle.org/documentation/docs/zh/0.15.0/new_docs/advanced_usage/development/contribute_to_paddle.html) We appreciate your contributions! diff --git a/paddle/fluid/framework/ir/CMakeLists.txt b/paddle/fluid/framework/ir/CMakeLists.txt index ce3ebed00b4db81b6ba1bab566a56207341a67c0..7004f484a9975124750fad4cb8f773342082b514 100644 --- a/paddle/fluid/framework/ir/CMakeLists.txt +++ b/paddle/fluid/framework/ir/CMakeLists.txt @@ -28,6 +28,9 @@ cc_library(graph_pattern_detector SRCS graph_pattern_detector.cc DEPS graph grap pass_library(graph_to_program_pass base) pass_library(graph_viz_pass base) pass_library(fc_fuse_pass inference) +if(WITH_MKLDNN) + pass_library(conv_relu_mkldnn_fuse_pass inference) +endif() pass_library(attention_lstm_fuse_pass inference) pass_library(infer_clean_graph_pass inference) pass_library(fc_lstm_fuse_pass inference) @@ -42,3 +45,6 @@ cc_test(graph_helper_test SRCS graph_helper_test.cc DEPS graph graph_helper op_r cc_test(graph_to_program_pass_test SRCS graph_to_program_pass_test.cc DEPS graph_to_program_pass) cc_test(test_graph_pattern_detector SRCS graph_pattern_detector_tester.cc DEPS graph_pattern_detector) cc_test(test_fc_fuse_pass SRCS fc_fuse_pass_tester.cc DEPS fc_fuse_pass framework_proto) +if(WITH_MKLDNN) + cc_test(test_conv_relu_mkldnn_fuse_pass SRCS conv_relu_mkldnn_fuse_pass_tester.cc DEPS conv_relu_mkldnn_fuse_pass) +endif() diff --git a/paddle/fluid/framework/ir/conv_relu_mkldnn_fuse_pass.cc b/paddle/fluid/framework/ir/conv_relu_mkldnn_fuse_pass.cc new file mode 100644 index 0000000000000000000000000000000000000000..4408cb45acb3d46e1addf5c25c238af50e5f5e5f --- /dev/null +++ b/paddle/fluid/framework/ir/conv_relu_mkldnn_fuse_pass.cc @@ -0,0 +1,90 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/framework/ir/conv_relu_mkldnn_fuse_pass.h" +#include +#include +#include "paddle/fluid/platform/enforce.h" + +namespace paddle { +namespace framework { +namespace ir { + +std::unique_ptr ConvReLUFusePass::ApplyImpl( + std::unique_ptr graph) const { + PADDLE_ENFORCE(graph.get()); + FusePassBase::Init("conv_relu_mkldnn_fuse", graph.get()); + + std::unordered_set nodes2delete; + + GraphPatternDetector gpd; + auto* conv_input = gpd.mutable_pattern() + ->NewNode("conv_relu_mkldnn_fuse/conv_input") + ->AsInput() + ->assert_is_op_input("conv2d", "Input"); + patterns::ConvReLU conv_relu_pattern(gpd.mutable_pattern(), + "conv_relu_mkldnn_fuse"); + conv_relu_pattern(conv_input); + + int found_conv_relu_count = 0; + auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, + Graph* g) { + VLOG(4) << "handle ConvReLU fuse"; + GET_IR_NODE_FROM_SUBGRAPH(conv_weight, conv_weight, + conv_relu_pattern); // Filter + GET_IR_NODE_FROM_SUBGRAPH(conv_bias, conv_bias, conv_relu_pattern); // Bias + GET_IR_NODE_FROM_SUBGRAPH(conv_out, conv_out, conv_relu_pattern); // tmp + GET_IR_NODE_FROM_SUBGRAPH(conv, conv, conv_relu_pattern); // CONV op + GET_IR_NODE_FROM_SUBGRAPH(relu_out, relu_out, conv_relu_pattern); // Out + GET_IR_NODE_FROM_SUBGRAPH(relu, relu, conv_relu_pattern); // ReLU op + + // Create an ConvReLU Node. + OpDesc desc; + std::string conv_relu_i_in = subgraph.at(conv_input)->Name(); + std::string conv_relu_w_in = conv_weight->Name(); + std::string conv_relu_b_in = conv_bias->Name(); + std::string conv_relu_out = relu_out->Name(); + desc.SetInput("Input", std::vector({conv_relu_i_in})); + desc.SetInput("Filter", std::vector({conv_relu_w_in})); + desc.SetInput("Bias", std::vector({conv_relu_b_in})); + desc.SetOutput("Out", std::vector({conv_relu_out})); + desc.SetType("conv2d"); + for (auto& attr : conv->Op()->GetAttrMap()) { + desc.SetAttr(attr.first, attr.second); + } + desc.SetAttr("fuse_relu", true); + auto conv_relu_node = g->CreateOpNode(&desc); // OpDesc will be copied. + GraphSafeRemoveNodes(graph.get(), {conv, relu, conv_out}); + + PADDLE_ENFORCE(subgraph.count(conv_input)); + IR_NODE_LINK_TO(subgraph.at(conv_input), conv_relu_node); + IR_NODE_LINK_TO(conv_weight, conv_relu_node); + IR_NODE_LINK_TO(conv_bias, conv_relu_node); + IR_NODE_LINK_TO(conv_relu_node, relu_out); + + found_conv_relu_count++; + }; + + gpd(graph.get(), handler); + + AddStatis(found_conv_relu_count); + return graph; +} + +} // namespace ir +} // namespace framework +} // namespace paddle + +REGISTER_PASS(conv_relu_mkldnn_fuse_pass, + paddle::framework::ir::ConvReLUFusePass); diff --git a/paddle/fluid/framework/ir/conv_relu_mkldnn_fuse_pass.h b/paddle/fluid/framework/ir/conv_relu_mkldnn_fuse_pass.h new file mode 100644 index 0000000000000000000000000000000000000000..b5de0d548713772e7ad41cfb6d8b3e9460683efb --- /dev/null +++ b/paddle/fluid/framework/ir/conv_relu_mkldnn_fuse_pass.h @@ -0,0 +1,39 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/fluid/framework/ir/fuse_pass_base.h" +#include "paddle/fluid/framework/ir/graph.h" +#include "paddle/fluid/framework/ir/graph_pattern_detector.h" +#include "paddle/fluid/framework/ir/pass.h" + +namespace paddle { +namespace framework { +namespace ir { + +/* + * Fuse the CONV and ReLU to a ConvReLUOp. + */ +class ConvReLUFusePass : public FusePassBase { + public: + virtual ~ConvReLUFusePass() {} + + protected: + std::unique_ptr ApplyImpl(std::unique_ptr graph) const; +}; + +} // namespace ir +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/ir/conv_relu_mkldnn_fuse_pass_tester.cc b/paddle/fluid/framework/ir/conv_relu_mkldnn_fuse_pass_tester.cc new file mode 100644 index 0000000000000000000000000000000000000000..82b5fa1886098ca3b19c147c307d3f2fc3ba03d6 --- /dev/null +++ b/paddle/fluid/framework/ir/conv_relu_mkldnn_fuse_pass_tester.cc @@ -0,0 +1,108 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/framework/ir/conv_relu_mkldnn_fuse_pass.h" + +#include + +namespace paddle { +namespace framework { +namespace ir { + +void SetOp(ProgramDesc* prog, const std::string& type, + const std::vector& inputs, + const std::vector& outputs) { + auto* op = prog->MutableBlock(0)->AppendOp(); + op->SetType(type); + if (type == "conv2d") { + op->SetAttr("use_mkldnn", true); + op->SetInput("Input", {inputs[0]}); + op->SetInput("Filter", {inputs[1]}); + op->SetInput("Bias", {inputs[2]}); + } else if (type == "relu") { + op->SetInput("X", inputs); + } + op->SetOutput("Out", outputs); +} + +// a->OP0->b +// b->OP1->c +// (c, weights, bias)->conv->f +// (f)->relu->g +ProgramDesc BuildProgramDesc() { + ProgramDesc prog; + for (auto& v : + std::vector({"a", "b", "c", "weights", "bias", "f", "g"})) { + auto* var = prog.MutableBlock(0)->Var(v); + var->SetType(proto::VarType::SELECTED_ROWS); + if (v == "weights" || v == "bias") { + var->SetPersistable(true); + } + } + + SetOp(&prog, "OP0", std::vector({"a"}), + std::vector({"b"})); + SetOp(&prog, "OP1", std::vector({"b"}), + std::vector({"c"})); + SetOp(&prog, "conv2d", std::vector({"c", "weights", "bias"}), + std::vector({"f"})); + SetOp(&prog, "relu", std::vector({"f"}), + std::vector({"g"})); + + return prog; +} + +TEST(ConvReLUFusePass, basic) { + auto prog = BuildProgramDesc(); + + std::unique_ptr graph(new ir::Graph(prog)); + + auto pass = PassRegistry::Instance().Get("conv_relu_mkldnn_fuse_pass"); + + int original_nodes_num = graph->Nodes().size(); + + graph = pass->Apply(std::move(graph)); + + int current_nodes_num = graph->Nodes().size(); + + // Remove 3 Nodes: CONV, RELU, conv_out + // Add 1 Node: ConvReLU + EXPECT_EQ(original_nodes_num - 2, current_nodes_num); + + // Assert conv_relu op in newly generated graph + int conv_relu_count = 0; + + for (auto* node : graph->Nodes()) { + if (node->IsOp() && node->Op()->Type() == "conv2d") { + if (node->Op()->HasAttr("use_mkldnn")) { + bool use_mkldnn = boost::get(node->Op()->GetAttr("use_mkldnn")); + if (use_mkldnn) { + if (node->Op()->HasAttr("fuse_relu")) { + bool fuse_relu = boost::get(node->Op()->GetAttr("fuse_relu")); + if (fuse_relu) { + ++conv_relu_count; + } + } + } + } + } + } + EXPECT_EQ(conv_relu_count, 1); +} + +} // namespace ir +} // namespace framework +} // namespace paddle + +USE_PASS(conv_relu_mkldnn_fuse_pass); diff --git a/paddle/fluid/framework/ir/fc_lstm_fuse_pass.cc b/paddle/fluid/framework/ir/fc_lstm_fuse_pass.cc index f7fda873574a0f8b10251d4fa6b604a9312ad7f9..aa95d3e9f6c8221f6e48d192b73ad5135539dc75 100644 --- a/paddle/fluid/framework/ir/fc_lstm_fuse_pass.cc +++ b/paddle/fluid/framework/ir/fc_lstm_fuse_pass.cc @@ -51,7 +51,7 @@ int BuildFusion(Graph* graph, const std::string& name_scope, Scope* scope, if (with_fc_bias) { // Add FC-bias with LSTM-bias and create a new weight PADDLE_ENFORCE(scope); - const std::string& new_bias_var = name_scope + "_bias.new"; + const std::string& new_bias_var = patterns::UniqueKey("NewBias"); auto* bias_var = scope->Var(new_bias_var); PADDLE_ENFORCE(bias_var); auto* bias_tensor = bias_var->GetMutable(); @@ -120,7 +120,6 @@ int BuildFusion(Graph* graph, const std::string& name_scope, Scope* scope, auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, Graph* g) { - GET_IR_NODE_FROM_SUBGRAPH(lstm, lstm, lstm_pattern); GET_IR_NODE_FROM_SUBGRAPH(Weight, Weight, lstm_pattern); GET_IR_NODE_FROM_SUBGRAPH(Bias, Bias, lstm_pattern); @@ -136,7 +135,7 @@ int BuildFusion(Graph* graph, const std::string& name_scope, Scope* scope, fc_bias); // Remove unneeded nodes. std::unordered_set marked_nodes( - {mul, lstm, elementwise_add}); + {mul, lstm, elementwise_add, fc_bias}); GraphSafeRemoveNodes(graph, marked_nodes); } else { GET_IR_NODE_FROM_SUBGRAPH(fc_out, mul_out, fc_pattern); diff --git a/paddle/fluid/framework/ir/graph_pattern_detector.cc b/paddle/fluid/framework/ir/graph_pattern_detector.cc index 5825a129b731afe9de468b5a611c25ac2753aa3f..11d5998aafe1f325b94ef1a5ea1c13c72c13f5c9 100644 --- a/paddle/fluid/framework/ir/graph_pattern_detector.cc +++ b/paddle/fluid/framework/ir/graph_pattern_detector.cc @@ -522,6 +522,39 @@ bool VarLinksFromOp(Node* node, const std::string& op_type) { return false; } +PDNode* patterns::ConvReLU::operator()( + paddle::framework::ir::PDNode* conv_input) { + // Create Operators + conv_input->assert_is_op_input("conv2d", "Input"); + auto* conv_op = pattern->NewNode(conv_repr())->assert_is_op("conv2d"); + auto* relu_op = pattern->NewNode(relu_repr())->assert_is_op("relu"); + // Create variables + // Filter + auto* conv_weight_var = pattern->NewNode(conv_weight_repr()) + ->AsInput() + ->assert_is_persistable_var() + ->assert_is_op_input("conv2d", "Filter"); + // Bias + auto* conv_bias_var = pattern->NewNode(conv_bias_repr()) + ->AsInput() + ->assert_is_persistable_var() + ->assert_is_op_input("conv2d", "Bias"); + // intermediate variable, will be removed in the IR after fuse. + auto* conv_out_var = pattern->NewNode(conv_out_repr()) + ->AsIntermediate() + ->assert_is_only_output_of_op("conv2d") + ->assert_is_op_input("relu"); + // output + auto* relu_out_var = pattern->NewNode(relu_out_repr()) + ->AsOutput() + ->assert_is_op_output("relu"); + + conv_op->LinksFrom({conv_input, conv_weight_var, conv_bias_var}) + .LinksTo({conv_out_var}); + relu_op->LinksFrom({conv_out_var}).LinksTo({relu_out_var}); + return relu_out_var; +} + PDNode* patterns::FC::operator()(paddle::framework::ir::PDNode* x, bool with_bias) { // Create shared nodes. diff --git a/paddle/fluid/framework/ir/graph_pattern_detector.h b/paddle/fluid/framework/ir/graph_pattern_detector.h index 57482a07b607ba1d9fa06a5f325f60ba58dce307..371384dc56eec91db1f621c0ebb65113e7a5a5cc 100644 --- a/paddle/fluid/framework/ir/graph_pattern_detector.h +++ b/paddle/fluid/framework/ir/graph_pattern_detector.h @@ -360,6 +360,28 @@ struct PatternBase { size_t id_; }; +// CONV with ReLU +// op: conv + relu +// named nodes: +// conv_input, conv_weight, +// conv_bias, conv_out, conv, +// relu_out, relu +struct ConvReLU : public PatternBase { + ConvReLU(PDPattern* pattern, const std::string& name_scope) + : PatternBase(pattern, name_scope, "conv_relu") {} + + PDNode* operator()(PDNode* conv_input); + + // declare operator node's name + PATTERN_DECL_NODE(conv); + PATTERN_DECL_NODE(relu); + // declare variable node's name + PATTERN_DECL_NODE(conv_weight); + PATTERN_DECL_NODE(conv_bias); + PATTERN_DECL_NODE(conv_out); + PATTERN_DECL_NODE(relu_out); +}; + // FC with bias // op: mul + elementwise_add // named nodes: diff --git a/paddle/fluid/framework/operator.cc b/paddle/fluid/framework/operator.cc index d58d6e4f3e684b97fcc1121e51355bdf3aae3fce..b7fae7171a57666a8fb4613a7cbe3aa15997b638 100644 --- a/paddle/fluid/framework/operator.cc +++ b/paddle/fluid/framework/operator.cc @@ -464,35 +464,35 @@ class RuntimeInferShapeContext : public InferShapeContext { : op_(op), scope_(scope) {} bool HasInput(const std::string& name) const override { - if (!op_.HasInputs(name)) { + // has only one input + const auto& ins = op_.Inputs(); + auto it = ins.find(name); + if (it == ins.end()) { return false; } - auto& ins = Inputs(name); - size_t length = ins.size(); - if (length == 0) { + const auto& in = it->second; + if (in.size() == 0 || in[0] == kEmptyVarName) { return false; } - PADDLE_ENFORCE_EQ(length, 1UL, + PADDLE_ENFORCE_EQ(in.size(), 1UL, "Input %s should not have more than one inputs", name); - auto ipt = ins[0]; - auto* var = ipt == kEmptyVarName ? nullptr : scope_.FindVar(ipt); - return var != nullptr; + return scope_.FindVar(in[0]) != nullptr; } bool HasOutput(const std::string& name) const override { - if (!op_.HasOutputs(name)) { + // has only one output + const auto& outs = op_.Outputs(); + auto it = outs.find(name); + if (it == outs.end()) { return false; } - auto& outs = Outputs(name); - size_t length = outs.size(); - if (length == 0) { + const auto& out = it->second; + if (out.size() == 0 || out[0] == kEmptyVarName) { return false; } - PADDLE_ENFORCE_EQ(length, 1UL, - "Output %s should not have more than one inputs", name); - auto ipt = outs[0]; - auto* var = ipt == kEmptyVarName ? nullptr : scope_.FindVar(ipt); - return var != nullptr; + PADDLE_ENFORCE_EQ(out.size(), 1UL, + "Output %s should not have more than one outputs", name); + return scope_.FindVar(out[0]) != nullptr; } bool HasInputs(const std::string& name) const override { diff --git a/paddle/fluid/framework/parallel_executor.cc b/paddle/fluid/framework/parallel_executor.cc index 81cb24bdda6b87a3d708cf5047dce05d5020a0d5..5b8c75a93de2ddd8f7260d2191c22a5945b3d2d9 100644 --- a/paddle/fluid/framework/parallel_executor.cc +++ b/paddle/fluid/framework/parallel_executor.cc @@ -352,7 +352,10 @@ void ParallelExecutor::FeedAndSplitTensorIntoLocalScopes( ParallelExecutor::~ParallelExecutor() { if (member_->own_local_scope_) { for (size_t i = 1; i < member_->local_scopes_.size(); ++i) { - member_->global_scope_->DeleteScope(member_->local_scopes_[i]); + Scope *local_scope = member_->local_scopes_[i]; + if (member_->global_scope_->HasKid(local_scope)) { + member_->global_scope_->DeleteScope(local_scope); + } } } } diff --git a/paddle/fluid/framework/program_desc_test.cc b/paddle/fluid/framework/program_desc_test.cc index 925ea98dbe62e4da91689f6e56c135e51c24a8a3..7e689a37da8a16bd9b1ac6650b9322d2eb5a2c85 100644 --- a/paddle/fluid/framework/program_desc_test.cc +++ b/paddle/fluid/framework/program_desc_test.cc @@ -87,8 +87,17 @@ TEST(ProgramDesc, copy_ctor) { ASSERT_EQ(op_origin->Inputs(), op_copy->Inputs()); ASSERT_EQ(op_origin->Outputs(), op_copy->Outputs()); - ASSERT_EQ(op_copy->Proto()->SerializeAsString(), - op_origin->Proto()->SerializeAsString()); + ASSERT_EQ(op_origin->Proto()->attrs().size(), + op_copy->Proto()->attrs().size()); + for (auto it = op_origin->Proto()->attrs().begin(); + it != op_origin->Proto()->attrs().end(); ++it) { + for (auto it_2 = op_copy->Proto()->attrs().begin(); + it_2 != op_copy->Proto()->attrs().end(); ++it_2) { + if (it->name() == it_2->name()) { + ASSERT_TRUE(it_2->SerializeAsString() == it->SerializeAsString()); + } + } + } if (op->Type() == "op_with_subblock") { ASSERT_EQ(1, op->GetBlockAttrId("sub_block")); diff --git a/paddle/fluid/framework/scope.cc b/paddle/fluid/framework/scope.cc index 50f374e3703a97f6c1fdb4b14fdeb0b603f9ac86..2be655b89a4caf2bf9874dcab6bc0bdb2856a026 100644 --- a/paddle/fluid/framework/scope.cc +++ b/paddle/fluid/framework/scope.cc @@ -72,6 +72,12 @@ void Scope::DropKids() { kids_.clear(); } +bool Scope::HasKid(const Scope* scope) const { + std::unique_lock lock(mutex_); + auto it = std::find(this->kids_.begin(), this->kids_.end(), scope); + return it != this->kids_.end(); +} + std::vector Scope::LocalVarNames() const { std::unique_lock lock(mutex_); std::vector known_vars; diff --git a/paddle/fluid/framework/scope.h b/paddle/fluid/framework/scope.h index e246241c0abfbc7bdcaf38d073cc58fc36a4f737..b6165a595d537c314a95685e8b1edbc42e387ab7 100644 --- a/paddle/fluid/framework/scope.h +++ b/paddle/fluid/framework/scope.h @@ -71,6 +71,9 @@ class Scope { /// Drop all kids scopes belonged to this scope. void DropKids(); + /// Find if a scope exists in the kid scopes + bool HasKid(const Scope* scope) const; + // enumerate all the variables current contains. std::vector LocalVarNames() const; diff --git a/paddle/fluid/inference/analysis/ir_pass_manager.cc b/paddle/fluid/inference/analysis/ir_pass_manager.cc index 30c1e8e93d2513dca4531d224fb939ec47a7739d..e76708baf4b39afb0febbcf3ff71281dfbfc8627 100644 --- a/paddle/fluid/inference/analysis/ir_pass_manager.cc +++ b/paddle/fluid/inference/analysis/ir_pass_manager.cc @@ -14,6 +14,7 @@ #include "paddle/fluid/inference/analysis/ir_pass_manager.h" #include +#include #include "paddle/fluid/framework/ir/fuse_pass_base.h" #include "paddle/fluid/framework/ir/graph.h" #include "paddle/fluid/framework/scope.h" @@ -37,13 +38,16 @@ IRPassManager::IRPassManager(const ProgramDesc &program, void IRPassManager::Apply(const std::vector &passes) { // Apply all the passes std::string pre_pass; + int pass_num = 0; for (const std::string &pass_name : passes) { PrettyLogEndl(Style::H2(), "--- Running IR pass [%s]", pass_name); auto pass = framework::ir::PassRegistry::Instance().Get(pass_name); if (pass_name == "graph_viz_pass") { - std::string dot_file_path = - "ir_" + (pre_pass.empty() ? "origin" : pre_pass) + ".dot"; + std::string dot_file_path = std::to_string(pass_num) + "_ir_" + + (pre_pass.empty() ? "origin" : pre_pass) + + ".dot"; pass->Set("graph_viz_path", new std::string(std::move(dot_file_path))); + pass_num++; } graph_ = pass->Apply(std::move(graph_)); pre_pass = pass_name; diff --git a/paddle/fluid/inference/tests/api/analyzer_ner_tester.cc b/paddle/fluid/inference/tests/api/analyzer_ner_tester.cc index 661b047ed7cb70545267e468d8c2c48596a2994c..6e8e43add7d3383fa79efea91c23750be9c8956f 100644 --- a/paddle/fluid/inference/tests/api/analyzer_ner_tester.cc +++ b/paddle/fluid/inference/tests/api/analyzer_ner_tester.cc @@ -144,8 +144,9 @@ void TestChineseNERPrediction(bool use_analysis) { size_t num_samples; for (int i = 0; i < FLAGS_repeat; i++) { DataRecord data(FLAGS_infer_data, FLAGS_batch_size); + // Just one batch, the num_samples remains the same. num_samples = data.num_samples; - for (size_t bid = 0; bid < num_samples; ++bid) { + for (size_t bid = 0; bid < num_samples / FLAGS_batch_size; ++bid) { PrepareInputs(&input_slots, &data, FLAGS_batch_size); timer.tic(); predictor->Run(input_slots, &outputs); diff --git a/paddle/fluid/operators/attention_lstm_op.cc b/paddle/fluid/operators/attention_lstm_op.cc index 39b0c856996c11c6efdb530f1396afd5731c778d..9b943440a869e213db4ed761cfe7c508bc5e94ae 100644 --- a/paddle/fluid/operators/attention_lstm_op.cc +++ b/paddle/fluid/operators/attention_lstm_op.cc @@ -24,28 +24,28 @@ namespace operators { void AttentionLSTMOp::InferShape(framework::InferShapeContext* ctx) const { PADDLE_ENFORCE(ctx->HasInput("X"), - "Input(X) of AttentionLSTM should not be null."); + "Assert only one Input(X) of AttentionLSTM."); PADDLE_ENFORCE(ctx->HasInput("C0"), - "Input(C0) of AttentionLSTM should not be null."); + "Assert only one Input(C0) of AttentionLSTM."); PADDLE_ENFORCE(ctx->HasInput("LSTMWeight"), - "Input(LSTMWeight) of AttentionLSTM should not be null."); + "Assert only one Input(LSTMWeight) of AttentionLSTM."); PADDLE_ENFORCE(ctx->HasInput("LSTMBias"), - "Input(LSTMBias) of AttentionLSTM should not be null."); + "Assert only one Input(LSTMBias) of AttentionLSTM."); PADDLE_ENFORCE(ctx->HasInput("AttentionWeight"), - "Input(AttentionWeight) of AttentionLSTM should not be null."); + "Assert only one Input(AttentionWeight) of AttentionLSTM."); PADDLE_ENFORCE(ctx->HasOutput("Hidden"), - "Output(Hidden) of AttentionLSTM should not be null."); + "Assert only one Output(Hidden) of AttentionLSTM."); PADDLE_ENFORCE(ctx->HasOutput("Cell"), - "Output(Cell) of AttentionLSTM should not be null."); + "Assert only one Output(Cell) of AttentionLSTM."); PADDLE_ENFORCE(ctx->HasOutput("AttentionedX"), - "Output(AttentionedX) of AttentionLSTM should not be null."); + "Assert only one Output(AttentionedX) of AttentionLSTM."); PADDLE_ENFORCE(ctx->HasOutput("AttentionFCOut"), - "Output(AttentionFCOut) of AttentionLSTM should not be null."); + "Assert only one Output(AttentionFCOut) of AttentionLSTM."); PADDLE_ENFORCE(ctx->HasOutput("LSTMX"), - "Output(LSTMX) of AttentionLSTM should not be null."); + "Assert only one Output(LSTMX) of AttentionLSTM."); PADDLE_ENFORCE(ctx->HasOutput("LSTMOUT"), - "Output(LSTMOUT) of AttentionLSTM should not be null."); + "Assert only one Output(LSTMOUT) of AttentionLSTM."); auto x_dims = ctx->GetInputDim("X"); const int M = x_dims[1]; diff --git a/paddle/fluid/operators/fusion_gru_op.cc b/paddle/fluid/operators/fusion_gru_op.cc index 916f84cb4a78c3721cb67bd3cf8d3759a8eaf1bf..31e87d9113118ebe7a4b25ffee5ba55e2714fb66 100644 --- a/paddle/fluid/operators/fusion_gru_op.cc +++ b/paddle/fluid/operators/fusion_gru_op.cc @@ -25,14 +25,14 @@ namespace paddle { namespace operators { void FusionGRUOp::InferShape(framework::InferShapeContext* ctx) const { - PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) of GRU should not be null."); + PADDLE_ENFORCE(ctx->HasInput("X"), "Assert only one Input(X) of GRU."); PADDLE_ENFORCE(ctx->HasInput("WeightX"), - "Input(WeightX) of GRU should not be null."); + "Assert only one Input(WeightX) of GRU."); PADDLE_ENFORCE(ctx->HasInput("WeightH"), - "Input(WeightH) of GRU should not be null."); - PADDLE_ENFORCE(ctx->HasOutput("XX"), "Output(XX) of GRU should not be null."); + "Assert only one Input(WeightH) of GRU."); + PADDLE_ENFORCE(ctx->HasOutput("XX"), "Assert only one Output(XX) of GRU."); PADDLE_ENFORCE(ctx->HasOutput("Hidden"), - "Output(Hidden) of GRU should not be null."); + "Assert only one Output(Hidden) of GRU."); auto x_dims = ctx->GetInputDim("X"); PADDLE_ENFORCE_EQ(x_dims.size(), 2, "Input(X)'s rank must be 2."); @@ -80,11 +80,11 @@ void FusionGRUOp::InferShape(framework::InferShapeContext* ctx) const { } else { xx_width = x_dims[1] > wx_dims[1] ? wx_dims[1] : x_dims[1]; PADDLE_ENFORCE(ctx->HasOutput("ReorderedH0"), - "Output(ReorderedH0) of GRU should not be null."); + "Assert only one Output(ReorderedH0) of GRU."); PADDLE_ENFORCE(ctx->HasOutput("BatchedInput"), - "Output(BatchedInput) of GRU should not be null."); + "Assert only one Output(BatchedInput) of GRU."); PADDLE_ENFORCE(ctx->HasOutput("BatchedOut"), - "Output(BatchedOut) of GRU should not be null."); + "Assert only one Output(BatchedOut) of GRU."); ctx->SetOutputDim("BatchedInput", {x_dims[0], wx_dims[1]}); ctx->SetOutputDim("BatchedOut", out_dims); } diff --git a/paddle/fluid/operators/fusion_lstm_op.cc b/paddle/fluid/operators/fusion_lstm_op.cc index ef23ab3f981786d33567619ad0194d21f31bdc8e..55e465e3af08c012b8cff7714452ed32b32a5556 100644 --- a/paddle/fluid/operators/fusion_lstm_op.cc +++ b/paddle/fluid/operators/fusion_lstm_op.cc @@ -24,20 +24,17 @@ namespace paddle { namespace operators { void FusionLSTMOp::InferShape(framework::InferShapeContext* ctx) const { - PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) of LSTM should not be null."); + PADDLE_ENFORCE(ctx->HasInput("X"), "Assert only one Input(X) of LSTM."); PADDLE_ENFORCE(ctx->HasInput("WeightX"), - "Input(WeightX) of LSTM should not be null."); + "Assert only one Input(WeightX) of LSTM."); PADDLE_ENFORCE(ctx->HasInput("WeightH"), - "Input(WeightH) of LSTM should not be null."); - PADDLE_ENFORCE(ctx->HasInput("Bias"), - "Input(Bias) of LSTM should not be null."); - - PADDLE_ENFORCE(ctx->HasOutput("XX"), - "Output(XX) of LSTM should not be null."); + "Assert only one Input(WeightH) of LSTM."); + PADDLE_ENFORCE(ctx->HasInput("Bias"), "Assert only one Input(Bias) of LSTM."); + PADDLE_ENFORCE(ctx->HasOutput("XX"), "Assert only one Output(XX) of LSTM."); PADDLE_ENFORCE(ctx->HasOutput("Hidden"), - "Output(Hidden) of LSTM should not be null."); + "Assert only one Output(Hidden) of LSTM."); PADDLE_ENFORCE(ctx->HasOutput("Cell"), - "Output(Cell) of LSTM should not be null."); + "Assert only one Output(Cell) of LSTM."); auto x_dims = ctx->GetInputDim("X"); PADDLE_ENFORCE_EQ(x_dims.size(), 2, "Input(X)'s rank must be 2."); @@ -96,15 +93,15 @@ void FusionLSTMOp::InferShape(framework::InferShapeContext* ctx) const { } else { xx_width = x_dims[1] > wx_dims[1] ? wx_dims[1] : x_dims[1]; PADDLE_ENFORCE(ctx->HasOutput("BatchedInput"), - "Output(BatchedInput) of LSTM should not be null."); + "Assert only one Output(BatchedInput) of LSTM."); PADDLE_ENFORCE(ctx->HasOutput("BatchedHidden"), - "Output(BatchedHidden) of LSTM should not be null."); + "Assert only one Output(BatchedHidden) of LSTM."); PADDLE_ENFORCE(ctx->HasOutput("BatchedCell"), - "Output(BatchedCell) of LSTM should not be null."); + "Assert only one Output(BatchedCell) of LSTM."); PADDLE_ENFORCE(ctx->HasOutput("ReorderedH0"), - "Output(ReorderedH0) of LSTM should not be null."); + "Assert only one Output(ReorderedH0) of LSTM"); PADDLE_ENFORCE(ctx->HasOutput("ReorderedC0"), - "Output(ReorderedC0) of LSTM should not be null."); + "Assert only one Output(ReorderedC0) of LSTM."); ctx->SetOutputDim("BatchedInput", {x_dims[0], wx_dims[1]}); ctx->SetOutputDim("BatchedHidden", out_dims); ctx->SetOutputDim("BatchedCell", out_dims); diff --git a/paddle/fluid/operators/softmax_with_cross_entropy_op.cu b/paddle/fluid/operators/softmax_with_cross_entropy_op.cu index 148faec4af50c4fe3a8e9d1f22e0da70c8ddcb44..a07c17348ebb3f768d1c8be65c2d31e3c130bd23 100644 --- a/paddle/fluid/operators/softmax_with_cross_entropy_op.cu +++ b/paddle/fluid/operators/softmax_with_cross_entropy_op.cu @@ -31,7 +31,8 @@ __global__ void CrossEntropyGrad(T* logit_grad, const int64_t* labels, for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < batch_size; i += blockDim.x * gridDim.x) { int idx = i * class_num + labels[i]; - logit_grad[idx] -= static_cast(1.); + logit_grad[idx] -= + ignore_index == labels[i] ? static_cast(0.) : static_cast(1.); } } diff --git a/python/paddle/fluid/tests/book/high-level-api/image_classification/test_image_classification_resnet.py b/python/paddle/fluid/tests/book/high-level-api/image_classification/test_image_classification_resnet.py index e5ae95e2d943917b9bc10f0d4c4bdc5f8fb07fdb..de276755bb1eb2746cc780575a40357255223809 100644 --- a/python/paddle/fluid/tests/book/high-level-api/image_classification/test_image_classification_resnet.py +++ b/python/paddle/fluid/tests/book/high-level-api/image_classification/test_image_classification_resnet.py @@ -178,7 +178,4 @@ if __name__ == '__main__': for parallel in (False, True): if use_cuda and not core.is_compiled_with_cuda(): continue - # TODO(minqiyang): remove this line after fixing the deletion - # order problem of Scope in ParallelExecutor in manylinux - if six.PY2: - main(use_cuda=use_cuda, parallel=parallel) + main(use_cuda=use_cuda, parallel=parallel) diff --git a/python/paddle/fluid/tests/book/high-level-api/image_classification/test_image_classification_vgg.py b/python/paddle/fluid/tests/book/high-level-api/image_classification/test_image_classification_vgg.py index ff91be72c918f8dac65b7030e45c4a00deb965ac..dd547f3448ae55c07b6c09f9de4ac08d8ec5ee88 100644 --- a/python/paddle/fluid/tests/book/high-level-api/image_classification/test_image_classification_vgg.py +++ b/python/paddle/fluid/tests/book/high-level-api/image_classification/test_image_classification_vgg.py @@ -152,7 +152,4 @@ if __name__ == '__main__': for parallel in (False, True): if use_cuda and not core.is_compiled_with_cuda(): continue - # TODO(minqiyang): remove this line after fixing the deletion - # order problem of Scope in ParallelExecutor in manylinux - if six.PY2: - main(use_cuda=use_cuda, parallel=parallel) + main(use_cuda=use_cuda, parallel=parallel) diff --git a/python/paddle/fluid/tests/book/high-level-api/recognize_digits/test_recognize_digits_conv.py b/python/paddle/fluid/tests/book/high-level-api/recognize_digits/test_recognize_digits_conv.py index fa72c939e57356f26d60032dd0a91c894b28c505..973308498bec3cddde2ef651751ad5d0c9f84503 100644 --- a/python/paddle/fluid/tests/book/high-level-api/recognize_digits/test_recognize_digits_conv.py +++ b/python/paddle/fluid/tests/book/high-level-api/recognize_digits/test_recognize_digits_conv.py @@ -155,7 +155,4 @@ if __name__ == '__main__': for parallel in (False, True): if use_cuda and not core.is_compiled_with_cuda(): continue - # TODO(minqiyang): remove this line after fixing the deletion - # order problem of Scope in ParallelExecutor in manylinux - if six.PY2: - main(use_cuda=use_cuda, parallel=parallel) + main(use_cuda=use_cuda, parallel=parallel) diff --git a/python/paddle/fluid/tests/book/high-level-api/recognize_digits/test_recognize_digits_mlp.py b/python/paddle/fluid/tests/book/high-level-api/recognize_digits/test_recognize_digits_mlp.py index 440d2a30835cb89089709f024a4dcc6e4113efa8..cb4aeb430e1a9662a183084c0cdacc41c5a8ec11 100644 --- a/python/paddle/fluid/tests/book/high-level-api/recognize_digits/test_recognize_digits_mlp.py +++ b/python/paddle/fluid/tests/book/high-level-api/recognize_digits/test_recognize_digits_mlp.py @@ -137,7 +137,4 @@ if __name__ == '__main__': for parallel in (False, True): if use_cuda and not core.is_compiled_with_cuda(): continue - # TODO(minqiyang): remove this line after fixing the deletion - # order problem of Scope in ParallelExecutor in manylinux - if six.PY2: - main(use_cuda=use_cuda, parallel=parallel) + main(use_cuda=use_cuda, parallel=parallel) diff --git a/python/paddle/fluid/tests/unittests/test_parallel_executor_transformer.py b/python/paddle/fluid/tests/unittests/test_parallel_executor_transformer.py index 5ad922725a0b692e28552737a99b745ed09ddbd5..a55b2002ed989d4588716202a37aa6f4139825ea 100644 --- a/python/paddle/fluid/tests/unittests/test_parallel_executor_transformer.py +++ b/python/paddle/fluid/tests/unittests/test_parallel_executor_transformer.py @@ -20,6 +20,7 @@ import numpy as np from parallel_executor_test_base import TestParallelExecutorBase import unittest import paddle +import paddle.fluid.core as core import paddle.dataset.wmt16 as wmt16 import os @@ -170,7 +171,8 @@ class TestTransformer(TestParallelExecutorBase): writer.complete_append_tensor() def test_main(self): - self.check_network_convergence(transformer, use_cuda=True) + if core.is_compiled_with_cuda(): + self.check_network_convergence(transformer, use_cuda=True) self.check_network_convergence(transformer, use_cuda=False, iter=5) diff --git a/python/paddle/fluid/tests/unittests/test_py_reader_using_executor.py b/python/paddle/fluid/tests/unittests/test_py_reader_using_executor.py index 931cac409f26fce4ecca18c4b0cfcca2e675046f..b7fad9b3a60632adb564e1d155a3d935706b467f 100644 --- a/python/paddle/fluid/tests/unittests/test_py_reader_using_executor.py +++ b/python/paddle/fluid/tests/unittests/test_py_reader_using_executor.py @@ -96,7 +96,8 @@ class TestPyReaderUsingExecutor(unittest.TestCase): self.queue_capacity = 50 def test(self): - for use_cuda in [False, True]: + for use_cuda in ([False, True] + if core.is_compiled_with_cuda() else [False]): for use_parallel_executor in [False, True]: for use_double_buffer in [False, True]: print('Test Parameters:'),