import inspect
import sys
from functools import wraps
from typing import TYPE_CHECKING
import sentry_sdk.utils
from sentry_sdk import start_span
from sentry_sdk.ai.utils import _set_span_data_attribute
from sentry_sdk.consts import SPANDATA
from sentry_sdk.traces import StreamedSpan
from sentry_sdk.tracing import Span
from sentry_sdk.utils import ContextVar, capture_internal_exceptions, reraise
if TYPE_CHECKING:
from typing import Any, Awaitable, Callable, Optional, TypeVar, Union
F = TypeVar("F", bound=Union[Callable[..., Any], Callable[..., Awaitable[Any]]])
_ai_pipeline_name = ContextVar("ai_pipeline_name", default=None)
def set_ai_pipeline_name(name: "Optional[str]") -> None:
_ai_pipeline_name.set(name)
def get_ai_pipeline_name() -> "Optional[str]":
return _ai_pipeline_name.get()
def ai_track(description: str, **span_kwargs: "Any") -> "Callable[[F], F]":
def decorator(f: "F") -> "F":
def sync_wrapped(*args: "Any", **kwargs: "Any") -> "Any":
curr_pipeline = _ai_pipeline_name.get()
op = span_kwargs.pop("op", "ai.run" if curr_pipeline else "ai.pipeline")
with start_span(name=description, op=op, **span_kwargs) as span:
for k, v in kwargs.pop("sentry_tags", {}).items():
span.set_tag(k, v)
for k, v in kwargs.pop("sentry_data", {}).items():
span.set_data(k, v)
if curr_pipeline:
span.set_data(SPANDATA.GEN_AI_PIPELINE_NAME, curr_pipeline)
return f(*args, **kwargs)
else:
_ai_pipeline_name.set(description)
try:
res = f(*args, **kwargs)
except Exception as e:
exc_info = sys.exc_info()
with capture_internal_exceptions():
event, hint = sentry_sdk.utils.event_from_exception(
e,
client_options=sentry_sdk.get_client().options,
mechanism={"type": "ai_monitoring", "handled": False},
)
sentry_sdk.capture_event(event, hint=hint)
reraise(*exc_info)
finally:
_ai_pipeline_name.set(None)
return res
async def async_wrapped(*args: "Any", **kwargs: "Any") -> "Any":
curr_pipeline = _ai_pipeline_name.get()
op = span_kwargs.pop("op", "ai.run" if curr_pipeline else "ai.pipeline")
with start_span(name=description, op=op, **span_kwargs) as span:
for k, v in kwargs.pop("sentry_tags", {}).items():
span.set_tag(k, v)
for k, v in kwargs.pop("sentry_data", {}).items():
span.set_data(k, v)
if curr_pipeline:
span.set_data(SPANDATA.GEN_AI_PIPELINE_NAME, curr_pipeline)
return await f(*args, **kwargs)
else:
_ai_pipeline_name.set(description)
try:
res = await f(*args, **kwargs)
except Exception as e:
exc_info = sys.exc_info()
with capture_internal_exceptions():
event, hint = sentry_sdk.utils.event_from_exception(
e,
client_options=sentry_sdk.get_client().options,
mechanism={"type": "ai_monitoring", "handled": False},
)
sentry_sdk.capture_event(event, hint=hint)
reraise(*exc_info)
finally:
_ai_pipeline_name.set(None)
return res
if inspect.iscoroutinefunction(f):
return wraps(f)(async_wrapped) # type: ignore
else:
return wraps(f)(sync_wrapped) # type: ignore
return decorator
def record_token_usage(
span: "Union[Span, StreamedSpan]",
input_tokens: "Optional[int]" = None,
input_tokens_cached: "Optional[int]" = None,
input_tokens_cache_write: "Optional[int]" = None,
output_tokens: "Optional[int]" = None,
output_tokens_reasoning: "Optional[int]" = None,
total_tokens: "Optional[int]" = None,
) -> None:
# TODO: move pipeline name elsewhere
ai_pipeline_name = get_ai_pipeline_name()
if ai_pipeline_name:
_set_span_data_attribute(span, SPANDATA.GEN_AI_PIPELINE_NAME, ai_pipeline_name)
if input_tokens is not None:
_set_span_data_attribute(span, SPANDATA.GEN_AI_USAGE_INPUT_TOKENS, input_tokens)
if input_tokens_cached is not None:
_set_span_data_attribute(
span,
SPANDATA.GEN_AI_USAGE_INPUT_TOKENS_CACHED,
input_tokens_cached,
)
if input_tokens_cache_write is not None:
_set_span_data_attribute(
span,
SPANDATA.GEN_AI_USAGE_INPUT_TOKENS_CACHE_WRITE,
input_tokens_cache_write,
)
if output_tokens is not None:
_set_span_data_attribute(
span, SPANDATA.GEN_AI_USAGE_OUTPUT_TOKENS, output_tokens
)
if output_tokens_reasoning is not None:
_set_span_data_attribute(
span,
SPANDATA.GEN_AI_USAGE_OUTPUT_TOKENS_REASONING,
output_tokens_reasoning,
)
if total_tokens is None and input_tokens is not None and output_tokens is not None:
total_tokens = input_tokens + output_tokens
if total_tokens is not None:
_set_span_data_attribute(span, SPANDATA.GEN_AI_USAGE_TOTAL_TOKENS, total_tokens)