model_inference module¶
Model inference utilities for MoE-VAE.
This module provides functions for preprocessing and inference using MoE-VAE models.
infer_and_visualize_single_model_Robust(model, test_loader, Rrs, mask, latitude, longitude, save_folder, extent, rgb_image, structure_name, run, TSS_scalers_dict, vmin=0, vmax=50)
¶
    Infer and visualize results from a single MoE-VAE model.
Parameters:
| Name | Type | Description | Default | 
|---|---|---|---|
| model | torch.nn.Module | MoE-VAE model. | required | 
| test_loader | DataLoader | DataLoader for test data. | required | 
| Rrs | array | Rrs data. | required | 
| mask | array | Boolean mask indicating valid pixels. | required | 
| latitude | array | Latitude values. | required | 
| longitude | array | Longitude values. | required | 
| save_folder | str | Path to save the results. | required | 
| extent | tuple | Tuple containing the extent of the image. | required | 
| rgb_image | array | RGB image. | required | 
| structure_name | str | Name of the structure. | required | 
| run | int | Run number. | required | 
| TSS_scalers_dict | dict | Dictionary containing the scalers for TSS. | required | 
| vmin | float | Minimum value for the colorbar. | 0 | 
| vmax | float | Maximum value for the colorbar. | 50 | 
Source code in hypercoast/moe_vae/model_inference.py
          def infer_and_visualize_single_model_Robust(
    model,
    test_loader,
    Rrs,
    mask,
    latitude,
    longitude,
    save_folder,
    extent,
    rgb_image,
    structure_name,
    run,
    TSS_scalers_dict,
    vmin=0,
    vmax=50,
):
    """
    Infer and visualize results from a single MoE-VAE model.
    Args:
        model (torch.nn.Module): MoE-VAE model.
        test_loader (DataLoader): DataLoader for test data.
        Rrs (array): Rrs data.
        mask (array): Boolean mask indicating valid pixels.
        latitude (array): Latitude values.
        longitude (array): Longitude values.
        save_folder (str): Path to save the results.
        extent (tuple): Tuple containing the extent of the image.
        rgb_image (array): RGB image.
        structure_name (str): Name of the structure.
        run (int): Run number.
        TSS_scalers_dict (dict): Dictionary containing the scalers for TSS.
        vmin (float): Minimum value for the colorbar.
        vmax (float): Maximum value for the colorbar.
    """
    device = next(model.parameters()).device
    predictions_all = []
    with torch.no_grad():
        for batch in test_loader:
            batch = batch[0].to(device)
            output_dict = model(batch)
            predictions = output_dict["pred_y"]
            # === Inverse transform using TSS_scalers_dict from training ===
            predictions_log = TSS_scalers_dict["robust"].inverse_transform(
                torch.tensor(predictions.cpu().numpy(), dtype=torch.float32)
            )
            predictions_all.append(
                TSS_scalers_dict["log"].inverse_transform(predictions_log).numpy()
            )
    predictions_all = np.vstack(predictions_all).squeeze(-1)
    outputs = np.full((Rrs.shape[0], Rrs.shape[1]), np.nan)
    outputs[mask] = predictions_all
    lat_flat = latitude.flatten()
    lon_flat = longitude.flatten()
    output_flat = outputs.flatten()
    final_output = np.column_stack((lat_flat, lon_flat, output_flat))
    if np.ma.isMaskedArray(final_output):
        final_output = final_output.filled(np.nan)
    os.makedirs(save_folder, exist_ok=True)
    base_name = os.path.splitext(os.path.basename(structure_name))[0]
    npy_path = os.path.join(save_folder, f"{base_name}.npy")
    png_path = os.path.join(save_folder, f"{base_name}.png")
    np.save(npy_path, final_output)
    latitude_masked = final_output[:, 0]
    longitude_masked = final_output[:, 1]
    tss_values = final_output[:, 2]
    mean_lat = (extent[2] + extent[3]) / 2
    resolution_deg_lat = 1000 / 111000
    resolution_deg_lon = 1000 / (111000 * np.cos(np.radians(mean_lat)))
    grid_lon = np.arange(extent[0], extent[1], resolution_deg_lon)
    grid_lat = np.arange(extent[3], extent[2], -resolution_deg_lat)
    grid_lon, grid_lat = np.meshgrid(grid_lon, grid_lat)
    tss_resampled = griddata(
        (longitude_masked, latitude_masked),
        tss_values,
        (grid_lon, grid_lat),
        method="linear",
    )
    tss_resampled = np.ma.masked_invalid(tss_resampled)
    plt.figure(figsize=(24, 6))
    plt.imshow(rgb_image / 255.0, extent=extent, origin="upper")
    im = plt.imshow(
        tss_resampled,
        extent=extent,
        cmap="jet",
        alpha=1,
        origin="upper",
        vmin=vmin,
        vmax=vmax,
    )
    cbar = plt.colorbar(im)
    cbar.set_label("(mg m$^{-3}$)", fontsize=16)
    plt.title(f"{structure_name} - Run {run}", loc="left", fontsize=20)
    plt.savefig(png_path, dpi=300, bbox_inches="tight", pad_inches=0.1)
    plt.close()
