행렬곱 오차 문제
안녕하세요. 큰 질문은 아니고 사소한 질문일 수도 있습니다만..
다름이 아니라, 행렬곱 강의에서 구현한 코드에서는 곱해주는 행렬 크기가 커질수록 오차가 누적되는 듯한(정확히 말하면 파이토치 내장 matmul과 계산 결과가 점점 더 달라지는듯한) 현상이 관찰되어 질문드립니다.
먼저, 실습에서 정의한 코드에서부터
x = torch.randn(16,16,device = 'cuda')
y = torch.randn(16,16,device = 'cuda')
a = matmul(x,y)
b = torch.matmul(x,y)
assert torch.allclose(a,b)torch.allclose 의 기본 인자(atol=1e-8, rtol=1e-5) 세팅에서는 assertion error가 발생하여 조건을 완화시켜야(atol=1e-5, rtol=1e-5) assertion이 통과되는 모습을 보였고
x = torch.randn(2048,1024,device = 'cuda')
y = torch.randn(1024,256,device = 'cuda')
x, y의 크기를 이와 같이 키웠을 경우엔 atol=1e-4, rtol=1e-4로 조건을 완화시켜야 assertion을 통과하는 모습을 보였습니다.
triton kernel로 구현한 행렬곱 연산과 PyTorch 내장 matmul 연산 모두 fp32로 연산이 이루어지고 있는데, 이러한 오차가 발생할 수 있는 원인에 무엇이 있는지 궁금해서 질문 드립니다.
Answer 2
3
안녕하세요? 아담한 고슴도치님,
먼저 강의를 수강해주셔서 감사합니다. 계산의 결과가 다른 이유는 크게 2가지가 있습니다.
첫째, 다른 데이터 타입을 사용함에 따라 오차가 발생할 수 있습니다. 예제의 경우 행렬을 곱을 tl.dot(x, y, allow_tf32=False)를 호출해서 계산했습니다. 만약 allow_tf32에 True가 설정되어 있거나 allow_tf32가 정의되어 있지 않는 경우에 오차가 발생할 수 있습니다. Triton이 행렬을 빠르게 계산하기 위해 float32를 tf32로 변환한 뒤 Tensor Core를 사용하기 때문입니다. tf32의 경우 float32보다 정밀도가 낮은데, 이 차이로 인해 계산의 오차가 발생할 수 있습니다.
둘째, 계산 순서에 따라 결과가 달라질 수 있습니다. float32는 IEEE 754 표준에 맞춰서 구현되어 있습니다. 지수에 8비트가 사용되고 가수에 23비트가 사용됩니다. 그러므로 float32는 실수를 다 표현할 수 없습니다. 이러한 한계 때문에 계산 순서에 따라 오차가 발생할 수 있습니다. 이러한 현상은 쉽게 확인할 수 있습니다. 크기가 20000인 배열에 실수가 저장되어 있는 경우, 순서대로 실수의 합을 더할때와 역순으로 실수의 합을 더할때의 결과가 다른 것을 확인할 수 있습니다.
이 2가지의 경우 하드웨어의 한계로 발생하는 오차입니다. 개인적으로 저는 이러한 오차를 오차라고 생각하지 않습니다.
마지막으로 예제 코드의 경우 경계 검사가 되어있지 않습니다. 텐서의 크기와 블록의 크기가 정확히 나누어 떨어지지 않는다면 쓰레기 값이 임시 텐서에 로드되게 되고 이것이 잘못된 결과를 만들 수 있습니다.
요즘 제가 바빠서 답변이 늦었는데 죄송합니다. 궁금하신점 있으시면 계속 물어봐주세요. 감사합니다!
작업형 1 유형 부분
0
9
1
수강평 이벤트
0
15
2
import torch가 안되는 경우는 어떻게 하나요?
0
15
1
작업형 1 (삭제예정, 구 버전)
0
28
2
강의노트는 어디있나요?
0
15
1
노션 학습 자료 권한 요청
0
15
1
수강기간 연장 문의드립니다.
0
20
1
2유형 레이블 인코딩 VS 원핫 인코딩
0
20
3
part2강의 문의사항입니다.
0
17
2
수강기간 연장 문의드립니다.
0
26
1
인덱스 슬라이싱
0
26
2
코드를 첨부해야하는 이유가 있나요?
0
20
2
소리가 겹쳐서 들려요
0
19
2
데스크톱과 노트북 연결
0
26
1
18강 smithery 를 이용한 mcp 실습(업데이트 요청)
0
17
1
autotune은 아직 안 올라온 건가요?
0
49
1
강의만 봐서는 triton 커널이 pytorch에 비해 빨라 보이지 않네요..
0
165
2
block ptr 질문
0
62
2
디스코드 커뮤
0
87
1
앞으로의 강의 계획에 대하여
0
243
2
코드가 실행되는 순서에 관하여
0
366
1
실행을 위한 최적 환경
1
830
1
강의 계획에 대하여
0
336
1
실습 코드
0
376
1

