How to use have batch norm not forget batch statistics it just used in Pytorch?

I am in an unusual setting where I should not use running statistics (as that would be considered cheating e.g. meta-learning). However, I often run a forward pass on a set of points (5 in fact) and then I want to evaluate only on 1 point using the previous statistics but batch norm forgets the batch statistics it just uses. I’ve tried to hard code the value it should be but I get strange errors (even when I uncomment things like from the pytorch code itself like checking the dimension size).

How do I hardcode the previous batch statistics so that batch norm works on a new single data point and then reset them for a fresh new next batch?

note: I don’t want to change the batch norm layer type.

Sample code I tried:

def set_tracking_running_stats(model):
    for attr in dir(model):
        if 'bn' in attr:
            target_attr = getattr(model, attr)
            target_attr.track_running_stats = True
            target_attr.running_mean = torch.nn.Parameter(torch.zeros(target_attr.num_features, requires_grad=False))
            target_attr.running_var = torch.nn.Parameter(torch.ones(target_attr.num_features, requires_grad=False))
            target_attr.num_batches_tracked = torch.nn.Parameter(torch.tensor(0, dtype=torch.long), requires_grad=False)
            # target_attr.reset_running_stats()
    return

my most comment errors:

    raise ValueError('expected 2D or 3D input (got {}D input)'
ValueError: expected 2D or 3D input (got 1D input)

and

IndexError: Dimension out of range (expected to be in range of (-1, 0), but got 1)

pytorch forum: https://discuss.pytorch.org/t/how-to-use-have-batch-norm-not-forget-batch-statistics-it-just-used/103437