Skip to content

Automatic Chain-of-Thought

An implementation of the automatic chain-of-thought (CoT) prompting strategy from this paper.

cogitator.strategies.auto_cot.AutoCoT

Implements the Automatic Chain-of-Thought (Auto-CoT) prompting strategy.

Auto-CoT aims to automatically construct demonstrations for few-shot CoT prompting by clustering questions and selecting diverse examples, then generating CoT reasoning for them using zero-shot prompts.

Reference

Zhang et al. (2022) "Automatic Chain of Thought Prompting in Large Language Models". https://arxiv.org/abs/2210.03493

Source code in cogitator/strategies/auto_cot.py
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
class AutoCoT:
    """Implements the Automatic Chain-of-Thought (Auto-CoT) prompting strategy.

    Auto-CoT aims to automatically construct demonstrations for few-shot CoT prompting
    by clustering questions and selecting diverse examples, then generating CoT
    reasoning for them using zero-shot prompts.

    Reference:
        Zhang et al. (2022) "Automatic Chain of Thought Prompting in Large Language Models".
        https://arxiv.org/abs/2210.03493
    """

    def __init__(
        self,
        llm: BaseLLM,
        n_demos: int = 8,
        max_q_tokens: int = 60,
        max_steps: int = 5,
        *,
        prompt_template: str = "Let's think step by step.",
        max_retries: int = 2,
        max_tokens: Optional[int] = None,
        rand_seed: Optional[int] = None,
        embedder: Optional[BaseEmbedder] = None,
        clusterer: Optional[BaseClusterer] = None,
    ) -> None:
        """Initializes the AutoCoT strategy handler.

        Args:
            llm: The language model instance to use for generation.
            n_demos: The desired number of demonstrations to generate.
            max_q_tokens: Maximum approximate token length for questions selected as demos.
            max_steps: Maximum number of reasoning steps allowed in a generated demo CoT.
            prompt_template: The zero-shot prompt template used to generate CoT reasoning.
            max_retries: Maximum number of retries for generating a CoT demo if LLM fails.
            max_tokens: Maximum tokens for LLM generation calls (demos and final answer).
            rand_seed: Base random seed for clustering and LLM seeding. LLM calls will
                       use variations of this seed.
            embedder: The embedding model instance. Defaults to SentenceTransformerEmbedder.
            clusterer: The clustering algorithm instance. Defaults to KMeansClusterer.
        """
        self.llm = llm
        self.n_demos = n_demos
        self.max_q_tokens = max_q_tokens
        self.max_steps = max_steps
        self.prompt_template = prompt_template
        self.max_retries = max_retries
        self.max_tokens = max_tokens
        self.rand_seed = rand_seed
        self.embedder = embedder or SentenceTransformerEmbedder()
        self.clusterer = clusterer or KMeansClusterer()
        self.demos: Optional[List[str]] = None

    def fit(self, questions: List[str]) -> None:
        """Builds the demonstration pool using the Auto-CoT process.

        This involves embedding questions, clustering them, selecting diverse
        representatives, generating CoT reasoning for them using varied seeds,
        and filtering based on length and step count criteria.

        Args:
            questions: A list of questions to build demonstrations from.

        Raises:
            ValueError: If the number of questions is lower than `n_demos`.
            RuntimeError: If embedding or clustering fails, or if no valid demos
                can be generated.
        """
        if len(questions) < self.n_demos:
            raise ValueError(f"Need >= {self.n_demos} questions, got {len(questions)}")

        logger.info("Encoding questions for AutoCoT fitting...")
        embs_list = self.embedder.encode(questions)
        if len(embs_list) == 0:
            raise RuntimeError("Embedding failed to produce results.")
        embs = np.stack(embs_list)

        logger.info("Clustering questions...")
        labels, centers = self.clusterer.cluster(
            embs, self.n_demos, random_seed=self.rand_seed or 0
        )

        logger.info("Selecting candidate demonstrations...")
        candidate_demos: List[Tuple[int, str]] = []
        for c in range(self.n_demos):
            idxs_in_cluster = np.where(labels == c)[0]
            if idxs_in_cluster.size == 0:
                logger.debug(f"Cluster {c} is empty, skipping.")
                continue

            # Calculate distances within the cluster
            cluster_embs = embs[idxs_in_cluster]
            dists = np.linalg.norm(cluster_embs - centers[c], axis=1)
            sorted_relative_indices = np.argsort(dists)

            # Iterate through questions closest to the centroid first
            found_candidate = False
            for relative_idx in sorted_relative_indices:
                original_idx = idxs_in_cluster[relative_idx]
                q = questions[original_idx]
                if approx_token_length(q) <= self.max_q_tokens:
                    candidate_demos.append((original_idx, q))
                    logger.debug(f"Selected candidate from cluster {c}: Q index {original_idx}")
                    found_candidate = True
                    break  # Take only the first valid one closest to centroid

            if not found_candidate:
                logger.debug(f"No suitable candidate found for cluster {c} within token limits.")

        logger.info(f"Generating CoT reasoning for {len(candidate_demos)} candidates...")
        demos: List[str] = []
        # Use enumerate for a simple loop counter if needed, but idx is usually better
        for _demo_idx, (original_q_idx, q) in enumerate(candidate_demos):
            prompt = f"Q: {q}\nA: {self.prompt_template}"
            cot: Optional[str] = None
            for attempt in range(self.max_retries + 1):
                # --- Seed Refinement ---
                # Vary seed based on the original question index and attempt number
                iter_seed: Optional[int] = None
                if self.rand_seed is not None:
                    # Combine base seed, question index, and attempt number
                    # Multiplying attempt helps space out seeds more
                    iter_seed = self.rand_seed + original_q_idx + attempt * 101
                # --- End Seed Refinement ---

                try:
                    logger.debug(
                        f"Attempt {attempt + 1} for Q idx {original_q_idx} with seed {iter_seed}"
                    )
                    cot = self.llm.generate(
                        prompt,
                        max_tokens=self.max_tokens,
                        seed=iter_seed,  # Use the varied seed
                    )
                    break  # Success
                except Exception as e:
                    logger.warning(
                        f"Retry {attempt + 1}/{self.max_retries + 1} for demo Q idx {original_q_idx}: {e}",
                        exc_info=(logger.getEffectiveLevel() <= logging.DEBUG),
                        # Show traceback in debug
                    )
                    if attempt < self.max_retries:
                        # Optional: add a small delay before retrying
                        time.sleep(0.5 * (2**attempt))

            if cot is None:
                logger.error(
                    "Failed to generate demo for Q idx %d ('%s') after %d retries",
                    original_q_idx,
                    q[:50] + "...",
                    self.max_retries + 1,
                )
                continue  # Skip this candidate

            # Filter based on step count
            steps_count = count_steps(cot)
            if steps_count <= self.max_steps:
                demos.append(f"Q: {q}\nA: {cot}")
                logger.debug(
                    f"Successfully generated and filtered demo for Q idx {original_q_idx} ({steps_count} steps)"
                )
            else:
                logger.debug(
                    f"Generated demo for Q idx {original_q_idx} discarded ({steps_count} steps > max {self.max_steps})"
                )

        if len(demos) < self.n_demos:
            logger.warning(
                "Could only build %d final demos (needed %d). Proceeding with available demos.",
                len(demos),
                self.n_demos,
            )
        if not demos:
            logger.error("Failed to build any valid demos after generation and filtering.")
            raise RuntimeError("Failed to build any valid demos.")

        self.demos = demos
        logger.info(f"AutoCoT fitting complete. Generated {len(demos)} demonstrations.")

    async def fit_async(
        self, questions: List[str], semaphore: Optional[asyncio.Semaphore] = None
    ) -> None:
        """Asynchronously builds the demonstration pool using the Auto-CoT process.

        Similar to `fit`, but performs LLM generation calls asynchronously
        using varied seeds.

        Args:
            questions: A list of questions to build demonstrations from.
            semaphore: An optional asyncio.Semaphore to limit concurrent LLM calls.

        Raises:
            ValueError: If the number of questions is lower than `n_demos`.
            RuntimeError: If embedding or clustering fails, or if no valid demos
                can be generated.
        """
        if len(questions) < self.n_demos:
            raise ValueError(f"Need >= {self.n_demos} questions, got {len(questions)}")

        logger.info("Encoding questions for async AutoCoT fitting...")
        embs_list = self.embedder.encode(questions)
        if len(embs_list) == 0:
            raise RuntimeError("Embedding failed to produce results.")
        embs = np.stack(embs_list)

        logger.info("Clustering questions async...")
        labels, centers = self.clusterer.cluster(
            embs, self.n_demos, random_seed=self.rand_seed or 0
        )

        logger.info("Selecting candidate demonstrations async...")
        candidate_demos_info: List[Tuple[int, str]] = []
        for c in range(self.n_demos):
            idxs_in_cluster = np.where(labels == c)[0]
            if idxs_in_cluster.size == 0:
                continue
            cluster_embs = embs[idxs_in_cluster]
            dists = np.linalg.norm(cluster_embs - centers[c], axis=1)
            sorted_relative_indices = np.argsort(dists)
            for relative_idx in sorted_relative_indices:
                original_idx = idxs_in_cluster[relative_idx]
                q = questions[original_idx]
                if approx_token_length(q) <= self.max_q_tokens:
                    candidate_demos_info.append((original_idx, q))
                    break

        logger.info(f"Generating CoT reasoning async for {len(candidate_demos_info)} candidates...")

        async def generate_demo(idx: int, q: str) -> Tuple[int, str, Optional[str]]:
            prompt = f"Q: {q}\nA: {self.prompt_template}"
            for attempt in range(self.max_retries + 1):
                iter_seed: Optional[int] = None
                if self.rand_seed is not None:
                    iter_seed = self.rand_seed + idx + attempt * 101

                try:
                    logger.debug(
                        f"Async attempt {attempt + 1} for Q idx {idx} with seed {iter_seed}"
                    )
                    gen_args = {
                        "max_tokens": self.max_tokens,
                        "seed": iter_seed,
                    }

                    local_semaphore = semaphore or asyncio.Semaphore(1)
                    async with local_semaphore:
                        cot = await self.llm.generate_async(prompt, **gen_args)
                    return idx, q, cot
                except Exception as e:
                    logger.warning(
                        f"Async retry {attempt + 1}/{self.max_retries + 1} for demo Q idx {idx}: {e}",
                        exc_info=(logger.getEffectiveLevel() <= logging.DEBUG),
                    )
                    if attempt < self.max_retries:
                        await asyncio.sleep(0.5 * (2**attempt))

            logger.error(
                "Failed to generate async demo for Q idx %d ('%s') after %d retries",
                idx,
                q[:50] + "...",
                self.max_retries + 1,
            )
            return idx, q, None  # Failed after retries

        tasks = [generate_demo(idx, q) for idx, q in candidate_demos_info]
        results = await asyncio.gather(*tasks, return_exceptions=True)

        demos: List[str] = []
        successful_generations = 0
        for res in results:
            if isinstance(res, Exception):
                if not isinstance(res, asyncio.CancelledError):
                    logger.error(f"Async demo generation task failed: {res}", exc_info=True)
                continue

            if isinstance(res, tuple) and len(res) == 3:
                _idx, q, cot = res
                if cot is not None:
                    successful_generations += 1
                    steps_count = count_steps(cot)
                    if steps_count <= self.max_steps:
                        demos.append(f"Q: {q}\nA: {cot}")
                        logger.debug(
                            f"Successfully generated and filtered async demo for Q idx {_idx} ({steps_count} steps)"
                        )
                    else:
                        logger.debug(
                            f"Async demo for Q idx {_idx} discarded ({steps_count} steps > max {self.max_steps})"
                        )
            else:
                logger.error(f"Unexpected result type from gather: {type(res)} - {res}")

        if len(demos) < self.n_demos:
            logger.warning(
                "Could only build %d final demos async (needed %d). Proceeding with available demos.",
                len(demos),
                self.n_demos,
            )
        if not demos:
            logger.error(
                "Failed to build any valid demos asynchronously after generation and filtering."
            )
            raise RuntimeError("Failed to build any valid demos asynchronously.")

        self.demos = demos
        logger.info(f"Async AutoCoT fitting complete. Generated {len(demos)} demonstrations.")

    # Add type hint Any to **kwargs
    def run(self, test_q: str, **kwargs: Any) -> str:
        """Runs the Auto-CoT strategy for a given test question.

        Constructs a prompt using the generated demonstrations and the test question,
        then calls the LLM to generate the final answer. The base seed is used
        for this final generation unless overridden in kwargs.

        Args:
            test_q: The test question to answer.
            **kwargs: Additional arguments passed to the LLM generation call,
                      potentially overriding default seed, max_tokens, etc.

        Returns:
            The LLM-generated answer string.

        Raises:
            RuntimeError: If `fit` or `fit_async` has not been called successfully first.
        """
        if self.demos is None:
            raise RuntimeError("Call fit() or fit_async() before run()")

        context = "\n\n".join(self.demos)
        payload = f"{context}\n\nQ: {test_q}\nA: {self.prompt_template}"
        logger.debug(
            "AutoCoT final inference payload:\n%s", payload[:500] + "..."
        )  # Log truncated payload

        final_seed = kwargs.pop("seed", self.rand_seed)
        final_max_tokens = kwargs.pop("max_tokens", self.max_tokens)

        logger.info(f"Running final AutoCoT generation for question: '{test_q[:50]}...'")
        try:
            result = self.llm.generate(
                payload,
                max_tokens=final_max_tokens,
                seed=final_seed,
                **kwargs,
            )
            logger.info("Final generation successful.")
            return result
        except Exception as e:
            logger.error(f"Final AutoCoT generation failed: {e}", exc_info=True)
            # Depending on desired behavior, either raise e or return an error marker
            # raise e
            return "[ERROR: Final generation failed]"

    # Add type hint Any to **kwargs
    async def run_async(self, test_q: str, **kwargs: Any) -> str:
        """Asynchronously runs the Auto-CoT strategy for a given test question.

        Constructs a prompt using the generated demonstrations and the test question,
        then calls the LLM asynchronously to generate the final answer. The base
        seed is used for this final generation unless overridden in kwargs.

        Args:
            test_q: The test question to answer.
            **kwargs: Additional arguments passed to the async LLM generation call,
                      potentially overriding default seed, max_tokens, etc.

        Returns:
            The LLM-generated answer string.

        Raises:
            RuntimeError: If `fit` or `fit_async` has not been called successfully first.
        """
        if self.demos is None:
            raise RuntimeError("Call fit() or fit_async() before run_async()")

        context = "\n\n".join(self.demos)
        payload = f"{context}\n\nQ: {test_q}\nA: {self.prompt_template}"
        logger.debug("Async AutoCoT final inference payload:\n%s", payload[:500] + "...")

        final_seed = kwargs.pop("seed", self.rand_seed)
        final_max_tokens = kwargs.pop("max_tokens", self.max_tokens)

        logger.info(f"Running final async AutoCoT generation for question: '{test_q[:50]}...'")
        try:
            semaphore = kwargs.pop("semaphore", None)
            gen_args = {
                "max_tokens": final_max_tokens,
                "seed": final_seed,
                **kwargs,
            }
            if semaphore:
                async with semaphore:
                    result = await self.llm.generate_async(payload, **gen_args)
            else:
                result = await self.llm.generate_async(payload, **gen_args)

            logger.info("Final async generation successful.")
            return result
        except Exception as e:
            logger.error(f"Final async AutoCoT generation failed: {e}", exc_info=True)
            # raise e
            return "[ERROR: Final async generation failed]"

