MFT
- class grid.model.perception.tracking.mft.MFT(*args, **kwargs)
MFT: Custom Point Tracking Model
This class implements a point tracking model that processes video frames and tracks points based on provided queries.
- Parameters:
queries (Tensor)
save_results (bool)
- _static_model
The model used for tracking.
- Type:
torch.nn.Module
- device
The device ('cuda' or 'cpu') where the model runs.
- Type:
str
- queries
Tensor containing point queries for tracking.
- Type:
torch.Tensor
- save_results
Flag indicating whether to save the tracking results to a video.
- Type:
bool
- initialized
Tracks whether the model has been initialized.
- Type:
bool
- window_frames
List to store video frames for later writing.
- Type:
List[np.ndarray]
- results
List to store the tracking results (coords, occlusions) for each frame.
- Type:
List[Tuple]
- __init__(queries, save_results=False)
Initialize the MFT with queries and optional flag to save results.
- Parameters:
queries (torch.Tensor) -- Tensor containing the point queries for tracking.
save_results (bool) -- Whether to save the results as a video.
- Return type:
None
- draw_dots(frame, coords, occlusions)
Draw points (dots) on the frame at the locations of the tracked points.
- Parameters:
frame (np.ndarray) -- The input video frame.
coords (torch.Tensor) -- Coordinates of the points to draw.
occlusions (torch.Tensor) -- Occlusions for each point.
- Returns:
Frame with points drawn on it.
- Return type:
np.ndarray
- process_frame(frame)
Process a single video frame to track points based on the queries.
- Parameters:
frame (np.ndarray) -- A single video frame (image).
- Returns:
Coordinates and occlusions of the tracked points.
- Return type:
Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]