提交 19d78f67 编写于 作者: X Xin Pan

polish

test=develop
上级 32d5a160
......@@ -50,7 +50,7 @@ std::unique_ptr<ir::Graph> AllReduceDepsPass::ApplyImpl(
std::unordered_map<std::string, int> vars;
// TODO(gongwb): use graph topology sort to find the order of operators.
// Note that must assert topology sort is stable
auto& ops = Get<const std::vector<OpDesc*>>(kAllOpDescs);
auto& ops = graph->Get<const std::vector<OpDesc*>>(kAllOpDescs);
for (auto* op_desc : ops) {
auto outputs = op_desc->Outputs();
for (auto& o_it : outputs) {
......@@ -120,4 +120,4 @@ std::unique_ptr<ir::Graph> AllReduceDepsPass::ApplyImpl(
REGISTER_PASS(all_reduce_deps_pass,
paddle::framework::details::AllReduceDepsPass)
.RequirePassAttr(paddle::framework::details::kAllOpDescs);
.RequireGraphAttr(paddle::framework::details::kAllOpDescs);
......@@ -183,7 +183,6 @@ std::unique_ptr<ir::Graph> BuildStrategy::Apply(
// Create a default one if not finalized by user.
CreatePassesFromStrategy(false);
std::vector<OpDesc *> all_ops = graph->OriginProgram().Block(0).AllOps();
for (std::shared_ptr<ir::Pass> &pass : pass_builder_->AllPasses()) {
if (IsMultiDevPass(pass->Type())) {
pass->Erase(kPlaces);
......@@ -201,33 +200,12 @@ std::unique_ptr<ir::Graph> BuildStrategy::Apply(
pass->Erase("nccl_ctxs");
pass->SetNotOwned<platform::NCCLContextMap>("nccl_ctxs", nctx);
#endif
} else if (pass->Type() == "memory_optimize_pass") {
if (graph->Has(kAllOpDescs)) {
graph->Erase(kAllOpDescs);
}
graph->SetNotOwned<const std::vector<OpDesc *>>(kAllOpDescs, &all_ops);
pass->Erase(kAllOpDescs);
pass->SetNotOwned<const std::vector<OpDesc *>>(kAllOpDescs, &all_ops);
} else if (pass->Type() == "sequential_execution_pass") {
LOG(INFO) << "set enable_sequential_execution:"
<< enable_sequential_execution_;
pass->Erase(kAllOpDescs);
pass->SetNotOwned<const std::vector<OpDesc *>>(kAllOpDescs, &all_ops);
} else if (pass->Type() == "all_reduce_deps_pass") {
LOG(INFO) << "SeqOnlyAllReduceOps:" << SeqOnlyAllReduceOps(*this)
<< ", num_trainers:" << num_trainers_;
pass->Erase(kAllOpDescs);
pass->SetNotOwned<const std::vector<OpDesc *>>(kAllOpDescs, &all_ops);
} else if (pass->Type() == "inplace_pass") {
if (graph->Has(kAllOpDescs)) {
graph->Erase(kAllOpDescs);
}
graph->SetNotOwned<const std::vector<OpDesc *>>(kAllOpDescs, &all_ops);
} else if (pass->Type() == "fuse_relu_depthwise_conv_pass") {
if (!use_cuda) {
LOG(WARNING) << "fuse_relu_depthwise_conv_pass is only supported on "
......
......@@ -81,7 +81,6 @@ ParallelSSAGraphExecutor::ParallelSSAGraphExecutor(
local_scopes_(std::move(local_scopes)),
pool_(places.size() >= 2 ? new ::ThreadPool(places.size()) : nullptr),
places_(std::move(places)),
main_prog_(graph->OriginProgram()),
// TODO(Yancey1989): Copying graphs is not safely since it deleted the
// attrs.
graphs_(SeparateMultiDevicesGraph(graph)) {
......@@ -89,10 +88,6 @@ ParallelSSAGraphExecutor::ParallelSSAGraphExecutor(
auto seq_allreduce_pass =
ir::PassRegistry::Instance().Get("all_reduce_deps_pass");
seq_allreduce_pass->Erase(details::kAllOpDescs);
seq_allreduce_pass->Set<const std::vector<OpDesc *>>(
details::kAllOpDescs,
new std::vector<OpDesc *>(main_prog_.Block(0).AllOps()));
for (size_t i = 0; i < graphs_.size(); ++i) {
graphs_[i] = seq_allreduce_pass->Apply(std::move(graphs_[i]));
}
......
......@@ -46,7 +46,6 @@ class ParallelSSAGraphExecutor : public SSAGraphExecutor {
std::vector<Scope *> local_scopes_;
std::unique_ptr<::ThreadPool> pool_{nullptr};
std::vector<platform::Place> places_;
framework::ProgramDesc main_prog_;
std::vector<std::unique_ptr<ir::Graph>> graphs_;
std::vector<std::unique_ptr<details::ThreadedSSAGraphExecutor>> executors_;
......
......@@ -40,7 +40,7 @@ std::unique_ptr<ir::Graph> SequentialExecutionPass::ApplyImpl(
static std::unordered_set<std::string> skip_dist_ops{
"send", "recv", "send_barrier", "fetch_barrier"};
auto &ops = Get<const std::vector<OpDesc *>>(kAllOpDescs);
auto &ops = graph->Get<const std::vector<OpDesc *>>(kAllOpDescs);
std::vector<ir::Node *> op_node_list;
op_node_list.reserve(ops.size());
......@@ -107,4 +107,4 @@ std::unique_ptr<ir::Graph> SequentialExecutionPass::ApplyImpl(
REGISTER_PASS(sequential_execution_pass,
paddle::framework::details::SequentialExecutionPass)
.RequirePassAttr(paddle::framework::details::kAllOpDescs);
.RequireGraphAttr(paddle::framework::details::kAllOpDescs);
......@@ -76,6 +76,9 @@ std::map<std::string, std::vector<ir::Node *>> Graph::InitFromProgram(
var->inputs.push_back(node);
}
}
Set<const std::vector<OpDesc *>>(
details::kAllOpDescs,
new std::vector<OpDesc *>(program.Block(0).AllOps()));
return var_nodes;
}
......
......@@ -195,12 +195,6 @@ class Graph {
return nullptr;
}
// Returns reference to the original program.
// WARN: After a series of passes, the current graph can be quite
// different from OriginProgram. Caller shouldn't assume much from
// the returned OriginProgram.
const ProgramDesc &OriginProgram() const { return program_; }
// This method takes ownership of `node`.
ir::Node *AddNode(ir::Node *node) {
PADDLE_ENFORCE(node_set_.find(node) == node_set_.end());
......
# copyright (c) 2018 paddlepaddle authors. all rights reserved.
#
# licensed under the apache license, version 2.0 (the "license");
# you may not use this file except in compliance with the license.
# you may obtain a copy of the license at
#
# http://www.apache.org/licenses/license-2.0
#
# unless required by applicable law or agreed to in writing, software
# distributed under the license is distributed on an "as is" basis,
# without warranties or conditions of any kind, either express or implied.
# see the license for the specific language governing permissions and
# limitations under the license.
import unittest
import random
import numpy as np
import paddle.fluid as fluid
import six
from paddle.fluid.framework import Program
from paddle.fluid.framework import IrGraph
from paddle.fluid.contrib.slim.quantization import QuantizationTransformPass
from paddle.fluid import core
def linear_fc(num):
data = fluid.layers.data(name='image', shape=[1, 28, 28], dtype='float32')
label = fluid.layers.data(name='label', shape=[1], dtype='int64')
hidden = data
for _ in six.moves.xrange(num):
hidden = fluid.layers.fc(hidden, size=128, act='relu')
fc = fluid.layers.fc(input=hidden, size=10)
loss = fluid.layers.softmax_with_cross_entropy(fc, label=label)
loss = fluid.layers.mean(loss)
return loss
def residual_block(num):
def conv_bn_layer(input,
ch_out,
filter_size,
stride,
padding,
act='relu',
bias_attr=False):
tmp = fluid.layers.conv2d(
input=input,
filter_size=filter_size,
num_filters=ch_out,
stride=stride,
padding=padding,
act=None,
bias_attr=bias_attr)
return fluid.layers.batch_norm(input=tmp, act=act)
data = fluid.layers.data(name='image', shape=[1, 28, 28], dtype='float32')
label = fluid.layers.data(name='label', shape=[1], dtype='int64')
hidden = data
for _ in six.moves.xrange(num):
conv = conv_bn_layer(hidden, 16, 3, 1, 1, act=None, bias_attr=True)
short = conv_bn_layer(hidden, 16, 1, 1, 0, act=None)
hidden = fluid.layers.elementwise_add(x=conv, y=short, act='relu')
fc = fluid.layers.fc(input=hidden, size=10)
loss = fluid.layers.softmax_with_cross_entropy(fc, label)
loss = fluid.layers.mean(loss)
return loss
class TestQuantizationTransformPass(unittest.TestCase):
def setUp(self):
self.quantizable_op_and_inputs = {
'conv2d': ['Input', 'Filter'],
'depthwise_conv2d': ['Input', 'Filter'],
'mul': ['X', 'Y']
}
self.quantizable_grad_op_inputs = {
'conv2d_grad': ['Input', 'Filter'],
'depthwise_conv2d_grad': ['Input', 'Filter'],
'mul_grad': ['X', 'Y']
}
def check_program(self, transform_pass, program):
quantized_ops = set()
for block in program.blocks:
for op in block.ops:
# check forward
if op.type in self.quantizable_op_and_inputs:
for arg_name in op.input_arg_names:
self.assertTrue(
arg_name.endswith('.quantized.dequantized'))
quantized_ops.add(arg_name)
for op in block.ops:
# check backward
if op.type in self.quantizable_grad_op_inputs:
for pname in self.quantizable_grad_op_inputs[op.type]:
arg_name = op.input(pname)[0]
self.assertTrue(
arg_name.endswith('.quantized.dequantized'))
self.assertTrue(arg_name in quantized_ops)
def linear_fc_quant(self, quant_type):
main = fluid.Program()
startup = fluid.Program()
with fluid.program_guard(main, startup):
loss = linear_fc(3)
opt = fluid.optimizer.Adam(learning_rate=0.001)
opt.minimize(loss)
exe = fluid.Executor(fluid.CPUPlace())
graph = IrGraph(core.Graph(main.desc), for_test=False)
transform_pass = QuantizationTransformPass(
scope=fluid.global_scope(),
program_exe=exe,
activation_quantize_type=quant_type)
transform_pass.apply(graph)
marked_nodes = set()
for op in graph.all_ops():
if op.name().find('quantize') > -1:
marked_nodes.add(op)
graph.draw('.', 'quantize_fc_' + quant_type, marked_nodes)
program = graph.to_program()
self.check_program(transform_pass, program)
val_graph = IrGraph(core.Graph(program.desc), for_test=False)
val_marked_nodes = set()
for op in val_graph.all_ops():
if op.name().find('quantize') > -1:
val_marked_nodes.add(op)
val_graph.draw('.', 'val_fc_' + quant_type, val_marked_nodes)
def test_linear_fc_quant_abs_max(self):
self.act_quant_op_type = 'fake_quantize_abs_max'
self.linear_fc_quant('abs_max')
def test_linear_fc_quant_range_abs_max(self):
self.act_quant_op_type = 'fake_quantize_range_abs_max'
self.linear_fc_quant('range_abs_max')
def residual_block_quant(self, quant_type):
main = fluid.Program()
startup = fluid.Program()
with fluid.program_guard(main, startup):
loss = residual_block(2)
opt = fluid.optimizer.Adam(learning_rate=0.001)
opt.minimize(loss)
exe = fluid.Executor(fluid.CPUPlace())
graph = IrGraph(core.Graph(main.desc), for_test=False)
transform_pass = QuantizationTransformPass(
scope=fluid.global_scope(),
program_exe=exe,
activation_quantize_type=quant_type)
transform_pass.apply(graph)
marked_nodes = set()
for op in graph.all_ops():
if op.name().find('quantize') > -1:
marked_nodes.add(op)
graph.draw('.', 'quantize_residual_' + quant_type, marked_nodes)
program = graph.to_program()
self.check_program(transform_pass, program)
val_graph = IrGraph(core.Graph(program.desc), for_test=False)
val_marked_nodes = set()
for op in val_graph.all_ops():
if op.name().find('quantize') > -1:
val_marked_nodes.add(op)
val_graph.draw('.', 'val_residual_' + quant_type, val_marked_nodes)
def test_residual_block_abs_max(self):
self.act_quant_op_type = 'fake_quantize_abs_max'
self.residual_block_quant('abs_max')
def test_residual_block_range_abs_max(self):
self.act_quant_op_type = 'fake_quantize_range_abs_max'
self.residual_block_quant('range_abs_max')
def test_execute_graph(self):
main = fluid.Program()
startup = fluid.Program()
with fluid.program_guard(main, startup):
loss = linear_fc(3)
opt = fluid.optimizer.Adam(learning_rate=0.0001)
opt.minimize(loss)
exe = fluid.Executor(fluid.CPUPlace())
graph = IrGraph(core.Graph(main.desc), for_test=False)
exe.run(startup)
binary = fluid.CompiledProgram(graph.graph).with_data_parallel(
loss_name=loss.name)
for i in range(10):
loss_val = exe.run(binary,
feed={
'image': np.ones(
[32, 784], dtype=np.float32),
'label': np.ones(
[32, 1], dtype=np.int64)
},
fetch_list=[loss])
if i == 0:
start_loss = np.sum(loss_val)
elif i == 9:
end_loss = np.sum(loss_val)
self.assertLess(end_loss, start_loss)
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册