Skip to content

model_inference module

Model inference and visualization utilities for EMIT data.

This module provides functions for running model inference on EMIT hyperspectral data, visualizing results with RGB backgrounds, and converting predictions to GeoTIFF format.

infer_and_visualize_single_model_Robust(model, test_loader, Rrs, mask, latitude, longitude, save_folder, rgb_nc_file, structure_name, TSS_scalers_dict, vmin=0, vmax=50, exposure_coefficient=5.0)

Run model inference and visualize results with RGB background using robust scaling.

This function performs model inference on preprocessed EMIT data, applies inverse transformations using robust scalers, creates a visualization overlaid on an RGB composite, and saves both array and image outputs.

Parameters:

Name Type Description Default
model Module

Trained PyTorch model for inference.

required
test_loader DataLoader

DataLoader containing preprocessed test data.

required
Rrs ndarray

Array of remote sensing reflectance data with shape (H, W, B).

required
mask ndarray

Boolean mask indicating valid pixels with shape (H, W).

required
latitude ndarray

Array of latitude values.

required
longitude ndarray

Array of longitude values.

required
save_folder str

Directory path to save outputs.

required
rgb_nc_file str

Path to NetCDF file containing RGB bands for visualization.

required
structure_name str

Name for output files.

required
TSS_scalers_dict Dict[str, Any]

Dictionary containing 'log' and 'robust' scalers for inverse transformation.

required
vmin float

Minimum value for colormap scaling.

0
vmax float

Maximum value for colormap scaling.

50
exposure_coefficient float

Coefficient for adjusting RGB brightness.

5.0

Returns:

Type Description
final_output

Array of shape (N, 3) containing [lat, lon, value] for each pixel.

