Cuối cùng, sau nhiều ngày đọc tài liệu hoặc gỡ lỗi chỉ số luồng, tôi đã có thể triển khai flash attention từ đầu trong DSC trên MI300X!
Phiên bản đầu tiên (màu cam) là phiên bản số nguyên cơ bản từ tài liệu gốc về flash-attention.
Phiên bản thứ hai (màu xanh) là cùng một thuật toán nhưng sử dụng các lõi ma trận (lõi tensor AMD) và như bạn có thể thấy, điều này *nhanh hơn đáng kể* so với phiên bản số nguyên.
Tôi đã sử dụng các lõi ma trận để tính cả Sij = Qi @ Kj^T và Pij @ Vj.
Một số 'cạm bẫy' của các lõi ma trận AMD:
- Chúng hoạt động trên cơ sở từng wavefront và một wavefront là 64 luồng trên AMD, điều này có nghĩa là bạn phải theo dõi cả ID của wave hiện tại và cũng là ID của luồng trong wave đó.
- Bố cục đầu ra sẽ bị xáo trộn trong các thanh ghi do thực tế rằng hoạt động cốt lõi của một lõi ma trận là một sản phẩm ngoài 4x1, vì vậy một bước sắp xếp lại là cần thiết.
- (Theo như tôi biết) các intrinsics hipcc cho các lõi ma trận không được tài liệu hóa ở đâu cả. Có một repo với một loạt ví dụ từ AMD nhưng ngoài điều đó, bạn sẽ phải grep mã nguồn LLVM.
Tôi sẽ làm sạch mã của mình bây giờ và sau đó có lẽ sẽ viết một bài viết chi tiết hơn về flash attention trên AMD.
Ôi và nhân tiện, cảm ơn @HotAisle vì đã làm điều này trở nên khả thi!