import java.io.IOException;
import java.net.DatagramPacket;
import java.net.InetAddress;
import java.net.SocketException;
import java.net.UnknownHostException;
import java.util.BitSet;

public class TftpClient implements Runnable {
    private static final String HOST_NAME = "localhost";
    private static final int HOST_PORT = 6900;
    private static final String FILE_NAME = "src/words.txt";
    
    private int initWaitMillis;
    private UdpModeledConnection connection;
    private InetAddress serverAddr;
    private int serverPort = -1;
    private BitSet blocksReceived = new BitSet();
    private int blocksReceivedCount = 0;
    private int blocksMax = -1;
    
    TftpClient(InetAddress serverAddr, UdpModeledConnection connection, int initWaitMillis) {
        this.serverAddr = serverAddr;
        this.connection = connection;
        this.initWaitMillis = initWaitMillis;
    }
    
    public void run() {
        if (initWaitMillis > 0) {
            try {
                Thread.sleep(initWaitMillis);
            } catch (InterruptedException e) { }
        }

        try {
            log("request '%s' from server", FILE_NAME);
            DatagramPacket packet = new DatagramPacket(new byte[80], 80, serverAddr, HOST_PORT);
            TftpPacket.packReadRequest(packet, FILE_NAME, "octet");
            connection.sendSafely(packet);
        } catch (IOException e) {
            log("I/O exception while sending request");
            return;
        }

        boolean isFirstPacket = true;
        boolean keepGoing = true;
        long start = System.nanoTime();
        long stop = start - 1;
        while (true) {
            DatagramPacket packet;
            try {
                if (keepGoing) {
                    packet = connection.receive();
                } else {
                    if (stop == start - 1) {
                        stop = System.nanoTime();
                    }
                    packet = connection.receiveTimeout(3000);
                    if (packet == null) {
                        break;
                    }
                }
            } catch (IOException e) {
                stop = System.nanoTime();
                break;
            }
            keepGoing = processPacket(packet, isFirstPacket);
            isFirstPacket = false;
        }
        connection.close();
        log("transmission complete: %.3fs elapsed, %d packets sent, %d packets dropped",
                (stop - start) / 1e9, connection.getTotalPackets(), connection.getDroppedPackets());
    }

    private void sendError(InetAddress addr, int port, int errNum, String errMessage) throws IOException {
        log("send error '%s' to server [%d]", errMessage, errNum);
        DatagramPacket packet = new DatagramPacket(new byte[80], 80, addr, port);
        TftpPacket.packError(packet, errNum, errMessage);
        connection.send(packet);
    }
    
    private boolean processPacket(DatagramPacket packet, boolean isFirstPacket) {
        try {
            int packetPort = packet.getPort();
            InetAddress packetAddr = packet.getAddress();
            if (isFirstPacket) {
                isFirstPacket = false;
                serverPort = packetPort;
                log("connected to port %d", serverPort);
            } else if (packetPort != serverPort || !packetAddr.equals(serverAddr)) {
                log("received message from port %d, expected from port %d", packetPort, serverPort);
                sendError(packetAddr, packetPort, 5, "unknown UDP port");
                return true;
            }
            
            switch (TftpPacket.unpackCode(packet)) {
            case TftpPacket.DATA:
                int block = TftpPacket.unpackBlockNumber(packet);
                int nbytes = TftpPacket.unpackDataLength(packet);
                if (nbytes < 512) {
                    blocksMax = block;
                }
                if (!blocksReceived.get(block)) {
                    blocksReceived.set(block);
                    blocksReceivedCount++;
                }
                
                log("acknowledge block %d [%d received of %d]", block, blocksReceivedCount, blocksMax);
                DatagramPacket packOut = new DatagramPacket(new byte[6], 6, serverAddr, serverPort);
                TftpPacket.packAcknowledgement(packOut, block);
                connection.send(packOut);
                return blocksMax < 0 || blocksReceivedCount < blocksMax;
            case TftpPacket.ERROR:
                log("server error '%s'", TftpPacket.unpackError(packet));
                return false;
            default:
                throw new TftpException(4, String.format("expected data or error message, got %d", TftpPacket.unpackCode(packet)));
            }
        } catch (TftpException e) {
            log("internal error '%s'", e.getMessage());
            try {
                sendError(serverAddr, serverPort, e.getErrorCode(), e.getMessage());
            } catch (IOException e2) {
                log("I/O exception while sending error packet");
            }
            return false;
        } catch (IOException e) {
            log("I/O exception while processing packet");
            return false;
        }
    }
    
    public static void runBackground() {
        InetAddress server = null;
        try {
            server = InetAddress.getByName(HOST_NAME);
        } catch (UnknownHostException e1) {
            System.err.printf("unknown server name '%s'\n", HOST_NAME);
            return;
        }

        UdpModeledConnection connection = null;
        try {
            connection = new UdpModeledConnection();
        } catch (SocketException e) {
            System.err.println("could not create client socket");
            return;
        } catch (SecurityException e) {
            System.err.println("no permission to create client socket");
            return;
        }

        TftpClient client = new TftpClient(server, connection, 500);
        new Thread(client).start();
    }
    
    private static void log(String format, Object... parms) {
        System.out.printf("client: %s\n", String.format(format, parms));
    }
}
