Skip to content

st_encoder_decoder.py

Spatio-Temporal encoder-decoder for Monthly Prediction. The main model class is SpatioTemporalModel.

MonthlyConvDecoder(embed_dim=128, patch_h=4, patch_w=4, hidden=128, overlap=1, num_months=12)

Bases: Module

Decoder to reconstruct 2D maps from patch tokens.

The MonthlyConvDecoder converts latent patch tokens back to pixel space: - Applies a 1*1 convolution to mix features on the patch grid. - Uses a transposed convolution (deconvolution) to upsample tokens to the original spatial resolution. - Applies a convolutional refinement block to smooth patch boundaries. - Applies a small convolutional head to produce the final single-channel output. - Optionally masks out land regions using a boolean mask.

Args: embed_dim: Dimension of the patch embedding.The default is 128. Many vision transformers use embedding dimensions that are multiples of 64 (e.g., 64, 128, 256). This can be tuned. patch_h: Patch height patch_w: Patch width hidden: Hidden dimension in the decoder for mixing channel features. The default is 128, which can be tuned. overlap: Overlap size for deconvolution. It creates smooth blending between adjacent upsampled patches. Default is 1, no overlap at edges. num_months: Number of months. Default is 12.

Source code in climanet/st_encoder_decoder.py
def __init__(
    self, embed_dim=128, patch_h=4, patch_w=4, hidden=128, overlap=1, num_months=12
):
    """
    Args:
        embed_dim: Dimension of the patch embedding.The default is 128.
            Many vision transformers use embedding dimensions that are
            multiples of 64 (e.g., 64, 128, 256). This can be tuned.
        patch_h: Patch height
        patch_w: Patch width
        hidden: Hidden dimension in the decoder for mixing channel features.
            The default is 128, which can be tuned.
        overlap: Overlap size for deconvolution. It creates smooth blending
            between adjacent upsampled patches. Default is 1, no overlap at edges.
        num_months: Number of months. Default is 12.
    """
    super().__init__()
    self.patch_h = patch_h
    self.patch_w = patch_w
    self.overlap = overlap

    # Mix channel features on the patch grid (Hp, Wp)
    # Input shape: (B, embed_dim, Hp, Wp) → Output shape: (B, hidden, Hp, Wp)
    # here kernel_size=1 means we are mixing features at each patch location
    # without spatial interaction
    in_channels, out_channels = embed_dim, hidden
    self.proj = nn.Conv2d(in_channels, out_channels, kernel_size=1)

    # Upsample to full resolution
    # With kernel = stride + 2*overlap and padding=overlap,
    # output size is exact: H = Hp*patch_h, W = Wp*patch_w (no output_padding needed).
    k_h = patch_h + 2 * overlap
    k_w = patch_w + 2 * overlap
    # As spatial size increases, channel count decreases to keep computation
    # manageable; here  hidden // 2 is a design choice.
    in_channels, out_channels = hidden, hidden // 2
    self.deconv = nn.ConvTranspose2d(
        in_channels,
        out_channels,
        kernel_size=(k_h, k_w),
        stride=(patch_h, patch_w),
        padding=overlap,
        output_padding=0,
        bias=True,
    )

    # Final conv head to get single channel output kernel_size=3 is the most
    # common choice for spatial convolutions; it's the smallest kernel that
    # captures spatial context in all directions
    in_channels, out_channels = hidden // 2, hidden // 2

    # Refinement block: a small conv layers to smooth patch boundaries
    self.refine = nn.Sequential(
        nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
        nn.BatchNorm2d(out_channels),
        nn.GELU(),
        nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
        nn.BatchNorm2d(out_channels),
        nn.GELU(),
    )

    # Final conv head to map to single-channel output
    self.head = nn.Conv2d(out_channels, 1, kernel_size=1)

    # Learnable scale and bias (mean and std) to improve predictions
    self.scale = nn.Parameter(torch.ones(num_months))
    self.bias = nn.Parameter(torch.zeros(num_months))

forward(latent, M, out_H, out_W, land_mask=None)

