Nakonec, po dnech čtení dokumentů nebo ladění indexů vláken, jsem byl schopen implementovat flash attention od nuly v DSC na MI300X! První verze (oranžová) je základní skalární verze z původního článku s bleskovou pozorností. Druhá verze (zelená) je stejný algoritmus, ale používá maticová jádra (AMD tenzorová jádra) a jak vidíte, je *výrazně* rychlejší než skalární. Použil jsem maticová jádra k výpočtu Sij = Qi @ Kj^T a Pij @ Vj. Některé "gotchas" maticových jader AMD: - Pracují na základě vlnoplochy a vlnoplocha má na AMD 64 vláken, což znamená, že musíte sledovat jak ID aktuální vlny, tak ID vlákna v této vlně. - Výstupní rozvržení bude v registrech zamícháno kvůli skutečnosti, že základní operací maticového jádra je vnější produkt 4x1, takže je vyžadován krok přiřazení. - (Pokud je mi známo) vnitřní prvky hipcc pro maticová jádra nejsou nikde zdokumentovány. K dispozici je repo s hromadou příkladů od AMD, ale kromě toho budete muset grep LLVM kód. Teď se chystám vyleštit svůj kód a pak pravděpodobně napíšu podrobnější příspěvek o flash attention na AMD. Jo a mimochodem, křičte na @HotAisle za to, že to umožnili!