Uncertainty propagation in Deep Neural Networks using Ensembles


This post is mainly reproduction of the paper: Simple and Scalable Predictive Uncertainty Estimation using Deep Ensembles (DeepMind, NIPS 2017)

Deep neural networks (NNs) are powerful black box predictors that have recently achieved impressive performance on a wide spectrum of tasks. However, Quantifying their predictive uncertainty is still challenging. Bayesian approximation and ensemble learning techniques are two of the most widely-used uncertainty evaluation methods in the literature. In this outputt ill discuss the implementation of one of the simplest method (which still got great results) to create a deep neural network for estimating predictive uncertainty.


Such probabilistic models capture the inherent ambiguity in outputs for a given input (aleatoric uncertainty) and ensembles to capture subjective uncertainty (epistemic uncertainty). To that aim, we use a Gaussian parameterization of the form $p_\theta(y|z)=\mathcal{N}(\mu_\theta(z),\Sigma_\theta(z))$ where $y$ is the output and $z$ is the input respectively. The mean $\mu_\theta(z)$ and diagonal covariance $\Sigma_\theta(z)$ is given by a neural network. With deterministic NN, we can simply output the mean value $\mu_\theta(z)$, using mean squared error as the loss function during training $\mathcal{L}(\theta)= \sum_{n}(y-\mu(z))^{2}$. For probalistic NN, we treat the input $z$ as a sample from Gaussian distribution, while using a negative log-likelihood loss function: \begin{equation} \mathcal{L}(\theta)=-\log{p_\theta(y|z)}=\frac{\log{\sigma^{2}_\theta(zi)}}{2}+\frac{(y-\mu_\theta(z))^{2}}{2}+ const. \end{equation}

During training we initialized each model $p_\theta^m$ with different random initialization parameters and different batch of the data and treated each ensemble as a Guassian mixture model. Each ensemble is then treated as Gaussian mixture model: \begin{equation} \mathcal{N}(\mu(z),\sigma^{2}(z)) = M^{-1}\sum_{m}{p_{\theta_m} (y|z)} \end{equation} where the mean and variance of the mixture is given by: \begin{equation} \mu(z) = M^{-1}\sum_{m}\mu_{\theta_m}(z) \end{equation} \begin{equation} \sigma^{2}(z) = M^{-1}\sum_{m}(\sigma^{2}_{\theta_m}(z)+\mu^{2}_{\theta_m}(z)) - \mu(z) \end{equation}


Alight, lets dig into the implementation, first we import all the necessary packages:

import numpy as np
import torch
import torch.distributions
import torch.nn.functional
import torch.nn as nn

Then, we will implement a simple Multilayer perceptron:

class MLP(nn.Module):
    def __init__(self,
        input_dim,
        output_dim,
        n_layers,
        size,
        device,
        deterministic,
        dropout_p,
        activation = nn.Tanh()):
        nn.Module.__init__(self)

        self.deterministic = deterministic
        self.output_dim = output_dim

        if not self.deterministic:
            self.output_dim *= 2

        # network architecture
        self.mlp = nn.ModuleList()
        self.mlp.append(nn.Linear(input_dim, size)) #first hidden layer
        self.mlp.append(activation)
        # self.mlp.append((nn.Dropout(p=dropout_p)))

        for h in range(n_layers - 1): #additional hidden layers
            self.mlp.append(nn.Linear(size, size))
            self.mlp.append(activation)
            self.mlp.append((nn.Dropout(p=dropout_p)))

        self.mlp.append(nn.Linear(size, self.output_dim)) #output layer, no activation function

        self.to(device)

    def forward(self, x):
        for layer in self.mlp:
            x = layer(x)
        if self.deterministic:
            return x
        else:
            mean, variance = torch.split(x, int(self.output_dim/2), dim=1)
            variance = torch.nn.functional.softplus(variance) + 1e-6
            return (mean, variance)

    def save(self, filepath):
        torch.save(self.state_dict(), filepath, _use_new_zipfile_serialization=False)

    def restore(self, filepath):
        self.load_state_dict(torch.load(filepath))

