Chl-a Prediction using EMIT Data¶
This notebook demonstrates how to predict Chlorophyll-a (Chl-a) concentrations using EMIT hyperspectral imagery and a Mixture-of-Experts Variational Autoencoder (MoE-VAE), including key steps such as:
- ✅ Model Training and Testing
- ✅ Inference
- ✅ Visualization
In [ ]:
Copied!
# %pip install "hypercoast[all]"
# %pip install "hypercoast[all]"
In [ ]:
Copied!
import torch
import numpy as np
import os
from hypercoast import download_file
from hypercoast.emit_utils import *
import torch
import numpy as np
import os
from hypercoast import download_file
from hypercoast.emit_utils import *
First, download the sample data from Google Drive. The file size is about 3.0 GB. It may take a while to download. Please be patient. The downloaded zip file will be saved as EMIT-sample-data.zip in the current working directory and automatically unzipped to data folder.
In [ ]:
Copied!
url = "https://drive.google.com/file/d/1q80PhA_vrLgIjRHuyxunVoWXmVUNgggV/view"
url = "https://drive.google.com/file/d/1q80PhA_vrLgIjRHuyxunVoWXmVUNgggV/view"
In [ ]:
Copied!
download_file(url, output="EMIT-sample-data.zip")
download_file(url, output="EMIT-sample-data.zip")
In [ ]:
Copied!
# === Device ===
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
# === Parameters ===
selected_bands = [
403,
411,
418,
425,
433,
440,
448,
455,
463,
470,
477,
485,
492,
500,
507,
515,
522,
530,
537,
544,
552,
559,
567,
574,
582,
589,
597,
604,
611,
619,
626,
634,
641,
649,
656,
664,
671,
679,
686,
693,
700,
]
excel_path = "data/Gloria_updated_07242025.xlsx"
test_files = [
"data/GreatLake_all_data.xlsx",
"data/GOA_insitu_data_09052025.xlsx",
"data/satellite_for_EMIT_09052025.xlsx",
]
nc_path = "data/ISS_EMIT_2024_09_29_17_42_42_L2W.nc"
rgb_path = "data/ISS_EMIT_2024_09_29_17_42_42_L2R.nc"
save_dir = os.path.join("./EMIT/test_Chla")
os.makedirs(save_dir, exist_ok=True)
# === Device ===
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
# === Parameters ===
selected_bands = [
403,
411,
418,
425,
433,
440,
448,
455,
463,
470,
477,
485,
492,
500,
507,
515,
522,
530,
537,
544,
552,
559,
567,
574,
582,
589,
597,
604,
611,
619,
626,
634,
641,
649,
656,
664,
671,
679,
686,
693,
700,
]
excel_path = "data/Gloria_updated_07242025.xlsx"
test_files = [
"data/GreatLake_all_data.xlsx",
"data/GOA_insitu_data_09052025.xlsx",
"data/satellite_for_EMIT_09052025.xlsx",
]
nc_path = "data/ISS_EMIT_2024_09_29_17_42_42_L2W.nc"
rgb_path = "data/ISS_EMIT_2024_09_29_17_42_42_L2R.nc"
save_dir = os.path.join("./EMIT/test_Chla")
os.makedirs(save_dir, exist_ok=True)
In [ ]:
Copied!
# === Training dataset ===
train_real_dl, test_real_dl, input_dim, output_dim, train_ids, test_ids = (
load_real_data(
excel_path=excel_path,
selected_bands=selected_bands,
seed=42,
diff_before_norm=False,
diff_after_norm=False,
target_parameter="chl-a",
lower_quantile=0,
upper_quantile=1,
log_offset=1,
)
)
# === Training dataset ===
train_real_dl, test_real_dl, input_dim, output_dim, train_ids, test_ids = (
load_real_data(
excel_path=excel_path,
selected_bands=selected_bands,
seed=42,
diff_before_norm=False,
diff_after_norm=False,
target_parameter="chl-a",
lower_quantile=0,
upper_quantile=1,
log_offset=1,
)
)
In [ ]:
Copied!
# === External test datasets ===
test_dls, test_ids_list, test_dates_list = [], [], []
for file in test_files:
dl, _, _, ids, dates = load_real_test(
excel_path=file,
selected_bands=selected_bands,
diff_before_norm=False,
diff_after_norm=False,
max_allowed_diff=1.0,
target_parameter="chl-a",
log_offset=1,
)
test_dls.append(dl)
test_ids_list.append(ids)
test_dates_list.append(dates)
# === External test datasets ===
test_dls, test_ids_list, test_dates_list = [], [], []
for file in test_files:
dl, _, _, ids, dates = load_real_test(
excel_path=file,
selected_bands=selected_bands,
diff_before_norm=False,
diff_after_norm=False,
max_allowed_diff=1.0,
target_parameter="chl-a",
log_offset=1,
)
test_dls.append(dl)
test_ids_list.append(ids)
test_dates_list.append(dates)
In [ ]:
Copied!
# === EMIT image preprocess ===
test_loader, Rrs, mask, latitude, longitude = preprocess_emit_data_minmax(
nc_path=nc_path,
diff_before_norm=False,
diff_after_norm=False,
full_band_wavelengths=np.array(selected_bands),
)
# === EMIT image preprocess ===
test_loader, Rrs, mask, latitude, longitude = preprocess_emit_data_minmax(
nc_path=nc_path,
diff_before_norm=False,
diff_after_norm=False,
full_band_wavelengths=np.array(selected_bands),
)
In [ ]:
Copied!
# Define the Mixture-of-Experts Variational Autoencoder (MoE-VAE) model
model = MoE_VAE(
input_dim=input_dim,
output_dim=output_dim,
latent_dim=16,
encoder_hidden_dims=[128, 64, 32],
decoder_hidden_dims=[32, 64, 128],
activation="leakyrelu",
use_norm="layer",
use_dropout=False,
use_softplus_output=True,
num_experts=4,
k=2,
noisy_gating=True,
).to(device)
# Define the Mixture-of-Experts Variational Autoencoder (MoE-VAE) model
model = MoE_VAE(
input_dim=input_dim,
output_dim=output_dim,
latent_dim=16,
encoder_hidden_dims=[128, 64, 32],
decoder_hidden_dims=[32, 64, 128],
activation="leakyrelu",
use_norm="layer",
use_dropout=False,
use_softplus_output=True,
num_experts=4,
k=2,
noisy_gating=True,
).to(device)
In [ ]:
Copied!
# Define optimizer and Train the MoE-VAE model
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
train_log = train(
model, train_real_dl, device, epochs=400, optimizer=optimizer, save_dir=save_dir
)
# Define optimizer and Train the MoE-VAE model
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
train_log = train(
model, train_real_dl, device, epochs=400, optimizer=optimizer, save_dir=save_dir
)
In [ ]:
Copied!
# Evaluate model on test set, compute metrics, save and plot results
predictions, actuals = evaluate(model, test_real_dl, device, log_offset=1)
epsilon, beta, rmse, rmsle, mape, bias, mae = calculate_metrics(predictions, actuals)
test_loss = mae
save_results_to_excel(
test_ids, actuals, predictions, os.path.join(save_dir, "test_results.xlsx")
)
plot_results_with_density(
predictions, actuals, save_dir, mode="test_results", xlim=(-2, 4), ylim=(-2, 4)
)
# Evaluate model on test set, compute metrics, save and plot results
predictions, actuals = evaluate(model, test_real_dl, device, log_offset=1)
epsilon, beta, rmse, rmsle, mape, bias, mae = calculate_metrics(predictions, actuals)
test_loss = mae
save_results_to_excel(
test_ids, actuals, predictions, os.path.join(save_dir, "test_results.xlsx")
)
plot_results_with_density(
predictions, actuals, save_dir, mode="test_results", xlim=(-2, 4), ylim=(-2, 4)
)
In [ ]:
Copied!
# Evaluate on external test datasets
for dl, ids, dates, path in zip(test_dls, test_ids_list, test_dates_list, test_files):
preds, acts = evaluate(model, dl, device, log_offset=1)
save_results_from_excel_for_test(preds, acts, ids, dates, path, save_dir)
# Get unique mode name from file name
mode_name = os.path.splitext(os.path.basename(path))[0]
# Save individual plot
plot_results(
preds,
acts,
save_dir,
threshold=10,
mode=f"test_{mode_name}",
xlim=(-2, 4),
ylim=(-2, 4),
)
# Evaluate on external test datasets
for dl, ids, dates, path in zip(test_dls, test_ids_list, test_dates_list, test_files):
preds, acts = evaluate(model, dl, device, log_offset=1)
save_results_from_excel_for_test(preds, acts, ids, dates, path, save_dir)
# Get unique mode name from file name
mode_name = os.path.splitext(os.path.basename(path))[0]
# Save individual plot
plot_results(
preds,
acts,
save_dir,
threshold=10,
mode=f"test_{mode_name}",
xlim=(-2, 4),
ylim=(-2, 4),
)
In [ ]:
Copied!
# Perform model inference and generate a spatial map of predicted Chlorophyll-a concentration
Output = infer_and_visualize_single_model_minmax(
model=model,
test_loader=test_loader,
Rrs=Rrs,
mask=mask,
latitude=latitude,
longitude=longitude,
save_folder=save_dir,
rgb_nc_file=rgb_path,
structure_name="Chl-a",
log_offset=1,
vmin=0,
vmax=30,
)
# Perform model inference and generate a spatial map of predicted Chlorophyll-a concentration
Output = infer_and_visualize_single_model_minmax(
model=model,
test_loader=test_loader,
Rrs=Rrs,
mask=mask,
latitude=latitude,
longitude=longitude,
save_folder=save_dir,
rgb_nc_file=rgb_path,
structure_name="Chl-a",
log_offset=1,
vmin=0,
vmax=30,
)
In [ ]:
Copied!
# Save predicted Chlorophyll-a as GeoTIFF and visualize the spatial distribution on map
tif_path = os.path.join(save_dir, "chla.tif")
npy_to_tif(npy_input=Output, out_tif=tif_path, resolution_m=30)
with rasterio.open(tif_path) as src:
img = src.read(1)
transform = src.transform
bounds = src.bounds
img_masked = np.where(img < 0, np.nan, img)
extent = [bounds.left, bounds.right, bounds.bottom, bounds.top]
plt.figure(figsize=(8, 6))
im = plt.imshow(img_masked, cmap="jet", vmin=0, vmax=30, extent=extent, origin="upper")
plt.colorbar(im, label="Chl-a")
plt.title("Chl-a")
plt.show()
# Save predicted Chlorophyll-a as GeoTIFF and visualize the spatial distribution on map
tif_path = os.path.join(save_dir, "chla.tif")
npy_to_tif(npy_input=Output, out_tif=tif_path, resolution_m=30)
with rasterio.open(tif_path) as src:
img = src.read(1)
transform = src.transform
bounds = src.bounds
img_masked = np.where(img < 0, np.nan, img)
extent = [bounds.left, bounds.right, bounds.bottom, bounds.top]
plt.figure(figsize=(8, 6))
im = plt.imshow(img_masked, cmap="jet", vmin=0, vmax=30, extent=extent, origin="upper")
plt.colorbar(im, label="Chl-a")
plt.title("Chl-a")
plt.show()