.. DO NOT EDIT.
.. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY.
.. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE:
.. "intermediate/ensembling.py"
.. LINE NUMBERS ARE GIVEN BELOW.

.. only:: html

    .. note::
        :class: sphx-glr-download-link-note

        Click :ref:`here <sphx_glr_download_intermediate_ensembling.py>`
        to download the full example code

.. rst-class:: sphx-glr-example-title

.. _sphx_glr_intermediate_ensembling.py:


Model ensembling
================

This tutorial illustrates how to vectorize model ensembling using ``torch.vmap``.

What is model ensembling?
-------------------------
Model ensembling combines the predictions from multiple models together.
Traditionally this is done by running each model on some inputs separately
and then combining the predictions. However, if you're running models with
the same architecture, then it may be possible to combine them together
using ``torch.vmap``. ``vmap`` is a function transform that maps functions across
dimensions of the input tensors. One of its use cases is eliminating
for-loops and speeding them up through vectorization.

Let's demonstrate how to do this using an ensemble of simple MLPs.

.. note::

   This tutorial requires PyTorch 2.0.0 or later.

.. GENERATED FROM PYTHON SOURCE LINES 24-47

.. code-block:: default


    import torch
    import torch.nn as nn
    import torch.nn.functional as F
    torch.manual_seed(0)

    # Here's a simple MLP
    class SimpleMLP(nn.Module):
        def __init__(self):
            super(SimpleMLP, self).__init__()
            self.fc1 = nn.Linear(784, 128)
            self.fc2 = nn.Linear(128, 128)
            self.fc3 = nn.Linear(128, 10)

        def forward(self, x):
            x = x.flatten(1)
            x = self.fc1(x)
            x = F.relu(x)
            x = self.fc2(x)
            x = F.relu(x)
            x = self.fc3(x)
            return x


.. GENERATED FROM PYTHON SOURCE LINES 48-52

Let’s generate a batch of dummy data and pretend that we’re working with
an MNIST dataset. Thus, the dummy images are 28 by 28, and we have a
minibatch of size 64. Furthermore, lets say we want to combine the predictions
from 10 different models.

.. GENERATED FROM PYTHON SOURCE LINES 52-61

.. code-block:: default


    device = 'cuda'
    num_models = 10

    data = torch.randn(100, 64, 1, 28, 28, device=device)
    targets = torch.randint(10, (6400,), device=device)

    models = [SimpleMLP().to(device) for _ in range(num_models)]


.. GENERATED FROM PYTHON SOURCE LINES 62-66

We have a couple of options for generating predictions. Maybe we want to
give each model a different randomized minibatch of data. Alternatively,
maybe we want to run the same minibatch of data through each model (e.g.
if we were testing the effect of different model initializations).

.. GENERATED FROM PYTHON SOURCE LINES 68-69

Option 1: different minibatch for each model

.. GENERATED FROM PYTHON SOURCE LINES 69-73

.. code-block:: default


    minibatches = data[:num_models]
    predictions_diff_minibatch_loop = [model(minibatch) for model, minibatch in zip(models, minibatches)]


.. GENERATED FROM PYTHON SOURCE LINES 74-75

Option 2: Same minibatch

.. GENERATED FROM PYTHON SOURCE LINES 75-79

.. code-block:: default


    minibatch = data[0]
    predictions2 = [model(minibatch) for model in models]


.. GENERATED FROM PYTHON SOURCE LINES 80-93

Using ``vmap`` to vectorize the ensemble
------------------------------------

Let's use ``vmap`` to speed up the for-loop. We must first prepare the models
for use with ``vmap``.

First, let’s combine the states of the model together by stacking each
parameter. For example, ``model[i].fc1.weight`` has shape ``[784, 128]``; we are
going to stack the ``.fc1.weight`` of each of the 10 models to produce a big
weight of shape ``[10, 784, 128]``.

PyTorch offers the ``torch.func.stack_module_state`` convenience function to do
this.

.. GENERATED FROM PYTHON SOURCE LINES 93-97

.. code-block:: default

    from torch.func import stack_module_state

    params, buffers = stack_module_state(models)


.. GENERATED FROM PYTHON SOURCE LINES 98-102

Next, we need to define a function to ``vmap`` over. The function should,
given parameters and buffers and inputs, run the model using those
parameters, buffers, and inputs. We'll use ``torch.func.functional_call``
to help out:

