use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Arc;
use std::time::{Duration, Instant};
#[derive(Debug, Clone, PartialEq)]
enum LimiterMode {
Enforcing,
Tracking,
Untracked,
}
#[derive(Debug, Clone, PartialEq)]
enum TimeLimiterMode {
Limited,
Unlimited,
}
#[derive(Debug, Clone)]
pub struct KeyValueScanLimiter {
limiter_mode: LimiterMode,
limit: usize,
keyvalues_scanned: Arc<AtomicUsize>,
}
impl KeyValueScanLimiter {
pub fn enforcing(limit: usize) -> KeyValueScanLimiter {
KeyValueScanLimiter {
limiter_mode: LimiterMode::Enforcing,
limit,
keyvalues_scanned: Arc::new(AtomicUsize::new(0)),
}
}
pub fn tracking() -> KeyValueScanLimiter {
KeyValueScanLimiter {
limiter_mode: LimiterMode::Tracking,
limit: 0,
keyvalues_scanned: Arc::new(AtomicUsize::new(0)),
}
}
pub fn untracked() -> KeyValueScanLimiter {
KeyValueScanLimiter {
limiter_mode: LimiterMode::Untracked,
limit: 0,
keyvalues_scanned: Arc::new(AtomicUsize::new(0)),
}
}
pub fn get_limit(&self) -> usize {
match self.limiter_mode {
LimiterMode::Enforcing => self.limit,
LimiterMode::Tracking => usize::MAX,
LimiterMode::Untracked => usize::MAX,
}
}
pub fn get_scanned_keyvalues(&self) -> usize {
match self.limiter_mode {
LimiterMode::Enforcing | LimiterMode::Tracking => {
self.keyvalues_scanned.load(Ordering::Relaxed)
}
LimiterMode::Untracked => 0,
}
}
pub(crate) fn try_keyvalue_scan(&self) -> bool {
match self.limiter_mode {
LimiterMode::Enforcing => {
let limit = self.limit;
let keyvalues_scanned = self.keyvalues_scanned.load(Ordering::Relaxed);
match limit.checked_sub(keyvalues_scanned) {
Some(x) => {
if x == 0 {
false
} else {
self.keyvalues_scanned.fetch_add(1, Ordering::SeqCst);
true
}
}
None => {
false
}
}
}
LimiterMode::Tracking => {
self.keyvalues_scanned.fetch_add(1, Ordering::SeqCst);
true
}
LimiterMode::Untracked => true,
}
}
pub fn is_enforcing(&self) -> bool {
match self.limiter_mode {
LimiterMode::Enforcing => true,
LimiterMode::Tracking | LimiterMode::Untracked => false,
}
}
}
#[cfg(test)]
impl PartialEq for KeyValueScanLimiter {
fn eq(&self, other: &KeyValueScanLimiter) -> bool {
let self_keyvalues_scanned = unsafe { *(&*self.keyvalues_scanned).as_ptr() };
let other_keyvalues_scanned = unsafe { *(&*other.keyvalues_scanned).as_ptr() };
(self_keyvalues_scanned == other_keyvalues_scanned)
&& (self.limiter_mode == other.limiter_mode)
&& (self.limit == other.limit)
}
}
#[derive(Debug, Clone)]
pub struct ByteScanLimiter {
limiter_mode: LimiterMode,
limit: usize,
bytes_scanned: Arc<AtomicUsize>,
}
impl ByteScanLimiter {
pub fn enforcing(limit: usize) -> ByteScanLimiter {
ByteScanLimiter {
limiter_mode: LimiterMode::Enforcing,
limit,
bytes_scanned: Arc::new(AtomicUsize::new(0)),
}
}
pub fn tracking() -> ByteScanLimiter {
ByteScanLimiter {
limiter_mode: LimiterMode::Tracking,
limit: 0,
bytes_scanned: Arc::new(AtomicUsize::new(0)),
}
}
pub fn untracked() -> ByteScanLimiter {
ByteScanLimiter {
limiter_mode: LimiterMode::Untracked,
limit: 0,
bytes_scanned: Arc::new(AtomicUsize::new(0)),
}
}
pub fn get_limit(&self) -> usize {
match self.limiter_mode {
LimiterMode::Enforcing => self.limit,
LimiterMode::Tracking => usize::MAX,
LimiterMode::Untracked => usize::MAX,
}
}
pub fn get_scanned_bytes(&self) -> usize {
match self.limiter_mode {
LimiterMode::Enforcing | LimiterMode::Tracking => {
self.bytes_scanned.load(Ordering::Relaxed)
}
LimiterMode::Untracked => 0,
}
}
pub fn is_enforcing(&self) -> bool {
match self.limiter_mode {
LimiterMode::Enforcing => true,
LimiterMode::Tracking | LimiterMode::Untracked => false,
}
}
pub(crate) fn try_keyvalue_scan(&self) -> bool {
match self.limiter_mode {
LimiterMode::Enforcing => {
let limit = self.limit;
let bytes_scanned = self.bytes_scanned.load(Ordering::Relaxed);
match limit.checked_sub(bytes_scanned) {
Some(x) => {
x != 0
}
None => {
false
}
}
}
LimiterMode::Tracking | LimiterMode::Untracked => true,
}
}
pub(crate) fn register_scanned_bytes(&self, bytes: usize) {
if let LimiterMode::Enforcing | LimiterMode::Tracking = self.limiter_mode {
self.bytes_scanned.fetch_add(bytes, Ordering::SeqCst);
}
}
}
#[cfg(test)]
impl PartialEq for ByteScanLimiter {
fn eq(&self, other: &ByteScanLimiter) -> bool {
let self_bytes_scanned = unsafe { *(&*self.bytes_scanned).as_ptr() };
let other_bytes_scanned = unsafe { *(&*other.bytes_scanned).as_ptr() };
(self_bytes_scanned == other_bytes_scanned)
&& (self.limiter_mode == other.limiter_mode)
&& (self.limit == other.limit)
}
}
#[derive(Debug, Clone, PartialEq)]
pub struct TimeScanLimiter {
limiter_mode: TimeLimiterMode,
start_instant: Option<Instant>,
time_limit_duration: Option<Duration>,
}
impl TimeScanLimiter {
pub fn new(time_limit: u64) -> TimeScanLimiter {
TimeScanLimiter {
limiter_mode: TimeLimiterMode::Limited,
start_instant: Some(Instant::now()),
time_limit_duration: Some(Duration::from_millis(time_limit)),
}
}
pub fn unlimited() -> TimeScanLimiter {
TimeScanLimiter {
limiter_mode: TimeLimiterMode::Unlimited,
start_instant: None,
time_limit_duration: None,
}
}
pub(crate) fn try_keyvalue_scan(&self) -> bool {
if let TimeLimiterMode::Unlimited = self.limiter_mode {
true
} else {
let since_start_duration = self.start_instant.as_ref().unwrap().elapsed();
let time_limit_duration = self.time_limit_duration.unwrap();
since_start_duration < time_limit_duration
}
}
}
#[cfg(not(test))]
#[derive(Debug, Clone)]
pub struct ScanLimiter {
keyvalue_scan_limiter: Option<KeyValueScanLimiter>,
byte_scan_limiter: Option<ByteScanLimiter>,
time_scan_limiter: Option<TimeScanLimiter>,
}
#[cfg(test)]
#[derive(Debug, Clone, PartialEq)]
pub struct ScanLimiter {
keyvalue_scan_limiter: Option<KeyValueScanLimiter>,
byte_scan_limiter: Option<ByteScanLimiter>,
time_scan_limiter: Option<TimeScanLimiter>,
}
impl ScanLimiter {
pub fn new(
keyvalue_scan_limiter: Option<KeyValueScanLimiter>,
byte_scan_limiter: Option<ByteScanLimiter>,
time_scan_limiter: Option<TimeScanLimiter>,
) -> ScanLimiter {
ScanLimiter {
keyvalue_scan_limiter,
time_scan_limiter,
byte_scan_limiter,
}
}
pub fn into_parts(
self,
) -> (
Option<KeyValueScanLimiter>,
Option<ByteScanLimiter>,
Option<TimeScanLimiter>,
) {
let ScanLimiter {
keyvalue_scan_limiter,
time_scan_limiter,
byte_scan_limiter,
} = self;
(keyvalue_scan_limiter, byte_scan_limiter, time_scan_limiter)
}
pub fn get_keyvalue_scan_limiter_ref(&self) -> Option<&KeyValueScanLimiter> {
self.keyvalue_scan_limiter.as_ref()
}
pub fn get_byte_scan_limiter_ref(&self) -> Option<&ByteScanLimiter> {
self.byte_scan_limiter.as_ref()
}
pub fn get_time_scan_limiter_ref(&self) -> Option<&TimeScanLimiter> {
self.time_scan_limiter.as_ref()
}
}
#[cfg(test)]
mod tests {
mod key_value_scan_limiter {
use super::super::KeyValueScanLimiter;
#[test]
fn get_limit() {
let enforcing = KeyValueScanLimiter::enforcing(3);
let tracking = KeyValueScanLimiter::tracking();
let untracked = KeyValueScanLimiter::untracked();
assert_eq!(enforcing.get_limit(), 3);
assert_eq!(tracking.get_limit(), usize::MAX);
assert_eq!(untracked.get_limit(), usize::MAX);
}
#[test]
fn get_scanned_keyvalues() {
let enforcing = KeyValueScanLimiter::enforcing(3);
let tracking = KeyValueScanLimiter::tracking();
let untracked = KeyValueScanLimiter::untracked();
assert_eq!(enforcing.get_scanned_keyvalues(), 0);
assert_eq!(tracking.get_scanned_keyvalues(), 0);
assert_eq!(untracked.get_scanned_keyvalues(), 0);
assert!(enforcing.try_keyvalue_scan());
assert!(enforcing.try_keyvalue_scan());
assert!(tracking.try_keyvalue_scan());
assert!(tracking.try_keyvalue_scan());
assert!(tracking.try_keyvalue_scan());
assert!(untracked.try_keyvalue_scan());
assert!(untracked.try_keyvalue_scan());
assert!(untracked.try_keyvalue_scan());
assert!(untracked.try_keyvalue_scan());
assert_eq!(enforcing.get_scanned_keyvalues(), 2);
assert_eq!(tracking.get_scanned_keyvalues(), 3);
assert_eq!(untracked.get_scanned_keyvalues(), 0);
assert!(enforcing.try_keyvalue_scan());
assert_eq!(enforcing.get_scanned_keyvalues(), 3);
assert!(!enforcing.try_keyvalue_scan());
assert_eq!(enforcing.get_scanned_keyvalues(), 3);
}
#[test]
fn try_keyvalue_scan() {
let enforcing = KeyValueScanLimiter::enforcing(3);
let tracking = KeyValueScanLimiter::tracking();
let untracked = KeyValueScanLimiter::untracked();
assert!(enforcing.try_keyvalue_scan());
assert!(enforcing.try_keyvalue_scan());
assert!(enforcing.try_keyvalue_scan());
assert!(!enforcing.try_keyvalue_scan());
assert!(tracking.try_keyvalue_scan());
assert!(tracking.try_keyvalue_scan());
assert!(tracking.try_keyvalue_scan());
assert!(untracked.try_keyvalue_scan());
assert!(untracked.try_keyvalue_scan());
assert!(untracked.try_keyvalue_scan());
assert!(untracked.try_keyvalue_scan());
}
#[test]
fn is_enforcing() {
let enforcing = KeyValueScanLimiter::enforcing(3);
let tracking = KeyValueScanLimiter::tracking();
let untracked = KeyValueScanLimiter::untracked();
assert!(enforcing.is_enforcing());
assert!(!tracking.is_enforcing());
assert!(!untracked.is_enforcing());
}
}
mod byte_scan_limiter {
use super::super::ByteScanLimiter;
#[test]
fn get_limit() {
let enforcing = ByteScanLimiter::enforcing(1000);
let tracking = ByteScanLimiter::tracking();
let untracked = ByteScanLimiter::untracked();
assert_eq!(enforcing.get_limit(), 1000);
assert_eq!(tracking.get_limit(), usize::MAX);
assert_eq!(untracked.get_limit(), usize::MAX);
}
#[test]
fn get_scanned_bytes() {
let enforcing = ByteScanLimiter::enforcing(1000);
let tracking = ByteScanLimiter::tracking();
let untracked = ByteScanLimiter::untracked();
assert_eq!(enforcing.get_scanned_bytes(), 0);
assert_eq!(tracking.get_scanned_bytes(), 0);
assert_eq!(untracked.get_scanned_bytes(), 0);
enforcing.register_scanned_bytes(100);
tracking.register_scanned_bytes(100);
untracked.register_scanned_bytes(100);
assert_eq!(enforcing.get_scanned_bytes(), 100);
assert_eq!(tracking.get_scanned_bytes(), 100);
assert_eq!(untracked.get_scanned_bytes(), 0);
enforcing.register_scanned_bytes(100);
tracking.register_scanned_bytes(100);
untracked.register_scanned_bytes(100);
assert_eq!(enforcing.get_scanned_bytes(), 200);
assert_eq!(tracking.get_scanned_bytes(), 200);
assert_eq!(untracked.get_scanned_bytes(), 0);
}
#[test]
fn try_keyvalue_scan() {
let enforcing = ByteScanLimiter::enforcing(1000);
let tracking = ByteScanLimiter::tracking();
let untracked = ByteScanLimiter::untracked();
tracking.register_scanned_bytes(10000);
untracked.register_scanned_bytes(10000);
assert!(tracking.try_keyvalue_scan());
assert!(untracked.try_keyvalue_scan());
enforcing.register_scanned_bytes(500);
assert!(enforcing.try_keyvalue_scan());
enforcing.register_scanned_bytes(500);
assert!(!enforcing.try_keyvalue_scan());
enforcing.register_scanned_bytes(500);
assert!(!enforcing.try_keyvalue_scan());
}
#[test]
fn is_enforcing() {
let enforcing = ByteScanLimiter::enforcing(1000);
let tracking = ByteScanLimiter::tracking();
let untracked = ByteScanLimiter::untracked();
assert!(enforcing.is_enforcing());
assert!(!tracking.is_enforcing());
assert!(!untracked.is_enforcing());
}
#[test]
fn register_scanned_bytes() {
let enforcing = ByteScanLimiter::enforcing(1000);
let tracking = ByteScanLimiter::tracking();
let untracked = ByteScanLimiter::untracked();
untracked.register_scanned_bytes(1000);
assert_eq!(untracked.get_scanned_bytes(), 0);
enforcing.register_scanned_bytes(1000);
tracking.register_scanned_bytes(1000);
assert_eq!(enforcing.get_scanned_bytes(), 1000);
assert_eq!(tracking.get_scanned_bytes(), 1000);
enforcing.register_scanned_bytes(1000);
tracking.register_scanned_bytes(1000);
assert_eq!(enforcing.get_scanned_bytes(), 2000);
assert_eq!(tracking.get_scanned_bytes(), 2000);
}
#[test]
fn arc() {
let enforcing1 = ByteScanLimiter::enforcing(1000);
let enforcing2 = enforcing1.clone();
enforcing1.register_scanned_bytes(500);
assert!(enforcing1.try_keyvalue_scan());
drop(enforcing1);
enforcing2.register_scanned_bytes(500);
assert!(!enforcing2.try_keyvalue_scan());
}
}
mod time_scan_limiter {
use std::thread::sleep;
use std::time::Duration;
use super::super::TimeScanLimiter;
#[test]
fn try_keyvalue_scan() {
let unlimited = TimeScanLimiter::unlimited();
assert!(unlimited.try_keyvalue_scan());
let limited = TimeScanLimiter::new(100);
let limited_zero = TimeScanLimiter::new(0);
assert!(!limited_zero.try_keyvalue_scan());
sleep(Duration::from_millis(50));
assert!(limited.try_keyvalue_scan());
sleep(Duration::from_millis(50));
assert!(!limited.try_keyvalue_scan());
}
}
mod scan_limiter {
use super::super::{ByteScanLimiter, KeyValueScanLimiter, ScanLimiter, TimeScanLimiter};
#[test]
fn from_parts() {
let scan_limiter = ScanLimiter::new(
Some(KeyValueScanLimiter::untracked()),
Some(ByteScanLimiter::untracked()),
Some(TimeScanLimiter::unlimited()),
);
let (keyvalue_scan_limiter, byte_scan_limiter, time_scan_limiter) =
scan_limiter.into_parts();
let keyvalue_scan_limiter = keyvalue_scan_limiter.unwrap();
let byte_scan_limiter = byte_scan_limiter.unwrap();
let time_scan_limiter = time_scan_limiter.unwrap();
assert!(!keyvalue_scan_limiter.is_enforcing());
assert!(!byte_scan_limiter.is_enforcing());
assert!(time_scan_limiter.try_keyvalue_scan());
}
#[test]
fn get_keyvalue_scan_limiter_ref() {
let scan_limiter = ScanLimiter::new(None, None, None);
assert!(scan_limiter.get_keyvalue_scan_limiter_ref().is_none());
let scan_limiter = ScanLimiter::new(Some(KeyValueScanLimiter::untracked()), None, None);
assert!(!scan_limiter
.get_keyvalue_scan_limiter_ref()
.unwrap()
.is_enforcing());
}
#[test]
fn get_byte_scan_limiter_ref() {
let scan_limiter = ScanLimiter::new(None, None, None);
assert!(scan_limiter.get_byte_scan_limiter_ref().is_none());
let scan_limiter = ScanLimiter::new(None, Some(ByteScanLimiter::untracked()), None);
assert!(!scan_limiter
.get_byte_scan_limiter_ref()
.unwrap()
.is_enforcing());
}
#[test]
fn get_time_scan_limiter_ref() {
let scan_limiter = ScanLimiter::new(None, None, None);
assert!(scan_limiter.get_time_scan_limiter_ref().is_none());
let scan_limiter = ScanLimiter::new(None, None, Some(TimeScanLimiter::unlimited()));
assert!(scan_limiter
.get_time_scan_limiter_ref()
.unwrap()
.try_keyvalue_scan());
}
}
}