Endlich, nach Tagen des Lesens von Dokumentationen oder des Debuggens von Thread-Indizes, konnte ich Flash Attention von Grund auf in DSC auf MI300X implementieren!
Die erste Version (orange) ist die grundlegende Skalarversion aus dem ursprünglichen Flash-Attention-Papier.
Die zweite Version (grün) ist derselbe Algorithmus, verwendet jedoch Matrixkerne (AMD Tensor-Kerne) und wie Sie sehen können, ist dies *deutlich* schneller als die Skalarversion.
Ich habe die Matrixkerne verwendet, um sowohl Sij = Qi @ Kj^T als auch Pij @ Vj zu berechnen.
Einige 'Hürden' der AMD-Matrixkerne:
- Sie arbeiten auf Basis von Wellenfronten, und eine Wellenfront besteht aus 64 Threads bei AMD, das bedeutet, dass Sie sowohl die ID der aktuellen Welle als auch die ID des Threads innerhalb dieser Welle im Auge behalten müssen.
- Das Ausgabeformat wird in Registern aufgrund der Tatsache, dass die Kernoperation eines Matrixkerns ein 4x1-Außenprodukt ist, umsortiert, sodass ein Umordnungs-Schritt erforderlich ist.
- (Soweit ich informiert bin) sind die hipcc-Intrinsiken für Matrixkerne nirgendwo dokumentiert. Es gibt ein Repository mit einer Menge von Beispielen von AMD, aber abgesehen davon müssen Sie den LLVM-Code durchsuchen.
Ich werde jetzt meinen Code verfeinern und dann wahrscheinlich einen ausführlicheren Beitrag über Flash Attention auf AMD schreiben.
Oh und übrigens, ein Dankeschön an @HotAisle, dass dies möglich gemacht wurde!