From 609acaad02bd419b0fd866006fc3ca0803bfa1db Mon Sep 17 00:00:00 2001 From: Moti Cohen Date: Sun, 18 Jan 2026 14:52:50 +0200 Subject: [PATCH] Optimize zset to use dict with no_value=1 (#14701) * Embed sds element inside skiplist nodes: Changed zset dict to store zskiplistNode* as keys (with no_value=1) instead of storing sds keys and double* values, eliminating redundant sds storage and enabling single-allocation nodes * Single allocation for skiplist nodes: Each node now contains: fixed fields + level[] array + embedded sds, reducing memory fragmentation and allocation overhead. This optimization is based on https://github.com/valkey-io/valkey/pull/1427 * Optimize lookups with dictFindLink: Use dictFindLink in zsetAdd to avoid double hash table lookup when inserting new elements (find + add becomes single operation) * Simplify score updates --- src/aof.c | 7 +- src/db.c | 6 +- src/debug.c | 6 +- src/defrag.c | 101 ++++------ src/geo.c | 6 +- src/module.c | 20 +- src/object.c | 7 +- src/rdb.c | 8 +- src/server.c | 11 -- src/server.h | 16 +- src/sort.c | 4 +- src/t_zset.c | 550 ++++++++++++++++++++++++++++++--------------------- 12 files changed, 414 insertions(+), 328 deletions(-) diff --git a/src/aof.c b/src/aof.c index 0e0653d5f..3ace67011 100644 --- a/src/aof.c +++ b/src/aof.c @@ -2042,8 +2042,9 @@ int rewriteSortedSetObject(rio *r, robj *key, robj *o) { dictInitIterator(&di, zs->dict); while((de = dictNext(&di)) != NULL) { - sds ele = dictGetKey(de); - double *score = dictGetVal(de); + zskiplistNode *znode = dictGetKey(de); + sds ele = zslGetNodeElement(znode); + double score = znode->score; if (count == 0) { int cmd_items = (items > AOF_REWRITE_ITEMS_PER_CMD) ? @@ -2057,7 +2058,7 @@ int rewriteSortedSetObject(rio *r, robj *key, robj *o) { return 0; } } - if (!rioWriteBulkDouble(r,*score) || + if (!rioWriteBulkDouble(r,score) || !rioWriteBulkString(r,ele,sdslen(ele))) { dictResetIterator(&di); diff --git a/src/db.c b/src/db.c index e54691b3c..2a64147f4 100644 --- a/src/db.c +++ b/src/db.c @@ -1576,12 +1576,16 @@ void scanCallback(void *privdata, const dictEntry *de, dictEntryLink plink) { serverAssert(!((data->type != LLONG_MAX) && o)); kvobj *kv = NULL; + zskiplistNode *znode = NULL; if (!o) { /* If scanning keyspace */ kv = dictGetKV(de); keyStr = kvobjGetKey(kv); } else if (o->type == OBJ_HASH) { hashEntry = dictGetKey(de); keyStr = entryGetField(hashEntry); + } else if (o->type == OBJ_ZSET) { + znode = dictGetKey(de); + keyStr = zslGetNodeElement(znode); } else { keyStr = dictGetKey(de); } @@ -1624,7 +1628,7 @@ void scanCallback(void *privdata, const dictEntry *de, dictEntryLink plink) { } else if (o->type == OBJ_ZSET) { char buf[MAX_LONG_DOUBLE_CHARS]; - int len = ld2string(buf, sizeof(buf), *(double *)dictGetVal(de), LD_STR_AUTO); + int len = ld2string(buf, sizeof(buf), znode->score, LD_STR_AUTO); key = sdsdup(keyStr); val = sdsnewlen(buf, len); } else { diff --git a/src/debug.c b/src/debug.c index 523180064..0b9b78d7a 100644 --- a/src/debug.c +++ b/src/debug.c @@ -199,9 +199,9 @@ void xorObjectDigest(redisDb *db, robj *keyobj, unsigned char *digest, robj *o) dictInitIterator(&di, zs->dict); while((de = dictNext(&di)) != NULL) { - sds sdsele = dictGetKey(de); - double *score = dictGetVal(de); - const int len = fpconv_dtoa(*score, buf); + zskiplistNode *znode = dictGetKey(de); + sds sdsele = zslGetNodeElement(znode); + const int len = fpconv_dtoa(znode->score, buf); buf[len] = '\0'; memset(eledigest,0,20); mixDigest(eledigest,sdsele,sdslen(sdsele)); diff --git a/src/defrag.c b/src/defrag.c index d2ae9985e..f0bff158f 100644 --- a/src/defrag.c +++ b/src/defrag.c @@ -382,7 +382,7 @@ dict *dictDefragTables(dict *d) { return ret; } -/* Internal function used by zslDefrag */ +/* Internal function used by activeDefragZsetNode */ void zslUpdateNode(zskiplist *zsl, zskiplistNode *oldnode, zskiplistNode *newnode, zskiplistNode **update) { int i; for (i = 0; i < zsl->level; i++) { @@ -399,60 +399,40 @@ void zslUpdateNode(zskiplist *zsl, zskiplistNode *oldnode, zskiplistNode *newnod } } -/* Defrag helper for sorted set. - * Update the robj pointer, defrag the skiplist struct and return the new score - * reference. We may not access oldele pointer (not even the pointer stored in - * the skiplist), as it was already freed. Newele may be null, in which case we - * only need to defrag the skiplist, but not update the obj pointer. - * When return value is non-NULL, it is the score reference that must be updated - * in the dict record. */ -double *zslDefrag(zskiplist *zsl, double score, sds oldele, sds newele) { - zskiplistNode *update[ZSKIPLIST_MAXLEVEL], *x, *newx; +/* Defrag a single zset node, update dictEntry and skiplist struct */ +void activeDefragZsetNode(zset *zs, dictEntry *de, dictEntryLink plink) { + zskiplistNode *znode = dictGetKey(de); + + /* Try to defrag the skiplist node first */ + zskiplistNode *newnode = activeDefragAllocWithoutFree(znode); + if (!newnode) return; /* No defrag needed */ + + /* Node was defragged, now we need to update all skiplist pointers */ + zskiplistNode *update[ZSKIPLIST_MAXLEVEL], *iter; int i; - sds ele = newele? newele: oldele; + double score = newnode->score; + sds ele = zslGetNodeElement(newnode); - /* find the skiplist node referring to the object that was moved, - * and all pointers that need to be updated if we'll end up moving the skiplist node. */ - x = zsl->header; - for (i = zsl->level-1; i >= 0; i--) { - while (x->level[i].forward && - x->level[i].forward->ele != oldele && /* make sure not to access the - ->obj pointer if it matches - oldele */ - (x->level[i].forward->score < score || - (x->level[i].forward->score == score && - sdscmp(x->level[i].forward->ele,ele) < 0))) - x = x->level[i].forward; - update[i] = x; + /* Find all pointers that need to be updated */ + iter = zs->zsl->header; + for (i = zs->zsl->level-1; i >= 0; i--) { + while (iter->level[i].forward && + iter->level[i].forward != znode && + zslCompareWithNode(score, ele, iter->level[i].forward) > 0) + iter = iter->level[i].forward; + update[i] = iter; } - /* update the robj pointer inside the skip list record. */ - x = x->level[0].forward; - serverAssert(x && score == x->score && x->ele==oldele); - if (newele) - x->ele = newele; + /* Verify we found the right node */ + iter = iter->level[0].forward; + serverAssert(iter && iter == znode); - /* try to defrag the skiplist record itself */ - newx = activeDefragAlloc(x); - if (newx) { - zslUpdateNode(zsl, x, newx, update); - return &newx->score; - } - return NULL; -} + /* Update all skiplist pointers and dict key */ + zslUpdateNode(zs->zsl, znode, newnode, update); + dictSetKeyAtLink(zs->dict, newnode, &plink, 0); -/* Defrag helper for sorted set. - * Defrag a single dict entry key name, and corresponding skiplist struct */ -void activeDefragZsetEntry(zset *zs, dictEntry *de) { - sds newsds; - double* newscore; - sds sdsele = dictGetKey(de); - if ((newsds = activeDefragSds(sdsele))) - dictSetKey(zs->dict, de, newsds); - newscore = zslDefrag(zs->zsl, *(double*)dictGetVal(de), sdsele, newsds); - if (newscore) { - dictSetVal(zs->dict, de, newscore); - } + /* Free the old node now that all pointers have been updated */ + activeDefragFree(znode); } #define DEFRAG_SDS_DICT_NO_VAL 0 @@ -627,11 +607,10 @@ typedef struct { zset *zs; } scanLaterZsetData; -void scanLaterZsetCallback(void *privdata, const dictEntry *_de, dictEntryLink plink) { - UNUSED(plink); +void scanZsetCallback(void *privdata, const dictEntry *_de, dictEntryLink plink) { dictEntry *de = (dictEntry*)_de; scanLaterZsetData *data = privdata; - activeDefragZsetEntry(data->zs, de); + activeDefragZsetNode(data->zs, de, plink); server.stat_active_defrag_scanned++; } @@ -641,7 +620,7 @@ void scanLaterZset(robj *ob, unsigned long *cursor) { dict *d = zs->dict; scanLaterZsetData data = {zs}; dictDefragFunctions defragfns = {.defragAlloc = activeDefragAlloc}; - *cursor = dictScanDefrag(d, *cursor, scanLaterZsetCallback, &defragfns, &data); + *cursor = dictScanDefrag(d, *cursor, scanZsetCallback, &defragfns, &data); } /* Used as scan callback when all the work is done in the dictDefragFunctions. */ @@ -723,7 +702,6 @@ void defragZsetSkiplist(defragKeysCtx *ctx, kvobj *ob) { zset *newzs; zskiplist *newzsl; dict *newdict; - dictEntry *de; struct zskiplistNode *newheader; serverAssert(ob->type == OBJ_ZSET && ob->encoding == OBJ_ENCODING_SKIPLIST); if ((newzs = activeDefragAlloc(zs))) @@ -735,12 +713,15 @@ void defragZsetSkiplist(defragKeysCtx *ctx, kvobj *ob) { if (dictSize(zs->dict) > server.active_defrag_max_scan_fields) defragLater(ctx, ob); else { - dictIterator di; - dictInitIterator(&di, zs->dict); - while((de = dictNext(&di)) != NULL) { - activeDefragZsetEntry(zs, de); - } - dictResetIterator(&di); + /* Use dictScanDefrag to iterate and defrag both dictEntry structures and skiplist nodes. + * dictScanDefrag handles defragging dictEntry/dictEntryNoValue structures via defragfns, + * and calls our callback with plink for each entry so we can defrag skiplist nodes. */ + scanLaterZsetData data = {zs}; + dictDefragFunctions defragfns = {.defragAlloc = activeDefragAlloc}; + unsigned long cursor = 0; + do { + cursor = dictScanDefrag(zs->dict, cursor, scanZsetCallback, &defragfns, &data); + } while (cursor != 0); } /* defrag the dict struct and tables */ if ((newdict = dictDefragTables(zs->dict))) diff --git a/src/geo.c b/src/geo.c index 6d8181897..ce890f7f0 100644 --- a/src/geo.c +++ b/src/geo.c @@ -313,7 +313,8 @@ int geoGetPointsInRange(robj *zobj, double min, double max, GeoShape *shape, geo break; if (geoWithinShape(shape, ln->score, xy, &distance) == C_OK) { /* Append the new element. */ - geoArrayAppend(ga, xy, distance, ln->score, sdsdup(ln->ele)); + sds ele = zslGetNodeElement(ln); + geoArrayAppend(ga, xy, distance, ln->score, sdsdup(ele)); } if (ga->used && limit && ga->used >= limit) break; ln = ln->level[0].forward; @@ -822,7 +823,8 @@ void georadiusGeneric(client *c, int srcKeyIndex, int flags) { if (maxelelen < elelen) maxelelen = elelen; totelelen += elelen; znode = zslInsert(zs->zsl,score,gp->member); - serverAssert(dictAdd(zs->dict,gp->member,&znode->score) == DICT_OK); + serverAssert(dictAdd(zs->dict, znode, NULL) == DICT_OK); + sdsfree(gp->member); /* zslInsert copies the sds, so free the original */ gp->member = NULL; } diff --git a/src/module.c b/src/module.c index 29f588fec..3cabd9ca5 100644 --- a/src/module.c +++ b/src/module.c @@ -5431,7 +5431,8 @@ RedisModuleString *RM_ZsetRangeCurrentElement(RedisModuleKey *key, double *score } else if (key->kv->encoding == OBJ_ENCODING_SKIPLIST) { zskiplistNode *ln = key->u.zset.current; if (score) *score = ln->score; - str = createStringObject(ln->ele,sdslen(ln->ele)); + sds ele = zslGetNodeElement(ln); + str = createStringObject(ele,sdslen(ele)); } else { serverPanic("Unsupported zset encoding"); } @@ -5490,7 +5491,7 @@ int RM_ZsetRangeNext(RedisModuleKey *key) { key->u.zset.er = 1; return 0; } else if (key->u.zset.type == REDISMODULE_ZSET_RANGE_LEX) { - if (!zslLexValueLteMax(next->ele,&key->u.zset.lrs)) { + if (!zslLexValueLteMax(zslGetNodeElement(next),&key->u.zset.lrs)) { key->u.zset.er = 1; return 0; } @@ -5554,7 +5555,7 @@ int RM_ZsetRangePrev(RedisModuleKey *key) { key->u.zset.er = 1; return 0; } else if (key->u.zset.type == REDISMODULE_ZSET_RANGE_LEX) { - if (!zslLexValueGteMin(prev->ele,&key->u.zset.lrs)) { + if (!zslLexValueGteMin(zslGetNodeElement(prev),&key->u.zset.lrs)) { key->u.zset.er = 1; return 0; } @@ -11836,6 +11837,7 @@ static void moduleScanKeyCallback(void *privdata, const dictEntry *de, dictEntry robj *field = NULL; robj *value = NULL; if (kv->type == OBJ_SET) { + field = createStringObject(key, sdslen(key)); value = NULL; } else if (kv->type == OBJ_HASH) { Entry *e = (Entry *) key; @@ -11852,13 +11854,13 @@ static void moduleScanKeyCallback(void *privdata, const dictEntry *de, dictEntry field = createStringObject(fieldStr, sdslen(fieldStr)); value = createStringObject(val, sdslen(val)); } else if (kv->type == OBJ_ZSET) { - double *val = (double*)dictGetVal(de); - value = createStringObjectFromLongDouble(*val, 0); + zskiplistNode *znode = (zskiplistNode *) key; + sds fieldStr = zslGetNodeElement(znode); + field = createStringObject(fieldStr, sdslen(fieldStr)); + value = createStringObjectFromLongDouble(znode->score, 0); } - - /* if type is OBJ_HASH then key is of type entry*. Otherwise sds. */ - if (!field) field = createStringObject(key, sdslen(key)); - + + serverAssert(field != NULL); data->fn(data->key, field, value, data->user_data); decrRefCount(field); if (value) decrRefCount(value); diff --git a/src/object.c b/src/object.c index 1241dff8b..14a521868 100644 --- a/src/object.c +++ b/src/object.c @@ -711,10 +711,11 @@ void dismissZsetObject(robj *o, size_t size_hint) { /* We iterate all nodes only when average member size is bigger than a * page size, and there's a high chance we'll actually dismiss something. */ if (size_hint / zsl->length >= server.page_size) { - zskiplistNode *zn = zsl->tail; + zskiplistNode *zn = zsl->header->level[0].forward; while (zn != NULL) { - dismissSds(zn->ele); - zn = zn->backward; + zskiplistNode *next = zn->level[0].forward; + dismissMemory(zn, 0); + zn = next; } } diff --git a/src/rdb.c b/src/rdb.c index 08f3dfbb1..8f9da2e8c 100644 --- a/src/rdb.c +++ b/src/rdb.c @@ -1095,8 +1095,9 @@ ssize_t rdbSaveObject(rio *rdb, robj *o, robj *key, int dbid) { * O(1) instead of O(log(N)). */ zskiplistNode *zn = zsl->tail; while (zn != NULL) { + sds ele = zslGetNodeElement(zn); if ((n = rdbSaveRawString(rdb, - (unsigned char*)zn->ele,sdslen(zn->ele))) == -1) + (unsigned char*)ele,sdslen(ele))) == -1) { return -1; } @@ -2374,12 +2375,13 @@ robj *rdbLoadObject(int rdbtype, rio *rdb, sds key, int dbid, int *error) totelelen += sdslen(sdsele); znode = zslInsert(zs->zsl,score,sdsele); - if (dictAdd(zs->dict,sdsele,&znode->score) != DICT_OK) { + if (dictAdd(zs->dict, znode, NULL) != DICT_OK) { rdbReportCorruptRDB("Duplicate zset fields detected"); decrRefCount(o); - /* no need to free 'sdsele', will be released by zslFree together with 'o' */ + sdsfree(sdsele); /* zslInsert copies the sds, so we need to free the original */ return NULL; } + sdsfree(sdsele); /* zslInsert copies the sds into the node, so free the original */ } /* Convert *after* loading, since sorted sets are not stored ordered. */ diff --git a/src/server.c b/src/server.c index f7c368bb5..3c040aec6 100644 --- a/src/server.c +++ b/src/server.c @@ -590,17 +590,6 @@ dictType setDictType = { .dictMetadataBytes = setDictMetadataBytes, }; -/* Sorted sets hash (note: a skiplist is used in addition to the hash table) */ -dictType zsetDictType = { - dictSdsHash, /* hash function */ - NULL, /* key dup */ - NULL, /* val dup */ - dictSdsKeyCompare, /* key compare */ - NULL, /* Note: SDS string shared & freed by skiplist */ - NULL, /* val destructor */ - NULL, /* allow to expand */ -}; - /* Db->dict, keys are of type kvobj, unification of key and value */ dictType dbDictType = { dictSdsHash, /* hash function */ diff --git a/src/server.h b/src/server.h index 943400122..cd72f4863 100644 --- a/src/server.h +++ b/src/server.h @@ -1595,17 +1595,24 @@ struct sharedObjectsStruct { }; /* ZSETs use a specialized version of Skiplists */ + +/* Node info placed in level[0].span since it's unused at level 0 (static assert verified) */ +typedef struct zskiplistNodeInfo { + uint16_t sdsoffset; /* Offset from node start to sds data (after sds header) */ + uint8_t levels; /* Number of levels in this node (1-32) */ + uint8_t reserved; +} zskiplistNodeInfo; + typedef struct zskiplistNode { - sds ele; double score; struct zskiplistNode *backward; struct zskiplistLevel { struct zskiplistNode *forward; /* Span is the number of elements between this node and the next node at this level. - * At level 0, span is always 1 (or 0 for the last node), so we repurpose it to store - * the node's level. This enables O(1) access to node level for rank calculations. */ + * At level 0, span is repurposed to store zskiplistNodeInfo for regular nodes, */ unsigned long span; } level[]; + /* sds ele is embedded after level[] array (assist zslGetNodeElement(node) to access it) */ } zskiplistNode; typedef struct zskiplist { @@ -3418,9 +3425,10 @@ typedef struct { zskiplist *zslCreate(void); void zslFree(zskiplist *zsl); size_t zslAllocSize(const zskiplist *zsl); +sds zslGetNodeElement(const zskiplistNode *node); +int zslCompareWithNode(double score, sds ele, const zskiplistNode *n); zskiplistNode *zslInsert(zskiplist *zsl, double score, sds ele); unsigned char *zzlInsert(unsigned char *zl, sds ele, double score); -int zslDelete(zskiplist *zsl, double score, sds ele, zskiplistNode **node); zskiplistNode *zslNthInRange(zskiplist *zsl, zrangespec *range, long n, unsigned long *out_rank); double zzlGetScore(unsigned char *sptr); void zzlNext(unsigned char *zl, unsigned char **eptr, unsigned char **sptr); diff --git a/src/sort.c b/src/sort.c index d397b5ed5..997ca62b7 100644 --- a/src/sort.c +++ b/src/sort.c @@ -467,7 +467,7 @@ void sortCommandGeneric(client *c, int readonly) { while(rangelen--) { serverAssertWithInfo(c,sortval,ln != NULL); - sdsele = ln->ele; + sdsele = zslGetNodeElement(ln); vector[j].obj = createStringObject(sdsele,sdslen(sdsele)); vector[j].u.score = 0; vector[j].u.cmpobj = NULL; @@ -487,7 +487,7 @@ void sortCommandGeneric(client *c, int readonly) { oldsize = zsetAllocSize(sortval); dictInitIterator(&di, set); while((setele = dictNext(&di)) != NULL) { - sdsele = dictGetKey(setele); + sdsele = zslGetNodeElement(dictGetKey(setele)); vector[j].obj = createStringObject(sdsele,sdslen(sdsele)); vector[j].u.score = 0; vector[j].u.cmpobj = NULL; diff --git a/src/t_zset.c b/src/t_zset.c index 4b178d47c..f531dc430 100644 --- a/src/t_zset.c +++ b/src/t_zset.c @@ -45,6 +45,24 @@ #include "intset.h" /* Compact integer set structure */ #include +#define ZSL_OFFSET_MAX_ELE UINT16_MAX +#define ZSL_OFFSET_NO_ELE UINT16_MAX + +const void *zslGetNodeElementForDict(const void *node); + +/* dictType for zset's dict (maps sds to zskiplistNode*) */ +dictType zsetDictType = { + dictSdsHash, /* hash function */ + NULL, /* key dup */ + NULL, /* val dup */ + dictSdsKeyCompare, /* compares embedded sds by keyFromStoredKey */ + NULL, /* key destructor - skiplist owns the node memory */ + NULL, /* val destructor */ + NULL, /* allow to expand */ + .no_value = 1, /* no values stored (only nodes) */ + .keyFromStoredKey = zslGetNodeElementForDict, /* extract embedded sds from node */ +}; + /*----------------------------------------------------------------------------- * Skiplist implementation of the low level API *----------------------------------------------------------------------------*/ @@ -58,7 +76,8 @@ static inline unsigned long zslGetNodeSpanAtLevel(zskiplistNode *x, int level) { /* At level 0, span stores node level instead of distance, so return the actual span value: * 1 for all nodes except the last node (which has span 0). */ if (level > 0) return x->level[level].span; - return x->level[level].forward ? 1 : 0; + /* For level 0, if regular node, span is 1. If tail node, span is 0. */ + return x->level[0].forward ? 1 : 0; } static inline void zslSetNodeSpanAtLevel(zskiplistNode *x, int level, unsigned long span) { @@ -79,33 +98,114 @@ static inline void zslDecrNodeSpanAtLevel(zskiplistNode *x, int level, unsigned x->level[level].span -= decr; } -static inline unsigned long zslGetNodeLevel(zskiplistNode *x) { - /* Level 0 span is repurposed to store node level. */ - return x->level[0].span; +/* Get zskiplistNodeInfo from node (stored in level[0].span). */ +static_assert(sizeof(zskiplistNodeInfo) <= sizeof(((zskiplistNode *)0)->level[0].span), "Must fit in level[0].span"); +static inline zskiplistNodeInfo *zslGetNodeInfo(const zskiplistNode *node) { + return (zskiplistNodeInfo *)&node->level[0].span; } -static inline void zslSetNodeLevel(zskiplistNode *x, int level) { - /* Level 0 span is repurposed to store node level. */ - x->level[0].span = level; +/* Set zskiplistNodeInfo in node (stored in level[0].span) */ +static inline void zslSetNodeInfo(zskiplistNode *node, uint8_t levels, uint16_t sdsoffset) { + union { + zskiplistNodeInfo info; + unsigned long span; + } u = { .info = { .levels = levels, .sdsoffset = sdsoffset } }; + node->level[0].span = u.span; +} + +/* Compare {score, ele} with node. Returns: 1=bigger 0=equal -1=smaller + * + * Ordering is by score first, then lexicographically by element. + * NULL is treated as +infinity (comes after any real node). */ +int zslCompareWithNode(double score, sds ele, const zskiplistNode *n) { + if (/*score < */ n == NULL) return -1; /* NULL is +infinity, comes after any real node */ + if (score < n->score) return -1; + if (score > n->score) return 1; + /* Scores are equal, compare elements lexicographically */ + return sdscmp(ele, zslGetNodeElement(n)); +} + +/* Get embedded sds from node. Uses the stored offset to directly access the sds data */ +sds zslGetNodeElement(const zskiplistNode *node) { + zskiplistNodeInfo *info = zslGetNodeInfo(node); + debugServerAssert(info->sdsoffset != ZSL_OFFSET_NO_ELE); + return (char*)node + info->sdsoffset; +} + +/* Wrapper for dict getKeyId callback - extracts sds from node pointer. + * This allows the dict to store zskiplistNode* but look them up using sds. */ +const void *zslGetNodeElementForDict(const void *node) { + return zslGetNodeElement((zskiplistNode*)node); +} + +/* Create a skiplist header node with ZSKIPLIST_MAXLEVEL levels */ +static zskiplistNode *zslCreateHeaderNode(zskiplist *zsl) { + size_t usable; + zskiplistNode *zn = zmalloc_usable(sizeof(*zn) + ZSKIPLIST_MAXLEVEL * sizeof(struct zskiplistLevel), &usable); + + /* Initialize all fields */ + zn->score = 0; + zn->backward = NULL; + + /* Initialize all level pointers and spans */ + for (int j = 0; j < ZSKIPLIST_MAXLEVEL; j++) { + zn->level[j].forward = NULL; + zn->level[j].span = 0; /* Will be overwritten for level[0] below */ + } + + /* Use ZSL_OFFSET_NO_ELE as sentinel to indicate no embedded sds (header node) */ + zslSetNodeInfo(zn, ZSKIPLIST_MAXLEVEL, ZSL_OFFSET_NO_ELE); + + /* Track allocation size */ + zsl->alloc_size += usable; + + return zn; } /* Create a skiplist node with the specified number of levels. - * The SDS string 'ele' is referenced by the node after the call. */ + * The SDS string 'ele' is COPIED into an embedded sds within the node allocation. + * This creates a single allocation containing: node + level[] + embedded sds. + * The caller is responsible for freeing 'ele' if it's no longer needed. */ static zskiplistNode *zslCreateNode(zskiplist *zsl, int level, double score, sds ele) { size_t usable; - zskiplistNode *zn = - zmalloc_usable(sizeof(*zn)+level*sizeof(struct zskiplistLevel), &usable); + size_t ele_len = sdslen(ele); + char sds_type = sdsReqType(ele_len); + size_t sds_hdr_len = sdsHdrSize(sds_type); + + /* Calculate total size: node fixed part + level[] + sds buffer space */ + size_t node_size = sizeof(zskiplistNode) + level * sizeof(struct zskiplistLevel); + size_t sds_buf_size = sds_hdr_len + ele_len + 1; /* header + data + null terminator */ + size_t total_size = node_size + sds_buf_size; + + /* Allocate single block for everything */ + zskiplistNode *zn = zmalloc_usable(total_size, &usable); + + /* Initialize node fields */ zn->score = score; - zn->ele = ele; - zslSetNodeLevel(zn, level); + zn->backward = NULL; + + /* Calculate offset from node start to sds data (after sds header) */ + size_t sds_offset = node_size + sds_hdr_len; + debugServerAssert(sds_offset < ZSL_OFFSET_MAX_ELE); + + /* Initialize embedded sds using sdsnewplacement */ + char *sds_buf = (char*)zn + node_size; + sds embedded_sds = sdsnewplacement(sds_buf, sds_buf_size, sds_type, ele, ele_len); + + /* Store node info in level[0].span */ + zslSetNodeInfo(zn, level, sds_offset); + + /* Verify that embedded_sds matches our calculated offset */ + serverAssert(embedded_sds == (sds)((char*)zn + sds_offset)); + + /* Update allocation size tracking */ zsl->alloc_size += usable; - if (ele) zsl->alloc_size += sdsAllocSize(ele); + return zn; } /* Create a new skiplist. */ zskiplist *zslCreate(void) { - int j; zskiplist *zsl; size_t zsl_size; @@ -113,25 +213,17 @@ zskiplist *zslCreate(void) { zsl->level = 1; zsl->length = 0; zsl->alloc_size = zsl_size; - zsl->header = zslCreateNode(zsl,ZSKIPLIST_MAXLEVEL,0,NULL); - for (j = 0; j < ZSKIPLIST_MAXLEVEL; j++) { - zsl->header->level[j].forward = NULL; - zsl->header->level[j].span = 0; - } + zsl->header = zslCreateHeaderNode(zsl); zsl->header->backward = NULL; zsl->tail = NULL; return zsl; } -/* Free the specified skiplist node. The referenced SDS string representation - * of the element is freed too, unless node->ele is set to NULL before calling - * this function. */ +/* Free the specified skiplist node. The embedded SDS is freed as part of + * the single allocation (node + level[] + embedded sds). */ static void zslFreeNode(zskiplist *zsl, zskiplistNode *node) { size_t usable; - if (node->ele) { - zsl->alloc_size -= sdsAllocSize(node->ele); - sdsfree(node->ele); - } + /* No separate sdsfree() needed - embedded sds is part of node allocation */ zfree_usable(node, &usable); zsl->alloc_size -= usable; } @@ -167,34 +259,32 @@ static int zslRandomLevel(void) { return (levelscore; + sds ele = zslGetNodeElement(node); + level = zslGetNodeInfo(node)->levels; serverAssert(!isnan(score)); + + /* Find the position where this node should be inserted */ x = zsl->header; for (i = zsl->level-1; i >= 0; i--) { /* store rank that is crossed to reach the insert position */ rank[i] = i == (zsl->level-1) ? 0 : rank[i+1]; - while (x->level[i].forward && - (x->level[i].forward->score < score || - (x->level[i].forward->score == score && - sdscmp(x->level[i].forward->ele,ele) < 0))) - { + while (zslCompareWithNode(score, ele, x->level[i].forward) > 0) { rank[i] += zslGetNodeSpanAtLevel(x, i); x = x->level[i].forward; } update[i] = x; } - /* we assume the element is not already inside, since we allow duplicated - * scores, reinserting the same element should never happen since the - * caller of zslInsert() should test in the hash table if the element is - * already inside or not. */ - level = zslRandomLevel(); + + /* Update skiplist level if needed */ if (level > zsl->level) { for (i = zsl->level; i < level; i++) { rank[i] = 0; @@ -202,14 +292,16 @@ zskiplistNode *zslInsert(zskiplist *zsl, double score, sds ele) { zslSetNodeSpanAtLevel(update[i], i, zsl->length); } zsl->level = level; + zslGetNodeInfo(zsl->header)->levels = level; } - x = zslCreateNode(zsl,level,score,ele); - for (i = 0; i < level; i++) { - x->level[i].forward = update[i]->level[i].forward; - update[i]->level[i].forward = x; - /* update span covered by update[i] as x is inserted here */ - zslSetNodeSpanAtLevel(x, i, zslGetNodeSpanAtLevel(update[i], i) - (rank[0] - rank[i])); + /* Insert the node at the found position */ + for (i = 0; i < level; i++) { + node->level[i].forward = update[i]->level[i].forward; + update[i]->level[i].forward = node; + + /* update span covered by update[i] as node is inserted here */ + zslSetNodeSpanAtLevel(node, i, zslGetNodeSpanAtLevel(update[i], i) - (rank[0] - rank[i])); zslSetNodeSpanAtLevel(update[i], i, (rank[0] - rank[i]) + 1); } @@ -218,18 +310,39 @@ zskiplistNode *zslInsert(zskiplist *zsl, double score, sds ele) { zslIncrNodeSpanAtLevel(update[i], i, 1); } - x->backward = (update[0] == zsl->header) ? NULL : update[0]; - if (x->level[0].forward) - x->level[0].forward->backward = x; + /* Update backward pointers */ + node->backward = (update[0] == zsl->header) ? NULL : update[0]; + if (node->level[0].forward) + node->level[0].forward->backward = node; else - zsl->tail = x; + zsl->tail = node; + zsl->length++; - return x; +} + +/* Insert a new node in the skiplist. Assumes the element does not already + * exist (up to the caller to enforce that). The element 'ele' is COPIED + * into the new node, so the caller retains ownership and can free it. */ +zskiplistNode *zslInsert(zskiplist *zsl, double score, sds ele) { + int level; + + serverAssert(!isnan(score)); + + /* we assume the element is not already inside, since we allow duplicated + * scores, reinserting the same element should never happen since the + * caller of zslInsert() should test in the hash table if the element is + * already inside or not. */ + level = zslRandomLevel(); + zskiplistNode *node = zslCreateNode(zsl, level, score, ele); + zslInsertNode(zsl, node); + return node; } /* Internal function used by zslDelete, zslDeleteRangeByScore and - * zslDeleteRangeByRank. */ -static void zslDeleteNode(zskiplist *zsl, zskiplistNode *x, zskiplistNode **update) { + * zslDeleteRangeByRank. + * This function only unlinks the node from the skiplist structure but does NOT free it. + * The caller is responsible for freeing the node with zslFreeNode(). */ +static void zslUnlinkNode(zskiplist *zsl, zskiplistNode *x, zskiplistNode **update) { int i; for (i = 0; i < zsl->level; i++) { if (update[i]->level[i].forward == x) { @@ -252,97 +365,68 @@ static void zslDeleteNode(zskiplist *zsl, zskiplistNode *x, zskiplistNode **upda zsl->length--; } -/* Delete an element with matching score/element from the skiplist. - * The function returns 1 if the node was found and deleted, otherwise - * 0 is returned. - * - * If 'node' is NULL the deleted node is freed by zslFreeNode(), otherwise - * it is not freed (but just unlinked) and *node is set to the node pointer, - * so that it is possible for the caller to reuse the node (including the - * referenced SDS string at node->ele). */ -int zslDelete(zskiplist *zsl, double score, sds ele, zskiplistNode **node) { +/* Delete the specified node from the skiplist. + * The node is unlinked from all levels and then freed by zslFreeNode(), + * which also frees the embedded SDS string. */ +static void zslDelete(zskiplist *zsl, zskiplistNode *node) { zskiplistNode *update[ZSKIPLIST_MAXLEVEL], *x; int i; + double score = node->score; + sds ele = zslGetNodeElement(node); x = zsl->header; for (i = zsl->level-1; i >= 0; i--) { - while (x->level[i].forward && - (x->level[i].forward->score < score || - (x->level[i].forward->score == score && - sdscmp(x->level[i].forward->ele,ele) < 0))) - { + while (zslCompareWithNode(score, ele, x->level[i].forward) > 0) { x = x->level[i].forward; } update[i] = x; } - /* We may have multiple elements with the same score, what we need - * is to find the element with both the right score and object. */ - x = x->level[0].forward; - if (x && score == x->score && sdscmp(x->ele,ele) == 0) { - zslDeleteNode(zsl, x, update); - if (!node) - zslFreeNode(zsl, x); - else - *node = x; - return 1; - } - return 0; /* not found */ + + /* Verify we truly found the node */ + serverAssert(x->level[0].forward == node); + + zslUnlinkNode(zsl, node, update); + zslFreeNode(zsl, node); } /* Update the score of an element inside the sorted set skiplist. - * Note that the element must exist and must match 'score'. - * This function does not update the score in the hash table side, the - * caller should take care of it. - * - * Note that this function attempts to just update the node, in case after - * the score update, the node would be exactly at the same position. - * Otherwise the skiplist is modified by removing and re-adding a new - * element, which is more costly. - * - * The function returns the updated element skiplist node pointer. */ -static zskiplistNode *zslUpdateScore(zskiplist *zsl, double curscore, sds ele, double newscore) { + * If the new score would keep the node in its current position, updates in-place and returns NULL. + * Otherwise, unlinks the node, updates score, reinserts at correct position, and returns node. + * Anyway, the node pointer stays the same (no dict update needed). */ +static void zslUpdateScore(zskiplist *zsl, zskiplistNode *node, double newscore) { + /* Fast path: if the node, after the score update, would be still exactly + * at the same position, we can just update the score without + * actually removing and re-inserting the element in the skiplist. */ + if ((node->backward == NULL || node->backward->score < newscore) && + (node->level[0].forward == NULL || node->level[0].forward->score > newscore)) + { + node->score = newscore; + return; + } + + /* Slow path: need to reposition the node. + * Find the update[] array for unlinking. */ zskiplistNode *update[ZSKIPLIST_MAXLEVEL], *x; int i; + double curscore = node->score; + sds ele = zslGetNodeElement(node); - /* We need to seek to element to update to start: this is useful anyway, - * we'll have to update or remove it. */ x = zsl->header; for (i = zsl->level-1; i >= 0; i--) { - while (x->level[i].forward && - (x->level[i].forward->score < curscore || - (x->level[i].forward->score == curscore && - sdscmp(x->level[i].forward->ele,ele) < 0))) - { + while (zslCompareWithNode(curscore, ele, x->level[i].forward) > 0) { x = x->level[i].forward; } update[i] = x; } - /* Jump to our element: note that this function assumes that the - * element with the matching score exists. */ - x = x->level[0].forward; - serverAssert(x && curscore == x->score && sdscmp(x->ele,ele) == 0); + /* Verify we found the right node */ + serverAssert(x->level[0].forward == node); - /* If the node, after the score update, would be still exactly - * at the same position, we can just update the score without - * actually removing and re-inserting the element in the skiplist. */ - if ((x->backward == NULL || x->backward->score < newscore) && - (x->level[0].forward == NULL || x->level[0].forward->score > newscore)) - { - x->score = newscore; - return x; - } - - /* No way to reuse the old node: we need to remove and insert a new - * one at a different place. */ - zslDeleteNode(zsl, x, update); - zskiplistNode *newnode = zslInsert(zsl,newscore,x->ele); - /* We reused the old node x->ele SDS string, free the node now - * since zslInsert created a new one. */ - if (x->ele) zsl->alloc_size -= sdsAllocSize(x->ele); - x->ele = NULL; - zslFreeNode(zsl, x); - return newnode; + /* Unlink, update score, and reinsert at new position. + * We reuse the same node to avoid dict updates. */ + zslUnlinkNode(zsl, node, update); + node->score = newscore; + zslInsertNode(zsl, node); } int zslValueGteMin(double value, zrangespec *spec) { @@ -486,8 +570,8 @@ static unsigned long zslDeleteRangeByScore(zskiplist *zsl, zrangespec *range, di /* Delete nodes while in range. */ while (x && zslValueLteMax(x->score, range)) { zskiplistNode *next = x->level[0].forward; - zslDeleteNode(zsl,x,update); - dictDelete(dict,x->ele); + zslUnlinkNode(zsl,x,update); + dictDelete(dict,zslGetNodeElement(x)); zslFreeNode(zsl, x); /* Here is where x->ele is actually released. */ removed++; x = next; @@ -504,7 +588,7 @@ static unsigned long zslDeleteRangeByLex(zskiplist *zsl, zlexrangespec *range, d x = zsl->header; for (i = zsl->level-1; i >= 0; i--) { while (x->level[i].forward && - !zslLexValueGteMin(x->level[i].forward->ele,range)) + !zslLexValueGteMin(zslGetNodeElement(x->level[i].forward),range)) x = x->level[i].forward; update[i] = x; } @@ -513,10 +597,10 @@ static unsigned long zslDeleteRangeByLex(zskiplist *zsl, zlexrangespec *range, d x = x->level[0].forward; /* Delete nodes while in range. */ - while (x && zslLexValueLteMax(x->ele,range)) { + while (x && zslLexValueLteMax(zslGetNodeElement(x),range)) { zskiplistNode *next = x->level[0].forward; - zslDeleteNode(zsl,x,update); - dictDelete(dict,x->ele); + zslUnlinkNode(zsl,x,update); + dictDelete(dict,zslGetNodeElement(x)); zslFreeNode(zsl, x); /* Here is where x->ele is actually released. */ removed++; x = next; @@ -544,8 +628,8 @@ static unsigned long zslDeleteRangeByRank(zskiplist *zsl, unsigned int start, un x = x->level[0].forward; while (x && traversed <= end) { zskiplistNode *next = x->level[0].forward; - zslDeleteNode(zsl,x,update); - dictDelete(dict,x->ele); + zslUnlinkNode(zsl,x,update); + dictDelete(dict,zslGetNodeElement(x)); zslFreeNode(zsl, x); removed++; traversed++; @@ -565,16 +649,12 @@ unsigned long zslGetRank(zskiplist *zsl, double score, sds ele) { x = zsl->header; for (i = zsl->level-1; i >= 0; i--) { - while (x->level[i].forward && - (x->level[i].forward->score < score || - (x->level[i].forward->score == score && - sdscmp(x->level[i].forward->ele,ele) <= 0))) { + while (zslCompareWithNode(score, ele, x->level[i].forward) >= 0) { rank += zslGetNodeSpanAtLevel(x, i); x = x->level[i].forward; } - /* x might be equal to zsl->header, so test if obj is non-NULL */ - if (x->ele && x->score == score && sdscmp(x->ele,ele) == 0) { + if (x != zsl->header && zslCompareWithNode(score, ele, x) == 0) { return rank; } } @@ -595,7 +675,7 @@ unsigned long zslGetRankByNode(zskiplist *zsl, zskiplistNode *x) { /* Walk forward from x to the end, using top level of each node for fast jumps */ while (x) { - level = zslGetNodeLevel(x) - 1; + level = zslGetNodeInfo(x)->levels - 1; distance_to_end += zslGetNodeSpanAtLevel(x, level); x = x->level[level].forward; } @@ -769,10 +849,10 @@ static int zslIsInLexRange(zskiplist *zsl, zlexrangespec *range) { if (cmp > 0 || (cmp == 0 && (range->minex || range->maxex))) return 0; x = zsl->tail; - if (x == NULL || !zslLexValueGteMin(x->ele,range)) + if ((x == NULL) || (!zslLexValueGteMin(zslGetNodeElement(x),range))) return 0; x = zsl->header->level[0].forward; - if (x == NULL || !zslLexValueLteMax(x->ele,range)) + if ((x == NULL) || (!zslLexValueLteMax(zslGetNodeElement(x),range))) return 0; return 1; } @@ -795,7 +875,7 @@ zskiplistNode *zslNthInLexRange(zskiplist *zsl, zlexrangespec *range, long n, un /* Go forward while *OUT* of range at level of zsl->level-1. */ x = zsl->header; i = zsl->level - 1; - while (x->level[i].forward && !zslLexValueGteMin(x->level[i].forward->ele, range)) { + while (x->level[i].forward && !zslLexValueGteMin(zslGetNodeElement(x->level[i].forward), range)) { edge_rank += zslGetNodeSpanAtLevel(x, i); x = x->level[i].forward; } @@ -806,7 +886,7 @@ zskiplistNode *zslNthInLexRange(zskiplist *zsl, zlexrangespec *range, long n, un if (n >= 0) { for (i = zsl->level - 2; i >= 0; i--) { /* Go forward while *OUT* of range. */ - while (x->level[i].forward && !zslLexValueGteMin(x->level[i].forward->ele, range)) { + while (x->level[i].forward && !zslLexValueGteMin(zslGetNodeElement(x->level[i].forward), range)) { /* Count the rank of the last element smaller than the range. */ edge_rank += zslGetNodeSpanAtLevel(x, i); x = x->level[i].forward; @@ -826,13 +906,13 @@ zskiplistNode *zslNthInLexRange(zskiplist *zsl, zlexrangespec *range, long n, un x = zslGetElementByRankFromNode(last_highest_level_node, zsl->level - 1, rank_diff); } /* Check if score <= max. */ - if (x && !zslLexValueLteMax(x->ele,range)) return NULL; + if (x && !zslLexValueLteMax(zslGetNodeElement(x),range)) return NULL; /* Store rank if requested. For n >= 0, the returned node is at rank edge_rank + n + 1. */ if (x && out_rank) *out_rank = edge_rank + n + 1; } else { for (i = zsl->level - 1; i >= 0; i--) { /* Go forward while *IN* range. */ - while (x->level[i].forward && zslLexValueLteMax(x->level[i].forward->ele, range)) { + while (x->level[i].forward && zslLexValueLteMax(zslGetNodeElement(x->level[i].forward), range)) { /* Count the rank of the last element in range. */ edge_rank += zslGetNodeSpanAtLevel(x, i); x = x->level[i].forward; @@ -852,7 +932,7 @@ zskiplistNode *zslNthInLexRange(zskiplist *zsl, zlexrangespec *range, long n, un x = zslGetElementByRankFromNode(last_highest_level_node, zsl->level - 1, rank_diff); } /* Check if score >= min. */ - if (x && !zslLexValueGteMin(x->ele, range)) return NULL; + if (x && !zslLexValueGteMin(zslGetNodeElement(x), range)) return NULL; /* Store rank if requested. For n < 0, the returned node is at rank edge_rank + n + 1. */ if (x && out_rank) *out_rank = edge_rank + n + 1; } @@ -1297,14 +1377,6 @@ static unsigned char *zzlDeleteRangeByRank(unsigned char *zl, unsigned int start * Common sorted set API *----------------------------------------------------------------------------*/ -/* Get the skiplist node from a dict entry. The dict value points to &node->score, - * so we use pointer arithmetic to get the node address. This enables O(1) lookup - * of the skiplist node from the hash table, used by optimized ZRANK. */ -static inline zskiplistNode *zsetGetSLNodeByEntry(dictEntry *de) { - char *score_ref = ((char *)dictGetVal(de)); - return (zskiplistNode *)(score_ref - offsetof(zskiplistNode, score)); -} - unsigned long zsetLength(const robj *zobj) { unsigned long length = 0; if (zobj->encoding == OBJ_ENCODING_LISTPACK) { @@ -1411,7 +1483,8 @@ void zsetConvertAndExpand(robj *zobj, int encoding, unsigned long cap) { ele = sdsnewlen((char*)vstr,vlen); node = zslInsert(zs->zsl,score,ele); - serverAssert(dictAdd(zs->dict,ele,&node->score) == DICT_OK); + serverAssert(dictAdd(zs->dict, node, NULL) == DICT_OK); + sdsfree(ele); /* zslInsert copied it, we can free our copy */ zzlNext(zl,&eptr,&sptr); } @@ -1432,7 +1505,7 @@ void zsetConvertAndExpand(robj *zobj, int encoding, unsigned long cap) { zfree(zs->zsl->header); while (node) { - zl = zzlInsertAt(zl,NULL,node->ele,node->score); + zl = zzlInsertAt(zl,NULL,zslGetNodeElement(node),node->score); next = node->level[0].forward; zslFreeNode(zs->zsl, node); node = next; @@ -1475,7 +1548,8 @@ int zsetScore(robj *zobj, sds member, double *score) { zset *zs = zobj->ptr; dictEntry *de = dictFind(zs->dict, member); if (de == NULL) return C_ERR; - *score = *(double*)dictGetVal(de); + zskiplistNode *znode = dictGetKey(de); + *score = znode->score; } else { serverPanic("Unknown sorted set encoding"); } @@ -1604,16 +1678,25 @@ int zsetAdd(robj *zobj, double score, sds ele, int in_flags, int *out_flags, dou zset *zs = zobj->ptr; zskiplistNode *znode; dictEntry *de; + dictEntryLink bucket, link; + + /* Use dictFindLink to find the element and get the bucket for potential insertion. + * This avoids a second lookup in dictAdd() if the element doesn't exist. */ + link = dictFindLink(zs->dict, ele, &bucket); + + if (link != NULL) { + /* Element exists - get the dictEntry from the link */ + de = *link; - de = dictFind(zs->dict,ele); - if (de != NULL) { /* NX? Return, same element already exists. */ if (nx) { *out_flags |= ZADD_OUT_NOP; return 1; } - curscore = *(double*)dictGetVal(de); + /* Get the node pointer from dict entry */ + znode = dictGetKey(de); + curscore = znode->score; /* Prepare the score for the increment if needed. */ if (incr) { @@ -1634,18 +1717,20 @@ int zsetAdd(robj *zobj, double score, sds ele, int in_flags, int *out_flags, dou /* Remove and re-insert when score changes. */ if (score != curscore) { - znode = zslUpdateScore(zs->zsl,curscore,ele,score); - /* Note that we did not removed the original element from - * the hash table representing the sorted set, so we just - * update the score. */ - dictSetVal(zs->dict, de, &znode->score); /* Update score ptr. */ + zslUpdateScore(zs->zsl, znode, score); + /* Note that we did not remove the original element from + * the hash table representing the sorted set, so we don't + * need to update the dict - the node pointer stays the same. */ *out_flags |= ZADD_OUT_UPDATED; } return 1; } else if (!xx) { - ele = sdsdup(ele); - znode = zslInsert(zs->zsl,score,ele); - serverAssert(dictAdd(zs->dict,ele,&znode->score) == DICT_OK); + /* Element doesn't exist - create node with embedded sds and add to skiplist */ + znode = zslInsert(zs->zsl, score, ele); + + /* Add node pointer to dict using the bucket we already found */ + dictSetKeyAtLink(zs->dict, znode, &bucket, 1); + *out_flags |= ZADD_OUT_ADDED; if (newscore) *newscore = score; return 1; @@ -1665,12 +1750,11 @@ int zsetAdd(robj *zobj, double score, sds ele, int in_flags, int *out_flags, dou * element. */ static int zsetRemoveFromSkiplist(zset *zs, sds ele) { dictEntry *de; - double score; de = dictUnlink(zs->dict,ele); if (de != NULL) { - /* Get the score in order to delete from the skiplist later. */ - score = *(double*)dictGetVal(de); + /* Get the node and score in order to delete from the skiplist later. */ + zskiplistNode *znode = dictGetKey(de); /* Delete from the hash table and later from the skiplist. * Note that the order is important: deleting from the skiplist @@ -1680,8 +1764,7 @@ static int zsetRemoveFromSkiplist(zset *zs, sds ele) { dictFreeUnlinkedEntry(zs->dict,de); /* Delete from skiplist. */ - int retval = zslDelete(zs->zsl,score,ele,NULL); - serverAssert(retval); + zslDelete(zs->zsl, znode); return 1; } @@ -1763,7 +1846,7 @@ long zsetRank(robj *zobj, sds ele, int reverse, double *output_score) { de = dictFind(zs->dict,ele); if (de != NULL) { - zskiplistNode *n = zsetGetSLNodeByEntry(de); + zskiplistNode *n = dictGetKey(de); rank = zslGetRankByNode(zsl, n); /* Existing elements always have a rank. */ serverAssert(rank != 0); @@ -1819,10 +1902,9 @@ robj *zsetDup(robj *o) { * O(1) instead of O(log(N)). */ ln = zsl->tail; while (llen--) { - ele = ln->ele; - sds new_ele = sdsdup(ele); - zskiplistNode *znode = zslInsert(new_zs->zsl,ln->score,new_ele); - dictAdd(new_zs->dict,new_ele,&znode->score); + ele = zslGetNodeElement(ln); + zskiplistNode *znode = zslInsert(new_zs->zsl,ln->score,ele); + dictAdd(new_zs->dict, znode, NULL); ln = ln->backward; } } else { @@ -1853,11 +1935,13 @@ void zsetTypeRandomElement(robj *zsetobj, unsigned long zsetsize, listpackEntry if (zsetobj->encoding == OBJ_ENCODING_SKIPLIST) { zset *zs = zsetobj->ptr; dictEntry *de = dictGetFairRandomKey(zs->dict); - sds s = dictGetKey(de); + zskiplistNode *znode = dictGetKey(de); + sds s = zslGetNodeElement(znode); key->sval = (unsigned char*)s; key->slen = sdslen(s); - if (score) - *score = *(double*)dictGetVal(de); + if (score) { + *score = znode->score; + } } else if (zsetobj->encoding == OBJ_ENCODING_LISTPACK) { listpackEntry val; lpRandomPair(zsetobj->ptr, zsetsize, key, &val, 2); @@ -2209,6 +2293,8 @@ void zremrangebylexCommand(client *c) { zremrangeGenericCommand(c,ZRANGE_LEX); } +/* Unified iterator source for set operations (ZUNION/ZINTER/ZDIFF). + * Provides polymorphic iteration over sets and sorted sets with different encodings. */ typedef struct { robj *subject; int type; /* Set, sorted set */ @@ -2441,7 +2527,7 @@ int zuiNext(zsetopsrc *op, zsetopval *val) { } else if (op->encoding == OBJ_ENCODING_SKIPLIST) { if (it->sl.node == NULL) return 0; - val->ele = it->sl.node->ele; + val->ele = zslGetNodeElement(it->sl.node); val->score = it->sl.node->score; /* Move to next element. (going backwards, see zuiInitIterator) */ @@ -2545,7 +2631,8 @@ int zuiFind(zsetopsrc *op, zsetopval *val, double *score) { zset *zs = op->subject->ptr; dictEntry *de; if ((de = dictFind(zs->dict,val->ele)) != NULL) { - *score = *(double*)dictGetVal(de); + zskiplistNode *znode = dictGetKey(de); + *score = znode->score; return 1; } else { return 0; @@ -2573,7 +2660,6 @@ static int zuiCompareByRevCardinality(const void *s1, const void *s2) { #define REDIS_AGGR_SUM 1 #define REDIS_AGGR_MIN 2 #define REDIS_AGGR_MAX 3 -#define zunionInterDictValue(_e) (dictGetVal(_e) == NULL ? 1.0 : *(double*)dictGetVal(_e)) inline static void zunionInterAggregate(double *target, double val, int aggregate) { if (aggregate == REDIS_AGGR_SUM) { @@ -2600,7 +2686,9 @@ static size_t zsetDictGetMaxElementLength(dict *d, size_t *totallen) { dictInitIterator(&di, d); while((de = dictNext(&di)) != NULL) { - sds ele = dictGetKey(de); + /* Extract sds from the node (key is zskiplistNode*) */ + zskiplistNode *znode = dictGetKey(de); + sds ele = zslGetNodeElement(znode); if (sdslen(ele) > maxelelen) maxelelen = sdslen(ele); if (totallen) (*totallen) += sdslen(ele); @@ -2657,9 +2745,10 @@ static void zdiffAlgorithm1(zsetopsrc *src, long setnum, zset *dstzset, size_t * if (!exists) { tmp = zuiNewSdsFromValue(&zval); znode = zslInsert(dstzset->zsl,zval.score,tmp); - dictAdd(dstzset->dict,tmp,&znode->score); + dictAdd(dstzset->dict, znode, NULL); if (sdslen(tmp) > *maxelelen) *maxelelen = sdslen(tmp); (*totelelen) += sdslen(tmp); + sdsfree(tmp); /* zslInsert copied it, we can free our copy */ } } zuiClearIterator(&src[0]); @@ -2697,8 +2786,9 @@ static void zdiffAlgorithm2(zsetopsrc *src, long setnum, zset *dstzset, size_t * if (j == 0) { tmp = zuiNewSdsFromValue(&zval); znode = zslInsert(dstzset->zsl,zval.score,tmp); - dictAdd(dstzset->dict,tmp,&znode->score); + dictAdd(dstzset->dict, znode, NULL); cardinality++; + sdsfree(tmp); /* zslInsert copied it, we can free our copy */ } else { dictPauseAutoResize(dstzset->dict); tmp = zuiSdsFromValue(&zval); @@ -2962,16 +3052,17 @@ void zunionInterDiffGenericCommand(client *c, robj *dstkey, int numkeysIndex, in } else if (j == setnum) { tmp = zuiNewSdsFromValue(&zval); znode = zslInsert(dstzset->zsl,score,tmp); - dictAdd(dstzset->dict,tmp,&znode->score); + dictAdd(dstzset->dict, znode, NULL); totelelen += sdslen(tmp); if (sdslen(tmp) > maxelelen) maxelelen = sdslen(tmp); + sdsfree(tmp); /* zslInsert copied it, we can free our copy */ } } zuiClearIterator(&src[0]); } } else if (op == SET_OP_UNION) { dictIterator di; - dictEntry *de, *existing; + dictEntry *de; double score; if (setnum) { @@ -2980,8 +3071,8 @@ void zunionInterDiffGenericCommand(client *c, robj *dstkey, int numkeysIndex, in dictExpand(dstzset->dict,zuiLength(&src[setnum-1])); } - /* Step 1: Create a dictionary of elements -> aggregated-scores - * by iterating one sorted set after the other. */ + /* Step 1: Iterate all sorted sets and aggregate scores. + * For each element, either insert into skiplist (new) or update score (existing). */ for (i = 0; i < setnum; i++) { if (zuiLength(&src[i]) == 0) continue; @@ -2991,41 +3082,42 @@ void zunionInterDiffGenericCommand(client *c, robj *dstkey, int numkeysIndex, in score = src[i].weight * zval.score; if (isnan(score)) score = 0; - /* Search for this element in the accumulating dictionary. */ - de = dictAddRaw(dstzset->dict,zuiSdsFromValue(&zval),&existing); - /* If we don't have it, we need to create a new entry. */ - if (!existing) { + /* Search for this element in the dict (which stores node pointers). */ + dictEntryLink bucket, link; + link = dictFindLink(dstzset->dict, zuiSdsFromValue(&zval), &bucket); + + if (link == NULL) { /* if not exists */ + /* New element: create node and insert into dict */ tmp = zuiNewSdsFromValue(&zval); /* Remember the longest single element encountered, * to understand if it's possible to convert to listpack * at the end. */ totelelen += sdslen(tmp); if (sdslen(tmp) > maxelelen) maxelelen = sdslen(tmp); - /* Update the element with its initial score. */ - dictSetKey(dstzset->dict, de, tmp); - dictSetDoubleVal(de,score); + + /* Create node with embedded sds and score */ + znode = zslCreateNode(dstzset->zsl, zslRandomLevel(), score, tmp); + /* Add node pointer to dict using the bucket we already found */ + dictSetKeyAtLink(dstzset->dict, znode, &bucket, 1); + sdsfree(tmp); /* zslCreateNode copied it, we can free our copy */ } else { - /* Update the score with the score of the new instance - * of the element found in the current sorted set. - * - * Here we access directly the dictEntry double - * value inside the union as it is a big speedup - * compared to using the getDouble/setDouble API. */ - double *existing_score_ptr = dictGetDoubleValPtr(existing); - zunionInterAggregate(existing_score_ptr, score, aggregate); + /* Existing element: aggregate score */ + de = *link; + znode = dictGetKey(de); + double newscore = znode->score; + zunionInterAggregate(&newscore, score, aggregate); + znode->score = newscore; } } zuiClearIterator(&src[i]); } - /* Step 2: convert the dictionary into the final sorted set. */ + /* Step 2: Done filling dict with nodes and updating scores. Now insert skiplist */ dictInitIterator(&di, dstzset->dict); while((de = dictNext(&di)) != NULL) { - sds ele = dictGetKey(de); - score = dictGetDoubleVal(de); - znode = zslInsert(dstzset->zsl,score,ele); - dictSetVal(dstzset->dict,de,&znode->score); + zskiplistNode *znode = dictGetKey(de); + zslInsertNode(dstzset->zsl, znode); } dictResetIterator(&di); } else if (op == SET_OP_DIFF) { @@ -3077,7 +3169,7 @@ void zunionInterDiffGenericCommand(client *c, robj *dstkey, int numkeysIndex, in while (zn != NULL) { if (withscores && c->resp > 2) addReplyArrayLen(c,2); - addReplyBulkCBuffer(c,zn->ele,sdslen(zn->ele)); + sds ele = zslGetNodeElement(zn); addReplyBulkCBuffer(c,ele,sdslen(ele)); if (withscores) addReplyDouble(c,zn->score); zn = zn->level[0].forward; } @@ -3389,7 +3481,7 @@ void genericZrangebyrankCommand(zrange_result_handler *handler, while(rangelen--) { serverAssertWithInfo(c,zobj,ln != NULL); - sds ele = ln->ele; + sds ele = zslGetNodeElement(ln); handler->emitResultFromCBuffer(handler, ele, sdslen(ele), ln->score); ln = reverse ? ln->backward : ln->level[0].forward; } @@ -3511,7 +3603,8 @@ void genericZrangebyscoreCommand(zrange_result_handler *handler, } rangelen++; - handler->emitResultFromCBuffer(handler, ln->ele, sdslen(ln->ele), ln->score); + sds ele = zslGetNodeElement(ln); + handler->emitResultFromCBuffer(handler, ele, sdslen(ele), ln->score); /* Move to next node */ if (reverse) { @@ -3778,13 +3871,14 @@ void genericZrangebylexCommand(zrange_result_handler *handler, while (ln && limit--) { /* Abort when the node is no longer in range. */ if (reverse) { - if (!zslLexValueGteMin(ln->ele,range)) break; + if (!zslLexValueGteMin(zslGetNodeElement(ln),range)) break; } else { - if (!zslLexValueLteMax(ln->ele,range)) break; + if (!zslLexValueLteMax(zslGetNodeElement(ln),range)) break; } rangelen++; - handler->emitResultFromCBuffer(handler, ln->ele, sdslen(ln->ele), ln->score); + sds ele = zslGetNodeElement(ln); + handler->emitResultFromCBuffer(handler, ele, sdslen(ele), ln->score); /* Move to next node */ if (reverse) { @@ -4218,7 +4312,7 @@ void genericZpopCommand(client *c, robj **keyv, int keyc, int where, int emitkey /* There must be an element in the sorted set. */ serverAssertWithInfo(c,zobj,zln != NULL); - ele = sdsdup(zln->ele); + ele = sdsdup(zslGetNodeElement(zln)); score = zln->score; } else { serverPanic("Unknown sorted set encoding"); @@ -4439,12 +4533,14 @@ void zrandmemberWithCountCommand(client *c, long l, int withscores) { zset *zs = zsetobj->ptr; while (count--) { dictEntry *de = dictGetFairRandomKey(zs->dict); - sds key = dictGetKey(de); + zskiplistNode *znode = dictGetKey(de); + sds key = zslGetNodeElement(znode); if (withscores && c->resp > 2) addReplyArrayLen(c,2); addReplyBulkCBuffer(c, key, sdslen(key)); - if (withscores) - addReplyDouble(c, *(double*)dictGetVal(de)); + if (withscores) { + addReplyDouble(c, znode->score); + } if (c->flags & CLIENT_CLOSE_ASAP) break; } @@ -4738,7 +4834,7 @@ static void zslDebugVerifyStruct(zskiplist *zsl) { /* Verify header node */ serverAssert(zsl->header != NULL); - serverAssert(zsl->header->ele == NULL); + serverAssert(zslGetNodeInfo(zsl->header)->sdsoffset == ZSL_OFFSET_NO_ELE); serverAssert(zsl->header->backward == NULL); /* Verify level is in valid range */ @@ -4766,17 +4862,17 @@ static void zslDebugVerifyStruct(zskiplist *zsl) { serverAssert(x->backward == prev); /* Verify node has valid element */ - serverAssert(x->ele != NULL); + serverAssert(zslGetNodeInfo(x)->sdsoffset != ZSL_OFFSET_NO_ELE); /* Verify node level is in valid range */ - unsigned long node_level = zslGetNodeLevel(x); + unsigned long node_level = zslGetNodeInfo(x)->levels; serverAssert(node_level >= 1 && node_level <= ZSKIPLIST_MAXLEVEL); /* Verify score ordering */ if (x->level[0].forward) { zskiplistNode *next = x->level[0].forward; serverAssert(next->score > x->score || - (next->score == x->score && sdscmp(next->ele, x->ele) > 0)); + (next->score == x->score && sdscmp(zslGetNodeElement(next), zslGetNodeElement(x)) > 0)); } /* Verify spans are correct for all levels this node participates in. @@ -4876,7 +4972,7 @@ int zsetTest(int argc, char **argv, int flags) { /* Store for later deletion - keep a copy of the element name */ elements[i].score = score; - elements[i].ele = sdsdup(ele); + elements[i].ele = ele; elements[i].node = node; /* Verify structure after each insertion */ @@ -4910,7 +5006,7 @@ int zsetTest(int argc, char **argv, int flags) { assert(rank >= 1 && rank <= (unsigned long)(i + 1)); /* Delete the element - zslDelete frees the node's SDS string */ - assert(zslDelete(zsl, score, ele, NULL)); + zslDelete(zsl, elements[i].node); /* Verify structure after each deletion */ zslDebugVerifyStruct(zsl);