Caustics optimization#


This tutorial contains an advanced inverse rendering example: recover the surface displacement (heightmap) of a slab of glass such that light passing through it focuses into a specific desired image.

This reproduces the results showcased in Section 4.3 of Mitsuba 2: A Retargetable Forward and Inverse Renderer.

🚀 You will learn how to:

  • Creating a simple mesh from Python

  • Use the particle tracer integrator

  • Loading a scene define procedurally from Python

  • Applying a heightmap to a mesh from Python

  • Optimizing “latent” variables, i.e. variables which are not directly defined as part of the scene but affect it

The scene will be setup as follows:

  1. A directional area light (white or colorful, depending on the target image)

  2. Light from the emitter passes through a glass slab. We will optimize the slab’s surface (via a heightmap)…

  3. …so that light is focused on a receiving plane in a way that reproduces a desired target image.

In order to efficiently render and optimize this scene, we will use the Particle Tracer integrator (ptracer), which traces rays from the emitter rather than the sensor.

Caustic Optimization diagram


We start by importing Mitsuba and selecting an appropriate variant supporting automatic differentiation (AD) as it is required to compute gradients with respect to the slab’s surface.

import os
from os.path import realpath, join

import drjit as dr
import mitsuba as mi


Choose a configuration#

In this tutorial, we can attempt to reproduce either a grayscale image using a uniform emitter, or a color image using an RGB emitter. Here, we define those two options and select one.

Feel free to define additional configurations, e.g. to target a different reference image of your choice.

