Skip to content

Commit aaa1368

Browse files
committed
[MLIR][Python] fix stubgen
1 parent e9aee33 commit aaa1368

File tree

3 files changed

+50
-11
lines changed

3 files changed

+50
-11
lines changed

mlir/cmake/modules/AddMLIRPython.cmake

Lines changed: 28 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,11 @@ endfunction()
111111
# Outputs:
112112
# NB_STUBGEN_CUSTOM_TARGET: The target corresponding to generation which other targets can depend on.
113113
function(generate_type_stubs MODULE_NAME DEPENDS_TARGET MLIR_DEPENDS_TARGET OUTPUT_DIR)
114+
cmake_parse_arguments(ARG
115+
""
116+
""
117+
"OUTPUTS"
118+
${ARGN})
114119
if(EXISTS ${nanobind_DIR}/../src/stubgen.py)
115120
set(NB_STUBGEN "${nanobind_DIR}/../src/stubgen.py")
116121
elseif(EXISTS ${nanobind_DIR}/../stubgen.py)
@@ -135,9 +140,9 @@ function(generate_type_stubs MODULE_NAME DEPENDS_TARGET MLIR_DEPENDS_TARGET OUTP
135140
--output-dir
136141
"${OUTPUT_DIR}")
137142

138-
set(NB_STUBGEN_OUTPUT "${OUTPUT_DIR}/${MODULE_NAME}.pyi")
143+
list(TRANSFORM ARG_OUTPUTS PREPEND "${OUTPUT_DIR}/" OUTPUT_VARIABLE _generated_type_stubs)
139144
add_custom_command(
140-
OUTPUT ${NB_STUBGEN_OUTPUT}
145+
OUTPUT ${_generated_type_stubs}
141146
COMMAND ${NB_STUBGEN_CMD}
142147
WORKING_DIRECTORY "${CMAKE_CURRENT_FUNCTION_LIST_DIR}"
143148
DEPENDS
@@ -146,7 +151,7 @@ function(generate_type_stubs MODULE_NAME DEPENDS_TARGET MLIR_DEPENDS_TARGET OUTP
146151
"${DEPENDS_TARGET}"
147152
)
148153
set(_name "MLIRPythonModuleStubs_${_module}")
149-
add_custom_target("${_name}" ALL DEPENDS ${NB_STUBGEN_OUTPUT})
154+
add_custom_target("${_name}" ALL DEPENDS ${_generated_type_stubs})
150155
set(NB_STUBGEN_CUSTOM_TARGET "${_name}" PARENT_SCOPE)
151156
endfunction()
152157

@@ -166,12 +171,13 @@ endfunction()
166171
# on. These will be collected for all extensions and put into an
167172
# aggregate dylib that is linked against.
168173
# PYTHON_BINDINGS_LIBRARY: Either pybind11 or nanobind.
169-
# GENERATE_TYPE_STUBS: Enable type stub generation.
174+
# GENERATE_TYPE_STUBS: List of generated type stubs expected from stubgen relative to _mlir_libs.
175+
# Note, these will be emitted into ${CMAKE_CURRENT_BINARY_DIR}/_mlir_libs.
170176
function(declare_mlir_python_extension name)
171177
cmake_parse_arguments(ARG
172-
"GENERATE_TYPE_STUBS"
178+
""
173179
"ROOT_DIR;MODULE_NAME;ADD_TO_PARENT;PYTHON_BINDINGS_LIBRARY"
174-
"SOURCES;PRIVATE_LINK_LIBS;EMBED_CAPI_LINK_LIBS"
180+
"SOURCES;PRIVATE_LINK_LIBS;EMBED_CAPI_LINK_LIBS;GENERATE_TYPE_STUBS"
175181
${ARGN})
176182

177183
if(NOT ARG_ROOT_DIR)
@@ -302,15 +308,26 @@ function(add_mlir_python_modules name)
302308
${_module_name}
303309
${_extension_target}
304310
${name}
305-
"${CMAKE_CURRENT_SOURCE_DIR}/mlir/_mlir_libs/_mlir"
311+
# "${CMAKE_CURRENT_SOURCE_DIR}/mlir/_mlir_libs"
312+
"${CMAKE_CURRENT_BINARY_DIR}/_mlir_libs"
313+
OUTPUTS "${_generate_type_stubs}"
306314
)
315+
add_dependencies("${modules_target}" "${NB_STUBGEN_CUSTOM_TARGET}")
316+
set(_stubgen_target "${MLIR_PYTHON_PACKAGE_PREFIX}.${_module_name}_type_stub_gen")
307317
declare_mlir_python_sources(
308-
"${MLIR_PYTHON_PACKAGE_PREFIX}.${_module_name}_type_stub_gen"
309-
ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
318+
${_stubgen_target}
319+
ROOT_DIR "${CMAKE_CURRENT_BINARY_DIR}/_mlir_libs"
310320
ADD_TO_PARENT "${sources_target}"
311-
SOURCES_GLOB "_mlir_libs/${_module_name}/**/*.pyi"
321+
SOURCES "${_generate_type_stubs}"
312322
)
313-
add_dependencies("${modules_target}" "${NB_STUBGEN_CUSTOM_TARGET}")
323+
set(_pure_sources_target "${modules_target}.sources.${sources_target}_type_stub_gen")
324+
add_mlir_python_sources_target(${_pure_sources_target}
325+
INSTALL_COMPONENT ${modules_target}
326+
INSTALL_DIR "${ARG_INSTALL_PREFIX}/_mlir_libs"
327+
OUTPUT_DIRECTORY "${ARG_ROOT_PREFIX}/_mlir_libs"
328+
SOURCES_TARGETS ${_stubgen_target}
329+
)
330+
add_dependencies(${modules_target} ${_pure_sources_target})
314331
endif()
315332
else()
316333
message(SEND_ERROR "Unrecognized source type '${_source_type}' for python source target ${sources_target}")

