model module¶
Variational Autoencoder and Mixture of Experts models.
This module implements VAE and MoE-VAE architectures for remote sensing data analysis, including sparse gating mechanisms and training utilities.
MoE_VAE (LightningModule)
¶
Call a Sparsely gated mixture of experts layer with 1-layer Feed-Forward networks as experts.
Source code in hypercoast/moe_vae/model.py
class MoE_VAE(LightningModule):
"""Call a Sparsely gated mixture of experts layer with 1-layer Feed-Forward networks as experts.
Args:
input_dim: integer - size of the input
output_dim: integer - size of the input
num_experts: an integer - number of experts
hidden_dims: an integer - hidden_dims size of the experts
noisy_gating: a boolean
k: an integer - how many experts to use for each batch element
"""
def __init__(
self,
input_dim,
output_dim,
latent_dim,
encoder_hidden_dims,
decoder_hidden_dims,
num_experts,
k=4,
activation="leakyrelu",
noisy_gating=True,
use_norm=False,
use_dropout=False,
use_softplus_output=False,
**kwargs,
):
"""
Initialize the MoE-VAE model.
Args:
input_dim (int): Dimension of input data.
output_dim (int): Dimension of output/reconstructed data.
latent_dim (int): Dimension of latent space.
encoder_hidden_dims (list): List of hidden layer dimensions for encoder.
decoder_hidden_dims (list): List of hidden layer dimensions for decoder.
num_experts (int): Number of experts.
k (int, optional): Number of experts to use for each batch element.
activation (str, optional): Activation function name.
noisy_gating (bool, optional): Whether to use noisy gating.
use_norm (str or bool, optional): Normalization type. Can be 'batch',
'layer', or False. Defaults to False.
use_dropout (bool, optional): Whether to use dropout. Defaults to False.
use_softplus_output (bool, optional): Whether to apply softplus to output.
Defaults to False.
**kwargs: Additional keyword arguments.
"""
super(MoE_VAE, self).__init__()
self.noisy_gating = noisy_gating
self.num_experts = num_experts
self.output_dim = output_dim
self.input_dim = input_dim
self.latent_dim = latent_dim
self.encoder_hidden_dims = encoder_hidden_dims
self.decoder_hidden_dims = decoder_hidden_dims
self.num_experts = num_experts
self.k = k
self.activation = activation
self.use_norm = use_norm
self.use_dropout = use_dropout
self.use_softplus_output = use_softplus_output
# instantiate experts
self.experts = nn.ModuleList(
[
VAE(
self.input_dim,
self.output_dim,
self.latent_dim,
self.encoder_hidden_dims,
self.decoder_hidden_dims,
self.activation,
use_norm=self.use_norm,
use_dropout=self.use_dropout,
use_softplus_output=self.use_softplus_output,
)
for i in range(self.num_experts)
]
)
self.w_gate = nn.Parameter(
torch.zeros(input_dim, num_experts, dtype=self.dtype), requires_grad=True
)
self.w_noise = nn.Parameter(
torch.zeros(input_dim, num_experts, dtype=self.dtype), requires_grad=True
)
self.softplus = nn.Softplus()
self.softmax = nn.Softmax(1)
self.register_buffer("mean", torch.tensor([0.0]))
self.register_buffer("std", torch.tensor([1.0]))
self.batch_gates = None
assert self.k <= self.num_experts
def forward(self, x, moe_weight=1e-2):
"""
Forward pass of the MoE-VAE model.
Args:
x (torch.Tensor): Input tensor of shape [batch_size, input_dim].
moe_weight (float, optional): Multiplier for load-balancing loss.
Defaults to 1e-2.
Returns:
dict: Dictionary containing:
- 'pred_y': Predicted output tensor of shape [batch_size, output_dim].
- 'moe_loss': Load-balancing loss.
"""
gates, load = self.noisy_top_k_gating(x, self.training)
self.batch_gates = gates
# calculate importance loss
importance = gates.sum(0)
moe_loss = moe_weight * self.cv_squared(
importance
) + moe_weight * self.cv_squared(load)
dispatcher = SparseDispatcher(self.num_experts, gates)
expert_inputs = dispatcher.dispatch(x)
gates = dispatcher.expert_to_gates()
expert_outputs = []
for i in range(self.num_experts):
input_i = expert_inputs[i]
if input_i.shape[0] > 1:
expert_outputs.append(self.experts[i](input_i)["pred_y"])
else:
expert_outputs.append(
torch.zeros(
(input_i.shape[0], self.output_dim), device=input_i.device
)
)
pred_y = dispatcher.combine(expert_outputs)
return {"pred_y": pred_y, "moe_loss": moe_loss}
def loss_fn(self, output_dict) -> torch.Tensor:
"""
Compute loss between model output and target.
Args:
output: Model output tensor of shape (batch, output_dim)
target: Target tensor of shape (batch, output_dim)
Returns:
loss: Scalar tensor representing the loss
"""
pred_y = output_dict["pred_y"]
y = output_dict["y"]
batch_size = y.shape[0]
MAE = F.l1_loss(pred_y, y, reduction="mean")
mse_losss = F.mse_loss(pred_y, y, reduction="mean")
moe_loss = output_dict.get(
"moe_loss", torch.tensor(0.0, device=pred_y.device, dtype=pred_y.dtype)
)
total_loss = MAE + moe_loss
return {
"total_loss": total_loss,
"mae_loss": MAE,
"mse_loss": mse_losss,
"moe_loss": moe_loss,
}
def get_batch_gates(self):
"""Get the gating weights from the last forward pass.
Returns:
torch.Tensor: Gating weights of shape [batch_size, num_experts].
"""
return self.batch_gates
def cv_squared(self, x):
"""Compute squared coefficient of variation for load balancing.
Calculates the squared coefficient of variation (variance/mean²) which
serves as a loss term to encourage uniform distribution across experts.
Args:
x (torch.Tensor): Input tensor (typically expert loads or importance).
Returns:
torch.Tensor: Scalar tensor representing squared coefficient of variation.
Returns 0 for single-element tensors.
"""
eps = 1e-10
# if only num_experts = 1
if x.shape[0] == 1:
return torch.tensor([0], device=x.device, dtype=x.dtype)
return x.float().var() / (x.float().mean() ** 2 + eps)
def _gates_to_load(self, gates):
"""Convert gate weights to expert load counts.
Computes the number of examples assigned to each expert (with gate > 0).
Args:
gates (torch.Tensor): Gate weights of shape [batch_size, num_experts].
Returns:
torch.Tensor: Load count per expert of shape [num_experts].
"""
return (gates > 0).sum(0)
def _prob_in_top_k(
self, clean_values, noisy_values, noise_stddev, noisy_top_values
):
"""Compute probability of expert being in top-k selection.
Helper function for noisy top-k gating that computes the probability
each expert would be selected given different noise realizations.
This enables differentiable load balancing.
Args:
clean_values (torch.Tensor): Clean logits of shape [batch, num_experts].
noisy_values (torch.Tensor): Noisy logits of shape [batch, num_experts].
noise_stddev (torch.Tensor): Noise standard deviation of same shape.
noisy_top_values (torch.Tensor): Top-k+1 noisy values for thresholding.
Returns:
torch.Tensor: Probability of each expert being in top-k,
shape [batch, num_experts].
"""
batch = clean_values.size(0)
m = noisy_top_values.size(1)
top_values_flat = noisy_top_values.flatten()
threshold_positions_if_in = (
torch.arange(batch, device=clean_values.device) * m + self.k
)
threshold_if_in = torch.unsqueeze(
torch.gather(top_values_flat, 0, threshold_positions_if_in), 1
)
is_in = torch.gt(noisy_values, threshold_if_in)
threshold_positions_if_out = threshold_positions_if_in - 1
threshold_if_out = torch.unsqueeze(
torch.gather(top_values_flat, 0, threshold_positions_if_out), 1
)
# is each value currently in the top k.
normal = Normal(self.mean, self.std)
prob_if_in = normal.cdf((clean_values - threshold_if_in) / noise_stddev)
prob_if_out = normal.cdf((clean_values - threshold_if_out) / noise_stddev)
prob = torch.where(is_in, prob_if_in, prob_if_out)
return prob
def noisy_top_k_gating(self, x, train, noise_epsilon=1e-2):
"""Noisy top-k gating mechanism for expert selection.
Implements the noisy top-k gating from "Outrageously Large Neural Networks"
(https://arxiv.org/abs/1701.06538). Adds controlled noise during training
to improve load balancing across experts.
Args:
x (torch.Tensor): Input features of shape [batch_size, input_dim].
train (bool): Whether model is in training mode (adds noise if True).
noise_epsilon (float, optional): Minimum noise standard deviation.
Defaults to 1e-2.
Returns:
tuple: A tuple containing:
- gates (torch.Tensor): Sparse gate weights [batch_size, num_experts]
- load (torch.Tensor): Expert load for balancing [num_experts]
"""
clean_logits = x @ self.w_gate
if self.noisy_gating and train:
raw_noise_stddev = x @ self.w_noise
noise_stddev = self.softplus(raw_noise_stddev) + noise_epsilon
noisy_logits = clean_logits + (
torch.randn_like(clean_logits) * noise_stddev
)
logits = noisy_logits
else:
logits = clean_logits
# Add this safety check to ensure we always have at least one expert selected
if (logits.sum(dim=1) == 0).any():
# Add a small positive value to ensure we have non-zero logits
logits = logits + 1e-5
# calculate topk + 1 that will be needed for the noisy gates
top_logits, top_indices = logits.topk(min(self.k + 1, self.num_experts), dim=1)
top_k_logits = top_logits[:, : self.k]
top_k_indices = top_indices[:, : self.k]
top_k_gates = self.softmax(top_k_logits)
zeros = torch.zeros_like(logits, requires_grad=True, dtype=self.dtype)
gates = zeros.scatter(1, top_k_indices, top_k_gates)
# Safety check - ensure at least one expert is selected per sample
if (gates.sum(dim=1) < 1e-6).any():
# Force selection of the top expert for samples with no experts
problematic_samples = (gates.sum(dim=1) < 1e-6).nonzero().squeeze(1)
if problematic_samples.numel() > 0: # If there are problematic samples
# Select the top expert for these samples
top_expert = top_indices[problematic_samples, 0]
# Set a minimum value for the gate
gates[problematic_samples, top_expert] = 0.1
if self.noisy_gating and self.k < self.num_experts and train:
load = (
self._prob_in_top_k(
clean_logits, noisy_logits, noise_stddev, top_logits
)
).sum(0)
else:
load = self._gates_to_load(gates)
return gates, load
__init__(self, input_dim, output_dim, latent_dim, encoder_hidden_dims, decoder_hidden_dims, num_experts, k=4, activation='leakyrelu', noisy_gating=True, use_norm=False, use_dropout=False, use_softplus_output=False, **kwargs)
special
¶
Initialize the MoE-VAE model.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
input_dim |
int |
Dimension of input data. |
required |
output_dim |
int |
Dimension of output/reconstructed data. |
required |
latent_dim |
int |
Dimension of latent space. |
required |
encoder_hidden_dims |
list |
List of hidden layer dimensions for encoder. |
required |
decoder_hidden_dims |
list |
List of hidden layer dimensions for decoder. |
required |
num_experts |
int |
Number of experts. |
required |
k |
int |
Number of experts to use for each batch element. |
4 |
activation |
str |
Activation function name. |
'leakyrelu' |
noisy_gating |
bool |
Whether to use noisy gating. |
True |
use_norm |
str or bool |
Normalization type. Can be 'batch', 'layer', or False. Defaults to False. |
False |
use_dropout |
bool |
Whether to use dropout. Defaults to False. |
False |
use_softplus_output |
bool |
Whether to apply softplus to output. Defaults to False. |
False |
**kwargs |
Additional keyword arguments. |
{} |
Source code in hypercoast/moe_vae/model.py
def __init__(
self,
input_dim,
output_dim,
latent_dim,
encoder_hidden_dims,
decoder_hidden_dims,
num_experts,
k=4,
activation="leakyrelu",
noisy_gating=True,
use_norm=False,
use_dropout=False,
use_softplus_output=False,
**kwargs,
):
"""
Initialize the MoE-VAE model.
Args:
input_dim (int): Dimension of input data.
output_dim (int): Dimension of output/reconstructed data.
latent_dim (int): Dimension of latent space.
encoder_hidden_dims (list): List of hidden layer dimensions for encoder.
decoder_hidden_dims (list): List of hidden layer dimensions for decoder.
num_experts (int): Number of experts.
k (int, optional): Number of experts to use for each batch element.
activation (str, optional): Activation function name.
noisy_gating (bool, optional): Whether to use noisy gating.
use_norm (str or bool, optional): Normalization type. Can be 'batch',
'layer', or False. Defaults to False.
use_dropout (bool, optional): Whether to use dropout. Defaults to False.
use_softplus_output (bool, optional): Whether to apply softplus to output.
Defaults to False.
**kwargs: Additional keyword arguments.
"""
super(MoE_VAE, self).__init__()
self.noisy_gating = noisy_gating
self.num_experts = num_experts
self.output_dim = output_dim
self.input_dim = input_dim
self.latent_dim = latent_dim
self.encoder_hidden_dims = encoder_hidden_dims
self.decoder_hidden_dims = decoder_hidden_dims
self.num_experts = num_experts
self.k = k
self.activation = activation
self.use_norm = use_norm
self.use_dropout = use_dropout
self.use_softplus_output = use_softplus_output
# instantiate experts
self.experts = nn.ModuleList(
[
VAE(
self.input_dim,
self.output_dim,
self.latent_dim,
self.encoder_hidden_dims,
self.decoder_hidden_dims,
self.activation,
use_norm=self.use_norm,
use_dropout=self.use_dropout,
use_softplus_output=self.use_softplus_output,
)
for i in range(self.num_experts)
]
)
self.w_gate = nn.Parameter(
torch.zeros(input_dim, num_experts, dtype=self.dtype), requires_grad=True
)
self.w_noise = nn.Parameter(
torch.zeros(input_dim, num_experts, dtype=self.dtype), requires_grad=True
)
self.softplus = nn.Softplus()
self.softmax = nn.Softmax(1)
self.register_buffer("mean", torch.tensor([0.0]))
self.register_buffer("std", torch.tensor([1.0]))
self.batch_gates = None
assert self.k <= self.num_experts
cv_squared(self, x)
¶
Compute squared coefficient of variation for load balancing.
Calculates the squared coefficient of variation (variance/mean²) which serves as a loss term to encourage uniform distribution across experts.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
x |
torch.Tensor |
Input tensor (typically expert loads or importance). |
required |
Returns:
Type | Description |
---|---|
torch.Tensor |
Scalar tensor representing squared coefficient of variation. Returns 0 for single-element tensors. |
Source code in hypercoast/moe_vae/model.py
def cv_squared(self, x):
"""Compute squared coefficient of variation for load balancing.
Calculates the squared coefficient of variation (variance/mean²) which
serves as a loss term to encourage uniform distribution across experts.
Args:
x (torch.Tensor): Input tensor (typically expert loads or importance).
Returns:
torch.Tensor: Scalar tensor representing squared coefficient of variation.
Returns 0 for single-element tensors.
"""
eps = 1e-10
# if only num_experts = 1
if x.shape[0] == 1:
return torch.tensor([0], device=x.device, dtype=x.dtype)
return x.float().var() / (x.float().mean() ** 2 + eps)
forward(self, x, moe_weight=0.01)
¶
Forward pass of the MoE-VAE model.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
x |
torch.Tensor |
Input tensor of shape [batch_size, input_dim]. |
required |
moe_weight |
float |
Multiplier for load-balancing loss. Defaults to 1e-2. |
0.01 |
Returns:
Type | Description |
---|---|
dict |
Dictionary containing: - 'pred_y': Predicted output tensor of shape [batch_size, output_dim]. - 'moe_loss': Load-balancing loss. |
Source code in hypercoast/moe_vae/model.py
def forward(self, x, moe_weight=1e-2):
"""
Forward pass of the MoE-VAE model.
Args:
x (torch.Tensor): Input tensor of shape [batch_size, input_dim].
moe_weight (float, optional): Multiplier for load-balancing loss.
Defaults to 1e-2.
Returns:
dict: Dictionary containing:
- 'pred_y': Predicted output tensor of shape [batch_size, output_dim].
- 'moe_loss': Load-balancing loss.
"""
gates, load = self.noisy_top_k_gating(x, self.training)
self.batch_gates = gates
# calculate importance loss
importance = gates.sum(0)
moe_loss = moe_weight * self.cv_squared(
importance
) + moe_weight * self.cv_squared(load)
dispatcher = SparseDispatcher(self.num_experts, gates)
expert_inputs = dispatcher.dispatch(x)
gates = dispatcher.expert_to_gates()
expert_outputs = []
for i in range(self.num_experts):
input_i = expert_inputs[i]
if input_i.shape[0] > 1:
expert_outputs.append(self.experts[i](input_i)["pred_y"])
else:
expert_outputs.append(
torch.zeros(
(input_i.shape[0], self.output_dim), device=input_i.device
)
)
pred_y = dispatcher.combine(expert_outputs)
return {"pred_y": pred_y, "moe_loss": moe_loss}
get_batch_gates(self)
¶
Get the gating weights from the last forward pass.
Returns:
Type | Description |
---|---|
torch.Tensor |
Gating weights of shape [batch_size, num_experts]. |
Source code in hypercoast/moe_vae/model.py
def get_batch_gates(self):
"""Get the gating weights from the last forward pass.
Returns:
torch.Tensor: Gating weights of shape [batch_size, num_experts].
"""
return self.batch_gates
loss_fn(self, output_dict)
¶
Compute loss between model output and target.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
output |
Model output tensor of shape (batch, output_dim) |
required | |
target |
Target tensor of shape (batch, output_dim) |
required |
Returns:
Type | Description |
---|---|
loss |
Scalar tensor representing the loss |
Source code in hypercoast/moe_vae/model.py
def loss_fn(self, output_dict) -> torch.Tensor:
"""
Compute loss between model output and target.
Args:
output: Model output tensor of shape (batch, output_dim)
target: Target tensor of shape (batch, output_dim)
Returns:
loss: Scalar tensor representing the loss
"""
pred_y = output_dict["pred_y"]
y = output_dict["y"]
batch_size = y.shape[0]
MAE = F.l1_loss(pred_y, y, reduction="mean")
mse_losss = F.mse_loss(pred_y, y, reduction="mean")
moe_loss = output_dict.get(
"moe_loss", torch.tensor(0.0, device=pred_y.device, dtype=pred_y.dtype)
)
total_loss = MAE + moe_loss
return {
"total_loss": total_loss,
"mae_loss": MAE,
"mse_loss": mse_losss,
"moe_loss": moe_loss,
}
noisy_top_k_gating(self, x, train, noise_epsilon=0.01)
¶
Noisy top-k gating mechanism for expert selection.
Implements the noisy top-k gating from "Outrageously Large Neural Networks" (https://arxiv.org/abs/1701.06538). Adds controlled noise during training to improve load balancing across experts.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
x |
torch.Tensor |
Input features of shape [batch_size, input_dim]. |
required |
train |
bool |
Whether model is in training mode (adds noise if True). |
required |
noise_epsilon |
float |
Minimum noise standard deviation. Defaults to 1e-2. |
0.01 |
Returns:
Type | Description |
---|---|
tuple |
A tuple containing: - gates (torch.Tensor): Sparse gate weights [batch_size, num_experts] - load (torch.Tensor): Expert load for balancing [num_experts] |
Source code in hypercoast/moe_vae/model.py
def noisy_top_k_gating(self, x, train, noise_epsilon=1e-2):
"""Noisy top-k gating mechanism for expert selection.
Implements the noisy top-k gating from "Outrageously Large Neural Networks"
(https://arxiv.org/abs/1701.06538). Adds controlled noise during training
to improve load balancing across experts.
Args:
x (torch.Tensor): Input features of shape [batch_size, input_dim].
train (bool): Whether model is in training mode (adds noise if True).
noise_epsilon (float, optional): Minimum noise standard deviation.
Defaults to 1e-2.
Returns:
tuple: A tuple containing:
- gates (torch.Tensor): Sparse gate weights [batch_size, num_experts]
- load (torch.Tensor): Expert load for balancing [num_experts]
"""
clean_logits = x @ self.w_gate
if self.noisy_gating and train:
raw_noise_stddev = x @ self.w_noise
noise_stddev = self.softplus(raw_noise_stddev) + noise_epsilon
noisy_logits = clean_logits + (
torch.randn_like(clean_logits) * noise_stddev
)
logits = noisy_logits
else:
logits = clean_logits
# Add this safety check to ensure we always have at least one expert selected
if (logits.sum(dim=1) == 0).any():
# Add a small positive value to ensure we have non-zero logits
logits = logits + 1e-5
# calculate topk + 1 that will be needed for the noisy gates
top_logits, top_indices = logits.topk(min(self.k + 1, self.num_experts), dim=1)
top_k_logits = top_logits[:, : self.k]
top_k_indices = top_indices[:, : self.k]
top_k_gates = self.softmax(top_k_logits)
zeros = torch.zeros_like(logits, requires_grad=True, dtype=self.dtype)
gates = zeros.scatter(1, top_k_indices, top_k_gates)
# Safety check - ensure at least one expert is selected per sample
if (gates.sum(dim=1) < 1e-6).any():
# Force selection of the top expert for samples with no experts
problematic_samples = (gates.sum(dim=1) < 1e-6).nonzero().squeeze(1)
if problematic_samples.numel() > 0: # If there are problematic samples
# Select the top expert for these samples
top_expert = top_indices[problematic_samples, 0]
# Set a minimum value for the gate
gates[problematic_samples, top_expert] = 0.1
if self.noisy_gating and self.k < self.num_experts and train:
load = (
self._prob_in_top_k(
clean_logits, noisy_logits, noise_stddev, top_logits
)
).sum(0)
else:
load = self._gates_to_load(gates)
return gates, load
MoE_VAE_Token (LightningModule)
¶
Token-wise Mixture of Experts VAE for spectral data analysis.
This variant of MoE-VAE divides the input spectral bands among different experts, with each expert processing a specific spectral segment. This is particularly useful for hyperspectral data where different spectral regions may have distinct characteristics.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
input_dim |
int |
Total dimension of input spectral data. |
required |
output_dim |
int |
Dimension of output data. |
required |
latent_dim |
int |
Dimension of latent space for each VAE expert. |
required |
encoder_hidden_dims |
list |
Hidden layer dimensions for encoder networks. |
required |
decoder_hidden_dims |
list |
Hidden layer dimensions for decoder networks. |
required |
num_experts |
int |
Number of expert VAE models (spectral segments). |
required |
k |
int |
Kept for compatibility, unused in token-wise mode. Defaults to 4. |
4 |
activation |
str |
Activation function name. Defaults to 'leakyrelu'. |
'leakyrelu' |
noisy_gating |
bool |
Kept for compatibility, unused in token-wise mode. Defaults to True. |
True |
use_norm |
str or bool |
Normalization type. Defaults to False. |
False |
use_dropout |
bool |
Whether to use dropout. Defaults to False. |
False |
use_softplus_output |
bool |
Whether to apply softplus to output. Defaults to False. |
False |
**kwargs |
Additional keyword arguments. |
{} |
Source code in hypercoast/moe_vae/model.py
class MoE_VAE_Token(LightningModule):
"""Token-wise Mixture of Experts VAE for spectral data analysis.
This variant of MoE-VAE divides the input spectral bands among different
experts, with each expert processing a specific spectral segment. This is
particularly useful for hyperspectral data where different spectral regions
may have distinct characteristics.
Args:
input_dim (int): Total dimension of input spectral data.
output_dim (int): Dimension of output data.
latent_dim (int): Dimension of latent space for each VAE expert.
encoder_hidden_dims (list): Hidden layer dimensions for encoder networks.
decoder_hidden_dims (list): Hidden layer dimensions for decoder networks.
num_experts (int): Number of expert VAE models (spectral segments).
k (int, optional): Kept for compatibility, unused in token-wise mode.
Defaults to 4.
activation (str, optional): Activation function name. Defaults to 'leakyrelu'.
noisy_gating (bool, optional): Kept for compatibility, unused in token-wise mode.
Defaults to True.
use_norm (str or bool, optional): Normalization type. Defaults to False.
use_dropout (bool, optional): Whether to use dropout. Defaults to False.
use_softplus_output (bool, optional): Whether to apply softplus to output.
Defaults to False.
**kwargs: Additional keyword arguments.
"""
def __init__(
self,
input_dim,
output_dim,
latent_dim,
encoder_hidden_dims,
decoder_hidden_dims,
num_experts,
k=4,
activation="leakyrelu",
noisy_gating=True,
use_norm=False,
use_dropout=False,
use_softplus_output=False,
**kwargs,
):
super(MoE_VAE_Token, self).__init__()
self.num_experts = num_experts
self.input_dim = input_dim
self.output_dim = output_dim
self.latent_dim = latent_dim
self.encoder_hidden_dims = encoder_hidden_dims
self.decoder_hidden_dims = decoder_hidden_dims
self.activation = activation
self.use_norm = use_norm
self.use_dropout = use_dropout
self.use_softplus_output = use_softplus_output
# instantiate experts
self.sub_input_dims = [input_dim // num_experts] * (num_experts - 1)
self.sub_input_dims.append(input_dim - sum(self.sub_input_dims))
self.experts = nn.ModuleList(
[
VAE(
sub_dim,
sub_dim,
self.latent_dim,
self.encoder_hidden_dims,
self.decoder_hidden_dims,
self.activation,
use_norm=self.use_norm,
use_dropout=self.use_dropout,
use_softplus_output=self.use_softplus_output,
)
for sub_dim in self.sub_input_dims
]
)
self.w_gate = nn.Parameter(
torch.zeros(input_dim, num_experts, dtype=self.dtype), requires_grad=True
)
self.w_noise = nn.Parameter(
torch.zeros(input_dim, num_experts, dtype=self.dtype), requires_grad=True
)
self.softplus = nn.Softplus()
self.softmax = nn.Softmax(1)
self.register_buffer("mean", torch.tensor([0.0]))
self.register_buffer("std", torch.tensor([1.0]))
self.batch_gates = None
self.k = k
assert self.k <= self.num_experts
def forward(self, x, moe_weight=0.0):
"""
Token-wise MoE forward pass:
Each expert processes a different spectral segment of the input.
Args:
x: Tensor of shape [batch_size, input_dim]
moe_weight: kept for compatibility, but unused in token-wise mode.
Returns:
A dict with:
'pred_y': reconstructed tensor of shape [batch_size, input_dim]
'moe_loss': dummy 0.0 (no gating loss in token-wise)
"""
# Split the input into band segments for each expert
x_chunks = torch.split(x, self.sub_input_dims, dim=1)
expert_outputs = []
for i in range(self.num_experts):
out_i = self.experts[i](x_chunks[i])["pred_y"]
expert_outputs.append(out_i)
pred_y = torch.cat(expert_outputs, dim=1)
return {
"pred_y": pred_y,
"moe_loss": torch.tensor(0.0, device=x.device, dtype=x.dtype),
}
def loss_fn(self, output_dict):
"""Compute loss for token-wise MoE-VAE model.
Computes reconstruction loss without MoE-specific penalties since
no gating mechanism is used in the token-wise approach.
Args:
output_dict (dict): Dictionary containing model outputs and targets:
- 'pred_y': Model predictions
- 'y': Target values
- 'moe_loss': Always zero for token-wise model
Returns:
dict: Dictionary containing loss components:
- 'total_loss': MAE loss (no MoE penalty)
- 'mae_loss': Mean Absolute Error
- 'mse_loss': Mean Squared Error
- 'moe_loss': Zero tensor
"""
pred_y = output_dict["pred_y"]
y = output_dict["y"]
batch_size = y.shape[0]
MAE = F.l1_loss(pred_y, y, reduction="mean")
mse_losss = F.mse_loss(pred_y, y, reduction="mean")
moe_loss = output_dict.get(
"moe_loss", torch.tensor(0.0, device=pred_y.device, dtype=pred_y.dtype)
)
total_loss = MAE + moe_loss
return {
"total_loss": total_loss,
"mae_loss": MAE,
"mse_loss": mse_losss,
"moe_loss": moe_loss,
}
forward(self, x, moe_weight=0.0)
¶
Token-wise MoE forward pass: Each expert processes a different spectral segment of the input.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
x |
Tensor of shape [batch_size, input_dim] |
required | |
moe_weight |
kept for compatibility, but unused in token-wise mode. |
0.0 |
Returns:
Type | Description |
---|---|
A dict with |
'pred_y': reconstructed tensor of shape [batch_size, input_dim] 'moe_loss': dummy 0.0 (no gating loss in token-wise) |
Source code in hypercoast/moe_vae/model.py
def forward(self, x, moe_weight=0.0):
"""
Token-wise MoE forward pass:
Each expert processes a different spectral segment of the input.
Args:
x: Tensor of shape [batch_size, input_dim]
moe_weight: kept for compatibility, but unused in token-wise mode.
Returns:
A dict with:
'pred_y': reconstructed tensor of shape [batch_size, input_dim]
'moe_loss': dummy 0.0 (no gating loss in token-wise)
"""
# Split the input into band segments for each expert
x_chunks = torch.split(x, self.sub_input_dims, dim=1)
expert_outputs = []
for i in range(self.num_experts):
out_i = self.experts[i](x_chunks[i])["pred_y"]
expert_outputs.append(out_i)
pred_y = torch.cat(expert_outputs, dim=1)
return {
"pred_y": pred_y,
"moe_loss": torch.tensor(0.0, device=x.device, dtype=x.dtype),
}
loss_fn(self, output_dict)
¶
Compute loss for token-wise MoE-VAE model.
Computes reconstruction loss without MoE-specific penalties since no gating mechanism is used in the token-wise approach.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
output_dict |
dict |
Dictionary containing model outputs and targets: - 'pred_y': Model predictions - 'y': Target values - 'moe_loss': Always zero for token-wise model |
required |
Returns:
Type | Description |
---|---|
dict |
Dictionary containing loss components: - 'total_loss': MAE loss (no MoE penalty) - 'mae_loss': Mean Absolute Error - 'mse_loss': Mean Squared Error - 'moe_loss': Zero tensor |
Source code in hypercoast/moe_vae/model.py
def loss_fn(self, output_dict):
"""Compute loss for token-wise MoE-VAE model.
Computes reconstruction loss without MoE-specific penalties since
no gating mechanism is used in the token-wise approach.
Args:
output_dict (dict): Dictionary containing model outputs and targets:
- 'pred_y': Model predictions
- 'y': Target values
- 'moe_loss': Always zero for token-wise model
Returns:
dict: Dictionary containing loss components:
- 'total_loss': MAE loss (no MoE penalty)
- 'mae_loss': Mean Absolute Error
- 'mse_loss': Mean Squared Error
- 'moe_loss': Zero tensor
"""
pred_y = output_dict["pred_y"]
y = output_dict["y"]
batch_size = y.shape[0]
MAE = F.l1_loss(pred_y, y, reduction="mean")
mse_losss = F.mse_loss(pred_y, y, reduction="mean")
moe_loss = output_dict.get(
"moe_loss", torch.tensor(0.0, device=pred_y.device, dtype=pred_y.dtype)
)
total_loss = MAE + moe_loss
return {
"total_loss": total_loss,
"mae_loss": MAE,
"mse_loss": mse_losss,
"moe_loss": moe_loss,
}
SparseDispatcher
¶
Helper for implementing a mixture of experts with sparse gating.
This class handles the distribution of inputs to experts and combines their outputs based on gating weights. It optimizes computation by only processing inputs for experts with non-zero gates.
The class provides two main functions: - dispatch: Creates input batches for each expert based on gating weights - combine: Combines expert outputs weighted by their respective gates
Parameters:
Name | Type | Description | Default |
---|---|---|---|
num_experts |
int |
Number of expert models. |
required |
gates |
torch.Tensor |
Gating weights of shape [batch_size, num_experts]. Element [b, e] represents the weight for sending batch element b to expert e. |
required |
Examples:
>>> gates = torch.tensor([[0.8, 0.2, 0.0], [0.1, 0.0, 0.9]])
>>> dispatcher = SparseDispatcher(3, gates)
>>> expert_inputs = dispatcher.dispatch(inputs)
>>> expert_outputs = [experts[i](expert_inputs[i]) for i in range(3)]
>>> combined_output = dispatcher.combine(expert_outputs)
Note
Input and output tensors are expected to be 2D [batch, depth]. Caller is responsible for reshaping higher-dimensional tensors before dispatch and after combine operations.
Source code in hypercoast/moe_vae/model.py
class SparseDispatcher(object):
"""Helper for implementing a mixture of experts with sparse gating.
This class handles the distribution of inputs to experts and combines their
outputs based on gating weights. It optimizes computation by only processing
inputs for experts with non-zero gates.
The class provides two main functions:
- dispatch: Creates input batches for each expert based on gating weights
- combine: Combines expert outputs weighted by their respective gates
Args:
num_experts (int): Number of expert models.
gates (torch.Tensor): Gating weights of shape [batch_size, num_experts].
Element [b, e] represents the weight for sending batch element b
to expert e.
Example:
>>> gates = torch.tensor([[0.8, 0.2, 0.0], [0.1, 0.0, 0.9]])
>>> dispatcher = SparseDispatcher(3, gates)
>>> expert_inputs = dispatcher.dispatch(inputs)
>>> expert_outputs = [experts[i](expert_inputs[i]) for i in range(3)]
>>> combined_output = dispatcher.combine(expert_outputs)
Note:
Input and output tensors are expected to be 2D [batch, depth]. Caller
is responsible for reshaping higher-dimensional tensors before dispatch
and after combine operations.
"""
def __init__(self, num_experts, gates):
"""Initialize the SparseDispatcher.
Args:
num_experts (int): Number of expert models.
gates (torch.Tensor): Gating weights of shape [batch_size, num_experts].
"""
self._gates = gates
self._num_experts = num_experts
# Safety check: ensure at least one example per expert
if (gates.sum(dim=0) == 0).any():
# Find experts with no assignments and create dummy assignments
empty_experts = (gates.sum(dim=0) == 0).nonzero().squeeze(1)
if empty_experts.numel() > 0:
# Assign the first example to all empty experts with a small weight
for expert_idx in empty_experts:
gates[0, expert_idx] = 1e-5
# Sort experts
sorted_experts, index_sorted_experts = torch.nonzero(gates).sort(0)
# Drop indices
_, self._expert_index = sorted_experts.split(1, dim=1)
# Get according batch index for each expert
self._batch_index = torch.nonzero(gates)[index_sorted_experts[:, 1], 0]
# Calculate num samples that each expert gets
self._part_sizes = (gates > 0).sum(0).tolist()
# Safety check: ensure no expert has 0 examples
for i, size in enumerate(self._part_sizes):
if size == 0:
# Add a dummy example to this expert
self._part_sizes[i] = 1
if i >= len(self._expert_index):
# Add a new dummy index if needed
self._expert_index = torch.cat(
[self._expert_index, torch.tensor([[i]], device=gates.device)]
)
self._batch_index = torch.cat(
[self._batch_index, torch.tensor([0], device=gates.device)]
)
# Expand gates to match with self._batch_index
gates_exp = gates[self._batch_index.flatten()]
self._nonzero_gates = torch.gather(gates_exp, 1, self._expert_index)
# Safety check for nonzero gates
if (self._nonzero_gates <= 0).any():
self._nonzero_gates = torch.clamp(self._nonzero_gates, min=1e-5)
def dispatch(self, inp):
"""Distribute input tensor to experts based on gating weights.
Creates separate input tensors for each expert containing only the
samples assigned to that expert (where gates[b, i] > 0).
Args:
inp (torch.Tensor): Input tensor of shape [batch_size, input_dim].
Returns:
list: List of tensors, one for each expert. Each tensor contains
only the inputs assigned to that expert.
"""
# assigns samples to experts whose gate is nonzero
# expand according to batch index so we can just split by _part_sizes
inp_exp = inp[self._batch_index].squeeze(1)
return torch.split(inp_exp, self._part_sizes, dim=0)
def combine(self, expert_out, multiply_by_gates=True):
"""Combine expert outputs weighted by gating values.
Aggregates outputs from all experts for each batch element, weighted
by the corresponding gate values. The final output for batch element b
is the sum of expert outputs weighted by gates[b, i].
Args:
expert_out (list): List of expert output tensors, each with shape
[expert_batch_size_i, output_dim].
multiply_by_gates (bool, optional): Whether to weight outputs by
gate values. If False, outputs are simply summed. Defaults to True.
Returns:
torch.Tensor: Combined output tensor of shape [batch_size, output_dim].
"""
# apply exp to expert outputs, so we are not longer in log space
stitched = torch.cat(expert_out, 0)
if multiply_by_gates:
stitched = stitched.mul(self._nonzero_gates)
zeros = torch.zeros(
self._gates.size(0),
expert_out[-1].size(1),
requires_grad=True,
device=stitched.device,
)
# combine samples that have been processed by the same k experts
combined = zeros.index_add(0, self._batch_index, stitched.float())
return combined
def expert_to_gates(self):
"""Extract gate values for each expert's assigned samples.
Returns:
list: List of 1D tensors, one for each expert, containing the
gate values for samples assigned to that expert.
"""
# split nonzero gates for each expert
return torch.split(self._nonzero_gates, self._part_sizes, dim=0)
__init__(self, num_experts, gates)
special
¶
Initialize the SparseDispatcher.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
num_experts |
int |
Number of expert models. |
required |
gates |
torch.Tensor |
Gating weights of shape [batch_size, num_experts]. |
required |
Source code in hypercoast/moe_vae/model.py
def __init__(self, num_experts, gates):
"""Initialize the SparseDispatcher.
Args:
num_experts (int): Number of expert models.
gates (torch.Tensor): Gating weights of shape [batch_size, num_experts].
"""
self._gates = gates
self._num_experts = num_experts
# Safety check: ensure at least one example per expert
if (gates.sum(dim=0) == 0).any():
# Find experts with no assignments and create dummy assignments
empty_experts = (gates.sum(dim=0) == 0).nonzero().squeeze(1)
if empty_experts.numel() > 0:
# Assign the first example to all empty experts with a small weight
for expert_idx in empty_experts:
gates[0, expert_idx] = 1e-5
# Sort experts
sorted_experts, index_sorted_experts = torch.nonzero(gates).sort(0)
# Drop indices
_, self._expert_index = sorted_experts.split(1, dim=1)
# Get according batch index for each expert
self._batch_index = torch.nonzero(gates)[index_sorted_experts[:, 1], 0]
# Calculate num samples that each expert gets
self._part_sizes = (gates > 0).sum(0).tolist()
# Safety check: ensure no expert has 0 examples
for i, size in enumerate(self._part_sizes):
if size == 0:
# Add a dummy example to this expert
self._part_sizes[i] = 1
if i >= len(self._expert_index):
# Add a new dummy index if needed
self._expert_index = torch.cat(
[self._expert_index, torch.tensor([[i]], device=gates.device)]
)
self._batch_index = torch.cat(
[self._batch_index, torch.tensor([0], device=gates.device)]
)
# Expand gates to match with self._batch_index
gates_exp = gates[self._batch_index.flatten()]
self._nonzero_gates = torch.gather(gates_exp, 1, self._expert_index)
# Safety check for nonzero gates
if (self._nonzero_gates <= 0).any():
self._nonzero_gates = torch.clamp(self._nonzero_gates, min=1e-5)
combine(self, expert_out, multiply_by_gates=True)
¶
Combine expert outputs weighted by gating values.
Aggregates outputs from all experts for each batch element, weighted by the corresponding gate values. The final output for batch element b is the sum of expert outputs weighted by gates[b, i].
Parameters:
Name | Type | Description | Default |
---|---|---|---|
expert_out |
list |
List of expert output tensors, each with shape [expert_batch_size_i, output_dim]. |
required |
multiply_by_gates |
bool |
Whether to weight outputs by gate values. If False, outputs are simply summed. Defaults to True. |
True |
Returns:
Type | Description |
---|---|
torch.Tensor |
Combined output tensor of shape [batch_size, output_dim]. |
Source code in hypercoast/moe_vae/model.py
def combine(self, expert_out, multiply_by_gates=True):
"""Combine expert outputs weighted by gating values.
Aggregates outputs from all experts for each batch element, weighted
by the corresponding gate values. The final output for batch element b
is the sum of expert outputs weighted by gates[b, i].
Args:
expert_out (list): List of expert output tensors, each with shape
[expert_batch_size_i, output_dim].
multiply_by_gates (bool, optional): Whether to weight outputs by
gate values. If False, outputs are simply summed. Defaults to True.
Returns:
torch.Tensor: Combined output tensor of shape [batch_size, output_dim].
"""
# apply exp to expert outputs, so we are not longer in log space
stitched = torch.cat(expert_out, 0)
if multiply_by_gates:
stitched = stitched.mul(self._nonzero_gates)
zeros = torch.zeros(
self._gates.size(0),
expert_out[-1].size(1),
requires_grad=True,
device=stitched.device,
)
# combine samples that have been processed by the same k experts
combined = zeros.index_add(0, self._batch_index, stitched.float())
return combined
dispatch(self, inp)
¶
Distribute input tensor to experts based on gating weights.
Creates separate input tensors for each expert containing only the samples assigned to that expert (where gates[b, i] > 0).
Parameters:
Name | Type | Description | Default |
---|---|---|---|
inp |
torch.Tensor |
Input tensor of shape [batch_size, input_dim]. |
required |
Returns:
Type | Description |
---|---|
list |
List of tensors, one for each expert. Each tensor contains only the inputs assigned to that expert. |
Source code in hypercoast/moe_vae/model.py
def dispatch(self, inp):
"""Distribute input tensor to experts based on gating weights.
Creates separate input tensors for each expert containing only the
samples assigned to that expert (where gates[b, i] > 0).
Args:
inp (torch.Tensor): Input tensor of shape [batch_size, input_dim].
Returns:
list: List of tensors, one for each expert. Each tensor contains
only the inputs assigned to that expert.
"""
# assigns samples to experts whose gate is nonzero
# expand according to batch index so we can just split by _part_sizes
inp_exp = inp[self._batch_index].squeeze(1)
return torch.split(inp_exp, self._part_sizes, dim=0)
expert_to_gates(self)
¶
Extract gate values for each expert's assigned samples.
Returns:
Type | Description |
---|---|
list |
List of 1D tensors, one for each expert, containing the gate values for samples assigned to that expert. |
Source code in hypercoast/moe_vae/model.py
def expert_to_gates(self):
"""Extract gate values for each expert's assigned samples.
Returns:
list: List of 1D tensors, one for each expert, containing the
gate values for samples assigned to that expert.
"""
# split nonzero gates for each expert
return torch.split(self._nonzero_gates, self._part_sizes, dim=0)
VAE (LightningModule)
¶
Variational Autoencoder implementation using PyTorch Lightning.
A standard VAE architecture with configurable encoder/decoder networks, support for various activation functions, normalization layers, and dropout regularization.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
input_dim |
int |
Dimension of input data. |
required |
output_dim |
int |
Dimension of output/reconstructed data. |
required |
latent_dim |
int |
Dimension of latent space. |
required |
encoder_hidden_dims |
list |
List of hidden layer dimensions for encoder. |
required |
decoder_hidden_dims |
list |
List of hidden layer dimensions for decoder. |
required |
activation |
str |
Activation function name. Supports 'relu', 'tanh', 'sigmoid', 'leakyrelu'. Defaults to 'leakyrelu'. |
'leakyrelu' |
use_norm |
str or bool |
Normalization type. Can be 'batch', 'layer', or False. Defaults to False. |
False |
use_dropout |
bool |
Whether to use dropout. Defaults to False. |
False |
use_softplus_output |
bool |
Whether to apply softplus to output. Defaults to False. |
False |
**kwargs |
Additional keyword arguments. |
{} |
Source code in hypercoast/moe_vae/model.py
class VAE(LightningModule):
"""Variational Autoencoder implementation using PyTorch Lightning.
A standard VAE architecture with configurable encoder/decoder networks,
support for various activation functions, normalization layers, and
dropout regularization.
Args:
input_dim (int): Dimension of input data.
output_dim (int): Dimension of output/reconstructed data.
latent_dim (int): Dimension of latent space.
encoder_hidden_dims (list): List of hidden layer dimensions for encoder.
decoder_hidden_dims (list): List of hidden layer dimensions for decoder.
activation (str, optional): Activation function name. Supports 'relu',
'tanh', 'sigmoid', 'leakyrelu'. Defaults to 'leakyrelu'.
use_norm (str or bool, optional): Normalization type. Can be 'batch',
'layer', or False. Defaults to False.
use_dropout (bool, optional): Whether to use dropout. Defaults to False.
use_softplus_output (bool, optional): Whether to apply softplus to output.
Defaults to False.
**kwargs: Additional keyword arguments.
"""
def __init__(
self,
input_dim,
output_dim,
latent_dim,
encoder_hidden_dims,
decoder_hidden_dims,
activation="leakyrelu",
use_norm=False,
use_dropout=False,
use_softplus_output=False,
**kwargs,
):
"""
Initialize the VAE model.
Args:
input_dim (int): Dimension of input data.
output_dim (int): Dimension of output/reconstructed data.
latent_dim (int): Dimension of latent space.
encoder_hidden_dims (list): List of hidden layer dimensions for encoder.
decoder_hidden_dims (list): List of hidden layer dimensions for decoder.
activation (str, optional): Activation function name. Supports 'relu',
'tanh', 'sigmoid', 'leakyrelu'. Defaults to 'leakyrelu'.
use_norm (str or bool, optional): Normalization type. Can be 'batch',
'layer', or False. Defaults to False.
use_dropout (bool, optional): Whether to use dropout. Defaults to False.
use_softplus_output (bool, optional): Whether to apply softplus to output.
Defaults to False.
**kwargs: Additional keyword arguments.
"""
super().__init__()
# Define the activation function
self.use_softplus_output = use_softplus_output
if activation == "relu":
self.activation = nn.ReLU()
elif activation == "tanh":
self.activation = nn.Tanh()
elif activation == "sigmoid":
self.activation = nn.Sigmoid()
elif activation == "leakyrelu":
self.activation = nn.LeakyReLU(0.2)
else:
raise ValueError(f"Unsupported activation function: {activation}")
# Encoder layers
self.encoder_layers = self.build_layers(
input_dim, encoder_hidden_dims, use_norm, use_dropout
)
self.fc_mu = nn.Linear(encoder_hidden_dims[-1], latent_dim)
self.fc_log_var = nn.Linear(encoder_hidden_dims[-1], latent_dim)
# Decoder layers
self.decoder_layers = self.build_layers(
latent_dim, decoder_hidden_dims, use_norm, use_dropout
)
# self.decoder_layers.add_module('softplus', nn.Softplus())
self.decoder_layers.add_module(
"output_layer", nn.Linear(decoder_hidden_dims[-1], output_dim)
)
if self.use_softplus_output:
self.decoder_layers.add_module("output_activation", nn.Softplus())
# self.decoder_layers.add_module('output_activation', nn.Tanh()) # Assuming output is in range [-1, 1]
# with the classic robust preprocessing method it is -1 to 1, but for others it may not.
def build_layers(self, input_dim, hidden_dims, use_norm, use_dropout=False):
"""Build sequential neural network layers.
Args:
input_dim (int): Input dimension for the first layer.
hidden_dims (list): List of hidden layer dimensions.
use_norm (str or bool): Normalization type ('batch', 'layer', or False).
use_dropout (bool, optional): Whether to include dropout layers.
Defaults to False.
Returns:
nn.Sequential: Sequential container of network layers.
"""
layers = []
current_size = input_dim
for hidden_dim in hidden_dims:
next_size = hidden_dim
layers.append(nn.Linear(current_size, next_size))
if use_norm == "batch":
layers.append(nn.BatchNorm1d(hidden_dim))
elif use_norm == "layer":
layers.append(nn.LayerNorm(hidden_dim))
layers.append(self.activation)
if use_dropout:
layers.append(nn.Dropout(0.1))
current_size = next_size
return nn.Sequential(*layers)
def encode(self, x):
"""Encode input to latent space parameters.
Args:
x (torch.Tensor): Input tensor of shape (batch_size, input_dim).
Returns:
tuple: A tuple containing:
- mu (torch.Tensor): Mean of latent distribution.
- log_var (torch.Tensor): Log variance of latent distribution.
"""
x = self.encoder_layers(x)
mu = self.fc_mu(x)
log_var = self.fc_log_var(x)
return mu, log_var
def reparameterize(self, mu, log_var):
"""Reparameterization trick for sampling from latent distribution.
Args:
mu (torch.Tensor): Mean of latent distribution.
log_var (torch.Tensor): Log variance of latent distribution.
Returns:
torch.Tensor: Sampled latent vector.
"""
std = torch.exp(0.5 * log_var)
eps = torch.randn_like(std)
z = mu + eps * std
return z
def decode(self, z):
"""Decode latent representation to output space.
Args:
z (torch.Tensor): Latent representation.
Returns:
torch.Tensor: Reconstructed output.
"""
return self.decoder_layers(z)
def forward(self, x):
"""Forward pass through the VAE.
Args:
x (torch.Tensor): Input tensor.
Returns:
dict: Dictionary containing:
- 'pred_y': Reconstructed output
- 'mu': Mean of latent distribution
- 'log_var': Log variance of latent distribution
"""
mu, log_var = self.encode(x)
z = self.reparameterize(mu, log_var)
pred_y = self.decode(z)
return {"pred_y": pred_y, "mu": mu, "log_var": log_var}
def loss_fn(self, output_dict, kld_weight=0.0):
"""Compute VAE loss (reconstruction + KL divergence).
Args:
output_dict (dict): Dictionary containing model outputs and targets.
kld_weight (float, optional): Weight for KL divergence term.
Defaults to 0.0.
Returns:
dict: Dictionary containing different loss components:
- 'total_loss': Combined loss (MAE + weighted KLD)
- 'mae_loss': Mean Absolute Error
- 'mse_loss': Mean Squared Error
- 'kld_loss': KL Divergence loss
"""
pred_y, y, mu, log_var = (
output_dict["pred_y"],
output_dict["y"],
output_dict["mu"],
output_dict["log_var"],
)
batch_size = y.shape[0]
MAE = F.l1_loss(pred_y, y, reduction="mean")
# Reconstruction loss (MSE)
MSE = F.mse_loss(pred_y, y, reduction="mean")
# KL divergence
KLD = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp()) / batch_size
# Return combined loss
return {
"total_loss": MAE + kld_weight * KLD,
"mae_loss": MAE,
"mse_loss": MSE,
"kld_loss": KLD,
}
__init__(self, input_dim, output_dim, latent_dim, encoder_hidden_dims, decoder_hidden_dims, activation='leakyrelu', use_norm=False, use_dropout=False, use_softplus_output=False, **kwargs)
special
¶
Initialize the VAE model.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
input_dim |
int |
Dimension of input data. |
required |
output_dim |
int |
Dimension of output/reconstructed data. |
required |
latent_dim |
int |
Dimension of latent space. |
required |
encoder_hidden_dims |
list |
List of hidden layer dimensions for encoder. |
required |
decoder_hidden_dims |
list |
List of hidden layer dimensions for decoder. |
required |
activation |
str |
Activation function name. Supports 'relu', 'tanh', 'sigmoid', 'leakyrelu'. Defaults to 'leakyrelu'. |
'leakyrelu' |
use_norm |
str or bool |
Normalization type. Can be 'batch', 'layer', or False. Defaults to False. |
False |
use_dropout |
bool |
Whether to use dropout. Defaults to False. |
False |
use_softplus_output |
bool |
Whether to apply softplus to output. Defaults to False. |
False |
**kwargs |
Additional keyword arguments. |
{} |
Source code in hypercoast/moe_vae/model.py
def __init__(
self,
input_dim,
output_dim,
latent_dim,
encoder_hidden_dims,
decoder_hidden_dims,
activation="leakyrelu",
use_norm=False,
use_dropout=False,
use_softplus_output=False,
**kwargs,
):
"""
Initialize the VAE model.
Args:
input_dim (int): Dimension of input data.
output_dim (int): Dimension of output/reconstructed data.
latent_dim (int): Dimension of latent space.
encoder_hidden_dims (list): List of hidden layer dimensions for encoder.
decoder_hidden_dims (list): List of hidden layer dimensions for decoder.
activation (str, optional): Activation function name. Supports 'relu',
'tanh', 'sigmoid', 'leakyrelu'. Defaults to 'leakyrelu'.
use_norm (str or bool, optional): Normalization type. Can be 'batch',
'layer', or False. Defaults to False.
use_dropout (bool, optional): Whether to use dropout. Defaults to False.
use_softplus_output (bool, optional): Whether to apply softplus to output.
Defaults to False.
**kwargs: Additional keyword arguments.
"""
super().__init__()
# Define the activation function
self.use_softplus_output = use_softplus_output
if activation == "relu":
self.activation = nn.ReLU()
elif activation == "tanh":
self.activation = nn.Tanh()
elif activation == "sigmoid":
self.activation = nn.Sigmoid()
elif activation == "leakyrelu":
self.activation = nn.LeakyReLU(0.2)
else:
raise ValueError(f"Unsupported activation function: {activation}")
# Encoder layers
self.encoder_layers = self.build_layers(
input_dim, encoder_hidden_dims, use_norm, use_dropout
)
self.fc_mu = nn.Linear(encoder_hidden_dims[-1], latent_dim)
self.fc_log_var = nn.Linear(encoder_hidden_dims[-1], latent_dim)
# Decoder layers
self.decoder_layers = self.build_layers(
latent_dim, decoder_hidden_dims, use_norm, use_dropout
)
# self.decoder_layers.add_module('softplus', nn.Softplus())
self.decoder_layers.add_module(
"output_layer", nn.Linear(decoder_hidden_dims[-1], output_dim)
)
if self.use_softplus_output:
self.decoder_layers.add_module("output_activation", nn.Softplus())
# self.decoder_layers.add_module('output_activation', nn.Tanh()) # Assuming output is in range [-1, 1]
# with the classic robust preprocessing method it is -1 to 1, but for others it may not.
build_layers(self, input_dim, hidden_dims, use_norm, use_dropout=False)
¶
Build sequential neural network layers.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
input_dim |
int |
Input dimension for the first layer. |
required |
hidden_dims |
list |
List of hidden layer dimensions. |
required |
use_norm |
str or bool |
Normalization type ('batch', 'layer', or False). |
required |
use_dropout |
bool |
Whether to include dropout layers. Defaults to False. |
False |
Returns:
Type | Description |
---|---|
nn.Sequential |
Sequential container of network layers. |
Source code in hypercoast/moe_vae/model.py
def build_layers(self, input_dim, hidden_dims, use_norm, use_dropout=False):
"""Build sequential neural network layers.
Args:
input_dim (int): Input dimension for the first layer.
hidden_dims (list): List of hidden layer dimensions.
use_norm (str or bool): Normalization type ('batch', 'layer', or False).
use_dropout (bool, optional): Whether to include dropout layers.
Defaults to False.
Returns:
nn.Sequential: Sequential container of network layers.
"""
layers = []
current_size = input_dim
for hidden_dim in hidden_dims:
next_size = hidden_dim
layers.append(nn.Linear(current_size, next_size))
if use_norm == "batch":
layers.append(nn.BatchNorm1d(hidden_dim))
elif use_norm == "layer":
layers.append(nn.LayerNorm(hidden_dim))
layers.append(self.activation)
if use_dropout:
layers.append(nn.Dropout(0.1))
current_size = next_size
return nn.Sequential(*layers)
decode(self, z)
¶
Decode latent representation to output space.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
z |
torch.Tensor |
Latent representation. |
required |
Returns:
Type | Description |
---|---|
torch.Tensor |
Reconstructed output. |
Source code in hypercoast/moe_vae/model.py
def decode(self, z):
"""Decode latent representation to output space.
Args:
z (torch.Tensor): Latent representation.
Returns:
torch.Tensor: Reconstructed output.
"""
return self.decoder_layers(z)
encode(self, x)
¶
Encode input to latent space parameters.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
x |
torch.Tensor |
Input tensor of shape (batch_size, input_dim). |
required |
Returns:
Type | Description |
---|---|
tuple |
A tuple containing: - mu (torch.Tensor): Mean of latent distribution. - log_var (torch.Tensor): Log variance of latent distribution. |
Source code in hypercoast/moe_vae/model.py
def encode(self, x):
"""Encode input to latent space parameters.
Args:
x (torch.Tensor): Input tensor of shape (batch_size, input_dim).
Returns:
tuple: A tuple containing:
- mu (torch.Tensor): Mean of latent distribution.
- log_var (torch.Tensor): Log variance of latent distribution.
"""
x = self.encoder_layers(x)
mu = self.fc_mu(x)
log_var = self.fc_log_var(x)
return mu, log_var
forward(self, x)
¶
Forward pass through the VAE.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
x |
torch.Tensor |
Input tensor. |
required |
Returns:
Type | Description |
---|---|
dict |
Dictionary containing: - 'pred_y': Reconstructed output - 'mu': Mean of latent distribution - 'log_var': Log variance of latent distribution |
Source code in hypercoast/moe_vae/model.py
def forward(self, x):
"""Forward pass through the VAE.
Args:
x (torch.Tensor): Input tensor.
Returns:
dict: Dictionary containing:
- 'pred_y': Reconstructed output
- 'mu': Mean of latent distribution
- 'log_var': Log variance of latent distribution
"""
mu, log_var = self.encode(x)
z = self.reparameterize(mu, log_var)
pred_y = self.decode(z)
return {"pred_y": pred_y, "mu": mu, "log_var": log_var}
loss_fn(self, output_dict, kld_weight=0.0)
¶
Compute VAE loss (reconstruction + KL divergence).
Parameters:
Name | Type | Description | Default |
---|---|---|---|
output_dict |
dict |
Dictionary containing model outputs and targets. |
required |
kld_weight |
float |
Weight for KL divergence term. Defaults to 0.0. |
0.0 |
Returns:
Type | Description |
---|---|
dict |
Dictionary containing different loss components: - 'total_loss': Combined loss (MAE + weighted KLD) - 'mae_loss': Mean Absolute Error - 'mse_loss': Mean Squared Error - 'kld_loss': KL Divergence loss |
Source code in hypercoast/moe_vae/model.py
def loss_fn(self, output_dict, kld_weight=0.0):
"""Compute VAE loss (reconstruction + KL divergence).
Args:
output_dict (dict): Dictionary containing model outputs and targets.
kld_weight (float, optional): Weight for KL divergence term.
Defaults to 0.0.
Returns:
dict: Dictionary containing different loss components:
- 'total_loss': Combined loss (MAE + weighted KLD)
- 'mae_loss': Mean Absolute Error
- 'mse_loss': Mean Squared Error
- 'kld_loss': KL Divergence loss
"""
pred_y, y, mu, log_var = (
output_dict["pred_y"],
output_dict["y"],
output_dict["mu"],
output_dict["log_var"],
)
batch_size = y.shape[0]
MAE = F.l1_loss(pred_y, y, reduction="mean")
# Reconstruction loss (MSE)
MSE = F.mse_loss(pred_y, y, reduction="mean")
# KL divergence
KLD = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp()) / batch_size
# Return combined loss
return {
"total_loss": MAE + kld_weight * KLD,
"mae_loss": MAE,
"mse_loss": MSE,
"kld_loss": KLD,
}
reparameterize(self, mu, log_var)
¶
Reparameterization trick for sampling from latent distribution.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
mu |
torch.Tensor |
Mean of latent distribution. |
required |
log_var |
torch.Tensor |
Log variance of latent distribution. |
required |
Returns:
Type | Description |
---|---|
torch.Tensor |
Sampled latent vector. |
Source code in hypercoast/moe_vae/model.py
def reparameterize(self, mu, log_var):
"""Reparameterization trick for sampling from latent distribution.
Args:
mu (torch.Tensor): Mean of latent distribution.
log_var (torch.Tensor): Log variance of latent distribution.
Returns:
torch.Tensor: Sampled latent vector.
"""
std = torch.exp(0.5 * log_var)
eps = torch.randn_like(std)
z = mu + eps * std
return z
evaluate(model, test_dl, device, TSS_scalers_dict=None, log_offset=0.01)
¶
Evaluate the MoE-VAE model.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
model |
torch.nn.Module |
MoE-VAE model. |
required |
test_dl |
torch.utils.data.DataLoader |
DataLoader for test data. |
required |
device |
torch.device |
Device to use for evaluation. |
required |
TSS_scalers_dict |
dict |
Dictionary containing scalers for TSS. |
None |
log_offset |
float |
Log offset for predictions. Defaults to 0.01. |
0.01 |
Returns:
Type | Description |
---|---|
tuple |
Tuple containing: - predictions_inverse (numpy.ndarray): Inverse transformed predictions. - actuals_inverse (numpy.ndarray): Inverse transformed actuals. |
Source code in hypercoast/moe_vae/model.py
def evaluate(model, test_dl, device, TSS_scalers_dict=None, log_offset=0.01):
"""Evaluate the MoE-VAE model.
Args:
model (torch.nn.Module): MoE-VAE model.
test_dl (torch.utils.data.DataLoader): DataLoader for test data.
device (torch.device): Device to use for evaluation.
TSS_scalers_dict (dict, optional): Dictionary containing scalers for TSS.
log_offset (float, optional): Log offset for predictions. Defaults to 0.01.
Returns:
tuple: Tuple containing:
- predictions_inverse (numpy.ndarray): Inverse transformed predictions.
- actuals_inverse (numpy.ndarray): Inverse transformed actuals.
"""
model.eval()
predictions, actuals = [], []
with torch.no_grad():
for x, y in test_dl:
x, y = x.to(device), y.to(device)
output_dict = model(x)
y_pred = output_dict["pred_y"]
predictions.append(y_pred.cpu().numpy())
actuals.append(y.cpu().numpy())
predictions = np.vstack(predictions)
actuals = np.vstack(actuals)
# === Inverse transformation ===
if TSS_scalers_dict is not None:
log_scaler = TSS_scalers_dict["log"]
robust_scaler = TSS_scalers_dict["robust"]
# First reverse min-max, then reverse log
predictions_inverse = (
log_scaler.inverse_transform(
torch.tensor(
robust_scaler.inverse_transform(
torch.tensor(predictions, dtype=torch.float32)
),
dtype=torch.float32,
)
)
.numpy()
.flatten()
)
actuals_inverse = (
log_scaler.inverse_transform(
torch.tensor(
robust_scaler.inverse_transform(
torch.tensor(actuals, dtype=torch.float32)
),
dtype=torch.float32,
)
)
.numpy()
.flatten()
)
else:
predictions_inverse = (10 ** predictions.flatten()) - log_offset
actuals_inverse = (10 ** actuals.flatten()) - log_offset
return predictions_inverse, actuals_inverse
evaluate_token(model, test_dl, device, TSS_scalers_dict=None, log_offset=0.01)
¶
Evaluate the token-wise MoE-VAE model.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
model |
torch.nn.Module |
MoE-VAE model. |
required |
test_dl |
torch.utils.data.DataLoader |
DataLoader for test data. |
required |
device |
torch.device |
Device to use for evaluation. |
required |
TSS_scalers_dict |
dict |
Dictionary containing scalers for TSS. |
None |
log_offset |
float |
Log offset for predictions. Defaults to 0.01. |
0.01 |
Returns:
Type | Description |
---|---|
tuple |
Tuple containing: - predictions_inverse (numpy.ndarray): Inverse transformed predictions. - actuals_inverse (numpy.ndarray): Inverse transformed actuals. |
Source code in hypercoast/moe_vae/model.py
def evaluate_token(model, test_dl, device, TSS_scalers_dict=None, log_offset=0.01):
"""Evaluate the token-wise MoE-VAE model.
Args:
model (torch.nn.Module): MoE-VAE model.
test_dl (torch.utils.data.DataLoader): DataLoader for test data.
device (torch.device): Device to use for evaluation.
TSS_scalers_dict (dict, optional): Dictionary containing scalers for TSS.
log_offset (float, optional): Log offset for predictions. Defaults to 0.01.
Returns:
tuple: Tuple containing:
- predictions_inverse (numpy.ndarray): Inverse transformed predictions.
- actuals_inverse (numpy.ndarray): Inverse transformed actuals.
"""
model.eval()
predictions, actuals = [], []
with torch.no_grad():
for x, y in test_dl:
x, y = x.to(device), y.to(device)
output_dict = model(x)
y_pred = output_dict["pred_y"] # [B, token_len]
if y_pred.ndim == 2:
y_pred = y_pred.mean(dim=1, keepdim=True) # [B, 1]
predictions.append(y_pred.cpu().numpy())
actuals.append(y.cpu().numpy())
predictions = np.vstack(predictions)
actuals = np.vstack(actuals)
# === Inverse transformation ===
if TSS_scalers_dict is not None:
log_scaler = TSS_scalers_dict["log"]
robust_scaler = TSS_scalers_dict["robust"]
# First reverse min-max, then reverse log
predictions_inverse = (
log_scaler.inverse_transform(
torch.tensor(
robust_scaler.inverse_transform(
torch.tensor(predictions, dtype=torch.float32)
),
dtype=torch.float32,
)
)
.numpy()
.flatten()
)
actuals_inverse = (
log_scaler.inverse_transform(
torch.tensor(
robust_scaler.inverse_transform(
torch.tensor(actuals, dtype=torch.float32)
),
dtype=torch.float32,
)
)
.numpy()
.flatten()
)
else:
predictions_inverse = (10 ** predictions.flatten()) - log_offset
actuals_inverse = (10 ** actuals.flatten()) - log_offset
return predictions_inverse, actuals_inverse
train(model, train_dl, device, epochs=200, optimizer=None, save_dir=None)
¶
Train the MoE-VAE model.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
model |
torch.nn.Module |
MoE-VAE model. |
required |
train_dl |
torch.utils.data.DataLoader |
DataLoader for training data. |
required |
device |
torch.device |
Device to use for training. |
required |
epochs |
int |
Number of epochs to train. Defaults to 200. |
200 |
optimizer |
torch.optim.Optimizer |
Optimizer to use for training. |
None |
save_dir |
str |
Directory to save the model. Defaults to None. |
None |
Returns:
Type | Description |
---|---|
dict |
Dictionary containing training metrics: - 'total_loss': List of total loss values per epoch. - 'l1_loss': List of L1 loss values per epoch. - 'best_loss': Minimum total loss value. |
Source code in hypercoast/moe_vae/model.py
def train(model, train_dl, device, epochs=200, optimizer=None, save_dir=None):
"""Train the MoE-VAE model.
Args:
model (torch.nn.Module): MoE-VAE model.
train_dl (torch.utils.data.DataLoader): DataLoader for training data.
device (torch.device): Device to use for training.
epochs (int, optional): Number of epochs to train. Defaults to 200.
optimizer (torch.optim.Optimizer, optional): Optimizer to use for training.
save_dir (str, optional): Directory to save the model. Defaults to None.
Returns:
dict: Dictionary containing training metrics:
- 'total_loss': List of total loss values per epoch.
- 'l1_loss': List of L1 loss values per epoch.
- 'best_loss': Minimum total loss value.
"""
model.train()
min_total_loss = float("inf")
best_model_path = os.path.join(save_dir, "best_model_minloss.pth")
total_list = []
l1_list = []
for epoch in range(epochs):
total_loss_epoch = 0.0
l1_epoch = 0.0
for x, y in train_dl:
x, y = x.to(device), y.to(device)
output_dict = model(x)
output_dict["y"] = y
loss_dict = model.loss_fn(output_dict)
loss = loss_dict["total_loss"]
l1 = loss_dict["mae_loss"]
optimizer.zero_grad()
loss.backward()
optimizer.step()
total_loss_epoch += loss.item()
l1_epoch += l1.item()
avg_total_loss = total_loss_epoch / len(train_dl)
avg_l1 = l1_epoch / len(train_dl)
print(f"[Epoch {epoch+1}] Total: {avg_total_loss:.4f} | L1: {avg_l1:.4f}")
total_list.append(avg_total_loss)
l1_list.append(avg_l1)
if avg_total_loss < min_total_loss:
min_total_loss = avg_total_loss
torch.save(model.state_dict(), best_model_path)
return {"total_loss": total_list, "l1_loss": l1_list, "best_loss": min_total_loss}