Moe vae acdom
Mixture of Experts Variational Autoencoder (MoE-VAE) for aCDOM440¶
This example demonstrates how to use the Mixture of Experts Variational Autoencoder (MoE-VAE) to predict the Absorption Coefficient of Colored Dissolved Organic Matter at 440 nm (aCDOM440) using PACE data.
In [ ]:
Copied!
# %pip install "hypercoast[all]"
# %pip install "hypercoast[all]"
Import libraries¶
In [ ]:
Copied!
import os
import pickle
import numpy as np
import pandas as pd
import rasterio
import torch
import os
import pickle
import numpy as np
import pandas as pd
import rasterio
import torch
In [ ]:
Copied!
from hypercoast import download_file
from hypercoast.emit_utils.plot_and_save import plot_results_with_density
from hypercoast.moe_vae import (
load_real_data_Robust,
load_real_test_Robust,
calculate_metrics,
plot_results,
save_results_to_excel,
save_results_from_excel_for_test,
preprocess_pace_data_Robust,
infer_and_visualize_single_model_Robust,
npy_to_tif,
MoE_VAE,
train,
evaluate,
)
from hypercoast import download_file
from hypercoast.emit_utils.plot_and_save import plot_results_with_density
from hypercoast.moe_vae import (
load_real_data_Robust,
load_real_test_Robust,
calculate_metrics,
plot_results,
save_results_to_excel,
save_results_from_excel_for_test,
preprocess_pace_data_Robust,
infer_and_visualize_single_model_Robust,
npy_to_tif,
MoE_VAE,
train,
evaluate,
)
Download data¶
In [ ]:
Copied!
url = "https://github.com/opengeos/datasets/releases/download/hypercoast/pace_moe_data.zip"
url = "https://github.com/opengeos/datasets/releases/download/hypercoast/pace_moe_data.zip"
In [ ]:
Copied!
download_file(url, quiet=False)
download_file(url, quiet=False)
Set data paths¶
In [ ]:
Copied!
nc_path = "./data/PACE_OCI.20240929T185124.L2.OC_AOP.V3_0.nc"
pace_rgb_path = "./data/snapshot-2024-08-10T00_00_00Z.tif"
wavelength_filepath = "./data/pace_wavelengths.csv"
# === Dataset paths ===
excel_path_train = "./data/Gloria_updated_07242025.xlsx"
test_files = [
"./data/GOA_insitu_data_07242025updated.xlsx",
]
base_save_dir = "./test"
save_dir = os.path.join(base_save_dir, "aCDOM440")
os.makedirs(save_dir, exist_ok=True)
nc_path = "./data/PACE_OCI.20240929T185124.L2.OC_AOP.V3_0.nc"
pace_rgb_path = "./data/snapshot-2024-08-10T00_00_00Z.tif"
wavelength_filepath = "./data/pace_wavelengths.csv"
# === Dataset paths ===
excel_path_train = "./data/Gloria_updated_07242025.xlsx"
test_files = [
"./data/GOA_insitu_data_07242025updated.xlsx",
]
base_save_dir = "./test"
save_dir = os.path.join(base_save_dir, "aCDOM440")
os.makedirs(save_dir, exist_ok=True)
Set PACE wavelengths for aCDOM440¶
The wavelength selection for aCDOM440 is different from chl-a, focusing on 150 bands from 400-720nm.
In [ ]:
Copied!
wv_PACE = pd.read_csv(wavelength_filepath)["wavelength"].tolist()
selected_bands = wv_PACE
wv_PACE = pd.read_csv(wavelength_filepath)["wavelength"].tolist()
selected_bands = wv_PACE
Read RGB image¶
In [ ]:
Copied!
with rasterio.open(pace_rgb_path) as ds:
R, G, B = ds.read(1), ds.read(2), ds.read(3)
extent = [ds.bounds.left, ds.bounds.right, ds.bounds.bottom, ds.bounds.top]
rgb_image = np.stack((R, G, B), axis=-1)
with rasterio.open(pace_rgb_path) as ds:
R, G, B = ds.read(1), ds.read(2), ds.read(3)
extent = [ds.bounds.left, ds.bounds.right, ds.bounds.bottom, ds.bounds.top]
rgb_image = np.stack((R, G, B), axis=-1)
Load training data¶
Using robust scaling for aCDOM440 prediction.
In [ ]:
Copied!
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
In [ ]:
Copied!
(
train_real_dl,
test_real_dl,
input_dim,
output_dim,
train_ids,
test_ids,
scalers_Rrs_real,
scalers_dict,
) = load_real_data_Robust(
excel_path=excel_path_train,
selected_bands=selected_bands,
target_parameter="aCDOM440",
)
# Save scalers for later use
torch.save(scalers_dict, os.path.join(save_dir, "scaler.pt"))
with open(os.path.join(save_dir, "scalers_Rrs_real.pkl"), "wb") as f:
pickle.dump(scalers_Rrs_real, f)
(
train_real_dl,
test_real_dl,
input_dim,
output_dim,
train_ids,
test_ids,
scalers_Rrs_real,
scalers_dict,
) = load_real_data_Robust(
excel_path=excel_path_train,
selected_bands=selected_bands,
target_parameter="aCDOM440",
)
# Save scalers for later use
torch.save(scalers_dict, os.path.join(save_dir, "scaler.pt"))
with open(os.path.join(save_dir, "scalers_Rrs_real.pkl"), "wb") as f:
pickle.dump(scalers_Rrs_real, f)
Load test data¶
In [ ]:
Copied!
test_dls, test_ids_list, test_dates_list = [], [], []
for file in test_files:
dl, _, _, ids, dates = load_real_test_Robust(
excel_path=file,
selected_bands=selected_bands,
scaler_Rrs=scalers_Rrs_real,
scalers_dict=scalers_dict,
target_parameter="aCDOM440",
)
test_dls.append(dl)
test_ids_list.append(ids)
test_dates_list.append(dates)
test_dls, test_ids_list, test_dates_list = [], [], []
for file in test_files:
dl, _, _, ids, dates = load_real_test_Robust(
excel_path=file,
selected_bands=selected_bands,
scaler_Rrs=scalers_Rrs_real,
scalers_dict=scalers_dict,
target_parameter="aCDOM440",
)
test_dls.append(dl)
test_ids_list.append(ids)
test_dates_list.append(dates)
Load PACE data¶
In [ ]:
Copied!
test_loader, Rrs, mask, latitude, longitude = preprocess_pace_data_Robust(
nc_path=nc_path,
scaler_Rrs=scalers_Rrs_real,
use_diff=False,
full_band_wavelengths=np.array(selected_bands),
)
test_loader, Rrs, mask, latitude, longitude = preprocess_pace_data_Robust(
nc_path=nc_path,
scaler_Rrs=scalers_Rrs_real,
use_diff=False,
full_band_wavelengths=np.array(selected_bands),
)
Initialize model¶
The model architecture for aCDOM440 uses larger hidden dimensions compared to chl-a.
In [ ]:
Copied!
model = MoE_VAE(
input_dim=input_dim,
output_dim=output_dim,
latent_dim=32,
encoder_hidden_dims=[256, 128, 64],
decoder_hidden_dims=[64, 128, 256],
activation="leakyrelu",
use_norm="layer",
use_dropout=False,
use_softplus_output=False,
num_experts=4,
k=2,
noisy_gating=True,
).to(device)
model = MoE_VAE(
input_dim=input_dim,
output_dim=output_dim,
latent_dim=32,
encoder_hidden_dims=[256, 128, 64],
decoder_hidden_dims=[64, 128, 256],
activation="leakyrelu",
use_norm="layer",
use_dropout=False,
use_softplus_output=False,
num_experts=4,
k=2,
noisy_gating=True,
).to(device)
Model training¶
In [ ]:
Copied!
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
train_log = train(
model,
train_real_dl,
device,
epochs=400,
optimizer=optimizer,
save_dir=save_dir,
)
best_train_loss = train_log["best_loss"]
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
train_log = train(
model,
train_real_dl,
device,
epochs=400,
optimizer=optimizer,
save_dir=save_dir,
)
best_train_loss = train_log["best_loss"]
Model evaluation¶
In [ ]:
Copied!
predictions, actuals = evaluate(
model, test_real_dl, device, TSS_scalers_dict=scalers_dict
)
epsilon, beta, rmse, rmsle, mape, bias, mae = calculate_metrics(predictions, actuals)
test_loss = mae
save_results_to_excel(
test_ids, actuals, predictions, os.path.join(save_dir, "test_results.xlsx")
)
predictions, actuals = evaluate(
model, test_real_dl, device, TSS_scalers_dict=scalers_dict
)
epsilon, beta, rmse, rmsle, mape, bias, mae = calculate_metrics(predictions, actuals)
test_loss = mae
save_results_to_excel(
test_ids, actuals, predictions, os.path.join(save_dir, "test_results.xlsx")
)
In [ ]:
Copied!
plot_results_with_density(
predictions, actuals, save_dir, mode="test_results", xlim=(-4, 2), ylim=(-4, 2)
)
plot_results_with_density(
predictions, actuals, save_dir, mode="test_results", xlim=(-4, 2), ylim=(-4, 2)
)
In [ ]:
Copied!
plot_results(predictions, actuals, save_dir, mode="test_results")
plot_results(predictions, actuals, save_dir, mode="test_results")
In [ ]:
Copied!
for dl, ids, dates, path in zip(test_dls, test_ids_list, test_dates_list, test_files):
preds, acts = evaluate(model, dl, device, TSS_scalers_dict=scalers_dict)
save_results_from_excel_for_test(preds, acts, ids, dates, path, save_dir)
for dl, ids, dates, path in zip(test_dls, test_ids_list, test_dates_list, test_files):
preds, acts = evaluate(model, dl, device, TSS_scalers_dict=scalers_dict)
save_results_from_excel_for_test(preds, acts, ids, dates, path, save_dir)
Model inference¶
In [ ]:
Copied!
Output = infer_and_visualize_single_model_Robust(
model=model,
test_loader=test_loader,
Rrs=Rrs,
mask=mask,
latitude=latitude,
longitude=longitude,
save_folder=save_dir,
extent=extent,
rgb_image=rgb_image,
structure_name="aCDOM440",
TSS_scalers_dict=scalers_dict,
vmin=0,
vmax=5,
)
print(
f"✅ Finished training, train loss: {best_train_loss:.4f}, test loss: {test_loss:.4f}"
)
Output = infer_and_visualize_single_model_Robust(
model=model,
test_loader=test_loader,
Rrs=Rrs,
mask=mask,
latitude=latitude,
longitude=longitude,
save_folder=save_dir,
extent=extent,
rgb_image=rgb_image,
structure_name="aCDOM440",
TSS_scalers_dict=scalers_dict,
vmin=0,
vmax=5,
)
print(
f"✅ Finished training, train loss: {best_train_loss:.4f}, test loss: {test_loss:.4f}"
)
Save as GeoTIFF (optional)¶
In [ ]:
Copied!
import matplotlib.pyplot as plt
tif_path = os.path.join(save_dir, "aCDOM440.tif")
npy_to_tif(npy_input=Output, out_tif=tif_path, resolution_m=1000)
with rasterio.open(tif_path) as src:
img = src.read(1)
transform = src.transform
bounds = src.bounds
img_masked = np.where(img < 0, np.nan, img)
extent = [bounds.left, bounds.right, bounds.bottom, bounds.top]
plt.figure(figsize=(8, 6))
im = plt.imshow(img_masked, cmap="jet", vmin=0, vmax=5, extent=extent, origin="upper")
plt.colorbar(im, label="aCDOM440")
plt.title("aCDOM440")
plt.show()
import matplotlib.pyplot as plt
tif_path = os.path.join(save_dir, "aCDOM440.tif")
npy_to_tif(npy_input=Output, out_tif=tif_path, resolution_m=1000)
with rasterio.open(tif_path) as src:
img = src.read(1)
transform = src.transform
bounds = src.bounds
img_masked = np.where(img < 0, np.nan, img)
extent = [bounds.left, bounds.right, bounds.bottom, bounds.top]
plt.figure(figsize=(8, 6))
im = plt.imshow(img_masked, cmap="jet", vmin=0, vmax=5, extent=extent, origin="upper")
plt.colorbar(im, label="aCDOM440")
plt.title("aCDOM440")
plt.show()