diff --git a/paddle/fluid/framework/CMakeLists.txt b/paddle/fluid/framework/CMakeLists.txt index ed4d042062e368105d1e18d69419a001e94c4163..0ed62ac93a7278ac9063d096853366b999e68818 100644 --- a/paddle/fluid/framework/CMakeLists.txt +++ b/paddle/fluid/framework/CMakeLists.txt @@ -203,6 +203,10 @@ cc_library(op_registry SRCS op_registry.cc DEPS op_proto_maker op_info operator cc_library(op_call_stack SRCS op_call_stack.cc DEPS op_proto_maker enforce) cc_test(op_call_stack_test SRCS op_call_stack_test.cc DEPS op_call_stack) + +cc_library(program_processing SRCS program_processing.cc DEPS framework_proto) +cc_test(program_processing_test SRCS program_processing_test.cc DEPS proto_desc program_processing) + if(WITH_GPU) nv_test(op_registry_test SRCS op_registry_test.cc DEPS op_registry) elseif(WITH_ROCM) diff --git a/paddle/fluid/framework/op_desc.h b/paddle/fluid/framework/op_desc.h index 1bc1a308e453bb816b00d6ca9a62358f8d33082a..31339b4d620b5164e1b2eeac78cdbf6c935f77d1 100644 --- a/paddle/fluid/framework/op_desc.h +++ b/paddle/fluid/framework/op_desc.h @@ -122,6 +122,16 @@ class OpDesc { const VariableNameMap &Outputs() const { return outputs_; } + VariableNameMap *MutableInputs() { + this->need_update_ = true; + return &this->inputs_; + } + + VariableNameMap *MutableOutputs() { + this->need_update_ = true; + return &this->outputs_; + } + AttributeMap *MutableAttrMap() { this->need_update_ = true; return &this->attrs_; diff --git a/paddle/fluid/framework/program_processing.cc b/paddle/fluid/framework/program_processing.cc new file mode 100644 index 0000000000000000000000000000000000000000..3bcf6f8f3855f0edd857cca37bd1497e2fd9ab2f --- /dev/null +++ b/paddle/fluid/framework/program_processing.cc @@ -0,0 +1,131 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/framework/program_processing.h" +#include "paddle/fluid/framework/block_desc.h" + +namespace paddle { +namespace framework { + +void ProgramProcessor::GetInputsOutputsInBlock( + const BlockDesc ¤t_block, std::set *inner_inputs, + std::set *inner_outputs) { + /* Find inputs and outputs in current control flow block. + :param current_block: Current control flow block. + :param inner_inputs: Input var vector of ops in current block. + :param inner_outputs: Output var vector of ops in current block. */ + + // Step1: update inner_inputs and inner_outputs + // NOTE: Here assumes that all variables are input or output of Ops, + + for (OpDesc *op : current_block.AllOps()) { + for (auto iname : op->InputNames()) { + for (auto in_var_name : op->Input(iname)) { + VLOG(3) << "insert inner_inputs_name:" << in_var_name; + inner_inputs->insert(in_var_name); + } + } + + for (auto oname : op->OutputNames()) { + for (auto out_var_name : op->Output(oname)) { + VLOG(3) << "insert out_var_name:" << out_var_name; + inner_outputs->insert(out_var_name); + } + } + } + + // Step2: Remove variable created in current control flow block. + BlockDesc *parent_block = current_block.ParentBlock(); + + if (parent_block) { + for (auto iter = inner_inputs->begin(); iter != inner_inputs->end();) { + const std::string &in_var_name = *iter; + if (current_block.HasVar(in_var_name)) { + VLOG(3) << "remove inner intput var:" << in_var_name; + iter = inner_inputs->erase(iter); + } else { + ++iter; + } + } + + for (auto iter = inner_outputs->begin(); iter != inner_outputs->end();) { + const std::string &out_var_name = *iter; + if (current_block.HasVar(out_var_name)) { + VLOG(3) << "remove inner output var:" << out_var_name; + iter = inner_outputs->erase(iter); + } else { + ++iter; + } + } + } +} + +void ProgramProcessor::AddDepToBlockOp(const BlockDesc &block) { + VLOG(3) << "Op size:" << block.AllOps().size(); + for (OpDesc *op : block.AllOps()) { + if (op->HasAttr("sub_block")) { + auto op_type = op->Type(); + BlockDesc *sub_block = + BOOST_GET_MUTABLE(BlockDesc *, op->GetAttr("sub_block")); + + // recursively processing + AddDepToBlockOp(*sub_block); + + std::set sub_inputs; + std::set sub_outputs; + ProgramProcessor::GetInputsOutputsInBlock(*sub_block, &sub_inputs, + &sub_outputs); + VLOG(3) << "sub_inputs.size:" << sub_inputs.size(); + VLOG(3) << "sub_outputs.size:" << sub_outputs.size(); + + auto *op_inputs = op->MutableInputs(); + std::vector *op_input_var_vec; + VLOG(3) << "op_type:>>>>>>" << op_type; + if (op_type.compare("while") == 0) { + op_input_var_vec = &((*op_inputs)["kX"]); + } else if (op_type.compare("conditional_block") == 0) { + op_input_var_vec = &((*op_inputs)["kInputs"]); + } else { + // Only support while_op and conditinal_block_op now + LOG(WARNING) + << "Currently, only support while_op and conditinal_block_op.\n"; + continue; + } + + for (auto sub_input : sub_inputs) { + if (std::find(op_input_var_vec->begin(), op_input_var_vec->end(), + sub_input) == op_input_var_vec->end()) + op_input_var_vec->push_back(sub_input); + VLOG(3) << "modified private inputs, inputs.size():" + << op_input_var_vec->size(); + } + + auto *op_outputs = op->MutableOutputs(); + auto *op_output_var_vec = &((*op_outputs)["kOutputs"]); + + for (auto sub_output : sub_outputs) { + if (std::find(op_output_var_vec->begin(), op_output_var_vec->end(), + sub_output) == op_output_var_vec->end()) + op_output_var_vec->push_back(sub_output); + VLOG(3) << "modified private outputs, outputs.size():" + << op_output_var_vec->size(); + } + } + } +} + +ProgramProcessor::ProgramProcessor() {} + +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/program_processing.h b/paddle/fluid/framework/program_processing.h new file mode 100644 index 0000000000000000000000000000000000000000..b495c31793d9ace2edf5d3b7c1cc537bb8387c56 --- /dev/null +++ b/paddle/fluid/framework/program_processing.h @@ -0,0 +1,34 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include "paddle/fluid/framework/program_desc.h" + +namespace paddle { +namespace framework { + +class ProgramDesc; + +class ProgramProcessor { + public: + ProgramProcessor(); + + void GetInputsOutputsInBlock(const BlockDesc ¤t_block, + std::set *inner_inputs, + std::set *inner_outputs); + + void AddDepToBlockOp(const BlockDesc &block); +}; +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/program_processing_test.cc b/paddle/fluid/framework/program_processing_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..18f4ed42938e23d6be726b99d8167cfa1988920a --- /dev/null +++ b/paddle/fluid/framework/program_processing_test.cc @@ -0,0 +1,242 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/framework/program_processing.h" + +#include "gtest/gtest-message.h" +#include "gtest/gtest-test-part.h" +#include "gtest/gtest.h" +#include "gtest/gtest_pred_impl.h" + +namespace paddle { + +namespace framework { + +TEST(ProgramDesc, GetInputsOutputsInBlock) { + ProgramDesc program; + auto* global_block = program.MutableBlock(0); + auto* mul_1_x = global_block->Var("Mul_1_X"); + mul_1_x->SetType(proto::VarType::LOD_TENSOR); + mul_1_x->SetLoDLevel(0); + mul_1_x->SetDataType(proto::VarType::FP32); + mul_1_x->SetShape({1000, 784}); + + auto* mul_1_y = global_block->Var("Mul_1_Y"); + mul_1_y->SetType(proto::VarType::LOD_TENSOR); + mul_1_y->SetLoDLevel(0); + mul_1_y->SetDataType(proto::VarType::FP32); + mul_1_y->SetShape({784, 100}); + + auto* mul_1_out = global_block->Var("Mul_1_Out"); + mul_1_out->SetType(proto::VarType::LOD_TENSOR); + auto* mul_op_1 = global_block->AppendOp(); + + mul_op_1->SetType("mul"); + mul_op_1->SetInput("X", {mul_1_x->Name()}); + mul_op_1->SetInput("Y", {mul_1_y->Name()}); + mul_op_1->SetOutput("Y", {mul_1_out->Name()}); + + // building cond op such as less_than + auto* less_than_op_1 = global_block->AppendOp(); + less_than_op_1->SetType("less_than"); + auto* less_than_1_x = global_block->Var("Less_than_1_X"); + less_than_1_x->SetType(proto::VarType::LOD_TENSOR); + less_than_1_x->SetLoDLevel(0); + less_than_1_x->SetDataType(proto::VarType::FP32); + less_than_1_x->SetShape({1}); + + auto* less_than_1_y = global_block->Var("Less_than_1_Y"); + less_than_1_y->SetType(proto::VarType::LOD_TENSOR); + less_than_1_y->SetLoDLevel(0); + less_than_1_y->SetDataType(proto::VarType::FP32); + less_than_1_y->SetShape({1}); + + auto* less_than_1_out = global_block->Var("Less_than_1_Out"); + less_than_1_out->SetType(proto::VarType::BOOL); + + less_than_op_1->SetInput("X", {less_than_1_x->Name()}); + less_than_op_1->SetInput("Y", {less_than_1_y->Name()}); + less_than_op_1->SetOutput("Out", {less_than_1_out->Name()}); + + BlockDesc* sub_block = program.AppendBlock(*global_block); + std::vector sub_blocks; + sub_blocks.push_back(sub_block); + + BlockDesc* sub_block2 = + program.AppendBlock(*sub_block); // for testing nested case. + sub_blocks.push_back(sub_block2); + + // building while op in sub_block + auto* while_op = global_block->AppendOp(); + while_op->SetType("while"); + while_op->SetAttr("sub_block", sub_blocks[0]); + + auto* while_x = global_block->Var("While_X"); + while_x->SetType(proto::VarType::LOD_TENSOR); + while_x->SetLoDLevel(0); + while_x->SetDataType(proto::VarType::FP32); + while_x->SetShape({1}); + + while_op->SetInput("kX", {while_x->Name()}); + while_op->SetInput("kCondition", {less_than_1_out->Name()}); + + auto* while_out = global_block->Var("While_Out"); + while_out->SetType(proto::VarType::LOD_TENSOR); + while_out->SetLoDLevel(0); + while_out->SetDataType(proto::VarType::FP32); + while_out->SetShape({1}); + + auto* steps = global_block->Var("StepScopes"); + + while_op->SetOutput("kOutputs", {while_out->Name()}); + while_op->SetOutput("kStepScopes", {steps->Name()}); + + auto* mul_2_x = global_block->Var("Mul_2_X"); + mul_2_x->SetType(proto::VarType::LOD_TENSOR); + mul_2_x->SetLoDLevel(0); + mul_2_x->SetDataType(proto::VarType::FP32); + mul_2_x->SetShape({1000, 784}); + + auto* mul_2_y = global_block->Var("Mul_2_Y"); + mul_2_y->SetType(proto::VarType::LOD_TENSOR); + mul_2_y->SetLoDLevel(0); + mul_2_y->SetDataType(proto::VarType::FP32); + mul_2_y->SetShape({784, 100}); + + auto* mul_op_2 = sub_blocks[0]->AppendOp(); + mul_op_2->SetType("mul"); + mul_op_2->SetInput("X", {mul_2_x->Name()}); + mul_op_2->SetInput("Y", {mul_2_y->Name()}); + + auto* mul_2_out = global_block->Var("Mul_2_Out"); + mul_2_out->SetType(proto::VarType::LOD_TENSOR); + mul_op_2->SetOutput("Y", {mul_2_out->Name()}); + + auto* less_than_op_2 = sub_blocks[0]->AppendOp(); + less_than_op_2->SetType("less_than"); + auto* less_than_2_x = global_block->Var("Less_than_2_X"); + less_than_2_x->SetType(proto::VarType::LOD_TENSOR); + less_than_2_x->SetLoDLevel(0); + less_than_2_x->SetDataType(proto::VarType::FP32); + less_than_2_x->SetShape({1}); + + auto* less_than_2_y = global_block->Var("Less_than_2_Y"); + less_than_2_y->SetType(proto::VarType::LOD_TENSOR); + less_than_2_y->SetLoDLevel(0); + less_than_2_y->SetDataType(proto::VarType::FP32); + less_than_2_y->SetShape({1}); + + less_than_op_2->SetInput("X", {less_than_2_x->Name()}); + less_than_op_2->SetInput("Y", {less_than_2_y->Name()}); + + auto* less_than_2_out = global_block->Var("Less_than_2_Out"); + less_than_2_out->SetType(proto::VarType::BOOL); + less_than_op_2->SetOutput("Out", {less_than_2_out->Name()}); + + auto* cond_op = sub_blocks[0]->AppendOp(); + cond_op->SetType("conditional_block"); + cond_op->SetAttr("sub_block", sub_blocks[1]); + + auto* cond_x = sub_blocks[0]->Var("Cond_X"); + cond_x->SetType(proto::VarType::LOD_TENSOR); + cond_x->SetLoDLevel(0); + cond_x->SetDataType(proto::VarType::FP32); + cond_x->SetShape({1}); + + cond_op->SetInput("kInputs", {cond_x->Name()}); + cond_op->SetInput("kCondition", {less_than_2_out->Name()}); + + auto* cond_out = sub_blocks[0]->Var("Cond_Out"); + cond_out->SetType(proto::VarType::LOD_TENSOR); + cond_out->SetLoDLevel(0); + cond_out->SetDataType(proto::VarType::FP32); + cond_out->SetShape({1}); + + auto* scope = sub_blocks[0]->Var("Scope"); + scope->SetType(proto::VarType::STEP_SCOPES); + + cond_op->SetOutput("kOutputs", {cond_out->Name()}); + cond_op->SetOutput("kScope", {scope->Name()}); + + auto* mul_3_x = global_block->Var("Mul_3_X"); + mul_3_x->SetType(proto::VarType::LOD_TENSOR); + mul_3_x->SetLoDLevel(0); + mul_3_x->SetDataType(proto::VarType::FP32); + mul_3_x->SetShape({1000, 784}); + + auto* mul_3_y = global_block->Var("Mul_3_Y"); + mul_3_y->SetType(proto::VarType::LOD_TENSOR); + mul_3_y->SetLoDLevel(0); + mul_3_y->SetDataType(proto::VarType::FP32); + mul_3_y->SetShape({784, 100}); + + auto* mul_3_out = global_block->Var("Mul_3_Out"); + mul_3_out->SetType(proto::VarType::LOD_TENSOR); + + auto* mul_op_3 = sub_blocks[1]->AppendOp(); + mul_op_3->SetType("mul"); + mul_op_3->SetInput("X", {mul_3_x->Name()}); + mul_op_3->SetInput("Y", {mul_3_y->Name()}); + mul_op_3->SetOutput("Y", {mul_3_out->Name()}); + + ProgramProcessor program_processor; + std::set inner_inputs; + std::set inner_outputs; + + program_processor.GetInputsOutputsInBlock(*sub_blocks[0], &inner_inputs, + &inner_outputs); + + VLOG(3) << "inner_inputs().size():" << inner_inputs.size(); + VLOG(3) << "inner_outputs().size():" << inner_outputs.size(); + + ASSERT_EQ(5UL, inner_inputs.size()); + ASSERT_EQ(2UL, inner_outputs.size()); + + // varible "Less_than_2_Out" is the input of cond_op, it also is the output of + // less_than_op. + std::set inner_inputs_{"Less_than_2_Out", "Less_than_2_X", + "Less_than_2_Y", "Mul_2_X", "Mul_2_Y"}; + std::set inner_outputs_{"Less_than_2_Out", "Mul_2_Out"}; + + ASSERT_EQ(inner_inputs, inner_inputs_); + ASSERT_EQ(inner_outputs, inner_outputs_); + + // Test AddDepToBlockOp + VLOG(3) << "Before AddDependency, while op's input kX size:" + << while_op->Input("kX").size(); + VLOG(3) << "Before AddDependency, while op's output kOutPuts size:" + << while_op->Output("kOutputs").size(); + + program_processor.AddDepToBlockOp(*global_block); + + VLOG(3) << "After AddDependency, while op's input kX size:" + << while_op->Input("kX").size(); + VLOG(3) << "After AddDependency, while op's output kOutPuts size:" + << while_op->Output("kOutputs").size(); + + ASSERT_EQ(8UL, while_op->Input("kX").size()); + ASSERT_EQ(4UL, while_op->Output("kOutputs").size()); + + std::vector var_input_vec = { + "While_X", "Less_than_2_Out", "Less_than_2_X", "Less_than_2_Y", + "Mul_2_X", "Mul_2_Y", "Mul_3_X", "Mul_3_Y"}; + + std::vector var_output_vec = {"While_Out", "Less_than_2_Out", + "Mul_2_Out", "Mul_3_Out"}; + + ASSERT_EQ(var_input_vec, while_op->Input("kX")); + ASSERT_EQ(var_output_vec, while_op->Output("kOutputs")); +} +} // namespace framework +} // namespace paddle