def receive(self, chunk: ModelResponseStream):
adapter_name = settings.adapter.__class__.__name__ if settings.adapter else "ChatAdapter"
if adapter_name not in self.adapter_identifiers:
raise ValueError(
f"Unsupported adapter for streaming: {adapter_name}, please use one of the following adapters: "
f"{', '.join([a.__name__ for a in ADAPTER_SUPPORT_STREAMING])}"
)
start_identifier = self.adapter_identifiers[adapter_name]["start_identifier"]
end_identifier = self.adapter_identifiers[adapter_name]["end_identifier"]
start_indicator = self.adapter_identifiers[adapter_name]["start_indicator"]
if self.stream_end:
if self.allow_reuse:
# Clear up the state for the next stream.
self.stream_end = False
self.cache_hit = False
self.field_start_queue = []
self.field_end_queue = Queue()
self.stream_start = False
else:
return
try:
chunk_message = chunk.choices[0].delta.content
if chunk_message is None:
return
except Exception:
return
if chunk_message and start_identifier in chunk_message:
# If the cache is hit, the chunk_message could be the full response. When it happens we can
# directly end the stream listening. In some models like gemini, each stream chunk can be multiple
# tokens, so it's possible that response only has one chunk, we also fall back to this logic.
message_after_start_identifier = chunk_message[
chunk_message.find(start_identifier) + len(start_identifier) :
]
if re.search(end_identifier, message_after_start_identifier):
self.cache_hit = True
self.stream_start = True
self.stream_end = True
return
if len(self.field_start_queue) == 0 and not self.stream_start and start_indicator in chunk_message:
# We look for the pattern of start_identifier, i.e., "[[ ## {self.signature_field_name} ## ]]" for
# ChatAdapter to identify the start of the stream of our target field. Once the start_indicator, i.e., "[["
# for ChatAdapter, is found, we start checking the next tokens
self.field_start_queue.append(chunk_message)
return
if len(self.field_start_queue) > 0 and not self.stream_start:
# We keep appending the tokens to the queue until we have a full identifier or the concanated
# tokens no longer match our expected identifier.
self.field_start_queue.append(chunk_message)
concat_message = "".join(self.field_start_queue)
if start_identifier in concat_message:
# We have a full identifier, we can start the stream.
self.stream_start = True
self.field_start_queue = []
# Keep the part after the start_identifier from the concat_message, we need to write it to the buffer.
value_start_index = concat_message.find(start_identifier) + len(start_identifier)
chunk_message = concat_message[value_start_index:].lstrip()
if isinstance(settings.adapter, JSONAdapter) and chunk_message.startswith('"'):
# For JSONAdapter, we need to remove the leading ". We cannot do this with the start_identifier
# because there could be a few splitters between ':' and '"', e.g., '"name": "value"'.
chunk_message = chunk_message[1:]
elif self._buffered_message_end_with_start_identifier(concat_message.strip(), start_identifier):
# If the buffered message ends with part of the start_identifier, we keep looking for the
# start_identifier from the token stream.
return
else:
# Doesn't match the expected identifier, reset the queue.
self.field_start_queue = []
return
if self.stream_start:
# The stream is started, we keep returning the token until we see the start of the next field.
token = None
self.field_end_queue.put(chunk_message)
if self.field_end_queue.qsize() > 10:
# We keep the last 10 tokens in the buffer to check if they form a valid identifier for end_identifier,
# i.e., "[[ ## {next_field_name} ## ]]" for ChatAdapter to identify the end of the current field.
# In most cases 10 tokens are enough to cover the end_identifier for all adapters.
token = self.field_end_queue.get()
concat_message = "".join(self.field_end_queue.queue).strip()
if re.search(end_identifier, concat_message):
# The next field is identified, we can end the stream and flush out all tokens in the buffer.
self.stream_end = True
last_token = self.flush()
token = token + last_token if token else last_token
token = token.rstrip() # Remove the trailing \n\n
if token:
return StreamResponse(
self.predict_name,
self.signature_field_name,
token,
is_last_chunk=self.stream_end,
)