Line data Source code
1 : // SPDX-License-Identifier: BSD-3-Clause
2 : // Copyright (c) 2023 Scipp contributors (https://github.com/scipp)
3 : /// @file
4 : /// @author Jan-Lukas Wynen
5 : ///
6 : /// Dict is a container similar to Python's dict. It differs from
7 : /// std::map and std::unordered_map in that it stores elements in the
8 : /// order of insertion. In addition, its iterators throw an exception
9 : /// if the dict has changed size during iteration. This matches Python's
10 : /// behavior and avoids segfaults when misusing the dict.
11 : #pragma once
12 :
13 : #include <functional>
14 : #include <string>
15 : #include <string_view>
16 : #include <vector>
17 :
18 : #include "scipp/common/index.h"
19 :
20 : #include "scipp/core/except.h"
21 : #include "scipp/core/string.h"
22 :
23 : namespace scipp::core::dict_detail {
24 : template <class It1, class It2 = void> struct ValueType {
25 : using type = std::pair<typename It1::value_type, typename It2::value_type>;
26 : };
27 :
28 : template <class It1> struct ValueType<It1, void> {
29 : using type = typename It1::value_type;
30 : };
31 :
32 : template <class It1, class It2 = void> struct ReferenceType {
33 : using type = std::add_rvalue_reference_t<
34 : std::pair<typename It1::reference, typename It2::reference>>;
35 : };
36 :
37 : template <class It1> struct ReferenceType<It1, void> {
38 : using type = typename It1::reference;
39 : };
40 :
41 : template <class BaseIterator, class Func> class TransformIterator;
42 :
43 : // This iterator is mostly standard library conform. But it violates the
44 : // requirement that *it must return a reference to value_type.
45 : // This is required because the keys must be returned as const refs but
46 : // stored in the dict as non-const.
47 : template <class Container, class... It> class Iterator {
48 : static_assert(sizeof...(It) > 0 && sizeof...(It) < 3);
49 :
50 : public:
51 : using difference_type = std::ptrdiff_t;
52 : using value_type = typename ValueType<It...>::type;
53 : using pointer = std::add_pointer_t<std::remove_reference_t<value_type>>;
54 : using reference = typename ReferenceType<It...>::type;
55 :
56 : template <class... T>
57 2264060 : explicit Iterator(std::reference_wrapper<Container> container, T &&...it)
58 2264060 : : m_iterators{std::forward<T>(it)...}, m_container(container),
59 2264060 : m_base_address(container.get().data()), m_size(container.get().size()) {
60 2264060 : }
61 :
62 708111 : decltype(auto) operator*() const {
63 708111 : expect_container_unchanged();
64 : if constexpr (sizeof...(It) == 1) {
65 43874 : return *std::get<0>(m_iterators);
66 : } else {
67 664237 : return std::make_pair(std::cref(*std::get<0>(m_iterators)),
68 1328474 : std::ref(*std::get<1>(m_iterators)));
69 : }
70 : }
71 :
72 4256 : decltype(auto) operator->() const {
73 : if constexpr (sizeof...(It) == 1) {
74 8 : expect_container_unchanged();
75 8 : return std::get<0>(m_iterators);
76 : } else {
77 : // No need to use expect_container_unchanged
78 : // because we delegate to operator*
79 4248 : return TemporaryItem<reference>(**this);
80 : }
81 : }
82 :
83 698918 : Iterator &operator++() {
84 698918 : expect_container_unchanged();
85 698918 : ++std::get<0>(m_iterators);
86 : if constexpr (sizeof...(It) == 2)
87 655099 : ++std::get<1>(m_iterators);
88 698918 : return *this;
89 : }
90 :
91 1850160 : bool operator==(const Iterator<Container, It...> &other) const {
92 1850160 : expect_container_unchanged();
93 : // Assuming m_iterators are always in sync.
94 1850160 : return std::get<0>(m_iterators) == std::get<0>(other.m_iterators);
95 : }
96 :
97 1788205 : bool operator!=(const Iterator<Container, It...> &other) const {
98 1788205 : return !(*this == other); // NOLINT
99 : }
100 :
101 2 : template <class F> auto transform(F &&func) const & {
102 2 : return TransformIterator{*this, std::forward<F>(func)};
103 : }
104 :
105 82334 : template <class F> auto transform(F &&func) && {
106 82334 : return TransformIterator{std::move(*this), std::forward<F>(func)};
107 : }
108 :
109 : friend void swap(Iterator &a, Iterator &b) {
110 : swap(a.m_iterators, b.m_iterators);
111 : swap(a.m_container, b.m_container);
112 : std::swap(a.m_base_address, b.m_base_address);
113 : std::swap(a.m_size, b.m_size);
114 : }
115 :
116 : protected:
117 : // operator-> needs to return a pointer or something that has operator->
118 : // But we cannot take the address of the temporary pair or transform result.
119 : // So store it in this wrapper to make it accessible via its address.
120 : template <class T> class TemporaryItem {
121 : public:
122 4292 : explicit TemporaryItem(T &&item) : m_item(std::move(item)) {}
123 4292 : auto *operator->() { return &m_item; }
124 :
125 : private:
126 : std::decay_t<T> m_item;
127 : };
128 :
129 : private:
130 : using IteratorStorage = std::tuple<It...>;
131 :
132 : IteratorStorage m_iterators;
133 : std::reference_wrapper<Container> m_container;
134 : const void *m_base_address;
135 : size_t m_size;
136 :
137 3257197 : void expect_container_unchanged() const {
138 6514394 : if (m_container.get().data() != m_base_address ||
139 3257197 : m_container.get().size() != m_size) {
140 0 : throw std::runtime_error("dictionary changed size during iteration");
141 : }
142 3257197 : }
143 : };
144 :
145 : template <class BaseIterator, class Func>
146 : class TransformIterator : public BaseIterator {
147 : public:
148 : using difference_type = std::ptrdiff_t;
149 : using value_type =
150 : std::invoke_result_t<Func, typename BaseIterator::value_type>;
151 : using pointer = std::add_pointer_t<std::remove_reference_t<value_type>>;
152 : using reference = std::add_lvalue_reference_t<value_type>;
153 :
154 : template <class It, class F>
155 82336 : TransformIterator(It &&base, F &&func)
156 82336 : : BaseIterator(std::forward<It>(base)), m_func(std::forward<F>(func)) {}
157 :
158 80268 : decltype(auto) operator*() const { return m_func(BaseIterator::operator*()); }
159 :
160 44 : decltype(auto) operator->() const {
161 : using Result = typename BaseIterator::template TemporaryItem<
162 : std::decay_t<decltype(**this)>>;
163 44 : return Result(**this);
164 : }
165 :
166 : template <class F> auto transform(F &&func) const & {
167 : return BaseIterator::transform(
168 : [new_f = std::forward<F>(func), old_f = this->m_func](const auto &x) {
169 : return new_f(old_f(x));
170 : });
171 : }
172 :
173 2 : template <class F> auto transform(F &&func) && {
174 : // Make a copy for old_f to avoid referencing a member of *this.
175 : return BaseIterator::transform(
176 3 : [new_f = std::forward<F>(func), old_f = this->m_func](const auto &x) {
177 1 : return new_f(old_f(x));
178 2 : });
179 : }
180 :
181 : private:
182 : std::decay_t<Func> m_func;
183 : };
184 :
185 : template <class I, class F>
186 : TransformIterator(I, F) -> TransformIterator<std::decay_t<I>, std::decay_t<F>>;
187 : } // namespace scipp::core::dict_detail
188 :
189 : namespace std {
190 : template <class Container, class... It>
191 : struct iterator_traits<scipp::core::dict_detail::Iterator<Container, It...>> {
192 : private:
193 : using I = scipp::core::dict_detail::Iterator<Container, It...>;
194 :
195 : public:
196 : using difference_type = typename I::difference_type;
197 : using value_type = typename I::value_type;
198 : using pointer = typename I::pointer;
199 : using reference = typename I::reference;
200 :
201 : // It is a forward iterator for most use cases.
202 : // But it misses post-increment:
203 : // it++ and *it++ (easy, but not needed right now)
204 : using iterator_category = std::forward_iterator_tag;
205 : };
206 :
207 : template <class BaseIterator, class Func>
208 : struct iterator_traits<
209 : scipp::core::dict_detail::TransformIterator<BaseIterator, Func>> {
210 : private:
211 : using I = scipp::core::dict_detail::TransformIterator<BaseIterator, Func>;
212 :
213 : public:
214 : using difference_type = typename I::difference_type;
215 : using value_type = typename I::value_type;
216 : using pointer = typename I::pointer;
217 : using reference = typename I::reference;
218 :
219 : // It is a forward iterator for most use cases.
220 : // But it misses post-increment:
221 : // it++ and *it++ (easy, but not needed right now)
222 : using iterator_category = std::forward_iterator_tag;
223 : };
224 : } // namespace std
225 :
226 : namespace scipp::core {
227 : template <class Key, class Value> class Dict {
228 : using Keys = std::vector<Key>;
229 : using Values = std::vector<Value>;
230 :
231 : public:
232 : using key_type = Key;
233 : using mapped_type = Value;
234 : using value_type = std::pair<const Key, Value>;
235 : using value_iterator =
236 : typename dict_detail::Iterator<Values, typename Values::iterator>;
237 : using iterator =
238 : typename dict_detail::Iterator<Keys, typename Keys::const_iterator,
239 : typename Values::iterator>;
240 : using const_key_iterator =
241 : typename dict_detail::Iterator<const Keys, typename Keys::const_iterator>;
242 : using const_value_iterator =
243 : typename dict_detail::Iterator<const Values,
244 : typename Values::const_iterator>;
245 : using const_iterator =
246 : typename dict_detail::Iterator<const Keys, typename Keys::const_iterator,
247 : typename Values::const_iterator>;
248 :
249 2249 : Dict(std::initializer_list<std::pair<const Key, Value>> items) {
250 2249 : reserve(items.size());
251 2272 : for (const auto &[key, value] : items) {
252 23 : if (contains(key))
253 0 : throw std::invalid_argument("duplicate key in initializer");
254 23 : insert_or_assign(key, value);
255 : }
256 2249 : }
257 :
258 871455 : Dict() = default;
259 :
260 : /// Return the number of elements.
261 48899 : [[nodiscard]] index size() const noexcept { return scipp::size(m_keys); }
262 : /// Return true if there are 0 elements.
263 2505 : [[nodiscard]] bool empty() const noexcept { return size() == 0; }
264 : /// Return the number of elements that space is currently allocated for.
265 966 : [[nodiscard]] index capacity() const noexcept { return m_keys.capacity(); }
266 :
267 2524 : void reserve(const index new_capacity) {
268 2524 : m_keys.reserve(new_capacity);
269 2524 : m_values.reserve(new_capacity);
270 2524 : }
271 :
272 923929 : [[nodiscard]] bool contains(const Key &key) const noexcept {
273 923929 : return find_key(key) != m_keys.end();
274 : }
275 :
276 554409 : template <class V> void insert_or_assign(const key_type &key, V &&value) {
277 554409 : if (const auto key_it = find_key(key); key_it == m_keys.end()) {
278 553428 : m_keys.push_back(key);
279 553428 : m_values.emplace_back(std::forward<V>(value));
280 : } else {
281 981 : m_values[index_of(key_it)] = std::forward<V>(value);
282 : }
283 554409 : }
284 :
285 38 : void erase(const key_type &key) { static_cast<void>(extract(key)); }
286 :
287 7961 : mapped_type extract(const key_type &key) {
288 7961 : const auto key_it = expect_find_key(key);
289 7959 : const auto value_it = std::next(m_values.begin(), index_of(key_it));
290 7959 : m_keys.erase(key_it);
291 7959 : mapped_type value = std::move(*value_it);
292 7959 : m_values.erase(value_it);
293 15918 : return value;
294 0 : }
295 :
296 3 : void clear() {
297 3 : m_keys.clear();
298 3 : m_values.clear();
299 3 : }
300 :
301 242589 : [[nodiscard]] const mapped_type &operator[](const key_type &key) const {
302 242589 : return m_values[expect_find_index(key)];
303 : }
304 :
305 14953 : [[nodiscard]] mapped_type &operator[](const key_type &key) {
306 14953 : return m_values[expect_find_index(key)];
307 : }
308 :
309 242589 : [[nodiscard]] const mapped_type &at(const key_type &key) const {
310 242589 : return (*this)[key];
311 : }
312 :
313 11255 : [[nodiscard]] mapped_type &at(const key_type &key) { return (*this)[key]; }
314 :
315 6849 : [[nodiscard]] const_iterator find(const key_type &key) const {
316 6849 : if (const auto key_it = find_key(key); key_it == m_keys.end()) {
317 1694 : return end();
318 : } else {
319 5155 : return const_iterator(m_keys, key_it,
320 5155 : std::next(m_values.begin(), index_of(key_it)));
321 : }
322 : }
323 :
324 6307 : [[nodiscard]] iterator find(const key_type &key) {
325 6307 : if (const auto key_it = find_key(key); key_it == m_keys.end()) {
326 4918 : return end();
327 : } else {
328 1389 : return iterator(m_keys, key_it,
329 1389 : std::next(m_values.begin(), index_of(key_it)));
330 : }
331 : }
332 :
333 35883 : [[nodiscard]] auto keys_begin() const noexcept {
334 35883 : return const_key_iterator(m_keys, m_keys.cbegin());
335 : }
336 :
337 35885 : [[nodiscard]] auto keys_end() const noexcept {
338 35885 : return const_key_iterator(m_keys, m_keys.cend());
339 : }
340 :
341 : [[nodiscard]] auto values_begin() noexcept {
342 : return value_iterator(m_values, m_values.begin());
343 : }
344 :
345 : [[nodiscard]] auto values_end() noexcept {
346 : return value_iterator(m_values, m_values.end());
347 : }
348 :
349 3254 : [[nodiscard]] auto values_begin() const noexcept {
350 3254 : return const_value_iterator(m_values, m_values.cbegin());
351 : }
352 :
353 3278 : [[nodiscard]] auto values_end() const noexcept {
354 3278 : return const_value_iterator(m_values, m_values.cend());
355 : }
356 :
357 478172 : [[nodiscard]] auto begin() noexcept {
358 478172 : return iterator(m_keys, m_keys.cbegin(), m_values.begin());
359 : }
360 :
361 488597 : [[nodiscard]] auto end() noexcept {
362 488597 : return iterator(m_keys, m_keys.cend(), m_values.end());
363 : }
364 :
365 603788 : [[nodiscard]] auto begin() const noexcept {
366 603788 : return const_iterator(m_keys, m_keys.cbegin(), m_values.cbegin());
367 : }
368 :
369 608659 : [[nodiscard]] auto end() const noexcept {
370 608659 : return const_iterator(m_keys, m_keys.cend(), m_values.cbegin());
371 : }
372 :
373 : private:
374 : Keys m_keys;
375 : Values m_values;
376 :
377 1756997 : auto find_key(const Key &key) const noexcept {
378 1756997 : return std::find(m_keys.begin(), m_keys.end(), key);
379 : }
380 :
381 265503 : auto expect_find_key(const Key &key) const {
382 265503 : if (const auto key_it = find_key(key); key_it != m_keys.end()) {
383 265501 : return key_it;
384 : }
385 : using scipp::core::to_string;
386 : using std::to_string;
387 2 : throw except::NotFoundError("Expected " + dict_keys_to_string(*this) +
388 0 : " to contain " + to_string(key) + ".");
389 : }
390 :
391 273026 : auto index_of(const typename Keys::const_iterator &it) const noexcept {
392 273026 : return std::distance(m_keys.begin(), it);
393 : }
394 :
395 257542 : scipp::index expect_find_index(const Key &key) const {
396 257542 : return index_of(expect_find_key(key));
397 : }
398 : };
399 :
400 : template <class It>
401 : std::string dict_keys_to_string(It it, It end,
402 : const std::string_view &dict_name);
403 :
404 : template <class Key, class Value>
405 2 : std::string dict_keys_to_string(const Dict<Key, Value> &dict,
406 : const std::string_view &dict_name = "Dict") {
407 2 : return dict_keys_to_string(dict.keys_begin(), dict.keys_end(), dict_name);
408 : }
409 : } // namespace scipp::core
|