Skip to content

model_inference module

Model inference utilities for MoE-VAE.

This module provides functions for preprocessing and inference using MoE-VAE models.

infer_and_visualize_single_model_Robust(model, test_loader, Rrs, mask, latitude, longitude, save_folder, extent, rgb_image, structure_name, run, TSS_scalers_dict, vmin=0, vmax=50)

Infer and visualize results from a single MoE-VAE model.

Parameters:

Name Type Description Default
model torch.nn.Module

MoE-VAE model.

required
test_loader DataLoader

DataLoader for test data.

required
Rrs array

Rrs data.

required
mask array

Boolean mask indicating valid pixels.

required
latitude array

Latitude values.

required
longitude array

Longitude values.

required
save_folder str

Path to save the results.

required
extent tuple

Tuple containing the extent of the image.

required
rgb_image array

RGB image.

required
structure_name str

Name of the structure.

required
run int

Run number.

required
TSS_scalers_dict dict

Dictionary containing the scalers for TSS.

required
vmin float

Minimum value for the colorbar.

0
vmax float

Maximum value for the colorbar.

50
Source code in hypercoast/moe_vae/model_inference.py
def infer_and_visualize_single_model_Robust(
    model,
    test_loader,
    Rrs,
    mask,
    latitude,
    longitude,
    save_folder,
    extent,
    rgb_image,
    structure_name,
    run,
    TSS_scalers_dict,
    vmin=0,
    vmax=50,
):
    """
    Infer and visualize results from a single MoE-VAE model.

    Args:
        model (torch.nn.Module): MoE-VAE model.
        test_loader (DataLoader): DataLoader for test data.
        Rrs (array): Rrs data.
        mask (array): Boolean mask indicating valid pixels.
        latitude (array): Latitude values.
        longitude (array): Longitude values.
        save_folder (str): Path to save the results.
        extent (tuple): Tuple containing the extent of the image.
        rgb_image (array): RGB image.
        structure_name (str): Name of the structure.
        run (int): Run number.
        TSS_scalers_dict (dict): Dictionary containing the scalers for TSS.
        vmin (float): Minimum value for the colorbar.
        vmax (float): Maximum value for the colorbar.
    """
    device = next(model.parameters()).device
    predictions_all = []
    with torch.no_grad():
        for batch in test_loader:
            batch = batch[0].to(device)
            output_dict = model(batch)
            predictions = output_dict["pred_y"]

            # === Inverse transform using TSS_scalers_dict from training ===
            predictions_log = TSS_scalers_dict["robust"].inverse_transform(
                torch.tensor(predictions.cpu().numpy(), dtype=torch.float32)
            )
            predictions_all.append(
                TSS_scalers_dict["log"].inverse_transform(predictions_log).numpy()
            )

    predictions_all = np.vstack(predictions_all).squeeze(-1)
    outputs = np.full((Rrs.shape[0], Rrs.shape[1]), np.nan)
    outputs[mask] = predictions_all
    lat_flat = latitude.flatten()
    lon_flat = longitude.flatten()
    output_flat = outputs.flatten()
    final_output = np.column_stack((lat_flat, lon_flat, output_flat))

    if np.ma.isMaskedArray(final_output):
        final_output = final_output.filled(np.nan)
    os.makedirs(save_folder, exist_ok=True)
    base_name = os.path.splitext(os.path.basename(structure_name))[0]
    npy_path = os.path.join(save_folder, f"{base_name}.npy")
    png_path = os.path.join(save_folder, f"{base_name}.png")
    np.save(npy_path, final_output)

    latitude_masked = final_output[:, 0]
    longitude_masked = final_output[:, 1]
    tss_values = final_output[:, 2]

    mean_lat = (extent[2] + extent[3]) / 2
    resolution_deg_lat = 1000 / 111000
    resolution_deg_lon = 1000 / (111000 * np.cos(np.radians(mean_lat)))
    grid_lon = np.arange(extent[0], extent[1], resolution_deg_lon)
    grid_lat = np.arange(extent[3], extent[2], -resolution_deg_lat)
    grid_lon, grid_lat = np.meshgrid(grid_lon, grid_lat)
    tss_resampled = griddata(
        (longitude_masked, latitude_masked),
        tss_values,
        (grid_lon, grid_lat),
        method="linear",
    )
    tss_resampled = np.ma.masked_invalid(tss_resampled)

    plt.figure(figsize=(24, 6))
    plt.imshow(rgb_image / 255.0, extent=extent, origin="upper")
    im = plt.imshow(
        tss_resampled,
        extent=extent,
        cmap="jet",
        alpha=1,
        origin="upper",
        vmin=vmin,
        vmax=vmax,
    )
    cbar = plt.colorbar(im)
    cbar.set_label("(mg m$^{-3}$)", fontsize=16)
    plt.title(f"{structure_name} - Run {run}", loc="left", fontsize=20)
    plt.savefig(png_path, dpi=300, bbox_inches="tight", pad_inches=0.1)
    plt.close()

