t_set.c 30.3 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29
/*
 * Copyright (c) 2009-2012, Salvatore Sanfilippo <antirez at gmail dot com>
 * All rights reserved.
 *
 * Redistribution and use in source and binary forms, with or without
 * modification, are permitted provided that the following conditions are met:
 *
 *   * Redistributions of source code must retain the above copyright notice,
 *     this list of conditions and the following disclaimer.
 *   * Redistributions in binary form must reproduce the above copyright
 *     notice, this list of conditions and the following disclaimer in the
 *     documentation and/or other materials provided with the distribution.
 *   * Neither the name of Redis nor the names of its contributors may be used
 *     to endorse or promote products derived from this software without
 *     specific prior written permission.
 *
 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
 * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
 * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
 * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
 * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
 * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
 * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
 * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
 * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
 * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
 * POSSIBILITY OF SUCH DAMAGE.
 */

30 31 32 33 34 35
#include "redis.h"

/*-----------------------------------------------------------------------------
 * Set Commands
 *----------------------------------------------------------------------------*/

36 37
void sunionDiffGenericCommand(redisClient *c, robj **setkeys, int setnum, robj *dstkey, int op);

38 39 40 41
/* Factory method to return a set that *can* hold "value". When the object has
 * an integer-encodable value, an intset will be returned. Otherwise a regular
 * hash table. */
robj *setTypeCreate(robj *value) {
42
    if (isObjectRepresentableAsLongLong(value,NULL) == REDIS_OK)
43 44 45 46 47 48 49 50 51 52 53 54
        return createIntsetObject();
    return createSetObject();
}

int setTypeAdd(robj *subject, robj *value) {
    long long llval;
    if (subject->encoding == REDIS_ENCODING_HT) {
        if (dictAdd(subject->ptr,value,NULL) == DICT_OK) {
            incrRefCount(value);
            return 1;
        }
    } else if (subject->encoding == REDIS_ENCODING_INTSET) {
55
        if (isObjectRepresentableAsLongLong(value,&llval) == REDIS_OK) {
56 57 58 59 60 61 62 63 64 65 66 67 68 69 70
            uint8_t success = 0;
            subject->ptr = intsetAdd(subject->ptr,llval,&success);
            if (success) {
                /* Convert to regular set when the intset contains
                 * too many entries. */
                if (intsetLen(subject->ptr) > server.set_max_intset_entries)
                    setTypeConvert(subject,REDIS_ENCODING_HT);
                return 1;
            }
        } else {
            /* Failed to get integer from object, convert to regular set. */
            setTypeConvert(subject,REDIS_ENCODING_HT);

            /* The set *was* an intset and this value is not integer
             * encodable, so dictAdd should always work. */
71
            redisAssertWithInfo(NULL,value,dictAdd(subject->ptr,value,NULL) == DICT_OK);
72 73 74 75 76 77 78 79 80
            incrRefCount(value);
            return 1;
        }
    } else {
        redisPanic("Unknown set encoding");
    }
    return 0;
}

81
int setTypeRemove(robj *setobj, robj *value) {
82
    long long llval;
83 84 85
    if (setobj->encoding == REDIS_ENCODING_HT) {
        if (dictDelete(setobj->ptr,value) == DICT_OK) {
            if (htNeedsResize(setobj->ptr)) dictResize(setobj->ptr);
86 87
            return 1;
        }
88
    } else if (setobj->encoding == REDIS_ENCODING_INTSET) {
89
        if (isObjectRepresentableAsLongLong(value,&llval) == REDIS_OK) {
90 91
            int success;
            setobj->ptr = intsetRemove(setobj->ptr,llval,&success);
92 93 94 95 96 97 98 99 100 101 102 103 104
            if (success) return 1;
        }
    } else {
        redisPanic("Unknown set encoding");
    }
    return 0;
}

int setTypeIsMember(robj *subject, robj *value) {
    long long llval;
    if (subject->encoding == REDIS_ENCODING_HT) {
        return dictFind((dict*)subject->ptr,value) != NULL;
    } else if (subject->encoding == REDIS_ENCODING_INTSET) {
105
        if (isObjectRepresentableAsLongLong(value,&llval) == REDIS_OK) {
106 107 108 109 110 111 112 113
            return intsetFind((intset*)subject->ptr,llval);
        }
    } else {
        redisPanic("Unknown set encoding");
    }
    return 0;
}

