Tensorflow checkpoint 기능 활용 방법

21.11.14 15:58 작성 조회수 386

0

안녕하세요 딥러닝을 구글colab을 이용하여 공부중인데 현재 이미지 학습 부분을 공부하고 있습니다.

하지만 결제를 하여도 런타임이 24시간이 한계라 훈련이 자꾸 중단되어 epoch를 전부 학습하지 못하고 있습니다!

그래서 1epoch마다 checkpoint를 저장하는 방법을 사용하여 훈련을 하고 있고, 현재 잘 저장이 되고 있습니다.

하지만 저는 항상 3/10 epoch에서 24시간이 지나 훈련이 중단 되는데 그러면 이때 3 까지 저장된 checkpoint를 불러와서

다시 4epoch부터 재 학습을 시킬수 있는방법이  궁금합니다! 검색을 해 보았지만 전부 학습이 완료된 데이터를 불러오는 예제 밖에 없어서 질문 남깁니다!

제가 사용한 코드는 아래와 같습니다. 

from fastai.imports import *
from tensorflow.keras import datasets, layers, models, losses, Model
from tensorflow import keras
import tensorflow as tf
from keras.layers import Dense,Dropout,Activation,Add,MaxPooling2D,Conv2D,Flatten,BatchNormalization
from keras.models import Sequential 
from keras.preprocessing.image import ImageDataGenerator
from keras import layers
import seaborn as sns
from keras.preprocessing import image
import numpy as np
import cv2   
from google.colab.patches import cv2_imshow
import matplotlib.pyplot as plt
plt.style.use('seaborn-white')

data_path = '/content/drive/MyDrive/train_val_data'
train_dir = os.path.join(data_path,'train')
val_dir = os.path.join(data_path,'test')
classes = os.listdir(train_dir)

train_datagen = tf.keras.preprocessing.image.ImageDataGenerator(
    width_shift_range = 0.2,
    height_shift_range = 0.2,
    zoom_range = 0.2,
    vertical_flip=True,
    rescale = 1. / 255,
    fill_mode='nearest')

val_datagen = tf.keras.preprocessing.image.ImageDataGenerator(rescale = 1. / 255)

train_generator = train_datagen.flow_from_directory(
    train_dir,
    target_size=(224,224),
    batch_size = 32,
    class_mode = 'categorical'
)

val_generator = val_datagen.flow_from_directory(
    val_dir,
    target_size=(224,224),
    batch_size = 32,
    class_mode = 'categorical'
)

base_model = tf.keras.applications.ResNet50(weights = 'imagenet', include_top = False, input_shape = (224,224,3))
for layer in base_model.layers:
  layer.trainable = False

x = layers.Flatten()(base_model.output)
x = layers.Dense(720, activation='relu')(x)
predictions = layers.Dense(360, activation = 'softmax')(x)

opt = tf.keras.optimizers.Adam(learning_rate=0.001)

checkpoint_path = "/content/drive/MyDrive/training_resnet50/resnet50_cp.ckpt"
checkpoint_dir = os.path.dirname(checkpoint_path)

# 체크포인트 콜백 만들기
cp_callback = tf.keras.callbacks.ModelCheckpoint(checkpoint_path,
                                                 save_weights_only=True,
                                                 verbose=1)

head_model = Model(inputs = base_model.input, outputs = predictions)
head_model.compile(optimizer=opt, loss='categorical_crossentropy', metrics=['accuracy'])

history = head_model.fit(train_generator, validation_data=val_generator, batch_size=32, epochs=100, callbacks = [cp_callback])
 
 

답변 0

답변을 작성해보세요.

답변을 기다리고 있는 질문이에요.
첫번째 답변을 남겨보세요!