CenterNet:把目标检测视为中心点检测.
- paper:Objects as Points
CenterNet是一种anchor-free的目标检测模型,该模型直接检测目标的中心点和大小。
网络整体采用关键点检测结构。假设输入图像的尺寸为$(H,W,3)$,在预测的时候,产生关键点的热点图尺寸为$(H/R,W/R,C+4)$,其中$R$为输出对应原图的步长。默认是$C=80$(COCO格式)个类别、2个预测的高度和宽度、2个中心点的位置偏置。
class resnet50_Head(nn.Module):
def __init__(self, num_classes=80, channel=64, bn_momentum=0.1):
super(resnet50_Head, self).__init__()
#-----------------------------------------------------------------#
# 对获取到的特征进行上采样,进行分类预测和回归预测
# 128, 128, 64 -> 128, 128, 64 -> 128, 128, num_classes
# -> 128, 128, 64 -> 128, 128, 2
# -> 128, 128, 64 -> 128, 128, 2
#-----------------------------------------------------------------#
# 热力图预测部分
self.cls_head = nn.Sequential(
nn.Conv2d(64, channel,
kernel_size=3, padding=1, bias=False),
nn.BatchNorm2d(64, momentum=bn_momentum),
nn.ReLU(inplace=True),
nn.Conv2d(channel, num_classes,
kernel_size=1, stride=1, padding=0))
# 宽高预测的部分
self.wh_head = nn.Sequential(
nn.Conv2d(64, channel,
kernel_size=3, padding=1, bias=False),
nn.BatchNorm2d(64, momentum=bn_momentum),
nn.ReLU(inplace=True),
nn.Conv2d(channel, 2,
kernel_size=1, stride=1, padding=0))
# 中心点预测的部分
self.reg_head = nn.Sequential(
nn.Conv2d(64, channel,
kernel_size=3, padding=1, bias=False),
nn.BatchNorm2d(64, momentum=bn_momentum),
nn.ReLU(inplace=True),
nn.Conv2d(channel, 2,
kernel_size=1, stride=1, padding=0))
def forward(self, x):
hm = self.cls_head(x).sigmoid_()
wh = self.wh_head(x)
offset = self.reg_head(x)
return hm, wh, offset
损失函数由三部分构造:
⚪ 中心点预测损失
通过focal loss构造,并额外抑制与中心点足够接近的点的预测:
\[L_k=\frac{-1}{N} \sum_{x y c}\left\{\begin{array}{cc} \left(1-\hat{Y}_{x y c}\right)^\alpha \log \left(\hat{Y}_{x y c}\right) & \text { if } Y_{x y c}=1 \\ \left(1-Y_{x y c}\right)^\beta\left(\hat{Y}_{x y c}\right)^\alpha \log \left(1-\hat{Y}_{x y c}\right) & \text { otherwise } \end{array}\right.\]def focal_loss(pred, target):
pred = pred.permute(0, 2, 3, 1)
#-------------------------------------------------------------------------#
# 找到每张图片的正样本和负样本
# 一个真实框对应一个正样本
# 除去正样本的特征点,其余为负样本
#-------------------------------------------------------------------------#
pos_inds = target.eq(1).float()
neg_inds = target.lt(1).float()
#-------------------------------------------------------------------------#
# 正样本特征点附近的负样本的权值更小一些
#-------------------------------------------------------------------------#
neg_weights = torch.pow(1 - target, 4)
pred = torch.clamp(pred, 1e-6, 1 - 1e-6)
#-------------------------------------------------------------------------#
# 计算focal loss。难分类样本权重大,易分类样本权重小。
#-------------------------------------------------------------------------#
pos_loss = torch.log(pred) * torch.pow(1 - pred, 2) * pos_inds
neg_loss = torch.log(1 - pred) * torch.pow(pred, 2) * neg_weights * neg_inds
#-------------------------------------------------------------------------#
# 进行损失的归一化
#-------------------------------------------------------------------------#
num_pos = pos_inds.float().sum()
pos_loss = pos_loss.sum()
neg_loss = neg_loss.sum()
if num_pos == 0:
loss = -neg_loss
else:
loss = -(pos_loss + neg_loss) / num_pos
return loss
⚪ 目标大小与中心位置偏移损失
目标大小损失与中心位置偏移损失均采用l1损失:
def reg_l1_loss(pred, target, mask):
#--------------------------------#
# 计算l1_loss
#--------------------------------#
pred = pred.permute(0,2,3,1)
expand_mask = torch.unsqueeze(mask,-1).repeat(1,1,1,2)
loss = F.l1_loss(pred * expand_mask, target * expand_mask, reduction='sum')
loss = loss / (mask.sum() + 1e-4)
return loss
在预测阶段,首先针对一张图像进行下采样,随后对下采样后的图像进行预测,对于每个类在下采样的特征图中预测中心点,然后将输出图中的每个类的热点单独地提取出来。检测当前热点的值是否比周围的八个近邻点都大(或者等于),然后取100个这样的点。根据模型预测出来的当前中心点存在物体的概率值,设置阈值为0.3,从结果中选出大于该阈值的中心点作为最终的结果。整个过程没有显式地使用NMS。
def pool_nms(heat, kernel = 3):
pad = (kernel - 1) // 2
hmax = nn.functional.max_pool2d(heat, (kernel, kernel), stride=1, padding=pad)
keep = (hmax == heat).float()
return heat * keep
def decode_bbox(pred_hms, pred_whs, pred_offsets, confidence, cuda):
#-------------------------------------------------------------------------#
# 当利用512x512x3图片进行coco数据集预测的时候
# h = w = 128 num_classes = 80
# Hot map热力图 -> b, 80, 128, 128,
# 进行热力图的非极大抑制,利用3x3的卷积对热力图进行最大值筛选
# 找出一定区域内,得分最大的特征点。
#-------------------------------------------------------------------------#
pred_hms = pool_nms(pred_hms)
b, c, output_h, output_w = pred_hms.shape
detects = []
#-------------------------------------------------------------------------#
# 只传入一张图片,循环只进行一次
#-------------------------------------------------------------------------#
for batch in range(b):
#-------------------------------------------------------------------------#
# heat_map 128*128, num_classes 热力图
# pred_wh 128*128, 2 特征点的预测宽高
# pred_offset 128*128, 2 特征点的xy轴偏移情况
#-------------------------------------------------------------------------#
heat_map = pred_hms[batch].permute(1, 2, 0).view([-1, c])
pred_wh = pred_whs[batch].permute(1, 2, 0).view([-1, 2])
pred_offset = pred_offsets[batch].permute(1, 2, 0).view([-1, 2])
yv, xv = torch.meshgrid(torch.arange(0, output_h), torch.arange(0, output_w))
#-------------------------------------------------------------------------#
# xv 128*128, 特征点的x轴坐标
# yv 128*128, 特征点的y轴坐标
#-------------------------------------------------------------------------#
xv, yv = xv.flatten().float(), yv.flatten().float()
if cuda:
xv = xv.cuda()
yv = yv.cuda()
#-------------------------------------------------------------------------#
# class_conf 128*128, 特征点的种类置信度
# class_pred 128*128, 特征点的种类
#-------------------------------------------------------------------------#
class_conf, class_pred = torch.max(heat_map, dim = -1)
mask = class_conf > confidence
#-----------------------------------------#
# 取出得分筛选后对应的结果
#-----------------------------------------#
pred_wh_mask = pred_wh[mask]
pred_offset_mask = pred_offset[mask]
if len(pred_wh_mask) == 0:
detects.append([])
continue
#----------------------------------------#
# 计算调整后预测框的中心
#----------------------------------------#
xv_mask = torch.unsqueeze(xv[mask] + pred_offset_mask[..., 0], -1)
yv_mask = torch.unsqueeze(yv[mask] + pred_offset_mask[..., 1], -1)
#----------------------------------------#
# 计算预测框的宽高
#----------------------------------------#
half_w, half_h = pred_wh_mask[..., 0:1] / 2, pred_wh_mask[..., 1:2] / 2
#----------------------------------------#
# 获得预测框的左上角和右下角
#----------------------------------------#
bboxes = torch.cat([xv_mask - half_w, yv_mask - half_h, xv_mask + half_w, yv_mask + half_h], dim=1)
bboxes[:, [0, 2]] /= output_w
bboxes[:, [1, 3]] /= output_h
detect = torch.cat([bboxes, torch.unsqueeze(class_conf[mask],-1), torch.unsqueeze(class_pred[mask],-1).float()], dim=-1)
detects.append(detect)
return detects
CenterNet的完整实现可参考centernet-pytorch。