Skip to content

Commit

Permalink
Add WS data interception.
Browse files Browse the repository at this point in the history
  • Loading branch information
twitchax committed Sep 26, 2024
1 parent 123ca54 commit bb9d6d4
Show file tree
Hide file tree
Showing 5 changed files with 89 additions and 4 deletions.
14 changes: 10 additions & 4 deletions src/Core/Extensions/Ws.cs
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,10 @@ internal static async Task ExecuteWsProxyOperationAsync(this HttpContext context
using var socketToClient = await context.WebSockets.AcceptWebSocketAsync(socketToEndpoint.SubProtocol).ConfigureAwait(false);

var bufferSize = options?.BufferSize ?? 4096;
await Task.WhenAll(PumpWebSocket(socketToEndpoint, socketToClient, bufferSize, context.RequestAborted), PumpWebSocket(socketToClient, socketToEndpoint, bufferSize, context.RequestAborted)).ConfigureAwait(false);
await Task.WhenAll(
PumpWebSocket(socketToEndpoint, socketToClient, WsProxyDataDirection.Downstream, wsProxy, bufferSize, context.RequestAborted),
PumpWebSocket(socketToClient, socketToEndpoint, WsProxyDataDirection.Upstream, wsProxy, bufferSize, context.RequestAborted)
).ConfigureAwait(false);
}
catch (Exception e)
{
Expand All @@ -69,7 +72,7 @@ internal static async Task ExecuteWsProxyOperationAsync(this HttpContext context
}
}

private static async Task PumpWebSocket(WebSocket source, WebSocket destination, int bufferSize, CancellationToken cancellationToken)
private static async Task PumpWebSocket(WebSocket source, WebSocket destination, WsProxyDataDirection direction, WsProxy wsProxy, int bufferSize, CancellationToken cancellationToken)
{
using var ms = new MemoryStream();
var receiveBuffer = WebSocket.CreateServerBuffer(bufferSize);
Expand Down Expand Up @@ -113,8 +116,11 @@ private static async Task PumpWebSocket(WebSocket source, WebSocket destination,

var sendBuffer = new ArraySegment<byte>(ms.GetBuffer(), 0, (int)ms.Length);

// TODO: Add handlers here to allow the developer to edit message before forwarding, and vice versa?
// Possibly in the future, if deemed useful.
// If the data intercept is set, then invoke it.
if(wsProxy.Options?.DataIntercept != null)
{
await wsProxy.Options.DataIntercept(sendBuffer, direction, result.MessageType).ConfigureAwait(false);
}

await destination.SendAsync(sendBuffer, result.MessageType, result.EndOfMessage, cancellationToken).ConfigureAwait(false);
}
Expand Down
14 changes: 14 additions & 0 deletions src/Core/Helpers/Helpers.cs
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,20 @@ public interface IBuilder<TInterface, TConcrete> where TConcrete : class
TConcrete Build();
}

/// <summary>
/// The direction of the data flow.
/// </summary>
public enum WsProxyDataDirection {
/// <summary>
/// The data is flowing from the client to the server.
/// </summary>
Upstream,
/// <summary>
/// The data is flowing from the server to the client.
/// </summary>
Downstream
}

internal static class Helpers
{
internal static readonly string HttpProxyClientName = "AspNetCore.Proxy.HttpProxyClient";
Expand Down
33 changes: 33 additions & 0 deletions src/Core/Options/WsProxyOptions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,14 @@ public interface IWsProxyOptionsBuilder : IBuilder<IWsProxyOptionsBuilder, WsPro
/// <returns>This instance.</returns>
IWsProxyOptionsBuilder WithIntercept(Func<HttpContext, ValueTask<bool>> intercept);

/// <summary>
/// A <see cref="Func{ArraySegment, WsProxyDataDirection, WebSocketMessageType, Task}"/> that is invoked upon a new message.
/// This allows for the data to be intercepted and modified before being forwarded.
/// </summary>
/// <param name="dataIntercept"></param>
/// <returns>This instance.</returns>
IWsProxyOptionsBuilder WithDataIntercept(Func<ArraySegment<byte>, WsProxyDataDirection, WebSocketMessageType, Task> dataIntercept);

/// <summary>
/// An <see cref="Func{HttpContext, ClientWebSocketOptions, Task}"/> that is invoked before the connect to the remote endpoint.
/// The <see cref="ClientWebSocketOptions"/> can be edited before the call.
Expand All @@ -49,6 +57,7 @@ public sealed class WsProxyOptionsBuilder : IWsProxyOptionsBuilder
{
private int _bufferSize = 4096;
private Func<HttpContext, ValueTask<bool>> _intercept;
private Func<ArraySegment<byte>, WsProxyDataDirection, WebSocketMessageType, Task> _dataIntercept;
private Func<HttpContext, ClientWebSocketOptions, Task> _beforeConnect;
private Func<HttpContext, Exception, Task> _handleFailure;

Expand All @@ -71,6 +80,7 @@ public IWsProxyOptionsBuilder New()
return Instance
.WithBufferSize(_bufferSize)
.WithIntercept(_intercept)
.WithDataIntercept(_dataIntercept)
.WithBeforeConnect(_beforeConnect)
.WithHandleFailure(_handleFailure);
}
Expand All @@ -81,6 +91,7 @@ public WsProxyOptions Build()
return new WsProxyOptions(
_bufferSize,
_intercept,
_dataIntercept,
_beforeConnect,
_handleFailure);
}
Expand Down Expand Up @@ -108,6 +119,18 @@ public IWsProxyOptionsBuilder WithIntercept(Func<HttpContext, ValueTask<bool>> i
return this;
}

