You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 
 
 

137 lines
4.5 KiB

use std::collections::HashMap;
use std::fmt::Debug;
use crate::{
config::MASTER_PORT,
messaging::{self, AsMsg, BaseMessage, Empty},
models::{self},
utils::opt_to_string,
UError,
};
use anyhow::{Context, Result};
use reqwest::{header::HeaderMap, Certificate, Client, Identity, Url};
use serde::de::DeserializeOwned;
use serde_json::from_str;
use uuid::Uuid;
const AGENT_IDENTITY: &[u8] = include_bytes!("../../../certs/alice.p12");
const ROOT_CA_CERT: &[u8] = include_bytes!("../../../certs/ca.crt");
#[derive(Clone, Debug)]
pub struct ClientHandler {
base_url: Url,
client: Client,
}
impl ClientHandler {
pub fn new(server: &str, password: Option<String>) -> Self {
let identity = Identity::from_pkcs12_der(AGENT_IDENTITY, "").unwrap();
let mut client = Client::builder().identity(identity);
if let Some(pwd) = password {
client = client.default_headers(
HeaderMap::try_from(&HashMap::from([(
"Authorization".to_string(),
format!("Bearer {pwd}"),
)]))
.unwrap(),
)
}
let client = client
.add_root_certificate(Certificate::from_pem(ROOT_CA_CERT).unwrap())
.build()
.unwrap();
Self {
client,
base_url: Url::parse(&format!("https://{}:{}", server, MASTER_PORT)).unwrap(),
}
}
async fn _req<P: AsMsg + Debug, M: AsMsg + DeserializeOwned + Debug + Default>(
&self,
url: impl AsRef<str> + Debug,
payload: P,
) -> Result<M> {
let request = self
.client
.post(self.base_url.join(url.as_ref()).unwrap())
.json(&payload.as_message());
let response = request.send().await.context("send")?;
let content_len = response.content_length();
let is_success = match response.error_for_status_ref() {
Ok(_) => Ok(()),
Err(e) => Err(UError::from(e)),
};
let resp = response.text().await.context("resp")?;
debug!("url = {}, resp = {}", url.as_ref(), resp);
match is_success {
Ok(_) => from_str::<BaseMessage<M>>(&resp)
.map(|msg| msg.into_inner())
.or_else(|e| match content_len {
Some(0) => Ok(Default::default()),
_ => Err(UError::NetError(e.to_string(), resp)),
}),
Err(UError::NetError(err, _)) => Err(UError::NetError(err, resp)),
_ => unreachable!(),
}
.map_err(From::from)
}
// get jobs for client
pub async fn get_personal_jobs(&self, url_param: Uuid) -> Result<Vec<models::AssignedJob>> {
self._req(format!("get_personal_jobs/{}", url_param), Empty)
.await
}
// send something to server
pub async fn report(&self, payload: &[messaging::Reportable]) -> Result<Empty> {
self._req("report", payload).await
}
// download file
pub async fn dl(&self, file: String) -> Result<Vec<u8>> {
self._req(format!("dl/{file}"), Empty).await
}
}
//##########// Admin area //##########//
#[cfg(feature = "panel")]
impl ClientHandler {
/// agent listing
pub async fn get_agents(&self, agent: Option<Uuid>) -> Result<Vec<models::Agent>> {
self._req(format!("get_agents/{}", opt_to_string(agent)), Empty)
.await
}
/// update something
pub async fn update_item(&self, item: impl AsMsg + Debug) -> Result<Empty> {
self._req("update_item", item).await
}
/// get all available jobs
pub async fn get_jobs(&self, job: Option<Uuid>) -> Result<Vec<models::JobMeta>> {
self._req(format!("get_jobs/{}", opt_to_string(job)), Empty)
.await
}
/// create and upload job
pub async fn upload_jobs(&self, payload: &[models::JobMeta]) -> Result<Empty> {
self._req("upload_jobs", payload).await
}
/// delete something
pub async fn del(&self, item: Uuid) -> Result<i32> {
self._req(format!("del/{item}"), Empty).await
}
/// set jobs for any agent
pub async fn set_jobs(&self, agent: Uuid, job_idents: &[String]) -> Result<Vec<Uuid>> {
self._req(format!("set_jobs/{agent}"), job_idents).await
}
/// get jobs for any agent
pub async fn get_agent_jobs(&self, agent: Option<Uuid>) -> Result<Vec<models::AssignedJob>> {
self._req(format!("get_personal_jobs/{}", opt_to_string(agent)), Empty)
.await
}
}