提交 e3c37bd5 编写于 作者: B baojun 提交者: tensor-tang

remove const_cast and refactor ngraph engine code (#15925)

* remove concast_cast and refactor code test=develop

* reduce flag use test=develop
上级 09799566
......@@ -34,11 +34,11 @@ limitations under the License. */
#ifdef PADDLE_WITH_NGRAPH
#include "paddle/fluid/operators/ngraph/ngraph_engine.h"
DEFINE_bool(use_ngraph, false, "Use NGRAPH to run");
#endif
DECLARE_bool(benchmark);
DEFINE_bool(use_mkldnn, false, "Use MKLDNN to run");
DEFINE_bool(use_ngraph, false, "Use NGRAPH to run");
namespace paddle {
namespace framework {
......@@ -194,9 +194,6 @@ void Executor::Run(const ProgramDesc& pdesc, Scope* scope, int block_id,
bool force_disable_gc) {
platform::RecordBlock b(block_id);
if (FLAGS_use_mkldnn) EnableMKLDNN(pdesc);
#ifdef PADDLE_WITH_NGRAPH
if (FLAGS_use_ngraph) operators::NgraphEngine::EnableNgraph(pdesc);
#endif
auto ctx = Prepare(pdesc, block_id, skip_ref_cnt_vars, force_disable_gc);
RunPreparedContext(ctx.get(), scope, create_local_scope, create_vars);
}
......@@ -372,6 +369,12 @@ std::unique_ptr<ExecutorPrepareContext> Executor::Prepare(
for (auto& op_desc : block.AllOps()) {
ctx->ops_.push_back(OpRegistry::CreateOp(*op_desc));
}
#ifdef PADDLE_WITH_NGRAPH
if (FLAGS_use_ngraph) {
paddle::operators::NgraphEngine::FuseNgraphOps(
ctx->prog_.Block(ctx->block_id_), &ctx->ops_);
}
#endif
return ctx;
}
......
......@@ -12,12 +12,18 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifndef PADDLE_FLUID_OPERATORS_NGRAPH_NGRAPH_ENGINE_H_
#define PADDLE_FLUID_OPERATORS_NGRAPH_NGRAPH_ENGINE_H_
#include <memory>
#include <set>
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <vector>
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/var_desc.h"
#include "ngraph/ngraph.hpp"
......@@ -33,29 +39,47 @@ enum class OpState { /* nGraph support state on ops */
UNKNOWN /* Output all for debug purpose */
};
// cache engine repetitives
struct EngineCache {
std::shared_ptr<ngraph::Function> ngraph_function;
std::set<std::string> persistables;
std::vector<std::string> var_in;
std::vector<std::string> var_out;
std::vector<size_t> var_in_updates;
bool is_test = true;
};
// perform graph build through bridge and execute computation
class NgraphEngine {
public:
explicit NgraphEngine(const framework::Scope& scope,
const platform::Place& place,
const std::string& serialized_graph,
const std::vector<int>& interval);
const framework::ExecutionContext& ctx);
void Run(const framework::Scope& scope, const platform::Place& place) const;
static void EnableNgraph(const framework::ProgramDesc& program);
static const framework::BlockDesc* p_bdesc;
static std::vector<std::string> feed_vars, fetch_vars;
static void FuseNgraphOps(
const framework::BlockDesc& prog,
std::vector<std::unique_ptr<framework::OperatorBase>>* ops);
private:
static std::unordered_map<std::string, std::shared_ptr<ngraph::Function>>
func_cache_;
static std::unordered_map<std::string, EngineCache> engine_cache;
static std::unordered_map<
std::string, std::vector<std::shared_ptr<ngraph::runtime::Tensor>>>
t_in_cache_;
static framework::Variable* pre_var_ptr;
const framework::Scope& scope_;
const platform::Place& place_;
std::vector<std::shared_ptr<framework::OperatorBase>> fused_ops_;
std::unordered_map<std::string, ngraph::element::Type> var_type_map_;
std::unordered_set<std::string> persistables_;
std::unordered_set<std::string> fetches_;
std::set<std::string> persistables_;
std::unordered_set<std::string> post_op_inputs_;
OpState ng_op_state_ = OpState::UNKNOWN;
OpState op_state_ = OpState::UNKNOWN;
bool is_test_{true};
std::string func_cache_key_;
// ngraph backend eg. CPU
......@@ -66,6 +90,8 @@ class NgraphEngine {
std::vector<std::string> var_in_;
// var_name of outputs from fetch in order
std::vector<std::string> var_out_;
// non-persitable var_in
std::vector<size_t> var_in_updates_;
// map input vars to nodes
std::shared_ptr<
std::unordered_map<std::string, std::shared_ptr<ngraph::Node>>>
......@@ -74,20 +100,23 @@ class NgraphEngine {
std::shared_ptr<
std::unordered_map<std::string, std::shared_ptr<ngraph::Node>>>
var_node_map_;
// prepare info for nraph engine
void Prepare(const framework::BlockDesc& block,
// prepare info for ngraph engine need
void Prepare(const std::vector<int>& interval);
// get ngraph engine input and output list
void BuildNgIO(const std::vector<framework::OpDesc*>& op_descs,
const std::vector<int>& interval);
// get ngraph input and define ngraph input parameters
void GetNgInputShape(std::shared_ptr<framework::OperatorBase> op);
void GetNgInputShape();
// Call ngraph bridge to map ops
void BuildNgNodes();
// get the ngraph input and output var list
void BuildNgIO();
// run paddle RuntimeInferShape to get the tensor shape
void RunInferShape();
// build ngraph function call
void BuildNgFunction();
void BuildNgFunction(const std::vector<int>& interval);
// Check cache for ngraph function or otherwise build the function
void GetNgFunction();
void GetNgFunction(std::string engine_key, const std::vector<int>& interval);
};
} // namespace operators
} // namespace paddle
#endif // PADDLE_FLUID_OPERATORS_NGRAPH_NGRAPH_ENGINE_H_
......@@ -29,6 +29,7 @@ class NgraphEngineOpMaker : public framework::OpProtoAndCheckerMaker {
AddInput("Xs", "A list of inputs.").AsDispensable();
AddOutput("Ys", "A list of outputs").AsDispensable();
AddAttr<std::string>("graph", "the graph.");
AddAttr<std::string>("engine_key", "the engine hash key.");
AddAttr<std::vector<int>>("interval", "op interval supported by ngraph");
AddComment("ngraph engine operator.");
}
......
......@@ -46,10 +46,8 @@ class NgraphEngineKernel : public framework::OpKernel<T> {
void Compute(const framework::ExecutionContext& ctx) const override {
auto& scope = ctx.scope();
auto place = ctx.GetPlace();
std::string serialized_graph = ctx.Attr<std::string>("graph");
auto interval = ctx.Attr<std::vector<int>>("interval");
NgraphEngine ngraph_engine(scope, place, serialized_graph, interval);
NgraphEngine ngraph_engine(scope, place, ctx);
ngraph_engine.Run(scope, place);
}
};
......
......@@ -94,6 +94,14 @@ bool IsCompiledWithMKLDNN() {
#endif
}
bool IsCompiledWithNGRAPH() {
#ifndef PADDLE_WITH_NGRAPH
return false;
#else
return true;
#endif
}
bool IsCompiledWithBrpc() {
#ifndef PADDLE_WITH_DISTRIBUTE
return false;
......@@ -874,6 +882,7 @@ All parameter, weight, gradient are variables in Paddle.
m.def("init_devices",
[](bool init_p2p) { framework::InitDevices(init_p2p); });
m.def("is_compiled_with_ngraph", IsCompiledWithNGRAPH);
m.def("is_compiled_with_cuda", IsCompiledWithCUDA);
m.def("is_compiled_with_mkldnn", IsCompiledWithMKLDNN);
m.def("is_compiled_with_brpc", IsCompiledWithBrpc);
......
......@@ -125,7 +125,7 @@ def __bootstrap__():
os.environ['OMP_NUM_THREADS'] = str(num_threads)
sysstr = platform.system()
read_env_flags = [
'check_nan_inf', 'benchmark', 'eager_delete_scope', 'use_ngraph',
'check_nan_inf', 'benchmark', 'eager_delete_scope',
'initial_cpu_memory_in_mb', 'init_allocated_mem', 'free_idle_memory',
'paddle_num_threads', "dist_threadpool_size", 'eager_delete_tensor_gb',
'fast_eager_deletion_mode', 'memory_fraction_of_eager_deletion',
......@@ -143,6 +143,9 @@ def __bootstrap__():
if core.is_compiled_with_mkldnn():
read_env_flags.append('use_mkldnn')
if core.is_compiled_with_ngraph():
read_env_flags.append('use_ngraph')
if core.is_compiled_with_dist():
read_env_flags.append('rpc_deadline')
read_env_flags.append('rpc_server_profile_path')
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册