提交 e3c37bd5 编写于 作者: B baojun 提交者: tensor-tang

remove const_cast and refactor ngraph engine code (#15925)

* remove concast_cast and refactor code test=develop

* reduce flag use test=develop
上级 09799566
...@@ -34,11 +34,11 @@ limitations under the License. */ ...@@ -34,11 +34,11 @@ limitations under the License. */
#ifdef PADDLE_WITH_NGRAPH #ifdef PADDLE_WITH_NGRAPH
#include "paddle/fluid/operators/ngraph/ngraph_engine.h" #include "paddle/fluid/operators/ngraph/ngraph_engine.h"
DEFINE_bool(use_ngraph, false, "Use NGRAPH to run");
#endif #endif
DECLARE_bool(benchmark); DECLARE_bool(benchmark);
DEFINE_bool(use_mkldnn, false, "Use MKLDNN to run"); DEFINE_bool(use_mkldnn, false, "Use MKLDNN to run");
DEFINE_bool(use_ngraph, false, "Use NGRAPH to run");
namespace paddle { namespace paddle {
namespace framework { namespace framework {
...@@ -194,9 +194,6 @@ void Executor::Run(const ProgramDesc& pdesc, Scope* scope, int block_id, ...@@ -194,9 +194,6 @@ void Executor::Run(const ProgramDesc& pdesc, Scope* scope, int block_id,
bool force_disable_gc) { bool force_disable_gc) {
platform::RecordBlock b(block_id); platform::RecordBlock b(block_id);
if (FLAGS_use_mkldnn) EnableMKLDNN(pdesc); if (FLAGS_use_mkldnn) EnableMKLDNN(pdesc);
#ifdef PADDLE_WITH_NGRAPH
if (FLAGS_use_ngraph) operators::NgraphEngine::EnableNgraph(pdesc);
#endif
auto ctx = Prepare(pdesc, block_id, skip_ref_cnt_vars, force_disable_gc); auto ctx = Prepare(pdesc, block_id, skip_ref_cnt_vars, force_disable_gc);
RunPreparedContext(ctx.get(), scope, create_local_scope, create_vars); RunPreparedContext(ctx.get(), scope, create_local_scope, create_vars);
} }
...@@ -372,6 +369,12 @@ std::unique_ptr<ExecutorPrepareContext> Executor::Prepare( ...@@ -372,6 +369,12 @@ std::unique_ptr<ExecutorPrepareContext> Executor::Prepare(
for (auto& op_desc : block.AllOps()) { for (auto& op_desc : block.AllOps()) {
ctx->ops_.push_back(OpRegistry::CreateOp(*op_desc)); ctx->ops_.push_back(OpRegistry::CreateOp(*op_desc));
} }
#ifdef PADDLE_WITH_NGRAPH
if (FLAGS_use_ngraph) {
paddle::operators::NgraphEngine::FuseNgraphOps(
ctx->prog_.Block(ctx->block_id_), &ctx->ops_);
}
#endif
return ctx; return ctx;
} }
......
...@@ -12,12 +12,18 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,12 +12,18 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#ifndef PADDLE_FLUID_OPERATORS_NGRAPH_NGRAPH_ENGINE_H_
#define PADDLE_FLUID_OPERATORS_NGRAPH_NGRAPH_ENGINE_H_
#include <memory>
#include <set>
#include <string> #include <string>
#include <unordered_map> #include <unordered_map>
#include <unordered_set>
#include <vector> #include <vector>
#include "paddle/fluid/framework/operator.h" #include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/program_desc.h" #include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/var_desc.h"
#include "ngraph/ngraph.hpp" #include "ngraph/ngraph.hpp"
...@@ -33,29 +39,47 @@ enum class OpState { /* nGraph support state on ops */ ...@@ -33,29 +39,47 @@ enum class OpState { /* nGraph support state on ops */
UNKNOWN /* Output all for debug purpose */ UNKNOWN /* Output all for debug purpose */
}; };
// cache engine repetitives
struct EngineCache {
std::shared_ptr<ngraph::Function> ngraph_function;
std::set<std::string> persistables;
std::vector<std::string> var_in;
std::vector<std::string> var_out;
std::vector<size_t> var_in_updates;
bool is_test = true;
};
// perform graph build through bridge and execute computation // perform graph build through bridge and execute computation
class NgraphEngine { class NgraphEngine {
public: public:
explicit NgraphEngine(const framework::Scope& scope, explicit NgraphEngine(const framework::Scope& scope,
const platform::Place& place, const platform::Place& place,
const std::string& serialized_graph, const framework::ExecutionContext& ctx);
const std::vector<int>& interval);
void Run(const framework::Scope& scope, const platform::Place& place) const; void Run(const framework::Scope& scope, const platform::Place& place) const;
static void EnableNgraph(const framework::ProgramDesc& program); static const framework::BlockDesc* p_bdesc;
static std::vector<std::string> feed_vars, fetch_vars;
static void FuseNgraphOps(
const framework::BlockDesc& prog,
std::vector<std::unique_ptr<framework::OperatorBase>>* ops);
private: private:
static std::unordered_map<std::string, std::shared_ptr<ngraph::Function>> static std::unordered_map<std::string, EngineCache> engine_cache;
func_cache_; static std::unordered_map<
std::string, std::vector<std::shared_ptr<ngraph::runtime::Tensor>>>
t_in_cache_;
static framework::Variable* pre_var_ptr;
const framework::Scope& scope_; const framework::Scope& scope_;
const platform::Place& place_; const platform::Place& place_;
std::vector<std::shared_ptr<framework::OperatorBase>> fused_ops_; std::vector<std::shared_ptr<framework::OperatorBase>> fused_ops_;
std::unordered_map<std::string, ngraph::element::Type> var_type_map_; std::unordered_map<std::string, ngraph::element::Type> var_type_map_;
std::unordered_set<std::string> persistables_; std::set<std::string> persistables_;
std::unordered_set<std::string> fetches_;
std::unordered_set<std::string> post_op_inputs_; std::unordered_set<std::string> post_op_inputs_;
OpState ng_op_state_ = OpState::UNKNOWN; OpState op_state_ = OpState::UNKNOWN;
bool is_test_{true};
std::string func_cache_key_; std::string func_cache_key_;
// ngraph backend eg. CPU // ngraph backend eg. CPU
...@@ -66,6 +90,8 @@ class NgraphEngine { ...@@ -66,6 +90,8 @@ class NgraphEngine {
std::vector<std::string> var_in_; std::vector<std::string> var_in_;
// var_name of outputs from fetch in order // var_name of outputs from fetch in order
std::vector<std::string> var_out_; std::vector<std::string> var_out_;
// non-persitable var_in
std::vector<size_t> var_in_updates_;
// map input vars to nodes // map input vars to nodes
std::shared_ptr< std::shared_ptr<
std::unordered_map<std::string, std::shared_ptr<ngraph::Node>>> std::unordered_map<std::string, std::shared_ptr<ngraph::Node>>>
...@@ -74,20 +100,23 @@ class NgraphEngine { ...@@ -74,20 +100,23 @@ class NgraphEngine {
std::shared_ptr< std::shared_ptr<
std::unordered_map<std::string, std::shared_ptr<ngraph::Node>>> std::unordered_map<std::string, std::shared_ptr<ngraph::Node>>>
var_node_map_; var_node_map_;
// prepare info for nraph engine // prepare info for ngraph engine need
void Prepare(const framework::BlockDesc& block, void Prepare(const std::vector<int>& interval);
const std::vector<int>& interval); // get ngraph engine input and output list
void BuildNgIO(const std::vector<framework::OpDesc*>& op_descs,
const std::vector<int>& interval);
// get ngraph input and define ngraph input parameters // get ngraph input and define ngraph input parameters
void GetNgInputShape(std::shared_ptr<framework::OperatorBase> op); void GetNgInputShape();
// Call ngraph bridge to map ops // Call ngraph bridge to map ops
void BuildNgNodes(); void BuildNgNodes();
// get the ngraph input and output var list // run paddle RuntimeInferShape to get the tensor shape
void BuildNgIO(); void RunInferShape();
// build ngraph function call // build ngraph function call
void BuildNgFunction(); void BuildNgFunction(const std::vector<int>& interval);
// Check cache for ngraph function or otherwise build the function // Check cache for ngraph function or otherwise build the function
void GetNgFunction(); void GetNgFunction(std::string engine_key, const std::vector<int>& interval);
}; };
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
#endif // PADDLE_FLUID_OPERATORS_NGRAPH_NGRAPH_ENGINE_H_
...@@ -29,6 +29,7 @@ class NgraphEngineOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -29,6 +29,7 @@ class NgraphEngineOpMaker : public framework::OpProtoAndCheckerMaker {
AddInput("Xs", "A list of inputs.").AsDispensable(); AddInput("Xs", "A list of inputs.").AsDispensable();
AddOutput("Ys", "A list of outputs").AsDispensable(); AddOutput("Ys", "A list of outputs").AsDispensable();
AddAttr<std::string>("graph", "the graph."); AddAttr<std::string>("graph", "the graph.");
AddAttr<std::string>("engine_key", "the engine hash key.");
AddAttr<std::vector<int>>("interval", "op interval supported by ngraph"); AddAttr<std::vector<int>>("interval", "op interval supported by ngraph");
AddComment("ngraph engine operator."); AddComment("ngraph engine operator.");
} }
......
...@@ -46,10 +46,8 @@ class NgraphEngineKernel : public framework::OpKernel<T> { ...@@ -46,10 +46,8 @@ class NgraphEngineKernel : public framework::OpKernel<T> {
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
auto& scope = ctx.scope(); auto& scope = ctx.scope();
auto place = ctx.GetPlace(); auto place = ctx.GetPlace();
std::string serialized_graph = ctx.Attr<std::string>("graph");
auto interval = ctx.Attr<std::vector<int>>("interval");
NgraphEngine ngraph_engine(scope, place, serialized_graph, interval); NgraphEngine ngraph_engine(scope, place, ctx);
ngraph_engine.Run(scope, place); ngraph_engine.Run(scope, place);
} }
}; };
......
...@@ -94,6 +94,14 @@ bool IsCompiledWithMKLDNN() { ...@@ -94,6 +94,14 @@ bool IsCompiledWithMKLDNN() {
#endif #endif
} }
bool IsCompiledWithNGRAPH() {
#ifndef PADDLE_WITH_NGRAPH
return false;
#else
return true;
#endif
}
bool IsCompiledWithBrpc() { bool IsCompiledWithBrpc() {
#ifndef PADDLE_WITH_DISTRIBUTE #ifndef PADDLE_WITH_DISTRIBUTE
return false; return false;
...@@ -874,6 +882,7 @@ All parameter, weight, gradient are variables in Paddle. ...@@ -874,6 +882,7 @@ All parameter, weight, gradient are variables in Paddle.
m.def("init_devices", m.def("init_devices",
[](bool init_p2p) { framework::InitDevices(init_p2p); }); [](bool init_p2p) { framework::InitDevices(init_p2p); });
m.def("is_compiled_with_ngraph", IsCompiledWithNGRAPH);
m.def("is_compiled_with_cuda", IsCompiledWithCUDA); m.def("is_compiled_with_cuda", IsCompiledWithCUDA);
m.def("is_compiled_with_mkldnn", IsCompiledWithMKLDNN); m.def("is_compiled_with_mkldnn", IsCompiledWithMKLDNN);
m.def("is_compiled_with_brpc", IsCompiledWithBrpc); m.def("is_compiled_with_brpc", IsCompiledWithBrpc);
......
...@@ -125,7 +125,7 @@ def __bootstrap__(): ...@@ -125,7 +125,7 @@ def __bootstrap__():
os.environ['OMP_NUM_THREADS'] = str(num_threads) os.environ['OMP_NUM_THREADS'] = str(num_threads)
sysstr = platform.system() sysstr = platform.system()
read_env_flags = [ read_env_flags = [
'check_nan_inf', 'benchmark', 'eager_delete_scope', 'use_ngraph', 'check_nan_inf', 'benchmark', 'eager_delete_scope',
'initial_cpu_memory_in_mb', 'init_allocated_mem', 'free_idle_memory', 'initial_cpu_memory_in_mb', 'init_allocated_mem', 'free_idle_memory',
'paddle_num_threads', "dist_threadpool_size", 'eager_delete_tensor_gb', 'paddle_num_threads', "dist_threadpool_size", 'eager_delete_tensor_gb',
'fast_eager_deletion_mode', 'memory_fraction_of_eager_deletion', 'fast_eager_deletion_mode', 'memory_fraction_of_eager_deletion',
...@@ -143,6 +143,9 @@ def __bootstrap__(): ...@@ -143,6 +143,9 @@ def __bootstrap__():
if core.is_compiled_with_mkldnn(): if core.is_compiled_with_mkldnn():
read_env_flags.append('use_mkldnn') read_env_flags.append('use_mkldnn')
if core.is_compiled_with_ngraph():
read_env_flags.append('use_ngraph')
if core.is_compiled_with_dist(): if core.is_compiled_with_dist():
read_env_flags.append('rpc_deadline') read_env_flags.append('rpc_deadline')
read_env_flags.append('rpc_server_profile_path') read_env_flags.append('rpc_server_profile_path')
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册