diff --git a/servant-client/test/Servant/ClientTestUtils.hs b/servant-client/test/Servant/ClientTestUtils.hs index 4b70a7a9e..7f92fafbf 100644 --- a/servant-client/test/Servant/ClientTestUtils.hs +++ b/servant-client/test/Servant/ClientTestUtils.hs @@ -237,7 +237,7 @@ basicAuthHandler = if username == "servant" && password == "server" then return (Authorized ()) else return Unauthorized - in BasicAuthCheck check + in BasicAuthCheck True check basicServerContext :: Context '[ BasicAuthCheck () ] basicServerContext = basicAuthHandler :. EmptyContext diff --git a/servant-http-streams/test/Servant/ClientSpec.hs b/servant-http-streams/test/Servant/ClientSpec.hs index 41e7fbe44..c2a21fbe2 100644 --- a/servant-http-streams/test/Servant/ClientSpec.hs +++ b/servant-http-streams/test/Servant/ClientSpec.hs @@ -222,7 +222,7 @@ basicAuthHandler = if username == "servant" && password == "server" then return (Authorized ()) else return Unauthorized - in BasicAuthCheck check + in BasicAuthCheck True check basicServerContext :: Context '[ BasicAuthCheck () ] basicServerContext = basicAuthHandler :. EmptyContext diff --git a/servant-server/src/Servant/Server.hs b/servant-server/src/Servant/Server.hs index 5d40eb6f6..a38689d28 100644 --- a/servant-server/src/Servant/Server.hs +++ b/servant-server/src/Servant/Server.hs @@ -43,7 +43,7 @@ module Servant.Server , descendIntoNamedContext -- * Basic Authentication - , BasicAuthCheck(BasicAuthCheck, unBasicAuthCheck) + , BasicAuthCheck(BasicAuthCheck, basicAuthRunCheck, basicAuthPresentChallenge) , BasicAuthResult(..) -- * General Authentication diff --git a/servant-server/src/Servant/Server/Internal/BasicAuth.hs b/servant-server/src/Servant/Server/Internal/BasicAuth.hs index b92e4b02a..6c68d10a6 100644 --- a/servant-server/src/Servant/Server/Internal/BasicAuth.hs +++ b/servant-server/src/Servant/Server/Internal/BasicAuth.hs @@ -44,9 +44,12 @@ data BasicAuthResult usr deriving (Eq, Show, Read, Generic, Typeable, Functor) -- | Datatype wrapping a function used to check authentication. -newtype BasicAuthCheck usr = BasicAuthCheck - { unBasicAuthCheck :: BasicAuthData - -> IO (BasicAuthResult usr) +data BasicAuthCheck usr + = BasicAuthCheck + { basicAuthPresentChallenge :: Bool + -- ^ Decides if we'll send a @WWW-Authenticate@ HTTP header. Sending the header causes browser to + -- surface a prompt for user name and password, which may be undesirable for APIs. + , basicAuthRunCheck :: BasicAuthData -> IO (BasicAuthResult usr) } deriving (Generic, Typeable, Functor) @@ -68,7 +71,7 @@ decodeBAHdr req = do -- | Run and check basic authentication, returning the appropriate http error per -- the spec. runBasicAuth :: Request -> BS.ByteString -> BasicAuthCheck usr -> DelayedIO usr -runBasicAuth req realm (BasicAuthCheck ba) = +runBasicAuth req realm (BasicAuthCheck presentChallenge ba) = case decodeBAHdr req of Nothing -> plzAuthenticate Just e -> liftIO (ba e) >>= \res -> case res of @@ -76,4 +79,6 @@ runBasicAuth req realm (BasicAuthCheck ba) = NoSuchUser -> plzAuthenticate Unauthorized -> delayedFailFatal err403 Authorized usr -> return usr - where plzAuthenticate = delayedFailFatal err401 { errHeaders = [mkBAChallengerHdr realm] } + where + plzAuthenticate = + delayedFailFatal err401 { errHeaders = [mkBAChallengerHdr realm | presentChallenge] } diff --git a/servant-server/test/Servant/Server/ErrorSpec.hs b/servant-server/test/Servant/Server/ErrorSpec.hs index 72251b21c..b9a8f2bfa 100644 --- a/servant-server/test/Servant/Server/ErrorSpec.hs +++ b/servant-server/test/Servant/Server/ErrorSpec.hs @@ -44,7 +44,7 @@ errorOrderAuthCheck = if username == "servant" && password == "server" then return (Authorized ()) else return Unauthorized - in BasicAuthCheck check + in BasicAuthCheck True check ------------------------------------------------------------------------------ -- * Error Order {{{ diff --git a/servant-server/test/Servant/ServerSpec.hs b/servant-server/test/Servant/ServerSpec.hs index e3dec48e9..aa711ce52 100644 --- a/servant-server/test/Servant/ServerSpec.hs +++ b/servant-server/test/Servant/ServerSpec.hs @@ -63,7 +63,7 @@ import qualified Servant.Types.SourceT as S import Test.Hspec (Spec, context, describe, it, shouldBe, shouldContain) import Test.Hspec.Wai - (get, liftIO, matchHeaders, matchStatus, shouldRespondWith, + (get, liftIO, matchHeaders, MatchHeader(..), matchStatus, shouldRespondWith, with, (<:>)) import qualified Test.Hspec.Wai as THW @@ -742,9 +742,9 @@ basicAuthServer = const (return jerry) :<|> (Tagged $ \ _ sendResponse -> sendResponse $ responseLBS imATeapot418 [] "") -basicAuthContext :: Context '[ BasicAuthCheck () ] -basicAuthContext = - let basicHandler = BasicAuthCheck $ \(BasicAuthData usr pass) -> +basicAuthContext :: Bool -> Context '[ BasicAuthCheck () ] +basicAuthContext withRealm = + let basicHandler = BasicAuthCheck withRealm $ \(BasicAuthData usr pass) -> if usr == "servant" && pass == "server" then return (Authorized ()) else return Unauthorized @@ -753,7 +753,17 @@ basicAuthContext = basicAuthSpec :: Spec basicAuthSpec = do describe "Servant.API.BasicAuth" $ do - with (return (serveWithContext basicAuthApi basicAuthContext basicAuthServer)) $ do + with (return (serveWithContext basicAuthApi (basicAuthContext False) basicAuthServer)) $ do + context "Basic Authentication without realm" $ do + it "does not send WWW-Authenticate headers on 401" $ do + let noWWW = + MatchHeader $ \headers _ -> + if "WWW-Authenticate" `elem` map fst headers + then Just "WWW-Authenticate header is unexpected, " + else Nothing + get "/basic" `shouldRespondWith` "" {matchStatus = 401, matchHeaders = [noWWW]} + + with (return (serveWithContext basicAuthApi (basicAuthContext True) basicAuthServer)) $ do context "Basic Authentication" $ do let basicAuthHeaders user password = @@ -761,6 +771,9 @@ basicAuthSpec = do it "returns 401 when no credentials given" $ do get "/basic" `shouldRespondWith` 401 + it "returns 401 WWW-Authenticate headers" $ do + get "/basic" `shouldRespondWith` "" {matchStatus = 401, matchHeaders = ["WWW-Authenticate" <:> "Basic realm=\"foo\""]} + it "returns 403 when invalid credentials given" $ do THW.request methodGet "/basic" (basicAuthHeaders "servant" "wrong") "" `shouldRespondWith` 403