【Python】PyTorchとYOLOv5で物体の種類、座標、幅、高さを検出する
概要
PyTorchとYOLOv5を使用して、画像の物体検出を行い物体の種類・左上のxy座標・幅・高さを求めてみます。
YOLOv5はCOCO datasetを利用しているので、全部で80種類の物体を検出できます。
https://cocodataset.org/
実行環境
- Windows10
- Python 3.9.5
- YOLOv5 version5.0
動作環境準備
動作環境を設定するためpipコマンドを使用します。
1pip install -qr https://raw.githubusercontent.com/ultralytics/yolov5/master/requirements.txt
pipenvを利用する場合は下記のrequirements.txtを取得後、pipenv installを使用します。
https://raw.githubusercontent.com/ultralytics/yolov5/master/requirements.txt
1pipenv install -r requirements.txt
PyTorchとYOLOv5で物体検出
事前準備が完了したら物体検出を行います。
今回は下記の画像から人が検出できるか試してみます。

Photo by rawpixel.com from Pexels
物体検出コード
下記のコードで物体検出を行います。
1import torch
2
3model = torch.hub.load("ultralytics/yolov5", "yolov5s", pretrained=True)
4print(model.names) # 検出できる物体の種類
5
6results = model("sample_picture.jpg") # 画像パスを設定し、物体検出を行う
7objects = results.pandas().xyxy[0] # 検出結果を取得
8
9for i in range(len(objects)):
10 name = objects.name[i]
11 xmin = objects.xmin[i]
12 ymin = objects.ymin[i]
13 width = objects.xmax[i] - objects.xmin[i]
14 height = objects.ymax[i] - objects.ymin[i]
15 print(f"{i}, 種類:{name}, 座標x:{xmin}, 座標y:{ymin}, 幅:{width}, 高さ:{height}")
16
17# 0, 種類:person, 座標x:106.36381530761719, 座標y:392.3997802734375, 幅:509.06483459472656, 高さ:513.8624267578125
18# 1, 種類:person, 座標x:694.4984130859375, 座標y:410.292236328125, 幅:439.1519775390625, 高さ:497.43505859375
19
20results.show() # 検出した物体の表示
21results.crop() # 検出した物体の切り取り
コード解説
1model = torch.hub.load("ultralytics/yolov5", "yolov5s", pretrained=True)
PyTorch Hubから学習済みのモデルをダウンロードします。
既にモデルをダウンロード済みの場合は、自動的にダウンロード済みのモデルを利用します。
1print(model.names) # 検出できる物体の種類
2
3# ['person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train',
4# 'truck', 'boat', 'traffic light', 'fire hydrant', 'stop sign', 'parking meter',
5# 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow', 'elephant',
6# 'bear', 'zebra', 'giraffe', 'backpack', 'umbrella', 'handbag', 'tie',
7# 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball', 'kite', 'baseball bat',
8# 'baseball glove', 'skateboard', 'surfboard', 'tennis racket', 'bottle',
9# 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple',
10# 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut',
11# 'cake', 'chair', 'couch', 'potted plant', 'bed', 'dining table', 'toilet',
12# 'tv', 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone', 'microwave',
13# 'oven', 'toaster', 'sink', 'refrigerator', 'book', 'clock', 'vase', 'scissors',
14# 'teddy bear', 'hair drier', 'toothbrush']
model.namesをprintすると検出できる物体の種類が出力できます。
1results = model("sample_picture.jpg") # 画像パスを設定し、物体検出を行う
2objects = results.pandas().xyxy[0] # 検出結果を取得
画像パスを指定しYOLOv5で物体検出を行います。
検出結果をobjectsに入れます。
1for i in range(len(objects)):
2 name = objects.name[i]
3 xmin = objects.xmin[i]
4 ymin = objects.ymin[i]
5 width = objects.xmax[i] - objects.xmin[i]
6 height = objects.ymax[i] - objects.ymin[i]
7 print(f"{i}, 種類:{name}, 座標x:{xmin}, 座標y:{ymin}, 幅:{width}, 高さ:{height}")
8
9# 0, 種類:person, 座標x:106.36381530761719, 座標y:392.3997802734375, 幅:509.06483459472656, 高さ:513.8624267578125
10# 1, 種類:person, 座標x:694.4984130859375, 座標y:410.292236328125, 幅:439.1519775390625, 高さ:497.43505859375
objectsには検出した物体の左上のxy座標・右下のxy座標・信頼度・クラスラベル・物体名が入っています。
検出した右下のxy座標から左上のxy座標を引いて、物体の幅と高さを計算しています。
1results.show() # 検出した物体の表示
検出結果のshowメソッドを使うと検出結果を表示します。

1results.crop() # 検出した物体の切り取り
cropメソッドを使うと、検出した結果を切り出して保存してくれます。
検出結果はrunsフォルダのexpフォルダ内に保存されます。


参考サイト
https://github.com/pytorch/pytorch