0% found this document useful (0 votes)
19 views17 pages

简单的基于LSTM的股市分析与预测(Python)

The document presents a simple stock market analysis and prediction using LSTM in Python, focusing on tech stocks like Apple, Google, Microsoft, and Amazon. It details the process of data acquisition, visualization, and model training, including steps for calculating moving averages and daily returns. The LSTM model is built and trained to predict stock prices, with performance evaluated using root mean squared error.

Uploaded by

yl5404
Copyright
© © All Rights Reserved
We take content rights seriously. If you suspect this is your content, claim it here.
Available Formats
Download as PDF, TXT or read online on Scribd
0% found this document useful (0 votes)
19 views17 pages

简单的基于LSTM的股市分析与预测(Python)

The document presents a simple stock market analysis and prediction using LSTM in Python, focusing on tech stocks like Apple, Google, Microsoft, and Amazon. It details the process of data acquisition, visualization, and model training, including steps for calculating moving averages and daily returns. The LSTM model is built and trained to predict stock prices, with performance evaluated using root mean squared error.

Uploaded by

yl5404
Copyright
© © All Rights Reserved
We take content rights seriously. If you suspect this is your content, claim it here.
Available Formats
Download as PDF, TXT or read online on Scribd
You are on page 1/ 17

2025/1/19 03:08 简单的基于LSTM的股市分析与预测(Python) - 知乎

简单的基于LSTM的股市分析与预测(Python)
哥廷根数学学派 关注他
与现代信号处理,机器学习,深度学习,故障诊断那些事
3 人赞同了该文章

import pandas as pd
import numpy as np

import matplotlib.pyplot as plt


import seaborn as sns
sns.set_style('whitegrid')
plt.style.use("fivethirtyeight")
%matplotlib inline

# For reading stock data from yahoo


from pandas_datareader.data import DataReader
import yfinance as yf
from pandas_datareader import data as pdr

yf.pdr_override()

# For time stamps


from datetime import datetime

https://zhuanlan.zhihu.com/p/712662485 1/17
2025/1/19 03:08 简单的基于LSTM的股市分析与预测(Python) - 知乎

# The tech stocks we'll use for this analysis


tech_list = ['AAPL', 'GOOG', 'MSFT', 'AMZN']

# Set up End and Start times for data grab


tech_list = ['AAPL', 'GOOG', 'MSFT', 'AMZN']

end = datetime.now()
start = datetime(end.year - 1, end.month, end.day)

for stock in tech_list:


globals()[stock] = yf.download(stock, start, end)

company_list = [AAPL, GOOG, MSFT, AMZN]


company_name = ["APPLE", "GOOGLE", "MICROSOFT", "AMAZON"]

for company, com_name in zip(company_list, company_name):


company["company_name"] = com_name

df = pd.concat(company_list, axis=0)
df.tail(10)
[*********************100%%**********************] 1 of 1 completed
[*********************100%%**********************] 1 of 1 completed
[*********************100%%**********************] 1 of 1 completed
[*********************100%%**********************] 1 of 1 completed

# Summary Stats
AAPL.describe()

https://zhuanlan.zhihu.com/p/712662485 2/17
2025/1/19 03:08 简单的基于LSTM的股市分析与预测(Python) - 知乎

# General info
AAPL.info()
<class 'pandas.core.frame.DataFrame'>
DatetimeIndex: 252 entries, 2023-06-05 to 2024-06-04
Data columns (total 7 columns):
# Column Non-Null Count Dtype
--- ------ -------------- -----
0 Open 252 non-null float64
1 High 252 non-null float64
2 Low 252 non-null float64
3 Close 252 non-null float64
4 Adj Close 252 non-null float64
5 Volume 252 non-null int64
6 company_name 252 non-null object
dtypes : float64(5), int64(1), object(1)
memory usage: 15.8+ KB
# Let's see a historical view of the closing price
plt.figure(figsize=(15, 10))
plt.subplots_adjust(top=1.25, bottom=1.2)

for i, company in enumerate(company_list, 1):


