t_set.c 20.2 KB
Newer Older
1 2 3 4 5 6
#include "redis.h"

/*-----------------------------------------------------------------------------
 * Set Commands
 *----------------------------------------------------------------------------*/

7 8 9 10
/* Factory method to return a set that *can* hold "value". When the object has
 * an integer-encodable value, an intset will be returned. Otherwise a regular
 * hash table. */
robj *setTypeCreate(robj *value) {
11
    if (isObjectRepresentableAsLongLong(value,NULL) == REDIS_OK)
12 13 14 15 16 17 18 19 20 21 22 23
        return createIntsetObject();
    return createSetObject();
}

int setTypeAdd(robj *subject, robj *value) {
    long long llval;
    if (subject->encoding == REDIS_ENCODING_HT) {
        if (dictAdd(subject->ptr,value,NULL) == DICT_OK) {
            incrRefCount(value);
            return 1;
        }
    } else if (subject->encoding == REDIS_ENCODING_INTSET) {
24
        if (isObjectRepresentableAsLongLong(value,&llval) == REDIS_OK) {
25 26 27 28 29 30 31 32 33 34 35 36 37 38 39
            uint8_t success = 0;
            subject->ptr = intsetAdd(subject->ptr,llval,&success);
            if (success) {
                /* Convert to regular set when the intset contains
                 * too many entries. */
                if (intsetLen(subject->ptr) > server.set_max_intset_entries)
                    setTypeConvert(subject,REDIS_ENCODING_HT);
                return 1;
            }
        } else {
            /* Failed to get integer from object, convert to regular set. */
            setTypeConvert(subject,REDIS_ENCODING_HT);

            /* The set *was* an intset and this value is not integer
             * encodable, so dictAdd should always work. */
40
            redisAssertWithInfo(NULL,value,dictAdd(subject->ptr,value,NULL) == DICT_OK);
41 42 43 44 45 46 47 48 49
            incrRefCount(value);
            return 1;
        }
    } else {
        redisPanic("Unknown set encoding");
    }
    return 0;
}

50
int setTypeRemove(robj *setobj, robj *value) {
51
    long long llval;
52 53 54
    if (setobj->encoding == REDIS_ENCODING_HT) {
        if (dictDelete(setobj->ptr,value) == DICT_OK) {
            if (htNeedsResize(setobj->ptr)) dictResize(setobj->ptr);
55 56
            return 1;
        }
57
    } else if (setobj->encoding == REDIS_ENCODING_INTSET) {
58
        if (isObjectRepresentableAsLongLong(value,&llval) == REDIS_OK) {
59 60
            int success;
            setobj->ptr = intsetRemove(setobj->ptr,llval,&success);
61 62 63 64 65 66 67 68 69 70 71 72 73
            if (success) return 1;
        }
    } else {
        redisPanic("Unknown set encoding");
    }
    return 0;
}

int setTypeIsMember(robj *subject, robj *value) {
    long long llval;
    if (subject->encoding == REDIS_ENCODING_HT) {
        return dictFind((dict*)subject->ptr,value) != NULL;
    } else if (subject->encoding == REDIS_ENCODING_INTSET) {
74
        if (isObjectRepresentableAsLongLong(value,&llval) == REDIS_OK) {
75 76 77 78 79 80 81 82
            return intsetFind((intset*)subject->ptr,llval);
        }
    } else {
        redisPanic("Unknown set encoding");
    }
    return 0;
}

83
setTypeIterator *setTypeInitIterator(robj *subject) {
84
    setTypeIterator *si = zmalloc(sizeof(setTypeIterator));
85 86 87 88 89 90 91 92 93 94 95 96
    si->subject = subject;
    si->encoding = subject->encoding;
    if (si->encoding == REDIS_ENCODING_HT) {
        si->di = dictGetIterator(subject->ptr);
    } else if (si->encoding == REDIS_ENCODING_INTSET) {
        si->ii = 0;
    } else {
        redisPanic("Unknown set encoding");
    }
    return si;
}

97
void setTypeReleaseIterator(setTypeIterator *si) {
98 99 100 101 102 103
    if (si->encoding == REDIS_ENCODING_HT)
        dictReleaseIterator(si->di);
    zfree(si);
}

