Understanding how to guide Large Language Models
Disclaimer: This is a text about a library that is used to build applications with LLMs (Large Language Models). This article is rather dense, so don’t feel frustrated if you don’t understand everything at first. Beware: I might also have misunderstood some (or many) things in the process.
So if you have no idea what guidance is about, I’d recommend that you start reading the article I wrote about using it to build a docstring generator.
Also, you can try the following:
- The official library documentation.
- A small library I wrote to serve an LLM with guidance behind an HTTP server
- Take a peek at the Brain Chulo repository. There is some really interesting guidance prompts there written by another developer, which deserves their own article.
On another day, I might go back to the basics, but today I want to dive deep into how the guidance library works.
Why a deep dive?
First, we’d like to understand how it works fully to better understand its limitations.
Second, we’d like to be able to extend the library to other libraries besides Hugging Face. For instance:
- The infamous Recurrent-Neural-Network Language Model — RWKV
- ExLlama — a memory-efficient implementation of the LLaMA model
So buckle up. Today we’re diving deep into guidance source code. Before we start, let’s set up some reproducibility. I’m looking at the main branch of guidance on the day 20 of June of 2023, at this commit: 5f7fa7f6eef6455e6940fe743c5bfdb557330d0b.
To keep things simple, I often copy snippets of code from the repository and delete the commented lines of code. Believe me, there’s a lot of commended code in this library.
Guidance — A Template Language
The top-level construct in Guidance is called a Program. Through a program, one defines the flow of the guided program that the LLM must follow.
But how is it implemented? From the outside, a guidance program looks very similar to a Jinja2 template. Which means there’s probably some parsing happening, right? Exactly, the implementation uses a Grammar Language.
Let’s take a look at the class called ProgramExecutor , which is found in the file guidance/_program_executor.py.
class ProgramExecutor():
def __init__(self, program):
""" Attaches this executor to a program object.
"""
self.program = program
self.block_content = []
self.executing = True
self.should_stop = False
self.caught_stop_iteration = False
self.llm_session = None
self._logging = hasattr(self.program.log, "append")
try:
self.parse_tree = grammar.parse_string(program._text)
except (pp.ParseException, pp.ParseSyntaxException) as e:
As you can see, indeed, there’s an AST (Abstract Syntax Tree) being built from the guidance program string passed as input.
Note: the main entry point of guidance is actually in guidance/_guidance.py, but we’ll skip this one. I’ll mention that it controls a lot of UI and event handling, using asyncio. It’s quite elaborate but hard to follow. It’s also not the part I’m most interested in exploring.
Moving forward, let’s take a peek at guidance/_grammar.py:
import pyparsing as pp
pp.ParserElement.enable_packrat()
program = pp.Forward()
program_chunk = pp.Forward()
ws = pp.White()
opt_ws = pp.Optional(ws)
## comments ##
# long-form comments {{!-- my comment --}}
command_end = pp.Suppress(opt_ws + "}}") | pp.Suppress(opt_ws + "~}}" + opt_ws)
long_comment_start = pp.Suppress(pp.Literal("{{!--"))
(...more similar statements below)
That’s quite elegant! I didn’t know about pyparsing (https://github.com/pyparsing/pyparsing) before looking at this source code; it looks very interesting and lean. So here, the library defines the grammar language of the guidance program. Every possible statement in a guidance program must be defined here.
Do you want to define a new crazy construct for your guidance templates? Well, extend the grammar in this file. Perhaps a topic for another day.
One “minor” thing to notice. To execute a program, guidance also needs to store some state, and it does so using a VariableStack (_variable_stack.py).
class VariableStack:
"""This represents the variables scope stack of a Guidance program."""
def __init__(self, stack, executor):
"""Build a new variable stack object with the given stack and program executor."""
self._stack = stack
self._executor = executor
def push(self, variables):
self._stack.append(variables)
def pop(self):
out = self._stack.pop()
(a lot more code below...)
We don’t need to get into the details of this stack (it’s a stack, after all), but keep in mind it’s there, as it will appear quite often in the next section code blocks.
Guidance — Library Commands
If you’ve used guidance, you’re probably familiar with how the library defines different commands, such as gen, or select.
They’re all defined under the directory guidance/_library.
Let’s start with gen. The source code is rather long (250 lines) just for this command. From what I see, every command starts by pulling up variables from the correct parser context (we briefly saw about the parser in the previous section):
parser = _parser_context['parser']
variable_stack = _parser_context['variable_stack']
next_node = _parser_context["next_node"]
next_next_node = _parser_context["next_next_node"]
prev_node = _parser_context["prev_node"]
So the library implements some look-ahead functionality that checks the next node in the tree. Here’s what the code looks like:
next_text = getattr(next_node, "text", next_node) if next_node is not None else ""
For instance, to decide whether or not it should stop generating text, we can use the following code:
# auto-detect role stop tags
if stop is None:
m = re.match(r"^{{~?/\w*(user|assistant|system|role)\w*~?}}.*", next_text)
if m:
stop = parser.program.llm.role_end(m.group(1))
In any case, several checks and parameters are checked to adjust the generation’s behavior, for instance, temperature, etc. At some point in the source code, it decides to start generating text. Here’s how to do that:
# call the LLM
gen_obj = await parser.llm_session(
variable_stack["@prefix"]+prefix, stop=stop, stop_regex=stop_regex, max_tokens=max_tokens, n=n, pattern=pattern,
temperature=temperature, top_p=top_p, logprobs=logprobs, cache_seed=cache_seed, token_healing=token_healing,
echo=parser.program.logprobs is not None, stream=stream, caching=parser.program.caching, **llm_kwargs
)
So, for this, we need the llm_session, which we’ll cover in the next section. To keep things short, let’s skip reading the results of the generation when only one generation is requested. In this case, guidance can generate content asynchronously, hence the returned gen_obj is actually an AsyncGeneratorType.
async for resp in gen_obj:
await asyncio.sleep(0)
if parser.should_stop:
break
generated_value += resp["choices"][0]["text"]
variable_stack["@raw_prefix"] += resp["choices"][0]["text"]
if logprobs is not None:
logprobs_out.extend(resp["choices"][0]["logprobs"]["top_logprobs"])
if list_append:
variable_stack[name][list_ind] = generated_value
if logprobs is not None:
variable_stack[name+"_logprobs"][list_ind] = logprobs_out
elif name is not None:
variable_stack[name] = generated_value
if logprobs is not None:
variable_stack[name+"_logprobs"] = logprobs_out
We can see here that it’s expecting a certain format from the generated response, e.g., a dictionary that contains an array under choices with a single item of text and a dict of logprobs. Since this block is only used when only exactly one generation is requested, the logic makes sense.
It’s good to note that the generated values are also pushed to the variable_stack, and I suspect this is also what’s passed to the program result accessible through a dictionary.
Let’s move on and take a look at the next command: select. This is a really interesting one.
I’ll skip the boring parts and jump straight into the golden parts.
First, with this command, we’ll figure out which options are available to be used at the generation. Don’t worry about understanding everything here in this block. The important piece is that it’s finding the correct options while handling edge cases where one option is a substring of another option.
# find what text follows the select command and append it to the options.
# we do this so we can differentiate between select options where one is a prefix of another
next_text = next_node.text if next_node is not None else ""
if next_next_node and next_next_node.text.startswith("{{~"):
next_text = next_text.lstrip()
if next_next_node and next_text == "":
next_text = next_next_node.text
if next_text == "": # if we have nothing after us then we are at the end of the text
next_text = parser.program.llm.end_of_text()
options = [option + next_text for option in options]
# TODO: this retokenizes the whole prefix many times, perhaps this could become a bottleneck?
options_tokens = [parser.program.llm.encode(variable_stack["@prefix"] + option) for option in options]
# encoding the prefix and then decoding it might change the length, so we need to account for that
recoded_parser_prefix_length = len(parser.program.llm.decode(parser.program.llm.encode(variable_stack["@prefix"])))
Then it uses a Trie data structure to map the options tokens in an efficient way through the pygtrie library:
# build a trie of the options
token_map = pygtrie.Trie()
for i,option in enumerate(options_tokens):
token_map[option] = i
Now we have a really big recursive function. The good part is that it’s thoroughly commented, so it’s not so hard to understand the individual blocks.
async def recursive_select(current_prefix, allow_token_extension=True):
# A lot of code in here...
return logprobs_out
Still, let’s go over this function in chunks. First, the function recursively eliminates options that are no longer possible, and it initializes the logprobs output dictionary:
# find which select options are possible
try:
extension_options = token_map.items(prefix=current_prefix)
except KeyError:
return {}
# this is the dictionary of logprobs for each option we will return
# note that the logprobs are just for this branch point and below in the decision tree
logprobs_out = {option[0]: -1000 for option in extension_options}
If only one option remains with the current prefix, it exists earlier. Otherwise, it will auto extend the prefix with the longest common prefix. I imagine this is purely for optimization reasons, so it saves asking the LLM to generate a few tokens that would be deterministically generated anyway.
# extend the prefix with the longest common prefix among the valid options
# we also stop early if we have one option
if len(extension_options) == 1:
logprobs_out[extension_options[0][0]] = 0 # probability of 1.0 that we will select the only valid option
return logprobs_out
else:
match_index = len(current_prefix)
for i in range(len(current_prefix), min([len(o[0]) for o in extension_options])):
if len(set([o[0][i] for o in extension_options])) > 1:
break
match_index += 1
if match_index > len(current_prefix):
current_prefix += extension_options[0][0][len(current_prefix):match_index]
# extension_options = [(option[i:], index) for option,index in extension_options]
Now there’s a really interesting block:
# bias the logits towards valid options
logit_bias = {}
for option_tokens,index in extension_options:
logit_bias[option_tokens[match_index]] = 100
So the select command adds biases to the tokens that are part of the valid options. This makes the LLM necessarily generate content that matches one of the options. The next step is naturally to pass this content to the LLM session:
gen_obj = await parser.llm_session(
parser.program.llm.decode(current_prefix), # TODO: perhaps we should allow passing of token ids directly? (this could allow us to avoid retokenizing the whole prefix many times)
max_tokens=1,
logit_bias=logit_bias,
logprobs=len(logit_bias),
cache_seed=0,
token_healing=False # we manage token boundary healing ourselves for this function
)
Notice how the logit_bias is passed along with the current prefix, and exactly one token is generated at a time. Probably because at each generated character, the Trie needs to be updated to generate the next valid options and update the logit_bias accordingly.
At each generated token, the result is interpreted from the top logpobs choice is picked:
# no need to explore all branches if we are just taking the greedy max
if logprobs is None:
max_key = max(top_logprobs, key=top_logprobs.get)
top_logprobs = {max_key: top_logprobs[max_key]}
And then the recursion continues:
# for each possible next token, see if it grows the prefix in a valid way
for token,logprob in top_logprobs.items():
sub_logprobs = await recursive_select(current_prefix + [token])
All right, that’s it for the select command. Please notice that I’m omitting a lot of details in this implementation; I’m trying to focus on the essential part.
If we would like to support this command through a different library than Hugging Face, we need to support at least the following operations using the llm
llm.decode
llm.encode
llm.end_of_text
llm.token_to_id
And, of course, also support the logit_bias parameter when calling the llm_session.
Guidance — The LLM
So we saw that to call the guidance commands, we need a llm object that complies with a certain interface. Here it is (certain methods omitted to give more clarity):
class LLM(metaclass=LLMMeta):
cache_version = 1
default_system_prompt = "You are a helpful assistant."
llm_name: str = "unknown"
def session(self, asynchronous=False):
if asynchronous:
return LLMSession(self)
else:
return SyncSession(LLMSession(self))
def encode(self, string, **kwargs):
return self._tokenizer.encode(string, **kwargs)
def decode(self, tokens, **kwargs):
return self._tokenizer.decode(tokens, **kwargs)
def id_to_token(self, id):
return self.decode([id])
def token_to_id(self, token):
return self.encode(token)[0]
Similarly, we must also implement the LLMSession class itself, although there the interface is not so informative.
Let’s take a look at the main implementation classes. The classes inside llms/_transformers.py.
In here, we’ll see a few important classes:
class Transformers(LLM):
class TransformersSession(LLMSession):
class TokenHealingLogitsProcessor():
class BiasLogitsProcessor():
class RegexLogitsProcessor():
class RegexStoppingCriteria():
class TransformersStringBuilder():
class TransformersStreamer():
Overall, this spans over 763 lines of code. It’s a lot to unpack. This also hints that transformer-based LLMs have the best support in terms of library features, as many of these classes are not found in the abstract interface of the LLM.
So while a simple implementation of our own interface could just add equivalent classes of Transformers and TransformersSession, an implementation with feature parity to the transformers library requires a lot of effort.
Still, let’s go over at least some of these classes.
The Transfomers class:
class Transformers(LLM):
""" A HuggingFace transformers language model with Guidance support.
"""
llm_name: str = "transformers"
def __init__(self, model=None, tokenizer=None, caching=True, token_healing=True, acceleration=True, \
temperature=0.0, device=None, **kwargs):
super().__init__()
# (ommited constructor here)
def new_string_builder(self, starting_ids=None):
return TransformersStringBuilder(self.tokenizer, starting_ids)
def prefix_matches(self, prefix):
""" Return the list of tokens that match the given prefix.
"""
return [v for arr in self._token_prefix_map.values(prefix=prefix) for v in arr]
def encode(self, string, **kwargs):
return self.tokenizer.encode(string, **kwargs)
def decode(self, tokens, **kwargs):
return self.tokenizer.decode(tokens, **kwargs)
def id_to_token(self, id):
return self.tokenizer.convert_ids_to_tokens([id])[0]
def token_to_id(self, token):
return self.tokenizer.convert_tokens_to_ids([token])[0]
def end_of_text(self):
return self.tokenizer.eos_token
def _build_token_prefix_map(self, model_name):
""" Build a map from token to index.
"""
token_map = pygtrie.CharTrie()
for i in range(self.tokenizer.vocab_size):
s = self.id_to_token(i)
if s in token_map:
token_map[s].append(i) # handle duplicate token encodings... (GPT2 BPE has this oddly enough)
else:
token_map[s] = [i]
return token_map
There are some interesting points here:
- There’s another Trie prebuilt when this class is instantiated.
- The class TransformersStringBuilder is used behind the method new_string_builder, it’s not yet clear why it’s needed.
Other than that, it’s really just implementing the abstract interface we saw above.
Guidance — The TransformersSession
The next class, TransformersSession, is a beast, spanning over 250 lines of code. As we saw in the gen and select commands, this implements the generation interface that the guidance program will use to generate text.
Again, let’s tackle it step by step, but we’ll keep it “simple” and skip some parts.
The constructor is rather simple, and the first dunder method applies token acceleration if enabled, which uses some prefix cache. By default, this is enabled by the Transformers LLM class.
class TransformersSession(LLMSession):
def __init__(self, llm):
super().__init__(llm)
self._past_key_values = None
self._prefix_cache = []
def __enter__(self):
# we only need decorators if we are using token acceleration
if self.llm.acceleration:
# Decorators that apply token acceleration
# For brevity, we'll skip this logic
The next function is where the main logic of the TransformerSession is found, here’s the signature (reformatted for more clarity):
async def __call__(
self,
prompt,
stop=None,
stop_regex=None,
temperature=None,
n=1,
max_tokens=1000,
logprobs=None,
top_p=1.0,
echo=False,
logit_bias=None,
token_healing=None,
pattern=None,
stream=False,
cache_seed=0,
caching=None,
**generate_kwargs
):
"""Generate a completion of the given prompt."""
The next code snippets are all part of the same function above.
As we saw in the previous sections, while a guidance program evaluates the language Abstract Syntax Tree, each parsed node from the grammar may call this session directly with any of these parameters. It starts by simply checking the parameters and filling in some defaults. Then the code check if there’s already the answer cached or return the cache value if present:
# handle caching
in_cache = key in llm_cache
not_caching = (caching is not True and not self.llm.caching) or caching is False
if not in_cache or not_caching:
# compute value
# (omitted block here...)
return llm_cache[key]
Let’s zoom into how the llm_cache[key] gets filled in (that is, inside the omitted block here above):
So we check that we have a valid prompt and import the appropriate libraries (transformers, torch). Then the prompt is encoded and sent to the appropriate device (mostly likely cuda). Funnily, the prompt is decoded again and stored in a variable for later.
Then we see a series of setup blocks where, first, we initialize the TokenHealingLogitsProcessor.
# setup token healing
if token_healing:
healer = TokenHealingLogitsProcessor(self.llm, model_config.vocab_size, input_ids[0])
healed_token_ids = healer.healed_token_ids
if len(healed_token_ids) > 0:
input_ids = input_ids[:,:-len(healed_token_ids)]
# attention_mask = attention_mask[:,:-len(healed_token_ids)]
max_tokens += len(healed_token_ids) # increase to account for the tokens we regen for token healing
processors.append(healer)
Then the BiasLogitsProcessor:
# setup logit biasing
if logit_bias is not None:
processors.append(BiasLogitsProcessor(self.llm, model_config.vocab_size, logit_bias))
Looks for the max_content supported by the model and checks input doesn’t exceed this limit. Here’s what the code looks like:
# find the max context length
possible_attributes = ["max_sequence_length", "max_seq_len", "model_max_length", "n_positions", "max_position_embeddings"]
max_context = None
for obj in [model_config, self.llm.tokenizer]:
for attr in possible_attributes:
if max_context is None:
max_context = getattr(obj, attr, None)
else:
break
assert max_context is not None, "Could not find a max context length for the model! Tried: "+", ".join(possible_attributes)
# make sure we don't run off the end of the model
if max_tokens + len(input_ids[0]) > max_context:
max_tokens = max_context - len(input_ids[0])
It seems like we must always generate at least one token, even if the whole prefix is already cached, so there’s some logic to check that:
# find how much of the prompt is cached
prefix_match_len = 0
for token in input_ids[0]:
if prefix_match_len >= len(self._prefix_cache) or token != self._prefix_cache[prefix_match_len]:
break
else:
prefix_match_len += 1
# we always need to run the model on at least one token so transformers is happy
if prefix_match_len == len(input_ids[0]):
prefix_match_len -= 1
Then a couple more classes are optionally initialized RegexLogitsProcessor and RegexStoppingCriteria.
# add support for pattern guidance
if pattern is not None:
processors.append(RegexLogitsProcessor(pattern, stop_regex, self.llm, model_config.vocab_size, temperature == 0, len(coded_prompt), self.llm.tokenizer.eos_token_id))
if stop_regex is not None:
stoppers.append(RegexStoppingCriteria(stop_regex, self.llm, len(coded_prompt)))
We initialize a TransformersStreamer:
# a streamer to handle potentially partial output
streamer = TransformersStreamer(
input_ids=input_ids,
stop_regex=stop_regex,
healed_token_ids=healed_token_ids,
prefix_length=len(coded_prompt),
llm=self.llm,
max_new_tokens=max_tokens,
logprobs=logprobs
)
We define all the arguments for generation with this code:
# the args for the transformers generate call
generate_args = dict(
inputs=input_ids,
temperature=temperature,
max_new_tokens=max_tokens,
top_p=top_p,
pad_token_id=model_config.pad_token_id if model_config.pad_token_id is not None else self.llm.tokenizer.eos_token_id,
logits_processor=transformers.LogitsProcessorList(processors),
stopping_criteria=transformers.StoppingCriteriaList(stoppers),
output_scores=logprobs is not None and logprobs > 0,
return_dict_in_generate=True,
**generate_kwargs
)
Let’s skip some further parameter magic for the generation and finally see the end of this function:
if stream:
generate_args["streamer"] = streamer
thread = threading.Thread(target=self.llm.model_obj.generate, kwargs=generate_args)
thread.start()
return self._stream_then_save(streamer, key, thread)
# if we are not streaming we still manually use the streamer for consistency
else:
generated_sequence = self.llm.model_obj.generate(**generate_args)
streamer.put(generated_sequence)
self.llm.cache[key] = streamer.__next__()
self._update_prefix_cache(streamer)
So here, we’ll finally use the llm.model_obj.generate function to generate the text with the given prompt. If you recall from the Transformer class, model_obj is the actual instance of a model loaded by the transformers library from hugging face.
Let’s backtrack a little bit. So what is this TransformersStreamer doing and how is the logit bias being applied?
So the TransformersStreamer can be understood as a queue of tokens. If we look at its __next__ method, it becomes quite obvious:
def __next__(self):
value = self.out_queue.get(timeout=self.timeout)
if value is None:
raise StopIteration()
else:
return value
And also the signature of the put method:
def put(self, token_obj):
Of course, there’s a lot of logic in there, but let’s keep it simple. Alright, and then if you paid attention, you might have noticed these generation arguments:
logits_processor=transformers.LogitsProcessorList(processors),
stopping_criteria=transformers.StoppingCriteriaList(stoppers)
So, inside LogitsProcessorList we pass the token healer and the logit bias processor. And inside StoppingCriteriaList, the RegexStoppingCriteria is used.
So this is completely new to me. Hugging Face generative transformers actually support these two parameters (logits_processor, and stopping_criteria) out of the box: https://huggingface.co/docs/transformers/v4.30.0/en/main_classes/text_generation#transformers.GenerationMixin.generate
About logits_processor:
Custom logits processors that complement the default logits processors built from arguments and generation config. If a logit processor is passed that is already created with the arguments or a generation config an error is thrown. This feature is intended for advanced users.
A LogitsProcessor can be used to modify the prediction scores of a language model head for generation.
Similarly, about the stopping_criteria:
Custom stopping criteria that complement the default stopping criteria built from arguments and a generation config. If a stopping criteria is passed that is already created with the arguments or a generation config an error is thrown. This feature is intended for advanced users.
A StoppingCriteria can be used to change when to stop generation (other than EOS token).
Very advanced stuff! So this is the magic behind Guidance, using a custom logits_processor from the transformers library to influence the outcome, along with a custom stopping_criteria.
Alright, let’s wrap up. Here are the referred symbols that this Session class expects to find. There could be more, as I didn’t get deep into all of the classes in this module:
llm.temperature
llm.token_healing
llm.cache
llm.llm_name
llm.encode
llm.decode
llm.tokenizer.eos_token
llm.tokenizer.eos_token_id
llm.model_obj.config
llm.model_obj.config.vocab_size
llm.model_obj.config.pad_token_id
Conclusion
All right, I think that’s enough to satisfy my curiosity. While we could go into more details, I don’t think I’d benefit much. I understand this text is rather dense and difficult to read, but I hope you can learn a bit as I did by doing this deep dive.
One could also analyse how the OpenAI API is integrated into guidance, but I’ll leave that as an exercise to the interested reader.
In any case, it seems like to add extended guidance to support additional libraries beyond transformers, not much is needed. From what I understand, here are the key components:
- A Large Language Model class that encapsulates some abstractions for the generation Session.
- A tokenizer exposes certain information about the tokens and offers a way to encode/decode tokens.
- A Session that implements the communication between the guidance parser and the LLM and its tokenizer in a way that overrides stopping criteria and influences the bias of the predictions through a logit processor or similar functionality.
Perhaps the next step would be to pick our candidate library and check how much it offers. In any case, that’s for another day!
A Deep Dive Into Guidance’s Source Code was originally published in Better Programming on Medium, where people are continuing the conversation by highlighting and responding to this story.