infer_and_visualize_single_model_minmax(model, test_loader, Rrs, mask, latitude, longitude, save_folder, extent, rgb_image, structure_name, run, vmin=0, vmax=50, log_offset=0.01)
¶
    Infer and visualize results from a single MoE-VAE model.
Parameters:
| Name | Type | Description | Default | 
|---|---|---|---|
| model | torch.nn.Module | MoE-VAE model. | required | 
| test_loader | DataLoader | DataLoader for test data. | required | 
| Rrs | array | Rrs data. | required | 
| mask | array | Boolean mask indicating valid pixels. | required | 
| latitude | array | Latitude values. | required | 
| longitude | array | Longitude values. | required | 
| save_folder | str | Path to save the results. | required | 
| extent | tuple | Tuple containing the extent of the image. | required | 
| rgb_image | array | RGB image. | required | 
| structure_name | str | Name of the structure. | required | 
| run | int | Run number. | required | 
| vmin | float | Minimum value for the colorbar. | 0 | 
| vmax | float | Maximum value for the colorbar. | 50 | 
| log_offset | float | Log offset for predictions. | 0.01 | 
Source code in hypercoast/moe_vae/model_inference.py
          def infer_and_visualize_single_model_minmax(
    model,
    test_loader,
    Rrs,
    mask,
    latitude,
    longitude,
    save_folder,
    extent,
    rgb_image,
    structure_name,
    run,
    vmin=0,
    vmax=50,
    log_offset=0.01,
):
    """
    Infer and visualize results from a single MoE-VAE model.
    Args:
        model (torch.nn.Module): MoE-VAE model.
        test_loader (DataLoader): DataLoader for test data.
        Rrs (array): Rrs data.
        mask (array): Boolean mask indicating valid pixels.
        latitude (array): Latitude values.
        longitude (array): Longitude values.
        save_folder (str): Path to save the results.
        extent (tuple): Tuple containing the extent of the image.
        rgb_image (array): RGB image.
        structure_name (str): Name of the structure.
        run (int): Run number.
        vmin (float): Minimum value for the colorbar.
        vmax (float): Maximum value for the colorbar.
        log_offset (float): Log offset for predictions.
    """
    device = next(model.parameters()).device
    predictions_all = []
    with torch.no_grad():
        for batch in test_loader:
            batch = batch[0].to(device)
            output_dict = model(batch)
            predictions = output_dict["pred_y"]
            predictions_np = predictions.cpu().numpy()
            predictions_original = (10**predictions_np) - log_offset
            predictions_all.append(predictions_original)
    predictions_all = np.vstack(predictions_all).squeeze(-1)
    outputs = np.full((Rrs.shape[0], Rrs.shape[1]), np.nan)
    outputs[mask] = predictions_all
    lat_flat = latitude.flatten()
    lon_flat = longitude.flatten()
    output_flat = outputs.flatten()
    final_output = np.column_stack((lat_flat, lon_flat, output_flat))
    if np.ma.isMaskedArray(final_output):
        final_output = final_output.filled(np.nan)
    os.makedirs(save_folder, exist_ok=True)
    base_name = os.path.splitext(os.path.basename(structure_name))[0]
    npy_path = os.path.join(save_folder, f"{base_name}.npy")
    png_path = os.path.join(save_folder, f"{base_name}.png")
    np.save(npy_path, final_output)
    latitude_masked = final_output[:, 0]
    longitude_masked = final_output[:, 1]
    tss_values = final_output[:, 2]
    mean_lat = (extent[2] + extent[3]) / 2
    resolution_deg_lat = 1000 / 111000
    resolution_deg_lon = 1000 / (111000 * np.cos(np.radians(mean_lat)))
    grid_lon = np.arange(extent[0], extent[1], resolution_deg_lon)
    grid_lat = np.arange(extent[3], extent[2], -resolution_deg_lat)
    grid_lon, grid_lat = np.meshgrid(grid_lon, grid_lat)
    tss_resampled = griddata(
        (longitude_masked, latitude_masked),
        tss_values,
        (grid_lon, grid_lat),
        method="linear",
    )
    tss_resampled = np.ma.masked_invalid(tss_resampled)
    plt.figure(figsize=(24, 6))
    plt.imshow(rgb_image / 255.0, extent=extent, origin="upper")
    im = plt.imshow(
        tss_resampled,
        extent=extent,
        cmap="jet",
        alpha=1,
        origin="upper",
        vmin=vmin,
        vmax=vmax,
    )
    cbar = plt.colorbar(im)
    cbar.set_label("(mg m$^{-3}$)", fontsize=16)
    plt.title(f"{structure_name} - Run {run}", loc="left", fontsize=20)
    plt.savefig(png_path, dpi=300, bbox_inches="tight", pad_inches=0.1)
    plt.close()
infer_and_visualize_token_model_Robust(model, test_loader, Rrs, mask, latitude, longitude, save_folder, extent, rgb_image, structure_name, run, TSS_scalers_dict, vmin=0, vmax=50)
¶
    Infer and visualize results from a token-based MoE-VAE model.
Parameters:
| Name | Type | Description | Default | 
|---|---|---|---|
| model | torch.nn.Module | MoE-VAE model. | required | 
| test_loader | DataLoader | DataLoader for test data. | required | 
| Rrs | array | Rrs data. | required | 
| mask | array | Boolean mask indicating valid pixels. | required | 
| latitude | array | Latitude values. | required | 
| longitude | array | Longitude values. | required | 
| save_folder | str | Path to save the results. | required | 
| extent | tuple | Tuple containing the extent of the image. | required | 
| rgb_image | array | RGB image. | required | 
| structure_name | str | Name of the structure. | required | 
| run | int | Run number. | required | 
| TSS_scalers_dict | dict | Dictionary containing the scalers for TSS. | required | 
| vmin | float | Minimum value for the colorbar. | 0 | 
| vmax | float | Maximum value for the colorbar. | 50 | 
Source code in hypercoast/moe_vae/model_inference.py
          def infer_and_visualize_token_model_Robust(
    model,
    test_loader,
    Rrs,
    mask,
    latitude,
    longitude,
    save_folder,
    extent,
    rgb_image,
    structure_name,
    run,
    TSS_scalers_dict,
    vmin=0,
    vmax=50,
):
    """
    Infer and visualize results from a token-based MoE-VAE model.
    Args:
        model (torch.nn.Module): MoE-VAE model.
        test_loader (DataLoader): DataLoader for test data.
        Rrs (array): Rrs data.
        mask (array): Boolean mask indicating valid pixels.
        latitude (array): Latitude values.
        longitude (array): Longitude values.
        save_folder (str): Path to save the results.
        extent (tuple): Tuple containing the extent of the image.
        rgb_image (array): RGB image.
        structure_name (str): Name of the structure.
        run (int): Run number.
        TSS_scalers_dict (dict): Dictionary containing the scalers for TSS.
        vmin (float): Minimum value for the colorbar.
        vmax (float): Maximum value for the colorbar.
    """
    device = next(model.parameters()).device
    predictions_all = []
    with torch.no_grad():
        for batch in test_loader:
            batch = batch[0].to(device)
            output_dict = model(batch)
            predictions = output_dict["pred_y"]  # shape [B, token_len]
            # === Aggregate by token ===
            if predictions.ndim == 2:
                predictions = predictions.mean(dim=1, keepdim=True)  # [B, 1]
            # === Robust + log inverse transform ===
            predictions_log = TSS_scalers_dict["robust"].inverse_transform(
                torch.tensor(predictions.cpu().numpy(), dtype=torch.float32)
            )
            predictions_all.append(
                TSS_scalers_dict["log"].inverse_transform(predictions_log).numpy()
            )
    # === Concatenate and remove extra dimensions ===
    predictions_all = np.vstack(predictions_all).reshape(-1)
    outputs = np.full((Rrs.shape[0], Rrs.shape[1]), np.nan)
    outputs[mask] = predictions_all
    lat_flat = latitude.flatten()
    lon_flat = longitude.flatten()
    output_flat = outputs.flatten()
    final_output = np.column_stack((lat_flat, lon_flat, output_flat))
    os.makedirs(save_folder, exist_ok=True)
    base_name = os.path.splitext(os.path.basename(structure_name))[0]
    npy_path = os.path.join(save_folder, f"{base_name}.npy")
    png_path = os.path.join(save_folder, f"{base_name}.png")
    np.save(npy_path, final_output)
    latitude_masked = final_output[:, 0]
    longitude_masked = final_output[:, 1]
    tss_values = final_output[:, 2]
    mean_lat = (extent[2] + extent[3]) / 2
    resolution_deg_lat = 1000 / 111000
    resolution_deg_lon = 1000 / (111000 * np.cos(np.radians(mean_lat)))
    grid_lon = np.arange(extent[0], extent[1], resolution_deg_lon)
    grid_lat = np.arange(extent[3], extent[2], -resolution_deg_lat)
    grid_lon, grid_lat = np.meshgrid(grid_lon, grid_lat)
    tss_resampled = griddata(
        (longitude_masked, latitude_masked),
        tss_values,
        (grid_lon, grid_lat),
        method="linear",
    )
    tss_resampled = np.ma.masked_invalid(tss_resampled)
    plt.figure(figsize=(24, 6))
    plt.imshow(rgb_image / 255.0, extent=extent, origin="upper")
    im = plt.imshow(
        tss_resampled,
        extent=extent,
        cmap="jet",
        alpha=1,
        origin="upper",
        vmin=vmin,
        vmax=vmax,
    )
    cbar = plt.colorbar(im)
    cbar.set_label("(mg m$^{-3}$)", fontsize=16)
    plt.title(f"{structure_name} - Run {run}", loc="left", fontsize=20)
    plt.savefig(png_path, dpi=300, bbox_inches="tight", pad_inches=0.1)
    plt.close()
infer_and_visualize_token_model_minmax(model, test_loader, Rrs, mask, latitude, longitude, save_folder, extent, rgb_image, structure_name, run, vmin=0, vmax=50, log_offset=0.01)
¶
    Infer and visualize results from a token-based MoE-VAE model.
