(pytorch复现)基于深度强化学习(CNN+dueling network/DQN/DDQN/D3QN)的自适应车间调度(JSP)

Posted 码丽莲梦露

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了(pytorch复现)基于深度强化学习(CNN+dueling network/DQN/DDQN/D3QN)的自适应车间调度(JSP)相关的知识,希望对你有一定的参考价值。

为了深入学习各种深度学习网络和强化学习的结合,实现了一下下列文章:

Research on Adaptive Job Shop Scheduling Problems Based on Dueling Double DQN | IEEE Journals & Magazine | IEEE Xplore

状态、动作、奖励函数及实验的简单介绍可参考:

基于深度强化学习的自适应作业车间调度问题研究_松间沙路的博客-CSDN博客_强化学习调度

整体代码复现可见个人Github:Aihong-Sun/DQN-DDQN-Dueling_networ-D3QN-_for_JSP: pytorch implementation of DQN/DDQN/Dueling_networ/D3QN for job shop scheudling problem (github.com)

1 状态特征提取

首先从特征提取开始,原文的状态特征为3个网格矩阵,如下:

于是搭建CNN进行特征提取:

不太了解CNN的可以参考:卷积神经网络(CNN)详解 - 知乎 (zhihu.com)

self.conv1=nn.Sequential(
            nn.Conv2d(
                in_channels=3,  
                out_channels=6,
                kernel_size=3,
                stride=1,
                padding=1,  
            ),      # output shape (3,J_num,O_max_len)
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2,ceil_mode=False)  
        )

 1.1 卷积层

上诉状态可看作是长、宽为工件数、工件最大工序数(若不相等时取最大,但原文涉及的算例都为工序数相等),深度为3的一个图像,于是in_channels=3,out_channels可按自己的需求进行设计,这里我设为6,即有6个卷积核,输出的图像的深度就为6,为保证图片的长宽不变(便于后面针对不同工件数、工序数的全连接层的计算),用0在图像边缘处进行填充,计算方法如下:

假设Jnum=6,O_max_len=6,于是,令卷积核kernel_size=3,stride=1,padding=1,此时,W2=(6-3+2)/1+1=6,H2=(6-3+2)/1+1=6,于是图片大小没变。

1.2 池化层

它的作用是用来逐渐降低数据体的空间尺寸,这样的话就能减少网络中参数的数量,使得计算资源耗费变少,也能有效控制过拟合。

池化层使用 MAX 操作,对输入数据体的每一个深度切片独立进行操作,改变它的空间尺寸。最常见的形式是汇聚层使用尺寸2x2的滤波器,以步长为2来对每个深度切片进行降采样。

这里设置kernel_size=2,即对图像缩小一般,ceil_mode=False,即针对工件为奇数的情况,比如Jnum=7,Omax_len=7,图像缩小边缘的数则不取,于是生成新的图像大小为(3,3),若ceil_mode=False,生成新的图像大小则为(4,4).

2 动作

动作为17条规则,具体可见上诉给出的个人Github

3DQN/DDQN/Dueling Network/D3QN

3.1 DQN与DDQN

下面第一个式子为DQN的目标函数,第二个式子为DDQN的目标函数:

DDQN与DQN大部分都相同,只有一步不同,那就是在选择Q(s_t+1,a_t+1)的过程中,DQN总是选择Target Q网络的最大输出值。而DDQN不同,DDQN首先从Q网络中找到最大输出值的那个动作,然后再找到这个动作对应的Target Q网络的输出值。这么做的原因是传统的DQN通常会高估Q值得大小,两者代码差别如下:

 if self.double:      #当为DQN时
            # q_eval
            q_eval = self.eval_net(batch_state).gather(1, batch_action)
            q_next_eval=self.eval_net( batch_next_state).detach()
            q_next = self.target_net(batch_next_state).detach()
            q_a=q_next_eval.argmax(dim=1)
            q_a=torch.reshape(q_a,(-1,len(q_a)))
            q_target = batch_reward + self.GAMMA * q_next.gather(1, q_a)
        else:       #当为DDQN时
            #q_eval
            q_eval = self.eval_net(batch_state).gather(1,batch_action)
            q_next = self.target_net(batch_next_state).detach()
            q_target = batch_reward + self.GAMMA * q_next.max(1)[0].view(self.BATCH_SIZE, 1)

        loss = self.loss_func(q_eval, q_target)

        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