infer_and_visualize_single_model_minmax(model, test_loader, Rrs, mask, latitude, longitude, save_folder, extent, rgb_image, structure_name, run, vmin=0, vmax=50, log_offset=0.01)

Infer and visualize results from a single MoE-VAE model.

Parameters:

Name Type Description Default
model torch.nn.Module

MoE-VAE model.

required
test_loader DataLoader

DataLoader for test data.

required
Rrs array

Rrs data.

required
mask array

Boolean mask indicating valid pixels.

required
latitude array

Latitude values.

required
longitude array

Longitude values.

required
save_folder str

Path to save the results.

required
extent tuple

Tuple containing the extent of the image.

required
rgb_image array

RGB image.

required
structure_name str

Name of the structure.

required
run int

Run number.

required
vmin float

Minimum value for the colorbar.

0
vmax float

Maximum value for the colorbar.

50
log_offset float

Log offset for predictions.

0.01
Source code in hypercoast/moe_vae/model_inference.py
def infer_and_visualize_single_model_minmax(
    model,
    test_loader,
    Rrs,
    mask,
    latitude,
    longitude,
    save_folder,
    extent,
    rgb_image,
    structure_name,
    run,
    vmin=0,
    vmax=50,
    log_offset=0.01,
):
    """
    Infer and visualize results from a single MoE-VAE model.

    Args:
        model (torch.nn.Module): MoE-VAE model.
        test_loader (DataLoader): DataLoader for test data.
        Rrs (array): Rrs data.
        mask (array): Boolean mask indicating valid pixels.
        latitude (array): Latitude values.
        longitude (array): Longitude values.
        save_folder (str): Path to save the results.
        extent (tuple): Tuple containing the extent of the image.
        rgb_image (array): RGB image.
        structure_name (str): Name of the structure.
        run (int): Run number.
        vmin (float): Minimum value for the colorbar.
        vmax (float): Maximum value for the colorbar.
        log_offset (float): Log offset for predictions.
    """
    device = next(model.parameters()).device
    predictions_all = []

    with torch.no_grad():
        for batch in test_loader:
            batch = batch[0].to(device)
            output_dict = model(batch)
            predictions = output_dict["pred_y"]

            predictions_np = predictions.cpu().numpy()
            predictions_original = (10**predictions_np) - log_offset
            predictions_all.append(predictions_original)

    predictions_all = np.vstack(predictions_all).squeeze(-1)

    outputs = np.full((Rrs.shape[0], Rrs.shape[1]), np.nan)
    outputs[mask] = predictions_all

    lat_flat = latitude.flatten()
    lon_flat = longitude.flatten()
    output_flat = outputs.flatten()
    final_output = np.column_stack((lat_flat, lon_flat, output_flat))
    if np.ma.isMaskedArray(final_output):
        final_output = final_output.filled(np.nan)
    os.makedirs(save_folder, exist_ok=True)
    base_name = os.path.splitext(os.path.basename(structure_name))[0]
    npy_path = os.path.join(save_folder, f"{base_name}.npy")
    png_path = os.path.join(save_folder, f"{base_name}.png")
    np.save(npy_path, final_output)

    latitude_masked = final_output[:, 0]
    longitude_masked = final_output[:, 1]
    tss_values = final_output[:, 2]

    mean_lat = (extent[2] + extent[3]) / 2
    resolution_deg_lat = 1000 / 111000
    resolution_deg_lon = 1000 / (111000 * np.cos(np.radians(mean_lat)))
    grid_lon = np.arange(extent[0], extent[1], resolution_deg_lon)
    grid_lat = np.arange(extent[3], extent[2], -resolution_deg_lat)
    grid_lon, grid_lat = np.meshgrid(grid_lon, grid_lat)

    tss_resampled = griddata(
        (longitude_masked, latitude_masked),
        tss_values,
        (grid_lon, grid_lat),
        method="linear",
    )
    tss_resampled = np.ma.masked_invalid(tss_resampled)

    plt.figure(figsize=(24, 6))
    plt.imshow(rgb_image / 255.0, extent=extent, origin="upper")
    im = plt.imshow(
        tss_resampled,
        extent=extent,
        cmap="jet",
        alpha=1,
        origin="upper",
        vmin=vmin,
        vmax=vmax,
    )
    cbar = plt.colorbar(im)
    cbar.set_label("(mg m$^{-3}$)", fontsize=16)
    plt.title(f"{structure_name} - Run {run}", loc="left", fontsize=20)
    plt.savefig(png_path, dpi=300, bbox_inches="tight", pad_inches=0.1)
    plt.close()

infer_and_visualize_token_model_Robust(model, test_loader, Rrs, mask, latitude, longitude, save_folder, extent, rgb_image, structure_name, run, TSS_scalers_dict, vmin=0, vmax=50)

Infer and visualize results from a token-based MoE-VAE model.

Parameters:

Name Type Description Default
model torch.nn.Module

MoE-VAE model.

required
test_loader DataLoader

DataLoader for test data.

required
Rrs array

Rrs data.

required
mask array

Boolean mask indicating valid pixels.

required
latitude array

Latitude values.