.. GENERATED FROM PYTHON SOURCE LINES 102-114

.. code-block:: default


    from torch.func import functional_call
    import copy

    # Construct a "stateless" version of one of the models. It is "stateless" in
    # the sense that the parameters are meta Tensors and do not have storage.
    base_model = copy.deepcopy(models[0])
    base_model = base_model.to('meta')

    def fmodel(params, buffers, x):
        return functional_call(base_model, (params, buffers), (x,))


.. GENERATED FROM PYTHON SOURCE LINES 115-121

Option 1: get predictions using a different minibatch for each model.

By default, ``vmap`` maps a function across the first dimension of all inputs to
the passed-in function. After using ``stack_module_state``, each of
the ``params`` and buffers have an additional dimension of size 'num_models' at
the front, and minibatches has a dimension of size 'num_models'.

.. GENERATED FROM PYTHON SOURCE LINES 121-133

.. code-block:: default


    print([p.size(0) for p in params.values()]) # show the leading 'num_models' dimension

    assert minibatches.shape == (num_models, 64, 1, 28, 28) # verify minibatch has leading dimension of size 'num_models'

    from torch import vmap

    predictions1_vmap = vmap(fmodel)(params, buffers, minibatches)

    # verify the ``vmap`` predictions match the
    assert torch.allclose(predictions1_vmap, torch.stack(predictions_diff_minibatch_loop), atol=1e-3, rtol=1e-5)


.. GENERATED FROM PYTHON SOURCE LINES 134-139

Option 2: get predictions using the same minibatch of data.

``vmap`` has an ``in_dims`` argument that specifies which dimensions to map over.
By using ``None``, we tell ``vmap`` we want the same minibatch to apply for all of
the 10 models.

.. GENERATED FROM PYTHON SOURCE LINES 139-144

.. code-block:: default


    predictions2_vmap = vmap(fmodel, in_dims=(0, 0, None))(params, buffers, minibatch)

    assert torch.allclose(predictions2_vmap, torch.stack(predictions2), atol=1e-3, rtol=1e-5)


.. GENERATED FROM PYTHON SOURCE LINES 145-151

A quick note: there are limitations around what types of functions can be
transformed by ``vmap``. The best functions to transform are ones that are pure
functions: a function where the outputs are only determined by the inputs
that have no side effects (e.g. mutation). ``vmap`` is unable to handle mutation
of arbitrary Python data structures, but it is able to handle many in-place
PyTorch operations.

.. GENERATED FROM PYTHON SOURCE LINES 153-156

Performance
-----------
Curious about performance numbers? Here's how the numbers look.

.. GENERATED FROM PYTHON SOURCE LINES 156-167

.. code-block:: default


    from torch.utils.benchmark import Timer
    without_vmap = Timer(
        stmt="[model(minibatch) for model, minibatch in zip(models, minibatches)]",
        globals=globals())
    with_vmap = Timer(
        stmt="vmap(fmodel)(params, buffers, minibatches)",
        globals=globals())
    print(f'Predictions without vmap {without_vmap.timeit(100)}')
    print(f'Predictions with vmap {with_vmap.timeit(100)}')


.. GENERATED FROM PYTHON SOURCE LINES 168-176

There's a large speedup using ``vmap``!

In general, vectorization with ``vmap`` should be faster than running a function
in a for-loop and competitive with manual batching. There are some exceptions
though, like if we haven’t implemented the ``vmap`` rule for a particular
operation or if the underlying kernels weren’t optimized for older hardware
(GPUs). If you see any of these cases, please let us know by opening an issue
on GitHub.


.. rst-class:: sphx-glr-timing

   **Total running time of the script:** ( 0 minutes  0.000 seconds)


.. _sphx_glr_download_intermediate_ensembling.py:

.. only:: html

  .. container:: sphx-glr-footer sphx-glr-footer-example


    .. container:: sphx-glr-download sphx-glr-download-python

      :download:`Download Python source code: ensembling.py <ensembling.py>`

    .. container:: sphx-glr-download sphx-glr-download-jupyter

      :download:`Download Jupyter notebook: ensembling.ipynb <ensembling.ipynb>`


.. only:: html

 .. rst-class:: sphx-glr-signature

    `Gallery generated by Sphinx-Gallery <https://sphinx-gallery.github.io>`_