diff --git a/lib/ConsumerImpl.h b/lib/ConsumerImpl.h index dd7163fb..61d96b1c 100644 --- a/lib/ConsumerImpl.h +++ b/lib/ConsumerImpl.h @@ -23,6 +23,7 @@ #include #include +#include #include #include diff --git a/lib/Future.h b/lib/Future.h index 5ee937ee..290ebc6f 100644 --- a/lib/Future.h +++ b/lib/Future.h @@ -20,14 +20,11 @@ #define LIB_FUTURE_H_ #include -#include +#include +#include #include -#include -#include #include #include -#include -#include namespace pulsar { @@ -38,71 +35,70 @@ class InternalState { using Pair = std::pair; using Lock = std::unique_lock; + enum Status : uint8_t + { + INITIAL, + COMPLETING, + COMPLETED + }; + // NOTE: Add the constructor explicitly just to be compatible with GCC 4.8 InternalState() {} void addListener(Listener listener) { Lock lock{mutex_}; - listeners_.emplace_back(listener); - lock.unlock(); - if (completed()) { - Type value; - Result result = get(value); - triggerListeners(result, value); + auto result = result_; + auto value = value_; + lock.unlock(); + listener(result, value); + } else { + tailListener_ = listeners_.emplace_after(tailListener_, std::move(listener)); } } bool complete(Result result, const Type &value) { - bool expected = false; - if (!completed_.compare_exchange_strong(expected, true)) { + Status expected = Status::INITIAL; + if (!status_.compare_exchange_strong(expected, Status::COMPLETING)) { return false; } - triggerListeners(result, value); - promise_.set_value(std::make_pair(result, value)); - return true; - } - - bool completed() const noexcept { return completed_; } - Result get(Type &result) { - const auto &pair = future_.get(); - result = pair.second; - return pair.first; - } + // Ensure if another thread calls `addListener` at the same time, that thread can get the value by + // `get` before the existing listeners are executed + Lock lock{mutex_}; + result_ = result; + value_ = value; + status_ = COMPLETED; + cond_.notify_all(); - // Only public for test - void triggerListeners(Result result, const Type &value) { - while (true) { - Lock lock{mutex_}; - if (listeners_.empty()) { - return; + if (!listeners_.empty()) { + auto listeners = std::move(listeners_); + lock.unlock(); + for (auto &&listener : listeners) { + listener(result, value); } + } - bool expected = false; - if (!listenerRunning_.compare_exchange_strong(expected, true)) { - // There is another thread that polled a listener that is running, skip polling and release - // the lock. Here we wait for some time to avoid busy waiting. - std::this_thread::sleep_for(std::chrono::milliseconds(1)); - continue; - } - auto listener = std::move(listeners_.front()); - listeners_.pop_front(); - lock.unlock(); + return true; + } - listener(result, value); - listenerRunning_ = false; - } + bool completed() const noexcept { return status_.load() == COMPLETED; } + + Result get(Type &value) const { + Lock lock{mutex_}; + cond_.wait(lock, [this] { return completed(); }); + value = value_; + return result_; } private: - std::atomic_bool completed_{false}; - std::promise promise_; - std::shared_future future_{promise_.get_future()}; - - std::list listeners_; mutable std::mutex mutex_; - std::atomic_bool listenerRunning_{false}; + mutable std::condition_variable cond_; + std::forward_list listeners_; + decltype(listeners_.before_begin()) tailListener_{listeners_.before_begin()}; + Result result_; + Type value_; + std::atomic status_{INITIAL}; }; template diff --git a/lib/ProducerImpl.h b/lib/ProducerImpl.h index 770ac45f..91b95443 100644 --- a/lib/ProducerImpl.h +++ b/lib/ProducerImpl.h @@ -20,6 +20,7 @@ #define LIB_PRODUCERIMPL_H_ #include +#include #include #include "Future.h" diff --git a/tests/PromiseTest.cc b/tests/PromiseTest.cc index 29ee2a3d..ad67e7df 100644 --- a/tests/PromiseTest.cc +++ b/tests/PromiseTest.cc @@ -19,10 +19,13 @@ #include #include +#include +#include #include #include #include +#include "WaitUtils.h" #include "lib/Future.h" #include "lib/LogUtils.h" @@ -88,26 +91,38 @@ TEST(PromiseTest, testListeners) { ASSERT_EQ(values, (std::vector(2, "hello"))); } -TEST(PromiseTest, testTriggerListeners) { - InternalState state; - state.addListener([](int, const int&) { - LOG_INFO("Start task 1..."); - std::this_thread::sleep_for(std::chrono::seconds(1)); - LOG_INFO("Finish task 1..."); +TEST(PromiseTest, testListenerDeadlock) { + Promise promise; + auto future = promise.getFuture(); + auto mutex = std::make_shared(); + auto done = std::make_shared(false); + + future.addListener([mutex, done](int, int) { + LOG_INFO("Listener-1 before acquiring the lock"); + std::lock_guard lock{*mutex}; + LOG_INFO("Listener-1 after acquiring the lock"); + done->store(true); }); - state.addListener([](int, const int&) { - LOG_INFO("Start task 2..."); + + std::thread t1{[mutex, &future] { + std::lock_guard lock{*mutex}; + // Make it a great chance that `t2` executes `promise.setValue` first + std::this_thread::sleep_for(std::chrono::seconds(2)); + + // Since the future is completed, `Future::get` will be called in `addListener` to get the result + LOG_INFO("Before adding Listener-2 (acquired the mutex)") + future.addListener([](int, int) { LOG_INFO("Listener-2 is triggered"); }); + LOG_INFO("After adding Listener-2 (releasing the mutex)"); + }}; + t1.detach(); + std::thread t2{[mutex, promise] { + // Make there a great chance that `t1` acquires `mutex` first std::this_thread::sleep_for(std::chrono::seconds(1)); - LOG_INFO("Finish task 2..."); - }); + LOG_INFO("Before setting value"); + promise.setValue(0); // the 1st listener is called, which is blocked at acquiring `mutex` + LOG_INFO("After setting value"); + }}; + t2.detach(); - auto start = std::chrono::high_resolution_clock::now(); - auto future1 = std::async(std::launch::async, [&state] { state.triggerListeners(0, 0); }); - auto future2 = std::async(std::launch::async, [&state] { state.triggerListeners(0, 0); }); - future1.wait(); - future2.wait(); - auto elapsed = std::chrono::duration_cast( - std::chrono::high_resolution_clock::now() - start) - .count(); - ASSERT_TRUE(elapsed > 2000) << "elapsed: " << elapsed << "ms"; + ASSERT_TRUE(waitUntil(std::chrono::seconds(5000), [done] { return done->load(); })); } diff --git a/tests/WaitUtils.h b/tests/WaitUtils.h index d7db82e5..4a03e534 100644 --- a/tests/WaitUtils.h +++ b/tests/WaitUtils.h @@ -25,13 +25,13 @@ namespace pulsar { template -inline void waitUntil(std::chrono::duration timeout, const std::function& condition, +inline bool waitUntil(std::chrono::duration timeout, const std::function& condition, long durationMs = 10) { auto timeoutMs = std::chrono::duration_cast(timeout).count(); while (timeoutMs > 0) { auto now = std::chrono::high_resolution_clock::now(); if (condition()) { - break; + return true; } std::this_thread::sleep_for(std::chrono::milliseconds(durationMs)); auto elapsed = std::chrono::duration_cast( @@ -39,6 +39,7 @@ inline void waitUntil(std::chrono::duration timeout, const std::fun .count(); timeoutMs -= elapsed; } + return false; } } // namespace pulsar