required
longitude array

Longitude values.

required
save_folder str

Path to save the results.

required
extent tuple

Tuple containing the extent of the image.

required
rgb_image array

RGB image.

required
structure_name str

Name of the structure.

required
run int

Run number.

required
TSS_scalers_dict dict

Dictionary containing the scalers for TSS.

required
vmin float

Minimum value for the colorbar.

0
vmax float

Maximum value for the colorbar.

50
Source code in hypercoast/moe_vae/model_inference.py
def infer_and_visualize_token_model_Robust(
    model,
    test_loader,
    Rrs,
    mask,
    latitude,
    longitude,
    save_folder,
    extent,
    rgb_image,
    structure_name,
    run,
    TSS_scalers_dict,
    vmin=0,
    vmax=50,
):
    """
    Infer and visualize results from a token-based MoE-VAE model.

    Args:
        model (torch.nn.Module): MoE-VAE model.
        test_loader (DataLoader): DataLoader for test data.
        Rrs (array): Rrs data.
        mask (array): Boolean mask indicating valid pixels.
        latitude (array): Latitude values.
        longitude (array): Longitude values.
        save_folder (str): Path to save the results.
        extent (tuple): Tuple containing the extent of the image.
        rgb_image (array): RGB image.
        structure_name (str): Name of the structure.
        run (int): Run number.
        TSS_scalers_dict (dict): Dictionary containing the scalers for TSS.
        vmin (float): Minimum value for the colorbar.
        vmax (float): Maximum value for the colorbar.
    """
    device = next(model.parameters()).device
    predictions_all = []

    with torch.no_grad():
        for batch in test_loader:
            batch = batch[0].to(device)
            output_dict = model(batch)
            predictions = output_dict["pred_y"]  # shape [B, token_len]

            # === Aggregate by token ===
            if predictions.ndim == 2:
                predictions = predictions.mean(dim=1, keepdim=True)  # [B, 1]

            # === Robust + log inverse transform ===
            predictions_log = TSS_scalers_dict["robust"].inverse_transform(
                torch.tensor(predictions.cpu().numpy(), dtype=torch.float32)
            )
            predictions_all.append(
                TSS_scalers_dict["log"].inverse_transform(predictions_log).numpy()
            )

    # === Concatenate and remove extra dimensions ===
    predictions_all = np.vstack(predictions_all).reshape(-1)

    outputs = np.full((Rrs.shape[0], Rrs.shape[1]), np.nan)
    outputs[mask] = predictions_all

    lat_flat = latitude.flatten()
    lon_flat = longitude.flatten()
    output_flat = outputs.flatten()
    final_output = np.column_stack((lat_flat, lon_flat, output_flat))

    os.makedirs(save_folder, exist_ok=True)
    base_name = os.path.splitext(os.path.basename(structure_name))[0]
    npy_path = os.path.join(save_folder, f"{base_name}.npy")
    png_path = os.path.join(save_folder, f"{base_name}.png")
    np.save(npy_path, final_output)

    latitude_masked = final_output[:, 0]
    longitude_masked = final_output[:, 1]
    tss_values = final_output[:, 2]

    mean_lat = (extent[2] + extent[3]) / 2
    resolution_deg_lat = 1000 / 111000
    resolution_deg_lon = 1000 / (111000 * np.cos(np.radians(mean_lat)))
    grid_lon = np.arange(extent[0], extent[1], resolution_deg_lon)
    grid_lat = np.arange(extent[3], extent[2], -resolution_deg_lat)
    grid_lon, grid_lat = np.meshgrid(grid_lon, grid_lat)

    tss_resampled = griddata(
        (longitude_masked, latitude_masked),
        tss_values,
        (grid_lon, grid_lat),
        method="linear",
    )
    tss_resampled = np.ma.masked_invalid(tss_resampled)

    plt.figure(figsize=(24, 6))
    plt.imshow(rgb_image / 255.0, extent=extent, origin="upper")
    im = plt.imshow(
        tss_resampled,
        extent=extent,
        cmap="jet",
        alpha=1,
        origin="upper",
        vmin=vmin,
        vmax=vmax,
    )
    cbar = plt.colorbar(im)
    cbar.set_label("(mg m$^{-3}$)", fontsize=16)
    plt.title(f"{structure_name} - Run {run}", loc="left", fontsize=20)
    plt.savefig(png_path, dpi=300, bbox_inches="tight", pad_inches=0.1)
    plt.close()

infer_and_visualize_token_model_minmax(model, test_loader, Rrs, mask, latitude, longitude, save_folder, extent, rgb_image, structure_name, run, vmin=0, vmax=50, log_offset=0.01)

Infer and visualize results from a token-based MoE-VAE model.

Parameters:

Name Type Description Default
model torch.nn.Module

MoE-VAE model.

required
test_loader DataLoader

DataLoader for test data.

required
Rrs array

Rrs data.

required
mask array

Boolean mask indicating valid pixels.

required
latitude array

Latitude values.

required
longitude array

Longitude values.

required
save_folder str

Path to save the results.

required
extent tuple