114
setTypeIterator *setTypeInitIterator(robj *subject) {
115
    setTypeIterator *si = zmalloc(sizeof(setTypeIterator));
116 117 118 119 120 121 122 123 124 125 126 127
    si->subject = subject;
    si->encoding = subject->encoding;
    if (si->encoding == REDIS_ENCODING_HT) {
        si->di = dictGetIterator(subject->ptr);
    } else if (si->encoding == REDIS_ENCODING_INTSET) {
        si->ii = 0;
    } else {
        redisPanic("Unknown set encoding");
    }
    return si;
}

128
void setTypeReleaseIterator(setTypeIterator *si) {
129 130 131 132 133 134
    if (si->encoding == REDIS_ENCODING_HT)
        dictReleaseIterator(si->di);
    zfree(si);
}

/* Move to the next entry in the set. Returns the object at the current
135 136 137 138 139 140 141 142 143 144 145
 * position.
 *
 * Since set elements can be internally be stored as redis objects or
 * simple arrays of integers, setTypeNext returns the encoding of the
 * set object you are iterating, and will populate the appropriate pointer
 * (eobj) or (llobj) accordingly.
 *
 * When there are no longer elements -1 is returned.
 * Returned objects ref count is not incremented, so this function is
 * copy on write friendly. */
int setTypeNext(setTypeIterator *si, robj **objele, int64_t *llele) {
146 147
    if (si->encoding == REDIS_ENCODING_HT) {
        dictEntry *de = dictNext(si->di);
148
        if (de == NULL) return -1;
149
        *objele = dictGetKey(de);
150
    } else if (si->encoding == REDIS_ENCODING_INTSET) {
151 152
        if (!intsetGet(si->subject->ptr,si->ii++,llele))
            return -1;
153
    }
154
    return si->encoding;
155 156
}

157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181
/* The not copy on write friendly version but easy to use version
 * of setTypeNext() is setTypeNextObject(), returning new objects
 * or incrementing the ref count of returned objects. So if you don't
 * retain a pointer to this object you should call decrRefCount() against it.
 *
 * This function is the way to go for write operations where COW is not
 * an issue as the result will be anyway of incrementing the ref count. */
robj *setTypeNextObject(setTypeIterator *si) {
    int64_t intele;
    robj *objele;
    int encoding;

    encoding = setTypeNext(si,&objele,&intele);
    switch(encoding) {
        case -1:    return NULL;
        case REDIS_ENCODING_INTSET:
            return createStringObjectFromLongLong(intele);
        case REDIS_ENCODING_HT:
            incrRefCount(objele);
            return objele;
        default:
            redisPanic("Unsupported encoding");
    }
    return NULL; /* just to suppress warnings */
}
182

183
/* Return random element from a non empty set.
184
 * The returned element can be a int64_t value if the set is encoded
185 186 187 188 189 190
 * as an "intset" blob of integers, or a redis object if the set
 * is a regular set.
 *
 * The caller provides both pointers to be populated with the right
 * object. The return value of the function is the object->encoding
 * field of the object and is used by the caller to check if the
G
guiquanz 已提交
191
 * int64_t pointer or the redis object pointer was populated.
192 193 194
 *
 * When an object is returned (the set was a real set) the ref count
 * of the object is not incremented so this function can be considered
195 196
 * copy on write friendly. */
int setTypeRandomElement(robj *setobj, robj **objele, int64_t *llele) {
197 198
    if (setobj->encoding == REDIS_ENCODING_HT) {
        dictEntry *de = dictGetRandomKey(setobj->ptr);
199
        *objele = dictGetKey(de);
200 201
    } else if (setobj->encoding == REDIS_ENCODING_INTSET) {
        *llele = intsetRandom(setobj->ptr);
202 203 204
    } else {
        redisPanic("Unknown set encoding");
    }
205
    return setobj->encoding;
206 207 208 209 210 211 212 213 214 215 216 217 218
}

unsigned long setTypeSize(robj *subject) {
    if (subject->encoding == REDIS_ENCODING_HT) {
        return dictSize((dict*)subject->ptr);
    } else if (subject->encoding == REDIS_ENCODING_INTSET) {
        return intsetLen((intset*)subject->ptr);
    } else {
        redisPanic("Unknown set encoding");
    }
}

