Skip to content

Course Project: Video Prediction with Object Representations

Evaluating Image Representations for Video Prediction

A comprehensive implementation of video prediction models using Hybrid Transformer-based and CNN architectures for both holistic and object-centric scene representations. This project explores different approaches to learning and predicting future video frames on the MOVi-C dataset.

Reconstruction GIF
Figure: Holistic autoencoder reconstruction of MOVi-C video frames (gif visualization)

๐Ÿ“‹ Table of Contents

๐ŸŽฏ Overview

This project implements a two-stage video prediction pipeline:

  1. Stage 1 - Autoencoder Training: Learn compressed representations of video frames
  2. Stage 2 - Predictor Training: Predict future frame representations in latent space

The framework supports two distinct scene representation approaches: - Holistic Representation: Treats the entire scene as a unified entity - Object-Centric (OC) Representation: Decomposes scenes into individual objects using masks/bounding boxes

โœจ Features

  • ๐Ÿ”„ Two-Stage Training Pipeline: Separate autoencoder and predictor training phases
  • ๐ŸŽญ Dual Scene Representations: Support for both holistic and object-centric approaches
  • ๐Ÿง  Transformer-Based Architecture: Modern attention-based encoders and decoders
  • ๐ŸŽฏ Flexible Configuration: Easy-to-modify configuration system
  • ๐Ÿ“Š Comprehensive Logging: TensorBoard integration with visualization support
  • โšก Mixed Precision Training: Efficient GPU utilization with AMP support
  • ๐Ÿ” Early Stopping & Scheduling: Automatic training optimization
  • ๐Ÿ’พ Checkpoint Management: Automatic model saving and loading

๐Ÿ“ Project Structure

CourseProject_2/
โ”œโ”€โ”€ src/
โ”‚   โ”œโ”€โ”€ base/                    # Base classes
โ”‚   โ”‚   โ”œโ”€โ”€ baseTrainer.py       # Base trainer implementation
โ”‚   โ”‚   โ””โ”€โ”€ baseTransformer.py   # Base transformer blocks
โ”‚   โ”œโ”€โ”€ datalib/                 # Data loading and processing
โ”‚   โ”‚   โ”œโ”€โ”€ MoviC.py            # MOVi-C dataset class
โ”‚   โ”‚   โ”œโ”€โ”€ load_data.py        # Data loading utilities
โ”‚   โ”‚   โ””โ”€โ”€ transforms.py        # Data augmentation
โ”‚   โ”œโ”€โ”€ model/                   # Model architectures
โ”‚   โ”‚   โ”œโ”€โ”€ ocvp.py             # Main model definitions (TransformerAutoEncoder, TransformerPredictor, OCVP)
โ”‚   โ”‚   โ”œโ”€โ”€ holistic_encoder.py # Holistic encoder (patch-based)
โ”‚   โ”‚   โ”œโ”€โ”€ holistic_decoder.py # Holistic decoder
โ”‚   โ”‚   โ”œโ”€โ”€ holistic_predictor.py # Holistic predictor
โ”‚   โ”‚   โ”œโ”€โ”€ oc_encoder.py       # Object-centric encoder (CNN + Transformer)
โ”‚   โ”‚   โ”œโ”€โ”€ oc_decoder.py       # Object-centric decoder (Transformer + CNN)
โ”‚   โ”‚   โ”œโ”€โ”€ oc_predictor.py     # Object-centric predictor
โ”‚   โ”‚   โ”œโ”€โ”€ predictor_wrapper.py # Autoregressive wrapper with sliding window
โ”‚   โ”‚   โ””โ”€โ”€ model_utils.py      # Model utilities (TransformerBlock, Patchifier, etc.)
โ”‚   โ”œโ”€โ”€ utils/                   # Utility functions
โ”‚   โ”‚   โ”œโ”€โ”€ logger.py           # Logging utilities
โ”‚   โ”‚   โ”œโ”€โ”€ metrics.py          # Evaluation metrics
โ”‚   โ”‚   โ”œโ”€โ”€ utils.py            # General utilities
โ”‚   โ”‚   โ””โ”€โ”€ visualization.py    # Visualization tools
โ”‚   โ”œโ”€โ”€ experiments/             # Experiment outputs
โ”‚   โ”‚   โ””โ”€โ”€ [experiment_name]/
โ”‚   โ”‚       โ”œโ”€โ”€ checkpoints/    # Model checkpoints
โ”‚   โ”‚       โ”œโ”€โ”€ config/         # Experiment config
โ”‚   โ”‚       โ””โ”€โ”€ tboard_logs/    # TensorBoard logs
โ”‚   โ”œโ”€โ”€ CONFIG.py               # Global configuration
โ”‚   โ”œโ”€โ”€ trainer.py              # Training entry point
โ”‚   โ””โ”€โ”€ ocvp.ipynb             # Analysis notebook
โ”œโ”€โ”€ docs/                       # Documentation and reports
โ”œโ”€โ”€ requirements.txt            # Python dependencies
โ””โ”€โ”€ README.md                   # This file