Tuple containing the extent of the image.

required
rgb_image array

RGB image.

required
structure_name str

Name of the structure.

required
run int

Run number.

required
vmin float

Minimum value for the colorbar.

0
vmax float

Maximum value for the colorbar.

50
log_offset float

Log offset for predictions.

0.01
Source code in hypercoast/moe_vae/model_inference.py
def infer_and_visualize_token_model_minmax(
    model,
    test_loader,
    Rrs,
    mask,
    latitude,
    longitude,
    save_folder,
    extent,
    rgb_image,
    structure_name,
    run,
    vmin=0,
    vmax=50,
    log_offset=0.01,
):
    """
    Infer and visualize results from a token-based MoE-VAE model.

    Args:
        model (torch.nn.Module): MoE-VAE model.
        test_loader (DataLoader): DataLoader for test data.
        Rrs (array): Rrs data.
        mask (array): Boolean mask indicating valid pixels.
        latitude (array): Latitude values.
        longitude (array): Longitude values.
        save_folder (str): Path to save the results.
        extent (tuple): Tuple containing the extent of the image.
        rgb_image (array): RGB image.
        structure_name (str): Name of the structure.
        run (int): Run number.
        vmin (float): Minimum value for the colorbar.
        vmax (float): Maximum value for the colorbar.
        log_offset (float): Log offset for predictions.
    """
    device = next(model.parameters()).device
    predictions_all = []

    with torch.no_grad():
        for batch in test_loader:
            batch = batch[0].to(device)
            output_dict = model(batch)
            predictions = output_dict["pred_y"]  # shape [B, token_len]

            # === Aggregate along the token dimension ===
            if predictions.ndim == 2:
                predictions = predictions.mean(dim=1, keepdim=True)  # shape [B, 1]

            predictions_np = predictions.cpu().numpy()
            predictions_original = (10**predictions_np) - log_offset
            predictions_all.append(predictions_original)

    # === Concatenate and flatten ===
    predictions_all = np.vstack(predictions_all).reshape(-1)

    outputs = np.full((Rrs.shape[0], Rrs.shape[1]), np.nan)
    outputs[mask] = predictions_all

    lat_flat = latitude.flatten()
    lon_flat = longitude.flatten()
    output_flat = outputs.flatten()
    final_output = np.column_stack((lat_flat, lon_flat, output_flat))

    os.makedirs(save_folder, exist_ok=True)
    base_name = os.path.splitext(os.path.basename(structure_name))[0]
    npy_path = os.path.join(save_folder, f"{base_name}.npy")
    png_path = os.path.join(save_folder, f"{base_name}.png")
    np.save(npy_path, final_output)

    latitude_masked = final_output[:, 0]
    longitude_masked = final_output[:, 1]
    tss_values = final_output[:, 2]

    mean_lat = (extent[2] + extent[3]) / 2
    resolution_deg_lat = 1000 / 111000
    resolution_deg_lon = 1000 / (111000 * np.cos(np.radians(mean_lat)))
    grid_lon = np.arange(extent[0], extent[1], resolution_deg_lon)
    grid_lat = np.arange(extent[3], extent[2], -resolution_deg_lat)
    grid_lon, grid_lat = np.meshgrid(grid_lon, grid_lat)

    tss_resampled = griddata(
        (longitude_masked, latitude_masked),
        tss_values,
        (grid_lon, grid_lat),
        method="linear",
    )
    tss_resampled = np.ma.masked_invalid(tss_resampled)

    plt.figure(figsize=(24, 6))
    plt.imshow(rgb_image / 255.0, extent=extent, origin="upper")
    im = plt.imshow(
        tss_resampled,
        extent=extent,
        cmap="jet",
        alpha=1,
        origin="upper",
        vmin=vmin,
        vmax=vmax,
    )
    cbar = plt.colorbar(im)
    cbar.set_label("(mg m$^{-3}$)", fontsize=16)
    plt.title(f"{structure_name} - Run {run}", loc="left", fontsize=20)
    plt.savefig(png_path, dpi=300, bbox_inches="tight", pad_inches=0.1)
    plt.close()

preprocess_emit_data_Robust(nc_path, scaler_Rrs, use_diff=True, full_band_wavelengths=None)

Preprocess EMIT data for Robust scaling.

Parameters:

Name Type Description Default
nc_path str

Path to the NetCDF file containing EMIT data.

required
scaler_Rrs object

RobustScaler object for Rrs normalization.

required
use_diff bool

Whether to apply first-order differencing.

True
full_band_wavelengths list

List of target wavelength bands.

None

Returns:

Type Description
test_loader (DataLoader)

DataLoader for test data. filtered_Rrs (array): Filtered Rrs data. mask (array): Boolean mask indicating valid pixels. latitude (array): Latitude values. longitude (array): Longitude values.