__init__(llm, n_demos=8, max_q_tokens=60, max_steps=5, *, prompt_template="Let's think step by step.", max_retries=2, max_tokens=None, rand_seed=None, embedder=None, clusterer=None)

Initializes the AutoCoT strategy handler.

Parameters:

Name Type Description Default
llm BaseLLM

The language model instance to use for generation.

required
n_demos int

The desired number of demonstrations to generate.

8
max_q_tokens int

Maximum approximate token length for questions selected as demos.

60
max_steps int

Maximum number of reasoning steps allowed in a generated demo CoT.

5
prompt_template str

The zero-shot prompt template used to generate CoT reasoning.

"Let's think step by step."
max_retries int

Maximum number of retries for generating a CoT demo if LLM fails.

2
max_tokens Optional[int]

Maximum tokens for LLM generation calls (demos and final answer).

None
rand_seed Optional[int]

Base random seed for clustering and LLM seeding. LLM calls will use variations of this seed.

None
embedder Optional[BaseEmbedder]

The embedding model instance. Defaults to SentenceTransformerEmbedder.

None
clusterer Optional[BaseClusterer]

The clustering algorithm instance. Defaults to KMeansClusterer.

None
Source code in cogitator/strategies/auto_cot.py
def __init__(
    self,
    llm: BaseLLM,
    n_demos: int = 8,
    max_q_tokens: int = 60,
    max_steps: int = 5,
    *,
    prompt_template: str = "Let's think step by step.",
    max_retries: int = 2,
    max_tokens: Optional[int] = None,
    rand_seed: Optional[int] = None,
    embedder: Optional[BaseEmbedder] = None,
    clusterer: Optional[BaseClusterer] = None,
) -> None:
    """Initializes the AutoCoT strategy handler.

    Args:
        llm: The language model instance to use for generation.
        n_demos: The desired number of demonstrations to generate.
        max_q_tokens: Maximum approximate token length for questions selected as demos.
        max_steps: Maximum number of reasoning steps allowed in a generated demo CoT.
        prompt_template: The zero-shot prompt template used to generate CoT reasoning.
        max_retries: Maximum number of retries for generating a CoT demo if LLM fails.
        max_tokens: Maximum tokens for LLM generation calls (demos and final answer).
        rand_seed: Base random seed for clustering and LLM seeding. LLM calls will
                   use variations of this seed.
        embedder: The embedding model instance. Defaults to SentenceTransformerEmbedder.
        clusterer: The clustering algorithm instance. Defaults to KMeansClusterer.
    """
    self.llm = llm
    self.n_demos = n_demos
    self.max_q_tokens = max_q_tokens
    self.max_steps = max_steps
    self.prompt_template = prompt_template
    self.max_retries = max_retries
    self.max_tokens = max_tokens
    self.rand_seed = rand_seed
    self.embedder = embedder or SentenceTransformerEmbedder()
    self.clusterer = clusterer or KMeansClusterer()
    self.demos: Optional[List[str]] = None

