提交 32a20149 编写于 作者: D dzhwinter

refine build strategy. test=develop

上级 1a44b2fb
......@@ -44,28 +44,18 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder {
public:
explicit ParallelExecutorPassBuilder(const BuildStrategy &strategy)
: ir::PassBuilder(), strategy_(strategy) {
if (strategy_.enable_inplace_) {
// before inplaced
// if (!strategy_.debug_graphviz_path_.empty()) {
// const std::string path = strategy_.debug_graphviz_path_ +
// "before_inplaced";
// auto pass = AppendPass("graph_print_pass");
// pass->Set<std::string>(kGraphvizPath, new std::string(path));
// }
if (strategy_.enable_sequential_execution_) {
AppendPass("sequential_execution_pass");
}
AppendPass("inplace_pass");
// after inplaced
// if (!strategy_.debug_graphviz_path_.empty()) {
// const std::string path = strategy_.debug_graphviz_path_ +
// "after_inplaced";
// auto pass = AppendPass("graph_print_pass");
// pass->Set<std::string>(details::kGraphvizPath, new
// std::string(path));
// }
// Add op fusion.
if (strategy.fuse_relu_depthwise_conv_) {
AppendPass("fuse_relu_depthwise_conv_pass");
}
if (strategy_.enable_sequential_execution_) {
AppendPass("sequential_execution_pass");
// Add automatically inplace.
if (strategy_.enable_inplace_) {
AppendPass("inplace_pass");
}
// Add a graph viz pass to record a graph.
......@@ -76,10 +66,6 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder {
viz_pass->Set<std::string>("graph_viz_path", new std::string(graph_path));
}
// Add op fusion.
if (strategy.fuse_relu_depthwise_conv_) {
AppendPass("fuse_relu_depthwise_conv_pass");
}
if (strategy.fuse_elewise_add_act_ops_) {
auto fuse_elewise_add_act_pass = AppendPass("fuse_elewise_add_act_pass");
// Add a graph viz pass to record a graph.
......
......@@ -74,40 +74,6 @@ std::vector<T*> FilterByNodeWrapper(const Container& con) {
return ret;
}
// bool DetectCircleRecursive(const std::map<ir::Node*,
// std::unordered_set<ir::Node*>>, std::unordered_set<ir::Node*>* visited,
// std::unordered_set<ir::Node*> *in_trace, std::vector<std::vector<ir::Node*>>*
// circles) {
// if (visited->find(node) == visited->end()) {
// visited->insert(node);
// in_trace->insert(node);
// for (ir::Node *in : adj_list.at(node)) {
// if (visited->find(in) == visited->end() &&
// HasCircleHelper(in, adj_list, visited, in_trace)) {
// return true;
// } else if (in_trace->find(in) != in_trace->end()) {
// circles->push_back(in_trace);
// return true;
// }
// }
// }
// in_trace->erase(node);
// return false;
// }
// bool DetectCircle(const std::map<ir::Node*, std::unordered_set<ir::Node*>>&
// adj_list, std::vector<std::vector<ir::Node*>>* circles) {
// std::unordered_set<ir::Node *> visited;
// std::unordered_set<ir::Node *> in_trace;
// bool has_circle = false;
// for(auto& adj : adj_list) {
// has_circle &= DetectCircleRecursive(adj, adj_list,&visited, &in_trace,
// circles);
// }
// return has_circle;
// }
std::unordered_map<ir::Node*, int> SSAGraphPrinterImpl::ToGraphvizNode(
const ir::Graph& graph) const {
// Convert to GraphvizNode format
......@@ -125,8 +91,6 @@ std::unordered_map<ir::Node*, int> SSAGraphPrinterImpl::ToGraphvizNode(
std::unique_ptr<GraphvizOp> op(new GraphvizOp(node, op_id++));
ops[node] = op.get();
graphviz_nodes.emplace(std::move(op));
// graphviz_nodes.emplace(new GraphvizOp(node, op_id++));
// ops.emplace(std::make_pair(node, graphviz_nodes.back().get()));
} else {
PADDLE_THROW("Unknown op type");
}
......
......@@ -100,6 +100,7 @@ static inline ir::Node* GetNextCascadeInplacedVar(ir::Node* var) {
static inline ir::Node* GetPrevCascadeInplacedVar(ir::Node* var) {
PADDLE_ENFORCE(var && var->IsVar() && !var->IsCtrlVar());
if (var->inputs.empty()) return nullptr;
auto* prev_op = var->inputs.at(0);
auto input_it = std::find_if(prev_op->inputs.begin(), prev_op->inputs.end(),
[&](ir::Node* node) {
......@@ -165,12 +166,6 @@ std::unique_ptr<ir::Graph> InplacePass::ApplyImpl(
view_.Build(graph.get());
InitSSAGraphNodes();
std::unique_ptr<SSAGraphPrinter> printer(new SSAGraphPrinterImpl);
constexpr char graph_path1[] = "ir_graph_before_inplaced.txt";
std::unique_ptr<std::ostream> fout1(new std::ofstream(graph_path1));
PADDLE_ENFORCE(fout1->good());
printer->Print(*graph, *fout1);
for (auto* op : view_.AllOps()) {
if (FLAGS_enable_inplace_whitelist && !whitelist_.count(op->Name()))
continue;
......@@ -178,10 +173,6 @@ std::unique_ptr<ir::Graph> InplacePass::ApplyImpl(
}
graph->ResolveHazard(var_nodes_);
constexpr char graph_path[] = "ir_graph_inplaced.txt";
std::unique_ptr<std::ostream> fout(new std::ofstream(graph_path));
PADDLE_ENFORCE(fout->good());
printer->Print(*graph, *fout);
return graph;
}
......@@ -291,6 +282,7 @@ void InplacePass::WithdrawModify(const SSANodePair& nodes,
void InplacePass::TryInplaceOpInputOutput(ir::Node* op,
ir::Graph* graph) const {
VLOG(4) << "Try to inplace op " << op->Name();
PADDLE_ENFORCE(op->Op() != nullptr && op->Op()->Block() != nullptr,
"op_desc is nullptr");
// 4 pre-requirments need to meet if the op want to inplaced.
......
......@@ -25,6 +25,7 @@ import paddle.fluid.layers as layers
import paddle.fluid.optimizer as optimizer
from paddle.fluid.framework import Program, program_guard
from paddle.fluid.io import save_inference_model, load_inference_model
from paddle.fluid.transpiler import memory_optimize
class TestBook(unittest.TestCase):
......@@ -86,5 +87,31 @@ class TestBook(unittest.TestCase):
self.assertEqual(expected, actual)
class TestSaveInferenceModel(unittest.TestCase):
def test_save_inference_model(self):
MODEL_DIR = "./tmp/inference_model2"
init_program = Program()
program = Program()
# fake program without feed/fetch
with program_guard(program, init_program):
x = layers.data(name='x', shape=[2], dtype='float32')
y = layers.data(name='y', shape=[1], dtype='float32')
y_predict = layers.fc(input=x, size=1, act=None)
cost = layers.square_error_cost(input=y_predict, label=y)
avg_cost = layers.mean(cost)
place = core.CPUPlace()
exe = executor.Executor(place)
exe.run(init_program, feed={}, fetch_list=[])
memory_optimize(program, print_log=True)
self.assertRaises(RuntimeError,
save_inference_model(MODEL_DIR, ["x", "y"],
[avg_cost], exe, program))
if __name__ == '__main__':
unittest.main()
......@@ -277,7 +277,7 @@ class TestResnet(TestParallelExecutorBase):
use_cuda=True,
use_reduce=False,
iter=20,
delta2=1e-6):
delta2=1e-5):
if use_cuda and not core.is_compiled_with_cuda():
return
......@@ -308,7 +308,7 @@ class TestResnet(TestParallelExecutorBase):
optimizer=optimizer)
self.assertAlmostEquals(
np.mean(parallel_first_loss), single_first_loss[0], delta=1e-6)
np.mean(parallel_first_loss), single_first_loss[0], delta=1e-5)
self.assertAlmostEquals(
np.mean(parallel_last_loss), single_last_loss[0], delta=delta2)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册