You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
208 lines
6.1 KiB
208 lines
6.1 KiB
use std::{fs, path::Path, str::FromStr}; |
|
|
|
use chrono::DateTime; |
|
use rusqlite::{self, params, Connection}; |
|
|
|
use crate::{ |
|
error::{AppError, AppResult}, |
|
models::{Candle, CandleExtended, CandleInterval, Trade}, |
|
}; |
|
|
|
use super::Repo; |
|
|
|
pub struct SqliteRepo { |
|
conn: Connection, |
|
} |
|
|
|
impl SqliteRepo { |
|
pub fn new_init(db_path: impl AsRef<Path>) -> AppResult<Self> { |
|
let path = db_path.as_ref(); |
|
|
|
// постоянно создаём новую бд для упрощения тестирования |
|
fs::remove_file(path).ok(); |
|
|
|
let conn = Connection::open(path)?; |
|
|
|
conn.execute_batch( |
|
"BEGIN; |
|
|
|
CREATE TABLE IF NOT EXISTS trades( |
|
symbol TEXT NOT NULL, |
|
amount REAL NOT NULL, |
|
taker_side TEXT NOT NULL, |
|
quantity REAL NOT NULL, |
|
create_time INT NOT NULL, |
|
price REAL NOT NULL, |
|
id TEXT NOT NULL, |
|
ts INT NOT NULL |
|
); |
|
|
|
CREATE TABLE IF NOT EXISTS candles( |
|
low REAL NOT NULL, |
|
high REAL NOT NULL, |
|
open REAL NOT NULL, |
|
close REAL NOT NULL, |
|
amount REAL NOT NULL, |
|
quantity REAL NOT NULL, |
|
buy_taker_amount REAL NOT NULL, |
|
buy_taker_quantity REAL NOT NULL, |
|
trade_count INT NOT NULL, |
|
ts INT NOT NULL, |
|
weighted_average REAL NOT NULL, |
|
interval TEXT NOT NULL, |
|
start_time INT NOT NULL, |
|
close_time INT NOT NULL, |
|
pair TEXT NOT NULL, |
|
|
|
PRIMARY KEY(pair, interval, start_time) |
|
); |
|
|
|
COMMIT;", |
|
)?; |
|
|
|
Ok(Self { conn }) |
|
} |
|
} |
|
|
|
impl Repo for SqliteRepo { |
|
fn upsert_candle(&self, candle: &CandleExtended) -> AppResult<usize> { |
|
let q = " |
|
REPLACE INTO candles( |
|
low, |
|
high, |
|
open, |
|
close, |
|
amount, |
|
quantity, |
|
buy_taker_amount, |
|
buy_taker_quantity, |
|
trade_count, |
|
ts, |
|
weighted_average, |
|
interval, |
|
start_time, |
|
close_time, |
|
pair |
|
) VALUES ( |
|
?1, |
|
?2, |
|
?3, |
|
?4, |
|
?5, |
|
?6, |
|
?7, |
|
?8, |
|
?9, |
|
?10, |
|
?11, |
|
?12, |
|
?13, |
|
?14, |
|
?15 |
|
) |
|
"; |
|
self.conn |
|
.execute( |
|
q, |
|
params![ |
|
&candle.candle.low, |
|
&candle.candle.high, |
|
&candle.candle.open, |
|
&candle.candle.close, |
|
&candle.candle.amount, |
|
&candle.candle.quantity, |
|
&candle.candle.buy_taker_amount, |
|
&candle.candle.buy_taker_quantity, |
|
&candle.candle.trade_count, |
|
&candle.candle.ts.and_utc().timestamp_millis(), |
|
&candle.candle.weighted_average, |
|
&candle.candle.interval.as_ref(), |
|
&candle.candle.start_time.and_utc().timestamp_millis(), |
|
&candle.candle.close_time.and_utc().timestamp_millis(), |
|
&candle.pair |
|
], |
|
) |
|
.map_err(AppError::from) |
|
} |
|
|
|
fn insert_trade(&self, trade: &Trade) -> AppResult<usize> { |
|
let q = " |
|
INSERT INTO trades( |
|
symbol, |
|
amount, |
|
taker_side, |
|
quantity, |
|
create_time, |
|
price, |
|
id, |
|
ts |
|
) VALUES ( |
|
?1, |
|
?2, |
|
?3, |
|
?4, |
|
?5, |
|
?6, |
|
?7, |
|
?8 |
|
); |
|
"; |
|
|
|
self.conn |
|
.execute( |
|
&q, |
|
params![ |
|
&trade.symbol, |
|
&trade.amount, |
|
&trade.taker_side.as_ref(), |
|
&trade.quantity, |
|
&trade.create_time.and_utc().timestamp_millis(), |
|
&trade.price, |
|
&trade.id, |
|
&trade.ts.and_utc().timestamp_millis() |
|
], |
|
) |
|
.map_err(AppError::from) |
|
} |
|
|
|
fn get_latest_candle_from_interval( |
|
&self, |
|
pair: &str, |
|
interval: CandleInterval, |
|
) -> AppResult<CandleExtended> { |
|
let q = " |
|
SELECT * FROM candles |
|
WHERE pair = ?1 AND interval = ?2 |
|
"; |
|
|
|
self.conn |
|
.query_row(&q, params![pair, interval.as_ref()], |row| { |
|
Ok(CandleExtended { |
|
candle: Candle { |
|
low: row.get(0)?, |
|
high: row.get(1)?, |
|
open: row.get(2)?, |
|
close: row.get(3)?, |
|
amount: row.get(4)?, |
|
quantity: row.get(5)?, |
|
buy_taker_amount: row.get(6)?, |
|
buy_taker_quantity: row.get(7)?, |
|
trade_count: row.get(8)?, |
|
ts: DateTime::from_timestamp(row.get(9)?, 0) |
|
.unwrap() |
|
.naive_local(), |
|
weighted_average: row.get(10)?, |
|
interval: FromStr::from_str(&row.get::<_, String>(11)?).unwrap(), |
|
start_time: DateTime::from_timestamp(row.get(12)?, 0) |
|
.unwrap() |
|
.naive_local(), |
|
close_time: DateTime::from_timestamp(row.get(13)?, 0) |
|
.unwrap() |
|
.naive_local(), |
|
}, |
|
pair: row.get(14)?, |
|
}) |
|
}) |
|
.map_err(AppError::from) |
|
} |
|
}
|
|
|