/* Convert the set to specified encoding. The resulting dict (when converting
219
 * to a hash table) is presized to hold the number of elements in the original
220
 * set. */
221
void setTypeConvert(robj *setobj, int enc) {
222
    setTypeIterator *si;
223 224
    redisAssertWithInfo(NULL,setobj,setobj->type == REDIS_SET &&
                             setobj->encoding == REDIS_ENCODING_INTSET);
225 226

    if (enc == REDIS_ENCODING_HT) {
227
        int64_t intele;
228
        dict *d = dictCreate(&setDictType,NULL);
229 230
        robj *element;

231
        /* Presize the dict to avoid rehashing */
232
        dictExpand(d,intsetLen(setobj->ptr));
233

234 235 236 237
        /* To add the elements we extract integers and create redis objects */
        si = setTypeInitIterator(setobj);
        while (setTypeNext(si,NULL,&intele) != -1) {
            element = createStringObjectFromLongLong(intele);
238
            redisAssertWithInfo(NULL,element,dictAdd(d,element,NULL) == DICT_OK);
239
        }
240 241
        setTypeReleaseIterator(si);

242 243 244
        setobj->encoding = REDIS_ENCODING_HT;
        zfree(setobj->ptr);
        setobj->ptr = d;
245 246 247 248 249
    } else {
        redisPanic("Unsupported set conversion");
    }
}

250 251
void saddCommand(redisClient *c) {
    robj *set;
A
antirez 已提交
252
    int j, added = 0;
253 254 255

    set = lookupKeyWrite(c->db,c->argv[1]);
    if (set == NULL) {
256
        set = setTypeCreate(c->argv[2]);
257 258 259 260 261 262 263
        dbAdd(c->db,c->argv[1],set);
    } else {
        if (set->type != REDIS_SET) {
            addReply(c,shared.wrongtypeerr);
            return;
        }
    }
A
antirez 已提交
264 265 266 267

    for (j = 2; j < c->argc; j++) {
        c->argv[j] = tryObjectEncoding(c->argv[j]);
        if (setTypeAdd(set,c->argv[j])) added++;
268
    }
A
antirez 已提交
269 270 271
    if (added) signalModifiedKey(c->db,c->argv[1]);
    server.dirty += added;
    addReplyLongLong(c,added);
272 273 274 275
}

void sremCommand(redisClient *c) {
    robj *set;
A
antirez 已提交
276
    int j, deleted = 0;
277 278 279 280

    if ((set = lookupKeyWriteOrReply(c,c->argv[1],shared.czero)) == NULL ||
        checkType(c,set,REDIS_SET)) return;

A
antirez 已提交
281 282 283
    for (j = 2; j < c->argc; j++) {
        if (setTypeRemove(set,c->argv[j])) {
            deleted++;
284 285 286 287
            if (setTypeSize(set) == 0) {
                dbDelete(c->db,c->argv[1]);
                break;
            }
A
antirez 已提交
288 289 290
        }
    }
    if (deleted) {
291
        signalModifiedKey(c->db,c->argv[1]);
A
antirez 已提交
292
        server.dirty += deleted;
293
    }
A
antirez 已提交
294
    addReplyLongLong(c,deleted);
295 296 297
}

void smoveCommand(redisClient *c) {
298
    robj *srcset, *dstset, *ele;
299 300
    srcset = lookupKeyWrite(c->db,c->argv[1]);
    dstset = lookupKeyWrite(c->db,c->argv[2]);
301
    ele = c->argv[3] = tryObjectEncoding(c->argv[3]);
302

303 304 305
    /* If the source key does not exist return 0 */
    if (srcset == NULL) {
        addReply(c,shared.czero);
306 307
        return;
    }
308 309 310 311 312 313 314 315 316

    /* If the source key has the wrong type, or the destination key
     * is set and has the wrong type, return with an error. */
    if (checkType(c,srcset,REDIS_SET) ||
        (dstset && checkType(c,dstset,REDIS_SET))) return;

    /* If srcset and dstset are equal, SMOVE is a no-op */
    if (srcset == dstset) {
        addReply(c,shared.cone);
317 318
        return;
    }
319 320 321

    /* If the element cannot be removed from the src set, return 0. */
    if (!setTypeRemove(srcset,ele)) {
322 323 324
        addReply(c,shared.czero);
        return;
    }
325 326 327

    /* Remove the src set from the database when empty */
    if (setTypeSize(srcset) == 0) dbDelete(c->db,c->argv[1]);
328 329
    signalModifiedKey(c->db,c->argv[1]);
    signalModifiedKey(c->db,c->argv[2]);
330
    server.dirty++;
331 332

    /* Create the destination set when it doesn't exist */
333
    if (!dstset) {
334
        dstset = setTypeCreate(ele);
335 336
        dbAdd(c->db,c->argv[2],dstset);
    }
337 338 339

    /* An extra key has changed when ele was successfully added to dstset */
    if (setTypeAdd(dstset,ele)) server.dirty++;
340 341 342 343 344 345 346 347 348
    addReply(c,shared.cone);
}

