1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
use std::{
sync::{Arc, Mutex},
thread,
time::Duration,
};
use serde::Deserialize;
use crate::{
dataflow::{graph::default_graph, Data, Message},
node::NodeId,
scheduler::channel_manager::ChannelManager,
};
use super::{
errors::{ReadError, TryReadError},
InternalReadStream, ReadStream, StreamId,
};
pub struct ExtractStream<D>
where
for<'a> D: Data + Deserialize<'a>,
{
id: StreamId,
node_id: NodeId,
read_stream_option: Option<ReadStream<D>>,
channel_manager_option: Arc<Mutex<Option<Arc<Mutex<ChannelManager>>>>>,
}
impl<D> ExtractStream<D>
where
for<'a> D: Data + Deserialize<'a>,
{
pub fn new(node_id: NodeId, read_stream: &ReadStream<D>) -> Self {
let id = read_stream.get_id();
let extract_stream = Self {
id,
node_id,
read_stream_option: None,
channel_manager_option: Arc::new(Mutex::new(None)),
};
let channel_manager_option_copy = Arc::clone(&extract_stream.channel_manager_option);
let setup_hook = move |channel_manager: Arc<Mutex<ChannelManager>>| {
channel_manager_option_copy
.lock()
.unwrap()
.replace(channel_manager);
};
default_graph::add_extract_stream(&extract_stream, setup_hook);
extract_stream
}
pub fn get_id(&self) -> StreamId {
self.id
}
pub fn get_node_id(&self) -> NodeId {
self.node_id
}
pub fn is_closed(&self) -> bool {
self.read_stream_option
.as_ref()
.map(ReadStream::is_closed)
.unwrap_or(true)
}
pub fn try_read(&mut self) -> Result<Message<D>, TryReadError> {
if let Some(read_stream) = &self.read_stream_option {
read_stream.try_read()
} else {
if let Some(channel_manager) = &*self.channel_manager_option.lock().unwrap() {
match channel_manager.lock().unwrap().take_recv_endpoint(self.id) {
Ok(recv_endpoint) => {
let read_stream = ReadStream::from(InternalReadStream::from_endpoint(
recv_endpoint,
self.id,
));
let result = read_stream.try_read();
self.read_stream_option.replace(read_stream);
return result;
}
Err(msg) => eprintln!(
"ExtractStream {}: error getting endpoint from channel manager \"{}\"",
self.id, msg
),
}
}
Err(TryReadError::Disconnected)
}
}
pub fn read(&mut self) -> Result<Message<D>, ReadError> {
loop {
let result = self.try_read();
if self.read_stream_option.is_some() {
break match result {
Ok(msg) => Ok(msg),
Err(TryReadError::Disconnected) => Err(ReadError::Disconnected),
Err(TryReadError::Empty) => self.read_stream_option.as_ref().unwrap().read(),
Err(TryReadError::SerializationError) => Err(ReadError::SerializationError),
Err(TryReadError::Closed) => Err(ReadError::Closed),
};
} else {
thread::sleep(Duration::from_millis(100));
}
}
}
}
unsafe impl<D> Send for ExtractStream<D> where for<'a> D: Data + Deserialize<'a> {}