Нарешті, після кількох днів читання документації або налагодження індексів потоків, я зміг реалізувати flash attention з нуля в DSC на MI300X! Перша версія (помаранчева) є базовою скалярною версією з оригінального паперу про спалах-увагу. Друга версія (зелена) - це той же алгоритм, але використовуються матричні ядра (тензорні ядра AMD), і, як бачите, це *значно* швидше, ніж скалярне. Я використовував ядра матриці для обчислення обох Sij = Qi @ Kj^T і Pij @ Vj. Деякі «фіччі» ядер матриць AMD: - Вони працюють на основі хвильового фронту, а хвильовий фронт становить 64 потоки на AMD, це означає, що вам потрібно відстежувати як ID поточної хвилі, так і ID потоку в межах цієї хвилі. - Розкладка виводу буде перемішана в регістрах у зв'язку з тим, що основною роботою ядра матриці є зовнішній продукт 4х1, тому потрібен крок перезамовлення. - (Наскільки мені відомо) власне значення hipcc для ядер матриць ніде не задокументовані. Є репозиторій з купою прикладів від AMD, але крім цього вам доведеться grep кодової бази LLVM. Я збираюся час від часу шліфувати свій код, а потім, мабуть, напишу більш глибокий пост про увагу до спалаху на AMD. Ох і, до речі, кричіть @HotAisle за те, що вони зробили це можливим!