Skip to content

moe_vae module

Mixture of Experts Variational Autoencoder (MoE-VAE) models.

This module implements VAE and MoE-VAE architectures for water quality parameter estimation from hyperspectral remote sensing data. It includes sparse dispatching for efficient expert routing and training/evaluation utilities.

MoE_VAE (LightningModule)

Mixture of Experts Variational Autoencoder for water quality estimation.

This class implements a sparsely-gated MoE architecture where multiple VAE experts specialize in different regions of the input space. A learned gating network dynamically routes each input to the top-k most relevant experts.

Parameters:

Name Type Description Default
input_dim int

Number of input features (spectral bands).

required
output_dim int

Number of output features (water quality parameters).

required
latent_dim int

Dimension of the latent space for each expert VAE.

required
encoder_hidden_dims List[int]

List of hidden layer dimensions for encoders.

required
decoder_hidden_dims List[int]

List of hidden layer dimensions for decoders.

required
num_experts int

Total number of expert VAEs.

required
k int

Number of experts to activate for each input sample.

4
activation str

Activation function ('relu', 'tanh', 'sigmoid', 'leakyrelu').

'leakyrelu'
noisy_gating bool

Whether to add noise to gating for exploration during training.

True
use_norm Union[str, bool]

Type of normalization ('batch', 'layer', 'group', or False).

False
use_dropout bool

Whether to use dropout regularization.

False
use_softplus_output bool

Whether to apply softplus activation to output.

False
**kwargs

Additional keyword arguments.