Parameters:
| Name | Type | Description | Default | 
|---|---|---|---|
| model | torch.nn.Module | MoE-VAE model. | required | 
| test_loader | DataLoader | DataLoader for test data. | required | 
| Rrs | array | Rrs data. | required | 
| mask | array | Boolean mask indicating valid pixels. | required | 
| latitude | array | Latitude values. | required | 
| longitude | array | Longitude values. | required | 
| save_folder | str | Path to save the results. | required | 
| extent | tuple | Tuple containing the extent of the image. | required | 
| rgb_image | array | RGB image. | required | 
| structure_name | str | Name of the structure. | required | 
| run | int | Run number. | required | 
| vmin | float | Minimum value for the colorbar. | 0 | 
| vmax | float | Maximum value for the colorbar. | 50 | 
| log_offset | float | Log offset for predictions. | 0.01 | 
Source code in hypercoast/moe_vae/model_inference.py
          def infer_and_visualize_token_model_minmax(
    model,
    test_loader,
    Rrs,
    mask,
    latitude,
    longitude,
    save_folder,
    extent,
    rgb_image,
    structure_name,
    run,
    vmin=0,
    vmax=50,
    log_offset=0.01,
):
    """
    Infer and visualize results from a token-based MoE-VAE model.
    Args:
        model (torch.nn.Module): MoE-VAE model.
        test_loader (DataLoader): DataLoader for test data.
        Rrs (array): Rrs data.
        mask (array): Boolean mask indicating valid pixels.
        latitude (array): Latitude values.
        longitude (array): Longitude values.
        save_folder (str): Path to save the results.
        extent (tuple): Tuple containing the extent of the image.
        rgb_image (array): RGB image.
        structure_name (str): Name of the structure.
        run (int): Run number.
        vmin (float): Minimum value for the colorbar.
        vmax (float): Maximum value for the colorbar.
        log_offset (float): Log offset for predictions.
    """
    device = next(model.parameters()).device
    predictions_all = []
    with torch.no_grad():
        for batch in test_loader:
            batch = batch[0].to(device)
            output_dict = model(batch)
            predictions = output_dict["pred_y"]  # shape [B, token_len]
            # === Aggregate along the token dimension ===
            if predictions.ndim == 2:
                predictions = predictions.mean(dim=1, keepdim=True)  # shape [B, 1]
            predictions_np = predictions.cpu().numpy()
            predictions_original = (10**predictions_np) - log_offset
            predictions_all.append(predictions_original)
    # === Concatenate and flatten ===
    predictions_all = np.vstack(predictions_all).reshape(-1)
    outputs = np.full((Rrs.shape[0], Rrs.shape[1]), np.nan)
    outputs[mask] = predictions_all
    lat_flat = latitude.flatten()
    lon_flat = longitude.flatten()
    output_flat = outputs.flatten()
    final_output = np.column_stack((lat_flat, lon_flat, output_flat))
    os.makedirs(save_folder, exist_ok=True)
    base_name = os.path.splitext(os.path.basename(structure_name))[0]
    npy_path = os.path.join(save_folder, f"{base_name}.npy")
    png_path = os.path.join(save_folder, f"{base_name}.png")
    np.save(npy_path, final_output)
    latitude_masked = final_output[:, 0]
    longitude_masked = final_output[:, 1]
    tss_values = final_output[:, 2]
    mean_lat = (extent[2] + extent[3]) / 2
    resolution_deg_lat = 1000 / 111000
    resolution_deg_lon = 1000 / (111000 * np.cos(np.radians(mean_lat)))
    grid_lon = np.arange(extent[0], extent[1], resolution_deg_lon)
    grid_lat = np.arange(extent[3], extent[2], -resolution_deg_lat)
    grid_lon, grid_lat = np.meshgrid(grid_lon, grid_lat)
    tss_resampled = griddata(
        (longitude_masked, latitude_masked),
        tss_values,
        (grid_lon, grid_lat),
        method="linear",
    )
    tss_resampled = np.ma.masked_invalid(tss_resampled)
    plt.figure(figsize=(24, 6))
    plt.imshow(rgb_image / 255.0, extent=extent, origin="upper")
    im = plt.imshow(
        tss_resampled,
        extent=extent,
        cmap="jet",
        alpha=1,
        origin="upper",
        vmin=vmin,
        vmax=vmax,
    )
    cbar = plt.colorbar(im)
    cbar.set_label("(mg m$^{-3}$)", fontsize=16)
    plt.title(f"{structure_name} - Run {run}", loc="left", fontsize=20)
    plt.savefig(png_path, dpi=300, bbox_inches="tight", pad_inches=0.1)
    plt.close()
preprocess_emit_data_Robust(nc_path, scaler_Rrs, use_diff=True, full_band_wavelengths=None)
¶
    Preprocess EMIT data for Robust scaling.
