Inference Examples¶
Basic Inference¶
import torch
from models import create_model
# Load model
model = create_model('panformer_lite', ms_bands=4)
model.load_state_dict(torch.load('checkpoints/panformer_lite_best.pth'))
model.eval()
# Prepare inputs
ms = torch.randn(1, 4, 256, 256)
pan = torch.randn(1, 1, 256, 256)
# Run inference
with torch.no_grad():
fused = model(ms, pan)
print(f"Output shape: {fused.shape}")
GPU Inference¶
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)
ms = ms.to(device)
pan = pan.to(device)
with torch.no_grad():
fused = model(ms, pan)
Inference with GeoTIFF¶
import rasterio
from utils.data_utils import load_geotiff, save_geotiff
# Load images
ms, ms_profile = load_geotiff('data/ms.tif')
pan, pan_profile = load_geotiff('data/pan.tif')
# Convert to tensors
ms_tensor = torch.from_numpy(ms).unsqueeze(0).float()
pan_tensor = torch.from_numpy(pan).unsqueeze(0).float()
# Run model
with torch.no_grad():
fused = model(ms_tensor, pan_tensor)
# Save with geospatial metadata
fused_np = fused.squeeze(0).numpy()
save_geotiff('results/fused.tif', fused_np, pan_profile)
Batch Inference¶
# Process multiple patches
batch_size = 8
results = []
for i in range(0, len(patches), batch_size):
batch_ms = patches_ms[i:i+batch_size]
batch_pan = patches_pan[i:i+batch_size]
with torch.no_grad():
batch_fused = model(batch_ms, batch_pan)
results.append(batch_fused)
fused = torch.cat(results, dim=0)
Classic Methods Inference¶
No training required for classic methods:
from methods.classic import brovey, ihs, sfim, gram_schmidt, hpf
# Run classic methods
fused_brovey = brovey(pan, ms)
fused_ihs = ihs(pan, ms)
fused_sfim = sfim(pan, ms)
fused_gs = gram_schmidt(pan, ms)
fused_hpf = hpf(pan, ms)
Evaluate Results¶
from utils.metrics import calculate_metrics
# Calculate quality metrics
metrics = calculate_metrics(fused, ground_truth)
print(f"PSNR: {metrics['psnr']:.2f} dB")
print(f"SSIM: {metrics['ssim']:.4f}")
print(f"SAM: {metrics['sam']:.4f}")
print(f"ERGAS: {metrics['ergas']:.4f}")
Visualization¶
import matplotlib.pyplot as plt
fig, axes = plt.subplots(1, 4, figsize=(16, 4))
axes[0].imshow(ms[:3].permute(1, 2, 0))
axes[0].set_title('MS (RGB)')
axes[1].imshow(pan[0], cmap='gray')
axes[1].set_title('PAN')
axes[2].imshow(fused[:3].permute(1, 2, 0))
axes[2].set_title('Fused')
axes[3].imshow(ground_truth[:3].permute(1, 2, 0))
axes[3].set_title('Ground Truth')
plt.tight_layout()
plt.savefig('results/comparison.png', dpi=150)
plt.show()