Course Project: Video Prediction with Object Representations
Evaluating Image Representations for Video Prediction
A comprehensive implementation of video prediction models using Hybrid Transformer-based and CNN architectures for both holistic and object-centric scene representations. This project explores different approaches to learning and predicting future video frames on the MOVi-C dataset.
Figure: Holistic autoencoder reconstruction of MOVi-C video frames (gif visualization)
๐ Table of Contents
- Overview
- Features
- Project Structure
- Architecture
- Dataset
- Usage
- Experiments
- Model Checkpoints
- Configuration
- Results
- Citation
๐ฏ Overview
This project implements a two-stage video prediction pipeline:
- Stage 1 - Autoencoder Training: Learn compressed representations of video frames
- Stage 2 - Predictor Training: Predict future frame representations in latent space
The framework supports two distinct scene representation approaches: - Holistic Representation: Treats the entire scene as a unified entity - Object-Centric (OC) Representation: Decomposes scenes into individual objects using masks/bounding boxes
โจ Features
- ๐ Two-Stage Training Pipeline: Separate autoencoder and predictor training phases
- ๐ญ Dual Scene Representations: Support for both holistic and object-centric approaches
- ๐ง Transformer-Based Architecture: Modern attention-based encoders and decoders
- ๐ฏ Flexible Configuration: Easy-to-modify configuration system
- ๐ Comprehensive Logging: TensorBoard integration with visualization support
- โก Mixed Precision Training: Efficient GPU utilization with AMP support
- ๐ Early Stopping & Scheduling: Automatic training optimization
- ๐พ Checkpoint Management: Automatic model saving and loading
๐ Project Structure
CourseProject_2/
โโโ src/
โ โโโ base/ # Base classes
โ โ โโโ baseTrainer.py # Base trainer implementation
โ โ โโโ baseTransformer.py # Base transformer blocks
โ โโโ datalib/ # Data loading and processing
โ โ โโโ MoviC.py # MOVi-C dataset class
โ โ โโโ load_data.py # Data loading utilities
โ โ โโโ transforms.py # Data augmentation
โ โโโ model/ # Model architectures
โ โ โโโ ocvp.py # Main model definitions (TransformerAutoEncoder, TransformerPredictor, OCVP)
โ โ โโโ holistic_encoder.py # Holistic encoder (patch-based)
โ โ โโโ holistic_decoder.py # Holistic decoder
โ โ โโโ holistic_predictor.py # Holistic predictor
โ โ โโโ oc_encoder.py # Object-centric encoder (CNN + Transformer)
โ โ โโโ oc_decoder.py # Object-centric decoder (Transformer + CNN)
โ โ โโโ oc_predictor.py # Object-centric predictor
โ โ โโโ predictor_wrapper.py # Autoregressive wrapper with sliding window
โ โ โโโ model_utils.py # Model utilities (TransformerBlock, Patchifier, etc.)
โ โโโ utils/ # Utility functions
โ โ โโโ logger.py # Logging utilities
โ โ โโโ metrics.py # Evaluation metrics
โ โ โโโ utils.py # General utilities
โ โ โโโ visualization.py # Visualization tools
โ โโโ experiments/ # Experiment outputs
โ โ โโโ [experiment_name]/
โ โ โโโ checkpoints/ # Model checkpoints
โ โ โโโ config/ # Experiment config
โ โ โโโ tboard_logs/ # TensorBoard logs
โ โโโ CONFIG.py # Global configuration
โ โโโ trainer.py # Training entry point
โ โโโ ocvp.ipynb # Analysis notebook
โโโ docs/ # Documentation and reports
โโโ requirements.txt # Python dependencies
โโโ README.md # This file
Why Transformer + CNN Hybrid?
The object-centric model uses a hybrid Transformer + CNN architecture for optimal performance:
CNN Advantages: - โ Inductive Bias: Built-in understanding of spatial locality and translation invariance - โ Efficient Downsampling: Reduces 64ร64 images to compact 256D vectors - โ Parameter Efficiency: Fewer parameters than fully linear projections - โ Better Image Reconstruction: ConvTranspose layers naturally upsample spatial features
Transformer Advantages: - โ Temporal Modeling: Captures long-range dependencies across time - โ Object Relationships: Models interactions between multiple objects - โ Attention Mechanism: Learns which objects/features are important - โ Flexible Context: Handles variable number of objects and temporal sequences
Combined Benefits: - ๐ฏ CNNs handle spatial features (what objects look like) - ๐ฏ Transformers handle temporal dynamics (how objects move and interact) - ๐ฏ Best of both worlds: local spatial structure + global temporal reasoning
Key Components
- Encoder (
HolisticEncoder/ObjectCentricEncoder) - Holistic: Patchifies input images (16ร16 patches) โ Linear projection โ Transformer
- Object-Centric: CNN encoder + Transformer hybrid architecture
- CNN Feature Extraction: 3-layer ConvNet downsampler
- Conv2d(3โ64): 64ร64 โ 32ร32
- Conv2d(64โ128): 32ร32 โ 16ร16
- Conv2d(128โ256): 16ร16 โ 8ร8
- Linear: Flatten โ 256D embedding
- Extracts per-object features from masks/bboxes (up to 11 objects)
- Transformer processes object tokens across time
- Configurable depth (12 layers default)
- Embedding dimension: 256
- Multi-head attention (8 heads)
-
MLP size: 1024
-
Decoder (
HolisticDecoder/ObjectCentricDecoder) - Holistic: Transformer โ Linear projection โ Unpatchify to image
- Object-Centric: Transformer + CNN hybrid architecture
- Transformer processes latent object representations
- CNN Upsampling Decoder: 3-layer ConvTranspose
- Linear: 192D โ 128ร8ร8 feature map
- ConvTranspose2d(128โ64): 8ร8 โ 16ร16
- ConvTranspose2d(64โ32): 16ร16 โ 32ร32
- ConvTranspose2d(32โ3): 32ร32 โ 64ร64 RGB
- Tanh activation for [-1, 1] output range
- Combines per-object frames back to full scene
- Configurable depth (8 layers default)
- Embedding dimension: 192
-
Mixed loss: MSE (0.8) + L1 (0.2)
-
Predictor (
HolisticTransformerPredictor/ObjectCentricTransformerPredictor) - Predicts future latent representations autoregressively
- Transformer-based temporal modeling
- Configurable depth (8 layers default)
- Embedding dimension: 192
-
Optional residual connections
-
Predictor Wrapper (
PredictorWrapper) - Autoregressive Prediction: Iteratively predicts future frames
- Sliding Window Mechanism: Maintains a buffer of size 5
- Concatenates new predictions to input buffer
- Drops oldest frames when buffer exceeds window size
- Training Strategy:
- Random temporal slicing for data augmentation
- Per-step loss computation with temporal consistency
- Advanced Loss Function:
- MSE loss (0.6): Overall structure
- L1 loss (0.2): Sharpness and sparsity
- Cosine similarity loss (0.2): Feature alignment
- Generates 5 future frame predictions per forward pass
๐๏ธ Architecture
Overall Pipeline
Input Video Frames โ Encoder โ Latent Representation โ Predictor โ Future Latent โ Decoder โ Predicted Frames
Detailed Architecture: Object-Centric Model (Transformer + CNN)
โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
โ OBJECT-CENTRIC ENCODER โ
โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโค
โ Input: Video [B, T, 3, 64, 64] + Masks [B, T, 64, 64] โ
โ โ โ
โ Object Extraction (11 objects max) โ
โ โ Object Frames: [B, T, 11, 3, 64, 64] โ
โ โ โ
โ CNN Feature Extractor (Per Object): โ
โ โข Conv2d(3โ64, k=4, s=2) + BatchNorm + ReLU [64x64 โ 32x32] โ
โ โข Conv2d(64โ128, k=4, s=2) + BatchNorm + ReLU [32x32 โ 16x16] โ
โ โข Conv2d(128โ256, k=4, s=2) + BatchNorm + ReLU [16x16 โ 8x8] โ
โ โข Flatten + Linear(256ยท8ยท8 โ 256) โ
โ โ Object Tokens: [B, T, 11, 256] โ
โ โ โ
โ Transformer Encoder (12 layers): โ
โ โข Positional Encoding โ
โ โข Multi-Head Attention (8 heads, dim=128) โ
โ โข MLP (dim=1024) โ
โ โข Layer Normalization โ
โ โ Latent: [B, T, 11, 256] โ
โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
โ PREDICTOR + WRAPPER โ
โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโค
โ Input Latent: [B, T=24, 11, 256] โ
โ โ โ
โ PredictorWrapper (Autoregressive): โ
โ โข Random temporal slice (5 frames) โ
โ โข Sliding window buffer (size=5) โ
โ โ โ
โ Transformer Predictor (8 layers): โ
โ โข Linear(256 โ 192) โ
โ โข Transformer blocks (depth=8) โ
โ โข Linear(192 โ 256) โ
โ โข Optional residual connections โ
โ โ โ
โ Autoregressive Loop (5 predictions): โ
โ For t in 1..5: โ
โ โข Predict next frame โ
โ โข Append to buffer, shift window โ
โ โข Compute loss (MSE + L1 + Cosine) โ
โ โ Future Latent: [B, 5, 11, 256] โ
โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
โ OBJECT-CENTRIC DECODER โ
โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโค
โ Input Latent: [B, T, 11, 256] โ
โ โ โ
โ Transformer Decoder (8 layers): โ
โ โข Linear(256 โ 192) โ
โ โข Positional Encoding โ
โ โข Transformer blocks (depth=8) โ
โ โข Layer Normalization โ
โ โ [B, T, 11, 192] โ
โ โ โ
โ CNN Upsampling Decoder (Per Object): โ
โ โข Linear(192 โ 128ยท8ยท8) + Reshape to [128, 8, 8] โ
โ โข ConvTranspose2d(128โ64, k=4, s=2) + BatchNorm + ReLU [8x8 โ 16x16] โ
โ โข ConvTranspose2d(64โ32, k=4, s=2) + BatchNorm + ReLU [16x16 โ 32x32] โ
โ โข ConvTranspose2d(32โ3, k=4, s=2) + Tanh [32x32 โ 64x64] โ
โ โ Per-Object Frames: [B, T, 11, 3, 64, 64] โ
โ โ โ
โ Object Composition: โ
โ โข Sum all object frames: ฮฃ(objects) โ
โ โข Normalize: (x + 1) / 2 (from [-1,1] to [0,1]) โ
โ โข Clamp to [0, 1] โ
โ โ Reconstructed Video: [B, T, 3, 64, 64] โ
โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
Setup
-
Clone the repository:
-
Create and activate virtual environment:
-
Install dependencies:
๐ฆ Dataset
This project uses the MOVi-C dataset (Multi-Object Video Dataset).
Dataset Setup
- Download MOVi-C dataset from the official source
- Extract to your preferred location
- Update the dataset path in
src/CONFIG.py:
Dataset Structure
The MOVi-C dataset should have the following structure:
๐ป Usage
Training Autoencoder
Train the autoencoder with holistic representation:
Train with object-centric representation:
Training Predictor
After training the autoencoder, train the predictor:
python trainer.py --predictor --scene_rep holistic \
--ackpt experiments/01_Holistic_AE_XL/checkpoints/best_01_Holistic_AE_XL.pth
For object-centric:
python trainer.py --predictor --scene_rep oc \
--ackpt experiments/01_OC_AE_XL_64_Full_CNN/checkpoints/best_01_OC_AE_XL_64_Full_CNN.pth
Inference
Run end-to-end video prediction:
python trainer.py --inference --scene_rep holistic \
--ackpt path/to/autoencoder.pth \
--pckpt path/to/predictor.pth
Command-Line Arguments
| Argument | Short | Description |
|---|---|---|
--ae |
-a |
Enable autoencoder training mode |
--predictor |
-p |
Enable predictor training mode |
--inference |
-i |
Enable end-to-end inference mode |
--ackpt |
-ac |
Path to pretrained autoencoder checkpoint |
--pckpt |
-pc |
Path to pretrained predictor checkpoint |
--scene_rep |
-s |
Scene representation type: holistic or oc |
๐ฌ Experiments
The project includes several experimental configurations:
Autoencoder Experiments
- Holistic Autoencoders:
01_Holistic_AE_Base: Baseline holistic autoencoder-
02_Holistic_AE_XL: Extra-large holistic autoencoder -
Object-Centric Autoencoders:
01_OC_AE_XL_64_Full_CNN: Full CNN-based OC autoencoder01_OC_AE_XL_64_Mixed_CNN_Decoder_Linear_ENCODER: Mixed architecture- Various linear and advanced configurations
Predictor Experiments
- Holistic Predictors:
02_Holistic_Predictor_XL: Standard predictor03_Holistic_Predictor_XL: Improved version-
05_Holistic_Predictor_XL_NoResidual: Without residual connections -
Object-Centric Predictors:
01_OC_Predictor_XL: Standard OC predictor
Experiment Outputs
Each experiment generates: - Checkpoints: Best and periodic model saves - TensorBoard Logs: Training curves, visualizations - Configuration Snapshots: Reproducible experiment configs
๐พ Model Checkpoints
Pre-trained model checkpoints are available for download:
๐ Download Model Checkpoints
Available Checkpoints
- Holistic Autoencoder (Base & XL)
- Object-Centric Autoencoder (Various configurations)
- Holistic Predictor (Multiple versions)
- Object-Centric Predictor
โ๏ธ Configuration
The main configuration file is src/CONFIG.py. Key parameters:
Data Configuration
'data': {
'dataset_path': '/path/to/movi_c/',
'batch_size': 32,
'patch_size': 16,
'max_objects': 11,
'num_workers': 8,
'image_height': 64,
'image_width': 64,
}
Training Configuration
'training': {
'num_epochs': 300,
'warmup_epochs': 15,
'early_stopping_patience': 15,
'model_name': '01_OC_AE_XL_64_Full_CNN',
'lr': 4e-4,
'save_frequency': 25,
'use_scheduler': True,
'use_early_stopping': True,
'use_transforms': False,
'use_amp': True, # Mixed precision training
}
Model Configuration
'vit_cfg': {
'encoder_embed_dim': 256,
'decoder_embed_dim': 192,
'num_heads': 8,
'mlp_size': 1024,
'encoder_depth': 12,
'decoder_depth': 8,
'predictor_depth': 8,
'num_preds': 5,
'predictor_window_size': 5,
'use_masks': True,
'use_bboxes': False,
'residual': True,
}
๐ Results
Reconstruction Quality
The models achieve high-quality video frame reconstruction:
- Holistic Models: Capture global scene structure effectively
- Object-Centric Models: Better at preserving individual object details
Visualization
View results in the Jupyter notebook:
The notebook includes: - Training/validation loss curves - Reconstruction visualizations - Prediction quality analysis - Comparison between holistic and object-centric approaches
TensorBoard
Monitor training progress:
๐ Citation
If you use this code in your research, please cite:
@misc{video_prediction_ocvp,
title={Evaluating Image Representations for Video Prediction},
author={Your Name},
year={2025},
howpublished={\url{https://github.com/your-repo}}
}
๐ฌ Support
If you found this project helpful, you can support my work by buying me a coffee or via paypal!
Location
The complete project documentation, code, and notebooks are located in: