LCOV - code coverage report
Current view: top level - core/include/scipp/core - dict.h (source / functions) Hit Total Coverage
Test: coverage.info Lines: 121 125 96.8 %
Date: 2024-04-28 01:25:40 Functions: 194 202 96.0 %

          Line data    Source code
       1             : // SPDX-License-Identifier: BSD-3-Clause
       2             : // Copyright (c) 2023 Scipp contributors (https://github.com/scipp)
       3             : /// @file
       4             : /// @author Jan-Lukas Wynen
       5             : ///
       6             : /// Dict is a container similar to Python's dict. It differs from
       7             : /// std::map and std::unordered_map in that it stores elements in the
       8             : /// order of insertion. In addition, its iterators throw an exception
       9             : /// if the dict has changed size during iteration. This matches Python's
      10             : /// behavior and avoids segfaults when misusing the dict.
      11             : #pragma once
      12             : 
      13             : #include <functional>
      14             : #include <string>
      15             : #include <string_view>
      16             : #include <vector>
      17             : 
      18             : #include "scipp/common/index.h"
      19             : 
      20             : #include "scipp/core/except.h"
      21             : #include "scipp/core/string.h"
      22             : 
      23             : namespace scipp::core::dict_detail {
      24             : template <class It1, class It2 = void> struct ValueType {
      25             :   using type = std::pair<typename It1::value_type, typename It2::value_type>;
      26             : };
      27             : 
      28             : template <class It1> struct ValueType<It1, void> {
      29             :   using type = typename It1::value_type;
      30             : };
      31             : 
      32             : template <class It1, class It2 = void> struct ReferenceType {
      33             :   using type = std::add_rvalue_reference_t<
      34             :       std::pair<typename It1::reference, typename It2::reference>>;
      35             : };
      36             : 
      37             : template <class It1> struct ReferenceType<It1, void> {
      38             :   using type = typename It1::reference;
      39             : };
      40             : 
      41             : template <class BaseIterator, class Func> class TransformIterator;
      42             : 
      43             : // This iterator is mostly standard library conform. But it violates the
      44             : // requirement that *it must return a reference to value_type.
      45             : // This is required because the keys must be returned as const refs but
      46             : // stored in the dict as non-const.
      47             : template <class Container, class... It> class Iterator {
      48             :   static_assert(sizeof...(It) > 0 && sizeof...(It) < 3);
      49             : 
      50             : public:
      51             :   using difference_type = std::ptrdiff_t;
      52             :   using value_type = typename ValueType<It...>::type;
      53             :   using pointer = std::add_pointer_t<std::remove_reference_t<value_type>>;
      54             :   using reference = typename ReferenceType<It...>::type;
      55             : 
      56             :   template <class... T>
      57     2264060 :   explicit Iterator(std::reference_wrapper<Container> container, T &&...it)
      58     2264060 :       : m_iterators{std::forward<T>(it)...}, m_container(container),
      59     2264060 :         m_base_address(container.get().data()), m_size(container.get().size()) {
      60     2264060 :   }
      61             : 
      62      708111 :   decltype(auto) operator*() const {
      63      708111 :     expect_container_unchanged();
      64             :     if constexpr (sizeof...(It) == 1) {
      65       43874 :       return *std::get<0>(m_iterators);
      66             :     } else {
      67      664237 :       return std::make_pair(std::cref(*std::get<0>(m_iterators)),
      68     1328474 :                             std::ref(*std::get<1>(m_iterators)));
      69             :     }
      70             :   }
      71             : 
      72        4256 :   decltype(auto) operator->() const {
      73             :     if constexpr (sizeof...(It) == 1) {
      74           8 :       expect_container_unchanged();
      75           8 :       return std::get<0>(m_iterators);
      76             :     } else {
      77             :       // No need to use expect_container_unchanged
      78             :       // because we delegate to operator*
      79        4248 :       return TemporaryItem<reference>(**this);
      80             :     }
      81             :   }
      82             : 
      83      698918 :   Iterator &operator++() {
      84      698918 :     expect_container_unchanged();
      85      698918 :     ++std::get<0>(m_iterators);
      86             :     if constexpr (sizeof...(It) == 2)
      87      655099 :       ++std::get<1>(m_iterators);
      88      698918 :     return *this;
      89             :   }
      90             : 
      91     1850160 :   bool operator==(const Iterator<Container, It...> &other) const {
      92     1850160 :     expect_container_unchanged();
      93             :     // Assuming m_iterators are always in sync.
      94     1850160 :     return std::get<0>(m_iterators) == std::get<0>(other.m_iterators);
      95             :   }
      96             : 
      97     1788205 :   bool operator!=(const Iterator<Container, It...> &other) const {
      98     1788205 :     return !(*this == other); // NOLINT
      99             :   }
     100             : 
     101           2 :   template <class F> auto transform(F &&func) const & {
     102           2 :     return TransformIterator{*this, std::forward<F>(func)};
     103             :   }
     104             : 
     105       82334 :   template <class F> auto transform(F &&func) && {
     106       82334 :     return TransformIterator{std::move(*this), std::forward<F>(func)};
     107             :   }
     108             : 
     109             :   friend void swap(Iterator &a, Iterator &b) {
     110             :     swap(a.m_iterators, b.m_iterators);
     111             :     swap(a.m_container, b.m_container);
     112             :     std::swap(a.m_base_address, b.m_base_address);
     113             :     std::swap(a.m_size, b.m_size);
     114             :   }
     115             : 
     116             : protected:
     117             :   // operator-> needs to return a pointer or something that has operator->
     118             :   // But we cannot take the address of the temporary pair or transform result.
     119             :   // So store it in this wrapper to make it accessible via its address.
     120             :   template <class T> class TemporaryItem {
     121             :   public:
     122        4292 :     explicit TemporaryItem(T &&item) : m_item(std::move(item)) {}
     123        4292 :     auto *operator->() { return &m_item; }
     124             : 
     125             :   private:
     126             :     std::decay_t<T> m_item;
     127             :   };
     128             : 
     129             : private:
     130             :   using IteratorStorage = std::tuple<It...>;
     131             : 
     132             :   IteratorStorage m_iterators;
     133             :   std::reference_wrapper<Container> m_container;
     134             :   const void *m_base_address;
     135             :   size_t m_size;
     136             : 
     137     3257197 :   void expect_container_unchanged() const {
     138     6514394 :     if (m_container.get().data() != m_base_address ||
     139     3257197 :         m_container.get().size() != m_size) {
     140           0 :       throw std::runtime_error("dictionary changed size during iteration");
     141             :     }
     142     3257197 :   }
     143             : };
     144             : 
     145             : template <class BaseIterator, class Func>
     146             : class TransformIterator : public BaseIterator {
     147             : public:
     148             :   using difference_type = std::ptrdiff_t;
     149             :   using value_type =
     150             :       std::invoke_result_t<Func, typename BaseIterator::value_type>;
     151             :   using pointer = std::add_pointer_t<std::remove_reference_t<value_type>>;
     152             :   using reference = std::add_lvalue_reference_t<value_type>;
     153             : 
     154             :   template <class It, class F>
     155       82336 :   TransformIterator(It &&base, F &&func)
     156       82336 :       : BaseIterator(std::forward<It>(base)), m_func(std::forward<F>(func)) {}
     157             : 
     158       80268 :   decltype(auto) operator*() const { return m_func(BaseIterator::operator*()); }
     159             : 
     160          44 :   decltype(auto) operator->() const {
     161             :     using Result = typename BaseIterator::template TemporaryItem<
     162             :         std::decay_t<decltype(**this)>>;
     163          44 :     return Result(**this);
     164             :   }
     165             : 
     166             :   template <class F> auto transform(F &&func) const & {
     167             :     return BaseIterator::transform(
     168             :         [new_f = std::forward<F>(func), old_f = this->m_func](const auto &x) {
     169             :           return new_f(old_f(x));
     170             :         });
     171             :   }
     172             : 
     173           2 :   template <class F> auto transform(F &&func) && {
     174             :     // Make a copy for old_f to avoid referencing a member of *this.
     175             :     return BaseIterator::transform(
     176           3 :         [new_f = std::forward<F>(func), old_f = this->m_func](const auto &x) {
     177           1 :           return new_f(old_f(x));
     178           2 :         });
     179             :   }
     180             : 
     181             : private:
     182             :   std::decay_t<Func> m_func;
     183             : };
     184             : 
     185             : template <class I, class F>
     186             : TransformIterator(I, F) -> TransformIterator<std::decay_t<I>, std::decay_t<F>>;
     187             : } // namespace scipp::core::dict_detail
     188             : 
     189             : namespace std {
     190             : template <class Container, class... It>
     191             : struct iterator_traits<scipp::core::dict_detail::Iterator<Container, It...>> {
     192             : private:
     193             :   using I = scipp::core::dict_detail::Iterator<Container, It...>;
     194             : 
     195             : public:
     196             :   using difference_type = typename I::difference_type;
     197             :   using value_type = typename I::value_type;
     198             :   using pointer = typename I::pointer;
     199             :   using reference = typename I::reference;
     200             : 
     201             :   // It is a forward iterator for most use cases.
     202             :   // But it misses post-increment:
     203             :   //   it++ and *it++  (easy, but not needed right now)
     204             :   using iterator_category = std::forward_iterator_tag;
     205             : };
     206             : 
     207             : template <class BaseIterator, class Func>
     208             : struct iterator_traits<
     209             :     scipp::core::dict_detail::TransformIterator<BaseIterator, Func>> {
     210             : private:
     211             :   using I = scipp::core::dict_detail::TransformIterator<BaseIterator, Func>;
     212             : 
     213             : public:
     214             :   using difference_type = typename I::difference_type;
     215             :   using value_type = typename I::value_type;
     216             :   using pointer = typename I::pointer;
     217             :   using reference = typename I::reference;
     218             : 
     219             :   // It is a forward iterator for most use cases.
     220             :   // But it misses post-increment:
     221             :   //   it++ and *it++  (easy, but not needed right now)
     222             :   using iterator_category = std::forward_iterator_tag;
     223             : };
     224             : } // namespace std
     225             : 
     226             : namespace scipp::core {
     227             : template <class Key, class Value> class Dict {
     228             :   using Keys = std::vector<Key>;
     229             :   using Values = std::vector<Value>;
     230             : 
     231             : public:
     232             :   using key_type = Key;
     233             :   using mapped_type = Value;
     234             :   using value_type = std::pair<const Key, Value>;
     235             :   using value_iterator =
     236             :       typename dict_detail::Iterator<Values, typename Values::iterator>;
     237             :   using iterator =
     238             :       typename dict_detail::Iterator<Keys, typename Keys::const_iterator,
     239             :                                      typename Values::iterator>;
     240             :   using const_key_iterator =
     241             :       typename dict_detail::Iterator<const Keys, typename Keys::const_iterator>;
     242             :   using const_value_iterator =
     243             :       typename dict_detail::Iterator<const Values,
     244             :                                      typename Values::const_iterator>;
     245             :   using const_iterator =
     246             :       typename dict_detail::Iterator<const Keys, typename Keys::const_iterator,
     247             :                                      typename Values::const_iterator>;
     248             : 
     249        2249 :   Dict(std::initializer_list<std::pair<const Key, Value>> items) {
     250        2249 :     reserve(items.size());
     251        2272 :     for (const auto &[key, value] : items) {
     252          23 :       if (contains(key))
     253           0 :         throw std::invalid_argument("duplicate key in initializer");
     254          23 :       insert_or_assign(key, value);
     255             :     }
     256        2249 :   }
     257             : 
     258      871455 :   Dict() = default;
     259             : 
     260             :   /// Return the number of elements.
     261       48899 :   [[nodiscard]] index size() const noexcept { return scipp::size(m_keys); }
     262             :   /// Return true if there are 0 elements.
     263        2505 :   [[nodiscard]] bool empty() const noexcept { return size() == 0; }
     264             :   /// Return the number of elements that space is currently allocated for.
     265         966 :   [[nodiscard]] index capacity() const noexcept { return m_keys.capacity(); }
     266             : 
     267        2524 :   void reserve(const index new_capacity) {
     268        2524 :     m_keys.reserve(new_capacity);
     269        2524 :     m_values.reserve(new_capacity);
     270        2524 :   }
     271             : 
     272      923929 :   [[nodiscard]] bool contains(const Key &key) const noexcept {
     273      923929 :     return find_key(key) != m_keys.end();
     274             :   }
     275             : 
     276      554409 :   template <class V> void insert_or_assign(const key_type &key, V &&value) {
     277      554409 :     if (const auto key_it = find_key(key); key_it == m_keys.end()) {
     278      553428 :       m_keys.push_back(key);
     279      553428 :       m_values.emplace_back(std::forward<V>(value));
     280             :     } else {
     281         981 :       m_values[index_of(key_it)] = std::forward<V>(value);
     282             :     }
     283      554409 :   }
     284             : 
     285          38 :   void erase(const key_type &key) { static_cast<void>(extract(key)); }
     286             : 
     287        7961 :   mapped_type extract(const key_type &key) {
     288        7961 :     const auto key_it = expect_find_key(key);
     289        7959 :     const auto value_it = std::next(m_values.begin(), index_of(key_it));
     290        7959 :     m_keys.erase(key_it);
     291        7959 :     mapped_type value = std::move(*value_it);
     292        7959 :     m_values.erase(value_it);
     293       15918 :     return value;
     294           0 :   }
     295             : 
     296           3 :   void clear() {
     297           3 :     m_keys.clear();
     298           3 :     m_values.clear();
     299           3 :   }
     300             : 
     301      242589 :   [[nodiscard]] const mapped_type &operator[](const key_type &key) const {
     302      242589 :     return m_values[expect_find_index(key)];
     303             :   }
     304             : 
     305       14953 :   [[nodiscard]] mapped_type &operator[](const key_type &key) {
     306       14953 :     return m_values[expect_find_index(key)];
     307             :   }
     308             : 
     309      242589 :   [[nodiscard]] const mapped_type &at(const key_type &key) const {
     310      242589 :     return (*this)[key];
     311             :   }
     312             : 
     313       11255 :   [[nodiscard]] mapped_type &at(const key_type &key) { return (*this)[key]; }
     314             : 
     315        6849 :   [[nodiscard]] const_iterator find(const key_type &key) const {
     316        6849 :     if (const auto key_it = find_key(key); key_it == m_keys.end()) {
     317        1694 :       return end();
     318             :     } else {
     319        5155 :       return const_iterator(m_keys, key_it,
     320        5155 :                             std::next(m_values.begin(), index_of(key_it)));
     321             :     }
     322             :   }
     323             : 
     324        6307 :   [[nodiscard]] iterator find(const key_type &key) {
     325        6307 :     if (const auto key_it = find_key(key); key_it == m_keys.end()) {
     326        4918 :       return end();
     327             :     } else {
     328        1389 :       return iterator(m_keys, key_it,
     329        1389 :                       std::next(m_values.begin(), index_of(key_it)));
     330             :     }
     331             :   }
     332             : 
     333       35883 :   [[nodiscard]] auto keys_begin() const noexcept {
     334       35883 :     return const_key_iterator(m_keys, m_keys.cbegin());
     335             :   }
     336             : 
     337       35885 :   [[nodiscard]] auto keys_end() const noexcept {
     338       35885 :     return const_key_iterator(m_keys, m_keys.cend());
     339             :   }
     340             : 
     341             :   [[nodiscard]] auto values_begin() noexcept {
     342             :     return value_iterator(m_values, m_values.begin());
     343             :   }
     344             : 
     345             :   [[nodiscard]] auto values_end() noexcept {
     346             :     return value_iterator(m_values, m_values.end());
     347             :   }
     348             : 
     349        3254 :   [[nodiscard]] auto values_begin() const noexcept {
     350        3254 :     return const_value_iterator(m_values, m_values.cbegin());
     351             :   }
     352             : 
     353        3278 :   [[nodiscard]] auto values_end() const noexcept {
     354        3278 :     return const_value_iterator(m_values, m_values.cend());
     355             :   }
     356             : 
     357      478172 :   [[nodiscard]] auto begin() noexcept {
     358      478172 :     return iterator(m_keys, m_keys.cbegin(), m_values.begin());
     359             :   }
     360             : 
     361      488597 :   [[nodiscard]] auto end() noexcept {
     362      488597 :     return iterator(m_keys, m_keys.cend(), m_values.end());
     363             :   }
     364             : 
     365      603788 :   [[nodiscard]] auto begin() const noexcept {
     366      603788 :     return const_iterator(m_keys, m_keys.cbegin(), m_values.cbegin());
     367             :   }
     368             : 
     369      608659 :   [[nodiscard]] auto end() const noexcept {
     370      608659 :     return const_iterator(m_keys, m_keys.cend(), m_values.cbegin());
     371             :   }
     372             : 
     373             : private:
     374             :   Keys m_keys;
     375             :   Values m_values;
     376             : 
     377     1756997 :   auto find_key(const Key &key) const noexcept {
     378     1756997 :     return std::find(m_keys.begin(), m_keys.end(), key);
     379             :   }
     380             : 
     381      265503 :   auto expect_find_key(const Key &key) const {
     382      265503 :     if (const auto key_it = find_key(key); key_it != m_keys.end()) {
     383      265501 :       return key_it;
     384             :     }
     385             :     using scipp::core::to_string;
     386             :     using std::to_string;
     387           2 :     throw except::NotFoundError("Expected " + dict_keys_to_string(*this) +
     388           0 :                                 " to contain " + to_string(key) + ".");
     389             :   }
     390             : 
     391      273026 :   auto index_of(const typename Keys::const_iterator &it) const noexcept {
     392      273026 :     return std::distance(m_keys.begin(), it);
     393             :   }
     394             : 
     395      257542 :   scipp::index expect_find_index(const Key &key) const {
     396      257542 :     return index_of(expect_find_key(key));
     397             :   }
     398             : };
     399             : 
     400             : template <class It>
     401             : std::string dict_keys_to_string(It it, It end,
     402             :                                 const std::string_view &dict_name);
     403             : 
     404             : template <class Key, class Value>
     405           2 : std::string dict_keys_to_string(const Dict<Key, Value> &dict,
     406             :                                 const std::string_view &dict_name = "Dict") {
     407           2 :   return dict_keys_to_string(dict.keys_begin(), dict.keys_end(), dict_name);
     408             : }
     409             : } // namespace scipp::core

Generated by: LCOV version 1.14