Source code in hypercoast/emit_utils/model_inference.py
def infer_and_visualize_single_model_Robust(
    model: torch.nn.Module,
    test_loader: DataLoader,
    Rrs: np.ndarray,
    mask: np.ndarray,
    latitude: np.ndarray,
    longitude: np.ndarray,
    save_folder: str,
    rgb_nc_file: str,
    structure_name: str,
    TSS_scalers_dict: Dict[str, Any],
    vmin: float = 0,
    vmax: float = 50,
    exposure_coefficient: float = 5.0,
) -> np.ndarray:
    """Run model inference and visualize results with RGB background using robust scaling.

    This function performs model inference on preprocessed EMIT data, applies
    inverse transformations using robust scalers, creates a visualization overlaid
    on an RGB composite, and saves both array and image outputs.

    Args:
        model: Trained PyTorch model for inference.
        test_loader: DataLoader containing preprocessed test data.
        Rrs: Array of remote sensing reflectance data with shape (H, W, B).
        mask: Boolean mask indicating valid pixels with shape (H, W).
        latitude: Array of latitude values.
        longitude: Array of longitude values.
        save_folder: Directory path to save outputs.
        rgb_nc_file: Path to NetCDF file containing RGB bands for visualization.
        structure_name: Name for output files.
        TSS_scalers_dict: Dictionary containing 'log' and 'robust' scalers for
            inverse transformation.
        vmin: Minimum value for colormap scaling.
        vmax: Maximum value for colormap scaling.
        exposure_coefficient: Coefficient for adjusting RGB brightness.

    Returns:
        final_output: Array of shape (N, 3) containing [lat, lon, value] for each pixel.
    """
    device = next(model.parameters()).device
    predictions_all = []

    # === Model inference ===
    with torch.no_grad():
        for batch in test_loader:
            batch = batch[0].to(device)
            output_dict = model(batch)

            # === Inverse transform using scalers ===
            predictions_log = TSS_scalers_dict["robust"].inverse_transform(
                torch.tensor(output_dict["pred_y"].cpu().numpy(), dtype=torch.float32)
            )
            predictions_real = (
                TSS_scalers_dict["log"].inverse_transform(predictions_log).numpy()
            )
            predictions_all.append(predictions_real)

    predictions_all = np.vstack(predictions_all).squeeze(-1)

    # Fill predictions into 2D array according to mask
    outputs = np.full((Rrs.shape[0], Rrs.shape[1]), np.nan)
    outputs[mask] = predictions_all

    # Save as [lat, lon, value]
    lat_flat = latitude.flatten()
    lon_flat = longitude.flatten()
    output_flat = outputs.flatten()
    final_output = np.column_stack((lat_flat, lon_flat, output_flat))
    if np.ma.isMaskedArray(final_output):
        final_output = final_output.filled(np.nan)

    os.makedirs(save_folder, exist_ok=True)
    base_name = os.path.splitext(os.path.basename(structure_name))[0]
    npy_path = os.path.join(save_folder, f"{base_name}.npy")
    png_path = os.path.join(save_folder, f"{base_name}.png")
    np.save(npy_path, final_output)

    # === Construct RGB image from EMIT L2R .nc ===
    with Dataset(rgb_nc_file) as ds:
        # Latitude
        if "lat" in ds.variables:
            lat_var = ds.variables["lat"][:]
        elif "latitude" in ds.variables:
            lat_var = ds.variables["latitude"][:]
        else:
            raise KeyError("Latitude variable not found")

        # Longitude
        if "lon" in ds.variables:
            lon_var = ds.variables["lon"][:]
        elif "longitude" in ds.variables:
            lon_var = ds.variables["longitude"][:]
        else:
            raise KeyError("Longitude variable not found")

        # rhos band list
        band_list = []
        for name in ds.variables:
            m = re.match(r"^rhos_(\d+(?:\.\d+)?)$", name)
            if m:
                wl = float(m.group(1))
                band_list.append((wl, name))
        if not band_list:
            raise ValueError("No rhos_* bands found")

        # Select nearest RGB bands
        targets = {"R": 664.0, "G": 559.0, "B": 492.0}

        def pick_nearest(target_nm):
            return min(band_list, key=lambda x: abs(x[0] - target_nm))[1]

        var_R = pick_nearest(targets["R"])
        var_G = pick_nearest(targets["G"])
        var_B = pick_nearest(targets["B"])

        R = ds.variables[var_R][:]
        G = ds.variables[var_G][:]
        B = ds.variables[var_B][:]

        if isinstance(R, np.ma.MaskedArray):
            R = R.filled(np.nan)
        if isinstance(G, np.ma.MaskedArray):
            G = G.filled(np.nan)
        if isinstance(B, np.ma.MaskedArray):
            B = B.filled(np.nan)

    # Lat/lon grid
    if lat_var.ndim == 1 and lon_var.ndim == 1:
        lat2d, lon2d = np.meshgrid(lat_var, lon_var, indexing="ij")
    else:
        lat2d, lon2d = lat_var, lon_var

    H, W = R.shape
    lat_flat = lat2d.reshape(-1)
    lon_flat = lon2d.reshape(-1)
    R_flat, G_flat, B_flat = R.reshape(-1), G.reshape(-1), B.reshape(-1)

    lat_top, lat_bot = np.nanmax(lat2d), np.nanmin(lat2d)
    lon_min, lon_max = np.nanmin(lon2d), np.nanmax(lon2d)
    grid_lat = np.linspace(lat_top, lat_bot, H)
    grid_lon = np.linspace(lon_min, lon_max, W)
    grid_lon, grid_lat = np.meshgrid(grid_lon, grid_lat)

    R_interp = griddata(
        (lon_flat, lat_flat), R_flat, (grid_lon, grid_lat), method="linear"
    )
    G_interp = griddata(
        (lon_flat, lat_flat), G_flat, (grid_lon, grid_lat), method="linear"
    )
    B_interp = griddata(
        (lon_flat, lat_flat), B_flat, (grid_lon, grid_lat), method="linear"
    )

    rgb_image = np.stack((R_interp, G_interp, B_interp), axis=-1)
    rgb_max = np.nanmax(rgb_image)
    if not np.isfinite(rgb_max) or rgb_max == 0:
        rgb_max = 1.0
    rgb_image = np.clip((rgb_image / rgb_max) * exposure_coefficient, 0, 1)
    extent_raw = [lon_min, lon_max, lat_bot, lat_top]

    # Interpolate predictions to same grid
    interp_output = griddata(
        (final_output[:, 1], final_output[:, 0]),  # lon, lat
        final_output[:, 2],
        (grid_lon, grid_lat),
        method="linear",
    )
    interp_output = np.ma.masked_invalid(interp_output)

    # Plot and save PNG
    plt.figure(figsize=(24, 6))
    plt.imshow(rgb_image, extent=extent_raw, origin="upper")
    im = plt.imshow(
        interp_output,
        extent=extent_raw,
        cmap="jet",
        alpha=1,
        origin="upper",
        vmin=vmin,
        vmax=vmax,
    )
    cbar = plt.colorbar(im)
    # cbar.set_label('(mg m$^{-3}$)', fontsize=16)
    plt.title(f"{structure_name}", loc="left", fontsize=20)
    plt.savefig(png_path, dpi=300, bbox_inches="tight", pad_inches=0.1)
    plt.show()

    print(f"✅ Saved {png_path}")
    print(f"✅ Saved {npy_path} (for npy_to_tif)")

    # Return numpy array for direct use
    return final_output

infer_and_visualize_single_model_minmax(model, test_loader, Rrs, mask, latitude, longitude, save_folder, rgb_nc_file, structure_name, vmin=0, vmax=50, log_offset=0.01, exposure_coefficient=5.0)

Run model inference and visualize results with RGB background using MinMax scaling.

This function performs model inference on preprocessed EMIT data, applies inverse log transformation, creates a visualization overlaid on an RGB composite, and saves both array and image outputs.

Parameters:

Name Type Description Default
model Module

Trained PyTorch model for inference.

required
test_loader DataLoader