3.2 DQN 与Dueling Network

Dueling network 是一篇来自2015年的论文,这篇论文提出了一个新的网络架构,这个架构不但提高了最终效果,而且还可以和其他的算法相结合以获取更加优异的表现。

之前的DQN网络在将图片卷积获取特征之后会输入几个全连接层,经过训练直接输出在该state下各个action的价值也就是Q(s,a)。而Dueling network则不同,它在卷积网络之后引出了两个不同的分支,一个分支用于预测state的价值,另一个用于预测每个action的优势。最后将这两个分支的结果合并输出Q(s,a),两者的网络结构如下(上为DQN,下为Dueling network)

代码上的区别:

DQN:

class CNN_FNN(nn.Module):
    """docstring for Net"""
    def __init__(self,J_num,O_max_len):
        super(CNN_FNN, self).__init__()
        # summary(self.conv1,(3,6,6))
        self.fc1 = nn.Linear(6*int(J_num/2)*int(O_max_len/2), 258)
        self.fc2 = nn.Linear(258,258)
        self.out = nn.Linear(258,17)

    def forward(self,x):
        x=self.conv1(x)
        x=x.view(x.size(0),-1)
        x = self.fc1(x)
        x = F.relu(x)
        x = self.fc2(x)
        x = F.relu(x)
        action_prob = self.out(x)
        return action_prob

 Dueling Network:

class CNN_dueling(nn.Module):
    def __init__(self,J_num,O_max_len):
        super(CNN_dueling, self).__init__()
        self.conv1=nn.Sequential(
            nn.Conv2d(
                in_channels=3,  #input shape (3,J_num,O_max_len)
                out_channels=6,
                kernel_size=3,
                stride=1,
                padding=1,  #使得出来的图片大小不变P=(3-1)/2,
            ),      # output shape (3,J_num,O_max_len)
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2,ceil_mode=False)  #output shape:  (6,int(J_num/2),int(O_max_len/2))
        )
        # summary(self.conv1,(3,6,6))
        self.val_hidden = nn.Linear(6*int(J_num/2)*int(O_max_len/2), 258)
        self.adv_hidden=nn.Linear(6*int(J_num/2)*int(O_max_len/2), 258)

        self.val=nn.Linear(258,1)
        self.adv = nn.Linear(258,17)

    def forward(self,x):
        x=self.conv1(x)
        x=x.view(x.size(0),-1)

        val_hidden = self.val_hidden(x)
        val_hidden = F.relu(val_hidden)

        adv_hidden = self.adv_hidden(x)
        adv_hidden = F.relu(adv_hidden)

        val = self.val(val_hidden)
        adv = self.adv(adv_hidden)

        adv_ave = torch.mean(adv, dim=1, keepdim=True)
        x = adv + val - adv_ave
        return x

Dueling架构的好处:

(1)Dueling network与DQN最主要的不同就是将State与action进行了一定程度的分离,虽然最终的输出依然相同,但在计算的过程中,state不再完全依赖于action的价值来进行判断,可以进行单独的价值预测。这其实是十分有用的,模型既可以学习到某一个state的价值是多少,也可以学习到在该state下不同action的价值是多少,它可以对环境中的state和action进行相对独立而又紧密结合的观察学习,可以进行更灵活的处理。同时在某些state中,action的选择并不会对state产生影响,这时候Dueling模型就可以有更加强大的表现。

(2)在具有多个冗余或者近似的动作时,Dueling可以比DQN更快的识别出策略中的正确操作。

作者在论文中给出了两个不同的改进公式,需要提前说明的是,这两个公式的最终效果是类似的。都可以使用:

不同点在于,第一个公式是每一个A都减去A的最大值,第二个公式是每一个A都减去A的平均值。这样即使V和A都分别加减同样的的常数,最终的结果也不会相同。

3.3 D3QN

D3QN(Dueling Double DQN)是结合了Dueling DQN和Double DQN的优点。

4 调参及相关说明

关于这种自定义环境,收敛是需要通过不断调参实现的,之前看到一篇文章比较好的讲了调参过程,大家可以参考一下:

启人zhr:强化学习中的调参经验与编程技巧(on policy 篇)

 5 部分代码展示

5.1 JSP_Env.py

import copy
import matplotlib.pyplot as plt
import numpy as np

