当前位置: 首页 > news >正文

【逻辑回归】MAP - Charting Student Math Misunderstandings

MAP - Charting Student Math Misunderstandings

数据分析

import pandas as pd# Load the dataset
file_path = '/mnt/data/train.csv'
data = pd.read_csv(file_path)# Display the first few rows of the dataset to understand its structure
data.head()

结果

   row_id  QuestionId                                       QuestionText  \
0       0       31772  What fraction of the shape is not shaded? Give...   
1       1       31772  What fraction of the shape is not shaded? Give...   
2       2       31772  What fraction of the shape is not shaded? Give...   
3       3       31772  What fraction of the shape is not shaded? Give...   
4       4       31772  What fraction of the shape is not shaded? Give...   MC_Answer                                 StudentExplanation  \
0  \( \frac{1}{3} \)                  0ne third is equal to tree nineth   
1  \( \frac{1}{3} \)  1 / 3 because 6 over 9 is 2 thirds and 1 third...   
2  \( \frac{1}{3} \)  1 3rd is half of 3 6th, so it is simplee to un...   
3  \( \frac{1}{3} \)        1 goes into everything and 3 goes into nine   
4  \( \frac{1}{3} \)                    1 out of every 3 isn't coloured   Category Misconception  
0  True_Correct           NaN  
1  True_Correct           NaN  
2  True_Neither           NaN  
3  True_Neither           NaN  
4  True_Correct           NaN  

数据集包含以下几列:

  • row_id: 每行的唯一标识符。
  • QuestionId: 每个问题的唯一标识符。
  • QuestionText: 问题的文本描述。
  • MC_Answer: 学生选择的答案,通常是数学表达式。
  • StudentExplanation: 学生的解释或回答,可能包含错误或不清楚的地方。
  • Category: 问题类别,如“True_Correct”或“True_Neither”。
  • Misconception: 如果学生的解释中存在误解,该列为标签(缺失值可能表示没有误解)。

模型训练

接下来,进行数据预处理,并使用学生解释文本进行特征提取。我们将清理文本、进行分词、去除停用词,并准备数据进行训练。之后,我会选择一个简单的模型-逻辑回归来作为起点,进行训练并评估。

from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import classification_report
import string# Step 1: Clean the student explanation text (remove punctuation and lower case)
def clean_text(text):text = text.lower()  # convert to lower casetext = ''.join([char for char in text if char not in string.punctuation])  # remove punctuationreturn text# Apply the cleaning function to the 'StudentExplanation' column
data['cleaned_explanation'] = data['StudentExplanation'].apply(clean_text)# Step 2: Use TF-IDF to convert text to numerical features
vectorizer = TfidfVectorizer(stop_words='english', max_features=5000)
X = vectorizer.fit_transform(data['cleaned_explanation'])# Step 3: Prepare labels (Misconception column)
# We will predict if the explanation contains a misconception or not, so we need to define labels
# If 'Misconception' column is NaN, we can assume it's no misconception for simplicity
data['Misconception'] = data['Misconception'].fillna('No_Misconception')# Convert labels to binary: 'No_Misconception' -> 0, any other label -> 1
y = data['Misconception'].apply(lambda x: 0 if x == 'No_Misconception' else 1)# Step 4: Split the data into training and testing sets
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)# Step 5: Train a Logistic Regression model
model = LogisticRegression()
model.fit(X_train, y_train)# Step 6: Evaluate the model
y_pred = model.predict(X_test)
classification_report(y_test, y_pred)
           precision    recall  f1-score   support0       0.81      0.93      0.87      52771       0.71      0.46      0.56      2063accuracy                           0.80      7340macro avg       0.76      0.69      0.71      7340
weighted avg       0.79      0.80      0.78      7340

