migrate to sqlx

master
plazmoid 3 months ago
parent 65cb0d875b
commit ee7e437357
  1. 2
      .env
  2. 769
      Cargo.lock
  3. 2
      Cargo.toml
  4. 31
      migrations/20250210064035_initial.sql
  5. 2
      src/config.rs
  6. 2
      src/error.rs
  7. 116
      src/main.rs
  8. 12
      src/models.rs
  9. 11
      src/repos/mod.rs
  10. 224
      src/repos/sqlite.rs

@ -1,6 +1,6 @@
export PAIRS="BTC_USDT,TRX_USDT,ETH_USDT,DOGE_USDT,BCH_USDT" export PAIRS="BTC_USDT,TRX_USDT,ETH_USDT,DOGE_USDT,BCH_USDT"
export INTERVALS="MINUTE_1,MINUTE_15,HOUR_1,DAY_1" export INTERVALS="MINUTE_1,MINUTE_15,HOUR_1,DAY_1"
export DB_NAME="./poloniex_data.db" export DATABASE_URL="sqlite://./poloniex_data.db"
export POLONIEX_REST_URL="https://api.poloniex.com" export POLONIEX_REST_URL="https://api.poloniex.com"
export POLONIEX_WS_URL="wss://ws.poloniex.com/ws/public" export POLONIEX_WS_URL="wss://ws.poloniex.com/ws/public"

769
Cargo.lock generated

File diff suppressed because it is too large Load Diff

@ -12,10 +12,10 @@ envy = "0.4.2"
futures-util = "0.3.31" futures-util = "0.3.31"
reqwest = { version = "0.12.12", features = ["json"] } reqwest = { version = "0.12.12", features = ["json"] }
reqwest-websocket = "0.4.4" reqwest-websocket = "0.4.4"
rusqlite = "0.33.0"
serde = { version = "1.0.217", features = ["derive"] } serde = { version = "1.0.217", features = ["derive"] }
serde_json = "1.0.138" serde_json = "1.0.138"
serde_tuple = "1.1.0" serde_tuple = "1.1.0"
sqlx = { version = "0.8.3", features = ["chrono", "runtime-tokio", "sqlite"] }
strum = { version = "0.26.3", features = ["derive"] } strum = { version = "0.26.3", features = ["derive"] }
thiserror = "2.0.11" thiserror = "2.0.11"
tokio = { version = "1.43.0", features = ["rt-multi-thread", "macros"] } tokio = { version = "1.43.0", features = ["rt-multi-thread", "macros"] }

@ -0,0 +1,31 @@
-- Add migration script here
CREATE TABLE IF NOT EXISTS trades(
symbol TEXT NOT NULL,
amount REAL NOT NULL,
taker_side TEXT NOT NULL,
quantity REAL NOT NULL,
create_time DATETIME NOT NULL,
price REAL NOT NULL,
id TEXT NOT NULL,
ts INT NOT NULL
);
CREATE TABLE IF NOT EXISTS candles(
pair TEXT NOT NULL,
interval TEXT NOT NULL,
low REAL NOT NULL,
high REAL NOT NULL,
open REAL NOT NULL,
close REAL NOT NULL,
amount REAL NOT NULL,
quantity REAL NOT NULL,
buy_taker_amount REAL NOT NULL,
buy_taker_quantity REAL NOT NULL,
trade_count INT NOT NULL,
ts INT NOT NULL,
weighted_average REAL NOT NULL,
start_time DATETIME NOT NULL,
close_time DATETIME NOT NULL,
PRIMARY KEY(pair, interval, start_time)
);