void sismemberCommand(redisClient *c) {
    robj *set;

    if ((set = lookupKeyReadOrReply(c,c->argv[1],shared.czero)) == NULL ||
        checkType(c,set,REDIS_SET)) return;

349
    c->argv[2] = tryObjectEncoding(c->argv[2]);
350
    if (setTypeIsMember(set,c->argv[2]))
351 352 353 354 355 356 357 358 359 360 361
        addReply(c,shared.cone);
    else
        addReply(c,shared.czero);
}

void scardCommand(redisClient *c) {
    robj *o;

    if ((o = lookupKeyReadOrReply(c,c->argv[1],shared.czero)) == NULL ||
        checkType(c,o,REDIS_SET)) return;

362
    addReplyLongLong(c,setTypeSize(o));
363 364 365
}

void spopCommand(redisClient *c) {
366
    robj *set, *ele, *aux;
367
    int64_t llele;
368
    int encoding;
369 370 371 372

    if ((set = lookupKeyWriteOrReply(c,c->argv[1],shared.nullbulk)) == NULL ||
        checkType(c,set,REDIS_SET)) return;

373 374
    encoding = setTypeRandomElement(set,&ele,&llele);
    if (encoding == REDIS_ENCODING_INTSET) {
375
        ele = createStringObjectFromLongLong(llele);
376
        set->ptr = intsetRemove(set->ptr,llele,NULL);
377
    } else {
378
        incrRefCount(ele);
379
        setTypeRemove(set,ele);
380
    }
381

382 383 384 385 386
    /* Replicate/AOF this command as an SREM operation */
    aux = createStringObject("SREM",4);
    rewriteClientCommandVector(c,3,aux,c->argv[1],ele);
    decrRefCount(ele);
    decrRefCount(aux);
387 388

    addReplyBulk(c,ele);
389
    if (setTypeSize(set) == 0) dbDelete(c->db,c->argv[1]);
390
    signalModifiedKey(c->db,c->argv[1]);
391
    server.dirty++;
392 393
}

394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509
/* handle the "SRANDMEMBER key <count>" variant. The normal version of the
 * command is handled by the srandmemberCommand() function itself. */

/* How many times bigger should be the set compared to the requested size
 * for us to don't use the "remove elements" strategy? Read later in the
 * implementation for more info. */
#define SRANDMEMBER_SUB_STRATEGY_MUL 3

