Models API Reference¶
Factory Function¶
create_model¶
Create a pansharpening model by name.
Parameters:
| Parameter | Type | Default | Description |
|---|---|---|---|
model_name |
str | required | Model identifier |
ms_bands |
int | 4 | Number of multispectral bands |
**kwargs |
dict | {} | Additional model-specific arguments |
Returns: PyTorch nn.Module
Example:
CNN Models¶
PNN¶
Basic 3-layer CNN for pansharpening.
PanNet¶
class PanNet(nn.Module):
def __init__(
self,
ms_bands: int = 4,
n_blocks: int = 4,
n_features: int = 64
)
ResNet-style pansharpening with high-pass filtering.
DRPNN¶
class DRPNN(nn.Module):
def __init__(
self,
ms_bands: int = 4,
n_blocks: int = 8,
n_features: int = 64
)
Deep residual pansharpening network.
PanNetCBAM¶
class PanNetCBAM(nn.Module):
def __init__(
self,
ms_bands: int = 4,
n_blocks: int = 4,
n_features: int = 64,
reduction: int = 16
)
PanNet with CBAM attention modules.
MultiScalePanNet¶
Multi-scale feature pyramid network.
Transformer Models¶
PanFormer¶
class PanFormer(nn.Module):
def __init__(
self,
ms_bands: int = 4,
embed_dim: int = 128,
depth: int = 4,
num_heads: int = 8,
patch_size: int = 4,
mlp_ratio: float = 4.0
)
Full transformer with cross-attention fusion.
PanFormerLite¶
class PanFormerLite(nn.Module):
def __init__(
self,
ms_bands: int = 4,
embed_dim: int = 64,
depth: int = 2,
num_heads: int = 4,
window_size: int = 8,
patch_size: int = 4
)
Lightweight transformer with window attention.
Forward Pass¶
All models have the same forward signature:
Parameters:
| Parameter | Shape | Description |
|---|---|---|
ms |
(B, C, H, W) | Multispectral image |
pan |
(B, 1, H, W) | Panchromatic image |
Returns: Tensor of shape (B, C, H, W)
Example: