migrate to sqlx

master
plazmoid 3 months ago
parent 65cb0d875b
commit ee7e437357
  1. 2
      .env
  2. 769
      Cargo.lock
  3. 2
      Cargo.toml
  4. 31
      migrations/20250210064035_initial.sql
  5. 2
      src/config.rs
  6. 2
      src/error.rs
  7. 180
      src/main.rs
  8. 12
      src/models.rs
  9. 11
      src/repos/mod.rs
  10. 224
      src/repos/sqlite.rs

@ -1,6 +1,6 @@
export PAIRS="BTC_USDT,TRX_USDT,ETH_USDT,DOGE_USDT,BCH_USDT"
export INTERVALS="MINUTE_1,MINUTE_15,HOUR_1,DAY_1"
export DB_NAME="./poloniex_data.db"
export DATABASE_URL="sqlite://./poloniex_data.db"
export POLONIEX_REST_URL="https://api.poloniex.com"
export POLONIEX_WS_URL="wss://ws.poloniex.com/ws/public"

769
Cargo.lock generated

File diff suppressed because it is too large Load Diff

@ -12,10 +12,10 @@ envy = "0.4.2"
futures-util = "0.3.31"
reqwest = { version = "0.12.12", features = ["json"] }
reqwest-websocket = "0.4.4"
rusqlite = "0.33.0"
serde = { version = "1.0.217", features = ["derive"] }
serde_json = "1.0.138"
serde_tuple = "1.1.0"
sqlx = { version = "0.8.3", features = ["chrono", "runtime-tokio", "sqlite"] }
strum = { version = "0.26.3", features = ["derive"] }
thiserror = "2.0.11"
tokio = { version = "1.43.0", features = ["rt-multi-thread", "macros"] }

@ -0,0 +1,31 @@
-- Add migration script here
CREATE TABLE IF NOT EXISTS trades(
symbol TEXT NOT NULL,
amount REAL NOT NULL,
taker_side TEXT NOT NULL,
quantity REAL NOT NULL,
create_time DATETIME NOT NULL,
price REAL NOT NULL,
id TEXT NOT NULL,
ts INT NOT NULL
);
CREATE TABLE IF NOT EXISTS candles(
pair TEXT NOT NULL,
interval TEXT NOT NULL,
low REAL NOT NULL,
high REAL NOT NULL,
open REAL NOT NULL,
close REAL NOT NULL,
amount REAL NOT NULL,
quantity REAL NOT NULL,
buy_taker_amount REAL NOT NULL,
buy_taker_quantity REAL NOT NULL,
trade_count INT NOT NULL,
ts INT NOT NULL,
weighted_average REAL NOT NULL,
start_time DATETIME NOT NULL,
close_time DATETIME NOT NULL,
PRIMARY KEY(pair, interval, start_time)
);