Why Transformer + CNN Hybrid?

The object-centric model uses a hybrid Transformer + CNN architecture for optimal performance:

CNN Advantages: - โœ… Inductive Bias: Built-in understanding of spatial locality and translation invariance - โœ… Efficient Downsampling: Reduces 64ร—64 images to compact 256D vectors - โœ… Parameter Efficiency: Fewer parameters than fully linear projections - โœ… Better Image Reconstruction: ConvTranspose layers naturally upsample spatial features

Transformer Advantages: - โœ… Temporal Modeling: Captures long-range dependencies across time - โœ… Object Relationships: Models interactions between multiple objects - โœ… Attention Mechanism: Learns which objects/features are important - โœ… Flexible Context: Handles variable number of objects and temporal sequences

Combined Benefits: - ๐ŸŽฏ CNNs handle spatial features (what objects look like) - ๐ŸŽฏ Transformers handle temporal dynamics (how objects move and interact) - ๐ŸŽฏ Best of both worlds: local spatial structure + global temporal reasoning

Key Components

  1. Encoder (HolisticEncoder / ObjectCentricEncoder)
  2. Holistic: Patchifies input images (16ร—16 patches) โ†’ Linear projection โ†’ Transformer
  3. Object-Centric: CNN encoder + Transformer hybrid architecture
    • CNN Feature Extraction: 3-layer ConvNet downsampler
    • Conv2d(3โ†’64): 64ร—64 โ†’ 32ร—32
    • Conv2d(64โ†’128): 32ร—32 โ†’ 16ร—16
    • Conv2d(128โ†’256): 16ร—16 โ†’ 8ร—8
    • Linear: Flatten โ†’ 256D embedding
    • Extracts per-object features from masks/bboxes (up to 11 objects)
    • Transformer processes object tokens across time
  4. Configurable depth (12 layers default)
  5. Embedding dimension: 256
  6. Multi-head attention (8 heads)
  7. MLP size: 1024

  8. Decoder (HolisticDecoder / ObjectCentricDecoder)

  9. Holistic: Transformer โ†’ Linear projection โ†’ Unpatchify to image
  10. Object-Centric: Transformer + CNN hybrid architecture
    • Transformer processes latent object representations
    • CNN Upsampling Decoder: 3-layer ConvTranspose
    • Linear: 192D โ†’ 128ร—8ร—8 feature map
    • ConvTranspose2d(128โ†’64): 8ร—8 โ†’ 16ร—16
    • ConvTranspose2d(64โ†’32): 16ร—16 โ†’ 32ร—32
    • ConvTranspose2d(32โ†’3): 32ร—32 โ†’ 64ร—64 RGB
    • Tanh activation for [-1, 1] output range
    • Combines per-object frames back to full scene
  11. Configurable depth (8 layers default)
  12. Embedding dimension: 192
  13. Mixed loss: MSE (0.8) + L1 (0.2)

  14. Predictor (HolisticTransformerPredictor / ObjectCentricTransformerPredictor)

  15. Predicts future latent representations autoregressively
  16. Transformer-based temporal modeling
  17. Configurable depth (8 layers default)
  18. Embedding dimension: 192
  19. Optional residual connections

  20. Predictor Wrapper (PredictorWrapper)

  21. Autoregressive Prediction: Iteratively predicts future frames
  22. Sliding Window Mechanism: Maintains a buffer of size 5
    • Concatenates new predictions to input buffer
    • Drops oldest frames when buffer exceeds window size
  23. Training Strategy:
    • Random temporal slicing for data augmentation
    • Per-step loss computation with temporal consistency
  24. Advanced Loss Function:
    • MSE loss (0.6): Overall structure
    • L1 loss (0.2): Sharpness and sparsity
    • Cosine similarity loss (0.2): Feature alignment
  25. Generates 5 future frame predictions per forward pass

