Наконец, после нескольких дней чтения документации или отладки индексов потоков, я смог реализовать flash attention с нуля в DSC на MI300X!
Первая версия (оранжевая) — это базовая скалярная версия из оригинальной статьи по flash-attention.
Вторая версия (зеленая) — это тот же алгоритм, но использующий матричные ядра (тензорные ядра AMD), и, как вы можете видеть, это *значительно* быстрее, чем скалярная версия.
Я использовал матричные ядра для вычисления как Sij = Qi @ Kj^T, так и Pij @ Vj.
Некоторые "подводные камни" матричных ядер AMD:
- Они работают на основе волнового фронта, и волновой фронт состоит из 64 потоков на AMD, это означает, что вам нужно отслеживать как ID текущей волны, так и ID потока внутри этой волны.
- Выходной макет будет перемешан в регистрах из-за того, что основная операция матричного ядра — это внешнее произведение 4x1, поэтому требуется шаг переупорядочивания.
- (Насколько мне известно) встроенные функции hipcc для матричных ядер нигде не задокументированы. Есть репозиторий с множеством примеров от AMD, но кроме этого вам придется искать в кодовой базе LLVM.
Я собираюсь сейчас доработать свой код, а затем, вероятно, напишу более подробный пост о flash attention на AMD.
О, и кстати, спасибо @HotAisle за то, что это стало возможным!