data_loading module¶
Data loading and preprocessing utilities for MoE-VAE models.
This module provides functions for loading, preprocessing, and preparing remote sensing data for training and inference with VAE and MoE-VAE models. It includes support for various data formats and preprocessing techniques.
load_real_data(excel_path, selected_bands, split_ratio=0.7, seed=42, diff_before_norm=False, diff_after_norm=False, target_parameter='TSS', lower_quantile=0.0, upper_quantile=1.0, log_offset=0.01)
¶
Load and preprocess real data using MinMax scaling.
This function loads remote sensing data from Excel files and applies MinMax normalization with optional differencing operations. Each sample is normalized independently to the range [1, 10].
Parameters:
Name | Type | Description | Default |
---|---|---|---|
excel_path |
str |
Path to Excel file containing the data. |
required |
selected_bands |
list |
List of wavelength bands to extract. |
required |
split_ratio |
float |
Train/test split ratio. Defaults to 0.7. |
0.7 |
seed |
int |
Random seed for reproducible splits. Defaults to 42. |
42 |
diff_before_norm |
bool |
Apply differencing before normalization. Defaults to False. |
False |
diff_after_norm |
bool |
Apply differencing after normalization. Defaults to False. |
False |
target_parameter |
str |
Target parameter column name. Defaults to "TSS". |
'TSS' |
lower_quantile |
float |
Lower quantile for outlier removal. Defaults to 0.0. |
0.0 |
upper_quantile |
float |
Upper quantile for outlier removal. Defaults to 1.0. |
1.0 |
log_offset |
float |
Offset for log transformation. Defaults to 0.01. |
0.01 |
Returns:
Type | Description |
---|---|
tuple |
A tuple containing: - train_dl (DataLoader): Training data loader - test_dl (DataLoader): Test data loader - input_dim (int): Input feature dimension - output_dim (int): Output dimension - train_ids (list): Training sample IDs - test_ids (list): Test sample IDs |
Source code in hypercoast/moe_vae/data_loading.py
def load_real_data(
excel_path,
selected_bands,
split_ratio=0.7,
seed=42,
diff_before_norm=False,
diff_after_norm=False,
target_parameter="TSS",
lower_quantile=0.0,
upper_quantile=1.0,
log_offset=0.01,
):
"""Load and preprocess real data using MinMax scaling.
This function loads remote sensing data from Excel files and applies
MinMax normalization with optional differencing operations. Each sample
is normalized independently to the range [1, 10].
Args:
excel_path (str): Path to Excel file containing the data.
selected_bands (list): List of wavelength bands to extract.
split_ratio (float, optional): Train/test split ratio. Defaults to 0.7.
seed (int, optional): Random seed for reproducible splits. Defaults to 42.
diff_before_norm (bool, optional): Apply differencing before normalization.
Defaults to False.
diff_after_norm (bool, optional): Apply differencing after normalization.
Defaults to False.
target_parameter (str, optional): Target parameter column name.
Defaults to "TSS".
lower_quantile (float, optional): Lower quantile for outlier removal.
Defaults to 0.0.
upper_quantile (float, optional): Upper quantile for outlier removal.
Defaults to 1.0.
log_offset (float, optional): Offset for log transformation.
Defaults to 0.01.
Returns:
tuple: A tuple containing:
- train_dl (DataLoader): Training data loader
- test_dl (DataLoader): Test data loader
- input_dim (int): Input feature dimension
- output_dim (int): Output dimension
- train_ids (list): Training sample IDs
- test_ids (list): Test sample IDs
"""
rounded_bands = [int(round(b)) for b in selected_bands]
band_cols = [f"Rrs_{b}" for b in rounded_bands]
df_rrs = pd.read_excel(excel_path, sheet_name="Rrs")
df_param = pd.read_excel(excel_path, sheet_name="parameter")
df_rrs_selected = df_rrs[["GLORIA_ID"] + band_cols]
df_param_selected = df_param[["GLORIA_ID", target_parameter]]
df_merged = pd.merge(
df_rrs_selected, df_param_selected, on="GLORIA_ID", how="inner"
)
# === Filter valid samples ===
mask_rrs_valid = df_merged[band_cols].notna().all(axis=1)
mask_target_valid = df_merged[target_parameter].notna()
df_filtered = df_merged[mask_rrs_valid & mask_target_valid].reset_index(drop=True)
print(
f"✅ Number of samples after filtering Rrs and {target_parameter}: {len(df_filtered)}"
)
# === Quantile clipping for target parameter ===
lower = df_filtered[target_parameter].quantile(lower_quantile)
upper = df_filtered[target_parameter].quantile(upper_quantile)
df_filtered = df_filtered[
(df_filtered[target_parameter] >= lower)
& (df_filtered[target_parameter] <= upper)
].reset_index(drop=True)
print(
f"✅ Number of samples after removing {target_parameter} quantiles [{lower_quantile}, {upper_quantile}]: {len(df_filtered)}"
)
# === Extract sample IDs, Rrs, and target parameter ===
all_sample_ids = df_filtered["GLORIA_ID"].astype(str).tolist()
Rrs_array = df_filtered[band_cols].values
param_array = df_filtered[[target_parameter]].values
if diff_before_norm:
Rrs_array = np.diff(Rrs_array, axis=1)
# === Apply MinMax scaling to [1, 10] for each sample independently ===
scalers_Rrs_real = [MinMaxScaler((1, 10)) for _ in range(Rrs_array.shape[0])]
Rrs_normalized = np.array(
[
scalers_Rrs_real[i].fit_transform(row.reshape(-1, 1)).flatten()
for i, row in enumerate(Rrs_array)
]
)
if diff_after_norm:
Rrs_normalized = np.diff(Rrs_normalized, axis=1)
# === Transform target parameter to log10(param + log_offset) ===
param_transformed = np.log10(param_array + log_offset)
# === Build Dataset ===
Rrs_tensor = torch.tensor(Rrs_normalized, dtype=torch.float32)
param_tensor = torch.tensor(param_transformed, dtype=torch.float32)
dataset = TensorDataset(Rrs_tensor, param_tensor)
# === Split into training and testing sets ===
num_samples = len(dataset)
indices = np.arange(num_samples)
np.random.seed(seed)
np.random.shuffle(indices)
train_size = int(split_ratio * num_samples)
train_indices = indices[:train_size]
test_indices = indices[train_size:]
train_dataset = Subset(dataset, train_indices)
test_dataset = Subset(dataset, test_indices)
train_ids = [all_sample_ids[i] for i in train_indices]
test_ids = [all_sample_ids[i] for i in test_indices]
train_dl = DataLoader(train_dataset, batch_size=1024, shuffle=True, num_workers=0)
test_dl = DataLoader(test_dataset, batch_size=1024, shuffle=False, num_workers=0)
input_dim = Rrs_tensor.shape[1]
output_dim = param_tensor.shape[1]
return (train_dl, test_dl, input_dim, output_dim, train_ids, test_ids)
load_real_data_Robust(excel_path, selected_bands, target_parameter='TSS', split_ratio=0.7, seed=42, use_diff=False, lower_quantile=0.0, upper_quantile=1.0, Rrs_range=(0, 0.25), target_range=(-0.5, 0.5))
¶
Load and preprocess real data using robust scaling methods.
This function loads remote sensing reflectance (Rrs) and parameter data from Excel files, applies robust preprocessing including quantile filtering, normalization, and data splitting for training/testing.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
excel_path |
str |
Path to Excel file containing Rrs and parameter data. |
required |
selected_bands |
list |
List of wavelength bands to extract from Rrs data. |
required |
target_parameter |
str |
Name of target parameter column. Defaults to "TSS". |
'TSS' |
split_ratio |
float |
Train/test split ratio. Defaults to 0.7. |
0.7 |
seed |
int |
Random seed for reproducible splits. Defaults to 42. |
42 |
use_diff |
bool |
Whether to apply first difference to Rrs. Defaults to False. |
False |
lower_quantile |
float |
Lower quantile for outlier removal. Defaults to 0.0. |
0.0 |
upper_quantile |
float |
Upper quantile for outlier removal. Defaults to 1.0. |
1.0 |
Rrs_range |
tuple |
Target range for Rrs normalization. Defaults to (0, 0.25). |
(0, 0.25) |
target_range |
tuple |
Target range for parameter normalization. Defaults to (-0.5, 0.5). |
(-0.5, 0.5) |
Returns:
Type | Description |
---|---|
tuple |
A tuple containing: - train_dl (DataLoader): Training data loader - test_dl (DataLoader): Test data loader - input_dim (int): Input feature dimension - output_dim (int): Output dimension - train_ids (list): Training sample IDs - test_ids (list): Test sample IDs - scaler_Rrs: Fitted Rrs scaler object - TSS_scalers_dict (dict): Dictionary of fitted target scalers |
Source code in hypercoast/moe_vae/data_loading.py
def load_real_data_Robust(
excel_path,
selected_bands,
target_parameter="TSS",
split_ratio=0.7,
seed=42,
use_diff=False,
lower_quantile=0.0,
upper_quantile=1.0,
Rrs_range=(0, 0.25),
target_range=(-0.5, 0.5),
):
"""Load and preprocess real data using robust scaling methods.
This function loads remote sensing reflectance (Rrs) and parameter data from
Excel files, applies robust preprocessing including quantile filtering,
normalization, and data splitting for training/testing.
Args:
excel_path (str): Path to Excel file containing Rrs and parameter data.
selected_bands (list): List of wavelength bands to extract from Rrs data.
target_parameter (str, optional): Name of target parameter column.
Defaults to "TSS".
split_ratio (float, optional): Train/test split ratio. Defaults to 0.7.
seed (int, optional): Random seed for reproducible splits. Defaults to 42.
use_diff (bool, optional): Whether to apply first difference to Rrs.
Defaults to False.
lower_quantile (float, optional): Lower quantile for outlier removal.
Defaults to 0.0.
upper_quantile (float, optional): Upper quantile for outlier removal.
Defaults to 1.0.
Rrs_range (tuple, optional): Target range for Rrs normalization.
Defaults to (0, 0.25).
target_range (tuple, optional): Target range for parameter normalization.
Defaults to (-0.5, 0.5).
Returns:
tuple: A tuple containing:
- train_dl (DataLoader): Training data loader
- test_dl (DataLoader): Test data loader
- input_dim (int): Input feature dimension
- output_dim (int): Output dimension
- train_ids (list): Training sample IDs
- test_ids (list): Test sample IDs
- scaler_Rrs: Fitted Rrs scaler object
- TSS_scalers_dict (dict): Dictionary of fitted target scalers
"""
rounded_bands = [int(round(b)) for b in selected_bands]
band_cols = [f"Rrs_{b}" for b in rounded_bands]
df_rrs = pd.read_excel(excel_path, sheet_name="Rrs")
df_param = pd.read_excel(excel_path, sheet_name="parameter")
df_rrs_selected = df_rrs[["GLORIA_ID"] + band_cols]
df_param_selected = df_param[["GLORIA_ID", target_parameter]]
df_merged = pd.merge(
df_rrs_selected, df_param_selected, on="GLORIA_ID", how="inner"
)
mask_rrs_valid = df_merged[band_cols].notna().all(axis=1)
mask_param_valid = df_merged[target_parameter].notna()
df_filtered = df_merged[mask_rrs_valid & mask_param_valid].reset_index(drop=True)
print(
f"Number of samples after filtering Rrs and {target_parameter}: {len(df_filtered)}"
)
lower = df_filtered[target_parameter].quantile(lower_quantile)
top = df_filtered[target_parameter].quantile(upper_quantile)
df_filtered = df_filtered[
(df_filtered[target_parameter] >= lower)
& (df_filtered[target_parameter] <= top)
].reset_index(drop=True)
print(
f"Number of samples after removing {target_parameter} quantiles [{lower_quantile}, {upper_quantile}]: {len(df_filtered)}"
)
all_sample_ids = df_filtered["GLORIA_ID"].astype(str).tolist()
Rrs_array = df_filtered[band_cols].values
param_array = df_filtered[[target_parameter]].values
if use_diff:
Rrs_array = np.diff(Rrs_array, axis=1)
scaler_Rrs = RobustMinMaxScaler(feature_range=Rrs_range)
scaler_Rrs.fit(torch.tensor(Rrs_array, dtype=torch.float32))
Rrs_normalized = scaler_Rrs.transform(
torch.tensor(Rrs_array, dtype=torch.float32)
).numpy()
log_scaler = LogScaler(shift_min=False, safety_term=1e-8)
param_log = log_scaler.fit_transform(torch.tensor(param_array, dtype=torch.float32))
param_scaler = RobustMinMaxScaler(
feature_range=target_range, global_scale=True, robust=True
)
param_transformed = param_scaler.fit_transform(param_log).numpy()
Rrs_tensor = torch.tensor(Rrs_normalized, dtype=torch.float32)
param_tensor = torch.tensor(param_transformed, dtype=torch.float32)
dataset = TensorDataset(Rrs_tensor, param_tensor)
num_samples = len(dataset)
indices = np.arange(num_samples)
np.random.seed(seed)
np.random.shuffle(indices)
train_size = int(split_ratio * num_samples)
train_indices = indices[:train_size]
test_indices = indices[train_size:]
train_dataset = Subset(dataset, train_indices)
test_dataset = Subset(dataset, test_indices)
train_ids = [all_sample_ids[i] for i in train_indices]
test_ids = [all_sample_ids[i] for i in test_indices]
train_dl = DataLoader(train_dataset, batch_size=1024, shuffle=True, num_workers=0)
test_dl = DataLoader(test_dataset, batch_size=1024, shuffle=False, num_workers=0)
input_dim = Rrs_tensor.shape[1]
output_dim = param_tensor.shape[1]
TSS_scalers_dict = {"log": log_scaler, "robust": param_scaler}
return (
train_dl,
test_dl,
input_dim,
output_dim,
train_ids,
test_ids,
scaler_Rrs,
TSS_scalers_dict,
)
load_real_test(excel_path, selected_bands, max_allowed_diff=1.0, diff_before_norm=False, diff_after_norm=False, target_parameter='TSS', log_offset=0.01)
¶
Load and preprocess real test data using MinMax scaling.
This function loads test data from Excel files, matches wavelength bands to available bands, and applies MinMax normalization with optional differencing operations consistent with training preprocessing.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
excel_path |
str |
Path to Excel file containing test data. |
required |
selected_bands |
list |
List of target wavelength bands. |
required |
max_allowed_diff |
float |
Maximum allowed wavelength difference for band matching in nm. Defaults to 1.0. |
1.0 |
diff_before_norm |
bool |
Apply differencing before normalization. Defaults to False. |
False |
diff_after_norm |
bool |
Apply differencing after normalization. Defaults to False. |
False |
target_parameter |
str |
Target parameter column name. Defaults to "TSS". |
'TSS' |
log_offset |
float |
Offset for log transformation. Defaults to 0.01. |
0.01 |
Returns:
Type | Description |
---|---|
tuple |
A tuple containing: - test_dl (DataLoader): Test data loader - input_dim (int): Input feature dimension - output_dim (int): Output dimension - sample_ids (list): Sample identifiers - sample_dates (list): Sample dates |
Exceptions:
Type | Description |
---|---|
ValueError |
If table row counts don't match or wavelengths can't be matched. |
Source code in hypercoast/moe_vae/data_loading.py
def load_real_test(
excel_path,
selected_bands,
max_allowed_diff=1.0,
diff_before_norm=False,
diff_after_norm=False,
target_parameter="TSS",
log_offset=0.01,
):
"""Load and preprocess real test data using MinMax scaling.
This function loads test data from Excel files, matches wavelength bands
to available bands, and applies MinMax normalization with optional
differencing operations consistent with training preprocessing.
Args:
excel_path (str): Path to Excel file containing test data.
selected_bands (list): List of target wavelength bands.
max_allowed_diff (float, optional): Maximum allowed wavelength difference
for band matching in nm. Defaults to 1.0.
diff_before_norm (bool, optional): Apply differencing before normalization.
Defaults to False.
diff_after_norm (bool, optional): Apply differencing after normalization.
Defaults to False.
target_parameter (str, optional): Target parameter column name.
Defaults to "TSS".
log_offset (float, optional): Offset for log transformation.
Defaults to 0.01.
Returns:
tuple: A tuple containing:
- test_dl (DataLoader): Test data loader
- input_dim (int): Input feature dimension
- output_dim (int): Output dimension
- sample_ids (list): Sample identifiers
- sample_dates (list): Sample dates
Raises:
ValueError: If table row counts don't match or wavelengths can't be matched.
"""
df_rrs = pd.read_excel(excel_path, sheet_name="Rrs")
df_param = pd.read_excel(excel_path, sheet_name="parameter")
if df_rrs.shape[0] != df_param.shape[0]:
raise ValueError(
f"❌ The number of rows in the Rrs table and parameter table do not match. Rrs: {df_rrs.shape[0]}, parameter: {df_param.shape[0]}"
)
# === Extract IDs and dates ===
sample_ids = df_rrs["Site Label"].astype(str).tolist()
sample_dates = df_rrs["Date"].astype(str).tolist()
# === Match target bands ===
rrs_wavelengths = []
rrs_cols = []
for col in df_rrs.columns:
try:
wl = float(col)
rrs_wavelengths.append(wl)
rrs_cols.append(col)
except Exception:
continue
band_cols = []
matched_bands = []
for target_band in selected_bands:
diffs = [abs(wl - target_band) for wl in rrs_wavelengths]
min_diff = min(diffs)
if min_diff > max_allowed_diff:
raise ValueError(
f"Target wavelength {target_band} nm cannot be matched, error {min_diff:.2f} nm exceeds the allowed range"
)
best_idx = diffs.index(min_diff)
band_cols.append(rrs_cols[best_idx])
matched_bands.append(rrs_wavelengths[best_idx])
print(
f"\n✅ Band matching successful, {len(selected_bands)} target bands in total, {len(band_cols)} columns actually extracted"
)
print(f"Original number of test samples: {df_rrs.shape[0]}\n")
# === Extract Rrs and target parameter (without differencing for now) ===
Rrs_array = df_rrs[band_cols].values.astype(float)
target_array = df_param[[target_parameter]].values.astype(float).flatten()
# === Key: Remove rows with NaN/Inf before differencing ===
mask_inputs_ok = np.all(np.isfinite(Rrs_array), axis=1)
mask_target_ok = np.isfinite(target_array)
mask_ok = mask_inputs_ok & mask_target_ok
if not np.any(mask_ok):
raise ValueError("❌ No valid samples (NaN/Inf found in input or target).")
dropped = int(len(mask_ok) - mask_ok.sum())
if dropped > 0:
print(
f"⚠️ Dropped {dropped} invalid samples (containing NaN/Inf) before differencing"
)
Rrs_array = Rrs_array[mask_ok]
target_array = target_array[mask_ok]
sample_ids = [sid for sid, keep in zip(sample_ids, mask_ok) if keep]
sample_dates = [d for d, keep in zip(sample_dates, mask_ok) if keep]
# === Preprocessing before differencing (optional) ===
if diff_before_norm:
Rrs_array = np.diff(Rrs_array, axis=1)
# === Apply MinMaxScaler to [1, 10] for each sample ===
scalers_Rrs_test = [MinMaxScaler((1, 10)) for _ in range(Rrs_array.shape[0])]
Rrs_normalized = np.array(
[
scalers_Rrs_test[i].fit_transform(row.reshape(-1, 1)).flatten()
for i, row in enumerate(Rrs_array)
]
)
# === Post-processing after differencing (optional) ===
if diff_after_norm:
Rrs_normalized = np.diff(Rrs_normalized, axis=1)
# === Transform target value to log10(x + log_offset) ===
target_transformed = np.log10(target_array + log_offset)
# === Construct DataLoader ===
Rrs_tensor = torch.tensor(Rrs_normalized, dtype=torch.float32)
target_tensor = torch.tensor(target_transformed.reshape(-1, 1), dtype=torch.float32)
dataset = TensorDataset(Rrs_tensor, target_tensor)
test_dl = DataLoader(dataset, batch_size=len(dataset), shuffle=False, num_workers=0)
input_dim = Rrs_tensor.shape[1]
output_dim = target_tensor.shape[1]
return test_dl, input_dim, output_dim, sample_ids, sample_dates
load_real_test_Robust(excel_path, selected_bands, max_allowed_diff=1.0, scaler_Rrs=None, scalers_dict=None, use_diff=False, target_parameter='SPM')
¶
Load and preprocess real test data using robust scaling methods.
This function loads test data from Excel files, matches wavelength bands to the nearest available bands, and applies the same preprocessing transformations as used during training.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
excel_path |
str |
Path to Excel file containing test data. |
required |
selected_bands |
list |
List of target wavelength bands. |
required |
max_allowed_diff |
float |
Maximum allowed wavelength difference for band matching. Defaults to 1.0. |
1.0 |
scaler_Rrs |
Pre-fitted Rrs scaler from training data. |
None |
|
scalers_dict |
dict |
Dictionary of pre-fitted scalers from training. |
None |
use_diff |
bool |
Whether to apply first difference. Defaults to False. |
False |
target_parameter |
str |
Name of target parameter. Defaults to "SPM". |
'SPM' |
Returns:
Type | Description |
---|---|
tuple |
A tuple containing: - test_dl (DataLoader): Test data loader - input_dim (int): Input feature dimension - output_dim (int): Output dimension - sample_ids (list): Sample identifiers - sample_dates (list): Sample dates |
Exceptions:
Type | Description |
---|---|
ValueError |
If number of rows in Rrs and parameter tables don't match, or if target wavelengths cannot be matched within tolerance. |
Source code in hypercoast/moe_vae/data_loading.py
def load_real_test_Robust(
excel_path,
selected_bands,
max_allowed_diff=1.0,
scaler_Rrs=None,
scalers_dict=None,
use_diff=False,
target_parameter="SPM",
):
"""Load and preprocess real test data using robust scaling methods.
This function loads test data from Excel files, matches wavelength bands
to the nearest available bands, and applies the same preprocessing
transformations as used during training.
Args:
excel_path (str): Path to Excel file containing test data.
selected_bands (list): List of target wavelength bands.
max_allowed_diff (float, optional): Maximum allowed wavelength difference
for band matching. Defaults to 1.0.
scaler_Rrs: Pre-fitted Rrs scaler from training data.
scalers_dict (dict): Dictionary of pre-fitted scalers from training.
use_diff (bool, optional): Whether to apply first difference.
Defaults to False.
target_parameter (str, optional): Name of target parameter.
Defaults to "SPM".
Returns:
tuple: A tuple containing:
- test_dl (DataLoader): Test data loader
- input_dim (int): Input feature dimension
- output_dim (int): Output dimension
- sample_ids (list): Sample identifiers
- sample_dates (list): Sample dates
Raises:
ValueError: If number of rows in Rrs and parameter tables don't match,
or if target wavelengths cannot be matched within tolerance.
"""
df_rrs = pd.read_excel(excel_path, sheet_name="Rrs")
df_param = pd.read_excel(excel_path, sheet_name="parameter")
if df_rrs.shape[0] != df_param.shape[0]:
raise ValueError(
f"❌ The number of rows in the Rrs table and parameter table do not match. Rrs: {df_rrs.shape[0]}, parameter: {df_param.shape[0]}"
)
sample_ids = df_rrs["Site Label"].astype(str).tolist()
sample_dates = df_rrs["Date"].astype(str).tolist()
# Match target bands
rrs_wavelengths = []
rrs_cols = []
for col in df_rrs.columns:
try:
wl = float(col)
rrs_wavelengths.append(wl)
rrs_cols.append(col)
except:
continue
band_cols = []
for target_band in selected_bands:
diffs = [abs(wl - target_band) for wl in rrs_wavelengths]
min_diff = min(diffs)
if min_diff > max_allowed_diff:
raise ValueError(
f"Target wavelength {target_band} nm cannot be matched, error {min_diff:.2f} nm exceeds the allowed range"
)
best_idx = diffs.index(min_diff)
band_cols.append(rrs_cols[best_idx])
print(f"\n✅ Band matching successful, {len(selected_bands)} target bands in total")
print(f"Final number of valid test samples: {df_rrs.shape[0]}\n")
Rrs_array = df_rrs[band_cols].values
param_array = df_param[[target_parameter]].values.flatten()
# === Key: Remove rows with NaN/Inf before differencing ===
mask_inputs_ok = np.all(np.isfinite(Rrs_array), axis=1)
mask_target_ok = np.isfinite(param_array)
mask_ok = mask_inputs_ok & mask_target_ok
if not np.any(mask_ok):
raise ValueError("❌ Valid samples = 0 (NaN/Inf found in input or target).")
dropped = int(len(mask_ok) - mask_ok.sum())
if dropped > 0:
print(
f"⚠️ Dropped {dropped} invalid samples (containing NaN/Inf) before differencing"
)
Rrs_array = Rrs_array[mask_ok]
param_array = param_array[mask_ok]
sample_ids = [sid for sid, keep in zip(sample_ids, mask_ok) if keep]
sample_dates = [d for d, keep in zip(sample_dates, mask_ok) if keep]
if use_diff:
Rrs_array = np.diff(Rrs_array, axis=1)
Rrs_tensor = torch.tensor(Rrs_array, dtype=torch.float32)
Rrs_normalized = scaler_Rrs.transform(Rrs_tensor).numpy()
log_scaler = scalers_dict["log"]
robust_scaler = scalers_dict["robust"]
param_log = log_scaler.transform(
torch.tensor(param_array.reshape(-1, 1), dtype=torch.float32)
)
param_transformed = robust_scaler.transform(param_log).numpy()
dataset = TensorDataset(
torch.tensor(Rrs_normalized, dtype=torch.float32),
torch.tensor(param_transformed.reshape(-1, 1), dtype=torch.float32),
)
test_dl = DataLoader(dataset, batch_size=len(dataset), shuffle=False, num_workers=0)
input_dim = Rrs_tensor.shape[1]
output_dim = 1
return test_dl, input_dim, output_dim, sample_ids, sample_dates