{}
Source code in hypercoast/emit_utils/MoE_VAE.py
class MoE_VAE(LightningModule):
    """Mixture of Experts Variational Autoencoder for water quality estimation.

    This class implements a sparsely-gated MoE architecture where multiple VAE
    experts specialize in different regions of the input space. A learned gating
    network dynamically routes each input to the top-k most relevant experts.

    Args:
        input_dim: Number of input features (spectral bands).
        output_dim: Number of output features (water quality parameters).
        latent_dim: Dimension of the latent space for each expert VAE.
        encoder_hidden_dims: List of hidden layer dimensions for encoders.
        decoder_hidden_dims: List of hidden layer dimensions for decoders.
        num_experts: Total number of expert VAEs.
        k: Number of experts to activate for each input sample.
        activation: Activation function ('relu', 'tanh', 'sigmoid', 'leakyrelu').
        noisy_gating: Whether to add noise to gating for exploration during training.
        use_norm: Type of normalization ('batch', 'layer', 'group', or False).
        use_dropout: Whether to use dropout regularization.
        use_softplus_output: Whether to apply softplus activation to output.
        **kwargs: Additional keyword arguments.
    """

    def __init__(
        self,
        input_dim: int,
        output_dim: int,
        latent_dim: int,
        encoder_hidden_dims: List[int],
        decoder_hidden_dims: List[int],
        num_experts: int,
        k: int = 4,
        activation: str = "leakyrelu",
        noisy_gating: bool = True,
        use_norm: Union[str, bool] = False,
        use_dropout: bool = False,
        use_softplus_output: bool = False,
        **kwargs,
    ):
        super(MoE_VAE, self).__init__()
        self.noisy_gating = noisy_gating
        self.num_experts = num_experts
        self.output_dim = output_dim
        self.input_dim = input_dim
        self.latent_dim = latent_dim
        self.encoder_hidden_dims = encoder_hidden_dims
        self.decoder_hidden_dims = decoder_hidden_dims
        self.num_experts = num_experts
        self.k = k
        self.activation = activation
        self.use_norm = use_norm
        self.use_dropout = use_dropout
        self.use_softplus_output = use_softplus_output

        # instantiate experts
        self.experts = nn.ModuleList(
            [
                VAE(
                    self.input_dim,
                    self.output_dim,
                    self.latent_dim,
                    self.encoder_hidden_dims,
                    self.decoder_hidden_dims,
                    self.activation,
                    use_norm=self.use_norm,
                    use_dropout=self.use_dropout,
                    use_softplus_output=self.use_softplus_output,
                )
                for i in range(self.num_experts)
            ]
        )

        self.w_gate = nn.Parameter(
            torch.zeros(input_dim, num_experts, dtype=self.dtype), requires_grad=True
        )
        self.w_noise = nn.Parameter(
            torch.zeros(input_dim, num_experts, dtype=self.dtype), requires_grad=True
        )

        self.softplus = nn.Softplus()
        self.softmax = nn.Softmax(1)
        self.register_buffer("mean", torch.tensor([0.0]))
        self.register_buffer("std", torch.tensor([1.0]))
        self.batch_gates = None

        assert self.k <= self.num_experts

    def forward(
        self, x: torch.Tensor, moe_weight: float = 1e-2
    ) -> Dict[str, torch.Tensor]:
        """Forward pass through the MoE-VAE model.

        Args:
            x: Input tensor of shape (batch_size, input_dim).
            moe_weight: Weight for the MoE load balancing loss.

        Returns:
            Dictionary containing:
                - 'pred_y': Predicted output tensor (batch_size, output_dim).
                - 'moe_loss': Load balancing loss to encourage uniform expert usage.
        """
        gates, load = self.noisy_top_k_gating(x, self.training)
        self.batch_gates = gates
        # calculate importance loss
        importance = gates.sum(0)

        moe_loss = moe_weight * self.cv_squared(
            importance
        ) + moe_weight * self.cv_squared(load)

        dispatcher = SparseDispatcher(self.num_experts, gates)
        expert_inputs = dispatcher.dispatch(x)
        gates = dispatcher.expert_to_gates()
        expert_outputs = []
        for i in range(self.num_experts):
            input_i = expert_inputs[i]
            if input_i.shape[0] > 1:
                expert_outputs.append(self.experts[i](input_i)["pred_y"])
            else:
                expert_outputs.append(
                    torch.zeros(
                        (input_i.shape[0], self.output_dim), device=input_i.device
                    )
                )
        pred_y = dispatcher.combine(expert_outputs)
        return {"pred_y": pred_y, "moe_loss": moe_loss}

    def loss_fn(self, output_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
        """Compute MoE-VAE loss including reconstruction and load balancing terms.

        Args:
            output_dict: Dictionary containing 'pred_y', 'y', and 'moe_loss'.

        Returns:
            Dictionary containing 'total_loss', 'mae_loss', 'mse_loss', and 'moe_loss'.
        """
        pred_y = output_dict["pred_y"]
        y = output_dict["y"]
        batch_size = y.shape[0]
        MAE = F.l1_loss(pred_y, y, reduction="mean")
        mse_losss = F.mse_loss(pred_y, y, reduction="mean")
        moe_loss = output_dict.get(
            "moe_loss", torch.tensor(0.0, device=pred_y.device, dtype=pred_y.dtype)
        )
        total_loss = MAE + moe_loss
        return {
            "total_loss": total_loss,
            "mae_loss": MAE,
            "mse_loss": mse_losss,
            "moe_loss": moe_loss,
        }

    def get_batch_gates(self) -> Optional[torch.Tensor]:
        """Get the gating weights from the most recent forward pass.

        Returns:
            Tensor of shape (batch_size, num_experts) or None if forward hasn't been called.
        """
        return self.batch_gates

    def cv_squared(self, x: torch.Tensor) -> torch.Tensor:
        """Compute squared coefficient of variation for load balancing.

        This metric encourages uniform distribution across experts by penalizing
        high variance in expert usage.

        Args:
            x: Input tensor (e.g., expert loads or importance weights).

        Returns:
            Squared coefficient of variation scalar.
        """
        eps = 1e-10
        # if only num_experts = 1

        if x.shape[0] == 1:
            return torch.tensor([0], device=x.device, dtype=x.dtype)
        return x.float().var() / (x.float().mean() ** 2 + eps)

    def _gates_to_load(self, gates: torch.Tensor) -> torch.Tensor:
        """Compute the load (number of samples) assigned to each expert.

        Args:
            gates: Gating tensor of shape (batch_size, num_experts).

        Returns:
            Load tensor of shape (num_experts,) with counts of assigned samples.
        """
        return (gates > 0).sum(0)

    def _prob_in_top_k(
        self,
        clean_values: torch.Tensor,
        noisy_values: torch.Tensor,
        noise_stddev: torch.Tensor,
        noisy_top_values: torch.Tensor,
    ) -> torch.Tensor:
        """Compute probability that each value is in top-k after adding noise.

        This is used for load balancing during training by computing differentiable
        probabilities of expert selection.

        Args:
            clean_values: Clean logits of shape (batch, num_experts).
            noisy_values: Noisy logits of shape (batch, num_experts).
            noise_stddev: Noise standard deviations of shape (batch, num_experts).
            noisy_top_values: Top-k noisy values of shape (batch, k+1).

        Returns:
            Probabilities of shape (batch, num_experts).
        """
        batch = clean_values.size(0)
        m = noisy_top_values.size(1)
        top_values_flat = noisy_top_values.flatten()

        threshold_positions_if_in = (
            torch.arange(batch, device=clean_values.device) * m + self.k
        )
        threshold_if_in = torch.unsqueeze(
            torch.gather(top_values_flat, 0, threshold_positions_if_in), 1
        )
        is_in = torch.gt(noisy_values, threshold_if_in)
        threshold_positions_if_out = threshold_positions_if_in - 1
        threshold_if_out = torch.unsqueeze(
            torch.gather(top_values_flat, 0, threshold_positions_if_out), 1
        )
        # is each value currently in the top k.
        normal = Normal(self.mean, self.std)
        prob_if_in = normal.cdf((clean_values - threshold_if_in) / noise_stddev)
        prob_if_out = normal.cdf((clean_values - threshold_if_out) / noise_stddev)
        prob = torch.where(is_in, prob_if_in, prob_if_out)
        return prob

    def noisy_top_k_gating(
        self, x: torch.Tensor, train: bool, noise_epsilon: float = 1e-2
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """Compute noisy top-k gating for expert selection.

        Implements the gating mechanism from "Outrageously Large Neural Networks:
        The Sparsely-Gated Mixture-of-Experts Layer" (https://arxiv.org/abs/1701.06538).
        Adds tunable noise during training for exploration.

        Args:
            x: Input tensor of shape (batch_size, input_dim).
            train: Whether model is in training mode (adds noise if True).
            noise_epsilon: Small constant for numerical stability.

        Returns:
            gates: Sparse gating weights of shape (batch_size, num_experts).
            load: Load assigned to each expert of shape (num_experts,).
        """
        clean_logits = x @ self.w_gate
        if self.noisy_gating and train:
            raw_noise_stddev = x @ self.w_noise
            noise_stddev = self.softplus(raw_noise_stddev) + noise_epsilon
            noisy_logits = clean_logits + (
                torch.randn_like(clean_logits) * noise_stddev
            )
            logits = noisy_logits
        else:
            logits = clean_logits
            # Add this safety check to ensure we always have at least one expert selected
        if (logits.sum(dim=1) == 0).any():
            # Add a small positive value to ensure we have non-zero logits
            logits = logits + 1e-5

        # calculate topk + 1 that will be needed for the noisy gates
        top_logits, top_indices = logits.topk(min(self.k + 1, self.num_experts), dim=1)
        top_k_logits = top_logits[:, : self.k]
        top_k_indices = top_indices[:, : self.k]
        top_k_gates = self.softmax(top_k_logits)

        zeros = torch.zeros_like(logits, requires_grad=True, dtype=self.dtype)
        gates = zeros.scatter(1, top_k_indices, top_k_gates)

        # Safety check - ensure at least one expert is selected per sample
        if (gates.sum(dim=1) < 1e-6).any():
            # Force selection of the top expert for samples with no experts
            problematic_samples = (gates.sum(dim=1) < 1e-6).nonzero().squeeze(1)
            if problematic_samples.numel() > 0:  # If there are problematic samples
                # Select the top expert for these samples
                top_expert = top_indices[problematic_samples, 0]
                # Set a minimum value for the gate
                gates[problematic_samples, top_expert] = 0.1

        if self.noisy_gating and self.k < self.num_experts and train:
            load = (
                self._prob_in_top_k(
                    clean_logits, noisy_logits, noise_stddev, top_logits
                )
            ).sum(0)
        else:
            load = self._gates_to_load(gates)
        return gates, load

cv_squared(self, x)

Compute squared coefficient of variation for load balancing.

This metric encourages uniform distribution across experts by penalizing high variance in expert usage.

Parameters:

Name Type Description Default
x Tensor

Input tensor (e.g., expert loads or importance weights).

required

Returns:

Type Description
Tensor

Squared coefficient of variation scalar.

Source code in hypercoast/emit_utils/MoE_VAE.py
def cv_squared(self, x: torch.Tensor) -> torch.Tensor:
    """Compute squared coefficient of variation for load balancing.

    This metric encourages uniform distribution across experts by penalizing
    high variance in expert usage.

    Args:
        x: Input tensor (e.g., expert loads or importance weights).

    Returns:
        Squared coefficient of variation scalar.
    """
    eps = 1e-10
    # if only num_experts = 1

    if x.shape[0] == 1:
        return torch.tensor([0], device=x.device, dtype=x.dtype)
    return x.float().var() / (x.float().mean() ** 2 + eps)

forward(self, x, moe_weight=0.01)

Forward pass through the MoE-VAE model.

Parameters:

Name Type Description Default
x Tensor

Input tensor of shape (batch_size, input_dim).

required
moe_weight float

Weight for the MoE load balancing loss.

0.01

Returns:

Type Description
Dictionary containing
  • 'pred_y': Predicted output tensor (batch_size, output_dim).
    • 'moe_loss': Load balancing loss to encourage uniform expert usage.
Source code in hypercoast/emit_utils/MoE_VAE.py
def forward(
    self, x: torch.Tensor, moe_weight: float = 1e-2
) -> Dict[str, torch.Tensor]:
    """Forward pass through the MoE-VAE model.

    Args:
        x: Input tensor of shape (batch_size, input_dim).
        moe_weight: Weight for the MoE load balancing loss.

    Returns:
        Dictionary containing:
            - 'pred_y': Predicted output tensor (batch_size, output_dim).
            - 'moe_loss': Load balancing loss to encourage uniform expert usage.
    """
    gates, load = self.noisy_top_k_gating(x, self.training)
    self.batch_gates = gates
    # calculate importance loss
    importance = gates.sum(0)

    moe_loss = moe_weight * self.cv_squared(
        importance
    ) + moe_weight * self.cv_squared(load)

    dispatcher = SparseDispatcher(self.num_experts, gates)
    expert_inputs = dispatcher.dispatch(x)
    gates = dispatcher.expert_to_gates()
    expert_outputs = []
    for i in range(self.num_experts):
        input_i = expert_inputs[i]
        if input_i.shape[0] > 1:
            expert_outputs.append(self.experts[i](input_i)["pred_y"])
        else:
            expert_outputs.append(
                torch.zeros(
                    (input_i.shape[0], self.output_dim), device=input_i.device
                )
            )
    pred_y = dispatcher.combine(expert_outputs)
    return {"pred_y": pred_y, "moe_loss": moe_loss}

get_batch_gates(self)

Get the gating weights from the most recent forward pass.

Returns:

Type Description
Optional[torch.Tensor]

Tensor of shape (batch_size, num_experts) or None if forward hasn't been called.

Source code in hypercoast/emit_utils/MoE_VAE.py
def get_batch_gates(self) -> Optional[torch.Tensor]:
    """Get the gating weights from the most recent forward pass.

    Returns:
        Tensor of shape (batch_size, num_experts) or None if forward hasn't been called.
    """
    return self.batch_gates

loss_fn(self, output_dict)

Compute MoE-VAE loss including reconstruction and load balancing terms.

Parameters:

Name Type Description Default
output_dict Dict[str, torch.Tensor]

Dictionary containing 'pred_y', 'y', and 'moe_loss'.

required

Returns:

Type Description
Dict[str, torch.Tensor]

Dictionary containing 'total_loss', 'mae_loss', 'mse_loss', and 'moe_loss'.

Source code in hypercoast/emit_utils/MoE_VAE.py
def loss_fn(self, output_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
    """Compute MoE-VAE loss including reconstruction and load balancing terms.

    Args:
        output_dict: Dictionary containing 'pred_y', 'y', and 'moe_loss'.

    Returns:
        Dictionary containing 'total_loss', 'mae_loss', 'mse_loss', and 'moe_loss'.
    """
    pred_y = output_dict["pred_y"]
    y = output_dict["y"]
    batch_size = y.shape[0]
    MAE = F.l1_loss(pred_y, y, reduction="mean")
    mse_losss = F.mse_loss(pred_y, y, reduction="mean")
    moe_loss = output_dict.get(
        "moe_loss", torch.tensor(0.0, device=pred_y.device, dtype=pred_y.dtype)
    )
    total_loss = MAE + moe_loss
    return {
        "total_loss": total_loss,
        "mae_loss": MAE,
        "mse_loss": mse_losss,
        "moe_loss": moe_loss,
    }

noisy_top_k_gating(self, x, train, noise_epsilon=0.01)

Compute noisy top-k gating for expert selection.

Implements the gating mechanism from "Outrageously Large Neural Networks: The Sparsely-Gated Mixture-of-Experts Layer" (https://arxiv.org/abs/1701.06538). Adds tunable noise during training for exploration.

Parameters:

Name Type Description Default
x Tensor

Input tensor of shape (batch_size, input_dim).

required
train bool

Whether model is in training mode (adds noise if True).

required
noise_epsilon float

Small constant for numerical stability.

0.01

Returns:

Type Description
gates

Sparse gating weights of shape (batch_size, num_experts). load: Load assigned to each expert of shape (num_experts,).

Source code in hypercoast/emit_utils/MoE_VAE.py
def noisy_top_k_gating(
    self, x: torch.Tensor, train: bool, noise_epsilon: float = 1e-2
) -> Tuple[torch.Tensor, torch.Tensor]:
    """Compute noisy top-k gating for expert selection.

    Implements the gating mechanism from "Outrageously Large Neural Networks:
    The Sparsely-Gated Mixture-of-Experts Layer" (https://arxiv.org/abs/1701.06538).
    Adds tunable noise during training for exploration.

    Args:
        x: Input tensor of shape (batch_size, input_dim).
        train: Whether model is in training mode (adds noise if True).
        noise_epsilon: Small constant for numerical stability.

    Returns:
        gates: Sparse gating weights of shape (batch_size, num_experts).
        load: Load assigned to each expert of shape (num_experts,).
    """
    clean_logits = x @ self.w_gate
    if self.noisy_gating and train:
        raw_noise_stddev = x @ self.w_noise
        noise_stddev = self.softplus(raw_noise_stddev) + noise_epsilon
        noisy_logits = clean_logits + (
            torch.randn_like(clean_logits) * noise_stddev
        )
        logits = noisy_logits
    else:
        logits = clean_logits
        # Add this safety check to ensure we always have at least one expert selected
    if (logits.sum(dim=1) == 0).any():
        # Add a small positive value to ensure we have non-zero logits
        logits = logits + 1e-5

    # calculate topk + 1 that will be needed for the noisy gates
    top_logits, top_indices = logits.topk(min(self.k + 1, self.num_experts), dim=1)
    top_k_logits = top_logits[:, : self.k]
    top_k_indices = top_indices[:, : self.k]
    top_k_gates = self.softmax(top_k_logits)

    zeros = torch.zeros_like(logits, requires_grad=True, dtype=self.dtype)
    gates = zeros.scatter(1, top_k_indices, top_k_gates)

    # Safety check - ensure at least one expert is selected per sample
    if (gates.sum(dim=1) < 1e-6).any():
        # Force selection of the top expert for samples with no experts
        problematic_samples = (gates.sum(dim=1) < 1e-6).nonzero().squeeze(1)
        if problematic_samples.numel() > 0:  # If there are problematic samples
            # Select the top expert for these samples
            top_expert = top_indices[problematic_samples, 0]
            # Set a minimum value for the gate
            gates[problematic_samples, top_expert] = 0.1

    if self.noisy_gating and self.k < self.num_experts and train:
        load = (
            self._prob_in_top_k(
                clean_logits, noisy_logits, noise_stddev, top_logits
            )
        ).sum(0)
    else:
        load = self._gates_to_load(gates)
    return gates, load

SparseDispatcher

Helper for implementing sparse mixture of experts routing.

This class efficiently dispatches inputs to different experts based on gating weights and combines their outputs. It leverages sparsity by only sending batch elements to experts with non-zero gate values.

The dispatcher performs two key operations: 1. dispatch(): Distribute input samples to appropriate experts. 2. combine(): Aggregate expert outputs weighted by gate values.

Examples:

gates = torch.tensor([[0.7, 0.3, 0.0], [0.0, 0.5, 0.5]]) # (batch=2, experts=3) dispatcher = SparseDispatcher(num_experts=3, gates=gates) expert_inputs = dispatcher.dispatch(inputs) expert_outputs = [expert(expert_inputs[i]) for i, expert in enumerate(experts)] combined_output = dispatcher.combine(expert_outputs)

Parameters:

Name Type Description Default
num_experts int

Total number of experts.

required
gates Tensor

Tensor of shape (batch_size, num_experts) with gating weights. Non-zero values indicate which experts receive which samples.

required
Source code in hypercoast/emit_utils/MoE_VAE.py
class SparseDispatcher(object):
    """Helper for implementing sparse mixture of experts routing.

    This class efficiently dispatches inputs to different experts based on
    gating weights and combines their outputs. It leverages sparsity by only
    sending batch elements to experts with non-zero gate values.

    The dispatcher performs two key operations:
    1. dispatch(): Distribute input samples to appropriate experts.
    2. combine(): Aggregate expert outputs weighted by gate values.

    Example:
        gates = torch.tensor([[0.7, 0.3, 0.0], [0.0, 0.5, 0.5]])  # (batch=2, experts=3)
        dispatcher = SparseDispatcher(num_experts=3, gates=gates)
        expert_inputs = dispatcher.dispatch(inputs)
        expert_outputs = [expert(expert_inputs[i]) for i, expert in enumerate(experts)]
        combined_output = dispatcher.combine(expert_outputs)

    Args:
        num_experts: Total number of experts.
        gates: Tensor of shape (batch_size, num_experts) with gating weights.
            Non-zero values indicate which experts receive which samples.
    """

    def __init__(self, num_experts: int, gates: torch.Tensor):
        """Create a SparseDispatcher."""
        self._gates = gates
        self._num_experts = num_experts

        # Safety check: ensure at least one example per expert
        if (gates.sum(dim=0) == 0).any():
            # Find experts with no assignments and create dummy assignments
            empty_experts = (gates.sum(dim=0) == 0).nonzero().squeeze(1)
            if empty_experts.numel() > 0:
                # Assign the first example to all empty experts with a small weight
                for expert_idx in empty_experts:
                    gates[0, expert_idx] = 1e-5

        # Sort experts
        sorted_experts, index_sorted_experts = torch.nonzero(gates).sort(0)
        # Drop indices
        _, self._expert_index = sorted_experts.split(1, dim=1)
        # Get according batch index for each expert
        self._batch_index = torch.nonzero(gates)[index_sorted_experts[:, 1], 0]
        # Calculate num samples that each expert gets
        self._part_sizes = (gates > 0).sum(0).tolist()

        # Safety check: ensure no expert has 0 examples
        for i, size in enumerate(self._part_sizes):
            if size == 0:
                # Add a dummy example to this expert
                self._part_sizes[i] = 1
                if i >= len(self._expert_index):
                    # Add a new dummy index if needed
                    self._expert_index = torch.cat(
                        [self._expert_index, torch.tensor([[i]], device=gates.device)]
                    )
                    self._batch_index = torch.cat(
                        [self._batch_index, torch.tensor([0], device=gates.device)]
                    )

        # Expand gates to match with self._batch_index
        gates_exp = gates[self._batch_index.flatten()]
        self._nonzero_gates = torch.gather(gates_exp, 1, self._expert_index)

        # Safety check for nonzero gates
        if (self._nonzero_gates <= 0).any():
            self._nonzero_gates = torch.clamp(self._nonzero_gates, min=1e-5)

    def dispatch(self, inp: torch.Tensor) -> List[torch.Tensor]:
        """Dispatch input samples to their assigned experts.

        Args:
            inp: Input tensor of shape (batch_size, input_dim).

        Returns:
            List of tensors, one for each expert, containing the inputs
            assigned to that expert.
        """
        # assigns samples to experts whose gate is nonzero

        # expand according to batch index so we can just split by _part_sizes
        inp_exp = inp[self._batch_index].squeeze(1)
        return torch.split(inp_exp, self._part_sizes, dim=0)

    def combine(
        self, expert_out: List[torch.Tensor], multiply_by_gates: bool = True
    ) -> torch.Tensor:
        """Combine expert outputs weighted by gate values.

        Args:
            expert_out: List of expert output tensors, each with shape
                (expert_batch_size_i, output_dim).
            multiply_by_gates: Whether to weight outputs by gate values.

        Returns:
            Combined output tensor of shape (batch_size, output_dim).
        """
        # apply exp to expert outputs, so we are not longer in log space
        stitched = torch.cat(expert_out, 0)

        if multiply_by_gates:
            stitched = stitched.mul(self._nonzero_gates)
        zeros = torch.zeros(
            self._gates.size(0),
            expert_out[-1].size(1),
            requires_grad=True,
            device=stitched.device,
        )
        # combine samples that have been processed by the same k experts
        combined = zeros.index_add(0, self._batch_index, stitched.float())
        return combined

    def expert_to_gates(self) -> List[torch.Tensor]:
        """Get gate values for each expert's assigned samples.

        Returns:
            List of tensors containing gate values for each expert's samples.
        """
        # split nonzero gates for each expert
        return torch.split(self._nonzero_gates, self._part_sizes, dim=0)

__init__(self, num_experts, gates) special

Create a SparseDispatcher.

Source code in hypercoast/emit_utils/MoE_VAE.py
def __init__(self, num_experts: int, gates: torch.Tensor):
    """Create a SparseDispatcher."""
    self._gates = gates
    self._num_experts = num_experts

    # Safety check: ensure at least one example per expert
    if (gates.sum(dim=0) == 0).any():
        # Find experts with no assignments and create dummy assignments
        empty_experts = (gates.sum(dim=0) == 0).nonzero().squeeze(1)
        if empty_experts.numel() > 0:
            # Assign the first example to all empty experts with a small weight
            for expert_idx in empty_experts:
                gates[0, expert_idx] = 1e-5

    # Sort experts
    sorted_experts, index_sorted_experts = torch.nonzero(gates).sort(0)
    # Drop indices
    _, self._expert_index = sorted_experts.split(1, dim=1)
    # Get according batch index for each expert
    self._batch_index = torch.nonzero(gates)[index_sorted_experts[:, 1], 0]
    # Calculate num samples that each expert gets
    self._part_sizes = (gates > 0).sum(0).tolist()

    # Safety check: ensure no expert has 0 examples
    for i, size in enumerate(self._part_sizes):
        if size == 0:
            # Add a dummy example to this expert
            self._part_sizes[i] = 1
            if i >= len(self._expert_index):
                # Add a new dummy index if needed
                self._expert_index = torch.cat(
                    [self._expert_index, torch.tensor([[i]], device=gates.device)]
                )
                self._batch_index = torch.cat(
                    [self._batch_index, torch.tensor([0], device=gates.device)]
                )

    # Expand gates to match with self._batch_index
    gates_exp = gates[self._batch_index.flatten()]
    self._nonzero_gates = torch.gather(gates_exp, 1, self._expert_index)

    # Safety check for nonzero gates
    if (self._nonzero_gates <= 0).any():
        self._nonzero_gates = torch.clamp(self._nonzero_gates, min=1e-5)

combine(self, expert_out, multiply_by_gates=True)

Combine expert outputs weighted by gate values.

Parameters:

Name Type Description Default
expert_out List[torch.Tensor]

List of expert output tensors, each with shape (expert_batch_size_i, output_dim).

required
multiply_by_gates bool

Whether to weight outputs by gate values.

True

Returns:

Type Description
Tensor

Combined output tensor of shape (batch_size, output_dim).

Source code in hypercoast/emit_utils/MoE_VAE.py
def combine(
    self, expert_out: List[torch.Tensor], multiply_by_gates: bool = True
) -> torch.Tensor:
    """Combine expert outputs weighted by gate values.

    Args:
        expert_out: List of expert output tensors, each with shape
            (expert_batch_size_i, output_dim).
        multiply_by_gates: Whether to weight outputs by gate values.

    Returns:
        Combined output tensor of shape (batch_size, output_dim).
    """
    # apply exp to expert outputs, so we are not longer in log space
    stitched = torch.cat(expert_out, 0)

    if multiply_by_gates:
        stitched = stitched.mul(self._nonzero_gates)
    zeros = torch.zeros(
        self._gates.size(0),
        expert_out[-1].size(1),
        requires_grad=True,
        device=stitched.device,
    )
    # combine samples that have been processed by the same k experts
    combined = zeros.index_add(0, self._batch_index, stitched.float())
    return combined

dispatch(self, inp)

Dispatch input samples to their assigned experts.

Parameters:

Name Type Description Default
inp Tensor

Input tensor of shape (batch_size, input_dim).

required

Returns:

Type Description
List[torch.Tensor]

List of tensors, one for each expert, containing the inputs assigned to that expert.

Source code in hypercoast/emit_utils/MoE_VAE.py
def dispatch(self, inp: torch.Tensor) -> List[torch.Tensor]:
    """Dispatch input samples to their assigned experts.

    Args:
        inp: Input tensor of shape (batch_size, input_dim).

    Returns:
        List of tensors, one for each expert, containing the inputs
        assigned to that expert.
    """
    # assigns samples to experts whose gate is nonzero

    # expand according to batch index so we can just split by _part_sizes
    inp_exp = inp[self._batch_index].squeeze(1)
    return torch.split(inp_exp, self._part_sizes, dim=0)

expert_to_gates(self)

Get gate values for each expert's assigned samples.

Returns:

Type Description
List[torch.Tensor]

List of tensors containing gate values for each expert's samples.

Source code in hypercoast/emit_utils/MoE_VAE.py
def expert_to_gates(self) -> List[torch.Tensor]:
    """Get gate values for each expert's assigned samples.

    Returns:
        List of tensors containing gate values for each expert's samples.
    """
    # split nonzero gates for each expert
    return torch.split(self._nonzero_gates, self._part_sizes, dim=0)

VAE (LightningModule)

Variational Autoencoder for water quality parameter estimation.

This class implements a standard VAE architecture with configurable encoder and decoder networks for estimating water quality parameters from hyperspectral remote sensing reflectance data.

Parameters:

Name Type Description Default
input_dim int

Number of input features (spectral bands).

required
output_dim int

Number of output features (water quality parameters).

required
latent_dim int

Dimension of the latent space.

required
encoder_hidden_dims List[int]

List of hidden layer dimensions for the encoder.

required
decoder_hidden_dims List[int]

List of hidden layer dimensions for the decoder.

required
activation str

Activation function ('relu', 'tanh', 'sigmoid', 'leakyrelu').

'leakyrelu'
use_norm Union[str, bool]

Type of normalization ('batch', 'layer', 'group', or False).

False
use_dropout bool

Whether to use dropout regularization.

False
use_softplus_output bool

Whether to apply softplus activation to output.

False
**kwargs

Additional keyword arguments.

{}
Source code in hypercoast/emit_utils/MoE_VAE.py
class VAE(LightningModule):
    """Variational Autoencoder for water quality parameter estimation.

    This class implements a standard VAE architecture with configurable encoder
    and decoder networks for estimating water quality parameters from hyperspectral
    remote sensing reflectance data.

    Args:
        input_dim: Number of input features (spectral bands).
        output_dim: Number of output features (water quality parameters).
        latent_dim: Dimension of the latent space.
        encoder_hidden_dims: List of hidden layer dimensions for the encoder.
        decoder_hidden_dims: List of hidden layer dimensions for the decoder.
        activation: Activation function ('relu', 'tanh', 'sigmoid', 'leakyrelu').
        use_norm: Type of normalization ('batch', 'layer', 'group', or False).
        use_dropout: Whether to use dropout regularization.
        use_softplus_output: Whether to apply softplus activation to output.
        **kwargs: Additional keyword arguments.
    """

    def __init__(
        self,
        input_dim: int,
        output_dim: int,
        latent_dim: int,
        encoder_hidden_dims: List[int],
        decoder_hidden_dims: List[int],
        activation: str = "leakyrelu",
        use_norm: Union[str, bool] = False,
        use_dropout: bool = False,
        use_softplus_output: bool = False,
        **kwargs,
    ):
        super().__init__()
        # Define the activation function
        self.use_softplus_output = use_softplus_output
        if activation == "relu":
            self.activation = nn.ReLU()
        elif activation == "tanh":
            self.activation = nn.Tanh()
        elif activation == "sigmoid":
            self.activation = nn.Sigmoid()
        elif activation == "leakyrelu":
            self.activation = nn.LeakyReLU(0.2)
        else:
            raise ValueError(f"Unsupported activation function: {activation}")

        # Encoder layers
        self.encoder_layers = self.build_layers(
            input_dim, encoder_hidden_dims, use_norm, use_dropout
        )
        self.fc_mu = nn.Linear(encoder_hidden_dims[-1], latent_dim)
        self.fc_log_var = nn.Linear(encoder_hidden_dims[-1], latent_dim)

        # Decoder layers
        self.decoder_layers = self.build_layers(
            latent_dim, decoder_hidden_dims, use_norm, use_dropout
        )
        # self.decoder_layers.add_module('softplus', nn.Softplus())
        self.decoder_layers.add_module(
            "output_layer", nn.Linear(decoder_hidden_dims[-1], output_dim)
        )
        if self.use_softplus_output:
            self.decoder_layers.add_module("output_activation", nn.Softplus())
        # self.decoder_layers.add_module('output_activation', nn.Tanh())  # Assuming output is in range [-1, 1]
        # with the classic robust preprocessing method it is -1 to 1, but for others it may not.

    def build_layers(
        self,
        input_dim: int,
        hidden_dims: List[int],
        use_norm: Union[str, bool],
        use_dropout: bool = False,
    ) -> nn.Sequential:
        """Build a sequential network with specified layers and normalization.

        Args:
            input_dim: Number of input features.
            hidden_dims: List of hidden layer dimensions.
            use_norm: Type of normalization ('batch', 'layer', 'group', or False).
            use_dropout: Whether to add dropout layers.

        Returns:
            Sequential module containing the network layers.
        """
        layers = []
        current_size = input_dim
        for hidden_dim in hidden_dims:
            next_size = hidden_dim
            layers.append(nn.Linear(current_size, next_size))
            if use_norm == "batch":
                layers.append(nn.BatchNorm1d(hidden_dim))
            elif use_norm == "layer":
                layers.append(nn.LayerNorm(hidden_dim))
            elif use_norm == "group":
                num_groups = max(1, hidden_dim // 4)
                layers.append(
                    nn.GroupNorm(num_groups=num_groups, num_channels=hidden_dim)
                )
            layers.append(self.activation)
            if use_dropout:
                layers.append(nn.Dropout(0.1))
            current_size = next_size
        return nn.Sequential(*layers)

    def encode(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        """Encode input to latent distribution parameters.

        Args:
            x: Input tensor of shape (batch, input_dim).

        Returns:
            mu: Mean of latent distribution.
            log_var: Log variance of latent distribution.
        """
        x = self.encoder_layers(x)
        mu = self.fc_mu(x)
        log_var = self.fc_log_var(x)
        return mu, log_var

    def reparameterize(self, mu: torch.Tensor, log_var: torch.Tensor) -> torch.Tensor:
        """Apply reparameterization trick for sampling from latent distribution.

        Args:
            mu: Mean of latent distribution.
            log_var: Log variance of latent distribution.

        Returns:
            z: Sampled latent vector.
        """
        std = torch.exp(0.5 * log_var)
        eps = torch.randn_like(std)
        z = mu + eps * std
        return z

    def decode(self, z: torch.Tensor) -> torch.Tensor:
        """Decode latent vector to output.

        Args:
            z: Latent vector.

        Returns:
            Decoded output tensor.
        """
        return self.decoder_layers(z)

    def forward(self, x: torch.Tensor) -> Dict[str, torch.Tensor]:
        """Forward pass through the VAE.

        Args:
            x: Input tensor of shape (batch, input_dim).

        Returns:
            Dictionary containing 'pred_y', 'mu', and 'log_var'.
        """
        mu, log_var = self.encode(x)
        z = self.reparameterize(mu, log_var)
        pred_y = self.decode(z)
        return {"pred_y": pred_y, "mu": mu, "log_var": log_var}

    def loss_fn(
        self, output_dict: Dict[str, torch.Tensor], kld_weight: float = 0.0
    ) -> Dict[str, torch.Tensor]:
        """Compute VAE loss including reconstruction and KL divergence terms.

        Args:
            output_dict: Dictionary containing 'pred_y', 'y', 'mu', and 'log_var'.
            kld_weight: Weight for the KL divergence term.

        Returns:
            Dictionary containing 'total_loss', 'mae_loss', 'mse_loss', and 'kld_loss'.
        """
        pred_y, y, mu, log_var = (
            output_dict["pred_y"],
            output_dict["y"],
            output_dict["mu"],
            output_dict["log_var"],
        )
        batch_size = y.shape[0]
        MAE = F.l1_loss(pred_y, y, reduction="mean")
        # Reconstruction loss (MSE)
        MSE = F.mse_loss(pred_y, y, reduction="mean")
        # KL divergence
        KLD = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp()) / batch_size
        # Return combined loss
        return {
            "total_loss": MAE + kld_weight * KLD,
            "mae_loss": MAE,
            "mse_loss": MSE,
            "kld_loss": KLD,
        }

build_layers(self, input_dim, hidden_dims, use_norm, use_dropout=False)

Build a sequential network with specified layers and normalization.

Parameters:

Name Type Description Default
input_dim int

Number of input features.

required
hidden_dims List[int]

List of hidden layer dimensions.

required
use_norm Union[str, bool]

Type of normalization ('batch', 'layer', 'group', or False).

required
use_dropout bool

Whether to add dropout layers.

False

Returns:

Type Description
Sequential

Sequential module containing the network layers.

Source code in hypercoast/emit_utils/MoE_VAE.py
def build_layers(
    self,
    input_dim: int,
    hidden_dims: List[int],
    use_norm: Union[str, bool],
    use_dropout: bool = False,
) -> nn.Sequential:
    """Build a sequential network with specified layers and normalization.

    Args:
        input_dim: Number of input features.
        hidden_dims: List of hidden layer dimensions.
        use_norm: Type of normalization ('batch', 'layer', 'group', or False).
        use_dropout: Whether to add dropout layers.

    Returns:
        Sequential module containing the network layers.
    """
    layers = []
    current_size = input_dim
    for hidden_dim in hidden_dims:
        next_size = hidden_dim
        layers.append(nn.Linear(current_size, next_size))
        if use_norm == "batch":
            layers.append(nn.BatchNorm1d(hidden_dim))
        elif use_norm == "layer":
            layers.append(nn.LayerNorm(hidden_dim))
        elif use_norm == "group":
            num_groups = max(1, hidden_dim // 4)
            layers.append(
                nn.GroupNorm(num_groups=num_groups, num_channels=hidden_dim)
            )
        layers.append(self.activation)
        if use_dropout:
            layers.append(nn.Dropout(0.1))
        current_size = next_size
    return nn.Sequential(*layers)

decode(self, z)

Decode latent vector to output.

Parameters:

Name Type Description Default
z Tensor

Latent vector.

required

Returns:

Type Description
Tensor

Decoded output tensor.

Source code in hypercoast/emit_utils/MoE_VAE.py
def decode(self, z: torch.Tensor) -> torch.Tensor:
    """Decode latent vector to output.

    Args:
        z: Latent vector.

    Returns:
        Decoded output tensor.
    """
    return self.decoder_layers(z)

encode(self, x)

Encode input to latent distribution parameters.

Parameters:

Name Type Description Default
x Tensor

Input tensor of shape (batch, input_dim).

required

Returns:

Type Description
mu

Mean of latent distribution. log_var: Log variance of latent distribution.

Source code in hypercoast/emit_utils/MoE_VAE.py
def encode(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
    """Encode input to latent distribution parameters.

    Args:
        x: Input tensor of shape (batch, input_dim).

    Returns:
        mu: Mean of latent distribution.
        log_var: Log variance of latent distribution.
    """
    x = self.encoder_layers(x)
    mu = self.fc_mu(x)
    log_var = self.fc_log_var(x)
    return mu, log_var

forward(self, x)

Forward pass through the VAE.

Parameters:

Name Type Description Default
x Tensor

Input tensor of shape (batch, input_dim).

required

Returns:

Type Description
Dict[str, torch.Tensor]

Dictionary containing 'pred_y', 'mu', and 'log_var'.

Source code in hypercoast/emit_utils/MoE_VAE.py
def forward(self, x: torch.Tensor) -> Dict[str, torch.Tensor]:
    """Forward pass through the VAE.

    Args:
        x: Input tensor of shape (batch, input_dim).

    Returns:
        Dictionary containing 'pred_y', 'mu', and 'log_var'.
    """
    mu, log_var = self.encode(x)
    z = self.reparameterize(mu, log_var)
    pred_y = self.decode(z)
    return {"pred_y": pred_y, "mu": mu, "log_var": log_var}

loss_fn(self, output_dict, kld_weight=0.0)

Compute VAE loss including reconstruction and KL divergence terms.

Parameters:

Name Type Description Default
output_dict Dict[str, torch.Tensor]

Dictionary containing 'pred_y', 'y', 'mu', and 'log_var'.

required
kld_weight float

Weight for the KL divergence term.

0.0

Returns:

Type Description
Dict[str, torch.Tensor]

Dictionary containing 'total_loss', 'mae_loss', 'mse_loss', and 'kld_loss'.

Source code in hypercoast/emit_utils/MoE_VAE.py
def loss_fn(
    self, output_dict: Dict[str, torch.Tensor], kld_weight: float = 0.0
) -> Dict[str, torch.Tensor]:
    """Compute VAE loss including reconstruction and KL divergence terms.

    Args:
        output_dict: Dictionary containing 'pred_y', 'y', 'mu', and 'log_var'.
        kld_weight: Weight for the KL divergence term.

    Returns:
        Dictionary containing 'total_loss', 'mae_loss', 'mse_loss', and 'kld_loss'.
    """
    pred_y, y, mu, log_var = (
        output_dict["pred_y"],
        output_dict["y"],
        output_dict["mu"],
        output_dict["log_var"],
    )
    batch_size = y.shape[0]
    MAE = F.l1_loss(pred_y, y, reduction="mean")
    # Reconstruction loss (MSE)
    MSE = F.mse_loss(pred_y, y, reduction="mean")
    # KL divergence
    KLD = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp()) / batch_size
    # Return combined loss
    return {
        "total_loss": MAE + kld_weight * KLD,
        "mae_loss": MAE,
        "mse_loss": MSE,
        "kld_loss": KLD,
    }

reparameterize(self, mu, log_var)

Apply reparameterization trick for sampling from latent distribution.

Parameters:

Name Type Description Default
mu Tensor

Mean of latent distribution.

required
log_var Tensor

Log variance of latent distribution.

required

Returns:

Type Description
z

Sampled latent vector.

Source code in hypercoast/emit_utils/MoE_VAE.py
def reparameterize(self, mu: torch.Tensor, log_var: torch.Tensor) -> torch.Tensor:
    """Apply reparameterization trick for sampling from latent distribution.

    Args:
        mu: Mean of latent distribution.
        log_var: Log variance of latent distribution.

    Returns:
        z: Sampled latent vector.
    """
    std = torch.exp(0.5 * log_var)
    eps = torch.randn_like(std)
    z = mu + eps * std
    return z

evaluate(model, test_dl, device, TSS_scalers_dict=None, log_offset=0.01)

Evaluate a VAE or MoE-VAE model on test data.

Parameters:

Name Type Description Default
model Union[hypercoast.emit_utils.MoE_VAE.VAE, hypercoast.emit_utils.MoE_VAE.MoE_VAE]

Trained VAE or MoE_VAE model.

required
test_dl DataLoader

DataLoader providing test batches.

required
device device

Device to run evaluation on (CPU or CUDA).

required
TSS_scalers_dict Optional[Dict[str, Any]]

Optional dictionary with 'log' and 'robust' scalers for inverse transformation. If None, uses simple log offset.

None
log_offset float

Offset for inverse log transformation if TSS_scalers_dict is None.

0.01

Returns:

Type Description
predictions_inverse

Predictions in original scale. actuals_inverse: Ground truth values in original scale.

Source code in hypercoast/emit_utils/MoE_VAE.py
def evaluate(
    model: Union[VAE, MoE_VAE],
    test_dl: DataLoader,
    device: torch.device,
    TSS_scalers_dict: Optional[Dict[str, Any]] = None,
    log_offset: float = 0.01,
) -> Tuple[np.ndarray, np.ndarray]:
    """Evaluate a VAE or MoE-VAE model on test data.

    Args:
        model: Trained VAE or MoE_VAE model.
        test_dl: DataLoader providing test batches.
        device: Device to run evaluation on (CPU or CUDA).
        TSS_scalers_dict: Optional dictionary with 'log' and 'robust' scalers
            for inverse transformation. If None, uses simple log offset.
        log_offset: Offset for inverse log transformation if TSS_scalers_dict is None.

    Returns:
        predictions_inverse: Predictions in original scale.
        actuals_inverse: Ground truth values in original scale.
    """
    model.eval()
    predictions, actuals = [], []

    with torch.no_grad():
        for x, y in test_dl:
            x, y = x.to(device), y.to(device)
            output_dict = model(x)
            y_pred = output_dict["pred_y"]
            predictions.append(y_pred.cpu().numpy())
            actuals.append(y.cpu().numpy())

    predictions = np.vstack(predictions)
    actuals = np.vstack(actuals)

    # === Inverse transformation ===
    if TSS_scalers_dict is not None:
        log_scaler = TSS_scalers_dict["log"]
        robust_scaler = TSS_scalers_dict["robust"]

        # First reverse min-max, then reverse log
        predictions_inverse = (
            log_scaler.inverse_transform(
                torch.tensor(
                    robust_scaler.inverse_transform(
                        torch.tensor(predictions, dtype=torch.float32)
                    ),
                    dtype=torch.float32,
                )
            )
            .numpy()
            .flatten()
        )

        actuals_inverse = (
            log_scaler.inverse_transform(
                torch.tensor(
                    robust_scaler.inverse_transform(
                        torch.tensor(actuals, dtype=torch.float32)
                    ),
                    dtype=torch.float32,
                )
            )
            .numpy()
            .flatten()
        )
    else:
        predictions_inverse = (10 ** predictions.flatten()) - log_offset
        actuals_inverse = (10 ** actuals.flatten()) - log_offset

    return predictions_inverse, actuals_inverse

train(model, train_dl, device, epochs=200, optimizer=None, save_dir=None)

Train a VAE or MoE-VAE model.

Parameters:

Name Type Description Default
model Union[hypercoast.emit_utils.MoE_VAE.VAE, hypercoast.emit_utils.MoE_VAE.MoE_VAE]

VAE or MoE_VAE model to train.

required
train_dl DataLoader

DataLoader providing training batches.

required
device device

Device to train on (CPU or CUDA).

required
epochs int

Number of training epochs.

200
optimizer Optional[torch.optim.optimizer.Optimizer]

PyTorch optimizer. If None, must be configured externally.

None
save_dir Optional[str]

Directory to save the best model checkpoint.

None

Returns:

Type Description
Dictionary containing
  • 'total_loss': List of total losses per epoch.
    • 'l1_loss': List of L1 losses per epoch.
    • 'best_loss': Best total loss achieved.
Source code in hypercoast/emit_utils/MoE_VAE.py
def train(
    model: Union[VAE, MoE_VAE],
    train_dl: DataLoader,
    device: torch.device,
    epochs: int = 200,
    optimizer: Optional[torch.optim.Optimizer] = None,
    save_dir: Optional[str] = None,
) -> Dict[str, Any]:
    """Train a VAE or MoE-VAE model.

    Args:
        model: VAE or MoE_VAE model to train.
        train_dl: DataLoader providing training batches.
        device: Device to train on (CPU or CUDA).
        epochs: Number of training epochs.
        optimizer: PyTorch optimizer. If None, must be configured externally.
        save_dir: Directory to save the best model checkpoint.

    Returns:
        Dictionary containing:
            - 'total_loss': List of total losses per epoch.
            - 'l1_loss': List of L1 losses per epoch.
            - 'best_loss': Best total loss achieved.
    """
    model.train()
    min_total_loss = float("inf")
    best_model_path = os.path.join(save_dir, "best_model_minloss.pth")

    total_list = []
    l1_list = []

    for epoch in range(epochs):
        total_loss_epoch = 0.0
        l1_epoch = 0.0

        for x, y in train_dl:
            x, y = x.to(device), y.to(device)

            output_dict = model(x)
            output_dict["y"] = y

            loss_dict = model.loss_fn(output_dict)

            loss = loss_dict["total_loss"]
            l1 = loss_dict["mae_loss"]

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            total_loss_epoch += loss.item()
            l1_epoch += l1.item()

        avg_total_loss = total_loss_epoch / len(train_dl)
        avg_l1 = l1_epoch / len(train_dl)

        print(f"[Epoch {epoch+1}] Total: {avg_total_loss:.4f} | L1: {avg_l1:.4f}")
        total_list.append(avg_total_loss)
        l1_list.append(avg_l1)

        if avg_total_loss < min_total_loss:
            min_total_loss = avg_total_loss
            torch.save(model.state_dict(), best_model_path)

    return {"total_loss": total_list, "l1_loss": l1_list, "best_loss": min_total_loss}