• 카테고리

    질문 & 답변
  • 세부 분야

    컴퓨터 비전

  • 해결 여부

    미해결

.pt 파일에서 .tflite로 어떻게 변환하는지 궁금합니다.

21.09.23 02:08 작성 조회수 4.97k

0

custom dataset으로 weight모델을 만들어 android에 연동을 하고자 합니다.

android는 .pt이 아닌 .tflite를 사용하는데 혹시 변환하는 방법을 알수있을까요?

답변 2

·

답변을 작성해보세요.

1

Jina Shin님의 프로필

Jina Shin

2023.08.17

torch2tflite는 아니지만, 쉽게 변환해주는 서비스가 있네요.
모델 올리면, onnx2TFlite, TF2TFlite 자동으로 변환해줍니다.
https://launchx.netspresso.ai/main

0

안녕하십니까, 

저도 torch를 tflite로는 변환해 보지 않아서 검색을 해봤는데, 아래와 같은 방식으로 변환을 해야 할것 같습니다. 

PyTorch -> Onnx -> Tensorflow 2 -> TFLite

변환이 생각만큼 자연스럽게 되는것 같지는 않습니다.

아래에 변환 사례가 있습니다. 

https://bekusib.tistory.com/210

변환 해주는 util도 있군요. 

https://github.com/omerferhatt/torch2tflite

 

dhyu님의 프로필

dhyu

질문자

2021.10.04

ERROR:root:Can not load PyTorch model. Please make surethat model saved like `torch.save(model, PATH)`

깃허브에 나와있는 유틸로 테스트 해보고있는데 계속 이와 같이 오류가 뜹니다.

 

Path to local PyTorch model, please save whole model e.g. torch.save(model, PATH)

깃허브에선 이와같이 설명이 되어있는데 torch.save 형식이 이해가 안갑니다.

파이토치에서 torch.save로 모델을 저장하면 .pt 형식으로 저장되는걸로 아는데 그럼 경로를 .pt로 주면 되는게 아니였나요..? ㅠㅠㅠ 

저도 해당 util을 안써봐서 잘 모르겠습니다만, 

mmdetection 용 pt 인지, yolo용 pt인지는 모르겠지만, 변환하려는 pt 파일이 torch.save(model)로 저장된 형태인지 소스 코드 확인이 필요해 보입니다.  아마 checkpoint로 저장시 torch.save(model) 이 아니라 torch.save(model.state_dict(), Path) 로 모델 구조를 저장하지 않고 weight 만 저장하는 구조로 만들었을 가능성이 높습니다. torch.save(model, path)로 모델 구조와 weight를 함께 저장하는 구조로 되어 있는지 확인이 필요해 보입니다. 

dhyu님의 프로필

dhyu

질문자

2021.10.06

답변해주셔서 감사합니다. 제가 강의의 ultralytics/yolov3 을 이용하여 .pt파일을 만들어서 진행하였습니다. train.py파일을 살펴보니 torch.save(model.state_dict(), Path)  구조가 맞는거 같은데 머가 문제인지 잘 이해가 안되네요 ㅠㅠㅠ 'weight를 함께 저장하는 구조'도 무슨 말인지 이해가 안갑니다 ㅠㅠㅠ

torch.save()가 두가지 형태가 있습니다. 하나는 모델의 전체 구조와 weight 값을 함께 저장하는 형태, 그러니까 모델의 layer가 몇개이고, 어떻게 각각이 연결되어 있고,  activation은 어떻게 되고 등의 모델 전체 구조와 convolution, dense, batch normalization 등의 layer에 weight값을 함께 저장하는 형를 말합니다. 

다른 하나는 모델의 전체 구조를 저장하지 않고 weight값만 저장합니다. 이런 경우는 load()를 호출하는 쪽에서 이미 모델의 전체 구조는 만들어져 있고, 개별 layer에 어떤 weight값을 load만 하는 경우에 사용됩니다. 

torch.save(model.state_dict(), Path) 는 model의 구조가 아닌 weight만 저장하는 방식입니다. 

모델을 함께 저장하려면 torch.save(model, PATH)를 사용합니다. 그런데 이 방식은 python의 pickle을 적용하는데 제대로 잘 안되는 경우가 많습니다. 모델 그 자체를 저장한다기 보다 그냥 pickle로 직렬화를 시키는 것입니다. 이때 모델을 만들때 필요한 클래스들이 제대로 호환되지 않을 수 있습니다. 그래서 보통은 torch.save(model.state_dict(), path)와 같이 weight만 저장하는 경우가 많습니다.

dhyu님의 프로필

dhyu

질문자

2021.10.07

자꾸 되물어서 죄송합니다. ㅠㅠ

선생님이 말씀하신대로 Ultralytics/yolov3의 train.py의 코드를 살펴보니 torch.save(model,path)형식으로 저장되는것 같은데 util이용하면 똑같이 상기 명시된 에러가 뜨네요 ㅠㅠ

포기하는 마음으로 직접 .pt -> .onnx -> .pb -> . tflite로 변환하고자 했는데 pt->onnx는 되었지만 onnx-> pb가 안됩니다 

제가 사용한 방법은 

1) https://github.com/onnx/onnx-tensorflow.git에서 깃클론후 converter.py을 실행했습니다. 버전 다맞추고 에러없이 실행에는 성공했는데 반응이 없으며 (어떠한 log도 나오지 않았습니다. converter.py코드상으로는 convert시작하면 시작한다고 log가 뜨게끔 되어있습니다.) output파일 또한 나오질 않았습니다.

2) onnx-tf install후 아래와 같이 코드실행

import onnx
from onnx_tf.backend import prepare

onnx_model = onnx.load("/content/drive/MyDrive/ultra_workdir/eyeshop/weights/best.onnx") # load onnx model
tf_rep = prepare(onnx_model) # prepare tf representation
tf_rep.export_graph("/content/drive/MyDrive/ultra_workdir/eyeshop/weights/best.pb") # export the model
 
버젼을 맞추고 이와 같이 실행을 하였는데 아래와 같은 오류가 뜹니다.
버젼은 tensorflow == 2.6(1.15로 해보니 더이상 지원하지 않는다며 에러가 발생했었습니다.) , onnx == 1.8, onnx-tf==1.8 , torch == 1.7 입니다.

혹시 tflite로 변경해서 하는 방법 말고 다른방법으로 andorid studio로 연동할수 있는 방법이 있으면 알고 싶습니다.
 

음, tflite로 바꾸시려는 의도는 알겠습니다만, Ultralytics나 mmdetection의 학습된 모델을 tflite로 변경해도 inference 하기가 어려울것 같습니다. ultralytics나 mmdetection 학습 모델을 inference 시에는 ultralytics/mmdetection 내부의 inference api call을 사용합니다. 학습 모델을 tflite로 바꾸어도 inference 호출이 어려울 것 같습니다.

tflite로 바꾸시려면 automl efficientdet은 어떠신지요? 성능도 yolo v5와 비슷하고 tensorflow 2.x로 되어 있는데다가 tflite 변환 후에도 automl efficientdet inference가 잘되는 걸로 알고 있습니다. 

dhyu님의 프로필

dhyu

질문자

2021.10.08

넵 알겠습니다! 답변해주셔서 감사합니다!!

dhyu님의 프로필

dhyu

질문자

2021.10.09

혹시 제가 custom Data셋으로 모델을 train 하고 싶은데 CVAT 툴의 tfrecord로 데이터 셋을 export 한다음 automl efficientdet 으로 train이 가능할까요? 또한 강의에서 tflite으로 변환하는 과정도 다루는 지 알 수 있을까요?

제가 CVAT 툴로 tfrecord를 만들어보지는 않았습니다. 

efficientdet 강의를 들으시면 일반 이미지와 annotation 파일로 어떻게 tfrecord를 만드는지 강의 영상이 있으니 참조하시면 될 것 같습니다.

전체 강의에서 tflite로 변환하는 과정은 없습니다.