Raytracing In 2D, Using A GPU Generated BVH
Another month, another vaguely raytracing related post.
If you've read some of the other stuff I've written about 2D raytracing (Raytracing In 2D, Hybrid Distance Field / Raytracing In 2D etc), you may have noticed all of those require the acceleration structure to be generated on the CPU.
While that's all very well and good for most use cases, it would be handy to be able to generate a BVH on the GPU itself (procedural or animating maps come to mind).
Building A BVH On The GPU
Originally, I attempted to do a similar thing to Raytracing In 2D, but on the GPU.
Needless to say, the amount of passes required mixed with needing to perform a lot of nth_element
seems rather impractical.
Rather than starting at root node and working down, I attempt to flip this and start with buckets and work up towards the root node
Basically, dice a grid by some power of 2 and bin lines by their center, building up a BVH by combining them.



[2x4] / [4x2]


[1x2] / [2x1]
With the node layout being:
struct BvhNodeV2
uint leftType : 1; // 0 = node, 1 = line bucket
uint leftCount : 31; // always 0 for node
uint leftOffset;
uint rightType : 1; // 0 = node, 1 = line bucket
uint rightCount : 31; // always 0 for node
uint rightOffset;
float4 leftBBox;
float4 rightBBox;
Happily fitting within three float4s.
We can also determine the amount of memory required up front.
Lets say we have grid size of 8x8 and had a node per cell:
8x8 + 4x8 + 4x4 + 2x4 + 2x2 + 1x2 + 1x1
= 64 + 32 + 16 + 8 + 4 + 2 + 1
= 127
= 2 * gridSize * gridSize - 1
Except our bottom level nodes always come in pairs of two, so the actual amount of nodes is half, leaving us instead with:
4x8 + 4x4 + 2x4 + 2x2 + 1x2 + 1x1
= 32 + 16 + 8 + 4 + 2 + 1
= 63
= gridSize * gridSize - 1
We can proove this via the Geometric Series:
You may have noticed the power of two sequence we have going on:
Alternatively, for every 4 cells (2x2), there will be 3 nodes:
Pednantic math aside, when you take into account the number of lines:
const uint bvhFloat4Size = 3;
uint numFloat4s = (gridSize * gridSize - 1) * bvhFloat4Size
+ numLines
Tiny Bucket Problem
One problem we may encounter is that we end up with a lot of tiny buckets, which cause needless overhead when traversing.
Ideally we would "merge" line buckets when they are smaller than some threshold (say 16).
Iterating buckets by their Morton Encoding order, means we can easily merge 4 nodes together, but additionally memory is spatially stored close by, increasing the likelihood of hitting a cache.