DataLoader containing preprocessed test data.

required
Rrs ndarray

Array of remote sensing reflectance data with shape (H, W, B).

required
mask ndarray

Boolean mask indicating valid pixels with shape (H, W).

required
latitude ndarray

Array of latitude values.

required
longitude ndarray

Array of longitude values.

required
save_folder str

Directory path to save outputs.

required
rgb_nc_file str

Path to NetCDF file containing RGB bands for visualization.

required
structure_name str

Name for output files.

required
vmin float

Minimum value for colormap scaling.

0
vmax float

Maximum value for colormap scaling.

50
log_offset float

Offset used in log transformation during preprocessing.

0.01
exposure_coefficient float

Coefficient for adjusting RGB brightness.

5.0

Returns:

Type Description
final_output

Array of shape (N, 3) containing [lat, lon, value] for each pixel.

Source code in hypercoast/emit_utils/model_inference.py
def infer_and_visualize_single_model_minmax(
    model: torch.nn.Module,
    test_loader: DataLoader,
    Rrs: np.ndarray,
    mask: np.ndarray,
    latitude: np.ndarray,
    longitude: np.ndarray,
    save_folder: str,
    rgb_nc_file: str,
    structure_name: str,
    vmin: float = 0,
    vmax: float = 50,
    log_offset: float = 0.01,
    exposure_coefficient: float = 5.0,
) -> np.ndarray:
    """Run model inference and visualize results with RGB background using MinMax scaling.

    This function performs model inference on preprocessed EMIT data, applies
    inverse log transformation, creates a visualization overlaid on an RGB
    composite, and saves both array and image outputs.

    Args:
        model: Trained PyTorch model for inference.
        test_loader: DataLoader containing preprocessed test data.
        Rrs: Array of remote sensing reflectance data with shape (H, W, B).
        mask: Boolean mask indicating valid pixels with shape (H, W).
        latitude: Array of latitude values.
        longitude: Array of longitude values.
        save_folder: Directory path to save outputs.
        rgb_nc_file: Path to NetCDF file containing RGB bands for visualization.
        structure_name: Name for output files.
        vmin: Minimum value for colormap scaling.
        vmax: Maximum value for colormap scaling.
        log_offset: Offset used in log transformation during preprocessing.
        exposure_coefficient: Coefficient for adjusting RGB brightness.

    Returns:
        final_output: Array of shape (N, 3) containing [lat, lon, value] for each pixel.
    """
    device = next(model.parameters()).device
    predictions_all = []

    # === Model inference ===
    with torch.no_grad():
        for batch in test_loader:
            batch = batch[0].to(device)
            output_dict = model(batch)
            predictions = output_dict["pred_y"]
            predictions_np = predictions.cpu().numpy()
            predictions_original = (10**predictions_np) - log_offset
            predictions_all.append(predictions_original)

    predictions_all = np.vstack(predictions_all).squeeze(-1)

    # Fill predictions into 2D array according to mask
    outputs = np.full((Rrs.shape[0], Rrs.shape[1]), np.nan)
    outputs[mask] = predictions_all

    # Flatten lat/lon and combine with predictions
    lat_flat = latitude.flatten()
    lon_flat = longitude.flatten()
    output_flat = outputs.flatten()
    final_output = np.column_stack((lat_flat, lon_flat, output_flat))
    if np.ma.isMaskedArray(final_output):
        final_output = final_output.filled(np.nan)

    # Save .npy file (lat, lon, value)
    os.makedirs(save_folder, exist_ok=True)
    base_name = os.path.splitext(os.path.basename(structure_name))[0]
    npy_path = os.path.join(save_folder, f"{base_name}.npy")
    png_path = os.path.join(save_folder, f"{base_name}.png")
    np.save(npy_path, final_output)

    # === Read RGB bands from .nc file ===
    with Dataset(rgb_nc_file) as ds:
        if "lat" in ds.variables:
            lat_var = ds.variables["lat"][:]
        elif "latitude" in ds.variables:
            lat_var = ds.variables["latitude"][:]
        else:
            raise KeyError("Latitude variable not found (lat/latitude)")

        if "lon" in ds.variables:
            lon_var = ds.variables["lon"][:]
        elif "longitude" in ds.variables:
            lon_var = ds.variables["longitude"][:]
        else:
            raise KeyError("Longitude variable not found (lon/longitude)")

        band_list = []
        for name in ds.variables.keys():
            m = re.match(r"^rhos_(\d+(?:\.\d+)?)$", name)
            if m:
                wl = float(m.group(1))
                band_list.append((wl, name))
        if not band_list:
            raise ValueError("No rhos_* bands found in file")

        targets = {"R": 664.0, "G": 559.0, "B": 492.0}

        def pick_nearest(target_nm):
            idx = int(np.argmin([abs(w - target_nm) for w, _ in band_list]))
            wl_sel, name_sel = band_list[idx]
            return wl_sel, name_sel

        wl_R, var_R = pick_nearest(targets["R"])
        wl_G, var_G = pick_nearest(targets["G"])
        wl_B, var_B = pick_nearest(targets["B"])

        print(
            f"RGB band selection: R→{var_R}{wl_R - targets['R']:+.1f}nm), "
            f"G→{var_G}{wl_G - targets['G']:+.1f}nm), "
            f"B→{var_B}{wl_B - targets['B']:+.1f}nm)"
        )

        R = ds.variables[var_R][:]
        G = ds.variables[var_G][:]
        B = ds.variables[var_B][:]
        if isinstance(R, np.ma.MaskedArray):
            R = R.filled(np.nan)
        if isinstance(G, np.ma.MaskedArray):
            G = G.filled(np.nan)
        if isinstance(B, np.ma.MaskedArray):
            B = B.filled(np.nan)

    if lat_var.ndim == 1 and lon_var.ndim == 1:
        lat2d, lon2d = np.meshgrid(
            np.asarray(lat_var), np.asarray(lon_var), indexing="ij"
        )
    else:
        lat2d, lon2d = np.asarray(lat_var), np.asarray(lon_var)

    H, W = R.shape
    lat_flat = lat2d.reshape(-1)
    lon_flat = lon2d.reshape(-1)
    R_flat, G_flat, B_flat = R.reshape(-1), G.reshape(-1), B.reshape(-1)

    lat_top, lat_bot = np.nanmax(lat2d), np.nanmin(lat2d)
    lon_min, lon_max = np.nanmin(lon2d), np.nanmax(lon2d)
    grid_lat = np.linspace(lat_top, lat_bot, H)
    grid_lon = np.linspace(lon_min, lon_max, W)
    grid_lon, grid_lat = np.meshgrid(grid_lon, grid_lat)

    R_interp = griddata(
        (lon_flat, lat_flat), R_flat, (grid_lon, grid_lat), method="linear"
    )
    G_interp = griddata(
        (lon_flat, lat_flat), G_flat, (grid_lon, grid_lat), method="linear"
    )
    B_interp = griddata(
        (lon_flat, lat_flat), B_flat, (grid_lon, grid_lat), method="linear"
    )

    rgb_image = np.stack((R_interp, G_interp, B_interp), axis=-1)
    rgb_max = np.nanmax(rgb_image)
    if not np.isfinite(rgb_max) or rgb_max == 0:
        rgb_max = 1.0
    rgb_image = np.clip((rgb_image / rgb_max) * exposure_coefficient, 0, 1)
    extent_raw = [lon_min, lon_max, lat_bot, lat_top]

    interp_output = griddata(
        (final_output[:, 1], final_output[:, 0]),
        final_output[:, 2],
        (grid_lon, grid_lat),
        method="linear",
    )
    interp_output = np.ma.masked_invalid(interp_output)

    plt.figure(figsize=(24, 6))
    plt.imshow(rgb_image, extent=extent_raw, origin="upper")
    im = plt.imshow(
        interp_output,
        extent=extent_raw,
        cmap="jet",
        alpha=1,
        origin="upper",
        vmin=vmin,
        vmax=vmax,
    )
    cbar = plt.colorbar(im)
    # cbar.set_label('(mg m$^{-3}$)', fontsize=16)
    plt.title(f"{structure_name}", loc="left", fontsize=20)
    plt.savefig(png_path, dpi=300, bbox_inches="tight", pad_inches=0.1)
    plt.show()

    print(f"✅ Saved {png_path}")
    print(f"✅ Saved {npy_path} (for npy_to_tif)")

    # Return numpy array for direct use
    return final_output