fit(questions)

Builds the demonstration pool using the Auto-CoT process.

This involves embedding questions, clustering them, selecting diverse representatives, generating CoT reasoning for them using varied seeds, and filtering based on length and step count criteria.

Parameters:

Name Type Description Default
questions List[str]

A list of questions to build demonstrations from.

required

Raises:

Type Description
ValueError

If the number of questions is lower than n_demos.

RuntimeError

If embedding or clustering fails, or if no valid demos can be generated.

Source code in cogitator/strategies/auto_cot.py
def fit(self, questions: List[str]) -> None:
    """Builds the demonstration pool using the Auto-CoT process.

    This involves embedding questions, clustering them, selecting diverse
    representatives, generating CoT reasoning for them using varied seeds,
    and filtering based on length and step count criteria.

    Args:
        questions: A list of questions to build demonstrations from.

    Raises:
        ValueError: If the number of questions is lower than `n_demos`.
        RuntimeError: If embedding or clustering fails, or if no valid demos
            can be generated.
    """
    if len(questions) < self.n_demos:
        raise ValueError(f"Need >= {self.n_demos} questions, got {len(questions)}")

    logger.info("Encoding questions for AutoCoT fitting...")
    embs_list = self.embedder.encode(questions)
    if len(embs_list) == 0:
        raise RuntimeError("Embedding failed to produce results.")
    embs = np.stack(embs_list)

    logger.info("Clustering questions...")
    labels, centers = self.clusterer.cluster(
        embs, self.n_demos, random_seed=self.rand_seed or 0
    )

    logger.info("Selecting candidate demonstrations...")
    candidate_demos: List[Tuple[int, str]] = []
    for c in range(self.n_demos):
        idxs_in_cluster = np.where(labels == c)[0]
        if idxs_in_cluster.size == 0:
            logger.debug(f"Cluster {c} is empty, skipping.")
            continue

        # Calculate distances within the cluster
        cluster_embs = embs[idxs_in_cluster]
        dists = np.linalg.norm(cluster_embs - centers[c], axis=1)
        sorted_relative_indices = np.argsort(dists)

        # Iterate through questions closest to the centroid first
        found_candidate = False
        for relative_idx in sorted_relative_indices:
            original_idx = idxs_in_cluster[relative_idx]
            q = questions[original_idx]
            if approx_token_length(q) <= self.max_q_tokens:
                candidate_demos.append((original_idx, q))
                logger.debug(f"Selected candidate from cluster {c}: Q index {original_idx}")
                found_candidate = True
                break  # Take only the first valid one closest to centroid

        if not found_candidate:
            logger.debug(f"No suitable candidate found for cluster {c} within token limits.")

    logger.info(f"Generating CoT reasoning for {len(candidate_demos)} candidates...")
    demos: List[str] = []
    # Use enumerate for a simple loop counter if needed, but idx is usually better
    for _demo_idx, (original_q_idx, q) in enumerate(candidate_demos):
        prompt = f"Q: {q}\nA: {self.prompt_template}"
        cot: Optional[str] = None
        for attempt in range(self.max_retries + 1):
            # --- Seed Refinement ---
            # Vary seed based on the original question index and attempt number
            iter_seed: Optional[int] = None
            if self.rand_seed is not None:
                # Combine base seed, question index, and attempt number
                # Multiplying attempt helps space out seeds more
                iter_seed = self.rand_seed + original_q_idx + attempt * 101
            # --- End Seed Refinement ---

            try:
                logger.debug(
                    f"Attempt {attempt + 1} for Q idx {original_q_idx} with seed {iter_seed}"
                )
                cot = self.llm.generate(
                    prompt,
                    max_tokens=self.max_tokens,
                    seed=iter_seed,  # Use the varied seed
                )
                break  # Success
            except Exception as e:
                logger.warning(
                    f"Retry {attempt + 1}/{self.max_retries + 1} for demo Q idx {original_q_idx}: {e}",
                    exc_info=(logger.getEffectiveLevel() <= logging.DEBUG),
                    # Show traceback in debug
                )
                if attempt < self.max_retries:
                    # Optional: add a small delay before retrying
                    time.sleep(0.5 * (2**attempt))

        if cot is None:
            logger.error(
                "Failed to generate demo for Q idx %d ('%s') after %d retries",
                original_q_idx,
                q[:50] + "...",
                self.max_retries + 1,
            )
            continue  # Skip this candidate

        # Filter based on step count
        steps_count = count_steps(cot)
        if steps_count <= self.max_steps:
            demos.append(f"Q: {q}\nA: {cot}")
            logger.debug(
                f"Successfully generated and filtered demo for Q idx {original_q_idx} ({steps_count} steps)"
            )
        else:
            logger.debug(
                f"Generated demo for Q idx {original_q_idx} discarded ({steps_count} steps > max {self.max_steps})"
            )

    if len(demos) < self.n_demos:
        logger.warning(
            "Could only build %d final demos (needed %d). Proceeding with available demos.",
            len(demos),
            self.n_demos,
        )
    if not demos:
        logger.error("Failed to build any valid demos after generation and filtering.")
        raise RuntimeError("Failed to build any valid demos.")

    self.demos = demos
    logger.info(f"AutoCoT fitting complete. Generated {len(demos)} demonstrations.")

