提交 5622830c 编写于 作者: D Daniel P. Berrange 提交者: Daniel Veillard

Add mutex protection to SASL and TLS modules

The virNetSASLContext, virNetSASLSession, virNetTLSContext and
virNetTLSSession classes previously relied in their owners
(virNetClient / virNetServer / virNetServerClient) to provide
locking protection for concurrent usage. When virNetSocket
gained its own locking code, this invalidated the implicit
safety the SASL/TLS modules relied on. Thus we need to give
them all explicit locking of their own via new mutexes.

* src/rpc/virnetsaslcontext.c, src/rpc/virnettlscontext.c: Add
  a mutex per object
上级 a4458597
...@@ -27,6 +27,7 @@ ...@@ -27,6 +27,7 @@
#include "virterror_internal.h" #include "virterror_internal.h"
#include "memory.h" #include "memory.h"
#include "threads.h"
#include "logging.h" #include "logging.h"
#define VIR_FROM_THIS VIR_FROM_RPC #define VIR_FROM_THIS VIR_FROM_RPC
...@@ -36,11 +37,13 @@ ...@@ -36,11 +37,13 @@
struct _virNetSASLContext { struct _virNetSASLContext {
virMutex lock;
const char *const*usernameWhitelist; const char *const*usernameWhitelist;
int refs; int refs;
}; };
struct _virNetSASLSession { struct _virNetSASLSession {
virMutex lock;
sasl_conn_t *conn; sasl_conn_t *conn;
int refs; int refs;
size_t maxbufsize; size_t maxbufsize;
...@@ -65,6 +68,13 @@ virNetSASLContextPtr virNetSASLContextNewClient(void) ...@@ -65,6 +68,13 @@ virNetSASLContextPtr virNetSASLContextNewClient(void)
return NULL; return NULL;
} }
if (virMutexInit(&ctxt->lock) < 0) {
virNetError(VIR_ERR_INTERNAL_ERROR, "%s",
_("Failed to initialized mutex"));
VIR_FREE(ctxt);
return NULL;
}
ctxt->refs = 1; ctxt->refs = 1;
return ctxt; return ctxt;
...@@ -88,6 +98,13 @@ virNetSASLContextPtr virNetSASLContextNewServer(const char *const*usernameWhitel ...@@ -88,6 +98,13 @@ virNetSASLContextPtr virNetSASLContextNewServer(const char *const*usernameWhitel
return NULL; return NULL;
} }
if (virMutexInit(&ctxt->lock) < 0) {
virNetError(VIR_ERR_INTERNAL_ERROR, "%s",
_("Failed to initialized mutex"));
VIR_FREE(ctxt);
return NULL;
}
ctxt->usernameWhitelist = usernameWhitelist; ctxt->usernameWhitelist = usernameWhitelist;
ctxt->refs = 1; ctxt->refs = 1;
...@@ -98,21 +115,28 @@ int virNetSASLContextCheckIdentity(virNetSASLContextPtr ctxt, ...@@ -98,21 +115,28 @@ int virNetSASLContextCheckIdentity(virNetSASLContextPtr ctxt,
const char *identity) const char *identity)
{ {
const char *const*wildcards; const char *const*wildcards;
int ret = -1;
virMutexLock(&ctxt->lock);
/* If the list is not set, allow any DN. */ /* If the list is not set, allow any DN. */
wildcards = ctxt->usernameWhitelist; wildcards = ctxt->usernameWhitelist;
if (!wildcards) if (!wildcards) {
return 1; /* No ACL, allow all */ ret = 1; /* No ACL, allow all */
goto cleanup;
}
while (*wildcards) { while (*wildcards) {
int ret = fnmatch (*wildcards, identity, 0); int rv = fnmatch (*wildcards, identity, 0);
if (ret == 0) /* Succesful match */ if (rv == 0) {
return 1; ret = 1;
goto cleanup; /* Succesful match */
}
if (ret != FNM_NOMATCH) { if (ret != FNM_NOMATCH) {
virNetError(VIR_ERR_INTERNAL_ERROR, virNetError(VIR_ERR_INTERNAL_ERROR,
_("Malformed TLS whitelist regular expression '%s'"), _("Malformed TLS whitelist regular expression '%s'"),
*wildcards); *wildcards);
return -1; goto cleanup;
} }
wildcards++; wildcards++;
...@@ -124,13 +148,19 @@ int virNetSASLContextCheckIdentity(virNetSASLContextPtr ctxt, ...@@ -124,13 +148,19 @@ int virNetSASLContextCheckIdentity(virNetSASLContextPtr ctxt,
/* This is the most common error: make it informative. */ /* This is the most common error: make it informative. */
virNetError(VIR_ERR_SYSTEM_ERROR, "%s", virNetError(VIR_ERR_SYSTEM_ERROR, "%s",
_("Client's username is not on the list of allowed clients")); _("Client's username is not on the list of allowed clients"));
return 0; ret = 0;
cleanup:
virMutexUnlock(&ctxt->lock);
return ret;
} }
void virNetSASLContextRef(virNetSASLContextPtr ctxt) void virNetSASLContextRef(virNetSASLContextPtr ctxt)
{ {
virMutexLock(&ctxt->lock);
ctxt->refs++; ctxt->refs++;
virMutexUnlock(&ctxt->lock);
} }
void virNetSASLContextFree(virNetSASLContextPtr ctxt) void virNetSASLContextFree(virNetSASLContextPtr ctxt)
...@@ -138,10 +168,15 @@ void virNetSASLContextFree(virNetSASLContextPtr ctxt) ...@@ -138,10 +168,15 @@ void virNetSASLContextFree(virNetSASLContextPtr ctxt)
if (!ctxt) if (!ctxt)
return; return;
virMutexLock(&ctxt->lock);
ctxt->refs--; ctxt->refs--;
if (ctxt->refs > 0) if (ctxt->refs > 0) {
virMutexUnlock(&ctxt->lock);
return; return;
}
virMutexUnlock(&ctxt->lock);
virMutexDestroy(&ctxt->lock);
VIR_FREE(ctxt); VIR_FREE(ctxt);
} }
...@@ -160,6 +195,13 @@ virNetSASLSessionPtr virNetSASLSessionNewClient(virNetSASLContextPtr ctxt ATTRIB ...@@ -160,6 +195,13 @@ virNetSASLSessionPtr virNetSASLSessionNewClient(virNetSASLContextPtr ctxt ATTRIB
goto cleanup; goto cleanup;
} }
if (virMutexInit(&sasl->lock) < 0) {
virNetError(VIR_ERR_INTERNAL_ERROR, "%s",
_("Failed to initialized mutex"));
VIR_FREE(sasl);
return NULL;
}
sasl->refs = 1; sasl->refs = 1;
/* Arbitrary size for amount of data we can encode in a single block */ /* Arbitrary size for amount of data we can encode in a single block */
sasl->maxbufsize = 1 << 16; sasl->maxbufsize = 1 << 16;
...@@ -198,6 +240,13 @@ virNetSASLSessionPtr virNetSASLSessionNewServer(virNetSASLContextPtr ctxt ATTRIB ...@@ -198,6 +240,13 @@ virNetSASLSessionPtr virNetSASLSessionNewServer(virNetSASLContextPtr ctxt ATTRIB
goto cleanup; goto cleanup;
} }
if (virMutexInit(&sasl->lock) < 0) {
virNetError(VIR_ERR_INTERNAL_ERROR, "%s",
_("Failed to initialized mutex"));
VIR_FREE(sasl);
return NULL;
}
sasl->refs = 1; sasl->refs = 1;
/* Arbitrary size for amount of data we can encode in a single block */ /* Arbitrary size for amount of data we can encode in a single block */
sasl->maxbufsize = 1 << 16; sasl->maxbufsize = 1 << 16;
...@@ -226,43 +275,56 @@ cleanup: ...@@ -226,43 +275,56 @@ cleanup:
void virNetSASLSessionRef(virNetSASLSessionPtr sasl) void virNetSASLSessionRef(virNetSASLSessionPtr sasl)
{ {
virMutexLock(&sasl->lock);
sasl->refs++; sasl->refs++;
virMutexUnlock(&sasl->lock);
} }
int virNetSASLSessionExtKeySize(virNetSASLSessionPtr sasl, int virNetSASLSessionExtKeySize(virNetSASLSessionPtr sasl,
int ssf) int ssf)
{ {
int err; int err;
int ret = -1;
virMutexLock(&sasl->lock);
err = sasl_setprop(sasl->conn, SASL_SSF_EXTERNAL, &ssf); err = sasl_setprop(sasl->conn, SASL_SSF_EXTERNAL, &ssf);
if (err != SASL_OK) { if (err != SASL_OK) {
virNetError(VIR_ERR_INTERNAL_ERROR, virNetError(VIR_ERR_INTERNAL_ERROR,
_("cannot set external SSF %d (%s)"), _("cannot set external SSF %d (%s)"),
err, sasl_errstring(err, NULL, NULL)); err, sasl_errstring(err, NULL, NULL));
return -1; goto cleanup;
} }
return 0;
ret = 0;
cleanup:
virMutexLock(&sasl->lock);
return ret;
} }
const char *virNetSASLSessionGetIdentity(virNetSASLSessionPtr sasl) const char *virNetSASLSessionGetIdentity(virNetSASLSessionPtr sasl)
{ {
const void *val; const void *val = NULL;
int err; int err;
virMutexLock(&sasl->lock);
err = sasl_getprop(sasl->conn, SASL_USERNAME, &val); err = sasl_getprop(sasl->conn, SASL_USERNAME, &val);
if (err != SASL_OK) { if (err != SASL_OK) {
virNetError(VIR_ERR_AUTH_FAILED, virNetError(VIR_ERR_AUTH_FAILED,
_("cannot query SASL username on connection %d (%s)"), _("cannot query SASL username on connection %d (%s)"),
err, sasl_errstring(err, NULL, NULL)); err, sasl_errstring(err, NULL, NULL));
return NULL; val = NULL;
goto cleanup;
} }
if (val == NULL) { if (val == NULL) {
virNetError(VIR_ERR_AUTH_FAILED, virNetError(VIR_ERR_AUTH_FAILED,
_("no client username was found")); _("no client username was found"));
return NULL; goto cleanup;
} }
VIR_DEBUG("SASL client username %s", (const char *)val); VIR_DEBUG("SASL client username %s", (const char *)val);
cleanup:
virMutexUnlock(&sasl->lock);
return (const char*)val; return (const char*)val;
} }
...@@ -272,14 +334,20 @@ int virNetSASLSessionGetKeySize(virNetSASLSessionPtr sasl) ...@@ -272,14 +334,20 @@ int virNetSASLSessionGetKeySize(virNetSASLSessionPtr sasl)
int err; int err;
int ssf; int ssf;
const void *val; const void *val;
virMutexLock(&sasl->lock);
err = sasl_getprop(sasl->conn, SASL_SSF, &val); err = sasl_getprop(sasl->conn, SASL_SSF, &val);
if (err != SASL_OK) { if (err != SASL_OK) {
virNetError(VIR_ERR_AUTH_FAILED, virNetError(VIR_ERR_AUTH_FAILED,
_("cannot query SASL ssf on connection %d (%s)"), _("cannot query SASL ssf on connection %d (%s)"),
err, sasl_errstring(err, NULL, NULL)); err, sasl_errstring(err, NULL, NULL));
return -1; ssf = -1;
goto cleanup;
} }
ssf = *(const int *)val; ssf = *(const int *)val;
cleanup:
virMutexUnlock(&sasl->lock);
return ssf; return ssf;
} }
...@@ -290,10 +358,12 @@ int virNetSASLSessionSecProps(virNetSASLSessionPtr sasl, ...@@ -290,10 +358,12 @@ int virNetSASLSessionSecProps(virNetSASLSessionPtr sasl,
{ {
sasl_security_properties_t secprops; sasl_security_properties_t secprops;
int err; int err;
int ret = -1;
VIR_DEBUG("minSSF=%d maxSSF=%d allowAnonymous=%d maxbufsize=%zu", VIR_DEBUG("minSSF=%d maxSSF=%d allowAnonymous=%d maxbufsize=%zu",
minSSF, maxSSF, allowAnonymous, sasl->maxbufsize); minSSF, maxSSF, allowAnonymous, sasl->maxbufsize);
virMutexLock(&sasl->lock);
memset(&secprops, 0, sizeof secprops); memset(&secprops, 0, sizeof secprops);
secprops.min_ssf = minSSF; secprops.min_ssf = minSSF;
...@@ -307,10 +377,14 @@ int virNetSASLSessionSecProps(virNetSASLSessionPtr sasl, ...@@ -307,10 +377,14 @@ int virNetSASLSessionSecProps(virNetSASLSessionPtr sasl,
virNetError(VIR_ERR_INTERNAL_ERROR, virNetError(VIR_ERR_INTERNAL_ERROR,
_("cannot set security props %d (%s)"), _("cannot set security props %d (%s)"),
err, sasl_errstring(err, NULL, NULL)); err, sasl_errstring(err, NULL, NULL));
return -1; goto cleanup;
} }
return 0; ret = 0;
cleanup:
virMutexUnlock(&sasl->lock);
return ret;
} }
...@@ -336,9 +410,10 @@ static int virNetSASLSessionUpdateBufSize(virNetSASLSessionPtr sasl) ...@@ -336,9 +410,10 @@ static int virNetSASLSessionUpdateBufSize(virNetSASLSessionPtr sasl)
char *virNetSASLSessionListMechanisms(virNetSASLSessionPtr sasl) char *virNetSASLSessionListMechanisms(virNetSASLSessionPtr sasl)
{ {
const char *mechlist; const char *mechlist;
char *ret; char *ret = NULL;
int err; int err;
virMutexLock(&sasl->lock);
err = sasl_listmech(sasl->conn, err = sasl_listmech(sasl->conn,
NULL, /* Don't need to set user */ NULL, /* Don't need to set user */
"", /* Prefix */ "", /* Prefix */
...@@ -351,12 +426,15 @@ char *virNetSASLSessionListMechanisms(virNetSASLSessionPtr sasl) ...@@ -351,12 +426,15 @@ char *virNetSASLSessionListMechanisms(virNetSASLSessionPtr sasl)
virNetError(VIR_ERR_INTERNAL_ERROR, virNetError(VIR_ERR_INTERNAL_ERROR,
_("cannot list SASL mechanisms %d (%s)"), _("cannot list SASL mechanisms %d (%s)"),
err, sasl_errdetail(sasl->conn)); err, sasl_errdetail(sasl->conn));
return NULL; goto cleanup;
} }
if (!(ret = strdup(mechlist))) { if (!(ret = strdup(mechlist))) {
virReportOOMError(); virReportOOMError();
return NULL; goto cleanup;
} }
cleanup:
virMutexUnlock(&sasl->lock);
return ret; return ret;
} }
...@@ -369,35 +447,44 @@ int virNetSASLSessionClientStart(virNetSASLSessionPtr sasl, ...@@ -369,35 +447,44 @@ int virNetSASLSessionClientStart(virNetSASLSessionPtr sasl,
const char **mech) const char **mech)
{ {
unsigned outlen = 0; unsigned outlen = 0;
int err;
int ret = -1;
VIR_DEBUG("sasl=%p mechlist=%s prompt_need=%p clientout=%p clientoutlen=%p mech=%p", VIR_DEBUG("sasl=%p mechlist=%s prompt_need=%p clientout=%p clientoutlen=%p mech=%p",
sasl, mechlist, prompt_need, clientout, clientoutlen, mech); sasl, mechlist, prompt_need, clientout, clientoutlen, mech);
int err = sasl_client_start(sasl->conn, virMutexLock(&sasl->lock);
mechlist, err = sasl_client_start(sasl->conn,
prompt_need, mechlist,
clientout, prompt_need,
&outlen, clientout,
mech); &outlen,
mech);
*clientoutlen = outlen; *clientoutlen = outlen;
switch (err) { switch (err) {
case SASL_OK: case SASL_OK:
if (virNetSASLSessionUpdateBufSize(sasl) < 0) if (virNetSASLSessionUpdateBufSize(sasl) < 0)
return -1; goto cleanup;
return VIR_NET_SASL_COMPLETE; ret = VIR_NET_SASL_COMPLETE;
break;
case SASL_CONTINUE: case SASL_CONTINUE:
return VIR_NET_SASL_CONTINUE; ret = VIR_NET_SASL_CONTINUE;
break;
case SASL_INTERACT: case SASL_INTERACT:
return VIR_NET_SASL_INTERACT; ret = VIR_NET_SASL_INTERACT;
break;
default: default:
virNetError(VIR_ERR_AUTH_FAILED, virNetError(VIR_ERR_AUTH_FAILED,
_("Failed to start SASL negotiation: %d (%s)"), _("Failed to start SASL negotiation: %d (%s)"),
err, sasl_errdetail(sasl->conn)); err, sasl_errdetail(sasl->conn));
return -1; break;
} }
cleanup:
virMutexUnlock(&sasl->lock);
return ret;
} }
...@@ -410,34 +497,43 @@ int virNetSASLSessionClientStep(virNetSASLSessionPtr sasl, ...@@ -410,34 +497,43 @@ int virNetSASLSessionClientStep(virNetSASLSessionPtr sasl,
{ {
unsigned inlen = serverinlen; unsigned inlen = serverinlen;
unsigned outlen = 0; unsigned outlen = 0;
int err;
int ret = -1;
VIR_DEBUG("sasl=%p serverin=%s serverinlen=%zu prompt_need=%p clientout=%p clientoutlen=%p", VIR_DEBUG("sasl=%p serverin=%s serverinlen=%zu prompt_need=%p clientout=%p clientoutlen=%p",
sasl, serverin, serverinlen, prompt_need, clientout, clientoutlen); sasl, serverin, serverinlen, prompt_need, clientout, clientoutlen);
int err = sasl_client_step(sasl->conn, virMutexLock(&sasl->lock);
serverin, err = sasl_client_step(sasl->conn,
inlen, serverin,
prompt_need, inlen,
clientout, prompt_need,
&outlen); clientout,
&outlen);
*clientoutlen = outlen; *clientoutlen = outlen;
switch (err) { switch (err) {
case SASL_OK: case SASL_OK:
if (virNetSASLSessionUpdateBufSize(sasl) < 0) if (virNetSASLSessionUpdateBufSize(sasl) < 0)
return -1; goto cleanup;
return VIR_NET_SASL_COMPLETE; ret = VIR_NET_SASL_COMPLETE;
break;
case SASL_CONTINUE: case SASL_CONTINUE:
return VIR_NET_SASL_CONTINUE; ret = VIR_NET_SASL_CONTINUE;
break;
case SASL_INTERACT: case SASL_INTERACT:
return VIR_NET_SASL_INTERACT; ret = VIR_NET_SASL_INTERACT;
break;
default: default:
virNetError(VIR_ERR_AUTH_FAILED, virNetError(VIR_ERR_AUTH_FAILED,
_("Failed to step SASL negotiation: %d (%s)"), _("Failed to step SASL negotiation: %d (%s)"),
err, sasl_errdetail(sasl->conn)); err, sasl_errdetail(sasl->conn));
return -1; break;
} }
cleanup:
virMutexUnlock(&sasl->lock);
return ret;
} }
int virNetSASLSessionServerStart(virNetSASLSessionPtr sasl, int virNetSASLSessionServerStart(virNetSASLSessionPtr sasl,
...@@ -449,31 +545,41 @@ int virNetSASLSessionServerStart(virNetSASLSessionPtr sasl, ...@@ -449,31 +545,41 @@ int virNetSASLSessionServerStart(virNetSASLSessionPtr sasl,
{ {
unsigned inlen = clientinlen; unsigned inlen = clientinlen;
unsigned outlen = 0; unsigned outlen = 0;
int err = sasl_server_start(sasl->conn, int err;
mechname, int ret = -1;
clientin,
inlen, virMutexLock(&sasl->lock);
serverout, err = sasl_server_start(sasl->conn,
&outlen); mechname,
clientin,
inlen,
serverout,
&outlen);
*serveroutlen = outlen; *serveroutlen = outlen;
switch (err) { switch (err) {
case SASL_OK: case SASL_OK:
if (virNetSASLSessionUpdateBufSize(sasl) < 0) if (virNetSASLSessionUpdateBufSize(sasl) < 0)
return -1; goto cleanup;
return VIR_NET_SASL_COMPLETE; ret = VIR_NET_SASL_COMPLETE;
break;
case SASL_CONTINUE: case SASL_CONTINUE:
return VIR_NET_SASL_CONTINUE; ret = VIR_NET_SASL_CONTINUE;
break;
case SASL_INTERACT: case SASL_INTERACT:
return VIR_NET_SASL_INTERACT; ret = VIR_NET_SASL_INTERACT;
break;
default: default:
virNetError(VIR_ERR_AUTH_FAILED, virNetError(VIR_ERR_AUTH_FAILED,
_("Failed to start SASL negotiation: %d (%s)"), _("Failed to start SASL negotiation: %d (%s)"),
err, sasl_errdetail(sasl->conn)); err, sasl_errdetail(sasl->conn));
return -1; break;
} }
cleanup:
virMutexUnlock(&sasl->lock);
return ret;
} }
...@@ -485,36 +591,49 @@ int virNetSASLSessionServerStep(virNetSASLSessionPtr sasl, ...@@ -485,36 +591,49 @@ int virNetSASLSessionServerStep(virNetSASLSessionPtr sasl,
{ {
unsigned inlen = clientinlen; unsigned inlen = clientinlen;
unsigned outlen = 0; unsigned outlen = 0;
int err;
int ret = -1;
int err = sasl_server_step(sasl->conn, virMutexLock(&sasl->lock);
clientin, err = sasl_server_step(sasl->conn,
inlen, clientin,
serverout, inlen,
&outlen); serverout,
&outlen);
*serveroutlen = outlen; *serveroutlen = outlen;
switch (err) { switch (err) {
case SASL_OK: case SASL_OK:
if (virNetSASLSessionUpdateBufSize(sasl) < 0) if (virNetSASLSessionUpdateBufSize(sasl) < 0)
return -1; goto cleanup;
return VIR_NET_SASL_COMPLETE; ret = VIR_NET_SASL_COMPLETE;
break;
case SASL_CONTINUE: case SASL_CONTINUE:
return VIR_NET_SASL_CONTINUE; ret = VIR_NET_SASL_CONTINUE;
break;
case SASL_INTERACT: case SASL_INTERACT:
return VIR_NET_SASL_INTERACT; ret = VIR_NET_SASL_INTERACT;
break;
default: default:
virNetError(VIR_ERR_AUTH_FAILED, virNetError(VIR_ERR_AUTH_FAILED,
_("Failed to start SASL negotiation: %d (%s)"), _("Failed to start SASL negotiation: %d (%s)"),
err, sasl_errdetail(sasl->conn)); err, sasl_errdetail(sasl->conn));
return -1; break;
} }
cleanup:
virMutexUnlock(&sasl->lock);
return ret;
} }
size_t virNetSASLSessionGetMaxBufSize(virNetSASLSessionPtr sasl) size_t virNetSASLSessionGetMaxBufSize(virNetSASLSessionPtr sasl)
{ {
return sasl->maxbufsize; size_t ret;
virMutexLock(&sasl->lock);
ret = sasl->maxbufsize;
virMutexUnlock(&sasl->lock);
return ret;
} }
ssize_t virNetSASLSessionEncode(virNetSASLSessionPtr sasl, ssize_t virNetSASLSessionEncode(virNetSASLSessionPtr sasl,
...@@ -526,12 +645,14 @@ ssize_t virNetSASLSessionEncode(virNetSASLSessionPtr sasl, ...@@ -526,12 +645,14 @@ ssize_t virNetSASLSessionEncode(virNetSASLSessionPtr sasl,
unsigned inlen = inputLen; unsigned inlen = inputLen;
unsigned outlen = 0; unsigned outlen = 0;
int err; int err;
ssize_t ret = -1;
virMutexLock(&sasl->lock);
if (inputLen > sasl->maxbufsize) { if (inputLen > sasl->maxbufsize) {
virReportSystemError(EINVAL, virReportSystemError(EINVAL,
_("SASL data length %zu too long, max %zu"), _("SASL data length %zu too long, max %zu"),
inputLen, sasl->maxbufsize); inputLen, sasl->maxbufsize);
return -1; goto cleanup;
} }
err = sasl_encode(sasl->conn, err = sasl_encode(sasl->conn,
...@@ -545,9 +666,13 @@ ssize_t virNetSASLSessionEncode(virNetSASLSessionPtr sasl, ...@@ -545,9 +666,13 @@ ssize_t virNetSASLSessionEncode(virNetSASLSessionPtr sasl,
virNetError(VIR_ERR_INTERNAL_ERROR, virNetError(VIR_ERR_INTERNAL_ERROR,
_("failed to encode SASL data: %d (%s)"), _("failed to encode SASL data: %d (%s)"),
err, sasl_errstring(err, NULL, NULL)); err, sasl_errstring(err, NULL, NULL));
return -1; goto cleanup;
} }
return 0; ret = 0;
cleanup:
virMutexUnlock(&sasl->lock);
return ret;
} }
ssize_t virNetSASLSessionDecode(virNetSASLSessionPtr sasl, ssize_t virNetSASLSessionDecode(virNetSASLSessionPtr sasl,
...@@ -559,12 +684,14 @@ ssize_t virNetSASLSessionDecode(virNetSASLSessionPtr sasl, ...@@ -559,12 +684,14 @@ ssize_t virNetSASLSessionDecode(virNetSASLSessionPtr sasl,
unsigned inlen = inputLen; unsigned inlen = inputLen;
unsigned outlen = 0; unsigned outlen = 0;
int err; int err;
ssize_t ret = -1;
virMutexLock(&sasl->lock);
if (inputLen > sasl->maxbufsize) { if (inputLen > sasl->maxbufsize) {
virReportSystemError(EINVAL, virReportSystemError(EINVAL,
_("SASL data length %zu too long, max %zu"), _("SASL data length %zu too long, max %zu"),
inputLen, sasl->maxbufsize); inputLen, sasl->maxbufsize);
return -1; goto cleanup;
} }
err = sasl_decode(sasl->conn, err = sasl_decode(sasl->conn,
...@@ -577,9 +704,13 @@ ssize_t virNetSASLSessionDecode(virNetSASLSessionPtr sasl, ...@@ -577,9 +704,13 @@ ssize_t virNetSASLSessionDecode(virNetSASLSessionPtr sasl,
virNetError(VIR_ERR_INTERNAL_ERROR, virNetError(VIR_ERR_INTERNAL_ERROR,
_("failed to decode SASL data: %d (%s)"), _("failed to decode SASL data: %d (%s)"),
err, sasl_errstring(err, NULL, NULL)); err, sasl_errstring(err, NULL, NULL));
return -1; goto cleanup;
} }
return 0; ret = 0;
cleanup:
virMutexUnlock(&sasl->lock);
return ret;
} }
void virNetSASLSessionFree(virNetSASLSessionPtr sasl) void virNetSASLSessionFree(virNetSASLSessionPtr sasl)
...@@ -587,12 +718,17 @@ void virNetSASLSessionFree(virNetSASLSessionPtr sasl) ...@@ -587,12 +718,17 @@ void virNetSASLSessionFree(virNetSASLSessionPtr sasl)
if (!sasl) if (!sasl)
return; return;
virMutexLock(&sasl->lock);
sasl->refs--; sasl->refs--;
if (sasl->refs > 0) if (sasl->refs > 0) {
virMutexUnlock(&sasl->lock);
return; return;
}
if (sasl->conn) if (sasl->conn)
sasl_dispose(&sasl->conn); sasl_dispose(&sasl->conn);
virMutexUnlock(&sasl->lock);
virMutexDestroy(&sasl->lock);
VIR_FREE(sasl); VIR_FREE(sasl);
} }
...@@ -34,6 +34,7 @@ ...@@ -34,6 +34,7 @@
#include "virterror_internal.h" #include "virterror_internal.h"
#include "util.h" #include "util.h"
#include "logging.h" #include "logging.h"
#include "threads.h"
#include "configmake.h" #include "configmake.h"
#define DH_BITS 1024 #define DH_BITS 1024
...@@ -52,6 +53,7 @@ ...@@ -52,6 +53,7 @@
__FUNCTION__, __LINE__, __VA_ARGS__) __FUNCTION__, __LINE__, __VA_ARGS__)
struct _virNetTLSContext { struct _virNetTLSContext {
virMutex lock;
int refs; int refs;
gnutls_certificate_credentials_t x509cred; gnutls_certificate_credentials_t x509cred;
...@@ -63,6 +65,8 @@ struct _virNetTLSContext { ...@@ -63,6 +65,8 @@ struct _virNetTLSContext {
}; };
struct _virNetTLSSession { struct _virNetTLSSession {
virMutex lock;
int refs; int refs;
bool handshakeComplete; bool handshakeComplete;
...@@ -653,6 +657,13 @@ static virNetTLSContextPtr virNetTLSContextNew(const char *cacert, ...@@ -653,6 +657,13 @@ static virNetTLSContextPtr virNetTLSContextNew(const char *cacert,
return NULL; return NULL;
} }
if (virMutexInit(&ctxt->lock) < 0) {
virNetError(VIR_ERR_INTERNAL_ERROR, "%s",
_("Failed to initialized mutex"));
VIR_FREE(ctxt);
return NULL;
}
ctxt->refs = 1; ctxt->refs = 1;
/* Initialise GnuTLS. */ /* Initialise GnuTLS. */
...@@ -1053,18 +1064,29 @@ authfail: ...@@ -1053,18 +1064,29 @@ authfail:
int virNetTLSContextCheckCertificate(virNetTLSContextPtr ctxt, int virNetTLSContextCheckCertificate(virNetTLSContextPtr ctxt,
virNetTLSSessionPtr sess) virNetTLSSessionPtr sess)
{ {
int ret = -1;
virMutexLock(&ctxt->lock);
virMutexLock(&sess->lock);
if (virNetTLSContextValidCertificate(ctxt, sess) < 0) { if (virNetTLSContextValidCertificate(ctxt, sess) < 0) {
virErrorPtr err = virGetLastError(); virErrorPtr err = virGetLastError();
VIR_WARN("Certificate check failed %s", err && err->message ? err->message : "<unknown>"); VIR_WARN("Certificate check failed %s", err && err->message ? err->message : "<unknown>");
if (ctxt->requireValidCert) { if (ctxt->requireValidCert) {
virNetError(VIR_ERR_AUTH_FAILED, "%s", virNetError(VIR_ERR_AUTH_FAILED, "%s",
_("Failed to verify peer's certificate")); _("Failed to verify peer's certificate"));
return -1; goto cleanup;
} }
virResetLastError(); virResetLastError();
VIR_INFO("Ignoring bad certificate at user request"); VIR_INFO("Ignoring bad certificate at user request");
} }
return 0;
ret = 0;
cleanup:
virMutexUnlock(&ctxt->lock);
virMutexUnlock(&sess->lock);
return ret;
} }
void virNetTLSContextFree(virNetTLSContextPtr ctxt) void virNetTLSContextFree(virNetTLSContextPtr ctxt)
...@@ -1072,12 +1094,17 @@ void virNetTLSContextFree(virNetTLSContextPtr ctxt) ...@@ -1072,12 +1094,17 @@ void virNetTLSContextFree(virNetTLSContextPtr ctxt)
if (!ctxt) if (!ctxt)
return; return;
virMutexLock(&ctxt->lock);
ctxt->refs--; ctxt->refs--;
if (ctxt->refs > 0) if (ctxt->refs > 0) {
virMutexUnlock(&ctxt->lock);
return; return;
}
gnutls_dh_params_deinit(ctxt->dhParams); gnutls_dh_params_deinit(ctxt->dhParams);
gnutls_certificate_free_credentials(ctxt->x509cred); gnutls_certificate_free_credentials(ctxt->x509cred);
virMutexUnlock(&ctxt->lock);
virMutexDestroy(&ctxt->lock);
VIR_FREE(ctxt); VIR_FREE(ctxt);
} }
...@@ -1124,6 +1151,13 @@ virNetTLSSessionPtr virNetTLSSessionNew(virNetTLSContextPtr ctxt, ...@@ -1124,6 +1151,13 @@ virNetTLSSessionPtr virNetTLSSessionNew(virNetTLSContextPtr ctxt,
return NULL; return NULL;
} }
if (virMutexInit(&sess->lock) < 0) {
virNetError(VIR_ERR_INTERNAL_ERROR, "%s",
_("Failed to initialized mutex"));
VIR_FREE(ctxt);
return NULL;
}
sess->refs = 1; sess->refs = 1;
if (hostname && if (hostname &&
!(sess->hostname = strdup(hostname))) { !(sess->hostname = strdup(hostname))) {
...@@ -1184,7 +1218,9 @@ error: ...@@ -1184,7 +1218,9 @@ error:
void virNetTLSSessionRef(virNetTLSSessionPtr sess) void virNetTLSSessionRef(virNetTLSSessionPtr sess)
{ {
virMutexLock(&sess->lock);
sess->refs++; sess->refs++;
virMutexUnlock(&sess->lock);
} }
void virNetTLSSessionSetIOCallbacks(virNetTLSSessionPtr sess, void virNetTLSSessionSetIOCallbacks(virNetTLSSessionPtr sess,
...@@ -1192,9 +1228,11 @@ void virNetTLSSessionSetIOCallbacks(virNetTLSSessionPtr sess, ...@@ -1192,9 +1228,11 @@ void virNetTLSSessionSetIOCallbacks(virNetTLSSessionPtr sess,
virNetTLSSessionReadFunc readFunc, virNetTLSSessionReadFunc readFunc,
void *opaque) void *opaque)
{ {
virMutexLock(&sess->lock);
sess->writeFunc = writeFunc; sess->writeFunc = writeFunc;
sess->readFunc = readFunc; sess->readFunc = readFunc;
sess->opaque = opaque; sess->opaque = opaque;
virMutexUnlock(&sess->lock);
} }
...@@ -1202,10 +1240,12 @@ ssize_t virNetTLSSessionWrite(virNetTLSSessionPtr sess, ...@@ -1202,10 +1240,12 @@ ssize_t virNetTLSSessionWrite(virNetTLSSessionPtr sess,
const char *buf, size_t len) const char *buf, size_t len)
{ {
ssize_t ret; ssize_t ret;
virMutexLock(&sess->lock);
ret = gnutls_record_send(sess->session, buf, len); ret = gnutls_record_send(sess->session, buf, len);
if (ret >= 0) if (ret >= 0)
return ret; goto cleanup;
switch (ret) { switch (ret) {
case GNUTLS_E_AGAIN: case GNUTLS_E_AGAIN:
...@@ -1222,7 +1262,11 @@ ssize_t virNetTLSSessionWrite(virNetTLSSessionPtr sess, ...@@ -1222,7 +1262,11 @@ ssize_t virNetTLSSessionWrite(virNetTLSSessionPtr sess,
break; break;
} }
return -1; ret = -1;
cleanup:
virMutexUnlock(&sess->lock);
return ret;
} }
ssize_t virNetTLSSessionRead(virNetTLSSessionPtr sess, ssize_t virNetTLSSessionRead(virNetTLSSessionPtr sess,
...@@ -1230,10 +1274,11 @@ ssize_t virNetTLSSessionRead(virNetTLSSessionPtr sess, ...@@ -1230,10 +1274,11 @@ ssize_t virNetTLSSessionRead(virNetTLSSessionPtr sess,
{ {
ssize_t ret; ssize_t ret;
virMutexLock(&sess->lock);
ret = gnutls_record_recv(sess->session, buf, len); ret = gnutls_record_recv(sess->session, buf, len);
if (ret >= 0) if (ret >= 0)
return ret; goto cleanup;
switch (ret) { switch (ret) {
case GNUTLS_E_AGAIN: case GNUTLS_E_AGAIN:
...@@ -1247,21 +1292,29 @@ ssize_t virNetTLSSessionRead(virNetTLSSessionPtr sess, ...@@ -1247,21 +1292,29 @@ ssize_t virNetTLSSessionRead(virNetTLSSessionPtr sess,
break; break;
} }
return -1; ret = -1;
cleanup:
virMutexUnlock(&sess->lock);
return ret;
} }
int virNetTLSSessionHandshake(virNetTLSSessionPtr sess) int virNetTLSSessionHandshake(virNetTLSSessionPtr sess)
{ {
int ret;
VIR_DEBUG("sess=%p", sess); VIR_DEBUG("sess=%p", sess);
int ret = gnutls_handshake(sess->session); virMutexLock(&sess->lock);
ret = gnutls_handshake(sess->session);
VIR_DEBUG("Ret=%d", ret); VIR_DEBUG("Ret=%d", ret);
if (ret == 0) { if (ret == 0) {
sess->handshakeComplete = true; sess->handshakeComplete = true;
VIR_DEBUG("Handshake is complete"); VIR_DEBUG("Handshake is complete");
return 0; goto cleanup;
}
if (ret == GNUTLS_E_INTERRUPTED || ret == GNUTLS_E_AGAIN) {
ret = 1;
goto cleanup;
} }
if (ret == GNUTLS_E_INTERRUPTED || ret == GNUTLS_E_AGAIN)
return 1;
#if 0 #if 0
PROBE(CLIENT_TLS_FAIL, "fd=%d", PROBE(CLIENT_TLS_FAIL, "fd=%d",
...@@ -1271,32 +1324,43 @@ int virNetTLSSessionHandshake(virNetTLSSessionPtr sess) ...@@ -1271,32 +1324,43 @@ int virNetTLSSessionHandshake(virNetTLSSessionPtr sess)
virNetError(VIR_ERR_AUTH_FAILED, virNetError(VIR_ERR_AUTH_FAILED,
_("TLS handshake failed %s"), _("TLS handshake failed %s"),
gnutls_strerror(ret)); gnutls_strerror(ret));
return -1; ret = -1;
cleanup:
virMutexUnlock(&sess->lock);
return ret;
} }
virNetTLSSessionHandshakeStatus virNetTLSSessionHandshakeStatus
virNetTLSSessionGetHandshakeStatus(virNetTLSSessionPtr sess) virNetTLSSessionGetHandshakeStatus(virNetTLSSessionPtr sess)
{ {
virNetTLSSessionHandshakeStatus ret;
virMutexLock(&sess->lock);
if (sess->handshakeComplete) if (sess->handshakeComplete)
return VIR_NET_TLS_HANDSHAKE_COMPLETE; ret = VIR_NET_TLS_HANDSHAKE_COMPLETE;
else if (gnutls_record_get_direction(sess->session) == 0) else if (gnutls_record_get_direction(sess->session) == 0)
return VIR_NET_TLS_HANDSHAKE_RECVING; ret = VIR_NET_TLS_HANDSHAKE_RECVING;
else else
return VIR_NET_TLS_HANDSHAKE_SENDING; ret = VIR_NET_TLS_HANDSHAKE_SENDING;
virMutexUnlock(&sess->lock);
return ret;
} }
int virNetTLSSessionGetKeySize(virNetTLSSessionPtr sess) int virNetTLSSessionGetKeySize(virNetTLSSessionPtr sess)
{ {
gnutls_cipher_algorithm_t cipher; gnutls_cipher_algorithm_t cipher;
int ssf; int ssf;
virMutexLock(&sess->lock);
cipher = gnutls_cipher_get(sess->session); cipher = gnutls_cipher_get(sess->session);
if (!(ssf = gnutls_cipher_get_key_size(cipher))) { if (!(ssf = gnutls_cipher_get_key_size(cipher))) {
virNetError(VIR_ERR_INTERNAL_ERROR, "%s", virNetError(VIR_ERR_INTERNAL_ERROR, "%s",
_("invalid cipher size for TLS session")); _("invalid cipher size for TLS session"));
return -1; ssf = -1;
goto cleanup;
} }
cleanup:
virMutexUnlock(&sess->lock);
return ssf; return ssf;
} }
...@@ -1306,11 +1370,16 @@ void virNetTLSSessionFree(virNetTLSSessionPtr sess) ...@@ -1306,11 +1370,16 @@ void virNetTLSSessionFree(virNetTLSSessionPtr sess)
if (!sess) if (!sess)
return; return;
virMutexLock(&sess->lock);
sess->refs--; sess->refs--;
if (sess->refs > 0) if (sess->refs > 0) {
virMutexUnlock(&sess->lock);
return; return;
}
VIR_FREE(sess->hostname); VIR_FREE(sess->hostname);
gnutls_deinit(sess->session); gnutls_deinit(sess->session);
virMutexUnlock(&sess->lock);
virMutexDestroy(&sess->lock);
VIR_FREE(sess); VIR_FREE(sess);
} }
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册