Parameters:
| Name | Type | Description | Default | 
|---|---|---|---|
| nc_path | str | Path to the NetCDF file containing EMIT data. | required | 
| scaler_Rrs | object | RobustScaler object for Rrs normalization. | required | 
| use_diff | bool | Whether to apply first-order differencing. | True | 
| full_band_wavelengths | list | List of target wavelength bands. | None | 
Returns:
| Type | Description | 
|---|---|
| test_loader (DataLoader) | DataLoader for test data. filtered_Rrs (array): Filtered Rrs data. mask (array): Boolean mask indicating valid pixels. latitude (array): Latitude values. longitude (array): Longitude values. | 
Source code in hypercoast/moe_vae/model_inference.py
          def preprocess_emit_data_Robust(
    nc_path, scaler_Rrs, use_diff=True, full_band_wavelengths=None
):
    """
    Preprocess EMIT data for Robust scaling.
    Args:
        nc_path (str): Path to the NetCDF file containing EMIT data.
        scaler_Rrs (object): RobustScaler object for Rrs normalization.
        use_diff (bool): Whether to apply first-order differencing.
        full_band_wavelengths (list): List of target wavelength bands.
    Returns:
        test_loader (DataLoader): DataLoader for test data.
        filtered_Rrs (array): Filtered Rrs data.
        mask (array): Boolean mask indicating valid pixels.
        latitude (array): Latitude values.
        longitude (array): Longitude values.
    """
    if full_band_wavelengths is None:
        raise ValueError(
            "full_band_wavelengths must be provided to match EMIT Rrs bands"
        )
    def find_closest_band(target, available_bands):
        available_waves = [int(b.split("_")[1]) for b in available_bands]
        closest_wave = min(available_waves, key=lambda w: abs(w - target))
        return f"Rrs_{closest_wave}"
    dataset = nc(nc_path)
    latitude = dataset.variables["lat"][:]
    longitude = dataset.variables["lon"][:]
    all_vars = dataset.variables.keys()
    bands_to_extract = []
    for w in full_band_wavelengths:
        band_name = f"Rrs_{int(w)}"
        if band_name in all_vars:
            bands_to_extract.append(band_name)
        else:
            closest = find_closest_band(int(w))
            print(f"⚠️ {band_name} does not exist, using the closest band {closest}")
            bands_to_extract.append(closest)
    filtered_Rrs = np.array([dataset.variables[band][:] for band in bands_to_extract])
    filtered_Rrs = np.moveaxis(filtered_Rrs, 0, -1)
    mask = np.all(~np.isnan(filtered_Rrs), axis=2)
    target_443 = (
        f"Rrs_443"
        if "Rrs_443" in bands_to_extract
        else find_closest_band(443, bands_to_extract)
    )
    target_560 = (
        f"Rrs_560"
        if "Rrs_560" in bands_to_extract
        else find_closest_band(560, bands_to_extract)
    )
    print(f"Using {target_443} and {target_560} for mask check.")
    idx_443 = bands_to_extract.index(target_443)
    idx_560 = bands_to_extract.index(target_560)
    mask &= filtered_Rrs[:, :, idx_443] <= filtered_Rrs[:, :, idx_560]
    valid_test_data = filtered_Rrs[mask]
    # ---- smooth + diff
    if use_diff:
        from scipy.ndimage import gaussian_filter1d
        Rrs_smoothed = np.array(
            [gaussian_filter1d(spectrum, sigma=1) for spectrum in valid_test_data]
        )
        Rrs_processed = np.diff(Rrs_smoothed, axis=1)
        print("✅ [5] Performed Gaussian smoothing + first-order differencing")
    else:
        Rrs_processed = valid_test_data
        print("✅ [5] Smoothing and differencing not enabled")
    # ---- normalize
    Rrs_normalized = scaler_Rrs.transform(
        torch.tensor(Rrs_processed, dtype=torch.float32)
    ).numpy()
    # ---- DataLoader
    test_tensor = TensorDataset(torch.tensor(Rrs_normalized).float())
    test_loader = DataLoader(test_tensor, batch_size=2048, shuffle=False)
    print("✅ [6] DataLoader construction completed")
    return test_loader, filtered_Rrs, mask, latitude, longitude
preprocess_emit_data_minmax(nc_path, full_band_wavelengths=None, diff_before_norm=False, diff_after_norm=False)
¶
    Read EMIT NetCDF, extract Rrs_* bands according to full_band_wavelengths, apply (optional) smooth+diff and robust normalization, and return a DataLoader.
