Eindelijk, na dagen van het lezen van documentatie of het debuggen van thread-indices, was ik in staat om flash attention vanaf nul te implementeren in DSC op MI300X!
De eerste versie (oranje) is de basis scalare versie van het originele flash-attention paper.
De tweede versie (groen) is hetzelfde algoritme, maar gebruikt matrixkernen (AMD tensor cores) en zoals je kunt zien is dit *significant* sneller dan de scalare versie.
Ik heb de matrixkernen gebruikt om zowel Sij = Qi @ Kj^T als Pij @ Vj te berekenen.
Enkele 'gotchas' van AMD matrixkernen:
- Ze werken op basis van per-golffront en een golffront is 64 threads op AMD, dit betekent dat je zowel de ID van de huidige golf als de ID van de thread binnen die golf moet bijhouden.
- De uitvoerindeling zal in registers worden geschud vanwege het feit dat de kernoperatie van een matrixkern een 4x1 buitenproduct is, dus een herschikkingsstap is vereist.
- (Voor zover ik weet) zijn hipcc intrinsics voor matrixkernen nergens gedocumenteerd. Er is een repo met een aantal voorbeelden van AMD, maar verder moet je de LLVM-codebase doorzoeken.
Ik ga nu mijn code polijsten en zal waarschijnlijk een meer diepgaande post schrijven over flash attention op AMD.
Oh en trouwens, shout out naar @HotAisle voor het mogelijk maken hiervan!