commit 997188e3107cad134bcc27538bfa0367eb710c79 Author: Pavel Merzlyakov Date: Tue Nov 15 10:22:46 2022 +0300 initial commit diff --git a/go.mod b/go.mod new file mode 100644 index 0000000..e8da7a9 --- /dev/null +++ b/go.mod @@ -0,0 +1,5 @@ +module git.bit5.ru/backend/oauth2 + +go 1.18 + +require github.com/pkg/errors v0.9.1 diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..7c401c3 --- /dev/null +++ b/go.sum @@ -0,0 +1,2 @@ +github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= +github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= diff --git a/oauth2.go b/oauth2.go new file mode 100644 index 0000000..cece370 --- /dev/null +++ b/oauth2.go @@ -0,0 +1,166 @@ +package oauth2 + +import ( + "context" + "io" + "mime" + "net/http" + "net/url" + "strings" + "sync" + + "github.com/pkg/errors" +) + +type Config struct { + ClientId string + ClientSecret string + AuthURL string + TokenURL string + Scopes []string +} + +type TokenIssuer interface { + Issue(context.Context) (Token, error) +} + +type ClientCredentialsIssuer struct { + cfg Config + retriever *TokenRetriever +} + +func NewClientCredentialsIssuer(cfg Config, tr *TokenRetriever) *ClientCredentialsIssuer { + isr := ClientCredentialsIssuer{ + cfg: cfg, + retriever: tr, + } + return &isr +} + +func (isr *ClientCredentialsIssuer) Issue(ctx context.Context) (Token, error) { + data := url.Values{ + "grant_type": {"client_credentials"}, + "client_id": {isr.cfg.ClientId}, + "client_secret": {isr.cfg.ClientSecret}, + } + if len(isr.cfg.Scopes) > 0 { + data.Set("scope", strings.Join(isr.cfg.Scopes, " ")) + } + + return isr.retriever.Retrieve(ctx, data) +} + +type TokenRetriever struct { + tokenURL string + httpClient *http.Client + timeNow timeNowFunc +} + +func NewTokenRetriever(tokenURL string, httpClient *http.Client, timeNow timeNowFunc) *TokenRetriever { + tr := TokenRetriever{ + tokenURL: tokenURL, + httpClient: httpClient, + timeNow: timeNow, + } + return &tr +} + +func (tr *TokenRetriever) Retrieve(ctx context.Context, data url.Values) (Token, error) { + req, err := http.NewRequestWithContext( + ctx, + http.MethodPost, + tr.tokenURL, + strings.NewReader(data.Encode()), + ) + if err != nil { + return Token{}, errors.WithStack(err) + } + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + + resp, err := tr.httpClient.Do(req) + if err != nil { + return Token{}, errors.WithStack(err) + } + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + return Token{}, errors.WithStack(err) + } + + if resp.StatusCode != http.StatusOK { + return Token{}, errors.Errorf("Unexpected status code: %v, body: %s", resp.StatusCode, body) + } + + contentType, _, _ := mime.ParseMediaType(resp.Header.Get("Content-Type")) + token, err := parseToken(body, contentType, tr.timeNow) + if err != nil { + return Token{}, err + } + + return token, nil +} + +type TokenRefresher struct { + retriever *TokenRetriever +} + +func NewTokenRefresher(retriever *TokenRetriever) *TokenRefresher { + tr := TokenRefresher{ + retriever: retriever, + } + return &tr +} + +func (tr *TokenRefresher) Refresh(ctx context.Context, refreshToken string) (Token, error) { + if len(refreshToken) == 0 { + return Token{}, errors.Errorf("refresh token is empty!") + } + + data := url.Values{ + "grant_type": {"refresh_token"}, + "refresh_token": {refreshToken}, + } + return tr.retriever.Retrieve(ctx, data) +} + +type ReuseIssuer struct { + new TokenIssuer + refresher *TokenRefresher + timeNow timeNowFunc + + token Token + mu sync.Mutex +} + +func NewReuseIssuer(new TokenIssuer, refresher *TokenRefresher, timeNow timeNowFunc) *ReuseIssuer { + isr := ReuseIssuer{ + new: new, + refresher: refresher, + timeNow: timeNow, + } + return &isr +} + +func (isr *ReuseIssuer) Issue(ctx context.Context) (Token, error) { + isr.mu.Lock() + defer isr.mu.Unlock() + + if isr.token.Valid(isr.timeNow()) { + return isr.token, nil + } + + token, err := isr.refresher.Refresh(ctx, isr.token.RefreshToken) + if err == nil { + isr.token = token + return isr.token, nil + } + + token, err = isr.new.Issue(ctx) + if err != nil { + return Token{}, err + } + + isr.token = token + return isr.token, nil +} diff --git a/token.go b/token.go new file mode 100644 index 0000000..f3a38f8 --- /dev/null +++ b/token.go @@ -0,0 +1,118 @@ +package oauth2 + +import ( + "encoding/json" + "fmt" + "net/url" + "strconv" + "strings" + "time" + + "github.com/pkg/errors" +) + +const ( + expiresThreshold = time.Second * 3 +) + +type timeNowFunc func() time.Time + +type rawToken struct { + AccessToken string `json:"access_token"` + TokenType string `json:"token_type"` + ExpiresIn int64 `json:"expires_in"` + RefreshToken string `json:"refresh_token"` + Scope string `json:"scope"` + + rawData map[string]any `json:"-"` +} + +func (rt rawToken) makeToken(timeNow timeNowFunc) Token { + return Token{ + AccessToken: rt.AccessToken, + TokenType: rt.TokenType, + ExpiresIn: timeNow().Add(time.Second * time.Duration(rt.ExpiresIn)), + RefreshToken: rt.RefreshToken, + Scope: rt.Scope, + rawData: rt.rawData, + } +} + +type Token struct { + AccessToken string + TokenType string + ExpiresIn time.Time + RefreshToken string + Scope string + + rawData map[string]any +} + +func (t Token) Valid(now time.Time) bool { + return len(t.AccessToken) > 0 && !t.expired(now) +} + +func (t Token) expired(now time.Time) bool { + return t.ExpiresIn.Add(-expiresThreshold).Before(now) +} + +func (t Token) Type() string { + if len(t.TokenType) == 0 { + return "Bearer" + } + if strings.EqualFold(t.TokenType, "bearer") { + return "Bearer" + } + return t.TokenType +} + +func (t Token) AuthorizationHeader() string { + return fmt.Sprintf("%s %s", t.Type(), t.AccessToken) +} + +func (t Token) Extra(key string) any { + v, ok := t.rawData[key] + if !ok { + return nil + } + return v +} + +func parseToken(data []byte, contentType string, timeNow timeNowFunc) (Token, error) { + switch contentType { + case "application/x-www-form-urlencoded", "text/plain": + vals, err := url.ParseQuery(string(data)) + if err != nil { + return Token{}, errors.WithStack(err) + } + + expiresIn, err := strconv.ParseInt(vals.Get("expires_in"), 10, 64) + if err != nil { + return Token{}, errors.WithStack(err) + } + + rawData := make(map[string]any, len(vals)) + for k := range vals { + rawData[k] = vals.Get(k) + } + + rt := rawToken{ + AccessToken: vals.Get("access_token"), + TokenType: vals.Get("token_type"), + ExpiresIn: expiresIn, + RefreshToken: vals.Get("refresh_token"), + Scope: vals.Get("scope"), + rawData: rawData, + } + return rt.makeToken(timeNow), nil + + default: + var rt rawToken + if err := json.Unmarshal(data, &rt); err != nil { + return Token{}, errors.WithStack(err) + } + json.Unmarshal(data, &rt.rawData) + + return rt.makeToken(timeNow), nil + } +}