diff --git a/paddle/fluid/framework/ngraph_operator.cc b/paddle/fluid/framework/ngraph_operator.cc index 8878917da1683606b3e5602e4981c10b88a7d735..99c4cf0da607cdb5d3282696fc7290187a3f4ed9 100644 --- a/paddle/fluid/framework/ngraph_operator.cc +++ b/paddle/fluid/framework/ngraph_operator.cc @@ -57,6 +57,7 @@ typedef enum { /* nGraph support state on ops */ PARTIAL_TEST /* Support partial list of ops for test */ } op_state; +// perform graph build through bridge and execute computation class NgraphOperator { public: explicit NgraphOperator(const Scope& scope, const platform::Place& place, @@ -100,33 +101,33 @@ class NgraphOperator { std::unordered_set post_op_inputs_; op_state ng_op_state_; + // ngraph backend eg. CPU static std::shared_ptr backend_; - + // ngraph function to call and execute std::shared_ptr ngraph_function_; // var_name of inputs std::vector var_in_; // var_name of outputs from fetch in order std::vector var_out_; - + // map input vars to nodes std::shared_ptr< std::unordered_map>> var_in_node_map_; - // map each var name with a ngraph node std::shared_ptr< std::unordered_map>> var_node_map_; - + // cache key to check if function is cached std::shared_ptr GetCacheKey(); - + // get ngraph input and define ngraph input parameters void GetNgInputShape(std::shared_ptr op); - + // Call ngraph bridge to map ops void BuildNgNode(); - + // get the ngraph input and output var list void BuildNgIO(); - + // build ngraph function call void BuildNgFunction(); - + // Check cache for ngraph function or otherwise build the function void GetNgFunction(); };