Till slut, efter dagar av att antingen läsa dokument eller felsöka trådindex, kunde jag implementera flash attention från grunden i DSC på MI300X! Den första versionen (orange) är den grundläggande skalära versionen från det ursprungliga flash-attention-papperet. Den andra versionen (grön) är samma algoritm men använder matriskärnor (AMD tensorkärnor) och som du kan se är detta * betydligt * snabbare än den skalära en. Jag använde matriskärnorna för att beräkna både Sij = Qi @ Kj^T och Pij @ Vj. Några "gotchas" av AMD-matriskärnor: - De arbetar per vågfront och en vågfront är 64 trådar på AMD, det betyder att du måste hålla reda på både ID för den aktuella vågen och även ID för tråden inom den vågen. - Utdatalayouten kommer att blandas i register på grund av det faktum att kärnoperationen för en matriskärna är en 4x1 yttre produkt, så ett omordningssteg krävs. - (Så vitt jag vet) finns inte hipcc intrinsics för matriskärnor dokumenterade någonstans. Det finns ett repo med ett gäng exempel från AMD, men förutom det måste du använda LLVM-kodbasen. Jag ska putsa på min kod nu och sedan kommer jag nog att skriva ett mer djupgående inlägg om flash attention på AMD. Åh och btw, shout out till @HotAisle för att göra detta möjligt!