Detailed look at Optimizer
#
In the Gradient-based optimization tutorial, Mitsuba’s Optimizer
class was used to build an optimization loop. In this tutorial, we will study the API of optimizers in more detail. It is designed to be convenient to use when directly optimizing parameters in a Mitsuba scene, but also flexible enough for more flexible use-cases (e.g. chaining together a neural network and differentiable rendering step
in the same computation graph).
Let’s start by setting an AD-compatible variant.
[2]:
import mitsuba as mi
import drjit as dr
mi.set_variant('llvm_ad_rgb')
The basics#
To perform gradient-based optimization, Mitsuba ships with standard optimizers including Stochastic Gradient Descent (SGD) with and without momentum, as well as Adam. Those both inherit from the
Optimizer base class and can be found in the mitsuba.ad
submodule.
The Optimizer
class behaves like a Python dict
with some extra logic and methods. This how-to guide will take you through its API and highlight the best-practices and pittfalls related to this class.
Let’s first construct a simple SGD
optimizer with a learning rate of 0.25
.
[3]:
opt = mi.ad.SGD(lr=0.25)
We can now specify a variable to be optimized. The Optimizer
will automatically enable gradient computation on the stored variable as it is necessary for further computation to produce any gradients.
[4]:
opt['x'] = mi.Float(1.0)
opt
[4]:
SGD[
variables = ['x'],
lr = {'default': 0.25},
momentum = 0
]
It is also possible to directly pass a set of variables to be optimized in the constructor of the Optimizer
[5]:
opt = mi.ad.SGD(lr=0.25, params={'x': mi.Float(1.0), 'y': mi.Float(2.0)})
opt
[5]:
SGD[
variables = ['x', 'y'],
lr = {'default': 0.25},
momentum = 0
]
It also provides a similar API to perform basic dictionary manipulations.
[6]:
for k, v in opt.items():
print(f"{k}: {v}")
x: [1.0]
y: [2.0]
⚠ It is important to note that a copy of the variable is made when assigned to the Optimizer via the __setitem__
method. For instance in the following code, the original variable won’t be attached to the AD graph, and its value will remain unchanged.
[7]:
y = mi.Float(2.0)
opt['y'] = y # A copy is made here
opt['y'] += 1.0
print(f"Original: {y}, grad_enabled={dr.grad_enabled(y)}.")
print(f"Optimizer: {opt['y']}, grad_enabled={dr.grad_enabled(opt['y'])}.")
Original: [2.0], grad_enabled=False.
Optimizer: [3.0], grad_enabled=True.
It is therefore crucial to use the variable held by the optimizer to perform the differentiable computation in order to produce the proper gradients. Here is a simple example where x
and y
are used in some calculation for which we then request the gradients to be backpropagated. We can validate that the gradients are adequately propagated to the optimizer variables.
[8]:
z = opt['x'] + 2.0 * opt['y']
dr.backward(z)
print(f"x grad={dr.grad(opt['x'])}")
print(f"y grad={dr.grad(opt['y'])}")
x grad=[1.0]
y grad=[2.0]
During the optimization, the role of the optimizer will be to take a gradient step according to its update rule. In the case of a simple SGD optimizer with no momentum, the update rule is:
The Optimizer
method step() will apply its update rule to all the variables.
[9]:
print(f"Before the gradient step: x={opt['x']}, y={opt['y']}")
opt.step()
print(f"After the gradient step: x={opt['x']}, y={opt['y']}")
Before the gradient step: x=[1.0], y=[3.0]
After the gradient step: x=[0.75], y=[2.5]
After performing the update rule, the Optimizer
also resets the gradient values of its variables to 0.0
and ensures gradient computations are still enabled on all its variables. This ensures that everything is ready for the next iteration of the optimization loop.
Optimizing scene parameters#
In the context of differentiable rendering, we are often interested in optimizing parameters of a Mitsuba scene, which are exposed via the `traverse()
<https://mitsuba.readthedocs.io/en/latest/src/api_reference.html#mitsuba.traverse>`__ mechanism. This function returns a `SceneParameters
<https://mitsuba.readthedocs.io/en/latest/src/api_reference.html#mitsuba.SceneParameters>`__ object which exposes dictionary-like interface.
[26]:
scene = mi.load_file('../scenes/cbox.xml')
params = mi.traverse(scene)
params
[26]:
SceneParameters[
----------------------------------------------------------------------------------------
Name Flags Type Parent
----------------------------------------------------------------------------------------
sensor.near_clip float PerspectiveCamera
sensor.far_clip float PerspectiveCamera
sensor.shutter_open float PerspectiveCamera
sensor.shutter_open_time float PerspectiveCamera
sensor.x_fov float PerspectiveCamera
sensor.to_world Transform4f PerspectiveCamera
gray.reflectance.value ∂ Color3f SRGBReflectanceSpectrum
white.reflectance.value ∂ Color3f SRGBReflectanceSpectrum
green.reflectance.value ∂ Color3f SRGBReflectanceSpectrum
red.reflectance.value ∂ Color3f SRGBReflectanceSpectrum
glass.eta float SmoothDielectric
mirror.eta.value ∂, D Float UniformSpectrum
mirror.k.value ∂, D Float UniformSpectrum
mirror.specular_reflectance.value ∂ Float UniformSpectrum
light.emitter.radiance.value ∂ Color3f SRGBEmitterSpectrum
light.vertex_count int OBJMesh
light.face_count int OBJMesh
light.faces UInt OBJMesh
light.vertex_positions ∂, D Float OBJMesh
light.vertex_normals ∂, D Float OBJMesh
light.vertex_texcoords ∂ Float OBJMesh
floor.vertex_count int OBJMesh
floor.face_count int OBJMesh
floor.faces UInt OBJMesh
floor.vertex_positions ∂, D Float OBJMesh
floor.vertex_normals ∂, D Float OBJMesh
floor.vertex_texcoords ∂ Float OBJMesh
ceiling.vertex_count int OBJMesh
ceiling.face_count int OBJMesh
ceiling.faces UInt OBJMesh
ceiling.vertex_positions ∂, D Float OBJMesh
ceiling.vertex_normals ∂, D Float OBJMesh
ceiling.vertex_texcoords ∂ Float OBJMesh
back.vertex_count int OBJMesh
back.face_count int OBJMesh
back.faces UInt OBJMesh
back.vertex_positions ∂, D Float OBJMesh
back.vertex_normals ∂, D Float OBJMesh
back.vertex_texcoords ∂ Float OBJMesh
greenwall.vertex_count int OBJMesh
greenwall.face_count int OBJMesh
greenwall.faces UInt OBJMesh
greenwall.vertex_positions ∂, D Float OBJMesh
greenwall.vertex_normals ∂, D Float OBJMesh
greenwall.vertex_texcoords ∂ Float OBJMesh
redwall.vertex_count int OBJMesh
redwall.face_count int OBJMesh
redwall.faces UInt OBJMesh
redwall.vertex_positions ∂, D Float OBJMesh
redwall.vertex_normals ∂, D Float OBJMesh
redwall.vertex_texcoords ∂ Float OBJMesh
mirrorsphere.to_world Transform4f Sphere
glasssphere.to_world Transform4f Sphere
]
This is very convinient as it can be directly passed to the Optimizer
constructor to set all the scene parameters to be optimized. In this case the Optimizer
will actually ignore all the variables that are not marked as differentiable (see the ∂ flag in the print of params
object above).
However, it is sometimes preferable to only optimize a subset of the params, which can easily be achieved by filtering out items in the params
object using the keep
method. This method accepts both a list of keys or a REGEX.
[27]:
params.keep(r'.*\.reflectance\.value')
params
[27]:
SceneParameters[
------------------------------------------------------------------------------
Name Flags Type Parent
------------------------------------------------------------------------------
gray.reflectance.value ∂ Color3f SRGBReflectanceSpectrum
white.reflectance.value ∂ Color3f SRGBReflectanceSpectrum
green.reflectance.value ∂ Color3f SRGBReflectanceSpectrum
red.reflectance.value ∂ Color3f SRGBReflectanceSpectrum
]
We can now constructor an Optimizer that will optimize those reflectance values:
[30]:
opt = mi.ad.SGD(lr=0.25, params=params)
opt
[30]:
SGD[
variables = ['gray.reflectance.value', 'white.reflectance.value', 'green.reflectance.value', 'red.reflectance.value'],
lr = {'default': 0.25},
momentum = 0
]
Of course, as done before, it is also possible to start from an empty optimizer and set variables one by one from the params
object.
[31]:
opt = mi.ad.SGD(lr=0.25)
opt['red.reflectance.value'] = params['red.reflectance.value']
opt
[31]:
SGD[
variables = ['red.reflectance.value'],
lr = {'default': 0.25},
momentum = 0
]
As explained above, the loaded parameters will be copied internally, so any attempt to change their value in the optimizer will not directly be reflected in params
.
[32]:
opt['red.reflectance.value'] *= 0.5
print(f"params: {params['red.reflectance.value']}")
print(f"optimize: {opt['red.reflectance.value']}")
params: [[0.5700680017471313, 0.043013498187065125, 0.04437059909105301]]
optimize: [[0.2850340008735657, 0.021506749093532562, 0.022185299545526505]]
In order to propagate those changes to params
(and to the Scene
itself), we use the update() method of SceneParameters
. Internally this method will look for the variables keys of the optimizer matching with the ones in params
and overwrite the corresponding values. Then it will update the internal scene state (e.g. re-build the BVH when some geometry data was modified).
[33]:
params.update(opt);
print(f"params: {params['red.reflectance.value']}")
print(f"optimize: {opt['red.reflectance.value']}")
params: [[0.2850340008735657, 0.021506749093532562, 0.022185299545526505]]
optimize: [[0.2850340008735657, 0.021506749093532562, 0.022185299545526505]]
Optimizing latent variables#
In more complex optimization scenarios, scene parameters might be described as a function of some other parameters. In such a scenario, we would be interested in optimizing those other parameters instead of the scene parameters directly.
For example, this would be needed when generating the vertex positions of a mesh using a neural network: we’d want to optimize the weights of the neural network, not the vertex positions themselves. Another example could be a procedurally generated texture, maybe from a physically-based model that can be tweaked with a few parameters.
We refer to this type of “external” parameters as latent variables. Many of the design decisions of the Optimizer class were made to support latent variables.
For a simpler example, let’s consider the case where we are aiming at optimizing the translation vector of a 3D mesh object in our scene. Even from an convexity standpoint, optimizing those three translation values will be much easier that having to optimize all the vertex positions simultaneously and hope for the best.
Let’s fetch the scene parameters again and initialize our optimizer one more time.
[34]:
params = mi.traverse(scene)
opt = mi.ad.SGD(lr=0.25)
We can then append a latent variable to the optimizer, similar do what we did in the first section of this How-to Guide.
[35]:
opt['translation'] = mi.Vector3f(0, 0, 0)
In this scenario, it will be the user’s responsability to propagate the changes of the latent variable to the scene parameters every time the latent variable is updated. To this end, it’s helpful to define a dedicated update function.
Note that the vertex positions on a mesh are stored in a linearized fashion in Mitsuba (xyzxyzxyzxyz...
). In order to apply the translation, we first reorder the initial coordinates into 3D points. Once the translation has been applied, we need to write the new values back into the linearized array before the assignment to the scene parameters. For those operations, we use the dr.unravel
and dr.ravel
helper functions of DrJIT.
[36]:
# Copy or our vertex positions (and convert them to 3D points)
initial_vertex_pos = dr.unravel(mi.Point3f, params['redwall.vertex_positions'])
# Now we define the update rule
def update_vertex_pos():
# Create the translation transformation
T = mi.Transform4f.translate(opt['translation'])
# Apply the transformation to the vertex positions
new_vertex_pos = T @ initial_vertex_pos
# Flatten the vertex position array before assigning it to `params`
params['redwall.vertex_positions'] = dr.ravel(new_vertex_pos)\
# Propagate changes through the scene (e.g. rebuild BVH)
params.update()
With this function implemented, all we need to do is to make sure to call it after every optimizer step to update the vertex positions accordingly.
💭 On top of propagating the changes of values, calling update_vertex_pos()
also builds the computational graph necessary to the automatic differentiation layer to later compute the gradients of the scene parameters with respect to the latent variable.
Optimizer state#
On top of carrying variables to optimize, the Optimizer also holds an internal state used in its update rule. For instance, the momentum-based SGD optimizer tracks the velocity of the previous iteration to apply the momentum in its step() method. In most cases this state is stored on a per-parameter basis.
In some cases, it is useful to reset the state of an Optimizer
(e.g. optimization scheduling that resizes the optimized volume grid). For this, we can use the reset() method which will zero-initialize the internal state associated with a specific parameter.
[37]:
opt.reset('translation')
Another useful feature of the Mitsuba Optimizer
is its ability to mask the update of its internal state based on the presence or not of gradients w.r.t. each variables.
This can be useful in Monte Carlo simulations such as differentiable path tracing, where not all optimized parameters necessarily receive gradients at each iteration. We found that in this situation, updating the state of the optimizer for parameters that did not receive gradients may degrade convergence. The update should rather be discarded instead. When constructing the Optimizer
, it is possible to specific whether to discard such updates when the gradients are zero using the
mask_updates
argument (the default is False
).
[38]:
opt = mi.ad.SGD(lr=0.25, mask_updates=True)
Finally, the Mitsuba implementation of the SGD
optimizers also supports per-parameter learning rates. This is useful to control the magnitude of the gradient step taken on individual parameters. This can be achieved using the set_learning_rate()
method, which takes a optional key
argument to specify for which parameter to set the learning rate, or alternatively a dict
with all keys for which to override the learning rate.
[39]:
opt['x'] = mi.Float(1.0)
opt['y'] = mi.Float(1.0)
opt.set_learning_rate({'x': 0.125, 'y': 0.25})