st_encoder_decoder.py
Spatio-Temporal encoder-decoder for Monthly Prediction. The main model class is SpatioTemporalModel.
MonthlyConvDecoder(embed_dim=128, patch_h=4, patch_w=4, hidden=128, overlap=1, num_months=12)
Bases: Module
Decoder to reconstruct 2D maps from patch tokens.
The MonthlyConvDecoder converts latent patch tokens back to pixel space: - Applies a 1*1 convolution to mix features on the patch grid. - Uses a transposed convolution (deconvolution) to upsample tokens to the original spatial resolution. - Applies a convolutional refinement block to smooth patch boundaries. - Applies a small convolutional head to produce the final single-channel output. - Optionally masks out land regions using a boolean mask.
Args: embed_dim: Dimension of the patch embedding.The default is 128. Many vision transformers use embedding dimensions that are multiples of 64 (e.g., 64, 128, 256). This can be tuned. patch_h: Patch height patch_w: Patch width hidden: Hidden dimension in the decoder for mixing channel features. The default is 128, which can be tuned. overlap: Overlap size for deconvolution. It creates smooth blending between adjacent upsampled patches. Default is 1, no overlap at edges. num_months: Number of months. Default is 12.
Source code in climanet/st_encoder_decoder.py
forward(latent, M, out_H, out_W, land_mask=None)
Reconstruct 2D maps from latent patch tokens. Args: latent: Tensor of shape (B, MHpWp, C) where C is the embedding dimension. M: Number of months (temporal patches) out_H: Target output height (must be divisible by patch_h) out_W: Target output width (must be divisible by patch_w) land_mask: Optional boolean tensor of shape (out_H, out_W). Values set to True will be masked out (set to 0) in the output (only ocean pixels exist). Returns: Tensor of shape (B, M, out_H, out_W) representing the monthly variable e.g. SST.
Source code in climanet/st_encoder_decoder.py
SpatialPositionalEncoding2D(embed_dim=128, max_H=1024, max_W=1024)
Bases: Module
2D Spatial Positional Encoding using sine and cosine functions.
This module generates fixed sinusoidal positional encodings for a 2D spatial grid, following the formulation in "Attention Is All You Need" (Vaswani et al., 2017).
The returned positional encodings are intended to be added to spatial tokens by the caller. The encodings are not learnable.
Initialize the positional encoding. Args: embed_dim: Dimension of the embedding, it must be even. The default is 128. Embedding dimensions are usually multiples of 64 (e.g., 64, 128, 256). This can be tuned. max_H: Maximum height. Default is 1024, which should be sufficient. max_W: Maximum width. Default is 1024, which should be sufficient.
Source code in climanet/st_encoder_decoder.py
build_pe(H, W, embed_dim)
staticmethod
Build the 2D positional encoding encoding tensor. Args: H: Height of the grid W: Width of the grid embed_dim: Dimension of the embedding (must be even) Returns: Tensor of shape (H, W, embed_dim) containing fixed positional encodings. Encodings are constructed by combining sine/cosine encodings along height and width. Not learnable.
Source code in climanet/st_encoder_decoder.py
forward(Hp, Wp)
Get positional encoding for size (Hp, Wp). Args: Hp: Height after patching (≤ max_H) Wp: Width after patching (≤ max_W) Returns: Tensor of shape (Hp*Wp, embed_dim) containing positional encodings flattened in row-major order (height * width).
Source code in climanet/st_encoder_decoder.py
SpatialTransformer(embed_dim=128, depth=2, num_heads=4, mlp_ratio=4.0, dropout=0.0)
Bases: Module
Spatial Transformer for spatial feature mixing.
This module applies a standard Transformer encoder to a sequence of spatial tokens (patch embeddings), allowing information to be mixed across all spatial locations.
Key points: - Uses multi-head self-attention and feedforward layers. - Designed to operate on flattened spatial tokens.
Initialize the spatial transformer. Args: embed_dim: Dimension of the embedding. Default is 128. The embedding dimensions are multiples of 64 (e.g., 64, 128, 256). This can be tuned. depth: Number of transformer encoder layers. Default is 2. This can be increased for more complex spatial mixing. num_heads: Number of attention heads in each layer. Default is 4. When embed_dim is 128, 4 heads is a common choice. mlp_ratio: Ratio of feedforward hidden dimension to embed_dim. Default is 4.0. dropout: Dropout rate applied to attention and feedforward layers. Default is 0.0.
Source code in climanet/st_encoder_decoder.py
forward(x)
Forward pass of the spatial transformer. Args: x: Input tensor of shape (B, N, C), where N = number of spatial tokens (H'*W') and C = embedding dimension Returns: Tensor of shape (B, N, C) with spatially mixed features across patches
Source code in climanet/st_encoder_decoder.py
SpatioTemporalModel(in_chans=1, embed_dim=128, patch_size=(1, 4, 4), max_days=31, max_months=12, hidden=128, overlap=1, max_H=1024, max_W=1024, spatial_depth=2, spatial_heads=4)
Bases: Module
Spatio-Temporal Model for Monthly Prediction.
Processes daily data in a video-style format with shape (B, C, T, H, W): B: batch size C: number of channels (e.g., 1 for SST, can include additional channels like masks) T: temporal dimension (number of days, e.g., 31 for a month) H: spatial height W: spatial width
The model pipeline: 1. Encode spatio-temporal patches using VideoEncoder. 2. Aggregate temporal information for each spatial patch via TemporalAttentionAggregator. 3. Add 2D spatial positional encodings and mix spatial features with SpatialTransformer. 4. Decode aggregated tokens into a full-resolution 2D map using MonthlyConvDecoder.
Output: - Reconstructed monthly (SST) map of shape (B, M, H, W)
Initialize the Spatio-Temporal Model.
Args: in_chans: Number of input channels (e.g., 1 for SST, additional channels possible) embed_dim: Dimension of the patch embedding patch_size: Tuple of (T, H, W) patch sizes for temporal and spatial patching max_days: Maximum number of days for temporal positional encoding max_months: Maximum number of months for temporal positional encoding hidden: Hidden dimension used in the decoder overlap: Overlap for deconvolution in the decoder max_H: Maximum spatial height for 2D positional encoding max_W: Maximum spatial width for 2D positional encoding spatial_depth: Number of layers in the spatial Transformer spatial_heads: Number of attention heads in the spatial Transformer
Source code in climanet/st_encoder_decoder.py
forward(daily_data, daily_mask, land_mask_patch, padded_days_mask=None)
Forward pass of the Spatio-Temporal model.
Args: daily_data: Tensor of shape (B, C, M, T, H, W) containing daily data, where C is the number of channels (e.g., 1 for SST) daily_mask: Boolean tensor of same shape as daily_data indicating missing values land_mask_patch: Boolean tensor of shape (H, W) to mask land areas in the output padded_days_mask: Optional boolean tensor of shape (B, M, T) indicating which day tokens are padded (True for padded tokens). Used to mask out padded tokens in temporal attention. Returns: monthly_pred: Tensor of shape (B, M, H, W) representing the reconstructed monthly map
Source code in climanet/st_encoder_decoder.py
TemporalAttentionAggregator(embed_dim=128, max_days=31, max_months=12)
Bases: Module
Temporal attention-based aggregator.
This module aggregates temporal information for each spatial patch by applying attention across the temporal dimension. It consists of two main steps: 1. Day attention: For each month, it computes attention weights across the temporal tokens (days) and performs a weighted sum to get one token per spatial location for each month. 2. Cross-month mixing: After temporal aggregation, it applies a Transformer encoder layer to mix information across months at each spatial location.
For each spatial location, the day attention allows the model to learn which days are most important for predicting the monthly average, while the cross-month mixing allows the model to learn interactions between different months.
Initialize the temporal attention aggregator.
Args: embed_dim: Dimension of the embedding. The default is 128. Many vision transformers use embedding dimensions that are multiples of 64 (e.g., 64, 128, 256). This can be tuned. max_days: Maximum length of the temporal dimension to precompute encodings for. Default is 31, which is sufficient for a month of daily data. max_months: Maximum number of months (temporal patches) to precompute encodings for. Default is 12, which is sufficient for a year of monthly data.
Source code in climanet/st_encoder_decoder.py
forward(x, M, T, H, W, padded_days_mask=None)
Args: x: (B, MTHW, C) containing spatio-temporal tokens, where C is the embedding dimension. M: number of months T: number of temporal tokens per month after temporal patching (Tp) H: spatial height after spatial patching W: spatial width after spatial patching padded_days_mask: Optional boolean tensor of shape (B, M, T), bool, True indicating which day tokens are padded (because some months have fewer days). This is used to mask out padded tokens in attention computation. Returns: Tensor of shape (B, MH*W, C) with one temporally aggregated, where C is the embedding dimension.
Source code in climanet/st_encoder_decoder.py
TemporalPositionalEncoding(embed_dim=128, max_len=31)
Bases: Module
Temporal Positional Encoding using sine and cosine functions.
This module generates fixed (non-learnable) sinusoidal positional encodings for the temporal dimension, following the formulation in "Attention Is All You Need" (Vaswani et al., 2017).
The returned positional encodings are intended to be added to temporal embeddings by the caller, but this module itself does not perform the addition.
Initialize the temporal positional encoding. Args: embed_dim: Dimension of the embedding.The default is 128. Many vision transformers use embedding dimensions that are multiples of 64 (e.g., 64, 128, 256). This can be tuned. max_len: Maximum length of the temporal dimension to precompute encodings for. Default is 31, which is sufficient for a month of daily data.
Source code in climanet/st_encoder_decoder.py
forward(T)
Return positional encodings for a temporal sequence. Args: T: Temporal length (must be <= max_len) Returns: Tensor of shape (T, embed_dim) containing sinusoidal positional encodings
Source code in climanet/st_encoder_decoder.py
VideoEncoder(in_chans=1, embed_dim=128, patch_size=(1, 4, 4), drop=0.0)
Bases: Module
Video Encoder with spatio-temporal patch embedding.
This module converts an input video into a sequence of non-overlapping spatio-temporal patch embeddings using a 3D convolution.
Masking is handled by: - zeroing out masked (missing) pixels - concatenating a validity mask as an additional input channel
The convolution uses kernel size and stride equal to the patch size. The output is a sequence of patch embeddings, as used in VideoMAE: https://arxiv.org/abs/2203.12602
Args: in_chans: Number of input channels (1 for SST) embed_dim: Dimension of the patch embedding. The default is 128. Many vision transformers use embedding dimensions that are multiples of 64 (e.g., 64, 128, 256). This can be tuned. patch_size: Tuple of (T, H, W) patch size. Default is (1, 4, 4). drop: Probability of an element to be zeroed. Default is 0.0. Increase it if there is overfitting.
Source code in climanet/st_encoder_decoder.py
forward(x, mask)
Forward pass with masking support via an additional validity channel. Args: x: Input video tensor of shape (B, C, T, H, W) mask: Boolean mask tensor of shape (B, C, T, H, W), where True indicates masked pixels
Returns: Embedded patches of shape (B, N_patches, embed_dim)
Notes: - Masked pixels are zeroed out before patch embedding - A validity mask (1 = observed, 0 = missing) is concatenated as an additional input channel