W końcu, po dniach czytania dokumentacji lub debugowania indeksów wątków, udało mi się zaimplementować flash attention od podstaw w DSC na MI300X!
Pierwsza wersja (pomarańczowa) to podstawowa wersja skalarna z oryginalnego artykułu o flash-attention.
Druga wersja (zielona) to ten sam algorytm, ale wykorzystuje rdzenie macierzowe (rdzenie tensorowe AMD) i, jak widać, jest *znacząco* szybsza niż wersja skalarna.
Użyłem rdzeni macierzowych do obliczenia zarówno Sij = Qi @ Kj^T, jak i Pij @ Vj.
Kilka 'pułapek' rdzeni macierzowych AMD:
- Działają na zasadzie per-wavefront, a wavefront to 64 wątki na AMD, co oznacza, że musisz śledzić zarówno ID bieżącego wave, jak i ID wątku w tym wave.
- Układ wyjściowy będzie przetasowany w rejestrach z powodu faktu, że podstawowa operacja rdzenia macierzowego to iloczyn zewnętrzny 4x1, więc wymagany jest krok przetasowania.
- (O ile mi wiadomo) intrinsics hipcc dla rdzeni macierzowych nie są nigdzie udokumentowane. Jest repozytorium z wieloma przykładami od AMD, ale poza tym będziesz musiał przeszukać kod LLVM.
Zamierzam teraz dopracować mój kod, a potem prawdopodobnie napiszę bardziej szczegółowy post na temat flash attention na AMD.
A tak przy okazji, pozdrowienia dla @HotAisle za umożliwienie tego!