Reconstruct 2D maps from latent patch tokens. Args: latent: Tensor of shape (B, MHpWp, C) where C is the embedding dimension. M: Number of months (temporal patches) out_H: Target output height (must be divisible by patch_h) out_W: Target output width (must be divisible by patch_w) land_mask: Optional boolean tensor of shape (out_H, out_W). Values set to True will be masked out (set to 0) in the output (only ocean pixels exist). Returns: Tensor of shape (B, M, out_H, out_W) representing the monthly variable e.g. SST.

Source code in climanet/st_encoder_decoder.py
def forward(self, latent, M, out_H, out_W, land_mask=None):
    """Reconstruct 2D maps from latent patch tokens.
    Args:
        latent: Tensor of shape (B, M*Hp*Wp, C) where C is the embedding dimension.
        M: Number of months (temporal patches)
        out_H: Target output height (must be divisible by patch_h)
        out_W: Target output width (must be divisible by patch_w)
        land_mask: Optional boolean tensor of shape (out_H, out_W). Values set to True
            will be masked out (set to 0) in the output (only ocean pixels exist).
    Returns:
        Tensor of shape (B, M, out_H, out_W) representing the monthly variable e.g. SST.
    """
    B, MHW, C = latent.shape
    Hp = out_H // self.patch_h
    Wp = out_W // self.patch_w
    assert MHW == M * Hp * Wp, f"Token mismatch: got {MHW}, expected {M * Hp * Wp}"

    # transforms the latent tensor from sequence format to image format for
    # convolution operations; (B, M*Hp*Wp, C) -> (B*M, C, Hp, Wp)
    out = latent.view(B, M, Hp, Wp, C).permute(0, 1, 4, 2, 3).contiguous()
    out = out.view(B * M, C, Hp, Wp)

    # Apply 1x1 convolution to mix features
    out = self.proj(out)  # (B*M, hidden, Hp, Wp)

    # Use transposed convolution to upsample
    out = self.deconv(out)  # (B*M, hidden//2, H, W)

    # Refinement CNN to smooth boundaries
    out = self.refine(out)  # (B*M, hidden//2, H, W)

    # Apply final conv head to get single channel output
    out = self.head(out)  # (B*M, 1, H, W)

    # Apply scale and bias per month to improve predictions; reshape to (B*M, 1, 1, 1) for broadcasting
    scale = self.scale[:M].unsqueeze(0).expand(B, M).reshape(B * M, 1, 1, 1)
    bias = self.bias[:M].unsqueeze(0).expand(B, M).reshape(B * M, 1, 1, 1)
    out = out * scale + bias
    out = out.view(B, M, out_H, out_W)  # (B, M, H, W)

    # Mask out land areas if land_mask is provided
    if land_mask is not None:
        out = out.masked_fill(land_mask.bool()[None, None, :, :], 0.0)
    return out  # (B, M, out_H, out_W)

SpatialPositionalEncoding2D(embed_dim=128, max_H=1024, max_W=1024)

Bases: Module

2D Spatial Positional Encoding using sine and cosine functions.

This module generates fixed sinusoidal positional encodings for a 2D spatial grid, following the formulation in "Attention Is All You Need" (Vaswani et al., 2017).

The returned positional encodings are intended to be added to spatial tokens by the caller. The encodings are not learnable.

Initialize the positional encoding. Args: embed_dim: Dimension of the embedding, it must be even. The default is 128. Embedding dimensions are usually multiples of 64 (e.g., 64, 128, 256). This can be tuned. max_H: Maximum height. Default is 1024, which should be sufficient. max_W: Maximum width. Default is 1024, which should be sufficient.

Source code in climanet/st_encoder_decoder.py
def __init__(self, embed_dim=128, max_H=1024, max_W=1024):
    """Initialize the positional encoding.
    Args:
        embed_dim: Dimension of the embedding, it must be even. The default is 128.
            Embedding dimensions are usually multiples of 64 (e.g., 64, 128,
            256). This can be tuned.
        max_H: Maximum height. Default is 1024, which should be sufficient.
        max_W: Maximum width. Default is 1024, which should be sufficient.
    """
    super().__init__()
    self.embed_dim = embed_dim
    self.max_H = max_H
    self.max_W = max_W
    self.register_buffer(
        "pe", self.build_pe(max_H, max_W, embed_dim), persistent=False
    )

