• 카테고리

    질문 & 답변
  • 세부 분야

    컴퓨터 비전

  • 해결 여부

    해결됨

config 파일 수정 문의

23.03.27 22:45 작성 23.03.27 23:21 수정 조회수 369

0

안녕하세요 선생님

선생님 강의를 통해서 custom dataset을 이용하여

faster-rcnn 모델을 돌려볼 수 있었습니다.

이 custom dataset으로 다른 모델(swin)도 적용해보려고 하는데요 https://github.com/open-mmlab/mmdetection/tree/master/configs/swin 이 페이지의 mask_rcnn_swin-t-p4-w7_fpn_1x_coco.py 파일을 이용해보려고 합니다. 그에 맞게 config파일과 checkpoints를 변경하고 모델을 구동하려고 하니 mask관련해 오류가 발생했습니다. 아마 mask-rcnn으로인해 발생한 오류처럼 보입니다. 구글링을 해보니 이 부분을 주석 처리해서 실행해보라고 하던데 colab에서 해당 부분을 주석처리할 수 있는 방법이 있을까요? 혹시 더 좋은 방법이 있다면 가르쳐 주시면 감사하겠습니다.


2023-03-27 14:19:05,247 - mmdet - INFO - Automatic scaling of learning rate (LR) has been disabled.
<ipython-input-14-f8ce61995cc8>:47: DeprecationWarning: `np.long` is a deprecated alias for `np.compat.long`. To silence this warning, use `np.compat.long` by itself. In the likely event your code does not need to work on Python 2 you can use the builtin `int` for which `np.compat.long` is itself an alias. Doing this will not modify any behaviour and is safe. When replacing `np.long`, you may wish to use e.g. `np.int64` or `np.int32` to specify the precision. If you wish to review your current use, check the release note link for additional information.
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  'labels': np.array(gt_labels, dtype=np.long),
<ipython-input-14-f8ce61995cc8>:49: DeprecationWarning: `np.long` is a deprecated alias for `np.compat.long`. To silence this warning, use `np.compat.long` by itself. In the likely event your code does not need to work on Python 2 you can use the builtin `int` for which `np.compat.long` is itself an alias. Doing this will not modify any behaviour and is safe. When replacing `np.long`, you may wish to use e.g. `np.int64` or `np.int32` to specify the precision. If you wish to review your current use, check the release note link for additional information.
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  'label_ignore':np.array(gt_labels_ignore, dtype=np.long)
2023-03-27 14:19:08,688 - mmdet - INFO - load checkpoint from local path: checkpoints/mask_rcnn_swin-t-p4-w7_fpn_1x_coco_20210902_120937-9d6b7cfa.pth
2023-03-27 14:19:08,849 - mmdet - WARNING - The model and loaded state dict do not match exactly

size mismatch for roi_head.bbox_head.fc_cls.weight: copying a param with shape torch.Size([81, 1024]) from checkpoint, the shape in current model is torch.Size([16, 1024]).
size mismatch for roi_head.bbox_head.fc_cls.bias: copying a param with shape torch.Size([81]) from checkpoint, the shape in current model is torch.Size([16]).
size mismatch for roi_head.bbox_head.fc_reg.weight: copying a param with shape torch.Size([320, 1024]) from checkpoint, the shape in current model is torch.Size([60, 1024]).
size mismatch for roi_head.bbox_head.fc_reg.bias: copying a param with shape torch.Size([320]) from checkpoint, the shape in current model is torch.Size([60]).
size mismatch for roi_head.mask_head.conv_logits.weight: copying a param with shape torch.Size([80, 256, 1, 1]) from checkpoint, the shape in current model is torch.Size([15, 256, 1, 1]).
size mismatch for roi_head.mask_head.conv_logits.bias: copying a param with shape torch.Size([80]) from checkpoint, the shape in current model is torch.Size([15]).
2023-03-27 14:19:08,856 - mmdet - INFO - Start running, host: root@06d3ab7dae34, work_dir: /content/gdrive/MyDrive/htp_dir_swin
2023-03-27 14:19:08,858 - mmdet - INFO - Hooks will be executed in the following order:
before_run:
(VERY_HIGH   ) StepLrUpdaterHook                  
(NORMAL      ) CheckpointHook                     
(LOW         ) EvalHook                           
(VERY_LOW    ) TextLoggerHook                     
 -------------------- 
