Skip to content

Commit

Permalink
- Assembly Loader work; implemented custom dictionary key and table.
Browse files Browse the repository at this point in the history
  • Loading branch information
MapleWheels committed Dec 9, 2024
1 parent 7b786f0 commit a428325
Showing 1 changed file with 171 additions and 22 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
using FluentResults.LuaCs;
using LightInject;
using Microsoft.CodeAnalysis.CSharp;
using OneOf;
using Path = Barotrauma.IO.Path;

[assembly: InternalsVisibleTo(IAssemblyLoaderService.InternalsAwareAssemblyName)]
Expand All @@ -42,12 +43,9 @@ public bool IsDisposed
/// </summary>
private readonly ReaderWriterLockSlim _operationsLock = new(LockRecursionPolicy.SupportsRecursion);
private readonly ConcurrentDictionary<string, AssemblyDependencyResolver> _dependencyResolvers = new();
private readonly ConcurrentDictionary<string, (Assembly, byte[])> _memoryCompiledAssemblies = new();
private readonly ConcurrentDictionary<Assembly, ImmutableArray<MetadataReference>> _loadedAssemblyReferences = new();
private readonly ConcurrentDictionary<Assembly, ImmutableArray<Type>> _typeCache = new();
private readonly ConcurrentDictionary<AssemblyOrStringKey, AssemblyData> _loadedAssemblyData = new();

private ThreadLocal<bool> _isResolving = new(static()=>false); // cyclic resolution exit


#region PublicAPI

Expand Down Expand Up @@ -82,7 +80,8 @@ public FluentResults.Result AddDependencyPaths(ImmutableArray<string> paths)
}
catch (Exception ex)
{
return res.WithError(new ExceptionalError(ex));
return res.WithError(new ExceptionalError(ex)
.WithMetadata(MetadataType.Sources, path));
}
}
return FluentResults.Result.Ok();
Expand All @@ -102,7 +101,7 @@ public FluentResults.Result<Assembly> CompileScriptAssembly(
.WithMetadata(MetadataType.RootObject, syntaxTrees));
}