Parameters:
| Name | Type | Description | Default | 
|---|---|---|---|
| nc_path | str | Path to the NetCDF file containing EMIT data. | required | 
| full_band_wavelengths | list | List of target wavelength bands. | None | 
| diff_before_norm | bool | Whether to apply first-order differencing before normalization. | False | 
| diff_after_norm | bool | Whether to apply first-order differencing after normalization. | False | 
Returns:
| Type | Description | 
|---|---|
| test_loader, filtered_Rrs(H, W, B), mask(H, W), latitude(H), longitude(W) | 
Source code in hypercoast/moe_vae/model_inference.py
          def preprocess_emit_data_minmax(
    nc_path, full_band_wavelengths=None, diff_before_norm=False, diff_after_norm=False
):
    """
    Read EMIT NetCDF, extract Rrs_* bands according to full_band_wavelengths,
    apply (optional) smooth+diff and robust normalization, and return a DataLoader.
    Args:
        nc_path (str): Path to the NetCDF file containing EMIT data.
        full_band_wavelengths (list): List of target wavelength bands.
        diff_before_norm (bool): Whether to apply first-order differencing before normalization.
        diff_after_norm (bool): Whether to apply first-order differencing after normalization.
    Returns:
        test_loader, filtered_Rrs(H, W, B), mask(H, W), latitude(H), longitude(W)
    """
    print(f"📥 Start processing: {nc_path}")
    # ---- sanity checks
    if full_band_wavelengths is None or len(full_band_wavelengths) == 0:
        raise ValueError(
            "A non-empty full_band_wavelengths must be provided (e.g., [400, 402, ...])."
        )
    full_band_wavelengths = [int(w) for w in full_band_wavelengths]
    try:
        with nc(nc_path, "r") as dataset:
            latitude = dataset.variables["lat"][:]
            longitude = dataset.variables["lon"][:]
            all_vars = set(dataset.variables.keys())
            available_wavelengths = [
                float(v.split("_")[1]) for v in all_vars if v.startswith("Rrs_")
            ]
            def find_closest_band(target_nm: float):
                nearest = min(available_wavelengths, key=lambda w: abs(w - target_nm))
                return f"Rrs_{int(nearest)}"
            # Search according to full_band_wavelengths
            bands_to_extract = []
            for w in full_band_wavelengths:
                band_name = f"Rrs_{w}"
                if band_name in all_vars:
                    bands_to_extract.append(band_name)
                else:
                    closest = find_closest_band(w)
                    print(
                        f"⚠️ {band_name} does not exist, using the closest band {closest}"
                    )
                    bands_to_extract.append(closest)
            seen = set()
            bands_to_extract = [
                b for b in bands_to_extract if not (b in seen or seen.add(b))
            ]
            if len(bands_to_extract) == 0:
                raise ValueError("❌ No usable Rrs_* bands found in the file.")
            # ---- read and stack to (H, W, B)
            # Each variable expected shape: (lat, lon) or (y, x)
            Rrs_stack = []
            for band in bands_to_extract:
                arr = dataset.variables[band][:]  # (H, W)
                Rrs_stack.append(arr)
            Rrs = np.array(Rrs_stack)  # (B, H, W)
            Rrs = np.moveaxis(Rrs, 0, -1)  # (H, W, B)
            filtered_Rrs = Rrs  # keep naming consistent with your previous return
            # ---- build mask using 440 & 560 (or nearest present within your requested list)
            have_waves = [int(b.split("_")[1]) for b in bands_to_extract]
            def nearest_idx(target_nm: int):
                # find nearest *among bands_to_extract*
                nearest_w = min(have_waves, key=lambda w: abs(w - target_nm))
                return bands_to_extract.index(f"Rrs_{nearest_w}")
            # Prefer exact if available; otherwise nearest in the user-requested set
            idx_440 = (
                bands_to_extract.index("Rrs_440")
                if "Rrs_440" in bands_to_extract
                else nearest_idx(440)
            )
            idx_560 = (
                bands_to_extract.index("Rrs_560")
                if "Rrs_560" in bands_to_extract
                else nearest_idx(560)
            )
            print(
                f"✅ Bands used for mask check: {bands_to_extract[idx_440]} and {bands_to_extract[idx_560]}"
            )
            mask_nanfree = np.all(~np.isnan(filtered_Rrs), axis=2)
            mask_condition = filtered_Rrs[:, :, idx_560] >= filtered_Rrs[:, :, idx_440]
            mask = mask_nanfree & mask_condition
            print(f"✅ [4] Built valid mask, remaining pixels: {int(np.sum(mask))}")
            if not np.any(mask):
                raise ValueError("❌ No valid pixels passed the filtering.")
            valid_test_data = filtered_Rrs[mask]  # (N, B)
        # === Check whether smoothing is needed (only executed if any differencing is enabled) ===
        if diff_before_norm or diff_after_norm:
            from scipy.ndimage import gaussian_filter1d
            Rrs_smoothed = np.array(
                [gaussian_filter1d(spectrum, sigma=1) for spectrum in valid_test_data]
            )
            print("✅ Gaussian smoothing applied")
        else:
            Rrs_smoothed = valid_test_data
            print("✅ Smoothing not enabled")
        # === Preprocessing before differencing ===
        if diff_before_norm:
            Rrs_preprocessed = np.diff(Rrs_smoothed, axis=1)
            print("✅ Preprocessing before differencing completed")
        else:
            Rrs_preprocessed = Rrs_smoothed
            print("✅ Preprocessing before differencing not enabled")
        # === MinMax normalization to [1, 10] ===
        scalers = [MinMaxScaler((1, 10)) for _ in range(Rrs_preprocessed.shape[0])]
        Rrs_normalized = np.array(
            [
                scalers[i].fit_transform(row.reshape(-1, 1)).flatten()
                for i, row in enumerate(Rrs_preprocessed)
            ]
        )
        # === Post-processing after differencing ===
        if diff_after_norm:
            Rrs_normalized = np.diff(Rrs_normalized, axis=1)
            print("✅ Post-processing after differencing completed")
        else:
            print("✅ Post-processing after differencing not enabled")
        # === Construct DataLoader
        test_tensor = TensorDataset(torch.tensor(Rrs_normalized).float())
        test_loader = DataLoader(test_tensor, batch_size=2048, shuffle=False)
        return test_loader, Rrs, mask, latitude, longitude
    except Exception as e:
        print(f"❌ [ERROR] Failed to process file {nc_path}: {e}")
        return None
