MFT

class grid.model.perception.tracking.mft.MFT(*args, **kwargs)

MFT: Custom Point Tracking Model

This class implements a point tracking model that processes video frames and tracks points based on provided queries.

Parameters:
  • queries (Tensor)

  • save_results (bool)

_static_model

The model used for tracking.

Type:

torch.nn.Module

device

The device ('cuda' or 'cpu') where the model runs.

Type:

str

queries

Tensor containing point queries for tracking.

Type:

torch.Tensor

save_results

Flag indicating whether to save the tracking results to a video.

Type:

bool

initialized

Tracks whether the model has been initialized.

Type:

bool

window_frames

List to store video frames for later writing.

Type:

List[np.ndarray]

results

List to store the tracking results (coords, occlusions) for each frame.

Type:

List[Tuple]

__init__(queries, save_results=False)

Initialize the MFT with queries and optional flag to save results.

Parameters:
  • queries (torch.Tensor) -- Tensor containing the point queries for tracking.

  • save_results (bool) -- Whether to save the results as a video.

Return type:

None

draw_dots(frame, coords, occlusions)

Draw points (dots) on the frame at the locations of the tracked points.

Parameters:
  • frame (np.ndarray) -- The input video frame.

  • coords (torch.Tensor) -- Coordinates of the points to draw.

  • occlusions (torch.Tensor) -- Occlusions for each point.

Returns:

Frame with points drawn on it.

Return type:

np.ndarray

process_frame(frame)

Process a single video frame to track points based on the queries.

Parameters:

frame (np.ndarray) -- A single video frame (image).

Returns:

Coordinates and occlusions of the tracked points.

Return type:

Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]