C-API design survey
Created by: jacquesqiao
下面是对tensorflow和mxnet的api设计做了一些调研的结论。
共同点:
- 1,都使用标准c对外提供api接口。
- 2,所有对外接口封装到一个文件中,例:c_api.h。
- 3,使用下面的类似结构进行调用状态的传递,封装的语言需要对结果进行check。
struct Status {
const char* msg;
int32_t code;
};
- 4,资源需要有宿主语言来释放。
不同点:
- 1,mxnet使用了exception,标准API通过try catch封装内部接口。TF不使用exception。
- 2,mxnet使用python直接load so的方式进行封装。TF的官方python binding使用了swig。
#849 (closed)):
经过一些讨论得到的结论(- 1,所有API通过一个类似c_api.h方式对外暴露。
- 2,错误信息使用Status的方式,宿主语言做错误检查。
- 3,不使用exception。
- 4,不使用swig,直接封装so的方式。
- 5,资源需要由宿主语言来释放。
实现细节。
TF封装API的完整demo。
- c_api.h中的接口申明。
extern TF_Session* TF_NewSession(TF_Graph* graph, const TF_SessionOptions* opts, TF_Status* status);
- c_api.cc中的接口实现。
TF_Session* TF_NewSession(TF_Graph* graph, const TF_SessionOptions* opt,
TF_Status* status) {
//初始化并做一些检查工作,根据检查结果设置status
Session* session;
status->status = NewSession(opt->options, &session);
return new TF_Session(session, graph);
}
- session.py中封装和session相关的api。通过try ... finally ... 方式进行资源释放。
class Session:
def __init__():
opts = tf_session.TF_NewSessionOptions(target=target, config=config)
try:
status = tf_session.TF_NewStatus()
try:
self._session = tf_session.TF_NewSession(opts, status)
# 状态检查
if tf_session.TF_GetCode(status) != 0:
raise RuntimeError(compat.as_text(tf_session.TF_Message(status)))
finally:
# 释放资源
tf_session.TF_DeleteStatus(status)
finally:
# 释放资源
tf_session.TF_DeleteSessionOptions(opts)
mxnet封装demo
- 所有接口封装在c_api.h中。
extern “C" int MXOptimizerFindCreator(const char *key, OptimizerCreator *out);
-
接口实现在c_api.cc中。用API_BEGIN() 开始,用 API_END() 结束。实际上是使用了try catch。返回值是int,0代表正常,非0代表有异常。
-
接口实现在 c_api.cc中。
int MXOptimizerCreateOptimizer(OptimizerCreator creator,
mx_uint num_param,
const char **keys,
const char **vals,
OptimizerHandle *out) {
API_BEGIN();
# user logic
API_END();
}
- python 调用方式
# 直接从so中load所有的接口。
_LIB = _load_lib()
# check_call检查返回值,
check_call(_LIB.MXOptimizerFindCreator(c_str(name), ctypes.byref(creator)))
def check_call(ret):
"""Check the return value of C API call
This function will raise exception when error occurs.
Wrap every API call with this function
Parameters
----------
ret : int
return value from API calls
"""
if ret != 0:
raise MXNetError(py_str(_LIB.MXGetLastError()))