diff options
author | David Blajda <blajda@hotmail.com> | 2019-02-03 22:30:15 +0000 |
---|---|---|
committer | David Blajda <blajda@hotmail.com> | 2019-02-03 22:30:15 +0000 |
commit | 96715ceb58b24ee7220d98e421701daa550f44db (patch) | |
tree | 2d00984339efab0549fa07079be623b2a7b634f8 /src/client.rs | |
parent | 0a5892c67fb02e09a621ac8796ac84232935f5c3 (diff) |
Add Helix and Kraken scopes. Client Config and allow injecting of responses
Diffstat (limited to 'src/client.rs')
-rw-r--r-- | src/client.rs | 406 |
1 files changed, 342 insertions, 64 deletions
diff --git a/src/client.rs b/src/client.rs index a8bc0b5..0307a05 100644 --- a/src/client.rs +++ b/src/client.rs @@ -1,6 +1,10 @@ +use crate::models::Message; +use std::convert::TryFrom; use futures::future::Future; use std::sync::{Arc, Mutex}; use reqwest::r#async::Client as ReqwestClient; +use reqwest::Error as ReqwestError; +use reqwest::r#async::{Request, Response}; use std::collections::{HashSet, HashMap}; use super::error::Error; @@ -9,6 +13,8 @@ use futures::Poll; use serde::de::DeserializeOwned; use futures::Async; use futures::try_ready; +use serde_json::Value; +use futures::future::Either; use crate::error::ConditionError; @@ -18,7 +24,10 @@ pub use super::types; pub enum RatelimitKey { Default, } -type RatelimitMap = HashMap<RatelimitKey, Ratelimit>; + +pub struct RatelimitMap { + pub inner: HashMap<RatelimitKey, Ratelimit> +} const API_DOMAIN: &'static str = "api.twitch.tv"; const AUTH_DOMAIN: &'static str = "id.twitch.tv"; @@ -35,12 +44,170 @@ pub struct Client { inner: Arc<ClientType>, } +#[derive(Debug)] +pub struct ScopeParseError {} +use std::fmt; +impl fmt::Display for ScopeParseError { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "Scope Parse Error") + } +} + /*TODO*/ -#[derive(PartialEq, Hash, Eq, Clone)] +#[derive(PartialEq, Hash, Eq, Clone, Debug)] pub enum Scope { + Helix(HelixScope), + Kraken(KrakenScope), +} + +impl TryFrom<&str> for Scope { + type Error = ScopeParseError; + fn try_from(s: &str) -> Result<Scope, Self::Error> { + if let Ok(scope) = HelixScope::try_from(s) { + return Ok(Scope::Helix(scope)); + } + if let Ok(scope) = KrakenScope::try_from(s) { + return Ok(Scope::Kraken(scope)); + } + Err(ScopeParseError {}) + } +} +use serde::{Deserialize, Deserializer}; +impl<'de> Deserialize<'de> for Scope { + + fn deserialize<D>(deserializer: D) -> Result<Self, D::Error> + where D: Deserializer<'de> + { + let id = String::deserialize(deserializer)?; + Scope::try_from(&id[0..]).map_err(serde::de::Error::custom) + } +} + +#[derive(PartialEq, Hash, Eq, Clone, Debug)] +pub enum HelixScope { + AnalyticsReadExtensions, + AnalyticsReadGames, + BitsRead, + ChannelReadSubscriptions, + ClipsEdit, + UserEdit, + UserEditBroadcast, + UserReadBroadcast, UserReadEmail, } +impl HelixScope { + pub fn to_str(&self) -> &'static str { + use self::HelixScope::*; + match self { + AnalyticsReadExtensions => "analytics:read:extensions", + AnalyticsReadGames => "analytics:read:games", + BitsRead => "bits:read", + ChannelReadSubscriptions => "channel:read:subscriptions", + ClipsEdit => "clips:edit", + UserEdit => "user:edit", + UserEditBroadcast => "user:edit:broadcast", + UserReadBroadcast => "user:read:broadcast", + UserReadEmail => "user:read:email", + } + } +} + +impl TryFrom<&str> for HelixScope { + type Error = ScopeParseError; + fn try_from(s: &str) -> Result<HelixScope, Self::Error> { + use self::HelixScope::*; + Ok( match s { + "analytics:read:extensions" => AnalyticsReadExtensions, + "analytics:read:games" => AnalyticsReadGames, + "bits:read" => BitsRead, + "channel:read:subscriptions" => ChannelReadSubscriptions, + "clips:edit" => ClipsEdit, + "user:edit" => UserEdit, + "user:edit:broadcast" => UserEditBroadcast, + "user:read:broadcast" => UserReadBroadcast, + "user:read:email" => UserReadEmail, + _ => return Err(ScopeParseError{}) + }) + } +} + +#[derive(PartialEq, Hash, Eq, Clone, Debug)] +pub enum KrakenScope { + ChannelCheckSubscription, + ChannelCommercial, + ChannelEditor, + ChannelFeedEdit, + ChannelFeedRead, + ChannelRead, + ChannelStream, + ChannelSubscriptions, + CollectionsEdit, + CommunitiesEdit, + CommunitiesModerate, + Openid, + UserBlocksEdit, + UserBlocksRead, + UserFollowsEdit, + UserRead, + UserSubscriptions, + ViewingActivityRead, +} + +impl KrakenScope { + pub fn to_str(&self) -> &'static str { + use self::KrakenScope::*; + match self { + ChannelCheckSubscription => "channel_check_subscription", + ChannelCommercial => "channel_commercial", + ChannelEditor => "channel_editor", + ChannelFeedEdit => "channel_feed_edit", + ChannelFeedRead => "channel_feed_read", + ChannelRead => "channel_read", + ChannelStream => "channel_stream", + ChannelSubscriptions => "channel_subscriptions", + CollectionsEdit => "collections_edit", + CommunitiesEdit => "communities_edit", + CommunitiesModerate => "communities_moderate", + Openid => "openid", + UserBlocksEdit => "user_blocks_edit", + UserBlocksRead => "user_blocks_read", + UserFollowsEdit => "user_follows_edit", + UserRead => "user_read", + UserSubscriptions => "user_subscriptions", + ViewingActivityRead => "viewing_activity_read", + } + } +} + +impl TryFrom<&str> for KrakenScope { + type Error = ScopeParseError; + fn try_from(s: &str) -> Result<KrakenScope, Self::Error> { + use self::KrakenScope::*; + Ok( match s { + "channel_check_subscription" => ChannelCheckSubscription, + "channel_commercial" => ChannelCommercial, + "channel_editor" => ChannelEditor, + "channel_feed_edit" => ChannelFeedEdit, + "channel_feed_read" => ChannelFeedRead, + "channel_read" => ChannelRead, + "channel_stream" => ChannelStream, + "channel_subscriptions" => ChannelSubscriptions, + "collections_edit" => CollectionsEdit, + "communities_edit" => CommunitiesEdit, + "communities_moderate" => CommunitiesModerate, + "openid" => Openid, + "user_blocks_edit" => UserBlocksEdit, + "user_blocks_read" => UserBlocksRead, + "user_follows_edit" => UserFollowsEdit, + "user_read" => UserRead, + "user_subscriptions" => UserSubscriptions, + "viewing_activity_read" => ViewingActivityRead, + _ => return Err(ScopeParseError {}) + }) + } +} + #[derive(Clone)] pub enum Version { Helix, @@ -62,17 +229,94 @@ impl Client { } } +pub struct TestConfigRef { + pub requests: Vec<Result<Request, ReqwestError>>, + pub responses: Vec<Response>, +} + +#[derive(Clone)] +pub struct TestConfig { + pub inner: Arc<Mutex<TestConfigRef>> +} + +impl TestConfig { + + pub fn push_response(&self, response: Response) { + let inner = &mut self.inner.lock().unwrap(); + inner.responses.push(response); + } +} + +impl Default for TestConfig { + + fn default() -> Self { + TestConfig { + inner: Arc::new( + Mutex::new( + TestConfigRef { + requests: Vec::new(), + responses: Vec::new(), + } + ) + ) + } + } +} + enum ClientType { Unauth(UnauthClient), Auth(AuthClient), } + +pub struct ClientConfig { + pub reqwest: ReqwestClient, + pub domain: String, + pub auth_domain: String, + pub ratelimits: RatelimitMap, + pub max_retrys: u32, + pub test_config: Option<TestConfig>, +} + +impl Default for RatelimitMap { + + fn default() -> Self { + let mut limits = HashMap::new(); + limits.insert(RatelimitKey::Default, Ratelimit::new(30, "Ratelimit-Limit", "Ratelimit-Remaining", "Ratelimit-Reset")); + RatelimitMap { + inner: limits + } + } +} + +impl RatelimitMap { + pub fn empty() -> RatelimitMap { + RatelimitMap { + inner: HashMap::new() + } + } +} + +impl Default for ClientConfig { + + fn default() -> Self { + let reqwest = ReqwestClient::new(); + let ratelimits = RatelimitMap::default(); + + ClientConfig { + reqwest, + domain: API_DOMAIN.to_owned(), + auth_domain: AUTH_DOMAIN.to_owned(), + ratelimits, + max_retrys: 1, + test_config: None, + } + } +} + pub struct UnauthClient { id: String, - reqwest: ReqwestClient, - domain: String, - auth_domain: String, - ratelimits: RatelimitMap, + config: ClientConfig, version: Version, } @@ -86,6 +330,7 @@ pub struct AuthClient { pub trait ClientTrait { fn id<'a>(&'a self) -> &'a str; + fn config<'a>(&'a self) -> &'a ClientConfig; fn domain<'a>(&'a self) -> &'a str; fn auth_domain<'a>(&'a self) -> &'a str; fn ratelimit<'a>(&'a self, key: RatelimitKey) -> Option<&'a Ratelimit>; @@ -100,21 +345,25 @@ impl ClientTrait for UnauthClient { } fn domain<'a>(&'a self) -> &'a str { - &self.domain + &self.config.domain } fn auth_domain<'a>(&'a self) -> &'a str { - &self.auth_domain + &self.config.auth_domain } fn ratelimit<'a>(&'a self, key: RatelimitKey) -> Option<&'a Ratelimit> { - self.ratelimits.get(&key) + self.config.ratelimits.inner.get(&key) } fn authenticated(&self) -> bool { false } + fn config<'a>(&'a self) -> &'a ClientConfig { + &self.config + } + fn scopes(&self) -> Vec<Scope> { Vec::with_capacity(0) } @@ -146,6 +395,14 @@ impl ClientTrait for Client { } } + fn config<'a>(&'a self) -> &'a ClientConfig { + use self::ClientType::*; + match self.inner.as_ref() { + Unauth(inner) => inner.config(), + Auth(inner) => inner.config(), + } + } + fn ratelimit<'a>(&'a self, key: RatelimitKey) -> Option<&'a Ratelimit> { use self::ClientType::*; match self.inner.as_ref() { @@ -194,6 +451,13 @@ impl ClientTrait for AuthClient { } } + fn config<'a>(&'a self) -> &'a ClientConfig { + match self.previous.inner.as_ref() { + ClientType::Auth(auth) => auth.config(), + ClientType::Unauth(unauth) => unauth.config(), + } + } + fn ratelimit<'a>(&'a self, key: RatelimitKey) -> Option<&'a Ratelimit> { match self.previous.inner.as_ref() { ClientType::Auth(auth) => auth.ratelimit(key), @@ -208,7 +472,7 @@ impl ClientTrait for AuthClient { fn scopes(&self) -> Vec<Scope> { let auth = self.auth_state.lock().expect("Auth Lock is poisoned"); - Vec::with_capacity(0) + auth.scopes.clone() } } @@ -225,28 +489,13 @@ struct AuthStateRef { } impl Client { - pub fn new(id: &str, version: Version) -> Client { + pub fn new(id: &str, config: ClientConfig, version: Version) -> Client { let client = ReqwestClient::new(); - Client::new_with_client(id, client, version) - } - - fn default_ratelimits() -> RatelimitMap { - let mut limits = RatelimitMap::new(); - limits.insert(RatelimitKey::Default, Ratelimit::new(30, "Ratelimit-Limit", "Ratelimit-Remaining", "Ratelimit-Reset")); - - limits - } - - pub fn new_with_client(id: &str, reqwest: ReqwestClient, version: Version) -> Client { - Client { inner: Arc::new( ClientType::Unauth(UnauthClient { id: id.to_owned(), - reqwest: reqwest, - domain: API_DOMAIN.to_owned(), - auth_domain: AUTH_DOMAIN.to_owned(), - ratelimits: Self::default_ratelimits(), + config: config, version: version, })) } @@ -271,11 +520,23 @@ impl Client { fn reqwest(&self) -> ReqwestClient { use self::ClientType::*; match self.inner.as_ref() { - Unauth(inner) => inner.reqwest.clone(), + Unauth(inner) => inner.config.reqwest.clone(), Auth(inner) => inner.previous.reqwest(), } } + fn send(&self, builder: RequestBuilder) -> Box<dyn Future<Item=Response, Error=reqwest::Error> + Send> { + if let Some(test_config) = &self.config().test_config { + let config: &mut TestConfigRef = &mut test_config.inner.lock().expect("Test Config poisoned"); + println!("{}", config.responses.len()); + config.requests.push(builder.build()); + let res = config.responses.pop().expect("Ran out of test responses!"); + Box::new(futures::future::ok(res)) + } else { + Box::new(builder.send()) + } + } + /* The 'bottom' client must always be a client that is not authorized. * This which allows for calls to Auth endpoints using the same control flow * as other requests. @@ -437,7 +698,7 @@ enum RequestState<T> { SetupRatelimit, WaitLimit(WaiterState<RatelimitWaiter>), WaitRequest, - PollParse(Box<dyn Future<Item=T, Error=reqwest::Error> + Send>), + PollParse(Box<dyn Future<Item=T, Error=Error> + Send>), } pub struct ApiRequest<T> { @@ -468,11 +729,12 @@ impl<T: DeserializeOwned + PaginationTrait + 'static + Send> ApiRequest<T> { ratelimit: Option<RatelimitKey>, ) -> ApiRequest<T> { + let max_attempts = client.config().max_retrys; ApiRequest { inner: Arc::new(RequestRef::new(url, params, client, method, ratelimit)), state: RequestState::SetupRequest, attempt: 0, - max_attempts: 1, + max_attempts, pagination: None, } } @@ -582,7 +844,6 @@ impl RatelimitRef { } } -use futures::future::SharedError; use crate::sync::barrier::Barrier; use crate::sync::waiter::Waiter; @@ -662,10 +923,13 @@ impl Waiter for AuthWaiter { let mut auth = inner.auth_state.lock().unwrap(); auth.state = AuthState::Auth; auth.token = Some(credentials.access_token.clone()); + if let Some(scopes) = credentials.scope { + for scope in scopes { auth.scopes.push(scope) } + } } () }) - .map_err(|_| ConditionError{}); + .map_err(|err| err.into()); Future::shared(Box::new(auth_future)) } @@ -694,15 +958,11 @@ impl Waiter for RatelimitWaiter { limits.remaining = limits.limit; () }) - .map_err(|_| ConditionError{}) + .map_err(|err| Error::from(err).into()) )) } } -/* Todo: If the polled futures returns an error than all the waiters should - * get that error - */ - /* Macro ripped directly from try_ready and simplies retries if any error occurs * and there are remaning retry attempt */ @@ -739,7 +999,7 @@ impl<T: DeserializeOwned + PaginationTrait + 'static + Send> Stream for Iterable inner: self.inner.clone(), state: RequestState::SetupRequest, attempt: 0, - max_attempts: 1, + max_attempts: self.inner.client.config().max_retrys, pagination: None }); }, @@ -762,7 +1022,7 @@ impl<T: DeserializeOwned + PaginationTrait + 'static + Send> Stream for Iterable inner: self.inner.clone(), state: RequestState::SetupRequest, attempt: 0, - max_attempts: 1, + max_attempts: self.inner.client.config().max_retrys, pagination: Some(cursor.to_owned()), }); }, @@ -865,37 +1125,55 @@ impl<T: DeserializeOwned + 'static + Send> Future for ApiRequest<T> { }; - let key_err = self.inner.ratelimit.clone(); - let key_ok = self.inner.ratelimit.clone(); - let client_err = client.clone(); - let client_ok = client.clone(); - + let ratelimit_key = self.inner.ratelimit.clone(); + let client_cloned = client.clone(); + /* + Allow testing by capturing the request and returning a predetermined response + If testing is set in the client config then `Pending` is captured and saved and a future::ok(Resposne) is returned. + */ let f = - builder.send() - .map_err(move |err| { - if let Some(key) = key_err { - if let Some(limits) = client_err.ratelimit(key) { + client.send(builder) + .then(move |result| { + trace!("[TWITCH_API] {:?}", result); + if let Some(ratelimit_key) = ratelimit_key { + if let Some(limits) = client_cloned.ratelimit(ratelimit_key) { let mut mut_limits = limits.inner.lock().unwrap(); mut_limits.inflight = mut_limits.inflight - 1; } } - - err + result }) - .map(move |mut response| { - println!("{:?}", response); - if let Some(key) = key_ok { - if let Some(limits) = client_ok.ratelimit(key) { - let mut mut_limits = limits.inner.lock().unwrap(); - mut_limits.inflight = mut_limits.inflight - 1; - mut_limits.update_from_headers(response.headers()); - } + .map_err(|err| err.into()) + .and_then(|mut response| { + let status = response.status(); + if status.is_success() { + Either::A( + response.json().map_err(|err| Error::from(err)).and_then(|json| { + trace!("[TWITCH_API] {}", json); + serde_json::from_value(json).map_err(|err| err.into()) + }) + ) + } else { + Either::B( + response.json::<Message>() + .then(|res| { + match res { + Ok(message) => futures::future::err(Some(message)), + Err(_err) => futures::future::err(None) + } + }) + .map_err(move |maybe_message| { + let status = response.status(); + if status == 401 || status == 403 { + Error::auth_error(maybe_message) + } else if status == 429 { + Error::ratelimit_error(maybe_message) + } else { + Error::auth_error(maybe_message) + } + }) + ) } - - response.json::<T>() - }) - .and_then(|json| { - json }); self.state = RequestState::PollParse(Box::new(f)); }, @@ -906,4 +1184,4 @@ impl<T: DeserializeOwned + 'static + Send> Future for ApiRequest<T> { } } } -} +}
\ No newline at end of file |