Source code in hypercoast/moe_vae/model_inference.py
def preprocess_emit_data_Robust(
    nc_path, scaler_Rrs, use_diff=True, full_band_wavelengths=None
):
    """
    Preprocess EMIT data for Robust scaling.

    Args:
        nc_path (str): Path to the NetCDF file containing EMIT data.
        scaler_Rrs (object): RobustScaler object for Rrs normalization.
        use_diff (bool): Whether to apply first-order differencing.
        full_band_wavelengths (list): List of target wavelength bands.

    Returns:
        test_loader (DataLoader): DataLoader for test data.
        filtered_Rrs (array): Filtered Rrs data.
        mask (array): Boolean mask indicating valid pixels.
        latitude (array): Latitude values.
        longitude (array): Longitude values.
    """

    if full_band_wavelengths is None:
        raise ValueError(
            "full_band_wavelengths must be provided to match EMIT Rrs bands"
        )

    def find_closest_band(target, available_bands):
        available_waves = [int(b.split("_")[1]) for b in available_bands]
        closest_wave = min(available_waves, key=lambda w: abs(w - target))
        return f"Rrs_{closest_wave}"

    dataset = nc(nc_path)
    latitude = dataset.variables["lat"][:]
    longitude = dataset.variables["lon"][:]

    all_vars = dataset.variables.keys()

    bands_to_extract = []
    for w in full_band_wavelengths:
        band_name = f"Rrs_{int(w)}"
        if band_name in all_vars:
            bands_to_extract.append(band_name)
        else:
            closest = find_closest_band(int(w))
            print(f"⚠️ {band_name} does not exist, using the closest band {closest}")
            bands_to_extract.append(closest)
    filtered_Rrs = np.array([dataset.variables[band][:] for band in bands_to_extract])
    filtered_Rrs = np.moveaxis(filtered_Rrs, 0, -1)

    mask = np.all(~np.isnan(filtered_Rrs), axis=2)

    target_443 = (
        f"Rrs_443"
        if "Rrs_443" in bands_to_extract
        else find_closest_band(443, bands_to_extract)
    )
    target_560 = (
        f"Rrs_560"
        if "Rrs_560" in bands_to_extract
        else find_closest_band(560, bands_to_extract)
    )

    print(f"Using {target_443} and {target_560} for mask check.")

    idx_443 = bands_to_extract.index(target_443)
    idx_560 = bands_to_extract.index(target_560)
    mask &= filtered_Rrs[:, :, idx_443] <= filtered_Rrs[:, :, idx_560]

    valid_test_data = filtered_Rrs[mask]

    # ---- smooth + diff
    if use_diff:
        from scipy.ndimage import gaussian_filter1d

        Rrs_smoothed = np.array(
            [gaussian_filter1d(spectrum, sigma=1) for spectrum in valid_test_data]
        )
        Rrs_processed = np.diff(Rrs_smoothed, axis=1)
        print("✅ [5] Performed Gaussian smoothing + first-order differencing")
    else:
        Rrs_processed = valid_test_data
        print("✅ [5] Smoothing and differencing not enabled")

    # ---- normalize
    Rrs_normalized = scaler_Rrs.transform(
        torch.tensor(Rrs_processed, dtype=torch.float32)
    ).numpy()

    # ---- DataLoader
    test_tensor = TensorDataset(torch.tensor(Rrs_normalized).float())
    test_loader = DataLoader(test_tensor, batch_size=2048, shuffle=False)
    print("✅ [6] DataLoader construction completed")

    return test_loader, filtered_Rrs, mask, latitude, longitude

preprocess_emit_data_minmax(nc_path, full_band_wavelengths=None, diff_before_norm=False, diff_after_norm=False)

Read EMIT NetCDF, extract Rrs_* bands according to full_band_wavelengths, apply (optional) smooth+diff and robust normalization, and return a DataLoader.

Parameters:

Name Type Description Default
nc_path str

Path to the NetCDF file containing EMIT data.

required
full_band_wavelengths list

List of target wavelength bands.

None
diff_before_norm bool

Whether to apply first-order differencing before normalization.

False
diff_after_norm bool

Whether to apply first-order differencing after normalization.

False

Returns:

Type Description

test_loader, filtered_Rrs(H, W, B), mask(H, W), latitude(H), longitude(W)

