/** * Copyright 2019-2020 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #ifndef GE_SESSION_SESSION_MANAGER_H_ #define GE_SESSION_SESSION_MANAGER_H_ #include #include #include #include #include #include "common/ge_inner_error_codes.h" #include "ge/ge_api_types.h" #include "session/inner_session.h" namespace ge { using SessionPtr = std::shared_ptr; class SessionManager { friend class GELib; public: Status SetrtContext(rtContext_t rt_context); /// /// @ingroup ge_session /// @brief create session /// @param [in] options session config options /// @param [out] session_id session id /// @return Status result of function /// Status CreateSession(const std::map &options, SessionId &session_id); /// /// @ingroup ge_session /// @brief destroy the session with specific session id /// @param [in] session_id session id /// @return Status result of function /// Status DestroySession(SessionId session_id); /// /// @ingroup ge_session /// @brief add a graph to the session with specific session id /// @param [in] session_id session id /// @param [in] graph_id graph id /// @param [in] graph the graph to add /// @return Status result of function /// Status AddGraph(SessionId session_id, uint32_t graph_id, const ge::Graph &graph); /// /// @ingroup ge_session /// @brief run a graph of the session with specific session id /// @param [in] session_id session id /// @param [in] graph_id graph id /// @param [in] inputs input data /// @param [out] outputs output data /// @return Status result of function /// Status RunGraph(SessionId session_id, uint32_t graph_id, const std::vector &inputs, std::vector &outputs); /// /// @ingroup ge_session /// @brief remove a graph from the session with specific session id /// @param [in] session_id session id /// @param [in] graph_id graph id /// @return Status result of function /// Status RemoveGraph(SessionId session_id, uint32_t graph_id); /// /// @ingroup ge_session /// @brief get variable value from the session with specific session id /// @param [in] session_id session id /// @param [in] name op name /// @param [out] val out value tensor /// @return Status result of function /// Status GetVariable(SessionId session_id, const std::string &name, Tensor &val); /// /// @ingroup ge_session /// @brief run a graph of the session with specific session id for train asynchronously /// @param [in] session_id session id /// @param [in] graph_id graph id /// @param [in] inputs input data /// @param [out] outputs output data /// @return Status result of function /// Status RunGraphAsync(SessionId session_id, uint32_t graph_id, const std::vector &inputs, std::vector &outputs, std::function callback); /// /// @ingroup ge_graph /// @brief me register the callback function to get the result of summary or checkpoin /// @param [in] session_id session id /// @param [in] key: summary or checkpoint /// @param [in] callbak: The real callback object of me /// @return Status result of function /// Status RegisterCallBackFunc( SessionId session_id, const std::string &key, const std::function &)> &callback); bool IsGraphNeedRebuild(SessionId session_id, uint32_t graph_id); private: SessionManager() = default; ~SessionManager() = default; /// /// @ingroup ge_session /// @brief initialize session manager /// @param [in] options session manager config options /// @return Status result of function /// Status Initialize(const std::map &options); /// /// @ingroup ge_session /// @brief finalize session manager /// @return Status result of function /// Status Finalize(); bool HasSession(SessionId session_id); Status GetNextSessionId(SessionId &next_session_id) const; std::map session_manager_map_; std::mutex mutex_; bool init_flag_ = false; }; }; // namespace ge #endif // GE_SESSION_SESSION_MANAGER_H_