diff --git a/router/src/infer.rs b/router/src/infer.rs index 703dacd4..5b39ec37 100644 --- a/router/src/infer.rs +++ b/router/src/infer.rs @@ -276,7 +276,7 @@ impl Infer { } /// Add a new request to the queue and return a stream of InferStreamResponse - #[instrument(skip(self))] + #[instrument(skip_all,fields(parameters = ? request.parameters))] pub(crate) async fn generate_stream( &self, request: GenerateRequest, @@ -400,7 +400,7 @@ impl Infer { } /// Add a new request to the queue and return a InferResponse - #[instrument(skip(self))] + #[instrument(skip_all,fields(parameters = ? request.parameters))] pub(crate) async fn generate( &self, request: GenerateRequest, @@ -488,7 +488,7 @@ impl Infer { } } - #[instrument(skip(self))] + #[instrument(skip_all,fields(parameters = ? request.parameters))] pub(crate) async fn embed(&self, request: EmbedRequest) -> Result { // Limit concurrent requests by acquiring a permit from the semaphore let _permit = self @@ -618,7 +618,7 @@ impl Infer { } } - #[instrument(skip(self))] + #[instrument(skip_all)] pub(crate) async fn classify( &self, request: ClassifyRequest, @@ -731,7 +731,7 @@ impl Infer { } } - #[instrument(skip(self))] + #[instrument(skip_all)] pub(crate) async fn classify_batch( &self, request: BatchClassifyRequest, @@ -861,7 +861,7 @@ impl Infer { } } - #[instrument(skip(self))] + #[instrument(skip_all,fields(parameters = ? request.parameters))] pub(crate) async fn embed_batch( &self, request: BatchEmbedRequest, @@ -996,7 +996,7 @@ impl Infer { /// Add best_of new requests to the queue and return a InferResponse of the sequence with /// the highest log probability per token - #[instrument(skip(self))] + #[instrument(skip_all,fields(parameters = ? request.parameters, best_of, prefix_caching))] pub(crate) async fn generate_best_of( &self, request: GenerateRequest, diff --git a/router/src/server.rs b/router/src/server.rs index 3eb24521..b64b1f22 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -80,7 +80,18 @@ example = json ! ({"error": "Input validation error"})), example = json ! ({"error": "Incomplete generation"})), ) )] -#[instrument(skip(infer, req))] +#[instrument( +skip_all, +fields( +parameters = ? req.0.parameters, +total_time, +validation_time, +queue_time, +inference_time, +time_per_token, +seed, +) +)] async fn compat_generate( default_return_full_text: Extension, infer: Extension, @@ -91,6 +102,10 @@ async fn compat_generate( req_headers: HeaderMap, req: Json, ) -> Result)> { + // Log some useful headers to the span. + let span = tracing::Span::current(); + trace_headers(req_headers, &span); + let mut req = req.0; // default return_full_text given the pipeline_tag @@ -147,7 +162,7 @@ example = json ! ({"error": "Input validation error"})), example = json ! ({"error": "Incomplete generation"})), ) )] -#[instrument(skip(infer, req))] +#[instrument(skip(infer, req, req_headers))] async fn completions_v1( default_return_full_text: Extension, infer: Extension, @@ -158,6 +173,8 @@ async fn completions_v1( req_headers: HeaderMap, req: Json, ) -> Result)> { + let span = tracing::Span::current(); + trace_headers(req_headers, &span); let mut req = req.0; if req.model == info.model_id.as_str() { // Allow user to specify the base model, but treat it as an empty adapter_id @@ -232,7 +249,7 @@ example = json ! ({"error": "Input validation error"})), example = json ! ({"error": "Incomplete generation"})), ) )] -#[instrument(skip(infer, req))] +#[instrument(skip(infer, req, req_headers))] async fn chat_completions_v1( default_return_full_text: Extension, infer: Extension, @@ -243,6 +260,8 @@ async fn chat_completions_v1( req_headers: HeaderMap, req: Json, ) -> Result)> { + let span = tracing::Span::current(); + trace_headers(req_headers, &span); let mut req = req.0; let model_id = info.model_id.clone(); if req.model == info.model_id.as_str() { @@ -632,6 +651,7 @@ async fn generate( mut req: Json, ) -> Result<(HeaderMap, Json), (StatusCode, Json)> { let span = tracing::Span::current(); + trace_headers(req_headers, &span); let start_time = Instant::now(); metrics::increment_counter!("lorax_request_count"); @@ -2058,3 +2078,15 @@ async fn tokenize( )) } } + +fn trace_headers(headers: HeaderMap, span: &tracing::Span) { + headers + .get("x-predibase-tenant") + .map(|value| span.record("x-predibase-tenant", value)); + headers + .get("user-agent") + .map(|value| span.record("user-agent", value)); + headers + .get("x-b3-traceid") + .map(|value| span.record("x-b3-traceid", value)); +}