diff --git a/paddle/fluid/framework/details/all_reduce_deps_pass.cc b/paddle/fluid/framework/details/all_reduce_deps_pass.cc index 2e20c436dfdb61fcda78cd044b86848c750cf22c..87d3b1042bce57c7e48cdd1ae3ad9adda95f624b 100644 --- a/paddle/fluid/framework/details/all_reduce_deps_pass.cc +++ b/paddle/fluid/framework/details/all_reduce_deps_pass.cc @@ -50,7 +50,7 @@ std::unique_ptr AllReduceDepsPass::ApplyImpl( std::unordered_map vars; // TODO(gongwb): use graph topology sort to find the order of operators. // Note that must assert topology sort is stable - auto& ops = Get>(kAllOpDescs); + auto& ops = graph->Get>(kAllOpDescs); for (auto* op_desc : ops) { auto outputs = op_desc->Outputs(); for (auto& o_it : outputs) { @@ -120,4 +120,4 @@ std::unique_ptr AllReduceDepsPass::ApplyImpl( REGISTER_PASS(all_reduce_deps_pass, paddle::framework::details::AllReduceDepsPass) - .RequirePassAttr(paddle::framework::details::kAllOpDescs); + .RequireGraphAttr(paddle::framework::details::kAllOpDescs); diff --git a/paddle/fluid/framework/details/build_strategy.cc b/paddle/fluid/framework/details/build_strategy.cc index 774be6c24c7e19b630b7d9e3c2f46ef81d8a9f09..c14a40a997786c46838958a7ea3e9ee3512350b6 100644 --- a/paddle/fluid/framework/details/build_strategy.cc +++ b/paddle/fluid/framework/details/build_strategy.cc @@ -183,7 +183,6 @@ std::unique_ptr BuildStrategy::Apply( // Create a default one if not finalized by user. CreatePassesFromStrategy(false); - std::vector all_ops = graph->OriginProgram().Block(0).AllOps(); for (std::shared_ptr &pass : pass_builder_->AllPasses()) { if (IsMultiDevPass(pass->Type())) { pass->Erase(kPlaces); @@ -201,33 +200,12 @@ std::unique_ptr BuildStrategy::Apply( pass->Erase("nccl_ctxs"); pass->SetNotOwned("nccl_ctxs", nctx); #endif - } else if (pass->Type() == "memory_optimize_pass") { - if (graph->Has(kAllOpDescs)) { - graph->Erase(kAllOpDescs); - } - - graph->SetNotOwned>(kAllOpDescs, &all_ops); - - pass->Erase(kAllOpDescs); - pass->SetNotOwned>(kAllOpDescs, &all_ops); - } else if (pass->Type() == "sequential_execution_pass") { LOG(INFO) << "set enable_sequential_execution:" << enable_sequential_execution_; - - pass->Erase(kAllOpDescs); - pass->SetNotOwned>(kAllOpDescs, &all_ops); } else if (pass->Type() == "all_reduce_deps_pass") { LOG(INFO) << "SeqOnlyAllReduceOps:" << SeqOnlyAllReduceOps(*this) << ", num_trainers:" << num_trainers_; - - pass->Erase(kAllOpDescs); - pass->SetNotOwned>(kAllOpDescs, &all_ops); - } else if (pass->Type() == "inplace_pass") { - if (graph->Has(kAllOpDescs)) { - graph->Erase(kAllOpDescs); - } - graph->SetNotOwned>(kAllOpDescs, &all_ops); } else if (pass->Type() == "fuse_relu_depthwise_conv_pass") { if (!use_cuda) { LOG(WARNING) << "fuse_relu_depthwise_conv_pass is only supported on " diff --git a/paddle/fluid/framework/details/parallel_ssa_graph_executor.cc b/paddle/fluid/framework/details/parallel_ssa_graph_executor.cc index 46332a8f23d3fe4deb0c9382a3d4b9dc850fdddc..5b8ae8b6770df79df309bb6be16e4f2a24ee0460 100644 --- a/paddle/fluid/framework/details/parallel_ssa_graph_executor.cc +++ b/paddle/fluid/framework/details/parallel_ssa_graph_executor.cc @@ -81,7 +81,6 @@ ParallelSSAGraphExecutor::ParallelSSAGraphExecutor( local_scopes_(std::move(local_scopes)), pool_(places.size() >= 2 ? new ::ThreadPool(places.size()) : nullptr), places_(std::move(places)), - main_prog_(graph->OriginProgram()), // TODO(Yancey1989): Copying graphs is not safely since it deleted the // attrs. graphs_(SeparateMultiDevicesGraph(graph)) { @@ -89,10 +88,6 @@ ParallelSSAGraphExecutor::ParallelSSAGraphExecutor( auto seq_allreduce_pass = ir::PassRegistry::Instance().Get("all_reduce_deps_pass"); - seq_allreduce_pass->Erase(details::kAllOpDescs); - seq_allreduce_pass->Set>( - details::kAllOpDescs, - new std::vector(main_prog_.Block(0).AllOps())); for (size_t i = 0; i < graphs_.size(); ++i) { graphs_[i] = seq_allreduce_pass->Apply(std::move(graphs_[i])); } diff --git a/paddle/fluid/framework/details/parallel_ssa_graph_executor.h b/paddle/fluid/framework/details/parallel_ssa_graph_executor.h index a7a792dabd5992ac5f9b966f43016f5f05a70522..1e421f2a3a51363fe368859f7a34593c8c894077 100644 --- a/paddle/fluid/framework/details/parallel_ssa_graph_executor.h +++ b/paddle/fluid/framework/details/parallel_ssa_graph_executor.h @@ -46,7 +46,6 @@ class ParallelSSAGraphExecutor : public SSAGraphExecutor { std::vector local_scopes_; std::unique_ptr<::ThreadPool> pool_{nullptr}; std::vector places_; - framework::ProgramDesc main_prog_; std::vector> graphs_; std::vector> executors_; diff --git a/paddle/fluid/framework/details/sequential_execution_pass.cc b/paddle/fluid/framework/details/sequential_execution_pass.cc index 879fb29d5926941e574d0080051c195293bc60a9..d4e7bb658980cea139ae8e7c06b7e8ed331fbc58 100644 --- a/paddle/fluid/framework/details/sequential_execution_pass.cc +++ b/paddle/fluid/framework/details/sequential_execution_pass.cc @@ -40,7 +40,7 @@ std::unique_ptr SequentialExecutionPass::ApplyImpl( static std::unordered_set skip_dist_ops{ "send", "recv", "send_barrier", "fetch_barrier"}; - auto &ops = Get>(kAllOpDescs); + auto &ops = graph->Get>(kAllOpDescs); std::vector op_node_list; op_node_list.reserve(ops.size()); @@ -107,4 +107,4 @@ std::unique_ptr SequentialExecutionPass::ApplyImpl( REGISTER_PASS(sequential_execution_pass, paddle::framework::details::SequentialExecutionPass) - .RequirePassAttr(paddle::framework::details::kAllOpDescs); + .RequireGraphAttr(paddle::framework::details::kAllOpDescs); diff --git a/paddle/fluid/framework/ir/graph.cc b/paddle/fluid/framework/ir/graph.cc index 4b5c846f3271b2dd5e094020571069aff590cd2b..5ea30f824f95b6fbdcda86a6696c4055a085be49 100644 --- a/paddle/fluid/framework/ir/graph.cc +++ b/paddle/fluid/framework/ir/graph.cc @@ -76,6 +76,9 @@ std::map> Graph::InitFromProgram( var->inputs.push_back(node); } } + Set>( + details::kAllOpDescs, + new std::vector(program.Block(0).AllOps())); return var_nodes; } diff --git a/paddle/fluid/framework/ir/graph.h b/paddle/fluid/framework/ir/graph.h index 7e783f74ff44219665530ed330fad6b0e286ad2a..296f3b83961c1379ee2c1237aa15784791b46878 100644 --- a/paddle/fluid/framework/ir/graph.h +++ b/paddle/fluid/framework/ir/graph.h @@ -195,12 +195,6 @@ class Graph { return nullptr; } - // Returns reference to the original program. - // WARN: After a series of passes, the current graph can be quite - // different from OriginProgram. Caller shouldn't assume much from - // the returned OriginProgram. - const ProgramDesc &OriginProgram() const { return program_; } - // This method takes ownership of `node`. ir::Node *AddNode(ir::Node *node) { PADDLE_ENFORCE(node_set_.find(node) == node_set_.end()); diff --git a/python/paddle/fluid/contrib/slim/unitest/test_quantization_pass.py b/python/paddle/fluid/contrib/slim/unitest/test_quantization_pass.py deleted file mode 100644 index 4f3fee09459490bf290d1c74bd7297a93200f415..0000000000000000000000000000000000000000 --- a/python/paddle/fluid/contrib/slim/unitest/test_quantization_pass.py +++ /dev/null @@ -1,204 +0,0 @@ -# copyright (c) 2018 paddlepaddle authors. all rights reserved. -# -# licensed under the apache license, version 2.0 (the "license"); -# you may not use this file except in compliance with the license. -# you may obtain a copy of the license at -# -# http://www.apache.org/licenses/license-2.0 -# -# unless required by applicable law or agreed to in writing, software -# distributed under the license is distributed on an "as is" basis, -# without warranties or conditions of any kind, either express or implied. -# see the license for the specific language governing permissions and -# limitations under the license. - -import unittest -import random -import numpy as np -import paddle.fluid as fluid -import six -from paddle.fluid.framework import Program -from paddle.fluid.framework import IrGraph -from paddle.fluid.contrib.slim.quantization import QuantizationTransformPass -from paddle.fluid import core - - -def linear_fc(num): - data = fluid.layers.data(name='image', shape=[1, 28, 28], dtype='float32') - label = fluid.layers.data(name='label', shape=[1], dtype='int64') - hidden = data - for _ in six.moves.xrange(num): - hidden = fluid.layers.fc(hidden, size=128, act='relu') - fc = fluid.layers.fc(input=hidden, size=10) - loss = fluid.layers.softmax_with_cross_entropy(fc, label=label) - loss = fluid.layers.mean(loss) - return loss - - -def residual_block(num): - def conv_bn_layer(input, - ch_out, - filter_size, - stride, - padding, - act='relu', - bias_attr=False): - tmp = fluid.layers.conv2d( - input=input, - filter_size=filter_size, - num_filters=ch_out, - stride=stride, - padding=padding, - act=None, - bias_attr=bias_attr) - return fluid.layers.batch_norm(input=tmp, act=act) - - data = fluid.layers.data(name='image', shape=[1, 28, 28], dtype='float32') - label = fluid.layers.data(name='label', shape=[1], dtype='int64') - hidden = data - for _ in six.moves.xrange(num): - conv = conv_bn_layer(hidden, 16, 3, 1, 1, act=None, bias_attr=True) - short = conv_bn_layer(hidden, 16, 1, 1, 0, act=None) - hidden = fluid.layers.elementwise_add(x=conv, y=short, act='relu') - fc = fluid.layers.fc(input=hidden, size=10) - loss = fluid.layers.softmax_with_cross_entropy(fc, label) - loss = fluid.layers.mean(loss) - return loss - - -class TestQuantizationTransformPass(unittest.TestCase): - def setUp(self): - self.quantizable_op_and_inputs = { - 'conv2d': ['Input', 'Filter'], - 'depthwise_conv2d': ['Input', 'Filter'], - 'mul': ['X', 'Y'] - } - self.quantizable_grad_op_inputs = { - 'conv2d_grad': ['Input', 'Filter'], - 'depthwise_conv2d_grad': ['Input', 'Filter'], - 'mul_grad': ['X', 'Y'] - } - - def check_program(self, transform_pass, program): - quantized_ops = set() - for block in program.blocks: - for op in block.ops: - # check forward - if op.type in self.quantizable_op_and_inputs: - for arg_name in op.input_arg_names: - self.assertTrue( - arg_name.endswith('.quantized.dequantized')) - quantized_ops.add(arg_name) - - for op in block.ops: - # check backward - if op.type in self.quantizable_grad_op_inputs: - for pname in self.quantizable_grad_op_inputs[op.type]: - arg_name = op.input(pname)[0] - self.assertTrue( - arg_name.endswith('.quantized.dequantized')) - self.assertTrue(arg_name in quantized_ops) - - def linear_fc_quant(self, quant_type): - main = fluid.Program() - startup = fluid.Program() - with fluid.program_guard(main, startup): - loss = linear_fc(3) - opt = fluid.optimizer.Adam(learning_rate=0.001) - opt.minimize(loss) - exe = fluid.Executor(fluid.CPUPlace()) - graph = IrGraph(core.Graph(main.desc), for_test=False) - transform_pass = QuantizationTransformPass( - scope=fluid.global_scope(), - program_exe=exe, - activation_quantize_type=quant_type) - transform_pass.apply(graph) - marked_nodes = set() - for op in graph.all_ops(): - if op.name().find('quantize') > -1: - marked_nodes.add(op) - graph.draw('.', 'quantize_fc_' + quant_type, marked_nodes) - program = graph.to_program() - self.check_program(transform_pass, program) - val_graph = IrGraph(core.Graph(program.desc), for_test=False) - val_marked_nodes = set() - for op in val_graph.all_ops(): - if op.name().find('quantize') > -1: - val_marked_nodes.add(op) - val_graph.draw('.', 'val_fc_' + quant_type, val_marked_nodes) - - def test_linear_fc_quant_abs_max(self): - self.act_quant_op_type = 'fake_quantize_abs_max' - self.linear_fc_quant('abs_max') - - def test_linear_fc_quant_range_abs_max(self): - self.act_quant_op_type = 'fake_quantize_range_abs_max' - self.linear_fc_quant('range_abs_max') - - def residual_block_quant(self, quant_type): - main = fluid.Program() - startup = fluid.Program() - with fluid.program_guard(main, startup): - loss = residual_block(2) - opt = fluid.optimizer.Adam(learning_rate=0.001) - opt.minimize(loss) - exe = fluid.Executor(fluid.CPUPlace()) - graph = IrGraph(core.Graph(main.desc), for_test=False) - transform_pass = QuantizationTransformPass( - scope=fluid.global_scope(), - program_exe=exe, - activation_quantize_type=quant_type) - transform_pass.apply(graph) - marked_nodes = set() - for op in graph.all_ops(): - if op.name().find('quantize') > -1: - marked_nodes.add(op) - graph.draw('.', 'quantize_residual_' + quant_type, marked_nodes) - program = graph.to_program() - self.check_program(transform_pass, program) - val_graph = IrGraph(core.Graph(program.desc), for_test=False) - val_marked_nodes = set() - for op in val_graph.all_ops(): - if op.name().find('quantize') > -1: - val_marked_nodes.add(op) - val_graph.draw('.', 'val_residual_' + quant_type, val_marked_nodes) - - def test_residual_block_abs_max(self): - self.act_quant_op_type = 'fake_quantize_abs_max' - self.residual_block_quant('abs_max') - - def test_residual_block_range_abs_max(self): - self.act_quant_op_type = 'fake_quantize_range_abs_max' - self.residual_block_quant('range_abs_max') - - def test_execute_graph(self): - main = fluid.Program() - startup = fluid.Program() - with fluid.program_guard(main, startup): - loss = linear_fc(3) - opt = fluid.optimizer.Adam(learning_rate=0.0001) - opt.minimize(loss) - - exe = fluid.Executor(fluid.CPUPlace()) - graph = IrGraph(core.Graph(main.desc), for_test=False) - exe.run(startup) - binary = fluid.CompiledProgram(graph.graph).with_data_parallel( - loss_name=loss.name) - for i in range(10): - loss_val = exe.run(binary, - feed={ - 'image': np.ones( - [32, 784], dtype=np.float32), - 'label': np.ones( - [32, 1], dtype=np.int64) - }, - fetch_list=[loss]) - if i == 0: - start_loss = np.sum(loss_val) - elif i == 9: - end_loss = np.sum(loss_val) - self.assertLess(end_loss, start_loss) - - -if __name__ == '__main__': - unittest.main()