SCENE_DIR = realpath('../scenes')

    'wave': {
        'emitter': 'gray',
        'reference': join(SCENE_DIR, 'references/wave-1024.jpg'),
    'sunday': {
        'emitter': 'bayer',
        'reference': join(SCENE_DIR, 'references/sunday-512.jpg'),

# Pick one of the available configs
config_name = 'sunday'
# config_name = 'wave'

config = CONFIGS[config_name]

# Add common options
    'render_resolution': (64, 64),
    'heightmap_resolution': (128, 128),
    'spp': 32,
    'max_iterations': 100,
    'learning_rate': 3e-5,

print('[i] Reference image selected:', config['reference'])

output_dir = realpath(join('.', 'outputs', config_name))
os.makedirs(output_dir, exist_ok=True)
print('[i] Results will be saved to:', output_dir)
[i] Reference image selected: /Users/speierer/projects/mitsuba3/tutorials/scenes/references/sunday-512.jpg
[i] Results will be saved to: /Users/speierer/projects/mitsuba3/tutorials/getting_started/inverse_rendering/outputs/sunday

Creating the scene#

Depending on the chosen configuration, a different type of emitter will need to be used. For this reason, we define the scene dynamically directly from Python as a dictionary and load it with load_dict().

# Make sure that resources from the scene directory can be found

Creating the lens mesh#

The goal of the optimization is to recover the heightfield that needs to be applied to a slab of glass so that it focuses light in just the right way to reproduce the desired target image.

The heightmap will be represented as a texture and applied to the slab’s vertices. For this technique to be effective, the slab must have enough geometric resolution (vertices) to match the heightmap texture.

Here, we generate the appropriate mesh directly from Python: a simple tesselated plane with the desired resolution and save it to disk.

def create_flat_lens_mesh(resolution):
    # Generate UV coordinates
    U, V = dr.meshgrid(
        dr.linspace(mi.Float, 0, 1, resolution[0]),
        dr.linspace(mi.Float, 0, 1, resolution[1]),
    texcoords = mi.Vector2f(U, V)

    # Generate vertex coordinates
    X = 2.0 * (U - 0.5)
    Y = 2.0 * (V - 0.5)
    vertices = mi.Vector3f(X, Y, 0.0)

    # Create two triangles per grid cell
    faces_x, faces_y, faces_z = [], [], []
    for i in range(resolution[0] - 1):
        for j in range(resolution[1] - 1):
            v00 = i * resolution[1] + j
            v01 = v00 + 1
            v10 = (i + 1) * resolution[1] + j
            v11 = v10 + 1
            faces_x.extend([v00, v01])
            faces_y.extend([v10, v10])
            faces_z.extend([v01, v11])

    # Assemble face buffer
    faces = mi.Vector3u(faces_x, faces_y, faces_z)

    # Instantiate the mesh object
    mesh = mi.Mesh("lens-mesh", resolution[0] * resolution[1], len(faces_x), has_vertex_texcoords=True)

    # Set its buffers
    mesh_params = mi.traverse(mesh)
    mesh_params['vertex_positions'] = dr.ravel(vertices)
    mesh_params['vertex_texcoords'] = dr.ravel(texcoords)
    mesh_params['faces'] = dr.ravel(faces)

    return mesh
lens_res = config.get('lens_res', config['heightmap_resolution'])
lens_fname = join(output_dir, 'lens_{}_{}.ply'.format(*lens_res))

if not os.path.isfile(lens_fname):
    m = create_flat_lens_mesh(lens_res)
    print('[+] Wrote lens mesh ({}x{} tesselation) file to: {}'.format(*lens_res, lens_fname))

Creating the emitter#

As explained previously, depending on whether we are trying to reproduce a grayscale or colorful target image, we setup the emitter to either emit constant white light or an RGB Bayer pattern. In the latter case, the pattern is generated on-the-fly and passed to the emitter as an in-memory Bitmap texture.

emitter = None
if config['emitter'] == 'gray':
    emitter = {
        'radiance': {
            'type': 'spectrum',
            'value': 0.8
elif config['emitter'] == 'bayer':
    bayer = dr.zeros(mi.TensorXf, (32, 32, 3))
    bayer[ ::2,  ::2, 2] = 2.2
    bayer[ ::2, 1::2, 1] = 2.2
    bayer[1::2, 1::2, 0] = 2.2

    emitter = {
        'radiance': {
            'type': 'bitmap',
            'bitmap': mi.Bitmap(bayer),
            'raw': True,
            'filter_type': 'nearest'

Assemble the scene#

The sensor looks directly at the receiving plane where the caustic will be formed. The light source and optimized lens will stand behind the camera. Note that since the camera is an idealized pinhole camera and does not occupy any space, it will not cast any shadow on the receiving plane.

# Looking at the receiving plane, not looking through the lens
sensor_to_world = mi.ScalarTransform4f.look_at(
    target=[0, -20, 0],
    origin=[0, -4.65, 0],
    up=[0, 0, 1]
resx, resy = config['render_resolution']
sensor = {
    'type': 'perspective',
    'near_clip': 1,
    'far_clip': 1000,
    'fov': 45,
    'to_world': sensor_to_world,

    'sampler': {
        'type': 'independent',
        'sample_count': 512  # Not really used
    'film': {
        'type': 'hdrfilm',
        'width': resx,
        'height': resy,
        'pixel_format': 'rgb',
        'rfilter': {
            # Important: smooth reconstruction filter with a footprint larger than 1 pixel.
            'type': 'gaussian'

The chosen light source emits light in a single direction, which would be very difficult (or impossible) to sample correctly with a standard path tracer. For this reason, we use a particle tracer (ptracer), which starts rays from the emitters rather than the sensor.

integrator = {
    'type': 'ptracer',
    'samples_per_pass': 256,
    'max_depth': 4,
    'hide_emitters': False,

We can now put everything together into a single large dictionary, where we also define the remaining geometry (receiving plane, geometry, etc).

scene = {
    'type': 'scene',
    'sensor': sensor,
    'integrator': integrator,
    # Glass BSDF
    'simple-glass': {
        'type': 'dielectric',
        'id': 'simple-glass-bsdf',
        'ext_ior': 'air',
        'int_ior': 1.5,
        'specular_reflectance': { 'type': 'spectrum', 'value': 0 },
    'white-bsdf': {
        'type': 'diffuse',
        'id': 'white-bsdf',
        'reflectance': { 'type': 'rgb', 'value': (1, 1, 1) },
    'black-bsdf': {
        'type': 'diffuse',
        'id': 'black-bsdf',
        'reflectance': { 'type': 'spectrum', 'value': 0 },
    # Receiving plane
    'receiving-plane': {
        'type': 'obj',
        'id': 'receiving-plane',
        'filename': 'meshes/rectangle.obj',
        'to_world': \
                target=[0, 1, 0],
                origin=[0, -7, 0],
                up=[0, 0, 1]
            ).scale((5, 5, 5)),
        'bsdf': {'type': 'ref', 'id': 'white-bsdf'},
    # Glass slab, excluding the 'exit' face (added separately below)
    'slab': {
        'type': 'obj',
        'id': 'slab',
        'filename': 'meshes/slab.obj',
        'to_world': mi.ScalarTransform4f.rotate(axis=(1, 0, 0), angle=90),
        'bsdf': {'type': 'ref', 'id': 'simple-glass'},
    # Glass rectangle, to be optimized
    'lens': {
        'type': 'ply',
        'id': 'lens',
        'filename': lens_fname,
        'to_world': mi.ScalarTransform4f.rotate(axis=(1, 0, 0), angle=90),
        'bsdf': {'type': 'ref', 'id': 'simple-glass'},

    # Directional area emitter placed behind the glass slab
    'focused-emitter-shape': {
        'type': 'obj',
        'filename': 'meshes/rectangle.obj',
        'to_world': mi.ScalarTransform4f.look_at(
            target=[0, 0, 0],
            origin=[0, 5, 0],
            up=[0, 0, 1]
        'bsdf': {'type': 'ref', 'id': 'black-bsdf'},
        'focused-emitter': emitter,

Finally, the scene is loaded which instantiates all of the appropriate plugins, loads the geometry, etc.

scene = mi.load_dict(scene)

Reference image#

Now that the sensor has been defined, we can load the reference image and ensure that its resolution matches the render resolution.

def load_ref_image(config, resolution, output_dir):
    b = mi.Bitmap(config['reference'])
    b = b.convert(mi.Bitmap.PixelFormat.RGB, mi.Bitmap.Float32, False)
    if b.size() != resolution:
        b = b.resample(resolution)

    mi.util.write_bitmap(join(output_dir, 'out_ref.exr'), b)

    print('[i] Loaded reference image from:', config['reference'])
    return mi.TensorXf(b)

# Make sure the reference image will have a resolution matching the sensor
sensor = scene.sensors()[0]
crop_size =
image_ref = load_ref_image(config, crop_size, output_dir=output_dir)
scaled_image_ref = image_ref / dr.mean(image_ref)
[i] Loaded reference image from: /Users/speierer/projects/mitsuba3/tutorials/scenes/references/sunday-512.jpg

Displacement texture#

Rather than optimizing the unconstrained vertex positions of the lens directly, we optimize values of a high-resolution heightmap. Here, we create the heightmap texture and create an optimizer that will work on its values.

Notice how the traverse() method is used directly on our new texture object, rather than on the scene loaded earlier.

heightmap_texture = mi.load_dict({
    'type': 'bitmap',
    'id': 'heightmap_texture',
    'bitmap': mi.Bitmap(dr.zeros(mi.TensorXf, config['heightmap_resolution'])),
    'raw': True,

# Actually optimized: the heightmap texture
params = mi.traverse(heightmap_texture)
opt =['learning_rate'], params=params)

At each iteration, the lens’ vertices will displaced from their original position along their normal by the value of the heightmap. Don’t forget that the geometric resolution of the lens mesh (number of vertices) must also be high enough for this technique to work as expected.

params_scene = mi.traverse(scene)

# We will always apply displacements along the original normals and
# starting from the original positions.
positions_initial = dr.unravel(mi.Vector3f, params_scene['lens.vertex_positions'])
normals_initial   = dr.unravel(mi.Vector3f, params_scene['lens.vertex_normals'])

lens_si = dr.zeros(mi.SurfaceInteraction3f, dr.width(positions_initial))
lens_si.uv = dr.unravel(type(lens_si.uv), params_scene['lens.vertex_texcoords'])

def apply_displacement(amplitude = 1.):
    # Enforce reasonable range. For reference, the receiving plane
    # is 7 scene units away from the lens.
    vmax = 1 / 100.
    params['data'] = dr.clamp(params['data'], -vmax, vmax)

    height_values = heightmap_texture.eval_1(lens_si)
    new_positions = (height_values * normals_initial * amplitude + positions_initial)
    params_scene['lens.vertex_positions'] = dr.ravel(new_positions)

Step 6 - Running the optimization#

We’re finally ready to start the optimization itself!

At each iteration, we apply the current heightmap displacement to the lens surface and render the scene with automatic differentiation enabled.

We then compare the render to our target image with a scale-independent L2 loss. We divide out the average brightness in the loss so that the general brightness of the emitter (set arbitrarily) does not interfere with the optimization.

After backpropagating through the computation graph, we use the gradients of the loss w.r.t. the heightmap values to update the heightmap.

import time
start_time = time.time()
iterations = config['max_iterations']
loss_values = []

for it in range(iterations):
    t0 = time.time()

    # Apply displacement and update the scene BHV accordingly

    # Perform a differentiable rendering of the scene
    image = mi.render(scene, params, seed=it, spp=config['spp'])

    # Scale-independent L2 function
    loss = dr.mean(dr.sqr((image / dr.mean(dr.detach(image))) - scaled_image_ref))

    # Back-propagate errors to input parameters and take an optimizer step

    # Take a gradient step

    # Carry over the update to our "latent variable" (the heightmap values)

    # Log progress
    elapsed_ms = 1000. * (time.time() - t0)
    current_loss = loss[0]
        it / (iterations-1),
        f'Iteration {it:03d}: loss={current_loss:g} (took {elapsed_ms:.0f}ms)',
        'Caustic Optimization', '')

end_time = time.time()
print(((end_time - start_time) * 1000) / iterations, ' ms per iteration on average')
171.21176958084106  ms per iteration on average

Once the optimization has completed, we save the final heightmap and the corresponding lens with displacement applied.

fname = join(output_dir, 'heightmap_final.exr')
mi.util.write_bitmap(fname, params['data'])
print('[+] Saved final heightmap state to:', fname)

fname = join(output_dir, 'lens_displaced.ply')
lens_mesh = [m for m in scene.shapes() if == 'lens'][0]
print('[+] Saved displaced lens to:', fname)
[+] Saved final heightmap state to: /Users/speierer/projects/mitsuba3/tutorials/getting_started/inverse_rendering/outputs/sunday/heightmap_final.exr
2022-07-01 13:54:17 INFO main [Mesh] Writing mesh to "/Users/speierer/projects/mitsuba3/tutorials/getting_started/inverse_rendering/outputs/sunday/lens_displaced.ply" ..
2022-07-01 13:54:17 INFO main [Mesh] "/Users/speierer/projects/mitsuba3/tutorials/getting_started/inverse_rendering/outputs/sunday/lens_displaced.ply": wrote 32258 faces, 16384 vertices (890 KiB in 1ms)
[+] Saved displaced lens to: /Users/speierer/projects/mitsuba3/tutorials/getting_started/inverse_rendering/outputs/sunday/lens_displaced.ply

Visualize results#

Plot the evolution of the loss and show the final result next to the reference.

import matplotlib.pyplot as plt

def show_image(ax, img, title):

def show_heightmap(fig, ax, values, title):
    values = params['data'].numpy().reshape(lens_res)
    im = ax.imshow(values, vmax=1e-4)
    fig.colorbar(im, ax=ax)

fig, ax = plt.subplots(2, 2, figsize=(11, 10))
ax = ax.ravel()
ax[0].set_xlabel('Iteration'); ax[0].set_ylabel('Loss value'); ax[0].set_title('Convergence plot')

show_heightmap(fig, ax[1], params['data'], 'Final heightmap')
show_image(ax[2], image_ref, 'Reference')
show_image(ax[3], image,     'Final state')

Congratulations! Feel free to define your own target images at the top of the notebook and run this tutorial again.

If you would like to improve the quality of the results, you could try the following: - Letting the optimization run for more iterations - Tweaking the learning rate and sample count - Progressively increasing the resolution of the heightmap through optimization, e.g. starting from a 16x16 heightmap and doubling the resolution every N iterations.

See also#