Then we will create the model:

class Model:
    def __init__(self, out_dim, ob_dim, n_layers, size, device, deterministic, optimizer,dropout_p, learning_rate = 0.001):
        # init vars
        self.device = device
        self.deterministic = deterministic

        self.mlp = MLP(input_dim = ob_dim,
                              output_dim = out_dim,
                              n_layers = n_layers,
                              size = size,
                              device = self.device,
                              deterministic = deterministic,
                              dropout_p= dropout_p)


        self.optimizer = getattr(torch.optim, optimizer)(self.mlp.parameters(), lr = learning_rate, weight_decay=1e-5)

    #############################

    def get_prediction(self, obs):

        if len(obs.shape) == 1:
            obs = np.squeeze(obs)[None]

        obs = torch.Tensor(obs).to(self.device)

        if self.deterministic:
            output = self.mlp(obs).cpu().detach().numpy()
            return output
        else:
            out = self.mlp(obs)
            output_mean, output_var = out[0].cpu().detach().numpy(), out[1].cpu().detach().numpy()
            return output_mean, output_var

    def update(self, observations, true_output):

        pred_output = self.mlp(torch.Tensor(observations).to(self.device))
        true_output = torch.Tensor(true_output).to(self.device)

        if self.deterministic:
            loss = nn.functional.mse_loss(true_output, pred_output)
        else:
            # Negative log-likelihood loss function.
            mean = pred_output[0]
            var = pred_output[1]
            loss = torch.mean(0.5*torch.log(var) + 0.5*((true_output - mean).pow(2))/var)#.sum()

        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

        return loss.item()

    def eval(self, observations, true_output):

        with torch.no_grad():
            pred_output = self.mlp(torch.Tensor(observations).to(self.device))

        true_output = torch.Tensor(true_output).to(self.device)

        if self.deterministic:
            loss = nn.functional.mse_loss(true_output, pred_output)
        else:
            # Negative log-likelihood loss function.
            mean = pred_output[0]
            var = pred_output[1]
            loss = torch.mean(0.5*torch.log(var) + 0.5*((true_output - mean).pow(2))/var)#.sum()

        return loss.item()

    def save_model(self,path):
        self.mlp.save(path)

    def load_model(self,path):
        self.mlp.restore(path)

And finally the Ensemble class:

class ModelEnsemble():
    def __init__(self, params):
        # super(ModelEnsemble, self).__init__()

        self.params = params
        self.ensemble_size = self.params['ensemble_size']

        GEs = []
        for i in range(self.ensemble_size):
            model = Model(self.params['out_dim'],
                            self.params['obs_dim'],
                            self.params['n_layers'],
                            self.params['size'],
                            self.params['device'],
                            self.params['deterministic'],
                            self.params['optimizer'],
                            self.params['dropout_p'],
                            self.params['learning_rate'])

            GEs.append(model)

    def forward(self,  obs):
        means = []
        variances = []
        outputes = []

        if GEs[0].deterministic:
            for model in GEs:
                output = model.get_prediction(obs)
                outputes.append(output)
            mean = np.mean(outputes, axis=0)
            variance = np.var(outputes, axis=0,ddof=1)
        else:
            for model in GEs:
                mean_m, var_m = model.get_prediction(obs)
                means.append(mean_m)
                variances.append(var_m)
            mean = np.mean(means, axis=0)
            variance = np.mean((variances + np.power(means, 2)), axis=0) - np.power(mean, 2)

        return mean, variance

    def train(self, obs, output):

        # TODO: each model in the ensemble is trained on a different random batch of size batch_size
        losses = []
        num_data = obs.shape[0]
        num_data_per_ens = int(num_data / self.ensemble_size)

        start = 0
        for model in GEs:
            # select which datapoints to use for this model of the ensemble
            finish = start + num_data_per_ens

            observations = obs[start:finish]
            outputes = output[start:finish]

            # use datapoints to update one of the models
            loss = model.update(observations, outputes)
            losses.append(loss)

            start = finish

        avg_loss = np.mean(losses)

        return avg_loss

    def eval(self, obs, output):

        losses = []
        num_data = obs.shape[0]
        num_data_per_ens = int(num_data / self.ensemble_size)

        start = 0
        for model in GEs:
            # select which datapoints to use for this model of the ensemble
            finish = start + num_data_per_ens

            observations = obs[start:finish]
            outputes = output[start:finish]

            # use datapoints to update one of the models
            loss = model.eval(observations, outputes)
            losses.append(loss)

            start = finish

        avg_loss = np.mean(losses)

        return avg_loss

    def save_models(self,path):
        for i, model in enumerate(GEs):
            model.save_model(path + str(i))

    def load_models(self,path):

        GEs = []
        for i in range(self.ensemble_size):
            model = Model(self.params['out_dim'],
                            self.params['obs_dim'],
                            self.params['n_layers'],
                            self.params['size'],
                            self.params['device'],
                            self.params['deterministic'],
                            self.params['optimizer'],
                            self.params['dropout_p'],
                            self.params['learning_rate'])
            model.load_model(path + str(i))
            GEs.append(model)

