Rust常用并发示例代码

2022-09-28 14:43:49 浏览数 (1)

记录几个常用的并发用法:


1、如何让线程只创建1次

先看一段熟悉的java代码:

代码语言:javascript复制
void method1() {
    new Thread(() -> {
        while (true) {
            System.out.println(String.format("thread-id:%s,timestamp:%d",
                    Thread.currentThread().getId(), System.currentTimeMillis()));
            try {
                Thread.sleep(1000);
            } catch (InterruptedException e) {
            }
        }
    }).start();
}

如果method1()被多次调用,就会创建多个线程,如果希望不管调用多少次,只能有1个线程,在不使用线程池的前提下,有1个简单的办法:

代码语言:javascript复制
AtomicBoolean flag = new AtomicBoolean(false);

void method1() {
    //AtomicBoolean保证线程安全,getAndSet是1个原子操作,method1只有第1次执行时,才能if判断才能通过
    if (!flag.getAndSet(true)) {
        new Thread(() -> {
            while (true) {
                System.out.println(String.format("thread-id:%s,timestamp:%d",
                        Thread.currentThread().getId(), System.currentTimeMillis()));
                try {
                    Thread.sleep(1000);
                } catch (InterruptedException e) {
                }
            }
        }).start();
    }
}

在rust中也可以套用这个思路,完整代码如下:

代码语言:javascript复制
use std::sync::atomic::{AtomicBool, Ordering};
use std::thread;
use std::time::{Duration, SystemTime, UNIX_EPOCH};

//声明1个全局静态变量(AtomicBool能保证线程安全)
static FLAG: AtomicBool = AtomicBool::new(false);

fn method1() {
    //fetch_update类似java中的AtomicBoolean.getAndSet
    if FLAG
        .fetch_update(Ordering::SeqCst, Ordering::SeqCst, |_| Some(true))
        .unwrap()
    {
        std::thread::spawn(move || loop {
            println!(
                "thread-id:{:?},timestamp:{}",
                thread::current().id(),
                timestamp()
            );
            thread::sleep(Duration::from_millis(1000));
        });
    }
}

//辅助方法,获取系统时间戳(不用太关注这个方法)
fn timestamp() -> i64 {
    let start = SystemTime::now();
    let since_the_epoch = start
        .duration_since(UNIX_EPOCH)
        .expect("Time went backwards");
    let ms = since_the_epoch.as_secs() as i64 * 1000
          (since_the_epoch.subsec_nanos() as f64 / 1_000_000.0) as i64;
    ms
}
fn main() {
    //调用2次
    method1();
    method1();

    //用1个死循环,防止main线束(仅演示用)
    loop {
        thread::sleep(Duration::from_millis(1000));
    }
}

输出:

代码语言:javascript复制
thread-id:ThreadId(2),timestamp:1662265684621
thread-id:ThreadId(2),timestamp:1662265685623
thread-id:ThreadId(2),timestamp:1662265686627
thread-id:ThreadId(2),timestamp:1662265687628
thread-id:ThreadId(2),timestamp:1662265688630
...

从输出的线程id上看,2次method1()只创建了1个线程


2、如何让线程执行完再继续

代码语言:javascript复制
fn main() {
    let mut thread_list = Vec::<thread::JoinHandle<()>>::new();
    for _i in 0..5 {
        let t = thread::spawn(|| {
            for n in 1..3 {
                println!("{:?}, n:{}", thread::current().id(), n);
                thread::sleep_ms(5);
            }
        });
        thread_list.push(t);
    }
    //运行后会发现,大概率只有下面这行会输出,因为main已经提前线束了,上面的线程没机会执行,就被顺带着被干掉了
    println!("main thread");
}

上面这段代码,如果希望在main主线程结束前,让所有创建出来的子线程执行完,可以使用join方法

代码语言:javascript复制
fn main() {
    let mut thread_list = Vec::<thread::JoinHandle<()>>::new();
    for _i in 0..5 {
        let t = thread::spawn(|| {
            for n in 1..3 {
                println!("{:?}, n:{}", thread::current().id(), n);
                thread::sleep_ms(5);
            }
        });
        thread_list.push(t);
    }
 
    //将所有线程join,强制执行完后,才继续
    for t in thread_list {
        t.join().unwrap();
    }
 
    println!("main thread");
}

3、线程互斥锁

代码语言:javascript复制
use std::sync::{Arc, Mutex};
use std::thread;
use std::time::Duration;

