Skip to content

Commit 46b34cd

Browse files
committed
Incrementally generate
1 parent c6f3ff1 commit 46b34cd

File tree

5 files changed

+542
-53
lines changed

5 files changed

+542
-53
lines changed

src/Files.Core.SourceGenerator/Data/VTableFunctionInfo.cs

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,12 @@
33

44
namespace Files.Core.SourceGenerator.Data
55
{
6-
internal class VTableFunctionInfo
7-
{
8-
public required string Name { get; init; }
9-
10-
public required string ReturnType { get; init; }
11-
12-
public required Dictionary<string, string> Parameters { get; init; }
13-
}
6+
internal record VTableFunctionInfo(
7+
string FullyQualifiedParentTypeName,
8+
string ParentTypeNamespace,
9+
string ParentTypeName,
10+
string Name,
11+
string ReturnTypeName,
12+
int Index,
13+
EquatableArray<ISymbol> Parameters);
1414
}

src/Files.Core.SourceGenerator/Files.Core.SourceGenerator.csproj

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
<PlatformTarget>AnyCPU</PlatformTarget>
1010
<EnforceExtendedAnalyzerRules>true</EnforceExtendedAnalyzerRules>
1111
<Configurations>Debug;Release</Configurations>
12+
<AllowUnsafeBlocks>true</AllowUnsafeBlocks>
1213
</PropertyGroup>
1314

1415
<ItemGroup>

src/Files.Core.SourceGenerator/Generators/VTableFunctionGenerator.cs

Lines changed: 54 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
// Licensed under the MIT License.
33

44
using Microsoft.CodeAnalysis;
5+
using System.Reflection;
56

67
namespace Files.Core.SourceGenerator.Generators
78
{
@@ -12,26 +13,60 @@ public void Initialize(IncrementalGeneratorInitializationContext context)
1213
{
1314
var sources = context.SyntaxProvider.ForAttributeWithMetadataName(
1415
"Files.Shared.Attributes.GeneratedVTableFunctionAttribute",
15-
static (node, token) => true,
16-
static (context, token) => context)
17-
.Collect();
16+
static (node, token) =>
17+
{
18+
token.ThrowIfCancellationRequested();
19+
20+
// Check if the method has partial modifier and is public or internal (and not static)
21+
if (node is not MethodDeclarationSyntax { AttributeLists.Count: > 0 } method ||
22+
!method.Modifiers.Any(SyntaxKind.PartialKeyword) ||
23+
!(method.Modifiers.Any(SyntaxKind.PublicKeyword) || method.Modifiers.Any(SyntaxKind.InternalKeyword)) ||
24+
method.Modifiers.Any(SyntaxKind.StaticKeyword))
25+
return false;
26+
27+
// Check if the type containing the method has partial modifier and is a struct
28+
if (node.Parent is not TypeDeclarationSyntax { Keyword.RawKind: (int)SyntaxKind.StructKeyword, Modifiers: { } modifiers } ||
29+
!modifiers.Any(SyntaxKind.PartialKeyword))
30+
return false;
31+
32+
return true;
33+
},
34+
static (context, token) =>
35+
{
36+
token.ThrowIfCancellationRequested();
37+
38+
var fullyQualifiedParentTypeName = context.TargetSymbol.ContainingType.ToString();
39+
var structNamespace = context.TargetSymbol.ContainingType.ContainingNamespace.ToString();
40+
var structName = context.TargetSymbol.ContainingType.Name;
41+
var methodSymbol = (IMethodSymbol)context.TargetSymbol;
42+
var functionName = methodSymbol.Name;
43+
var returnTypeName = methodSymbol.ReturnType.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat);
44+
var parameters = methodSymbol.Parameters.CastArray<ISymbol>();
45+
var index = (int)context.Attributes[0].NamedArguments.FirstOrDefault(x => x.Key.Equals("Index")).Value.Value!;
46+
47+
return new VTableFunctionInfo(fullyQualifiedParentTypeName, structNamespace, structName, functionName, returnTypeName, index, new(parameters));
48+
})
49+
.Where(static item => item is not null)
50+
.Collect()
51+
.Select((items, token) =>
52+
{
53+
token.ThrowIfCancellationRequested();
54+
55+
return items.GroupBy(source => source.FullyQualifiedParentTypeName, StringComparer.OrdinalIgnoreCase);
56+
});
57+
1858

