未验证 提交 c7b3bfcd 编写于 作者: T Tao Luo 提交者: GitHub

Merge pull request #14376 from baojun-nervana/intel/ngraph_fusedop

Adding fused operator for ngraph 
......@@ -136,6 +136,10 @@ cc_library(version SRCS version.cc)
cc_test(version_test SRCS version_test.cc DEPS version)
cc_library(proto_desc SRCS var_desc.cc op_desc.cc block_desc.cc program_desc.cc DEPS shape_inference op_info operator glog version)
cc_library(ngraph_bridge SRCS ngraph_bridge.cc DEPS operator framework_proto)
cc_library(ngraph_operator SRCS ngraph_operator.cc DEPS ngraph_bridge operator op_info device_context tensor scope glog
shape_inference data_transform lod_tensor profiler)
cc_library(op_registry SRCS op_registry.cc DEPS op_proto_maker op_info operator glog proto_desc)
nv_test(op_registry_test SRCS op_registry_test.cc DEPS op_registry)
......@@ -163,10 +167,10 @@ if(WITH_DISTRIBUTE)
set(DISTRIBUTE_COMPILE_FLAGS "-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor")
set_source_files_properties(executor.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
else()
cc_library(executor SRCS executor.cc DEPS op_registry device_context scope framework_proto glog lod_rank_table feed_fetch_method graph_to_program_pass)
cc_library(executor SRCS executor.cc DEPS op_registry device_context scope framework_proto glog lod_rank_table feed_fetch_method graph_to_program_pass ngraph_operator)
cc_test(test_naive_executor SRCS naive_executor_test.cc DEPS naive_executor elementwise_add_op)
endif()
if (NOT WIN32)
cc_library(parallel_executor SRCS parallel_executor.cc DEPS
threaded_ssa_graph_executor scope_buffered_ssa_graph_executor
......
......@@ -17,6 +17,7 @@ limitations under the License. */
#include "paddle/fluid/framework/feed_fetch_method.h"
#include "paddle/fluid/framework/lod_rank_table.h"
#include "paddle/fluid/framework/lod_tensor_array.h"
#include "paddle/fluid/framework/ngraph_operator.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/reader.h"
#include "paddle/fluid/operators/detail/macros.h"
......@@ -25,6 +26,7 @@ limitations under the License. */
DECLARE_bool(benchmark);
DEFINE_bool(use_mkldnn, false, "Use MKLDNN to run");
DEFINE_bool(use_ngraph, false, "Use NGRAPH to run");
namespace paddle {
namespace framework {
......@@ -81,6 +83,24 @@ static void DeleteUnusedTensors(const Scope& scope, const OperatorBase* op,
}
}
static void EnableFusedOp(ExecutorPrepareContext* ctx) {
#ifdef PADDLE_WITH_NGRAPH
VLOG(3) << "use_ngraph=True";
auto intervals = FusedOperator::FusedOpIntervals(&ctx->ops_);
for (auto& interval : intervals) {
auto* fused_op = new FusedOperator(ctx->prog_, ctx->block_id_,
interval.at(0), interval.at(1));
*interval[0] = std::unique_ptr<OperatorBase>(fused_op);
}
for (auto it = intervals.rbegin(); it != intervals.rend(); ++it) {
ctx->ops_.erase(it->at(0) + 1, it->at(1));
}
#else
LOG(WARNING)
<< "'NGRAPH' is not supported, Please re-compile with WITH_NGRAPH option";
#endif
}
Executor::Executor(const platform::Place& place) : place_(place) {}
void Executor::Close() {
......@@ -338,6 +358,7 @@ std::unique_ptr<ExecutorPrepareContext> Executor::Prepare(
for (auto& op_desc : block.AllOps()) {
ctx->ops_.push_back(OpRegistry::CreateOp(*op_desc));
}
if (FLAGS_use_ngraph) EnableFusedOp(ctx.get());
return ctx;
}
......@@ -486,6 +507,5 @@ void Executor::EnableMKLDNN(const ProgramDesc& program) {
<< "'MKLDNN' is not supported, Please re-compile with WITH_MKLDNN option";
#endif
}
} // namespace framework
} // namespace paddle
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef PADDLE_WITH_NGRAPH
#include <algorithm>
#include <functional>
#include "paddle/fluid/framework/ngraph_bridge.h"
#include "ngraph/ngraph.hpp"
namespace paddle {
namespace framework {
std::map<std::string,
std::function<void(const std::shared_ptr<OperatorBase>&,
std::shared_ptr<std::unordered_map<
std::string, std::shared_ptr<ngraph::Node>>>)>>
NgraphBridge::NG_NODE_MAP = {};
void NgraphBridge::build_graph(const std::shared_ptr<OperatorBase>& op) {
auto& op_type = op->Type();
NG_NODE_MAP[op_type](op, ngb_node_map);
}
} // namespace framework
} // namespace paddle
#endif
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#ifdef PADDLE_WITH_NGRAPH
#include <algorithm>
#include <map>
#include <string>
#include <unordered_map>
#include <vector>
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/platform/enforce.h"
#include "ngraph/ngraph.hpp"
namespace paddle {
namespace framework {
class NgraphBridge {
public:
static std::map<
std::string,
std::function<void(const std::shared_ptr<OperatorBase>&,
std::shared_ptr<std::unordered_map<
std::string, std::shared_ptr<ngraph::Node>>>)>>
NG_NODE_MAP;
explicit NgraphBridge(
std::shared_ptr<
std::unordered_map<std::string, std::shared_ptr<ngraph::Node>>>
var_node_map)
: ngb_node_map(var_node_map) {}
void build_graph(const std::shared_ptr<OperatorBase>& op);
private:
std::shared_ptr<
std::unordered_map<std::string, std::shared_ptr<ngraph::Node>>>
ngb_node_map;
};
} // namespace framework
} // namespace paddle
#endif
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef PADDLE_WITH_NGRAPH
#include <glog/logging.h>
#include <algorithm>
#include <map>
#include "paddle/fluid/framework/feed_fetch_type.h"
#include "paddle/fluid/framework/ngraph_operator.h"
#include "paddle/fluid/framework/shape_inference.h"
#include "paddle/fluid/framework/var_desc.h"
#include "paddle/fluid/framework/var_type.h"
namespace paddle {
namespace framework {
static std::map<proto::VarType::Type, ngraph::element::Type> pd2ng_type_map = {
{proto::VarType::FP32, ngraph::element::f32},
{proto::VarType::FP64, ngraph::element::f64},
{proto::VarType::INT32, ngraph::element::i32},
{proto::VarType::INT64, ngraph::element::i64},
{proto::VarType::BOOL, ngraph::element::boolean},
};
typedef enum { /* nGraph support state on ops */
FULL_TRAIN, /* Support full ops for train */
PARTIAL_TRAIN, /* Support partial ops for train */
FULL_TEST, /* Support full list of ops for test */
PARTIAL_TEST /* Support partial list of ops for test */
} op_state;
class NgraphOperator {
public:
explicit NgraphOperator(const Scope& scope, const platform::Place& place,
const std::vector<std::shared_ptr<OperatorBase>>& ops,
const std::unordered_map<
std::string, ngraph::element::Type>& var_type_map,
const std::unordered_set<std::string>& persist,
const std::unordered_set<std::string>& fetches,
const std::unordered_set<std::string>& post_op_inputs,
op_state ng_op_state)
: scope_(scope),
place_(place),
fused_ops_(ops),
var_type_map_(var_type_map),
persistables_(persist),
fetches_(fetches),
post_op_inputs_(post_op_inputs),
ng_op_state_(ng_op_state) {}
void Run(const Scope& scope, const platform::Place& place) const;
private:
static std::unordered_map<std::string, std::shared_ptr<ngraph::Function>>
func_cache;
const Scope& scope_;
const platform::Place& place_;
std::vector<std::shared_ptr<OperatorBase>> fused_ops_;
std::unordered_map<std::string, ngraph::element::Type> var_type_map_;
std::unordered_set<std::string> persistables_;
std::unordered_set<std::string> fetches_;
std::unordered_set<std::string> post_op_inputs_;
op_state ng_op_state_;
};
std::vector<std::vector<std::vector<std::unique_ptr<OperatorBase>>::iterator>>
FusedOperator::FusedOpIntervals(
std::vector<std::unique_ptr<paddle::framework::OperatorBase>>* ops) {
std::vector<std::vector<std::vector<std::unique_ptr<OperatorBase>>::iterator>>
intervals;
if (ops->empty()) {
return intervals;
}
size_t size = ops->size();
size_t left = 0;
while (left < size && ops.at(left)->Type() != kFeedOpType) {
++left;
}
if (left == size) {
return intervals;
}
while (left < size && ops->at(left)->Type() == kFeedOpType) {
++left;
}
size_t right = left;
while (right < size && ops->at(right)->Type() != kFetchOpType) {
++right;
}
if (right == size) {
return intervals;
}
if (left >= right) return intervals;
// (left, right - 1) represents indices between feed and fetch
size_t pivot = left;
while (pivot < right) {
auto op_type = ops->at(pivot)->Type();
if (paddle::framework::NgraphBridge::NG_NODE_MAP.find(op_type) ==
paddle::framework::NgraphBridge::NG_NODE_MAP.end()) {
++pivot;
} else {
size_t start = pivot, end = start;
while (pivot < right &&
(paddle::framework::NgraphBridge::NG_NODE_MAP.find(
ops.at(pivot)->Type()) !=
paddle::framework::NgraphBridge::NG_NODE_MAP.end())) {
++pivot;
++end;
}
std::vector<std::vector<std::unique_ptr<OperatorBase>>::iterator>
interval = {ops->begin() + start, ops->begin() + end};
intervals.push_back(interval);
}
} // end while
return intervals;
}
FusedOperator::FusedOperator(
const ProgramDesc& prog, size_t block_id,
std::vector<std::unique_ptr<OperatorBase>>::iterator start,
std::vector<std::unique_ptr<OperatorBase>>::iterator end,
const std::string& type, const VariableNameMap& inputs,
const VariableNameMap& outputs, const AttributeMap& attrs)
: OperatorBase(type, inputs, outputs, attrs), pdesc(prog), block(block_id) {
for (std::vector<std::unique_ptr<OperatorBase>>::iterator it = start;
it != end; ++it) {
fused_ops_.push_back(std::move(*it));
}
for (std::vector<std::unique_ptr<OperatorBase>>::iterator it = end;
(*it)->Type() != kFetchOpType; ++it) {
for (auto& var_name_item : (*it)->Inputs()) {
for (auto& var_name : var_name_item.second) {
post_op_inputs_.insert(var_name);
}
}
}
if ((*(start - 1))->Type() == kFeedOpType && (*end)->Type() == kFetchOpType) {
is_complete = true;
}
Process();
}
void FusedOperator::Process() {
auto& bdesc = pdesc_.Block(block_);
for (auto& var : bdesc.AllVars()) {
if (!(var->GetType() == proto::VarType::SELECTED_ROWS ||
var->GetType() == proto::VarType::LOD_TENSOR ||
var->GetType() == proto::VarType::LOD_TENSOR_ARRAY)) {
continue;
}
auto var_name = var->Name();
if (var->Name() == framework::kEmptyVarName) {
continue;
}
if (var_name != "fetch" && var_name != "feed") {
auto pd_type = var->GetDataType();
if (pd2ng_type_map.find(pd_type) == pd2ng_type_map.end()) {
PADDLE_THROW("Data type of var %s not found in pd2ng_type_map",
var_name);
}
var_type_map_[var_name] = pd2ng_type_map[pd_type];
}
if (var->Persistable()) {
persistables_.insert(var->Name());
}
}
for (auto* op : bdesc.AllOps()) {
if (op->Type() == kFetchOpType) {
std::string fetch_target_name = op->Input("X")[0];
fetches_.insert(fetch_target_name);
}
}
}
void FusedOperator::RunImpl(const Scope& scope,
const platform::Place& place) const {
op_state ng_op_state = PARTIAL_TEST;
auto& bdesc = pdesc_.Block(block_);
for (auto* op : bdesc.AllOps()) {
if (op->Type().find("_grad") != std::string::npos) {
ng_op_state = PARTIAL_TRAIN;
break;
}
}
if (is_full) {
ng_op_state = ng_op_state == PARTIAL_TEST ? FULL_TEST : FULL_TRAIN;
}
NgraphOperator ngraph_op(scope, place, fused_ops_, var_type_map_,
persistables_, fetches_, post_op_inputs_,
ng_op_state);
ngraph_op.Run(scope, place);
}
} // namespace framework
} // namespace paddle
#endif
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#ifdef PADDLE_WITH_NGRAPH
#include <algorithm>
#include <atomic>
#include <string>
#include <unordered_map>
#include <vector>
#include "paddle/fluid/framework/attribute.h"
#include "paddle/fluid/framework/framework.pb.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/ngraph_bridge.h"
#include "paddle/fluid/framework/op_info.h"
#include "paddle/fluid/framework/op_kernel_type.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/platform/variant.h"
#include "ngraph/ngraph.hpp"
namespace paddle {
namespace framework {
class FusedOperator : public OperatorBase {
public:
static std::vector<
std::vector<std::vector<std::unique_ptr<OperatorBase>>::iterator>>
FusedOpIntervals(
std::vector<std::unique_ptr<paddle::framework::OperatorBase>>* ops);
explicit FusedOperator(
const ProgramDesc& prog, size_t block_id,
std::vector<std::unique_ptr<OperatorBase>>::iterator start,
std::vector<std::unique_ptr<OperatorBase>>::iterator end,
const std::string& type = "fused_op", const VariableNameMap& inputs = {},
const VariableNameMap& outputs = {}, const AttributeMap& attrs = {});
void RunImpl(const Scope& scope, const platform::Place& place) const final;
private:
const ProgramDesc pdesc_;
size_t block_;
std::vector<std::shared_ptr<OperatorBase>> fused_ops_;
std::unordered_map<std::string, ngraph::element::Type> var_type_map_;
std::unordered_set<std::string> persistables_;
std::unordered_set<std::string> fetches_;
std::unordered_set<std::string> post_op_inputs_;
bool is_full_ = false;
void Process();
};
} // namespace framework
} // namespace paddle
#endif
......@@ -112,10 +112,10 @@ def __bootstrap__():
read_env_flags = [
'use_pinned_memory', 'check_nan_inf', 'benchmark', 'warpctc_dir',
'eager_delete_scope', 'use_mkldnn', 'initial_cpu_memory_in_mb',
'init_allocated_mem', 'free_idle_memory', 'paddle_num_threads',
'dist_threadpool_size', 'cpu_deterministic', 'eager_delete_tensor_gb',
'reader_queue_speed_test_mode'
'eager_delete_scope', 'use_mkldnn', 'use_ngraph',
'initial_cpu_memory_in_mb', 'init_allocated_mem', 'free_idle_memory',
'paddle_num_threads', 'dist_threadpool_size', 'cpu_deterministic',
'eager_delete_tensor_gb', 'reader_queue_speed_test_mode'
]
if core.is_compiled_with_dist():
read_env_flags.append('rpc_deadline')
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册