开发者社区 > 大数据与机器学习 > 正文

强化学习训练的很好,但是模型无法有效复用,请各路高手指点迷津

主程序gymnasium_example.py内容为:

import gymnasium as gym
from stable_baselines3 import DQN
from stable_baselines3.common.env_util import make_vec_env
from gymnasium import spaces
import gymnasium
import gymnasium_env
import random
import time 
import cv2
from collections import deque
import gymnasium as gym
import torch
import sys 
from stable_baselines3 import PPO
from stable_baselines3.common.evaluation import evaluate_policy
import time
import numpy as np
import socket  
import threading  
import traceback 
import json  

#写入文件,追加
def log_add(neirong,file_log):
    f = open(file_log,'a',encoding='utf-8')
    neirong = neirong + "\n"
    f.write(neirong)
    f.close()
#写入文件,代替
def log_write(neirong,file_log):
    f = open(file_log,'w',encoding='utf-8')
    neirong = neirong
    f.write(neirong)
    f.close()
score = 0
total_timesteps = 100000
vec_env = make_vec_env("gymnasium_env/GridWorld-v0", n_envs=1)
#device = torch.device('gpu')
device = torch.device('cuda')
model = DQN("MultiInputPolicy", vec_env, verbose=1,device=device)
model.learn(total_timesteps=total_timesteps)
model.save("gymnasium_fangkuai_28")
#shenglv_xunlian = vec_env.times_win/total_timesteps
log_add("------------------------------------------------------------","game_result.txt")
log_add(f"------------------------------------------------------------","fangkuai_log.txt")
#input('开始演示吗?')
del model # remove to demonstrate saving and loading

vec_env = make_vec_env("gymnasium_env/GridWorld-v0", n_envs=1)
device = torch.device('cuda')
model = DQN.load("gymnasium_fangkuai_28",env=vec_env)

obs = vec_env.reset()
#vec_env2.reset()

score = 0
print(f'obs:{type(obs)}')
#while True:
for i in range(total_timesteps):
    action, _states = model.predict(obs)
    print(f'type_action:{type(action)}')
    print(f'action:{action}')
    obs, rewards, dones, info = vec_env.step(action)
    if dones:
        vec_env.reset()

游戏程序grid_world.py内容为:

from enum import Enum
import gymnasium as gym
from gymnasium import spaces
import pygame
import numpy as np
import random
from collections import deque
import cv2


#写入文件,追加
def log_add(neirong,file_log):
    f = open(file_log,'a',encoding='utf-8')
    neirong = neirong + "\n"
    f.write(neirong)
    f.close()
#写入文件,代替
def log_write(neirong,file_log):
    f = open(file_log,'w',encoding='utf-8')
    neirong = neirong
    f.write(neirong)
    f.close()


