Python基于dtw实现股票预测(多线程)

Posted 云山之巅

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了Python基于dtw实现股票预测(多线程)相关的知识,希望对你有一定的参考价值。

  1 # -*- coding: utf-8 -*-
  2 """
  3 Created on Tue Dec  4 08:53:08 2018
  4 
  5 @author: zhen
  6 """
  7 from dtw import fastdtw
  8 import matplotlib.pyplot as plt
  9 import numpy as np
 10 import pandas as pd
 11 import threading
 12 import time
 13 from datetime import datetime
 14 
 15 def normalization(x): # np.std:计算矩阵的标准差(方差的算术平方根)
 16     return (x - np.mean(x)) / np.std(x)
 17 
 18 def corrcoef(a,b):
 19     corrc = np.corrcoef(a,b) # 计算皮尔逊相关系数,用于度量两个变量之间的相关性,其值介于-1到1之间
 20     corrc = corrc[0,1]
 21     return (16 * ((1 - corrc) / (1 + corrc)) ** 1) # ** 表示乘方
 22         
 23 print("begin Main Thread")
 24 startTimeStamp = datetime.now() # 获取当前时间
 25 # 加载数据
 26 filename = \'C:/Users/zhen/.spyder-py3/sh000300_2017.csv\'
 27 # 获取第一,二列的数据
 28 all_date = pd.read_csv(filename,usecols=[0, 1], dtype = \'str\')
 29 all_date = np.array(all_date)
 30 data = all_date[:, 0]
 31 times = all_date[:, 1]
 32 
 33 data_points = pd.read_csv(filename,usecols=[3])
 34 data_points = np.array(data_points)
 35 data_points = data_points[:,0] #数据
 36 
 37 topk = 10 #只显示top-10
 38 baselen = 100 # 假设在50到150之间变化
 39 basebegin = 361
 40 basedata = data[basebegin]+\' \'+times[basebegin]+\'~\'+data[basebegin+baselen-1]+\' \'+times[basebegin+baselen-1]
 41 length = len(data_points) #数据长度
 42 
 43 # 定义自定义线程类
 44 class Thread_Local(threading.Thread):
 45     def __init__(self, thread_id, name, counter):
 46         threading.Thread.__init__(self)
 47         self.thread_id = thread_id
 48         self.name = name
 49         self.counter = counter
 50         self.__running = threading.Event() # 标识停止线程
 51         self.__running.set() # 设置为True
 52         
 53     def run(self):
 54         print("starting %s" % self.name)
 55         split_data(self, self.counter) # 执行代码逻辑
 56         
 57     def stop(self):
 58         self.__running.clear()
 59         
 60 # 分割片段并执行匹配,多线程
 61 def split_data(self, split_len):
 62     base = data_points[basebegin:basebegin+baselen]  # 获取初始要匹配的数据
 63     subseries = []
 64     dateseries = []
 65     for j in range(0, length): 
 66         if (j < (basebegin - split_len) or j > (basebegin + split_len - 1)) and j <length - split_len:
 67             subseries.append(data_points[j:j+split_len])
 68             dateseries.append(j) #开始位置
 69     search(self, subseries, base, dateseries)  # 调用模式匹配
 70 
 71 # 定义结果变量
 72 result = []  
 73 base_list = []
 74 date_list = []
 75 def search(self, subseries, base, dateseries):
 76      # 片段搜索
 77     listdistance = []
 78     for i in range(0, len(subseries)):
 79         tt = np.array(subseries[i])
 80         dist, cost, acc, path = fastdtw(base, tt, dist=\'euclidean\')
 81         listdistance.append(dist)
 82         # distance = corrcoef(base, tt)
 83         # listdistance.append(distance)
 84     # 排序
 85     index = np.argsort(listdistance, kind=\'quicksort\') #排序,返回排序后的索引序列
 86     result.append(subseries[index[0]])
 87     print("result length is %d" % len(result))
 88     base_list.append(base)
 89     date_list.append(dateseries[index[0]])
 90     # 关闭线程
 91     self.stop()
 92     
 93 # 变换数据(收缩或扩展),生成50到150之间的数据,间隔为10
 94 loc = 0
 95 for split_len in range(round(0.5 * baselen), round(1.5 * baselen), 10):
 96     # 执行匹配
 97    thread = Thread_Local(1, "Thread" + str(loc), split_len)
 98    loc += 1
 99    # 开启线程
100    thread.start()
101 
102 boo = 1
103 
104 while(boo > 0):
105     if(len(result) < 10):
106         if(boo % 100 == 0):
107             print("has running %d s" % boo)
108         boo += 1
109         time.sleep(1)
110     else:
111         boo = 0
112         
113  # 片段搜索
114 listdistance = []
115 for i in range(0, len(result)):
116     tt = np.array(result[i])
117     dist, cost, acc, path = fastdtw(base_list[i], tt, dist=\'euclidean\')
118     # distance = corrcoef(base_list[i], tt)
119     listdistance.append(dist)
120 # 最终排序   
121 index = np.argsort(listdistance, kind=\'quicksort\') #排序,返回排序后的索引序列
122 print("closed Main Thread")
123 endTimeStamp = datetime.now()
124 # 结果集对比
125 plt.figure(0)
126 plt.plot(normalization(base_list[index[0]]),label= basedata,linewidth=\'2\')
127 length = len(result[index[0]])
128 begin = data[date_list[index[0]]] + \' \' + times[date_list[index[0]]]
129 end = data[date_list[index[0]] + length - 1] + \' \' + times[date_list[index[0]] + length - 1]
130 label = begin + \'~\' + end
131 plt.plot(normalization(result[index[0]]), label=label, linewidth=\'2\')  
132 plt.legend(loc=\'lower right\')
133 plt.title(\'normal similarity search\')
134 plt.show()
135 print(\'run time\', (endTimeStamp-startTimeStamp).seconds, "s")

结果:

has running 100 s
has running 200 s
has running 300 s
has running 400 s
has running 500 s
has running 600 s
has running 700 s
has running 800 s
has running 900 s
has running 1000 s
has running 1100 s
has running 1200 s
has running 1300 s
has running 1400 s
has running 1500 s
has running 1600 s
has running 1700 s
has running 1800 s
has running 1900 s
has running 2000 s
has running 2100 s
has running 2200 s
has running 2300 s
has running 2400 s
has running 2500 s
has running 2600 s
has running 2700 s
has running 2800 s
has running 2900 s
has running 3000 s
has running 3100 s
has running 3200 s
has running 3300 s
has running 3400 s
has running 3500 s
has running 3600 s
has running 3700 s
has running 3800 s
has running 3900 s
has running 4000 s
has running 4100 s
has running 4200 s
has running 4300 s
has running 4400 s
has running 4500 s
has running 4600 s
has running 4700 s
has running 4800 s
has running 4900 s
has running 5000 s
has running 5100 s
has running 5200 s
has running 5300 s
has running 5400 s
has running 5500 s
has running 5600 s
has running 5700 s
has running 5800 s
has running 5900 s
has running 6000 s
has running 6100 s
has running 6200 s
has running 6300 s
has running 6400 s
has running 6500 s
has running 6600 s
has running 6700 s
has running 6800 s
has running 6900 s
has running 7000 s
has running 7100 s
has running 7200 s
has running 7300 s
has running 7400 s
has running 7500 s
has running 7600 s
has running 7700 s
has running 7800 s
has running 7900 s
has running 8000 s
has running 8100 s
has running 8200 s
has running 8300 s
has running 8400 s
has running 8500 s
has running 8600 s
has running 8700 s
has running 8800 s
has running 8900 s
has running 9000 s
has running 9100 s
has running 9200 s
has running 9300 s
has running 9400 s
has running 9500 s
has running 9600 s
has running 9700 s
has running 9800 s
has running 9900 s
has running 10000 s
has running 10100 s
has running 10200 s
has running 10300 s
has running 10400 s
has running 10500 s
has running 10600 s
has running 10700 s
has running 10800 s
has running 10900 s
has running 11000 s
has running 11100 s
has running 11200 s
has running 11300 s
has running 11400 s
has running 11500 s
has running 11600 s
has running 11700 s
has running 11800 s
has running 11900 s
has running 12000 s
has running 12100 s
has running 12200 s
has running 12300 s
has running 12400 s
has running 12500 s
has running 12600 s
has running 12700 s
has running 12800 s
has running 12900 s
has running 13000 s
has running 13100 s
has running 13200 s
has running 13300 s
has running 13400 s
has running 13500 s
has running 13600 s
has running 13700 s
has running 13800 s
has running 13900 s
has running 14000 s
has running 14100 s
has running 14200 s
has running 14300 s
has running 14400 s
result length is 1
result length is 2
has running 14500 s
has running 14600 s
has running 14700 s
has running 14800 s
result length is 3
has running 14900 s
has running 15000 s
result length is 4
has running 15100 s
has running 15200 s
has running 15300 s
has running 15400 s
result length is 5
has running 15500 s
has running 15600 s
has running 15700 s
has running 15800 s
has running 15900 s
has running 16000 s
has running 16100 s
has running 16200 s
result length is 6
has running 16300 s
has running 16400 s
has running 16500 s
has running 16600 s
result length is 7
result length is 8
has running 16700 s
result length is 9
result length is 10
closed Main Thread
View Code

 

以上是关于Python基于dtw实现股票预测(多线程)的主要内容,如果未能解决你的问题,请参考以下文章

使用 python yfinance 多线程下载雅虎股票历史

python深度学习之基于LSTM时间序列的股票价格预测

python多线程实现预测饭局组织者

干货基于TensorFlow卷积神经网络的短期股票预测

Python 多线程爬虫

DL之GRU:基于2022年6月最新上证指数数据集结合Pytorch框架利用GRU算法预测最新股票上证指数实现回归预测