diff --git a/modules/vector-sets/.gitignore b/modules/vector-sets/.gitignore new file mode 100644 index 000000000..c72b1b8e3 --- /dev/null +++ b/modules/vector-sets/.gitignore @@ -0,0 +1,11 @@ +__pycache__ +misc +*.so +*.xo +*.o +.DS_Store +w2v +word2vec.bin +TODO +*.txt +*.rdb diff --git a/modules/vector-sets/LICENSE b/modules/vector-sets/LICENSE new file mode 100644 index 000000000..79fb7e399 --- /dev/null +++ b/modules/vector-sets/LICENSE @@ -0,0 +1,2 @@ +This code is Copyright (c) 2024-Present, Redis Ltd. +All Rights Reserved. diff --git a/modules/vector-sets/Makefile b/modules/vector-sets/Makefile new file mode 100644 index 000000000..407ed08ce --- /dev/null +++ b/modules/vector-sets/Makefile @@ -0,0 +1,84 @@ +# Compiler settings +CC = cc + +ifdef SANITIZER +ifeq ($(SANITIZER),address) + SAN=-fsanitize=address +else +ifeq ($(SANITIZER),undefined) + SAN=-fsanitize=undefined +else +ifeq ($(SANITIZER),thread) + SAN=-fsanitize=thread +else + $(error "unknown sanitizer=${SANITIZER}") +endif +endif +endif +endif + +CFLAGS = -O2 -Wall -Wextra -g $(SAN) -std=c11 +LDFLAGS = -lm $(SAN) + +# Detect OS +uname_S := $(shell sh -c 'uname -s 2>/dev/null || echo not') +uname_M := $(shell sh -c 'uname -m 2>/dev/null || echo not') + +# Shared library compile flags for linux / osx +ifeq ($(uname_S),Linux) + SHOBJ_CFLAGS ?= -W -Wall -fno-common -g -ggdb -std=c11 -O2 + SHOBJ_LDFLAGS ?= -shared +ifneq (,$(findstring armv,$(uname_M))) + SHOBJ_LDFLAGS += -latomic +endif +ifneq (,$(findstring aarch64,$(uname_M))) + SHOBJ_LDFLAGS += -latomic +endif +else + SHOBJ_CFLAGS ?= -W -Wall -dynamic -fno-common -g -ggdb -std=c11 -O3 + SHOBJ_LDFLAGS ?= -bundle -undefined dynamic_lookup +endif + +# OS X 11.x doesn't have /usr/lib/libSystem.dylib and needs an explicit setting. +ifeq ($(uname_S),Darwin) +ifeq ("$(wildcard /usr/lib/libSystem.dylib)","") +LIBS = -L /Library/Developer/CommandLineTools/SDKs/MacOSX.sdk/usr/lib -lsystem +endif +endif + +.SUFFIXES: .c .so .xo .o + +all: vset.so + +.c.xo: + $(CC) -I. $(CFLAGS) $(SHOBJ_CFLAGS) -fPIC -c $< -o $@ + +vset.xo: redismodule.h expr.c + +vset.so: vset.xo hnsw.xo cJSON.xo + $(CC) -o $@ $^ $(SHOBJ_LDFLAGS) $(LIBS) $(SAN) -lc + +# Example sources / objects +SRCS = hnsw.c w2v.c +OBJS = $(SRCS:.c=.o) + +TARGET = w2v +MODULE = vset.so + +# Default target +all: $(TARGET) $(MODULE) + +# Example linking rule +$(TARGET): $(OBJS) + $(CC) $(OBJS) $(LDFLAGS) -o $(TARGET) + +# Compilation rule for object files +%.o: %.c + $(CC) $(CFLAGS) -c $< -o $@ + +# Clean rule +clean: + rm -f $(TARGET) $(OBJS) *.xo *.so + +# Declare phony targets +.PHONY: all clean diff --git a/modules/vector-sets/README.md b/modules/vector-sets/README.md new file mode 100644 index 000000000..1ed1cb2cf --- /dev/null +++ b/modules/vector-sets/README.md @@ -0,0 +1,633 @@ +This module implements Vector Sets for Redis, a new Redis data type similar +to Sorted Sets but having string elements associated to a vector instead of +a score. The fundamental goal of Vector Sets is to make possible adding items, +and later get a subset of the added items that are the most similar to a +specified vector (often a learned embedding), or the most similar to the vector +of an element that is already part of the Vector Set. + +Moreover, Vector sets implement optional filtered search capabilities: it is possible to associate attributes to all or to a subset of elements in the set, and then, using the `FILTER` option of the `VSIM` command, to ask for items similar to a given vector but also passing a filter specified as a simple mathematical expression (Like `".year > 1950"` or similar). This means that **you can have vector similarity and scalar filters at the same time**. + +## Installation + +Build with: + + make + +Then load the module with the following command line, or by inserting the needed directives in the `redis.conf` file. + + ./redis-server --loadmodule vset.so + +To run tests, I suggest using this: + + ./redis-server --save "" --enable-debug-command yes + +The execute the tests with: + + ./test.py + +## Reference of available commands + +**VADD: add items into a vector set** + + VADD key [REDUCE dim] FP32|VALUES vector element [CAS] [NOQUANT | Q8 | BIN] + [EF build-exploration-factor] [SETATTR ] [M ] + +Add a new element into the vector set specified by the key. +The vector can be provided as FP32 blob of values, or as floating point +numbers as strings, prefixed by the number of elements (3 in the example): + + VADD mykey VALUES 3 0.1 1.2 0.5 my-element + +Meaning of the options: + +`REDUCE` implements random projection, in order to reduce the +dimensionality of the vector. The projection matrix is saved and reloaded +along with the vector set. **Please note that** the `REDUCE` option must be passed immediately before the vector, like in `REDUCE 50 VALUES ...`. + +`CAS` performs the operation partially using threads, in a +check-and-set style. The neighbor candidates collection, which is slow, is +performed in the background, while the command is executed in the main thread. + +`NOQUANT` forces the vector to be created (in the first VADD call to a given key) without integer 8 quantization, which is otherwise the default. + +`BIN` forces the vector to use binary quantization instead of int8. This is much faster and uses less memory, but has impacts on the recall quality. + +`Q8` forces the vector to use signed 8 bit quantization. This is the default, and the option only exists in order to make sure to check at insertion time if the vector set is of the same format. + +`EF` plays a role in the effort made to find good candidates when connecting the new node to the existing HNSW graph. The default is 200. Using a larger value, may help to have a better recall. To improve the recall it is also possible to increase `EF` during `VSIM` searches. + +`SETATTR` associates attributes to the newly created entry or update the entry attributes (if it already exists). It is the same as calling the `VSETATTR` attribute separately, so please check the documentation of that command in the filtered search section of this documentation. + +`M` defaults to 16 and is the HNSW famous `M` parameters. It is the maximum number of connections that each node of the graph have with other nodes: more connections mean more memory, but a better ability to explore the graph. Nodes at layer zero (every node exists at least at layer zero) have `M*2` connections, while the other layers only have `M` connections. This means that, for instance, an `M` of 64 will use at least 1024 bytes of memory for each node! That is, `64 links * 2 times * 8 bytes pointers`, and even more, since on average each node has something like 1.33 layers (but the other layers have just `M` connections, instead of `M*2`). If you don't have a recall quality problem, the default is fine, and uses a limited amount of memory. + +**VSIM: return elements by vector similarity** + + VSIM key [ELE|FP32|VALUES] [WITHSCORES] [COUNT num] [EF search-exploration-factor] [FILTER expression] [FILTER-EF max-filtering-effort] [TRUTH] [NOTHREAD] + +The command returns similar vectors, for simplicity (and verbosity) in the following example, instead of providing a vector using FP32 or VALUES (like in `VADD`), we will ask for elements having a vector similar to a given element already in the sorted set: + + > VSIM word_embeddings ELE apple + 1) "apple" + 2) "apples" + 3) "pear" + 4) "fruit" + 5) "berry" + 6) "pears" + 7) "strawberry" + 8) "peach" + 9) "potato" + 10) "grape" + +It is possible to specify a `COUNT` and also to get the similarity score (from 1 to 0, where 1 is identical, 0 is opposite vector) between the query and the returned items. + + > VSIM word_embeddings ELE apple WITHSCORES COUNT 3 + 1) "apple" + 2) "0.9998867657923256" + 3) "apples" + 4) "0.8598527610301971" + 5) "pear" + 6) "0.8226882219314575" + +The `EF` argument is the exploration factor: the higher it is, the slower the command becomes, but the better the index is explored to find nodes that are near to our query. Sensible values are from 50 to 1000. + +The `TRUTH` option forces the command to perform a linear scan of all the entries inside the set, without using the graph search inside the HNSW, so it returns the best matching elements (the perfect result set) that can be used in order to easily calculate the recall. Of course the linear scan is `O(N)`, so it is much slower than the `log(N)` (considering a small `COUNT`) provided by the HNSW index. + +The `NOTHREAD` option forces the command to execute the search on the data structure in the main thread. Normally `VSIM` spawns a thread instead. This may be useful for benchmarking purposes, or when we work with extremely small vector sets and don't want to pay the cost of spawning a thread. It is possible that in the future this option will be automatically used by Redis when we detect small vector sets. Note that this option blocks the server for all the time needed to complete the command, so it is a source of potential latency issues: if you are in doubt, never use it. + +For `FILTER` and `FILTER-EF` options, please check the filtered search section of this documentation. + +**VDIM: return the dimension of the vectors inside the vector set** + + VDIM keyname + +Example: + + > VDIM word_embeddings + (integer) 300 + +Note that in the case of vectors that were populated using the `REDUCE` +option, for random projection, the vector set will report the size of +the projected (reduced) dimension. Yet the user should perform all the +queries using full-size vectors. + +**VCARD: return the number of elements in a vector set** + + VCARD key + +Example: + + > VCARD word_embeddings + (integer) 3000000 + + +**VREM: remove elements from vector set** + + VREM key element + +Example: + + > VADD vset VALUES 3 1 0 1 bar + (integer) 1 + > VREM vset bar + (integer) 1 + > VREM vset bar + (integer) 0 + +VREM does not perform thumstone / logical deletion, but will actually reclaim +the memory from the vector set, so it is save to add and remove elements +in a vector set in the context of long running applications that continuously +update the same index. + +**VEMB: return the approximated vector of an element** + + VEMB key element + +Example: + + > VEMB word_embeddings SQL + 1) "0.18208661675453186" + 2) "0.08535309880971909" + 3) "0.1365649551153183" + 4) "-0.16501599550247192" + 5) "0.14225517213344574" + ... 295 more elements ... + +Because vector sets perform insertion time normalization and optional +quantization, the returned vector could be approximated. `VEMB` will take +care to de-quantized and de-normalize the vector before returning it. + +It is possible to ask VEMB to return raw data, that is, the interal representation used by the vector: fp32, int8, or a bitmap for binary quantization. This behavior is triggered by the `RAW` option of of VEMB: + + VEMB word_embedding apple RAW + +In this case the return value of the command is an array of three or more elements: +1. The name of the quantization used, that is one of: "fp32", "bin", "q8". +2. The a string blob containing the raw data, 4 bytes fp32 floats for fp32, a bitmap for binary quants, or int8 bytes array for q8 quants. +3. A float representing the l2 of the vector before normalization. You need to multiply by this vector if you want to de-normalize the value for any reason. + +For q8 quantization, an additional elements is also returned: the quantization +range, so the integers from -127 to 127 represent (normalized) components +in the range `-range`, `+range`. + +**VLINKS: introspection command that shows neighbors for a node** + + VLINKS key element [WITHSCORES] + +The command reports the neighbors for each level. + +**VINFO: introspection command that shows info about a vector set** + + VINFO key + +Example: + + > VINFO word_embeddings + 1) quant-type + 2) int8 + 3) vector-dim + 4) (integer) 300 + 5) size + 6) (integer) 3000000 + 7) max-level + 8) (integer) 12 + 9) vset-uid + 10) (integer) 1 + 11) hnsw-max-node-uid + 12) (integer) 3000000 + +**VSETATTR: associate or remove the JSON attributes of elements** + + VSETATTR key element "{... json ...}" + +Each element of a vector set can be optionally associated with a JSON string +in order to use the `FILTER` option of `VSIM` to filter elements by scalars +(see the filtered search section for more information). This command can set, +update (if already set) or delete (if you set to an empty string) the +associated JSON attributes of an element. + +The command returns 0 if the element or the key don't exist, without +raising an error, otherwise 1 is returned, and the element attributes +are set or updated. + +**VGETATTR: retrieve the JSON attributes of elements** + + VGETATTR key element + +The command returns the JSON attribute associated with an element, or +null if there is no element associated, or no element at all, or no key. + +**VRANDMEMBER: return random members from a vector set** + + VRANDMEMBER key [count] + +Return one or more random elements from a vector set. + +The semantics of this command are similar to Redis's native SRANDMEMBER command: + +- When called without count, returns a single random element from the set, as a single string (no array reply). +- When called with a positive count, returns up to count distinct random elements (no duplicates). +- When called with a negative count, returns count random elements, potentially with duplicates. +- If the count value is larger than the set size (and positive), only the entire set is returned. + +If the key doesn't exist, returns a Null reply if count is not given, or an empty array if a count is provided. + +Examples: + + > VADD vset VALUES 3 1 0 0 elem1 + (integer) 1 + > VADD vset VALUES 3 0 1 0 elem2 + (integer) 1 + > VADD vset VALUES 3 0 0 1 elem3 + (integer) 1 + + # Return a single random element + > VRANDMEMBER vset + "elem2" + + # Return 2 distinct random elements + > VRANDMEMBER vset 2 + 1) "elem1" + 2) "elem3" + + # Return 3 random elements with possible duplicates + > VRANDMEMBER vset -3 + 1) "elem2" + 2) "elem2" + 3) "elem1" + + # Return more elements than in the set (returns all elements) + > VRANDMEMBER vset 10 + 1) "elem1" + 2) "elem2" + 3) "elem3" + + # When key doesn't exist + > VRANDMEMBER nonexistent + (nil) + > VRANDMEMBER nonexistent 3 + (empty array) + +This command is particularly useful for: + +1. Selecting random samples from a vector set for testing or training. +2. Performance testing by retrieving random elements for subsequent similarity searches. + +When the user asks for unique elements (positev count) the implementation optimizes for two scenarios: +- For small sample sizes (less than 20% of the set size), it uses a dictionary to avoid duplicates, and performs a real random walk inside the graph. +- For large sample sizes (more than 20% of the set size), it starts from a random node and sequentially traverses the internal list, providing faster performances but not really "random" elements. + +The command has `O(N)` worst-case time complexity when requesting many unique elements (it uses linear scanning), or `O(M*log(N))` complexity when the users asks for `M` random elements in a sorted set of `N` elements, with `M` much smaller than `N`. + +# Filtered search + +Each element of the vector set can be associated with a set of attributes specified as a JSON blob: + + > VADD vset VALUES 3 1 1 1 a SETATTR '{"year": 1950}' + (integer) 1 + > VADD vset VALUES 3 -1 -1 -1 b SETATTR '{"year": 1951}' + (integer) 1 + +Specifying an attribute with the `SETATTR` option of `VADD` is exactly equivalent to adding an element and then setting (or updating, if already set) the attributes JSON string. Also the symmetrical `VGETATTR` command returns the attribute associated to a given element. + + > VADD vset VALUES 3 0 1 0 c + (integer) 1 + > VSETATTR vset c '{"year": 1952}' + (integer) 1 + > VGETATTR vset c + "{\"year\": 1952}" + +At this point, I may use the FILTER option of VSIM to only ask for the subset of elements that are verified by my expression: + + > VSIM vset VALUES 3 0 0 0 FILTER '.year > 1950' + 1) "c" + 2) "b" + +The items will be returned again in order of similarity (most similar first), but only the items with the year field matching the expression is returned. + +The expressions are similar to what you would write inside the `if` statement of JavaScript or other familiar programming languages: you can use `and`, `or`, the obvious math operators like `+`, `-`, `/`, `>=`, `<`, ... and so forth (see the expressions section for more info). The selectors of the JSON object attributes start with a dot followed by the name of the key inside the JSON objects. + +Elements with invalid JSON or not having a given specified field **are considered as not matching** the expression, but will not generate any error at runtime. + +## FILTER expressions capabilities + +FILTER expressions allow you to perform complex filtering on vector similarity results using a JavaScript-like syntax. The expression is evaluated against each element's JSON attributes, with only elements that satisfy the expression being included in the results. + +### Expression Syntax + +Expressions support the following operators and capabilities: + +1. **Arithmetic operators**: `+`, `-`, `*`, `/`, `%` (modulo), `**` (exponentiation) +2. **Comparison operators**: `>`, `>=`, `<`, `<=`, `==`, `!=` +3. **Logical operators**: `and`/`&&`, `or`/`||`, `!`/`not` +4. **Containment operator**: `in` +5. **Parentheses** for grouping: `(...)` + +### Selector Notation + +Attributes are accessed using dot notation: + +- `.year` references the "year" attribute +- `.movie.year` would **NOT** reference the "year" field inside a "movie" object, only keys that are at the first level of the JSON object are accessible. + +### JSON and expressions data types + +Expressions can work with: + +- Numbers (dobule precision floats) +- Strings (enclosed in single or double quotes) +- Booleans (no native type: they are represented as 1 for true, 0 for false) +- Arrays (for use with the `in` operator: `value in [1, 2, 3]`) + +JSON attributes are converted in this way: + +- Numbers will be converted to numbers. +- Strings to strings. +- Booleans to 0 or 1 number. +- Arrays to tuples (for "in" operator), but only if composed of just numbers and strings. + +Any other type is ignored, and accessig it will make the expression evaluate to false. + +### Examples + +``` +# Find items from the 1980s +VSIM movies VALUES 3 0.5 0.8 0.2 FILTER '.year >= 1980 and .year < 1990' + +# Find action movies with high ratings +VSIM movies VALUES 3 0.5 0.8 0.2 FILTER '.genre == "action" and .rating > 8.0' + +# Find movies directed by either Spielberg or Nolan +VSIM movies VALUES 3 0.5 0.8 0.2 FILTER '.director in ["Spielberg", "Nolan"]' + +# Complex condition with numerical operations +VSIM movies VALUES 3 0.5 0.8 0.2 FILTER '(.year - 2000) ** 2 < 100 and .rating / 2 > 4' +``` + +### Error Handling + +Elements with any of the following conditions are considered not matching: +- Missing the queried JSON attribute +- Having invalid JSON in their attributes +- Having a JSON value that cannot be converted to the expected type + +This behavior allows you to safely filter on optional attributes without generating errors. + +### FILTER effort + +The `FILTER-EF` option controls the maximum effort spent when filtering vector search results. + +When performing vector similarity search with filtering, Vector Sets perform the standard similarity search as they apply the filter expression to each node. Since many results might be filtered out, Vector Sets may need to examine a lot more candidates than the requested `COUNT` to ensure sufficient matching results are returned. Actually, if the elements matching the filter are very rare or if there are less than elements matching than the specified count, this would trigger a full scan of the HNSW graph. + +For this reason, by default, the maximum effort is limited to a reasonable amount of nodes explored. + +### Modifying the FILTER effort + +1. By default, Vector Sets will explore up to `COUNT * 100` candidates to find matching results. +2. You can control this exploration with the `FILTER-EF` parameter. +3. A higher `FILTER-EF` value increases the chances of finding all relevant matches at the cost of increased processing time. +4. A `FILTER-EF` of zero will explore as many nodes as needed in order to actually return the number of elements specified by `COUNT`. +5. Even when a high `FILTER-EF` value is specified **the implementation will do a lot less work** if the elements passing the filter are very common, because of the early stop conditions of the HNSW implementation (once the specified amount of elements is reached and the quality check of the other candidates trigger an early stop). + +``` +VSIM key [ELE|FP32|VALUES] COUNT 10 FILTER '.year > 2000' FILTER-EF 500 +``` + +In this example, Vector Sets will examine up to 500 potential nodes. Of course if count is reached before exploring 500 nodes, and the quality checks show that it is not possible to make progresses on similarity, the search is ended sooner. + +### Performance Considerations + +- If you have highly selective filters (few items match), use a higher `FILTER-EF`, or just design your application in order to handle a result set that is smaller than the requested count. Note that anyway the additional elements may be too distant than the query vector. +- For less selective filters, the default should be sufficient. +- Very selective filters with low `FILTER-EF` values may return fewer items than requested. +- Extremely high values may impact performance without significantly improving results. + +The optimal `FILTER-EF` value depends on: +1. The selectivity of your filter. +2. The distribution of your data. +3. The required recall quality. + +A good practice is to start with the default and increase if needed when you observe fewer results than expected. + +### Testing a larg-ish data set + +To really see how things work at scale, you can [download](https://antirez.com/word2vec_with_attribs.rdb) the following dataset: + + wget https://antirez.com/word2vec_with_attribs.rdb + +It contains the 3 million words in Word2Vec having as attribute a JSON with just the length of the word. Because of the length distribution of words in large amounts of texts, where longer words become less and less common, this is ideal to check how filtering behaves with a filter verifying as true with less and less elements in a vector set. + +For instance: + + > VSIM word_embeddings_bin ele "pasta" FILTER ".len == 6" + 1) "pastas" + 2) "rotini" + 3) "gnocci" + 4) "panino" + 5) "salads" + 6) "breads" + 7) "salame" + 8) "sauces" + 9) "cheese" + 10) "fritti" + +This will easily retrieve the desired amount of items (`COUNT` is 10 by default) since there are many items of length 6. However: + + > VSIM word_embeddings_bin ele "pasta" FILTER ".len == 33" + 1) "skinless_boneless_chicken_breasts" + 2) "boneless_skinless_chicken_breasts" + 3) "Boneless_skinless_chicken_breasts" + +This time even if we asked for 10 items, we only get 3, since the default filter effort will be `10*100 = 1000`. We can tune this giving the effort in an explicit way, with the risk of our query being slower, of course: + + > VSIM word_embeddings_bin ele "pasta" FILTER ".len == 33" FILTER-EF 10000 + 1) "skinless_boneless_chicken_breasts" + 2) "boneless_skinless_chicken_breasts" + 3) "Boneless_skinless_chicken_breasts" + 4) "mozzarella_feta_provolone_cheddar" + 5) "Greatfood.com_R_www.greatfood.com" + 6) "Pepperidge_Farm_Goldfish_crackers" + 7) "Prosecuted_Mobsters_Rebuilt_Dying" + 8) "Crispy_Snacker_Sandwiches_Popcorn" + 9) "risultati_delle_partite_disputate" + 10) "Peppermint_Mocha_Twist_Gingersnap" + +This time we get all the ten items, even if the last one will be quite far from our query vector. We encourage to experiment with this test dataset in order to understand better the dynamics of the implementation and the natural tradeoffs of filtered search. + +**Keep in mind** that by default, Redis Vector Sets will try to avoid a likely very useless huge scan of the HNSW graph, and will be more happy to return few or no elements at all, since this is almost always what the user actually wants in the context of retrieving *similar* items to the query. + +# Single Instance Scalability and Latency + +Vector Sets implement a threading model that allows Redis to handle many concurrent requests: by default `VSIM` is always threaded, and `VADD` is not (but can be partially threaded using the `CAS` option). This section explains how the threading and locking mechanisms work, and what to expect in terms of performance. + +## Threading Model + +- The `VSIM` command runs in a separate thread by default, allowing Redis to continue serving other commands. +- A maximum of 32 threads can run concurrently (defined by `HNSW_MAX_THREADS`). +- When this limit is reached, additional `VSIM` requests are queued - Redis remains responsive, no latency event is generated. +- The `VADD` command with the `CAS` option also leverages threading for the computation-heavy candidate search phase, but the insertion itself is performed in the main thread. `VADD` always runs in a sub-millisecond time, so this is not a source of latency, but having too many hundreds of writes per second can be challenging to handle with a single instance. Please, look at the next section about multiple instances scalability. +- Commands run within Lua scripts, MULTI/EXEC blocks, or from replication are executed in the main thread to ensure consistency. + +``` +> VSIM vset VALUES 3 1 1 1 FILTER '.year > 2000' # This runs in a thread. +> VADD vset VALUES 3 1 1 1 element CAS # Candidate search runs in a thread. +``` + +## Locking Mechanism + +Vector Sets use a read/write locking mechanism to coordinate access: + +- Reads (`VSIM`, `VEMB`, etc.) acquire a read lock, allowing multiple concurrent reads. +- Writes (`VADD`, `VREM`, etc.) acquire a write lock, temporarily blocking all reads. +- When a write lock is requested while reads are in progress, the write operation waits for all reads to complete. +- Once a write lock is granted, all reads are blocked until the write completes. +- Each thread has a dedicated slot for tracking visited nodes during graph traversal, avoiding contention. This improves performances but limits the maximum number of concurrent threads, since each node has a memory cost proportional to the number of slots. + +## DEL latency + +Deleting a very large vector set (millions of elements) can cause latency spikes, as deletion rebuilds connections between nodes. This may change in the future. +The deletion latency is most noticeable when using `DEL` on a key containing a large vector set or when the key expires. + +## Performance Characteristics + +- Search operations (`VSIM`) scale almost linearly with the number of CPU cores available, up to the thread limit. You can expect a Vector Set composed of million of items associated with components of dimension 300, with the default int8 quantization, to deliver around 50k VSIM operations per second in a single host. +- Insertion operations (`VADD`) are more computationally expensive than searches, and can't be threaded: expect much lower throughput, in the range of a few thousands inserts per second. +- Binary quantization offers significantly faster search performance at the cost of some recall quality, while int8 quantization, the default, seems to have very small impacts on recall quality, while it significantly improves performances and space efficiency. +- The `EF` parameter has a major impact on both search quality and performance - higher values mean better recall but slower searches. +- Graph traversal time scales logarithmically with the number of elements, making Vector Sets efficient even with millions of vectors + +## Loading / Saving performances + +Vector Sets are able to serialize on disk the graph structure as it is in memory, so loading back the data does not need to rebuild the HNSW graph. This means that Redis can load millions of items per minute. For instance 3 million items with 300 components vectors can be loaded back into memory into around 15 seconds. + +# Scaling vector sets to multiple instances + +The fundamental way vector sets can be scaled to very large data sets +and to many Redis instances is that a given very large set of vectors +can be partitioned into N different Redis keys, that can also live into +different Redis instances. + +For instance, I could add my elements into `key0`, `key1`, `key2`, by hashing +the item in some way, like doing `crc32(item)%3`, effectively splitting +the dataset into three different parts. However once I want all the vectors +of my dataset near to a given query vector, I could simply perform the +`VSIM` command against all the three keys, merging the results by +score (so the commands must be called using the `WITHSCORES` option) on +the client side: once the union of the results are ordered by the +similarity score, the query is equivalent to having a single key `key1+2+3` +containing all the items. + +There are a few interesting facts to note about this pattern: + +1. It is possible to have a logical sorted set that is as big as the sum of all the Redis instances we are using. +2. Deletion operations remain simple, we can hash the key and select the key where our item belongs. +3. However, even if I use 10 different Redis instances, I'm not going to reach 10x the **read** operations per second, compared to using a single server: for each logical query, I need to query all the instances. Yet, smaller graphs are faster to navigate, so there is some win even from the point of view of CPU usage. +4. Insertions, so **write** queries, will be scaled linearly: I can add N items against N instances at the same time, splitting the insertion load evenly. This is very important since vector sets, being based on HNSW data structures, are slower to add items than to query similar items, by a very big factor. +5. While it cannot guarantee always the best results, with proper timeout management this system may be considered *highly available*, since if a subset of N instances are reachable, I'll be still be able to return similar items to my query vector. + +Notably, this pattern can be implemented in a way that avoids paying the sum of the round trip time with all the servers: it is possible to send the queries at the same time to all the instances, so that latency will be equal the slower reply out of of the N servers queries. + +# Optimizing memory usage + +Vector Sets, or better, HNSWs, the underlying data structure used by Vector Sets, combined with the features provided by the Vector Sets themselves (quantization, random projection, filtering, ...) form an implementation that has a non-trivial space of parameters that can be tuned. Despite to the complexity of the implementation and of vector similarity problems, here there is a list of simple ideas that can drive the user to pick the best settings: + +* 8 bit quantization (the default) is almost always a win. It reduces the memory usage of vectors by a factor of 4, yet the performance penality in terms of recall is minimal. It also reduces insertion and search time by around 2 times or more. +* Binary quantization is much more extreme: it makes vector sets a lot faster, but increases the recall error in a sensible way, for instance from 95% to 80% if all the parameters remain the same. Yet, the speedup is really big, and the memory usage of vectors, compaerd to full precision vectors, 32 times smaller. +* Vectors memory usage are not the only responsible for Vector Set high memory usage per entry: nodes contain, on average `M*2 + M*0.33` pointers, where M is by default 16 (but can be tuned in `VADD`, see the `M` option). Also each node has the string item and the optional JSON attributes: those should be as small as possible in order to avoid contributing more to the memory usage. +* The `M` parameter should be incresed to 32 or more only when a near perfect recall is really needed. +* It is possible to gain space (less memory usage) sacrificing time (more CPU time) by using a low `M` (the default of 16, for instance) and a high `EF` (the effort parameter of `VSIM`) in order to scan the graph more deeply. +* When memory usage is seriosu concern, and there is the suspect the vectors we are storing don't contain as much information - at least for our use case - to justify the number of components they feature, random projection (the `REDUCE` option of `VADD`) could be tested to see if dimensionality reduction is possible with acceptable precision loss. + +## Random projection tradeoffs + +Sometimes learned vectors are not as information dense as we could guess, that +is there are components having similar meanings in the space, and components +having values that don't really represent features that matter in our use case. + +At the same time, certain vectors are very big, 1024 components or more. In this cases, it is possible to use the random projection feature of Redis Vector Sets in order to reduce both space (less RAM used) and space (more operstions per second). The feature is accessible via the `REDUCE` option of the `VADD` command. However, keep in mind that you need to test how much reduction impacts the performances of your vectors in term of recall and quality of the results you get back. + +## What is a random projection? + +The concept of Random Projection is relatively simple to grasp. For instance, a projection that turns a 100 components vector into a 10 components vector will perform a different linear transformation between the 100 components and each of the target 10 components. Please note that *each of the target components* will get some random amount of all the 100 original components. It is mathematically proved that this process results in a vector space where elements still have similar distances among them, but still some information will get lost. + +## Examples of projections and loss of precision + +To show you a bit of a extreme case, let's take Word2Vec 3 million items and compress them from 300 to 100, 50 and 25 components vectors. Then, we check the recall compared to the ground truth against each of the vector sets produced in this way (using different `REDUCE` parameters of `VADD`). This is the result, obtained asking for the top 10 elements. + +``` +---------------------------------------------------------------------- +Key Average Recall % Std Dev +---------------------------------------------------------------------- +word_embeddings_int8 95.98 12.14 + ^ This is the same key used for ground truth, but without TRUTH option +word_embeddings_reduced_100 40.20 20.13 +word_embeddings_reduced_50 24.42 16.89 +word_embeddings_reduced_25 14.31 9.99 +``` + +Here the dimensionality reduction we are using is quite extreme: from 300 to 100 means that 66.6% of the original information is lost. The recall drops from 96% to 40%, down to 24% and 14% for even more extreme dimension reduction. + +Reducing the dimension of vectors that are already relatively small, like the above example, of 300 components, will provide only relatively small memory savings, especially because by default Vector Sets use `int8` quantization, that will use only one byte per component: + +``` +> MEMORY USAGE word_embeddings_int8 +(integer) 3107002888 +> MEMORY USAGE word_embeddings_reduced_100 +(integer) 2507122888 +``` + +Of course going, for example, from 2048 component vectors to 1024 would provide a much more sensible memory saving, even with the `int8` quantization used by Vector Sets, assuming the recall loss is acceptable. Other than the memory saving, there is also the reduction in CPU time, translating to more operations per second. + +Another thing to note is that, with certain embedding models, binary quantization (that offers a 8x reduction of memory usage compared to 8 bit quants, and a very big speedup in computation) performs much better than reducing the dimension of vectors of the same amount via random projections: + +``` +word_embeddings_bin 35.48 19.78 +``` + +Here in the same test did above: we have a 35% recall which is not too far than the 40% obtained with a random projection from 300 to 100 components. However, while the first technique reduces the size by 3 times, the size reduced of binary quantization is by 8 times. + +``` +> memory usage word_embeddings_bin +(integer) 2327002888 +``` + +In this specific case the key uses JSON attributes and has a graph connection overhead that is much bigger than the 300 bits each vector takes, but, as already said, for big vectors (1024 components, for instance) or for lower values of `M` (see `VADD`, the `M` parameter connects the level of connectivity, so it changes the amount of pointers used per node) the memory saving is much stronger. + +# Vector Sets troubleshooting and understandability + +## Debugging poor recall or unexpected results + +Vector graphs and similarity queries pose many challenges mainly due to the following three problems: + +1. The error due to the approximated nature of Vector Sets is hard to evaluate. +2. The error added by the quantization is often depends on the exact vector space (the embedding we are using **and** how far apart the elements we represent into such embeddings are). +3. We live in the illusion that learned embeddings capture the best similarity possible among elements, which is obviously not always true, and highly application dependent. + +The only way to debug such problems, is the ability to inspect step by step what is happening inside our application, and the structure of the HNSW graph itself. To do so, we suggest to consider the following tools: + +1. The `TRUTH` option of the `VSIM` command is able to return the ground truth of the most similar elements, without using the HNSW graph, but doing a linear scan. +2. The `VLINKS` command allows to explore the graph to see if the connections among nodes make sense, and to investigate why a given node may be more isolated than expected. Such command can also be used in a different way, when we want very fast "similar items" without paying the HNSW traversal time. It exploits the fact that we have a direct reference from each element in our vector set to each node in our HNSW graph. +3. The `WITHSCORES` option, in the supported commands, return a value that is directly related to the *cosine similarity* between the query and the items vectors, the interval of the similarity is simply rescaled from the -1, 1 original range to 0, 1, otherwise the metric is identical. + +## Clients, latency and bandwidth usage + +During Vector Sets testing, we discovered that often clients introduce considerable latecy and CPU usage (in the client side, not in Redis) for two main reasons: + +1. Often the serialization to `VALUES ... list of floats ...` can be very slow. +2. The vector payload of floats represented as strings is very large, resulting in high bandwidth usage and latency, compared to other Redis commands. + +Switching from `VALUES` to `FP32` as a method for transmitting vectors may easily provide 10-20x speedups. + +# Known bugs + +* Replication code is pretty much untested, and very vanilla (replicating the commands verbatim). + +# Implementation details + +Vector sets are based on the `hnsw.c` implementation of the HNSW data structure with extensions for speed and functionality. + +The main features are: + +* Proper nodes deletion with relinking. +* 8 bits and binary quantization. +* Threaded queries. +* Filtered search with predicate callback. diff --git a/modules/vector-sets/cJSON.c b/modules/vector-sets/cJSON.c new file mode 100644 index 000000000..bcbb3414f --- /dev/null +++ b/modules/vector-sets/cJSON.c @@ -0,0 +1,3164 @@ +/* + Copyright (c) 2009-2017 Dave Gamble and cJSON contributors + + Permission is hereby granted, free of charge, to any person obtaining a copy + of this software and associated documentation files (the "Software"), to deal + in the Software without restriction, including without limitation the rights + to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + copies of the Software, and to permit persons to whom the Software is + furnished to do so, subject to the following conditions: + + The above copyright notice and this permission notice shall be included in + all copies or substantial portions of the Software. + + THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + THE SOFTWARE. +*/ + +/* cJSON */ +/* JSON parser in C. */ + +/* disable warnings about old C89 functions in MSVC */ +#if !defined(_CRT_SECURE_NO_DEPRECATE) && defined(_MSC_VER) +#define _CRT_SECURE_NO_DEPRECATE +#endif + +#ifdef __GNUC__ +#pragma GCC visibility push(default) +#endif +#if defined(_MSC_VER) +#pragma warning (push) +/* disable warning about single line comments in system headers */ +#pragma warning (disable : 4001) +#endif + +#include +#include +#include +#include +#include +#include +#include + +#ifdef ENABLE_LOCALES +#include +#endif + +#if defined(_MSC_VER) +#pragma warning (pop) +#endif +#ifdef __GNUC__ +#pragma GCC visibility pop +#endif + +#include "cJSON.h" + +/* define our own boolean type */ +#ifdef true +#undef true +#endif +#define true ((cJSON_bool)1) + +#ifdef false +#undef false +#endif +#define false ((cJSON_bool)0) + +/* define isnan and isinf for ANSI C, if in C99 or above, isnan and isinf has been defined in math.h */ +#ifndef isinf +#define isinf(d) (isnan((d - d)) && !isnan(d)) +#endif +#ifndef isnan +#define isnan(d) (d != d) +#endif + +#ifndef NAN +#ifdef _WIN32 +#define NAN sqrt(-1.0) +#else +#define NAN 0.0/0.0 +#endif +#endif + +typedef struct { + const unsigned char *json; + size_t position; +} error; +static error global_error = { NULL, 0 }; + +CJSON_PUBLIC(const char *) cJSON_GetErrorPtr(void) +{ + return (const char*) (global_error.json + global_error.position); +} + +CJSON_PUBLIC(char *) cJSON_GetStringValue(const cJSON * const item) +{ + if (!cJSON_IsString(item)) + { + return NULL; + } + + return item->valuestring; +} + +CJSON_PUBLIC(double) cJSON_GetNumberValue(const cJSON * const item) +{ + if (!cJSON_IsNumber(item)) + { + return (double) NAN; + } + + return item->valuedouble; +} + +/* This is a safeguard to prevent copy-pasters from using incompatible C and header files */ +#if (CJSON_VERSION_MAJOR != 1) || (CJSON_VERSION_MINOR != 7) || (CJSON_VERSION_PATCH != 18) + #error cJSON.h and cJSON.c have different versions. Make sure that both have the same. +#endif + +CJSON_PUBLIC(const char*) cJSON_Version(void) +{ + static char version[15]; + snprintf(version, sizeof(version), "%i.%i.%i", CJSON_VERSION_MAJOR, CJSON_VERSION_MINOR, CJSON_VERSION_PATCH); + + return version; +} + +/* Case insensitive string comparison, doesn't consider two NULL pointers equal though */ +static int case_insensitive_strcmp(const unsigned char *string1, const unsigned char *string2) +{ + if ((string1 == NULL) || (string2 == NULL)) + { + return 1; + } + + if (string1 == string2) + { + return 0; + } + + for(; tolower(*string1) == tolower(*string2); (void)string1++, string2++) + { + if (*string1 == '\0') + { + return 0; + } + } + + return tolower(*string1) - tolower(*string2); +} + +typedef struct internal_hooks +{ + void *(CJSON_CDECL *allocate)(size_t size); + void (CJSON_CDECL *deallocate)(void *pointer); + void *(CJSON_CDECL *reallocate)(void *pointer, size_t size); +} internal_hooks; + +#if defined(_MSC_VER) +/* work around MSVC error C2322: '...' address of dllimport '...' is not static */ +static void * CJSON_CDECL internal_malloc(size_t size) +{ + return malloc(size); +} +static void CJSON_CDECL internal_free(void *pointer) +{ + free(pointer); +} +static void * CJSON_CDECL internal_realloc(void *pointer, size_t size) +{ + return realloc(pointer, size); +} +#else +#define internal_malloc malloc +#define internal_free free +#define internal_realloc realloc +#endif + +/* strlen of character literals resolved at compile time */ +#define static_strlen(string_literal) (sizeof(string_literal) - sizeof("")) + +static internal_hooks global_hooks = { internal_malloc, internal_free, internal_realloc }; + +static unsigned char* cJSON_strdup(const unsigned char* string, const internal_hooks * const hooks) +{ + size_t length = 0; + unsigned char *copy = NULL; + + if (string == NULL) + { + return NULL; + } + + length = strlen((const char*)string) + sizeof(""); + copy = (unsigned char*)hooks->allocate(length); + if (copy == NULL) + { + return NULL; + } + memcpy(copy, string, length); + + return copy; +} + +CJSON_PUBLIC(void) cJSON_InitHooks(cJSON_Hooks* hooks) +{ + if (hooks == NULL) + { + /* Reset hooks */ + global_hooks.allocate = malloc; + global_hooks.deallocate = free; + global_hooks.reallocate = realloc; + return; + } + + global_hooks.allocate = malloc; + if (hooks->malloc_fn != NULL) + { + global_hooks.allocate = hooks->malloc_fn; + } + + global_hooks.deallocate = free; + if (hooks->free_fn != NULL) + { + global_hooks.deallocate = hooks->free_fn; + } + + /* use realloc only if both free and malloc are used */ + global_hooks.reallocate = NULL; + if ((global_hooks.allocate == malloc) && (global_hooks.deallocate == free)) + { + global_hooks.reallocate = realloc; + } +} + +/* Internal constructor. */ +static cJSON *cJSON_New_Item(const internal_hooks * const hooks) +{ + cJSON* node = (cJSON*)hooks->allocate(sizeof(cJSON)); + if (node) + { + memset(node, '\0', sizeof(cJSON)); + } + + return node; +} + +/* Delete a cJSON structure. */ +CJSON_PUBLIC(void) cJSON_Delete(cJSON *item) +{ + cJSON *next = NULL; + while (item != NULL) + { + next = item->next; + if (!(item->type & cJSON_IsReference) && (item->child != NULL)) + { + cJSON_Delete(item->child); + } + if (!(item->type & cJSON_IsReference) && (item->valuestring != NULL)) + { + global_hooks.deallocate(item->valuestring); + item->valuestring = NULL; + } + if (!(item->type & cJSON_StringIsConst) && (item->string != NULL)) + { + global_hooks.deallocate(item->string); + item->string = NULL; + } + global_hooks.deallocate(item); + item = next; + } +} + +/* get the decimal point character of the current locale */ +static unsigned char get_decimal_point(void) +{ +#ifdef ENABLE_LOCALES + struct lconv *lconv = localeconv(); + return (unsigned char) lconv->decimal_point[0]; +#else + return '.'; +#endif +} + +typedef struct +{ + const unsigned char *content; + size_t length; + size_t offset; + size_t depth; /* How deeply nested (in arrays/objects) is the input at the current offset. */ + internal_hooks hooks; +} parse_buffer; + +/* check if the given size is left to read in a given parse buffer (starting with 1) */ +#define can_read(buffer, size) ((buffer != NULL) && (((buffer)->offset + size) <= (buffer)->length)) +/* check if the buffer can be accessed at the given index (starting with 0) */ +#define can_access_at_index(buffer, index) ((buffer != NULL) && (((buffer)->offset + index) < (buffer)->length)) +#define cannot_access_at_index(buffer, index) (!can_access_at_index(buffer, index)) +/* get a pointer to the buffer at the position */ +#define buffer_at_offset(buffer) ((buffer)->content + (buffer)->offset) + +/* Parse the input text to generate a number, and populate the result into item. */ +static cJSON_bool parse_number(cJSON * const item, parse_buffer * const input_buffer) +{ + double number = 0; + unsigned char *after_end = NULL; + unsigned char number_c_string[64]; + unsigned char decimal_point = get_decimal_point(); + size_t i = 0; + + if ((input_buffer == NULL) || (input_buffer->content == NULL)) + { + return false; + } + + /* copy the number into a temporary buffer and replace '.' with the decimal point + * of the current locale (for strtod) + * This also takes care of '\0' not necessarily being available for marking the end of the input */ + for (i = 0; (i < (sizeof(number_c_string) - 1)) && can_access_at_index(input_buffer, i); i++) + { + switch (buffer_at_offset(input_buffer)[i]) + { + case '0': + case '1': + case '2': + case '3': + case '4': + case '5': + case '6': + case '7': + case '8': + case '9': + case '+': + case '-': + case 'e': + case 'E': + number_c_string[i] = buffer_at_offset(input_buffer)[i]; + break; + + case '.': + number_c_string[i] = decimal_point; + break; + + default: + goto loop_end; + } + } +loop_end: + number_c_string[i] = '\0'; + + number = strtod((const char*)number_c_string, (char**)&after_end); + if (number_c_string == after_end) + { + return false; /* parse_error */ + } + + item->valuedouble = number; + + /* use saturation in case of overflow */ + if (number >= INT_MAX) + { + item->valueint = INT_MAX; + } + else if (number <= (double)INT_MIN) + { + item->valueint = INT_MIN; + } + else + { + item->valueint = (int)number; + } + + item->type = cJSON_Number; + + input_buffer->offset += (size_t)(after_end - number_c_string); + return true; +} + +/* don't ask me, but the original cJSON_SetNumberValue returns an integer or double */ +CJSON_PUBLIC(double) cJSON_SetNumberHelper(cJSON *object, double number) +{ + if (number >= INT_MAX) + { + object->valueint = INT_MAX; + } + else if (number <= (double)INT_MIN) + { + object->valueint = INT_MIN; + } + else + { + object->valueint = (int)number; + } + + return object->valuedouble = number; +} + +/* Note: when passing a NULL valuestring, cJSON_SetValuestring treats this as an error and return NULL */ +CJSON_PUBLIC(char*) cJSON_SetValuestring(cJSON *object, const char *valuestring) +{ + char *copy = NULL; + size_t v1_len; + size_t v2_len; + /* if object's type is not cJSON_String or is cJSON_IsReference, it should not set valuestring */ + if ((object == NULL) || !(object->type & cJSON_String) || (object->type & cJSON_IsReference)) + { + return NULL; + } + /* return NULL if the object is corrupted or valuestring is NULL */ + if (object->valuestring == NULL || valuestring == NULL) + { + return NULL; + } + + v1_len = strlen(valuestring); + v2_len = strlen(object->valuestring); + + if (v1_len <= v2_len) + { + /* strcpy does not handle overlapping string: [X1, X2] [Y1, Y2] => X2 < Y1 or Y2 < X1 */ + if (!( valuestring + v1_len < object->valuestring || object->valuestring + v2_len < valuestring )) + { + return NULL; + } + strcpy(object->valuestring, valuestring); + return object->valuestring; + } + copy = (char*) cJSON_strdup((const unsigned char*)valuestring, &global_hooks); + if (copy == NULL) + { + return NULL; + } + if (object->valuestring != NULL) + { + cJSON_free(object->valuestring); + } + object->valuestring = copy; + + return copy; +} + +typedef struct +{ + unsigned char *buffer; + size_t length; + size_t offset; + size_t depth; /* current nesting depth (for formatted printing) */ + cJSON_bool noalloc; + cJSON_bool format; /* is this print a formatted print */ + internal_hooks hooks; +} printbuffer; + +/* realloc printbuffer if necessary to have at least "needed" bytes more */ +static unsigned char* ensure(printbuffer * const p, size_t needed) +{ + unsigned char *newbuffer = NULL; + size_t newsize = 0; + + if ((p == NULL) || (p->buffer == NULL)) + { + return NULL; + } + + if ((p->length > 0) && (p->offset >= p->length)) + { + /* make sure that offset is valid */ + return NULL; + } + + if (needed > INT_MAX) + { + /* sizes bigger than INT_MAX are currently not supported */ + return NULL; + } + + needed += p->offset + 1; + if (needed <= p->length) + { + return p->buffer + p->offset; + } + + if (p->noalloc) { + return NULL; + } + + /* calculate new buffer size */ + if (needed > (INT_MAX / 2)) + { + /* overflow of int, use INT_MAX if possible */ + if (needed <= INT_MAX) + { + newsize = INT_MAX; + } + else + { + return NULL; + } + } + else + { + newsize = needed * 2; + } + + if (p->hooks.reallocate != NULL) + { + /* reallocate with realloc if available */ + newbuffer = (unsigned char*)p->hooks.reallocate(p->buffer, newsize); + if (newbuffer == NULL) + { + p->hooks.deallocate(p->buffer); + p->length = 0; + p->buffer = NULL; + + return NULL; + } + } + else + { + /* otherwise reallocate manually */ + newbuffer = (unsigned char*)p->hooks.allocate(newsize); + if (!newbuffer) + { + p->hooks.deallocate(p->buffer); + p->length = 0; + p->buffer = NULL; + + return NULL; + } + + memcpy(newbuffer, p->buffer, p->offset + 1); + p->hooks.deallocate(p->buffer); + } + p->length = newsize; + p->buffer = newbuffer; + + return newbuffer + p->offset; +} + +/* calculate the new length of the string in a printbuffer and update the offset */ +static void update_offset(printbuffer * const buffer) +{ + const unsigned char *buffer_pointer = NULL; + if ((buffer == NULL) || (buffer->buffer == NULL)) + { + return; + } + buffer_pointer = buffer->buffer + buffer->offset; + + buffer->offset += strlen((const char*)buffer_pointer); +} + +/* securely comparison of floating-point variables */ +static cJSON_bool compare_double(double a, double b) +{ + double maxVal = fabs(a) > fabs(b) ? fabs(a) : fabs(b); + return (fabs(a - b) <= maxVal * DBL_EPSILON); +} + +/* Render the number nicely from the given item into a string. */ +static cJSON_bool print_number(const cJSON * const item, printbuffer * const output_buffer) +{ + unsigned char *output_pointer = NULL; + double d = item->valuedouble; + int length = 0; + size_t i = 0; + unsigned char number_buffer[26] = {0}; /* temporary buffer to print the number into */ + unsigned char decimal_point = get_decimal_point(); + double test = 0.0; + + if (output_buffer == NULL) + { + return false; + } + + /* This checks for NaN and Infinity */ + if (isnan(d) || isinf(d)) + { + length = snprintf((char*)number_buffer, sizeof(number_buffer), "null"); + } + else if(d == (double)item->valueint) + { + length = snprintf((char*)number_buffer, sizeof(number_buffer), "%d", item->valueint); + } + else + { + /* Try 15 decimal places of precision to avoid nonsignificant nonzero digits */ + length = snprintf((char*)number_buffer, sizeof(number_buffer), "%1.15g", d); + + /* Check whether the original double can be recovered */ + if ((sscanf((char*)number_buffer, "%lg", &test) != 1) || !compare_double((double)test, d)) + { + /* If not, print with 17 decimal places of precision */ + length = snprintf((char*)number_buffer, sizeof(number_buffer), "%1.17g", d); + } + } + + /* snprintf failed or buffer overrun occurred */ + if ((length < 0) || (length > (int)(sizeof(number_buffer) - 1))) + { + return false; + } + + /* reserve appropriate space in the output */ + output_pointer = ensure(output_buffer, (size_t)length + sizeof("")); + if (output_pointer == NULL) + { + return false; + } + + /* copy the printed number to the output and replace locale + * dependent decimal point with '.' */ + for (i = 0; i < ((size_t)length); i++) + { + if (number_buffer[i] == decimal_point) + { + output_pointer[i] = '.'; + continue; + } + + output_pointer[i] = number_buffer[i]; + } + output_pointer[i] = '\0'; + + output_buffer->offset += (size_t)length; + + return true; +} + +/* parse 4 digit hexadecimal number */ +static unsigned parse_hex4(const unsigned char * const input) +{ + unsigned int h = 0; + size_t i = 0; + + for (i = 0; i < 4; i++) + { + /* parse digit */ + if ((input[i] >= '0') && (input[i] <= '9')) + { + h += (unsigned int) input[i] - '0'; + } + else if ((input[i] >= 'A') && (input[i] <= 'F')) + { + h += (unsigned int) 10 + input[i] - 'A'; + } + else if ((input[i] >= 'a') && (input[i] <= 'f')) + { + h += (unsigned int) 10 + input[i] - 'a'; + } + else /* invalid */ + { + return 0; + } + + if (i < 3) + { + /* shift left to make place for the next nibble */ + h = h << 4; + } + } + + return h; +} + +/* converts a UTF-16 literal to UTF-8 + * A literal can be one or two sequences of the form \uXXXX */ +static unsigned char utf16_literal_to_utf8(const unsigned char * const input_pointer, const unsigned char * const input_end, unsigned char **output_pointer) +{ + long unsigned int codepoint = 0; + unsigned int first_code = 0; + const unsigned char *first_sequence = input_pointer; + unsigned char utf8_length = 0; + unsigned char utf8_position = 0; + unsigned char sequence_length = 0; + unsigned char first_byte_mark = 0; + + if ((input_end - first_sequence) < 6) + { + /* input ends unexpectedly */ + goto fail; + } + + /* get the first utf16 sequence */ + first_code = parse_hex4(first_sequence + 2); + + /* check that the code is valid */ + if (((first_code >= 0xDC00) && (first_code <= 0xDFFF))) + { + goto fail; + } + + /* UTF16 surrogate pair */ + if ((first_code >= 0xD800) && (first_code <= 0xDBFF)) + { + const unsigned char *second_sequence = first_sequence + 6; + unsigned int second_code = 0; + sequence_length = 12; /* \uXXXX\uXXXX */ + + if ((input_end - second_sequence) < 6) + { + /* input ends unexpectedly */ + goto fail; + } + + if ((second_sequence[0] != '\\') || (second_sequence[1] != 'u')) + { + /* missing second half of the surrogate pair */ + goto fail; + } + + /* get the second utf16 sequence */ + second_code = parse_hex4(second_sequence + 2); + /* check that the code is valid */ + if ((second_code < 0xDC00) || (second_code > 0xDFFF)) + { + /* invalid second half of the surrogate pair */ + goto fail; + } + + + /* calculate the unicode codepoint from the surrogate pair */ + codepoint = 0x10000 + (((first_code & 0x3FF) << 10) | (second_code & 0x3FF)); + } + else + { + sequence_length = 6; /* \uXXXX */ + codepoint = first_code; + } + + /* encode as UTF-8 + * takes at maximum 4 bytes to encode: + * 11110xxx 10xxxxxx 10xxxxxx 10xxxxxx */ + if (codepoint < 0x80) + { + /* normal ascii, encoding 0xxxxxxx */ + utf8_length = 1; + } + else if (codepoint < 0x800) + { + /* two bytes, encoding 110xxxxx 10xxxxxx */ + utf8_length = 2; + first_byte_mark = 0xC0; /* 11000000 */ + } + else if (codepoint < 0x10000) + { + /* three bytes, encoding 1110xxxx 10xxxxxx 10xxxxxx */ + utf8_length = 3; + first_byte_mark = 0xE0; /* 11100000 */ + } + else if (codepoint <= 0x10FFFF) + { + /* four bytes, encoding 1110xxxx 10xxxxxx 10xxxxxx 10xxxxxx */ + utf8_length = 4; + first_byte_mark = 0xF0; /* 11110000 */ + } + else + { + /* invalid unicode codepoint */ + goto fail; + } + + /* encode as utf8 */ + for (utf8_position = (unsigned char)(utf8_length - 1); utf8_position > 0; utf8_position--) + { + /* 10xxxxxx */ + (*output_pointer)[utf8_position] = (unsigned char)((codepoint | 0x80) & 0xBF); + codepoint >>= 6; + } + /* encode first byte */ + if (utf8_length > 1) + { + (*output_pointer)[0] = (unsigned char)((codepoint | first_byte_mark) & 0xFF); + } + else + { + (*output_pointer)[0] = (unsigned char)(codepoint & 0x7F); + } + + *output_pointer += utf8_length; + + return sequence_length; + +fail: + return 0; +} + +/* Parse the input text into an unescaped cinput, and populate item. */ +static cJSON_bool parse_string(cJSON * const item, parse_buffer * const input_buffer) +{ + const unsigned char *input_pointer = buffer_at_offset(input_buffer) + 1; + const unsigned char *input_end = buffer_at_offset(input_buffer) + 1; + unsigned char *output_pointer = NULL; + unsigned char *output = NULL; + + /* not a string */ + if (buffer_at_offset(input_buffer)[0] != '\"') + { + goto fail; + } + + { + /* calculate approximate size of the output (overestimate) */ + size_t allocation_length = 0; + size_t skipped_bytes = 0; + while (((size_t)(input_end - input_buffer->content) < input_buffer->length) && (*input_end != '\"')) + { + /* is escape sequence */ + if (input_end[0] == '\\') + { + if ((size_t)(input_end + 1 - input_buffer->content) >= input_buffer->length) + { + /* prevent buffer overflow when last input character is a backslash */ + goto fail; + } + skipped_bytes++; + input_end++; + } + input_end++; + } + if (((size_t)(input_end - input_buffer->content) >= input_buffer->length) || (*input_end != '\"')) + { + goto fail; /* string ended unexpectedly */ + } + + /* This is at most how much we need for the output */ + allocation_length = (size_t) (input_end - buffer_at_offset(input_buffer)) - skipped_bytes; + output = (unsigned char*)input_buffer->hooks.allocate(allocation_length + sizeof("")); + if (output == NULL) + { + goto fail; /* allocation failure */ + } + } + + output_pointer = output; + /* loop through the string literal */ + while (input_pointer < input_end) + { + if (*input_pointer != '\\') + { + *output_pointer++ = *input_pointer++; + } + /* escape sequence */ + else + { + unsigned char sequence_length = 2; + if ((input_end - input_pointer) < 1) + { + goto fail; + } + + switch (input_pointer[1]) + { + case 'b': + *output_pointer++ = '\b'; + break; + case 'f': + *output_pointer++ = '\f'; + break; + case 'n': + *output_pointer++ = '\n'; + break; + case 'r': + *output_pointer++ = '\r'; + break; + case 't': + *output_pointer++ = '\t'; + break; + case '\"': + case '\\': + case '/': + *output_pointer++ = input_pointer[1]; + break; + + /* UTF-16 literal */ + case 'u': + sequence_length = utf16_literal_to_utf8(input_pointer, input_end, &output_pointer); + if (sequence_length == 0) + { + /* failed to convert UTF16-literal to UTF-8 */ + goto fail; + } + break; + + default: + goto fail; + } + input_pointer += sequence_length; + } + } + + /* zero terminate the output */ + *output_pointer = '\0'; + + item->type = cJSON_String; + item->valuestring = (char*)output; + + input_buffer->offset = (size_t) (input_end - input_buffer->content); + input_buffer->offset++; + + return true; + +fail: + if (output != NULL) + { + input_buffer->hooks.deallocate(output); + output = NULL; + } + + if (input_pointer != NULL) + { + input_buffer->offset = (size_t)(input_pointer - input_buffer->content); + } + + return false; +} + +/* Render the cstring provided to an escaped version that can be printed. */ +static cJSON_bool print_string_ptr(const unsigned char * const input, printbuffer * const output_buffer) +{ + const unsigned char *input_pointer = NULL; + unsigned char *output = NULL; + unsigned char *output_pointer = NULL; + size_t output_length = 0; + /* numbers of additional characters needed for escaping */ + size_t escape_characters = 0; + + if (output_buffer == NULL) + { + return false; + } + + /* empty string */ + if (input == NULL) + { + output = ensure(output_buffer, sizeof("\"\"")); + if (output == NULL) + { + return false; + } + strcpy((char*)output, "\"\""); + + return true; + } + + /* set "flag" to 1 if something needs to be escaped */ + for (input_pointer = input; *input_pointer; input_pointer++) + { + switch (*input_pointer) + { + case '\"': + case '\\': + case '\b': + case '\f': + case '\n': + case '\r': + case '\t': + /* one character escape sequence */ + escape_characters++; + break; + default: + if (*input_pointer < 32) + { + /* UTF-16 escape sequence uXXXX */ + escape_characters += 5; + } + break; + } + } + output_length = (size_t)(input_pointer - input) + escape_characters; + + output = ensure(output_buffer, output_length + sizeof("\"\"")); + if (output == NULL) + { + return false; + } + + /* no characters have to be escaped */ + if (escape_characters == 0) + { + output[0] = '\"'; + memcpy(output + 1, input, output_length); + output[output_length + 1] = '\"'; + output[output_length + 2] = '\0'; + + return true; + } + + output[0] = '\"'; + output_pointer = output + 1; + /* copy the string */ + for (input_pointer = input; *input_pointer != '\0'; (void)input_pointer++, output_pointer++) + { + if ((*input_pointer > 31) && (*input_pointer != '\"') && (*input_pointer != '\\')) + { + /* normal character, copy */ + *output_pointer = *input_pointer; + } + else + { + /* character needs to be escaped */ + *output_pointer++ = '\\'; + switch (*input_pointer) + { + case '\\': + *output_pointer = '\\'; + break; + case '\"': + *output_pointer = '\"'; + break; + case '\b': + *output_pointer = 'b'; + break; + case '\f': + *output_pointer = 'f'; + break; + case '\n': + *output_pointer = 'n'; + break; + case '\r': + *output_pointer = 'r'; + break; + case '\t': + *output_pointer = 't'; + break; + default: + /* escape and print as unicode codepoint */ + snprintf((char*)output_pointer, sizeof(output_pointer), "u%04x", *input_pointer); + output_pointer += 4; + break; + } + } + } + output[output_length + 1] = '\"'; + output[output_length + 2] = '\0'; + + return true; +} + +/* Invoke print_string_ptr (which is useful) on an item. */ +static cJSON_bool print_string(const cJSON * const item, printbuffer * const p) +{ + return print_string_ptr((unsigned char*)item->valuestring, p); +} + +/* Predeclare these prototypes. */ +static cJSON_bool parse_value(cJSON * const item, parse_buffer * const input_buffer); +static cJSON_bool print_value(const cJSON * const item, printbuffer * const output_buffer); +static cJSON_bool parse_array(cJSON * const item, parse_buffer * const input_buffer); +static cJSON_bool print_array(const cJSON * const item, printbuffer * const output_buffer); +static cJSON_bool parse_object(cJSON * const item, parse_buffer * const input_buffer); +static cJSON_bool print_object(const cJSON * const item, printbuffer * const output_buffer); + +/* Utility to jump whitespace and cr/lf */ +static parse_buffer *buffer_skip_whitespace(parse_buffer * const buffer) +{ + if ((buffer == NULL) || (buffer->content == NULL)) + { + return NULL; + } + + if (cannot_access_at_index(buffer, 0)) + { + return buffer; + } + + while (can_access_at_index(buffer, 0) && (buffer_at_offset(buffer)[0] <= 32)) + { + buffer->offset++; + } + + if (buffer->offset == buffer->length) + { + buffer->offset--; + } + + return buffer; +} + +/* skip the UTF-8 BOM (byte order mark) if it is at the beginning of a buffer */ +static parse_buffer *skip_utf8_bom(parse_buffer * const buffer) +{ + if ((buffer == NULL) || (buffer->content == NULL) || (buffer->offset != 0)) + { + return NULL; + } + + if (can_access_at_index(buffer, 4) && (strncmp((const char*)buffer_at_offset(buffer), "\xEF\xBB\xBF", 3) == 0)) + { + buffer->offset += 3; + } + + return buffer; +} + +CJSON_PUBLIC(cJSON *) cJSON_ParseWithOpts(const char *value, const char **return_parse_end, cJSON_bool require_null_terminated) +{ + size_t buffer_length; + + if (NULL == value) + { + return NULL; + } + + /* Adding null character size due to require_null_terminated. */ + buffer_length = strlen(value) + sizeof(""); + + return cJSON_ParseWithLengthOpts(value, buffer_length, return_parse_end, require_null_terminated); +} + +/* Parse an object - create a new root, and populate. */ +CJSON_PUBLIC(cJSON *) cJSON_ParseWithLengthOpts(const char *value, size_t buffer_length, const char **return_parse_end, cJSON_bool require_null_terminated) +{ + parse_buffer buffer = { 0, 0, 0, 0, { 0, 0, 0 } }; + cJSON *item = NULL; + + /* reset error position */ + global_error.json = NULL; + global_error.position = 0; + + if (value == NULL || 0 == buffer_length) + { + goto fail; + } + + buffer.content = (const unsigned char*)value; + buffer.length = buffer_length; + buffer.offset = 0; + buffer.hooks = global_hooks; + + item = cJSON_New_Item(&global_hooks); + if (item == NULL) /* memory fail */ + { + goto fail; + } + + if (!parse_value(item, buffer_skip_whitespace(skip_utf8_bom(&buffer)))) + { + /* parse failure. ep is set. */ + goto fail; + } + + /* if we require null-terminated JSON without appended garbage, skip and then check for a null terminator */ + if (require_null_terminated) + { + buffer_skip_whitespace(&buffer); + if ((buffer.offset >= buffer.length) || buffer_at_offset(&buffer)[0] != '\0') + { + goto fail; + } + } + if (return_parse_end) + { + *return_parse_end = (const char*)buffer_at_offset(&buffer); + } + + return item; + +fail: + if (item != NULL) + { + cJSON_Delete(item); + } + + if (value != NULL) + { + error local_error; + local_error.json = (const unsigned char*)value; + local_error.position = 0; + + if (buffer.offset < buffer.length) + { + local_error.position = buffer.offset; + } + else if (buffer.length > 0) + { + local_error.position = buffer.length - 1; + } + + if (return_parse_end != NULL) + { + *return_parse_end = (const char*)local_error.json + local_error.position; + } + + global_error = local_error; + } + + return NULL; +} + +/* Default options for cJSON_Parse */ +CJSON_PUBLIC(cJSON *) cJSON_Parse(const char *value) +{ + return cJSON_ParseWithOpts(value, 0, 0); +} + +CJSON_PUBLIC(cJSON *) cJSON_ParseWithLength(const char *value, size_t buffer_length) +{ + return cJSON_ParseWithLengthOpts(value, buffer_length, 0, 0); +} + +#define cjson_min(a, b) (((a) < (b)) ? (a) : (b)) + +static unsigned char *print(const cJSON * const item, cJSON_bool format, const internal_hooks * const hooks) +{ + static const size_t default_buffer_size = 256; + printbuffer buffer[1]; + unsigned char *printed = NULL; + + memset(buffer, 0, sizeof(buffer)); + + /* create buffer */ + buffer->buffer = (unsigned char*) hooks->allocate(default_buffer_size); + buffer->length = default_buffer_size; + buffer->format = format; + buffer->hooks = *hooks; + if (buffer->buffer == NULL) + { + goto fail; + } + + /* print the value */ + if (!print_value(item, buffer)) + { + goto fail; + } + update_offset(buffer); + + /* check if reallocate is available */ + if (hooks->reallocate != NULL) + { + printed = (unsigned char*) hooks->reallocate(buffer->buffer, buffer->offset + 1); + if (printed == NULL) { + goto fail; + } + buffer->buffer = NULL; + } + else /* otherwise copy the JSON over to a new buffer */ + { + printed = (unsigned char*) hooks->allocate(buffer->offset + 1); + if (printed == NULL) + { + goto fail; + } + memcpy(printed, buffer->buffer, cjson_min(buffer->length, buffer->offset + 1)); + printed[buffer->offset] = '\0'; /* just to be sure */ + + /* free the buffer */ + hooks->deallocate(buffer->buffer); + buffer->buffer = NULL; + } + + return printed; + +fail: + if (buffer->buffer != NULL) + { + hooks->deallocate(buffer->buffer); + buffer->buffer = NULL; + } + + if (printed != NULL) + { + hooks->deallocate(printed); + printed = NULL; + } + + return NULL; +} + +/* Render a cJSON item/entity/structure to text. */ +CJSON_PUBLIC(char *) cJSON_Print(const cJSON *item) +{ + return (char*)print(item, true, &global_hooks); +} + +CJSON_PUBLIC(char *) cJSON_PrintUnformatted(const cJSON *item) +{ + return (char*)print(item, false, &global_hooks); +} + +CJSON_PUBLIC(char *) cJSON_PrintBuffered(const cJSON *item, int prebuffer, cJSON_bool fmt) +{ + printbuffer p = { 0, 0, 0, 0, 0, 0, { 0, 0, 0 } }; + + if (prebuffer < 0) + { + return NULL; + } + + p.buffer = (unsigned char*)global_hooks.allocate((size_t)prebuffer); + if (!p.buffer) + { + return NULL; + } + + p.length = (size_t)prebuffer; + p.offset = 0; + p.noalloc = false; + p.format = fmt; + p.hooks = global_hooks; + + if (!print_value(item, &p)) + { + global_hooks.deallocate(p.buffer); + p.buffer = NULL; + return NULL; + } + + return (char*)p.buffer; +} + +CJSON_PUBLIC(cJSON_bool) cJSON_PrintPreallocated(cJSON *item, char *buffer, const int length, const cJSON_bool format) +{ + printbuffer p = { 0, 0, 0, 0, 0, 0, { 0, 0, 0 } }; + + if ((length < 0) || (buffer == NULL)) + { + return false; + } + + p.buffer = (unsigned char*)buffer; + p.length = (size_t)length; + p.offset = 0; + p.noalloc = true; + p.format = format; + p.hooks = global_hooks; + + return print_value(item, &p); +} + +/* Parser core - when encountering text, process appropriately. */ +static cJSON_bool parse_value(cJSON * const item, parse_buffer * const input_buffer) +{ + if ((input_buffer == NULL) || (input_buffer->content == NULL)) + { + return false; /* no input */ + } + + /* parse the different types of values */ + /* null */ + if (can_read(input_buffer, 4) && (strncmp((const char*)buffer_at_offset(input_buffer), "null", 4) == 0)) + { + item->type = cJSON_NULL; + input_buffer->offset += 4; + return true; + } + /* false */ + if (can_read(input_buffer, 5) && (strncmp((const char*)buffer_at_offset(input_buffer), "false", 5) == 0)) + { + item->type = cJSON_False; + input_buffer->offset += 5; + return true; + } + /* true */ + if (can_read(input_buffer, 4) && (strncmp((const char*)buffer_at_offset(input_buffer), "true", 4) == 0)) + { + item->type = cJSON_True; + item->valueint = 1; + input_buffer->offset += 4; + return true; + } + /* string */ + if (can_access_at_index(input_buffer, 0) && (buffer_at_offset(input_buffer)[0] == '\"')) + { + return parse_string(item, input_buffer); + } + /* number */ + if (can_access_at_index(input_buffer, 0) && ((buffer_at_offset(input_buffer)[0] == '-') || ((buffer_at_offset(input_buffer)[0] >= '0') && (buffer_at_offset(input_buffer)[0] <= '9')))) + { + return parse_number(item, input_buffer); + } + /* array */ + if (can_access_at_index(input_buffer, 0) && (buffer_at_offset(input_buffer)[0] == '[')) + { + return parse_array(item, input_buffer); + } + /* object */ + if (can_access_at_index(input_buffer, 0) && (buffer_at_offset(input_buffer)[0] == '{')) + { + return parse_object(item, input_buffer); + } + + return false; +} + +/* Render a value to text. */ +static cJSON_bool print_value(const cJSON * const item, printbuffer * const output_buffer) +{ + unsigned char *output = NULL; + + if ((item == NULL) || (output_buffer == NULL)) + { + return false; + } + + switch ((item->type) & 0xFF) + { + case cJSON_NULL: + output = ensure(output_buffer, 5); + if (output == NULL) + { + return false; + } + strcpy((char*)output, "null"); + return true; + + case cJSON_False: + output = ensure(output_buffer, 6); + if (output == NULL) + { + return false; + } + strcpy((char*)output, "false"); + return true; + + case cJSON_True: + output = ensure(output_buffer, 5); + if (output == NULL) + { + return false; + } + strcpy((char*)output, "true"); + return true; + + case cJSON_Number: + return print_number(item, output_buffer); + + case cJSON_Raw: + { + size_t raw_length = 0; + if (item->valuestring == NULL) + { + return false; + } + + raw_length = strlen(item->valuestring) + sizeof(""); + output = ensure(output_buffer, raw_length); + if (output == NULL) + { + return false; + } + memcpy(output, item->valuestring, raw_length); + return true; + } + + case cJSON_String: + return print_string(item, output_buffer); + + case cJSON_Array: + return print_array(item, output_buffer); + + case cJSON_Object: + return print_object(item, output_buffer); + + default: + return false; + } +} + +/* Build an array from input text. */ +static cJSON_bool parse_array(cJSON * const item, parse_buffer * const input_buffer) +{ + cJSON *head = NULL; /* head of the linked list */ + cJSON *current_item = NULL; + + if (input_buffer->depth >= CJSON_NESTING_LIMIT) + { + return false; /* to deeply nested */ + } + input_buffer->depth++; + + if (buffer_at_offset(input_buffer)[0] != '[') + { + /* not an array */ + goto fail; + } + + input_buffer->offset++; + buffer_skip_whitespace(input_buffer); + if (can_access_at_index(input_buffer, 0) && (buffer_at_offset(input_buffer)[0] == ']')) + { + /* empty array */ + goto success; + } + + /* check if we skipped to the end of the buffer */ + if (cannot_access_at_index(input_buffer, 0)) + { + input_buffer->offset--; + goto fail; + } + + /* step back to character in front of the first element */ + input_buffer->offset--; + /* loop through the comma separated array elements */ + do + { + /* allocate next item */ + cJSON *new_item = cJSON_New_Item(&(input_buffer->hooks)); + if (new_item == NULL) + { + goto fail; /* allocation failure */ + } + + /* attach next item to list */ + if (head == NULL) + { + /* start the linked list */ + current_item = head = new_item; + } + else + { + /* add to the end and advance */ + current_item->next = new_item; + new_item->prev = current_item; + current_item = new_item; + } + + /* parse next value */ + input_buffer->offset++; + buffer_skip_whitespace(input_buffer); + if (!parse_value(current_item, input_buffer)) + { + goto fail; /* failed to parse value */ + } + buffer_skip_whitespace(input_buffer); + } + while (can_access_at_index(input_buffer, 0) && (buffer_at_offset(input_buffer)[0] == ',')); + + if (cannot_access_at_index(input_buffer, 0) || buffer_at_offset(input_buffer)[0] != ']') + { + goto fail; /* expected end of array */ + } + +success: + input_buffer->depth--; + + if (head != NULL) { + head->prev = current_item; + } + + item->type = cJSON_Array; + item->child = head; + + input_buffer->offset++; + + return true; + +fail: + if (head != NULL) + { + cJSON_Delete(head); + } + + return false; +} + +/* Render an array to text */ +static cJSON_bool print_array(const cJSON * const item, printbuffer * const output_buffer) +{ + unsigned char *output_pointer = NULL; + size_t length = 0; + cJSON *current_element = item->child; + + if (output_buffer == NULL) + { + return false; + } + + /* Compose the output array. */ + /* opening square bracket */ + output_pointer = ensure(output_buffer, 1); + if (output_pointer == NULL) + { + return false; + } + + *output_pointer = '['; + output_buffer->offset++; + output_buffer->depth++; + + while (current_element != NULL) + { + if (!print_value(current_element, output_buffer)) + { + return false; + } + update_offset(output_buffer); + if (current_element->next) + { + length = (size_t) (output_buffer->format ? 2 : 1); + output_pointer = ensure(output_buffer, length + 1); + if (output_pointer == NULL) + { + return false; + } + *output_pointer++ = ','; + if(output_buffer->format) + { + *output_pointer++ = ' '; + } + *output_pointer = '\0'; + output_buffer->offset += length; + } + current_element = current_element->next; + } + + output_pointer = ensure(output_buffer, 2); + if (output_pointer == NULL) + { + return false; + } + *output_pointer++ = ']'; + *output_pointer = '\0'; + output_buffer->depth--; + + return true; +} + +/* Build an object from the text. */ +static cJSON_bool parse_object(cJSON * const item, parse_buffer * const input_buffer) +{ + cJSON *head = NULL; /* linked list head */ + cJSON *current_item = NULL; + + if (input_buffer->depth >= CJSON_NESTING_LIMIT) + { + return false; /* to deeply nested */ + } + input_buffer->depth++; + + if (cannot_access_at_index(input_buffer, 0) || (buffer_at_offset(input_buffer)[0] != '{')) + { + goto fail; /* not an object */ + } + + input_buffer->offset++; + buffer_skip_whitespace(input_buffer); + if (can_access_at_index(input_buffer, 0) && (buffer_at_offset(input_buffer)[0] == '}')) + { + goto success; /* empty object */ + } + + /* check if we skipped to the end of the buffer */ + if (cannot_access_at_index(input_buffer, 0)) + { + input_buffer->offset--; + goto fail; + } + + /* step back to character in front of the first element */ + input_buffer->offset--; + /* loop through the comma separated array elements */ + do + { + /* allocate next item */ + cJSON *new_item = cJSON_New_Item(&(input_buffer->hooks)); + if (new_item == NULL) + { + goto fail; /* allocation failure */ + } + + /* attach next item to list */ + if (head == NULL) + { + /* start the linked list */ + current_item = head = new_item; + } + else + { + /* add to the end and advance */ + current_item->next = new_item; + new_item->prev = current_item; + current_item = new_item; + } + + if (cannot_access_at_index(input_buffer, 1)) + { + goto fail; /* nothing comes after the comma */ + } + + /* parse the name of the child */ + input_buffer->offset++; + buffer_skip_whitespace(input_buffer); + if (!parse_string(current_item, input_buffer)) + { + goto fail; /* failed to parse name */ + } + buffer_skip_whitespace(input_buffer); + + /* swap valuestring and string, because we parsed the name */ + current_item->string = current_item->valuestring; + current_item->valuestring = NULL; + + if (cannot_access_at_index(input_buffer, 0) || (buffer_at_offset(input_buffer)[0] != ':')) + { + goto fail; /* invalid object */ + } + + /* parse the value */ + input_buffer->offset++; + buffer_skip_whitespace(input_buffer); + if (!parse_value(current_item, input_buffer)) + { + goto fail; /* failed to parse value */ + } + buffer_skip_whitespace(input_buffer); + } + while (can_access_at_index(input_buffer, 0) && (buffer_at_offset(input_buffer)[0] == ',')); + + if (cannot_access_at_index(input_buffer, 0) || (buffer_at_offset(input_buffer)[0] != '}')) + { + goto fail; /* expected end of object */ + } + +success: + input_buffer->depth--; + + if (head != NULL) { + head->prev = current_item; + } + + item->type = cJSON_Object; + item->child = head; + + input_buffer->offset++; + return true; + +fail: + if (head != NULL) + { + cJSON_Delete(head); + } + + return false; +} + +/* Render an object to text. */ +static cJSON_bool print_object(const cJSON * const item, printbuffer * const output_buffer) +{ + unsigned char *output_pointer = NULL; + size_t length = 0; + cJSON *current_item = item->child; + + if (output_buffer == NULL) + { + return false; + } + + /* Compose the output: */ + length = (size_t) (output_buffer->format ? 2 : 1); /* fmt: {\n */ + output_pointer = ensure(output_buffer, length + 1); + if (output_pointer == NULL) + { + return false; + } + + *output_pointer++ = '{'; + output_buffer->depth++; + if (output_buffer->format) + { + *output_pointer++ = '\n'; + } + output_buffer->offset += length; + + while (current_item) + { + if (output_buffer->format) + { + size_t i; + output_pointer = ensure(output_buffer, output_buffer->depth); + if (output_pointer == NULL) + { + return false; + } + for (i = 0; i < output_buffer->depth; i++) + { + *output_pointer++ = '\t'; + } + output_buffer->offset += output_buffer->depth; + } + + /* print key */ + if (!print_string_ptr((unsigned char*)current_item->string, output_buffer)) + { + return false; + } + update_offset(output_buffer); + + length = (size_t) (output_buffer->format ? 2 : 1); + output_pointer = ensure(output_buffer, length); + if (output_pointer == NULL) + { + return false; + } + *output_pointer++ = ':'; + if (output_buffer->format) + { + *output_pointer++ = '\t'; + } + output_buffer->offset += length; + + /* print value */ + if (!print_value(current_item, output_buffer)) + { + return false; + } + update_offset(output_buffer); + + /* print comma if not last */ + length = ((size_t)(output_buffer->format ? 1 : 0) + (size_t)(current_item->next ? 1 : 0)); + output_pointer = ensure(output_buffer, length + 1); + if (output_pointer == NULL) + { + return false; + } + if (current_item->next) + { + *output_pointer++ = ','; + } + + if (output_buffer->format) + { + *output_pointer++ = '\n'; + } + *output_pointer = '\0'; + output_buffer->offset += length; + + current_item = current_item->next; + } + + output_pointer = ensure(output_buffer, output_buffer->format ? (output_buffer->depth + 1) : 2); + if (output_pointer == NULL) + { + return false; + } + if (output_buffer->format) + { + size_t i; + for (i = 0; i < (output_buffer->depth - 1); i++) + { + *output_pointer++ = '\t'; + } + } + *output_pointer++ = '}'; + *output_pointer = '\0'; + output_buffer->depth--; + + return true; +} + +/* Get Array size/item / object item. */ +CJSON_PUBLIC(int) cJSON_GetArraySize(const cJSON *array) +{ + cJSON *child = NULL; + size_t size = 0; + + if (array == NULL) + { + return 0; + } + + child = array->child; + + while(child != NULL) + { + size++; + child = child->next; + } + + /* FIXME: Can overflow here. Cannot be fixed without breaking the API */ + + return (int)size; +} + +static cJSON* get_array_item(const cJSON *array, size_t index) +{ + cJSON *current_child = NULL; + + if (array == NULL) + { + return NULL; + } + + current_child = array->child; + while ((current_child != NULL) && (index > 0)) + { + index--; + current_child = current_child->next; + } + + return current_child; +} + +CJSON_PUBLIC(cJSON *) cJSON_GetArrayItem(const cJSON *array, int index) +{ + if (index < 0) + { + return NULL; + } + + return get_array_item(array, (size_t)index); +} + +static cJSON *get_object_item(const cJSON * const object, const char * const name, const cJSON_bool case_sensitive) +{ + cJSON *current_element = NULL; + + if ((object == NULL) || (name == NULL)) + { + return NULL; + } + + current_element = object->child; + if (case_sensitive) + { + while ((current_element != NULL) && (current_element->string != NULL) && (strcmp(name, current_element->string) != 0)) + { + current_element = current_element->next; + } + } + else + { + while ((current_element != NULL) && (case_insensitive_strcmp((const unsigned char*)name, (const unsigned char*)(current_element->string)) != 0)) + { + current_element = current_element->next; + } + } + + if ((current_element == NULL) || (current_element->string == NULL)) { + return NULL; + } + + return current_element; +} + +CJSON_PUBLIC(cJSON *) cJSON_GetObjectItem(const cJSON * const object, const char * const string) +{ + return get_object_item(object, string, false); +} + +CJSON_PUBLIC(cJSON *) cJSON_GetObjectItemCaseSensitive(const cJSON * const object, const char * const string) +{ + return get_object_item(object, string, true); +} + +CJSON_PUBLIC(cJSON_bool) cJSON_HasObjectItem(const cJSON *object, const char *string) +{ + return cJSON_GetObjectItem(object, string) ? 1 : 0; +} + +/* Utility for array list handling. */ +static void suffix_object(cJSON *prev, cJSON *item) +{ + prev->next = item; + item->prev = prev; +} + +/* Utility for handling references. */ +static cJSON *create_reference(const cJSON *item, const internal_hooks * const hooks) +{ + cJSON *reference = NULL; + if (item == NULL) + { + return NULL; + } + + reference = cJSON_New_Item(hooks); + if (reference == NULL) + { + return NULL; + } + + memcpy(reference, item, sizeof(cJSON)); + reference->string = NULL; + reference->type |= cJSON_IsReference; + reference->next = reference->prev = NULL; + return reference; +} + +static cJSON_bool add_item_to_array(cJSON *array, cJSON *item) +{ + cJSON *child = NULL; + + if ((item == NULL) || (array == NULL) || (array == item)) + { + return false; + } + + child = array->child; + /* + * To find the last item in array quickly, we use prev in array + */ + if (child == NULL) + { + /* list is empty, start new one */ + array->child = item; + item->prev = item; + item->next = NULL; + } + else + { + /* append to the end */ + if (child->prev) + { + suffix_object(child->prev, item); + array->child->prev = item; + } + } + + return true; +} + +/* Add item to array/object. */ +CJSON_PUBLIC(cJSON_bool) cJSON_AddItemToArray(cJSON *array, cJSON *item) +{ + return add_item_to_array(array, item); +} + +#if defined(__clang__) || (defined(__GNUC__) && ((__GNUC__ > 4) || ((__GNUC__ == 4) && (__GNUC_MINOR__ > 5)))) + #pragma GCC diagnostic push +#endif +#ifdef __GNUC__ +#pragma GCC diagnostic ignored "-Wcast-qual" +#endif +/* helper function to cast away const */ +static void* cast_away_const(const void* string) +{ + return (void*)string; +} +#if defined(__clang__) || (defined(__GNUC__) && ((__GNUC__ > 4) || ((__GNUC__ == 4) && (__GNUC_MINOR__ > 5)))) + #pragma GCC diagnostic pop +#endif + + +static cJSON_bool add_item_to_object(cJSON * const object, const char * const string, cJSON * const item, const internal_hooks * const hooks, const cJSON_bool constant_key) +{ + char *new_key = NULL; + int new_type = cJSON_Invalid; + + if ((object == NULL) || (string == NULL) || (item == NULL) || (object == item)) + { + return false; + } + + if (constant_key) + { + new_key = (char*)cast_away_const(string); + new_type = item->type | cJSON_StringIsConst; + } + else + { + new_key = (char*)cJSON_strdup((const unsigned char*)string, hooks); + if (new_key == NULL) + { + return false; + } + + new_type = item->type & ~cJSON_StringIsConst; + } + + if (!(item->type & cJSON_StringIsConst) && (item->string != NULL)) + { + hooks->deallocate(item->string); + } + + item->string = new_key; + item->type = new_type; + + return add_item_to_array(object, item); +} + +CJSON_PUBLIC(cJSON_bool) cJSON_AddItemToObject(cJSON *object, const char *string, cJSON *item) +{ + return add_item_to_object(object, string, item, &global_hooks, false); +} + +/* Add an item to an object with constant string as key */ +CJSON_PUBLIC(cJSON_bool) cJSON_AddItemToObjectCS(cJSON *object, const char *string, cJSON *item) +{ + return add_item_to_object(object, string, item, &global_hooks, true); +} + +CJSON_PUBLIC(cJSON_bool) cJSON_AddItemReferenceToArray(cJSON *array, cJSON *item) +{ + if (array == NULL) + { + return false; + } + + return add_item_to_array(array, create_reference(item, &global_hooks)); +} + +CJSON_PUBLIC(cJSON_bool) cJSON_AddItemReferenceToObject(cJSON *object, const char *string, cJSON *item) +{ + if ((object == NULL) || (string == NULL)) + { + return false; + } + + return add_item_to_object(object, string, create_reference(item, &global_hooks), &global_hooks, false); +} + +CJSON_PUBLIC(cJSON*) cJSON_AddNullToObject(cJSON * const object, const char * const name) +{ + cJSON *null = cJSON_CreateNull(); + if (add_item_to_object(object, name, null, &global_hooks, false)) + { + return null; + } + + cJSON_Delete(null); + return NULL; +} + +CJSON_PUBLIC(cJSON*) cJSON_AddTrueToObject(cJSON * const object, const char * const name) +{ + cJSON *true_item = cJSON_CreateTrue(); + if (add_item_to_object(object, name, true_item, &global_hooks, false)) + { + return true_item; + } + + cJSON_Delete(true_item); + return NULL; +} + +CJSON_PUBLIC(cJSON*) cJSON_AddFalseToObject(cJSON * const object, const char * const name) +{ + cJSON *false_item = cJSON_CreateFalse(); + if (add_item_to_object(object, name, false_item, &global_hooks, false)) + { + return false_item; + } + + cJSON_Delete(false_item); + return NULL; +} + +CJSON_PUBLIC(cJSON*) cJSON_AddBoolToObject(cJSON * const object, const char * const name, const cJSON_bool boolean) +{ + cJSON *bool_item = cJSON_CreateBool(boolean); + if (add_item_to_object(object, name, bool_item, &global_hooks, false)) + { + return bool_item; + } + + cJSON_Delete(bool_item); + return NULL; +} + +CJSON_PUBLIC(cJSON*) cJSON_AddNumberToObject(cJSON * const object, const char * const name, const double number) +{ + cJSON *number_item = cJSON_CreateNumber(number); + if (add_item_to_object(object, name, number_item, &global_hooks, false)) + { + return number_item; + } + + cJSON_Delete(number_item); + return NULL; +} + +CJSON_PUBLIC(cJSON*) cJSON_AddStringToObject(cJSON * const object, const char * const name, const char * const string) +{ + cJSON *string_item = cJSON_CreateString(string); + if (add_item_to_object(object, name, string_item, &global_hooks, false)) + { + return string_item; + } + + cJSON_Delete(string_item); + return NULL; +} + +CJSON_PUBLIC(cJSON*) cJSON_AddRawToObject(cJSON * const object, const char * const name, const char * const raw) +{ + cJSON *raw_item = cJSON_CreateRaw(raw); + if (add_item_to_object(object, name, raw_item, &global_hooks, false)) + { + return raw_item; + } + + cJSON_Delete(raw_item); + return NULL; +} + +CJSON_PUBLIC(cJSON*) cJSON_AddObjectToObject(cJSON * const object, const char * const name) +{ + cJSON *object_item = cJSON_CreateObject(); + if (add_item_to_object(object, name, object_item, &global_hooks, false)) + { + return object_item; + } + + cJSON_Delete(object_item); + return NULL; +} + +CJSON_PUBLIC(cJSON*) cJSON_AddArrayToObject(cJSON * const object, const char * const name) +{ + cJSON *array = cJSON_CreateArray(); + if (add_item_to_object(object, name, array, &global_hooks, false)) + { + return array; + } + + cJSON_Delete(array); + return NULL; +} + +CJSON_PUBLIC(cJSON *) cJSON_DetachItemViaPointer(cJSON *parent, cJSON * const item) +{ + if ((parent == NULL) || (item == NULL) || (item != parent->child && item->prev == NULL)) + { + return NULL; + } + + if (item != parent->child) + { + /* not the first element */ + item->prev->next = item->next; + } + if (item->next != NULL) + { + /* not the last element */ + item->next->prev = item->prev; + } + + if (item == parent->child) + { + /* first element */ + parent->child = item->next; + } + else if (item->next == NULL) + { + /* last element */ + parent->child->prev = item->prev; + } + + /* make sure the detached item doesn't point anywhere anymore */ + item->prev = NULL; + item->next = NULL; + + return item; +} + +CJSON_PUBLIC(cJSON *) cJSON_DetachItemFromArray(cJSON *array, int which) +{ + if (which < 0) + { + return NULL; + } + + return cJSON_DetachItemViaPointer(array, get_array_item(array, (size_t)which)); +} + +CJSON_PUBLIC(void) cJSON_DeleteItemFromArray(cJSON *array, int which) +{ + cJSON_Delete(cJSON_DetachItemFromArray(array, which)); +} + +CJSON_PUBLIC(cJSON *) cJSON_DetachItemFromObject(cJSON *object, const char *string) +{ + cJSON *to_detach = cJSON_GetObjectItem(object, string); + + return cJSON_DetachItemViaPointer(object, to_detach); +} + +CJSON_PUBLIC(cJSON *) cJSON_DetachItemFromObjectCaseSensitive(cJSON *object, const char *string) +{ + cJSON *to_detach = cJSON_GetObjectItemCaseSensitive(object, string); + + return cJSON_DetachItemViaPointer(object, to_detach); +} + +CJSON_PUBLIC(void) cJSON_DeleteItemFromObject(cJSON *object, const char *string) +{ + cJSON_Delete(cJSON_DetachItemFromObject(object, string)); +} + +CJSON_PUBLIC(void) cJSON_DeleteItemFromObjectCaseSensitive(cJSON *object, const char *string) +{ + cJSON_Delete(cJSON_DetachItemFromObjectCaseSensitive(object, string)); +} + +/* Replace array/object items with new ones. */ +CJSON_PUBLIC(cJSON_bool) cJSON_InsertItemInArray(cJSON *array, int which, cJSON *newitem) +{ + cJSON *after_inserted = NULL; + + if (which < 0 || newitem == NULL) + { + return false; + } + + after_inserted = get_array_item(array, (size_t)which); + if (after_inserted == NULL) + { + return add_item_to_array(array, newitem); + } + + if (after_inserted != array->child && after_inserted->prev == NULL) { + /* return false if after_inserted is a corrupted array item */ + return false; + } + + newitem->next = after_inserted; + newitem->prev = after_inserted->prev; + after_inserted->prev = newitem; + if (after_inserted == array->child) + { + array->child = newitem; + } + else + { + newitem->prev->next = newitem; + } + return true; +} + +CJSON_PUBLIC(cJSON_bool) cJSON_ReplaceItemViaPointer(cJSON * const parent, cJSON * const item, cJSON * replacement) +{ + if ((parent == NULL) || (parent->child == NULL) || (replacement == NULL) || (item == NULL)) + { + return false; + } + + if (replacement == item) + { + return true; + } + + replacement->next = item->next; + replacement->prev = item->prev; + + if (replacement->next != NULL) + { + replacement->next->prev = replacement; + } + if (parent->child == item) + { + if (parent->child->prev == parent->child) + { + replacement->prev = replacement; + } + parent->child = replacement; + } + else + { /* + * To find the last item in array quickly, we use prev in array. + * We can't modify the last item's next pointer where this item was the parent's child + */ + if (replacement->prev != NULL) + { + replacement->prev->next = replacement; + } + if (replacement->next == NULL) + { + parent->child->prev = replacement; + } + } + + item->next = NULL; + item->prev = NULL; + cJSON_Delete(item); + + return true; +} + +CJSON_PUBLIC(cJSON_bool) cJSON_ReplaceItemInArray(cJSON *array, int which, cJSON *newitem) +{ + if (which < 0) + { + return false; + } + + return cJSON_ReplaceItemViaPointer(array, get_array_item(array, (size_t)which), newitem); +} + +static cJSON_bool replace_item_in_object(cJSON *object, const char *string, cJSON *replacement, cJSON_bool case_sensitive) +{ + if ((replacement == NULL) || (string == NULL)) + { + return false; + } + + /* replace the name in the replacement */ + if (!(replacement->type & cJSON_StringIsConst) && (replacement->string != NULL)) + { + cJSON_free(replacement->string); + } + replacement->string = (char*)cJSON_strdup((const unsigned char*)string, &global_hooks); + if (replacement->string == NULL) + { + return false; + } + + replacement->type &= ~cJSON_StringIsConst; + + return cJSON_ReplaceItemViaPointer(object, get_object_item(object, string, case_sensitive), replacement); +} + +CJSON_PUBLIC(cJSON_bool) cJSON_ReplaceItemInObject(cJSON *object, const char *string, cJSON *newitem) +{ + return replace_item_in_object(object, string, newitem, false); +} + +CJSON_PUBLIC(cJSON_bool) cJSON_ReplaceItemInObjectCaseSensitive(cJSON *object, const char *string, cJSON *newitem) +{ + return replace_item_in_object(object, string, newitem, true); +} + +/* Create basic types: */ +CJSON_PUBLIC(cJSON *) cJSON_CreateNull(void) +{ + cJSON *item = cJSON_New_Item(&global_hooks); + if(item) + { + item->type = cJSON_NULL; + } + + return item; +} + +CJSON_PUBLIC(cJSON *) cJSON_CreateTrue(void) +{ + cJSON *item = cJSON_New_Item(&global_hooks); + if(item) + { + item->type = cJSON_True; + } + + return item; +} + +CJSON_PUBLIC(cJSON *) cJSON_CreateFalse(void) +{ + cJSON *item = cJSON_New_Item(&global_hooks); + if(item) + { + item->type = cJSON_False; + } + + return item; +} + +CJSON_PUBLIC(cJSON *) cJSON_CreateBool(cJSON_bool boolean) +{ + cJSON *item = cJSON_New_Item(&global_hooks); + if(item) + { + item->type = boolean ? cJSON_True : cJSON_False; + } + + return item; +} + +CJSON_PUBLIC(cJSON *) cJSON_CreateNumber(double num) +{ + cJSON *item = cJSON_New_Item(&global_hooks); + if(item) + { + item->type = cJSON_Number; + item->valuedouble = num; + + /* use saturation in case of overflow */ + if (num >= INT_MAX) + { + item->valueint = INT_MAX; + } + else if (num <= (double)INT_MIN) + { + item->valueint = INT_MIN; + } + else + { + item->valueint = (int)num; + } + } + + return item; +} + +CJSON_PUBLIC(cJSON *) cJSON_CreateString(const char *string) +{ + cJSON *item = cJSON_New_Item(&global_hooks); + if(item) + { + item->type = cJSON_String; + item->valuestring = (char*)cJSON_strdup((const unsigned char*)string, &global_hooks); + if(!item->valuestring) + { + cJSON_Delete(item); + return NULL; + } + } + + return item; +} + +CJSON_PUBLIC(cJSON *) cJSON_CreateStringReference(const char *string) +{ + cJSON *item = cJSON_New_Item(&global_hooks); + if (item != NULL) + { + item->type = cJSON_String | cJSON_IsReference; + item->valuestring = (char*)cast_away_const(string); + } + + return item; +} + +CJSON_PUBLIC(cJSON *) cJSON_CreateObjectReference(const cJSON *child) +{ + cJSON *item = cJSON_New_Item(&global_hooks); + if (item != NULL) { + item->type = cJSON_Object | cJSON_IsReference; + item->child = (cJSON*)cast_away_const(child); + } + + return item; +} + +CJSON_PUBLIC(cJSON *) cJSON_CreateArrayReference(const cJSON *child) { + cJSON *item = cJSON_New_Item(&global_hooks); + if (item != NULL) { + item->type = cJSON_Array | cJSON_IsReference; + item->child = (cJSON*)cast_away_const(child); + } + + return item; +} + +CJSON_PUBLIC(cJSON *) cJSON_CreateRaw(const char *raw) +{ + cJSON *item = cJSON_New_Item(&global_hooks); + if(item) + { + item->type = cJSON_Raw; + item->valuestring = (char*)cJSON_strdup((const unsigned char*)raw, &global_hooks); + if(!item->valuestring) + { + cJSON_Delete(item); + return NULL; + } + } + + return item; +} + +CJSON_PUBLIC(cJSON *) cJSON_CreateArray(void) +{ + cJSON *item = cJSON_New_Item(&global_hooks); + if(item) + { + item->type=cJSON_Array; + } + + return item; +} + +CJSON_PUBLIC(cJSON *) cJSON_CreateObject(void) +{ + cJSON *item = cJSON_New_Item(&global_hooks); + if (item) + { + item->type = cJSON_Object; + } + + return item; +} + +/* Create Arrays: */ +CJSON_PUBLIC(cJSON *) cJSON_CreateIntArray(const int *numbers, int count) +{ + size_t i = 0; + cJSON *n = NULL; + cJSON *p = NULL; + cJSON *a = NULL; + + if ((count < 0) || (numbers == NULL)) + { + return NULL; + } + + a = cJSON_CreateArray(); + + for(i = 0; a && (i < (size_t)count); i++) + { + n = cJSON_CreateNumber(numbers[i]); + if (!n) + { + cJSON_Delete(a); + return NULL; + } + if(!i) + { + a->child = n; + } + else + { + suffix_object(p, n); + } + p = n; + } + + if (a && a->child) { + a->child->prev = n; + } + + return a; +} + +CJSON_PUBLIC(cJSON *) cJSON_CreateFloatArray(const float *numbers, int count) +{ + size_t i = 0; + cJSON *n = NULL; + cJSON *p = NULL; + cJSON *a = NULL; + + if ((count < 0) || (numbers == NULL)) + { + return NULL; + } + + a = cJSON_CreateArray(); + + for(i = 0; a && (i < (size_t)count); i++) + { + n = cJSON_CreateNumber((double)numbers[i]); + if(!n) + { + cJSON_Delete(a); + return NULL; + } + if(!i) + { + a->child = n; + } + else + { + suffix_object(p, n); + } + p = n; + } + + if (a && a->child) { + a->child->prev = n; + } + + return a; +} + +CJSON_PUBLIC(cJSON *) cJSON_CreateDoubleArray(const double *numbers, int count) +{ + size_t i = 0; + cJSON *n = NULL; + cJSON *p = NULL; + cJSON *a = NULL; + + if ((count < 0) || (numbers == NULL)) + { + return NULL; + } + + a = cJSON_CreateArray(); + + for(i = 0; a && (i < (size_t)count); i++) + { + n = cJSON_CreateNumber(numbers[i]); + if(!n) + { + cJSON_Delete(a); + return NULL; + } + if(!i) + { + a->child = n; + } + else + { + suffix_object(p, n); + } + p = n; + } + + if (a && a->child) { + a->child->prev = n; + } + + return a; +} + +CJSON_PUBLIC(cJSON *) cJSON_CreateStringArray(const char *const *strings, int count) +{ + size_t i = 0; + cJSON *n = NULL; + cJSON *p = NULL; + cJSON *a = NULL; + + if ((count < 0) || (strings == NULL)) + { + return NULL; + } + + a = cJSON_CreateArray(); + + for (i = 0; a && (i < (size_t)count); i++) + { + n = cJSON_CreateString(strings[i]); + if(!n) + { + cJSON_Delete(a); + return NULL; + } + if(!i) + { + a->child = n; + } + else + { + suffix_object(p,n); + } + p = n; + } + + if (a && a->child) { + a->child->prev = n; + } + + return a; +} + +/* Duplication */ +cJSON * cJSON_Duplicate_rec(const cJSON *item, size_t depth, cJSON_bool recurse); + +CJSON_PUBLIC(cJSON *) cJSON_Duplicate(const cJSON *item, cJSON_bool recurse) +{ + return cJSON_Duplicate_rec(item, 0, recurse ); +} + +cJSON * cJSON_Duplicate_rec(const cJSON *item, size_t depth, cJSON_bool recurse) +{ + cJSON *newitem = NULL; + cJSON *child = NULL; + cJSON *next = NULL; + cJSON *newchild = NULL; + + /* Bail on bad ptr */ + if (!item) + { + goto fail; + } + /* Create new item */ + newitem = cJSON_New_Item(&global_hooks); + if (!newitem) + { + goto fail; + } + /* Copy over all vars */ + newitem->type = item->type & (~cJSON_IsReference); + newitem->valueint = item->valueint; + newitem->valuedouble = item->valuedouble; + if (item->valuestring) + { + newitem->valuestring = (char*)cJSON_strdup((unsigned char*)item->valuestring, &global_hooks); + if (!newitem->valuestring) + { + goto fail; + } + } + if (item->string) + { + newitem->string = (item->type&cJSON_StringIsConst) ? item->string : (char*)cJSON_strdup((unsigned char*)item->string, &global_hooks); + if (!newitem->string) + { + goto fail; + } + } + /* If non-recursive, then we're done! */ + if (!recurse) + { + return newitem; + } + /* Walk the ->next chain for the child. */ + child = item->child; + while (child != NULL) + { + if(depth >= CJSON_CIRCULAR_LIMIT) { + goto fail; + } + newchild = cJSON_Duplicate_rec(child, depth + 1, true); /* Duplicate (with recurse) each item in the ->next chain */ + if (!newchild) + { + goto fail; + } + if (next != NULL) + { + /* If newitem->child already set, then crosswire ->prev and ->next and move on */ + next->next = newchild; + newchild->prev = next; + next = newchild; + } + else + { + /* Set newitem->child and move to it */ + newitem->child = newchild; + next = newchild; + } + child = child->next; + } + if (newitem && newitem->child) + { + newitem->child->prev = newchild; + } + + return newitem; + +fail: + if (newitem != NULL) + { + cJSON_Delete(newitem); + } + + return NULL; +} + +static void skip_oneline_comment(char **input) +{ + *input += static_strlen("//"); + + for (; (*input)[0] != '\0'; ++(*input)) + { + if ((*input)[0] == '\n') { + *input += static_strlen("\n"); + return; + } + } +} + +static void skip_multiline_comment(char **input) +{ + *input += static_strlen("/*"); + + for (; (*input)[0] != '\0'; ++(*input)) + { + if (((*input)[0] == '*') && ((*input)[1] == '/')) + { + *input += static_strlen("*/"); + return; + } + } +} + +static void minify_string(char **input, char **output) { + (*output)[0] = (*input)[0]; + *input += static_strlen("\""); + *output += static_strlen("\""); + + + for (; (*input)[0] != '\0'; (void)++(*input), ++(*output)) { + (*output)[0] = (*input)[0]; + + if ((*input)[0] == '\"') { + (*output)[0] = '\"'; + *input += static_strlen("\""); + *output += static_strlen("\""); + return; + } else if (((*input)[0] == '\\') && ((*input)[1] == '\"')) { + (*output)[1] = (*input)[1]; + *input += static_strlen("\""); + *output += static_strlen("\""); + } + } +} + +CJSON_PUBLIC(void) cJSON_Minify(char *json) +{ + char *into = json; + + if (json == NULL) + { + return; + } + + while (json[0] != '\0') + { + switch (json[0]) + { + case ' ': + case '\t': + case '\r': + case '\n': + json++; + break; + + case '/': + if (json[1] == '/') + { + skip_oneline_comment(&json); + } + else if (json[1] == '*') + { + skip_multiline_comment(&json); + } else { + json++; + } + break; + + case '\"': + minify_string(&json, (char**)&into); + break; + + default: + into[0] = json[0]; + json++; + into++; + } + } + + /* and null-terminate. */ + *into = '\0'; +} + +CJSON_PUBLIC(cJSON_bool) cJSON_IsInvalid(const cJSON * const item) +{ + if (item == NULL) + { + return false; + } + + return (item->type & 0xFF) == cJSON_Invalid; +} + +CJSON_PUBLIC(cJSON_bool) cJSON_IsFalse(const cJSON * const item) +{ + if (item == NULL) + { + return false; + } + + return (item->type & 0xFF) == cJSON_False; +} + +CJSON_PUBLIC(cJSON_bool) cJSON_IsTrue(const cJSON * const item) +{ + if (item == NULL) + { + return false; + } + + return (item->type & 0xff) == cJSON_True; +} + + +CJSON_PUBLIC(cJSON_bool) cJSON_IsBool(const cJSON * const item) +{ + if (item == NULL) + { + return false; + } + + return (item->type & (cJSON_True | cJSON_False)) != 0; +} +CJSON_PUBLIC(cJSON_bool) cJSON_IsNull(const cJSON * const item) +{ + if (item == NULL) + { + return false; + } + + return (item->type & 0xFF) == cJSON_NULL; +} + +CJSON_PUBLIC(cJSON_bool) cJSON_IsNumber(const cJSON * const item) +{ + if (item == NULL) + { + return false; + } + + return (item->type & 0xFF) == cJSON_Number; +} + +CJSON_PUBLIC(cJSON_bool) cJSON_IsString(const cJSON * const item) +{ + if (item == NULL) + { + return false; + } + + return (item->type & 0xFF) == cJSON_String; +} + +CJSON_PUBLIC(cJSON_bool) cJSON_IsArray(const cJSON * const item) +{ + if (item == NULL) + { + return false; + } + + return (item->type & 0xFF) == cJSON_Array; +} + +CJSON_PUBLIC(cJSON_bool) cJSON_IsObject(const cJSON * const item) +{ + if (item == NULL) + { + return false; + } + + return (item->type & 0xFF) == cJSON_Object; +} + +CJSON_PUBLIC(cJSON_bool) cJSON_IsRaw(const cJSON * const item) +{ + if (item == NULL) + { + return false; + } + + return (item->type & 0xFF) == cJSON_Raw; +} + +CJSON_PUBLIC(cJSON_bool) cJSON_Compare(const cJSON * const a, const cJSON * const b, const cJSON_bool case_sensitive) +{ + if ((a == NULL) || (b == NULL) || ((a->type & 0xFF) != (b->type & 0xFF))) + { + return false; + } + + /* check if type is valid */ + switch (a->type & 0xFF) + { + case cJSON_False: + case cJSON_True: + case cJSON_NULL: + case cJSON_Number: + case cJSON_String: + case cJSON_Raw: + case cJSON_Array: + case cJSON_Object: + break; + + default: + return false; + } + + /* identical objects are equal */ + if (a == b) + { + return true; + } + + switch (a->type & 0xFF) + { + /* in these cases and equal type is enough */ + case cJSON_False: + case cJSON_True: + case cJSON_NULL: + return true; + + case cJSON_Number: + if (compare_double(a->valuedouble, b->valuedouble)) + { + return true; + } + return false; + + case cJSON_String: + case cJSON_Raw: + if ((a->valuestring == NULL) || (b->valuestring == NULL)) + { + return false; + } + if (strcmp(a->valuestring, b->valuestring) == 0) + { + return true; + } + + return false; + + case cJSON_Array: + { + cJSON *a_element = a->child; + cJSON *b_element = b->child; + + for (; (a_element != NULL) && (b_element != NULL);) + { + if (!cJSON_Compare(a_element, b_element, case_sensitive)) + { + return false; + } + + a_element = a_element->next; + b_element = b_element->next; + } + + /* one of the arrays is longer than the other */ + if (a_element != b_element) { + return false; + } + + return true; + } + + case cJSON_Object: + { + cJSON *a_element = NULL; + cJSON *b_element = NULL; + cJSON_ArrayForEach(a_element, a) + { + /* TODO This has O(n^2) runtime, which is horrible! */ + b_element = get_object_item(b, a_element->string, case_sensitive); + if (b_element == NULL) + { + return false; + } + + if (!cJSON_Compare(a_element, b_element, case_sensitive)) + { + return false; + } + } + + /* doing this twice, once on a and b to prevent true comparison if a subset of b + * TODO: Do this the proper way, this is just a fix for now */ + cJSON_ArrayForEach(b_element, b) + { + a_element = get_object_item(a, b_element->string, case_sensitive); + if (a_element == NULL) + { + return false; + } + + if (!cJSON_Compare(b_element, a_element, case_sensitive)) + { + return false; + } + } + + return true; + } + + default: + return false; + } +} + +CJSON_PUBLIC(void *) cJSON_malloc(size_t size) +{ + return global_hooks.allocate(size); +} + +CJSON_PUBLIC(void) cJSON_free(void *object) +{ + global_hooks.deallocate(object); + object = NULL; +} diff --git a/modules/vector-sets/cJSON.h b/modules/vector-sets/cJSON.h new file mode 100644 index 000000000..37520bbcf --- /dev/null +++ b/modules/vector-sets/cJSON.h @@ -0,0 +1,306 @@ +/* + Copyright (c) 2009-2017 Dave Gamble and cJSON contributors + + Permission is hereby granted, free of charge, to any person obtaining a copy + of this software and associated documentation files (the "Software"), to deal + in the Software without restriction, including without limitation the rights + to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + copies of the Software, and to permit persons to whom the Software is + furnished to do so, subject to the following conditions: + + The above copyright notice and this permission notice shall be included in + all copies or substantial portions of the Software. + + THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + THE SOFTWARE. +*/ + +#ifndef cJSON__h +#define cJSON__h + +#ifdef __cplusplus +extern "C" +{ +#endif + +#if !defined(__WINDOWS__) && (defined(WIN32) || defined(WIN64) || defined(_MSC_VER) || defined(_WIN32)) +#define __WINDOWS__ +#endif + +#ifdef __WINDOWS__ + +/* When compiling for windows, we specify a specific calling convention to avoid issues where we are being called from a project with a different default calling convention. For windows you have 3 define options: + +CJSON_HIDE_SYMBOLS - Define this in the case where you don't want to ever dllexport symbols +CJSON_EXPORT_SYMBOLS - Define this on library build when you want to dllexport symbols (default) +CJSON_IMPORT_SYMBOLS - Define this if you want to dllimport symbol + +For *nix builds that support visibility attribute, you can define similar behavior by + +setting default visibility to hidden by adding +-fvisibility=hidden (for gcc) +or +-xldscope=hidden (for sun cc) +to CFLAGS + +then using the CJSON_API_VISIBILITY flag to "export" the same symbols the way CJSON_EXPORT_SYMBOLS does + +*/ + +#define CJSON_CDECL __cdecl +#define CJSON_STDCALL __stdcall + +/* export symbols by default, this is necessary for copy pasting the C and header file */ +#if !defined(CJSON_HIDE_SYMBOLS) && !defined(CJSON_IMPORT_SYMBOLS) && !defined(CJSON_EXPORT_SYMBOLS) +#define CJSON_EXPORT_SYMBOLS +#endif + +#if defined(CJSON_HIDE_SYMBOLS) +#define CJSON_PUBLIC(type) type CJSON_STDCALL +#elif defined(CJSON_EXPORT_SYMBOLS) +#define CJSON_PUBLIC(type) __declspec(dllexport) type CJSON_STDCALL +#elif defined(CJSON_IMPORT_SYMBOLS) +#define CJSON_PUBLIC(type) __declspec(dllimport) type CJSON_STDCALL +#endif +#else /* !__WINDOWS__ */ +#define CJSON_CDECL +#define CJSON_STDCALL + +#if (defined(__GNUC__) || defined(__SUNPRO_CC) || defined (__SUNPRO_C)) && defined(CJSON_API_VISIBILITY) +#define CJSON_PUBLIC(type) __attribute__((visibility("default"))) type +#else +#define CJSON_PUBLIC(type) type +#endif +#endif + +/* project version */ +#define CJSON_VERSION_MAJOR 1 +#define CJSON_VERSION_MINOR 7 +#define CJSON_VERSION_PATCH 18 + +#include + +/* cJSON Types: */ +#define cJSON_Invalid (0) +#define cJSON_False (1 << 0) +#define cJSON_True (1 << 1) +#define cJSON_NULL (1 << 2) +#define cJSON_Number (1 << 3) +#define cJSON_String (1 << 4) +#define cJSON_Array (1 << 5) +#define cJSON_Object (1 << 6) +#define cJSON_Raw (1 << 7) /* raw json */ + +#define cJSON_IsReference 256 +#define cJSON_StringIsConst 512 + +/* The cJSON structure: */ +typedef struct cJSON +{ + /* next/prev allow you to walk array/object chains. Alternatively, use GetArraySize/GetArrayItem/GetObjectItem */ + struct cJSON *next; + struct cJSON *prev; + /* An array or object item will have a child pointer pointing to a chain of the items in the array/object. */ + struct cJSON *child; + + /* The type of the item, as above. */ + int type; + + /* The item's string, if type==cJSON_String and type == cJSON_Raw */ + char *valuestring; + /* writing to valueint is DEPRECATED, use cJSON_SetNumberValue instead */ + int valueint; + /* The item's number, if type==cJSON_Number */ + double valuedouble; + + /* The item's name string, if this item is the child of, or is in the list of subitems of an object. */ + char *string; +} cJSON; + +typedef struct cJSON_Hooks +{ + /* malloc/free are CDECL on Windows regardless of the default calling convention of the compiler, so ensure the hooks allow passing those functions directly. */ + void *(CJSON_CDECL *malloc_fn)(size_t sz); + void (CJSON_CDECL *free_fn)(void *ptr); +} cJSON_Hooks; + +typedef int cJSON_bool; + +/* Limits how deeply nested arrays/objects can be before cJSON rejects to parse them. + * This is to prevent stack overflows. */ +#ifndef CJSON_NESTING_LIMIT +#define CJSON_NESTING_LIMIT 1000 +#endif + +/* Limits the length of circular references can be before cJSON rejects to parse them. + * This is to prevent stack overflows. */ +#ifndef CJSON_CIRCULAR_LIMIT +#define CJSON_CIRCULAR_LIMIT 10000 +#endif + +/* returns the version of cJSON as a string */ +CJSON_PUBLIC(const char*) cJSON_Version(void); + +/* Supply malloc, realloc and free functions to cJSON */ +CJSON_PUBLIC(void) cJSON_InitHooks(cJSON_Hooks* hooks); + +/* Memory Management: the caller is always responsible to free the results from all variants of cJSON_Parse (with cJSON_Delete) and cJSON_Print (with stdlib free, cJSON_Hooks.free_fn, or cJSON_free as appropriate). The exception is cJSON_PrintPreallocated, where the caller has full responsibility of the buffer. */ +/* Supply a block of JSON, and this returns a cJSON object you can interrogate. */ +CJSON_PUBLIC(cJSON *) cJSON_Parse(const char *value); +CJSON_PUBLIC(cJSON *) cJSON_ParseWithLength(const char *value, size_t buffer_length); +/* ParseWithOpts allows you to require (and check) that the JSON is null terminated, and to retrieve the pointer to the final byte parsed. */ +/* If you supply a ptr in return_parse_end and parsing fails, then return_parse_end will contain a pointer to the error so will match cJSON_GetErrorPtr(). */ +CJSON_PUBLIC(cJSON *) cJSON_ParseWithOpts(const char *value, const char **return_parse_end, cJSON_bool require_null_terminated); +CJSON_PUBLIC(cJSON *) cJSON_ParseWithLengthOpts(const char *value, size_t buffer_length, const char **return_parse_end, cJSON_bool require_null_terminated); + +/* Render a cJSON entity to text for transfer/storage. */ +CJSON_PUBLIC(char *) cJSON_Print(const cJSON *item); +/* Render a cJSON entity to text for transfer/storage without any formatting. */ +CJSON_PUBLIC(char *) cJSON_PrintUnformatted(const cJSON *item); +/* Render a cJSON entity to text using a buffered strategy. prebuffer is a guess at the final size. guessing well reduces reallocation. fmt=0 gives unformatted, =1 gives formatted */ +CJSON_PUBLIC(char *) cJSON_PrintBuffered(const cJSON *item, int prebuffer, cJSON_bool fmt); +/* Render a cJSON entity to text using a buffer already allocated in memory with given length. Returns 1 on success and 0 on failure. */ +/* NOTE: cJSON is not always 100% accurate in estimating how much memory it will use, so to be safe allocate 5 bytes more than you actually need */ +CJSON_PUBLIC(cJSON_bool) cJSON_PrintPreallocated(cJSON *item, char *buffer, const int length, const cJSON_bool format); +/* Delete a cJSON entity and all subentities. */ +CJSON_PUBLIC(void) cJSON_Delete(cJSON *item); + +/* Returns the number of items in an array (or object). */ +CJSON_PUBLIC(int) cJSON_GetArraySize(const cJSON *array); +/* Retrieve item number "index" from array "array". Returns NULL if unsuccessful. */ +CJSON_PUBLIC(cJSON *) cJSON_GetArrayItem(const cJSON *array, int index); +/* Get item "string" from object. Case insensitive. */ +CJSON_PUBLIC(cJSON *) cJSON_GetObjectItem(const cJSON * const object, const char * const string); +CJSON_PUBLIC(cJSON *) cJSON_GetObjectItemCaseSensitive(const cJSON * const object, const char * const string); +CJSON_PUBLIC(cJSON_bool) cJSON_HasObjectItem(const cJSON *object, const char *string); +/* For analysing failed parses. This returns a pointer to the parse error. You'll probably need to look a few chars back to make sense of it. Defined when cJSON_Parse() returns 0. 0 when cJSON_Parse() succeeds. */ +CJSON_PUBLIC(const char *) cJSON_GetErrorPtr(void); + +/* Check item type and return its value */ +CJSON_PUBLIC(char *) cJSON_GetStringValue(const cJSON * const item); +CJSON_PUBLIC(double) cJSON_GetNumberValue(const cJSON * const item); + +/* These functions check the type of an item */ +CJSON_PUBLIC(cJSON_bool) cJSON_IsInvalid(const cJSON * const item); +CJSON_PUBLIC(cJSON_bool) cJSON_IsFalse(const cJSON * const item); +CJSON_PUBLIC(cJSON_bool) cJSON_IsTrue(const cJSON * const item); +CJSON_PUBLIC(cJSON_bool) cJSON_IsBool(const cJSON * const item); +CJSON_PUBLIC(cJSON_bool) cJSON_IsNull(const cJSON * const item); +CJSON_PUBLIC(cJSON_bool) cJSON_IsNumber(const cJSON * const item); +CJSON_PUBLIC(cJSON_bool) cJSON_IsString(const cJSON * const item); +CJSON_PUBLIC(cJSON_bool) cJSON_IsArray(const cJSON * const item); +CJSON_PUBLIC(cJSON_bool) cJSON_IsObject(const cJSON * const item); +CJSON_PUBLIC(cJSON_bool) cJSON_IsRaw(const cJSON * const item); + +/* These calls create a cJSON item of the appropriate type. */ +CJSON_PUBLIC(cJSON *) cJSON_CreateNull(void); +CJSON_PUBLIC(cJSON *) cJSON_CreateTrue(void); +CJSON_PUBLIC(cJSON *) cJSON_CreateFalse(void); +CJSON_PUBLIC(cJSON *) cJSON_CreateBool(cJSON_bool boolean); +CJSON_PUBLIC(cJSON *) cJSON_CreateNumber(double num); +CJSON_PUBLIC(cJSON *) cJSON_CreateString(const char *string); +/* raw json */ +CJSON_PUBLIC(cJSON *) cJSON_CreateRaw(const char *raw); +CJSON_PUBLIC(cJSON *) cJSON_CreateArray(void); +CJSON_PUBLIC(cJSON *) cJSON_CreateObject(void); + +/* Create a string where valuestring references a string so + * it will not be freed by cJSON_Delete */ +CJSON_PUBLIC(cJSON *) cJSON_CreateStringReference(const char *string); +/* Create an object/array that only references it's elements so + * they will not be freed by cJSON_Delete */ +CJSON_PUBLIC(cJSON *) cJSON_CreateObjectReference(const cJSON *child); +CJSON_PUBLIC(cJSON *) cJSON_CreateArrayReference(const cJSON *child); + +/* These utilities create an Array of count items. + * The parameter count cannot be greater than the number of elements in the number array, otherwise array access will be out of bounds.*/ +CJSON_PUBLIC(cJSON *) cJSON_CreateIntArray(const int *numbers, int count); +CJSON_PUBLIC(cJSON *) cJSON_CreateFloatArray(const float *numbers, int count); +CJSON_PUBLIC(cJSON *) cJSON_CreateDoubleArray(const double *numbers, int count); +CJSON_PUBLIC(cJSON *) cJSON_CreateStringArray(const char *const *strings, int count); + +/* Append item to the specified array/object. */ +CJSON_PUBLIC(cJSON_bool) cJSON_AddItemToArray(cJSON *array, cJSON *item); +CJSON_PUBLIC(cJSON_bool) cJSON_AddItemToObject(cJSON *object, const char *string, cJSON *item); +/* Use this when string is definitely const (i.e. a literal, or as good as), and will definitely survive the cJSON object. + * WARNING: When this function was used, make sure to always check that (item->type & cJSON_StringIsConst) is zero before + * writing to `item->string` */ +CJSON_PUBLIC(cJSON_bool) cJSON_AddItemToObjectCS(cJSON *object, const char *string, cJSON *item); +/* Append reference to item to the specified array/object. Use this when you want to add an existing cJSON to a new cJSON, but don't want to corrupt your existing cJSON. */ +CJSON_PUBLIC(cJSON_bool) cJSON_AddItemReferenceToArray(cJSON *array, cJSON *item); +CJSON_PUBLIC(cJSON_bool) cJSON_AddItemReferenceToObject(cJSON *object, const char *string, cJSON *item); + +/* Remove/Detach items from Arrays/Objects. */ +CJSON_PUBLIC(cJSON *) cJSON_DetachItemViaPointer(cJSON *parent, cJSON * const item); +CJSON_PUBLIC(cJSON *) cJSON_DetachItemFromArray(cJSON *array, int which); +CJSON_PUBLIC(void) cJSON_DeleteItemFromArray(cJSON *array, int which); +CJSON_PUBLIC(cJSON *) cJSON_DetachItemFromObject(cJSON *object, const char *string); +CJSON_PUBLIC(cJSON *) cJSON_DetachItemFromObjectCaseSensitive(cJSON *object, const char *string); +CJSON_PUBLIC(void) cJSON_DeleteItemFromObject(cJSON *object, const char *string); +CJSON_PUBLIC(void) cJSON_DeleteItemFromObjectCaseSensitive(cJSON *object, const char *string); + +/* Update array items. */ +CJSON_PUBLIC(cJSON_bool) cJSON_InsertItemInArray(cJSON *array, int which, cJSON *newitem); /* Shifts pre-existing items to the right. */ +CJSON_PUBLIC(cJSON_bool) cJSON_ReplaceItemViaPointer(cJSON * const parent, cJSON * const item, cJSON * replacement); +CJSON_PUBLIC(cJSON_bool) cJSON_ReplaceItemInArray(cJSON *array, int which, cJSON *newitem); +CJSON_PUBLIC(cJSON_bool) cJSON_ReplaceItemInObject(cJSON *object,const char *string,cJSON *newitem); +CJSON_PUBLIC(cJSON_bool) cJSON_ReplaceItemInObjectCaseSensitive(cJSON *object,const char *string,cJSON *newitem); + +/* Duplicate a cJSON item */ +CJSON_PUBLIC(cJSON *) cJSON_Duplicate(const cJSON *item, cJSON_bool recurse); +/* Duplicate will create a new, identical cJSON item to the one you pass, in new memory that will + * need to be released. With recurse!=0, it will duplicate any children connected to the item. + * The item->next and ->prev pointers are always zero on return from Duplicate. */ +/* Recursively compare two cJSON items for equality. If either a or b is NULL or invalid, they will be considered unequal. + * case_sensitive determines if object keys are treated case sensitive (1) or case insensitive (0) */ +CJSON_PUBLIC(cJSON_bool) cJSON_Compare(const cJSON * const a, const cJSON * const b, const cJSON_bool case_sensitive); + +/* Minify a strings, remove blank characters(such as ' ', '\t', '\r', '\n') from strings. + * The input pointer json cannot point to a read-only address area, such as a string constant, + * but should point to a readable and writable address area. */ +CJSON_PUBLIC(void) cJSON_Minify(char *json); + +/* Helper functions for creating and adding items to an object at the same time. + * They return the added item or NULL on failure. */ +CJSON_PUBLIC(cJSON*) cJSON_AddNullToObject(cJSON * const object, const char * const name); +CJSON_PUBLIC(cJSON*) cJSON_AddTrueToObject(cJSON * const object, const char * const name); +CJSON_PUBLIC(cJSON*) cJSON_AddFalseToObject(cJSON * const object, const char * const name); +CJSON_PUBLIC(cJSON*) cJSON_AddBoolToObject(cJSON * const object, const char * const name, const cJSON_bool boolean); +CJSON_PUBLIC(cJSON*) cJSON_AddNumberToObject(cJSON * const object, const char * const name, const double number); +CJSON_PUBLIC(cJSON*) cJSON_AddStringToObject(cJSON * const object, const char * const name, const char * const string); +CJSON_PUBLIC(cJSON*) cJSON_AddRawToObject(cJSON * const object, const char * const name, const char * const raw); +CJSON_PUBLIC(cJSON*) cJSON_AddObjectToObject(cJSON * const object, const char * const name); +CJSON_PUBLIC(cJSON*) cJSON_AddArrayToObject(cJSON * const object, const char * const name); + +/* When assigning an integer value, it needs to be propagated to valuedouble too. */ +#define cJSON_SetIntValue(object, number) ((object) ? (object)->valueint = (object)->valuedouble = (number) : (number)) +/* helper for the cJSON_SetNumberValue macro */ +CJSON_PUBLIC(double) cJSON_SetNumberHelper(cJSON *object, double number); +#define cJSON_SetNumberValue(object, number) ((object != NULL) ? cJSON_SetNumberHelper(object, (double)number) : (number)) +/* Change the valuestring of a cJSON_String object, only takes effect when type of object is cJSON_String */ +CJSON_PUBLIC(char*) cJSON_SetValuestring(cJSON *object, const char *valuestring); + +/* If the object is not a boolean type this does nothing and returns cJSON_Invalid else it returns the new type*/ +#define cJSON_SetBoolValue(object, boolValue) ( \ + (object != NULL && ((object)->type & (cJSON_False|cJSON_True))) ? \ + (object)->type=((object)->type &(~(cJSON_False|cJSON_True)))|((boolValue)?cJSON_True:cJSON_False) : \ + cJSON_Invalid\ +) + +/* Macro for iterating over an array or object */ +#define cJSON_ArrayForEach(element, array) for(element = (array != NULL) ? (array)->child : NULL; element != NULL; element = element->next) + +/* malloc/free objects using the malloc/free functions that have been set with cJSON_InitHooks */ +CJSON_PUBLIC(void *) cJSON_malloc(size_t size); +CJSON_PUBLIC(void) cJSON_free(void *object); + +#ifdef __cplusplus +} +#endif + +#endif diff --git a/modules/vector-sets/examples/cli-tool/.gitignore b/modules/vector-sets/examples/cli-tool/.gitignore new file mode 100644 index 000000000..5ceb3864c --- /dev/null +++ b/modules/vector-sets/examples/cli-tool/.gitignore @@ -0,0 +1 @@ +venv diff --git a/modules/vector-sets/examples/cli-tool/README.md b/modules/vector-sets/examples/cli-tool/README.md new file mode 100644 index 000000000..ad217447f --- /dev/null +++ b/modules/vector-sets/examples/cli-tool/README.md @@ -0,0 +1,44 @@ +This tool is similar to redis-cli (but very basic) but allows +to specify arguments that are expanded as vectors by calling +ollama to get the embedding. + +Whatever is passed as !"foo bar" gets expanded into + VALUES ... embedding ... + +You must have ollama running with the mxbai-emb-large model +already installed for this to work. + +Example: + + redis> KEYS * + 1) food_items + 2) glove_embeddings_bin + 3) many_movies_mxbai-embed-large_BIN + 4) many_movies_mxbai-embed-large_NOQUANT + 5) word_embeddings + 6) word_embeddings_bin + 7) glove_embeddings_fp32 + + redis> VSIM food_items !"drinks with fruit" + 1) (Fruit)Juices,Lemonade,100ml,50 cal,210 kJ + 2) (Fruit)Juices,Limeade,100ml,128 cal,538 kJ + 3) CannedFruit,Canned Fruit Cocktail,100g,81 cal,340 kJ + 4) (Fruit)Juices,Energy-Drink,100ml,87 cal,365 kJ + 5) Fruits,Lime,100g,30 cal,126 kJ + 6) (Fruit)Juices,Coconut Water,100ml,19 cal,80 kJ + 7) Fruits,Lemon,100g,29 cal,122 kJ + 8) (Fruit)Juices,Clamato,100ml,60 cal,252 kJ + 9) Fruits,Fruit salad,100g,50 cal,210 kJ + 10) (Fruit)Juices,Capri-Sun,100ml,41 cal,172 kJ + + redis> vsim food_items !"barilla" + 1) Pasta&Noodles,Spirelli,100g,367 cal,1541 kJ + 2) Pasta&Noodles,Farfalle,100g,358 cal,1504 kJ + 3) Pasta&Noodles,Capellini,100g,353 cal,1483 kJ + 4) Pasta&Noodles,Spaetzle,100g,368 cal,1546 kJ + 5) Pasta&Noodles,Cappelletti,100g,164 cal,689 kJ + 6) Pasta&Noodles,Penne,100g,351 cal,1474 kJ + 7) Pasta&Noodles,Shells,100g,353 cal,1483 kJ + 8) Pasta&Noodles,Linguine,100g,357 cal,1499 kJ + 9) Pasta&Noodles,Rotini,100g,353 cal,1483 kJ + 10) Pasta&Noodles,Rigatoni,100g,353 cal,1483 kJ diff --git a/modules/vector-sets/examples/cli-tool/cli.py b/modules/vector-sets/examples/cli-tool/cli.py new file mode 100755 index 000000000..a60c5facc --- /dev/null +++ b/modules/vector-sets/examples/cli-tool/cli.py @@ -0,0 +1,138 @@ +#!/usr/bin/env python3 +import redis +import requests +import re +import shlex +from prompt_toolkit import PromptSession +from prompt_toolkit.history import InMemoryHistory + +def get_embedding(text): + """Get embedding from local Ollama API""" + url = "http://localhost:11434/api/embeddings" + payload = { + "model": "mxbai-embed-large", + "prompt": text + } + try: + response = requests.post(url, json=payload) + response.raise_for_status() + return response.json()['embedding'] + except requests.exceptions.RequestException as e: + raise Exception(f"Failed to get embedding: {str(e)}") + +def process_embedding_patterns(text): + """Process !"text" and !!"text" patterns in the command""" + + def replace_with_embedding(match): + text = match.group(1) + embedding = get_embedding(text) + return f"VALUES {len(embedding)} {' '.join(map(str, embedding))}" + + def replace_with_embedding_and_text(match): + text = match.group(1) + embedding = get_embedding(text) + # Return both the embedding values and the original text as next argument + return f'VALUES {len(embedding)} {" ".join(map(str, embedding))} "{text}"' + + # First handle !!"text" pattern (must be done before !"text") + text = re.sub(r'!!"([^"]*)"', replace_with_embedding_and_text, text) + # Then handle !"text" pattern + text = re.sub(r'!"([^"]*)"', replace_with_embedding, text) + return text + +def parse_command(command): + """Parse command respecting quoted strings""" + try: + # Use shlex to properly handle quoted strings + return shlex.split(command) + except ValueError as e: + raise Exception(f"Invalid command syntax: {str(e)}") + +def format_response(response): + """Format the response to match Redis protocol style""" + if response is None: + return "(nil)" + elif isinstance(response, bool): + return "+OK" if response else "(error) Operation failed" + elif isinstance(response, (list, set)): + if not response: + return "(empty list or set)" + return "\n".join(f"{i+1}) {item}" for i, item in enumerate(response)) + elif isinstance(response, int): + return f"(integer) {response}" + else: + return str(response) + +def main(): + # Default connection to localhost:6379 + r = redis.Redis(host='localhost', port=6379, decode_responses=True) + + try: + # Test connection + r.ping() + print("Connected to Redis. Type your commands (CTRL+D to exit):") + print("Special syntax:") + print(" !\"text\" - Replace with embedding") + print(" !!\"text\" - Replace with embedding and append text as value") + print(" \"text\" - Quote strings containing spaces") + except redis.ConnectionError: + print("Error: Could not connect to Redis server") + return + + # Setup prompt session with history + session = PromptSession(history=InMemoryHistory()) + + # Main loop + while True: + try: + # Read input with line editing support + command = session.prompt("redis> ") + + # Skip empty commands + if not command.strip(): + continue + + # Process any embedding patterns before parsing + try: + processed_command = process_embedding_patterns(command) + except Exception as e: + print(f"(error) Embedding processing failed: {str(e)}") + continue + + # Parse the command respecting quoted strings + try: + parts = parse_command(processed_command) + except Exception as e: + print(f"(error) {str(e)}") + continue + + if not parts: + continue + + cmd = parts[0].lower() + args = parts[1:] + + # Execute command + try: + method = getattr(r, cmd, None) + if method is not None: + result = method(*args) + else: + # Use execute_command for unknown commands + result = r.execute_command(cmd, *args) + print(format_response(result)) + except AttributeError: + print(f"(error) Unknown command '{cmd}'") + + except EOFError: + print("\nGoodbye!") + break + except KeyboardInterrupt: + continue # Allow Ctrl+C to clear current line + except redis.RedisError as e: + print(f"(error) {str(e)}") + except Exception as e: + print(f"(error) {str(e)}") + +if __name__ == "__main__": + main() diff --git a/modules/vector-sets/examples/glove-100/README b/modules/vector-sets/examples/glove-100/README new file mode 100644 index 000000000..e8bb6dd6f --- /dev/null +++ b/modules/vector-sets/examples/glove-100/README @@ -0,0 +1,3 @@ +wget http://ann-benchmarks.com/glove-100-angular.hdf5 +python insert.py +python recall.py (use --k optionally, default top-10) diff --git a/modules/vector-sets/examples/glove-100/insert.py b/modules/vector-sets/examples/glove-100/insert.py new file mode 100644 index 000000000..fe9658343 --- /dev/null +++ b/modules/vector-sets/examples/glove-100/insert.py @@ -0,0 +1,47 @@ +import h5py +import redis +from tqdm import tqdm + +# Initialize Redis connection +redis_client = redis.Redis(host='localhost', port=6379, decode_responses=True, encoding='utf-8') + +def add_to_redis(index, embedding): + """Add embedding to Redis using VADD command""" + args = ["VADD", "glove_embeddings", "VALUES", "100"] # 100 is vector dimension + args.extend(map(str, embedding)) + args.append(f"{index}") # Using index as identifier since we don't have words + args.append("EF") + args.append("200") + # args.append("NOQUANT") + # args.append("BIN") + redis_client.execute_command(*args) + +def main(): + with h5py.File('glove-100-angular.hdf5', 'r') as f: + # Get the train dataset + train_vectors = f['train'] + total_vectors = train_vectors.shape[0] + + print(f"Starting to process {total_vectors} vectors...") + + # Process in batches to avoid memory issues + batch_size = 1000 + + for i in tqdm(range(0, total_vectors, batch_size)): + batch_end = min(i + batch_size, total_vectors) + batch = train_vectors[i:batch_end] + + for j, vector in enumerate(batch): + try: + current_index = i + j + add_to_redis(current_index, vector) + + except Exception as e: + print(f"Error processing vector {current_index}: {str(e)}") + continue + + if (i + batch_size) % 10000 == 0: + print(f"Processed {i + batch_size} vectors") + +if __name__ == "__main__": + main() diff --git a/modules/vector-sets/examples/glove-100/recall.py b/modules/vector-sets/examples/glove-100/recall.py new file mode 100644 index 000000000..28982b3e9 --- /dev/null +++ b/modules/vector-sets/examples/glove-100/recall.py @@ -0,0 +1,78 @@ +import h5py +import redis +import numpy as np +from tqdm import tqdm +import argparse + +# Initialize Redis connection +redis_client = redis.Redis(host='localhost', port=6379, decode_responses=True, encoding='utf-8') + +def get_redis_neighbors(query_vector, k): + """Get nearest neighbors using Redis VSIM command""" + args = ["VSIM", "glove_embeddings_bin", "VALUES", "100"] + args.extend(map(str, query_vector)) + args.extend(["COUNT", str(k)]) + args.extend(["EF", 100]) + if False: + print(args) + exit(1) + results = redis_client.execute_command(*args) + return [int(res) for res in results] + +def calculate_recall(ground_truth, predicted, k): + """Calculate recall@k""" + relevant = set(ground_truth[:k]) + retrieved = set(predicted[:k]) + return len(relevant.intersection(retrieved)) / len(relevant) + +def main(): + parser = argparse.ArgumentParser(description='Evaluate Redis VSIM recall') + parser.add_argument('--k', type=int, default=10, help='Number of neighbors to evaluate (default: 10)') + parser.add_argument('--batch', type=int, default=100, help='Progress update frequency (default: 100)') + args = parser.parse_args() + + k = args.k + batch_size = args.batch + + with h5py.File('glove-100-angular.hdf5', 'r') as f: + test_vectors = f['test'][:] + ground_truth_neighbors = f['neighbors'][:] + + num_queries = len(test_vectors) + recalls = [] + + print(f"Evaluating recall@{k} for {num_queries} test queries...") + + for i in tqdm(range(num_queries)): + try: + # Get Redis results + redis_neighbors = get_redis_neighbors(test_vectors[i], k) + + # Get ground truth for this query + true_neighbors = ground_truth_neighbors[i] + + # Calculate recall + recall = calculate_recall(true_neighbors, redis_neighbors, k) + recalls.append(recall) + + if (i + 1) % batch_size == 0: + current_avg_recall = np.mean(recalls) + print(f"Current average recall@{k} after {i+1} queries: {current_avg_recall:.4f}") + + except Exception as e: + print(f"Error processing query {i}: {str(e)}") + continue + + final_recall = np.mean(recalls) + print("\nFinal Results:") + print(f"Average recall@{k}: {final_recall:.4f}") + print(f"Total queries evaluated: {len(recalls)}") + + # Save detailed results + with open(f'recall_evaluation_results_k{k}.txt', 'w') as f: + f.write(f"Average recall@{k}: {final_recall:.4f}\n") + f.write(f"Total queries evaluated: {len(recalls)}\n") + f.write(f"Individual query recalls: {recalls}\n") + +if __name__ == "__main__": + main() diff --git a/modules/vector-sets/examples/movies/.gitignore b/modules/vector-sets/examples/movies/.gitignore new file mode 100644 index 000000000..e736c6ad2 --- /dev/null +++ b/modules/vector-sets/examples/movies/.gitignore @@ -0,0 +1,2 @@ +mpst_full_data.csv +partition.json diff --git a/modules/vector-sets/examples/movies/README b/modules/vector-sets/examples/movies/README new file mode 100644 index 000000000..3931a6d2f --- /dev/null +++ b/modules/vector-sets/examples/movies/README @@ -0,0 +1,30 @@ +This example maps long form movies plots to movies titles. +It will create fp32 and binary vectors (the two extremes). + +1. Install ollama, and install the embedding model "mxbai-embed-large" +2. Download mpst_full_data.csv from https://www.kaggle.com/datasets/cryptexcode/mpst-movie-plot-synopses-with-tags +3. python insert.py + +127.0.0.1:6379> VSIM many_movies_mxbai-embed-large_NOQUANT ELE "The Matrix" + 1) "The Matrix" + 2) "The Matrix Reloaded" + 3) "The Matrix Revolutions" + 4) "Commando" + 5) "Avatar" + 6) "Forbidden Planet" + 7) "Terminator Salvation" + 8) "Mandroid" + 9) "The Omega Code" +10) "Coherence" + +127.0.0.1:6379> VSIM many_movies_mxbai-embed-large_BIN ELE "The Matrix" + 1) "The Matrix" + 2) "The Matrix Reloaded" + 3) "The Matrix Revolutions" + 4) "The Omega Code" + 5) "Forbidden Planet" + 6) "Avatar" + 7) "John Carter" + 8) "System Shock 2" + 9) "Coherence" +10) "Tomorrowland" diff --git a/modules/vector-sets/examples/movies/insert.py b/modules/vector-sets/examples/movies/insert.py new file mode 100644 index 000000000..576243667 --- /dev/null +++ b/modules/vector-sets/examples/movies/insert.py @@ -0,0 +1,48 @@ +import csv +import requests +import redis + +ModelName="mxbai-embed-large" + +# Initialize Redis connection, setting encoding to utf-8 +redis_client = redis.Redis(host='localhost', port=6379, decode_responses=True, encoding='utf-8') + +def get_embedding(text): + """Get embedding from local API""" + url = "http://localhost:11434/api/embeddings" + payload = { + "model": ModelName, + "prompt": "Represent this movie plot and genre: "+text + } + response = requests.post(url, json=payload) + return response.json()['embedding'] + +def add_to_redis(title, embedding, quant_type): + """Add embedding to Redis using VADD command""" + args = ["VADD", "many_movies_"+ModelName+"_"+quant_type, "VALUES", str(len(embedding))] + args.extend(map(str, embedding)) + args.append(title) + args.append(quant_type) + redis_client.execute_command(*args) + +def main(): + with open('mpst_full_data.csv', 'r', encoding='utf-8') as file: + reader = csv.DictReader(file) + + for movie in reader: + try: + text_to_embed = f"{movie['title']} {movie['plot_synopsis']} {movie['tags']}" + + print(f"Getting embedding for: {movie['title']}") + embedding = get_embedding(text_to_embed) + + add_to_redis(movie['title'], embedding, "BIN") + add_to_redis(movie['title'], embedding, "NOQUANT") + print(f"Successfully processed: {movie['title']}") + + except Exception as e: + print(f"Error processing {movie['title']}: {str(e)}") + continue + +if __name__ == "__main__": + main() diff --git a/modules/vector-sets/expr.c b/modules/vector-sets/expr.c new file mode 100644 index 000000000..d9712921e --- /dev/null +++ b/modules/vector-sets/expr.c @@ -0,0 +1,995 @@ +/* Filtering of objects based on simple expressions. + * This powers the FILTER option of Vector Sets, but it is otherwise + * general code to be used when we want to tell if a given object (with fields) + * passes or fails a given test for scalars, strings, ... + * + * Copyright(C) 2024-Present, Redis Ltd. All Rights Reserved. + * Originally authored by: Salvatore Sanfilippo. + */ + +#include +#include +#include +#include +#include +#include "cJSON.h" + +#ifdef TEST_MAIN +#define RedisModule_Alloc malloc +#define RedisModule_Realloc realloc +#define RedisModule_Free free +#define RedisModule_Strdup strdup +#endif + +#define EXPR_TOKEN_EOF 0 +#define EXPR_TOKEN_NUM 1 +#define EXPR_TOKEN_STR 2 +#define EXPR_TOKEN_TUPLE 3 +#define EXPR_TOKEN_SELECTOR 4 +#define EXPR_TOKEN_OP 5 + +#define EXPR_OP_OPAREN 0 /* ( */ +#define EXPR_OP_CPAREN 1 /* ) */ +#define EXPR_OP_NOT 2 /* ! */ +#define EXPR_OP_POW 3 /* ** */ +#define EXPR_OP_MULT 4 /* * */ +#define EXPR_OP_DIV 5 /* / */ +#define EXPR_OP_MOD 6 /* % */ +#define EXPR_OP_SUM 7 /* + */ +#define EXPR_OP_DIFF 8 /* - */ +#define EXPR_OP_GT 9 /* > */ +#define EXPR_OP_GTE 10 /* >= */ +#define EXPR_OP_LT 11 /* < */ +#define EXPR_OP_LTE 12 /* <= */ +#define EXPR_OP_EQ 13 /* == */ +#define EXPR_OP_NEQ 14 /* != */ +#define EXPR_OP_IN 15 /* in */ +#define EXPR_OP_AND 16 /* and */ +#define EXPR_OP_OR 17 /* or */ + +/* This structure represents a token in our expression. It's either + * literals like 4, "foo", or operators like "+", "-", "and", or + * json selectors, that start with a dot: ".age", ".properties.somearray[1]" */ +typedef struct exprtoken { + int refcount; // Reference counting for memory reclaiming. + int token_type; // Token type of the just parsed token. + int offset; // Chars offset in expression. + union { + double num; // Value for EXPR_TOKEN_NUM. + struct { + char *start; // String pointer for EXPR_TOKEN_STR / SELECTOR. + size_t len; // String len for EXPR_TOKEN_STR / SELECTOR. + char *heapstr; // True if we have a private allocation for this + // string. When possible, it just references to the + // string expression we compiled, exprstate->expr. + } str; + int opcode; // Opcode ID for EXPR_TOKEN_OP. + struct { + struct exprtoken **ele; + size_t len; + } tuple; // Tuples are like [1, 2, 3] for "in" operator. + }; +} exprtoken; + +/* Simple stack of expr tokens. This is used both to represent the stack + * of values and the stack of operands during VM execution. */ +typedef struct exprstack { + exprtoken **items; + int numitems; + int allocsize; +} exprstack; + +typedef struct exprstate { + char *expr; /* Expression string to compile. Note that + * expression token strings point directly to this + * string. */ + char *p; // Currnet position inside 'expr', while parsing. + + // Virtual machine state. + exprstack values_stack; + exprstack ops_stack; // Operator stack used during compilation. + exprstack tokens; // Expression processed into a sequence of tokens. + exprstack program; // Expression compiled into opcodes and values. +} exprstate; + +/* Valid operators. */ +struct { + char *opname; + int oplen; + int opcode; + int precedence; + int arity; +} ExprOptable[] = { + {"(", 1, EXPR_OP_OPAREN, 7, 0}, + {")", 1, EXPR_OP_CPAREN, 7, 0}, + {"!", 1, EXPR_OP_NOT, 6, 1}, + {"not", 3, EXPR_OP_NOT, 6, 1}, + {"**", 2, EXPR_OP_POW, 5, 2}, + {"*", 1, EXPR_OP_MULT, 4, 2}, + {"/", 1, EXPR_OP_DIV, 4, 2}, + {"%", 1, EXPR_OP_MOD, 4, 2}, + {"+", 1, EXPR_OP_SUM, 3, 2}, + {"-", 1, EXPR_OP_DIFF, 3, 2}, + {">", 1, EXPR_OP_GT, 2, 2}, + {">=", 2, EXPR_OP_GTE, 2, 2}, + {"<", 1, EXPR_OP_LT, 2, 2}, + {"<=", 2, EXPR_OP_LTE, 2, 2}, + {"==", 2, EXPR_OP_EQ, 2, 2}, + {"!=", 2, EXPR_OP_NEQ, 2, 2}, + {"in", 2, EXPR_OP_IN, 2, 2}, + {"and", 3, EXPR_OP_AND, 1, 2}, + {"&&", 2, EXPR_OP_AND, 1, 2}, + {"or", 2, EXPR_OP_OR, 0, 2}, + {"||", 2, EXPR_OP_OR, 0, 2}, + {NULL, 0, 0, 0, 0} // Terminator. +}; + +#define EXPR_OP_SPECIALCHARS "+-*%/!()<>=|&" +#define EXPR_SELECTOR_SPECIALCHARS "_-" + +/* ================================ Expr token ============================== */ + +/* Return an heap allocated token of the specified type, setting the + * reference count to 1. */ +exprtoken *exprNewToken(int type) { + exprtoken *t = RedisModule_Alloc(sizeof(exprtoken)); + memset(t,0,sizeof(*t)); + t->token_type = type; + t->refcount = 1; + return t; +} + +/* Generic free token function, can be used to free stack allocated + * objects (in this case the pointer itself will not be freed) or + * heap allocated objects. See the wrappers below. */ +void exprTokenRelease(exprtoken *t) { + if (t == NULL) return; + + if (t->refcount <= 0) { + printf("exprTokenRelease() against a token with refcount %d!\n" + "Aborting program execution\n", + t->refcount); + exit(1); + } + t->refcount--; + if (t->refcount > 0) return; + + // We reached refcount 0: free the object. + if (t->token_type == EXPR_TOKEN_STR) { + if (t->str.heapstr != NULL) RedisModule_Free(t->str.heapstr); + } else if (t->token_type == EXPR_TOKEN_TUPLE) { + for (size_t j = 0; j < t->tuple.len; j++) + exprTokenRelease(t->tuple.ele[j]); + if (t->tuple.ele) RedisModule_Free(t->tuple.ele); + } + RedisModule_Free(t); +} + +void exprTokenRetain(exprtoken *t) { + t->refcount++; +} + +/* ============================== Stack handling ============================ */ + +#include +#include + +#define EXPR_STACK_INITIAL_SIZE 16 + +/* Initialize a new expression stack. */ +void exprStackInit(exprstack *stack) { + stack->items = RedisModule_Alloc(sizeof(exprtoken*) * EXPR_STACK_INITIAL_SIZE); + stack->numitems = 0; + stack->allocsize = EXPR_STACK_INITIAL_SIZE; +} + +/* Push a token pointer onto the stack. Does not increment the refcount + * of the token: it is up to the caller doing this. */ +void exprStackPush(exprstack *stack, exprtoken *token) { + /* Check if we need to grow the stack. */ + if (stack->numitems == stack->allocsize) { + size_t newsize = stack->allocsize * 2; + exprtoken **newitems = + RedisModule_Realloc(stack->items, sizeof(exprtoken*) * newsize); + stack->items = newitems; + stack->allocsize = newsize; + } + stack->items[stack->numitems] = token; + stack->numitems++; +} + +/* Pop a token pointer from the stack. Return NULL if the stack is + * empty. Does NOT recrement the refcount of the token, it's up to the + * caller to do so, as the new owner of the reference. */ +exprtoken *exprStackPop(exprstack *stack) { + if (stack->numitems == 0) return NULL; + stack->numitems--; + return stack->items[stack->numitems]; +} + +/* Just return the last element pushed, without consuming it nor altering + * the reference count. */ +exprtoken *exprStackPeek(exprstack *stack) { + if (stack->numitems == 0) return NULL; + return stack->items[stack->numitems-1]; +} + +/* Free the stack structure state, including the items it contains, that are + * assumed to be heap allocated. The passed pointer itself is not freed. */ +void exprStackFree(exprstack *stack) { + for (int j = 0; j < stack->numitems; j++) + exprTokenRelease(stack->items[j]); + RedisModule_Free(stack->items); +} + +/* Just reset the stack removing all the items, but leaving it in a state + * that makes it still usable for new elements. */ +void exprStackReset(exprstack *stack) { + for (int j = 0; j < stack->numitems; j++) + exprTokenRelease(stack->items[j]); + stack->numitems = 0; +} + +/* =========================== Expression compilation ======================= */ + +void exprConsumeSpaces(exprstate *es) { + while(es->p[0] && isspace(es->p[0])) es->p++; +} + +/* Parse an operator, trying to match the longer match in the + * operators table. */ +exprtoken *exprParseOperator(exprstate *es) { + exprtoken *t = exprNewToken(EXPR_TOKEN_OP); + char *start = es->p; + + while(es->p[0] && + (isalpha(es->p[0]) || + strchr(EXPR_OP_SPECIALCHARS,es->p[0]) != NULL)) + { + es->p++; + } + + int matchlen = es->p - start; + int bestlen = 0; + int j; + + // Find the longest matching operator. + for (j = 0; ExprOptable[j].opname != NULL; j++) { + if (ExprOptable[j].oplen > matchlen) continue; + if (memcmp(ExprOptable[j].opname, start, ExprOptable[j].oplen) != 0) + { + continue; + } + if (ExprOptable[j].oplen > bestlen) { + t->opcode = ExprOptable[j].opcode; + bestlen = ExprOptable[j].oplen; + } + } + if (bestlen == 0) { + exprTokenRelease(t); + return NULL; + } else { + es->p = start + bestlen; + } + return t; +} + +// Valid selector charset. +static int is_selector_char(int c) { + return (isalpha(c) || + isdigit(c) || + strchr(EXPR_SELECTOR_SPECIALCHARS,c) != NULL); +} + +/* Parse selectors, they start with a dot and can have alphanumerical + * or few special chars. */ +exprtoken *exprParseSelector(exprstate *es) { + exprtoken *t = exprNewToken(EXPR_TOKEN_SELECTOR); + es->p++; // Skip dot. + char *start = es->p; + + while(es->p[0] && is_selector_char(es->p[0])) es->p++; + int matchlen = es->p - start; + t->str.start = start; + t->str.len = matchlen; + return t; +} + +exprtoken *exprParseNumber(exprstate *es) { + exprtoken *t = exprNewToken(EXPR_TOKEN_NUM); + char num[64]; + int idx = 0; + while(isdigit(es->p[0]) || es->p[0] == '.' || es->p[0] == 'e' || + es->p[0] == 'E' || (idx == 0 && es->p[0] == '-')) + { + if (idx >= (int)sizeof(num)-1) { + exprTokenRelease(t); + return NULL; + } + num[idx++] = es->p[0]; + es->p++; + } + num[idx] = 0; + + char *endptr; + t->num = strtod(num, &endptr); + if (*endptr != '\0') { + exprTokenRelease(t); + return NULL; + } + return t; +} + +exprtoken *exprParseString(exprstate *es) { + char quote = es->p[0]; /* Store the quote type (' or "). */ + es->p++; /* Skip opening quote. */ + + exprtoken *t = exprNewToken(EXPR_TOKEN_STR); + t->str.start = es->p; + + while(es->p[0] != '\0') { + if (es->p[0] == '\\' && es->p[1] != '\0') { + es->p += 2; // Skip escaped char. + continue; + } + if (es->p[0] == quote) { + t->str.len = es->p - t->str.start; + es->p++; // Skip closing quote. + return t; + } + es->p++; + } + /* If we reach here, string was not terminated. */ + exprTokenRelease(t); + return NULL; +} + +/* Parse a tuple of the form [1, "foo", 42]. No nested tuples are + * supported. This type is useful mostly to be used with the "IN" + * operator. */ +exprtoken *exprParseTuple(exprstate *es) { + exprtoken *t = exprNewToken(EXPR_TOKEN_TUPLE); + t->tuple.ele = NULL; + t->tuple.len = 0; + es->p++; /* Skip opening '['. */ + + size_t allocated = 0; + while(1) { + exprConsumeSpaces(es); + + /* Check for empty tuple or end. */ + if (es->p[0] == ']') { + es->p++; + break; + } + + /* Grow tuple array if needed. */ + if (t->tuple.len == allocated) { + size_t newsize = allocated == 0 ? 4 : allocated * 2; + exprtoken **newele = RedisModule_Realloc(t->tuple.ele, + sizeof(exprtoken*) * newsize); + t->tuple.ele = newele; + allocated = newsize; + } + + /* Parse tuple element. */ + exprtoken *ele = NULL; + if (isdigit(es->p[0]) || es->p[0] == '-') { + ele = exprParseNumber(es); + } else if (es->p[0] == '"' || es->p[0] == '\'') { + ele = exprParseString(es); + } else { + exprTokenRelease(t); + return NULL; + } + + /* Error parsing number/string? */ + if (ele == NULL) { + exprTokenRelease(t); + return NULL; + } + + /* Store element if no error was detected. */ + t->tuple.ele[t->tuple.len] = ele; + t->tuple.len++; + + /* Check for next element. */ + exprConsumeSpaces(es); + if (es->p[0] == ']') { + es->p++; + break; + } + if (es->p[0] != ',') { + exprTokenRelease(t); + return NULL; + } + es->p++; /* Skip comma. */ + } + return t; +} + +/* Deallocate the object returned by exprCompile(). */ +void exprFree(exprstate *es) { + if (es == NULL) return; + + /* Free the original expression string. */ + if (es->expr) RedisModule_Free(es->expr); + + /* Free all stacks. */ + exprStackFree(&es->values_stack); + exprStackFree(&es->ops_stack); + exprStackFree(&es->tokens); + exprStackFree(&es->program); + + /* Free the state object itself. */ + RedisModule_Free(es); +} + +/* Split the provided expression into a stack of tokens. Returns + * 0 on success, 1 on error. */ +int exprTokenize(exprstate *es, int *errpos) { + /* Main parsing loop. */ + while(1) { + exprConsumeSpaces(es); + + /* Set a flag to see if we can consider the - part of the + * number, or an operator. */ + int minus_is_number = 0; // By default is an operator. + + exprtoken *last = exprStackPeek(&es->tokens); + if (last == NULL) { + /* If we are at the start of an expression, the minus is + * considered a number. */ + minus_is_number = 1; + } else if (last->token_type == EXPR_TOKEN_OP && + last->opcode != EXPR_OP_CPAREN) + { + /* Also, if the previous token was an operator, the minus + * is considered a number, unless the previous operator is + * a closing parens. In such case it's like (...) -5, or alike + * and we want to emit an operator. */ + minus_is_number = 1; + } + + /* Parse based on the current character. */ + exprtoken *current = NULL; + if (*es->p == '\0') { + current = exprNewToken(EXPR_TOKEN_EOF); + } else if (isdigit(*es->p) || + (minus_is_number && *es->p == '-' && isdigit(es->p[1]))) + { + current = exprParseNumber(es); + } else if (*es->p == '"' || *es->p == '\'') { + current = exprParseString(es); + } else if (*es->p == '.' && is_selector_char(es->p[1])) { + current = exprParseSelector(es); + } else if (isalpha(*es->p) || strchr(EXPR_OP_SPECIALCHARS, *es->p)) { + current = exprParseOperator(es); + } else if (*es->p == '[') { + current = exprParseTuple(es); + } + + if (current == NULL) { + if (errpos) *errpos = es->p - es->expr; + return 1; // Syntax Error. + } + + /* Push the current token to tokens stack. */ + exprStackPush(&es->tokens, current); + if (current->token_type == EXPR_TOKEN_EOF) break; + } + return 0; +} + +/* Helper function to get operator precedence from the operator table. */ +int exprGetOpPrecedence(int opcode) { + for (int i = 0; ExprOptable[i].opname != NULL; i++) { + if (ExprOptable[i].opcode == opcode) + return ExprOptable[i].precedence; + } + return -1; +} + +/* Helper function to get operator arity from the operator table. */ +int exprGetOpArity(int opcode) { + for (int i = 0; ExprOptable[i].opname != NULL; i++) { + if (ExprOptable[i].opcode == opcode) + return ExprOptable[i].arity; + } + return -1; +} + +/* Process an operator during compilation. Returns 0 on success, 1 on error. + * This function will retain a reference of the operator 'op' in case it + * is pushed on the operators stack. */ +int exprProcessOperator(exprstate *es, exprtoken *op, int *stack_items, int *errpos) { + if (op->opcode == EXPR_OP_OPAREN) { + // This is just a marker for us. Do nothing. + exprStackPush(&es->ops_stack, op); + exprTokenRetain(op); + return 0; + } + + if (op->opcode == EXPR_OP_CPAREN) { + /* Process operators until we find the matching opening parenthesis. */ + while (1) { + exprtoken *top_op = exprStackPop(&es->ops_stack); + if (top_op == NULL) { + if (errpos) *errpos = op->offset; + return 1; + } + + if (top_op->opcode == EXPR_OP_OPAREN) { + /* Open parethesis found. Our work finished. */ + exprTokenRelease(top_op); + return 0; + } + + int arity = exprGetOpArity(top_op->opcode); + if (*stack_items < arity) { + exprTokenRelease(top_op); + if (errpos) *errpos = top_op->offset; + return 1; + } + + /* Move the operator on the program stack. */ + exprStackPush(&es->program, top_op); + *stack_items = *stack_items - arity + 1; + } + } + + int curr_prec = exprGetOpPrecedence(op->opcode); + + /* Process operators with higher or equal precedence. */ + while (1) { + exprtoken *top_op = exprStackPeek(&es->ops_stack); + if (top_op == NULL || top_op->opcode == EXPR_OP_OPAREN) break; + + int top_prec = exprGetOpPrecedence(top_op->opcode); + if (top_prec < curr_prec) break; + /* Special case for **: only pop if precedence is strictly higher + * so that the operator is right associative, that is: + * 2 ** 3 ** 2 is evaluated as 2 ** (3 ** 2) == 512 instead + * of (2 ** 3) ** 2 == 64. */ + if (op->opcode == EXPR_OP_POW && top_prec <= curr_prec) break; + + /* Pop and add to program. */ + top_op = exprStackPop(&es->ops_stack); + int arity = exprGetOpArity(top_op->opcode); + if (*stack_items < arity) { + exprTokenRelease(top_op); + if (errpos) *errpos = top_op->offset; + return 1; + } + + /* Move to the program stack. */ + exprStackPush(&es->program, top_op); + *stack_items = *stack_items - arity + 1; + } + + /* Push current operator. */ + exprStackPush(&es->ops_stack, op); + exprTokenRetain(op); + return 0; +} + +/* Compile the expression into a set of push-value and exec-operator + * that exprRun() can execute. The function returns an expstate object + * that can be used for execution of the program. On error, NULL + * is returned, and optionally the position of the error into the + * expression is returned by reference. */ +exprstate *exprCompile(char *expr, int *errpos) { + /* Initialize expression state. */ + exprstate *es = RedisModule_Alloc(sizeof(exprstate)); + es->expr = RedisModule_Strdup(expr); + es->p = es->expr; + + /* Initialize all stacks. */ + exprStackInit(&es->values_stack); + exprStackInit(&es->ops_stack); + exprStackInit(&es->tokens); + exprStackInit(&es->program); + + /* Tokenization. */ + if (exprTokenize(es, errpos)) { + exprFree(es); + return NULL; + } + + /* Compile the expression into a sequence of operations. */ + int stack_items = 0; // Track # of items that would be on the stack + // during execution. This way we can detect arity + // issues at compile time. + + /* Process each token. */ + for (int i = 0; i < es->tokens.numitems; i++) { + exprtoken *token = es->tokens.items[i]; + + if (token->token_type == EXPR_TOKEN_EOF) break; + + /* Handle values (numbers, strings, selectors). */ + if (token->token_type == EXPR_TOKEN_NUM || + token->token_type == EXPR_TOKEN_STR || + token->token_type == EXPR_TOKEN_TUPLE || + token->token_type == EXPR_TOKEN_SELECTOR) + { + exprStackPush(&es->program, token); + exprTokenRetain(token); + stack_items++; + continue; + } + + /* Handle operators. */ + if (token->token_type == EXPR_TOKEN_OP) { + if (exprProcessOperator(es, token, &stack_items, errpos)) { + exprFree(es); + return NULL; + } + continue; + } + } + + /* Process remaining operators on the stack. */ + while (es->ops_stack.numitems > 0) { + exprtoken *op = exprStackPop(&es->ops_stack); + if (op->opcode == EXPR_OP_OPAREN) { + if (errpos) *errpos = op->offset; + exprTokenRelease(op); + exprFree(es); + return NULL; + } + + int arity = exprGetOpArity(op->opcode); + if (stack_items < arity) { + if (errpos) *errpos = op->offset; + exprTokenRelease(op); + exprFree(es); + return NULL; + } + + exprStackPush(&es->program, op); + stack_items = stack_items - arity + 1; + } + + /* Verify that exactly one value would remain on the stack after + * execution. We could also check that such value is a number, but this + * would make the code more complex without much gains. */ + if (stack_items != 1) { + if (errpos) { + /* Point to the last token's offset for error reporting. */ + exprtoken *last = es->tokens.items[es->tokens.numitems - 1]; + *errpos = last->offset; + } + exprFree(es); + return NULL; + } + return es; +} + +/* ============================ Expression execution ======================== */ + +/* Convert a token to its numeric value. For strings we attempt to parse them + * as numbers, returning 0 if conversion fails. */ +double exprTokenToNum(exprtoken *t) { + char buf[128]; + if (t->token_type == EXPR_TOKEN_NUM) { + return t->num; + } else if (t->token_type == EXPR_TOKEN_STR && t->str.len < sizeof(buf)) { + memcpy(buf, t->str.start, t->str.len); + buf[t->str.len] = '\0'; + char *endptr; + double val = strtod(buf, &endptr); + return *endptr == '\0' ? val : 0; + } else { + return 0; + } +} + +/* Conver obejct to true/false (0 or 1) */ +double exprTokenToBool(exprtoken *t) { + if (t->token_type == EXPR_TOKEN_NUM) { + return t->num != 0; + } else if (t->token_type == EXPR_TOKEN_STR && t->str.len == 0) { + return 0; // Empty string are false, like in Javascript. + } else { + return 1; // Every non numerical type is true. + } +} + +/* Compare two tokens. Returns true if they are equal. */ +int exprTokensEqual(exprtoken *a, exprtoken *b) { + // If both are strings, do string comparison. + if (a->token_type == EXPR_TOKEN_STR && b->token_type == EXPR_TOKEN_STR) { + return a->str.len == b->str.len && + memcmp(a->str.start, b->str.start, a->str.len) == 0; + } + + // If both are numbers, do numeric comparison. + if (a->token_type == EXPR_TOKEN_NUM && b->token_type == EXPR_TOKEN_NUM) { + return a->num == b->num; + } + + // Mixed types - convert to numbers and compare. + return exprTokenToNum(a) == exprTokenToNum(b); +} + +/* Convert a json object to an expression token. There is only + * limited support for JSON arrays: they must be composed of + * just numbers and strings. Returns NULL if the JSON object + * cannot be converted. */ +exprtoken *exprJsonToToken(cJSON *js) { + if (cJSON_IsNumber(js)) { + exprtoken *obj = exprNewToken(EXPR_TOKEN_NUM); + obj->num = cJSON_GetNumberValue(js); + return obj; + } else if (cJSON_IsString(js)) { + exprtoken *obj = exprNewToken(EXPR_TOKEN_STR); + char *strval = cJSON_GetStringValue(js); + obj->str.heapstr = RedisModule_Strdup(strval); + obj->str.start = obj->str.heapstr; + obj->str.len = strlen(obj->str.heapstr); + return obj; + } else if (cJSON_IsBool(js)) { + exprtoken *obj = exprNewToken(EXPR_TOKEN_NUM); + obj->num = cJSON_IsTrue(js); + return obj; + } else if (cJSON_IsArray(js)) { + // First, scan the array to ensure it only + // contains strings and numbers. Otherwise the + // expression will evaluate to false. + int array_size = cJSON_GetArraySize(js); + + for (int j = 0; j < array_size; j++) { + cJSON *item = cJSON_GetArrayItem(js, j); + if (!cJSON_IsNumber(item) && !cJSON_IsString(item)) return NULL; + } + + // Create a tuple token for the array. + exprtoken *obj = exprNewToken(EXPR_TOKEN_TUPLE); + obj->tuple.len = array_size; + obj->tuple.ele = NULL; + if (obj->tuple.len == 0) return obj; // No elements, already ok. + + obj->tuple.ele = + RedisModule_Alloc(sizeof(exprtoken*) * obj->tuple.len); + + // Convert each array element to a token. + for (size_t j = 0; j < obj->tuple.len; j++) { + cJSON *item = cJSON_GetArrayItem(js, j); + if (cJSON_IsNumber(item)) { + exprtoken *eleToken = exprNewToken(EXPR_TOKEN_NUM); + eleToken->num = cJSON_GetNumberValue(item); + obj->tuple.ele[j] = eleToken; + } else if (cJSON_IsString(item)) { + exprtoken *eleToken = exprNewToken(EXPR_TOKEN_STR); + char *strval = cJSON_GetStringValue(item); + eleToken->str.heapstr = RedisModule_Strdup(strval); + eleToken->str.start = eleToken->str.heapstr; + eleToken->str.len = strlen(eleToken->str.heapstr); + obj->tuple.ele[j] = eleToken; + } + } + return obj; + } + return NULL; // No conversion possible for this type. +} + +/* Execute the compiled expression program. Returns 1 if the final stack value + * evaluates to true, 0 otherwise. Also returns 0 if any selector callback + * fails. */ +int exprRun(exprstate *es, char *json, size_t json_len) { + exprStackReset(&es->values_stack); + cJSON *parsed_json = NULL; + + // Execute each instruction in the program. + for (int i = 0; i < es->program.numitems; i++) { + exprtoken *t = es->program.items[i]; + + // Handle selectors by calling the callback. + if (t->token_type == EXPR_TOKEN_SELECTOR) { + if (json != NULL) { + cJSON *attrib = NULL; + if (parsed_json == NULL) { + parsed_json = cJSON_ParseWithLength(json,json_len); + // Will be left to NULL if the above fails. + } + if (parsed_json) { + char item_name[128]; + if (t->str.len > 0 && t->str.len < sizeof(item_name)) { + memcpy(item_name,t->str.start,t->str.len); + item_name[t->str.len] = 0; + attrib = cJSON_GetObjectItem(parsed_json,item_name); + } + /* Fill the token according to the JSON type stored + * at the attribute. */ + if (attrib) { + exprtoken *obj = exprJsonToToken(attrib); + if (obj) { + exprStackPush(&es->values_stack, obj); + continue; + } + } + } + } + + // Selector not found or JSON object not convertible to + // expression tokens. Evaluate the expression to false. + if (parsed_json) cJSON_Delete(parsed_json); + return 0; + } + + // Push non-operator values directly onto the stack. + if (t->token_type != EXPR_TOKEN_OP) { + exprStackPush(&es->values_stack, t); + exprTokenRetain(t); + continue; + } + + // Handle operators. + exprtoken *result = exprNewToken(EXPR_TOKEN_NUM); + + // Pop operands - we know we have enough from compile-time checks. + exprtoken *b = exprStackPop(&es->values_stack); + exprtoken *a = NULL; + if (exprGetOpArity(t->opcode) == 2) { + a = exprStackPop(&es->values_stack); + } + + switch(t->opcode) { + case EXPR_OP_NOT: + result->num = exprTokenToBool(b) == 0 ? 1 : 0; + break; + case EXPR_OP_POW: { + double base = exprTokenToNum(a); + double exp = exprTokenToNum(b); + result->num = pow(base, exp); + break; + } + case EXPR_OP_MULT: + result->num = exprTokenToNum(a) * exprTokenToNum(b); + break; + case EXPR_OP_DIV: + result->num = exprTokenToNum(a) / exprTokenToNum(b); + break; + case EXPR_OP_MOD: { + double va = exprTokenToNum(a); + double vb = exprTokenToNum(b); + result->num = fmod(va, vb); + break; + } + case EXPR_OP_SUM: + result->num = exprTokenToNum(a) + exprTokenToNum(b); + break; + case EXPR_OP_DIFF: + result->num = exprTokenToNum(a) - exprTokenToNum(b); + break; + case EXPR_OP_GT: + result->num = exprTokenToNum(a) > exprTokenToNum(b) ? 1 : 0; + break; + case EXPR_OP_GTE: + result->num = exprTokenToNum(a) >= exprTokenToNum(b) ? 1 : 0; + break; + case EXPR_OP_LT: + result->num = exprTokenToNum(a) < exprTokenToNum(b) ? 1 : 0; + break; + case EXPR_OP_LTE: + result->num = exprTokenToNum(a) <= exprTokenToNum(b) ? 1 : 0; + break; + case EXPR_OP_EQ: + result->num = exprTokensEqual(a, b) ? 1 : 0; + break; + case EXPR_OP_NEQ: + result->num = !exprTokensEqual(a, b) ? 1 : 0; + break; + case EXPR_OP_IN: { + // For 'in' operator, b must be a tuple. + result->num = 0; // Default to false. + if (b->token_type == EXPR_TOKEN_TUPLE) { + for (size_t j = 0; j < b->tuple.len; j++) { + if (exprTokensEqual(a, b->tuple.ele[j])) { + result->num = 1; // Found a match. + break; + } + } + } + break; + } + case EXPR_OP_AND: + result->num = + exprTokenToBool(a) != 0 && exprTokenToBool(b) != 0 ? 1 : 0; + break; + case EXPR_OP_OR: + result->num = + exprTokenToBool(a) != 0 || exprTokenToBool(b) != 0 ? 1 : 0; + break; + default: + // Do nothing: we don't want runtime errors. + break; + } + + // Free operands and push result. + if (a) exprTokenRelease(a); + exprTokenRelease(b); + exprStackPush(&es->values_stack, result); + } + + if (parsed_json) cJSON_Delete(parsed_json); + + // Get final result from stack. + exprtoken *final = exprStackPop(&es->values_stack); + if (final == NULL) return 0; + + // Convert result to boolean. + int retval = exprTokenToBool(final); + exprTokenRelease(final); + return retval; +} + +/* ============================ Simple test main ============================ */ + +#ifdef TEST_MAIN +void exprPrintToken(exprtoken *t) { + switch(t->token_type) { + case EXPR_TOKEN_EOF: + printf("EOF"); + break; + case EXPR_TOKEN_NUM: + printf("NUM:%g", t->num); + break; + case EXPR_TOKEN_STR: + printf("STR:\"%.*s\"", (int)t->str.len, t->str.start); + break; + case EXPR_TOKEN_SELECTOR: + printf("SEL:%.*s", (int)t->str.len, t->str.start); + break; + case EXPR_TOKEN_OP: + printf("OP:"); + for (int i = 0; ExprOptable[i].opname != NULL; i++) { + if (ExprOptable[i].opcode == t->opcode) { + printf("%s", ExprOptable[i].opname); + break; + } + } + break; + default: + printf("UNKNOWN"); + break; + } +} + +void exprPrintStack(exprstack *stack, const char *name) { + printf("%s (%d items):", name, stack->numitems); + for (int j = 0; j < stack->numitems; j++) { + printf(" "); + exprPrintToken(stack->items[j]); + } + printf("\n"); +} + +int main(int argc, char **argv) { + char *testexpr = "(5+2)*3 and .year > 1980 and 'foo' == 'foo'"; + char *testjson = "{\"year\": 1984, \"name\": \"The Matrix\"}"; + if (argc >= 2) testexpr = argv[1]; + if (argc >= 3) testjson = argv[2]; + + printf("Compiling expression: %s\n", testexpr); + + int errpos = 0; + exprstate *es = exprCompile(testexpr,&errpos); + if (es == NULL) { + printf("Compilation failed near \"...%s\"\n", testexpr+errpos); + return 1; + } + + exprPrintStack(&es->tokens, "Tokens"); + exprPrintStack(&es->program, "Program"); + printf("Running against object: %s\n", testjson); + int result = exprRun(es,testjson,strlen(testjson)); + printf("Result1: %s\n", result ? "True" : "False"); + result = exprRun(es,testjson,strlen(testjson)); + printf("Result2: %s\n", result ? "True" : "False"); + + exprFree(es); + return 0; +} +#endif diff --git a/modules/vector-sets/hnsw.c b/modules/vector-sets/hnsw.c new file mode 100644 index 000000000..a9a2695ad --- /dev/null +++ b/modules/vector-sets/hnsw.c @@ -0,0 +1,2718 @@ +/* HNSW (Hierarchical Navigable Small World) Implementation. + * + * Based on the paper by Yu. A. Malkov, D. A. Yashunin. + * + * Many details of this implementation, not covered in the paper, were + * obtained simulating different workloads and checking the connection + * quality of the graph. + * + * Notably, this implementation: + * + * 1. Only uses bi-directional links, implementing strategies in order to + * link new nodes even when candidates are full, and our new node would + * be not close enough to replace old links in candidate. + * + * 2. We normalize on-insert, making cosine similarity and dot product the + * same. This means we can't use euclidian distance or alike here. + * Together with quantization, this provides an important speedup that + * makes HNSW more practical. + * + * 3. The quantization used is int8. And it is performed per-vector, so the + * "range" (max abs value) is also stored alongside with the quantized data. + * + * 4. This library implements true elements deletion, not just marking the + * element as deleted, but removing it (we can do it since our links are + * bidirectional), and reliking the nodes orphaned of one link among + * them. + * + * Copyright(C) 2024-Present, Redis Ltd. All Rights Reserved. + * Originally authored by: Salvatore Sanfilippo. + */ + +#define _DEFAULT_SOURCE +#define _POSIX_C_SOURCE 200809L + +#include +#include +#include +#include +#include +#include /* for INFINITY if not in math.h */ +#include +#include "hnsw.h" + +#if 0 +#define debugmsg printf +#else +#define debugmsg if(0) printf +#endif + +#ifndef INFINITY +#define INFINITY (1.0/0.0) +#endif + +#define MIN(a,b) ((a) < (b) ? (a) : (b)) + +/* Algorithm parameters. */ + +#define HNSW_P 0.25 /* Probability of level increase. */ +#define HNSW_MAX_LEVEL 16 /* Max level nodes can reach. */ +#define HNSW_EF_C 200 /* Default size of dynamic candidate list while + * inserting a new node, in case 0 is passed to + * the 'ef' argument while inserting. This is also + * used when deleting nodes for the search step + * needed sometimes to reconnect nodes that remain + * orphaned of one link. */ + +static void (*hfree)(void *p) = free; +static void *(*hmalloc)(size_t s) = malloc; +static void *(*hrealloc)(void *old, size_t s) = realloc; + +void hnsw_set_allocator(void (*free_ptr)(void*), void *(*malloc_ptr)(size_t), + void *(*realloc_ptr)(void*, size_t)) +{ + hfree = free_ptr; + hmalloc = malloc_ptr; + hrealloc = realloc_ptr; +} + +// Get a warning if you use the libc allocator functions for mistake. +#define malloc use_hmalloc_instead +#define realloc use_hrealloc_instead +#define free use_hfree_instead + +/* ============================== Prototypes ================================ */ +void hnsw_cursor_element_deleted(HNSW *index, hnswNode *deleted); + +/* ============================ Priority queue ================================ + * We need a priority queue to take an ordered list of candidates. Right now + * it is implemented as a linear array, since it is relatively small. + * + * You may find it to be odd that we take the best element (smaller distance) + * at the end of the array, but this way popping from the pqueue is O(1), as + * we need to just decrement the count, and this is a very used operation + * in a critical code path. This makes the priority queue implementation a + * bit more complex in the insertion, but for good reasons. */ + +/* Maximum number of candidates we'll ever need (cit. Bill Gates). */ +#define HNSW_MAX_CANDIDATES 256 + +typedef struct { + hnswNode *node; + float distance; +} pqitem; + +typedef struct { + pqitem *items; /* Array of items. */ + uint32_t count; /* Current number of items. */ + uint32_t cap; /* Maximum capacity. */ +} pqueue; + +/* The HNSW algorithms access the pqueue conceptually from nearest (index 0) + * to farest (larger indexes) node, so the following macros are used to + * access the pqueue in this fashion, even if the internal order is + * actually reversed. */ +#define pq_get_node(q,i) ((q)->items[(q)->count-(i+1)].node) +#define pq_get_distance(q,i) ((q)->items[(q)->count-(i+1)].distance) + +/* Create a new priority queue with given capacity. Adding to the + * pqueue only retains 'capacity' elements with the shortest distance. */ +pqueue *pq_new(uint32_t capacity) { + pqueue *pq = hmalloc(sizeof(*pq)); + if (!pq) return NULL; + + pq->items = hmalloc(sizeof(pqitem) * capacity); + if (!pq->items) { + hfree(pq); + return NULL; + } + + pq->count = 0; + pq->cap = capacity; + return pq; +} + +/* Free a priority queue. */ +void pq_free(pqueue *pq) { + if (!pq) return; + hfree(pq->items); + hfree(pq); +} + +/* Insert maintaining distance order (higher distances first). */ +void pq_push(pqueue *pq, hnswNode *node, float distance) { + if (pq->count < pq->cap) { + /* Queue not full: shift right from high distances to make room. */ + uint32_t i = pq->count; + while (i > 0 && pq->items[i-1].distance < distance) { + pq->items[i] = pq->items[i-1]; + i--; + } + pq->items[i].node = node; + pq->items[i].distance = distance; + pq->count++; + } else { + /* Queue full: if new item is worse than worst, ignore it. */ + if (distance >= pq->items[0].distance) return; + + /* Otherwise shift left from low distances to drop worst. */ + uint32_t i = 0; + while (i < pq->cap-1 && pq->items[i+1].distance > distance) { + pq->items[i] = pq->items[i+1]; + i++; + } + pq->items[i].node = node; + pq->items[i].distance = distance; + } +} + +/* Remove and return the top (closest) element, which is at count-1 + * since we store elements with higher distances first. + * Runs in constant time. */ +hnswNode *pq_pop(pqueue *pq, float *distance) { + if (pq->count == 0) return NULL; + pq->count--; + *distance = pq->items[pq->count].distance; + return pq->items[pq->count].node; +} + +/* Get distance of the furthest element. + * An empty priority queue has infinite distance as its furthest element, + * note that this behavior is needed by the algorithms below. */ +float pq_max_distance(pqueue *pq) { + if (pq->count == 0) return INFINITY; + return pq->items[0].distance; +} + +/* ============================ HNSW algorithm ============================== */ + +/* Dot product: our vectors are already normalized. + * Version for not quantized vectors of floats. */ +float vectors_distance_float(const float *x, const float *y, uint32_t dim) { + /* Use two accumulators to reduce dependencies among multiplications. + * This provides a clear speed boost in Apple silicon, but should be + * help in general. */ + float dot0 = 0.0f, dot1 = 0.0f; + uint32_t i; + + // Process 8 elements per iteration, 50/50 with the two accumulators. + for (i = 0; i + 7 < dim; i += 8) { + dot0 += x[i] * y[i] + + x[i+1] * y[i+1] + + x[i+2] * y[i+2] + + x[i+3] * y[i+3]; + + dot1 += x[i+4] * y[i+4] + + x[i+5] * y[i+5] + + x[i+6] * y[i+6] + + x[i+7] * y[i+7]; + } + + /* Handle the remaining elements. These are a minority in the case + * of a smal vector, don't optimze this part. */ + for (; i < dim; i++) dot0 += x[i] * y[i]; + + /* The following line may be counter intuitive. The dot product of + * normalized vectors is equivalent to their cosine similarity. The + * cosine will be from -1 (vectors facing opposite directions in the + * N-dim space) to 1 (vectors are facing in the same direction). + * + * We kinda want a "score" of distance from 0 to 2 (this is a distance + * function and we want minimize the distance for K-NN searches), so we + * can't just add 1: that would return a number in the 0-2 range, with + * 0 meaning opposite vectors and 2 identical vectors, so this is + * similarity, not distance. + * + * Returning instead (1 - dotprod) inverts the meaning: 0 is identical + * and 2 is opposite, hence it is their distance. + * + * Why don't normalize the similarity right now, and return from 0 to + * 1? Because division is costly. */ + return 1.0f - (dot0 + dot1); +} + +/* Q8 quants dotproduct. We do integer math and later fix it by range. */ +float vectors_distance_q8(const int8_t *x, const int8_t *y, uint32_t dim, + float range_a, float range_b) { + // Handle zero vectors special case. + if (range_a == 0 || range_b == 0) { + /* Zero vector distance from anything is 1.0 + * (since 1.0 - dot_product where dot_product = 0). */ + return 1.0f; + } + + /* Each vector is quantized from [-max_abs, +max_abs] to [-127, 127] + * where range = 2*max_abs. */ + const float scale_product = (range_a/127) * (range_b/127); + + int32_t dot0 = 0, dot1 = 0; + uint32_t i; + + // Process 8 elements at a time for better pipeline utilization. + for (i = 0; i + 7 < dim; i += 8) { + dot0 += ((int32_t)x[i]) * ((int32_t)y[i]) + + ((int32_t)x[i+1]) * ((int32_t)y[i+1]) + + ((int32_t)x[i+2]) * ((int32_t)y[i+2]) + + ((int32_t)x[i+3]) * ((int32_t)y[i+3]); + + dot1 += ((int32_t)x[i+4]) * ((int32_t)y[i+4]) + + ((int32_t)x[i+5]) * ((int32_t)y[i+5]) + + ((int32_t)x[i+6]) * ((int32_t)y[i+6]) + + ((int32_t)x[i+7]) * ((int32_t)y[i+7]); + } + + // Handle remaining elements. + for (; i < dim; i++) dot0 += ((int32_t)x[i]) * ((int32_t)y[i]); + + // Convert to original range. + float dotf = (dot0 + dot1) * scale_product; + float distance = 1.0f - dotf; + + // Clamp distance to [0, 2]. + if (distance < 0) distance = 0; + else if (distance > 2) distance = 2; + return distance; +} + +static inline int popcount64(uint64_t x) { + x = (x & 0x5555555555555555) + ((x >> 1) & 0x5555555555555555); + x = (x & 0x3333333333333333) + ((x >> 2) & 0x3333333333333333); + x = (x & 0x0F0F0F0F0F0F0F0F) + ((x >> 4) & 0x0F0F0F0F0F0F0F0F); + x = (x & 0x00FF00FF00FF00FF) + ((x >> 8) & 0x00FF00FF00FF00FF); + x = (x & 0x0000FFFF0000FFFF) + ((x >> 16) & 0x0000FFFF0000FFFF); + x = (x & 0x00000000FFFFFFFF) + ((x >> 32) & 0x00000000FFFFFFFF); + return x; +} + +/* Binary vectors distance. */ +float vectors_distance_bin(const uint64_t *x, const uint64_t *y, uint32_t dim) { + uint32_t len = (dim+63)/64; + uint32_t opposite = 0; + for (uint32_t j = 0; j < len; j++) { + uint64_t xor = x[j]^y[j]; + opposite += popcount64(xor); + } + return (float)opposite*2/dim; +} + +/* Dot product between nodes. Will call the right version depending on the + * quantization used. */ +float hnsw_distance(HNSW *index, hnswNode *a, hnswNode *b) { + switch(index->quant_type) { + case HNSW_QUANT_NONE: + return vectors_distance_float(a->vector,b->vector,index->vector_dim); + case HNSW_QUANT_Q8: + return vectors_distance_q8(a->vector,b->vector,index->vector_dim,a->quants_range,b->quants_range); + case HNSW_QUANT_BIN: + return vectors_distance_bin(a->vector,b->vector,index->vector_dim); + default: + assert(1 != 1); + return 0; + } +} + +/* This do Q8 'range' quantization. + * For people looking at this code thinking: Oh, I could use min/max + * quants instead! Well: I tried with min/max normalization but the dot + * product needs to accumulate the sum for later correction, and it's slower. */ +void quantize_to_q8(float *src, int8_t *dst, uint32_t dim, float *rangeptr) { + float max_abs = 0; + for (uint32_t j = 0; j < dim; j++) { + if (src[j] > max_abs) max_abs = src[j]; + if (-src[j] > max_abs) max_abs = -src[j]; + } + + if (max_abs == 0) { + if (rangeptr) *rangeptr = 0; + memset(dst, 0, dim); + return; + } + + const float scale = 127.0f / max_abs; // Scale to map to [-127, 127]. + + for (uint32_t j = 0; j < dim; j++) { + dst[j] = (int8_t)roundf(src[j] * scale); + } + if (rangeptr) *rangeptr = max_abs; // Return max_abs instead of 2*max_abs. +} + +/* Binary quantization of vector 'src' to 'dst'. We use full words of + * 64 bit as smallest unit, we will just set all the unused bits to 0 + * so that they'll be the same in all the vectors, and when xor+popcount + * is used to compute the distance, such bits are not considered. This + * allows to go faster. */ +void quantize_to_bin(float *src, uint64_t *dst, uint32_t dim) { + memset(dst,0,(dim+63)/64*sizeof(uint64_t)); + for (uint32_t j = 0; j < dim; j++) { + uint32_t word = j/64; + uint32_t bit = j&63; + /* Since cosine similarity checks the vector direction and + * not magnitudo, we do likewise in the binary quantization and + * just remember if the component is positive or negative. */ + if (src[j] > 0) dst[word] |= 1ULL< HNSW_MAX_M) m = HNSW_MAX_M; + + index->M = m; + index->quant_type = quant_type; + index->enter_point = NULL; + index->max_level = 0; + index->vector_dim = vector_dim; + index->node_count = 0; + index->last_id = 0; + index->head = NULL; + index->cursors = NULL; + + /* Initialize epochs array. */ + for (int i = 0; i < HNSW_MAX_THREADS; i++) + index->current_epoch[i] = 0; + + /* Initialize locks. */ + if (pthread_rwlock_init(&index->global_lock, NULL) != 0) { + hfree(index); + return NULL; + } + + for (int i = 0; i < HNSW_MAX_THREADS; i++) { + if (pthread_mutex_init(&index->slot_locks[i], NULL) != 0) { + /* Clean up previously initialized mutexes. */ + for (int j = 0; j < i; j++) + pthread_mutex_destroy(&index->slot_locks[j]); + pthread_rwlock_destroy(&index->global_lock); + hfree(index); + return NULL; + } + } + + /* Initialize atomic variables. */ + index->next_slot = 0; + index->version = 0; + return index; +} + +/* Fill 'vec' with the node vector, de-normalizing and de-quantizing it + * as needed. Note that this function will return an approximated version + * of the original vector. */ +void hnsw_get_node_vector(HNSW *index, hnswNode *node, float *vec) { + if (index->quant_type == HNSW_QUANT_NONE) { + memcpy(vec,node->vector,index->vector_dim*sizeof(float)); + } else if (index->quant_type == HNSW_QUANT_Q8) { + int8_t *quants = node->vector; + for (uint32_t j = 0; j < index->vector_dim; j++) + vec[j] = (quants[j]*node->quants_range)/127; + } else if (index->quant_type == HNSW_QUANT_BIN) { + uint64_t *bits = node->vector; + for (uint32_t j = 0; j < index->vector_dim; j++) { + uint32_t word = j/64; + uint32_t bit = j&63; + vec[j] = (bits[word] & (1ULL<quant_type != HNSW_QUANT_BIN) { + for (uint32_t j = 0; j < index->vector_dim; j++) + vec[j] *= node->l2; + } +} + +/* Return the number of bytes needed to represent a vector in the index, + * that is function of the dimension of the vectors and the quantization + * type used. */ +uint32_t hnsw_quants_bytes(HNSW *index) { + switch(index->quant_type) { + case HNSW_QUANT_NONE: return index->vector_dim * sizeof(float); + case HNSW_QUANT_Q8: return index->vector_dim; + case HNSW_QUANT_BIN: return (index->vector_dim+63)/64*8; + default: assert(0 && "Quantization type not supported."); + } +} + +/* Create new node. Returns NULL on out of memory. + * It is possible to pass the vector as floats or, in case this index + * was already stored on disk and is being loaded, or serialized and + * transmitted in any form, the already quantized version in + * 'qvector'. + * + * Only vector or qvector should be non-NULL. The reason why passing + * a quantized vector is useful, is that because re-normalizing and + * re-quantizing several times the same vector may accumulate rounding + * errors. So if you work with quantized indexes, you should save + * the quantized indexes. + * + * Note that, together with qvector, the quantization range is needed, + * since this library uses per-vector quantization. In case of quantized + * vectors the l2 is considered to be '1', so if you want to restore + * the right l2 (to use the API that returns an approximation of the + * original vector) make sure to save the l2 on disk and set it back + * after the node creation (see later for the serialization API that + * handles this and more). */ +hnswNode *hnsw_node_new(HNSW *index, uint64_t id, const float *vector, const int8_t *qvector, float qrange, uint32_t level, int normalize) { + hnswNode *node = hmalloc(sizeof(hnswNode)+(sizeof(hnswNodeLayer)*(level+1))); + if (!node) return NULL; + + if (id == 0) id = ++index->last_id; + node->level = level; + node->id = id; + node->next = NULL; + node->vector = NULL; + node->l2 = 1; // Default in case of already quantized vectors. It is + // up to the caller to fill this later, if needed. + + /* Initialize visited epoch array. */ + for (int i = 0; i < HNSW_MAX_THREADS; i++) + node->visited_epoch[i] = 0; + + if (qvector == NULL) { + /* Copy input vector. */ + node->vector = hmalloc(sizeof(float) * index->vector_dim); + if (!node->vector) { + hfree(node); + return NULL; + } + memcpy(node->vector, vector, sizeof(float) * index->vector_dim); + if (normalize) + hnsw_normalize_vector(node->vector,&node->l2,index->vector_dim); + + /* Handle quantization. */ + if (index->quant_type != HNSW_QUANT_NONE) { + void *quants = hmalloc(hnsw_quants_bytes(index)); + if (quants == NULL) { + hfree(node->vector); + hfree(node); + return NULL; + } + + // Quantize. + switch(index->quant_type) { + case HNSW_QUANT_Q8: + quantize_to_q8(node->vector,quants,index->vector_dim,&node->quants_range); + break; + case HNSW_QUANT_BIN: + quantize_to_bin(node->vector,quants,index->vector_dim); + break; + default: + assert(0 && "Quantization type not handled."); + break; + } + + // Discard the full precision vector. + hfree(node->vector); + node->vector = quants; + } + } else { + // We got the already quantized vector. Just copy it. + assert(index->quant_type != HNSW_QUANT_NONE); + uint32_t vector_bytes = hnsw_quants_bytes(index); + node->vector = hmalloc(vector_bytes); + node->quants_range = qrange; + if (node->vector == NULL) { + hfree(node); + return NULL; + } + memcpy(node->vector,qvector,vector_bytes); + } + + /* Initialize each layer. */ + for (uint32_t i = 0; i <= level; i++) { + uint32_t max_links = (i == 0) ? index->M*2 : index->M; + node->layers[i].max_links = max_links; + node->layers[i].num_links = 0; + node->layers[i].worst_distance = 0; + node->layers[i].worst_idx = 0; + node->layers[i].links = hmalloc(sizeof(hnswNode*) * max_links); + if (!node->layers[i].links) { + for (uint32_t j = 0; j < i; j++) hfree(node->layers[j].links); + hfree(node->vector); + hfree(node); + return NULL; + } + } + + return node; +} + +/* Free a node. */ +void hnsw_node_free(hnswNode *node) { + if (!node) return; + + for (uint32_t i = 0; i <= node->level; i++) + hfree(node->layers[i].links); + + hfree(node->vector); + hfree(node); +} + +/* Free the entire index. */ +void hnsw_free(HNSW *index,void(*free_value)(void*value)) { + if (!index) return; + + hnswNode *current = index->head; + while (current) { + hnswNode *next = current->next; + if (free_value) free_value(current->value); + hnsw_node_free(current); + current = next; + } + + /* Destroy locks */ + pthread_rwlock_destroy(&index->global_lock); + for (int i = 0; i < HNSW_MAX_THREADS; i++) { + pthread_mutex_destroy(&index->slot_locks[i]); + } + + hfree(index); +} + +/* Add node to linked list of nodes. We may need to scan the whole + * HNSW graph for several reasons. The list is doubly linked since we + * also need the ability to remove a node without scanning the whole thing. */ +void hnsw_add_node(HNSW *index, hnswNode *node) { + node->next = index->head; + node->prev = NULL; + if (index->head) + index->head->prev = node; + index->head = node; + index->node_count++; +} + +/* Search the specified layer starting from the specified entry point + * to collect 'ef' nodes that are near to 'query'. + * + * This function implements optional hybrid search, so that each node + * can be accepted or not based on its associated value. In this case + * a callback 'filter_callback' should be passed, together with a maximum + * effort for the search (number of candidates to evaluate), since even + * with a a low "EF" value we risk that there are too few nodes that satisfy + * the provided filter, and we could trigger a full scan. */ +pqueue *search_layer_with_filter( + HNSW *index, hnswNode *query, hnswNode *entry_point, + uint32_t ef, uint32_t layer, uint32_t slot, + int (*filter_callback)(void *value, void *privdata), + void *filter_privdata, uint32_t max_candidates) +{ + // Mark visited nodes with a never seen epoch. + index->current_epoch[slot]++; + + pqueue *candidates = pq_new(HNSW_MAX_CANDIDATES); + pqueue *results = pq_new(ef); + if (!candidates || !results) { + if (candidates) pq_free(candidates); + if (results) pq_free(results); + return NULL; + } + + // Take track of the total effort: only used when filtering via + // a callback to have a bound effort. + uint32_t evaluated_candidates = 1; + + // Add entry point. + float dist = hnsw_distance(index, query, entry_point); + pq_push(candidates, entry_point, dist); + if (filter_callback == NULL || + filter_callback(entry_point->value, filter_privdata)) + { + pq_push(results, entry_point, dist); + } + entry_point->visited_epoch[slot] = index->current_epoch[slot]; + + // Process candidates. + while (candidates->count > 0) { + // Max effort. If zero, we keep scanning. + if (filter_callback && + max_candidates && + evaluated_candidates >= max_candidates) break; + + float cur_dist; + hnswNode *current = pq_pop(candidates, &cur_dist); + evaluated_candidates++; + + float furthest = pq_max_distance(results); + if (results->count >= ef && cur_dist > furthest) break; + + /* Check neighbors. */ + for (uint32_t i = 0; i < current->layers[layer].num_links; i++) { + hnswNode *neighbor = current->layers[layer].links[i]; + + if (neighbor->visited_epoch[slot] == index->current_epoch[slot]) + continue; // Already visited during this scan. + + neighbor->visited_epoch[slot] = index->current_epoch[slot]; + float neighbor_dist = hnsw_distance(index, query, neighbor); + + furthest = pq_max_distance(results); + if (filter_callback == NULL) { + /* Original HNSW logic when no filtering: + * Add to results if better than current max or + * results not full. */ + if (neighbor_dist < furthest || results->count < ef) { + pq_push(candidates, neighbor, neighbor_dist); + pq_push(results, neighbor, neighbor_dist); + } + } else { + /* With filtering: we add candidates even if doesn't match + * the filter, in order to continue to explore the graph. */ + if (neighbor_dist < furthest || candidates->count < ef) { + pq_push(candidates, neighbor, neighbor_dist); + } + + /* Add results only if passes filter. */ + if (filter_callback(neighbor->value, filter_privdata)) { + if (neighbor_dist < furthest || results->count < ef) { + pq_push(results, neighbor, neighbor_dist); + } + } + } + } + } + + pq_free(candidates); + return results; +} + +/* Just a wrapper without hybrid search callback. */ +pqueue *search_layer(HNSW *index, hnswNode *query, hnswNode *entry_point, + uint32_t ef, uint32_t layer, uint32_t slot) +{ + return search_layer_with_filter(index, query, entry_point, ef, layer, slot, + NULL, NULL, 0); +} + +/* This function is used in order to initialize a node allocated in the + * function stack with the specified vector. The idea is that we can + * easily use hnsw_distance() from a vector and the HNSW nodes this way: + * + * hnswNode myQuery; + * hnsw_init_tmp_node(myIndex,&myQuery,0,some_vector); + * hnsw_distance(&myQuery, some_hnsw_node); + * + * Make sure to later free the node with: + * + * hnsw_free_tmp_node(&myQuery,some_vector); + * You have to pass the vector to the free function, because sometimes + * hnsw_init_tmp_node() may just avoid allocating a vector at all, + * just reusing 'some_vector' pointer. + * + * Return 0 on out of memory, 1 on success. + */ +int hnsw_init_tmp_node(HNSW *index, hnswNode *node, int is_normalized, const float *vector) { + node->vector = NULL; + + /* Work on a normalized query vector if the input vector is + * not normalized. */ + if (!is_normalized) { + node->vector = hmalloc(sizeof(float)*index->vector_dim); + if (node->vector == NULL) return 0; + memcpy(node->vector,vector,sizeof(float)*index->vector_dim); + hnsw_normalize_vector(node->vector,NULL,index->vector_dim); + } else { + node->vector = (float*)vector; + } + + /* If quantization is enabled, our query fake node should be + * quantized as well. */ + if (index->quant_type != HNSW_QUANT_NONE) { + void *quants = hmalloc(hnsw_quants_bytes(index)); + if (quants == NULL) { + if (node->vector != vector) hfree(node->vector); + return 0; + } + switch(index->quant_type) { + case HNSW_QUANT_Q8: + quantize_to_q8(node->vector, quants, index->vector_dim, &node->quants_range); + break; + case HNSW_QUANT_BIN: + quantize_to_bin(node->vector, quants, index->vector_dim); + } + if (node->vector != vector) hfree(node->vector); + node->vector = quants; + } + return 1; +} + +/* Free the stack allocated node initialized by hnsw_init_tmp_node(). */ +void hnsw_free_tmp_node(hnswNode *node, const float *vector) { + if (node->vector != vector) hfree(node->vector); +} + +/* Return approximated K-NN items. Note that neighbors and distances + * arrays must have space for at least 'k' items. + * norm_query should be set to 1 if the query vector is already + * normalized, otherwise, if 0, the function will copy the vector, + * L2-normalize the copy and search using the normalized version. + * + * If the filter_privdata callback is passed, only elements passing the + * specified filter (invoked with privdata and the value associated + * to the node as arguments) are returned. In such case, if max_candidates + * is not NULL, it represents the maximum number of nodes to explore, since + * the search may be otherwise unbound if few or no elements pass the + * filter. */ +int hnsw_search_with_filter + (HNSW *index, const float *query_vector, uint32_t k, + hnswNode **neighbors, float *distances, uint32_t slot, + int query_vector_is_normalized, + int (*filter_callback)(void *value, void *privdata), + void *filter_privdata, uint32_t max_candidates) + +{ + if (!index || !query_vector || !neighbors || k == 0) return -1; + if (!index->enter_point) return 0; // Empty index. + + /* Use a fake node that holds the query vector, this way we can + * use our normal node to node distance functions when checking + * the distance between query and graph nodes. */ + hnswNode query; + if (hnsw_init_tmp_node(index,&query,query_vector_is_normalized,query_vector) == 0) return -1; + + // Start searching from the entry point. + hnswNode *curr_ep = index->enter_point; + + /* Start from higher layer to layer 1 (layer 0 is handled later) + * in the next section. Descend to the most similar node found + * so far. */ + for (int lc = index->max_level; lc > 0; lc--) { + pqueue *results = search_layer(index, &query, curr_ep, 1, lc, slot); + if (!results) continue; + + if (results->count > 0) { + curr_ep = pq_get_node(results,0); + } + pq_free(results); + } + + /* Search bottom layer (the most densely populated) with ef = k */ + pqueue *results = search_layer_with_filter( + index, &query, curr_ep, k, 0, slot, filter_callback, + filter_privdata, max_candidates); + if (!results) { + hnsw_free_tmp_node(&query, query_vector); + return -1; + } + + /* Copy results. */ + uint32_t found = MIN(k, results->count); + for (uint32_t i = 0; i < found; i++) { + neighbors[i] = pq_get_node(results,i); + if (distances) { + distances[i] = pq_get_distance(results,i); + } + } + + pq_free(results); + hnsw_free_tmp_node(&query, query_vector); + return found; +} + +/* Wrapper to hnsw_search_with_filter() when no filter is needed. */ +int hnsw_search(HNSW *index, const float *query_vector, uint32_t k, + hnswNode **neighbors, float *distances, uint32_t slot, + int query_vector_is_normalized) +{ + return hnsw_search_with_filter(index,query_vector,k,neighbors, + distances,slot,query_vector_is_normalized, + NULL,NULL,0); +} + +/* Rescan a node and update the wortst neighbor index. + * The followinng two functions are variants of this function to be used + * when links are added or removed: they may do less work than a full scan. */ +void hnsw_update_worst_neighbor(HNSW *index, hnswNode *node, uint32_t layer) { + float worst_dist = 0; + uint32_t worst_idx = 0; + for (uint32_t i = 0; i < node->layers[layer].num_links; i++) { + float dist = hnsw_distance(index, node, node->layers[layer].links[i]); + if (dist > worst_dist) { + worst_dist = dist; + worst_idx = i; + } + } + node->layers[layer].worst_distance = worst_dist; + node->layers[layer].worst_idx = worst_idx; +} + +/* Update node worst neighbor distance information when a new neighbor + * is added. */ +void hnsw_update_worst_neighbor_on_add(HNSW *index, hnswNode *node, uint32_t layer, uint32_t added_index, float distance) { + (void) index; // Unused but here for API symmetry. + if (node->layers[layer].num_links == 1 || // First neighbor? + distance > node->layers[layer].worst_distance) // New worst? + { + node->layers[layer].worst_distance = distance; + node->layers[layer].worst_idx = added_index; + } +} + +/* Update node worst neighbor distance information when a linked neighbor + * is removed. */ +void hnsw_update_worst_neighbor_on_remove(HNSW *index, hnswNode *node, uint32_t layer, uint32_t removed_idx) +{ + if (node->layers[layer].num_links == 0) { + node->layers[layer].worst_distance = 0; + node->layers[layer].worst_idx = 0; + } else if (removed_idx == node->layers[layer].worst_idx) { + hnsw_update_worst_neighbor(index,node,layer); + } else if (removed_idx < node->layers[layer].worst_idx) { + // Just update index if we removed element before worst. + node->layers[layer].worst_idx--; + } +} + +/* We have a list of candidate nodes to link to the new node, when iserting + * one. This function selects which nodes to link and performs the linking. + * + * Parameters: + * + * - 'candidates' is the priority queue of potential good nodes to link to the + * new node 'new_node'. + * - 'required_links' is as many links we would like our new_node to get + * at the specified layer. + * - 'aggressive' changes the startegy used to find good neighbors as follows: + * + * This function is called with aggressive=0 for all the layers, including + * layer 0. When called like that, it will use the diversity of links and + * quality of links checks before linking our new node with some candidate. + * + * However if the insert function finds that at layer 0, with aggressive=0, + * few connections were made, it calls this function again with agressiveness + * levels greater up to 2. + * + * At aggressive=1, the diversity checks are disabled, and the candidate + * node for linking is accepted even if it is nearest to an already accepted + * neighbor than it is to the new node. + * + * When we link our new node by replacing the link of a candidate neighbor + * that already has the max number of links, inevitably some other node loses + * a connection (to make space for our new node link). In this case: + * + * 1. If such "dropped" node would remain with too little links, we try with + * some different neighbor instead, however as the 'aggressive' paramter + * has incremental values (0, 1, 2) we are more and more willing to leave + * the dropped node with fever connections. + * 2. If aggressive=2, we will scan the candidate neighbor node links to + * find a different linked-node to replace, one better connected even if + * its distance is not the worse. + * + * Note: this function is also called during deletion of nodes in order to + * provide certain nodes with additional links. + */ +void select_neighbors(HNSW *index, pqueue *candidates, hnswNode *new_node, + uint32_t layer, uint32_t required_links, int aggressive) +{ + for (uint32_t i = 0; i < candidates->count; i++) { + hnswNode *neighbor = pq_get_node(candidates,i); + if (neighbor == new_node) continue; // Don't link node with itself. + + /* Use our cached distance among the new node and the candidate. */ + float dist = pq_get_distance(candidates,i); + + /* First of all, since our links are all bidirectional, if the + * new node for any reason has no longer room, or if it accumulated + * the required number of links, return ASAP. */ + if (new_node->layers[layer].num_links >= new_node->layers[layer].max_links || + new_node->layers[layer].num_links >= required_links) return; + + /* If aggressive is true, it is possible that the new node + * already got some link among the candidates (see the top comment, + * this function gets re-called in case of too few links). + * So we need to check if this candidate is already linked to + * the new node. */ + if (aggressive) { + int duplicated = 0; + for (uint32_t j = 0; j < new_node->layers[layer].num_links; j++) { + if (new_node->layers[layer].links[j] == neighbor) { + duplicated = 1; + break; + } + } + if (duplicated) continue; + } + + /* Diversity check. We accept new candidates + * only if there is no element already accepted that is nearest + * to the candidate than the new element itself. + * However this check is disabled if we have pressure to find + * new links (aggressive != 0) */ + if (!aggressive) { + int diversity_failed = 0; + for (uint32_t j = 0; j < new_node->layers[layer].num_links; j++) { + float link_dist = hnsw_distance(index, neighbor, + new_node->layers[layer].links[j]); + if (link_dist < dist) { + diversity_failed = 1; + break; + } + } + if (diversity_failed) continue; + } + + /* If potential neighbor node has space, simply add the new link. + * We will have space as well. */ + uint32_t n = neighbor->layers[layer].num_links; + if (n < neighbor->layers[layer].max_links) { + /* Link candidate to new node. */ + neighbor->layers[layer].links[n] = new_node; + neighbor->layers[layer].num_links++; + + /* Update candidate worst link info. */ + hnsw_update_worst_neighbor_on_add(index,neighbor,layer,n,dist); + + /* Link new node to candidate. */ + uint32_t new_links = new_node->layers[layer].num_links; + new_node->layers[layer].links[new_links] = neighbor; + new_node->layers[layer].num_links++; + + /* Update new node worst link info. */ + hnsw_update_worst_neighbor_on_add(index,new_node,layer,new_links,dist); + continue; + } + + /* ==================================================================== + * Replacing existing candidate neighbor link step. + * ================================================================== */ + + /* If we are here, our accepted candidate for linking is full. + * + * If new node is more distant to candidate than its current worst link + * then we skip it: we would not be able to establish a bidirectional + * connection without compromising link quality of candidate. + * + * At aggressiveness > 0 we don't care about this check. */ + if (!aggressive && dist >= neighbor->layers[layer].worst_distance) + continue; + + /* We can add it: we are ready to replace the candidate neighbor worst + * link with the new node, assuming certain conditions are met. */ + hnswNode *worst_node = neighbor->layers[layer].links[neighbor->layers[layer].worst_idx]; + + /* The worst node linked to our candidate may remain too disconnected + * if we remove the candidate node as its link. Let's check if + * this is the case: */ + if (aggressive == 0 && + worst_node->layers[layer].num_links <= index->M/2) + continue; + + /* Aggressive level = 1. It's ok if the node remains with just + * HNSW_M/4 links. */ + else if (aggressive == 1 && + worst_node->layers[layer].num_links <= index->M/4) + continue; + + /* If aggressive is set to 2, then the new node we are adding failed + * to find enough neighbors. We can't insert an almost orphaned new + * node, so let's see if the target node has some other link + * that is well connected in the graph: we could drop it instead + * of the worst link. */ + if (aggressive == 2 && worst_node->layers[layer].num_links <= + index->M/4) + { + /* Let's see if we can find at least a candidate link that + * would remain with a few connections. Track the one + * that is the farest away (worst distance) from our candidate + * neighbor (in order to remove the less interesting link). */ + worst_node = NULL; + uint32_t worst_idx = 0; + float max_dist = 0; + for (uint32_t j = 0; j < neighbor->layers[layer].num_links; j++) { + hnswNode *to_drop = neighbor->layers[layer].links[j]; + + /* Skip this if it would remain too disconnected as well. + * + * NOTE about index->M/4 min connections requirement: + * + * It is not too strict, since leaving a node with just a + * single link does not just leave it too weakly connected, but + * also sometimes creates cycles with few disconnected + * nodes linked among them. */ + if (to_drop->layers[layer].num_links <= index->M/4) continue; + + float link_dist = hnsw_distance(index, neighbor, to_drop); + if (worst_node == NULL || link_dist > max_dist) { + worst_node = to_drop; + max_dist = link_dist; + worst_idx = j; + } + } + + if (worst_node != NULL) { + /* We found a node that we can drop. Let's pretend this is + * the worst node of the candidate to unify the following + * code path. Later we will fix the worst node info anyway. */ + neighbor->layers[layer].worst_distance = max_dist; + neighbor->layers[layer].worst_idx = worst_idx; + } else { + /* Otherwise we have no other option than reallocating + * the max number of links for this target node, and + * ensure at least a few connections for our new node. */ + uint32_t reallocation_limit = layer == 0 ? + index->M * 3 : index->M *2; + if (neighbor->layers[layer].max_links >= reallocation_limit) + continue; + + uint32_t new_max_links = neighbor->layers[layer].max_links+1; + hnswNode **new_links = hrealloc(neighbor->layers[layer].links, + sizeof(hnswNode*) * new_max_links); + if (new_links == NULL) continue; // Non critical. + + /* Update neighbor's link capacity. */ + neighbor->layers[layer].links = new_links; + neighbor->layers[layer].max_links = new_max_links; + + /* Establish bidirectional link. */ + uint32_t n = neighbor->layers[layer].num_links; + neighbor->layers[layer].links[n] = new_node; + neighbor->layers[layer].num_links++; + hnsw_update_worst_neighbor_on_add(index, neighbor, layer, + n, dist); + + n = new_node->layers[layer].num_links; + new_node->layers[layer].links[n] = neighbor; + new_node->layers[layer].num_links++; + hnsw_update_worst_neighbor_on_add(index, new_node, layer, + n, dist); + continue; + } + } + + // Remove backlink from the worst node of our candidate. + for (uint64_t j = 0; j < worst_node->layers[layer].num_links; j++) { + if (worst_node->layers[layer].links[j] == neighbor) { + memmove(&worst_node->layers[layer].links[j], + &worst_node->layers[layer].links[j+1], + (worst_node->layers[layer].num_links - j - 1) * sizeof(hnswNode*)); + worst_node->layers[layer].num_links--; + hnsw_update_worst_neighbor_on_remove(index,worst_node,layer,j); + break; + } + } + + /* Replace worst link with the new node. */ + neighbor->layers[layer].links[neighbor->layers[layer].worst_idx] = new_node; + + /* Update the worst link in the target node, at this point + * the link that we replaced may no longer be the worst. */ + hnsw_update_worst_neighbor(index,neighbor,layer); + + // Add new node -> candidate link. + uint32_t new_links = new_node->layers[layer].num_links; + new_node->layers[layer].links[new_links] = neighbor; + new_node->layers[layer].num_links++; + + // Update new node worst link. + hnsw_update_worst_neighbor_on_add(index,new_node,layer,new_links,dist); + } +} + +/* This function implements node reconnection after a node deletion in HNSW. + * When a node is deleted, other nodes at the specified layer lose one + * connection (all the neighbors of the deleted node). This function attempts + * to pair such nodes together in a way that maximizes connection quality + * among the M nodes that were former neighbors of our deleted node. + * + * The algorithm works by first building a distance matrix among the nodes: + * + * N0 N1 N2 N3 + * N0 0 1.2 0.4 0.9 + * N1 1.2 0 0.8 0.5 + * N2 0.4 0.8 0 1.1 + * N3 0.9 0.5 1.1 0 + * + * For each potential pairing (i,j) we compute a score that combines: + * 1. The direct cosine distance between the two nodes + * 2. The average distance to other nodes that would no longer be + * available for pairing if we select this pair + * + * We want to balance local node-to-node requirements and global requirements. + * For instance sometimes connecting A with B, while optimal, would leave + * C and D to be connected without other choices, and this could be a very + * bad connection. Maybe instead A and C and B and D are both relatively high + * quality connections. + * + * The formula used to calculate the score of each connection is: + * + * score[i,j] = W1*(2-distance[i,j]) + W2*((new_avg_i + new_avg_j)/2) + * where new_avg_x is the average of distances in row x excluding distance[i,j] + * + * So the score is directly proportional to the SIMILARITY of the two nodes + * and also directly proportional to the DISTANCE of the potential other + * connections that we lost by pairign i,j. So we have a cost for missed + * opportunities, or better, in this case, a reward if the missing + * opportunities are not so good (big average distance). + * + * W1 and W2 are weights (defaults: 0.7 and 0.3) that determine the relative + * importance of immediate connection quality vs future pairing potential. + * + * After the initial pairing phase, any nodes that couldn't be paired + * (due to odd count or existing connections) are handled by searching + * the broader graph using the standard HNSW neighbor selection logic. + */ +void hnsw_reconnect_nodes(HNSW *index, hnswNode **nodes, int count, uint32_t layer) { + if (count <= 0) return; + debugmsg("Reconnecting %d nodes\n", count); + + /* Step 1: Build the distance matrix between all nodes. + * Since distance(i,j) = distance(j,i), we only compute the upper triangle + * and mirror it to the lower triangle. */ + float *distances = hmalloc(count * count * sizeof(float)); + if (!distances) return; + + for (int i = 0; i < count; i++) { + distances[i*count + i] = 0; // Distance to self is 0 + for (int j = i+1; j < count; j++) { + float dist = hnsw_distance(index, nodes[i], nodes[j]); + distances[i*count + j] = dist; // Upper triangle. + distances[j*count + i] = dist; // Lower triangle. + } + } + + /* Step 2: Calculate row averages (will be used in scoring): + * please note that we just calculate row averages and not + * colums averages since the matrix is symmetrical, so those + * are the same: check the image in the top comment if you have any + * doubt about this. */ + float *row_avgs = hmalloc(count * sizeof(float)); + if (!row_avgs) { + hfree(distances); + return; + } + + for (int i = 0; i < count; i++) { + float sum = 0; + int valid_count = 0; + for (int j = 0; j < count; j++) { + if (i != j) { + sum += distances[i*count + j]; + valid_count++; + } + } + row_avgs[i] = valid_count ? sum / valid_count : 0; + } + + /* Step 3: Build scoring matrix. What we do here is to combine how + * good is a given i,j nodes connection, with how badly connecting + * i,j will affect the remaining quality of connections left to + * pair the other nodes. */ + float *scores = hmalloc(count * count * sizeof(float)); + if (!scores) { + hfree(distances); + hfree(row_avgs); + return; + } + + /* Those weights were obtained manually... No guarantee that they + * are optimal. However with these values the algorithm is certain + * better than its greedy version that just attempts to pick the + * best pair each time (verified experimentally). */ + const float W1 = 0.7; // Weight for immediate distance. + const float W2 = 0.3; // Weight for future potential. + + for (int i = 0; i < count; i++) { + for (int j = 0; j < count; j++) { + if (i == j) { + scores[i*count + j] = -1; // Invalid pairing. + continue; + } + + // Check for existing connection between i and j. + int already_linked = 0; + for (uint32_t k = 0; k < nodes[i]->layers[layer].num_links; k++) + { + if (nodes[i]->layers[layer].links[k] == nodes[j]) { + scores[i*count + j] = -1; // Already linked. + already_linked = 1; + break; + } + } + if (already_linked) continue; + + float dist = distances[i*count + j]; + + /* Calculate new averages excluding this pair. + * Handle edge case where we might have too few elements. + * Note that it would be not very smart to recompute the average + * each time scanning the row, we can remove the element + * and adjust the average without it. */ + float new_avg_i = 0, new_avg_j = 0; + if (count > 2) { + new_avg_i = (row_avgs[i] * (count-1) - dist) / (count-2); + new_avg_j = (row_avgs[j] * (count-1) - dist) / (count-2); + } + + /* Final weighted score: the more similar i,j, the better + * the score. The more distant are the pairs we lose by + * connecting i,j, the better the score. */ + scores[i*count + j] = W1*(2-dist) + W2*((new_avg_i + new_avg_j)/2); + } + } + + // Step 5: Pair nodes greedily based on scores. + int *used = hmalloc(count*sizeof(int)); + memset(used,0,count*sizeof(int)); + if (!used) { + hfree(distances); + hfree(row_avgs); + hfree(scores); + return; + } + + /* Scan the matrix looking each time for the potential + * link with the best score. */ + while(1) { + float max_score = -1; + int best_j = -1, best_i = -1; + + // Seek best score i,j values. + for (int i = 0; i < count; i++) { + if (used[i]) continue; // Already connected. + + /* No space left? Not possible after a node deletion but makes + * this function more future-proof. */ + if (nodes[i]->layers[layer].num_links >= + nodes[i]->layers[layer].max_links) continue; + + for (int j = 0; j < count; j++) { + if (i == j) continue; // Same node, skip. + if (used[j]) continue; // Already connected. + float score = scores[i*count + j]; + if (score < 0) continue; // Invalid link. + + /* If the target node has space, and its score is better + * than any other seen so far... remember it is the best. */ + if (score > max_score && + nodes[j]->layers[layer].num_links < + nodes[j]->layers[layer].max_links) + { + // Track the best connection found so far. + max_score = score; + best_j = j; + best_i = i; + } + } + } + + // Possible link found? Connect i and j. + if (best_j != -1) { + debugmsg("[%d] linking %d with %d: %f\n", layer, (int)best_i, (int)best_j, max_score); + // Link i -> j. + int link_idx = nodes[best_i]->layers[layer].num_links; + nodes[best_i]->layers[layer].links[link_idx] = nodes[best_j]; + nodes[best_i]->layers[layer].num_links++; + + // Update worst distance if needed. + float dist = distances[best_i*count + best_j]; + hnsw_update_worst_neighbor_on_add(index,nodes[best_i],layer,link_idx,dist); + + // Link j -> i. + link_idx = nodes[best_j]->layers[layer].num_links; + nodes[best_j]->layers[layer].links[link_idx] = nodes[best_i]; + nodes[best_j]->layers[layer].num_links++; + + // Update worst distance if needed. + hnsw_update_worst_neighbor_on_add(index,nodes[best_j],layer,link_idx,dist); + + // Mark connection as used. + used[best_i] = used[best_j] = 1; + } else { + break; // No more valid connections available. + } + } + + /* Step 6: Handle remaining unpaired nodes using the standard HNSW + * neighbor selection. */ + for (int i = 0; i < count; i++) { + if (used[i]) continue; + + // Skip if node is already at max connections. + if (nodes[i]->layers[layer].num_links >= + nodes[i]->layers[layer].max_links) + continue; + + debugmsg("[%d] Force linking %d\n", layer, i); + + /* First, try with local nodes as candidates. + * Some candidate may have space. */ + pqueue *candidates = pq_new(count); + if (!candidates) continue; + + /* Add all the local nodes having some space as candidates + * to be linked with this node. */ + for (int j = 0; j < count; j++) { + if (i != j && // Must not be itself. + nodes[j]->layers[layer].num_links < // Must not be full. + nodes[j]->layers[layer].max_links) + { + float dist = distances[i*count + j]; + pq_push(candidates, nodes[j], dist); + } + } + + /* Try local candidates first with aggressive = 1. + * So we will link only if there is space. + * We want one link more than the links we already have. */ + uint32_t wanted_links = nodes[i]->layers[layer].num_links+1; + if (candidates->count > 0) { + select_neighbors(index, candidates, nodes[i], layer, + wanted_links, 1); + debugmsg("Final links after attempt with local nodes: %d (wanted: %d)\n", (int)nodes[i]->layers[layer].num_links, wanted_links); + } + + // If still no connection, search the broader graph. + if (nodes[i]->layers[layer].num_links != wanted_links) { + debugmsg("No force linking possible with local candidats\n"); + pq_free(candidates); + + // Find entry point for target layer by descending through levels. + hnswNode *curr_ep = index->enter_point; + for (uint32_t lc = index->max_level; lc > layer; lc--) { + pqueue *results = search_layer(index, nodes[i], curr_ep, 1, lc, 0); + if (results) { + if (results->count > 0) { + curr_ep = pq_get_node(results,0); + } + pq_free(results); + } + } + + if (curr_ep) { + /* Search this layer for candidates. + * Use the defalt EF_C in this case, since it's not an + * "insert" operation, and we don't know the user + * specified "EF". */ + candidates = search_layer(index, nodes[i], curr_ep, HNSW_EF_C, layer, 0); + if (candidates) { + /* Try to connect with aggressiveness proportional to the + * node linking condition. */ + int aggressiveness = + (nodes[i]->layers[layer].num_links > index->M / 2) + ? 1 : 2; + select_neighbors(index, candidates, nodes[i], layer, + wanted_links, aggressiveness); + debugmsg("Final links with broader search: %d (wanted: %d)\n", (int)nodes[i]->layers[layer].num_links, wanted_links); + pq_free(candidates); + } + } + } else { + pq_free(candidates); + } + } + + // Cleanup. + hfree(distances); + hfree(row_avgs); + hfree(scores); + hfree(used); +} + +/* This is an helper function in order to support node deletion. + * It's goal is just to: + * + * 1. Remove the node from the bidirectional links of neighbors in the graph. + * 2. Remove the node from the linked list of nodes. + * 3. Fix the entry point in the graph. We just select one of the neighbors + * of the deleted node at a lower level. If none is found, we do + * a full scan. + * 4. The node itself amd its aux value field are NOT freed. It's up to the + * caller to do it, by using hnsw_node_free(). + * 5. The node associated value (node->value) is NOT freed. + * + * Why this function will not free the node? Because in node updates it + * could be a good idea to reuse the node allocation for different reasons + * (currently not implemented). + * In general it is more future-proof to be able to reuse the node if + * needed. Right now this library reuses the node only when links are + * not touched (see hnsw_update() for more information). */ +void hnsw_unlink_node(HNSW *index, hnswNode *node) { + if (!index || !node) return; + + index->version++; // This node may be missing in an already compiled list + // of neighbors. Make optimistic concurrent inserts fail. + + /* Remove all bidirectional links at each level. + * Note that in this implementation all the + * links are guaranteed to be bedirectional. */ + + /* For each level of the deleted node... */ + for (uint32_t level = 0; level <= node->level; level++) { + /* For each linked node of the deleted node... */ + for (uint32_t i = 0; i < node->layers[level].num_links; i++) { + hnswNode *linked = node->layers[level].links[i]; + /* Find and remove the backlink in the linked node */ + for (uint32_t j = 0; j < linked->layers[level].num_links; j++) { + if (linked->layers[level].links[j] == node) { + /* Remove by shifting remaining links left */ + memmove(&linked->layers[level].links[j], + &linked->layers[level].links[j + 1], + (linked->layers[level].num_links - j - 1) * sizeof(hnswNode*)); + linked->layers[level].num_links--; + hnsw_update_worst_neighbor_on_remove(index,linked,level,j); + break; + } + } + } + } + + /* Update cursors pointing at this element. */ + if (index->cursors) hnsw_cursor_element_deleted(index,node); + + /* Update the previous node's next pointer. */ + if (node->prev) { + node->prev->next = node->next; + } else { + /* If there's no previous node, this is the head. */ + index->head = node->next; + } + + /* Update the next node's prev pointer. */ + if (node->next) node->next->prev = node->prev; + + /* Update node count. */ + index->node_count--; + + /* If this node was the enter_point, we need to update it. */ + if (node == index->enter_point) { + /* Reset entry point - we'll find a new one (unless the HNSW is + * now empty) */ + index->enter_point = NULL; + index->max_level = 0; + + /* Step 1: Try to find a replacement by scanning levels + * from top to bottom. Under normal conditions, if there is + * any other node at the same level, we have a link. Anyway + * we descend levels to find any neighbor at the higher level + * possible. */ + for (int level = node->level; level >= 0; level--) { + if (node->layers[level].num_links > 0) { + index->enter_point = node->layers[level].links[0]; + break; + } + } + + /* Step 2: If no links were found at any level, do a full scan. + * This should never happen in practice if the HNSW is not + * empty. */ + if (!index->enter_point) { + uint32_t new_max_level = 0; + hnswNode *current = index->head; + + while (current) { + if (current != node && current->level >= new_max_level) { + new_max_level = current->level; + index->enter_point = current; + } + current = current->next; + } + } + + /* Update max_level. */ + if (index->enter_point) + index->max_level = index->enter_point->level; + } + + /* Clear the node's links but don't free the node itself */ + node->prev = node->next = NULL; +} + +/* Higher level API for hnsw_unlink_node() + hnsw_reconnect_nodes() actual work. + * This will get the write lock, will delete the node, free it, + * reconnect the node neighbors among themselves, and unlock again. + * If free_value function pointer is not NULL, then the function provided is + * used to free node->value. + * + * The function returns 0 on error (inability to acquire the lock), otherwise + * 1 is returned. */ +int hnsw_delete_node(HNSW *index, hnswNode *node, void(*free_value)(void*value)) { + if (pthread_rwlock_wrlock(&index->global_lock) != 0) return 0; + hnsw_unlink_node(index,node); + if (free_value && node->value) free_value(node->value); + + /* Relink all the nodes orphaned of this node link. + * Do it for all the levels. */ + for (unsigned int j = 0; j <= node->level; j++) { + hnsw_reconnect_nodes(index, node->layers[j].links, + node->layers[j].num_links, j); + } + hnsw_node_free(node); + pthread_rwlock_unlock(&index->global_lock); + return 1; +} + +/* ============================ Threaded API ================================ + * Concurent readers should use the following API to get a slot assigned + * (and a lock, too), do their read-only call, and unlock the slot. + * + * There is a reason why read operations don't implement opaque transparent + * locking directly on behalf of the user: when we return a result set + * with hnsw_search(), we report a set of nodes. The caller will do something + * with the nodes and the associated values, so the unlocking of the + * slot should happen AFTER the result was already used, otherwise we may + * have changes to the HNSW nodes as the result is being accessed. */ + +/* Try to acquire a read slot. Returns the slot number (0 to HNSW_MAX_THREADS-1) + * on success, -1 on error (pthread mutex errors). */ +int hnsw_acquire_read_slot(HNSW *index) { + /* First try a non-blocking approach on all slots. */ + for (uint32_t i = 0; i < HNSW_MAX_THREADS; i++) { + if (pthread_mutex_trylock(&index->slot_locks[i]) == 0) { + if (pthread_rwlock_rdlock(&index->global_lock) != 0) { + pthread_mutex_unlock(&index->slot_locks[i]); + return -1; + } + return i; + } + } + + /* All trylock attempts failed, use atomic increment to select slot. */ + uint32_t slot = index->next_slot++ % HNSW_MAX_THREADS; + + /* Try to lock the selected slot. */ + if (pthread_mutex_lock(&index->slot_locks[slot]) != 0) return -1; + + /* Get read lock. */ + if (pthread_rwlock_rdlock(&index->global_lock) != 0) { + pthread_mutex_unlock(&index->slot_locks[slot]); + return -1; + } + + return slot; +} + +/* Release a previously acquired read slot: note that it is important that + * nodes returned by hnsw_search() are accessed while the read lock is + * still active, to be sure that nodes are not freed. */ +void hnsw_release_read_slot(HNSW *index, int slot) { + if (slot < 0 || slot >= HNSW_MAX_THREADS) return; + pthread_rwlock_unlock(&index->global_lock); + pthread_mutex_unlock(&index->slot_locks[slot]); +} + +/* ============================ Nodes insertion ============================= + * We have an optimistic API separating the read-only candidates search + * and the write side (actual node insertion). We internally also use + * this API to provide the plain hnsw_insert() function for code unification. */ + +struct InsertContext { + pqueue *level_queues[HNSW_MAX_LEVEL]; /* Candidates for each level. */ + hnswNode *node; /* Pre-allocated node ready for insertion */ + uint64_t version; /* Index version at preparation time. This is used + * for CAS-like locking during change commit. */ +}; + +/* Optimistic insertion API. + * + * WARNING: Note that this is an internal function: users should call + * hnsw_prepare_insert() instead. + * + * This is how it works: you use hnsw_prepare_insert() and it will return + * a context where good candidate neighbors are already pre-selected. + * This step only uses read locks. + * + * Then finally you try to actually commit the new node with + * hnsw_try_commit_insert(): this time we will require a write lock, but + * for less time than it would be otherwise needed if using directly + * hnsw_insert(). When you try to commit the write, if no node was deleted in + * the meantime, your operation will succeed, otherwise it will fail, and + * you should try to just use the hnsw_insert() API, since there is + * contention. + * + * See hnsw_node_new() for information about 'vector' and 'qvector' + * arguments, and which one to pass. */ +InsertContext *hnsw_prepare_insert_nolock(HNSW *index, const float *vector, + const int8_t *qvector, float qrange, uint64_t id, + int slot, int ef) +{ + InsertContext *ctx = hmalloc(sizeof(*ctx)); + if (!ctx) return NULL; + + memset(ctx, 0, sizeof(*ctx)); + ctx->version = index->version; + + /* Crete a new node that we may be able to insert into the + * graph later, when calling the commit function. */ + uint32_t level = random_level(); + ctx->node = hnsw_node_new(index, id, vector, qvector, qrange, level, 1); + if (!ctx->node) { + hfree(ctx); + return NULL; + } + + hnswNode *curr_ep = index->enter_point; + + /* Empty graph, no need to collect candidates. */ + if (curr_ep == NULL) return ctx; + + /* Phase 1: Find good entry point on the highest level of the new + * node we are going to insert. */ + for (unsigned int lc = index->max_level; lc > level; lc--) { + pqueue *results = search_layer(index, ctx->node, curr_ep, 1, lc, slot); + + if (results) { + if (results->count > 0) curr_ep = pq_get_node(results,0); + pq_free(results); + } + } + + /* Phase 2: Collect a set of potential connections for each layer of + * the new node. */ + for (int lc = MIN(level, index->max_level); lc >= 0; lc--) { + pqueue *candidates = + search_layer(index, ctx->node, curr_ep, ef, lc, slot); + + if (!candidates) continue; + curr_ep = (candidates->count > 0) ? pq_get_node(candidates,0) : curr_ep; + ctx->level_queues[lc] = candidates; + } + + return ctx; +} + +/* External API for hnsw_prepare_insert_nolock(), handling locking. */ +InsertContext *hnsw_prepare_insert(HNSW *index, const float *vector, + const int8_t *qvector, float qrange, uint64_t id, + int ef) +{ + InsertContext *ctx; + int slot = hnsw_acquire_read_slot(index); + ctx = hnsw_prepare_insert_nolock(index,vector,qvector,qrange,id,slot,ef); + hnsw_release_read_slot(index,slot); + return ctx; +} + +/* Free an insert context and all its resources. */ +void hnsw_free_insert_context(InsertContext *ctx) { + if (!ctx) return; + for (uint32_t i = 0; i < HNSW_MAX_LEVEL; i++) { + if (ctx->level_queues[i]) pq_free(ctx->level_queues[i]); + } + if (ctx->node) hnsw_node_free(ctx->node); + hfree(ctx); +} + +/* Commit a prepared insert operation. This function is a low level API that + * should not be called by the user. See instead hnsw_try_commit_insert(), that + * will perform the CAS check and acquire the write lock. + * + * See the top comment in hnsw_prepare_insert() for more information + * on the optimistic insertion API. + * + * This function can't fail and always returns the pointer to the + * just inserted node. Out of memory is not possible since no critical + * allocation is never performed in this code path: we populate links + * on already allocated nodes. */ +hnswNode *hnsw_commit_insert_nolock(HNSW *index, InsertContext *ctx, void *value) { + hnswNode *node = ctx->node; + node->value = value; + + /* Handle first node case. */ + if (index->enter_point == NULL) { + index->version++; // First node, make concurrent inserts fail. + index->enter_point = node; + index->max_level = node->level; + hnsw_add_node(index, node); + ctx->node = NULL; // So hnsw_free_insert_context() will not free it. + hnsw_free_insert_context(ctx); + return node; + } + + /* Connect the node with near neighbors at each level. */ + for (int lc = MIN(node->level,index->max_level); lc >= 0; lc--) { + if (ctx->level_queues[lc] == NULL) continue; + + /* Try to provide index->M connections to our node. The call + * is not guaranteed to be able to provide all the links we would + * like to have for the new node: they must be bi-directional, obey + * certain quality checks, and so forth, so later there are further + * calls to force the hand a bit if needed. + * + * Let's start with aggressiveness = 0. */ + select_neighbors(index, ctx->level_queues[lc], node, lc, index->M, 0); + + /* Layer 0 and too few connections? Let's be more aggressive. */ + if (lc == 0 && node->layers[0].num_links < index->M/2) { + select_neighbors(index, ctx->level_queues[lc], node, lc, + index->M, 1); + + /* Still too few connections? Let's go to + * aggressiveness level '2' in linking strategy. */ + if (node->layers[0].num_links < index->M/4) { + select_neighbors(index, ctx->level_queues[lc], node, lc, + index->M/4, 2); + } + } + } + + /* If new node level is higher than current max, update entry point. */ + if (node->level > index->max_level) { + index->version++; // Entry point changed, make concurrent inserts fail. + index->enter_point = node; + index->max_level = node->level; + } + + /* Add node to the linked list. */ + hnsw_add_node(index, node); + ctx->node = NULL; // So hnsw_free_insert_context() will not free the node. + hnsw_free_insert_context(ctx); + return node; +} + +/* If the context obtained with hnsw_prepare_insert() is still valid + * (nodes not deleted in the meantime) then add the new node to the HNSW + * index and return its pointer. Otherwise NULL is returned and the operation + * should be either performed with the blocking API hnsw_insert() or attempted + * again. */ +hnswNode *hnsw_try_commit_insert(HNSW *index, InsertContext *ctx, void *value) { + /* Check if the version changed since preparation. Note that we + * should access index->version under the write lock in order to + * be sure we can safely commit the write: this is just a fast-path + * in order to return ASAP without acquiring the write lock in case + * the version changed. */ + if (ctx->version != index->version) { + hnsw_free_insert_context(ctx); + return NULL; + } + + /* Try to acquire write lock. */ + if (pthread_rwlock_wrlock(&index->global_lock) != 0) { + hnsw_free_insert_context(ctx); + return NULL; + } + + /* Check version again under write lock. */ + if (ctx->version != index->version) { + pthread_rwlock_unlock(&index->global_lock); + hnsw_free_insert_context(ctx); + return NULL; + } + + /* Commit the change: note that it's up to hnsw_commit_insert_nolock() + * to free the insertion context. */ + hnswNode *node = hnsw_commit_insert_nolock(index, ctx, value); + + /* Release the write lock. */ + pthread_rwlock_unlock(&index->global_lock); + return node; +} + +/* Insert a new element into the graph. + * See hnsw_node_new() for information about 'vector' and 'qvector' + * arguments, and which one to pass. + * + * Return NULL on out of memory during insert. Otherwise the newly + * inserted node pointer is returned. */ +hnswNode *hnsw_insert(HNSW *index, const float *vector, const int8_t *qvector, float qrange, uint64_t id, void *value, int ef) { + /* Write lock. We acquire the write lock even for the prepare() + * operation (that is a read-only operation) since we want this function + * to don't fail in the check-and-set stage of commit(). + * + * Basically here we are using the optimistic API in a non-optimistinc + * way in order to have a single insertion code in the implementation. */ + if (pthread_rwlock_wrlock(&index->global_lock) != 0) return NULL; + + // Prepare the insertion - note we pass slot 0 since we're single threaded. + InsertContext *ctx = hnsw_prepare_insert_nolock(index, vector, qvector, + qrange, id, 0, ef); + if (!ctx) { + pthread_rwlock_unlock(&index->global_lock); + return NULL; + } + + // Commit the prepared insertion without version checking. + hnswNode *node = hnsw_commit_insert_nolock(index, ctx, value); + + // Release write lock and return our node pointer. + pthread_rwlock_unlock(&index->global_lock); + return node; +} + +/* Helper function for qsort call in hnsw_should_reuse_node(). */ +static int compare_floats(const float *a, const float *b) { + if (*a < *b) return 1; + if (*a > *b) return -1; + return 0; +} + +/* This function determines if a node can be reused with a new vector by: + * + * 1. Computing average of worst 25% of current distances. + * 2. Checking if at least 50% of new distances stay below this threshold. + * 3. Requiring a minimum number of links for the check to be meaningful. + * + * This check is useful when we want to just update a node that already + * exists in the graph. Often the new vector is a learned embedding generated + * by some model, and the embedding represents some document that perhaps + * changed just slightly compared to the past, so the new embedding will + * be very nearby. We need to find a way do determine if the current node + * neighbors (practically speaking its location in the grapb) are good + * enough even with the new vector. + * + * XXX: this function needs improvements: successive updates to the same + * node with more and more distant vectors will make the node drift away + * from its neighbors. One of the additional metrics used could be + * neighbor-to-neighbor distance, that represents a more absolute check + * of fit for the new vector. */ +int hnsw_should_reuse_node(HNSW *index, hnswNode *node, int is_normalized, const float *new_vector) { + /* Step 1: Not enough links? Advice to avoid reuse. */ + const uint32_t min_links_for_reuse = 4; + uint32_t layer0_connections = node->layers[0].num_links; + if (layer0_connections < min_links_for_reuse) return 0; + + /* Step2: get all current distances and run our heuristic. */ + float *old_distances = hmalloc(sizeof(float) * layer0_connections); + if (!old_distances) return 0; + + // Temporary node with the new vector, to simplify the next logic. + hnswNode tmp_node; + if (hnsw_init_tmp_node(index,&tmp_node,is_normalized,new_vector) == 0) { + hfree(old_distances); + return 0; + } + + /* Get old dinstances and sort them to access the 25% worst + * (bigger) ones. */ + for (uint32_t i = 0; i < layer0_connections; i++) { + old_distances[i] = hnsw_distance(index, node, node->layers[0].links[i]); + } + qsort(old_distances, layer0_connections, sizeof(float), + (int (*)(const void*, const void*))(&compare_floats)); + + uint32_t count = (layer0_connections+3)/4; // 25% approx to larger int. + if (count > layer0_connections) count = layer0_connections; // Futureproof. + float worst_avg = 0; + + // Compute average of 25% worst dinstances. + for (uint32_t i = 0; i < count; i++) worst_avg += old_distances[i]; + worst_avg /= count; + hfree(old_distances); + + // Count how many new distances stay below the threshold. + uint32_t good_distances = 0; + for (uint32_t i = 0; i < layer0_connections; i++) { + float new_dist = hnsw_distance(index, &tmp_node, node->layers[0].links[i]); + if (new_dist <= worst_avg) good_distances++; + } + hnsw_free_tmp_node(&tmp_node,new_vector); + + /* At least 50% of the nodes should pass our quality test, for the + * node to be reused. */ + return good_distances >= layer0_connections/2; +} + +/** + * Return a random node from the HNSW graph. + * + * This function performs a random walk starting from the entry point, + * using only level 0 connections for navigation. It uses log^2(N) steps + * to ensure proper mixing time. + */ + +hnswNode *hnsw_random_node(HNSW *index, int slot) { + if (index->node_count == 0 || index->enter_point == NULL) + return NULL; + + (void)slot; // Unused, but we need the caller to acquire the lock. + + /* First phase: descend from max level to level 0 taking random paths. + * Note that we don't need a more conservative log^2(N) steps for + * proper mixing, since we already descend to a random cluster here. */ + hnswNode *current = index->enter_point; + for (uint32_t level = index->max_level; level > 0; level--) { + /* If current node doesn't have this level or no links, continue + * to lower level. */ + if (current->level < level || current->layers[level].num_links == 0) + continue; + + /* Choose random neighbor at this level. */ + uint32_t rand_neighbor = rand() % current->layers[level].num_links; + current = current->layers[level].links[rand_neighbor]; + } + + /* Second phase: at level 0, take log(N) * c random steps. */ + const int c = 3; // Multiplier for more thorough exploration. + double logN = log2(index->node_count + 1); + uint32_t num_walks = (uint32_t)(logN * c); + + // Perform random walk at level 0. + for (uint32_t i = 0; i < num_walks; i++) { + if (current->layers[0].num_links == 0) return current; + + // Choose random neighbor. + uint32_t rand_neighbor = rand() % current->layers[0].num_links; + current = current->layers[0].links[rand_neighbor]; + } + return current; +} + +/* ============================= Serialization ============================== + * + * TO SERIALIZE + * ============ + * + * To serialize on disk, you need to persist the vector dimension, number + * of elements, and the quantization type index->quant_type. These are + * global values for the whole index. + * + * Then, to serialize each node: + * + * call hnsw_serialize_node() with each node you find in the linked list + * of nodes, starting at index->head (each node has a next pointer). + * The function will return an hnswSerNode structure, you will need + * to store the following on disk (for each node): + * + * - The sernode->vector data, that is sernode->vector_size bytes. + * - The sernode->params array, that points to an array of uint64_t + * integers. There are sernode->params_count total items. These + * parameters contain everything there is to need about your node: how + * many levels it has, its ID, the list of neighbors for each level (as node + * IDs), and so forth. + * + * You need to to save your own node->value in some way as well, but it already + * belongs to the user of the API, since, for this library, it's just a pointer, + * so the user should know how to serialized its private data. + * + * RELOADING FROM DISK / NET + * ========================= + * + * When reloading nodes, you first load the index vector dimension and + * quantization type, and create the index with: + * + * HNSW *hnsw_new(uint32_t vector_dim, uint32_t quant_type); + * + * Then you load back, for each node (you stored how many nodes you had) + * the vector and the params array / count. + * You also load the value associated with your node. + * + * At this point you add back the loaded elements into the index with: + * + * hnsw_insert_serialized(HNSW *index, void *vector, uint64_t params, + * uint32_t params_len, void *value); + * + * Once you added all the nodes back, you need to resolve the pointers + * (since so far they are added just with the node IDs as reference), so + * you call: + * + * hnsw_deserialize_index(index); + * + * The index is now ready to be used like if it has been always in memory. + * + * DESIGN NOTES + * ============ + * + * Why this API does not just give you a binary blob to save? Because in + * many systems (and in Redis itself) to save integers / floats can have + * more interesting encodings that just storing a 64 bit value. Many vector + * indexes will be small, and their IDs will be small numbers, so the storage + * system can exploit that and use less disk space, less network bandwidth + * and so forth. + * + * How is the data stored in these arrays of numbers? Oh well, we have + * things that are obviously numbers like node ID, number of levels for the + * node and so forth. Also each of our nodes have an unique incremental ID, + * so we can store a node set of links in terms of linked node IDs. This + * data is put directly in the loaded node pointer space! We just cast the + * integer to the pointer (so THIS IS NOT SAFE for 32 bit systems). Then + * we want to translate such IDs into pointers. To do that, we build an + * hash table, then scan all the nodes again and fix all the links converting + * the ID to the pointer. */ + +/* Return the serialized node information as specified in the top comment + * above. Note that the returned information is true as long as the node + * provided is not deleted or modified, so this function should be called + * when there are no concurrent writes. + * + * The function hnsw_serialize_node() should be called in order to + * free the result of this function. */ +hnswSerNode *hnsw_serialize_node(HNSW *index, hnswNode *node) { + /* The first step is calculating the number of uint64_t parameters + * that we need in order to serialize the node. */ + uint32_t num_params = 0; + num_params += 2; // node ID, number of layers. + for (uint32_t i = 0; i <= node->level; i++) { + num_params += 2; // max_links and num_links info for this layer. + num_params += node->layers[i].num_links; // The IDs of linked nodes. + } + + /* We use another 64bit value to store two floats that are about + * the vector: l2 and quantization range (that is only used if the + * vector is quantized). */ + num_params++; + + /* Allocate the return object and the parameters array. */ + hnswSerNode *sn = hmalloc(sizeof(hnswSerNode)); + if (sn == NULL) return NULL; + sn->params = hmalloc(sizeof(uint64_t)*num_params); + if (sn->params == NULL) { + hfree(sn); + return NULL; + } + + /* Fill data. */ + sn->params_count = num_params; + sn->vector = node->vector; + sn->vector_size = hnsw_quants_bytes(index); + + uint32_t param_idx = 0; + sn->params[param_idx++] = node->id; + sn->params[param_idx++] = node->level; + for (uint32_t i = 0; i <= node->level; i++) { + sn->params[param_idx++] = node->layers[i].num_links; + sn->params[param_idx++] = node->layers[i].max_links; + for (uint32_t j = 0; j < node->layers[i].num_links; j++) { + sn->params[param_idx++] = node->layers[i].links[j]->id; + } + } + uint64_t l2_and_range = 0; + unsigned char *aux = (unsigned char*)&l2_and_range; + memcpy(aux,&node->l2,sizeof(float)); + memcpy(aux+4,&node->quants_range,sizeof(float)); + sn->params[param_idx++] = l2_and_range; + + /* Better safe than sorry: */ + assert(param_idx == num_params); + return sn; +} + +/* This is needed in order to free hnsw_serialize_node() returned + * structure. */ +void hnsw_free_serialized_node(hnswSerNode *sn) { + hfree(sn->params); + hfree(sn); +} + +/* Load a serialized node. See the top comment in this section of code + * for the documentation about how to use this. + * + * The function returns NULL both on out of memory and if the remaining + * parameters length does not match the number of links or other items + * to load. */ +hnswNode *hnsw_insert_serialized(HNSW *index, void *vector, uint64_t *params, uint32_t params_len, void *value) +{ + if (params_len < 2) return NULL; + + uint64_t id = params[0]; + uint32_t level = params[1]; + + /* Keep track of maximum ID seen while loading. */ + if (id >= index->last_id) index->last_id = id; + + /* Create node, passing vector data directly based on quantization type. */ + hnswNode *node; + if (index->quant_type != HNSW_QUANT_NONE) { + node = hnsw_node_new(index, id, NULL, vector, 0, level, 0); + } else { + node = hnsw_node_new(index, id, vector, NULL, 0, level, 0); + } + if (!node) return NULL; + + /* Load params array into the node. */ + uint32_t param_idx = 2; + for (uint32_t i = 0; i <= level; i++) { + /* Sanity check. */ + if (param_idx + 2 > params_len) { + hnsw_node_free(node); + return NULL; + } + + uint32_t num_links = params[param_idx++]; + uint32_t max_links = params[param_idx++]; + + /* If max_links is larger than current allocation, reallocate. + * It could happen in select_neighbors() that we over-allocate the + * node under very unlikely to happen conditions. */ + if (max_links > node->layers[i].max_links) { + hnswNode **new_links = hrealloc(node->layers[i].links, + sizeof(hnswNode*) * max_links); + if (!new_links) { + hnsw_node_free(node); + return NULL; + } + node->layers[i].links = new_links; + node->layers[i].max_links = max_links; + } + node->layers[i].num_links = num_links; + + /* Sanity check. */ + if (param_idx + num_links > params_len) { + hnsw_node_free(node); + return NULL; + } + + /* Fill links for this layer with the IDs. Note that this + * is going to not work in 32 bit systems. Deleting / adding-back + * nodes can produce IDs larger than 2^32-1 even if we can't never + * fit more than 2^32 nodes in a 32 bit system. */ + for (uint32_t j = 0; j < num_links; j++) + node->layers[i].links[j] = (hnswNode*)params[param_idx++]; + } + + /* Get l2 and quantization range. */ + if (param_idx >= params_len) { + hnsw_node_free(node); + return NULL; + } + uint64_t l2_and_range = params[param_idx]; + unsigned char *aux = (unsigned char*)&l2_and_range; + memcpy(&node->l2, aux, sizeof(float)); + memcpy(&node->quants_range, aux+4, sizeof(float)); + + node->value = value; + hnsw_add_node(index, node); + + /* Keep track of higher node level and set the entry point to the + * greatest level node seen so far: thanks to this check we don't + * need to remember what our entry point was during serialization. */ + if (index->enter_point == NULL || level > index->max_level) { + index->max_level = level; + index->enter_point = node; + } + return node; +} + +/* Integer hashing, used by hnsw_deserialize_index(). + * MurmurHash3's 64-bit finalizer function. */ +uint64_t hnsw_hash_node_id(uint64_t id) { + id ^= id >> 33; + id *= 0xff51afd7ed558ccd; + id ^= id >> 33; + id *= 0xc4ceb9fe1a85ec53; + id ^= id >> 33; + return id; +} + +/* Fix pointers of neighbors nodes: after loading the serialized nodes, the + * neighbors links are just IDs (casted to pointers), instead of the actual + * pointers. We need to resolve IDs into pointers. + * + * Return 0 on error (out of memory or some ID that can't be resolved), 1 on + * success. */ +int hnsw_deserialize_index(HNSW *index) { + /* We will use simple linear probing, so over-allocating is a good + * idea: anyway this flat array of pointers will consume a fraction + * of the memory of the loaded index. */ + uint64_t min_size = index->node_count*2; + uint64_t table_size = 1; + while(table_size < min_size) table_size <<= 1; + + hnswNode **table = hmalloc(sizeof(hnswNode*) * table_size); + if (table == NULL) return 0; + memset(table,0,sizeof(hnswNode*) * table_size); + + /* First pass: populate the ID -> pointer hash table. */ + hnswNode *node = index->head; + while(node) { + uint64_t bucket = hnsw_hash_node_id(node->id) & (table_size-1); + for (uint64_t j = 0; j < table_size; j++) { + if (table[bucket] == NULL) { + table[bucket] = node; + break; + } + bucket = (bucket+1) & (table_size-1); + } + node = node->next; + } + + /* Second pass: fix pointers of all the neighbors links. */ + node = index->head; // Rewind. + while(node) { + for (uint32_t i = 0; i <= node->level; i++) { + for (uint32_t j = 0; j < node->layers[i].num_links; j++) { + uint64_t linked_id = (uint64_t) node->layers[i].links[j]; + uint64_t bucket = hnsw_hash_node_id(linked_id) & (table_size-1); + hnswNode *neighbor = NULL; + for (uint64_t k = 0; k < table_size; k++) { + if (table[bucket] && table[bucket]->id == linked_id) { + neighbor = table[bucket]; + break; + } + bucket = (bucket+1) & (table_size-1); + } + if (neighbor == NULL) { + /* Unresolved link! Either a bug in this code + * or broken serialization data. */ + hfree(table); + return 0; + } + node->layers[i].links[j] = neighbor; + } + } + node = node->next; + } + hfree(table); + return 1; +} + +/* ================================ Iterator ================================ */ + +/* Get a cursor that can be used as argument of hnsw_cursor_next() to iterate + * all the elements that remain there from the start to the end of the + * iteration, excluding newly added elements. + * + * The function returns NULL on out of memory. */ +hnswCursor *hnsw_cursor_init(HNSW *index) { + if (pthread_rwlock_wrlock(&index->global_lock) != 0) return NULL; + hnswCursor *cursor = hmalloc(sizeof(*cursor)); + if (cursor == NULL) { + pthread_rwlock_unlock(&index->global_lock); + return NULL; + } + cursor->index = index; + cursor->next = index->cursors; + cursor->current = index->head; + index->cursors = cursor; + pthread_rwlock_unlock(&index->global_lock); + return cursor; +} + +/* Free the cursor. Can be called both at the end of the iteration, when + * hnsw_cursor_next() returned NULL, or before. */ +void hnsw_cursor_free(hnswCursor *cursor) { + if (pthread_rwlock_wrlock(&cursor->index->global_lock) != 0) { + // No easy way to recover from that. We will leak memory. + return; + } + + hnswCursor *x = cursor->index->cursors; + hnswCursor *prev = NULL; + while(x) { + if (x == cursor) { + if (prev) + prev->next = cursor->next; + else + cursor->index->cursors = cursor->next; + hfree(cursor); + break; + } + x = x->next; + } + pthread_rwlock_unlock(&cursor->index->global_lock); +} + +/* Acquire a lock to use the cursor. Returns 1 if the lock was acquired + * with success, otherwise zero is returned. The returned element is + * protected after calling hnsw_cursor_next() for all the time required to + * access it, then hnsw_cursor_release_lock() should be called in order + * to unlock the HNSW index. */ +int hnsw_cursor_acquire_lock(hnswCursor *cursor) { + return pthread_rwlock_rdlock(&cursor->index->global_lock) == 0; +} + +/* Release the cursor lock, see hnsw_cursor_acquire_lock() top comment + * for more information. */ +void hnsw_cursor_release_lock(hnswCursor *cursor) { + pthread_rwlock_unlock(&cursor->index->global_lock); +} + +/* Return the next element of the HNSW. See hnsw_cursor_init() for + * the guarantees of the function. */ +hnswNode *hnsw_cursor_next(hnswCursor *cursor) { + hnswNode *ret = cursor->current; + if (ret) cursor->current = ret->next; + return ret; +} + +/* Called by hnsw_unlink_node() if there is at least an active cursor. + * Will scan the cursors to see if any cursor is going to yeld this + * one, and in this case, updates the current element to the next. */ +void hnsw_cursor_element_deleted(HNSW *index, hnswNode *deleted) { + hnswCursor *x = index->cursors; + while(x) { + if (x->current == deleted) x->current = deleted->next; + x = x->next; + } +} + +/* ============================ Debugging stuff ============================= */ + +/* Show stats about nodes connections. */ +void hnsw_print_stats(HNSW *index) { + if (!index || !index->head) { + printf("Empty index or NULL pointer passed\n"); + return; + } + + long long total_links = 0; + int min_links = -1; // We'll set this to first node's count. + int isolated_nodes = 0; + uint32_t node_count = 0; + + // Iterate through all nodes using the linked list. + hnswNode *current = index->head; + while (current) { + // Count total links for this node across all layers. + int node_total_links = 0; + for (uint32_t layer = 0; layer <= current->level; layer++) + node_total_links += current->layers[layer].num_links; + + // Update statistics. + total_links += node_total_links; + + // Initialize or update minimum links. + if (min_links == -1 || node_total_links < min_links) { + min_links = node_total_links; + } + + // Check if node is isolated (no links at all). + if (node_total_links == 0) isolated_nodes++; + + node_count++; + current = current->next; + } + + // Print statistics + printf("HNSW Graph Statistics:\n"); + printf("----------------------\n"); + printf("Total nodes: %u\n", node_count); + if (node_count > 0) { + printf("Average links per node: %.2f\n", + (float)total_links / node_count); + printf("Minimum links in a single node: %d\n", min_links); + printf("Number of isolated nodes: %d (%.1f%%)\n", + isolated_nodes, + (float)isolated_nodes * 100 / node_count); + } +} + +/* Validate graph connectivity and link reciprocity. Takes pointers to store results: + * - connected_nodes: will contain number of reachable nodes from entry point. + * - reciprocal_links: will contain 1 if all links are reciprocal, 0 otherwise. + * Returns 0 on success, -1 on error (NULL parameters and such). + */ +int hnsw_validate_graph(HNSW *index, uint64_t *connected_nodes, int *reciprocal_links) { + if (!index || !connected_nodes || !reciprocal_links) return -1; + if (!index->enter_point) { + *connected_nodes = 0; + *reciprocal_links = 1; // Empty graph is valid. + return 0; + } + + // Initialize connectivity check. + index->current_epoch[0]++; + *connected_nodes = 0; + *reciprocal_links = 1; + + // Initialize node stack. + uint64_t stack_size = index->node_count; + hnswNode **stack = hmalloc(sizeof(hnswNode*) * stack_size); + if (!stack) return -1; + uint64_t stack_top = 0; + + // Start from entry point. + index->enter_point->visited_epoch[0] = index->current_epoch[0]; + (*connected_nodes)++; + stack[stack_top++] = index->enter_point; + + // Process all reachable nodes. + while (stack_top > 0) { + hnswNode *current = stack[--stack_top]; + + // Explore all neighbors at each level. + for (uint32_t level = 0; level <= current->level; level++) { + for (uint64_t i = 0; i < current->layers[level].num_links; i++) { + hnswNode *neighbor = current->layers[level].links[i]; + + // Check reciprocity. + int found_backlink = 0; + for (uint64_t j = 0; j < neighbor->layers[level].num_links; j++) { + if (neighbor->layers[level].links[j] == current) { + found_backlink = 1; + break; + } + } + if (!found_backlink) { + *reciprocal_links = 0; + } + + // If we haven't visited this neighbor yet. + if (neighbor->visited_epoch[0] != index->current_epoch[0]) { + neighbor->visited_epoch[0] = index->current_epoch[0]; + (*connected_nodes)++; + if (stack_top < stack_size) { + stack[stack_top++] = neighbor; + } else { + // This should never happen in a valid graph. + hfree(stack); + return -1; + } + } + } + } + } + + hfree(stack); + + // Now scan for unreachable nodes and print debug info. + printf("\nUnreachable nodes debug information:\n"); + printf("=====================================\n"); + + hnswNode *current = index->head; + while (current) { + if (current->visited_epoch[0] != index->current_epoch[0]) { + printf("\nUnreachable node found:\n"); + printf("- Node pointer: %p\n", (void*)current); + printf("- Node ID: %llu\n", (unsigned long long)current->id); + printf("- Node level: %u\n", current->level); + + // Print info about all its links at each level. + for (uint32_t level = 0; level <= current->level; level++) { + printf(" Level %u links (%u):\n", level, + current->layers[level].num_links); + for (uint64_t i = 0; i < current->layers[level].num_links; i++) { + hnswNode *neighbor = current->layers[level].links[i]; + // Check reciprocity for this specific link + int found_backlink = 0; + for (uint64_t j = 0; j < neighbor->layers[level].num_links; j++) { + if (neighbor->layers[level].links[j] == current) { + found_backlink = 1; + break; + } + } + printf(" - Link %llu: pointer=%p, id=%llu, visited=%s,recpr=%s\n", + (unsigned long long)i, (void*)neighbor, + (unsigned long long)neighbor->id, + neighbor->visited_epoch[0] == index->current_epoch[0] ? + "yes" : "no", + found_backlink ? "yes" : "no"); + } + } + } + current = current->next; + } + + printf("Total connected nodes: %llu\n", (unsigned long long)*connected_nodes); + printf("All links are bi-directiona? %s\n", (*reciprocal_links)?"yes":"no"); + return 0; +} + +/* Test graph recall ability by verifying each node can be found searching + * for its own vector. This helps validate that the majority of nodes are + * properly connected and easily reachable in the graph structure. Every + * unreachable node is reported. + * + * Normally only a small percentage of nodes will be not reachable when + * visited. This is expected and part of the statistical properties + * of HNSW. This happens especially with entries that have an ambiguous + * meaning in the represented space, and are across two or multiple clusters + * of items. + * + * The function works by: + * 1. Iterating through all nodes in the linked list + * 2. Using each node's vector to perform a search with specified EF + * 3. Verifying the node can find itself as nearest neighbor + * 4. Collecting and reporting statistics about reachability + * + * This is just a debugging function that reports stuff in the standard + * output, part of the implementation because this kind of functions + * provide some visiblity on what happens inside the HNSW. + */ +void hnsw_test_graph_recall(HNSW *index, int test_ef, int verbose) { + // Stats + uint32_t total_nodes = 0; + uint32_t unreachable_nodes = 0; + uint32_t perfectly_reachable = 0; // Node finds itself as first result + + // For storing search results + hnswNode **neighbors = hmalloc(sizeof(hnswNode*) * test_ef); + float *distances = hmalloc(sizeof(float) * test_ef); + float *test_vector = hmalloc(sizeof(float) * index->vector_dim); + if (!neighbors || !distances || !test_vector) { + hfree(neighbors); + hfree(distances); + hfree(test_vector); + return; + } + + // Get a read slot for searching (even if it's highly unlikely that + // this test will be run threaded...). + int slot = hnsw_acquire_read_slot(index); + if (slot < 0) { + hfree(neighbors); + hfree(distances); + return; + } + + printf("\nTesting graph recall\n"); + printf("====================\n"); + + // Process one node at a time using the linked list + hnswNode *current = index->head; + while (current) { + total_nodes++; + + // If using quantization, we need to reconstruct the normalized vector + if (index->quant_type == HNSW_QUANT_Q8) { + int8_t *quants = current->vector; + // Reconstruct normalized vector from quantized data + for (uint32_t j = 0; j < index->vector_dim; j++) { + test_vector[j] = (quants[j] * current->quants_range) / 127; + } + } else if (index->quant_type == HNSW_QUANT_NONE) { + memcpy(test_vector,current->vector,sizeof(float)*index->vector_dim); + } else { + assert(0 && "Quantization type not supported."); + } + + // Search using the node's own vector with high ef + int found = hnsw_search(index, test_vector, test_ef, neighbors, + distances, slot, 1); + + if (found == 0) continue; // Empty HNSW? + + // Look for the node itself in the results + int found_self = 0; + int self_position = -1; + for (int i = 0; i < found; i++) { + if (neighbors[i] == current) { + found_self = 1; + self_position = i; + break; + } + } + + if (!found_self || self_position != 0) { + unreachable_nodes++; + if (verbose) { + if (!found_self) + printf("\nNode %s cannot find itself:\n", (char*)current->value); + else + printf("\nNode %s is not top result:\n", (char*)current->value); + printf("- Node ID: %llu\n", (unsigned long long)current->id); + printf("- Node level: %u\n", current->level); + printf("- Found %d neighbors but self not among them\n", found); + printf("- Closest neighbor distance: %f\n", distances[0]); + printf("- Neighbors: "); + for (uint32_t i = 0; i < current->layers[0].num_links; i++) { + printf("%s ", (char*)current->layers[0].links[i]->value); + } + printf("\n"); + printf("\nFound instead: "); + for (int j = 0; j < found && j < 10; j++) { + printf("%s ", (char*)neighbors[j]->value); + } + printf("\n"); + } + } else { + perfectly_reachable++; + } + current = current->next; + } + + // Release read slot + hnsw_release_read_slot(index, slot); + + // Free resources + hfree(neighbors); + hfree(distances); + hfree(test_vector); + + // Print final statistics + printf("Total nodes tested: %u\n", total_nodes); + printf("Perfectly reachable nodes: %u (%.1f%%)\n", + perfectly_reachable, + total_nodes ? (float)perfectly_reachable * 100 / total_nodes : 0); + printf("Unreachable/suboptimal nodes: %u (%.1f%%)\n", + unreachable_nodes, + total_nodes ? (float)unreachable_nodes * 100 / total_nodes : 0); +} + +/* Return exact K-NN items by performing a linear scan of all nodes. + * This function has the same signature as hnsw_search_with_filter() but + * instead of using the graph structure, it scans all nodes to find the + * true nearest neighbors. + * + * Note that neighbors and distances arrays must have space for at least 'k' items. + * norm_query should be set to 1 if the query vector is already normalized. + * + * If the filter_callback is passed, only elements passing the specified filter + * are returned. The slot parameter is ignored but kept for API consistency. */ +int hnsw_ground_truth_with_filter + (HNSW *index, const float *query_vector, uint32_t k, + hnswNode **neighbors, float *distances, uint32_t slot, + int query_vector_is_normalized, + int (*filter_callback)(void *value, void *privdata), + void *filter_privdata) +{ + /* Note that we don't really use the slot here: it's a linear scan. + * Yet we want the user to acquire the slot as this will hold the + * global lock in read only mode. */ + (void) slot; + + /* Take our query vector into a temporary node. */ + hnswNode query; + if (hnsw_init_tmp_node(index, &query, query_vector_is_normalized, query_vector) == 0) return -1; + + /* Accumulate best results into a priority queue. */ + pqueue *results = pq_new(k); + if (!results) { + hnsw_free_tmp_node(&query, query_vector); + return -1; + } + + /* Scan all nodes linearly. */ + hnswNode *current = index->head; + while (current) { + /* Apply filter if needed. */ + if (filter_callback && + !filter_callback(current->value, filter_privdata)) + { + current = current->next; + continue; + } + + /* Calculate distance to query. */ + float dist = hnsw_distance(index, &query, current); + + /* Add to results to pqueue. Will be accepted only if better than + * the current worse or pqueue not full. */ + pq_push(results, current, dist); + current = current->next; + } + + /* Copy results to output arrays. */ + uint32_t found = MIN(k, results->count); + for (uint32_t i = 0; i < found; i++) { + neighbors[i] = pq_get_node(results, i); + if (distances) distances[i] = pq_get_distance(results, i); + } + + /* Clean up. */ + pq_free(results); + hnsw_free_tmp_node(&query, query_vector); + return found; +} diff --git a/modules/vector-sets/hnsw.h b/modules/vector-sets/hnsw.h new file mode 100644 index 000000000..877302e50 --- /dev/null +++ b/modules/vector-sets/hnsw.h @@ -0,0 +1,183 @@ +/* + * HNSW (Hierarchical Navigable Small World) Implementation + * Based on the paper by Yu. A. Malkov, D. A. Yashunin + * + * Copyright(C) 2024-Pesent Redis Ltd. All Rights Reserved. + */ + +#ifndef HNSW_H +#define HNSW_H + +#include +#include + +#define HNSW_DEFAULT_M 16 /* Used when 0 is given at creation time. */ +#define HNSW_MIN_M 4 /* Probably even too low already. */ +#define HNSW_MAX_M 4096 /* Safeguard sanity limit. */ +#define HNSW_MAX_THREADS 32 /* Maximum number of concurrent threads */ + +/* Quantization types you can enable at creation time in hnsw_new() */ +#define HNSW_QUANT_NONE 0 // No quantization. +#define HNSW_QUANT_Q8 1 // Q8 quantization. +#define HNSW_QUANT_BIN 2 // Binary quantization. + +/* Layer structure for HNSW nodes. Each node will have from one to a few + * of this depending on its level. */ +typedef struct { + struct hnswNode **links; /* Array of neighbors for this layer */ + uint32_t num_links; /* Number of used links */ + uint32_t max_links; /* Maximum links for this layer. We may + * reallocate the node in very particular + * conditions in order to allow linking of + * new inserted nodes, so this may change + * dynamically and be > M*2 for a small set of + * nodes. */ + float worst_distance; /* Distance to the worst neighbor */ + uint32_t worst_idx; /* Index of the worst neighbor */ +} hnswNodeLayer; + +/* Node structure for HNSW graph */ +typedef struct hnswNode { + uint32_t level; /* Node's maximum level */ + uint64_t id; /* Unique identifier, may be useful in order to + * have a bitmap of visited notes to use as + * alternative to epoch / visited_epoch. + * Also used in serialization in order to retain + * links specifying IDs. */ + void *vector; /* The vector, quantized or not. */ + float quants_range; /* Quantization range for this vector: + * min/max values will be in the range + * -quants_range, +quants_range */ + float l2; /* L2 before normalization. */ + + /* Last time (epoch) this node was visited. We need one per thread. + * This avoids having a different data structure where we track + * visited nodes, but costs memory per node. */ + uint64_t visited_epoch[HNSW_MAX_THREADS]; + + void *value; /* Associated value */ + struct hnswNode *prev, *next; /* Prev/Next node in the list starting at + * HNSW->head. */ + + /* Links (and links info) per each layer. Note that this is part + * of the node allocation to be more cache friendly: reliable 3% speedup + * on Apple silicon, and does not make anything more complex. */ + hnswNodeLayer layers[]; +} hnswNode; + +struct HNSW; + +/* It is possible to navigate an HNSW with a cursor that guarantees + * visiting all the elements that remain in the HNSW from the start to the + * end of the process (but not the new ones, so that the process will + * eventually finish). Check hnsw_cursor_init(), hnsw_cursor_next() and + * hnsw_cursor_free(). */ +typedef struct hnswCursor { + struct HNSW *index; // Reference to the index of this cursor. + hnswNode *current; // Element to report when hnsw_cursor_next() is called. + struct hnswCursor *next; // Next cursor active. +} hnswCursor; + +/* Main HNSW index structure */ +typedef struct HNSW { + hnswNode *enter_point; /* Entry point for the graph */ + uint32_t M; /* M as in the paper: layer 0 has M*2 max + neighbors (M populated at insertion time) + while all the other layers have M neighbors. */ + uint32_t max_level; /* Current maximum level in the graph */ + uint32_t vector_dim; /* Dimensionality of stored vectors */ + uint64_t node_count; /* Total number of nodes */ + _Atomic uint64_t last_id; /* Last node ID used */ + uint64_t current_epoch[HNSW_MAX_THREADS]; /* Current epoch for visit tracking */ + hnswNode *head; /* Linked list of nodes. Last first */ + + /* We have two locks here: + * 1. A global_lock that is used to perform write operations blocking all + * the readers. + * 2. One mutex per epoch slot, in order for read operations to acquire + * a lock on a specific slot to use epochs tracking of visited nodes. */ + pthread_rwlock_t global_lock; /* Global read-write lock */ + pthread_mutex_t slot_locks[HNSW_MAX_THREADS]; /* Per-slot locks */ + + _Atomic uint32_t next_slot; /* Next thread slot to try */ + _Atomic uint64_t version; /* Version for optimistic concurrency, this is + * incremented on deletions and entry point + * updates. */ + uint32_t quant_type; /* Quantization used. HNSW_QUANT_... */ + hnswCursor *cursors; +} HNSW; + +/* Serialized node. This structure is used as return value of + * hnsw_serialize_node(). */ +typedef struct hnswSerNode { + void *vector; + uint32_t vector_size; + uint64_t *params; + uint32_t params_count; +} hnswSerNode; + +/* Insert preparation context */ +typedef struct InsertContext InsertContext; + +/* Core HNSW functions */ +HNSW *hnsw_new(uint32_t vector_dim, uint32_t quant_type, uint32_t m); +void hnsw_free(HNSW *index,void(*free_value)(void*value)); +void hnsw_node_free(hnswNode *node); +void hnsw_print_stats(HNSW *index); +hnswNode *hnsw_insert(HNSW *index, const float *vector, const int8_t *qvector, + float qrange, uint64_t id, void *value, int ef); +int hnsw_search(HNSW *index, const float *query, uint32_t k, + hnswNode **neighbors, float *distances, uint32_t slot, + int query_vector_is_normalized); +int hnsw_search_with_filter + (HNSW *index, const float *query_vector, uint32_t k, + hnswNode **neighbors, float *distances, uint32_t slot, + int query_vector_is_normalized, + int (*filter_callback)(void *value, void *privdata), + void *filter_privdata, uint32_t max_candidates); +void hnsw_get_node_vector(HNSW *index, hnswNode *node, float *vec); +int hnsw_delete_node(HNSW *index, hnswNode *node, void(*free_value)(void*value)); +hnswNode *hnsw_random_node(HNSW *index, int slot); + +/* Thread safety functions. */ +int hnsw_acquire_read_slot(HNSW *index); +void hnsw_release_read_slot(HNSW *index, int slot); + +/* Optimistic insertion API. */ +InsertContext *hnsw_prepare_insert(HNSW *index, const float *vector, const int8_t *qvector, float qrange, uint64_t id, int ef); +hnswNode *hnsw_try_commit_insert(HNSW *index, InsertContext *ctx, void *value); +void hnsw_free_insert_context(InsertContext *ctx); + +/* Serialization. */ +hnswSerNode *hnsw_serialize_node(HNSW *index, hnswNode *node); +void hnsw_free_serialized_node(hnswSerNode *sn); +hnswNode *hnsw_insert_serialized(HNSW *index, void *vector, uint64_t *params, uint32_t params_len, void *value); +int hnsw_deserialize_index(HNSW *index); + +// Helper function in case the user wants to directly copy +// the vector bytes. +uint32_t hnsw_quants_bytes(HNSW *index); + +/* Cursors. */ +hnswCursor *hnsw_cursor_init(HNSW *index); +void hnsw_cursor_free(hnswCursor *cursor); +hnswNode *hnsw_cursor_next(hnswCursor *cursor); +int hnsw_cursor_acquire_lock(hnswCursor *cursor); +void hnsw_cursor_release_lock(hnswCursor *cursor); + +/* Allocator selection. */ +void hnsw_set_allocator(void (*free_ptr)(void*), void *(*malloc_ptr)(size_t), + void *(*realloc_ptr)(void*, size_t)); + +/* Testing. */ +int hnsw_validate_graph(HNSW *index, uint64_t *connected_nodes, int *reciprocal_links); +void hnsw_test_graph_recall(HNSW *index, int test_ef, int verbose); +float hnsw_distance(HNSW *index, hnswNode *a, hnswNode *b); +int hnsw_ground_truth_with_filter + (HNSW *index, const float *query_vector, uint32_t k, + hnswNode **neighbors, float *distances, uint32_t slot, + int query_vector_is_normalized, + int (*filter_callback)(void *value, void *privdata), + void *filter_privdata); + +#endif /* HNSW_H */ diff --git a/modules/vector-sets/redismodule.h b/modules/vector-sets/redismodule.h new file mode 100644 index 000000000..b84913b1e --- /dev/null +++ b/modules/vector-sets/redismodule.h @@ -0,0 +1,1704 @@ +#ifndef REDISMODULE_H +#define REDISMODULE_H + +#include +#include +#include +#include + + +typedef struct RedisModuleString RedisModuleString; +typedef struct RedisModuleKey RedisModuleKey; + +/* -------------- Defines NOT common between core and modules ------------- */ + +#if defined REDISMODULE_CORE +/* Things only defined for the modules core (server), not exported to modules + * that include this file. */ + +#define RedisModuleString robj + +#endif /* defined REDISMODULE_CORE */ + +#if !defined REDISMODULE_CORE && !defined REDISMODULE_CORE_MODULE +/* Things defined for modules, but not for core-modules. */ + +typedef long long mstime_t; +typedef long long ustime_t; + +#endif /* !defined REDISMODULE_CORE && !defined REDISMODULE_CORE_MODULE */ + +/* ---------------- Defines common between core and modules --------------- */ + +/* Error status return values. */ +#define REDISMODULE_OK 0 +#define REDISMODULE_ERR 1 + +/* Module Based Authentication status return values. */ +#define REDISMODULE_AUTH_HANDLED 0 +#define REDISMODULE_AUTH_NOT_HANDLED 1 + +/* API versions. */ +#define REDISMODULE_APIVER_1 1 + +/* Version of the RedisModuleTypeMethods structure. Once the RedisModuleTypeMethods + * structure is changed, this version number needs to be changed synchronistically. */ +#define REDISMODULE_TYPE_METHOD_VERSION 5 + +/* API flags and constants */ +#define REDISMODULE_READ (1<<0) +#define REDISMODULE_WRITE (1<<1) + +/* RedisModule_OpenKey extra flags for the 'mode' argument. + * Avoid touching the LRU/LFU of the key when opened. */ +#define REDISMODULE_OPEN_KEY_NOTOUCH (1<<16) +/* Don't trigger keyspace event on key misses. */ +#define REDISMODULE_OPEN_KEY_NONOTIFY (1<<17) +/* Don't update keyspace hits/misses counters. */ +#define REDISMODULE_OPEN_KEY_NOSTATS (1<<18) +/* Avoid deleting lazy expired keys. */ +#define REDISMODULE_OPEN_KEY_NOEXPIRE (1<<19) +/* Avoid any effects from fetching the key */ +#define REDISMODULE_OPEN_KEY_NOEFFECTS (1<<20) +/* Allow access expired key that haven't deleted yet */ +#define REDISMODULE_OPEN_KEY_ACCESS_EXPIRED (1<<21) + +/* Mask of all REDISMODULE_OPEN_KEY_* values. Any new mode should be added to this list. + * Should not be used directly by the module, use RM_GetOpenKeyModesAll instead. + * Located here so when we will add new modes we will not forget to update it. */ +#define _REDISMODULE_OPEN_KEY_ALL REDISMODULE_READ | REDISMODULE_WRITE | REDISMODULE_OPEN_KEY_NOTOUCH | REDISMODULE_OPEN_KEY_NONOTIFY | REDISMODULE_OPEN_KEY_NOSTATS | REDISMODULE_OPEN_KEY_NOEXPIRE | REDISMODULE_OPEN_KEY_NOEFFECTS | REDISMODULE_OPEN_KEY_ACCESS_EXPIRED + +/* List push and pop */ +#define REDISMODULE_LIST_HEAD 0 +#define REDISMODULE_LIST_TAIL 1 + +/* Key types. */ +#define REDISMODULE_KEYTYPE_EMPTY 0 +#define REDISMODULE_KEYTYPE_STRING 1 +#define REDISMODULE_KEYTYPE_LIST 2 +#define REDISMODULE_KEYTYPE_HASH 3 +#define REDISMODULE_KEYTYPE_SET 4 +#define REDISMODULE_KEYTYPE_ZSET 5 +#define REDISMODULE_KEYTYPE_MODULE 6 +#define REDISMODULE_KEYTYPE_STREAM 7 + +/* Reply types. */ +#define REDISMODULE_REPLY_UNKNOWN -1 +#define REDISMODULE_REPLY_STRING 0 +#define REDISMODULE_REPLY_ERROR 1 +#define REDISMODULE_REPLY_INTEGER 2 +#define REDISMODULE_REPLY_ARRAY 3 +#define REDISMODULE_REPLY_NULL 4 +#define REDISMODULE_REPLY_MAP 5 +#define REDISMODULE_REPLY_SET 6 +#define REDISMODULE_REPLY_BOOL 7 +#define REDISMODULE_REPLY_DOUBLE 8 +#define REDISMODULE_REPLY_BIG_NUMBER 9 +#define REDISMODULE_REPLY_VERBATIM_STRING 10 +#define REDISMODULE_REPLY_ATTRIBUTE 11 +#define REDISMODULE_REPLY_PROMISE 12 + +/* Postponed array length. */ +#define REDISMODULE_POSTPONED_ARRAY_LEN -1 /* Deprecated, please use REDISMODULE_POSTPONED_LEN */ +#define REDISMODULE_POSTPONED_LEN -1 + +/* Expire */ +#define REDISMODULE_NO_EXPIRE -1 + +/* Sorted set API flags. */ +#define REDISMODULE_ZADD_XX (1<<0) +#define REDISMODULE_ZADD_NX (1<<1) +#define REDISMODULE_ZADD_ADDED (1<<2) +#define REDISMODULE_ZADD_UPDATED (1<<3) +#define REDISMODULE_ZADD_NOP (1<<4) +#define REDISMODULE_ZADD_GT (1<<5) +#define REDISMODULE_ZADD_LT (1<<6) + +/* Hash API flags. */ +#define REDISMODULE_HASH_NONE 0 +#define REDISMODULE_HASH_NX (1<<0) +#define REDISMODULE_HASH_XX (1<<1) +#define REDISMODULE_HASH_CFIELDS (1<<2) +#define REDISMODULE_HASH_EXISTS (1<<3) +#define REDISMODULE_HASH_COUNT_ALL (1<<4) + +#define REDISMODULE_CONFIG_DEFAULT 0 /* This is the default for a module config. */ +#define REDISMODULE_CONFIG_IMMUTABLE (1ULL<<0) /* Can this value only be set at startup? */ +#define REDISMODULE_CONFIG_SENSITIVE (1ULL<<1) /* Does this value contain sensitive information */ +#define REDISMODULE_CONFIG_HIDDEN (1ULL<<4) /* This config is hidden in `config get ` (used for tests/debugging) */ +#define REDISMODULE_CONFIG_PROTECTED (1ULL<<5) /* Becomes immutable if enable-protected-configs is enabled. */ +#define REDISMODULE_CONFIG_DENY_LOADING (1ULL<<6) /* This config is forbidden during loading. */ + +#define REDISMODULE_CONFIG_MEMORY (1ULL<<7) /* Indicates if this value can be set as a memory value */ +#define REDISMODULE_CONFIG_BITFLAGS (1ULL<<8) /* Indicates if this value can be set as a multiple enum values */ + +/* StreamID type. */ +typedef struct RedisModuleStreamID { + uint64_t ms; + uint64_t seq; +} RedisModuleStreamID; + +/* StreamAdd() flags. */ +#define REDISMODULE_STREAM_ADD_AUTOID (1<<0) +/* StreamIteratorStart() flags. */ +#define REDISMODULE_STREAM_ITERATOR_EXCLUSIVE (1<<0) +#define REDISMODULE_STREAM_ITERATOR_REVERSE (1<<1) +/* StreamIteratorTrim*() flags. */ +#define REDISMODULE_STREAM_TRIM_APPROX (1<<0) + +/* Context Flags: Info about the current context returned by + * RM_GetContextFlags(). */ + +/* The command is running in the context of a Lua script */ +#define REDISMODULE_CTX_FLAGS_LUA (1<<0) +/* The command is running inside a Redis transaction */ +#define REDISMODULE_CTX_FLAGS_MULTI (1<<1) +/* The instance is a master */ +#define REDISMODULE_CTX_FLAGS_MASTER (1<<2) +/* The instance is a slave */ +#define REDISMODULE_CTX_FLAGS_SLAVE (1<<3) +/* The instance is read-only (usually meaning it's a slave as well) */ +#define REDISMODULE_CTX_FLAGS_READONLY (1<<4) +/* The instance is running in cluster mode */ +#define REDISMODULE_CTX_FLAGS_CLUSTER (1<<5) +/* The instance has AOF enabled */ +#define REDISMODULE_CTX_FLAGS_AOF (1<<6) +/* The instance has RDB enabled */ +#define REDISMODULE_CTX_FLAGS_RDB (1<<7) +/* The instance has Maxmemory set */ +#define REDISMODULE_CTX_FLAGS_MAXMEMORY (1<<8) +/* Maxmemory is set and has an eviction policy that may delete keys */ +#define REDISMODULE_CTX_FLAGS_EVICT (1<<9) +/* Redis is out of memory according to the maxmemory flag. */ +#define REDISMODULE_CTX_FLAGS_OOM (1<<10) +/* Less than 25% of memory available according to maxmemory. */ +#define REDISMODULE_CTX_FLAGS_OOM_WARNING (1<<11) +/* The command was sent over the replication link. */ +#define REDISMODULE_CTX_FLAGS_REPLICATED (1<<12) +/* Redis is currently loading either from AOF or RDB. */ +#define REDISMODULE_CTX_FLAGS_LOADING (1<<13) +/* The replica has no link with its master, note that + * there is the inverse flag as well: + * + * REDISMODULE_CTX_FLAGS_REPLICA_IS_ONLINE + * + * The two flags are exclusive, one or the other can be set. */ +#define REDISMODULE_CTX_FLAGS_REPLICA_IS_STALE (1<<14) +/* The replica is trying to connect with the master. + * (REPL_STATE_CONNECT and REPL_STATE_CONNECTING states) */ +#define REDISMODULE_CTX_FLAGS_REPLICA_IS_CONNECTING (1<<15) +/* THe replica is receiving an RDB file from its master. */ +#define REDISMODULE_CTX_FLAGS_REPLICA_IS_TRANSFERRING (1<<16) +/* The replica is online, receiving updates from its master. */ +#define REDISMODULE_CTX_FLAGS_REPLICA_IS_ONLINE (1<<17) +/* There is currently some background process active. */ +#define REDISMODULE_CTX_FLAGS_ACTIVE_CHILD (1<<18) +/* The next EXEC will fail due to dirty CAS (touched keys). */ +#define REDISMODULE_CTX_FLAGS_MULTI_DIRTY (1<<19) +/* Redis is currently running inside background child process. */ +#define REDISMODULE_CTX_FLAGS_IS_CHILD (1<<20) +/* The current client does not allow blocking, either called from + * within multi, lua, or from another module using RM_Call */ +#define REDISMODULE_CTX_FLAGS_DENY_BLOCKING (1<<21) +/* The current client uses RESP3 protocol */ +#define REDISMODULE_CTX_FLAGS_RESP3 (1<<22) +/* Redis is currently async loading database for diskless replication. */ +#define REDISMODULE_CTX_FLAGS_ASYNC_LOADING (1<<23) +/* Redis is starting. */ +#define REDISMODULE_CTX_FLAGS_SERVER_STARTUP (1<<24) + +/* Next context flag, must be updated when adding new flags above! +This flag should not be used directly by the module. + * Use RedisModule_GetContextFlagsAll instead. */ +#define _REDISMODULE_CTX_FLAGS_NEXT (1<<25) + +/* Keyspace changes notification classes. Every class is associated with a + * character for configuration purposes. + * NOTE: These have to be in sync with NOTIFY_* in server.h */ +#define REDISMODULE_NOTIFY_KEYSPACE (1<<0) /* K */ +#define REDISMODULE_NOTIFY_KEYEVENT (1<<1) /* E */ +#define REDISMODULE_NOTIFY_GENERIC (1<<2) /* g */ +#define REDISMODULE_NOTIFY_STRING (1<<3) /* $ */ +#define REDISMODULE_NOTIFY_LIST (1<<4) /* l */ +#define REDISMODULE_NOTIFY_SET (1<<5) /* s */ +#define REDISMODULE_NOTIFY_HASH (1<<6) /* h */ +#define REDISMODULE_NOTIFY_ZSET (1<<7) /* z */ +#define REDISMODULE_NOTIFY_EXPIRED (1<<8) /* x */ +#define REDISMODULE_NOTIFY_EVICTED (1<<9) /* e */ +#define REDISMODULE_NOTIFY_STREAM (1<<10) /* t */ +#define REDISMODULE_NOTIFY_KEY_MISS (1<<11) /* m (Note: This one is excluded from REDISMODULE_NOTIFY_ALL on purpose) */ +#define REDISMODULE_NOTIFY_LOADED (1<<12) /* module only key space notification, indicate a key loaded from rdb */ +#define REDISMODULE_NOTIFY_MODULE (1<<13) /* d, module key space notification */ +#define REDISMODULE_NOTIFY_NEW (1<<14) /* n, new key notification */ + +/* Next notification flag, must be updated when adding new flags above! +This flag should not be used directly by the module. + * Use RedisModule_GetKeyspaceNotificationFlagsAll instead. */ +#define _REDISMODULE_NOTIFY_NEXT (1<<15) + +#define REDISMODULE_NOTIFY_ALL (REDISMODULE_NOTIFY_GENERIC | REDISMODULE_NOTIFY_STRING | REDISMODULE_NOTIFY_LIST | REDISMODULE_NOTIFY_SET | REDISMODULE_NOTIFY_HASH | REDISMODULE_NOTIFY_ZSET | REDISMODULE_NOTIFY_EXPIRED | REDISMODULE_NOTIFY_EVICTED | REDISMODULE_NOTIFY_STREAM | REDISMODULE_NOTIFY_MODULE) /* A */ + +/* A special pointer that we can use between the core and the module to signal + * field deletion, and that is impossible to be a valid pointer. */ +#define REDISMODULE_HASH_DELETE ((RedisModuleString*)(long)1) + +/* Error messages. */ +#define REDISMODULE_ERRORMSG_WRONGTYPE "WRONGTYPE Operation against a key holding the wrong kind of value" + +#define REDISMODULE_POSITIVE_INFINITE (1.0/0.0) +#define REDISMODULE_NEGATIVE_INFINITE (-1.0/0.0) + +/* Cluster API defines. */ +#define REDISMODULE_NODE_ID_LEN 40 +#define REDISMODULE_NODE_MYSELF (1<<0) +#define REDISMODULE_NODE_MASTER (1<<1) +#define REDISMODULE_NODE_SLAVE (1<<2) +#define REDISMODULE_NODE_PFAIL (1<<3) +#define REDISMODULE_NODE_FAIL (1<<4) +#define REDISMODULE_NODE_NOFAILOVER (1<<5) + +#define REDISMODULE_CLUSTER_FLAG_NONE 0 +#define REDISMODULE_CLUSTER_FLAG_NO_FAILOVER (1<<1) +#define REDISMODULE_CLUSTER_FLAG_NO_REDIRECTION (1<<2) + +#define REDISMODULE_NOT_USED(V) ((void) V) + +/* Logging level strings */ +#define REDISMODULE_LOGLEVEL_DEBUG "debug" +#define REDISMODULE_LOGLEVEL_VERBOSE "verbose" +#define REDISMODULE_LOGLEVEL_NOTICE "notice" +#define REDISMODULE_LOGLEVEL_WARNING "warning" + +/* Bit flags for aux_save_triggers and the aux_load and aux_save callbacks */ +#define REDISMODULE_AUX_BEFORE_RDB (1<<0) +#define REDISMODULE_AUX_AFTER_RDB (1<<1) + +/* RM_Yield flags */ +#define REDISMODULE_YIELD_FLAG_NONE (1<<0) +#define REDISMODULE_YIELD_FLAG_CLIENTS (1<<1) + +/* RM_BlockClientOnKeysWithFlags flags */ +#define REDISMODULE_BLOCK_UNBLOCK_DEFAULT (0) +#define REDISMODULE_BLOCK_UNBLOCK_DELETED (1<<0) + +/* This type represents a timer handle, and is returned when a timer is + * registered and used in order to invalidate a timer. It's just a 64 bit + * number, because this is how each timer is represented inside the radix tree + * of timers that are going to expire, sorted by expire time. */ +typedef uint64_t RedisModuleTimerID; + +/* CommandFilter Flags */ + +/* Do filter RedisModule_Call() commands initiated by module itself. */ +#define REDISMODULE_CMDFILTER_NOSELF (1<<0) + +/* Declare that the module can handle errors with RedisModule_SetModuleOptions. */ +#define REDISMODULE_OPTIONS_HANDLE_IO_ERRORS (1<<0) + +/* When set, Redis will not call RedisModule_SignalModifiedKey(), implicitly in + * RedisModule_CloseKey, and the module needs to do that when manually when keys + * are modified from the user's perspective, to invalidate WATCH. */ +#define REDISMODULE_OPTION_NO_IMPLICIT_SIGNAL_MODIFIED (1<<1) + +/* Declare that the module can handle diskless async replication with RedisModule_SetModuleOptions. */ +#define REDISMODULE_OPTIONS_HANDLE_REPL_ASYNC_LOAD (1<<2) + +/* Declare that the module want to get nested key space notifications. + * If enabled, the module is responsible to break endless loop. */ +#define REDISMODULE_OPTIONS_ALLOW_NESTED_KEYSPACE_NOTIFICATIONS (1<<3) + +/* Next option flag, must be updated when adding new module flags above! + * This flag should not be used directly by the module. + * Use RedisModule_GetModuleOptionsAll instead. */ +#define _REDISMODULE_OPTIONS_FLAGS_NEXT (1<<4) + +/* Definitions for RedisModule_SetCommandInfo. */ + +typedef enum { + REDISMODULE_ARG_TYPE_STRING, + REDISMODULE_ARG_TYPE_INTEGER, + REDISMODULE_ARG_TYPE_DOUBLE, + REDISMODULE_ARG_TYPE_KEY, /* A string, but represents a keyname */ + REDISMODULE_ARG_TYPE_PATTERN, + REDISMODULE_ARG_TYPE_UNIX_TIME, + REDISMODULE_ARG_TYPE_PURE_TOKEN, + REDISMODULE_ARG_TYPE_ONEOF, /* Must have sub-arguments */ + REDISMODULE_ARG_TYPE_BLOCK /* Must have sub-arguments */ +} RedisModuleCommandArgType; + +#define REDISMODULE_CMD_ARG_NONE (0) +#define REDISMODULE_CMD_ARG_OPTIONAL (1<<0) /* The argument is optional (like GET in SET command) */ +#define REDISMODULE_CMD_ARG_MULTIPLE (1<<1) /* The argument may repeat itself (like key in DEL) */ +#define REDISMODULE_CMD_ARG_MULTIPLE_TOKEN (1<<2) /* The argument may repeat itself, and so does its token (like `GET pattern` in SORT) */ +#define _REDISMODULE_CMD_ARG_NEXT (1<<3) + +typedef enum { + REDISMODULE_KSPEC_BS_INVALID = 0, /* Must be zero. An implicitly value of + * zero is provided when the field is + * absent in a struct literal. */ + REDISMODULE_KSPEC_BS_UNKNOWN, + REDISMODULE_KSPEC_BS_INDEX, + REDISMODULE_KSPEC_BS_KEYWORD +} RedisModuleKeySpecBeginSearchType; + +typedef enum { + REDISMODULE_KSPEC_FK_OMITTED = 0, /* Used when the field is absent in a + * struct literal. Don't use this value + * explicitly. */ + REDISMODULE_KSPEC_FK_UNKNOWN, + REDISMODULE_KSPEC_FK_RANGE, + REDISMODULE_KSPEC_FK_KEYNUM +} RedisModuleKeySpecFindKeysType; + +/* Key-spec flags. For details, see the documentation of + * RedisModule_SetCommandInfo and the key-spec flags in server.h. */ +#define REDISMODULE_CMD_KEY_RO (1ULL<<0) +#define REDISMODULE_CMD_KEY_RW (1ULL<<1) +#define REDISMODULE_CMD_KEY_OW (1ULL<<2) +#define REDISMODULE_CMD_KEY_RM (1ULL<<3) +#define REDISMODULE_CMD_KEY_ACCESS (1ULL<<4) +#define REDISMODULE_CMD_KEY_UPDATE (1ULL<<5) +#define REDISMODULE_CMD_KEY_INSERT (1ULL<<6) +#define REDISMODULE_CMD_KEY_DELETE (1ULL<<7) +#define REDISMODULE_CMD_KEY_NOT_KEY (1ULL<<8) +#define REDISMODULE_CMD_KEY_INCOMPLETE (1ULL<<9) +#define REDISMODULE_CMD_KEY_VARIABLE_FLAGS (1ULL<<10) + +/* Channel flags, for details see the documentation of + * RedisModule_ChannelAtPosWithFlags. */ +#define REDISMODULE_CMD_CHANNEL_PATTERN (1ULL<<0) +#define REDISMODULE_CMD_CHANNEL_PUBLISH (1ULL<<1) +#define REDISMODULE_CMD_CHANNEL_SUBSCRIBE (1ULL<<2) +#define REDISMODULE_CMD_CHANNEL_UNSUBSCRIBE (1ULL<<3) + +typedef struct RedisModuleCommandArg { + const char *name; + RedisModuleCommandArgType type; + int key_spec_index; /* If type is KEY, this is a zero-based index of + * the key_spec in the command. For other types, + * you may specify -1. */ + const char *token; /* If type is PURE_TOKEN, this is the token. */ + const char *summary; + const char *since; + int flags; /* The REDISMODULE_CMD_ARG_* macros. */ + const char *deprecated_since; + struct RedisModuleCommandArg *subargs; + const char *display_text; +} RedisModuleCommandArg; + +typedef struct { + const char *since; + const char *changes; +} RedisModuleCommandHistoryEntry; + +typedef struct { + const char *notes; + uint64_t flags; /* REDISMODULE_CMD_KEY_* macros. */ + RedisModuleKeySpecBeginSearchType begin_search_type; + union { + struct { + /* The index from which we start the search for keys */ + int pos; + } index; + struct { + /* The keyword that indicates the beginning of key args */ + const char *keyword; + /* An index in argv from which to start searching. + * Can be negative, which means start search from the end, in reverse + * (Example: -2 means to start in reverse from the penultimate arg) */ + int startfrom; + } keyword; + } bs; + RedisModuleKeySpecFindKeysType find_keys_type; + union { + struct { + /* Index of the last key relative to the result of the begin search + * step. Can be negative, in which case it's not relative. -1 + * indicating till the last argument, -2 one before the last and so + * on. */ + int lastkey; + /* How many args should we skip after finding a key, in order to + * find the next one. */ + int keystep; + /* If lastkey is -1, we use limit to stop the search by a factor. 0 + * and 1 mean no limit. 2 means 1/2 of the remaining args, 3 means + * 1/3, and so on. */ + int limit; + } range; + struct { + /* Index of the argument containing the number of keys to come + * relative to the result of the begin search step */ + int keynumidx; + /* Index of the fist key. (Usually it's just after keynumidx, in + * which case it should be set to keynumidx + 1.) */ + int firstkey; + /* How many args should we skip after finding a key, in order to + * find the next one, relative to the result of the begin search + * step. */ + int keystep; + } keynum; + } fk; +} RedisModuleCommandKeySpec; + +typedef struct { + int version; + size_t sizeof_historyentry; + size_t sizeof_keyspec; + size_t sizeof_arg; +} RedisModuleCommandInfoVersion; + +static const RedisModuleCommandInfoVersion RedisModule_CurrentCommandInfoVersion = { + .version = 1, + .sizeof_historyentry = sizeof(RedisModuleCommandHistoryEntry), + .sizeof_keyspec = sizeof(RedisModuleCommandKeySpec), + .sizeof_arg = sizeof(RedisModuleCommandArg) +}; + +#define REDISMODULE_COMMAND_INFO_VERSION (&RedisModule_CurrentCommandInfoVersion) + +typedef struct { + /* Always set version to REDISMODULE_COMMAND_INFO_VERSION */ + const RedisModuleCommandInfoVersion *version; + /* Version 1 fields (added in Redis 7.0.0) */ + const char *summary; /* Summary of the command */ + const char *complexity; /* Complexity description */ + const char *since; /* Debut module version of the command */ + RedisModuleCommandHistoryEntry *history; /* History */ + /* A string of space-separated tips meant for clients/proxies regarding this + * command */ + const char *tips; + /* Number of arguments, it is possible to use -N to say >= N */ + int arity; + RedisModuleCommandKeySpec *key_specs; + RedisModuleCommandArg *args; +} RedisModuleCommandInfo; + +/* Eventloop definitions. */ +#define REDISMODULE_EVENTLOOP_READABLE 1 +#define REDISMODULE_EVENTLOOP_WRITABLE 2 +typedef void (*RedisModuleEventLoopFunc)(int fd, void *user_data, int mask); +typedef void (*RedisModuleEventLoopOneShotFunc)(void *user_data); + +/* Server events definitions. + * Those flags should not be used directly by the module, instead + * the module should use RedisModuleEvent_* variables. + * Note: This must be synced with moduleEventVersions */ +#define REDISMODULE_EVENT_REPLICATION_ROLE_CHANGED 0 +#define REDISMODULE_EVENT_PERSISTENCE 1 +#define REDISMODULE_EVENT_FLUSHDB 2 +#define REDISMODULE_EVENT_LOADING 3 +#define REDISMODULE_EVENT_CLIENT_CHANGE 4 +#define REDISMODULE_EVENT_SHUTDOWN 5 +#define REDISMODULE_EVENT_REPLICA_CHANGE 6 +#define REDISMODULE_EVENT_MASTER_LINK_CHANGE 7 +#define REDISMODULE_EVENT_CRON_LOOP 8 +#define REDISMODULE_EVENT_MODULE_CHANGE 9 +#define REDISMODULE_EVENT_LOADING_PROGRESS 10 +#define REDISMODULE_EVENT_SWAPDB 11 +#define REDISMODULE_EVENT_REPL_BACKUP 12 /* Deprecated since Redis 7.0, not used anymore. */ +#define REDISMODULE_EVENT_FORK_CHILD 13 +#define REDISMODULE_EVENT_REPL_ASYNC_LOAD 14 +#define REDISMODULE_EVENT_EVENTLOOP 15 +#define REDISMODULE_EVENT_CONFIG 16 +#define REDISMODULE_EVENT_KEY 17 +#define _REDISMODULE_EVENT_NEXT 18 /* Next event flag, should be updated if a new event added. */ + +typedef struct RedisModuleEvent { + uint64_t id; /* REDISMODULE_EVENT_... defines. */ + uint64_t dataver; /* Version of the structure we pass as 'data'. */ +} RedisModuleEvent; + +struct RedisModuleCtx; +struct RedisModuleDefragCtx; +typedef void (*RedisModuleEventCallback)(struct RedisModuleCtx *ctx, RedisModuleEvent eid, uint64_t subevent, void *data); + +/* IMPORTANT: When adding a new version of one of below structures that contain + * event data (RedisModuleFlushInfoV1 for example) we have to avoid renaming the + * old RedisModuleEvent structure. + * For example, if we want to add RedisModuleFlushInfoV2, the RedisModuleEvent + * structures should be: + * RedisModuleEvent_FlushDB = { + * REDISMODULE_EVENT_FLUSHDB, + * 1 + * }, + * RedisModuleEvent_FlushDBV2 = { + * REDISMODULE_EVENT_FLUSHDB, + * 2 + * } + * and NOT: + * RedisModuleEvent_FlushDBV1 = { + * REDISMODULE_EVENT_FLUSHDB, + * 1 + * }, + * RedisModuleEvent_FlushDB = { + * REDISMODULE_EVENT_FLUSHDB, + * 2 + * } + * The reason for that is forward-compatibility: We want that module that + * compiled with a new redismodule.h to be able to work with a old server, + * unless the author explicitly decided to use the newer event type. + */ +static const RedisModuleEvent + RedisModuleEvent_ReplicationRoleChanged = { + REDISMODULE_EVENT_REPLICATION_ROLE_CHANGED, + 1 + }, + RedisModuleEvent_Persistence = { + REDISMODULE_EVENT_PERSISTENCE, + 1 + }, + RedisModuleEvent_FlushDB = { + REDISMODULE_EVENT_FLUSHDB, + 1 + }, + RedisModuleEvent_Loading = { + REDISMODULE_EVENT_LOADING, + 1 + }, + RedisModuleEvent_ClientChange = { + REDISMODULE_EVENT_CLIENT_CHANGE, + 1 + }, + RedisModuleEvent_Shutdown = { + REDISMODULE_EVENT_SHUTDOWN, + 1 + }, + RedisModuleEvent_ReplicaChange = { + REDISMODULE_EVENT_REPLICA_CHANGE, + 1 + }, + RedisModuleEvent_CronLoop = { + REDISMODULE_EVENT_CRON_LOOP, + 1 + }, + RedisModuleEvent_MasterLinkChange = { + REDISMODULE_EVENT_MASTER_LINK_CHANGE, + 1 + }, + RedisModuleEvent_ModuleChange = { + REDISMODULE_EVENT_MODULE_CHANGE, + 1 + }, + RedisModuleEvent_LoadingProgress = { + REDISMODULE_EVENT_LOADING_PROGRESS, + 1 + }, + RedisModuleEvent_SwapDB = { + REDISMODULE_EVENT_SWAPDB, + 1 + }, + /* Deprecated since Redis 7.0, not used anymore. */ + __attribute__ ((deprecated)) + RedisModuleEvent_ReplBackup = { + REDISMODULE_EVENT_REPL_BACKUP, + 1 + }, + RedisModuleEvent_ReplAsyncLoad = { + REDISMODULE_EVENT_REPL_ASYNC_LOAD, + 1 + }, + RedisModuleEvent_ForkChild = { + REDISMODULE_EVENT_FORK_CHILD, + 1 + }, + RedisModuleEvent_EventLoop = { + REDISMODULE_EVENT_EVENTLOOP, + 1 + }, + RedisModuleEvent_Config = { + REDISMODULE_EVENT_CONFIG, + 1 + }, + RedisModuleEvent_Key = { + REDISMODULE_EVENT_KEY, + 1 + }; + +/* Those are values that are used for the 'subevent' callback argument. */ +#define REDISMODULE_SUBEVENT_PERSISTENCE_RDB_START 0 +#define REDISMODULE_SUBEVENT_PERSISTENCE_AOF_START 1 +#define REDISMODULE_SUBEVENT_PERSISTENCE_SYNC_RDB_START 2 +#define REDISMODULE_SUBEVENT_PERSISTENCE_ENDED 3 +#define REDISMODULE_SUBEVENT_PERSISTENCE_FAILED 4 +#define REDISMODULE_SUBEVENT_PERSISTENCE_SYNC_AOF_START 5 +#define _REDISMODULE_SUBEVENT_PERSISTENCE_NEXT 6 + +#define REDISMODULE_SUBEVENT_LOADING_RDB_START 0 +#define REDISMODULE_SUBEVENT_LOADING_AOF_START 1 +#define REDISMODULE_SUBEVENT_LOADING_REPL_START 2 +#define REDISMODULE_SUBEVENT_LOADING_ENDED 3 +#define REDISMODULE_SUBEVENT_LOADING_FAILED 4 +#define _REDISMODULE_SUBEVENT_LOADING_NEXT 5 + +#define REDISMODULE_SUBEVENT_CLIENT_CHANGE_CONNECTED 0 +#define REDISMODULE_SUBEVENT_CLIENT_CHANGE_DISCONNECTED 1 +#define _REDISMODULE_SUBEVENT_CLIENT_CHANGE_NEXT 2 + +#define REDISMODULE_SUBEVENT_MASTER_LINK_UP 0 +#define REDISMODULE_SUBEVENT_MASTER_LINK_DOWN 1 +#define _REDISMODULE_SUBEVENT_MASTER_NEXT 2 + +#define REDISMODULE_SUBEVENT_REPLICA_CHANGE_ONLINE 0 +#define REDISMODULE_SUBEVENT_REPLICA_CHANGE_OFFLINE 1 +#define _REDISMODULE_SUBEVENT_REPLICA_CHANGE_NEXT 2 + +#define REDISMODULE_EVENT_REPLROLECHANGED_NOW_MASTER 0 +#define REDISMODULE_EVENT_REPLROLECHANGED_NOW_REPLICA 1 +#define _REDISMODULE_EVENT_REPLROLECHANGED_NEXT 2 + +#define REDISMODULE_SUBEVENT_FLUSHDB_START 0 +#define REDISMODULE_SUBEVENT_FLUSHDB_END 1 +#define _REDISMODULE_SUBEVENT_FLUSHDB_NEXT 2 + +#define REDISMODULE_SUBEVENT_MODULE_LOADED 0 +#define REDISMODULE_SUBEVENT_MODULE_UNLOADED 1 +#define _REDISMODULE_SUBEVENT_MODULE_NEXT 2 + +#define REDISMODULE_SUBEVENT_CONFIG_CHANGE 0 +#define _REDISMODULE_SUBEVENT_CONFIG_NEXT 1 + +#define REDISMODULE_SUBEVENT_LOADING_PROGRESS_RDB 0 +#define REDISMODULE_SUBEVENT_LOADING_PROGRESS_AOF 1 +#define _REDISMODULE_SUBEVENT_LOADING_PROGRESS_NEXT 2 + +/* Replication Backup events are deprecated since Redis 7.0 and are never fired. */ +#define REDISMODULE_SUBEVENT_REPL_BACKUP_CREATE 0 +#define REDISMODULE_SUBEVENT_REPL_BACKUP_RESTORE 1 +#define REDISMODULE_SUBEVENT_REPL_BACKUP_DISCARD 2 +#define _REDISMODULE_SUBEVENT_REPL_BACKUP_NEXT 3 + +#define REDISMODULE_SUBEVENT_REPL_ASYNC_LOAD_STARTED 0 +#define REDISMODULE_SUBEVENT_REPL_ASYNC_LOAD_ABORTED 1 +#define REDISMODULE_SUBEVENT_REPL_ASYNC_LOAD_COMPLETED 2 +#define _REDISMODULE_SUBEVENT_REPL_ASYNC_LOAD_NEXT 3 + +#define REDISMODULE_SUBEVENT_FORK_CHILD_BORN 0 +#define REDISMODULE_SUBEVENT_FORK_CHILD_DIED 1 +#define _REDISMODULE_SUBEVENT_FORK_CHILD_NEXT 2 + +#define REDISMODULE_SUBEVENT_EVENTLOOP_BEFORE_SLEEP 0 +#define REDISMODULE_SUBEVENT_EVENTLOOP_AFTER_SLEEP 1 +#define _REDISMODULE_SUBEVENT_EVENTLOOP_NEXT 2 + +#define REDISMODULE_SUBEVENT_KEY_DELETED 0 +#define REDISMODULE_SUBEVENT_KEY_EXPIRED 1 +#define REDISMODULE_SUBEVENT_KEY_EVICTED 2 +#define REDISMODULE_SUBEVENT_KEY_OVERWRITTEN 3 +#define _REDISMODULE_SUBEVENT_KEY_NEXT 4 + +#define _REDISMODULE_SUBEVENT_SHUTDOWN_NEXT 0 +#define _REDISMODULE_SUBEVENT_CRON_LOOP_NEXT 0 +#define _REDISMODULE_SUBEVENT_SWAPDB_NEXT 0 + +/* RedisModuleClientInfo flags. */ +#define REDISMODULE_CLIENTINFO_FLAG_SSL (1<<0) +#define REDISMODULE_CLIENTINFO_FLAG_PUBSUB (1<<1) +#define REDISMODULE_CLIENTINFO_FLAG_BLOCKED (1<<2) +#define REDISMODULE_CLIENTINFO_FLAG_TRACKING (1<<3) +#define REDISMODULE_CLIENTINFO_FLAG_UNIXSOCKET (1<<4) +#define REDISMODULE_CLIENTINFO_FLAG_MULTI (1<<5) + +/* Here we take all the structures that the module pass to the core + * and the other way around. Notably the list here contains the structures + * used by the hooks API RedisModule_RegisterToServerEvent(). + * + * The structures always start with a 'version' field. This is useful + * when we want to pass a reference to the structure to the core APIs, + * for the APIs to fill the structure. In that case, the structure 'version' + * field is initialized before passing it to the core, so that the core is + * able to cast the pointer to the appropriate structure version. In this + * way we obtain ABI compatibility. + * + * Here we'll list all the structure versions in case they evolve over time, + * however using a define, we'll make sure to use the last version as the + * public name for the module to use. */ + +#define REDISMODULE_CLIENTINFO_VERSION 1 +typedef struct RedisModuleClientInfo { + uint64_t version; /* Version of this structure for ABI compat. */ + uint64_t flags; /* REDISMODULE_CLIENTINFO_FLAG_* */ + uint64_t id; /* Client ID. */ + char addr[46]; /* IPv4 or IPv6 address. */ + uint16_t port; /* TCP port. */ + uint16_t db; /* Selected DB. */ +} RedisModuleClientInfoV1; + +#define RedisModuleClientInfo RedisModuleClientInfoV1 + +#define REDISMODULE_CLIENTINFO_INITIALIZER_V1 { .version = 1 } + +#define REDISMODULE_REPLICATIONINFO_VERSION 1 +typedef struct RedisModuleReplicationInfo { + uint64_t version; /* Not used since this structure is never passed + from the module to the core right now. Here + for future compatibility. */ + int master; /* true if master, false if replica */ + char *masterhost; /* master instance hostname for NOW_REPLICA */ + int masterport; /* master instance port for NOW_REPLICA */ + char *replid1; /* Main replication ID */ + char *replid2; /* Secondary replication ID */ + uint64_t repl1_offset; /* Main replication offset */ + uint64_t repl2_offset; /* Offset of replid2 validity */ +} RedisModuleReplicationInfoV1; + +#define RedisModuleReplicationInfo RedisModuleReplicationInfoV1 + +#define REDISMODULE_FLUSHINFO_VERSION 1 +typedef struct RedisModuleFlushInfo { + uint64_t version; /* Not used since this structure is never passed + from the module to the core right now. Here + for future compatibility. */ + int32_t sync; /* Synchronous or threaded flush?. */ + int32_t dbnum; /* Flushed database number, -1 for ALL. */ +} RedisModuleFlushInfoV1; + +#define RedisModuleFlushInfo RedisModuleFlushInfoV1 + +#define REDISMODULE_MODULE_CHANGE_VERSION 1 +typedef struct RedisModuleModuleChange { + uint64_t version; /* Not used since this structure is never passed + from the module to the core right now. Here + for future compatibility. */ + const char* module_name;/* Name of module loaded or unloaded. */ + int32_t module_version; /* Module version. */ +} RedisModuleModuleChangeV1; + +#define RedisModuleModuleChange RedisModuleModuleChangeV1 + +#define REDISMODULE_CONFIGCHANGE_VERSION 1 +typedef struct RedisModuleConfigChange { + uint64_t version; /* Not used since this structure is never passed + from the module to the core right now. Here + for future compatibility. */ + uint32_t num_changes; /* how many redis config options were changed */ + const char **config_names; /* the config names that were changed */ +} RedisModuleConfigChangeV1; + +#define RedisModuleConfigChange RedisModuleConfigChangeV1 + +#define REDISMODULE_CRON_LOOP_VERSION 1 +typedef struct RedisModuleCronLoopInfo { + uint64_t version; /* Not used since this structure is never passed + from the module to the core right now. Here + for future compatibility. */ + int32_t hz; /* Approximate number of events per second. */ +} RedisModuleCronLoopV1; + +#define RedisModuleCronLoop RedisModuleCronLoopV1 + +#define REDISMODULE_LOADING_PROGRESS_VERSION 1 +typedef struct RedisModuleLoadingProgressInfo { + uint64_t version; /* Not used since this structure is never passed + from the module to the core right now. Here + for future compatibility. */ + int32_t hz; /* Approximate number of events per second. */ + int32_t progress; /* Approximate progress between 0 and 1024, or -1 + * if unknown. */ +} RedisModuleLoadingProgressV1; + +#define RedisModuleLoadingProgress RedisModuleLoadingProgressV1 + +#define REDISMODULE_SWAPDBINFO_VERSION 1 +typedef struct RedisModuleSwapDbInfo { + uint64_t version; /* Not used since this structure is never passed + from the module to the core right now. Here + for future compatibility. */ + int32_t dbnum_first; /* Swap Db first dbnum */ + int32_t dbnum_second; /* Swap Db second dbnum */ +} RedisModuleSwapDbInfoV1; + +#define RedisModuleSwapDbInfo RedisModuleSwapDbInfoV1 + +#define REDISMODULE_KEYINFO_VERSION 1 +typedef struct RedisModuleKeyInfo { + uint64_t version; /* Not used since this structure is never passed + from the module to the core right now. Here + for future compatibility. */ + RedisModuleKey *key; /* Opened key. */ +} RedisModuleKeyInfoV1; + +#define RedisModuleKeyInfo RedisModuleKeyInfoV1 + +typedef enum { + REDISMODULE_ACL_LOG_AUTH = 0, /* Authentication failure */ + REDISMODULE_ACL_LOG_CMD, /* Command authorization failure */ + REDISMODULE_ACL_LOG_KEY, /* Key authorization failure */ + REDISMODULE_ACL_LOG_CHANNEL /* Channel authorization failure */ +} RedisModuleACLLogEntryReason; + +/* Incomplete structures needed by both the core and modules. */ +typedef struct RedisModuleIO RedisModuleIO; +typedef struct RedisModuleDigest RedisModuleDigest; +typedef struct RedisModuleInfoCtx RedisModuleInfoCtx; +typedef struct RedisModuleDefragCtx RedisModuleDefragCtx; + +/* Function pointers needed by both the core and modules, these needs to be + * exposed since you can't cast a function pointer to (void *). */ +typedef void (*RedisModuleInfoFunc)(RedisModuleInfoCtx *ctx, int for_crash_report); +typedef void (*RedisModuleDefragFunc)(RedisModuleDefragCtx *ctx); +typedef void (*RedisModuleUserChangedFunc) (uint64_t client_id, void *privdata); + +/* ------------------------- End of common defines ------------------------ */ + +/* ----------- The rest of the defines are only for modules ----------------- */ +#if !defined REDISMODULE_CORE || defined REDISMODULE_CORE_MODULE +/* Things defined for modules and core-modules. */ + +/* Macro definitions specific to individual compilers */ +#ifndef REDISMODULE_ATTR_UNUSED +# ifdef __GNUC__ +# define REDISMODULE_ATTR_UNUSED __attribute__((unused)) +# else +# define REDISMODULE_ATTR_UNUSED +# endif +#endif + +#ifndef REDISMODULE_ATTR_PRINTF +# ifdef __GNUC__ +# define REDISMODULE_ATTR_PRINTF(idx,cnt) __attribute__((format(printf,idx,cnt))) +# else +# define REDISMODULE_ATTR_PRINTF(idx,cnt) +# endif +#endif + +#ifndef REDISMODULE_ATTR_COMMON +# if defined(__GNUC__) && !(defined(__clang__) && defined(__cplusplus)) +# define REDISMODULE_ATTR_COMMON __attribute__((__common__)) +# else +# define REDISMODULE_ATTR_COMMON +# endif +#endif + +/* Incomplete structures for compiler checks but opaque access. */ +typedef struct RedisModuleCtx RedisModuleCtx; +typedef struct RedisModuleCommand RedisModuleCommand; +typedef struct RedisModuleCallReply RedisModuleCallReply; +typedef struct RedisModuleType RedisModuleType; +typedef struct RedisModuleBlockedClient RedisModuleBlockedClient; +typedef struct RedisModuleClusterInfo RedisModuleClusterInfo; +typedef struct RedisModuleDict RedisModuleDict; +typedef struct RedisModuleDictIter RedisModuleDictIter; +typedef struct RedisModuleCommandFilterCtx RedisModuleCommandFilterCtx; +typedef struct RedisModuleCommandFilter RedisModuleCommandFilter; +typedef struct RedisModuleServerInfoData RedisModuleServerInfoData; +typedef struct RedisModuleScanCursor RedisModuleScanCursor; +typedef struct RedisModuleUser RedisModuleUser; +typedef struct RedisModuleKeyOptCtx RedisModuleKeyOptCtx; +typedef struct RedisModuleRdbStream RedisModuleRdbStream; + +typedef int (*RedisModuleCmdFunc)(RedisModuleCtx *ctx, RedisModuleString **argv, int argc); +typedef void (*RedisModuleDisconnectFunc)(RedisModuleCtx *ctx, RedisModuleBlockedClient *bc); +typedef int (*RedisModuleNotificationFunc)(RedisModuleCtx *ctx, int type, const char *event, RedisModuleString *key); +typedef void (*RedisModulePostNotificationJobFunc) (RedisModuleCtx *ctx, void *pd); +typedef void *(*RedisModuleTypeLoadFunc)(RedisModuleIO *rdb, int encver); +typedef void (*RedisModuleTypeSaveFunc)(RedisModuleIO *rdb, void *value); +typedef int (*RedisModuleTypeAuxLoadFunc)(RedisModuleIO *rdb, int encver, int when); +typedef void (*RedisModuleTypeAuxSaveFunc)(RedisModuleIO *rdb, int when); +typedef void (*RedisModuleTypeRewriteFunc)(RedisModuleIO *aof, RedisModuleString *key, void *value); +typedef size_t (*RedisModuleTypeMemUsageFunc)(const void *value); +typedef size_t (*RedisModuleTypeMemUsageFunc2)(RedisModuleKeyOptCtx *ctx, const void *value, size_t sample_size); +typedef void (*RedisModuleTypeDigestFunc)(RedisModuleDigest *digest, void *value); +typedef void (*RedisModuleTypeFreeFunc)(void *value); +typedef size_t (*RedisModuleTypeFreeEffortFunc)(RedisModuleString *key, const void *value); +typedef size_t (*RedisModuleTypeFreeEffortFunc2)(RedisModuleKeyOptCtx *ctx, const void *value); +typedef void (*RedisModuleTypeUnlinkFunc)(RedisModuleString *key, const void *value); +typedef void (*RedisModuleTypeUnlinkFunc2)(RedisModuleKeyOptCtx *ctx, const void *value); +typedef void *(*RedisModuleTypeCopyFunc)(RedisModuleString *fromkey, RedisModuleString *tokey, const void *value); +typedef void *(*RedisModuleTypeCopyFunc2)(RedisModuleKeyOptCtx *ctx, const void *value); +typedef int (*RedisModuleTypeDefragFunc)(RedisModuleDefragCtx *ctx, RedisModuleString *key, void **value); +typedef void (*RedisModuleClusterMessageReceiver)(RedisModuleCtx *ctx, const char *sender_id, uint8_t type, const unsigned char *payload, uint32_t len); +typedef void (*RedisModuleTimerProc)(RedisModuleCtx *ctx, void *data); +typedef void (*RedisModuleCommandFilterFunc) (RedisModuleCommandFilterCtx *filter); +typedef void (*RedisModuleForkDoneHandler) (int exitcode, int bysignal, void *user_data); +typedef void (*RedisModuleScanCB)(RedisModuleCtx *ctx, RedisModuleString *keyname, RedisModuleKey *key, void *privdata); +typedef void (*RedisModuleScanKeyCB)(RedisModuleKey *key, RedisModuleString *field, RedisModuleString *value, void *privdata); +typedef RedisModuleString * (*RedisModuleConfigGetStringFunc)(const char *name, void *privdata); +typedef long long (*RedisModuleConfigGetNumericFunc)(const char *name, void *privdata); +typedef int (*RedisModuleConfigGetBoolFunc)(const char *name, void *privdata); +typedef int (*RedisModuleConfigGetEnumFunc)(const char *name, void *privdata); +typedef int (*RedisModuleConfigSetStringFunc)(const char *name, RedisModuleString *val, void *privdata, RedisModuleString **err); +typedef int (*RedisModuleConfigSetNumericFunc)(const char *name, long long val, void *privdata, RedisModuleString **err); +typedef int (*RedisModuleConfigSetBoolFunc)(const char *name, int val, void *privdata, RedisModuleString **err); +typedef int (*RedisModuleConfigSetEnumFunc)(const char *name, int val, void *privdata, RedisModuleString **err); +typedef int (*RedisModuleConfigApplyFunc)(RedisModuleCtx *ctx, void *privdata, RedisModuleString **err); +typedef void (*RedisModuleOnUnblocked)(RedisModuleCtx *ctx, RedisModuleCallReply *reply, void *private_data); +typedef int (*RedisModuleAuthCallback)(RedisModuleCtx *ctx, RedisModuleString *username, RedisModuleString *password, RedisModuleString **err); + +typedef struct RedisModuleTypeMethods { + uint64_t version; + RedisModuleTypeLoadFunc rdb_load; + RedisModuleTypeSaveFunc rdb_save; + RedisModuleTypeRewriteFunc aof_rewrite; + RedisModuleTypeMemUsageFunc mem_usage; + RedisModuleTypeDigestFunc digest; + RedisModuleTypeFreeFunc free; + RedisModuleTypeAuxLoadFunc aux_load; + RedisModuleTypeAuxSaveFunc aux_save; + int aux_save_triggers; + RedisModuleTypeFreeEffortFunc free_effort; + RedisModuleTypeUnlinkFunc unlink; + RedisModuleTypeCopyFunc copy; + RedisModuleTypeDefragFunc defrag; + RedisModuleTypeMemUsageFunc2 mem_usage2; + RedisModuleTypeFreeEffortFunc2 free_effort2; + RedisModuleTypeUnlinkFunc2 unlink2; + RedisModuleTypeCopyFunc2 copy2; + RedisModuleTypeAuxSaveFunc aux_save2; +} RedisModuleTypeMethods; + +#define REDISMODULE_GET_API(name) \ + RedisModule_GetApi("RedisModule_" #name, ((void **)&RedisModule_ ## name)) + +/* Default API declaration prefix (not 'extern' for backwards compatibility) */ +#ifndef REDISMODULE_API +#define REDISMODULE_API +#endif + +/* Default API declaration suffix (compiler attributes) */ +#ifndef REDISMODULE_ATTR +#define REDISMODULE_ATTR REDISMODULE_ATTR_COMMON +#endif + +REDISMODULE_API void * (*RedisModule_Alloc)(size_t bytes) REDISMODULE_ATTR; +REDISMODULE_API void * (*RedisModule_TryAlloc)(size_t bytes) REDISMODULE_ATTR; +REDISMODULE_API void * (*RedisModule_Realloc)(void *ptr, size_t bytes) REDISMODULE_ATTR; +REDISMODULE_API void * (*RedisModule_TryRealloc)(void *ptr, size_t bytes) REDISMODULE_ATTR; +REDISMODULE_API void (*RedisModule_Free)(void *ptr) REDISMODULE_ATTR; +REDISMODULE_API void * (*RedisModule_Calloc)(size_t nmemb, size_t size) REDISMODULE_ATTR; +REDISMODULE_API void * (*RedisModule_TryCalloc)(size_t nmemb, size_t size) REDISMODULE_ATTR; +REDISMODULE_API char * (*RedisModule_Strdup)(const char *str) REDISMODULE_ATTR; +REDISMODULE_API int (*RedisModule_GetApi)(const char *, void *) REDISMODULE_ATTR; +REDISMODULE_API int (*RedisModule_CreateCommand)(RedisModuleCtx *ctx, const char *name, RedisModuleCmdFunc cmdfunc, const char *strflags, int firstkey, int lastkey, int keystep) REDISMODULE_ATTR; +REDISMODULE_API RedisModuleCommand *(*RedisModule_GetCommand)(RedisModuleCtx *ctx, const char *name) REDISMODULE_ATTR; +REDISMODULE_API int (*RedisModule_CreateSubcommand)(RedisModuleCommand *parent, const char *name, RedisModuleCmdFunc cmdfunc, const char *strflags, int firstkey, int lastkey, int keystep) REDISMODULE_ATTR; +REDISMODULE_API int (*RedisModule_SetCommandInfo)(RedisModuleCommand *command, const RedisModuleCommandInfo *info) REDISMODULE_ATTR; +REDISMODULE_API int (*RedisModule_SetCommandACLCategories)(RedisModuleCommand *command, const char *ctgrsflags) REDISMODULE_ATTR; +REDISMODULE_API int (*RedisModule_AddACLCategory)(RedisModuleCtx *ctx, const char *name) REDISMODULE_ATTR; +REDISMODULE_API void (*RedisModule_SetModuleAttribs)(RedisModuleCtx *ctx, const char *name, int ver, int apiver) REDISMODULE_ATTR; +REDISMODULE_API int (*RedisModule_IsModuleNameBusy)(const char *name) REDISMODULE_ATTR; +REDISMODULE_API int (*RedisModule_WrongArity)(RedisModuleCtx *ctx) REDISMODULE_ATTR; +REDISMODULE_API int (*RedisModule_ReplyWithLongLong)(RedisModuleCtx *ctx, long long ll) REDISMODULE_ATTR; +REDISMODULE_API int (*RedisModule_GetSelectedDb)(RedisModuleCtx *ctx) REDISMODULE_ATTR; +REDISMODULE_API int (*RedisModule_SelectDb)(RedisModuleCtx *ctx, int newid) REDISMODULE_ATTR; +REDISMODULE_API int (*RedisModule_KeyExists)(RedisModuleCtx *ctx, RedisModuleString *keyname) REDISMODULE_ATTR; +REDISMODULE_API RedisModuleKey * (*RedisModule_OpenKey)(RedisModuleCtx *ctx, RedisModuleString *keyname, int mode) REDISMODULE_ATTR; +REDISMODULE_API int (*RedisModule_GetOpenKeyModesAll)(void) REDISMODULE_ATTR; +REDISMODULE_API void (*RedisModule_CloseKey)(RedisModuleKey *kp) REDISMODULE_ATTR; +REDISMODULE_API int (*RedisModule_KeyType)(RedisModuleKey *kp) REDISMODULE_ATTR; +REDISMODULE_API size_t (*RedisModule_ValueLength)(RedisModuleKey *kp) REDISMODULE_ATTR; +REDISMODULE_API int (*RedisModule_ListPush)(RedisModuleKey *kp, int where, RedisModuleString *ele) REDISMODULE_ATTR; +REDISMODULE_API RedisModuleString * (*RedisModule_ListPop)(RedisModuleKey *key, int where) REDISMODULE_ATTR; +REDISMODULE_API RedisModuleString * (*RedisModule_ListGet)(RedisModuleKey *key, long index) REDISMODULE_ATTR; +REDISMODULE_API int (*RedisModule_ListSet)(RedisModuleKey *key, long index, RedisModuleString *value) REDISMODULE_ATTR; +REDISMODULE_API int (*RedisModule_ListInsert)(RedisModuleKey *key, long index, RedisModuleString *value) REDISMODULE_ATTR; +REDISMODULE_API int (*RedisModule_ListDelete)(RedisModuleKey *key, long index) REDISMODULE_ATTR; +REDISMODULE_API RedisModuleCallReply * (*RedisModule_Call)(RedisModuleCtx *ctx, const char *cmdname, const char *fmt, ...) REDISMODULE_ATTR; +REDISMODULE_API const char * (*RedisModule_CallReplyProto)(RedisModuleCallReply *reply, size_t *len) REDISMODULE_ATTR; +REDISMODULE_API void (*RedisModule_FreeCallReply)(RedisModuleCallReply *reply) REDISMODULE_ATTR; +REDISMODULE_API int (*RedisModule_CallReplyType)(RedisModuleCallReply *reply) REDISMODULE_ATTR; +REDISMODULE_API long long (*RedisModule_CallReplyInteger)(RedisModuleCallReply *reply) REDISMODULE_ATTR; +REDISMODULE_API double (*RedisModule_CallReplyDouble)(RedisModuleCallReply *reply) REDISMODULE_ATTR; +REDISMODULE_API int (*RedisModule_CallReplyBool)(RedisModuleCallReply *reply) REDISMODULE_ATTR; +REDISMODULE_API const char* (*RedisModule_CallReplyBigNumber)(RedisModuleCallReply *reply, size_t *len) REDISMODULE_ATTR; +REDISMODULE_API const char* (*RedisModule_CallReplyVerbatim)(RedisModuleCallReply *reply, size_t *len, const char **format) REDISMODULE_ATTR; +REDISMODULE_API RedisModuleCallReply * (*RedisModule_CallReplySetElement)(RedisModuleCallReply *reply, size_t idx) REDISMODULE_ATTR; +REDISMODULE_API int (*RedisModule_CallReplyMapElement)(RedisModuleCallReply *reply, size_t idx, RedisModuleCallReply **key, RedisModuleCallReply **val) REDISMODULE_ATTR; +REDISMODULE_API int (*RedisModule_CallReplyAttributeElement)(RedisModuleCallReply *reply, size_t idx, RedisModuleCallReply **key, RedisModuleCallReply **val) REDISMODULE_ATTR; +REDISMODULE_API void (*RedisModule_CallReplyPromiseSetUnblockHandler)(RedisModuleCallReply *reply, RedisModuleOnUnblocked on_unblock, void *private_data) REDISMODULE_ATTR; +REDISMODULE_API int (*RedisModule_CallReplyPromiseAbort)(RedisModuleCallReply *reply, void **private_data) REDISMODULE_ATTR; +REDISMODULE_API RedisModuleCallReply * (*RedisModule_CallReplyAttribute)(RedisModuleCallReply *reply) REDISMODULE_ATTR; +REDISMODULE_API size_t (*RedisModule_CallReplyLength)(RedisModuleCallReply *reply) REDISMODULE_ATTR; +REDISMODULE_API RedisModuleCallReply * (*RedisModule_CallReplyArrayElement)(RedisModuleCallReply *reply, size_t idx) REDISMODULE_ATTR; +REDISMODULE_API RedisModuleString * (*RedisModule_CreateString)(RedisModuleCtx *ctx, const char *ptr, size_t len) REDISMODULE_ATTR; +REDISMODULE_API RedisModuleString * (*RedisModule_CreateStringFromLongLong)(RedisModuleCtx *ctx, long long ll) REDISMODULE_ATTR; +REDISMODULE_API RedisModuleString * (*RedisModule_CreateStringFromULongLong)(RedisModuleCtx *ctx, unsigned long long ull) REDISMODULE_ATTR; +REDISMODULE_API RedisModuleString * (*RedisModule_CreateStringFromDouble)(RedisModuleCtx *ctx, double d) REDISMODULE_ATTR; +REDISMODULE_API RedisModuleString * (*RedisModule_CreateStringFromLongDouble)(RedisModuleCtx *ctx, long double ld, int humanfriendly) REDISMODULE_ATTR; +REDISMODULE_API RedisModuleString * (*RedisModule_CreateStringFromString)(RedisModuleCtx *ctx, const RedisModuleString *str) REDISMODULE_ATTR; +REDISMODULE_API RedisModuleString * (*RedisModule_CreateStringFromStreamID)(RedisModuleCtx *ctx, const RedisModuleStreamID *id) REDISMODULE_ATTR; +REDISMODULE_API RedisModuleString * (*RedisModule_CreateStringPrintf)(RedisModuleCtx *ctx, const char *fmt, ...) REDISMODULE_ATTR_PRINTF(2,3) REDISMODULE_ATTR; +REDISMODULE_API void (*RedisModule_FreeString)(RedisModuleCtx *ctx, RedisModuleString *str) REDISMODULE_ATTR; +REDISMODULE_API const char * (*RedisModule_StringPtrLen)(const RedisModuleString *str, size_t *len) REDISMODULE_ATTR; +REDISMODULE_API int (*RedisModule_ReplyWithError)(RedisModuleCtx *ctx, const char *err) REDISMODULE_ATTR; +REDISMODULE_API int (*RedisModule_ReplyWithErrorFormat)(RedisModuleCtx *ctx, const char *fmt, ...) REDISMODULE_ATTR; +REDISMODULE_API int (*RedisModule_ReplyWithSimpleString)(RedisModuleCtx *ctx, const char *msg) REDISMODULE_ATTR; +REDISMODULE_API int (*RedisModule_ReplyWithArray)(RedisModuleCtx *ctx, long len) REDISMODULE_ATTR; +REDISMODULE_API int (*RedisModule_ReplyWithMap)(RedisModuleCtx *ctx, long len) REDISMODULE_ATTR; +REDISMODULE_API int (*RedisModule_ReplyWithSet)(RedisModuleCtx *ctx, long len) REDISMODULE_ATTR; +REDISMODULE_API int (*RedisModule_ReplyWithAttribute)(RedisModuleCtx *ctx, long len) REDISMODULE_ATTR; +REDISMODULE_API int (*RedisModule_ReplyWithNullArray)(RedisModuleCtx *ctx) REDISMODULE_ATTR; +REDISMODULE_API int (*RedisModule_ReplyWithEmptyArray)(RedisModuleCtx *ctx) REDISMODULE_ATTR; +REDISMODULE_API void (*RedisModule_ReplySetArrayLength)(RedisModuleCtx *ctx, long len) REDISMODULE_ATTR; +REDISMODULE_API void (*RedisModule_ReplySetMapLength)(RedisModuleCtx *ctx, long len) REDISMODULE_ATTR; +REDISMODULE_API void (*RedisModule_ReplySetSetLength)(RedisModuleCtx *ctx, long len) REDISMODULE_ATTR; +REDISMODULE_API void (*RedisModule_ReplySetAttributeLength)(RedisModuleCtx *ctx, long len) REDISMODULE_ATTR; +REDISMODULE_API void (*RedisModule_ReplySetPushLength)(RedisModuleCtx *ctx, long len) REDISMODULE_ATTR; +REDISMODULE_API int (*RedisModule_ReplyWithStringBuffer)(RedisModuleCtx *ctx, const char *buf, size_t len) REDISMODULE_ATTR; +REDISMODULE_API int (*RedisModule_ReplyWithCString)(RedisModuleCtx *ctx, const char *buf) REDISMODULE_ATTR; +REDISMODULE_API int (*RedisModule_ReplyWithString)(RedisModuleCtx *ctx, RedisModuleString *str) REDISMODULE_ATTR; +REDISMODULE_API int (*RedisModule_ReplyWithEmptyString)(RedisModuleCtx *ctx) REDISMODULE_ATTR; +REDISMODULE_API int (*RedisModule_ReplyWithVerbatimString)(RedisModuleCtx *ctx, const char *buf, size_t len) REDISMODULE_ATTR; +REDISMODULE_API int (*RedisModule_ReplyWithVerbatimStringType)(RedisModuleCtx *ctx, const char *buf, size_t len, const char *ext) REDISMODULE_ATTR; +REDISMODULE_API int (*RedisModule_ReplyWithNull)(RedisModuleCtx *ctx) REDISMODULE_ATTR; +REDISMODULE_API int (*RedisModule_ReplyWithBool)(RedisModuleCtx *ctx, int b) REDISMODULE_ATTR; +REDISMODULE_API int (*RedisModule_ReplyWithLongDouble)(RedisModuleCtx *ctx, long double d) REDISMODULE_ATTR; +REDISMODULE_API int (*RedisModule_ReplyWithDouble)(RedisModuleCtx *ctx, double d) REDISMODULE_ATTR; +REDISMODULE_API int (*RedisModule_ReplyWithBigNumber)(RedisModuleCtx *ctx, const char *bignum, size_t len) REDISMODULE_ATTR; +REDISMODULE_API int (*RedisModule_ReplyWithCallReply)(RedisModuleCtx *ctx, RedisModuleCallReply *reply) REDISMODULE_ATTR; +REDISMODULE_API int (*RedisModule_StringToLongLong)(const RedisModuleString *str, long long *ll) REDISMODULE_ATTR; +REDISMODULE_API int (*RedisModule_StringToULongLong)(const RedisModuleString *str, unsigned long long *ull) REDISMODULE_ATTR; +REDISMODULE_API int (*RedisModule_StringToDouble)(const RedisModuleString *str, double *d) REDISMODULE_ATTR; +REDISMODULE_API int (*RedisModule_StringToLongDouble)(const RedisModuleString *str, long double *d) REDISMODULE_ATTR; +REDISMODULE_API int (*RedisModule_StringToStreamID)(const RedisModuleString *str, RedisModuleStreamID *id) REDISMODULE_ATTR; +REDISMODULE_API void (*RedisModule_AutoMemory)(RedisModuleCtx *ctx) REDISMODULE_ATTR; +REDISMODULE_API int (*RedisModule_Replicate)(RedisModuleCtx *ctx, const char *cmdname, const char *fmt, ...) REDISMODULE_ATTR; +REDISMODULE_API int (*RedisModule_ReplicateVerbatim)(RedisModuleCtx *ctx) REDISMODULE_ATTR; +REDISMODULE_API const char * (*RedisModule_CallReplyStringPtr)(RedisModuleCallReply *reply, size_t *len) REDISMODULE_ATTR; +REDISMODULE_API RedisModuleString * (*RedisModule_CreateStringFromCallReply)(RedisModuleCallReply *reply) REDISMODULE_ATTR; +REDISMODULE_API int (*RedisModule_DeleteKey)(RedisModuleKey *key) REDISMODULE_ATTR; +REDISMODULE_API int (*RedisModule_UnlinkKey)(RedisModuleKey *key) REDISMODULE_ATTR; +REDISMODULE_API int (*RedisModule_StringSet)(RedisModuleKey *key, RedisModuleString *str) REDISMODULE_ATTR; +REDISMODULE_API char * (*RedisModule_StringDMA)(RedisModuleKey *key, size_t *len, int mode) REDISMODULE_ATTR; +REDISMODULE_API int (*RedisModule_StringTruncate)(RedisModuleKey *key, size_t newlen) REDISMODULE_ATTR; +REDISMODULE_API mstime_t (*RedisModule_GetExpire)(RedisModuleKey *key) REDISMODULE_ATTR; +REDISMODULE_API int (*RedisModule_SetExpire)(RedisModuleKey *key, mstime_t expire) REDISMODULE_ATTR; +REDISMODULE_API mstime_t (*RedisModule_GetAbsExpire)(RedisModuleKey *key) REDISMODULE_ATTR; +REDISMODULE_API int (*RedisModule_SetAbsExpire)(RedisModuleKey *key, mstime_t expire) REDISMODULE_ATTR; +REDISMODULE_API void (*RedisModule_ResetDataset)(int restart_aof, int async) REDISMODULE_ATTR; +REDISMODULE_API unsigned long long (*RedisModule_DbSize)(RedisModuleCtx *ctx) REDISMODULE_ATTR; +REDISMODULE_API RedisModuleString * (*RedisModule_RandomKey)(RedisModuleCtx *ctx) REDISMODULE_ATTR; +REDISMODULE_API int (*RedisModule_ZsetAdd)(RedisModuleKey *key, double score, RedisModuleString *ele, int *flagsptr) REDISMODULE_ATTR; +REDISMODULE_API int (*RedisModule_ZsetIncrby)(RedisModuleKey *key, double score, RedisModuleString *ele, int *flagsptr, double *newscore) REDISMODULE_ATTR; +REDISMODULE_API int (*RedisModule_ZsetScore)(RedisModuleKey *key, RedisModuleString *ele, double *score) REDISMODULE_ATTR; +REDISMODULE_API int (*RedisModule_ZsetRem)(RedisModuleKey *key, RedisModuleString *ele, int *deleted) REDISMODULE_ATTR; +REDISMODULE_API void (*RedisModule_ZsetRangeStop)(RedisModuleKey *key) REDISMODULE_ATTR; +REDISMODULE_API int (*RedisModule_ZsetFirstInScoreRange)(RedisModuleKey *key, double min, double max, int minex, int maxex) REDISMODULE_ATTR; +REDISMODULE_API int (*RedisModule_ZsetLastInScoreRange)(RedisModuleKey *key, double min, double max, int minex, int maxex) REDISMODULE_ATTR; +REDISMODULE_API int (*RedisModule_ZsetFirstInLexRange)(RedisModuleKey *key, RedisModuleString *min, RedisModuleString *max) REDISMODULE_ATTR; +REDISMODULE_API int (*RedisModule_ZsetLastInLexRange)(RedisModuleKey *key, RedisModuleString *min, RedisModuleString *max) REDISMODULE_ATTR; +REDISMODULE_API RedisModuleString * (*RedisModule_ZsetRangeCurrentElement)(RedisModuleKey *key, double *score) REDISMODULE_ATTR; +REDISMODULE_API int (*RedisModule_ZsetRangeNext)(RedisModuleKey *key) REDISMODULE_ATTR; +REDISMODULE_API int (*RedisModule_ZsetRangePrev)(RedisModuleKey *key) REDISMODULE_ATTR; +REDISMODULE_API int (*RedisModule_ZsetRangeEndReached)(RedisModuleKey *key) REDISMODULE_ATTR; +REDISMODULE_API int (*RedisModule_HashSet)(RedisModuleKey *key, int flags, ...) REDISMODULE_ATTR; +REDISMODULE_API int (*RedisModule_HashGet)(RedisModuleKey *key, int flags, ...) REDISMODULE_ATTR; +REDISMODULE_API int (*RedisModule_StreamAdd)(RedisModuleKey *key, int flags, RedisModuleStreamID *id, RedisModuleString **argv, int64_t numfields) REDISMODULE_ATTR; +REDISMODULE_API int (*RedisModule_StreamDelete)(RedisModuleKey *key, RedisModuleStreamID *id) REDISMODULE_ATTR; +REDISMODULE_API int (*RedisModule_StreamIteratorStart)(RedisModuleKey *key, int flags, RedisModuleStreamID *startid, RedisModuleStreamID *endid) REDISMODULE_ATTR; +REDISMODULE_API int (*RedisModule_StreamIteratorStop)(RedisModuleKey *key) REDISMODULE_ATTR; +REDISMODULE_API int (*RedisModule_StreamIteratorNextID)(RedisModuleKey *key, RedisModuleStreamID *id, long *numfields) REDISMODULE_ATTR; +REDISMODULE_API int (*RedisModule_StreamIteratorNextField)(RedisModuleKey *key, RedisModuleString **field_ptr, RedisModuleString **value_ptr) REDISMODULE_ATTR; +REDISMODULE_API int (*RedisModule_StreamIteratorDelete)(RedisModuleKey *key) REDISMODULE_ATTR; +REDISMODULE_API long long (*RedisModule_StreamTrimByLength)(RedisModuleKey *key, int flags, long long length) REDISMODULE_ATTR; +REDISMODULE_API long long (*RedisModule_StreamTrimByID)(RedisModuleKey *key, int flags, RedisModuleStreamID *id) REDISMODULE_ATTR; +REDISMODULE_API int (*RedisModule_IsKeysPositionRequest)(RedisModuleCtx *ctx) REDISMODULE_ATTR; +REDISMODULE_API void (*RedisModule_KeyAtPos)(RedisModuleCtx *ctx, int pos) REDISMODULE_ATTR; +REDISMODULE_API void (*RedisModule_KeyAtPosWithFlags)(RedisModuleCtx *ctx, int pos, int flags) REDISMODULE_ATTR; +REDISMODULE_API int (*RedisModule_IsChannelsPositionRequest)(RedisModuleCtx *ctx) REDISMODULE_ATTR; +REDISMODULE_API void (*RedisModule_ChannelAtPosWithFlags)(RedisModuleCtx *ctx, int pos, int flags) REDISMODULE_ATTR; +REDISMODULE_API unsigned long long (*RedisModule_GetClientId)(RedisModuleCtx *ctx) REDISMODULE_ATTR; +REDISMODULE_API RedisModuleString * (*RedisModule_GetClientUserNameById)(RedisModuleCtx *ctx, uint64_t id) REDISMODULE_ATTR; +REDISMODULE_API int (*RedisModule_GetClientInfoById)(void *ci, uint64_t id) REDISMODULE_ATTR; +REDISMODULE_API RedisModuleString * (*RedisModule_GetClientNameById)(RedisModuleCtx *ctx, uint64_t id) REDISMODULE_ATTR; +REDISMODULE_API int (*RedisModule_SetClientNameById)(uint64_t id, RedisModuleString *name) REDISMODULE_ATTR; +REDISMODULE_API int (*RedisModule_PublishMessage)(RedisModuleCtx *ctx, RedisModuleString *channel, RedisModuleString *message) REDISMODULE_ATTR; +REDISMODULE_API int (*RedisModule_PublishMessageShard)(RedisModuleCtx *ctx, RedisModuleString *channel, RedisModuleString *message) REDISMODULE_ATTR; +REDISMODULE_API int (*RedisModule_GetContextFlags)(RedisModuleCtx *ctx) REDISMODULE_ATTR; +REDISMODULE_API int (*RedisModule_AvoidReplicaTraffic)(void) REDISMODULE_ATTR; +REDISMODULE_API void * (*RedisModule_PoolAlloc)(RedisModuleCtx *ctx, size_t bytes) REDISMODULE_ATTR; +REDISMODULE_API RedisModuleType * (*RedisModule_CreateDataType)(RedisModuleCtx *ctx, const char *name, int encver, RedisModuleTypeMethods *typemethods) REDISMODULE_ATTR; +REDISMODULE_API int (*RedisModule_ModuleTypeSetValue)(RedisModuleKey *key, RedisModuleType *mt, void *value) REDISMODULE_ATTR; +REDISMODULE_API int (*RedisModule_ModuleTypeReplaceValue)(RedisModuleKey *key, RedisModuleType *mt, void *new_value, void **old_value) REDISMODULE_ATTR; +REDISMODULE_API RedisModuleType * (*RedisModule_ModuleTypeGetType)(RedisModuleKey *key) REDISMODULE_ATTR; +REDISMODULE_API void * (*RedisModule_ModuleTypeGetValue)(RedisModuleKey *key) REDISMODULE_ATTR; +REDISMODULE_API int (*RedisModule_IsIOError)(RedisModuleIO *io) REDISMODULE_ATTR; +REDISMODULE_API void (*RedisModule_SetModuleOptions)(RedisModuleCtx *ctx, int options) REDISMODULE_ATTR; +REDISMODULE_API int (*RedisModule_SignalModifiedKey)(RedisModuleCtx *ctx, RedisModuleString *keyname) REDISMODULE_ATTR; +REDISMODULE_API void (*RedisModule_SaveUnsigned)(RedisModuleIO *io, uint64_t value) REDISMODULE_ATTR; +REDISMODULE_API uint64_t (*RedisModule_LoadUnsigned)(RedisModuleIO *io) REDISMODULE_ATTR; +REDISMODULE_API void (*RedisModule_SaveSigned)(RedisModuleIO *io, int64_t value) REDISMODULE_ATTR; +REDISMODULE_API int64_t (*RedisModule_LoadSigned)(RedisModuleIO *io) REDISMODULE_ATTR; +REDISMODULE_API void (*RedisModule_EmitAOF)(RedisModuleIO *io, const char *cmdname, const char *fmt, ...) REDISMODULE_ATTR; +REDISMODULE_API void (*RedisModule_SaveString)(RedisModuleIO *io, RedisModuleString *s) REDISMODULE_ATTR; +REDISMODULE_API void (*RedisModule_SaveStringBuffer)(RedisModuleIO *io, const char *str, size_t len) REDISMODULE_ATTR; +REDISMODULE_API RedisModuleString * (*RedisModule_LoadString)(RedisModuleIO *io) REDISMODULE_ATTR; +REDISMODULE_API char * (*RedisModule_LoadStringBuffer)(RedisModuleIO *io, size_t *lenptr) REDISMODULE_ATTR; +REDISMODULE_API void (*RedisModule_SaveDouble)(RedisModuleIO *io, double value) REDISMODULE_ATTR; +REDISMODULE_API double (*RedisModule_LoadDouble)(RedisModuleIO *io) REDISMODULE_ATTR; +REDISMODULE_API void (*RedisModule_SaveFloat)(RedisModuleIO *io, float value) REDISMODULE_ATTR; +REDISMODULE_API float (*RedisModule_LoadFloat)(RedisModuleIO *io) REDISMODULE_ATTR; +REDISMODULE_API void (*RedisModule_SaveLongDouble)(RedisModuleIO *io, long double value) REDISMODULE_ATTR; +REDISMODULE_API long double (*RedisModule_LoadLongDouble)(RedisModuleIO *io) REDISMODULE_ATTR; +REDISMODULE_API void * (*RedisModule_LoadDataTypeFromString)(const RedisModuleString *str, const RedisModuleType *mt) REDISMODULE_ATTR; +REDISMODULE_API void * (*RedisModule_LoadDataTypeFromStringEncver)(const RedisModuleString *str, const RedisModuleType *mt, int encver) REDISMODULE_ATTR; +REDISMODULE_API RedisModuleString * (*RedisModule_SaveDataTypeToString)(RedisModuleCtx *ctx, void *data, const RedisModuleType *mt) REDISMODULE_ATTR; +REDISMODULE_API void (*RedisModule_Log)(RedisModuleCtx *ctx, const char *level, const char *fmt, ...) REDISMODULE_ATTR REDISMODULE_ATTR_PRINTF(3,4); +REDISMODULE_API void (*RedisModule_LogIOError)(RedisModuleIO *io, const char *levelstr, const char *fmt, ...) REDISMODULE_ATTR REDISMODULE_ATTR_PRINTF(3,4); +REDISMODULE_API void (*RedisModule__Assert)(const char *estr, const char *file, int line) REDISMODULE_ATTR; +REDISMODULE_API void (*RedisModule_LatencyAddSample)(const char *event, mstime_t latency) REDISMODULE_ATTR; +REDISMODULE_API int (*RedisModule_StringAppendBuffer)(RedisModuleCtx *ctx, RedisModuleString *str, const char *buf, size_t len) REDISMODULE_ATTR; +REDISMODULE_API void (*RedisModule_TrimStringAllocation)(RedisModuleString *str) REDISMODULE_ATTR; +REDISMODULE_API void (*RedisModule_RetainString)(RedisModuleCtx *ctx, RedisModuleString *str) REDISMODULE_ATTR; +REDISMODULE_API RedisModuleString * (*RedisModule_HoldString)(RedisModuleCtx *ctx, RedisModuleString *str) REDISMODULE_ATTR; +REDISMODULE_API int (*RedisModule_StringCompare)(const RedisModuleString *a, const RedisModuleString *b) REDISMODULE_ATTR; +REDISMODULE_API RedisModuleCtx * (*RedisModule_GetContextFromIO)(RedisModuleIO *io) REDISMODULE_ATTR; +REDISMODULE_API const RedisModuleString * (*RedisModule_GetKeyNameFromIO)(RedisModuleIO *io) REDISMODULE_ATTR; +REDISMODULE_API const RedisModuleString * (*RedisModule_GetKeyNameFromModuleKey)(RedisModuleKey *key) REDISMODULE_ATTR; +REDISMODULE_API int (*RedisModule_GetDbIdFromModuleKey)(RedisModuleKey *key) REDISMODULE_ATTR; +REDISMODULE_API int (*RedisModule_GetDbIdFromIO)(RedisModuleIO *io) REDISMODULE_ATTR; +REDISMODULE_API int (*RedisModule_GetDbIdFromOptCtx)(RedisModuleKeyOptCtx *ctx) REDISMODULE_ATTR; +REDISMODULE_API int (*RedisModule_GetToDbIdFromOptCtx)(RedisModuleKeyOptCtx *ctx) REDISMODULE_ATTR; +REDISMODULE_API const RedisModuleString * (*RedisModule_GetKeyNameFromOptCtx)(RedisModuleKeyOptCtx *ctx) REDISMODULE_ATTR; +REDISMODULE_API const RedisModuleString * (*RedisModule_GetToKeyNameFromOptCtx)(RedisModuleKeyOptCtx *ctx) REDISMODULE_ATTR; +REDISMODULE_API mstime_t (*RedisModule_Milliseconds)(void) REDISMODULE_ATTR; +REDISMODULE_API uint64_t (*RedisModule_MonotonicMicroseconds)(void) REDISMODULE_ATTR; +REDISMODULE_API ustime_t (*RedisModule_Microseconds)(void) REDISMODULE_ATTR; +REDISMODULE_API ustime_t (*RedisModule_CachedMicroseconds)(void) REDISMODULE_ATTR; +REDISMODULE_API void (*RedisModule_DigestAddStringBuffer)(RedisModuleDigest *md, const char *ele, size_t len) REDISMODULE_ATTR; +REDISMODULE_API void (*RedisModule_DigestAddLongLong)(RedisModuleDigest *md, long long ele) REDISMODULE_ATTR; +REDISMODULE_API void (*RedisModule_DigestEndSequence)(RedisModuleDigest *md) REDISMODULE_ATTR; +REDISMODULE_API int (*RedisModule_GetDbIdFromDigest)(RedisModuleDigest *dig) REDISMODULE_ATTR; +REDISMODULE_API const RedisModuleString * (*RedisModule_GetKeyNameFromDigest)(RedisModuleDigest *dig) REDISMODULE_ATTR; +REDISMODULE_API RedisModuleDict * (*RedisModule_CreateDict)(RedisModuleCtx *ctx) REDISMODULE_ATTR; +REDISMODULE_API void (*RedisModule_FreeDict)(RedisModuleCtx *ctx, RedisModuleDict *d) REDISMODULE_ATTR; +REDISMODULE_API uint64_t (*RedisModule_DictSize)(RedisModuleDict *d) REDISMODULE_ATTR; +REDISMODULE_API int (*RedisModule_DictSetC)(RedisModuleDict *d, void *key, size_t keylen, void *ptr) REDISMODULE_ATTR; +REDISMODULE_API int (*RedisModule_DictReplaceC)(RedisModuleDict *d, void *key, size_t keylen, void *ptr) REDISMODULE_ATTR; +REDISMODULE_API int (*RedisModule_DictSet)(RedisModuleDict *d, RedisModuleString *key, void *ptr) REDISMODULE_ATTR; +REDISMODULE_API int (*RedisModule_DictReplace)(RedisModuleDict *d, RedisModuleString *key, void *ptr) REDISMODULE_ATTR; +REDISMODULE_API void * (*RedisModule_DictGetC)(RedisModuleDict *d, void *key, size_t keylen, int *nokey) REDISMODULE_ATTR; +REDISMODULE_API void * (*RedisModule_DictGet)(RedisModuleDict *d, RedisModuleString *key, int *nokey) REDISMODULE_ATTR; +REDISMODULE_API int (*RedisModule_DictDelC)(RedisModuleDict *d, void *key, size_t keylen, void *oldval) REDISMODULE_ATTR; +REDISMODULE_API int (*RedisModule_DictDel)(RedisModuleDict *d, RedisModuleString *key, void *oldval) REDISMODULE_ATTR; +REDISMODULE_API RedisModuleDictIter * (*RedisModule_DictIteratorStartC)(RedisModuleDict *d, const char *op, void *key, size_t keylen) REDISMODULE_ATTR; +REDISMODULE_API RedisModuleDictIter * (*RedisModule_DictIteratorStart)(RedisModuleDict *d, const char *op, RedisModuleString *key) REDISMODULE_ATTR; +REDISMODULE_API void (*RedisModule_DictIteratorStop)(RedisModuleDictIter *di) REDISMODULE_ATTR; +REDISMODULE_API int (*RedisModule_DictIteratorReseekC)(RedisModuleDictIter *di, const char *op, void *key, size_t keylen) REDISMODULE_ATTR; +REDISMODULE_API int (*RedisModule_DictIteratorReseek)(RedisModuleDictIter *di, const char *op, RedisModuleString *key) REDISMODULE_ATTR; +REDISMODULE_API void * (*RedisModule_DictNextC)(RedisModuleDictIter *di, size_t *keylen, void **dataptr) REDISMODULE_ATTR; +REDISMODULE_API void * (*RedisModule_DictPrevC)(RedisModuleDictIter *di, size_t *keylen, void **dataptr) REDISMODULE_ATTR; +REDISMODULE_API RedisModuleString * (*RedisModule_DictNext)(RedisModuleCtx *ctx, RedisModuleDictIter *di, void **dataptr) REDISMODULE_ATTR; +REDISMODULE_API RedisModuleString * (*RedisModule_DictPrev)(RedisModuleCtx *ctx, RedisModuleDictIter *di, void **dataptr) REDISMODULE_ATTR; +REDISMODULE_API int (*RedisModule_DictCompareC)(RedisModuleDictIter *di, const char *op, void *key, size_t keylen) REDISMODULE_ATTR; +REDISMODULE_API int (*RedisModule_DictCompare)(RedisModuleDictIter *di, const char *op, RedisModuleString *key) REDISMODULE_ATTR; +REDISMODULE_API int (*RedisModule_RegisterInfoFunc)(RedisModuleCtx *ctx, RedisModuleInfoFunc cb) REDISMODULE_ATTR; +REDISMODULE_API void (*RedisModule_RegisterAuthCallback)(RedisModuleCtx *ctx, RedisModuleAuthCallback cb) REDISMODULE_ATTR; +REDISMODULE_API int (*RedisModule_InfoAddSection)(RedisModuleInfoCtx *ctx, const char *name) REDISMODULE_ATTR; +REDISMODULE_API int (*RedisModule_InfoBeginDictField)(RedisModuleInfoCtx *ctx, const char *name) REDISMODULE_ATTR; +REDISMODULE_API int (*RedisModule_InfoEndDictField)(RedisModuleInfoCtx *ctx) REDISMODULE_ATTR; +REDISMODULE_API int (*RedisModule_InfoAddFieldString)(RedisModuleInfoCtx *ctx, const char *field, RedisModuleString *value) REDISMODULE_ATTR; +REDISMODULE_API int (*RedisModule_InfoAddFieldCString)(RedisModuleInfoCtx *ctx, const char *field,const char *value) REDISMODULE_ATTR; +REDISMODULE_API int (*RedisModule_InfoAddFieldDouble)(RedisModuleInfoCtx *ctx, const char *field, double value) REDISMODULE_ATTR; +REDISMODULE_API int (*RedisModule_InfoAddFieldLongLong)(RedisModuleInfoCtx *ctx, const char *field, long long value) REDISMODULE_ATTR; +REDISMODULE_API int (*RedisModule_InfoAddFieldULongLong)(RedisModuleInfoCtx *ctx, const char *field, unsigned long long value) REDISMODULE_ATTR; +REDISMODULE_API RedisModuleServerInfoData * (*RedisModule_GetServerInfo)(RedisModuleCtx *ctx, const char *section) REDISMODULE_ATTR; +REDISMODULE_API void (*RedisModule_FreeServerInfo)(RedisModuleCtx *ctx, RedisModuleServerInfoData *data) REDISMODULE_ATTR; +REDISMODULE_API RedisModuleString * (*RedisModule_ServerInfoGetField)(RedisModuleCtx *ctx, RedisModuleServerInfoData *data, const char* field) REDISMODULE_ATTR; +REDISMODULE_API const char * (*RedisModule_ServerInfoGetFieldC)(RedisModuleServerInfoData *data, const char* field) REDISMODULE_ATTR; +REDISMODULE_API long long (*RedisModule_ServerInfoGetFieldSigned)(RedisModuleServerInfoData *data, const char* field, int *out_err) REDISMODULE_ATTR; +REDISMODULE_API unsigned long long (*RedisModule_ServerInfoGetFieldUnsigned)(RedisModuleServerInfoData *data, const char* field, int *out_err) REDISMODULE_ATTR; +REDISMODULE_API double (*RedisModule_ServerInfoGetFieldDouble)(RedisModuleServerInfoData *data, const char* field, int *out_err) REDISMODULE_ATTR; +REDISMODULE_API int (*RedisModule_SubscribeToServerEvent)(RedisModuleCtx *ctx, RedisModuleEvent event, RedisModuleEventCallback callback) REDISMODULE_ATTR; +REDISMODULE_API int (*RedisModule_SetLRU)(RedisModuleKey *key, mstime_t lru_idle) REDISMODULE_ATTR; +REDISMODULE_API int (*RedisModule_GetLRU)(RedisModuleKey *key, mstime_t *lru_idle) REDISMODULE_ATTR; +REDISMODULE_API int (*RedisModule_SetLFU)(RedisModuleKey *key, long long lfu_freq) REDISMODULE_ATTR; +REDISMODULE_API int (*RedisModule_GetLFU)(RedisModuleKey *key, long long *lfu_freq) REDISMODULE_ATTR; +REDISMODULE_API RedisModuleBlockedClient * (*RedisModule_BlockClientOnKeys)(RedisModuleCtx *ctx, RedisModuleCmdFunc reply_callback, RedisModuleCmdFunc timeout_callback, void (*free_privdata)(RedisModuleCtx*,void*), long long timeout_ms, RedisModuleString **keys, int numkeys, void *privdata) REDISMODULE_ATTR; +REDISMODULE_API RedisModuleBlockedClient * (*RedisModule_BlockClientOnKeysWithFlags)(RedisModuleCtx *ctx, RedisModuleCmdFunc reply_callback, RedisModuleCmdFunc timeout_callback, void (*free_privdata)(RedisModuleCtx*,void*), long long timeout_ms, RedisModuleString **keys, int numkeys, void *privdata, int flags) REDISMODULE_ATTR; +REDISMODULE_API void (*RedisModule_SignalKeyAsReady)(RedisModuleCtx *ctx, RedisModuleString *key) REDISMODULE_ATTR; +REDISMODULE_API RedisModuleString * (*RedisModule_GetBlockedClientReadyKey)(RedisModuleCtx *ctx) REDISMODULE_ATTR; +REDISMODULE_API RedisModuleScanCursor * (*RedisModule_ScanCursorCreate)(void) REDISMODULE_ATTR; +REDISMODULE_API void (*RedisModule_ScanCursorRestart)(RedisModuleScanCursor *cursor) REDISMODULE_ATTR; +REDISMODULE_API void (*RedisModule_ScanCursorDestroy)(RedisModuleScanCursor *cursor) REDISMODULE_ATTR; +REDISMODULE_API int (*RedisModule_Scan)(RedisModuleCtx *ctx, RedisModuleScanCursor *cursor, RedisModuleScanCB fn, void *privdata) REDISMODULE_ATTR; +REDISMODULE_API int (*RedisModule_ScanKey)(RedisModuleKey *key, RedisModuleScanCursor *cursor, RedisModuleScanKeyCB fn, void *privdata) REDISMODULE_ATTR; +REDISMODULE_API int (*RedisModule_GetContextFlagsAll)(void) REDISMODULE_ATTR; +REDISMODULE_API int (*RedisModule_GetModuleOptionsAll)(void) REDISMODULE_ATTR; +REDISMODULE_API int (*RedisModule_GetKeyspaceNotificationFlagsAll)(void) REDISMODULE_ATTR; +REDISMODULE_API int (*RedisModule_IsSubEventSupported)(RedisModuleEvent event, uint64_t subevent) REDISMODULE_ATTR; +REDISMODULE_API int (*RedisModule_GetServerVersion)(void) REDISMODULE_ATTR; +REDISMODULE_API int (*RedisModule_GetTypeMethodVersion)(void) REDISMODULE_ATTR; +REDISMODULE_API void (*RedisModule_Yield)(RedisModuleCtx *ctx, int flags, const char *busy_reply) REDISMODULE_ATTR; +REDISMODULE_API RedisModuleBlockedClient * (*RedisModule_BlockClient)(RedisModuleCtx *ctx, RedisModuleCmdFunc reply_callback, RedisModuleCmdFunc timeout_callback, void (*free_privdata)(RedisModuleCtx*,void*), long long timeout_ms) REDISMODULE_ATTR; +REDISMODULE_API void * (*RedisModule_BlockClientGetPrivateData)(RedisModuleBlockedClient *blocked_client) REDISMODULE_ATTR; +REDISMODULE_API void (*RedisModule_BlockClientSetPrivateData)(RedisModuleBlockedClient *blocked_client, void *private_data) REDISMODULE_ATTR; +REDISMODULE_API RedisModuleBlockedClient * (*RedisModule_BlockClientOnAuth)(RedisModuleCtx *ctx, RedisModuleAuthCallback reply_callback, void (*free_privdata)(RedisModuleCtx*,void*)) REDISMODULE_ATTR; +REDISMODULE_API int (*RedisModule_UnblockClient)(RedisModuleBlockedClient *bc, void *privdata) REDISMODULE_ATTR; +REDISMODULE_API int (*RedisModule_IsBlockedReplyRequest)(RedisModuleCtx *ctx) REDISMODULE_ATTR; +REDISMODULE_API int (*RedisModule_IsBlockedTimeoutRequest)(RedisModuleCtx *ctx) REDISMODULE_ATTR; +REDISMODULE_API void * (*RedisModule_GetBlockedClientPrivateData)(RedisModuleCtx *ctx) REDISMODULE_ATTR; +REDISMODULE_API RedisModuleBlockedClient * (*RedisModule_GetBlockedClientHandle)(RedisModuleCtx *ctx) REDISMODULE_ATTR; +REDISMODULE_API int (*RedisModule_AbortBlock)(RedisModuleBlockedClient *bc) REDISMODULE_ATTR; +REDISMODULE_API int (*RedisModule_BlockedClientMeasureTimeStart)(RedisModuleBlockedClient *bc) REDISMODULE_ATTR; +REDISMODULE_API int (*RedisModule_BlockedClientMeasureTimeEnd)(RedisModuleBlockedClient *bc) REDISMODULE_ATTR; +REDISMODULE_API RedisModuleCtx * (*RedisModule_GetThreadSafeContext)(RedisModuleBlockedClient *bc) REDISMODULE_ATTR; +REDISMODULE_API RedisModuleCtx * (*RedisModule_GetDetachedThreadSafeContext)(RedisModuleCtx *ctx) REDISMODULE_ATTR; +REDISMODULE_API void (*RedisModule_FreeThreadSafeContext)(RedisModuleCtx *ctx) REDISMODULE_ATTR; +REDISMODULE_API void (*RedisModule_ThreadSafeContextLock)(RedisModuleCtx *ctx) REDISMODULE_ATTR; +REDISMODULE_API int (*RedisModule_ThreadSafeContextTryLock)(RedisModuleCtx *ctx) REDISMODULE_ATTR; +REDISMODULE_API void (*RedisModule_ThreadSafeContextUnlock)(RedisModuleCtx *ctx) REDISMODULE_ATTR; +REDISMODULE_API int (*RedisModule_SubscribeToKeyspaceEvents)(RedisModuleCtx *ctx, int types, RedisModuleNotificationFunc cb) REDISMODULE_ATTR; +REDISMODULE_API int (*RedisModule_AddPostNotificationJob)(RedisModuleCtx *ctx, RedisModulePostNotificationJobFunc callback, void *pd, void (*free_pd)(void*)) REDISMODULE_ATTR; +REDISMODULE_API int (*RedisModule_NotifyKeyspaceEvent)(RedisModuleCtx *ctx, int type, const char *event, RedisModuleString *key) REDISMODULE_ATTR; +REDISMODULE_API int (*RedisModule_GetNotifyKeyspaceEvents)(void) REDISMODULE_ATTR; +REDISMODULE_API int (*RedisModule_BlockedClientDisconnected)(RedisModuleCtx *ctx) REDISMODULE_ATTR; +REDISMODULE_API void (*RedisModule_RegisterClusterMessageReceiver)(RedisModuleCtx *ctx, uint8_t type, RedisModuleClusterMessageReceiver callback) REDISMODULE_ATTR; +REDISMODULE_API int (*RedisModule_SendClusterMessage)(RedisModuleCtx *ctx, const char *target_id, uint8_t type, const char *msg, uint32_t len) REDISMODULE_ATTR; +REDISMODULE_API int (*RedisModule_GetClusterNodeInfo)(RedisModuleCtx *ctx, const char *id, char *ip, char *master_id, int *port, int *flags) REDISMODULE_ATTR; +REDISMODULE_API char ** (*RedisModule_GetClusterNodesList)(RedisModuleCtx *ctx, size_t *numnodes) REDISMODULE_ATTR; +REDISMODULE_API void (*RedisModule_FreeClusterNodesList)(char **ids) REDISMODULE_ATTR; +REDISMODULE_API RedisModuleTimerID (*RedisModule_CreateTimer)(RedisModuleCtx *ctx, mstime_t period, RedisModuleTimerProc callback, void *data) REDISMODULE_ATTR; +REDISMODULE_API int (*RedisModule_StopTimer)(RedisModuleCtx *ctx, RedisModuleTimerID id, void **data) REDISMODULE_ATTR; +REDISMODULE_API int (*RedisModule_GetTimerInfo)(RedisModuleCtx *ctx, RedisModuleTimerID id, uint64_t *remaining, void **data) REDISMODULE_ATTR; +REDISMODULE_API const char * (*RedisModule_GetMyClusterID)(void) REDISMODULE_ATTR; +REDISMODULE_API size_t (*RedisModule_GetClusterSize)(void) REDISMODULE_ATTR; +REDISMODULE_API void (*RedisModule_GetRandomBytes)(unsigned char *dst, size_t len) REDISMODULE_ATTR; +REDISMODULE_API void (*RedisModule_GetRandomHexChars)(char *dst, size_t len) REDISMODULE_ATTR; +REDISMODULE_API void (*RedisModule_SetDisconnectCallback)(RedisModuleBlockedClient *bc, RedisModuleDisconnectFunc callback) REDISMODULE_ATTR; +REDISMODULE_API void (*RedisModule_SetClusterFlags)(RedisModuleCtx *ctx, uint64_t flags) REDISMODULE_ATTR; +REDISMODULE_API unsigned int (*RedisModule_ClusterKeySlot)(RedisModuleString *key) REDISMODULE_ATTR; +REDISMODULE_API const char *(*RedisModule_ClusterCanonicalKeyNameInSlot)(unsigned int slot) REDISMODULE_ATTR; +REDISMODULE_API int (*RedisModule_ExportSharedAPI)(RedisModuleCtx *ctx, const char *apiname, void *func) REDISMODULE_ATTR; +REDISMODULE_API void * (*RedisModule_GetSharedAPI)(RedisModuleCtx *ctx, const char *apiname) REDISMODULE_ATTR; +REDISMODULE_API RedisModuleCommandFilter * (*RedisModule_RegisterCommandFilter)(RedisModuleCtx *ctx, RedisModuleCommandFilterFunc cb, int flags) REDISMODULE_ATTR; +REDISMODULE_API int (*RedisModule_UnregisterCommandFilter)(RedisModuleCtx *ctx, RedisModuleCommandFilter *filter) REDISMODULE_ATTR; +REDISMODULE_API int (*RedisModule_CommandFilterArgsCount)(RedisModuleCommandFilterCtx *fctx) REDISMODULE_ATTR; +REDISMODULE_API RedisModuleString * (*RedisModule_CommandFilterArgGet)(RedisModuleCommandFilterCtx *fctx, int pos) REDISMODULE_ATTR; +REDISMODULE_API int (*RedisModule_CommandFilterArgInsert)(RedisModuleCommandFilterCtx *fctx, int pos, RedisModuleString *arg) REDISMODULE_ATTR; +REDISMODULE_API int (*RedisModule_CommandFilterArgReplace)(RedisModuleCommandFilterCtx *fctx, int pos, RedisModuleString *arg) REDISMODULE_ATTR; +REDISMODULE_API int (*RedisModule_CommandFilterArgDelete)(RedisModuleCommandFilterCtx *fctx, int pos) REDISMODULE_ATTR; +REDISMODULE_API unsigned long long (*RedisModule_CommandFilterGetClientId)(RedisModuleCommandFilterCtx *fctx) REDISMODULE_ATTR; +REDISMODULE_API int (*RedisModule_Fork)(RedisModuleForkDoneHandler cb, void *user_data) REDISMODULE_ATTR; +REDISMODULE_API void (*RedisModule_SendChildHeartbeat)(double progress) REDISMODULE_ATTR; +REDISMODULE_API int (*RedisModule_ExitFromChild)(int retcode) REDISMODULE_ATTR; +REDISMODULE_API int (*RedisModule_KillForkChild)(int child_pid) REDISMODULE_ATTR; +REDISMODULE_API float (*RedisModule_GetUsedMemoryRatio)(void) REDISMODULE_ATTR; +REDISMODULE_API size_t (*RedisModule_MallocSize)(void* ptr) REDISMODULE_ATTR; +REDISMODULE_API size_t (*RedisModule_MallocUsableSize)(void *ptr) REDISMODULE_ATTR; +REDISMODULE_API size_t (*RedisModule_MallocSizeString)(RedisModuleString* str) REDISMODULE_ATTR; +REDISMODULE_API size_t (*RedisModule_MallocSizeDict)(RedisModuleDict* dict) REDISMODULE_ATTR; +REDISMODULE_API RedisModuleUser * (*RedisModule_CreateModuleUser)(const char *name) REDISMODULE_ATTR; +REDISMODULE_API void (*RedisModule_FreeModuleUser)(RedisModuleUser *user) REDISMODULE_ATTR; +REDISMODULE_API void (*RedisModule_SetContextUser)(RedisModuleCtx *ctx, const RedisModuleUser *user) REDISMODULE_ATTR; +REDISMODULE_API int (*RedisModule_SetModuleUserACL)(RedisModuleUser *user, const char* acl) REDISMODULE_ATTR; +REDISMODULE_API int (*RedisModule_SetModuleUserACLString)(RedisModuleCtx * ctx, RedisModuleUser *user, const char* acl, RedisModuleString **error) REDISMODULE_ATTR; +REDISMODULE_API RedisModuleString * (*RedisModule_GetModuleUserACLString)(RedisModuleUser *user) REDISMODULE_ATTR; +REDISMODULE_API RedisModuleString * (*RedisModule_GetCurrentUserName)(RedisModuleCtx *ctx) REDISMODULE_ATTR; +REDISMODULE_API RedisModuleUser * (*RedisModule_GetModuleUserFromUserName)(RedisModuleString *name) REDISMODULE_ATTR; +REDISMODULE_API int (*RedisModule_ACLCheckCommandPermissions)(RedisModuleUser *user, RedisModuleString **argv, int argc) REDISMODULE_ATTR; +REDISMODULE_API int (*RedisModule_ACLCheckKeyPermissions)(RedisModuleUser *user, RedisModuleString *key, int flags) REDISMODULE_ATTR; +REDISMODULE_API int (*RedisModule_ACLCheckChannelPermissions)(RedisModuleUser *user, RedisModuleString *ch, int literal) REDISMODULE_ATTR; +REDISMODULE_API void (*RedisModule_ACLAddLogEntry)(RedisModuleCtx *ctx, RedisModuleUser *user, RedisModuleString *object, RedisModuleACLLogEntryReason reason) REDISMODULE_ATTR; +REDISMODULE_API void (*RedisModule_ACLAddLogEntryByUserName)(RedisModuleCtx *ctx, RedisModuleString *user, RedisModuleString *object, RedisModuleACLLogEntryReason reason) REDISMODULE_ATTR; +REDISMODULE_API int (*RedisModule_AuthenticateClientWithACLUser)(RedisModuleCtx *ctx, const char *name, size_t len, RedisModuleUserChangedFunc callback, void *privdata, uint64_t *client_id) REDISMODULE_ATTR; +REDISMODULE_API int (*RedisModule_AuthenticateClientWithUser)(RedisModuleCtx *ctx, RedisModuleUser *user, RedisModuleUserChangedFunc callback, void *privdata, uint64_t *client_id) REDISMODULE_ATTR; +REDISMODULE_API int (*RedisModule_DeauthenticateAndCloseClient)(RedisModuleCtx *ctx, uint64_t client_id) REDISMODULE_ATTR; +REDISMODULE_API int (*RedisModule_RedactClientCommandArgument)(RedisModuleCtx *ctx, int pos) REDISMODULE_ATTR; +REDISMODULE_API RedisModuleString * (*RedisModule_GetClientCertificate)(RedisModuleCtx *ctx, uint64_t id) REDISMODULE_ATTR; +REDISMODULE_API int *(*RedisModule_GetCommandKeys)(RedisModuleCtx *ctx, RedisModuleString **argv, int argc, int *num_keys) REDISMODULE_ATTR; +REDISMODULE_API int *(*RedisModule_GetCommandKeysWithFlags)(RedisModuleCtx *ctx, RedisModuleString **argv, int argc, int *num_keys, int **out_flags) REDISMODULE_ATTR; +REDISMODULE_API const char *(*RedisModule_GetCurrentCommandName)(RedisModuleCtx *ctx) REDISMODULE_ATTR; +REDISMODULE_API int (*RedisModule_RegisterDefragFunc)(RedisModuleCtx *ctx, RedisModuleDefragFunc func) REDISMODULE_ATTR; +REDISMODULE_API int (*RedisModule_RegisterDefragCallbacks)(RedisModuleCtx *ctx, RedisModuleDefragFunc start, RedisModuleDefragFunc end) REDISMODULE_ATTR; +REDISMODULE_API void *(*RedisModule_DefragAlloc)(RedisModuleDefragCtx *ctx, void *ptr) REDISMODULE_ATTR; +REDISMODULE_API void *(*RedisModule_DefragAllocRaw)(RedisModuleDefragCtx *ctx, size_t size) REDISMODULE_ATTR; +REDISMODULE_API void (*RedisModule_DefragFreeRaw)(RedisModuleDefragCtx *ctx, void *ptr) REDISMODULE_ATTR; +REDISMODULE_API RedisModuleString *(*RedisModule_DefragRedisModuleString)(RedisModuleDefragCtx *ctx, RedisModuleString *str) REDISMODULE_ATTR; +REDISMODULE_API int (*RedisModule_DefragShouldStop)(RedisModuleDefragCtx *ctx) REDISMODULE_ATTR; +REDISMODULE_API int (*RedisModule_DefragCursorSet)(RedisModuleDefragCtx *ctx, unsigned long cursor) REDISMODULE_ATTR; +REDISMODULE_API int (*RedisModule_DefragCursorGet)(RedisModuleDefragCtx *ctx, unsigned long *cursor) REDISMODULE_ATTR; +REDISMODULE_API int (*RedisModule_GetDbIdFromDefragCtx)(RedisModuleDefragCtx *ctx) REDISMODULE_ATTR; +REDISMODULE_API const RedisModuleString * (*RedisModule_GetKeyNameFromDefragCtx)(RedisModuleDefragCtx *ctx) REDISMODULE_ATTR; +REDISMODULE_API int (*RedisModule_EventLoopAdd)(int fd, int mask, RedisModuleEventLoopFunc func, void *user_data) REDISMODULE_ATTR; +REDISMODULE_API int (*RedisModule_EventLoopDel)(int fd, int mask) REDISMODULE_ATTR; +REDISMODULE_API int (*RedisModule_EventLoopAddOneShot)(RedisModuleEventLoopOneShotFunc func, void *user_data) REDISMODULE_ATTR; +REDISMODULE_API int (*RedisModule_RegisterBoolConfig)(RedisModuleCtx *ctx, const char *name, int default_val, unsigned int flags, RedisModuleConfigGetBoolFunc getfn, RedisModuleConfigSetBoolFunc setfn, RedisModuleConfigApplyFunc applyfn, void *privdata) REDISMODULE_ATTR; +REDISMODULE_API int (*RedisModule_RegisterNumericConfig)(RedisModuleCtx *ctx, const char *name, long long default_val, unsigned int flags, long long min, long long max, RedisModuleConfigGetNumericFunc getfn, RedisModuleConfigSetNumericFunc setfn, RedisModuleConfigApplyFunc applyfn, void *privdata) REDISMODULE_ATTR; +REDISMODULE_API int (*RedisModule_RegisterStringConfig)(RedisModuleCtx *ctx, const char *name, const char *default_val, unsigned int flags, RedisModuleConfigGetStringFunc getfn, RedisModuleConfigSetStringFunc setfn, RedisModuleConfigApplyFunc applyfn, void *privdata) REDISMODULE_ATTR; +REDISMODULE_API int (*RedisModule_RegisterEnumConfig)(RedisModuleCtx *ctx, const char *name, int default_val, unsigned int flags, const char **enum_values, const int *int_values, int num_enum_vals, RedisModuleConfigGetEnumFunc getfn, RedisModuleConfigSetEnumFunc setfn, RedisModuleConfigApplyFunc applyfn, void *privdata) REDISMODULE_ATTR; +REDISMODULE_API int (*RedisModule_LoadConfigs)(RedisModuleCtx *ctx) REDISMODULE_ATTR; +REDISMODULE_API RedisModuleRdbStream *(*RedisModule_RdbStreamCreateFromFile)(const char *filename) REDISMODULE_ATTR; +REDISMODULE_API void (*RedisModule_RdbStreamFree)(RedisModuleRdbStream *stream) REDISMODULE_ATTR; +REDISMODULE_API int (*RedisModule_RdbLoad)(RedisModuleCtx *ctx, RedisModuleRdbStream *stream, int flags) REDISMODULE_ATTR; +REDISMODULE_API int (*RedisModule_RdbSave)(RedisModuleCtx *ctx, RedisModuleRdbStream *stream, int flags) REDISMODULE_ATTR; + +#define RedisModule_IsAOFClient(id) ((id) == UINT64_MAX) + +/* This is included inline inside each Redis module. */ +static int RedisModule_Init(RedisModuleCtx *ctx, const char *name, int ver, int apiver) REDISMODULE_ATTR_UNUSED; +static int RedisModule_Init(RedisModuleCtx *ctx, const char *name, int ver, int apiver) { + void *getapifuncptr = ((void**)ctx)[0]; + RedisModule_GetApi = (int (*)(const char *, void *)) (unsigned long)getapifuncptr; + REDISMODULE_GET_API(Alloc); + REDISMODULE_GET_API(TryAlloc); + REDISMODULE_GET_API(Calloc); + REDISMODULE_GET_API(TryCalloc); + REDISMODULE_GET_API(Free); + REDISMODULE_GET_API(Realloc); + REDISMODULE_GET_API(TryRealloc); + REDISMODULE_GET_API(Strdup); + REDISMODULE_GET_API(CreateCommand); + REDISMODULE_GET_API(GetCommand); + REDISMODULE_GET_API(CreateSubcommand); + REDISMODULE_GET_API(SetCommandInfo); + REDISMODULE_GET_API(SetCommandACLCategories); + REDISMODULE_GET_API(AddACLCategory); + REDISMODULE_GET_API(SetModuleAttribs); + REDISMODULE_GET_API(IsModuleNameBusy); + REDISMODULE_GET_API(WrongArity); + REDISMODULE_GET_API(ReplyWithLongLong); + REDISMODULE_GET_API(ReplyWithError); + REDISMODULE_GET_API(ReplyWithErrorFormat); + REDISMODULE_GET_API(ReplyWithSimpleString); + REDISMODULE_GET_API(ReplyWithArray); + REDISMODULE_GET_API(ReplyWithMap); + REDISMODULE_GET_API(ReplyWithSet); + REDISMODULE_GET_API(ReplyWithAttribute); + REDISMODULE_GET_API(ReplyWithNullArray); + REDISMODULE_GET_API(ReplyWithEmptyArray); + REDISMODULE_GET_API(ReplySetArrayLength); + REDISMODULE_GET_API(ReplySetMapLength); + REDISMODULE_GET_API(ReplySetSetLength); + REDISMODULE_GET_API(ReplySetAttributeLength); + REDISMODULE_GET_API(ReplySetPushLength); + REDISMODULE_GET_API(ReplyWithStringBuffer); + REDISMODULE_GET_API(ReplyWithCString); + REDISMODULE_GET_API(ReplyWithString); + REDISMODULE_GET_API(ReplyWithEmptyString); + REDISMODULE_GET_API(ReplyWithVerbatimString); + REDISMODULE_GET_API(ReplyWithVerbatimStringType); + REDISMODULE_GET_API(ReplyWithNull); + REDISMODULE_GET_API(ReplyWithBool); + REDISMODULE_GET_API(ReplyWithCallReply); + REDISMODULE_GET_API(ReplyWithDouble); + REDISMODULE_GET_API(ReplyWithBigNumber); + REDISMODULE_GET_API(ReplyWithLongDouble); + REDISMODULE_GET_API(GetSelectedDb); + REDISMODULE_GET_API(SelectDb); + REDISMODULE_GET_API(KeyExists); + REDISMODULE_GET_API(OpenKey); + REDISMODULE_GET_API(GetOpenKeyModesAll); + REDISMODULE_GET_API(CloseKey); + REDISMODULE_GET_API(KeyType); + REDISMODULE_GET_API(ValueLength); + REDISMODULE_GET_API(ListPush); + REDISMODULE_GET_API(ListPop); + REDISMODULE_GET_API(ListGet); + REDISMODULE_GET_API(ListSet); + REDISMODULE_GET_API(ListInsert); + REDISMODULE_GET_API(ListDelete); + REDISMODULE_GET_API(StringToLongLong); + REDISMODULE_GET_API(StringToULongLong); + REDISMODULE_GET_API(StringToDouble); + REDISMODULE_GET_API(StringToLongDouble); + REDISMODULE_GET_API(StringToStreamID); + REDISMODULE_GET_API(Call); + REDISMODULE_GET_API(CallReplyProto); + REDISMODULE_GET_API(FreeCallReply); + REDISMODULE_GET_API(CallReplyInteger); + REDISMODULE_GET_API(CallReplyDouble); + REDISMODULE_GET_API(CallReplyBool); + REDISMODULE_GET_API(CallReplyBigNumber); + REDISMODULE_GET_API(CallReplyVerbatim); + REDISMODULE_GET_API(CallReplySetElement); + REDISMODULE_GET_API(CallReplyMapElement); + REDISMODULE_GET_API(CallReplyAttributeElement); + REDISMODULE_GET_API(CallReplyPromiseSetUnblockHandler); + REDISMODULE_GET_API(CallReplyPromiseAbort); + REDISMODULE_GET_API(CallReplyAttribute); + REDISMODULE_GET_API(CallReplyType); + REDISMODULE_GET_API(CallReplyLength); + REDISMODULE_GET_API(CallReplyArrayElement); + REDISMODULE_GET_API(CallReplyStringPtr); + REDISMODULE_GET_API(CreateStringFromCallReply); + REDISMODULE_GET_API(CreateString); + REDISMODULE_GET_API(CreateStringFromLongLong); + REDISMODULE_GET_API(CreateStringFromULongLong); + REDISMODULE_GET_API(CreateStringFromDouble); + REDISMODULE_GET_API(CreateStringFromLongDouble); + REDISMODULE_GET_API(CreateStringFromString); + REDISMODULE_GET_API(CreateStringFromStreamID); + REDISMODULE_GET_API(CreateStringPrintf); + REDISMODULE_GET_API(FreeString); + REDISMODULE_GET_API(StringPtrLen); + REDISMODULE_GET_API(AutoMemory); + REDISMODULE_GET_API(Replicate); + REDISMODULE_GET_API(ReplicateVerbatim); + REDISMODULE_GET_API(DeleteKey); + REDISMODULE_GET_API(UnlinkKey); + REDISMODULE_GET_API(StringSet); + REDISMODULE_GET_API(StringDMA); + REDISMODULE_GET_API(StringTruncate); + REDISMODULE_GET_API(GetExpire); + REDISMODULE_GET_API(SetExpire); + REDISMODULE_GET_API(GetAbsExpire); + REDISMODULE_GET_API(SetAbsExpire); + REDISMODULE_GET_API(ResetDataset); + REDISMODULE_GET_API(DbSize); + REDISMODULE_GET_API(RandomKey); + REDISMODULE_GET_API(ZsetAdd); + REDISMODULE_GET_API(ZsetIncrby); + REDISMODULE_GET_API(ZsetScore); + REDISMODULE_GET_API(ZsetRem); + REDISMODULE_GET_API(ZsetRangeStop); + REDISMODULE_GET_API(ZsetFirstInScoreRange); + REDISMODULE_GET_API(ZsetLastInScoreRange); + REDISMODULE_GET_API(ZsetFirstInLexRange); + REDISMODULE_GET_API(ZsetLastInLexRange); + REDISMODULE_GET_API(ZsetRangeCurrentElement); + REDISMODULE_GET_API(ZsetRangeNext); + REDISMODULE_GET_API(ZsetRangePrev); + REDISMODULE_GET_API(ZsetRangeEndReached); + REDISMODULE_GET_API(HashSet); + REDISMODULE_GET_API(HashGet); + REDISMODULE_GET_API(StreamAdd); + REDISMODULE_GET_API(StreamDelete); + REDISMODULE_GET_API(StreamIteratorStart); + REDISMODULE_GET_API(StreamIteratorStop); + REDISMODULE_GET_API(StreamIteratorNextID); + REDISMODULE_GET_API(StreamIteratorNextField); + REDISMODULE_GET_API(StreamIteratorDelete); + REDISMODULE_GET_API(StreamTrimByLength); + REDISMODULE_GET_API(StreamTrimByID); + REDISMODULE_GET_API(IsKeysPositionRequest); + REDISMODULE_GET_API(KeyAtPos); + REDISMODULE_GET_API(KeyAtPosWithFlags); + REDISMODULE_GET_API(IsChannelsPositionRequest); + REDISMODULE_GET_API(ChannelAtPosWithFlags); + REDISMODULE_GET_API(GetClientId); + REDISMODULE_GET_API(GetClientUserNameById); + REDISMODULE_GET_API(GetContextFlags); + REDISMODULE_GET_API(AvoidReplicaTraffic); + REDISMODULE_GET_API(PoolAlloc); + REDISMODULE_GET_API(CreateDataType); + REDISMODULE_GET_API(ModuleTypeSetValue); + REDISMODULE_GET_API(ModuleTypeReplaceValue); + REDISMODULE_GET_API(ModuleTypeGetType); + REDISMODULE_GET_API(ModuleTypeGetValue); + REDISMODULE_GET_API(IsIOError); + REDISMODULE_GET_API(SetModuleOptions); + REDISMODULE_GET_API(SignalModifiedKey); + REDISMODULE_GET_API(SaveUnsigned); + REDISMODULE_GET_API(LoadUnsigned); + REDISMODULE_GET_API(SaveSigned); + REDISMODULE_GET_API(LoadSigned); + REDISMODULE_GET_API(SaveString); + REDISMODULE_GET_API(SaveStringBuffer); + REDISMODULE_GET_API(LoadString); + REDISMODULE_GET_API(LoadStringBuffer); + REDISMODULE_GET_API(SaveDouble); + REDISMODULE_GET_API(LoadDouble); + REDISMODULE_GET_API(SaveFloat); + REDISMODULE_GET_API(LoadFloat); + REDISMODULE_GET_API(SaveLongDouble); + REDISMODULE_GET_API(LoadLongDouble); + REDISMODULE_GET_API(SaveDataTypeToString); + REDISMODULE_GET_API(LoadDataTypeFromString); + REDISMODULE_GET_API(LoadDataTypeFromStringEncver); + REDISMODULE_GET_API(EmitAOF); + REDISMODULE_GET_API(Log); + REDISMODULE_GET_API(LogIOError); + REDISMODULE_GET_API(_Assert); + REDISMODULE_GET_API(LatencyAddSample); + REDISMODULE_GET_API(StringAppendBuffer); + REDISMODULE_GET_API(TrimStringAllocation); + REDISMODULE_GET_API(RetainString); + REDISMODULE_GET_API(HoldString); + REDISMODULE_GET_API(StringCompare); + REDISMODULE_GET_API(GetContextFromIO); + REDISMODULE_GET_API(GetKeyNameFromIO); + REDISMODULE_GET_API(GetKeyNameFromModuleKey); + REDISMODULE_GET_API(GetDbIdFromModuleKey); + REDISMODULE_GET_API(GetDbIdFromIO); + REDISMODULE_GET_API(GetKeyNameFromOptCtx); + REDISMODULE_GET_API(GetToKeyNameFromOptCtx); + REDISMODULE_GET_API(GetDbIdFromOptCtx); + REDISMODULE_GET_API(GetToDbIdFromOptCtx); + REDISMODULE_GET_API(Milliseconds); + REDISMODULE_GET_API(MonotonicMicroseconds); + REDISMODULE_GET_API(Microseconds); + REDISMODULE_GET_API(CachedMicroseconds); + REDISMODULE_GET_API(DigestAddStringBuffer); + REDISMODULE_GET_API(DigestAddLongLong); + REDISMODULE_GET_API(DigestEndSequence); + REDISMODULE_GET_API(GetKeyNameFromDigest); + REDISMODULE_GET_API(GetDbIdFromDigest); + REDISMODULE_GET_API(CreateDict); + REDISMODULE_GET_API(FreeDict); + REDISMODULE_GET_API(DictSize); + REDISMODULE_GET_API(DictSetC); + REDISMODULE_GET_API(DictReplaceC); + REDISMODULE_GET_API(DictSet); + REDISMODULE_GET_API(DictReplace); + REDISMODULE_GET_API(DictGetC); + REDISMODULE_GET_API(DictGet); + REDISMODULE_GET_API(DictDelC); + REDISMODULE_GET_API(DictDel); + REDISMODULE_GET_API(DictIteratorStartC); + REDISMODULE_GET_API(DictIteratorStart); + REDISMODULE_GET_API(DictIteratorStop); + REDISMODULE_GET_API(DictIteratorReseekC); + REDISMODULE_GET_API(DictIteratorReseek); + REDISMODULE_GET_API(DictNextC); + REDISMODULE_GET_API(DictPrevC); + REDISMODULE_GET_API(DictNext); + REDISMODULE_GET_API(DictPrev); + REDISMODULE_GET_API(DictCompare); + REDISMODULE_GET_API(DictCompareC); + REDISMODULE_GET_API(RegisterInfoFunc); + REDISMODULE_GET_API(RegisterAuthCallback); + REDISMODULE_GET_API(InfoAddSection); + REDISMODULE_GET_API(InfoBeginDictField); + REDISMODULE_GET_API(InfoEndDictField); + REDISMODULE_GET_API(InfoAddFieldString); + REDISMODULE_GET_API(InfoAddFieldCString); + REDISMODULE_GET_API(InfoAddFieldDouble); + REDISMODULE_GET_API(InfoAddFieldLongLong); + REDISMODULE_GET_API(InfoAddFieldULongLong); + REDISMODULE_GET_API(GetServerInfo); + REDISMODULE_GET_API(FreeServerInfo); + REDISMODULE_GET_API(ServerInfoGetField); + REDISMODULE_GET_API(ServerInfoGetFieldC); + REDISMODULE_GET_API(ServerInfoGetFieldSigned); + REDISMODULE_GET_API(ServerInfoGetFieldUnsigned); + REDISMODULE_GET_API(ServerInfoGetFieldDouble); + REDISMODULE_GET_API(GetClientInfoById); + REDISMODULE_GET_API(GetClientNameById); + REDISMODULE_GET_API(SetClientNameById); + REDISMODULE_GET_API(PublishMessage); + REDISMODULE_GET_API(PublishMessageShard); + REDISMODULE_GET_API(SubscribeToServerEvent); + REDISMODULE_GET_API(SetLRU); + REDISMODULE_GET_API(GetLRU); + REDISMODULE_GET_API(SetLFU); + REDISMODULE_GET_API(GetLFU); + REDISMODULE_GET_API(BlockClientOnKeys); + REDISMODULE_GET_API(BlockClientOnKeysWithFlags); + REDISMODULE_GET_API(SignalKeyAsReady); + REDISMODULE_GET_API(GetBlockedClientReadyKey); + REDISMODULE_GET_API(ScanCursorCreate); + REDISMODULE_GET_API(ScanCursorRestart); + REDISMODULE_GET_API(ScanCursorDestroy); + REDISMODULE_GET_API(Scan); + REDISMODULE_GET_API(ScanKey); + REDISMODULE_GET_API(GetContextFlagsAll); + REDISMODULE_GET_API(GetModuleOptionsAll); + REDISMODULE_GET_API(GetKeyspaceNotificationFlagsAll); + REDISMODULE_GET_API(IsSubEventSupported); + REDISMODULE_GET_API(GetServerVersion); + REDISMODULE_GET_API(GetTypeMethodVersion); + REDISMODULE_GET_API(Yield); + REDISMODULE_GET_API(GetThreadSafeContext); + REDISMODULE_GET_API(GetDetachedThreadSafeContext); + REDISMODULE_GET_API(FreeThreadSafeContext); + REDISMODULE_GET_API(ThreadSafeContextLock); + REDISMODULE_GET_API(ThreadSafeContextTryLock); + REDISMODULE_GET_API(ThreadSafeContextUnlock); + REDISMODULE_GET_API(BlockClient); + REDISMODULE_GET_API(BlockClientGetPrivateData); + REDISMODULE_GET_API(BlockClientSetPrivateData); + REDISMODULE_GET_API(BlockClientOnAuth); + REDISMODULE_GET_API(UnblockClient); + REDISMODULE_GET_API(IsBlockedReplyRequest); + REDISMODULE_GET_API(IsBlockedTimeoutRequest); + REDISMODULE_GET_API(GetBlockedClientPrivateData); + REDISMODULE_GET_API(GetBlockedClientHandle); + REDISMODULE_GET_API(AbortBlock); + REDISMODULE_GET_API(BlockedClientMeasureTimeStart); + REDISMODULE_GET_API(BlockedClientMeasureTimeEnd); + REDISMODULE_GET_API(SetDisconnectCallback); + REDISMODULE_GET_API(SubscribeToKeyspaceEvents); + REDISMODULE_GET_API(AddPostNotificationJob); + REDISMODULE_GET_API(NotifyKeyspaceEvent); + REDISMODULE_GET_API(GetNotifyKeyspaceEvents); + REDISMODULE_GET_API(BlockedClientDisconnected); + REDISMODULE_GET_API(RegisterClusterMessageReceiver); + REDISMODULE_GET_API(SendClusterMessage); + REDISMODULE_GET_API(GetClusterNodeInfo); + REDISMODULE_GET_API(GetClusterNodesList); + REDISMODULE_GET_API(FreeClusterNodesList); + REDISMODULE_GET_API(CreateTimer); + REDISMODULE_GET_API(StopTimer); + REDISMODULE_GET_API(GetTimerInfo); + REDISMODULE_GET_API(GetMyClusterID); + REDISMODULE_GET_API(GetClusterSize); + REDISMODULE_GET_API(GetRandomBytes); + REDISMODULE_GET_API(GetRandomHexChars); + REDISMODULE_GET_API(SetClusterFlags); + REDISMODULE_GET_API(ClusterKeySlot); + REDISMODULE_GET_API(ClusterCanonicalKeyNameInSlot); + REDISMODULE_GET_API(ExportSharedAPI); + REDISMODULE_GET_API(GetSharedAPI); + REDISMODULE_GET_API(RegisterCommandFilter); + REDISMODULE_GET_API(UnregisterCommandFilter); + REDISMODULE_GET_API(CommandFilterArgsCount); + REDISMODULE_GET_API(CommandFilterArgGet); + REDISMODULE_GET_API(CommandFilterArgInsert); + REDISMODULE_GET_API(CommandFilterArgReplace); + REDISMODULE_GET_API(CommandFilterArgDelete); + REDISMODULE_GET_API(CommandFilterGetClientId); + REDISMODULE_GET_API(Fork); + REDISMODULE_GET_API(SendChildHeartbeat); + REDISMODULE_GET_API(ExitFromChild); + REDISMODULE_GET_API(KillForkChild); + REDISMODULE_GET_API(GetUsedMemoryRatio); + REDISMODULE_GET_API(MallocSize); + REDISMODULE_GET_API(MallocUsableSize); + REDISMODULE_GET_API(MallocSizeString); + REDISMODULE_GET_API(MallocSizeDict); + REDISMODULE_GET_API(CreateModuleUser); + REDISMODULE_GET_API(FreeModuleUser); + REDISMODULE_GET_API(SetContextUser); + REDISMODULE_GET_API(SetModuleUserACL); + REDISMODULE_GET_API(SetModuleUserACLString); + REDISMODULE_GET_API(GetModuleUserACLString); + REDISMODULE_GET_API(GetCurrentUserName); + REDISMODULE_GET_API(GetModuleUserFromUserName); + REDISMODULE_GET_API(ACLCheckCommandPermissions); + REDISMODULE_GET_API(ACLCheckKeyPermissions); + REDISMODULE_GET_API(ACLCheckChannelPermissions); + REDISMODULE_GET_API(ACLAddLogEntry); + REDISMODULE_GET_API(ACLAddLogEntryByUserName); + REDISMODULE_GET_API(DeauthenticateAndCloseClient); + REDISMODULE_GET_API(AuthenticateClientWithACLUser); + REDISMODULE_GET_API(AuthenticateClientWithUser); + REDISMODULE_GET_API(RedactClientCommandArgument); + REDISMODULE_GET_API(GetClientCertificate); + REDISMODULE_GET_API(GetCommandKeys); + REDISMODULE_GET_API(GetCommandKeysWithFlags); + REDISMODULE_GET_API(GetCurrentCommandName); + REDISMODULE_GET_API(RegisterDefragFunc); + REDISMODULE_GET_API(RegisterDefragCallbacks); + REDISMODULE_GET_API(DefragAlloc); + REDISMODULE_GET_API(DefragAllocRaw); + REDISMODULE_GET_API(DefragFreeRaw); + REDISMODULE_GET_API(DefragRedisModuleString); + REDISMODULE_GET_API(DefragShouldStop); + REDISMODULE_GET_API(DefragCursorSet); + REDISMODULE_GET_API(DefragCursorGet); + REDISMODULE_GET_API(GetKeyNameFromDefragCtx); + REDISMODULE_GET_API(GetDbIdFromDefragCtx); + REDISMODULE_GET_API(EventLoopAdd); + REDISMODULE_GET_API(EventLoopDel); + REDISMODULE_GET_API(EventLoopAddOneShot); + REDISMODULE_GET_API(RegisterBoolConfig); + REDISMODULE_GET_API(RegisterNumericConfig); + REDISMODULE_GET_API(RegisterStringConfig); + REDISMODULE_GET_API(RegisterEnumConfig); + REDISMODULE_GET_API(LoadConfigs); + REDISMODULE_GET_API(RdbStreamCreateFromFile); + REDISMODULE_GET_API(RdbStreamFree); + REDISMODULE_GET_API(RdbLoad); + REDISMODULE_GET_API(RdbSave); + + if (RedisModule_IsModuleNameBusy && RedisModule_IsModuleNameBusy(name)) return REDISMODULE_ERR; + RedisModule_SetModuleAttribs(ctx,name,ver,apiver); + return REDISMODULE_OK; +} + +#define RedisModule_Assert(_e) ((_e)?(void)0 : (RedisModule__Assert(#_e,__FILE__,__LINE__),exit(1))) + +#define RMAPI_FUNC_SUPPORTED(func) (func != NULL) + +#endif /* REDISMODULE_CORE */ +#endif /* REDISMODULE_H */ diff --git a/modules/vector-sets/test.py b/modules/vector-sets/test.py new file mode 100755 index 000000000..2e38ba013 --- /dev/null +++ b/modules/vector-sets/test.py @@ -0,0 +1,225 @@ +#!/usr/bin/env python3 +# +# Vector set tests. +# A Redis instance should be running in the default port. +# Copyright(C) 2024-2025 Salvatore Sanfilippo. +# All Rights Reserved. + +#!/usr/bin/env python3 +import redis +import random +import struct +import math +import time +import sys +import os +import importlib +import inspect +from typing import List, Tuple, Optional +from dataclasses import dataclass + +def colored(text: str, color: str) -> str: + colors = { + 'red': '\033[91m', + 'green': '\033[92m' + } + reset = '\033[0m' + return f"{colors.get(color, '')}{text}{reset}" + +@dataclass +class VectorData: + vectors: List[List[float]] + names: List[str] + + def find_k_nearest(self, query_vector: List[float], k: int) -> List[Tuple[str, float]]: + """Find k-nearest neighbors using the same scoring as Redis VSIM WITHSCORES.""" + similarities = [] + query_norm = math.sqrt(sum(x*x for x in query_vector)) + if query_norm == 0: + return [] + + for i, vec in enumerate(self.vectors): + vec_norm = math.sqrt(sum(x*x for x in vec)) + if vec_norm == 0: + continue + + dot_product = sum(a*b for a,b in zip(query_vector, vec)) + cosine_sim = dot_product / (query_norm * vec_norm) + distance = 1.0 - cosine_sim + redis_similarity = 1.0 - (distance/2.0) + similarities.append((self.names[i], redis_similarity)) + + similarities.sort(key=lambda x: x[1], reverse=True) + return similarities[:k] + +def generate_random_vector(dim: int) -> List[float]: + """Generate a random normalized vector.""" + vec = [random.gauss(0, 1) for _ in range(dim)] + norm = math.sqrt(sum(x*x for x in vec)) + return [x/norm for x in vec] + +def fill_redis_with_vectors(r: redis.Redis, key: str, count: int, dim: int, + with_reduce: Optional[int] = None) -> VectorData: + """Fill Redis with random vectors and return a VectorData object for verification.""" + vectors = [] + names = [] + + r.delete(key) + for i in range(count): + vec = generate_random_vector(dim) + name = f"{key}:item:{i}" + vectors.append(vec) + names.append(name) + + vec_bytes = struct.pack(f'{dim}f', *vec) + args = [key] + if with_reduce: + args.extend(['REDUCE', with_reduce]) + args.extend(['FP32', vec_bytes, name]) + r.execute_command('VADD', *args) + + return VectorData(vectors=vectors, names=names) + +class TestCase: + def __init__(self): + self.error_msg = None + self.error_details = None + self.test_key = f"test:{self.__class__.__name__.lower()}" + # Primary Redis instance (default port) + self.redis = redis.Redis() + # Replica Redis instance (port 6380) + self.replica = redis.Redis(port=6380) + # Replication status + self.replication_setup = False + + def setup(self): + self.redis.delete(self.test_key) + + def teardown(self): + self.redis.delete(self.test_key) + + def setup_replication(self) -> bool: + """ + Setup replication between primary and replica Redis instances. + Returns True if replication is successfully established, False otherwise. + """ + # Configure replica to replicate from primary + self.replica.execute_command('REPLICAOF', '127.0.0.1', 6379) + + # Wait for replication to be established + max_attempts = 10 + for attempt in range(max_attempts): + # Check replication info + repl_info = self.replica.info('replication') + + # Check if replication is established + if (repl_info.get('role') == 'slave' and + repl_info.get('master_host') == '127.0.0.1' and + repl_info.get('master_port') == 6379 and + repl_info.get('master_link_status') == 'up'): + + self.replication_setup = True + return True + + # Wait before next attempt + time.sleep(0.5) + + # If we get here, replication wasn't established + self.error_msg = "Failed to establish replication between primary and replica" + return False + + def test(self): + raise NotImplementedError("Subclasses must implement test method") + + def run(self): + try: + self.setup() + self.test() + return True + except AssertionError as e: + self.error_msg = str(e) + import traceback + self.error_details = traceback.format_exc() + return False + except Exception as e: + self.error_msg = f"Unexpected error: {str(e)}" + import traceback + self.error_details = traceback.format_exc() + return False + finally: + self.teardown() + + def getname(self): + """Each test class should override this to provide its name""" + return self.__class__.__name__ + + def estimated_runtime(self): + """"Each test class should override this if it takes a significant amount of time to run. Default is 100ms""" + return 0.1 + +def find_test_classes(): + test_classes = [] + tests_dir = 'tests' + + if not os.path.exists(tests_dir): + return [] + + for file in os.listdir(tests_dir): + if file.endswith('.py'): + module_name = f"tests.{file[:-3]}" + try: + module = importlib.import_module(module_name) + for name, obj in inspect.getmembers(module): + if inspect.isclass(obj) and obj.__name__ != 'TestCase' and hasattr(obj, 'test'): + test_classes.append(obj()) + except Exception as e: + print(f"Error loading {file}: {e}") + + return test_classes + +def run_tests(): + print("================================================\n"+ + "Make sure to have Redis running in the localhost\n"+ + "with --enable-debug-command yes\n"+ + "Both primary (6379) and replica (6380) instances\n"+ + "================================================\n") + + tests = find_test_classes() + if not tests: + print("No tests found!") + return + + # Sort tests by estimated runtime + tests.sort(key=lambda t: t.estimated_runtime()) + + passed = 0 + total = len(tests) + + for test in tests: + print(f"{test.getname()}: ", end="") + sys.stdout.flush() + + start_time = time.time() + success = test.run() + duration = time.time() - start_time + + if success: + print(colored("OK", "green"), f"({duration:.2f}s)") + passed += 1 + else: + print(colored("ERR", "red"), f"({duration:.2f}s)") + print(f"Error: {test.error_msg}") + if test.error_details: + print("\nTraceback:") + print(test.error_details) + + print("\n" + "="*50) + print(f"\nTest Summary: {passed}/{total} tests passed") + + if passed == total: + print(colored("\nALL TESTS PASSED!", "green")) + else: + print(colored(f"\n{total-passed} TESTS FAILED!", "red")) + +if __name__ == "__main__": + run_tests() diff --git a/modules/vector-sets/tests/basic_commands.py b/modules/vector-sets/tests/basic_commands.py new file mode 100644 index 000000000..8481a3668 --- /dev/null +++ b/modules/vector-sets/tests/basic_commands.py @@ -0,0 +1,21 @@ +from test import TestCase, generate_random_vector +import struct + +class BasicCommands(TestCase): + def getname(self): + return "VADD, VDIM, VCARD basic usage" + + def test(self): + # Test VADD + vec = generate_random_vector(4) + vec_bytes = struct.pack('4f', *vec) + result = self.redis.execute_command('VADD', self.test_key, 'FP32', vec_bytes, f'{self.test_key}:item:1') + assert result == 1, "VADD should return 1 for first item" + + # Test VDIM + dim = self.redis.execute_command('VDIM', self.test_key) + assert dim == 4, f"VDIM should return 4, got {dim}" + + # Test VCARD + card = self.redis.execute_command('VCARD', self.test_key) + assert card == 1, f"VCARD should return 1, got {card}" diff --git a/modules/vector-sets/tests/basic_similarity.py b/modules/vector-sets/tests/basic_similarity.py new file mode 100644 index 000000000..11c3c9b17 --- /dev/null +++ b/modules/vector-sets/tests/basic_similarity.py @@ -0,0 +1,35 @@ +from test import TestCase + +class BasicSimilarity(TestCase): + def getname(self): + return "VSIM reported distance makes sense with 4D vectors" + + def test(self): + # Add two very similar vectors, one different + vec1 = [1, 0, 0, 0] + vec2 = [0.99, 0.01, 0, 0] + vec3 = [0.1, 1, -1, 0.5] + + # Add vectors using VALUES format + self.redis.execute_command('VADD', self.test_key, 'VALUES', 4, + *[str(x) for x in vec1], f'{self.test_key}:item:1') + self.redis.execute_command('VADD', self.test_key, 'VALUES', 4, + *[str(x) for x in vec2], f'{self.test_key}:item:2') + self.redis.execute_command('VADD', self.test_key, 'VALUES', 4, + *[str(x) for x in vec3], f'{self.test_key}:item:3') + + # Query similarity with vec1 + result = self.redis.execute_command('VSIM', self.test_key, 'VALUES', 4, + *[str(x) for x in vec1], 'WITHSCORES') + + # Convert results to dictionary + results_dict = {} + for i in range(0, len(result), 2): + key = result[i].decode() + score = float(result[i+1]) + results_dict[key] = score + + # Verify results + assert results_dict[f'{self.test_key}:item:1'] > 0.99, "Self-similarity should be very high" + assert results_dict[f'{self.test_key}:item:2'] > 0.99, "Similar vector should have high similarity" + assert results_dict[f'{self.test_key}:item:3'] < 0.8, "Not very similar vector should have low similarity" diff --git a/modules/vector-sets/tests/concurrent_vadd_cas_del_vsim.py b/modules/vector-sets/tests/concurrent_vadd_cas_del_vsim.py new file mode 100644 index 000000000..f4b3a1212 --- /dev/null +++ b/modules/vector-sets/tests/concurrent_vadd_cas_del_vsim.py @@ -0,0 +1,156 @@ +from test import TestCase, generate_random_vector +import threading +import time +import struct + +class ThreadingStressTest(TestCase): + def getname(self): + return "Concurrent VADD/DEL/VSIM operations stress test" + + def estimated_runtime(self): + return 10 # Test runs for 10 seconds + + def test(self): + # Constants - easy to modify if needed + NUM_VADD_THREADS = 10 + NUM_VSIM_THREADS = 1 + NUM_DEL_THREADS = 1 + TEST_DURATION = 10 # seconds + VECTOR_DIM = 100 + DEL_INTERVAL = 1 # seconds + + # Shared flags and state + stop_event = threading.Event() + error_list = [] + error_lock = threading.Lock() + + def log_error(thread_name, error): + with error_lock: + error_list.append(f"{thread_name}: {error}") + + def vadd_worker(thread_id): + """Thread function to perform VADD operations""" + thread_name = f"VADD-{thread_id}" + try: + vector_count = 0 + while not stop_event.is_set(): + try: + # Generate random vector + vec = generate_random_vector(VECTOR_DIM) + vec_bytes = struct.pack(f'{VECTOR_DIM}f', *vec) + + # Add vector with CAS option + self.redis.execute_command( + 'VADD', + self.test_key, + 'FP32', + vec_bytes, + f'{self.test_key}:item:{thread_id}:{vector_count}', + 'CAS' + ) + + vector_count += 1 + + # Small sleep to reduce CPU pressure + if vector_count % 10 == 0: + time.sleep(0.001) + except Exception as e: + log_error(thread_name, f"Error: {str(e)}") + time.sleep(0.1) # Slight backoff on error + except Exception as e: + log_error(thread_name, f"Thread error: {str(e)}") + + def del_worker(): + """Thread function that deletes the key periodically""" + thread_name = "DEL" + try: + del_count = 0 + while not stop_event.is_set(): + try: + # Sleep first, then delete + time.sleep(DEL_INTERVAL) + if stop_event.is_set(): + break + + self.redis.delete(self.test_key) + del_count += 1 + except Exception as e: + log_error(thread_name, f"Error: {str(e)}") + except Exception as e: + log_error(thread_name, f"Thread error: {str(e)}") + + def vsim_worker(thread_id): + """Thread function to perform VSIM operations""" + thread_name = f"VSIM-{thread_id}" + try: + search_count = 0 + while not stop_event.is_set(): + try: + # Generate query vector + query_vec = generate_random_vector(VECTOR_DIM) + query_str = [str(x) for x in query_vec] + + # Perform similarity search + args = ['VSIM', self.test_key, 'VALUES', VECTOR_DIM] + args.extend(query_str) + args.extend(['COUNT', 10]) + self.redis.execute_command(*args) + + search_count += 1 + + # Small sleep to reduce CPU pressure + if search_count % 10 == 0: + time.sleep(0.005) + except Exception as e: + # Don't log empty array errors, as they're expected when key doesn't exist + if "empty array" not in str(e).lower(): + log_error(thread_name, f"Error: {str(e)}") + time.sleep(0.1) # Slight backoff on error + except Exception as e: + log_error(thread_name, f"Thread error: {str(e)}") + + # Start all threads + threads = [] + + # VADD threads + for i in range(NUM_VADD_THREADS): + thread = threading.Thread(target=vadd_worker, args=(i,)) + thread.start() + threads.append(thread) + + # DEL threads + for _ in range(NUM_DEL_THREADS): + thread = threading.Thread(target=del_worker) + thread.start() + threads.append(thread) + + # VSIM threads + for i in range(NUM_VSIM_THREADS): + thread = threading.Thread(target=vsim_worker, args=(i,)) + thread.start() + threads.append(thread) + + # Let the test run for the specified duration + time.sleep(TEST_DURATION) + + # Signal all threads to stop + stop_event.set() + + # Wait for threads to finish + for thread in threads: + thread.join(timeout=2.0) + + # Check if Redis is still responsive + try: + ping_result = self.redis.ping() + assert ping_result, "Redis did not respond to PING after stress test" + except Exception as e: + assert False, f"Redis connection failed after stress test: {str(e)}" + + # Report any errors for diagnosis, but don't fail the test unless PING fails + if error_list: + error_count = len(error_list) + print(f"\nEncountered {error_count} errors during stress test.") + print("First 5 errors:") + for error in error_list[:5]: + print(f"- {error}") diff --git a/modules/vector-sets/tests/concurrent_vsim_and_del.py b/modules/vector-sets/tests/concurrent_vsim_and_del.py new file mode 100644 index 000000000..9bbf01116 --- /dev/null +++ b/modules/vector-sets/tests/concurrent_vsim_and_del.py @@ -0,0 +1,48 @@ +from test import TestCase, fill_redis_with_vectors, generate_random_vector +import threading, time + +class ConcurrentVSIMAndDEL(TestCase): + def getname(self): + return "Concurrent VSIM and DEL operations" + + def estimated_runtime(self): + return 2 + + def test(self): + # Fill the key with 5000 random vectors + dim = 128 + count = 5000 + fill_redis_with_vectors(self.redis, self.test_key, count, dim) + + # List to store results from threads + thread_results = [] + + def vsim_thread(): + """Thread function to perform VSIM operations until the key is deleted""" + while True: + query_vec = generate_random_vector(dim) + result = self.redis.execute_command('VSIM', self.test_key, 'VALUES', dim, + *[str(x) for x in query_vec], 'COUNT', 10) + if not result: + # Empty array detected, key is deleted + thread_results.append(True) + break + + # Start multiple threads to perform VSIM operations + threads = [] + for _ in range(4): # Start 4 threads + t = threading.Thread(target=vsim_thread) + t.start() + threads.append(t) + + # Delete the key while threads are still running + time.sleep(1) + self.redis.delete(self.test_key) + + # Wait for all threads to finish (they will exit once they detect the key is deleted) + for t in threads: + t.join() + + # Verify that all threads detected an empty array or error + assert len(thread_results) == len(threads), "Not all threads detected the key deletion" + assert all(thread_results), "Some threads did not detect an empty array or error after DEL" diff --git a/modules/vector-sets/tests/debug_digest.py b/modules/vector-sets/tests/debug_digest.py new file mode 100644 index 000000000..78f06d8ef --- /dev/null +++ b/modules/vector-sets/tests/debug_digest.py @@ -0,0 +1,39 @@ +from test import TestCase, generate_random_vector +import struct + +class DebugDigestTest(TestCase): + def getname(self): + return "[regression] DEBUG DIGEST-VALUE with attributes" + + def test(self): + # Generate random vectors + vec1 = generate_random_vector(4) + vec2 = generate_random_vector(4) + vec_bytes1 = struct.pack('4f', *vec1) + vec_bytes2 = struct.pack('4f', *vec2) + + # Add vectors to the key, one with attribute, one without + self.redis.execute_command('VADD', self.test_key, 'FP32', vec_bytes1, f'{self.test_key}:item:1') + self.redis.execute_command('VADD', self.test_key, 'FP32', vec_bytes2, f'{self.test_key}:item:2', 'SETATTR', '{"color":"red"}') + + # Call DEBUG DIGEST-VALUE on the key + try: + digest1 = self.redis.execute_command('DEBUG', 'DIGEST-VALUE', self.test_key) + assert digest1 is not None, "DEBUG DIGEST-VALUE should return a value" + + # Change attribute and verify digest changes + self.redis.execute_command('VSETATTR', self.test_key, f'{self.test_key}:item:2', '{"color":"blue"}') + + digest2 = self.redis.execute_command('DEBUG', 'DIGEST-VALUE', self.test_key) + assert digest2 is not None, "DEBUG DIGEST-VALUE should return a value after attribute change" + assert digest1 != digest2, "Digest should change when an attribute is modified" + + # Remove attribute and verify digest changes again + self.redis.execute_command('VSETATTR', self.test_key, f'{self.test_key}:item:2', '') + + digest3 = self.redis.execute_command('DEBUG', 'DIGEST-VALUE', self.test_key) + assert digest3 is not None, "DEBUG DIGEST-VALUE should return a value after attribute removal" + assert digest2 != digest3, "Digest should change when an attribute is removed" + + except Exception as e: + raise AssertionError(f"DEBUG DIGEST-VALUE command failed: {str(e)}") diff --git a/modules/vector-sets/tests/deletion.py b/modules/vector-sets/tests/deletion.py new file mode 100644 index 000000000..cb919591b --- /dev/null +++ b/modules/vector-sets/tests/deletion.py @@ -0,0 +1,173 @@ +from test import TestCase, fill_redis_with_vectors, generate_random_vector +import random + +""" +A note about this test: +It was experimentally tried to modify hnsw.c in order to +avoid calling hnsw_reconnect_nodes(). In this case, the test +fails very often with EF set to 250, while it hardly +fails at all with the same parameters if hnsw_reconnect_nodes() +is called. + +Note that for the nature of the test (it is very strict) it can +still fail from time to time, without this signaling any +actual bug. +""" + +class VREM(TestCase): + def getname(self): + return "Deletion and graph state after deletion" + + def estimated_runtime(self): + return 2.0 + + def format_neighbors_with_scores(self, links_result, old_links=None, items_to_remove=None): + """Format neighbors with their similarity scores and status indicators""" + if not links_result: + return "No neighbors" + + output = [] + for level, neighbors in enumerate(links_result): + level_num = len(links_result) - level - 1 + output.append(f"Level {level_num}:") + + # Get neighbors and scores + neighbors_with_scores = [] + for i in range(0, len(neighbors), 2): + neighbor = neighbors[i].decode() if isinstance(neighbors[i], bytes) else neighbors[i] + score = float(neighbors[i+1]) if i+1 < len(neighbors) else None + status = "" + + # For old links, mark deleted ones + if items_to_remove and neighbor in items_to_remove: + status = " [lost]" + # For new links, mark newly added ones + elif old_links is not None: + # Check if this neighbor was in the old links at this level + was_present = False + if old_links and level < len(old_links): + old_neighbors = [n.decode() if isinstance(n, bytes) else n + for n in old_links[level]] + was_present = neighbor in old_neighbors + if not was_present: + status = " [gained]" + + if score is not None: + neighbors_with_scores.append(f"{len(neighbors_with_scores)+1}. {neighbor} ({score:.6f}){status}") + else: + neighbors_with_scores.append(f"{len(neighbors_with_scores)+1}. {neighbor}{status}") + + output.extend([" " + n for n in neighbors_with_scores]) + return "\n".join(output) + + def test(self): + # 1. Fill server with random elements + dim = 128 + count = 5000 + data = fill_redis_with_vectors(self.redis, self.test_key, count, dim) + + # 2. Do VSIM to get 200 items + query_vec = generate_random_vector(dim) + results = self.redis.execute_command('VSIM', self.test_key, 'VALUES', dim, + *[str(x) for x in query_vec], + 'COUNT', 200, 'WITHSCORES') + + # Convert results to list of (item, score) pairs, sorted by score + items = [] + for i in range(0, len(results), 2): + item = results[i].decode() + score = float(results[i+1]) + items.append((item, score)) + items.sort(key=lambda x: x[1], reverse=True) # Sort by similarity + + # Store the graph structure for all items before deletion + neighbors_before = {} + for item, _ in items: + links = self.redis.execute_command('VLINKS', self.test_key, item, 'WITHSCORES') + if links: # Some items might not have links + neighbors_before[item] = links + + # 3. Remove 100 random items + items_to_remove = set(item for item, _ in random.sample(items, 100)) + # Keep track of top 10 non-removed items + top_remaining = [] + for item, score in items: + if item not in items_to_remove: + top_remaining.append((item, score)) + if len(top_remaining) == 10: + break + + # Remove the items + for item in items_to_remove: + result = self.redis.execute_command('VREM', self.test_key, item) + assert result == 1, f"VREM failed to remove {item}" + + # 4. Do VSIM again with same vector + new_results = self.redis.execute_command('VSIM', self.test_key, 'VALUES', dim, + *[str(x) for x in query_vec], + 'COUNT', 200, 'WITHSCORES', + 'EF', 500) + + # Convert new results to dict of item -> score + new_scores = {} + for i in range(0, len(new_results), 2): + item = new_results[i].decode() + score = float(new_results[i+1]) + new_scores[item] = score + + failure = False + failed_item = None + failed_reason = None + # 5. Verify all top 10 non-removed items are still found with similar scores + for item, old_score in top_remaining: + if item not in new_scores: + failure = True + failed_item = item + failed_reason = "missing" + break + new_score = new_scores[item] + if abs(new_score - old_score) >= 0.01: + failure = True + failed_item = item + failed_reason = f"score changed: {old_score:.6f} -> {new_score:.6f}" + break + + if failure: + print("\nTest failed!") + print(f"Problem with item: {failed_item} ({failed_reason})") + + print("\nOriginal neighbors (with similarity scores):") + if failed_item in neighbors_before: + print(self.format_neighbors_with_scores( + neighbors_before[failed_item], + items_to_remove=items_to_remove)) + else: + print("No neighbors found in original graph") + + print("\nCurrent neighbors (with similarity scores):") + current_links = self.redis.execute_command('VLINKS', self.test_key, + failed_item, 'WITHSCORES') + if current_links: + print(self.format_neighbors_with_scores( + current_links, + old_links=neighbors_before.get(failed_item))) + else: + print("No neighbors in current graph") + + print("\nOriginal results (top 20):") + for item, score in items[:20]: + deleted = "[deleted]" if item in items_to_remove else "" + print(f"{item}: {score:.6f} {deleted}") + + print("\nNew results after removal (top 20):") + new_items = [] + for i in range(0, len(new_results), 2): + item = new_results[i].decode() + score = float(new_results[i+1]) + new_items.append((item, score)) + new_items.sort(key=lambda x: x[1], reverse=True) + for item, score in new_items[:20]: + print(f"{item}: {score:.6f}") + + raise AssertionError(f"Test failed: Problem with item {failed_item} ({failed_reason}). *** IMPORTANT *** This test may fail from time to time without indicating that there is a bug. However normally it should pass. The fact is that it's a quite extreme test where we destroy 50% of nodes of top results and still expect perfect recall, with vectors that are very hostile because of the distribution used.") + diff --git a/modules/vector-sets/tests/dimension_validation.py b/modules/vector-sets/tests/dimension_validation.py new file mode 100644 index 000000000..f0811529a --- /dev/null +++ b/modules/vector-sets/tests/dimension_validation.py @@ -0,0 +1,67 @@ +from test import TestCase, generate_random_vector +import struct +import redis.exceptions + +class DimensionValidation(TestCase): + def getname(self): + return "[regression] Dimension Validation with Projection" + + def estimated_runtime(self): + return 0.5 + + def test(self): + # Test scenario 1: Create a set with projection + original_dim = 100 + reduced_dim = 50 + + # Create the initial vector and set with projection + vec1 = generate_random_vector(original_dim) + vec1_bytes = struct.pack(f'{original_dim}f', *vec1) + + # Add first vector with projection + result = self.redis.execute_command('VADD', self.test_key, + 'REDUCE', reduced_dim, + 'FP32', vec1_bytes, f'{self.test_key}:item:1') + assert result == 1, "First VADD with REDUCE should return 1" + + # Check VINFO returns the correct projection information + info = self.redis.execute_command('VINFO', self.test_key) + info_map = {k.decode('utf-8'): v for k, v in zip(info[::2], info[1::2])} + assert 'vector-dim' in info_map, "VINFO should contain vector-dim" + assert info_map['vector-dim'] == reduced_dim, f"Expected reduced dimension {reduced_dim}, got {info['vector-dim']}" + assert 'projection-input-dim' in info_map, "VINFO should contain projection-input-dim" + assert info_map['projection-input-dim'] == original_dim, f"Expected original dimension {original_dim}, got {info['projection-input-dim']}" + + # Test scenario 2: Try adding a mismatched vector - should fail + wrong_dim = 80 + wrong_vec = generate_random_vector(wrong_dim) + wrong_vec_bytes = struct.pack(f'{wrong_dim}f', *wrong_vec) + + # This should fail with dimension mismatch error + try: + self.redis.execute_command('VADD', self.test_key, + 'REDUCE', reduced_dim, + 'FP32', wrong_vec_bytes, f'{self.test_key}:item:2') + assert False, "VADD with wrong dimension should fail" + except redis.exceptions.ResponseError as e: + assert "Input dimension mismatch for projection" in str(e), f"Expected dimension mismatch error, got: {e}" + + # Test scenario 3: Add a correctly-sized vector + vec2 = generate_random_vector(original_dim) + vec2_bytes = struct.pack(f'{original_dim}f', *vec2) + + # This should succeed + result = self.redis.execute_command('VADD', self.test_key, + 'REDUCE', reduced_dim, + 'FP32', vec2_bytes, f'{self.test_key}:item:3') + assert result == 1, "VADD with correct dimensions should succeed" + + # Check VSIM also validates input dimensions + wrong_query = generate_random_vector(wrong_dim) + try: + self.redis.execute_command('VSIM', self.test_key, + 'VALUES', wrong_dim, *[str(x) for x in wrong_query], + 'COUNT', 10) + assert False, "VSIM with wrong dimension should fail" + except redis.exceptions.ResponseError as e: + assert "Input dimension mismatch for projection" in str(e), f"Expected dimension mismatch error in VSIM, got: {e}" diff --git a/modules/vector-sets/tests/evict_empty.py b/modules/vector-sets/tests/evict_empty.py new file mode 100644 index 000000000..6c78c825d --- /dev/null +++ b/modules/vector-sets/tests/evict_empty.py @@ -0,0 +1,27 @@ +from test import TestCase, generate_random_vector +import struct + +class VREM_LastItemDeletesKey(TestCase): + def getname(self): + return "VREM last item deletes key" + + def test(self): + # Generate a random vector + vec = generate_random_vector(4) + vec_bytes = struct.pack('4f', *vec) + + # Add the vector to the key + result = self.redis.execute_command('VADD', self.test_key, 'FP32', vec_bytes, f'{self.test_key}:item:1') + assert result == 1, "VADD should return 1 for first item" + + # Verify the key exists + exists = self.redis.exists(self.test_key) + assert exists == 1, "Key should exist after VADD" + + # Remove the item + result = self.redis.execute_command('VREM', self.test_key, f'{self.test_key}:item:1') + assert result == 1, "VREM should return 1 for successful removal" + + # Verify the key no longer exists + exists = self.redis.exists(self.test_key) + assert exists == 0, "Key should no longer exist after VREM of last item" diff --git a/modules/vector-sets/tests/filter_expr.py b/modules/vector-sets/tests/filter_expr.py new file mode 100644 index 000000000..13abf7b65 --- /dev/null +++ b/modules/vector-sets/tests/filter_expr.py @@ -0,0 +1,177 @@ +from test import TestCase + +class VSIMFilterExpressions(TestCase): + def getname(self): + return "VSIM FILTER expressions basic functionality" + + def test(self): + # Create a small set of vectors with different attributes + + # Basic vectors for testing - all orthogonal for clear results + vec1 = [1, 0, 0, 0] + vec2 = [0, 1, 0, 0] + vec3 = [0, 0, 1, 0] + vec4 = [0, 0, 0, 1] + vec5 = [0.5, 0.5, 0, 0] + + # Add vectors with various attributes + self.redis.execute_command('VADD', self.test_key, 'VALUES', 4, + *[str(x) for x in vec1], f'{self.test_key}:item:1') + self.redis.execute_command('VSETATTR', self.test_key, f'{self.test_key}:item:1', + '{"age": 25, "name": "Alice", "active": true, "scores": [85, 90, 95], "city": "New York"}') + + self.redis.execute_command('VADD', self.test_key, 'VALUES', 4, + *[str(x) for x in vec2], f'{self.test_key}:item:2') + self.redis.execute_command('VSETATTR', self.test_key, f'{self.test_key}:item:2', + '{"age": 30, "name": "Bob", "active": false, "scores": [70, 75, 80], "city": "Boston"}') + + self.redis.execute_command('VADD', self.test_key, 'VALUES', 4, + *[str(x) for x in vec3], f'{self.test_key}:item:3') + self.redis.execute_command('VSETATTR', self.test_key, f'{self.test_key}:item:3', + '{"age": 35, "name": "Charlie", "scores": [60, 65, 70], "city": "Seattle"}') + + self.redis.execute_command('VADD', self.test_key, 'VALUES', 4, + *[str(x) for x in vec4], f'{self.test_key}:item:4') + # Item 4 has no attribute at all + + self.redis.execute_command('VADD', self.test_key, 'VALUES', 4, + *[str(x) for x in vec5], f'{self.test_key}:item:5') + self.redis.execute_command('VSETATTR', self.test_key, f'{self.test_key}:item:5', + 'invalid json') # Intentionally malformed JSON + + # Test 1: Basic equality with numbers + result = self.redis.execute_command('VSIM', self.test_key, 'VALUES', 4, + *[str(x) for x in vec1], + 'FILTER', '.age == 25') + assert len(result) == 1, "Expected 1 result for age == 25" + assert result[0].decode() == f'{self.test_key}:item:1', "Expected item:1 for age == 25" + + # Test 2: Greater than + result = self.redis.execute_command('VSIM', self.test_key, 'VALUES', 4, + *[str(x) for x in vec1], + 'FILTER', '.age > 25') + assert len(result) == 2, "Expected 2 results for age > 25" + + # Test 3: Less than or equal + result = self.redis.execute_command('VSIM', self.test_key, 'VALUES', 4, + *[str(x) for x in vec1], + 'FILTER', '.age <= 30') + assert len(result) == 2, "Expected 2 results for age <= 30" + + # Test 4: String equality + result = self.redis.execute_command('VSIM', self.test_key, 'VALUES', 4, + *[str(x) for x in vec1], + 'FILTER', '.name == "Alice"') + assert len(result) == 1, "Expected 1 result for name == Alice" + assert result[0].decode() == f'{self.test_key}:item:1', "Expected item:1 for name == Alice" + + # Test 5: String inequality + result = self.redis.execute_command('VSIM', self.test_key, 'VALUES', 4, + *[str(x) for x in vec1], + 'FILTER', '.name != "Alice"') + assert len(result) == 2, "Expected 2 results for name != Alice" + + # Test 6: Boolean value + result = self.redis.execute_command('VSIM', self.test_key, 'VALUES', 4, + *[str(x) for x in vec1], + 'FILTER', '.active') + assert len(result) == 1, "Expected 1 result for .active being true" + + # Test 7: Logical AND + result = self.redis.execute_command('VSIM', self.test_key, 'VALUES', 4, + *[str(x) for x in vec1], + 'FILTER', '.age > 20 and .age < 30') + assert len(result) == 1, "Expected 1 result for 20 < age < 30" + assert result[0].decode() == f'{self.test_key}:item:1', "Expected item:1 for 20 < age < 30" + + # Test 8: Logical OR + result = self.redis.execute_command('VSIM', self.test_key, 'VALUES', 4, + *[str(x) for x in vec1], + 'FILTER', '.age < 30 or .age > 35') + assert len(result) == 1, "Expected 1 result for age < 30 or age > 35" + + # Test 9: Logical NOT + result = self.redis.execute_command('VSIM', self.test_key, 'VALUES', 4, + *[str(x) for x in vec1], + 'FILTER', '!(.age == 25)') + assert len(result) == 2, "Expected 2 results for NOT(age == 25)" + + # Test 10: The "in" operator with array + result = self.redis.execute_command('VSIM', self.test_key, 'VALUES', 4, + *[str(x) for x in vec1], + 'FILTER', '.age in [25, 35]') + assert len(result) == 2, "Expected 2 results for age in [25, 35]" + + # Test 11: The "in" operator with strings in array + result = self.redis.execute_command('VSIM', self.test_key, 'VALUES', 4, + *[str(x) for x in vec1], + 'FILTER', '.name in ["Alice", "David"]') + assert len(result) == 1, "Expected 1 result for name in [Alice, David]" + + # Test 12: Arithmetic operations - addition + result = self.redis.execute_command('VSIM', self.test_key, 'VALUES', 4, + *[str(x) for x in vec1], + 'FILTER', '.age + 10 > 40') + assert len(result) == 1, "Expected 1 result for age + 10 > 40" + + # Test 13: Arithmetic operations - multiplication + result = self.redis.execute_command('VSIM', self.test_key, 'VALUES', 4, + *[str(x) for x in vec1], + 'FILTER', '.age * 2 > 60') + assert len(result) == 1, "Expected 1 result for age * 2 > 60" + + # Test 14: Arithmetic operations - division + result = self.redis.execute_command('VSIM', self.test_key, 'VALUES', 4, + *[str(x) for x in vec1], + 'FILTER', '.age / 5 == 5') + assert len(result) == 1, "Expected 1 result for age / 5 == 5" + + # Test 15: Arithmetic operations - modulo + result = self.redis.execute_command('VSIM', self.test_key, 'VALUES', 4, + *[str(x) for x in vec1], + 'FILTER', '.age % 2 == 0') + assert len(result) == 1, "Expected 1 result for age % 2 == 0" + + # Test 16: Power operator + result = self.redis.execute_command('VSIM', self.test_key, 'VALUES', 4, + *[str(x) for x in vec1], + 'FILTER', '.age ** 2 > 900') + assert len(result) == 1, "Expected 1 result for age^2 > 900" + + # Test 17: Missing attribute (should exclude items missing that attribute) + result = self.redis.execute_command('VSIM', self.test_key, 'VALUES', 4, + *[str(x) for x in vec1], + 'FILTER', '.missing_field == "value"') + assert len(result) == 0, "Expected 0 results for missing_field == value" + + # Test 18: No attribute set at all + result = self.redis.execute_command('VSIM', self.test_key, 'VALUES', 4, + *[str(x) for x in vec1], + 'FILTER', '.any_field') + assert f'{self.test_key}:item:4' not in [item.decode() for item in result], "Item with no attribute should be excluded" + + # Test 19: Malformed JSON + result = self.redis.execute_command('VSIM', self.test_key, 'VALUES', 4, + *[str(x) for x in vec1], + 'FILTER', '.any_field') + assert f'{self.test_key}:item:5' not in [item.decode() for item in result], "Item with malformed JSON should be excluded" + + # Test 20: Complex expression combining multiple operators + result = self.redis.execute_command('VSIM', self.test_key, 'VALUES', 4, + *[str(x) for x in vec1], + 'FILTER', '(.age > 20 and .age < 40) and (.city == "Boston" or .city == "New York")') + assert len(result) == 2, "Expected 2 results for the complex expression" + expected_items = [f'{self.test_key}:item:1', f'{self.test_key}:item:2'] + assert set([item.decode() for item in result]) == set(expected_items), "Expected item:1 and item:2 for the complex expression" + + # Test 21: Parentheses to control operator precedence + result = self.redis.execute_command('VSIM', self.test_key, 'VALUES', 4, + *[str(x) for x in vec1], + 'FILTER', '.age > (20 + 10)') + assert len(result) == 1, "Expected 1 result for age > (20 + 10)" + + # Test 22: Array access (arrays evaluate to true) + result = self.redis.execute_command('VSIM', self.test_key, 'VALUES', 4, + *[str(x) for x in vec1], + 'FILTER', '.scores') + assert len(result) == 3, "Expected 3 results for .scores (arrays evaluate to true)" diff --git a/modules/vector-sets/tests/filter_int.py b/modules/vector-sets/tests/filter_int.py new file mode 100644 index 000000000..0fd1dc132 --- /dev/null +++ b/modules/vector-sets/tests/filter_int.py @@ -0,0 +1,668 @@ +from test import TestCase, generate_random_vector +import struct +import random +import math +import json +import time + +class VSIMFilterAdvanced(TestCase): + def getname(self): + return "VSIM FILTER comprehensive functionality testing" + + def estimated_runtime(self): + return 15 # This test might take up to 15 seconds for the large dataset + + def setup(self): + super().setup() + self.dim = 32 # Vector dimension + self.count = 5000 # Number of vectors for large tests + self.small_count = 50 # Number of vectors for small/quick tests + + # Categories for attributes + self.categories = ["electronics", "furniture", "clothing", "books", "food"] + self.cities = ["New York", "London", "Tokyo", "Paris", "Berlin", "Sydney", "Toronto", "Singapore"] + self.price_ranges = [(10, 50), (50, 200), (200, 1000), (1000, 5000)] + self.years = list(range(2000, 2025)) + + def create_attributes(self, index): + """Create realistic attributes for a vector""" + category = random.choice(self.categories) + city = random.choice(self.cities) + min_price, max_price = random.choice(self.price_ranges) + price = round(random.uniform(min_price, max_price), 2) + year = random.choice(self.years) + in_stock = random.random() > 0.3 # 70% chance of being in stock + rating = round(random.uniform(1, 5), 1) + views = int(random.expovariate(1/1000)) # Exponential distribution for page views + tags = random.sample(["popular", "sale", "new", "limited", "exclusive", "clearance"], + k=random.randint(0, 3)) + + # Add some specific patterns for testing + # Every 10th item has a specific property combination for testing + is_premium = (index % 10 == 0) + + # Create attributes dictionary + attrs = { + "id": index, + "category": category, + "location": city, + "price": price, + "year": year, + "in_stock": in_stock, + "rating": rating, + "views": views, + "tags": tags + } + + if is_premium: + attrs["is_premium"] = True + attrs["special_features"] = ["premium", "warranty", "support"] + + # Add sub-categories for more complex filters + if category == "electronics": + attrs["subcategory"] = random.choice(["phones", "computers", "cameras", "audio"]) + elif category == "furniture": + attrs["subcategory"] = random.choice(["chairs", "tables", "sofas", "beds"]) + elif category == "clothing": + attrs["subcategory"] = random.choice(["shirts", "pants", "dresses", "shoes"]) + + # Add some intentionally missing fields for testing + if random.random() > 0.9: # 10% chance of missing price + del attrs["price"] + + # Some items have promotion field + if random.random() > 0.7: # 30% chance of having a promotion + attrs["promotion"] = random.choice(["discount", "bundle", "gift"]) + + # Create invalid JSON for a small percentage of vectors + if random.random() > 0.98: # 2% chance of having invalid JSON + return "{{invalid json}}" + + return json.dumps(attrs) + + def create_vectors_with_attributes(self, key, count): + """Create vectors and add attributes to them""" + vectors = [] + names = [] + attribute_map = {} # To store attributes for verification + + # Create vectors + for i in range(count): + vec = generate_random_vector(self.dim) + vectors.append(vec) + name = f"{key}:item:{i}" + names.append(name) + + # Add to Redis + vec_bytes = struct.pack(f'{self.dim}f', *vec) + self.redis.execute_command('VADD', key, 'FP32', vec_bytes, name) + + # Create and add attributes + attrs = self.create_attributes(i) + self.redis.execute_command('VSETATTR', key, name, attrs) + + # Store attributes for later verification + try: + attribute_map[name] = json.loads(attrs) if '{' in attrs else None + except json.JSONDecodeError: + attribute_map[name] = None + + return vectors, names, attribute_map + + def filter_linear_search(self, vectors, names, query_vector, filter_expr, attribute_map, k=10): + """Perform a linear search with filtering for verification""" + similarities = [] + query_norm = math.sqrt(sum(x*x for x in query_vector)) + + if query_norm == 0: + return [] + + for i, vec in enumerate(vectors): + name = names[i] + attributes = attribute_map.get(name) + + # Skip if doesn't match filter + if not self.matches_filter(attributes, filter_expr): + continue + + vec_norm = math.sqrt(sum(x*x for x in vec)) + if vec_norm == 0: + continue + + dot_product = sum(a*b for a,b in zip(query_vector, vec)) + cosine_sim = dot_product / (query_norm * vec_norm) + distance = 1.0 - cosine_sim + redis_similarity = 1.0 - (distance/2.0) + similarities.append((name, redis_similarity)) + + similarities.sort(key=lambda x: x[1], reverse=True) + return similarities[:k] + + def matches_filter(self, attributes, filter_expr): + """Filter matching for verification - uses Python eval to handle complex expressions""" + if attributes is None: + return False # No attributes or invalid JSON + + # Replace JSON path selectors with Python dictionary access + py_expr = filter_expr + + # Handle `.field` notation (replace with attributes['field']) + i = 0 + while i < len(py_expr): + if py_expr[i] == '.' and (i == 0 or not py_expr[i-1].isalnum()): + # Find the end of the selector (stops at operators or whitespace) + j = i + 1 + while j < len(py_expr) and (py_expr[j].isalnum() or py_expr[j] == '_'): + j += 1 + + if j > i + 1: # Found a valid selector + field = py_expr[i+1:j] + # Use a safe access pattern that returns a default value based on context + py_expr = py_expr[:i] + f"attributes.get('{field}')" + py_expr[j:] + i = i + len(f"attributes.get('{field}')") + else: + i += 1 + else: + i += 1 + + # Convert not operator if needed + py_expr = py_expr.replace('!', ' not ') + + try: + # Custom evaluation that handles exceptions for missing fields + # by returning False for the entire expression + + # Split the expression on logical operators + parts = [] + for op in [' and ', ' or ']: + if op in py_expr: + parts = py_expr.split(op) + break + + if not parts: # No logical operators found + parts = [py_expr] + + # Try to evaluate each part - if any part fails, + # the whole expression should fail + try: + result = eval(py_expr, {"attributes": attributes}) + return bool(result) + except (TypeError, AttributeError): + # This typically happens when trying to compare None with + # numbers or other types, or when an attribute doesn't exist + return False + except Exception as e: + print(f"Error evaluating filter expression '{filter_expr}' as '{py_expr}': {e}") + return False + + except Exception as e: + print(f"Error evaluating filter expression '{filter_expr}' as '{py_expr}': {e}") + return False + + def safe_decode(self,item): + return item.decode() if isinstance(item, bytes) else item + + def calculate_recall(self, redis_results, linear_results, k=10): + """Calculate recall (percentage of correct results retrieved)""" + redis_set = set(self.safe_decode(item) for item in redis_results) + linear_set = set(item[0] for item in linear_results[:k]) + + if not linear_set: + return 1.0 # If no linear results, consider it perfect recall + + intersection = redis_set.intersection(linear_set) + return len(intersection) / len(linear_set) + + def test_recall_with_filter(self, filter_expr, ef=500, filter_ef=None): + """Test recall for a given filter expression""" + # Create query vector + query_vec = generate_random_vector(self.dim) + + # First, get ground truth using linear scan + linear_results = self.filter_linear_search( + self.vectors, self.names, query_vec, filter_expr, self.attribute_map, k=50) + + # Calculate true selectivity from ground truth + true_selectivity = len(linear_results) / len(self.names) if self.names else 0 + + # Perform Redis search with filter + cmd_args = ['VSIM', self.test_key, 'VALUES', self.dim] + cmd_args.extend([str(x) for x in query_vec]) + cmd_args.extend(['COUNT', 50, 'WITHSCORES', 'EF', ef, 'FILTER', filter_expr]) + if filter_ef: + cmd_args.extend(['FILTER-EF', filter_ef]) + + start_time = time.time() + redis_results = self.redis.execute_command(*cmd_args) + query_time = time.time() - start_time + + # Convert Redis results to dict + redis_items = {} + for i in range(0, len(redis_results), 2): + key = redis_results[i].decode() if isinstance(redis_results[i], bytes) else redis_results[i] + score = float(redis_results[i+1]) + redis_items[key] = score + + # Calculate metrics + recall = self.calculate_recall(redis_items.keys(), linear_results) + selectivity = len(redis_items) / len(self.names) if redis_items else 0 + + # Compare against the true selectivity from linear scan + assert abs(selectivity - true_selectivity) < 0.1, \ + f"Redis selectivity {selectivity:.3f} differs significantly from ground truth {true_selectivity:.3f}" + + # We expect high recall for standard parameters + if ef >= 500 and (filter_ef is None or filter_ef >= 1000): + try: + assert recall >= 0.7, \ + f"Low recall {recall:.2f} for filter '{filter_expr}'" + except AssertionError as e: + # Get items found in each set + redis_items_set = set(redis_items.keys()) + linear_items_set = set(item[0] for item in linear_results) + + # Find items in each set + only_in_redis = redis_items_set - linear_items_set + only_in_linear = linear_items_set - redis_items_set + in_both = redis_items_set & linear_items_set + + # Build comprehensive debug message + debug = f"\nGround Truth: {len(linear_results)} matching items (total vectors: {len(self.vectors)})" + debug += f"\nRedis Found: {len(redis_items)} items with FILTER-EF: {filter_ef or 'default'}" + debug += f"\nItems in both sets: {len(in_both)} (recall: {recall:.4f})" + debug += f"\nItems only in Redis: {len(only_in_redis)}" + debug += f"\nItems only in Ground Truth: {len(only_in_linear)}" + + # Show some example items from each set with their scores + if only_in_redis: + debug += "\n\nTOP 5 ITEMS ONLY IN REDIS:" + sorted_redis = sorted([(k, v) for k, v in redis_items.items()], key=lambda x: x[1], reverse=True) + for i, (item, score) in enumerate(sorted_redis[:5]): + if item in only_in_redis: + debug += f"\n {i+1}. {item} (Score: {score:.4f})" + + # Show attribute that should match filter + attr = self.attribute_map.get(item) + if attr: + debug += f" - Attrs: {attr.get('category', 'N/A')}, Price: {attr.get('price', 'N/A')}" + + if only_in_linear: + debug += "\n\nTOP 5 ITEMS ONLY IN GROUND TRUTH:" + for i, (item, score) in enumerate(linear_results[:5]): + if item in only_in_linear: + debug += f"\n {i+1}. {item} (Score: {score:.4f})" + + # Show attribute that should match filter + attr = self.attribute_map.get(item) + if attr: + debug += f" - Attrs: {attr.get('category', 'N/A')}, Price: {attr.get('price', 'N/A')}" + + # Help identify parsing issues + debug += "\n\nPARSING CHECK:" + debug += f"\nRedis command: VSIM {self.test_key} VALUES {self.dim} [...] FILTER '{filter_expr}'" + + # Check for WITHSCORES handling issues + if len(redis_results) > 0 and len(redis_results) % 2 == 0: + debug += f"\nRedis returned {len(redis_results)} items (looks like item,score pairs)" + debug += f"\nFirst few results: {redis_results[:4]}" + + # Check the filter implementation + debug += "\n\nFILTER IMPLEMENTATION CHECK:" + debug += f"\nFilter expression: '{filter_expr}'" + debug += "\nSample attribute matches from attribute_map:" + count_matching = 0 + for i, (name, attrs) in enumerate(self.attribute_map.items()): + if attrs and self.matches_filter(attrs, filter_expr): + count_matching += 1 + if i < 3: # Show first 3 matches + debug += f"\n - {name}: {attrs}" + debug += f"\nTotal items matching filter in attribute_map: {count_matching}" + + # Check if results array handling could be wrong + debug += "\n\nRESULT ARRAYS CHECK:" + if len(linear_results) >= 1: + debug += f"\nlinear_results[0]: {linear_results[0]}" + if isinstance(linear_results[0], tuple) and len(linear_results[0]) == 2: + debug += " (correct tuple format: (name, score))" + else: + debug += " (UNEXPECTED FORMAT!)" + + # Debug sort order + debug += "\n\nSORTING CHECK:" + if len(linear_results) >= 2: + debug += f"\nGround truth first item score: {linear_results[0][1]}" + debug += f"\nGround truth second item score: {linear_results[1][1]}" + debug += f"\nCorrectly sorted by similarity? {linear_results[0][1] >= linear_results[1][1]}" + + # Re-raise with detailed information + raise AssertionError(str(e) + debug) + + return recall, selectivity, query_time, len(redis_items) + + def test(self): + print(f"\nRunning comprehensive VSIM FILTER tests...") + + # Create a larger dataset for testing + print(f"Creating dataset with {self.count} vectors and attributes...") + self.vectors, self.names, self.attribute_map = self.create_vectors_with_attributes( + self.test_key, self.count) + + # ==== 1. Recall and Precision Testing ==== + print("Testing recall for various filters...") + + # Test basic filters with different selectivity + results = {} + results["category"] = self.test_recall_with_filter('.category == "electronics"') + results["price_high"] = self.test_recall_with_filter('.price > 1000') + results["in_stock"] = self.test_recall_with_filter('.in_stock') + results["rating"] = self.test_recall_with_filter('.rating >= 4') + results["complex1"] = self.test_recall_with_filter('.category == "electronics" and .price < 500') + + print("Filter | Recall | Selectivity | Time (ms) | Results") + print("----------------------------------------------------") + for name, (recall, selectivity, time_ms, count) in results.items(): + print(f"{name:7} | {recall:.3f} | {selectivity:.3f} | {time_ms*1000:.1f} | {count}") + + # ==== 2. Filter Selectivity Performance ==== + print("\nTesting filter selectivity performance...") + + # High selectivity (very few matches) + high_sel_recall, _, high_sel_time, _ = self.test_recall_with_filter('.is_premium') + + # Medium selectivity + med_sel_recall, _, med_sel_time, _ = self.test_recall_with_filter('.price > 100 and .price < 1000') + + # Low selectivity (many matches) + low_sel_recall, _, low_sel_time, _ = self.test_recall_with_filter('.year > 2000') + + print(f"High selectivity recall: {high_sel_recall:.3f}, time: {high_sel_time*1000:.1f}ms") + print(f"Med selectivity recall: {med_sel_recall:.3f}, time: {med_sel_time*1000:.1f}ms") + print(f"Low selectivity recall: {low_sel_recall:.3f}, time: {low_sel_time*1000:.1f}ms") + + # ==== 3. FILTER-EF Parameter Testing ==== + print("\nTesting FILTER-EF parameter...") + + # Test with different FILTER-EF values + filter_expr = '.category == "electronics" and .price > 200' + ef_values = [100, 500, 2000, 5000] + + print("FILTER-EF | Recall | Time (ms)") + print("-----------------------------") + for filter_ef in ef_values: + recall, _, query_time, _ = self.test_recall_with_filter( + filter_expr, ef=500, filter_ef=filter_ef) + print(f"{filter_ef:9} | {recall:.3f} | {query_time*1000:.1f}") + + # Assert that higher FILTER-EF generally gives better recall + low_ef_recall, _, _, _ = self.test_recall_with_filter(filter_expr, filter_ef=100) + high_ef_recall, _, _, _ = self.test_recall_with_filter(filter_expr, filter_ef=5000) + + # This might not always be true due to randomness, but generally holds + # We use a softer assertion to avoid flaky tests + assert high_ef_recall >= low_ef_recall * 0.8, \ + f"Higher FILTER-EF should generally give better recall: {high_ef_recall:.3f} vs {low_ef_recall:.3f}" + + # ==== 4. Complex Filter Expressions ==== + print("\nTesting complex filter expressions...") + + # Test a variety of complex expressions + complex_filters = [ + '.price > 100 and (.category == "electronics" or .category == "furniture")', + '(.rating > 4 and .in_stock) or (.price < 50 and .views > 1000)', + '.category in ["electronics", "clothing"] and .price > 200 and .rating >= 3', + '(.category == "electronics" and .subcategory == "phones") or (.category == "furniture" and .price > 1000)', + '.year > 2010 and !(.price < 100) and .in_stock' + ] + + print("Expression | Results | Time (ms)") + print("-----------------------------") + for i, expr in enumerate(complex_filters): + try: + _, _, query_time, result_count = self.test_recall_with_filter(expr) + print(f"Complex {i+1} | {result_count:7} | {query_time*1000:.1f}") + except Exception as e: + print(f"Complex {i+1} | Error: {str(e)}") + + # ==== 5. Attribute Type Testing ==== + print("\nTesting different attribute types...") + + type_filters = [ + ('.price > 500', "Numeric"), + ('.category == "books"', "String equality"), + ('.in_stock', "Boolean"), + ('.tags in ["sale", "new"]', "Array membership"), + ('.rating * 2 > 8', "Arithmetic") + ] + + for expr, type_name in type_filters: + try: + _, _, query_time, result_count = self.test_recall_with_filter(expr) + print(f"{type_name:16} | {expr:30} | {result_count:5} results | {query_time*1000:.1f}ms") + except Exception as e: + print(f"{type_name:16} | {expr:30} | Error: {str(e)}") + + # ==== 6. Filter + Count Interaction ==== + print("\nTesting COUNT parameter with filters...") + + filter_expr = '.category == "electronics"' + counts = [5, 20, 100] + + for count in counts: + query_vec = generate_random_vector(self.dim) + cmd_args = ['VSIM', self.test_key, 'VALUES', self.dim] + cmd_args.extend([str(x) for x in query_vec]) + cmd_args.extend(['COUNT', count, 'WITHSCORES', 'FILTER', filter_expr]) + + results = self.redis.execute_command(*cmd_args) + result_count = len(results) // 2 # Divide by 2 because WITHSCORES returns pairs + + # We expect result count to be at most the requested count + assert result_count <= count, f"Got {result_count} results with COUNT {count}" + print(f"COUNT {count:3} | Got {result_count:3} results") + + # ==== 7. Edge Cases ==== + print("\nTesting edge cases...") + + # Test with no matching items + no_match_expr = '.category == "nonexistent_category"' + results = self.redis.execute_command('VSIM', self.test_key, 'VALUES', self.dim, + *[str(x) for x in generate_random_vector(self.dim)], + 'FILTER', no_match_expr) + assert len(results) == 0, f"Expected 0 results for non-matching filter, got {len(results)}" + print(f"No matching items: {len(results)} results (expected 0)") + + # Test with invalid filter syntax + try: + self.redis.execute_command('VSIM', self.test_key, 'VALUES', self.dim, + *[str(x) for x in generate_random_vector(self.dim)], + 'FILTER', '.category === "books"') # Triple equals is invalid + assert False, "Expected error for invalid filter syntax" + except: + print("Invalid filter syntax correctly raised an error") + + # Test with extremely long complex expression + long_expr = ' and '.join([f'.rating > {i/10}' for i in range(10)]) + try: + results = self.redis.execute_command('VSIM', self.test_key, 'VALUES', self.dim, + *[str(x) for x in generate_random_vector(self.dim)], + 'FILTER', long_expr) + print(f"Long expression: {len(results)} results") + except Exception as e: + print(f"Long expression error: {str(e)}") + + print("\nComprehensive VSIM FILTER tests completed successfully") + + +class VSIMFilterSelectivityTest(TestCase): + def getname(self): + return "VSIM FILTER selectivity performance benchmark" + + def estimated_runtime(self): + return 8 # This test might take up to 8 seconds + + def setup(self): + super().setup() + self.dim = 32 + self.count = 10000 + self.test_key = f"{self.test_key}:selectivity" # Use a different key + + def create_vector_with_age_attribute(self, name, age): + """Create a vector with a specific age attribute""" + vec = generate_random_vector(self.dim) + vec_bytes = struct.pack(f'{self.dim}f', *vec) + self.redis.execute_command('VADD', self.test_key, 'FP32', vec_bytes, name) + self.redis.execute_command('VSETATTR', self.test_key, name, json.dumps({"age": age})) + + def test(self): + print("\nRunning VSIM FILTER selectivity benchmark...") + + # Create a dataset where we control the exact selectivity + print(f"Creating controlled dataset with {self.count} vectors...") + + # Create vectors with age attributes from 1 to 100 + for i in range(self.count): + age = (i % 100) + 1 # Ages from 1 to 100 + name = f"{self.test_key}:item:{i}" + self.create_vector_with_age_attribute(name, age) + + # Create a query vector + query_vec = generate_random_vector(self.dim) + + # Test filters with different selectivities + selectivities = [0.01, 0.05, 0.10, 0.25, 0.50, 0.75, 0.99] + results = [] + + print("\nSelectivity | Filter | Results | Time (ms)") + print("--------------------------------------------------") + + for target_selectivity in selectivities: + # Calculate age threshold for desired selectivity + # For example, age <= 10 gives 10% selectivity + age_threshold = int(target_selectivity * 100) + filter_expr = f'.age <= {age_threshold}' + + # Run query and measure time + start_time = time.time() + cmd_args = ['VSIM', self.test_key, 'VALUES', self.dim] + cmd_args.extend([str(x) for x in query_vec]) + cmd_args.extend(['COUNT', 100, 'FILTER', filter_expr]) + + results = self.redis.execute_command(*cmd_args) + query_time = time.time() - start_time + + actual_selectivity = len(results) / min(100, int(target_selectivity * self.count)) + print(f"{target_selectivity:.2f} | {filter_expr:15} | {len(results):7} | {query_time*1000:.1f}") + + # Add assertion to ensure reasonable performance for different selectivities + # For very selective queries (1%), we might need more exploration + if target_selectivity <= 0.05: + # For very selective queries, ensure we can find some results + assert len(results) > 0, f"No results found for {filter_expr}" + else: + # For less selective queries, performance should be reasonable + assert query_time < 1.0, f"Query too slow: {query_time:.3f}s for {filter_expr}" + + print("\nSelectivity benchmark completed successfully") + + +class VSIMFilterComparisonTest(TestCase): + def getname(self): + return "VSIM FILTER EF parameter comparison" + + def estimated_runtime(self): + return 8 # This test might take up to 8 seconds + + def setup(self): + super().setup() + self.dim = 32 + self.count = 5000 + self.test_key = f"{self.test_key}:efparams" # Use a different key + + def create_dataset(self): + """Create a dataset with specific attribute patterns for testing FILTER-EF""" + vectors = [] + names = [] + + # Create vectors with category and quality score attributes + for i in range(self.count): + vec = generate_random_vector(self.dim) + name = f"{self.test_key}:item:{i}" + + # Add vector to Redis + vec_bytes = struct.pack(f'{self.dim}f', *vec) + self.redis.execute_command('VADD', self.test_key, 'FP32', vec_bytes, name) + + # Create attributes - we want a very selective filter + # Only 2% of items have category=premium AND quality>90 + category = "premium" if random.random() < 0.1 else random.choice(["standard", "economy", "basic"]) + quality = random.randint(1, 100) + + attrs = { + "id": i, + "category": category, + "quality": quality + } + + self.redis.execute_command('VSETATTR', self.test_key, name, json.dumps(attrs)) + vectors.append(vec) + names.append(name) + + return vectors, names + + def test(self): + print("\nRunning VSIM FILTER-EF parameter comparison...") + + # Create dataset + vectors, names = self.create_dataset() + + # Create a selective filter that matches ~2% of items + filter_expr = '.category == "premium" and .quality > 90' + + # Create query vector + query_vec = generate_random_vector(self.dim) + + # Test different FILTER-EF values + ef_values = [50, 100, 500, 1000, 5000] + results = [] + + print("\nFILTER-EF | Results | Time (ms) | Notes") + print("---------------------------------------") + + baseline_count = None + + for ef in ef_values: + # Run query and measure time + start_time = time.time() + cmd_args = ['VSIM', self.test_key, 'VALUES', self.dim] + cmd_args.extend([str(x) for x in query_vec]) + cmd_args.extend(['COUNT', 100, 'FILTER', filter_expr, 'FILTER-EF', ef]) + + query_results = self.redis.execute_command(*cmd_args) + query_time = time.time() - start_time + + # Set baseline for comparison + if baseline_count is None: + baseline_count = len(query_results) + + recall_rate = len(query_results) / max(1, baseline_count) if baseline_count > 0 else 1.0 + + notes = "" + if ef == 5000: + notes = "Baseline" + elif recall_rate < 0.5: + notes = "Low recall!" + + print(f"{ef:9} | {len(query_results):7} | {query_time*1000:.1f} | {notes}") + results.append((ef, len(query_results), query_time)) + + # If we have enough results at highest EF, check that recall improves with higher EF + if results[-1][1] >= 5: # At least 5 results for highest EF + # Extract result counts + result_counts = [r[1] for r in results] + + # The last result (highest EF) should typically find more results than the first (lowest EF) + # but we use a soft assertion to avoid flaky tests + assert result_counts[-1] >= result_counts[0], \ + f"Higher FILTER-EF should find at least as many results: {result_counts[-1]} vs {result_counts[0]}" + + print("\nFILTER-EF parameter comparison completed successfully") diff --git a/modules/vector-sets/tests/large_scale.py b/modules/vector-sets/tests/large_scale.py new file mode 100644 index 000000000..eac5dca52 --- /dev/null +++ b/modules/vector-sets/tests/large_scale.py @@ -0,0 +1,56 @@ +from test import TestCase, fill_redis_with_vectors, generate_random_vector +import random + +class LargeScale(TestCase): + def getname(self): + return "Large Scale Comparison" + + def estimated_runtime(self): + return 10 + + def test(self): + dim = 300 + count = 20000 + k = 50 + + # Fill Redis and get reference data for comparison + random.seed(42) # Make test deterministic + data = fill_redis_with_vectors(self.redis, self.test_key, count, dim) + + # Generate query vector + query_vec = generate_random_vector(dim) + + # Get results from Redis with good exploration factor + redis_raw = self.redis.execute_command('VSIM', self.test_key, 'VALUES', dim, + *[str(x) for x in query_vec], + 'COUNT', k, 'WITHSCORES', 'EF', 500) + + # Convert Redis results to dict + redis_results = {} + for i in range(0, len(redis_raw), 2): + key = redis_raw[i].decode() + score = float(redis_raw[i+1]) + redis_results[key] = score + + # Get results from linear scan + linear_results = data.find_k_nearest(query_vec, k) + linear_items = {name: score for name, score in linear_results} + + # Compare overlap + redis_set = set(redis_results.keys()) + linear_set = set(linear_items.keys()) + overlap = len(redis_set & linear_set) + + # If test fails, print comparison for debugging + if overlap < k * 0.7: + data.print_comparison({'items': redis_results, 'query_vector': query_vec}, k) + + assert overlap >= k * 0.7, \ + f"Expected at least 70% overlap in top {k} results, got {overlap/k*100:.1f}%" + + # Verify scores for common items + for item in redis_set & linear_set: + redis_score = redis_results[item] + linear_score = linear_items[item] + assert abs(redis_score - linear_score) < 0.01, \ + f"Score mismatch for {item}: Redis={redis_score:.3f} Linear={linear_score:.3f}" diff --git a/modules/vector-sets/tests/memory_usage.py b/modules/vector-sets/tests/memory_usage.py new file mode 100644 index 000000000..d0f3f0967 --- /dev/null +++ b/modules/vector-sets/tests/memory_usage.py @@ -0,0 +1,36 @@ +from test import TestCase, generate_random_vector +import struct + +class MemoryUsageTest(TestCase): + def getname(self): + return "[regression] MEMORY USAGE with attributes" + + def test(self): + # Generate random vectors + vec1 = generate_random_vector(4) + vec2 = generate_random_vector(4) + vec_bytes1 = struct.pack('4f', *vec1) + vec_bytes2 = struct.pack('4f', *vec2) + + # Add vectors to the key, one with attribute, one without + self.redis.execute_command('VADD', self.test_key, 'FP32', vec_bytes1, f'{self.test_key}:item:1') + self.redis.execute_command('VADD', self.test_key, 'FP32', vec_bytes2, f'{self.test_key}:item:2', 'SETATTR', '{"color":"red"}') + + # Get memory usage for the key + try: + memory_usage = self.redis.execute_command('MEMORY', 'USAGE', self.test_key) + # If we got here without exception, the command worked + assert memory_usage > 0, "MEMORY USAGE should return a positive value" + + # Add more attributes to increase complexity + self.redis.execute_command('VSETATTR', self.test_key, f'{self.test_key}:item:1', '{"color":"blue","size":10}') + + # Check memory usage again + new_memory_usage = self.redis.execute_command('MEMORY', 'USAGE', self.test_key) + assert new_memory_usage > 0, "MEMORY USAGE should still return a positive value after setting attributes" + + # Memory usage should be higher after adding attributes + assert new_memory_usage > memory_usage, "Memory usage increase after adding attributes" + + except Exception as e: + raise AssertionError(f"MEMORY USAGE command failed: {str(e)}") diff --git a/modules/vector-sets/tests/node_update.py b/modules/vector-sets/tests/node_update.py new file mode 100644 index 000000000..53aa2dd56 --- /dev/null +++ b/modules/vector-sets/tests/node_update.py @@ -0,0 +1,85 @@ +from test import TestCase, generate_random_vector +import struct +import math +import random + +class VectorUpdateAndClusters(TestCase): + def getname(self): + return "VADD vector update with cluster relocation" + + def estimated_runtime(self): + return 2.0 # Should take around 2 seconds + + def generate_cluster_vector(self, base_vec, noise=0.1): + """Generate a vector that's similar to base_vec with some noise.""" + vec = [x + random.gauss(0, noise) for x in base_vec] + # Normalize + norm = math.sqrt(sum(x*x for x in vec)) + return [x/norm for x in vec] + + def test(self): + dim = 128 + vectors_per_cluster = 5000 + + # Create two very different base vectors for our clusters + cluster1_base = generate_random_vector(dim) + cluster2_base = [-x for x in cluster1_base] # Opposite direction + + # Add vectors from first cluster + for i in range(vectors_per_cluster): + vec = self.generate_cluster_vector(cluster1_base) + vec_bytes = struct.pack(f'{dim}f', *vec) + self.redis.execute_command('VADD', self.test_key, 'FP32', vec_bytes, + f'{self.test_key}:cluster1:{i}') + + # Add vectors from second cluster + for i in range(vectors_per_cluster): + vec = self.generate_cluster_vector(cluster2_base) + vec_bytes = struct.pack(f'{dim}f', *vec) + self.redis.execute_command('VADD', self.test_key, 'FP32', vec_bytes, + f'{self.test_key}:cluster2:{i}') + + # Pick a test vector from cluster1 + test_key = f'{self.test_key}:cluster1:0' + + # Verify it's in cluster1 using VSIM + initial_vec = self.generate_cluster_vector(cluster1_base) + results = self.redis.execute_command('VSIM', self.test_key, 'VALUES', dim, + *[str(x) for x in initial_vec], + 'COUNT', 100, 'WITHSCORES') + + # Count how many cluster1 items are in top results + cluster1_count = sum(1 for i in range(0, len(results), 2) + if b'cluster1' in results[i]) + assert cluster1_count > 80, "Initial clustering check failed" + + # Now update the test vector to be in cluster2 + new_vec = self.generate_cluster_vector(cluster2_base, noise=0.05) + vec_bytes = struct.pack(f'{dim}f', *new_vec) + self.redis.execute_command('VADD', self.test_key, 'FP32', vec_bytes, test_key) + + # Verify the embedding was actually updated using VEMB + emb_result = self.redis.execute_command('VEMB', self.test_key, test_key) + updated_vec = [float(x) for x in emb_result] + + # Verify updated vector matches what we inserted + dot_product = sum(a*b for a,b in zip(updated_vec, new_vec)) + similarity = dot_product / (math.sqrt(sum(x*x for x in updated_vec)) * + math.sqrt(sum(x*x for x in new_vec))) + assert similarity > 0.9, "Vector was not properly updated" + + # Verify it's now in cluster2 using VSIM + results = self.redis.execute_command('VSIM', self.test_key, 'VALUES', dim, + *[str(x) for x in cluster2_base], + 'COUNT', 100, 'WITHSCORES') + + # Verify our updated vector is among top results + found = False + for i in range(0, len(results), 2): + if results[i].decode() == test_key: + found = True + similarity = float(results[i+1]) + assert similarity > 0.80, f"Updated vector has low similarity: {similarity}" + break + + assert found, "Updated vector not found in cluster2 proximity" diff --git a/modules/vector-sets/tests/persistence.py b/modules/vector-sets/tests/persistence.py new file mode 100644 index 000000000..021c8b6e3 --- /dev/null +++ b/modules/vector-sets/tests/persistence.py @@ -0,0 +1,83 @@ +from test import TestCase, fill_redis_with_vectors, generate_random_vector +import random + +class HNSWPersistence(TestCase): + def getname(self): + return "HNSW Persistence" + + def estimated_runtime(self): + return 30 + + def _verify_results(self, key, dim, query_vec, reduced_dim=None): + """Run a query and return results dict""" + k = 10 + args = ['VSIM', key] + + if reduced_dim: + args.extend(['VALUES', dim]) + args.extend([str(x) for x in query_vec]) + else: + args.extend(['VALUES', dim]) + args.extend([str(x) for x in query_vec]) + + args.extend(['COUNT', k, 'WITHSCORES']) + results = self.redis.execute_command(*args) + + results_dict = {} + for i in range(0, len(results), 2): + key = results[i].decode() + score = float(results[i+1]) + results_dict[key] = score + return results_dict + + def test(self): + # Setup dimensions + dim = 128 + reduced_dim = 32 + count = 5000 + random.seed(42) + + # Create two datasets - one normal and one with dimension reduction + normal_data = fill_redis_with_vectors(self.redis, f"{self.test_key}:normal", count, dim) + projected_data = fill_redis_with_vectors(self.redis, f"{self.test_key}:projected", + count, dim, reduced_dim) + + # Generate query vectors we'll use before and after reload + query_vec_normal = generate_random_vector(dim) + query_vec_projected = generate_random_vector(dim) + + # Get initial results for both sets + initial_normal = self._verify_results(f"{self.test_key}:normal", + dim, query_vec_normal) + initial_projected = self._verify_results(f"{self.test_key}:projected", + dim, query_vec_projected, reduced_dim) + + # Force Redis to save and reload the dataset + self.redis.execute_command('DEBUG', 'RELOAD') + + # Verify results after reload + reloaded_normal = self._verify_results(f"{self.test_key}:normal", + dim, query_vec_normal) + reloaded_projected = self._verify_results(f"{self.test_key}:projected", + dim, query_vec_projected, reduced_dim) + + # Verify normal vectors results + assert len(initial_normal) == len(reloaded_normal), \ + "Normal vectors: Result count mismatch before/after reload" + + for key in initial_normal: + assert key in reloaded_normal, f"Normal vectors: Missing item after reload: {key}" + assert abs(initial_normal[key] - reloaded_normal[key]) < 0.0001, \ + f"Normal vectors: Score mismatch for {key}: " + \ + f"before={initial_normal[key]:.6f}, after={reloaded_normal[key]:.6f}" + + # Verify projected vectors results + assert len(initial_projected) == len(reloaded_projected), \ + "Projected vectors: Result count mismatch before/after reload" + + for key in initial_projected: + assert key in reloaded_projected, \ + f"Projected vectors: Missing item after reload: {key}" + assert abs(initial_projected[key] - reloaded_projected[key]) < 0.0001, \ + f"Projected vectors: Score mismatch for {key}: " + \ + f"before={initial_projected[key]:.6f}, after={reloaded_projected[key]:.6f}" diff --git a/modules/vector-sets/tests/reduce.py b/modules/vector-sets/tests/reduce.py new file mode 100644 index 000000000..e39164f3b --- /dev/null +++ b/modules/vector-sets/tests/reduce.py @@ -0,0 +1,71 @@ +from test import TestCase, fill_redis_with_vectors, generate_random_vector + +class Reduce(TestCase): + def getname(self): + return "Dimension Reduction" + + def estimated_runtime(self): + return 0.2 + + def test(self): + original_dim = 100 + reduced_dim = 80 + count = 1000 + k = 50 # Number of nearest neighbors to check + + # Fill Redis with vectors using REDUCE and get reference data + data = fill_redis_with_vectors(self.redis, self.test_key, count, original_dim, reduced_dim) + + # Verify dimension is reduced + dim = self.redis.execute_command('VDIM', self.test_key) + assert dim == reduced_dim, f"Expected dimension {reduced_dim}, got {dim}" + + # Generate query vector and get nearest neighbors using Redis + query_vec = generate_random_vector(original_dim) + redis_raw = self.redis.execute_command('VSIM', self.test_key, 'VALUES', + original_dim, *[str(x) for x in query_vec], + 'COUNT', k, 'WITHSCORES') + + # Convert Redis results to dict + redis_results = {} + for i in range(0, len(redis_raw), 2): + key = redis_raw[i].decode() + score = float(redis_raw[i+1]) + redis_results[key] = score + + # Get results from linear scan with original vectors + linear_results = data.find_k_nearest(query_vec, k) + linear_items = {name: score for name, score in linear_results} + + # Compare overlap between reduced and non-reduced results + redis_set = set(redis_results.keys()) + linear_set = set(linear_items.keys()) + overlap = len(redis_set & linear_set) + overlap_ratio = overlap / k + + # With random projection, we expect some loss of accuracy but should + # maintain at least some similarity structure. + # Note that gaussian distribution is the worse with this test, so + # in real world practice, things will be better. + min_expected_overlap = 0.1 # At least 10% overlap in top-k + assert overlap_ratio >= min_expected_overlap, \ + f"Dimension reduction lost too much structure. Only {overlap_ratio*100:.1f}% overlap in top {k}" + + # For items that appear in both results, scores should be reasonably correlated + common_items = redis_set & linear_set + for item in common_items: + redis_score = redis_results[item] + linear_score = linear_items[item] + # Allow for some deviation due to dimensionality reduction + assert abs(redis_score - linear_score) < 0.2, \ + f"Score mismatch too high for {item}: Redis={redis_score:.3f} Linear={linear_score:.3f}" + + # If test fails, print comparison for debugging + if overlap_ratio < min_expected_overlap: + print("\nLow overlap in results. Details:") + print("\nTop results from linear scan (original vectors):") + for name, score in linear_results: + print(f"{name}: {score:.3f}") + print("\nTop results from Redis (reduced vectors):") + for item, score in sorted(redis_results.items(), key=lambda x: x[1], reverse=True): + print(f"{item}: {score:.3f}") diff --git a/modules/vector-sets/tests/replication.py b/modules/vector-sets/tests/replication.py new file mode 100644 index 000000000..91dfdf736 --- /dev/null +++ b/modules/vector-sets/tests/replication.py @@ -0,0 +1,92 @@ +from test import TestCase, generate_random_vector +import struct +import random +import time + +class ComprehensiveReplicationTest(TestCase): + def getname(self): + return "Comprehensive Replication Test with mixed operations" + + def estimated_runtime(self): + # This test will take longer than the default 100ms + return 20.0 # 20 seconds estimate + + def test(self): + # Setup replication between primary and replica + assert self.setup_replication(), "Failed to setup replication" + + # Test parameters + num_vectors = 5000 + vector_dim = 8 + delete_probability = 0.1 + cas_probability = 0.3 + + # Keep track of added items for potential deletion + added_items = [] + + # Add vectors and occasionally delete + for i in range(num_vectors): + # Generate a random vector + vec = generate_random_vector(vector_dim) + vec_bytes = struct.pack(f'{vector_dim}f', *vec) + item_name = f"{self.test_key}:item:{i}" + + # Decide whether to use CAS or not + use_cas = random.random() < cas_probability + + if use_cas and added_items: + # Get an existing item for CAS reference (if available) + cas_item = random.choice(added_items) + try: + # Add with CAS + result = self.redis.execute_command('VADD', self.test_key, 'FP32', vec_bytes, + item_name, 'CAS') + # Only add to our list if actually added (CAS might fail) + if result == 1: + added_items.append(item_name) + except Exception as e: + print(f" CAS VADD failed: {e}") + else: + try: + # Add without CAS + result = self.redis.execute_command('VADD', self.test_key, 'FP32', vec_bytes, item_name) + # Only add to our list if actually added + if result == 1: + added_items.append(item_name) + except Exception as e: + print(f" VADD failed: {e}") + + # Randomly delete items (with 10% probability) + if random.random() < delete_probability and added_items: + try: + # Select a random item to delete + item_to_delete = random.choice(added_items) + # Delete the item using VREM (not VDEL) + self.redis.execute_command('VREM', self.test_key, item_to_delete) + # Remove from our list + added_items.remove(item_to_delete) + except Exception as e: + print(f" VREM failed: {e}") + + # Allow time for replication to complete + time.sleep(2.0) + + # Verify final VCARD matches + primary_card = self.redis.execute_command('VCARD', self.test_key) + replica_card = self.replica.execute_command('VCARD', self.test_key) + assert primary_card == replica_card, f"Final VCARD mismatch: primary={primary_card}, replica={replica_card}" + + # Verify VDIM matches + primary_dim = self.redis.execute_command('VDIM', self.test_key) + replica_dim = self.replica.execute_command('VDIM', self.test_key) + assert primary_dim == replica_dim, f"VDIM mismatch: primary={primary_dim}, replica={replica_dim}" + + # Verify digests match using DEBUG DIGEST + primary_digest = self.redis.execute_command('DEBUG', 'DIGEST-VALUE', self.test_key) + replica_digest = self.replica.execute_command('DEBUG', 'DIGEST-VALUE', self.test_key) + assert primary_digest == replica_digest, f"Digest mismatch: primary={primary_digest}, replica={replica_digest}" + + # Print summary + print(f"\n Added and maintained {len(added_items)} vectors with dimension {vector_dim}") + print(f" Final vector count: {primary_card}") + print(f" Final digest: {primary_digest[0].decode()}") diff --git a/modules/vector-sets/tests/vadd_cas.py b/modules/vector-sets/tests/vadd_cas.py new file mode 100644 index 000000000..3cb3508e5 --- /dev/null +++ b/modules/vector-sets/tests/vadd_cas.py @@ -0,0 +1,98 @@ +from test import TestCase, generate_random_vector +import threading +import struct +import math +import time +import random +from typing import List, Dict + +class ConcurrentCASTest(TestCase): + def getname(self): + return "Concurrent VADD with CAS" + + def estimated_runtime(self): + return 1.5 + + def worker(self, vectors: List[List[float]], start_idx: int, end_idx: int, + dim: int, results: Dict[str, bool]): + """Worker thread that adds a subset of vectors using VADD CAS""" + for i in range(start_idx, end_idx): + vec = vectors[i] + name = f"{self.test_key}:item:{i}" + vec_bytes = struct.pack(f'{dim}f', *vec) + + # Try to add the vector with CAS + try: + result = self.redis.execute_command('VADD', self.test_key, 'FP32', + vec_bytes, name, 'CAS') + results[name] = (result == 1) # Store if it was actually added + except Exception as e: + results[name] = False + print(f"Error adding {name}: {e}") + + def verify_vector_similarity(self, vec1: List[float], vec2: List[float]) -> float: + """Calculate cosine similarity between two vectors""" + dot_product = sum(a*b for a,b in zip(vec1, vec2)) + norm1 = math.sqrt(sum(x*x for x in vec1)) + norm2 = math.sqrt(sum(x*x for x in vec2)) + return dot_product / (norm1 * norm2) if norm1 > 0 and norm2 > 0 else 0 + + def test(self): + # Test parameters + dim = 128 + total_vectors = 5000 + num_threads = 8 + vectors_per_thread = total_vectors // num_threads + + # Generate all vectors upfront + random.seed(42) # For reproducibility + vectors = [generate_random_vector(dim) for _ in range(total_vectors)] + + # Prepare threads and results dictionary + threads = [] + results = {} # Will store success/failure for each vector + + # Launch threads + for i in range(num_threads): + start_idx = i * vectors_per_thread + end_idx = start_idx + vectors_per_thread if i < num_threads-1 else total_vectors + thread = threading.Thread(target=self.worker, + args=(vectors, start_idx, end_idx, dim, results)) + threads.append(thread) + thread.start() + + # Wait for all threads to complete + for thread in threads: + thread.join() + + # Verify cardinality + card = self.redis.execute_command('VCARD', self.test_key) + assert card == total_vectors, \ + f"Expected {total_vectors} elements, but found {card}" + + # Verify each vector + num_verified = 0 + for i in range(total_vectors): + name = f"{self.test_key}:item:{i}" + + # Verify the item was successfully added + assert results[name], f"Vector {name} was not successfully added" + + # Get the stored vector + stored_vec_raw = self.redis.execute_command('VEMB', self.test_key, name) + stored_vec = [float(x) for x in stored_vec_raw] + + # Verify vector dimensions + assert len(stored_vec) == dim, \ + f"Stored vector dimension mismatch for {name}: {len(stored_vec)} != {dim}" + + # Calculate similarity with original vector + similarity = self.verify_vector_similarity(vectors[i], stored_vec) + assert similarity > 0.99, \ + f"Low similarity ({similarity}) for {name}" + + num_verified += 1 + + # Final verification + assert num_verified == total_vectors, \ + f"Only verified {num_verified} out of {total_vectors} vectors" diff --git a/modules/vector-sets/tests/vemb.py b/modules/vector-sets/tests/vemb.py new file mode 100644 index 000000000..0f4cf77a7 --- /dev/null +++ b/modules/vector-sets/tests/vemb.py @@ -0,0 +1,41 @@ +from test import TestCase +import struct +import math + +class VEMB(TestCase): + def getname(self): + return "VEMB Command" + + def test(self): + dim = 4 + + # Add same vector in both formats + vec = [1, 0, 0, 0] + norm = math.sqrt(sum(x*x for x in vec)) + vec = [x/norm for x in vec] # Normalize the vector + + # Add using FP32 + vec_bytes = struct.pack(f'{dim}f', *vec) + self.redis.execute_command('VADD', self.test_key, 'FP32', vec_bytes, f'{self.test_key}:item:1') + + # Add using VALUES + self.redis.execute_command('VADD', self.test_key, 'VALUES', dim, + *[str(x) for x in vec], f'{self.test_key}:item:2') + + # Get both back with VEMB + result1 = self.redis.execute_command('VEMB', self.test_key, f'{self.test_key}:item:1') + result2 = self.redis.execute_command('VEMB', self.test_key, f'{self.test_key}:item:2') + + retrieved_vec1 = [float(x) for x in result1] + retrieved_vec2 = [float(x) for x in result2] + + # Compare both vectors with original (allow for small quantization errors) + for i in range(dim): + assert abs(vec[i] - retrieved_vec1[i]) < 0.01, \ + f"FP32 vector component {i} mismatch: expected {vec[i]}, got {retrieved_vec1[i]}" + assert abs(vec[i] - retrieved_vec2[i]) < 0.01, \ + f"VALUES vector component {i} mismatch: expected {vec[i]}, got {retrieved_vec2[i]}" + + # Test non-existent item + result = self.redis.execute_command('VEMB', self.test_key, 'nonexistent') + assert result is None, "Non-existent item should return nil" diff --git a/modules/vector-sets/tests/vrandmember.py b/modules/vector-sets/tests/vrandmember.py new file mode 100644 index 000000000..ca9e0064a --- /dev/null +++ b/modules/vector-sets/tests/vrandmember.py @@ -0,0 +1,55 @@ +from test import TestCase, generate_random_vector, fill_redis_with_vectors +import struct + +class VRANDMEMBERTest(TestCase): + def getname(self): + return "VRANDMEMBER basic functionality" + + def test(self): + # Test with empty key + result = self.redis.execute_command('VRANDMEMBER', self.test_key) + assert result is None, "VRANDMEMBER on non-existent key should return NULL" + + result = self.redis.execute_command('VRANDMEMBER', self.test_key, 5) + assert isinstance(result, list) and len(result) == 0, "VRANDMEMBER with count on non-existent key should return empty array" + + # Fill with vectors + dim = 4 + count = 100 + data = fill_redis_with_vectors(self.redis, self.test_key, count, dim) + + # Test single random member + result = self.redis.execute_command('VRANDMEMBER', self.test_key) + assert result is not None, "VRANDMEMBER should return a random member" + assert result.decode() in data.names, "Random member should be in the set" + + # Test multiple unique members (positive count) + positive_count = 10 + result = self.redis.execute_command('VRANDMEMBER', self.test_key, positive_count) + assert isinstance(result, list), "VRANDMEMBER with positive count should return an array" + assert len(result) == positive_count, f"Should return {positive_count} members" + + # Check for uniqueness + decoded_results = [r.decode() for r in result] + assert len(decoded_results) == len(set(decoded_results)), "Results should be unique with positive count" + for item in decoded_results: + assert item in data.names, "All returned items should be in the set" + + # Test more members than in the set + result = self.redis.execute_command('VRANDMEMBER', self.test_key, count + 10) + assert len(result) == count, "Should return only the available members when asking for more than exist" + + # Test with duplicates (negative count) + negative_count = -20 + result = self.redis.execute_command('VRANDMEMBER', self.test_key, negative_count) + assert isinstance(result, list), "VRANDMEMBER with negative count should return an array" + assert len(result) == abs(negative_count), f"Should return {abs(negative_count)} members" + + # Check that all returned elements are valid + decoded_results = [r.decode() for r in result] + for item in decoded_results: + assert item in data.names, "All returned items should be in the set" + + # Test with count = 0 (edge case) + result = self.redis.execute_command('VRANDMEMBER', self.test_key, 0) + assert isinstance(result, list) and len(result) == 0, "VRANDMEMBER with count=0 should return empty array" diff --git a/modules/vector-sets/vset.c b/modules/vector-sets/vset.c new file mode 100644 index 000000000..c83a4a485 --- /dev/null +++ b/modules/vector-sets/vset.c @@ -0,0 +1,1974 @@ +/* Redis implementation for vector sets. The data structure itself + * is implemented in hnsw.c. + * + * Copyright(C) 2024-Present, Redis Ltd. All Rights Reserved. + * Originally authored by: Salvatore Sanfilippo. + * + * ======================== Understand threading model ========================= + * This code implements threaded operarations for two of the commands: + * + * 1. VSIM, by default. + * 2. VADD, if the CAS option is specified. + * + * Note that even if the second operation, VADD, is a write operation, only + * the neighbors collection for the new node is performed in a thread: then, + * the actual insert is performed in the reply callback VADD_CASReply(), + * which is executed in the main thread. + * + * Threaded operations need us to protect various operations with mutexes, + * even if a certain degree of protection is already provided by the HNSW + * library. Here are a few very important things about this implementation + * and the way locking is performed. + * + * 1. All the write operations are performed in the main Redis thread: + * this also include VADD_CASReply() callback, that is called by Redis + * internals only in the context of the main thread. However the HNSW + * library allows background threads in hnsw_search() (VSIM) to modify + * nodes metadata to speedup search (to understand if a node was already + * visited), but this only happens after acquiring a specific lock + * for a given "read slot". + * + * 2. We use a global lock for each Vector Set object, called "in_use". This + * lock is a read-write lock, and is acquired in read mode by all the + * threads that perform reads in the background. It is only acquired in + * write mode by vectorSetWaitAllBackgroundClients(): the function acquires + * the lock and immediately releases it, with the effect of waiting all the + * background threads still running from ending their execution. + * + * Note that no ther thread can be spawned, since we only call + * vectorSetWaitAllBackgroundClients() from the main Redis thread, that + * is also the only thread spawning other threads. + * + * vectorSetWaitAllBackgroundClients() is used in two ways: + * A) When we need to delete a vector set because of (DEL) or other + * operations destroying the object, we need to wait that all the + * background threads working with this object finished their work. + * B) When we modify the HNSW nodes bypassing the normal locking + * provided by the HNSW library. This only happens when we update + * an existing node attribute so far, in VSETATTR and when we call + * VADD to update a node with the SETATTR option. + * + * 3. Often during read operations performed by Redis commands in the + * main thread (VCARD, VEMB, VRANDMEMBER, ...) we don't acquire any + * lock at all. The commands run in the main Redis thread, we can only + * have, at the same time, background reads against the same data + * structure. Note that VSIM_thread() and VADD_thread() still modify the + * read slot metadata, that is node->visited_epoch[slot], but as long as + * our read commands running in the main thread don't need to use + * hnsw_search() or other HNSW functions using the visited epochs slots + * we are safe. + * + * 4. There is a race from the moment we create a thread, passing the + * vector set object, to the moment the thread can actually lock the + * result win the in_use_lock mutex: as the thread starts, in the meanwhile + * a DEL/expire could trigger and remove the object. For this reason + * we use an atomic counter that protects our object for this small + * time in vectorSetWaitAllBackgroundClients(). This prevents removal + * of objects that are about to be taken by threads. + * + * Note that other competing soltuions could be used to fix the problem + * but have their set of issues, however they are worth documenting here + * and evaluating in the future: + * + * A. Using a conditional variable we could "wait" for the thread to + * acquire the lock. However this means waiting before returning + * to the event loop, and would make the command execution slower. + * B. We could use again an atomic variable, like we did, but this time + * as a refcount for the object, with a vsetAcquire() vsetRelease(). + * In this case, the command could retain the object in the main thread + * before starting the thread, and the thread, after the work is done, + * could release it. This way sometimes the object would be freed by + * the thread, and it's while now can be safe to do the kind of resource + * deallocation that vectorSetReleaseObject() does, given that the + * Redis Modules API is not always thread safe this solution may not + * be future-proof. However there is to evaluate it better in the + * future. + * C. We could use the "B" solution but instead of freeing the object + * in the thread, in this specific case we could just put it into a + * list and defer it for later freeing (for instance in the reply + * callback), so that the object is always freed in the main thread. + * This would require a list of objects to free. + * + * However the current solution only disadvantage is the potential busy + * loop, but this busy loop in practical terms will almost never do + * much: to trigger it, a number of circumnstances must happen: deleting + * Vector Set keys while using them, hitting the small window needed to + * start the thread and read-lock the mutex. + */ + +#define _DEFAULT_SOURCE +#define _USE_MATH_DEFINES +#define _POSIX_C_SOURCE 200809L + +#include "redismodule.h" +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include "hnsw.h" + +// We inline directly the expression implementation here so that building +// the module is trivial. +#include "expr.c" + +static RedisModuleType *VectorSetType; +static uint64_t VectorSetTypeNextId = 0; + +// Default EF value if not specified during creation. +#define VSET_DEFAULT_C_EF 200 + +// Default EF value if not specified during search. +#define VSET_DEFAULT_SEARCH_EF 100 + +// Default num elements returned by VSIM. +#define VSET_DEFAULT_COUNT 10 + +/* ========================== Internal data structure ====================== */ + +/* Our abstract data type needs a dual representation similar to Redis + * sorted set: the proximity graph, and also a element -> graph-node map + * that will allow us to perform deletions and other operations that have + * as input the element itself. */ +struct vsetObject { + HNSW *hnsw; // Proximity graph. + RedisModuleDict *dict; // Element -> node mapping. + float *proj_matrix; // Random projection matrix, NULL if no projection + uint32_t proj_input_size; // Input dimension after projection. + // Output dimension is implicit in + // hnsw->vector_dim. + pthread_rwlock_t in_use_lock; // Lock needed to destroy the object safely. + uint64_t id; // Unique ID used by threaded VADD to know the + // object is still the same. + uint64_t numattribs; // Number of nodes associated with an attribute. + atomic_int thread_creation_pending; // Number of threads that are currently + // pending to lock the object. +}; + +/* Each node has two associated values: the associated string (the item + * in the set) and potentially a JSON string, that is, the attributes, used + * for hybrid search with the VSIM FILTER option. */ +struct vsetNodeVal { + RedisModuleString *item; + RedisModuleString *attrib; +}; + +/* Count the number of set bits in an integer (population count/Hamming weight). + * This is a portable implementation that doesn't rely on compiler + * extensions. */ +static inline uint32_t bit_count(uint32_t n) { + uint32_t count = 0; + while (n) { + count += n & 1; + n >>= 1; + } + return count; +} + +/* Create a Hadamard-based projection matrix for dimensionality reduction. + * Uses {-1, +1} entries with a pattern based on bit operations. + * The pattern is matrix[i][j] = (i & j) % 2 == 0 ? 1 : -1 + * Matrix is scaled by 1/sqrt(input_dim) for normalization. + * Returns NULL on allocation failure. + * + * Note that compared to other approaches (random gaussian weights), what + * we have here is deterministic, it means that our replicas will have + * the same set of weights. Also this approach seems to work much better + * in pratice, and the distances between elements are better guaranteed. + * + * Note that we still save the projection matrix in the RDB file, because + * in the future we may change the weights generation, and we want everything + * to be backward compatible. */ +float *createProjectionMatrix(uint32_t input_dim, uint32_t output_dim) { + float *matrix = RedisModule_Alloc(sizeof(float) * input_dim * output_dim); + + /* Scale factor to normalize the projection. */ + const float scale = 1.0f / sqrt(input_dim); + + /* Fill the matrix using Hadamard pattern. */ + for (uint32_t i = 0; i < output_dim; i++) { + for (uint32_t j = 0; j < input_dim; j++) { + /* Calculate position in the flattened matrix. */ + uint32_t pos = i * input_dim + j; + + /* Hadamard pattern: use bit operations to determine sign + * If the count of 1-bits in the bitwise AND of i and j is even, + * the value is 1, otherwise -1. */ + int value = (bit_count(i & j) % 2 == 0) ? 1 : -1; + + /* Store the scaled value. */ + matrix[pos] = value * scale; + } + } + return matrix; +} + +/* Apply random projection to input vector. Returns new allocated vector. */ +float *applyProjection(const float *input, const float *proj_matrix, + uint32_t input_dim, uint32_t output_dim) +{ + float *output = RedisModule_Alloc(sizeof(float) * output_dim); + + for (uint32_t i = 0; i < output_dim; i++) { + const float *row = &proj_matrix[i * input_dim]; + float sum = 0.0f; + for (uint32_t j = 0; j < input_dim; j++) { + sum += row[j] * input[j]; + } + output[i] = sum; + } + return output; +} + +/* Create the vector as HNSW+Dictionary combined data structure. */ +struct vsetObject *createVectorSetObject(unsigned int dim, uint32_t quant_type, uint32_t hnsw_M) { + struct vsetObject *o; + o = RedisModule_Alloc(sizeof(*o)); + + o->id = VectorSetTypeNextId++; + o->hnsw = hnsw_new(dim,quant_type,hnsw_M); + if (!o->hnsw) { // May fail because of mutex creation. + RedisModule_Free(o); + return NULL; + } + + o->dict = RedisModule_CreateDict(NULL); + o->proj_matrix = NULL; + o->proj_input_size = 0; + o->numattribs = 0; + o->thread_creation_pending = 0; + RedisModule_Assert(pthread_rwlock_init(&o->in_use_lock,NULL) == 0); + return o; +} + +void vectorSetReleaseNodeValue(void *v) { + struct vsetNodeVal *nv = v; + RedisModule_FreeString(NULL,nv->item); + if (nv->attrib) RedisModule_FreeString(NULL,nv->attrib); + RedisModule_Free(nv); +} + +/* Free the vector set object. */ +void vectorSetReleaseObject(struct vsetObject *o) { + if (!o) return; + if (o->hnsw) hnsw_free(o->hnsw,vectorSetReleaseNodeValue); + if (o->dict) RedisModule_FreeDict(NULL,o->dict); + if (o->proj_matrix) RedisModule_Free(o->proj_matrix); + pthread_rwlock_destroy(&o->in_use_lock); + RedisModule_Free(o); +} + +/* Wait for all the threads performing operations on this + * index to terminate their work (locking for write will + * wait for all the other threads). + * + * if 'for_del' is set to 1, we also wait for all the pending threads + * that still didn't acquire the lock to finish their work. This + * is useful only if we are going to call this function to delete + * the object, and not if we want to just to modify it. */ +void vectorSetWaitAllBackgroundClients(struct vsetObject *vset, int for_del) { + if (for_del) { + // If we are going to destroy the object, after this call, let's + // wait for threads that are being created and still didn't had + // a chance to acquire the lock. + while (vset->thread_creation_pending > 0); + } + RedisModule_Assert(pthread_rwlock_wrlock(&vset->in_use_lock) == 0); + pthread_rwlock_unlock(&vset->in_use_lock); +} + +/* Return a string representing the quantization type name of a vector set. */ +const char *vectorSetGetQuantName(struct vsetObject *o) { + switch(o->hnsw->quant_type) { + case HNSW_QUANT_NONE: return "f32"; + case HNSW_QUANT_Q8: return "int8"; + case HNSW_QUANT_BIN: return "bin"; + default: return "unknown"; + } +} + +/* Insert the specified element into the Vector Set. + * If update is '1', the existing node will be updated. + * + * Returns 1 if the element was added, or 0 if the element was already there + * and was just updated. */ +int vectorSetInsert(struct vsetObject *o, float *vec, int8_t *qvec, float qrange, RedisModuleString *val, RedisModuleString *attrib, int update, int ef) +{ + hnswNode *node = RedisModule_DictGet(o->dict,val,NULL); + if (node != NULL) { + if (update) { + /* Wait for clients in the background: background VSIM + * operations touch the nodes attributes we are going + * to touch. */ + vectorSetWaitAllBackgroundClients(o,0); + + struct vsetNodeVal *nv = node->value; + /* Pass NULL as value-free function. We want to reuse + * the old value. */ + hnsw_delete_node(o->hnsw, node, NULL); + node = hnsw_insert(o->hnsw,vec,qvec,qrange,0,nv,ef); + RedisModule_Assert(node != NULL); + RedisModule_DictReplace(o->dict,val,node); + + /* If attrib != NULL, the user wants that in case of an update we + * update the attribute as well (otherwise it reamins as it was). + * Note that the order of operations is conceinved so that it + * works in case the old attrib and the new attrib pointer is the + * same. */ + if (attrib) { + // Empty attribute string means: unset the attribute during + // the update. + size_t attrlen; + RedisModule_StringPtrLen(attrib,&attrlen); + if (attrlen != 0) { + RedisModule_RetainString(NULL,attrib); + o->numattribs++; + } else { + attrib = NULL; + } + + if (nv->attrib) { + o->numattribs--; + RedisModule_FreeString(NULL,nv->attrib); + } + nv->attrib = attrib; + } + } + return 0; + } + + struct vsetNodeVal *nv = RedisModule_Alloc(sizeof(*nv)); + nv->item = val; + nv->attrib = attrib; + node = hnsw_insert(o->hnsw,vec,qvec,qrange,0,nv,ef); + if (node == NULL) { + // XXX Technically in Redis-land we don't have out of memory, as we + // crash on OOM. However the HNSW library may fail for error in the + // locking libc call. Probably impossible in practical terms. + RedisModule_Free(nv); + return 0; + } + if (attrib != NULL) o->numattribs++; + RedisModule_DictSet(o->dict,val,node); + RedisModule_RetainString(NULL,val); + if (attrib) RedisModule_RetainString(NULL,attrib); + return 1; +} + +/* Parse vector from FP32 blob or VALUES format, with optional REDUCE. + * Format: [REDUCE dim] FP32|VALUES ... + * Returns allocated vector and sets dimension in *dim. + * If reduce_dim is not NULL, sets it to the requested reduction dimension. + * Returns NULL on parsing error. + * + * The function sets as a reference *consumed_args, so that the caller + * knows how many arguments we consumed in order to parse the input + * vector. Remaining arguments are often command options. */ +float *parseVector(RedisModuleString **argv, int argc, int start_idx, + size_t *dim, uint32_t *reduce_dim, int *consumed_args) +{ + int consumed = 0; // Argumnets consumed. + + /* Check for REDUCE option first. */ + if (reduce_dim) *reduce_dim = 0; + if (reduce_dim && argc > start_idx + 2 && + !strcasecmp(RedisModule_StringPtrLen(argv[start_idx],NULL),"REDUCE")) + { + long long rdim; + if (RedisModule_StringToLongLong(argv[start_idx+1],&rdim) + != REDISMODULE_OK || rdim <= 0) + { + return NULL; + } + if (reduce_dim) *reduce_dim = rdim; + start_idx += 2; // Skip REDUCE and its argument. + consumed += 2; + } + + /* Now parse the vector format as before. */ + float *vec = NULL; + const char *vec_format = RedisModule_StringPtrLen(argv[start_idx],NULL); + + if (!strcasecmp(vec_format,"FP32")) { + if (argc < start_idx + 2) return NULL; // Need FP32 + vector + value. + size_t vec_raw_len; + const char *blob = + RedisModule_StringPtrLen(argv[start_idx+1],&vec_raw_len); + + // Must be 4 bytes per component. + if (vec_raw_len % 4 || vec_raw_len < 4) return NULL; + *dim = vec_raw_len/4; + + vec = RedisModule_Alloc(vec_raw_len); + if (!vec) return NULL; + memcpy(vec,blob,vec_raw_len); + consumed += 2; + } else if (!strcasecmp(vec_format,"VALUES")) { + if (argc < start_idx + 2) return NULL; // Need at least the dimension. + long long vdim; // Vector dimension passed by the user. + if (RedisModule_StringToLongLong(argv[start_idx+1],&vdim) + != REDISMODULE_OK || vdim < 1) return NULL; + + // Check that all the arguments are available. + if (argc < start_idx + 2 + vdim) return NULL; + + *dim = vdim; + vec = RedisModule_Alloc(sizeof(float) * vdim); + if (!vec) return NULL; + + for (int j = 0; j < vdim; j++) { + double val; + if (RedisModule_StringToDouble(argv[start_idx+2+j],&val) + != REDISMODULE_OK) + { + RedisModule_Free(vec); + return NULL; + } + vec[j] = val; + } + consumed += vdim + 2; + } else { + return NULL; // Unknown format. + } + + if (consumed_args) *consumed_args = consumed; + return vec; +} + +/* ========================== Commands implementation ======================= */ + +/* VADD thread handling the "CAS" version of the command, that is + * performed blocking the client, accumulating here, in the thread, the + * set of potential candidates, and later inserting the element in the + * key (if it still exists, and if it is still the *same* vector set) + * in the Reply callback. */ +void *VADD_thread(void *arg) { + pthread_detach(pthread_self()); + + void **targ = (void**)arg; + RedisModuleBlockedClient *bc = targ[0]; + struct vsetObject *vset = targ[1]; + float *vec = targ[3]; + int ef = (uint64_t)targ[6]; + + /* Lock the object and signal that we are no longer pending + * the lock acquisition. */ + RedisModule_Assert(pthread_rwlock_rdlock(&vset->in_use_lock) == 0); + vset->thread_creation_pending--; + + /* Look for candidates... */ + InsertContext *ic = hnsw_prepare_insert(vset->hnsw, vec, NULL, 0, 0, ef); + targ[5] = ic; // Pass the context to the reply callback. + + /* Unblock the client so that our read reply will be invoked. */ + pthread_rwlock_unlock(&vset->in_use_lock); + RedisModule_BlockedClientMeasureTimeEnd(bc); + RedisModule_UnblockClient(bc,targ); // Use targ as privdata. + return NULL; +} + +/* Reply callback for CAS variant of VADD. + * Note: this is called in the main thread, in the background thread + * we just do the read operation of gathering the neighbors. */ +int VADD_CASReply(RedisModuleCtx *ctx, RedisModuleString **argv, int argc) { + (void)argc; + RedisModule_AutoMemory(ctx); /* Use automatic memory management. */ + + int retval = REDISMODULE_OK; + void **targ = (void**)RedisModule_GetBlockedClientPrivateData(ctx); + uint64_t vset_id = (unsigned long) targ[2]; + float *vec = targ[3]; + RedisModuleString *val = targ[4]; + InsertContext *ic = targ[5]; + int ef = (uint64_t)targ[6]; + RedisModuleString *attrib = targ[7]; + RedisModule_Free(targ); + + /* Open the key: there are no guarantees it still exists, or contains + * a vector set, or even the SAME vector set. */ + RedisModuleKey *key = RedisModule_OpenKey(ctx,argv[1], + REDISMODULE_READ|REDISMODULE_WRITE); + int type = RedisModule_KeyType(key); + struct vsetObject *vset = NULL; + + if (type != REDISMODULE_KEYTYPE_EMPTY && + RedisModule_ModuleTypeGetType(key) == VectorSetType) + { + vset = RedisModule_ModuleTypeGetValue(key); + // Same vector set? + if (vset->id != vset_id) vset = NULL; + + /* Also, if the element was already inserted, we just pretend + * the other insert won. We don't even start a threaded VADD + * if this was an udpate, since the deletion of the element itself + * in order to perform the update would invalidate the CAS state. */ + if (vset && RedisModule_DictGet(vset->dict,val,NULL) != NULL) + vset = NULL; + } + + if (vset == NULL) { + /* If the object does not match the start of the operation, we + * just pretend the VADD was performed BEFORE the key was deleted + * or replaced. We return success but don't do anything. */ + hnsw_free_insert_context(ic); + } else { + /* Otherwise try to insert the new element with the neighbors + * collected in background. If we fail, do it synchronously again + * from scratch. */ + + // First: allocate the dual-ported value for the node. + struct vsetNodeVal *nv = RedisModule_Alloc(sizeof(*nv)); + nv->item = val; + nv->attrib = attrib; + + /* Then: insert the node in the HNSW data structure. Note that + * 'ic' could be NULL in case hnsw_prepare_insert() failed because of + * locking failure (likely impossible in practical terms). */ + hnswNode *newnode; + if (ic == NULL || + (newnode = hnsw_try_commit_insert(vset->hnsw, ic, nv)) == NULL) + { + /* If we are here, the CAS insert failed. We need to insert + * again with full locking for neighbors selection and + * actual insertion. This time we can't fail: */ + newnode = hnsw_insert(vset->hnsw, vec, NULL, 0, 0, nv, ef); + RedisModule_Assert(newnode != NULL); + } + RedisModule_DictSet(vset->dict,val,newnode); + val = NULL; // Don't free it later. + attrib = NULL; // Dont' free it later. + + RedisModule_ReplicateVerbatim(ctx); + } + + // Whatever happens is a success... :D + RedisModule_ReplyWithBool(ctx,1); + if (val) RedisModule_FreeString(ctx,val); // Not added? Free it. + if (attrib) RedisModule_FreeString(ctx,attrib); // Not added? Free it. + RedisModule_Free(vec); + return retval; +} + +/* VADD key [REDUCE dim] FP32|VALUES vector value [CAS] [NOQUANT] [BIN] [Q8] + * [M count] */ +int VADD_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv, int argc) { + RedisModule_AutoMemory(ctx); /* Use automatic memory management. */ + + if (argc < 5) return RedisModule_WrongArity(ctx); + + /* Parse vector with optional REDUCE */ + size_t dim = 0; + uint32_t reduce_dim = 0; + int consumed_args; + int cas = 0; // Threaded check-and-set style insert. + long long ef = VSET_DEFAULT_C_EF; // HNSW creation time EF for new nodes. + long long hnsw_create_M = HNSW_DEFAULT_M; // HNSW creation default M value. + float *vec = parseVector(argv, argc, 2, &dim, &reduce_dim, &consumed_args); + RedisModuleString *attrib = NULL; // Attributes if passed via ATTRIB. + if (!vec) + return RedisModule_ReplyWithError(ctx,"ERR invalid vector specification"); + + /* Missing element string at the end? */ + if (argc-2-consumed_args < 1) return RedisModule_WrongArity(ctx); + + /* Parse options after the element string. */ + uint32_t quant_type = HNSW_QUANT_Q8; // Default quantization type. + + for (int j = 2 + consumed_args + 1; j < argc; j++) { + const char *opt = RedisModule_StringPtrLen(argv[j], NULL); + if (!strcasecmp(opt, "CAS")) { + cas = 1; + } else if (!strcasecmp(opt, "EF") && j+1 < argc) { + if (RedisModule_StringToLongLong(argv[j+1], &ef) + != REDISMODULE_OK || ef <= 0 || ef > 1000000) + { + RedisModule_Free(vec); + return RedisModule_ReplyWithError(ctx, "ERR invalid EF"); + } + j++; // skip argument. + } else if (!strcasecmp(opt, "M") && j+1 < argc) { + if (RedisModule_StringToLongLong(argv[j+1], &hnsw_create_M) + != REDISMODULE_OK || hnsw_create_M < HNSW_MIN_M || + hnsw_create_M > HNSW_MAX_M) + { + RedisModule_Free(vec); + return RedisModule_ReplyWithError(ctx, "ERR invalid M"); + } + j++; // skip argument. + } else if (!strcasecmp(opt, "SETATTR") && j+1 < argc) { + attrib = argv[j+1]; + j++; // skip argument. + } else if (!strcasecmp(opt, "NOQUANT")) { + quant_type = HNSW_QUANT_NONE; + } else if (!strcasecmp(opt, "BIN")) { + quant_type = HNSW_QUANT_BIN; + } else if (!strcasecmp(opt, "Q8")) { + quant_type = HNSW_QUANT_Q8; + } else { + RedisModule_Free(vec); + return RedisModule_ReplyWithError(ctx,"ERR invalid option after element"); + } + } + + /* Drop CAS if this is a replica and we are getting the command from the + * replication link: we want to add/delete items in the same order as + * the master, while with CAS the timing would be different. + * + * Also for Lua scripts and MULTI/EXEC, we want to run the command + * on the main thread. */ + if (RedisModule_GetContextFlags(ctx) & + (REDISMODULE_CTX_FLAGS_REPLICATED| + REDISMODULE_CTX_FLAGS_LUA| + REDISMODULE_CTX_FLAGS_MULTI)) + { + cas = 0; + } + + /* Open/create key */ + RedisModuleKey *key = RedisModule_OpenKey(ctx,argv[1], + REDISMODULE_READ|REDISMODULE_WRITE); + int type = RedisModule_KeyType(key); + if (type != REDISMODULE_KEYTYPE_EMPTY && + RedisModule_ModuleTypeGetType(key) != VectorSetType) + { + RedisModule_Free(vec); + return RedisModule_ReplyWithError(ctx,REDISMODULE_ERRORMSG_WRONGTYPE); + } + + /* Get the correct value argument based on format and REDUCE */ + RedisModuleString *val = argv[2 + consumed_args]; + + /* Create or get existing vector set */ + struct vsetObject *vset; + if (type == REDISMODULE_KEYTYPE_EMPTY) { + cas = 0; /* Do synchronous insert at creation, otherwise the + * key would be left empty until the threaded part + * does not return. It's also pointless to try try + * doing threaded first elemetn insertion. */ + vset = createVectorSetObject(reduce_dim ? reduce_dim : dim, quant_type, hnsw_create_M); + if (vset == NULL) { + // We can't fail for OOM in Redis, but the mutex initialization + // at least theoretically COULD fail. Likely this code path + // is not reachable in practical terms. + RedisModule_Free(vec); + return RedisModule_ReplyWithError(ctx, + "ERR unable to create a Vector Set: system resources issue?"); + } + + /* Initialize projection if requested */ + if (reduce_dim) { + vset->proj_matrix = createProjectionMatrix(dim, reduce_dim); + vset->proj_input_size = dim; + + /* Project the vector */ + float *projected = applyProjection(vec, vset->proj_matrix, + dim, reduce_dim); + RedisModule_Free(vec); + vec = projected; + } + RedisModule_ModuleTypeSetValue(key,VectorSetType,vset); + } else { + vset = RedisModule_ModuleTypeGetValue(key); + + if (vset->hnsw->quant_type != quant_type) { + RedisModule_Free(vec); + return RedisModule_ReplyWithError(ctx, + "ERR asked quantization mismatch with existing vector set"); + } + + if (vset->hnsw->M != hnsw_create_M) { + RedisModule_Free(vec); + return RedisModule_ReplyWithError(ctx, + "ERR asked M value mismatch with existing vector set"); + } + + if ((vset->proj_matrix == NULL && vset->hnsw->vector_dim != dim) || + (vset->proj_matrix && vset->hnsw->vector_dim != reduce_dim)) + { + RedisModule_Free(vec); + return RedisModule_ReplyWithErrorFormat(ctx, + "ERR Vector dimension mismatch - got %d but set has %d", + (int)dim, (int)vset->hnsw->vector_dim); + } + + /* Check REDUCE compatibility */ + if (reduce_dim) { + if (!vset->proj_matrix) { + RedisModule_Free(vec); + return RedisModule_ReplyWithError(ctx, + "ERR cannot add projection to existing set without projection"); + } + if (reduce_dim != vset->hnsw->vector_dim) { + RedisModule_Free(vec); + return RedisModule_ReplyWithError(ctx, + "ERR projection dimension mismatch with existing set"); + } + } + + /* Apply projection if needed */ + if (vset->proj_matrix) { + /* Ensure input dimension matches the projection matrix's expected input dimension */ + if (dim != vset->proj_input_size) { + RedisModule_Free(vec); + return RedisModule_ReplyWithErrorFormat(ctx, + "ERR Input dimension mismatch for projection - got %d but projection expects %d", + (int)dim, (int)vset->proj_input_size); + } + + float *projected = applyProjection(vec, vset->proj_matrix, + vset->proj_input_size, + vset->hnsw->vector_dim); + RedisModule_Free(vec); + vec = projected; + dim = vset->hnsw->vector_dim; + } + } + + /* For existing keys don't do CAS updates. For how things work now, the + * CAS state would be invalidated by the detetion before adding back. */ + if (cas && RedisModule_DictGet(vset->dict,val,NULL) != NULL) + cas = 0; + + /* Here depending on the CAS option we directly insert in a blocking + * way, or use a thread to do candidate neighbors selection and only + * later, in the reply callback, actually add the element. */ + if (cas) { + RedisModuleBlockedClient *bc = RedisModule_BlockClient(ctx,VADD_CASReply,NULL,NULL,0); + pthread_t tid; + void **targ = RedisModule_Alloc(sizeof(void*)*8); + targ[0] = bc; + targ[1] = vset; + targ[2] = (void*)(unsigned long)vset->id; + targ[3] = vec; + targ[4] = val; + targ[5] = NULL; // Used later for insertion context. + targ[6] = (void*)(unsigned long)ef; + targ[7] = attrib; + RedisModule_RetainString(ctx,val); + if (attrib) RedisModule_RetainString(ctx,attrib); + RedisModule_BlockedClientMeasureTimeStart(bc); + vset->thread_creation_pending++; + if (pthread_create(&tid,NULL,VADD_thread,targ) != 0) { + vset->thread_creation_pending--; + RedisModule_AbortBlock(bc); + RedisModule_Free(targ); + RedisModule_FreeString(ctx,val); + if (attrib) RedisModule_FreeString(ctx,attrib); + + // Fall back to synchronous insert, see later in the code. + } else { + return REDISMODULE_OK; + } + } + + /* Insert vector synchronously: we reach this place even + * if cas was true but thread creation failed. */ + int added = vectorSetInsert(vset,vec,NULL,0,val,attrib,1,ef); + RedisModule_Free(vec); + + RedisModule_ReplyWithBool(ctx,added); + if (added) RedisModule_ReplicateVerbatim(ctx); + return REDISMODULE_OK; +} + +/* HNSW callback to filter items according to a predicate function + * (our FILTER expression in this case). */ +int vectorSetFilterCallback(void *value, void *privdata) { + exprstate *expr = privdata; + struct vsetNodeVal *nv = value; + if (nv->attrib == NULL) return 0; // No attributes? No match. + size_t json_len; + char *json = (char*)RedisModule_StringPtrLen(nv->attrib,&json_len); + return exprRun(expr,json,json_len); +} + +/* Common path for the execution of the VSIM command both threaded and + * not threaded. Note that 'ctx' may be normal context of a thread safe + * context obtained from a blocked client. The locking that is specific + * to the vset object is handled by the caller, however the function + * handles the HNSW locking explicitly. */ +void VSIM_execute(RedisModuleCtx *ctx, struct vsetObject *vset, + float *vec, unsigned long count, float epsilon, unsigned long withscores, + unsigned long ef, exprstate *filter_expr, unsigned long filter_ef, + int ground_truth) +{ + /* In our scan, we can't just collect 'count' elements as + * if count is small we would explore the graph in an insufficient + * way to provide enough recall. + * + * If the user didn't asked for a specific exploration, we use + * VSET_DEFAULT_SEARCH_EF as minimum, or we match count if count + * is greater than that. Otherwise the minumim will be the specified + * EF argument. */ + if (ef == 0) ef = VSET_DEFAULT_SEARCH_EF; + if (count > ef) ef = count; + + /* Perform search */ + hnswNode **neighbors = RedisModule_Alloc(sizeof(hnswNode*)*ef); + float *distances = RedisModule_Alloc(sizeof(float)*ef); + int slot = hnsw_acquire_read_slot(vset->hnsw); + unsigned int found; + if (ground_truth) { + found = hnsw_ground_truth_with_filter(vset->hnsw, vec, ef, neighbors, + distances, slot, 0, + filter_expr ? vectorSetFilterCallback : NULL, + filter_expr); + } else { + if (filter_expr == NULL) { + found = hnsw_search(vset->hnsw, vec, ef, neighbors, + distances, slot, 0); + } else { + found = hnsw_search_with_filter(vset->hnsw, vec, ef, neighbors, + distances, slot, 0, vectorSetFilterCallback, + filter_expr, filter_ef); + } + } + + /* Return results */ + if (withscores) + RedisModule_ReplyWithMap(ctx, REDISMODULE_POSTPONED_LEN); + else + RedisModule_ReplyWithArray(ctx, REDISMODULE_POSTPONED_LEN); + long long arraylen = 0; + + for (unsigned int i = 0; i < found && i < count; i++) { + if (distances[i] > epsilon) break; + struct vsetNodeVal *nv = neighbors[i]->value; + RedisModule_ReplyWithString(ctx, nv->item); + arraylen++; + if (withscores) { + /* The similarity score is provided in a 0-1 range. */ + RedisModule_ReplyWithDouble(ctx, 1.0 - distances[i]/2.0); + } + } + hnsw_release_read_slot(vset->hnsw,slot); + + if (withscores) + RedisModule_ReplySetMapLength(ctx, arraylen); + else + RedisModule_ReplySetArrayLength(ctx, arraylen); + + RedisModule_Free(vec); + RedisModule_Free(neighbors); + RedisModule_Free(distances); + if (filter_expr) exprFree(filter_expr); +} + +/* VSIM thread handling the blocked client request. */ +void *VSIM_thread(void *arg) { + pthread_detach(pthread_self()); + + // Extract arguments. + void **targ = (void**)arg; + RedisModuleBlockedClient *bc = targ[0]; + struct vsetObject *vset = targ[1]; + float *vec = targ[2]; + unsigned long count = (unsigned long)targ[3]; + float epsilon = *((float*)targ[4]); + unsigned long withscores = (unsigned long)targ[5]; + unsigned long ef = (unsigned long)targ[6]; + exprstate *filter_expr = targ[7]; + unsigned long filter_ef = (unsigned long)targ[8]; + unsigned long ground_truth = (unsigned long)targ[9]; + RedisModule_Free(targ[4]); + RedisModule_Free(targ); + + /* Lock the object and signal that we are no longer pending + * the lock acquisition. */ + RedisModule_Assert(pthread_rwlock_rdlock(&vset->in_use_lock) == 0); + vset->thread_creation_pending--; + + // Accumulate reply in a thread safe context: no contention. + RedisModuleCtx *ctx = RedisModule_GetThreadSafeContext(bc); + + // Run the query. + VSIM_execute(ctx, vset, vec, count, epsilon, withscores, ef, filter_expr, filter_ef, ground_truth); + pthread_rwlock_unlock(&vset->in_use_lock); + + // Cleanup. + RedisModule_FreeThreadSafeContext(ctx); + RedisModule_BlockedClientMeasureTimeEnd(bc); + RedisModule_UnblockClient(bc,NULL); + return NULL; +} + +/* VSIM key [ELE|FP32|VALUES] [WITHSCORES] [COUNT num] [EPSILON eps] [EF exploration-factor] [FILTER expression] [FILTER-EF exploration-factor] */ +int VSIM_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv, int argc) { + RedisModule_AutoMemory(ctx); + + /* Basic argument check: need at least key and vector specification + * method. */ + if (argc < 4) return RedisModule_WrongArity(ctx); + + /* Defaults */ + int withscores = 0; + long long count = VSET_DEFAULT_COUNT; /* New default value */ + long long ef = 0; /* Exploration factor (see HNSW paper) */ + double epsilon = 2.0; /* Max cosine distance */ + long long ground_truth = 0; /* Linear scan instead of HNSW search? */ + int no_thread = 0; /* NOTHREAD option: exec on main thread. */ + + /* Things computed later. */ + long long filter_ef = 0; + exprstate *filter_expr = NULL; + + /* Get key and vector type */ + RedisModuleString *key = argv[1]; + const char *vectorType = RedisModule_StringPtrLen(argv[2], NULL); + + /* Get vector set */ + RedisModuleKey *keyptr = RedisModule_OpenKey(ctx, key, REDISMODULE_READ); + int type = RedisModule_KeyType(keyptr); + if (type == REDISMODULE_KEYTYPE_EMPTY) + return RedisModule_ReplyWithEmptyArray(ctx); + + if (RedisModule_ModuleTypeGetType(keyptr) != VectorSetType) + return RedisModule_ReplyWithError(ctx, REDISMODULE_ERRORMSG_WRONGTYPE); + + struct vsetObject *vset = RedisModule_ModuleTypeGetValue(keyptr); + + /* Vector parsing stage */ + float *vec = NULL; + size_t dim = 0; + int vector_args = 0; /* Number of args consumed by vector specification */ + + if (!strcasecmp(vectorType, "ELE")) { + /* Get vector from existing element */ + RedisModuleString *ele = argv[3]; + hnswNode *node = RedisModule_DictGet(vset->dict, ele, NULL); + if (!node) { + return RedisModule_ReplyWithError(ctx, "ERR element not found in set"); + } + vec = RedisModule_Alloc(sizeof(float) * vset->hnsw->vector_dim); + hnsw_get_node_vector(vset->hnsw,node,vec); + dim = vset->hnsw->vector_dim; + vector_args = 2; /* ELE + element name */ + } else { + /* Parse vector. */ + int consumed_args; + + vec = parseVector(argv, argc, 2, &dim, NULL, &consumed_args); + if (!vec) { + return RedisModule_ReplyWithError(ctx, + "ERR invalid vector specification"); + } + vector_args = consumed_args; + + /* Apply projection if the set uses it, with the exception + * of ELE type, that will already have the right dimension. */ + if (vset->proj_matrix && dim != vset->hnsw->vector_dim) { + /* Ensure input dimension matches the projection matrix's expected input dimension */ + if (dim != vset->proj_input_size) { + RedisModule_Free(vec); + return RedisModule_ReplyWithErrorFormat(ctx, + "ERR Input dimension mismatch for projection - got %d but projection expects %d", + (int)dim, (int)vset->proj_input_size); + } + + float *projected = applyProjection(vec, vset->proj_matrix, + vset->proj_input_size, + vset->hnsw->vector_dim); + RedisModule_Free(vec); + vec = projected; + dim = vset->hnsw->vector_dim; + } + + /* Count consumed arguments */ + if (!strcasecmp(vectorType, "FP32")) { + vector_args = 2; /* FP32 + vector blob */ + } else if (!strcasecmp(vectorType, "VALUES")) { + long long vdim; + if (RedisModule_StringToLongLong(argv[3], &vdim) != REDISMODULE_OK) { + RedisModule_Free(vec); + return RedisModule_ReplyWithError(ctx, "ERR invalid vector dimension"); + } + vector_args = 2 + vdim; /* VALUES + dim + values */ + } else { + RedisModule_Free(vec); + return RedisModule_ReplyWithError(ctx, + "ERR vector type must be ELE, FP32 or VALUES"); + } + } + + /* Check vector dimension matches set */ + if (dim != vset->hnsw->vector_dim) { + RedisModule_Free(vec); + return RedisModule_ReplyWithErrorFormat(ctx, + "ERR Vector dimension mismatch - got %d but set has %d", + (int)dim, (int)vset->hnsw->vector_dim); + } + + /* Parse optional arguments - start after vector specification */ + int j = 2 + vector_args; + while (j < argc) { + const char *opt = RedisModule_StringPtrLen(argv[j], NULL); + if (!strcasecmp(opt, "WITHSCORES")) { + withscores = 1; + j++; + } else if (!strcasecmp(opt, "TRUTH")) { + ground_truth = 1; + j++; + } else if (!strcasecmp(opt, "NOTHREAD")) { + no_thread = 1; + j++; + } else if (!strcasecmp(opt, "COUNT") && j+1 < argc) { + if (RedisModule_StringToLongLong(argv[j+1], &count) + != REDISMODULE_OK || count <= 0) + { + RedisModule_Free(vec); + return RedisModule_ReplyWithError(ctx, "ERR invalid COUNT"); + } + j += 2; + } else if (!strcasecmp(opt, "EPSILON") && j+1 < argc) { + if (RedisModule_StringToDouble(argv[j+1], &epsilon) != + REDISMODULE_OK || epsilon <= 0) + { + RedisModule_Free(vec); + return RedisModule_ReplyWithError(ctx, "ERR invalid EPSILON"); + } + j += 2; + } else if (!strcasecmp(opt, "EF") && j+1 < argc) { + if (RedisModule_StringToLongLong(argv[j+1], &ef) != + REDISMODULE_OK || ef <= 0) + { + RedisModule_Free(vec); + return RedisModule_ReplyWithError(ctx, "ERR invalid EF"); + } + j += 2; + } else if (!strcasecmp(opt, "FILTER-EF") && j+1 < argc) { + if (RedisModule_StringToLongLong(argv[j+1], &filter_ef) != + REDISMODULE_OK || filter_ef <= 0) + { + RedisModule_Free(vec); + return RedisModule_ReplyWithError(ctx, "ERR invalid FILTER-EF"); + } + j += 2; + } else if (!strcasecmp(opt, "FILTER") && j+1 < argc) { + RedisModuleString *exprarg = argv[j+1]; + size_t exprlen; + char *exprstr = (char*)RedisModule_StringPtrLen(exprarg,&exprlen); + int errpos; + filter_expr = exprCompile(exprstr,&errpos); + if (filter_expr == NULL) { + if ((size_t)errpos >= exprlen) errpos = 0; + RedisModule_Free(vec); + return RedisModule_ReplyWithErrorFormat(ctx, + "ERR syntax error in FILTER expression near: %s", + exprstr+errpos); + } + j += 2; + } else { + RedisModule_Free(vec); + return RedisModule_ReplyWithError(ctx, + "ERR syntax error in VSIM command"); + } + } + + int threaded_request = 1; // Run on a thread, by default. + if (filter_ef == 0) filter_ef = count * 100; // Max filter visited nodes. + + /* Disable threaded for MULTI/EXEC and Lua, or if explicitly + * requsted by the user via the NOTHREAD option. */ + if (no_thread || (RedisModule_GetContextFlags(ctx) & + (REDISMODULE_CTX_FLAGS_LUA| + REDISMODULE_CTX_FLAGS_MULTI))) + { + threaded_request = 0; + } + + if (threaded_request) { + /* Note: even if we create one thread per request, the underlying + * HNSW library has a fixed number of slots for the threads, as it's + * defined in HNSW_MAX_THREADS (beware that if you increase it, + * every node will use more memory). This means that while this request + * is threaded, and will NOT block Redis, it may end waiting for a + * free slot if all the HNSW_MAX_THREADS slots are used. */ + RedisModuleBlockedClient *bc = RedisModule_BlockClient(ctx,NULL,NULL,NULL,0); + pthread_t tid; + void **targ = RedisModule_Alloc(sizeof(void*)*10); + targ[0] = bc; + targ[1] = vset; + targ[2] = vec; + targ[3] = (void*)count; + targ[4] = RedisModule_Alloc(sizeof(float)); + *((float*)targ[4]) = epsilon; + targ[5] = (void*)(unsigned long)withscores; + targ[6] = (void*)(unsigned long)ef; + targ[7] = (void*)filter_expr; + targ[8] = (void*)(unsigned long)filter_ef; + targ[9] = (void*)(unsigned long)ground_truth; + RedisModule_BlockedClientMeasureTimeStart(bc); + vset->thread_creation_pending++; + if (pthread_create(&tid,NULL,VSIM_thread,targ) != 0) { + vset->thread_creation_pending--; + RedisModule_AbortBlock(bc); + RedisModule_Free(targ[4]); + RedisModule_Free(targ); + VSIM_execute(ctx, vset, vec, count, epsilon, withscores, ef, filter_expr, filter_ef, ground_truth); + } + } else { + VSIM_execute(ctx, vset, vec, count, epsilon, withscores, ef, filter_expr, filter_ef, ground_truth); + } + + return REDISMODULE_OK; +} + +/* VDIM : return the dimension of vectors in the vector set. */ +int VDIM_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv, int argc) { + RedisModule_AutoMemory(ctx); + + if (argc != 2) return RedisModule_WrongArity(ctx); + + RedisModuleKey *key = RedisModule_OpenKey(ctx, argv[1], REDISMODULE_READ); + int type = RedisModule_KeyType(key); + + if (type == REDISMODULE_KEYTYPE_EMPTY) + return RedisModule_ReplyWithError(ctx, "ERR key does not exist"); + + if (RedisModule_ModuleTypeGetType(key) != VectorSetType) + return RedisModule_ReplyWithError(ctx, REDISMODULE_ERRORMSG_WRONGTYPE); + + struct vsetObject *vset = RedisModule_ModuleTypeGetValue(key); + return RedisModule_ReplyWithLongLong(ctx, vset->hnsw->vector_dim); +} + +/* VCARD : return cardinality (num of elements) of the vector set. */ +int VCARD_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv, int argc) { + RedisModule_AutoMemory(ctx); + + if (argc != 2) return RedisModule_WrongArity(ctx); + + RedisModuleKey *key = RedisModule_OpenKey(ctx, argv[1], REDISMODULE_READ); + int type = RedisModule_KeyType(key); + + if (type == REDISMODULE_KEYTYPE_EMPTY) + return RedisModule_ReplyWithLongLong(ctx, 0); + + if (RedisModule_ModuleTypeGetType(key) != VectorSetType) + return RedisModule_ReplyWithError(ctx, REDISMODULE_ERRORMSG_WRONGTYPE); + + struct vsetObject *vset = RedisModule_ModuleTypeGetValue(key); + return RedisModule_ReplyWithLongLong(ctx, vset->hnsw->node_count); +} + +/* VREM key element + * Remove an element from a vector set. + * Returns 1 if the element was found and removed, 0 if not found. */ +int VREM_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv, int argc) { + RedisModule_AutoMemory(ctx); /* Use automatic memory management. */ + + if (argc != 3) return RedisModule_WrongArity(ctx); + + /* Get key and value */ + RedisModuleString *key = argv[1]; + RedisModuleString *element = argv[2]; + + /* Open key */ + RedisModuleKey *keyptr = RedisModule_OpenKey(ctx, key, + REDISMODULE_READ|REDISMODULE_WRITE); + int type = RedisModule_KeyType(keyptr); + + /* Handle non-existing key or wrong type */ + if (type == REDISMODULE_KEYTYPE_EMPTY) { + return RedisModule_ReplyWithBool(ctx, 0); + } + if (RedisModule_ModuleTypeGetType(keyptr) != VectorSetType) { + return RedisModule_ReplyWithError(ctx, REDISMODULE_ERRORMSG_WRONGTYPE); + } + + /* Get vector set from key */ + struct vsetObject *vset = RedisModule_ModuleTypeGetValue(keyptr); + + /* Find the node for this element */ + hnswNode *node = RedisModule_DictGet(vset->dict, element, NULL); + if (!node) { + return RedisModule_ReplyWithBool(ctx, 0); + } + + /* Remove from dictionary */ + RedisModule_DictDel(vset->dict, element, NULL); + + /* Remove from HNSW graph using the high-level API that handles + * locking and cleanup. We pass RedisModule_FreeString as the value + * free function since the strings were retained at insertion time. */ + struct vsetNodeVal *nv = node->value; + if (nv->attrib != NULL) vset->numattribs--; + RedisModule_Assert(hnsw_delete_node(vset->hnsw, node, vectorSetReleaseNodeValue) == 1); + + /* Destroy empty vector set. */ + if (RedisModule_DictSize(vset->dict) == 0) { + RedisModule_DeleteKey(keyptr); + } + + /* Reply and propagate the command */ + RedisModule_ReplyWithBool(ctx, 1); + RedisModule_ReplicateVerbatim(ctx); + return REDISMODULE_OK; +} + +/* VEMB key element + * Returns the embedding vector associated with an element, or NIL if not + * found. The vector is returned in the same format it was added, but the + * return value will have some lack of precision due to quantization and + * normalization of vectors. Also, if items were added using REDUCE, the + * reduced vector is returned instead. */ +int VEMB_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv, int argc) { + RedisModule_AutoMemory(ctx); + int raw_output = 0; // RAW option. + + if (argc < 3) return RedisModule_WrongArity(ctx); + + /* Parse arguments. */ + for (int j = 3; j < argc; j++) { + const char *opt = RedisModule_StringPtrLen(argv[j], NULL); + if (!strcasecmp(opt,"raw")) { + raw_output = 1; + } else { + return RedisModule_ReplyWithError(ctx,"ERR invalid option"); + } + } + + /* Get key and element. */ + RedisModuleString *key = argv[1]; + RedisModuleString *element = argv[2]; + + /* Open key. */ + RedisModuleKey *keyptr = RedisModule_OpenKey(ctx, key, REDISMODULE_READ); + int type = RedisModule_KeyType(keyptr); + + /* Handle non-existing key and key of wrong type. */ + if (type == REDISMODULE_KEYTYPE_EMPTY) { + return RedisModule_ReplyWithNull(ctx); + } else if (RedisModule_ModuleTypeGetType(keyptr) != VectorSetType) { + return RedisModule_ReplyWithError(ctx, REDISMODULE_ERRORMSG_WRONGTYPE); + } + + /* Lookup the node about the specified element. */ + struct vsetObject *vset = RedisModule_ModuleTypeGetValue(keyptr); + hnswNode *node = RedisModule_DictGet(vset->dict, element, NULL); + if (!node) { + return RedisModule_ReplyWithNull(ctx); + } + + if (raw_output) { + int output_qrange = vset->hnsw->quant_type == HNSW_QUANT_Q8; + RedisModule_ReplyWithArray(ctx, 3+output_qrange); + RedisModule_ReplyWithSimpleString(ctx, vectorSetGetQuantName(vset)); + RedisModule_ReplyWithStringBuffer(ctx, node->vector, hnsw_quants_bytes(vset->hnsw)); + RedisModule_ReplyWithDouble(ctx, node->l2); + if (output_qrange) RedisModule_ReplyWithDouble(ctx, node->quants_range); + } else { + /* Get the vector associated with the node. */ + float *vec = RedisModule_Alloc(sizeof(float) * vset->hnsw->vector_dim); + hnsw_get_node_vector(vset->hnsw, node, vec); // May dequantize/denorm. + + /* Return as array of doubles. */ + RedisModule_ReplyWithArray(ctx, vset->hnsw->vector_dim); + for (uint32_t i = 0; i < vset->hnsw->vector_dim; i++) + RedisModule_ReplyWithDouble(ctx, vec[i]); + RedisModule_Free(vec); + } + return REDISMODULE_OK; +} + +/* VSETATTR key element json + * Set or remove the JSON attribute associated with an element. + * Setting an empty string removes the attribute. + * The command returns one if the attribute was actually updated or + * zero if there is no key or element. */ +int VSETATTR_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv, int argc) { + RedisModule_AutoMemory(ctx); + + if (argc != 4) return RedisModule_WrongArity(ctx); + + RedisModuleKey *key = RedisModule_OpenKey(ctx, argv[1], + REDISMODULE_READ|REDISMODULE_WRITE); + int type = RedisModule_KeyType(key); + + if (type == REDISMODULE_KEYTYPE_EMPTY) + return RedisModule_ReplyWithBool(ctx, 0); + + if (RedisModule_ModuleTypeGetType(key) != VectorSetType) + return RedisModule_ReplyWithError(ctx, REDISMODULE_ERRORMSG_WRONGTYPE); + + struct vsetObject *vset = RedisModule_ModuleTypeGetValue(key); + hnswNode *node = RedisModule_DictGet(vset->dict, argv[2], NULL); + if (!node) + return RedisModule_ReplyWithBool(ctx, 0); + + struct vsetNodeVal *nv = node->value; + RedisModuleString *new_attr = argv[3]; + + /* Background VSIM operations use the node attributes, so + * wait for background operations before messing with them. */ + vectorSetWaitAllBackgroundClients(vset,0); + + /* Set or delete the attribute based on the fact it's an empty + * string or not. */ + size_t attrlen; + RedisModule_StringPtrLen(new_attr, &attrlen); + if (attrlen == 0) { + // If we had an attribute before, decrease the count and free it. + if (nv->attrib) { + vset->numattribs--; + RedisModule_FreeString(NULL, nv->attrib); + nv->attrib = NULL; + } + } else { + // If we didn't have an attribute before, increase the count. + // Otherwise free the old one. + if (nv->attrib) { + RedisModule_FreeString(NULL, nv->attrib); + } else { + vset->numattribs++; + } + // Set new attribute. + RedisModule_RetainString(NULL, new_attr); + nv->attrib = new_attr; + } + + RedisModule_ReplyWithBool(ctx, 1); + RedisModule_ReplicateVerbatim(ctx); + return REDISMODULE_OK; +} + +/* VGETATTR key element + * Get the JSON attribute associated with an element. + * Returns NIL if the element has no attribute or doesn't exist. */ +int VGETATTR_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv, int argc) { + RedisModule_AutoMemory(ctx); + + if (argc != 3) return RedisModule_WrongArity(ctx); + + RedisModuleKey *key = RedisModule_OpenKey(ctx, argv[1], REDISMODULE_READ); + int type = RedisModule_KeyType(key); + + if (type == REDISMODULE_KEYTYPE_EMPTY) + return RedisModule_ReplyWithNull(ctx); + + if (RedisModule_ModuleTypeGetType(key) != VectorSetType) + return RedisModule_ReplyWithError(ctx, REDISMODULE_ERRORMSG_WRONGTYPE); + + struct vsetObject *vset = RedisModule_ModuleTypeGetValue(key); + hnswNode *node = RedisModule_DictGet(vset->dict, argv[2], NULL); + if (!node) + return RedisModule_ReplyWithNull(ctx); + + struct vsetNodeVal *nv = node->value; + if (!nv->attrib) + return RedisModule_ReplyWithNull(ctx); + + return RedisModule_ReplyWithString(ctx, nv->attrib); +} + +/* ============================== Reflection ================================ */ + +/* VLINKS key element [WITHSCORES] + * Returns the neighbors of an element at each layer in the HNSW graph. + * Reply is an array of arrays, where each nested array represents one level + * of neighbors, from highest level to level 0. If WITHSCORES is specified, + * each neighbor is followed by its distance from the element. */ +int VLINKS_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv, int argc) { + RedisModule_AutoMemory(ctx); + + if (argc < 3 || argc > 4) return RedisModule_WrongArity(ctx); + + RedisModuleString *key = argv[1]; + RedisModuleString *element = argv[2]; + + /* Parse WITHSCORES option. */ + int withscores = 0; + if (argc == 4) { + const char *opt = RedisModule_StringPtrLen(argv[3], NULL); + if (strcasecmp(opt, "WITHSCORES") != 0) { + return RedisModule_WrongArity(ctx); + } + withscores = 1; + } + + RedisModuleKey *keyptr = RedisModule_OpenKey(ctx, key, REDISMODULE_READ); + int type = RedisModule_KeyType(keyptr); + + /* Handle non-existing key or wrong type. */ + if (type == REDISMODULE_KEYTYPE_EMPTY) + return RedisModule_ReplyWithNull(ctx); + + if (RedisModule_ModuleTypeGetType(keyptr) != VectorSetType) + return RedisModule_ReplyWithError(ctx, REDISMODULE_ERRORMSG_WRONGTYPE); + + /* Find the node for this element. */ + struct vsetObject *vset = RedisModule_ModuleTypeGetValue(keyptr); + hnswNode *node = RedisModule_DictGet(vset->dict, element, NULL); + if (!node) + return RedisModule_ReplyWithNull(ctx); + + /* Reply with array of arrays, one per level. */ + RedisModule_ReplyWithArray(ctx, node->level + 1); + + /* For each level, from highest to lowest: */ + for (int i = node->level; i >= 0; i--) { + /* Reply with array of neighbors at this level. */ + if (withscores) + RedisModule_ReplyWithMap(ctx,node->layers[i].num_links); + else + RedisModule_ReplyWithArray(ctx,node->layers[i].num_links); + + /* Add each neighbor's element value to the array. */ + for (uint32_t j = 0; j < node->layers[i].num_links; j++) { + struct vsetNodeVal *nv = node->layers[i].links[j]->value; + RedisModule_ReplyWithString(ctx, nv->item); + if (withscores) { + float distance = hnsw_distance(vset->hnsw, node, node->layers[i].links[j]); + /* Convert distance to similarity score to match + * VSIM behavior.*/ + float similarity = 1.0 - distance/2.0; + RedisModule_ReplyWithDouble(ctx, similarity); + } + } + } + return REDISMODULE_OK; +} + +/* VINFO key + * Returns information about a vector set, both visible and hidden + * features of the HNSW data structure. */ +int VINFO_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv, int argc) { + RedisModule_AutoMemory(ctx); + + if (argc != 2) return RedisModule_WrongArity(ctx); + + RedisModuleKey *key = RedisModule_OpenKey(ctx, argv[1], REDISMODULE_READ); + int type = RedisModule_KeyType(key); + + if (type == REDISMODULE_KEYTYPE_EMPTY) + return RedisModule_ReplyWithNullArray(ctx); + + if (RedisModule_ModuleTypeGetType(key) != VectorSetType) + return RedisModule_ReplyWithError(ctx, REDISMODULE_ERRORMSG_WRONGTYPE); + + struct vsetObject *vset = RedisModule_ModuleTypeGetValue(key); + + /* Reply with hash */ + RedisModule_ReplyWithMap(ctx, 9); + + /* Quantization type */ + RedisModule_ReplyWithSimpleString(ctx, "quant-type"); + RedisModule_ReplyWithSimpleString(ctx, vectorSetGetQuantName(vset)); + + /* HNSW M value */ + RedisModule_ReplyWithSimpleString(ctx, "hnsw-m"); + RedisModule_ReplyWithLongLong(ctx, vset->hnsw->M); + + /* Vector dimensionality. */ + RedisModule_ReplyWithSimpleString(ctx, "vector-dim"); + RedisModule_ReplyWithLongLong(ctx, vset->hnsw->vector_dim); + + /* Original input dimension before projection. + * This is zero for vector sets without a random projection matrix. */ + RedisModule_ReplyWithSimpleString(ctx, "projection-input-dim"); + RedisModule_ReplyWithLongLong(ctx, vset->proj_input_size); + + /* Number of elements. */ + RedisModule_ReplyWithSimpleString(ctx, "size"); + RedisModule_ReplyWithLongLong(ctx, vset->hnsw->node_count); + + /* Max level of HNSW. */ + RedisModule_ReplyWithSimpleString(ctx, "max-level"); + RedisModule_ReplyWithLongLong(ctx, vset->hnsw->max_level); + + /* Number of nodes with attributes. */ + RedisModule_ReplyWithSimpleString(ctx, "attributes-count"); + RedisModule_ReplyWithLongLong(ctx, vset->numattribs); + + /* Vector set ID. */ + RedisModule_ReplyWithSimpleString(ctx, "vset-uid"); + RedisModule_ReplyWithLongLong(ctx, vset->id); + + /* HNSW max node ID. */ + RedisModule_ReplyWithSimpleString(ctx, "hnsw-max-node-uid"); + RedisModule_ReplyWithLongLong(ctx, vset->hnsw->last_id); + + return REDISMODULE_OK; +} + +/* VRANDMEMBER key [count] + * Return random members from a vector set. + * + * Without count: returns a single random member. + * With positive count: N unique random members (no duplicates). + * With negative count: N random members (with possible duplicates). + * + * If the key doesn't exist, returns NULL if count is not given, or + * an empty array if a count was given. */ +int VRANDMEMBER_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv, int argc) { + RedisModule_AutoMemory(ctx); /* Use automatic memory management. */ + + /* Check arguments. */ + if (argc != 2 && argc != 3) return RedisModule_WrongArity(ctx); + + /* Parse optional count argument. */ + long long count = 1; /* Default is to return a single element. */ + int with_count = (argc == 3); + + if (with_count) { + if (RedisModule_StringToLongLong(argv[2], &count) != REDISMODULE_OK) { + return RedisModule_ReplyWithError(ctx, + "ERR COUNT value is not an integer"); + } + /* Count = 0 is a special case, return empty array */ + if (count == 0) { + return RedisModule_ReplyWithEmptyArray(ctx); + } + } + + /* Open key. */ + RedisModuleKey *key = RedisModule_OpenKey(ctx, argv[1], REDISMODULE_READ); + int type = RedisModule_KeyType(key); + + /* Handle non-existing key. */ + if (type == REDISMODULE_KEYTYPE_EMPTY) { + if (!with_count) { + return RedisModule_ReplyWithNull(ctx); + } else { + return RedisModule_ReplyWithEmptyArray(ctx); + } + } + + /* Check key type. */ + if (RedisModule_ModuleTypeGetType(key) != VectorSetType) { + return RedisModule_ReplyWithError(ctx, REDISMODULE_ERRORMSG_WRONGTYPE); + } + + /* Get vector set from key. */ + struct vsetObject *vset = RedisModule_ModuleTypeGetValue(key); + uint64_t set_size = vset->hnsw->node_count; + + /* No elements in the set? */ + if (set_size == 0) { + if (!with_count) { + return RedisModule_ReplyWithNull(ctx); + } else { + return RedisModule_ReplyWithEmptyArray(ctx); + } + } + + /* Case 1: No count specified: return a single element. */ + if (!with_count) { + hnswNode *random_node = hnsw_random_node(vset->hnsw, 0); + if (random_node) { + struct vsetNodeVal *nv = random_node->value; + return RedisModule_ReplyWithString(ctx, nv->item); + } else { + return RedisModule_ReplyWithNull(ctx); + } + } + + /* Case 2: COUNT option given, return an array of elements. */ + int allow_duplicates = (count < 0); + long long abs_count = (count < 0) ? -count : count; + + /* Cap the count to the set size if we are not allowing duplicates. */ + if (!allow_duplicates && abs_count > (long long)set_size) + abs_count = set_size; + + /* Prepare reply. */ + RedisModule_ReplyWithArray(ctx, abs_count); + + if (allow_duplicates) { + /* Simple case: With duplicates, just pick random nodes + * abs_count times. */ + for (long long i = 0; i < abs_count; i++) { + hnswNode *random_node = hnsw_random_node(vset->hnsw,0); + struct vsetNodeVal *nv = random_node->value; + RedisModule_ReplyWithString(ctx, nv->item); + } + } else { + /* Case where count is positive: we need unique elements. + * But, if the user asked for many elements, selecting so + * many (> 20%) random nodes may be too expansive: we just start + * from a random element and follow the next link. + * + * Otherwisem for the <= 20% case, a dictionary is used to + * reject duplicates. */ + int use_dict = (abs_count <= set_size * 0.2); + + if (use_dict) { + RedisModuleDict *returned = RedisModule_CreateDict(ctx); + + long long returned_count = 0; + while (returned_count < abs_count) { + hnswNode *random_node = hnsw_random_node(vset->hnsw, 0); + struct vsetNodeVal *nv = random_node->value; + + /* Check if we've already returned this element. */ + if (RedisModule_DictGet(returned, nv->item, NULL) == NULL) { + /* Mark as returned and add to results. */ + RedisModule_DictSet(returned, nv->item, (void*)1); + RedisModule_ReplyWithString(ctx, nv->item); + returned_count++; + } + } + RedisModule_FreeDict(ctx, returned); + } else { + /* For large samples, get a random starting node and walk + * the list. + * + * IMPORTANT: doing so does not really generate random + * elements: it's just a linear scan, but we have no choices. + * If we generate too many random elements, more and more would + * fail the check of being novel (not yet collected in the set + * to return) if the % of elements to emit is too large, we would + * spend too much CPU. */ + hnswNode *start_node = hnsw_random_node(vset->hnsw, 0); + hnswNode *current = start_node; + + long long returned_count = 0; + while (returned_count < abs_count) { + if (current == NULL) { + /* Restart from head if we hit the end. */ + current = vset->hnsw->head; + } + struct vsetNodeVal *nv = current->value; + RedisModule_ReplyWithString(ctx, nv->item); + returned_count++; + current = current->next; + } + } + } + return REDISMODULE_OK; +} + +/* ============================== vset type methods ========================= */ + +#define SAVE_FLAG_HAS_PROJMATRIX (1<<0) +#define SAVE_FLAG_HAS_ATTRIBS (1<<1) + +/* Save object to RDB */ +void VectorSetRdbSave(RedisModuleIO *rdb, void *value) { + struct vsetObject *vset = value; + RedisModule_SaveUnsigned(rdb, vset->hnsw->vector_dim); + RedisModule_SaveUnsigned(rdb, vset->hnsw->node_count); + + uint32_t hnsw_config = (vset->hnsw->quant_type & 0xff) | + ((vset->hnsw->M & 0xffff) << 8); + RedisModule_SaveUnsigned(rdb, hnsw_config); + + uint32_t save_flags = 0; + if (vset->proj_matrix) save_flags |= SAVE_FLAG_HAS_PROJMATRIX; + if (vset->numattribs != 0) save_flags |= SAVE_FLAG_HAS_ATTRIBS; + RedisModule_SaveUnsigned(rdb, save_flags); + + /* Save projection matrix if present */ + if (vset->proj_matrix) { + uint32_t input_dim = vset->proj_input_size; + uint32_t output_dim = vset->hnsw->vector_dim; + RedisModule_SaveUnsigned(rdb, input_dim); + // Output dim is the same as the first value saved + // above, so we don't save it. + + // Save projection matrix as binary blob + size_t matrix_size = sizeof(float) * input_dim * output_dim; + RedisModule_SaveStringBuffer(rdb, (const char *)vset->proj_matrix, matrix_size); + } + + hnswNode *node = vset->hnsw->head; + while(node) { + struct vsetNodeVal *nv = node->value; + RedisModule_SaveString(rdb, nv->item); + if (vset->numattribs) { + if (nv->attrib) + RedisModule_SaveString(rdb, nv->attrib); + else + RedisModule_SaveStringBuffer(rdb, "", 0); + } + hnswSerNode *sn = hnsw_serialize_node(vset->hnsw,node); + RedisModule_SaveStringBuffer(rdb, (const char *)sn->vector, sn->vector_size); + RedisModule_SaveUnsigned(rdb, sn->params_count); + for (uint32_t j = 0; j < sn->params_count; j++) + RedisModule_SaveUnsigned(rdb, sn->params[j]); + hnsw_free_serialized_node(sn); + node = node->next; + } +} + +/* Load object from RDB. Please note that we don't do any cleanup + * on errors, and just return NULL, as Redis will abort completely + * not just the module but the server itself in this case. */ +void *VectorSetRdbLoad(RedisModuleIO *rdb, int encver) { + if (encver != 0) return NULL; // Invalid version + + uint32_t dim = RedisModule_LoadUnsigned(rdb); + uint64_t elements = RedisModule_LoadUnsigned(rdb); + uint32_t hnsw_config = RedisModule_LoadUnsigned(rdb); + uint32_t quant_type = hnsw_config & 0xff; + uint32_t hnsw_m = (hnsw_config >> 8) & 0xffff; + + if (hnsw_m == 0) hnsw_m = 16; // Default, useful for RDB files predating + // this configuration parameter: it was fixed + // to 16. + struct vsetObject *vset = createVectorSetObject(dim,quant_type,hnsw_m); + RedisModule_Assert(vset != NULL); + + /* Load projection matrix if present */ + uint32_t save_flags = RedisModule_LoadUnsigned(rdb); + int has_projection = save_flags & SAVE_FLAG_HAS_PROJMATRIX; + int has_attribs = save_flags & SAVE_FLAG_HAS_ATTRIBS; + if (has_projection) { + uint32_t input_dim = RedisModule_LoadUnsigned(rdb); + uint32_t output_dim = dim; + size_t matrix_size = sizeof(float) * input_dim * output_dim; + + vset->proj_matrix = RedisModule_Alloc(matrix_size); + if (!vset->proj_matrix) { + vectorSetReleaseObject(vset); + return NULL; + } + vset->proj_input_size = input_dim; + + // Load projection matrix as a binary blob + char *matrix_blob = RedisModule_LoadStringBuffer(rdb, NULL); + memcpy(vset->proj_matrix, matrix_blob, matrix_size); + RedisModule_Free(matrix_blob); + } + + while(elements--) { + // Load associated string element. + RedisModuleString *ele = RedisModule_LoadString(rdb); + RedisModuleString *attrib = NULL; + if (has_attribs) { + attrib = RedisModule_LoadString(rdb); + size_t attrlen; + RedisModule_StringPtrLen(attrib,&attrlen); + if (attrlen == 0) { + RedisModule_FreeString(NULL,attrib); + attrib = NULL; + } + } + size_t vector_len; + void *vector = RedisModule_LoadStringBuffer(rdb, &vector_len); + uint32_t vector_bytes = hnsw_quants_bytes(vset->hnsw); + if (vector_len != vector_bytes) { + RedisModule_LogIOError(rdb,"warning", + "Mismatching vector dimension"); + return NULL; // Loading error. + } + + // Load node parameters back. + uint32_t params_count = RedisModule_LoadUnsigned(rdb); + uint64_t *params = RedisModule_Alloc(params_count*sizeof(uint64_t)); + for (uint32_t j = 0; j < params_count; j++) + params[j] = RedisModule_LoadUnsigned(rdb); + + struct vsetNodeVal *nv = RedisModule_Alloc(sizeof(*nv)); + nv->item = ele; + nv->attrib = attrib; + hnswNode *node = hnsw_insert_serialized(vset->hnsw, vector, params, params_count, nv); + if (node == NULL) { + RedisModule_LogIOError(rdb,"warning", + "Vector set node index loading error"); + return NULL; // Loading error. + } + if (nv->attrib) vset->numattribs++; + RedisModule_DictSet(vset->dict,ele,node); + RedisModule_Free(vector); + RedisModule_Free(params); + } + hnsw_deserialize_index(vset->hnsw); + return vset; +} + +/* Calculate memory usage */ +size_t VectorSetMemUsage(const void *value) { + const struct vsetObject *vset = value; + size_t size = sizeof(*vset); + + /* Account for HNSW index base structure */ + size += sizeof(HNSW); + + /* Account for projection matrix if present */ + if (vset->proj_matrix) { + /* For the matrix size, we need the input dimension. We can get it + * from the first node if the set is not empty. */ + uint32_t input_dim = vset->proj_input_size; + uint32_t output_dim = vset->hnsw->vector_dim; + size += sizeof(float) * input_dim * output_dim; + } + + /* Account for each node's memory usage. */ + hnswNode *node = vset->hnsw->head; + if (node == NULL) return size; + + /* Base node structure. */ + size += sizeof(*node) * vset->hnsw->node_count; + + /* Vector storage. */ + uint64_t vec_storage = hnsw_quants_bytes(vset->hnsw); + size += vec_storage * vset->hnsw->node_count; + + /* Layers array. We use 1.33 as average nodes layers count. */ + uint64_t layers_storage = sizeof(hnswNodeLayer) * vset->hnsw->node_count; + layers_storage = layers_storage * 4 / 3; // 1.33 times. + size += layers_storage; + + /* All the nodes have layer 0 links. */ + uint64_t level0_links = node->layers[0].max_links; + uint64_t other_levels_links = level0_links/2; + size += sizeof(hnswNode*) * level0_links * vset->hnsw->node_count; + + /* Add the 0.33 remaining part, but upper layers have less links. */ + size += (sizeof(hnswNode*) * other_levels_links * vset->hnsw->node_count)/3; + + /* Associated string value and attributres. + * Use Redis Module API to get string size, and guess that all the + * elements have similar size as the first few. */ + size_t items_scanned = 0, items_size = 0; + size_t attribs_scanned = 0, attribs_size = 0; + int scan_effort = 20; + while(scan_effort > 0 && node) { + struct vsetNodeVal *nv = node->value; + items_size += RedisModule_MallocSizeString(nv->item); + items_scanned++; + if (nv->attrib) { + attribs_size += RedisModule_MallocSizeString(nv->attrib); + attribs_scanned++; + } + scan_effort--; + node = node->next; + } + + /* Add the memory usage due to items. */ + if (items_scanned) + size += items_size / items_scanned * vset->hnsw->node_count; + + /* Add memory usage due to attributres. */ + if (attribs_scanned == 0) { + /* We were not lucky enough to find a single attribute in the + * first few items? Let's use a fixed arbitrary value. */ + attribs_scanned = 1; + attribs_size = 64; + } + size += attribs_size / attribs_scanned * vset->numattribs; + + /* Account for dictionary overhead - this is an approximation. */ + size += RedisModule_DictSize(vset->dict) * (sizeof(void*) * 2); + + return size; +} + +/* Free the entire data structure */ +void VectorSetFree(void *value) { + struct vsetObject *vset = value; + + vectorSetWaitAllBackgroundClients(vset,1); + vectorSetReleaseObject(value); +} + +/* Add object digest to the digest context */ +void VectorSetDigest(RedisModuleDigest *md, void *value) { + struct vsetObject *vset = value; + + /* Add consistent order-independent hash of all vectors */ + hnswNode *node = vset->hnsw->head; + + /* Hash the vector dimension and number of nodes. */ + RedisModule_DigestAddLongLong(md, vset->hnsw->node_count); + RedisModule_DigestAddLongLong(md, vset->hnsw->vector_dim); + RedisModule_DigestEndSequence(md); + + while(node) { + struct vsetNodeVal *nv = node->value; + /* Hash each vector component */ + RedisModule_DigestAddStringBuffer(md, node->vector, hnsw_quants_bytes(vset->hnsw)); + /* Hash the associated value */ + size_t len; + const char *str = RedisModule_StringPtrLen(nv->item, &len); + RedisModule_DigestAddStringBuffer(md, (char*)str, len); + if (nv->attrib) { + str = RedisModule_StringPtrLen(nv->attrib, &len); + RedisModule_DigestAddStringBuffer(md, (char*)str, len); + } + node = node->next; + RedisModule_DigestEndSequence(md); + } +} + +/* This function must be present on each Redis module. It is used in order to + * register the commands into the Redis server. */ +#ifdef MERGED_REDIS_MODULE +#define ONLOAD VectorSets_OnLoad +#else +#define ONLOAD RedisModule_OnLoad +#endif +int ONLOAD(RedisModuleCtx *ctx, RedisModuleString **argv, int argc) { + REDISMODULE_NOT_USED(argv); + REDISMODULE_NOT_USED(argc); + + if (RedisModule_Init(ctx,"vectorset",1,REDISMODULE_APIVER_1) + == REDISMODULE_ERR) return REDISMODULE_ERR; + + RedisModuleTypeMethods tm = { + .version = REDISMODULE_TYPE_METHOD_VERSION, + .rdb_load = VectorSetRdbLoad, + .rdb_save = VectorSetRdbSave, + .aof_rewrite = NULL, + .mem_usage = VectorSetMemUsage, + .free = VectorSetFree, + .digest = VectorSetDigest + }; + + VectorSetType = RedisModule_CreateDataType(ctx,"vectorset",0,&tm); + if (VectorSetType == NULL) return REDISMODULE_ERR; + + if (RedisModule_CreateCommand(ctx,"VADD", + VADD_RedisCommand,"write deny-oom",1,1,1) == REDISMODULE_ERR) + return REDISMODULE_ERR; + + if (RedisModule_CreateCommand(ctx,"VREM", + VREM_RedisCommand,"write",1,1,1) == REDISMODULE_ERR) + return REDISMODULE_ERR; + + if (RedisModule_CreateCommand(ctx,"VSIM", + VSIM_RedisCommand,"readonly",1,1,1) == REDISMODULE_ERR) + return REDISMODULE_ERR; + + if (RedisModule_CreateCommand(ctx, "VDIM", + VDIM_RedisCommand, "readonly fast", 1, 1, 1) == REDISMODULE_ERR) + return REDISMODULE_ERR; + + if (RedisModule_CreateCommand(ctx, "VCARD", + VCARD_RedisCommand, "readonly fast", 1, 1, 1) == REDISMODULE_ERR) + return REDISMODULE_ERR; + + if (RedisModule_CreateCommand(ctx, "VEMB", + VEMB_RedisCommand, "readonly fast", 1, 1, 1) == REDISMODULE_ERR) + return REDISMODULE_ERR; + + if (RedisModule_CreateCommand(ctx, "VLINKS", + VLINKS_RedisCommand, "readonly fast", 1, 1, 1) == REDISMODULE_ERR) + return REDISMODULE_ERR; + + if (RedisModule_CreateCommand(ctx, "VINFO", + VINFO_RedisCommand, "readonly fast", 1, 1, 1) == REDISMODULE_ERR) + return REDISMODULE_ERR; + + if (RedisModule_CreateCommand(ctx, "VSETATTR", + VSETATTR_RedisCommand, "write fast", 1, 1, 1) == REDISMODULE_ERR) + return REDISMODULE_ERR; + + if (RedisModule_CreateCommand(ctx, "VGETATTR", + VGETATTR_RedisCommand, "readonly fast", 1, 1, 1) == REDISMODULE_ERR) + return REDISMODULE_ERR; + + if (RedisModule_CreateCommand(ctx, "VRANDMEMBER", + VRANDMEMBER_RedisCommand, "readonly", 1, 1, 1) == REDISMODULE_ERR) + return REDISMODULE_ERR; + + hnsw_set_allocator(RedisModule_Free, RedisModule_Alloc, + RedisModule_Realloc); + + return REDISMODULE_OK; +} diff --git a/modules/vector-sets/w2v.c b/modules/vector-sets/w2v.c new file mode 100644 index 000000000..8d7614d2e --- /dev/null +++ b/modules/vector-sets/w2v.c @@ -0,0 +1,510 @@ +/* + * HNSW (Hierarchical Navigable Small World) Implementation + * Based on the paper by Yu. A. Malkov, D. A. Yashunin + * + * Copyright(C) 2024-Present, Redis Ltd. All Rights Reserved. + * Originally authored by: Salvatore Sanfilippo + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "hnsw.h" + +/* Get current time in milliseconds */ +uint64_t ms_time(void) { + struct timeval tv; + gettimeofday(&tv, NULL); + return (uint64_t)tv.tv_sec * 1000 + (tv.tv_usec / 1000); +} + +/* Implementation of the recall test with random vectors. */ +void test_recall(HNSW *index, int ef) { + const int num_test_vectors = 10000; + const int k = 100; // Number of nearest neighbors to find. + if (ef < k) ef = k; + + // Add recall distribution counters (2% bins from 0-100%). + int recall_bins[50] = {0}; + + // Create array to store vectors for mixing. + int num_source_vectors = 1000; // Enough, since we mix them. + float **source_vectors = malloc(sizeof(float*) * num_source_vectors); + if (!source_vectors) { + printf("Failed to allocate memory for source vectors\n"); + return; + } + + // Allocate memory for each source vector. + for (int i = 0; i < num_source_vectors; i++) { + source_vectors[i] = malloc(sizeof(float) * 300); + if (!source_vectors[i]) { + printf("Failed to allocate memory for source vector %d\n", i); + // Clean up already allocated vectors. + for (int j = 0; j < i; j++) free(source_vectors[j]); + free(source_vectors); + return; + } + } + + /* Populate source vectors from the index, we just scan the + * first N items. */ + int source_count = 0; + hnswNode *current = index->head; + while (current && source_count < num_source_vectors) { + hnsw_get_node_vector(index, current, source_vectors[source_count]); + source_count++; + current = current->next; + } + + if (source_count < num_source_vectors) { + printf("Warning: Only found %d nodes for source vectors\n", + source_count); + num_source_vectors = source_count; + } + + // Allocate memory for test vector. + float *test_vector = malloc(sizeof(float) * 300); + if (!test_vector) { + printf("Failed to allocate memory for test vector\n"); + for (int i = 0; i < num_source_vectors; i++) { + free(source_vectors[i]); + } + free(source_vectors); + return; + } + + // Allocate memory for results. + hnswNode **hnsw_results = malloc(sizeof(hnswNode*) * ef); + hnswNode **linear_results = malloc(sizeof(hnswNode*) * ef); + float *hnsw_distances = malloc(sizeof(float) * ef); + float *linear_distances = malloc(sizeof(float) * ef); + + if (!hnsw_results || !linear_results || !hnsw_distances || !linear_distances) { + printf("Failed to allocate memory for results\n"); + if (hnsw_results) free(hnsw_results); + if (linear_results) free(linear_results); + if (hnsw_distances) free(hnsw_distances); + if (linear_distances) free(linear_distances); + for (int i = 0; i < num_source_vectors; i++) free(source_vectors[i]); + free(source_vectors); + free(test_vector); + return; + } + + // Initialize random seed. + srand(time(NULL)); + + // Perform recall test. + printf("\nPerforming recall test with EF=%d on %d random vectors...\n", + ef, num_test_vectors); + double total_recall = 0.0; + + for (int t = 0; t < num_test_vectors; t++) { + // Create a random vector by mixing 3 existing vectors. + float weights[3] = {0.0}; + int src_indices[3] = {0}; + + // Generate random weights. + float weight_sum = 0.0; + for (int i = 0; i < 3; i++) { + weights[i] = (float)rand() / RAND_MAX; + weight_sum += weights[i]; + src_indices[i] = rand() % num_source_vectors; + } + + // Normalize weights. + for (int i = 0; i < 3; i++) weights[i] /= weight_sum; + + // Mix vectors. + memset(test_vector, 0, sizeof(float) * 300); + for (int i = 0; i < 3; i++) { + for (int j = 0; j < 300; j++) { + test_vector[j] += + weights[i] * source_vectors[src_indices[i]][j]; + } + } + + // Perform HNSW search with the specified EF parameter. + int slot = hnsw_acquire_read_slot(index); + int hnsw_found = hnsw_search(index, test_vector, ef, hnsw_results, hnsw_distances, slot, 0); + + // Perform linear search (ground truth). + int linear_found = hnsw_ground_truth_with_filter(index, test_vector, ef, linear_results, linear_distances, slot, 0, NULL, NULL); + hnsw_release_read_slot(index, slot); + + // Calculate recall for this query (intersection size / k). + if (hnsw_found > k) hnsw_found = k; + if (linear_found > k) linear_found = k; + int intersection_count = 0; + for (int i = 0; i < linear_found; i++) { + for (int j = 0; j < hnsw_found; j++) { + if (linear_results[i] == hnsw_results[j]) { + intersection_count++; + break; + } + } + } + + double recall = (double)intersection_count / linear_found; + total_recall += recall; + + // Add to distribution bins (2% steps) + int bin_index = (int)(recall * 50); + if (bin_index >= 50) bin_index = 49; // Handle 100% recall case + recall_bins[bin_index]++; + + // Show progress. + if ((t+1) % 1000 == 0 || t == num_test_vectors-1) { + printf("Processed %d/%d queries, current avg recall: %.2f%%\n", + t+1, num_test_vectors, (total_recall / (t+1)) * 100); + } + } + + // Calculate and print final average recall. + double avg_recall = (total_recall / num_test_vectors) * 100; + printf("\nRecall Test Results:\n"); + printf("Average recall@%d (EF=%d): %.2f%%\n", k, ef, avg_recall); + + // Print recall distribution histogram. + printf("\nRecall Distribution (2%% bins):\n"); + printf("================================\n"); + + // Find the maximum bin count for scaling. + int max_count = 0; + for (int i = 0; i < 50; i++) { + if (recall_bins[i] > max_count) max_count = recall_bins[i]; + } + + // Scale factor for histogram (max 50 chars wide) + const int max_bars = 50; + double scale = (max_count > max_bars) ? (double)max_bars / max_count : 1.0; + + // Print the histogram. + for (int i = 0; i < 50; i++) { + int bar_len = (int)(recall_bins[i] * scale); + printf("%3d%%-%-3d%% | %-6d |", i*2, (i+1)*2, recall_bins[i]); + for (int j = 0; j < bar_len; j++) printf("#"); + printf("\n"); + } + + // Cleanup. + free(hnsw_results); + free(linear_results); + free(hnsw_distances); + free(linear_distances); + free(test_vector); + for (int i = 0; i < num_source_vectors; i++) free(source_vectors[i]); + free(source_vectors); +} + +/* Example usage in main() */ +int w2v_single_thread(int m_param, int quantization, uint64_t numele, int massdel, int self_recall, int recall_ef) { + /* Create index */ + HNSW *index = hnsw_new(300, quantization, m_param); + float v[300]; + uint16_t wlen; + + FILE *fp = fopen("word2vec.bin","rb"); + if (fp == NULL) { + perror("word2vec.bin file missing"); + exit(1); + } + unsigned char header[8]; + fread(header,8,1,fp); // Skip header + + uint64_t id = 0; + uint64_t start_time = ms_time(); + char *word = NULL; + hnswNode *search_node = NULL; + + while(id < numele) { + if (fread(&wlen,2,1,fp) == 0) break; + word = malloc(wlen+1); + fread(word,wlen,1,fp); + word[wlen] = 0; + fread(v,300*sizeof(float),1,fp); + + // Plain API that acquires a write lock for the whole time. + hnswNode *added = hnsw_insert(index, v, NULL, 0, id++, word, 200); + + if (!strcmp(word,"banana")) search_node = added; + if (!(id % 10000)) printf("%llu added\n", (unsigned long long)id); + } + uint64_t elapsed = ms_time() - start_time; + fclose(fp); + + printf("%llu words added (%llu words/sec), last word: %s\n", + (unsigned long long)index->node_count, + (unsigned long long)id*1000/elapsed, word); + + /* Search query */ + if (search_node == NULL) search_node = index->head; + hnsw_get_node_vector(index,search_node,v); + hnswNode *neighbors[10]; + float distances[10]; + + int found, j; + start_time = ms_time(); + for (j = 0; j < 20000; j++) + found = hnsw_search(index, v, 10, neighbors, distances, 0, 0); + elapsed = ms_time() - start_time; + printf("%d searches performed (%llu searches/sec), nodes found: %d\n", + j, (unsigned long long)j*1000/elapsed, found); + + if (found > 0) { + printf("Found %d neighbors:\n", found); + for (int i = 0; i < found; i++) { + printf("Node ID: %llu, distance: %f, word: %s\n", + (unsigned long long)neighbors[i]->id, + distances[i], (char*)neighbors[i]->value); + } + } + + // Self-recall test (ability to find the node by its own vector). + if (self_recall) { + hnsw_print_stats(index); + hnsw_test_graph_recall(index,200,0); + } + + // Recall test with random vectors. + if (recall_ef > 0) { + test_recall(index, recall_ef); + } + + uint64_t connected_nodes; + int reciprocal_links; + hnsw_validate_graph(index, &connected_nodes, &reciprocal_links); + + if (massdel) { + int remove_perc = 95; + printf("\nRemoving %d%% of nodes...\n", remove_perc); + uint64_t initial_nodes = index->node_count; + + hnswNode *current = index->head; + while (current && index->node_count > initial_nodes*(100-remove_perc)/100) { + hnswNode *next = current->next; + hnsw_delete_node(index,current,free); + current = next; + // In order to don't remove only contiguous nodes, from time + // skip a node. + if (current && !(random() % remove_perc)) current = current->next; + } + printf("%llu nodes left\n", (unsigned long long)index->node_count); + + // Test again. + hnsw_validate_graph(index, &connected_nodes, &reciprocal_links); + hnsw_test_graph_recall(index,200,0); + } + + hnsw_free(index,free); + return 0; +} + +struct threadContext { + pthread_mutex_t FileAccessMutex; + uint64_t numele; + _Atomic uint64_t SearchesDone; + _Atomic uint64_t id; + FILE *fp; + HNSW *index; + float *search_vector; +}; + +// Note that in practical terms inserting with many concurrent threads +// may be *slower* and not faster, because there is a lot of +// contention. So this is more a robustness test than anything else. +// +// The optimistic commit API goal is actually to exploit the ability to +// add faster when there are many concurrent reads. +void *threaded_insert(void *ctxptr) { + struct threadContext *ctx = ctxptr; + char *word; + float v[300]; + uint16_t wlen; + + while(1) { + pthread_mutex_lock(&ctx->FileAccessMutex); + if (fread(&wlen,2,1,ctx->fp) == 0) break; + pthread_mutex_unlock(&ctx->FileAccessMutex); + word = malloc(wlen+1); + fread(word,wlen,1,ctx->fp); + word[wlen] = 0; + fread(v,300*sizeof(float),1,ctx->fp); + + // Check-and-set API that performs the costly scan for similar + // nodes concurrently with other read threads, and finally + // applies the check if the graph wasn't modified. + InsertContext *ic; + uint64_t next_id = ctx->id++; + ic = hnsw_prepare_insert(ctx->index, v, NULL, 0, next_id, 200); + if (hnsw_try_commit_insert(ctx->index, ic, word) == NULL) { + // This time try locking since the start. + hnsw_insert(ctx->index, v, NULL, 0, next_id, word, 200); + } + + if (next_id >= ctx->numele) break; + if (!((next_id+1) % 10000)) + printf("%llu added\n", (unsigned long long)next_id+1); + } + return NULL; +} + +void *threaded_search(void *ctxptr) { + struct threadContext *ctx = ctxptr; + + /* Search query */ + hnswNode *neighbors[10]; + float distances[10]; + int found = 0; + uint64_t last_id = 0; + + while(ctx->id < 1000000) { + int slot = hnsw_acquire_read_slot(ctx->index); + found = hnsw_search(ctx->index, ctx->search_vector, 10, neighbors, distances, slot, 0); + hnsw_release_read_slot(ctx->index,slot); + last_id = ++ctx->id; + } + + if (found > 0 && last_id == 1000000) { + printf("Found %d neighbors:\n", found); + for (int i = 0; i < found; i++) { + printf("Node ID: %llu, distance: %f, word: %s\n", + (unsigned long long)neighbors[i]->id, + distances[i], (char*)neighbors[i]->value); + } + } + return NULL; +} + +int w2v_multi_thread(int m_param, int numthreads, int quantization, uint64_t numele) { + /* Create index */ + struct threadContext ctx; + + ctx.index = hnsw_new(300, quantization, m_param); + + ctx.fp = fopen("word2vec.bin","rb"); + if (ctx.fp == NULL) { + perror("word2vec.bin file missing"); + exit(1); + } + + unsigned char header[8]; + fread(header,8,1,ctx.fp); // Skip header + pthread_mutex_init(&ctx.FileAccessMutex,NULL); + + uint64_t start_time = ms_time(); + ctx.id = 0; + ctx.numele = numele; + pthread_t threads[numthreads]; + for (int j = 0; j < numthreads; j++) + pthread_create(&threads[j], NULL, threaded_insert, &ctx); + + // Wait for all the threads to terminate adding items. + for (int j = 0; j < numthreads; j++) + pthread_join(threads[j],NULL); + + uint64_t elapsed = ms_time() - start_time; + fclose(ctx.fp); + + // Obtain the last word. + hnswNode *node = ctx.index->head; + char *word = node->value; + + // We will search this last inserted word in the next test. + // Let's save its embedding. + ctx.search_vector = malloc(sizeof(float)*300); + hnsw_get_node_vector(ctx.index,node,ctx.search_vector); + + printf("%llu words added (%llu words/sec), last word: %s\n", + (unsigned long long)ctx.index->node_count, + (unsigned long long)ctx.id*1000/elapsed, word); + + /* Search query */ + start_time = ms_time(); + ctx.id = 0; // We will use this atomic field to stop at N queries done. + + for (int j = 0; j < numthreads; j++) + pthread_create(&threads[j], NULL, threaded_search, &ctx); + + // Wait for all the threads to terminate searching. + for (int j = 0; j < numthreads; j++) + pthread_join(threads[j],NULL); + + elapsed = ms_time() - start_time; + printf("%llu searches performed (%llu searches/sec)\n", + (unsigned long long)ctx.id, + (unsigned long long)ctx.id*1000/elapsed); + + hnsw_print_stats(ctx.index); + uint64_t connected_nodes; + int reciprocal_links; + hnsw_validate_graph(ctx.index, &connected_nodes, &reciprocal_links); + printf("%llu connected nodes. Links all reciprocal: %d\n", + (unsigned long long)connected_nodes, reciprocal_links); + hnsw_free(ctx.index,free); + return 0; +} + +int main(int argc, char **argv) { + int quantization = HNSW_QUANT_NONE; + int numthreads = 0; + uint64_t numele = 20000; + int m_param = 0; // Default value (0 means use HNSW_DEFAULT_M) + + /* This you can enable in single thread mode for testing: */ + int massdel = 0; // If true, does the mass deletion test. + int self_recall = 0; // If true, does the self-recall test. + int recall_ef = 0; // If not 0, does the recall test with this EF value. + + for (int j = 1; j < argc; j++) { + int moreargs = argc-j-1; + + if (!strcasecmp(argv[j],"--quant")) { + quantization = HNSW_QUANT_Q8; + } else if (!strcasecmp(argv[j],"--bin")) { + quantization = HNSW_QUANT_BIN; + } else if (!strcasecmp(argv[j],"--mass-del")) { + massdel = 1; + } else if (!strcasecmp(argv[j],"--self-recall")) { + self_recall = 1; + } else if (moreargs >= 1 && !strcasecmp(argv[j],"--recall")) { + recall_ef = atoi(argv[j+1]); + j++; + } else if (moreargs >= 1 && !strcasecmp(argv[j],"--threads")) { + numthreads = atoi(argv[j+1]); + j++; + } else if (moreargs >= 1 && !strcasecmp(argv[j],"--numele")) { + numele = strtoll(argv[j+1],NULL,0); + j++; + if (numele < 1) numele = 1; + } else if (moreargs >= 1 && !strcasecmp(argv[j],"--m")) { + m_param = atoi(argv[j+1]); + j++; + } else if (!strcasecmp(argv[j],"--help")) { + printf("%s [--quant] [--bin] [--thread ] [--numele ] [--m ] [--mass-del] [--self-recall] [--recall ]\n", argv[0]); + exit(0); + } else { + printf("Unrecognized option or wrong number of arguments: %s\n", argv[j]); + exit(1); + } + } + + if (quantization == HNSW_QUANT_NONE) { + printf("You can enable quantization with --quant\n"); + } + + if (numthreads > 0) { + w2v_multi_thread(m_param, numthreads, quantization, numele); + } else { + printf("Single thread execution. Use --threads 4 for concurrent API\n"); + w2v_single_thread(m_param, quantization, numele, massdel, self_recall, recall_ef); + } +}