Abilitato l'addestramento fp8 per un miglioramento del +4,3% nel "tempo per GPT-2", sceso a 2,91 ore ora. Vale anche la pena notare che se utilizzi i prezzi delle istanze spot 8XH100, questo repro di GPT-2 costa davvero solo ~$20. Quindi è entusiasmante - GPT-2 (7 anni fa): troppo pericoloso da rilasciare. GPT-2 (oggi): nuovo MNIST! :) Sicuramente questo può scendere ben al di sotto di 1 ora. Alcune parole in più su fp8, è stato un po' più complicato di quanto avessi previsto e mi ci è voluto un po' per arrivarci e anche ora non sono sicuro al 100% se sia una grande idea a causa del supporto complessivo ridotto. Sulla carta, fp8 su H100 è 2X i FLOPS, ma in pratica è molto meno. Non siamo al 100% vincolati al calcolo durante l'esecuzione dell'addestramento reale, c'è un sovraccarico extra dovuto alle conversioni di scala aggiuntive, i GEMM non sono abbastanza grandi su scala GPT-2 da rendere il sovraccarico chiaramente giustificato e, naturalmente, a precisione inferiore la qualità di ogni passo è minore. Per la ricetta di scaling rowwise, le curve di perdita fp8 rispetto a bf16 erano abbastanza vicine ma era nettamente più lento. Per lo scaling tensorwise, le curve di perdita si sono separate di più (cioè ogni passo è di qualità peggiore), ma ora almeno otteniamo un'accelerazione (~7,3%). Puoi recuperare naivamente le prestazioni aumentando l'orizzonte di addestramento (alleni per più passi, ma ogni passo è più veloce) e sperare che alla fine tu ne esca avvantaggiato. In questo caso e in generale, giocando un po' con queste ricette e orizzonti di addestramento, finora ho ottenuto un'accelerazione di ~5%. torchao nel loro articolo riporta un'accelerazione dell'addestramento fp8 di Llama3-8B del 25% (rispetto al mio ~7,3% senza tenere conto della capacità), che è più vicino a ciò che speravo inizialmente, anche se Llama3-8B è un modello molto più grande. Probabilmente non è la fine della saga fp8. Dovrebbe essere possibile migliorare le cose scegliendo esattamente quali strati applicare e prestando maggiore attenzione ai numeri attraverso la rete.