diff options
Diffstat (limited to 'src')
| -rw-r--r-- | src/di_container.rs | 199 | ||||
| -rw-r--r-- | src/di_container_binding_map.rs | 55 | ||||
| -rw-r--r-- | src/lib.rs | 1 | 
3 files changed, 134 insertions, 121 deletions
diff --git a/src/di_container.rs b/src/di_container.rs index e59d8f4..e18fc3a 100644 --- a/src/di_container.rs +++ b/src/di_container.rs @@ -1,18 +1,29 @@ -use std::any::{type_name, TypeId}; +use std::any::type_name;  use std::marker::PhantomData; -use std::rc::Rc; -use ahash::AHashMap; -use error_stack::{Report, ResultExt}; +use error_stack::{report, Report, ResultExt};  #[cfg(feature = "factory")]  use crate::castable_factory::CastableFactory; +use crate::di_container_binding_map::DIContainerBindingMap;  use crate::errors::di_container::{BindingBuilderError, DIContainerError};  use crate::interfaces::injectable::Injectable; +use crate::libs::intertrait::cast::error::CastError;  use crate::libs::intertrait::cast::{CastBox, CastRc}; -use crate::provider::{IProvider, Providable, SingletonProvider, TransientTypeProvider}; +use crate::provider::{Providable, SingletonProvider, TransientTypeProvider};  use crate::ptr::{SingletonPtr, TransientPtr}; +fn unable_to_cast_binding<Interface>(err: Report<CastError>) -> Report<DIContainerError> +where +    Interface: 'static + ?Sized, +{ +    err.change_context(DIContainerError) +        .attach_printable(format!( +            "Unable to cast binding for interface '{}'", +            type_name::<Interface>() +        )) +} +  /// Binding builder for type `Interface` inside a [`DIContainer`].  pub struct BindingBuilder<'di_container_lt, Interface>  where @@ -40,12 +51,9 @@ where      where          Implementation: Injectable,      { -        let interface_typeid = TypeId::of::<Interface>(); - -        self.di_container.bindings.insert( -            interface_typeid, -            Rc::new(TransientTypeProvider::<Implementation>::new()), -        ); +        self.di_container +            .bindings +            .set::<Interface>(Box::new(TransientTypeProvider::<Implementation>::new()));      }      /// Creates a binding of type `Interface` to a new singleton of type `Implementation` @@ -59,8 +67,6 @@ where      where          Implementation: Injectable,      { -        let interface_typeid = TypeId::of::<Interface>(); -          let singleton: SingletonPtr<Implementation> = SingletonPtr::from(              Implementation::resolve(self.di_container)                  .change_context(BindingBuilderError)?, @@ -68,7 +74,7 @@ where          self.di_container              .bindings -            .insert(interface_typeid, Rc::new(SingletonProvider::new(singleton))); +            .set::<Interface>(Box::new(SingletonProvider::new(singleton)));          Ok(())      } @@ -84,16 +90,13 @@ where          Return: 'static + ?Sized,          Interface: crate::interfaces::factory::IFactory<Args, Return>,      { -        let interface_typeid = TypeId::of::<Interface>(); -          let factory_impl = CastableFactory::new(factory_func); -        self.di_container.bindings.insert( -            interface_typeid, -            Rc::new(crate::provider::FactoryProvider::new( -                crate::ptr::FactoryPtr::new(factory_impl), +        self.di_container.bindings.set::<Interface>(Box::new( +            crate::provider::FactoryProvider::new(crate::ptr::FactoryPtr::new( +                factory_impl,              )), -        ); +        ));      }  } @@ -144,7 +147,7 @@ where  /// ```  pub struct DIContainer  { -    bindings: AHashMap<TypeId, Rc<dyn IProvider>>, +    bindings: DIContainerBindingMap,  }  impl DIContainer @@ -154,7 +157,7 @@ impl DIContainer      pub fn new() -> Self      {          Self { -            bindings: AHashMap::new(), +            bindings: DIContainerBindingMap::new(),          }      } @@ -180,39 +183,18 @@ impl DIContainer      where          Interface: 'static + ?Sized,      { -        let interface_typeid = TypeId::of::<Interface>(); - -        let interface_name = type_name::<Interface>(); +        let binding_providable = self.get_binding_providable::<Interface>()?; -        let binding = self.bindings.get(&interface_typeid).ok_or_else(|| { -            Report::new(DIContainerError) -                .attach_printable(format!("No binding exists for {}", interface_name)) -        })?; - -        let binding_providable = binding -            .provide(self) -            .change_context(DIContainerError) -            .attach_printable(format!( -            "Failed to resolve binding for interface {}", -            interface_name -        ))?; - -        match binding_providable { -            Providable::Transient(binding_injectable) => { -                let interface_result = binding_injectable.cast::<Interface>(); - -                match interface_result { -                    Ok(interface) => Ok(interface), -                    Err(_) => Err(Report::new(DIContainerError).attach_printable( -                        format!("Unable to cast binding for {}", interface_name), -                    )), -                } -            } -            _ => Err(Report::new(DIContainerError).attach_printable(format!( -                "Binding for {} is not injectable", -                interface_name -            ))), +        if let Providable::Transient(binding_transient) = binding_providable { +            return binding_transient +                .cast::<Interface>() +                .map_err(unable_to_cast_binding::<Interface>);          } + +        Err(report!(DIContainerError).attach_printable(format!( +            "Binding for interface '{}' is not transient", +            type_name::<Interface>() +        )))      }      /// Returns the singleton instance bound with `Interface`. @@ -229,39 +211,18 @@ impl DIContainer      where          Interface: 'static + ?Sized,      { -        let interface_typeid = TypeId::of::<Interface>(); - -        let interface_name = type_name::<Interface>(); +        let binding_providable = self.get_binding_providable::<Interface>()?; -        let binding = self.bindings.get(&interface_typeid).ok_or_else(|| { -            Report::new(DIContainerError) -                .attach_printable(format!("No binding exists for {}", interface_name)) -        })?; - -        let binding_providable = binding -            .provide(self) -            .change_context(DIContainerError) -            .attach_printable(format!( -            "Failed to resolve binding for interface {}", -            interface_name -        ))?; - -        match binding_providable { -            Providable::Singleton(binding_singleton) => { -                let interface_result = binding_singleton.cast::<Interface>(); - -                match interface_result { -                    Ok(interface) => Ok(interface), -                    Err(_) => Err(Report::new(DIContainerError).attach_printable( -                        format!("Unable to cast binding for {}", interface_name), -                    )), -                } -            } -            _ => Err(Report::new(DIContainerError).attach_printable(format!( -                "Binding for {} is not a singleton", -                interface_name -            ))), +        if let Providable::Singleton(binding_singleton) = binding_providable { +            return binding_singleton +                .cast::<Interface>() +                .map_err(unable_to_cast_binding::<Interface>);          } + +        Err(report!(DIContainerError).attach_printable(format!( +            "Binding for interface '{}' is not a singleton", +            type_name::<Interface>() +        )))      }      /// Returns the factory bound with factory type `Interface`. @@ -279,39 +240,34 @@ impl DIContainer      where          Interface: 'static + ?Sized,      { -        let interface_typeid = TypeId::of::<Interface>(); +        let binding_providable = self.get_binding_providable::<Interface>()?; -        let interface_name = type_name::<Interface>(); +        if let Providable::Factory(binding_factory) = binding_providable { +            return binding_factory +                .cast::<Interface>() +                .map_err(unable_to_cast_binding::<Interface>); +        } -        let binding = self.bindings.get(&interface_typeid).ok_or_else(|| { -            Report::new(DIContainerError) -                .attach_printable(format!("No binding exists for {}", interface_name)) -        })?; +        Err(report!(DIContainerError).attach_printable(format!( +            "Binding for interface '{}' is not a factory", +            type_name::<Interface>() +        ))) +    } -        let binding_providable = binding +    fn get_binding_providable<Interface>( +        &self, +    ) -> error_stack::Result<Providable, DIContainerError> +    where +        Interface: 'static + ?Sized, +    { +        self.bindings +            .get::<Interface>()?              .provide(self)              .change_context(DIContainerError)              .attach_printable(format!( -            "Failed to resolve binding for interface {}", -            interface_name -        ))?; - -        match binding_providable { -            Providable::Factory(binding_factory) => { -                let factory_result = binding_factory.cast::<Interface>(); - -                match factory_result { -                    Ok(factory) => Ok(factory), -                    Err(_) => Err(Report::new(DIContainerError).attach_printable( -                        format!("Unable to cast binding for {}", interface_name), -                    )), -                } -            } -            _ => Err(Report::new(DIContainerError).attach_printable(format!( -                "Binding for {} is not a factory", -                interface_name -            ))), -        } +                "Failed to resolve binding for interface '{}'", +                type_name::<Interface>() +            ))      }  } @@ -332,6 +288,7 @@ mod tests      use super::*;      use crate::errors::injectable::ResolveError; +    use crate::provider::IProvider;      use crate::ptr::TransientPtr;      #[test] @@ -376,11 +333,11 @@ mod tests          let mut di_container: DIContainer = DIContainer::new(); -        assert_eq!(di_container.bindings.len(), 0); +        assert_eq!(di_container.bindings.count(), 0);          di_container.bind::<dyn IUserManager>().to::<UserManager>(); -        assert_eq!(di_container.bindings.len(), 1); +        assert_eq!(di_container.bindings.count(), 1);      }      #[test] @@ -425,13 +382,13 @@ mod tests          let mut di_container: DIContainer = DIContainer::new(); -        assert_eq!(di_container.bindings.len(), 0); +        assert_eq!(di_container.bindings.count(), 0);          di_container              .bind::<dyn IUserManager>()              .to_singleton::<UserManager>()?; -        assert_eq!(di_container.bindings.len(), 1); +        assert_eq!(di_container.bindings.count(), 1);          Ok(())      } @@ -475,7 +432,7 @@ mod tests          let mut di_container: DIContainer = DIContainer::new(); -        assert_eq!(di_container.bindings.len(), 0); +        assert_eq!(di_container.bindings.count(), 0);          di_container.bind::<IUserManagerFactory>().to_factory(&|| {              let user_manager: TransientPtr<dyn IUserManager> = @@ -484,7 +441,7 @@ mod tests              user_manager          }); -        assert_eq!(di_container.bindings.len(), 1); +        assert_eq!(di_container.bindings.count(), 1);      }      #[test] @@ -546,7 +503,7 @@ mod tests          di_container              .bindings -            .insert(TypeId::of::<dyn IUserManager>(), Rc::new(mock_provider)); +            .set::<dyn IUserManager>(Box::new(mock_provider));          di_container.get::<dyn IUserManager>()?; @@ -635,7 +592,7 @@ mod tests          di_container              .bindings -            .insert(TypeId::of::<dyn INumber>(), Rc::new(mock_provider)); +            .set::<dyn INumber>(Box::new(mock_provider));          let first_number_rc = di_container.get_singleton::<dyn INumber>()?; @@ -723,7 +680,7 @@ mod tests          di_container              .bindings -            .insert(TypeId::of::<IUserManagerFactory>(), Rc::new(mock_provider)); +            .set::<IUserManagerFactory>(Box::new(mock_provider));          di_container.get_factory::<IUserManagerFactory>()?; diff --git a/src/di_container_binding_map.rs b/src/di_container_binding_map.rs new file mode 100644 index 0000000..b505321 --- /dev/null +++ b/src/di_container_binding_map.rs @@ -0,0 +1,55 @@ +use std::any::{type_name, TypeId}; + +use ahash::AHashMap; +use error_stack::report; + +use crate::{errors::di_container::DIContainerError, provider::IProvider}; + +pub struct DIContainerBindingMap +{ +    bindings: AHashMap<TypeId, Box<dyn IProvider>>, +} + +impl DIContainerBindingMap +{ +    pub fn new() -> Self +    { +        Self { +            bindings: AHashMap::new(), +        } +    } + +    pub fn get<Interface>(&self) -> error_stack::Result<&dyn IProvider, DIContainerError> +    where +        Interface: 'static + ?Sized, +    { +        let interface_typeid = TypeId::of::<Interface>(); + +        Ok(self +            .bindings +            .get(&interface_typeid) +            .ok_or_else(|| { +                report!(DIContainerError).attach_printable(format!( +                    "No binding exists for {}", +                    type_name::<Interface>() +                )) +            })? +            .as_ref()) +    } + +    pub fn set<Interface>(&mut self, provider: Box<dyn IProvider>) +    where +        Interface: 'static + ?Sized, +    { +        let interface_typeid = TypeId::of::<Interface>(); + +        self.bindings.insert(interface_typeid, provider); +    } + +    /// Only used by tests in the ``di_container`` module. +    #[cfg(test)] +    pub fn count(&self) -> usize +    { +        self.bindings.len() +    } +} @@ -22,6 +22,7 @@ pub mod castable_factory;  pub mod libs;  // Private +mod di_container_binding_map;  mod provider;  /// Shortcut for creating a DI container binding for a injectable without a declared interface.  | 
