Skip to main content

Command Palette

Search for a command to run...

2.8 - Vertex Shader Optimization

Updated
55 min read
2.8 - Vertex Shader Optimization

What We're Learning

You've built amazing vertex shaders: flags that wave in the wind, fields of thousands of unique grass blades, and surfaces that pulse and deform. You've moved from "does it work?" to "does it look good?". Now, it's time to ask the third crucial question: how fast is it?

Performance optimization isn't about making code cryptic or sacrificing quality. It's a fundamental shift in mindset to understand how the GPU works and to write shaders that work with the hardware, not against it. A well-optimized shader isn't just about hitting a target framerate; it's a creative enabler. The performance budget you save can be spent on richer effects, more complex geometry, and denser scenes. An optimized shader is the difference between rendering a single tree and rendering an entire forest.

To unlock this performance, we need to think like a GPU. The GPU is a massively parallel machine that achieves incredible speed through specialization. It wants to perform simple, predictable work on huge batches of data. Throughout this article, we'll learn to align our code with this model by focusing on a few key principles:

  • Performing simple, uniform calculations across many vertices at once.

  • Avoiding divergent branches that force parts of the GPU to wait.

  • Accessing memory in efficient, predictable patterns.

  • Moving work that is constant for all vertices from the GPU to the CPU.

  • Leveraging the GPU's highly optimized built-in functions.

By the end of this article, you will have a deep understanding of not just what to optimize, but why it works, enabling you to write high-performance shaders for any task.

You will learn:

  • How vertex shaders execute on the GPU (the SIMD/wavefront model).

  • Why certain operations are expensive and how to avoid them.

  • How to move constant calculations from the GPU to the CPU.

  • The true cost of branching and the techniques to minimize or eliminate it.

  • The importance of memory access patterns and cache efficiency.

  • How to leverage GPU-optimized built-in functions for maximum speed.

  • How to use profiling tools and techniques in Bevy to find bottlenecks.

  • A complete, practical optimization workflow: measure, optimize, and verify.

Understanding Vertex Shader Execution

To optimize a shader, you must first understand how it actually runs on the GPU. The GPU is not a faster CPU; it's a fundamentally different kind of processor built for massive parallelism. Thinking it processes vertices one-by-one is the most common and costly mistake a shader programmer can make.

The SIMD/Wavefront Execution Model

GPUs achieve their speed by processing data in large, synchronized groups. Instead of executing your vertex shader on one vertex at a time, it runs it on a batch of 32 or 64 vertices simultaneously. This batch is called a wavefront (on AMD/Vulkan/Metal) or a warp (on NVIDIA).

Think of a wavefront as a small platoon of soldiers. The entire platoon receives the same command at the same time and must execute it in perfect lockstep. This architecture is called SIMD (Single Instruction, Multiple Data). One instruction (e.g., "add 5 to the Y position") is applied to a large set of different data (the positions of all 64 vertices in the wavefront).

CPU Model (One at a time):           GPU Model (Wavefront in lockstep):

Vertex 1 → Process → Done            [Vertex 1, Vertex 2, ..., Vertex 64]
Vertex 2 → Process → Done                      ▼
Vertex 3 → Process → Done            Process ALL with the SAME instruction
...                                            ▼
Vertex N → Process → Done            All 64 are done simultaneously

Time to process 64 vertices: ~64x.   Time to process 64 vertices: ~1x

This lockstep execution is the source of the GPU's incredible power, but it comes with a critical trade-off that has profound implications for our shader code.

What This Means for Your Code

Let's see how this plays out with two simple examples.

Uniform Code: The Ideal Scenario

First, consider a simple, uniform calculation where every vertex in the wavefront performs the exact same operation.

// Every vertex does the same calculation. This is ideal.
let height = position.y * 2.0;

In this case, the entire wavefront executes the instruction in perfect unison. There is no waiting and no wasted time. All hardware resources are used with maximum efficiency.

All 64 vertices execute the same instruction.
[✓✓✓✓✓✓✓✓✓✓✓✓✓✓✓✓✓✓✓✓✓✓✓✓✓✓✓✓✓✓✓✓✓✓✓✓✓✓✓✓✓✓✓✓✓✓✓✓✓✓✓✓✓✓✓✓✓✓✓✓✓✓✓✓]
    ↓
  All vertices multiply their `position.y` by 2.0.
    ↓
  Done!

Total Time: 1x (Optimal)

This is the GPU working at its best. The code is uniform across all vertices.

Divergent Code: The Performance Problem

Now, let's introduce an if/else statement where the condition depends on each vertex's unique position.

// Different vertices will take different paths. This is a problem.
if position.y > 0.5 {
    // Path A: Some vertices do this
    height = position.y * 2.0;
} else {
    // Path B: The other vertices do this
    height = position.y * 0.5;
}

Because the entire wavefront must execute in lockstep, it cannot split up and run both paths simultaneously. Instead, the hardware serializes the execution:

  1. The if condition is evaluated for all 64 vertices. Some are true, some are false.

  2. Path A Execution: The GPU executes the code for the true branch (height = position.y * 2.0). During this step, all vertices that evaluated to false are temporarily disabled and forced to wait.

  3. Path B Execution: The GPU executes the code for the else branch (height = position.y * 0.5). Now, all the vertices that took the first path are disabled and wait.

  4. The paths converge, and the full wavefront becomes active again.

The wavefront was forced to execute both branches, effectively taking twice as long. This phenomenon is called branch divergence, and it is one of the biggest performance killers in shader programming.

Visualizing Wavefront Divergence

Uniform Code (Coherent Branch, No Divergence):

Imagine a wavefront where all vertices happen to have position.y > 0.5.

All 64 vertices take the same path.
[✓✓✓✓✓✓✓✓✓✓✓✓✓✓✓✓✓✓✓✓✓✓✓✓✓✓✓✓✓✓✓✓✓✓✓✓✓✓✓✓✓✓✓✓✓✓✓✓✓✓✓✓✓✓✓✓✓✓✓✓✓✓✓✓]
    ↓
  All vertices execute Path A. Path B is skipped entirely.
    ↓
  Done!

Total Time: 1x (Optimal)

Divergent Code (Incoherent Branch):

Now, imagine half the vertices are above 0.5 and half are below.

32 vertices take Path A (✓), 32 take Path B (✗)
[✓✓✓✓✓✓✓✓✓✓✓✓✓✓✓✓✓✓✓✓✓✓✓✓✓✓✓✓✓✓✓✓✗✗✗✗✗✗✗✗✗✗✗✗✗✗✗✗✗✗✗✗✗✗✗✗✗✗✗✗✗✗✗✗]
    ↓
  Step 1: Execute Path A.
  (The ✗ vertices are inactive and wait.)
    ↓
  Step 2: Execute Path B.
  (The ✓ vertices are now inactive and wait.)
    ↓
  Done!

Total Time: ~2x (or worse, if the branches are complex)

The key takeaway is that divergence forces sequential execution on parallel hardware, neutralizing its main advantage.

When Branches Are Acceptable

Not all branches are bad. The performance penalty only occurs with divergent branches. A branch is perfectly fine if it is coherent - meaning all vertices in a wavefront take the same path.

This typically happens under two conditions:

  1. The condition is based on uniform data: Since uniforms are the same for all vertices, the entire wavefront will always choose the same path.

  2. The condition is based on data that is likely to be the same for large groups of vertices: For example, culling objects based on their distance to the camera. While one wavefront at the edge of the culling distance might diverge, the vast majority will be coherently "in" or "out."

// ✓ OK - Coherent Branch.
// The material's `mode` is a uniform, so all vertices in this
// draw call will take the same path.
if material.mode == 0u {
    // All vertices do this OR none do.
    displacement = wave_displacement(position, time);
} else {
    // All vertices do this OR none do.
    displacement = noise_displacement(position, time);
}

// ✓ OK - Mostly Coherent Branch for an early exit.
// This saves a huge amount of work for the many wavefronts that are far away.
if distance_to_camera > 100.0 {
    // Skip expensive animation for distant objects.
    return simple_vertex_output(position);
}

// ✗ BAD - Divergent Branch.
// The result of this `if` will be different for nearly every
// vertex, causing maximum divergence.
if sin(position.x * 10.0) > 0.0 {
    // Some vertices do this...
    height = complex_calculation_a();
} else {
    // ...while others do this.
    height = complex_calculation_b();
}

The Golden Rule

The fundamental principle of vertex shader optimization can be summarized in one rule:

Minimize divergence, maximize uniformity.

If all vertices in a wavefront are doing the same simple thing, the GPU will fly. The more you introduce divergence and complex, per-vertex decision-making, the more you force the GPU to slow down and serialize its work.

Optimization Strategy 1: Move Constant Calculations to CPU

Now that we understand the GPU's execution model, we can begin with our first and often easiest optimization strategy: if a calculation produces the same result for every single vertex in a draw call, it does not belong on the GPU.

The vertex shader's job is to process unique, per-vertex data. Any work that is constant across all vertices is a waste of the GPU's massively parallel power. It's like asking an entire army of 40,000 soldiers to individually calculate sin(0.5) - they will all get the same answer, but you've wasted 39,999 identical calculations.

Identifying Expensive Operations

To understand why this is so important, let's look at the relative costs of common shader operations. While the exact numbers vary between GPU architectures, the relative differences are a powerful mental model.

Operation TypeRelative Cost (Approximate)Notes
Addition / Subtraction1x (Baseline)The fastest operations.
Multiplication1xAs fast as addition.
Division~4-8xSignificantly slower. Avoid if able.
sqrt (Square root)~4-8xAn expensive operation. For vector normalization, inverseSqrt is a faster alternative to dividing by sqrt. We'll explore this in "Strategy 3."
sin, cos (Trig)~8-16xVery expensive. Minimize usage.
pow (Power)~10-20xAlso very expensive.
Texture Sample~20-200xInvolves memory access. Very slow.
Divergent Branch(Cost of A + Cost of B)Can be extremely costly.

The takeaway is clear: trigonometry, powers, square roots, and divisions are significantly more expensive than simple addition and multiplication. Our first goal should be to minimize these expensive operations, especially when they are redundant.

The Precomputation Pattern

Let's look at a common mistake: calculating a time-based value inside the vertex shader.

✗ BAD: Computing sin/cos Per-Vertex

// in: MyMaterial.wgsl
struct MyMaterial {
    time: f32,
    // ... other uniforms
}

@group(2) @binding(0) var<uniform> material: MyMaterial;

@vertex
fn vertex(in: VertexInput) -> VertexOutput {
    // ✗ BAD: Every vertex computes sin(time) and cos(time).
    // These values are identical for all vertices in this draw call.
    let wave = sin(material.time) * 0.5;
    let rotation_amount = cos(material.time * 0.5);

    // ... use wave and rotation_amount ...
}

If our mesh has 10,000 vertices, this code performs 20,000 expensive trigonometric calculations every single frame.