if (_memoryCompiledAssemblies.ContainsKey(assemblyName))
if (_loadedAssemblyData.ContainsKey(assemblyName))
{
return new FluentResults.Result<Assembly>().WithError(new Error($"The name provided is already assigned to an assembly!")
.WithMetadata(MetadataType.ExceptionObject, this)
Expand All @@ -118,9 +117,12 @@ public FluentResults.Result<Assembly> CompileScriptAssembly(
reportSuppressedDiagnostics: true,
allowUnsafe: true);

typeof(CSharpCompilationOptions)
if (!compileWithInternalAccess)
{
typeof(CSharpCompilationOptions)
.GetProperty("TopLevelBinderFlags", BindingFlags.Instance | BindingFlags.NonPublic)
?.SetValue(compilationOptions, (uint)1 << 22);
}

using var asmMemoryStream = new MemoryStream();
var result = CSharpCompilation.Create(compilationAssemblyName, syntaxTrees, metadataReferences, compilationOptions).Emit(asmMemoryStream);
Expand All @@ -141,11 +143,9 @@ public FluentResults.Result<Assembly> CompileScriptAssembly(
asmMemoryStream.Seek(0, SeekOrigin.Begin);
try
{
var assembly = LoadFromStream(asmMemoryStream);
var assemblyImage = asmMemoryStream.ToArray();
_memoryCompiledAssemblies[assemblyName] = (assembly, assemblyImage);
_typeCache[assembly] = assembly.GetSafeTypes().ToImmutableArray();
return new FluentResults.Result<Assembly>().WithSuccess($"Compiled assembly {assemblyName} successful.").WithValue(assembly);
var data = new AssemblyData(LoadFromStream(asmMemoryStream), asmMemoryStream.ToArray());
_loadedAssemblyData[data.Assembly] = data;
return new FluentResults.Result<Assembly>().WithSuccess($"Compiled assembly {assemblyName} successful.").WithValue(data.Assembly);
}
catch (Exception ex)
{
Expand All @@ -156,8 +156,86 @@ public FluentResults.Result<Assembly> CompileScriptAssembly(
public FluentResults.Result<Assembly> LoadAssemblyFromFile(string assemblyFilePath,
ImmutableArray<string> additionalDependencyPaths)
{
// TODO: Include runtime error diagnostics from Github issue.
throw new NotImplementedException();
if (assemblyFilePath.IsNullOrWhiteSpace())
return new FluentResults.Result<Assembly>().WithError(new Error($"The path provided is null!"));

if (additionalDependencyPaths.Any())
{
var r = AddDependencyPaths(additionalDependencyPaths);
if (!r.IsFailed)
{
// we have errors, loading may not work.
return FluentResults.Result.Fail(new Error($"Failed to load dependency paths")
.WithMetadata(MetadataType.ExceptionObject, this)
.WithMetadata(MetadataType.RootObject, assemblyFilePath))
.WithErrors(r.Errors);
}
}

string sanitizedFilePath = Path.GetFullPath(assemblyFilePath.CleanUpPath());
string directoryKey = Path.GetDirectoryName(sanitizedFilePath);

if (directoryKey is null)
{
return FluentResults.Result.Fail(new Error($"Unable to load assembly: bath file path: {assemblyFilePath}")
.WithMetadata(MetadataType.ExceptionObject, this)
.WithMetadata(MetadataType.RootObject, sanitizedFilePath));
}

try
{
var assembly = LoadFromAssemblyPath(sanitizedFilePath);
_loadedAssemblyData[assembly] = new AssemblyData(assembly, sanitizedFilePath);
return new Result<Assembly>().WithSuccess($"Loaded assembly'{assembly.GetName()}'").WithValue(assembly);
}
catch (ArgumentNullException ane)
{
return FluentResults.Result.Fail<Assembly>(new ExceptionalError(ane)
.WithMetadata(MetadataType.ExceptionObject, this)
.WithMetadata(MetadataType.RootObject, assemblyFilePath)
.WithMetadata(MetadataType.ExceptionDetails, ane.Message)
.WithMetadata(MetadataType.StackTrace, ane.StackTrace));
}
catch (ArgumentException ae)
{
return FluentResults.Result.Fail<Assembly>(new ExceptionalError(ae)
.WithMetadata(MetadataType.ExceptionObject, this)
.WithMetadata(MetadataType.RootObject, assemblyFilePath)
.WithMetadata(MetadataType.ExceptionDetails, ae.Message)
.WithMetadata(MetadataType.StackTrace, ae.StackTrace));
}
catch (FileLoadException fle)
{
return FluentResults.Result.Fail<Assembly>(new ExceptionalError(fle)
.WithMetadata(MetadataType.ExceptionObject, this)
.WithMetadata(MetadataType.RootObject, assemblyFilePath)
.WithMetadata(MetadataType.ExceptionDetails, fle.Message)
.WithMetadata(MetadataType.StackTrace, fle.StackTrace));
}
catch (FileNotFoundException fnfe)
{
return FluentResults.Result.Fail<Assembly>(new ExceptionalError(fnfe)
.WithMetadata(MetadataType.ExceptionObject, this)
.WithMetadata(MetadataType.RootObject, assemblyFilePath)
.WithMetadata(MetadataType.ExceptionDetails, fnfe.Message)
.WithMetadata(MetadataType.StackTrace, fnfe.StackTrace));
}
catch (BadImageFormatException bife)
{
return FluentResults.Result.Fail<Assembly>(new ExceptionalError(bife)
.WithMetadata(MetadataType.ExceptionObject, this)
.WithMetadata(MetadataType.RootObject, assemblyFilePath)
.WithMetadata(MetadataType.ExceptionDetails, bife.Message)
.WithMetadata(MetadataType.StackTrace, bife.StackTrace));
}
catch (Exception e)
{
return FluentResults.Result.Fail<Assembly>(new ExceptionalError(e)
.WithMetadata(MetadataType.ExceptionObject, this)
.WithMetadata(MetadataType.RootObject, assemblyFilePath)
.WithMetadata(MetadataType.ExceptionDetails, e.Message)
.WithMetadata(MetadataType.StackTrace, e.StackTrace));
}
}

public FluentResults.Result<Assembly> GetAssemblyByName(string assemblyName)
Expand All @@ -168,14 +246,14 @@ public FluentResults.Result<Assembly> GetAssemblyByName(string assemblyName)
.WithMetadata(MetadataType.ExceptionObject, this));
}

if (_memoryCompiledAssemblies.TryGetValue(assemblyName, out var assembly))
if (_loadedAssemblyData.TryGetValue(assemblyName, out var data))
{
return new FluentResults.Result<Assembly>().WithSuccess(new Success($"Assembly found")).WithValue(assembly.Item1);
return new FluentResults.Result<Assembly>().WithSuccess(new Success($"Assembly found")).WithValue(data.Assembly);
}

foreach (var assembly1 in Assemblies)
foreach (var assembly1 in this.Assemblies)
{
if (assembly1.GetName().Name == assemblyName)
if (assembly1.GetName().FullName == assemblyName)
{
return new FluentResults.Result<Assembly>().WithSuccess(new Success($"Assembly found")).WithValue(assembly1);
}
Expand All @@ -188,9 +266,7 @@ public FluentResults.Result<ImmutableArray<Type>> GetTypesInAssemblies()
{
try
{
return new FluentResults.Result<ImmutableArray<Type>>().WithValue([
.._typeCache.SelectMany(kvp => kvp.Value)
]);
return new FluentResults.Result<ImmutableArray<Type>>().WithValue([.._loadedAssemblyData.SelectMany(kvp=> kvp.Value.Types)]);
}
catch (Exception e)
{
Expand All @@ -204,7 +280,14 @@ public FluentResults.Result<ImmutableArray<Type>> GetTypesInAssemblies()

protected override Assembly Load(AssemblyName assemblyName)
{
throw new NotImplementedException();
if (_isResolving.Value)
return null;

try
{
_isResolving.Value = true;
throw new NotImplementedException();
}

Check failure on line 290 in Barotrauma/BarotraumaShared/SharedSource/LuaCs/_Plugins/AssemblyLoader.cs

View workflow job for this annotation

GitHub Actions / run-tests / run-tests

Expected catch or finally
}

protected override IntPtr LoadUnmanagedDll(string unmanagedDllName)
Expand Down Expand Up @@ -245,4 +328,70 @@ private void Dispose(bool disposing)
}
}
}

private readonly record struct AssemblyData
{
public readonly Assembly Assembly;
public readonly OneOf<byte[], string> AssemblyImageOrPath;
public readonly MetadataReference AssemblyReference;
public readonly ImmutableArray<Type> Types;

public AssemblyData(Assembly assembly, byte[] assemblyImage)
{
Assembly = assembly ?? throw new ArgumentNullException(nameof(assembly));
AssemblyImageOrPath = assemblyImage ?? throw new ArgumentNullException(nameof(assemblyImage));
AssemblyReference = MetadataReference.CreateFromImage(assemblyImage);
Types = [..assembly.GetSafeTypes()];
}

public AssemblyData(Assembly assembly, string path)
{
Assembly = assembly ?? throw new ArgumentNullException(nameof(assembly));
AssemblyImageOrPath = path ?? throw new ArgumentNullException(nameof(path));
AssemblyReference = MetadataReference.CreateFromFile(path);
Types = [..assembly.GetSafeTypes()];
}
}

private readonly record struct AssemblyOrStringKey : IEquatable<AssemblyOrStringKey>, IEqualityComparer<AssemblyOrStringKey>
{
public Assembly Assembly { get; init; }
public string AssemblyName { get; init; }
public readonly int HashCode;

public AssemblyOrStringKey(Assembly assembly)
{
if(assembly == null)
throw new ArgumentNullException(nameof(assembly));
Assembly = assembly;
AssemblyName = assembly.GetName().FullName;
if (AssemblyName == null)
throw new ArgumentNullException(nameof(AssemblyName));
HashCode = AssemblyName.GetHashCode();
}

public AssemblyOrStringKey(string assemblyName)
{
if (assemblyName.IsNullOrWhiteSpace())
throw new ArgumentNullException(nameof(assemblyName));
Assembly = null;
AssemblyName = assemblyName;
HashCode = AssemblyName.GetHashCode();
}

public bool Equals(AssemblyOrStringKey x, AssemblyOrStringKey y)
{
if (x.Assembly is not null && y.Assembly is not null)
return x.Assembly == y.Assembly;
return x.AssemblyName == y.AssemblyName;
}

public int GetHashCode(AssemblyOrStringKey obj)
{
return obj.HashCode;
}

public static implicit operator AssemblyOrStringKey(Assembly assembly) => new AssemblyOrStringKey(assembly);
public static implicit operator AssemblyOrStringKey(string name) => new AssemblyOrStringKey(name);
}
}

0 comments on commit a428325

Please sign in to comment.