model_inference module¶
Model inference utilities for MoE-VAE.
This module provides functions for preprocessing and inference using MoE-VAE models.
infer_and_visualize_single_model_Robust(model, test_loader, Rrs, mask, latitude, longitude, save_folder, extent, rgb_image, structure_name, run, TSS_scalers_dict, vmin=0, vmax=50)
¶
Infer and visualize results from a single MoE-VAE model.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
model |
torch.nn.Module |
MoE-VAE model. |
required |
test_loader |
DataLoader |
DataLoader for test data. |
required |
Rrs |
array |
Rrs data. |
required |
mask |
array |
Boolean mask indicating valid pixels. |
required |
latitude |
array |
Latitude values. |
required |
longitude |
array |
Longitude values. |
required |
save_folder |
str |
Path to save the results. |
required |
extent |
tuple |
Tuple containing the extent of the image. |
required |
rgb_image |
array |
RGB image. |
required |
structure_name |
str |
Name of the structure. |
required |
run |
int |
Run number. |
required |
TSS_scalers_dict |
dict |
Dictionary containing the scalers for TSS. |
required |
vmin |
float |
Minimum value for the colorbar. |
0 |
vmax |
float |
Maximum value for the colorbar. |
50 |
Source code in hypercoast/moe_vae/model_inference.py
def infer_and_visualize_single_model_Robust(
model,
test_loader,
Rrs,
mask,
latitude,
longitude,
save_folder,
extent,
rgb_image,
structure_name,
run,
TSS_scalers_dict,
vmin=0,
vmax=50,
):
"""
Infer and visualize results from a single MoE-VAE model.
Args:
model (torch.nn.Module): MoE-VAE model.
test_loader (DataLoader): DataLoader for test data.
Rrs (array): Rrs data.
mask (array): Boolean mask indicating valid pixels.
latitude (array): Latitude values.
longitude (array): Longitude values.
save_folder (str): Path to save the results.
extent (tuple): Tuple containing the extent of the image.
rgb_image (array): RGB image.
structure_name (str): Name of the structure.
run (int): Run number.
TSS_scalers_dict (dict): Dictionary containing the scalers for TSS.
vmin (float): Minimum value for the colorbar.
vmax (float): Maximum value for the colorbar.
"""
device = next(model.parameters()).device
predictions_all = []
with torch.no_grad():
for batch in test_loader:
batch = batch[0].to(device)
output_dict = model(batch)
predictions = output_dict["pred_y"]
# === Inverse transform using TSS_scalers_dict from training ===
predictions_log = TSS_scalers_dict["robust"].inverse_transform(
torch.tensor(predictions.cpu().numpy(), dtype=torch.float32)
)
predictions_all.append(
TSS_scalers_dict["log"].inverse_transform(predictions_log).numpy()
)
predictions_all = np.vstack(predictions_all).squeeze(-1)
outputs = np.full((Rrs.shape[0], Rrs.shape[1]), np.nan)
outputs[mask] = predictions_all
lat_flat = latitude.flatten()
lon_flat = longitude.flatten()
output_flat = outputs.flatten()
final_output = np.column_stack((lat_flat, lon_flat, output_flat))
if np.ma.isMaskedArray(final_output):
final_output = final_output.filled(np.nan)
os.makedirs(save_folder, exist_ok=True)
base_name = os.path.splitext(os.path.basename(structure_name))[0]
npy_path = os.path.join(save_folder, f"{base_name}.npy")
png_path = os.path.join(save_folder, f"{base_name}.png")
np.save(npy_path, final_output)
latitude_masked = final_output[:, 0]
longitude_masked = final_output[:, 1]
tss_values = final_output[:, 2]
mean_lat = (extent[2] + extent[3]) / 2
resolution_deg_lat = 1000 / 111000
resolution_deg_lon = 1000 / (111000 * np.cos(np.radians(mean_lat)))
grid_lon = np.arange(extent[0], extent[1], resolution_deg_lon)
grid_lat = np.arange(extent[3], extent[2], -resolution_deg_lat)
grid_lon, grid_lat = np.meshgrid(grid_lon, grid_lat)
tss_resampled = griddata(
(longitude_masked, latitude_masked),
tss_values,
(grid_lon, grid_lat),
method="linear",
)
tss_resampled = np.ma.masked_invalid(tss_resampled)
plt.figure(figsize=(24, 6))
plt.imshow(rgb_image / 255.0, extent=extent, origin="upper")
im = plt.imshow(
tss_resampled,
extent=extent,
cmap="jet",
alpha=1,
origin="upper",
vmin=vmin,
vmax=vmax,
)
cbar = plt.colorbar(im)
cbar.set_label("(mg m$^{-3}$)", fontsize=16)
plt.title(f"{structure_name} - Run {run}", loc="left", fontsize=20)
plt.savefig(png_path, dpi=300, bbox_inches="tight", pad_inches=0.1)
plt.close()
infer_and_visualize_single_model_minmax(model, test_loader, Rrs, mask, latitude, longitude, save_folder, extent, rgb_image, structure_name, run, vmin=0, vmax=50, log_offset=0.01)
¶
Infer and visualize results from a single MoE-VAE model.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
model |
torch.nn.Module |
MoE-VAE model. |
required |
test_loader |
DataLoader |
DataLoader for test data. |
required |
Rrs |
array |
Rrs data. |
required |
mask |
array |
Boolean mask indicating valid pixels. |
required |
latitude |
array |
Latitude values. |
required |
longitude |
array |
Longitude values. |
required |
save_folder |
str |
Path to save the results. |
required |
extent |
tuple |
Tuple containing the extent of the image. |
required |
rgb_image |
array |
RGB image. |
required |
structure_name |
str |
Name of the structure. |
required |
run |
int |
Run number. |
required |
vmin |
float |
Minimum value for the colorbar. |
0 |
vmax |
float |
Maximum value for the colorbar. |
50 |
log_offset |
float |
Log offset for predictions. |
0.01 |
Source code in hypercoast/moe_vae/model_inference.py
def infer_and_visualize_single_model_minmax(
model,
test_loader,
Rrs,
mask,
latitude,
longitude,
save_folder,
extent,
rgb_image,
structure_name,
run,
vmin=0,
vmax=50,
log_offset=0.01,
):
"""
Infer and visualize results from a single MoE-VAE model.
Args:
model (torch.nn.Module): MoE-VAE model.
test_loader (DataLoader): DataLoader for test data.
Rrs (array): Rrs data.
mask (array): Boolean mask indicating valid pixels.
latitude (array): Latitude values.
longitude (array): Longitude values.
save_folder (str): Path to save the results.
extent (tuple): Tuple containing the extent of the image.
rgb_image (array): RGB image.
structure_name (str): Name of the structure.
run (int): Run number.
vmin (float): Minimum value for the colorbar.
vmax (float): Maximum value for the colorbar.
log_offset (float): Log offset for predictions.
"""
device = next(model.parameters()).device
predictions_all = []
with torch.no_grad():
for batch in test_loader:
batch = batch[0].to(device)
output_dict = model(batch)
predictions = output_dict["pred_y"]
predictions_np = predictions.cpu().numpy()
predictions_original = (10**predictions_np) - log_offset
predictions_all.append(predictions_original)
predictions_all = np.vstack(predictions_all).squeeze(-1)
outputs = np.full((Rrs.shape[0], Rrs.shape[1]), np.nan)
outputs[mask] = predictions_all
lat_flat = latitude.flatten()
lon_flat = longitude.flatten()
output_flat = outputs.flatten()
final_output = np.column_stack((lat_flat, lon_flat, output_flat))
if np.ma.isMaskedArray(final_output):
final_output = final_output.filled(np.nan)
os.makedirs(save_folder, exist_ok=True)
base_name = os.path.splitext(os.path.basename(structure_name))[0]
npy_path = os.path.join(save_folder, f"{base_name}.npy")
png_path = os.path.join(save_folder, f"{base_name}.png")
np.save(npy_path, final_output)
latitude_masked = final_output[:, 0]
longitude_masked = final_output[:, 1]
tss_values = final_output[:, 2]
mean_lat = (extent[2] + extent[3]) / 2
resolution_deg_lat = 1000 / 111000
resolution_deg_lon = 1000 / (111000 * np.cos(np.radians(mean_lat)))
grid_lon = np.arange(extent[0], extent[1], resolution_deg_lon)
grid_lat = np.arange(extent[3], extent[2], -resolution_deg_lat)
grid_lon, grid_lat = np.meshgrid(grid_lon, grid_lat)
tss_resampled = griddata(
(longitude_masked, latitude_masked),
tss_values,
(grid_lon, grid_lat),
method="linear",
)
tss_resampled = np.ma.masked_invalid(tss_resampled)
plt.figure(figsize=(24, 6))
plt.imshow(rgb_image / 255.0, extent=extent, origin="upper")
im = plt.imshow(
tss_resampled,
extent=extent,
cmap="jet",
alpha=1,
origin="upper",
vmin=vmin,
vmax=vmax,
)
cbar = plt.colorbar(im)
cbar.set_label("(mg m$^{-3}$)", fontsize=16)
plt.title(f"{structure_name} - Run {run}", loc="left", fontsize=20)
plt.savefig(png_path, dpi=300, bbox_inches="tight", pad_inches=0.1)
plt.close()
infer_and_visualize_token_model_Robust(model, test_loader, Rrs, mask, latitude, longitude, save_folder, extent, rgb_image, structure_name, run, TSS_scalers_dict, vmin=0, vmax=50)
¶
Infer and visualize results from a token-based MoE-VAE model.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
model |
torch.nn.Module |
MoE-VAE model. |
required |
test_loader |
DataLoader |
DataLoader for test data. |
required |
Rrs |
array |
Rrs data. |
required |
mask |
array |
Boolean mask indicating valid pixels. |
required |
latitude |
array |
Latitude values. |
required |
longitude |
array |
Longitude values. |
required |
save_folder |
str |
Path to save the results. |
required |
extent |
tuple |
Tuple containing the extent of the image. |
required |
rgb_image |
array |
RGB image. |
required |
structure_name |
str |
Name of the structure. |
required |
run |
int |
Run number. |
required |
TSS_scalers_dict |
dict |
Dictionary containing the scalers for TSS. |
required |
vmin |
float |
Minimum value for the colorbar. |
0 |
vmax |
float |
Maximum value for the colorbar. |
50 |
Source code in hypercoast/moe_vae/model_inference.py
def infer_and_visualize_token_model_Robust(
model,
test_loader,
Rrs,
mask,
latitude,
longitude,
save_folder,
extent,
rgb_image,
structure_name,
run,
TSS_scalers_dict,
vmin=0,
vmax=50,
):
"""
Infer and visualize results from a token-based MoE-VAE model.
Args:
model (torch.nn.Module): MoE-VAE model.
test_loader (DataLoader): DataLoader for test data.
Rrs (array): Rrs data.
mask (array): Boolean mask indicating valid pixels.
latitude (array): Latitude values.
longitude (array): Longitude values.
save_folder (str): Path to save the results.
extent (tuple): Tuple containing the extent of the image.
rgb_image (array): RGB image.
structure_name (str): Name of the structure.
run (int): Run number.
TSS_scalers_dict (dict): Dictionary containing the scalers for TSS.
vmin (float): Minimum value for the colorbar.
vmax (float): Maximum value for the colorbar.
"""
device = next(model.parameters()).device
predictions_all = []
with torch.no_grad():
for batch in test_loader:
batch = batch[0].to(device)
output_dict = model(batch)
predictions = output_dict["pred_y"] # shape [B, token_len]
# === Aggregate by token ===
if predictions.ndim == 2:
predictions = predictions.mean(dim=1, keepdim=True) # [B, 1]
# === Robust + log inverse transform ===
predictions_log = TSS_scalers_dict["robust"].inverse_transform(
torch.tensor(predictions.cpu().numpy(), dtype=torch.float32)
)
predictions_all.append(
TSS_scalers_dict["log"].inverse_transform(predictions_log).numpy()
)
# === Concatenate and remove extra dimensions ===
predictions_all = np.vstack(predictions_all).reshape(-1)
outputs = np.full((Rrs.shape[0], Rrs.shape[1]), np.nan)
outputs[mask] = predictions_all
lat_flat = latitude.flatten()
lon_flat = longitude.flatten()
output_flat = outputs.flatten()
final_output = np.column_stack((lat_flat, lon_flat, output_flat))
os.makedirs(save_folder, exist_ok=True)
base_name = os.path.splitext(os.path.basename(structure_name))[0]
npy_path = os.path.join(save_folder, f"{base_name}.npy")
png_path = os.path.join(save_folder, f"{base_name}.png")
np.save(npy_path, final_output)
latitude_masked = final_output[:, 0]
longitude_masked = final_output[:, 1]
tss_values = final_output[:, 2]
mean_lat = (extent[2] + extent[3]) / 2
resolution_deg_lat = 1000 / 111000
resolution_deg_lon = 1000 / (111000 * np.cos(np.radians(mean_lat)))
grid_lon = np.arange(extent[0], extent[1], resolution_deg_lon)
grid_lat = np.arange(extent[3], extent[2], -resolution_deg_lat)
grid_lon, grid_lat = np.meshgrid(grid_lon, grid_lat)
tss_resampled = griddata(
(longitude_masked, latitude_masked),
tss_values,
(grid_lon, grid_lat),
method="linear",
)
tss_resampled = np.ma.masked_invalid(tss_resampled)
plt.figure(figsize=(24, 6))
plt.imshow(rgb_image / 255.0, extent=extent, origin="upper")
im = plt.imshow(
tss_resampled,
extent=extent,
cmap="jet",
alpha=1,
origin="upper",
vmin=vmin,
vmax=vmax,
)
cbar = plt.colorbar(im)
cbar.set_label("(mg m$^{-3}$)", fontsize=16)
plt.title(f"{structure_name} - Run {run}", loc="left", fontsize=20)
plt.savefig(png_path, dpi=300, bbox_inches="tight", pad_inches=0.1)
plt.close()
infer_and_visualize_token_model_minmax(model, test_loader, Rrs, mask, latitude, longitude, save_folder, extent, rgb_image, structure_name, run, vmin=0, vmax=50, log_offset=0.01)
¶
Infer and visualize results from a token-based MoE-VAE model.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
model |
torch.nn.Module |
MoE-VAE model. |
required |
test_loader |
DataLoader |
DataLoader for test data. |
required |
Rrs |
array |
Rrs data. |
required |
mask |
array |
Boolean mask indicating valid pixels. |
required |
latitude |
array |
Latitude values. |
required |
longitude |
array |
Longitude values. |
required |
save_folder |
str |
Path to save the results. |
required |
extent |
tuple |
Tuple containing the extent of the image. |
required |
rgb_image |
array |
RGB image. |
required |
structure_name |
str |
Name of the structure. |
required |
run |
int |
Run number. |
required |
vmin |
float |
Minimum value for the colorbar. |
0 |
vmax |
float |
Maximum value for the colorbar. |
50 |
log_offset |
float |
Log offset for predictions. |
0.01 |
Source code in hypercoast/moe_vae/model_inference.py
def infer_and_visualize_token_model_minmax(
model,
test_loader,
Rrs,
mask,
latitude,
longitude,
save_folder,
extent,
rgb_image,
structure_name,
run,
vmin=0,
vmax=50,
log_offset=0.01,
):
"""
Infer and visualize results from a token-based MoE-VAE model.
Args:
model (torch.nn.Module): MoE-VAE model.
test_loader (DataLoader): DataLoader for test data.
Rrs (array): Rrs data.
mask (array): Boolean mask indicating valid pixels.
latitude (array): Latitude values.
longitude (array): Longitude values.
save_folder (str): Path to save the results.
extent (tuple): Tuple containing the extent of the image.
rgb_image (array): RGB image.
structure_name (str): Name of the structure.
run (int): Run number.
vmin (float): Minimum value for the colorbar.
vmax (float): Maximum value for the colorbar.
log_offset (float): Log offset for predictions.
"""
device = next(model.parameters()).device
predictions_all = []
with torch.no_grad():
for batch in test_loader:
batch = batch[0].to(device)
output_dict = model(batch)
predictions = output_dict["pred_y"] # shape [B, token_len]
# === Aggregate along the token dimension ===
if predictions.ndim == 2:
predictions = predictions.mean(dim=1, keepdim=True) # shape [B, 1]
predictions_np = predictions.cpu().numpy()
predictions_original = (10**predictions_np) - log_offset
predictions_all.append(predictions_original)
# === Concatenate and flatten ===
predictions_all = np.vstack(predictions_all).reshape(-1)
outputs = np.full((Rrs.shape[0], Rrs.shape[1]), np.nan)
outputs[mask] = predictions_all
lat_flat = latitude.flatten()
lon_flat = longitude.flatten()
output_flat = outputs.flatten()
final_output = np.column_stack((lat_flat, lon_flat, output_flat))
os.makedirs(save_folder, exist_ok=True)
base_name = os.path.splitext(os.path.basename(structure_name))[0]
npy_path = os.path.join(save_folder, f"{base_name}.npy")
png_path = os.path.join(save_folder, f"{base_name}.png")
np.save(npy_path, final_output)
latitude_masked = final_output[:, 0]
longitude_masked = final_output[:, 1]
tss_values = final_output[:, 2]
mean_lat = (extent[2] + extent[3]) / 2
resolution_deg_lat = 1000 / 111000
resolution_deg_lon = 1000 / (111000 * np.cos(np.radians(mean_lat)))
grid_lon = np.arange(extent[0], extent[1], resolution_deg_lon)
grid_lat = np.arange(extent[3], extent[2], -resolution_deg_lat)
grid_lon, grid_lat = np.meshgrid(grid_lon, grid_lat)
tss_resampled = griddata(
(longitude_masked, latitude_masked),
tss_values,
(grid_lon, grid_lat),
method="linear",
)
tss_resampled = np.ma.masked_invalid(tss_resampled)
plt.figure(figsize=(24, 6))
plt.imshow(rgb_image / 255.0, extent=extent, origin="upper")
im = plt.imshow(
tss_resampled,
extent=extent,
cmap="jet",
alpha=1,
origin="upper",
vmin=vmin,
vmax=vmax,
)
cbar = plt.colorbar(im)
cbar.set_label("(mg m$^{-3}$)", fontsize=16)
plt.title(f"{structure_name} - Run {run}", loc="left", fontsize=20)
plt.savefig(png_path, dpi=300, bbox_inches="tight", pad_inches=0.1)
plt.close()
preprocess_emit_data_Robust(nc_path, scaler_Rrs, use_diff=True, full_band_wavelengths=None)
¶
Preprocess EMIT data for Robust scaling.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
nc_path |
str |
Path to the NetCDF file containing EMIT data. |
required |
scaler_Rrs |
object |
RobustScaler object for Rrs normalization. |
required |
use_diff |
bool |
Whether to apply first-order differencing. |
True |
full_band_wavelengths |
list |
List of target wavelength bands. |
None |
Returns:
Type | Description |
---|---|
test_loader (DataLoader) |
DataLoader for test data. filtered_Rrs (array): Filtered Rrs data. mask (array): Boolean mask indicating valid pixels. latitude (array): Latitude values. longitude (array): Longitude values. |
Source code in hypercoast/moe_vae/model_inference.py
def preprocess_emit_data_Robust(
nc_path, scaler_Rrs, use_diff=True, full_band_wavelengths=None
):
"""
Preprocess EMIT data for Robust scaling.
Args:
nc_path (str): Path to the NetCDF file containing EMIT data.
scaler_Rrs (object): RobustScaler object for Rrs normalization.
use_diff (bool): Whether to apply first-order differencing.
full_band_wavelengths (list): List of target wavelength bands.
Returns:
test_loader (DataLoader): DataLoader for test data.
filtered_Rrs (array): Filtered Rrs data.
mask (array): Boolean mask indicating valid pixels.
latitude (array): Latitude values.
longitude (array): Longitude values.
"""
if full_band_wavelengths is None:
raise ValueError(
"full_band_wavelengths must be provided to match EMIT Rrs bands"
)
def find_closest_band(target, available_bands):
available_waves = [int(b.split("_")[1]) for b in available_bands]
closest_wave = min(available_waves, key=lambda w: abs(w - target))
return f"Rrs_{closest_wave}"
dataset = nc(nc_path)
latitude = dataset.variables["lat"][:]
longitude = dataset.variables["lon"][:]
all_vars = dataset.variables.keys()
bands_to_extract = []
for w in full_band_wavelengths:
band_name = f"Rrs_{int(w)}"
if band_name in all_vars:
bands_to_extract.append(band_name)
else:
closest = find_closest_band(int(w))
print(f"⚠️ {band_name} does not exist, using the closest band {closest}")
bands_to_extract.append(closest)
filtered_Rrs = np.array([dataset.variables[band][:] for band in bands_to_extract])
filtered_Rrs = np.moveaxis(filtered_Rrs, 0, -1)
mask = np.all(~np.isnan(filtered_Rrs), axis=2)
target_443 = (
f"Rrs_443"
if "Rrs_443" in bands_to_extract
else find_closest_band(443, bands_to_extract)
)
target_560 = (
f"Rrs_560"
if "Rrs_560" in bands_to_extract
else find_closest_band(560, bands_to_extract)
)
print(f"Using {target_443} and {target_560} for mask check.")
idx_443 = bands_to_extract.index(target_443)
idx_560 = bands_to_extract.index(target_560)
mask &= filtered_Rrs[:, :, idx_443] <= filtered_Rrs[:, :, idx_560]
valid_test_data = filtered_Rrs[mask]
# ---- smooth + diff
if use_diff:
from scipy.ndimage import gaussian_filter1d
Rrs_smoothed = np.array(
[gaussian_filter1d(spectrum, sigma=1) for spectrum in valid_test_data]
)
Rrs_processed = np.diff(Rrs_smoothed, axis=1)
print("✅ [5] Performed Gaussian smoothing + first-order differencing")
else:
Rrs_processed = valid_test_data
print("✅ [5] Smoothing and differencing not enabled")
# ---- normalize
Rrs_normalized = scaler_Rrs.transform(
torch.tensor(Rrs_processed, dtype=torch.float32)
).numpy()
# ---- DataLoader
test_tensor = TensorDataset(torch.tensor(Rrs_normalized).float())
test_loader = DataLoader(test_tensor, batch_size=2048, shuffle=False)
print("✅ [6] DataLoader construction completed")
return test_loader, filtered_Rrs, mask, latitude, longitude
preprocess_emit_data_minmax(nc_path, full_band_wavelengths=None, diff_before_norm=False, diff_after_norm=False)
¶
Read EMIT NetCDF, extract Rrs_* bands according to full_band_wavelengths, apply (optional) smooth+diff and robust normalization, and return a DataLoader.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
nc_path |
str |
Path to the NetCDF file containing EMIT data. |
required |
full_band_wavelengths |
list |
List of target wavelength bands. |
None |
diff_before_norm |
bool |
Whether to apply first-order differencing before normalization. |
False |
diff_after_norm |
bool |
Whether to apply first-order differencing after normalization. |
False |
Returns:
Type | Description |
---|---|
test_loader, filtered_Rrs(H, W, B), mask(H, W), latitude(H), longitude(W) |
Source code in hypercoast/moe_vae/model_inference.py
def preprocess_emit_data_minmax(
nc_path, full_band_wavelengths=None, diff_before_norm=False, diff_after_norm=False
):
"""
Read EMIT NetCDF, extract Rrs_* bands according to full_band_wavelengths,
apply (optional) smooth+diff and robust normalization, and return a DataLoader.
Args:
nc_path (str): Path to the NetCDF file containing EMIT data.
full_band_wavelengths (list): List of target wavelength bands.
diff_before_norm (bool): Whether to apply first-order differencing before normalization.
diff_after_norm (bool): Whether to apply first-order differencing after normalization.
Returns:
test_loader, filtered_Rrs(H, W, B), mask(H, W), latitude(H), longitude(W)
"""
print(f"📥 Start processing: {nc_path}")
# ---- sanity checks
if full_band_wavelengths is None or len(full_band_wavelengths) == 0:
raise ValueError(
"A non-empty full_band_wavelengths must be provided (e.g., [400, 402, ...])."
)
full_band_wavelengths = [int(w) for w in full_band_wavelengths]
try:
with nc(nc_path, "r") as dataset:
latitude = dataset.variables["lat"][:]
longitude = dataset.variables["lon"][:]
all_vars = set(dataset.variables.keys())
available_wavelengths = [
float(v.split("_")[1]) for v in all_vars if v.startswith("Rrs_")
]
def find_closest_band(target_nm: float):
nearest = min(available_wavelengths, key=lambda w: abs(w - target_nm))
return f"Rrs_{int(nearest)}"
# Search according to full_band_wavelengths
bands_to_extract = []
for w in full_band_wavelengths:
band_name = f"Rrs_{w}"
if band_name in all_vars:
bands_to_extract.append(band_name)
else:
closest = find_closest_band(w)
print(
f"⚠️ {band_name} does not exist, using the closest band {closest}"
)
bands_to_extract.append(closest)
seen = set()
bands_to_extract = [
b for b in bands_to_extract if not (b in seen or seen.add(b))
]
if len(bands_to_extract) == 0:
raise ValueError("❌ No usable Rrs_* bands found in the file.")
# ---- read and stack to (H, W, B)
# Each variable expected shape: (lat, lon) or (y, x)
Rrs_stack = []
for band in bands_to_extract:
arr = dataset.variables[band][:] # (H, W)
Rrs_stack.append(arr)
Rrs = np.array(Rrs_stack) # (B, H, W)
Rrs = np.moveaxis(Rrs, 0, -1) # (H, W, B)
filtered_Rrs = Rrs # keep naming consistent with your previous return
# ---- build mask using 440 & 560 (or nearest present within your requested list)
have_waves = [int(b.split("_")[1]) for b in bands_to_extract]
def nearest_idx(target_nm: int):
# find nearest *among bands_to_extract*
nearest_w = min(have_waves, key=lambda w: abs(w - target_nm))
return bands_to_extract.index(f"Rrs_{nearest_w}")
# Prefer exact if available; otherwise nearest in the user-requested set
idx_440 = (
bands_to_extract.index("Rrs_440")
if "Rrs_440" in bands_to_extract
else nearest_idx(440)
)
idx_560 = (
bands_to_extract.index("Rrs_560")
if "Rrs_560" in bands_to_extract
else nearest_idx(560)
)
print(
f"✅ Bands used for mask check: {bands_to_extract[idx_440]} and {bands_to_extract[idx_560]}"
)
mask_nanfree = np.all(~np.isnan(filtered_Rrs), axis=2)
mask_condition = filtered_Rrs[:, :, idx_560] >= filtered_Rrs[:, :, idx_440]
mask = mask_nanfree & mask_condition
print(f"✅ [4] Built valid mask, remaining pixels: {int(np.sum(mask))}")
if not np.any(mask):
raise ValueError("❌ No valid pixels passed the filtering.")
valid_test_data = filtered_Rrs[mask] # (N, B)
# === Check whether smoothing is needed (only executed if any differencing is enabled) ===
if diff_before_norm or diff_after_norm:
from scipy.ndimage import gaussian_filter1d
Rrs_smoothed = np.array(
[gaussian_filter1d(spectrum, sigma=1) for spectrum in valid_test_data]
)
print("✅ Gaussian smoothing applied")
else:
Rrs_smoothed = valid_test_data
print("✅ Smoothing not enabled")
# === Preprocessing before differencing ===
if diff_before_norm:
Rrs_preprocessed = np.diff(Rrs_smoothed, axis=1)
print("✅ Preprocessing before differencing completed")
else:
Rrs_preprocessed = Rrs_smoothed
print("✅ Preprocessing before differencing not enabled")
# === MinMax normalization to [1, 10] ===
scalers = [MinMaxScaler((1, 10)) for _ in range(Rrs_preprocessed.shape[0])]
Rrs_normalized = np.array(
[
scalers[i].fit_transform(row.reshape(-1, 1)).flatten()
for i, row in enumerate(Rrs_preprocessed)
]
)
# === Post-processing after differencing ===
if diff_after_norm:
Rrs_normalized = np.diff(Rrs_normalized, axis=1)
print("✅ Post-processing after differencing completed")
else:
print("✅ Post-processing after differencing not enabled")
# === Construct DataLoader
test_tensor = TensorDataset(torch.tensor(Rrs_normalized).float())
test_loader = DataLoader(test_tensor, batch_size=2048, shuffle=False)
return test_loader, Rrs, mask, latitude, longitude
except Exception as e:
print(f"❌ [ERROR] Failed to process file {nc_path}: {e}")
return None
preprocess_pace_data_Robust(nc_path, scaler_Rrs, use_diff=True, full_band_wavelengths=None)
¶
Preprocess PACE data for Robust scaling.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
nc_path |
str |
Path to the NetCDF file containing PACE data. |
required |
scaler_Rrs |
object |
RobustScaler object for Rrs normalization. |
required |
use_diff |
bool |
Whether to apply first-order differencing. |
True |
full_band_wavelengths |
list |
List of target wavelength bands. |
None |
Returns:
Type | Description |
---|---|
test_loader (DataLoader) |
DataLoader for test data. filtered_Rrs (array): Filtered Rrs data. mask (array): Boolean mask indicating valid pixels. latitude (array): Latitude values. longitude (array): Longitude values. |
Source code in hypercoast/moe_vae/model_inference.py
def preprocess_pace_data_Robust(
nc_path, scaler_Rrs, use_diff=True, full_band_wavelengths=None
):
"""
Preprocess PACE data for Robust scaling.
Args:
nc_path (str): Path to the NetCDF file containing PACE data.
scaler_Rrs (object): RobustScaler object for Rrs normalization.
use_diff (bool): Whether to apply first-order differencing.
full_band_wavelengths (list): List of target wavelength bands.
Returns:
test_loader (DataLoader): DataLoader for test data.
filtered_Rrs (array): Filtered Rrs data.
mask (array): Boolean mask indicating valid pixels.
latitude (array): Latitude values.
longitude (array): Longitude values.
"""
print(f"📥 Start processing: {nc_path}")
try:
PACE_dataset = read_pace(nc_path)
print("✅ [1] Successfully read PACE data")
da = PACE_dataset["Rrs"]
Rrs = da.values # [lat, lon, bands]
latitude = da.latitude.values
longitude = da.longitude.values
print("✅ [2] Successfully retrieved Rrs, lat, and lon")
# ✅ Extract wavelength
if full_band_wavelengths is None:
raise ValueError(
"full_band_wavelengths must be provided to match PACE Rrs bands"
)
if hasattr(da, "wavelength") or "wavelength" in da.coords:
pace_band_wavelengths = da.wavelength.values
else:
raise ValueError(
"❌ Unable to extract wavelength from PACE data. Please check the NetCDF file structure."
)
missing = [b for b in full_band_wavelengths if b not in pace_band_wavelengths]
if missing:
raise ValueError(
f"❌ The following wavelengths are not present in the PACE data: {missing}"
)
indices = [
np.where(pace_band_wavelengths == b)[0][0] for b in full_band_wavelengths
]
band_wavelengths = pace_band_wavelengths[indices]
assert (
band_wavelengths == np.array(full_band_wavelengths)
).all(), "❌ Band order mismatch"
filtered_Rrs = Rrs[:, :, indices]
print(
f"✅ [3] Bands re-extracted based on full_band_wavelengths, total {len(indices)} bands"
)
# ✅ Build mask
idx_440 = np.where(band_wavelengths == 440)[0][0]
idx_560 = np.where(band_wavelengths == 560)[0][0]
Rrs_440 = filtered_Rrs[:, :, idx_440]
Rrs_560 = filtered_Rrs[:, :, idx_560]
mask_nanfree = np.all(~np.isnan(filtered_Rrs), axis=2)
mask_condition = Rrs_560 >= Rrs_440
mask = mask_nanfree & mask_condition
print(f"✅ [4] Built valid mask, remaining pixels: {np.sum(mask)}")
if not np.any(mask):
raise ValueError("❌ No valid pixels passed the filtering.")
valid_test_data = filtered_Rrs[mask]
# ✅ Smoothing before differencing (enabled only if use_diff=True)
if use_diff:
from scipy.ndimage import gaussian_filter1d
Rrs_smoothed = np.array(
[gaussian_filter1d(spectrum, sigma=1) for spectrum in valid_test_data]
)
Rrs_processed = np.diff(Rrs_smoothed, axis=1)
print("✅ [5] Performed Gaussian smoothing + first-order differencing")
else:
Rrs_processed = valid_test_data
print("✅ [5] Smoothing and differencing not enabled")
# ✅ Normalization (RobustScaler provided)
Rrs_normalized = scaler_Rrs.transform(
torch.tensor(Rrs_processed, dtype=torch.float32)
).numpy()
# ✅ Construct DataLoader
from torch.utils.data import DataLoader, TensorDataset
test_tensor = TensorDataset(torch.tensor(Rrs_normalized).float())
test_loader = DataLoader(test_tensor, batch_size=2048, shuffle=False)
print("✅ [6] DataLoader construction completed")
return test_loader, filtered_Rrs, mask, latitude, longitude
except Exception as e:
print(f"❌ [ERROR] Failed to process file {nc_path}: {e}")
return None
preprocess_pace_data_minmax(nc_path, full_band_wavelengths=None, diff_before_norm=False, diff_after_norm=False)
¶
Preprocess PACE data for MinMax scaling.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
nc_path |
str |
Path to the NetCDF file containing PACE data. |
required |
full_band_wavelengths |
list |
List of target wavelength bands. |
None |
diff_before_norm |
bool |
Whether to apply first-order differencing before normalization. |
False |
diff_after_norm |
bool |
Whether to apply first-order differencing after normalization. |
False |
Returns:
Type | Description |
---|---|
test_loader (DataLoader) |
DataLoader for test data. Rrs (array): Rrs data. mask (array): Boolean mask indicating valid pixels. latitude (array): Latitude values. longitude (array): Longitude values. |
Source code in hypercoast/moe_vae/model_inference.py
def preprocess_pace_data_minmax(
nc_path, full_band_wavelengths=None, diff_before_norm=False, diff_after_norm=False
):
"""
Preprocess PACE data for MinMax scaling.
Args:
nc_path (str): Path to the NetCDF file containing PACE data.
full_band_wavelengths (list): List of target wavelength bands.
diff_before_norm (bool): Whether to apply first-order differencing before normalization.
diff_after_norm (bool): Whether to apply first-order differencing after normalization.
Returns:
test_loader (DataLoader): DataLoader for test data.
Rrs (array): Rrs data.
mask (array): Boolean mask indicating valid pixels.
latitude (array): Latitude values.
longitude (array): Longitude values.
"""
try:
# === Load data ===
PACE_dataset = read_pace(nc_path)
da = PACE_dataset["Rrs"]
Rrs = da.values # [lat, lon, bands]
latitude = da.latitude.values
longitude = da.longitude.values
# === Band check ===
if full_band_wavelengths is None:
raise ValueError(
"full_band_wavelengths must be provided to match PACE Rrs bands"
)
if hasattr(da, "wavelength") or "wavelength" in da.coords:
pace_band_wavelengths = da.wavelength.values
else:
raise ValueError(
"❌ Unable to extract wavelength from PACE data. Please check the NetCDF file structure."
)
# Check for missing bands
missing = [b for b in full_band_wavelengths if b not in pace_band_wavelengths]
if missing:
raise ValueError(
f"❌ The following wavelengths are not found in the PACE data: {missing}"
)
# Extract indices in the order of full_band_wavelengths
indices = [
np.where(pace_band_wavelengths == b)[0][0] for b in full_band_wavelengths
]
band_wavelengths = pace_band_wavelengths[indices]
assert (
band_wavelengths == np.array(full_band_wavelengths)
).all(), "❌ Band order is inconsistent"
# Extract Rrs for selected_bands
filtered_Rrs = Rrs[:, :, indices]
# === Mask construction ===
idx_440 = np.where(band_wavelengths == 440)[0][0]
idx_560 = np.where(band_wavelengths == 560)[0][0]
Rrs_440 = filtered_Rrs[:, :, idx_440]
Rrs_560 = filtered_Rrs[:, :, idx_560]
mask_nanfree = np.all(~np.isnan(filtered_Rrs), axis=2)
mask_condition = Rrs_560 >= Rrs_440
mask = mask_nanfree & mask_condition
if not np.any(mask):
raise ValueError("❌ No valid pixels passed the filtering.")
valid_data = filtered_Rrs[mask] # [num_pixel, num_band]
# === Check if smoothing is needed (only executed when any differencing is enabled) ===
if diff_before_norm or diff_after_norm:
from scipy.ndimage import gaussian_filter1d
Rrs_smoothed = np.array(
[gaussian_filter1d(spectrum, sigma=1) for spectrum in valid_data]
)
print("✅ Gaussian smoothing applied")
else:
Rrs_smoothed = valid_data
print("✅ Smoothing not enabled")
# === Preprocessing before differencing ===
if diff_before_norm:
Rrs_preprocessed = np.diff(Rrs_smoothed, axis=1)
print("✅ Preprocessing before differencing completed")
else:
Rrs_preprocessed = Rrs_smoothed
print("✅ Preprocessing before differencing not enabled")
# === MinMax normalization to [1, 10] ===
scalers = [MinMaxScaler((1, 10)) for _ in range(Rrs_preprocessed.shape[0])]
Rrs_normalized = np.array(
[
scalers[i].fit_transform(row.reshape(-1, 1)).flatten()
for i, row in enumerate(Rrs_preprocessed)
]
)
# === Post-processing after differencing ===
if diff_after_norm:
Rrs_normalized = np.diff(Rrs_normalized, axis=1)
print("✅ Post-processing after differencing completed")
else:
print("✅ Post-processing after differencing not enabled")
# === Construct DataLoader ===
test_tensor = TensorDataset(torch.tensor(Rrs_normalized).float())
test_loader = DataLoader(test_tensor, batch_size=2048, shuffle=False)
return test_loader, Rrs, mask, latitude, longitude
except Exception as e:
print(f"❌ [ERROR] Failed to process file {nc_path}: {e}")
return None