• 카테고리

    질문 & 답변
  • 세부 분야

    컴퓨터 비전

  • 해결 여부

    해결됨

save_weights_only=True로 했을 때 load_model 오류

23.08.07 09:21 작성 조회수 324

0

안녕하세요 교수님!

ModelCheckpoint에서

ModelCheckpoint('best_model.h5', save_weights_only=True, monitor='val_loss', save_best_only=True, mode='min')

save_weights_only = True로 했을 때 아래와 같은 load_model 에러가 나더라구요..

그래서 구글링을 해봤는데 저렇게 설정할 경우에 모델 아키텍처가 저장이 안되어서 load_model을 할 수 없다고 json 파일로 모델을 따로 저장하고 나중에 json 모델을 다시 불러오는 방법을 사용하라고 나왔습니다.

 

  1. 강의 중에도 언급해주셨지만 save_weights_only = True로 했을 때의 이점이 있을까요..? False로 했을 때 교수님께서 모델을 불러올 때 충돌..? 비슷한 것이 난다고 하셨는데 좀 더 세부적인 내용을 알고 싶습니다..!

  2. 만약에 True로 설정했다면 매번 json으로 모델을 저장하는 과정을 거쳐야 하는 것인지 궁금합니다!

    model.save() 함수도 있던데 이거는 modelcheckpoint와 달리 학습 중에 저장은 안되는 것 같아서요..

 

항상 감사합니다 교수님!!

답변 1

답변을 작성해보세요.

0

안녕하십니까,

원래 초창기 keras는 weight와 layer기반 모델, optimizer 정보만을 하나의 파일로 저장합니다. 그리고 weight만 저장하더라도 하나의 파일로만 저장합니다. 반면에 Tensorflow는 weight는 별도의 파일로 만들고, layer기반 아키텍처와 optimizer, metric, loss정보는 또 다른 파일에 저장합니다.

또한 Keras의 경우는 fit()을 수행할 때 callback에서 모델과 weight를 함께 저장하는 방식을 적용할 때, 버전에 따라 알수 없는(?) 오류를 맞는 경우들이 있는 것 같습니다. tensorflow 저장 방식은 checkpoint(weight)와 모델 저장이 별도로 되는데 keras는 옛날부터 h5 형식으로 단일 파일에서 저장합니다. 이게 뭔가 tensorflow와 keras가 합쳐지면서 잘 안 맞는 느낌입니다(지금은 버전이 향상되어서 잘 될지도 모릅니다)

  1. save_weights_only = Fasle로 한번 시도해 보시지요. 다만 저는 말씀 드렸듯이 오류가 발생하는 경우가 있었습니다.

     

  2. 저는 fit()에서 callback으로 ModelCheckPoint 사용 시 weight만 저장 하는 방식을 선호 합니다. 모델은 다시 메모리로 아래와 같이 만든 저장된 checkpoint weight를 로딩하는 방식입니다.

    sequential_model = keras.Sequential(  

    [         keras.Input(shape=(784,), name="digits"),        

    keras.layers.Dense(64, activation="relu", name="dense_1"),        

    keras.layers.Dense(64, activation="relu", name="dense_2"),        

    keras.layers.Dense(10, name="predictions"),     ] )

load_status = sequential_model.load_weights("ckpt")

만약 모델과 weight를 하나의 파일로 다 저장하시려면 학습이 다 완료된 후에 model.save("my_h5_model.h5") 로 모델과 weight를 h5로 다 저장하신 후 아래와 같이 로딩하시면 됩니다.

reconstructed_model = keras.models.load_model("my_h5_model.h5")

보다 자세한 모델 저장 및 로드는 아래 URL 참조 부탁드립니다.

https://www.tensorflow.org/guide/keras/save_and_serialize?hl=ko

감사합니다.