✓ GOOD: Precomputing on the CPU

The solution is to perform the calculation just once on the CPU each frame and pass the result to the GPU as a uniform.

// in: my_material.rs (Rust code)

// A system that runs once per frame
fn update_material(
    time: Res<Time>,
    mut materials: ResMut<Assets<MyMaterial>>,
) {
    // Iterate over all instances of our material that exist
    for (_, material) in materials.iter_mut() {
        let t = time.elapsed_secs();

        // Compute these expensive values ONCE on the CPU
        material.uniforms.time_sin = t.sin();
        material.uniforms.time_cos = (t * 0.5).cos();
    }
}

Now, we update our shader to use these precomputed values.

// in: MyMaterial.wgsl
struct MyMaterial {
    time: f32,
    time_sin: f32,  // Precomputed on CPU
    time_cos: f32,  // Precomputed on CPU
}
@group(2) @binding(0) var<uniform> material: MyMaterial;


@vertex
fn vertex(in: VertexInput) -> VertexOutput {
    // ✓ GOOD: Just read the precomputed values from the uniform buffer.
    // This is just a fast memory read.
    let wave = material.time_sin * 0.5;
    let rotation_amount = material.time_cos;

    // ... use wave and rotation_amount ...
}

The performance difference is enormous. We've replaced 20,000 expensive GPU operations with just 2 on the CPU. The total cost of these operations inside the shader is now effectively zero.

Important Caveat: Position-Dependent Calculations

This technique is powerful, but it has a critical limitation: it only works for calculations that are independent of per-vertex attributes like position, normal, or uv.

Consider a spatial wave effect:

// This CANNOT be precomputed on the CPU
let wave = sin(position.x * frequency + time);

Here, sin() depends on both time (a uniform) and position.x (a vertex attribute). Since position.x is different for every vertex, the result of the sin() call will also be different. This calculation must happen in the vertex shader. For these cases, we rely on other optimization strategies, like reducing the complexity or using a Level of Detail (LOD) system, which we'll cover later.

What to Precompute

Always precompute on the CPU if a value is the same for all vertices. Common candidates include:

  • Time-based animations: Any sin(time), cos(time), or other function that only depends on a time uniform.

  • Complex uniform expressions: If you are combining multiple uniforms in a complex formula, do it once on the CPU.

  • Matrix inversions and transposes: The normal matrix (transpose(inverse(model))) is a classic example. It's an expensive calculation that is constant for all vertices of a given mesh.

Example: The Normal Matrix

// ✓ GOOD - Compute normal matrix on CPU (conceptual Rust)
let model_matrix: Mat4 = transform.compute_matrix();
let normal_matrix: Mat3 = model_matrix.inverse().transpose().into();
// ... upload both model_matrix and normal_matrix to the GPU ...
// in shader
// Just read the precomputed matrix
let world_normal = material.normal_matrix * in.normal;

This is vastly superior to calculating it per-vertex:

// ✗ BAD - Computing inverse-transpose per vertex
let model_3x3 = mat3x3(model[0].xyz, model[1].xyz, model[2].xyz);
// This is extremely expensive!
let normal_matrix = transpose(inverse(model_3x3));
let world_normal = normal_matrix * in.normal;

Precomputation Checklist

Review your vertex shader and ask these questions for every calculation:

  1. Does this value depend only on uniforms (like time, color, mode)? → Precompute on CPU.

  2. Does this involve expensive operations (sin, cos, pow, sqrt, inverse, transpose)? → Precompute if possible.

  3. Is this value calculated every frame but rarely changes? → Cache it on the CPU and only update it when needed.

  4. Is this value the same for every single vertex being drawn? → Precompute on CPU.

Optimization Strategy 2: Avoid Complex Branching

We've established that branch divergence is a primary enemy of GPU performance. An if/else statement based on per-vertex data forces the GPU's parallel hardware into a sequential, one-path-then-the-other execution model, destroying its efficiency.

The solution is to transform our code from a "control flow" problem (choosing which code to run) into a "data flow" problem (calculating a result with math). GPUs are phenomenal at math. By replacing if statements with mathematical equivalents, we can create branchless code that runs uniformly across all vertices in a wavefront, keeping the hardware fully saturated and performing at its peak.

Technique 1: Replace "Choosing" with "Blending"

The most common pattern for eliminating a branch is to calculate the results of both paths and then use a mathematical function to blend between them. This might seem counterintuitive - why do more math? - but it's faster because it avoids the costly process of stopping, waiting, and serializing execution that a divergent branch causes.

Our main tools for this are step() and mix().

  • step(edge, x): This is a simple threshold function. If x is less than edge, it returns 0.0. Otherwise, it returns 1.0. It's a perfect mathematical switch.

  • mix(a, b, t): This performs a linear interpolation. It returns a when t is 0.0 and b when t is 1.0.

By combining them, we can create a powerful branchless equivalent to if/else.

✗ BAD: Branching

var height: f32;
// If the vertex is above y=0, make it taller.
// If it's below, make it shorter.
if position.y > 0.0 {
    height = position.y * 2.0; // Path A
} else {
    height = position.y * 0.5; // Path B
}

✓ GOOD: Branchless Blending

// 1. Create a switch that is 0.0 for Path B and 1.0 for Path A.
let is_positive = step(0.0, position.y); // Returns 0.0 or 1.0

// 2. Use the switch as the blend factor in mix().
let height = mix(
    position.y * 0.5,   // Value when is_positive is 0.0 (Path B)
    position.y * 2.0,   // Value when is_positive is 1.0 (Path A)
    is_positive         // The blend factor (our switch)
);

On a modern GPU, this branchless version is often significantly faster. The hardware calculates both position.y * 0.5 and position.y * 2.0 in parallel and then uses the is_positive value to select the correct result, all without causing the pipeline to diverge.

Technique 2: Use min, max, and clamp for Range Checks

Another common use for if is to clamp a value within a certain range. This should almost always be replaced with the dedicated built-in functions, which correspond to single, fast hardware instructions.

✗ BAD: Branching

if height < 0.0 {
    height = 0.0;
}
if value > 1.0 {
    value = 1.0;
} else if value < 0.0 {
    value = 0.0;
}

✓ GOOD: Branchless

height = max(height, 0.0);
value = clamp(value, 0.0, 1.0);

Technique 3: Move the Condition from Code to Data

Sometimes, you can eliminate a branch by encoding the condition directly into your mesh data as a vertex attribute. This is an advanced but powerful technique for cases like a cloth simulation where some vertices are "pinned" and should not move.

✗ BAD: Branching on Vertex ID or Position

// This would be extremely divergent if pinned vertices
// are scattered throughout the mesh.
if is_pinned_position(position) {
    // Don't move this vertex
    return position;
} else {
    // Apply physics
    return position + displacement;
}

✓ GOOD: Using a Vertex Attribute as a Mask

In your Rust code, when you generate the mesh, you add a custom attribute.

// In mesh generation (conceptual Rust)
let pinned_flags: Vec<f32> = my_vertices.iter().map(|v| {
    if v.is_pinned { 1.0 } else { 0.0 }
}).collect();

mesh.insert_attribute(
    // A custom attribute, at the next available location
    Mesh::ATTRIBUTE_JOINT_INDEX, // Example location
    pinned_flags
);

Now, the shader can use this attribute as a simple multiplier, completely avoiding a branch.

// Read the custom attribute
@location(4) is_pinned: f32, // 0.0 = free, 1.0 = pinned

// ... inside the vertex shader ...

// No branch needed! Just multiply the displacement by 0.0 or 1.0.
let final_displacement = displacement * (1.0 - is_pinned);
return position + final_displacement;

When You Absolutely Must Branch: Keep It Coherent

If a branch is unavoidable, the goal is to ensure it is as coherent as possible, meaning that large, contiguous groups of vertices are likely to take the same path.

  1. Branch Early: Perform your branch check as early as possible in the shader to skip the maximum amount of expensive work.

  2. Branch on Uniform Data: As discussed, branching on a uniform is always perfectly coherent and fast.

  3. Branch on Slowly-Changing Data: Branching on values like distance_to_camera is generally acceptable. While there will be divergence right at the threshold, the vast majority of vertices will be either clearly near or clearly far, resulting in highly coherent wavefronts.

This is the principle behind Level of Detail (LOD) systems.

✓ OK: A Coherent Branch for LOD

// A mostly coherent branch that saves a huge amount of work.
if distance_to_camera > material.lod_distance {
    // Simple, cheap path for distant vertices.
    // Skip all the expensive calculations below.
    return simple_vertex_transform(position);
}

// Full-quality, expensive path for nearby vertices.
let wave = calculate_complex_wave(position, time); // Expensive
let foam = calculate_procedural_foam(uv);         // Also expensive
// ...

This is a good trade-off. We accept a small amount of divergence at the LOD boundary in exchange for saving a massive amount of computation on the thousands of vertices that are far away.

Branchless Patterns Cheat Sheet

Use this table as a quick reference for converting common if statements into faster, branchless equivalents. Many of the logical patterns assume you are working with float values of 0.0 for false and 1.0 for true, a common convention in shader programming.