fn main() {
    //声明1个互斥锁Mutex,注意在多线程中使用时,必须套一层Arc
    let flag = Arc::new(Mutex::new(false));
    let mut handlers = vec![];
    for _ in 0..10 {
        let flag = Arc::clone(&flag);
        let handle = thread::spawn(move || {
            thread::sleep(Duration::from_millis(10));
            //只有1个线程会lock成功
            let mut b = flag.lock().unwrap();
            if !*b {
                //抢到锁的,把标志位改成true,其它线程就没机会执行println
                *b = true;
                println!("subt=>t{:?}", thread::current().id());
            }
        });
        handlers.push(handle);
    }
    for h in handlers {
        h.join().unwrap();
    }
    println!("maint=>t{:?}", thread::current().id());
}

上面的效果,9个子线程中,只会有1个抢到锁,并输出println,输出类似下面这样:

代码语言:javascript复制
sub     =>      ThreadId(2)
main    =>      ThreadId(1)

4、线程之间发送数据

代码语言:javascript复制
use std::sync::mpsc;
use std::sync::mpsc::{Receiver, Sender};
use std::thread;
use std::thread::JoinHandle;
use std::time::Duration;
 
fn main() {
    let (sender, receiver) = mpsc::channel();
    let t1 = send_something(sender);
    let t2 = receive_something(receiver);
 
    t1.join().unwrap();
    t2.join().unwrap();
}
 
/**
 * 线程发送消息测试
 */
fn send_something(tx: Sender<String>) -> JoinHandle<()> {
    thread::spawn(move || {
        //模拟先做其它业务处理
        thread::sleep(Duration::from_millis(100));
 
        let msg_list = vec![
            String::from("a"),
            String::from("b"),
            String::from("c"),
            //约定:n是数据的结束符
            String::from("n"),
        ];
 
        //发送一堆消息
        for msg in msg_list {
            tx.send(msg).unwrap();
        }
    })
}
 
/**
 * 线程收消息
 */
fn receive_something(rx: Receiver<String>) -> JoinHandle<()> {
    thread::spawn(move || loop {
        //try_recv 不会阻塞
        let s = rx.try_recv();
        if s.is_ok() {
            let msg = s.unwrap();
            if msg == "n" {
                //约定:收到n表示后面没数据了,可以退出
                println!("end!");
                break;
            } else {
                println!("got msg:{}", msg);
            }
        }
        //模拟没数据时干点其它事情
        println!("do another thing!");
        thread::sleep(Duration::from_millis(100));
    })
}

输出:

代码语言:javascript复制
do another thing!
do another thing!
got msg:a
do another thing!
got msg:b
do another thing!
got msg:c
do another thing!
end!

5、线程池示例

先要引用threadpool的依赖

代码语言:javascript复制
[dependencies]
threadpool="1.8.1"

然后就可以使用了:

代码语言:javascript复制
use std::thread;
use std::time::Duration;
use threadpool::ThreadPool;

fn main() {
    let n_workers = 3;
    //创建1个名为test-pool的线程池
    let pool = ThreadPool::with_name(String::from("test-pool"), n_workers);

    for _ in 0..10 {
        pool.execute(|| {
            println!(
                "{:?},{:?}",
                thread::current().id(),
                thread::current().name()
            );
            thread::sleep(Duration::from_millis(100));
        });
    }

    //待线程池中的所有任务都执行完
    pool.join();
}

输出:

代码语言:javascript复制
ThreadId(2),Some("test-pool")
ThreadId(3),Some("test-pool")
ThreadId(4),Some("test-pool")
ThreadId(2),Some("test-pool")
ThreadId(3),Some("test-pool")
ThreadId(4),Some("test-pool")
ThreadId(2),Some("test-pool")
ThreadId(3),Some("test-pool")
ThreadId(4),Some("test-pool")
ThreadId(2),Some("test-pool")

6、指定线程名称

代码语言:javascript复制
use std::thread;

fn main() {
    let t1 = thread::Builder::new()
        //子线程命名
        .name(format!("my-thread"))
        .spawn(|| {
            //打印子线程的id和name
            println!(
                "{:?},{:?}",
                thread::current().id(),
                thread::current().name()
            );
        })
        .unwrap();
    t1.join().unwrap();

    //打印主线程的id和name
    println!(
        "{:?},{:?}",
        thread::current().id(),
        thread::current().name()
    );
}

输出:

代码语言:javascript复制
ThreadId(2),Some("my-thread")
ThreadId(1),Some("main")

7、如何暂停/恢复线程运行

代码语言:javascript复制
use std::time::Duration;

use std::thread;

