package oauth2 import ( "context" "io" "mime" "net/http" "net/url" "strings" "sync" "github.com/pkg/errors" "go.uber.org/zap" ) type Config struct { ClientId string ClientSecret string AuthURL string TokenURL string Scopes []string } type TokenIssuer interface { Issue(context.Context) (Token, error) } type ClientCredentialsIssuer struct { cfg Config retriever *TokenRetriever } func NewClientCredentialsIssuer(cfg Config, tr *TokenRetriever) *ClientCredentialsIssuer { isr := ClientCredentialsIssuer{ cfg: cfg, retriever: tr, } return &isr } func (isr *ClientCredentialsIssuer) Issue(ctx context.Context) (Token, error) { data := url.Values{ "grant_type": {"client_credentials"}, "client_id": {isr.cfg.ClientId}, "client_secret": {isr.cfg.ClientSecret}, } if len(isr.cfg.Scopes) > 0 { data.Set("scope", strings.Join(isr.cfg.Scopes, " ")) } return isr.retriever.Retrieve(ctx, data) } type TokenRetriever struct { tokenURL string httpClient *http.Client timeNow timeNowFunc logger *zap.SugaredLogger } func NewTokenRetriever(tokenURL string, httpClient *http.Client, timeNow timeNowFunc, logger *zap.SugaredLogger) *TokenRetriever { tr := TokenRetriever{ tokenURL: tokenURL, httpClient: httpClient, timeNow: timeNow, logger: logger, } return &tr } func (tr *TokenRetriever) Retrieve(ctx context.Context, data url.Values) (Token, error) { req, err := http.NewRequestWithContext( ctx, http.MethodPost, tr.tokenURL, strings.NewReader(data.Encode()), ) if err != nil { return Token{}, errors.WithStack(err) } req.Header.Set("Content-Type", "application/x-www-form-urlencoded") tr.logger.Infof("retrieve token with request: %+v", req) resp, err := tr.httpClient.Do(req) if err != nil { return Token{}, errors.WithStack(err) } defer resp.Body.Close() body, err := io.ReadAll(resp.Body) if err != nil { return Token{}, errors.WithStack(err) } if resp.StatusCode != http.StatusOK { return Token{}, errors.Errorf("Unexpected status code: %v, body: %s", resp.StatusCode, body) } contentType, _, _ := mime.ParseMediaType(resp.Header.Get("Content-Type")) token, err := parseToken(body, contentType, tr.timeNow) if err != nil { return Token{}, err } return token, nil } type TokenRefresher struct { retriever *TokenRetriever } func NewTokenRefresher(retriever *TokenRetriever) *TokenRefresher { tr := TokenRefresher{ retriever: retriever, } return &tr } func (tr *TokenRefresher) Refresh(ctx context.Context, refreshToken string) (Token, error) { if len(refreshToken) == 0 { return Token{}, errors.Errorf("refresh token is empty!") } data := url.Values{ "grant_type": {"refresh_token"}, "refresh_token": {refreshToken}, } return tr.retriever.Retrieve(ctx, data) } type ReuseIssuer struct { new TokenIssuer refresher *TokenRefresher timeNow timeNowFunc token Token mu sync.Mutex } func NewReuseIssuer(new TokenIssuer, refresher *TokenRefresher, timeNow timeNowFunc) *ReuseIssuer { isr := ReuseIssuer{ new: new, refresher: refresher, timeNow: timeNow, } return &isr } func (isr *ReuseIssuer) Issue(ctx context.Context) (Token, error) { isr.mu.Lock() defer isr.mu.Unlock() if isr.token.Valid(isr.timeNow()) { return isr.token, nil } token, err := isr.refresher.Refresh(ctx, isr.token.RefreshToken) if err == nil { isr.token = token return isr.token, nil } token, err = isr.new.Issue(ctx) if err != nil { return Token{}, err } isr.token = token return isr.token, nil }