@ -14,7 +14,7 @@ pub struct Config {
pub poloniex_rest_url: Url, pub poloniex_rest_url: Url,
#[serde(deserialize_with = "deser_url")] #[serde(deserialize_with = "deser_url")]
pub poloniex_ws_url: Url, pub poloniex_ws_url: Url,
pub db_name: String, pub database_url: String,
} }
fn deser_url<'de, D: Deserializer<'de>>(deserialize: D) -> Result<Url, D::Error> { fn deser_url<'de, D: Deserializer<'de>>(deserialize: D) -> Result<Url, D::Error> {

@ -17,7 +17,7 @@ pub enum AppError {
SerdeError(#[from] serde_json::Error), SerdeError(#[from] serde_json::Error),
#[error(transparent)] #[error(transparent)]
DbError(#[from] rusqlite::Error), DbError(#[from] sqlx::Error),
#[error(transparent)] #[error(transparent)]
StrumError(#[from] strum::ParseError), StrumError(#[from] strum::ParseError),

@ -5,7 +5,7 @@ use config::get_config;
use error::AppResult; use error::AppResult;
use futures_util::{future::try_join_all, StreamExt}; use futures_util::{future::try_join_all, StreamExt};
use markets::{poloniex::PoloniexClient, Market}; use markets::{poloniex::PoloniexClient, Market};
use models::{Candle, CandleExtended, CandleInterval, TradeDirection}; use models::{Candle, CandleExtended, CandleInterval, Trade, TradeDirection};
use repos::{sqlite::SqliteRepo, Repo}; use repos::{sqlite::SqliteRepo, Repo};
mod config; mod config;
@ -27,7 +27,11 @@ async fn fetch_candles_until_now(
let limit = 500; let limit = 500;
loop { loop {
println!("pulling candles from {start_time}"); println!(
"pulling {}:{} candles from {start_time}",
pair,
interval.as_ref()
);
let candles = market_client let candles = market_client
.get_historical_candles(&pair, interval, start_time, Utc::now().naive_utc(), limit) .get_historical_candles(&pair, interval, start_time, Utc::now().naive_utc(), limit)
.await?; .await?;
@ -64,31 +68,14 @@ async fn fetch_candles_until_now(
Ok((result, pair.to_string())) Ok((result, pair.to_string()))
} }
async fn trades_processor( async fn calculate_new_candles(
repo: Arc<impl Repo>, repo: Arc<impl Repo>,
market_client: Arc<impl Market>,
pairs: &[String],
interval: CandleInterval, interval: CandleInterval,
trade: Trade,
) -> AppResult<()> { ) -> AppResult<()> {
let mut trades = market_client.recent_trades_stream(&pairs).await?;
while let Some(t) = trades.next().await {
println!("{t:?}");
let Ok(trade) = t else { break };
let mut last_candle = repo.get_latest_candle_from_interval(&trade.symbol, interval)?;
let interval_delta = match last_candle.candle.interval {
CandleInterval::M1 => TimeDelta::minutes(1),
CandleInterval::M15 => TimeDelta::minutes(15),
CandleInterval::H1 => TimeDelta::hours(1),
CandleInterval::D1 => TimeDelta::days(1),
};
let is_buy = matches!(trade.taker_side, TradeDirection::Buy); let is_buy = matches!(trade.taker_side, TradeDirection::Buy);
let insert_new_candle = || async {
// если трейд не входит в интервал последней свечи, то создаём новую свечу, иначе обновляем предыдущую let interval_secs = match interval {
if trade.ts > (last_candle.candle.ts + interval_delta) {
let interval_secs = match last_candle.candle.interval {
CandleInterval::M1 => 60, CandleInterval::M1 => 60,
CandleInterval::M15 => 60 * 15, CandleInterval::M15 => 60 * 15,
CandleInterval::H1 => 60 * 60, CandleInterval::H1 => 60 * 60,
@ -121,7 +108,26 @@ async fn trades_processor(
pair: trade.symbol.clone(), pair: trade.symbol.clone(),
}; };
repo.upsert_candle(&new_candle)?; repo.upsert_candles(&[new_candle]).await?;
AppResult::Ok(())
};
let last_candle = repo
.get_latest_candle_from_interval(&trade.symbol, interval)
.await?;
if let Some(mut last_candle) = last_candle {
let interval_delta = match last_candle.candle.interval {
CandleInterval::M1 => TimeDelta::minutes(1),
CandleInterval::M15 => TimeDelta::minutes(15),
CandleInterval::H1 => TimeDelta::hours(1),
CandleInterval::D1 => TimeDelta::days(1),
};
// если трейд не входит в интервал последней свечи, то создаём новую свечу, иначе обновляем предыдущую
if trade.ts > (last_candle.candle.ts + interval_delta) {
insert_new_candle().await?;
} else { } else {
last_candle.candle.low = last_candle.candle.low.min(trade.price); last_candle.candle.low = last_candle.candle.low.min(trade.price);
last_candle.candle.high = last_candle.candle.high.max(trade.price); last_candle.candle.high = last_candle.candle.high.max(trade.price);
@ -140,12 +146,14 @@ async fn trades_processor(
last_candle.candle.buy_taker_quantity += trade.quantity; last_candle.candle.buy_taker_quantity += trade.quantity;
} }
repo.upsert_candle(&last_candle)?; repo.upsert_candles(&[last_candle]).await?;
} }
} else {
repo.insert_trade(&trade)?; insert_new_candle().await?;
} }
repo.insert_trade(&trade).await?;
Ok(()) Ok(())
} }
@ -157,9 +165,8 @@ async fn _main() -> AppResult<()> {
&config.poloniex_rest_url, &config.poloniex_rest_url,
&config.poloniex_ws_url, &config.poloniex_ws_url,
)); ));
let repo = Arc::new(SqliteRepo::new_init(config.db_name)?); let repo = Arc::new(SqliteRepo::new(&config.database_url).await?);
let base_start_time = NaiveDate::from_ymd_opt(2024, 12, 1)
let start_time = NaiveDate::from_ymd_opt(2024, 12, 1)
.unwrap() .unwrap()
.and_hms_opt(0, 0, 0) .and_hms_opt(0, 0, 0)
.unwrap(); .unwrap();
@ -168,6 +175,17 @@ async fn _main() -> AppResult<()> {
for pair in &config.pairs { for pair in &config.pairs {
for interval in &config.intervals { for interval in &config.intervals {
let start_time = {
let last_candle = repo
.get_latest_candle_from_interval(&pair, *interval)
.await?;
match last_candle {
Some(c) => c.candle.close_time + TimeDelta::seconds(1),
None => base_start_time,
}
};
let fetcher = fetch_candles_until_now( let fetcher = fetch_candles_until_now(
poloniex_client.clone(), poloniex_client.clone(),
pair.to_string(), pair.to_string(),
@ -185,38 +203,48 @@ async fn _main() -> AppResult<()> {
// config.interval.as_ref() // config.interval.as_ref()
// ); // );
// нельзя так делать, нужно использовать транзакцию let candles_to_upsert = fetched_candles
// и батч-вставку для уменьшения количества обращений к бд, .into_iter()
// но в контексте тестового и так сойдёт .flat_map(|(candles, pair)| {
for (candles, pair) in fetched_candles { candles
for candle in candles { .into_iter()
repo.upsert_candle(&CandleExtended { .map(|candle| CandleExtended {
candle, candle,
pair: pair.clone(), pair: pair.clone(),
})?; })
} .collect::<Vec<_>>()
} })
.collect::<Vec<CandleExtended>>();
repo.upsert_candles(&candles_to_upsert).await?;
let mut trades = poloniex_client.recent_trades_stream(&config.pairs).await?;
while let Some(t) = trades.next().await {
println!("{t:?}");
let Ok(trade) = t else { break };
for interval in &config.intervals { for interval in &config.intervals {
tokio::spawn({ tokio::spawn({
let poloniex_client = poloniex_client.clone();
let repo = repo.clone(); let repo = repo.clone();
let pairs = config.pairs.clone();
let interval = *interval; let interval = *interval;
let trade = trade.clone();
async move { async move {
let result = trades_processor(repo, poloniex_client, &pairs, interval).await; let result = calculate_new_candles(repo, interval, trade).await;
if let Err(e) = result { if let Err(e) = result {
eprintln!("processor stopped with error: {e}") eprintln!("processor stopped with error: {e:?}")
} }
} }
}); });
} }
}
Ok(()) Ok(())
} }
#[tokio::main] #[tokio::main]
async fn main() { async fn main() {
if let Err(e) = _main().await { if let Err(e) = _main().await {
eprintln!("{e}"); eprintln!("{e:?}");
} }
} }

@ -1,8 +1,9 @@
use chrono::{DateTime, NaiveDateTime}; use chrono::{DateTime, NaiveDateTime};
use serde::{Deserialize, Deserializer, Serialize}; use serde::{Deserialize, Deserializer, Serialize};
use serde_tuple::Deserialize_tuple; use serde_tuple::Deserialize_tuple;
use sqlx::prelude::{FromRow, Type};
#[derive(strum::EnumString, strum::AsRefStr, Clone, Copy, Deserialize, Debug)] #[derive(strum::EnumString, strum::AsRefStr, Clone, Copy, Deserialize, Debug, Type)]
pub enum CandleInterval { pub enum CandleInterval {
#[strum(serialize = "MINUTE_1")] #[strum(serialize = "MINUTE_1")]
#[serde(rename = "MINUTE_1")] #[serde(rename = "MINUTE_1")]
@ -18,7 +19,7 @@ pub enum CandleInterval {
D1, D1,
} }
#[derive(Debug, Deserialize_tuple)] #[derive(Debug, Deserialize_tuple, FromRow)]
pub struct Candle { pub struct Candle {
#[serde(deserialize_with = "deser_str_to_int")] #[serde(deserialize_with = "deser_str_to_int")]
pub low: f64, pub low: f64,
@ -48,7 +49,9 @@ pub struct Candle {
pub close_time: NaiveDateTime, pub close_time: NaiveDateTime,
} }
#[derive(FromRow)]
pub struct CandleExtended { pub struct CandleExtended {
#[sqlx(flatten)]
pub candle: Candle, pub candle: Candle,
pub pair: String, pub pair: String,
} }
@ -66,7 +69,7 @@ fn deser_naive_dt<'de, D: Deserializer<'de>>(deserialize: D) -> Result<NaiveDate
.map(|dt| dt.naive_utc()) .map(|dt| dt.naive_utc())
} }
#[derive(Deserialize, Debug)] #[derive(Deserialize, Debug, Clone)]
#[serde(rename_all = "camelCase")] #[serde(rename_all = "camelCase")]
pub struct Trade { pub struct Trade {
pub symbol: String, pub symbol: String,
@ -84,8 +87,9 @@ pub struct Trade {
pub ts: NaiveDateTime, pub ts: NaiveDateTime,
} }
#[derive(Deserialize, Debug, strum::AsRefStr)] #[derive(Deserialize, Debug, strum::AsRefStr, Clone, Copy, Type)]
#[serde(rename_all = "camelCase")] #[serde(rename_all = "camelCase")]
#[sqlx(rename_all = "camelCase")]
pub enum TradeDirection { pub enum TradeDirection {
Buy, Buy,
Sell, Sell,

@ -5,14 +5,15 @@ use crate::{
pub mod sqlite; pub mod sqlite;
pub trait Repo { #[async_trait]
fn upsert_candle(&self, candle: &CandleExtended) -> AppResult<usize>; pub trait Repo: Send + Sync {
async fn upsert_candles(&self, candles: &[CandleExtended]) -> AppResult<u64>;
fn insert_trade(&self, trade: &Trade) -> AppResult<usize>; async fn insert_trade(&self, trade: &Trade) -> AppResult<u64>;
fn get_latest_candle_from_interval( async fn get_latest_candle_from_interval(
&self, &self,
pair: &str, pair: &str,
interval: CandleInterval, interval: CandleInterval,
) -> AppResult<CandleExtended>; ) -> AppResult<Option<CandleExtended>>;
} }

@ -1,71 +1,27 @@
use std::{fs, path::Path, str::FromStr}; use sqlx::SqlitePool;
use chrono::DateTime;
use rusqlite::{self, params, Connection};
use crate::{ use crate::{
error::{AppError, AppResult}, error::{AppError, AppResult},
models::{Candle, CandleExtended, CandleInterval, Trade}, models::{CandleExtended, CandleInterval, Trade},
}; };
use super::Repo; use super::Repo;
pub struct SqliteRepo { pub struct SqliteRepo {
conn: Connection, pool: SqlitePool,
} }
impl SqliteRepo { impl SqliteRepo {
pub fn new_init(db_path: impl AsRef<Path>) -> AppResult<Self> { pub async fn new(db_url: &str) -> AppResult<Self> {
let path = db_path.as_ref(); let pool = SqlitePool::connect(db_url).await?;
// постоянно создаём новую бд для упрощения тестирования
fs::remove_file(path).ok();
let conn = Connection::open(path)?;
conn.execute_batch(
"BEGIN;
CREATE TABLE IF NOT EXISTS trades(
symbol TEXT NOT NULL,
amount REAL NOT NULL,
taker_side TEXT NOT NULL,
quantity REAL NOT NULL,
create_time INT NOT NULL,
price REAL NOT NULL,
id TEXT NOT NULL,
ts INT NOT NULL
);
CREATE TABLE IF NOT EXISTS candles(
low REAL NOT NULL,
high REAL NOT NULL,
open REAL NOT NULL,
close REAL NOT NULL,
amount REAL NOT NULL,
quantity REAL NOT NULL,
buy_taker_amount REAL NOT NULL,
buy_taker_quantity REAL NOT NULL,
trade_count INT NOT NULL,
ts INT NOT NULL,
weighted_average REAL NOT NULL,
interval TEXT NOT NULL,
start_time INT NOT NULL,
close_time INT NOT NULL,
pair TEXT NOT NULL,
PRIMARY KEY(pair, interval, start_time)
);
COMMIT;",
)?;
Ok(Self { conn }) Ok(Self { pool })
} }
} }
#[async_trait]
impl Repo for SqliteRepo { impl Repo for SqliteRepo {
fn upsert_candle(&self, candle: &CandleExtended) -> AppResult<usize> { async fn upsert_candles(&self, candles: &[CandleExtended]) -> AppResult<u64> {
let q = " let q = "
REPLACE INTO candles( REPLACE INTO candles(
low, low,
@ -84,48 +40,53 @@ impl Repo for SqliteRepo {
close_time, close_time,
pair pair
) VALUES ( ) VALUES (
?1, $1,
?2, $2,
?3, $3,
?4, $4,
?5, $5,
?6, $6,
?7, $7,
?8, $8,
?9, $9,
?10, $10,
?11, $11,
?12, $12,
?13, $13,
?14, $14,
?15 $15
) )
"; ";
self.conn
.execute( let mut tx = self.pool.begin().await?;
q, let mut affected = 0;
params![ for candle in candles {
&candle.candle.low, affected += sqlx::query(q)
&candle.candle.high, .bind(&candle.candle.low)
&candle.candle.open, .bind(&candle.candle.high)
&candle.candle.close, .bind(&candle.candle.open)
&candle.candle.amount, .bind(&candle.candle.close)
&candle.candle.quantity, .bind(&candle.candle.amount)
&candle.candle.buy_taker_amount, .bind(&candle.candle.quantity)
&candle.candle.buy_taker_quantity, .bind(&candle.candle.buy_taker_amount)
&candle.candle.trade_count, .bind(&candle.candle.buy_taker_quantity)
&candle.candle.ts.and_utc().timestamp_millis(), .bind(&candle.candle.trade_count)
&candle.candle.weighted_average, .bind(&candle.candle.ts)
&candle.candle.interval.as_ref(), .bind(&candle.candle.weighted_average)
&candle.candle.start_time.and_utc().timestamp_millis(), .bind(&candle.candle.interval)
&candle.candle.close_time.and_utc().timestamp_millis(), .bind(&candle.candle.start_time)
&candle.pair .bind(&candle.candle.close_time)
], .bind(&candle.pair)
) .execute(&mut *tx)
.map_err(AppError::from) .await?
.rows_affected();
}
tx.commit().await?;
Ok(affected)
} }
fn insert_trade(&self, trade: &Trade) -> AppResult<usize> { async fn insert_trade(&self, trade: &Trade) -> AppResult<u64> {
let q = " let q = "
INSERT INTO trades( INSERT INTO trades(
symbol, symbol,
@ -137,72 +98,49 @@ impl Repo for SqliteRepo {
id, id,
ts ts
) VALUES ( ) VALUES (
?1, $1,
?2, $2,
?3, $3,
?4, $4,
?5, $5,
?6, $6,
?7, $7,
?8 $8
); );
"; ";
self.conn sqlx::query(&q)
.execute( .bind(&trade.symbol)
&q, .bind(&trade.amount)
params![ .bind(&trade.taker_side)
&trade.symbol, .bind(&trade.quantity)
&trade.amount, .bind(&trade.create_time)
&trade.taker_side.as_ref(), .bind(&trade.price)
&trade.quantity, .bind(&trade.id)
&trade.create_time.and_utc().timestamp_millis(), .bind(&trade.ts)
&trade.price, .execute(&self.pool)
&trade.id, .await
&trade.ts.and_utc().timestamp_millis()
],
)
.map_err(AppError::from) .map_err(AppError::from)
.map(|r| r.rows_affected())
} }
fn get_latest_candle_from_interval( async fn get_latest_candle_from_interval(
&self, &self,
pair: &str, pair: &str,
interval: CandleInterval, interval: CandleInterval,
) -> AppResult<CandleExtended> { ) -> AppResult<Option<CandleExtended>> {
let q = " let q = "
SELECT * FROM candles SELECT * FROM candles
WHERE pair = ?1 AND interval = ?2 WHERE pair = $1 AND interval = $2
ORDER BY start_time DESC
LIMIT 1
"; ";
self.conn sqlx::query_as::<_, CandleExtended>(q)
.query_row(&q, params![pair, interval.as_ref()], |row| { .bind(pair)
Ok(CandleExtended { .bind(interval)
candle: Candle { .fetch_optional(&self.pool)
low: row.get(0)?, .await
high: row.get(1)?,
open: row.get(2)?,
close: row.get(3)?,
amount: row.get(4)?,
quantity: row.get(5)?,
buy_taker_amount: row.get(6)?,
buy_taker_quantity: row.get(7)?,
trade_count: row.get(8)?,
ts: DateTime::from_timestamp(row.get(9)?, 0)
.unwrap()
.naive_local(),
weighted_average: row.get(10)?,
interval: FromStr::from_str(&row.get::<_, String>(11)?).unwrap(),
start_time: DateTime::from_timestamp(row.get(12)?, 0)
.unwrap()
.naive_local(),
close_time: DateTime::from_timestamp(row.get(13)?, 0)
.unwrap()
.naive_local(),
},
pair: row.get(14)?,
})
})
.map_err(AppError::from) .map_err(AppError::from)
} }
} }

Loading…
Cancel
Save