fit_async(questions, semaphore=None) async

Asynchronously builds the demonstration pool using the Auto-CoT process.

Similar to fit, but performs LLM generation calls asynchronously using varied seeds.

Parameters:

Name Type Description Default
questions List[str]

A list of questions to build demonstrations from.

required
semaphore Optional[Semaphore]

An optional asyncio.Semaphore to limit concurrent LLM calls.

None

Raises:

Type Description
ValueError

If the number of questions is lower than n_demos.

RuntimeError

If embedding or clustering fails, or if no valid demos can be generated.

Source code in cogitator/strategies/auto_cot.py
async def fit_async(
    self, questions: List[str], semaphore: Optional[asyncio.Semaphore] = None
) -> None:
    """Asynchronously builds the demonstration pool using the Auto-CoT process.

    Similar to `fit`, but performs LLM generation calls asynchronously
    using varied seeds.

    Args:
        questions: A list of questions to build demonstrations from.
        semaphore: An optional asyncio.Semaphore to limit concurrent LLM calls.

    Raises:
        ValueError: If the number of questions is lower than `n_demos`.
        RuntimeError: If embedding or clustering fails, or if no valid demos
            can be generated.
    """
    if len(questions) < self.n_demos:
        raise ValueError(f"Need >= {self.n_demos} questions, got {len(questions)}")

    logger.info("Encoding questions for async AutoCoT fitting...")
    embs_list = self.embedder.encode(questions)
    if len(embs_list) == 0:
        raise RuntimeError("Embedding failed to produce results.")
    embs = np.stack(embs_list)

    logger.info("Clustering questions async...")
    labels, centers = self.clusterer.cluster(
        embs, self.n_demos, random_seed=self.rand_seed or 0
    )

    logger.info("Selecting candidate demonstrations async...")
    candidate_demos_info: List[Tuple[int, str]] = []
    for c in range(self.n_demos):
        idxs_in_cluster = np.where(labels == c)[0]
        if idxs_in_cluster.size == 0:
            continue
        cluster_embs = embs[idxs_in_cluster]
        dists = np.linalg.norm(cluster_embs - centers[c], axis=1)
        sorted_relative_indices = np.argsort(dists)
        for relative_idx in sorted_relative_indices:
            original_idx = idxs_in_cluster[relative_idx]
            q = questions[original_idx]
            if approx_token_length(q) <= self.max_q_tokens:
                candidate_demos_info.append((original_idx, q))
                break

    logger.info(f"Generating CoT reasoning async for {len(candidate_demos_info)} candidates...")

    async def generate_demo(idx: int, q: str) -> Tuple[int, str, Optional[str]]:
        prompt = f"Q: {q}\nA: {self.prompt_template}"
        for attempt in range(self.max_retries + 1):
            iter_seed: Optional[int] = None
            if self.rand_seed is not None:
                iter_seed = self.rand_seed + idx + attempt * 101

            try:
                logger.debug(
                    f"Async attempt {attempt + 1} for Q idx {idx} with seed {iter_seed}"
                )
                gen_args = {
                    "max_tokens": self.max_tokens,
                    "seed": iter_seed,
                }

                local_semaphore = semaphore or asyncio.Semaphore(1)
                async with local_semaphore:
                    cot = await self.llm.generate_async(prompt, **gen_args)
                return idx, q, cot
            except Exception as e:
                logger.warning(
                    f"Async retry {attempt + 1}/{self.max_retries + 1} for demo Q idx {idx}: {e}",
                    exc_info=(logger.getEffectiveLevel() <= logging.DEBUG),
                )
                if attempt < self.max_retries:
                    await asyncio.sleep(0.5 * (2**attempt))

        logger.error(
            "Failed to generate async demo for Q idx %d ('%s') after %d retries",
            idx,
            q[:50] + "...",
            self.max_retries + 1,
        )
        return idx, q, None  # Failed after retries

    tasks = [generate_demo(idx, q) for idx, q in candidate_demos_info]
    results = await asyncio.gather(*tasks, return_exceptions=True)

    demos: List[str] = []
    successful_generations = 0
    for res in results:
        if isinstance(res, Exception):
            if not isinstance(res, asyncio.CancelledError):
                logger.error(f"Async demo generation task failed: {res}", exc_info=True)
            continue

        if isinstance(res, tuple) and len(res) == 3:
            _idx, q, cot = res
            if cot is not None:
                successful_generations += 1
                steps_count = count_steps(cot)
                if steps_count <= self.max_steps:
                    demos.append(f"Q: {q}\nA: {cot}")
                    logger.debug(
                        f"Successfully generated and filtered async demo for Q idx {_idx} ({steps_count} steps)"
                    )
                else:
                    logger.debug(
                        f"Async demo for Q idx {_idx} discarded ({steps_count} steps > max {self.max_steps})"
                    )
        else:
            logger.error(f"Unexpected result type from gather: {type(res)} - {res}")

    if len(demos) < self.n_demos:
        logger.warning(
            "Could only build %d final demos async (needed %d). Proceeding with available demos.",
            len(demos),
            self.n_demos,
        )
    if not demos:
        logger.error(
            "Failed to build any valid demos asynchronously after generation and filtering."
        )
        raise RuntimeError("Failed to build any valid demos asynchronously.")

    self.demos = demos
    logger.info(f"Async AutoCoT fitting complete. Generated {len(demos)} demonstrations.")