def Gantt(Machines):
    plt.rcParams['font.sans-serif'] = ['Times New Roman']  # 如果要显示中文字体,则在此处设为:SimHei
    plt.rcParams['axes.unicode_minus'] = False  # 显示负号
    M = ['red', 'blue', 'yellow', 'orange', 'green', 'palegoldenrod', 'purple', 'pink', 'Thistle', 'Magenta',
         'SlateBlue', 'RoyalBlue', 'Cyan', 'Aqua', 'floralwhite', 'ghostwhite', 'goldenrod', 'mediumslateblue',
         'navajowhite', 'navy', 'sandybrown', 'moccasin']
    Job_text = ['J' + str(i + 1) for i in range(100)]
    Machine_text = ['M' + str(i + 1) for i in range(50)]

    for i in range(len(Machines)):
        for j in range(len(Machines[i].start)):
            if Machines[i].finish[j] - Machines[i].start[j]!= 0:
                plt.barh(i, width=Machines[i].finish[j] - Machines[i].start[j],
                         height=0.8, left=Machines[i].start[j],
                         color=M[Machines[i]._on[j]],
                         edgecolor='black')
                plt.text(x=Machines[i].start[j]+(Machines[i].finish[j] - Machines[i].start[j])/2 - 0.1,
                         y=i,
                         s=Job_text[Machines[i]._on[j]],
                         fontsize=12)
    plt.show()

class Machine:
    def __init__(self,idx):
        self.idx=idx
        self.start=[]
        self.finish=[]
        self._on=[]
        self.end=0

    def handling(self,Ji,pt):
        s=self.insert(Ji,pt)
        # if self.end<=Ji.end:
        #     s=Ji.end
        # else:
        #     s=self.end
        e=s+pt
        self.start.append(s)
        self.finish.append(e)
        self.start.sort()
        self.finish.sort()
        self._on.insert(self.start.index(s),Ji.idx)
        if self.end<e:
            self.end=e
        Ji.update(s,e)

    def Gap(self):
        Gap=0
        if self.start==[]:
            return 0
        else:
            Gap+=self.start[0]
            if len(self.start)>1:
                G=[self.start[i+1]-self.finish[i] for i in range(0,len(self.start)-1)]
                return Gap+sum(G)
            return Gap

    def judge_gap(self,t):
        Gap = []
        if self.start == []:
            return Gap
        else:
            if self.start[0]>0 and self.start[0]>t:
                Gap.append([0,self.start[0]])
            if len(self.start) > 1:
                Gap.extend([[self.finish[i], self.start[i + 1]] for i in range(0, len(self.start) - 1) if
                            self.start[i + 1] - self.finish[i] > 0 and self.start[i + 1] > t])
                return Gap
            return Gap

    def insert(self,Ji,pt):
        start=max(Ji.end,self.end)
        Gap=self.judge_gap(Ji.end)
        if Gap!=[]:
            for Gi in Gap:
                if Gi[0]>=Ji.end and Gi[1]-Gi[0]>=pt:
                    return Gi[0]
                elif Gi[0]<Ji.end and Gi[1]-Ji.end>=pt:
                    return Ji.end
        return start

class Job:
    def __init__(self,idx,max_ol):
        self.idx=idx
        self.start=0
        self.end=0
        self.op=0
        self.max_ol=max_ol
        self.Gap=0
        self.l=0

    def wether_end(self):
        if self.op<self.max_ol:
            return False
        else:
            return True

    def update(self,s,e):
        self.op+=1
        self.end=e
        self.start=s
        self.l=self.l+e-s

