• 카테고리

    질문 & 답변
  • 세부 분야

    컴퓨터 비전

  • 해결 여부

    미해결

깃허브에 있는 딥러닝 모델 가져오기(깃허브 오픈소스 사용법)

21.09.16 14:53 작성 조회수 1.82k

0

안녕하세요 아직 많이 듣지는 않았지만 ㅎ 유익한 강의 정말 잘 듣고 있습니다.!

다름이 아니라 최근 초해상화(Super resolution)에 관심이 다양한 딥러닝 모델들을 찾아보았습니다.

아직 논문을 보고 코드를 구현하기에는 많이 부족해서 다른 사람들이 깃허브에 업로드한 딥러닝 모델들을 적용해보고 싶은데  활용방법에 익숙치 않습니다. 

예를들어 SRCNN모델을 적용해보고싶어 https://github.com/TheStarkor/SRCNN-tensorflow2 이 깃허브를 활용하고 싶을때 어떻게 자신만의 데이터를 학습시키고 output 데이터를 도출할 수 있는지 잘 감이 잡히지 않아 질문드립니다. 

 

답변 1

답변을 작성해보세요.

0

안녕하십니까, 

음, 저도 이걸 구동해 보진 않았지만 main.py를 해석해 봤습니다.  너무 길어서 여기서 적어서 설명드릴수는 없을것 같습니다. 스스로 main.py를 함 보셔야 할 것 같습니다. 

개요적으로 말씀드리면 아래 구성에서 

DATA_DIR = "../src/"

FILE_PATH = "./models/srcnn_div2k.hdf5"

TRAIN_PATH = "DIV2K_train_HR"

TEST_PATH = "DIV2K_valid_HR"

 

N_TRAIN_DATA = args.N_TRAIN_DATA

N_TEST_DATA = args.N_TEST_DATA

BATCH_SIZE = args.BATCH_SIZE

EPOCHS = args.EPOCHS

 

train_data_generator() 함수 내부를 보시면 ImageDataGenerator.flow_from_directory() 구조로 되어 있습니다. 학습 데이터를 flow_from_directory()에서 사용가능할 수 있도록 DATA_DIR = "../src/" 디렉토리 구조에 맞게 옮겨 주셔야 학습이 될 것 같습니다.