Наконец, после нескольких дней чтения документации или отладки индексов потоков, я смог реализовать flash attention с нуля в DSC на MI300X! Первая версия (оранжевая) — это базовая скалярная версия из оригинальной статьи по flash-attention. Вторая версия (зеленая) — это тот же алгоритм, но использующий матричные ядра (тензорные ядра AMD), и, как вы можете видеть, это *значительно* быстрее, чем скалярная версия. Я использовал матричные ядра для вычисления как Sij = Qi @ Kj^T, так и Pij @ Vj. Некоторые "подводные камни" матричных ядер AMD: - Они работают на основе волнового фронта, и волновой фронт состоит из 64 потоков на AMD, это означает, что вам нужно отслеживать как ID текущей волны, так и ID потока внутри этой волны. - Выходной макет будет перемешан в регистрах из-за того, что основная операция матричного ядра — это внешнее произведение 4x1, поэтому требуется шаг переупорядочивания. - (Насколько мне известно) встроенные функции hipcc для матричных ядер нигде не задокументированы. Есть репозиторий с множеством примеров от AMD, но кроме этого вам придется искать в кодовой базе LLVM. Я собираюсь сейчас доработать свой код, а затем, вероятно, напишу более подробный пост о flash attention на AMD. О, и кстати, спасибо @HotAisle за то, что это стало возможным!