diff --git a/src/rpc/src/rpcMain.c b/src/rpc/src/rpcMain.c index 247e21a4cf7bffce568ba6c82c649f3db1fa6f8c..6ceb1f98bb8e524861c4a14633395e2e02853eb3 100755 --- a/src/rpc/src/rpcMain.c +++ b/src/rpc/src/rpcMain.c @@ -394,8 +394,8 @@ void rpcSendResponse(void *handle, int32_t code, void *pCont, int contLen) { // set msg header pHead->version = 1; pHead->msgType = pConn->inType+1; - pHead->spi = 0; - pHead->encrypt = 0; + pHead->spi = pConn->spi; + pHead->encrypt = pConn->encrypt; pHead->tranId = pConn->inTranId; pHead->sourceId = pConn->ownId; pHead->destId = pConn->peerId; @@ -545,11 +545,11 @@ static SRpcConn *rpcAllocateClientConn(SRpcInfo *pRpc) { static SRpcConn *rpcAllocateServerConn(SRpcInfo *pRpc, SRecvInfo *pRecv) { SRpcConn *pConn = NULL; - char hashstr[40]; + char hashstr[40] = {0}; SRpcHead *pHead = (SRpcHead *)pRecv->msg; sprintf(hashstr, "%x:%x:%x:%d", pRecv->ip, pHead->uid, pHead->sourceId, pRecv->connType); - + // check if it is already allocated SRpcConn **ppConn = (SRpcConn **)(taosGetStrHashData(pRpc->hash, hashstr)); if (ppConn) pConn = *ppConn; @@ -567,7 +567,7 @@ static SRpcConn *rpcAllocateServerConn(SRpcInfo *pRpc, SRecvInfo *pRecv) { pConn->sid = sid; pConn->tranId = (uint16_t)(rand() & 0xFFFF); pConn->ownId = htonl(pConn->sid); - if (pRpc->afp && (*pRpc->afp)(pConn->user, &pConn->spi, &pConn->encrypt, pConn->secret, pConn->ckey)) { + if (pRpc->afp && (*pRpc->afp)(pConn->user, &pConn->spi, &pConn->encrypt, pConn->secret, pConn->ckey) < 0) { tWarn("%s %p, user not there", pRpc->label, pConn); taosFreeId(pRpc->idPool, sid); // sid shall be released terrno = TSDB_CODE_INVALID_USER; @@ -720,21 +720,12 @@ static SRpcConn *rpcProcessMsgHead(SRpcInfo *pRpc, SRecvInfo *pRecv) { SRpcHead *pHead = (SRpcHead *)pRecv->msg; sid = htonl(pHead->destId); - pHead->code = htonl(pHead->code); - pHead->msgLen = (int32_t)htonl((uint32_t)pHead->msgLen); - pHead->port = htons(pHead->port); if (pHead->msgType >= TSDB_MSG_TYPE_MAX || pHead->msgType <= 0) { tTrace("%s sid:%d, invalid message type:%d", pRpc->label, sid, pHead->msgType); terrno = TSDB_CODE_INVALID_MSG_TYPE; return NULL; } - if (pRecv->msgLen != pHead->msgLen) { - tTrace("%s sid:%d, %s has invalid length, dataLen:%d, msgLen:%d", pRpc->label, sid, - taosMsg[pHead->msgType], pRecv->msgLen, pHead->msgLen); - terrno = TSDB_CODE_INVALID_MSG_LEN; return NULL; - } - if (sid < 0 || sid >= pRpc->sessions) { tTrace("%s sid:%d, sid is out of range, max sid:%d, %s discarded", pRpc->label, sid, pRpc->sessions, taosMsg[pHead->msgType]); @@ -756,10 +747,14 @@ static SRpcConn *rpcProcessMsgHead(SRpcInfo *pRpc, SRecvInfo *pRecv) { } if (pRecv->port) pConn->peerPort = pRecv->port; - if (pHead->port) pConn->peerPort = pHead->port; + if (pHead->port) pConn->peerPort = htons(pHead->port); if (pHead->uid) pConn->peerUid = pHead->uid; terrno = rpcCheckAuthentication(pConn, (char *)pHead, pRecv->msgLen); + + // code can be transformed only after authentication + pHead->code = htonl(pHead->code); + if (terrno == 0) { if (pHead->msgType != TSDB_MSG_TYPE_REG && pHead->encrypt) { // decrypt here @@ -906,9 +901,9 @@ static void rpcSendErrorMsgToPeer(SRecvInfo *pRecv, int32_t code) { pReplyHead->version = pRecvHead->version; pReplyHead->msgType = (char)(pRecvHead->msgType + 1); pReplyHead->spi = 0; - pReplyHead->encrypt = 0; + pReplyHead->encrypt = pRecvHead->encrypt; pReplyHead->tranId = pRecvHead->tranId; - pReplyHead->sourceId = 0; + pReplyHead->sourceId = pRecvHead->destId; pReplyHead->destId = pRecvHead->sourceId; memcpy(pReplyHead->user, pRecvHead->user, tListLen(pReplyHead->user)); @@ -918,7 +913,7 @@ static void rpcSendErrorMsgToPeer(SRecvInfo *pRecv, int32_t code) { if (code == TSDB_CODE_INVALID_TIME_STAMP) { // include a time stamp if client's time is not synchronized well uint8_t *pContent = pReplyHead->content; - timeStamp = taosGetTimestampSec(); + timeStamp = htonl(taosGetTimestampSec()); memcpy(pContent, &timeStamp, sizeof(timeStamp)); msgLen += sizeof(timeStamp); } @@ -1031,10 +1026,10 @@ static void rpcProcessRetryTimer(void *param, void *tmrId) { pConn->retry++; if (pConn->retry < 4) { - tTrace("%s %p, re-send msg:%s to %s:%hu", pRpc->label, pConn, - taosMsg[pConn->outType], pConn->peerIpstr, pConn->peerPort); + tTrace("%s %p, re-send msg:%s to %s:%hu retry:%d", pRpc->label, pConn, + taosMsg[pConn->outType], pConn->peerIpstr, pConn->peerPort, pConn->retry); rpcSendMsgToPeer(pConn, pConn->pReqMsg, pConn->reqMsgLen); - taosTmrReset(rpcProcessRetryTimer, tsRpcTimer<retry, pConn, pRpc->tmrCtrl, &pConn->pTimer); + taosTmrReset(rpcProcessRetryTimer, tsRpcTimer, pConn, pRpc->tmrCtrl, &pConn->pTimer); } else { // close the connection tTrace("%s %p, failed to send msg:%s to %s:%hu", pRpc->label, pConn, @@ -1083,12 +1078,6 @@ static void rpcProcessProgressTimer(void *param, void *tmrId) { rpcUnlockConn(pConn); } -static void rpcFreeOutMsg(void *msg) { - if ( msg == NULL ) return; - char *req = ((char *)msg) - sizeof(SRpcReqContext); - free(req); -} - static int32_t rpcCompressRpcMsg(char* pCont, int32_t contLen) { SRpcHead *pHead = rpcHeadFromCont(pCont); int32_t finalLen = 0; @@ -1160,14 +1149,14 @@ static SRpcHead *rpcDecompressRpcMsg(SRpcHead *pHead) { return pHead; } -static int rpcAuthenticateMsg(uint8_t *pMsg, int msgLen, uint8_t *pAuth, uint8_t *pKey) { +static int rpcAuthenticateMsg(void *pMsg, int msgLen, void *pAuth, void *pKey) { MD5_CTX context; int ret = -1; MD5Init(&context); - MD5Update(&context, pKey, TSDB_KEY_LEN); - MD5Update(&context, pMsg, msgLen); - MD5Update(&context, pKey, TSDB_KEY_LEN); + MD5Update(&context, (uint8_t *)pKey, TSDB_KEY_LEN); + MD5Update(&context, (uint8_t *)pMsg, msgLen); + MD5Update(&context, (uint8_t *)pKey, TSDB_KEY_LEN); MD5Final(&context); if (memcmp(context.digest, pAuth, sizeof(context.digest)) == 0) ret = 0; @@ -1175,18 +1164,16 @@ static int rpcAuthenticateMsg(uint8_t *pMsg, int msgLen, uint8_t *pAuth, uint8_t return ret; } -static int rpcBuildAuthHead(uint8_t *pMsg, int msgLen, uint8_t *pAuth, uint8_t *pKey) { +static void rpcBuildAuthHead(void *pMsg, int msgLen, void *pAuth, void *pKey) { MD5_CTX context; MD5Init(&context); - MD5Update(&context, pKey, TSDB_KEY_LEN); + MD5Update(&context, (uint8_t *)pKey, TSDB_KEY_LEN); MD5Update(&context, (uint8_t *)pMsg, msgLen); - MD5Update(&context, pKey, TSDB_KEY_LEN); + MD5Update(&context, (uint8_t *)pKey, TSDB_KEY_LEN); MD5Final(&context); memcpy(pAuth, context.digest, sizeof(context.digest)); - - return 0; } static int rpcAddAuthPart(SRpcConn *pConn, char *msg, int msgLen) { @@ -1199,7 +1186,7 @@ static int rpcAddAuthPart(SRpcConn *pConn, char *msg, int msgLen) { pDigest->timeStamp = htonl(taosGetTimestampSec()); msgLen += sizeof(SRpcDigest); pHead->msgLen = (int32_t)htonl((uint32_t)msgLen); - rpcBuildAuthHead((uint8_t *)pHead, msgLen - TSDB_AUTH_LEN, pDigest->auth, (uint8_t *)pConn->secret); + rpcBuildAuthHead(pHead, msgLen - TSDB_AUTH_LEN, pDigest->auth, pConn->secret); } else { pHead->msgLen = (int32_t)htonl((uint32_t)msgLen); } @@ -1212,7 +1199,21 @@ static int rpcCheckAuthentication(SRpcConn *pConn, char *msg, int msgLen) { SRpcInfo *pRpc = pConn->pRpc; int32_t code = 0; - if (pConn->spi == 0 ) return 0; + if (pConn->spi == 0) { + pHead->msgLen = (int32_t)htonl((uint32_t)pHead->msgLen); + return 0; + } + + if ( !rpcIsReq(pHead->msgType) ) { + // for response, if code is auth failure, it shall bypass the auth process + code = htonl(pHead->code); + if (code==TSDB_CODE_INVALID_TIME_STAMP || code==TSDB_CODE_AUTH_FAILURE || code==TSDB_CODE_INVALID_USER) { + pHead->msgLen = (int32_t)htonl((uint32_t)pHead->msgLen); + return 0; + } + } + + code = 0; if (pHead->spi == pConn->spi) { // authentication @@ -1222,23 +1223,19 @@ static int rpcCheckAuthentication(SRpcConn *pConn, char *msg, int msgLen) { delta = (int32_t)htonl(pDigest->timeStamp); delta -= (int32_t)taosGetTimestampSec(); if (abs(delta) > 900) { - tWarn("%s %p, time diff:%d is too big, msg discarded, timestamp:%d", pRpc->label, pConn, - delta, htonl(pDigest->timeStamp)); + tWarn("%s %p, time diff:%d is too big, msg discarded", pRpc->label, pConn, delta); code = TSDB_CODE_INVALID_TIME_STAMP; } else { - if (rpcAuthenticateMsg((uint8_t *)pHead, msgLen - TSDB_AUTH_LEN, pDigest->auth, (uint8_t *)pConn->secret) < 0) { + if (rpcAuthenticateMsg(pHead, msgLen-TSDB_AUTH_LEN, pDigest->auth, pConn->secret) < 0) { tError("%s %p, authentication failed, msg discarded", pRpc->label, pConn); code = TSDB_CODE_AUTH_FAILURE; } else { - pHead->msgLen -= sizeof(SRpcDigest); + pHead->msgLen = (int32_t)htonl((uint32_t)pHead->msgLen) - sizeof(SRpcDigest); } } } else { - // if it is request or response with code 0, msg shall be discarded - if (rpcIsReq(pHead->msgType) || (pHead->content[0] == 0)) { - tTrace("%s %p, auth spi not matched, msg discarded", pRpc->label, pConn); - code = TSDB_CODE_AUTH_FAILURE; - } + tTrace("%s %p, auth spi not matched, msg discarded", pRpc->label, pConn); + code = TSDB_CODE_AUTH_FAILURE; } return code; diff --git a/src/rpc/src/rpcUdp.c b/src/rpc/src/rpcUdp.c index 471b1bcc723b8e1557b1ab158a7570c39bf1f466..64a4df0e735a9454f7e0f7f0030c931054bd40b1 100644 --- a/src/rpc/src/rpcUdp.c +++ b/src/rpc/src/rpcUdp.c @@ -285,9 +285,9 @@ int taosSendUdpData(uint32_t ip, uint16_t port, void *data, int dataLen, void *c destAdd.sin_addr.s_addr = ip; destAdd.sin_port = htons(port); + //tTrace("%s msg is sent to 0x%x:%hu len:%d ret:%d localPort:%hu chandle:0x%x", pConn->label, destAdd.sin_addr.s_addr, + // port, dataLen, ret, pConn->localPort, chandle); int ret = (int)sendto(pConn->fd, data, (size_t)dataLen, 0, (struct sockaddr *)&destAdd, sizeof(destAdd)); - tTrace("%s msg is sent to 0x%x:%hu len:%d ret:%d localPort:%hu chandle:0x%x", pConn->label, destAdd.sin_addr.s_addr, - port, dataLen, ret, pConn->localPort, chandle); return ret; } diff --git a/src/rpc/test/rclient.c b/src/rpc/test/rclient.c index 181f8a8475f4eb42bf08decfb6d6edfeca9c3897..63c23ce7bc755ef8e492d1b7ada6b49351dfbc2f 100644 --- a/src/rpc/test/rclient.c +++ b/src/rpc/test/rclient.c @@ -110,6 +110,7 @@ int main(int argc, char *argv[]) { rpcInit.user = "michael"; rpcInit.secret = "mypassword"; rpcInit.ckey = "key"; + rpcInit.spi = 1; for (int i=1; i