稳定的基线:DQN 表现不佳?
Posted
技术标签:
【中文标题】稳定的基线:DQN 表现不佳?【英文标题】:Stable Baselines: DQN Not performing properly? 【发布时间】:2021-02-17 13:47:08 【问题描述】:我在使用 DQN 作为对角线和正弦波作为价格波动时遇到问题。当价格上涨时,会有奖励,并且在图表中显示为绿色。当价格下跌并被标记为红色时,奖励就会增加。请看这个link链接的DQN比stablebases的DQN学习效果很好。
即使对 DQN 使用对角线,我也遇到了困难。
正弦波:如果结果相反,那就太好了。绿色代表上升,红色代表下降。
我所做的是将学习率从 0.01 更改为 10。将 Epsilon 更改为 1。
在 PPO2 中,我可以得到一个不错的结果。 对于正弦波:
model = PPO2(MlpPolicy, env, verbose=1,learning_rate=.01)
model.learn(total_timesteps=500000)
对于对角线,它也确实有效!
这是我的代码。只需评论和取消评论测试 PPO2 与 DQN 所需的内容
from copy import deepcopy
import numpy as np
import pandas as pd
import gym
import gym_anytrading
from stable_baselines import A2C , DQN ,ACKTR
from stable_baselines.common.vec_env import DummyVecEnv
from stable_baselines.deepq.policies import MlpPolicy
import matplotlib.pyplot as plt
import math as m
from stable_baselines.deepq.policies import FeedForwardPolicy
from stable_baselines.common.policies import MlpPolicy
from stable_baselines.common import make_vec_env
from stable_baselines import PPO2
class CustomDQNPolicy(FeedForwardPolicy):
def __init__(self, *args, **kwargs):
super(CustomDQNPolicy, self).__init__(*args, **kwargs,
layers=[64,64,64],
layer_norm=True,
feature_extraction="mlp")
def main():
n_cpu = 16
# df = gym_anytrading.datasets.STOCKS_GOOGL.copy()
# print(df)
# arraysin =[]
# for x in range(0,200,1):
# arraysin = np.append(arraysin,(m.sin(x/10)+1))
# print(arraysin)
arraysin = np.arange(200/10.0) #linearly increasing prices
df = pd.DataFrame(arraysin)
# # convert the column (it's a string) to datetime type
# datetime_series = pd.to_datetime(df['date_of_birth'])
# # create datetime index passing the datetime series
# datetime_index = pd.DatetimeIndex(datetime_series.values)
df = pd.DataFrame(arraysin)
print(df)
df.columns=['Close']
# df=df.set_index(datetime_index)
window_size = 1
print(df)
start_index = window_size
end_index = len(df)
env_maker = lambda: gym.make(
'stocks-v0',
df = df,
window_size = window_size,
frame_bound = (start_index, end_index)
)
print(df)
env = DummyVecEnv([env_maker for _ in range(n_cpu)])
# policy_kwargs = dict(net_arch=[64, 'lstm', dict(vf=[128, 128, 128], pi=[64, 64])])
# model = A2C('MlpLstmPolicy', env, verbose=1, policy_kwargs=policy_kwargs)
# model = A2C(MlpPolicy, env, verbose=1,learning_rate=.01)
# model = ACKTR(MlpPolicy, env, verbose=1,learning_rate=1)
model = PPO2(MlpPolicy, env, verbose=1,learning_rate=.01)
# model = DQN(policy=CustomDQNPolicy,env=env, verbose=1,
# learning_rate= .01,
# buffer_size= 10000,
# double_q = False,
# exploration_final_eps= 1,
# prioritized_replay= True)
model.learn(total_timesteps=100000)
# model.save('nzdusdDQN')
env = env_maker()
observation = env.reset()
while True:
# observation = observation[np.newaxis, ...]
# action = env.action_space.sample()
action, _states = model.predict(observation)
observation, reward, done, info = env.step(action)
# env.render()
if done:
print("info:", info)
break
# for e in env.envs:
# plt.figure(figsize=(16, 6))
# e.render_all()
# plt.show()
plt.figure(figsize=(16, 6))
env.render_all()
plt.show()
if __name__ == '__main__':
main()
系统信息: 描述您的环境特征:
-
Windows 10
张量流 1.15.0
稳定基线 2.10.2a0 dev_0
gym-anytrading 1.2.0
康达列表:
PS E:\ML\reinforcementlearning\tradeorig> conda list
# packages in environment at C:\anaconda\envs\gymorig:
#
# Name Version Build Channel
_tflow_select 2.2.0 eigen
absl-py 0.11.0 py37haa95532_0
alabaster 0.7.12 py37_0
apipkg 1.5 pypi_0 pypi
argh 0.26.2 py37_0
asn1crypto 1.4.0 py_0
astor 0.8.1 py37_0
astroid 2.4.2 py37_0
async_generator 1.10 py37h28b3542_0
atari-py 0.2.6 pypi_0 pypi
atomicwrites 1.4.0 py_0
attrs 20.2.0 py_0
autopep8 1.5.4 py_0
babel 2.8.0 py_0
backcall 0.2.0 py_0
bcrypt 3.2.0 py37he774522_0
blas 1.0 mkl
bleach 3.2.1 py_0
brotlipy 0.7.0 py37he774522_1000
ca-certificates 2020.10.14 0
certifi 2020.6.20 py37haa95532_2
cffi 1.14.3 py37h7a1dbc1_0
chardet 3.0.4 py37_1003
cloudpickle 1.6.0 py_0
colorama 0.4.4 py_0
coverage 5.3 pypi_0 pypi
cryptography 2.3.1 py37h74b6da3_0
cycler 0.10.0 pypi_0 pypi
decorator 4.4.2 py_0
defusedxml 0.6.0 py_0
diff-match-patch 20200713 py_0
docutils 0.16 py37_1
entrypoints 0.3 py37_0
execnet 1.7.1 pypi_0 pypi
flake8 3.8.4 py_0
future 0.18.2 py37_1
gast 0.2.2 py37_0
google-pasta 0.2.0 py_0
grpcio 1.14.1 py37h5c4b210_0
gym 0.17.3 pypi_0 pypi
gym-anytrading 1.2.0 pypi_0 pypi
h5py 2.10.0 py37h5e291fa_0
hdf5 1.10.4 h7ebc959_0
icc_rt 2019.0.0 h0cc432a_1
icu 58.2 ha925a31_3
idna 2.10 py_0
imagesize 1.2.0 py_0
importlab 0.5.1 pypi_0 pypi
importlib-metadata 2.0.0 py_1
importlib_metadata 2.0.0 1
iniconfig 1.0.1 pypi_0 pypi
intel-openmp 2020.2 254
intervaltree 3.1.0 py_0
ipykernel 5.3.4 py37h5ca1d4c_0
ipython 7.18.1 py37h5ca1d4c_0
ipython_genutils 0.2.0 py37_0
isort 5.6.4 py_0
jedi 0.17.1 py37_0
jinja2 2.11.2 py_0
joblib 0.17.0 pypi_0 pypi
jpeg 9b hb83a4c4_2
jsonschema 3.2.0 py_2
jupyter_client 6.1.7 py_0
jupyter_core 4.6.3 py37_0
jupyterlab_pygments 0.1.2 py_0
keras-applications 1.0.8 py_1
keras-base 2.3.1 py37_0
keras-preprocessing 1.1.0 py_1
keyring 21.4.0 py37_1
kiwisolver 1.2.0 pypi_0 pypi
lazy-object-proxy 1.4.3 py37he774522_0
libpng 1.6.37 h2a8f88b_0
libprotobuf 3.13.0.1 h200bbdf_0
libsodium 1.0.18 h62dcd97_0
libspatialindex 1.9.3 h33f27b4_0
livereload 2.6.3 pypi_0 pypi
lxml 4.5.2 pypi_0 pypi
markdown 3.3.2 py37_0
markupsafe 1.1.1 py37hfa6e2cd_1
matplotlib 3.3.2 pypi_0 pypi
mccabe 0.6.1 py37_1
mistune 0.8.4 py37hfa6e2cd_1001
mkl 2020.2 256
mkl-service 2.3.0 py37hb782905_0
mkl_fft 1.2.0 py37h45dec08_0
mkl_random 1.1.1 py37h47e9c7a_0
mpi4py 3.0.3 pypi_0 pypi
msgpack 1.0.0 pypi_0 pypi
multitasking 0.0.9 pypi_0 pypi
nbclient 0.5.1 py_0
nbconvert 6.0.7 py37_0
nbformat 5.0.8 py_0
nest-asyncio 1.4.1 py_0
networkx 2.5 pypi_0 pypi
ninja 1.10.0.post2 pypi_0 pypi
numpy 1.19.2 py37hadc3359_0
numpy-base 1.19.2 py37ha3acd2a_0
numpydoc 1.1.0 py_0
opencv-python 4.4.0.44 pypi_0 pypi
openssl 1.0.2u he774522_0
opt_einsum 3.1.0 py_0
packaging 20.4 py_0
pandas 1.1.3 py37ha925a31_0
pandoc 2.11 h9490d1a_0
pandocfilters 1.4.2 py37_1
paramiko 2.4.2 py37_0
parso 0.7.0 py_0
pathtools 0.1.2 py_1
pexpect 4.8.0 py37_1
pickleshare 0.7.5 py37_1001
pillow 7.2.0 pypi_0 pypi
pip 20.2.4 py37_0
pluggy 0.13.1 py37_0
prompt-toolkit 3.0.8 py_0
protobuf 3.13.0.1 py37ha925a31_1
psutil 5.7.2 py37he774522_0
py 1.9.0 pypi_0 pypi
pyasn1 0.4.8 py_0
pycodestyle 2.6.0 py_0
pycparser 2.20 py_2
pydocstyle 5.1.1 py_0
pyflakes 2.2.0 py_0
pyglet 1.5.0 pypi_0 pypi
pygments 2.7.1 py_0
pylint 2.6.0 py37_0
pynacl 1.4.0 py37h62dcd97_1
pyopenssl 19.0.0 py37_0
pyparsing 2.4.7 py_0
pyqt 5.6.0 py37ha878b3d_6
pyreadline 2.1 py37_1
pyrsistent 0.17.3 py37he774522_0
pysocks 1.7.1 py37_1
pytest 6.1.1 pypi_0 pypi
pytest-cov 2.10.1 pypi_0 pypi
pytest-env 0.6.2 pypi_0 pypi
pytest-forked 1.3.0 pypi_0 pypi
pytest-xdist 2.1.0 pypi_0 pypi
python 3.7.1 h33f27b4_4
python-dateutil 2.8.1 py_0
python-jsonrpc-server 0.4.0 py_0
python-language-server 0.35.1 py_0
pytype 2020.9.29 pypi_0 pypi
pytz 2020.1 py_0
pywin32 227 py37he774522_1
pywin32-ctypes 0.2.0 py37_1001
pyyaml 5.3.1 pypi_0 pypi
pyzmq 19.0.2 py37ha925a31_1
qdarkstyle 2.8.1 py_0
qt 5.6.2 vc14h6f8c307_12
qtawesome 1.0.1 py_0
qtconsole 4.7.7 py_0
qtpy 1.9.0 py_0
quantstats 0.0.25 pypi_0 pypi
requests 2.24.0 py_0
rope 0.18.0 py_0
rtree 0.9.4 py37h21ff451_1
ruamel-yaml 0.16.12 pypi_0 pypi
ruamel-yaml-clib 0.2.2 pypi_0 pypi
scipy 1.5.2 py37h9439919_0
seaborn 0.11.0 pypi_0 pypi
setuptools 50.3.0 py37h9490d1a_1
sip 4.18.1 py37h6538335_2
six 1.15.0 py_0
snowballstemmer 2.0.0 py_0
sortedcontainers 2.2.2 py_0
sphinx 3.2.1 py_0
sphinx-autobuild 2020.9.1 pypi_0 pypi
sphinx-rtd-theme 0.5.0 pypi_0 pypi
sphinxcontrib-applehelp 1.0.2 py_0
sphinxcontrib-devhelp 1.0.2 py_0
sphinxcontrib-htmlhelp 1.0.3 py_0
sphinxcontrib-jsmath 1.0.1 py_0
sphinxcontrib-qthelp 1.0.3 py_0
sphinxcontrib-serializinghtml 1.1.4 py_0
spyder 4.1.5 py37_0
spyder-kernels 1.9.4 py37_0
sqlite 3.33.0 h2a8f88b_0
stable-baselines 2.10.2a0 dev_0 <develop>
tabulate 0.8.7 pypi_0 pypi
tensorboard 2.0.0 pyhb38c66f_1
tensorflow 1.15.0 eigen_py37h9f89a44_0
tensorflow-base 1.15.0 eigen_py37h07d2309_0
tensorflow-estimator 1.15.1 pyh2649769_0
termcolor 1.1.0 py37_1
testpath 0.4.4 py_0
toml 0.10.1 py_0
tornado 6.0.4 py37he774522_1
traitlets 5.0.5 py_0
typed-ast 1.4.1 py37he774522_0
ujson 4.0.1 py37ha925a31_0
urllib3 1.25.11 py_0
vc 14.1 h0510ff6_4
vs2015_runtime 14.16.27012 hf0eaf9b_3
watchdog 0.10.3 py37_0
wcwidth 0.2.5 py_0
webencodings 0.5.1 py37_1
werkzeug 0.16.1 py_0
wheel 0.35.1 py_0
win_inet_pton 1.1.0 py37_0
wincertstore 0.2 py37_0
wrapt 1.11.2 py37he774522_0
yaml 0.2.5 he774522_0
yapf 0.30.0 py_0
yfinance 0.1.55 pypi_0 pypi
zeromq 4.3.2 ha925a31_3
zipp 3.3.1 py_0
zlib 1.2.11 h62dcd97_4
【问题讨论】:
【参考方案1】:我认为问题在于您在稳定基线中使用了默认的网络结构。 您可以在示例中看到:
model = Sequential()
model.add(Dense(4, init='lecun_uniform', input_shape=(2,)))
model.add(Activation('relu'))
model.add(Dense(4, init='lecun_uniform'))
model.add(Activation('relu'))
model.add(Dense(4, init='lecun_uniform'))
model.add(Activation('linear'))
rms = RMSprop()
model.compile(loss='mse', optimizer=rms)
所以,这是一个非常简单的网络,有 3 层,每层有 4 个神经元。在 stable-baselines 中,您使用默认的MlpPolicy
,两层有 64 个神经元。您可以通过传递给模型policy_kwargs
参数来轻松指定网络结构,如下所示:
policy_kwargs = dict(
net_arch=[4, 4, 4]
)
您的 DQN 模型可以通过以下方式初始化:
model = DQN('MlpPolicy', env, policy_kwargs=policy_kwargs, verbose=1)
另外。在您的第一个示例中,作者使用一个网络创建了简单的 DQN 模型。然而,在稳定基线等框架中,DQN 算法包括两个相同结构的网络用于训练和评估。这对于更复杂的问题很有用,而对于像您这样简单的问题,它可能效果不佳。
【讨论】:
以上是关于稳定的基线:DQN 表现不佳?的主要内容,如果未能解决你的问题,请参考以下文章