并发编程基础
Rust 并发编程基础
概述
Rust 的所有权系统使得并发编程更加安全。编译器可以在编译时防止数据竞争,这是 Rust 的一大优势。
线程基础
Rust 使用 1:1 线程模型,每个语言线程对应一个操作系统线程。
简单示例
use std::thread;
use std::time::Duration;fn basic_threading() {let handle = thread::spawn(|| {for i in 1..10 {println!("子线程: {}", i);thread::sleep(Duration::from_millis(1));}});for i in 1..5 {println!("主线程: {}", i);thread::sleep(Duration::from_millis(1));}handle.join().unwrap();
}
复杂案例:实现一个并发任务调度器
use std::sync::{Arc, Mutex, Condvar};
use std::sync::mpsc::{self, Sender, Receiver};
use std::thread;
use std::time::Duration;
use std::collections::VecDeque;// 任务类型
type Task = Box<dyn FnOnce() + Send + 'static>;// 工作线程状态
#[derive(Debug, Clone, Copy, PartialEq)]
enum WorkerState {Idle,Busy,Stopped,
}// 工作线程
struct Worker {id: usize,thread: Option<thread::JoinHandle<()>>,
}impl Worker {fn new(id: usize,receiver: Arc<Mutex<Receiver<Task>>>,state: Arc<Mutex<WorkerState>>,condvar: Arc<Condvar>,) -> Self {let thread = thread::spawn(move || {loop {let task = {let receiver = receiver.lock().unwrap();receiver.recv()};match task {Ok(task) => {{let mut s = state.lock().unwrap();*s = WorkerState::Busy;}println!("工作线程 {} 正在执行任务", id);task();{let mut s = state.lock().unwrap();*s = WorkerState::Idle;}condvar.notify_all();}Err(_) => {println!("工作线程 {} 停止", id);let mut s = state.lock().unwrap();*s = WorkerState::Stopped;break;}}}});Worker {id,thread: Some(thread),}}
}// 线程池
struct ThreadPool {workers: Vec<Worker>,sender: Sender<Task>,worker_states: Vec<Arc<Mutex<WorkerState>>>,condvar: Arc<Condvar>,
}impl ThreadPool {fn new(size: usize) -> Self {assert!(size > 0);let (sender, receiver) = mpsc::channel();let receiver = Arc::new(Mutex::new(receiver));let condvar = Arc::new(Condvar::new());let mut workers = Vec::with_capacity(size);let mut worker_states = Vec::with_capacity(size);for id in 0..size {let state = Arc::new(Mutex::new(WorkerState::Idle));worker_states.push(state.clone());workers.push(Worker::new(id,Arc::clone(&receiver),state,Arc::clone(&condvar),));}ThreadPool {workers,sender,worker_states,condvar,}}fn execute<F>(&self, f: F)whereF: FnOnce() + Send + 'static,{let task = Box::new(f);self.sender.send(task).unwrap();}fn active_count(&self) -> usize {self.worker_states.iter().filter(|state| {*state.lock().unwrap() == WorkerState::Busy}).count()}fn wait_completion(&self) {loop {let all_idle = self.worker_states.iter().all(|state| {let s = state.lock().unwrap();*s == WorkerState::Idle || *s == WorkerState::Stopped});if all_idle {break;}let state = self.worker_states[0].lock().unwrap();let _guard = self.condvar.wait(state).unwrap();}}
}impl Drop for ThreadPool {fn drop(&mut self) {drop(self.sender.clone());for worker in &mut self.workers {if let Some(thread) = worker.thread.take() {thread.join().unwrap();}}}
}// 演示线程池使用
fn demonstrate_thread_pool() {let pool = ThreadPool::new(4);println!("提交 10 个任务到线程池");for i in 0..10 {pool.execute(move || {println!("任务 {} 开始执行", i);thread::sleep(Duration::from_millis(500));println!("任务 {} 完成", i);});}println!("等待所有任务完成...");pool.wait_completion();println!("所有任务已完成");
}// 使用 Arc 和 Mutex 实现共享状态
struct SharedCounter {count: Arc<Mutex<i32>>,
}impl SharedCounter {fn new() -> Self {SharedCounter {count: Arc::new(Mutex::new(0)),}}fn increment(&self) {let mut count = self.count.lock().unwrap();*count += 1;}fn get(&self) -> i32 {*self.count.lock().unwrap()}fn clone_counter(&self) -> Self {SharedCounter {count: Arc::clone(&self.count),}}
}fn demonstrate_shared_state() {let counter = SharedCounter::new();let mut handles = vec![];for _ in 0..10 {let counter_clone = counter.clone_counter();let handle = thread::spawn(move || {for _ in 0..100 {counter_clone.increment();}});handles.push(handle);}for handle in handles {handle.join().unwrap();}println!("最终计数: {}", counter.get());
}// 消息传递并发
fn demonstrate_message_passing() {let (tx, rx) = mpsc::channel();// 创建多个发送者for i in 0..5 {let tx_clone = tx.clone();thread::spawn(move || {for j in 0..10 {tx_clone.send(format!("线程 {} 发送消息 {}", i, j)).unwrap();thread::sleep(Duration::from_millis(100));}});}drop(tx); // 关闭原始发送者// 接收消息let mut count = 0;for received in rx {println!("收到: {}", received);count += 1;}println!("总共收到 {} 条消息", count);
}// 生产者-消费者模式
struct ProducerConsumer {queue: Arc<Mutex<VecDeque<i32>>>,condvar: Arc<Condvar>,max_size: usize,
}impl ProducerConsumer {fn new(max_size: usize) -> Self {ProducerConsumer {queue: Arc::new(Mutex::new(VecDeque::new())),condvar: Arc::new(Condvar::new()),max_size,}}fn produce(&self, item: i32) {let mut queue = self.queue.lock().unwrap();while queue.len() >= self.max_size {queue = self.condvar.wait(queue).unwrap();}queue.push_back(item);println!("生产: {}, 队列大小: {}", item, queue.len());self.condvar.notify_all();}fn consume(&self) -> Option<i32> {let mut queue = self.queue.lock().unwrap();while queue.is_empty() {queue = self.condvar.wait(queue).unwrap();}let item = queue.pop_front();if let Some(i) = item {println!("消费: {}, 队列大小: {}", i, queue.len());}self.condvar.notify_all();item}fn clone_pc(&self) -> Self {ProducerConsumer {queue: Arc::clone(&self.queue),condvar: Arc::clone(&self.condvar),max_size: self.max_size,}}
}fn demonstrate_producer_consumer() {let pc = ProducerConsumer::new(5);// 生产者线程let pc_producer = pc.clone_pc();let producer = thread::spawn(move || {for i in 0..20 {pc_producer.produce(i);thread::sleep(Duration::from_millis(50));}});// 消费者线程let mut consumers = vec![];for _ in 0..3 {let pc_consumer = pc.clone_pc();let consumer = thread::spawn(move || {for _ in 0..7 {pc_consumer.consume();thread::sleep(Duration::from_millis(150));}});consumers.push(consumer);}producer.join().unwrap();for consumer in consumers {consumer.join().unwrap();}
}// 并行计算示例:并行求和
fn parallel_sum(data: Vec<i32>, num_threads: usize) -> i32 {let chunk_size = (data.len() + num_threads - 1) / num_threads;let data = Arc::new(data);let mut handles = vec![];for i in 0..num_threads {let data_clone = Arc::clone(&data);let handle = thread::spawn(move || {let start = i * chunk_size;let end = ((i + 1) * chunk_size).min(data_clone.len());if start >= data_clone.len() {return 0;}data_clone[start..end].iter().sum::<i32>()});handles.push(handle);}handles.into_iter().map(|h| h.join().unwrap()).sum()
}fn demonstrate_parallel_sum() {let data: Vec<i32> = (1..=1000).collect();let sum = parallel_sum(data, 4);println!("并行求和结果: {}", sum);
}fn main() {println!("=== 基础线程 ===");basic_threading();println!("\n=== 线程池 ===");demonstrate_thread_pool();println!("\n=== 共享状态 ===");demonstrate_shared_state();println!("\n=== 消息传递 ===");demonstrate_message_passing();println!("\n=== 生产者消费者 ===");demonstrate_producer_consumer();println!("\n=== 并行求和 ===");demonstrate_parallel_sum();
}
并发安全的数据结构
use std::sync::RwLock;struct ConcurrentMap<K, V> {data: Arc<RwLock<std::collections::HashMap<K, V>>>,
}impl<K: Eq + std::hash::Hash, V> ConcurrentMap<K, V> {fn new() -> Self {ConcurrentMap {data: Arc::new(RwLock::new(std::collections::HashMap::new())),}}fn insert(&self, key: K, value: V) {let mut map = self.data.write().unwrap();map.insert(key, value);}fn get(&self, key: &K) -> Option<V> where V: Clone {let map = self.data.read().unwrap();map.get(key).cloned()}
}
总结
Rust 的并发模型通过所有权和类型系统保证了线程安全。使用 Arc、Mutex、通道等工具,可以安全高效地编写并发程序。