before_train_epoch:
(VERY_HIGH   ) StepLrUpdaterHook                  
(NORMAL      ) NumClassCheckHook                  
(LOW         ) IterTimerHook                      
(LOW         ) EvalHook                           
(VERY_LOW    ) TextLoggerHook                     
 -------------------- 
before_train_iter:
(VERY_HIGH   ) StepLrUpdaterHook                  
(LOW         ) IterTimerHook                      
(LOW         ) EvalHook                           
 -------------------- 
after_train_iter:
(ABOVE_NORMAL) OptimizerHook                      
(NORMAL      ) CheckpointHook                     
(LOW         ) IterTimerHook                      
(LOW         ) EvalHook                           
(VERY_LOW    ) TextLoggerHook                     
 -------------------- 
after_train_epoch:
(NORMAL      ) CheckpointHook                     
(LOW         ) EvalHook                           
(VERY_LOW    ) TextLoggerHook                     
 -------------------- 
before_val_epoch:
(NORMAL      ) NumClassCheckHook                  
(LOW         ) IterTimerHook                      
(VERY_LOW    ) TextLoggerHook                     
 -------------------- 
before_val_iter:
(LOW         ) IterTimerHook                      
 -------------------- 
after_val_iter:
(LOW         ) IterTimerHook                      
 -------------------- 
after_val_epoch:
(VERY_LOW    ) TextLoggerHook                     
 -------------------- 
after_run:
(VERY_LOW    ) TextLoggerHook                     
 -------------------- 
2023-03-27 14:19:08,859 - mmdet - INFO - workflow: [('train', 1)], max: 5 epochs
2023-03-27 14:19:08,860 - mmdet - INFO - Checkpoints will be saved to /content/gdrive/MyDrive/htp_dir_swin by HardDiskBackend.
---------------------------------------------------------------------------
KeyError                                  Traceback (most recent call last)
<ipython-input-35-c8cc0d536607> in <module>
      4 mmcv.mkdir_or_exist(os.path.abspath(cfg.work_dir))
      5 # epochs는 config의 runner 파라미터로 지정됨. 기본 12회
----> 6 train_detector(model, datasets, cfg, distributed=False, validate=True)


6 frames


/usr/local/lib/python3.9/dist-packages/mmdet-2.28.2-py3.9.egg/mmdet/apis/train.py in train_detector(model, dataset, cfg, distributed, validate, timestamp, meta)
    244     elif cfg.load_from:
    245         runner.load_checkpoint(cfg.load_from)
--> 246     runner.run(data_loaders, cfg.workflow)

/usr/local/lib/python3.9/dist-packages/mmcv/runner/epoch_based_runner.py in run(self, data_loaders, workflow, max_epochs, **kwargs)
    134                     if mode == 'train' and self.epoch >= self._max_epochs:
    135                         break
--> 136                     epoch_runner(data_loaders[i], **kwargs)
    137 
    138         time.sleep(1)  # wait for some hooks like loggers to finish

/usr/local/lib/python3.9/dist-packages/mmcv/runner/epoch_based_runner.py in train(self, data_loader, **kwargs)
     47         self.call_hook('before_train_epoch')
     48         time.sleep(2)  # Prevent possible deadlock during epoch transition
---> 49         for i, data_batch in enumerate(self.data_loader):
     50             self.data_batch = data_batch
     51             self._inner_iter = i

/usr/local/lib/python3.9/dist-packages/torch/utils/data/dataloader.py in __next__(self)
    626                 # TODO(https://github.com/pytorch/pytorch/issues/76750)
    627                 self._reset()  # type: ignore[call-arg]
