基于carla的模仿学习(附数据集CORL2017)更新中........
1.模仿学习
把专家的示范数据,拆分成状态和动作对以后,就看到这些数据,看起来就是些有标记的数据。也就是说我们可以看到每一个状态上面,专家都做了什么动作?让机器学习连续动作的直观想法是,用一种监督学习的方式,可以把这个状态作为我们监督学习里面的样本,动作作为监督学习里面的标记,把我们的状态当成神经网络的输入,把神经网络输出当成动作,最后用期望的动作,教机器学习状态和动作之间的相对应关系。如果我们这里的动作是一个理想的动作,我们可以用一些分类的算法,如果这里是一个控制的动作,我们可以用一些回归的方法。
以CORL2017数据集为例
state: 图片
action:acc brake steer
2.模仿学习分类
2.1行为克隆
2.2逆强化学习
(本文使用COLR数据集进行基于行为克隆的模仿学习)
3.代码
3.1使用行为克隆进行模仿学习的代码((python+tensorflow)一个模板,了解一下大概流程,后面会有详细解析))
借鉴AI天才学院
import tensorflow as tf
from tensorflow.keras import layers, models
import numpy as np
# 生成示例专家示范数据
def generate_expert_data(num_samples):
states = np.random.rand(num_samples, 4)
actions = np.random.rand(num_samples, 2)
return states, actions
# 构建行为克隆模型
def build_model(input_shape, output_shape):
model = models.Sequential()
model.add(layers.Dense(64, activation='relu', input_shape=input_shape))
model.add(layers.Dense(64, activation='relu'))
model.add(layers.Dense(output_shape))
model.compile(optimizer='adam', loss='mse')
return model
# 训练行为克隆模型
def train_model(model, states, actions, epochs=50, batch_size=32):
model.fit(states, actions, epochs=epochs, batch_size=batch_size)
# 评估行为克隆模型
def evaluate_model(model, states, actions):
loss = model.evaluate(states, actions)
print(f"Evaluation loss: {loss}")
# 主函数
if __name__ == "__main__":
num_samples = 1000
states, actions = generate_expert_data(num_samples)
model = build_model((4,), 2)
train_model(model, states, actions)
evaluate_model(model, states, actions)
3.2CORL数据集的数据采集
对h5文件进行分析
import h5py
import pandas as pd
import os
def read_corl2017_h5(file_path):
data = {}
with h5py.File(file_path, 'r') as hf:
# 遍历文件中的所有数据集并读取数据
for key in hf.keys():
data[key] = hf[key][:]
return data
# 替换为实际的文件路径
file_path = 'data_03664.h5'
train_file_path='G:\数据集\CORL2017ImitationLearningData\AgentHuman\SeqTrain'
val_file_path='G:\数据集\CORL2017ImitationLearningData\AgentHuman\Seqval'
h5_files = os.listdir(train_file_path)
image_dataset=[]
action_dataset=[]
for i in range(len(h5_files)):
#corl_data=image_files[i]
corl_data = read_corl2017_h5(train_file_path+'/'+h5_files[i])
#print(corl_data['targets'].shape)
print(corl_data['rgb'][0].shape)
for j in range(200):
#每个数据包含200个图像+200*28个车辆参数
corl_data['targets'][j][0]
corl_data['targets'][j][1]
corl_data['targets'][j][2]
if corl_data['targets'][j][1]>0:
break