2
2
// Licensed under the MIT License.
3
3
4
4
using Microsoft . CodeAnalysis ;
5
+ using System . Reflection ;
5
6
6
7
namespace Files . Core . SourceGenerator . Generators
7
8
{
@@ -12,26 +13,60 @@ public void Initialize(IncrementalGeneratorInitializationContext context)
12
13
{
13
14
var sources = context . SyntaxProvider . ForAttributeWithMetadataName (
14
15
"Files.Shared.Attributes.GeneratedVTableFunctionAttribute" ,
15
- static ( node , token ) => true ,
16
- static ( context , token ) => context )
17
- . Collect ( ) ;
16
+ static ( node , token ) =>
17
+ {
18
+ token . ThrowIfCancellationRequested ( ) ;
19
+
20
+ // Check if the method has partial modifier and is public or internal (and not static)
21
+ if ( node is not MethodDeclarationSyntax { AttributeLists . Count : > 0 } method ||
22
+ ! method . Modifiers . Any ( SyntaxKind . PartialKeyword ) ||
23
+ ! ( method . Modifiers . Any ( SyntaxKind . PublicKeyword ) || method . Modifiers . Any ( SyntaxKind . InternalKeyword ) ) ||
24
+ method . Modifiers . Any ( SyntaxKind . StaticKeyword ) )
25
+ return false ;
26
+
27
+ // Check if the type containing the method has partial modifier and is a struct
28
+ if ( node . Parent is not TypeDeclarationSyntax { Keyword . RawKind : ( int ) SyntaxKind . StructKeyword , Modifiers : { } modifiers } ||
29
+ ! modifiers . Any ( SyntaxKind . PartialKeyword ) )
30
+ return false ;
31
+
32
+ return true ;
33
+ } ,
34
+ static ( context , token ) =>
35
+ {
36
+ token . ThrowIfCancellationRequested ( ) ;
37
+
38
+ var fullyQualifiedParentTypeName = context . TargetSymbol . ContainingType . ToString ( ) ;
39
+ var structNamespace = context . TargetSymbol . ContainingType . ContainingNamespace . ToString ( ) ;
40
+ var structName = context . TargetSymbol . ContainingType . Name ;
41
+ var methodSymbol = ( IMethodSymbol ) context . TargetSymbol ;
42
+ var functionName = methodSymbol . Name ;
43
+ var returnTypeName = methodSymbol . ReturnType . ToDisplayString ( SymbolDisplayFormat . FullyQualifiedFormat ) ;
44
+ var parameters = methodSymbol . Parameters . CastArray < ISymbol > ( ) ;
45
+ var index = ( int ) context . Attributes [ 0 ] . NamedArguments . FirstOrDefault ( x => x . Key . Equals ( "Index" ) ) . Value . Value ! ;
46
+
47
+ return new VTableFunctionInfo ( fullyQualifiedParentTypeName , structNamespace , structName , functionName , returnTypeName , index , new ( parameters ) ) ;
48
+ } )
49
+ . Where ( static item => item is not null )
50
+ . Collect ( )
51
+ . Select ( ( items , token ) =>
52
+ {
53
+ token . ThrowIfCancellationRequested ( ) ;
54
+
55
+ return items . GroupBy ( source => source . FullyQualifiedParentTypeName , StringComparer . OrdinalIgnoreCase ) ;
56
+ } ) ;
57
+
18
58
19
59
context . RegisterSourceOutput ( sources , ( context , sources ) =>
20
60
{
21
- var vtableFunctionsGroupedByStructs = sources . GroupBy ( source => source . TargetSymbol . ContainingType , SymbolEqualityComparer . Default ) ;
22
-
23
- foreach ( var vtableFunctions in vtableFunctionsGroupedByStructs )
61
+ foreach ( var source in sources )
24
62
{
25
- if ( vtableFunctions . Key is not INamedTypeSymbol structSymbol || structSymbol . Name is not { } structName )
26
- continue ;
27
-
28
- string vtableFunctionsCode = GenerateVtableFunctionsForStruct ( structSymbol , vtableFunctions ) ;
29
- context . AddSource ( $ "{ structName } _VTableFunctions.g.cs", vtableFunctionsCode ) ;
63
+ string vtableFunctionsCode = GenerateVtableFunctionsForStruct ( source ) ;
64
+ context . AddSource ( $ "{ source . Key } _VTableFunctions.g.cs", vtableFunctionsCode ) ;
30
65
}
31
66
} ) ;
32
67
}
33
68
34
- private string GenerateVtableFunctionsForStruct ( INamedTypeSymbol structSymbol , IEnumerable < GeneratorAttributeSyntaxContext > sources )
69
+ private string GenerateVtableFunctionsForStruct ( IEnumerable < VTableFunctionInfo > sources )
35
70
{
36
71
StringBuilder builder = new ( ) ;
37
72
@@ -42,13 +77,10 @@ private string GenerateVtableFunctionsForStruct(INamedTypeSymbol structSymbol, I
42
77
builder . AppendLine ( $ "#pragma warning disable") ;
43
78
builder . AppendLine ( ) ;
44
79
45
- if ( structSymbol . ContainingNamespace is { IsGlobalNamespace : false } )
46
- {
47
- builder . AppendLine ( $ "namespace { structSymbol . ContainingNamespace } ;") ;
48
- builder . AppendLine ( ) ;
49
- }
80
+ builder . AppendLine ( $ "namespace { sources . ElementAt ( 0 ) . ParentTypeNamespace } ;") ;
81
+ builder . AppendLine ( ) ;
50
82
51
- builder . AppendLine ( $ "public unsafe partial struct { structSymbol . Name } ") ;
83
+ builder . AppendLine ( $ "public unsafe partial struct { sources . ElementAt ( 0 ) . ParentTypeName } ") ;
52
84
builder . AppendLine ( $ "{{") ;
53
85
54
86
builder . AppendLine ( $ " private void** lpVtbl;") ;
@@ -59,15 +91,14 @@ private string GenerateVtableFunctionsForStruct(INamedTypeSymbol structSymbol, I
59
91
60
92
foreach ( var source in sources )
61
93
{
62
- var vtblIndex = source . Attributes [ 0 ] . NamedArguments . Where ( x => x . Key . Equals ( "Index" ) ) . FirstOrDefault ( ) . Value ;
63
- var info = GetVTableFunctionInfo ( ( IMethodSymbol ) source . TargetSymbol ) ;
94
+ var parameters = source . Parameters . Cast < IParameterSymbol > ( ) . ToDictionary ( x => x . Type . ToDisplayString ( SymbolDisplayFormat . FullyQualifiedFormat ) , x => x . Name ) ;
64
95
65
96
builder . AppendLine ( $ " [global::System.Runtime.CompilerServices.MethodImpl(global::System.Runtime.CompilerServices.MethodImplOptions.AggressiveInlining)]") ;
66
97
67
- builder . AppendLine ( $ " public partial { info . ReturnType } { info . Name } ({ string . Join ( ", " , info . Parameters . Select ( x => $ "{ x . Key } { x . Value } ") ) } )") ;
98
+ builder . AppendLine ( $ " public partial { source . ReturnTypeName } { source . Name } ({ string . Join ( ", " , parameters . Select ( x => $ "{ x . Key } { x . Value } ") ) } )") ;
68
99
builder . AppendLine ( $ " {{") ;
69
- builder . AppendLine ( $ " return ({ info . ReturnType } )((delegate* unmanaged[MemberFunction]<{ structSymbol . Name } *, { string . Join ( ", " , info . Parameters . Select ( x => $ "{ x . Key } ") ) } , int>)(lpVtbl[{ vtblIndex . Value } ]))") ;
70
- builder . AppendLine ( $ " (({ structSymbol . Name } *)global::System.Runtime.CompilerServices.Unsafe.AsPointer(ref this), { string . Join ( ", " , info . Parameters . Select ( x => $ "{ x . Value } ") ) } );") ;
100
+ builder . AppendLine ( $ " return ({ source . ReturnTypeName } )((delegate* unmanaged[MemberFunction]<{ sources . ElementAt ( 0 ) . ParentTypeName } *, { string . Join ( ", " , parameters . Select ( x => $ "{ x . Key } ") ) } , int>)(lpVtbl[{ source . Index } ]))") ;
101
+ builder . AppendLine ( $ " (({ sources . ElementAt ( 0 ) . ParentTypeName } *)global::System.Runtime.CompilerServices.Unsafe.AsPointer(ref this), { string . Join ( ", " , parameters . Select ( x => $ "{ x . Value } ") ) } );") ;
71
102
builder . AppendLine ( $ " }}") ;
72
103
73
104
if ( sourceIndex < sourceCount - 1 )
@@ -80,27 +111,5 @@ private string GenerateVtableFunctionsForStruct(INamedTypeSymbol structSymbol, I
80
111
81
112
return builder . ToString ( ) ;
82
113
}
83
-
84
- private VTableFunctionInfo GetVTableFunctionInfo ( IMethodSymbol methodSymbol )
85
- {
86
- string functionName = methodSymbol . Name ;
87
- string returnType = methodSymbol . ReturnType . ToDisplayString ( SymbolDisplayFormat . FullyQualifiedFormat ) ;
88
-
89
- Dictionary < string , string > parameters = [ ] ;
90
- foreach ( var param in methodSymbol . Parameters )
91
- {
92
- var name = param . Name ;
93
- var type = param . Type ;
94
-
95
- parameters . Add ( type . ToDisplayString ( SymbolDisplayFormat . FullyQualifiedFormat ) , name ) ;
96
- }
97
-
98
- return new VTableFunctionInfo ( )
99
- {
100
- Name = functionName ,
101
- ReturnType = returnType ,
102
- Parameters = parameters ,
103
- } ;
104
- }
105
114
}
106
115
}
0 commit comments