参考资料 ultralytics/utils/tal.py
备注
该文件可从https://github.com/ultralytics/ultralytics/blob/main/ ultralytics/utils/tal .py。如果您发现问题,请通过提交 Pull Request🛠️ 帮助修复。谢谢🙏!
ultralytics.utils.tal.TaskAlignedAssigner
垒球 Module
用于物体检测的任务对齐分配器
该类根据任务对齐度量标准将地面实况(gt)对象分配给锚点,该标准结合了分类和定位信息。 分类和定位信息。
属性
名称 | 类型 | 说明 |
---|---|---|
topk |
int
|
需要考虑的最佳人选的数量。 |
num_classes |
int
|
对象类别的数量。 |
alpha |
float
|
任务对齐度量中分类部分的 alpha 参数。 |
beta |
float
|
任务对齐度量中定位部分的贝塔参数。 |
eps |
float
|
一个小值,防止除以零。 |
源代码 ultralytics/utils/tal.py
13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 |
|
__init__(topk=13, num_classes=80, alpha=1.0, beta=6.0, eps=1e-09)
使用可定制的超参数初始化一个 TaskAlignedAssigner 对象。
源代码 ultralytics/utils/tal.py
forward(pd_scores, pd_bboxes, anc_points, gt_labels, gt_bboxes, mask_gt)
计算任务对齐分配。参考代码见 https://github.com/Nioolek/PPYOLOE_pytorch/blob/master/ppyoloe/assigner/tal_assigner.py。
参数
名称 | 类型 | 说明 | 默认值 |
---|---|---|---|
pd_scores |
Tensor
|
shape(bs, num_total_anchors, num_classes) |
所需 |
pd_bboxes |
Tensor
|
shape(bs, num_total_anchors, 4) |
所需 |
anc_points |
Tensor
|
shape(num_total_anchors, 2) |
所需 |
gt_labels |
Tensor
|
shape(bs, n_max_boxes, 1) |
所需 |
gt_bboxes |
Tensor
|
shape(bs, n_max_boxes, 4) |
所需 |
mask_gt |
Tensor
|
shape(bs, n_max_boxes, 1) |
所需 |
返回:
名称 | 类型 | 说明 |
---|---|---|
target_labels |
Tensor
|
shape(bs, num_total_anchors) |
target_bboxes |
Tensor
|
shape(bs, num_total_anchors, 4) |
target_scores |
Tensor
|
shape(bs, num_total_anchors, num_classes) |
fg_mask |
Tensor
|
shape(bs, num_total_anchors) |
target_gt_idx |
Tensor
|
shape(bs, num_total_anchors) |
源代码 ultralytics/utils/tal.py
get_box_metrics(pd_scores, pd_bboxes, gt_labels, gt_bboxes, mask_gt)
给定预测边界框和地面实况边界框,计算对齐度量。
源代码 ultralytics/utils/tal.py
get_pos_mask(pd_scores, pd_bboxes, gt_labels, gt_bboxes, anc_points, mask_gt)
获取 in_gts mask、(b、max_num_obj、h*w)。
源代码 ultralytics/utils/tal.py
get_targets(gt_labels, gt_bboxes, target_gt_idx, fg_mask)
计算正锚点的目标标签、目标边界框和目标分数。
参数
名称 | 类型 | 说明 | 默认值 |
---|---|---|---|
gt_labels |
Tensor
|
形状为 (b, max_num_obj, 1) 的地面实况标签,其中 b 是批次大小,max_num_obj 是对象的最大数量。 是对象的最大数量。 |
所需 |
gt_bboxes |
Tensor
|
形状为 (b, max_num_obj, 4) 的地面真实边界框。 |
所需 |
target_gt_idx |
Tensor
|
为正锚点分配的地面实况对象索引 形状为(b, hw),其中 hw是锚点总数。 锚点总数。 |
所需 |
fg_mask |
Tensor
|
布尔值tensor ,形状为 (b,h*w),表示正(前景)锚点。 (前景)锚点。 |
所需 |
返回:
类型 | 说明 |
---|---|
Tuple[Tensor, Tensor, Tensor]
|
包含以下张量的元组: - target_labels (Tensor):形状(b,hw),包含正锚点的目标标签。 正锚点的目标标签。 - target_bboxes (Tensor):形状(b, hw, 4),包含正锚点的目标边界框 正锚点的目标边界框。 - target_scores (Tensor):形状(b, h*w, num_classes),包含正锚点的目标分数。 其中,num_classes 是对象类别的数量。 类的数量。 |
源代码 ultralytics/utils/tal.py
iou_calculation(gt_bboxes, pd_bboxes)
select_candidates_in_gts(xy_centers, gt_bboxes, eps=1e-09)
staticmethod
在 gt 中选择正锚中心。
参数
名称 | 类型 | 说明 | 默认值 |
---|---|---|---|
xy_centers |
Tensor
|
shape(h*w, 2) |
所需 |
gt_bboxes |
Tensor
|
shape(b, n_boxes, 4) |
所需 |
返回:
类型 | 说明 |
---|---|
Tensor
|
shape(b, n_boxes, h*w) |
源代码 ultralytics/utils/tal.py
select_highest_overlaps(mask_pos, overlaps, n_max_boxes)
staticmethod
如果一个锚箱分配给多个 gts,则会选择 IoU 最高的那个。
参数
名称 | 类型 | 说明 | 默认值 |
---|---|---|---|
mask_pos |
Tensor
|
shape(b, n_max_boxes, h*w) |
所需 |
overlaps |
Tensor
|
shape(b, n_max_boxes, h*w) |
所需 |
返回:
名称 | 类型 | 说明 |
---|---|---|
target_gt_idx |
Tensor
|
形状(b,h*w) |
fg_mask |
Tensor
|
形状(b,h*w) |
mask_pos |
Tensor
|
shape(b, n_max_boxes, h*w) |
源代码 ultralytics/utils/tal.py
select_topk_candidates(metrics, largest=True, topk_mask=None)
根据给定的指标选择前 k 个候选者。
参数
名称 | 类型 | 说明 | 默认值 |
---|---|---|---|
metrics |
Tensor
|
tensor 的形状(b, max_num_obj,hw),其中 b 是批量大小、 max_num_obj 是对象的最大数量,hw代表锚点总数。 锚点总数。 |
所需 |
largest |
bool
|
如果为 True,则选择最大值;否则选择最小值。 |
True
|
topk_mask |
Tensor
|
一个可选的布尔型tensor ,形状为(b, max_num_obj, topk),其中 topk 是要考虑的顶级候选对象的数量。如果未提供、 则会根据给定的指标自动计算出 top-k 值。 |
None
|
返回:
类型 | 说明 |
---|---|
Tensor
|
形状为(b、max_num_obj、h*w)的tensor ,其中包含选定的前 k 个候选对象。 |
源代码 ultralytics/utils/tal.py
ultralytics.utils.tal.RotatedTaskAlignedAssigner
源代码 ultralytics/utils/tal.py
iou_calculation(gt_bboxes, pd_bboxes)
select_candidates_in_gts(xy_centers, gt_bboxes)
staticmethod
在 gt 中为旋转边界框选择正锚中心。
参数
名称 | 类型 | 说明 | 默认值 |
---|---|---|---|
xy_centers |
Tensor
|
shape(h*w, 2) |
所需 |
gt_bboxes |
Tensor
|
shape(b, n_boxes, 5) |
所需 |
返回:
类型 | 说明 |
---|---|
Tensor
|
shape(b, n_boxes, h*w) |
源代码 ultralytics/utils/tal.py
ultralytics.utils.tal.make_anchors(feats, strides, grid_cell_offset=0.5)
根据特征生成锚点
源代码 ultralytics/utils/tal.py
ultralytics.utils.tal.dist2bbox(distance, anchor_points, xywh=True, dim=-1)
将距离(ltrb)转换为方框(xywh 或 xyxy)。
源代码 ultralytics/utils/tal.py
ultralytics.utils.tal.bbox2dist(anchor_points, bbox, reg_max)
ultralytics.utils.tal.dist2rbox(pred_dist, pred_angle, anchor_points, dim=-1)
根据锚点和分布解码预测对象边界框坐标。
参数
名称 | 类型 | 说明 | 默认值 |
---|---|---|---|
pred_dist |
Tensor
|
预测的旋转距离,(bs,h*w,4)。 |
所需 |
pred_angle |
Tensor
|
预测角度,(bs,h*w,1)。 |
所需 |
anchor_points |
Tensor
|
锚点,(h*w,2)。 |
所需 |
返回: (torch.Tensor):预测的旋转边界框,(bs, h*w, 4)。