diff --git a/include/crucible/namedptr.h b/include/crucible/namedptr.h new file mode 100644 index 0000000..adf21a5 --- /dev/null +++ b/include/crucible/namedptr.h @@ -0,0 +1,196 @@ +#ifndef CRUCIBLE_NAMEDPTR_H +#define CRUCIBLE_NAMEDPTR_H + +#include "crucible/lockset.h" + +#include +#include +#include +#include +#include + +namespace crucible { + using namespace std; + + /// Storage for objects with unique names + + template + class NamedPtr { + public: + using Key = tuple; + using Ptr = shared_ptr; + using Func = function; + private: + struct Value; + using WeakPtr = weak_ptr; + using MapType = map; + struct MapRep { + MapType m_map; + mutex m_mutex; + }; + using MapPtr = shared_ptr; + struct Value { + Ptr m_ret_ptr; + MapPtr m_map_rep; + Key m_ret_key; + ~Value(); + Value(Ptr&& ret_ptr, const Key &key, const MapPtr &map_rep); + }; + + Func m_fn; + MapPtr m_map_rep = make_shared(); + LockSet m_lockset; + + Ptr lookup_item(const Key &k); + Ptr insert_item(Func fn, Arguments... args); + + public: + NamedPtr(Func f = Func()); + + void func(Func f); + + Ptr operator()(Arguments... args); + Ptr insert(const Ptr &r, Arguments... args); + }; + + template + NamedPtr::NamedPtr(Func f) : + m_fn(f) + { + } + + template + NamedPtr::Value::Value(Ptr&& ret_ptr, const Key &key, const MapPtr &map_rep) : + m_ret_ptr(ret_ptr), + m_map_rep(map_rep), + m_ret_key(key) + { + } + + template + NamedPtr::Value::~Value() + { + unique_lock lock(m_map_rep->m_mutex); + // We are called from the shared_ptr destructor, so we + // know that the weak_ptr in the map has already expired; + // however, if another thread already noticed that the + // map entry expired while we were waiting for the lock, + // the other thread will have already replaced the map + // entry with a pointer to some other object, and that + // object now owns the map entry. So we do a key lookup + // here instead of storing a map iterator, and only erase + // "our" map entry if it exists and is expired. The other + // thread would have done the same for us if the race had + // a different winner. + auto found = m_map_rep->m_map.find(m_ret_key); + if (found != m_map_rep->m_map.end() && found->second.expired()) { + m_map_rep->m_map.erase(found); + } + } + + template + typename NamedPtr::Ptr + NamedPtr::lookup_item(const Key &k) + { + // Must be called with lock held + auto found = m_map_rep->m_map.find(k); + if (found != m_map_rep->m_map.end()) { + // Get the strong pointer back + auto rv = found->second.lock(); + if (rv) { + // Have strong pointer. Return value that shares map entry. + return shared_ptr(rv, rv->m_ret_ptr.get()); + } + // Have expired weak pointer. Another thread is trying to delete it, + // but we got the lock first. Leave the map entry alone here. + // The other thread will erase it, or we will put a different entry + // in the same map entry. + } + return Ptr(); + } + + template + typename NamedPtr::Ptr + NamedPtr::insert_item(Func fn, Arguments... args) + { + Key k(args...); + + // Is it already in the map? + unique_lock lock(m_map_rep->m_mutex); + auto rv = lookup_item(k); + if (rv) { + return rv; + } + + // Release map lock and acquire key lock + lock.unlock(); + auto key_lock = m_lockset.make_lock(k); + + // Did item appear in map while we were waiting for key? + lock.lock(); + rv = lookup_item(k); + if (rv) { + return rv; + } + + // We now hold key and index locks, but item not in map (or expired). + // Release map lock + lock.unlock(); + + // Call the function and create a new Value + auto new_value_ptr = make_shared(fn(args...), k, m_map_rep); + // Function must return a non-null pointer + THROW_CHECK0(runtime_error, new_value_ptr->m_ret_ptr); + + // Reacquire index lock for map insertion + lock.lock(); + + // Insert return value in map or overwrite existing + // empty or expired weak_ptr value. + WeakPtr &new_item_ref = m_map_rep->m_map[k]; + + // We searched the map while holding both locks and + // found no entry or an expired weak_ptr; therefore, no + // other thread could have inserted a new non-expired + // weak_ptr, and the weak_ptr in the map is expired + // or was default-constructed as a nullptr. So if the + // new_item_ref is not expired, we have a bug we need + // to find and fix. + assert(new_item_ref.expired()); + + // Update the empty map slot + new_item_ref = new_value_ptr; + + // Drop lock so we don't deadlock in constructor exceptions + lock.unlock(); + + // Return shared_ptr to Return using strong pointer's reference counter + return shared_ptr(new_value_ptr, new_value_ptr->m_ret_ptr.get()); + } + + template + void + NamedPtr::func(Func func) + { + unique_lock lock(m_map_rep->m_mutex); + m_fn = func; + } + + template + typename NamedPtr::Ptr + NamedPtr::operator()(Arguments... args) + { + return insert_item(m_fn, args...); + } + + template + typename NamedPtr::Ptr + NamedPtr::insert(const Ptr &r, Arguments... args) + { + THROW_CHECK0(invalid_argument, r); + return insert_item([&](Arguments...) -> Ptr { return r; }, args...); + } + +} + +#endif // NAMEDPTR_H diff --git a/test/Makefile b/test/Makefile index 5814cd4..a0ba01c 100644 --- a/test/Makefile +++ b/test/Makefile @@ -3,6 +3,7 @@ PROGRAMS = \ crc64 \ fd \ limits \ + namedptr \ path \ process \ progress \ diff --git a/test/namedptr.cc b/test/namedptr.cc new file mode 100644 index 0000000..985aa8f --- /dev/null +++ b/test/namedptr.cc @@ -0,0 +1,84 @@ +#include "tests.h" +#include "crucible/error.h" +#include "crucible/namedptr.h" + +#include +#include + +using namespace crucible; + +struct named_thing { + static set s_set; + int m_a, m_b; + named_thing() = delete; + named_thing(const named_thing &that) : + m_a(that.m_a), + m_b(that.m_b) + { + cerr << "named_thing(" << m_a << ", " << m_b << ") " << this << " copied from " << &that << "." << endl; + auto rv = s_set.insert(this); + THROW_CHECK1(runtime_error, *rv.first, rv.second); + } + named_thing(int a, int b) : + m_a(a), m_b(b) + { + cerr << "named_thing(" << a << ", " << b << ") " << this << " constructed." << endl; + auto rv = s_set.insert(this); + THROW_CHECK1(runtime_error, *rv.first, rv.second); + } + ~named_thing() { + auto rv = s_set.erase(this); + assert(rv == 1); + cerr << "named_thing(" << m_a << ", " << m_b << ") " << this << " destroyed." << endl; + m_a = ~m_a; + m_b = ~m_b; + } + void check(int a, int b) { + THROW_CHECK2(runtime_error, m_a, a, m_a == a); + THROW_CHECK2(runtime_error, m_b, b, m_b == b); + } + static void check_empty() { + THROW_CHECK1(runtime_error, s_set.size(), s_set.empty()); + } +}; + +set named_thing::s_set; + +static +void +test_namedptr() +{ + NamedPtr names; + names.func([](int a, int b) -> shared_ptr { return make_shared(a, b); }); + + auto a_3_5 = names(3, 5); + auto b_3_5 = names(3, 5); + { + auto c_2_7 = names(2, 7); + b_3_5 = a_3_5; + a_3_5->check(3, 5); + b_3_5->check(3, 5); + c_2_7->check(2, 7); + } + auto d_2_7 = names(2, 7); + a_3_5->check(3, 5); + a_3_5.reset(); + b_3_5->check(3, 5); + d_2_7->check(2, 7); +} + +static +void +test_leak() +{ + named_thing::check_empty(); +} + +int +main(int, char**) +{ + RUN_A_TEST(test_namedptr()); + RUN_A_TEST(test_leak()); + + exit(EXIT_SUCCESS); +}