preprocess_pace_data_Robust(nc_path, scaler_Rrs, use_diff=True, full_band_wavelengths=None)
¶
    Preprocess PACE data for Robust scaling.
Parameters:
| Name | Type | Description | Default | 
|---|---|---|---|
| nc_path | str | Path to the NetCDF file containing PACE data. | required | 
| scaler_Rrs | object | RobustScaler object for Rrs normalization. | required | 
| use_diff | bool | Whether to apply first-order differencing. | True | 
| full_band_wavelengths | list | List of target wavelength bands. | None | 
Returns:
| Type | Description | 
|---|---|
| test_loader (DataLoader) | DataLoader for test data. filtered_Rrs (array): Filtered Rrs data. mask (array): Boolean mask indicating valid pixels. latitude (array): Latitude values. longitude (array): Longitude values. | 
Source code in hypercoast/moe_vae/model_inference.py
          def preprocess_pace_data_Robust(
    nc_path, scaler_Rrs, use_diff=True, full_band_wavelengths=None
):
    """
    Preprocess PACE data for Robust scaling.
    Args:
        nc_path (str): Path to the NetCDF file containing PACE data.
        scaler_Rrs (object): RobustScaler object for Rrs normalization.
        use_diff (bool): Whether to apply first-order differencing.
        full_band_wavelengths (list): List of target wavelength bands.
    Returns:
        test_loader (DataLoader): DataLoader for test data.
        filtered_Rrs (array): Filtered Rrs data.
        mask (array): Boolean mask indicating valid pixels.
        latitude (array): Latitude values.
        longitude (array): Longitude values.
    """
    print(f"📥 Start processing: {nc_path}")
    try:
        PACE_dataset = read_pace(nc_path)
        print("✅ [1] Successfully read PACE data")
        da = PACE_dataset["Rrs"]
        Rrs = da.values  # [lat, lon, bands]
        latitude = da.latitude.values
        longitude = da.longitude.values
        print("✅ [2] Successfully retrieved Rrs, lat, and lon")
        # ✅ Extract wavelength
        if full_band_wavelengths is None:
            raise ValueError(
                "full_band_wavelengths must be provided to match PACE Rrs bands"
            )
        if hasattr(da, "wavelength") or "wavelength" in da.coords:
            pace_band_wavelengths = da.wavelength.values
        else:
            raise ValueError(
                "❌ Unable to extract wavelength from PACE data. Please check the NetCDF file structure."
            )
        missing = [b for b in full_band_wavelengths if b not in pace_band_wavelengths]
        if missing:
            raise ValueError(
                f"❌ The following wavelengths are not present in the PACE data: {missing}"
            )
        indices = [
            np.where(pace_band_wavelengths == b)[0][0] for b in full_band_wavelengths
        ]
        band_wavelengths = pace_band_wavelengths[indices]
        assert (
            band_wavelengths == np.array(full_band_wavelengths)
        ).all(), "❌ Band order mismatch"
        filtered_Rrs = Rrs[:, :, indices]
        print(
            f"✅ [3] Bands re-extracted based on full_band_wavelengths, total {len(indices)} bands"
        )
        # ✅ Build mask
        idx_440 = np.where(band_wavelengths == 440)[0][0]
        idx_560 = np.where(band_wavelengths == 560)[0][0]
        Rrs_440 = filtered_Rrs[:, :, idx_440]
        Rrs_560 = filtered_Rrs[:, :, idx_560]
        mask_nanfree = np.all(~np.isnan(filtered_Rrs), axis=2)
        mask_condition = Rrs_560 >= Rrs_440
        mask = mask_nanfree & mask_condition
        print(f"✅ [4] Built valid mask, remaining pixels: {np.sum(mask)}")
        if not np.any(mask):
            raise ValueError("❌ No valid pixels passed the filtering.")
        valid_test_data = filtered_Rrs[mask]
        # ✅ Smoothing before differencing (enabled only if use_diff=True)
        if use_diff:
            from scipy.ndimage import gaussian_filter1d
            Rrs_smoothed = np.array(
                [gaussian_filter1d(spectrum, sigma=1) for spectrum in valid_test_data]
            )
            Rrs_processed = np.diff(Rrs_smoothed, axis=1)
            print("✅ [5] Performed Gaussian smoothing + first-order differencing")
        else:
            Rrs_processed = valid_test_data
            print("✅ [5] Smoothing and differencing not enabled")
        # ✅ Normalization (RobustScaler provided)
        Rrs_normalized = scaler_Rrs.transform(
            torch.tensor(Rrs_processed, dtype=torch.float32)
        ).numpy()
        # ✅ Construct DataLoader
        from torch.utils.data import DataLoader, TensorDataset
        test_tensor = TensorDataset(torch.tensor(Rrs_normalized).float())
        test_loader = DataLoader(test_tensor, batch_size=2048, shuffle=False)
        print("✅ [6] DataLoader construction completed")
        return test_loader, filtered_Rrs, mask, latitude, longitude
    except Exception as e:
        print(f"❌ [ERROR] Failed to process file {nc_path}: {e}")
        return None
preprocess_pace_data_minmax(nc_path, full_band_wavelengths=None, diff_before_norm=False, diff_after_norm=False)
¶
    Preprocess PACE data for MinMax scaling.
