/Users/deen/code/yugabyte-db/src/yb/util/mt-threadlocal-test.cc
Line | Count | Source (jump to first uncovered line) |
1 | | // Licensed to the Apache Software Foundation (ASF) under one |
2 | | // or more contributor license agreements. See the NOTICE file |
3 | | // distributed with this work for additional information |
4 | | // regarding copyright ownership. The ASF licenses this file |
5 | | // to you under the Apache License, Version 2.0 (the |
6 | | // "License"); you may not use this file except in compliance |
7 | | // with the License. You may obtain a copy of the License at |
8 | | // |
9 | | // http://www.apache.org/licenses/LICENSE-2.0 |
10 | | // |
11 | | // Unless required by applicable law or agreed to in writing, |
12 | | // software distributed under the License is distributed on an |
13 | | // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
14 | | // KIND, either express or implied. See the License for the |
15 | | // specific language governing permissions and limitations |
16 | | // under the License. |
17 | | // |
18 | | // The following only applies to changes made to this file as part of YugaByte development. |
19 | | // |
20 | | // Portions Copyright (c) YugaByte, Inc. |
21 | | // |
22 | | // Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except |
23 | | // in compliance with the License. You may obtain a copy of the License at |
24 | | // |
25 | | // http://www.apache.org/licenses/LICENSE-2.0 |
26 | | // |
27 | | // Unless required by applicable law or agreed to in writing, software distributed under the License |
28 | | // is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express |
29 | | // or implied. See the License for the specific language governing permissions and limitations |
30 | | // under the License. |
31 | | // |
32 | | #include <mutex> |
33 | | #include <unordered_set> |
34 | | |
35 | | #include <glog/logging.h> |
36 | | |
37 | | #include "yb/gutil/macros.h" |
38 | | #include "yb/gutil/map-util.h" |
39 | | #include "yb/gutil/ref_counted.h" |
40 | | #include "yb/gutil/stl_util.h" |
41 | | #include "yb/util/countdown_latch.h" |
42 | | #include "yb/util/env.h" |
43 | | #include "yb/util/locks.h" |
44 | | #include "yb/util/status_log.h" |
45 | | #include "yb/util/test_util.h" |
46 | | #include "yb/util/thread.h" |
47 | | #include "yb/util/threadlocal.h" |
48 | | |
49 | | using std::unordered_set; |
50 | | using std::vector; |
51 | | using strings::Substitute; |
52 | | |
53 | | namespace yb { |
54 | | namespace threadlocal { |
55 | | |
56 | | class ThreadLocalTest : public YBTest {}; |
57 | | |
58 | | const int kTargetCounterVal = 1000000; |
59 | | |
60 | | class Counter; |
61 | | typedef unordered_set<Counter*> CounterPtrSet; |
62 | | typedef Mutex RegistryLockType; |
63 | | typedef simple_spinlock CounterLockType; |
64 | | |
65 | | // Registry to provide reader access to the thread-local Counters. |
66 | | // The methods are only thread-safe if the calling thread holds the lock. |
67 | | class CounterRegistry { |
68 | | public: |
69 | 1 | CounterRegistry() { |
70 | 1 | } |
71 | | |
72 | 60 | RegistryLockType* get_lock() const { |
73 | 60 | return &lock_; |
74 | 60 | } |
75 | | |
76 | 24 | bool RegisterUnlocked(Counter* counter) { |
77 | 24 | LOG(INFO) << "Called RegisterUnlocked()"; |
78 | 24 | return InsertIfNotPresent(&counters_, counter); |
79 | 24 | } |
80 | | |
81 | 24 | bool UnregisterUnlocked(Counter* counter) { |
82 | 24 | LOG(INFO) << "Called UnregisterUnlocked()"; |
83 | 24 | return counters_.erase(counter) > 0; |
84 | 24 | } |
85 | | |
86 | 12 | CounterPtrSet* GetCountersUnlocked() { |
87 | 12 | return &counters_; |
88 | 12 | } |
89 | | |
90 | | private: |
91 | | mutable RegistryLockType lock_; |
92 | | CounterPtrSet counters_; |
93 | | DISALLOW_COPY_AND_ASSIGN(CounterRegistry); |
94 | | }; |
95 | | |
96 | | // A simple Counter class that registers itself with a CounterRegistry. |
97 | | class Counter { |
98 | | public: |
99 | | Counter(CounterRegistry* registry, int val) |
100 | | : tid_(Env::Default()->gettid()), |
101 | | registry_(CHECK_NOTNULL(registry)), |
102 | 24 | val_(val) { |
103 | 24 | LOG(INFO) << "Counter::~Counter(): tid = " << tid_ << ", addr = " << this << ", val = " << val_; |
104 | 24 | std::lock_guard<RegistryLockType> reg_lock(*registry_->get_lock()); |
105 | 24 | CHECK(registry_->RegisterUnlocked(this)); |
106 | 24 | } |
107 | | |
108 | 24 | ~Counter() { |
109 | 24 | LOG(INFO) << "Counter::~Counter(): tid = " << tid_ << ", addr = " << this << ", val = " << val_; |
110 | 24 | std::lock_guard<RegistryLockType> reg_lock(*registry_->get_lock()); |
111 | 24 | std::lock_guard<CounterLockType> self_lock(lock_); |
112 | 24 | LOG(INFO) << tid_ << ": deleting self from registry..."; |
113 | 24 | CHECK(registry_->UnregisterUnlocked(this)); |
114 | 24 | } |
115 | | |
116 | 96 | uint64_t tid() { |
117 | 96 | return tid_; |
118 | 96 | } |
119 | | |
120 | 20.6M | CounterLockType* get_lock() const { |
121 | 20.6M | return &lock_; |
122 | 20.6M | } |
123 | | |
124 | 22.6M | void IncrementUnlocked() { |
125 | 22.6M | val_++; |
126 | 22.6M | } |
127 | | |
128 | 96 | int GetValueUnlocked() { |
129 | 96 | return val_; |
130 | 96 | } |
131 | | |
132 | | private: |
133 | | // We expect that most of the time this lock will be uncontended. |
134 | | mutable CounterLockType lock_; |
135 | | |
136 | | // TID of thread that constructed this object. |
137 | | const uint64_t tid_; |
138 | | |
139 | | // Register / unregister ourselves with this on construction / destruction. |
140 | | CounterRegistry* const registry_; |
141 | | |
142 | | // Current value of the counter. |
143 | | int val_; |
144 | | |
145 | | DISALLOW_COPY_AND_ASSIGN(Counter); |
146 | | }; |
147 | | |
148 | | // Create a new THREAD_LOCAL Counter and loop an increment operation on it. |
149 | | static void RegisterCounterAndLoopIncr(CounterRegistry* registry, |
150 | | CountDownLatch* counters_ready, |
151 | | CountDownLatch* reader_ready, |
152 | | CountDownLatch* counters_done, |
153 | 24 | CountDownLatch* reader_done) { |
154 | 24 | BLOCK_STATIC_THREAD_LOCAL(Counter, counter, registry, 0); |
155 | | // Inform the reader that we are alive. |
156 | 24 | counters_ready->CountDown(); |
157 | | // Let the reader initialize before we start counting. |
158 | 24 | reader_ready->Wait(); |
159 | | // Now rock & roll on the counting loop. |
160 | 21.8M | for (int i = 0; i < kTargetCounterVal; i++) { |
161 | 21.8M | std::lock_guard<CounterLockType> l(*counter->get_lock()); |
162 | 21.8M | counter->IncrementUnlocked(); |
163 | 21.8M | } |
164 | | // Let the reader know we're ready for him to verify our counts. |
165 | 24 | counters_done->CountDown(); |
166 | | // Wait until the reader is done before we exit the thread, which will call |
167 | | // delete on the Counter. |
168 | 24 | reader_done->Wait(); |
169 | 24 | } |
170 | | |
171 | | // Iterate over the registered counters and their values. |
172 | 12 | static uint64_t Iterate(CounterRegistry* registry, int expected_counters) { |
173 | 12 | uint64_t sum = 0; |
174 | 12 | int seen_counters = 0; |
175 | 12 | std::lock_guard<RegistryLockType> l(*registry->get_lock()); |
176 | 96 | for (Counter* counter : *registry->GetCountersUnlocked()) { |
177 | 96 | uint64_t value; |
178 | 96 | { |
179 | 96 | std::lock_guard<CounterLockType> l(*counter->get_lock()); |
180 | 96 | value = counter->GetValueUnlocked(); |
181 | 96 | } |
182 | 96 | LOG(INFO) << "tid " << counter->tid() << " (counter " << counter << "): " << value; |
183 | 96 | sum += value; |
184 | 96 | seen_counters++; |
185 | 96 | } |
186 | 12 | CHECK_EQ(expected_counters, seen_counters); |
187 | 12 | return sum; |
188 | 12 | } |
189 | | |
190 | 3 | static void TestThreadLocalCounters(CounterRegistry* registry, const int num_threads) { |
191 | 3 | LOG(INFO) << "Starting threads..."; |
192 | 3 | vector<scoped_refptr<yb::Thread> > threads; |
193 | | |
194 | 3 | CountDownLatch counters_ready(num_threads); |
195 | 3 | CountDownLatch reader_ready(1); |
196 | 3 | CountDownLatch counters_done(num_threads); |
197 | 3 | CountDownLatch reader_done(1); |
198 | 27 | for (int i = 0; i < num_threads; i++) { |
199 | 24 | scoped_refptr<yb::Thread> new_thread; |
200 | 24 | CHECK_OK(yb::Thread::Create("test", strings::Substitute("t$0", i), |
201 | 24 | &RegisterCounterAndLoopIncr, registry, &counters_ready, &reader_ready, |
202 | 24 | &counters_done, &reader_done, &new_thread)); |
203 | 24 | threads.push_back(new_thread); |
204 | 24 | } |
205 | | |
206 | | // Wait for all threads to start and register their Counters. |
207 | 3 | counters_ready.Wait(); |
208 | 3 | CHECK_EQ(0, Iterate(registry, num_threads)); |
209 | 3 | LOG(INFO) << "--"; |
210 | | |
211 | | // Let the counters start spinning. |
212 | 3 | reader_ready.CountDown(); |
213 | | |
214 | | // Try to catch them in the act, just for kicks. |
215 | 9 | for (int i = 0; i < 2; i++) { |
216 | 6 | Iterate(registry, num_threads); |
217 | 6 | LOG(INFO) << "--"; |
218 | 6 | SleepFor(MonoDelta::FromMicroseconds(1)); |
219 | 6 | } |
220 | | |
221 | | // Wait until they're done and assure they sum up properly. |
222 | 3 | counters_done.Wait(); |
223 | 3 | LOG(INFO) << "Checking Counter sums..."; |
224 | 3 | CHECK_EQ(kTargetCounterVal * num_threads, Iterate(registry, num_threads)); |
225 | 3 | LOG(INFO) << "Counter sums add up!"; |
226 | 3 | reader_done.CountDown(); |
227 | | |
228 | 3 | LOG(INFO) << "Joining & deleting threads..."; |
229 | 24 | for (scoped_refptr<yb::Thread> thread : threads) { |
230 | 24 | CHECK_OK(ThreadJoiner(thread.get()).Join()); |
231 | 24 | } |
232 | 3 | LOG(INFO) << "Done."; |
233 | 3 | } |
234 | | |
235 | 1 | TEST_F(ThreadLocalTest, TestConcurrentCounters) { |
236 | | // Run this multiple times to ensure we don't leave remnants behind in the |
237 | | // CounterRegistry. |
238 | 1 | CounterRegistry registry; |
239 | 4 | for (int i = 0; i < 3; i++) { |
240 | 3 | TestThreadLocalCounters(®istry, 8); |
241 | 3 | } |
242 | 1 | } |
243 | | |
244 | | // Test class that stores a string in a static thread local member. |
245 | | // This class cannot be instantiated. The methods are all static. |
246 | | class ThreadLocalString { |
247 | | public: |
248 | | static void set(std::string value); |
249 | | static const std::string& get(); |
250 | | private: |
251 | 0 | ThreadLocalString() { |
252 | 0 | } |
253 | | DECLARE_STATIC_THREAD_LOCAL(std::string, value_); |
254 | | DISALLOW_COPY_AND_ASSIGN(ThreadLocalString); |
255 | | }; |
256 | | |
257 | | DEFINE_STATIC_THREAD_LOCAL(std::string, ThreadLocalString, value_); |
258 | | |
259 | 8 | void ThreadLocalString::set(std::string value) { |
260 | 8 | INIT_STATIC_THREAD_LOCAL(std::string, value_); |
261 | 8 | *value_ = value; |
262 | 8 | } |
263 | | |
264 | 16 | const std::string& ThreadLocalString::get() { |
265 | 16 | INIT_STATIC_THREAD_LOCAL(std::string, value_); |
266 | 16 | return *value_; |
267 | 16 | } |
268 | | |
269 | | static void RunAndAssign(CountDownLatch* writers_ready, |
270 | | CountDownLatch *readers_ready, |
271 | | CountDownLatch *all_done, |
272 | | CountDownLatch *threads_exiting, |
273 | | const std::string& in, |
274 | 8 | std::string* out) { |
275 | 8 | writers_ready->Wait(); |
276 | | // Ensure it starts off as an empty string. |
277 | 8 | CHECK_EQ("", ThreadLocalString::get()); |
278 | 8 | ThreadLocalString::set(in); |
279 | | |
280 | 8 | readers_ready->Wait(); |
281 | 8 | out->assign(ThreadLocalString::get()); |
282 | 8 | all_done->Wait(); |
283 | 8 | threads_exiting->CountDown(); |
284 | 8 | } |
285 | | |
286 | 1 | TEST_F(ThreadLocalTest, TestTLSMember) { |
287 | 1 | const int num_threads = 8; |
288 | | |
289 | 1 | vector<CountDownLatch*> writers_ready; |
290 | 1 | vector<CountDownLatch*> readers_ready; |
291 | 1 | vector<std::string*> out_strings; |
292 | 1 | vector<scoped_refptr<yb::Thread> > threads; |
293 | | |
294 | 1 | ElementDeleter writers_deleter(&writers_ready); |
295 | 1 | ElementDeleter readers_deleter(&readers_ready); |
296 | 1 | ElementDeleter out_strings_deleter(&out_strings); |
297 | | |
298 | 1 | CountDownLatch all_done(1); |
299 | 1 | CountDownLatch threads_exiting(num_threads); |
300 | | |
301 | 1 | LOG(INFO) << "Starting threads..."; |
302 | 9 | for (int i = 0; i < num_threads; i++) { |
303 | 8 | writers_ready.push_back(new CountDownLatch(1)); |
304 | 8 | readers_ready.push_back(new CountDownLatch(1)); |
305 | 8 | out_strings.push_back(new std::string()); |
306 | 8 | scoped_refptr<yb::Thread> new_thread; |
307 | 8 | CHECK_OK(yb::Thread::Create("test", strings::Substitute("t$0", i), |
308 | 8 | &RunAndAssign, writers_ready[i], readers_ready[i], |
309 | 8 | &all_done, &threads_exiting, Substitute("$0", i), out_strings[i], &new_thread)); |
310 | 8 | threads.push_back(new_thread); |
311 | 8 | } |
312 | | |
313 | | // Unlatch the threads in order. |
314 | 1 | LOG(INFO) << "Writing to thread locals..."; |
315 | 9 | for (int i = 0; i < num_threads; i++) { |
316 | 8 | writers_ready[i]->CountDown(); |
317 | 8 | } |
318 | 1 | LOG(INFO) << "Reading from thread locals..."; |
319 | 9 | for (int i = 0; i < num_threads; i++) { |
320 | 8 | readers_ready[i]->CountDown(); |
321 | 8 | } |
322 | 1 | all_done.CountDown(); |
323 | | // threads_exiting acts as a memory barrier. |
324 | 1 | threads_exiting.Wait(); |
325 | 9 | for (int i = 0; i < num_threads; i++) { |
326 | 8 | ASSERT_EQ(Substitute("$0", i), *out_strings[i]); |
327 | 8 | LOG(INFO) << "Read " << *out_strings[i]; |
328 | 8 | } |
329 | | |
330 | 1 | LOG(INFO) << "Joining & deleting threads..."; |
331 | 8 | for (scoped_refptr<yb::Thread> thread : threads) { |
332 | 8 | CHECK_OK(ThreadJoiner(thread.get()).Join()); |
333 | 8 | } |
334 | 1 | } |
335 | | |
336 | | } // namespace threadlocal |
337 | | } // namespace yb |