为什么 tokio::task::spawn 不能重构为 fn?
why can't tokio::task::spawn be refactored into a fn?
抱歉,示例太大了,相关位在顶部。我把它全部包括在内,以便好奇的人可以轻松地投入到他们最喜欢的东西中。这是一个游乐场 link:
作为概述,我设置了一个菊花链广播频道并发送到第一个并从最后一个接收通知。每个创建的转发器都创建了一个任务,该任务循环等待 recv 完成。一切正常。但是,我想将任务创建重构到 fn 中,我似乎 运行 违反了一些我一直无法克服的异步规则。
我使用以下板条箱:
[dependencies]
tokio = { version = "0.3", features = ["full"] }
async-trait = "0.1.42"
这是代码,如果您想在这里查看它:
use std::sync::Arc;
use std::fmt;
use async_trait::async_trait;
use tokio::sync::broadcast;
use tokio::spawn;
use tokio::sync::Mutex;
use tokio::runtime;
fn main() -> Result<(), Box<dyn std::error::Error>> {
let rt = runtime::Runtime::new()?;
let mut tasks = Vec::new();
let mut first = None;
let mut prev = None;
rt.block_on(async {
// setup a daisy-chain of 5 machines
for id in 1..=5 {
// create sender and adapter wrapping forwarder
let (_, s, mut a) = build_bounded(Forwarder::new(id), 100);
// create a task to run the recv loop -- consider using stream
我想将这个 spawn 拉出到 returns 任务的 fn 中。
// why can't this be refactored into a fn that returns task?
let task = spawn(async move {
loop {
let cmd = a.receiver.recv().await;
if cmd.is_ok() {
a.machine.recv(cmd.unwrap()).await;
} else {
break
}
}
a.machine.disconnected().await;
});
// save the task
tasks.push(task);
if prev.is_none() {
// first time save the sender
first = Some(s.clone());
} else {
// tell previous sender to send to this sender
send_cmd(prev.unwrap(), TestMessage::AddSender(s.clone()));
}
prev = Some(s);
}
// create notifier and tell the last to send to it
let (s, mut r) = broadcast::channel::<TestMessage>(10);
send_cmd(prev.unwrap(), TestMessage::Notify(s, 1));
// send to the first
send_cmd(first.unwrap(), TestMessage::TestData(0));
// wait for the notification
if let Ok(_msg) = r.recv().await {
println!("got notification");
}
println!("done");
});
Ok(())
}
#[async_trait]
pub trait Machine<T>: Send + Sync + 'static {
async fn disconnected(&self);
async fn recv(&self, cmd: T);
}
pub trait InstructionSet: Clone {
type InstructionSet;
}
fn send_cmd<T: Send + Sync + 'static>(sender: broadcast::Sender<T>, cmd: T) {
if sender.send(cmd).is_err() {}
}
#[derive(Debug, Clone)]
pub enum TestMessage {
// TestData has a single parameter, as a tuple
TestData(usize),
// AddSender can be implemented to push a sender onto a list of senders
AddSender(TestMessageSender),
// Notify, is setup for a notification via TestData, where usize is a message count
Notify(TestMessageSender, usize),
}
// TestMessageSender is shorthand for a sender of a TestMessage instruction.
pub type TestMessageSender = broadcast::Sender<TestMessage>;
impl InstructionSet for TestMessage {
type InstructionSet = TestMessage;
}
#[derive(Default)]
struct Forwarder {
/// a id, mosly used for logging
id: usize,
/// The mutable bits...
mutable: Mutex<ForwarderMutable>,
}
/// This is the mutable part of the Forwarder
pub struct ForwarderMutable {
/// collection of senders, each will be sent any received message.
senders: Vec<TestMessageSender>,
/// received_count is the count of messages received by this forwarder.
received_count: usize,
/// send_count is the count of messages sent by this forwarder.
send_count: usize,
/// notify_count is compared against received_count for means of notifcation.
notify_count: usize,
/// notify_sender is sent a TestData message with the data being the number of messages received.
notify_sender: Option<TestMessageSender>,
/// forwarding multiplier
forwarding_multiplier: usize,
// for TestData, this is the next in sequence
next_seq: usize,
}
impl Default for ForwarderMutable {
fn default() -> Self {
Self::new()
}
}
impl ForwarderMutable {
fn new() -> Self {
Self {
senders: Vec::<TestMessageSender>::new(),
received_count: 0,
send_count: 0,
notify_count: 0,
notify_sender: None,
forwarding_multiplier: 1,
next_seq: 0,
}
}
fn drop_all_senders(&mut self) {
self.senders.clear();
self.notify_sender = None;
}
/// if msg is TestData, validate the sequence or reset if 0
fn validate_sequence(&mut self, msg: TestMessage) -> Result<TestMessage, TestMessage> {
match msg {
TestMessage::TestData(seq) if seq == self.next_seq => self.next_seq += 1,
TestMessage::TestData(seq) if seq == 0 => self.next_seq = 1,
TestMessage::TestData(_) => return Err(msg),
_ => (),
}
// bump received count
self.received_count += 1;
Ok(msg)
}
/// If msg is a configuration msg, handle it otherwise return it as an error
fn handle_config(&mut self, msg: TestMessage, id: usize) -> Result<(), TestMessage> {
match msg {
TestMessage::Notify(sender, on_receive_count) => {
println!("forwarder {}: added notifier", id);
self.notify_sender = Some(sender);
self.notify_count = on_receive_count;
},
TestMessage::AddSender(sender) => {
println!("forwarder {}: added sender", id);
self.senders.push(sender);
},
msg => return Err(msg),
}
Ok(())
}
/// handle the action messages
fn handle_action(&mut self, message: TestMessage, id: usize) {
match message {
TestMessage::TestData(_) => {
println!("forwarder {}: received TestData", id);
for sender in &self.senders {
for _ in 0 .. self.forwarding_multiplier {
send_cmd(sender.clone(), TestMessage::TestData(self.send_count));
self.send_count += 1;
}
}
},
_ => self.senders.iter().for_each(|sender| {
for _ in 0 .. self.forwarding_multiplier {
send_cmd(sender.clone(), message.clone());
}
}),
}
}
/// handle sending out a notification and resetting counters when notificaiton is sent
fn handle_notification(&mut self, id: usize) {
if self.received_count == self.notify_count {
let count = self.get_and_clear_received_count();
if let Some(notifier) = self.notify_sender.as_ref() {
println!("forwarder {}: sending notification", id);
send_cmd(notifier.clone(), TestMessage::TestData(count));
}
}
}
/// get the current received count and clear counters
fn get_and_clear_received_count(&mut self) -> usize {
let received_count = self.received_count;
self.received_count = 0;
self.send_count = 0;
received_count
}
}
impl Forwarder {
pub fn new(id: usize) -> Self { Self { id, ..Default::default() } }
pub const fn get_id(&self) -> usize { self.id }
}
#[async_trait]
impl Machine<TestMessage> for Forwarder {
async fn disconnected(&self) {
println!("forwarder {}: disconnected", self.get_id());
// drop senders
let mut mutable = self.mutable.lock().await;
mutable.drop_all_senders();
}
async fn recv(&self, message: TestMessage) {
let mut mutable = self.mutable.lock().await;
match mutable.handle_config(message, self.get_id()) {
Ok(_) => (),
Err(msg) => match mutable.validate_sequence(msg) {
Ok(msg) => {
mutable.handle_action(msg, self.get_id());
mutable.handle_notification(self.get_id());
},
Err(msg) => panic!("sequence error fwd {}, msg {:#?}", self.get_id(), msg),
},
}
}
}
struct Adapter<T: InstructionSet<InstructionSet = T>> {
sender: broadcast::Sender<T>,
pub receiver: broadcast::Receiver<T>,
pub machine: Arc<dyn Machine<<T as InstructionSet>::InstructionSet>>,
}
impl<T: InstructionSet<InstructionSet = T>> Clone for Adapter<T> {
fn clone(&self) -> Self {
Self {
sender: self.sender.clone(),
receiver: self.sender.subscribe(),
machine: self.machine.clone(),
}
}
}
fn build_bounded<T, U>(raw: U, capacity: usize) -> (Arc<U>, broadcast::Sender<T>, Adapter<T>)
where
U: Machine<T> + Send + Sync + 'static,
T: InstructionSet<InstructionSet = T>,
{
let instance = Arc::new(raw);
let (sender, receiver) = broadcast::channel::<T>(capacity);
let cloned_sender = sender.clone();
let machine = instance.clone() as Arc<dyn Machine<T>>;
let adapter = Adapter::<T> { sender, receiver, machine, };
(instance, cloned_sender, adapter)
}
#[derive(Debug, Clone, PartialEq)]
pub struct Error {
message: String,
}
impl fmt::Display for Error {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> {
write!(f, "{}", self.message)
}
}
您可以将该代码段重构为独立函数,但 Adapter 参数所需的泛型类型需要具有额外的特征边界以符合 Tokio 的生成要求。这是内联时为您推断的,但是当拉出到函数中时,它必须是显式的。
Tokio 生成的所有任务函数都需要 Send
以及 'static
。请注意,此处的 static 仅表示允许注释类型一直存在到程序结束。它与应用于引用的静态无关。
Tokio 生成教程中提供了对这种微妙之处的深入讨论:https://tokio.rs/tokio/tutorial/spawning - 在讨论“静态和发送边界”的过程中寻找一个部分。
这是重构的结果:
fn spawn_task<T>(mut a: Adapter<T>) -> JoinHandle<()>
where
T: InstructionSet<InstructionSet = T> + Send + 'static
{
spawn(async move {
loop {
let cmd = a.receiver.recv().await;
if cmd.is_ok() {
a.machine.recv(cmd.unwrap()).await;
} else {
break
}
}
a.machine.disconnected().await;
})
}
抱歉,示例太大了,相关位在顶部。我把它全部包括在内,以便好奇的人可以轻松地投入到他们最喜欢的东西中。这是一个游乐场 link:
作为概述,我设置了一个菊花链广播频道并发送到第一个并从最后一个接收通知。每个创建的转发器都创建了一个任务,该任务循环等待 recv 完成。一切正常。但是,我想将任务创建重构到 fn 中,我似乎 运行 违反了一些我一直无法克服的异步规则。
我使用以下板条箱:
[dependencies]
tokio = { version = "0.3", features = ["full"] }
async-trait = "0.1.42"
这是代码,如果您想在这里查看它:
use std::sync::Arc;
use std::fmt;
use async_trait::async_trait;
use tokio::sync::broadcast;
use tokio::spawn;
use tokio::sync::Mutex;
use tokio::runtime;
fn main() -> Result<(), Box<dyn std::error::Error>> {
let rt = runtime::Runtime::new()?;
let mut tasks = Vec::new();
let mut first = None;
let mut prev = None;
rt.block_on(async {
// setup a daisy-chain of 5 machines
for id in 1..=5 {
// create sender and adapter wrapping forwarder
let (_, s, mut a) = build_bounded(Forwarder::new(id), 100);
// create a task to run the recv loop -- consider using stream
我想将这个 spawn 拉出到 returns 任务的 fn 中。
// why can't this be refactored into a fn that returns task?
let task = spawn(async move {
loop {
let cmd = a.receiver.recv().await;
if cmd.is_ok() {
a.machine.recv(cmd.unwrap()).await;
} else {
break
}
}
a.machine.disconnected().await;
});
// save the task
tasks.push(task);
if prev.is_none() {
// first time save the sender
first = Some(s.clone());
} else {
// tell previous sender to send to this sender
send_cmd(prev.unwrap(), TestMessage::AddSender(s.clone()));
}
prev = Some(s);
}
// create notifier and tell the last to send to it
let (s, mut r) = broadcast::channel::<TestMessage>(10);
send_cmd(prev.unwrap(), TestMessage::Notify(s, 1));
// send to the first
send_cmd(first.unwrap(), TestMessage::TestData(0));
// wait for the notification
if let Ok(_msg) = r.recv().await {
println!("got notification");
}
println!("done");
});
Ok(())
}
#[async_trait]
pub trait Machine<T>: Send + Sync + 'static {
async fn disconnected(&self);
async fn recv(&self, cmd: T);
}
pub trait InstructionSet: Clone {
type InstructionSet;
}
fn send_cmd<T: Send + Sync + 'static>(sender: broadcast::Sender<T>, cmd: T) {
if sender.send(cmd).is_err() {}
}
#[derive(Debug, Clone)]
pub enum TestMessage {
// TestData has a single parameter, as a tuple
TestData(usize),
// AddSender can be implemented to push a sender onto a list of senders
AddSender(TestMessageSender),
// Notify, is setup for a notification via TestData, where usize is a message count
Notify(TestMessageSender, usize),
}
// TestMessageSender is shorthand for a sender of a TestMessage instruction.
pub type TestMessageSender = broadcast::Sender<TestMessage>;
impl InstructionSet for TestMessage {
type InstructionSet = TestMessage;
}
#[derive(Default)]
struct Forwarder {
/// a id, mosly used for logging
id: usize,
/// The mutable bits...
mutable: Mutex<ForwarderMutable>,
}
/// This is the mutable part of the Forwarder
pub struct ForwarderMutable {
/// collection of senders, each will be sent any received message.
senders: Vec<TestMessageSender>,
/// received_count is the count of messages received by this forwarder.
received_count: usize,
/// send_count is the count of messages sent by this forwarder.
send_count: usize,
/// notify_count is compared against received_count for means of notifcation.
notify_count: usize,
/// notify_sender is sent a TestData message with the data being the number of messages received.
notify_sender: Option<TestMessageSender>,
/// forwarding multiplier
forwarding_multiplier: usize,
// for TestData, this is the next in sequence
next_seq: usize,
}
impl Default for ForwarderMutable {
fn default() -> Self {
Self::new()
}
}
impl ForwarderMutable {
fn new() -> Self {
Self {
senders: Vec::<TestMessageSender>::new(),
received_count: 0,
send_count: 0,
notify_count: 0,
notify_sender: None,
forwarding_multiplier: 1,
next_seq: 0,
}
}
fn drop_all_senders(&mut self) {
self.senders.clear();
self.notify_sender = None;
}
/// if msg is TestData, validate the sequence or reset if 0
fn validate_sequence(&mut self, msg: TestMessage) -> Result<TestMessage, TestMessage> {
match msg {
TestMessage::TestData(seq) if seq == self.next_seq => self.next_seq += 1,
TestMessage::TestData(seq) if seq == 0 => self.next_seq = 1,
TestMessage::TestData(_) => return Err(msg),
_ => (),
}
// bump received count
self.received_count += 1;
Ok(msg)
}
/// If msg is a configuration msg, handle it otherwise return it as an error
fn handle_config(&mut self, msg: TestMessage, id: usize) -> Result<(), TestMessage> {
match msg {
TestMessage::Notify(sender, on_receive_count) => {
println!("forwarder {}: added notifier", id);
self.notify_sender = Some(sender);
self.notify_count = on_receive_count;
},
TestMessage::AddSender(sender) => {
println!("forwarder {}: added sender", id);
self.senders.push(sender);
},
msg => return Err(msg),
}
Ok(())
}
/// handle the action messages
fn handle_action(&mut self, message: TestMessage, id: usize) {
match message {
TestMessage::TestData(_) => {
println!("forwarder {}: received TestData", id);
for sender in &self.senders {
for _ in 0 .. self.forwarding_multiplier {
send_cmd(sender.clone(), TestMessage::TestData(self.send_count));
self.send_count += 1;
}
}
},
_ => self.senders.iter().for_each(|sender| {
for _ in 0 .. self.forwarding_multiplier {
send_cmd(sender.clone(), message.clone());
}
}),
}
}
/// handle sending out a notification and resetting counters when notificaiton is sent
fn handle_notification(&mut self, id: usize) {
if self.received_count == self.notify_count {
let count = self.get_and_clear_received_count();
if let Some(notifier) = self.notify_sender.as_ref() {
println!("forwarder {}: sending notification", id);
send_cmd(notifier.clone(), TestMessage::TestData(count));
}
}
}
/// get the current received count and clear counters
fn get_and_clear_received_count(&mut self) -> usize {
let received_count = self.received_count;
self.received_count = 0;
self.send_count = 0;
received_count
}
}
impl Forwarder {
pub fn new(id: usize) -> Self { Self { id, ..Default::default() } }
pub const fn get_id(&self) -> usize { self.id }
}
#[async_trait]
impl Machine<TestMessage> for Forwarder {
async fn disconnected(&self) {
println!("forwarder {}: disconnected", self.get_id());
// drop senders
let mut mutable = self.mutable.lock().await;
mutable.drop_all_senders();
}
async fn recv(&self, message: TestMessage) {
let mut mutable = self.mutable.lock().await;
match mutable.handle_config(message, self.get_id()) {
Ok(_) => (),
Err(msg) => match mutable.validate_sequence(msg) {
Ok(msg) => {
mutable.handle_action(msg, self.get_id());
mutable.handle_notification(self.get_id());
},
Err(msg) => panic!("sequence error fwd {}, msg {:#?}", self.get_id(), msg),
},
}
}
}
struct Adapter<T: InstructionSet<InstructionSet = T>> {
sender: broadcast::Sender<T>,
pub receiver: broadcast::Receiver<T>,
pub machine: Arc<dyn Machine<<T as InstructionSet>::InstructionSet>>,
}
impl<T: InstructionSet<InstructionSet = T>> Clone for Adapter<T> {
fn clone(&self) -> Self {
Self {
sender: self.sender.clone(),
receiver: self.sender.subscribe(),
machine: self.machine.clone(),
}
}
}
fn build_bounded<T, U>(raw: U, capacity: usize) -> (Arc<U>, broadcast::Sender<T>, Adapter<T>)
where
U: Machine<T> + Send + Sync + 'static,
T: InstructionSet<InstructionSet = T>,
{
let instance = Arc::new(raw);
let (sender, receiver) = broadcast::channel::<T>(capacity);
let cloned_sender = sender.clone();
let machine = instance.clone() as Arc<dyn Machine<T>>;
let adapter = Adapter::<T> { sender, receiver, machine, };
(instance, cloned_sender, adapter)
}
#[derive(Debug, Clone, PartialEq)]
pub struct Error {
message: String,
}
impl fmt::Display for Error {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> {
write!(f, "{}", self.message)
}
}
您可以将该代码段重构为独立函数,但 Adapter 参数所需的泛型类型需要具有额外的特征边界以符合 Tokio 的生成要求。这是内联时为您推断的,但是当拉出到函数中时,它必须是显式的。
Tokio 生成的所有任务函数都需要 Send
以及 'static
。请注意,此处的 static 仅表示允许注释类型一直存在到程序结束。它与应用于引用的静态无关。
Tokio 生成教程中提供了对这种微妙之处的深入讨论:https://tokio.rs/tokio/tutorial/spawning - 在讨论“静态和发送边界”的过程中寻找一个部分。
这是重构的结果:
fn spawn_task<T>(mut a: Adapter<T>) -> JoinHandle<()>
where
T: InstructionSet<InstructionSet = T> + Send + 'static
{
spawn(async move {
loop {
let cmd = a.receiver.recv().await;
if cmd.is_ok() {
a.machine.recv(cmd.unwrap()).await;
} else {
break
}
}
a.machine.disconnected().await;
})
}