/* Move to the next entry in the set. Returns the object at the current
104 105 106 107 108 109 110 111 112 113 114
 * position.
 *
 * Since set elements can be internally be stored as redis objects or
 * simple arrays of integers, setTypeNext returns the encoding of the
 * set object you are iterating, and will populate the appropriate pointer
 * (eobj) or (llobj) accordingly.
 *
 * When there are no longer elements -1 is returned.
 * Returned objects ref count is not incremented, so this function is
 * copy on write friendly. */
int setTypeNext(setTypeIterator *si, robj **objele, int64_t *llele) {
115 116
    if (si->encoding == REDIS_ENCODING_HT) {
        dictEntry *de = dictNext(si->di);
117
        if (de == NULL) return -1;
118
        *objele = dictGetKey(de);
119
    } else if (si->encoding == REDIS_ENCODING_INTSET) {
120 121
        if (!intsetGet(si->subject->ptr,si->ii++,llele))
            return -1;
122
    }
123
    return si->encoding;
124 125
}

126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150
/* The not copy on write friendly version but easy to use version
 * of setTypeNext() is setTypeNextObject(), returning new objects
 * or incrementing the ref count of returned objects. So if you don't
 * retain a pointer to this object you should call decrRefCount() against it.
 *
 * This function is the way to go for write operations where COW is not
 * an issue as the result will be anyway of incrementing the ref count. */
robj *setTypeNextObject(setTypeIterator *si) {
    int64_t intele;
    robj *objele;
    int encoding;

    encoding = setTypeNext(si,&objele,&intele);
    switch(encoding) {
        case -1:    return NULL;
        case REDIS_ENCODING_INTSET:
            return createStringObjectFromLongLong(intele);
        case REDIS_ENCODING_HT:
            incrRefCount(objele);
            return objele;
        default:
            redisPanic("Unsupported encoding");
    }
    return NULL; /* just to suppress warnings */
}
151

152
/* Return random element from a non empty set.
153
 * The returned element can be a int64_t value if the set is encoded
154 155 156 157 158 159
 * as an "intset" blob of integers, or a redis object if the set
 * is a regular set.
 *
 * The caller provides both pointers to be populated with the right
 * object. The return value of the function is the object->encoding
 * field of the object and is used by the caller to check if the
160
 * int64_t pointer or the redis object pointere was populated.
161 162 163
 *
 * When an object is returned (the set was a real set) the ref count
 * of the object is not incremented so this function can be considered
164 165
 * copy on write friendly. */
int setTypeRandomElement(robj *setobj, robj **objele, int64_t *llele) {
166 167
    if (setobj->encoding == REDIS_ENCODING_HT) {
        dictEntry *de = dictGetRandomKey(setobj->ptr);
168
        *objele = dictGetKey(de);
169 170
    } else if (setobj->encoding == REDIS_ENCODING_INTSET) {
        *llele = intsetRandom(setobj->ptr);
171 172 173
    } else {
        redisPanic("Unknown set encoding");
    }
174
    return setobj->encoding;
175 176 177 178 179 180 181 182 183 184 185 186 187 188 189
}

unsigned long setTypeSize(robj *subject) {
    if (subject->encoding == REDIS_ENCODING_HT) {
        return dictSize((dict*)subject->ptr);
    } else if (subject->encoding == REDIS_ENCODING_INTSET) {
        return intsetLen((intset*)subject->ptr);
    } else {
        redisPanic("Unknown set encoding");
    }
}

/* Convert the set to specified encoding. The resulting dict (when converting
 * to a hashtable) is presized to hold the number of elements in the original
 * set. */
