From 7ee9f27471629f7d02d8d74ef57d3c842765da24 Mon Sep 17 00:00:00 2001 From: yoiannis Date: Sun, 2 Mar 2025 17:21:21 +0800 Subject: [PATCH] =?UTF-8?q?=E5=9B=BA=E5=AE=9A=E6=8D=9F=E5=A4=B1=E6=9D=83?= =?UTF-8?q?=E9=87=8D?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- ultralytics/nn/tasks.py | 4 ++-- ultralytics/utils/loss.py | 29 ++++++++++++++++++++++++++++- 2 files changed, 30 insertions(+), 3 deletions(-) diff --git a/ultralytics/nn/tasks.py b/ultralytics/nn/tasks.py index 75f4576..c9fbd5c 100644 --- a/ultralytics/nn/tasks.py +++ b/ultralytics/nn/tasks.py @@ -69,6 +69,7 @@ from ultralytics.utils.loss import ( v8ClassificationLoss, v8MTLClassificationLoss, v8DetectionLoss, + v8MTLUWClassificationLoss, v8OBBLoss, v8PoseLoss, v8SegmentationLoss, @@ -95,7 +96,6 @@ from ultralytics.nn.backbone.revcol import * from ultralytics.nn.backbone.lsknet import * from ultralytics.nn.backbone.SwinTransformer import * from ultralytics.nn.backbone.repvit import * -from ultralytics.nn.backbone.resnet import * from ultralytics.nn.backbone.CSwomTramsformer import * from ultralytics.nn.backbone.UniRepLKNet import * from ultralytics.nn.backbone.TransNext import * @@ -625,7 +625,7 @@ class MTLClassificationModel(BaseModel): def init_criterion(self): """Initialize the loss criterion for the ClassificationModel.""" - return v8MTLClassificationLoss(self) + return v8MTLUWClassificationLoss(self) class RTDETRDetectionModel(DetectionModel): """ diff --git a/ultralytics/utils/loss.py b/ultralytics/utils/loss.py index 9251049..76b0c3b 100644 --- a/ultralytics/utils/loss.py +++ b/ultralytics/utils/loss.py @@ -996,12 +996,39 @@ class v8MTLClassificationLoss: for i in range(len(preds)): loss[i + 1] = torch.nn.functional.cross_entropy(preds[i], batch["cls"][i], reduction="mean") loss[0] = loss.sum() - return loss.sum(), loss.detach() # loss(box, cls, dfl) + return loss[0], loss.detach() # loss(box, cls, dfl) else: loss = (torch.nn.functional.cross_entropy(preds, batch["cls"], reduction="mean")) loss_items = loss.detach() return loss, loss_items +class v8MTLUWClassificationLoss(nn.Module): + """Criterion class for computing training losses with learnable uncertainty weights.""" + def __init__(self, model,task_numbers = 3): + super().__init__() + self.device = next(model.parameters()).device + + self.logvars = torch.ones(task_numbers, device=self.device) + + # self.register_parameter('logvars', self.logvars) + # self.task_numbers = task_numbers + + def forward(self, preds, batch): + """Compute the classification loss between predictions and true labels.""" + loss = torch.zeros(len(preds) + 1, device=self.device) + total_loss = torch.zeros(1, device=self.device) + if isinstance(preds, list): + for i in range(len(preds)): + loss[i + 1] = torch.nn.functional.cross_entropy(preds[i], batch["cls"][i], reduction="mean") + total_loss += (1.0 / (self.logvars[i] ** 2) * loss[i + 1] + torch.log(self.logvars[i])) + loss[0] = total_loss + return loss[0], loss.detach() # loss(box, cls, dfl) + else: + loss = (torch.nn.functional.cross_entropy(preds, batch["cls"], reduction="mean")) + loss_items = loss.detach() + return loss, loss_items + + class v8OBBLoss(v8DetectionLoss): def __init__(self, model): """