Source code in hypercoast/moe_vae/model_inference.py
def preprocess_emit_data_minmax(
    nc_path, full_band_wavelengths=None, diff_before_norm=False, diff_after_norm=False
):
    """
    Read EMIT NetCDF, extract Rrs_* bands according to full_band_wavelengths,
    apply (optional) smooth+diff and robust normalization, and return a DataLoader.

    Args:
        nc_path (str): Path to the NetCDF file containing EMIT data.
        full_band_wavelengths (list): List of target wavelength bands.
        diff_before_norm (bool): Whether to apply first-order differencing before normalization.
        diff_after_norm (bool): Whether to apply first-order differencing after normalization.

    Returns:
        test_loader, filtered_Rrs(H, W, B), mask(H, W), latitude(H), longitude(W)
    """
    print(f"📥 Start processing: {nc_path}")

    # ---- sanity checks
    if full_band_wavelengths is None or len(full_band_wavelengths) == 0:
        raise ValueError(
            "A non-empty full_band_wavelengths must be provided (e.g., [400, 402, ...])."
        )

    full_band_wavelengths = [int(w) for w in full_band_wavelengths]

    try:
        with nc(nc_path, "r") as dataset:
            latitude = dataset.variables["lat"][:]
            longitude = dataset.variables["lon"][:]
            all_vars = set(dataset.variables.keys())
            available_wavelengths = [
                float(v.split("_")[1]) for v in all_vars if v.startswith("Rrs_")
            ]

            def find_closest_band(target_nm: float):
                nearest = min(available_wavelengths, key=lambda w: abs(w - target_nm))
                return f"Rrs_{int(nearest)}"

            # Search according to full_band_wavelengths
            bands_to_extract = []
            for w in full_band_wavelengths:
                band_name = f"Rrs_{w}"
                if band_name in all_vars:
                    bands_to_extract.append(band_name)
                else:
                    closest = find_closest_band(w)
                    print(
                        f"⚠️ {band_name} does not exist, using the closest band {closest}"
                    )
                    bands_to_extract.append(closest)

            seen = set()
            bands_to_extract = [
                b for b in bands_to_extract if not (b in seen or seen.add(b))
            ]

            if len(bands_to_extract) == 0:
                raise ValueError("❌ No usable Rrs_* bands found in the file.")
            # ---- read and stack to (H, W, B)
            # Each variable expected shape: (lat, lon) or (y, x)
            Rrs_stack = []
            for band in bands_to_extract:
                arr = dataset.variables[band][:]  # (H, W)
                Rrs_stack.append(arr)

            Rrs = np.array(Rrs_stack)  # (B, H, W)
            Rrs = np.moveaxis(Rrs, 0, -1)  # (H, W, B)
            filtered_Rrs = Rrs  # keep naming consistent with your previous return

            # ---- build mask using 440 & 560 (or nearest present within your requested list)
            have_waves = [int(b.split("_")[1]) for b in bands_to_extract]

            def nearest_idx(target_nm: int):
                # find nearest *among bands_to_extract*
                nearest_w = min(have_waves, key=lambda w: abs(w - target_nm))
                return bands_to_extract.index(f"Rrs_{nearest_w}")

            # Prefer exact if available; otherwise nearest in the user-requested set
            idx_440 = (
                bands_to_extract.index("Rrs_440")
                if "Rrs_440" in bands_to_extract
                else nearest_idx(440)
            )
            idx_560 = (
                bands_to_extract.index("Rrs_560")
                if "Rrs_560" in bands_to_extract
                else nearest_idx(560)
            )

            print(
                f"✅ Bands used for mask check: {bands_to_extract[idx_440]} and {bands_to_extract[idx_560]}"
            )

            mask_nanfree = np.all(~np.isnan(filtered_Rrs), axis=2)
            mask_condition = filtered_Rrs[:, :, idx_560] >= filtered_Rrs[:, :, idx_440]
            mask = mask_nanfree & mask_condition
            print(f"✅ [4] Built valid mask, remaining pixels: {int(np.sum(mask))}")

            if not np.any(mask):
                raise ValueError("❌ No valid pixels passed the filtering.")

            valid_test_data = filtered_Rrs[mask]  # (N, B)

        # === Check whether smoothing is needed (only executed if any differencing is enabled) ===
        if diff_before_norm or diff_after_norm:
            from scipy.ndimage import gaussian_filter1d

            Rrs_smoothed = np.array(
                [gaussian_filter1d(spectrum, sigma=1) for spectrum in valid_test_data]
            )
            print("✅ Gaussian smoothing applied")
        else:
            Rrs_smoothed = valid_test_data
            print("✅ Smoothing not enabled")

        # === Preprocessing before differencing ===
        if diff_before_norm:
            Rrs_preprocessed = np.diff(Rrs_smoothed, axis=1)
            print("✅ Preprocessing before differencing completed")
        else:
            Rrs_preprocessed = Rrs_smoothed
            print("✅ Preprocessing before differencing not enabled")

        # === MinMax normalization to [1, 10] ===
        scalers = [MinMaxScaler((1, 10)) for _ in range(Rrs_preprocessed.shape[0])]
        Rrs_normalized = np.array(
            [
                scalers[i].fit_transform(row.reshape(-1, 1)).flatten()
                for i, row in enumerate(Rrs_preprocessed)
            ]
        )

        # === Post-processing after differencing ===
        if diff_after_norm:
            Rrs_normalized = np.diff(Rrs_normalized, axis=1)
            print("✅ Post-processing after differencing completed")
        else:
            print("✅ Post-processing after differencing not enabled")

        # === Construct DataLoader
        test_tensor = TensorDataset(torch.tensor(Rrs_normalized).float())
        test_loader = DataLoader(test_tensor, batch_size=2048, shuffle=False)

        return test_loader, Rrs, mask, latitude, longitude

    except Exception as e:
        print(f"❌ [ERROR] Failed to process file {nc_path}: {e}")
        return None

preprocess_pace_data_Robust(nc_path, scaler_Rrs, use_diff=True, full_band_wavelengths=None)