class JSP_Env:
    def __init__(self,n,m,PT,M):
        self.n,self.m=n,m
        self.O_max_len=len(PT[0])
        self.PT=copy.copy(PT)
        self.M=M
        self.finished=[]
        self.Num_finished=0
        self.g=0


    def Create_Item(self):
        self.Jobs=[]
        for i in range(self.n):
            Ji=Job(i,len(self.PT[i]))
            self.Jobs.append(Ji)
        self.Machines=[]
        for i in range(self.n):
            Mi=Machine(i)
            self.Machines.append(Mi)

    def C_max(self):
        m=0
        for Mi in self.Machines:
            if Mi.end>m:
                m=Mi.end
        return m

    def reset(self):
        self.u=0
        self.P = 0  # total working time
        self.finished=[]
        self.Num_finished=0
        done=False
        self.Create_Item()
        self.S1_Matrix = np.array(copy.copy(self.PT))
        self.S2_Matrix = np.zeros_like(self.S1_Matrix)
        self.S3_Matrix = np.zeros_like(self.S1_Matrix)
        self.s=np.stack((self.S1_Matrix,self.S2_Matrix,self.S3_Matrix),0)
        # s=self.s.flatten()
        return self.s,done

    def Gap(self):
        G=0
        for Mi in self.Machines:
            G+=Mi.Gap()
        return G/self.C_max()

    def U(self):
        C_max = self.C_max()
        return self.P/(self.m*C_max)

    def step(self,action):
        # print(action)
        done=False
        # if action in self.finished:
        #     s=self.s.flatten()
        #     return s,-999,done
        Ji=self.Jobs[action]
        op=Ji.op
        # print('a',action,op)
        pt=self.PT[action][op]
        self.P+=pt
        self.s[0][action][op] = 0
        Mi=self.Machines[self.M[action][op]]
        Mi.handling(Ji,pt)
        self.s[1][action][op]=Ji.end
        if Ji.wether_end():
            self.finished.append(action)
            self.Num_finished+=1
        if self.Num_finished==self.n:
            done=True
        Gap=self.Gap()
        self.s[2][action][op] =Gap
        u=self.U()
        r=u-self.u
        self.u=u
        # s=self.s.flatten()
        return self.s,r,done


if __name__=="__main__":

    from Dataset.data_extract import change
    from Actor_Critic_for_JSP.action_space import Dispatch_rule
    import random
    n, m, PT, MT = change('ft', 6)
    print(PT)
    print()
    jsp=JSP_Env(n, m, PT, MT)
    os1=[]
    for i in range(len(PT)):
        for j in range(len(PT[i])):
            os1.append(i)
    s,done=jsp.reset()
    while not done:
        a=random.randint(0,16)
        print('dispatch rule',a)
        a=Dispatch_rule(a,jsp)
        print('this is action',a)
        s, r, done=jsp.step(a)
        print(r)
        print(done)
        os1.remove(a)
        shape=len(s)
    Gantt(jsp.Machines)
    print(jsp.C_max())

5.2 RL_network.py

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

class CNN_FNN(nn.Module):
    """docstring for Net"""
    def __init__(self,J_num,O_max_len):
        super(CNN_FNN, self).__init__()
        # summary(self.conv1,(3,6,6))
        self.fc1 = nn.Linear(6*int(J_num/2)*int(O_max_len/2), 258)
        self.fc2 = nn.Linear(258,258)
        self.out = nn.Linear(258,17)

    def forward(self,x):
        x=self.conv1(x)
        x=x.view(x.size(0),-1)
        x = self.fc1(x)
        x = F.relu(x)
        x = self.fc2(x)
        x = F.relu(x)
        action_prob = self.out(x)
        return action_prob


class CNN_dueling(nn.Module):
    def __init__(self,J_num,O_max_len):
        super(CNN_dueling, self).__init__()
        self.conv1=nn.Sequential(
            nn.Conv2d(
                in_channels=3,  #input shape (3,J_num,O_max_len)
                out_channels=6,
                kernel_size=3,
                stride=1,
                padding=1,  #使得出来的图片大小不变P=(3-1)/2,
            ),      # output shape (3,J_num,O_max_len)
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2,ceil_mode=False)  #output shape:  (6,int(J_num/2),int(O_max_len/2))
        )
        # summary(self.conv1,(3,6,6))
        self.val_hidden = nn.Linear(6*int(J_num/2)*int(O_max_len/2), 258)
        self.adv_hidden=nn.Linear(6*int(J_num/2)*int(O_max_len/2), 258)

        self.val=nn.Linear(258,1)
        self.adv = nn.Linear(258,17)

    def forward(self,x):
        x=self.conv1(x)
        x=x.view(x.size(0),-1)

        val_hidden = self.val_hidden(x)
        val_hidden = F.relu(val_hidden)

        adv_hidden = self.adv_hidden(x)
        adv_hidden = F.relu(adv_hidden)

        val = self.val(val_hidden)
        adv = self.adv(adv_hidden)

        adv_ave = torch.mean(adv, dim=1, keepdim=True)
        x = adv + val - adv_ave
        return x

5.3 Agent.py

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from Actor_Critic_for_JSP.Agent.RL_network import CNN_FNN,CNN_dueling
from Actor_Critic_for_JSP.Memory.Memory import Memory
from Actor_Critic_for_JSP.Memory.PreMemory import preMemory

