未验证 提交 ee476bb0 编写于 作者: S slguan 提交者: GitHub

Merge pull request #1288 from taosdata/refact/rpc

fix the bug for authentication
......@@ -394,8 +394,8 @@ void rpcSendResponse(void *handle, int32_t code, void *pCont, int contLen) {
// set msg header
pHead->version = 1;
pHead->msgType = pConn->inType+1;
pHead->spi = 0;
pHead->encrypt = 0;
pHead->spi = pConn->spi;
pHead->encrypt = pConn->encrypt;
pHead->tranId = pConn->inTranId;
pHead->sourceId = pConn->ownId;
pHead->destId = pConn->peerId;
......@@ -545,11 +545,11 @@ static SRpcConn *rpcAllocateClientConn(SRpcInfo *pRpc) {
static SRpcConn *rpcAllocateServerConn(SRpcInfo *pRpc, SRecvInfo *pRecv) {
SRpcConn *pConn = NULL;
char hashstr[40];
char hashstr[40] = {0};
SRpcHead *pHead = (SRpcHead *)pRecv->msg;
sprintf(hashstr, "%x:%x:%x:%d", pRecv->ip, pHead->uid, pHead->sourceId, pRecv->connType);
// check if it is already allocated
SRpcConn **ppConn = (SRpcConn **)(taosGetStrHashData(pRpc->hash, hashstr));
if (ppConn) pConn = *ppConn;
......@@ -567,7 +567,7 @@ static SRpcConn *rpcAllocateServerConn(SRpcInfo *pRpc, SRecvInfo *pRecv) {
pConn->sid = sid;
pConn->tranId = (uint16_t)(rand() & 0xFFFF);
pConn->ownId = htonl(pConn->sid);
if (pRpc->afp && (*pRpc->afp)(pConn->user, &pConn->spi, &pConn->encrypt, pConn->secret, pConn->ckey)) {
if (pRpc->afp && (*pRpc->afp)(pConn->user, &pConn->spi, &pConn->encrypt, pConn->secret, pConn->ckey) < 0) {
tWarn("%s %p, user not there", pRpc->label, pConn);
taosFreeId(pRpc->idPool, sid); // sid shall be released
terrno = TSDB_CODE_INVALID_USER;
......@@ -720,21 +720,12 @@ static SRpcConn *rpcProcessMsgHead(SRpcInfo *pRpc, SRecvInfo *pRecv) {
SRpcHead *pHead = (SRpcHead *)pRecv->msg;
sid = htonl(pHead->destId);
pHead->code = htonl(pHead->code);
pHead->msgLen = (int32_t)htonl((uint32_t)pHead->msgLen);
pHead->port = htons(pHead->port);
if (pHead->msgType >= TSDB_MSG_TYPE_MAX || pHead->msgType <= 0) {
tTrace("%s sid:%d, invalid message type:%d", pRpc->label, sid, pHead->msgType);
terrno = TSDB_CODE_INVALID_MSG_TYPE; return NULL;
}
if (pRecv->msgLen != pHead->msgLen) {
tTrace("%s sid:%d, %s has invalid length, dataLen:%d, msgLen:%d", pRpc->label, sid,
taosMsg[pHead->msgType], pRecv->msgLen, pHead->msgLen);
terrno = TSDB_CODE_INVALID_MSG_LEN; return NULL;
}
if (sid < 0 || sid >= pRpc->sessions) {
tTrace("%s sid:%d, sid is out of range, max sid:%d, %s discarded", pRpc->label, sid,
pRpc->sessions, taosMsg[pHead->msgType]);
......@@ -756,10 +747,14 @@ static SRpcConn *rpcProcessMsgHead(SRpcInfo *pRpc, SRecvInfo *pRecv) {
}
if (pRecv->port) pConn->peerPort = pRecv->port;
if (pHead->port) pConn->peerPort = pHead->port;
if (pHead->port) pConn->peerPort = htons(pHead->port);
if (pHead->uid) pConn->peerUid = pHead->uid;
terrno = rpcCheckAuthentication(pConn, (char *)pHead, pRecv->msgLen);
// code can be transformed only after authentication
pHead->code = htonl(pHead->code);
if (terrno == 0) {
if (pHead->msgType != TSDB_MSG_TYPE_REG && pHead->encrypt) {
// decrypt here
......@@ -906,9 +901,9 @@ static void rpcSendErrorMsgToPeer(SRecvInfo *pRecv, int32_t code) {
pReplyHead->version = pRecvHead->version;
pReplyHead->msgType = (char)(pRecvHead->msgType + 1);
pReplyHead->spi = 0;
pReplyHead->encrypt = 0;
pReplyHead->encrypt = pRecvHead->encrypt;
pReplyHead->tranId = pRecvHead->tranId;
pReplyHead->sourceId = 0;
pReplyHead->sourceId = pRecvHead->destId;
pReplyHead->destId = pRecvHead->sourceId;
memcpy(pReplyHead->user, pRecvHead->user, tListLen(pReplyHead->user));
......@@ -918,7 +913,7 @@ static void rpcSendErrorMsgToPeer(SRecvInfo *pRecv, int32_t code) {
if (code == TSDB_CODE_INVALID_TIME_STAMP) {
// include a time stamp if client's time is not synchronized well
uint8_t *pContent = pReplyHead->content;
timeStamp = taosGetTimestampSec();
timeStamp = htonl(taosGetTimestampSec());
memcpy(pContent, &timeStamp, sizeof(timeStamp));
msgLen += sizeof(timeStamp);
}
......@@ -1031,10 +1026,10 @@ static void rpcProcessRetryTimer(void *param, void *tmrId) {
pConn->retry++;
if (pConn->retry < 4) {
tTrace("%s %p, re-send msg:%s to %s:%hu", pRpc->label, pConn,
taosMsg[pConn->outType], pConn->peerIpstr, pConn->peerPort);
tTrace("%s %p, re-send msg:%s to %s:%hu retry:%d", pRpc->label, pConn,
taosMsg[pConn->outType], pConn->peerIpstr, pConn->peerPort, pConn->retry);
rpcSendMsgToPeer(pConn, pConn->pReqMsg, pConn->reqMsgLen);
taosTmrReset(rpcProcessRetryTimer, tsRpcTimer<<pConn->retry, pConn, pRpc->tmrCtrl, &pConn->pTimer);
taosTmrReset(rpcProcessRetryTimer, tsRpcTimer, pConn, pRpc->tmrCtrl, &pConn->pTimer);
} else {
// close the connection
tTrace("%s %p, failed to send msg:%s to %s:%hu", pRpc->label, pConn,
......@@ -1083,12 +1078,6 @@ static void rpcProcessProgressTimer(void *param, void *tmrId) {
rpcUnlockConn(pConn);
}
static void rpcFreeOutMsg(void *msg) {
if ( msg == NULL ) return;
char *req = ((char *)msg) - sizeof(SRpcReqContext);
free(req);
}
static int32_t rpcCompressRpcMsg(char* pCont, int32_t contLen) {
SRpcHead *pHead = rpcHeadFromCont(pCont);
int32_t finalLen = 0;
......@@ -1160,14 +1149,14 @@ static SRpcHead *rpcDecompressRpcMsg(SRpcHead *pHead) {
return pHead;
}
static int rpcAuthenticateMsg(uint8_t *pMsg, int msgLen, uint8_t *pAuth, uint8_t *pKey) {
static int rpcAuthenticateMsg(void *pMsg, int msgLen, void *pAuth, void *pKey) {
MD5_CTX context;
int ret = -1;
MD5Init(&context);
MD5Update(&context, pKey, TSDB_KEY_LEN);
MD5Update(&context, pMsg, msgLen);
MD5Update(&context, pKey, TSDB_KEY_LEN);
MD5Update(&context, (uint8_t *)pKey, TSDB_KEY_LEN);
MD5Update(&context, (uint8_t *)pMsg, msgLen);
MD5Update(&context, (uint8_t *)pKey, TSDB_KEY_LEN);
MD5Final(&context);
if (memcmp(context.digest, pAuth, sizeof(context.digest)) == 0) ret = 0;
......@@ -1175,18 +1164,16 @@ static int rpcAuthenticateMsg(uint8_t *pMsg, int msgLen, uint8_t *pAuth, uint8_t
return ret;
}
static int rpcBuildAuthHead(uint8_t *pMsg, int msgLen, uint8_t *pAuth, uint8_t *pKey) {
static void rpcBuildAuthHead(void *pMsg, int msgLen, void *pAuth, void *pKey) {
MD5_CTX context;
MD5Init(&context);
MD5Update(&context, pKey, TSDB_KEY_LEN);
MD5Update(&context, (uint8_t *)pKey, TSDB_KEY_LEN);
MD5Update(&context, (uint8_t *)pMsg, msgLen);
MD5Update(&context, pKey, TSDB_KEY_LEN);
MD5Update(&context, (uint8_t *)pKey, TSDB_KEY_LEN);
MD5Final(&context);
memcpy(pAuth, context.digest, sizeof(context.digest));
return 0;
}
static int rpcAddAuthPart(SRpcConn *pConn, char *msg, int msgLen) {
......@@ -1199,7 +1186,7 @@ static int rpcAddAuthPart(SRpcConn *pConn, char *msg, int msgLen) {
pDigest->timeStamp = htonl(taosGetTimestampSec());
msgLen += sizeof(SRpcDigest);
pHead->msgLen = (int32_t)htonl((uint32_t)msgLen);
rpcBuildAuthHead((uint8_t *)pHead, msgLen - TSDB_AUTH_LEN, pDigest->auth, (uint8_t *)pConn->secret);
rpcBuildAuthHead(pHead, msgLen - TSDB_AUTH_LEN, pDigest->auth, pConn->secret);
} else {
pHead->msgLen = (int32_t)htonl((uint32_t)msgLen);
}
......@@ -1212,7 +1199,21 @@ static int rpcCheckAuthentication(SRpcConn *pConn, char *msg, int msgLen) {
SRpcInfo *pRpc = pConn->pRpc;
int32_t code = 0;
if (pConn->spi == 0 ) return 0;
if (pConn->spi == 0) {
pHead->msgLen = (int32_t)htonl((uint32_t)pHead->msgLen);
return 0;
}
if ( !rpcIsReq(pHead->msgType) ) {
// for response, if code is auth failure, it shall bypass the auth process
code = htonl(pHead->code);
if (code==TSDB_CODE_INVALID_TIME_STAMP || code==TSDB_CODE_AUTH_FAILURE || code==TSDB_CODE_INVALID_USER) {
pHead->msgLen = (int32_t)htonl((uint32_t)pHead->msgLen);
return 0;
}
}
code = 0;
if (pHead->spi == pConn->spi) {
// authentication
......@@ -1222,23 +1223,19 @@ static int rpcCheckAuthentication(SRpcConn *pConn, char *msg, int msgLen) {
delta = (int32_t)htonl(pDigest->timeStamp);
delta -= (int32_t)taosGetTimestampSec();
if (abs(delta) > 900) {
tWarn("%s %p, time diff:%d is too big, msg discarded, timestamp:%d", pRpc->label, pConn,
delta, htonl(pDigest->timeStamp));
tWarn("%s %p, time diff:%d is too big, msg discarded", pRpc->label, pConn, delta);
code = TSDB_CODE_INVALID_TIME_STAMP;
} else {
if (rpcAuthenticateMsg((uint8_t *)pHead, msgLen - TSDB_AUTH_LEN, pDigest->auth, (uint8_t *)pConn->secret) < 0) {
if (rpcAuthenticateMsg(pHead, msgLen-TSDB_AUTH_LEN, pDigest->auth, pConn->secret) < 0) {
tError("%s %p, authentication failed, msg discarded", pRpc->label, pConn);
code = TSDB_CODE_AUTH_FAILURE;
} else {
pHead->msgLen -= sizeof(SRpcDigest);
pHead->msgLen = (int32_t)htonl((uint32_t)pHead->msgLen) - sizeof(SRpcDigest);
}
}
} else {
// if it is request or response with code 0, msg shall be discarded
if (rpcIsReq(pHead->msgType) || (pHead->content[0] == 0)) {
tTrace("%s %p, auth spi not matched, msg discarded", pRpc->label, pConn);
code = TSDB_CODE_AUTH_FAILURE;
}
tTrace("%s %p, auth spi not matched, msg discarded", pRpc->label, pConn);
code = TSDB_CODE_AUTH_FAILURE;
}
return code;
......
......@@ -285,9 +285,9 @@ int taosSendUdpData(uint32_t ip, uint16_t port, void *data, int dataLen, void *c
destAdd.sin_addr.s_addr = ip;
destAdd.sin_port = htons(port);
//tTrace("%s msg is sent to 0x%x:%hu len:%d ret:%d localPort:%hu chandle:0x%x", pConn->label, destAdd.sin_addr.s_addr,
// port, dataLen, ret, pConn->localPort, chandle);
int ret = (int)sendto(pConn->fd, data, (size_t)dataLen, 0, (struct sockaddr *)&destAdd, sizeof(destAdd));
tTrace("%s msg is sent to 0x%x:%hu len:%d ret:%d localPort:%hu chandle:0x%x", pConn->label, destAdd.sin_addr.s_addr,
port, dataLen, ret, pConn->localPort, chandle);
return ret;
}
......
......@@ -110,6 +110,7 @@ int main(int argc, char *argv[]) {
rpcInit.user = "michael";
rpcInit.secret = "mypassword";
rpcInit.ckey = "key";
rpcInit.spi = 1;
for (int i=1; i<argc; ++i) {
if (strcmp(argv[i], "-p")==0 && i < argc-1) {
......@@ -128,11 +129,16 @@ int main(int argc, char *argv[]) {
numOfReqs = atoi(argv[++i]);
} else if (strcmp(argv[i], "-a")==0 && i < argc-1) {
appThreads = atoi(argv[++i]);
} else if (strcmp(argv[i], "-d")==0 && i < argc-1) {
rpcDebugFlag = atoi(argv[++i]);
} else if (strcmp(argv[i], "-o")==0 && i < argc-1) {
tsCompressMsgSize = atoi(argv[++i]);
} else if (strcmp(argv[i], "-u")==0 && i < argc-1) {
rpcInit.user = argv[++i];
} else if (strcmp(argv[i], "-k")==0 && i < argc-1) {
rpcInit.secret = argv[++i];
} else if (strcmp(argv[i], "-spi")==0 && i < argc-1) {
rpcInit.spi = atoi(argv[++i]);
} else if (strcmp(argv[i], "-d")==0 && i < argc-1) {
rpcDebugFlag = atoi(argv[++i]);
} else {
printf("\nusage: %s [options] \n", argv[0]);
printf(" [-i ip]: first server IP address, default is:%s\n", serverIp);
......@@ -144,6 +150,9 @@ int main(int argc, char *argv[]) {
printf(" [-a threads]: number of app threads, default is:%d\n", appThreads);
printf(" [-n requests]: number of requests per thread, default is:%d\n", numOfReqs);
printf(" [-o compSize]: compression message size, default is:%d\n", tsCompressMsgSize);
printf(" [-u user]: user name for the connection, default is:%s\n", rpcInit.user);
printf(" [-k secret]: password for the connection, default is:%s\n", rpcInit.secret);
printf(" [-spi SPI]: security parameter index, default is:%d\n", rpcInit.spi);
printf(" [-d debugFlag]: debug flag, default:%d\n", rpcDebugFlag);
printf(" [-h help]: print out this help\n\n");
exit(0);
......
......@@ -68,6 +68,26 @@ void processShellMsg(int numOfMsgs, SRpcMsg *pMsg) {
}
int retrieveAuthInfo(char *meterId, char *spi, char *encrypt, char *secret, char *ckey) {
// app shall retrieve the auth info based on meterID from DB or a data file
// demo code here only for simple demo
int ret = 0;
if (strcmp(meterId, "michael") == 0) {
*spi = 1;
*encrypt = 0;
strcpy(secret, "mypassword");
strcpy(ckey, "key");
} else if (strcmp(meterId, "jeff") == 0) {
*spi = 0;
*encrypt = 0;
} else {
ret = -1; // user not there
}
return ret;
}
void processRequestMsg(char type, void *pCont, int contLen, void *thandle, int32_t code) {
tTrace("request is received, type:%d, contLen:%d", type, contLen);
SRpcMsg rpcMsg;
......@@ -91,6 +111,7 @@ int main(int argc, char *argv[]) {
rpcInit.cfp = processRequestMsg;
rpcInit.sessions = 1000;
rpcInit.idleTime = 2000;
rpcInit.afp = retrieveAuthInfo;
for (int i=1; i<argc; ++i) {
if (strcmp(argv[i], "-p")==0 && i < argc-1) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册