深度学习——数据处理脚本(基于detectron2框架)-Warmup_Cosine

时间:2025-04-18 17:49:12
from torch.optim.lr_scheduler import CosineAnnealingLR, _LRScheduler
from detectron2.solver import build_lr_scheduler

class WarmupCosineAnnealingLR(_LRScheduler):
    def __init__(self, optimizer, max_iters, warmup_iters, warmup_factor, last_epoch=-1):
        self.max_iters = max_iters
        self.warmup_iters = warmup_iters
        self.warmup_factor = warmup_factor
        super().__init__(optimizer, last_epoch)

    def get_lr(self):
        if self.last_epoch < self.warmup_iters:
            alpha = self.last_epoch / self.warmup_iters
            return [base_lr * self.warmup_factor * (1 - alpha) + alpha * base_lr for base_lr in self.base_lrs]
        else:
            return [base_lr * (1 + math.cos(math.pi * (self.last_epoch - self.warmup_iters) / (self.max_iters - self.warmup_iters))) / 2 for base_lr in self.base_lrs]

def build_warmup_cosine_scheduler(cfg, optimizer):
    return WarmupCosineAnnealingLR(
        optimizer,
        max_iters=cfg.SOLVER.MAX_ITER,
        warmup_iters=cfg.SOLVER.WARMUP_ITERS,
        warmup_factor=cfg.SOLVER.WARMUP_FACTOR,
    )