Tôi đã dành vài giờ để xem qua repo /karpathy/autoresearch từng dòng một. Góc nhìn "các tác nhân AI thực hiện nghiên cứu" đang thu hút tất cả sự chú ý nhưng tôi nghĩ điều thú vị hơn là những gì thực sự có trong kịch bản đào tạo và các quyết định kỹ thuật khiến vòng lặp tìm kiếm trở nên chặt chẽ. Đây là một trong những thiết lập đào tạo một tệp dày đặc nhất mà tôi đã đọc. Hãy để tôi bắt đầu với điều làm cho toàn bộ dự án trở nên khả thi: ngân sách thời gian được cố định ở 300 giây đồng hồ. Không phải bước cố định, không phải token cố định, không phải flops cố định. Giây đồng hồ. Điều này nghe có vẻ như một chi tiết nhỏ nhưng đó là lý do chính khiến vòng lặp tự động hoạt động. Tác nhân có thể làm cho mô hình lớn hơn 3 lần, cắt giảm kích thước lô xuống một nửa, thay thế bằng một kiến trúc hoàn toàn khác, và kết quả vẫn có thể so sánh trực tiếp với mọi thí nghiệm khác vì tất cả đều có đúng 5 phút đào tạo trên cùng một GPU. Nếu bạn cố định bước thay vào đó, một mô hình lớn hơn sẽ nhận được ít cập nhật gradient hơn mỗi giây và bạn sẽ phạt nó một cách không công bằng. Nếu bạn cố định token, bạn sẽ gặp phải vấn đề tương tự. Cố định thời gian thực có nghĩa là bạn đang đặt ra câu hỏi đúng: với phần cứng này và thời gian này, mô hình tốt nhất bạn có thể sản xuất là gì? Mọi thứ khác là một biến tự do. Tác nhân có thể khám phá toàn bộ bề mặt Pareto của kích thước mô hình so với thông lượng so với tốc độ hội tụ mà không có bất kỳ sự đánh đổi nào bị nhầm lẫn bởi giao thức đánh giá. Chỉ số cũng được chọn lựa cẩn thận. Đó là bit trên byte, không phải mất mát entropy chéo. Mất mát entropy chéo phụ thuộc vào kích thước từ vựng của bạn. Một mô hình với 32k token và một mô hình với 8k token sẽ có giá trị mất mát rất khác nhau ngay cả khi chúng nén dữ liệu một cách tương đương. bpb chuẩn hóa điều này bằng cách cộng tổng mất mát entropy trên mỗi token trong nats, cộng tổng chiều dài byte utf-8 của các token mục tiêu, và chuyển đổi nats trên byte thành bits trên byte. Vì vậy, ngay cả khi tác nhân thay đổi điều gì đó ảnh hưởng đến phân phối token hiệu quả, sự so sánh vẫn công bằng. Hai lựa chọn này, thời gian thực cố định và chỉ số không phụ thuộc vào từ vựng, biến một tìm kiếm không thể so sánh lộn xộn thành một bài toán tối ưu hóa sạch sẽ. Bây giờ là mô hình tự nó. Nó là một GPT nhưng với một loạt các mẹo hiện đại đáng để hiểu. Đầu tiên, RMSnorm ở khắp mọi nơi. Trên các đầu vào khối (pre-norm), và cũng trên các truy vấn và khóa ngay trước khi sản phẩm điểm chú ý. Điều này QK-norm rất quan trọng vì nếu không có nó, các chuẩn của q và k có thể phát triển không giới hạn trong quá trình đào tạo, khiến các logits chú ý trở nên sắc nét và softmax bão hòa. Chuẩn hóa q và k giữ cho các sản phẩm điểm trong một khoảng ổn định bất kể mạng sâu đến đâu hoặc động lực đào tạo phát triển như thế nào. Chú ý tự nó là FA 3, được tải qua thư viện kernels. Nó sử dụng triển khai của varunneal trên hopper (sm_90) và quay lại một bản xây dựng cộng đồng trên các GPU cũ hơn. Mẫu chú ý là "SSSL" có nghĩa là ba lớp chú ý cửa sổ trượt (cửa sổ = một nửa chiều dài chuỗi) theo sau là một lớp chú ý nguyên nhân đầy đủ, lặp lại. Đây là mẫu thưa đến dày mà bạn thấy trong mistral và gemma2. Các lớp chú ý cục bộ rẻ về mặt tính toán vì ma trận chú ý là băng, và lớp toàn cầu định kỳ cho phép thông tin chảy qua toàn bộ ngữ cảnh. Với 8 lớp và một mẫu 4 ký tự, bạn có các lớp 0,1,2 cục bộ, lớp 3 toàn cầu, các lớp 4,5,6 cục bộ, lớp 7 toàn cầu. Lớp cuối cùng được buộc phải toàn cầu bất kể mẫu. Điều nhúng giá trị là tinh tế và tôi nghĩ không được đánh giá cao. Mỗi lớp khác nhận được bảng nhúng riêng của nó, hoàn toàn tách biệt khỏi nhúng token chính, ánh xạ các id token trực tiếp đến các vector chiều giá trị. Những điều này được trộn vào các giá trị chú ý thông qua một cổng học được: v = v + 2 * sigmoid(W_gate @ x:32) * ve. Trọng số cổng được khởi tạo bằng không, vì vậy sigmoid(0) = 0.5, nhân với 2 cho 1.0, đây là điểm khởi đầu trung lập. Trong quá trình đào tạo, mô hình có thể học cách khuếch đại hoặc giảm bớt nhúng giá trị theo đầu dựa trên 32 chiều đầu tiên của trạng thái ẩn. Điều này đến từ dòng công việc ResFormer và trực giác là nó cung cấp cho chú ý một lối tắt trực tiếp đến danh tính token. Các vector giá trị có thể mang thông tin về "token nào đang ở vị trí này" mà không cần thông tin đó phải sống sót qua các biến đổi dòng dư từ các lớp trước. Nó về cơ bản là một kết nối bỏ qua từ đầu vào trực tiếp vào các giá trị chú ý, được cổng để mô hình có thể quyết định khi nào là hữu ích. Cũng có các số học có thể học được theo lớp trên dòng dư: x = lambda_residi * x + lambda_x0i * x0, trong đó x0 là nhúng chuẩn hóa từ lớp 0. Mỗi lớp có thể độc lập kiểm soát mức độ mà nó lắng nghe dòng dư đang chạy so với đầu vào gốc. Các lambda dư bắt đầu ở 1.0, các lambda x0 bắt đầu ở 0.1. Đây là một phiên bản mềm của ý tưởng "dòng dư tách rời". Trong một transformer tiêu chuẩn, dòng dư là tổng của tất cả các đầu ra lớp trước và nó ngày càng bị ô nhiễm khi bạn đi sâu hơn. Việc cho mỗi lớp quyền truy cập vào nhúng gốc sạch có nghĩa là nó không phải học cách "hoàn tác" các lớp trước để phục hồi thông tin cấp thấp. Các logits được giới hạn mềm ở 15 thông qua tanh(logits/15)*15, điều này ngăn mô hình trở nên quá tự tin sớm trong quá trình đào tạo khi các đại diện vẫn còn ồn ào. Nhưng thành thật mà nói, phần thú vị nhất của toàn bộ tệp là bộ tối ưu hóa. MuonAdamW là một bộ tối ưu hóa kết hợp mà phân phối các quy tắc cập nhật khác nhau dựa trên nhóm tham số. Các nhúng (nhúng token, nhúng giá trị, đầu ra không nhúng) và các số học theo lớp nhận được AdamW tiêu chuẩn với các tốc độ học khác nhau cho mỗi nhóm. Sự phân bố là hoang dã. Tốc độ học nhúng là 0.6, tốc độ học không nhúng là 0.004, đó là sự khác biệt 150x, và điều này là có chủ ý. Ma trận nhúng thấy từng token và cần cập nhật một cách quyết liệt. Ma trận không nhúng là một đầu dò tuyến tính trên đại diện cuối cùng và được hưởng lợi từ sự ổn định. Tốc độ học nhúng, nhúng giá trị và không nhúng đều được điều chỉnh bởi (d_model / 768)^(-0.5) mà là một điều chỉnh lấy cảm hứng từ muP. Khi chiều rộng mô hình thay đổi, các tốc độ học đó điều chỉnh để giữ cho động lực học đặc trưng không thay đổi theo quy mô. Các tốc độ học cho các lambda theo lớp được xử lý riêng và không nhận được sự điều chỉnh này. Các ma trận trọng số 2D trong transformer, các phép chiếu chú ý và trọng số mlp, nhận được Muon, và đây là nơi nó trở nên thực sự thú vị. muon lấy gradient, áp dụng động lực nesterov, sau đó chạy một lần lặp newton-schulz để xấp xỉ phân rã cực của ma trận gradient. Phân rã cực phân tách một ma trận G thành G = U * S trong đó U là trực giao và S là dương bán định nghĩa. muon tính toán U, ma trận trực giao gần nhất với gradient, và sử dụng nó làm hướng cập nhật. Lần lặp newton-schulz là 5 bước. Đối với các ma trận cao (nhiều hàng hơn cột), A = X^T @ X sau đó X -> aX + X @ (bA + cA^2). Đối với các ma trận rộng, A = X @ X^T sau đó X -> aX + (bA + cA^2) @ X. Các hệ số được mã hóa cứng từ một tính toán trước. Họ gọi nó là "polar express." Toàn bộ điều này biên dịch thành một kernel hợp nhất duy nhất thông qua torch.compile. Tại sao điều này lại quan trọng? Bởi vì đối với các ma trận trọng số, gradient chuẩn frobenius (điều mà adam và sgd sử dụng) là sai về mặt hình học. Hướng giảm dốc "đúng" cho một ma trận trọng số là hướng tối thiểu hóa mất mát với điều kiện rằng cập nhật có chuẩn phổ đơn vị, không phải chuẩn frobenius đơn vị. Yếu tố cực trực giao cung cấp cho bạn chính xác điều này. Trong thực tế, điều này có nghĩa là muon thực hiện các cập nhật hiệu quả lớn hơn nhiều vì nó không lãng phí kích thước bước vào việc điều chỉnh các giá trị riêng. Nó chỉ xoay chúng. Đây là lý do tại sao muon hội tụ nhanh hơn đáng kể so với adam trên các ma trận trọng số transformer. muon cũng duy trì các bộ đệm động lực theo phần tử (cùng hình dạng như các tham số, xếp chồng qua mỗi nhóm hình dạng), nhưng không giống như adam, nó không theo dõi các khoảnh khắc thứ hai theo phần tử. Các ước lượng khoảnh khắc thứ hai là theo hàng hoặc theo cột sau khi trực giao hóa, không phải theo phần tử. Đó là nơi NorMuon xuất hiện. Trên cơ sở muon có NorMuon, một sơ đồ giảm phương sai. Sau khi trực giao hóa, nó tính toán các ước lượng khoảnh khắc thứ hai theo hàng (hoặc theo cột tùy thuộc vào tỷ lệ khía cạnh), duy trì một trung bình di động theo cấp số nhân của những điều đó, và điều chỉnh cập nhật để mỗi chiều đầu ra nhận được kích thước bước thích ứng riêng. Nó về cơ bản là ý tưởng thích ứng của adam nhưng được áp dụng trong hệ tọa độ trực giao thay vì không gian tham số thô. Sự suy giảm trọng số cũng không tiêu chuẩn. Nó "cẩn thận," có nghĩa là nó chỉ suy giảm các tham số mà hướng cập nhật muon đồng ý với dấu tham số: mask = (g * params) >= 0. Điều này tránh được chế độ thất bại đã biết nơi suy giảm trọng số đẩy các tham số về phía không chống lại mong muốn của cập nhật, điều này có thể làm mất ổn định quá trình đào tạo. Một chi tiết nhỏ mà tôi đánh giá cao: sau bước đào tạo đầu tiên, mã gọi gc.collect(), gc.freeze(), gc.disable() để hoàn toàn tắt bộ thu gom rác của python. GC của python chạy định kỳ và gây ra các khoảng dừng ~500ms. Khi ngân sách tổng thể của bạn là 300 giây và mỗi bước có thể là 300ms, một khoảng dừng GC ngẫu nhiên khiến bạn mất gần 2 bước đào tạo. Họ kích hoạt thủ công gc.collect() mỗi 5000 bước như một sự thỏa hiệp. Đây là loại điều mà bạn chỉ học được bằng cách phân tích các lần chạy đào tạo thực tế và nhận thấy những giảm sút thông lượng bí ẩn. 11 bước đầu tiên (0 đến 10) cũng không được tính vào ngân sách thời gian. Đó là thời gian khởi động nơi torch.compile thực hiện công việc của nó và các kernel CUDA được JIT. Nếu không có sự loại trừ này, các thí nghiệm khác nhau sẽ nhận được các khoảng thời gian "thực" đào tạo khác nhau tùy thuộc vào thời gian biên dịch mất bao lâu cho cấu hình mô hình cụ thể đó. Một lần nữa, một lựa chọn thiết kế có vẻ nhỏ nhưng rất quan trọng để làm cho các thí nghiệm có thể so sánh được. Bây giờ hãy phóng to. Vòng lặp tự nghiên cứu thực tế là: tác nhân đọc program.md (một tệp markdown mô tả công việc của nó), sửa đổi train.py, cam kết, chạy trong 5 phút, kiểm tra xem val_bpb có cải thiện không, giữ lại hoặc hoàn tác, lặp lại. program.md rõ ràng nói "KHÔNG BAO GIỜ DỪNG LẠI." Tác nhân chạy vô thời hạn cho đến khi con người giết nó. ~12 thí nghiệm mỗi giờ, ~100 qua đêm trong khi bạn ngủ. ...