initial commit
This commit is contained in:
commit
997188e310
|
@ -0,0 +1,5 @@
|
||||||
|
module git.bit5.ru/backend/oauth2
|
||||||
|
|
||||||
|
go 1.18
|
||||||
|
|
||||||
|
require github.com/pkg/errors v0.9.1
|
|
@ -0,0 +1,2 @@
|
||||||
|
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
|
||||||
|
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
|
|
@ -0,0 +1,166 @@
|
||||||
|
package oauth2
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"io"
|
||||||
|
"mime"
|
||||||
|
"net/http"
|
||||||
|
"net/url"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
|
||||||
|
"github.com/pkg/errors"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Config struct {
|
||||||
|
ClientId string
|
||||||
|
ClientSecret string
|
||||||
|
AuthURL string
|
||||||
|
TokenURL string
|
||||||
|
Scopes []string
|
||||||
|
}
|
||||||
|
|
||||||
|
type TokenIssuer interface {
|
||||||
|
Issue(context.Context) (Token, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
type ClientCredentialsIssuer struct {
|
||||||
|
cfg Config
|
||||||
|
retriever *TokenRetriever
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewClientCredentialsIssuer(cfg Config, tr *TokenRetriever) *ClientCredentialsIssuer {
|
||||||
|
isr := ClientCredentialsIssuer{
|
||||||
|
cfg: cfg,
|
||||||
|
retriever: tr,
|
||||||
|
}
|
||||||
|
return &isr
|
||||||
|
}
|
||||||
|
|
||||||
|
func (isr *ClientCredentialsIssuer) Issue(ctx context.Context) (Token, error) {
|
||||||
|
data := url.Values{
|
||||||
|
"grant_type": {"client_credentials"},
|
||||||
|
"client_id": {isr.cfg.ClientId},
|
||||||
|
"client_secret": {isr.cfg.ClientSecret},
|
||||||
|
}
|
||||||
|
if len(isr.cfg.Scopes) > 0 {
|
||||||
|
data.Set("scope", strings.Join(isr.cfg.Scopes, " "))
|
||||||
|
}
|
||||||
|
|
||||||
|
return isr.retriever.Retrieve(ctx, data)
|
||||||
|
}
|
||||||
|
|
||||||
|
type TokenRetriever struct {
|
||||||
|
tokenURL string
|
||||||
|
httpClient *http.Client
|
||||||
|
timeNow timeNowFunc
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewTokenRetriever(tokenURL string, httpClient *http.Client, timeNow timeNowFunc) *TokenRetriever {
|
||||||
|
tr := TokenRetriever{
|
||||||
|
tokenURL: tokenURL,
|
||||||
|
httpClient: httpClient,
|
||||||
|
timeNow: timeNow,
|
||||||
|
}
|
||||||
|
return &tr
|
||||||
|
}
|
||||||
|
|
||||||
|
func (tr *TokenRetriever) Retrieve(ctx context.Context, data url.Values) (Token, error) {
|
||||||
|
req, err := http.NewRequestWithContext(
|
||||||
|
ctx,
|
||||||
|
http.MethodPost,
|
||||||
|
tr.tokenURL,
|
||||||
|
strings.NewReader(data.Encode()),
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return Token{}, errors.WithStack(err)
|
||||||
|
}
|
||||||
|
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||||
|
|
||||||
|
resp, err := tr.httpClient.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
return Token{}, errors.WithStack(err)
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
body, err := io.ReadAll(resp.Body)
|
||||||
|
if err != nil {
|
||||||
|
return Token{}, errors.WithStack(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
return Token{}, errors.Errorf("Unexpected status code: %v, body: %s", resp.StatusCode, body)
|
||||||
|
}
|
||||||
|
|
||||||
|
contentType, _, _ := mime.ParseMediaType(resp.Header.Get("Content-Type"))
|
||||||
|
token, err := parseToken(body, contentType, tr.timeNow)
|
||||||
|
if err != nil {
|
||||||
|
return Token{}, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return token, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type TokenRefresher struct {
|
||||||
|
retriever *TokenRetriever
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewTokenRefresher(retriever *TokenRetriever) *TokenRefresher {
|
||||||
|
tr := TokenRefresher{
|
||||||
|
retriever: retriever,
|
||||||
|
}
|
||||||
|
return &tr
|
||||||
|
}
|
||||||
|
|
||||||
|
func (tr *TokenRefresher) Refresh(ctx context.Context, refreshToken string) (Token, error) {
|
||||||
|
if len(refreshToken) == 0 {
|
||||||
|
return Token{}, errors.Errorf("refresh token is empty!")
|
||||||
|
}
|
||||||
|
|
||||||
|
data := url.Values{
|
||||||
|
"grant_type": {"refresh_token"},
|
||||||
|
"refresh_token": {refreshToken},
|
||||||
|
}
|
||||||
|
return tr.retriever.Retrieve(ctx, data)
|
||||||
|
}
|
||||||
|
|
||||||
|
type ReuseIssuer struct {
|
||||||
|
new TokenIssuer
|
||||||
|
refresher *TokenRefresher
|
||||||
|
timeNow timeNowFunc
|
||||||
|
|
||||||
|
token Token
|
||||||
|
mu sync.Mutex
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewReuseIssuer(new TokenIssuer, refresher *TokenRefresher, timeNow timeNowFunc) *ReuseIssuer {
|
||||||
|
isr := ReuseIssuer{
|
||||||
|
new: new,
|
||||||
|
refresher: refresher,
|
||||||
|
timeNow: timeNow,
|
||||||
|
}
|
||||||
|
return &isr
|
||||||
|
}
|
||||||
|
|
||||||
|
func (isr *ReuseIssuer) Issue(ctx context.Context) (Token, error) {
|
||||||
|
isr.mu.Lock()
|
||||||
|
defer isr.mu.Unlock()
|
||||||
|
|
||||||
|
if isr.token.Valid(isr.timeNow()) {
|
||||||
|
return isr.token, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
token, err := isr.refresher.Refresh(ctx, isr.token.RefreshToken)
|
||||||
|
if err == nil {
|
||||||
|
isr.token = token
|
||||||
|
return isr.token, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
token, err = isr.new.Issue(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return Token{}, err
|
||||||
|
}
|
||||||
|
|
||||||
|
isr.token = token
|
||||||
|
return isr.token, nil
|
||||||
|
}
|
|
@ -0,0 +1,118 @@
|
||||||
|
package oauth2
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"net/url"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/pkg/errors"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
expiresThreshold = time.Second * 3
|
||||||
|
)
|
||||||
|
|
||||||
|
type timeNowFunc func() time.Time
|
||||||
|
|
||||||
|
type rawToken struct {
|
||||||
|
AccessToken string `json:"access_token"`
|
||||||
|
TokenType string `json:"token_type"`
|
||||||
|
ExpiresIn int64 `json:"expires_in"`
|
||||||
|
RefreshToken string `json:"refresh_token"`
|
||||||
|
Scope string `json:"scope"`
|
||||||
|
|
||||||
|
rawData map[string]any `json:"-"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (rt rawToken) makeToken(timeNow timeNowFunc) Token {
|
||||||
|
return Token{
|
||||||
|
AccessToken: rt.AccessToken,
|
||||||
|
TokenType: rt.TokenType,
|
||||||
|
ExpiresIn: timeNow().Add(time.Second * time.Duration(rt.ExpiresIn)),
|
||||||
|
RefreshToken: rt.RefreshToken,
|
||||||
|
Scope: rt.Scope,
|
||||||
|
rawData: rt.rawData,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type Token struct {
|
||||||
|
AccessToken string
|
||||||
|
TokenType string
|
||||||
|
ExpiresIn time.Time
|
||||||
|
RefreshToken string
|
||||||
|
Scope string
|
||||||
|
|
||||||
|
rawData map[string]any
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t Token) Valid(now time.Time) bool {
|
||||||
|
return len(t.AccessToken) > 0 && !t.expired(now)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t Token) expired(now time.Time) bool {
|
||||||
|
return t.ExpiresIn.Add(-expiresThreshold).Before(now)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t Token) Type() string {
|
||||||
|
if len(t.TokenType) == 0 {
|
||||||
|
return "Bearer"
|
||||||
|
}
|
||||||
|
if strings.EqualFold(t.TokenType, "bearer") {
|
||||||
|
return "Bearer"
|
||||||
|
}
|
||||||
|
return t.TokenType
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t Token) AuthorizationHeader() string {
|
||||||
|
return fmt.Sprintf("%s %s", t.Type(), t.AccessToken)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t Token) Extra(key string) any {
|
||||||
|
v, ok := t.rawData[key]
|
||||||
|
if !ok {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return v
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseToken(data []byte, contentType string, timeNow timeNowFunc) (Token, error) {
|
||||||
|
switch contentType {
|
||||||
|
case "application/x-www-form-urlencoded", "text/plain":
|
||||||
|
vals, err := url.ParseQuery(string(data))
|
||||||
|
if err != nil {
|
||||||
|
return Token{}, errors.WithStack(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
expiresIn, err := strconv.ParseInt(vals.Get("expires_in"), 10, 64)
|
||||||
|
if err != nil {
|
||||||
|
return Token{}, errors.WithStack(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
rawData := make(map[string]any, len(vals))
|
||||||
|
for k := range vals {
|
||||||
|
rawData[k] = vals.Get(k)
|
||||||
|
}
|
||||||
|
|
||||||
|
rt := rawToken{
|
||||||
|
AccessToken: vals.Get("access_token"),
|
||||||
|
TokenType: vals.Get("token_type"),
|
||||||
|
ExpiresIn: expiresIn,
|
||||||
|
RefreshToken: vals.Get("refresh_token"),
|
||||||
|
Scope: vals.Get("scope"),
|
||||||
|
rawData: rawData,
|
||||||
|
}
|
||||||
|
return rt.makeToken(timeNow), nil
|
||||||
|
|
||||||
|
default:
|
||||||
|
var rt rawToken
|
||||||
|
if err := json.Unmarshal(data, &rt); err != nil {
|
||||||
|
return Token{}, errors.WithStack(err)
|
||||||
|
}
|
||||||
|
json.Unmarshal(data, &rt.rawData)
|
||||||
|
|
||||||
|
return rt.makeToken(timeNow), nil
|
||||||
|
}
|
||||||
|
}
|
Loading…
Reference in New Issue