【Python】PyTorchとYOLOv5で物体の種類、座標、幅、高さを検出する

概要

PyTorchとYOLOv5を使用して、画像の物体検出を行い物体の種類・左上のxy座標・幅・高さを求めてみます。

YOLOv5はCOCO datasetを利用しているので、全部で80種類の物体を検出できます。
https://cocodataset.org/

実行環境

  • Windows10
  • Python 3.9.5
  • YOLOv5 version5.0

動作環境準備

動作環境を設定するためpipコマンドを使用します。

1pip install -qr https://raw.githubusercontent.com/ultralytics/yolov5/master/requirements.txt

pipenvを利用する場合は下記のrequirements.txtを取得後、pipenv installを使用します。
https://raw.githubusercontent.com/ultralytics/yolov5/master/requirements.txt

1pipenv install -r requirements.txt

PyTorchとYOLOv5で物体検出

事前準備が完了したら物体検出を行います。
今回は下記の画像から人が検出できるか試してみます。

Photo by rawpixel.com from Pexels

物体検出コード

下記のコードで物体検出を行います。

 1import torch
 2
 3model = torch.hub.load("ultralytics/yolov5", "yolov5s", pretrained=True)
 4print(model.names)  # 検出できる物体の種類
 5
 6results = model("sample_picture.jpg")  # 画像パスを設定し、物体検出を行う
 7objects = results.pandas().xyxy[0]  # 検出結果を取得
 8
 9for i in range(len(objects)):
10    name = objects.name[i]
11    xmin = objects.xmin[i]
12    ymin = objects.ymin[i]
13    width = objects.xmax[i] - objects.xmin[i]
14    height = objects.ymax[i] - objects.ymin[i]
15    print(f"{i}, 種類:{name}, 座標x:{xmin}, 座標y:{ymin}, 幅:{width}, 高さ:{height}")
16
17# 0, 種類:person, 座標x:106.36381530761719, 座標y:392.3997802734375, 幅:509.06483459472656, 高さ:513.8624267578125
18# 1, 種類:person, 座標x:694.4984130859375, 座標y:410.292236328125, 幅:439.1519775390625, 高さ:497.43505859375
19
20results.show()  # 検出した物体の表示
21results.crop()  # 検出した物体の切り取り

コード解説

1model = torch.hub.load("ultralytics/yolov5", "yolov5s", pretrained=True)

PyTorch Hubから学習済みのモデルをダウンロードします。
既にモデルをダウンロード済みの場合は、自動的にダウンロード済みのモデルを利用します。

 1print(model.names)  # 検出できる物体の種類
 2
 3# ['person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train',
 4# 'truck', 'boat', 'traffic light', 'fire hydrant', 'stop sign', 'parking meter',
 5# 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow', 'elephant',
 6# 'bear', 'zebra', 'giraffe', 'backpack', 'umbrella', 'handbag', 'tie',
 7# 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball', 'kite', 'baseball bat',
 8# 'baseball glove', 'skateboard', 'surfboard', 'tennis racket', 'bottle',
 9# 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple',
10# 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut',
11# 'cake', 'chair', 'couch', 'potted plant', 'bed', 'dining table', 'toilet',
12#  'tv', 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone', 'microwave',
13# 'oven', 'toaster', 'sink', 'refrigerator', 'book', 'clock', 'vase', 'scissors',
14# 'teddy bear', 'hair drier', 'toothbrush']

model.namesをprintすると検出できる物体の種類が出力できます。

1results = model("sample_picture.jpg")  # 画像パスを設定し、物体検出を行う
2objects = results.pandas().xyxy[0]  # 検出結果を取得

画像パスを指定しYOLOv5で物体検出を行います。
検出結果をobjectsに入れます。

 1for i in range(len(objects)):
 2    name = objects.name[i]
 3    xmin = objects.xmin[i]
 4    ymin = objects.ymin[i]
 5    width = objects.xmax[i] - objects.xmin[i]
 6    height = objects.ymax[i] - objects.ymin[i]
 7    print(f"{i}, 種類:{name}, 座標x:{xmin}, 座標y:{ymin}, 幅:{width}, 高さ:{height}")
 8
 9# 0, 種類:person, 座標x:106.36381530761719, 座標y:392.3997802734375, 幅:509.06483459472656, 高さ:513.8624267578125
10# 1, 種類:person, 座標x:694.4984130859375, 座標y:410.292236328125, 幅:439.1519775390625, 高さ:497.43505859375

objectsには検出した物体の左上のxy座標・右下のxy座標・信頼度・クラスラベル・物体名が入っています。
検出した右下のxy座標から左上のxy座標を引いて、物体の幅と高さを計算しています。

1results.show()  # 検出した物体の表示

検出結果のshowメソッドを使うと検出結果を表示します。

1results.crop()  # 検出した物体の切り取り

cropメソッドを使うと、検出した結果を切り出して保存してくれます。
検出結果はrunsフォルダのexpフォルダ内に保存されます。

参考サイト

https://github.com/pytorch/pytorch

https://pytorch.org/hub/ultralytics_yolov5/

https://kikaben.com/yolov5-starter/

関連ページ