@ -14,7 +14,7 @@ pub struct Config {
pub poloniex_rest_url: Url,
#[serde(deserialize_with = "deser_url")]
pub poloniex_ws_url: Url,
pub db_name: String,
pub database_url: String,
}
fn deser_url<'de, D: Deserializer<'de>>(deserialize: D) -> Result<Url, D::Error> {

@ -17,7 +17,7 @@ pub enum AppError {
SerdeError(#[from] serde_json::Error),
#[error(transparent)]
DbError(#[from] rusqlite::Error),
DbError(#[from] sqlx::Error),
#[error(transparent)]
StrumError(#[from] strum::ParseError),

@ -5,7 +5,7 @@ use config::get_config;
use error::AppResult;
use futures_util::{future::try_join_all, StreamExt};
use markets::{poloniex::PoloniexClient, Market};
use models::{Candle, CandleExtended, CandleInterval, TradeDirection};
use models::{Candle, CandleExtended, CandleInterval, Trade, TradeDirection};
use repos::{sqlite::SqliteRepo, Repo};
mod config;
@ -27,7 +27,11 @@ async fn fetch_candles_until_now(
let limit = 500;
loop {
println!("pulling candles from {start_time}");
println!(
"pulling {}:{} candles from {start_time}",
pair,
interval.as_ref()
);
let candles = market_client
.get_historical_candles(&pair, interval, start_time, Utc::now().naive_utc(), limit)
.await?;
@ -64,19 +68,56 @@ async fn fetch_candles_until_now(
Ok((result, pair.to_string()))
}
async fn trades_processor(
async fn calculate_new_candles(
repo: Arc<impl Repo>,
market_client: Arc<impl Market>,
pairs: &[String],
interval: CandleInterval,
trade: Trade,
) -> AppResult<()> {
let mut trades = market_client.recent_trades_stream(&pairs).await?;
let is_buy = matches!(trade.taker_side, TradeDirection::Buy);
let insert_new_candle = || async {
let interval_secs = match interval {
CandleInterval::M1 => 60,
CandleInterval::M15 => 60 * 15,
CandleInterval::H1 => 60 * 60,
CandleInterval::D1 => 60 * 60 * 24,
};
let new_candle_ts = DateTime::from_timestamp(
(trade.ts.and_utc().timestamp() / interval_secs) * interval_secs,
0,
)
.unwrap()
.naive_utc();
let new_candle = CandleExtended {
candle: Candle {
low: trade.price,
high: trade.price,
open: trade.price,
close: trade.price,
amount: trade.amount,
quantity: trade.quantity,
buy_taker_amount: if is_buy { trade.amount } else { 0.0 },
buy_taker_quantity: if is_buy { trade.quantity } else { 0.0 },
trade_count: 1,
ts: trade.ts,
weighted_average: trade.amount / trade.quantity,
interval,
start_time: new_candle_ts,
close_time: NaiveDateTime::UNIX_EPOCH,
},
pair: trade.symbol.clone(),
};
while let Some(t) = trades.next().await {
println!("{t:?}");
repo.upsert_candles(&[new_candle]).await?;
let Ok(trade) = t else { break };
let mut last_candle = repo.get_latest_candle_from_interval(&trade.symbol, interval)?;
AppResult::Ok(())
};
let last_candle = repo
.get_latest_candle_from_interval(&trade.symbol, interval)
.await?;
if let Some(mut last_candle) = last_candle {
let interval_delta = match last_candle.candle.interval {
CandleInterval::M1 => TimeDelta::minutes(1),
CandleInterval::M15 => TimeDelta::minutes(15),
@ -84,44 +125,9 @@ async fn trades_processor(
CandleInterval::D1 => TimeDelta::days(1),
};
let is_buy = matches!(trade.taker_side, TradeDirection::Buy);
// если трейд не входит в интервал последней свечи, то создаём новую свечу, иначе обновляем предыдущую
if trade.ts > (last_candle.candle.ts + interval_delta) {
let interval_secs = match last_candle.candle.interval {
CandleInterval::M1 => 60,
CandleInterval::M15 => 60 * 15,
CandleInterval::H1 => 60 * 60,
CandleInterval::D1 => 60 * 60 * 24,
};
let new_candle_ts = DateTime::from_timestamp(
(trade.ts.and_utc().timestamp() / interval_secs) * interval_secs,
0,
)
.unwrap()
.naive_utc();
let new_candle = CandleExtended {
candle: Candle {
low: trade.price,
high: trade.price,
open: trade.price,
close: trade.price,
amount: trade.amount,
quantity: trade.quantity,
buy_taker_amount: if is_buy { trade.amount } else { 0.0 },
buy_taker_quantity: if is_buy { trade.quantity } else { 0.0 },
trade_count: 1,
ts: trade.ts,
weighted_average: trade.amount / trade.quantity,
interval,
start_time: new_candle_ts,
close_time: NaiveDateTime::UNIX_EPOCH,
},
pair: trade.symbol.clone(),
};
repo.upsert_candle(&new_candle)?;
insert_new_candle().await?;
} else {
last_candle.candle.low = last_candle.candle.low.min(trade.price);
last_candle.candle.high = last_candle.candle.high.max(trade.price);
@ -140,12 +146,14 @@ async fn trades_processor(
last_candle.candle.buy_taker_quantity += trade.quantity;
}
repo.upsert_candle(&last_candle)?;
repo.upsert_candles(&[last_candle]).await?;
}
repo.insert_trade(&trade)?;
} else {
insert_new_candle().await?;
}
repo.insert_trade(&trade).await?;
Ok(())
}
@ -157,9 +165,8 @@ async fn _main() -> AppResult<()> {
&config.poloniex_rest_url,
&config.poloniex_ws_url,
));
let repo = Arc::new(SqliteRepo::new_init(config.db_name)?);
let start_time = NaiveDate::from_ymd_opt(2024, 12, 1)
let repo = Arc::new(SqliteRepo::new(&config.database_url).await?);
let base_start_time = NaiveDate::from_ymd_opt(2024, 12, 1)
.unwrap()
.and_hms_opt(0, 0, 0)
.unwrap();
@ -168,6 +175,17 @@ async fn _main() -> AppResult<()> {
for pair in &config.pairs {
for interval in &config.intervals {
let start_time = {
let last_candle = repo
.get_latest_candle_from_interval(&pair, *interval)
.await?;
match last_candle {
Some(c) => c.candle.close_time + TimeDelta::seconds(1),
None => base_start_time,
}
};
let fetcher = fetch_candles_until_now(
poloniex_client.clone(),
pair.to_string(),
@ -185,31 +203,41 @@ async fn _main() -> AppResult<()> {
// config.interval.as_ref()
// );
// нельзя так делать, нужно использовать транзакцию
// и батч-вставку для уменьшения количества обращений к бд,
// но в контексте тестового и так сойдёт
for (candles, pair) in fetched_candles {
for candle in candles {
repo.upsert_candle(&CandleExtended {
candle,
pair: pair.clone(),
})?;
}
}
let candles_to_upsert = fetched_candles
.into_iter()
.flat_map(|(candles, pair)| {
candles
.into_iter()
.map(|candle| CandleExtended {
candle,
pair: pair.clone(),
})
.collect::<Vec<_>>()
})
.collect::<Vec<CandleExtended>>();
repo.upsert_candles(&candles_to_upsert).await?;
for interval in &config.intervals {
tokio::spawn({
let poloniex_client = poloniex_client.clone();
let repo = repo.clone();
let pairs = config.pairs.clone();
let interval = *interval;
async move {
let result = trades_processor(repo, poloniex_client, &pairs, interval).await;
if let Err(e) = result {
eprintln!("processor stopped with error: {e}")
let mut trades = poloniex_client.recent_trades_stream(&config.pairs).await?;
while let Some(t) = trades.next().await {
println!("{t:?}");
let Ok(trade) = t else { break };
for interval in &config.intervals {
tokio::spawn({
let repo = repo.clone();
let interval = *interval;
let trade = trade.clone();
async move {
let result = calculate_new_candles(repo, interval, trade).await;
if let Err(e) = result {
eprintln!("processor stopped with error: {e:?}")
}
}
}
});
});
}
}
Ok(())
}
@ -217,6 +245,6 @@ async fn _main() -> AppResult<()> {
#[tokio::main]
async fn main() {
if let Err(e) = _main().await {
eprintln!("{e}");
eprintln!("{e:?}");
}
}

@ -1,8 +1,9 @@
use chrono::{DateTime, NaiveDateTime};
use serde::{Deserialize, Deserializer, Serialize};
use serde_tuple::Deserialize_tuple;
use sqlx::prelude::{FromRow, Type};
#[derive(strum::EnumString, strum::AsRefStr, Clone, Copy, Deserialize, Debug)]
#[derive(strum::EnumString, strum::AsRefStr, Clone, Copy, Deserialize, Debug, Type)]
pub enum CandleInterval {
#[strum(serialize = "MINUTE_1")]
#[serde(rename = "MINUTE_1")]
@ -18,7 +19,7 @@ pub enum CandleInterval {
D1,
}
#[derive(Debug, Deserialize_tuple)]
#[derive(Debug, Deserialize_tuple, FromRow)]
pub struct Candle {
#[serde(deserialize_with = "deser_str_to_int")]
pub low: f64,
@ -48,7 +49,9 @@ pub struct Candle {
pub close_time: NaiveDateTime,
}
#[derive(FromRow)]
pub struct CandleExtended {
#[sqlx(flatten)]
pub candle: Candle,
pub pair: String,
}
@ -66,7 +69,7 @@ fn deser_naive_dt<'de, D: Deserializer<'de>>(deserialize: D) -> Result<NaiveDate
.map(|dt| dt.naive_utc())
}
#[derive(Deserialize, Debug)]
#[derive(Deserialize, Debug, Clone)]
#[serde(rename_all = "camelCase")]
pub struct Trade {
pub symbol: String,
@ -84,8 +87,9 @@ pub struct Trade {
pub ts: NaiveDateTime,
}
#[derive(Deserialize, Debug, strum::AsRefStr)]
#[derive(Deserialize, Debug, strum::AsRefStr, Clone, Copy, Type)]
#[serde(rename_all = "camelCase")]
#[sqlx(rename_all = "camelCase")]
pub enum TradeDirection {
Buy,
Sell,

@ -5,14 +5,15 @@ use crate::{
pub mod sqlite;
pub trait Repo {
fn upsert_candle(&self, candle: &CandleExtended) -> AppResult<usize>;
#[async_trait]
pub trait Repo: Send + Sync {
async fn upsert_candles(&self, candles: &[CandleExtended]) -> AppResult<u64>;
fn insert_trade(&self, trade: &Trade) -> AppResult<usize>;
async fn insert_trade(&self, trade: &Trade) -> AppResult<u64>;
fn get_latest_candle_from_interval(
async fn get_latest_candle_from_interval(
&self,
pair: &str,
interval: CandleInterval,
) -> AppResult<CandleExtended>;
) -> AppResult<Option<CandleExtended>>;
}

@ -1,71 +1,27 @@
use std::{fs, path::Path, str::FromStr};
use chrono::DateTime;
use rusqlite::{self, params, Connection};
use sqlx::SqlitePool;
use crate::{
error::{AppError, AppResult},
models::{Candle, CandleExtended, CandleInterval, Trade},
models::{CandleExtended, CandleInterval, Trade},
};
use super::Repo;
pub struct SqliteRepo {
conn: Connection,
pool: SqlitePool,
}
impl SqliteRepo {
pub fn new_init(db_path: impl AsRef<Path>) -> AppResult<Self> {
let path = db_path.as_ref();
// постоянно создаём новую бд для упрощения тестирования
fs::remove_file(path).ok();
let conn = Connection::open(path)?;
conn.execute_batch(
"BEGIN;
CREATE TABLE IF NOT EXISTS trades(
symbol TEXT NOT NULL,
amount REAL NOT NULL,
taker_side TEXT NOT NULL,
quantity REAL NOT NULL,
create_time INT NOT NULL,
price REAL NOT NULL,
id TEXT NOT NULL,
ts INT NOT NULL
);
CREATE TABLE IF NOT EXISTS candles(
low REAL NOT NULL,
high REAL NOT NULL,
open REAL NOT NULL,
close REAL NOT NULL,
amount REAL NOT NULL,
quantity REAL NOT NULL,
buy_taker_amount REAL NOT NULL,
buy_taker_quantity REAL NOT NULL,
trade_count INT NOT NULL,
ts INT NOT NULL,
weighted_average REAL NOT NULL,
interval TEXT NOT NULL,
start_time INT NOT NULL,
close_time INT NOT NULL,
pair TEXT NOT NULL,
pub async fn new(db_url: &str) -> AppResult<Self> {
let pool = SqlitePool::connect(db_url).await?;
PRIMARY KEY(pair, interval, start_time)
);
COMMIT;",
)?;
Ok(Self { conn })
Ok(Self { pool })
}
}
#[async_trait]
impl Repo for SqliteRepo {
fn upsert_candle(&self, candle: &CandleExtended) -> AppResult<usize> {
async fn upsert_candles(&self, candles: &[CandleExtended]) -> AppResult<u64> {
let q = "
REPLACE INTO candles(
low,
@ -84,48 +40,53 @@ impl Repo for SqliteRepo {
close_time,
pair
) VALUES (
?1,
?2,
?3,
?4,
?5,
?6,
?7,
?8,
?9,
?10,
?11,
?12,
?13,
?14,
?15
$1,
$2,
$3,
$4,
$5,
$6,
$7,
$8,
$9,
$10,
$11,
$12,
$13,
$14,
$15
)
";
self.conn
.execute(
q,
params![
&candle.candle.low,
&candle.candle.high,
&candle.candle.open,
&candle.candle.close,
&candle.candle.amount,
&candle.candle.quantity,
&candle.candle.buy_taker_amount,
&candle.candle.buy_taker_quantity,
&candle.candle.trade_count,
&candle.candle.ts.and_utc().timestamp_millis(),
&candle.candle.weighted_average,
&candle.candle.interval.as_ref(),
&candle.candle.start_time.and_utc().timestamp_millis(),
&candle.candle.close_time.and_utc().timestamp_millis(),
&candle.pair
],
)
.map_err(AppError::from)
let mut tx = self.pool.begin().await?;
let mut affected = 0;
for candle in candles {
affected += sqlx::query(q)
.bind(&candle.candle.low)
.bind(&candle.candle.high)
.bind(&candle.candle.open)
.bind(&candle.candle.close)
.bind(&candle.candle.amount)
.bind(&candle.candle.quantity)
.bind(&candle.candle.buy_taker_amount)
.bind(&candle.candle.buy_taker_quantity)
.bind(&candle.candle.trade_count)
.bind(&candle.candle.ts)
.bind(&candle.candle.weighted_average)
.bind(&candle.candle.interval)
.bind(&candle.candle.start_time)
.bind(&candle.candle.close_time)
.bind(&candle.pair)
.execute(&mut *tx)
.await?
.rows_affected();
}
tx.commit().await?;
Ok(affected)
}
fn insert_trade(&self, trade: &Trade) -> AppResult<usize> {
async fn insert_trade(&self, trade: &Trade) -> AppResult<u64> {
let q = "
INSERT INTO trades(
symbol,
@ -137,72 +98,49 @@ impl Repo for SqliteRepo {
id,
ts
) VALUES (
?1,
?2,
?3,
?4,
?5,
?6,
?7,
?8
$1,
$2,
$3,
$4,
$5,
$6,
$7,
$8
);
";
self.conn
.execute(
&q,
params![
&trade.symbol,
&trade.amount,
&trade.taker_side.as_ref(),
&trade.quantity,
&trade.create_time.and_utc().timestamp_millis(),
&trade.price,
&trade.id,
&trade.ts.and_utc().timestamp_millis()
],
)
sqlx::query(&q)
.bind(&trade.symbol)
.bind(&trade.amount)
.bind(&trade.taker_side)
.bind(&trade.quantity)
.bind(&trade.create_time)
.bind(&trade.price)
.bind(&trade.id)
.bind(&trade.ts)
.execute(&self.pool)
.await
.map_err(AppError::from)
.map(|r| r.rows_affected())
}
fn get_latest_candle_from_interval(
async fn get_latest_candle_from_interval(
&self,
pair: &str,
interval: CandleInterval,
) -> AppResult<CandleExtended> {
) -> AppResult<Option<CandleExtended>> {
let q = "
SELECT * FROM candles
WHERE pair = ?1 AND interval = ?2
WHERE pair = $1 AND interval = $2
ORDER BY start_time DESC
LIMIT 1
";
self.conn
.query_row(&q, params![pair, interval.as_ref()], |row| {
Ok(CandleExtended {
candle: Candle {
low: row.get(0)?,
high: row.get(1)?,
open: row.get(2)?,
close: row.get(3)?,
amount: row.get(4)?,
quantity: row.get(5)?,
buy_taker_amount: row.get(6)?,
buy_taker_quantity: row.get(7)?,
trade_count: row.get(8)?,
ts: DateTime::from_timestamp(row.get(9)?, 0)
.unwrap()
.naive_local(),
weighted_average: row.get(10)?,
interval: FromStr::from_str(&row.get::<_, String>(11)?).unwrap(),
start_time: DateTime::from_timestamp(row.get(12)?, 0)
.unwrap()
.naive_local(),
close_time: DateTime::from_timestamp(row.get(13)?, 0)
.unwrap()
.naive_local(),
},
pair: row.get(14)?,
})
})
sqlx::query_as::<_, CandleExtended>(q)
.bind(pair)
.bind(interval)
.fetch_optional(&self.pool)
.await
.map_err(AppError::from)
}
}

Loading…
Cancel
Save