Instead of...Use...
if (x > threshold) { a } else { b }let result = mix(b, a, step(threshold, x));
if (x < 0.0) { 0.0 } else { x }let result = max(x, 0.0);
if (x > 1.0) { 1.0 } else if ... { x }let result = clamp(x, 0.0, 1.0);
if (my_bool) { value } else { 0.0 }let result = value * f32(my_bool);
if (a > 0.5 && b > 0.5) { value }let result = value * step(0.5, a) * step(0.5, b);
`if (a > 0.5

Optimization Strategy 3: Use Built-in Functions

GPU drivers are some of the most heavily optimized pieces of software on the planet. The engineers at NVIDIA, AMD, and Apple have spent millions of hours fine-tuning the performance of the core WGSL functions. Our third strategy is simple but powerful: trust their work and use it.

Whenever you are tempted to write a common mathematical function yourself (like linear interpolation or vector normalization), stop and check if a built-in function already exists. A custom implementation is almost guaranteed to be slower than the driver's version.

Why Built-ins Are Faster

Built-in functions aren't just convenient wrappers. They often map directly to:

  • Specialized Hardware Instructions: Many functions like normalize(), dot(), and mix() are implemented directly in the silicon. They execute in a single, incredibly fast hardware instruction.

  • Optimized Microcode: For more complex functions, the driver uses a highly optimized sequence of low-level instructions that are tuned for the specific GPU architecture you are running on.

  • Numerical Stability: Driver implementations are carefully designed to handle edge cases and avoid precision errors that can easily creep into custom code.

Common Built-ins to Favor

Here is a non-exhaustive list of essential, highly-optimized functions available in WGSL. You should always prefer these over manual implementations.

Vector Operations

length(v)
distance(a, b)
normalize(v)
dot(a, b)
cross(a, b)

Scalar Math & Clamping

min(a, b)
max(a, b)
clamp(x, low, hi)
abs(x)
sign(x)

Interpolation & Stepping

mix(a, b, t)          // Linear interpolation (lerp)
smoothstep(e0, e1, x) // Smooth Hermite interpolation
step(edge, x)         // 0.0 if x < edge, else 1.0

Exponentials & Trigonometry

exp(x), exp2(x)
log(x), log2(x)
pow(x, y)
sqrt(x)
inverseSqrt(x)       // 1.0 / sqrt(x)
sin(x), cos(x), tan(x)

Example 1: The Power of inverseSqrt and normalize

As promised, let's explore inverseSqrt. Its primary purpose is to accelerate vector normalization, one of the most common operations in 3D graphics. The goal of normalization is to make a vector's length equal to 1. The formula is v / length(v).

✗ BAD: Manual Normalization

// This involves a dot product, an expensive `sqrt`, and an expensive vector division.
let len = sqrt(dot(v, v));
let normalized = v / len;

✓ GOOD: Using inverseSqrt to Avoid Division

We can mathematically rewrite the formula to replace the slow division with a fast multiplication. v / len is the same as v * (1.0 / len). The term 1.0 / sqrt(dot(v, v)) is exactly what inverseSqrt calculates, and it does so with a single, highly-optimized hardware instruction.

// This is better. We've replaced a sqrt and a division with a single, fast
// inverseSqrt and a fast multiplication.
let inv_len = inverseSqrt(dot(v, v));
let normalized = v * inv_len;

✨ BEST: Using normalize

The best approach is to use the highest-level function that describes your intent. The normalize() function is the clearest and gives the driver the most freedom to use the absolute fastest method available on the hardware, which might be even more optimized than a manual inverseSqrt.

// Perfect. This is readable, concise, and guaranteed to be the fastest method.
let normalized = normalize(v);

Example 2: Vector Distance

✗ BAD: Manual Implementation

// Manual, verbose, and misses potential optimizations.
let d = b - a;
let distance = sqrt(d.x*d.x + d.y*d.y + d.z*d.z);

✓ GOOD: Built-in Function

// The driver can optimize this much more effectively.
let distance = distance(a, b);

Example 3: Color Interpolation

✗ BAD: Manual Linear Interpolation (lerp)

// This is the definition of lerp, but `mix` is faster.
let result = color_a + (color_b - color_a) * t;

✓ GOOD: Built-in mix

// `mix` is the WGSL name for lerp.
let result = mix(color_a, color_b, t);

Built-in Patterns Cheat Sheet

For this task...Use this function...Instead of this...
Vector Normalizationlet n = normalize(v);v / length(v)
Distance Between Ptslet d = distance(a, b);length(b - a)
Clamping a Valuelet c = clamp(x, 0.0, 1.0);min(max(x, 0.0), 1.0)
Linear Interpolationlet lerp = mix(a, b, t);a + (b - a) * t
Smooth Interpolationlet s = smoothstep(0.0, 1.0, x);A custom x*x*(3.0-2.0*x) curve
Absolute Valuelet abs_val = abs(x);if (x < 0.0) { -x } else { x }
Sign of a Valuelet s = sign(x);if (x < 0.0) { -1.0 } else { 1.0 }

Optimization Strategy 4: Memory Access Patterns

Modern GPUs are mathematical beasts. They can perform trillions of floating-point operations per second. In many cases, the bottleneck in a shader is not the math (being ALU-bound), but the time it takes to fetch data from memory (being memory-bound). Every texture sample, every uniform read, requires a trip to memory. Our fourth strategy is to make those trips as short and infrequent as possible.

Understanding the Memory Hierarchy

Not all memory on a GPU is created equal. There is a hierarchy of memory types, each with a trade-off between speed and size.

Fastest & Smallest   │ Registers          → Extremely fast, on-chip. Holds local variables.
                     ├────────────────────
                     │ L1 / Texture Cache → Very fast, small. Holds recently accessed data.
                     ├────────────────────
                     │ L2 Cache           → Fast, larger. A shared cache for more data.
                     ├────────────────────
Slowest & Largest    │ VRAM (Global)      → Slow. Holds all your textures and buffers.

Think of it like a workshop:

  • Registers are the tools in your hand. Access is instant.

  • Caches are the tools on your workbench. Quick to grab.

  • VRAM is the supply closet down the hall. Every trip costs you time.

Your goal as a shader programmer is to keep the data you need in your hand or on the workbench, minimizing trips to the supply closet. You do this by following two key principles of memory access.

The Two Rules of Efficient Memory Access

  1. Temporal Locality (Reuse What You Fetch): If you take the time to fetch an item from memory, keep it in a local variable for as long as you need it. Don't go back to memory for the same piece of data multiple times. This ensures the data stays in the fastest registers.

  2. Spatial Locality (Access Nearby Data): When you access memory (especially textures), the GPU is smart. It fetches not just the single piece of data you asked for, but also a small block of its neighbors, storing them in the cache. If your next memory access is for one of those neighbors, it will be an incredibly fast "cache hit."

Texture Sampling: The Most Expensive Memory Access

A texture sample is often the single most expensive operation in a shader because it involves a complex memory lookup. Minimizing texture fetches is one of the biggest performance wins you can achieve.

✓ Rule 1: Sample Once, Use All Channels

Never sample the same texture multiple times with the same UV coordinates. Fetch it once into a local vec4 variable and reuse that variable.

@vertex
fn vertex(in: VertexInput) -> VertexOutput {
    // ✗ BAD: Three separate, expensive trips to memory for the same location.
    let noise1 = textureSample(noise_tex, samp, in.uv).r;
    let noise2 = textureSample(noise_tex, samp, in.uv).g;
    let noise3 = textureSample(noise_tex, samp, in.uv).b;

    // ... use noise1, noise2, noise3 ...
}
@vertex
fn vertex(in: VertexInput) -> VertexOutput {
    // ✓ GOOD: One trip to memory. The result is stored in a fast register.
    let noise_sample = textureSample(noise_tex, samp, in.uv);
    let noise1 = noise_sample.r;
    let noise2 = noise_sample.g;
    let noise3 = noise_sample.b;

    // Same result, much faster.
}

✓ Rule 2: Pack Your Data

Following from the first rule, you can drastically reduce your total sample count by packing multiple grayscale masks or data values into the R, G, B, and A channels of a single texture. This is a standard technique in game development.

Instead of using three separate textures:

  • Texture 1 (R): Height displacement

  • Texture 2 (R): Roughness value

  • Texture 3 (R): Ambient occlusion mask

Combine them into one:

  • Combined_Texture.r: Height displacement

  • Combined_Texture.g: Roughness value

  • Combined_Texture.b: Ambient occlusion mask

  • Combined_Texture.a: (another value, like metallic)

This technique reduces three expensive texture fetches down to just one.

Uniform and Storage Buffer Access

The principle of temporal locality applies equally to uniform and storage buffers. While this data is cached, you can still gain performance by being explicit.

✓ Rule 3: Cache Buffer Values in Local Variables

When you access a uniform like material.time multiple times, you are repeatedly reading from the uniform buffer. While this is likely to hit a cache, it's even faster to read it once into a local variable, which guarantees it lives in a register.

@vertex
fn vertex(in: VertexInput) -> VertexOutput {
    // Each use of `material.time` is a potential memory access.
    let wave1 = sin(position.x + material.time);
    let wave2 = cos(position.y + material.time);
    let wave3 = sin(position.z * material.time);
    // ...
}
@vertex
fn vertex(in: VertexInput) -> VertexOutput {
    // ✓ GOOD: Read from the uniform buffer ONCE.
    let time = material.time;
    // `time` now lives in a fast register for the rest of the shader.
    let wave1 = sin(position.x + time);
    let wave2 = cos(position.y + time);
    let wave3 = sin(position.z * time);
    // ...
}

This pattern is especially important for storage buffers in instanced rendering. Access the instance data array once, store the result in a local struct, and then use the fields from that local struct.

Memory Access Checklist

  • Minimize texture fetches. They are your most expensive memory operation.

  • Pack data into RGBA channels to reduce the total number of textures and samples.

  • Sample once, reuse often. Store texture samples in a local vec4 and use its components.

  • Cache uniform and storage buffer values in local variables at the top of your shader.

  • Prioritize local variables (registers) for all frequently used data to avoid round trips to slower memory.

Optimization Strategy 5: Reduce Vertex Shader Complexity

The optimizations we've covered so far - precomputation, branchless math, built-ins, and efficient memory access - are about making your existing calculations run faster. Our final strategy is simpler and often more impactful: just do less work.

A complex vertex shader with multiple waves, noise functions, and procedural animations might look fantastic up close, but that detail is completely wasted on an object that is a few pixels wide on the horizon. By strategically reducing or eliminating work for distant or insignificant vertices, you can free up enormous amounts of GPU time to be spent where it matters.

Technique 1: Level of Detail (LOD)

Level of Detail is the most powerful technique in this category. The concept is simple: check how far a vertex is from the camera, and run a cheaper, lower-quality version of your shader for it if it's far away.

This is a prime example of an "acceptable branch." While the if/else will cause some divergence for the few wavefronts at the LOD boundaries, the vast majority of wavefronts will be coherently near or far. The performance saved by skipping expensive calculations for the thousands of distant vertices far outweighs the cost of the branch itself.

@vertex
fn vertex(in: VertexInput) -> VertexOutput {
    let world_pos = /* ... transform vertex to world space ... */;
    let distance_to_camera = distance(world_pos, view.world_position);

    var displacement: vec3<f32>;

    // Check distance and choose a code path
    if distance_to_camera < 20.0 {
        // CLOSE: Full quality. Use multiple sine waves and noise.
        displacement = calculate_complex_waves_with_noise(in.position, 4u);
    } else if distance_to_camera < 50.0 {
        // MEDIUM: Reduced quality. Just a single sine wave.
        displacement = calculate_simple_sine_wave(in.position);
    } else {
        // FAR: Minimal quality. No displacement at all.
        displacement = vec3<f32>(0.0);
    }

    // Apply displacement and continue...
}

This technique ensures that your GPU's budget is spent rendering beautiful detail on the objects the player is actually looking at, not wasting cycles on distant scenery.

Technique 2: Back-Face Optimization

For any solid, opaque object (like a character model or a rock), any vertex on a face pointing away from the camera will never be visible. We can detect this and skip expensive calculations for those vertices.

To do this, we need to know which way the vertex's normal is pointing and which way the camera is looking.

  • The view direction is the vector from the vertex's position to the camera's position.

  • The dot() product of the vertex normal and the view direction tells us if they are pointing in roughly the same direction.

    • If dot(normal, view_dir) > 0, the face is pointing towards the camera (front-facing).

    • If dot(normal, view_dir) < 0, the face is pointing away from the camera (back-facing).

We can use this as an early exit for expensive work.

@vertex
fn vertex(in: VertexInput) -> VertexOutput {
    let world_pos = /* ... get world position ... */;
    let world_normal = /* ... get world normal ... */;
    let view_dir = normalize(view.world_position - world_pos);

    // Is this vertex on the back of the object?
    if dot(world_normal, view_dir) < 0.0 {
        // Yes, it's hidden. Skip all expensive displacement.
        return simple_output(world_pos);
    }

    // No, it's visible. Proceed with the full, expensive calculation.
    let displacement = calculate_expensive_displacement(world_pos);
    // ...
}

Critical Caveat: This optimization is only valid for closed, opaque meshes. Never use it on:

  • Transparent objects: You need to see the back faces through the front.

  • Single-sided planes: Grass, leaves, cloth, paper. The back face is the only other side.

  • Objects with holes: You might see the "back" of an interior surface through a hole.

Use this technique carefully, but for the right meshes, it can effectively cut your vertex shader's workload in half.

Technique 3: Simplify Your Math

Sometimes, a visually identical or "good enough" effect can be achieved with a much cheaper mathematical formula. Always look for opportunities to simplify.

  • Powers: pow(x, 2.0) is expensive. x * x is cheap and gives the same result. pow(x, 3.0) can be x * x * x.

  • Approximations: Do you need a perfectly circular falloff, or would a cheaper, linear falloff look just as good in motion? Do you need a high-quality sin wave, or can you use a simpler triangular wave pattern for a background effect?

  • Reduce Frequency: For effects like noise, often you can use a lower-frequency (larger scale) noise pattern that requires fewer calculations (fewer octaves) with little to no perceptible difference from a distance.

Always start with the simplest effect that achieves your goal, and only add mathematical complexity if it provides a clear visual benefit.

Profiling and Measuring Performance

There is a golden rule in all software optimization: you cannot optimize what you cannot measure.

It's easy to fall into the trap of "premature optimization" - guessing where the slow parts of your code are and making changes based on intuition. This often leads to wasted time, more complex code, and minimal performance gains. A professional workflow is always driven by data. You must first profile your application to find the actual bottleneck before you try to fix it.

Step 1: Is the CPU or GPU the Bottleneck?

Your application's frame rate is limited by whichever processor finishes its work last.

  • If the CPU takes 20ms to prepare a frame and the GPU only takes 5ms to render it, your frame time will be 20ms (~50 FPS). You are CPU-bound. Optimizing your shader will have zero effect on your frame rate.

  • If the CPU takes 5ms and the GPU takes 20ms, your frame time will also be 20ms. You are GPU-bound. In this case, optimizing your shader is critical.

Bevy's built-in diagnostics are the perfect tool for this initial investigation.

Enabling Bevy's Frame Time Diagnostics

Add the following plugins to your Bevy App:

use bevy::diagnostic::{FrameTimeDiagnosticsPlugin, LogDiagnosticsPlugin};

App::new()
    .add_plugins(DefaultPlugins)
    .add_plugins(FrameTimeDiagnosticsPlugin)
    .add_plugins(LogDiagnosticsPlugin::default())
    // ...

When you run your application from the command line, Bevy will now print the frame time. The most important number to watch is the cpu frame time.

  • If cpu frame time is high (e.g., > 16.6ms for 60 FPS): Your bottleneck is likely on the CPU. This is often caused by having too many draw calls (which instancing solves), running too many complex systems, or spawning/despawning too many entities per frame.

  • If cpu frame time is low (e.g., < 5ms) but your FPS is still poor: Your bottleneck is almost certainly on the GPU. This means your shaders are too complex, you have too much geometry, or you're rendering too many pixels (fill-rate limited).

Step 2: Profiling the GPU

Once you've confirmed you are GPU-bound, you need to dig deeper. Specialized GPU profiling tools are essential for this. They can capture a single frame of your application and give you a detailed breakdown of exactly how long each draw call and shader took to execute.

  • RenderDoc (Windows & Linux): The industry standard for graphics debugging. It allows you to inspect every stage of the pipeline and provides detailed GPU timings for each event.

  • Xcode Instruments (macOS): Provides excellent GPU profiling tools for Metal, allowing you to see shader execution times and identify bottlenecks.

  • PIX (Windows, for DirectX): Microsoft's dedicated performance tuning and debugging tool.

The process generally involves:

  1. Launching your Bevy application through the profiler.

  2. Pressing a key to "capture" a single, representative frame.

  3. Analyzing the captured frame in the tool's UI to see a timeline of GPU work. You can find your expensive draw call and see precisely how many milliseconds were spent in its vertex and fragment shaders.

Step 3: An Iterative Optimization Workflow

With these tools, you can adopt a professional, data-driven workflow:

  1. Establish a Baseline: Run your scene and record the performance metrics. What is the current FPS? What does the GPU profiler say your vertex shader time is?

  2. Form a Hypothesis: Based on the principles in this article, identify a potential optimization. For example, "I believe the 6-octave noise function in my vertex shader is the main bottleneck."

  3. Implement One Change: Apply a single optimization. For example, add a LOD system to reduce the noise to 1 octave for distant vertices.

  4. Measure Again: Run the profiler again under the exact same conditions. Did the frame time improve? Did the vertex shader execution time decrease as expected?

  5. Verify: If the performance improved, you've confirmed your hypothesis. If not, revert the change and go back to step 2 with a new hypothesis.

This iterative cycle of measure → hypothesize → change → measure ensures that you are always making meaningful, data-backed improvements to your code.

Manual Debugging and Visualization

Sometimes a full profiling tool is overkill. You can often get a good sense of what your shader is doing by outputting debug information as colors.

Example: Visualizing LOD Distance

You can visualize which code path your LOD system is taking by passing the distance to the fragment shader and outputting it as a color.

// in vertex shader
// ...
out.distance_to_camera = distance_to_camera;

// in fragment shader
@location(4) distance_to_camera: f32,
// ...
// Color vertices by which LOD they fall into
var debug_color: vec3<f32>;
if in.distance_to_camera < 20.0 {
    debug_color = vec3(1.0, 0.0, 0.0); // Red = High Detail
} else if in.distance_to_camera < 50.0 {
    debug_color = vec3(0.0, 1.0, 0.0); // Green = Medium Detail
} else {
    debug_color = vec3(0.0, 0.0, 1.0); // Blue = Low Detail
}
return vec4<f32>(debug_color, 1.0);

This will instantly show you if your LOD thresholds are set correctly and are behaving as you expect, turning your scene into a performance heatmap.


Complete Example: Optimizing a Complex Vertex Shader

Theory is essential, but seeing optimization in action provides the crucial "aha!" moment. We will now apply every principle we've learned in a practical, step-by-step optimization of a complex ocean shader.

Our Goal

Our starting point is an unoptimized ocean shader that renders a large, detailed water plane with 40,000 vertices. It uses a multi-octave sine wave function for realistic ripples and a noise texture to generate sea foam. While it looks nice, its performance is poor, running at around 30 FPS on a typical GPU.

Our goal is to significantly improve its performance by applying optimization principles without significantly sacrificing the visual quality for the parts of the ocean closest to the camera. We will do this by creating a second, optimized version of the shader and a Bevy application that lets us toggle between the two in real-time to see the performance difference.

What This Project Demonstrates

  • Profiling and Baseline: How to establish a clear performance baseline.

  • Applying Optimizations: A step-by-step application of our optimization strategies.

    • Reducing mathematical complexity (fewer wave octaves).

    • Replacing branches with branchless math (mix, smoothstep).

    • Optimizing memory access (single texture sample).

    • Implementing a Level of Detail (LOD) system to do less work.

  • Verification: How to measure the concrete FPS improvement from each change.

  • Real-World Trade-offs: Understanding that optimization is about spending your performance budget intelligently, not just making code run faster.

The Unoptimized Shader: Identifying the Anti-Patterns

First, let's analyze the unoptimized shader and its associated Rust code. This is our starting point, running at ~30 FPS.

Dependency Note: Before adding the application code, you'll need to add one dependency to your project. The demo uses this to generate a noise texture on the CPU. Open your Cargo.toml file and add the following line under [dependencies]:

[dependencies]
bevy = "0.16" # Ensure this matches your Bevy version
noise = "0.9"

The Unoptimized Shader (assets/shaders/d02_08_ocean_unoptimized.wgsl)

This shader is full of common performance mistakes that we can now easily identify.

  • calculate_wave_height:

    • It uses a for loop to calculate 6 octaves of sine waves for every single vertex, which is incredibly expensive.

    • The expensive sin and cos functions are called repeatedly inside the loop.

  • calculate_foam:

    • It samples the same noise texture three times with slightly different UVs, resulting in three expensive memory fetches where one would suffice.

    • It uses a chain of if/else if/else statements based on wave_height, which is per-vertex data. This will cause significant branch divergence.

    • It manually clamps the foam value with if statements instead of using the built-in clamp() function.

  • vertex function:

    • It performs the full, expensive wave and foam calculations for every vertex, regardless of its distance from the camera. There is no LOD system.

    • It calculates the distance to the camera manually instead of using the built-in distance() function.

    • It has another divergent branch to calculate detail_level based on distance.

#import bevy_pbr::mesh_functions
#import bevy_pbr::view_transformations::position_world_to_clip

struct OceanMaterial {
    time: f32,
    camera_position: vec3<f32>,
    wave_amplitude: f32,
    wave_frequency: f32,
}

@group(2) @binding(0)
var<uniform> material: OceanMaterial;

@group(2) @binding(1)
var noise_texture: texture_2d<f32>;

@group(2) @binding(2)
var noise_sampler: sampler;

struct VertexInput {
    @builtin(instance_index) instance_index: u32,
    @location(0) position: vec3<f32>,
    @location(1) normal: vec3<f32>,
    @location(2) uv: vec2<f32>,
}

struct VertexOutput {
    @builtin(position) clip_position: vec4<f32>,
    @location(0) world_position: vec3<f32>,
    @location(1) world_normal: vec3<f32>,
    @location(2) foam_amount: f32,
    @location(3) wave_height: f32,
}

// A helper for a single directional wave.
fn wave(position: vec2<f32>, direction: vec2<f32>, frequency: f32, amplitude: f32, speed: f32, time: f32) -> f32 {
    let angle = dot(direction, position);
    return sin(angle * frequency + time * speed) * amplitude;
}

// ✗ ANTI-PATTERN: Excessive Complexity. Uses a long loop for ALL vertices.
fn calculate_wave_height(pos: vec3<f32>, time: f32) -> f32 {
    var height = 0.0;
    let base_amp = material.wave_amplitude;
    let base_freq = material.wave_frequency;

    // Use a loop with 6 waves to be expensive.
    let directions = array<vec2<f32>, 6>(
        normalize(vec2(1.0, 0.5)), normalize(vec2(0.8, 1.0)),
        normalize(vec2(1.0, 1.3)), normalize(vec2(-0.2, 1.0)),
        normalize(vec2(0.7, 0.7)), normalize(vec2(1.0, -0.3))
    );
    for (var i = 0u; i < 6u; i = i + 1u) {
        // We vary the parameters inside the loop to make it look complex.
        height += wave(pos.xz, directions[i], base_freq * (1.0 + f32(i)*0.2), base_amp * (1.0 - f32(i)*0.1), 1.0 + f32(i)*0.1, time);
    }
    return height;
}

// ✗ ANTI-PATTERN: Redundant Texture Samples & Divergent Branching.
fn calculate_foam(pos: vec3<f32>, wave_height: f32) -> f32 {
    let uv1 = fract(pos.xz * 0.1);
    let uv2 = fract(pos.xz * 0.15);
    let uv3 = fract(pos.xz * 0.2);

    // ✗ ANTI-PATTERN: Three separate, expensive texture fetches.
    let noise1 = textureSampleLevel(noise_texture, noise_sampler, uv1, 0.0).r;
    let noise2 = textureSampleLevel(noise_texture, noise_sampler, uv2, 0.0).g;
    let noise3 = textureSampleLevel(noise_texture, noise_sampler, uv3, 0.0).b;

    // ✗ ANTI-PATTERN: Divergent branch on per-vertex data.
    var foam = 0.0;
    if wave_height > 0.5 {
        foam = noise1;
    } else if wave_height > 0.25 {
        foam = noise2;
    } else {
        foam = noise3;
    }

    // ✗ ANTI-PATTERN: Manual clamping instead of built-in.
    if foam > 1.0 { foam = 1.0; }
    if foam < 0.0 { foam = 0.0; }

    return foam;
}

@vertex
fn vertex(in: VertexInput) -> VertexOutput {
    var out: VertexOutput;

    let model = mesh_functions::get_world_from_local(in.instance_index);
    let world_position = mesh_functions::mesh_position_local_to_world(
        model,
        vec4<f32>(in.position, 1.0)
    ).xyz;

    // ✗ ANTI-PATTERN: No LOD system. All vertices get the full calculation.
    let wave_height = calculate_wave_height(world_position, material.time);

    var displaced_position = world_position;
    displaced_position.y += wave_height;

    let foam = calculate_foam(displaced_position, wave_height);

    // ✗ ANTI-PATTERN: Not using built-in distance function.
    let dx = displaced_position.x - material.camera_position.x;
    let dy = displaced_position.y - material.camera_position.y;
    let dz = displaced_position.z - material.camera_position.z;
    let distance_to_camera = sqrt(dx*dx + dy*dy + dz*dz);

    // ✗ ANTI-PATTERN: A second divergent branch.
    var detail_level: f32;
    if distance_to_camera < 20.0 {
        detail_level = 1.0;
    } else if distance_to_camera < 50.0 {
        detail_level = 0.5;
    } else {
        detail_level = 0.25;
    }

    let world_normal = mesh_functions::mesh_normal_local_to_world(
        in.normal,
        in.instance_index
    );

    out.clip_position = position_world_to_clip(displaced_position);
    out.world_position = displaced_position;
    out.world_normal = normalize(world_normal);
    // Apply the pointless detail_level to match original behavior.
    out.foam_amount = foam * detail_level;
    out.wave_height = wave_height;

    return out;
}

@fragment
fn fragment(in: VertexOutput) -> @location(0) vec4<f32> {
    let normal = normalize(in.world_normal);

    let light_dir = normalize(vec3<f32>(1.0, 1.0, 1.0));
    let diffuse = max(0.0, dot(normal, light_dir)) * 0.7;
    let ambient = 0.3;

    let base_color = vec3<f32>(0.1, 0.3, 0.5);
    let foam_color = vec3<f32>(0.9, 0.9, 0.95);

    let final_color = mix(base_color, foam_color, in.foam_amount);
    let lit_color = final_color * (ambient + diffuse);

    return vec4<f32>(lit_color, 1.0);
}

The Unoptimized Rust Material (src/materials/d02_08_ocean_unoptimized.rs)

This is a standard Bevy material setup. It defines the uniforms struct and the Material implementation that links to our unoptimized WGSL file.

use bevy::prelude::*;
use bevy::render::render_resource::{AsBindGroup, ShaderRef};

mod uniforms {
    #![allow(dead_code)]

    use bevy::prelude::*;
    use bevy::render::render_resource::ShaderType;

    #[derive(ShaderType, Debug, Clone, Copy)]
    pub struct OceanUnoptimizedUniforms {
        pub time: f32,
        pub camera_position: Vec3,
        pub wave_amplitude: f32,
        pub wave_frequency: f32,
    }

    impl Default for OceanUnoptimizedUniforms {
        fn default() -> Self {
            Self {
                time: 0.0,
                camera_position: Vec3::ZERO,
                wave_amplitude: 0.5,
                wave_frequency: 1.0,
            }
        }
    }
}

pub use uniforms::OceanUnoptimizedUniforms;

#[derive(Asset, TypePath, AsBindGroup, Debug, Clone)]
pub struct OceanMaterialUnoptimized {
    #[uniform(0)]
    pub uniforms: OceanUnoptimizedUniforms,

    #[texture(1)]
    #[sampler(2)]
    pub noise_texture: Handle<Image>,
}

impl Material for OceanMaterialUnoptimized {
    fn vertex_shader() -> ShaderRef {
        "shaders/d02_08_ocean_unoptimized.wgsl".into()
    }

    fn fragment_shader() -> ShaderRef {
        "shaders/d02_08_ocean_unoptimized.wgsl".into()
    }
}

Don't forget to add it to src/materials/mod.rs:

// ... other materials
pub mod d02_08_ocean_unoptimized;
// We will add the optimized version later

Now that we have our baseline, let's start improving it.

The Optimized Version: A Breakdown of the Changes

We will now create a new set of files, d02_08_ocean_optimized.wgsl and d02_08_ocean_optimized.rs, that implement the fixes for all the performance problems we identified. We will then create a single Bevy application that allows us to switch between the unoptimized and optimized materials on the fly, making the performance difference immediately obvious.

A Critical Note: Why Not Precompute sin(time)?

Our very first optimization strategy was to move constant calculations to the CPU. A sharp-eyed reader might wonder why we aren't precomputing sin(material.time) and cos(material.time) on the CPU for our wave calculations.

This is a crucial point: precomputation only works for calculations that are independent of per-vertex attributes.

Our wave formula is sin(pos.x * frequency + time). The sin function depends on both pos.x (which is unique to each vertex) and time (which is uniform). There is no mathematical way to separate these two inputs; the wave's shape across space is intrinsically linked to the current time. Trying to precompute sin(time) would break the spatial wave effect, resulting in the entire ocean plane simply bobbing up and down as a single flat sheet.

Therefore, for this specific effect, the expensive sin and cos calls must remain in the vertex shader. Our primary strategies will be to do less work (by calling them less often in loops and for distant vertices via LOD) and to work more efficiently with the rest of the code.

The Optimized Shader (assets/shaders/d02_08_ocean_optimized.wgsl)

Here is the complete optimized shader. Read the comments carefully to see how each of our strategies has been applied.

  • LOD System: The main vertex function now checks the distance_to_camera first and calls different, cheaper versions of the wave calculation for medium and far vertices. This is the single biggest performance win.

  • Reduced Complexity: The most detailed wave function, calculate_wave_height_detailed, now only performs 2 octaves instead of 6. The _simple version performs only 1.

  • Optimized Memory Access: calculate_foam now performs only a single textureSampleLevel and uses the R, G, and B channels of the result.

  • Branchless Math: The divergent if/else if chain in calculate_foam has been replaced with a mathematically equivalent, branchless version using mix and smoothstep.

  • Built-in Functions: All manual calculations like distance and clamping have been replaced with their fast, built-in equivalents.

  • Cached Uniforms: Values like camera_pos and time are read from the uniform buffer once at the top of the shader and stored in local variables.

#import bevy_pbr::mesh_functions
#import bevy_pbr::view_transformations::position_world_to_clip

struct OceanMaterial {
    time: f32,
    time_sin: f32,
    time_cos: f32,
    time_sin_slow: f32,
    time_cos_slow: f32,
    camera_position: vec3<f32>,
    wave_amplitude: f32,
    wave_frequency: f32,
    lod_near: f32,
    lod_far: f32,
}

@group(2) @binding(0)
var<uniform> material: OceanMaterial;

@group(2) @binding(1)
var noise_texture: texture_2d<f32>;

@group(2) @binding(2)
var noise_sampler: sampler;

struct VertexInput {
    @builtin(instance_index) instance_index: u32,
    @location(0) position: vec3<f32>,
    @location(1) normal: vec3<f32>,
    @location(2) uv: vec2<f32>,
}

struct VertexOutput {
    @builtin(position) clip_position: vec4<f32>,
    @location(0) world_position: vec3<f32>,
    @location(1) world_normal: vec3<f32>,
    @location(2) foam_amount: f32,
    @location(3) wave_height: f32,
    @location(4) distance_to_camera: f32,
}

// A helper function for a single directional wave. The core of a realistic water effect.
// It takes a 2D position and calculates the wave height at that point.
fn wave(
    position: vec2<f32>,  // The world-space XZ position of the vertex
    direction: vec2<f32>, // A normalized 2D vector for the wave's direction
    frequency: f32,     // Controls the distance between wave crests (higher = choppier)
    amplitude: f32,     // Controls the height of the wave crests
    speed: f32,         // Controls how fast the wave travels in its direction
    time: f32           // The global time uniform to animate the wave
) -> f32 {
    // The dot product projects the vertex position onto the wave's direction vector.
    // This is the key to making the wave move in a specific direction instead of just along the X or Z axis.
    let angle = dot(direction, position);
    // The classic sine wave formula, using our calculated angle.
    return sin(angle * frequency + time * speed) * amplitude;
}

// Detailed wave calculation for nearby vertices. It sums four different directional waves.
// This is what creates the chaotic, overlapping, and natural-looking surface of water.
fn calculate_wave_height_detailed(pos: vec3<f32>, time: f32) -> f32 {
    let amp = material.wave_amplitude;
    let freq = material.wave_frequency;
    var height = 0.0;

    // We add four different waves together. Each one has a unique direction, frequency,
    // amplitude, and speed. This combination is what breaks the repetitive grid-like
    // pattern and makes the surface look organic.
    height += wave(pos.xz, normalize(vec2(1.0, 0.5)), freq * 1.2, amp * 0.5, 1.5, time);
    height += wave(pos.xz, normalize(vec2(0.8, 1.0)), freq * 0.8, amp * 0.3, 1.2, time);
    height += wave(pos.xz, normalize(vec2(1.0, 1.3)), freq * 2.2, amp * 0.15, 2.1, time);
    height += wave(pos.xz, normalize(vec2(-0.2, 1.0)), freq * 1.5, amp * 0.25, 1.8, time);

    return height;
}

// A much cheaper version for medium-distance vertices.
// It only calculates two of the four waves, saving half the work.
fn calculate_wave_height_simple(pos: vec3<f32>, time: f32) -> f32 {
    let amp = material.wave_amplitude;
    let freq = material.wave_frequency;
    var height = 0.0;

    // We use the two largest, most noticeable waves for the medium LOD.
    // The smaller, high-frequency detail waves are skipped as they wouldn't be visible from a distance anyway.
    height += wave(pos.xz, normalize(vec2(1.0, 0.5)), freq * 1.2, amp * 0.5, 1.5, time);
    height += wave(pos.xz, normalize(vec2(0.8, 1.0)), freq * 0.8, amp * 0.3, 1.2, time);
    return height;
}

// Foam calculation with optimized texture sampling
// ✓ Single texture sample (vs 3 in unoptimized)
// ✓ Uses all RGB channels from one sample
// ✓ Branchless calculation using smoothstep (vs if/else branching)
fn calculate_foam(pos: vec3<f32>, wave_height: f32) -> f32 {
    // ✓ Single texture sample, use all channels
    // We start with our world position, scaled for the desired tiling size.
    let tiling_coords = pos.xz * 0.05;

    // We use fract() to wrap the coordinates into the [0.0, 1.0] range.
    let uv = fract(tiling_coords);

    // We now sample the texture using these correctly wrapped UVs.
    let noise_sample = textureSampleLevel(
        noise_texture,
        noise_sampler,
        uv,
        0.0
    );

    // ✓ Branchless foam calculation using smoothstep
    let t1 = smoothstep(0.25, 0.5, wave_height);
    let t2 = smoothstep(0.5, 0.75, wave_height);
    let foam = mix(noise_sample.b, mix(noise_sample.g, noise_sample.r, t1), t2);

    // ✓ Use built-in clamp
    return clamp(foam, 0.0, 1.0);
}

@vertex
fn vertex(in: VertexInput) -> VertexOutput {
    var out: VertexOutput;

    // ✓ Cache uniform values
    let camera_pos = material.camera_position;
    let time = material.time;
    let lod_near = material.lod_near;
    let lod_far = material.lod_far;

    let model = mesh_functions::get_world_from_local(in.instance_index);
    let world_position = mesh_functions::mesh_position_local_to_world(
        model,
        vec4<f32>(in.position, 1.0)
    );

    // ✓ Use built-in distance function
    let distance_to_camera = distance(world_position.xyz, camera_pos);

    // ✓ LOD system: adjust detail based on distance
    var wave_height: f32;
    var foam: f32;

    if distance_to_camera < lod_near {
        // Close: Full quality (2 octaves)
        wave_height = calculate_wave_height_detailed(world_position.xyz, time);
        foam = calculate_foam(world_position.xyz, wave_height);
    } else if distance_to_camera < lod_far {
        // Medium: Simplified waves (1 octave), no foam
        wave_height = calculate_wave_height_simple(world_position.xyz, time);
        foam = 0.0;
    } else {
        // Far: Minimal processing
        wave_height = 0.0;
        foam = 0.0;
    }

    var displaced_position = world_position.xyz;
    displaced_position.y += wave_height;

    // Pass normal through (will be normalized in fragment shader after interpolation)
    let world_normal = mesh_functions::mesh_normal_local_to_world(
        in.normal,
        in.instance_index
    );

    out.clip_position = position_world_to_clip(displaced_position);
    out.world_position = displaced_position;
    out.world_normal = world_normal;
    out.foam_amount = foam;
    out.wave_height = wave_height;
    out.distance_to_camera = distance_to_camera;

    return out;
}

@fragment
fn fragment(in: VertexOutput) -> @location(0) vec4<f32> {
    // Normalize after interpolation
    let normal = normalize(in.world_normal);

    // Simple directional lighting
    let light_dir = normalize(vec3<f32>(1.0, 1.0, 1.0));
    let diffuse = max(0.0, dot(normal, light_dir)) * 0.7;
    let ambient = 0.3;

    // Ocean color
    let base_color = vec3<f32>(0.1, 0.3, 0.5);
    let foam_color = vec3<f32>(0.9, 0.9, 0.95);

    // Mix base and foam based on calculated amount
    let final_color = mix(base_color, foam_color, in.foam_amount);
    let lit_color = final_color * (ambient + diffuse);

    // Optional: Fade distant ocean to horizon color
    let fade = smoothstep(100.0, 50.0, in.distance_to_camera);
    let horizon_color = vec3<f32>(0.6, 0.7, 0.8);
    let color_with_distance = mix(horizon_color, lit_color, fade);

    return vec4<f32>(color_with_distance, 1.0);
}

The Optimized Rust Material (src/materials/d02_08_ocean_optimized.rs)

The Rust code for our optimized material is very similar to the unoptimized one, but we add the new lod_near and lod_far fields to the OceanUniforms struct. This allows us to control the Level of Detail distances from our Bevy application.

We also include the time_sin and time_cos fields to match the WGSL struct, even though they aren't used for the wave effect itself. This demonstrates how you would pass precomputed values if you had other, non-position-dependent effects.

use bevy::prelude::*;
use bevy::render::render_resource::{AsBindGroup, ShaderRef};

mod uniforms {
    #![allow(dead_code)]

    use bevy::prelude::*;
    use bevy::render::render_resource::ShaderType;

    #[derive(ShaderType, Debug, Clone, Copy)]
    pub struct OceanUniforms {
        pub time: f32,
        pub time_sin: f32,
        pub time_cos: f32,
        pub time_sin_slow: f32,
        pub time_cos_slow: f32,
        pub camera_position: Vec3,
        pub wave_amplitude: f32,
        pub wave_frequency: f32,
        pub lod_near: f32,
        pub lod_far: f32,
    }

    impl Default for OceanUniforms {
        fn default() -> Self {
            Self {
                time: 0.0,
                time_sin: 0.0,
                time_cos: 1.0,
                time_sin_slow: 0.0,
                time_cos_slow: 1.0,
                camera_position: Vec3::ZERO,
                wave_amplitude: 0.5,
                wave_frequency: 0.5, // Lower frequency for more visible waves
                lod_near: 80.0,      // Increased to show waves for most of the ocean
                lod_far: 120.0,      // Increased to match scene size
            }
        }
    }
}

pub use uniforms::OceanUniforms;

#[derive(Asset, TypePath, AsBindGroup, Debug, Clone)]
pub struct OceanMaterial {
    #[uniform(0)]
    pub uniforms: OceanUniforms,

    #[texture(1)]
    #[sampler(2)]
    pub noise_texture: Handle<Image>,
}

impl Material for OceanMaterial {
    fn vertex_shader() -> ShaderRef {
        "shaders/d02_08_ocean_optimized.wgsl".into()
    }

    fn fragment_shader() -> ShaderRef {
        "shaders/d02_08_ocean_optimized.wgsl".into()
    }
}

Don't forget to add it to src/materials/mod.rs:

// ... other materials
pub mod d02_08_ocean_unoptimized;
pub mod d02_08_ocean_optimized;

The Demo Module (src/demos/d02_08_ocean_demo.rs)

This is the core of our project. This Bevy application sets up a scene with a single large ocean plane and allows the user to press a key to hot-swap the material between the unoptimized and optimized versions.

Key components of this file:

  • OceanMaterials Resource: We create a custom resource to hold the handles for both our OceanMaterial and OceanMaterialUnoptimized. This gives us a single, reliable place to access them from any system.

  • setup function: Creates the high-vertex ocean mesh, generates a procedural noise texture, and creates both material assets. It spawns the ocean entity initially using the optimized material.

  • toggle_optimization system: This is the magic. When the user presses P, this system gets the Entity of our ocean plane. It then uses commands to remove the currently active material component (e.g., MeshMaterial3d<OceanMaterial>) and insert the other one (e.g., MeshMaterial3d<OceanMaterialUnoptimized>). Bevy's renderer detects this change and uses the new shader on the next frame.

  • update_ocean_materials system: This system runs every frame to update the time and camera_position uniforms for both materials, ensuring a seamless visual transition when toggling.

  • Other Systems: Standard systems are included to handle camera controls, UI updates, and input for adjusting wave parameters.

use crate::materials::d02_08_ocean_optimized::{OceanMaterial, OceanUniforms};
use crate::materials::d02_08_ocean_unoptimized::{OceanMaterialUnoptimized, OceanUnoptimizedUniforms};
use bevy::diagnostic::{FrameTimeDiagnosticsPlugin, LogDiagnosticsPlugin};
use bevy::prelude::*;
use std::f32::consts::PI;

#[derive(Component)]
struct OceanPlane;

#[derive(Component)]
struct OrbitCamera {
    radius: f32,
    angle: f32,
    height: f32,
}

// Resource to store both material handles
#[derive(Resource)]
struct OceanMaterials {
    optimized: Handle<OceanMaterial>,
    unoptimized: Handle<OceanMaterialUnoptimized>,
    using_optimized: bool,
}

pub fn run() {
    App::new()
        .add_plugins(DefaultPlugins)
        .add_plugins(FrameTimeDiagnosticsPlugin::default())
        .add_plugins(LogDiagnosticsPlugin::default())
        .add_plugins(MaterialPlugin::<OceanMaterial>::default())
        .add_plugins(MaterialPlugin::<OceanMaterialUnoptimized>::default())
        .add_systems(Startup, setup)
        .add_systems(
            Update,
            (
                update_ocean_materials,
                handle_input,
                update_camera,
                toggle_optimization,
                update_ui,
            ),
        )
        .run();
}

fn setup(
    mut commands: Commands,
    mut meshes: ResMut<Assets<Mesh>>,
    mut optimized_materials: ResMut<Assets<OceanMaterial>>,
    mut unoptimized_materials: ResMut<Assets<OceanMaterialUnoptimized>>,
    mut images: ResMut<Assets<Image>>,
) {
    // Generate noise texture
    let noise_texture = generate_noise_texture(256);
    let noise_handle = images.add(noise_texture);

    // Create ocean plane with high vertex count
    let ocean_mesh = create_ocean_plane(200, 200, 100.0);
    let mesh_handle = meshes.add(ocean_mesh);

    println!("Ocean mesh: 200x200 grid = 40,000 vertices");

    // Create both materials
    let optimized_handle = optimized_materials.add(OceanMaterial {
        uniforms: OceanUniforms::default(),
        noise_texture: noise_handle.clone(),
    });

    let unoptimized_handle = unoptimized_materials.add(OceanMaterialUnoptimized {
        uniforms: OceanUnoptimizedUniforms::default(),
        noise_texture: noise_handle.clone(),
    });

    // Store material handles in a resource
    commands.insert_resource(OceanMaterials {
        optimized: optimized_handle.clone(),
        unoptimized: unoptimized_handle.clone(),
        using_optimized: true,
    });

    // Spawn ocean with optimized material initially
    commands.spawn((
        Mesh3d(mesh_handle.clone()),
        MeshMaterial3d(optimized_handle.clone()),
        Transform::from_xyz(0.0, 0.0, 0.0),
        OceanPlane,
    ));

    // Lighting
    commands.spawn((
        DirectionalLight {
            illuminance: 15000.0,
            shadows_enabled: false,
            ..default()
        },
        Transform::from_rotation(Quat::from_euler(EulerRot::XYZ, -PI / 3.0, PI / 4.0, 0.0)),
    ));

    // Camera
    commands.spawn((
        Camera3d::default(),
        Transform::from_xyz(0.0, 30.0, 50.0).looking_at(Vec3::ZERO, Vec3::Y),
        OrbitCamera {
            radius: 50.0,
            angle: 0.0,
            height: 30.0,
        },
    ));

    // UI
    commands.spawn((
        Text::new(
            "[P] Toggle Optimized/Unoptimized Shader\n\
             [Arrow Keys] Rotate Camera | [Z/X] Camera Height\n\
             [+/-] Wave Amplitude | [[ / ]] Wave Frequency\n\
             [1/2] LOD Near Distance (optimized only)\n\
             \n\
             Mode: OPTIMIZED | FPS: -- | Vertices: 40,000",
        ),
        Node {
            position_type: PositionType::Absolute,
            top: Val::Px(10.0),
            left: Val::Px(10.0),
            padding: UiRect::all(Val::Px(10.0)),
            ..default()
        },
        TextFont {
            font_size: 16.0,
            ..default()
        },
        TextColor(Color::WHITE),
        BackgroundColor(Color::srgba(0.0, 0.0, 0.0, 0.7)),
    ));
}

fn create_ocean_plane(width_segments: u32, height_segments: u32, size: f32) -> Mesh {
    use bevy::render::mesh::{Indices, PrimitiveTopology};
    use bevy::render::render_asset::RenderAssetUsages;

    let mut positions = Vec::new();
    let mut normals = Vec::new();
    let mut uvs = Vec::new();
    let mut indices = Vec::new();

    for y in 0..=height_segments {
        for x in 0..=width_segments {
            let u = x as f32 / width_segments as f32;
            let v = y as f32 / height_segments as f32;

            let pos_x = (u - 0.5) * size;
            let pos_z = (v - 0.5) * size;

            positions.push([pos_x, 0.0, pos_z]);
            normals.push([0.0, 1.0, 0.0]);
            uvs.push([u, v]);
        }
    }

    for y in 0..height_segments {
        for x in 0..width_segments {
            let quad_start = y * (width_segments + 1) + x;

            indices.push(quad_start);
            indices.push(quad_start + width_segments + 1);
            indices.push(quad_start + 1);

            indices.push(quad_start + 1);
            indices.push(quad_start + width_segments + 1);
            indices.push(quad_start + width_segments + 2);
        }
    }

    let mut mesh = Mesh::new(
        PrimitiveTopology::TriangleList,
        RenderAssetUsages::default(),
    );

    mesh.insert_attribute(Mesh::ATTRIBUTE_POSITION, positions);
    mesh.insert_attribute(Mesh::ATTRIBUTE_NORMAL, normals);
    mesh.insert_attribute(Mesh::ATTRIBUTE_UV_0, uvs);
    mesh.insert_indices(Indices::U32(indices));

    mesh
}

fn generate_noise_texture(size: u32) -> Image {
    use bevy::render::render_asset::RenderAssetUsages;
    use bevy::render::render_resource::{Extent3d, TextureDimension, TextureFormat};
    use noise::{NoiseFn, Perlin};
    use std::f64::consts::PI;

    // Use a 4D Perlin noise function for generating seamless 2D noise.
    let perlin = Perlin::new(42);
    let mut data = Vec::with_capacity((size * size * 4) as usize);

    for y in 0..size {
        for x in 0..size {
            // Map the 2D coordinates to a circle in 4D space.
            // This is the mathematical trick to make the noise tileable.
            let angle_x = (x as f64 / size as f64) * 2.0 * PI;
            let angle_y = (y as f64 / size as f64) * 2.0 * PI;

            // We use sin/cos to wrap the coordinates around, ensuring the
            // start and end points of the texture match up perfectly.
            let p_x = angle_x.cos();
            let p_y = angle_x.sin();
            let p_z = angle_y.cos();
            let p_w = angle_y.sin();

            // The scale factor determines the "zoom" of the noise pattern.
            let scale = 2.0;
            let noise_value = perlin.get([p_x * scale, p_y * scale, p_z * scale, p_w * scale]);

            // Map the noise value from [-1, 1] to [0, 255] for the texture.
            let byte_value = ((noise_value + 1.0) * 0.5 * 255.0) as u8;

            data.push(byte_value); // R
            data.push(byte_value); // G
            data.push(byte_value); // B
            data.push(255); // A
        }
    }

    Image::new(
        Extent3d {
            width: size,
            height: size,
            depth_or_array_layers: 1,
        },
        TextureDimension::D2,
        data,
        TextureFormat::Rgba8Unorm,
        RenderAssetUsages::default(),
    )
}

fn update_ocean_materials(
    time: Res<Time>,
    camera_query: Query<&Transform, With<Camera3d>>,
    mut optimized_materials: ResMut<Assets<OceanMaterial>>,
    mut unoptimized_materials: ResMut<Assets<OceanMaterialUnoptimized>>,
) {
    let t = time.elapsed_secs();

    if let Ok(camera_transform) = camera_query.single() {
        // Update optimized material (with precomputed values)
        for (_, material) in optimized_materials.iter_mut() {
            material.uniforms.time = t;
            material.uniforms.time_sin = t.sin();
            material.uniforms.time_cos = t.cos();
            material.uniforms.time_sin_slow = (t * 0.5).sin();
            material.uniforms.time_cos_slow = (t * 0.5).cos();
            material.uniforms.camera_position = camera_transform.translation;
        }

        // Update unoptimized material (just time and camera)
        for (_, material) in unoptimized_materials.iter_mut() {
            material.uniforms.time = t;
            material.uniforms.camera_position = camera_transform.translation;
        }
    }
}

fn handle_input(
    keyboard: Res<ButtonInput<KeyCode>>,
    time: Res<Time>,
    mut optimized_materials: ResMut<Assets<OceanMaterial>>,
    mut unoptimized_materials: ResMut<Assets<OceanMaterialUnoptimized>>,
) {
    let delta = time.delta_secs();

    // Update wave amplitude for optimized materials
    for (_, material) in optimized_materials.iter_mut() {
        // Wave amplitude controls
        if keyboard.pressed(KeyCode::Equal) {
            material.uniforms.wave_amplitude =
                (material.uniforms.wave_amplitude + delta * 0.5).min(3.0);
            println!("Wave amplitude: {:.2}", material.uniforms.wave_amplitude);
        }
        if keyboard.pressed(KeyCode::Minus) {
            material.uniforms.wave_amplitude =
                (material.uniforms.wave_amplitude - delta * 0.5).max(0.0);
            println!("Wave amplitude: {:.2}", material.uniforms.wave_amplitude);
        }

        // Wave frequency controls
        if keyboard.pressed(KeyCode::BracketRight) {
            material.uniforms.wave_frequency =
                (material.uniforms.wave_frequency + delta * 0.2).min(2.0);
            println!("Wave frequency: {:.2}", material.uniforms.wave_frequency);
        }
        if keyboard.pressed(KeyCode::BracketLeft) {
            material.uniforms.wave_frequency =
                (material.uniforms.wave_frequency - delta * 0.2).max(0.1);
            println!("Wave frequency: {:.2}", material.uniforms.wave_frequency);
        }

        // LOD distance controls
        if keyboard.pressed(KeyCode::Digit1) {
            material.uniforms.lod_near = (material.uniforms.lod_near + delta * 10.0).min(200.0);
            println!(
                "LOD near: {:.1}, far: {:.1}",
                material.uniforms.lod_near, material.uniforms.lod_far
            );
        }
        if keyboard.pressed(KeyCode::Digit2) {
            material.uniforms.lod_near = (material.uniforms.lod_near - delta * 10.0).max(10.0);
            println!(
                "LOD near: {:.1}, far: {:.1}",
                material.uniforms.lod_near, material.uniforms.lod_far
            );
        }
    }

    // Update wave amplitude for unoptimized materials
    for (_, material) in unoptimized_materials.iter_mut() {
        if keyboard.pressed(KeyCode::Equal) {
            material.uniforms.wave_amplitude =
                (material.uniforms.wave_amplitude + delta * 0.5).min(3.0);
        }
        if keyboard.pressed(KeyCode::Minus) {
            material.uniforms.wave_amplitude =
                (material.uniforms.wave_amplitude - delta * 0.5).max(0.0);
        }

        if keyboard.pressed(KeyCode::BracketRight) {
            material.uniforms.wave_frequency =
                (material.uniforms.wave_frequency + delta * 0.2).min(2.0);
        }
        if keyboard.pressed(KeyCode::BracketLeft) {
            material.uniforms.wave_frequency =
                (material.uniforms.wave_frequency - delta * 0.2).max(0.1);
        }
    }
}

fn update_camera(
    keyboard: Res<ButtonInput<KeyCode>>,
    time: Res<Time>,
    mut camera_query: Query<(&mut Transform, &mut OrbitCamera), With<Camera3d>>,
) {
    if let Ok((mut transform, mut orbit)) = camera_query.single_mut() {
        let delta = time.delta_secs();

        if keyboard.pressed(KeyCode::ArrowLeft) {
            orbit.angle += delta;
        }
        if keyboard.pressed(KeyCode::ArrowRight) {
            orbit.angle -= delta;
        }

        if keyboard.pressed(KeyCode::KeyZ) {
            orbit.height = (orbit.height - delta * 20.0).max(5.0);
        }
        if keyboard.pressed(KeyCode::KeyX) {
            orbit.height = (orbit.height + delta * 20.0).min(100.0);
        }

        let x = orbit.angle.cos() * orbit.radius;
        let z = orbit.angle.sin() * orbit.radius;

        transform.translation = Vec3::new(x, orbit.height, z);
        transform.look_at(Vec3::ZERO, Vec3::Y);
    }
}

fn toggle_optimization(
    keyboard: Res<ButtonInput<KeyCode>>,
    mut materials_res: ResMut<OceanMaterials>,
    mut commands: Commands,
    ocean_entity: Query<Entity, With<OceanPlane>>,
) {
    if keyboard.just_pressed(KeyCode::KeyP) {
        materials_res.using_optimized = !materials_res.using_optimized;

        println!(
            "Switched to {} shader",
            if materials_res.using_optimized {
                "OPTIMIZED"
            } else {
                "UNOPTIMIZED"
            }
        );

        // Get the ocean entity
        if let Ok(entity) = ocean_entity.single() {
            // Remove old material component and add new one
            if materials_res.using_optimized {
                // Switch to optimized
                commands
                    .entity(entity)
                    .remove::<MeshMaterial3d<OceanMaterialUnoptimized>>()
                    .insert(MeshMaterial3d(materials_res.optimized.clone()));
            } else {
                // Switch to unoptimized
                commands
                    .entity(entity)
                    .remove::<MeshMaterial3d<OceanMaterial>>()
                    .insert(MeshMaterial3d(materials_res.unoptimized.clone()));
            }
        }
    }
}

fn update_ui(
    diagnostics: Res<bevy::diagnostic::DiagnosticsStore>,
    materials_res: Res<OceanMaterials>,
    optimized_materials: Res<Assets<OceanMaterial>>,
    unoptimized_materials: Res<Assets<OceanMaterialUnoptimized>>,
    mut text_query: Query<&mut Text>,
) {
    if let Some(fps_diagnostic) = diagnostics.get(&FrameTimeDiagnosticsPlugin::FPS) {
        if let Some(fps_smoothed) = fps_diagnostic.smoothed() {
            let mode_text = if materials_res.using_optimized {
                "OPTIMIZED"
            } else {
                "UNOPTIMIZED"
            };

            // Get current material settings
            let (wave_amp, wave_freq, lod_info) = if materials_res.using_optimized {
                if let Some(mat) = optimized_materials.get(&materials_res.optimized) {
                    (
                        mat.uniforms.wave_amplitude,
                        mat.uniforms.wave_frequency,
                        format!(
                            " | LOD: {:.0}/{:.0}",
                            mat.uniforms.lod_near, mat.uniforms.lod_far
                        ),
                    )
                } else {
                    (0.5, 0.5, String::new())
                }
            } else {
                if let Some(mat) = unoptimized_materials.get(&materials_res.unoptimized) {
                    (
                        mat.uniforms.wave_amplitude,
                        mat.uniforms.wave_frequency,
                        String::new(),
                    )
                } else {
                    (0.5, 0.5, String::new())
                }
            };

            for mut text in text_query.iter_mut() {
                **text = format!(
                    "[P] Toggle Optimized/Unoptimized Shader\n\
                     [Arrow Keys] Rotate Camera | [Z/X] Camera Height\n\
                     [+/-] Wave Amplitude | [[ / ]] Wave Frequency\n\
                     [1/2] LOD Near Distance (optimized only)\n\
                     \n\
                     Mode: {} | FPS: {:.0} | Vertices: 40,000\n\
                     Amplitude: {:.2} | Frequency: {:.2}{}",
                    mode_text, fps_smoothed, wave_amp, wave_freq, lod_info
                );
            }
        }
    }
}

Don't forget to add it to src/demos/mod.rs:

// ... other demos
pub mod d02_08_ocean_demo;

And register it in src/main.rs:

Demo {
    number: "2.8",
    title: "Vertex Shader Optimization",
    run: demos::d02_08_ocean_demo::run,
},

Running the Demo

Now that we have all the code in place, run the application. You'll see a large ocean plane with rolling waves and sea foam. The UI displays the current performance and allows you to control the scene. The goal of this demo is not necessarily to see a slideshow turn into a smooth experience, but to directly compare the relative cost of two different approaches to writing a shader.

Controls

KeyAction
PToggle between the Optimized/Unoptimized shader.
Arrow KeysOrbit the camera around the center of the ocean.
Z / XLower / Raise the camera height.
+ / -Increase / Decrease the wave amplitude.
[ / ]Decrease / Increase the wave frequency.
1 / 2Adjust the LOD near distance (Optimized only).

What You're Seeing

This is the core of the lesson. Press P to toggle between the two shaders and watch the "FPS" counter in the UI.

The performance difference you see will depend heavily on your GPU.

  • On a lower-end or older GPU, you may see a very significant FPS difference, illustrating the real-world cost of the unoptimized code.

  • On a high-end, modern GPU, both versions might run at a high frame rate. However, the unoptimized shader will be using significantly more of your GPU's power to do so. In a real game, that wasted power would mean less budget for everything else - other models, other effects, and higher resolutions.

The key takeaway is that the unoptimized version is demonstrably more expensive, and this demo gives you a tool to see that cost on your own hardware.

Optimization Breakdown

In our simple demo, some of the individual changes we've made might only provide a small, or even unmeasurable, FPS boost on a modern GPU. However, in a large-scale game with millions of vertices, complex lighting, and dozens of different materials, these optimizations are not just good practice; they are the difference between a playable frame rate and an unworkable one.

Let's break down the cost of each anti-pattern in the unoptimized shader and how the optimized version solves it.

Anti-Pattern 1: Redundant Texture Samples

  • The Cost: The unoptimized calculate_foam function calls textureSampleLevel three times per vertex. For our 40,000 vertex mesh, this is 120,000 expensive memory fetches per frame just for the foam. In a real game with multiple textures per material, this kind of redundant sampling can quickly become a major memory bottleneck.

  • The Solution: The optimized version samples the texture only once and stores the result in a local vec4 variable. It then reuses the .r, .g, and .b components of this variable. This is a 3x reduction in memory fetches for this effect, a critical habit for writing scalable code.

Anti-Pattern 2: Divergent Branching

  • The Cost: The if/else if/else chain in the unoptimized foam logic is highly divergent, forcing parallel GPU hardware to run sequentially. While modern GPUs have improved branch prediction, this is still fundamentally inefficient and forces vertices in a wavefront to wait on each other.

  • The Solution: The optimized version uses a branchless equivalent with smoothstep and mix. This transforms the problem from "choosing which code to run" to "calculating a value," which is what GPUs are designed to do with maximum efficiency.

Anti-Pattern 3: No Level of Detail (LOD)

  • The Cost: The unoptimized shader calculates 6 complex waves for every single vertex, even those that are just a few pixels wide on the horizon. This is the single biggest waste of computation in the shader.

  • The Solution: The optimized shader implements a simple LOD system. This is the highest-impact optimization. By running cheaper math (or no math at all) for the vast majority of vertices that are far from the camera, we save millions of expensive sin() calculations per frame. This is a foundational technique for rendering large, detailed worlds.

Anti-Pattern 4: Manual Math and Incorrect Data Handling

  • The Cost: The unoptimized shader uses several less-than-ideal practices: it manually calculates distance instead of using the optimized built-in, and it passes incorrect coordinates to the texture sampler. While a single manual distance calculation might be immeasurably slow on its own, thousands of these small inefficiencies across a large project add up to a significant performance drain.

  • The Solution: The optimized version uses best practices. It uses the fast, built-in distance() function. It uses fract() to correctly wrap coordinates for texture tiling. These are "micro-optimizations" that, when practiced consistently, lead to robust, high-performance code that is also easier to read and maintain.

Key Takeaways

This phase has been a deep dive into the art of vertex manipulation. Before we move on to coloring our creations, let's solidify the core principles of vertex shader programming and optimization.

  1. Think Like the GPU: Minimize Divergence: The GPU processes vertices in lockstep. Avoid if/else statements that depend on per-vertex data. Whenever possible, use branchless math (mix, step, clamp) to turn control flow problems into data calculation problems.

  2. Do Less Work: The most significant performance gains often come from simply doing less. A Level of Detail (LOD) system that runs cheaper calculations for distant objects is the most powerful tool in your optimization arsenal.

  3. Work Efficiently: Your shader's performance is not just about the math, but how you access memory. Minimize expensive texture fetches by packing data into RGBA channels and sampling only once. Always prefer the GPU's highly-optimized built-in functions over manual implementations.

  4. Profile First, Optimize Second: Don't guess where your performance bottlenecks are. Use profiling tools to gather data, form a hypothesis, make one change, and measure the result. A data-driven approach is the hallmark of a professional graphics programmer.

  5. Correct Your Data: The most subtle bugs can come from a misunderstanding of your data's coordinate space. Ensure the coordinates you pass to functions (like textureSampleLevel) are in the range they expect ([0, 1]) by using functions like fract() for tiling.

What's Next?

You are now equipped with the knowledge to control the position, shape, and animation of every vertex on the GPU, and how to do so for thousands of objects at once with high performance. We have completed our journey through the first major programmable stage of the graphics pipeline.

In the next phase, we will shift our focus from the "what" and "where" to the "how it looks." We will dive headfirst into the Fragment Shader, the part of the pipeline that runs for every pixel on your screen and is responsible for giving our creations color, texture, and life.

Next up: 3.1 - Fragment Shaders Fundamentals


Quick Reference

Use this section as a quick reminder of the core concepts, patterns, and functions covered in this phase.

The Golden Rule of Shader Optimization

Minimize divergence, maximize uniformity. Write code that allows the GPU to perform the same simple operations on large batches of vertices in lockstep.

The Optimization Workflow

  1. Profile First: Use tools to measure performance and identify the actual bottleneck (CPU vs. GPU, Vertex vs. Fragment). Don't guess.

  2. Hypothesize: Form a theory about what is slow (e.g., "This loop is too complex").

  3. Change One Thing: Apply a single optimization strategy.

  4. Measure Again: Verify that the change had the intended positive impact.

  5. Repeat.

Key Optimization Strategies Checklist

  • Move to CPU: Is the calculation the same for all vertices (e.g., sin(time))? Pre-calculate it once on the CPU and pass it in a uniform.

  • Implement LOD: Are you doing expensive work for distant vertices? Add a distance_to_camera check to run a cheaper version of the shader (or do no work at all) for far-away objects.

  • Reduce Divergence: Do you have if/else statements that depend on per-vertex data like position or uv? Rewrite them using branchless math.

  • Minimize Memory Access:

    • Are you calling textureSample multiple times? Sample once and store the result in a local vec4.

    • Are you using multiple grayscale textures? Pack them into the R, G, B, and A channels of a single texture.

  • Use Built-ins: Are you writing common math functions by hand (e.g., v / length(v))? Replace them with the highly optimized built-in function (normalize(v)).

  • Correct Your Data: Are you passing coordinates to a sampler that are outside the expected range? Use fract() to wrap them for correct tiling.

Branchless Patterns

Instead of...Use...
if (x > threshold) { a } else { b }let result = mix(b, a, step(threshold, x));
if (x < 0.0) { 0.0 } else { x }let result = max(x, 0.0);
if (x > 1.0) { 1.0 } else if ... { x }let result = clamp(x, 0.0, 1.0);

Performance Targets (for a 60 FPS goal)

  • Total Frame Budget: 16.6 ms

  • Typical Vertex Shader Budget: < 4 ms

If a GPU profiler shows your vertex shader is taking more than a few milliseconds, it's a prime candidate for optimization.