plt.subplot(2, 2, i)
company['Adj Close'].plot()
plt.ylabel('Adj Close')
plt.xlabel(None)
plt.title(f"Closing Price of {tech_list[i - 1]}")

plt.tight_layout()

https://zhuanlan.zhihu.com/p/712662485 3/17
2025/1/19 03:08 简单的基于LSTM的股市分析与预测(Python) - 知乎

# Now let's plot the total volume of stock being traded each day
plt.figure(figsize=(15, 10))
plt.subplots_adjust(top=1.25, bottom=1.2)

for i, company in enumerate(company_list, 1):


plt.subplot(2, 2, i)
company['Volume'].plot()
plt.ylabel('Volume')
plt.xlabel(None)
plt.title(f"Sales Volume for {tech_list[i - 1]}")

plt.tight_layout ()

ma_day = [10, 20, 50]

for ma in ma_day:

https://zhuanlan.zhihu.com/p/712662485 4/17
2025/1/19 03:08 简单的基于LSTM的股市分析与预测(Python) - 知乎
for company in company_list:
column_name = f"MA for {ma} days"
company[column_name] = company['Adj Close'].rolling(ma).mean()

fig, axes = plt.subplots(nrows=2, ncols=2)


fig.set_figheight(10)
fig.set_figwidth(15)