class GridWorldEnv(gym.Env):
    global score
    metadata = {"render_modes": ["human", "rgb_array"], "render_fps": 4}

    def __init__(self, render_mode=None, size=5):

        self.jishu = 0

    #++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
        self.k_boss = random.choice((115,119,97,100))
        #初始化纸牌


        self.times_lose = 0
        self.times_win = 0

        self.qipan_size = 50
        self.action_list = {}
        self.action_list[0] = 107
        self.action_list[1] = 105
        self.action_list[2] = 106
        self.action_list[3] = 108
        self.action_list[4] = -1


        len_x = self.qipan_size
        len_y = self.qipan_size

        len_x_boss = 10
        len_y_boss = 10
        len_x_me = 5
        len_y_me = 5
        len_x_base = 13 
        len_y_base = 13 

        self.cmds = deque()
        self.qipan_list = []
        self.boss_list = []
        self.me_list = []
        self.base_list = []

        self.qipan = np.zeros((len_x, len_y))
        self.qipan = self.qipan.astype(np.uint8)
        self.qipan = np.full((len_x, len_y),0, dtype=np.uint8)
        self.qipan_show = self.qipan.copy()
        self.me = np.full((len_x_me, len_y_me),100, dtype=np.uint8)
        self.boss = np.full((len_x_boss, len_y_boss),200, dtype=np.uint8)
        self.base = np.full((len_x_base, len_y_base),50, dtype=np.uint8)

        self.boss_x1 = 0
        self.boss_x2 = self.boss.shape[1]
        self.boss_y1 = self.qipan.shape[0] - self.boss.shape[0]
        self.boss_y2 = self.qipan.shape[0]

        self.me_x1 = 0
        self.me_x2 = self.me.shape[1]
        self.me_y1 = 0 
        self.me_y2 = self.me.shape[0]

        self.base_x1 = self.qipan.shape[1] - self.base.shape[1]
        self.base_x2 = self.qipan.shape[1]
        self.base_y1 = self.qipan.shape[0] - self.base.shape[0] 
        self.base_y2 = self.qipan.shape[0]
        self.observation_space = spaces.Dict(
            {
                #"qipan": spaces.Box(low=0, high=(self.qipan_size-1), shape=(self.qipan_size*self.qipan_size,), dtype=int),
                "qipan": spaces.Box(low=0, high=(self.qipan_size*self.qipan_size-1), shape=(self.qipan_size*self.qipan_size,), dtype=int),
            }
        )


        self.action_space = spaces.Discrete(5)  # 5个动作
        #self.state = None


    #-----------------------------------------------------------------


    def reset(self, seed=None, options=None):
        self.jishu = 0


    #+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
        super().reset(seed=seed)
        self.boss_x1 = 0
        self.boss_x2 = self.boss.shape[1]
        self.boss_y1 = self.qipan.shape[0] - self.boss.shape[0]
        self.boss_y2 = self.qipan.shape[0]

        self.me_x1 = 0
        self.me_x2 = self.me.shape[1]
        self.me_y1 = 0 
        self.me_y2 = self.me.shape[0]

        self.base_x1 = self.qipan.shape[1] - self.base.shape[1]
        self.base_x2 = self.qipan.shape[1]
        self.base_y1 = self.qipan.shape[0] - self.base.shape[0] 
        self.base_y2 = self.qipan.shape[0]




        self.qipan_show[self.boss_x1:self.boss_x2,self.boss_y1:self.boss_y2] = self.boss
        self.qipan_show[self.me_x1:self.me_x2,self.me_y1:self.me_y2] = self.me
        self.qipan_show[self.base_x1:self.base_x2,self.base_y1:self.base_y2] = self.base

        #cv2.imshow('Gobang Board', cv2.resize(self.qipan_show,(700,700)))  
        #self.k = cv2.waitKey(1)  # 等待按键按下  


        observation_fangkuai = {"qipan":self.qipan.flatten()}
        info_fangkuai = {}
        return observation_fangkuai, info_fangkuai

    #--------------------------------------------------------------------------------------
    def baohan_panduan(self,sites_x,point):
        if point[0] > sites_x[0] and point[0] <sites_x[1] and point[1] > sites_x[2] and point[1] <sites_x[3]:
            return True
        return False 

    def pengzhuang_panduan(self,sites_a,sites_b):
        point_a_zuoshang = (sites_a[0],sites_a[2])
        point_a_youshang = (sites_a[1],sites_a[2])
        point_a_zuoxia = (sites_a[0],sites_a[3])
        point_a_youxia = (sites_a[1],sites_a[3])

        point_b_zuoshang = (sites_b[0],sites_b[2])
        point_b_youshang = (sites_b[1],sites_b[2])
        point_b_zuoxia = (sites_b[0],sites_b[3])
        point_b_youxia = (sites_b[1],sites_b[3])
        points_a = []
        points_b = []

        points_a.append(point_a_zuoshang)
        points_a.append(point_a_youshang)
        points_a.append(point_a_zuoxia)
        points_a.append(point_a_youxia)

        points_b.append(point_b_zuoshang)
        points_b.append(point_b_youshang)
        points_b.append(point_b_zuoxia)
        points_b.append(point_b_youxia)
        for point in points_a:
            if self.baohan_panduan(sites_b,point):
                return True

        for point in points_b:
            if self.baohan_panduan(sites_a,point):
                return True
        return False

    def step(self, action):
        self.jishu += 1
    #+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
        #cv2.imshow('Gobang Board', cv2.resize(self.qipan_show,(700,700)))  
        #self.k = cv2.waitKey(1)  # 等待按键按下  

        self.k = self.action_list[action]

        if self.k == -1:
            pass
        if self.k == 107 and self.me_x2 < self.qipan.shape[1]:
            self.me_x1 += 1
            self.me_x2 += 1
        if self.k == 105 and self.me_x1 > 0:
            self.me_x1 -= 1
            self.me_x2 -= 1
        if self.k == 106 and self.me_y1 > 0:
            self.me_y1 -= 1
            self.me_y2 -= 1
        if self.k == 108 and self.me_y2 < self.qipan.shape[0]:
            self.me_y1 += 1
            self.me_y2 += 1


        if self.k_boss == -1:
            pass
        if len(self.cmds) == 0:
            cmd1 = random.choice((115,119,97,100))
            cmd2 = random.choice((115,119,97,100))
            cmd_times = random.randint(1,self.qipan_size)
            for i in range(cmd_times):
                self.cmds.append(cmd1)
                self.cmds.append(cmd2)

        self.k_boss = self.cmds.popleft()
        if self.k_boss == 115 and self.boss_x2 < self.qipan.shape[1]:
            self.boss_x1 += 1
            self.boss_x2 += 1
        if self.k_boss == 119 and self.boss_x1 > 0:
            self.boss_x1 -= 1
            self.boss_x2 -= 1
        if self.k_boss == 97 and self.boss_y1 > 0:
            self.boss_y1 -= 1
            self.boss_y2 -= 1
        if self.k_boss == 100 and self.boss_y2 < self.qipan.shape[0]:
            self.boss_y1 += 1
            self.boss_y2 += 1
        boss_site = (self.boss_x1,self.boss_x2,self.boss_y1,self.boss_y2)
        me_site = (self.me_x1,self.me_x2,self.me_y1,self.me_y2)
        base_site = (self.base_x1,self.base_x2,self.base_y1,self.base_y2)
        self.qipan_show = self.qipan.copy()
        self.qipan_show[self.boss_x1:self.boss_x2,self.boss_y1:self.boss_y2] = self.boss
        self.qipan_show[self.me_x1:self.me_x2,self.me_y1:self.me_y2] = self.me
        self.qipan_show[self.base_x1:self.base_x2,self.base_y1:self.base_y2] = self.base
        reward = 1 
        terminated = False
        if self.pengzhuang_panduan(boss_site,me_site) or self.pengzhuang_panduan(me_site,boss_site):
            print('pengzhuang!!!!!!!!!!!!!!!!!!!!')
            reward = -100
            terminated = True
            log_add(f"lose {self.jishu}","fangkuai_log.txt")
            self.times_lose += 1
        if self.pengzhuang_panduan(base_site,me_site) or self.pengzhuang_panduan(me_site,base_site):
            print('base!!!!!!!!!!!!!!!!!!!!')
            reward = 100
            log_add(f"win {self.jishu}","fangkuai_log.txt")
            terminated = True
            self.times_win += 1

        observation_fangkuai = {"qipan":self.qipan.flatten()}
        truncated = False
        info_fangkuai = {}
        return observation_fangkuai, reward, terminated, False, info_fangkuai

    #-------------------------------------------------------------
    def render(self):
        #if self.render_mode == "rgb_array":
        #    return self._render_frame()
        pass

    def close(self):
        if self.window is not None:
            pygame.display.quit()
            pygame.quit()

