Finalmente, após dias a ler documentos ou a depurar índices de threads, consegui implementar a atenção flash do zero em DSC no MI300X!
A primeira versão (laranja) é a versão escalar básica do artigo original sobre atenção flash.
A segunda versão (verde) é o mesmo algoritmo, mas utiliza núcleos de matriz (núcleos tensorais da AMD) e, como podem ver, isto é *significativamente* mais rápido do que a versão escalar.
Usei os núcleos de matriz para calcular tanto Sij = Qi @ Kj^T como Pij @ Vj.
Alguns 'pontos a ter em conta' dos núcleos de matriz da AMD:
- Eles funcionam com base em cada onda e uma onda é 64 threads na AMD, o que significa que você tem que acompanhar tanto o ID da onda atual quanto o ID da thread dentro dessa onda.
- O layout de saída será embaralhado nos registos devido ao fato de que a operação principal de um núcleo de matriz é um produto externo 4x1, portanto, um passo de reordenação é necessário.
- (Até onde sei) as intrínsecas hipcc para núcleos de matriz não estão documentadas em nenhum lugar. Há um repositório com um monte de exemplos da AMD, mas além disso, você terá que grep no código base do LLVM.
Vou polir meu código agora e depois provavelmente escreverei um post mais detalhado sobre a atenção flash na AMD.
Ah, e a propósito, um agradecimento ao @HotAisle por tornar isso possível!