未验证 提交 832bd720 编写于 作者: Z Zhen Wang 提交者: GitHub

Merge pull request #15610 from wzzju/quantization_inference_passes

Quantization inference passes
...@@ -13,10 +13,12 @@ ...@@ -13,10 +13,12 @@
// limitations under the License. // limitations under the License.
#include "paddle/fluid/pybind/ir.h" #include "paddle/fluid/pybind/ir.h"
#include <algorithm>
#include <string> #include <string>
#include <unordered_map> #include <unordered_map>
#include <unordered_set> #include <unordered_set>
#include "paddle/fluid/framework/ir/graph.h" #include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h" #include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/framework/ir/node.h" #include "paddle/fluid/framework/ir/node.h"
#include "paddle/fluid/framework/op_desc.h" #include "paddle/fluid/framework/op_desc.h"
...@@ -27,6 +29,10 @@ namespace py = pybind11; ...@@ -27,6 +29,10 @@ namespace py = pybind11;
using paddle::framework::ir::Graph; using paddle::framework::ir::Graph;
using paddle::framework::ir::Node; using paddle::framework::ir::Node;
using paddle::framework::ir::GraphSafeRemoveNodes; using paddle::framework::ir::GraphSafeRemoveNodes;
using paddle::framework::ir::HasCircle;
using paddle::framework::ir::GraphNum;
using paddle::framework::ir::TopologySortOperations;
using paddle::framework::ir::BuildOperationAdjList;
using paddle::framework::OpDesc; using paddle::framework::OpDesc;
using paddle::framework::ProgramDesc; using paddle::framework::ProgramDesc;
using paddle::framework::VarDesc; using paddle::framework::VarDesc;
...@@ -36,6 +42,12 @@ namespace paddle { ...@@ -36,6 +42,12 @@ namespace paddle {
namespace pybind { namespace pybind {
void BindGraph(py::module *m) { void BindGraph(py::module *m) {
m->def("graph_safe_remove_nodes", GraphSafeRemoveNodes); m->def("graph_safe_remove_nodes", GraphSafeRemoveNodes);
m->def("has_circle", HasCircle);
m->def("graph_num", GraphNum);
m->def("topology_sort", TopologySortOperations,
return_value_policy::reference);
m->def("build_adjacency_list", BuildOperationAdjList,
return_value_policy::reference);
py::class_<Graph, std::shared_ptr<Graph>>( py::class_<Graph, std::shared_ptr<Graph>>(
*m, "Graph", *m, "Graph",
"The graph is a Directed Acyclic Single Static Assignment Graph, see " "The graph is a Directed Acyclic Single Static Assignment Graph, see "
...@@ -46,7 +58,6 @@ void BindGraph(py::module *m) { ...@@ -46,7 +58,6 @@ void BindGraph(py::module *m) {
.def("get_float", &Graph::Get<float>) .def("get_float", &Graph::Get<float>)
.def("get_double", &Graph::Get<double>) .def("get_double", &Graph::Get<double>)
.def("get_string", &Graph::Get<std::string>) .def("get_string", &Graph::Get<std::string>)
.def("get_program", &Graph::Get<ProgramDesc>)
.def("get_marked_nodes", &Graph::Get<std::unordered_set<const Node *>>) .def("get_marked_nodes", &Graph::Get<std::unordered_set<const Node *>>)
.def("set", [](Graph &self, const std::string &attr_name, .def("set", [](Graph &self, const std::string &attr_name,
int attr) { return self.Set(attr_name, new int(attr)); }) int attr) { return self.Set(attr_name, new int(attr)); })
...@@ -63,11 +74,6 @@ void BindGraph(py::module *m) { ...@@ -63,11 +74,6 @@ void BindGraph(py::module *m) {
[](Graph &self, const std::string &attr_name, double attr) { [](Graph &self, const std::string &attr_name, double attr) {
return self.Set(attr_name, new double(attr)); return self.Set(attr_name, new double(attr));
}) })
.def("set",
[](Graph &self, const std::string &attr_name,
const ProgramDesc &attr) {
return self.Set(attr_name, new ProgramDesc(attr));
})
.def("set", .def("set",
[](Graph &self, const std::string &attr_name, [](Graph &self, const std::string &attr_name,
const std::unordered_set<const Node *> &attr) { const std::unordered_set<const Node *> &attr) {
...@@ -108,42 +114,42 @@ void BindNode(py::module *m) { ...@@ -108,42 +114,42 @@ void BindNode(py::module *m) {
.def("is_op", &Node::IsOp) .def("is_op", &Node::IsOp)
.def("is_var", &Node::IsVar) .def("is_var", &Node::IsVar)
.def("is_ctrl_var", &Node::IsCtrlVar) .def("is_ctrl_var", &Node::IsCtrlVar)
.def("clear_inputs", [](Node &self) { self.inputs.clear(); })
.def("inputs_remove", .def("inputs_remove",
[](Node &self, int node_id) { [](Node &self, int node_id) {
for (auto it = self.inputs.begin(); it != self.inputs.end(); auto pos = std::find_if(
it++) { self.inputs.begin(), self.inputs.end(),
if ((*it)->id() == node_id) { [&node_id](const Node *n) { return n->id() == node_id; });
self.inputs.erase(it); if (pos != self.inputs.end()) {
} self.inputs.erase(pos);
} }
}) })
.def("inputs_remove", .def("inputs_remove",
[](Node &self, Node &node) { [](Node &self, Node &node) {
for (auto it = self.inputs.begin(); it != self.inputs.end(); auto pos =
it++) { std::find(self.inputs.begin(), self.inputs.end(), &node);
if (*it == &node) { if (pos != self.inputs.end()) {
self.inputs.erase(it); self.inputs.erase(pos);
}
} }
}) })
.def("inputs_append", .def("inputs_append",
[](Node &self, Node &node) { self.inputs.push_back(&node); }) [](Node &self, Node &node) { self.inputs.push_back(&node); })
.def("clear_outputs", [](Node &self) { self.outputs.clear(); })
.def("outputs_remove", .def("outputs_remove",
[](Node &self, int node_id) { [](Node &self, int node_id) {
for (auto it = self.outputs.begin(); it != self.outputs.end(); auto pos = std::find_if(
it++) { self.outputs.begin(), self.outputs.end(),
if ((*it)->id() == node_id) { [&node_id](const Node *n) { return n->id() == node_id; });
self.outputs.erase(it); if (pos != self.outputs.end()) {
} self.outputs.erase(pos);
} }
}) })
.def("outputs_remove", .def("outputs_remove",
[](Node &self, Node &node) { [](Node &self, Node &node) {
for (auto it = self.outputs.begin(); it != self.outputs.end(); auto pos =
it++) { std::find(self.outputs.begin(), self.outputs.end(), &node);
if (*it == &node) { if (pos != self.outputs.end()) {
self.outputs.erase(it); self.outputs.erase(pos);
}
} }
}) })
.def("outputs_append", .def("outputs_append",
......
...@@ -829,8 +829,7 @@ All parameter, weight, gradient are variables in Paddle. ...@@ -829,8 +829,7 @@ All parameter, weight, gradient are variables in Paddle.
m.def("disable_profiler", platform::DisableProfiler); m.def("disable_profiler", platform::DisableProfiler);
m.def("is_profiler_enabled", platform::IsProfileEnabled); m.def("is_profiler_enabled", platform::IsProfileEnabled);
m.def("reset_profiler", platform::ResetProfiler); m.def("reset_profiler", platform::ResetProfiler);
m.def("get_pass", [](const py::bytes &binary_str) { m.def("get_pass", [](const std::string &pass_type) {
std::string pass_type(binary_str);
auto pass = framework::ir::PassRegistry::Instance().Get(pass_type); auto pass = framework::ir::PassRegistry::Instance().Get(pass_type);
return std::shared_ptr<framework::ir::Pass>(std::move(pass)); return std::shared_ptr<framework::ir::Pass>(std::move(pass));
}); });
...@@ -838,10 +837,9 @@ All parameter, weight, gradient are variables in Paddle. ...@@ -838,10 +837,9 @@ All parameter, weight, gradient are variables in Paddle.
py::class_<ir::Pass, std::shared_ptr<ir::Pass>> pass(m, "Pass"); py::class_<ir::Pass, std::shared_ptr<ir::Pass>> pass(m, "Pass");
pass.def(py::init()) pass.def(py::init())
.def("has", &ir::Pass::Has) .def("has", &ir::Pass::Has)
.def("set", .def("set_not_owned",
[](ir::Pass &self, const std::string &attr_name, [](ir::Pass &self, const std::string &attr_name, ProgramDesc &attr) {
const ProgramDesc &attr) { self.SetNotOwned<ProgramDesc>(attr_name, &attr);
return self.Set(attr_name, new ProgramDesc(attr));
}) })
.def( .def(
"set", "set",
...@@ -850,7 +848,6 @@ All parameter, weight, gradient are variables in Paddle. ...@@ -850,7 +848,6 @@ All parameter, weight, gradient are variables in Paddle.
}) })
.def("set", [](ir::Pass &self, const std::string &name, .def("set", [](ir::Pass &self, const std::string &name,
int val) { self.Set<const int>(name, new int(val)); }) int val) { self.Set<const int>(name, new int(val)); })
.def("get_program", &ir::Pass::Get<ProgramDesc>)
.def("type", &ir::Pass::Type) .def("type", &ir::Pass::Type)
.def("apply", [](ir::Pass &self, std::shared_ptr<ir::Graph> graph) { .def("apply", [](ir::Pass &self, std::shared_ptr<ir::Graph> graph) {
std::unique_ptr<ir::Graph> origin_graph(graph.get()); std::unique_ptr<ir::Graph> origin_graph(graph.get());
......
...@@ -64,6 +64,7 @@ if (WITH_TESTING) ...@@ -64,6 +64,7 @@ if (WITH_TESTING)
add_subdirectory(paddle/dataset/tests) add_subdirectory(paddle/dataset/tests)
add_subdirectory(paddle/fluid/tests) add_subdirectory(paddle/fluid/tests)
add_subdirectory(paddle/fluid/contrib/tests) add_subdirectory(paddle/fluid/contrib/tests)
add_subdirectory(paddle/fluid/contrib/slim/tests)
endif() endif()
install(DIRECTORY ${PADDLE_PYTHON_PACKAGE_DIR} install(DIRECTORY ${PADDLE_PYTHON_PACKAGE_DIR}
DESTINATION opt/paddle/share/wheels DESTINATION opt/paddle/share/wheels
......
file(GLOB TEST_OPS RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}" "test_*.py")
string(REPLACE ".py" "" TEST_OPS "${TEST_OPS}")
foreach(src ${TEST_OPS})
py_test(${src} SRCS ${src}.py)
endforeach()
version: 1.0 version: 1.0
include: ["./unitest/configs/pruners.yaml", "./unitest/configs/pruners_0.yaml"] include: ["./configs/pruners.yaml", "./configs/pruners_0.yaml"]
pruners: pruners:
pruner_1: pruner_1:
class: 'RatioPruner' class: 'RatioPruner'
......
...@@ -18,7 +18,7 @@ import unittest ...@@ -18,7 +18,7 @@ import unittest
class TestFactory(unittest.TestCase): class TestFactory(unittest.TestCase):
def test_parse(self): def test_parse(self):
factory = ConfigFactory('./unitest/configs/config.yaml') factory = ConfigFactory('./configs/config.yaml')
pruner = factory.instance('pruner_1') pruner = factory.instance('pruner_1')
self.assertEquals(pruner.ratios['conv1_1.w'], 0.3) self.assertEquals(pruner.ratios['conv1_1.w'], 0.3)
......
# copyright (c) 2018 paddlepaddle authors. all rights reserved.
#
# licensed under the apache license, version 2.0 (the "license");
# you may not use this file except in compliance with the license.
# you may obtain a copy of the license at
#
# http://www.apache.org/licenses/license-2.0
#
# unless required by applicable law or agreed to in writing, software
# distributed under the license is distributed on an "as is" basis,
# without warranties or conditions of any kind, either express or implied.
# see the license for the specific language governing permissions and
# limitations under the license.
from __future__ import print_function
import unittest
import paddle.fluid as fluid
import six
from paddle.fluid.framework import IrGraph
from paddle.fluid import core
def residual_block(num):
def conv_bn_layer(input,
ch_out,
filter_size,
stride,
padding,
act='relu',
bias_attr=False):
tmp = fluid.layers.conv2d(
input=input,
filter_size=filter_size,
num_filters=ch_out,
stride=stride,
padding=padding,
act=None,
bias_attr=bias_attr)
return fluid.layers.batch_norm(input=tmp, act=act)
data = fluid.layers.data(name='image', shape=[1, 32, 32], dtype='float32')
label = fluid.layers.data(name='label', shape=[1], dtype='int64')
hidden = data
for _ in six.moves.xrange(num):
conv = conv_bn_layer(hidden, 16, 3, 1, 1, act=None, bias_attr=True)
short = conv_bn_layer(hidden, 16, 1, 1, 0, act=None)
hidden = fluid.layers.elementwise_add(x=conv, y=short, act='relu')
fc = fluid.layers.fc(input=hidden, size=10)
loss = fluid.layers.cross_entropy(input=fc, label=label)
loss = fluid.layers.mean(loss)
return loss
class TestGraph(unittest.TestCase):
def test_graph_functions(self):
main = fluid.Program()
startup = fluid.Program()
with fluid.program_guard(main, startup):
loss = residual_block(2)
opt = fluid.optimizer.Adam(learning_rate=0.001)
opt.minimize(loss)
graph = IrGraph(core.Graph(main.desc), for_test=False)
marked_nodes = set()
for op in graph.all_ops():
if op.name().find('conv2d') > -1:
marked_nodes.add(op)
graph.draw('.', 'residual', marked_nodes)
self.assertFalse(graph.has_circle())
self.assertEqual(graph.graph_num(), 1)
nodes = graph.topology_sort()
self.assertEqual(len(nodes), len(graph.all_ops()))
nodes_map = graph.build_adjacency_list()
self.assertEqual(len(nodes_map), len(graph.all_ops()))
nodes_num = len(graph.all_nodes())
graph.safe_remove_nodes(marked_nodes)
self.assertEqual(len(graph.all_nodes()), nodes_num - len(marked_nodes))
if __name__ == '__main__':
unittest.main()
...@@ -17,9 +17,12 @@ import random ...@@ -17,9 +17,12 @@ import random
import numpy as np import numpy as np
import paddle.fluid as fluid import paddle.fluid as fluid
import six import six
from paddle.fluid.framework import Program import paddle
from paddle.fluid.framework import IrGraph from paddle.fluid.framework import IrGraph
from paddle.fluid.contrib.slim.quantization import QuantizationTransformPass from paddle.fluid.contrib.slim.quantization import QuantizationTransformPass
from paddle.fluid.contrib.slim.quantization import QuantizationFreezePass
from paddle.fluid.contrib.slim.quantization import ConvertToInt8Pass
from paddle.fluid.contrib.slim.quantization import TransformForMobilePass
from paddle.fluid import core from paddle.fluid import core
...@@ -65,6 +68,28 @@ def residual_block(num): ...@@ -65,6 +68,28 @@ def residual_block(num):
return loss return loss
def conv_net(img, label):
conv_pool_1 = fluid.nets.simple_img_conv_pool(
input=img,
filter_size=5,
num_filters=20,
pool_size=2,
pool_stride=2,
act="relu")
conv_pool_1 = fluid.layers.batch_norm(conv_pool_1)
conv_pool_2 = fluid.nets.simple_img_conv_pool(
input=conv_pool_1,
filter_size=5,
num_filters=50,
pool_size=2,
pool_stride=2,
act="relu")
prediction = fluid.layers.fc(input=conv_pool_2, size=10, act='softmax')
loss = fluid.layers.cross_entropy(input=prediction, label=label)
avg_loss = fluid.layers.mean(loss)
return avg_loss
class TestQuantizationTransformPass(unittest.TestCase): class TestQuantizationTransformPass(unittest.TestCase):
def setUp(self): def setUp(self):
self.quantizable_op_and_inputs = { self.quantizable_op_and_inputs = {
...@@ -171,5 +196,177 @@ class TestQuantizationTransformPass(unittest.TestCase): ...@@ -171,5 +196,177 @@ class TestQuantizationTransformPass(unittest.TestCase):
self.residual_block_quant('range_abs_max') self.residual_block_quant('range_abs_max')
class TestQuantizationFreezePass(unittest.TestCase):
def freeze_graph(self, use_cuda, seed, quant_type):
def build_program(main, startup, is_test):
main.random_seed = seed
startup.random_seed = seed
with fluid.unique_name.guard():
with fluid.program_guard(main, startup):
img = fluid.layers.data(
name='image', shape=[1, 28, 28], dtype='float32')
label = fluid.layers.data(
name='label', shape=[1], dtype='int64')
loss = conv_net(img, label)
if not is_test:
opt = fluid.optimizer.Adam(learning_rate=0.001)
opt.minimize(loss)
return [img, label], loss
random.seed(0)
np.random.seed(0)
main = fluid.Program()
startup = fluid.Program()
test_program = fluid.Program()
feeds, loss = build_program(main, startup, False)
build_program(test_program, startup, True)
test_program = test_program.clone(for_test=True)
main_graph = IrGraph(core.Graph(main.desc), for_test=False)
test_graph = IrGraph(core.Graph(test_program.desc), for_test=True)
place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace()
exe = fluid.Executor(place)
scope = fluid.Scope()
with fluid.scope_guard(scope):
exe.run(startup)
transform_pass = QuantizationTransformPass(
scope=scope, program_exe=exe, activation_quantize_type=quant_type)
transform_pass.apply(main_graph)
transform_pass.apply(test_graph)
dev_name = '_gpu_' if use_cuda else '_cpu_'
marked_nodes = set()
for op in main_graph.all_ops():
if op.name().find('quantize') > -1:
marked_nodes.add(op)
main_graph.draw('.', 'main' + dev_name + quant_type, marked_nodes)
marked_nodes = set()
for op in test_graph.all_ops():
if op.name().find('quantize') > -1:
marked_nodes.add(op)
test_graph.draw('.', 'test' + dev_name + quant_type, marked_nodes)
quantized_main_program = main_graph.to_program()
quantized_test_program = test_graph.to_program()
iters = 5
batch_size = 8
#train_exe = fluid.ParallelExecutor(
# main_program=quantized_main_program,
# use_cuda=bool(use_cuda),
# loss_name=loss.name,
# scope=scope)
train_reader = paddle.batch(
paddle.reader.shuffle(
paddle.dataset.mnist.train(), buf_size=500),
batch_size=batch_size)
test_reader = paddle.batch(
paddle.dataset.mnist.test(), batch_size=batch_size)
feeder = fluid.DataFeeder(feed_list=feeds, place=place)
with fluid.scope_guard(scope):
for _ in range(iters):
data = next(train_reader())
loss_v = exe.run(program=quantized_main_program,
feed=feeder.feed(data),
fetch_list=[loss])
#loss_v = train_exe.run(feed=feeder.feed(data),
# fetch_list=[loss.name])
#print('{}: {}'.format('loss' + dev_name + quant_type, loss_v))
test_data = next(test_reader())
with fluid.program_guard(quantized_test_program):
w_var = fluid.framework._get_var('conv2d_1.w_0.quantized',
quantized_test_program)
# Testing
with fluid.scope_guard(scope):
test_loss1, w_quant = exe.run(program=quantized_test_program,
feed=feeder.feed(test_data),
fetch_list=[loss, w_var])
# Freeze graph for inference, but the weight of fc/conv is still float type.
freeze_pass = QuantizationFreezePass(scope=scope, place=place)
freeze_pass.apply(test_graph)
marked_nodes = set()
for op in test_graph.all_ops():
if op.name().find('quantize') > -1:
marked_nodes.add(op)
test_graph.draw('.', 'test_freeze' + dev_name + quant_type,
marked_nodes)
server_program = test_graph.to_program()
with fluid.scope_guard(scope):
test_loss2, = exe.run(program=server_program,
feed=feeder.feed(test_data),
fetch_list=[loss])
self.assertAlmostEqual(test_loss1, test_loss2, delta=5e-3)
#print('{}: {}'.format('test_loss1' + dev_name + quant_type, test_loss1))
#print('{}: {}'.format('test_loss2' + dev_name + quant_type, test_loss2))
w_freeze = np.array(scope.find_var('conv2d_1.w_0').get_tensor())
# Maybe failed, this is due to the calculation precision
# self.assertAlmostEqual(np.sum(w_freeze), np.sum(w_quant))
#print('{}: {}'.format('w_freeze' + dev_name + quant_type,
# np.sum(w_freeze)))
#print('{}: {}'.format('w_quant' + dev_name + quant_type,
# np.sum(w_quant)))
# Convert parameter to 8-bit.
convert_int8_pass = ConvertToInt8Pass(scope=scope, place=place)
convert_int8_pass.apply(test_graph)
marked_nodes = set()
for op in test_graph.all_ops():
if op.name().find('quantize') > -1:
marked_nodes.add(op)
test_graph.draw('.', 'test_int8' + dev_name + quant_type, marked_nodes)
server_program_int8 = test_graph.to_program()
# Save the 8-bit parameter and model file.
with fluid.scope_guard(scope):
fluid.io.save_inference_model('server_int8' + dev_name + quant_type,
['image', 'label'], [loss], exe,
server_program_int8)
# Test whether the 8-bit parameter and model file can be loaded successfully.
[infer, feed, fetch] = fluid.io.load_inference_model(
'server_int8' + dev_name + quant_type, exe)
# Check the loaded 8-bit weight.
w_8bit = np.array(scope.find_var('conv2d_1.w_0.int8').get_tensor())
self.assertEqual(w_8bit.dtype, np.int8)
self.assertEqual(np.sum(w_8bit), np.sum(w_freeze))
#print('{}: {}'.format('w_8bit' + dev_name + quant_type, np.sum(w_8bit)))
#print('{}: {}'.format('w_freeze' + dev_name + quant_type,
# np.sum(w_freeze)))
mobile_pass = TransformForMobilePass()
mobile_pass.apply(test_graph)
marked_nodes = set()
for op in test_graph.all_ops():
if op.name().find('quantize') > -1:
marked_nodes.add(op)
test_graph.draw('.', 'test_mobile' + dev_name + quant_type,
marked_nodes)
mobile_program = test_graph.to_program()
with fluid.scope_guard(scope):
fluid.io.save_inference_model('mobile_int8' + dev_name + quant_type,
['image', 'label'], [loss], exe,
mobile_program)
def test_freeze_graph_cuda_dynamic(self):
if fluid.core.is_compiled_with_cuda():
with fluid.unique_name.guard():
self.freeze_graph(True, seed=1, quant_type='abs_max')
def test_freeze_graph_cpu_dynamic(self):
with fluid.unique_name.guard():
self.freeze_graph(False, seed=2, quant_type='abs_max')
def test_freeze_graph_cuda_static(self):
if fluid.core.is_compiled_with_cuda():
with fluid.unique_name.guard():
self.freeze_graph(True, seed=1, quant_type='range_abs_max')
def test_freeze_graph_cpu_static(self):
with fluid.unique_name.guard():
self.freeze_graph(False, seed=2, quant_type='range_abs_max')
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
...@@ -204,9 +204,11 @@ class TestQuantizeTranspiler(unittest.TestCase): ...@@ -204,9 +204,11 @@ class TestQuantizeTranspiler(unittest.TestCase):
build_program(test_program, startup, True) build_program(test_program, startup, True)
test_program = test_program.clone(for_test=True) test_program = test_program.clone(for_test=True)
quant_transpiler = QuantizeTranspiler() quant_type = 'range_abs_max' # 'range_abs_max' or 'abs_max'
quant_transpiler.training_transpile(main) quant_transpiler = QuantizeTranspiler(
quant_transpiler.training_transpile(test_program) activation_quantize_type=quant_type)
quant_transpiler.training_transpile(main, startup)
quant_transpiler.training_transpile(test_program, startup)
place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace() place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace()
exe = fluid.Executor(place) exe = fluid.Executor(place)
......
...@@ -16,6 +16,8 @@ from __future__ import print_function ...@@ -16,6 +16,8 @@ from __future__ import print_function
import collections import collections
from collections import defaultdict from collections import defaultdict
from collections import Iterable
import contextlib
from .wrapped_decorator import signature_safe_contextmanager from .wrapped_decorator import signature_safe_contextmanager
import os import os
import re import re
...@@ -1529,12 +1531,16 @@ class Block(object): ...@@ -1529,12 +1531,16 @@ class Block(object):
class IrGraph(object): class IrGraph(object):
""" """
IrGraph uses core.Graph as the delegation to accomplish the manipulation. Python IrGraph. Beneath it is a core.Graph, which is used for
create a c++ Ir Pass Graph. An IrGraph is just a graph view of
a Program. In an IrGraph, both Variables and Operators are graph
nodes.
""" """
def __init__(self, graph, for_test=False): def __init__(self, graph, for_test=False):
""" """
Construct the IrGraph using core.Graph. Construct an IrGraph using core.Graph.
Args: Args:
graph(core.Graph): C++ Graph. graph(core.Graph): C++ Graph.
for_test(bool): True for the test graph and false for the train graph. for_test(bool): True for the test graph and false for the train graph.
...@@ -1545,23 +1551,81 @@ class IrGraph(object): ...@@ -1545,23 +1551,81 @@ class IrGraph(object):
self._for_test = for_test self._for_test = for_test
def is_test(self): def is_test(self):
"""
If the graph is used for testing, the function returns true. Otherwise, returns false.
"""
return self._for_test return self._for_test
def all_parameters(self): def all_nodes(self):
param_nodes = set() """
for node in self.graph.nodes(): Return all nodes included in the graph as a set.
if node.is_var() and node.var() is not None and node.var( """
).persistable(): return {node for node in self.graph.nodes()}
param_nodes.add(node)
return param_nodes
def all_vars(self): def all_vars(self):
"""
Return all variable nodes included in the graph as a set.
"""
return {node for node in self.graph.nodes() if node.is_var()} return {node for node in self.graph.nodes() if node.is_var()}
def all_persistable_vars(self):
"""
Return all persistable variable nodes included in the graph as a set.
"""
persistable_nodes = set()
for node in self.graph.nodes():
if node.is_var() and node.var() is not None and node.var(
).persistable():
persistable_nodes.add(node)
return persistable_nodes
def all_ops(self): def all_ops(self):
"""
Return all operator nodes included in the graph as a set.
"""
return {node for node in self.graph.nodes() if node.is_op()} return {node for node in self.graph.nodes() if node.is_op()}
def var_node(self, name):
"""
Get a variable node by name from the graph.
Args:
name(str): the name of the variable node.
Raises:
ValueError: The If input's type is not str, or this graph
doesn't have a variable with the giving name.
Returns:
core.Node: the variable node with the giving name.
"""
if not isinstance(name, six.string_types):
raise TypeError(
"var require string as parameter, but get %s instead." %
(type(name)))
target_var_node = None
var_nodes = self.all_vars()
for var_node in var_nodes:
if var_node.name() == name:
target_var_node = var_node
if target_var_node is None:
raise ValueError("var_node %s not in this graph" % name)
return target_var_node
def create_param_node(self, name, var_type, shape, var_dtype): def create_param_node(self, name, var_type, shape, var_dtype):
"""
Create a persistable variable node in the graph. In IrGraph,
it can not distinguish between persistable variables and parameters.
Args:
name(str): the name of the persistable variable node.
vart_type(core.VarDesc.VarType): the type of the persistable variable node.
shape(list): the shape of the persistable variable node.
var_dtype(core.VarDesc.VarType): the data type of the persistable variable node.
Returns:
core.Node: the created persistable variable node.
"""
var_desc = core.VarDesc(name) var_desc = core.VarDesc(name)
var_desc.set_type(var_type) var_desc.set_type(var_type)
var_desc.set_shape(shape) var_desc.set_shape(shape)
...@@ -1570,6 +1634,20 @@ class IrGraph(object): ...@@ -1570,6 +1634,20 @@ class IrGraph(object):
return self.graph.create_var_node(var_desc) return self.graph.create_var_node(var_desc)
def create_var_node(self, name, var_type, shape, var_dtype): def create_var_node(self, name, var_type, shape, var_dtype):
"""
Create a variable node in the graph. The created variable node is
not persistable.
Args:
name(str): the name of the variable node.
vart_type(core.VarDesc.VarType): the type of the variable node.
shape(list): the shape of the variable node.
var_dtype(core.VarDesc.VarType): the data type of the variable node.
Returns:
core.Node: the created variable node.
"""
var_desc = core.VarDesc(name) var_desc = core.VarDesc(name)
var_desc.set_type(var_type) var_desc.set_type(var_type)
var_desc.set_shape(shape) var_desc.set_shape(shape)
...@@ -1577,19 +1655,41 @@ class IrGraph(object): ...@@ -1577,19 +1655,41 @@ class IrGraph(object):
return self.graph.create_var_node(var_desc) return self.graph.create_var_node(var_desc)
def create_var_node_from_desc(self, var_desc): def create_var_node_from_desc(self, var_desc):
"""
Create a variable node by using an existing VarDesc in the graph.
Depend on the giving VarDesc, the created variable node may be persistable.
Args:
var_desc(core.VarDesc): the giving variable description.
Returns:
core.Node: the created variable node.
"""
return self.graph.create_var_node(var_desc) return self.graph.create_var_node(var_desc)
def create_op_node(self, op_type, attrs, inputs, outputs): def create_op_node(self, op_type, attrs, inputs, outputs):
"""
Create a operator node in the graph.
Args:
op_type(str): the type of the operator node.
attrs(dict): the attributes of the operator node.
inputs(dict): the inputs of the operator node.
outputs(dict): the outpus of the operator node.
Returns:
core.Node: the created operator node.
"""
op_desc = core.OpDesc() op_desc = core.OpDesc()
op_desc.set_type(op_type) op_desc.set_type(op_type)
for attr, value in attrs.iteritems(): for attr, value in six.iteritems(attrs):
self._update_desc_attr(op_desc, attr, value) self._update_desc_attr(op_desc, attr, value)
for input_name, var_nodes in inputs.iteritems(): for input_name, var_nodes in six.iteritems(inputs):
if not isinstance(var_nodes, list): if not isinstance(var_nodes, list):
var_nodes = [var_nodes] var_nodes = [var_nodes]
op_desc.set_input(input_name, op_desc.set_input(input_name,
[var_node.name() for var_node in var_nodes]) [var_node.name() for var_node in var_nodes])
for output_name, var_nodes in outputs.iteritems(): for output_name, var_nodes in six.iteritems(outputs):
if not isinstance(var_nodes, list): if not isinstance(var_nodes, list):
var_nodes = [var_nodes] var_nodes = [var_nodes]
op_desc.set_output(output_name, op_desc.set_output(output_name,
...@@ -1597,11 +1697,29 @@ class IrGraph(object): ...@@ -1597,11 +1697,29 @@ class IrGraph(object):
return self.graph.create_op_node(op_desc) return self.graph.create_op_node(op_desc)
def create_op_node_from_desc(self, op_desc): def create_op_node_from_desc(self, op_desc):
"""
Create a operator node by using an existing OpDesc in the graph.
Args:
op_desc(core.VarDesc): the giving operator description.
Returns:
core.Node: the created operator node.
"""
return self.graph.create_op_node(op_desc) return self.graph.create_op_node(op_desc)
def update_input_link(self, old_input_node, new_input_node, op_node): def update_input_link(self, old_input_node, new_input_node, op_node):
assert old_input_node in self.graph.nodes() and new_input_node in self.graph.nodes() and \ """
op_node in self.graph.nodes(), 'Th three arguments must be in the graph nodes.' Update the input's link of a operator node.
Args:
old_input_node(core.Node): the old input node of the giving op_node.
new_input_node(core.Node): the new input node of the giving op_node.
op_node(core.Node): the operator node that is needed to update input's link.
"""
assert old_input_node in self.graph.nodes() and new_input_node in \
self.graph.nodes() and op_node in self.graph.nodes(), \
'The three arguments(old_input_node&new_input_node&op_node) must be in the graph nodes.'
old_input_node.outputs_remove(op_node) old_input_node.outputs_remove(op_node)
op_node.inputs_remove(old_input_node) op_node.inputs_remove(old_input_node)
new_input_node.outputs_append(op_node) new_input_node.outputs_append(op_node)
...@@ -1609,17 +1727,85 @@ class IrGraph(object): ...@@ -1609,17 +1727,85 @@ class IrGraph(object):
op_node.op()._rename_input(old_input_node.name(), new_input_node.name()) op_node.op()._rename_input(old_input_node.name(), new_input_node.name())
def link_to(self, node_in, node_out): def link_to(self, node_in, node_out):
"""
Connect two nodes.
Args:
node_in(core.Node): the input node.
node_out(core.Node): the output node.
"""
assert node_in in self.graph.nodes() and node_out in self.graph.nodes(), \ assert node_in in self.graph.nodes() and node_out in self.graph.nodes(), \
'Th two arguments must be in the graph nodes.' 'The two arguments(node_in&node_out) must be in the graph nodes.'
node_in.outputs_append(node_out) node_in.outputs_append(node_out)
node_out.inputs_append(node_in) node_out.inputs_append(node_in)
def safe_remove_nodes(self, remove_nodes): def safe_remove_nodes(self, remove_nodes):
"""
Remove nodes safely since links connected to these removed nodes are
also removed.
Args:
remove_nodes(set): the nodes prepared to be removed.
"""
if not isinstance(remove_nodes, set): if not isinstance(remove_nodes, set):
if isinstance(remove_nodes, Iterable):
remove_nodes = set(remove_nodes) remove_nodes = set(remove_nodes)
else:
remove_nodes = {remove_nodes}
core.graph_safe_remove_nodes(self.graph, remove_nodes) core.graph_safe_remove_nodes(self.graph, remove_nodes)
def draw(self, save_path, name, marked_nodes=None): def has_circle(self):
"""
Check if the graph has a circle.
Returns:
bool: True if the graph has a circle else False.
"""
return core.has_circle(self.graph)
def graph_num(self):
"""
Count the number of unconnected graphs in this graph.
Returns:
int: the number of unconnected graphs.
"""
return core.graph_num(self.graph)
def topology_sort(self):
"""
Perform the topology sort operation on the graph.
Notes: the `graph` cannot contain a circle.
Returns:
set(core.Node): nodes in topology order.
"""
return core.topology_sort(self.graph)
def build_adjacency_list(self):
"""
Build an adjacency list of operations for the `graph`.
Returns:
dict{core.Node: set(core.Node)}: the adjacency list.
"""
return core.build_adjacency_list(self.graph)
def draw(self, save_path, name, marked_nodes=None, remove_ctr_var=True):
"""
Draw the graph. If `dot` command is installed, the drawn graph
will be saved as pdf file type, otherwise dot file type is used.
Args:
save_path(str): the save path of drawn graph.
name(str): the name of drawn graph.
marked_nodes(set(core.Node)): nodes that are needed to be marked.
Default value is None.
remove_ctr_var(bool): If it is set True, all control variable nodes
in the graph will be removed. Default value is True.
"""
def _convert_to_pdf(dot_file_path): def _convert_to_pdf(dot_file_path):
pdf_save_path = os.path.splitext(dot_file_path)[0] + '.pdf' pdf_save_path = os.path.splitext(dot_file_path)[0] + '.pdf'
exited_code = subprocess.call('dot -Tpdf ' + dot_file_path \ exited_code = subprocess.call('dot -Tpdf ' + dot_file_path \
...@@ -1629,15 +1815,17 @@ class IrGraph(object): ...@@ -1629,15 +1815,17 @@ class IrGraph(object):
print('The {} is saved as the dot filetype.'.format( print('The {} is saved as the dot filetype.'.format(
dot_file_path)) dot_file_path))
if remove_ctr_var:
remove_ctr_vars = set() remove_ctr_vars = set()
ops_num = 0
for node in self.graph.nodes(): for node in self.graph.nodes():
if node.is_ctrl_var(): if node.is_ctrl_var():
remove_ctr_vars.add(node) remove_ctr_vars.add(node)
elif node.is_op(): self.safe_remove_nodes(remove_ctr_vars)
ops_num = 0
for node in self.graph.nodes():
if node.is_op():
ops_num += 1 ops_num += 1
print('Total ops num = {}.'.format(ops_num)) print('Total ops num = {}.'.format(ops_num))
self.safe_remove_nodes(remove_ctr_vars)
if marked_nodes is not None: if marked_nodes is not None:
if not isinstance(marked_nodes, set): if not isinstance(marked_nodes, set):
marked_nodes = set(marked_nodes) marked_nodes = set(marked_nodes)
...@@ -1652,10 +1840,20 @@ class IrGraph(object): ...@@ -1652,10 +1840,20 @@ class IrGraph(object):
_convert_to_pdf(viz_dot_path) _convert_to_pdf(viz_dot_path)
def to_program(self): def to_program(self):
"""
Convert the graph into a Program.
Notes: When the graph includes backward operator nodes, the
conversion process may be failed. Usually, this function is
only used to convert a test graph.
Returns:
Program: a program converted from the graph.
"""
convert_pass = core.get_pass('graph_to_program_pass') convert_pass = core.get_pass('graph_to_program_pass')
convert_pass.set('program', Program().desc) desc = core.ProgramDesc()
convert_pass.set_not_owned('program', desc)
convert_pass.apply(self.graph) convert_pass.apply(self.graph)
desc = convert_pass.get_program('program')
program = Program._construct_from_desc(desc) program = Program._construct_from_desc(desc)
return program return program
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册