Finalmente, dopo giorni passati a leggere documenti o a fare debug sugli indici dei thread, sono riuscito a implementare l'attenzione flash da zero in DSC su MI300X! La prima versione (arancione) è la versione scalare di base dell'articolo originale sull'attenzione flash. La seconda versione (verde) è lo stesso algoritmo ma utilizza i core matrice (core tensor AMD) e come puoi vedere è *significativamente* più veloce rispetto a quella scalare. Ho usato i core matrice per calcolare sia Sij = Qi @ Kj^T che Pij @ Vj. Alcuni 'problemi' dei core matrice AMD: - Funzionano su base per wavefront e un wavefront è composto da 64 thread su AMD, il che significa che devi tenere traccia sia dell'ID dell'attuale wave che dell'ID del thread all'interno di quella wave. - Il layout dell'output sarà mescolato nei registri a causa del fatto che l'operazione principale di un core matrice è un prodotto esterno 4x1, quindi è necessario un passaggio di riordino. - (Per quanto ne so) le intrinseche hipcc per i core matrice non sono documentate da nessuna parte. C'è un repository con un sacco di esempi da AMD, ma a parte questo dovrai cercare nel codice sorgente di LLVM. Adesso andrò a rifinire il mio codice e poi probabilmente scriverò un post più approfondito sull'attenzione flash su AMD. Oh e a proposito, un saluto a @HotAisle per aver reso tutto questo possibile!