Skip to content

Reference for ultralytics/models/sam/sam3/necks.py

Improvements

This page is sourced from https://github.com/ultralytics/ultralytics/blob/main/ultralytics/models/sam/sam3/necks.py. Have an improvement or example to add? Open a Pull Request — thank you! 🙏


class ultralytics.models.sam.sam3.necks.Sam3DualViTDetNeck

def __init__(
    self,
    trunk: nn.Module,
    position_encoding: nn.Module,
    d_model: int,
    scale_factors=(4.0, 2.0, 1.0, 0.5),
    add_sam2_neck: bool = False,
)

Bases: nn.Module

A neck that implements a simple FPN as in ViTDet, with support for dual necks (for SAM3 and SAM2).

(From detectron2, very lightly adapted) It supports a "dual neck" setting, where we have two identical necks (for SAM3 and SAM2), with different weights.

:param trunk: the backbone :param position_encoding: the positional encoding to use :param d_model: the dimension of the model

Args

NameTypeDescriptionDefault
trunknn.Modulerequired
position_encodingnn.Modulerequired
d_modelintrequired
scale_factors(4.0, 2.0, 1.0, 0.5)
add_sam2_neckboolFalse

Methods

NameDescription
forwardGet the feature maps and positional encodings from the neck.
set_imgszSet the image size for the trunk backbone.
Source code in ultralytics/models/sam/sam3/necks.pyView on GitHub
class Sam3DualViTDetNeck(nn.Module):
    """A neck that implements a simple FPN as in ViTDet, with support for dual necks (for SAM3 and SAM2)."""

    def __init__(
        self,
        trunk: nn.Module,
        position_encoding: nn.Module,
        d_model: int,
        scale_factors=(4.0, 2.0, 1.0, 0.5),
        add_sam2_neck: bool = False,
    ):
        """
        SimpleFPN neck a la ViTDet
        (From detectron2, very lightly adapted)
        It supports a "dual neck" setting, where we have two identical necks (for SAM3 and SAM2), with different weights.

        :param trunk: the backbone
        :param position_encoding: the positional encoding to use
        :param d_model: the dimension of the model
        """
        super().__init__()
        self.trunk = trunk
        self.position_encoding = position_encoding
        self.convs = nn.ModuleList()

        self.scale_factors = scale_factors
        use_bias = True
        dim: int = self.trunk.channel_list[-1]

        for _, scale in enumerate(scale_factors):
            current = nn.Sequential()

            if scale == 4.0:
                current.add_module(
                    "dconv_2x2_0",
                    nn.ConvTranspose2d(dim, dim // 2, kernel_size=2, stride=2),
                )
                current.add_module(
                    "gelu",
                    nn.GELU(),
                )
                current.add_module(
                    "dconv_2x2_1",
                    nn.ConvTranspose2d(dim // 2, dim // 4, kernel_size=2, stride=2),
                )
                out_dim = dim // 4
            elif scale == 2.0:
                current.add_module(
                    "dconv_2x2",
                    nn.ConvTranspose2d(dim, dim // 2, kernel_size=2, stride=2),
                )
                out_dim = dim // 2
            elif scale == 1.0:
                out_dim = dim
            elif scale == 0.5:
                current.add_module(
                    "maxpool_2x2",
                    nn.MaxPool2d(kernel_size=2, stride=2),
                )
                out_dim = dim
            else:
                raise NotImplementedError(f"scale_factor={scale} is not supported yet.")

            current.add_module(
                "conv_1x1",
                nn.Conv2d(
                    in_channels=out_dim,
                    out_channels=d_model,
                    kernel_size=1,
                    bias=use_bias,
                ),
            )
            current.add_module(
                "conv_3x3",
                nn.Conv2d(
                    in_channels=d_model,
                    out_channels=d_model,
                    kernel_size=3,
                    padding=1,
                    bias=use_bias,
                ),
            )
            self.convs.append(current)

        self.sam2_convs = None
        if add_sam2_neck:
            # Assumes sam2 neck is just a clone of the original neck
            self.sam2_convs = deepcopy(self.convs)


method ultralytics.models.sam.sam3.necks.Sam3DualViTDetNeck.forward

def forward(
    self, tensor_list: list[torch.Tensor]
) -> tuple[list[torch.Tensor], list[torch.Tensor], list[torch.Tensor], list[torch.Tensor]]

Get the feature maps and positional encodings from the neck.

Args

NameTypeDescriptionDefault
tensor_listlist[torch.Tensor]required
Source code in ultralytics/models/sam/sam3/necks.pyView on GitHub
def forward(
    self, tensor_list: list[torch.Tensor]
) -> tuple[list[torch.Tensor], list[torch.Tensor], list[torch.Tensor], list[torch.Tensor]]:
    """Get the feature maps and positional encodings from the neck."""
    xs = self.trunk(tensor_list)
    sam3_out, sam3_pos = [], []
    sam2_out, sam2_pos = None, None
    if self.sam2_convs is not None:
        sam2_out, sam2_pos = [], []
    x = xs[-1]  # simpleFPN
    for i in range(len(self.convs)):
        sam3_x_out = self.convs[i](x)
        sam3_pos_out = self.position_encoding(sam3_x_out).to(sam3_x_out.dtype)
        sam3_out.append(sam3_x_out)
        sam3_pos.append(sam3_pos_out)

        if self.sam2_convs is not None:
            sam2_x_out = self.sam2_convs[i](x)
            sam2_pos_out = self.position_encoding(sam2_x_out).to(sam2_x_out.dtype)
            sam2_out.append(sam2_x_out)
            sam2_pos.append(sam2_pos_out)
    return sam3_out, sam3_pos, sam2_out, sam2_pos


method ultralytics.models.sam.sam3.necks.Sam3DualViTDetNeck.set_imgsz

def set_imgsz(self, imgsz: list[int] = [1008, 1008])

Set the image size for the trunk backbone.

Args

NameTypeDescriptionDefault
imgszlist[int][1008, 1008]
Source code in ultralytics/models/sam/sam3/necks.pyView on GitHub
def set_imgsz(self, imgsz: list[int] = [1008, 1008]):
    """Set the image size for the trunk backbone."""
    self.trunk.set_imgsz(imgsz)





📅 Created 0 days ago ✏️ Updated 0 days ago
Laughing-q