작성자 없음
작성자 정보가 삭제된 글입니다.
작성
·
261
0
안녕하세요.
배운 내용을 기반으로 yolonas를 학습해보는 과정에서 질문이 있습니다.
nas에서 기본적으로 사용하고 있는 transforms 대신에 albumentations 라이브러리를 사용하고 싶은데 계속해서 image 가 없다는 에러가 뜹니다.
코드를 어떻게 수정해야하는지 궁금합니다.
############## 기존 학습 코드
from super_gradients.training import Trainer
from super_gradients.training import dataloaders
from super_gradients.training.dataloaders.dataloaders import (
coco_detection_yolo_format_train,
coco_detection_yolo_format_val
)
from super_gradients.training import models
from super_gradients.training.losses import PPYoloELoss
from super_gradients.training.metrics import (
DetectionMetrics_050,
DetectionMetrics_050_095
)
from super_gradients.training.models.detection_models.pp_yolo_e import PPYoloEPostPredictionCallback
from tqdm.auto import tqdm
import os
import requests
import zipfile
import cv2
import matplotlib.pyplot as plt
import glob
import numpy as np
import random
ROOT_DIR = '/home/바탕화면/test_data'
train_imgs_dir = 'train/images'
train_labels_dir = 'train/labels'
val_imgs_dir = 'val/images'
val_labels_dir = 'val/labels'
classes = ['fallen', 'normal']
dataset_params = {
'data_dir':ROOT_DIR,
'train_images_dir':train_imgs_dir,
'train_labels_dir':train_labels_dir,
'val_images_dir':val_imgs_dir,
'val_labels_dir':val_labels_dir,
'classes':classes
}
EPOCHS = 50
BATCH_SIZE = 16
WORKERS = 8
train_data = coco_detection_yolo_format_train(
dataset_params={
'data_dir': dataset_params['data_dir'],
'images_dir': dataset_params['train_images_dir'],
'labels_dir': dataset_params['train_labels_dir'],
'classes': dataset_params['classes']
},
dataloader_params={
'batch_size':BATCH_SIZE,
'num_workers':WORKERS
}
)
val_data = coco_detection_yolo_format_val(
dataset_params={
'data_dir': dataset_params['data_dir'],
'images_dir': dataset_params['val_images_dir'],
'labels_dir': dataset_params['val_labels_dir'],
'classes': dataset_params['classes']
},
dataloader_params={
'batch_size':BATCH_SIZE,
'num_workers':WORKERS
}
)
train_params = {
'silent_mode': False,
"average_best_models":True,
"warmup_mode": "linear_epoch_step",
"warmup_initial_lr": 1e-6,
"lr_warmup_epochs": 3,
"initial_lr": 5e-4,
"lr_mode": "cosine",
"cosine_final_lr_ratio": 0.1,
"optimizer": "Adam",
"optimizer_params": {"weight_decay": 0.0001},
"zero_weight_decay_on_bias_and_bn": True,
"ema": True,
"ema_params": {"decay": 0.9, "decay_type": "threshold"},
"max_epochs": EPOCHS,
"mixed_precision": True,
"loss": PPYoloELoss(
use_static_assigner=False,
num_classes=len(dataset_params['classes']),
reg_max=16
),
"valid_metrics_list": [
DetectionMetrics_050(
score_thres=0.1,
top_k_predictions=300,
num_cls=len(dataset_params['classes']),
normalize_targets=True,
post_prediction_callback=PPYoloEPostPredictionCallback(
score_threshold=0.01,
nms_top_k=1000,
max_predictions=300,
nms_threshold=0.7
)
),
DetectionMetrics_050_095(
score_thres=0.1,
top_k_predictions=300,
num_cls=len(dataset_params['classes']),
normalize_targets=True,
post_prediction_callback=PPYoloEPostPredictionCallback(
score_threshold=0.01,
nms_top_k=1000,
max_predictions=300,
nms_threshold=0.7
)
)
],
"metric_to_watch": 'mAP@0.50:0.95'
}
trainer = Trainer(
experiment_name='yolo_nas_m',
ckpt_root_dir='checkpoints'
)
model = models.get(
'yolo_nas_m',
num_classes=len(dataset_params['classes']),
pretrained_weights="coco"
)
trainer.train(
model=model,
training_params=train_params,
train_loader=train_data,
valid_loader=val_data
)
############## 기존 학습 코드에서 변경 시킨 부분
train_data = coco_detection_yolo_format_train(
dataset_params={
'data_dir': dataset_params['data_dir'],
'images_dir': dataset_params['train_images_dir'],
'labels_dir': dataset_params['train_labels_dir'],
'classes': dataset_params['classes'],
'transforms' : [ A.CLAHE(p=1.0),
A.RandomBrightnessContrast(p=1.0),
A.RandomGamma(p=1.0),
]
},
dataloader_params={
'batch_size':BATCH_SIZE,
'num_workers':WORKERS
}
)
val_data = coco_detection_yolo_format_val(
dataset_params={
'data_dir': dataset_params['data_dir'],
'images_dir': dataset_params['val_images_dir'],
'labels_dir': dataset_params['val_labels_dir'],
'classes': dataset_params['classes'],
'transforms' : [ A.CLAHE(p=1.0),
A.RandomBrightnessContrast(p=1.0),
A.RandomGamma(p=1.0),
]
},
dataloader_params={
'batch_size':BATCH_SIZE,
'num_workers':WORKERS
}
)
답변 1
0
안녕하십니까,
제 강의에 나오는 패키지가 아닌것 같습니다만,
해당 패키지는 제가 사용해 보지 않아서 제대로 답해 드리기가 어려울 것 같습니다. 다만 오류 메시지상 apply_transform() 메소드 호출 될때 'image' key가 없다는 오류로 봐서는 albumentation 적용이 안되거나 뭔가 다른 방식으로 적용하셔야 할 것 같습니다.
감사합니다.