class OpenAILLM(BaseLLM):
"""LLM provider implementation for OpenAI API models.
Handles interactions with models like GPT-4, GPT-4o, etc., supporting
standard generation, streaming, JSON mode, and structured outputs where available.
Includes retry logic for common API errors.
"""
_STRUCTURED_OUTPUT_SUPPORTING_MODELS = {
"gpt-4o",
"gpt-4o-mini",
"gpt-4o-2024-08-06",
"gpt-4o-mini-2024-07-18",
}
_JSON_MODE_SUPPORTING_MODELS = {
"gpt-4",
"gpt-4-turbo",
"gpt-4-turbo-preview",
"gpt-3.5-turbo-1106",
"gpt-3.5-turbo-0125",
} | _STRUCTURED_OUTPUT_SUPPORTING_MODELS
def __init__(
self,
api_key: str,
model: str = "gpt-4o-mini", # Changed default
temperature: float = 0.7,
max_tokens: int = 512,
stop: Optional[List[str]] = None,
seed: Optional[int] = 33,
retry_attempts: int = 3,
retry_backoff: float = 1.0,
) -> None:
"""Initializes the OpenAILLM provider.
Args:
api_key: Your OpenAI API key.
model: The OpenAI model identifier (e.g., "gpt-4o", "gpt-3.5-turbo").
temperature: The sampling temperature for generation.
max_tokens: The maximum number of tokens to generate.
stop: A list of sequences where the API will stop generation.
seed: The random seed for reproducibility (if supported by the model).
retry_attempts: Number of retries upon API call failure.
retry_backoff: Initial backoff factor for retries (exponential).
"""
super().__init__() # Call BaseLLM init
self.client = SyncOpenAI(api_key=api_key)
self.async_client = AsyncOpenAI(api_key=api_key)
self.model = model
self.temperature = temperature
self.max_tokens = max_tokens
self.stop = stop
self.seed = seed
self.retry_attempts = retry_attempts
self.retry_backoff = retry_backoff
# Load tiktoken encoding
try:
self.encoding = tiktoken.encoding_for_model(self.model)
except KeyError:
logger.warning(f"No tiktoken encoding found for model {self.model}. Using cl100k_base.")
self.encoding = tiktoken.get_encoding("cl100k_base")
logger.info(f"Initialized OpenAILLM with model: {self.model}")
def _update_token_counts(
self, prompt: str, response: Any, completion_text: Optional[str]
) -> None:
"""Updates token counts using API response or tiktoken."""
prompt_tokens = None
completion_tokens = None
source = "unknown"
if hasattr(response, "usage") and response.usage:
prompt_tokens = getattr(response.usage, "prompt_tokens", None)
completion_tokens = getattr(response.usage, "completion_tokens", None)
source = "api"
# Fallback or verification using tiktoken
if prompt_tokens is None or completion_tokens is None:
source = "tiktoken"
try:
if prompt:
prompt_tokens = len(self.encoding.encode(prompt))
if completion_text:
completion_tokens = len(self.encoding.encode(completion_text))
else:
completion_tokens = 0 # Set completion to 0 if text is None or empty
except Exception as e:
logger.warning(f"tiktoken encoding failed during fallback: {e}", exc_info=False)
# Keep existing API values if only fallback failed, otherwise reset
if not (hasattr(response, "usage") and response.usage):
self._reset_token_counts()
return # Exit if fallback fails and no API data exists
# Store the determined values
self._last_prompt_tokens = prompt_tokens
self._last_completion_tokens = completion_tokens
logger.debug(
f"Token usage ({source}): P={self._last_prompt_tokens}, C={self._last_completion_tokens}"
)
def _prepare_api_params(
self,
is_json_mode: bool = False,
response_schema: Optional[Type[BaseModel]] = None,
**kwargs: Any,
) -> Tuple[Dict[str, Any], Optional[str]]:
"""Prepares the parameters dictionary for the OpenAI API call.
Determines the appropriate 'response_format' based on whether JSON mode
or structured output is requested, the provided schema, and model support.
Args:
is_json_mode: Flag indicating if JSON output is requested.
response_schema: The Pydantic model if structured output is desired.
**kwargs: Additional parameters to pass to the API call, overriding defaults.
Returns:
A tuple containing:
- The dictionary of parameters ready for the API call.
- A string indicating the JSON mode used ('json_schema', 'json_object', None),
used for downstream processing logic.
"""
params = kwargs.copy()
mode_used: Optional[str] = None
supports_structured = any(
self.model.startswith(known) for known in self._STRUCTURED_OUTPUT_SUPPORTING_MODELS
)
supports_json_object = any(
self.model.startswith(known) for known in self._JSON_MODE_SUPPORTING_MODELS
)
if is_json_mode:
if response_schema:
if supports_structured:
try:
schema_dict = response_schema.model_json_schema()
# Ensure additionalProperties is false for strictness if it's an object
if schema_dict.get("type") == "object":
schema_dict["additionalProperties"] = False
params["response_format"] = {
"type": "json_schema",
"json_schema": {
"name": response_schema.__name__,
"description": response_schema.__doc__
or f"Schema for {response_schema.__name__}",
"strict": True, # Enable strict schema validation
"schema": schema_dict,
},
}
mode_used = "json_schema"
logger.debug(
f"Using OpenAI Structured Outputs (json_schema) for model: {self.model}"
)
except Exception as e:
logger.warning(
f"Failed to generate/set JSON schema for {response_schema.__name__}: {e}. Falling back."
)
# Fallback to json_object if schema fails but model supports it
if supports_json_object:
params["response_format"] = {"type": "json_object"}
mode_used = "json_object"
logger.debug(
f"Fell back to OpenAI JSON mode (json_object) after schema failure for model: {self.model}"
)
else:
mode_used = None # Cannot use JSON mode
logger.debug(
"Fallback failed, JSON mode not supported. Relying on extraction."
)
elif supports_json_object:
# Model supports json_object but not full structured output, use json_object
params["response_format"] = {"type": "json_object"}
mode_used = "json_object"
logger.debug(
f"Model {self.model} supports only json_object, using that despite schema being provided."
)
else:
# Model doesn't officially support either, but attempt structured output anyway if schema provided
logger.warning(
f"Model {self.model} not known to support JSON modes. Attempting json_schema anyway as schema was provided..."
)
try:
schema_dict = response_schema.model_json_schema()
# Ensure additionalProperties is false for strictness if it's an object
if schema_dict.get("type") == "object":
schema_dict["additionalProperties"] = False
params["response_format"] = {
"type": "json_schema",
"json_schema": {
"name": response_schema.__name__,
"description": response_schema.__doc__
or f"Schema for {response_schema.__name__}",
"strict": True,
"schema": schema_dict,
},
}
mode_used = "json_schema"
logger.debug(
"Attempting OpenAI Structured Outputs (json_schema) on potentially unsupported model..."
)
except Exception as e:
logger.warning(
f"Failed to generate/set JSON schema for unsupported model attempt: {e}. Relying on extraction."
)
mode_used = None
else: # is_json_mode is True, but no response_schema provided
if supports_json_object:
params["response_format"] = {"type": "json_object"}
mode_used = "json_object"
logger.debug("Using OpenAI JSON mode (json_object) as no schema provided.")
else:
mode_used = None # Cannot use JSON mode
logger.debug(
"JSON requested, no schema, model doesn't support json_object. Relying on extraction."
)
else: # is_json_mode is False
mode_used = None
# Add seed if not present and set in instance
if "seed" not in params and self.seed is not None:
params["seed"] = self.seed
# Ensure seed is an integer
if params.get("seed") is not None:
try:
params["seed"] = int(params["seed"])
except (ValueError, TypeError):
logger.warning(
f"Could not convert seed value {params['seed']} to int. Setting seed to None."
)
if "seed" in params:
del params["seed"]
return params, mode_used
def _call_api(
self,
is_json_mode: bool = False,
response_schema: Optional[Type[BaseModel]] = None,
**kwargs: Any,
) -> Tuple[Any, Optional[str]]:
"""Makes a synchronous call to the OpenAI chat completions API with retries."""
attempts = 0
api_params, mode_used = self._prepare_api_params(
is_json_mode=is_json_mode, response_schema=response_schema, **kwargs
)
# Get prompt text for token counting before potential modification/removal
prompt_for_count = (
api_params["messages"][-1]["content"] if api_params.get("messages") else ""
)
self._reset_token_counts() # Reset before attempting call
while True:
try:
completion = self.client.chat.completions.create(**api_params)
# Extract completion text for token counting
completion_text = ""
if completion.choices:
msg = getattr(completion.choices[0], "message", None)
if msg:
completion_text = getattr(msg, "content", "") or ""
# Update token counts using API response or tiktoken
self._update_token_counts(prompt_for_count, completion, completion_text)
return completion, mode_used
except openai.OpenAIError as e:
self._reset_token_counts() # Reset counts on error
attempts += 1
if attempts > self.retry_attempts:
logger.error(f"OpenAI API call failed after {attempts} attempts: {e}")
raise
logger.warning(
f"OpenAI API error (attempt {attempts}/{self.retry_attempts + 1}): {e}. Retrying..."
)
time.sleep(self.retry_backoff * (2 ** (attempts - 1)))
except Exception as e:
self._reset_token_counts() # Reset counts on unexpected error
logger.error(f"Unexpected error during OpenAI API call: {e}", exc_info=True)
raise
async def _call_api_async(
self,
is_json_mode: bool = False,
response_schema: Optional[Type[BaseModel]] = None,
**kwargs: Any,
) -> Tuple[Any, Optional[str]]:
"""Makes an asynchronous call to the OpenAI chat completions API with retries."""
attempts = 0
api_params, mode_used = self._prepare_api_params(
is_json_mode=is_json_mode, response_schema=response_schema, **kwargs
)
prompt_for_count = (
api_params["messages"][-1]["content"] if api_params.get("messages") else ""
)
self._reset_token_counts() # Reset before attempting call
while True:
try:
completion = await self.async_client.chat.completions.create(**api_params)
# Extract completion text for token counting
completion_text = ""
if completion.choices:
msg = getattr(completion.choices[0], "message", None)
if msg:
completion_text = getattr(msg, "content", "") or ""
# Update token counts using API response or tiktoken
self._update_token_counts(prompt_for_count, completion, completion_text)
return completion, mode_used
except openai.OpenAIError as e:
self._reset_token_counts() # Reset counts on error
attempts += 1
if attempts > self.retry_attempts:
logger.error(f"Async OpenAI API call failed after {attempts} attempts: {e}")
raise
logger.warning(
f"Async OpenAI API error (attempt {attempts}/{self.retry_attempts + 1}): {e}. Retrying..."
)
await asyncio.sleep(self.retry_backoff * (2 ** (attempts - 1)))
except Exception as e:
self._reset_token_counts() # Reset counts on unexpected error
logger.error(f"Unexpected error during async OpenAI API call: {e}", exc_info=True)
raise
def generate(
self,
prompt: str,
temperature: Optional[float] = None,
max_tokens: Optional[int] = None,
stop: Optional[List[str]] = None,
**kwargs: Any,
) -> str:
"""Generates a single text completion using the configured OpenAI model."""
# _reset_token_counts is handled by _call_api
call_kwargs = {
"model": self.model,
"messages": [{"role": "user", "content": prompt}],
"temperature": temperature if temperature is not None else self.temperature,
"max_tokens": max_tokens or self.max_tokens,
"stop": stop or self.stop,
**kwargs,
}
resp, _ = self._call_api(is_json_mode=False, **call_kwargs)
choices = resp.choices or []
if not choices or not choices[0].message or choices[0].message.content is None:
logger.warning(
f"OpenAI response missing choices or content for prompt: {prompt[:100]}..."
)
# Reset counts if response is invalid, as _update might not have run correctly
self._reset_token_counts()
raise RuntimeError("OpenAI returned empty choices or content")
text = choices[0].message.content
# Note: _update_token_counts was already called in _call_api
return text.strip()
async def generate_async(
self,
prompt: str,
temperature: Optional[float] = None,
max_tokens: Optional[int] = None,
stop: Optional[List[str]] = None,
**kwargs: Any,
) -> str:
"""Asynchronously generates a single text completion using OpenAI."""
# _reset_token_counts is handled by _call_api_async
call_kwargs = {
"model": self.model,
"messages": [{"role": "user", "content": prompt}],
"temperature": temperature if temperature is not None else self.temperature,
"max_tokens": max_tokens or self.max_tokens,
"stop": stop or self.stop,
**kwargs,
}
resp, _ = await self._call_api_async(is_json_mode=False, **call_kwargs)
choices = resp.choices or []
if not choices or not choices[0].message or choices[0].message.content is None:
logger.warning(
f"Async OpenAI response missing choices or content for prompt: {prompt[:100]}..."
)
# Reset counts if response is invalid
self._reset_token_counts()
raise RuntimeError("Async OpenAI returned empty choices or content")
text = choices[0].message.content
# Note: _update_token_counts was already called in _call_api_async
return text.strip()
def _generate_json_internal(
self, prompt: str, response_model: Type[BaseModel], **kwargs: Any
) -> Tuple[str, Optional[str]]:
"""Internal method for OpenAI JSON generation."""
# _reset_token_counts is handled by _call_api
call_kwargs = {
"model": self.model,
"messages": [{"role": "user", "content": prompt}],
"max_tokens": kwargs.pop("max_tokens", self.max_tokens),
"temperature": kwargs.pop("temperature", 0.1), # Default lower temp for JSON
**kwargs,
}
resp, mode_used = self._call_api(
is_json_mode=True, response_schema=response_model, **call_kwargs
)
choices = resp.choices or []
if not choices or not choices[0].message or choices[0].message.content is None:
logger.warning(
f"OpenAI JSON response missing choices or content for prompt: {prompt[:100]}..."
)
# Reset counts if response is invalid
self._reset_token_counts()
raise RuntimeError("OpenAI returned empty choices or content for JSON request")
# Token counts updated within _call_api
return choices[0].message.content, mode_used
async def _generate_json_internal_async(
self, prompt: str, response_model: Type[BaseModel], **kwargs: Any
) -> Tuple[str, Optional[str]]:
"""Asynchronous internal method for OpenAI JSON generation."""
# _reset_token_counts is handled by _call_api_async
call_kwargs = {
"model": self.model,
"messages": [{"role": "user", "content": prompt}],
"max_tokens": kwargs.pop("max_tokens", self.max_tokens),
"temperature": kwargs.pop("temperature", 0.1), # Default lower temp for JSON
**kwargs,
}
resp, mode_used = await self._call_api_async(
is_json_mode=True, response_schema=response_model, **call_kwargs
)
choices = resp.choices or []
if not choices or not choices[0].message or choices[0].message.content is None:
logger.warning(
f"Async OpenAI JSON response missing choices or content for prompt: {prompt[:100]}..."
)
# Reset counts if response is invalid
self._reset_token_counts()
raise RuntimeError("Async OpenAI returned empty choices or content for JSON request")
# Token counts updated within _call_api_async
return choices[0].message.content, mode_used
def generate_stream(
self,
prompt: str,
temperature: Optional[float] = None,
max_tokens: Optional[int] = None,
stop: Optional[List[str]] = None,
**kwargs: Any,
) -> Iterator[str]:
"""Generates a stream of text chunks using the configured OpenAI model.
Note: Token counts are reset but not reliably updated during streaming.
Use `get_last...` methods after streaming for potential approximations
based on the prompt if needed, but completion counts will be unreliable.
"""
self._reset_token_counts() # Reset counts for stream start
# Estimate prompt tokens before starting stream
try:
self._last_prompt_tokens = len(self.encoding.encode(prompt))
except Exception as e:
logger.warning(f"tiktoken encoding failed for stream prompt: {e}")
self._last_prompt_tokens = None
call_kwargs = {
"model": self.model,
"messages": [{"role": "user", "content": prompt}],
"temperature": temperature if temperature is not None else self.temperature,
"max_tokens": max_tokens or self.max_tokens,
"stop": stop or self.stop,
"stream": True,
**kwargs,
}
try:
resp_stream = self.client.chat.completions.create(**call_kwargs)
for chunk in resp_stream:
if chunk.choices:
delta = getattr(chunk.choices[0], "delta", None)
if delta and delta.content:
yield delta.content
# Final usage stats are not typically available in the stream response object
except openai.OpenAIError as e:
logger.error(f"OpenAI stream API call failed: {e}")
self._reset_token_counts()
raise RuntimeError(f"OpenAI stream failed: {e}") from e
except Exception as e:
logger.error(f"Unexpected error during OpenAI stream call: {e}", exc_info=True)
self._reset_token_counts()
raise
async def generate_stream_async(
self,
prompt: str,
temperature: Optional[float] = None,
max_tokens: Optional[int] = None,
stop: Optional[List[str]] = None,
**kwargs: Any,
) -> AsyncIterator[str]:
"""Asynchronously generates a stream of text chunks using OpenAI.
Note: Token counts are reset but not reliably updated during streaming.
"""
self._reset_token_counts() # Reset counts for stream start
# Estimate prompt tokens before starting stream
try:
self._last_prompt_tokens = len(self.encoding.encode(prompt))
except Exception as e:
logger.warning(f"tiktoken encoding failed for async stream prompt: {e}")
self._last_prompt_tokens = None
call_kwargs = {
"model": self.model,
"messages": [{"role": "user", "content": prompt}],
"temperature": temperature if temperature is not None else self.temperature,
"max_tokens": max_tokens or self.max_tokens,
"stop": stop or self.stop,
"stream": True,
**kwargs,
}
try:
resp_stream = await self.async_client.chat.completions.create(**call_kwargs)
async for chunk in resp_stream:
if chunk.choices:
delta = getattr(chunk.choices[0], "delta", None)
if delta and delta.content:
yield delta.content
# Final usage stats are not typically available in the stream response object
except openai.OpenAIError as e:
logger.error(f"Async OpenAI stream API call failed: {e}")
self._reset_token_counts()
raise RuntimeError(f"Async OpenAI stream failed: {e}") from e
except Exception as e:
logger.error(f"Unexpected error during async OpenAI stream call: {e}", exc_info=True)
self._reset_token_counts()
raise