提交 c67b29c1 编写于 作者: W WangZhen

fix some bugs of graph.to_program and get_pass.

上级 c64f2204
...@@ -58,7 +58,6 @@ void BindGraph(py::module *m) { ...@@ -58,7 +58,6 @@ void BindGraph(py::module *m) {
.def("get_float", &Graph::Get<float>) .def("get_float", &Graph::Get<float>)
.def("get_double", &Graph::Get<double>) .def("get_double", &Graph::Get<double>)
.def("get_string", &Graph::Get<std::string>) .def("get_string", &Graph::Get<std::string>)
.def("get_program", &Graph::Get<ProgramDesc>)
.def("get_marked_nodes", &Graph::Get<std::unordered_set<const Node *>>) .def("get_marked_nodes", &Graph::Get<std::unordered_set<const Node *>>)
.def("set", [](Graph &self, const std::string &attr_name, .def("set", [](Graph &self, const std::string &attr_name,
int attr) { return self.Set(attr_name, new int(attr)); }) int attr) { return self.Set(attr_name, new int(attr)); })
...@@ -75,11 +74,6 @@ void BindGraph(py::module *m) { ...@@ -75,11 +74,6 @@ void BindGraph(py::module *m) {
[](Graph &self, const std::string &attr_name, double attr) { [](Graph &self, const std::string &attr_name, double attr) {
return self.Set(attr_name, new double(attr)); return self.Set(attr_name, new double(attr));
}) })
.def("set",
[](Graph &self, const std::string &attr_name,
const ProgramDesc &attr) {
return self.Set(attr_name, new ProgramDesc(attr));
})
.def("set", .def("set",
[](Graph &self, const std::string &attr_name, [](Graph &self, const std::string &attr_name,
const std::unordered_set<const Node *> &attr) { const std::unordered_set<const Node *> &attr) {
......
...@@ -788,8 +788,7 @@ All parameter, weight, gradient are variables in Paddle. ...@@ -788,8 +788,7 @@ All parameter, weight, gradient are variables in Paddle.
m.def("disable_profiler", platform::DisableProfiler); m.def("disable_profiler", platform::DisableProfiler);
m.def("is_profiler_enabled", platform::IsProfileEnabled); m.def("is_profiler_enabled", platform::IsProfileEnabled);
m.def("reset_profiler", platform::ResetProfiler); m.def("reset_profiler", platform::ResetProfiler);
m.def("get_pass", [](const py::bytes &binary_str) { m.def("get_pass", [](const std::string &pass_type) {
std::string pass_type(binary_str);
auto pass = framework::ir::PassRegistry::Instance().Get(pass_type); auto pass = framework::ir::PassRegistry::Instance().Get(pass_type);
return std::shared_ptr<framework::ir::Pass>(std::move(pass)); return std::shared_ptr<framework::ir::Pass>(std::move(pass));
}); });
...@@ -797,10 +796,9 @@ All parameter, weight, gradient are variables in Paddle. ...@@ -797,10 +796,9 @@ All parameter, weight, gradient are variables in Paddle.
py::class_<ir::Pass, std::shared_ptr<ir::Pass>> pass(m, "Pass"); py::class_<ir::Pass, std::shared_ptr<ir::Pass>> pass(m, "Pass");
pass.def(py::init()) pass.def(py::init())
.def("has", &ir::Pass::Has) .def("has", &ir::Pass::Has)
.def("set", .def("set_not_owned",
[](ir::Pass &self, const std::string &attr_name, [](ir::Pass &self, const std::string &attr_name, ProgramDesc &attr) {
const ProgramDesc &attr) { self.SetNotOwned<ProgramDesc>(attr_name, &attr);
return self.Set(attr_name, new ProgramDesc(attr));
}) })
.def( .def(
"set", "set",
...@@ -809,7 +807,6 @@ All parameter, weight, gradient are variables in Paddle. ...@@ -809,7 +807,6 @@ All parameter, weight, gradient are variables in Paddle.
}) })
.def("set", [](ir::Pass &self, const std::string &name, .def("set", [](ir::Pass &self, const std::string &name,
int val) { self.Set<const int>(name, new int(val)); }) int val) { self.Set<const int>(name, new int(val)); })
.def("get_program", &ir::Pass::Get<ProgramDesc>)
.def("type", &ir::Pass::Type) .def("type", &ir::Pass::Type)
.def("apply", [](ir::Pass &self, std::shared_ptr<ir::Graph> graph) { .def("apply", [](ir::Pass &self, std::shared_ptr<ir::Graph> graph) {
std::unique_ptr<ir::Graph> origin_graph(graph.get()); std::unique_ptr<ir::Graph> origin_graph(graph.get());
......
...@@ -248,8 +248,8 @@ class TestQuantizationFreezePass(unittest.TestCase): ...@@ -248,8 +248,8 @@ class TestQuantizationFreezePass(unittest.TestCase):
quantized_main_program = main_graph.to_program() quantized_main_program = main_graph.to_program()
quantized_test_program = test_graph.to_program() quantized_test_program = test_graph.to_program()
iters = 5 iters = 10
batch_size = 8 batch_size = 128
train_reader = paddle.batch( train_reader = paddle.batch(
paddle.reader.shuffle( paddle.reader.shuffle(
......
...@@ -204,7 +204,7 @@ class TestQuantizeTranspiler(unittest.TestCase): ...@@ -204,7 +204,7 @@ class TestQuantizeTranspiler(unittest.TestCase):
build_program(test_program, startup, True) build_program(test_program, startup, True)
test_program = test_program.clone(for_test=True) test_program = test_program.clone(for_test=True)
quant_type = 'abs_max' quant_type = 'range_abs_max'
quant_transpiler = QuantizeTranspiler( quant_transpiler = QuantizeTranspiler(
activation_quantize_type=quant_type) activation_quantize_type=quant_type)
quant_transpiler.training_transpile(main, startup) quant_transpiler.training_transpile(main, startup)
......
...@@ -1683,9 +1683,9 @@ class IrGraph(object): ...@@ -1683,9 +1683,9 @@ class IrGraph(object):
def to_program(self): def to_program(self):
convert_pass = core.get_pass('graph_to_program_pass') convert_pass = core.get_pass('graph_to_program_pass')
convert_pass.set('program', Program().desc) desc = core.ProgramDesc()
convert_pass.set_not_owned('program', desc)
convert_pass.apply(self.graph) convert_pass.apply(self.graph)
desc = convert_pass.get_program('program')
program = Program._construct_from_desc(desc) program = Program._construct_from_desc(desc)
return program return program
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册