Preprocess PACE data for Robust scaling.

Parameters:

Name Type Description Default
nc_path str

Path to the NetCDF file containing PACE data.

required
scaler_Rrs object

RobustScaler object for Rrs normalization.

required
use_diff bool

Whether to apply first-order differencing.

True
full_band_wavelengths list

List of target wavelength bands.

None

Returns:

Type Description
test_loader (DataLoader)

DataLoader for test data. filtered_Rrs (array): Filtered Rrs data. mask (array): Boolean mask indicating valid pixels. latitude (array): Latitude values. longitude (array): Longitude values.

Source code in hypercoast/moe_vae/model_inference.py
def preprocess_pace_data_Robust(
    nc_path, scaler_Rrs, use_diff=True, full_band_wavelengths=None
):
    """
    Preprocess PACE data for Robust scaling.

    Args:
        nc_path (str): Path to the NetCDF file containing PACE data.
        scaler_Rrs (object): RobustScaler object for Rrs normalization.
        use_diff (bool): Whether to apply first-order differencing.
        full_band_wavelengths (list): List of target wavelength bands.

    Returns:
        test_loader (DataLoader): DataLoader for test data.
        filtered_Rrs (array): Filtered Rrs data.
        mask (array): Boolean mask indicating valid pixels.
        latitude (array): Latitude values.
        longitude (array): Longitude values.
    """
    print(f"📥 Start processing: {nc_path}")
    try:
        PACE_dataset = read_pace(nc_path)
        print("✅ [1] Successfully read PACE data")

        da = PACE_dataset["Rrs"]
        Rrs = da.values  # [lat, lon, bands]
        latitude = da.latitude.values
        longitude = da.longitude.values
        print("✅ [2] Successfully retrieved Rrs, lat, and lon")

        # ✅ Extract wavelength
        if full_band_wavelengths is None:
            raise ValueError(
                "full_band_wavelengths must be provided to match PACE Rrs bands"
            )

        if hasattr(da, "wavelength") or "wavelength" in da.coords:
            pace_band_wavelengths = da.wavelength.values
        else:
            raise ValueError(
                "❌ Unable to extract wavelength from PACE data. Please check the NetCDF file structure."
            )

        missing = [b for b in full_band_wavelengths if b not in pace_band_wavelengths]
        if missing:
            raise ValueError(
                f"❌ The following wavelengths are not present in the PACE data: {missing}"
            )

        indices = [
            np.where(pace_band_wavelengths == b)[0][0] for b in full_band_wavelengths
        ]
        band_wavelengths = pace_band_wavelengths[indices]
        assert (
            band_wavelengths == np.array(full_band_wavelengths)
        ).all(), "❌ Band order mismatch"

        filtered_Rrs = Rrs[:, :, indices]
        print(
            f"✅ [3] Bands re-extracted based on full_band_wavelengths, total {len(indices)} bands"
        )

        # ✅ Build mask
        idx_440 = np.where(band_wavelengths == 440)[0][0]
        idx_560 = np.where(band_wavelengths == 560)[0][0]

        Rrs_440 = filtered_Rrs[:, :, idx_440]
        Rrs_560 = filtered_Rrs[:, :, idx_560]

        mask_nanfree = np.all(~np.isnan(filtered_Rrs), axis=2)
        mask_condition = Rrs_560 >= Rrs_440
        mask = mask_nanfree & mask_condition
        print(f"✅ [4] Built valid mask, remaining pixels: {np.sum(mask)}")

        if not np.any(mask):
            raise ValueError("❌ No valid pixels passed the filtering.")

        valid_test_data = filtered_Rrs[mask]

        # ✅ Smoothing before differencing (enabled only if use_diff=True)
        if use_diff:
            from scipy.ndimage import gaussian_filter1d

            Rrs_smoothed = np.array(
                [gaussian_filter1d(spectrum, sigma=1) for spectrum in valid_test_data]
            )
            Rrs_processed = np.diff(Rrs_smoothed, axis=1)
            print("✅ [5] Performed Gaussian smoothing + first-order differencing")
        else:
            Rrs_processed = valid_test_data
            print("✅ [5] Smoothing and differencing not enabled")

        # ✅ Normalization (RobustScaler provided)
        Rrs_normalized = scaler_Rrs.transform(
            torch.tensor(Rrs_processed, dtype=torch.float32)
        ).numpy()

        # ✅ Construct DataLoader
        from torch.utils.data import DataLoader, TensorDataset

        test_tensor = TensorDataset(torch.tensor(Rrs_normalized).float())
        test_loader = DataLoader(test_tensor, batch_size=2048, shuffle=False)
        print("✅ [6] DataLoader construction completed")

        return test_loader, filtered_Rrs, mask, latitude, longitude

    except Exception as e:
        print(f"❌ [ERROR] Failed to process file {nc_path}: {e}")
        return None

preprocess_pace_data_minmax(nc_path, full_band_wavelengths=None, diff_before_norm=False, diff_after_norm=False)

Preprocess PACE data for MinMax scaling.

Parameters:

Name Type Description Default
nc_path str

