1. 我们会用到什么模型呢?2 o3 Z0 R9 P; u$ l
在之前的投票分类器中,我们用到了一个分类模型和一个回归模型。 在回归模型中,我们在用于计算分类时,用预测价格替代预测价格走势(下跌、上涨、不变)。 然而,在这种情况下,我们不能依据分类得到概率分布,而对于所谓的“软投票”这样是不允许的。
& q/ M4 G, w( X# P0 H' _我们已准备了 3 个分类模型。 在“如何在 MQL5 中集成 ONNX 模型的示例”一文中已用到两个模型。 第一个模型(回归)被转换为分类模型。 基于 10 个 OHLC 价格序列进行了培训。 第二个模型是分类模型。 基于 63 个收盘价序列进- V* }6 O$ x' e6 G
//| https://www.mql5.com |
- C% o8 i0 [ U/ o5 I//+------------------------------------------------------------------+
: a, O0 g( m2 U/ S//--- price movement prediction6 ?! [# X C M4 {
#define PRICE_UP 0& B8 r, e" [6 Z3 ~# }2 d, c7 f, d9 D
#define PRICE_SAME 11 x) d( v7 s+ i; N' _9 Y* v9 I
#define PRICE_DOWN 2
S$ O4 l2 ?$ r+ z$ d2 w//+------------------------------------------------------------------+9 [5 ^. z' u4 E2 X2 G) p
//| Base class for models based on trained symbol and period |
# B% n' r/ \ q2 z) x//+------------------------------------------------------------------+
& }' Q* i0 h3 S; j* wclass CModelSymbolPeriod
' G2 j* R- N5 `& s1 A5 g3 A7 s3 r{7 A* q) R- J% W; d1 y) _! [; w
protected:1 H! h7 j' N% v4 r+ t Y
long m_handle; // created model session handle
5 N& m [9 Z0 }$ Hstring m_symbol; // symbol of trained data! ^3 O2 ~& \8 k5 I+ n
ENUM_TIMEFRAMES m_period; // timeframe of trained data
q# \6 u7 ^: }$ sdatetime m_next_bar; // time of next bar (we work at bar begin only)
5 d! ~2 ~& T3 b e5 ddouble m_class_delta; // delta to recognize "price the same" in regression models! I% l/ o+ w& [' w! {
public:: r# } r3 n; h& M4 T
//+------------------------------------------------------------------+ I8 X. v4 B# n. Q7 K
//| Constructor |. s) j; q8 O$ d# f" O0 M
//+------------------------------------------------------------------+
# ]7 J. L5 L% j# vCModelSymbolPeriod(const string symbol,const ENUM_TIMEFRAMES period,const double class_delta=0.0001); W0 Y2 Y4 q x: B. }# v
{ G# j" J* q, z+ l! ^
m_handle=INVALID_HANDLE;
0 ]7 \6 W2 U1 K! @0 om_symbol=symbol;
6 ~$ y1 X" z) ^* p+ O0 W. R: Q5 Rm_period=period;
! Q! f+ E( l3 Mm_next_bar=0;
: E+ Q3 B# a) q; o5 e6 mm_class_delta=class_delta;; U$ C& B4 s4 b* L. S
}
# p( n* ]; }. e7 a//+------------------------------------------------------------------+5 @+ s5 Z: \6 s2 E
//| Destructor |
# n8 M9 F, x# D. l; V: J; b//| Check for initialization, create model |
' v5 ~1 s U1 e$ L7 Y- p! {3 h//+------------------------------------------------------------------+
4 x& T" c6 \* B8 Ibool CheckInit(const string symbol,const ENUM_TIMEFRAMES period,const uchar& model[])
$ q, n- I E* f) E* q0 B- k{2 Z+ X, N# [* U! P6 j
//--- check symbol, period5 q& m- w! R5 c( ?
if(symbol!=m_symbol || period!=m_period)9 X6 x: Z! k- l2 p
{' |* D/ R* U2 V5 q. P& X% x
PrintFormat("Model must work with %s,%s",m_symbol,EnumToString(m_period));7 _5 d. ]+ c' p" m
return(false);# _, u+ g, i5 W" f% o
}- f$ D* j- A5 n' [% |8 e2 d
//--- create a model from static buffer8 n6 A% b! Q6 Z7 G4 d# V3 ^! x
m_handle=OnnxCreateFromBuffer(model,ONNX_DEFAULT);4 j6 S# h- x% n- h/ c
if(m_handle==INVALID_HANDLE)3 T, p9 d: n. @$ E7 K) N5 [. \
{
) v# K$ \7 z) H! |/ K7 i) N) K% b7 OPrint("OnnxCreateFromBuffer error ",GetLastError());! ?- H" ]" e0 p1 E
return(false);9 t! o; p5 F1 o0 k# }. {
}$ W# ]/ K( S6 ~3 r6 V
//--- ok
# ^9 y Y: O+ @% }! O; Yreturn(true);# } Y9 Q% F% M5 a
}
/ p( C0 [$ H$ g0 ]//+------------------------------------------------------------------+& m- ?% s q2 g4 z( G
m_next_bar=TimeCurrent();8 w z, c* g! T, x
m_next_bar-=m_next_bar%PeriodSeconds(m_period);
) F2 y- j6 ~; e: xm_next_bar+=PeriodSeconds(m_period);8 a, o0 X( s0 y' k
//--- work on new day bar
7 E% G& j3 S% D, g0 n# @& |return(true);3 s/ } i9 q& Y% @! Q- H0 S
}
. [, O/ K* }7 a4 r' z- ]4 m. D//+------------------------------------------------------------------+
. T) V5 X d9 B* |; u//| virtual stub for PredictPrice (regression model) |# \2 Y/ ^+ \9 k) x/ r% A( S
//+------------------------------------------------------------------+
3 y: ` F& m; t3 J. U$ k9 u1 Nvirtual double PredictPrice(void)* s; t$ j1 l. v% [
{) N: s/ |) I( S" ~
return(DBL_MAX);
% y# f3 y5 F: g; U7 q}7 j" c" J6 p% q7 L# l7 o
//+------------------------------------------------------------------+
8 m/ a$ s* e, t% F//| Predict class (regression -> classification) |
1 d: {; O$ V! M9 D+ @+ K//+------------------------------------------------------------------+# M! m- j' A3 q# X8 l
virtual int PredictClass(void)
) `5 j5 m( r7 D: o& K{
$ n, d+ w& W+ M# [- h! l; b' Cdouble predicted_price=PredictPrice();0 j$ e9 b: k# I* i! ^3 W0 `
if(predicted_price==DBL_MAX)) y( z" m& u) W, {% y
return(-1);
0 _- a6 W5 a% h& W: W% E* M+ L! Oint predicted_class=-1;
+ Q$ W) Z- [ t+ W3 c$ i6 ldouble last_close=iClose(m_symbol,m_period,1);, ?6 `& E- g/ O$ n
//--- classify predicted price movement; Y8 ]5 b9 K1 U
double delta=last_close-predicted_price;
5 y% ?, |" ]+ Q( j5 }6 Wif(fabs(delta)<=m_class_delta)
3 t0 V7 F* n( N9 i6 c: G6 w& y$ Cpredicted_class=PRICE_SAME;
# O8 Q; k1 a* m1 @else
7 F% S+ v, V, Dprivate:2 f2 @+ s. N M' W( K
int m_sample_size;
3 }" [0 I( s0 L//+------------------------------------------------------------------+; y9 g) D- b% c7 O' F
virtual bool Init(const string symbol, const ENUM_TIMEFRAMES period)* a: K8 `) I, Q
{5 t( |) n+ l b- N
//--- check symbol, period, create model
* v. ~* p2 V7 I( A; z# Fif(!CModelSymbolPeriod::CheckInit(symbol,period,model_eurusd_D1_10_class))% m4 t* B8 D0 A
{
J9 |- y8 O# ePrint("model_eurusd_D1_10_class : initialization error");& B5 y/ a) I5 s
return(false); l* M$ y. Q( m/ x6 W1 J3 |$ p, }
}; z2 M% U6 M# p. w' l" w" Q" L
//--- since not all sizes defined in the input tensor we must set them explicitly
) K& p6 f+ \% |/ M4 O8 j7 s2 y//--- first index - batch size, second index - series size, third index - number of series (OHLC)
- u# _1 \( E2 a; d% rconst long input_shape[] = {1,m_sample_size,4};
8 X6 Z) C l% Y' Z/ C. C4 c2 ^if(!OnnxSetInputShape(m_handle,0,input_shape))
! p0 T0 k. x4 w0 Q. I3 K{
7 [5 O$ T; E+ `. [) B! J; APrint("model_eurusd_D1_10_class : OnnxSetInputShape error ",GetLastError());
; x8 D0 A9 E x3 K* V; g' f# Xreturn(false);$ \$ {6 y& T" y' q; s* h
}" y& V( x" _. e
//--- since not all sizes defined in the output tensor we must set them explicitly3 f& z7 P1 C, T0 W
//--- first index - batch size, must match the batch size of the input tensor) A+ w8 j0 j: ]" |9 O
//--- second index - number of classes (up, same or down)7 a+ Y+ K' Z3 W/ Y R
const long output_shape[] = {1,3};* `$ ^2 f* R6 N3 ?# O
if(!OnnxSetOutputShape(m_handle,0,output_shape))) p7 p u9 T5 V
{% q, f. m4 t$ z
Print("model_eurusd_D1_10_class : OnnxSetOutputShape error ",GetLastError());
2 `1 s0 J' U( x+ x$ oreturn(false);
, h8 ]& @* y# J Q}1 x8 t% g; h& ]5 K) v
//--- ok
% w/ S" o5 ? C* |6 n# r% treturn(true);% Z- |% Y) E9 l
}: ?# X1 l4 q" k2 y5 `9 `
//+------------------------------------------------------------------+
8 @! u/ A U1 X$ z5 d6 P) y//| Predict class |6 U: `! o6 m% u
//+------------------------------------------------------------------+
3 _& B, \% _* R0 E/ s1 Kvirtual int PredictClass(void)
; |! _, k* v# W3 M4 N) _2 i{ |