inflearn logo
강의

강의

N
챌린지

챌린지

멘토링

멘토링

N
클립

클립

로드맵

로드맵

지식공유

Tensorflow checkpoint 기능 활용 방법

606

bubibubibaby

작성한 질문수 2

0

안녕하세요 딥러닝을 구글colab을 이용하여 공부중인데 현재 이미지 학습 부분을 공부하고 있습니다.

하지만 결제를 하여도 런타임이 24시간이 한계라 훈련이 자꾸 중단되어 epoch를 전부 학습하지 못하고 있습니다!

그래서 1epoch마다 checkpoint를 저장하는 방법을 사용하여 훈련을 하고 있고, 현재 잘 저장이 되고 있습니다.

하지만 저는 항상 3/10 epoch에서 24시간이 지나 훈련이 중단 되는데 그러면 이때 3 까지 저장된 checkpoint를 불러와서

다시 4epoch부터 재 학습을 시킬수 있는방법이  궁금합니다! 검색을 해 보았지만 전부 학습이 완료된 데이터를 불러오는 예제 밖에 없어서 질문 남깁니다!

제가 사용한 코드는 아래와 같습니다. 

from fastai.imports import *
from tensorflow.keras import datasets, layers, models, losses, Model
from tensorflow import keras
import tensorflow as tf
from keras.layers import Dense,Dropout,Activation,Add,MaxPooling2D,Conv2D,Flatten,BatchNormalization
from keras.models import Sequential 
from keras.preprocessing.image import ImageDataGenerator
from keras import layers
import seaborn as sns
from keras.preprocessing import image
import numpy as np
import cv2   
from google.colab.patches import cv2_imshow
import matplotlib.pyplot as plt
plt.style.use('seaborn-white')

data_path = '/content/drive/MyDrive/train_val_data'
train_dir = os.path.join(data_path,'train')
val_dir = os.path.join(data_path,'test')
classes = os.listdir(train_dir)

train_datagen = tf.keras.preprocessing.image.ImageDataGenerator(
    width_shift_range = 0.2,
    height_shift_range = 0.2,
    zoom_range = 0.2,
    vertical_flip=True,
    rescale = 1. / 255,
    fill_mode='nearest')

val_datagen = tf.keras.preprocessing.image.ImageDataGenerator(rescale = 1. / 255)

train_generator = train_datagen.flow_from_directory(
    train_dir,
    target_size=(224,224),
    batch_size = 32,
    class_mode = 'categorical'
)

val_generator = val_datagen.flow_from_directory(
    val_dir,
    target_size=(224,224),
    batch_size = 32,
    class_mode = 'categorical'
)

base_model = tf.keras.applications.ResNet50(weights = 'imagenet', include_top = False, input_shape = (224,224,3))
for layer in base_model.layers:
  layer.trainable = False

x = layers.Flatten()(base_model.output)
x = layers.Dense(720, activation='relu')(x)
predictions = layers.Dense(360, activation = 'softmax')(x)

opt = tf.keras.optimizers.Adam(learning_rate=0.001)

checkpoint_path = "/content/drive/MyDrive/training_resnet50/resnet50_cp.ckpt"
checkpoint_dir = os.path.dirname(checkpoint_path)

# 체크포인트 콜백 만들기
cp_callback = tf.keras.callbacks.ModelCheckpoint(checkpoint_path,
                                                 save_weights_only=True,
                                                 verbose=1)

head_model = Model(inputs = base_model.input, outputs = predictions)
head_model.compile(optimizer=opt, loss='categorical_crossentropy', metrics=['accuracy'])

history = head_model.fit(train_generator, validation_data=val_generator, batch_size=32, epochs=100, callbacks = [cp_callback])
 
 

tensorflow deeplearning python 파이썬 딥러닝 텐서플로우 checkpoint 체크포인트

답변 0

6-6

0

5

0

작업형 1 유형 부분

0

10

1

수강평 이벤트

0

17

2

import torch가 안되는 경우는 어떻게 하나요?

0

16

1

작업형 1 (삭제예정, 구 버전)

0

30

2

강의노트는 어디있나요?

0

17

1

노션 학습 자료 권한 요청

0

17

1

수강기간 연장 문의드립니다.

0

21

1

2유형 레이블 인코딩 VS 원핫 인코딩

0

24

3

part2강의 문의사항입니다.

0

19

2

수강기간 연장 문의드립니다.

0

26

1

인덱스 슬라이싱

0

27

2

코드를 첨부해야하는 이유가 있나요?

0

20

2

소리가 겹쳐서 들려요

0

20

2

데스크톱과 노트북 연결

0

26

1

dict, zip

0

21

2

노션 : 파트3번 링크와 권한 , 파트4번 권한요청, 파트 5번도 미리 요청 드립니다.

0

27

4

6-6 실습 문의

0

23

2

미션 06-02

0

24

2

yes24 수집 md 파일 만들 때

0

24

2

python main.py 실행시 게임이 실행이 안돼요

0

27

2

antigravity 대신 cursor를 활용해도 되나요?

0

26

1

뉴스 검색 분류 한도초과

0

36

2

완성자료

0

25

2