190
void setTypeConvert(robj *setobj, int enc) {
191
    setTypeIterator *si;
192 193
    redisAssertWithInfo(NULL,setobj,setobj->type == REDIS_SET &&
                             setobj->encoding == REDIS_ENCODING_INTSET);
194 195

    if (enc == REDIS_ENCODING_HT) {
196
        int64_t intele;
197
        dict *d = dictCreate(&setDictType,NULL);
198 199
        robj *element;

200
        /* Presize the dict to avoid rehashing */
201
        dictExpand(d,intsetLen(setobj->ptr));
202

203 204 205 206
        /* To add the elements we extract integers and create redis objects */
        si = setTypeInitIterator(setobj);
        while (setTypeNext(si,NULL,&intele) != -1) {
            element = createStringObjectFromLongLong(intele);
207
            redisAssertWithInfo(NULL,element,dictAdd(d,element,NULL) == DICT_OK);
208
        }
209 210
        setTypeReleaseIterator(si);

211 212 213
        setobj->encoding = REDIS_ENCODING_HT;
        zfree(setobj->ptr);
        setobj->ptr = d;
214 215 216 217 218
    } else {
        redisPanic("Unsupported set conversion");
    }
}

219 220
void saddCommand(redisClient *c) {
    robj *set;
A
antirez 已提交
221
    int j, added = 0;
222 223 224

    set = lookupKeyWrite(c->db,c->argv[1]);
    if (set == NULL) {
225
        set = setTypeCreate(c->argv[2]);
226 227 228 229 230 231 232
        dbAdd(c->db,c->argv[1],set);
    } else {
        if (set->type != REDIS_SET) {
            addReply(c,shared.wrongtypeerr);
            return;
        }
    }
A
antirez 已提交
233 234 235 236

    for (j = 2; j < c->argc; j++) {
        c->argv[j] = tryObjectEncoding(c->argv[j]);
        if (setTypeAdd(set,c->argv[j])) added++;
237
    }
A
antirez 已提交
238 239 240
    if (added) signalModifiedKey(c->db,c->argv[1]);
    server.dirty += added;
    addReplyLongLong(c,added);
241 242 243 244
}

void sremCommand(redisClient *c) {
    robj *set;
A
antirez 已提交
245
    int j, deleted = 0;
246 247 248 249

    if ((set = lookupKeyWriteOrReply(c,c->argv[1],shared.czero)) == NULL ||
        checkType(c,set,REDIS_SET)) return;

A
antirez 已提交
250 251 252
    for (j = 2; j < c->argc; j++) {
        if (setTypeRemove(set,c->argv[j])) {
            deleted++;
253 254 255 256
            if (setTypeSize(set) == 0) {
                dbDelete(c->db,c->argv[1]);
                break;
            }
A
antirez 已提交
257 258 259
        }
    }
    if (deleted) {
260
        signalModifiedKey(c->db,c->argv[1]);
A
antirez 已提交
261
        server.dirty += deleted;
262
    }
A
antirez 已提交
263
    addReplyLongLong(c,deleted);
264 265 266
}

void smoveCommand(redisClient *c) {
267
    robj *srcset, *dstset, *ele;
268 269
    srcset = lookupKeyWrite(c->db,c->argv[1]);
    dstset = lookupKeyWrite(c->db,c->argv[2]);
270
    ele = c->argv[3] = tryObjectEncoding(c->argv[3]);
271

272 273 274
    /* If the source key does not exist return 0 */
    if (srcset == NULL) {
        addReply(c,shared.czero);
275 276
        return;
    }
277 278 279 280 281 282 283 284 285

    /* If the source key has the wrong type, or the destination key
     * is set and has the wrong type, return with an error. */
    if (checkType(c,srcset,REDIS_SET) ||
        (dstset && checkType(c,dstset,REDIS_SET))) return;

    /* If srcset and dstset are equal, SMOVE is a no-op */
    if (srcset == dstset) {
        addReply(c,shared.cone);
286 287
        return;
    }
288 289 290

    /* If the element cannot be removed from the src set, return 0. */
    if (!setTypeRemove(srcset,ele)) {
291 292 293
        addReply(c,shared.czero);
        return;
    }
294 295 296

    /* Remove the src set from the database when empty */
    if (setTypeSize(srcset) == 0) dbDelete(c->db,c->argv[1]);
297 298
    signalModifiedKey(c->db,c->argv[1]);
    signalModifiedKey(c->db,c->argv[2]);
299
    server.dirty++;
300 301

    /* Create the destination set when it doesn't exist */
302
    if (!dstset) {
303
        dstset = setTypeCreate(ele);
304 305
        dbAdd(c->db,c->argv[2],dstset);
    }
306 307 308

    /* An extra key has changed when ele was successfully added to dstset */
    if (setTypeAdd(dstset,ele)) server.dirty++;
309 310 311 312 313 314 315 316 317
    addReply(c,shared.cone);
}

