提交 a3a4b6e5 编写于 作者: B baojun 提交者: Tao Luo

Enable ngraph through build_strategy (#19266)

* enable ngraph throught build_strategy test=develop

* add unittest test=develop

* put use_ngraph unconditional test=develop

* remove paddle_enforce test=develop

* remove paddle_enforce test=develop

* fix copyright test=develop

* limit for ngraph only test=develop
上级 4cfe432c
...@@ -87,6 +87,12 @@ cc_library(fast_threaded_ssa_graph_executor SRCS fast_threaded_ssa_graph_executo ...@@ -87,6 +87,12 @@ cc_library(fast_threaded_ssa_graph_executor SRCS fast_threaded_ssa_graph_executo
DEPS fetch_op_handle ssa_graph_executor scope simple_threadpool device_context) DEPS fetch_op_handle ssa_graph_executor scope simple_threadpool device_context)
cc_test(fused_broadcast_op_test SRCS fused_broadcast_op_handle_test.cc DEPS fused_broadcast_op_handle) cc_test(fused_broadcast_op_test SRCS fused_broadcast_op_handle_test.cc DEPS fused_broadcast_op_handle)
if(WITH_NGRAPH)
set(NGRAPH_BS_DEPS ngraph)
else()
set(NGRAPH_BS_DEPS)
endif()
cc_library(build_strategy SRCS build_strategy.cc DEPS cc_library(build_strategy SRCS build_strategy.cc DEPS
graph_viz_pass multi_devices_graph_pass graph_viz_pass multi_devices_graph_pass
multi_devices_graph_print_pass multi_devices_graph_check_pass multi_devices_graph_print_pass multi_devices_graph_check_pass
...@@ -94,4 +100,5 @@ cc_library(build_strategy SRCS build_strategy.cc DEPS ...@@ -94,4 +100,5 @@ cc_library(build_strategy SRCS build_strategy.cc DEPS
fuse_relu_depthwise_conv_pass fuse_relu_depthwise_conv_pass
lock_free_optimize_pass lock_free_optimize_pass
coalesce_grad_tensor_pass fuse_all_reduce_op_pass backward_optimizer_op_deps_pass coalesce_grad_tensor_pass fuse_all_reduce_op_pass backward_optimizer_op_deps_pass
fuse_adam_op_pass fuse_sgd_op_pass fuse_momentum_op_pass) fuse_adam_op_pass fuse_sgd_op_pass fuse_momentum_op_pass
${NGRAPH_BS_DEPS})
...@@ -27,6 +27,7 @@ limitations under the License. */ ...@@ -27,6 +27,7 @@ limitations under the License. */
#include "paddle/fluid/framework/ir/multi_devices_graph_pass/multi_devices_graph_pass.h" #include "paddle/fluid/framework/ir/multi_devices_graph_pass/multi_devices_graph_pass.h"
DECLARE_bool(use_mkldnn); DECLARE_bool(use_mkldnn);
DECLARE_bool(use_ngraph);
namespace paddle { namespace paddle {
namespace framework { namespace framework {
...@@ -53,6 +54,8 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder { ...@@ -53,6 +54,8 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder {
"sequential_execution_pass"); "sequential_execution_pass");
AppendPassWithCheck(strategy_.sync_batch_norm_, "sync_batch_norm_pass"); AppendPassWithCheck(strategy_.sync_batch_norm_, "sync_batch_norm_pass");
AppendPassToUseNgraph("ngraph_subgraph_pass");
AppendOpFusePasses(); AppendOpFusePasses();
AppendPrintGraphPass("graph_viz_pass", "_fused_graph"); AppendPrintGraphPass("graph_viz_pass", "_fused_graph");
...@@ -220,6 +223,22 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder { ...@@ -220,6 +223,22 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder {
#endif #endif
} }
void AppendPassToUseNgraph(const std::string &pass_name) {
#ifdef PADDLE_WITH_NGRAPH
if (FLAGS_use_ngraph) {
if (strategy_.reduce_ != BuildStrategy::ReduceStrategy::kAllReduce) {
LOG(WARNING) << "Currently ngraph_subgraph_pass works under AllReduce,"
"please set FLAGS_use_ngraph=false.";
} else {
AppendPass(pass_name);
}
}
#else
PADDLE_ENFORCE_NE(FLAGS_use_ngraph, true,
"Please compile with NGRAPH first to use NGRAPH");
#endif
}
private: private:
BuildStrategy strategy_; BuildStrategy strategy_;
}; };
...@@ -360,3 +379,6 @@ USE_PASS(runtime_context_cache_pass); ...@@ -360,3 +379,6 @@ USE_PASS(runtime_context_cache_pass);
#ifdef PADDLE_WITH_MKLDNN #ifdef PADDLE_WITH_MKLDNN
USE_PASS(mkldnn_placement_pass); USE_PASS(mkldnn_placement_pass);
#endif #endif
#ifdef PADDLE_WITH_NGRAPH
USE_PASS(ngraph_subgraph_pass);
#endif
...@@ -39,11 +39,11 @@ limitations under the License. */ ...@@ -39,11 +39,11 @@ limitations under the License. */
#ifdef PADDLE_WITH_NGRAPH #ifdef PADDLE_WITH_NGRAPH
#include "paddle/fluid/operators/ngraph/ngraph_engine.h" #include "paddle/fluid/operators/ngraph/ngraph_engine.h"
DEFINE_bool(use_ngraph, false, "Use NGRAPH to run");
#endif #endif
DECLARE_bool(benchmark); DECLARE_bool(benchmark);
DEFINE_bool(use_mkldnn, false, "Use MKLDNN to run"); DEFINE_bool(use_mkldnn, false, "Use MKLDNN to run");
DEFINE_bool(use_ngraph, false, "Use NGRAPH to run");
namespace paddle { namespace paddle {
namespace framework { namespace framework {
......
...@@ -47,8 +47,8 @@ std::string GenerateEngineKey(const std::set<std::string> &engine_inputs, ...@@ -47,8 +47,8 @@ std::string GenerateEngineKey(const std::set<std::string> &engine_inputs,
return engine_key; return engine_key;
} }
void NgraphSubgraphPass::ApplyImpl(ir::Graph *graph) const { void NgraphSubgraphPass::ApplyImpl(Graph *graph) const {
PADDLE_ENFORCE(graph); PADDLE_ENFORCE_NOT_NULL(graph);
FusePassBase::Init("ngraph_subgraph_pass", graph); FusePassBase::Init("ngraph_subgraph_pass", graph);
std::unordered_set<Node *> nodes2delete; std::unordered_set<Node *> nodes2delete;
...@@ -66,15 +66,13 @@ void NgraphSubgraphPass::ApplyImpl(ir::Graph *graph) const { ...@@ -66,15 +66,13 @@ void NgraphSubgraphPass::ApplyImpl(ir::Graph *graph) const {
if (node->IsOp() && !ANAT::Agent(node).subgraph()->empty()) { if (node->IsOp() && !ANAT::Agent(node).subgraph()->empty()) {
OpDesc *op_desc = node->Op(); OpDesc *op_desc = node->Op();
op_desc->SetType("ngraph_engine"); op_desc->SetType("ngraph_engine");
for (auto it = ANAT::Agent(node).subgraph()->begin();
it != ANAT::Agent(node).subgraph()->end(); ++it) {
}
CreateNgraphEngineOp(node, graph); CreateNgraphEngineOp(node, graph);
std::unordered_set<const Node *> nodes2remove( std::unordered_set<const Node *> nodes2remove(
ANAT::Agent(node).subgraph()->begin(), ANAT::Agent(node).subgraph()->begin(),
ANAT::Agent(node).subgraph()->end()); ANAT::Agent(node).subgraph()->end());
GraphSafeRemoveNodes(graph, nodes2remove); GraphSafeRemoveNodes(graph, nodes2remove);
} }
} }
...@@ -85,70 +83,100 @@ void NgraphSubgraphPass::ApplyImpl(ir::Graph *graph) const { ...@@ -85,70 +83,100 @@ void NgraphSubgraphPass::ApplyImpl(ir::Graph *graph) const {
nodes2remove.insert(node); nodes2remove.insert(node);
} }
} }
framework::ir::GraphSafeRemoveNodes(graph, nodes2remove); framework::ir::GraphSafeRemoveNodes(graph, nodes2remove);
std::vector<ir::Node *> nodes = ir::TopologySortOperations(*graph); // std::vector<ir::Node *> nodes = ir::TopologySortOperations(*graph);
} }
void NgraphSubgraphPass::CreateNgraphEngineOp(framework::ir::Node *node, bool IsValid(std::string name) {
Graph *graph) const { return name.find(Node::kControlDepVarName) == std::string::npos;
auto *op_desc = node->Op(); }
void UpdateNgraphIO(Node *node, Graph *graph,
std::vector<std::string> *input_names,
std::vector<std::string> *output_names) {
bool is_test = true, has_fetch = false;
for (Node *node : graph->Nodes()) {
if (node->IsOp() && node->Name().find("_grad") != std::string::npos) {
is_test = false;
}
if (node->IsVar() && node->Var()) {
for (auto out : node->outputs) {
if (out->Name() == "fetch") has_fetch = true;
}
}
}
if (is_test && has_fetch) {
for (auto *x : node->inputs) {
(*input_names).emplace_back(x->Name());
}
for (auto *x : node->outputs) {
(*output_names).emplace_back(x->Name());
}
return;
}
auto &subgraph = *ANAT::Agent(node).subgraph(); auto &subgraph = *ANAT::Agent(node).subgraph();
PADDLE_ENFORCE(!subgraph.empty()); std::unordered_set<std::string> inputs;
std::unordered_set<std::string> outputs;
for (auto *node : subgraph) {
for (auto in : node->inputs) {
auto name = in->Name();
if (!IsValid(name)) continue;
if (!outputs.count(name) && !inputs.count(name)) {
(*input_names).emplace_back(name);
inputs.insert(name);
}
}
for (auto out : node->outputs) {
auto name = out->Name();
if (!IsValid(name)) continue;
outputs.insert(name);
(*output_names).emplace_back(name);
}
}
}
framework::ProgramDesc *program_desc = void NgraphSubgraphPass::CreateNgraphEngineOp(Node *node, Graph *graph) const {
Get<framework::ProgramDesc *>("program"); auto &subgraph = *ANAT::Agent(node).subgraph();
const framework::BlockDesc &main_block = PADDLE_ENFORCE_NE(subgraph.empty(), true, "subgraph cannot be empty");
program_desc->Block(framework::kRootBlockIndex);
framework::BlockDesc *new_block = program_desc->AppendBlock(main_block);
framework::proto::BlockDesc block_proto; framework::proto::BlockDesc block_proto;
framework::BlockDesc block_desc(nullptr, &block_proto); framework::BlockDesc block_desc(nullptr, &block_proto);
block_desc.Proto()->set_parent_idx(-1); block_desc.Proto()->set_parent_idx(-1);
block_desc.Proto()->set_idx(0); block_desc.Proto()->set_idx(0);
for (auto *node : subgraph) { for (auto *node : subgraph) {
auto *new_block_op = new_block->AppendOp();
auto *op = block_desc.AppendOp(); auto *op = block_desc.AppendOp();
*new_block_op->Proto() = *node->Op()->Proto();
*op->Proto() = *node->Op()->Proto(); *op->Proto() = *node->Op()->Proto();
} }
std::set<std::string> input_names;
std::set<std::string> input_names_with_id;
for (auto *x : node->inputs) {
input_names.insert(x->Name());
input_names_with_id.insert(x->Name() + std::to_string(x->id()));
}
op_desc->SetInput(
"Xs", std::vector<std::string>(input_names.begin(), input_names.end()));
std::set<std::string> output_names;
std::set<std::string> output_names_with_id;
for (auto *x : node->outputs) {
output_names.insert(x->Name());
output_names_with_id.insert(x->Name() + std::to_string(x->id()));
}
op_desc->SetOutput(
"Ys", std::vector<std::string>(output_names.begin(), output_names.end()));
auto *vars = block_desc.Proto()->mutable_vars(); auto *vars = block_desc.Proto()->mutable_vars();
for (framework::ir::Node *node : graph->Nodes()) { for (Node *node : graph->Nodes()) {
if (node->IsVar() && node->Var()) { if (node->IsVar() && node->Var()) {
*vars->Add() = *node->Var()->Proto(); *vars->Add() = *node->Var()->Proto();
} }
} }
PADDLE_ENFORCE_NE(block_desc.Proto()->vars().empty(), true,
PADDLE_ENFORCE(!block_desc.Proto()->vars().empty(),
"the block has no var-desc"); "the block has no var-desc");
op_desc->SetType("ngraph_engine"); std::vector<std::string> input_names;
std::vector<std::string> output_names;
UpdateNgraphIO(node, graph, &input_names, &output_names);
auto *op_desc = node->Op();
op_desc->SetInput(
"Xs", std::vector<std::string>(input_names.begin(), input_names.end()));
op_desc->SetOutput(
"Ys", std::vector<std::string>(output_names.begin(), output_names.end()));
int sgs = subgraph.size(); int sgs = subgraph.size();
std::string engine_key = GenerateEngineKey( std::string subgraph_str = block_desc.Proto()->SerializeAsString();
input_names_with_id, output_names_with_id, std::to_string(sgs)); std::string engine_key =
std::to_string(std::hash<std::string>()(subgraph_str));
std::vector<int> interval{0, sgs}; std::vector<int> interval{0, sgs};
op_desc->SetType("ngraph_engine");
op_desc->SetAttr("interval", interval); op_desc->SetAttr("interval", interval);
op_desc->SetAttr("graph", block_desc.Proto()->SerializeAsString()); op_desc->SetAttr("graph", subgraph_str);
op_desc->SetAttr("engine_key", engine_key); op_desc->SetAttr("engine_key", engine_key);
op_desc->SetAttr("op_role", 0);
} }
} // namespace ir } // namespace ir
......
...@@ -21,6 +21,8 @@ limitations under the License. */ ...@@ -21,6 +21,8 @@ limitations under the License. */
#include "paddle/fluid/framework/ir/graph_pattern_detector.h" #include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/framework/ir/node.h" #include "paddle/fluid/framework/ir/node.h"
DECLARE_bool(use_ngraph);
namespace paddle { namespace paddle {
namespace inference { namespace inference {
namespace analysis { namespace analysis {
...@@ -398,6 +400,11 @@ void RemoveIntermediateOutputInSubgraph(const std::vector<Node *> &subgraph, ...@@ -398,6 +400,11 @@ void RemoveIntermediateOutputInSubgraph(const std::vector<Node *> &subgraph,
} }
} }
// In use for ngraph subgraph pass for parallel executor,
// this will remove all nodes, bypass this and let ngraph
// subgraph pass to process outputs
if (FLAGS_use_ngraph && valid_output.size() == 0) return;
outputs->assign(valid_output.begin(), valid_output.end()); outputs->assign(valid_output.begin(), valid_output.end());
} }
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
from paddle.fluid.tests.unittests.simple_nets import simple_fc_net
import paddle.fluid as fluid
import paddle.fluid.core as core
from paddle.fluid import compiler
import numpy as np
import unittest
import os
import sys
import math
class TestPallelExecutorNgraph(unittest.TestCase):
def check_network_convergence(self, build_strategy=None):
os.environ['CPU_NUM'] = str(2)
main = fluid.Program()
startup = fluid.Program()
with fluid.program_guard(main, startup):
loss = simple_fc_net()
test_program = main.clone(for_test=True)
opt = fluid.optimizer.Adam(learning_rate=0.001)
opt.minimize(loss)
batch_size = 32
image = np.random.normal(size=(batch_size, 784)).astype('float32')
label = np.random.randint(0, 10, (batch_size, 1), dtype="int64")
place = fluid.CPUPlace()
exe = fluid.Executor(place)
exe.run(startup)
feed_dict = {'image': image, 'label': label}
train_cp = compiler.CompiledProgram(main).with_data_parallel(
loss_name=loss.name, build_strategy=build_strategy)
test_cp = compiler.CompiledProgram(test_program).with_data_parallel(
loss_name=loss.name,
build_strategy=build_strategy,
share_vars_from=train_cp)
for i in range(5):
_ = exe.run(train_cp, fetch_list=[loss.name], feed=feed_dict)
test_loss, = exe.run(test_cp,
fetch_list=[loss.name],
feed=feed_dict)
train_loss = exe.run(train_cp,
fetch_list=[loss.name],
feed=feed_dict)
avg_test_loss_val = np.array(test_loss).mean()
if math.isnan(float(avg_test_loss_val)):
sys.exit("got NaN loss, testing failed.")
avg_train_loss_val = np.array(train_loss).mean()
if math.isnan(float(avg_train_loss_val)):
sys.exit("got NaN loss, training failed.")
self.assertTrue(
np.allclose(
train_loss, test_loss, atol=1e-8),
"Train loss: " + str(train_loss) + "\n Test loss:" +
str(test_loss))
def test_parallel_testing(self):
build_strategy = fluid.BuildStrategy()
build_strategy.enable_inplace = False
build_strategy.memory_optimize = False
self.check_network_convergence(build_strategy=build_strategy)
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册