
package uk.co.wingpath.modbus;

import java.io.*;
import uk.co.wingpath.io.*;
import uk.co.wingpath.util.*;

/**
* Implementation of the {@code PacketType} interface for sending and receiving
* Modbus messages using the Modbus/RTU protocol.
*/
public class RtuPacketType
    implements PacketType
{
    private volatile int preTimeout = 1000;
    private volatile int eomTimeout = 100;
    private int maxPdu = Modbus.MAX_IMP_PDU_SIZE;
    private boolean allowLongMessages = false;

    private class FormatException
        extends Exception
    {
        private final String helpId;

        FormatException (String helpId, String message)
        {
            super (message);
            this.helpId = helpId;
        }

        String getHelpId ()
        {
            return helpId;
        }
    }

    public void setMaxPdu (int maxPdu)
    {
        assert maxPdu >= Modbus.MIN_PDU_SIZE &&
            maxPdu <= Modbus.MAX_IMP_PDU_SIZE : maxPdu;
        this.maxPdu = maxPdu;
    }

    public void setAllowLongMessages (boolean allowLongMessages)
    {
        this.allowLongMessages = allowLongMessages;
    }

    public void send (Connection connection, Tracer tracer,
            ModbusMessage message)
        throws IOException, InterruptedException
    {
        if (connection == null)
            throw new HEOFException ("I100", "Connection closed");

        byte [] buf = pack (message);
        if (tracer != null)
        {
            tracer.traceRaw (">", buf, 0, buf.length);
            message.traceSend (tracer);
        }
        connection.write (buf, 0, buf.length);
    }

    public ModbusMessage receive (Connection connection,
            Tracer tracer, boolean isRequest, Reporter reporter)
        throws InterruptedIOException, IOException, InterruptedException
    {
        byte buf [] = new byte [Modbus.MAX_IMP_PDU_SIZE + 3];
        int [] offsets = new int [buf.length];
        String [] errorMsg = new String [buf.length];
        String [] helpId = new String [buf.length];
        int bytesRead;

        int nChunks = 0;
        offsets [0] = 0;

        for (;;)
        {
            // Have we filled the buffer?

            assert nChunks < offsets.length;
            assert offsets [nChunks] >= nChunks;
            if (offsets [nChunks] >= buf.length)
            {
                // Discard what we've read, and throw exception to caller

                assert (nChunks != 0);
                errorMsg [nChunks - 1] =
                    "Message too long (> " + buf.length + " bytes)";
                helpId [nChunks - 1] = "S503";
                break;
            }

            // Read a chunk of data

            try
            {
                bytesRead = connection.read (buf, offsets [nChunks],
                    buf.length - offsets [nChunks],
                    nChunks == 0 ? preTimeout : eomTimeout,
                    nChunks == 0);
                if (tracer != null)
                    tracer.traceRaw ("<", buf, offsets [nChunks], bytesRead);
                assert bytesRead > 0;
            }
            catch (InterruptedIOException e)
            {
                // simple timeout

                if (nChunks != 0)
                {
                    // Trace discarded data and report error

                    break;
                }

                throw e;
            }
            catch (RecoverableIOException e)
            {
                String id = e.getHelpId ();
                String msg = e.getMessage ();
                if (nChunks != 0)
                {
                    ModbusMessage.traceDiscard (tracer,
                        id, msg, buf, 0, offsets [nChunks]);
                }
                else
                {
                    ModbusMessage.traceError (tracer, id, msg);
                }
                if (reporter != null)
                    reporter.warning (id, "Invalid data received: " + msg);
                throw e;
            }

            assert bytesRead > 0;
            assert nChunks + 1 < offsets.length;
            offsets [nChunks + 1] = offsets [nChunks] + bytesRead;
            nChunks++;
            assert nChunks < offsets.length;
            assert offsets [nChunks] >= nChunks;

            // Attempt to build a message from the chunks we have received

            if (offsets [nChunks] - 3 > maxPdu)
            {
                String msg = "PDU size (" + (offsets [nChunks] - 3) +
                    " bytes) exceeds maximum (" + maxPdu + " bytes)";
                ModbusMessage.traceDiscard (tracer, "S504", msg, buf, 0,
                    offsets [nChunks]);
                if (reporter != null)
                    reporter.warning ("S504", "Invalid data received: " + msg);
                throw new RecoverableIOException ("S504", msg);
            }

            for (int i = 0 ; i < nChunks ; i++)
            {
                try
                {
                    ModbusMessage message = unpack (buf, offsets [i],
                        offsets [nChunks] - offsets [i], isRequest);
                    if (i != 0)
                    {
                        ModbusMessage.traceDiscard (tracer, helpId [i - 1],
                            errorMsg [i - 1], buf, 0, offsets [i]);
                        if (reporter != null)
                        {
                            reporter.warning ("S501",
                                "Invalid data received: " + errorMsg [i - 1]);
                        }
                    }
                    return message;
                }
                catch (FormatException e)
                {
                    if (i == 0)
                    {
                        errorMsg [nChunks - 1] = e.getMessage ();
                        helpId [nChunks - 1] = e.getHelpId ();
                    }
                }
            }
        }

        if (nChunks != 0)
        {
            ModbusMessage.traceDiscard (tracer, helpId [nChunks - 1],
                errorMsg [nChunks - 1],
                buf, 0, offsets [nChunks]);
            if (reporter != null)
            {
                reporter.warning (helpId [nChunks - 1],
                    "Invalid data received: " + errorMsg [nChunks - 1]);
            }
        }

        throw new RecoverableIOException (helpId [nChunks - 1],
            errorMsg [nChunks - 1]);
    }

    /**
        Does packet type support packet identifiers?
        @return true if packet identifiers supported
    */

    public boolean hasTransactionIds ()
    {
        return false;
    }

    public void setTimeout (int timeout)
    {
        preTimeout = timeout;
    }

    public void setEomTimeout (int timeout)
    {
        eomTimeout = timeout;
    }

    /**
    * Packs fields of the supplied message into a byte array, ready for
    * writing to a serial connection.
    * @param message the message to be packed.
    * @return packed fields from the message.
    */
    private byte [] pack (ModbusMessage message)
    {
        byte [] data = message.getData ();
        byte [] buf = new byte [data.length + 4];
        buf [0] = (byte) message.getSlaveId ();
        buf [1] = (byte) message.getFunctionCode ();
        System.arraycopy (data, 0, buf, 2, data.length);
        int crc = calculateCrc (buf, 0, data.length + 2);
        buf [data.length + 2] = (byte) crc;
        buf [data.length + 3] = (byte) (crc >> 8);
        return buf;
    }

    /**
    * Unpacks the supplied byte array (which has been read from a serial
    * connection) into fields of a Modbus message.
    * @param buf the array to be unpacked.
    * @param offset offset in {@code buf} at which to start unpacking.
    * @param len number of bytes to be unpacked.
    * @param isRequest {@code true} if the message is a request, {@code false}
    * if it is a response.
    * @return message constructed from the unpacked fields.
    * @throws FormatException if an error occurs in the unpacking.
    */
    private ModbusMessage unpack (byte [] buf, int offset, int len,
            boolean isRequest)
        throws FormatException
    {
        if (len < 4)
        {
            throw new FormatException ("S502",
                "Message too short (" + len + " bytes)");
        }
        int receivedCrc = (((int) buf [offset + len - 1] & 0xff) << 8) |
            ((int) buf [offset + len - 2] & 0xff);
        int calculatedCrc = calculateCrc (buf, offset, len - 2);
        if (receivedCrc != calculatedCrc)
        {
            String msg = String.format (
                "Invalid CRC %02x %02x (should be %02x %02x)",
                buf [offset + len - 2], buf [offset + len - 1],
                calculatedCrc & 0xff, (calculatedCrc >> 8) & 0xff);
            throw new FormatException ("S501", msg);
        }
        int slaveId = buf [0] & 0xff;
        int function = buf [1] & 0xff;
        MessageBuilder body = new MessageBuilder ();
        body.addData (buf, 2, len - 4);
        ModbusMessage message;
        if (isRequest)
        {
            message = new ModbusRequest (slaveId, function, -1,
                body.getData ());
        }
        else if ((function & 0x80) != 0 && len > 4)
        {
            message = new ModbusErrorResponse (slaveId, function, -1,
                buf [2] & 0xff);
        }
        else
        {
            message = new ModbusResponse (slaveId, function, -1,
                body.getData ());
        }
        try
        {
            message.checkSize (allowLongMessages);
        }
        catch (ModbusException e)
        {
            throw new FormatException (e.getHelpId (), e.getExplanation ());
        }
        return message;
    }

    private int calculateCrc (byte [] buf, int offset, int len)
    {
        int result = 0xffff;

        while (len-- > 0)
        {
            result ^= (int) buf [offset++] & 0xff;

            for (int i = 0 ; i < 8 ; i++)
            {
                if ((result & 1) != 0)
                {
                    result >>= 1;
                    result ^= 0xa001;
                }
                else
                {
                    result >>= 1;
                }
            }
        }

        return result;
    }

    public int setDelimiter (int delimiter)
    {
        return delimiter;
    }
}