npy_to_tif(npy_input, out_tif, resolution_m=30, method='linear', nodata_val=-9999.0, bbox_padding=0.0, lat_col=0, lon_col=1, band_cols=None, band_names=None, wavelengths=None, crs='EPSG:4326', compress='deflate', bigtiff='IF_SAFER')

Convert scattered point data to a multi-band GeoTIFF.

This function takes point data in the format [lat, lon, band1, band2, ...] and interpolates it onto a regular grid to create a georeferenced GeoTIFF file.

Parameters:

Name Type Description Default
npy_input Union[str, numpy.ndarray]

Path to .npy file or numpy array of shape [N, M] where N is the number of points and M includes lat, lon, and band values.

required
out_tif str

Output path for the GeoTIFF file.

required
resolution_m float

Spatial resolution in meters.

30
method str

Interpolation method ('linear', 'nearest', 'cubic').

'linear'
nodata_val float

Value to use for NoData pixels.

-9999.0
bbox_padding float

Padding to add to bounding box in degrees.

0.0
lat_col int

Column index containing latitude values.

0
lon_col int

Column index containing longitude values.

1
band_cols Union[int, List[int]]

Column indices to rasterize as bands. If None, uses all columns except lat/lon.

None
band_names Optional[List[str]]

Optional list of band description names.