AAPL[['Adj Close', 'MA for 10 days', 'MA for 20 days', 'MA for 50 days']].plot(
axes[0,0].set_title('APPLE')

GOOG[['Adj Close', 'MA for 10 days', 'MA for 20 days', 'MA for 50 days']].plot(
axes[0,1].set_title('GOOGLE')

MSFT[['Adj Close', 'MA for 10 days', 'MA for 20 days', 'MA for 50 days']].plot(
axes[1,0].set_title('MICROSOFT')

AMZN[['Adj Close', 'MA for 10 days', 'MA for 20 days', 'MA for 50 days']].plot(
axes[1,1].set_title('AMAZON')

fig.tight_layout()

# We'll use pct_change to find the percent change for each day
for company in company_list:
company['Daily Return'] = company['Adj Close'].pct_change()

# Then we'll plot the daily return percentage


fig, axes = plt.subplots(nrows=2, ncols=2)
fig.set_figheight(10)
fig.set_figwidth(15)

https://zhuanlan.zhihu.com/p/712662485 5/17
2025/1/19 03:08 简单的基于LSTM的股市分析与预测(Python) - 知乎
AAPL['Daily Return'].plot(ax=axes[0,0], legend=True, linestyle='--', marker='o
axes[0,0].set_title('APPLE')

GOOG['Daily Return'].plot(ax=axes[0,1], legend=True, linestyle='--', marker='o


axes[0,1].set_title('GOOGLE')

MSFT['Daily Return'].plot(ax=axes[1,0], legend=True, linestyle='--', marker='o


axes[1,0].set_title('MICROSOFT')

AMZN['Daily Return'].plot(ax=axes[1,1], legend=True, linestyle='--', marker='o


axes[1,1].set_title('AMAZON')

fig.tight_layout()

plt.figure(figsize=(12, 9))

for i, company in enumerate(company_list, 1):


plt.subplot(2, 2, i)
company['Daily Return'].hist(bins=50)
plt.xlabel('Daily Return')
plt.ylabel('Counts')
plt.title(f'{company_name[i - 1]}')

plt.tight_layout()

https://zhuanlan.zhihu.com/p/712662485 6/17
2025/1/19 03:08 简单的基于LSTM的股市分析与预测(Python) - 知乎

# Grab all the closing prices for the tech stock list into one DataFrame

closing_df = pdr.get_data_yahoo(tech_list, start=start, end=end)['Adj Close']

# Make a new tech returns DataFrame


tech_rets = closing_df.pct_change()
tech_rets.head()

# Comparing Google to itself should show a perfectly linear relationship


sns.jointplot(x='GOOG', y='GOOG', data=tech_rets, kind='scatter', color='seagre

https://zhuanlan.zhihu.com/p/712662485 7/17
2025/1/19 03:08 简单的基于LSTM的股市分析与预测(Python) - 知乎

# We'll use joinplot to compare the daily returns of Google and Microsoft
sns.jointplot(x='GOOG', y='MSFT', data=tech_rets, kind='scatter')

https://zhuanlan.zhihu.com/p/712662485 8/17
2025/1/19 03:08 简单的基于LSTM的股市分析与预测(Python) - 知乎

# We can simply call pairplot on our DataFrame for an automatic visual analysis
# of all the comparisons

sns.pairplot(tech_rets, kind='reg')

# Set up our figure by naming it returns_fig, call PairPLot on the DataFrame


return_fig = sns.PairGrid(tech_rets.dropna())

# Using map_upper we can specify what the upper triangle will look like.
return_fig.map_upper(plt.scatter, color='purple')

# We can also define the lower triangle in the figure, inclufing the plot type
# or the color map (BluePurple)
return_fig.map_lower(sns.kdeplot, cmap='cool_d')

# Finally we'll define the diagonal as a series of histogram plots of the daily
return_fig.map_diag(plt.hist, bins=30)

https://zhuanlan.zhihu.com/p/712662485 9/17
2025/1/19 03:08 简单的基于LSTM的股市分析与预测(Python) - 知乎

# Set up our figure by naming it returns_fig, call PairPLot on the DataFrame


returns_fig = sns.PairGrid(closing_df)

# Using map_upper we can specify what the upper triangle will look like.
returns_fig.map_upper(plt.scatter,color='purple')

# We can also define the lower triangle in the figure, inclufing the plot type
returns_fig.map_lower(sns.kdeplot ,cmap='cool_d')

# Finally we'll define the diagonal as a series of histogram plots of the daily
returns_fig.map_diag(plt.hist,bins=30)

https://zhuanlan.zhihu.com/p/712662485 10/17
2025/1/19 03:08 简单的基于LSTM的股市分析与预测(Python) - 知乎

plt.figure(figsize=(12, 10))

plt.subplot(2, 2, 1)
sns.heatmap(tech_rets.corr(), annot=True, cmap='summer')
plt.title('Correlation of stock return')

plt.subplot(2, 2, 2)
sns.heatmap(closing_df.corr(), annot=True, cmap='summer')
plt.title('Correlation of stock closing price')

rets = tech_rets.dropna()

https://zhuanlan.zhihu.com/p/712662485 11/17
2025/1/19 03:08 简单的基于LSTM的股市分析与预测(Python) - 知乎
area = np.pi * 20

plt.figure(figsize=(10, 8))
plt.scatter(rets.mean(), rets.std(), s=area)
plt.xlabel('Expected return')
plt.ylabel('Risk')

for label, x, y in zip(rets.columns, rets.mean(), rets.std()):


plt.annotate(label, xy=(x, y), xytext=(50, 50), textcoords='offset points',
arrowprops=dict(arrowstyle='-', color='blue', connectionstyle=

# Get the stock quote


df = pdr.get_data_yahoo('AAPL', start='2012-01-01', end=datetime.now())
# Show teh data
df

https://zhuanlan.zhihu.com/p/712662485 12/17
2025/1/19 03:08 简单的基于LSTM的股市分析与预测(Python) - 知乎

plt.figure(figsize=(16,6))
plt.title('Close Price History')
plt.plot(df['Close'])
plt.xlabel('Date', fontsize=18)
plt.ylabel('Close Price USD ($)', fontsize=18)
plt.show()

# Create a new dataframe with only the 'Close column


data = df.filter(['Close'])
# Convert the dataframe to a numpy array
dataset = data.values
# Get the number of rows to train the model on
training_data_len = int(np.ceil( len(dataset) * .95 ))

training_data_len
2969
# Scale the data
from sklearn.preprocessing import MinMaxScaler

scaler = MinMaxScaler(feature_range=(0,1))
scaled_data = scaler.fit_transform(dataset)

scaled_data

https://zhuanlan.zhihu.com/p/712662485 13/17
2025/1/19 03:08 简单的基于LSTM的股市分析与预测(Python) - 知乎
array([[0.00401431],
[0.00444289],
[0.00533302],
...,
[0.96818027],
[0.97784564],
[0.98387293]])
# Create the training data set
# Create the scaled training data set
train_data = scaled_data[0:int(training_data_len), :]
# Split the data into x_train and y_train data sets
x_train = []
y_train = []

for i in range(60, len(train_data)):


x_train.append(train_data[i-60:i, 0])
y_train.append(train_data[i, 0])
if i<= 61:
print(x_train)
print(y_train)
print()

# Convert the x_train and y_train to numpy arrays


x_train, y_train = np.array(x_train), np.array(y_train)

# Reshape the data


x_train = np.reshape(x_train, (x_train.shape[0], x_train.shape[1], 1))
# x_train.shape
from keras.models import Sequential
from keras.layers import Dense, LSTM

# Build the LSTM model


model = Sequential()
model.add(LSTM(128, return_sequences=True, input_shape = (x_train.shape[1], 1)
model.add(LSTM(64, return_sequences=False))
model.add(Dense(25))
model.add(Dense(1))

# Compile the model


model.compile(optimizer='adam', loss='mean_squared_error')

# Train the model


model.fit(x_train, y_train, batch_size=1, epochs=1)
# Create the testing data set
# Create a new array containing scaled values from index 1543 to 2002
test_data = scaled_data[training_data_len - 60: , :]
# Create the data sets x_test and y_test
x_test = []
y_test = dataset[training_data_len:, :]
for i in range(60, len(test_data)):
x_test.append(test_data[i-60:i, 0])

# Convert the data to a numpy array


x_test = np.array(x_test)

# Reshape the data


x_test = np.reshape(x_test, (x_test.shape[0], x_test.shape[1], 1 ))

https://zhuanlan.zhihu.com/p/712662485 14/17
2025/1/19 03:08 简单的基于LSTM的股市分析与预测(Python) - 知乎

# Get the models predicted price values


predictions = model.predict(x_test)
predictions = scaler.inverse_transform(predictions)

# Get the root mean squared error (RMSE)


rmse = np.sqrt(np.mean(((predictions - y_test) ** 2)))
rmse
10.469761767631676
# Plot the data
train = data[:training_data_len]
valid = data[training_data_len:]
valid['Predictions'] = predictions
# Visualize the data
plt.figure(figsize=(16,6))
plt.title('Model')
plt.xlabel('Date', fontsize=18)
plt.ylabel('Close Price USD ($)', fontsize=18)
plt.plot(train['Close'])
plt.plot(valid[['Close', 'Predictions']])
plt.legend(['Train', 'Val', 'Predictions'], loc='lower right')
plt.show()

# Show the valid and predicted prices


valid

https://zhuanlan.zhihu.com/p/712662485 15/17
2025/1/19 03:08 简单的基于LSTM的股市分析与预测(Python) - 知乎

知乎学术咨询:https://www.zhihu.com/consult/people/792359672131756032?isMe=1

担任《Mechanical System and Signal Processing》等审稿专家,擅长领域:现代信号处理,机


器学习,深度学习,数字孪生 ,时间序列分析,设备缺陷检测、设备异常检测、设备智能故障诊
断与健康管理PHM等。
编辑于 2024-08-04 03:26・IP 属地重庆

LSTM 股市分析 金融时间序列分析

赞同 3 添加评论 分享 喜欢 收藏 申请转载

https://zhuanlan.zhihu.com/p/712662485 16/17
2025/1/19 03:08 简单的基于LSTM的股市分析与预测(Python) - 知乎

评论区已关闭

推荐阅读

量化
前言
理了
是用
个人
数据科学 | Seaborn + Pandas 从零入门量化交易系列-数据可 利用python进行时间序列分析 的pl
plot
带你玩转股市数据可视化分析 视化库Matplotlib详解 ——季节性ARIMA
运筹OR帷... 发表于『运筹AI... wang ... 发表于从零入门量... Eureka 雁陎

https://zhuanlan.zhihu.com/p/712662485 17/17

You might also like