提交 c67b29c1 编写于 作者: W WangZhen

fix some bugs of graph.to_program and get_pass.

上级 c64f2204
......@@ -58,7 +58,6 @@ void BindGraph(py::module *m) {
.def("get_float", &Graph::Get<float>)
.def("get_double", &Graph::Get<double>)
.def("get_string", &Graph::Get<std::string>)
.def("get_program", &Graph::Get<ProgramDesc>)
.def("get_marked_nodes", &Graph::Get<std::unordered_set<const Node *>>)
.def("set", [](Graph &self, const std::string &attr_name,
int attr) { return self.Set(attr_name, new int(attr)); })
......@@ -75,11 +74,6 @@ void BindGraph(py::module *m) {
[](Graph &self, const std::string &attr_name, double attr) {
return self.Set(attr_name, new double(attr));
})
.def("set",
[](Graph &self, const std::string &attr_name,
const ProgramDesc &attr) {
return self.Set(attr_name, new ProgramDesc(attr));
})
.def("set",
[](Graph &self, const std::string &attr_name,
const std::unordered_set<const Node *> &attr) {
......
......@@ -788,8 +788,7 @@ All parameter, weight, gradient are variables in Paddle.
m.def("disable_profiler", platform::DisableProfiler);
m.def("is_profiler_enabled", platform::IsProfileEnabled);
m.def("reset_profiler", platform::ResetProfiler);
m.def("get_pass", [](const py::bytes &binary_str) {
std::string pass_type(binary_str);
m.def("get_pass", [](const std::string &pass_type) {
auto pass = framework::ir::PassRegistry::Instance().Get(pass_type);
return std::shared_ptr<framework::ir::Pass>(std::move(pass));
});
......@@ -797,10 +796,9 @@ All parameter, weight, gradient are variables in Paddle.
py::class_<ir::Pass, std::shared_ptr<ir::Pass>> pass(m, "Pass");
pass.def(py::init())
.def("has", &ir::Pass::Has)
.def("set",
[](ir::Pass &self, const std::string &attr_name,
const ProgramDesc &attr) {
return self.Set(attr_name, new ProgramDesc(attr));
.def("set_not_owned",
[](ir::Pass &self, const std::string &attr_name, ProgramDesc &attr) {
self.SetNotOwned<ProgramDesc>(attr_name, &attr);
})
.def(
"set",
......@@ -809,7 +807,6 @@ All parameter, weight, gradient are variables in Paddle.
})
.def("set", [](ir::Pass &self, const std::string &name,
int val) { self.Set<const int>(name, new int(val)); })
.def("get_program", &ir::Pass::Get<ProgramDesc>)
.def("type", &ir::Pass::Type)
.def("apply", [](ir::Pass &self, std::shared_ptr<ir::Graph> graph) {
std::unique_ptr<ir::Graph> origin_graph(graph.get());
......
......@@ -248,8 +248,8 @@ class TestQuantizationFreezePass(unittest.TestCase):
quantized_main_program = main_graph.to_program()
quantized_test_program = test_graph.to_program()
iters = 5
batch_size = 8
iters = 10
batch_size = 128
train_reader = paddle.batch(
paddle.reader.shuffle(
......
......@@ -204,7 +204,7 @@ class TestQuantizeTranspiler(unittest.TestCase):
build_program(test_program, startup, True)
test_program = test_program.clone(for_test=True)
quant_type = 'abs_max'
quant_type = 'range_abs_max'
quant_transpiler = QuantizeTranspiler(
activation_quantize_type=quant_type)
quant_transpiler.training_transpile(main, startup)
......
......@@ -1683,9 +1683,9 @@ class IrGraph(object):
def to_program(self):
convert_pass = core.get_pass('graph_to_program_pass')
convert_pass.set('program', Program().desc)
desc = core.ProgramDesc()
convert_pass.set_not_owned('program', desc)
convert_pass.apply(self.graph)
desc = convert_pass.get_program('program')
program = Program._construct_from_desc(desc)
return program
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册