import math import torch import torch.nn as nn import torch.nn.functional as F class Embed(nn.Module): def __init__( self, in_chans: int = 3, embed_dim: int = 768, norm_layer = None, bias: bool = True, ): super().__init__() self.in_chans = in_chans self.embed_dim = embed_dim self.proj = nn.Linear(in_chans, embed_dim, bias=bias) self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() def forward(self, x): x = self.proj(x) x = self.norm(x) return x