Moe vae
Mixture of Experts Variational Autoencoder (MoE-VAE)¶
This example demonstrates how to use the Mixture of Experts Variational Autoencoder (MoE-VAE) to predict chlorophyll-a (chl-a) using PACE data.
In [ ]:
Copied!
# %pip install "hypercoast[all]"
# %pip install "hypercoast[all]"
Import libraries¶
In [ ]:
Copied!
import os
import numpy as np
import pandas as pd
import rasterio
import torch
import os
import numpy as np
import pandas as pd
import rasterio
import torch
In [ ]:
Copied!
from hypercoast import download_file
from hypercoast.moe_vae import (
load_real_data,
load_real_test,
calculate_metrics,
plot_results,
save_results_to_excel,
save_and_plot_results_from_excel,
preprocess_pace_data_minmax,
infer_and_visualize_single_model_minmax,
MoE_VAE,
train,
evaluate,
)
from hypercoast import download_file
from hypercoast.moe_vae import (
load_real_data,
load_real_test,
calculate_metrics,
plot_results,
save_results_to_excel,
save_and_plot_results_from_excel,
preprocess_pace_data_minmax,
infer_and_visualize_single_model_minmax,
MoE_VAE,
train,
evaluate,
)
Download data¶
In [ ]:
Copied!
url = "https://github.com/opengeos/datasets/releases/download/hypercoast/pace_moe_data.zip"
url = "https://github.com/opengeos/datasets/releases/download/hypercoast/pace_moe_data.zip"
In [ ]:
Copied!
download_file(url, quiet=False)
download_file(url, quiet=False)
Set data paths¶
In [ ]:
Copied!
nc_path = "./data/PACE_OCI.20240929T185124.L2.OC_AOP.V3_0.nc"
pace_rgb_path = "./data/snapshot-2024-08-10T00_00_00Z.tif"
wavelength_filepath = "./data/pace_wavelengths.csv"
# === Dataset paths ===
excel_path_train = "./data/Gloria_updated_07242025.xlsx"
test_files = [
"./data/GreatLake_all_data.xlsx",
"./data/GOA_insitu_data_07242025updated.xlsx",
"./data/satellite_for_PACE.xlsx",
"./data/satellite_for_PACE_LE.xlsx",
]
base_save_dir = "./test"
nc_path = "./data/PACE_OCI.20240929T185124.L2.OC_AOP.V3_0.nc"
pace_rgb_path = "./data/snapshot-2024-08-10T00_00_00Z.tif"
wavelength_filepath = "./data/pace_wavelengths.csv"
# === Dataset paths ===
excel_path_train = "./data/Gloria_updated_07242025.xlsx"
test_files = [
"./data/GreatLake_all_data.xlsx",
"./data/GOA_insitu_data_07242025updated.xlsx",
"./data/satellite_for_PACE.xlsx",
"./data/satellite_for_PACE_LE.xlsx",
]
base_save_dir = "./test"
Set PACE wavelengths¶
In [ ]:
Copied!
wv_PACE = pd.read_csv(wavelength_filepath)["wavelength"].tolist()
# wv_PACE
wv_PACE = pd.read_csv(wavelength_filepath)["wavelength"].tolist()
# wv_PACE
In [ ]:
Copied!
selected_bands = wv_PACE
selected_bands = wv_PACE
Read data¶
In [ ]:
Copied!
with rasterio.open(pace_rgb_path) as ds:
R, G, B = ds.read(1), ds.read(2), ds.read(3)
extent = [ds.bounds.left, ds.bounds.right, ds.bounds.bottom, ds.bounds.top]
rgb_image = np.stack((R, G, B), axis=-1)
os.makedirs(base_save_dir, exist_ok=True)
with rasterio.open(pace_rgb_path) as ds:
R, G, B = ds.read(1), ds.read(2), ds.read(3)
extent = [ds.bounds.left, ds.bounds.right, ds.bounds.bottom, ds.bounds.top]
rgb_image = np.stack((R, G, B), axis=-1)
os.makedirs(base_save_dir, exist_ok=True)
Load training data¶
In [ ]:
Copied!
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
In [ ]:
Copied!
train_real_dl, test_real_dl, input_dim, output_dim, train_ids, test_ids = (
load_real_data(
excel_path=excel_path_train,
selected_bands=selected_bands,
seed=42,
diff_before_norm=False,
diff_after_norm=False,
target_parameter="chl-a",
lower_quantile=0,
upper_quantile=1,
log_offset=1,
)
)
train_real_dl, test_real_dl, input_dim, output_dim, train_ids, test_ids = (
load_real_data(
excel_path=excel_path_train,
selected_bands=selected_bands,
seed=42,
diff_before_norm=False,
diff_after_norm=False,
target_parameter="chl-a",
lower_quantile=0,
upper_quantile=1,
log_offset=1,
)
)
Load test data¶
In [ ]:
Copied!
test_dls, test_ids_list, test_dates_list = [], [], []
for file in test_files:
dl, _, _, ids, dates = load_real_test(
excel_path=file,
selected_bands=selected_bands,
diff_before_norm=False,
diff_after_norm=False,
max_allowed_diff=1.0,
target_parameter="chl-a",
log_offset=1,
)
test_dls.append(dl)
test_ids_list.append(ids)
test_dates_list.append(dates)
test_dls, test_ids_list, test_dates_list = [], [], []
for file in test_files:
dl, _, _, ids, dates = load_real_test(
excel_path=file,
selected_bands=selected_bands,
diff_before_norm=False,
diff_after_norm=False,
max_allowed_diff=1.0,
target_parameter="chl-a",
log_offset=1,
)
test_dls.append(dl)
test_ids_list.append(ids)
test_dates_list.append(dates)
Load PACE data¶
In [ ]:
Copied!
test_loader, Rrs, mask, latitude, longitude = preprocess_pace_data_minmax(
nc_path=nc_path,
diff_before_norm=False,
diff_after_norm=False,
full_band_wavelengths=np.array(selected_bands),
)
test_loader, Rrs, mask, latitude, longitude = preprocess_pace_data_minmax(
nc_path=nc_path,
diff_before_norm=False,
diff_after_norm=False,
full_band_wavelengths=np.array(selected_bands),
)
Initialize model¶
In [ ]:
Copied!
model = MoE_VAE(
input_dim=input_dim,
output_dim=output_dim,
latent_dim=32,
encoder_hidden_dims=[64, 64],
decoder_hidden_dims=[64, 64],
activation="leakyrelu",
use_norm="layer",
use_dropout=False,
use_softplus_output=True,
num_experts=4,
k=2,
noisy_gating=True,
).to(device)
model = MoE_VAE(
input_dim=input_dim,
output_dim=output_dim,
latent_dim=32,
encoder_hidden_dims=[64, 64],
decoder_hidden_dims=[64, 64],
activation="leakyrelu",
use_norm="layer",
use_dropout=False,
use_softplus_output=True,
num_experts=4,
k=2,
noisy_gating=True,
).to(device)
Model training¶
In [ ]:
Copied!
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
train_log = train(
model,
train_real_dl,
device,
epochs=400,
optimizer=optimizer,
save_dir=base_save_dir,
)
best_train_loss = train_log["best_loss"]
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
train_log = train(
model,
train_real_dl,
device,
epochs=400,
optimizer=optimizer,
save_dir=base_save_dir,
)
best_train_loss = train_log["best_loss"]
Model evaluation¶
In [ ]:
Copied!
predictions, actuals = evaluate(model, test_real_dl, device, log_offset=1)
epsilon, beta, rmse, rmsle, mape, bias, mae = calculate_metrics(predictions, actuals)
test_loss = mae
save_results_to_excel(
test_ids, actuals, predictions, os.path.join(base_save_dir, "test_results.xlsx")
)
plot_results(predictions, actuals, base_save_dir, mode="test_results")
for dl, ids, dates, path in zip(test_dls, test_ids_list, test_dates_list, test_files):
preds, acts = evaluate(model, dl, device, log_offset=1)
save_and_plot_results_from_excel(preds, acts, ids, dates, path, base_save_dir)
predictions, actuals = evaluate(model, test_real_dl, device, log_offset=1)
epsilon, beta, rmse, rmsle, mape, bias, mae = calculate_metrics(predictions, actuals)
test_loss = mae
save_results_to_excel(
test_ids, actuals, predictions, os.path.join(base_save_dir, "test_results.xlsx")
)
plot_results(predictions, actuals, base_save_dir, mode="test_results")
for dl, ids, dates, path in zip(test_dls, test_ids_list, test_dates_list, test_files):
preds, acts = evaluate(model, dl, device, log_offset=1)
save_and_plot_results_from_excel(preds, acts, ids, dates, path, base_save_dir)
Model inference¶
In [ ]:
Copied!
infer_and_visualize_single_model_minmax(
model=model,
test_loader=test_loader,
Rrs=Rrs,
mask=mask,
latitude=latitude,
longitude=longitude,
save_folder=base_save_dir,
extent=extent,
rgb_image=rgb_image,
structure_name="09292024",
run=1,
vmin=0,
vmax=30,
log_offset=1,
)
print(
f"✅ Finished training, train loss: {best_train_loss:.4f}, test loss: {test_loss:.4f}"
)
infer_and_visualize_single_model_minmax(
model=model,
test_loader=test_loader,
Rrs=Rrs,
mask=mask,
latitude=latitude,
longitude=longitude,
save_folder=base_save_dir,
extent=extent,
rgb_image=rgb_image,
structure_name="09292024",
run=1,
vmin=0,
vmax=30,
log_offset=1,
)
print(
f"✅ Finished training, train loss: {best_train_loss:.4f}, test loss: {test_loss:.4f}"
)