mlir/python/CMakeLists.txt

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -506,6 +506,10 @@ declare_mlir_python_extension(MLIRPythonExtension.Core
506506
# Dialects
507507
MLIRCAPIFunc
508508
GENERATE_TYPE_STUBS
509+
"_mlir/__init__.pyi"
510+
"_mlir/ir.pyi"
511+
"_mlir/passmanager.pyi"
512+
"_mlir/rewrite.pyi"
509513
)
510514

511515
# This extension exposes an API to register all dialects, extensions, and passes
@@ -528,6 +532,7 @@ declare_mlir_python_extension(MLIRPythonExtension.RegisterEverything
528532
MLIRCAPITransforms
529533
MLIRCAPIRegisterEverything
530534
GENERATE_TYPE_STUBS
535+
"_mlirRegisterEverything.pyi"
531536
)
532537

533538
declare_mlir_python_extension(MLIRPythonExtension.Dialects.Linalg.Pybind
@@ -543,6 +548,7 @@ declare_mlir_python_extension(MLIRPythonExtension.Dialects.Linalg.Pybind
543548
MLIRCAPIIR
544549
MLIRCAPILinalg
545550
GENERATE_TYPE_STUBS
551+
"_mlirDialectsLinalg.pyi"
546552
)
547553

548554
declare_mlir_python_extension(MLIRPythonExtension.Dialects.GPU.Pybind
@@ -558,6 +564,7 @@ declare_mlir_python_extension(MLIRPythonExtension.Dialects.GPU.Pybind
558564
MLIRCAPIIR
559565
MLIRCAPIGPU
560566
GENERATE_TYPE_STUBS
567+
"_mlirDialectsGPU.pyi"
561568
)
562569

563570
declare_mlir_python_extension(MLIRPythonExtension.Dialects.LLVM.Pybind
@@ -573,6 +580,7 @@ declare_mlir_python_extension(MLIRPythonExtension.Dialects.LLVM.Pybind
573580
MLIRCAPIIR
574581
MLIRCAPILLVM
575582
GENERATE_TYPE_STUBS
583+
"_mlirDialectsLLVM.pyi"
576584
)
577585

578586
declare_mlir_python_extension(MLIRPythonExtension.Dialects.Quant.Pybind
@@ -588,6 +596,7 @@ declare_mlir_python_extension(MLIRPythonExtension.Dialects.Quant.Pybind
588596
MLIRCAPIIR
589597
MLIRCAPIQuant
590598
GENERATE_TYPE_STUBS
599+
"_mlirDialectsQuant.pyi"
591600
)
592601

593602
declare_mlir_python_extension(MLIRPythonExtension.Dialects.NVGPU.Pybind
@@ -603,6 +612,7 @@ declare_mlir_python_extension(MLIRPythonExtension.Dialects.NVGPU.Pybind
603612
MLIRCAPIIR
604613
MLIRCAPINVGPU
605614
GENERATE_TYPE_STUBS
615+
"_mlirDialectsNVGPU.pyi"
606616
)
607617

608618
declare_mlir_python_extension(MLIRPythonExtension.Dialects.PDL.Pybind
@@ -618,6 +628,7 @@ declare_mlir_python_extension(MLIRPythonExtension.Dialects.PDL.Pybind
618628
MLIRCAPIIR
619629
MLIRCAPIPDL
620630
GENERATE_TYPE_STUBS
631+
"_mlirDialectsPDL.pyi"
621632
)
622633

623634
declare_mlir_python_extension(MLIRPythonExtension.Dialects.SparseTensor.Pybind
@@ -633,6 +644,7 @@ declare_mlir_python_extension(MLIRPythonExtension.Dialects.SparseTensor.Pybind
633644
MLIRCAPIIR
634645
MLIRCAPISparseTensor
635646
GENERATE_TYPE_STUBS
647+
"_mlirDialectsSparseTensor.pyi"
636648
)
637649

638650
declare_mlir_python_extension(MLIRPythonExtension.Dialects.Transform.Pybind
@@ -648,6 +660,7 @@ declare_mlir_python_extension(MLIRPythonExtension.Dialects.Transform.Pybind
648660
MLIRCAPIIR
649661
MLIRCAPITransformDialect
650662
GENERATE_TYPE_STUBS
663+
"_mlirDialectsTransform.pyi"
651664
)
652665

653666
declare_mlir_python_extension(MLIRPythonExtension.AsyncDialectPasses
@@ -662,6 +675,7 @@ declare_mlir_python_extension(MLIRPythonExtension.AsyncDialectPasses
662675
EMBED_CAPI_LINK_LIBS
663676
MLIRCAPIAsync
664677
GENERATE_TYPE_STUBS
678+
"_mlirAsyncPasses.pyi"
665679
)
666680

667681
if(MLIR_ENABLE_EXECUTION_ENGINE)
@@ -677,6 +691,7 @@ if(MLIR_ENABLE_EXECUTION_ENGINE)
677691
EMBED_CAPI_LINK_LIBS
678692
MLIRCAPIExecutionEngine
679693
GENERATE_TYPE_STUBS
694+
"_mlirExecutionEngine.pyi"
680695
)
681696
endif()
682697

@@ -692,6 +707,7 @@ declare_mlir_python_extension(MLIRPythonExtension.GPUDialectPasses
692707
EMBED_CAPI_LINK_LIBS
693708
MLIRCAPIGPU
694709
GENERATE_TYPE_STUBS
710+
"_mlirGPUPasses.pyi"
695711
)
696712

697713
declare_mlir_python_extension(MLIRPythonExtension.LinalgPasses
@@ -706,6 +722,7 @@ declare_mlir_python_extension(MLIRPythonExtension.LinalgPasses
706722
EMBED_CAPI_LINK_LIBS
707723
MLIRCAPILinalg
708724
GENERATE_TYPE_STUBS
725+
"_mlirLinalgPasses.pyi"
709726
)
710727

711728
declare_mlir_python_extension(MLIRPythonExtension.Dialects.SMT.Pybind
@@ -724,6 +741,7 @@ declare_mlir_python_extension(MLIRPythonExtension.Dialects.SMT.Pybind
724741
MLIRCAPISMT
725742
MLIRCAPIExportSMTLIB
726743
GENERATE_TYPE_STUBS
744+
"_mlirDialectsSMT.pyi"
727745
)
728746

729747
declare_mlir_python_extension(MLIRPythonExtension.SparseTensorDialectPasses
@@ -738,6 +756,7 @@ declare_mlir_python_extension(MLIRPythonExtension.SparseTensorDialectPasses
738756
EMBED_CAPI_LINK_LIBS
739757
MLIRCAPISparseTensor
740758
GENERATE_TYPE_STUBS
759+
"_mlirSparseTensorPasses.pyi"
741760
)
742761

743762
declare_mlir_python_extension(MLIRPythonExtension.TransformInterpreter
@@ -752,6 +771,7 @@ declare_mlir_python_extension(MLIRPythonExtension.TransformInterpreter
752771
EMBED_CAPI_LINK_LIBS
753772
MLIRCAPITransformDialectTransforms
754773
GENERATE_TYPE_STUBS
774+
"_mlirTransformInterpreter.pyi"
755775
)
756776

757777
# TODO: Figure out how to put this in the test tree.
@@ -811,6 +831,7 @@ if(MLIR_INCLUDE_TESTS)
811831
EMBED_CAPI_LINK_LIBS
812832
MLIRCAPIPythonTestDialect
813833
GENERATE_TYPE_STUBS
834+
"_mlirPythonTestNanobind.pyi"
814835
)
815836
endif()
816837

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,2 @@
11
_mlir/**/*.pyi
2+
*.pyi

0 commit comments

Comments
 (0)