Parameter containers.

Many models depend on several parameters of different types. In this notebook, we demonstrate flattening and folding “containers”, which are patterns containing other patterns.

Parameter dictionaries.

import autograd
import autograd.numpy as np
import example_utils
import matplotlib.pyplot as plt
import paragami
import time

Let’s consider a multivariate normal example. Define some parameters and draw some data.

dim = 2
num_obs = 1000

mean_true = np.random.random(dim)
cov_true = np.random.random((dim, dim))
cov_true = 0.1 * np.eye(dim) + np.full((dim, dim), 1.0)

x = np.random.multivariate_normal(mean=mean_true, cov=cov_true, size=(num_obs, ))
plt.plot(x[:, 0], x[:, 1], 'k.')
A multivariate normal distribution depends on two parameters – a mean and a positive definite covariance matrix. In Python, we might store these parameters in a dictionary with a member for the mean and a member for the covariance.

This can be represented with a pattern dictionary, i.e., a paragami.PatternDict pattern. Each member of a pattern dictionary must itself be a pattern.

mvn_pattern = paragami.PatternDict(free_default=True)
mvn_pattern['mean'] = paragami.NumericVectorPattern(length=dim)
mvn_pattern['cov'] = paragami.PSDSymmetricMatrixPattern(size=dim)

Flattening and folding work with pattern dictionaries just as with ordinary patterns. Note that folded pattern dicitonaries are OrderedDict by default.

true_mvn_par = dict()
true_mvn_par['mean'] = mean_true
true_mvn_par['cov'] = cov_true

print('\nA dictionary of MVN parameters:\n{}'.format(

mvn_par_free = mvn_pattern.flatten(true_mvn_par)
print('\nA flat representation:\n{}'.format(

print('\nFolding recovers the original parameters:\n{}'.format(

A dictionary of MVN parameters:
{'cov': array([[1.1, 1. ],
       [1. , 1.1]]), 'mean': array([0.87367236, 0.21280422])}

A flat representation:
[ 0.87367236  0.21280422  0.04765509  0.95346259 -0.82797896]

Folding recovers the original parameters:
OrderedDict([('mean', array([0.87367236, 0.21280422])), ('cov', array([[1.1, 1. ],
       [1. , 1.1]]))])

Parameter dictionaries are particularly convenient for optimization problems involving multiple parameters. A good working style is to implement the loss function using named arguments and then wrap it using a lambda function.

To illustrate this, let us use get_normal_log_prob from example_utils.

# ``example_utils.get_normal_log_prob`` returns the log probability of
# each datapoint x up to a constant.
def get_loss(x, sigma, mu):
    return -1 * np.sum(
        example_utils.get_normal_log_prob(x, sigma, mu))

get_free_loss = paragami.FlattenFunctionInput(
    lambda mvn_par: get_loss(x=x, sigma=mvn_par['cov'], mu=mvn_par['mean']),

print('Free loss:\t{}'.format(get_free_loss(mvn_par_free)))
print('Original loss:\t{}'.format(get_loss(x, true_mvn_par['cov'], true_mvn_par['mean'])))
Free loss:      219.8143711666299
Original loss:  219.8143711666299

As with other parameters, autograd works with parameter dictionaries.

get_free_loss_grad = autograd.grad(get_free_loss)
[-59.66882221  78.5304     -24.49767189 -12.36127117  36.00719091]

Pattern dictionaries containing pattern dictionaries.

Pattern dictionaries can contain pattern dictionaries (and so on).

mvns_pattern = paragami.PatternDict(free_default=True)
mvns_pattern['mvn1'] = mvn_pattern
mvns_pattern['mvn2'] = mvn_pattern
mvns_pattern['ez'] = paragami.SimplexArrayPattern(array_shape=(1, ), simplex_size=2)

mvns_par = dict()
mvns_par['mvn1'] = true_mvn_par
mvns_par['mvn2'] = mvn_pattern.random()
mvns_par['ez'] = np.array([[0.3, 0.7]])

print('Folded mvns_par:\n{}'.format(mvns_par))
print('\nFree mvns_par:\n{}'.format(mvns_pattern.flatten(mvns_par)))
Folded mvns_par:
{'mvn1': {'cov': array([[1.1, 1. ],
       [1. , 1.1]]), 'mean': array([0.87367236, 0.21280422])}, 'mvn2': OrderedDict([('mean', array([0.9666736 , 0.24848757])), ('cov', array([[4.77849291, 1.56784344],
       [1.56784344, 1.84355085]]))]), 'ez': array([[0.3, 0.7]])}

Free mvns_par:
[ 0.87367236  0.21280422  0.04765509  0.95346259 -0.82797896  0.9666736
  0.24848757  0.7820626   0.71722798  0.14226413  0.84729786]

Members of a pattern dictionary are just patterns, and can be used directly.

mvn_par_free = mvns_pattern['mvn1'].flatten(true_mvn_par)
print('\nA flat representation of true_mvn_par:\n{}'.format(

A flat representation of true_mvn_par:
[ 0.87367236  0.21280422  0.04765509  0.95346259 -0.82797896]

Locking parameter dictionaries.

The meaning of a parameter dictionary changes as you add or assign elements.

example_pattern = paragami.PatternDict(free_default=True)

example_pattern['par1'] = paragami.NumericVectorPattern(length=2)
print('The flat length of example_pattern with par1:\t\t\t{}'.format(

example_pattern['par2'] = paragami.NumericVectorPattern(length=3)
print('The flat length of example_pattern with par1 and par2:\t\t{}'.format(

example_pattern['par1'] = paragami.NumericVectorPattern(length=10)
print('The flat length of example_pattern with new par1 and par2:\t{}'.format(

The flat length of example_pattern with par1:                   2
The flat length of example_pattern with par1 and par2:          5
The flat length of example_pattern with new par1 and par2:      13

Sometime, you want to make sure the meaning of a parameter dictionary stays fixed. In order to prevent a pattern dictionary from having more elements added, you can lock() it.

    example_pattern['par3'] = paragami.NumericVectorPattern(length=4)
except ValueError as err:
    print('Adding a new pattern failed with the following error:\n{}'.format(err))

    example_pattern['par1'] = paragami.NumericVectorPattern(length=4)
except ValueError as err:
    print('\nChanging an existing pattern failed with the following error:\n{}'.format(err))
Adding a new pattern failed with the following error:
The dictionary is locked, and its values cannot be changed.

Changing an existing pattern failed with the following error:
The dictionary is locked, and its values cannot be changed.

Parameter arrays.

Sometimes it is useful to have arrays of patterns. A classic use case is mixture distributions. Suppose that, for some \(K\), and probabilities \(\pi_{k}\) for \(k=1\) to \(K\),

\[\begin{align} P(y_n | \pi, \mu_1,...,\mu_K, \Sigma_1,...,\Sigma_K) &= \sum_{k=1}^K \pi_{k} \mathcal{N}\left(y_n | \mu_k, \Sigma_k\right). \end{align}\]

Then the random variable \(y_n\) is distributed according to a mixture of normals. Let us define parameters and draw some data.

num_components = 4
prob_true = np.arange(1, num_components + 1)
prob_true = prob_true / np.sum(prob_true)

k_true = np.random.choice(range(num_components), p=prob_true, size=num_obs)

means_true = np.array([
    np.full(dim, float(k)) for k in range(num_components) ])

covs_true = np.array([
    0.1 * (k + 1) * np.eye(dim) for k in range(num_components) ])

y = np.array([
        means_true[k_true[n], :],
        covs_true[k_true[n], :, :]) for n in range(num_obs) ])

for k in range(num_components):
    k_rows = k_true == k
    plt.plot(y[k_rows, 0], y[k_rows, 1], '.')

To define parameters for the mixture, we have defined arrays of shape (num_components, ) containing the means and covariances. The means can be represented with an orderinary NumericArrayPattern. But to represent the array of covariances with paragami, we can use pattern arrays, i.e, paragami.PatternArray.

mix_pattern = paragami.PatternDict(free_default=True)
mix_pattern['prob'] = paragami.SimplexArrayPattern(
    array_shape=(1, ), simplex_size=num_components)
mix_pattern['means'] = paragami.NumericArrayPattern(shape=(num_components, dim))
mix_pattern['covs'] = paragami.PatternArray(
    array_shape=(num_components, ),

true_mix_par = dict()
true_mix_par['means'] = means_true
true_mix_par['covs'] = covs_true
true_mix_par['prob'] = np.expand_dims(prob_true, axis=0)

true_mix_free = mix_pattern.flatten(true_mix_par)
print('Folding recovers the true mixture parameters:\n{}'.format(
Folding recovers the true mixture parameters:
OrderedDict([('prob', array([[0.1, 0.2, 0.3, 0.4]])), ('means', array([[0., 0.],
       [1., 1.],
       [2., 2.],
       [3., 3.]])), ('covs', array([[[0.1, 0. ],
        [0. , 0.1]],

       [[0.2, 0. ],
        [0. , 0.2]],

       [[0.3, 0. ],
        [0. , 0.3]],

       [[0.4, 0. ],
        [0. , 0.4]]]))])

Pattern arrays have some limitations. For one, they can only contain numeric types. This means you cannot have arrays of pattern dictionaries.

example_pattern = paragami.PatternDict()
example_pattern['a'] = paragami.NumericVectorPattern(length=2)

    paragami.PatternArray(array_shape=(2, ), base_pattern=example_pattern)
except NotImplementedError as err:
    print('Attempting to create a PatternArray of PatternDicts failed with the error:\n{}'.format(
Attempting to create a PatternArray of PatternDicts failed with the error:
PatternArray does not support patterns whose folded values are not numpy.ndarray types.

Also, under the hood, PatternArray types fold and flatten using a for loop over the array elements. For this reason, it is more effecient to use a different type if possible.

In particular, SimplexArray patterns will be more efficient than a PatternArray of SimplexArray, as the following example shows.

import time

array_shape = (50, 10)
test_array = paragami.PatternArray(
    base_pattern=paragami.SimplexArrayPattern(array_shape=(1, ), simplex_size=5))

test_simplex = paragami.SimplexArrayPattern(array_shape=array_shape, simplex_size=5)

simplex_val = test_simplex.random()
simplex_array_val = np.expand_dims(simplex_val, axis=2)

test_times = 10

simplex_time = time.time()
for i in range(test_times):
    test_simplex.flatten(simplex_val, free=True)
simplex_time = time.time() - simplex_time

array_time = time.time()
for i in range(test_times):
    test_array.flatten(simplex_array_val, free=True)
array_time = time.time() - array_time

print('Array time:\t{}\nSimplex time:\t{}'.format(array_time, simplex_time))
Array time:     0.30515265464782715
Simplex time:   0.0015544891357421875