Parameters:
| Name | Type | Description | Default | 
|---|---|---|---|
| nc_path | str | Path to the NetCDF file containing PACE data. | required | 
| full_band_wavelengths | list | List of target wavelength bands. | None | 
| diff_before_norm | bool | Whether to apply first-order differencing before normalization. | False | 
| diff_after_norm | bool | Whether to apply first-order differencing after normalization. | False | 
Returns:
| Type | Description | 
|---|---|
| test_loader (DataLoader) | DataLoader for test data. Rrs (array): Rrs data. mask (array): Boolean mask indicating valid pixels. latitude (array): Latitude values. longitude (array): Longitude values. | 
Source code in hypercoast/moe_vae/model_inference.py
          def preprocess_pace_data_minmax(
    nc_path, full_band_wavelengths=None, diff_before_norm=False, diff_after_norm=False
):
    """
    Preprocess PACE data for MinMax scaling.
    Args:
        nc_path (str): Path to the NetCDF file containing PACE data.
        full_band_wavelengths (list): List of target wavelength bands.
        diff_before_norm (bool): Whether to apply first-order differencing before normalization.
        diff_after_norm (bool): Whether to apply first-order differencing after normalization.
    Returns:
        test_loader (DataLoader): DataLoader for test data.
        Rrs (array): Rrs data.
        mask (array): Boolean mask indicating valid pixels.
        latitude (array): Latitude values.
        longitude (array): Longitude values.
    """
    try:
        # === Load data ===
        PACE_dataset = read_pace(nc_path)
        da = PACE_dataset["Rrs"]
        Rrs = da.values  # [lat, lon, bands]
        latitude = da.latitude.values
        longitude = da.longitude.values
        # === Band check ===
        if full_band_wavelengths is None:
            raise ValueError(
                "full_band_wavelengths must be provided to match PACE Rrs bands"
            )
        if hasattr(da, "wavelength") or "wavelength" in da.coords:
            pace_band_wavelengths = da.wavelength.values
        else:
            raise ValueError(
                "❌ Unable to extract wavelength from PACE data. Please check the NetCDF file structure."
            )
        # Check for missing bands
        missing = [b for b in full_band_wavelengths if b not in pace_band_wavelengths]
        if missing:
            raise ValueError(
                f"❌ The following wavelengths are not found in the PACE data: {missing}"
            )
        # Extract indices in the order of full_band_wavelengths
        indices = [
            np.where(pace_band_wavelengths == b)[0][0] for b in full_band_wavelengths
        ]
        band_wavelengths = pace_band_wavelengths[indices]
        assert (
            band_wavelengths == np.array(full_band_wavelengths)
        ).all(), "❌ Band order is inconsistent"
        # Extract Rrs for selected_bands
        filtered_Rrs = Rrs[:, :, indices]
        # === Mask construction ===
        idx_440 = np.where(band_wavelengths == 440)[0][0]
        idx_560 = np.where(band_wavelengths == 560)[0][0]
        Rrs_440 = filtered_Rrs[:, :, idx_440]
        Rrs_560 = filtered_Rrs[:, :, idx_560]
        mask_nanfree = np.all(~np.isnan(filtered_Rrs), axis=2)
        mask_condition = Rrs_560 >= Rrs_440
        mask = mask_nanfree & mask_condition
        if not np.any(mask):
            raise ValueError("❌ No valid pixels passed the filtering.")
        valid_data = filtered_Rrs[mask]  # [num_pixel, num_band]
        # === Check if smoothing is needed (only executed when any differencing is enabled) ===
        if diff_before_norm or diff_after_norm:
            from scipy.ndimage import gaussian_filter1d
            Rrs_smoothed = np.array(
                [gaussian_filter1d(spectrum, sigma=1) for spectrum in valid_data]
            )
            print("✅ Gaussian smoothing applied")
        else:
            Rrs_smoothed = valid_data
            print("✅ Smoothing not enabled")
        # === Preprocessing before differencing ===
        if diff_before_norm:
            Rrs_preprocessed = np.diff(Rrs_smoothed, axis=1)
            print("✅ Preprocessing before differencing completed")
        else:
            Rrs_preprocessed = Rrs_smoothed
            print("✅ Preprocessing before differencing not enabled")
        # === MinMax normalization to [1, 10] ===
        scalers = [MinMaxScaler((1, 10)) for _ in range(Rrs_preprocessed.shape[0])]
        Rrs_normalized = np.array(
            [
                scalers[i].fit_transform(row.reshape(-1, 1)).flatten()
                for i, row in enumerate(Rrs_preprocessed)
            ]
        )
        # === Post-processing after differencing ===
        if diff_after_norm:
            Rrs_normalized = np.diff(Rrs_normalized, axis=1)
            print("✅ Post-processing after differencing completed")
        else:
            print("✅ Post-processing after differencing not enabled")
        # === Construct DataLoader ===
        test_tensor = TensorDataset(torch.tensor(Rrs_normalized).float())
        test_loader = DataLoader(test_tensor, batch_size=2048, shuffle=False)
        return test_loader, Rrs, mask, latitude, longitude
    except Exception as e:
        print(f"❌ [ERROR] Failed to process file {nc_path}: {e}")
        return None