Du kan inte välja fler än 25 ämnen Ämnen måste starta med en bokstav eller siffra, kan innehålla bindestreck ('-') och vara max 35 tecken långa.

283 lines
9.4KB

  1. /// \file
  2. // Range v3 library
  3. //
  4. // Copyright Casey Carter 2016
  5. //
  6. // Use, modification and distribution is subject to the
  7. // Boost Software License, Version 1.0. (See accompanying
  8. // file LICENSE_1_0.txt or copy at
  9. // http://www.boost.org/LICENSE_1_0.txt)
  10. //
  11. // Project home: https://github.com/ericniebler/range-v3
  12. //
  13. #ifndef RANGES_V3_VIEW_SAMPLE_HPP
  14. #define RANGES_V3_VIEW_SAMPLE_HPP
  15. #include <meta/meta.hpp>
  16. #include <range/v3/algorithm/shuffle.hpp>
  17. #include <range/v3/functional/bind_back.hpp>
  18. #include <range/v3/functional/invoke.hpp>
  19. #include <range/v3/iterator/concepts.hpp>
  20. #include <range/v3/iterator/default_sentinel.hpp>
  21. #include <range/v3/iterator/operations.hpp>
  22. #include <range/v3/range/concepts.hpp>
  23. #include <range/v3/utility/static_const.hpp>
  24. #include <range/v3/view/all.hpp>
  25. #include <range/v3/view/facade.hpp>
  26. #include <range/v3/view/view.hpp>
  27. namespace ranges
  28. {
  29. /// \cond
  30. namespace detail
  31. {
  32. template<typename Rng,
  33. bool = (bool)sized_sentinel_for<sentinel_t<Rng>, iterator_t<Rng>>>
  34. class size_tracker
  35. {
  36. range_difference_t<Rng> size_;
  37. public:
  38. CPP_assert(forward_range<Rng> || sized_range<Rng>);
  39. size_tracker() = default;
  40. size_tracker(Rng & rng)
  41. : size_(ranges::distance(rng))
  42. {}
  43. void decrement()
  44. {
  45. --size_;
  46. }
  47. range_difference_t<Rng> get(Rng &, iterator_t<Rng> &) const
  48. {
  49. return size_;
  50. }
  51. };
  52. // Impl for sized_sentinel_for (no need to store anything)
  53. template<typename Rng>
  54. class size_tracker<Rng, true>
  55. {
  56. public:
  57. size_tracker() = default;
  58. size_tracker(Rng &)
  59. {}
  60. void decrement()
  61. {}
  62. range_difference_t<Rng> get(Rng & rng, iterator_t<Rng> const & it) const
  63. {
  64. return ranges::end(rng) - it;
  65. }
  66. };
  67. } // namespace detail
  68. /// \endcond
  69. /// \addtogroup group-views
  70. /// @{
  71. // Take a random sampling from another view
  72. template<typename Rng, typename URNG>
  73. class sample_view : public view_facade<sample_view<Rng, URNG>, finite>
  74. {
  75. friend range_access;
  76. using D = range_difference_t<Rng>;
  77. Rng rng_;
  78. // Mutable is OK here because sample_view is an Input view.
  79. mutable range_difference_t<Rng> size_;
  80. URNG * engine_;
  81. template<bool IsConst>
  82. class cursor
  83. {
  84. friend cursor<!IsConst>;
  85. using Base = meta::const_if_c<IsConst, Rng>;
  86. meta::const_if_c<IsConst, sample_view> * parent_;
  87. iterator_t<Base> current_;
  88. RANGES_NO_UNIQUE_ADDRESS detail::size_tracker<Base> size_;
  89. D pop_size()
  90. {
  91. return size_.get(parent_->rng_, current_);
  92. }
  93. void advance()
  94. {
  95. if(parent_->size_ > 0)
  96. {
  97. using Dist = std::uniform_int_distribution<D>;
  98. Dist dist{};
  99. URNG & engine = *parent_->engine_;
  100. for(;; ++current_, size_.decrement())
  101. {
  102. RANGES_ASSERT(current_ != ranges::end(parent_->rng_));
  103. auto n = pop_size();
  104. RANGES_EXPECT(n > 0);
  105. typename Dist::param_type const interval{0, n - 1};
  106. if(dist(engine, interval) < parent_->size_)
  107. break;
  108. }
  109. }
  110. }
  111. public:
  112. using value_type = range_value_t<Rng>;
  113. using difference_type = D;
  114. cursor() = default;
  115. explicit cursor(meta::const_if_c<IsConst, sample_view> * rng)
  116. : parent_(rng)
  117. , current_(ranges::begin(rng->rng_))
  118. , size_{rng->rng_}
  119. {
  120. auto n = pop_size();
  121. if(rng->size_ > n)
  122. rng->size_ = n;
  123. advance();
  124. }
  125. CPP_template(bool Other)( //
  126. requires IsConst && (!Other)) cursor(cursor<Other> that)
  127. : parent_(that.parent_)
  128. , current_(std::move(that.current_))
  129. , size_(that.size_)
  130. {}
  131. range_reference_t<Rng> read() const
  132. {
  133. return *current_;
  134. }
  135. bool equal(default_sentinel_t) const
  136. {
  137. RANGES_EXPECT(parent_);
  138. return parent_->size_ <= 0;
  139. }
  140. void next()
  141. {
  142. RANGES_EXPECT(parent_);
  143. RANGES_EXPECT(parent_->size_ > 0);
  144. --parent_->size_;
  145. RANGES_ASSERT(current_ != ranges::end(parent_->rng_));
  146. ++current_;
  147. size_.decrement();
  148. advance();
  149. }
  150. };
  151. cursor<false> begin_cursor()
  152. {
  153. return cursor<false>{this};
  154. }
  155. template<bool Const = true>
  156. auto begin_cursor() const -> CPP_ret(cursor<Const>)( //
  157. requires Const &&
  158. (sized_range<meta::const_if_c<Const, Rng>> ||
  159. sized_sentinel_for<sentinel_t<meta::const_if_c<Const, Rng>>,
  160. iterator_t<meta::const_if_c<Const, Rng>>> ||
  161. forward_range<meta::const_if_c<Const, Rng>>))
  162. {
  163. return cursor<true>{this};
  164. }
  165. public:
  166. sample_view() = default;
  167. explicit sample_view(Rng rng, D sample_size, URNG & generator)
  168. : rng_(std::move(rng))
  169. , size_(sample_size)
  170. , engine_(std::addressof(generator))
  171. {
  172. RANGES_EXPECT(sample_size >= 0);
  173. }
  174. Rng base() const
  175. {
  176. return rng_;
  177. }
  178. };
  179. #if RANGES_CXX_DEDUCTION_GUIDES >= RANGES_CXX_DEDUCTION_GUIDES_17
  180. template<typename Rng, typename URNG>
  181. sample_view(Rng &&, range_difference_t<Rng>, URNG &)
  182. ->sample_view<views::all_t<Rng>, URNG>;
  183. #endif
  184. namespace views
  185. {
  186. /// Returns a random sample of a range of length `size(range)`.
  187. struct sample_fn
  188. {
  189. private:
  190. friend view_access;
  191. #ifdef RANGES_WORKAROUND_MSVC_OLD_LAMBDA
  192. template<typename Size, typename URNG>
  193. struct lamduh
  194. {
  195. Size n;
  196. URNG & urng;
  197. template<typename Rng>
  198. auto operator()(Rng && rng) const
  199. -> invoke_result_t<sample_fn, Rng, range_difference_t<Rng>, URNG &>
  200. {
  201. return sample_fn{}(static_cast<Rng &&>(rng),
  202. static_cast<range_difference_t<Rng>>(n),
  203. urng);
  204. }
  205. };
  206. template<typename Size, typename URNG = detail::default_random_engine>
  207. static auto CPP_fun(bind)(sample_fn, Size n,
  208. URNG & urng = detail::get_random_engine())( //
  209. requires integral<Size> && uniform_random_bit_generator<URNG>)
  210. {
  211. return make_pipeable(lamduh<Size, URNG>{std::move(n), urng});
  212. }
  213. #else // ^^^ workaround / no workaround vvv
  214. template<typename Size, typename URNG = detail::default_random_engine>
  215. static auto CPP_fun(bind)(sample_fn, Size n,
  216. URNG & urng = detail::get_random_engine())( //
  217. requires integral<Size> && uniform_random_bit_generator<URNG>)
  218. {
  219. return make_pipeable(
  220. [n, &urng](
  221. auto && rng) -> invoke_result_t<sample_fn,
  222. decltype(rng),
  223. range_difference_t<decltype(rng)>,
  224. URNG &> {
  225. return sample_fn{}(
  226. static_cast<decltype(rng)>(rng),
  227. static_cast<range_difference_t<decltype(rng)>>(n),
  228. urng);
  229. });
  230. }
  231. #endif // RANGES_WORKAROUND_MSVC_OLD_LAMBDA
  232. public:
  233. template<typename Rng, typename URNG = detail::default_random_engine>
  234. auto operator()(Rng && rng, range_difference_t<Rng> sample_size,
  235. URNG & generator = detail::get_random_engine()) const
  236. -> CPP_ret(sample_view<all_t<Rng>, URNG>)( //
  237. requires viewable_range<Rng> && input_range<Rng> &&
  238. uniform_random_bit_generator<URNG> && convertible_to<
  239. invoke_result_t<URNG &>, range_difference_t<Rng>> &&
  240. (sized_range<Rng> ||
  241. sized_sentinel_for<sentinel_t<Rng>, iterator_t<Rng>> ||
  242. forward_range<Rng>))
  243. {
  244. return sample_view<all_t<Rng>, URNG>{
  245. all(static_cast<Rng &&>(rng)), sample_size, generator};
  246. }
  247. };
  248. /// \relates sample_fn
  249. /// \ingroup group-views
  250. RANGES_INLINE_VARIABLE(view<sample_fn>, sample)
  251. } // namespace views
  252. /// @}
  253. } // namespace ranges
  254. #include <range/v3/detail/satisfy_boost_range.hpp>
  255. RANGES_SATISFY_BOOST_RANGE(::ranges::sample_view)
  256. #endif