Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: key not from file & other improvements #6

Draft
wants to merge 9 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 11 additions & 6 deletions Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
[package]
name = "fcm"
version = "1.0.0"
version = "0.10.0"
authors = [
"Suvish Varghese Thoovamalayil <vishy1618@gmail.com>",
"panicbit <panicbit.dev@gmail.com>",
"Julius de Bruijn <pimeys@gmail.com>",
"Richard Jansen <demo_epso@proton.me>"
"Richard Jansen <demo_epso@proton.me>",
]
description = "An API to talk to FCM (Firebase Cloud Messaging) in Rust"
license = "MIT"
Expand All @@ -16,22 +16,27 @@ keywords = ["fcm", "firebase", "notification"]
edition = "2018"

[features]
default = ["native-tls"]
default = ["native-tls", "dotenv"]
native-tls = ["reqwest/native-tls"]
rustls = ["reqwest/rustls-tls"]
vendored-tls = ["reqwest/native-tls-vendored"]
dotenv = ["dep:dotenv"]

[dependencies]
serde = { version = "1", features = ["derive"] }
serde_json = { version = "1", features = ["preserve_order"] }
erased-serde = "0.4.1"
reqwest = {version = "0.11.0", features = ["json"], default-features=false}
reqwest = { version = "0.12.4", features = ["json"], default-features = false }
chrono = "0.4"
log = "0.4"
gauth = "0.7.0"
dotenv = "0.15.0"
gauth = { git = "https://github.com/WalletConnect/gauth-rs.git", branch = "feat/key-not-from-file" } # TODO switch to tagged version once released
dotenv = { version = "0.15.0", optional = true }
thiserror = "1"

[dev-dependencies]
argparse = "0.2.1"
tokio = { version = "1.0", features = ["rt-multi-thread", "macros"] }
pretty_env_logger = "0.5.0"

# [patch.'https://github.com/WalletConnect/gauth-rs.git']
# gauth = { path = "../gauth-rs" }
10 changes: 5 additions & 5 deletions examples/simple_sender.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
ap.parse_args_or_exit();
}

let client = Client::new();
let client = Client::new().await.unwrap();