Path to the NetCDF file containing PACE data.

required
full_band_wavelengths list

List of target wavelength bands.

None
diff_before_norm bool

Whether to apply first-order differencing before normalization.

False
diff_after_norm bool

Whether to apply first-order differencing after normalization.

False

Returns:

Type Description
test_loader (DataLoader)

DataLoader for test data. Rrs (array): Rrs data. mask (array): Boolean mask indicating valid pixels. latitude (array): Latitude values. longitude (array): Longitude values.

Source code in hypercoast/moe_vae/model_inference.py
def preprocess_pace_data_minmax(
    nc_path, full_band_wavelengths=None, diff_before_norm=False, diff_after_norm=False
):
    """
    Preprocess PACE data for MinMax scaling.

    Args:
        nc_path (str): Path to the NetCDF file containing PACE data.
        full_band_wavelengths (list): List of target wavelength bands.
        diff_before_norm (bool): Whether to apply first-order differencing before normalization.
        diff_after_norm (bool): Whether to apply first-order differencing after normalization.

    Returns:
        test_loader (DataLoader): DataLoader for test data.
        Rrs (array): Rrs data.
        mask (array): Boolean mask indicating valid pixels.
        latitude (array): Latitude values.
        longitude (array): Longitude values.
    """
    try:
        # === Load data ===
        PACE_dataset = read_pace(nc_path)
        da = PACE_dataset["Rrs"]
        Rrs = da.values  # [lat, lon, bands]
        latitude = da.latitude.values
        longitude = da.longitude.values

        # === Band check ===
        if full_band_wavelengths is None:
            raise ValueError(
                "full_band_wavelengths must be provided to match PACE Rrs bands"
            )

        if hasattr(da, "wavelength") or "wavelength" in da.coords:
            pace_band_wavelengths = da.wavelength.values
        else:
            raise ValueError(
                "❌ Unable to extract wavelength from PACE data. Please check the NetCDF file structure."
            )

        # Check for missing bands
        missing = [b for b in full_band_wavelengths if b not in pace_band_wavelengths]
        if missing:
            raise ValueError(
                f"❌ The following wavelengths are not found in the PACE data: {missing}"
            )

        # Extract indices in the order of full_band_wavelengths
        indices = [
            np.where(pace_band_wavelengths == b)[0][0] for b in full_band_wavelengths
        ]
        band_wavelengths = pace_band_wavelengths[indices]
        assert (
            band_wavelengths == np.array(full_band_wavelengths)
        ).all(), "❌ Band order is inconsistent"

        # Extract Rrs for selected_bands
        filtered_Rrs = Rrs[:, :, indices]

        # === Mask construction ===
        idx_440 = np.where(band_wavelengths == 440)[0][0]
        idx_560 = np.where(band_wavelengths == 560)[0][0]
        Rrs_440 = filtered_Rrs[:, :, idx_440]
        Rrs_560 = filtered_Rrs[:, :, idx_560]

        mask_nanfree = np.all(~np.isnan(filtered_Rrs), axis=2)
        mask_condition = Rrs_560 >= Rrs_440
        mask = mask_nanfree & mask_condition
        if not np.any(mask):
            raise ValueError("❌ No valid pixels passed the filtering.")

        valid_data = filtered_Rrs[mask]  # [num_pixel, num_band]

        # === Check if smoothing is needed (only executed when any differencing is enabled) ===
        if diff_before_norm or diff_after_norm:
            from scipy.ndimage import gaussian_filter1d

            Rrs_smoothed = np.array(
                [gaussian_filter1d(spectrum, sigma=1) for spectrum in valid_data]
            )
            print("✅ Gaussian smoothing applied")
        else:
            Rrs_smoothed = valid_data
            print("✅ Smoothing not enabled")

        # === Preprocessing before differencing ===
        if diff_before_norm:
            Rrs_preprocessed = np.diff(Rrs_smoothed, axis=1)
            print("✅ Preprocessing before differencing completed")
        else:
            Rrs_preprocessed = Rrs_smoothed
            print("✅ Preprocessing before differencing not enabled")

        # === MinMax normalization to [1, 10] ===
        scalers = [MinMaxScaler((1, 10)) for _ in range(Rrs_preprocessed.shape[0])]
        Rrs_normalized = np.array(
            [
                scalers[i].fit_transform(row.reshape(-1, 1)).flatten()
                for i, row in enumerate(Rrs_preprocessed)
            ]
        )

        # === Post-processing after differencing ===
        if diff_after_norm:
            Rrs_normalized = np.diff(Rrs_normalized, axis=1)
            print("✅ Post-processing after differencing completed")
        else:
            print("✅ Post-processing after differencing not enabled")

        # === Construct DataLoader ===
        test_tensor = TensorDataset(torch.tensor(Rrs_normalized).float())
        test_loader = DataLoader(test_tensor, batch_size=2048, shuffle=False)

        return test_loader, Rrs, mask, latitude, longitude

    except Exception as e:
        print(f"❌ [ERROR] Failed to process file {nc_path}: {e}")
        return None