inflearn logo
강의

강의

N
챌린지

챌린지

멘토링

멘토링

N
클립

클립

로드맵

로드맵

지식공유

Tensorflow 사용메뉴얼

Model Implementation

epoch 1부터 loss가 너무 낮게 나와 학습이 안되네요

542

임해빈

작성한 질문수 20

0


ratio = 0.2
x_train = tf.random.normal(shape = (300,), dtype = tf.float32)
y_train = 3 * x_train + 1 + ratio * tf.random.normal(shape = (300, ), dtype = tf.float32)

x_test = tf.random.normal(shape = (30,), dtype = tf.float32)
y_test = 3 * x_test + 1 + ratio * tf.random.normal(shape = (30, ), dtype = tf.float32)

class LinearPredictor(Model):
    def __init__(self):
        super(LinearPredictor, self).__init__()

        self.d1 = Dense(1, activation = "linear")

    def call(self, x):
        x = self.d1(x)
        return x

model = LinearPredictor()

loss_object = tf.keras.losses.MeanSquaredError()
optimizer = SGD(learning_rate=0.01)

for epoch in range(5):
    for x, y in zip(x_train, y_train):
        x = tf.reshape(x, (1, 1))
        with tf.GradientTape() as tape:
            predictions = model(x)
            loss = loss_object(y, predictions)
        gradients = tape.gradient(loss, model.trainable_variables)
        optimizer.apply_gradients(zip(gradients, model.trainable_variables))
    print(f"Epoch: {epoch + 1}")
    print(f"Train Loss: {loss:.4f}")
    

 

강사님이랑 똑같이 코드를 짠것 같은데 train loss가 너무 낮게 나와 학습이 안되네요. 틀린곳이 있는건지 데이터가 너무 심플하게 생성되서 그런건지 잘 모르겠습니다.

 

딥러닝 tensorflow

답변 1

0

인프런 AI 인턴

안녕하세요, 인프런 AI 인턴이에요.

먼저 코드를 확인해보면, 생성된 데이터(x_train, y_train)가 비교적 간단한 형태인 것 같습니다. 입력 데이터인 x_train은 무작위로 생성된 정규 분포를 따르는 값이고, 이를 기반으로 y_train은 일정한 기울기와 절편을 가진 선형 모델에 노이즈(ratio * tf.random.normal)를 추가한 값으로 생성되어 있습니다.

만약 데이터가 너무 심플하다면, 학습이 제대로 이루어지지 않을 수 있습니다. 모델이 학습할 패턴이 부족하여 일정 수준 이상의 정확도를 달성할 수 없습니다.

또한, 학습 횟수인 epoch도 5로 상대적으로 낮게 설정되어 있습니다. 실제로 학습 횟수(epoch)를 늘릴수록 모델의 성능이 향상될 가능성이 있습니다. epoch를 5에서 10, 20 등 더 큰 수로 늘려보고 결과를 확인해보는 것을 추천드립니다.

데이터의 복잡도를 높이거나, 학습 횟수를 늘려본 후에도 문제가 지속된다면 다른 요인들(예: 학습률, 다른 optimizer 등)을 조정하여 성능을 향상시킬 수 있습니다.
성능 향상에는 여러 가지 요인이 작용할 수 있으므로, 여러 가지 시도를 통해 최적의 결과를 찾아보시기 바랍니다.

더 도움이 필요하시다면 언제든지 질문해주세요.

0

임해빈

혹시 다른 코드 이상은 없나요?

[업로드오류] 강의 내용과 제목 불일치

0

24

2

섹션22 퀴즈 질문

0

30

1

import torch가 안되는 경우는 어떻게 하나요?

0

31

1

소리가 겹쳐서 들려요

0

33

2

20강에서 파인튜닝 때 사용한 데이터가 없어졌습니다. LoRA Trainer 매개변수도 라이브러리 업그레이드로 수정되었습니다.

0

24

1

[개정판] 딥러닝 컴퓨터 비전 완벽 가이드 먼저? 구현하며 배우는 Transformer 먼저?

0

34

1

수업자료

0

29

2

전 강의와 전혀 이어지지가 않음

0

34

1

pytorch local 설치 옵션에 conda 가 없습니다.

0

40

3

pc에서는 괜찮은데 탭으로 들으니 화면확대시 화면이 까맙니다

0

24

1

강의 환경설정 질문

0

43

2

모든 자료 다운로드 누를때마다 똑같은 excel파일이 다운로드 받아짐. 노션 주소 공유되나요?

0

35

2

오토인코더+ Knn, SVC 로 해석하는경우

0

46

3

강의자료에 소스코드가 없는데요

0

51

3

강화학습저장 및 로드

0

61

1

Custom Dataset에서의 polygon 정보 관련

0

90

3

동영상 재생오류

0

55

1

multiple inputs

0

320

1

텍스트 데이터일 때의 dtype

0

267

1

12강 data split take와 skip

0

319

1

<tensorflow사용메뉴얼> 강의파일

0

291

0

SyntaxError: keyword can't be an expression

0

519

0

강의 감사합니다. 다섯번째 강의인 Model Implementation가 재생이 안됩니다.

0

173

0

unsupported operand type(s) for *: 'float' and 'NoneType'

0

1969

1