public class HeatTest {
    private static final int DEFAULT_SIZE = 512;
    private static final int DEFAULT_STEPS = 1000;
    private static final int DEFAULT_SUB_WIDTH = 512;
    private static final int DEFAULT_SUB_HEIGHT = 128;
    private static final double DEFAULT_CONDUCTIVITY = 0.2;
    private static final double INITIAL_TEMPERATURE = 0.0;

    private static class Args {
        int gridSize = DEFAULT_SIZE;
        double conductivity = DEFAULT_CONDUCTIVITY;
        int steps = DEFAULT_STEPS;
        int subWidth = DEFAULT_SUB_WIDTH;
        int subHeight = DEFAULT_SUB_HEIGHT;
    }
    
    public static void main(String[] argStrings) {
        Args args = readArgs(argStrings);
        int gridSize = args.gridSize;

        HeatModel model;
        if (args.subWidth == gridSize && args.subHeight == gridSize) {
            model = new HeatModelSimple(gridSize, gridSize,
                INITIAL_TEMPERATURE, args.conductivity);
        } else {
            model = new HeatModelThreaded(gridSize, gridSize,
                INITIAL_TEMPERATURE, args.conductivity,
                args.subWidth, args.subHeight);
        }

        model.setPinned(gridSize - 1, gridSize - 1,
            100.0 - INITIAL_TEMPERATURE);

        double heat00 = Double.MIN_VALUE;
        double heat99 = Double.MIN_VALUE;
        long toSubtract = 0;
        long start = System.nanoTime();
        for (int step = 1; step <= args.steps; step++) {
            model.step();
            if (step == 1000) {
                long subtractStart = System.nanoTime();
                double[][] heat = new double[gridSize][gridSize];
                model.getAll(heat);
                heat00 = heat[0][0];
                heat99 = heat[9][9];
                toSubtract += (System.nanoTime() - subtractStart);
            }
        }
        double elapse = 1e-6 * (System.nanoTime() - start - toSubtract);

        System.out.printf("1000: temperature at 0,0: %7.3f (expect 58.096)\n",
                heat00);
        System.out.printf("1000: temperature at 9,9: %7.3f (expect 12.241)\n",
                heat99);
        System.out.printf("# subgrids:               %3d\n",
            gridSize * gridSize / args.subWidth / args.subHeight);
        System.out.printf("avg time per step:        %7.3fms\n",
            elapse / args.steps);
        System.exit(0);
    }

    private static Args readArgs(String[] argStrings) {
        Args args = new Args();
        if (argStrings.length < 1) {
            return args;
        } else if (argStrings.length != 4) {
            exitError("usage: java HeatTest GRID_SIZE NUM_STEPS SUB_WIDTH SUB_HEIGHT");
        }
        try {
            args.gridSize = Integer.parseInt(argStrings[0]);
        } catch (NumberFormatException e) {
            exitError("grid size must be integer");
        }
        try {
            args.steps = Integer.parseInt(argStrings[1]);
        } catch (NumberFormatException e) {
            exitError("number of steps must be integer");
        }
        try {
            args.subWidth = Integer.parseInt(argStrings[2]);
        } catch (NumberFormatException e) {
            exitError("subgrid width must be integer");
        }
        try {
            args.subHeight = Integer.parseInt(argStrings[3]);
        } catch (NumberFormatException e) {
            exitError("subgrid height must be integer");
        }

        if (args.gridSize <= 1 || args.gridSize > 2048) {
            exitError("grid size must be between 2 and 2048");
        } else if ((args.gridSize & (args.gridSize - 1)) != 0) {
            exitError("grid size must be power of 2");
        } else if (args.steps <= 0) {
            exitError("number of steps must be positive");
        } else if (args.subWidth < 1 || args.subWidth > args.gridSize) {
            exitError("subgrid width must be between 1 and grid size");
        } else if ((args.subWidth & (args.subWidth - 1)) != 0) {
            exitError("subgrid width must be power of 2");
        } else if (args.subHeight < 1 || args.subHeight > args.gridSize) {
            exitError("subgrid height must be between 1 and grid size");
        } else if ((args.subHeight & (args.subHeight - 1)) != 0) {
            exitError("subgrid height must be power of 2");
        }

        return args;
    }

    private static void exitError(String message) {
        System.err.println(message);
        System.exit(-1);
    }
}
