Enfin, après des jours à lire des documents ou à déboguer des indices de thread, j'ai pu implémenter l'attention flash de zéro en DSC sur MI300X ! La première version (orange) est la version scalaire de base tirée du document original sur l'attention flash. La deuxième version (verte) est le même algorithme mais utilise des cœurs de matrice (cœurs tensoriels AMD) et comme vous pouvez le voir, c'est *significativement* plus rapide que la version scalaire. J'ai utilisé les cœurs de matrice pour calculer à la fois Sij = Qi @ Kj^T et Pij @ Vj. Quelques 'pièges' des cœurs de matrice AMD : - Ils fonctionnent sur une base par vague et une vague est de 64 threads sur AMD, cela signifie que vous devez garder une trace à la fois de l'ID de la vague actuelle et de l'ID du thread au sein de cette vague. - La disposition de la sortie sera mélangée dans les registres en raison du fait que l'opération principale d'un cœur de matrice est un produit extérieur 4x1, donc une étape de réordonnancement est nécessaire. - (Autant que je sache) les intrinsics hipcc pour les cœurs de matrice ne sont documentés nulle part. Il existe un dépôt avec un tas d'exemples d'AMD, mais à part ça, vous devrez fouiller dans le code source LLVM. Je vais maintenant peaufiner mon code et ensuite je vais probablement écrire un article plus approfondi sur l'attention flash sur AMD. Oh et au fait, un grand merci à @HotAisle pour avoir rendu cela possible !