/// <summary>
/// Sets the <see cref="Func{ArraySegment, WsProxyDataDirection, WebSocketMessageType, Task}"/> that is invoked upon a new data message.
/// </summary>
/// <param name="dataIntercept"></param>
/// <returns>The current instance with the specified option set.</returns>
public IWsProxyOptionsBuilder WithDataIntercept(Func<ArraySegment<byte>, WsProxyDataDirection, WebSocketMessageType, Task> dataIntercept)
{
_dataIntercept = dataIntercept;
return this;
}


/// <summary>
/// Sets the <see cref="Func{HttpContext, ClientWebSocketOptions, Task}"/> that is invoked upon a new connection.
/// The <see cref="ClientWebSocketOptions"/> can be edited before the response is written to the client.
Expand Down Expand Up @@ -155,6 +178,14 @@ public sealed class WsProxyOptions
/// </value>
public Func<HttpContext, ValueTask<bool>> Intercept { get; }

/// <summary>
/// DataIntercept property.
/// </summary>
/// <value>
/// A <see cref="Func{ArraySegment, WsProxyDataDirection, WebSocketMessageType, Task}"/> that is invoked upon a data call.
/// </value>
public Func<ArraySegment<byte>, WsProxyDataDirection, WebSocketMessageType, Task> DataIntercept { get; }

/// <summary>
/// BeforeConnect property.
/// </summary>
Expand All @@ -175,11 +206,13 @@ public sealed class WsProxyOptions
internal WsProxyOptions(
int bufferSize,
Func<HttpContext, ValueTask<bool>> intercept,
Func<ArraySegment<byte>, WsProxyDataDirection, WebSocketMessageType, Task> dataIntercept,
Func<HttpContext, ClientWebSocketOptions, Task> beforeConnect,
Func<HttpContext, Exception, Task> handleFailure)
{
BufferSize = bufferSize;
Intercept = intercept;
DataIntercept = dataIntercept;
BeforeConnect = beforeConnect;
HandleFailure = handleFailure;
}
Expand Down
6 changes: 6 additions & 0 deletions src/Test/Ws/WsHelpers.cs
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,12 @@ internal static Task RunWsServers(CancellationToken token)
.UseWs("ws://localhost:5002/", options => options
.WithBufferSize(8192)
.WithIntercept(context => new ValueTask<bool>(context.WebSockets.WebSocketRequestedProtocols.Contains("interceptedProtocol")))
.WithDataIntercept((data, direction, type) => {
if(direction == WsProxyDataDirection.Downstream && System.Text.Encoding.Default.GetString(data.Array).StartsWith("[should_be_intercepted]"))
data.Array[0] = (byte)']';

return Task.CompletedTask;
})
.WithBeforeConnect((context, wso) =>
{
wso.AddSubProtocol("myRandomProto");
Expand Down
26 changes: 26 additions & 0 deletions src/Test/Ws/WsIntegrationTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,32 @@ public async Task CanIntercept()
Assert.Equal(message, exception.Message);
}

[Fact]
public async Task CanDataIntercept()
{
var send1 = "should_be_intercepted";
var expected1 = $"]{send1}]";

await _client.ConnectAsync(new Uri("ws://localhost:5001/ws"), CancellationToken.None);
Assert.Equal(Extensions.SupportedProtocol, _client.SubProtocol);

// Send a message.
await _client.SendMessageAsync(send1);
await _client.SendShortMessageAsync(Extensions.CloseMessage);

// Receive responses.
var response1 = await _client.ReceiveMessageAsync();
Assert.Equal(expected1, response1);

// Receive close.
var result = await _client.ReceiveAsync(new ArraySegment<byte>(new byte[4096]), CancellationToken.None);
Assert.Equal(WebSocketMessageType.Close, result.MessageType);
Assert.Equal(WebSocketCloseStatus.NormalClosure, result.CloseStatus);
Assert.Equal(Extensions.CloseDescription, result.CloseStatusDescription);

await _client.CloseOutputAsync(WebSocketCloseStatus.NormalClosure, Extensions.CloseDescription, CancellationToken.None);
}

[Fact]
public async Task CanRunBeforeConnectAndHandleFailure()
{
Expand Down

0 comments on commit bb9d6d4

Please sign in to comment.