1   /*
2       This file is part of quExec.
3   
4       quExec is free software; you can redistribute it and/or modify
5       it under the terms of the GNU Lesser General Public License as published by
6       the Free Software Foundation; either version 2 of the License, or
7       (at your option) any later version.
8   
9       quExec is distributed in the hope that it will be useful,
10      but WITHOUT ANY WARRANTY; without even the implied warranty of
11      MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
12      GNU Lesser General Public License for more details.
13  
14      You should have received a copy of the GNU Lesser General Public License
15      along with quExec; if not, write to the Free Software
16      Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA  02110-1301  USA
17  */
18  
19  package net.sourceforge.quexec.testutil;
20  
21  
22  import java.util.ArrayList;
23  import java.util.Collections;
24  import java.util.List;
25  import java.util.concurrent.CountDownLatch;
26  
27  import org.apache.commons.logging.Log;
28  import org.apache.commons.logging.LogFactory;
29  import org.junit.After;
30  import org.junit.Before;
31  import org.junit.BeforeClass;
32  
33  /**
34   * Utility base class for JUnit tests which use multiple threads.
35   * 
36   * This utility class provides some methods for conveniently defining
37   * and executing multi-threaded tests in JUnit. In detail, it performs
38   * the following tasks:
39   * 
40   * - catch exceptions in test threads and rethrow them in the JUnit main
41   * thread as unchecked exceptions
42   * - provide a convenient interface for defining test threads (replacing
43   * Runnable which is inconvenient to use since run() may not throw
44   * exceptions) 
45   * - ensure that all test threads start in parallel (by blocking their
46   * execution until all threads are ready to start)
47   * - wait for the completion of all test threads
48   * 
49   * @author schickin
50   *
51   */
52  public abstract class AbstractMultithreadedTest {
53  	
54  	public interface Task {
55  		public void doIt() throws Throwable;
56  	}
57  	
58  	private static final Log log = LogFactory.getLog(AbstractMultithreadedTest.class);
59  	
60  	private static final List<Throwable> thrownByTest =
61  		Collections.synchronizedList(new ArrayList<Throwable>(10));
62  	
63  	private final List<Thread> testThreads =
64  		Collections.synchronizedList(new ArrayList<Thread>(10));
65  	
66  	private final CountDownLatch testStartLatch = new CountDownLatch(1);
67  
68  	/**
69  	 * This method must be executed as @BeforeClass method by subclasses.
70  	 * No action is required if subclasses do not define their own such
71  	 * method.
72  	 */
73  	@BeforeClass
74  	public static void setUpMultithreadingBeforeClass() throws Throwable {
75  		Thread.setDefaultUncaughtExceptionHandler(
76  				new Thread.UncaughtExceptionHandler() {
77  					public void uncaughtException(Thread t, Throwable e) {
78  						thrownByTest.add(e);
79  						log.error("Exception in thread '" + t.getName(),
80  								e.fillInStackTrace());
81  						throw new IllegalStateException(
82  								"Exception in thread '" + t.getName() + "'",
83  								e);
84  					}
85  				});
86  	}
87  	
88  	/**
89  	 * This method must be executed as @Before method by subclasses.
90  	 * No action is required if subclasses do not define their own such
91  	 * method.
92  	 */
93  	@Before
94  	public void setUpMultithreading() {
95  		thrownByTest.clear();
96  		testThreads.clear();
97  	}
98  	
99  	/**
100 	 * This method must be executed as @After method by subclasses.
101 	 * No action is required if subclasses do not define their own such
102 	 * method.
103 	 */
104 	@After
105 	public void tearDownMultithreading() throws Throwable {
106 		if (!thrownByTest.isEmpty()) {
107 			Throwable e = thrownByTest.get(0);
108 			log.warn("exception by test thread detected: ", e.fillInStackTrace());
109 			throw e;
110 		}
111 	}
112 	
113 	protected final Thread scheduleTask(final Task task, String name) {
114 		Thread t = new Thread(name) {
115 			public void run() {
116 				try {
117 					testStartLatch.await();
118 					task.doIt();
119 				}
120 				catch (Throwable e) {
121 					rethrowAsUnchecked(e);
122 				}
123 			}
124 		};
125 		t.start();
126 		testThreads.add(t);
127 		return t;
128 	}
129 	
130 	protected final void startAll() {
131 		testStartLatch.countDown();
132 	}
133 
134 	protected final void joinAll() throws InterruptedException {
135 		for (Thread t : testThreads) {
136 			log.debug("Waiting for thread: " + t);
137 			t.join();
138 		}
139 	}
140 	
141 	protected final void runAll() throws InterruptedException {
142 		startAll();
143 		joinAll();
144 	}
145 	
146 	private void rethrowAsUnchecked(Throwable e) {
147 		if (e instanceof RuntimeException) {
148 			throw (RuntimeException) e;
149 		}
150 		else if (e instanceof Error) {
151 			throw (Error) e;
152 		}
153 		else {
154 			throw new IllegalStateException(
155 					"wrapped checked exception", e);
156 		}
157 	}
158 }