diff --git a/compiler/rustc_codegen_llvm/src/attributes.rs b/compiler/rustc_codegen_llvm/src/attributes.rs index e8c42d16733ec..378df6eb7fcad 100644 --- a/compiler/rustc_codegen_llvm/src/attributes.rs +++ b/compiler/rustc_codegen_llvm/src/attributes.rs @@ -28,6 +28,19 @@ pub(crate) fn apply_to_callsite(callsite: &Value, idx: AttributePlace, attrs: &[ } } +pub(crate) fn has_attr(llfn: &Value, place: AttributePlace, attr: AttributeKind) -> bool { + llvm::HasAttributeAtIndex(llfn, place, attr) +} + +pub(crate) fn remove_str_attr_from_llfn(llfn: &Value, place: AttributePlace, attr: &Attribute) { + llvm::RemoveStringAttributeAtIndex(llfn, place, attr) +} + +pub(crate) fn remove_from_llfn(llfn: &Value, place: AttributePlace, kind: AttributeKind) { + llvm::RemoveEnumAttributeAtIndex(llfn, place, kind); +} + + /// Get LLVM attribute for the provided inline heuristic. #[inline] fn inline_attr<'ll>(cx: &CodegenCx<'ll, '_>, inline: InlineAttr) -> Option<&'ll Attribute> { diff --git a/compiler/rustc_codegen_llvm/src/back/lto.rs b/compiler/rustc_codegen_llvm/src/back/lto.rs index a8b49e9552c30..64feacb7bb21d 100644 --- a/compiler/rustc_codegen_llvm/src/back/lto.rs +++ b/compiler/rustc_codegen_llvm/src/back/lto.rs @@ -28,8 +28,9 @@ use crate::back::write::{ use crate::errors::{ DynamicLinkingWithLTO, LlvmError, LtoBitcodeFromRlib, LtoDisallowed, LtoDylib, LtoProcMacro, }; -use crate::llvm::{self, build_string}; -use crate::{LlvmCodegenBackend, ModuleLlvm}; +use crate::llvm::AttributePlace::Function; +use crate::llvm::{self, build_string, get_value_name, LLVMDumpModule}; +use crate::{LlvmCodegenBackend, ModuleLlvm, SimpleCx, attributes}; /// We keep track of the computed LTO cache keys from the previous /// session to determine which CGUs we can reuse. @@ -584,12 +585,10 @@ fn thin_lto( } } -fn enable_autodiff_settings(ad: &[config::AutoDiff], module: &mut ModuleCodegen) { +fn enable_autodiff_settings(ad: &[config::AutoDiff]) { for &val in ad { + // We intentionally don't use a wildcard, to not forget handling anything new. match val { - config::AutoDiff::PrintModBefore => { - unsafe { llvm::LLVMDumpModule(module.module_llvm.llmod()) }; - } config::AutoDiff::PrintPerf => { llvm::set_print_perf(true); } @@ -603,17 +602,23 @@ fn enable_autodiff_settings(ad: &[config::AutoDiff], module: &mut ModuleCodegen< llvm::set_inline(true); } config::AutoDiff::LooseTypes => { - llvm::set_loose_types(false); + llvm::set_loose_types(true); } config::AutoDiff::PrintSteps => { llvm::set_print(true); } - // We handle this below + // We handle this in the PassWrapper.cpp + config::AutoDiff::PrintPasses => {} + // We handle this in the PassWrapper.cpp + config::AutoDiff::PrintModBefore => {} + // We handle this in the PassWrapper.cpp config::AutoDiff::PrintModAfter => {} - // We handle this below + // We handle this in the PassWrapper.cpp config::AutoDiff::PrintModFinal => {} // This is required and already checked config::AutoDiff::Enable => {} + // We handle this below + config::AutoDiff::NoPostopt => {} } } // This helps with handling enums for now. @@ -647,11 +652,14 @@ pub(crate) fn run_pass_manager( // We then run the llvm_optimize function a second time, to optimize the code which we generated // in the enzyme differentiation pass. let enable_ad = config.autodiff.contains(&config::AutoDiff::Enable); - let stage = - if enable_ad { write::AutodiffStage::DuringAD } else { write::AutodiffStage::PostAD }; + let stage = if thin { + write::AutodiffStage::PreAD + } else { + if enable_ad { write::AutodiffStage::DuringAD } else { write::AutodiffStage::PostAD } + }; if enable_ad { - enable_autodiff_settings(&config.autodiff, module); + enable_autodiff_settings(&config.autodiff); } unsafe { @@ -664,10 +672,49 @@ pub(crate) fn run_pass_manager( unsafe { llvm::LLVMDumpModule(module.module_llvm.llmod()) }; } + let cx = + SimpleCx::new(module.module_llvm.llmod(), &module.module_llvm.llcx, cgcx.pointer_size); + let enzyme_marker_attr = unsafe { llvm::CreateAttrString(cx.llcx, "enzyme_marker") }; + + for function in cx.get_functions() { + let c_attr_str: CString = CString::new("enzyme_marker").unwrap(); + let c_attr_str2 = c_attr_str.as_ptr(); + let has_attribute = unsafe {llvm::LLVMRustHasFnAttribute(function, c_attr_str2)}; + + if has_attribute { + dbg!("Found enzyme marker attribute"); + unsafe {LLVMDumpModule(cx.llmod())}; + // Doesn't do anything, maybe wrong Attribute Kind for C/C++ level? + attributes::remove_from_llfn(function, Function, llvm::AttributeKind::NoInline); + // This works + unsafe { + llvm::LLVMRustRemoveEnumAttributeAtIndex( + function, + !0, + llvm::AttributeKind::NoInline, + ); + } + unsafe {LLVMDumpModule(cx.llmod())}; + // This works though for string attributes + unsafe {llvm::LLVMRustRemoveFnAttribute(function, c_attr_str2)}; + unsafe {LLVMDumpModule(cx.llmod())}; + let has_attribute = unsafe {llvm::LLVMRustHasFnAttribute(function, c_attr_str2)}; + assert!(!has_attribute, "Expected function to not have 'enzyme_marker'"); + let attr = llvm::AttributeKind::AlwaysInline.create_attr(cx.llcx); + attributes::apply_to_llfn(function, Function, &[attr]); + dbg!(&function); + } else { + continue; + } + + } + let opt_stage = llvm::OptStage::FatLTO; let stage = write::AutodiffStage::PostAD; - unsafe { - write::llvm_optimize(cgcx, dcx, module, None, config, opt_level, opt_stage, stage)?; + if !config.autodiff.contains(&config::AutoDiff::NoPostopt) { + unsafe { + write::llvm_optimize(cgcx, dcx, module, None, config, opt_level, opt_stage, stage)?; + } } // This is the final IR, so people should be able to inspect the optimized autodiff output, diff --git a/compiler/rustc_codegen_llvm/src/back/write.rs b/compiler/rustc_codegen_llvm/src/back/write.rs index 18d221d232e82..c5d9f64247411 100644 --- a/compiler/rustc_codegen_llvm/src/back/write.rs +++ b/compiler/rustc_codegen_llvm/src/back/write.rs @@ -572,6 +572,9 @@ pub(crate) unsafe fn llvm_optimize( let consider_ad = cfg!(llvm_enzyme) && config.autodiff.contains(&config::AutoDiff::Enable); let run_enzyme = autodiff_stage == AutodiffStage::DuringAD; + let print_before_enzyme = config.autodiff.contains(&config::AutoDiff::PrintModBefore); + let print_after_enzyme = config.autodiff.contains(&config::AutoDiff::PrintModAfter); + let print_passes = config.autodiff.contains(&config::AutoDiff::PrintPasses); let unroll_loops; let vectorize_slp; let vectorize_loop; @@ -670,6 +673,9 @@ pub(crate) unsafe fn llvm_optimize( config.no_builtins, config.emit_lifetime_markers, run_enzyme, + print_before_enzyme, + print_after_enzyme, + print_passes, sanitizer_options.as_ref(), pgo_gen_path.as_ref().map_or(std::ptr::null(), |s| s.as_ptr()), pgo_use_path.as_ref().map_or(std::ptr::null(), |s| s.as_ptr()), diff --git a/compiler/rustc_codegen_llvm/src/builder/autodiff.rs b/compiler/rustc_codegen_llvm/src/builder/autodiff.rs index 5e7ef27143b14..3d5751f27a2b8 100644 --- a/compiler/rustc_codegen_llvm/src/builder/autodiff.rs +++ b/compiler/rustc_codegen_llvm/src/builder/autodiff.rs @@ -334,6 +334,17 @@ fn generate_enzyme_call<'ll>( let attr = llvm::AttributeKind::NoInline.create_attr(cx.llcx); attributes::apply_to_llfn(ad_fn, Function, &[attr]); + // We add a made-up attribute just such that we can recognize it after AD to update + // (no)-inline attributes. We'll then also remove this attribute. + let enzyme_marker_attr = unsafe { llvm::CreateAttrString(cx.llcx, "enzyme_marker") }; + //Some(llvm::CreateAttrStringValue(cx.llcx, "no-jump-tables", "true")) + // &llvm::ffi::Attribute + attributes::apply_to_llfn( + outer_fn, + Function, + &[enzyme_marker_attr], + ); + // first, remove all calls from fnc let entry = llvm::LLVMGetFirstBasicBlock(outer_fn); let br = llvm::LLVMRustGetTerminator(entry); diff --git a/compiler/rustc_codegen_llvm/src/context.rs b/compiler/rustc_codegen_llvm/src/context.rs index 4ec6999551898..ed50515b70716 100644 --- a/compiler/rustc_codegen_llvm/src/context.rs +++ b/compiler/rustc_codegen_llvm/src/context.rs @@ -698,6 +698,16 @@ impl<'ll, CX: Borrow>> GenericCx<'ll, CX> { llvm::LLVMMDStringInContext2(self.llcx(), name.as_ptr() as *const c_char, name.len()) }) } + + pub(crate) fn get_functions(&self) -> Vec<&'ll Value> { + let mut functions = vec![]; + let mut func = unsafe { llvm::LLVMGetFirstFunction(self.llmod()) }; + while let Some(f) = func { + functions.push(f); + func = unsafe { llvm::LLVMGetNextFunction(f) } + } + functions + } } impl<'ll, 'tcx> MiscCodegenMethods<'tcx> for CodegenCx<'ll, 'tcx> { diff --git a/compiler/rustc_codegen_llvm/src/llvm/enzyme_ffi.rs b/compiler/rustc_codegen_llvm/src/llvm/enzyme_ffi.rs index a9b3bdf7344be..d764522b0b152 100644 --- a/compiler/rustc_codegen_llvm/src/llvm/enzyme_ffi.rs +++ b/compiler/rustc_codegen_llvm/src/llvm/enzyme_ffi.rs @@ -19,6 +19,13 @@ unsafe extern "C" { pub(crate) fn LLVMRustVerifyFunction(V: &Value, action: LLVMRustVerifierFailureAction) -> Bool; pub(crate) fn LLVMRustHasAttributeAtIndex(V: &Value, i: c_uint, Kind: AttributeKind) -> bool; pub(crate) fn LLVMRustGetArrayNumElements(Ty: &Type) -> u64; + pub(crate) fn LLVMRustHasFnAttribute(F: &Value, Name: *const c_char) -> bool; + pub(crate) fn LLVMRustRemoveFnAttribute(F: &Value, Name: *const c_char); + pub(crate) fn LLVMRustRemoveEnumAttributeAtIndex( + F: &Value, + i: c_uint, + Kind: AttributeKind, + ); } unsafe extern "C" { diff --git a/compiler/rustc_codegen_llvm/src/llvm/ffi.rs b/compiler/rustc_codegen_llvm/src/llvm/ffi.rs index 9ff04f729030c..e03dbc071af69 100644 --- a/compiler/rustc_codegen_llvm/src/llvm/ffi.rs +++ b/compiler/rustc_codegen_llvm/src/llvm/ffi.rs @@ -1157,6 +1157,11 @@ unsafe extern "C" { pub(crate) fn LLVMGetNamedGlobal(M: &Module, Name: *const c_char) -> Option<&Value>; pub(crate) fn LLVMGetFirstGlobal(M: &Module) -> Option<&Value>; pub(crate) fn LLVMGetNextGlobal(GlobalVar: &Value) -> Option<&Value>; + + pub(crate) fn LLVMGetFirstFunction(M: &Module) -> Option<&Value>; + pub(crate) fn LLVMGetNextFunction(Fn: &Value) -> Option<&Value>; + pub(crate) fn LLVMRemoveEnumAttributeAtIndex(Fn: &Value, index: c_uint, kind: AttributeKind); + pub(crate) fn LLVMDeleteGlobal(GlobalVar: &Value); pub(crate) fn LLVMGetInitializer(GlobalVar: &Value) -> Option<&Value>; pub(crate) fn LLVMSetInitializer<'a>(GlobalVar: &'a Value, ConstantVal: &'a Value); @@ -1875,6 +1880,11 @@ unsafe extern "C" { Attrs: *const &'a Attribute, AttrsLen: size_t, ); + pub(crate) fn LLVMRustRemoveFunctionAttribute<'a>( + Fn: &'a Value, + index: c_uint, + Attrs: &'a Attribute, + ); // Operations on call sites pub(crate) fn LLVMRustAddCallSiteAttributes<'a>( @@ -2454,6 +2464,9 @@ unsafe extern "C" { DisableSimplifyLibCalls: bool, EmitLifetimeMarkers: bool, RunEnzyme: bool, + PrintBeforeEnzyme: bool, + PrintAfterEnzyme: bool, + PrintPasses: bool, SanitizerOptions: Option<&SanitizerOptions>, PGOGenPath: *const c_char, PGOUsePath: *const c_char, diff --git a/compiler/rustc_codegen_llvm/src/llvm/mod.rs b/compiler/rustc_codegen_llvm/src/llvm/mod.rs index 6ca81c651ed42..7aa3c5bfac497 100644 --- a/compiler/rustc_codegen_llvm/src/llvm/mod.rs +++ b/compiler/rustc_codegen_llvm/src/llvm/mod.rs @@ -41,6 +41,21 @@ pub(crate) fn AddFunctionAttributes<'ll>( } } +pub(crate) fn HasAttributeAtIndex<'ll>(llfn: &'ll Value, place: AttributePlace, kind: AttributeKind) -> bool { + unsafe { LLVMRustHasAttributeAtIndex(llfn, place.as_uint(), kind) } +} +pub(crate) fn RemoveStringAttributeAtIndex<'ll>(llfn: &Value, place: AttributePlace, attr: &'ll Attribute) { + unsafe { + LLVMRustRemoveFunctionAttribute(llfn, place.as_uint(), attr) + } +} + +pub(crate) fn RemoveEnumAttributeAtIndex(llfn: &Value, place: AttributePlace, kind: AttributeKind) { + unsafe { + LLVMRemoveEnumAttributeAtIndex(llfn, place.as_uint(), kind); + } +} + pub(crate) fn AddCallSiteAttributes<'ll>( callsite: &'ll Value, idx: AttributePlace, diff --git a/compiler/rustc_codegen_llvm/src/type_.rs b/compiler/rustc_codegen_llvm/src/type_.rs index b89ce90d1a1dc..169036f515298 100644 --- a/compiler/rustc_codegen_llvm/src/type_.rs +++ b/compiler/rustc_codegen_llvm/src/type_.rs @@ -128,6 +128,10 @@ impl<'ll, CX: Borrow>> GenericCx<'ll, CX> { (**self).borrow().llcx } + pub(crate) fn llmod(&self) -> &'ll llvm::Module { + (**self).borrow().llmod + } + pub(crate) fn isize_ty(&self) -> &'ll Type { (**self).borrow().isize_ty } diff --git a/compiler/rustc_llvm/llvm-wrapper/PassWrapper.cpp b/compiler/rustc_llvm/llvm-wrapper/PassWrapper.cpp index e02c80c50b14a..8bee051dd4c3c 100644 --- a/compiler/rustc_llvm/llvm-wrapper/PassWrapper.cpp +++ b/compiler/rustc_llvm/llvm-wrapper/PassWrapper.cpp @@ -14,6 +14,7 @@ #include "llvm/IR/LegacyPassManager.h" #include "llvm/IR/PassManager.h" #include "llvm/IR/Verifier.h" +#include "llvm/IRPrinter/IRPrintingPasses.h" #include "llvm/LTO/LTO.h" #include "llvm/MC/MCSubtargetInfo.h" #include "llvm/MC/TargetRegistry.h" @@ -703,7 +704,8 @@ extern "C" LLVMRustResult LLVMRustOptimize( bool LintIR, LLVMRustThinLTOBuffer **ThinLTOBufferRef, bool EmitThinLTO, bool EmitThinLTOSummary, bool MergeFunctions, bool UnrollLoops, bool SLPVectorize, bool LoopVectorize, bool DisableSimplifyLibCalls, - bool EmitLifetimeMarkers, bool RunEnzyme, + bool EmitLifetimeMarkers, bool RunEnzyme, bool PrintBeforeEnzyme, + bool PrintAfterEnzyme, bool PrintPasses, LLVMRustSanitizerOptions *SanitizerOptions, const char *PGOGenPath, const char *PGOUsePath, bool InstrumentCoverage, const char *InstrProfileOutput, const char *PGOSampleUsePath, @@ -1048,14 +1050,38 @@ extern "C" LLVMRustResult LLVMRustOptimize( // now load "-enzyme" pass: #ifdef ENZYME if (RunEnzyme) { - registerEnzymeAndPassPipeline(PB, true); + + if (PrintBeforeEnzyme) { + // Handle the Rust flag `-Zautodiff=PrintModBefore`. + std::string Banner = "Module before EnzymeNewPM"; + MPM.addPass(PrintModulePass(outs(), Banner, true, false)); + } + + registerEnzymeAndPassPipeline(PB, false); if (auto Err = PB.parsePassPipeline(MPM, "enzyme")) { std::string ErrMsg = toString(std::move(Err)); LLVMRustSetLastError(ErrMsg.c_str()); return LLVMRustResult::Failure; } + + if (PrintAfterEnzyme) { + // Handle the Rust flag `-Zautodiff=PrintModAfter`. + std::string Banner = "Module after EnzymeNewPM"; + MPM.addPass(PrintModulePass(outs(), Banner, true, false)); + } } #endif + if (PrintPasses) { + // Print all passes from the PM: + std::string Pipeline; + raw_string_ostream SOS(Pipeline); + MPM.printPipeline(SOS, [&PIC](StringRef ClassName) { + auto PassName = PIC.getPassNameForClassName(ClassName); + return PassName.empty() ? ClassName : PassName; + }); + outs() << Pipeline; + outs() << "\n"; + } // Upgrade all calls to old intrinsics first. for (Module::iterator I = TheModule->begin(), E = TheModule->end(); I != E;) diff --git a/compiler/rustc_llvm/llvm-wrapper/RustWrapper.cpp b/compiler/rustc_llvm/llvm-wrapper/RustWrapper.cpp index 5f0e4d745e833..bfe0a7ac23ff6 100644 --- a/compiler/rustc_llvm/llvm-wrapper/RustWrapper.cpp +++ b/compiler/rustc_llvm/llvm-wrapper/RustWrapper.cpp @@ -958,6 +958,30 @@ extern "C" void LLVMRustEraseInstUntilInclusive(LLVMBasicBlockRef bb, It->eraseFromParent(); } + +extern "C" void LLVMRustRemoveEnumAttributeAtIndex(LLVMValueRef F, size_t index, + LLVMRustAttributeKind RustAttr) { + LLVMRemoveEnumAttributeAtIndex(F, index, fromRust(RustAttr)); +} + +extern "C" void LLVMRustRemoveFnAttribute(LLVMValueRef Fn, const char *Name) { + Function *F = unwrap(Fn); + assert(F); + assert(F->hasFnAttribute(Name) && + "Function does not have the attribute to be removed"); + F->removeFnAttr(Name); + //AttributeList PAL = F->getAttributes(); + //PAL = PAL.removeParamAttribute(F->getContext(), Index, fromRust(RustAttr)); + //F->setAttributes(PAL); +} + +extern "C" bool LLVMRustHasFnAttribute(LLVMValueRef F, const char *Name) { + if (auto *Fn = dyn_cast(unwrap(F))) { + return Fn->hasFnAttribute(Name); + } + return false; +} + extern "C" bool LLVMRustHasMetadata(LLVMValueRef inst, unsigned kindID) { if (auto *I = dyn_cast(unwrap(inst))) { return I->hasMetadata(kindID); diff --git a/compiler/rustc_session/src/config.rs b/compiler/rustc_session/src/config.rs index bdd54a15147b4..e0ad562a032e1 100644 --- a/compiler/rustc_session/src/config.rs +++ b/compiler/rustc_session/src/config.rs @@ -244,6 +244,10 @@ pub enum AutoDiff { /// Print the module after running autodiff and optimizations. PrintModFinal, + /// Print all passes scheduled by LLVM + PrintPasses, + /// Disable extra opt run after running autodiff + NoPostopt, /// Enzyme's loose type debug helper (can cause incorrect gradients!!) /// Usable in cases where Enzyme errors with `can not deduce type of X`. LooseTypes, diff --git a/compiler/rustc_session/src/options.rs b/compiler/rustc_session/src/options.rs index c70f1500d3930..2531b0c9d4219 100644 --- a/compiler/rustc_session/src/options.rs +++ b/compiler/rustc_session/src/options.rs @@ -711,7 +711,7 @@ mod desc { pub(crate) const parse_list: &str = "a space-separated list of strings"; pub(crate) const parse_list_with_polarity: &str = "a comma-separated list of strings, with elements beginning with + or -"; - pub(crate) const parse_autodiff: &str = "a comma separated list of settings: `Enable`, `PrintSteps`, `PrintTA`, `PrintAA`, `PrintPerf`, `PrintModBefore`, `PrintModAfter`, `PrintModFinal`, `LooseTypes`, `Inline`"; + pub(crate) const parse_autodiff: &str = "a comma separated list of settings: `Enable`, `PrintSteps`, `PrintTA`, `PrintAA`, `PrintPerf`, `PrintModBefore`, `PrintModAfter`, `PrintModFinal`, `PrintPasses`, `NoPostopt`, `LooseTypes`, `Inline`"; pub(crate) const parse_comma_list: &str = "a comma-separated list of strings"; pub(crate) const parse_opt_comma_list: &str = parse_comma_list; pub(crate) const parse_number: &str = "a number"; @@ -1360,6 +1360,8 @@ pub mod parse { "PrintModBefore" => AutoDiff::PrintModBefore, "PrintModAfter" => AutoDiff::PrintModAfter, "PrintModFinal" => AutoDiff::PrintModFinal, + "NoPostopt" => AutoDiff::NoPostopt, + "PrintPasses" => AutoDiff::PrintPasses, "LooseTypes" => AutoDiff::LooseTypes, "Inline" => AutoDiff::Inline, _ => { @@ -2095,6 +2097,8 @@ options! { `=PrintModBefore` `=PrintModAfter` `=PrintModFinal` + `=PrintPasses`, + `=NoPostopt` `=LooseTypes` `=Inline` Multiple options can be combined with commas."), diff --git a/tests/codegen/autodiff/inline.rs b/tests/codegen/autodiff/inline.rs new file mode 100644 index 0000000000000..8e9c63c43c796 --- /dev/null +++ b/tests/codegen/autodiff/inline.rs @@ -0,0 +1,26 @@ +//@ compile-flags: -Zautodiff=Enable -C opt-level=3 -Clto=fat -Zautodiff=NoPostopt +//@ no-prefer-dynamic +//@ needs-enzyme + + +#![feature(autodiff)] + +use std::autodiff::autodiff; + +#[autodiff(d_square, Reverse, Duplicated, Active)] +fn square(x: &f64) -> f64 { + x * x +} + + +// CHECK: ; inline::d_square +// CHECK-NEXT: ; Function Attrs: alwaysinline +// CHECK-NOT: noinline +// CHECK-NEXT: define internal fastcc void @_ZN6inline8d_square17h021c74e92c259cdeE + +fn main() { + let x = std::hint::black_box(3.0); + let mut dx1 = std::hint::black_box(1.0); + let _ = d_square(&x, &mut dx1, 1.0); + assert_eq!(dx1, 6.0); +}