1959
context.RegisterSourceOutput(sources, (context, sources) =>
2060
{
21-
var vtableFunctionsGroupedByStructs = sources.GroupBy(source => source.TargetSymbol.ContainingType, SymbolEqualityComparer.Default);
22-
23-
foreach (var vtableFunctions in vtableFunctionsGroupedByStructs)
61+
foreach (var source in sources)
2462
{
25-
if (vtableFunctions.Key is not INamedTypeSymbol structSymbol || structSymbol.Name is not { } structName)
26-
continue;
27-
28-
string vtableFunctionsCode = GenerateVtableFunctionsForStruct(structSymbol, vtableFunctions);
29-
context.AddSource($"{structName}_VTableFunctions.g.cs", vtableFunctionsCode);
63+
string vtableFunctionsCode = GenerateVtableFunctionsForStruct(source);
64+
context.AddSource($"{source.Key}_VTableFunctions.g.cs", vtableFunctionsCode);
3065
}
3166
});
3267
}
3368

34-
private string GenerateVtableFunctionsForStruct(INamedTypeSymbol structSymbol, IEnumerable<GeneratorAttributeSyntaxContext> sources)
69+
private string GenerateVtableFunctionsForStruct(IEnumerable<VTableFunctionInfo> sources)
3570
{
3671
StringBuilder builder = new();
3772

@@ -42,13 +77,10 @@ private string GenerateVtableFunctionsForStruct(INamedTypeSymbol structSymbol, I
4277
builder.AppendLine($"#pragma warning disable");
4378
builder.AppendLine();
4479

45-
if (structSymbol.ContainingNamespace is { IsGlobalNamespace: false })
46-
{
47-
builder.AppendLine($"namespace {structSymbol.ContainingNamespace};");
48-
builder.AppendLine();
49-
}
80+
builder.AppendLine($"namespace {sources.ElementAt(0).ParentTypeNamespace};");
81+
builder.AppendLine();
5082

51-
builder.AppendLine($"public unsafe partial struct {structSymbol.Name}");
83+
builder.AppendLine($"public unsafe partial struct {sources.ElementAt(0).ParentTypeName}");
5284
builder.AppendLine($"{{");
5385

5486
builder.AppendLine($" private void** lpVtbl;");
@@ -59,15 +91,14 @@ private string GenerateVtableFunctionsForStruct(INamedTypeSymbol structSymbol, I
5991

6092
foreach (var source in sources)
6193
{
62-
var vtblIndex = source.Attributes[0].NamedArguments.Where(x => x.Key.Equals("Index")).FirstOrDefault().Value;
63-
var info = GetVTableFunctionInfo((IMethodSymbol)source.TargetSymbol);
94+
var parameters = source.Parameters.Cast<IParameterSymbol>().ToDictionary(x => x.Type.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat), x => x.Name);
6495

6596
builder.AppendLine($" [global::System.Runtime.CompilerServices.MethodImpl(global::System.Runtime.CompilerServices.MethodImplOptions.AggressiveInlining)]");
6697

67-
builder.AppendLine($" public partial {info.ReturnType} {info.Name}({string.Join(", ", info.Parameters.Select(x => $"{x.Key} {x.Value}"))})");
98+
builder.AppendLine($" public partial {source.ReturnTypeName} {source.Name}({string.Join(", ", parameters.Select(x => $"{x.Key} {x.Value}"))})");
6899
builder.AppendLine($" {{");
69-
builder.AppendLine($" return ({info.ReturnType})((delegate* unmanaged[MemberFunction]<{structSymbol.Name}*, {string.Join(", ", info.Parameters.Select(x => $"{x.Key}"))}, int>)(lpVtbl[{vtblIndex.Value}]))");
70-
builder.AppendLine($" (({structSymbol.Name}*)global::System.Runtime.CompilerServices.Unsafe.AsPointer(ref this), {string.Join(", ", info.Parameters.Select(x => $"{x.Value}"))});");
100+
builder.AppendLine($" return ({source.ReturnTypeName})((delegate* unmanaged[MemberFunction]<{sources.ElementAt(0).ParentTypeName}*, {string.Join(", ", parameters.Select(x => $"{x.Key}"))}, int>)(lpVtbl[{source.Index}]))");
101+
builder.AppendLine($" (({sources.ElementAt(0).ParentTypeName}*)global::System.Runtime.CompilerServices.Unsafe.AsPointer(ref this), {string.Join(", ", parameters.Select(x => $"{x.Value}"))});");
71102
builder.AppendLine($" }}");
72103

73104
if (sourceIndex < sourceCount - 1)
@@ -80,27 +111,5 @@ private string GenerateVtableFunctionsForStruct(INamedTypeSymbol structSymbol, I
80111

81112
return builder.ToString();
82113
}
83-
84-
private VTableFunctionInfo GetVTableFunctionInfo(IMethodSymbol methodSymbol)
85-
{
86-
string functionName = methodSymbol.Name;
87-
string returnType = methodSymbol.ReturnType.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat);
88-
89-
Dictionary<string, string> parameters = [];
90-
foreach (var param in methodSymbol.Parameters)
91-
{
92-
var name = param.Name;
93-
var type = param.Type;
94-
95-
parameters.Add(type.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat), name);
96-
}
97-
98-
return new VTableFunctionInfo()
99-
{
100-
Name = functionName,
101-
ReturnType = returnType,
102-
Parameters = parameters,
103-
};
104-
}
105114
}
106115
}
Lines changed: 125 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,125 @@
1+
// Copyright (c) Files Community
2+
// Licensed under the MIT License.
3+
4+
using System;
5+
using System.Collections;
6+
using System.Collections.Generic;
7+
using System.Collections.Immutable;
8+
using System.Linq;
9+
using System.Runtime.CompilerServices;
10+
using System.Runtime.InteropServices;
11+
12+
namespace Files.Core.SourceGenerator.Utilities;
13+
14+
internal readonly struct EquatableArray<T>(ImmutableArray<T> array) : IEquatable<EquatableArray<T>>, IEnumerable<T>
15+
where T : IEquatable<T>
16+
{
17+
private readonly T[]? array = ImmutableCollectionsMarshal.AsArray(array);
18+
19+
public ref readonly T this[int index]
20+
{
21+
[MethodImpl(MethodImplOptions.AggressiveInlining)]
22+
get => ref AsImmutableArray().ItemRef(index);
23+
}
24+
25+
public bool IsEmpty
26+
{
27+
[MethodImpl(MethodImplOptions.AggressiveInlining)]
28+
get => AsImmutableArray().IsEmpty;
29+
}
30+
31+
public bool IsDefaultOrEmpty
32+
{
33+
[MethodImpl(MethodImplOptions.AggressiveInlining)]
34+
get => AsImmutableArray().IsDefaultOrEmpty;
35+
}
36+
37+
public int Length
38+
{
39+
[MethodImpl(MethodImplOptions.AggressiveInlining)]
40+
get => AsImmutableArray().Length;
41+
}
42+
43+
public bool Equals(EquatableArray<T> array)
44+
{
45+
return AsSpan().SequenceEqual(array.AsSpan());
46+
}
47+
48+
public override bool Equals(object? obj)
49+
{
50+
return obj is EquatableArray<T> array && Equals(this, array);
51+
}
52+
53+
public override unsafe int GetHashCode()
54+
{
55+
if (this.array is not T[] array)
56+
return 0;
57+
58+
HashCode hashCode = default;
59+
60+
if (typeof(T) == typeof(byte))
61+
{
62+
ReadOnlySpan<T> span = array;
63+
ref T r0 = ref MemoryMarshal.GetReference(span);
64+
ref byte r1 = ref Unsafe.As<T, byte>(ref r0);
65+
66+
fixed (byte* p = &r1)
67+
{
68+
ReadOnlySpan<byte> bytes = new(p, span.Length);
69+
70+
hashCode.AddBytes(bytes);
71+
}
72+
}
73+
else
74+
{
75+
foreach (T item in array)
76+
hashCode.Add(item);
77+
}
78+
79+
return hashCode.ToHashCode();
80+
}
81+
82+
[MethodImpl(MethodImplOptions.AggressiveInlining)]
83+
public ImmutableArray<T> AsImmutableArray()
84+
{
85+
return ImmutableCollectionsMarshal.AsImmutableArray(this.array);
86+
}
87+
88+
public static EquatableArray<T> FromImmutableArray(ImmutableArray<T> array)
89+
{
90+
return new(array);
91+
}
92+
93+
public ReadOnlySpan<T> AsSpan()
94+
{
95+
return AsImmutableArray().AsSpan();
96+
}
97+
98+
public T[] ToArray()
99+
{
100+
return [.. AsImmutableArray()];
101+
}
102+
103+
public ImmutableArray<T>.Enumerator GetEnumerator()
104+
{
105+
return AsImmutableArray().GetEnumerator();
106+
}
107+
108+
IEnumerator<T> IEnumerable<T>.GetEnumerator()
109+
{
110+
return ((IEnumerable<T>)AsImmutableArray()).GetEnumerator();
111+
}
112+
113+
IEnumerator IEnumerable.GetEnumerator()
114+
{
115+
return ((IEnumerable)AsImmutableArray()).GetEnumerator();
116+
}
117+
118+
public static implicit operator EquatableArray<T>(ImmutableArray<T> array) => FromImmutableArray(array);
119+
120+
public static implicit operator ImmutableArray<T>(EquatableArray<T> array) => array.AsImmutableArray();
121+
122+
public static bool operator ==(EquatableArray<T> left, EquatableArray<T> right) => left.Equals(right);
123+
124+
public static bool operator !=(EquatableArray<T> left, EquatableArray<T> right) => !left.Equals(right);
125+
}

0 commit comments

Comments
 (0)