oauth2/oauth2.go

172 lines
3.5 KiB
Go

package oauth2
import (
"context"
"io"
"mime"
"net/http"
"net/url"
"strings"
"sync"
"github.com/pkg/errors"
"go.uber.org/zap"
)
type Config struct {
ClientId string
ClientSecret string
AuthURL string
TokenURL string
Scopes []string
}
type TokenIssuer interface {
Issue(context.Context) (Token, error)
}
type ClientCredentialsIssuer struct {
cfg Config
retriever *TokenRetriever
}
func NewClientCredentialsIssuer(cfg Config, tr *TokenRetriever) *ClientCredentialsIssuer {
isr := ClientCredentialsIssuer{
cfg: cfg,
retriever: tr,
}
return &isr
}
func (isr *ClientCredentialsIssuer) Issue(ctx context.Context) (Token, error) {
data := url.Values{
"grant_type": {"client_credentials"},
"client_id": {isr.cfg.ClientId},
"client_secret": {isr.cfg.ClientSecret},
}
if len(isr.cfg.Scopes) > 0 {
data.Set("scope", strings.Join(isr.cfg.Scopes, " "))
}
return isr.retriever.Retrieve(ctx, data)
}
type TokenRetriever struct {
tokenURL string
httpClient *http.Client
timeNow timeNowFunc
logger *zap.SugaredLogger
}
func NewTokenRetriever(tokenURL string, httpClient *http.Client, timeNow timeNowFunc, logger *zap.SugaredLogger) *TokenRetriever {
tr := TokenRetriever{
tokenURL: tokenURL,
httpClient: httpClient,
timeNow: timeNow,
logger: logger,
}
return &tr
}
func (tr *TokenRetriever) Retrieve(ctx context.Context, data url.Values) (Token, error) {
req, err := http.NewRequestWithContext(
ctx,
http.MethodPost,
tr.tokenURL,
strings.NewReader(data.Encode()),
)
if err != nil {
return Token{}, errors.WithStack(err)
}
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
tr.logger.Infof("retrieve token with request: %+v", req)
resp, err := tr.httpClient.Do(req)
if err != nil {
return Token{}, errors.WithStack(err)
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
return Token{}, errors.WithStack(err)
}
if resp.StatusCode != http.StatusOK {
return Token{}, errors.Errorf("Unexpected status code: %v, body: %s", resp.StatusCode, body)
}
contentType, _, _ := mime.ParseMediaType(resp.Header.Get("Content-Type"))
token, err := parseToken(body, contentType, tr.timeNow)
if err != nil {
return Token{}, err
}
return token, nil
}
type TokenRefresher struct {
retriever *TokenRetriever
}
func NewTokenRefresher(retriever *TokenRetriever) *TokenRefresher {
tr := TokenRefresher{
retriever: retriever,
}
return &tr
}
func (tr *TokenRefresher) Refresh(ctx context.Context, refreshToken string) (Token, error) {
if len(refreshToken) == 0 {
return Token{}, errors.Errorf("refresh token is empty!")
}
data := url.Values{
"grant_type": {"refresh_token"},
"refresh_token": {refreshToken},
}
return tr.retriever.Retrieve(ctx, data)
}
type ReuseIssuer struct {
new TokenIssuer
refresher *TokenRefresher
timeNow timeNowFunc
token Token
mu sync.Mutex
}
func NewReuseIssuer(new TokenIssuer, refresher *TokenRefresher, timeNow timeNowFunc) *ReuseIssuer {
isr := ReuseIssuer{
new: new,
refresher: refresher,
timeNow: timeNow,
}
return &isr
}
func (isr *ReuseIssuer) Issue(ctx context.Context) (Token, error) {
isr.mu.Lock()
defer isr.mu.Unlock()
if isr.token.Valid(isr.timeNow()) {
return isr.token, nil
}
token, err := isr.refresher.Refresh(ctx, isr.token.RefreshToken)
if err == nil {
isr.token = token
return isr.token, nil
}
token, err = isr.new.Issue(ctx)
if err != nil {
return Token{}, err
}
isr.token = token
return isr.token, nil
}