diff --git a/bolt/include/bolt/Core/MCPlusBuilder.h b/bolt/include/bolt/Core/MCPlusBuilder.h index 7cf846728acebd2100313193ed0d02ab93c481ea..7d9af341b3e86abdde815ff301f0818a29147351 100644 --- a/bolt/include/bolt/Core/MCPlusBuilder.h +++ b/bolt/include/bolt/Core/MCPlusBuilder.h @@ -1650,6 +1650,48 @@ public: return true; } + /// Extract a symbol and an addend out of the fixup value expression. + /// + /// Only the following limited expression types are supported: + /// Symbol + Addend + /// Symbol + Constant + Addend + /// Const + Addend + /// Symbol + std::pair extractFixupExpr(const MCFixup &Fixup) const { + uint64_t Addend = 0; + MCSymbol *Symbol = nullptr; + const MCExpr *ValueExpr = Fixup.getValue(); + if (ValueExpr->getKind() == MCExpr::Binary) { + const auto *BinaryExpr = cast(ValueExpr); + assert(BinaryExpr->getOpcode() == MCBinaryExpr::Add && + "unexpected binary expression"); + const MCExpr *LHS = BinaryExpr->getLHS(); + if (LHS->getKind() == MCExpr::Constant) { + Addend = cast(LHS)->getValue(); + } else if (LHS->getKind() == MCExpr::Binary) { + const auto *LHSBinaryExpr = cast(LHS); + assert(LHSBinaryExpr->getOpcode() == MCBinaryExpr::Add && + "unexpected binary expression"); + const MCExpr *LLHS = LHSBinaryExpr->getLHS(); + assert(LLHS->getKind() == MCExpr::SymbolRef && "unexpected LLHS"); + Symbol = const_cast(this->getTargetSymbol(LLHS)); + const MCExpr *RLHS = LHSBinaryExpr->getRHS(); + assert(RLHS->getKind() == MCExpr::Constant && "unexpected RLHS"); + Addend = cast(RLHS)->getValue(); + } else { + assert(LHS->getKind() == MCExpr::SymbolRef && "unexpected LHS"); + Symbol = const_cast(this->getTargetSymbol(LHS)); + } + const MCExpr *RHS = BinaryExpr->getRHS(); + assert(RHS->getKind() == MCExpr::Constant && "unexpected RHS"); + Addend += cast(RHS)->getValue(); + } else { + assert(ValueExpr->getKind() == MCExpr::SymbolRef && "unexpected value"); + Symbol = const_cast(this->getTargetSymbol(ValueExpr)); + } + return std::make_pair(Symbol, Addend); + } + /// Return annotation index matching the \p Name. Optional getAnnotationIndex(StringRef Name) const { auto AI = AnnotationNameIndexMap.find(Name); diff --git a/bolt/lib/Core/Relocation.cpp b/bolt/lib/Core/Relocation.cpp index f989ab1e0c47be2d0efd89d1f13d9598c0fefe58..34247f3daec32a07c6e268c59233d173498d9a56 100644 --- a/bolt/lib/Core/Relocation.cpp +++ b/bolt/lib/Core/Relocation.cpp @@ -273,6 +273,13 @@ uint64_t adjustValueAArch64(uint64_t Type, uint64_t Value, uint64_t PC) { case ELF::R_AARCH64_PREL64: Value -= PC; break; + case ELF::R_AARCH64_CALL26: + Value -= PC; + assert(isInt<28>(Value) && "only PC +/- 128MB is allowed for direct call"); + // Immediate goes in bits 25:0 of BL. + // OP 1001_01 goes in bits 31:26 of BL. + Value = (Value >> 2) | 0x94000000ULL; + break; } return Value; } diff --git a/bolt/lib/Target/AArch64/AArch64MCPlusBuilder.cpp b/bolt/lib/Target/AArch64/AArch64MCPlusBuilder.cpp index c736196a84ca4fe92e4361fbae361b648661245f..e9b494dc3b3180b31b5eab421557b76e4ffbe71d 100644 --- a/bolt/lib/Target/AArch64/AArch64MCPlusBuilder.cpp +++ b/bolt/lib/Target/AArch64/AArch64MCPlusBuilder.cpp @@ -11,11 +11,13 @@ //===----------------------------------------------------------------------===// #include "MCTargetDesc/AArch64AddressingModes.h" +#include "MCTargetDesc/AArch64FixupKinds.h" #include "MCTargetDesc/AArch64MCExpr.h" #include "MCTargetDesc/AArch64MCTargetDesc.h" #include "Utils/AArch64BaseInfo.h" #include "bolt/Core/MCPlusBuilder.h" #include "llvm/BinaryFormat/ELF.h" +#include "llvm/MC/MCFixupKindInfo.h" #include "llvm/MC/MCInstrInfo.h" #include "llvm/MC/MCRegisterInfo.h" #include "llvm/Support/Debug.h" @@ -1135,6 +1137,52 @@ public: ELF::R_AARCH64_ADD_ABS_LO12_NC); return Insts; } + + std::optional + createRelocation(const MCFixup &Fixup, + const MCAsmBackend &MAB) const override { + const MCFixupKindInfo &FKI = MAB.getFixupKindInfo(Fixup.getKind()); + + assert(FKI.TargetOffset == 0 && "0-bit relocation offset expected"); + const uint64_t RelOffset = Fixup.getOffset(); + + uint64_t RelType; + if (Fixup.getKind() == MCFixupKind(AArch64::fixup_aarch64_pcrel_call26)) + RelType = ELF::R_AARCH64_CALL26; + else if (FKI.Flags & MCFixupKindInfo::FKF_IsPCRel) { + switch (FKI.TargetSize) { + default: + return std::nullopt; + case 16: + RelType = ELF::R_AARCH64_PREL16; + break; + case 32: + RelType = ELF::R_AARCH64_PREL32; + break; + case 64: + RelType = ELF::R_AARCH64_PREL64; + break; + } + } else { + switch (FKI.TargetSize) { + default: + return std::nullopt; + case 16: + RelType = ELF::R_AARCH64_ABS16; + break; + case 32: + RelType = ELF::R_AARCH64_ABS32; + break; + case 64: + RelType = ELF::R_AARCH64_ABS64; + break; + } + } + + auto [RelSymbol, RelAddend] = extractFixupExpr(Fixup); + + return Relocation({RelOffset, RelSymbol, RelType, RelAddend, 0}); + } }; } // end anonymous namespace diff --git a/bolt/lib/Target/X86/X86MCPlusBuilder.cpp b/bolt/lib/Target/X86/X86MCPlusBuilder.cpp index b6343aada995dd05a5c44735975de51d6f0651f0..df6d6795d8150227d0139364da33c24bae422c2a 100644 --- a/bolt/lib/Target/X86/X86MCPlusBuilder.cpp +++ b/bolt/lib/Target/X86/X86MCPlusBuilder.cpp @@ -2617,30 +2617,9 @@ public: } } - // Extract a symbol and an addend out of the fixup value expression. - // - // Only the following limited expression types are supported: - // Symbol + Addend - // Symbol - uint64_t Addend = 0; - MCSymbol *Symbol = nullptr; - const MCExpr *ValueExpr = Fixup.getValue(); - if (ValueExpr->getKind() == MCExpr::Binary) { - const auto *BinaryExpr = cast(ValueExpr); - assert(BinaryExpr->getOpcode() == MCBinaryExpr::Add && - "unexpected binary expression"); - const MCExpr *LHS = BinaryExpr->getLHS(); - assert(LHS->getKind() == MCExpr::SymbolRef && "unexpected LHS"); - Symbol = const_cast(this->getTargetSymbol(LHS)); - const MCExpr *RHS = BinaryExpr->getRHS(); - assert(RHS->getKind() == MCExpr::Constant && "unexpected RHS"); - Addend = cast(RHS)->getValue(); - } else { - assert(ValueExpr->getKind() == MCExpr::SymbolRef && "unexpected value"); - Symbol = const_cast(this->getTargetSymbol(ValueExpr)); - } + auto [RelSymbol, RelAddend] = extractFixupExpr(Fixup); - return Relocation({RelOffset, Symbol, RelType, Addend, 0}); + return Relocation({RelOffset, RelSymbol, RelType, RelAddend, 0}); } bool replaceImmWithSymbolRef(MCInst &Inst, const MCSymbol *Symbol, diff --git a/bolt/test/AArch64/reloc-call26.s b/bolt/test/AArch64/reloc-call26.s new file mode 100644 index 0000000000000000000000000000000000000000..834bf6f91cd959a76cb3babf4af63b58f922a9d5 --- /dev/null +++ b/bolt/test/AArch64/reloc-call26.s @@ -0,0 +1,29 @@ +## This test checks processing of R_AARCH64_CALL26 relocation +## when option `--funcs` is enabled + +# REQUIRES: system-linux + +# RUN: llvm-mc -filetype=obj -triple aarch64-unknown-unknown \ +# RUN: %s -o %t.o +# RUN: %clang %cflags %t.o -o %t.exe -Wl,-q +# RUN: llvm-bolt %t.exe -o %t.bolt --funcs=func1 +# RUN: llvm-objdump -d --disassemble-symbols='_start' %t.bolt | \ +# RUN: FileCheck %s + +# CHECK: {{.*}} bl {{.*}} + + .text + .align 4 + .global _start + .type _start, %function +_start: + bl func1 + mov w8, #93 + svc #0 + .size _start, .-_start + + .global func1 + .type func1, %function +func1: + ret + .size func1, .-func1 \ No newline at end of file diff --git a/clang/lib/Driver/ToolChains/Gnu.cpp b/clang/lib/Driver/ToolChains/Gnu.cpp index 665cdc3132fb8fc69f0a33feffdd0a82a5510810..3ee5624e85f201bdcc78e07f783895b752cc637f 100644 --- a/clang/lib/Driver/ToolChains/Gnu.cpp +++ b/clang/lib/Driver/ToolChains/Gnu.cpp @@ -2156,6 +2156,11 @@ void Generic_GCC::GCCInstallationDetector::AddDefaultGCCPrefixes( Prefixes.push_back("/opt/rh/devtoolset-2/root/usr"); } + // openeuler embedded nativesdk uses this dir + if (SysRoot.empty() && TargetTriple.getVendor() == llvm::Triple::OpenEmbedded && + D.getVFS().exists("/opt/buildtools/nativesdk/sysroots")) + Prefixes.push_back("/opt/buildtools/nativesdk/sysroots/" + TargetTriple.getTriple()); + // Fall back to /usr which is used by most non-Solaris systems. Prefixes.push_back(concat(SysRoot, "/usr")); } @@ -2201,7 +2206,7 @@ void Generic_GCC::GCCInstallationDetector::AddDefaultGCCPrefixes( static const char *const CSKYTriples[] = { "csky-linux-gnuabiv2", "csky-linux-uclibcabiv2", "csky-elf-noneabiv2"}; - static const char *const X86_64LibDirs[] = {"/lib64", "/lib"}; + static const char *const X86_64LibDirs[] = {"/lib64", "/lib", "/usr/lib"}; static const char *const X86_64Triples[] = { "x86_64-linux-gnu", "x86_64-unknown-linux-gnu", "x86_64-pc-linux-gnu", "x86_64-redhat-linux6E", diff --git a/clang/lib/Driver/ToolChains/Linux.cpp b/clang/lib/Driver/ToolChains/Linux.cpp index ceb1a982c3a4ceb28575f9ffb4a962d67b77254b..d3c5ae53a3b28b14893f3541292e65df8897092a 100644 --- a/clang/lib/Driver/ToolChains/Linux.cpp +++ b/clang/lib/Driver/ToolChains/Linux.cpp @@ -376,7 +376,7 @@ std::string Linux::computeSysRoot() const { return std::string(); } - if (!GCCInstallation.isValid() || !getTriple().isMIPS()) + if (!GCCInstallation.isValid() || (!getTriple().isMIPS() && getTriple().getVendor() != llvm::Triple::OpenEmbedded)) return std::string(); // Standalone MIPS toolchains use different names for sysroot folder @@ -396,6 +396,11 @@ std::string Linux::computeSysRoot() const { Path = (InstallDir + "/../../../../sysroot" + Multilib.osSuffix()).str(); + if (getVFS().exists(Path)) + return Path; + + Path = (InstallDir + "/../../../../../" + TripleStr).str(); + if (getVFS().exists(Path)) return Path; @@ -454,7 +459,7 @@ std::string Linux::getDynamicLinker(const ArgList &Args) const { llvm_unreachable("unsupported architecture"); case llvm::Triple::aarch64: - LibDir = "lib"; + LibDir = "lib64"; Loader = "ld-linux-aarch64.so.1"; break; case llvm::Triple::aarch64_be: @@ -545,9 +550,12 @@ std::string Linux::getDynamicLinker(const ArgList &Args) const { break; case llvm::Triple::x86_64: { bool X32 = Triple.isX32(); + bool OE = (Triple.getVendor() == llvm::Triple::OpenEmbedded); LibDir = X32 ? "libx32" : "lib64"; Loader = X32 ? "ld-linux-x32.so.2" : "ld-linux-x86-64.so.2"; + if (OE) + return "/opt/buildtools/nativesdk/sysroots/" + Triple.str() + "/lib/"+ Loader; break; } case llvm::Triple::ve: diff --git a/llvm/include/llvm/ADT/ArrayView.h b/llvm/include/llvm/ADT/ArrayView.h new file mode 100644 index 0000000000000000000000000000000000000000..574042e1e5eec4d1389ec4bc25044f1b4719c707 --- /dev/null +++ b/llvm/include/llvm/ADT/ArrayView.h @@ -0,0 +1,55 @@ +#ifndef LLVM_ADT_ARRAY_VIEW +#define LLVM_ADT_ARRAY_VIEW + +template +class ArrayView { +public: + using iterator = typename ArrayBaseType::iterator; + using reverse_iterator = typename ArrayBaseType::reverse_iterator; + using value_type = typename ArrayBaseType::value_type; + +private: + iterator Begin; + iterator End; + reverse_iterator RBegin; + reverse_iterator REnd; + size_t Size; + +public: + + ArrayView(ArrayBaseType &Arr) { + Begin=Arr.begin(); + End=Arr.end(); + RBegin=Arr.rbegin(); + REnd=Arr.rend(); + Size=End-Begin; + } + + ArrayView(iterator Begin, iterator End, + reverse_iterator RBegin, reverse_iterator REnd) + : Begin(Begin), End(End), RBegin(RBegin), REnd(REnd) { + Size = End-Begin; + } + + iterator begin() { return Begin; } + iterator end() { return End; } + reverse_iterator rbegin() { return RBegin; } + reverse_iterator rend() { return REnd; } + + size_t size() { return Size; } + + void sliceWindow(size_t StartOffset, size_t EndOffset) { + End = Begin+EndOffset; + Begin = Begin+StartOffset; + REnd = RBegin+(Size-StartOffset); + RBegin = RBegin+(Size-EndOffset); + Size = End-Begin; + } + + value_type &operator[](size_t Index) { + return *(Begin+Index); + } + +}; + +#endif diff --git a/llvm/include/llvm/ADT/SADiagonalWindows.h b/llvm/include/llvm/ADT/SADiagonalWindows.h new file mode 100644 index 0000000000000000000000000000000000000000..09e4a467000bec15ebb240a49e50da4c12d0612c --- /dev/null +++ b/llvm/include/llvm/ADT/SADiagonalWindows.h @@ -0,0 +1,74 @@ +template> +class DiagonalWindowsSA : public SequenceAligner { +private: + using BaseType = SequenceAligner; + + size_t WindowSize; + +public: + DiagonalWindowsSA(ScoringSystem Scoring, MatchFnTy Match, size_t WindowSize) : BaseType(Scoring, Match), WindowSize(WindowSize) {} + + virtual size_t getMemoryRequirement(ContainerType &Seq1, + ContainerType &Seq2) { + size_t MemorySize = sizeof(ScoreSystemType)*(WindowSize+1)*(WindowSize+1); + + if (BaseType::getMatchOperation() != nullptr) + MemorySize += WindowSize*WindowSize*sizeof(bool); + + return MemorySize; + } + + virtual AlignedSequence getAlignment(ContainerType &Seq1, ContainerType &Seq2) { + + AlignedSequence Res; + + size_t Offset1 = 0; + size_t Offset2 = 0; + + + while (Offset1 View1(Seq1); + size_t EndWindow1 = ((Offset1+WindowSize)>View1.size())?View1.size():(Offset1+WindowSize); + View1.sliceWindow(Offset1, EndWindow1); + + ArrayView< ContainerType > View2(Seq2); + size_t EndWindow2 = ((Offset2+WindowSize)>View2.size())?View2.size():(Offset2+WindowSize); + View2.sliceWindow(Offset2, EndWindow2); + + NeedlemanWunschSA, Ty, Blank, MatchFnTy> SA( + BaseType::getScoring(), + BaseType::getMatchOperation()); + + AlignedSequence NWRes = SA.getAlignment(View1, View2); + + Res.splice(NWRes); + + Offset1 = EndWindow1; + Offset2 = EndWindow2; + + //Finished Seq1 or Seq2 + if (Offset1>=Seq1.size()) { + //Copy the remaining entries from Seq2 + if (Offset2 View2(Seq2); + View2.sliceWindow(Offset2, Seq2.size()); + for (auto Char : View2) + Res.Data.push_back(typename BaseType::EntryType(Blank,Char,false)); + } + } else if (Offset2>=Seq2.size()) { + //Copy the remaining entries from Seq1 + if (Offset1 View1(Seq1); + View1.sliceWindow(Offset1, Seq1.size()); + for (auto Char : View1) + Res.Data.push_back(typename BaseType::EntryType(Char,Blank,false)); + } + } + } + + return Res; + } + +}; + diff --git a/llvm/include/llvm/ADT/SAHirschberg.h b/llvm/include/llvm/ADT/SAHirschberg.h new file mode 100644 index 0000000000000000000000000000000000000000..8513cfae781403d9b9cf903f5e25e97bb5fc00e9 --- /dev/null +++ b/llvm/include/llvm/ADT/SAHirschberg.h @@ -0,0 +1,165 @@ +template> +class HirschbergSA : public SequenceAligner { +private: + ScoreSystemType *FinalScore; + ScoreSystemType *ScoreAux; + ScoreSystemType *ScoreCache; + + using BaseType = SequenceAligner; + + template + void NWScore(iterator1 Begin1, iterator1 End1, iterator2 Begin2, iterator2 End2) { + const size_t SizeSeq1 = End1-Begin1; + const size_t SizeSeq2 = End2-Begin2; + + ScoringSystem &Scoring = BaseType::getScoring(); + const ScoreSystemType Gap = Scoring.getGapPenalty(); + const ScoreSystemType Match = Scoring.getMatchProfit(); + const bool AllowMismatch = Scoring.getAllowMismatch(); + const ScoreSystemType Mismatch = AllowMismatch + ?Scoring.getMismatchPenalty() + :std::numeric_limits::min(); + + FinalScore[0] = 0; + for (size_t j = 1; j<=SizeSeq2; j++) { + FinalScore[j] = FinalScore[j-1] + Gap; //Ins(F2[j-1]); + } + + if (BaseType::getMatchOperation()==nullptr) { + if (AllowMismatch) { + for (size_t i = 1; i<=SizeSeq1; i++) { + ScoreAux[0] = FinalScore[0] + Gap; //Del(*(Begin1+(i-1))); + for (size_t j = 1; j<=SizeSeq2; j++) { + ScoreSystemType Similarity = (*(Begin1+(i-1))==*(Begin2+(j-1))) ? Match : Mismatch; + ScoreSystemType ScoreSub = FinalScore[j-1] + Similarity; //Sub(F1[i-1],F2[j-1]); + ScoreSystemType ScoreDel = FinalScore[j] + Gap; //Del(F1[i-1]); + ScoreSystemType ScoreIns = ScoreAux[j-1] + Gap; //Ins(F2[j-1]); + ScoreAux[j] = std::max(std::max(ScoreSub,ScoreDel),ScoreIns); + } + std::swap(FinalScore,ScoreAux); + } + } else { + for (size_t i = 1; i<=SizeSeq1; i++) { + ScoreAux[0] = FinalScore[0] + Gap; //Del(F1[i-1]); + for (size_t j = 1; j<=SizeSeq2; j++) { + ScoreSystemType ScoreSub = (*(Begin1+(i-1))==*(Begin2+(j-1))) ? (FinalScore[j-1] + Match) : Mismatch; + ScoreSystemType ScoreDel = FinalScore[j] + Gap; //Del(F1[i-1]); + ScoreSystemType ScoreIns = ScoreAux[j-1] + Gap; //Ins(F2[j-1]); + ScoreAux[j] = std::max(std::max(ScoreSub,ScoreDel),ScoreIns); + } + std::swap(FinalScore,ScoreAux); + } + } + } else { + if (AllowMismatch) { + for (size_t i = 1; i<=SizeSeq1; i++) { + ScoreAux[0] = FinalScore[0] + Gap; //Del(*(Begin1+(i-1))); + for (size_t j = 1; j<=SizeSeq2; j++) { + ScoreSystemType Similarity = BaseType::match(*(Begin1+(i-1)),*(Begin2+(j-1))) ? Match : Mismatch; + ScoreSystemType ScoreSub = FinalScore[j-1] + Similarity; //Sub(F1[i-1],F2[j-1]); + ScoreSystemType ScoreDel = FinalScore[j] + Gap; //Del(F1[i-1]); + ScoreSystemType ScoreIns = ScoreAux[j-1] + Gap; //Ins(F2[j-1]); + ScoreAux[j] = std::max(std::max(ScoreSub,ScoreDel),ScoreIns); + } + std::swap(FinalScore,ScoreAux); + } + } else { + for (size_t i = 1; i<=SizeSeq1; i++) { + ScoreAux[0] = FinalScore[0] + Gap; //Del(F1[i-1]); + for (size_t j = 1; j<=SizeSeq2; j++) { + ScoreSystemType ScoreSub = BaseType::match(*(Begin1+(i-1)),*(Begin2+(j-1))) ? (FinalScore[j-1] + Match) : Mismatch; + ScoreSystemType ScoreDel = FinalScore[j] + Gap; //Del(F1[i-1]); + ScoreSystemType ScoreIns = ScoreAux[j-1] + Gap; //Ins(F2[j-1]); + ScoreAux[j] = std::max(std::max(ScoreSub,ScoreDel),ScoreIns); + } + std::swap(FinalScore,ScoreAux); + } + } + } + //last score is in FinalScore + } + + template + void HirschbergRec(ArrayType &Seq1, ArrayType &Seq2, AlignedSequence &Res) { + if (Seq1.size()==0) { + for (auto Char : Seq2) { + Res.Data.push_back(typename BaseType::EntryType(Blank,Char,false)); + } + } else if (Seq2.size()==0) { + for (auto Char : Seq1) { + Res.Data.push_back(typename BaseType::EntryType(Char,Blank,false)); + } + } else if (Seq1.size()==1 || Seq2.size()==1) { + NeedlemanWunschSA, Ty, Blank, MatchFnTy> SA( + BaseType::getScoring(), + BaseType::getMatchOperation()); + AlignedSequence NWResult = SA.getAlignment(Seq1, Seq2); + Res.splice(NWResult); + } else { + int Seq1Mid = Seq1.size()/2; + + NWScore(Seq1.begin(),Seq1.begin()+Seq1Mid,Seq2.begin(), Seq2.end()); + std::swap(FinalScore,ScoreCache); + + ArrayType SlicedSeq1(Seq1); + SlicedSeq1.sliceWindow(Seq1Mid,Seq1.size()); + NWScore(SlicedSeq1.rbegin(), SlicedSeq1.rend(), Seq2.rbegin(), Seq2.rend()); + + size_t Seq2Mid = 0; + int MaxScore = std::numeric_limits::min(); + size_t Size2 = Seq2.size(); + for (size_t i = 0; i=MaxScore) { + MaxScore = S; + Seq2Mid = i; + } + } + + ArrayType NewSeq1L(Seq1); + NewSeq1L.sliceWindow(0,Seq1Mid); + ArrayType NewSeq2L(Seq2); + NewSeq2L.sliceWindow(0,Seq2Mid); + HirschbergRec(NewSeq1L, NewSeq2L, Res); + + ArrayType NewSeq1R(Seq1); + NewSeq1R.sliceWindow(Seq1Mid,Seq1.size()); + ArrayType NewSeq2R(Seq2); + NewSeq2R.sliceWindow(Seq2Mid,Seq2.size()); + HirschbergRec(NewSeq1R, NewSeq2R, Res); + } + } + +public: + + + HirschbergSA() + : BaseType(NeedlemanWunschSA, Ty, Blank, MatchFnTy>::getDefaultScoring(), nullptr) {} + + HirschbergSA(ScoringSystem Scoring, MatchFnTy Match = nullptr) + : BaseType(Scoring, Match) {} + + virtual size_t getMemoryRequirement(ContainerType &Seq1, + ContainerType &Seq2) { + size_t MemorySize = sizeof(ScoreSystemType)*(3*(Seq2.size()+1)); + + if (BaseType::getMatchOperation() != nullptr) + MemorySize += sizeof(bool)*(3*(Seq2.size()+1)); + + return MemorySize; + } + + virtual AlignedSequence getAlignment(ContainerType &Seq1, ContainerType &Seq2) { + AlignedSequence Result; + ScoreSystemType *ScoreContainer = new ScoreSystemType[3*(Seq2.size()+1)]; + FinalScore = &ScoreContainer[0]; + ScoreAux = &ScoreContainer[Seq2.size()+1]; + ScoreCache = &ScoreContainer[2*(Seq2.size()+1)]; + ArrayView< ContainerType > View1(Seq1); + ArrayView< ContainerType > View2(Seq2); + HirschbergRec(View1,View2,Result); + delete []ScoreContainer; + return Result; + } + +}; diff --git a/llvm/include/llvm/ADT/SANeedlemanWunsch.h b/llvm/include/llvm/ADT/SANeedlemanWunsch.h new file mode 100644 index 0000000000000000000000000000000000000000..2238717c79745caae84a45e96fab93f47e8ccb98 --- /dev/null +++ b/llvm/include/llvm/ADT/SANeedlemanWunsch.h @@ -0,0 +1,268 @@ +template > +class NeedlemanWunschSA + : public SequenceAligner { +private: + ScoreSystemType *Matrix; + size_t MatrixRows; + size_t MatrixCols; + bool *Matches; + size_t MatchesRows; + size_t MatchesCols; + + const static unsigned END = 0; + const static unsigned DIAGONAL = 1; + const static unsigned UP = 2; + const static unsigned LEFT = 3; + + size_t MaxRow; + size_t MaxCol; + + using BaseType = SequenceAligner; + + void cacheAllMatches(ContainerType &Seq1, ContainerType &Seq2) { + if (BaseType::getMatchOperation() == nullptr) { + Matches = nullptr; + return; + } + const size_t SizeSeq1 = Seq1.size(); + const size_t SizeSeq2 = Seq2.size(); + + MatchesRows = SizeSeq1; + MatchesCols = SizeSeq2; + Matches = new bool[SizeSeq1 * SizeSeq2]; + for (unsigned i = 0; i < SizeSeq1; i++) + for (unsigned j = 0; j < SizeSeq2; j++) + Matches[i * SizeSeq2 + j] = BaseType::match(Seq1[i], Seq2[j]); + } + + void computeScoreMatrix(ContainerType &Seq1, ContainerType &Seq2) { + const size_t SizeSeq1 = Seq1.size(); + const size_t SizeSeq2 = Seq2.size(); + + const size_t NumRows = SizeSeq1 + 1; + const size_t NumCols = SizeSeq2 + 1; + Matrix = new ScoreSystemType[NumRows * NumCols]; + MatrixRows = NumRows; + MatrixCols = NumCols; + + ScoringSystem &Scoring = BaseType::getScoring(); + const ScoreSystemType Gap = Scoring.getGapPenalty(); + const ScoreSystemType Match = Scoring.getMatchProfit(); + const bool AllowMismatch = Scoring.getAllowMismatch(); + const ScoreSystemType Mismatch = + AllowMismatch ? Scoring.getMismatchPenalty() + : std::numeric_limits::min(); + + for (unsigned i = 0; i < NumRows; i++) + Matrix[i * NumCols + 0] = i * Gap; + for (unsigned j = 0; j < NumCols; j++) + Matrix[0 * NumCols + j] = j * Gap; + + ScoreSystemType MaxScore = std::numeric_limits::min(); + if (Matches) { + if (AllowMismatch) { + for (unsigned i = 1; i < NumRows; i++) { + for (unsigned j = 1; j < NumCols; j++) { + ScoreSystemType Similarity = + Matches[(i - 1) * MatchesCols + j - 1] ? Match : Mismatch; + ScoreSystemType Diagonal = + Matrix[(i - 1) * NumCols + j - 1] + Similarity; + ScoreSystemType Upper = Matrix[(i - 1) * NumCols + j] + Gap; + ScoreSystemType Left = Matrix[i * NumCols + j - 1] + Gap; + ScoreSystemType Score = std::max(std::max(Diagonal, Upper), Left); + Matrix[i * NumCols + j] = Score; + if (Score >= MaxScore) { + MaxScore = Score; + MaxRow = i; + MaxCol = j; + } + } + } + } else { + for (unsigned i = 1; i < NumRows; i++) { + for (unsigned j = 1; j < NumCols; j++) { + ScoreSystemType Diagonal = + Matches[(i - 1) * MatchesCols + j - 1] + ? (Matrix[(i - 1) * NumCols + j - 1] + Match) + : Mismatch; + ScoreSystemType Upper = Matrix[(i - 1) * NumCols + j] + Gap; + ScoreSystemType Left = Matrix[i * NumCols + j - 1] + Gap; + ScoreSystemType Score = std::max(std::max(Diagonal, Upper), Left); + Matrix[i * NumCols + j] = Score; + if (Score >= MaxScore) { + MaxScore = Score; + MaxRow = i; + MaxCol = j; + } + } + } + } + } else { + if (AllowMismatch) { + for (unsigned i = 1; i < NumRows; i++) { + for (unsigned j = 1; j < NumCols; j++) { + ScoreSystemType Similarity = + (Seq1[i - 1] == Seq2[j - 1]) ? Match : Mismatch; + ScoreSystemType Diagonal = + Matrix[(i - 1) * NumCols + j - 1] + Similarity; + ScoreSystemType Upper = Matrix[(i - 1) * NumCols + j] + Gap; + ScoreSystemType Left = Matrix[i * NumCols + j - 1] + Gap; + ScoreSystemType Score = std::max(std::max(Diagonal, Upper), Left); + Matrix[i * NumCols + j] = Score; + if (Score >= MaxScore) { + MaxScore = Score; + MaxRow = i; + MaxCol = j; + } + } + } + } else { + for (unsigned i = 1; i < NumRows; i++) { + for (unsigned j = 1; j < NumCols; j++) { + ScoreSystemType Diagonal = + (Seq1[i - 1] == Seq2[j - 1]) + ? (Matrix[(i - 1) * NumCols + j - 1] + Match) + : Mismatch; + ScoreSystemType Upper = Matrix[(i - 1) * NumCols + j] + Gap; + ScoreSystemType Left = Matrix[i * NumCols + j - 1] + Gap; + ScoreSystemType Score = std::max(std::max(Diagonal, Upper), Left); + Matrix[i * NumCols + j] = Score; + if (Score >= MaxScore) { + MaxScore = Score; + MaxRow = i; + MaxCol = j; + } + } + } + } + } + } + + void buildResult(ContainerType &Seq1, ContainerType &Seq2, + AlignedSequence &Result) { + auto &Data = Result.Data; + + ScoringSystem &Scoring = BaseType::getScoring(); + const ScoreSystemType Gap = Scoring.getGapPenalty(); + const ScoreSystemType Match = Scoring.getMatchProfit(); + const bool AllowMismatch = Scoring.getAllowMismatch(); + const ScoreSystemType Mismatch = + AllowMismatch ? Scoring.getMismatchPenalty() + : std::numeric_limits::min(); + + int i = MatrixRows - 1, j = MatrixCols - 1; + + size_t LongestMatch = 0; + size_t CurrentMatch = 0; + + while (i > 0 || j > 0) { + if (i > 0 && j > 0) { + // Diagonal + + bool IsValidMatch = false; + + ScoreSystemType Score = std::numeric_limits::min(); + if (Matches) { + IsValidMatch = Matches[(i - 1) * MatchesCols + j - 1]; + } else { + IsValidMatch = (Seq1[i - 1] == Seq2[j - 1]); + } + + if (!IsValidMatch) { + if (CurrentMatch > LongestMatch) + LongestMatch = CurrentMatch; + CurrentMatch = 0; + } else + CurrentMatch += 1; + + if (AllowMismatch) { + Score = Matrix[(i - 1) * MatrixCols + j - 1] + + (IsValidMatch ? Match : Mismatch); + } else { + Score = IsValidMatch ? (Matrix[(i - 1) * MatrixCols + j - 1] + Match) + : Mismatch; + } + + if (Matrix[i * MatrixCols + j] == Score) { + if (IsValidMatch || AllowMismatch) { + Data.push_front(typename BaseType::EntryType( + Seq1[i - 1], Seq2[j - 1], IsValidMatch)); + } else { + Data.push_front( + typename BaseType::EntryType(Seq1[i - 1], Blank, false)); + Data.push_front( + typename BaseType::EntryType(Blank, Seq2[j - 1], false)); + } + + i--; + j--; + continue; + } + } + if (i > 0 && Matrix[i * MatrixCols + j] == + (Matrix[(i - 1) * MatrixCols + j] + Gap)) { + // Up + Data.push_front( + typename BaseType::EntryType(Seq1[i - 1], Blank, false)); + i--; + } + else if (j > 0 && Matrix[i * MatrixCols + j] == + (Matrix[i * MatrixCols + (j - 1)] + Gap)) { + // Left + Data.push_front( + typename BaseType::EntryType(Blank, Seq2[j - 1], false)); + j--; + } + } + + if (CurrentMatch > LongestMatch) + LongestMatch = CurrentMatch; + } + + void clearAll() { + if (Matrix) + delete[] Matrix; + if (Matches) + delete[] Matches; + Matrix = nullptr; + Matches = nullptr; + } + +public: + static ScoringSystem getDefaultScoring() { return ScoringSystem(-1, 2, -1); } + + NeedlemanWunschSA() + : BaseType(getDefaultScoring(), nullptr), Matrix(nullptr), + Matches(nullptr) {} + + NeedlemanWunschSA(ScoringSystem Scoring, MatchFnTy Match = nullptr) + : BaseType(Scoring, Match), Matrix(nullptr), Matches(nullptr) {} + + ~NeedlemanWunschSA() {clearAll();} + + virtual size_t getMemoryRequirement(ContainerType &Seq1, + ContainerType &Seq2) override { + const size_t SizeSeq1 = Seq1.size(); + const size_t SizeSeq2 = Seq2.size(); + size_t MemorySize = 0; + + MemorySize += sizeof(ScoreSystemType)*(SizeSeq1+1)*(SizeSeq2+1); + + if (BaseType::getMatchOperation() != nullptr) + MemorySize += SizeSeq1*SizeSeq2*sizeof(bool); + + return MemorySize; + } + + virtual AlignedSequence getAlignment(ContainerType &Seq1, + ContainerType &Seq2) override { + AlignedSequence Result; + cacheAllMatches(Seq1, Seq2); + computeScoreMatrix(Seq1, Seq2); + buildResult(Seq1, Seq2, Result); + clearAll(); + return Result; + } +}; diff --git a/llvm/include/llvm/ADT/SequenceAlignment.h b/llvm/include/llvm/ADT/SequenceAlignment.h new file mode 100644 index 0000000000000000000000000000000000000000..e874c4678242b285172c75b68eb72ae50c4fac5c --- /dev/null +++ b/llvm/include/llvm/ADT/SequenceAlignment.h @@ -0,0 +1,168 @@ +//===-- llvm/ADT/SequenceAlignment.h - Sequence Alignment -------*- C++ -*-===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// Provides efficient implementations of different algorithms for sequence +// alignment. +// +//===----------------------------------------------------------------------===// + +#ifndef LLVM_ADT_SEQUENCE_ALIGNMENT_H +#define LLVM_ADT_SEQUENCE_ALIGNMENT_H + +#include +#include +#include +#include +#include // INT_MIN +#include + +#include "llvm/ADT/ArrayView.h" + +#define ScoreSystemType int + +// Store alignment result here +template +class AlignedSequence { +public: + + class Entry { + private: + //TODO: change it for a vector for Multi-Sequence Alignment + std::pair Pair; + bool IsMatchingPair; + public: + Entry() { IsMatchingPair = false; } + + Entry(Ty V1, Ty V2) : Pair(V1,V2) { IsMatchingPair = !hasBlank(); } + + Entry(Ty V1, Ty V2, bool Matching) : Pair(V1,V2), IsMatchingPair(Matching) {} + + Ty get(size_t index) const { + assert((index==0 || index==1) && "Index out of bounds!"); + if (index==0) return Pair.first; + else return Pair.second; + } + + bool empty() const { return (Pair.first==Blank && Pair.second==Blank); } + bool hasBlank() const { return (Pair.first==Blank || Pair.second==Blank); } + + bool match() const { return IsMatchingPair; } + bool mismatch() const { return (!IsMatchingPair); } + + Ty getNonBlank() const { + if (Pair.first != Blank) + return Pair.first; + else + return Pair.second; + } + + }; + + std::list< Entry > Data; + size_t LargestMatch{0}; + + AlignedSequence() = default; + + AlignedSequence(const AlignedSequence &Other) : Data(Other.Data), LargestMatch(Other.LargestMatch) {} + AlignedSequence(AlignedSequence &&Other) : Data(std::move(Other.Data)), LargestMatch(Other.LargestMatch) {} + + AlignedSequence &operator=(const AlignedSequence &Other) { + Data = Other.Data; + LargestMatch = Other.LargestMatch; + return (*this); + } + + void append(const AlignedSequence &Other) { + Data.insert(Data.end(), Other.Data.begin(), Other.Data.end()); + } + + void splice(AlignedSequence &Other) { + Data.splice(Data.end(), Other.Data); + } + + typename std::list< Entry >::iterator begin() { return Data.begin(); } + typename std::list< Entry >::iterator end() { return Data.end(); } + typename std::list< Entry >::const_iterator begin() const { return Data.cbegin(); } + typename std::list< Entry >::const_iterator end() const { return Data.cend(); } + + size_t size() { return Data.size(); } + +}; + +class ScoringSystem { + ScoreSystemType Gap; + ScoreSystemType Match; + ScoreSystemType Mismatch; + bool AllowMismatch; +public: + ScoringSystem(ScoreSystemType Gap, ScoreSystemType Match) { + this->Gap = Gap; + this->Match = Match; + this->Mismatch = std::numeric_limits::min(); + this->AllowMismatch = false; + } + + ScoringSystem(ScoreSystemType Gap, ScoreSystemType Match, ScoreSystemType Mismatch, bool AllowMismatch = true) { + this->Gap = Gap; + this->Match = Match; + this->Mismatch = Mismatch; + this->AllowMismatch = AllowMismatch; + } + + bool getAllowMismatch() { + return AllowMismatch; + } + + ScoreSystemType getMismatchPenalty() { + return Mismatch; + } + + ScoreSystemType getGapPenalty() { + return Gap; + } + + ScoreSystemType getMatchProfit() { + return Match; + } +}; + +template> +class SequenceAligner { +private: + ScoringSystem Scoring; + MatchFnTy Match; + +public: + + using EntryType = typename AlignedSequence::Entry; + + SequenceAligner(ScoringSystem Scoring, MatchFnTy Match = nullptr) + : Scoring(Scoring), Match(Match) {} + + virtual ~SequenceAligner() = default; + + ScoringSystem &getScoring() { return Scoring; } + + bool match(Ty Val1, Ty Val2) { + return Match(Val1,Val2); + } + + MatchFnTy getMatchOperation() { return Match; } + + Ty getBlank() { return Blank; } + + virtual AlignedSequence getAlignment(ContainerType &Seq0, ContainerType &Seq1) = 0; + virtual size_t getMemoryRequirement(ContainerType &Seq0, ContainerType &Seq1) = 0; +}; + +#include "llvm/ADT/SANeedlemanWunsch.h" +#include "llvm/ADT/SAHirschberg.h" +#include "llvm/ADT/SADiagonalWindows.h" + +#endif diff --git a/llvm/include/llvm/IR/Attributes.inc b/llvm/include/llvm/IR/Attributes.inc new file mode 100644 index 0000000000000000000000000000000000000000..a497a8a329973fbd307ccaac73ef59306b9f53c8 --- /dev/null +++ b/llvm/include/llvm/IR/Attributes.inc @@ -0,0 +1,487 @@ +------------- Classes ----------------- +class Attr Attr:P = ?> { + string AttrString = Attr:S; + list Properties = Attr:P; +} +class AttrProperty { +} +class CompatRule { + string CompatFunc = CompatRule:F; +} +class EnumAttr EnumAttr:P = ?> { // Attr + string AttrString = EnumAttr:S; + list Properties = EnumAttr:P; +} +class IntAttr IntAttr:P = ?> { // Attr + string AttrString = IntAttr:S; + list Properties = IntAttr:P; +} +class MergeRule { + string MergeFunc = MergeRule:F; +} +class StrBoolAttr { // Attr + string AttrString = StrBoolAttr:S; + list Properties = []; +} +class TypeAttr TypeAttr:P = ?> { // Attr + string AttrString = TypeAttr:S; + list Properties = TypeAttr:P; +} +------------- Defs ----------------- +def Alignment { // Attr IntAttr + string AttrString = "align"; + list Properties = [ParamAttr, RetAttr]; +} +def AllocAlign { // Attr EnumAttr + string AttrString = "allocalign"; + list Properties = [ParamAttr]; +} +def AllocKind { // Attr IntAttr + string AttrString = "allockind"; + list Properties = [FnAttr]; +} +def AllocSize { // Attr IntAttr + string AttrString = "allocsize"; + list Properties = [FnAttr]; +} +def AllocatedPointer { // Attr EnumAttr + string AttrString = "allocptr"; + list Properties = [ParamAttr]; +} +def AlwaysInline { // Attr EnumAttr + string AttrString = "alwaysinline"; + list Properties = [FnAttr]; +} +def ApproxFuncFPMath { // Attr StrBoolAttr + string AttrString = "approx-func-fp-math"; + list Properties = []; +} +def ArgMemOnly { // Attr EnumAttr + string AttrString = "argmemonly"; + list Properties = [FnAttr]; +} +def Builtin { // Attr EnumAttr + string AttrString = "builtin"; + list Properties = [FnAttr]; +} +def ByRef { // Attr TypeAttr + string AttrString = "byref"; + list Properties = [ParamAttr]; +} +def ByVal { // Attr TypeAttr + string AttrString = "byval"; + list Properties = [ParamAttr]; +} +def Cold { // Attr EnumAttr + string AttrString = "cold"; + list Properties = [FnAttr]; +} +def Convergent { // Attr EnumAttr + string AttrString = "convergent"; + list Properties = [FnAttr]; +} +def Dereferenceable { // Attr IntAttr + string AttrString = "dereferenceable"; + list Properties = [ParamAttr, RetAttr]; +} +def DereferenceableOrNull { // Attr IntAttr + string AttrString = "dereferenceable_or_null"; + list Properties = [ParamAttr, RetAttr]; +} +def DisableSanitizerInstrumentation { // Attr EnumAttr + string AttrString = "disable_sanitizer_instrumentation"; + list Properties = [FnAttr]; +} +def ElementType { // Attr TypeAttr + string AttrString = "elementtype"; + list Properties = [ParamAttr]; +} +def FnAttr { // AttrProperty +} +def FnRetThunkExtern { // Attr EnumAttr + string AttrString = "fn_ret_thunk_extern"; + list Properties = [FnAttr]; +} +def Hot { // Attr EnumAttr + string AttrString = "hot"; + list Properties = [FnAttr]; +} +def ImmArg { // Attr EnumAttr + string AttrString = "immarg"; + list Properties = [ParamAttr]; +} +def InAlloca { // Attr TypeAttr + string AttrString = "inalloca"; + list Properties = [ParamAttr]; +} +def InReg { // Attr EnumAttr + string AttrString = "inreg"; + list Properties = [ParamAttr, RetAttr]; +} +def InaccessibleMemOnly { // Attr EnumAttr + string AttrString = "inaccessiblememonly"; + list Properties = [FnAttr]; +} +def InaccessibleMemOrArgMemOnly { // Attr EnumAttr + string AttrString = "inaccessiblemem_or_argmemonly"; + list Properties = [FnAttr]; +} +def InlineHint { // Attr EnumAttr + string AttrString = "inlinehint"; + list Properties = [FnAttr]; +} +def JumpTable { // Attr EnumAttr + string AttrString = "jumptable"; + list Properties = [FnAttr]; +} +def LessPreciseFPMAD { // Attr StrBoolAttr + string AttrString = "less-precise-fpmad"; + list Properties = []; +} +def MinSize { // Attr EnumAttr + string AttrString = "minsize"; + list Properties = [FnAttr]; +} +def MustProgress { // Attr EnumAttr + string AttrString = "mustprogress"; + list Properties = [FnAttr]; +} +def Naked { // Attr EnumAttr + string AttrString = "naked"; + list Properties = [FnAttr]; +} +def Nest { // Attr EnumAttr + string AttrString = "nest"; + list Properties = [ParamAttr]; +} +def NoAlias { // Attr EnumAttr + string AttrString = "noalias"; + list Properties = [ParamAttr, RetAttr]; +} +def NoBuiltin { // Attr EnumAttr + string AttrString = "nobuiltin"; + list Properties = [FnAttr]; +} +def NoCallback { // Attr EnumAttr + string AttrString = "nocallback"; + list Properties = [FnAttr]; +} +def NoCapture { // Attr EnumAttr + string AttrString = "nocapture"; + list Properties = [ParamAttr]; +} +def NoCfCheck { // Attr EnumAttr + string AttrString = "nocf_check"; + list Properties = [FnAttr]; +} +def NoDuplicate { // Attr EnumAttr + string AttrString = "noduplicate"; + list Properties = [FnAttr]; +} +def NoFree { // Attr EnumAttr + string AttrString = "nofree"; + list Properties = [FnAttr, ParamAttr]; +} +def NoImplicitFloat { // Attr EnumAttr + string AttrString = "noimplicitfloat"; + list Properties = [FnAttr]; +} +def NoInfsFPMath { // Attr StrBoolAttr + string AttrString = "no-infs-fp-math"; + list Properties = []; +} +def NoInline { // Attr EnumAttr + string AttrString = "noinline"; + list Properties = [FnAttr]; +} +def NoInlineLineTables { // Attr StrBoolAttr + string AttrString = "no-inline-line-tables"; + list Properties = []; +} +def NoJumpTables { // Attr StrBoolAttr + string AttrString = "no-jump-tables"; + list Properties = []; +} +def NoMerge { // Attr EnumAttr + string AttrString = "nomerge"; + list Properties = [FnAttr]; +} +def NoNansFPMath { // Attr StrBoolAttr + string AttrString = "no-nans-fp-math"; + list Properties = []; +} +def NoProfile { // Attr EnumAttr + string AttrString = "noprofile"; + list Properties = [FnAttr]; +} +def NoRecurse { // Attr EnumAttr + string AttrString = "norecurse"; + list Properties = [FnAttr]; +} +def NoRedZone { // Attr EnumAttr + string AttrString = "noredzone"; + list Properties = [FnAttr]; +} +def NoReturn { // Attr EnumAttr + string AttrString = "noreturn"; + list Properties = [FnAttr]; +} +def NoSanitizeBounds { // Attr EnumAttr + string AttrString = "nosanitize_bounds"; + list Properties = [FnAttr]; +} +def NoSanitizeCoverage { // Attr EnumAttr + string AttrString = "nosanitize_coverage"; + list Properties = [FnAttr]; +} +def NoSignedZerosFPMath { // Attr StrBoolAttr + string AttrString = "no-signed-zeros-fp-math"; + list Properties = []; +} +def NoSync { // Attr EnumAttr + string AttrString = "nosync"; + list Properties = [FnAttr]; +} +def NoUndef { // Attr EnumAttr + string AttrString = "noundef"; + list Properties = [ParamAttr, RetAttr]; +} +def NoUnwind { // Attr EnumAttr + string AttrString = "nounwind"; + list Properties = [FnAttr]; +} +def NonLazyBind { // Attr EnumAttr + string AttrString = "nonlazybind"; + list Properties = [FnAttr]; +} +def NonNull { // Attr EnumAttr + string AttrString = "nonnull"; + list Properties = [ParamAttr, RetAttr]; +} +def NullPointerIsValid { // Attr EnumAttr + string AttrString = "null_pointer_is_valid"; + list Properties = [FnAttr]; +} +def OptForFuzzing { // Attr EnumAttr + string AttrString = "optforfuzzing"; + list Properties = [FnAttr]; +} +def OptimizeForSize { // Attr EnumAttr + string AttrString = "optsize"; + list Properties = [FnAttr]; +} +def OptimizeNone { // Attr EnumAttr + string AttrString = "optnone"; + list Properties = [FnAttr]; +} +def ParamAttr { // AttrProperty +} +def Preallocated { // Attr TypeAttr + string AttrString = "preallocated"; + list Properties = [FnAttr, ParamAttr]; +} +def PresplitCoroutine { // Attr EnumAttr + string AttrString = "presplitcoroutine"; + list Properties = [FnAttr]; +} +def ProfileSampleAccurate { // Attr StrBoolAttr + string AttrString = "profile-sample-accurate"; + list Properties = []; +} +def ReadNone { // Attr EnumAttr + string AttrString = "readnone"; + list Properties = [FnAttr, ParamAttr]; +} +def ReadOnly { // Attr EnumAttr + string AttrString = "readonly"; + list Properties = [FnAttr, ParamAttr]; +} +def RetAttr { // AttrProperty +} +def Returned { // Attr EnumAttr + string AttrString = "returned"; + list Properties = [ParamAttr]; +} +def ReturnsTwice { // Attr EnumAttr + string AttrString = "returns_twice"; + list Properties = [FnAttr]; +} +def SExt { // Attr EnumAttr + string AttrString = "signext"; + list Properties = [ParamAttr, RetAttr]; +} +def SafeStack { // Attr EnumAttr + string AttrString = "safestack"; + list Properties = [FnAttr]; +} +def SanitizeAddress { // Attr EnumAttr + string AttrString = "sanitize_address"; + list Properties = [FnAttr]; +} +def SanitizeHWAddress { // Attr EnumAttr + string AttrString = "sanitize_hwaddress"; + list Properties = [FnAttr]; +} +def SanitizeMemTag { // Attr EnumAttr + string AttrString = "sanitize_memtag"; + list Properties = [FnAttr]; +} +def SanitizeMemory { // Attr EnumAttr + string AttrString = "sanitize_memory"; + list Properties = [FnAttr]; +} +def SanitizeThread { // Attr EnumAttr + string AttrString = "sanitize_thread"; + list Properties = [FnAttr]; +} +def ShadowCallStack { // Attr EnumAttr + string AttrString = "shadowcallstack"; + list Properties = [FnAttr]; +} +def Speculatable { // Attr EnumAttr + string AttrString = "speculatable"; + list Properties = [FnAttr]; +} +def SpeculativeLoadHardening { // Attr EnumAttr + string AttrString = "speculative_load_hardening"; + list Properties = [FnAttr]; +} +def StackAlignment { // Attr IntAttr + string AttrString = "alignstack"; + list Properties = [FnAttr, ParamAttr]; +} +def StackProtect { // Attr EnumAttr + string AttrString = "ssp"; + list Properties = [FnAttr]; +} +def StackProtectReq { // Attr EnumAttr + string AttrString = "sspreq"; + list Properties = [FnAttr]; +} +def StackProtectStrong { // Attr EnumAttr + string AttrString = "sspstrong"; + list Properties = [FnAttr]; +} +def StrictFP { // Attr EnumAttr + string AttrString = "strictfp"; + list Properties = [FnAttr]; +} +def StructRet { // Attr TypeAttr + string AttrString = "sret"; + list Properties = [ParamAttr]; +} +def SwiftAsync { // Attr EnumAttr + string AttrString = "swiftasync"; + list Properties = [ParamAttr]; +} +def SwiftError { // Attr EnumAttr + string AttrString = "swifterror"; + list Properties = [ParamAttr]; +} +def SwiftSelf { // Attr EnumAttr + string AttrString = "swiftself"; + list Properties = [ParamAttr]; +} +def UWTable { // Attr IntAttr + string AttrString = "uwtable"; + list Properties = [FnAttr]; +} +def UnsafeFPMath { // Attr StrBoolAttr + string AttrString = "unsafe-fp-math"; + list Properties = []; +} +def UseSampleProfile { // Attr StrBoolAttr + string AttrString = "use-sample-profile"; + list Properties = []; +} +def VScaleRange { // Attr IntAttr + string AttrString = "vscale_range"; + list Properties = [FnAttr]; +} +def WillReturn { // Attr EnumAttr + string AttrString = "willreturn"; + list Properties = [FnAttr]; +} +def WriteOnly { // Attr EnumAttr + string AttrString = "writeonly"; + list Properties = [FnAttr, ParamAttr]; +} +def ZExt { // Attr EnumAttr + string AttrString = "zeroext"; + list Properties = [ParamAttr, RetAttr]; +} +def anonymous_0 { // CompatRule + string CompatFunc = "isEqual"; +} +def anonymous_1 { // CompatRule + string CompatFunc = "isEqual"; +} +def anonymous_10 { // MergeRule + string MergeFunc = "setAND"; +} +def anonymous_11 { // MergeRule + string MergeFunc = "setAND"; +} +def anonymous_12 { // MergeRule + string MergeFunc = "setAND"; +} +def anonymous_13 { // MergeRule + string MergeFunc = "setAND"; +} +def anonymous_14 { // MergeRule + string MergeFunc = "setAND"; +} +def anonymous_15 { // MergeRule + string MergeFunc = "setOR"; +} +def anonymous_16 { // MergeRule + string MergeFunc = "setOR"; +} +def anonymous_17 { // MergeRule + string MergeFunc = "setOR"; +} +def anonymous_18 { // MergeRule + string MergeFunc = "setOR"; +} +def anonymous_19 { // MergeRule + string MergeFunc = "adjustCallerSSPLevel"; +} +def anonymous_2 { // CompatRule + string CompatFunc = "isEqual"; +} +def anonymous_20 { // MergeRule + string MergeFunc = "adjustCallerStackProbes"; +} +def anonymous_21 { // MergeRule + string MergeFunc = "adjustCallerStackProbeSize"; +} +def anonymous_22 { // MergeRule + string MergeFunc = "adjustMinLegalVectorWidth"; +} +def anonymous_23 { // MergeRule + string MergeFunc = "adjustNullPointerValidAttr"; +} +def anonymous_24 { // MergeRule + string MergeFunc = "setAND"; +} +def anonymous_3 { // CompatRule + string CompatFunc = "isEqual"; +} +def anonymous_4 { // CompatRule + string CompatFunc = "isEqual"; +} +def anonymous_5 { // CompatRule + string CompatFunc = "isEqual"; +} +def anonymous_6 { // CompatRule + string CompatFunc = "isEqual"; +} +def anonymous_7 { // CompatRule + string CompatFunc = "isEqual"; +} +def anonymous_8 { // CompatRule + string CompatFunc = "isEqual"; +} +def anonymous_9 { // MergeRule + string MergeFunc = "setAND"; +} diff --git a/llvm/include/llvm/IR/Function.h b/llvm/include/llvm/IR/Function.h index 7945c64c86103f39dff8356a2636ff845a17cd01..7f4395f94f53b6adbd207aade1cbb63239c41985 100644 --- a/llvm/include/llvm/IR/Function.h +++ b/llvm/include/llvm/IR/Function.h @@ -662,7 +662,9 @@ public: /// Optimize this function for size (-Os) or minimum size (-Oz). bool hasOptSize() const { - return hasFnAttribute(Attribute::OptimizeForSize) || hasMinSize(); + //for size + return true; + //return hasFnAttribute(Attribute::OptimizeForSize) || hasMinSize(); } /// Returns the denormal handling type for the default rounding mode of the diff --git a/llvm/include/llvm/Transforms/IPO/FunctionMerging.h b/llvm/include/llvm/Transforms/IPO/FunctionMerging.h new file mode 100644 index 0000000000000000000000000000000000000000..62ebbe7ed496cbbb55288245424cb531f113dd2a --- /dev/null +++ b/llvm/include/llvm/Transforms/IPO/FunctionMerging.h @@ -0,0 +1,425 @@ +//===- FunctionMerging.h - A function merging pass ----------------------===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// This file implements the general function merging optimization. +// +// It identifies similarities between functions, and If profitable, merges them +// into a single function, replacing the original ones. Functions do not need +// to be identical to be merged. In fact, there is very little restriction to +// merge two function, however, the produced merged function can be larger than +// the two original functions together. For that reason, it uses the +// TargetTransformInfo analysis to estimate the code-size costs of instructions +// in order to estimate the profitability of merging two functions. +// +// This function merging transformation has three major parts: +// 1. The input functions are linearized, representing their CFGs as sequences +// of labels and instructions. +// 2. We apply a sequence alignment algorithm, namely, the Needleman-Wunsch +// algorithm, to identify similar code between the two linearized functions. +// 3. We use the aligned sequences to perform code generate, producing the new +// merged function, using an extra parameter to represent the function +// identifier. +// +// This pass integrates the function merging transformation with an exploration +// framework. For every function, the other functions are ranked based their +// degree of similarity, which is computed from the functions' fingerprints. +// Only the top candidates are analyzed in a greedy manner and if one of them +// produces a profitable result, the merged function is taken. +// +//===----------------------------------------------------------------------===// +// +// This optimization was proposed in +// +// Function Merging by Sequence Alignment: An Interprocedural Code-Size +// Optimization +// Rodrigo C. O. Rocha, Pavlos Petoumenos, Zheng Wang, Murray Cole, Hugh Leather +// +//===----------------------------------------------------------------------===// + +#ifndef LLVM_TRANSFORMS_IPO_FUNCTIONMERGING_F3M_H +#define LLVM_TRANSFORMS_IPO_FUNCTIONMERGING_F3M_H + +#include "llvm/ADT/SequenceAlignment.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/StringSet.h" + +#include "llvm/Analysis/BlockFrequencyInfo.h" +#include "llvm/Analysis/ProfileSummaryInfo.h" +#include "llvm/Analysis/TargetTransformInfo.h" + +#include "llvm/InitializePasses.h" + +#include "llvm/IR/Function.h" +#include "llvm/IR/IRBuilder.h" +#include "llvm/IR/Module.h" +#include "llvm/IR/PassManager.h" + +#include "llvm/Transforms/Utils/Cloning.h" +#include "llvm/Transforms/IPO/SearchStrategy.h" + +#include +#include + +namespace llvm { + +/// A set of parameters used to control the transforms by MergeFunctions. +struct FunctionMergingOptions { + bool MaximizeParamScore; + bool IdenticalTypesOnly; + bool EnableUnifiedReturnType; + + FunctionMergingOptions(bool MaximizeParamScore = true, + bool IdenticalTypesOnly = true, + bool EnableUnifiedReturnType = true) + : MaximizeParamScore(MaximizeParamScore), + IdenticalTypesOnly(IdenticalTypesOnly), + EnableUnifiedReturnType(EnableUnifiedReturnType) {} + + FunctionMergingOptions &maximizeParameterScore(bool MPS) { + MaximizeParamScore = MPS; + return *this; + } + + FunctionMergingOptions &matchOnlyIdenticalTypes(bool IT) { + IdenticalTypesOnly = IT; + return *this; + } + + FunctionMergingOptions &enableUnifiedReturnTypes(bool URT) { + EnableUnifiedReturnType = URT; + return *this; + } +}; + +class AlignedCode : public AlignedSequence { + public: + int Insts{0}; + int Matches{0}; + int CoreMatches{0}; + + AlignedCode() = default; + + AlignedCode(const AlignedCode &Other) : + AlignedSequence(Other), Insts{Other.Insts}, + Matches{Other.Matches}, CoreMatches{Other.CoreMatches} {} + + AlignedCode(AlignedCode &&Other) : + AlignedSequence(Other), Insts{Other.Insts}, + Matches{Other.Matches}, CoreMatches{Other.CoreMatches} {} + + AlignedCode(const AlignedSequence &Other) : AlignedSequence(Other) {} + + AlignedCode(AlignedSequence &&Other) : AlignedSequence(Other) {} + + AlignedCode(BasicBlock *B1, BasicBlock *B2); + + AlignedCode &operator=(const AlignedCode &Other) { + Data = Other.Data; + LargestMatch = Other.LargestMatch; + Insts = Other.Insts; + Matches = Other.Matches; + CoreMatches = Other.CoreMatches; + return (*this); + } + + void extend(const AlignedCode &Other); + void extend(int index, const BasicBlock *BB); + + bool hasMatches() const {return (Matches == Insts) || (CoreMatches > 0);}; + bool isProfitable() const; +}; + +class FunctionMergeResult { +private: + Function *F1; + Function *F2; + Function *MergedFunction; + bool HasIdArg; + bool NeedUnifiedReturn; + std::map ParamMap1; + std::map ParamMap2; + + + FunctionMergeResult() + : F1(nullptr), F2(nullptr), MergedFunction(nullptr), HasIdArg(false), + NeedUnifiedReturn(false) {} + +public: + //feise:to check if the function is successfully merged + bool Success=true; + FunctionMergeResult(bool success) + : F1(nullptr), F2(nullptr), MergedFunction(nullptr), HasIdArg(false), + NeedUnifiedReturn(false), Success(success) {} + + FunctionMergeResult(Function *F1, Function *F2, Function *MergedFunction, + bool NeedUnifiedReturn = false) + : F1(F1), F2(F2), MergedFunction(MergedFunction), HasIdArg(true), + NeedUnifiedReturn(NeedUnifiedReturn) {} + + std::pair getFunctions() { + return std::pair(F1, F2); + } + + std::map &getArgumentMapping(Function *F) { + return (F1 == F) ? ParamMap1 : ParamMap2; + } + + Value *getFunctionIdValue(Function *F) { + if (F == F1) + return ConstantInt::getTrue(IntegerType::get(F1->getContext(), 1)); + else if (F == F2) + return ConstantInt::getFalse(IntegerType::get(F2->getContext(), 1)); + else + return nullptr; + } + + void setFunctionIdArgument(bool HasFuncIdArg) { HasIdArg = HasFuncIdArg; } + + bool hasFunctionIdArgument() { return HasIdArg; } + + void setUnifiedReturn(bool NeedUnifiedReturn) { + this->NeedUnifiedReturn = NeedUnifiedReturn; + } + + bool needUnifiedReturn() { return NeedUnifiedReturn; } + + // returns whether or not the merge operation was successful + operator bool() const { return (MergedFunction != nullptr); } + + void setArgumentMapping(Function *F, std::map &ParamMap) { + if (F == F1) + ParamMap1 = ParamMap; + else if (F == F2) + ParamMap2 = ParamMap; + } + + void addArgumentMapping(Function *F, unsigned SrcArg, unsigned DstArg) { + if (F == F1) + ParamMap1[SrcArg] = DstArg; + else if (F == F2) + ParamMap2[SrcArg] = DstArg; + } + + Function *getMergedFunction() { return MergedFunction; } + + // static const FunctionMergeResult Error; +}; + +class FunctionMerger { +private: + Module *M; + + // ProfileSummaryInfo *PSI; + function_ref LookupBFI; + + Type *IntPtrTy; + + const DataLayout *DL; + LLVMContext *ContextPtr; + + // cache of linear functions + // KeyValueCache> LFCache; + + // statistics for analyzing this optimization for future improvements + // unsigned LastMaxParamScore = 0; + // unsigned TotalParamScore = 0; + // int CountOpReorder = 0; + // int CountBinOps = 0; + + enum LinearizationKind { LK_Random, LK_Canonical }; + + void linearize(Function *F, SmallVectorImpl &FVec, + LinearizationKind LK = LinearizationKind::LK_Canonical); + + void replaceByCall(Function *F, FunctionMergeResult &MergedFunc, + const FunctionMergingOptions &Options = {}); + bool replaceCallsWith(Function *F, FunctionMergeResult &MergedFunc, + const FunctionMergingOptions &Options = {}); + + void updateCallGraph(Function *F, FunctionMergeResult &MFR, + StringSet<> &AlwaysPreserved, + const FunctionMergingOptions &Options); + +public: + FunctionMerger(Module *M) : M(M), IntPtrTy(nullptr) { + //, ProfileSummaryInfo *PSI=nullptr, function_ref LookupBFI=nullptr) : M(M), PSI(PSI), LookupBFI(LookupBFI), + // IntPtrTy(nullptr) { + if (M) { + DL = &M->getDataLayout(); + ContextPtr = &M->getContext(); + IntPtrTy = DL->getIntPtrType(*ContextPtr); + } + } + + bool validMergeTypes(Function *F1, Function *F2, + const FunctionMergingOptions &Options = {}); + + static bool areTypesEquivalent(Type *Ty1, Type *Ty2, const DataLayout *DL, + const FunctionMergingOptions &Options = {}); + + + static bool match(Value *V1, Value *V2); + static bool matchInstructions(Instruction *I1, Instruction *I2, + const FunctionMergingOptions &Options = {}); + static bool matchWholeBlocks(Value *V1, Value *V2); + static bool matchBlocks(BasicBlock *B1, BasicBlock *B2); + + void updateCallGraph(FunctionMergeResult &Result, + StringSet<> &AlwaysPreserved, + const FunctionMergingOptions &Options = {}); + + FunctionMergeResult merge(Function *F1, Function *F2, std::string Name = "", + const FunctionMergingOptions &Options = {}); + + class CodeGenerator { + private: + LLVMContext *ContextPtr; + Type *IntPtrTy; + + Value *IsFunc1; + + std::vector Blocks1; + std::vector Blocks2; + + BasicBlock *EntryBB1; + BasicBlock *EntryBB2; + BasicBlock *PreBB; + + Type *RetType1; + Type *RetType2; + Type *ReturnType; + + bool RequiresUnifiedReturn; + + Function *MergedFunc; + + SmallPtrSet CreatedBBs; + SmallPtrSet CreatedInsts; + + protected: + void removeRedundantInstructions(std::vector &WorkInst, + DominatorTree &DT); + + public: + CodeGenerator(Function* F1, Function* F2) + { + for (BasicBlock &BB: *F1) + Blocks1.push_back(&BB); + for (BasicBlock &BB: *F2) + Blocks2.push_back(&BB); + } + virtual ~CodeGenerator() {} + + CodeGenerator &setContext(LLVMContext *ContextPtr) { + this->ContextPtr = ContextPtr; + return *this; + } + + CodeGenerator &setIntPtrType(Type *IntPtrTy) { + this->IntPtrTy = IntPtrTy; + return *this; + } + + CodeGenerator &setFunctionIdentifier(Value *IsFunc1) { + this->IsFunc1 = IsFunc1; + return *this; + } + + CodeGenerator &setEntryPoints(BasicBlock *EntryBB1, BasicBlock *EntryBB2) { + this->EntryBB1 = EntryBB1; + this->EntryBB2 = EntryBB2; + return *this; + } + + CodeGenerator &setReturnTypes(Type *RetType1, Type *RetType2) { + this->RetType1 = RetType1; + this->RetType2 = RetType2; + return *this; + } + + CodeGenerator &setMergedEntryPoint(BasicBlock *PreBB) { + this->PreBB = PreBB; + return *this; + } + + CodeGenerator &setMergedReturnType(Type *ReturnType, + bool RequiresUnifiedReturn = false) { + this->ReturnType = ReturnType; + this->RequiresUnifiedReturn = RequiresUnifiedReturn; + return *this; + } + + CodeGenerator &setMergedFunction(Function *MergedFunc) { + this->MergedFunc = MergedFunc; + return *this; + } + + Function *getMergedFunction() { return MergedFunc; } + Type *getMergedReturnType() { return ReturnType; } + bool getRequiresUnifiedReturn() { return RequiresUnifiedReturn; } + + Value *getFunctionIdentifier() { return IsFunc1; } + + LLVMContext &getContext() { return *ContextPtr; } + + std::vector &getBlocks1() { return Blocks1; } + std::vector &getBlocks2() { return Blocks2; } + + BasicBlock *getEntryBlock1() { return EntryBB1; } + BasicBlock *getEntryBlock2() { return EntryBB2; } + BasicBlock *getPreBlock() { return PreBB; } + + Type *getReturnType1() { return RetType1; } + Type *getReturnType2() { return RetType2; } + + Type *getIntPtrType() { return IntPtrTy; } + + void insert(BasicBlock *BB) { CreatedBBs.insert(BB); } + void insert(Instruction *I) { CreatedInsts.insert(I); } + + void erase(BasicBlock *BB) { CreatedBBs.erase(BB); } + void erase(Instruction *I) { CreatedInsts.erase(I); } + + virtual bool generate(AlignedCode &AlignedSeq, + ValueToValueMapTy &VMap, + const FunctionMergingOptions &Options = {}) = 0; + + void destroyGeneratedCode(); + + SmallPtrSet::const_iterator begin() const { + return CreatedInsts.begin(); + } + SmallPtrSet::const_iterator end() const { + return CreatedInsts.end(); + } + }; + + class SALSSACodeGen : public FunctionMerger::CodeGenerator { + + public: + SALSSACodeGen(Function *F1, Function *F2) : CodeGenerator(F1, F2) {} + virtual ~SALSSACodeGen() {} + virtual bool generate(AlignedCode &AlignedSeq, + ValueToValueMapTy &VMap, + const FunctionMergingOptions &Options = {}) override; + }; +}; + +FunctionMergeResult MergeFunctions(Function *F1, Function *F2, + const FunctionMergingOptions &Options = {}); + +class FunctionMergingPass : public PassInfoMixin { +public: + PreservedAnalyses run(Module &M, ModuleAnalysisManager &AM); +}; + +} // namespace llvm + +#endif diff --git a/llvm/include/llvm/Transforms/IPO/SearchStrategy.h b/llvm/include/llvm/Transforms/IPO/SearchStrategy.h new file mode 100644 index 0000000000000000000000000000000000000000..4bdedbb6414c463d75f39972f6285cef0bf8b540 --- /dev/null +++ b/llvm/include/llvm/Transforms/IPO/SearchStrategy.h @@ -0,0 +1,196 @@ +#include +#include +#include +#include + +class SearchStrategy { +private: + + // Default values + const size_t nHashes{200}; + const size_t rows{2}; + const size_t bands{100}; + std::vector randomHashFuncs; + +public: + SearchStrategy() = default; + + SearchStrategy(size_t rows, size_t bands) : nHashes(rows * bands), rows(rows), bands(bands) { + updateRandomHashFunctions(nHashes - 1); + }; + + uint32_t fnv1a(const std::vector &Seq) { + uint32_t hash = 2166136261; + int len = Seq.size(); + + for (int i = 0; i < len; i++) { + hash ^= Seq[i]; + hash *= 1099511628211; + } + + return hash; + } + + uint32_t fnv1a(const std::vector &Seq, uint32_t newHash) { + uint32_t hash = newHash; + int len = Seq.size(); + + for (int i = 0; i < len; i++) { + hash ^= Seq[i]; + hash *= 1099511628211; + } + + return hash; + } + + // Generate shingles using a single hash -- unused as not effective for function merging + template + std::vector& + generateShinglesSingleHashPipelineTurbo(const std::vector &Seq, std::vector &ret) { + uint32_t pipeline[K] = {0}; + int len = Seq.size(); + + ret.resize(nHashes); + + std::unordered_set set; + // set.reserve(nHashes); + uint32_t last = 0; + + for (int i = 0; i < len; i++) { + + for (int k = 0; k < K; k++) { + pipeline[k] ^= Seq[i]; + pipeline[k] *= 1099511628211; + } + + // Collect head of pipeline + if (last <= nHashes - 1) { + ret[last++] = pipeline[0]; + + if (last > nHashes - 1) { + std::make_heap(ret.begin(), ret.end()); + std::sort_heap(ret.begin(), ret.end()); + } + } + + if (pipeline[0] < ret.front() && last > nHashes - 1) { + if (set.find(pipeline[0]) == set.end()) { + set.insert(pipeline[0]); + + ret[last] = pipeline[0]; + + std::sort_heap(ret.begin(), ret.end()); + } + } + + // Shift pipeline + for (int k = 0; k < K - 1; k++) { + pipeline[k] = pipeline[k + 1]; + } + pipeline[K - 1] = 2166136261; + } + + return ret; + } + + // Generate MinHash fingerprint with multiple hash functions + template + std::vector & + generateShinglesMultipleHashPipelineTurbo(const std::vector &Seq, std::vector &ret) { + uint32_t pipeline[K] = {0}; + uint32_t len = Seq.size(); + + uint32_t smallest = std::numeric_limits::max(); + + std::vector shingleHashes(len); + + ret.resize(nHashes); + + // Pipeline to hash all shingles using fnv1a + // Store all hashes + // While storing smallest + // Then for each shingle hash, rehash with an XOR of 32 bit random number + // and store smallest Do this nHashes-1 times to obtain nHashes minHashes + // quickly Sort the hashes at the end + + for (uint32_t i = 0; i < len; i++) { + for (uint32_t k = 0; k < K; k++) { + pipeline[k] ^= Seq[i]; + pipeline[k] *= 1099511628211; + } + + // Collect head of pipeline + if (pipeline[0] < smallest) + smallest = pipeline[0]; + shingleHashes[i] = pipeline[0]; + + // Shift pipeline + for (uint32_t k = 0; k < K - 1; k++) + pipeline[k] = pipeline[k + 1]; + pipeline[K - 1] = 2166136261; + } + + ret[0] = smallest; + + // Now for each hash function, rehash each shingle and store the smallest + // each time + for (uint32_t i = 0; i < randomHashFuncs.size(); i++) { + smallest = std::numeric_limits::max(); + + for (uint32_t j = 0; j < shingleHashes.size(); j++) { + uint32_t temp = shingleHashes[j] ^ randomHashFuncs[i]; + + if (temp < smallest) + smallest = temp; + } + + ret[i + 1] = smallest; + } + + std::sort(ret.begin(), ret.end()); + + return ret; + } + + void updateRandomHashFunctions(size_t num) { + size_t old_num = randomHashFuncs.size(); + randomHashFuncs.resize(num); + + // if we shrunk the vector, there is nothing more to do + if (num <= old_num) + return; + + // If we enlarged it, we need to generate new random numbers + // std::random_device rd; + // std::mt19937 gen(rd()); + std::mt19937 gen(0); + std::uniform_real_distribution<> distribution( + 0, std::numeric_limits::max()); + + // generating a random integer: + for (size_t i = old_num; i < num; i++) + randomHashFuncs[i] = distribution(gen); + } + + std::vector &generateBands(const std::vector &minHashes, + std::vector &LSHBands) { + LSHBands.resize(bands); + + // Generate a hash for each band + for (size_t i = 0; i < bands; i++) { + // Perform fnv1a on the rows + auto first = minHashes.begin() + (i * rows); + auto last = minHashes.begin() + (i * rows) + rows; + LSHBands[i] = fnv1a(std::vector{first, last}); + } + + // Remove duplicate bands -- no need to place twice in the same bucket + std::sort(LSHBands.begin(), LSHBands.end()); + auto last = std::unique(LSHBands.begin(), LSHBands.end()); + LSHBands.erase(last, LSHBands.end()); + + return LSHBands; + } + + uint32_t item_footprint() { return sizeof(uint32_t) * bands * (rows + 1); } +}; diff --git a/llvm/include/llvm/Transforms/IPO/tsl/robin_growth_policy.h b/llvm/include/llvm/Transforms/IPO/tsl/robin_growth_policy.h new file mode 100644 index 0000000000000000000000000000000000000000..cdaf6bda2bf444435640a86253b793f84434f02f --- /dev/null +++ b/llvm/include/llvm/Transforms/IPO/tsl/robin_growth_policy.h @@ -0,0 +1,351 @@ +/** + * MIT License + * + * Copyright (c) 2017 Thibaut Goetghebuer-Planchon + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ +#ifndef TSL_ROBIN_GROWTH_POLICY_H +#define TSL_ROBIN_GROWTH_POLICY_H + + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + + +#ifdef TSL_DEBUG +# define tsl_rh_assert(expr) assert(expr) +#else +# define tsl_rh_assert(expr) (static_cast(0)) +#endif + + +/** + * If exceptions are enabled, throw the exception passed in parameter, otherwise call std::terminate. + */ +#if (defined(__cpp_exceptions) || defined(__EXCEPTIONS) || (defined (_MSC_VER) && defined (_CPPUNWIND))) && !defined(TSL_NO_EXCEPTIONS) +# define TSL_RH_THROW_OR_TERMINATE(ex, msg) throw ex(msg) +#else +# define TSL_RH_NO_EXCEPTIONS +# ifdef NDEBUG +# define TSL_RH_THROW_OR_TERMINATE(ex, msg) std::terminate() +# else +# include +# define TSL_RH_THROW_OR_TERMINATE(ex, msg) do { std::cerr << msg << std::endl; std::terminate(); } while(0) +# endif +#endif + + +#if defined(__GNUC__) || defined(__clang__) +# define TSL_RH_LIKELY(exp) (__builtin_expect(!!(exp), true)) +#else +# define TSL_RH_LIKELY(exp) (exp) +#endif + + +#define TSL_RH_UNUSED(x) static_cast(x) + + +namespace tsl { +namespace rh { + +/** + * Grow the hash table by a factor of GrowthFactor keeping the bucket count to a power of two. It allows + * the table to use a mask operation instead of a modulo operation to map a hash to a bucket. + * + * GrowthFactor must be a power of two >= 2. + */ +template +class power_of_two_growth_policy { +public: + /** + * Called on the hash table creation and on rehash. The number of buckets for the table is passed in parameter. + * This number is a minimum, the policy may update this value with a higher value if needed (but not lower). + * + * If 0 is given, min_bucket_count_in_out must still be 0 after the policy creation and + * bucket_for_hash must always return 0 in this case. + */ + explicit power_of_two_growth_policy(std::size_t& min_bucket_count_in_out) { + if(min_bucket_count_in_out > max_bucket_count()) { + TSL_RH_THROW_OR_TERMINATE(std::length_error, "The hash table exceeds its maximum size."); + } + + if(min_bucket_count_in_out > 0) { + min_bucket_count_in_out = round_up_to_power_of_two(min_bucket_count_in_out); + m_mask = min_bucket_count_in_out - 1; + } + else { + m_mask = 0; + } + } + + /** + * Return the bucket [0, bucket_count()) to which the hash belongs. + * If bucket_count() is 0, it must always return 0. + */ + std::size_t bucket_for_hash(std::size_t hash) const noexcept { + return hash & m_mask; + } + + /** + * Return the number of buckets that should be used on next growth. + */ + std::size_t next_bucket_count() const { + if((m_mask + 1) > max_bucket_count() / GrowthFactor) { + TSL_RH_THROW_OR_TERMINATE(std::length_error, "The hash table exceeds its maximum size."); + } + + return (m_mask + 1) * GrowthFactor; + } + + /** + * Return the maximum number of buckets supported by the policy. + */ + std::size_t max_bucket_count() const { + // Largest power of two. + return (std::numeric_limits::max() / 2) + 1; + } + + /** + * Reset the growth policy as if it was created with a bucket count of 0. + * After a clear, the policy must always return 0 when bucket_for_hash is called. + */ + void clear() noexcept { + m_mask = 0; + } + +private: + static std::size_t round_up_to_power_of_two(std::size_t value) { + if(is_power_of_two(value)) { + return value; + } + + if(value == 0) { + return 1; + } + + --value; + for(std::size_t i = 1; i < sizeof(std::size_t) * CHAR_BIT; i *= 2) { + value |= value >> i; + } + + return value + 1; + } + + static constexpr bool is_power_of_two(std::size_t value) { + return value != 0 && (value & (value - 1)) == 0; + } + +protected: + static_assert(is_power_of_two(GrowthFactor) && GrowthFactor >= 2, "GrowthFactor must be a power of two >= 2."); + + std::size_t m_mask; +}; + + +/** + * Grow the hash table by GrowthFactor::num / GrowthFactor::den and use a modulo to map a hash + * to a bucket. Slower but it can be useful if you want a slower growth. + */ +template> +class mod_growth_policy { +public: + explicit mod_growth_policy(std::size_t& min_bucket_count_in_out) { + if(min_bucket_count_in_out > max_bucket_count()) { + TSL_RH_THROW_OR_TERMINATE(std::length_error, "The hash table exceeds its maximum size."); + } + + if(min_bucket_count_in_out > 0) { + m_mod = min_bucket_count_in_out; + } + else { + m_mod = 1; + } + } + + std::size_t bucket_for_hash(std::size_t hash) const noexcept { + return hash % m_mod; + } + + std::size_t next_bucket_count() const { + if(m_mod == max_bucket_count()) { + TSL_RH_THROW_OR_TERMINATE(std::length_error, "The hash table exceeds its maximum size."); + } + + const double next_bucket_count = std::ceil(double(m_mod) * REHASH_SIZE_MULTIPLICATION_FACTOR); + if(!std::isnormal(next_bucket_count)) { + TSL_RH_THROW_OR_TERMINATE(std::length_error, "The hash table exceeds its maximum size."); + } + + if(next_bucket_count > double(max_bucket_count())) { + return max_bucket_count(); + } + else { + return std::size_t(next_bucket_count); + } + } + + std::size_t max_bucket_count() const { + return MAX_BUCKET_COUNT; + } + + void clear() noexcept { + m_mod = 1; + } + +private: + static constexpr double REHASH_SIZE_MULTIPLICATION_FACTOR = 1.0 * GrowthFactor::num / GrowthFactor::den; + static const std::size_t MAX_BUCKET_COUNT = + std::size_t(double( + std::numeric_limits::max() / REHASH_SIZE_MULTIPLICATION_FACTOR + )); + + static_assert(REHASH_SIZE_MULTIPLICATION_FACTOR >= 1.1, "Growth factor should be >= 1.1."); + + std::size_t m_mod; +}; + + + +namespace detail { + +#if SIZE_MAX >= ULLONG_MAX +#define TSL_RH_NB_PRIMES 51 +#elif SIZE_MAX >= ULONG_MAX +#define TSL_RH_NB_PRIMES 40 +#else +#define TSL_RH_NB_PRIMES 23 +#endif + +static constexpr const std::array PRIMES = {{ + 1u, 5u, 17u, 29u, 37u, 53u, 67u, 79u, 97u, 131u, 193u, 257u, 389u, 521u, 769u, 1031u, + 1543u, 2053u, 3079u, 6151u, 12289u, 24593u, 49157u, +#if SIZE_MAX >= ULONG_MAX + 98317ul, 196613ul, 393241ul, 786433ul, 1572869ul, 3145739ul, 6291469ul, 12582917ul, + 25165843ul, 50331653ul, 100663319ul, 201326611ul, 402653189ul, 805306457ul, 1610612741ul, + 3221225473ul, 4294967291ul, +#endif +#if SIZE_MAX >= ULLONG_MAX + 6442450939ull, 12884901893ull, 25769803751ull, 51539607551ull, 103079215111ull, 206158430209ull, + 412316860441ull, 824633720831ull, 1649267441651ull, 3298534883309ull, 6597069766657ull, +#endif +}}; + +template +static constexpr std::size_t mod(std::size_t hash) { return hash % PRIMES[IPrime]; } + +// MOD_PRIME[iprime](hash) returns hash % PRIMES[iprime]. This table allows for faster modulo as the +// compiler can optimize the modulo code better with a constant known at the compilation. +static constexpr const std::array MOD_PRIME = {{ + &mod<0>, &mod<1>, &mod<2>, &mod<3>, &mod<4>, &mod<5>, &mod<6>, &mod<7>, &mod<8>, &mod<9>, &mod<10>, + &mod<11>, &mod<12>, &mod<13>, &mod<14>, &mod<15>, &mod<16>, &mod<17>, &mod<18>, &mod<19>, &mod<20>, + &mod<21>, &mod<22>, +#if SIZE_MAX >= ULONG_MAX + &mod<23>, &mod<24>, &mod<25>, &mod<26>, &mod<27>, &mod<28>, &mod<29>, &mod<30>, &mod<31>, &mod<32>, + &mod<33>, &mod<34>, &mod<35>, &mod<36>, &mod<37> , &mod<38>, &mod<39>, +#endif +#if SIZE_MAX >= ULLONG_MAX + &mod<40>, &mod<41>, &mod<42>, &mod<43>, &mod<44>, &mod<45>, &mod<46>, &mod<47>, &mod<48>, &mod<49>, + &mod<50>, +#endif +}}; + +} + +/** + * Grow the hash table by using prime numbers as bucket count. Slower than tsl::rh::power_of_two_growth_policy in + * general but will probably distribute the values around better in the buckets with a poor hash function. + * + * To allow the compiler to optimize the modulo operation, a lookup table is used with constant primes numbers. + * + * With a switch the code would look like: + * \code + * switch(iprime) { // iprime is the current prime of the hash table + * case 0: hash % 5ul; + * break; + * case 1: hash % 17ul; + * break; + * case 2: hash % 29ul; + * break; + * ... + * } + * \endcode + * + * Due to the constant variable in the modulo the compiler is able to optimize the operation + * by a series of multiplications, substractions and shifts. + * + * The 'hash % 5' could become something like 'hash - (hash * 0xCCCCCCCD) >> 34) * 5' in a 64 bits environment. + */ +class prime_growth_policy { +public: + explicit prime_growth_policy(std::size_t& min_bucket_count_in_out) { + auto it_prime = std::lower_bound(detail::PRIMES.begin(), + detail::PRIMES.end(), min_bucket_count_in_out); + if(it_prime == detail::PRIMES.end()) { + TSL_RH_THROW_OR_TERMINATE(std::length_error, "The hash table exceeds its maximum size."); + } + + m_iprime = static_cast(std::distance(detail::PRIMES.begin(), it_prime)); + if(min_bucket_count_in_out > 0) { + min_bucket_count_in_out = *it_prime; + } + else { + min_bucket_count_in_out = 0; + } + } + + std::size_t bucket_for_hash(std::size_t hash) const noexcept { + return detail::MOD_PRIME[m_iprime](hash); + } + + std::size_t next_bucket_count() const { + if(m_iprime + 1 >= detail::PRIMES.size()) { + TSL_RH_THROW_OR_TERMINATE(std::length_error, "The hash table exceeds its maximum size."); + } + + return detail::PRIMES[m_iprime + 1]; + } + + std::size_t max_bucket_count() const { + return detail::PRIMES.back(); + } + + void clear() noexcept { + m_iprime = 0; + } + +private: + unsigned int m_iprime; + + static_assert(std::numeric_limits::max() >= detail::PRIMES.size(), + "The type of m_iprime is not big enough."); +}; + +} +} + +#endif \ No newline at end of file diff --git a/llvm/include/llvm/Transforms/IPO/tsl/robin_hash.h b/llvm/include/llvm/Transforms/IPO/tsl/robin_hash.h new file mode 100644 index 0000000000000000000000000000000000000000..bc24adfb425be3cd4e28cee899eaacd084fc8949 --- /dev/null +++ b/llvm/include/llvm/Transforms/IPO/tsl/robin_hash.h @@ -0,0 +1,1619 @@ +/** + * MIT License + * + * Copyright (c) 2017 Thibaut Goetghebuer-Planchon + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ +#ifndef TSL_ROBIN_HASH_H +#define TSL_ROBIN_HASH_H + + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include "llvm/Transforms/IPO/tsl/robin_growth_policy.h" + + +namespace tsl { + +namespace detail_robin_hash { + +template +struct make_void { + using type = void; +}; + +template +struct has_is_transparent: std::false_type { +}; + +template +struct has_is_transparent::type>: std::true_type { +}; + +template +struct is_power_of_two_policy: std::false_type { +}; + +template +struct is_power_of_two_policy>: std::true_type { +}; + +// Only available in C++17, we need to be compatible with C++11 +template +const T& clamp( const T& v, const T& lo, const T& hi) { + return std::min(hi, std::max(lo, v)); +} + +template +static T numeric_cast(U value, const char* error_message = "numeric_cast() failed.") { + T ret = static_cast(value); + if(static_cast(ret) != value) { + TSL_RH_THROW_OR_TERMINATE(std::runtime_error, error_message); + } + + const bool is_same_signedness = (std::is_unsigned::value && std::is_unsigned::value) || + (std::is_signed::value && std::is_signed::value); + if(!is_same_signedness && (ret < T{}) != (value < U{})) { + TSL_RH_THROW_OR_TERMINATE(std::runtime_error, error_message); + } + + return ret; +} + +template +static T deserialize_value(Deserializer& deserializer) { + // MSVC < 2017 is not conformant, circumvent the problem by removing the template keyword +#if defined (_MSC_VER) && _MSC_VER < 1910 + return deserializer.Deserializer::operator()(); +#else + return deserializer.Deserializer::template operator()(); +#endif +} + + +/** + * Fixed size type used to represent size_type values on serialization. Need to be big enough + * to represent a std::size_t on 32 and 64 bits platforms, and must be the same size on both platforms. + */ +using slz_size_type = std::uint64_t; +static_assert(std::numeric_limits::max() >= std::numeric_limits::max(), + "slz_size_type must be >= std::size_t"); + +using truncated_hash_type = std::uint32_t; + + +/** + * Helper class that stores a truncated hash if StoreHash is true and nothing otherwise. + */ +template +class bucket_entry_hash { +public: + bool bucket_hash_equal(std::size_t /*hash*/) const noexcept { + return true; + } + + truncated_hash_type truncated_hash() const noexcept { + return 0; + } + +protected: + void set_hash(truncated_hash_type /*hash*/) noexcept { + } +}; + +template<> +class bucket_entry_hash { +public: + bool bucket_hash_equal(std::size_t hash) const noexcept { + return m_hash == truncated_hash_type(hash); + } + + truncated_hash_type truncated_hash() const noexcept { + return m_hash; + } + +protected: + void set_hash(truncated_hash_type hash) noexcept { + m_hash = truncated_hash_type(hash); + } + +private: + truncated_hash_type m_hash; +}; + + +/** + * Each bucket entry has: + * - A value of type `ValueType`. + * - An integer to store how far the value of the bucket, if any, is from its ideal bucket + * (ex: if the current bucket 5 has the value 'foo' and `hash('foo') % nb_buckets` == 3, + * `dist_from_ideal_bucket()` will return 2 as the current value of the bucket is two + * buckets away from its ideal bucket) + * If there is no value in the bucket (i.e. `empty()` is true) `dist_from_ideal_bucket()` will be < 0. + * - A marker which tells us if the bucket is the last bucket of the bucket array (useful for the + * iterator of the hash table). + * - If `StoreHash` is true, 32 bits of the hash of the value, if any, are also stored in the bucket. + * If the size of the hash is more than 32 bits, it is truncated. We don't store the full hash + * as storing the hash is a potential opportunity to use the unused space due to the alignment + * of the bucket_entry structure. We can thus potentially store the hash without any extra space + * (which would not be possible with 64 bits of the hash). + */ +template +class bucket_entry: public bucket_entry_hash { + using bucket_hash = bucket_entry_hash; + +public: + using value_type = ValueType; + using distance_type = std::int16_t; + + + bucket_entry() noexcept: bucket_hash(), m_dist_from_ideal_bucket(EMPTY_MARKER_DIST_FROM_IDEAL_BUCKET), + m_last_bucket(false) + { + tsl_rh_assert(empty()); + } + + bucket_entry(bool last_bucket) noexcept: bucket_hash(), m_dist_from_ideal_bucket(EMPTY_MARKER_DIST_FROM_IDEAL_BUCKET), + m_last_bucket(last_bucket) + { + tsl_rh_assert(empty()); + } + + bucket_entry(const bucket_entry& other) noexcept(std::is_nothrow_copy_constructible::value): + bucket_hash(other), + m_dist_from_ideal_bucket(EMPTY_MARKER_DIST_FROM_IDEAL_BUCKET), + m_last_bucket(other.m_last_bucket) + { + if(!other.empty()) { + ::new (static_cast(std::addressof(m_value))) value_type(other.value()); + m_dist_from_ideal_bucket = other.m_dist_from_ideal_bucket; + } + } + + /** + * Never really used, but still necessary as we must call resize on an empty `std::vector`. + * and we need to support move-only types. See robin_hash constructor for details. + */ + bucket_entry(bucket_entry&& other) noexcept(std::is_nothrow_move_constructible::value): + bucket_hash(std::move(other)), + m_dist_from_ideal_bucket(EMPTY_MARKER_DIST_FROM_IDEAL_BUCKET), + m_last_bucket(other.m_last_bucket) + { + if(!other.empty()) { + ::new (static_cast(std::addressof(m_value))) value_type(std::move(other.value())); + m_dist_from_ideal_bucket = other.m_dist_from_ideal_bucket; + } + } + + bucket_entry& operator=(const bucket_entry& other) + noexcept(std::is_nothrow_copy_constructible::value) + { + if(this != &other) { + clear(); + + bucket_hash::operator=(other); + if(!other.empty()) { + ::new (static_cast(std::addressof(m_value))) value_type(other.value()); + } + + m_dist_from_ideal_bucket = other.m_dist_from_ideal_bucket; + m_last_bucket = other.m_last_bucket; + } + + return *this; + } + + bucket_entry& operator=(bucket_entry&& ) = delete; + + ~bucket_entry() noexcept { + clear(); + } + + void clear() noexcept { + if(!empty()) { + destroy_value(); + m_dist_from_ideal_bucket = EMPTY_MARKER_DIST_FROM_IDEAL_BUCKET; + } + } + + bool empty() const noexcept { + return m_dist_from_ideal_bucket == EMPTY_MARKER_DIST_FROM_IDEAL_BUCKET; + } + + value_type& value() noexcept { + tsl_rh_assert(!empty()); + return *reinterpret_cast(std::addressof(m_value)); + } + + const value_type& value() const noexcept { + tsl_rh_assert(!empty()); + return *reinterpret_cast(std::addressof(m_value)); + } + + distance_type dist_from_ideal_bucket() const noexcept { + return m_dist_from_ideal_bucket; + } + + bool last_bucket() const noexcept { + return m_last_bucket; + } + + void set_as_last_bucket() noexcept { + m_last_bucket = true; + } + + template + void set_value_of_empty_bucket(distance_type dist_from_ideal_bucket, + truncated_hash_type hash, Args&&... value_type_args) + { + tsl_rh_assert(dist_from_ideal_bucket >= 0); + tsl_rh_assert(empty()); + + ::new (static_cast(std::addressof(m_value))) value_type(std::forward(value_type_args)...); + this->set_hash(hash); + m_dist_from_ideal_bucket = dist_from_ideal_bucket; + + tsl_rh_assert(!empty()); + } + + void swap_with_value_in_bucket(distance_type& dist_from_ideal_bucket, + truncated_hash_type& hash, value_type& value) + { + tsl_rh_assert(!empty()); + + using std::swap; + swap(value, this->value()); + swap(dist_from_ideal_bucket, m_dist_from_ideal_bucket); + + if(StoreHash) { + const truncated_hash_type tmp_hash = this->truncated_hash(); + this->set_hash(hash); + hash = tmp_hash; + } + else { + // Avoid warning of unused variable if StoreHash is false + TSL_RH_UNUSED(hash); + } + } + + static truncated_hash_type truncate_hash(std::size_t hash) noexcept { + return truncated_hash_type(hash); + } + +private: + void destroy_value() noexcept { + tsl_rh_assert(!empty()); + value().~value_type(); + } + +public: + static const distance_type EMPTY_MARKER_DIST_FROM_IDEAL_BUCKET = -1; + static const distance_type DIST_FROM_IDEAL_BUCKET_LIMIT = 4096; + static_assert(DIST_FROM_IDEAL_BUCKET_LIMIT <= std::numeric_limits::max() - 1, + "DIST_FROM_IDEAL_BUCKET_LIMIT must be <= std::numeric_limits::max() - 1."); + +private: + using storage = typename std::aligned_storage::type; + + distance_type m_dist_from_ideal_bucket; + bool m_last_bucket; + storage m_value; +}; + + + +/** + * Internal common class used by `robin_map` and `robin_set`. + * + * ValueType is what will be stored by `robin_hash` (usually `std::pair` for map and `Key` for set). + * + * `KeySelect` should be a `FunctionObject` which takes a `ValueType` in parameter and returns a + * reference to the key. + * + * `ValueSelect` should be a `FunctionObject` which takes a `ValueType` in parameter and returns a + * reference to the value. `ValueSelect` should be void if there is no value (in a set for example). + * + * The strong exception guarantee only holds if the expression + * `std::is_nothrow_swappable::value && std::is_nothrow_move_constructible::value` is true. + * + * Behaviour is undefined if the destructor of `ValueType` throws. + */ +template +class robin_hash: private Hash, private KeyEqual, private GrowthPolicy { +private: + template + using has_mapped_type = typename std::integral_constant::value>; + + static_assert(noexcept(std::declval().bucket_for_hash(std::size_t(0))), "GrowthPolicy::bucket_for_hash must be noexcept."); + static_assert(noexcept(std::declval().clear()), "GrowthPolicy::clear must be noexcept."); + +public: + template + class robin_iterator; + + using key_type = typename KeySelect::key_type; + using value_type = ValueType; + using size_type = std::size_t; + using difference_type = std::ptrdiff_t; + using hasher = Hash; + using key_equal = KeyEqual; + using allocator_type = Allocator; + using reference = value_type&; + using const_reference = const value_type&; + using pointer = value_type*; + using const_pointer = const value_type*; + using iterator = robin_iterator; + using const_iterator = robin_iterator; + + +private: + /** + * Either store the hash because we are asked by the `StoreHash` template parameter + * or store the hash because it doesn't cost us anything in size and can be used to speed up rehash. + */ + static constexpr bool STORE_HASH = StoreHash || + ( + (sizeof(tsl::detail_robin_hash::bucket_entry) == + sizeof(tsl::detail_robin_hash::bucket_entry)) + && + (sizeof(std::size_t) == sizeof(truncated_hash_type) || + is_power_of_two_policy::value) + && + // Don't store the hash for primitive types with default hash. + (!std::is_arithmetic::value || + !std::is_same>::value) + ); + + /** + * Only use the stored hash on lookup if we are explicitly asked. We are not sure how slow + * the KeyEqual operation is. An extra comparison may slow things down with a fast KeyEqual. + */ + static constexpr bool USE_STORED_HASH_ON_LOOKUP = StoreHash; + + /** + * We can only use the hash on rehash if the size of the hash type is the same as the stored one or + * if we use a power of two modulo. In the case of the power of two modulo, we just mask + * the least significant bytes, we just have to check that the truncated_hash_type didn't truncated + * more bytes. + */ + static bool USE_STORED_HASH_ON_REHASH(size_type bucket_count) { + if(STORE_HASH && sizeof(std::size_t) == sizeof(truncated_hash_type)) { + TSL_RH_UNUSED(bucket_count); + return true; + } + else if(STORE_HASH && is_power_of_two_policy::value) { + tsl_rh_assert(bucket_count > 0); + return (bucket_count - 1) <= std::numeric_limits::max(); + } + else { + TSL_RH_UNUSED(bucket_count); + return false; + } + } + + using bucket_entry = tsl::detail_robin_hash::bucket_entry; + using distance_type = typename bucket_entry::distance_type; + + using buckets_allocator = typename std::allocator_traits::template rebind_alloc; + using buckets_container_type = std::vector; + + +public: + /** + * The 'operator*()' and 'operator->()' methods return a const reference and const pointer respectively to the + * stored value type. + * + * In case of a map, to get a mutable reference to the value associated to a key (the '.second' in the + * stored pair), you have to call 'value()'. + * + * The main reason for this is that if we returned a `std::pair&` instead + * of a `const std::pair&`, the user may modify the key which will put the map in a undefined state. + */ + template + class robin_iterator { + friend class robin_hash; + + private: + using bucket_entry_ptr = typename std::conditional::type; + + + robin_iterator(bucket_entry_ptr bucket) noexcept: m_bucket(bucket) { + } + + public: + using iterator_category = std::forward_iterator_tag; + using value_type = const typename robin_hash::value_type; + using difference_type = std::ptrdiff_t; + using reference = value_type&; + using pointer = value_type*; + + + robin_iterator() noexcept { + } + + // Copy constructor from iterator to const_iterator. + template::type* = nullptr> + robin_iterator(const robin_iterator& other) noexcept: m_bucket(other.m_bucket) { + } + + robin_iterator(const robin_iterator& other) = default; + robin_iterator(robin_iterator&& other) = default; + robin_iterator& operator=(const robin_iterator& other) = default; + robin_iterator& operator=(robin_iterator&& other) = default; + + const typename robin_hash::key_type& key() const { + return KeySelect()(m_bucket->value()); + } + + template::value && IsConst>::type* = nullptr> + const typename U::value_type& value() const { + return U()(m_bucket->value()); + } + + template::value && !IsConst>::type* = nullptr> + typename U::value_type& value() const { + return U()(m_bucket->value()); + } + + reference operator*() const { + return m_bucket->value(); + } + + pointer operator->() const { + return std::addressof(m_bucket->value()); + } + + robin_iterator& operator++() { + while(true) { + if(m_bucket->last_bucket()) { + ++m_bucket; + return *this; + } + + ++m_bucket; + if(!m_bucket->empty()) { + return *this; + } + } + } + + robin_iterator operator++(int) { + robin_iterator tmp(*this); + ++*this; + + return tmp; + } + + friend bool operator==(const robin_iterator& lhs, const robin_iterator& rhs) { + return lhs.m_bucket == rhs.m_bucket; + } + + friend bool operator!=(const robin_iterator& lhs, const robin_iterator& rhs) { + return !(lhs == rhs); + } + + private: + bucket_entry_ptr m_bucket; + }; + + +public: +#if defined(__cplusplus) && __cplusplus >= 201402L + robin_hash(size_type bucket_count, + const Hash& hash, + const KeyEqual& equal, + const Allocator& alloc, + float min_load_factor = DEFAULT_MIN_LOAD_FACTOR, + float max_load_factor = DEFAULT_MAX_LOAD_FACTOR): + Hash(hash), + KeyEqual(equal), + GrowthPolicy(bucket_count), + m_buckets_data( + [&]() { + if(bucket_count > max_bucket_count()) { + TSL_RH_THROW_OR_TERMINATE(std::length_error, + "The map exceeds its maximum bucket count."); + } + + return bucket_count; + }(), alloc + ), + m_buckets(m_buckets_data.empty()?static_empty_bucket_ptr():m_buckets_data.data()), + m_bucket_count(bucket_count), + m_nb_elements(0), + m_grow_on_next_insert(false), + m_try_shrink_on_next_insert(false) + { + if(m_bucket_count > 0) { + tsl_rh_assert(!m_buckets_data.empty()); + m_buckets_data.back().set_as_last_bucket(); + } + + this->min_load_factor(min_load_factor); + this->max_load_factor(max_load_factor); + } +#else + /** + * C++11 doesn't support the creation of a std::vector with a custom allocator and 'count' default-inserted elements. + * The needed contructor `explicit vector(size_type count, const Allocator& alloc = Allocator());` is only + * available in C++14 and later. We thus must resize after using the `vector(const Allocator& alloc)` constructor. + * + * We can't use `vector(size_type count, const T& value, const Allocator& alloc)` as it requires the + * value T to be copyable. + */ + robin_hash(size_type bucket_count, + const Hash& hash, + const KeyEqual& equal, + const Allocator& alloc, + float min_load_factor = DEFAULT_MIN_LOAD_FACTOR, + float max_load_factor = DEFAULT_MAX_LOAD_FACTOR): + Hash(hash), + KeyEqual(equal), + GrowthPolicy(bucket_count), + m_buckets_data(alloc), + m_buckets(static_empty_bucket_ptr()), + m_bucket_count(bucket_count), + m_nb_elements(0), + m_grow_on_next_insert(false), + m_try_shrink_on_next_insert(false) + { + if(bucket_count > max_bucket_count()) { + TSL_RH_THROW_OR_TERMINATE(std::length_error, "The map exceeds its maximum bucket count."); + } + + if(m_bucket_count > 0) { + m_buckets_data.resize(m_bucket_count); + m_buckets = m_buckets_data.data(); + + tsl_rh_assert(!m_buckets_data.empty()); + m_buckets_data.back().set_as_last_bucket(); + } + + this->min_load_factor(min_load_factor); + this->max_load_factor(max_load_factor); + } +#endif + + robin_hash(const robin_hash& other): Hash(other), + KeyEqual(other), + GrowthPolicy(other), + m_buckets_data(other.m_buckets_data), + m_buckets(m_buckets_data.empty()?static_empty_bucket_ptr():m_buckets_data.data()), + m_bucket_count(other.m_bucket_count), + m_nb_elements(other.m_nb_elements), + m_load_threshold(other.m_load_threshold), + m_min_load_factor(other.m_min_load_factor), + m_max_load_factor(other.m_max_load_factor), + m_grow_on_next_insert(other.m_grow_on_next_insert), + m_try_shrink_on_next_insert(other.m_try_shrink_on_next_insert) + { + } + + robin_hash(robin_hash&& other) noexcept(std::is_nothrow_move_constructible::value && + std::is_nothrow_move_constructible::value && + std::is_nothrow_move_constructible::value && + std::is_nothrow_move_constructible::value) + : Hash(std::move(static_cast(other))), + KeyEqual(std::move(static_cast(other))), + GrowthPolicy(std::move(static_cast(other))), + m_buckets_data(std::move(other.m_buckets_data)), + m_buckets(m_buckets_data.empty()?static_empty_bucket_ptr():m_buckets_data.data()), + m_bucket_count(other.m_bucket_count), + m_nb_elements(other.m_nb_elements), + m_load_threshold(other.m_load_threshold), + m_min_load_factor(other.m_min_load_factor), + m_max_load_factor(other.m_max_load_factor), + m_grow_on_next_insert(other.m_grow_on_next_insert), + m_try_shrink_on_next_insert(other.m_try_shrink_on_next_insert) + { + other.clear_and_shrink(); + } + + robin_hash& operator=(const robin_hash& other) { + if(&other != this) { + Hash::operator=(other); + KeyEqual::operator=(other); + GrowthPolicy::operator=(other); + + m_buckets_data = other.m_buckets_data; + m_buckets = m_buckets_data.empty()?static_empty_bucket_ptr(): + m_buckets_data.data(); + m_bucket_count = other.m_bucket_count; + m_nb_elements = other.m_nb_elements; + + m_load_threshold = other.m_load_threshold; + m_min_load_factor = other.m_min_load_factor; + m_max_load_factor = other.m_max_load_factor; + + m_grow_on_next_insert = other.m_grow_on_next_insert; + m_try_shrink_on_next_insert = other.m_try_shrink_on_next_insert; + } + + return *this; + } + + robin_hash& operator=(robin_hash&& other) { + other.swap(*this); + other.clear(); + + return *this; + } + + allocator_type get_allocator() const { + return m_buckets_data.get_allocator(); + } + + + /* + * Iterators + */ + iterator begin() noexcept { + std::size_t i = 0; + while(i < m_bucket_count && m_buckets[i].empty()) { + i++; + } + + return iterator(m_buckets + i); + } + + const_iterator begin() const noexcept { + return cbegin(); + } + + const_iterator cbegin() const noexcept { + std::size_t i = 0; + while(i < m_bucket_count && m_buckets[i].empty()) { + i++; + } + + return const_iterator(m_buckets + i); + } + + iterator end() noexcept { + return iterator(m_buckets + m_bucket_count); + } + + const_iterator end() const noexcept { + return cend(); + } + + const_iterator cend() const noexcept { + return const_iterator(m_buckets + m_bucket_count); + } + + + /* + * Capacity + */ + bool empty() const noexcept { + return m_nb_elements == 0; + } + + size_type size() const noexcept { + return m_nb_elements; + } + + size_type max_size() const noexcept { + return m_buckets_data.max_size(); + } + + /* + * Modifiers + */ + void clear() noexcept { + if(m_min_load_factor > 0.0f) { + clear_and_shrink(); + } + else { + for(auto& bucket: m_buckets_data) { + bucket.clear(); + } + + m_nb_elements = 0; + m_grow_on_next_insert = false; + } + } + + + + template + std::pair insert(P&& value) { + return insert_impl(KeySelect()(value), std::forward

(value)); + } + + template + iterator insert_hint(const_iterator hint, P&& value) { + if(hint != cend() && compare_keys(KeySelect()(*hint), KeySelect()(value))) { + return mutable_iterator(hint); + } + + return insert(std::forward

(value)).first; + } + + template + void insert(InputIt first, InputIt last) { + if(std::is_base_of::iterator_category>::value) + { + const auto nb_elements_insert = std::distance(first, last); + const size_type nb_free_buckets = m_load_threshold - size(); + tsl_rh_assert(m_load_threshold >= size()); + + if(nb_elements_insert > 0 && nb_free_buckets < size_type(nb_elements_insert)) { + reserve(size() + size_type(nb_elements_insert)); + } + } + + for(; first != last; ++first) { + insert(*first); + } + } + + + + template + std::pair insert_or_assign(K&& key, M&& obj) { + auto it = try_emplace(std::forward(key), std::forward(obj)); + if(!it.second) { + it.first.value() = std::forward(obj); + } + + return it; + } + + template + iterator insert_or_assign(const_iterator hint, K&& key, M&& obj) { + if(hint != cend() && compare_keys(KeySelect()(*hint), key)) { + auto it = mutable_iterator(hint); + it.value() = std::forward(obj); + + return it; + } + + return insert_or_assign(std::forward(key), std::forward(obj)).first; + } + + + template + std::pair emplace(Args&&... args) { + return insert(value_type(std::forward(args)...)); + } + + template + iterator emplace_hint(const_iterator hint, Args&&... args) { + return insert_hint(hint, value_type(std::forward(args)...)); + } + + + + template + std::pair try_emplace(K&& key, Args&&... args) { + return insert_impl(key, std::piecewise_construct, + std::forward_as_tuple(std::forward(key)), + std::forward_as_tuple(std::forward(args)...)); + } + + template + iterator try_emplace_hint(const_iterator hint, K&& key, Args&&... args) { + if(hint != cend() && compare_keys(KeySelect()(*hint), key)) { + return mutable_iterator(hint); + } + + return try_emplace(std::forward(key), std::forward(args)...).first; + } + + /** + * Here to avoid `template size_type erase(const K& key)` being used when + * we use an `iterator` instead of a `const_iterator`. + */ + iterator erase(iterator pos) { + erase_from_bucket(pos); + + /** + * Erase bucket used a backward shift after clearing the bucket. + * Check if there is a new value in the bucket, if not get the next non-empty. + */ + if(pos.m_bucket->empty()) { + ++pos; + } + + m_try_shrink_on_next_insert = true; + + return pos; + } + + iterator erase(const_iterator pos) { + return erase(mutable_iterator(pos)); + } + + iterator erase(const_iterator first, const_iterator last) { + if(first == last) { + return mutable_iterator(first); + } + + auto first_mutable = mutable_iterator(first); + auto last_mutable = mutable_iterator(last); + for(auto it = first_mutable.m_bucket; it != last_mutable.m_bucket; ++it) { + if(!it->empty()) { + it->clear(); + m_nb_elements--; + } + } + + if(last_mutable == end()) { + m_try_shrink_on_next_insert = true; + return end(); + } + + + /* + * Backward shift on the values which come after the deleted values. + * We try to move the values closer to their ideal bucket. + */ + std::size_t icloser_bucket = static_cast(first_mutable.m_bucket - m_buckets); + std::size_t ito_move_closer_value = static_cast(last_mutable.m_bucket - m_buckets); + tsl_rh_assert(ito_move_closer_value > icloser_bucket); + + const std::size_t ireturn_bucket = ito_move_closer_value - + std::min(ito_move_closer_value - icloser_bucket, + std::size_t(m_buckets[ito_move_closer_value].dist_from_ideal_bucket())); + + while(ito_move_closer_value < m_bucket_count && m_buckets[ito_move_closer_value].dist_from_ideal_bucket() > 0) { + icloser_bucket = ito_move_closer_value - + std::min(ito_move_closer_value - icloser_bucket, + std::size_t(m_buckets[ito_move_closer_value].dist_from_ideal_bucket())); + + + tsl_rh_assert(m_buckets[icloser_bucket].empty()); + const distance_type new_distance = distance_type(m_buckets[ito_move_closer_value].dist_from_ideal_bucket() - + (ito_move_closer_value - icloser_bucket)); + m_buckets[icloser_bucket].set_value_of_empty_bucket(new_distance, + m_buckets[ito_move_closer_value].truncated_hash(), + std::move(m_buckets[ito_move_closer_value].value())); + m_buckets[ito_move_closer_value].clear(); + + + ++icloser_bucket; + ++ito_move_closer_value; + } + + m_try_shrink_on_next_insert = true; + + return iterator(m_buckets + ireturn_bucket); + } + + + template + size_type erase(const K& key) { + return erase(key, hash_key(key)); + } + + template + size_type erase(const K& key, std::size_t hash) { + auto it = find(key, hash); + if(it != end()) { + erase_from_bucket(it); + m_try_shrink_on_next_insert = true; + + return 1; + } + else { + return 0; + } + } + + + + + + void swap(robin_hash& other) { + using std::swap; + + swap(static_cast(*this), static_cast(other)); + swap(static_cast(*this), static_cast(other)); + swap(static_cast(*this), static_cast(other)); + swap(m_buckets_data, other.m_buckets_data); + swap(m_buckets, other.m_buckets); + swap(m_bucket_count, other.m_bucket_count); + swap(m_nb_elements, other.m_nb_elements); + swap(m_load_threshold, other.m_load_threshold); + swap(m_min_load_factor, other.m_min_load_factor); + swap(m_max_load_factor, other.m_max_load_factor); + swap(m_grow_on_next_insert, other.m_grow_on_next_insert); + swap(m_try_shrink_on_next_insert, other.m_try_shrink_on_next_insert); + } + + + /* + * Lookup + */ + template::value>::type* = nullptr> + typename U::value_type& at(const K& key) { + return at(key, hash_key(key)); + } + + template::value>::type* = nullptr> + typename U::value_type& at(const K& key, std::size_t hash) { + return const_cast(static_cast(this)->at(key, hash)); + } + + + template::value>::type* = nullptr> + const typename U::value_type& at(const K& key) const { + return at(key, hash_key(key)); + } + + template::value>::type* = nullptr> + const typename U::value_type& at(const K& key, std::size_t hash) const { + auto it = find(key, hash); + if(it != cend()) { + return it.value(); + } + else { + TSL_RH_THROW_OR_TERMINATE(std::out_of_range, "Couldn't find key."); + } + } + + template::value>::type* = nullptr> + typename U::value_type& operator[](K&& key) { + return try_emplace(std::forward(key)).first.value(); + } + + + template + size_type count(const K& key) const { + return count(key, hash_key(key)); + } + + template + size_type count(const K& key, std::size_t hash) const { + if(find(key, hash) != cend()) { + return 1; + } + else { + return 0; + } + } + + + template + iterator find(const K& key) { + return find_impl(key, hash_key(key)); + } + + template + iterator find(const K& key, std::size_t hash) { + return find_impl(key, hash); + } + + + template + const_iterator find(const K& key) const { + return find_impl(key, hash_key(key)); + } + + template + const_iterator find(const K& key, std::size_t hash) const { + return find_impl(key, hash); + } + + + template + bool contains(const K& key) const { + return contains(key, hash_key(key)); + } + + template + bool contains(const K& key, std::size_t hash) const { + return count(key, hash) != 0; + } + + + template + std::pair equal_range(const K& key) { + return equal_range(key, hash_key(key)); + } + + template + std::pair equal_range(const K& key, std::size_t hash) { + iterator it = find(key, hash); + return std::make_pair(it, (it == end())?it:std::next(it)); + } + + + template + std::pair equal_range(const K& key) const { + return equal_range(key, hash_key(key)); + } + + template + std::pair equal_range(const K& key, std::size_t hash) const { + const_iterator it = find(key, hash); + return std::make_pair(it, (it == cend())?it:std::next(it)); + } + + /* + * Bucket interface + */ + size_type bucket_count() const { + return m_bucket_count; + } + + size_type max_bucket_count() const { + return std::min(GrowthPolicy::max_bucket_count(), m_buckets_data.max_size()); + } + + /* + * Hash policy + */ + float load_factor() const { + if(bucket_count() == 0) { + return 0; + } + + return float(m_nb_elements)/float(bucket_count()); + } + + float min_load_factor() const { + return m_min_load_factor; + } + + float max_load_factor() const { + return m_max_load_factor; + } + + void min_load_factor(float ml) { + m_min_load_factor = clamp(ml, float(MINIMUM_MIN_LOAD_FACTOR), + float(MAXIMUM_MIN_LOAD_FACTOR)); + } + + void max_load_factor(float ml) { + m_max_load_factor = clamp(ml, float(MINIMUM_MAX_LOAD_FACTOR), + float(MAXIMUM_MAX_LOAD_FACTOR)); + m_load_threshold = size_type(float(bucket_count())*m_max_load_factor); + } + + void rehash(size_type count) { + count = std::max(count, size_type(std::ceil(float(size())/max_load_factor()))); + rehash_impl(count); + } + + void reserve(size_type count) { + rehash(size_type(std::ceil(float(count)/max_load_factor()))); + } + + /* + * Observers + */ + hasher hash_function() const { + return static_cast(*this); + } + + key_equal key_eq() const { + return static_cast(*this); + } + + + /* + * Other + */ + iterator mutable_iterator(const_iterator pos) { + return iterator(const_cast(pos.m_bucket)); + } + + template + void serialize(Serializer& serializer) const { + serialize_impl(serializer); + } + + template + void deserialize(Deserializer& deserializer, bool hash_compatible) { + deserialize_impl(deserializer, hash_compatible); + } + +private: + template + std::size_t hash_key(const K& key) const { + return Hash::operator()(key); + } + + template + bool compare_keys(const K1& key1, const K2& key2) const { + return KeyEqual::operator()(key1, key2); + } + + std::size_t bucket_for_hash(std::size_t hash) const { + const std::size_t bucket = GrowthPolicy::bucket_for_hash(hash); + tsl_rh_assert(bucket < m_bucket_count || (bucket == 0 && m_bucket_count == 0)); + + return bucket; + } + + template::value>::type* = nullptr> + std::size_t next_bucket(std::size_t index) const noexcept { + tsl_rh_assert(index < bucket_count()); + + return (index + 1) & this->m_mask; + } + + template::value>::type* = nullptr> + std::size_t next_bucket(std::size_t index) const noexcept { + tsl_rh_assert(index < bucket_count()); + + index++; + return (index != bucket_count())?index:0; + } + + + + template + iterator find_impl(const K& key, std::size_t hash) { + return mutable_iterator(static_cast(this)->find(key, hash)); + } + + template + const_iterator find_impl(const K& key, std::size_t hash) const { + std::size_t ibucket = bucket_for_hash(hash); + distance_type dist_from_ideal_bucket = 0; + + while(dist_from_ideal_bucket <= m_buckets[ibucket].dist_from_ideal_bucket()) { + if(TSL_RH_LIKELY((!USE_STORED_HASH_ON_LOOKUP || m_buckets[ibucket].bucket_hash_equal(hash)) && + compare_keys(KeySelect()(m_buckets[ibucket].value()), key))) + { + return const_iterator(m_buckets + ibucket); + } + + ibucket = next_bucket(ibucket); + dist_from_ideal_bucket++; + } + + return cend(); + } + + void erase_from_bucket(iterator pos) { + pos.m_bucket->clear(); + m_nb_elements--; + + /** + * Backward shift, swap the empty bucket, previous_ibucket, with the values on its right, ibucket, + * until we cross another empty bucket or if the other bucket has a distance_from_ideal_bucket == 0. + * + * We try to move the values closer to their ideal bucket. + */ + std::size_t previous_ibucket = static_cast(pos.m_bucket - m_buckets); + std::size_t ibucket = next_bucket(previous_ibucket); + + while(m_buckets[ibucket].dist_from_ideal_bucket() > 0) { + tsl_rh_assert(m_buckets[previous_ibucket].empty()); + + const distance_type new_distance = distance_type(m_buckets[ibucket].dist_from_ideal_bucket() - 1); + m_buckets[previous_ibucket].set_value_of_empty_bucket(new_distance, m_buckets[ibucket].truncated_hash(), + std::move(m_buckets[ibucket].value())); + m_buckets[ibucket].clear(); + + previous_ibucket = ibucket; + ibucket = next_bucket(ibucket); + } + } + + template + std::pair insert_impl(const K& key, Args&&... value_type_args) { + const std::size_t hash = hash_key(key); + + std::size_t ibucket = bucket_for_hash(hash); + distance_type dist_from_ideal_bucket = 0; + + while(dist_from_ideal_bucket <= m_buckets[ibucket].dist_from_ideal_bucket()) { + if((!USE_STORED_HASH_ON_LOOKUP || m_buckets[ibucket].bucket_hash_equal(hash)) && + compare_keys(KeySelect()(m_buckets[ibucket].value()), key)) + { + return std::make_pair(iterator(m_buckets + ibucket), false); + } + + ibucket = next_bucket(ibucket); + dist_from_ideal_bucket++; + } + + if(rehash_on_extreme_load()) { + ibucket = bucket_for_hash(hash); + dist_from_ideal_bucket = 0; + + while(dist_from_ideal_bucket <= m_buckets[ibucket].dist_from_ideal_bucket()) { + ibucket = next_bucket(ibucket); + dist_from_ideal_bucket++; + } + } + + + if(m_buckets[ibucket].empty()) { + m_buckets[ibucket].set_value_of_empty_bucket(dist_from_ideal_bucket, bucket_entry::truncate_hash(hash), + std::forward(value_type_args)...); + } + else { + insert_value(ibucket, dist_from_ideal_bucket, bucket_entry::truncate_hash(hash), + std::forward(value_type_args)...); + } + + + m_nb_elements++; + /* + * The value will be inserted in ibucket in any case, either because it was + * empty or by stealing the bucket (robin hood). + */ + return std::make_pair(iterator(m_buckets + ibucket), true); + } + + + template + void insert_value(std::size_t ibucket, distance_type dist_from_ideal_bucket, + truncated_hash_type hash, Args&&... value_type_args) + { + value_type value(std::forward(value_type_args)...); + insert_value_impl(ibucket, dist_from_ideal_bucket, hash, value); + } + + void insert_value(std::size_t ibucket, distance_type dist_from_ideal_bucket, + truncated_hash_type hash, value_type&& value) + { + insert_value_impl(ibucket, dist_from_ideal_bucket, hash, value); + } + + /* + * We don't use `value_type&& value` as last argument due to a bug in MSVC when `value_type` is a pointer, + * The compiler is not able to see the difference between `std::string*` and `std::string*&&` resulting in + * a compilation error. + * + * The `value` will be in a moved state at the end of the function. + */ + void insert_value_impl(std::size_t ibucket, distance_type dist_from_ideal_bucket, + truncated_hash_type hash, value_type& value) + { + m_buckets[ibucket].swap_with_value_in_bucket(dist_from_ideal_bucket, hash, value); + ibucket = next_bucket(ibucket); + dist_from_ideal_bucket++; + + while(!m_buckets[ibucket].empty()) { + if(dist_from_ideal_bucket > m_buckets[ibucket].dist_from_ideal_bucket()) { + if(dist_from_ideal_bucket >= bucket_entry::DIST_FROM_IDEAL_BUCKET_LIMIT) { + /** + * The number of probes is really high, rehash the map on the next insert. + * Difficult to do now as rehash may throw an exception. + */ + m_grow_on_next_insert = true; + } + + m_buckets[ibucket].swap_with_value_in_bucket(dist_from_ideal_bucket, hash, value); + } + + ibucket = next_bucket(ibucket); + dist_from_ideal_bucket++; + } + + m_buckets[ibucket].set_value_of_empty_bucket(dist_from_ideal_bucket, hash, std::move(value)); + } + + + void rehash_impl(size_type count) { + robin_hash new_table(count, static_cast(*this), static_cast(*this), + get_allocator(), m_min_load_factor, m_max_load_factor); + + const bool use_stored_hash = USE_STORED_HASH_ON_REHASH(new_table.bucket_count()); + for(auto& bucket: m_buckets_data) { + if(bucket.empty()) { + continue; + } + + const std::size_t hash = use_stored_hash?bucket.truncated_hash(): + new_table.hash_key(KeySelect()(bucket.value())); + + new_table.insert_value_on_rehash(new_table.bucket_for_hash(hash), 0, + bucket_entry::truncate_hash(hash), std::move(bucket.value())); + } + + new_table.m_nb_elements = m_nb_elements; + new_table.swap(*this); + } + + void clear_and_shrink() noexcept { + GrowthPolicy::clear(); + m_buckets_data.clear(); + m_buckets = static_empty_bucket_ptr(); + m_bucket_count = 0; + m_nb_elements = 0; + m_load_threshold = 0; + m_grow_on_next_insert = false; + m_try_shrink_on_next_insert = false; + } + + void insert_value_on_rehash(std::size_t ibucket, distance_type dist_from_ideal_bucket, + truncated_hash_type hash, value_type&& value) + { + while(true) { + if(dist_from_ideal_bucket > m_buckets[ibucket].dist_from_ideal_bucket()) { + if(m_buckets[ibucket].empty()) { + m_buckets[ibucket].set_value_of_empty_bucket(dist_from_ideal_bucket, hash, std::move(value)); + return; + } + else { + m_buckets[ibucket].swap_with_value_in_bucket(dist_from_ideal_bucket, hash, value); + } + } + + dist_from_ideal_bucket++; + ibucket = next_bucket(ibucket); + } + } + + + + /** + * Grow the table if m_grow_on_next_insert is true or we reached the max_load_factor. + * Shrink the table if m_try_shrink_on_next_insert is true (an erase occurred) and + * we're below the min_load_factor. + * + * Return true if the table has been rehashed. + */ + bool rehash_on_extreme_load() { + if(m_grow_on_next_insert || size() >= m_load_threshold) { + rehash_impl(GrowthPolicy::next_bucket_count()); + m_grow_on_next_insert = false; + + return true; + } + + if(m_try_shrink_on_next_insert) { + m_try_shrink_on_next_insert = false; + if(m_min_load_factor != 0.0f && load_factor() < m_min_load_factor) { + reserve(size() + 1); + + return true; + } + } + + return false; + } + + template + void serialize_impl(Serializer& serializer) const { + const slz_size_type version = SERIALIZATION_PROTOCOL_VERSION; + serializer(version); + + // Indicate if the truncated hash of each bucket is stored. Use a std::int16_t instead + // of a bool to avoid the need for the serializer to support an extra 'bool' type. + const std::int16_t hash_stored_for_bucket = static_cast(STORE_HASH); + serializer(hash_stored_for_bucket); + + const slz_size_type nb_elements = m_nb_elements; + serializer(nb_elements); + + const slz_size_type bucket_count = m_buckets_data.size(); + serializer(bucket_count); + + const float min_load_factor = m_min_load_factor; + serializer(min_load_factor); + + const float max_load_factor = m_max_load_factor; + serializer(max_load_factor); + + for(const bucket_entry& bucket: m_buckets_data) { + if(bucket.empty()) { + const std::int16_t empty_bucket = bucket_entry::EMPTY_MARKER_DIST_FROM_IDEAL_BUCKET; + serializer(empty_bucket); + } + else { + const std::int16_t dist_from_ideal_bucket = bucket.dist_from_ideal_bucket(); + serializer(dist_from_ideal_bucket); + if(STORE_HASH) { + const std::uint32_t truncated_hash = bucket.truncated_hash(); + serializer(truncated_hash); + } + serializer(bucket.value()); + } + } + } + + template + void deserialize_impl(Deserializer& deserializer, bool hash_compatible) { + tsl_rh_assert(m_buckets_data.empty()); // Current hash table must be empty + + const slz_size_type version = deserialize_value(deserializer); + // For now we only have one version of the serialization protocol. + // If it doesn't match there is a problem with the file. + if(version != SERIALIZATION_PROTOCOL_VERSION) { + TSL_RH_THROW_OR_TERMINATE(std::runtime_error, "Can't deserialize the ordered_map/set. " + "The protocol version header is invalid."); + } + + const bool hash_stored_for_bucket = deserialize_value(deserializer)?true:false; + if(hash_compatible && STORE_HASH != hash_stored_for_bucket) { + TSL_RH_THROW_OR_TERMINATE(std::runtime_error, "Can't deserialize a map with a different StoreHash " + "than the one used during the serialization when " + "hash compatibility is used"); + } + + const slz_size_type nb_elements = deserialize_value(deserializer); + const slz_size_type bucket_count_ds = deserialize_value(deserializer); + const float min_load_factor = deserialize_value(deserializer); + const float max_load_factor = deserialize_value(deserializer); + + if(min_load_factor < MINIMUM_MIN_LOAD_FACTOR || min_load_factor > MAXIMUM_MIN_LOAD_FACTOR) { + TSL_RH_THROW_OR_TERMINATE(std::runtime_error, "Invalid min_load_factor. Check that the serializer " + "and deserializer support floats correctly as they " + "can be converted implicitly to ints."); + } + + if(max_load_factor < MINIMUM_MAX_LOAD_FACTOR || max_load_factor > MAXIMUM_MAX_LOAD_FACTOR) { + TSL_RH_THROW_OR_TERMINATE(std::runtime_error, "Invalid max_load_factor. Check that the serializer " + "and deserializer support floats correctly as they " + "can be converted implicitly to ints."); + } + + this->min_load_factor(min_load_factor); + this->max_load_factor(max_load_factor); + + if(bucket_count_ds == 0) { + tsl_rh_assert(nb_elements == 0); + return; + } + + + if(!hash_compatible) { + reserve(numeric_cast(nb_elements, "Deserialized nb_elements is too big.")); + for(slz_size_type ibucket = 0; ibucket < bucket_count_ds; ibucket++) { + const distance_type dist_from_ideal_bucket = deserialize_value(deserializer); + if(dist_from_ideal_bucket != bucket_entry::EMPTY_MARKER_DIST_FROM_IDEAL_BUCKET) { + if(hash_stored_for_bucket) { + TSL_RH_UNUSED(deserialize_value(deserializer)); + } + + insert(deserialize_value(deserializer)); + } + } + + tsl_rh_assert(nb_elements == size()); + } + else { + m_bucket_count = numeric_cast(bucket_count_ds, "Deserialized bucket_count is too big."); + + GrowthPolicy::operator=(GrowthPolicy(m_bucket_count)); + // GrowthPolicy should not modify the bucket count we got from deserialization + if(m_bucket_count != bucket_count_ds) { + TSL_RH_THROW_OR_TERMINATE(std::runtime_error, "The GrowthPolicy is not the same even though hash_compatible is true."); + } + + m_nb_elements = numeric_cast(nb_elements, "Deserialized nb_elements is too big."); + m_buckets_data.resize(m_bucket_count); + m_buckets = m_buckets_data.data(); + + for(bucket_entry& bucket: m_buckets_data) { + const distance_type dist_from_ideal_bucket = deserialize_value(deserializer); + if(dist_from_ideal_bucket != bucket_entry::EMPTY_MARKER_DIST_FROM_IDEAL_BUCKET) { + truncated_hash_type truncated_hash = 0; + if(hash_stored_for_bucket) { + tsl_rh_assert(hash_stored_for_bucket); + truncated_hash = deserialize_value(deserializer); + } + + bucket.set_value_of_empty_bucket(dist_from_ideal_bucket, truncated_hash, + deserialize_value(deserializer)); + } + } + + if(!m_buckets_data.empty()) { + m_buckets_data.back().set_as_last_bucket(); + } + } + } + + +public: + static const size_type DEFAULT_INIT_BUCKETS_SIZE = 0; + + static constexpr float DEFAULT_MAX_LOAD_FACTOR = 0.5f; + static constexpr float MINIMUM_MAX_LOAD_FACTOR = 0.2f; + static constexpr float MAXIMUM_MAX_LOAD_FACTOR = 0.95f; + + static constexpr float DEFAULT_MIN_LOAD_FACTOR = 0.0f; + static constexpr float MINIMUM_MIN_LOAD_FACTOR = 0.0f; + static constexpr float MAXIMUM_MIN_LOAD_FACTOR = 0.15f; + + static_assert(MINIMUM_MAX_LOAD_FACTOR < MAXIMUM_MAX_LOAD_FACTOR, + "MINIMUM_MAX_LOAD_FACTOR should be < MAXIMUM_MAX_LOAD_FACTOR"); + static_assert(MINIMUM_MIN_LOAD_FACTOR < MAXIMUM_MIN_LOAD_FACTOR, + "MINIMUM_MIN_LOAD_FACTOR should be < MAXIMUM_MIN_LOAD_FACTOR"); + static_assert(MAXIMUM_MIN_LOAD_FACTOR < MINIMUM_MAX_LOAD_FACTOR, + "MAXIMUM_MIN_LOAD_FACTOR should be < MINIMUM_MAX_LOAD_FACTOR"); + +private: + /** + * Protocol version currenlty used for serialization. + */ + static const slz_size_type SERIALIZATION_PROTOCOL_VERSION = 1; + + /** + * Return an always valid pointer to an static empty bucket_entry with last_bucket() == true. + */ + bucket_entry* static_empty_bucket_ptr() noexcept { + static bucket_entry empty_bucket(true); + return &empty_bucket; + } + +private: + buckets_container_type m_buckets_data; + + /** + * Points to m_buckets_data.data() if !m_buckets_data.empty() otherwise points to static_empty_bucket_ptr. + * This variable is useful to avoid the cost of checking if m_buckets_data is empty when trying + * to find an element. + * + * TODO Remove m_buckets_data and only use a pointer instead of a pointer+vector to save some space in the robin_hash object. + * Manage the Allocator manually. + */ + bucket_entry* m_buckets; + + /** + * Used a lot in find, avoid the call to m_buckets_data.size() which is a bit slower. + */ + size_type m_bucket_count; + + size_type m_nb_elements; + + size_type m_load_threshold; + + float m_min_load_factor; + float m_max_load_factor; + + bool m_grow_on_next_insert; + + /** + * We can't shrink down the map on erase operations as the erase methods need to return the next iterator. + * Shrinking the map would invalidate all the iterators and we could not return the next iterator in a meaningful way, + * On erase, we thus just indicate on erase that we should try to shrink the hash table on the next insert + * if we go below the min_load_factor. + */ + bool m_try_shrink_on_next_insert; +}; + +} + +} + +#endif \ No newline at end of file diff --git a/llvm/include/llvm/Transforms/IPO/tsl/robin_map.h b/llvm/include/llvm/Transforms/IPO/tsl/robin_map.h new file mode 100644 index 0000000000000000000000000000000000000000..1b7eb8d88ebf05f075103bec04c369b39eb96d79 --- /dev/null +++ b/llvm/include/llvm/Transforms/IPO/tsl/robin_map.h @@ -0,0 +1,757 @@ +/** + * MIT License + * + * Copyright (c) 2017 Thibaut Goetghebuer-Planchon + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ +#ifndef TSL_ROBIN_MAP_H +#define TSL_ROBIN_MAP_H + + +#include +#include +#include +#include +#include +#include +#include "llvm/Transforms/IPO/tsl/robin_hash.h" + + +namespace tsl { + + +/** + * Implementation of a hash map using open-addressing and the robin hood hashing algorithm with backward shift deletion. + * + * For operations modifying the hash map (insert, erase, rehash, ...), the strong exception guarantee + * is only guaranteed when the expression `std::is_nothrow_swappable>::value && + * std::is_nothrow_move_constructible>::value` is true, otherwise if an exception + * is thrown during the swap or the move, the hash map may end up in a undefined state. Per the standard + * a `Key` or `T` with a noexcept copy constructor and no move constructor also satisfies the + * `std::is_nothrow_move_constructible>::value` criterion (and will thus guarantee the + * strong exception for the map). + * + * When `StoreHash` is true, 32 bits of the hash are stored alongside the values. It can improve + * the performance during lookups if the `KeyEqual` function takes time (if it engenders a cache-miss for example) + * as we then compare the stored hashes before comparing the keys. When `tsl::rh::power_of_two_growth_policy` is used + * as `GrowthPolicy`, it may also speed-up the rehash process as we can avoid to recalculate the hash. + * When it is detected that storing the hash will not incur any memory penalty due to alignment (i.e. + * `sizeof(tsl::detail_robin_hash::bucket_entry) == + * sizeof(tsl::detail_robin_hash::bucket_entry)`) and `tsl::rh::power_of_two_growth_policy` is + * used, the hash will be stored even if `StoreHash` is false so that we can speed-up the rehash (but it will + * not be used on lookups unless `StoreHash` is true). + * + * `GrowthPolicy` defines how the map grows and consequently how a hash value is mapped to a bucket. + * By default the map uses `tsl::rh::power_of_two_growth_policy`. This policy keeps the number of buckets + * to a power of two and uses a mask to map the hash to a bucket instead of the slow modulo. + * Other growth policies are available and you may define your own growth policy, + * check `tsl::rh::power_of_two_growth_policy` for the interface. + * + * `std::pair` must be swappable. + * + * `Key` and `T` must be copy and/or move constructible. + * + * If the destructor of `Key` or `T` throws an exception, the behaviour of the class is undefined. + * + * Iterators invalidation: + * - clear, operator=, reserve, rehash: always invalidate the iterators. + * - insert, emplace, emplace_hint, operator[]: if there is an effective insert, invalidate the iterators. + * - erase: always invalidate the iterators. + */ +template, + class KeyEqual = std::equal_to, + class Allocator = std::allocator>, + bool StoreHash = false, + class GrowthPolicy = tsl::rh::power_of_two_growth_policy<2>> +class robin_map { +private: + template + using has_is_transparent = tsl::detail_robin_hash::has_is_transparent; + + class KeySelect { + public: + using key_type = Key; + + const key_type& operator()(const std::pair& key_value) const noexcept { + return key_value.first; + } + + key_type& operator()(std::pair& key_value) noexcept { + return key_value.first; + } + }; + + class ValueSelect { + public: + using value_type = T; + + const value_type& operator()(const std::pair& key_value) const noexcept { + return key_value.second; + } + + value_type& operator()(std::pair& key_value) noexcept { + return key_value.second; + } + }; + + using ht = detail_robin_hash::robin_hash, KeySelect, ValueSelect, + Hash, KeyEqual, Allocator, StoreHash, GrowthPolicy>; + +public: + using key_type = typename ht::key_type; + using mapped_type = T; + using value_type = typename ht::value_type; + using size_type = typename ht::size_type; + using difference_type = typename ht::difference_type; + using hasher = typename ht::hasher; + using key_equal = typename ht::key_equal; + using allocator_type = typename ht::allocator_type; + using reference = typename ht::reference; + using const_reference = typename ht::const_reference; + using pointer = typename ht::pointer; + using const_pointer = typename ht::const_pointer; + using iterator = typename ht::iterator; + using const_iterator = typename ht::const_iterator; + + +public: + /* + * Constructors + */ + robin_map(): robin_map(ht::DEFAULT_INIT_BUCKETS_SIZE) { + } + + explicit robin_map(size_type bucket_count, + const Hash& hash = Hash(), + const KeyEqual& equal = KeyEqual(), + const Allocator& alloc = Allocator()): + m_ht(bucket_count, hash, equal, alloc) + { + } + + robin_map(size_type bucket_count, + const Allocator& alloc): robin_map(bucket_count, Hash(), KeyEqual(), alloc) + { + } + + robin_map(size_type bucket_count, + const Hash& hash, + const Allocator& alloc): robin_map(bucket_count, hash, KeyEqual(), alloc) + { + } + + explicit robin_map(const Allocator& alloc): robin_map(ht::DEFAULT_INIT_BUCKETS_SIZE, alloc) { + } + + template + robin_map(InputIt first, InputIt last, + size_type bucket_count = ht::DEFAULT_INIT_BUCKETS_SIZE, + const Hash& hash = Hash(), + const KeyEqual& equal = KeyEqual(), + const Allocator& alloc = Allocator()): robin_map(bucket_count, hash, equal, alloc) + { + insert(first, last); + } + + template + robin_map(InputIt first, InputIt last, + size_type bucket_count, + const Allocator& alloc): robin_map(first, last, bucket_count, Hash(), KeyEqual(), alloc) + { + } + + template + robin_map(InputIt first, InputIt last, + size_type bucket_count, + const Hash& hash, + const Allocator& alloc): robin_map(first, last, bucket_count, hash, KeyEqual(), alloc) + { + } + + robin_map(std::initializer_list init, + size_type bucket_count = ht::DEFAULT_INIT_BUCKETS_SIZE, + const Hash& hash = Hash(), + const KeyEqual& equal = KeyEqual(), + const Allocator& alloc = Allocator()): + robin_map(init.begin(), init.end(), bucket_count, hash, equal, alloc) + { + } + + robin_map(std::initializer_list init, + size_type bucket_count, + const Allocator& alloc): + robin_map(init.begin(), init.end(), bucket_count, Hash(), KeyEqual(), alloc) + { + } + + robin_map(std::initializer_list init, + size_type bucket_count, + const Hash& hash, + const Allocator& alloc): + robin_map(init.begin(), init.end(), bucket_count, hash, KeyEqual(), alloc) + { + } + + robin_map& operator=(std::initializer_list ilist) { + m_ht.clear(); + + m_ht.reserve(ilist.size()); + m_ht.insert(ilist.begin(), ilist.end()); + + return *this; + } + + allocator_type get_allocator() const { return m_ht.get_allocator(); } + + + /* + * Iterators + */ + iterator begin() noexcept { return m_ht.begin(); } + const_iterator begin() const noexcept { return m_ht.begin(); } + const_iterator cbegin() const noexcept { return m_ht.cbegin(); } + + iterator end() noexcept { return m_ht.end(); } + const_iterator end() const noexcept { return m_ht.end(); } + const_iterator cend() const noexcept { return m_ht.cend(); } + + + /* + * Capacity + */ + bool empty() const noexcept { return m_ht.empty(); } + size_type size() const noexcept { return m_ht.size(); } + size_type max_size() const noexcept { return m_ht.max_size(); } + + /* + * Modifiers + */ + void clear() noexcept { m_ht.clear(); } + + + + std::pair insert(const value_type& value) { + return m_ht.insert(value); + } + + template::value>::type* = nullptr> + std::pair insert(P&& value) { + return m_ht.emplace(std::forward

(value)); + } + + std::pair insert(value_type&& value) { + return m_ht.insert(std::move(value)); + } + + + iterator insert(const_iterator hint, const value_type& value) { + return m_ht.insert_hint(hint, value); + } + + template::value>::type* = nullptr> + iterator insert(const_iterator hint, P&& value) { + return m_ht.emplace_hint(hint, std::forward

(value)); + } + + iterator insert(const_iterator hint, value_type&& value) { + return m_ht.insert_hint(hint, std::move(value)); + } + + + template + void insert(InputIt first, InputIt last) { + m_ht.insert(first, last); + } + + void insert(std::initializer_list ilist) { + m_ht.insert(ilist.begin(), ilist.end()); + } + + + + + template + std::pair insert_or_assign(const key_type& k, M&& obj) { + return m_ht.insert_or_assign(k, std::forward(obj)); + } + + template + std::pair insert_or_assign(key_type&& k, M&& obj) { + return m_ht.insert_or_assign(std::move(k), std::forward(obj)); + } + + template + iterator insert_or_assign(const_iterator hint, const key_type& k, M&& obj) { + return m_ht.insert_or_assign(hint, k, std::forward(obj)); + } + + template + iterator insert_or_assign(const_iterator hint, key_type&& k, M&& obj) { + return m_ht.insert_or_assign(hint, std::move(k), std::forward(obj)); + } + + + + /** + * Due to the way elements are stored, emplace will need to move or copy the key-value once. + * The method is equivalent to insert(value_type(std::forward(args)...)); + * + * Mainly here for compatibility with the std::unordered_map interface. + */ + template + std::pair emplace(Args&&... args) { + return m_ht.emplace(std::forward(args)...); + } + + + + /** + * Due to the way elements are stored, emplace_hint will need to move or copy the key-value once. + * The method is equivalent to insert(hint, value_type(std::forward(args)...)); + * + * Mainly here for compatibility with the std::unordered_map interface. + */ + template + iterator emplace_hint(const_iterator hint, Args&&... args) { + return m_ht.emplace_hint(hint, std::forward(args)...); + } + + + + + template + std::pair try_emplace(const key_type& k, Args&&... args) { + return m_ht.try_emplace(k, std::forward(args)...); + } + + template + std::pair try_emplace(key_type&& k, Args&&... args) { + return m_ht.try_emplace(std::move(k), std::forward(args)...); + } + + template + iterator try_emplace(const_iterator hint, const key_type& k, Args&&... args) { + return m_ht.try_emplace_hint(hint, k, std::forward(args)...); + } + + template + iterator try_emplace(const_iterator hint, key_type&& k, Args&&... args) { + return m_ht.try_emplace_hint(hint, std::move(k), std::forward(args)...); + } + + + + + iterator erase(iterator pos) { return m_ht.erase(pos); } + iterator erase(const_iterator pos) { return m_ht.erase(pos); } + iterator erase(const_iterator first, const_iterator last) { return m_ht.erase(first, last); } + size_type erase(const key_type& key) { return m_ht.erase(key); } + + /** + * Use the hash value 'precalculated_hash' instead of hashing the key. The hash value should be the same + * as hash_function()(key). Useful to speed-up the lookup to the value if you already have the hash. + */ + size_type erase(const key_type& key, std::size_t precalculated_hash) { + return m_ht.erase(key, precalculated_hash); + } + + /** + * This overload only participates in the overload resolution if the typedef KeyEqual::is_transparent exists. + * If so, K must be hashable and comparable to Key. + */ + template::value>::type* = nullptr> + size_type erase(const K& key) { return m_ht.erase(key); } + + /** + * @copydoc erase(const K& key) + * + * Use the hash value 'precalculated_hash' instead of hashing the key. The hash value should be the same + * as hash_function()(key). Useful to speed-up the lookup to the value if you already have the hash. + */ + template::value>::type* = nullptr> + size_type erase(const K& key, std::size_t precalculated_hash) { + return m_ht.erase(key, precalculated_hash); + } + + + + void swap(robin_map& other) { other.m_ht.swap(m_ht); } + + + + /* + * Lookup + */ + T& at(const Key& key) { return m_ht.at(key); } + + /** + * Use the hash value 'precalculated_hash' instead of hashing the key. The hash value should be the same + * as hash_function()(key). Useful to speed-up the lookup if you already have the hash. + */ + T& at(const Key& key, std::size_t precalculated_hash) { return m_ht.at(key, precalculated_hash); } + + + const T& at(const Key& key) const { return m_ht.at(key); } + + /** + * @copydoc at(const Key& key, std::size_t precalculated_hash) + */ + const T& at(const Key& key, std::size_t precalculated_hash) const { return m_ht.at(key, precalculated_hash); } + + + /** + * This overload only participates in the overload resolution if the typedef KeyEqual::is_transparent exists. + * If so, K must be hashable and comparable to Key. + */ + template::value>::type* = nullptr> + T& at(const K& key) { return m_ht.at(key); } + + /** + * @copydoc at(const K& key) + * + * Use the hash value 'precalculated_hash' instead of hashing the key. The hash value should be the same + * as hash_function()(key). Useful to speed-up the lookup if you already have the hash. + */ + template::value>::type* = nullptr> + T& at(const K& key, std::size_t precalculated_hash) { return m_ht.at(key, precalculated_hash); } + + + /** + * @copydoc at(const K& key) + */ + template::value>::type* = nullptr> + const T& at(const K& key) const { return m_ht.at(key); } + + /** + * @copydoc at(const K& key, std::size_t precalculated_hash) + */ + template::value>::type* = nullptr> + const T& at(const K& key, std::size_t precalculated_hash) const { return m_ht.at(key, precalculated_hash); } + + + + + T& operator[](const Key& key) { return m_ht[key]; } + T& operator[](Key&& key) { return m_ht[std::move(key)]; } + + + + + size_type count(const Key& key) const { return m_ht.count(key); } + + /** + * Use the hash value 'precalculated_hash' instead of hashing the key. The hash value should be the same + * as hash_function()(key). Useful to speed-up the lookup if you already have the hash. + */ + size_type count(const Key& key, std::size_t precalculated_hash) const { + return m_ht.count(key, precalculated_hash); + } + + /** + * This overload only participates in the overload resolution if the typedef KeyEqual::is_transparent exists. + * If so, K must be hashable and comparable to Key. + */ + template::value>::type* = nullptr> + size_type count(const K& key) const { return m_ht.count(key); } + + /** + * @copydoc count(const K& key) const + * + * Use the hash value 'precalculated_hash' instead of hashing the key. The hash value should be the same + * as hash_function()(key). Useful to speed-up the lookup if you already have the hash. + */ + template::value>::type* = nullptr> + size_type count(const K& key, std::size_t precalculated_hash) const { return m_ht.count(key, precalculated_hash); } + + + + + iterator find(const Key& key) { return m_ht.find(key); } + + /** + * Use the hash value 'precalculated_hash' instead of hashing the key. The hash value should be the same + * as hash_function()(key). Useful to speed-up the lookup if you already have the hash. + */ + iterator find(const Key& key, std::size_t precalculated_hash) { return m_ht.find(key, precalculated_hash); } + + const_iterator find(const Key& key) const { return m_ht.find(key); } + + /** + * @copydoc find(const Key& key, std::size_t precalculated_hash) + */ + const_iterator find(const Key& key, std::size_t precalculated_hash) const { + return m_ht.find(key, precalculated_hash); + } + + /** + * This overload only participates in the overload resolution if the typedef KeyEqual::is_transparent exists. + * If so, K must be hashable and comparable to Key. + */ + template::value>::type* = nullptr> + iterator find(const K& key) { return m_ht.find(key); } + + /** + * @copydoc find(const K& key) + * + * Use the hash value 'precalculated_hash' instead of hashing the key. The hash value should be the same + * as hash_function()(key). Useful to speed-up the lookup if you already have the hash. + */ + template::value>::type* = nullptr> + iterator find(const K& key, std::size_t precalculated_hash) { return m_ht.find(key, precalculated_hash); } + + /** + * @copydoc find(const K& key) + */ + template::value>::type* = nullptr> + const_iterator find(const K& key) const { return m_ht.find(key); } + + /** + * @copydoc find(const K& key) + * + * Use the hash value 'precalculated_hash' instead of hashing the key. The hash value should be the same + * as hash_function()(key). Useful to speed-up the lookup if you already have the hash. + */ + template::value>::type* = nullptr> + const_iterator find(const K& key, std::size_t precalculated_hash) const { + return m_ht.find(key, precalculated_hash); + } + + + + + bool contains(const Key& key) const { return m_ht.contains(key); } + + /** + * Use the hash value 'precalculated_hash' instead of hashing the key. The hash value should be the same + * as hash_function()(key). Useful to speed-up the lookup if you already have the hash. + */ + bool contains(const Key& key, std::size_t precalculated_hash) const { + return m_ht.contains(key, precalculated_hash); + } + + /** + * This overload only participates in the overload resolution if the typedef KeyEqual::is_transparent exists. + * If so, K must be hashable and comparable to Key. + */ + template::value>::type* = nullptr> + bool contains(const K& key) const { return m_ht.contains(key); } + + /** + * @copydoc contains(const K& key) const + * + * Use the hash value 'precalculated_hash' instead of hashing the key. The hash value should be the same + * as hash_function()(key). Useful to speed-up the lookup if you already have the hash. + */ + template::value>::type* = nullptr> + bool contains(const K& key, std::size_t precalculated_hash) const { + return m_ht.contains(key, precalculated_hash); + } + + + + + std::pair equal_range(const Key& key) { return m_ht.equal_range(key); } + + /** + * Use the hash value 'precalculated_hash' instead of hashing the key. The hash value should be the same + * as hash_function()(key). Useful to speed-up the lookup if you already have the hash. + */ + std::pair equal_range(const Key& key, std::size_t precalculated_hash) { + return m_ht.equal_range(key, precalculated_hash); + } + + std::pair equal_range(const Key& key) const { return m_ht.equal_range(key); } + + /** + * @copydoc equal_range(const Key& key, std::size_t precalculated_hash) + */ + std::pair equal_range(const Key& key, std::size_t precalculated_hash) const { + return m_ht.equal_range(key, precalculated_hash); + } + + /** + * This overload only participates in the overload resolution if the typedef KeyEqual::is_transparent exists. + * If so, K must be hashable and comparable to Key. + */ + template::value>::type* = nullptr> + std::pair equal_range(const K& key) { return m_ht.equal_range(key); } + + + /** + * @copydoc equal_range(const K& key) + * + * Use the hash value 'precalculated_hash' instead of hashing the key. The hash value should be the same + * as hash_function()(key). Useful to speed-up the lookup if you already have the hash. + */ + template::value>::type* = nullptr> + std::pair equal_range(const K& key, std::size_t precalculated_hash) { + return m_ht.equal_range(key, precalculated_hash); + } + + /** + * @copydoc equal_range(const K& key) + */ + template::value>::type* = nullptr> + std::pair equal_range(const K& key) const { return m_ht.equal_range(key); } + + /** + * @copydoc equal_range(const K& key, std::size_t precalculated_hash) + */ + template::value>::type* = nullptr> + std::pair equal_range(const K& key, std::size_t precalculated_hash) const { + return m_ht.equal_range(key, precalculated_hash); + } + + + + + /* + * Bucket interface + */ + size_type bucket_count() const { return m_ht.bucket_count(); } + size_type max_bucket_count() const { return m_ht.max_bucket_count(); } + + + /* + * Hash policy + */ + float load_factor() const { return m_ht.load_factor(); } + + float min_load_factor() const { return m_ht.min_load_factor(); } + float max_load_factor() const { return m_ht.max_load_factor(); } + + /** + * Set the `min_load_factor` to `ml`. When the `load_factor` of the map goes + * below `min_load_factor` after some erase operations, the map will be + * shrunk when an insertion occurs. The erase method itself never shrinks + * the map. + * + * The default value of `min_load_factor` is 0.0f, the map never shrinks by default. + */ + void min_load_factor(float ml) { m_ht.min_load_factor(ml); } + void max_load_factor(float ml) { m_ht.max_load_factor(ml); } + + void rehash(size_type count) { m_ht.rehash(count); } + void reserve(size_type count) { m_ht.reserve(count); } + + + /* + * Observers + */ + hasher hash_function() const { return m_ht.hash_function(); } + key_equal key_eq() const { return m_ht.key_eq(); } + + /* + * Other + */ + + /** + * Convert a const_iterator to an iterator. + */ + iterator mutable_iterator(const_iterator pos) { + return m_ht.mutable_iterator(pos); + } + + /** + * Serialize the map through the `serializer` parameter. + * + * The `serializer` parameter must be a function object that supports the following call: + * - `template void operator()(const U& value);` where the types `std::int16_t`, `std::uint32_t`, + * `std::uint64_t`, `float` and `std::pair` must be supported for U. + * + * The implementation leaves binary compatibility (endianness, IEEE 754 for floats, ...) of the types it serializes + * in the hands of the `Serializer` function object if compatibility is required. + */ + template + void serialize(Serializer& serializer) const { + m_ht.serialize(serializer); + } + + /** + * Deserialize a previously serialized map through the `deserializer` parameter. + * + * The `deserializer` parameter must be a function object that supports the following call: + * - `template U operator()();` where the types `std::int16_t`, `std::uint32_t`, `std::uint64_t`, `float` + * and `std::pair` must be supported for U. + * + * If the deserialized hash map type is hash compatible with the serialized map, the deserialization process can be + * sped up by setting `hash_compatible` to true. To be hash compatible, the Hash, KeyEqual and GrowthPolicy must behave the + * same way than the ones used on the serialized map and the StoreHash must have the same value. The `std::size_t` must also + * be of the same size as the one on the platform used to serialize the map. If these criteria are not met, the behaviour is + * undefined with `hash_compatible` sets to true. + * + * The behaviour is undefined if the type `Key` and `T` of the `robin_map` are not the same as the + * types used during serialization. + * + * The implementation leaves binary compatibility (endianness, IEEE 754 for floats, size of int, ...) of the types it + * deserializes in the hands of the `Deserializer` function object if compatibility is required. + */ + template + static robin_map deserialize(Deserializer& deserializer, bool hash_compatible = false) { + robin_map map(0); + map.m_ht.deserialize(deserializer, hash_compatible); + + return map; + } + + friend bool operator==(const robin_map& lhs, const robin_map& rhs) { + if(lhs.size() != rhs.size()) { + return false; + } + + for(const auto& element_lhs: lhs) { + const auto it_element_rhs = rhs.find(element_lhs.first); + if(it_element_rhs == rhs.cend() || element_lhs.second != it_element_rhs->second) { + return false; + } + } + + return true; + } + + friend bool operator!=(const robin_map& lhs, const robin_map& rhs) { + return !operator==(lhs, rhs); + } + + friend void swap(robin_map& lhs, robin_map& rhs) { + lhs.swap(rhs); + } + +private: + ht m_ht; +}; + + +/** + * Same as `tsl::robin_map`. + */ +template, + class KeyEqual = std::equal_to, + class Allocator = std::allocator>, + bool StoreHash = false> +using robin_pg_map = robin_map; + +} // end namespace tsl + +#endif \ No newline at end of file diff --git a/llvm/include/llvm/Transforms/IPO/tsl/robin_set.h b/llvm/include/llvm/Transforms/IPO/tsl/robin_set.h new file mode 100644 index 0000000000000000000000000000000000000000..4bd4c1adfaf957db2ab4f55ce440bea60709ad0d --- /dev/null +++ b/llvm/include/llvm/Transforms/IPO/tsl/robin_set.h @@ -0,0 +1,622 @@ +/** + * MIT License + * + * Copyright (c) 2017 Thibaut Goetghebuer-Planchon + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ +#ifndef TSL_ROBIN_SET_H +#define TSL_ROBIN_SET_H + + +#include +#include +#include +#include +#include +#include +#include "llvm/Transforms/IPO/tsl/robin_hash.h" + + +namespace tsl { + + +/** + * Implementation of a hash set using open-addressing and the robin hood hashing algorithm with backward shift deletion. + * + * For operations modifying the hash set (insert, erase, rehash, ...), the strong exception guarantee + * is only guaranteed when the expression `std::is_nothrow_swappable::value && + * std::is_nothrow_move_constructible::value` is true, otherwise if an exception + * is thrown during the swap or the move, the hash set may end up in a undefined state. Per the standard + * a `Key` with a noexcept copy constructor and no move constructor also satisfies the + * `std::is_nothrow_move_constructible::value` criterion (and will thus guarantee the + * strong exception for the set). + * + * When `StoreHash` is true, 32 bits of the hash are stored alongside the values. It can improve + * the performance during lookups if the `KeyEqual` function takes time (or engenders a cache-miss for example) + * as we then compare the stored hashes before comparing the keys. When `tsl::rh::power_of_two_growth_policy` is used + * as `GrowthPolicy`, it may also speed-up the rehash process as we can avoid to recalculate the hash. + * When it is detected that storing the hash will not incur any memory penalty due to alignment (i.e. + * `sizeof(tsl::detail_robin_hash::bucket_entry) == + * sizeof(tsl::detail_robin_hash::bucket_entry)`) and `tsl::rh::power_of_two_growth_policy` is + * used, the hash will be stored even if `StoreHash` is false so that we can speed-up the rehash (but it will + * not be used on lookups unless `StoreHash` is true). + * + * `GrowthPolicy` defines how the set grows and consequently how a hash value is mapped to a bucket. + * By default the set uses `tsl::rh::power_of_two_growth_policy`. This policy keeps the number of buckets + * to a power of two and uses a mask to set the hash to a bucket instead of the slow modulo. + * Other growth policies are available and you may define your own growth policy, + * check `tsl::rh::power_of_two_growth_policy` for the interface. + * + * `Key` must be swappable. + * + * `Key` must be copy and/or move constructible. + * + * If the destructor of `Key` throws an exception, the behaviour of the class is undefined. + * + * Iterators invalidation: + * - clear, operator=, reserve, rehash: always invalidate the iterators. + * - insert, emplace, emplace_hint, operator[]: if there is an effective insert, invalidate the iterators. + * - erase: always invalidate the iterators. + */ +template, + class KeyEqual = std::equal_to, + class Allocator = std::allocator, + bool StoreHash = false, + class GrowthPolicy = tsl::rh::power_of_two_growth_policy<2>> +class robin_set { +private: + template + using has_is_transparent = tsl::detail_robin_hash::has_is_transparent; + + class KeySelect { + public: + using key_type = Key; + + const key_type& operator()(const Key& key) const noexcept { + return key; + } + + key_type& operator()(Key& key) noexcept { + return key; + } + }; + + using ht = detail_robin_hash::robin_hash; + +public: + using key_type = typename ht::key_type; + using value_type = typename ht::value_type; + using size_type = typename ht::size_type; + using difference_type = typename ht::difference_type; + using hasher = typename ht::hasher; + using key_equal = typename ht::key_equal; + using allocator_type = typename ht::allocator_type; + using reference = typename ht::reference; + using const_reference = typename ht::const_reference; + using pointer = typename ht::pointer; + using const_pointer = typename ht::const_pointer; + using iterator = typename ht::iterator; + using const_iterator = typename ht::const_iterator; + + + /* + * Constructors + */ + robin_set(): robin_set(ht::DEFAULT_INIT_BUCKETS_SIZE) { + } + + explicit robin_set(size_type bucket_count, + const Hash& hash = Hash(), + const KeyEqual& equal = KeyEqual(), + const Allocator& alloc = Allocator()): + m_ht(bucket_count, hash, equal, alloc) + { + } + + robin_set(size_type bucket_count, + const Allocator& alloc): robin_set(bucket_count, Hash(), KeyEqual(), alloc) + { + } + + robin_set(size_type bucket_count, + const Hash& hash, + const Allocator& alloc): robin_set(bucket_count, hash, KeyEqual(), alloc) + { + } + + explicit robin_set(const Allocator& alloc): robin_set(ht::DEFAULT_INIT_BUCKETS_SIZE, alloc) { + } + + template + robin_set(InputIt first, InputIt last, + size_type bucket_count = ht::DEFAULT_INIT_BUCKETS_SIZE, + const Hash& hash = Hash(), + const KeyEqual& equal = KeyEqual(), + const Allocator& alloc = Allocator()): robin_set(bucket_count, hash, equal, alloc) + { + insert(first, last); + } + + template + robin_set(InputIt first, InputIt last, + size_type bucket_count, + const Allocator& alloc): robin_set(first, last, bucket_count, Hash(), KeyEqual(), alloc) + { + } + + template + robin_set(InputIt first, InputIt last, + size_type bucket_count, + const Hash& hash, + const Allocator& alloc): robin_set(first, last, bucket_count, hash, KeyEqual(), alloc) + { + } + + robin_set(std::initializer_list init, + size_type bucket_count = ht::DEFAULT_INIT_BUCKETS_SIZE, + const Hash& hash = Hash(), + const KeyEqual& equal = KeyEqual(), + const Allocator& alloc = Allocator()): + robin_set(init.begin(), init.end(), bucket_count, hash, equal, alloc) + { + } + + robin_set(std::initializer_list init, + size_type bucket_count, + const Allocator& alloc): + robin_set(init.begin(), init.end(), bucket_count, Hash(), KeyEqual(), alloc) + { + } + + robin_set(std::initializer_list init, + size_type bucket_count, + const Hash& hash, + const Allocator& alloc): + robin_set(init.begin(), init.end(), bucket_count, hash, KeyEqual(), alloc) + { + } + + + robin_set& operator=(std::initializer_list ilist) { + m_ht.clear(); + + m_ht.reserve(ilist.size()); + m_ht.insert(ilist.begin(), ilist.end()); + + return *this; + } + + allocator_type get_allocator() const { return m_ht.get_allocator(); } + + + /* + * Iterators + */ + iterator begin() noexcept { return m_ht.begin(); } + const_iterator begin() const noexcept { return m_ht.begin(); } + const_iterator cbegin() const noexcept { return m_ht.cbegin(); } + + iterator end() noexcept { return m_ht.end(); } + const_iterator end() const noexcept { return m_ht.end(); } + const_iterator cend() const noexcept { return m_ht.cend(); } + + + /* + * Capacity + */ + bool empty() const noexcept { return m_ht.empty(); } + size_type size() const noexcept { return m_ht.size(); } + size_type max_size() const noexcept { return m_ht.max_size(); } + + /* + * Modifiers + */ + void clear() noexcept { m_ht.clear(); } + + + + + std::pair insert(const value_type& value) { + return m_ht.insert(value); + } + + std::pair insert(value_type&& value) { + return m_ht.insert(std::move(value)); + } + + iterator insert(const_iterator hint, const value_type& value) { + return m_ht.insert_hint(hint, value); + } + + iterator insert(const_iterator hint, value_type&& value) { + return m_ht.insert_hint(hint, std::move(value)); + } + + template + void insert(InputIt first, InputIt last) { + m_ht.insert(first, last); + } + + void insert(std::initializer_list ilist) { + m_ht.insert(ilist.begin(), ilist.end()); + } + + + + + /** + * Due to the way elements are stored, emplace will need to move or copy the key-value once. + * The method is equivalent to insert(value_type(std::forward(args)...)); + * + * Mainly here for compatibility with the std::unordered_map interface. + */ + template + std::pair emplace(Args&&... args) { + return m_ht.emplace(std::forward(args)...); + } + + + + /** + * Due to the way elements are stored, emplace_hint will need to move or copy the key-value once. + * The method is equivalent to insert(hint, value_type(std::forward(args)...)); + * + * Mainly here for compatibility with the std::unordered_map interface. + */ + template + iterator emplace_hint(const_iterator hint, Args&&... args) { + return m_ht.emplace_hint(hint, std::forward(args)...); + } + + + + iterator erase(iterator pos) { return m_ht.erase(pos); } + iterator erase(const_iterator pos) { return m_ht.erase(pos); } + iterator erase(const_iterator first, const_iterator last) { return m_ht.erase(first, last); } + size_type erase(const key_type& key) { return m_ht.erase(key); } + + /** + * Use the hash value 'precalculated_hash' instead of hashing the key. The hash value should be the same + * as hash_function()(key). Useful to speed-up the lookup to the value if you already have the hash. + */ + size_type erase(const key_type& key, std::size_t precalculated_hash) { + return m_ht.erase(key, precalculated_hash); + } + + /** + * This overload only participates in the overload resolution if the typedef KeyEqual::is_transparent exists. + * If so, K must be hashable and comparable to Key. + */ + template::value>::type* = nullptr> + size_type erase(const K& key) { return m_ht.erase(key); } + + /** + * @copydoc erase(const K& key) + * + * Use the hash value 'precalculated_hash' instead of hashing the key. The hash value should be the same + * as hash_function()(key). Useful to speed-up the lookup to the value if you already have the hash. + */ + template::value>::type* = nullptr> + size_type erase(const K& key, std::size_t precalculated_hash) { + return m_ht.erase(key, precalculated_hash); + } + + + + void swap(robin_set& other) { other.m_ht.swap(m_ht); } + + + + /* + * Lookup + */ + size_type count(const Key& key) const { return m_ht.count(key); } + + /** + * Use the hash value 'precalculated_hash' instead of hashing the key. The hash value should be the same + * as hash_function()(key). Useful to speed-up the lookup if you already have the hash. + */ + size_type count(const Key& key, std::size_t precalculated_hash) const { return m_ht.count(key, precalculated_hash); } + + /** + * This overload only participates in the overload resolution if the typedef KeyEqual::is_transparent exists. + * If so, K must be hashable and comparable to Key. + */ + template::value>::type* = nullptr> + size_type count(const K& key) const { return m_ht.count(key); } + + /** + * @copydoc count(const K& key) const + * + * Use the hash value 'precalculated_hash' instead of hashing the key. The hash value should be the same + * as hash_function()(key). Useful to speed-up the lookup if you already have the hash. + */ + template::value>::type* = nullptr> + size_type count(const K& key, std::size_t precalculated_hash) const { return m_ht.count(key, precalculated_hash); } + + + + + iterator find(const Key& key) { return m_ht.find(key); } + + /** + * Use the hash value 'precalculated_hash' instead of hashing the key. The hash value should be the same + * as hash_function()(key). Useful to speed-up the lookup if you already have the hash. + */ + iterator find(const Key& key, std::size_t precalculated_hash) { return m_ht.find(key, precalculated_hash); } + + const_iterator find(const Key& key) const { return m_ht.find(key); } + + /** + * @copydoc find(const Key& key, std::size_t precalculated_hash) + */ + const_iterator find(const Key& key, std::size_t precalculated_hash) const { return m_ht.find(key, precalculated_hash); } + + /** + * This overload only participates in the overload resolution if the typedef KeyEqual::is_transparent exists. + * If so, K must be hashable and comparable to Key. + */ + template::value>::type* = nullptr> + iterator find(const K& key) { return m_ht.find(key); } + + /** + * @copydoc find(const K& key) + * + * Use the hash value 'precalculated_hash' instead of hashing the key. The hash value should be the same + * as hash_function()(key). Useful to speed-up the lookup if you already have the hash. + */ + template::value>::type* = nullptr> + iterator find(const K& key, std::size_t precalculated_hash) { return m_ht.find(key, precalculated_hash); } + + /** + * @copydoc find(const K& key) + */ + template::value>::type* = nullptr> + const_iterator find(const K& key) const { return m_ht.find(key); } + + /** + * @copydoc find(const K& key) + * + * Use the hash value 'precalculated_hash' instead of hashing the key. The hash value should be the same + * as hash_function()(key). Useful to speed-up the lookup if you already have the hash. + */ + template::value>::type* = nullptr> + const_iterator find(const K& key, std::size_t precalculated_hash) const { return m_ht.find(key, precalculated_hash); } + + + + + bool contains(const Key& key) const { return m_ht.contains(key); } + + /** + * Use the hash value 'precalculated_hash' instead of hashing the key. The hash value should be the same + * as hash_function()(key). Useful to speed-up the lookup if you already have the hash. + */ + bool contains(const Key& key, std::size_t precalculated_hash) const { + return m_ht.contains(key, precalculated_hash); + } + + /** + * This overload only participates in the overload resolution if the typedef KeyEqual::is_transparent exists. + * If so, K must be hashable and comparable to Key. + */ + template::value>::type* = nullptr> + bool contains(const K& key) const { return m_ht.contains(key); } + + /** + * @copydoc contains(const K& key) const + * + * Use the hash value 'precalculated_hash' instead of hashing the key. The hash value should be the same + * as hash_function()(key). Useful to speed-up the lookup if you already have the hash. + */ + template::value>::type* = nullptr> + bool contains(const K& key, std::size_t precalculated_hash) const { + return m_ht.contains(key, precalculated_hash); + } + + + + + std::pair equal_range(const Key& key) { return m_ht.equal_range(key); } + + /** + * Use the hash value 'precalculated_hash' instead of hashing the key. The hash value should be the same + * as hash_function()(key). Useful to speed-up the lookup if you already have the hash. + */ + std::pair equal_range(const Key& key, std::size_t precalculated_hash) { + return m_ht.equal_range(key, precalculated_hash); + } + + std::pair equal_range(const Key& key) const { return m_ht.equal_range(key); } + + /** + * @copydoc equal_range(const Key& key, std::size_t precalculated_hash) + */ + std::pair equal_range(const Key& key, std::size_t precalculated_hash) const { + return m_ht.equal_range(key, precalculated_hash); + } + + /** + * This overload only participates in the overload resolution if the typedef KeyEqual::is_transparent exists. + * If so, K must be hashable and comparable to Key. + */ + template::value>::type* = nullptr> + std::pair equal_range(const K& key) { return m_ht.equal_range(key); } + + /** + * @copydoc equal_range(const K& key) + * + * Use the hash value 'precalculated_hash' instead of hashing the key. The hash value should be the same + * as hash_function()(key). Useful to speed-up the lookup if you already have the hash. + */ + template::value>::type* = nullptr> + std::pair equal_range(const K& key, std::size_t precalculated_hash) { + return m_ht.equal_range(key, precalculated_hash); + } + + /** + * @copydoc equal_range(const K& key) + */ + template::value>::type* = nullptr> + std::pair equal_range(const K& key) const { return m_ht.equal_range(key); } + + /** + * @copydoc equal_range(const K& key, std::size_t precalculated_hash) + */ + template::value>::type* = nullptr> + std::pair equal_range(const K& key, std::size_t precalculated_hash) const { + return m_ht.equal_range(key, precalculated_hash); + } + + + + + /* + * Bucket interface + */ + size_type bucket_count() const { return m_ht.bucket_count(); } + size_type max_bucket_count() const { return m_ht.max_bucket_count(); } + + + /* + * Hash policy + */ + float load_factor() const { return m_ht.load_factor(); } + + float min_load_factor() const { return m_ht.min_load_factor(); } + float max_load_factor() const { return m_ht.max_load_factor(); } + + /** + * Set the `min_load_factor` to `ml`. When the `load_factor` of the set goes + * below `min_load_factor` after some erase operations, the set will be + * shrunk when an insertion occurs. The erase method itself never shrinks + * the set. + * + * The default value of `min_load_factor` is 0.0f, the set never shrinks by default. + */ + void min_load_factor(float ml) { m_ht.min_load_factor(ml); } + void max_load_factor(float ml) { m_ht.max_load_factor(ml); } + + void rehash(size_type count) { m_ht.rehash(count); } + void reserve(size_type count) { m_ht.reserve(count); } + + + /* + * Observers + */ + hasher hash_function() const { return m_ht.hash_function(); } + key_equal key_eq() const { return m_ht.key_eq(); } + + + /* + * Other + */ + + /** + * Convert a const_iterator to an iterator. + */ + iterator mutable_iterator(const_iterator pos) { + return m_ht.mutable_iterator(pos); + } + + friend bool operator==(const robin_set& lhs, const robin_set& rhs) { + if(lhs.size() != rhs.size()) { + return false; + } + + for(const auto& element_lhs: lhs) { + const auto it_element_rhs = rhs.find(element_lhs); + if(it_element_rhs == rhs.cend()) { + return false; + } + } + + return true; + } + + /** + * Serialize the set through the `serializer` parameter. + * + * The `serializer` parameter must be a function object that supports the following call: + * - `template void operator()(const U& value);` where the types `std::int16_t`, `std::uint32_t`, + * `std::uint64_t`, `float` and `Key` must be supported for U. + * + * The implementation leaves binary compatibility (endianness, IEEE 754 for floats, ...) of the types it serializes + * in the hands of the `Serializer` function object if compatibility is required. + */ + template + void serialize(Serializer& serializer) const { + m_ht.serialize(serializer); + } + + /** + * Deserialize a previously serialized set through the `deserializer` parameter. + * + * The `deserializer` parameter must be a function object that supports the following call: + * - `template U operator()();` where the types `std::int16_t`, `std::uint32_t`, `std::uint64_t`, `float` and + * `Key` must be supported for U. + * + * If the deserialized hash set type is hash compatible with the serialized set, the deserialization process can be + * sped up by setting `hash_compatible` to true. To be hash compatible, the Hash, KeyEqual and GrowthPolicy must behave the + * same way than the ones used on the serialized set and the StoreHash must have the same value. The `std::size_t` must also + * be of the same size as the one on the platform used to serialize the set. If these criteria are not met, the behaviour is + * undefined with `hash_compatible` sets to true. + * + * The behaviour is undefined if the type `Key` of the `robin_set` is not the same as the type used during serialization. + * + * The implementation leaves binary compatibility (endianness, IEEE 754 for floats, size of int, ...) of the types it + * deserializes in the hands of the `Deserializer` function object if compatibility is required. + */ + template + static robin_set deserialize(Deserializer& deserializer, bool hash_compatible = false) { + robin_set set(0); + set.m_ht.deserialize(deserializer, hash_compatible); + + return set; + } + + friend bool operator!=(const robin_set& lhs, const robin_set& rhs) { + return !operator==(lhs, rhs); + } + + friend void swap(robin_set& lhs, robin_set& rhs) { + lhs.swap(rhs); + } + +private: + ht m_ht; +}; + + +/** + * Same as `tsl::robin_set`. + */ +template, + class KeyEqual = std::equal_to, + class Allocator = std::allocator, + bool StoreHash = false> +using robin_pg_set = robin_set; + +} // end namespace tsl + +#endif \ No newline at end of file diff --git a/llvm/lib/CodeGen/BreakFalseDeps.cpp b/llvm/lib/CodeGen/BreakFalseDeps.cpp index 57170c58db144b2e29128036f5899464591acabf..43da6922ff6cc5c7e5a0b5abc435ef003c1352fc 100644 --- a/llvm/lib/CodeGen/BreakFalseDeps.cpp +++ b/llvm/lib/CodeGen/BreakFalseDeps.cpp @@ -213,6 +213,10 @@ void BreakFalseDeps::processDefs(MachineInstr *MI) { if (MF->getFunction().hasMinSize()) return; + //========== code size ============== + return; + //========== code size ============== + for (unsigned i = 0, e = MI->isVariadic() ? MI->getNumOperands() : MCID.getNumDefs(); i != e; ++i) { @@ -237,6 +241,10 @@ void BreakFalseDeps::processUndefReads(MachineBasicBlock *MBB) { if (MF->getFunction().hasMinSize()) return; + //========== code size ============== + return; + //========== code size ============== + // Collect this block's live out register units. LiveRegSet.init(*TRI); // We do not need to care about pristine registers as they are just preserved diff --git a/llvm/lib/CodeGen/ExpandMemCmp.cpp b/llvm/lib/CodeGen/ExpandMemCmp.cpp index b2639636dda799222455bb5ac43e70b91eb24d2a..479a2f224fd9462a39505338ee73eac53a6a78aa 100644 --- a/llvm/lib/CodeGen/ExpandMemCmp.cpp +++ b/llvm/lib/CodeGen/ExpandMemCmp.cpp @@ -747,6 +747,10 @@ static bool expandMemCmp(CallInst *CI, const TargetTransformInfo *TTI, if (CI->getFunction()->hasMinSize()) return false; + // ========= code size + return false; + //============ + // Early exit from expansion if size is not a constant. ConstantInt *SizeCast = dyn_cast(CI->getArgOperand(2)); if (!SizeCast) { diff --git a/llvm/lib/CodeGen/TargetPassConfig.cpp b/llvm/lib/CodeGen/TargetPassConfig.cpp index 0bd229f4fc6822224dcccf84301717e95b8eb1eb..7b0c9a9872a6d2e26336c03af43d7a325cab195d 100644 --- a/llvm/lib/CodeGen/TargetPassConfig.cpp +++ b/llvm/lib/CodeGen/TargetPassConfig.cpp @@ -53,6 +53,11 @@ using namespace llvm; +static cl::opt EnableCodeSizeMO( + "enable-code-size-MO", cl::init(true), cl::Hidden, + cl::desc("Enable optimizations for code size as part of the optimization " + "pipeline")); + static cl::opt EnableIPRA("enable-ipra", cl::init(false), cl::Hidden, cl::desc("Enable interprocedural register allocation " @@ -1268,6 +1273,11 @@ void TargetPassConfig::addMachinePasses() { addPass(&StackMapLivenessID); addPass(&LiveDebugValuesID); + //====== code size === + if(EnableCodeSizeMO && TM->Options.SupportsDefaultOutlining){ + addPass(createMachineOutlinerPass(true)); + }else{ + //==================== if (TM->Options.EnableMachineOutliner && getOptLevel() != CodeGenOpt::None && EnableMachineOutliner != RunOutliner::NeverOutline) { bool RunOnAllFunctions = @@ -1277,6 +1287,7 @@ void TargetPassConfig::addMachinePasses() { if (AddOutliner) addPass(createMachineOutlinerPass(RunOnAllFunctions)); } + } // Machine function splitter uses the basic block sections feature. Both // cannot be enabled at the same time. Basic block sections takes precedence. diff --git a/llvm/lib/Passes/PassBuilder.cpp b/llvm/lib/Passes/PassBuilder.cpp index 42fde3752724ea27a74cee91c47e1464a225aa32..4a080e42b4acb7f9fa735b8b04170ec1d0b1f022 100644 --- a/llvm/lib/Passes/PassBuilder.cpp +++ b/llvm/lib/Passes/PassBuilder.cpp @@ -102,6 +102,7 @@ #include "llvm/Transforms/IPO/ForceFunctionAttrs.h" #include "llvm/Transforms/IPO/FunctionAttrs.h" #include "llvm/Transforms/IPO/FunctionImport.h" +#include "llvm/Transforms/IPO/FunctionMerging.h" #include "llvm/Transforms/IPO/GlobalDCE.h" #include "llvm/Transforms/IPO/GlobalOpt.h" #include "llvm/Transforms/IPO/GlobalSplit.h" diff --git a/llvm/lib/Passes/PassBuilderPipelines.cpp b/llvm/lib/Passes/PassBuilderPipelines.cpp index 945ef512391b02e73b2235394ca163940b81e742..67abc6d2983f389b0c09a21c7f771bed600efa6d 100644 --- a/llvm/lib/Passes/PassBuilderPipelines.cpp +++ b/llvm/lib/Passes/PassBuilderPipelines.cpp @@ -47,6 +47,7 @@ #include "llvm/Transforms/IPO/ElimAvailExtern.h" #include "llvm/Transforms/IPO/ForceFunctionAttrs.h" #include "llvm/Transforms/IPO/FunctionAttrs.h" +#include "llvm/Transforms/IPO/FunctionMerging.h" //func-merging #include "llvm/Transforms/IPO/GlobalDCE.h" #include "llvm/Transforms/IPO/GlobalOpt.h" #include "llvm/Transforms/IPO/GlobalSplit.h" @@ -55,6 +56,7 @@ #include "llvm/Transforms/IPO/InferFunctionAttrs.h" #include "llvm/Transforms/IPO/Inliner.h" #include "llvm/Transforms/IPO/LowerTypeTests.h" +#include "llvm/Transforms/IPO/MergeFunctions.h" //run before func-merging #include "llvm/Transforms/IPO/MergeFunctions.h" #include "llvm/Transforms/IPO/ModuleInliner.h" #include "llvm/Transforms/IPO/OpenMPOpt.h" @@ -130,6 +132,10 @@ using namespace llvm; +static cl::opt EnableFuncMerging( + "enable-func-merging", cl::init(0), cl::Hidden, + cl::desc("Enable function merging as part of the optimization pipeline")); + static cl::opt UseInlineAdvisor( "enable-ml-inliner", cl::init(InliningAdvisorMode::Default), cl::Hidden, cl::desc("Enable ML policy for inliner. Currently trained for -Oz only"), @@ -182,6 +188,11 @@ static cl::opt EnableMergeFunctions( "enable-merge-functions", cl::init(false), cl::Hidden, cl::desc("Enable function merging as part of the optimization pipeline")); +static cl::opt EnableCodeSize( + "enable-code-size", cl::init(true), cl::Hidden, + cl::desc("Enable optimizations for code size as part of the optimization " + "pipeline")); + PipelineTuningOptions::PipelineTuningOptions() { LoopInterleaving = true; LoopVectorization = true; @@ -481,9 +492,24 @@ PassBuilder::buildFunctionSimplificationPipeline(OptimizationLevel Level, LPM1.addPass(LICMPass(PTO.LicmMssaOptCap, PTO.LicmMssaNoAccForPromotionCap, /*AllowSpeculation=*/false)); + //====for size==================== + if (EnableCodeSize && false) { + if (Level == OptimizationLevel::O2) { + LPM1.addPass(LoopRotatePass(false, isLTOPreLink(Phase))); + } else { + LPM1.addPass( + LoopRotatePass(Level != OptimizationLevel::Oz, isLTOPreLink(Phase))); + } + } else { + LPM1.addPass( + LoopRotatePass(Level != OptimizationLevel::Oz, isLTOPreLink(Phase))); + } + //======================== // Disable header duplication in loop rotation at -Oz. - LPM1.addPass( - LoopRotatePass(Level != OptimizationLevel::Oz, isLTOPreLink(Phase))); + if (!EnableCodeSize) { + LPM1.addPass( + LoopRotatePass(Level != OptimizationLevel::Oz, isLTOPreLink(Phase))); + } // TODO: Investigate promotion cap for O1. LPM1.addPass(LICMPass(PTO.LicmMssaOptCap, PTO.LicmMssaNoAccForPromotionCap, /*AllowSpeculation=*/true)); @@ -708,6 +734,12 @@ void PassBuilder::addPGOInstrPassesForO0(ModulePassManager &MPM, } static InlineParams getInlineParamsFromOptLevel(OptimizationLevel Level) { + //===for size==================== + if (EnableCodeSize) { + if (Level == OptimizationLevel::O2) + return getInlineParams(2, 1); + } + //===for size==================== return getInlineParams(Level.getSpeedupLevel(), Level.getSizeLevel()); } @@ -1086,7 +1118,8 @@ void PassBuilder::addVectorPasses(OptimizationLevel Level, } // Optimize parallel scalar instruction chains into SIMD instructions. - if (PTO.SLPVectorization) { + //======== code size + if (PTO.SLPVectorization && !EnableCodeSize) { FPM.addPass(SLPVectorizerPass()); if (Level.getSpeedupLevel() > 1 && ExtraVectorizerPasses) { FPM.addPass(EarlyCSEPass()); @@ -1212,9 +1245,19 @@ PassBuilder::buildModuleOptimizationPipeline(OptimizationLevel Level, C(OptimizePM, Level); LoopPassManager LPM; + //====for size==================== + if (EnableCodeSize && false) { + if (Level == OptimizationLevel::O2) { + LPM.addPass(LoopRotatePass(false, LTOPreLink)); + } else { + LPM.addPass(LoopRotatePass(Level != OptimizationLevel::Oz, LTOPreLink)); + } + } else { + LPM.addPass(LoopRotatePass(Level != OptimizationLevel::Oz, LTOPreLink)); + } + //======================== // First rotate loops that may have been un-rotated by prior passes. // Disable header duplication at -Oz. - LPM.addPass(LoopRotatePass(Level != OptimizationLevel::Oz, LTOPreLink)); // Some loops may have become dead by now. Try to delete them. // FIXME: see discussion in https://reviews.llvm.org/D112851, // this may need to be revisited once we run GVN before loop deletion @@ -1324,6 +1367,11 @@ PassBuilder::buildPerModuleDefaultPipeline(OptimizationLevel Level, const ThinOrFullLTOPhase LTOPhase = LTOPreLink ? ThinOrFullLTOPhase::FullLTOPreLink : ThinOrFullLTOPhase::None; + + if (EnableCodeSize) { + MPM.addPass(MergeFunctionsPass()); + MPM.addPass(FunctionMergingPass()); + } // Add the core simplification pipeline. MPM.addPass(buildModuleSimplificationPipeline(Level, LTOPhase)); @@ -1689,7 +1737,6 @@ PassBuilder::buildLTODefaultPipeline(OptimizationLevel Level, MainFPM.addPass(DSEPass()); MainFPM.addPass(MergedLoadStoreMotionPass()); - if (EnableConstraintElimination) MainFPM.addPass(ConstraintEliminationPass()); diff --git a/llvm/lib/Passes/PassRegistry.def b/llvm/lib/Passes/PassRegistry.def index 7c29bffbc327bfd6e050f18e5339742f2ca56e15..fd339057ab591c5c5b133b95e09d7c346948c1c7 100644 --- a/llvm/lib/Passes/PassRegistry.def +++ b/llvm/lib/Passes/PassRegistry.def @@ -59,6 +59,7 @@ MODULE_PASS("elim-avail-extern", EliminateAvailableExternallyPass()) MODULE_PASS("extract-blocks", BlockExtractorPass()) MODULE_PASS("forceattrs", ForceFunctionAttrsPass()) MODULE_PASS("function-import", FunctionImportPass()) +MODULE_PASS("func-merging", FunctionMergingPass()) MODULE_PASS("function-specialization", FunctionSpecializationPass()) MODULE_PASS("globaldce", GlobalDCEPass()) MODULE_PASS("globalopt", GlobalOptPass()) @@ -126,6 +127,7 @@ MODULE_PASS("sancov-module", ModuleSanitizerCoveragePass()) MODULE_PASS("memprof-module", ModuleMemProfilerPass()) MODULE_PASS("poison-checking", PoisonCheckingPass()) MODULE_PASS("pseudo-probe-update", PseudoProbeUpdatePass()) + #undef MODULE_PASS #ifndef MODULE_PASS_WITH_PARAMS @@ -382,6 +384,7 @@ FUNCTION_PASS("tlshoist", TLSVariableHoistPass()) FUNCTION_PASS("transform-warning", WarnMissedTransformationsPass()) FUNCTION_PASS("tsan", ThreadSanitizerPass()) FUNCTION_PASS("memprof", MemProfilerPass()) + #undef FUNCTION_PASS #ifndef FUNCTION_PASS_WITH_PARAMS diff --git a/llvm/lib/Support/Triple.cpp b/llvm/lib/Support/Triple.cpp index 6696d158b2c1ae3d31f0819880ec0814f74a3a16..0d348df5974e2652b3bfc7d200e7c8796325d7db 100644 --- a/llvm/lib/Support/Triple.cpp +++ b/llvm/lib/Support/Triple.cpp @@ -543,6 +543,7 @@ static Triple::VendorType parseVendor(StringRef VendorName) { .Case("mesa", Triple::Mesa) .Case("suse", Triple::SUSE) .Case("oe", Triple::OpenEmbedded) + .Case("pokysdk", Triple::OpenEmbedded) .Default(Triple::UnknownVendor); } diff --git a/llvm/lib/Target/AArch64/AArch64ConditionalCompares.cpp b/llvm/lib/Target/AArch64/AArch64ConditionalCompares.cpp index 343f888b7552e8651fb2e58a2a3fe15d855aa48d..6b1b6e31e40ef6052fcfbe2f12b0a18c2060058e 100644 --- a/llvm/lib/Target/AArch64/AArch64ConditionalCompares.cpp +++ b/llvm/lib/Target/AArch64/AArch64ConditionalCompares.cpp @@ -942,6 +942,10 @@ bool AArch64ConditionalCompares::runOnMachineFunction(MachineFunction &MF) { MinInstr = nullptr; MinSize = MF.getFunction().hasMinSize(); + //========== code size ============== + MinSize=true; + //========== code size ============== + bool Changed = false; CmpConv.runOnMachineFunction(MF, MBPI); diff --git a/llvm/lib/Transforms/IPO/CMakeLists.txt b/llvm/lib/Transforms/IPO/CMakeLists.txt index f9833224d1424914ab3d126f1a4275cca67bca26..e4c999b73c420786b34867a7c2123cfc2793fe2e 100644 --- a/llvm/lib/Transforms/IPO/CMakeLists.txt +++ b/llvm/lib/Transforms/IPO/CMakeLists.txt @@ -15,6 +15,7 @@ add_llvm_component_library(LLVMipo ForceFunctionAttrs.cpp FunctionAttrs.cpp FunctionImport.cpp + FunctionMerging.cpp FunctionSpecialization.cpp GlobalDCE.cpp GlobalOpt.cpp @@ -44,6 +45,7 @@ add_llvm_component_library(LLVMipo ThinLTOBitcodeWriter.cpp WholeProgramDevirt.cpp + ADDITIONAL_HEADER_DIRS ${LLVM_MAIN_INCLUDE_DIR}/llvm/Transforms ${LLVM_MAIN_INCLUDE_DIR}/llvm/Transforms/IPO diff --git a/llvm/lib/Transforms/IPO/FunctionMerging.cpp b/llvm/lib/Transforms/IPO/FunctionMerging.cpp new file mode 100644 index 0000000000000000000000000000000000000000..01093d0c96d37628701c88604d73bfe92a3741aa --- /dev/null +++ b/llvm/lib/Transforms/IPO/FunctionMerging.cpp @@ -0,0 +1,5929 @@ +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// This file implements the general function merging optimization. +// +// It identifies similarities between functions, and If profitable, merges them +// into a single function, replacing the original ones. Functions do not need +// to be identical to be merged. In fact, there is very little restriction to +// merge two function, however, the produced merged function can be larger than +// the two original functions together. For that reason, it uses the +// TargetTransformInfo analysis to estimate the code-size costs of instructions +// in order to estimate the profitability of merging two functions. +// +// This function merging transformation has three major parts: +// 1. The input functions are linearized, representing their CFGs as sequences +// of labels and instructions. +// 2. We apply a sequence alignment algorithm, namely, the Needleman-Wunsch +// algorithm, to identify similar code between the two linearized functions. +// 3. We use the aligned sequences to perform code generate, producing the new +// merged function, using an extra parameter to represent the function +// identifier. +// +// This pass integrates the function merging transformation with an exploration +// framework. For every function, the other functions are ranked based their +// degree of similarity, which is computed from the functions' fingerprints. +// Only the top candidates are analyzed in a greedy manner and if one of them +// produces a profitable result, the merged function is taken. +// +//===----------------------------------------------------------------------===// +// +// This optimization was proposed in +// +// Function Merging by Sequence Alignment (CGO'19) +// Rodrigo C. O. Rocha, Pavlos Petoumenos, Zheng Wang, Murray Cole, Hugh Leather +// +// Effective Function Merging in the SSA Form (PLDI'20) +// Rodrigo C. O. Rocha, Pavlos Petoumenos, Zheng Wang, Murray Cole, Hugh Leather +// +// HyFM: Function Merging for Free (LCTES'21) +// Rodrigo C. O. Rocha, Pavlos Petoumenos, Zheng Wang, Murray Cole, Kim Hazelwood, Hugh Leather +// +// F3M: Fast Focused Function Merging (CGO'22) +// Sean Sterling, Rodrigo C. O. Rocha, Hugh Leather, Kim Hazelwood, Michael O'Boyle, Pavlos Petoumenos +// +//===----------------------------------------------------------------------===// + +#include "llvm/Transforms/IPO/FunctionMerging.h" + +#include "llvm/IR/BasicBlock.h" +#include "llvm/IR/CFG.h" +#include "llvm/IR/Dominators.h" +#include "llvm/IR/Function.h" +#include "llvm/IR/GlobalValue.h" +#include "llvm/IR/IRBuilder.h" +#include "llvm/IR/InstIterator.h" +#include "llvm/IR/InstrTypes.h" +#include "llvm/IR/Instruction.h" +#include "llvm/IR/Instructions.h" +#include "llvm/IR/IntrinsicInst.h" +#include "llvm/IR/Verifier.h" + +#include "llvm/Support/Error.h" +#include "llvm/Support/Timer.h" + +#include "llvm/Support/CommandLine.h" +#include "llvm/Support/FormatVariadic.h" + +#include "llvm/Analysis/LoopInfo.h" +//#include "llvm/Analysis/ValueTracking.h" +#include "llvm/Analysis/CFG.h" +#include "llvm/Analysis/CallGraph.h" +#include "llvm/Analysis/InstructionSimplify.h" +#include "llvm/Analysis/IteratedDominanceFrontier.h" +#include "llvm/Analysis/PostDominators.h" + +#include "llvm/Transforms/Utils/Cloning.h" +#include "llvm/Transforms/Utils/CodeExtractor.h" + +#include "llvm/Support/RandomNumberGenerator.h" + +//#include "llvm/ADT/PostOrderIterator.h" +#include "llvm/ADT/BreadthFirstIterator.h" +#include "llvm/ADT/PostOrderIterator.h" +#include "llvm/ADT/SmallSet.h" +#include "llvm/ADT/SmallVector.h" + +#include "llvm/Analysis/Utils/Local.h" +#include "llvm/Transforms/Utils/Local.h" + +#include "llvm/Transforms/InstCombine/InstCombine.h" +#include "llvm/Transforms/Utils/FunctionComparator.h" +#include "llvm/Transforms/Utils/Mem2Reg.h" +#include "llvm/Transforms/Utils/PromoteMemToReg.h" + +#include "llvm/Bitcode/BitcodeWriter.h" +#include "llvm/Transforms/IPO.h" + +#include "llvm/IR/LegacyPassManager.h" +#include "llvm/Transforms/InstCombine/InstCombine.h" +#include "llvm/Transforms/Scalar.h" +#include "llvm/Transforms/Utils.h" + +// #include "llvm/Transforms/IPO/FeisenDebug.h" + +#include "llvm/Analysis/InlineSizeEstimatorAnalysis.h" + + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include + +#ifdef __unix__ +/* __unix__ is usually defined by compilers targeting Unix systems */ +#include +#elif defined(_WIN32) || defined(WIN32) +/* _Win32 is usually defined by compilers targeting 32 or 64 bit Windows + * systems */ +#include +#endif + +#define DEBUG_TYPE "func-merging" + +//#define ENABLE_DEBUG_CODE + +//#define SKIP_MERGING + +// #define TIME_STEPS_DEBUG + +#define CHANGES + +using namespace llvm; + +//feisen: auxilary functions +static size_t getNumPredecessors(BasicBlock* BB) { + return std::distance(pred_begin(BB), pred_end(BB)); +} +static bool isEntryBlock(BasicBlock* BB) { + Function* func = BB->getParent(); + return BB == &func->getEntryBlock(); +} +static bool isUnreachableBlock(BasicBlock* BB) { + if(!isEntryBlock(BB) && getNumPredecessors(BB) == 0) { + return true; + } + return false; +} + +//feisen +static Value *getPossibleValue_fm(Type *t, BasicBlock *pred) { + Value *v = nullptr; + for(Instruction &I: *pred){ + if(LandingPadInst *landingPadInst = dyn_cast(&I)) { + if(landingPadInst->getType() == t) { + v = landingPadInst; + // v = landingPadInst->getOperand(0); + } + } + if(I.getType()==t){ + v = &I; + } + } + return v; +} + +//feisen +static bool HandlePHINode_fm(PHINode *phiInst) { + bool changed = false; + for(BasicBlock *pred: phiInst->blocks()){ + Value *v = phiInst->getIncomingValueForBlock(pred); + if(UndefValue::classof(v)) { + Type *t = v->getType(); + Value *possibleValue = getPossibleValue_fm(t, pred); + if(possibleValue != nullptr) { + phiInst->setIncomingValueForBlock(pred, possibleValue); + changed = true; + } + } + } + return changed; +} + +//feisen +static PreservedAnalyses resolvePHI_fm(Module &M, ModuleAnalysisManager &AM){ + bool changed = false; + for(Function &F: M){ + + std::string fName = F.getName().str(); + if(fName.size()<5){continue;} + if(fName.at(0)!='_' || + fName.at(1)!='m' || + fName.at(2)!='_' || + fName.at(3)!='f' || + fName.at(4)!='_'){continue;} + + for(BasicBlock &B: F){ + for(Instruction &I: B){ + if(PHINode *phiInst = dyn_cast(&I)) { + changed |= HandlePHINode_fm(phiInst); + } + } + } + } + if (changed){ + return PreservedAnalyses::none(); + } + return PreservedAnalyses::all(); +} + +static bool resolvePHI(Function &F) { + bool changed = false; + std::string fName = F.getName().str(); + if(fName.size()<5){return false;} + if(fName.at(0)!='_' || + fName.at(1)!='m' || + fName.at(2)!='_' || + fName.at(3)!='f' || + fName.at(4)!='_'){return false;} + + for(BasicBlock &B: F){ + for(Instruction &I: B){ + if(PHINode *phiInst = dyn_cast(&I)) { + changed |= HandlePHINode_fm(phiInst); + } + } + } + return changed; +} +//----- + +static cl::opt ExplorationThreshold("func-merging-explore", cl::init(10), cl::Hidden); +static cl::opt RankingThreshold("func-merging-ranking-threshold", cl::init(100), cl::Hidden); +static cl::opt MergingOverheadThreshold("func-merging-threshold", cl::init(50), cl::Hidden); +static cl::opt MaxParamScore("func-merging-max-param", cl::init(true), cl::Hidden); +static cl::opt Debug("func-merging-debug", cl::init(false), cl::Hidden); +static cl::opt Verbose("func-merging-verbose", cl::init(false),cl::Hidden); +static cl::opt IdenticalType("func-merging-identic-type", cl::init(false)); +static cl::opt EnableUnifiedReturnType("func-merging-unify-return", cl::init(true),cl::Hidden); +static cl::opt EnableOperandReordering("func-merging-operand-reorder", cl::init(false),cl::Hidden); //feisen: this must be false to avoid A bug +static cl::opt HasWholeProgram("func-merging-whole-program", cl::init(false)); +static cl::opt EnableHyFMPA("func-merging-hyfm-pa", cl::init(false), cl::Hidden); //feisen:debug bug * thread #1, queue = 'com.apple.main-thread', stop reason = EXC_BAD_ACCESS (code=EXC_I386_GPFLT) +static cl::optEnableHyFMNW("func-merging-hyfm-nw", cl::init(true), cl::Hidden); //feisen:debug bug frame #0: 0x0000000100c48ee9 clang++`llvm::BasicBlock::getSinglePredecessor() const + 25 +static cl::opt EnableSALSSACoalescing("func-merging-coalescing", cl::init(true), cl::Hidden); +static cl::opt ReuseMergedFunctions("func-merging-reuse-merges", cl::init(true), cl::Hidden); +static cl::opt MaxNumSelection("func-merging-max-selects", cl::init(500), cl::Hidden); +static cl::opt HyFMProfitability("hyfm-profitability", cl::init(true), cl::Hidden); +static cl::opt EnableF3M("func-merging-f3m", cl::init(true), cl::Hidden); +static cl::opt LSHRows("hyfm-f3m-rows", cl::init(2), cl::Hidden); +static cl::opt LSHBands("hyfm-f3m-bands", cl::init(100), cl::Hidden); +static cl::opt ShingleCrossBBs("shingling-cross-basic-blocks", cl::init(true)); +static cl::opt AdaptiveThreshold("adaptive-threshold", cl::init(true), cl::Hidden); +static cl::opt AdaptiveBands("adaptive-bands", cl::init(true), cl::Hidden); +static cl::opt RankingDistance("ranking-distance", cl::init(1.0), cl::Hidden); +static cl::opt EnableThunkPrediction("thunk-predictor", cl::init(false), cl::Hidden); +static cl::opt ReportStats("func-merging-report", cl::init(false), cl::Hidden); +static cl::opt MatcherStats("func-merging-matcher-report", cl::init(false), cl::Hidden); +static cl::opt Deterministic("func-merging-deterministic", cl::init(true), cl::Hidden); +static cl::opt BucketSizeCap("bucket-size-cap", cl::init(1000000000), cl::Hidden); + +// static cl::opt ExplorationThreshold( +// "func-merging-explore", cl::init(1), cl::Hidden, +// cl::desc("Exploration threshold of evaluated functions")); + +// static cl::opt RankingThreshold( +// "func-merging-ranking-threshold", cl::init(1), cl::Hidden, +// cl::desc("Threshold of how many candidates should be ranked")); + +// static cl::opt MergingOverheadThreshold( +// "func-merging-threshold", cl::init(1), cl::Hidden, +// cl::desc("Threshold of allowed overhead for merging function")); + +// static cl::opt +// MaxParamScore("func-merging-max-param", cl::init(false), cl::Hidden, +// cl::desc("Maximizing the score for merging parameters")); + +// static cl::opt Debug("func-merging-debug", cl::init(false), cl::Hidden, +// cl::desc("Outputs debug information")); + +// static cl::opt Verbose("func-merging-verbose", cl::init(false), +// cl::Hidden, cl::desc("Outputs debug information")); + +// static cl::opt +// IdenticalType("func-merging-identic-type", cl::init(false), cl::Hidden, +// cl::desc("Match only values with identical types")); + +// static cl::opt +// EnableUnifiedReturnType("func-merging-unify-return", cl::init(false), +// cl::Hidden, +// cl::desc("Enable unified return types")); + +// static cl::opt +// EnableOperandReordering("func-merging-operand-reorder", cl::init(false), +// cl::Hidden, cl::desc("Enable operand reordering")); + +// static cl::opt +// HasWholeProgram("func-merging-whole-program", cl::init(false), cl::Hidden, +// cl::desc("Function merging applied on whole program")); + +// static cl::opt +// EnableHyFMPA("func-merging-hyfm-pa", cl::init(false), cl::Hidden, +// cl::desc("Enable HyFM with the Pairwise Alignment")); + +// static cl::opt +// EnableHyFMNW("func-merging-hyfm-nw", cl::init(true), cl::Hidden, +// cl::desc("Enable HyFM with the Needleman-Wunsch alignment")); + +// static cl::opt EnableSALSSACoalescing( +// "func-merging-coalescing", cl::init(false), cl::Hidden, +// cl::desc("Enable phi-node coalescing during SSA reconstruction")); + +// static cl::opt ReuseMergedFunctions( +// "func-merging-reuse-merges", cl::init(false), cl::Hidden, +// cl::desc("Try to reuse merged functions for another merge operation")); + +// static cl::opt +// MaxNumSelection("func-merging-max-selects", cl::init(500), cl::Hidden, +// cl::desc("Maximum number of allowed operand selection")); + +// static cl::opt HyFMProfitability( +// "hyfm-profitability", cl::init(false), cl::Hidden, +// cl::desc("Try to reuse merged functions for another merge operation")); + +// static cl::opt EnableF3M( +// "func-merging-f3m", cl::init(false), cl::Hidden, +// cl::desc("Enable function pairing based on MinHashes and LSH")); + +// static cl::opt LSHRows( +// "hyfm-f3m-rows", cl::init(2), cl::Hidden, +// cl::desc("Number of rows in the LSH structure")); + +// static cl::opt LSHBands( +// "hyfm-f3m-bands", cl::init(100), cl::Hidden, +// cl::desc("Number of bands in the LSH structure")); + +// static cl::opt ShingleCrossBBs( +// "shingling-cross-basic-blocks", cl::init(true), cl::Hidden, +// cl::desc("Do shingles in MinHash cross basic blocks")); + +// static cl::opt AdaptiveThreshold( +// "adaptive-threshold", cl::init(false), cl::Hidden, +// cl::desc("Adaptively define a new threshold based on the application")); + +// static cl::opt AdaptiveBands( +// "adaptive-bands", cl::init(false), cl::Hidden, +// cl::desc("Adaptively define the LSH geometry based on the application")); + +// static cl::opt RankingDistance( +// "ranking-distance", cl::init(1.0), cl::Hidden, +// cl::desc("Define a threshold to be used")); + +// static cl::opt EnableThunkPrediction( +// "thunk-predictor", cl::init(false), cl::Hidden, +// cl::desc("Enable dismissal of candidates caused by thunk non-profitability")); + +// static cl::opt ReportStats( +// "func-merging-report", cl::init(false), cl::Hidden, +// cl::desc("Only report the distances and alignment between all allowed function pairs")); + +// static cl::opt MatcherStats( +// "func-merging-matcher-report", cl::init(false), cl::Hidden, +// cl::desc("Only report statistics about the distribution of distances and bucket sizes in the Matcher")); + +// static cl::opt Deterministic( +// "func-merging-deterministic", cl::init(true), cl::Hidden, +// cl::desc("Replace all random number generators with deterministic values")); + +// static cl::opt BucketSizeCap( +// "bucket-size-cap", cl::init(1000000000), cl::Hidden, +// cl::desc("Define a threshold to be used")); + + +// Command line option to specify the function to merge. This is +// mainly used for debugging. +static cl::opt ToMergeFile( + "func-merging-pairs-file", cl::init(""), cl::value_desc("filename"), + cl::desc("File containing the functions and basic blocks to merge"), cl::Hidden); + +static std::string GetValueName(const Value *V); + + +#ifdef __unix__ /* __unix__ is usually defined by compilers targeting Unix \ + systems */ + +unsigned long long getTotalSystemMemory() { + long pages = sysconf(_SC_PHYS_PAGES); + long page_size = sysconf(_SC_PAGE_SIZE); + return pages * page_size; +} + +#elif defined(_WIN32) || \ + defined(WIN32) /* _Win32 is usually defined by compilers targeting 32 or \ + 64 bit Windows systems */ + +unsigned long long getTotalSystemMemory() { + MEMORYSTATUSEX status; + status.dwLength = sizeof(status); + GlobalMemoryStatusEx(&status); + return status.ullTotalPhys; +} + +#elif defined(__APPLE__) +//add apple :feisen +#include +#include +unsigned long long getTotalSystemMemory() { + int mib[2]; + mib[0] = CTL_HW; + mib[1] = HW_MEMSIZE; + + unsigned long long physicalMemory; + size_t len = sizeof(physicalMemory); + + if (sysctl(mib, 2, &physicalMemory, &len, NULL, 0) == 0) { + return physicalMemory; + } else { + return 1024*1024; // 获取失败 + } +} + +#endif + +class FunctionMerging { +public: + bool runImpl(Module &M) { + TargetTransformInfo TTI(M.getDataLayout()); + auto GTTI = [&](Function &F) -> TargetTransformInfo * { return &TTI; }; + return runImpl(M, GTTI); + } + bool runImpl(Module &M, function_ref GTTI); +}; + +FunctionMergeResult MergeFunctions(Function *F1, Function *F2, + const FunctionMergingOptions &Options) { + if (F1->getParent() != F2->getParent()) + return FunctionMergeResult(F1, F2, nullptr); + FunctionMerger Merger(F1->getParent()); + return Merger.merge(F1, F2, "", Options); +} + +static bool CmpNumbers(uint64_t L, uint64_t R) { return L == R; } + +// Any two pointers in the same address space are equivalent, intptr_t and +// pointers are equivalent. Otherwise, standard type equivalence rules apply. +static bool CmpTypes(Type *TyL, Type *TyR, const DataLayout *DL) { + auto *PTyL = dyn_cast(TyL); + auto *PTyR = dyn_cast(TyR); + + // const DataLayout &DL = FnL->getParent()->getDataLayout(); + if (PTyL && PTyL->getAddressSpace() == 0) + TyL = DL->getIntPtrType(TyL); + if (PTyR && PTyR->getAddressSpace() == 0) + TyR = DL->getIntPtrType(TyR); + + if (TyL == TyR) + return false; + + if (int Res = CmpNumbers(TyL->getTypeID(), TyR->getTypeID())) + return Res; + + switch (TyL->getTypeID()) { + default: + llvm_unreachable("Unknown type!"); + case Type::IntegerTyID: + return CmpNumbers(cast(TyL)->getBitWidth(), + cast(TyR)->getBitWidth()); + // TyL == TyR would have returned true earlier, because types are uniqued. + case Type::VoidTyID: + case Type::FloatTyID: + case Type::DoubleTyID: + case Type::X86_FP80TyID: + case Type::FP128TyID: + case Type::PPC_FP128TyID: + case Type::LabelTyID: + case Type::MetadataTyID: + case Type::TokenTyID: + return false; + + case Type::PointerTyID: + assert(PTyL && PTyR && "Both types must be pointers here."); + return CmpNumbers(PTyL->getAddressSpace(), PTyR->getAddressSpace()); + + case Type::StructTyID: { + auto *STyL = cast(TyL); + auto *STyR = cast(TyR); + if (STyL->getNumElements() != STyR->getNumElements()) + return CmpNumbers(STyL->getNumElements(), STyR->getNumElements()); + + if (STyL->isPacked() != STyR->isPacked()) + return CmpNumbers(STyL->isPacked(), STyR->isPacked()); + + for (unsigned i = 0, e = STyL->getNumElements(); i != e; ++i) { + if (int Res = + CmpTypes(STyL->getElementType(i), STyR->getElementType(i), DL)) + return Res; + } + return false; + } + + case Type::FunctionTyID: { + auto *FTyL = cast(TyL); + auto *FTyR = cast(TyR); + if (FTyL->getNumParams() != FTyR->getNumParams()) + return CmpNumbers(FTyL->getNumParams(), FTyR->getNumParams()); + + if (FTyL->isVarArg() != FTyR->isVarArg()) + return CmpNumbers(FTyL->isVarArg(), FTyR->isVarArg()); + + if (int Res = CmpTypes(FTyL->getReturnType(), FTyR->getReturnType(), DL)) + return Res; + + for (unsigned i = 0, e = FTyL->getNumParams(); i != e; ++i) { + if (int Res = CmpTypes(FTyL->getParamType(i), FTyR->getParamType(i), DL)) + return Res; + } + return false; + } + + case Type::ArrayTyID: { + auto *STyL = cast(TyL); + auto *STyR = cast(TyR); + if (STyL->getNumElements() != STyR->getNumElements()) + return CmpNumbers(STyL->getNumElements(), STyR->getNumElements()); + return CmpTypes(STyL->getElementType(), STyR->getElementType(), DL); + } + case Type::FixedVectorTyID: + case Type::ScalableVectorTyID: { + auto *STyL = cast(TyL); + auto *STyR = cast(TyR); + if (STyL->getElementCount().isScalable() != + STyR->getElementCount().isScalable()) + return CmpNumbers(STyL->getElementCount().isScalable(), + STyR->getElementCount().isScalable()); + if (STyL->getElementCount() != STyR->getElementCount()) + return CmpNumbers(STyL->getElementCount().getKnownMinValue(), + STyR->getElementCount().getKnownMinValue()); + return CmpTypes(STyL->getElementType(), STyR->getElementType(), DL); + } + } +} + +// Any two pointers in the same address space are equivalent, intptr_t and +// pointers are equivalent. Otherwise, standard type equivalence rules apply. +bool FunctionMerger::areTypesEquivalent(Type *Ty1, Type *Ty2, + const DataLayout *DL, + const FunctionMergingOptions &Options) { + if (Ty1 == Ty2) + return true; + if (Options.IdenticalTypesOnly) + return false; + + return CmpTypes(Ty1, Ty2, DL); +} + +static bool matchIntrinsicCalls(Intrinsic::ID ID, const CallBase *CI1, + const CallBase *CI2) { + Function *F = CI1->getCalledFunction(); + if (!F) + return false; + auto ID1 = (Intrinsic::ID)F->getIntrinsicID(); + + F = CI2->getCalledFunction(); + if (!F) + return false; + auto ID2 = (Intrinsic::ID)F->getIntrinsicID(); + + if (ID1 != ID) + return false; + if (ID1 != ID2) + return false; + + switch (ID) { + default: + break; + case Intrinsic::coro_id: { + /* + auto *InfoArg = CS.getArgOperand(3)->stripPointerCasts(); + if (isa(InfoArg)) + break; + auto *GV = dyn_cast(InfoArg); + Assert(GV && GV->isConstant() && GV->hasDefinitiveInitializer(), + "info argument of llvm.coro.begin must refer to an initialized " + "constant"); + Constant *Init = GV->getInitializer(); + Assert(isa(Init) || isa(Init), + "info argument of llvm.coro.begin must refer to either a struct or " + "an array"); + */ + break; + } + case Intrinsic::ctlz: // llvm.ctlz + case Intrinsic::cttz: // llvm.cttz + // is_zero_undef argument of bit counting intrinsics must be a constant int + return CI1->getArgOperand(1) == CI2->getArgOperand(1); + case Intrinsic::experimental_constrained_fadd: + case Intrinsic::experimental_constrained_fsub: + case Intrinsic::experimental_constrained_fmul: + case Intrinsic::experimental_constrained_fdiv: + case Intrinsic::experimental_constrained_frem: + case Intrinsic::experimental_constrained_fma: + case Intrinsic::experimental_constrained_sqrt: + case Intrinsic::experimental_constrained_pow: + case Intrinsic::experimental_constrained_powi: + case Intrinsic::experimental_constrained_sin: + case Intrinsic::experimental_constrained_cos: + case Intrinsic::experimental_constrained_exp: + case Intrinsic::experimental_constrained_exp2: + case Intrinsic::experimental_constrained_log: + case Intrinsic::experimental_constrained_log10: + case Intrinsic::experimental_constrained_log2: + case Intrinsic::experimental_constrained_rint: + case Intrinsic::experimental_constrained_nearbyint: + // visitConstrainedFPIntrinsic( + // cast(*CS.getInstruction())); + break; + case Intrinsic::dbg_declare: // llvm.dbg.declare + // Assert(isa(CS.getArgOperand(0)), + // "invalid llvm.dbg.declare intrinsic call 1", CS); + // visitDbgIntrinsic("declare", + // cast(*CS.getInstruction())); + break; + case Intrinsic::dbg_value: // llvm.dbg.value + // visitDbgIntrinsic("value", cast(*CS.getInstruction())); + break; + case Intrinsic::dbg_label: // llvm.dbg.label + // visitDbgLabelIntrinsic("label", + // cast(*CS.getInstruction())); + break; + case Intrinsic::memcpy: + case Intrinsic::memmove: + case Intrinsic::memset: { + // isvolatile argument of memory intrinsics must be a constant int + return CI1->getArgOperand(3) == CI2->getArgOperand(3); + } + case Intrinsic::memcpy_element_unordered_atomic: + case Intrinsic::memmove_element_unordered_atomic: + case Intrinsic::memset_element_unordered_atomic: { + const auto *AMI1 = cast(CI1); + const auto *AMI2 = cast(CI2); + + auto *ElementSizeCI1 = dyn_cast(AMI1->getRawElementSizeInBytes()); + + auto *ElementSizeCI2 = dyn_cast(AMI2->getRawElementSizeInBytes()); + + return (ElementSizeCI1 != nullptr && ElementSizeCI1 == ElementSizeCI2); + } + case Intrinsic::gcroot: + case Intrinsic::gcwrite: + case Intrinsic::gcread: + // llvm.gcroot parameter #2 must be a constant. + return CI1->getArgOperand(1) == CI2->getArgOperand(1); + case Intrinsic::init_trampoline: + break; + case Intrinsic::prefetch: + // arguments #2 and #3 in llvm.prefetch must be constants + return CI1->getArgOperand(1) == CI2->getArgOperand(1) && + CI1->getArgOperand(2) == CI2->getArgOperand(2); + case Intrinsic::stackprotector: + /* + Assert(isa(CS.getArgOperand(1)->stripPointerCasts()), + "llvm.stackprotector parameter #2 must resolve to an alloca.", CS); + */ + break; + case Intrinsic::lifetime_start: + case Intrinsic::lifetime_end: + case Intrinsic::invariant_start: + // size argument of memory use markers must be a constant integer + return CI1->getArgOperand(0) == CI2->getArgOperand(0); + case Intrinsic::invariant_end: + // llvm.invariant.end parameter #2 must be a constant integer + return CI1->getArgOperand(1) == CI2->getArgOperand(1); + case Intrinsic::localescape: { + /* + BasicBlock *BB = CS.getParent(); + Assert(BB == &BB->getParent()->front(), + "llvm.localescape used outside of entry block", CS); + Assert(!SawFrameEscape, + "multiple calls to llvm.localescape in one function", CS); + for (Value *Arg : CS.args()) { + if (isa(Arg)) + continue; // Null values are allowed as placeholders. + auto *AI = dyn_cast(Arg->stripPointerCasts()); + Assert(AI && AI->isStaticAlloca(), + "llvm.localescape only accepts static allocas", CS); + } + FrameEscapeInfo[BB->getParent()].first = CS.get_size(); + SawFrameEscape = true; + */ + break; + } + case Intrinsic::localrecover: { + /* + Value *FnArg = CS.getArgOperand(0)->stripPointerCasts(); + Function *Fn = dyn_cast(FnArg); + Assert(Fn && !Fn->isDeclaration(), + "llvm.localrecover first " + "argument must be function defined in this module", + CS); + auto *IdxArg = dyn_cast(CS.getArgOperand(2)); + Assert(IdxArg, "idx argument of llvm.localrecover must be a constant int", + CS); + auto &Entry = FrameEscapeInfo[Fn]; + Entry.second = unsigned( + std::max(uint64_t(Entry.second), IdxArg->getLimitedValue(~0U) + 1)); + */ + break; + } + /* + case Intrinsic::experimental_gc_statepoint: + Assert(!CS.isInlineAsm(), + "gc.statepoint support for inline assembly unimplemented", CS); + Assert(CS.getParent()->getParent()->hasGC(), + "Enclosing function does not use GC.", CS); + + verifyStatepoint(CS); + break; + case Intrinsic::experimental_gc_result: { + Assert(CS.getParent()->getParent()->hasGC(), + "Enclosing function does not use GC.", CS); + // Are we tied to a statepoint properly? + CallSite StatepointCS(CS.getArgOperand(0)); + const Function *StatepointFn = + StatepointCS.getInstruction() ? StatepointCS.getCalledFunction() : + nullptr; Assert(StatepointFn && StatepointFn->isDeclaration() && + StatepointFn->getIntrinsicID() == + Intrinsic::experimental_gc_statepoint, + "gc.result operand #1 must be from a statepoint", CS, + CS.getArgOperand(0)); + + // Assert that result type matches wrapped callee. + const Value *Target = StatepointCS.getArgument(2); + auto *PT = cast(Target->getType()); + auto *TargetFuncType = cast(PT->getElementType()); + Assert(CS.getType() == TargetFuncType->getReturnType(), + "gc.result result type does not match wrapped callee", CS); + break; + } + case Intrinsic::experimental_gc_relocate: { + Assert(CS.get_size() == 3, "wrong number of arguments", CS); + + Assert(isa(CS.getType()->getScalarType()), + "gc.relocate must return a pointer or a vector of pointers", CS); + + // Check that this relocate is correctly tied to the statepoint + + // This is case for relocate on the unwinding path of an invoke statepoint + if (LandingPadInst *LandingPad = + dyn_cast(CS.getArgOperand(0))) { + + const BasicBlock *InvokeBB = + LandingPad->getParent()->getUniquePredecessor(); + + // Landingpad relocates should have only one predecessor with invoke + // statepoint terminator + Assert(InvokeBB, "safepoints should have unique landingpads", + LandingPad->getParent()); + Assert(InvokeBB->getTerminator(), "safepoint block should be well + formed", InvokeBB); Assert(isStatepoint(InvokeBB->getTerminator()), "gc + relocate should be linked to a statepoint", InvokeBB); + } + else { + // In all other cases relocate should be tied to the statepoint + directly. + // This covers relocates on a normal return path of invoke statepoint + and + // relocates of a call statepoint. + auto Token = CS.getArgOperand(0); + Assert(isa(Token) && + isStatepoint(cast(Token)), "gc relocate is incorrectly tied to + the statepoint", CS, Token); + } + + // Verify rest of the relocate arguments. + + ImmutableCallSite StatepointCS( + cast(*CS.getInstruction()).getStatepoint()); + + // Both the base and derived must be piped through the safepoint. + Value* Base = CS.getArgOperand(1); + Assert(isa(Base), + "gc.relocate operand #2 must be integer offset", CS); + + Value* Derived = CS.getArgOperand(2); + Assert(isa(Derived), + "gc.relocate operand #3 must be integer offset", CS); + + const int BaseIndex = cast(Base)->getZExtValue(); + const int DerivedIndex = cast(Derived)->getZExtValue(); + // Check the bounds + Assert(0 <= BaseIndex && BaseIndex < (int)StatepointCS.arg_size(), + "gc.relocate: statepoint base index out of bounds", CS); + Assert(0 <= DerivedIndex && DerivedIndex < (int)StatepointCS.arg_size(), + "gc.relocate: statepoint derived index out of bounds", CS); + + // Check that BaseIndex and DerivedIndex fall within the 'gc parameters' + // section of the statepoint's argument. + Assert(StatepointCS.arg_size() > 0, + "gc.statepoint: insufficient arguments"); + Assert(isa(StatepointCS.getArgument(3)), + "gc.statement: number of call arguments must be constant integer"); + const unsigned NumCallArgs = + cast(StatepointCS.getArgument(3))->getZExtValue(); + Assert(StatepointCS.arg_size() > NumCallArgs + 5, + "gc.statepoint: mismatch in number of call arguments"); + Assert(isa(StatepointCS.getArgument(NumCallArgs + 5)), + "gc.statepoint: number of transition arguments must be " + "a constant integer"); + const int NumTransitionArgs = + cast(StatepointCS.getArgument(NumCallArgs + 5)) + ->getZExtValue(); + const int DeoptArgsStart = 4 + NumCallArgs + 1 + NumTransitionArgs + 1; + Assert(isa(StatepointCS.getArgument(DeoptArgsStart)), + "gc.statepoint: number of deoptimization arguments must be " + "a constant integer"); + const int NumDeoptArgs = + cast(StatepointCS.getArgument(DeoptArgsStart)) + ->getZExtValue(); + const int GCParamArgsStart = DeoptArgsStart + 1 + NumDeoptArgs; + const int GCParamArgsEnd = StatepointCS.arg_size(); + Assert(GCParamArgsStart <= BaseIndex && BaseIndex < GCParamArgsEnd, + "gc.relocate: statepoint base index doesn't fall within the " + "'gc parameters' section of the statepoint call", + CS); + Assert(GCParamArgsStart <= DerivedIndex && DerivedIndex < GCParamArgsEnd, + "gc.relocate: statepoint derived index doesn't fall within the " + "'gc parameters' section of the statepoint call", + CS); + + // Relocated value must be either a pointer type or vector-of-pointer + type, + // but gc_relocate does not need to return the same pointer type as the + // relocated pointer. It can be casted to the correct type later if it's + // desired. However, they must have the same address space and + 'vectorness' GCRelocateInst &Relocate = + cast(*CS.getInstruction()); + Assert(Relocate.getDerivedPtr()->getType()->isPtrOrPtrVectorTy(), + "gc.relocate: relocated value must be a gc pointer", CS); + + auto ResultType = CS.getType(); + auto DerivedType = Relocate.getDerivedPtr()->getType(); + Assert(ResultType->isVectorTy() == DerivedType->isVectorTy(), + "gc.relocate: vector relocates to vector and pointer to pointer", + CS); + Assert( + ResultType->getPointerAddressSpace() == + DerivedType->getPointerAddressSpace(), + "gc.relocate: relocating a pointer shouldn't change its address + space", CS); break; + } + case Intrinsic::eh_exceptioncode: + case Intrinsic::eh_exceptionpointer: { + Assert(isa(CS.getArgOperand(0)), + "eh.exceptionpointer argument must be a catchpad", CS); + break; + } + case Intrinsic::masked_load: { + Assert(CS.getType()->isVectorTy(), "masked_load: must return a vector", + CS); + + Value *Ptr = CS.getArgOperand(0); + //Value *Alignment = CS.getArgOperand(1); + Value *Mask = CS.getArgOperand(2); + Value *PassThru = CS.getArgOperand(3); + Assert(Mask->getType()->isVectorTy(), + "masked_load: mask must be vector", CS); + + // DataTy is the overloaded type + Type *DataTy = cast(Ptr->getType())->getElementType(); + Assert(DataTy == CS.getType(), + "masked_load: return must match pointer type", CS); + Assert(PassThru->getType() == DataTy, + "masked_load: pass through and data type must match", CS); + Assert(Mask->getType()->getVectorNumElements() == + DataTy->getVectorNumElements(), + "masked_load: vector mask must be same length as data", CS); + break; + } + case Intrinsic::masked_store: { + Value *Val = CS.getArgOperand(0); + Value *Ptr = CS.getArgOperand(1); + //Value *Alignment = CS.getArgOperand(2); + Value *Mask = CS.getArgOperand(3); + Assert(Mask->getType()->isVectorTy(), + "masked_store: mask must be vector", CS); + + // DataTy is the overloaded type + Type *DataTy = cast(Ptr->getType())->getElementType(); + Assert(DataTy == Val->getType(), + "masked_store: storee must match pointer type", CS); + Assert(Mask->getType()->getVectorNumElements() == + DataTy->getVectorNumElements(), + "masked_store: vector mask must be same length as data", CS); + break; + } + + case Intrinsic::experimental_guard: { + Assert(CS.isCall(), "experimental_guard cannot be invoked", CS); + Assert(CS.countOperandBundlesOfType(LLVMContext::OB_deopt) == 1, + "experimental_guard must have exactly one " + "\"deopt\" operand bundle"); + break; + } + + case Intrinsic::experimental_deoptimize: { + Assert(CS.isCall(), "experimental_deoptimize cannot be invoked", CS); + Assert(CS.countOperandBundlesOfType(LLVMContext::OB_deopt) == 1, + "experimental_deoptimize must have exactly one " + "\"deopt\" operand bundle"); + Assert(CS.getType() == + CS.getInstruction()->getFunction()->getReturnType(), + "experimental_deoptimize return type must match caller return + type"); + + if (CS.isCall()) { + auto *DeoptCI = CS.getInstruction(); + auto *RI = dyn_cast(DeoptCI->getNextNode()); + Assert(RI, + "calls to experimental_deoptimize must be followed by a return"); + + if (!CS.getType()->isVoidTy() && RI) + Assert(RI->getReturnValue() == DeoptCI, + "calls to experimental_deoptimize must be followed by a return + " "of the value computed by experimental_deoptimize"); + } + + break; + } + */ + }; + return false; // TODO: change to false by default +} + +// bool FunctionMerger::matchLandingPad(LandingPadInst *LP1, LandingPadInst +// *LP2) { +static bool matchLandingPad(LandingPadInst *LP1, LandingPadInst *LP2) { + if (LP1->getType() != LP2->getType()) + return false; + if (LP1->isCleanup() != LP2->isCleanup()) + return false; + if (LP1->getNumClauses() != LP2->getNumClauses()) + return false; + for (unsigned i = 0; i < LP1->getNumClauses(); i++) { + if (LP1->isCatch(i) != LP2->isCatch(i)) + return false; + if (LP1->isFilter(i) != LP2->isFilter(i)) + return false; + if (LP1->getClause(i) != LP2->getClause(i)) + return false; + } + return true; +} + +static bool matchLoadInsts(const LoadInst *LI1, const LoadInst *LI2) { + return LI1->isVolatile() == LI2->isVolatile() && + LI1->getAlign() == LI2->getAlign() && + LI1->getOrdering() == LI2->getOrdering(); +} + +static bool matchStoreInsts(const StoreInst *SI1, const StoreInst *SI2) { + return SI1->isVolatile() == SI2->isVolatile() && + SI1->getAlign() == SI2->getAlign() && + SI1->getOrdering() == SI2->getOrdering(); +} + +static bool matchAllocaInsts(const AllocaInst *AI1, const AllocaInst *AI2) { + //feisen:dubug alloca insts's type is also important + if (AI1->getArraySize() != AI2->getArraySize() || + AI1->getAlign() != AI2->getAlign() || + AI1->getAllocatedType() != AI2->getAllocatedType()) + return false; + + if (AI1->getArraySize() != AI2->getArraySize() || + AI1->getAlign() != AI2->getAlign()) + return false; + + /* + // If size is known, I2 can be seen as equivalent to I1 if it allocates + // the same or less memory. + if (DL->getTypeAllocSize(AI->getAllocatedType()) + < DL->getTypeAllocSize(cast(I2)->getAllocatedType())) + return false; + + */ + + return true; +} + +static bool matchGetElementPtrInsts(const GetElementPtrInst *GEP1, + const GetElementPtrInst *GEP2) { + Type *Ty1 = GEP1->getSourceElementType(); + SmallVector Idxs1(GEP1->idx_begin(), GEP1->idx_end()); + + Type *Ty2 = GEP2->getSourceElementType(); + SmallVector Idxs2(GEP2->idx_begin(), GEP2->idx_end()); + + if (Ty1 != Ty2) + return false; + if (Idxs1.size() != Idxs2.size()) + return false; + + if (Idxs1.empty()) + return true; + + for (unsigned i = 1; i < Idxs1.size(); i++) { + Value *V1 = Idxs1[i]; + Value *V2 = Idxs2[i]; + + // structs must have constant indices, therefore they must be constants and + // must be identical when merging + if (isa(Ty1)) { + if (V1 != V2) + return false; + } + Ty1 = GetElementPtrInst::getTypeAtIndex(Ty1, V1); + Ty2 = GetElementPtrInst::getTypeAtIndex(Ty2, V2); + if (Ty1 != Ty2) + return false; + } + return true; +} + +static bool matchSwitchInsts(const SwitchInst *SI1, const SwitchInst *SI2) { + if (SI1->getNumCases() == SI2->getNumCases()) { + auto CaseIt1 = SI1->case_begin(), CaseEnd1 = SI1->case_end(); + auto CaseIt2 = SI2->case_begin(), CaseEnd2 = SI2->case_end(); + do { + auto *Case1 = &*CaseIt1; + auto *Case2 = &*CaseIt2; + if (Case1 != Case2) + return false; // TODO: could allow permutation! + ++CaseIt1; + ++CaseIt2; + } while (CaseIt1 != CaseEnd1 && CaseIt2 != CaseEnd2); + return true; + } + return false; +} + +static bool matchCallInsts(const CallBase *CI1, const CallBase *CI2) { + if (CI1->isInlineAsm() || CI2->isInlineAsm()) + return false; + + // if (CI1->getCalledFunction()==nullptr) return false; + + if (CI1->getCalledFunction() != CI2->getCalledFunction()) + return false; + + if (Function *F = CI1->getCalledFunction()) { + if (auto ID = (Intrinsic::ID)F->getIntrinsicID()) { + if (!matchIntrinsicCalls(ID, CI1, CI2)) + return false; + } + } + + return CI1->arg_size() == CI2->arg_size() && + CI1->getCallingConv() == CI2->getCallingConv() && + CI1->getAttributes() == CI2->getAttributes(); +} + +static bool matchInvokeInsts(const InvokeInst *II1, const InvokeInst *II2) { + return matchCallInsts(II1, II2) && + II1->getCallingConv() == II2->getCallingConv() && + II1->getAttributes() == II2->getAttributes() && + matchLandingPad(II1->getLandingPadInst(), II2->getLandingPadInst()); +} + +static bool matchInsertValueInsts(const InsertValueInst *IV1, + const InsertValueInst *IV2) { + return IV1->getIndices() == IV2->getIndices(); +} + +static bool matchExtractValueInsts(const ExtractValueInst *EV1, + const ExtractValueInst *EV2) { + return EV1->getIndices() == EV2->getIndices(); +} + +static bool matchFenceInsts(const FenceInst *FI1, const FenceInst *FI2) { + return FI1->getOrdering() == FI2->getOrdering() && + FI1->getSyncScopeID() == FI2->getSyncScopeID(); +} + +bool FunctionMerger::matchInstructions(Instruction *I1, Instruction *I2, + const FunctionMergingOptions &Options) { + + if (I1->getOpcode() != I2->getOpcode()) + return false; + + if (I1->getOpcode() == Instruction::CallBr) + return false; + + // Returns are special cases that can differ in the number of operands + if (I1->getOpcode() == Instruction::Ret) + return true; + + if (I1->getNumOperands() != I2->getNumOperands()) + return false; + + const DataLayout *DL = + &I1->getParent()->getParent()->getParent()->getDataLayout(); + + bool sameType = false; + if (Options.IdenticalTypesOnly) { + sameType = (I1->getType() == I2->getType()); + for (unsigned i = 0; i < I1->getNumOperands(); i++) { + sameType = sameType && + (I1->getOperand(i)->getType() == I2->getOperand(i)->getType()); + } + } else { + sameType = areTypesEquivalent(I1->getType(), I2->getType(), DL, Options); + for (unsigned i = 0; i < I1->getNumOperands(); i++) { + sameType = sameType && + areTypesEquivalent(I1->getOperand(i)->getType(), + I2->getOperand(i)->getType(), DL, Options); + } + } + if (!sameType) + return false; + + switch (I1->getOpcode()) { + // case Instruction::Br: return false; //{ return (I1->getNumOperands()==1); + // } + + //#define MatchCaseInst(Kind, I1, I2) case Instruction::#Kind + case Instruction::ShuffleVector: //feisen:24/03/09/ + return cast(I1)->getShuffleMask() == + cast(I2)->getShuffleMask(); + case Instruction::Load: + return matchLoadInsts(dyn_cast(I1), dyn_cast(I2)); + case Instruction::Store: + return matchStoreInsts(dyn_cast(I1), dyn_cast(I2)); + case Instruction::Alloca: + return matchAllocaInsts(dyn_cast(I1), dyn_cast(I2)); + case Instruction::GetElementPtr: + return matchGetElementPtrInsts(dyn_cast(I1), + dyn_cast(I2)); + case Instruction::Switch: + return matchSwitchInsts(dyn_cast(I1), dyn_cast(I2)); + case Instruction::Call: + return matchCallInsts(dyn_cast(I1), dyn_cast(I2)); + case Instruction::Invoke: + return matchInvokeInsts(dyn_cast(I1), dyn_cast(I2)); + case Instruction::InsertValue: + return matchInsertValueInsts(dyn_cast(I1), + dyn_cast(I2)); + case Instruction::ExtractValue: + return matchExtractValueInsts(dyn_cast(I1), + dyn_cast(I2)); + case Instruction::Fence: + return matchFenceInsts(dyn_cast(I1), dyn_cast(I2)); + case Instruction::AtomicCmpXchg: { + const AtomicCmpXchgInst *CXI = dyn_cast(I1); + const AtomicCmpXchgInst *CXI2 = cast(I2); + return CXI->isVolatile() == CXI2->isVolatile() && + CXI->isWeak() == CXI2->isWeak() && + CXI->getSuccessOrdering() == CXI2->getSuccessOrdering() && + CXI->getFailureOrdering() == CXI2->getFailureOrdering() && + CXI->getSyncScopeID() == CXI2->getSyncScopeID(); + } + case Instruction::AtomicRMW: { + const AtomicRMWInst *RMWI = dyn_cast(I1); + return RMWI->getOperation() == cast(I2)->getOperation() && + RMWI->isVolatile() == cast(I2)->isVolatile() && + RMWI->getOrdering() == cast(I2)->getOrdering() && + RMWI->getSyncScopeID() == cast(I2)->getSyncScopeID(); + } + default: + if (auto *CI = dyn_cast(I1)) + return CI->getPredicate() == cast(I2)->getPredicate(); + if (isa(I1)) { + if (!isa(I2)) + return false; + if (I1->hasNoUnsignedWrap() != I2->hasNoUnsignedWrap()) + return false; + if (I1->hasNoSignedWrap() != I2->hasNoSignedWrap()) + return false; + } + if (isa(I1)) { + if (!isa(I2)) + return false; + if (I1->isExact() != I2->isExact()) + return false; + } + if (isa(I1)) { + if (!isa(I2)) + return false; + if (I1->isFast() != I2->isFast()) + return false; + if (I1->hasAllowReassoc() != I2->hasAllowReassoc()) + return false; + if (I1->hasNoNaNs() != I2->hasNoNaNs()) + return false; + if (I1->hasNoInfs() != I2->hasNoInfs()) + return false; + if (I1->hasNoSignedZeros() != I2->hasNoSignedZeros()) + return false; + if (I1->hasAllowReciprocal() != I2->hasAllowReciprocal()) + return false; + if (I1->hasAllowContract() != I2->hasAllowContract()) + return false; + if (I1->hasApproxFunc() != I2->hasApproxFunc()) + return false; + } + } + + return true; +} + +bool FunctionMerger::match(Value *V1, Value *V2) { + if (auto *I1 = dyn_cast(V1)) + if (auto *I2 = dyn_cast(V2)) + return matchInstructions(I1, I2); + + if (auto *BB1 = dyn_cast(V1)) + if (auto *BB2 = dyn_cast(V2)) + return matchBlocks(BB1, BB2); + + return false; +} + +bool FunctionMerger::matchBlocks(BasicBlock *BB1, BasicBlock *BB2) { + if (BB1 == nullptr || BB2 == nullptr) + return false; + if (BB1->isLandingPad() || BB2->isLandingPad()) { + LandingPadInst *LP1 = BB1->getLandingPadInst(); + LandingPadInst *LP2 = BB2->getLandingPadInst(); + if (LP1 == nullptr || LP2 == nullptr) + return false; + return matchLandingPad(LP1, LP2); + } + return true; +} + +bool FunctionMerger::matchWholeBlocks(Value *V1, Value *V2) { + auto *BB1 = dyn_cast(V1); + auto *BB2 = dyn_cast(V2); + if (BB1 == nullptr || BB2 == nullptr) + return false; + + if (!matchBlocks(BB1, BB2)) + return false; + + auto It1 = BB1->begin(); + auto It2 = BB2->begin(); + + while (isa(*It1) || isa(*It1)) + It1++; + while (isa(*It2) || isa(*It2)) + It2++; + + while (It1 != BB1->end() && It2 != BB2->end()) { + if (!matchInstructions(&*It1, &*It2)) + return false; + + It1++; + It2++; + } + + if (It1 != BB1->end() || It2 != BB2->end()) + return false; + + return true; +} + +static unsigned +RandomLinearizationOfBlocks(BasicBlock *BB, + std::list &OrederedBBs, + std::set &Visited) { + if (Visited.find(BB) != Visited.end()) + return 0; + Visited.insert(BB); + + Instruction *TI = BB->getTerminator(); + + std::vector NextBBs; + for (unsigned i = 0; i < TI->getNumSuccessors(); i++) { + NextBBs.push_back(TI->getSuccessor(i)); + } + std::random_device rd; + std::shuffle(NextBBs.begin(), NextBBs.end(), std::mt19937(rd())); + + unsigned SumSizes = 0; + for (BasicBlock *NextBlock : NextBBs) { + SumSizes += RandomLinearizationOfBlocks(NextBlock, OrederedBBs, Visited); + } + + OrederedBBs.push_front(BB); + return SumSizes + BB->size(); +} + +static unsigned +RandomLinearizationOfBlocks(Function *F, std::list &OrederedBBs) { + std::set Visited; + return RandomLinearizationOfBlocks(&F->getEntryBlock(), OrederedBBs, Visited); +} + +static unsigned +CanonicalLinearizationOfBlocks(BasicBlock *BB, + std::list &OrederedBBs, + std::set &Visited) { + if (Visited.find(BB) != Visited.end()) + return 0; + Visited.insert(BB); + + Instruction *TI = BB->getTerminator(); + + unsigned SumSizes = 0; + for (unsigned i = 0; i < TI->getNumSuccessors(); i++) { + SumSizes += CanonicalLinearizationOfBlocks(TI->getSuccessor(i), OrederedBBs, + Visited); + } + // for (unsigned i = 1; i <= TI->getNumSuccessors(); i++) { + // SumSizes += + // CanonicalLinearizationOfBlocks(TI->getSuccessor(TI->getNumSuccessors()-i), + // OrederedBBs, + // Visited); + //} + + OrederedBBs.push_front(BB); + return SumSizes + BB->size(); +} + +static unsigned +CanonicalLinearizationOfBlocks(Function *F, + std::list &OrederedBBs) { + std::set Visited; + return CanonicalLinearizationOfBlocks(&F->getEntryBlock(), OrederedBBs, + Visited); +} + +static void vectorizeBB(SmallVectorImpl &Vec, BasicBlock *BB) { + Vec.push_back(BB); + for (Instruction &I : *BB) + if (!isa(&I) && !isa(&I)) + Vec.push_back(&I); +} + +void FunctionMerger::linearize(Function *F, SmallVectorImpl &FVec, + FunctionMerger::LinearizationKind LK) { + std::list OrderedBBs; + + unsigned FReserve = 0; + switch (LK) { + case LinearizationKind::LK_Random: + FReserve = RandomLinearizationOfBlocks(F, OrderedBBs); + break; + case LinearizationKind::LK_Canonical: + default: + FReserve = CanonicalLinearizationOfBlocks(F, OrderedBBs); + break; + } + + FVec.reserve(FReserve + OrderedBBs.size()); + for (BasicBlock *BB : OrderedBBs) + vectorizeBB(FVec, BB); +} + +bool FunctionMerger::validMergeTypes(Function *F1, Function *F2, + const FunctionMergingOptions &Options) { + bool EquivTypes = + areTypesEquivalent(F1->getReturnType(), F2->getReturnType(), DL, Options); + if (!EquivTypes && !F1->getReturnType()->isVoidTy() && + !F2->getReturnType()->isVoidTy()) { + return false; + } + return true; +} + +#ifdef TIME_STEPS_DEBUG +Timer TimeLin("Merge::CodeGen::Lin", "Merge::CodeGen::Lin"); +Timer TimeAlign("Merge::CodeGen::Align", "Merge::CodeGen::Align"); +Timer TimeAlignRank("Merge::CodeGen::Align::Rank", "Merge::CodeGen::Align::Rank"); +Timer TimeParam("Merge::CodeGen::Param", "Merge::CodeGen::Param"); +Timer TimeCodeGen("Merge::CodeGen::Gen", "Merge::CodeGen::Gen"); +Timer TimeCodeGenFix("Merge::CodeGen::Fix", "Merge::CodeGen::Fix"); +Timer TimePostOpt("Merge::CodeGen::PostOpt", "Merge::CodeGen::PostOpt"); +Timer TimeCodeGenTotal("Merge::CodeGen::Total", "Merge::CodeGen::Total"); + +Timer TimePreProcess("Merge::Preprocess", "Merge::Preprocess"); +Timer TimeRank("Merge::Rank", "Merge::Rank"); +Timer TimeVerify("Merge::Verify", "Merge::Verify"); +Timer TimeUpdate("Merge::Update", "Merge::Update"); +Timer TimePrinting("Merge::Printing", "Merge::Printing"); +Timer TimeTotal("Merge::Total", "Merge::Total"); + +std::chrono::time_point time_ranking_start; +std::chrono::time_point time_ranking_end; +std::chrono::time_point time_align_start; +std::chrono::time_point time_align_end; +std::chrono::time_point time_codegen_start; +std::chrono::time_point time_codegen_end; +std::chrono::time_point time_verify_start; +std::chrono::time_point time_verify_end; +std::chrono::time_point time_update_start; +std::chrono::time_point time_update_end; +std::chrono::time_point time_iteration_end; +#endif + + +static bool validMergePair(Function *F1, Function *F2) { + if (!HasWholeProgram && (F1->hasAvailableExternallyLinkage() || + F2->hasAvailableExternallyLinkage())) + return false; + + if (!HasWholeProgram && + (F1->hasLinkOnceLinkage() || F2->hasLinkOnceLinkage())) + return false; + + // if (!F1->getSection().equals(F2->getSection())) return false; + // if (F1->hasSection()!=F2->hasSection()) return false; + // if (F1->hasSection() && !F1->getSection().equals(F2->getSection())) return + // false; + + if (F1->hasComdat() != F2->hasComdat()) + return false; + if (F1->hasComdat() && F1->getComdat() != F2->getComdat()) + return false; + + if (F1->hasPersonalityFn() != F2->hasPersonalityFn()) + return false; + if (F1->hasPersonalityFn()) { + Constant *PersonalityFn1 = F1->getPersonalityFn(); + Constant *PersonalityFn2 = F2->getPersonalityFn(); + if (PersonalityFn1 != PersonalityFn2) + return false; + } + + return true; +} + +static void MergeArguments(LLVMContext &Context, Function *F1, Function *F2, + AlignedCode &AlignedSeq, + std::map &ParamMap1, + std::map &ParamMap2, + std::vector &Args, + const FunctionMergingOptions &Options) { + + std::vector ArgsList1; + for (Argument &arg : F1->args()) { + ArgsList1.push_back(&arg); + } + + Args.push_back(IntegerType::get(Context, 1)); // push the function Id argument + unsigned ArgId = 0; + for (auto I = F1->arg_begin(), E = F1->arg_end(); I != E; I++) { + ParamMap1[ArgId] = Args.size(); + Args.push_back((*I).getType()); + ArgId++; + } + + auto AttrList1 = F1->getAttributes(); + auto AttrList2 = F2->getAttributes(); + + // merge arguments from Function2 with Function1 + ArgId = 0; + for (auto I = F2->arg_begin(), E = F2->arg_end(); I != E; I++) { + + std::map MatchingScore; + // first try to find an argument with the same name/type + // otherwise try to match by type only + for (unsigned i = 0; i < ArgsList1.size(); i++) { + if (ArgsList1[i]->getType() == (*I).getType()) { + + auto AttrSet1 = AttrList1.getParamAttrs(ArgsList1[i]->getArgNo()); + auto AttrSet2 = AttrList2.getParamAttrs((*I).getArgNo()); + if (AttrSet1 != AttrSet2) + continue; + + bool hasConflict = false; // check for conflict from a previous matching + for (auto ParamPair : ParamMap2) { + if (ParamPair.second == ParamMap1[i]) { + hasConflict = true; + break; + } + } + if (hasConflict) + continue; + MatchingScore[i] = 0; + if (!Options.MaximizeParamScore) + break; // if not maximize score, get the first one + } + } + + //TODO: 这里存在问题 + if (MatchingScore.size() > 0) { // maximize scores + for (auto &Entry : AlignedSeq) { + if (Entry.match()) { + auto *I1 = dyn_cast(Entry.get(0)); + auto *I2 = dyn_cast(Entry.get(1)); + if (I1 != nullptr && I2 != nullptr) { // test both for sanity + for (unsigned i = 0; i < I1->getNumOperands(); i++) { + for (auto KV : MatchingScore) { + if (I1->getOperand(i) == ArgsList1[KV.first]) { + if (i < I2->getNumOperands() && I2->getOperand(i) == &(*I)) { + MatchingScore[KV.first]++; + } + } + } + } + } + } + } + + int MaxScore = -1; + unsigned MaxId = 0; + + for (auto KV : MatchingScore) { + if (KV.second > MaxScore) { + MaxScore = KV.second; + MaxId = KV.first; + } + } + + ParamMap2[ArgId] = ParamMap1[MaxId]; + } else { + ParamMap2[ArgId] = Args.size(); + Args.push_back((*I).getType()); + } + + ArgId++; + } + // errs() << "Args.size() = " << Args.size() << "\n"; + // Args[0]->print(errs()); +} + +static void SetFunctionAttributes(Function *F1, Function *F2, + Function *MergedFunc) { + unsigned MaxAlignment = std::max(F1->getAlignment(), F2->getAlignment()); + if (F1->getAlignment() != F2->getAlignment()) { + if (Debug) + errs() << "WARNING: different function alignment!\n"; + } + if (MaxAlignment) + MergedFunc->setAlignment(Align(MaxAlignment)); + + if (F1->getCallingConv() == F2->getCallingConv()) { + MergedFunc->setCallingConv(F1->getCallingConv()); + } else { + if (Debug) + errs() << "WARNING: different calling convention!\n"; + // MergedFunc->setCallingConv(CallingConv::Fast); + } + + /* + if (F1->getLinkage() == F2->getLinkage()) { + MergedFunc->setLinkage(F1->getLinkage()); + } else { + if (Debug) errs() << "ERROR: different linkage type!\n"; + MergedFunc->setLinkage(GlobalValue::LinkageTypes::InternalLinkage); + } + */ + // MergedFunc->setLinkage(GlobalValue::LinkageTypes::ExternalLinkage); + MergedFunc->setLinkage(GlobalValue::LinkageTypes::InternalLinkage); + + /* + if (F1->isDSOLocal() == F2->isDSOLocal()) { + MergedFunc->setDSOLocal(F1->isDSOLocal()); + } else { + if (Debug) errs() << "ERROR: different DSO local!\n"; + } + */ + MergedFunc->setDSOLocal(true); + + if (F1->getSubprogram() == F2->getSubprogram()) { + MergedFunc->setSubprogram(F1->getSubprogram()); + } else { + if (Debug) + errs() << "WARNING: different subprograms!\n"; + } + + /* + if (F1->getUnnamedAddr() == F2->getUnnamedAddr()) { + MergedFunc->setUnnamedAddr(F1->getUnnamedAddr()); + } else { + if (Debug) errs() << "ERROR: different unnamed addr!\n"; + MergedFunc->setUnnamedAddr(GlobalValue::UnnamedAddr::Local); + } + */ + // MergedFunc->setUnnamedAddr(GlobalValue::UnnamedAddr::Local); + + /* + if (F1->getVisibility() == F2->getVisibility()) { + //MergedFunc->setVisibility(F1->getVisibility()); + } else if (Debug) { + errs() << "ERROR: different visibility!\n"; + } + */ + MergedFunc->setVisibility(GlobalValue::VisibilityTypes::DefaultVisibility); + + // Exception Handling requires landing pads to have the same personality + // function + if (F1->hasPersonalityFn() && F2->hasPersonalityFn()) { + Constant *PersonalityFn1 = F1->getPersonalityFn(); + Constant *PersonalityFn2 = F2->getPersonalityFn(); + if (PersonalityFn1 == PersonalityFn2) { + MergedFunc->setPersonalityFn(PersonalityFn1); + } else { +#ifdef ENABLE_DEBUG_CODE + PersonalityFn1->dump(); + PersonalityFn2->dump(); +#endif + // errs() << "ERROR: different personality function!\n"; + if (Debug) + errs() << "WARNING: different personality function!\n"; + } + } else if (F1->hasPersonalityFn()) { + // errs() << "Only F1 has PersonalityFn\n"; + // TODO: check if this is valid: merge function with personality with + // function without it + MergedFunc->setPersonalityFn(F1->getPersonalityFn()); + if (Debug) + errs() << "WARNING: only one personality function!\n"; + } else if (F2->hasPersonalityFn()) { + // errs() << "Only F2 has PersonalityFn\n"; + // TODO: check if this is valid: merge function with personality with + // function without it + MergedFunc->setPersonalityFn(F2->getPersonalityFn()); + if (Debug) + errs() << "WARNING: only one personality function!\n"; + } + + if (F1->hasComdat() && F2->hasComdat()) { + auto *Comdat1 = F1->getComdat(); + auto *Comdat2 = F2->getComdat(); + if (Comdat1 == Comdat2) { + MergedFunc->setComdat(Comdat1); + } else if (Debug) { + errs() << "WARNING: different comdats!\n"; + } + } else if (F1->hasComdat()) { + // errs() << "Only F1 has Comdat\n"; + MergedFunc->setComdat(F1->getComdat()); // TODO: check if this is valid: + // merge function with comdat with + // function without it + if (Debug) + errs() << "WARNING: only one comdat!\n"; + } else if (F2->hasComdat()) { + // errs() << "Only F2 has Comdat\n"; + MergedFunc->setComdat(F2->getComdat()); // TODO: check if this is valid: + // merge function with comdat with + // function without it + if (Debug) + errs() << "WARNING: only one comdat!\n"; + } + + //feisen:debug:attributes of function : merge function attributes + // for(int i = Attribute::AttrKind::None; i < Attribute::AttrKind::EndAttrKinds; i++) { + // if(F1->hasFnAttribute((Attribute::AttrKind)i) && F2->hasFnAttribute((Attribute::AttrKind)i)) { + // // if(F1->getFnAttribute((Attribute::AttrKind)i) == F2->getFnAttribute((Attribute::AttrKind)i)) { + // MergedFunc->addFnAttr(F1->getFnAttribute((Attribute::AttrKind)i)); + // // } + // } + // } + + + if (F1->hasSection()) { + MergedFunc->setSection(F1->getSection()); + } +} + +static Value *createCastIfNeeded(Value *V, Type *DstType, IRBuilder<> &Builder, + Type *IntPtrTy, + const FunctionMergingOptions &Options = {}); + +/* +bool CodeGenerator(Value *IsFunc1, BasicBlock *EntryBB1, BasicBlock *EntryBB2, +BasicBlock *PreBB, std::list> &AlignedInsts, + ValueToValueMapTy &VMap, Function *MergedFunc, +Type *RetType1, Type *RetType2, Type *ReturnType, bool RequiresUnifiedReturn, +LLVMContext &Context, Type *IntPtrTy, const FunctionMergingOptions &Options = +{}) { +*/ + +void FunctionMerger::CodeGenerator::destroyGeneratedCode() { + for (Instruction *I : CreatedInsts) { + I->dropAllReferences(); + } + for (Instruction *I : CreatedInsts) { + I->eraseFromParent(); + } + for (BasicBlock *BB : CreatedBBs) { + BB->eraseFromParent(); + } + CreatedInsts.clear(); + CreatedBBs.clear(); +} + +unsigned instToInt(Instruction *I); + +inst_range getInstructions(Function *F) { return instructions(F); } + +iterator_range getInstructions(BasicBlock *BB) { + return make_range(BB->begin(), BB->end()); +} + + +template class FingerprintMH { +private: + // The number of instructions defining a shingle. 2 or 3 is best. + static constexpr size_t K = 2; + static constexpr double threshold = 0.3; + static constexpr size_t MaxOpcode = 68; + const uint32_t _footprint; + +public: + uint64_t magnitude{0}; + std::vector hash; + std::vector bandHash; + +public: + FingerprintMH() = default; + + FingerprintMH(T owner, SearchStrategy &searchStrategy) : _footprint(searchStrategy.item_footprint()) { + std::vector integers; + std::array OpcodeFreq; + + for (size_t i = 0; i < MaxOpcode; i++) + OpcodeFreq[i] = 0; + + if (ShingleCrossBBs) + { + for (Instruction &I : getInstructions(owner)) { + integers.push_back(instToInt(&I)); + OpcodeFreq[I.getOpcode()]++; + if (I.isTerminator()) + OpcodeFreq[0] += I.getNumSuccessors(); + } + } + else + { + for (BasicBlock &BB : *owner) + { + + // Process normal instructions + for (Instruction &I : BB) + { + integers.push_back(instToInt(&I)); + OpcodeFreq[I.getOpcode()]++; + if(I.isTerminator()) + OpcodeFreq[0] += I.getNumSuccessors(); + } + + // Add dummy instructions between basic blocks + for (size_t i = 0; i(integers, hash); + searchStrategy.generateBands(hash, bandHash); + } + + uint32_t footprint() const { return _footprint; } + + float distance(const FingerprintMH &FP2) const { + size_t nintersect = 0; + size_t pos1 = 0; + size_t pos2 = 0; + size_t nHashes = hash.size(); + + while (pos1 != nHashes && pos2 != nHashes) { + if (hash[pos1] == FP2.hash[pos2]) { + nintersect++; + pos1++; + pos2++; + } else if (hash[pos1] < FP2.hash[pos2]) { + pos1++; + } else { + pos2++; + } + } + + int nunion = 2 * nHashes - nintersect; + return 1.f - (nintersect / (float)nunion); + } + + float distance_under(const FingerprintMH &FP2, float best_distance) const { + size_t mismatches = 0; + size_t pos1 = 0; + size_t pos2 = 0; + size_t nHashes = hash.size(); + size_t best_nintersect = static_cast(2.0 * nHashes * (1.f - best_distance) / (2.f - best_distance)); + size_t best_mismatches = 2 * (nHashes - best_nintersect); + + while (pos1 != nHashes && pos2 != nHashes) { + if (hash[pos1] == FP2.hash[pos2]) { + pos1++; + pos2++; + } else if (hash[pos1] < FP2.hash[pos2]) { + mismatches++; + pos1++; + } else { + mismatches++; + pos2++; + } + if (mismatches > best_mismatches) + break; + } + + size_t nintersect = nHashes - (mismatches / 2); + int nunion = 2 * nHashes - nintersect; + return 1.f - (nintersect / (float)nunion); + } +}; + + +template class Fingerprint { +public: + uint64_t magnitude{0}; + static const size_t MaxOpcode = 68; + std::array OpcodeFreq; + + Fingerprint() = default; + + Fingerprint(T owner) { + //feisen:debug may have segmentation fault bbb + assert(owner!=nullptr); + // errs()<<"{"; + // memset(OpcodeFreq, 0, sizeof(int) * MaxOpcode); + for (size_t i = 0; i < MaxOpcode; i++) + OpcodeFreq[i] = 0; + + // errs()<<"{"; + for (Instruction &I : getInstructions(owner)) { + //feisen:debug--- + if(I.getOpcode()>MaxOpcode||I.getOpcode()<0){ + errs()<<"Opcode is "<(Distance); + } +}; + +class BlockFingerprint : public Fingerprint { +public: + BasicBlock *BB{nullptr}; + size_t Size{0}; + + BlockFingerprint(BasicBlock *BB) : Fingerprint(BB), BB(BB) { + for (Instruction &I : *BB) { + if (!isa(&I) && !isa(&I)) { + Size++; + } + } + } +}; + +template class MatchInfo { +public: + T candidate{nullptr}; + size_t Size{0}; + size_t OtherSize{0}; + size_t MergedSize{0}; + size_t Magnitude{0}; + size_t OtherMagnitude{0}; + float Distance{0}; + bool Valid{false}; + bool Profitable{false}; + + + MatchInfo() = default; + MatchInfo(T candidate) : candidate(candidate) {}; + MatchInfo(T candidate, size_t Size) : candidate(candidate), Size(Size) {}; +}; + +template class Matcher { +public: + Matcher() = default; + virtual ~Matcher() = default; + + virtual void add_candidate(T candidate, size_t size) = 0; + virtual void remove_candidate(T candidate) = 0; + virtual T next_candidate() = 0; + virtual std::vector> &get_matches(T candidate) = 0; + virtual size_t size() = 0; + virtual void print_stats() = 0; +}; + +template class FPTy = Fingerprint> class MatcherManual : public Matcher{ +private: + struct MatcherEntry { + T candidate; + size_t size; + FPTy FP; + MatcherEntry() : MatcherEntry(nullptr, 0){}; + + template, typename T2 = Fingerprint> + MatcherEntry(T candidate, size_t size, + typename std::enable_if_t::value, int> * = nullptr) + : candidate(candidate), size(size), FP(candidate){} + + template , typename T2 = FingerprintMH> + MatcherEntry(T candidate, size_t size, SearchStrategy &strategy, + typename std::enable_if_t::value, int> * = nullptr) + : candidate(candidate), size(size), FP(candidate, strategy){} + }; + using MatcherIt = typename std::list::iterator; + + bool initialized{false}; + FunctionMerger &FM; + FunctionMergingOptions &Options; + std::list candidates; + std::unordered_map cache; + std::vector> matches; + std::unordered_map matchNames; + +public: + MatcherManual() = default; + MatcherManual(FunctionMerger &FM, FunctionMergingOptions &Options, std::string Filename) + : FM(FM), Options(Options) { + std::ifstream File{Filename}; + std::string FuncName1, FuncName2; + while (File >> FuncName1 >> FuncName2) { + matchNames[FuncName1] = FuncName2; + matchNames[FuncName2] = FuncName1; + } + } + + virtual ~MatcherManual() = default; + + void add_candidate(T candidate, size_t size) override { + if (matchNames.count(GetValueName(candidate)) == 0) + return; + add_candidate_helper(candidate, size); + cache[candidate] = candidates.begin(); + } + + template, typename T2 = Fingerprint> + void add_candidate_helper(T candidate, size_t size, + typename std::enable_if_t::value, int> * = nullptr) + { + candidates.emplace_front(candidate, size); + } + + void remove_candidate(T candidate) override { + auto cache_it = cache.find(candidate); + assert(cache_it != cache.end()); + candidates.erase(cache_it->second); + } + + T next_candidate() override { + if (!initialized) { + candidates.sort([&](auto &item1, auto &item2) -> bool { + return item1.FP.magnitude > item2.FP.magnitude; + }); + initialized = true; + } + update_matches(candidates.begin()); + return candidates.front().candidate; + } + + std::vector> &get_matches(T candidate) override { + return matches; + } + + size_t size() override { return candidates.size(); } + + void print_stats() override { + int Sum = 0; + int Count = 0; + float MinDistance = std::numeric_limits::max(); + float MaxDistance = 0; + + int Index1 = 0; + for (auto It1 = candidates.begin(), E1 = candidates.end(); It1!=E1; It1++) { + + int BestIndex = 0; + bool FoundCandidate = false; + float BestDist = std::numeric_limits::max(); + + unsigned CountCandidates = 0; + int Index2 = Index1; + for (auto It2 = It1, E2 = candidates.end(); It2 != E2; It2++) { + + if (It1->candidate == It2->candidate || Index1 == Index2) { + Index2++; + continue; + } + + if ((!FM.validMergeTypes(It1->candidate, It2->candidate, Options) && + !Options.EnableUnifiedReturnType) || + !validMergePair(It1->candidate, It2->candidate)) + continue; + + auto Dist = It1->FP.distance(It2->FP); + if (Dist < BestDist) { + BestDist = Dist; + FoundCandidate = true; + BestIndex = Index2; + } + if (RankingThreshold && CountCandidates > RankingThreshold) { + break; + } + CountCandidates++; + Index2++; + } + if (FoundCandidate) { + int Distance = std::abs(Index1 - BestIndex); + Sum += Distance; + if (Distance > MaxDistance) MaxDistance = Distance; + if (Distance < MinDistance) MinDistance = Distance; + Count++; + } + Index1++; + } + if(Debug){ + errs() << "Total: " << Count << "\n"; + errs() << "Min Distance: " << MinDistance << "\n"; + errs() << "Max Distance: " << MaxDistance << "\n"; + errs() << "Average Distance: " << (((double)Sum)/((double)Count)) << "\n"; + } + } + + +private: + void update_matches(MatcherIt it) { + matches.clear(); + + MatchInfo best_match; + best_match.OtherSize = it->size; + best_match.OtherMagnitude = it->FP.magnitude; + best_match.Distance = std::numeric_limits::max(); + + for (auto entry = std::next(candidates.cbegin()); entry != candidates.cend(); ++entry) { + if ((!FM.validMergeTypes(it->candidate, entry->candidate, Options) && + !Options.EnableUnifiedReturnType) || + !validMergePair(it->candidate, entry->candidate)) + continue; + if (matchNames[GetValueName(it->candidate)] == GetValueName(entry->candidate)) { + best_match.candidate = entry->candidate; + best_match.Size = entry->size; + best_match.Magnitude = entry->FP.magnitude; + best_match.Distance = 0; + break; + } + } + + if (best_match.candidate != nullptr) + matches.push_back(std::move(best_match)); + return; + } +}; + +template class FPTy = Fingerprint> class MatcherFQ : public Matcher{ +private: + struct MatcherEntry { + T candidate; + size_t size; + FPTy FP; + MatcherEntry() : MatcherEntry(nullptr, 0){}; + + template, typename T2 = Fingerprint> + MatcherEntry(T candidate, size_t size, + typename std::enable_if_t::value, int> * = nullptr) + : candidate(candidate), size(size), FP(candidate){} + + template , typename T2 = FingerprintMH> + MatcherEntry(T candidate, size_t size, SearchStrategy &strategy, + typename std::enable_if_t::value, int> * = nullptr) + : candidate(candidate), size(size), FP(candidate, strategy){} + }; + using MatcherIt = typename std::list::iterator; + + bool initialized{false}; + FunctionMerger &FM; + FunctionMergingOptions &Options; + std::list candidates; + std::unordered_map cache; + std::vector> matches; + SearchStrategy strategy; + +public: + MatcherFQ() = default; + MatcherFQ(FunctionMerger &FM, FunctionMergingOptions &Options, size_t rows=2, size_t bands=100) + : FM(FM), Options(Options), strategy(rows, bands){}; + + virtual ~MatcherFQ() = default; + + void add_candidate(T candidate, size_t size) override { + add_candidate_helper(candidate, size); + cache[candidate] = candidates.begin(); + } + + template, typename T2 = Fingerprint> + void add_candidate_helper(T candidate, size_t size, + typename std::enable_if_t::value, int> * = nullptr) + { + candidates.emplace_front(candidate, size); + } + + template, typename T2 = Fingerprint> + void add_candidate_helper(T candidate, size_t size, + typename std::enable_if_t::value, int> * = nullptr) + { + candidates.emplace_front(candidate, size, strategy); + } + + void remove_candidate(T candidate) override { + auto cache_it = cache.find(candidate); + assert(cache_it != cache.end()); + candidates.erase(cache_it->second); + } + + T next_candidate() override { + if (!initialized) { + candidates.sort([&](auto &item1, auto &item2) -> bool { + return item1.FP.magnitude > item2.FP.magnitude; + }); + initialized = true; + } + update_matches(candidates.begin()); + return candidates.front().candidate; + } + + std::vector> &get_matches(T candidate) override { + return matches; + } + + size_t size() override { return candidates.size(); } + + void print_stats() override { + int Sum = 0; + int Count = 0; + float MinDistance = std::numeric_limits::max(); + float MaxDistance = 0; + + int Index1 = 0; + for (auto It1 = candidates.begin(), E1 = candidates.end(); It1!=E1; It1++) { + + int BestIndex = 0; + bool FoundCandidate = false; + float BestDist = std::numeric_limits::max(); + + unsigned CountCandidates = 0; + int Index2 = Index1; + for (auto It2 = It1, E2 = candidates.end(); It2 != E2; It2++) { + + if (It1->candidate == It2->candidate || Index1 == Index2) { + Index2++; + continue; + } + + if ((!FM.validMergeTypes(It1->candidate, It2->candidate, Options) && + !Options.EnableUnifiedReturnType) || + !validMergePair(It1->candidate, It2->candidate)) + continue; + + auto Dist = It1->FP.distance(It2->FP); + if (Dist < BestDist) { + BestDist = Dist; + FoundCandidate = true; + BestIndex = Index2; + } + if (RankingThreshold && CountCandidates > RankingThreshold) { + break; + } + CountCandidates++; + Index2++; + } + if (FoundCandidate) { + int Distance = std::abs(Index1 - BestIndex); + Sum += Distance; + if (Distance > MaxDistance) MaxDistance = Distance; + if (Distance < MinDistance) MinDistance = Distance; + Count++; + } + Index1++; + } + if(Debug){ + errs() << "Total: " << Count << "\n"; + errs() << "Min Distance: " << MinDistance << "\n"; + errs() << "Max Distance: " << MaxDistance << "\n"; + errs() << "Average Distance: " << (((double)Sum)/((double)Count)) << "\n"; + } + } + + +private: + void update_matches(MatcherIt it) { + size_t CountCandidates = 0; + matches.clear(); + + MatchInfo best_match; + best_match.OtherSize = it->size; + best_match.OtherMagnitude = it->FP.magnitude; + best_match.Distance = std::numeric_limits::max(); + + if (ExplorationThreshold == 1) { + for (auto entry = std::next(candidates.cbegin()); entry != candidates.cend(); ++entry) { + if ((!FM.validMergeTypes(it->candidate, entry->candidate, Options) && + !Options.EnableUnifiedReturnType) || + !validMergePair(it->candidate, entry->candidate)) + continue; + auto new_distance = it->FP.distance(entry->FP); + if (new_distance < best_match.Distance) { + best_match.candidate = entry->candidate; + best_match.Size = entry->size; + best_match.Magnitude = entry->FP.magnitude; + best_match.Distance = new_distance; + } + if (RankingThreshold && (CountCandidates > RankingThreshold)) + break; + CountCandidates++; + } + if (best_match.candidate != nullptr) + if (!EnableF3M || best_match.Distance < RankingDistance) + /*if (EnableThunkPrediction) + { + if (std::max(best_match.size, best_match.OtherSize) + EstimateThunkOverhead(it->candidate, best_match->candidate)) // Needs AlwaysPreserved + return; + }*/ + matches.push_back(std::move(best_match)); + return; + } + + for (auto &entry : candidates) { + if (entry.candidate == it->candidate) + continue; + if ((!FM.validMergeTypes(it->candidate, entry.candidate, Options) && + !Options.EnableUnifiedReturnType) || + !validMergePair(it->candidate, entry.candidate)) + continue; + MatchInfo new_match(entry.candidate, entry.size); + new_match.Distance = it->FP.distance(entry.FP); + new_match.OtherSize = it->size; + new_match.OtherMagnitude = it->FP.magnitude; + new_match.Magnitude = entry.FP.magnitude; + if (!EnableF3M || new_match.Distance < RankingDistance) + matches.push_back(std::move(new_match)); + if (RankingThreshold && (CountCandidates > RankingThreshold)) + break; + CountCandidates++; + } + + + if (ExplorationThreshold < matches.size()) { + std::partial_sort(matches.begin(), matches.begin() + ExplorationThreshold, + matches.end(), [&](auto &match1, auto &match2) -> bool { + return match1.Distance < match2.Distance; + }); + matches.resize(ExplorationThreshold); + std::reverse(matches.begin(), matches.end()); + } else { + std::sort(matches.begin(), matches.end(), + [&](auto &match1, auto &match2) -> bool { + return match1.Distance > match2.Distance; + }); + } + } +}; + +template class MatcherLSH : public Matcher { +private: + struct MatcherEntry { + T candidate; + size_t size; + FingerprintMH FP; + MatcherEntry() : MatcherEntry(nullptr, 0){}; + MatcherEntry(T candidate, size_t size, SearchStrategy &strategy) + : candidate(candidate), size(size), + FP(candidate, strategy){}; + }; + using MatcherIt = typename std::list::iterator; + + bool initialized{false}; + const size_t rows{2}; + const size_t bands{100}; + FunctionMerger &FM; + FunctionMergingOptions &Options; + SearchStrategy strategy; + + std::list candidates; + std::unordered_map> lsh; + std::vector> cache; + std::vector> matches; + +public: + MatcherLSH() = default; + MatcherLSH(FunctionMerger &FM, FunctionMergingOptions &Options, size_t rows, size_t bands) + : rows(rows), bands(bands), FM(FM), Options(Options), strategy(rows, bands) {}; + + virtual ~MatcherLSH() = default; + + void add_candidate(T candidate, size_t size) override { + candidates.emplace_front(candidate, size, strategy); + + auto it = candidates.begin(); + auto &bandHash = it->FP.bandHash; + for (size_t i = 0; i < bands; ++i) { + if (lsh.count(bandHash[i]) > 0) + lsh.at(bandHash[i]).push_back(it); + else + lsh.insert(std::make_pair(bandHash[i], std::vector(1, it))); + } + } + + void remove_candidate(T candidate) override { + auto cache_it = candidates.end(); + for (auto &cache_item : cache) { + if (cache_item.first == candidate) { + cache_it = cache_item.second; + break; + } + } + assert(cache_it != candidates.end()); + + auto &FP = cache_it->FP; + for (size_t i = 0; i < bands; ++i) { + if (lsh.count(FP.bandHash[i]) == 0) + continue; + + auto &foundFs = lsh.at(FP.bandHash[i]); + for (size_t j = 0; j < foundFs.size(); ++j) + if (foundFs[j]->candidate == candidate) + lsh.at(FP.bandHash[i]).erase(lsh.at(FP.bandHash[i]).begin() + j); + } + candidates.erase(cache_it); + } + + T next_candidate() override { + if (!initialized) { + candidates.sort([&](auto &item1, auto &item2) -> bool { + return item1.FP.magnitude > item2.FP.magnitude; + }); + initialized = true; + } + update_matches(candidates.begin()); + return candidates.front().candidate; + } + + std::vector> &get_matches(T candidate) override { + return matches; + } + + size_t size() override { return candidates.size(); } + + void print_stats() override { + std::unordered_set seen; + std::vector hist_bucket_size(20); + std::vector hist_distances(21); + std::vector hist_distances_diff(21); + uint32_t duplicate_hashes = 0; + + for (auto it = lsh.cbegin(); it != lsh.cend(); ++it) { + size_t idx = 31 - __builtin_clz(it->second.size()); + idx = idx < 20 ? idx : 19; + hist_bucket_size[idx]++; + } + for (size_t i = 0; i < 20; i++) + errs() << "STATS: Histogram Bucket Size " << (1 << i) << " : " << hist_bucket_size[i] << "\n"; + return; + + for (auto it = candidates.begin(); it != candidates.end(); ++it) { + seen.clear(); + seen.reserve(candidates.size() / 10); + + float best_distance = std::numeric_limits::max(); + std::unordered_set temp(it->FP.hash.begin(), it->FP.hash.end()); + duplicate_hashes += it->FP.hash.size() - temp.size(); + + for (size_t i = 0; i < bands; ++i) { + auto &foundFs = lsh.at(it->FP.bandHash[i]); + size_t idx = 31 - __builtin_clz(foundFs.size()); + idx = idx < 20 ? idx : 19; + hist_bucket_size[idx]++; + for (size_t j = 0; j < foundFs.size(); ++j) { + auto match_it = foundFs[j]; + if ((match_it->candidate == NULL) || + (match_it->candidate == it->candidate)) + continue; + if ((!FM.validMergeTypes(it->candidate, match_it->candidate, Options) && + !Options.EnableUnifiedReturnType) || + !validMergePair(it->candidate, match_it->candidate)) + continue; + + if (seen.count(match_it->candidate) == 1) + continue; + seen.insert(match_it->candidate); + + auto distance = it->FP.distance(match_it->FP); + best_distance = distance < best_distance ? distance : best_distance; + auto idx2 = static_cast(distance * 20); + idx2 = idx2 < 21 ? idx2 : 20; + hist_distances[idx2]++; + auto idx3 = static_cast((distance - best_distance) * 20); + idx3 = idx3 < 21 ? idx3 : 20; + hist_distances_diff[idx3]++; + } + } + } + if(Debug){ + errs() << "STATS: Avg Duplicate Hashes: " << (1.0*duplicate_hashes) / candidates.size() << "\n"; + for (size_t i = 0; i < 20; i++) + errs() << "STATS: Histogram Bucket Size " << (1 << i) << " : " << hist_bucket_size[i] << "\n"; + for (size_t i = 0; i < 21; i++) + errs() << "STATS: Histogram Distances " << i * 0.05 << " : " << hist_distances[i] << "\n"; + for (size_t i = 0; i < 21; i++) + errs() << "STATS: Histogram Distances Diff " << i * 0.05 << " : " << hist_distances_diff[i] << "\n"; + } + } + +private: + void update_matches(MatcherIt it) { + size_t CountCandidates = 0; + std::unordered_set seen; + seen.reserve(candidates.size() / 10); + matches.clear(); + cache.clear(); + cache.emplace_back(it->candidate, it); + + auto &FP = it->FP; + MatchInfo best_match; + best_match.Distance = std::numeric_limits::max(); + for (size_t i = 0; i < bands; ++i) { + assert(lsh.count(FP.bandHash[i]) > 0); + + auto &foundFs = lsh.at(FP.bandHash[i]); + for (size_t j = 0; j < foundFs.size() && j < BucketSizeCap; ++j) { + auto match_it = foundFs[j]; + if ((match_it->candidate == NULL) || + (match_it->candidate == it->candidate)) + continue; + if ((!FM.validMergeTypes(it->candidate, match_it->candidate, Options) && + !Options.EnableUnifiedReturnType) || + !validMergePair(it->candidate, match_it->candidate)) + continue; + + if (seen.count(match_it->candidate) == 1) + continue; + seen.insert(match_it->candidate); + + MatchInfo new_match(match_it->candidate, match_it->size); + if (best_match.Distance < 0.1) + new_match.Distance = FP.distance_under(match_it->FP, best_match.Distance); + else + new_match.Distance = FP.distance(match_it->FP); + new_match.OtherSize = it->size; + new_match.OtherMagnitude = FP.magnitude; + new_match.Magnitude = match_it->FP.magnitude; + if (new_match.Distance < best_match.Distance && new_match.Distance < RankingDistance ) + best_match = new_match; + if (ExplorationThreshold > 1) + if (new_match.Distance < RankingDistance) + matches.push_back(new_match); + cache.emplace_back(match_it->candidate, match_it); + if (RankingThreshold && (CountCandidates > RankingThreshold)) + break; + CountCandidates++; + } + // If we've gone through i = 0 without finding a distance of 0.0 + // the minimum distance we might ever find is 2.0 / (nHashes + 1) + if ((ExplorationThreshold == 1) && (best_match.Distance < (2.0 / (rows * bands) ))) + break; + if (RankingThreshold && (CountCandidates > RankingThreshold)) + break; + } + + if (ExplorationThreshold == 1) + if (best_match.candidate != nullptr) + matches.push_back(std::move(best_match)); + + if (matches.size() <= 1) + return; + + size_t toRank = std::min((size_t)ExplorationThreshold, matches.size()); + + std::partial_sort(matches.begin(), matches.begin() + toRank, matches.end(), + [&](auto &match1, auto &match2) -> bool { + return match1.Distance < match2.Distance; + }); + matches.resize(toRank); + std::reverse(matches.begin(), matches.end()); + } +}; + + +template class MatcherReport { +private: + struct MatcherEntry { + T candidate; + Fingerprint FPF; + FingerprintMH FPMH; + MatcherEntry(T candidate, SearchStrategy &strategy) + : candidate(candidate), FPF(candidate), FPMH(candidate, strategy){}; + }; + using MatcherIt = typename std::list::iterator; + + FunctionMerger &FM; + FunctionMergingOptions &Options; + SearchStrategy strategy; + std::vector candidates; + +public: + MatcherReport() = default; + MatcherReport(size_t rows, size_t bands, FunctionMerger &FM, FunctionMergingOptions &Options) + : FM(FM), Options(Options), strategy(rows, bands) {}; + + ~MatcherReport() = default; + + void add_candidate(T candidate) { + candidates.emplace_back(candidate, strategy); + } + + void report() const { + char distance_mh_str[20]; + + for (auto &entry: candidates) { + uint64_t val = 0; + for (auto &num: entry.FPF.OpcodeFreq) + val += num; + if(Debug){ + errs() << "Function Name: " << GetValueName(entry.candidate) + << " Fingerprint Size: " << val << "\n"; + } + } + + std::string Name("_m_f_"); + for (auto it1 = candidates.cbegin(); it1 != candidates.cend(); ++it1) { + for (auto it2 = std::next(it1); it2 != candidates.cend(); ++it2) { + if ((!FM.validMergeTypes(it1->candidate, it2->candidate, Options) && + !Options.EnableUnifiedReturnType) || + !validMergePair(it1->candidate, it2->candidate)) + continue; + + auto distance_fq = it1->FPF.distance(it2->FPF); + auto distance_mh = it1->FPMH.distance(it2->FPMH); + std::snprintf(distance_mh_str, 20, "%.5f", distance_mh); + if(Debug){ + errs() << "F1: " << it1 - candidates.cbegin() << " + " + << "F2: " << it2 - candidates.cbegin() << " " + << "FQ: " << static_cast(distance_fq) << " " + << "MH: " << distance_mh_str << "\n"; + } + FunctionMergeResult Result = FM.merge(it1->candidate, it2->candidate, Name, Options); + } + } + } +}; + +AlignedCode::AlignedCode(BasicBlock *BB1, BasicBlock *BB2) { + // this should never happen + assert(BB1 != nullptr || BB2 != nullptr); + + // Add only BB1, skipping Phi nodes and Landing Pads + if (BB1 != nullptr && BB2 == nullptr) { + Data.emplace_back(BB1, nullptr, false); + for (Instruction &I : *BB1) { + if (isa(&I) || isa(&I)) + continue; + Data.emplace_back(&I, nullptr, false); + } + return; + } + + // Add only BB2, skipping Phi nodes and Landing Pads + if (BB1 == nullptr && BB2 != nullptr) { + Data.emplace_back(nullptr, BB2, false); + for (Instruction &I : *BB2) { + if (isa(&I) || isa(&I)) + continue; + Data.emplace_back(nullptr, &I, false); + } + return; + } + + // Add both, skipping Phi nodes and Landing Pads + Data.emplace_back(BB1, BB2, FunctionMerger::matchBlocks(BB1, BB2)); + + auto It1 = BB1->begin(); + while (isa(*It1) || isa(*It1)) + It1++; + + auto It2 = BB2->begin(); + while (isa(*It2) || isa(*It2)) + It2++; + + while (It1 != BB1->end() && It2 != BB2->end()) { + Instruction *I1 = &*It1; + Instruction *I2 = &*It2; + + if (FunctionMerger::matchInstructions(I1, I2)) { + Data.emplace_back(I1, I2, true); + } else { + Data.emplace_back(I1, nullptr, false); + Data.emplace_back(nullptr, I2, false); + } + + It1++; + It2++; + } + assert ((It1 == BB1->end()) && (It2 == BB2->end())); +} + +bool AlignedCode::isProfitable() const { + int OriginalCost = 0; + int MergedCost = 0; + + bool InsideSplit = false; + + for (auto &Entry : Data) { + Instruction *I1 = nullptr; + if (Entry.get(0)) + I1 = dyn_cast(Entry.get(0)); + + Instruction *I2 = nullptr; + if (Entry.get(1)) + I2 = dyn_cast(Entry.get(1)); + + bool IsInstruction = I1 != nullptr || I2 != nullptr; + if (Entry.match()) { + if (IsInstruction) { + OriginalCost += 2; + MergedCost += 1; + } + if (InsideSplit) { + InsideSplit = false; + MergedCost += 2; + } + } else { + if (IsInstruction) { + OriginalCost += 1; + MergedCost += 1; + } + if (!InsideSplit) { + InsideSplit = true; + MergedCost += 1; + } + } + } + + bool Profitable = (MergedCost <= OriginalCost); + if (Verbose) + errs() << ((Profitable) ? "Profitable" : "Unprofitable") << "\n"; + return Profitable; +} + +void AlignedCode::extend(const AlignedCode &Other) { + for (auto &Entry : Other) { + Instruction *I1 = nullptr; + if (Entry.get(0)) + I1 = dyn_cast(Entry.get(0)); + + Instruction *I2 = nullptr; + if (Entry.get(1)) + I2 = dyn_cast(Entry.get(1)); + + bool IsInstruction = I1 != nullptr || I2 != nullptr; + + Data.emplace_back(Entry.get(0), Entry.get(1), Entry.match()); + + if (IsInstruction) { + Insts++; + if (Entry.match()) { + Matches++; + Instruction *I = I1 ? I1 : I2; + if (!I->isTerminator()) + CoreMatches++; + } + } + } +} + +bool AcrossBlocks; + +FunctionMergeResult +FunctionMerger::merge(Function *F1, Function *F2, std::string Name, const FunctionMergingOptions &Options) { + bool ProfitableFn = true; + LLVMContext &Context = *ContextPtr; + FunctionMergeResult ErrorResponse(F1, F2, nullptr); + + if (!validMergePair(F1, F2)) + return ErrorResponse; + +#ifdef TIME_STEPS_DEBUG + TimeAlign.startTimer(); + time_align_start = std::chrono::steady_clock::now(); +#endif + + AlignedCode AlignedSeq; + NeedlemanWunschSA> SA(ScoringSystem(-1, 2), FunctionMerger::match); + + if (EnableHyFMNW || EnableHyFMPA) { // Processing individual pairs of blocks + + int B1Max{0}, B2Max{0}; + size_t MaxMem{0}; + + int NumBB1{0}, NumBB2{0}; + size_t MemSize{0}; + +#ifdef TIME_STEPS_DEBUG + TimeAlignRank.startTimer(); +#endif + + // Fingerprints for all Blocks in F1 organized by size + std::map> Blocks; + for (BasicBlock &BB1 : *F1) { + BlockFingerprint BD1(&BB1); + NumBB1++; + MemSize += BD1.footprint(); + Blocks[BD1.Size].push_back(std::move(BD1)); + } + +#ifdef TIME_STEPS_DEBUG + TimeAlignRank.stopTimer(); +#endif + + for (BasicBlock &BIt : *F2) { +#ifdef TIME_STEPS_DEBUG + TimeAlignRank.startTimer(); +#endif + BasicBlock *BB2 = &BIt; + BlockFingerprint BD2(BB2); + NumBB2++; + MemSize += BD2.footprint(); + + // list all the map entries in Blocks in order of distance from BD2.Size + auto ItSetIncr = Blocks.lower_bound(BD2.Size); + // auto ItSetDecr = std::reverse_iterator(ItSetIncr); //todo: fix bug :feisen + // auto ItSetDecr = std::reverse_iterator>::iterator>(ItSetIncr); // + auto ItSetDecr = std::reverse_iterator>::iterator>(ItSetIncr); // fix bug + + + std::vector ItSets; + + if (EnableHyFMNW) { + while (ItSetDecr != Blocks.rend() && ItSetIncr != Blocks.end()) { + if (BD2.Size - ItSetDecr->first < ItSetIncr->first - BD2.Size){ + ItSets.push_back(std::prev(ItSetDecr.base())); + ItSetDecr++; + } else { + ItSets.push_back(ItSetIncr); + ItSetIncr++; + } + } + + while (ItSetDecr != Blocks.rend()) { + ItSets.push_back(std::prev(ItSetDecr.base())); + ItSetDecr++; + } + + while (ItSetIncr != Blocks.end()) { + ItSets.push_back(ItSetIncr); + ItSetIncr++; + } + } else { + ItSetIncr = Blocks.find(BD2.Size); + if (ItSetIncr != Blocks.end()) + ItSets.push_back(ItSetIncr); + } + + // Find the closest block starting from blocks with similar size + std::vector::iterator BestIt; + std::map>::iterator BestSet; + float BestDist = std::numeric_limits::max(); + + for (auto ItSet : ItSets) { + for (auto BDIt = ItSet->second.begin(), E = ItSet->second.end(); BDIt != E; BDIt++) { + auto D = BD2.distance(*BDIt); + if (D < BestDist) { + BestDist = D; + BestIt = BDIt; + BestSet = ItSet; + if (BestDist < std::numeric_limits::epsilon()) + break; + } + } + if (BestDist < std::numeric_limits::epsilon()) + break; + } + +#ifdef TIME_STEPS_DEBUG + TimeAlignRank.stopTimer(); +#endif + + bool MergedBlock = false; + if (BestDist < std::numeric_limits::max()) { + BasicBlock *BB1 = BestIt->BB; + AlignedCode AlignedBlocks; + + if (EnableHyFMNW) { + SmallVector BB1Vec; + vectorizeBB(BB1Vec, BB1); + + SmallVector BB2Vec; + vectorizeBB(BB2Vec, BB2); + + AlignedBlocks = SA.getAlignment(BB1Vec, BB2Vec); + + if (Verbose) { + auto MemReq = SA.getMemoryRequirement(BB1Vec, BB2Vec); + errs() << "MStats: " << BB1Vec.size() << " , " << BB2Vec.size() << " , " << MemReq << "\n"; + + if (MemReq > MaxMem) { + MaxMem = MemReq; + B1Max = BB1Vec.size(); + B2Max = BB2Vec.size(); + } + } + } else if (EnableHyFMPA) { + AlignedBlocks = AlignedCode(BB1, BB2); + + if (Verbose) { + auto MemReq = AlignedBlocks.size() * (sizeof(AlignedCode::Entry) + 2 * sizeof(void*)); + errs() << "MStats: " << BB1->size() << " , " << BB2->size() << " , " << MemReq << "\n"; + + if (MemReq > MaxMem) { + MaxMem = MemReq; + B1Max = BB1->size(); + B2Max = BB2->size(); + } + } + } + + if (!HyFMProfitability || AlignedBlocks.isProfitable()) { + AlignedSeq.extend(AlignedBlocks); + BestSet->second.erase(BestIt); + MergedBlock = true; + } + } + + if (!MergedBlock) + AlignedSeq.extend(AlignedCode(nullptr, BB2)); + } + + for (auto &Pair : Blocks) + for (auto &BD1 : Pair.second) + AlignedSeq.extend(AlignedCode(BD1.BB, nullptr)); + + if (Verbose) { + errs() << "SStats: " << B1Max << " , " << B2Max << " , " << MaxMem << "\n"; + errs() << "RStats: " << NumBB1 << " , " << NumBB2 << " , " << MemSize << "\n"; + } + + ProfitableFn = AlignedSeq.hasMatches(); + + } else { //default SALSSA + SmallVector F1Vec; + SmallVector F2Vec; + +#ifdef TIME_STEPS_DEBUG + TimeLin.startTimer(); +#endif + linearize(F1, F1Vec); + linearize(F2, F2Vec); +#ifdef TIME_STEPS_DEBUG + TimeLin.stopTimer(); +#endif + + auto MemReq = SA.getMemoryRequirement(F1Vec, F2Vec); + auto MemAvailable = getTotalSystemMemory(); + if(Debug) + errs() << "MStats: " << F1Vec.size() << " , " << F2Vec.size() << " , " << MemReq << "\n"; + if (MemReq > MemAvailable * 0.9) { + errs() << "Insufficient Memory\n"; +#ifdef TIME_STEPS_DEBUG + TimeAlign.stopTimer(); + time_align_end = std::chrono::steady_clock::now(); +#endif + return ErrorResponse; + } + + AlignedSeq = SA.getAlignment(F1Vec, F2Vec); + } + +#ifdef TIME_STEPS_DEBUG + TimeAlign.stopTimer(); + time_align_end = std::chrono::steady_clock::now(); +#endif + if (!ProfitableFn && !ReportStats) { + if (Verbose) + errs() << "Skipped: Not profitable enough!!\n"; + return ErrorResponse; + } + + unsigned NumMatches = 0; + unsigned TotalEntries = 0; + AcrossBlocks = false; + BasicBlock *CurrBB0 = nullptr; + BasicBlock *CurrBB1 = nullptr; + for (auto &Entry : AlignedSeq) { + TotalEntries++; + if (Entry.match()) { + NumMatches++; + if (isa(Entry.get(1))) { + CurrBB1 = cast(Entry.get(1)); + } else if (auto *I = dyn_cast(Entry.get(1))) { + if (CurrBB1 == nullptr) + CurrBB1 = I->getParent(); + else if (CurrBB1 != I->getParent()) { + AcrossBlocks = true; + } + } + if (isa(Entry.get(0))) { + CurrBB0 = cast(Entry.get(0)); + } else if (auto *I = dyn_cast(Entry.get(0))) { + if (CurrBB0 == nullptr) + CurrBB0 = I->getParent(); + else if (CurrBB0 != I->getParent()) { + AcrossBlocks = true; + } + } + } else { + if (isa_and_nonnull(Entry.get(0))) + CurrBB1 = nullptr; + if (isa_and_nonnull(Entry.get(1))) + CurrBB0 = nullptr; + } + } + if (AcrossBlocks) { + if (Verbose) { + errs() << "Across Basic Blocks\n"; + } + } + if (Verbose || ReportStats) { + errs() << "Matches: " << NumMatches << ", " << TotalEntries << ", " << ( (double) NumMatches/ (double) TotalEntries) << "\n"; + } + + if (ReportStats) + return ErrorResponse; + + // errs() << "Code Gen\n"; +#ifdef ENABLE_DEBUG_CODE + if (Verbose) { + for (auto &Entry : AlignedSeq) { + if (Entry.match()) { + errs() << "1: "; + if (isa(Entry.get(0))) + errs() << "BB " << GetValueName(Entry.get(0)) << "\n"; + else + Entry.get(0)->dump(); + errs() << "2: "; + if (isa(Entry.get(1))) + errs() << "BB " << GetValueName(Entry.get(1)) << "\n"; + else + Entry.get(1)->dump(); + errs() << "----\n"; + } else { + if (Entry.get(0)) { + errs() << "1: "; + if (isa(Entry.get(0))) + errs() << "BB " << GetValueName(Entry.get(0)) << "\n"; + else + Entry.get(0)->dump(); + errs() << "2: -\n"; + } else if (Entry.get(1)) { + errs() << "1: -\n"; + errs() << "2: "; + if (isa(Entry.get(1))) + errs() << "BB " << GetValueName(Entry.get(1)) << "\n"; + else + Entry.get(1)->dump(); + } + errs() << "----\n"; + } + } + } +#endif + +#ifdef TIME_STEPS_DEBUG + TimeParam.startTimer(); +#endif + + // errs() << "Creating function type\n"; + + // Merging parameters + std::map ParamMap1; + std::map ParamMap2; + std::vector Args; + + // errs() << "Merging arguments\n"; + MergeArguments(Context, F1, F2, AlignedSeq, ParamMap1, ParamMap2, Args, + Options); + + Type *RetType1 = F1->getReturnType(); + Type *RetType2 = F2->getReturnType(); + Type *ReturnType = nullptr; + + bool RequiresUnifiedReturn = false; + + // Value *RetUnifiedAddr = nullptr; + // Value *RetAddr1 = nullptr; + // Value *RetAddr2 = nullptr; + + if (validMergeTypes(F1, F2, Options)) { + // errs() << "Simple return types\n"; + ReturnType = RetType1; + if (ReturnType->isVoidTy()) { + ReturnType = RetType2; + } + } else if (Options.EnableUnifiedReturnType) { + // errs() << "Unifying return types\n"; + RequiresUnifiedReturn = true; + + auto SizeOfTy1 = DL->getTypeStoreSize(RetType1); + auto SizeOfTy2 = DL->getTypeStoreSize(RetType2); + if (SizeOfTy1 >= SizeOfTy2) { + ReturnType = RetType1; + } else { + ReturnType = RetType2; + } + } else { +#ifdef TIME_STEPS_DEBUG + TimeParam.stopTimer(); +#endif + return ErrorResponse; + } + FunctionType *FTy = + FunctionType::get(ReturnType, ArrayRef(Args), false); + + if (Name.empty()) { + // Name = ".m.f"; + Name = "_m_f"; + } + /* + if (!HasWholeProgram) { + Name = M->getModuleIdentifier() + std::string("."); + } + Name = Name + std::string("m.f"); + */ + Function *MergedFunc = + Function::Create(FTy, // GlobalValue::LinkageTypes::InternalLinkage, + GlobalValue::LinkageTypes::PrivateLinkage, Twine(Name), + M); // merged.function + // MergedFunc->setLinkage(GlobalValue::ExternalLinkage); + +//feisen: + // Function *MergedFunc = + // Function::Create(FTy, // GlobalValue::LinkageTypes::InternalLinkage, + // GlobalValue::LinkageTypes::ExternalLinkage, Twine(Name), + // M); // merged.function + + // errs() << "Initializing VMap\n"; + ValueToValueMapTy VMap; + + std::vector ArgsList; + for (Argument &arg : MergedFunc->args()) { + ArgsList.push_back(&arg); + } + Value *FuncId = ArgsList[0]; + + //feisen:debug:attribute + + // for(int i = Attribute::AttrKind::None; i < Attribute::AttrKind::EndAttrKinds; i++) { + // if(F1->hasFnAttribute((Attribute::AttrKind)i) && F2->hasFnAttribute((Attribute::AttrKind)i)) { + // // if(F1->getFnAttribute((Attribute::AttrKind)i) == F2->getFnAttribute((Attribute::AttrKind)i)) { + // MergedFunc->addFnAttr(F1->getFnAttribute((Attribute::AttrKind)i)); + // // break; + // // } + // } + // } + + ////TODO: merging attributes might create compilation issues if we are not careful. + ////Therefore, attributes are not being merged right now. + //auto AttrList1 = F1->getAttributes(); + //auto AttrList2 = F2->getAttributes(); + //auto AttrListM = MergedFunc->getAttributes(); + + int ArgId = 0; + for (auto I = F1->arg_begin(), E = F1->arg_end(); I != E; I++) { + VMap[&(*I)] = ArgsList[ParamMap1[ArgId]]; + + //auto AttrSet1 = AttrList1.getParamAttributes((*I).getArgNo()); + //AttrBuilder Attrs(AttrSet1); + //AttrListM = AttrListM.addParamAttributes( + // Context, ArgsList[ParamMap1[ArgId]]->getArgNo(), Attrs); + + ArgId++; + } + + ArgId = 0; + for (auto I = F2->arg_begin(), E = F2->arg_end(); I != E; I++) { + VMap[&(*I)] = ArgsList[ParamMap2[ArgId]]; + + //auto AttrSet2 = AttrList2.getParamAttributes((*I).getArgNo()); + //AttrBuilder Attrs(AttrSet2); + //AttrListM = AttrListM.addParamAttributes( + // Context, ArgsList[ParamMap2[ArgId]]->getArgNo(), Attrs); + + ArgId++; + } + //MergedFunc->setAttributes(AttrListM); + +#ifdef TIME_STEPS_DEBUG + TimeParam.stopTimer(); +#endif + + // errs() << "Setting attributes\n"; + SetFunctionAttributes(F1, F2, MergedFunc); + + Value *IsFunc1 = FuncId; + + // errs() << "Running code generator\n"; + + auto Gen = [&](auto &CG) { + CG.setFunctionIdentifier(IsFunc1) + .setEntryPoints(&F1->getEntryBlock(), &F2->getEntryBlock()) + .setReturnTypes(RetType1, RetType2) + .setMergedFunction(MergedFunc) + .setMergedEntryPoint(BasicBlock::Create(Context, "entry", MergedFunc)) + .setMergedReturnType(ReturnType, RequiresUnifiedReturn) + .setContext(ContextPtr) + .setIntPtrType(IntPtrTy); + if (!CG.generate(AlignedSeq, VMap, Options)) { + // F1->dump(); + // F2->dump(); + // MergedFunc->dump(); + if(Debug){ + errs()<<"f1:\n"; + F1->print(errs()); + errs()<<"f2:\n"; + F2->print(errs()); + // errs()<<"merged:\n"; + // MergedFunc->print(errs()); + } + + MergedFunc->eraseFromParent(); + MergedFunc = nullptr; + if (Debug) + errs() << "ERROR: Failed to generate the merged function!\n"; + //feisen: fail to generate the merged function and return false + return false; + //feisen + } + return true; + }; + + SALSSACodeGen CG(F1, F2); + //feisen: check if the code generation is successful; if false then return nullptr; + if(!Gen(CG)){ + if(Debug) + errs()<<"feisen\n"; + FunctionMergeResult Result(false); + return Result; + } + + FunctionMergeResult Result(F1, F2, MergedFunc, RequiresUnifiedReturn); + Result.setArgumentMapping(F1, ParamMap1); + Result.setArgumentMapping(F2, ParamMap2); + Result.setFunctionIdArgument(FuncId != nullptr); + return Result; +} + +void FunctionMerger::replaceByCall(Function *F, FunctionMergeResult &MFR, + const FunctionMergingOptions &Options) { + LLVMContext &Context = M->getContext(); + + Value *FuncId = MFR.getFunctionIdValue(F); + Function *MergedF = MFR.getMergedFunction(); + + // Make sure we preserve its linkage + auto Linkage = F->getLinkage(); + + F->deleteBody(); + BasicBlock *NewBB = BasicBlock::Create(Context, "", F); + IRBuilder<> Builder(NewBB); + + std::vector args; + for (unsigned i = 0; i < MergedF->getFunctionType()->getNumParams(); i++) { + args.push_back(nullptr); + } + + if (MFR.hasFunctionIdArgument()) { + args[0] = FuncId; + } + + std::vector ArgsList; + for (Argument &arg : F->args()) { + ArgsList.push_back(&arg); + } + + for (auto Pair : MFR.getArgumentMapping(F)) { + args[Pair.second] = ArgsList[Pair.first]; + } + + for (unsigned i = 0; i < args.size(); i++) { + if (args[i] == nullptr) { + args[i] = UndefValue::get(MergedF->getFunctionType()->getParamType(i)); + } + } + + F->setLinkage(Linkage); + + CallInst *CI = + (CallInst *)Builder.CreateCall(MergedF, ArrayRef(args)); + CI->setTailCall(); + CI->setCallingConv(MergedF->getCallingConv()); + CI->setAttributes(MergedF->getAttributes()); + CI->setIsNoInline(); + + if (F->getReturnType()->isVoidTy()) { + Builder.CreateRetVoid(); + } else { + Value *CastedV; + if (MFR.needUnifiedReturn()) { + Value *AddrCI = Builder.CreateAlloca(CI->getType()); + Builder.CreateStore(CI, AddrCI); + Value *CastedAddr = Builder.CreatePointerCast( + AddrCI, + PointerType::get(F->getReturnType(), DL->getAllocaAddrSpace())); + CastedV = Builder.CreateLoad(F->getReturnType(), CastedAddr); + } else { + CastedV = createCastIfNeeded(CI, F->getReturnType(), Builder, IntPtrTy, + Options); + } + Builder.CreateRet(CastedV); + } +} + +bool FunctionMerger::replaceCallsWith(Function *F, FunctionMergeResult &MFR, + const FunctionMergingOptions &Options) { + + Value *FuncId = MFR.getFunctionIdValue(F); + Function *MergedF = MFR.getMergedFunction(); + + unsigned CountUsers = 0; + std::vector Calls; + for (User *U : F->users()) { + CountUsers++; + if (auto *CI = dyn_cast(U)) { + if (CI->getCalledFunction() == F) { + Calls.push_back(CI); + } + } else if (auto *II = dyn_cast(U)) { + if (II->getCalledFunction() == F) { + Calls.push_back(II); + } + } + } + + if (Calls.size() < CountUsers) + return false; + + for (CallBase *CI : Calls) { + IRBuilder<> Builder(CI); + + std::vector args; + for (unsigned i = 0; i < MergedF->getFunctionType()->getNumParams(); i++) { + args.push_back(nullptr); + } + + if (MFR.hasFunctionIdArgument()) { + args[0] = FuncId; + } + + for (auto Pair : MFR.getArgumentMapping(F)) { + args[Pair.second] = CI->getArgOperand(Pair.first); + } + + for (unsigned i = 0; i < args.size(); i++) { + if (args[i] == nullptr) { + args[i] = UndefValue::get(MergedF->getFunctionType()->getParamType(i)); + } + } + + CallBase *NewCB = nullptr; + if (CI->getOpcode() == Instruction::Call) { + NewCB = (CallInst *)Builder.CreateCall(MergedF->getFunctionType(), + MergedF, args); + } else if (CI->getOpcode() == Instruction::Invoke) { + auto *II = dyn_cast(CI); + NewCB = (InvokeInst *)Builder.CreateInvoke(MergedF->getFunctionType(), + MergedF, II->getNormalDest(), + II->getUnwindDest(), args); + // MergedF->dump(); + // MergedF->getFunctionType()->dump(); + // errs() << "Invoke CallUpdate:\n"; + // II->dump(); + // NewCB->dump(); + } + NewCB->setCallingConv(MergedF->getCallingConv()); + NewCB->setAttributes(MergedF->getAttributes()); + NewCB->setIsNoInline(); + Value *CastedV = NewCB; + if (!F->getReturnType()->isVoidTy()) { + if (MFR.needUnifiedReturn()) { + Value *AddrCI = Builder.CreateAlloca(NewCB->getType()); + Builder.CreateStore(NewCB, AddrCI); + Value *CastedAddr = Builder.CreatePointerCast( + AddrCI, + PointerType::get(F->getReturnType(), DL->getAllocaAddrSpace())); + CastedV = Builder.CreateLoad(F->getReturnType(), CastedAddr); + } else { + CastedV = createCastIfNeeded(NewCB, F->getReturnType(), Builder, + IntPtrTy, Options); + } + } + + // if (F->getReturnType()==MergedF->getReturnType()) + if (CI->getNumUses() > 0) { + CI->replaceAllUsesWith(CastedV); + } + // assert( (CI->getNumUses()>0) && "ERROR: Function Call has uses!"); + CI->eraseFromParent(); + } + + return true; +} + +static bool ShouldPreserveGV(const GlobalValue *GV) { + // Function must be defined here + if (GV->isDeclaration()) + return true; + + // Available externally is really just a "declaration with a body". + // if (GV->hasAvailableExternallyLinkage()) + // return true; + + // Assume that dllexported symbols are referenced elsewhere + if (GV->hasDLLExportStorageClass()) + return true; + + // Already local, has nothing to do. + if (GV->hasLocalLinkage()) + return false; + + return false; +} + +static int RequiresOriginalInterface(Function *F, FunctionMergeResult &MFR, + StringSet<> &AlwaysPreserved) { + bool CanErase = !F->hasAddressTaken(); + CanErase = + CanErase && (AlwaysPreserved.find(F->getName()) == AlwaysPreserved.end()); + if (!HasWholeProgram) { + CanErase = CanErase && F->isDiscardableIfUnused(); + } + return !CanErase; +} + +static int RequiresOriginalInterfaces(FunctionMergeResult &MFR, + StringSet<> &AlwaysPreserved) { + auto FPair = MFR.getFunctions(); + Function *F1 = FPair.first; + Function *F2 = FPair.second; + return (RequiresOriginalInterface(F1, MFR, AlwaysPreserved) ? 1 : 0) + + (RequiresOriginalInterface(F2, MFR, AlwaysPreserved) ? 1 : 0); +} + +void FunctionMerger::updateCallGraph(Function *F, FunctionMergeResult &MFR, + StringSet<> &AlwaysPreserved, + const FunctionMergingOptions &Options) { + replaceByCall(F, MFR, Options); + if (!RequiresOriginalInterface(F, MFR, AlwaysPreserved)) { + bool CanErase = replaceCallsWith(F, MFR, Options); + CanErase = CanErase && F->use_empty(); + CanErase = CanErase && + (AlwaysPreserved.find(F->getName()) == AlwaysPreserved.end()); + if (!HasWholeProgram) { + CanErase = CanErase && !ShouldPreserveGV(F); + CanErase = CanErase && F->isDiscardableIfUnused(); + } + if (CanErase) + F->eraseFromParent(); + } +} + +void FunctionMerger::updateCallGraph(FunctionMergeResult &MFR, + StringSet<> &AlwaysPreserved, + const FunctionMergingOptions &Options) { + auto FPair = MFR.getFunctions(); + Function *F1 = FPair.first; + Function *F2 = FPair.second; + updateCallGraph(F1, MFR, AlwaysPreserved, Options); + updateCallGraph(F2, MFR, AlwaysPreserved, Options); +} + +static int EstimateThunkOverhead(FunctionMergeResult &MFR, + StringSet<> &AlwaysPreserved) { + // return RequiresOriginalInterfaces(MFR, AlwaysPreserved) * 3; + return RequiresOriginalInterfaces(MFR, AlwaysPreserved) * + (2 + MFR.getMergedFunction()->getFunctionType()->getNumParams()); +} + +/*static int EstimateThunkOverhead(Function* F1, Function* F2, + StringSet<> &AlwaysPreserved) { + int fParams = F1->getFunctionType()->getNumParams() + F2->getFunctionType()->getNumParams(); + return RequiresOriginalInterfaces(F1, F2, AlwaysPreserved) * (2 + fParams); +}*/ + +static size_t EstimateFunctionSize(Function *F, TargetTransformInfo *TTI) { + float size = 0; + for (Instruction &I : instructions(F)) { + switch (I.getOpcode()) { + // case Instruction::Alloca: + case Instruction::PHI: + size += 0.2; + break; + // case Instruction::Select: + // size += 1.2; + // break; + default: + auto cost = TTI->getInstructionCost(&I, TargetTransformInfo::TargetCostKind::TCK_CodeSize); + size += cost.getValue().value(); + } + } + return size_t(std::ceil(size)); +} + + +unsigned instToInt(Instruction *I) { + uint32_t value = 0; + static uint32_t pseudorand_value = 100; + + if (pseudorand_value > 10000) + pseudorand_value = 100; + + // std::ofstream myfile; + // std::string newPath = "/home/sean/similarityChecker.txt"; + + // Opcodes must be equivalent for instructions to match -- use opcode value as + // base + value = I->getOpcode(); + + // Number of operands must be equivalent -- except in the case where the + // instruction is a return instruction -- +1 to stop being zero + uint32_t operands = + I->getOpcode() == Instruction::Ret ? 1 : I->getNumOperands(); + value = value * (operands + 1); + + // Instruction type must be equivalent, pairwise operand types must be + // equivalent -- use typeID casted to int -- This may not be perfect as my + // understanding of this is limited + auto instTypeID = static_cast(I->getType()->getTypeID()); + value = value * (instTypeID + 1); + auto *ITypePtr = I->getType(); + if (ITypePtr) { + value = value * (reinterpret_cast(ITypePtr) + 1); + } + + for (size_t i = 0; i < I->getNumOperands(); i++) { + auto operTypeID = static_cast(I->getOperand(i)->getType()->getTypeID()); + value = value * (operTypeID + 1); + + auto *IOperTypePtr = I->getOperand(i)->getType(); + + if (IOperTypePtr) { + value = + value * + (reinterpret_cast(I->getOperand(i)->getType()) + 1); + } + + value = value * (i + 1); + } + return value; + + // Now for the funky stuff -- this is gonna be a wild ride + switch (I->getOpcode()) { + + case Instruction::Load: { + + const LoadInst *LI = dyn_cast(I); + uint32_t lValue = LI->isVolatile() ? 1 : 10; // Volatility + lValue += LI->getAlign().value(); // Alignment + lValue += static_cast(LI->getOrdering()); // Ordering + + value = value * lValue; + + break; + } + + case Instruction::Store: { + + const StoreInst *SI = dyn_cast(I); + uint32_t sValue = SI->isVolatile() ? 2 : 20; // Volatility + sValue += SI->getAlign().value(); // Alignment + sValue += static_cast(SI->getOrdering()); // Ordering + + value = value * sValue; + + break; + } + + case Instruction::Alloca: { + const AllocaInst *AI = dyn_cast(I); + uint32_t aValue = AI->getAlign().value(); // Alignment + + if (AI->getArraySize()) { + aValue += reinterpret_cast(AI->getArraySize()); + } + + value = value * (aValue + 1); + + break; + } + + case Instruction::GetElementPtr: // Important + { + + auto *GEP = dyn_cast(I); + uint32_t gValue = 1; + + SmallVector Indices(GEP->idx_begin(), GEP->idx_end()); + gValue = Indices.size() + 1; + + gValue += GEP->isInBounds() ? 3 : 30; + + Type *AggTy = GEP->getSourceElementType(); + gValue += static_cast(AggTy->getTypeID()); + + unsigned curIndex = 1; + for (; curIndex != Indices.size(); ++curIndex) { + // CompositeType* CTy = dyn_cast(AggTy); + + if (!AggTy || AggTy->isPointerTy()) { + if (Deterministic) + value = pseudorand_value++; + else + value = std::rand() % 10000 + 100; + break; + } + + Value *Idx = Indices[curIndex]; + + if (isa(AggTy)) { + if (!isa(Idx)) { + if (Deterministic) + value = pseudorand_value++; + else + value = std::rand() % 10000 + 100; // Use a random number as we don't + // want this to match with anything + break; + } + + auto i = 0; + if (Idx) { + i = reinterpret_cast(Idx); + } + gValue += i; + } + } + + value = value * gValue; + + break; + } + + case Instruction::Switch: { + auto *SI = dyn_cast(I); + uint32_t sValue = 1; + sValue = SI->getNumCases(); + + auto CaseIt = SI->case_begin(), CaseEnd = SI->case_end(); + + while (CaseIt != CaseEnd) { + auto *Case = &*CaseIt; + if (Case) { + sValue += reinterpret_cast(Case); + } + CaseIt++; + } + + value = value * sValue; + + break; + } + + case Instruction::Call: { + auto *CI = dyn_cast(I); + uint32_t cValue = 1; + + if (CI->isInlineAsm()) { + if (Deterministic) + value = pseudorand_value++; + else + value = std::rand() % 10000 + 100; + break; + } + + if (CI->getCalledFunction()) { + cValue = reinterpret_cast(CI->getCalledFunction()); + } + + if (Function *F = CI->getCalledFunction()) { + if (auto ID = (Intrinsic::ID)F->getIntrinsicID()) { + cValue += static_cast(ID); + } + } + + cValue += static_cast(CI->getCallingConv()); + + value = value * cValue; + + break; + } + + case Instruction::Invoke: // Need to look at matching landing pads + { + auto *II = dyn_cast(I); + uint32_t iValue = 1; + + iValue = static_cast(II->getCallingConv()); + + if (II->getAttributes().getRawPointer()) { + iValue += + reinterpret_cast(II->getAttributes().getRawPointer()); + } + + value = value * iValue; + + break; + } + + case Instruction::InsertValue: { + auto *IVI = dyn_cast(I); + + uint32_t ivValue = 1; + + ivValue = IVI->getNumIndices(); + + // check element wise equality + auto Idx = IVI->getIndices(); + const auto *IdxIt = Idx.begin(); + const auto *IdxEnd = Idx.end(); + + while (IdxIt != IdxEnd) { + auto *val = &*IdxIt; + if (val) { + ivValue += reinterpret_cast(*val); + } + IdxIt++; + } + + value = value * ivValue; + + break; + } + + case Instruction::ExtractValue: { + auto *EVI = dyn_cast(I); + + uint32_t evValue = 1; + + evValue = EVI->getNumIndices(); + + // check element wise equality + auto Idx = EVI->getIndices(); + const auto *IdxIt = Idx.begin(); + const auto *IdxEnd = Idx.end(); + + while (IdxIt != IdxEnd) { + auto *val = &*IdxIt; + if (val) { + evValue += reinterpret_cast(*val); + } + IdxIt++; + } + + value = value * evValue; + + break; + } + + case Instruction::Fence: { + auto *FI = dyn_cast(I); + + uint32_t fValue = 1; + + fValue = static_cast(FI->getOrdering()); + + fValue += static_cast(FI->getSyncScopeID()); + + value = value * fValue; + + break; + } + + case Instruction::AtomicCmpXchg: { + auto *AXI = dyn_cast(I); + + uint32_t axValue = 1; + + axValue = AXI->isVolatile() ? 4 : 40; + axValue += AXI->isWeak() ? 5 : 50; + axValue += static_cast(AXI->getSuccessOrdering()); + axValue += static_cast(AXI->getFailureOrdering()); + axValue += static_cast(AXI->getSyncScopeID()); + + value = value * axValue; + + break; + } + + case Instruction::AtomicRMW: { + auto *ARI = dyn_cast(I); + + uint32_t arValue = 1; + + arValue = static_cast(ARI->getOperation()); + arValue += ARI->isVolatile() ? 6 : 60; + arValue += static_cast(ARI->getOrdering()); + arValue += static_cast(ARI->getSyncScopeID()); + + value = value * arValue; + break; + } + + case Instruction::PHI: { + if (Deterministic) + value = pseudorand_value++; + else + value = std::rand() % 10000 + 100; + break; + } + + default: + if (auto *CI = dyn_cast(I)) { + uint32_t cmpValue = 1; + + cmpValue = static_cast(CI->getPredicate()) + 1; + + value = value * cmpValue; + } + } + + // Return + return value; +} + +//feisen===== +bool detectASM_fm(Function &F){ + for(BasicBlock &B: F){ + for(Instruction &I: B){ + if(CallInst *callInst = dyn_cast(&I)) { + if(callInst->isInlineAsm()) { + return true; + } + } + } + } + return false; +} + +bool detect_bad_ndelay(Function &F) { + for (BasicBlock &B : F) { + for (Instruction &I : B) { + if (CallInst *callInst = dyn_cast(&I)) { + if (Function *calledFunction = callInst->getCalledFunction()) { + if (calledFunction->getName() == "__bad_ndelay") { + return true; + } + } + } + } + } + return false; +} +//========== + +static bool checkAsmVolatility(Function &F) { + for (BasicBlock &B : F) { + for (Instruction &I : B) { + if (CallInst *callInst = dyn_cast(&I)) { + if (callInst->isInlineAsm()) { + // if(auto *IA = dyn_cast(callInst->getCalledValue())) { + // return true; + // } + return true; + } + } + } + } + return false; +} + +static bool skipFunction(Function &F) { + if (F.getName().equals("ftab_insert")) { + return true; + } + return false; +} + +bool ignoreFunction(Function &F) { + if (skipFunction(F)){ + return true; + } + if(checkAsmVolatility(F)){ + return true; + } + // if(F.hasInternalLinkage()) //feisen: ignore internal functions + // return true; + //feisen:remove asm + // if(detectASM_fm(F)){ + // return true; + // } + // if(F.getName().equals("vfprintf")){ + // return true; + // } + // if(F.getName().equals("vfprintf")){ + // return true; + // } + if(detect_bad_ndelay(F)){ + return true; + } + for (Instruction &I : instructions(F)) { + if (auto *CB = dyn_cast(&I)) { + if (Function *F2 = CB->getCalledFunction()) { + if (auto ID = (Intrinsic::ID)F2->getIntrinsicID()) { + if (Intrinsic::isOverloaded(ID)) + continue; + if (Intrinsic::getName(ID).contains("permvar")) + return true; + if (Intrinsic::getName(ID).contains("vcvtps")) + return true; + if (Intrinsic::getName(ID).contains("avx")) + return true; + if (Intrinsic::getName(ID).contains("x86")) + return true; + if (Intrinsic::getName(ID).contains("arm")) + return true; + } + } + } + } + return false; +} + +bool FunctionMerging::runImpl( + Module &M, function_ref GTTI) { + +#ifdef TIME_STEPS_DEBUG + TimeTotal.startTimer(); + TimePreProcess.startTimer(); +#endif + + StringSet<> AlwaysPreserved; + AlwaysPreserved.insert("main"); + + srand(time(nullptr)); // 设置随机种子 + + FunctionMergingOptions Options = + FunctionMergingOptions() + .maximizeParameterScore(MaxParamScore) + .matchOnlyIdenticalTypes(IdenticalType) + .enableUnifiedReturnTypes(EnableUnifiedReturnType); + + // auto *PSI = &this->getAnalysis().getPSI(); + // auto LookupBFI = [this](Function &F) { + // return &this->getAnalysis(F).getBFI(); + //}; + + // TODO: We could use a TTI ModulePass instead but current TTI analysis pass + // is a FunctionPass. + + FunctionMerger FM(&M); + + if (ReportStats) { + MatcherReport reporter(LSHRows, LSHBands, FM, Options); + for (auto &F : M) { + if (F.isDeclaration() || F.isVarArg() || (!HasWholeProgram && F.hasAvailableExternallyLinkage())) + continue; + reporter.add_candidate(&F); + } + reporter.report(); +#ifdef TIME_STEPS_DEBUG + TimeTotal.stopTimer(); + TimePreProcess.stopTimer(); + TimeRank.clear(); + TimeCodeGenTotal.clear(); + TimeAlign.clear(); + TimeAlignRank.clear(); + TimeParam.clear(); + TimeCodeGen.clear(); + TimeCodeGenFix.clear(); + TimePostOpt.clear(); + TimeVerify.clear(); + TimePreProcess.clear(); + TimeLin.clear(); + TimeUpdate.clear(); + TimePrinting.clear(); + TimeTotal.clear(); +#endif + return false; + } + + std::unique_ptr> matcher; + + // Check whether to use a linear scan instead + int size = 0; + for (auto &F : M) { + // (声明 || 可变参数函数 || (不包含完整程序&&有外部链接))--》 continue + if (F.isDeclaration() || F.isVarArg() || (!HasWholeProgram && F.hasAvailableExternallyLinkage())) + continue; + size++; + } + + // Create a threshold based on the application's size + if (AdaptiveThreshold || AdaptiveBands) + { + double x = std::log10(size) / 10; + RankingDistance = (double) (x - 0.3); + if (RankingDistance < 0.05) + RankingDistance = 0.05; + if (RankingDistance > 0.4) + RankingDistance = 0.4; + + if (AdaptiveBands) { + float target_probability = 0.9; + float offset = 0.1; + unsigned tempBands = std::ceil(std::log(1.0 - target_probability) / std::log(1.0 - std::pow(RankingDistance + offset, LSHRows))); + if (tempBands < LSHBands) + LSHBands = tempBands; + + } + if (AdaptiveThreshold) + RankingDistance = 1 - RankingDistance; + else + RankingDistance = 1.0; + + } + if(Debug){ + errs() << "Threshold: " << RankingDistance << "\n"; + errs() << "LSHRows: " << LSHRows << "\n"; + errs() << "LSHBands: " << LSHBands << "\n"; + } + + if (!ToMergeFile.empty()) { + matcher = std::make_unique>(FM, Options, ToMergeFile); + } else if (EnableF3M) { + matcher = std::make_unique>(FM, Options, LSHRows, LSHBands); + // errs() << "LSH MH\n"; + } else { + matcher = std::make_unique>(FM, Options); + // errs() << "LIN SCAN FP\n"; + } + + SearchStrategy strategy(LSHRows, LSHBands); + for (auto &F : M) { + if (F.isDeclaration() || F.isVarArg() || (!HasWholeProgram && F.hasAvailableExternallyLinkage())) + continue; + if (ignoreFunction(F)) + continue; + matcher->add_candidate(&F, EstimateFunctionSize(&F, GTTI(F))); + } + +#ifdef TIME_STEPS_DEBUG + TimePreProcess.stopTimer(); +#endif + if(Debug){ + errs() << "Number of Functions: " << matcher->size() << "\n"; + if (MatcherStats) { +#ifdef TIME_STEPS_DEBUG + matcher->print_stats(); + TimeRank.clear(); + TimeCodeGenTotal.clear(); + TimeAlign.clear(); + TimeAlignRank.clear(); + TimeParam.clear(); + TimeCodeGen.clear(); + TimeCodeGenFix.clear(); + TimePostOpt.clear(); + TimeVerify.clear(); + TimePreProcess.clear(); + TimeLin.clear(); + TimeUpdate.clear(); + TimePrinting.clear(); + TimeTotal.clear(); +#endif + return false; + } + } + + unsigned TotalMerges = 0; + unsigned TotalOpReorder = 0; + unsigned TotalBinOps = 0; + + while (matcher->size() > 0) { +#ifdef TIME_STEPS_DEBUG + TimeRank.startTimer(); + time_ranking_start = std::chrono::steady_clock::now(); + + time_ranking_end = time_ranking_start; + time_align_start = time_ranking_start; + time_align_end = time_ranking_start; + time_codegen_start = time_ranking_start; + time_codegen_end = time_ranking_start; + time_verify_start = time_ranking_start; + time_verify_end = time_ranking_start; + time_update_start = time_ranking_start; + time_update_end = time_ranking_start; + time_iteration_end = time_ranking_start; +#endif + + Function *F1 = matcher->next_candidate(); + auto &Rank = matcher->get_matches(F1); + matcher->remove_candidate(F1); + + if(F1==nullptr) + continue; + //feisen:debug + // if(F1->getName().equals("")||F1->getName().equals(" ")){ + // // errs()<<"feisen:debug:empty function\n"; + // continue; + // } + +#ifdef TIME_STEPS_DEBUG + TimeRank.stopTimer(); + time_ranking_end = std::chrono::steady_clock::now(); +#endif + unsigned MergingTrialsCount = 0; + float OtherDistance = 0.0; + +//feisen:debug +// errs()<<"A"; +std::string F1Name(GetValueName(F1)); +// errs()<<"F1: "<0){ + // break; + // }else{ + // counter++; + // } + //feisen:debug + // errs()<<"i2.0"; +#ifdef TIME_STEPS_DEBUG + TimeCodeGenTotal.startTimer(); + time_codegen_start = std::chrono::steady_clock::now(); +#endif + MatchInfo match = Rank.back(); + Rank.pop_back(); + Function *F2 = match.candidate; + + if(F1==nullptr||F2==nullptr) continue; //=====0321====== + //verify function + if(verifyFunction(*F1)||verifyFunction(*F2)){ + errs()<<"feisen:debug:verify function error\n"; + continue; + } + + //feisen:debug + if(Debug){ + errs()<<"feisen:debug:F1: "<getName()<<" F2: "<getName()<<"\n"; + errs()<getName().equals("")<<"\n"; + } + F1->getName().equals("");F2->getName().equals(""); + // errs()<<"i2.9"; + // if(F1->getName().equals("")){ + // continue; + // } + // errs()<getName().equals("")<<"[b"; + // errs()<getName()<<"[a"; + // for(BasicBlock &b: *F1){ + // errs()<<"0;"; + // b.print(errs()); + // errs()<<"1;"; + // } + // errs()<<"i2.1"; + // for(BasicBlock &b: *F2){ + + // } + if(F1!=F11){ + errs()<<"F1 Changed"; + break; + } + if(F2==nullptr){ + errs()<<"F2 error"; + break; + } + // if(F1->getName().equals("")){ + // // errs()<<"feisen:debug:empty function\n"; + // continue; + // } + //feisen:debug + + //feisen:debug + // std::string F1Name(GetValueName(F1)); + //feisen:debug + // errs()<<"i2.2"; + + std::string F2Name(GetValueName(F2)); + + //feisen:debug + // errs()<<"i2.3"; + + if (Verbose) { + if (EnableF3M) { + Fingerprint FP1(F1); + Fingerprint FP2(F2); + OtherDistance = FP1.distance(FP2); + } else { + FingerprintMH FP1(F1, strategy); + FingerprintMH FP2(F2, strategy); + OtherDistance = FP1.distance(FP2); + } + } + + MergingTrialsCount++; + + + if (Debug) + errs() << "Attempting: " << F1Name << ", " << F2Name << " : " << match.Distance << "\n"; + + std::string Name = "_m_f_" + std::to_string(TotalMerges); + //feisen:debug:FunctionMergeResult Result + FunctionMergeResult Result = FM.merge(F1, F2, Name, Options); +#ifdef TIME_STEPS_DEBUG + TimeCodeGenTotal.stopTimer(); + time_codegen_end = std::chrono::steady_clock::now(); +#endif + //feisen:debug:Result must be successful + if (Result.getMergedFunction() != nullptr && Result.Success) { + // if (Result.getMergedFunction() != nullptr) { +#ifdef TIME_STEPS_DEBUG + TimeVerify.startTimer(); + time_verify_start = std::chrono::steady_clock::now(); +#endif + match.Valid = !verifyFunction(*Result.getMergedFunction()); +#ifdef TIME_STEPS_DEBUG + TimeVerify.stopTimer(); + time_verify_end = std::chrono::steady_clock::now(); +#endif + +#ifdef ENABLE_DEBUG_CODE + if (Debug) { + errs() << "F1:\n"; + F1->dump(); + errs() << "F2:\n"; + F2->dump(); + errs() << "F1-F2:\n"; + Result.getMergedFunction()->dump(); + } +#endif + +//feisen:debug +// errs()<<"L2"; + +#ifdef TIME_STEPS_DEBUG + TimeUpdate.startTimer(); + time_update_start = std::chrono::steady_clock::now(); +#endif + if (!match.Valid) { + Result.getMergedFunction()->eraseFromParent(); + } else { + size_t MergedSize = EstimateFunctionSize(Result.getMergedFunction(), GTTI(*Result.getMergedFunction())); + size_t Overhead = EstimateThunkOverhead(Result, AlwaysPreserved); + + size_t SizeF12 = MergedSize + Overhead; + size_t SizeF1F2 = match.OtherSize + match.Size; + + match.MergedSize = SizeF12; + match.Profitable = (SizeF12 + MergingOverheadThreshold) < SizeF1F2; + +#ifdef SKIP_MERGING + Result.getMergedFunction()->eraseFromParent(); +#else + if (!ToMergeFile.empty() || match.Profitable) { + + //debug:print f1,f2,m: + // errs()<<"F1: "<getName()<<" F2: "<getName()<<" M: "<getName()<<"\n"; + //endl + TotalMerges++; + matcher->remove_candidate(F2); + + FM.updateCallGraph(Result, AlwaysPreserved, Options); + + //resolve phinode + resolvePHI( *(Result.getMergedFunction()) ); + + //feisen:debug:resue merged functions + if (ReuseMergedFunctions + ){ + //feisen:debug + // && !Result.getMergedFunction()->getName().equals("")) { + // feed new function back into the working lists + matcher->add_candidate( + Result.getMergedFunction(), + EstimateFunctionSize(Result.getMergedFunction(), GTTI(*Result.getMergedFunction()))); + } + break; //========0321=========== + } else { + Result.getMergedFunction()->eraseFromParent(); + } +#endif + } +#ifdef TIME_STEPS_DEBUG + TimeUpdate.stopTimer(); + time_update_end = std::chrono::steady_clock::now(); +#endif + } + +#ifdef TIME_STEPS_DEBUG + time_iteration_end = std::chrono::steady_clock::now(); +#endif + +#ifdef TIME_STEPS_DEBUG + TimePrinting.startTimer(); +#endif + if (Debug){ + errs() << F1Name << " + " << F2Name << " <= " << Name + << " Tries: " << MergingTrialsCount + << " Valid: " << match.Valid + << " BinSizes: " << match.OtherSize << " + " << match.Size << " <= " << match.MergedSize + << " IRSizes: " << match.OtherMagnitude << " + " << match.Magnitude + << " AcrossBlocks: " << AcrossBlocks + << " Profitable: " << match.Profitable + << " Distance: " << match.Distance; + } + if (Verbose) + errs() << " OtherDistance: " << OtherDistance; +#ifdef TIME_STEPS_DEBUG + using namespace std::chrono_literals; + if(Debug){ + errs() << " TotalTime: " << (time_iteration_end - time_ranking_start) / 1us + << " RankingTime: " << (time_ranking_end - time_ranking_start) / 1us + << " AlignTime: " << (time_align_end - time_align_start) / 1us + << " CodegenTime: " << ((time_codegen_end - time_codegen_start) - (time_align_end - time_align_start)) / 1us + << " VerifyTime: " << (time_verify_end - time_verify_start) / 1us + << " UpdateTime: " << (time_update_end - time_update_start) / 1us; + } +#endif + if(Debug) + errs() << "\n"; + + +#ifdef TIME_STEPS_DEBUG + TimePrinting.stopTimer(); +#endif + + //if (match.Profitable || (MergingTrialsCount >= ExplorationThreshold)) + if (MergingTrialsCount >= ExplorationThreshold) + break; + } + } + + double MergingAverageDistance = 0; + unsigned MergingMaxDistance = 0; + + if (Debug || Verbose) { + errs() << "Total operand reordering: " << TotalOpReorder << "/" + << TotalBinOps << " (" + << 100.0 * (((double)TotalOpReorder) / ((double)TotalBinOps)) + << " %)\n"; + + // errs() << "Total parameter score: " << TotalParamScore << "\n"; + + // errs() << "Total number of merges: " << MergingDistance.size() << + // "\n"; + errs() << "Average number of trials before merging: " + << MergingAverageDistance << "\n"; + errs() << "Maximum number of trials before merging: " << MergingMaxDistance + << "\n"; + } + +#ifdef TIME_STEPS_DEBUG + TimeTotal.stopTimer(); + if(Debug){ + errs() << "Timer:Rank: " << TimeRank.getTotalTime().getWallTime() << "\n"; + TimeRank.clear(); + + errs() << "Timer:CodeGen:Total: " << TimeCodeGenTotal.getTotalTime().getWallTime() << "\n"; + TimeCodeGenTotal.clear(); + + errs() << "Timer:CodeGen:Align: " << TimeAlign.getTotalTime().getWallTime() << "\n"; + TimeAlign.clear(); + + errs() << "Timer:CodeGen:Align:Rank: " << TimeAlignRank.getTotalTime().getWallTime() << "\n"; + TimeAlignRank.clear(); + + errs() << "Timer:CodeGen:Param: " << TimeParam.getTotalTime().getWallTime() << "\n"; + TimeParam.clear(); + + errs() << "Timer:CodeGen:Gen: " << TimeCodeGen.getTotalTime().getWallTime() + << "\n"; + TimeCodeGen.clear(); + + errs() << "Timer:CodeGen:Fix: " << TimeCodeGenFix.getTotalTime().getWallTime() + << "\n"; + TimeCodeGenFix.clear(); + + errs() << "Timer:CodeGen:PostOpt: " << TimePostOpt.getTotalTime().getWallTime() + << "\n"; + TimePostOpt.clear(); + + errs() << "Timer:Verify: " << TimeVerify.getTotalTime().getWallTime() << "\n"; + TimeVerify.clear(); + + errs() << "Timer:PreProcess: " << TimePreProcess.getTotalTime().getWallTime() + << "\n"; + TimePreProcess.clear(); + + errs() << "Timer:Lin: " << TimeLin.getTotalTime().getWallTime() << "\n"; + TimeLin.clear(); + + errs() << "Timer:Update: " << TimeUpdate.getTotalTime().getWallTime() << "\n"; + TimeUpdate.clear(); + + errs() << "Timer:Printing: " << TimePrinting.getTotalTime().getWallTime() << "\n"; + TimePrinting.clear(); + + errs() << "Timer:Total: " << TimeTotal.getTotalTime().getWallTime() << "\n"; + TimeTotal.clear(); + } +#endif + + return true; +} + +PreservedAnalyses FunctionMergingPass::run(Module &M, + ModuleAnalysisManager &AM) { + FunctionMerging FM; + if(Debug){ + errs() << "ExplorationThreshold: "<< ExplorationThreshold << "\n"; //feisen + errs() << "RankingThreshold: "<< RankingThreshold << "\n"; + errs() << "MergingOverheadThreshold" <printAsOperand(namestream, false); + return namestream.str(); + } + return "[null]"; +} + +/// Create a cast instruction if needed to cast V to type DstType. We treat +/// pointer and integer types of the same bitwidth as equivalent, so this can be +/// used to cast them to each other where needed. The function returns the Value +/// itself if no cast is needed, or a new CastInst instance inserted before +/// InsertBefore. The integer type equivalent to pointers must be passed as +/// IntPtrType (get it from DataLayout). This is guaranteed to generate no-op +/// casts, otherwise it will assert. +// Value *FunctionMerger::createCastIfNeeded(Value *V, Type *DstType, +// IRBuilder<> &Builder, const FunctionMergingOptions &Options) { +Value *createCastIfNeeded(Value *V, Type *DstType, IRBuilder<> &Builder, + Type *IntPtrTy, + const FunctionMergingOptions &Options) { + + if (V->getType() == DstType || Options.IdenticalTypesOnly) + return V; + + Value *Result; + Type *OrigType = V->getType(); + + if (OrigType->isStructTy()) { + assert(DstType->isStructTy()); + assert(OrigType->getStructNumElements() == DstType->getStructNumElements()); + + Result = UndefValue::get(DstType); + for (unsigned int I = 0, E = OrigType->getStructNumElements(); I < E; ++I) { + Value *ExtractedValue = + Builder.CreateExtractValue(V, ArrayRef(I)); + Value *Element = + createCastIfNeeded(ExtractedValue, DstType->getStructElementType(I), + Builder, IntPtrTy, Options); + Result = + Builder.CreateInsertValue(Result, Element, ArrayRef(I)); + } + return Result; + } + assert(!DstType->isStructTy()); + + if (OrigType->isPointerTy() && + (DstType->isIntegerTy() || DstType->isPointerTy())) { + Result = Builder.CreatePointerCast(V, DstType, "merge_cast"); + } else if (OrigType->isIntegerTy() && DstType->isPointerTy() && + OrigType == IntPtrTy) { + // Int -> Ptr + Result = Builder.CreateCast(CastInst::IntToPtr, V, DstType, "merge_cast"); + } else { + llvm_unreachable("Can only cast int -> ptr or ptr -> (ptr or int)"); + } + + // assert(cast(Result)->isNoopCast(InsertAtEnd->getParent()->getParent()->getDataLayout()) + // && + // "Cast is not a no-op cast. Potential loss of precision"); + + return Result; +} + +void FunctionMerger::CodeGenerator::removeRedundantInstructions( + std::vector &WorkInst, DominatorTree &DT) { + std::set SkipList; + + std::map> UpdateList; + + for (Instruction *I1 : WorkInst) { + if (SkipList.find(I1) != SkipList.end()) + continue; + for (Instruction *I2 : WorkInst) { + if (I1 == I2) + continue; + if (SkipList.find(I2) != SkipList.end()) + continue; + assert(I1->getNumOperands() == I2->getNumOperands() && + "Should have the same num of operands!"); + bool AllEqual = true; + for (unsigned i = 0; i < I1->getNumOperands(); ++i) { + AllEqual = AllEqual && (I1->getOperand(i) == I2->getOperand(i)); + } + + if (AllEqual && DT.dominates(I1, I2)) { + UpdateList[I1].push_back(I2); + SkipList.insert(I2); + SkipList.insert(I1); + } + } + } + + for (auto &kv : UpdateList) { + for (auto *I : kv.second) { + erase(I); + I->replaceAllUsesWith(kv.first); + I->eraseFromParent(); + } + } + //feisen:debug + errs()<<"L3"; +} + +//////////////////////////////////// SALSSA //////////////////////////////// + +static void postProcessFunction(Function &F) { + legacy::FunctionPassManager FPM(F.getParent()); + + // FPM.add(createPromoteMemoryToRegisterPass()); + FPM.add(createCFGSimplificationPass()); + // FPM.add(createInstructionCombiningPass(2)); + // FPM.add(createCFGSimplificationPass()); + + FPM.doInitialization(); + FPM.run(F); + FPM.doFinalization(); +} + +template +static void CodeGen(BlockListType &Blocks1, BlockListType &Blocks2, + BasicBlock *EntryBB1, BasicBlock *EntryBB2, + Function *MergedFunc, Value *IsFunc1, BasicBlock *PreBB, + AlignedCode &AlignedSeq, + ValueToValueMapTy &VMap, + std::unordered_map &BlocksF1, + std::unordered_map &BlocksF2, + std::unordered_map &MaterialNodes) { + + auto CloneInst = [](IRBuilder<> &Builder, Function *MF, + Instruction *I) -> Instruction * { + Instruction *NewI = nullptr; + if (I->getOpcode() == Instruction::Ret) { + if (MF->getReturnType()->isVoidTy()) { + NewI = Builder.CreateRetVoid(); + } else { + NewI = Builder.CreateRet(UndefValue::get(MF->getReturnType())); + } + } else { + // assert(I1->getNumOperands() == I2->getNumOperands() && + // "Num of Operands SHOULD be EQUAL!"); + NewI = I->clone(); + for (unsigned i = 0; i < NewI->getNumOperands(); i++) { + if (!isa(I->getOperand(i))) + NewI->setOperand(i, nullptr); + } + Builder.Insert(NewI); + } + + // NewI->dropPoisonGeneratingFlags(); //TODO: NOT SURE IF THIS IS VALID + + // TODO: temporarily removing metadata + + SmallVector, 8> MDs; + NewI->getAllMetadata(MDs); + for (std::pair MDPair : MDs) { + NewI->setMetadata(MDPair.first, nullptr); + } + + // if (isa(NewI)) { + // GetElementPtrInst * GEP = dyn_cast(I); + // GetElementPtrInst * GEP2 = dyn_cast(I2); + // dyn_cast(NewI)->setIsInBounds(GEP->isInBounds()); + //} + + /* + if (auto *CB = dyn_cast(I)) { + auto *NewCB = dyn_cast(NewI); + auto AttrList = CB->getAttributes(); + NewCB->setAttributes(AttrList); + }*/ + + return NewI; + }; + + for (auto &Entry : AlignedSeq) { + if (Entry.match()) { + + auto *I1 = dyn_cast(Entry.get(0)); + auto *I2 = dyn_cast(Entry.get(1)); + + std::string BBName = + (I1 == nullptr) ? "m.label.bb" + : (I1->isTerminator() ? "m.term.bb" : "m.inst.bb"); + + BasicBlock *MergedBB = + BasicBlock::Create(MergedFunc->getContext(), BBName, MergedFunc); + + MaterialNodes[Entry.get(0)] = MergedBB; + MaterialNodes[Entry.get(1)] = MergedBB; + + if (I1 != nullptr && I2 != nullptr) { + IRBuilder<> Builder(MergedBB); + Instruction *NewI = CloneInst(Builder, MergedFunc, I1); + + //feisen:debug + // if(BasicBlock *fbb = dyn_cast(I1)){ + // //feisen:debug + // errs()<<"feisen:debug::"; + // for(Instruction &I : *fbb){ + // if(PHINode *phi = dyn_cast(&I)){ + // phi->print(errs()); + // errs()<<"\n"; + // } + // } + // }else if(PHINode *phi = dyn_cast(I1)){ + // phi->print(errs()); + // errs()<<"\n"; + // } + + + VMap[I1] = NewI; + VMap[I2] = NewI; + BlocksF1[MergedBB] = I1->getParent(); + BlocksF2[MergedBB] = I2->getParent(); + } else { + assert(isa(Entry.get(0)) && isa(Entry.get(1)) && + "Both nodes must be basic blocks!"); + auto *BB1 = dyn_cast(Entry.get(0)); + auto *BB2 = dyn_cast(Entry.get(1)); + + VMap[BB1] = MergedBB; + VMap[BB2] = MergedBB; + BlocksF1[MergedBB] = BB1; + BlocksF2[MergedBB] = BB2; + + // IMPORTANT: make sure any use in a blockaddress constant + // operation is updated correctly + for (User *U : BB1->users()) { + if (auto *BA = dyn_cast(U)) { + VMap[BA] = BlockAddress::get(MergedFunc, MergedBB); + } + } + for (User *U : BB2->users()) { + if (auto *BA = dyn_cast(U)) { + VMap[BA] = BlockAddress::get(MergedFunc, MergedBB); + } + } + + IRBuilder<> Builder(MergedBB); + for (Instruction &I : *BB1) { + if (isa(&I)) { + //feisen:debug + // errs()<<":\n"; + // BB1->print(errs()); + VMap[&I] = Builder.CreatePHI(I.getType(), 0); + } + } + for (Instruction &I : *BB2) { + if (isa(&I)) { + //feisen:debug + // errs()<<":\n"; + // BB2->print(errs()); + VMap[&I] = Builder.CreatePHI(I.getType(), 0); + } + } + } // end if(instruction)-else + + // feisen + // if(BBName=="m.inst.bb") + // MergedBB->print(errs()); + } + } + + auto ChainBlocks = [](BasicBlock *SrcBB, BasicBlock *TargetBB, + Value *IsFunc1) { + IRBuilder<> Builder(SrcBB); + if (SrcBB->getTerminator() == nullptr) { + Builder.CreateBr(TargetBB); + } else { + auto *Br = dyn_cast(SrcBB->getTerminator()); + assert(Br && Br->isUnconditional() && + "Branch should be unconditional at this point!"); + BasicBlock *SuccBB = Br->getSuccessor(0); + // if (SuccBB != TargetBB) { + Br->eraseFromParent(); + Builder.CreateCondBr(IsFunc1, SuccBB, TargetBB); + //} + } + }; + + auto ProcessEachFunction = + [&](BlockListType &Blocks, + std::unordered_map &BlocksFX, + Value *IsFunc1) { + for (BasicBlock *BB : Blocks) { + BasicBlock *LastMergedBB = nullptr; + BasicBlock *NewBB = nullptr; + bool HasBeenMerged = MaterialNodes.find(BB) != MaterialNodes.end(); + if (HasBeenMerged) { + LastMergedBB = MaterialNodes[BB]; + } else { + std::string BBName = std::string("src.bb"); + NewBB = BasicBlock::Create(MergedFunc->getContext(), BBName, + MergedFunc); + VMap[BB] = NewBB; + BlocksFX[NewBB] = BB; + + // IMPORTANT: make sure any use in a blockaddress constant + // operation is updated correctly + for (User *U : BB->users()) { + if (auto *BA = dyn_cast(U)) { + VMap[BA] = BlockAddress::get(MergedFunc, NewBB); + } + } + + // errs() << "NewBB: " << NewBB->getName() << "\n"; + IRBuilder<> Builder(NewBB); + for (Instruction &I : *BB) { + if (isa(&I)) { + VMap[&I] = Builder.CreatePHI(I.getType(), 0); + } + } + } + for (Instruction &I : *BB) { + if (isa(&I)) + continue; + if (isa(&I)) + continue; + + bool HasBeenMerged = MaterialNodes.find(&I) != MaterialNodes.end(); + if (HasBeenMerged) { + BasicBlock *NodeBB = MaterialNodes[&I]; + if (LastMergedBB) { + // errs() << "Chaining last merged " << LastMergedBB->getName() + // << " with " << NodeBB->getName() << "\n"; + ChainBlocks(LastMergedBB, NodeBB, IsFunc1); + } else { + IRBuilder<> Builder(NewBB); + Builder.CreateBr(NodeBB); + // errs() << "Chaining newBB " << NewBB->getName() << " with " + // << NodeBB->getName() << "\n"; + } + // end keep track + LastMergedBB = NodeBB; + } else { + if (LastMergedBB) { + std::string BBName = std::string("split.bb"); + NewBB = BasicBlock::Create(MergedFunc->getContext(), BBName, + MergedFunc); + ChainBlocks(LastMergedBB, NewBB, IsFunc1); + BlocksFX[NewBB] = BB; + // errs() << "Splitting last merged " << LastMergedBB->getName() + // << " into " << NewBB->getName() << "\n"; + } + LastMergedBB = nullptr; + + IRBuilder<> Builder(NewBB); + Instruction *NewI = CloneInst(Builder, MergedFunc, &I); + VMap[&I] = NewI; + // errs() << "Cloned into " << NewBB->getName() << " : " << + // NewI->getName() << " " << NewI->getOpcodeName() << "\n"; + // I.dump(); + } + } + } + }; + + auto ProcessEachFunction_NonSeq = + [&](int FuncIdx, + std::unordered_map &BlocksFX, + Value *IsFunc1) { + + BasicBlock *LastMergedBB = nullptr; + BasicBlock *NewBB = nullptr; + + for (auto &Entry: AlignedSeq) { + Value *V = Entry.get(FuncIdx); + if (V == nullptr) + continue; + + if (BasicBlock *BB = dyn_cast(V)) { + LastMergedBB = nullptr; + NewBB = nullptr; + if (auto It = MaterialNodes.find(BB); It != MaterialNodes.end()) { + LastMergedBB = It->second; + // BB->print(errs()); + } else { + std::string BBName = std::string("src.bb"); + NewBB = BasicBlock::Create(MergedFunc->getContext(), BBName, + MergedFunc); + VMap[BB] = NewBB; + BlocksFX[NewBB] = BB; + + // IMPORTANT: make sure any use in a blockaddress constant + // operation is updated correctly + for (User *U : BB->users()) { + if (auto *BA = dyn_cast(U)) { + VMap[BA] = BlockAddress::get(MergedFunc, NewBB); + } + } + + IRBuilder<> Builder(NewBB); + for (Instruction &I : *BB) { + if (isa(&I)) { + VMap[&I] = Builder.CreatePHI(I.getType(), 0); + } + } + } + } else if (Instruction *I = dyn_cast(V)) { + if (isa(I)) + continue; + if (isa(I)) + continue; + + if (auto It = MaterialNodes.find(I); It != MaterialNodes.end()) { + BasicBlock *NodeBB = It->second; + if (LastMergedBB) { + ChainBlocks(LastMergedBB, NodeBB, IsFunc1); + } else { + IRBuilder<> Builder(NewBB); + Builder.CreateBr(NodeBB); + } + // end keep track + LastMergedBB = NodeBB; + } else { + if (LastMergedBB) { + std::string BBName = std::string("split.bb"); + NewBB = BasicBlock::Create(MergedFunc->getContext(), BBName, + MergedFunc); + ChainBlocks(LastMergedBB, NewBB, IsFunc1); + BlocksFX[NewBB] = BB; + } + LastMergedBB = nullptr; + + IRBuilder<> Builder(NewBB); + Instruction *NewI = CloneInst(Builder, MergedFunc, I); + VMap[I] = NewI; + } + } else { + errs() << "Should never get here!\n"; + } + } + }; + +#ifdef CHANGES + ProcessEachFunction_NonSeq(0, BlocksF1, IsFunc1); + ProcessEachFunction_NonSeq(1, BlocksF2, IsFunc1); +#else + ProcessEachFunction(Blocks1, BlocksF1, IsFunc1); + ProcessEachFunction(Blocks2, BlocksF2, IsFunc1); +#endif + // errs()<<"AlignedSeq Size: " << AlignedSeq.size()<<"\n"; + // errs()<<"CodeGen: 4740\n"; + // errs()<(VMap[EntryBB1]); + auto *BB2 = dyn_cast(VMap[EntryBB2]); + + BlocksF1[PreBB] = BB1; + BlocksF2[PreBB] = BB2; + + if (BB1 == BB2) { + IRBuilder<> Builder(PreBB); + Builder.CreateBr(BB1); + } else { + IRBuilder<> Builder(PreBB); + Builder.CreateCondBr(IsFunc1, BB1, BB2); + } +} + +bool FunctionMerger::SALSSACodeGen::generate( + AlignedCode &AlignedSeq, ValueToValueMapTy &VMap, + const FunctionMergingOptions &Options) { + +#ifdef TIME_STEPS_DEBUG + TimeCodeGen.startTimer(); +#endif + + LLVMContext &Context = CodeGenerator::getContext(); + Function *MergedFunc = CodeGenerator::getMergedFunction(); + Value *IsFunc1 = CodeGenerator::getFunctionIdentifier(); + Type *ReturnType = CodeGenerator::getMergedReturnType(); + bool RequiresUnifiedReturn = + CodeGenerator::getRequiresUnifiedReturn(); + BasicBlock *EntryBB1 = CodeGenerator::getEntryBlock1(); + BasicBlock *EntryBB2 = CodeGenerator::getEntryBlock2(); + BasicBlock *PreBB = CodeGenerator::getPreBlock(); + + Type *RetType1 = CodeGenerator::getReturnType1(); + Type *RetType2 = CodeGenerator::getReturnType2(); + + Type *IntPtrTy = CodeGenerator::getIntPtrType(); + + std::vector &Blocks1 = CodeGenerator::getBlocks1(); + std::vector &Blocks2 = CodeGenerator::getBlocks2(); + + std::list LinearOffendingInsts; + std::set OffendingInsts; + std::map> + CoalescingCandidates; + + std::vector ListSelects; + + std::vector Allocas; + + Value *RetUnifiedAddr = nullptr; + Value *RetAddr1 = nullptr; + Value *RetAddr2 = nullptr; + + // maps new basic blocks in the merged function to their original + // correspondents + std::unordered_map BlocksF1; + std::unordered_map BlocksF2; + std::unordered_map MaterialNodes; + + CodeGen(Blocks1, Blocks2, EntryBB1, EntryBB2, MergedFunc, IsFunc1, PreBB, + AlignedSeq, VMap, BlocksF1, BlocksF2, MaterialNodes); + + if (RequiresUnifiedReturn) { + IRBuilder<> Builder(PreBB); + RetUnifiedAddr = Builder.CreateAlloca(ReturnType); + CodeGenerator::insert(dyn_cast(RetUnifiedAddr)); + + RetAddr1 = Builder.CreateAlloca(RetType1); + RetAddr2 = Builder.CreateAlloca(RetType2); + CodeGenerator::insert(dyn_cast(RetAddr1)); + CodeGenerator::insert(dyn_cast(RetAddr2)); + } + + // errs() << "Assigning label operands\n"; + + std::set XorBrConds; + // assigning label operands + + for (auto &Entry : AlignedSeq) { + Instruction *I1 = nullptr; + Instruction *I2 = nullptr; + + if (Entry.get(0) != nullptr) + I1 = dyn_cast(Entry.get(0)); + if (Entry.get(1) != nullptr) + I2 = dyn_cast(Entry.get(1)); + + // Skip non-instructions + if (I1 == nullptr && I2 == nullptr) + continue; + + if (Entry.match()) { + + Instruction *I = I1; + if (I1->getOpcode() == Instruction::Ret) { + I = (I1->getNumOperands() >= I2->getNumOperands()) ? I1 : I2; + } else { + assert(I1->getNumOperands() == I2->getNumOperands() && + "Num of Operands SHOULD be EQUAL\n"); + } + + auto *NewI = dyn_cast(VMap[I]); + + bool Handled = false; + /* + BranchInst *NewBr = dyn_cast(NewI); + if (EnableOperandReordering && NewBr!=nullptr && NewBr->isConditional()) { + BranchInst *Br1 = dyn_cast(I1); + BranchInst *Br2 = dyn_cast(I2); + + BasicBlock *SuccBB10 = + dyn_cast(MapValue(Br1->getSuccessor(0), VMap)); BasicBlock + *SuccBB11 = dyn_cast(MapValue(Br1->getSuccessor(1), VMap)); + + BasicBlock *SuccBB20 = + dyn_cast(MapValue(Br2->getSuccessor(0), VMap)); BasicBlock + *SuccBB21 = dyn_cast(MapValue(Br2->getSuccessor(1), VMap)); + + if (SuccBB10!=nullptr && SuccBB11!=nullptr && SuccBB10==SuccBB21 && + SuccBB20==SuccBB11) { if (Debug) errs() << "OptimizationTriggered: Labels of Conditional Branch Reordering\n"; + + XorBrConds.insert(NewBr); + NewBr->setSuccessor(0,SuccBB20); + NewBr->setSuccessor(1,SuccBB21); + Handled = true; + } + } + */ + if (!Handled) { + for (unsigned i = 0; i < I->getNumOperands(); i++) { + + Value *F1V = nullptr; + Value *V1 = nullptr; + if (i < I1->getNumOperands()) { + F1V = I1->getOperand(i); + V1 = MapValue(F1V, VMap); + // assert(V1!=nullptr && "Mapped value should NOT be NULL!"); + if (V1 == nullptr) { + if (Debug) + errs() << "ERROR: Null value mapped: V1 = " + "MapValue(I1->getOperand(i), " + "VMap);\n"; + // MergedFunc->eraseFromParent(); +#ifdef TIME_STEPS_DEBUG + TimeCodeGen.stopTimer(); +#endif + return false; + } + } else { + V1 = UndefValue::get(I2->getOperand(i)->getType()); + } + + Value *F2V = nullptr; + Value *V2 = nullptr; + if (i < I2->getNumOperands()) { + F2V = I2->getOperand(i); + V2 = MapValue(F2V, VMap); + // assert(V2!=nullptr && "Mapped value should NOT be NULL!"); + + if (V2 == nullptr) { + if (Debug) + errs() << "ERROR: Null value mapped: V2 = " + "MapValue(I2->getOperand(i), " + "VMap);\n"; + // MergedFunc->eraseFromParent(); +#ifdef TIME_STEPS_DEBUG + TimeCodeGen.stopTimer(); +#endif + return false; + } + + } else { + V2 = UndefValue::get(I1->getOperand(i)->getType()); + } + + assert(V1 != nullptr && "Value should NOT be null!"); + assert(V2 != nullptr && "Value should NOT be null!"); + + Value *V = V1; // first assume that V1==V2 + + // handling just label operands for now + if (!isa(V)) + continue; + + auto *F1BB = dyn_cast(F1V); + auto *F2BB = dyn_cast(F2V); + + if (V1 != V2) { + auto *BB1 = dyn_cast(V1); + auto *BB2 = dyn_cast(V2); + + // auto CacheKey = std::pair(BB1, BB2); + BasicBlock *SelectBB = + BasicBlock::Create(Context, "bb.select", MergedFunc); + IRBuilder<> BuilderBB(SelectBB); + + BlocksF1[SelectBB] = I1->getParent(); + BlocksF2[SelectBB] = I2->getParent(); + + BuilderBB.CreateCondBr(IsFunc1, BB1, BB2); + V = SelectBB; + } + + if (F1BB->isLandingPad() || F2BB->isLandingPad()) { + LandingPadInst *LP1 = F1BB->getLandingPadInst(); + LandingPadInst *LP2 = F2BB->getLandingPadInst(); + assert((LP1 != nullptr && LP2 != nullptr) && + "Should be both as per the BasicBlock match!"); + (void)LP2; + + BasicBlock *LPadBB = + BasicBlock::Create(Context, "lpad.bb", MergedFunc); + IRBuilder<> BuilderBB(LPadBB); + + Instruction *NewLP = LP1->clone(); + BuilderBB.Insert(NewLP); + + BuilderBB.CreateBr(dyn_cast(V)); + + BlocksF1[LPadBB] = I1->getParent(); + BlocksF2[LPadBB] = I2->getParent(); + + VMap[F1BB->getLandingPadInst()] = NewLP; + VMap[F2BB->getLandingPadInst()] = NewLP; + + V = LPadBB; + } + NewI->setOperand(i, V); + } + } + + } else { // if(entry.match())-else + + auto AssignLabelOperands = + [&](Instruction *I, + std::unordered_map &BlocksReMap) + -> bool { + auto *NewI = dyn_cast(VMap[I]); + // if (isa(I)) + // errs() << "Setting operand in " << NewI->getParent()->getName() << " + // : " << NewI->getName() << " " << NewI->getOpcodeName() << "\n"; + for (unsigned i = 0; i < I->getNumOperands(); i++) { + // handling just label operands for now + if (!isa(I->getOperand(i))) + continue; + auto *FXBB = dyn_cast(I->getOperand(i)); + + Value *V = MapValue(FXBB, VMap); + // assert( V!=nullptr && "Mapped value should NOT be NULL!"); + if (V == nullptr) + return false; // ErrorResponse; + + if (FXBB->isLandingPad()) { + + LandingPadInst *LP = FXBB->getLandingPadInst(); + assert(LP != nullptr && "Should have a landingpad inst!"); + + BasicBlock *LPadBB = + BasicBlock::Create(Context, "lpad.bb", MergedFunc); + IRBuilder<> BuilderBB(LPadBB); + + Instruction *NewLP = LP->clone(); + BuilderBB.Insert(NewLP); + VMap[LP] = NewLP; + BlocksReMap[LPadBB] = I->getParent(); //FXBB; + + BuilderBB.CreateBr(dyn_cast(V)); + + V = LPadBB; + } + + NewI->setOperand(i, V); + // if (isa(NewI)) + // errs() << "Operand " << i << ": " << V->getName() << "\n"; + } + return true; + }; + + if (I1 != nullptr && !AssignLabelOperands(I1, BlocksF1)) { + if (Debug) + errs() << "ERROR: Value should NOT be null\n"; + // MergedFunc->eraseFromParent(); + +#ifdef TIME_STEPS_DEBUG + TimeCodeGen.stopTimer(); +#endif + return false; + } + if (I2 != nullptr && !AssignLabelOperands(I2, BlocksF2)) { + if (Debug) + errs() << "ERROR: Value should NOT be null\n"; + // MergedFunc->eraseFromParent(); + +#ifdef TIME_STEPS_DEBUG + TimeCodeGen.stopTimer(); +#endif + return false; + } + } + } + + // errs() << "Assigning value operands\n"; + + auto MergeValues = [&](Value *V1, Value *V2, + Instruction *InsertPt) -> Value * { + if (V1 == V2) + return V1; + + if (V1 == ConstantInt::getTrue(Context) && V2 == ConstantInt::getFalse(Context)) + return IsFunc1; + + if (V1 == ConstantInt::getFalse(Context) && V2 == ConstantInt::getTrue(Context)) { + IRBuilder<> Builder(InsertPt); + /// TODO: create a single not(IsFunc1) for each merged function that needs it + return Builder.CreateNot(IsFunc1); + } + + auto *IV1 = dyn_cast(V1); + auto *IV2 = dyn_cast(V2); + + if (IV1 && IV2) { + // if both IV1 and IV2 are non-merged values + if (BlocksF2.find(IV1->getParent()) == BlocksF2.end() && + BlocksF1.find(IV2->getParent()) == BlocksF1.end()) { + CoalescingCandidates[IV1][IV2]++; + CoalescingCandidates[IV2][IV1]++; + } + } + + IRBuilder<> Builder(InsertPt); + Instruction *Sel = (Instruction *)Builder.CreateSelect(IsFunc1, V1, V2); + ListSelects.push_back(dyn_cast(Sel)); + return Sel; + }; + + auto AssignOperands = [&](Instruction *I, bool IsFuncId1) -> bool { + auto *NewI = dyn_cast(VMap[I]); + IRBuilder<> Builder(NewI); + + if (I->getOpcode() == Instruction::Ret && RequiresUnifiedReturn) { + Value *V = MapValue(I->getOperand(0), VMap); + if (V == nullptr) { + return false; // ErrorResponse; + } + if (V->getType() != ReturnType) { + // Value *Addr = (IsFuncId1 ? RetAddr1 : RetAddr2); + Value *Addr = Builder.CreateAlloca(V->getType()); + Builder.CreateStore(V, Addr); + Value *CastedAddr = + Builder.CreatePointerCast(Addr, RetUnifiedAddr->getType()); + V = Builder.CreateLoad(ReturnType, CastedAddr); + } + NewI->setOperand(0, V); + } else { + for (unsigned i = 0; i < I->getNumOperands(); i++) { + if (isa(I->getOperand(i))) + continue; + + Value *V = MapValue(I->getOperand(i), VMap); + // assert( V!=nullptr && "Mapped value should NOT be NULL!"); + if (V == nullptr) { + return false; // ErrorResponse; + } + + // Value *CastedV = createCastIfNeeded(V, + // NewI->getOperand(i)->getType(), Builder, IntPtrTy); + NewI->setOperand(i, V); + } + } + + return true; + }; + + for (auto &Entry : AlignedSeq) { + Instruction *I1 = nullptr; + Instruction *I2 = nullptr; + + if (Entry.get(0) != nullptr) + I1 = dyn_cast(Entry.get(0)); + if (Entry.get(1) != nullptr) + I2 = dyn_cast(Entry.get(1)); + + if (I1 != nullptr && I2 != nullptr) { + + // Instruction *I1 = dyn_cast(MN->N1->getValue()); + // Instruction *I2 = dyn_cast(MN->N2->getValue()); + + Instruction *I = I1; + if (I1->getOpcode() == Instruction::Ret) { + I = (I1->getNumOperands() >= I2->getNumOperands()) ? I1 : I2; + } else { + assert(I1->getNumOperands() == I2->getNumOperands() && + "Num of Operands SHOULD be EQUAL\n"); + } + + auto *NewI = dyn_cast(VMap[I]); + + IRBuilder<> Builder(NewI); + + if (EnableOperandReordering && isa(NewI) && + I->isCommutative()) { + + auto *BO1 = dyn_cast(I1); + auto *BO2 = dyn_cast(I2); + Value *VL1 = MapValue(BO1->getOperand(0), VMap); + Value *VL2 = MapValue(BO2->getOperand(0), VMap); + Value *VR1 = MapValue(BO1->getOperand(1), VMap); + Value *VR2 = MapValue(BO2->getOperand(1), VMap); + if (VL1 == VR2 && VL2 != VR2) { + std::swap(VL2, VR2); + // CountOpReorder++; + } else if (VL2 == VR1 && VL1 != VR1) { + std::swap(VL1, VR1); + } + + std::vector> Vs; + Vs.emplace_back(VL1, VL2); + Vs.emplace_back(VR1, VR2); + + for (unsigned i = 0; i < Vs.size(); i++) { + Value *V1 = Vs[i].first; + Value *V2 = Vs[i].second; + + Value *V = MergeValues(V1, V2, NewI); + if (V == nullptr) { + if (Debug) { + errs() << "Could Not select:\n"; + errs() << "ERROR: Value should NOT be null\n"; + } + // MergedFunc->eraseFromParent(); +#ifdef TIME_STEPS_DEBUG + TimeCodeGen.stopTimer(); +#endif + return false; // ErrorResponse; + } + + // TODO: cache the created instructions + // Value *CastedV = CreateCast(Builder, V, + // NewI->getOperand(i)->getType()); + Value *CastedV = createCastIfNeeded(V, NewI->getOperand(i)->getType(), + Builder, IntPtrTy); + NewI->setOperand(i, CastedV); + } + } else { + for (unsigned i = 0; i < I->getNumOperands(); i++) { + if (isa(I->getOperand(i))) + continue; + + Value *V1 = nullptr; + if (i < I1->getNumOperands()) { + V1 = MapValue(I1->getOperand(i), VMap); + // assert(V1!=nullptr && "Mapped value should NOT be NULL!"); + if (V1 == nullptr) { + if (Debug) + errs() << "ERROR: Null value mapped: V1 = " + "MapValue(I1->getOperand(i), " + "VMap);\n"; + // MergedFunc->eraseFromParent(); +#ifdef TIME_STEPS_DEBUG + TimeCodeGen.stopTimer(); +#endif + return false; + } + } else { + V1 = UndefValue::get(I2->getOperand(i)->getType()); + } + + Value *V2 = nullptr; + if (i < I2->getNumOperands()) { + V2 = MapValue(I2->getOperand(i), VMap); + // assert(V2!=nullptr && "Mapped value should NOT be NULL!"); + + if (V2 == nullptr) { + if (Debug) + errs() << "ERROR: Null value mapped: V2 = " + "MapValue(I2->getOperand(i), " + "VMap);\n"; + // MergedFunc->eraseFromParent(); +#ifdef TIME_STEPS_DEBUG + TimeCodeGen.stopTimer(); +#endif + return false; + } + + } else { + V2 = UndefValue::get(I1->getOperand(i)->getType()); + } + + assert(V1 != nullptr && "Value should NOT be null!"); + assert(V2 != nullptr && "Value should NOT be null!"); + + Value *V = MergeValues(V1, V2, NewI); + if (V == nullptr) { + if (Debug) { + errs() << "Could Not select:\n"; + errs() << "ERROR: Value should NOT be null\n"; + } + // MergedFunc->eraseFromParent(); +#ifdef TIME_STEPS_DEBUG + TimeCodeGen.stopTimer(); +#endif + return false; // ErrorResponse; + } + + // Value *CastedV = createCastIfNeeded(V, + // NewI->getOperand(i)->getType(), Builder, IntPtrTy); + NewI->setOperand(i, V); + + } // end for operands + } + } // end if isomorphic + else { + // PDGNode *N = MN->getUniqueNode(); + if (I1 != nullptr && !AssignOperands(I1, true)) { + if (Debug) + errs() << "ERROR: Value should NOT be null\n"; + // MergedFunc->eraseFromParent(); +#ifdef TIME_STEPS_DEBUG + TimeCodeGen.stopTimer(); +#endif + return false; + } + if (I2 != nullptr && !AssignOperands(I2, false)) { + if (Debug) + errs() << "ERROR: Value should NOT be null\n"; + // MergedFunc->eraseFromParent(); +#ifdef TIME_STEPS_DEBUG + TimeCodeGen.stopTimer(); +#endif + return false; + } + } // end 'if-else' non-isomorphic + + } // end for nodes + if(Debug) + errs() << "NumSelects: " << ListSelects.size() << "\n"; + if (ListSelects.size() > MaxNumSelection) { + if(Debug) + errs() << "Bailing out: Operand selection threshold\n"; +#ifdef TIME_STEPS_DEBUG + TimeCodeGen.stopTimer(); +#endif + return false; + } + + // errs() << "Assigning PHI operands\n"; + + auto AssignPHIOperandsInBlock = + [&](BasicBlock *BB, + std::unordered_map &BlocksReMap) -> bool { + for (Instruction &I : *BB) { + if (auto *PHI = dyn_cast(&I)) { + auto *NewPHI = dyn_cast(VMap[PHI]); + + std::set FoundIndices; + + for (auto It = pred_begin(NewPHI->getParent()), + E = pred_end(NewPHI->getParent()); + It != E; It++) { + + BasicBlock *NewPredBB = *It; + + Value *V = nullptr; + + // if (BlocksReMap.find(NewPredBB) != BlocksReMap.end()) { + if (BlocksReMap.find(NewPredBB) != BlocksReMap.end()) { + int Index = PHI->getBasicBlockIndex(BlocksReMap[NewPredBB]); + if (Index >= 0) { + V = MapValue(PHI->getIncomingValue(Index), VMap); + FoundIndices.insert(Index); + } + } + + if (V == nullptr){ + V = UndefValue::get(NewPHI->getType()); + // errs()<<"feisen:1;"; + } + // errs()<<"feisen:1|"; + + // IRBuilder<> Builder(NewPredBB->getTerminator()); + // Value *CastedV = createCastIfNeeded(V, NewPHI->getType(), Builder, + // IntPtrTy); + NewPHI->addIncoming(V, NewPredBB); + } + // errs()<<"feisen:1]"; + if (FoundIndices.size() != PHI->getNumIncomingValues()){ + if(Debug){ + // print the PHI node / do not use dump + PHI->print(errs()); + errs()<<"\n"; + + errs()<<"feisen: "; + errs()<<"FoundIndices.size(): "<getNumIncomingValues()"<getNumIncomingValues()<<" \n"; + } + return false; + } + } + } + return true; + }; + + for (BasicBlock *BB1 : Blocks1) { + if (!AssignPHIOperandsInBlock(BB1, BlocksF1)) { + if (Debug) + errs() << "ERROR: PHI assignment\n"; + //MergedFunc->eraseFromParent(); + +#ifdef TIME_STEPS_DEBUG + TimeCodeGen.stopTimer(); +#endif + return false; + } + } + for (BasicBlock *BB2 : Blocks2) { + if (!AssignPHIOperandsInBlock(BB2, BlocksF2)) { + if (Debug) + errs() << "ERROR: PHI assignment\n"; + //MergedFunc->eraseFromParent(); +#ifdef TIME_STEPS_DEBUG + TimeCodeGen.stopTimer(); +#endif + return false; + } + } + +#ifdef CHANGES + // Replace select statements by merged PHIs + + // Collect candidate pairs of PHI Nodes + SmallSet, 16> CandPHI; + for (Instruction *I: ListSelects) { + SelectInst *SI = dyn_cast(I); + assert(SI != nullptr); + + PHINode *PT = dyn_cast(SI->getTrueValue()); + PHINode *PF = dyn_cast(SI->getFalseValue()); + + if (PT == nullptr || PF == nullptr) + continue; + + // Only pair PHI Nodes in the same block + if (PT->getParent() != PF->getParent()) + continue; + + CandPHI.insert({PT, PF}); + } + + SmallSet RemovedPHIs; + for (auto [PT, PF] : CandPHI) { + if ((RemovedPHIs.count(PT) > 0) || (RemovedPHIs.count(PF) > 0)) + continue; + // Merge PT and PF if: + // 1) their defined incoming values do not overlap + // 2) their uses are only select statements on IsFunc1 + bool valid = true; + SmallVector CandSel; + + // Are PHIs mergeable? + for (unsigned i = 0; i < PT->getNumIncomingValues() && valid; ++i) { + // if PT incoming value is Undef, this edge pair is mergeable + Value *VT = PT->getIncomingValue(i); + if (dyn_cast(VT) != nullptr) + continue; + + // if the PF incoming value for the same block is Undef, + // this edge pair is mergeable + BasicBlock *PredBB = PT->getIncomingBlock(i); + if (PF->getBasicBlockIndex(PredBB) < 0) { + errs() << "PHI ERROR\n"; + //Comment out this code temporarily to eliminate the linking error. :feisen + // PT->dump(); + // PF->dump(); + // MergedFunc->dump(); + //Comment out this code temporarily to eliminate the linking error. :feisen + } + Value *VF = PF->getIncomingValueForBlock(PredBB); + if(dyn_cast(VF) != nullptr) + continue; + + // If the two incoming values are the same, then we can merge them + if (VT == VF) + continue; + + valid = false; + } + + if (!valid) + continue; + + // Are PHIs only used together in select statements? + for (auto *UI: PT->users()) { + SelectInst *SI = dyn_cast(UI); + if (SI == nullptr) { + valid = false; + break; + } + + if ((SI->getTrueValue() != PT) || (SI->getFalseValue() != PF)) { + valid = false; + break; + } + + if (SI->getCondition() != IsFunc1) { + valid = false; + break; + } + CandSel.push_back(SI); + } + + if (!valid) + continue; + + // Do the actual PHI merging using PT + for (unsigned i = 0; i < PT->getNumIncomingValues() && valid; ++i) { + // If edge is set, use it + if (dyn_cast(PT->getIncomingValue(i)) == nullptr) + continue; + + // If edge not set, copy it from PF + BasicBlock *PredBB = PT->getIncomingBlock(i); + PT->setIncomingValue(i, PF->getIncomingValueForBlock(PredBB)); + // errs()<<"feisen:5"; + } + + PF->replaceAllUsesWith(PT); + PF->eraseFromParent(); + RemovedPHIs.insert(PF); + + // Replace all uses of the select statements with PT + for (SelectInst *SI: CandSel) { + SI->replaceAllUsesWith(PT); + SI->eraseFromParent(); + } + } +#endif + + // errs() << "Collecting offending instructions\n"; + DominatorTree DT(*MergedFunc); + + for (Instruction &I : instructions(MergedFunc)) { + if (auto *PHI = dyn_cast(&I)) { + for (unsigned i = 0; i < PHI->getNumIncomingValues(); i++) { + BasicBlock *BB = PHI->getIncomingBlock(i); + if (BB == nullptr) + errs() << "Null incoming block\n"; + Value *V = PHI->getIncomingValue(i); + if (V == nullptr) + errs() << "Null incoming value\n"; + if (auto *IV = dyn_cast(V)) { + if (BB->getTerminator() == nullptr) { + if (Debug) + errs() << "ERROR: Null terminator\n"; + // MergedFunc->eraseFromParent(); +#ifdef TIME_STEPS_DEBUG + TimeCodeGen.stopTimer(); +#endif + return false; + } + if (!DT.dominates(IV, BB->getTerminator())) { + if (OffendingInsts.count(IV) == 0) { + OffendingInsts.insert(IV); + LinearOffendingInsts.push_back(IV); + } + } + } + } + } else { + for (unsigned i = 0; i < I.getNumOperands(); i++) { + if (I.getOperand(i) == nullptr) { + // MergedFunc->dump(); + // I.getParent()->dump(); + // errs() << "Null operand\n"; + // I.dump(); + if (Debug) + errs() << "ERROR: Null operand\n"; + // MergedFunc->eraseFromParent(); +#ifdef TIME_STEPS_DEBUG + TimeCodeGen.stopTimer(); +#endif + return false; + } + if (auto *IV = dyn_cast(I.getOperand(i))) { + if (!DT.dominates(IV, &I)) { + if (OffendingInsts.count(IV) == 0) { + OffendingInsts.insert(IV); + LinearOffendingInsts.push_back(IV); + } + } + } + } + } + } + + for (BranchInst *NewBr : XorBrConds) { + IRBuilder<> Builder(NewBr); + Value *XorCond = Builder.CreateXor(NewBr->getCondition(), IsFunc1); + NewBr->setCondition(XorCond); + } + +#ifdef TIME_STEPS_DEBUG + TimeCodeGen.stopTimer(); +#endif + +#ifdef TIME_STEPS_DEBUG + TimeCodeGenFix.startTimer(); +#endif + + auto StoreInstIntoAddr = [](Instruction *IV, Value *Addr) { + IRBuilder<> Builder(IV->getParent()); + if (IV->isTerminator()) { + BasicBlock *SrcBB = IV->getParent(); + if (auto *II = dyn_cast(IV)) { + BasicBlock *DestBB = II->getNormalDest(); + + Builder.SetInsertPoint(&*DestBB->getFirstInsertionPt()); + // create PHI + PHINode *PHI = Builder.CreatePHI(IV->getType(), 0); + for (auto PredIt = pred_begin(DestBB), PredE = pred_end(DestBB); + PredIt != PredE; PredIt++) { + BasicBlock *PredBB = *PredIt; + if (PredBB == SrcBB) { + PHI->addIncoming(IV, PredBB); + // errs()<<"feisen:7"; + } else { + PHI->addIncoming(UndefValue::get(IV->getType()), PredBB); + // errs()<<"feisen:2;"; + } + } + Builder.CreateStore(PHI, Addr); + } else { + for (auto SuccIt = succ_begin(SrcBB), SuccE = succ_end(SrcBB); + SuccIt != SuccE; SuccIt++) { + BasicBlock *DestBB = *SuccIt; + + Builder.SetInsertPoint(&*DestBB->getFirstInsertionPt()); + // create PHI + PHINode *PHI = Builder.CreatePHI(IV->getType(), 0); + for (auto PredIt = pred_begin(DestBB), PredE = pred_end(DestBB); + PredIt != PredE; PredIt++) { + BasicBlock *PredBB = *PredIt; + if (PredBB == SrcBB) { + PHI->addIncoming(IV, PredBB); + // errs()<<"feisen:8"; + } else { + PHI->addIncoming(UndefValue::get(IV->getType()), PredBB); + // errs()<<"feisen:3;"; + } + } + Builder.CreateStore(PHI, Addr); + } + } + } else { + Instruction *LastI = nullptr; + Instruction *InsertPt = nullptr; + for (Instruction &I : *IV->getParent()) { + InsertPt = &I; + if (LastI == IV) + break; + LastI = &I; + } + if (isa(InsertPt) || isa(InsertPt)) { + Builder.SetInsertPoint(&*IV->getParent()->getFirstInsertionPt()); + //Builder.SetInsertPoint(IV->getParent()->getTerminator()); + } else + Builder.SetInsertPoint(InsertPt); + + Builder.CreateStore(IV, Addr); + } + }; + + auto MemfyInst = [&](std::set &InstSet) -> AllocaInst * { + if (InstSet.empty()) + return nullptr; + IRBuilder<> Builder(&*PreBB->getFirstInsertionPt()); + AllocaInst *Addr = Builder.CreateAlloca((*InstSet.begin())->getType()); + Type *Ty = Addr->getAllocatedType(); + + for (Instruction *I : InstSet) { + for (auto UIt = I->use_begin(), E = I->use_end(); UIt != E;) { + Use &UI = *UIt; + UIt++; + + auto *User = cast(UI.getUser()); + + if (auto *PHI = dyn_cast(User)) { + /// TODO: make sure getOperandNo is getting the correct incoming edge + auto InsertionPt = PHI->getIncomingBlock(UI.getOperandNo())->getTerminator(); + /// TODO: If the terminator of the incoming block is the producer of + // the value we want to store, the load cannot be inserted between + // the producer and the user. Something more complex is needed. + if (InsertionPt == I) + continue; + IRBuilder<> Builder(InsertionPt); + UI.set(Builder.CreateLoad(Ty, Addr)); + } else { + IRBuilder<> Builder(User); + UI.set(Builder.CreateLoad(Ty, Addr)); + } + } + } + + for (Instruction *I : InstSet) + StoreInstIntoAddr(I, Addr); + + return Addr; + }; + + auto isCoalescingProfitable = [&](Instruction *I1, Instruction *I2) -> bool { + std::set BBSet1; + std::set UnionBB; + for (User *U : I1->users()) { + if (auto *UI = dyn_cast(U)) { + BasicBlock *BB1 = UI->getParent(); + BBSet1.insert(BB1); + UnionBB.insert(BB1); + } + } + + unsigned Intersection = 0; + for (User *U : I2->users()) { + if (auto *UI = dyn_cast(U)) { + BasicBlock *BB2 = UI->getParent(); + UnionBB.insert(BB2); + if (BBSet1.find(BB2) != BBSet1.end()) + Intersection++; + } + } + + const float Threshold = 0.7; + return (float(Intersection) / float(UnionBB.size()) > Threshold); + }; + + auto OptimizeCoalescing = + [&](Instruction *I, std::set &InstSet, + std::map> + &CoalescingCandidates, + std::set &Visited) { + Instruction *OtherI = nullptr; + unsigned Score = 0; + if (CoalescingCandidates.find(I) != CoalescingCandidates.end()) { + for (auto &Pair : CoalescingCandidates[I]) { + if (Pair.second > Score && + Visited.find(Pair.first) == Visited.end()) { + if (isCoalescingProfitable(I, Pair.first)) { + OtherI = Pair.first; + Score = Pair.second; + } + } + } + } + /* + if (OtherI==nullptr) { + for (Instruction *OI : OffendingInsts) { + if (OI->getType()!=I->getType()) continue; + if (Visited.find(OI)!=Visited.end()) continue; + if (CoalescingCandidates.find(OI)!=CoalescingCandidates.end()) + continue; if( (BlocksF2.find(I->getParent())==BlocksF2.end() && + BlocksF1.find(OI->getParent())==BlocksF1.end()) || + (BlocksF2.find(OI->getParent())==BlocksF2.end() && + BlocksF1.find(I->getParent())==BlocksF1.end()) ) { OtherI = OI; break; + } + } + } + */ + if (OtherI) { + InstSet.insert(OtherI); + // errs() << "Coalescing: " << GetValueName(I->getParent()) << ":"; + // I->dump(); errs() << "With: " << GetValueName(OtherI->getParent()) + // << ":"; OtherI->dump(); + } + }; + + // errs() << "Finishing code\n"; + if (MergedFunc != nullptr) { + // errs() << "Offending: " << OffendingInsts.size() << " "; + // errs() << ((float)OffendingInsts.size())/((float)AlignedSeq.size()) << " + // : "; if (OffendingInsts.size()>1000) { if (false) { + if (((float)OffendingInsts.size()) / ((float)AlignedSeq.size()) > 4.5) { + if (Debug) + errs() << "Bailing out\n"; +#ifdef TIME_STEPS_DEBUG + TimeCodeGenFix.stopTimer(); +#endif + return false; + } + //errs() << "Fixing Domination:\n"; + //MergedFunc->dump(); + std::set Visited; + for (Instruction *I : LinearOffendingInsts) { + if (Visited.find(I) != Visited.end()) + continue; + + std::set InstSet; + InstSet.insert(I); + + // Create a coalescing group in InstSet + if (EnableSALSSACoalescing) + OptimizeCoalescing(I, InstSet, CoalescingCandidates, Visited); + + for (Instruction *OtherI : InstSet) + Visited.insert(OtherI); + + AllocaInst *Addr = MemfyInst(InstSet); + if (Addr) + Allocas.push_back(Addr); + } + + //errs() << "Fixed Domination:\n"; + //MergedFunc->dump(); + + DominatorTree DT(*MergedFunc); + PromoteMemToReg(Allocas, DT, nullptr); + + //errs() << "Mem2Reg:\n"; + //MergedFunc->dump(); + + if (verifyFunction(*MergedFunc)) { + if (Verbose) + errs() << "ERROR: Produced Broken Function!\n"; +#ifdef TIME_STEPS_DEBUG + TimeCodeGenFix.stopTimer(); +#endif + return false; + } +#ifdef TIME_STEPS_DEBUG + TimeCodeGenFix.stopTimer(); +#endif +#ifdef TIME_STEPS_DEBUG + TimePostOpt.startTimer(); +#endif + postProcessFunction(*MergedFunc); +#ifdef TIME_STEPS_DEBUG + TimePostOpt.stopTimer(); +#endif + // errs() << "PostProcessing:\n"; + // MergedFunc->dump(); + } + + + //feisen:debug:use this to avoid some bugs int EarlyCSE + bool Success = true; + int counter = 0; + //array deque basicblock + std::deque myDeque; + for(BasicBlock &b: *MergedFunc){ + // if(b.getSinglePredecessor()==nullptr){ + // // return false; + // } + + // if(Debug) + // errs()<<"in fm.cpp block name: "<getName()+"_block_"+std::to_string(counter++)); + // if(b.getName().equals("")){ + // // myDeque.push_back(&b); + // if(Debug) + // errs()<<"cannot set basic block name"<eraseFromParent(); + // } + if(!Success) return false; + + return MergedFunc != nullptr; +} diff --git a/llvm/lib/Transforms/IPO/MergeFunctions.cpp b/llvm/lib/Transforms/IPO/MergeFunctions.cpp index b850591b4aa65dcbf7e5724b72618ce9e0620012..9e82d72189f132ce06225c845f0a62c5583cd165 100644 --- a/llvm/lib/Transforms/IPO/MergeFunctions.cpp +++ b/llvm/lib/Transforms/IPO/MergeFunctions.cpp @@ -325,6 +325,13 @@ ModulePass *llvm::createMergeFunctionsPass() { PreservedAnalyses MergeFunctionsPass::run(Module &M, ModuleAnalysisManager &AM) { MergeFunctions MF; + if(M.getName().find("/gold")!=std::string::npos + || M.getName().find("/binutils")!=std::string::npos + || M.getName().find("xlog")!=std::string::npos + || M.getName().find("xcommon")!=std::string::npos + || M.getName().find("exportfs")!=std::string::npos){ + return PreservedAnalyses::all(); + } if (!MF.runOnModule(M)) return PreservedAnalyses::all(); return PreservedAnalyses::none(); diff --git a/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp b/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp index 71c763de43b4cb981157099ecad20ec0b5733ad9..539c11950c2c6cda652a650b406fe484d9d7b73e 100644 --- a/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp +++ b/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp @@ -134,6 +134,11 @@ static constexpr unsigned InstCombineDefaultInfiniteLoopThreshold = 100; static constexpr unsigned InstCombineDefaultInfiniteLoopThreshold = 1000; #endif +static cl::opt EnableCodeSizeInst( + "enable-code-size-inst", cl::init(true), cl::Hidden, + cl::desc("Enable optimizations for code size as part of the optimization " + "pipeline")); + static cl::opt EnableCodeSinking("instcombine-code-sinking", cl::desc("Enable code sinking"), cl::init(true)); @@ -4568,12 +4573,28 @@ static bool combineInstructionsOverFunction( MadeIRChange |= prepareICWorklistFromFunction(F, DL, &TLI, Worklist); - InstCombinerImpl IC(Worklist, Builder, F.hasMinSize(), AA, AC, TLI, TTI, DT, + // ======== code size ======= + if(EnableCodeSizeInst) { + InstCombinerImpl IC(Worklist, Builder, true, AA, AC, TLI, TTI, DT, + ORE, BFI, PSI, DL, LI); + IC.MaxArraySizeForCombine = MaxArraySize; + if (!IC.run()) + break; + } + else { + InstCombinerImpl IC(Worklist, Builder, F.hasMinSize(), AA, AC, TLI, TTI, DT, ORE, BFI, PSI, DL, LI); - IC.MaxArraySizeForCombine = MaxArraySize; + IC.MaxArraySizeForCombine = MaxArraySize; + if (!IC.run()) + break; + } + // ========================== + // InstCombinerImpl IC(Worklist, Builder, F.hasMinSize(), AA, AC, TLI, TTI, DT, + // ORE, BFI, PSI, DL, LI); + //IC.MaxArraySizeForCombine = MaxArraySize; - if (!IC.run()) - break; + //if (!IC.run()) + // break; MadeIRChange = true; } diff --git a/llvm/lib/Transforms/Scalar/CMakeLists.txt b/llvm/lib/Transforms/Scalar/CMakeLists.txt index eb008c15903a744b5685197a50598d3ae950f486..760534647ea4ae4aa8a82709729396f637a2d7f3 100644 --- a/llvm/lib/Transforms/Scalar/CMakeLists.txt +++ b/llvm/lib/Transforms/Scalar/CMakeLists.txt @@ -79,6 +79,7 @@ add_llvm_component_library(LLVMScalarOpts TLSVariableHoist.cpp WarnMissedTransforms.cpp + ADDITIONAL_HEADER_DIRS ${LLVM_MAIN_INCLUDE_DIR}/llvm/Transforms ${LLVM_MAIN_INCLUDE_DIR}/llvm/Transforms/Scalar diff --git a/llvm/lib/Transforms/Scalar/EarlyCSE.cpp b/llvm/lib/Transforms/Scalar/EarlyCSE.cpp index cf282495412277e6d9c51d3d1c79c1433dbc782b..d51899bdd31d37ef22112d75be699f050a074c9c 100644 --- a/llvm/lib/Transforms/Scalar/EarlyCSE.cpp +++ b/llvm/lib/Transforms/Scalar/EarlyCSE.cpp @@ -1208,9 +1208,15 @@ bool EarlyCSE::processNode(DomTreeNode *Node) { // have invalidated the live-out memory values of our parent value. For now, // just be conservative and invalidate memory if this block has multiple // predecessors. + + //feisen:debug + // errs()<<"BB->getSinglePredecessor() before:\n"; if (!BB->getSinglePredecessor()) ++CurrentGeneration; + //feisen:debug + // errs()<<"BB->getSinglePredecessor() after:\n"; + // If this node has a single predecessor which ends in a conditional branch, // we can infer the value of the branch condition given that we took this // path. We need the single predecessor to ensure there's not another path @@ -1226,6 +1232,9 @@ bool EarlyCSE::processNode(DomTreeNode *Node) { } } + //feisen:debug + // errs()<<"BB->getSinglePredecessor() after 2:\n"; + /// LastStore - Keep track of the last non-volatile store that we saw... for /// as long as there in no instruction that reads memory. If we see a store /// to the same location, we delete the dead store. This zaps trivial dead @@ -1234,7 +1243,59 @@ bool EarlyCSE::processNode(DomTreeNode *Node) { // See if any instructions in the block can be eliminated. If so, do it. If // not, add them to AvailableValues. + + //feisen:debug + // ----------------------------- +// int counter = 0; +// errs()<<"BB->getSinglePredecessor() after 3:\n"; +// BB->getInstList(); +// errs()<<"BB->getSinglePredecessor() after 3.1:\n"; +// make_early_inc_range(BB->getInstList()); +// errs()<<"BB->getSinglePredecessor() after 3.2:\n"; +// for (Instruction &Inst : make_early_inc_range(BB->getInstList())) {break;} +// errs()<<"BB->getSinglePredecessor() after 3.3:\n"; +// errs()<<"bb name?= "<getName()<<"\n"; +// if (BB->getInstList().begin() == BB->getInstList().end()) { +// errs() << "BB's instruction list is empty\n"; +// } +// for (Instruction &Inst : BB->getInstList()) { +// errs()<<"check inst: "<<"\n"; +// bool inst_null = &Inst==nullptr; +// errs()<<"inst parent?= "<getInstList().size()<<"\n"; +// int counter2 = 0; +// BB->print(errs()); +// errs()<<"print BB end \n"; +// for(Instruction &Inst: *BB){ +// errs()<<"loop0: "<getInstList())) { + + //feisen:debug + // errs()<<"BB->getSinglePredecessor() after 4:\n"; + + //feisen:debug + // errs()<<"int loop: "<(F).getMSSA() : nullptr; EarlyCSE CSE(F.getParent()->getDataLayout(), TLI, TTI, DT, AC, MSSA); + //feisen:debug aaaa + // errs()<<"EarlyCSEPass::run1\n"; + // ----------------------------- + // errs()<<"----------------------\n"; + // F.print(errs()); + // errs()<<"----------------------\n"; + // int counter = 0; + // for(BasicBlock &BB : F){ + // BB.setName(F.getName()+"_"+std::to_string(counter++)); + // errs()<<"----\n"; + // BB.print(errs()); + // errs()<<"----\n"; + // } + // errs()<<"2----------------------\n"; + // ----------------------------- if (!CSE.run()) return PreservedAnalyses::all(); + //feisen:debug + // errs()<<"EarlyCSEPass::run2\n"; + PreservedAnalyses PA; PA.preserveSet(); + + //feisen:debug + // errs()<<"EarlyCSEPass::run3\n"; + if (UseMemorySSA) PA.preserve(); return PA; diff --git a/llvm/lib/Transforms/Scalar/LoopUnrollPass.cpp b/llvm/lib/Transforms/Scalar/LoopUnrollPass.cpp index de5833f60adc922f70fe688fe12bf4190a6256c8..c85b2da866827706ce2b36bfc71c7729dd905142 100644 --- a/llvm/lib/Transforms/Scalar/LoopUnrollPass.cpp +++ b/llvm/lib/Transforms/Scalar/LoopUnrollPass.cpp @@ -72,6 +72,11 @@ using namespace llvm; #define DEBUG_TYPE "loop-unroll" +static cl::opt EnableCodeSizeLoop( + "enable-code-size-loop", cl::init(true), cl::Hidden, + cl::desc("Enable optimizations for code size as part of the optimization " + "pipeline")); + cl::opt llvm::ForgetSCEVInLoopUnroll( "forget-scev-loop-unroll", cl::init(false), cl::Hidden, cl::desc("Forget everything in SCEV when doing LoopUnroll, instead of just" @@ -222,6 +227,10 @@ TargetTransformInfo::UnrollingPreferences llvm::gatherUnrollingPreferences( (hasUnrollTransformation(L) != TM_ForcedByUser && llvm::shouldOptimizeForSize(L->getHeader(), PSI, BFI, PGSOQueryType::IRPass)); + + //for code size + if(EnableCodeSizeLoop) OptForSize = true; + if (OptForSize) { UP.Threshold = UP.OptSizeThreshold; UP.PartialThreshold = UP.PartialOptSizeThreshold; @@ -403,6 +412,9 @@ static Optional analyzeLoopUnrollCost( RootI.getFunction()->hasMinSize() ? TargetTransformInfo::TCK_CodeSize : TargetTransformInfo::TCK_SizeAndLatency; + // ============ code size + if(EnableCodeSizeLoop) CostKind = TargetTransformInfo::TCK_CodeSize; + // ============ code size for (;; --Iteration) { do { Instruction *I = CostWorklist.pop_back_val(); @@ -486,6 +498,11 @@ static Optional analyzeLoopUnrollCost( TargetTransformInfo::TargetCostKind CostKind = L->getHeader()->getParent()->hasMinSize() ? TargetTransformInfo::TCK_CodeSize : TargetTransformInfo::TCK_SizeAndLatency; + + // ============ code size + if(EnableCodeSizeLoop) CostKind = TargetTransformInfo::TCK_CodeSize; + // ============ code size + // Simulate execution of each iteration of the loop counting instructions, // which would be simplified. // Since the same load will take different values on different iterations, @@ -1172,6 +1189,10 @@ static LoopUnrollResult tryToUnrollLoop( return LoopUnrollResult::Unmodified; bool OptForSize = L->getHeader()->getParent()->hasOptSize(); + + //for code size + if(EnableCodeSizeLoop) OptForSize = true; + unsigned NumInlineCandidates; bool NotDuplicatable; bool Convergent; diff --git a/llvm/lib/Transforms/Scalar/SimpleLoopUnswitch.cpp b/llvm/lib/Transforms/Scalar/SimpleLoopUnswitch.cpp index 0535608244cc2354978519b25e6014a01b994fb0..f2384c0ae0a47384fa7fd7b4732e842152973858 100644 --- a/llvm/lib/Transforms/Scalar/SimpleLoopUnswitch.cpp +++ b/llvm/lib/Transforms/Scalar/SimpleLoopUnswitch.cpp @@ -2863,6 +2863,10 @@ static bool unswitchBestCondition( L.getHeader()->getParent()->hasMinSize() ? TargetTransformInfo::TCK_CodeSize : TargetTransformInfo::TCK_SizeAndLatency; + + // =======code size======= + CostKind = TargetTransformInfo::TCK_CodeSize; + // =======code size======= InstructionCost LoopCost = 0; for (auto *BB : L.blocks()) { InstructionCost Cost = 0; diff --git a/llvm/lib/Transforms/Utils/ScalarEvolutionExpander.cpp b/llvm/lib/Transforms/Utils/ScalarEvolutionExpander.cpp index 372cd74ea01dc9641678e52fbe31eaed3d71b083..74ab2bf102f3e191d97309edf3ceeae6bfab20ae 100644 --- a/llvm/lib/Transforms/Utils/ScalarEvolutionExpander.cpp +++ b/llvm/lib/Transforms/Utils/ScalarEvolutionExpander.cpp @@ -2265,6 +2265,10 @@ bool SCEVExpander::isHighCostExpansionHelper( ? TargetTransformInfo::TCK_CodeSize : TargetTransformInfo::TCK_RecipThroughput; + // =======code size======= + CostKind = TargetTransformInfo::TCK_CodeSize; + // =======code size======= + switch (S->getSCEVType()) { case scCouldNotCompute: llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!"); diff --git a/llvm/lib/Transforms/Utils/SimplifyCFG.cpp b/llvm/lib/Transforms/Utils/SimplifyCFG.cpp index 1806081678a867bb5f775db555cc139579695607..6f2f6757815fc18345252b3142c116b85e825258 100644 --- a/llvm/lib/Transforms/Utils/SimplifyCFG.cpp +++ b/llvm/lib/Transforms/Utils/SimplifyCFG.cpp @@ -2646,6 +2646,10 @@ static bool validateAndCostRequiredSelects(BasicBlock *BB, BasicBlock *ThenBB, BB->getParent()->hasMinSize() ? TargetTransformInfo::TCK_CodeSize : TargetTransformInfo::TCK_SizeAndLatency; + + // =======code size======= + CostKind = TargetTransformInfo::TCK_CodeSize; + // =======code size======= bool HaveRewritablePHIs = false; for (PHINode &PN : EndBB->phis()) { @@ -3556,6 +3560,10 @@ bool llvm::FoldBranchToCommonDest(BranchInst *BI, DomTreeUpdater *DTU, BB->getParent()->hasMinSize() ? TargetTransformInfo::TCK_CodeSize : TargetTransformInfo::TCK_SizeAndLatency; + // =======code size======= + CostKind = TargetTransformInfo::TCK_CodeSize; + // =======code size======= + Instruction *Cond = dyn_cast(BI->getCondition()); if (!Cond || diff --git a/llvm/test/Transforms/FunctionMerging/address-spaces.ll b/llvm/test/Transforms/FunctionMerging/address-spaces.ll new file mode 100644 index 0000000000000000000000000000000000000000..0d239b6cfcd328b4a54d4da379869b4d25ad51b1 --- /dev/null +++ b/llvm/test/Transforms/FunctionMerging/address-spaces.ll @@ -0,0 +1,34 @@ +; RUN: opt -passes=func-merging -S < %s | FileCheck %s + +target datalayout = "p:32:32:32-p1:32:32:32-p2:16:16:16" + +declare void @foo(i32) nounwind + +; None of these functions should be merged + +define i32 @store_as0(i32* %x) { +; CHECK-LABEL: @store_as0( +; CHECK: call void @foo( + %gep = getelementptr i32, i32* %x, i32 4 + %y = load i32, i32* %gep + call void @foo(i32 %y) nounwind + ret i32 %y +} + +define i32 @store_as1(i32 addrspace(1)* %x) { +; CHECK-LABEL: @store_as1( +; CHECK: call void @foo( + %gep = getelementptr i32, i32 addrspace(1)* %x, i32 4 + %y = load i32, i32 addrspace(1)* %gep + call void @foo(i32 %y) nounwind + ret i32 %y +} + +define i32 @store_as2(i32 addrspace(2)* %x) { +; CHECK-LABEL: @store_as2( +; CHECK: call void @foo( + %gep = getelementptr i32, i32 addrspace(2)* %x, i32 4 + %y = load i32, i32 addrspace(2)* %gep + call void @foo(i32 %y) nounwind + ret i32 %y +} diff --git a/llvm/test/Transforms/FunctionMerging/alloca.ll b/llvm/test/Transforms/FunctionMerging/alloca.ll new file mode 100644 index 0000000000000000000000000000000000000000..bdef36eab95935e578736e97c9b9af1f77e11652 --- /dev/null +++ b/llvm/test/Transforms/FunctionMerging/alloca.ll @@ -0,0 +1,61 @@ +; RUN: opt -passes=func-merging -S < %s | FileCheck %s + +;; Make sure that two different allocas are not treated as equal. + +target datalayout = "e-m:w-p:32:32-i64:64-f80:32-n8:16:32-S32" + +%kv1 = type { i32, i32 } +%kv2 = type { i8 } +%kv3 = type { i64, i64 } + +; Size difference. + +; CHECK-LABEL: define void @size1 +; CHECK-NOT: call void @ +define void @size1(i8 *%f) { + %v = alloca %kv1, align 8 + %f_2 = bitcast i8* %f to void (%kv1 *)* + call void %f_2(%kv1 * %v) + call void %f_2(%kv1 * %v) + call void %f_2(%kv1 * %v) + call void %f_2(%kv1 * %v) + ret void +} + +; CHECK-LABEL: define void @size2 +; CHECK-NOT: call void @ +define void @size2(i8 *%f) { + %v = alloca %kv2, align 8 + %f_2 = bitcast i8* %f to void (%kv2 *)* + call void %f_2(%kv2 * %v) + call void %f_2(%kv2 * %v) + call void %f_2(%kv2 * %v) + call void %f_2(%kv2 * %v) + ret void +} + +; Alignment difference. + +; CHECK-LABEL: define void @align1 +; CHECK-NOT: call void @ +define void @align1(i8 *%f) { + %v = alloca %kv3, align 8 + %f_2 = bitcast i8* %f to void (%kv3 *)* + call void %f_2(%kv3 * %v) + call void %f_2(%kv3 * %v) + call void %f_2(%kv3 * %v) + call void %f_2(%kv3 * %v) + ret void +} + +; CHECK-LABEL: define void @align2 +; CHECK-NOT: call void @ +define void @align2(i8 *%f) { + %v = alloca %kv3, align 16 + %f_2 = bitcast i8* %f to void (%kv3 *)* + call void %f_2(%kv3 * %v) + call void %f_2(%kv3 * %v) + call void %f_2(%kv3 * %v) + call void %f_2(%kv3 * %v) + ret void +} diff --git a/llvm/test/Transforms/FunctionMerging/fm-test.ll b/llvm/test/Transforms/FunctionMerging/fm-test.ll new file mode 100644 index 0000000000000000000000000000000000000000..3f660eaf4a8534a75030312a916541281e67c09a --- /dev/null +++ b/llvm/test/Transforms/FunctionMerging/fm-test.ll @@ -0,0 +1,279 @@ +; RUN: opt -passes=func-merging -S < %s | FileCheck %s +; REQUIRES: target=x86_64{{.*}} + +target datalayout = "e-m:o-p270:32:32-p271:32:32-p272:64:64-i64:64-f80:128-n8:16:32:64-S128" +target triple = "x86_64-apple-macosx13.0.0" + +; Function Attrs: noinline nounwind optnone ssp uwtable +define void @insertionsort(ptr noundef %0, i32 noundef %1) #0 { +; CHECK-LABEL: @insertionsort(ptr noundef %0, i32 noundef %1) #0 { +; CHECK: %3 = tail call i32 @_m_f_0(i1 false, ptr %0, i32 %1) #1 +; CHECK: ret void + %3 = alloca ptr, align 8 + %4 = alloca i32, align 4 + %5 = alloca i32, align 4 + %6 = alloca i32, align 4 + %7 = alloca i32, align 4 + store ptr %0, ptr %3, align 8 + store i32 %1, ptr %4, align 4 + store i32 1, ptr %5, align 4 + br label %8 + +8: ; preds = %53, %2 + %9 = load i32, ptr %5, align 4 + %10 = load i32, ptr %4, align 4 + %11 = icmp slt i32 %9, %10 + br i1 %11, label %12, label %56 + +12: ; preds = %8 + %13 = load ptr, ptr %3, align 8 + %14 = load i32, ptr %5, align 4 + %15 = sext i32 %14 to i64 + %16 = getelementptr inbounds i32, ptr %13, i64 %15 + %17 = load i32, ptr %16, align 4 + store i32 %17, ptr %6, align 4 + %18 = load i32, ptr %5, align 4 + %19 = sub nsw i32 %18, 1 + store i32 %19, ptr %7, align 4 + br label %20 + +20: ; preds = %33, %12 + %21 = load i32, ptr %7, align 4 + %22 = icmp sge i32 %21, 0 + br i1 %22, label %23, label %31 + +23: ; preds = %20 + %24 = load ptr, ptr %3, align 8 + %25 = load i32, ptr %7, align 4 + %26 = sext i32 %25 to i64 + %27 = getelementptr inbounds i32, ptr %24, i64 %26 + %28 = load i32, ptr %27, align 4 + %29 = load i32, ptr %6, align 4 + %30 = icmp sgt i32 %28, %29 + br label %31 + +31: ; preds = %23, %20 + %32 = phi i1 [ false, %20 ], [ %30, %23 ] + br i1 %32, label %33, label %46 + +33: ; preds = %31 + %34 = load ptr, ptr %3, align 8 + %35 = load i32, ptr %7, align 4 + %36 = sext i32 %35 to i64 + %37 = getelementptr inbounds i32, ptr %34, i64 %36 + %38 = load i32, ptr %37, align 4 + %39 = load ptr, ptr %3, align 8 + %40 = load i32, ptr %7, align 4 + %41 = add nsw i32 %40, 1 + %42 = sext i32 %41 to i64 + %43 = getelementptr inbounds i32, ptr %39, i64 %42 + store i32 %38, ptr %43, align 4 + %44 = load i32, ptr %7, align 4 + %45 = sub nsw i32 %44, 1 + store i32 %45, ptr %7, align 4 + br label %20, !llvm.loop !6 + +46: ; preds = %31 + %47 = load i32, ptr %6, align 4 + %48 = load ptr, ptr %3, align 8 + %49 = load i32, ptr %7, align 4 + %50 = add nsw i32 %49, 1 + %51 = sext i32 %50 to i64 + %52 = getelementptr inbounds i32, ptr %48, i64 %51 + store i32 %47, ptr %52, align 4 + br label %53 + +53: ; preds = %46 + %54 = load i32, ptr %5, align 4 + %55 = add nsw i32 %54, 1 + store i32 %55, ptr %5, align 4 + br label %8, !llvm.loop !8 + +56: ; preds = %8 + ret void +} + +; Function Attrs: noinline nounwind optnone ssp uwtable +define i32 @insertionsort2(ptr noundef %0, i32 noundef %1) #0 { +; CHECK-LABEL: @insertionsort2(ptr noundef %0, i32 noundef %1) #0 { +; CHECK: %3 = tail call i32 @_m_f_0(i1 true, ptr %0, i32 %1) #1 +; CHECK: ret i32 %3 + %3 = alloca ptr, align 8 + %4 = alloca i32, align 4 + %5 = alloca i32, align 4 + %6 = alloca i32, align 4 + %7 = alloca i32, align 4 + store ptr %0, ptr %3, align 8 + store i32 %1, ptr %4, align 4 + store i32 1, ptr %5, align 4 + br label %8 + +8: ; preds = %53, %2 + %9 = load i32, ptr %5, align 4 + %10 = load i32, ptr %4, align 4 + %11 = icmp slt i32 %9, %10 + br i1 %11, label %12, label %56 + +12: ; preds = %8 + %13 = load ptr, ptr %3, align 8 + %14 = load i32, ptr %5, align 4 + %15 = sext i32 %14 to i64 + %16 = getelementptr inbounds i32, ptr %13, i64 %15 + %17 = load i32, ptr %16, align 4 + store i32 %17, ptr %6, align 4 + %18 = load i32, ptr %5, align 4 + %19 = sub nsw i32 %18, 1 + store i32 %19, ptr %7, align 4 + br label %20 + +20: ; preds = %33, %12 + %21 = load i32, ptr %7, align 4 + %22 = icmp sge i32 %21, 0 + br i1 %22, label %23, label %31 + +23: ; preds = %20 + %24 = load ptr, ptr %3, align 8 + %25 = load i32, ptr %7, align 4 + %26 = sext i32 %25 to i64 + %27 = getelementptr inbounds i32, ptr %24, i64 %26 + %28 = load i32, ptr %27, align 4 + %29 = load i32, ptr %6, align 4 + %30 = icmp sgt i32 %28, %29 + br label %31 + +31: ; preds = %23, %20 + %32 = phi i1 [ false, %20 ], [ %30, %23 ] + br i1 %32, label %33, label %46 + +33: ; preds = %31 + %34 = load ptr, ptr %3, align 8 + %35 = load i32, ptr %7, align 4 + %36 = sext i32 %35 to i64 + %37 = getelementptr inbounds i32, ptr %34, i64 %36 + %38 = load i32, ptr %37, align 4 + %39 = load ptr, ptr %3, align 8 + %40 = load i32, ptr %7, align 4 + %41 = add nsw i32 %40, 1 + %42 = sext i32 %41 to i64 + %43 = getelementptr inbounds i32, ptr %39, i64 %42 + store i32 %38, ptr %43, align 4 + %44 = load i32, ptr %7, align 4 + %45 = sub nsw i32 %44, 1 + store i32 %45, ptr %7, align 4 + br label %20, !llvm.loop !9 + +46: ; preds = %31 + %47 = load i32, ptr %6, align 4 + %48 = load ptr, ptr %3, align 8 + %49 = load i32, ptr %7, align 4 + %50 = add nsw i32 %49, 1 + %51 = sext i32 %50 to i64 + %52 = getelementptr inbounds i32, ptr %48, i64 %51 + store i32 %47, ptr %52, align 4 + br label %53 + +53: ; preds = %46 + %54 = load i32, ptr %5, align 4 + %55 = add nsw i32 %54, 1 + store i32 %55, ptr %5, align 4 + br label %8, !llvm.loop !10 + +56: ; preds = %8 + ret i32 1 +} + +; CHECK-LABEL: @_m_f_0(i1 %0, ptr %1, i32 %2) { +; CHECK: entry: +; CHECK: %3 = alloca ptr, align 8 +; CHECK: %4 = alloca i32, align 4 +; CHECK: %5 = alloca i32, align 4 +; CHECK: %6 = alloca i32, align 4 +; CHECK: %7 = alloca i32, align 4 +; CHECK: store ptr %1, ptr %3, align 8 +; CHECK: store i32 %2, ptr %4, align 4 +; CHECK: store i32 1, ptr %5, align 4 +; CHECK: br label %m.label.bb8 + +; CHECK: m.label.bb8: ; preds = %m.inst.bb42, %entry +; CHECK: %8 = load i32, ptr %5, align 4 +; CHECK: %9 = load i32, ptr %4, align 4 +; CHECK: %10 = icmp slt i32 %8, %9 +; CHECK: br i1 %10, label %m.inst.bb16, label %m.term.bb14 + +; CHECK: m.term.bb14: ; preds = %m.label.bb8 +; CHECK: %11 = select i1 %0, i32 1, i32 undef +; CHECK: ret i32 %11 + +; CHECK: m.inst.bb16: ; preds = %m.label.bb8 +; CHECK: %12 = load ptr, ptr %3, align 8 +; CHECK: %13 = load i32, ptr %5, align 4 +; CHECK: %14 = sext i32 %13 to i64 +; CHECK: %15 = getelementptr inbounds i32, ptr %12, i64 %14 +; CHECK: %16 = load i32, ptr %15, align 4 +; CHECK: store i32 %16, ptr %6, align 4 +; CHECK: %17 = load i32, ptr %5, align 4 +; CHECK: %18 = sub nsw i32 %17, 1 +; CHECK: store i32 %18, ptr %7, align 4 +; CHECK: br label %m.label.bb26 + +; CHECK: m.label.bb26: ; preds = %m.inst.bb56, %m.inst.bb16 +; CHECK: %19 = load i32, ptr %7, align 4 +; CHECK: %20 = icmp sge i32 %19, 0 +; CHECK: br i1 %20, label %m.inst.bb31, label %m.inst.bb42 + +; CHECK: m.inst.bb31: ; preds = %m.label.bb26 +; CHECK: %21 = load ptr, ptr %3, align 8 +; CHECK: %22 = load i32, ptr %7, align 4 +; CHECK: %23 = sext i32 %22 to i64 +; CHECK: %24 = getelementptr inbounds i32, ptr %21, i64 %23 +; CHECK: %25 = load i32, ptr %24, align 4 +; CHECK: %26 = load i32, ptr %6, align 4 +; CHECK: %27 = icmp sgt i32 %25, %26 +; CHECK: br i1 %27, label %m.inst.bb56, label %m.inst.bb42 + +; CHECK: m.inst.bb42: ; preds = %m.label.bb26, %m.inst.bb31 +; CHECK: %28 = load i32, ptr %6, align 4 +; CHECK: %29 = load ptr, ptr %3, align 8 +; CHECK: %30 = load i32, ptr %7, align 4 +; CHECK: %31 = add nsw i32 %30, 1 +; CHECK: %32 = sext i32 %31 to i64 +; CHECK: %33 = getelementptr inbounds i32, ptr %29, i64 %32 +; CHECK: store i32 %28, ptr %33, align 4 +; CHECK: %34 = load i32, ptr %5, align 4 +; CHECK: %35 = add nsw i32 %34, 1 +; CHECK: store i32 %35, ptr %5, align 4 +; CHECK: br label %m.label.bb8 + +; CHECK: m.inst.bb56: ; preds = %m.inst.bb31 +; CHECK: %36 = load ptr, ptr %3, align 8 +; CHECK: %37 = load i32, ptr %7, align 4 +; CHECK: %38 = sext i32 %37 to i64 +; CHECK: %39 = getelementptr inbounds i32, ptr %36, i64 %38 +; CHECK: %40 = load i32, ptr %39, align 4 +; CHECK: %41 = load ptr, ptr %3, align 8 +; CHECK: %42 = load i32, ptr %7, align 4 +; CHECK: %43 = add nsw i32 %42, 1 +; CHECK: %44 = sext i32 %43 to i64 +; CHECK: %45 = getelementptr inbounds i32, ptr %41, i64 %44 +; CHECK: store i32 %40, ptr %45, align 4 +; CHECK: %46 = load i32, ptr %7, align 4 +; CHECK: %47 = sub nsw i32 %46, 1 +; CHECK: store i32 %47, ptr %7, align 4 +; CHECK: br label %m.label.bb26 + +attributes #0 = { noinline nounwind optnone ssp uwtable "frame-pointer"="all" "min-legal-vector-width"="0" "no-trapping-math"="true" "stack-protector-buffer-size"="8" "target-cpu"="penryn" "target-features"="+cx16,+cx8,+fxsr,+mmx,+sahf,+sse,+sse2,+sse3,+sse4.1,+ssse3,+x87" "tune-cpu"="generic" } + +!llvm.module.flags = !{!0, !1, !2, !3, !4} +!llvm.ident = !{!5} + +!0 = !{i32 2, !"SDK Version", [2 x i32] [i32 13, i32 3]} +!1 = !{i32 1, !"wchar_size", i32 4} +!2 = !{i32 7, !"PIC Level", i32 2} +!3 = !{i32 7, !"uwtable", i32 2} +!4 = !{i32 7, !"frame-pointer", i32 2} +!5 = !{!"clang version 15.0.7 (git@gitee.com:h836419908_2062810111/llvm-project.git 6c8d0437527ffd58e84041594d6cfad743ebdab9)"} +!6 = distinct !{!6, !7} +!7 = !{!"llvm.loop.mustprogress"} +!8 = distinct !{!8, !7} +!9 = distinct !{!9, !7} +!10 = distinct !{!10, !7} diff --git a/llvm/test/Transforms/FunctionMerging/functions.ll b/llvm/test/Transforms/FunctionMerging/functions.ll new file mode 100644 index 0000000000000000000000000000000000000000..49e765ae5f2edf725c6c86eaadd778ed472874aa --- /dev/null +++ b/llvm/test/Transforms/FunctionMerging/functions.ll @@ -0,0 +1,27 @@ +; RUN: opt -passes=func-merging -S < %s | FileCheck %s + +; Be sure we don't merge cross-referenced functions of same type. + +; CHECK-LABEL: @left +; CHECK-LABEL: entry-block +; CHECK-LABEL: call void @right(i64 %p) +define void @left(i64 %p) { +entry-block: + call void @right(i64 %p) + call void @right(i64 %p) + call void @right(i64 %p) + call void @right(i64 %p) + ret void +} + +; CHECK-LABEL: @right +; CHECK-LABEL: entry-block +; CHECK-LABEL: call void @left(i64 %p) +define void @right(i64 %p) { +entry-block: + call void @left(i64 %p) + call void @left(i64 %p) + call void @left(i64 %p) + call void @left(i64 %p) + ret void +} \ No newline at end of file diff --git a/llvm/test/Transforms/FunctionMerging/gep-base-type.ll b/llvm/test/Transforms/FunctionMerging/gep-base-type.ll new file mode 100644 index 0000000000000000000000000000000000000000..e5e7f9c7ba7bcd1e94b6846bda4e9164debd949e --- /dev/null +++ b/llvm/test/Transforms/FunctionMerging/gep-base-type.ll @@ -0,0 +1,45 @@ +; RUN: opt -passes=func-merging -S < %s | FileCheck %s +target datalayout = "e-p:64:64:64-i1:8:8-i8:8:8-i16:16:16-i32:32:32-i64:64:64-f32:32:32-f64:64:64-v64:64:64-v128:128:128-a0:0:64-s0:64:64-f80:128:128-n8:16:32:64-S128" + +; These should not be merged, the type of the GEP pointer argument does not have +; the same stride. + +%"struct1" = type <{ i8*, i32, [4 x i8] }> +%"struct2" = type { i8*, { i64, i64 } } + +define internal %struct2* @Ffunc(%struct2* %P, i64 %i) { +; CHECK-LABEL: @Ffunc( +; CHECK-NEXT: getelementptr +; CHECK-NEXT: getelementptr +; CHECK-NEXT: getelementptr +; CHECK-NEXT: getelementptr +; CHECK-NEXT: getelementptr +; CHECK-NEXT: getelementptr +; CHECK-NEXT: ret + %1 = getelementptr inbounds %"struct2", %"struct2"* %P, i64 %i + %2 = getelementptr inbounds %"struct2", %"struct2"* %P, i64 %i + %3 = getelementptr inbounds %"struct2", %"struct2"* %P, i64 %i + %4 = getelementptr inbounds %"struct2", %"struct2"* %P, i64 %i + %5 = getelementptr inbounds %"struct2", %"struct2"* %P, i64 %i + %6 = getelementptr inbounds %"struct2", %"struct2"* %P, i64 %i + ret %struct2* %6 +} + + +define internal %struct1* @Gfunc(%struct1* %P, i64 %i) { +; CHECK-LABEL: @Gfunc( +; CHECK-NEXT: getelementptr +; CHECK-NEXT: getelementptr +; CHECK-NEXT: getelementptr +; CHECK-NEXT: getelementptr +; CHECK-NEXT: getelementptr +; CHECK-NEXT: getelementptr +; CHECK-NEXT: ret + %1 = getelementptr inbounds %"struct1", %"struct1"* %P, i64 %i + %2 = getelementptr inbounds %"struct1", %"struct1"* %P, i64 %i + %3 = getelementptr inbounds %"struct1", %"struct1"* %P, i64 %i + %4 = getelementptr inbounds %"struct1", %"struct1"* %P, i64 %i + %5 = getelementptr inbounds %"struct1", %"struct1"* %P, i64 %i + %6 = getelementptr inbounds %"struct1", %"struct1"* %P, i64 %i + ret %struct1* %6 +} diff --git a/llvm/test/Transforms/FunctionMerging/merge-block-address-other-function.ll b/llvm/test/Transforms/FunctionMerging/merge-block-address-other-function.ll new file mode 100644 index 0000000000000000000000000000000000000000..e2e81dc45c1986d4c6fdf4fa62fba8454b1e048b --- /dev/null +++ b/llvm/test/Transforms/FunctionMerging/merge-block-address-other-function.ll @@ -0,0 +1,50 @@ +; RUN: opt -passes=func-merging -S < %s | FileCheck %s +; REQUIRES: target=x86_64{{.*}} + +target datalayout = "e-m:e-i64:64-f80:128-n8:16:32:64-S128" +target triple = "x86_64-unknown-linux-gnu" + +define i32 @_Z1fi(i32 %i) #0 { +entry: + %retval = alloca i32, align 4 + %i.addr = alloca i32, align 4 + store i32 %i, i32* %i.addr, align 4 + %0 = load i32, i32* %i.addr, align 4 + %cmp = icmp eq i32 %0, 1 + br i1 %cmp, label %if.then, label %if.end + +if.then: + store i32 3, i32* %retval + br label %return + +if.end: + %1 = load i32, i32* %i.addr, align 4 + %cmp1 = icmp eq i32 %1, 3 + br i1 %cmp1, label %if.then.2, label %if.end.3 + +if.then.2: + store i32 56, i32* %retval + br label %return + +if.end.3: + store i32 0, i32* %retval + br label %return + +return: + %2 = load i32, i32* %retval + ret i32 %2 +} + + +define internal i8* @Afunc(i32* %P) { + store i32 1, i32* %P + store i32 3, i32* %P + ret i8* blockaddress(@_Z1fi, %if.then.2) +} + +define internal i8* @Bfunc(i32* %P) { +; CHECK-NOT: @Bfunc + store i32 1, i32* %P + store i32 3, i32* %P + ret i8* blockaddress(@_Z1fi, %if.then.2) +} diff --git a/llvm/test/Transforms/FunctionMerging/merge-const-ptr-and-int.ll b/llvm/test/Transforms/FunctionMerging/merge-const-ptr-and-int.ll new file mode 100644 index 0000000000000000000000000000000000000000..c53d39864e676080b47117c603f967d17f3b9943 --- /dev/null +++ b/llvm/test/Transforms/FunctionMerging/merge-const-ptr-and-int.ll @@ -0,0 +1,19 @@ +; RUN: opt -passes=func-merging -S < %s | FileCheck %s +target datalayout = "e-p:64:64:64-i1:8:8-i8:8:8-i16:16:16-i32:32:32-i64:64:64-f32:32:32-f64:64:64-v64:64:64-v128:128:128-a0:0:64-s0:64:64-f80:128:128-n8:16:32:64-S128" + +; Afunc and Bfunc differ only in that one returns i64, the other a pointer. +; These should be merged. +define internal i64 @Afunc(i32* %P, i32* %Q) { +; CHECK-LABEL: define internal i64 @Afunc + store i32 4, i32* %P + store i32 6, i32* %Q + ret i64 0 +} + +define internal i64* @Bfunc(i32* %P, i32* %Q) { +; MERGE-NOT: @Bfunc + store i32 4, i32* %P + store i32 6, i32* %Q + ret i64* null +} + diff --git a/llvm/test/Transforms/FunctionMerging/mismatching-attr-crash.ll b/llvm/test/Transforms/FunctionMerging/mismatching-attr-crash.ll new file mode 100644 index 0000000000000000000000000000000000000000..5a2520d53abb68c656c0064648073e7878314aa3 --- /dev/null +++ b/llvm/test/Transforms/FunctionMerging/mismatching-attr-crash.ll @@ -0,0 +1,21 @@ +; RUN: opt -passes=func-merging -S < %s | FileCheck %s + +; CHECK-LABEL: define void @foo +; CHECK: call void %bc +define void @foo(i8* byval(i8) %a0, i8* swiftself %a4) { +entry: + %bc = bitcast i8* %a0 to void (i8*, i8*)* + call void %bc(i8* byval(i8) %a0, i8* swiftself %a4) + ret void +} + +; CHECK-LABEL: define void @bar +; CHECK: call void %bc +define void @bar(i8* byval(i8) %a0, i8** swifterror %a4) { +entry: + %bc = bitcast i8* %a0 to void (i8*, i8**)* + call void %bc(i8* byval(i8) %a0, i8** swifterror %a4) + ret void +} + + diff --git a/llvm/test/Transforms/FunctionMerging/no-merge-ptr-different-sizes.ll b/llvm/test/Transforms/FunctionMerging/no-merge-ptr-different-sizes.ll new file mode 100644 index 0000000000000000000000000000000000000000..f593e098dc2ed5a7dba1743e35d210e9dfa5fc48 --- /dev/null +++ b/llvm/test/Transforms/FunctionMerging/no-merge-ptr-different-sizes.ll @@ -0,0 +1,24 @@ +; RUN: opt -passes=func-merging -S < %s | FileCheck %s +target datalayout = "e-p:64:64:64-i1:8:8-i8:8:8-i16:16:16-i32:32:32-i64:64:64-f32:32:32-f64:64:64-v64:64:64-v128:128:128-a0:0:64-s0:64:64-f80:128:128-n8:16:32:64-S128" + +; These should not be merged, as the datalayout says a pointer is 64 bits. No +; sext/zext is specified, so these functions could lower differently. +define internal i32 @Ffunc(i32* %P, i32* %Q) { +; CHECK-LABEL: define internal i32 @Ffunc +; CHECK-NEXT: store +; CHECK-NEXT: store +; CHECK-NEXT: ret + store i32 1, i32* %P + store i32 3, i32* %Q + ret i32 0 +} + +define internal i64* @Gfunc(i32* %P, i32* %Q) { +; CHECK-LABEL: define internal i64* @Gfunc +; CHECK-NEXT: store +; CHECK-NEXT: store +; CHECK-NEXT: ret + store i32 1, i32* %P + store i32 3, i32* %Q + ret i64* null +} diff --git a/llvm/test/Transforms/FunctionMerging/no-merge-ptr-int-different-values.ll b/llvm/test/Transforms/FunctionMerging/no-merge-ptr-int-different-values.ll new file mode 100644 index 0000000000000000000000000000000000000000..0b0434b7c9123ba38c821039c868e720e3408638 --- /dev/null +++ b/llvm/test/Transforms/FunctionMerging/no-merge-ptr-int-different-values.ll @@ -0,0 +1,23 @@ +; RUN: opt -passes=func-merging -S < %s | FileCheck %s +target datalayout = "e-p:64:64:64-i1:8:8-i8:8:8-i16:16:16-i32:32:32-i64:64:64-f32:32:32-f64:64:64-v64:64:64-v128:128:128-a0:0:64-s0:64:64-f80:128:128-n8:16:32:64-S128" + +; These should not be merged, as 1 != 0. +define internal i64 @Ifunc(i32* %P, i32* %Q) { +; CHECK-LABEL: define internal i64 @Ifunc +; CHECK-NEXT: store +; CHECK-NEXT: store +; CHECK-NEXT: ret + store i32 10, i32* %P + store i32 10, i32* %Q + ret i64 1 +} + +define internal i64* @Jfunc(i32* %P, i32* %Q) { +; CHECK-LABEL: define internal i64* @Jfunc +; CHECK-NEXT: store +; CHECK-NEXT: store +; CHECK-NEXT: ret + store i32 10, i32* %P + store i32 10, i32* %Q + ret i64* null +} diff --git a/llvm/test/Transforms/FunctionMerging/phi-check-blocks.ll b/llvm/test/Transforms/FunctionMerging/phi-check-blocks.ll new file mode 100644 index 0000000000000000000000000000000000000000..483bdffca491e919c09854907c4ac14fcc56990c --- /dev/null +++ b/llvm/test/Transforms/FunctionMerging/phi-check-blocks.ll @@ -0,0 +1,50 @@ +; RUN: opt -passes=func-merging -S < %s | FileCheck %s + +; Ensure that we do not merge functions that are identical with the +; exception of the order of the incoming blocks to a phi. + +; CHECK-LABEL: define linkonce_odr hidden i1 @first(i2 %0) +define linkonce_odr hidden i1 @first(i2 %0) { +entry: +; CHECK: switch i2 + switch i2 %0, label %default [ + i2 0, label %L1 + i2 1, label %L2 + i2 -2, label %L3 + ] +default: + unreachable +L1: + br label %done +L2: + br label %done +L3: + br label %done +done: + %result = phi i1 [ true, %L1 ], [ false, %L2 ], [ false, %L3 ] +; CHECK: ret i1 + ret i1 %result +} + +; CHECK-LABEL: define linkonce_odr hidden i1 @second(i2 %0) +define linkonce_odr hidden i1 @second(i2 %0) { +entry: +; CHECK: switch i2 + switch i2 %0, label %default [ + i2 0, label %L1 + i2 1, label %L2 + i2 -2, label %L3 + ] +default: + unreachable +L1: + br label %done +L2: + br label %done +L3: + br label %done +done: + %result = phi i1 [ true, %L3 ], [ false, %L2 ], [ false, %L1 ] +; CHECK: ret i1 + ret i1 %result +} diff --git a/llvm/test/Transforms/FunctionMerging/tailcall.ll b/llvm/test/Transforms/FunctionMerging/tailcall.ll new file mode 100644 index 0000000000000000000000000000000000000000..92363c3be6bf410e2ea9218dfaffa0436291c089 --- /dev/null +++ b/llvm/test/Transforms/FunctionMerging/tailcall.ll @@ -0,0 +1,21 @@ +; RUN: opt -passes=func-merging -S < %s | FileCheck %s + +declare void @dummy() + +; CHECK-LABEL: define{{.*}}@foo +; CHECK: call {{.*}}@dummy +; CHECK: musttail {{.*}}@dummy +define void @foo() { + call void @dummy() + musttail call void @dummy() + ret void +} + +; CHECK-LABEL: define{{.*}}@bar +; CHECK: call {{.*}}@dummy +; CHECK: call {{.*}}@dummy +define void @bar() { + call void @dummy() + call void @dummy() + ret void +} diff --git a/llvm/test/Transforms/FunctionMerging/too-small.ll b/llvm/test/Transforms/FunctionMerging/too-small.ll new file mode 100644 index 0000000000000000000000000000000000000000..2fb7e385869884d1d9b8b2a12d1da2ce4b1526e9 --- /dev/null +++ b/llvm/test/Transforms/FunctionMerging/too-small.ll @@ -0,0 +1,14 @@ +; RUN: opt -passes=func-merging -S < %s | FileCheck %s + +define void @foo(i32 %x) { +; CHECK-LABEL: @foo( +; CHECK-NOT: call + ret void +} + +define void @bar(i32 %x) { +; CHECK-LABEL: @bar( +; CHECK-NOT: call + ret void +} + diff --git a/llvm/test/Transforms/FunctionMerging/weak-small.ll b/llvm/test/Transforms/FunctionMerging/weak-small.ll new file mode 100644 index 0000000000000000000000000000000000000000..b17bd3f1bb2061fd53755471c6f228cac0b689a1 --- /dev/null +++ b/llvm/test/Transforms/FunctionMerging/weak-small.ll @@ -0,0 +1,16 @@ +; RUN: opt -passes=func-merging -S < %s | FileCheck %s + +; Weak functions too small for merging to be profitable + +; CHECK: define weak i32 @foo(i8* %0, i32 %1) +; CHECK-NEXT: ret i32 %1 +; CHECK: define weak i32 @bar(i8* %0, i32 %1) +; CHECK-NEXT: ret i32 %1 + +define weak i32 @foo(i8* %0, i32 %1) #0 { + ret i32 %1 +} + +define weak i32 @bar(i8* %0, i32 %1) #0 { + ret i32 %1 +}