YugabyteDB (2.13.0.0-b42, bfc6a6643e7399ac8a0e81d06a3ee6d6571b33ab)

Coverage Report

Created: 2022-03-09 17:30

/Users/deen/code/yugabyte-db/src/yb/rpc/thread_pool-test.cc
Line
Count
Source (jump to first uncovered line)
1
//
2
// Copyright (c) YugaByte, Inc.
3
//
4
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
5
// in compliance with the License.  You may obtain a copy of the License at
6
//
7
// http://www.apache.org/licenses/LICENSE-2.0
8
//
9
// Unless required by applicable law or agreed to in writing, software distributed under the License
10
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
11
// or implied.  See the License for the specific language governing permissions and limitations
12
// under the License.
13
//
14
//
15
16
#include <atomic>
17
#include <thread>
18
19
#include <gtest/gtest.h>
20
21
#include "yb/rpc/strand.h"
22
#include "yb/rpc/thread_pool.h"
23
24
#include "yb/util/countdown_latch.h"
25
#include "yb/util/test_util.h"
26
#include "yb/util/thread.h"
27
#include "yb/util/tsan_util.h"
28
29
DECLARE_int32(TEST_strand_done_inject_delay_ms);
30
31
using namespace std::literals;
32
33
namespace yb {
34
namespace rpc {
35
36
class ThreadPoolTest : public YBTest {
37
};
38
39
enum class TestTaskState {
40
  IDLE,
41
  EXECUTED,
42
  COMPLETED,
43
  FAILED,
44
};
45
46
class TestTask final : public ThreadPoolTask {
47
 public:
48
40.1k
  TestTask() {}
49
50
40.1k
  ~TestTask() {}
51
52
30.1k
  bool IsCompleted() const {
53
30.1k
    return state_ == TestTaskState::COMPLETED;
54
30.1k
  }
55
56
0
  bool IsFailed() const {
57
0
    return state_ == TestTaskState::FAILED;
58
0
  }
59
60
10.0k
  bool IsDone() const {
61
10.0k
    return state_ == TestTaskState::COMPLETED || state_ == TestTaskState::FAILED;
62
10.0k
  }
63
64
40.0k
  void SetLatch(CountDownLatch* latch) {
65
40.0k
    latch_ = latch;
66
40.0k
  }
67
68
 private:
69
29.4k
  void Run() override {
70
29.4k
    auto expected = TestTaskState::IDLE;
71
29.4k
    ASSERT_TRUE(state_.compare_exchange_strong(expected, TestTaskState::EXECUTED));
72
30.0k
    ASSERT_EQ(expected, TestTaskState::IDLE);
73
30.0k
  }
74
75
40.0k
  void Done(const Status& status) override {
76
30.0k
    auto expected = status.ok() ? TestTaskState::EXECUTED : TestTaskState::IDLE;
77
30.0k
    const auto target_state = status.ok() ? TestTaskState::COMPLETED : TestTaskState::FAILED;
78
40.0k
    ASSERT_TRUE(state_.compare_exchange_strong(expected, target_state));
79
40.0k
    if (latch_) {
80
40.0k
      latch_->CountDown();
81
40.0k
    }
82
40.0k
  }
83
84
  CountDownLatch* latch_ = nullptr;
85
  std::atomic<TestTaskState> state_ = { TestTaskState::IDLE };
86
};
87
88
1
TEST_F(ThreadPoolTest, TestSingleThread) {
89
1
  constexpr size_t kTotalTasks = 100;
90
1
  constexpr size_t kTotalWorkers = 1;
91
1
  ThreadPool pool("test", kTotalTasks, kTotalWorkers);
92
93
1
  CountDownLatch latch(kTotalTasks);
94
1
  std::vector<TestTask> tasks(kTotalTasks);
95
100
  for (auto& task : tasks) {
96
100
    task.SetLatch(&latch);
97
100
    ASSERT_TRUE(pool.Enqueue(&task));
98
100
  }
99
1
  latch.Wait();
100
100
  for (auto& task : tasks) {
101
100
    ASSERT_TRUE(task.IsCompleted());
102
100
  }
103
1
}
104
105
1
TEST_F(ThreadPoolTest, TestSingleProducer) {
106
1
  constexpr size_t kTotalTasks = 10000;
107
1
  constexpr size_t kTotalWorkers = 4;
108
1
  ThreadPool pool("test", kTotalTasks, kTotalWorkers);
109
110
1
  CountDownLatch latch(kTotalTasks);
111
1
  std::vector<TestTask> tasks(kTotalTasks);
112
10.0k
  for (auto& task : tasks) {
113
10.0k
    task.SetLatch(&latch);
114
10.0k
    ASSERT_TRUE(pool.Enqueue(&task));
115
10.0k
  }
116
1
  latch.Wait();
117
10.0k
  for (auto& task : tasks) {
118
10.0k
    ASSERT_TRUE(task.IsCompleted());
119
10.0k
  }
120
1
}
121
122
1
TEST_F(ThreadPoolTest, TestMultiProducers) {
123
1
  constexpr size_t kTotalTasks = 10000;
124
1
  constexpr size_t kTotalWorkers = 4;
125
1
  constexpr size_t kProducers = 4;
126
1
  ThreadPool pool("test", kTotalTasks, kTotalWorkers);
127
128
1
  CountDownLatch latch(kTotalTasks);
129
1
  std::vector<TestTask> tasks(kTotalTasks);
130
1
  std::vector<std::thread> threads;
131
1
  size_t begin = 0;
132
5
  for (size_t i = 0; i != kProducers; ++i) {
133
4
    size_t end = kTotalTasks * (i + 1) / kProducers;
134
4
    threads.emplace_back([&pool, &latch, &tasks, begin, end] {
135
4
      CDSAttacher attacher;
136
10.0k
      for (size_t i = begin; i != end; ++i) {
137
9.99k
        tasks[i].SetLatch(&latch);
138
9.99k
        ASSERT_TRUE(pool.Enqueue(&tasks[i]));
139
9.99k
      }
140
4
    });
141
4
    begin = end;
142
4
  }
143
1
  latch.Wait();
144
10.0k
  for (auto& task : tasks) {
145
10.0k
    ASSERT_TRUE(task.IsCompleted());
146
10.0k
  }
147
4
  for (auto& thread : threads) {
148
4
    thread.join();
149
4
  }
150
1
}
151
152
1
TEST_F(ThreadPoolTest, TestQueueOverflow) {
153
1
  constexpr size_t kTotalTasks = 10000;
154
1
  constexpr size_t kTotalWorkers = 4;
155
1
  constexpr size_t kProducers = 4;
156
1
  ThreadPool pool("test", kTotalTasks, kTotalWorkers);
157
158
1
  CountDownLatch latch(kTotalTasks);
159
1
  std::vector<TestTask> tasks(kTotalTasks);
160
1
  std::vector<std::thread> threads;
161
1
  size_t begin = 0;
162
1
  std::atomic<size_t> enqueue_failed(0);
163
5
  for (size_t i = 0; i != kProducers; ++i) {
164
4
    size_t end = kTotalTasks * (i + 1) / kProducers;
165
4
    threads.emplace_back([&pool, &latch, &tasks, &enqueue_failed, begin, end] {
166
4
      CDSAttacher attacher;
167
10.0k
      for (size_t i = begin; i != end; ++i) {
168
9.99k
        tasks[i].SetLatch(&latch);
169
9.99k
        if(!pool.Enqueue(&tasks[i])) {
170
0
          ++enqueue_failed;
171
0
        }
172
9.99k
      }
173
4
    });
174
4
    begin = end;
175
4
  }
176
1
  latch.Wait();
177
1
  size_t failed = 0;
178
10.0k
  for (auto& task : tasks) {
179
10.0k
    if(!task.IsCompleted()) {
180
0
      ASSERT_TRUE(task.IsFailed());
181
0
      ++failed;
182
0
    }
183
10.0k
  }
184
1
  ASSERT_EQ(enqueue_failed, failed);
185
4
  for (auto& thread : threads) {
186
4
    thread.join();
187
4
  }
188
1
}
189
190
1
TEST_F(ThreadPoolTest, TestShutdown) {
191
1
  constexpr size_t kTotalTasks = 10000;
192
1
  constexpr size_t kTotalWorkers = 4;
193
1
  constexpr size_t kProducers = 4;
194
1
  ThreadPool pool("test", kTotalTasks, kTotalWorkers);
195
196
1
  CountDownLatch latch(kTotalTasks);
197
1
  std::vector<TestTask> tasks(kTotalTasks);
198
1
  std::vector<std::thread> threads;
199
1
  size_t begin = 0;
200
5
  for (size_t i = 0; i != kProducers; ++i) {
201
4
    size_t end = kTotalTasks * (i + 1) / kProducers;
202
4
    threads.emplace_back([&pool, &latch, &tasks, begin, end] {
203
4
      CDSAttacher attacher;
204
9.95k
      for (size_t i = begin; i != end; ++i) {
205
9.94k
        tasks[i].SetLatch(&latch);
206
9.94k
        pool.Enqueue(&tasks[i]);
207
9.94k
      }
208
4
    });
209
4
    begin = end;
210
4
  }
211
1
  pool.Shutdown();
212
1
  latch.Wait();
213
10.0k
  for (auto& task : tasks) {
214
10.0k
    ASSERT_TRUE(task.IsDone());
215
10.0k
  }
216
4
  for (auto& thread : threads) {
217
4
    thread.join();
218
4
  }
219
1
}
220
221
1
TEST_F(ThreadPoolTest, TestOwns) {
222
1
  class TestTask : public ThreadPoolTask {
223
1
   public:
224
1
    explicit TestTask(ThreadPool* thread_pool) : thread_pool_(thread_pool) {}
225
226
1
    void Run() {
227
1
      thread_ = Thread::current_thread();
228
1
      ASSERT_TRUE(thread_pool_->OwnsThisThread());
229
1
    }
230
231
1
    void Done(const Status& status) {
232
1
      latch_.CountDown();
233
1
    }
234
235
1
    Thread* thread() {
236
1
      return thread_;
237
1
    }
238
239
1
    void Wait() {
240
1
      return latch_.Wait();
241
1
    }
242
243
1
    virtual ~TestTask() {}
244
245
1
   private:
246
1
    ThreadPool* const thread_pool_;
247
1
    Thread* thread_ = nullptr;
248
1
    CountDownLatch latch_{1};
249
1
  };
250
251
1
  constexpr size_t kTotalTasks = 1;
252
1
  constexpr size_t kTotalWorkers = 1;
253
254
1
  ThreadPool pool("test", kTotalTasks, kTotalWorkers);
255
1
  ASSERT_FALSE(pool.OwnsThisThread());
256
1
  TestTask task(&pool);
257
1
  pool.Enqueue(&task);
258
1
  task.Wait();
259
1
  ASSERT_TRUE(pool.Owns(task.thread()));
260
1
}
261
262
namespace strand {
263
264
constexpr size_t kPoolMaxTasks = 100;
265
constexpr size_t kPoolTotalWorkers = 4;
266
267
1
TEST_F(ThreadPoolTest, Strand) {
268
1
  ThreadPool pool("test", kPoolMaxTasks, kPoolTotalWorkers);
269
1
  Strand strand(&pool);
270
271
1
  CountDownLatch latch(kPoolMaxTasks);
272
1
  std::atomic<int> counter(0);
273
101
  for (auto i = 0; i != kPoolMaxTasks; ++i) {
274
100
    strand.EnqueueFunctor([&counter, &latch] {
275
100
      ASSERT_EQ(++counter, 1);
276
100
      std::this_thread::sleep_for(1ms);
277
100
      ASSERT_EQ(--counter, 0);
278
100
      latch.CountDown();
279
100
    });
280
100
  }
281
282
1
  latch.Wait();
283
1
  strand.Shutdown();
284
1
}
285
286
1
TEST_F(ThreadPoolTest, StrandShutdown) {
287
1
  ThreadPool pool("test", kPoolMaxTasks, kPoolTotalWorkers);
288
1
  Strand strand(&pool);
289
290
1
  CountDownLatch latch1(1);
291
1
  strand.EnqueueFunctor([&latch1] {
292
1
    latch1.CountDown();
293
1
    std::this_thread::sleep_for(500ms);
294
1
  });
295
1
  class AbortedTask : public StrandTask {
296
1
   public:
297
0
    void Run() override {
298
0
      ASSERT_TRUE(false);
299
0
    }
300
301
1
    void Done(const Status& status) override {
302
1
      ASSERT_TRUE(status.IsAborted());
303
1
    }
304
305
1
    virtual ~AbortedTask() = default;
306
1
  };
307
1
  AbortedTask aborted_task;
308
1
  strand.Enqueue(&aborted_task);
309
1
  latch1.Wait();
310
1
  strand.Shutdown();
311
1
}
312
313
1
TEST_F(ThreadPoolTest, NotUsedStrandShutdown) {
314
1
  ThreadPool pool("test", kPoolMaxTasks, kPoolTotalWorkers);
315
316
1
  Strand strand(&pool);
317
318
1
  std::atomic<bool> shutdown_completed{false};
319
1
  std::thread shutdown_thread([strand = &strand, &shutdown_completed]{
320
1
    strand->Shutdown();
321
1
    shutdown_completed = true;
322
1
  });
323
324
1
  ASSERT_OK(LoggedWaitFor(
325
1
      [&shutdown_completed] {
326
1
        return shutdown_completed.load();
327
1
      },
328
1
      5s * kTimeMultiplier, "Waiting for strand shutdown"));
329
330
1
  shutdown_thread.join();
331
1
}
332
333
1
TEST_F(ThreadPoolTest, StrandShutdownAndDestroyRace) {
334
1
  constexpr size_t kNumIters = 10;
335
336
1
  ThreadPool pool("test", kPoolMaxTasks, kPoolTotalWorkers);
337
338
10
  auto task = []{};
339
340
11
  for (size_t iter = 0; iter < kNumIters; ++iter) {
341
10
    Strand strand(&pool);
342
343
10
    ANNOTATE_UNPROTECTED_WRITE(FLAGS_TEST_strand_done_inject_delay_ms) = 0;
344
10
    strand.EnqueueFunctor(task);
345
    // Give enough time for Strand::Done to be finished.
346
10
    std::this_thread::sleep_for(10ms);
347
10
    ANNOTATE_UNPROTECTED_WRITE(FLAGS_TEST_strand_done_inject_delay_ms) = 10;
348
10
    strand.EnqueueFunctor(task);
349
350
10
    strand.Shutdown();
351
10
  }
352
1
}
353
354
} // namespace strand
355
356
} // namespace rpc
357
} // namespace yb