172 lines
3.5 KiB
Go
172 lines
3.5 KiB
Go
package oauth2
|
|
|
|
import (
|
|
"context"
|
|
"io"
|
|
"mime"
|
|
"net/http"
|
|
"net/url"
|
|
"strings"
|
|
"sync"
|
|
|
|
"github.com/pkg/errors"
|
|
"go.uber.org/zap"
|
|
)
|
|
|
|
type Config struct {
|
|
ClientId string
|
|
ClientSecret string
|
|
AuthURL string
|
|
TokenURL string
|
|
Scopes []string
|
|
}
|
|
|
|
type TokenIssuer interface {
|
|
Issue(context.Context) (Token, error)
|
|
}
|
|
|
|
type ClientCredentialsIssuer struct {
|
|
cfg Config
|
|
retriever *TokenRetriever
|
|
}
|
|
|
|
func NewClientCredentialsIssuer(cfg Config, tr *TokenRetriever) *ClientCredentialsIssuer {
|
|
isr := ClientCredentialsIssuer{
|
|
cfg: cfg,
|
|
retriever: tr,
|
|
}
|
|
return &isr
|
|
}
|
|
|
|
func (isr *ClientCredentialsIssuer) Issue(ctx context.Context) (Token, error) {
|
|
data := url.Values{
|
|
"grant_type": {"client_credentials"},
|
|
"client_id": {isr.cfg.ClientId},
|
|
"client_secret": {isr.cfg.ClientSecret},
|
|
}
|
|
if len(isr.cfg.Scopes) > 0 {
|
|
data.Set("scope", strings.Join(isr.cfg.Scopes, " "))
|
|
}
|
|
|
|
return isr.retriever.Retrieve(ctx, data)
|
|
}
|
|
|
|
type TokenRetriever struct {
|
|
tokenURL string
|
|
httpClient *http.Client
|
|
timeNow timeNowFunc
|
|
logger *zap.SugaredLogger
|
|
}
|
|
|
|
func NewTokenRetriever(tokenURL string, httpClient *http.Client, timeNow timeNowFunc, logger *zap.SugaredLogger) *TokenRetriever {
|
|
tr := TokenRetriever{
|
|
tokenURL: tokenURL,
|
|
httpClient: httpClient,
|
|
timeNow: timeNow,
|
|
logger: logger,
|
|
}
|
|
return &tr
|
|
}
|
|
|
|
func (tr *TokenRetriever) Retrieve(ctx context.Context, data url.Values) (Token, error) {
|
|
req, err := http.NewRequestWithContext(
|
|
ctx,
|
|
http.MethodPost,
|
|
tr.tokenURL,
|
|
strings.NewReader(data.Encode()),
|
|
)
|
|
if err != nil {
|
|
return Token{}, errors.WithStack(err)
|
|
}
|
|
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
|
|
|
tr.logger.Infof("retrieve token with request: %+v", req)
|
|
|
|
resp, err := tr.httpClient.Do(req)
|
|
if err != nil {
|
|
return Token{}, errors.WithStack(err)
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
body, err := io.ReadAll(resp.Body)
|
|
if err != nil {
|
|
return Token{}, errors.WithStack(err)
|
|
}
|
|
|
|
if resp.StatusCode != http.StatusOK {
|
|
return Token{}, errors.Errorf("Unexpected status code: %v, body: %s", resp.StatusCode, body)
|
|
}
|
|
|
|
contentType, _, _ := mime.ParseMediaType(resp.Header.Get("Content-Type"))
|
|
token, err := parseToken(body, contentType, tr.timeNow)
|
|
if err != nil {
|
|
return Token{}, err
|
|
}
|
|
|
|
return token, nil
|
|
}
|
|
|
|
type TokenRefresher struct {
|
|
retriever *TokenRetriever
|
|
}
|
|
|
|
func NewTokenRefresher(retriever *TokenRetriever) *TokenRefresher {
|
|
tr := TokenRefresher{
|
|
retriever: retriever,
|
|
}
|
|
return &tr
|
|
}
|
|
|
|
func (tr *TokenRefresher) Refresh(ctx context.Context, refreshToken string) (Token, error) {
|
|
if len(refreshToken) == 0 {
|
|
return Token{}, errors.Errorf("refresh token is empty!")
|
|
}
|
|
|
|
data := url.Values{
|
|
"grant_type": {"refresh_token"},
|
|
"refresh_token": {refreshToken},
|
|
}
|
|
return tr.retriever.Retrieve(ctx, data)
|
|
}
|
|
|
|
type ReuseIssuer struct {
|
|
new TokenIssuer
|
|
refresher *TokenRefresher
|
|
timeNow timeNowFunc
|
|
|
|
token Token
|
|
mu sync.Mutex
|
|
}
|
|
|
|
func NewReuseIssuer(new TokenIssuer, refresher *TokenRefresher, timeNow timeNowFunc) *ReuseIssuer {
|
|
isr := ReuseIssuer{
|
|
new: new,
|
|
refresher: refresher,
|
|
timeNow: timeNow,
|
|
}
|
|
return &isr
|
|
}
|
|
|
|
func (isr *ReuseIssuer) Issue(ctx context.Context) (Token, error) {
|
|
isr.mu.Lock()
|
|
defer isr.mu.Unlock()
|
|
|
|
if isr.token.Valid(isr.timeNow()) {
|
|
return isr.token, nil
|
|
}
|
|
|
|
token, err := isr.refresher.Refresh(ctx, isr.token.RefreshToken)
|
|
if err == nil {
|
|
isr.token = token
|
|
return isr.token, nil
|
|
}
|
|
|
|
token, err = isr.new.Issue(ctx)
|
|
if err != nil {
|
|
return Token{}, err
|
|
}
|
|
|
|
isr.token = token
|
|
return isr.token, nil
|
|
}
|