
package uk.co.wingpath.modbus;

import java.io.*;
import uk.co.wingpath.io.*;
import uk.co.wingpath.util.*;

/**
* Implementation of the {@code PacketType} interface for sending and receiving
* Modbus messages using the Modbus/RTU protocol.
*/
public class RtuPacketType
    implements PacketType
{
    private volatile int preTimeout = 1000;
    private volatile int eomTimeout = 100;
    private int maxPdu = Modbus.MAX_IMP_PDU_SIZE;
    private boolean allowLongMessages = false;

    private class FormatException
        extends Exception
    {
        private final String helpId;

        FormatException (String helpId, String message)
        {
            super (message);
            this.helpId = helpId;
        }

        String getHelpId ()
        {
            return helpId;
        }
    }

    @Override
    public void setMaxPdu (int maxPdu)
    {
        assert maxPdu >= Modbus.MIN_PDU_SIZE &&
            maxPdu <= Modbus.MAX_IMP_PDU_SIZE : maxPdu;
        this.maxPdu = maxPdu;
    }

    @Override
    public void setAllowLongMessages (boolean allowLongMessages)
    {
        this.allowLongMessages = allowLongMessages;
    }

    @Override
    public void send (Connection connection, ModbusMessage message)
        throws IOException, InterruptedException
    {
        byte [] buf = pack (message);
        Tracer tracer = message.getTracer ();
        if (tracer != null)
        {
            tracer.traceRaw (">", buf, 0, buf.length);
            message.traceSend ();
        }
        connection.write (buf, 0, buf.length);
    }

    private void commError (ModbusCounters counters)
    {
        if (counters != null)
        {
            counters.incCommErrorCount ();
            counters.addEvent (Modbus.COMM_EVENT_RCV_COMM_ERROR);
        }
    }

    private ModbusMessage receive (Connection connection,
            boolean expectRequest, Tracer tracer, ModbusCounters counters)
        throws InterruptedIOException, IOException, InterruptedException
    {
        byte buf [] = new byte [Modbus.MAX_IMP_PDU_SIZE + 3];
        int [] offsets = new int [buf.length];
        String [] errorMsg = new String [buf.length];
        String [] helpId = new String [buf.length];
        int bytesRead;

        int nChunks = 0;
        offsets [0] = 0;

        for (;;)
        {
            // Have we filled the buffer?

            assert nChunks < offsets.length;
            assert offsets [nChunks] >= nChunks;
            if (offsets [nChunks] >= buf.length)
            {
                // Discard what we've read, and throw exception to caller

                assert (nChunks != 0);
                errorMsg [nChunks - 1] =
                    "Message too long (> " + buf.length + " bytes)";
                helpId [nChunks - 1] = "S503";
                break;
            }

            // Read a chunk of data

            try
            {
                bytesRead = connection.read (buf, offsets [nChunks],
                    buf.length - offsets [nChunks],
                    nChunks == 0 ? preTimeout : eomTimeout,
                    nChunks == 0);
                if (tracer != null)
                    tracer.traceRaw ("<", buf, offsets [nChunks], bytesRead);
                assert bytesRead > 0;
            }
            catch (InterruptedIOException e)
            {
                // simple timeout

                if (nChunks != 0)
                {
                    // Trace discarded data and report error

                    break;
                }

                throw e;
            }
            catch (RecoverableIOException e)
            {
                String id = e.getHelpId ();
                String msg = e.getMessage ();
                if (nChunks != 0)
                {
                    ModbusMessage.traceDiscard (tracer,
                        id, msg, buf, 0, offsets [nChunks]);
                }
                else
                {
                    ModbusMessage.traceError (tracer, id, msg);
                }
                if (e instanceof OverrunException)
                {
                    if (counters != null)
                    {
                        counters.incOverrunCount ();
                        counters.addEvent (Modbus.COMM_EVENT_RCV_OVERRUN);
                    }
                }
                else
                {
                    commError (counters);
                }
                throw e;
            }

            assert bytesRead > 0;
            assert nChunks + 1 < offsets.length;
            offsets [nChunks + 1] = offsets [nChunks] + bytesRead;
            nChunks++;
            assert nChunks < offsets.length;
            assert offsets [nChunks] >= nChunks;

            // Attempt to build a message from the chunks we have received

            if (offsets [nChunks] - 3 > maxPdu)
            {
                String msg = "PDU size (" + (offsets [nChunks] - 3) +
                    " bytes) exceeds maximum (" + maxPdu + " bytes)";
                ModbusMessage.traceDiscard (tracer, "S308", msg, buf, 0,
                    offsets [nChunks]);
                commError (counters);
                throw new RecoverableIOException ("S308", msg);
            }

            for (int i = 0 ; i < nChunks ; i++)
            {
if (nChunks == 30)
System.err.println (Bytes.toHexString (buf, offsets [i], offsets [nChunks] -
offsets [i]));
                try
                {
                    ModbusMessage message = unpack (buf, offsets [i],
                        offsets [nChunks] - offsets [i], expectRequest, tracer);
                    if (i != 0)
                    {
                        ModbusMessage.traceDiscard (tracer, helpId [i - 1],
                            errorMsg [i - 1], buf, 0, offsets [i]);
                        commError (counters);
                    }
                    if (counters != null)
                    {
                        counters.incBusMessageCount ();
                        if (message.isRequest ())
                        {
                            int event = Modbus.COMM_EVENT_RCV_OK;
                            if (message.getSlaveId () == 0)
                                event |= Modbus.COMM_EVENT_RCV_BROADCAST;
                            counters.addEvent (event);
                        }
                    }
                    message.traceReceive ();
                    return message;
                }
                catch (FormatException e)
                {
if (nChunks == 30)
System.err.println ("FormatException: " + i + " " + offsets [i] + " " +
offsets [nChunks] + " " + e.getMessage ());
                    if (i == 0)
                    {
                        errorMsg [nChunks - 1] = e.getMessage ();
                        helpId [nChunks - 1] = e.getHelpId ();
                    }
                }
            }
        }

        if (nChunks != 0)
        {
            ModbusMessage.traceDiscard (tracer, helpId [nChunks - 1],
                errorMsg [nChunks - 1],
                buf, 0, offsets [nChunks]);
            commError (counters);
        }

        throw new RecoverableIOException (helpId [nChunks - 1],
            errorMsg [nChunks - 1]);
    }

    @Override
    public ModbusMessage receiveRequest (Connection connection, Tracer tracer,
            ModbusCounters counters)
        throws InterruptedIOException, IOException, InterruptedException
    {
        for (;;)
        {
            ModbusMessage message = receive (connection, true, tracer,
                counters);
            if (message.isRequest ())
                return message;
            // Presumably a response from another slave.
            ModbusMessage.traceDiscard (tracer, "S401", "Unexpected");
        }
    }

    @Override
    public ModbusMessage receiveResponse (Connection connection, Tracer tracer,
            ModbusCounters counters)
        throws InterruptedIOException, IOException, InterruptedException
    {
        return receive (connection, false, tracer, counters);
    }

    @Override
    public ModbusMessage receive (Connection connection, Tracer tracer,
            ModbusCounters counters)
        throws InterruptedIOException, IOException, InterruptedException
    {
        return receive (connection, true, tracer, counters);
    }

    /**
        Does packet type support packet identifiers?
        @return true if packet identifiers supported
    */

    @Override
    public boolean hasTransactionIds ()
    {
        return false;
    }

    @Override
    public void setTimeout (int timeout)
    {
        preTimeout = timeout;
    }

    @Override
    public void setEomTimeout (int timeout)
    {
        eomTimeout = timeout;
    }

    /**
    * Packs fields of the supplied message into a byte array, ready for
    * writing to a serial connection.
    * @param message the message to be packed.
    * @return packed fields from the message.
    */
    private byte [] pack (ModbusMessage message)
    {
        byte [] data = message.getData ();
        byte [] buf = new byte [data.length + 4];
        buf [0] = (byte) message.getSlaveId ();
        buf [1] = (byte) message.getFunctionCode ();
        System.arraycopy (data, 0, buf, 2, data.length);
        int crc = calculateCrc (buf, 0, data.length + 2);
        buf [data.length + 2] = (byte) crc;
        buf [data.length + 3] = (byte) (crc >> 8);
        return buf;
    }

    /**
    * Unpacks the supplied byte array (which has been read from a serial
    * connection) into fields of a Modbus message.
    * @param buf the array to be unpacked.
    * @param offset offset in {@code buf} at which to start unpacking.
    * @param len number of bytes to be unpacked.
    * @param expectRequest {@code true} if the message is expected to be a
    * request, {@code false} if it is a response.
    * @return message constructed from the unpacked fields.
    * @throws FormatException if an error occurs in the unpacking.
    */
    private ModbusMessage unpack (byte [] buf, int offset, int len,
            boolean expectRequest, Tracer tracer)
        throws FormatException
    {
        if (len < 4)
        {
            throw new FormatException ("S502",
                "Message too short (" + len + " bytes)");
        }
        int receivedCrc = (((int) buf [offset + len - 1] & 0xff) << 8) |
            ((int) buf [offset + len - 2] & 0xff);
        int calculatedCrc = calculateCrc (buf, offset, len - 2);
        if (receivedCrc != calculatedCrc)
        {
            String msg = String.format (
                "CRC failed: %02x %02x instead of %02x %02x",
                buf [offset + len - 2], buf [offset + len - 1],
                calculatedCrc & 0xff, (calculatedCrc >> 8) & 0xff);
            throw new FormatException ("S501", msg);
        }
        int slaveId = buf [offset] & 0xff;
        int function = buf [offset + 1] & 0xff;
        MessageBuilder body = new MessageBuilder ();
        body.addData (buf, offset + 2, len - 4);
        if (expectRequest && (function & 0x80) != 0)
        {
            // We are expecting a request, but the error bit is set in
            // the function code byte, so this is presumably a response.
            expectRequest = false;
        }
        try
        {
            ModbusMessage message = new ModbusMessage (expectRequest, slaveId,
                function, -1, body.getData (), tracer);
System.err.println ("checking size");
System.err.println (Bytes.toHexString (body.getData ()));
            message.checkSize (allowLongMessages);
            return message;
        }
        catch (ModbusException e)
        {
            if (expectRequest)
            {
                // In a multi-drop setup we may see responses from other
                // slaves, so see if size is valid for a response.
                try
                {
                    ModbusMessage message = new ModbusMessage (false, slaveId,
                        function, -1, body.getData (), tracer);
                    message.checkSize (allowLongMessages);
                    return message;
                }
                catch (ModbusException e1)
                {
                }
            }
            throw new FormatException (e.getHelpId (), e.getExplanation ());
        }
    }

    private int calculateCrc (byte [] buf, int offset, int len)
    {
        int result = 0xffff;

        while (len-- > 0)
        {
            result ^= (int) buf [offset++] & 0xff;

            for (int i = 0 ; i < 8 ; i++)
            {
                if ((result & 1) != 0)
                {
                    result >>= 1;
                    result ^= 0xa001;
                }
                else
                {
                    result >>= 1;
                }
            }
        }

        return result;
    }
}