class Agent():
    """docstring for DQN"""
    def __init__(self,n,O_max_len,dueling=False,double=False,PER=False):
        self.double=double
        self.PER=PER
        self.GAMMA=1
        self.n=n
        self.O_max_len=O_max_len
        super(Agent, self).__init__()
        if dueling:
            self.eval_net, self.target_net = CNN_dueling(self.n,self.O_max_len), CNN_dueling(self.n,self.O_max_len)
        else:
            self.eval_net, self.target_net = CNN_FNN(self.n, self.O_max_len), CNN_FNN(self.n, self.O_max_len)
        self.Q_NETWORK_ITERATION=100
        self.BATCH_SIZE=256
        self.learn_step_counter = 0
        self.memory_counter = 0
        if PER:
            self.memory = preMemory()
        else:
            self.memory = Memory()
        self.EPISILO=0.8
        self.optimizer = torch.optim.Adam(self.eval_net.parameters(), lr=0.00001)
        self.loss_func = nn.MSELoss()

    def choose_action(self, state):
        state=np.reshape(state,(-1,3,self.n,self.O_max_len))
        state=torch.FloatTensor(state)
        # print(state.size())
        # state = torch.unsqueeze(torch.FloatTensor(state), 0) # get a 1D array
        if np.random.randn() <= self.EPISILO:# greedy policy
            action_value = self.eval_net.forward(state)
            action = torch.max(action_value, 1)[1].data.numpy()[0]
            # action = action[0] if ENV_A_SHAPE == 0 else action.reshape(ENV_A_SHAPE)
        else: # random policy
            action = np.random.randint(0,17)
            # action = action if ENV_A_SHAPE ==0 else action.reshape(ENV_A_SHAPE)
        self.EPISILO=min(0.001,self.EPISILO-0.00001)
        return action

    def PER_error(self,state, action, reward, next_state):

        state = torch.FloatTensor(np.reshape(state, (-1, 3, self.n, self.O_max_len)))
        next_state= torch.FloatTensor(np.reshape(next_state, (-1, 3, self.n, self.O_max_len)))
        p=self.eval_net.forward(state)
        p_=self.eval_net.forward(next_state)
        p_target=self.target_net(state)

        if self.double:
            q_a=p_.argmax(dim=1)
            q_a=torch.reshape(q_a,(-1,len(q_a)))
            qt=reward+self.GAMMA*p_target.gather(1,q_a)
        else:
            qt=reward+self.GAMMA*p_target.max(1)[0].view(self.BATCH_SIZE, 1)
        qt=qt.detach().numpy()
        p=p.detach().numpy()
        errors=np.abs(p[0][action]-qt[0][0])
        return errors

    def store_transition(self, state, action, reward, next_state):
        if self.PER:
            errors=self.PER_error(state, action, reward, next_state)
            self.memory.remember((state, action, reward, next_state), errors)
            self.memory_counter += 1
        else:
            self.memory.remember((state, action, reward, next_state))
            self.memory_counter+=1

    def learn(self):

        #update the parameters
        if self.learn_step_counter % self.Q_NETWORK_ITERATION ==0:
            self.target_net.load_state_dict(self.eval_net.state_dict())
        self.learn_step_counter+=1

        batch=self.memory.sample(self.BATCH_SIZE)

        #sample batch from memory
        batch_state=np.array([o[0] for o in batch])
        batch_next_state= np.array([o[3] for o in batch])
        batch_action=np.array([o[1] for o in batch])
        batch_reward=np.array([o[1] for o in batch])


        batch_action = torch.LongTensor(np.reshape(batch_action, (-1, len(batch_action))))
        batch_reward =  torch.LongTensor(np.reshape(batch_reward, (-1, len(batch_reward))))

        batch_state=torch.FloatTensor(np.reshape(batch_state, (-1, 3, self.n, self.O_max_len)))
        batch_next_state =torch.FloatTensor(np.reshape(batch_next_state, (-1, 3, self.n, self.O_max_len)))

        if self.double:
            # q_eval
            q_eval = self.eval_net(batch_state).gather(1, batch_action)
            q_next_eval=self.eval_net( batch_next_state).detach()
            q_next = self.target_net(batch_next_state).detach()
            q_a=q_next_eval.argmax(dim=1)
            q_a=torch.reshape(q_a,(-1,len(q_a)))
            q_target = batch_reward + self.GAMMA * q_next.gather(1, q_a)
        else:
            #q_eval
            q_eval = self.eval_net(batch_state).gather(1,batch_action)
            q_next = self.target_net(batch_next_state).detach()
            q_target = batch_reward + self.GAMMA * q_next.max(1)[0].view(self.BATCH_SIZE, 1)

        loss = self.loss_func(q_eval, q_target)

        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

