// SPDX-FileCopyrightText: Copyright (c) 2008-2013, NVIDIA Corporation. All rights reserved.
// SPDX-License-Identifier: Apache-2.0

#pragma once

#include <thrust/detail/config.h>

#if defined(_CCCL_IMPLICIT_SYSTEM_HEADER_GCC)
#  pragma GCC system_header
#elif defined(_CCCL_IMPLICIT_SYSTEM_HEADER_CLANG)
#  pragma clang system_header
#elif defined(_CCCL_IMPLICIT_SYSTEM_HEADER_MSVC)
#  pragma system_header
#endif // no system header

#include <thrust/random/detail/mod.h>

#include <cuda/std/cstdint>

THRUST_NAMESPACE_BEGIN

namespace random::detail
{
template <typename UIntType, UIntType a, unsigned long long c, UIntType m>
struct linear_congruential_engine_discard_implementation
{
  _CCCL_HOST_DEVICE static void discard(UIntType& state, unsigned long long z)
  {
    for (; z > 0; --z)
    {
      state = detail::mod<UIntType, a, c, m>(state);
    }
  }
}; // end linear_congruential_engine_discard

// specialize for small integers and c == 0
// XXX figure out a robust implementation of this for any unsigned integer type later
template <std::uint32_t a, std::uint32_t m>
struct linear_congruential_engine_discard_implementation<std::uint32_t, a, 0, m>
{
  _CCCL_HOST_DEVICE static void discard(std::uint32_t& state, unsigned long long z)
  {
    const std::uint32_t modulus = m;

    // XXX we need to use unsigned long long here or we will encounter overflow in the
    //     multiplies below
    //     figure out a robust implementation of this later
    unsigned long long multiplier      = a;
    unsigned long long multiplier_to_z = 1;

    // see http://en.wikipedia.org/wiki/Modular_exponentiation
    while (z > 0)
    {
      if (z & 1)
      {
        // multiply in this bit's contribution while using modulus to keep result small
        multiplier_to_z = (multiplier_to_z * multiplier) % modulus;
      }

      // move to the next bit of the exponent, square (and mod) the base accordingly
      z >>= 1;
      multiplier = (multiplier * multiplier) % modulus;
    }

    state = static_cast<std::uint32_t>((multiplier_to_z * state) % modulus);
  }
}; // end linear_congruential_engine_discard

struct linear_congruential_engine_discard
{
  template <typename LinearCongruentialEngine>
  _CCCL_HOST_DEVICE static void discard(LinearCongruentialEngine& lcg, unsigned long long z)
  {
    using result_type                    = typename LinearCongruentialEngine::result_type;
    [[maybe_unused]] const result_type c = LinearCongruentialEngine::increment;
    [[maybe_unused]] const result_type a = LinearCongruentialEngine::multiplier;
    [[maybe_unused]] const result_type m = LinearCongruentialEngine::modulus;

    linear_congruential_engine_discard_implementation<result_type, a, c, m>::discard(lcg.m_x, z);
  }
}; // end linear_congruential_engine_discard
} // namespace random::detail

THRUST_NAMESPACE_END