๐Ÿ—๏ธ Architecture

Overall Pipeline

Input Video Frames โ†’ Encoder โ†’ Latent Representation โ†’ Predictor โ†’ Future Latent โ†’ Decoder โ†’ Predicted Frames

Detailed Architecture: Object-Centric Model (Transformer + CNN)

โ”Œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”
โ”‚                          OBJECT-CENTRIC ENCODER                             โ”‚
โ”œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ค
โ”‚ Input: Video [B, T, 3, 64, 64] + Masks [B, T, 64, 64]                       โ”‚
โ”‚   โ†“                                                                         โ”‚
โ”‚ Object Extraction (11 objects max)                                          โ”‚
โ”‚   โ†’ Object Frames: [B, T, 11, 3, 64, 64]                                    โ”‚
โ”‚   โ†“                                                                         โ”‚
โ”‚ CNN Feature Extractor (Per Object):                                         โ”‚
โ”‚   โ€ข Conv2d(3โ†’64, k=4, s=2) + BatchNorm + ReLU    [64x64 โ†’ 32x32]            โ”‚
โ”‚   โ€ข Conv2d(64โ†’128, k=4, s=2) + BatchNorm + ReLU  [32x32 โ†’ 16x16]            โ”‚
โ”‚   โ€ข Conv2d(128โ†’256, k=4, s=2) + BatchNorm + ReLU [16x16 โ†’ 8x8]              โ”‚
โ”‚   โ€ข Flatten + Linear(256ยท8ยท8 โ†’ 256)                                         โ”‚
โ”‚   โ†’ Object Tokens: [B, T, 11, 256]                                          โ”‚
โ”‚   โ†“                                                                         โ”‚
โ”‚ Transformer Encoder (12 layers):                                            โ”‚
โ”‚   โ€ข Positional Encoding                                                     โ”‚
โ”‚   โ€ข Multi-Head Attention (8 heads, dim=128)                                 โ”‚
โ”‚   โ€ข MLP (dim=1024)                                                          โ”‚
โ”‚   โ€ข Layer Normalization                                                     โ”‚
โ”‚   โ†’ Latent: [B, T, 11, 256]                                                 โ”‚ 
โ””โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”˜
โ”Œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”
โ”‚                            PREDICTOR + WRAPPER                              โ”‚
โ”œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ค
โ”‚ Input Latent: [B, T=24, 11, 256]                                            โ”‚
โ”‚   โ†“                                                                         โ”‚
โ”‚ PredictorWrapper (Autoregressive):                                          โ”‚
โ”‚   โ€ข Random temporal slice (5 frames)                                        โ”‚
โ”‚   โ€ข Sliding window buffer (size=5)                                          โ”‚
โ”‚   โ†“                                                                         โ”‚
โ”‚ Transformer Predictor (8 layers):                                           โ”‚
โ”‚   โ€ข Linear(256 โ†’ 192)                                                       โ”‚
โ”‚   โ€ข Transformer blocks (depth=8)                                            โ”‚
โ”‚   โ€ข Linear(192 โ†’ 256)                                                       โ”‚
โ”‚   โ€ข Optional residual connections                                           โ”‚
โ”‚   โ†“                                                                         โ”‚
โ”‚ Autoregressive Loop (5 predictions):                                        โ”‚
โ”‚   For t in 1..5:                                                            โ”‚
โ”‚     โ€ข Predict next frame                                                    โ”‚
โ”‚     โ€ข Append to buffer, shift window                                        โ”‚
โ”‚     โ€ข Compute loss (MSE + L1 + Cosine)                                      โ”‚
โ”‚   โ†’ Future Latent: [B, 5, 11, 256]                                          โ”‚ 
โ””โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”˜
โ”Œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”
โ”‚                          OBJECT-CENTRIC DECODER                             โ”‚
โ”œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ค
โ”‚ Input Latent: [B, T, 11, 256]                                               โ”‚
โ”‚   โ†“                                                                         โ”‚
โ”‚ Transformer Decoder (8 layers):                                             โ”‚
โ”‚   โ€ข Linear(256 โ†’ 192)                                                       โ”‚
โ”‚   โ€ข Positional Encoding                                                     โ”‚
โ”‚   โ€ข Transformer blocks (depth=8)                                            โ”‚
โ”‚   โ€ข Layer Normalization                                                     โ”‚
โ”‚   โ†’ [B, T, 11, 192]                                                         โ”‚
โ”‚   โ†“                                                                         โ”‚
โ”‚ CNN Upsampling Decoder (Per Object):                                        โ”‚
โ”‚   โ€ข Linear(192 โ†’ 128ยท8ยท8) + Reshape to [128, 8, 8]                          โ”‚
โ”‚   โ€ข ConvTranspose2d(128โ†’64, k=4, s=2) + BatchNorm + ReLU [8x8 โ†’ 16x16]      โ”‚
โ”‚   โ€ข ConvTranspose2d(64โ†’32, k=4, s=2) + BatchNorm + ReLU [16x16 โ†’ 32x32]     โ”‚
โ”‚   โ€ข ConvTranspose2d(32โ†’3, k=4, s=2) + Tanh        [32x32 โ†’ 64x64]           โ”‚
โ”‚   โ†’ Per-Object Frames: [B, T, 11, 3, 64, 64]                                โ”‚
โ”‚   โ†“                                                                         โ”‚
โ”‚ Object Composition:                                                         โ”‚
โ”‚   โ€ข Sum all object frames: ฮฃ(objects)                                       โ”‚
โ”‚   โ€ข Normalize: (x + 1) / 2  (from [-1,1] to [0,1])                          โ”‚
โ”‚   โ€ข Clamp to [0, 1]                                                         โ”‚
โ”‚   โ†’ Reconstructed Video: [B, T, 3, 64, 64]                                  โ”‚
โ””โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”˜

Setup

  1. Clone the repository:

    git clone <repository-url>
    cd CourseProject_2
    

  2. Create and activate virtual environment:

    python -m venv venv
    source venv/bin/activate  # On Linux/Mac
    # or
    venv\Scripts\activate  # On Windows
    

  3. Install dependencies:

    pip install -r requirements.txt
    

๐Ÿ“ฆ Dataset

This project uses the MOVi-C dataset (Multi-Object Video Dataset).

Dataset Setup

  1. Download MOVi-C dataset from the official source
  2. Extract to your preferred location
  3. Update the dataset path in src/CONFIG.py:
config = {
    'data': {
        'dataset_path': '/path/to/movi_c/',
        ...
    }
}

Dataset Structure

The MOVi-C dataset should have the following structure:

movi_c/
โ”œโ”€โ”€ train/
โ”œโ”€โ”€ validation/
โ””โ”€โ”€ test/

๐Ÿ’ป Usage

Training Autoencoder

Train the autoencoder with holistic representation:

cd src
python trainer.py --ae --scene_rep holistic

Train with object-centric representation:

python trainer.py --ae --scene_rep oc

Training Predictor

After training the autoencoder, train the predictor:

python trainer.py --predictor --scene_rep holistic \
    --ackpt experiments/01_Holistic_AE_XL/checkpoints/best_01_Holistic_AE_XL.pth

For object-centric:

python trainer.py --predictor --scene_rep oc \
    --ackpt experiments/01_OC_AE_XL_64_Full_CNN/checkpoints/best_01_OC_AE_XL_64_Full_CNN.pth

Inference

Run end-to-end video prediction:

python trainer.py --inference --scene_rep holistic \
    --ackpt path/to/autoencoder.pth \
    --pckpt path/to/predictor.pth