let data = json!({
"key": "value",
Expand All @@ -29,8 +29,8 @@ async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
let builder = Message {
data: Some(data),
notification: Some(Notification {
title: Some("I'm high".to_string()),
body: Some(format!("it's {}", chrono::Utc::now())),
title: Some("Test FCM notification".to_string()),
body: Some(format!("It's {}", chrono::Utc::now())),
..Default::default()
}),
target: Target::Token(device_token),
Expand All @@ -40,8 +40,8 @@ async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
android: Some(AndroidConfig {
priority: Some(fcm::AndroidMessagePriority::High),
notification: Some(AndroidNotification {
title: Some("I'm Android high".to_string()),
body: Some(format!("Hi Android, it's {}", chrono::Utc::now())),
title: Some("Android: Test FCM notification".to_string()),
body: Some(format!("Android: It's {}", chrono::Utc::now())),
..Default::default()
}),
..Default::default()
Expand Down
250 changes: 122 additions & 128 deletions src/client/mod.rs
Original file line number Diff line number Diff line change
@@ -1,24 +1,33 @@
pub(crate) mod response;

use crate::client::response::{ErrorReason, FcmError, FcmResponse, RetryAfter};
use crate::{Message, MessageInternal};
use gauth::serv_account::ServiceAccount;
use reqwest::header::RETRY_AFTER;
use reqwest::{Body, StatusCode};
use crate::client::response::Response;
use crate::{
ClientBuildError, Error404Response, Message, MessageInternal, SendError, ERROR_404_CODE_UNREGISTERED,
TYPE_FCM_ERROR,
};
pub use gauth;
use gauth::serv_account::{ServiceAccount, ServiceAccountBuilder, ServiceAccountKey};
use reqwest::{Client as HttpClient, StatusCode};
use serde::Serialize;
use std::sync::Arc;

/// An async client for sending the notification payload.
pub struct Client {
http_client: reqwest::Client,
}
#[cfg(feature = "dotenv")]
use crate::DotEnvClientBuildError;

impl Default for Client {
fn default() -> Self {
Self::new()
}
const FIREBASE_MESSAGING_SCOPE: &str = "https://www.googleapis.com/auth/firebase.messaging";
#[cfg(feature = "dotenv")]
const ENV_VAR_FILE: &str = "GOOGLE_APPLICATION_CREDENTIALS";

/// An FCM v1 client that can be used to send messages to the FCM service. Can be constructed from a ServiceAccountKey using the `Client::builder()` method. The convenience methods `from_key()` and `new()` are also available.
///
/// Upon creation, the client will validate the provided ServiceAccountKey by requesting an initial access token and will return an error if invalid.
#[derive(Debug, Clone)]
pub struct Client {
http_client: HttpClient,
service_account: Arc<ServiceAccount>,
project_id: Arc<String>,
}

// will be used to wrap the message in a "message" field
#[derive(Serialize)]
struct MessageWrapper<'a> {
#[serde(rename = "message")]
Expand All @@ -32,141 +41,126 @@ impl MessageWrapper<'_> {
}

impl Client {
/// Get a new instance of Client.
pub fn new() -> Client {
let http_client = reqwest::ClientBuilder::new()
.pool_max_idle_per_host(usize::MAX)
.build()
.unwrap();

Client { http_client }
}

fn get_service_key_file_name(&self) -> Result<String, String> {
let key_path = match dotenv::var("GOOGLE_APPLICATION_CREDENTIALS") {
Ok(key_path) => key_path,
Err(err) => return Err(err.to_string()),
};

Ok(key_path)
pub fn builder() -> ClientBuilder {
ClientBuilder::new()
}

fn read_service_key_file(&self) -> Result<String, String> {
let key_path = self.get_service_key_file_name()?;

let private_key_content = match std::fs::read(key_path) {
Ok(content) => content,
Err(err) => return Err(err.to_string()),
};

Ok(String::from_utf8(private_key_content).unwrap())
}

fn read_service_key_file_json(&self) -> Result<serde_json::Value, String> {
let file_content = match self.read_service_key_file() {
Ok(content) => content,
Err(err) => return Err(err),
};

let json_content: serde_json::Value = match serde_json::from_str(&file_content) {
Ok(json) => json,
Err(err) => return Err(err.to_string()),
};

Ok(json_content)
/// Create a new Client using credentials from a file path specified in the GOOGLE_APPLICATION_CREDENTIALS environment variable.
#[cfg(feature = "dotenv")]
pub async fn new() -> Result<Client, DotEnvClientBuildError> {
let path = dotenv::var(ENV_VAR_FILE).map_err(DotEnvClientBuildError::DotEnv)?;
let bytes = std::fs::read(path).map_err(DotEnvClientBuildError::ReadFile)?;
let key = serde_json::from_slice::<ServiceAccountKey>(&bytes).map_err(DotEnvClientBuildError::ParseFile)?;
Self::from_key(key).await.map_err(DotEnvClientBuildError::ClientBuild)
}

fn get_project_id(&self) -> Result<String, String> {
let json_content = match self.read_service_key_file_json() {
Ok(json) => json,
Err(err) => return Err(err),
};

let project_id = match json_content["project_id"].as_str() {
Some(project_id) => project_id,
None => return Err("could not get project_id".to_string()),
};

Ok(project_id.to_string())
}

async fn get_auth_token(&self) -> Result<String, String> {
let tkn = match self.access_token().await {
Ok(tkn) => tkn,
Err(_) => return Err("could not get access token".to_string()),
};

Ok(tkn)
pub async fn from_key(key: ServiceAccountKey) -> Result<Client, ClientBuildError> {
Self::builder().build(key).await
}

async fn access_token(&self) -> Result<String, String> {
let scopes = vec!["https://www.googleapis.com/auth/firebase.messaging"];
let key_path = self.get_service_key_file_name()?;

let mut service_account = ServiceAccount::from_file(&key_path, scopes);
let access_token = match service_account.access_token().await {
Ok(access_token) => access_token,
Err(err) => return Err(err.to_string()),
};

let token_no_bearer = access_token.split(char::is_whitespace).collect::<Vec<&str>>()[1];

Ok(token_no_bearer.to_string())
}

pub async fn send(&self, message: Message) -> Result<FcmResponse, FcmError> {
pub async fn send(&self, message: Message) -> Result<Response, SendError> {
let fin = message.finalize();
let wrapper = MessageWrapper::new(&fin);
let payload = serde_json::to_vec(&wrapper).unwrap();

let project_id = match self.get_project_id() {
Ok(project_id) => project_id,
Err(err) => return Err(FcmError::ProjectIdError(err)),
};

let auth_token = match self.get_auth_token().await {
Ok(tkn) => tkn,
Err(err) => return Err(FcmError::ProjectIdError(err)),
};
let access_token = self
.service_account
.access_token()
.await
.map_err(SendError::AccessToken)?;

// https://firebase.google.com/docs/reference/fcm/rest/v1/projects.messages/send
let url = format!("https://fcm.googleapis.com/v1/projects/{}/messages:send", project_id);
let url = format!(
"https://fcm.googleapis.com/v1/projects/{}/messages:send",
self.project_id
);

let request = self
let response = self
.http_client
.post(&url)
.header("Content-Type", "application/json")
.bearer_auth(auth_token)
.body(Body::from(payload))
.build()?;

let response = self.http_client.execute(request).await?;
.bearer_auth(access_token.bearer_token)
.json(&wrapper)
.send()
.await
.map_err(SendError::HttpRequest)?;

let response_status = response.status();

let retry_after = response
.headers()
.get(RETRY_AFTER)
.and_then(|ra| ra.to_str().ok())
.and_then(|ra| ra.parse::<RetryAfter>().ok());
// let retry_after = response
// .headers()
// .get(RETRY_AFTER)
// .and_then(|ra| ra.to_str().ok())
// .and_then(|ra| ra.parse::<RetryAfter>().ok());

match response_status {
StatusCode::OK => {
let fcm_response: FcmResponse = response.json().await.unwrap();

match fcm_response.error {
Some(ErrorReason::Unavailable) => Err(FcmError::ServerError(retry_after)),
Some(ErrorReason::InternalServerError) => Err(FcmError::ServerError(retry_after)),
_ => Ok(fcm_response),
StatusCode::OK => response.json::<Response>().await.map_err(SendError::ResponseParse),
StatusCode::NOT_FOUND => {
let response = response
.json::<Error404Response>()
.await
.map_err(SendError::ResponseParse)?;
for detail in response.error.details.iter() {
if detail.typ == TYPE_FCM_ERROR && detail.error_code == ERROR_404_CODE_UNREGISTERED {
return Err(SendError::Unregistered);
}
}
Err(SendError::UnknownError404Response(response))
}
StatusCode::UNAUTHORIZED => Err(FcmError::Unauthorized),
StatusCode::BAD_REQUEST => {
let body = response.text().await.unwrap();
Err(FcmError::InvalidMessage(format!("Bad Request ({body}")))
}
status if status.is_server_error() => Err(FcmError::ServerError(retry_after)),
_ => Err(FcmError::InvalidMessage("Unknown Error".to_string())),
StatusCode::FORBIDDEN => Err(SendError::Forbidden),
// StatusCode::UNAUTHORIZED => Err(Error::Unauthorized),
// StatusCode::BAD_REQUEST => {
// let body = response.text().await.unwrap();
// Err(Error::InvalidMessage(format!("Bad Request ({body}")))
// }
// status if status.is_server_error() => Err(Error::ServerError(retry_after)),
// _ => Err(Error::InvalidMessage("Unknown Error".to_string())),
_ => Err(SendError::UnknownHttpResponse {
status: response_status,
body: response.text().await,
}),
}
}
}

pub struct ClientBuilder {
http_client: Option<HttpClient>,
}

impl ClientBuilder {
pub fn new() -> Self {
Self { http_client: None }
}

pub fn http_client(mut self, http_client: HttpClient) -> Self {
self.http_client = Some(http_client);
self
}

pub async fn build(self, key: ServiceAccountKey) -> Result<Client, ClientBuildError> {
let http_client = self.http_client.unwrap_or_default();
let project_id = key.project_id.clone();
let service_account = ServiceAccountBuilder::new()
.key(key)
.scopes(vec![FIREBASE_MESSAGING_SCOPE])
.http_client(http_client.clone())
.build()
.map_err(ClientBuildError::ServiceAccountBuild)?;

// Validate the key by requesting initial access token
let _access_token = service_account
.access_token()
.await
.map_err(ClientBuildError::GetAccessToken)?;

Ok(Client {
http_client,
project_id: Arc::new(project_id),
service_account: Arc::new(service_account),
})
}
}

impl Default for ClientBuilder {
fn default() -> Self {
Self::new()
}
}
Loading