--> 628             data = self._next_data()
    629             self._num_yielded += 1
    630             if self._dataset_kind == _DatasetKind.Iterable and \

/usr/local/lib/python3.9/dist-packages/torch/utils/data/dataloader.py in _next_data(self)
   1331             else:
   1332                 del self._task_info[idx]
-> 1333                 return self._process_data(data)
   1334 
   1335     def _try_put_index(self):

/usr/local/lib/python3.9/dist-packages/torch/utils/data/dataloader.py in _process_data(self, data)
   1357         self._try_put_index()
   1358         if isinstance(data, ExceptionWrapper):
-> 1359             data.reraise()
   1360         return data
   1361 

/usr/local/lib/python3.9/dist-packages/torch/_utils.py in reraise(self)
    541             # instantiate since we don't know how to
    542             raise RuntimeError(msg) from None
--> 543         raise exception
    544 
    545 

KeyError: Caught KeyError in DataLoader worker process 0.
Original Traceback (most recent call last):
  File "/usr/local/lib/python3.9/dist-packages/torch/utils/data/_utils/worker.py", line 302, in _worker_loop
    data = fetcher.fetch(index)
  File "/usr/local/lib/python3.9/dist-packages/torch/utils/data/_utils/fetch.py", line 58, in fetch
    data = [self.dataset[idx] for idx in possibly_batched_index]
  File "/usr/local/lib/python3.9/dist-packages/torch/utils/data/_utils/fetch.py", line 58, in <listcomp>
    data = [self.dataset[idx] for idx in possibly_batched_index]
  File "/usr/local/lib/python3.9/dist-packages/mmdet-2.28.2-py3.9.egg/mmdet/datasets/custom.py", line 220, in __getitem__
    data = self.prepare_train_img(idx)
  File "/usr/local/lib/python3.9/dist-packages/mmdet-2.28.2-py3.9.egg/mmdet/datasets/custom.py", line 243, in prepare_train_img
    return self.pipeline(results)
  File "/usr/local/lib/python3.9/dist-packages/mmdet-2.28.2-py3.9.egg/mmdet/datasets/pipelines/compose.py", line 41, in __call__
    data = t(data)
  File "/usr/local/lib/python3.9/dist-packages/mmdet-2.28.2-py3.9.egg/mmdet/datasets/pipelines/loading.py", line 398, in __call__
    results = self._load_masks(results)
  File "/usr/local/lib/python3.9/dist-packages/mmdet-2.28.2-py3.9.egg/mmdet/datasets/pipelines/loading.py", line 347, in _load_masks
    gt_masks = results['ann_info']['masks']
KeyError: 'masks'

답변 1

답변을 작성해보세요.

0

안녕하십니까,

음, faster rcnn은 object detection 모델인데, object detection용 dataset로 mask rcnn 적용이 어려울 것 같습니다만,,

어떤 dataset을 지금 이용하고 계시는건지요? 강의에 있는 데이터 세트인가요? 아님 다른 데이터 세트 인가요?

 

현재 다른 데이터 셋을 사용하고 있습니다. faster rcnn모델을 사용해서 object detection을 수행해보았고 mmdetection 패키지를 사용해서 다른 모델과의 성능 비교를 하고 싶은데 mask-rcnn은 적용이 어려울까요? ssd도 적용해보려고 하는데 자꾸 오류를 뱉어내네요 ㅠㅠ mmdetection configs에 swin이라고 되어 있길래 swin transformer인 줄 알았는데 backbone만 swin transformer를 사용하는 mask rcnn모델이었나봅니다

mask-rcnn을 적용하시려면 segmentation dataset을 적용하셔야 합니다. 이후 강의에 mmdetection mask-rcnn을 사용하는 강의가 segmentation 섹션에 있으니 해당 강의를 듣고 적용해 보시면 좋을 것 같습니다.

mmdetection의 SSD는 이슈가 많고, 성능이 좋지 않습니다.