
package uk.co.wingpath.util;

/**
* This class provides various static methods for manipulating sequences of
* bits stored in byte arrays.
* <p>The bits are assumed to be stored "from left to right" in each byte,
* since this is the "natural" printing order (e.g. if printed using
* {@link Bytes#toHexString}.
*/
public class Bits
{
    private Bits () {}

    /**
    * Shifts the contents of a byte array right by a specified number of bits.
    * @param src the byte array to be shifted.
    * @param shift how many bits to shift by.
    * @return a new array the same length as {@code src} with shifted contents
    */
    public static byte [] shiftRight (byte [] src, int shift)
    {
        if (shift < 0)
            return shiftRight (src, -shift);
        byte [] dest = new byte [src.length];
        int byteShift = shift / 8;
        if (byteShift >= src.length)
            return dest;
        int bitShift = shift % 8;
        byte last = 0;

        for (int i = 0 ; i < dest.length ; i++)
        {
            int j = i - byteShift;
            byte b = j >= 0 ? src [j] : 0;
            dest [i] = (byte) ((last << (8 - bitShift)) |
                ((b & 0xff) >> bitShift));
            last = b;
        }

        return dest;
    }

    /**
    * Shifts the contents of a byte array left by a specified number of bits.
    * @param src the byte array to be shifted.
    * @param shift how many bits to shift by.
    * @return a new array the same length as {@code src} with shifted contents
    */
    public static byte [] shiftLeft (byte [] src, int shift)
    {
        if (shift < 0)
            return shiftLeft (src, -shift);
        byte [] dest = new byte [src.length];
        int byteShift = shift / 8;
        if (byteShift >= src.length)
            return dest;
        int bitShift = shift % 8;
        byte b = src [byteShift];

        for (int i = 0 ; i < dest.length ; i++)
        {
            int j = i + byteShift + 1;
            byte next = j < src.length ? src [j] : 0;
            dest [i] = (byte) ((b << bitShift) |
                ((next & 0xff) >> (8 - bitShift)));
            b = next;
        }

        return dest;
    }

    /**
    * Rotates the contents of a byte array right by a specified number of bits.
    * @param src the byte array to be rotated.
    * @param shift how many bits to rotate by.
    * @return a new array the same length as {@code src} with rotated contents
    */
    public static byte [] rotateRight (byte [] src, int shift)
    {
        if (shift < 0)
            return rotateRight (src, -shift);
        byte [] dest = new byte [src.length];
        shift %= src.length * 8;
        int byteShift = shift / 8;
        int bitShift = shift % 8;
        int ir = src.length - byteShift;
        int il = ir - 1;
        assert il >= 0;
        if (ir >= src.length)
            ir = 0;

        for (int i = 0 ; i < dest.length ; i++)
        {
            byte br = src [ir];
            byte bl = src [il];
            dest [i] = (byte) ((bl << (8 - bitShift)) |
                ((br & 0xff) >> bitShift));
            ir++;
            if (ir >= src.length)
                ir = 0;
            il++;
            if (il >= src.length)
                il = 0;
        }

        return dest;
    }

    /**
    * Rotates the contents of a byte array left by a specified number of bits.
    * @param src the byte array to be rotated.
    * @param shift how many bits to rotate by.
    * @return a new array the same length as {@code src} with rotated contents
    */
    public static byte [] rotateLeft (byte [] src, int shift)
    {
        if (shift < 0)
            return rotateLeft (src, -shift);
        byte [] dest = new byte [src.length];
        shift %= src.length * 8;
        int byteShift = shift / 8;
        int bitShift = shift % 8;
        int il = byteShift;
        int ir = il + 1;
        if (ir == src.length)
            ir = 0;

        for (int i = 0 ; i < dest.length ; i++)
        {
            byte br = src [ir];
            byte bl = src [il];
            dest [i] = (byte) ((bl << bitShift) |
                ((br & 0xff) >> (8 - bitShift)));
            ir++;
            if (ir >= src.length)
                ir = 0;
            il++;
            if (il >= src.length)
                il = 0;
        }

        return dest;
    }

    private static byte [] bitCount;

    static
    {
        bitCount = new byte [256];

        for (int i = 0 ; i < 256 ; i++)
        {
            int count = 0;
            int mask = 1;

            for (int j = 0 ; j < 8 ; j++)
            {
                if ((i & mask) != 0)
                    count++;
                mask <<= 1;
            }

            bitCount [i] = (byte) count;
        }
    }

    /**
    * Counts the number of 1 bits in a sequence.
    * @param seq the bit sequence.
    * @return the number of 1 bits in the sequence.
    */
    public static int countBits (byte [] seq)
    {
        int count = 0;

        for (byte b : seq)
            count += bitCount [b & 0xff];

        return count;
    }

    /**
    * Computes the XOR auto-correlation of bit sequence.
    * @param seq the bit sequence.
    * @param len the maximum shift to use.
    * @return the auto-correlation.
    */
    public static int [] autocorrelate (byte [] seq, int len)
    {
        int [] result = new int [len];

        for (int i = 0 ; i < result.length ; i++)
        {
            byte [] rseq = rotateLeft (seq, i);
            Bytes.xor (rseq, seq);
            result [i] = seq.length * 8 - countBits (rseq);
        }

        return result;
    }
}

