1. 我们会用到什么模型呢?' k8 R' V0 c, N: F+ A: p
在之前的投票分类器中,我们用到了一个分类模型和一个回归模型。 在回归模型中,我们在用于计算分类时,用预测价格替代预测价格走势(下跌、上涨、不变)。 然而,在这种情况下,我们不能依据分类得到概率分布,而对于所谓的“软投票”这样是不允许的。
; I2 p/ {5 m6 V6 w我们已准备了 3 个分类模型。 在“如何在 MQL5 中集成 ONNX 模型的示例”一文中已用到两个模型。 第一个模型(回归)被转换为分类模型。 基于 10 个 OHLC 价格序列进行了培训。 第二个模型是分类模型。 基于 63 个收盘价序列进$ r8 E6 t4 @6 x* z
//| https://www.mql5.com |7 \* Q1 M! X0 q% Y# {3 y, d, x
//+------------------------------------------------------------------+* L5 R; N6 }( h# o7 W5 v6 ]3 k
//--- price movement prediction
& X3 M' ^, ?7 L F7 B; R#define PRICE_UP 0. w7 q2 o" l1 C) y& b
#define PRICE_SAME 1
- k; |( @2 j7 W0 W2 K#define PRICE_DOWN 2
" s7 N3 @9 E: }+ A6 A$ s//+------------------------------------------------------------------+% q% Q* A" N5 l- z' [/ O
//| Base class for models based on trained symbol and period |
" a* q! C, u/ Z' G( X; S% {//+------------------------------------------------------------------+
! ~1 w& s# E+ A \( c2 U% g Gclass CModelSymbolPeriod5 B' H" ?7 s& N/ ]$ k3 K6 q
{. w- t: {. N% {( F
protected:
. h( b7 i: w+ o0 c0 ^long m_handle; // created model session handle
% W- b+ `0 N0 Q# N% h" J% Mstring m_symbol; // symbol of trained data6 Q8 J4 W: S6 l, t9 _1 A
ENUM_TIMEFRAMES m_period; // timeframe of trained data
- n5 F' P* ?( P& adatetime m_next_bar; // time of next bar (we work at bar begin only)8 p" ]2 a9 b( _% j- Z
double m_class_delta; // delta to recognize "price the same" in regression models
- c( T. C" T" T2 h" I* Jpublic:
* G) z+ B9 O$ y8 E6 C) ]//+------------------------------------------------------------------+6 G& }! d; Z' X) F3 G
//| Constructor |
2 u& H: R7 e6 U! |//+------------------------------------------------------------------+
- g: d" }1 {. R1 p9 \9 fCModelSymbolPeriod(const string symbol,const ENUM_TIMEFRAMES period,const double class_delta=0.0001)
! I/ C7 n1 c- Z. w k8 H{8 D, y% r- z2 p; J* I
m_handle=INVALID_HANDLE;
8 r: @; R+ v2 M. W/ k( D3 Lm_symbol=symbol;" L+ P$ |) S! V" g1 _2 w0 H9 v
m_period=period;4 ]5 y L; a+ |/ u4 o
m_next_bar=0;
X8 f' p8 C/ i) m) Y Z9 v Y0 f1 Wm_class_delta=class_delta;
9 f8 g6 r8 n- u2 M9 J}) }% K+ u& X1 R. }4 m! S0 b
//+------------------------------------------------------------------+
2 k/ ^4 V$ h# @0 z+ G4 ~; Z//| Destructor |3 E- S+ a2 M3 j! `9 p, \
//| Check for initialization, create model |
8 N' b4 c6 w+ g1 U6 ]7 t//+------------------------------------------------------------------+: d, V% y! N J ~+ T' ^7 e' L
bool CheckInit(const string symbol,const ENUM_TIMEFRAMES period,const uchar& model[])
8 t0 \# }/ ]. r: G% z' ~+ U$ X5 i{
% C: i6 V& x6 Z7 C) v" _& i//--- check symbol, period; m' }& v$ j8 i
if(symbol!=m_symbol || period!=m_period)/ S. K. A$ g8 A! O* X
{
7 z' m2 H; r& u4 L+ cPrintFormat("Model must work with %s,%s",m_symbol,EnumToString(m_period));
: A, g3 Y. G9 }' @- f4 M& [$ A+ Wreturn(false);
2 p+ R. D) B) a. k( c4 q% ~9 Z} b" L- m/ c+ I
//--- create a model from static buffer
# V+ R6 v7 E! tm_handle=OnnxCreateFromBuffer(model,ONNX_DEFAULT);6 d( Y' |+ K+ k* O% ~; _
if(m_handle==INVALID_HANDLE)
. r, G) v- z1 T5 q8 d8 n{
4 T0 b- V% [8 X5 D; W6 T" t0 ^Print("OnnxCreateFromBuffer error ",GetLastError());, ^9 M# ]: _6 u: I* a
return(false);$ I7 D# h% i) l, z) n- _4 }
}+ e9 T9 y: u- ^& d
//--- ok# \' ?0 C Y6 j2 A# [2 |
return(true);
) J# f' s5 f2 Y$ k' k" P$ K% [}
$ p( L8 a8 ^# K. M//+------------------------------------------------------------------+. q4 y5 d& r0 E8 I2 L
m_next_bar=TimeCurrent();
- R. z( o) v: Pm_next_bar-=m_next_bar%PeriodSeconds(m_period);% `( Y4 x: z, a
m_next_bar+=PeriodSeconds(m_period);
. K" f1 e" L4 M. }//--- work on new day bar
8 z0 t/ C$ @- Y8 h! e% v Z1 Kreturn(true);
7 e0 W' N/ w: p! _}8 [6 k' v3 _+ p
//+------------------------------------------------------------------+
. k. T& p( m7 z" o//| virtual stub for PredictPrice (regression model) |
) {' q2 k$ h! W: G//+------------------------------------------------------------------+2 {% i' t$ w( b
virtual double PredictPrice(void)
5 G8 T5 h- m$ D4 j* F{5 Z1 o7 ]8 t1 j; U3 |
return(DBL_MAX);
# D6 } Q0 X4 w, w* b3 Z}
- p% `% a' u; n+ b2 f* L1 G//+------------------------------------------------------------------+- ?( o. A0 Z5 v e
//| Predict class (regression -> classification) |
: ?2 x1 c) J) `, g* I; S: ?//+------------------------------------------------------------------+, u: Q+ j% P) `" j
virtual int PredictClass(void)
7 l7 |1 U+ t9 T{
# o8 h$ \$ D% ?1 _0 |: x8 U3 Edouble predicted_price=PredictPrice();
, G5 Y) Q" C2 N% \if(predicted_price==DBL_MAX)
' z& {2 @/ T$ F+ rreturn(-1);5 s X3 K/ g! {* r
int predicted_class=-1;5 p, h. M0 k# D. Y# a
double last_close=iClose(m_symbol,m_period,1);( p& w h3 i( r0 \! W: d! R
//--- classify predicted price movement; b1 f/ R# v- M
double delta=last_close-predicted_price;
4 B! L' o7 c8 V: J/ z! @if(fabs(delta)<=m_class_delta): o+ w: F* C6 G: {* v
predicted_class=PRICE_SAME;
6 L+ @$ [/ l# x1 m `+ Welse9 I4 w) Y% f6 d7 ?
private:
" `! w; C# y4 @6 Rint m_sample_size;
: B2 v2 Q0 O; |) a; A; l//+------------------------------------------------------------------+! x1 L8 U% A3 e* @0 b
virtual bool Init(const string symbol, const ENUM_TIMEFRAMES period)" T& }4 z* j* ], o% c
{
3 S& s, m$ l; f. c! @//--- check symbol, period, create model; U8 p* q; u& G* b
if(!CModelSymbolPeriod::CheckInit(symbol,period,model_eurusd_D1_10_class))! b" K0 Z* n, V' O; U4 |
{& w6 s: B! a- N. B6 e5 @$ D- m
Print("model_eurusd_D1_10_class : initialization error");. f8 Q3 |7 Y2 w! o
return(false);+ ]+ I* S; y' Q2 c+ q
}
* ?. o5 j9 A$ P5 R- ?//--- since not all sizes defined in the input tensor we must set them explicitly' a! E# A$ c- e4 ~
//--- first index - batch size, second index - series size, third index - number of series (OHLC)1 E( Y0 q; P0 B# {7 ^/ A7 Z
const long input_shape[] = {1,m_sample_size,4};
7 r! Z' V1 C* `8 p$ T+ U7 ^if(!OnnxSetInputShape(m_handle,0,input_shape)), X' p% L$ |4 u
{7 Y% w# {! Y# ~+ b( ]8 c3 f
Print("model_eurusd_D1_10_class : OnnxSetInputShape error ",GetLastError());
6 |) O; ?, w% q& s& Z' k+ |! T* M' v9 vreturn(false);
& B9 v! L _2 _; I& v}
! |7 o+ Y/ ~7 _1 k//--- since not all sizes defined in the output tensor we must set them explicitly
% z: W3 e/ A! g9 A8 C//--- first index - batch size, must match the batch size of the input tensor/ o$ |. [7 e( N2 P; r# O
//--- second index - number of classes (up, same or down)
) f+ z0 s# l4 w% A6 uconst long output_shape[] = {1,3};) X+ K/ v& [3 ]# d4 o
if(!OnnxSetOutputShape(m_handle,0,output_shape))
. j0 b' i( c" c& N. {1 i+ E{$ b/ u$ ]: F q1 c* U; ^
Print("model_eurusd_D1_10_class : OnnxSetOutputShape error ",GetLastError());! A+ B2 _& L( Z2 n& b( J
return(false);' K: S) B) P5 r2 H) T& a& s1 z
}
$ ]# L0 J2 Z4 J" A//--- ok' |% c$ e# {( ~
return(true);( S& v; J2 C. Q1 N, \; t
}
, H' B8 D l5 \; u& V3 m//+------------------------------------------------------------------+$ f4 w: B# {/ ]" Q
//| Predict class |
% Z( h: U9 n& Z6 J' S+ H" _//+------------------------------------------------------------------+
( X: [# X/ Z* A8 n- d. N+ ]1 ^9 V: J mvirtual int PredictClass(void)! A L; F+ ]1 B) \" Q
{ |