python – validation and test loss for a variety of PyTorch time series forecasting models

Hi everyone I’m trying to reduce the complexity of some of my Python code. The function below aims to compute the validation and test loss for a variety of PyTorch time series forecasting models. I won’t go into all the intricacies but needs to support models that return multiple targets, an output distribution + std (as opposed to a single tensor), and models that require masked elements of the target sequence. This over time has resulted in long if else blocks and lots of other bad practices.

I’ve used dictionaries before to map long if else statements but due to the nested nature of this code it doesn’t seem like it would work well here. I also don’t really see the point in just creating more functions as that just moves the if else statements somewhere else and requires passing more parameters around. Does anyone have any ideas? There are several unit tests that run from the different paths in this code now. However, it is still cumbersome to read. Plus soon I will have even more model variations to expand and support. Full code can in context can be seen at this link.

def compute_validation(validation_loader: DataLoader,
                       model,
                       epoch: int,
                       sequence_size: int,
                       criterion: Type(torch.nn.modules.loss._Loss),
                       device: torch.device,
                       decoder_structure=False,
                       meta_data_model=None,
                       use_wandb: bool = False,
                       meta_model=None,
                       multi_targets=1,
                       val_or_test="validation_loss",
                       probabilistic=False) -> float:
    """Function to compute the validation loss metrics

    :param validation_loader: The data-loader of either validation or test-data
    :type validation_loader: DataLoader
    :param model: model
    :type model: (type)
    :param epoch: The epoch where the validation/test loss is being computed.
    :type epoch: int
    :param sequence_size: The number of historical time steps passed into the model
    :type sequence_size: int
    :param criterion: The evaluation metric function
    :type criterion: Type(torch.nn.modules.loss._Loss)
    :param device: The device
    :type device: torch.device
    :param decoder_structure: Whether the model should use sequential decoding, defaults to False
    :type decoder_structure: bool, optional
    :param meta_data_model: The model to handle the meta-data, defaults to None
    :type meta_data_model: PyTorchForecast, optional
    :param use_wandb: Whether Weights and Biases is in use, defaults to False
    :type use_wandb: bool, optional
    :param meta_model: Whether the model leverages meta-data, defaults to None
    :type meta_model: bool, optional
    :param multi_targets: Whether the model, defaults to 1
    :type multi_targets: int, optional
    :param val_or_test: Whether validation or test loss is computed, defaults to "validation_loss"
    :type val_or_test: str, optional
    :param probabilistic: Whether the model is probablistic, defaults to False
    :type probabilistic: bool, optional
    :return: The loss of the first metric in the list.
    :rtype: float
    """
    print('Computing validation loss')
    unscaled_crit = dict.fromkeys(criterion, 0)
    scaled_crit = dict.fromkeys(criterion, 0)
    model.eval()
    output_std = None
    multi_targs1 = multi_targets
    scaler = None
    if validation_loader.dataset.no_scale:
        scaler = validation_loader.dataset
    with torch.no_grad():
        i = 0
        loss_unscaled_full = 0.0
        for src, targ in validation_loader:
            src = src if isinstance(src, list) else src.to(device)
            targ = targ if isinstance(targ, list) else targ.to(device)
            # targ = targ if isinstance(targ, list) else targ.to(device)
            i += 1
            if decoder_structure:
                if type(model).__name__ == "SimpleTransformer":
                    targ_clone = targ.detach().clone()
                    output = greedy_decode(
                        model,
                        src,
                        targ.shape(1),
                        targ_clone,
                        device=device)(
                        :,
                        :,
                        0)
                elif type(model).__name__ == "Informer":
                    multi_targets = multi_targs1
                    filled_targ = targ(1).clone()
                    pred_len = model.pred_len
                    filled_targ(:, -pred_len:, :) = torch.zeros_like(filled_targ(:, -pred_len:, :)).float().to(device)
                    output = model(src(0).to(device), src(1).to(device), filled_targ.to(device), targ(0).to(device))
                    labels = targ(1)(:, -pred_len:, 0:multi_targets)
                    src = src(0)
                    multi_targets = False
                else:
                    output = simple_decode(model=model,
                                           src=src,
                                           max_seq_len=targ.shape(1),
                                           real_target=targ,
                                           output_len=sequence_size,
                                           multi_targets=multi_targets,
                                           probabilistic=probabilistic,
                                           scaler=scaler)
                    if probabilistic:
                        output, output_std = output(0), output(1)
                        output, output_std = output(:, :, 0), output_std(0)
                        output_dist = torch.distributions.Normal(output, output_std)
            else:
                if probabilistic:
                    output_dist = model(src.float())
                    output = output_dist.mean.detach().numpy()
                    output_std = output_dist.stddev.detach().numpy()
                else:
                    output = model(src.float())
            if multi_targets == 1:
                labels = targ(:, :, 0)
            elif multi_targets > 1:
                labels = targ(:, :, 0:multi_targets)
            validation_dataset = validation_loader.dataset
            for crit in criterion:
                if validation_dataset.scale:
                    # Should this also do loss.item() stuff?
                    if len(src.shape) == 2:
                        src = src.unsqueeze(0)
                    src1 = src(:, :, 0:multi_targets)
                    loss_unscaled_full = compute_loss(labels, output, src1, crit, validation_dataset,
                                                      probabilistic, output_std, m=multi_targets)
                    unscaled_crit(crit) += loss_unscaled_full.item() * len(labels.float())
                loss = compute_loss(labels, output, src, crit, False, probabilistic, output_std, m=multi_targets)
                scaled_crit(crit) += loss.item() * len(labels.float())
    if use_wandb:
        if loss_unscaled_full:
            scaled = {k.__class__.__name__: v / (len(validation_loader.dataset) - 1) for k, v in scaled_crit.items()}
            newD = {k.__class__.__name__: v / (len(validation_loader.dataset) - 1) for k, v in unscaled_crit.items()}
            wandb.log({'epoch': epoch,
                       val_or_test: scaled,
                       "unscaled_" + val_or_test: newD})
        else:
            scaled = {k.__class__.__name__: v / (len(validation_loader.dataset) - 1) for k, v in scaled_crit.items()}
            wandb.log({'epoch': epoch, val_or_test: scaled})
    model.train()
    return list(scaled_crit.values())(0)