Finalmente, después de días de leer documentos o depurar índices de hilos, ¡pude implementar atención flash desde cero en DSC en MI300X! La primera versión (naranja) es la versión escalar básica del artículo original sobre atención flash. La segunda versión (verde) es el mismo algoritmo pero utiliza núcleos de matriz (núcleos tensoriales de AMD) y, como puedes ver, esto es *significativamente* más rápido que la versión escalar. Usé los núcleos de matriz para calcular tanto Sij = Qi @ Kj^T como Pij @ Vj. Algunas 'trampas' de los núcleos de matriz de AMD: - Funcionan en una base por ola y una ola son 64 hilos en AMD, esto significa que tienes que llevar un seguimiento tanto del ID de la ola actual como del ID del hilo dentro de esa ola. - El diseño de salida se barajará en los registros debido al hecho de que la operación principal de un núcleo de matriz es un producto exterior 4x1, por lo que se requiere un paso de reordenamiento. - (Hasta donde sé) las intrínsecas de hipcc para núcleos de matriz no están documentadas en ningún lugar. Hay un repositorio con un montón de ejemplos de AMD, pero aparte de eso tendrás que buscar en la base de código de LLVM. Voy a pulir mi código ahora y luego probablemente escribiré un post más detallado sobre la atención flash en AMD. Oh, y por cierto, ¡un saludo a @HotAisle por hacer esto posible!