import java.util.Map;
import java.util.concurrent.locks.ReentrantLock;

class ThreadUnsafe {
    private static int unsafeCounter = 0;
    private static final int LOOPS = 1000000;

    // This method is not thread-safe.
    private static void unsafeMethod() {
        unsafeCounter += 1;
    }

    private static ReentrantLock wrapUnsafeLock = new ReentrantLock();
    private static void wrapUnsafeMethod() {
        // I would normally add this code to the method itself but I need
        // both versions in one file.
        if (!wrapUnsafeLock.tryLock()) {
            // Two callers at the same time
            dumpAllStacks();
        }
        unsafeMethod();
        wrapUnsafeLock.unlock();
    }

    private static void doThreads(Thread one, Thread two) 
            throws InterruptedException {
        unsafeCounter = 0;
        one.start();
        two.start();
        one.join();
        two.join();

        int expected = 2*LOOPS;
        System.out.println("unsafeCounter == " + unsafeCounter +
            " (should be " + expected + ")");
    }

    public static void main(String argv[]) throws InterruptedException {
        Thread one = new CounterThread();
        Thread two = new CounterThread();
        doThreads(one, two);
        
        one = new CrashCounterThread();
        two = new CrashCounterThread();
        doThreads(one, two);
    }

    private static class CounterThread extends Thread {
        public void run() {
            for (int i = 0; i < LOOPS; i++) {
                unsafeMethod();
            }
        }
    }

    private static class CrashCounterThread extends Thread {
        public void run() {
            for (int i = 0; i < LOOPS; i++) {
                wrapUnsafeMethod();
            }
        }
    }

    // Outputs all the stacks of all the active threads.
    private static void dumpAllStacks() {
        Map <Thread, StackTraceElement[]> stackTrace = Thread.getAllStackTraces();
        for (Thread t : stackTrace.keySet()) {
            System.out.println(t);
            for (StackTraceElement e : stackTrace.get(t)) {
                System.out.println("\t" + e);
            }
        }
    }
}