void srandmemberWithCountCommand(redisClient *c) {
    long l;
    unsigned long count, size;
    int uniq = 1;
    robj *set, *ele;
    int64_t llele;
    int encoding;

    dict *d;

    if (getLongFromObjectOrReply(c,c->argv[2],&l,NULL) != REDIS_OK) return;
    if (l >= 0) {
        count = (unsigned) l;
    } else {
        /* A negative count means: return the same elements multiple times
         * (i.e. don't remove the extracted element after every extraction). */
        count = -l;
        uniq = 0;
    }

    if ((set = lookupKeyReadOrReply(c,c->argv[1],shared.emptymultibulk))
        == NULL || checkType(c,set,REDIS_SET)) return;
    size = setTypeSize(set);

    /* If count is zero, serve it ASAP to avoid special cases later. */
    if (count == 0) {
        addReply(c,shared.emptymultibulk);
        return;
    }

    /* CASE 1: The count was negative, so the extraction method is just:
     * "return N random elements" sampling the whole set every time.
     * This case is trivial and can be served without auxiliary data
     * structures. */
    if (!uniq) {
        addReplyMultiBulkLen(c,count);
        while(count--) {
            encoding = setTypeRandomElement(set,&ele,&llele);
            if (encoding == REDIS_ENCODING_INTSET) {
                addReplyBulkLongLong(c,llele);
            } else {
                addReplyBulk(c,ele);
            }
        }
        return;
    }

    /* CASE 2:
     * The number of requested elements is greater than the number of
     * elements inside the set: simply return the whole set. */
    if (count >= size) {
        sunionDiffGenericCommand(c,c->argv,c->argc-1,NULL,REDIS_OP_UNION);
        return;
    }

    /* For CASE 3 and CASE 4 we need an auxiliary dictionary. */
    d = dictCreate(&setDictType,NULL);

    /* CASE 3:
     * The number of elements inside the set is not greater than
     * SRANDMEMBER_SUB_STRATEGY_MUL times the number of requested elements.
     * In this case we create a set from scratch with all the elements, and
     * subtract random elements to reach the requested number of elements.
     *
     * This is done because if the number of requsted elements is just
     * a bit less than the number of elements in the set, the natural approach
     * used into CASE 3 is highly inefficient. */
    if (count*SRANDMEMBER_SUB_STRATEGY_MUL > size) {
        setTypeIterator *si;

        /* Add all the elements into the temporary dictionary. */
        si = setTypeInitIterator(set);
        while((encoding = setTypeNext(si,&ele,&llele)) != -1) {
            int retval;

            if (encoding == REDIS_ENCODING_INTSET) {
                retval = dictAdd(d,createStringObjectFromLongLong(llele),NULL);
            } else if (ele->encoding == REDIS_ENCODING_RAW) {
                retval = dictAdd(d,dupStringObject(ele),NULL);
            } else if (ele->encoding == REDIS_ENCODING_INT) {
                retval = dictAdd(d,
                    createStringObjectFromLongLong((long)ele->ptr),NULL);
            }
            redisAssert(retval == DICT_OK);
        }
        setTypeReleaseIterator(si);
        redisAssert(dictSize(d) == size);

        /* Remove random elements to reach the right count. */
        while(size > count) {
            dictEntry *de;

            de = dictGetRandomKey(d);
            dictDelete(d,dictGetKey(de));
            size--;
        }
    }
    
    /* CASE 4: We have a big set compared to the requested number of elements.
     * In this case we can simply get random elements from the set and add
     * to the temporary set, trying to eventually get enough unique elements
     * to reach the specified count. */
    else {
        unsigned long added = 0;

        while(added < count) {
            encoding = setTypeRandomElement(set,&ele,&llele);
            if (encoding == REDIS_ENCODING_INTSET) {
A
antirez 已提交
510
                ele = createStringObjectFromLongLong(llele);
511
            } else if (ele->encoding == REDIS_ENCODING_RAW) {
A
antirez 已提交
512
                ele = dupStringObject(ele);
513
            } else if (ele->encoding == REDIS_ENCODING_INT) {
A
antirez 已提交
514
                ele = createStringObjectFromLongLong((long)ele->ptr);
515
            }
A
antirez 已提交
516 517 518 519 520 521 522
            /* Try to add the object to the dictionary. If it already exists
             * free it, otherwise increment the number of objects we have
             * in the result dictionary. */
            if (dictAdd(d,ele,NULL) == DICT_OK)
                added++;
            else
                decrRefCount(ele);
523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539
        }
    }

    /* CASE 3 & 4: send the result to the user. */
    {
        dictIterator *di;
        dictEntry *de;

        addReplyMultiBulkLen(c,count);
        di = dictGetIterator(d);
        while((de = dictNext(di)) != NULL)
            addReplyBulk(c,dictGetKey(de));
        dictReleaseIterator(di);
        dictRelease(d);
    }
}

540
void srandmemberCommand(redisClient *c) {
541
    robj *set, *ele;
542
    int64_t llele;
543
    int encoding;
544

545 546 547 548 549 550 551 552
    if (c->argc == 3) {
        srandmemberWithCountCommand(c);
        return;
    } else if (c->argc > 3) {
        addReply(c,shared.syntaxerr);
        return;
    }

553 554 555
    if ((set = lookupKeyReadOrReply(c,c->argv[1],shared.nullbulk)) == NULL ||
        checkType(c,set,REDIS_SET)) return;

556 557 558
    encoding = setTypeRandomElement(set,&ele,&llele);
    if (encoding == REDIS_ENCODING_INTSET) {
        addReplyBulkLongLong(c,llele);
559 560 561 562 563 564
    } else {
        addReplyBulk(c,ele);
    }
}

