// See documentation here: // // # Sending notifications // // https://firebase.google.com/docs/cloud-messaging/send-message?hl=en // // # OAuth2.0 // // https://developers.google.com/identity/protocols/oauth2/service-account?hl=en package fcm import ( "bufio" "bytes" "context" "encoding/json" "fmt" "io" "io/ioutil" "mime" "mime/multipart" "net/http" "net/textproto" "sync" "git.bit5.ru/backend/errors" "github.com/go-logr/logr" "golang.org/x/oauth2" "go.opentelemetry.io/otel" "go.opentelemetry.io/otel/attribute" "go.opentelemetry.io/otel/codes" ) const tracerName = "git.bit5.ru/backend/fcm" var tracer = otel.Tracer(tracerName) //----------------------------------------------------------------------------- const ( fcmErrorType = "type.googleapis.com/google.firebase.fcm.v1.FcmError" maxMessages = 500 multipartBoundary = "msg_boundary" BatchSendEndpoint = "https://fcm.googleapis.com/batch" ) var ( AuthScopes = []string{"https://www.googleapis.com/auth/firebase.messaging"} ) func MakeSendEndpoint(projectId string) string { return fmt.Sprintf("https://fcm.googleapis.com/v1/projects/%s/messages:send", projectId) } type Credentials struct { Type string `json:"type"` ProjectID string `json:"project_id"` PrivateKeyID string `json:"private_key_id"` PrivateKey string `json:"private_key"` ClientID string `json:"client_id"` ClientEmail string `json:"client_email"` AuthURL string `json:"auth_uri"` TokenURL string `json:"token_uri"` } func ReadCredentialsFromFile(filename string) (Credentials, error) { data, err := ioutil.ReadFile(filename) if err != nil { return Credentials{}, errors.WithStack(err) } var c Credentials if err := json.Unmarshal(data, &c); err != nil { return Credentials{}, errors.WithStack(err) } return c, nil } type ClientConfig struct { SendEndpoint string BatchSendEndpoint string } type Client struct { cfg ClientConfig ts oauth2.TokenSource hc *http.Client logger logr.Logger } func NewClient(cfg ClientConfig, ts oauth2.TokenSource, hc *http.Client, logger logr.Logger) *Client { return &Client{ cfg: cfg, ts: ts, hc: hc, logger: logger, } } func (c *Client) SendMessage(ctx context.Context, msg Message) (SendResponse, error) { ctx, span := tracer.Start(ctx, "Client.SendMessage") defer span.End() sendRequest := SendRequest{ ValidateOnly: false, Message: msg, } return c.doSendRequest(ctx, sendRequest, true) } func (c *Client) ValidateMessage(ctx context.Context, msg Message) (SendResponse, error) { ctx, span := tracer.Start(ctx, "Client.ValidateMessage") defer span.End() sendRequest := SendRequest{ ValidateOnly: true, Message: msg, } return c.doSendRequest(ctx, sendRequest, true) } func (c *Client) Send(ctx context.Context, message Message) (string, error) { ctx, span := tracer.Start(ctx, "Client.Send") defer span.End() sendRequest := SendRequest{ ValidateOnly: false, Message: message, } resp, err := c.doSendRequest(ctx, sendRequest, false) if err != nil { span.SetStatus(codes.Error, err.Error()) span.RecordError(err) return "", err } return resp.MessageName, nil } func (c *Client) Validate(ctx context.Context, message Message) (string, error) { ctx, span := tracer.Start(ctx, "Client.Validate") defer span.End() sendRequest := SendRequest{ ValidateOnly: true, Message: message, } resp, err := c.doSendRequest(ctx, sendRequest, false) if err != nil { span.SetStatus(codes.Error, err.Error()) span.RecordError(err) return "", err } return resp.MessageName, nil } func (c *Client) doSendRequest(ctx context.Context, req SendRequest, loggerEnabled bool) (SendResponse, error) { _, span := tracer.Start(ctx, "Client.doSendRequest") defer span.End() accessToken, err := c.ts.Token() if err != nil { span.SetStatus(codes.Error, err.Error()) span.RecordError(err) return SendResponse{}, err } data, err := json.Marshal(req) if err != nil { span.SetStatus(codes.Error, err.Error()) span.RecordError(err) return SendResponse{}, errors.WithStack(err) } span.SetAttributes( attribute.String("request.body", string(data)), ) if loggerEnabled { c.logger.Info("sending", "message", data) } request, err := http.NewRequest(http.MethodPost, c.cfg.SendEndpoint, bytes.NewReader(data)) if err != nil { span.SetStatus(codes.Error, err.Error()) span.RecordError(err) return SendResponse{}, errors.WithStack(err) } accessToken.SetAuthHeader(request) request.Header.Set("Content-Type", "application/json") response, err := c.hc.Do(request) if err != nil { span.SetStatus(codes.Error, err.Error()) span.RecordError(err) return SendResponse{}, errors.WithStack(err) } defer response.Body.Close() body, err := io.ReadAll(response.Body) if err != nil { span.SetStatus(codes.Error, err.Error()) span.RecordError(err) return SendResponse{}, errors.WithStack(err) } bodyStr := string(body) span.SetAttributes( attribute.Int("response.status_code", response.StatusCode), attribute.String("response.body", bodyStr), ) if response.StatusCode != http.StatusOK { err := errors.Errorf("Status is not OK. Status: %d. Body: %q.", response.StatusCode, bodyStr) span.SetStatus(codes.Error, err.Error()) span.RecordError(err) return SendResponse{}, err } var resp SendResponse if err := json.Unmarshal(body, &resp); err != nil { newErr := errors.Errorf("Can not parse send response as JSON. Response: %q. Error: %v", bodyStr, err) span.SetStatus(codes.Error, err.Error()) span.RecordError(err) return SendResponse{}, newErr } return resp, nil } // Deprecated: Use SendEach instead. func (c *Client) SendMessages(messages []Message) (MultiSendResponse, error) { return c.doSendMessages(messages, false) } func (c *Client) ValidateMessages(messages []Message) (MultiSendResponse, error) { return c.doSendMessages(messages, true) } func (c *Client) doSendMessages(messages []Message, validateOnly bool) (MultiSendResponse, error) { messageCount := len(messages) if messageCount == 0 { return MultiSendResponse{}, nil } if messageCount > maxMessages { return MultiSendResponse{}, errors.New(fmt.Sprintf("messages limit (%d) exceeded: %d", maxMessages, messageCount)) } accessToken, err := c.ts.Token() if err != nil { return MultiSendResponse{}, err } var body bytes.Buffer w := multipart.NewWriter(&body) w.SetBoundary(multipartBoundary) for index, msg := range messages { req := SendRequest{ ValidateOnly: validateOnly, Message: msg, } body, err := c.makeMessageRequest(req) if err != nil { return MultiSendResponse{}, err } if err := writePartTo(w, body, index); err != nil { return MultiSendResponse{}, err } } if err := w.Close(); err != nil { return MultiSendResponse{}, errors.WithStack(err) } request, err := http.NewRequest(http.MethodPost, c.cfg.BatchSendEndpoint, &body) if err != nil { return MultiSendResponse{}, errors.WithStack(err) } accessToken.SetAuthHeader(request) request.Header.Set("Content-Type", fmt.Sprintf(`multipart/mixed; boundary="%s"`, multipartBoundary)) response, err := c.hc.Do(request) if err != nil { return MultiSendResponse{}, errors.WithStack(err) } defer response.Body.Close() return c.makeMultiSendResponse(response, messageCount) } func (c *Client) makeMessageRequest(req SendRequest) ([]byte, error) { reqJson, err := json.Marshal(req) if err != nil { return nil, errors.WithStack(err) } request, err := http.NewRequest(http.MethodPost, c.cfg.SendEndpoint, bytes.NewBuffer(reqJson)) if err != nil { return nil, errors.WithStack(err) } request.Header.Set("Content-Type", "application/json; charset=UTF-8") request.Header.Set("User-Agent", "") var body bytes.Buffer if err := request.Write(&body); err != nil { return nil, errors.WithStack(err) } return body.Bytes(), nil } func writePartTo(w *multipart.Writer, bytes []byte, index int) error { header := make(textproto.MIMEHeader) header.Set("Content-Type", "application/http") header.Set("Content-Transfer-Encoding", "binary") header.Set("Content-ID", fmt.Sprintf("%d", index+1)) part, err := w.CreatePart(header) if err != nil { return errors.WithStack(err) } if _, err := part.Write(bytes); err != nil { return errors.WithStack(err) } return nil } func (c *Client) makeMultiSendResponse(response *http.Response, totalCount int) (MultiSendResponse, error) { responses := make([]SendResponse, 0, totalCount) var fails int _, params, err := mime.ParseMediaType(response.Header.Get("Content-Type")) if err != nil { return MultiSendResponse{}, errors.WithStack(err) } reader := multipart.NewReader(response.Body, params["boundary"]) for { part, err := reader.NextPart() if err == io.EOF { break } else if err != nil { return MultiSendResponse{}, errors.WithStack(err) } resp, err := makeSendResponseFromPart(part) if err != nil { return MultiSendResponse{}, err } responses = append(responses, resp) if resp.HasError() { c.logger.Info("fail", "error", fmt.Sprintf("%+v", *resp.Error)) fails++ } } return MultiSendResponse{ Responses: responses, Sent: totalCount - fails, Failed: fails, }, nil } func makeSendResponseFromPart(part *multipart.Part) (SendResponse, error) { response, err := http.ReadResponse(bufio.NewReader(part), nil) if err != nil { return SendResponse{}, errors.WithStack(err) } defer response.Body.Close() body, err := ioutil.ReadAll(response.Body) if err != nil { return SendResponse{}, errors.WithStack(err) } var resp SendResponse if err := json.Unmarshal(body, &resp); err != nil { return SendResponse{}, errors.WithMessagef(err, "response body: %s", body) } return resp, nil } func (c *Client) SendEach(ctx context.Context, messages []Message) (MessageMultiSendResponse, error) { ctx, span := tracer.Start(ctx, "Client.SendEach") defer span.End() return c.doSendEachInBatch(ctx, messages, false) } func (c *Client) doSendEachInBatch( ctx context.Context, messages []Message, validateOnly bool, ) (MessageMultiSendResponse, error) { ctx, span := tracer.Start(ctx, "Client.doSendEachInBatch") defer span.End() messageCount := len(messages) if messageCount == 0 { return MessageMultiSendResponse{}, nil } if messageCount > maxMessages { return MessageMultiSendResponse{}, errors.New(fmt.Sprintf("messages limit (%d) exceeded: %d", maxMessages, messageCount)) } var responses = make([]MessageSendResponse, len(messages)) var wg sync.WaitGroup for idx, m := range messages { //if err := validateMessage(m); err != nil { // return nil, fmt.Errorf("invalid message at index %d: %v", idx, err) //} wg.Add(1) go func(idx int, m Message, validateOnly bool, responses []MessageSendResponse) { defer wg.Done() var resp string var err error if validateOnly { resp, err = c.Validate(ctx, m) } else { resp, err = c.Send(ctx, m) } if err == nil { responses[idx] = MessageSendResponse{ Success: true, MessageID: resp, } } else { responses[idx] = MessageSendResponse{ Success: false, Error: err, } } }(idx, m, validateOnly, responses) } // Wait for all Validate/Send calls to finish wg.Wait() successCount := 0 for _, r := range responses { if r.Success { successCount++ } } return MessageMultiSendResponse{ Responses: responses, Sent: successCount, Failed: len(responses) - successCount, }, nil } type MessageSendResponse struct { Success bool MessageID string Error error } type MessageMultiSendResponse struct { Responses []MessageSendResponse Sent int Failed int } // Запрос на отправку пуш-оповещения. type SendRequest struct { // Flag for testing the request without actually delivering the message. ValidateOnly bool `json:"validate_only,omitempty"` Message Message `json:"message"` } type SendResponse struct { MessageName string `json:"name"` Error *SendError `json:"error"` } func (sr SendResponse) HasError() bool { return sr.Error != nil } type SendErrorCode string const ( SendErrorCode_UNSPECIFIED_ERROR SendErrorCode = "UNSPECIFIED_ERROR" SendErrorCode_UNREGISTERED SendErrorCode = "UNREGISTERED" SendErrorCode_SENDER_ID_MISMATCH SendErrorCode = "SENDER_ID_MISMATCH" SendErrorCode_QUOTA_EXCEEDED SendErrorCode = "QUOTA_EXCEEDED" SendErrorCode_THIRD_PARTY_AUTH_ERROR SendErrorCode = "THIRD_PARTY_AUTH_ERROR" ) type SendError struct { Code int `json:"code"` Message string `json:"message"` Status string `json:"status"` Details []struct { Type string `json:"@type"` ErrorCode SendErrorCode `json:"errorCode"` } `json:"details"` } func (se *SendError) IsUnregistered() bool { if se == nil { return false } for _, d := range se.Details { if d.Type == fcmErrorType && d.ErrorCode == SendErrorCode_UNREGISTERED { return true } } return false } type MultiSendResponse struct { Responses []SendResponse Sent int Failed int }