Applying Accelerated Raymarching to Reduce Rendering Time

  • #1
Aspiring_MLE
1
0
TL;DR Summary
I am attempting to add accelerated raymarching to HashNeRF, a PyTorch implementation of NeRF, with the goal of reducing rendering time. The issue is that when I run HashNeRF (100 iterations on Google COLAB's T4 GPU) with the early termination implementation the `rendering_time` is an order of magnitude bigger than the default HashNeRF (`rendering_time_early = 170 s` vs. `rendering_time = 12 s`), while I was expecting to get `rendering_time_early < rendering_time`.
Hello world! My name is Jorge and this is my first post on Physics Forums. It is great to be part of the community!

I hope this post can also be helpful to others so I will include as much detail as possible. Please feel free to ask for more material if needed.

I am attempting to add accelerated raymarching to HashNeRF (https://github.com/yashbhalgat/HashNeRF-pytorch/issues/28), a PyTorch implementation of NeRF (https://github.com/yashbhalgat/HashNeRF-pytorch), with the goal of reducing rendering time. Specifically, I am trying to implement an early ray termination technique in order to stop adding samples along any ray that has already accumulated enough transmittance (opacity close to 1).

I have found inspiration in the early ray termination method employed in instant-ngp implementation (https://github.com/NVlabs/instant-ngp/blob/master/src/testbed_nerf.cu).

The issue is that when I run HashNeRF (100 iterations on Google COLAB's T4 GPU) with the early termination implementation the `rendering_time` is an order of magnitude bigger than the default HashNeRF (`rendering_time_early = 170 s` vs. `rendering_time = 12 s`), while I was expecting to get `rendering_time_early < rendering_time`.

I followed the following steps to incorporate accelerated raymarching to HashNeRF

1. Defining an early termination technique​


I started by adding an early_ray_termination.py file

early_ray_termination.py:
import torch
import torch.nn.functional as F


def apply_early_termination(raw, z_vals, rays_d, transmittance, accumulated_rgb, accumulated_weights, i, early_termination_threshold):
    """
    Apply early ray termination logic during raymarching.

    Args:
        raw: The raw model output at the current sample [num_rays, 4].
        z_vals: Sampled depth values along the ray [num_rays, num_samples along the ray].
        rays_d: Ray directions [num_rays, 3].
        transmittance: Current transmittance value (T) [num_rays, 1].
        accumulated_rgb: Accumulated RGB values for rays [num_rays, 3].
        accumulated_weights: Accumulated opacity (weights) [num_rays].
        i: Current sample index.
        early_termination_threshold: Threshold for stopping raymarching early.

    Returns:
        updated_transmittance, updated_rgb, updated_weights, stop_raymarching (bool)
    """

    # # Extract the device from transmittance, as it should always be on the correct device
    # device = transmittance.device

    # # Initialize sparsity_loss on the same device as transmittance
    # sparsity_loss = torch.tensor(0.0, device=device)


    #Extract sigma and RGB values from raw model output
    sigma = F.relu(raw[..., 3])  #Density
    rgb = torch.sigmoid(raw[..., :3])  #Color

    #Compute delta between samples
    if i + 1 < z_vals.shape[1]:
        delta = z_vals[:, i + 1] - z_vals[:, i]
    else:
        delta = 1e10  # Far plane
    delta = delta * torch.norm(rays_d, dim=-1)  # Convert to real-world distances

    #Compute alpha and weights
    alpha = 1. - torch.exp(-sigma * delta)  # Alpha for this sample
    weights = transmittance * alpha  # Weighted alpha
    #print(f"Shape of weights:\n {weights.shape}\n")

    #Update transmittance, RGB and accumulated weights
    transmittance = transmittance * (1. - alpha)  # Update transmittance
    
    # Accumulate RGB contributions, summing over the sample dimension
    accumulated_rgb += torch.sum(weights[..., None] * rgb, -2)
    #print("Shape of accumulated_weights:", accumulated_weights.shape)
    #accumulated_rgb += weights[..., None] * rgb  # Add weighted RGB
    
    #Accumulate opacity weights (weights shape [1024, 1024] needs to match with accumulated_weights [1024] so we reduce the dimension of weights)
    accumulated_weights[:, i] = weights[:, i]  # Track weights per sample
    #accumulated_weights += torch.sum(weights, dim=-1) #weights 

    # Check for early termination
    stop_raymarching = torch.max(transmittance) < early_termination_threshold

    return transmittance, accumulated_rgb, accumulated_weights, stop_raymarching

Let me explain the logic behind it. The `sigma` parameter represents the (volumetric) density at the sampled point within the ray. The density is needed to compute the opacity of the given ray point. I used `ReLU` to make sure only positive density values are included. The `rbg` parameter stores the color of the sampled point. I used `sigmoid` to make sure color values are mapped to the range `[0, 1]`. Next the distance (delta) between consecutive sample points along a ray is calculated, measured in real-world distance units (I used norm to convert the normalized coordinates to the real world coordinates). This `delta` variable is used later to determine the contribution of a sample to the opacity (`alpha` parameter).Then I compute how much light is absorbed or scattered (opacity) at a particular point along a ray and then determine the contribution of that point (weights) to the final rendered result.In the last part I go to compute how much light a sample point along a ray contributes to the final image (color and opacity), update the remaining light (transmittance) and accumulate the results over all sampled points. Raymarching is stopped when the remaining light is below a given threshold.

2. Modifying `render_rays` method​


The `render_rays` function (https://github.com/yashbhalgat/HashNeRF-pytorch/blob/main/run_nerf.py) processes each ray by sampling points along it, querying the neural network and accumulating the results (color, depth, opacity). All the computations related to ray sampling and accumulation happen here. The goal of modifying `render_rays` is to apply the early ray termination method within it. You will see that `render` is the method called within `train` (also in `run_nerf.py`) instead of `render_rays`. `render` handles the big picture: divides rays into batches, calling `render_rays` for each batch, and combining the results. `batchify_rays` bridges `render` and `render_rays` by splitting the large ray tensor from `render` into smaller chunks and calling `render_rays` on each chunk.

render_rays:
def render_rays(ray_batch,
                network_fn,
                network_query_fn,
                N_samples,
                embed_fn=None,
                retraw=False,
                lindisp=False,
                perturb=0.,
                N_importance=0,
                network_fine=None,
                white_bkgd=False,
                raw_noise_std=0.,
                verbose=False,
                pytest=False,
                early_termination_threshold=1e-4,
                enable_early_termination=True)

[...]

#Initialize sparsity_loss to ensure it is defined when enable_early_termination=True
    sparsity_loss = torch.zeros((N_rays,), device=ray_batch.device)  # Match batch size
#sparsity_loss = torch.tensor(0.0, device=ray_batch.device) #Default sparsity loss
    
    if enable_early_termination:
      #Initialize accumulators for early termination
      transmittance = torch.ones((N_rays, 1), device=ray_batch.device)  #Start with T = 1
      accumulated_rgb = torch.zeros((N_rays, 3), device=ray_batch.device)  #Accumulated color
      accumulated_weights = torch.zeros_like(z_vals)  # [N_rays, N_samples]
#accumulated_weights = torch.zeros((N_rays,), device=ray_batch.device)  #Accumulated weights
      #print("Shape of accumulated weights:\n", accumulated_weights.shape)

      #Iterate through samples with early termination
      for i in range(N_samples):
          raw = network_query_fn(pts[:, i:i + 1], viewdirs, network_fn)

          transmittance, accumulated_rgb, accumulated_weights, stop_raymarching = apply_early_termination(
              raw, z_vals, rays_d, transmittance, accumulated_rgb, accumulated_weights, i, early_termination_threshold
          )

          if stop_raymarching:
              break

      weights = torch.where(
      torch.sum(accumulated_weights, dim=-1, keepdim=True) > 0,
      accumulated_weights / (torch.sum(accumulated_weights, dim=-1, keepdim=True) + 1e-10),
      torch.zeros_like(accumulated_weights)
)
#weights = accumulated_weights / (torch.sum(accumulated_weights, dim=-1, keepdim=True) + 1e-10)
#weights=accumulated_weights
      rgb_map = accumulated_rgb
      depth_map = torch.sum(accumulated_weights * z_vals, -1) / torch.sum(accumulated_weights, -1)#depth_map = torch.sum(accumulated_weights * z_vals, -1)
      disp_map = 1./torch.max(1e-10 * torch.ones_like(depth_map), depth_map)#disp_map = 1. / torch.max(1e-10 * torch.ones_like(depth_map), depth_map / accumulated_weights)
      #Compute accumulated opacity map
      acc_map = torch.sum(accumulated_weights, -1)#acc_map = accumulated_weights

    else:
      #Original HashNeRF behavior (no early termination)
      raw = network_query_fn(pts, viewdirs, network_fn)
      rgb_map, disp_map, acc_map, weights, depth_map, sparsity_loss = raw2outputs(raw, z_vals, rays_d, raw_noise_std, white_bkgd, pytest=pytest)

[...]

The `early_termination_threshold` is set to be `1e-4`.

I initialized accumulators for transmittance, color (RGB) and weights (opacity) for each ray. Then the code iterates over the samples along the ray, querying the neural network for predictions (`raw`) at each sampled point. The `apply_early_termination` function calculates the contribution of the current sample to the final color and opacity, updates the accumulators and checks whether the ray's contribution has become negligible, in which case the loop breaks early. After the loop, the weights are normalized to ensure they sum to 1 for rays with contributions and the final RGB color (`rgb_map`), depth (`depth_map`), disparity (`disp_map`) and accumulated opacity (`acc_map`) are computed using the accumulated values.

Any help is really appreciated.
 
Back
Top