int qsortCompareSetsByCardinality(const void *s1, const void *s2) {
565
    return setTypeSize(*(robj**)s1)-setTypeSize(*(robj**)s2);
566 567
}

568 569 570 571 572 573 574 575
/* This is used by SDIFF and in this case we can receive NULL that should
 * be handled as empty sets. */
int qsortCompareSetsByRevCardinality(const void *s1, const void *s2) {
    robj *o1 = *(robj**)s1, *o2 = *(robj**)s2;

    return  (o2 ? setTypeSize(o2) : 0) - (o1 ? setTypeSize(o1) : 0);
}

576 577
void sinterGenericCommand(redisClient *c, robj **setkeys, unsigned long setnum, robj *dstkey) {
    robj **sets = zmalloc(sizeof(robj*)*setnum);
578
    setTypeIterator *si;
579 580
    robj *eleobj, *dstset = NULL;
    int64_t intobj;
581
    void *replylen = NULL;
582
    unsigned long j, cardinality = 0;
583
    int encoding;
584

585 586 587 588
    for (j = 0; j < setnum; j++) {
        robj *setobj = dstkey ?
            lookupKeyWrite(c->db,setkeys[j]) :
            lookupKeyRead(c->db,setkeys[j]);
589
        if (!setobj) {
590
            zfree(sets);
591
            if (dstkey) {
592
                if (dbDelete(c->db,dstkey)) {
593
                    signalModifiedKey(c->db,dstkey);
594
                    server.dirty++;
595
                }
596 597 598 599 600 601
                addReply(c,shared.czero);
            } else {
                addReply(c,shared.emptymultibulk);
            }
            return;
        }
602 603
        if (checkType(c,setobj,REDIS_SET)) {
            zfree(sets);
604 605
            return;
        }
606
        sets[j] = setobj;
607 608
    }
    /* Sort sets from the smallest to largest, this will improve our
G
guiquanz 已提交
609
     * algorithm's performance */
610
    qsort(sets,setnum,sizeof(robj*),qsortCompareSetsByCardinality);
611 612 613 614 615 616 617

    /* The first thing we should output is the total number of elements...
     * since this is a multi-bulk write, but at this stage we don't know
     * the intersection set size, so we use a trick, append an empty object
     * to the output list and save the pointer to later modify it with the
     * right length */
    if (!dstkey) {
618
        replylen = addDeferredMultiBulkLength(c);
619 620 621
    } else {
        /* If we have a target key where to store the resulting set
         * create this key with an empty set inside */
622
        dstset = createIntsetObject();
623 624 625 626 627
    }

    /* Iterate all the elements of the first (smallest) set, and test
     * the element against all the other sets, if at least one set does
     * not include the element it is discarded */
628
    si = setTypeInitIterator(sets[0]);
629 630
    while((encoding = setTypeNext(si,&eleobj,&intobj)) != -1) {
        for (j = 1; j < setnum; j++) {
631
            if (sets[j] == sets[0]) continue;
632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664
            if (encoding == REDIS_ENCODING_INTSET) {
                /* intset with intset is simple... and fast */
                if (sets[j]->encoding == REDIS_ENCODING_INTSET &&
                    !intsetFind((intset*)sets[j]->ptr,intobj))
                {
                    break;
                /* in order to compare an integer with an object we
                 * have to use the generic function, creating an object
                 * for this */
                } else if (sets[j]->encoding == REDIS_ENCODING_HT) {
                    eleobj = createStringObjectFromLongLong(intobj);
                    if (!setTypeIsMember(sets[j],eleobj)) {
                        decrRefCount(eleobj);
                        break;
                    }
                    decrRefCount(eleobj);
                }
            } else if (encoding == REDIS_ENCODING_HT) {
                /* Optimization... if the source object is integer
                 * encoded AND the target set is an intset, we can get
                 * a much faster path. */
                if (eleobj->encoding == REDIS_ENCODING_INT &&
                    sets[j]->encoding == REDIS_ENCODING_INTSET &&
                    !intsetFind((intset*)sets[j]->ptr,(long)eleobj->ptr))
                {
                    break;
                /* else... object to object check is easy as we use the
                 * type agnostic API here. */
                } else if (!setTypeIsMember(sets[j],eleobj)) {
                    break;
                }
            }
        }
665 666 667 668

        /* Only take action when all sets contain the member */
        if (j == setnum) {
            if (!dstkey) {
669 670 671 672
                if (encoding == REDIS_ENCODING_HT)
                    addReplyBulk(c,eleobj);
                else
                    addReplyBulkLongLong(c,intobj);
673 674
                cardinality++;
            } else {
675 676 677 678 679 680 681
                if (encoding == REDIS_ENCODING_INTSET) {
                    eleobj = createStringObjectFromLongLong(intobj);
                    setTypeAdd(dstset,eleobj);
                    decrRefCount(eleobj);
                } else {
                    setTypeAdd(dstset,eleobj);
                }
682
            }
683 684
        }
    }
685
    setTypeReleaseIterator(si);
686 687 688 689 690

    if (dstkey) {
        /* Store the resulting set into the target, if the intersection
         * is not an empty set. */
        dbDelete(c->db,dstkey);
691
        if (setTypeSize(dstset) > 0) {
692
            dbAdd(c->db,dstkey,dstset);
693
            addReplyLongLong(c,setTypeSize(dstset));
694 695 696 697
        } else {
            decrRefCount(dstset);
            addReply(c,shared.czero);
        }
698
        signalModifiedKey(c->db,dstkey);
699 700
        server.dirty++;
    } else {
701
        setDeferredMultiBulkLength(c,replylen,cardinality);
702
    }
703
    zfree(sets);
704 705 706 707 708 709 710 711 712 713
}

