// SIMD-accelerated version using the Java Vector API
// Requires JDK with incubator vector module enabled (open JDK 19+)
//
// To compile and run:
//   javac --add-modules jdk.incubator.vector Vector.java
//   java  --add-modules jdk.incubator.vector Vector

import jdk.incubator.vector.*;

public class Vector {
    private static final int SIZE = 500_000_000;
    
    private static final VectorSpecies<Float> S = FloatVector.SPECIES_PREFERRED;

    // out[i] = (in[i - 1] + in[i] + in[i + 1]) / 3
    public static void smooth(float[] in, float[] out) {
        int n = in.length;

        // Left boundary
        out[0] = (in[0] + in[1] + in[2]) / 3f;

        // Vectorized main loop
        int i = 1;
        int upper = n - 1 - (S.length() - 1);  // leave space for right boundary

        for (; i < upper; i += S.length()) {
            FloatVector left  = FloatVector.fromArray(S, in, i - 1);
            FloatVector mid   = FloatVector.fromArray(S, in, i);
            FloatVector right = FloatVector.fromArray(S, in, i + 1);

            FloatVector sum = left.add(mid).add(right);
            FloatVector smoothed = sum.div(3f);

            smoothed.intoArray(out, i);
        }

        // Scalar loop for remaining elements
        for (; i < n - 1; i++) {
            out[i] = (in[i - 1] + in[i] + in[i + 1]) / 3f;
        }

        // Right boundary
        out[n - 1] = (in[n - 3] + in[n - 2] + in[n - 1]) / 3f;
    }

    public static void main(String[] args) {

        // calculate some data
        float[] array = generateData(SIZE);

        System.out.println("Preferred Vector: " + S);

        // Run the test a few times to warm up the JVM
        for (int i = 0; i < 5; i++) {
            runTest(array);
        }
    }

    private static void runTest(float[] array) {
        float[] smoothed = new float[array.length];

        long startTime = System.nanoTime();
        smooth(array, smoothed);
        long endTime = System.nanoTime();
        double elapsedTimeMS = (endTime - startTime) / 1_000_000.0;

        // Calculate a simple hash checksum.
        int checksum = 17;
        for (float v : smoothed) {
            checksum = checksum * 31 + (int)v;
        }

        System.out.printf("Vector: Time taken: %.3f ms, checksum: %d%n",
                          elapsedTimeMS, checksum);
    }

    // Utility method to generate "random" float data
    private static float[] generateData(int size) {
        float[] data = new float[size];
        for (int i = 0; i < size; i++) {
            data[i] = (float) i;
        }
        return data;
    }
}