Commit 2762e688 authored by drbh's avatar drbh
Browse files

fix: include fsm_grammar_states in FlashMistralBatch from_pb

parent ff42d33e
......@@ -98,6 +98,7 @@ class FlashMistralBatch(FlashCausalLMBatch):
prefill_cu_outlens = [0]
next_token_chooser_parameters = []
fsm_grammar_states = []
stopping_criterias = []
top_n_tokens = []
......@@ -136,6 +137,7 @@ class FlashMistralBatch(FlashCausalLMBatch):
cu_seqlen_prefill.append(cumulative_length + input_length)
next_token_chooser_parameters.append(r.parameters)
fsm_grammar_states.append(r.fsm_grammar_state)
stopping_criteria = StoppingCriteria.from_pb(
r.stopping_parameters, tokenizer
......@@ -204,7 +206,7 @@ class FlashMistralBatch(FlashCausalLMBatch):
)
next_token_chooser = HeterogeneousNextTokenChooser.from_pb(
next_token_chooser_parameters, dtype, device, tokenizer
next_token_chooser_parameters, dtype, device, tokenizer, fsm_grammar_states
)
start_slots = torch.tensor(start_slots, dtype=torch.int64)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment