// Multithreaded version using Java ForkJoin
//
// To compile and run:
//   javac ForkJoin.java
//   java ForkJoin

import java.util.concurrent.*;

public class ForkJoin {
    private static final int SIZE = 500_000_000;
    
    private static final ForkJoinPool POOL = ForkJoinPool.commonPool();
    private static final int THRESHOLD = 200_000; // tuneable

    // out[i] = (in[i - 1] + in[i] + in[i + 1]) / 3
    public static void smooth(float[] in, float[] out) {
        int n = in.length;

        // Left boundary
        out[0] = (in[0] + in[1] + in[2]) / 3f;

        // Launch ForkJoin task for [1 .. n-2]
        POOL.invoke(new SmoothTask(in, out, 1, n - 1));

        // Right boundary
        out[n - 1] = (in[n - 3] + in[n - 2] + in[n - 1]) / 3f;
    }

    private static class SmoothTask extends RecursiveAction {
        private final float[] in;
        private final float[] out;
        private final int start;
        private final int end;     // exclusive: computes [start .. end-1]

        public SmoothTask(float[] in, float[] out, int start, int end) {
            this.in = in;
            this.out = out;
            this.start = start;
            this.end = end;
        }

        @Override
        protected void compute() {
            int len = end - start;
            if (len < THRESHOLD) {
                computeSequential();
            } else {
                int mid = start + len / 2;
                SmoothTask left  = new SmoothTask(in, out, start, mid);
                SmoothTask right = new SmoothTask(in, out, mid, end);
                left.fork();
                right.compute();
                left.join();
            }
        }

        private void computeSequential() {
            for (int i = start; i < end; i++) {
                out[i] = (in[i - 1] + in[i] + in[i + 1]) / 3f;
            }
        }
    }

    public static void main(String[] args) {

        // Create some data
        float[] array = generateData(SIZE);

        // Run the test a few times to warm up the JVM
        for (int i = 0; i < 5; i++) {
            runTest(array);
        }
    }

    private static void runTest(float[] array) {
        float[] smoothed = new float[array.length];

        long startTime = System.nanoTime();
        smooth(array, smoothed);
        long endTime = System.nanoTime();
        double elapsedTimeMS = (endTime - startTime) / 1_000_000.0;

        // Checksum, matching Scalar.java.
        int hash = 17;
        for (float v : smoothed) {
            hash = (hash * 31) + (int) v;
        }

        System.out.printf("ForkJoin: Time taken: %.3f ms, checksum: %d%n",
                          elapsedTimeMS, hash);
    }

    // Utility method to generate "random" float data
    private static float[] generateData(int size) {
        float[] data = new float[size];
        for (int i = 0; i < size; i++) {
            data[i] = (float) i;
        }
        return data;
    }
}