Command-Line Arguments

Argument Short Description
--ae -a Enable autoencoder training mode
--predictor -p Enable predictor training mode
--inference -i Enable end-to-end inference mode
--ackpt -ac Path to pretrained autoencoder checkpoint
--pckpt -pc Path to pretrained predictor checkpoint
--scene_rep -s Scene representation type: holistic or oc

๐Ÿ”ฌ Experiments

The project includes several experimental configurations:

Autoencoder Experiments

  1. Holistic Autoencoders:
  2. 01_Holistic_AE_Base: Baseline holistic autoencoder
  3. 02_Holistic_AE_XL: Extra-large holistic autoencoder

  4. Object-Centric Autoencoders:

  5. 01_OC_AE_XL_64_Full_CNN: Full CNN-based OC autoencoder
  6. 01_OC_AE_XL_64_Mixed_CNN_Decoder_Linear_ENCODER: Mixed architecture
  7. Various linear and advanced configurations

Predictor Experiments

  1. Holistic Predictors:
  2. 02_Holistic_Predictor_XL: Standard predictor
  3. 03_Holistic_Predictor_XL: Improved version
  4. 05_Holistic_Predictor_XL_NoResidual: Without residual connections

  5. Object-Centric Predictors:

  6. 01_OC_Predictor_XL: Standard OC predictor

Experiment Outputs

Each experiment generates: - Checkpoints: Best and periodic model saves - TensorBoard Logs: Training curves, visualizations - Configuration Snapshots: Reproducible experiment configs

๐Ÿ’พ Model Checkpoints

Pre-trained model checkpoints are available for download:

๐Ÿ”— Download Model Checkpoints

Available Checkpoints

  • Holistic Autoencoder (Base & XL)
  • Object-Centric Autoencoder (Various configurations)
  • Holistic Predictor (Multiple versions)
  • Object-Centric Predictor

โš™๏ธ Configuration

The main configuration file is src/CONFIG.py. Key parameters:

Data Configuration

'data': {
    'dataset_path': '/path/to/movi_c/',
    'batch_size': 32,
    'patch_size': 16,
    'max_objects': 11,
    'num_workers': 8,
    'image_height': 64,
    'image_width': 64,
}

Training Configuration

'training': {
    'num_epochs': 300,
    'warmup_epochs': 15,
    'early_stopping_patience': 15,
    'model_name': '01_OC_AE_XL_64_Full_CNN',
    'lr': 4e-4,
    'save_frequency': 25,
    'use_scheduler': True,
    'use_early_stopping': True,
    'use_transforms': False,
    'use_amp': True,  # Mixed precision training
}

Model Configuration

'vit_cfg': {
    'encoder_embed_dim': 256,
    'decoder_embed_dim': 192,
    'num_heads': 8,
    'mlp_size': 1024,
    'encoder_depth': 12,
    'decoder_depth': 8,
    'predictor_depth': 8,
    'num_preds': 5,
    'predictor_window_size': 5,
    'use_masks': True,
    'use_bboxes': False,
    'residual': True,
}

๐Ÿ“Š Results

Reconstruction Quality

The models achieve high-quality video frame reconstruction:

  • Holistic Models: Capture global scene structure effectively
  • Object-Centric Models: Better at preserving individual object details

Visualization

View results in the Jupyter notebook:

cd src
jupyter lab ocvp.ipynb

The notebook includes: - Training/validation loss curves - Reconstruction visualizations - Prediction quality analysis - Comparison between holistic and object-centric approaches

TensorBoard

Monitor training progress:

tensorboard --logdir src/experiments/[experiment_name]/tboard_logs

๐ŸŽ“ Citation

If you use this code in your research, please cite:

@misc{video_prediction_ocvp,
  title={Evaluating Image Representations for Video Prediction},
  author={Your Name},
  year={2025},
  howpublished={\url{https://github.com/your-repo}}
}

๐Ÿ’ฌ Support

If you found this project helpful, you can support my work by buying me a coffee or via paypal!

Buy Me a Coffee

PayPal


Location

The complete project documentation, code, and notebooks are located in:

src/CourseProject/