/Users/deen/code/yugabyte-db/src/yb/rpc/thread_pool-test.cc
Line | Count | Source (jump to first uncovered line) |
1 | | // |
2 | | // Copyright (c) YugaByte, Inc. |
3 | | // |
4 | | // Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except |
5 | | // in compliance with the License. You may obtain a copy of the License at |
6 | | // |
7 | | // http://www.apache.org/licenses/LICENSE-2.0 |
8 | | // |
9 | | // Unless required by applicable law or agreed to in writing, software distributed under the License |
10 | | // is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express |
11 | | // or implied. See the License for the specific language governing permissions and limitations |
12 | | // under the License. |
13 | | // |
14 | | // |
15 | | |
16 | | #include <atomic> |
17 | | #include <thread> |
18 | | |
19 | | #include <gtest/gtest.h> |
20 | | |
21 | | #include "yb/rpc/strand.h" |
22 | | #include "yb/rpc/thread_pool.h" |
23 | | |
24 | | #include "yb/util/countdown_latch.h" |
25 | | #include "yb/util/test_util.h" |
26 | | #include "yb/util/thread.h" |
27 | | #include "yb/util/tsan_util.h" |
28 | | |
29 | | DECLARE_int32(TEST_strand_done_inject_delay_ms); |
30 | | |
31 | | using namespace std::literals; |
32 | | |
33 | | namespace yb { |
34 | | namespace rpc { |
35 | | |
36 | | class ThreadPoolTest : public YBTest { |
37 | | }; |
38 | | |
39 | | enum class TestTaskState { |
40 | | IDLE, |
41 | | EXECUTED, |
42 | | COMPLETED, |
43 | | FAILED, |
44 | | }; |
45 | | |
46 | | class TestTask final : public ThreadPoolTask { |
47 | | public: |
48 | 40.1k | TestTask() {} |
49 | | |
50 | 40.1k | ~TestTask() {} |
51 | | |
52 | 30.1k | bool IsCompleted() const { |
53 | 30.1k | return state_ == TestTaskState::COMPLETED; |
54 | 30.1k | } |
55 | | |
56 | 0 | bool IsFailed() const { |
57 | 0 | return state_ == TestTaskState::FAILED; |
58 | 0 | } |
59 | | |
60 | 10.0k | bool IsDone() const { |
61 | 10.0k | return state_ == TestTaskState::COMPLETED || state_ == TestTaskState::FAILED; |
62 | 10.0k | } |
63 | | |
64 | 40.0k | void SetLatch(CountDownLatch* latch) { |
65 | 40.0k | latch_ = latch; |
66 | 40.0k | } |
67 | | |
68 | | private: |
69 | 29.4k | void Run() override { |
70 | 29.4k | auto expected = TestTaskState::IDLE; |
71 | 29.4k | ASSERT_TRUE(state_.compare_exchange_strong(expected, TestTaskState::EXECUTED)); |
72 | 30.0k | ASSERT_EQ(expected, TestTaskState::IDLE); |
73 | 30.0k | } |
74 | | |
75 | 40.0k | void Done(const Status& status) override { |
76 | 30.0k | auto expected = status.ok() ? TestTaskState::EXECUTED : TestTaskState::IDLE; |
77 | 30.0k | const auto target_state = status.ok() ? TestTaskState::COMPLETED : TestTaskState::FAILED; |
78 | 40.0k | ASSERT_TRUE(state_.compare_exchange_strong(expected, target_state)); |
79 | 40.0k | if (latch_) { |
80 | 40.0k | latch_->CountDown(); |
81 | 40.0k | } |
82 | 40.0k | } |
83 | | |
84 | | CountDownLatch* latch_ = nullptr; |
85 | | std::atomic<TestTaskState> state_ = { TestTaskState::IDLE }; |
86 | | }; |
87 | | |
88 | 1 | TEST_F(ThreadPoolTest, TestSingleThread) { |
89 | 1 | constexpr size_t kTotalTasks = 100; |
90 | 1 | constexpr size_t kTotalWorkers = 1; |
91 | 1 | ThreadPool pool("test", kTotalTasks, kTotalWorkers); |
92 | | |
93 | 1 | CountDownLatch latch(kTotalTasks); |
94 | 1 | std::vector<TestTask> tasks(kTotalTasks); |
95 | 100 | for (auto& task : tasks) { |
96 | 100 | task.SetLatch(&latch); |
97 | 100 | ASSERT_TRUE(pool.Enqueue(&task)); |
98 | 100 | } |
99 | 1 | latch.Wait(); |
100 | 100 | for (auto& task : tasks) { |
101 | 100 | ASSERT_TRUE(task.IsCompleted()); |
102 | 100 | } |
103 | 1 | } |
104 | | |
105 | 1 | TEST_F(ThreadPoolTest, TestSingleProducer) { |
106 | 1 | constexpr size_t kTotalTasks = 10000; |
107 | 1 | constexpr size_t kTotalWorkers = 4; |
108 | 1 | ThreadPool pool("test", kTotalTasks, kTotalWorkers); |
109 | | |
110 | 1 | CountDownLatch latch(kTotalTasks); |
111 | 1 | std::vector<TestTask> tasks(kTotalTasks); |
112 | 10.0k | for (auto& task : tasks) { |
113 | 10.0k | task.SetLatch(&latch); |
114 | 10.0k | ASSERT_TRUE(pool.Enqueue(&task)); |
115 | 10.0k | } |
116 | 1 | latch.Wait(); |
117 | 10.0k | for (auto& task : tasks) { |
118 | 10.0k | ASSERT_TRUE(task.IsCompleted()); |
119 | 10.0k | } |
120 | 1 | } |
121 | | |
122 | 1 | TEST_F(ThreadPoolTest, TestMultiProducers) { |
123 | 1 | constexpr size_t kTotalTasks = 10000; |
124 | 1 | constexpr size_t kTotalWorkers = 4; |
125 | 1 | constexpr size_t kProducers = 4; |
126 | 1 | ThreadPool pool("test", kTotalTasks, kTotalWorkers); |
127 | | |
128 | 1 | CountDownLatch latch(kTotalTasks); |
129 | 1 | std::vector<TestTask> tasks(kTotalTasks); |
130 | 1 | std::vector<std::thread> threads; |
131 | 1 | size_t begin = 0; |
132 | 5 | for (size_t i = 0; i != kProducers; ++i) { |
133 | 4 | size_t end = kTotalTasks * (i + 1) / kProducers; |
134 | 4 | threads.emplace_back([&pool, &latch, &tasks, begin, end] { |
135 | 4 | CDSAttacher attacher; |
136 | 10.0k | for (size_t i = begin; i != end; ++i) { |
137 | 9.99k | tasks[i].SetLatch(&latch); |
138 | 9.99k | ASSERT_TRUE(pool.Enqueue(&tasks[i])); |
139 | 9.99k | } |
140 | 4 | }); |
141 | 4 | begin = end; |
142 | 4 | } |
143 | 1 | latch.Wait(); |
144 | 10.0k | for (auto& task : tasks) { |
145 | 10.0k | ASSERT_TRUE(task.IsCompleted()); |
146 | 10.0k | } |
147 | 4 | for (auto& thread : threads) { |
148 | 4 | thread.join(); |
149 | 4 | } |
150 | 1 | } |
151 | | |
152 | 1 | TEST_F(ThreadPoolTest, TestQueueOverflow) { |
153 | 1 | constexpr size_t kTotalTasks = 10000; |
154 | 1 | constexpr size_t kTotalWorkers = 4; |
155 | 1 | constexpr size_t kProducers = 4; |
156 | 1 | ThreadPool pool("test", kTotalTasks, kTotalWorkers); |
157 | | |
158 | 1 | CountDownLatch latch(kTotalTasks); |
159 | 1 | std::vector<TestTask> tasks(kTotalTasks); |
160 | 1 | std::vector<std::thread> threads; |
161 | 1 | size_t begin = 0; |
162 | 1 | std::atomic<size_t> enqueue_failed(0); |
163 | 5 | for (size_t i = 0; i != kProducers; ++i) { |
164 | 4 | size_t end = kTotalTasks * (i + 1) / kProducers; |
165 | 4 | threads.emplace_back([&pool, &latch, &tasks, &enqueue_failed, begin, end] { |
166 | 4 | CDSAttacher attacher; |
167 | 10.0k | for (size_t i = begin; i != end; ++i) { |
168 | 9.99k | tasks[i].SetLatch(&latch); |
169 | 9.99k | if(!pool.Enqueue(&tasks[i])) { |
170 | 0 | ++enqueue_failed; |
171 | 0 | } |
172 | 9.99k | } |
173 | 4 | }); |
174 | 4 | begin = end; |
175 | 4 | } |
176 | 1 | latch.Wait(); |
177 | 1 | size_t failed = 0; |
178 | 10.0k | for (auto& task : tasks) { |
179 | 10.0k | if(!task.IsCompleted()) { |
180 | 0 | ASSERT_TRUE(task.IsFailed()); |
181 | 0 | ++failed; |
182 | 0 | } |
183 | 10.0k | } |
184 | 1 | ASSERT_EQ(enqueue_failed, failed); |
185 | 4 | for (auto& thread : threads) { |
186 | 4 | thread.join(); |
187 | 4 | } |
188 | 1 | } |
189 | | |
190 | 1 | TEST_F(ThreadPoolTest, TestShutdown) { |
191 | 1 | constexpr size_t kTotalTasks = 10000; |
192 | 1 | constexpr size_t kTotalWorkers = 4; |
193 | 1 | constexpr size_t kProducers = 4; |
194 | 1 | ThreadPool pool("test", kTotalTasks, kTotalWorkers); |
195 | | |
196 | 1 | CountDownLatch latch(kTotalTasks); |
197 | 1 | std::vector<TestTask> tasks(kTotalTasks); |
198 | 1 | std::vector<std::thread> threads; |
199 | 1 | size_t begin = 0; |
200 | 5 | for (size_t i = 0; i != kProducers; ++i) { |
201 | 4 | size_t end = kTotalTasks * (i + 1) / kProducers; |
202 | 4 | threads.emplace_back([&pool, &latch, &tasks, begin, end] { |
203 | 4 | CDSAttacher attacher; |
204 | 9.95k | for (size_t i = begin; i != end; ++i) { |
205 | 9.94k | tasks[i].SetLatch(&latch); |
206 | 9.94k | pool.Enqueue(&tasks[i]); |
207 | 9.94k | } |
208 | 4 | }); |
209 | 4 | begin = end; |
210 | 4 | } |
211 | 1 | pool.Shutdown(); |
212 | 1 | latch.Wait(); |
213 | 10.0k | for (auto& task : tasks) { |
214 | 10.0k | ASSERT_TRUE(task.IsDone()); |
215 | 10.0k | } |
216 | 4 | for (auto& thread : threads) { |
217 | 4 | thread.join(); |
218 | 4 | } |
219 | 1 | } |
220 | | |
221 | 1 | TEST_F(ThreadPoolTest, TestOwns) { |
222 | 1 | class TestTask : public ThreadPoolTask { |
223 | 1 | public: |
224 | 1 | explicit TestTask(ThreadPool* thread_pool) : thread_pool_(thread_pool) {} |
225 | | |
226 | 1 | void Run() { |
227 | 1 | thread_ = Thread::current_thread(); |
228 | 1 | ASSERT_TRUE(thread_pool_->OwnsThisThread()); |
229 | 1 | } |
230 | | |
231 | 1 | void Done(const Status& status) { |
232 | 1 | latch_.CountDown(); |
233 | 1 | } |
234 | | |
235 | 1 | Thread* thread() { |
236 | 1 | return thread_; |
237 | 1 | } |
238 | | |
239 | 1 | void Wait() { |
240 | 1 | return latch_.Wait(); |
241 | 1 | } |
242 | | |
243 | 1 | virtual ~TestTask() {} |
244 | | |
245 | 1 | private: |
246 | 1 | ThreadPool* const thread_pool_; |
247 | 1 | Thread* thread_ = nullptr; |
248 | 1 | CountDownLatch latch_{1}; |
249 | 1 | }; |
250 | | |
251 | 1 | constexpr size_t kTotalTasks = 1; |
252 | 1 | constexpr size_t kTotalWorkers = 1; |
253 | | |
254 | 1 | ThreadPool pool("test", kTotalTasks, kTotalWorkers); |
255 | 1 | ASSERT_FALSE(pool.OwnsThisThread()); |
256 | 1 | TestTask task(&pool); |
257 | 1 | pool.Enqueue(&task); |
258 | 1 | task.Wait(); |
259 | 1 | ASSERT_TRUE(pool.Owns(task.thread())); |
260 | 1 | } |
261 | | |
262 | | namespace strand { |
263 | | |
264 | | constexpr size_t kPoolMaxTasks = 100; |
265 | | constexpr size_t kPoolTotalWorkers = 4; |
266 | | |
267 | 1 | TEST_F(ThreadPoolTest, Strand) { |
268 | 1 | ThreadPool pool("test", kPoolMaxTasks, kPoolTotalWorkers); |
269 | 1 | Strand strand(&pool); |
270 | | |
271 | 1 | CountDownLatch latch(kPoolMaxTasks); |
272 | 1 | std::atomic<int> counter(0); |
273 | 101 | for (auto i = 0; i != kPoolMaxTasks; ++i) { |
274 | 100 | strand.EnqueueFunctor([&counter, &latch] { |
275 | 100 | ASSERT_EQ(++counter, 1); |
276 | 100 | std::this_thread::sleep_for(1ms); |
277 | 100 | ASSERT_EQ(--counter, 0); |
278 | 100 | latch.CountDown(); |
279 | 100 | }); |
280 | 100 | } |
281 | | |
282 | 1 | latch.Wait(); |
283 | 1 | strand.Shutdown(); |
284 | 1 | } |
285 | | |
286 | 1 | TEST_F(ThreadPoolTest, StrandShutdown) { |
287 | 1 | ThreadPool pool("test", kPoolMaxTasks, kPoolTotalWorkers); |
288 | 1 | Strand strand(&pool); |
289 | | |
290 | 1 | CountDownLatch latch1(1); |
291 | 1 | strand.EnqueueFunctor([&latch1] { |
292 | 1 | latch1.CountDown(); |
293 | 1 | std::this_thread::sleep_for(500ms); |
294 | 1 | }); |
295 | 1 | class AbortedTask : public StrandTask { |
296 | 1 | public: |
297 | 0 | void Run() override { |
298 | 0 | ASSERT_TRUE(false); |
299 | 0 | } |
300 | | |
301 | 1 | void Done(const Status& status) override { |
302 | 1 | ASSERT_TRUE(status.IsAborted()); |
303 | 1 | } |
304 | | |
305 | 1 | virtual ~AbortedTask() = default; |
306 | 1 | }; |
307 | 1 | AbortedTask aborted_task; |
308 | 1 | strand.Enqueue(&aborted_task); |
309 | 1 | latch1.Wait(); |
310 | 1 | strand.Shutdown(); |
311 | 1 | } |
312 | | |
313 | 1 | TEST_F(ThreadPoolTest, NotUsedStrandShutdown) { |
314 | 1 | ThreadPool pool("test", kPoolMaxTasks, kPoolTotalWorkers); |
315 | | |
316 | 1 | Strand strand(&pool); |
317 | | |
318 | 1 | std::atomic<bool> shutdown_completed{false}; |
319 | 1 | std::thread shutdown_thread([strand = &strand, &shutdown_completed]{ |
320 | 1 | strand->Shutdown(); |
321 | 1 | shutdown_completed = true; |
322 | 1 | }); |
323 | | |
324 | 1 | ASSERT_OK(LoggedWaitFor( |
325 | 1 | [&shutdown_completed] { |
326 | 1 | return shutdown_completed.load(); |
327 | 1 | }, |
328 | 1 | 5s * kTimeMultiplier, "Waiting for strand shutdown")); |
329 | | |
330 | 1 | shutdown_thread.join(); |
331 | 1 | } |
332 | | |
333 | 1 | TEST_F(ThreadPoolTest, StrandShutdownAndDestroyRace) { |
334 | 1 | constexpr size_t kNumIters = 10; |
335 | | |
336 | 1 | ThreadPool pool("test", kPoolMaxTasks, kPoolTotalWorkers); |
337 | | |
338 | 10 | auto task = []{}; |
339 | | |
340 | 11 | for (size_t iter = 0; iter < kNumIters; ++iter) { |
341 | 10 | Strand strand(&pool); |
342 | | |
343 | 10 | ANNOTATE_UNPROTECTED_WRITE(FLAGS_TEST_strand_done_inject_delay_ms) = 0; |
344 | 10 | strand.EnqueueFunctor(task); |
345 | | // Give enough time for Strand::Done to be finished. |
346 | 10 | std::this_thread::sleep_for(10ms); |
347 | 10 | ANNOTATE_UNPROTECTED_WRITE(FLAGS_TEST_strand_done_inject_delay_ms) = 10; |
348 | 10 | strand.EnqueueFunctor(task); |
349 | | |
350 | 10 | strand.Shutdown(); |
351 | 10 | } |
352 | 1 | } |
353 | | |
354 | | } // namespace strand |
355 | | |
356 | | } // namespace rpc |
357 | | } // namespace yb |