diff --git a/fcm.go b/fcm.go index 93837b4..d052474 100644 --- a/fcm.go +++ b/fcm.go @@ -1,10 +1,10 @@ // See documentation here: // -// Sending notifications +// # Sending notifications // // https://firebase.google.com/docs/cloud-messaging/send-message?hl=en // -// OAuth2.0 +// # OAuth2.0 // // https://developers.google.com/identity/protocols/oauth2/service-account?hl=en package fcm @@ -20,9 +20,10 @@ import ( "mime/multipart" "net/http" "net/textproto" + "sync" - "github.com/go-logr/logr" "git.bit5.ru/backend/errors" + "github.com/go-logr/logr" "golang.org/x/oauth2" ) @@ -98,7 +99,7 @@ func (c *Client) SendMessage(msg Message) (SendResponse, error) { Message: msg, } - return c.doSendRequest(sendRequest) + return c.doSendRequest(sendRequest, true) } func (c *Client) ValidateMessage(msg Message) (SendResponse, error) { @@ -107,10 +108,10 @@ func (c *Client) ValidateMessage(msg Message) (SendResponse, error) { Message: msg, } - return c.doSendRequest(sendRequest) + return c.doSendRequest(sendRequest, true) } -func (c *Client) doSendRequest(req SendRequest) (SendResponse, error) { +func (c *Client) doSendRequest(req SendRequest, loggerEnabled bool) (SendResponse, error) { accessToken, err := c.ts.Token() if err != nil { return SendResponse{}, err @@ -120,7 +121,9 @@ func (c *Client) doSendRequest(req SendRequest) (SendResponse, error) { if err != nil { return SendResponse{}, errors.WithStack(err) } - c.logger.Info("sending", "message", data) + if loggerEnabled { + c.logger.Info("sending", "message", data) + } request, err := http.NewRequest(http.MethodPost, c.cfg.SendEndpoint, bytes.NewReader(data)) if err != nil { @@ -153,6 +156,7 @@ func (c *Client) doSendRequest(req SendRequest) (SendResponse, error) { return resp, nil } +// Deprecated: Use SendEach instead. func (c *Client) SendMessages(messages []Message) (MultiSendResponse, error) { return c.doSendMessages(messages, false) } @@ -280,7 +284,7 @@ func (c *Client) makeMultiSendResponse(response *http.Response, totalCount int) responses = append(responses, resp) if resp.HasError() { - c.logger.Info("fail", "error", fmt.Sprintf("%+v", *resp.Error)) + c.logger.Info("fail", "error", fmt.Sprintf("%+v", *resp.Error)) fails++ } } @@ -312,6 +316,98 @@ func makeSendResponseFromPart(part *multipart.Part) (SendResponse, error) { return resp, nil } +func (c *Client) SendEach(messages []Message) (MessageMultiSendResponse, error) { + return c.doSendEachInBatch(messages, false) +} + +func (c *Client) doSendEachInBatch(messages []Message, validateOnly bool) (MessageMultiSendResponse, error) { + messageCount := len(messages) + if messageCount == 0 { + return MessageMultiSendResponse{}, nil + } + if messageCount > maxMessages { + return MessageMultiSendResponse{}, errors.New(fmt.Sprintf("messages limit (%d) exceeded: %d", maxMessages, messageCount)) + } + + var responses = make([]MessageSendResponse, len(messages)) + var wg sync.WaitGroup + + for idx, m := range messages { + //if err := validateMessage(m); err != nil { + // return nil, fmt.Errorf("invalid message at index %d: %v", idx, err) + //} + wg.Add(1) + go func(idx int, m Message, validateOnly bool, responses []MessageSendResponse) { + defer wg.Done() + var resp string + var err error + if validateOnly { + resp, err = c.Validate(m) + } else { + resp, err = c.Send(m) + } + if err == nil { + responses[idx] = MessageSendResponse{ + Success: true, + MessageID: resp, + } + } else { + responses[idx] = MessageSendResponse{ + Success: false, + Error: err, + } + } + }(idx, m, validateOnly, responses) + } + // Wait for all Validate/Send calls to finish + wg.Wait() + + successCount := 0 + for _, r := range responses { + if r.Success { + successCount++ + } + } + + return MessageMultiSendResponse{ + Responses: responses, + Sent: successCount, + Failed: len(responses) - successCount, + }, nil +} + +func (c *Client) Send(message Message) (string, error) { + sendRequest := SendRequest{ + ValidateOnly: false, + Message: message, + } + resp, err := c.doSendRequest(sendRequest, false) + + return resp.MessageName, err +} + +func (c *Client) Validate(message Message) (string, error) { + sendRequest := SendRequest{ + ValidateOnly: true, + Message: message, + } + resp, err := c.doSendRequest(sendRequest, false) + + return resp.MessageName, err +} + +type MessageSendResponse struct { + Success bool + MessageID string + Error error +} + +type MessageMultiSendResponse struct { + Responses []MessageSendResponse + Sent int + Failed int +} + type SendRequest struct { // Flag for testing the request without actually delivering the message. ValidateOnly bool `json:"validate_only,omitempty"`