void sismemberCommand(redisClient *c) {
    robj *set;

    if ((set = lookupKeyReadOrReply(c,c->argv[1],shared.czero)) == NULL ||
        checkType(c,set,REDIS_SET)) return;

318
    c->argv[2] = tryObjectEncoding(c->argv[2]);
319
    if (setTypeIsMember(set,c->argv[2]))
320 321 322 323 324 325 326 327 328 329 330
        addReply(c,shared.cone);
    else
        addReply(c,shared.czero);
}

void scardCommand(redisClient *c) {
    robj *o;

    if ((o = lookupKeyReadOrReply(c,c->argv[1],shared.czero)) == NULL ||
        checkType(c,o,REDIS_SET)) return;

331
    addReplyLongLong(c,setTypeSize(o));
332 333 334
}

void spopCommand(redisClient *c) {
335
    robj *set, *ele, *aux;
336
    int64_t llele;
337
    int encoding;
338 339 340 341

    if ((set = lookupKeyWriteOrReply(c,c->argv[1],shared.nullbulk)) == NULL ||
        checkType(c,set,REDIS_SET)) return;

342 343
    encoding = setTypeRandomElement(set,&ele,&llele);
    if (encoding == REDIS_ENCODING_INTSET) {
344
        ele = createStringObjectFromLongLong(llele);
345
        set->ptr = intsetRemove(set->ptr,llele,NULL);
346
    } else {
347
        incrRefCount(ele);
348
        setTypeRemove(set,ele);
349
    }
350

351 352 353 354 355
    /* Replicate/AOF this command as an SREM operation */
    aux = createStringObject("SREM",4);
    rewriteClientCommandVector(c,3,aux,c->argv[1],ele);
    decrRefCount(ele);
    decrRefCount(aux);
356 357

    addReplyBulk(c,ele);
358
    if (setTypeSize(set) == 0) dbDelete(c->db,c->argv[1]);
359
    signalModifiedKey(c->db,c->argv[1]);
360
    server.dirty++;
361 362 363
}

void srandmemberCommand(redisClient *c) {
364
    robj *set, *ele;
365
    int64_t llele;
366
    int encoding;
367 368 369 370

    if ((set = lookupKeyReadOrReply(c,c->argv[1],shared.nullbulk)) == NULL ||
        checkType(c,set,REDIS_SET)) return;

371 372 373
    encoding = setTypeRandomElement(set,&ele,&llele);
    if (encoding == REDIS_ENCODING_INTSET) {
        addReplyBulkLongLong(c,llele);
374 375 376 377 378 379
    } else {
        addReplyBulk(c,ele);
    }
}

int qsortCompareSetsByCardinality(const void *s1, const void *s2) {
380
    return setTypeSize(*(robj**)s1)-setTypeSize(*(robj**)s2);
381 382
}