None
wavelengths Optional[List[float]]

Optional list of wavelengths for band descriptions (e.g., [440, 619]).

None
crs str

Coordinate reference system string.

'EPSG:4326'
compress str

Compression method for GeoTIFF.

'deflate'
bigtiff str

BigTIFF creation option ('YES', 'NO', 'IF_NEEDED', 'IF_SAFER').

'IF_SAFER'

Exceptions:

Type Description
TypeError

If npy_input is neither a path string nor numpy array.

ValueError

If input array has fewer than 3 columns or no bands are selected.

Source code in hypercoast/emit_utils/model_inference.py
def npy_to_tif(
    npy_input: Union[str, np.ndarray],
    out_tif: str,
    resolution_m: float = 30,
    method: str = "linear",
    nodata_val: float = -9999.0,
    bbox_padding: float = 0.0,
    lat_col: int = 0,
    lon_col: int = 1,
    band_cols: Optional[Union[int, List[int]]] = None,
    band_names: Optional[List[str]] = None,
    wavelengths: Optional[List[float]] = None,
    crs: str = "EPSG:4326",
    compress: str = "deflate",
    bigtiff: str = "IF_SAFER",
) -> None:
    """Convert scattered point data to a multi-band GeoTIFF.

    This function takes point data in the format [lat, lon, band1, band2, ...]
    and interpolates it onto a regular grid to create a georeferenced GeoTIFF file.

    Args:
        npy_input: Path to .npy file or numpy array of shape [N, M] where N is
            the number of points and M includes lat, lon, and band values.
        out_tif: Output path for the GeoTIFF file.
        resolution_m: Spatial resolution in meters.
        method: Interpolation method ('linear', 'nearest', 'cubic').
        nodata_val: Value to use for NoData pixels.
        bbox_padding: Padding to add to bounding box in degrees.
        lat_col: Column index containing latitude values.
        lon_col: Column index containing longitude values.
        band_cols: Column indices to rasterize as bands. If None, uses all columns
            except lat/lon.
        band_names: Optional list of band description names.
        wavelengths: Optional list of wavelengths for band descriptions (e.g., [440, 619]).
        crs: Coordinate reference system string.
        compress: Compression method for GeoTIFF.
        bigtiff: BigTIFF creation option ('YES', 'NO', 'IF_NEEDED', 'IF_SAFER').

    Raises:
        TypeError: If npy_input is neither a path string nor numpy array.
        ValueError: If input array has fewer than 3 columns or no bands are selected.
    """

    # --- 1) Load data ---
    if isinstance(npy_input, str):
        arr = np.load(npy_input)
    elif isinstance(npy_input, np.ndarray):
        arr = npy_input
    else:
        raise TypeError("npy_input must be either a path string or a numpy.ndarray.")

    if arr.ndim != 2 or arr.shape[1] < 3:
        raise ValueError("Input must be 2D with >=3 columns (lat, lon, values...).")

    lat = arr[:, lat_col].astype(float)
    lon = arr[:, lon_col].astype(float)

    # --- 2) Band selection ---
    if band_cols is None:
        band_cols = [i for i in range(arr.shape[1]) if i not in (lat_col, lon_col)]
    if isinstance(band_cols, (int, np.integer)):
        band_cols = [int(band_cols)]
    if len(band_cols) == 0:
        raise ValueError("No value columns selected for bands.")

    # --- 3) Bounds (+ padding) ---
    lat_min, lat_max = np.nanmin(lat), np.nanmax(lat)
    lon_min, lon_max = np.nanmin(lon), np.nanmax(lon)
    lat_min -= bbox_padding
    lat_max += bbox_padding
    lon_min -= bbox_padding
    lon_max += bbox_padding

    # --- 4) Resolution conversion ---
    lat_center = (lat_min + lat_max) / 2.0
    deg_per_m_lat = 1.0 / 111000.0
    deg_per_m_lon = 1.0 / (111000.0 * np.cos(np.radians(lat_center)))
    res_lat_deg = resolution_m * deg_per_m_lat
    res_lon_deg = resolution_m * deg_per_m_lon

    # --- 5) Grid ---
    lon_axis = np.arange(lon_min, lon_max + res_lon_deg, res_lon_deg)
    lat_axis = np.arange(lat_min, lat_max + res_lat_deg, res_lat_deg)
    Lon, Lat = np.meshgrid(lon_axis, lat_axis)

    transform = from_origin(lon_axis.min(), lat_axis.max(), res_lon_deg, res_lat_deg)

    # --- 6) Interpolation ---
    grids = []
    for idx in band_cols:
        vals = arr[:, idx].astype(float)

        g = griddata(points=(lon, lat), values=vals, xi=(Lon, Lat), method=method)
        if np.isnan(g).any():
            g_near = griddata(
                points=(lon, lat), values=vals, xi=(Lon, Lat), method=method
            )
            g = np.where(np.isnan(g), g_near, g)

        grids.append(np.flipud(g).astype(np.float32))

    data_stack = np.stack(grids, axis=0)

    # --- 7) Write GeoTIFF ---
    profile = {
        "driver": "GTiff",
        "height": data_stack.shape[1],
        "width": data_stack.shape[2],
        "count": data_stack.shape[0],
        "dtype": rasterio.float32,
        "crs": crs,
        "transform": transform,
        "nodata": nodata_val,
        "compress": compress,
        "tiled": True,
        "blockxsize": 256,
        "blockysize": 256,
        "BIGTIFF": bigtiff,
    }

    os.makedirs(os.path.dirname(out_tif) or ".", exist_ok=True)
    with rasterio.open(out_tif, "w", **profile) as dst:
        for b in range(data_stack.shape[0]):
            band = data_stack[b]
            band[~np.isfinite(band)] = nodata_val
            dst.write(band, b + 1)

        # Descriptions
        n_bands = data_stack.shape[0]
        if band_names is not None and len(band_names) == n_bands:
            descriptions = list(map(str, band_names))
        elif wavelengths is not None and len(wavelengths) == n_bands:
            descriptions = [f"aphy_{int(wl)}" for wl in wavelengths]
        else:
            descriptions = [f"band_{band_cols[b]}" for b in range(n_bands)]

        for b in range(1, n_bands + 1):
            dst.set_band_description(b, descriptions[b - 1])

    print(f"✅ GeoTIFF saved: {out_tif}")