fn main() {
    let (tx, rx) = std::sync::mpsc::channel();
    let t = thread::spawn(move || loop {
        //获取当前时间的秒数
        let seconds = std::time::SystemTime::now()
            .duration_since(std::time::UNIX_EPOCH)
            .unwrap()
            .as_secs();
        println!("{}", seconds);
        //每当秒数为5的倍数,就把自己暂停,同时对外发消息pause
        if seconds % 5 == 0 {
            tx.send("pause").unwrap();
            println!("nwill be parked !!!");
            //将自己暂停
            thread::park();
        }
        thread::sleep(Duration::from_millis(1000));
    });

    //不断收消息,发现是pause后,过3秒将线程t解封
    loop {
        let flag = rx.recv();
        if flag.is_ok() && flag.unwrap() == "pause" {
            thread::sleep(Duration::from_millis(3000));
            //解封t
            t.thread().unpark();
            println!("unparked !!!n");
        }
    }
}

这样就实现了一个简易版的ScheudleThread,可以周期性的运行,运行效果:

代码语言:javascript复制
1662278909
1662278910

will be parked !!!
unparked !!!

1662278914
1662278915

will be parked !!!
unparked !!!

1662278919
1662278920
...

8、信号量

推荐使用tokio的信号量实现

代码语言:javascript复制
[dependencies]
tokio = { version = "1.21.0", features = ["full"] }

示例:

代码语言:javascript复制
use std::sync::Arc;
use tokio::sync::Semaphore;

#[tokio::main]
async fn main() {
    let semaphore = Arc::new(Semaphore::new(1));
    println!("1-{:?}", semaphore);

    let _s = semaphore.clone().acquire_owned().await.unwrap();
    //消耗了1个信号量后,只剩下0
    println!("2-{:?}", semaphore);

    //此时再尝试获取信号量,会卡在这里,直到有人把信号号释放归还
    let _s = semaphore.clone().acquire_owned().await.unwrap();
    println!("3-{:?}", semaphore);
    println!("done");
}

输出:

代码语言:javascript复制
1-Semaphore { ll_sem: Semaphore { permits: 1 } }
2-Semaphore { ll_sem: Semaphore { permits: 0 } }
...会卡在这里

要归还信号号,可以使用drop方法

代码语言:javascript复制
use std::sync::Arc;
use tokio::sync::Semaphore;

#[tokio::main]
async fn main() {
    let semaphore = Arc::new(Semaphore::new(1));
    println!("1-{:?}", semaphore);

    let _s = semaphore.clone().acquire_owned().await.unwrap();
    println!("2-{:?}", semaphore);

    //信号号使用后,要记得归原
    drop(_s);
    println!("归原后-{:?}", semaphore);

    //只要剩余信号量>0,就不会卡住了
    let _s = semaphore.clone().acquire_owned().await.unwrap();
    println!("3-{:?}", semaphore);
    println!("done");
}

输出:

代码语言:javascript复制
1-Semaphore { ll_sem: Semaphore { permits: 1 } }
2-Semaphore { ll_sem: Semaphore { permits: 0 } }
归原后-Semaphore { ll_sem: Semaphore { permits: 1 } }
3-Semaphore { ll_sem: Semaphore { permits: 0 } }
done

9、条件变量Condvar

这个东西,要跟Mutex互斥锁一起使用,不要问为什么,Condvar的wait方法签名设计就是这样的!

但其实使用过程中,Mutex的值完全可以跟Condvar没任何关系,把官网的示例修改了下(注:可能没啥实际意义,只是出于演示)

代码语言:javascript复制
use std::sync::{Arc, Condvar, Mutex};
use std::thread;
use std::time::Duration;

fn current_seconds() -> u64 {
    std::time::SystemTime::now()
        .duration_since(std::time::UNIX_EPOCH)
        .unwrap()
        .as_secs()
}

fn main() {
    println!("11111t -> now:{}", current_seconds());
    let pair = Arc::new((Mutex::new(true), Condvar::new()));
    let pair2 = pair.clone();

    thread::spawn(move || {
        thread::sleep(Duration::from_secs(2));
        let &(_, ref cvar) = &*pair2;
        //唤醒被block的线程
        cvar.notify_one();
        println!("threadt -> now:{}", current_seconds());
    });

    let &(_, ref cvar) = &*pair;
    println!("22222t -> now:{}", current_seconds());
    //这里会阻塞住,直到子线程里notify_one通知
    //这里可以看出cvar的wait中完全可以传1个不相关的mutex!
    let no_use = Mutex::new(0);
    let _ = cvar.wait(no_use.lock().unwrap()).unwrap();

    println!("33333t -> now:{}", current_seconds());
}

这里main主线程在调用cvar.wait方法时会block住,直到子线程2秒后,cvar.notify_one()将其唤醒,输出:

代码语言:javascript复制
11111    -> now:1662285716
22222    -> now:1662285716
thread   -> now:1662285718
33333    -> now:1662285718

参考文章:

Rust语言圣经-多线程并发编程

0 人点赞