[개정판] 딥러닝 컴퓨터 비전 완벽 가이드

잘 모르겠습니다.

강사님 ssd_mobile_net 코드를 따라서 작성하던 중 아래와 같은 오류가 나와서 강사님의 코드를 복사 붙여 넣기 해도 같은 오류가 나와서 어떻게 해야할지 모르겠습니다.

def get_tensor_detected_image(sess, img_array, use_copied_array):

    rows = img_array.shape[0]

    cols = img_array.shape[1]


    if use_copied_array:

        draw_img = img_array.copy()


        draw_img = img_array


    inp = cv2.resize(img_array, (300,300))

    inp = inp[:,:,[2,1,0]]


    start = time.time()


    out = sess.run([sess.graph.get_tensor_by_name('num_detections:0'),




                  feed_dict={'image_tensor:0':inp.reshape(1, inp.shape[0], inp.shape[1], 3)})

    green_color = (0,255,0)

    red_color = (0,0,255)


    num_detections = int(out[0][0])


    for i in range(num_detections):

        classId = int(out[3][0][i])

        score = float(out[1][0][i])

        bbox = [float(v) for v in out[2][0][i]]

        if score > 0.3:

            left = bbox[1] * cols

            top = bbox[0] * rows

            right = bbox[3] * cols

            bottom = bbox[2] * rows

            cv2.rectangle(draw_img, (int(left), int(top)), (int(right), int(bottom)), green_color, thickness = 2)

            caption = "{}:{:.4f}".format(labels_to_names[classId], score)

            cv2.putText(draw_img, caption, (int(left), int(top - 5)), cv2.FONT_HERSHEY_SIMPLEX, 0.4, red_color, 1)

    print('Detection 수행시간:', round(time.time() - start,3),"초")

    return draw_img

import numpy as np

import tensorflow as tf

import cv2

import time

import matplotlib.pyplot as plt

%matplotlib inline

video_input_path = '../../data/video/Night_Day_Chase.mp4'

video_output_path = '../../data/output/Night_Day_Chase_tensor_ssd_mobile_01.mp4'

cap =  cv2.VideoCapture(video_input_path)

codec = cv2.VideoWriter_fourcc(*'XVID')

vid_size = (round(cap.get(cv2.CAP_PROP_FRAME_WIDTH)), round(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)))

vid_fps = cap.get(cv2.CAP_PROP_FPS)

vid_writer = cv2.VideoWriter(video_output_path, codec, vid_fps, vid_size)

frame_cnt = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))

print('총 Frame의 갯수:', frame_cnt, 'FPS:',vid_fps)

with tf.gfile.FastGFile('/home/bgw2001/DLCV/Detection/ssd/pretrained/ssd_mobilenet_v2_coco_2018_03_29/frozen_inference_graph.pb','rb') as f:

    graph_def = tf.GraphDef()


with tf.Session() as sess:


    tf.import_graph_def(graph_def, name = '')

    index = 0

    while True:

        hasFrame, img_frame = cap.read()

        if not hasFrame:

            print('더 이상 처리할 frame이 없습니다.')


        draw_img_frame = get_tensor_detected_image(sess=sess, img_array=img_frame, use_copied_array=False)




InvalidArgumentError                      Traceback (most recent call last)
~/anaconda3/envs/tf113/lib/python3.6/site-packages/tensorflow/python/client/session.py in _do_call(self, fn, *args)
   1333     try:
-> 1334       return fn(*args)
   1335     except errors.OpError as e:

~/anaconda3/envs/tf113/lib/python3.6/site-packages/tensorflow/python/client/session.py in _run_fn(feed_dict, fetch_list, target_list, options, run_metadata)
   1316       # Ensure any changes to the graph are reflected in the runtime.
-> 1317       self._extend_graph()
   1318       return self._call_tf_sessionrun(

~/anaconda3/envs/tf113/lib/python3.6/site-packages/tensorflow/python/client/session.py in _extend_graph(self)
   1351     with self._graph._session_run_lock():  # pylint: disable=protected-access
-> 1352       tf_session.ExtendSession(self._session)

InvalidArgumentError: Input 1 of node Preprocessor/map/while/Merge_2_1 was passed int32 from Preprocessor/map/while/NextIteration_2:0 incompatible with expected float.

During handling of the above exception, another exception occurred:

InvalidArgumentError                      Traceback (most recent call last)
<ipython-input-48-071e2e179d8f> in <module>
     12             print('더 이상 처리할 frame이 없습니다.')
     13             break
---> 14         draw_img_frame = get_tensor_detected_image(sess=sess, img_array=img_frame, use_copied_array=False)
     15         vid_writer.write(draw_img_frame)
     16 vid_writer.release()

<ipython-input-46-52db48ba4bca> in get_tensor_detected_image(sess, img_array, use_copied_array)
     17                    sess.graph.get_tensor_by_name('detection_boxes:0'),
     18                    sess.graph.get_tensor_by_name('detection_classes:0')],
---> 19                   feed_dict={'image_tensor:0':inp.reshape(1, inp.shape[0], inp.shape[1], 3)})
     20     green_color = (0,255,0)
     21     red_color = (0,0,255)

~/anaconda3/envs/tf113/lib/python3.6/site-packages/tensorflow/python/client/session.py in run(self, fetches, feed_dict, options, run_metadata)
    927     try:
    928       result = self._run(None, fetches, feed_dict, options_ptr,
--> 929                          run_metadata_ptr)
    930       if run_metadata:
    931         proto_data = tf_session.TF_GetBuffer(run_metadata_ptr)

~/anaconda3/envs/tf113/lib/python3.6/site-packages/tensorflow/python/client/session.py in _run(self, handle, fetches, feed_dict, options, run_metadata)
   1150     if final_fetches or final_targets or (handle and feed_dict_tensor):
   1151       results = self._do_run(handle, final_targets, final_fetches,
-> 1152                              feed_dict_tensor, options, run_metadata)
   1153     else:
   1154       results = []

~/anaconda3/envs/tf113/lib/python3.6/site-packages/tensorflow/python/client/session.py in _do_run(self, handle, target_list, fetch_list, feed_dict, options, run_metadata)
   1326     if handle is None:
   1327       return self._do_call(_run_fn, feeds, fetches, targets, options,
-> 1328                            run_metadata)
   1329     else:
   1330       return self._do_call(_prun_fn, handle, feeds, fetches)

~/anaconda3/envs/tf113/lib/python3.6/site-packages/tensorflow/python/client/session.py in _do_call(self, fn, *args)
   1346           pass
   1347       message = error_interpolation.interpolate(message, self._graph)
-> 1348       raise type(e)(node_def, op, message)
   1350   def _extend_graph(self):

InvalidArgumentError: Input 1 of node Preprocessor/map/while/Merge_2_1 was passed int32 from Preprocessor/map/while/NextIteration_2:0 incompatible with expected float.

권 철민


문제가 해결되서 다행입니다.



재기동 하라는 말이 shutdown이란 건줄 모르고 계속 reset만 시켜서 에러가 난 것같습니다 ~! 죄송합니다 ~! 

