Skip to main content

hydro_lang/
handoff_ref.rs

1//! Reference handles for capturing singletons, optionals, and streams in `q!()` closures.
2//!
3//! Each handle type wraps a `&RefCell<HydroNode>` and, when captured inside a `q!()` closure,
4//! registers itself with the current capture scope. At codegen time, the IR node is lowered
5//! to the corresponding DFIR pseudo-operator (`singleton()`, `optional()`, or `handoff()`),
6//! and the reference resolves to the appropriate borrow type.
7
8use std::cell::RefCell;
9use std::marker::PhantomData;
10use std::rc::Rc;
11
12use proc_macro2::Span;
13use quote::quote;
14use stageleft::runtime_support::{FreeVariableWithContextWithProps, QuoteTokens};
15
16use crate::compile::ir::{AccessCounter, HydroNode, SharedNode};
17use crate::location::Location;
18
19/// Determines which DFIR pseudo-operator a reference node lowers to.
20#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, serde::Serialize, serde::Deserialize)]
21pub enum HandoffRefKind {
22    /// `-> singleton()` — exactly one item, `#var` gives `&T`.
23    Singleton,
24    /// `-> optional()` — zero or one item, `#var` gives `&Option<T>`.
25    Optional,
26    /// `-> handoff()` — zero or more items, `#var` gives `&Vec<T>`.
27    Vec,
28}
29
30// Thread-local storage for handoff references captured during `q!()` expansion.
31// Stores the `HydroNode::Reference` and `is_mut: bool` for each reference captured in the current closure.
32// The index determines the ident name via `handoff_ref_ident`.
33thread_local! {
34    static CAPTURED_REFS: RefCell<Option<Vec<(HydroNode, bool)>>> = const { RefCell::new(None) };
35}
36
37/// Returns the canonical ident for a captured ref at the given index within a closure.
38pub(crate) fn handoff_ref_ident(index: usize) -> syn::Ident {
39    syn::Ident::new(
40        &format!("__hydro_singleton_ref_{}", index),
41        Span::call_site(),
42    )
43}
44
45/// Activate the reference capture context. Must be called before `q!()` expansion
46/// that may capture handoff references. Returns a `ClosureExpr` bundling the expression with any
47/// captured references.
48pub fn with_ref_capture(
49    f: impl FnOnce() -> crate::compile::ir::DebugExpr,
50) -> crate::compile::ir::ClosureExpr {
51    CAPTURED_REFS.with(|cell| {
52        let prev = cell.borrow_mut().replace(Vec::new());
53        assert!(
54            prev.is_none(),
55            "nested handoff reference capture scopes are not supported"
56        );
57    });
58    let expr = (f)();
59    let captured_refs = CAPTURED_REFS.with(|cell| cell.borrow_mut().take().unwrap());
60    crate::compile::ir::ClosureExpr::new(expr, captured_refs)
61}
62
63/// Shared registration logic: wraps the IR node in `HydroNode::Reference` if needed,
64/// pushes it to the capture list, and returns the ident to use in the closure body.
65fn register_handoff_ref(
66    ir_node: &RefCell<HydroNode>,
67    is_mut: bool,
68    kind: HandoffRefKind,
69) -> syn::Ident {
70    CAPTURED_REFS.with(|cell| {
71        let mut guard = cell.borrow_mut();
72        let refs = guard.as_mut().expect(
73            "HandoffRef used inside q!() but no reference capture scope is active. \
74             This is a bug — reference capture should be set up by the operator that uses q!().",
75        );
76
77        let index = refs.len();
78        let ident = handoff_ref_ident(index);
79
80        let metadata = ir_node.borrow().metadata().clone();
81
82        // Wrap in HydroNode::Reference for materialization + identity tracking.
83        // If already a Reference node, reuse it.
84        if !matches!(&*ir_node.borrow(), HydroNode::Reference { .. }) {
85            let orig = ir_node.replace(HydroNode::Placeholder);
86            *ir_node.borrow_mut() = HydroNode::Reference {
87                inner: SharedNode(Rc::new(RefCell::new(orig))),
88                kind,
89                access_counter: AccessCounter::new(),
90                metadata: metadata.clone(),
91            };
92        }
93
94        let borrow: std::cell::Ref<'_, HydroNode> = ir_node.borrow();
95        let HydroNode::Reference {
96            inner,
97            access_counter,
98            ..
99        } = &*borrow
100        else {
101            unreachable!()
102        };
103
104        // Compute access group at staging time (code order).
105        let group = access_counter.next_group(is_mut);
106
107        refs.push((
108            HydroNode::Reference {
109                inner: SharedNode(Rc::clone(&inner.0)),
110                kind,
111                access_counter: group,
112                metadata,
113            },
114            is_mut,
115        ));
116
117        ident
118    })
119}
120
121/// Macro to define a handoff reference struct with all necessary trait impls.
122macro_rules! define_handoff_ref {
123    (
124        $(
125            $(#[$meta:meta])*
126            $name:ident, $is_mut:expr, $kind:expr, $output:ty
127        )+
128    ) => {
129        $(
130            $(#[$meta])*
131            pub struct $name<'a, 'slf, T, L> {
132                pub(crate) ir_node: &'slf RefCell<HydroNode>,
133                _phantom: PhantomData<(&'a T, L)>,
134            }
135
136            impl<'slf, T, L> $name<'_, 'slf, T, L> {
137                /// Creates a new reference handle from an IR node cell.
138                pub(crate) fn new(ir_node: &'slf RefCell<HydroNode>) -> Self {
139                    Self {
140                        ir_node,
141                        _phantom: PhantomData,
142                    }
143                }
144            }
145
146            impl<T, L> Copy for $name<'_, '_, T, L> {}
147            impl<T, L> Clone for $name<'_, '_, T, L> {
148                fn clone(&self) -> Self {
149                    *self
150                }
151            }
152
153            impl<'a, 'slf, T: 'a, L> FreeVariableWithContextWithProps<L, ()> for $name<'a, 'slf, T, L>
154            where
155                L: Location<'a>,
156            {
157                type O = $output;
158
159                fn to_tokens(self, _ctx: &L) -> (QuoteTokens, ()) {
160                    let ident = register_handoff_ref(
161                        self.ir_node,
162                        $is_mut,
163                        $kind,
164                    );
165                    (
166                        QuoteTokens {
167                            prelude: None,
168                            expr: Some(quote!(#ident)),
169                        },
170                        (),
171                    )
172                }
173            }
174        )+
175    };
176}
177
178#[stageleft::export(
179    SingletonRef,
180    SingletonMut,
181    OptionalRef,
182    OptionalMut,
183    StreamRef,
184    StreamMut
185)]
186define_handoff_ref!(
187    /// A shared reference handle to a singleton, resolves to `&T` at runtime.
188    ///
189    /// Created via [`Singleton::by_ref()`](crate::live_collections::Singleton::by_ref).
190    SingletonRef, false, HandoffRefKind::Singleton, &'a T
191
192    /// A mutable reference handle to a singleton, resolves to `&mut T` at runtime.
193    ///
194    /// Created via [`Singleton::by_mut()`](crate::live_collections::Singleton::by_mut).
195    SingletonMut, true, HandoffRefKind::Singleton, &'a mut T
196
197    /// A shared reference handle to an optional, resolves to `&Option<T>` at runtime.
198    ///
199    /// Created via [`Optional::by_ref()`](crate::live_collections::Optional::by_ref).
200    OptionalRef, false, HandoffRefKind::Optional, &'a Option<T>
201
202    /// A mutable reference handle to an optional, resolves to `&mut Option<T>` at runtime.
203    ///
204    /// Created via [`Optional::by_mut()`](crate::live_collections::Optional::by_mut).
205    OptionalMut, true, HandoffRefKind::Optional, &'a mut Option<T>
206
207    /// A shared reference handle to a stream's handoff buffer, resolves to `&Vec<T>` at runtime.
208    ///
209    /// Created via [`Stream::by_ref()`](crate::live_collections::Stream::by_ref).
210    StreamRef, false, HandoffRefKind::Vec, &'a Vec<T>
211
212    /// A mutable reference handle to a stream's handoff buffer, resolves to `&mut Vec<T>` at runtime.
213    ///
214    /// Created via [`Stream::by_mut()`](crate::live_collections::Stream::by_mut).
215    StreamMut, true, HandoffRefKind::Vec, &'a mut Vec<T>
216);
217
218#[cfg(test)]
219#[cfg(feature = "build")]
220mod tests {
221    use stageleft::q;
222
223    use crate::compile::builder::FlowBuilder;
224    use crate::location::Location;
225
226    struct P1 {}
227
228    /// Compile-only test: verifies that `by_ref()` + `q!()` produces valid IR.
229    #[test]
230    fn singleton_by_ref_compiles() {
231        let mut flow = FlowBuilder::new();
232        let node = flow.process::<P1>();
233
234        let my_count = node
235            .source_iter(q!(0..5i32))
236            .fold(q!(|| 0i32), q!(|acc: &mut i32, x| *acc += x));
237        let count_ref = my_count.by_ref();
238
239        node.source_iter(q!(1..=3i32))
240            .map(q!(|x| x + *count_ref))
241            .for_each(q!(|_| {}));
242
243        my_count.into_stream().for_each(q!(|_| {}));
244        let _built = flow.finalize();
245    }
246
247    /// Test with a non-Copy type (Vec) to ensure we're borrowing, not copying.
248    #[test]
249    fn singleton_by_ref_non_copy() {
250        let mut flow = FlowBuilder::new();
251        let node = flow.process::<P1>();
252
253        let my_vec = node.source_iter(q!(0..5i32)).fold(
254            q!(|| Vec::<i32>::new()),
255            q!(|acc: &mut Vec<i32>, x| acc.push(x)),
256        );
257        let vec_ref = my_vec.by_ref();
258
259        node.source_iter(q!(1..=3i32))
260            .map(q!(|x| x + vec_ref.len() as i32))
261            .for_each(q!(|_| {}));
262
263        my_vec.into_stream().for_each(q!(|_| {}));
264        let _built = flow.finalize();
265    }
266
267    /// Compile-only: singleton ref inside filter closure.
268    #[test]
269    fn singleton_by_ref_filter() {
270        let mut flow = FlowBuilder::new();
271        let node = flow.process::<P1>();
272
273        let threshold = node
274            .source_iter(q!(0..5i32))
275            .fold(q!(|| 0i32), q!(|acc: &mut i32, x| *acc += x));
276        let threshold_ref = threshold.by_ref();
277
278        node.source_iter(q!(1..=10i32))
279            .filter(q!(|x| *x > *threshold_ref))
280            .for_each(q!(|_| {}));
281
282        threshold.into_stream().for_each(q!(|_| {}));
283        let _built = flow.finalize();
284    }
285
286    /// Compile-only: singleton ref inside flat_map closure.
287    #[test]
288    fn singleton_by_ref_flat_map() {
289        let mut flow = FlowBuilder::new();
290        let node = flow.process::<P1>();
291
292        let count = node
293            .source_iter(q!(0..3i32))
294            .fold(q!(|| 0i32), q!(|acc: &mut i32, _| *acc += 1));
295        let count_ref = count.by_ref();
296
297        node.source_iter(q!(1..=2i32))
298            .flat_map_ordered(q!(|x| (0..*count_ref).map(move |i| x + i)))
299            .for_each(q!(|_| {}));
300
301        count.into_stream().for_each(q!(|_| {}));
302        let _built = flow.finalize();
303    }
304
305    /// Compile-only: singleton ref inside inspect closure.
306    #[test]
307    fn singleton_by_ref_inspect() {
308        let mut flow = FlowBuilder::new();
309        let node = flow.process::<P1>();
310
311        let count = node
312            .source_iter(q!(0..5i32))
313            .fold(q!(|| 0i32), q!(|acc: &mut i32, _| *acc += 1));
314        let count_ref = count.by_ref();
315
316        node.source_iter(q!(1..=3i32))
317            .inspect(q!(|x| println!("count={}, x={}", *count_ref, x)))
318            .for_each(q!(|_| {}));
319
320        count.into_stream().for_each(q!(|_| {}));
321        let _built = flow.finalize();
322    }
323
324    /// Compile-only: singleton ref inside partition predicate.
325    #[test]
326    fn singleton_by_ref_partition() {
327        let mut flow = FlowBuilder::new();
328        let node = flow.process::<P1>();
329
330        let threshold = node
331            .source_iter(q!(0..5i32))
332            .fold(q!(|| 0i32), q!(|acc: &mut i32, x| *acc += x));
333        let threshold_ref = threshold.by_ref();
334
335        let (above, below) = node
336            .source_iter(q!(1..=10i32))
337            .partition(q!(|x| *x > *threshold_ref));
338
339        above.for_each(q!(|_| {}));
340        below.for_each(q!(|_| {}));
341        threshold.into_stream().for_each(q!(|_| {}));
342        let _built = flow.finalize();
343    }
344
345    /// Compile-only: singleton ref inside partition with downstream operators on both branches.
346    #[test]
347    fn singleton_by_ref_partition_with_downstream_ops() {
348        let mut flow = FlowBuilder::new();
349        let node = flow.process::<P1>();
350
351        let threshold = node
352            .source_iter(q!(0..5i32))
353            .fold(q!(|| 0i32), q!(|acc: &mut i32, x| *acc += x));
354        let threshold_ref = threshold.by_ref();
355
356        let (above, below) = node
357            .source_iter(q!(1..=10i32))
358            .partition(q!(|x| *x > *threshold_ref));
359
360        above.map(q!(|x| x * 2)).for_each(q!(|_| {}));
361        below.map(q!(|x| x + 100)).for_each(q!(|_| {}));
362        threshold.into_stream().for_each(q!(|_| {}));
363        let _built = flow.finalize();
364    }
365
366    /// Compile-only test: singleton by_mut.
367    #[test]
368    fn singleton_by_mut_compiles() {
369        let mut flow = FlowBuilder::new();
370        let node = flow.process::<P1>();
371
372        let my_count = node
373            .source_iter(q!(0..5i32))
374            .fold(q!(|| 0i32), q!(|acc: &mut i32, x| *acc += x));
375        let count_mut = my_count.by_mut();
376
377        node.source_iter(q!(1..=3i32))
378            .map(q!(|x| {
379                *count_mut += x;
380                x
381            }))
382            .for_each(q!(|_| {}));
383
384        my_count.into_stream().for_each(q!(|_| {}));
385        let _built = flow.finalize();
386    }
387
388    /// Compile-only test: optional by_ref.
389    #[test]
390    fn optional_by_ref_compiles() {
391        let mut flow = FlowBuilder::new();
392        let node = flow.process::<P1>();
393
394        let my_opt = node.source_iter(q!(0..5i32)).reduce(q!(|a, b| *a += b));
395        let opt_ref = my_opt.by_ref();
396
397        node.source_iter(q!(1..=3i32))
398            .map(q!(|x| x + opt_ref.unwrap_or(0)))
399            .for_each(q!(|_| {}));
400
401        my_opt.into_stream().for_each(q!(|_| {}));
402        let _built = flow.finalize();
403    }
404
405    /// Compile-only test: stream by_ref.
406    #[test]
407    fn stream_by_ref_compiles() {
408        let mut flow = FlowBuilder::new();
409        let node = flow.process::<P1>();
410
411        let my_stream = node.source_iter(q!(0..5i32));
412        let stream_ref = my_stream.by_ref();
413
414        node.source_iter(q!(1..=3i32))
415            .map(q!(|x| x + stream_ref.len() as i32))
416            .for_each(q!(|_| {}));
417
418        my_stream.for_each(q!(|_| {}));
419        let _built = flow.finalize();
420    }
421
422    /// Compile-only test: singleton by_mut in filter (TotalOrder).
423    #[test]
424    fn singleton_by_mut_filter() {
425        let mut flow = FlowBuilder::new();
426        let node = flow.process::<P1>();
427
428        let my_count = node
429            .source_iter(q!(0..5i32))
430            .fold(q!(|| 0i32), q!(|acc: &mut i32, x| *acc += x));
431        let count_mut = my_count.by_mut();
432
433        node.source_iter(q!(1..=3i32))
434            .filter(q!(|x| {
435                *count_mut += *x;
436                *count_mut > 0
437            }))
438            .for_each(q!(|_| {}));
439
440        my_count.into_stream().for_each(q!(|_| {}));
441        let _built = flow.finalize();
442    }
443
444    /// Compile-only test: singleton by_mut in flat_map_ordered (TotalOrder).
445    #[test]
446    fn singleton_by_mut_flat_map() {
447        let mut flow = FlowBuilder::new();
448        let node = flow.process::<P1>();
449
450        let my_count = node
451            .source_iter(q!(0..5i32))
452            .fold(q!(|| 0i32), q!(|acc: &mut i32, x| *acc += x));
453        let count_mut = my_count.by_mut();
454
455        node.source_iter(q!(1..=3i32))
456            .flat_map_ordered(q!(|x| {
457                *count_mut += x;
458                vec![*count_mut]
459            }))
460            .for_each(q!(|_| {}));
461
462        my_count.into_stream().for_each(q!(|_| {}));
463        let _built = flow.finalize();
464    }
465
466    /// Compile-only test: singleton by_mut in filter_map (TotalOrder).
467    #[test]
468    fn singleton_by_mut_filter_map() {
469        let mut flow = FlowBuilder::new();
470        let node = flow.process::<P1>();
471
472        let my_count = node
473            .source_iter(q!(0..5i32))
474            .fold(q!(|| 0i32), q!(|acc: &mut i32, x| *acc += x));
475        let count_mut = my_count.by_mut();
476
477        node.source_iter(q!(1..=3i32))
478            .filter_map(q!(|x| {
479                *count_mut += x;
480                Some(*count_mut)
481            }))
482            .for_each(q!(|_| {}));
483
484        my_count.into_stream().for_each(q!(|_| {}));
485        let _built = flow.finalize();
486    }
487
488    /// Compile-only test: singleton by_mut in inspect (TotalOrder).
489    #[test]
490    fn singleton_by_mut_inspect() {
491        let mut flow = FlowBuilder::new();
492        let node = flow.process::<P1>();
493
494        let my_count = node
495            .source_iter(q!(0..5i32))
496            .fold(q!(|| 0i32), q!(|acc: &mut i32, x| *acc += x));
497        let count_mut = my_count.by_mut();
498
499        node.source_iter(q!(1..=3i32))
500            .inspect(q!(|x| {
501                *count_mut += *x;
502            }))
503            .for_each(q!(|_| {}));
504
505        my_count.into_stream().for_each(q!(|_| {}));
506        let _built = flow.finalize();
507    }
508
509    /// Compile-only test: singleton by_ref in for_each.
510    #[test]
511    fn singleton_by_ref_for_each() {
512        let mut flow = FlowBuilder::new();
513        let node = flow.process::<P1>();
514
515        let my_count = node
516            .source_iter(q!(0..5i32))
517            .fold(q!(|| 0i32), q!(|acc: &mut i32, x| *acc += x));
518        let count_ref = my_count.by_ref();
519
520        node.source_iter(q!(1..=3i32))
521            .for_each(q!(|x| println!("{}", x + *count_ref)));
522
523        my_count.into_stream().for_each(q!(|_| {}));
524        let _built = flow.finalize();
525    }
526
527    /// Compile-only test: singleton by_mut in for_each.
528    #[test]
529    fn singleton_by_mut_for_each() {
530        let mut flow = FlowBuilder::new();
531        let node = flow.process::<P1>();
532
533        let my_count = node
534            .source_iter(q!(0..5i32))
535            .fold(q!(|| 0i32), q!(|acc: &mut i32, x| *acc += x));
536        let count_mut = my_count.by_mut();
537
538        node.source_iter(q!(1..=3i32)).for_each(q!(|x| {
539            *count_mut += x;
540        }));
541
542        my_count.into_stream().for_each(q!(|_| {}));
543        let _built = flow.finalize();
544    }
545}