preprocess_emit_data_Robust(nc_path, scaler_Rrs, use_diff=False, full_band_wavelengths=None)

Preprocess EMIT NetCDF data using robust scaling for model inference.

This function reads EMIT L2 data from NetCDF, extracts specified wavelength bands, applies filtering and masking, optionally performs spectral smoothing and differencing, and returns a DataLoader ready for inference.

Parameters:

Name Type Description Default
nc_path str

Path to EMIT L2 NetCDF file.

required
scaler_Rrs Any

Pre-fitted robust scaler from training.

required
use_diff bool

Whether to apply Gaussian smoothing and first-order differencing.

False
full_band_wavelengths Optional[List[float]]

List of wavelengths (nm) to extract. If a band is not present, the closest available band is used.

None

Returns:

Type Description
test_loader

DataLoader containing preprocessed data for inference. filtered_Rrs: Array of extracted Rrs bands with shape (H, W, B). mask: Boolean mask indicating valid pixels with shape (H, W). latitude: Array of latitude values. longitude: Array of longitude values.

Exceptions:

Type Description
ValueError

If full_band_wavelengths is not provided or no Rrs bands are found.

Source code in hypercoast/emit_utils/model_inference.py
def preprocess_emit_data_Robust(
    nc_path: str,
    scaler_Rrs: Any,
    use_diff: bool = False,
    full_band_wavelengths: Optional[List[float]] = None,
) -> Tuple[DataLoader, np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
    """Preprocess EMIT NetCDF data using robust scaling for model inference.

    This function reads EMIT L2 data from NetCDF, extracts specified wavelength
    bands, applies filtering and masking, optionally performs spectral smoothing
    and differencing, and returns a DataLoader ready for inference.

    Args:
        nc_path: Path to EMIT L2 NetCDF file.
        scaler_Rrs: Pre-fitted robust scaler from training.
        use_diff: Whether to apply Gaussian smoothing and first-order differencing.
        full_band_wavelengths: List of wavelengths (nm) to extract. If a band is
            not present, the closest available band is used.

    Returns:
        test_loader: DataLoader containing preprocessed data for inference.
        filtered_Rrs: Array of extracted Rrs bands with shape (H, W, B).
        mask: Boolean mask indicating valid pixels with shape (H, W).
        latitude: Array of latitude values.
        longitude: Array of longitude values.

    Raises:
        ValueError: If full_band_wavelengths is not provided or no Rrs bands are found.
    """

    if full_band_wavelengths is None:
        raise ValueError(
            "full_band_wavelengths must be provided to match EMIT Rrs bands"
        )

    def find_closest_band(target, available_bands):
        rrs_bands = [b for b in available_bands if b.startswith("Rrs_")]
        available_waves = [int(b.split("_")[1]) for b in rrs_bands]
        if not available_waves:
            raise ValueError("❌ No Rrs_* bands found in dataset")
        closest_wave = min(available_waves, key=lambda w: abs(w - target))
        return f"Rrs_{closest_wave}"

    dataset = Dataset(nc_path)
    latitude = dataset.variables["lat"][:]
    longitude = dataset.variables["lon"][:]

    all_vars = dataset.variables.keys()

    bands_to_extract = []
    for w in full_band_wavelengths:
        band_name = f"Rrs_{int(w)}"
        if band_name in all_vars:
            bands_to_extract.append(band_name)
        else:
            closest = find_closest_band(int(w), all_vars)
            print(f"⚠️ {band_name} does not exist, using the closest band {closest}")
            bands_to_extract.append(closest)
    filtered_Rrs = np.array([dataset.variables[band][:] for band in bands_to_extract])
    filtered_Rrs = np.moveaxis(filtered_Rrs, 0, -1)

    mask = np.all(~np.isnan(filtered_Rrs), axis=2)

    target_443 = (
        f"Rrs_443"
        if "Rrs_443" in bands_to_extract
        else find_closest_band(443, bands_to_extract)
    )
    target_560 = (
        f"Rrs_560"
        if "Rrs_560" in bands_to_extract
        else find_closest_band(560, bands_to_extract)
    )

    print(f"Using {target_443} and {target_560} for mask check.")

    idx_443 = bands_to_extract.index(target_443)
    idx_560 = bands_to_extract.index(target_560)
    mask &= filtered_Rrs[:, :, idx_443] <= filtered_Rrs[:, :, idx_560]

    valid_test_data = filtered_Rrs[mask]

    # ---- smooth + diff
    if use_diff:
        from scipy.ndimage import gaussian_filter1d

        Rrs_smoothed = np.array(
            [gaussian_filter1d(spectrum, sigma=1) for spectrum in valid_test_data]
        )
        Rrs_processed = np.diff(Rrs_smoothed, axis=1)
        print("✅ [5] Performed Gaussian smoothing + first-order differencing")
    else:
        Rrs_processed = valid_test_data
        print("✅ [5] Smoothing and differencing not enabled")

    # ---- normalize
    Rrs_normalized = scaler_Rrs.transform(
        torch.tensor(Rrs_processed, dtype=torch.float32)
    ).numpy()

    # ---- DataLoader
    test_tensor = TensorDataset(torch.tensor(Rrs_normalized).float())
    test_loader = DataLoader(test_tensor, batch_size=2048, shuffle=False)
    print("✅ [6] DataLoader construction completed")

    return test_loader, filtered_Rrs, mask, latitude, longitude

preprocess_emit_data_minmax(nc_path, full_band_wavelengths=None, diff_before_norm=False, diff_after_norm=False)

Preprocess EMIT NetCDF data using sample-wise MinMax scaling for inference.

This function reads EMIT L2 data from NetCDF, extracts specified wavelength bands, applies filtering and masking, optionally performs spectral smoothing and differencing, applies sample-wise MinMax normalization, and returns a DataLoader ready for inference.

Parameters:

Name Type Description Default
nc_path str

Path to EMIT L2 NetCDF file.

required
full_band_wavelengths Optional[List[float]]

List of wavelengths (nm) to extract. If a band is not present, the closest available band is used.

None
diff_before_norm bool

Whether to apply differencing before normalization.

False
diff_after_norm bool

Whether to apply differencing after normalization.

False

Returns:

Type Description
test_loader

DataLoader containing preprocessed data for inference. Rrs: Array of extracted Rrs bands with shape (H, W, B). mask: Boolean mask indicating valid pixels with shape (H, W). latitude: Array of latitude values. longitude: Array of longitude values. Returns None if an error occurs during processing.

Exceptions:

Type Description
ValueError

If full_band_wavelengths is empty, no Rrs bands are found, or no valid pixels pass filtering.

Source code in hypercoast/emit_utils/model_inference.py
def preprocess_emit_data_minmax(
    nc_path: str,
    full_band_wavelengths: Optional[List[float]] = None,
    diff_before_norm: bool = False,
    diff_after_norm: bool = False,
) -> Optional[Tuple[DataLoader, np.ndarray, np.ndarray, np.ndarray, np.ndarray]]:
    """Preprocess EMIT NetCDF data using sample-wise MinMax scaling for inference.

    This function reads EMIT L2 data from NetCDF, extracts specified wavelength
    bands, applies filtering and masking, optionally performs spectral smoothing
    and differencing, applies sample-wise MinMax normalization, and returns a
    DataLoader ready for inference.

    Args:
        nc_path: Path to EMIT L2 NetCDF file.
        full_band_wavelengths: List of wavelengths (nm) to extract. If a band is
            not present, the closest available band is used.
        diff_before_norm: Whether to apply differencing before normalization.
        diff_after_norm: Whether to apply differencing after normalization.

    Returns:
        test_loader: DataLoader containing preprocessed data for inference.
        Rrs: Array of extracted Rrs bands with shape (H, W, B).
        mask: Boolean mask indicating valid pixels with shape (H, W).
        latitude: Array of latitude values.
        longitude: Array of longitude values.
        Returns None if an error occurs during processing.

    Raises:
        ValueError: If full_band_wavelengths is empty, no Rrs bands are found,
            or no valid pixels pass filtering.
    """
    print(f"📥 Start processing: {nc_path}")

    # ---- sanity checks
    if full_band_wavelengths is None or len(full_band_wavelengths) == 0:
        raise ValueError(
            "A non-empty full_band_wavelengths must be provided (e.g., [400, 402, ...])."
        )

    full_band_wavelengths = [int(w) for w in full_band_wavelengths]

    try:
        with Dataset(nc_path) as dataset:
            latitude = dataset.variables["lat"][:]
            longitude = dataset.variables["lon"][:]
            all_vars = set(dataset.variables.keys())
            available_wavelengths = [
                float(v.split("_")[1]) for v in all_vars if v.startswith("Rrs_")
            ]

            def find_closest_band(target_nm: float):
                nearest = min(available_wavelengths, key=lambda w: abs(w - target_nm))
                return f"Rrs_{int(nearest)}"

            # Search according to full_band_wavelengths
            bands_to_extract = []
            for w in full_band_wavelengths:
                band_name = f"Rrs_{w}"
                if band_name in all_vars:
                    bands_to_extract.append(band_name)
                else:
                    closest = find_closest_band(w)
                    print(
                        f"⚠️ {band_name} does not exist, using the closest band {closest}"
                    )
                    bands_to_extract.append(closest)

            seen = set()
            bands_to_extract = [
                b for b in bands_to_extract if not (b in seen or seen.add(b))
            ]

            if len(bands_to_extract) == 0:
                raise ValueError("❌ No usable Rrs_* bands found in the file.")
            # ---- read and stack to (H, W, B)
            # Each variable expected shape: (lat, lon) or (y, x)
            Rrs_stack = []
            for band in bands_to_extract:
                arr = dataset.variables[band][:]  # (H, W)
                Rrs_stack.append(arr)

            Rrs = np.array(Rrs_stack)  # (B, H, W)
            Rrs = np.moveaxis(Rrs, 0, -1)  # (H, W, B)
            filtered_Rrs = Rrs  # keep naming consistent with your previous return

            # ---- build mask using 440 & 560 (or nearest present within your requested list)
            have_waves = [int(b.split("_")[1]) for b in bands_to_extract]

            def nearest_idx(target_nm: int):
                # find nearest *among bands_to_extract*
                nearest_w = min(have_waves, key=lambda w: abs(w - target_nm))
                return bands_to_extract.index(f"Rrs_{nearest_w}")

            # Prefer exact if available; otherwise nearest in the user-requested set
            idx_440 = (
                bands_to_extract.index("Rrs_440")
                if "Rrs_440" in bands_to_extract
                else nearest_idx(440)
            )
            idx_560 = (
                bands_to_extract.index("Rrs_560")
                if "Rrs_560" in bands_to_extract
                else nearest_idx(560)
            )

            print(
                f"✅ Bands used for mask check: {bands_to_extract[idx_440]} and {bands_to_extract[idx_560]}"
            )

            mask_nanfree = np.all(~np.isnan(filtered_Rrs), axis=2)
            mask_condition = filtered_Rrs[:, :, idx_560] >= filtered_Rrs[:, :, idx_440]
            mask = mask_nanfree & mask_condition
            print(f"✅ [4] Built valid mask, remaining pixels: {int(np.sum(mask))}")

            if not np.any(mask):
                raise ValueError("❌ No valid pixels passed the filtering.")

            valid_test_data = filtered_Rrs[mask]  # (N, B)

        # === Check whether smoothing is needed (only executed if any differencing is enabled) ===
        if diff_before_norm or diff_after_norm:
            from scipy.ndimage import gaussian_filter1d

            Rrs_smoothed = np.array(
                [gaussian_filter1d(spectrum, sigma=1) for spectrum in valid_test_data]
            )
            print("✅ Gaussian smoothing applied")
        else:
            Rrs_smoothed = valid_test_data
            print("✅ Smoothing not enabled")

        # === Preprocessing before differencing ===
        if diff_before_norm:
            Rrs_preprocessed = np.diff(Rrs_smoothed, axis=1)
            print("✅ Preprocessing before differencing completed")
        else:
            Rrs_preprocessed = Rrs_smoothed
            print("✅ Preprocessing before differencing not enabled")

        # === MinMax normalization to [1, 10] ===
        scalers = [MinMaxScaler((1, 10)) for _ in range(Rrs_preprocessed.shape[0])]
        Rrs_normalized = np.array(
            [
                scalers[i].fit_transform(row.reshape(-1, 1)).flatten()
                for i, row in enumerate(Rrs_preprocessed)
            ]
        )

        # === Post-processing after differencing ===
        if diff_after_norm:
            Rrs_normalized = np.diff(Rrs_normalized, axis=1)
            print("✅ Post-processing after differencing completed")
        else:
            print("✅ Post-processing after differencing not enabled")

        # === Construct DataLoader
        test_tensor = TensorDataset(torch.tensor(Rrs_normalized).float())
        test_loader = DataLoader(test_tensor, batch_size=2048, shuffle=False)

        return test_loader, Rrs, mask, latitude, longitude

    except Exception as e:
        print(f"❌ [ERROR] Failed to process file {nc_path}: {e}")
        return None