Moe vae
Mixture of Experts Variational Autoencoder (MoE-VAE)¶
This example demonstrates how to use the Mixture of Experts Variational Autoencoder (MoE-VAE) to predict chlorophyll-a (chl-a) using PACE data.
In [ ]:
Copied!
# %pip install "hypercoast[all]"
# %pip install "hypercoast[all]"
Import libraries¶
In [ ]:
Copied!
import os
import numpy as np
import pandas as pd
import rasterio
import torch
import os
import numpy as np
import pandas as pd
import rasterio
import torch
In [ ]:
Copied!
from hypercoast import download_file
from hypercoast.moe_vae import (
    load_real_data,
    load_real_test,
    calculate_metrics,
    plot_results,
    save_results_to_excel,
    save_and_plot_results_from_excel,
    preprocess_pace_data_minmax,
    infer_and_visualize_single_model_minmax,
    MoE_VAE,
    train,
    evaluate,
)
from hypercoast import download_file
from hypercoast.moe_vae import (
    load_real_data,
    load_real_test,
    calculate_metrics,
    plot_results,
    save_results_to_excel,
    save_and_plot_results_from_excel,
    preprocess_pace_data_minmax,
    infer_and_visualize_single_model_minmax,
    MoE_VAE,
    train,
    evaluate,
)
Download data¶
In [ ]:
Copied!
url = "https://github.com/opengeos/datasets/releases/download/hypercoast/pace_moe_data.zip"
url = "https://github.com/opengeos/datasets/releases/download/hypercoast/pace_moe_data.zip"
In [ ]:
Copied!
download_file(url, quiet=False)
download_file(url, quiet=False)
Set data paths¶
In [ ]:
Copied!
nc_path = "./data/PACE_OCI.20240929T185124.L2.OC_AOP.V3_0.nc"
pace_rgb_path = "./data/snapshot-2024-08-10T00_00_00Z.tif"
wavelength_filepath = "./data/pace_wavelengths.csv"
# === Dataset paths ===
excel_path_train = "./data/Gloria_updated_07242025.xlsx"
test_files = [
    "./data/GreatLake_all_data.xlsx",
    "./data/GOA_insitu_data_07242025updated.xlsx",
    "./data/satellite_for_PACE.xlsx",
    "./data/satellite_for_PACE_LE.xlsx",
]
base_save_dir = "./test"
nc_path = "./data/PACE_OCI.20240929T185124.L2.OC_AOP.V3_0.nc"
pace_rgb_path = "./data/snapshot-2024-08-10T00_00_00Z.tif"
wavelength_filepath = "./data/pace_wavelengths.csv"
# === Dataset paths ===
excel_path_train = "./data/Gloria_updated_07242025.xlsx"
test_files = [
    "./data/GreatLake_all_data.xlsx",
    "./data/GOA_insitu_data_07242025updated.xlsx",
    "./data/satellite_for_PACE.xlsx",
    "./data/satellite_for_PACE_LE.xlsx",
]
base_save_dir = "./test"
Set PACE wavelengths¶
In [ ]:
Copied!
wv_PACE = pd.read_csv(wavelength_filepath)["wavelength"].tolist()
# wv_PACE
wv_PACE = pd.read_csv(wavelength_filepath)["wavelength"].tolist()
# wv_PACE
In [ ]:
Copied!
selected_bands = wv_PACE
selected_bands = wv_PACE
Read data¶
In [ ]:
Copied!
with rasterio.open(pace_rgb_path) as ds:
    R, G, B = ds.read(1), ds.read(2), ds.read(3)
    extent = [ds.bounds.left, ds.bounds.right, ds.bounds.bottom, ds.bounds.top]
    rgb_image = np.stack((R, G, B), axis=-1)
os.makedirs(base_save_dir, exist_ok=True)
with rasterio.open(pace_rgb_path) as ds:
    R, G, B = ds.read(1), ds.read(2), ds.read(3)
    extent = [ds.bounds.left, ds.bounds.right, ds.bounds.bottom, ds.bounds.top]
    rgb_image = np.stack((R, G, B), axis=-1)
os.makedirs(base_save_dir, exist_ok=True)
Load training data¶
In [ ]:
Copied!
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
In [ ]:
Copied!
train_real_dl, test_real_dl, input_dim, output_dim, train_ids, test_ids = (
    load_real_data(
        excel_path=excel_path_train,
        selected_bands=selected_bands,
        seed=42,
        diff_before_norm=False,
        diff_after_norm=False,
        target_parameter="chl-a",
        lower_quantile=0,
        upper_quantile=1,
        log_offset=1,
    )
)
train_real_dl, test_real_dl, input_dim, output_dim, train_ids, test_ids = (
    load_real_data(
        excel_path=excel_path_train,
        selected_bands=selected_bands,
        seed=42,
        diff_before_norm=False,
        diff_after_norm=False,
        target_parameter="chl-a",
        lower_quantile=0,
        upper_quantile=1,
        log_offset=1,
    )
)
Load test data¶
In [ ]:
Copied!
test_dls, test_ids_list, test_dates_list = [], [], []
for file in test_files:
    dl, _, _, ids, dates = load_real_test(
        excel_path=file,
        selected_bands=selected_bands,
        diff_before_norm=False,
        diff_after_norm=False,
        max_allowed_diff=1.0,
        target_parameter="chl-a",
        log_offset=1,
    )
    test_dls.append(dl)
    test_ids_list.append(ids)
    test_dates_list.append(dates)