日志文件fangkuai_log.txt内容为:

lose 32
lose 34
lose 32
lose 83
win 705
lose 538
lose 33
lose 28
lose 1394
lose 80
lose 490
lose 107
lose 938
lose 175
lose 681
lose 742
lose 930
lose 349
lose 301
lose 680
lose 480
lose 541
lose 36
lose 237
lose 55
lose 209
lose 544
lose 72
lose 321
lose 346
lose 628
lose 463
lose 450
lose 252
lose 95
lose 510
lose 71
lose 2320
lose 123
lose 36
lose 34
lose 667
lose 190
lose 36
lose 1029
lose 135
lose 36
lose 918
lose 206
lose 88
lose 71
lose 1013
lose 36
lose 234
lose 72
lose 71
lose 743
lose 665
lose 27
lose 294
lose 556
lose 231
lose 72
lose 657
lose 164
lose 410
win 1688
win 590
lose 413
lose 18
lose 18
lose 18
lose 19
win 108
lose 558
lose 35
lose 249
lose 81
win 294
win 825
lose 328
lose 262
lose 417
lose 72
lose 61
lose 131
lose 19
lose 36
lose 37
lose 39
lose 36
lose 39
lose 78
lose 178
lose 1282
lose 73
lose 174
lose 415
lose 239
lose 36
lose 61
lose 72
lose 174
lose 850
win 337
lose 71
lose 65
lose 36
lose 36
win 278
win 363
lose 35
lose 962
lose 72
lose 1173
lose 71
lose 377
lose 194
lose 314
lose 137
lose 208
lose 18
lose 18
lose 18
lose 19
lose 38
lose 37
lose 18
lose 34
lose 769
lose 345
lose 704
lose 516
lose 422
lose 843
win 747
lose 37
lose 38
lose 648
win 591
win 183
win 507
lose 36
lose 39
lose 99
lose 187
lose 346
lose 1003
win 112
lose 306
lose 81
lose 425
lose 149
lose 512
lose 25
lose 24
lose 39
lose 37
lose 28
lose 27
lose 103
lose 139
lose 562
win 1056
lose 566
lose 157
lose 331
lose 128
lose 349
win 194
win 1054
lose 40
lose 375
lose 19
lose 24
lose 38
lose 470
lose 374
lose 66
lose 345
lose 36
lose 673
lose 416
lose 36
win 340
lose 70
lose 36
lose 1018
lose 362
lose 257
lose 40
lose 40
lose 31
lose 308
lose 214
lose 436
lose 1013
win 209
lose 778
lose 38
lose 1006
lose 70
lose 306
win 334
lose 260
lose 721
lose 39
lose 85
lose 36
lose 145
lose 593
lose 659
lose 36
win 390
lose 42
win 69
lose 130
lose 199
win 1064
lose 255
lose 283
lose 386
lose 132
lose 148
lose 1823
lose 770
lose 48
lose 38
lose 367
lose 1327
lose 65
lose 485
lose 555
lose 35
lose 37
lose 82
win 2519
lose 596
win 394
lose 71
lose 71
lose 362
lose 153
lose 70
lose 704
lose 36
lose 376
lose 35
win 108
win 144
win 428
lose 242
lose 312
lose 112
win 116
win 176
lose 76
lose 37
win 418
lose 36
win 378
lose 36
lose 37
win 264
lose 76
lose 36
lose 93
lose 76
lose 133
lose 108
lose 26
lose 18
lose 19
lose 204
lose 96
win 159
win 343
win 390
lose 37
lose 40
lose 36
lose 411
lose 244
win 243
win 142
win 233
win 71
lose 417
lose 456
lose 29
lose 521
win 280
lose 37
lose 34
win 270
lose 212
lose 255
lose 492
win 376
lose 37
win 124
lose 555
win 368
lose 533
lose 49
lose 154
lose 72
lose 126
win 586
lose 136
win 929
win 83
win 1187
lose 121
win 312
lose 40
lose 37
lose 38
win 403
lose 27
lose 26
lose 26
win 219
lose 24
lose 30
lose 319
lose 53
lose 39
lose 38
lose 178
lose 115
win 603
lose 142
lose 701
lose 24
lose 19
lose 22
lose 121
lose 414
lose 292
lose 18
lose 18
win 94
------------------------------------------------------------
lose 324
lose 36
lose 36
lose 215
lose 249
lose 36
lose 1204
lose 173
lose 61
lose 36
lose 36
lose 531
lose 670
lose 1096
lose 3630
lose 36
lose 481
lose 72
lose 147
lose 72
lose 402
lose 90
lose 1157
lose 125
lose 74
lose 320
lose 937
lose 1042
lose 1695
lose 72
lose 266
lose 49
lose 36
lose 36
lose 966
lose 216
lose 756
lose 85
lose 417
lose 1642
lose 36
lose 617
lose 87
lose 36
lose 352
lose 451
lose 1663
lose 256
lose 73
lose 554
lose 89
lose 36
lose 148
lose 118
lose 2038
lose 727
lose 45
lose 130
lose 84
lose 1458
lose 337
lose 600
lose 387
lose 1357
lose 1241
lose 57
lose 118
lose 72
lose 262
lose 281
lose 550
lose 439
lose 153
lose 36
lose 79
lose 682
lose 1036
lose 239
lose 647
lose 52
lose 127
lose 72
lose 852
lose 1242
lose 72
lose 959
lose 71
lose 384
lose 126
lose 2896
lose 642
lose 72
lose 37
lose 475
lose 53
lose 1000
lose 36
lose 533
lose 550
lose 928
lose 36
lose 146
lose 1197
lose 350
lose 72
lose 2149
lose 72
lose 310
lose 71
lose 381
lose 1826
lose 634
lose 668
lose 678
lose 72
lose 40
lose 36
lose 152
lose 604
lose 599
lose 409
lose 82
lose 391
lose 394
lose 718
lose 1584
lose 301
lose 766
lose 38
lose 864
lose 436
lose 56
lose 340
lose 149
lose 36
lose 220
lose 36
lose 57
lose 881
lose 71
lose 540
lose 469
lose 159
lose 71
lose 85
lose 1666
lose 1732
lose 42
lose 1353
lose 37
lose 36
lose 953
lose 70
lose 50
lose 36
lose 222
lose 177
lose 601
lose 432
lose 511
lose 1445
lose 390
lose 417
lose 1448
lose 154
lose 36
lose 67
lose 36
lose 633
lose 689
lose 2284
lose 775
lose 476
lose 745
lose 1388
lose 567
lose 308
lose 1076
lose 1249
lose 71
lose 129
lose 1988
lose 497
lose 130
lose 774
lose 3409
lose 259
lose 252
lose 1680
lose 291
lose 285

从以上描述看出,我编写了一个游戏,用stablebaselines3训练它,收敛了,水平也提高了,但是在演示的时候,水平却远远低于训练时候的水平,跟没有任何训练一样,请问这是为什么

展开
收起
laixbreth 2024-12-05 08:05:12 23 0
0 条回答
写回答
取消 提交回答

大数据领域前沿技术分享与交流,这里不止有技术干货、学习心得、企业实践、社区活动,还有未来。

相关电子书

更多
低代码开发师(初级)实战教程 立即下载
冬季实战营第三期:MySQL数据库进阶实战 立即下载
阿里巴巴DevOps 最佳实践手册 立即下载