模型训练完成并进行了评估。分类报告的结果:

  • 准确率 (Accuracy): 80%

  • 类别 0(无误解)的表现

    • 精确度 (Precision): 82%
    • 召回率 (Recall): 93%
    • F1 分数: 87%
  • 类别 1(含误解)的表现

    • 精确度 (Precision): 72%
    • 召回率 (Recall): 46%
    • F1 分数: 56%

结果分析:

  • 模型在识别没有误解的学生解释时表现得很好(高召回率)。
  • 但是,识别含有误解的解释的表现相对较弱,特别是在召回率方面,可能表明存在不平衡的数据问题。

完整代码

import pandas as pd
import string
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import classification_report# Step 1: Load and clean the data
# Load the dataset
file_path = 'train.csv'
data = pd.read_csv(file_path)# Clean the student explanation text (remove punctuation and lower case)
def clean_text(text):text = text.lower()  # Convert to lower casetext = ''.join([char for char in text if char not in string.punctuation])  # Remove punctuationreturn text# Apply the cleaning function to the 'StudentExplanation' column
data['cleaned_explanation'] = data['StudentExplanation'].apply(clean_text)# Step 2: Feature extraction using TF-IDF
vectorizer = TfidfVectorizer(stop_words='english', max_features=5000)
X = vectorizer.fit_transform(data['cleaned_explanation'])# Step 3: Prepare labels (Misconception column)
# We will predict if the explanation contains a misconception or not
# Fill NaN values with 'No_Misconception' as default
data['Misconception'] = data['Misconception'].fillna('No_Misconception')# Convert labels to binary: 'No_Misconception' -> 0, any other label -> 1
y = data['Misconception'].apply(lambda x: 0 if x == 'No_Misconception' else 1)# Step 4: Split the data into training and testing sets
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)# Step 5: Train a Logistic Regression model
model = LogisticRegression(max_iter=1000)  # Increase iterations to ensure convergence
model.fit(X_train, y_train)# Step 6: Evaluate the model
y_pred = model.predict(X_test)
print(classification_report(y_test, y_pred))

下一步:

  • 调整模型:可以考虑使用其他模型,如支持向量机(SVM)或深度学习模型(例如 BERT)来提高对含有误解解释的识别能力。

  • 数据不平衡问题:我们可以尝试使用过采样或欠采样技术,或者调整类别权重来改善模型对含误解解释的识别。

http://www.dtcms.com/a/287064.html

相关文章:

  • PostgreSQL ORDER BY 语句详解
  • bash方式启动模型训练
  • tkinter绘制组件(45)——导航栏
  • EP01:【Python 第一弹】基础入门知识
  • 国产电科金仓数据库:融合进化,智领未来
  • C++进阶课程第4期——动态规划
  • FastAPI遇上GraphQL:异步解析器如何让API性能飙升?
  • C++中的list(1)
  • c#中ArrayList和List的常用方法
  • 微信小程序入门实例_____从零开始 开发一个“旅行清单 ”微信小程序
  • Flutter基础(前端教程①④-data.map和assignAll和fromJson和toList)
  • 【深度学习笔记 Ⅱ】1 数据集的划分
  • Linux:多线程---深入生产消费模型环形队列生产消费模型
  • OSPF高级特性之Overflow
  • MyBatis之缓存机制详解
  • uniapp中报错:ReferenceError: FormData is not defined
  • 【安卓笔记】RxJava的Hook机制,整体拦截器
  • JavaScript空值安全深度指南
  • 加线机 和 胶带机
  • LVS 集群技术实践:NAT 与 DR 模式的配置与对比
  • 如何在HTML5页面中嵌入视频
  • 基于单片机宠物喂食器/智能宠物窝/智能饲养
  • 车载传统ECU---MCU软件架构设计指南
  • MSTP 多生成树协议
  • 零基础学后端-PHP语言(第一期-PHP环境配置)
  • 题解:CF1690G Count the Trains
  • 【C++基础】--多态
  • PortSwigger Labs 之 点击劫持利用
  • Go语言流程控制(if / for)
  • 编程研发工作日记_廖万忠_2016_2017