-
Notifications
You must be signed in to change notification settings - Fork 14
/
Copy pathheaders_inspection_handler.go
120 lines (105 loc) · 4.07 KB
/
headers_inspection_handler.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
package nethttplibrary
import (
nethttp "net/http"
abstractions "github.com/microsoft/kiota-abstractions-go"
"go.opentelemetry.io/otel"
"go.opentelemetry.io/otel/attribute"
"go.opentelemetry.io/otel/trace"
)
// HeadersInspectionHandlerOptions is the options to use when inspecting headers
type HeadersInspectionOptions struct {
InspectRequestHeaders bool
InspectResponseHeaders bool
RequestHeaders *abstractions.RequestHeaders
ResponseHeaders *abstractions.ResponseHeaders
}
// NewHeadersInspectionOptions creates a new HeadersInspectionOptions with default options
func NewHeadersInspectionOptions() *HeadersInspectionOptions {
return &HeadersInspectionOptions{
RequestHeaders: abstractions.NewRequestHeaders(),
ResponseHeaders: abstractions.NewResponseHeaders(),
}
}
type headersInspectionOptionsInt interface {
abstractions.RequestOption
GetInspectRequestHeaders() bool
GetInspectResponseHeaders() bool
GetRequestHeaders() *abstractions.RequestHeaders
GetResponseHeaders() *abstractions.ResponseHeaders
}
var headersInspectionKeyValue = abstractions.RequestOptionKey{
Key: "nethttplibrary.HeadersInspectionOptions",
}
// GetInspectRequestHeaders returns true if the request headers should be inspected
func (o *HeadersInspectionOptions) GetInspectRequestHeaders() bool {
return o.InspectRequestHeaders
}
// GetInspectResponseHeaders returns true if the response headers should be inspected
func (o *HeadersInspectionOptions) GetInspectResponseHeaders() bool {
return o.InspectResponseHeaders
}
// GetRequestHeaders returns the request headers
func (o *HeadersInspectionOptions) GetRequestHeaders() *abstractions.RequestHeaders {
return o.RequestHeaders
}
// GetResponseHeaders returns the response headers
func (o *HeadersInspectionOptions) GetResponseHeaders() *abstractions.ResponseHeaders {
return o.ResponseHeaders
}
// GetKey returns the key for the HeadersInspectionOptions
func (o *HeadersInspectionOptions) GetKey() abstractions.RequestOptionKey {
return headersInspectionKeyValue
}
// HeadersInspectionHandler allows inspecting of the headers of the request and response via a request option
type HeadersInspectionHandler struct {
options HeadersInspectionOptions
}
// NewHeadersInspectionHandler creates a new HeadersInspectionHandler with default options
func NewHeadersInspectionHandler() *HeadersInspectionHandler {
return NewHeadersInspectionHandlerWithOptions(*NewHeadersInspectionOptions())
}
// NewHeadersInspectionHandlerWithOptions creates a new HeadersInspectionHandler with the given options
func NewHeadersInspectionHandlerWithOptions(options HeadersInspectionOptions) *HeadersInspectionHandler {
return &HeadersInspectionHandler{options: options}
}
// Intercept implements the interface and evaluates whether to retry a failed request.
func (middleware HeadersInspectionHandler) Intercept(pipeline Pipeline, middlewareIndex int, req *nethttp.Request) (*nethttp.Response, error) {
obsOptions := GetObservabilityOptionsFromRequest(req)
ctx := req.Context()
var span trace.Span
var observabilityName string
if obsOptions != nil {
observabilityName = obsOptions.GetTracerInstrumentationName()
ctx, span = otel.GetTracerProvider().Tracer(observabilityName).Start(ctx, "HeadersInspectionHandler_Intercept")
span.SetAttributes(attribute.Bool("com.microsoft.kiota.handler.headersInspection.enable", true))
defer span.End()
req = req.WithContext(ctx)
}
reqOption, ok := req.Context().Value(headersInspectionKeyValue).(headersInspectionOptionsInt)
if !ok {
reqOption = &middleware.options
}
if reqOption.GetInspectRequestHeaders() {
for k, v := range req.Header {
if len(v) == 1 {
reqOption.GetRequestHeaders().Add(k, v[0])
} else {
reqOption.GetRequestHeaders().Add(k, v[0], v[1:]...)
}
}
}
response, err := pipeline.Next(req, middlewareIndex)
if err != nil {
return response, err
}
if reqOption.GetInspectResponseHeaders() {
for k, v := range response.Header {
if len(v) == 1 {
reqOption.GetResponseHeaders().Add(k, v[0])
} else {
reqOption.GetResponseHeaders().Add(k, v[0], v[1:]...)
}
}
}
return response, err
}