Finalmente, después de días de leer documentos o depurar índices de subprocesos, ¡pude implementar la atención flash desde cero en DSC en MI300X!
La primera versión (naranja) es la versión escalar básica del artículo original de flash-attention.
La segunda versión (verde) es el mismo algoritmo pero usa núcleos matriciales (núcleos tensoriales AMD) y, como puede ver, es *significativamente* más rápido que el escalar.
Utilicé los núcleos de la matriz para calcular tanto Sij = Qi @ Kj^T como Pij @ Vj.
Algunas "trampas" de los núcleos de matriz AMD:
- Funcionan por frente de onda y un frente de onda es de 64 hilos en AMD, esto significa que debe realizar un seguimiento tanto del ID de la onda actual como del ID del hilo dentro de esa onda.
- El diseño de salida se barajará en registros debido al hecho de que la operación central de un núcleo de matriz es un producto externo 4x1, por lo que se requiere un paso de reordenamiento.
- (Hasta donde yo sé) los intrínsecos de hipcc para los núcleos de matriz no están documentados en ninguna parte. Hay un repositorio con un montón de ejemplos de AMD, pero aparte de eso, tendrás que grep LLVM codebase.
Voy a pulir mi código de vez en cuando y luego probablemente escribiré una publicación más detallada sobre la atención flash en AMD.
Ah, y por cierto, ¡un saludo a @HotAisle por hacer esto posible!