run(test_q, **kwargs)

Runs the Auto-CoT strategy for a given test question.

Constructs a prompt using the generated demonstrations and the test question, then calls the LLM to generate the final answer. The base seed is used for this final generation unless overridden in kwargs.

Parameters:

Name Type Description Default
test_q str

The test question to answer.

required
**kwargs Any

Additional arguments passed to the LLM generation call, potentially overriding default seed, max_tokens, etc.

{}

Returns:

Type Description
str

The LLM-generated answer string.

Raises:

Type Description
RuntimeError

If fit or fit_async has not been called successfully first.

Source code in cogitator/strategies/auto_cot.py
def run(self, test_q: str, **kwargs: Any) -> str:
    """Runs the Auto-CoT strategy for a given test question.

    Constructs a prompt using the generated demonstrations and the test question,
    then calls the LLM to generate the final answer. The base seed is used
    for this final generation unless overridden in kwargs.

    Args:
        test_q: The test question to answer.
        **kwargs: Additional arguments passed to the LLM generation call,
                  potentially overriding default seed, max_tokens, etc.

    Returns:
        The LLM-generated answer string.

    Raises:
        RuntimeError: If `fit` or `fit_async` has not been called successfully first.
    """
    if self.demos is None:
        raise RuntimeError("Call fit() or fit_async() before run()")

    context = "\n\n".join(self.demos)
    payload = f"{context}\n\nQ: {test_q}\nA: {self.prompt_template}"
    logger.debug(
        "AutoCoT final inference payload:\n%s", payload[:500] + "..."
    )  # Log truncated payload

    final_seed = kwargs.pop("seed", self.rand_seed)
    final_max_tokens = kwargs.pop("max_tokens", self.max_tokens)

    logger.info(f"Running final AutoCoT generation for question: '{test_q[:50]}...'")
    try:
        result = self.llm.generate(
            payload,
            max_tokens=final_max_tokens,
            seed=final_seed,
            **kwargs,
        )
        logger.info("Final generation successful.")
        return result
    except Exception as e:
        logger.error(f"Final AutoCoT generation failed: {e}", exc_info=True)
        # Depending on desired behavior, either raise e or return an error marker
        # raise e
        return "[ERROR: Final generation failed]"