build_pe(H, W, embed_dim) staticmethod

Build the 2D positional encoding encoding tensor. Args: H: Height of the grid W: Width of the grid embed_dim: Dimension of the embedding (must be even) Returns: Tensor of shape (H, W, embed_dim) containing fixed positional encodings. Encodings are constructed by combining sine/cosine encodings along height and width. Not learnable.

Source code in climanet/st_encoder_decoder.py
@staticmethod
def build_pe(H, W, embed_dim):
    """Build the 2D positional encoding encoding tensor.
    Args:
        H: Height of the grid
        W: Width of the grid
        embed_dim: Dimension of the embedding (must be even)
    Returns:
        Tensor of shape (H, W, embed_dim) containing fixed positional encodings.
        Encodings are constructed by combining sine/cosine encodings along
        height and width. Not learnable.
    """
    assert embed_dim % 2 == 0, "embed_dim must be even"
    pe_h = torch.zeros(H, embed_dim // 2)
    pe_w = torch.zeros(W, embed_dim // 2)
    pos_h = torch.arange(H).unsqueeze(1)
    pos_w = torch.arange(W).unsqueeze(1)
    div = torch.exp(
        torch.arange(0, embed_dim // 2, 2) * (-math.log(10000.0) / (embed_dim // 2))
    )
    pe_h[:, 0::2] = torch.sin(pos_h * div)
    pe_h[:, 1::2] = torch.cos(pos_h * div)
    pe_w[:, 0::2] = torch.sin(pos_w * div)
    pe_w[:, 1::2] = torch.cos(pos_w * div)
    pe_2d = pe_h.unsqueeze(1) + pe_w.unsqueeze(0)  # (H, W, embed_dim/2)
    # concatenate to reach embed_dim
    pe = torch.cat([pe_2d, pe_2d], dim=-1)  # (H, W, embed_dim)
    return pe  # not learned

forward(Hp, Wp)

Get positional encoding for size (Hp, Wp). Args: Hp: Height after patching (≤ max_H) Wp: Width after patching (≤ max_W) Returns: Tensor of shape (Hp*Wp, embed_dim) containing positional encodings flattened in row-major order (height * width).

Source code in climanet/st_encoder_decoder.py
def forward(self, Hp, Wp):
    """Get positional encoding for size (Hp, Wp).
    Args:
        Hp: Height after patching (≤ max_H)
        Wp: Width after patching (≤ max_W)
    Returns:
        Tensor of shape (Hp*Wp, embed_dim) containing positional encodings
        flattened in row-major order (height * width).
    """
    # returns (Hp*Wp, embed_dim)
    pe_hw = self.pe[:Hp, :Wp, :].reshape(Hp * Wp, -1)
    return pe_hw

SpatialTransformer(embed_dim=128, depth=2, num_heads=4, mlp_ratio=4.0, dropout=0.0)

Bases: Module

Spatial Transformer for spatial feature mixing.

This module applies a standard Transformer encoder to a sequence of spatial tokens (patch embeddings), allowing information to be mixed across all spatial locations.

Key points: - Uses multi-head self-attention and feedforward layers. - Designed to operate on flattened spatial tokens.

Initialize the spatial transformer. Args: embed_dim: Dimension of the embedding. Default is 128. The embedding dimensions are multiples of 64 (e.g., 64, 128, 256). This can be tuned. depth: Number of transformer encoder layers. Default is 2. This can be increased for more complex spatial mixing. num_heads: Number of attention heads in each layer. Default is 4. When embed_dim is 128, 4 heads is a common choice. mlp_ratio: Ratio of feedforward hidden dimension to embed_dim. Default is 4.0. dropout: Dropout rate applied to attention and feedforward layers. Default is 0.0.

Source code in climanet/st_encoder_decoder.py
def __init__(self, embed_dim=128, depth=2, num_heads=4, mlp_ratio=4.0, dropout=0.0):
    """Initialize the spatial transformer.
    Args:
        embed_dim: Dimension of the embedding. Default is 128.
            The embedding dimensions are multiples of 64 (e.g., 64, 128,
            256). This can be tuned.
        depth: Number of transformer encoder layers. Default is 2. This can be
            increased for more complex spatial mixing.
        num_heads: Number of attention heads in each layer. Default is 4.
            When embed_dim is 128, 4 heads is a common choice.
        mlp_ratio: Ratio of feedforward hidden dimension to embed_dim. Default is 4.0.
        dropout: Dropout rate applied to attention and feedforward layers. Default is 0.0.
    """
    super().__init__()

    # a single Transformer encoder block that
    # performs self-attention and feedforward processing
    encoder_layer = nn.TransformerEncoderLayer(
        d_model=embed_dim,
        nhead=num_heads,
        dim_feedforward=int(embed_dim * mlp_ratio),
        batch_first=True,
        dropout=dropout,
        activation="gelu",
    )
    # stack multiple layers to form the full encoder
    self.enc = nn.TransformerEncoder(encoder_layer, num_layers=depth)

forward(x)

Forward pass of the spatial transformer. Args: x: Input tensor of shape (B, N, C), where N = number of spatial tokens (H'*W') and C = embedding dimension Returns: Tensor of shape (B, N, C) with spatially mixed features across patches

Source code in climanet/st_encoder_decoder.py
def forward(self, x):
    """Forward pass of the spatial transformer.
    Args:
        x: Input tensor of shape (B, N, C), where N = number of spatial tokens (H'*W') and
            C = embedding dimension
    Returns:
        Tensor of shape (B, N, C) with spatially mixed features across patches
    """
    return self.enc(x)

SpatioTemporalModel(in_chans=1, embed_dim=128, patch_size=(1, 4, 4), max_days=31, max_months=12, hidden=128, overlap=1, max_H=1024, max_W=1024, spatial_depth=2, spatial_heads=4)

Bases: Module

Spatio-Temporal Model for Monthly Prediction.

Processes daily data in a video-style format with shape (B, C, T, H, W): B: batch size C: number of channels (e.g., 1 for SST, can include additional channels like masks) T: temporal dimension (number of days, e.g., 31 for a month) H: spatial height W: spatial width

The model pipeline: 1. Encode spatio-temporal patches using VideoEncoder. 2. Aggregate temporal information for each spatial patch via TemporalAttentionAggregator. 3. Add 2D spatial positional encodings and mix spatial features with SpatialTransformer. 4. Decode aggregated tokens into a full-resolution 2D map using MonthlyConvDecoder.

Output: - Reconstructed monthly (SST) map of shape (B, M, H, W)

Initialize the Spatio-Temporal Model.

Args: in_chans: Number of input channels (e.g., 1 for SST, additional channels possible) embed_dim: Dimension of the patch embedding patch_size: Tuple of (T, H, W) patch sizes for temporal and spatial patching max_days: Maximum number of days for temporal positional encoding max_months: Maximum number of months for temporal positional encoding hidden: Hidden dimension used in the decoder overlap: Overlap for deconvolution in the decoder max_H: Maximum spatial height for 2D positional encoding max_W: Maximum spatial width for 2D positional encoding spatial_depth: Number of layers in the spatial Transformer spatial_heads: Number of attention heads in the spatial Transformer

Source code in climanet/st_encoder_decoder.py
def __init__(
    self,
    in_chans=1,
    embed_dim=128,
    patch_size=(1, 4, 4),
    max_days=31,
    max_months=12,
    hidden=128,
    overlap=1,
    max_H=1024,
    max_W=1024,
    spatial_depth=2,
    spatial_heads=4,
):
    """Initialize the Spatio-Temporal Model.

    Args:
        in_chans: Number of input channels (e.g., 1 for SST, additional channels possible)
        embed_dim: Dimension of the patch embedding
        patch_size: Tuple of (T, H, W) patch sizes for temporal and spatial patching
        max_days: Maximum number of days for temporal positional encoding
        max_months: Maximum number of months for temporal positional encoding
        hidden: Hidden dimension used in the decoder
        overlap: Overlap for deconvolution in the decoder
        max_H: Maximum spatial height for 2D positional encoding
        max_W: Maximum spatial width for 2D positional encoding
        spatial_depth: Number of layers in the spatial Transformer
        spatial_heads: Number of attention heads in the spatial Transformer
    """
    super().__init__()
    self.encoder = VideoEncoder(
        in_chans=in_chans, embed_dim=embed_dim, patch_size=patch_size
    )
    self.temporal = TemporalAttentionAggregator(
        embed_dim=embed_dim, max_days=max_days, max_months=max_months
    )
    self.spatial_pe = SpatialPositionalEncoding2D(
        embed_dim=embed_dim, max_H=max_H, max_W=max_W
    )
    self.spatial_tr = SpatialTransformer(
        embed_dim=embed_dim, depth=spatial_depth, num_heads=spatial_heads
    )
    self.decoder = MonthlyConvDecoder(
        embed_dim=embed_dim,
        patch_h=patch_size[1],
        patch_w=patch_size[2],
        hidden=hidden,
        overlap=overlap,
        num_months=max_months,
    )
    self.patch_size = patch_size

forward(daily_data, daily_mask, land_mask_patch, padded_days_mask=None)

Forward pass of the Spatio-Temporal model.

Args: daily_data: Tensor of shape (B, C, M, T, H, W) containing daily data, where C is the number of channels (e.g., 1 for SST) daily_mask: Boolean tensor of same shape as daily_data indicating missing values land_mask_patch: Boolean tensor of shape (H, W) to mask land areas in the output padded_days_mask: Optional boolean tensor of shape (B, M, T) indicating which day tokens are padded (True for padded tokens). Used to mask out padded tokens in temporal attention. Returns: monthly_pred: Tensor of shape (B, M, H, W) representing the reconstructed monthly map

Source code in climanet/st_encoder_decoder.py
def forward(self, daily_data, daily_mask, land_mask_patch, padded_days_mask=None):
    """Forward pass of the Spatio-Temporal model.

    Args:
        daily_data: Tensor of shape (B, C, M, T, H, W) containing daily
            data, where C is the number of channels (e.g., 1 for SST)
        daily_mask: Boolean tensor of same shape as daily_data indicating missing values
        land_mask_patch: Boolean tensor of shape (H, W) to mask land areas in the output
        padded_days_mask: Optional boolean tensor of shape (B, M, T) indicating which day tokens are padded
             (True for padded tokens). Used to mask out padded tokens in temporal attention.
    Returns:
        monthly_pred: Tensor of shape (B, M, H, W) representing the reconstructed monthly map
    """
    B, C, M, T, H, W = daily_data.shape

    Tp = T // self.patch_size[0]
    Hp = H // self.patch_size[1]
    Wp = W // self.patch_size[2]
    Np = Tp * Hp * Wp

    # check shape and patch compatibility
    assert daily_mask.shape == daily_data.shape, (
        "daily_mask must have the same shape as daily_data"
    )
    assert H % self.patch_size[1] == 0 and W % self.patch_size[2] == 0, (
        "H and W must be divisible by patch size"
    )
    assert T % self.patch_size[0] == 0, "T must be divisible by patch size"

    # Step 1: Encode spatio-temporal patches
    # each month independently by folding M into batch
    daily_data_reshaped = daily_data.reshape(B * M, C, T, H, W)
    daily_mask_reshaped = daily_mask.reshape(B * M, C, T, H, W)
    latent = self.encoder(
        daily_data_reshaped, daily_mask_reshaped
    )  # (B*M, N_patches, embed_dim)

    # Step 2: Aggregate temporal information for each spatial patch
    # latent -> (B, M*Np, embed_dim) to match the aggregator input x: (B, M*Tp*Hp*Wp, embed_dim)
    latent = latent.reshape(B, M * Np, -1)

    if padded_days_mask is not None and self.patch_size[0] > 1:
        B, M, T_days = padded_days_mask.shape
        if T_days % self.patch_size[0] != 0:
            raise ValueError(
                f"T_days={T_days} must be divisible by patch_size[0]={self.patch_size[0]}"
            )
        padded_days_mask = padded_days_mask.view(
            B, M, T_days // self.patch_size[0], self.patch_size[0]
        ).all(dim=-1)  # (B, M, Tp)

    agg_latent = self.temporal(
        latent, M, Tp, Hp, Wp, padded_days_mask=padded_days_mask
    )  # (B, M*Hp*Wp, embed_dim)

    # Step 3: Add spatial positional encodings and mix spatial features
    E = agg_latent.shape[-1]
    agg_latent = agg_latent.view(B, M, Hp * Wp, E)
    pe = (
        self.spatial_pe(Hp, Wp).to(agg_latent.device).to(agg_latent.dtype)
    )  # (Hp*Wp, E)
    x = agg_latent + pe[None, None, :, :]

    # Step 4: Spatial mixing with Transformer
    x = x.view(B * M, Hp * Wp, E)
    x = self.spatial_tr(x)  # (B*M, Hp*Wp, E)
    x = x.view(B, M * Hp * Wp, E)  # back to (B, M*Hp*Wp, E)

    # Step 5: Decode to full-resolution 2D map
    monthly_pred = self.decoder(x, M, H, W, land_mask_patch)  # (B, M, H, W)
    return monthly_pred

TemporalAttentionAggregator(embed_dim=128, max_days=31, max_months=12)

Bases: Module

Temporal attention-based aggregator.

This module aggregates temporal information for each spatial patch by applying attention across the temporal dimension. It consists of two main steps: 1. Day attention: For each month, it computes attention weights across the temporal tokens (days) and performs a weighted sum to get one token per spatial location for each month. 2. Cross-month mixing: After temporal aggregation, it applies a Transformer encoder layer to mix information across months at each spatial location.

For each spatial location, the day attention allows the model to learn which days are most important for predicting the monthly average, while the cross-month mixing allows the model to learn interactions between different months.

Initialize the temporal attention aggregator.

Args: embed_dim: Dimension of the embedding. The default is 128. Many vision transformers use embedding dimensions that are multiples of 64 (e.g., 64, 128, 256). This can be tuned. max_days: Maximum length of the temporal dimension to precompute encodings for. Default is 31, which is sufficient for a month of daily data. max_months: Maximum number of months (temporal patches) to precompute encodings for. Default is 12, which is sufficient for a year of monthly data.

Source code in climanet/st_encoder_decoder.py
def __init__(self, embed_dim=128, max_days=31, max_months=12):
    """Initialize the temporal attention aggregator.

    Args:
        embed_dim: Dimension of the embedding. The default is 128.
            Many vision transformers use embedding dimensions that are multiples
            of 64 (e.g., 64, 128, 256). This can be tuned.
        max_days: Maximum length of the temporal dimension to precompute
        encodings for. Default is 31, which is sufficient for a month of
        daily data.
        max_months: Maximum number of months (temporal patches) to precompute
        encodings for. Default is 12, which is sufficient for a year of monthly data.
    """
    super().__init__()

    # Positional encodings for days and months
    self.pos_days = TemporalPositionalEncoding(embed_dim, max_len=max_days)
    self.pos_months = TemporalPositionalEncoding(embed_dim, max_len=max_months)

    # Day scorer (within each month)
    self.day_scorer = nn.Sequential(
        nn.LayerNorm(embed_dim),  # normalizing features
        nn.Linear(embed_dim, embed_dim),  # learns temporal feature transformation
        nn.GELU(),  # adds non-linearity to capture complex temporal patterns
        nn.Linear(embed_dim, 1),  # project to a single score
    )

    # Cross month mixing
    self.month_ln = nn.LayerNorm(embed_dim)
    self.month_attn = nn.MultiheadAttention(
        embed_dim=embed_dim,
        num_heads=4,
        dropout=0.0,
        batch_first=True,
    )
    self.month_ffn = nn.Sequential(
        nn.LayerNorm(embed_dim),
        nn.Linear(
            embed_dim, 4 * embed_dim
        ),  # 4 is a common factor in transformer feedforward layers
        nn.GELU(),
        nn.Linear(4 * embed_dim, embed_dim),
    )

forward(x, M, T, H, W, padded_days_mask=None)

Args: x: (B, MTHW, C) containing spatio-temporal tokens, where C is the embedding dimension. M: number of months T: number of temporal tokens per month after temporal patching (Tp) H: spatial height after spatial patching W: spatial width after spatial patching padded_days_mask: Optional boolean tensor of shape (B, M, T), bool, True indicating which day tokens are padded (because some months have fewer days). This is used to mask out padded tokens in attention computation. Returns: Tensor of shape (B, MH*W, C) with one temporally aggregated, where C is the embedding dimension.

Source code in climanet/st_encoder_decoder.py
def forward(self, x, M, T, H, W, padded_days_mask=None):
    """
    Args:
        x: (B, M*T*H*W, C) containing spatio-temporal tokens, where C is the embedding dimension.
        M: number of months
        T: number of temporal tokens per month after temporal patching (Tp)
        H: spatial height after spatial patching
        W: spatial width after spatial patching
        padded_days_mask: Optional boolean tensor of shape (B, M, T), bool,
            True indicating which day tokens are padded (because some months
            have fewer days). This is used to mask out padded tokens in attention computation.
    Returns:
        Tensor of shape (B, M*H*W, C) with one temporally aggregated, where C is the embedding dimension.
    """
    seq = rearrange(x, "b (m t h w) c -> b (h w) m t c", m=M, t=T, h=H, w=W)

    pe_days = self.pos_days(T).to(seq.device).to(seq.dtype)  # (T, C)
    pe_months = self.pos_months(M).to(seq.device).to(seq.dtype)  # (M, C)

    seq = seq + pe_days[None, None, None, :, :]  # add day PE
    seq = seq + pe_months[None, None, :, None, :]  # add month PE

    # Day attention per month
    day_logits = self.day_scorer(seq).squeeze(-1)  # (B, HW, M, T)

    # padded_days_mask is (B, M, T) true=padded, -> (B, HW, M, T)
    if padded_days_mask is not None:
        pad = padded_days_mask[:, None, :, :].expand(x.shape[0], H * W, M, T)
        day_logits = day_logits.masked_fill(pad, float("-inf"))

    day_w = torch.softmax(day_logits, dim=-1)
    month_tokens = (seq * day_w.unsqueeze(-1)).sum(dim=3)  # (B, HW, M, C)

    # Cross-month attention at each spatial location
    z = rearrange(month_tokens, "b s m c -> (b s) m c")
    z_ln = self.month_ln(z)
    attn_out, _ = self.month_attn(z_ln, z_ln, z_ln, need_weights=False)
    z = z + attn_out
    z = z + self.month_ffn(z)

    # Back to flattened tokens with month kept
    z = rearrange(z, "(b s) m c -> b s m c", b=x.shape[0], s=H * W)
    out = rearrange(z, "b (h w) m c -> b (m h w) c", h=H, w=W)
    return out  # (B, M*H*W, C)  C: embedding dimension

TemporalPositionalEncoding(embed_dim=128, max_len=31)

Bases: Module

Temporal Positional Encoding using sine and cosine functions.

This module generates fixed (non-learnable) sinusoidal positional encodings for the temporal dimension, following the formulation in "Attention Is All You Need" (Vaswani et al., 2017).

The returned positional encodings are intended to be added to temporal embeddings by the caller, but this module itself does not perform the addition.

Initialize the temporal positional encoding. Args: embed_dim: Dimension of the embedding.The default is 128. Many vision transformers use embedding dimensions that are multiples of 64 (e.g., 64, 128, 256). This can be tuned. max_len: Maximum length of the temporal dimension to precompute encodings for. Default is 31, which is sufficient for a month of daily data.

Source code in climanet/st_encoder_decoder.py
def __init__(self, embed_dim=128, max_len=31):
    """Initialize the temporal positional encoding.
    Args:
        embed_dim: Dimension of the embedding.The default is 128.
            Many vision transformers use embedding dimensions that are multiples
            of 64 (e.g., 64, 128, 256). This can be tuned.
        max_len: Maximum length of the temporal dimension to precompute
        encodings for. Default is 31, which is sufficient for a month of
        daily data.
    """
    super().__init__()
    pe = torch.zeros(max_len, embed_dim)
    position = torch.arange(0, max_len).unsqueeze(1)
    div_term = torch.exp(
        torch.arange(0, embed_dim, 2) * (-math.log(10000.0) / embed_dim)
    )
    pe[:, 0::2] = torch.sin(position * div_term)
    pe[:, 1::2] = torch.cos(position * div_term)
    self.register_buffer("pe", pe)  # (max_len, embeddim)

forward(T)

Return positional encodings for a temporal sequence. Args: T: Temporal length (must be <= max_len) Returns: Tensor of shape (T, embed_dim) containing sinusoidal positional encodings

Source code in climanet/st_encoder_decoder.py
def forward(self, T):
    """Return positional encodings for a temporal sequence.
    Args:
        T: Temporal length (must be <= max_len)
    Returns:
        Tensor of shape (T, embed_dim) containing sinusoidal positional encodings
    """
    return self.pe[:T]  # (T, embed_dim)

VideoEncoder(in_chans=1, embed_dim=128, patch_size=(1, 4, 4), drop=0.0)

Bases: Module

Video Encoder with spatio-temporal patch embedding.

This module converts an input video into a sequence of non-overlapping spatio-temporal patch embeddings using a 3D convolution.

Masking is handled by: - zeroing out masked (missing) pixels - concatenating a validity mask as an additional input channel

The convolution uses kernel size and stride equal to the patch size. The output is a sequence of patch embeddings, as used in VideoMAE: https://arxiv.org/abs/2203.12602

Args: in_chans: Number of input channels (1 for SST) embed_dim: Dimension of the patch embedding. The default is 128. Many vision transformers use embedding dimensions that are multiples of 64 (e.g., 64, 128, 256). This can be tuned. patch_size: Tuple of (T, H, W) patch size. Default is (1, 4, 4). drop: Probability of an element to be zeroed. Default is 0.0. Increase it if there is overfitting.

Source code in climanet/st_encoder_decoder.py
def __init__(self, in_chans=1, embed_dim=128, patch_size=(1, 4, 4), drop=0.0):
    """
    Args:
        in_chans: Number of input channels (1 for SST)
        embed_dim: Dimension of the patch embedding. The default is 128.
            Many vision transformers use embedding dimensions that are multiples
            of 64 (e.g., 64, 128, 256). This can be tuned.
        patch_size: Tuple of (T, H, W) patch size. Default is (1, 4, 4).
        drop: Probability of an element to be zeroed. Default is 0.0.
            Increase it if there is overfitting.
    """
    super().__init__()
    self.patch_size = patch_size

    # proj is a Conv3d with kernel and stride = patch_size to create non-overlapping patches
    # 2 * in_chans because we add a validity channel
    self.proj = nn.Conv3d(
        2 * in_chans, embed_dim, kernel_size=patch_size, stride=patch_size
    )

    # norm is LayerNorm over the embedding dimension to normalize patch embeddings
    self.norm = nn.LayerNorm(embed_dim)

    # dropout for regularization
    self.drop = nn.Dropout(drop)

forward(x, mask)

Forward pass with masking support via an additional validity channel. Args: x: Input video tensor of shape (B, C, T, H, W) mask: Boolean mask tensor of shape (B, C, T, H, W), where True indicates masked pixels

Returns: Embedded patches of shape (B, N_patches, embed_dim)

Notes: - Masked pixels are zeroed out before patch embedding - A validity mask (1 = observed, 0 = missing) is concatenated as an additional input channel

Source code in climanet/st_encoder_decoder.py
def forward(self, x, mask):
    """Forward pass with masking support via an additional validity channel.
    Args:
        x: Input video tensor of shape (B, C, T, H, W)
        mask: Boolean mask tensor of shape (B, C, T, H, W), where True
        indicates masked pixels

    Returns:
        Embedded patches of shape (B, N_patches, embed_dim)

    Notes:
        - Masked pixels are zeroed out before patch embedding
        - A validity mask (1 = observed, 0 = missing) is concatenated
        as an additional input channel
    """
    # x: (B,1,T,H,W), mask: (B,1,T,H,W) where True means missing
    valid = (~mask).float()
    x = x * valid  # zero-out missing values
    x = torch.cat([x, valid], dim=1)  # add validity as a channel
    x = self.proj(x)  # (B, C, T', H', W')
    x = rearrange(x, "b c t h w -> b (t h w) c")
    x = self.norm(x)
    x = self.drop(x)
    return x  # (B, N_patches, embed_dim)