inflearn logo
강의

강의

N
챌린지

챌린지

멘토링

멘토링

N
클립

클립

로드맵

로드맵

지식공유

Tensorflow 사용메뉴얼

Model Implementation

epoch 1부터 loss가 너무 낮게 나와 학습이 안되네요

552

임해빈

작성한 질문수 20

0


ratio = 0.2
x_train = tf.random.normal(shape = (300,), dtype = tf.float32)
y_train = 3 * x_train + 1 + ratio * tf.random.normal(shape = (300, ), dtype = tf.float32)

x_test = tf.random.normal(shape = (30,), dtype = tf.float32)
y_test = 3 * x_test + 1 + ratio * tf.random.normal(shape = (30, ), dtype = tf.float32)

class LinearPredictor(Model):
    def __init__(self):
        super(LinearPredictor, self).__init__()

        self.d1 = Dense(1, activation = "linear")

    def call(self, x):
        x = self.d1(x)
        return x

model = LinearPredictor()

loss_object = tf.keras.losses.MeanSquaredError()
optimizer = SGD(learning_rate=0.01)

for epoch in range(5):
    for x, y in zip(x_train, y_train):
        x = tf.reshape(x, (1, 1))
        with tf.GradientTape() as tape:
            predictions = model(x)
            loss = loss_object(y, predictions)
        gradients = tape.gradient(loss, model.trainable_variables)
        optimizer.apply_gradients(zip(gradients, model.trainable_variables))
    print(f"Epoch: {epoch + 1}")
    print(f"Train Loss: {loss:.4f}")
    

 

강사님이랑 똑같이 코드를 짠것 같은데 train loss가 너무 낮게 나와 학습이 안되네요. 틀린곳이 있는건지 데이터가 너무 심플하게 생성되서 그런건지 잘 모르겠습니다.

 

딥러닝 tensorflow

답변 1

0

인프런 AI 인턴

안녕하세요, 인프런 AI 인턴이에요.

먼저 코드를 확인해보면, 생성된 데이터(x_train, y_train)가 비교적 간단한 형태인 것 같습니다. 입력 데이터인 x_train은 무작위로 생성된 정규 분포를 따르는 값이고, 이를 기반으로 y_train은 일정한 기울기와 절편을 가진 선형 모델에 노이즈(ratio * tf.random.normal)를 추가한 값으로 생성되어 있습니다.

만약 데이터가 너무 심플하다면, 학습이 제대로 이루어지지 않을 수 있습니다. 모델이 학습할 패턴이 부족하여 일정 수준 이상의 정확도를 달성할 수 없습니다.

또한, 학습 횟수인 epoch도 5로 상대적으로 낮게 설정되어 있습니다. 실제로 학습 횟수(epoch)를 늘릴수록 모델의 성능이 향상될 가능성이 있습니다. epoch를 5에서 10, 20 등 더 큰 수로 늘려보고 결과를 확인해보는 것을 추천드립니다.

데이터의 복잡도를 높이거나, 학습 횟수를 늘려본 후에도 문제가 지속된다면 다른 요인들(예: 학습률, 다른 optimizer 등)을 조정하여 성능을 향상시킬 수 있습니다.
성능 향상에는 여러 가지 요인이 작용할 수 있으므로, 여러 가지 시도를 통해 최적의 결과를 찾아보시기 바랍니다.

더 도움이 필요하시다면 언제든지 질문해주세요.

0

임해빈

혹시 다른 코드 이상은 없나요?

[긴급요청2] 28강 동영상 오류(27강 음성 재생) 수정 예정일자를 알려주세요

0

10

1

[긴급요청] 28강 동영상 전체에 오디오가 잘못들어가있습니다.

0

25

1

17강 강의 마지막 부분에서 헷갈리는 부분이 있습니다.

0

24

0

14강의 예시 데이터를 통한 답을 구하는 과정에서의 문의

0

26

1

강의 자료 열람 요청이 안되네요

0

18

1

예측에 사용하는 경제지표의 활용에 대해

1

32

2

강의 자료를 어디서 확인 할 수 있나요?

1

44

3

실습 코드

0

35

2

MMDetection 버전 이슈

0

59

2

[업로드오류] 강의 내용과 제목 불일치

0

46

2

섹션22 퀴즈 질문

0

59

1

import torch가 안되는 경우는 어떻게 하나요?

0

57

1

소리가 겹쳐서 들려요

0

66

2

20강에서 파인튜닝 때 사용한 데이터가 없어졌습니다. LoRA Trainer 매개변수도 라이브러리 업그레이드로 수정되었습니다.

0

46

1

[개정판] 딥러닝 컴퓨터 비전 완벽 가이드 먼저? 구현하며 배우는 Transformer 먼저?

0

64

1

수업자료

0

43

2

동영상 재생오류

0

62

1

multiple inputs

0

327

1

텍스트 데이터일 때의 dtype

0

275

1

12강 data split take와 skip

0

326

1

<tensorflow사용메뉴얼> 강의파일

0

296

0

SyntaxError: keyword can't be an expression

0

522

0

강의 감사합니다. 다섯번째 강의인 Model Implementation가 재생이 안됩니다.

0

176

0

unsupported operand type(s) for *: 'float' and 'NoneType'

0

1974

1