383 384
void sinterGenericCommand(redisClient *c, robj **setkeys, unsigned long setnum, robj *dstkey) {
    robj **sets = zmalloc(sizeof(robj*)*setnum);
385
    setTypeIterator *si;
386 387
    robj *eleobj, *dstset = NULL;
    int64_t intobj;
388
    void *replylen = NULL;
389
    unsigned long j, cardinality = 0;
390
    int encoding;
391

392 393 394 395
    for (j = 0; j < setnum; j++) {
        robj *setobj = dstkey ?
            lookupKeyWrite(c->db,setkeys[j]) :
            lookupKeyRead(c->db,setkeys[j]);
396
        if (!setobj) {
397
            zfree(sets);
398
            if (dstkey) {
399
                if (dbDelete(c->db,dstkey)) {
400
                    signalModifiedKey(c->db,dstkey);
401
                    server.dirty++;
402
                }
403 404 405 406 407 408
                addReply(c,shared.czero);
            } else {
                addReply(c,shared.emptymultibulk);
            }
            return;
        }
409 410
        if (checkType(c,setobj,REDIS_SET)) {
            zfree(sets);
411 412
            return;
        }
413
        sets[j] = setobj;
414 415 416
    }
    /* Sort sets from the smallest to largest, this will improve our
     * algorithm's performace */
417
    qsort(sets,setnum,sizeof(robj*),qsortCompareSetsByCardinality);
418 419 420 421 422 423 424

    /* The first thing we should output is the total number of elements...
     * since this is a multi-bulk write, but at this stage we don't know
     * the intersection set size, so we use a trick, append an empty object
     * to the output list and save the pointer to later modify it with the
     * right length */
    if (!dstkey) {
425
        replylen = addDeferredMultiBulkLength(c);
426 427 428
    } else {
        /* If we have a target key where to store the resulting set
         * create this key with an empty set inside */
429
        dstset = createIntsetObject();
430 431 432 433 434
    }

    /* Iterate all the elements of the first (smallest) set, and test
     * the element against all the other sets, if at least one set does
     * not include the element it is discarded */
435
    si = setTypeInitIterator(sets[0]);
436 437
    while((encoding = setTypeNext(si,&eleobj,&intobj)) != -1) {
        for (j = 1; j < setnum; j++) {
438
            if (sets[j] == sets[0]) continue;
439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471
            if (encoding == REDIS_ENCODING_INTSET) {
                /* intset with intset is simple... and fast */
                if (sets[j]->encoding == REDIS_ENCODING_INTSET &&
                    !intsetFind((intset*)sets[j]->ptr,intobj))
                {
                    break;
                /* in order to compare an integer with an object we
                 * have to use the generic function, creating an object
                 * for this */
                } else if (sets[j]->encoding == REDIS_ENCODING_HT) {
                    eleobj = createStringObjectFromLongLong(intobj);
                    if (!setTypeIsMember(sets[j],eleobj)) {
                        decrRefCount(eleobj);
                        break;
                    }
                    decrRefCount(eleobj);
                }
            } else if (encoding == REDIS_ENCODING_HT) {
                /* Optimization... if the source object is integer
                 * encoded AND the target set is an intset, we can get
                 * a much faster path. */
                if (eleobj->encoding == REDIS_ENCODING_INT &&
                    sets[j]->encoding == REDIS_ENCODING_INTSET &&
                    !intsetFind((intset*)sets[j]->ptr,(long)eleobj->ptr))
                {
                    break;
                /* else... object to object check is easy as we use the
                 * type agnostic API here. */
                } else if (!setTypeIsMember(sets[j],eleobj)) {
                    break;
                }
            }
        }
472 473 474 475

        /* Only take action when all sets contain the member */
        if (j == setnum) {
            if (!dstkey) {
476 477 478 479
                if (encoding == REDIS_ENCODING_HT)
                    addReplyBulk(c,eleobj);
                else
                    addReplyBulkLongLong(c,intobj);
480 481
                cardinality++;
            } else {
482 483 484 485 486 487 488
                if (encoding == REDIS_ENCODING_INTSET) {
                    eleobj = createStringObjectFromLongLong(intobj);
                    setTypeAdd(dstset,eleobj);
                    decrRefCount(eleobj);
                } else {
                    setTypeAdd(dstset,eleobj);
                }
489
            }
490 491
        }
    }
492
    setTypeReleaseIterator(si);
493 494 495 496 497

    if (dstkey) {
        /* Store the resulting set into the target, if the intersection
         * is not an empty set. */
        dbDelete(c->db,dstkey);
498
        if (setTypeSize(dstset) > 0) {
499
            dbAdd(c->db,dstkey,dstset);
500
            addReplyLongLong(c,setTypeSize(dstset));
501 502 503 504
        } else {
            decrRefCount(dstset);
            addReply(c,shared.czero);
        }
505
        signalModifiedKey(c->db,dstkey);
506 507
        server.dirty++;
    } else {
508
        setDeferredMultiBulkLength(c,replylen,cardinality);
509
    }
510
    zfree(sets);
511 512 513 514 515 516 517 518 519 520
}

void sinterCommand(redisClient *c) {
    sinterGenericCommand(c,c->argv+1,c->argc-1,NULL);
}