Toy Example

Lets consider the function $y = x^3$ as the ground truth.

The dataset is generated by sampling from $y = x^3+\epsilon$, where $\epsilon \sim \mathcal{N}(0,3^2)$.

X = torch.tensor([[np.random.uniform(-4,4)] for i in range(20)])
Y = torch.tensor([[x**3 + np.random.normal(0, std=3)] for x in xx])
x = np.linspace(-6, 6, 100).reshape(100, 1)
y = x**3
plt.plot(x, y, 'b-', label='ground truth: $y=x^3$')
plt.plot(X.numpy(),Y.numpy(),'or', label='data points')
plt.grid()
plt.xlabel('x')
plt.ylabel('y')
plt.legend()
computation_graph_args = {
    'out_dim': 1,
    'ensemble_size': 3,
    'n_layers': 3,
    'size': 100,
    'learning_rate': 0.001,
    'device': "cpu",
    'deterministic': False,
    'batch_size': 32 * 3,
    'display_step': 100,
    'epoch': 10,
    'optimizer': "Adam",
    'dropout_p': 0.2,
}
# Init the Guassian mixture model
GE = ModelEnsemble(computation_graph_args)


    epochs = computation_graph_args['epoch']
    display_step = computation_graph_args['display_step']

    COSTS = []
    epoch_cost = []
    count = 0

    for epoch in range(epochs):
        print("\nStart of epoch %d" % (epoch,))

        for (batch_idx, batch) in enumerate(train_generator):
            count += 1
            itr = batch_idx

            batch_x, batch_y = batch['observation'].cpu().detach().numpy(), batch['target'].cpu().detach().numpy()

            cost = GE.train(batch_x, batch_y)
            COSTS.append(cost)

        mean_train_loss = np.mean(COSTS[-len(self.train_generator):])
        print('Epoch train loss : ' + str(mean_train_loss))

        epoch_cost.append(mean_train_loss)

    print("Optimization Finished!")

means = []
variances = []
for model in GE.models:
    mean, var = model(torch.tensor(x).float())
    mean = mean.detach().numpy()
    var = var.detach().numpy()
    means.append(mean)
    variances.append(var)
    std = np.sqrt(var)
    plt.plot(x, mean, label='GMM (NLL) '+str(i+1),alpha=0.5)
    plt.fill_between(x.reshape(100,), (mean-std).reshape(100,), (mean+std).reshape(100,),alpha=0.1)
plt.plot(x, y, label='ground truth $y=x^3$', color='b')
  plt.plot( X.numpy(),Y.numpy(),'or', label='data points')
plt.title('Outputs of the network in the ensemble')
plt.xlabel('x')
plt.ylabel('y')

Osher Azulay
Osher Azulay
Roboticist

My research interests include robotic manipulation, deep reinforcement learning.