提交 a2969614 编写于 作者: B baojun-nervana

Added annotation

test=develop
上级 d6125a5e
...@@ -57,6 +57,7 @@ typedef enum { /* nGraph support state on ops */ ...@@ -57,6 +57,7 @@ typedef enum { /* nGraph support state on ops */
PARTIAL_TEST /* Support partial list of ops for test */ PARTIAL_TEST /* Support partial list of ops for test */
} op_state; } op_state;
// perform graph build through bridge and execute computation
class NgraphOperator { class NgraphOperator {
public: public:
explicit NgraphOperator(const Scope& scope, const platform::Place& place, explicit NgraphOperator(const Scope& scope, const platform::Place& place,
...@@ -100,33 +101,33 @@ class NgraphOperator { ...@@ -100,33 +101,33 @@ class NgraphOperator {
std::unordered_set<std::string> post_op_inputs_; std::unordered_set<std::string> post_op_inputs_;
op_state ng_op_state_; op_state ng_op_state_;
// ngraph backend eg. CPU
static std::shared_ptr<ngraph::runtime::Backend> backend_; static std::shared_ptr<ngraph::runtime::Backend> backend_;
// ngraph function to call and execute
std::shared_ptr<ngraph::Function> ngraph_function_; std::shared_ptr<ngraph::Function> ngraph_function_;
// var_name of inputs // var_name of inputs
std::vector<std::string> var_in_; std::vector<std::string> var_in_;
// var_name of outputs from fetch in order // var_name of outputs from fetch in order
std::vector<std::string> var_out_; std::vector<std::string> var_out_;
// map input vars to nodes
std::shared_ptr< std::shared_ptr<
std::unordered_map<std::string, std::shared_ptr<ngraph::Node>>> std::unordered_map<std::string, std::shared_ptr<ngraph::Node>>>
var_in_node_map_; var_in_node_map_;
// map each var name with a ngraph node // map each var name with a ngraph node
std::shared_ptr< std::shared_ptr<
std::unordered_map<std::string, std::shared_ptr<ngraph::Node>>> std::unordered_map<std::string, std::shared_ptr<ngraph::Node>>>
var_node_map_; var_node_map_;
// cache key to check if function is cached
std::shared_ptr<std::string> GetCacheKey(); std::shared_ptr<std::string> GetCacheKey();
// get ngraph input and define ngraph input parameters
void GetNgInputShape(std::shared_ptr<OperatorBase> op); void GetNgInputShape(std::shared_ptr<OperatorBase> op);
// Call ngraph bridge to map ops
void BuildNgNode(); void BuildNgNode();
// get the ngraph input and output var list
void BuildNgIO(); void BuildNgIO();
// build ngraph function call
void BuildNgFunction(); void BuildNgFunction();
// Check cache for ngraph function or otherwise build the function
void GetNgFunction(); void GetNgFunction();
}; };
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册