void sinterstoreCommand(redisClient *c) {
    sinterGenericCommand(c,c->argv+2,c->argc-2,c->argv[1]);
}

521 522 523
#define REDIS_OP_UNION 0
#define REDIS_OP_DIFF 1
#define REDIS_OP_INTER 2
524

525 526
void sunionDiffGenericCommand(redisClient *c, robj **setkeys, int setnum, robj *dstkey, int op) {
    robj **sets = zmalloc(sizeof(robj*)*setnum);
527
    setTypeIterator *si;
528 529
    robj *ele, *dstset = NULL;
    int j, cardinality = 0;
530

531 532 533 534
    for (j = 0; j < setnum; j++) {
        robj *setobj = dstkey ?
            lookupKeyWrite(c->db,setkeys[j]) :
            lookupKeyRead(c->db,setkeys[j]);
535
        if (!setobj) {
536
            sets[j] = NULL;
537 538
            continue;
        }
539 540
        if (checkType(c,setobj,REDIS_SET)) {
            zfree(sets);
541 542
            return;
        }
543
        sets[j] = setobj;
544 545 546 547 548
    }

    /* We need a temp set object to store our union. If the dstkey
     * is not NULL (that is, we are inside an SUNIONSTORE operation) then
     * this set object will be the resulting object to set into the target key*/
549
    dstset = createIntsetObject();
550 551 552

    /* Iterate all the elements of all the sets, add every element a single
     * time to the result set */
553 554 555
    for (j = 0; j < setnum; j++) {
        if (op == REDIS_OP_DIFF && j == 0 && !sets[j]) break; /* result set is empty */
        if (!sets[j]) continue; /* non existing keys are like empty sets */
556

557
        si = setTypeInitIterator(sets[j]);
558
        while((ele = setTypeNextObject(si)) != NULL) {
559
            if (op == REDIS_OP_UNION || j == 0) {
560
                if (setTypeAdd(dstset,ele)) {
561 562 563
                    cardinality++;
                }
            } else if (op == REDIS_OP_DIFF) {
564
                if (setTypeRemove(dstset,ele)) {
565 566 567
                    cardinality--;
                }
            }
568
            decrRefCount(ele);
569
        }
570
        setTypeReleaseIterator(si);
571

572
        /* Exit when result set is empty. */
573 574 575 576 577
        if (op == REDIS_OP_DIFF && cardinality == 0) break;
    }

    /* Output the content of the resulting set, if not in STORE mode */
    if (!dstkey) {
578
        addReplyMultiBulkLen(c,cardinality);
579
        si = setTypeInitIterator(dstset);
580
        while((ele = setTypeNextObject(si)) != NULL) {
581
            addReplyBulk(c,ele);
582
            decrRefCount(ele);
583
        }
584
        setTypeReleaseIterator(si);
585 586 587 588 589
        decrRefCount(dstset);
    } else {
        /* If we have a target key where to store the resulting set
         * create this key with the result set inside */
        dbDelete(c->db,dstkey);
590
        if (setTypeSize(dstset) > 0) {
591
            dbAdd(c->db,dstkey,dstset);
592
            addReplyLongLong(c,setTypeSize(dstset));
593 594 595 596
        } else {
            decrRefCount(dstset);
            addReply(c,shared.czero);
        }
597
        signalModifiedKey(c->db,dstkey);
598 599
        server.dirty++;
    }
600
    zfree(sets);
601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617
}

void sunionCommand(redisClient *c) {
    sunionDiffGenericCommand(c,c->argv+1,c->argc-1,NULL,REDIS_OP_UNION);
}

void sunionstoreCommand(redisClient *c) {
    sunionDiffGenericCommand(c,c->argv+2,c->argc-2,c->argv[1],REDIS_OP_UNION);
}

void sdiffCommand(redisClient *c) {
    sunionDiffGenericCommand(c,c->argv+1,c->argc-1,NULL,REDIS_OP_DIFF);
}

void sdiffstoreCommand(redisClient *c) {
    sunionDiffGenericCommand(c,c->argv+2,c->argc-2,c->argv[1],REDIS_OP_DIFF);
}