test_dls, test_ids_list, test_dates_list = [], [], []
for file in test_files:
    dl, _, _, ids, dates = load_real_test(
        excel_path=file,
        selected_bands=selected_bands,
        diff_before_norm=False,
        diff_after_norm=False,
        max_allowed_diff=1.0,
        target_parameter="chl-a",
        log_offset=1,
    )
    test_dls.append(dl)
    test_ids_list.append(ids)
    test_dates_list.append(dates)
Load PACE data¶
In [ ]:
Copied!
test_loader, Rrs, mask, latitude, longitude = preprocess_pace_data_minmax(
    nc_path=nc_path,
    diff_before_norm=False,
    diff_after_norm=False,
    full_band_wavelengths=np.array(selected_bands),
)
test_loader, Rrs, mask, latitude, longitude = preprocess_pace_data_minmax(
    nc_path=nc_path,
    diff_before_norm=False,
    diff_after_norm=False,
    full_band_wavelengths=np.array(selected_bands),
)
Initialize model¶
In [ ]:
Copied!
model = MoE_VAE(
    input_dim=input_dim,
    output_dim=output_dim,
    latent_dim=32,
    encoder_hidden_dims=[64, 64],
    decoder_hidden_dims=[64, 64],
    activation="leakyrelu",
    use_norm="layer",
    use_dropout=False,
    use_softplus_output=True,
    num_experts=4,
    k=2,
    noisy_gating=True,
).to(device)
model = MoE_VAE(
    input_dim=input_dim,
    output_dim=output_dim,
    latent_dim=32,
    encoder_hidden_dims=[64, 64],
    decoder_hidden_dims=[64, 64],
    activation="leakyrelu",
    use_norm="layer",
    use_dropout=False,
    use_softplus_output=True,
    num_experts=4,
    k=2,
    noisy_gating=True,
).to(device)
Model training¶
In [ ]:
Copied!
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
train_log = train(
    model,
    train_real_dl,
    device,
    epochs=400,
    optimizer=optimizer,
    save_dir=base_save_dir,
)
best_train_loss = train_log["best_loss"]
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
train_log = train(
    model,
    train_real_dl,
    device,
    epochs=400,
    optimizer=optimizer,
    save_dir=base_save_dir,
)
best_train_loss = train_log["best_loss"]
Model evaluation¶
In [ ]:
Copied!
predictions, actuals = evaluate(model, test_real_dl, device, log_offset=1)
epsilon, beta, rmse, rmsle, mape, bias, mae = calculate_metrics(predictions, actuals)
test_loss = mae
save_results_to_excel(
    test_ids, actuals, predictions, os.path.join(base_save_dir, "test_results.xlsx")
)
plot_results(predictions, actuals, base_save_dir, mode="test_results")
for dl, ids, dates, path in zip(test_dls, test_ids_list, test_dates_list, test_files):
    preds, acts = evaluate(model, dl, device, log_offset=1)
    save_and_plot_results_from_excel(preds, acts, ids, dates, path, base_save_dir)
predictions, actuals = evaluate(model, test_real_dl, device, log_offset=1)
epsilon, beta, rmse, rmsle, mape, bias, mae = calculate_metrics(predictions, actuals)
test_loss = mae
save_results_to_excel(
    test_ids, actuals, predictions, os.path.join(base_save_dir, "test_results.xlsx")
)
plot_results(predictions, actuals, base_save_dir, mode="test_results")
for dl, ids, dates, path in zip(test_dls, test_ids_list, test_dates_list, test_files):
    preds, acts = evaluate(model, dl, device, log_offset=1)
    save_and_plot_results_from_excel(preds, acts, ids, dates, path, base_save_dir)
Model inference¶
In [ ]:
Copied!
infer_and_visualize_single_model_minmax(
    model=model,
    test_loader=test_loader,
    Rrs=Rrs,
    mask=mask,
    latitude=latitude,
    longitude=longitude,
    save_folder=base_save_dir,
    extent=extent,
    rgb_image=rgb_image,
    structure_name="09292024",
    run=1,
    vmin=0,
    vmax=30,
    log_offset=1,
)
print(
    f"✅ Finished training, train loss: {best_train_loss:.4f}, test loss: {test_loss:.4f}"
)
infer_and_visualize_single_model_minmax(
    model=model,
    test_loader=test_loader,
    Rrs=Rrs,
    mask=mask,
    latitude=latitude,
    longitude=longitude,
    save_folder=base_save_dir,
    extent=extent,
    rgb_image=rgb_image,
    structure_name="09292024",
    run=1,
    vmin=0,
    vmax=30,
    log_offset=1,
)
print(
    f"✅ Finished training, train loss: {best_train_loss:.4f}, test loss: {test_loss:.4f}"
)