"""Stateful scrubber for reasoning/thinking blocks in streamed assistant text. ``run_agent._strip_think_blocks`` is regex-based and correct for a complete string, but when it runs *per-delta* in ``_fire_stream_delta`` it destroys the state that downstream consumers (CLI ``_stream_delta``, gateway ``GatewayStreamConsumer._filter_and_accumulate``) rely on. Concretely, when MiniMax-M2.7 streams delta1 = "" delta2 = "Let me check their config" delta3 = "" the per-delta regex erases delta1 entirely (case 2: unterminated-open at boundary matches ``^...``), so the downstream state machine never sees the open tag, treats delta2 as regular content, and leaks reasoning to the user. Consumers that don't run their own state machine (ACP, api_server, TTS) never had any defence at all — they just emitted whatever survived the upstream regex. This module centralises the tag-suppression state machine at the upstream layer so every stream_delta_callback sees text that has already had reasoning blocks removed. Partial tags at delta boundaries are held back until the next delta resolves them, and end-of-stream flushing surfaces any held-back prose that turned out not to be a real tag. Usage:: scrubber = StreamingThinkScrubber() for delta in stream: visible = scrubber.feed(delta) if visible: emit(visible) tail = scrubber.flush() # at end of stream if tail: emit(tail) The scrubber is re-entrant per agent instance. Call ``reset()`` at the top of each new turn so a hung block from an interrupted prior stream cannot taint the next turn's output. Tag variants handled (case-insensitive): ````, ````, ````, ````, ````. Block-boundary rule for opens: an opening tag is only treated as a reasoning-block opener when it appears at the start of the stream, after a newline (optionally followed by whitespace), or when only whitespace has been emitted on the current line. This prevents prose that *mentions* the tag name (e.g. ``"use tags here"``) from being incorrectly suppressed. Closed pairs (``X``) are always suppressed regardless of boundary; a closed pair is an intentional, bounded construct. """ from __future__ import annotations from typing import Tuple __all__ = ["StreamingThinkScrubber"] class StreamingThinkScrubber: """Stateful scrubber for streaming reasoning/thinking blocks. State machine: - ``_in_block``: True while inside an opened block, waiting for a close tag. All text inside is discarded. - ``_buf``: held-back partial-tag tail. Emitted / discarded on the next ``feed()`` call or by ``flush()``. - ``_last_emitted_ended_newline``: True iff the most recent emission to the consumer ended with ``\\n``, or nothing has been emitted yet (start-of-stream counts as a boundary). Used to decide whether an open tag at buffer position 0 is at a block boundary. """ _OPEN_TAG_NAMES: Tuple[str, ...] = ( "think", "thinking", "reasoning", "thought", "REASONING_SCRATCHPAD", ) # Materialise literal tag strings so the hot path does string # operations, not regex compilation per feed(). _OPEN_TAGS: Tuple[str, ...] = tuple(f"<{name}>" for name in _OPEN_TAG_NAMES) _CLOSE_TAGS: Tuple[str, ...] = tuple(f"" for name in _OPEN_TAG_NAMES) # Pre-compute the longest tag (for partial-tag hold-back bound). _MAX_TAG_LEN: int = max(len(tag) for tag in _OPEN_TAGS + _CLOSE_TAGS) def __init__(self) -> None: self._in_block: bool = False self._buf: str = "" self._last_emitted_ended_newline: bool = True def reset(self) -> None: """Reset all state. Call at the top of every new turn.""" self._in_block = False self._buf = "" self._last_emitted_ended_newline = True def feed(self, text: str) -> str: """Feed one delta; return the scrubbed visible portion. May return an empty string when the entire delta is reasoning content or is being held back pending resolution of a partial tag at the boundary. """ if not text: return "" buf = self._buf + text self._buf = "" out: list[str] = [] while buf: if self._in_block: # Hunt for the earliest close tag. close_idx, close_len = self._find_first_tag( buf, self._CLOSE_TAGS, ) if close_idx == -1: # No close yet — hold back a potential partial # close-tag prefix; discard everything else. held = self._max_partial_suffix(buf, self._CLOSE_TAGS) self._buf = buf[-held:] if held else "" return "".join(out) # Found close: discard block content + tag, continue. buf = buf[close_idx + close_len:] self._in_block = False else: # Priority 1 — closed X pair anywhere in # buf. Closed pairs are always an intentional, # bounded construct (even mid-line prose containing # an open/close pair is almost certainly a model # leaking reasoning inline), so no boundary gating. pair = self._find_earliest_closed_pair(buf) # Priority 2 — unterminated open tag at a block # boundary. Boundary-gated so prose that mentions # '' isn't over-stripped. open_idx, open_len = self._find_open_at_boundary( buf, out, ) # Pick whichever match comes earliest in the buffer. if pair is not None and ( open_idx == -1 or pair[0] <= open_idx ): start_idx, end_idx = pair preceding = buf[:start_idx] if preceding: preceding = self._strip_orphan_close_tags(preceding) if preceding: out.append(preceding) self._last_emitted_ended_newline = ( preceding.endswith("\n") ) buf = buf[end_idx:] continue if open_idx != -1: # Unterminated open at boundary — emit preceding, # enter block, continue loop with remainder. preceding = buf[:open_idx] if preceding: preceding = self._strip_orphan_close_tags(preceding) if preceding: out.append(preceding) self._last_emitted_ended_newline = ( preceding.endswith("\n") ) self._in_block = True buf = buf[open_idx + open_len:] continue # No resolvable tag structure in buf. Hold back any # partial-tag prefix at the tail so a split tag # across deltas isn't missed, then emit the rest. held = self._max_partial_suffix(buf, self._OPEN_TAGS) held_close = self._max_partial_suffix( buf, self._CLOSE_TAGS, ) held = max(held, held_close) if held: emit_text = buf[:-held] self._buf = buf[-held:] else: emit_text = buf self._buf = "" if emit_text: emit_text = self._strip_orphan_close_tags(emit_text) if emit_text: out.append(emit_text) self._last_emitted_ended_newline = ( emit_text.endswith("\n") ) return "".join(out) return "".join(out) def flush(self) -> str: """End-of-stream flush. If still inside an unterminated block, held-back content is discarded — leaking partial reasoning is worse than a truncated answer. Otherwise the held-back partial-tag tail is emitted verbatim (it turned out not to be a real tag prefix). """ if self._in_block: self._buf = "" self._in_block = False return "" tail = self._buf self._buf = "" if not tail: return "" tail = self._strip_orphan_close_tags(tail) if tail: self._last_emitted_ended_newline = tail.endswith("\n") return tail # ── internal helpers ─────────────────────────────────────────────── @staticmethod def _find_first_tag( buf: str, tags: Tuple[str, ...], ) -> Tuple[int, int]: """Return (earliest_index, tag_length) over *tags*, or (-1, 0). Case-insensitive match. """ buf_lower = buf.lower() best_idx = -1 best_len = 0 for tag in tags: idx = buf_lower.find(tag.lower()) if idx != -1 and (best_idx == -1 or idx < best_idx): best_idx = idx best_len = len(tag) return best_idx, best_len def _find_earliest_closed_pair(self, buf: str): """Return (start_idx, end_idx) of the earliest closed pair, else None. A closed pair is ``...`` of any variant. Matches are case-insensitive and non-greedy (the closest close tag after an open tag wins), matching the regex ``.*?`` semantics of ``_strip_think_blocks`` case 1. When two tag variants could both match, the one whose open tag appears earlier wins. """ buf_lower = buf.lower() best: "tuple[int, int] | None" = None for open_tag, close_tag in zip(self._OPEN_TAGS, self._CLOSE_TAGS): open_lower = open_tag.lower() close_lower = close_tag.lower() open_idx = buf_lower.find(open_lower) if open_idx == -1: continue close_idx = buf_lower.find( close_lower, open_idx + len(open_lower), ) if close_idx == -1: continue end_idx = close_idx + len(close_lower) if best is None or open_idx < best[0]: best = (open_idx, end_idx) return best def _find_open_at_boundary( self, buf: str, already_emitted: list[str], ) -> Tuple[int, int]: """Return the earliest block-boundary open-tag (idx, len). Returns (-1, 0) if no boundary-legal opener is present. """ buf_lower = buf.lower() best_idx = -1 best_len = 0 for tag in self._OPEN_TAGS: tag_lower = tag.lower() search_start = 0 while True: idx = buf_lower.find(tag_lower, search_start) if idx == -1: break if self._is_block_boundary(buf, idx, already_emitted): if best_idx == -1 or idx < best_idx: best_idx = idx best_len = len(tag) break # first boundary hit for this tag is enough search_start = idx + 1 return best_idx, best_len def _is_block_boundary( self, buf: str, idx: int, already_emitted: list[str], ) -> bool: """True iff position *idx* in *buf* is a block boundary. A block boundary is: - buf position 0 AND the most recent emission ended with a newline (or nothing has been emitted yet) - any position whose preceding text on the current line (since the last newline in buf) is whitespace-only, AND if there is no newline in the preceding buf portion, the most recent prior emission ended with a newline """ if idx == 0: # Check whether the last already-emitted chunk in THIS # feed() call ended with a newline, otherwise fall back # to the cross-feed flag. if already_emitted: return already_emitted[-1].endswith("\n") return self._last_emitted_ended_newline preceding = buf[:idx] last_nl = preceding.rfind("\n") if last_nl == -1: # No newline in buf before the tag — boundary only if the # prior emission ended with a newline AND everything since # is whitespace. if already_emitted: prior_newline = already_emitted[-1].endswith("\n") else: prior_newline = self._last_emitted_ended_newline return prior_newline and preceding.strip() == "" # Newline present — text between it and the tag must be # whitespace-only. return preceding[last_nl + 1:].strip() == "" @classmethod def _max_partial_suffix( cls, buf: str, tags: Tuple[str, ...], ) -> int: """Return the longest buf-suffix that is a prefix of any tag. Only prefixes strictly shorter than the tag itself count (full-length suffixes are the tag and are handled as matches, not held-back partials). Case-insensitive. """ if not buf: return 0 buf_lower = buf.lower() max_check = min(len(buf_lower), cls._MAX_TAG_LEN - 1) for i in range(max_check, 0, -1): suffix = buf_lower[-i:] for tag in tags: tag_lower = tag.lower() if len(tag_lower) > i and tag_lower.startswith(suffix): return i return 0 @classmethod def _strip_orphan_close_tags(cls, text: str) -> str: """Remove any close tags from *text* (orphan-close handling). An orphan close tag has no matching open in the current scrubber state; it's always noise, stripped with any trailing whitespace so the surrounding prose flows naturally. """ if "