5.4 train.py

from Actor_Critic_for_JSP.JSP_env import JSP_Env,Gantt
import matplotlib.pyplot as plt
from Actor_Critic_for_JSP.Dataset.data_extract import change
from Actor_Critic_for_JSP.action_space import Dispatch_rule
from Actor_Critic_for_JSP.Agent.Agent import Agent

def main(Agent,env,batch_size):
    Reward_total = []
    C_total = []
    rewards_list = []
    C = []

    episodes = 8000
    print("Collecting Experience....")
    for i in range(episodes):
        print(i)
        state,done = env.reset()
        ep_reward = 0
        while True:

            action = Agent.choose_action(state)

            a=Dispatch_rule(action,env)
            try:
                next_state, reward, done = env.step(a)
            except:
                print(action,a)

            Agent.store_transition(state, action, reward, next_state)
            ep_reward += reward
            if Agent.memory_counter >= batch_size:
                Agent.learn()
                if done and i%1==0:
                    ret, f, C1, R1 = evaluate(i,Agent,env)
                    Reward_total.append(R1)
                    C_total.append(C1)
                    rewards_list.append( ep_reward)
                    C.append(env.C_max())
            if done:
                # Gantt(env.Machines)
                break
            state = next_state
    x = [_ for _ in range(len(C))]
    plt.plot(x, rewards_list)
    # plt.show()
    plt.plot(x, C)
    # plt.show()
    return Reward_total,C_total

def evaluate(i,Agent,env):
    returns = []
    C=[]
    for  total_step in range(10):
        state, done = env.reset()
        ep_reward = 0
        while True:
            action = Agent.choose_action(state)
            a = Dispatch_rule(action, env)
            try:
                next_state, reward, done = env.step(a)
            except:
                print(action,a)
            ep_reward += reward
            if done == True:
                fitness = env.C_max()
                C.append(fitness)
                break
        returns.append(ep_reward)
    print('time step:',i,'','Reward :',sum(returns)/10 ,'','C_max:',sum(C) /10)
    return sum(returns) / 10,sum(C) /10,C,returns


if __name__ == '__main__':
    import pickle
    import os

    n, m, PT, MT = change('la', 16)

    f=r'.\\result\\la'
    if not os.path.exists(f):
        os.mkdir(f)
    f1=os.path.join(f,'la'+'16')
    if not os.path.exists(f1):
        os.mkdir(f1)
    print(n, m, PT, MT)
    env = JSP_Env(n, m, PT, MT)
    # (0,0)CNN+FNN+DQN (1,0):CNN+Dueling network+DQN (0,1):CNN+FNN+DDQN (1,1):CNN+Dueling network+DDQN
    agent=Agent(env.n,env.O_max_len,1,1)
    Reward_total,C_total=main(agent,env,100)
    print(os.path.join(f1, 'C_max' + ".pkl"))
    with open(os.path.join(f1, 'C_max' + ".pkl"), "wb") as f2:
        pickle.dump(C_total, f2, pickle.HIGHEST_PROTOCOL)
    with open(os.path.join(f1, 'Reward' + ".pkl"), "wb") as f3:
        pickle.dump(Reward_total, f3, pickle.HIGHEST_PROTOCOL)

以上是关于(pytorch复现)基于深度强化学习(CNN+dueling network/DQN/DDQN/D3QN)的自适应车间调度(JSP)的主要内容,如果未能解决你的问题,请参考以下文章

(pytorch复现)基于深度强化学习(CNN+dueling network/DQN/DDQN/D3QN/PER)的自适应车间调度(JSP)

重磅!深度学习网络模型大全来了(基于TensorFlow和PyTorch的开源复现)

深度强化学习 Policy Gradients 模型解析,附Pytorch完整代码

三维几何学习从零开始网格上的深度学习-2:卷积网络CNN篇(Pytorch)

深度强化学习 DDPG 模型解析,附Pytorch完整代码

复现深度强化学习论文经验之谈