run_async(test_q, **kwargs) async

Asynchronously runs the Auto-CoT strategy for a given test question.

Constructs a prompt using the generated demonstrations and the test question, then calls the LLM asynchronously to generate the final answer. The base seed is used for this final generation unless overridden in kwargs.

Parameters:

Name Type Description Default
test_q str

The test question to answer.

required
**kwargs Any

Additional arguments passed to the async LLM generation call, potentially overriding default seed, max_tokens, etc.

{}

Returns:

Type Description
str

The LLM-generated answer string.

Raises:

Type Description
RuntimeError

If fit or fit_async has not been called successfully first.

Source code in cogitator/strategies/auto_cot.py
async def run_async(self, test_q: str, **kwargs: Any) -> str:
    """Asynchronously runs the Auto-CoT strategy for a given test question.

    Constructs a prompt using the generated demonstrations and the test question,
    then calls the LLM asynchronously to generate the final answer. The base
    seed is used for this final generation unless overridden in kwargs.

    Args:
        test_q: The test question to answer.
        **kwargs: Additional arguments passed to the async LLM generation call,
                  potentially overriding default seed, max_tokens, etc.

    Returns:
        The LLM-generated answer string.

    Raises:
        RuntimeError: If `fit` or `fit_async` has not been called successfully first.
    """
    if self.demos is None:
        raise RuntimeError("Call fit() or fit_async() before run_async()")

    context = "\n\n".join(self.demos)
    payload = f"{context}\n\nQ: {test_q}\nA: {self.prompt_template}"
    logger.debug("Async AutoCoT final inference payload:\n%s", payload[:500] + "...")

    final_seed = kwargs.pop("seed", self.rand_seed)
    final_max_tokens = kwargs.pop("max_tokens", self.max_tokens)

    logger.info(f"Running final async AutoCoT generation for question: '{test_q[:50]}...'")
    try:
        semaphore = kwargs.pop("semaphore", None)
        gen_args = {
            "max_tokens": final_max_tokens,
            "seed": final_seed,
            **kwargs,
        }
        if semaphore:
            async with semaphore:
                result = await self.llm.generate_async(payload, **gen_args)
        else:
            result = await self.llm.generate_async(payload, **gen_args)

        logger.info("Final async generation successful.")
        return result
    except Exception as e:
        logger.error(f"Final async AutoCoT generation failed: {e}", exc_info=True)
        # raise e
        return "[ERROR: Final async generation failed]"