@@ -208,6 +208,56 @@ def detect_image(self, image, crop = False, count = False):
208208 del draw
209209
210210 return image
211+
212+ def get_FPS (self , image , test_interval ):
213+ image_shape = np .array (np .shape (image )[0 :2 ])
214+ #---------------------------------------------------------#
215+ # 在这里将图像转换成RGB图像,防止灰度图在预测时报错。
216+ # 代码仅仅支持RGB图像的预测,所有其它类型的图像都会转化成RGB
217+ #---------------------------------------------------------#
218+ image = cvtColor (image )
219+ #---------------------------------------------------------#
220+ # 给图像增加灰条,实现不失真的resize
221+ # 也可以直接resize进行识别
222+ #---------------------------------------------------------#
223+ image_data = resize_image (image , (self .input_shape [1 ],self .input_shape [0 ]), self .letterbox_image )
224+ #---------------------------------------------------------#
225+ # 添加上batch_size维度
226+ #---------------------------------------------------------#
227+ image_data = np .expand_dims (np .transpose (preprocess_input (np .array (image_data , dtype = 'float32' )), (2 , 0 , 1 )), 0 )
228+
229+ with torch .no_grad ():
230+ images = torch .from_numpy (image_data )
231+ if self .cuda :
232+ images = images .cuda ()
233+ #---------------------------------------------------------#
234+ # 将图像输入网络当中进行预测!
235+ #---------------------------------------------------------#
236+ outputs = self .net (images )
237+ outputs = decode_outputs (outputs , self .input_shape )
238+ #---------------------------------------------------------#
239+ # 将预测框进行堆叠,然后进行非极大抑制
240+ #---------------------------------------------------------#
241+ results = non_max_suppression (outputs , self .num_classes , self .input_shape ,
242+ image_shape , self .letterbox_image , conf_thres = self .confidence , nms_thres = self .nms_iou )
243+
244+ t1 = time .time ()
245+ for _ in range (test_interval ):
246+ with torch .no_grad ():
247+ #---------------------------------------------------------#
248+ # 将图像输入网络当中进行预测!
249+ #---------------------------------------------------------#
250+ outputs = self .net (images )
251+ outputs = decode_outputs (outputs , self .input_shape )
252+ #---------------------------------------------------------#
253+ # 将预测框进行堆叠,然后进行非极大抑制
254+ #---------------------------------------------------------#
255+ results = non_max_suppression (outputs , self .num_classes , self .input_shape ,
256+ image_shape , self .letterbox_image , conf_thres = self .confidence , nms_thres = self .nms_iou )
257+
258+ t2 = time .time ()
259+ tact_time = (t2 - t1 ) / test_interval
260+ return tact_time
211261
212262 def detect_heatmap (self , image , heatmap_save_path ):
213263 import cv2
@@ -265,56 +315,6 @@ def sigmoid(x):
265315 plt .savefig (heatmap_save_path , dpi = 200 )
266316 print ("Save to the " + heatmap_save_path )
267317 plt .cla ()
268-
269- def get_FPS (self , image , test_interval ):
270- image_shape = np .array (np .shape (image )[0 :2 ])
271- #---------------------------------------------------------#
272- # 在这里将图像转换成RGB图像,防止灰度图在预测时报错。
273- # 代码仅仅支持RGB图像的预测,所有其它类型的图像都会转化成RGB
274- #---------------------------------------------------------#
275- image = cvtColor (image )
276- #---------------------------------------------------------#
277- # 给图像增加灰条,实现不失真的resize
278- # 也可以直接resize进行识别
279- #---------------------------------------------------------#
280- image_data = resize_image (image , (self .input_shape [1 ],self .input_shape [0 ]), self .letterbox_image )
281- #---------------------------------------------------------#
282- # 添加上batch_size维度
283- #---------------------------------------------------------#
284- image_data = np .expand_dims (np .transpose (preprocess_input (np .array (image_data , dtype = 'float32' )), (2 , 0 , 1 )), 0 )
285-
286- with torch .no_grad ():
287- images = torch .from_numpy (image_data )
288- if self .cuda :
289- images = images .cuda ()
290- #---------------------------------------------------------#
291- # 将图像输入网络当中进行预测!
292- #---------------------------------------------------------#
293- outputs = self .net (images )
294- outputs = decode_outputs (outputs , self .input_shape )
295- #---------------------------------------------------------#
296- # 将预测框进行堆叠,然后进行非极大抑制
297- #---------------------------------------------------------#
298- results = non_max_suppression (outputs , self .num_classes , self .input_shape ,
299- image_shape , self .letterbox_image , conf_thres = self .confidence , nms_thres = self .nms_iou )
300-
301- t1 = time .time ()
302- for _ in range (test_interval ):
303- with torch .no_grad ():
304- #---------------------------------------------------------#
305- # 将图像输入网络当中进行预测!
306- #---------------------------------------------------------#
307- outputs = self .net (images )
308- outputs = decode_outputs (outputs , self .input_shape )
309- #---------------------------------------------------------#
310- # 将预测框进行堆叠,然后进行非极大抑制
311- #---------------------------------------------------------#
312- results = non_max_suppression (outputs , self .num_classes , self .input_shape ,
313- image_shape , self .letterbox_image , conf_thres = self .confidence , nms_thres = self .nms_iou )
314-
315- t2 = time .time ()
316- tact_time = (t2 - t1 ) / test_interval
317- return tact_time
318318
319319 def convert_to_onnx (self , simplify , model_path ):
320320 import onnx
0 commit comments