
package uk.co.wingpath.util;

import java.util.concurrent.*;

/**
* This class provides static methods for creating thread pools.
* <p>A shutdown hook is added for each thread pool to shut down the
* thread pool when the program exits.
* <p>Each {@link Runnable} task that is executed by (or submitted to) a
* pool is wrapped so that any unchecked exception thrown by the task will
* abort the program. (The default behaviour of {@code ThreadPoolExecutor} is
* to print the exception or save it in a {@link Future}, and continue running.
* Setting an uncaught-exception handler on the pool threads has no effect on
* this behaviour).
*/
public class ThreadPool
    extends ThreadPoolExecutor
{
    private final String name;
    private final Thread.UncaughtExceptionHandler exceptionHandler;
    private final Thread shutdownHook;

    /**
    * Creates a new <tt>ThreadPool</tt> with the given initial
    * parameters and default thread factory and rejected execution handler.
    *
    * @param corePoolSize the number of threads to keep in the
    * pool, even if they are idle.
    * @param maximumPoolSize the maximum number of threads to allow in the
    * pool.
    * @param keepAliveTime when the number of threads is greater than
    * the core, this is the maximum time that excess idle threads
    * will wait for new tasks before terminating.
    * @param unit the time unit for the keepAliveTime
    * argument.
    * @param workQueue the queue to use for holding tasks before they
    * are executed. This queue will hold only the <tt>Runnable</tt>
    * tasks submitted by the <tt>execute</tt> method.
    * @param name name for the thread pool. This is used to generate thread
    * names.
    * @param exceptionHandler handler for uncaught exceptions.
    * May be {@code null} to report to {@code System.err}.
    * @throws IllegalArgumentException if corePoolSize or
    * keepAliveTime less than zero, or if maximumPoolSize less than or
    * equal to zero, or if corePoolSize greater than maximumPoolSize.
    * @throws NullPointerException if <tt>workQueue</tt> is null
    */
    public ThreadPool (int corePoolSize, int maximumPoolSize,
        long keepAliveTime, TimeUnit unit,
        BlockingQueue<Runnable> workQueue,
        final String name,
        Thread.UncaughtExceptionHandler exceptionHandler)
    {
        super (corePoolSize, maximumPoolSize,
            keepAliveTime, unit,
            workQueue, new NamedThreadFactory (name));
        this.name = name;
        this.exceptionHandler = exceptionHandler;
        shutdownHook = new Thread (
            new Runnable ()
            {
                public void run ()
                {
                    try
                    {
                        shutdownNow ();
                        awaitTermination (10, TimeUnit.SECONDS);
                    }
                    catch (Throwable e)
                    {
                    }
                }
            },
            name + "-Pool-Shutdown");
        try
        {
            Runtime.getRuntime ().addShutdownHook (shutdownHook);
        }
        catch (IllegalStateException e)
        {
            // Happens if already shutting down.
        }
        catch (SecurityException e)
        {
            // May happen if called from applet.
        }
    }

    @Override
    protected void terminated ()
    {
        super.terminated ();
        try
        {
            Runtime.getRuntime ().removeShutdownHook (shutdownHook);
        }
        catch (IllegalStateException e)
        {
            // Happens if already shutting down.
        }
        catch (SecurityException e)
        {
            // May happen if called from applet.
        }
    }

    /**
    * This class implements {@link ThreadFactory} by delegating to
    * {@link java.util.concurrent.Executors.DefaultThreadFactory}.
    * It generates thread names from the pool name, but otherwise behaves
    * exactly like {@code Executors.DefaultThreadFactory}.
    */
    private static class NamedThreadFactory
        implements ThreadFactory
    {
        private final String name;
        private final ThreadFactory factory;

        private NamedThreadFactory (String name)
        {
            this.name = name;
            factory = Executors.defaultThreadFactory ();
        }

        public Thread newThread (Runnable r)
        {
            Thread t = factory.newThread (r);
            String tname = name + "-" + t.getId ();
            t.setName (tname);
            return t;
        }
    }

    private class AbortRunnable
        implements Runnable
    {
        private final Runnable task;

        AbortRunnable (Runnable task)
        {
            this.task = task;
        }

        public void run ()
        {
            try
            {
                task.run ();
            }
            catch (Throwable e)
            {
                if (!isTerminating ())
                {
                    exceptionHandler.uncaughtException (
                        Thread.currentThread (), e);
                }
            }
        }
    }

    @Override
    public void execute (Runnable task)
    {
        try
        {
            if (task instanceof Future)
                super.execute (task);
            else
                super.execute (new AbortRunnable (task));
        }
        catch (RejectedExecutionException e)
        {
            // Should only happen if executor has been shut down.
        }
    }

    @Override
    public Future<?> submit (Runnable task)
    {
        return super.submit (new AbortRunnable (task));
    }

    @Override
    public <T> Future<T> submit (Runnable task, T result)
    {
        return super.submit (new AbortRunnable (task), result);
    }

    /**
    * Creates a cached thread pool using
    * {@link java.util.concurrent.Executors#newCachedThreadPool}.
    * @param name name for the thread pool. This is used to generate thread
    * names.
    * @param exceptionHandler handler for uncaught exceptions.
    * May be {@code null} to report to {@code System.err}.
    */
    public static ExecutorService createCached (String name,
        Thread.UncaughtExceptionHandler exceptionHandler)
    {
        return new ThreadPool (0, Integer.MAX_VALUE,
            60L, TimeUnit.SECONDS,
            new SynchronousQueue<Runnable> (),
            name, exceptionHandler);
    }

    /**
    * Creates a single-thread thread pool using
    * {@link java.util.concurrent.Executors#newSingleThreadExecutor}.
    * @param name name for the thread pool. This is used to generate thread
    * names.
    * @param exceptionHandler handler for uncaught exceptions.
    * May be {@code null} to report to {@code System.err}.
    */
    public static ExecutorService createSingle (String name,
        Thread.UncaughtExceptionHandler exceptionHandler)
    {
        return new ThreadPool (1, 1,
            0L, TimeUnit.MILLISECONDS,
            new LinkedBlockingQueue<Runnable> (),
            name, exceptionHandler);
    }
}