void sinterCommand(redisClient *c) {
    sinterGenericCommand(c,c->argv+1,c->argc-1,NULL);
}

void sinterstoreCommand(redisClient *c) {
    sinterGenericCommand(c,c->argv+2,c->argc-2,c->argv[1]);
}

714 715 716
#define REDIS_OP_UNION 0
#define REDIS_OP_DIFF 1
#define REDIS_OP_INTER 2
717

718 719
void sunionDiffGenericCommand(redisClient *c, robj **setkeys, int setnum, robj *dstkey, int op) {
    robj **sets = zmalloc(sizeof(robj*)*setnum);
720
    setTypeIterator *si;
721 722
    robj *ele, *dstset = NULL;
    int j, cardinality = 0;
723
    int diff_algo = 1;
724

725 726 727 728
    for (j = 0; j < setnum; j++) {
        robj *setobj = dstkey ?
            lookupKeyWrite(c->db,setkeys[j]) :
            lookupKeyRead(c->db,setkeys[j]);
729
        if (!setobj) {
730
            sets[j] = NULL;
731 732
            continue;
        }
733 734
        if (checkType(c,setobj,REDIS_SET)) {
            zfree(sets);
735 736
            return;
        }
737
        sets[j] = setobj;
738 739
    }

740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772
    /* Select what DIFF algorithm to use.
     *
     * Algorithm 1 is O(N*M) where N is the size of the element first set
     * and M the total number of sets.
     *
     * Algorithm 2 is O(N) where N is the total number of elements in all
     * the sets.
     *
     * We compute what is the best bet with the current input here. */
    if (op == REDIS_OP_DIFF && sets[0]) {
        long long algo_one_work = 0, algo_two_work = 0;

        for (j = 0; j < setnum; j++) {
            if (sets[j] == NULL) continue;

            algo_one_work += setTypeSize(sets[0]);
            algo_two_work += setTypeSize(sets[j]);
        }

        /* Algorithm 1 has better constant times and performs less operations
         * if there are elements in common. Give it some advantage. */
        algo_one_work /= 2;
        diff_algo = (algo_one_work <= algo_two_work) ? 1 : 2;

        if (diff_algo == 1 && setnum > 1) {
            /* With algorithm 1 it is better to order the sets to subtract
             * by decreasing size, so that we are more likely to find
             * duplicated elements ASAP. */
            qsort(sets+1,setnum-1,sizeof(robj*),
                qsortCompareSetsByRevCardinality);
        }
    }

773 774 775
    /* We need a temp set object to store our union. If the dstkey
     * is not NULL (that is, we are inside an SUNIONSTORE operation) then
     * this set object will be the resulting object to set into the target key*/
776
    dstset = createIntsetObject();
777

778 779 780 781 782
    if (op == REDIS_OP_UNION) {
        /* Union is trivial, just add every element of every set to the
         * temporary set. */
        for (j = 0; j < setnum; j++) {
            if (!sets[j]) continue; /* non existing keys are like empty sets */
783

784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800
            si = setTypeInitIterator(sets[j]);
            while((ele = setTypeNextObject(si)) != NULL) {
                if (setTypeAdd(dstset,ele)) cardinality++;
                decrRefCount(ele);
            }
            setTypeReleaseIterator(si);
        }
    } else if (op == REDIS_OP_DIFF && sets[0] && diff_algo == 1) {
        /* DIFF Algorithm 1:
         *
         * We perform the diff by iterating all the elements of the first set,
         * and only adding it to the target set if the element does not exist
         * into all the other sets.
         *
         * This way we perform at max N*M operations, where N is the size of
         * the first set, and M the number of sets. */
        si = setTypeInitIterator(sets[0]);
801
        while((ele = setTypeNextObject(si)) != NULL) {
802 803 804 805 806 807 808 809
            for (j = 1; j < setnum; j++) {
                if (!sets[j]) continue; /* no key is an empty set. */
                if (setTypeIsMember(sets[j],ele)) break;
            }
            if (j == setnum) {
                /* There is no other set with this element. Add it. */
                setTypeAdd(dstset,ele);
                cardinality++;
810
            }
811
            decrRefCount(ele);
812
        }
813
        setTypeReleaseIterator(si);
814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834
    } else if (op == REDIS_OP_DIFF && sets[0] && diff_algo == 2) {
        /* DIFF Algorithm 2:
         *
         * Add all the elements of the first set to the auxiliary set.
         * Then remove all the elements of all the next sets from it.
         *
         * This is O(N) where N is the sum of all the elements in every
         * set. */
        for (j = 0; j < setnum; j++) {
            if (!sets[j]) continue; /* non existing keys are like empty sets */

            si = setTypeInitIterator(sets[j]);
            while((ele = setTypeNextObject(si)) != NULL) {
                if (j == 0) {
                    if (setTypeAdd(dstset,ele)) cardinality++;
                } else {
                    if (setTypeRemove(dstset,ele)) cardinality--;
                }
                decrRefCount(ele);
            }
            setTypeReleaseIterator(si);
835

836 837 838 839
            /* Exit if result set is empty as any additional removal
             * of elements will have no effect. */
            if (cardinality == 0) break;
        }
840 841 842 843
    }

    /* Output the content of the resulting set, if not in STORE mode */
    if (!dstkey) {
844
        addReplyMultiBulkLen(c,cardinality);
845
        si = setTypeInitIterator(dstset);
846
        while((ele = setTypeNextObject(si)) != NULL) {
847
            addReplyBulk(c,ele);
848
            decrRefCount(ele);
849
        }
850
        setTypeReleaseIterator(si);
851 852 853 854 855
        decrRefCount(dstset);
    } else {
        /* If we have a target key where to store the resulting set
         * create this key with the result set inside */
        dbDelete(c->db,dstkey);
856
        if (setTypeSize(dstset) > 0) {
857
            dbAdd(c->db,dstkey,dstset);
858
            addReplyLongLong(c,setTypeSize(dstset));
859 860 861 862
        } else {
            decrRefCount(dstset);
            addReply(c,shared.czero);
        }
863
        signalModifiedKey(c->db,dstkey);
864 865
        server.dirty++;
    }
866
    zfree(sets);
867 868 869 870 871 872 873 874 875 876 877 878 879 880 881 882 883
}

void sunionCommand(redisClient *c) {
    sunionDiffGenericCommand(c,c->argv+1,c->argc-1,NULL,REDIS_OP_UNION);
}

void sunionstoreCommand(redisClient *c) {
    sunionDiffGenericCommand(c,c->argv+2,c->argc-2,c->argv[1],REDIS_OP_UNION);
}

void sdiffCommand(redisClient *c) {
    sunionDiffGenericCommand(c,c->argv+1,c->argc-1,NULL,REDIS_OP_DIFF);
}

void sdiffstoreCommand(redisClient *c) {
    sunionDiffGenericCommand(c,c->argv+2,c->argc-2,c->argv[1],REDIS_OP_DIFF);
}