Clemsim 1 rok temu
rodzic
commit
9ca92a8dee

+ 59 - 3
src/compute/compute_pipeline.rs

@@ -1,9 +1,11 @@
+use core::time;
+
 use serde::de;
 use wgpu::{include_wgsl, util::DeviceExt, BindGroup, BufferUsages, ComputePipeline, Texture};
 
 use super::SIZE;
 
-pub fn new_compute_pipeline(device: &wgpu::Device)-> (ComputePipeline,BindGroup,(u32,u32), Texture){
+pub fn new_compute_pipeline(device: &wgpu::Device)-> (ComputePipeline,BindGroup,(u32,u32), Texture, BindGroup, wgpu::Buffer){
 
     // let u32_size = std::mem::size_of::<u64>() as wgpu::BufferAddress;
     // let time_buffer = device.create_buffer(&wgpu::BufferDescriptor{
@@ -70,12 +72,66 @@ pub fn new_compute_pipeline(device: &wgpu::Device)-> (ComputePipeline,BindGroup,
             }
         ]
     });
+    //Size and Time Buffer
+    let size_buffer = device.create_buffer_init(&wgpu::util::BufferInitDescriptor{
+        label: Some("Size Buffer"),
+        contents: bytemuck::cast_slice(&[SIZE as u32]),
+        usage: wgpu::BufferUsages::UNIFORM | wgpu::BufferUsages::COPY_DST
+    });
+    let time_buffer = device.create_buffer_init(&wgpu::util::BufferInitDescriptor{
+        label: Some("Time Buffer"),
+        contents: bytemuck::cast_slice(&[0.0f32]),
+        usage: wgpu::BufferUsages::UNIFORM | wgpu::BufferUsages::COPY_DST
+    });
+    //Size and Time Bind group
+    let size_time_bind_group_layout = device.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor{
+        label: Some("Size and Time Bind Group Layout"),
+        entries: &[
+            wgpu::BindGroupLayoutEntry{
+                binding: 0,
+                visibility: wgpu::ShaderStages::COMPUTE,
+                count: None,
+                ty: wgpu::BindingType::Buffer{
+                    ty: wgpu::BufferBindingType::Uniform,
+                    has_dynamic_offset: false,
+                    min_binding_size: None
+                }
+            },
+            wgpu::BindGroupLayoutEntry{
+                binding: 1,
+                visibility: wgpu::ShaderStages::COMPUTE,
+                count: None,
+                ty: wgpu::BindingType::Buffer{
+                    ty: wgpu::BufferBindingType::Uniform,
+                    has_dynamic_offset: false,
+                    min_binding_size: None
+                }
+            }
+        ]
+    });
+    let size_time_bind_group = device.create_bind_group(&wgpu::BindGroupDescriptor{
+        label: Some("Size and Time Bind Group"),
+        layout: &size_time_bind_group_layout,
+        entries: &[
+            wgpu::BindGroupEntry{
+                binding: 0,
+                resource: time_buffer.as_entire_binding()
+            },
+            wgpu::BindGroupEntry{
+                binding: 1,
+                resource: size_buffer.as_entire_binding()
+            }
+        ]
+    });
+
+
 
     //Pipeline Stuff
     let compute_pipeline_layout = device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor{
         label: Some("Compute Pipeline Layout 1"),
         bind_group_layouts: &[
-            &heightmap_texture_bind_group_layout
+            &heightmap_texture_bind_group_layout,
+            &size_time_bind_group_layout
         ],
         push_constant_ranges: &[]
     });
@@ -88,7 +144,7 @@ pub fn new_compute_pipeline(device: &wgpu::Device)-> (ComputePipeline,BindGroup,
 
     let dispatch = compute_work_group_count((SIZE,SIZE), (1,1));
     //let dispatch = (1,1);
-    return (compute_pipeline,heightmap_texture_bind_group,dispatch, output_texture)
+    return (compute_pipeline,heightmap_texture_bind_group,dispatch, output_texture, size_time_bind_group, time_buffer);
 }
 
 pub fn compute_work_group_count(

+ 11 - 0
src/lib.rs

@@ -58,6 +58,8 @@ struct State<'a> {
     visualisation_texture: Texture,
     visualisation_texture_view: TextureView,
     visualisation_texture_bind_group_layout: BindGroupLayout,
+    size_and_time: wgpu::BindGroup,
+    time_buffer: wgpu::Buffer,
 }
 
 impl<'a> State<'a> {
@@ -290,6 +292,8 @@ impl<'a> State<'a> {
             visualisation_texture: y.0,
             visualisation_texture_view: y.1,
             visualisation_texture_bind_group_layout: y.2,
+            size_and_time: x.4,
+            time_buffer: x.5,
         }
     }
 
@@ -348,6 +352,7 @@ impl<'a> State<'a> {
             });
             compute_pass.set_pipeline(&self.compute_pipeline);
             compute_pass.set_bind_group(0, &self.heightmap2_bindgroup, &[]);
+            compute_pass.set_bind_group(1, &self.size_and_time, &[]);
             compute_pass.dispatch_workgroups(self.dispatch.0, self.dispatch.1, 1);
             //compute_pass.dispatch_workgroups(256, 256, 1);
         }
@@ -459,6 +464,12 @@ impl<'a> State<'a> {
             self.last_fps = Instant::now();
             println!("{:?}", smoothed_fps.ceil());
         }
+        self.framecount += 1;
+        self.time_buffer = self.device.create_buffer_init(&wgpu::util::BufferInitDescriptor{
+            label: Some("Time Buffer"),
+            contents: bytemuck::cast_slice(&[self.framecount]),
+            usage: wgpu::BufferUsages::UNIFORM | wgpu::BufferUsages::COPY_DST
+        });
         Ok(())
     }
 }

+ 0 - 57
src/shaders/CS_FFTHorizontal.wgsl

@@ -1,57 +0,0 @@
-@group(0) @binding(0)
-var<uniform> u_input: texture_2d<f32>; // readonly image2D
-
-@group(0) @binding(1)
-var<storage, write> u_output: texture_2d<f32>; // writeonly image2D
-
-let PI: f32 = 3.14159265358979323846;
-
-struct Complex {
-  x: f32,
-  y: f32,
-};
-
-fn multiply_complex(a: Complex, b: Complex) -> Complex {
-  return Complex(
-    a.x * b.x - a.y * b.y,
-    a.y * b.x + a.x * b.y
-  );
-}
-
-fn add_complex(a: Complex, b: Complex) -> Complex{
-  return Complex(
-    a.x + b.x,
-    a.y + b.y
-  )
-}
-
-fn butterfly_operation(a: Complex, b: Complex, twiddle: Complex) -> vec4<f32> {
-  let twiddle_b: Complex = multiply_complex(twiddle, b);
-  let result: vec4<f32> = vec4<f32>(a.x + twiddle_b.x, a.y + twiddle_b.y, a.x - twiddle_b.x, a.y - twiddle_b.y);
-  return result;
-}
-
-@compute @workgroup_size(256)
-fn main(@builtin(global_invocation_id) global_id: vec3<u32>) {
-  let pixel_coord: vec2<u32> = vec2<u32>(global_id.x, global_id.y);
-
-  let thread_count: u32 = u32(u_total_count >> 1); // Integer division by 2
-  let thread_idx: u32 = pixel_coord.x;
-
-  let in_idx: u32 = thread_idx & (u_subseq_count - 1u);
-  let out_idx: u32 = ((thread_idx - in_idx) << 1) + in_idx;
-
-  let angle: f32 = -PI * (f32(in_idx) / f32(u_subseq_count));
-  let twiddle: Complex = Complex(cos(angle), sin(angle));
-
-  let a: vec4<f32> = textureSample(u_input, pixel_coord);
-  let b: vec4<f32> = textureSample(u_input, vec2<u32>(pixel_coord.x + thread_count, pixel_coord.y));
-
-  // Transforming two complex sequences independently and simultaneously
-
-  let result0: vec4<f32> = butterfly_operation(Complex(a.x, a.y), Complex(b.x, b.y), twiddle);
-  let result1: vec4<f32> = butterfly_operation(Complex(a.z, a.w), Complex(b.z, b.w), twiddle);
-
-  textureStore(u_output, vec2<u32>(out_idx, pixel_coord.y), vec4<f32>(result0.xy, result1.xy));
-  textureStore(u_output, vec2<u32>(out_idx + u_subseq_count, pixel_coord.y), vec4<f32>(result0.zw, result1.zw));
-}

+ 0 - 56
src/shaders/CS_FFTVertical.wgsl

@@ -1,56 +0,0 @@
-@group(0) @binding(0)
-var<uniform> u_input: texture_2d<f32>; // readonly image2D
-
-@group(0) @binding(1)
-var<storage, write> u_output: texture_2d<f32>; // writeonly image2D
-
-let PI: f32 = 3.14159265358979323846;
-
-struct Complex {
-  x: f32,
-  y: f32,
-};
-
-fn multiply_complex(a: Complex, b: Complex) -> Complex {
-  return Complex(
-    a.x * b.x - a.y * b.y,
-    a.y * b.x + a.x * b.y
-  );
-}
-fn add_complex(a: Complex, b: Complex) -> Complex{
-  return Complex(
-    a.x + b.x,
-    a.y + b.y
-  )
-}
-
-fn butterfly_operation(a: Complex, b: Complex, twiddle: Complex) -> vec4<f32> {
-  let twiddle_b: Complex = multiply_complex(twiddle, b);
-  let result: vec4<f32> = vec4<f32>(a.x + twiddle_b.x, a.y + twiddle_b.y, a.x - twiddle_b.x, a.y - twiddle_b.y);
-  return result;
-}
-
-@compute @workgroup_size(256)
-fn main(@builtin(global_invocation_id) global_id: vec3<u32>) {
-  let pixel_coord: vec2<u32> = vec2<u32>(global_id.x, global_id.y);
-
-  let thread_count: u32 = u32(u_total_count >> 1); // Integer division by 2
-  let thread_idx: u32 = pixel_coord.y;
-
-  let in_idx: u32 = thread_idx & (u_subseq_count - 1u);
-  let out_idx: u32 = ((thread_idx - in_idx) << 1) + in_idx;
-
-  let angle: f32 = -PI * (f32(in_idx) / f32(u_subseq_count));
-  let twiddle: Complex = Complex(cos(angle), sin(angle));
-
-  let a: vec4<f32> = textureSample(u_input, pixel_coord);
-  let b: vec4<f32> = textureSample(u_input, vec2<u32>(pixel_coord.x, pixel_coord.y + thread_count));
-
-  // Transforming two complex sequences independently and simultaneously
-
-  let result0: vec4<f32> = butterfly_operation(Complex(a.x, a.y), Complex(b.x, b.y), twiddle);
-  let result1: vec4<f32> = butterfly_operation(Complex(a.z, a.w), Complex(b.z, b.w), twiddle);
-
-  textureStore(u_output, vec2<u32>(pixel_coord.x, out_idx), vec4<f32>(result0.xy, result1.xy));
-  textureStore(u_output, vec2<u32>(pixel_coord.x, out_idx + u_subseq_count), vec4<f32>(result0.zw, result1.zw));
-}

+ 0 - 237
src/shaders/CS_mine.wgsl

@@ -1,237 +0,0 @@
-//Structs
-struct Complex{
-    x: f32,
-    y:f32
-};
-
-struct SpectrumParameters {
-  scale: f32,
-  angle: f32,
-  spreadBlend: f32,
-  swell: f32,
-  alpha: f32,
-  peakOmega: f32,
-  gamma: f32,
-  shortWavesFade: f32,
-};
-//Constants
-let PI: f32 = 3.14159265358979323846;
-let GRAVITY = 9.81;
-let DEPTH = 20;
-let N = 5;
-let SEED = 325235235;
-let LENGTHSCALE0 = 1;
-let LENGTHSCALE1 = 1;
-let LENGTHSCALE2 = 1;
-let LENGTHSCALE3 = 1;
-
-
-fn add_complex(a:Complex, b:Complex)-> Complex{
-    return Complex(
-        a.x + b.x,
-        a.y + b.y
-    )
-}
-
-fn multiply_complex(a:Complex, b:Complex)-> Complex{
-    return Complex(
-        (a.x * b.x) - (a.y * b.y),
-        (a.x * b.y) + (b.x * a.y)
-    )
-}
-
-fn gaussian(x: f32, y:f32)-> f32{
-    //Mean = 0
-    let spread = 1.0;
-    let sigma_squared = spread * spread;
-    return (1 / (sqrt(2 * PI)* spread)) * exp(-((x * x) + (y * y)) / (2 * sigmaSqu));
-}
-
-fn angle_to_complex(x:f32)-> Complex{
-    return Complex(
-        cos(x),
-        sin(x)
-    )
-}
-
-// fn hash(n: u32) -> f32 {
-//   let mut n = (n << 13u);
-//   n1 ^= n1;
-//   n = n * (n * n * 15731u + 0x789221u) + 0x1376312589u;
-//   return f32(n & 0x7fffffffu) / 0.7976709; // 0x7fffffff is the same as 2^31 - 1
-// }
-
-fn UniformToGaussian(u1: f32, u2: f32) -> vec2<f32> {
-  // Calculate the radius (R) from the first uniform value
-  let R = sqrt(-2.0 * log(max(u1, 0.00001))); // Clamp u1 to avoid log(0)
-
-  // Calculate the angle (theta) from the second uniform value
-  let theta = 2.0 * PI * u2;
-
-  // Return a 2D vector with R * cos(theta) and R * sin(theta)
-  return vec2<f32>(R * cos(theta), R * sin(theta));
-}
-
-fn Dispersion(kMag: f32) -> f32 {
-  let depth = constant _Depth; // Assuming _Depth is a constant
-  return sqrt(GRAVITY * kMag * tanh(min(kMag * depth, 20.0)));
-}
-
-fn DispersionDerivative(kMag: f32) -> f32 {
-  let th = tanh(min(kMag * constant _Depth, 20.0)); // Assuming _Depth is a constant
-  let ch = cosh(kMag * constant _Depth); // Assuming _Depth is a constant
-  return GRAVITY * (depth * kMag / ch / ch + th) / Dispersion(kMag) / 2.0;
-}
-
-fn NormalizationFactor(s: f32) -> f32 {
-  let s2 = s * s;
-  let s3 = s2 * s;
-  let s4 = s3 * s;
-  let threshold = 5.0;
-  return if s < threshold {
-    -0.000564 * s4 + 0.00776 * s3 - 0.044 * s2 + 0.192 * s + 0.163
-  } else {
-    -4.80e-08 * s4 + 1.07e-05 * s3 - 9.53e-04 * s2 + 5.90e-02 * s + 3.93e-01
-  };
-}
-fn DonelanBannerBeta(x: f32) -> f32 {
-  let threshold1 = 0.95;
-  let threshold2 = 1.6;
-  let factor1 = 2.61;
-  let factor2 = 2.28;
-  let power1 = 1.3;
-  let power2 = -1.3;
-
-  if x < threshold1 {
-    return factor1 * abs(x) ^ power1;
-  } else if x < threshold2 {
-    return factor2 * abs(x) ^ power2;
-  } else {
-    let p = -0.4 + 0.8393 * exp(-0.567 * log(x * x));
-    return pow(10.0, p);
-  }
-}
-
-fn DonelanBanner(theta: f32, omega: f32, peakOmega: f32) -> f32 {
-  let beta = DonelanBannerBeta(omega / peakOmega);
-  let sech = 1.0 / cosh(beta * theta);
-  return beta / 2.0 / tanh(3.14159 * beta) * sech * sech;
-}
-
-fn Cosine2s(theta: f32, s: f32) -> f32 {
-  return NormalizationFactor(s) * pow(abs(cos(0.5 * theta)), 2.0 * s);
-}
-
-fn SpreadPower(omega: f32, peakOmega: f32) -> f32 {
-  let threshold = peakOmega;
-  let factor1 = 9.77;
-  let factor2 = 6.97;
-  let power1 = -2.5;
-  let power2 = 5.0;
-
-  if omega > threshold {
-    return factor1 * abs(omega / peakOmega) ^ power1;
-  } else {
-    return factor2 * abs(omega / peakOmega) ^ power2;
-  }
-}
-fn DirectionSpectrum(theta: f32, omega: f32, spectrum: SpectrumParameters) -> f32 {
-  let peakOmega = spectrum.peakOmega;
-  let swell = spectrum.swell;
-  let spreadBlend = spectrum.spreadBlend;
-  let s = SpreadPower(omega, peakOmega) + 16.0 * tanh(min(omega / peakOmega, 20.0)) * swell * swell;
-  return lerp(2.0 / 3.14159 * cos(theta) * cos(theta), Cosine2s(theta - spectrum.angle, s), spreadBlend);
-}
-
-fn TMACorrection(omega: f32) -> f32 {
-  let depth = constant _Depth; // Assuming _Depth is a constant
-  let omegaH = omega * sqrt(depth / GRAVITY);
-
-  if omegaH <= 1.0 {
-    return 0.5 * omegaH * omegaH;
-  } else if omegaH < 2.0 {
-    return 1.0 - 0.5 * (2.0 - omegaH) * (2.0 - omegaH);
-  } else {
-    return 1.0;
-  }
-}
-
-fn JONSWAP(omega: f32, spectrum: SpectrumParameters) -> f32 {
-  let peakOmega = spectrum.peakOmega;
-  let sigma = if omega <= peakOmega { 0.07 } else { 0.09 };
-  let r = exp(-(omega - peakOmega) * (omega - peakOmega) / 2.0 / sigma / sigma / peakOmega / peakOmega);
-  let oneOverOmega = 1.0 / omega;
-  let peakOmegaOverOmega = peakOmega / omega;
-  return spectrum.scale * TMACorrection(omega) * spectrum.alpha * GRAVITY * GRAVITY *
-         oneOverOmega * oneOverOmega * oneOverOmega * oneOverOmega * oneOverOmega *
-         exp(-1.25 * peakOmegaOverOmega * peakOmegaOverOmega * peakOmegaOverOmega * peakOmegaOverOmega) *
-         pow(abs(spectrum.gamma), r);
-}
-
-fn ShortWavesFade(kLength: f32, spectrum: SpectrumParameters) -> f32 {
-  return exp(-spectrum.shortWavesFade * spectrum.shortWavesFade * kLength * kLength);
-}
-
-// fn CS_InitializeSpectrum(id: u32) {
-//   // Seed generation based on thread ID
-//   let seed = u32(id.x) + N * u32(id.y) + N;
-//   seed += SEED;
-
-//   // Pre-defined length scales
-//   let lengthScales: array<f32, 4> = [LENGTHSCALE0, LENGTHSCALE1, LENGTHSCALE2, LENGTHSCALE3];
-
-//   // Loop through all scales
-//   for (let i: u32 = 0; i < 4; i++) {
-//     let halfN = N / 2.0;
-//     let deltaK = 2.0 * PI / lengthScales[i];
-//     let K = vec2<f32>(f32(id.x) - halfN, f32(id.y) - halfN) * deltaK;
-//     let kLength = length(K);
-
-//     // Update seed with scrambling
-//     seed += i + hash(seed) * 10.0;
-//     let uniformRandSamples = vec4<f32>(hash(seed), hash(seed * 2.0), hash(seed * 3.0), hash(seed * 4.0));
-//     let gauss1 = UniformToGaussian(uniformRandSamples.x, uniformRandSamples.y);
-//     let gauss2 = UniformToGaussian(uniformRandSamples.z, uniformRandSamples.w);
-
-//     // Check if wave number is within cut-off range
-//     if (_LowCutoff <= kLength && kLength <= _HighCutoff) {
-//       let kAngle = atan2(K.y, K.x);
-//       let omega = Dispersion(kLength);
-//       let dOmegadk = DispersionDerivative(kLength);
-
-//       // Calculate spectrum for this wave number and scale
-//       let spectrum = JONSWAP(omega, _Spectrums[i * 2]) * DirectionSpectrum(kAngle, omega, _Spectrums[i * 2]) * ShortWavesFade(kLength, _Spectrums[i * 2]);
-
-//       // Add contribution from secondary spectrum if present
-//       if (_Spectrums[i * 2 + 1].scale > 0.0) {
-//         spectrum += JONSWAP(omega, _Spectrums[i * 2 + 1]) * DirectionSpectrum(kAngle, omega, _Spectrums[i * 2 + 1]) * ShortWavesFade(kLength, _Spectrums[i * 2 + 1]);
-//       }
-
-//       // Fill initial spectrum texture with calculated values
-//       _InitialSpectrumTextures[u32(id.xy, i)] = vec4(vec2(gauss2.x, gauss1.y) * sqrt(2.0 * spectrum * abs(dOmegadk) / kLength * deltaK * deltaK), 0.0, 0.0);
-//     } else {
-//       // Set to zero if wave number is outside cut-off range
-//       _InitialSpectrumTextures[u32(id.xy, i)] = vec4(0.0);
-//     }
-//   }
-// }
-
-fn CS_PackSpectrumConjugate(id: u32) {
-  // Loop through all scales
-  for (let i: u32 = 0; i < 4; i++) {
-    // Access initial spectrum data (assuming vec2 for real and imaginary parts)
-    let h0 = vec2(_InitialSpectrumTextures[u32(id.xy, i)].x, _InitialSpectrumTextures[u32(id.xy, i)].y);
-
-    // Calculate mirrored thread ID for conjugate access
-    let mirrored_id = u32(N - id.x % N, N - id.y % N, i);
-
-    // Access conjugate spectrum data
-    let h0conj = vec2(_InitialSpectrumTextures[mirrored_id].x, -_InitialSpectrumTextures[mirrored_id].y);
-
-    // Pack spectrum and conjugate into a vec4
-    _InitialSpectrumTextures[u32(id.xy, i)] = vec4(h0, h0conj.x, -h0conj.y);
-  }
-}
-
-
-

+ 87 - 9
src/shaders/Heightmap_compute.wgsl

@@ -1,7 +1,17 @@
 //Constants
 const PI: f32 = 3.14159265358979323846;
-const GRAVITY:f32= 9.81;
-
+const G:f32= 9.81;
+const EPSILON:f32 = 0.0001;
+const ALPHA0= pow(7.6,10e-3);
+//Distance à la côte
+const FETCH : u32 = 100;
+//Vitesse du vent à 10m au dessus de la mer
+const U10 : f32 = 10.0;
+//Paramètre artistique
+const GAMMA : f32 = 3.3;
+@group(0) @binding(0) var heightmap_texture: texture_storage_2d<r32float, write>;
+@group(1) @binding(0) var<uniform> time: f32;
+@group(1) @binding(1) var<uniform> SIZE: u32;
 //Structs
 struct Complex {
   a: f32,
@@ -21,7 +31,15 @@ fn complex_multiply(a: Complex, b: Complex) -> Complex {
   let imaginary = a.a * b.b + a.b * b.a;
   return Complex(real, imaginary);
 }
-
+fn complex_conjugate(a: Complex) -> Complex {
+    return Complex(a.a, -a.b);
+}
+fn complex_magnitude(a: Complex) -> f32 {
+    return sqrt(a.a * a.a + a.b * a.b);
+}
+fn complex_phase(a: Complex) -> f32 {
+    return atan2(a.b, a.a);
+}
 fn angle_to_complex(x:f32)-> Complex{
     return Complex(
         cos(x),
@@ -29,6 +47,22 @@ fn angle_to_complex(x:f32)-> Complex{
     );
 }
 
+fn alpha()->f32{
+    return ALPHA0 * pow(((U10 * U10) / FETCH * G), 0.22);
+}
+
+fn omega_w()->f32{
+  return 22.0 * pow((G * G)/(U10* FETCH), 0.3333);
+}
+
+fn sigma(omega:f32)->f32{
+  if omega < omega_w(){
+    return 0.07;
+  }else{
+    return 0.09;
+  }
+}
+
 struct ComputeInput{
     @builtin(local_invocation_id) local_invocation_id: vec3<u32>,
     @builtin(local_invocation_index) local_invocation_index: u32,
@@ -37,10 +71,41 @@ struct ComputeInput{
     @builtin(num_workgroups) num_workgroups: vec3<u32>
 }
 
-//Actual code
+fn JONSWAP(omega:f32)->f32{
+  let alpha = alpha();
+  let r = exp(-pow((omega - omega_w()), 2) / (2 * pow(sigma(omega), 2) * pow(omega_w(), 2)));
+
+  let premier_terme = alpha * (G * G)/(pow(omega, 5));
+  let deuxieme_terme = exp(-1.25 * pow((omega_w() / omega), 4));
+  let troisieme_terme = pow(GAMMA, r);
+  return premier_terme * deuxieme_terme * troisieme_terme;
+}
+
+
+fn h_0(in:ComputeInput,k:vec2<f32>)->Complex{
+  let omega = sqrt(G * (k.x + k.y));
+  let spectre = JONSWAP(omega);
+  let h = sqrt(spectre / 2);
+  return Complex(h*gaussian(k.x), h*gaussian(k.y));
+}
+
+fn inverse_fft(k: vec2<f32>, spectrum: array<array<Complex, SIZE>, SIZE>, N: u32) -> f32 {
+    var sum: Complex = Complex(0.0, 0.0);
+
+    for (i = 0u; i < N; i = i + 1u) {
+        for (j = 0u; j < N; j = j + 1u) {
+            let n = vec2<f32>(f32(i), f32(j));
+            let exponent = -2.0 * PI * (k.x * n.x + k.y * n.y) / f32(N);
+            let phase = Complex(cos(exponent), sin(exponent));
+            sum = complex_add(sum, complex_multiply(spectrum[i][j], phase));
+        }
+    }
+
+    return sum.a / f32(N * N);
+}
+
+
 
-//@group(0) @binding(0) var heightmap_texture: texture_2d<f32>;
-@group(0) @binding(0) var heightmap_texture: texture_storage_2d<r32float, write>;
 fn gaussian(x: f32, y:f32)-> f32{
     let mean = 0.0;
     let spread = 1.0;
@@ -55,11 +120,24 @@ fn compute_main(in: ComputeInput){
     let coordonnees = in.global_invocation_id.xy;
     let dimensions = textureDimensions(heightmap_texture);
 
-    let pixel_color_1_channel = sin(f32(coordonnees.x + coordonnees.y)/10.);
+    if coords.x >= dimensions.x || coords.y >= dimensions.y {
+        return;
+    }
+    let k = vec2<f32>(f32(coords.x) / f32(dimensions.x), f32(coords.y) / f32(dimensions.y));
+    
+    // Create the spectrum based on the grid size N
+    var spectrum: array<array<Complex, SIZE>, SIZE>;
+    
+    for (i = 0u; i < N; i = i + 1u) {
+        for (j = 0u; j < N; j = j + 1u) {
+            let k_wave = vec2<f32>(f32(i) - f32(N / 2), f32(j) - f32(N / 2));
+            spectrum[i][j] = h_0(k_wave);
+        }
+    }
 
-    let vec4_pixel_color = vec4<f32>(pixel_color_1_channel);
+    let ht = inverse_fft(k, spectrum, N) * cos(time);
 
-    textureStore(heightmap_texture, coordonnees, vec4_pixel_color);
+    textureStore(heightmap_texture, coords, vec4<f32>(ht, 0.0, 0.0, 1.0));
 }