409 lines
9.3 KiB
Go
409 lines
9.3 KiB
Go
// See documentation here:
|
|
//
|
|
// # Sending notifications
|
|
//
|
|
// https://firebase.google.com/docs/cloud-messaging/send-message?hl=en
|
|
//
|
|
// # OAuth2.0
|
|
//
|
|
// https://developers.google.com/identity/protocols/oauth2/service-account?hl=en
|
|
package fcm
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"encoding/json"
|
|
"fmt"
|
|
"io"
|
|
"net/http"
|
|
"os"
|
|
"sync"
|
|
|
|
"git.bit5.ru/backend/errors"
|
|
|
|
"github.com/go-logr/logr"
|
|
"golang.org/x/oauth2"
|
|
|
|
"go.opentelemetry.io/otel"
|
|
"go.opentelemetry.io/otel/attribute"
|
|
"go.opentelemetry.io/otel/codes"
|
|
)
|
|
|
|
const tracerName = "git.bit5.ru/backend/fcm"
|
|
|
|
var tracer = otel.Tracer(tracerName)
|
|
|
|
//-----------------------------------------------------------------------------
|
|
|
|
const (
|
|
fcmErrorType = "type.googleapis.com/google.firebase.fcm.v1.FcmError"
|
|
|
|
maxMessages = 500
|
|
multipartBoundary = "msg_boundary"
|
|
)
|
|
|
|
var (
|
|
AuthScopes = []string{"https://www.googleapis.com/auth/firebase.messaging"}
|
|
)
|
|
|
|
func MakeSendEndpoint(projectId string) string {
|
|
return fmt.Sprintf("https://fcm.googleapis.com/v1/projects/%s/messages:send", projectId)
|
|
}
|
|
|
|
type Credentials struct {
|
|
Type string `json:"type"`
|
|
|
|
ProjectID string `json:"project_id"`
|
|
PrivateKeyID string `json:"private_key_id"`
|
|
PrivateKey string `json:"private_key"`
|
|
|
|
ClientID string `json:"client_id"`
|
|
ClientEmail string `json:"client_email"`
|
|
|
|
AuthURL string `json:"auth_uri"`
|
|
TokenURL string `json:"token_uri"`
|
|
}
|
|
|
|
func ReadCredentialsFromFile(filename string) (Credentials, error) {
|
|
data, err := os.ReadFile(filename)
|
|
if err != nil {
|
|
return Credentials{}, errors.WithStack(err)
|
|
}
|
|
|
|
var c Credentials
|
|
if err := json.Unmarshal(data, &c); err != nil {
|
|
return Credentials{}, errors.WithStack(err)
|
|
}
|
|
|
|
return c, nil
|
|
}
|
|
|
|
type ClientConfig struct {
|
|
SendEndpoint string
|
|
}
|
|
|
|
type Client struct {
|
|
cfg ClientConfig
|
|
ts oauth2.TokenSource
|
|
hc *http.Client
|
|
logger logr.Logger
|
|
}
|
|
|
|
func NewClient(projectId string, cfg ClientConfig, ts oauth2.TokenSource, hc *http.Client, logger logr.Logger) *Client {
|
|
return &Client{
|
|
cfg: cfg,
|
|
ts: ts,
|
|
hc: hc,
|
|
logger: logger,
|
|
}
|
|
}
|
|
|
|
func (c *Client) SendMessage(ctx context.Context, msg Message) (SendResponse, error) {
|
|
|
|
ctx, span := tracer.Start(ctx, "Client.SendMessage")
|
|
defer span.End()
|
|
|
|
sendRequest := SendRequest{
|
|
ValidateOnly: false,
|
|
Message: msg,
|
|
}
|
|
|
|
return c.doSendRequest(ctx, sendRequest, true)
|
|
}
|
|
|
|
func (c *Client) ValidateMessage(ctx context.Context, msg Message) (SendResponse, error) {
|
|
|
|
ctx, span := tracer.Start(ctx, "Client.ValidateMessage")
|
|
defer span.End()
|
|
|
|
sendRequest := SendRequest{
|
|
ValidateOnly: true,
|
|
Message: msg,
|
|
}
|
|
|
|
return c.doSendRequest(ctx, sendRequest, true)
|
|
}
|
|
|
|
func (c *Client) Send(ctx context.Context, message Message) (string, error) {
|
|
|
|
ctx, span := tracer.Start(ctx, "Client.Send")
|
|
defer span.End()
|
|
|
|
sendRequest := SendRequest{
|
|
ValidateOnly: false,
|
|
Message: message,
|
|
}
|
|
|
|
resp, err := c.doSendRequest(ctx, sendRequest, false)
|
|
if err != nil {
|
|
span.SetStatus(codes.Error, err.Error())
|
|
span.RecordError(err)
|
|
return "", err
|
|
}
|
|
|
|
return resp.MessageName, nil
|
|
}
|
|
|
|
func (c *Client) Validate(ctx context.Context, message Message) (string, error) {
|
|
|
|
ctx, span := tracer.Start(ctx, "Client.Validate")
|
|
defer span.End()
|
|
|
|
sendRequest := SendRequest{
|
|
ValidateOnly: true,
|
|
Message: message,
|
|
}
|
|
|
|
resp, err := c.doSendRequest(ctx, sendRequest, false)
|
|
if err != nil {
|
|
span.SetStatus(codes.Error, err.Error())
|
|
span.RecordError(err)
|
|
return "", err
|
|
}
|
|
|
|
return resp.MessageName, nil
|
|
}
|
|
|
|
func (c *Client) doSendRequest(ctx context.Context, req SendRequest, loggerEnabled bool) (SendResponse, error) {
|
|
|
|
_, span := tracer.Start(ctx, "Client.doSendRequest")
|
|
defer span.End()
|
|
|
|
accessToken, err := c.ts.Token()
|
|
if err != nil {
|
|
span.SetStatus(codes.Error, err.Error())
|
|
span.RecordError(err)
|
|
return SendResponse{}, err
|
|
}
|
|
|
|
data, err := json.Marshal(req)
|
|
if err != nil {
|
|
span.SetStatus(codes.Error, err.Error())
|
|
span.RecordError(err)
|
|
return SendResponse{}, errors.WithStack(err)
|
|
}
|
|
span.SetAttributes(
|
|
attribute.String("request.body", string(data)),
|
|
)
|
|
if loggerEnabled {
|
|
c.logger.Info("sending", "message", data)
|
|
}
|
|
|
|
request, err := http.NewRequest(http.MethodPost, c.cfg.SendEndpoint, bytes.NewReader(data))
|
|
if err != nil {
|
|
span.SetStatus(codes.Error, err.Error())
|
|
span.RecordError(err)
|
|
return SendResponse{}, errors.WithStack(err)
|
|
}
|
|
|
|
accessToken.SetAuthHeader(request)
|
|
request.Header.Set("Content-Type", "application/json")
|
|
|
|
response, err := c.hc.Do(request)
|
|
if err != nil {
|
|
span.SetStatus(codes.Error, err.Error())
|
|
span.RecordError(err)
|
|
return SendResponse{}, errors.WithStack(err)
|
|
}
|
|
defer response.Body.Close()
|
|
|
|
body, err := io.ReadAll(response.Body)
|
|
if err != nil {
|
|
span.SetStatus(codes.Error, err.Error())
|
|
span.RecordError(err)
|
|
return SendResponse{}, errors.WithStack(err)
|
|
}
|
|
|
|
bodyStr := string(body)
|
|
|
|
span.SetAttributes(
|
|
attribute.Int("response.status_code", response.StatusCode),
|
|
attribute.String("response.body", bodyStr),
|
|
)
|
|
|
|
if response.StatusCode != http.StatusOK {
|
|
err := errors.Errorf("Status is not OK. Status: %d. Body: %q.", response.StatusCode, bodyStr)
|
|
span.SetStatus(codes.Error, err.Error())
|
|
span.RecordError(err)
|
|
return SendResponse{}, err
|
|
}
|
|
|
|
var resp SendResponse
|
|
if err := json.Unmarshal(body, &resp); err != nil {
|
|
newErr := errors.Errorf("Can not parse send response as JSON. Response: %q. Error: %v", bodyStr, err)
|
|
span.SetStatus(codes.Error, err.Error())
|
|
span.RecordError(err)
|
|
return SendResponse{}, newErr
|
|
}
|
|
|
|
return resp, nil
|
|
}
|
|
|
|
func (c *Client) SendEach(ctx context.Context, messages []Message) (MessageMultiSendResponse, error) {
|
|
|
|
ctx, span := tracer.Start(ctx, "Client.SendEach")
|
|
defer span.End()
|
|
|
|
resp, err := c.doSendEachInBatch(ctx, messages, false)
|
|
if err != nil {
|
|
span.SetStatus(codes.Error, err.Error())
|
|
span.RecordError(err)
|
|
return MessageMultiSendResponse{}, err
|
|
}
|
|
|
|
if resp.Failed > 0 {
|
|
span.SetStatus(codes.Error, "Some notifications not sent.")
|
|
}
|
|
|
|
return resp, nil
|
|
}
|
|
|
|
func (c *Client) doSendEachInBatch(
|
|
ctx context.Context,
|
|
messages []Message,
|
|
validateOnly bool,
|
|
) (MessageMultiSendResponse, error) {
|
|
|
|
ctx, span := tracer.Start(ctx, "Client.doSendEachInBatch")
|
|
defer span.End()
|
|
|
|
messageCount := len(messages)
|
|
if messageCount == 0 {
|
|
return MessageMultiSendResponse{}, nil
|
|
}
|
|
|
|
if messageCount > maxMessages {
|
|
err := errors.New(fmt.Sprintf("messages limit (%d) exceeded: %d", maxMessages, messageCount))
|
|
span.SetStatus(codes.Error, err.Error())
|
|
span.RecordError(err)
|
|
return MessageMultiSendResponse{}, err
|
|
}
|
|
|
|
var responses = make([]MessageSendResponse, len(messages))
|
|
var wg sync.WaitGroup
|
|
|
|
for idx, m := range messages {
|
|
//if err := validateMessage(m); err != nil {
|
|
// return nil, fmt.Errorf("invalid message at index %d: %v", idx, err)
|
|
//}
|
|
wg.Add(1)
|
|
go func(idx int, m Message, validateOnly bool, responses []MessageSendResponse) {
|
|
defer wg.Done()
|
|
var resp string
|
|
var err error
|
|
if validateOnly {
|
|
resp, err = c.Validate(ctx, m)
|
|
} else {
|
|
resp, err = c.Send(ctx, m)
|
|
}
|
|
if err == nil {
|
|
responses[idx] = MessageSendResponse{
|
|
Success: true,
|
|
MessageID: resp,
|
|
}
|
|
} else {
|
|
span.SetStatus(codes.Error, "Some notifications not sent.")
|
|
responses[idx] = MessageSendResponse{
|
|
Success: false,
|
|
Error: err,
|
|
}
|
|
}
|
|
}(idx, m, validateOnly, responses)
|
|
}
|
|
// Wait for all Validate/Send calls to finish
|
|
wg.Wait()
|
|
|
|
sentAmount := 0
|
|
for _, r := range responses {
|
|
if r.Success {
|
|
sentAmount++
|
|
}
|
|
}
|
|
|
|
failedAmount := len(responses) - sentAmount
|
|
|
|
span.SetAttributes(
|
|
attribute.Int("resp.total", len(responses)),
|
|
attribute.Int("resp.sent_amount", sentAmount),
|
|
attribute.Int("resp.failed_amount", failedAmount),
|
|
)
|
|
|
|
resp := MessageMultiSendResponse{
|
|
Responses: responses,
|
|
Sent: sentAmount,
|
|
Failed: failedAmount,
|
|
}
|
|
|
|
return resp, nil
|
|
}
|
|
|
|
type MessageSendResponse struct {
|
|
Success bool
|
|
MessageID string
|
|
Error error
|
|
}
|
|
|
|
type MessageMultiSendResponse struct {
|
|
Responses []MessageSendResponse
|
|
Sent int
|
|
Failed int
|
|
}
|
|
|
|
// Запрос на отправку пуш-оповещения.
|
|
type SendRequest struct {
|
|
// Flag for testing the request without actually delivering the message.
|
|
ValidateOnly bool `json:"validate_only,omitempty"`
|
|
|
|
Message Message `json:"message"`
|
|
}
|
|
|
|
type SendResponse struct {
|
|
MessageName string `json:"name"`
|
|
|
|
Error *SendError `json:"error"`
|
|
}
|
|
|
|
func (sr SendResponse) HasError() bool {
|
|
return sr.Error != nil
|
|
}
|
|
|
|
type SendErrorCode string
|
|
|
|
const (
|
|
SendErrorCode_UNSPECIFIED_ERROR SendErrorCode = "UNSPECIFIED_ERROR"
|
|
SendErrorCode_UNREGISTERED SendErrorCode = "UNREGISTERED"
|
|
SendErrorCode_SENDER_ID_MISMATCH SendErrorCode = "SENDER_ID_MISMATCH"
|
|
SendErrorCode_QUOTA_EXCEEDED SendErrorCode = "QUOTA_EXCEEDED"
|
|
SendErrorCode_THIRD_PARTY_AUTH_ERROR SendErrorCode = "THIRD_PARTY_AUTH_ERROR"
|
|
)
|
|
|
|
type SendError struct {
|
|
Code int `json:"code"`
|
|
Message string `json:"message"`
|
|
Status string `json:"status"`
|
|
Details []struct {
|
|
Type string `json:"@type"`
|
|
ErrorCode SendErrorCode `json:"errorCode"`
|
|
} `json:"details"`
|
|
}
|
|
|
|
func (se *SendError) IsUnregistered() bool {
|
|
if se == nil {
|
|
return false
|
|
}
|
|
|
|
for _, d := range se.Details {
|
|
if d.Type == fcmErrorType && d.ErrorCode == SendErrorCode_UNREGISTERED {
|
|
return true
|
|
}
|
|
}
|
|
|
|
return false
|
|
}
|
|
|
|
type MultiSendResponse struct {
|
|
Responses []SendResponse
|
|
Sent int
|
|
Failed int
|
|
}
|