You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
166 lines
4.2 KiB
166 lines
4.2 KiB
use crate::{models::AssignedJob, UResult}; |
|
use futures::{future::BoxFuture, lock::Mutex}; |
|
use lazy_static::lazy_static; |
|
use std::collections::HashMap; |
|
use std::future::Future; |
|
use tokio::{ |
|
runtime::Handle, |
|
sync::mpsc::{channel, Receiver, Sender}, |
|
task::{spawn, spawn_blocking, JoinHandle}, |
|
}; |
|
use uuid::Uuid; |
|
|
|
pub type ExecResult = UResult<AssignedJob>; |
|
|
|
lazy_static! { |
|
static ref FUT_RESULTS: Mutex<HashMap<Uuid, JoinInfo>> = Mutex::new(HashMap::new()); |
|
static ref FUT_CHANNEL: (Sender<Uuid>, Mutex<Receiver<Uuid>>) = { |
|
spawn(init_receiver()); |
|
let (tx, rx) = channel(100); |
|
(tx, Mutex::new(rx)) |
|
}; |
|
} |
|
|
|
struct JoinInfo { |
|
handle: JoinHandle<JoinHandle<ExecResult>>, |
|
completed: bool, |
|
collectable: bool, // indicates if future can be popped from pool via pop_task_if_completed |
|
} |
|
|
|
impl JoinInfo { |
|
async fn wait_result(self) -> ExecResult { |
|
self.handle.await.unwrap().await.unwrap() |
|
} |
|
} |
|
|
|
fn get_sender() -> Sender<Uuid> { |
|
FUT_CHANNEL.0.clone() |
|
} |
|
|
|
pub struct Waiter { |
|
tasks: Vec<BoxFuture<'static, ExecResult>>, |
|
fids: Vec<Uuid>, |
|
} |
|
|
|
impl Waiter { |
|
pub fn new() -> Self { |
|
Self { |
|
tasks: vec![], |
|
fids: vec![], |
|
} |
|
} |
|
|
|
pub fn push(&mut self, task: impl Future<Output = ExecResult> + Send + 'static) { |
|
self.tasks.push(Box::pin(task)); |
|
} |
|
|
|
/// Spawn prepared tasks |
|
pub async fn spawn(mut self) -> Self { |
|
let collectable = true; //TODO: self.tasks.len() != 1; |
|
for f in self.tasks.drain(..) { |
|
let handle = Handle::current(); |
|
let fid = Uuid::new_v4(); |
|
let tx = get_sender(); |
|
self.fids.push(fid); |
|
let task_wrapper = async move { |
|
debug!("inside wrapper (started): {}", fid); |
|
let result = f.await; |
|
tx.send(fid).await.unwrap(); |
|
result |
|
}; |
|
let handler = JoinInfo { |
|
handle: spawn_blocking(move || handle.spawn(task_wrapper)), |
|
completed: false, |
|
collectable, |
|
}; |
|
FUT_RESULTS.lock().await.insert(fid, handler); |
|
} |
|
self |
|
} |
|
|
|
/// Wait until a bunch of tasks is finished. |
|
/// NOT GUARANTEED that all tasks will be returned due to |
|
/// possibility to pop them in other places |
|
pub async fn wait(self) -> Vec<ExecResult> { |
|
let mut result = vec![]; |
|
for fid in self.fids { |
|
if let Some(task) = pop_task(fid).await { |
|
result.push(task.wait_result().await); |
|
} |
|
} |
|
result |
|
} |
|
} |
|
|
|
async fn init_receiver() { |
|
while let Some(fid) = FUT_CHANNEL.1.lock().await.recv().await { |
|
if let Some(mut lock) = FUT_RESULTS.try_lock() { |
|
if let Some(j) = lock.get_mut(&fid) { |
|
j.completed = true; |
|
} |
|
} |
|
} |
|
} |
|
|
|
async fn pop_task(fid: Uuid) -> Option<JoinInfo> { |
|
FUT_RESULTS.lock().await.remove(&fid) |
|
} |
|
|
|
pub async fn pop_task_if_completed(fid: Uuid) -> Option<ExecResult> { |
|
let &JoinInfo { |
|
handle: _, |
|
collectable, |
|
completed, |
|
} = match FUT_RESULTS.lock().await.get(&fid) { |
|
Some(t) => t, |
|
None => return None, |
|
}; |
|
if collectable && completed { |
|
let task = pop_task(fid).await.unwrap(); |
|
Some(task.wait_result().await) |
|
} else { |
|
None |
|
} |
|
} |
|
|
|
pub async fn pop_completed() -> Vec<ExecResult> { |
|
let mut completed: Vec<ExecResult> = vec![]; |
|
let fids = FUT_RESULTS |
|
.lock() |
|
.await |
|
.keys() |
|
.copied() |
|
.collect::<Vec<Uuid>>(); |
|
for fid in fids { |
|
if let Some(r) = pop_task_if_completed(fid).await { |
|
completed.push(r) |
|
} |
|
} |
|
completed |
|
} |
|
|
|
#[cfg(test)] |
|
mod tests { |
|
use super::*; |
|
|
|
// WTF |
|
// WTF |
|
// WTF |
|
#[tokio::test] |
|
async fn test_spawn() { |
|
use std::sync::Arc; |
|
|
|
let val = Arc::new(Mutex::new(0)); |
|
let t = { |
|
let v = val.clone(); |
|
spawn(async move { |
|
*v.lock().await = 5; |
|
}) |
|
}; |
|
assert_eq!(0, *val.lock().await); |
|
spawn(async {}).await.unwrap(); |
|
assert_eq!(5, *val.lock().await); |
|
t.await.unwrap(); |
|
assert_eq!(5, *val.lock().await); |
|
} |
|
}
|
|
|