Putting It Together
We can generate the BVH in a single pass using a compute shader and shared memory (now possible to do in a browser with WebGPU).
Here is a example of how that may look:
Example BVH Gen Shader
#version 460 core
#extension GL_EXT_control_flow_attributes : require
level count : grid dims
1 : 2x2
2 : 4x4
3 : 8x8
4 : 16x16
#ifndef NUM_LEVELS
#define NUM_LEVELS 3
#endif // NUM_LEVELS
layout(local_size_x=GRID_SIZE_DIM, local_size_y=GRID_SIZE_DIM) in;
readonly layout(std430, binding = 0) buffer inLines_ { vec4 inLines[]; };
writeonly layout(std430, binding = 1) buffer outBVH_ { vec4 outBVH[]; };
layout(binding=2) uniform numLines_ { uint numLines; };
// https://gist.github.com/JarkkoPFC/0e4e599320b0cc7ea92df45fb416d79a
uint encode16_morton2(uint x, uint y)
uint res = x | (uint(y) << 16);
res = (res | (res<<4)) & 0x0f0f0f0fu;
res = (res | (res<<2)) & 0x33333333u;
res = (res | (res<<1)) & 0x55555555u;
return uint(res | (res>>15)) & 0xffffu;
// When accessing shared memory, we use a linear index
// based upon the thread id, to prevent bank conflicts.
// When accessing the BVH memory, we use morton encoding
// to create better cache coherancy and to allow us to
// merge small line buckets togehter.
#define MAKE_LINEAR_ID(tid, brickSize) ((tid).y * (brickSize) + (tid).x)
#define MAKE_MORTON_ID(tid) (encode16_morton2((tid).x, (tid).y))
vec4 mergeBounds(vec4 a, vec4 b)
return vec4(min(a.xy, b.xy), max(a.zw, b.zw));
float boundsArea(vec4 a)
vec2 deriv = a.zw - a.xy;
return max(0, deriv.x * deriv.y);
struct BVHNodeEntry
uint type;
uint offset;
vec4 bbox;
uint count;
BVHNodeEntry makeBottomLevelBVHEntry(uint offset, vec4 bbox, uint count)
BVHNodeEntry result;
result.type = BVH_V2_METADATA_LINE_BUCKET | (count << 1);
result.offset = offset;
result.bbox = bbox;
result.count = count;
return result;
BVHNodeEntry writeBVHNode(uint outputOffset,
BVHNodeEntry left,
BVHNodeEntry right)
outBVH[outputOffset] = uintBitsToFloat(uvec4(
left.type, left.offset, right.type, right.offset
outBVH[outputOffset + 1] = left.count == 0 ? vec4(0) : left.bbox;
outBVH[outputOffset + 2] = right.count == 0 ? vec4(0) : right.bbox;
// If we didn't do a proper merge, just return which
// ever value was valid, or return an invalid value
// if neither were valid.
if(right.count == 0) { return left; }
else if(left.count == 0) { return right; }
BVHNodeEntry merged;
merged.type = BVH_V2_METADATA_NODE;
merged.offset = outputOffset;
merged.count = left.count + right.count;
merged.bbox = mergeBounds(left.bbox, right.bbox);
return merged;
BVHNodeEntry writeBVHLevel(uint outputOffset,
BVHNodeEntry A, BVHNodeEntry B,
BVHNodeEntry C, BVHNodeEntry D)
bool done = false;
BVHNodeEntry left;
BVHNodeEntry right;
const uint maxLinesToMerge = 16u;
uint mergedCount = A.count + B.count + C.count + D.count;
bool hasSmallCount = mergedCount <= maxLinesToMerge;
bool allBuckets = all(equal(uvec4(A.type, B.type, C.type, D.type) & 1u,
if(hasSmallCount && allBuckets)
done = true;
uint offset = A.offset;
vec4 mergedBbox = mergeBounds(mergeBounds(A.bbox, B.bbox),
mergeBounds(C.bbox, D.bbox));
left = makeBottomLevelBVHEntry(offset, mergedBbox, mergedCount);
right = makeBottomLevelBVHEntry(0, vec4(1.0, 1.0, -1.0, -1.0), 0);
// Pick either [[A B] | [C D]]
// or [[A C] | [B D]]
float ABarea = boundsArea(mergeBounds(A.bbox, B.bbox));
float ACarea = boundsArea(mergeBounds(A.bbox, C.bbox));
float CDarea = boundsArea(mergeBounds(C.bbox, D.bbox));
float BDarea = boundsArea(mergeBounds(B.bbox, D.bbox));
if((ABarea + CDarea) > (ACarea + BDarea))
BVHNodeEntry tmp = B;
B = C;
C = tmp;
left = writeBVHNode(outputOffset + BVH_V2_NODE_FLOAT4_STRIDE,
A, B);
right = writeBVHNode(outputOffset + BVH_V2_NODE_FLOAT4_STRIDE * 2,
C, D);
return writeBVHNode(outputOffset, left, right);
shared uint bucketAllocator[GRID_SIZE_DIM * GRID_SIZE_DIM]; // 256 bytes
shared BVHNodeEntry nodeDataSwapA[GRID_SIZE_DIM * GRID_SIZE_DIM]; // 1792 or 2048 bytes
shared BVHNodeEntry nodeDataSwapB[(GRID_SIZE_DIM / 2) * (GRID_SIZE_DIM / 2)]; // 448 or 512 bytes
uint getNodeLevelOffset(uint level)
// ((2^level)^2 - 1) nodes
return ((1u << (level + level)) - 1u)
void fetchNodeDataLDS(uvec4 ids,
uint nodeDataSide,
out BVHNodeEntry A,
out BVHNodeEntry B,
out BVHNodeEntry C,
out BVHNodeEntry D)
if((nodeDataSide & 1u) == 0u)
A = nodeDataSwapA[ids.x];
B = nodeDataSwapA[ids.y];
C = nodeDataSwapA[ids.z];
D = nodeDataSwapA[ids.w];
A = nodeDataSwapB[ids.x];
B = nodeDataSwapB[ids.y];
C = nodeDataSwapB[ids.z];
D = nodeDataSwapB[ids.w];
void buildNodeTree(uvec2 tid,
uint level,
uint nodeDataSide)
uint inputBucketDim = (1u << (level+1u));
uint outputBucketDim = (1u << level);
if(all(lessThan(tid, uvec2(outputBucketDim))))
uint localBucketId = MAKE_LINEAR_ID(tid, outputBucketDim);
uint Aid = MAKE_LINEAR_ID((tid * 2u + uvec2(0, 0)), inputBucketDim);
uint Bid = MAKE_LINEAR_ID((tid * 2u + uvec2(1, 0)), inputBucketDim);
uint Cid = MAKE_LINEAR_ID((tid * 2u + uvec2(0, 1)), inputBucketDim);
uint Did = MAKE_LINEAR_ID((tid * 2u + uvec2(1, 1)), inputBucketDim);
BVHNodeEntry A;
BVHNodeEntry B;
BVHNodeEntry C;
BVHNodeEntry D;
fetchNodeDataLDS(uvec4(Aid, Bid, Cid, Did),
A, B, C, D);
uint levelWriteStart = getNodeLevelOffset(level);
uint nodeOffset = MAKE_MORTON_ID(tid);
BVHNodeEntry bvhEntry = writeBVHLevel(levelWriteStart + 3u * BVH_V2_NODE_FLOAT4_STRIDE * nodeOffset,
A, B, C, D);
if(level != 0u)
if((nodeDataSide & 1u) == 0u)
nodeDataSwapB[localBucketId] = bvhEntry;
nodeDataSwapA[localBucketId] = bvhEntry;
if(level != 0u)
void main()
uvec2 tid = gl_LocalInvocationID.xy;
uint bucketId = MAKE_MORTON_ID(tid);
// We bucket lines based upon where their center is.
vec4 bucketBbox = vec4(tid, tid + 1u) /
// This is the final bbox of the current cell, which starts
// of as being invalid, but as we find relevant lines, we will
// expand it to fit them all.
vec4 finalBbox = vec4(1.0, 1.0, -1.0, -1.0);
// Keep track of how many lines we need to home and the first and last
// line id (to make it quicker to iterate over).
uint bucketCount = 0u;
uint firstLine = 0u;
uint endLine = 0u;
for(uint lineId = 0; lineId < numLines; ++lineId)
vec4 line = inLines[lineId];
vec2 center = (line.xy + line.zw) * 0.5;
if(all(greaterThanEqual(center.xy, bucketBbox.xy))
&& all(lessThan(center.xy, bucketBbox.zw)))
if(bucketCount == 0)
firstLine = lineId;
endLine = lineId + 1u;
vec4 lineBbox = vec4(min(line.xy, line.zw), max(line.xy, line.zw));
finalBbox = mergeBounds(finalBbox, lineBbox);
bucketAllocator[bucketId] = bucketCount;
// Write lines after the BVH data
uint bucketWriteOffset = getNodeLevelOffset(NUM_LEVELS);
for(uint id=0; id<bucketId; ++id)
bucketWriteOffset += bucketAllocator[id];
uint ldsBucketId = MAKE_LINEAR_ID(tid, (1u << NUM_LEVELS));
nodeDataSwapA[ldsBucketId] = makeBottomLevelBVHEntry(bucketWriteOffset,
// Write out lines into their respective places
for(uint lineId = firstLine; lineId < endLine; ++lineId)
vec4 line = inLines[lineId];
vec2 center = (line.xy + line.zw) * 0.5;
if(all(greaterThanEqual(center.xy, bucketBbox.xy))
&& all(lessThan(center.xy, bucketBbox.zw)))
outBVH[bucketWriteOffset++] = vec4(line.xy, line.xy - line.zw);
uint nodeDataSide = 0u;
for(int level=(NUM_LEVELS-1); level >= 0; --level, ++nodeDataSide)
buildNodeTree(tid, uint(level), nodeDataSide);
The ray traversal is largely the same as previous stuff I've done (a stack based method), as such I'm not going to get into it here.
As always, there's a demo (at the top), this one allows you to edit the lines, which is all handled on the GPU side (picking a vertex and updating the lines buffer).
First time using WebGPU, the API is sort of like a Vulkan-lite, or a primitive RHI abstraction (which I guess is what it kind of is) and is quite nice compared to WebGL (although still has teething issues).
Here are all the relevant source code files:
- BVH Gen (GLSL): v2_tracing_generate_bvh.comp
- Ray Traversal (GLSL): v2_tracing.glsli
- Visualisers Ray (GLSL): v2_tracing_test.frag
- Visualisers BVH (GLSL): draw_tree.vert and draw_tree.frag
- JS used for demo: rtrace_bvh_v2.js