diff --git a/batch.go b/batch.go new file mode 100644 index 0000000..b73e916 --- /dev/null +++ b/batch.go @@ -0,0 +1,288 @@ +package anthropic + +import ( + "bytes" + "context" + "encoding/json" + "errors" + "fmt" + "io" + "net/http" + "net/url" + "time" +) + +type ResultType string + +const ( + ResultTypeSucceeded ResultType = "succeeded" + ResultTypeErrored ResultType = "errored" + ResultTypeCanceled ResultType = "canceled" + ResultTypeExpired ResultType = "expired" +) + +type BatchId string + +type BatchResponseType string + +const ( + BatchResponseTypeMessageBatch BatchResponseType = "message_batch" +) + +type ProcessingStatus string + +const ( + ProcessingStatusInProgress ProcessingStatus = "in_progress" + ProcessingStatusCanceling ProcessingStatus = "canceling" + ProcessingStatusEnded ProcessingStatus = "ended" +) + +// While in beta, batches may contain up to 10,000 requests and be up to 32 MB in total size. +type BatchRequest struct { + Requests []InnerRequests `json:"requests"` +} + +type InnerRequests struct { + CustomId string `json:"custom_id"` + Params MessagesRequest `json:"params"` +} + +// All times returned in RFC 3339 +type BatchResponse struct { + httpHeader + + BatchRespCore +} + +type BatchRespCore struct { + Id BatchId `json:"id"` + Type BatchResponseType `json:"type"` + ProcessingStatus ProcessingStatus `json:"processing_status"` + RequestCounts RequestCounts `json:"request_counts"` + EndedAt *time.Time `json:"ended_at"` + CreatedAt time.Time `json:"created_at"` + ExpiresAt time.Time `json:"expires_at"` + ArchivedAt *time.Time `json:"archived_at"` + CancelInitiatedAt *time.Time `json:"cancel_initiated_at"` + ResultsUrl *string `json:"results_url"` +} + +type RequestCounts struct { + Processing int `json:"processing"` + Succeeded int `json:"succeeded"` + Errored int `json:"errored"` + Canceled int `json:"canceled"` + Expired int `json:"expired"` +} + +func (c *Client) CreateBatch( + ctx context.Context, + request BatchRequest, +) (*BatchResponse, error) { + var setters []requestSetter + if len(c.config.BetaVersion) > 0 { + setters = append(setters, withBetaVersion(c.config.BetaVersion...)) + } + + urlSuffix := "/messages/batches" + req, err := c.newRequest(ctx, http.MethodPost, urlSuffix, request, setters...) + if err != nil { + return nil, err + } + + var response BatchResponse + err = c.sendRequest(req, &response) + + return &response, err +} + +func (c *Client) RetrieveBatch( + ctx context.Context, + batchId BatchId, +) (*BatchResponse, error) { + var setters []requestSetter + if len(c.config.BetaVersion) > 0 { + setters = append(setters, withBetaVersion(c.config.BetaVersion...)) + } + + urlSuffix := "/messages/batches/" + string(batchId) + req, err := c.newRequest(ctx, http.MethodGet, urlSuffix, nil, setters...) + if err != nil { + return nil, err + } + + var response BatchResponse + err = c.sendRequest(req, &response) + + return &response, err +} + +type BatchResultCore struct { + Type ResultType `json:"type"` + Result MessagesResponse `json:"message"` +} + +type BatchResult struct { + CustomId string `json:"custom_id"` + Result BatchResultCore `json:"result"` +} + +type RetrieveBatchResultsResponse struct { + httpHeader + + // Each line in the file is a JSON object containing the result of a + // single request in the Message Batch. Results are not guaranteed to + // be in the same order as requests. Use the custom_id field to match + // results to requests. + + Responses []BatchResult + RawResponse []byte +} + +func (c *Client) RetrieveBatchResults( + ctx context.Context, + batchId BatchId, +) (*RetrieveBatchResultsResponse, error) { + var setters []requestSetter + if len(c.config.BetaVersion) > 0 { + setters = append(setters, withBetaVersion(c.config.BetaVersion...)) + } + + // The documentation states that the URL should be obtained from the results_url field in the batch response. + // It clearly states that the URL should 'not be assumed'. However this seems to work fine. + urlSuffix := "/messages/batches/" + string(batchId) + "/results" + req, err := c.newRequest(ctx, http.MethodGet, urlSuffix, nil, setters...) + if err != nil { + return nil, err + } + + var response RetrieveBatchResultsResponse + + res, err := c.config.HTTPClient.Do(req) + if err != nil { + return nil, err + } + defer res.Body.Close() + + response.SetHeader(res.Header) + + if err := c.handlerRequestError(res); err != nil { + return nil, err + } + + response.RawResponse, err = io.ReadAll(res.Body) + if err != nil { + return nil, err + } + + response.Responses, err = decodeRawResponse(response.RawResponse) + if err != nil { + return nil, err + } + + return &response, err +} + +func decodeRawResponse(rawResponse []byte) ([]BatchResult, error) { + // This looks fishy, but this logic works. + // https://goplay.tools/snippet/tDPm3GJVv0_s + var results []BatchResult + for _, line := range bytes.Split(rawResponse, []byte("\n")) { + if len(line) == 0 { + continue + } + + var parsed BatchResult + err := json.Unmarshal(line, &parsed) + if err != nil { + return nil, err + } + + results = append(results, parsed) + } + + return results, nil +} + +type ListBatchesResponse struct { + httpHeader + + Data []BatchRespCore `json:"data"` + HasMore bool `json:"has_more"` + FirstId *BatchId `json:"first_id"` + LastId *BatchId `json:"last_id"` +} + +type ListBatchesRequest struct { + BeforeId *string `json:"before_id,omitempty"` + AfterId *string `json:"after_id,omitempty"` + Limit *int `json:"limit,omitempty"` +} + +func (l ListBatchesRequest) validate() error { + if l.Limit != nil && (*l.Limit < 1 || *l.Limit > 100) { + return errors.New("limit must be between 1 and 100") + } + + return nil +} + +func (c *Client) ListBatches( + ctx context.Context, + lBatchReq ListBatchesRequest, +) (*ListBatchesResponse, error) { + var setters []requestSetter + if len(c.config.BetaVersion) > 0 { + setters = append(setters, withBetaVersion(c.config.BetaVersion...)) + } + + if err := lBatchReq.validate(); err != nil { + return nil, err + } + + urlSuffix := "/messages/batches" + + v := url.Values{} + if lBatchReq.BeforeId != nil { + v.Set("before_id", *lBatchReq.BeforeId) + } + if lBatchReq.AfterId != nil { + v.Set("after_id", *lBatchReq.AfterId) + } + if lBatchReq.Limit != nil { + v.Set("limit", fmt.Sprintf("%d", *lBatchReq.Limit)) + } + + // encode the query parameters into the URL + urlSuffix += "?" + v.Encode() + req, err := c.newRequest(ctx, http.MethodGet, urlSuffix, nil, setters...) + if err != nil { + return nil, err + } + + var response ListBatchesResponse + err = c.sendRequest(req, &response) + + return &response, err +} + +func (c *Client) CancelBatch( + ctx context.Context, + batchId BatchId, +) (*BatchResponse, error) { + var setters []requestSetter + if len(c.config.BetaVersion) > 0 { + setters = append(setters, withBetaVersion(c.config.BetaVersion...)) + } + + urlSuffix := "/messages/batches/" + string(batchId) + "/cancel" + req, err := c.newRequest(ctx, http.MethodPost, urlSuffix, nil, setters...) + if err != nil { + return nil, err + } + + var response BatchResponse + err = c.sendRequest(req, &response) + + return &response, err +} diff --git a/batch_test.go b/batch_test.go new file mode 100644 index 0000000..f634933 --- /dev/null +++ b/batch_test.go @@ -0,0 +1,435 @@ +package anthropic_test + +import ( + "context" + "encoding/json" + "net/http" + "strconv" + "strings" + "testing" + "time" + + "github.com/liushuangls/go-anthropic/v2" + "github.com/liushuangls/go-anthropic/v2/internal/test" +) + +func TestCreateBatch(t *testing.T) { + server := test.NewTestServer() + server.RegisterHandler("/v1/messages/batches", handleCreateBatchEndpoint) + + ts := server.AnthropicTestServer() + ts.Start() + defer ts.Close() + + baseUrl := ts.URL + "/v1" + client := anthropic.NewClient( + test.GetTestToken(), + anthropic.WithBaseURL(baseUrl), + anthropic.WithBetaVersion(anthropic.BetaMessageBatches20240924), + ) + + t.Run("create batch success", func(t *testing.T) { + resp, err := client.CreateBatch(context.Background(), anthropic.BatchRequest{ + Requests: []anthropic.InnerRequests{ + { + CustomId: "custom-identifier-not-real-this-is-a-test", + Params: anthropic.MessagesRequest{ + Model: anthropic.ModelClaude3Haiku20240307, + MultiSystem: anthropic.NewMultiSystemMessages( + "you are an assistant", + "you are snarky", + ), + MaxTokens: 10, + Messages: []anthropic.Message{ + anthropic.NewUserTextMessage("What is your name?"), + anthropic.NewAssistantTextMessage("My name is Claude."), + anthropic.NewUserTextMessage("What is your favorite color?"), + }, + }, + }, + }, + }) + if err != nil { + t.Fatalf("CreateBatch error: %s", err) + } + t.Logf("Create Batch resp: %+v", resp) + }) + + t.Run("fails with missing beta version header", func(t *testing.T) { + clientWithoutBeta := anthropic.NewClient( + test.GetTestToken(), + anthropic.WithBaseURL(baseUrl), + anthropic.WithBetaVersion(anthropic.BetaMessageBatches20240924), + ) + _, err := clientWithoutBeta.CreateBatch(context.Background(), anthropic.BatchRequest{}) + if err == nil { + t.Fatalf("CreateBatch expected error, got nil") + } + }) + +} + +func handleCreateBatchEndpoint(w http.ResponseWriter, r *http.Request) { + var err error + var resBytes []byte + + // Creating batches only accepts POST requests + if r.Method != "POST" { + http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) + } + + var completeReq anthropic.BatchRequest + if completeReq, err = getRequest[anthropic.BatchRequest](r); err != nil { + http.Error(w, "could not read request", http.StatusInternalServerError) + return + } + + betaHeaders := r.Header.Get("Anthropic-Beta") + if !strings.Contains(betaHeaders, string(anthropic.BetaMessageBatches20240924)) { + http.Error(w, "missing beta version header", http.StatusBadRequest) + return + } + + custId := completeReq.Requests[0].CustomId + if custId == "" { + // I think this should be a bad request. TODO check docs + http.Error(w, "custom_id is required", http.StatusBadRequest) + return + } + + t1 := time.Now().Add(-time.Hour * 2) + + res := anthropic.BatchResponse{ + BatchRespCore: anthropic.BatchRespCore{ + Id: anthropic.BatchId( + "batch_id_" + strconv.FormatInt(time.Now().Unix(), 10), + ), + Type: anthropic.BatchResponseTypeMessageBatch, + ProcessingStatus: anthropic.ProcessingStatusInProgress, + RequestCounts: anthropic.RequestCounts{ + Processing: 1, + Succeeded: 2, + Canceled: 3, + Errored: 4, + Expired: 5, + }, + EndedAt: nil, + CreatedAt: t1, + ExpiresAt: t1.Add(time.Hour * 4), + ArchivedAt: nil, + CancelInitiatedAt: nil, + ResultsUrl: nil, + }, + } + resBytes, _ = json.Marshal(res) + _, _ = w.Write(resBytes) +} + +func TestRetrieveBatch(t *testing.T) { + server := test.NewTestServer() + server.RegisterHandler("/v1/messages/batches/batch_id_1234", handleRetrieveBatchEndpoint) + server.RegisterHandler("/v1/messages/batches/batch_id_not_found", handleRetrieveBatchEndpoint) + + ts := server.AnthropicTestServer() + ts.Start() + defer ts.Close() + + baseUrl := ts.URL + "/v1" + client := anthropic.NewClient( + test.GetTestToken(), + anthropic.WithBaseURL(baseUrl), + anthropic.WithBetaVersion(anthropic.BetaMessageBatches20240924), + ) + + t.Run("retrieve batch success", func(t *testing.T) { + resp, err := client.RetrieveBatch(context.Background(), "batch_id_1234") + if err != nil { + t.Fatalf("RetrieveBatch error: %s", err) + } + t.Logf("Retrieve Batch resp: %+v", resp) + }) + + t.Run("retrieve batch failure", func(t *testing.T) { + _, err := client.RetrieveBatch(context.Background(), "batch_id_not_found") + if err == nil { + t.Fatalf("RetrieveBatch expected error, got nil") + } + }) +} + +func handleRetrieveBatchEndpoint(w http.ResponseWriter, r *http.Request) { + var resBytes []byte + + // retrieving batches only accepts GET requests + if r.Method != "GET" { + http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) + } + + batchId := strings.TrimPrefix(r.URL.Path, "/v1/messages/batches/") + if batchId == "" { + http.Error(w, "missing batch id", http.StatusBadRequest) + return + } + + if batchId == "batch_id_not_found" { + http.Error(w, "batch not found", http.StatusNotFound) + return + } + + res := anthropic.BatchResponse{ + BatchRespCore: forgeBatchResponse(batchId), + } + resBytes, _ = json.Marshal(res) + _, _ = w.Write(resBytes) +} + +func TestRetrieveBatchResults(t *testing.T) { + server := test.NewTestServer() + server.RegisterHandler( + "/v1/messages/batches/batch_id_1234/results", + handleRetrieveBatchResultsEndpoint, + ) + server.RegisterHandler( + "/v1/messages/batches/batch_id_not_found/results", + handleRetrieveBatchResultsEndpoint, + ) + + ts := server.AnthropicTestServer() + ts.Start() + defer ts.Close() + + baseUrl := ts.URL + "/v1" + client := anthropic.NewClient( + test.GetTestToken(), + anthropic.WithBaseURL(baseUrl), + anthropic.WithBetaVersion(anthropic.BetaMessageBatches20240924), + ) + + t.Run("retrieve batch results success", func(t *testing.T) { + resp, err := client.RetrieveBatchResults(context.Background(), "batch_id_1234") + if err != nil { + t.Fatalf("RetrieveBatchResults error: %s", err) + } + t.Logf("Retrieve Batch Results resp: %+v", resp) + }) + + t.Run("retrieve batch results failure", func(t *testing.T) { + _, err := client.RetrieveBatchResults(context.Background(), "batch_id_not_found") + if err == nil { + t.Fatalf("RetrieveBatchResults expected error, got nil") + } + }) +} + +func handleRetrieveBatchResultsEndpoint(w http.ResponseWriter, r *http.Request) { + var resBytes []byte + + // retrieving batch results only accepts GET requests + if r.Method != "GET" { + http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) + } + + batchId := strings.TrimPrefix(r.URL.Path, "/v1/messages/batches/") + batchId = strings.TrimSuffix(batchId, "/results") + if batchId == "" { + http.Error(w, "missing batch id", http.StatusBadRequest) + return + } + + if batchId == "batch_id_not_found" { + http.Error(w, "batch not found", http.StatusNotFound) + return + } + + res := anthropic.RetrieveBatchResultsResponse{ + Responses: []anthropic.BatchResult{ + { + CustomId: "custom_id_1234", + Result: forgeBatchResult("batch_id_1234"), + }, + }, + } + resBytes, _ = json.Marshal(res) + _, _ = w.Write(resBytes) +} + +func TestListBatches(t *testing.T) { + server := test.NewTestServer() + server.RegisterHandler("/v1/messages/batches", handleListBatchesEndpoint) + + ts := server.AnthropicTestServer() + ts.Start() + defer ts.Close() + + baseUrl := ts.URL + "/v1" + client := anthropic.NewClient( + test.GetTestToken(), + anthropic.WithBaseURL(baseUrl), + anthropic.WithBetaVersion(anthropic.BetaMessageBatches20240924), + ) + + t.Run("list batches success", func(t *testing.T) { + resp, err := client.ListBatches(context.Background(), anthropic.ListBatchesRequest{ + Limit: toPtr(10), + BeforeId: nil, + AfterId: nil, + }) + if err != nil { + t.Fatalf("ListBatches error: %s", err) + } + t.Logf("List Batches resp: %+v", resp) + }) + + t.Run("list failure: limit too high", func(t *testing.T) { + _, err := client.ListBatches(context.Background(), anthropic.ListBatchesRequest{ + Limit: toPtr(101), + BeforeId: nil, + AfterId: nil, + }) + if err == nil { + t.Fatalf("ListBatches expected error, got nil") + } + }) + + t.Run("list batches with before_id and after_id", func(t *testing.T) { + _, err := client.ListBatches(context.Background(), anthropic.ListBatchesRequest{ + Limit: toPtr(10), + BeforeId: toPtr("batch_id_1234"), + AfterId: toPtr("batch_id_567"), + }) + if err != nil { + t.Fatalf("ListBatches error: %s", err) + } + }) +} + +func handleListBatchesEndpoint(w http.ResponseWriter, r *http.Request) { + var resBytes []byte + + // listing batches only accepts GET requests + if r.Method != "GET" { + http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) + } + + res := anthropic.ListBatchesResponse{ + Data: []anthropic.BatchRespCore{ + forgeBatchResponse("batch_id_1234"), + forgeBatchResponse("batch_id_567"), + }, + HasMore: false, + FirstId: nil, + LastId: nil, + } + + resBytes, _ = json.Marshal(res) + _, _ = w.Write(resBytes) +} + +func TestCancelBatch(t *testing.T) { + server := test.NewTestServer() + server.RegisterHandler("/v1/messages/batches/batch_id_1234/cancel", handleCancelBatchEndpoint) + server.RegisterHandler( + "/v1/messages/batches/batch_id_not_found/cancel", + handleCancelBatchEndpoint, + ) + + ts := server.AnthropicTestServer() + ts.Start() + defer ts.Close() + + baseUrl := ts.URL + "/v1" + client := anthropic.NewClient( + test.GetTestToken(), + anthropic.WithBaseURL(baseUrl), + anthropic.WithBetaVersion(anthropic.BetaMessageBatches20240924), + ) + + t.Run("cancel batch success", func(t *testing.T) { + resp, err := client.CancelBatch(context.Background(), "batch_id_1234") + if err != nil { + t.Fatalf("CancelBatch error: %s", err) + } + t.Logf("Cancel Batch resp: %+v", resp) + }) + + t.Run("cancel batch failure", func(t *testing.T) { + resp, err := client.CancelBatch(context.Background(), "batch_id_not_found") + if err == nil { + t.Fatalf("CancelBatch expected error, got nil") + } + t.Logf("Cancel Batch resp: %+v", resp) + }) +} + +func handleCancelBatchEndpoint(w http.ResponseWriter, r *http.Request) { + var resBytes []byte + + // canceling batches only accepts POST requests + if r.Method != "POST" { + http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) + } + + batchId := strings.TrimPrefix(r.URL.Path, "/v1/messages/batches/") + batchId = strings.TrimSuffix(batchId, "/cancel") + if batchId == "" { + http.Error(w, "missing batch id", http.StatusBadRequest) + return + } + + if batchId == "batch_id_not_found" { + http.Error(w, "batch not found", http.StatusNotFound) + return + } + + res := anthropic.BatchResponse{ + BatchRespCore: forgeBatchResponse(batchId), + } + resBytes, _ = json.Marshal(res) + _, _ = w.Write(resBytes) +} + +func forgeBatchResult(customId string) anthropic.BatchResultCore { + return anthropic.BatchResultCore{ + Type: anthropic.ResultTypeSucceeded, + Result: anthropic.MessagesResponse{ + ID: customId, + Type: anthropic.MessagesResponseTypeMessage, + Role: anthropic.RoleAssistant, + Content: []anthropic.MessageContent{ + { + Type: anthropic.MessagesContentTypeText, + Text: toPtr("My name is Claude."), + }, + }, + Model: anthropic.ModelClaude3Haiku20240307, + StopReason: anthropic.MessagesStopReasonEndTurn, + StopSequence: "", + Usage: anthropic.MessagesUsage{ + InputTokens: 10, + OutputTokens: 10, + }, + }, + } +} + +func forgeBatchResponse(batchId string) anthropic.BatchRespCore { + t1 := time.Now().Add(-time.Hour * 2) + return anthropic.BatchRespCore{ + Id: anthropic.BatchId(batchId), + Type: anthropic.BatchResponseTypeMessageBatch, + ProcessingStatus: anthropic.ProcessingStatusInProgress, + RequestCounts: anthropic.RequestCounts{ + Processing: 1, + Succeeded: 2, + Canceled: 3, + Errored: 4, + Expired: 5, + }, + EndedAt: nil, + CreatedAt: t1, + ExpiresAt: t1.Add(time.Hour * 4), + ArchivedAt: nil, + CancelInitiatedAt: nil, + ResultsUrl: nil, + } +} diff --git a/common_test.go b/common_test.go new file mode 100644 index 0000000..c8ab1af --- /dev/null +++ b/common_test.go @@ -0,0 +1,24 @@ +package anthropic_test + +import ( + "encoding/json" + "io" + "net/http" +) + +func toPtr[T any](s T) *T { + return &s +} + +func getRequest[T any](r *http.Request) (req T, err error) { + reqBody, err := io.ReadAll(r.Body) + if err != nil { + return + } + + err = json.Unmarshal(reqBody, &req) + if err != nil { + return + } + return +} diff --git a/complete_stream_test.go b/complete_stream_test.go index cd1a6c6..c39a2f1 100644 --- a/complete_stream_test.go +++ b/complete_stream_test.go @@ -107,7 +107,7 @@ func TestCompleteStreamError(t *testing.T) { } func handlerCompleteStream(w http.ResponseWriter, r *http.Request) { - request, err := getCompleteRequest(r) + request, err := getRequest[anthropic.CompleteStreamRequest](r) if err != nil { http.Error(w, "request error", http.StatusBadRequest) return diff --git a/complete_test.go b/complete_test.go index 3b91cd4..b87b3fe 100644 --- a/complete_test.go +++ b/complete_test.go @@ -3,7 +3,6 @@ package anthropic_test import ( "context" "encoding/json" - "io" "net/http" "strconv" "testing" @@ -105,7 +104,7 @@ func handleCompleteEndpoint(w http.ResponseWriter, r *http.Request) { } var completeReq anthropic.CompleteRequest - if completeReq, err = getCompleteRequest(r); err != nil { + if completeReq, err = getRequest[anthropic.CompleteRequest](r); err != nil { http.Error(w, "could not read request", http.StatusInternalServerError) return } @@ -119,16 +118,3 @@ func handleCompleteEndpoint(w http.ResponseWriter, r *http.Request) { resBytes, _ = json.Marshal(res) _, _ = w.Write(resBytes) } - -func getCompleteRequest(r *http.Request) (req anthropic.CompleteRequest, err error) { - reqBody, err := io.ReadAll(r.Body) - if err != nil { - return - } - - err = json.Unmarshal(reqBody, &req) - if err != nil { - return - } - return -} diff --git a/config.go b/config.go index c0f7742..ea54dd7 100644 --- a/config.go +++ b/config.go @@ -18,9 +18,10 @@ const ( type BetaVersion string const ( - BetaTools20240404 BetaVersion = "tools-2024-04-04" - BetaTools20240516 BetaVersion = "tools-2024-05-16" - BetaPromptCaching20240731 BetaVersion = "prompt-caching-2024-07-31" + BetaTools20240404 BetaVersion = "tools-2024-04-04" + BetaTools20240516 BetaVersion = "tools-2024-05-16" + BetaPromptCaching20240731 BetaVersion = "prompt-caching-2024-07-31" + BetaMessageBatches20240924 BetaVersion = "message-batches-2024-09-24" BetaMaxTokens35Sonnet20240715 BetaVersion = "max-tokens-3-5-sonnet-2024-07-15" ) diff --git a/integrationtest/batch_test.go b/integrationtest/batch_test.go new file mode 100644 index 0000000..5334d43 --- /dev/null +++ b/integrationtest/batch_test.go @@ -0,0 +1,125 @@ +package integrationtest + +import ( + "context" + "math/rand" + "testing" + + "github.com/liushuangls/go-anthropic/v2" +) + +func randomString(l int) string { + const charset = "1234567890abcdef" + b := make([]byte, l) + for i := range b { + b[i] = charset[rand.Intn(len(charset))] + } + return string(b) +} + +func TestIntegrationBatch(t *testing.T) { + testAPIKey(t) + ctx := context.Background() + + myId := "rand_id_" + randomString(5) + createBatchRequest := anthropic.BatchRequest{ + Requests: []anthropic.InnerRequests{ + { + CustomId: myId, + Params: anthropic.MessagesRequest{ + Model: anthropic.ModelClaude3Haiku20240307, + MultiSystem: anthropic.NewMultiSystemMessages( + "you are an assistant", + "you are snarky", + ), + MaxTokens: 10, + Messages: []anthropic.Message{ + anthropic.NewUserTextMessage("What is your name?"), + anthropic.NewAssistantTextMessage("My name is Claude."), + anthropic.NewUserTextMessage("What is your favorite color?"), + }, + }, + }, + }, + } + + betaOpts := anthropic.WithBetaVersion( + anthropic.BetaTools20240404, + anthropic.BetaMaxTokens35Sonnet20240715, + anthropic.BetaMessageBatches20240924, + ) + client := anthropic.NewClient(APIKey, betaOpts) + + // this will be set by the CreateBatch call below, and used in later tests + var batchID anthropic.BatchId + + t.Run("CreateBatch on real API", func(t *testing.T) { + resp, err := client.CreateBatch(ctx, createBatchRequest) + if err != nil { + t.Fatalf("CreateBatch error: %s", err) + } + t.Logf("CreateBatch resp: %+v", resp) + + // Save batchID for later tests + batchID = resp.Id + }) + + t.Run("RetrieveBatch on real API", func(t *testing.T) { + resp, err := client.RetrieveBatch(ctx, batchID) + if err != nil { + t.Fatalf("RetrieveBatch error: %s", err) + } + t.Logf("RetrieveBatch resp: %+v", resp) + }) + + var completedBatch *anthropic.BatchId + t.Run("ListBatches on real API", func(t *testing.T) { + req := anthropic.ListBatchesRequest{ + Limit: toPtr(99), + } + resp, err := client.ListBatches(ctx, req) + if err != nil { + t.Fatalf("ListBatches error: %s", err) + } + t.Logf("ListBatches resp: %+v", resp) + + for _, batch := range resp.Data { + if batch.ProcessingStatus == "ended" && batch.CancelInitiatedAt == nil { + completedBatch = &batch.Id + break + } + } + }) + + if completedBatch == nil { + // We probably need a better way to test this, but for now we'll skip if there's no completed batch + t.Skip("No completed batch to test RetrieveBatchResults") + } else { + // This test should be run after the first batch has completed. + // You should have a completed batch in your account to run this test! + // You can have a batch you run to completion by commenting out the CancelBatch call below. + t.Run("RetrieveBatchResults on real API", func(t *testing.T) { + resp, err := client.RetrieveBatchResults(ctx, *completedBatch) + if err != nil { + t.Fatalf("RetrieveBatchResults error: %s", err) + } + t.Logf("RetrieveBatchResults resp: %+v", resp) + + if len(resp.Responses) == 0 { + t.Fatalf("RetrieveBatchResults returned no responses") + } + + if resp.Responses[0].CustomId == "" { + t.Fatalf("RetrieveBatchResults returned a response with no CustomId. Parse error?") + } + }) + } + + t.Run("CancelBatch on real API", func(t *testing.T) { + resp, err := client.CancelBatch(ctx, batchID) + if err != nil { + t.Fatalf("CancelBatch error: %s", err) + } + t.Logf("CancelBatch resp: %+v", resp) + }) +} diff --git a/integrationtest/common_test.go b/integrationtest/common_test.go new file mode 100644 index 0000000..c6da12e --- /dev/null +++ b/integrationtest/common_test.go @@ -0,0 +1,5 @@ +package integrationtest + +func toPtr[T any](s T) *T { + return &s +} diff --git a/internal/test/server.go b/internal/test/server.go index 3eebc3d..c0d5034 100644 --- a/internal/test/server.go +++ b/internal/test/server.go @@ -40,10 +40,12 @@ func (ts *ServerTest) AnthropicTestServer() *httptest.Server { handlerCall, ok := ts.handlers[r.URL.Path] if !ok { + log.Printf("path %q not found\n", r.URL.Path) http.Error(w, "the resource path doesn't exist", http.StatusNotFound) return } handlerCall(w, r) + log.Printf("request handled successfully\n") }), ) } diff --git a/message_stream_test.go b/message_stream_test.go index 730e403..f66c247 100644 --- a/message_stream_test.go +++ b/message_stream_test.go @@ -289,7 +289,7 @@ func TestMessagesStreamToolUse(t *testing.T) { } func handlerMessagesStream(w http.ResponseWriter, r *http.Request) { - request, err := getMessagesRequest(r) + request, err := getRequest[anthropic.MessagesRequest](r) if err != nil { http.Error(w, "request error", http.StatusBadRequest) return @@ -362,7 +362,7 @@ func handlerMessagesStream(w http.ResponseWriter, r *http.Request) { } func handlerMessagesStreamToolUse(w http.ResponseWriter, r *http.Request) { - messagesReq, err := getMessagesRequest(r) + messagesReq, err := getRequest[anthropic.MessagesRequest](r) if err != nil { http.Error(w, "request error", http.StatusBadRequest) return @@ -440,7 +440,7 @@ func handlerMessagesStreamToolUse(w http.ResponseWriter, r *http.Request) { func handlerMessagesStreamEmptyMessages(numEmptyMessages int, payload string) test.Handler { return func(w http.ResponseWriter, r *http.Request) { - _, err := getMessagesRequest(r) + _, err := getRequest[anthropic.MessagesRequest](r) if err != nil { http.Error(w, "request error", http.StatusBadRequest) return diff --git a/message_test.go b/message_test.go index 42a3029..6274c7b 100644 --- a/message_test.go +++ b/message_test.go @@ -789,7 +789,7 @@ func handleMessagesEndpoint(headers map[string]string) func(http.ResponseWriter, } var messagesReq anthropic.MessagesRequest - if messagesReq, err = getMessagesRequest(r); err != nil { + if messagesReq, err = getRequest[anthropic.MessagesRequest](r); err != nil { http.Error(w, "could not read request", http.StatusInternalServerError) return } @@ -848,20 +848,3 @@ func handleMessagesEndpoint(headers map[string]string) func(http.ResponseWriter, _, _ = w.Write(resBytes) } } - -func getMessagesRequest(r *http.Request) (req anthropic.MessagesRequest, err error) { - reqBody, err := io.ReadAll(r.Body) - if err != nil { - return - } - - err = json.Unmarshal(reqBody, &req) - if err != nil { - return - } - return -} - -func toPtr(s string) *string { - return &s -}