Hi,
I’m 12 years old and was reading the train.py script. I noticed that it sets:
dtype = 'bfloat16' if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else 'float16'
scaler = torch.cuda.amp.GradScaler(enabled=(dtype == 'float16'))
ctx = nullcontext() if device_type == 'cpu' else torch.amp.autocast(device_type=device_type, dtype=ptdtype)
Since MPS isn’t CUDA, I was wondering if fp16/GradScaler would behave correctly there. Should the code explicitly disable fp16 when device_type == 'mps', or is it intended to rely on PyTorch’s default handling?
Thanks for making such an awesome repo!