diff options
-rw-r--r-- | macros/src/libs/intertrait_macros/gen_caster.rs | 40 | ||||
-rw-r--r-- | src/di_container/asynchronous/mod.rs | 27 | ||||
-rw-r--r-- | src/libs/intertrait/cast/arc.rs | 13 | ||||
-rw-r--r-- | src/libs/intertrait/cast/box.rs | 13 | ||||
-rw-r--r-- | src/libs/intertrait/cast/error.rs | 7 | ||||
-rw-r--r-- | src/libs/intertrait/cast/rc.rs | 13 | ||||
-rw-r--r-- | src/libs/intertrait/mod.rs | 82 |
7 files changed, 148 insertions, 47 deletions
diff --git a/macros/src/libs/intertrait_macros/gen_caster.rs b/macros/src/libs/intertrait_macros/gen_caster.rs index df743e2..a703a62 100644 --- a/macros/src/libs/intertrait_macros/gen_caster.rs +++ b/macros/src/libs/intertrait_macros/gen_caster.rs @@ -33,16 +33,46 @@ pub fn generate_caster( let new_caster = if sync { quote! { syrette::libs::intertrait::Caster::<dyn #dst_trait>::new_sync( - |from| from.downcast::<#ty>().unwrap(), - |from| from.downcast::<#ty>().unwrap(), - |from| from.downcast::<#ty>().unwrap() + |from| { + let concrete = from + .downcast::<#ty>() + .map_err(|_| syrette::libs::intertrait::CasterError::CastBoxFailed)?; + + Ok(concrete as Box<dyn #dst_trait>) + }, + |from| { + let concrete = from + .downcast::<#ty>() + .map_err(|_| syrette::libs::intertrait::CasterError::CastRcFailed)?; + + Ok(concrete as std::rc::Rc<dyn #dst_trait>) + }, + |from| { + let concrete = from + .downcast::<#ty>() + .map_err(|_| syrette::libs::intertrait::CasterError::CastArcFailed)?; + + Ok(concrete as std::sync::Arc<dyn #dst_trait>) + }, ) } } else { quote! { syrette::libs::intertrait::Caster::<dyn #dst_trait>::new( - |from| from.downcast::<#ty>().unwrap(), - |from| from.downcast::<#ty>().unwrap(), + |from| { + let concrete = from + .downcast::<#ty>() + .map_err(|_| syrette::libs::intertrait::CasterError::CastBoxFailed)?; + + Ok(concrete as Box<dyn #dst_trait>) + }, + |from| { + let concrete = from + .downcast::<#ty>() + .map_err(|_| syrette::libs::intertrait::CasterError::CastRcFailed)?; + + Ok(concrete as std::rc::Rc<dyn #dst_trait>) + }, ) } }; diff --git a/src/di_container/asynchronous/mod.rs b/src/di_container/asynchronous/mod.rs index 7dda1d7..89b2fba 100644 --- a/src/di_container/asynchronous/mod.rs +++ b/src/di_container/asynchronous/mod.rs @@ -270,7 +270,12 @@ impl AsyncDIContainer >( )) } - CastError::CastFailed { from: _, to: _ } => { + CastError::CastFailed { + source: _, + from: _, + to: _, + } + | CastError::GetCasterFailed(_) => { AsyncDIContainerError::CastFailed { interface: type_name::<Interface>(), binding_kind: "singleton", @@ -291,7 +296,12 @@ impl AsyncDIContainer type_name::<Interface>(), ) } - CastError::CastFailed { from: _, to: _ } => { + CastError::CastFailed { + source: _, + from: _, + to: _, + } + | CastError::GetCasterFailed(_) => { AsyncDIContainerError::CastFailed { interface: type_name::<Interface>(), binding_kind: "factory", @@ -348,12 +358,15 @@ impl AsyncDIContainer CastError::NotArcCastable(_) => { AsyncDIContainerError::InterfaceNotAsync(type_name::<Type>()) } - CastError::CastFailed { from: _, to: _ } => { - AsyncDIContainerError::CastFailed { - interface: type_name::<Type>(), - binding_kind, - } + CastError::CastFailed { + source: _, + from: _, + to: _, } + | CastError::GetCasterFailed(_) => AsyncDIContainerError::CastFailed { + interface: type_name::<Type>(), + binding_kind, + }, }) } diff --git a/src/libs/intertrait/cast/arc.rs b/src/libs/intertrait/cast/arc.rs index 1742c32..135cf64 100644 --- a/src/libs/intertrait/cast/arc.rs +++ b/src/libs/intertrait/cast/arc.rs @@ -31,16 +31,17 @@ impl<CastFromSelf: ?Sized + CastFromSync> CastArc for CastFromSelf self: Arc<Self>, ) -> Result<Arc<OtherTrait>, CastError> { - let caster = - get_caster::<OtherTrait>((*self).type_id()).ok_or(CastError::CastFailed { - from: type_name::<CastFromSelf>(), - to: type_name::<OtherTrait>(), - })?; + let caster = get_caster::<OtherTrait>((*self).type_id()) + .map_err(CastError::GetCasterFailed)?; let cast_arc = caster .opt_cast_arc .ok_or(CastError::NotArcCastable(type_name::<OtherTrait>()))?; - Ok(cast_arc(self.arc_any())) + cast_arc(self.arc_any()).map_err(|err| CastError::CastFailed { + source: err, + from: type_name::<Self>(), + to: type_name::<OtherTrait>(), + }) } } diff --git a/src/libs/intertrait/cast/box.rs b/src/libs/intertrait/cast/box.rs index 5694d97..67fd949 100644 --- a/src/libs/intertrait/cast/box.rs +++ b/src/libs/intertrait/cast/box.rs @@ -30,12 +30,13 @@ impl<CastFromSelf: ?Sized + CastFrom> CastBox for CastFromSelf self: Box<Self>, ) -> Result<Box<OtherTrait>, CastError> { - let caster = - get_caster::<OtherTrait>((*self).type_id()).ok_or(CastError::CastFailed { - from: type_name::<CastFromSelf>(), - to: type_name::<OtherTrait>(), - })?; + let caster = get_caster::<OtherTrait>((*self).type_id()) + .map_err(CastError::GetCasterFailed)?; - Ok((caster.cast_box)(self.box_any())) + (caster.cast_box)(self.box_any()).map_err(|err| CastError::CastFailed { + source: err, + from: type_name::<Self>(), + to: type_name::<OtherTrait>(), + }) } } diff --git a/src/libs/intertrait/cast/error.rs b/src/libs/intertrait/cast/error.rs index a834c05..e6d86a5 100644 --- a/src/libs/intertrait/cast/error.rs +++ b/src/libs/intertrait/cast/error.rs @@ -1,9 +1,16 @@ +use crate::libs::intertrait::{CasterError, GetCasterError}; + #[derive(thiserror::Error, Debug)] pub enum CastError { + #[error("Failed to get caster")] + GetCasterFailed(#[from] GetCasterError), + #[error("Failed to cast from trait {from} to trait {to}")] CastFailed { + #[source] + source: CasterError, from: &'static str, to: &'static str, }, diff --git a/src/libs/intertrait/cast/rc.rs b/src/libs/intertrait/cast/rc.rs index 805bcd7..ec70544 100644 --- a/src/libs/intertrait/cast/rc.rs +++ b/src/libs/intertrait/cast/rc.rs @@ -30,12 +30,13 @@ impl<CastFromSelf: ?Sized + CastFrom> CastRc for CastFromSelf self: Rc<Self>, ) -> Result<Rc<OtherTrait>, CastError> { - let caster = - get_caster::<OtherTrait>((*self).type_id()).ok_or(CastError::CastFailed { - from: type_name::<CastFromSelf>(), - to: type_name::<OtherTrait>(), - })?; + let caster = get_caster::<OtherTrait>((*self).type_id()) + .map_err(CastError::GetCasterFailed)?; - Ok((caster.cast_rc)(self.rc_any())) + (caster.cast_rc)(self.rc_any()).map_err(|err| CastError::CastFailed { + source: err, + from: type_name::<Self>(), + to: type_name::<OtherTrait>(), + }) } } diff --git a/src/libs/intertrait/mod.rs b/src/libs/intertrait/mod.rs index dc0f19e..78f98b1 100644 --- a/src/libs/intertrait/mod.rs +++ b/src/libs/intertrait/mod.rs @@ -56,7 +56,12 @@ static CASTER_MAP: Lazy<AHashMap<(TypeId, TypeId), BoxedCaster>> = Lazy::new(|| .collect() }); -type CastArcFn<Trait> = fn(from: Arc<dyn Any + Sync + Send + 'static>) -> Arc<Trait>; +type CastBoxFn<Trait> = fn(from: Box<dyn Any>) -> Result<Box<Trait>, CasterError>; + +type CastRcFn<Trait> = fn(from: Rc<dyn Any>) -> Result<Rc<Trait>, CasterError>; + +type CastArcFn<Trait> = + fn(from: Arc<dyn Any + Sync + Send + 'static>) -> Result<Arc<Trait>, CasterError>; /// A `Caster` knows how to cast a reference to or `Box` of a trait object for `Any` /// to a trait object of trait `Trait`. Each `Caster` instance is specific to a concrete @@ -66,11 +71,11 @@ pub struct Caster<Trait: ?Sized + 'static> { /// Casts a `Box` holding a trait object for `Any` to another `Box` holding a trait /// object for trait `Trait`. - pub cast_box: fn(from: Box<dyn Any>) -> Box<Trait>, + pub cast_box: CastBoxFn<Trait>, /// Casts an `Rc` holding a trait object for `Any` to another `Rc` holding a trait /// object for trait `Trait`. - pub cast_rc: fn(from: Rc<dyn Any>) -> Rc<Trait>, + pub cast_rc: CastRcFn<Trait>, /// Casts an `Arc` holding a trait object for `Any + Sync + Send + 'static` /// to another `Arc` holding a trait object for trait `Trait`. @@ -79,10 +84,7 @@ pub struct Caster<Trait: ?Sized + 'static> impl<Trait: ?Sized + 'static> Caster<Trait> { - pub fn new( - cast_box: fn(from: Box<dyn Any>) -> Box<Trait>, - cast_rc: fn(from: Rc<dyn Any>) -> Rc<Trait>, - ) -> Caster<Trait> + pub fn new(cast_box: CastBoxFn<Trait>, cast_rc: CastRcFn<Trait>) -> Caster<Trait> { Caster::<Trait> { cast_box, @@ -93,9 +95,9 @@ impl<Trait: ?Sized + 'static> Caster<Trait> #[allow(clippy::similar_names)] pub fn new_sync( - cast_box: fn(from: Box<dyn Any>) -> Box<Trait>, - cast_rc: fn(from: Rc<dyn Any>) -> Rc<Trait>, - cast_arc: fn(from: Arc<dyn Any + Sync + Send>) -> Arc<Trait>, + cast_box: CastBoxFn<Trait>, + cast_rc: CastRcFn<Trait>, + cast_arc: CastArcFn<Trait>, ) -> Caster<Trait> { Caster::<Trait> { @@ -106,14 +108,42 @@ impl<Trait: ?Sized + 'static> Caster<Trait> } } +#[derive(Debug, thiserror::Error)] +pub enum CasterError +{ + #[error("Failed to cast Box")] + CastBoxFailed, + + #[error("Failed to cast Rc")] + CastRcFailed, + + #[error("Failed to cast Arc")] + CastArcFailed, +} + /// Returns a `Caster<S, Trait>` from a concrete type `S` to a trait `Trait` implemented /// by it. -fn get_caster<Trait: ?Sized + 'static>(type_id: TypeId) - -> Option<&'static Caster<Trait>> +fn get_caster<Trait: ?Sized + 'static>( + type_id: TypeId, +) -> Result<&'static Caster<Trait>, GetCasterError> { - CASTER_MAP + let any_caster = CASTER_MAP .get(&(type_id, TypeId::of::<Caster<Trait>>())) - .and_then(|caster| caster.downcast_ref::<Caster<Trait>>()) + .ok_or(GetCasterError::NotFound)?; + + any_caster + .downcast_ref::<Caster<Trait>>() + .ok_or(GetCasterError::DowncastFailed) +} + +#[derive(Debug, thiserror::Error)] +pub enum GetCasterError +{ + #[error("Caster not found")] + NotFound, + + #[error("Failed to downcast caster")] + DowncastFailed, } /// `CastFrom` must be extended by a trait that wants to allow for casting into another @@ -237,9 +267,27 @@ mod tests { let type_id = TypeId::of::<TestStruct>(); let caster = Box::new(Caster::<dyn Debug> { - cast_box: |from| from.downcast::<TestStruct>().unwrap(), - cast_rc: |from| from.downcast::<TestStruct>().unwrap(), - opt_cast_arc: Some(|from| from.downcast::<TestStruct>().unwrap()), + cast_box: |from| { + let concrete = from + .downcast::<TestStruct>() + .map_err(|_| CasterError::CastBoxFailed)?; + + Ok(concrete as Box<dyn Debug>) + }, + cast_rc: |from| { + let concrete = from + .downcast::<TestStruct>() + .map_err(|_| CasterError::CastRcFailed)?; + + Ok(concrete as Rc<dyn Debug>) + }, + opt_cast_arc: Some(|from| { + let concrete = from + .downcast::<TestStruct>() + .map_err(|_| CasterError::CastArcFailed)?; + + Ok(concrete as Arc<dyn Debug>) + }), }); (type_id, caster) } |