diff options
author | HampusM <hampus@hampusmat.com> | 2022-08-30 18:53:23 +0200 |
---|---|---|
committer | HampusM <hampus@hampusmat.com> | 2022-08-30 18:53:23 +0200 |
commit | d6f01bd571753dc2e9628418f94f66139438bcb3 (patch) | |
tree | 1ed5492d8abdeb9231d498e9ecf349c7cc1ec3d8 /src | |
parent | 080cc42bb1da09059dbc35049a7ded0649961e0c (diff) |
refactor: replace arc cast panic with an error
Diffstat (limited to 'src')
-rw-r--r-- | src/async_di_container.rs | 26 | ||||
-rw-r--r-- | src/errors/async_di_container.rs | 4 | ||||
-rw-r--r-- | src/libs/intertrait/cast/arc.rs | 21 | ||||
-rw-r--r-- | src/libs/intertrait/cast/box.rs | 4 | ||||
-rw-r--r-- | src/libs/intertrait/cast/error.rs | 3 | ||||
-rw-r--r-- | src/libs/intertrait/cast/rc.rs | 4 | ||||
-rw-r--r-- | src/libs/intertrait/mod.rs | 25 |
7 files changed, 54 insertions, 33 deletions
diff --git a/src/async_di_container.rs b/src/async_di_container.rs index 374746f..ecf3a41 100644 --- a/src/async_di_container.rs +++ b/src/async_di_container.rs @@ -66,6 +66,7 @@ use crate::errors::async_di_container::{ AsyncDIContainerError, }; use crate::interfaces::async_injectable::AsyncInjectable; +use crate::libs::intertrait::cast::error::CastError; use crate::libs::intertrait::cast::{CastArc, CastBox}; use crate::provider::r#async::{ AsyncProvidable, @@ -410,11 +411,26 @@ impl AsyncDIContainer )) } AsyncProvidable::Singleton(singleton_binding) => { - Ok(SomeThreadsafePtr::ThreadsafeSingleton( - singleton_binding.cast::<Interface>().map_err(|_| { - AsyncDIContainerError::CastFailed(type_name::<Interface>()) - })?, - )) + Ok( + SomeThreadsafePtr::ThreadsafeSingleton( + singleton_binding.cast::<Interface>().map_err( + |err| match err { + CastError::NotArcCastable(_) => { + AsyncDIContainerError::InterfaceNotAsync(type_name::< + Interface, + >( + )) + } + CastError::CastFailed { from: _, to: _ } => { + AsyncDIContainerError::CastFailed(type_name::< + Interface, + >( + )) + } + }, + )?, + ), + ) } #[cfg(feature = "factory")] AsyncProvidable::Factory(factory_binding) => { diff --git a/src/errors/async_di_container.rs b/src/errors/async_di_container.rs index 4f5e50a..bdb6fa0 100644 --- a/src/errors/async_di_container.rs +++ b/src/errors/async_di_container.rs @@ -43,6 +43,10 @@ pub enum AsyncDIContainerError /// The name of the binding if one exists. name: Option<&'static str>, }, + + /// A interface has not been marked async. + #[error("Interface '{0}' has not been marked async")] + InterfaceNotAsync(&'static str), } /// Error type for [`AsyncBindingBuilder`]. diff --git a/src/libs/intertrait/cast/arc.rs b/src/libs/intertrait/cast/arc.rs index 94c0482..33d84d2 100644 --- a/src/libs/intertrait/cast/arc.rs +++ b/src/libs/intertrait/cast/arc.rs @@ -13,7 +13,7 @@ use std::any::type_name; use std::sync::Arc; use crate::libs::intertrait::cast::error::CastError; -use crate::libs::intertrait::{caster, CastFromSync}; +use crate::libs::intertrait::{get_caster, CastFromSync}; pub trait CastArc { @@ -31,12 +31,19 @@ impl<CastFromSelf: ?Sized + CastFromSync> CastArc for CastFromSelf self: Arc<Self>, ) -> Result<Arc<OtherTrait>, CastError> { - match caster::<OtherTrait>((*self).type_id()) { - Some(caster) => Ok((caster.cast_arc)(self.arc_any())), - None => Err(CastError::CastFailed { - from: type_name::<CastFromSelf>(), - to: type_name::<OtherTrait>(), - }), + let caster = get_caster::<OtherTrait>((*self).type_id()).map_or_else( + || { + Err(CastError::CastFailed { + from: type_name::<CastFromSelf>(), + to: type_name::<OtherTrait>(), + }) + }, + Ok, + )?; + + match caster.opt_cast_arc { + Some(cast_arc) => Ok(cast_arc(self.arc_any())), + None => Err(CastError::NotArcCastable(type_name::<OtherTrait>())), } } } diff --git a/src/libs/intertrait/cast/box.rs b/src/libs/intertrait/cast/box.rs index 31f06db..c463c2f 100644 --- a/src/libs/intertrait/cast/box.rs +++ b/src/libs/intertrait/cast/box.rs @@ -13,7 +13,7 @@ use std::any::type_name; use crate::libs::intertrait::cast::error::CastError; -use crate::libs::intertrait::{caster, CastFrom}; +use crate::libs::intertrait::{get_caster, CastFrom}; pub trait CastBox { @@ -30,7 +30,7 @@ impl<CastFromSelf: ?Sized + CastFrom> CastBox for CastFromSelf self: Box<Self>, ) -> Result<Box<OtherTrait>, CastError> { - match caster::<OtherTrait>((*self).type_id()) { + match get_caster::<OtherTrait>((*self).type_id()) { Some(caster) => Ok((caster.cast_box)(self.box_any())), None => Err(CastError::CastFailed { from: type_name::<CastFromSelf>(), diff --git a/src/libs/intertrait/cast/error.rs b/src/libs/intertrait/cast/error.rs index 74eb3ca..a834c05 100644 --- a/src/libs/intertrait/cast/error.rs +++ b/src/libs/intertrait/cast/error.rs @@ -7,4 +7,7 @@ pub enum CastError from: &'static str, to: &'static str, }, + + #[error("Trait '{0}' can't be cast to Arc")] + NotArcCastable(&'static str), } diff --git a/src/libs/intertrait/cast/rc.rs b/src/libs/intertrait/cast/rc.rs index dfb71c2..63c0024 100644 --- a/src/libs/intertrait/cast/rc.rs +++ b/src/libs/intertrait/cast/rc.rs @@ -13,7 +13,7 @@ use std::any::type_name; use std::rc::Rc; use crate::libs::intertrait::cast::error::CastError; -use crate::libs::intertrait::{caster, CastFrom}; +use crate::libs::intertrait::{get_caster, CastFrom}; pub trait CastRc { @@ -30,7 +30,7 @@ impl<CastFromSelf: ?Sized + CastFrom> CastRc for CastFromSelf self: Rc<Self>, ) -> Result<Rc<OtherTrait>, CastError> { - match caster::<OtherTrait>((*self).type_id()) { + match get_caster::<OtherTrait>((*self).type_id()) { Some(caster) => Ok((caster.cast_rc)(self.rc_any())), None => Err(CastError::CastFailed { from: type_name::<CastFromSelf>(), diff --git a/src/libs/intertrait/mod.rs b/src/libs/intertrait/mod.rs index bdae4c7..3b3e9ba 100644 --- a/src/libs/intertrait/mod.rs +++ b/src/libs/intertrait/mod.rs @@ -23,7 +23,7 @@ //! MIT license (LICENSE-MIT or <http://opensource.org/licenses/MIT>) //! //! at your option. -use std::any::{type_name, Any, TypeId}; +use std::any::{Any, TypeId}; use std::rc::Rc; use std::sync::Arc; @@ -58,22 +58,12 @@ static CASTER_MAP: Lazy<AHashMap<(TypeId, TypeId), BoxedCaster>> = Lazy::new(|| .collect() }); -fn cast_arc_panic<Trait: ?Sized + 'static>(_: Arc<dyn Any + Sync + Send>) -> Arc<Trait> -{ - panic!( - "Interface trait '{}' has not been marked async", - type_name::<Trait>() - ) -} +type CastArcFn<Trait> = fn(from: Arc<dyn Any + Sync + Send + 'static>) -> Arc<Trait>; /// A `Caster` knows how to cast a reference to or `Box` of a trait object for `Any` /// to a trait object of trait `Trait`. Each `Caster` instance is specific to a concrete /// type. That is, it knows how to cast to single specific trait implemented by single /// specific type. -/// -/// An implementation of a trait for a concrete type doesn't need to manually provide -/// a `Caster`. Instead attach `#[cast_to]` to the `impl` block. -#[doc(hidden)] pub struct Caster<Trait: ?Sized + 'static> { /// Casts a `Box` holding a trait object for `Any` to another `Box` holding a trait @@ -86,7 +76,7 @@ pub struct Caster<Trait: ?Sized + 'static> /// Casts an `Arc` holding a trait object for `Any + Sync + Send + 'static` /// to another `Arc` holding a trait object for trait `Trait`. - pub cast_arc: fn(from: Arc<dyn Any + Sync + Send + 'static>) -> Arc<Trait>, + pub opt_cast_arc: Option<CastArcFn<Trait>>, } impl<Trait: ?Sized + 'static> Caster<Trait> @@ -99,7 +89,7 @@ impl<Trait: ?Sized + 'static> Caster<Trait> Caster::<Trait> { cast_box, cast_rc, - cast_arc: cast_arc_panic, + opt_cast_arc: None, } } @@ -113,14 +103,15 @@ impl<Trait: ?Sized + 'static> Caster<Trait> Caster::<Trait> { cast_box, cast_rc, - cast_arc, + opt_cast_arc: Some(cast_arc), } } } /// Returns a `Caster<S, Trait>` from a concrete type `S` to a trait `Trait` implemented /// by it. -fn caster<Trait: ?Sized + 'static>(type_id: TypeId) -> Option<&'static Caster<Trait>> +fn get_caster<Trait: ?Sized + 'static>(type_id: TypeId) + -> Option<&'static Caster<Trait>> { CASTER_MAP .get(&(type_id, TypeId::of::<Caster<Trait>>())) @@ -250,7 +241,7 @@ mod tests let caster = Box::new(Caster::<dyn Debug> { cast_box: |from| from.downcast::<TestStruct>().unwrap(), cast_rc: |from| from.downcast::<TestStruct>().unwrap(), - cast_arc: |from| from.downcast::<TestStruct>().unwrap(), + opt_cast_arc: Some(|from| from.downcast::<TestStruct>().unwrap()), }); (type_id, caster) } |