initial commit
This commit is contained in:
commit
997188e310
|
@ -0,0 +1,5 @@
|
|||
module git.bit5.ru/backend/oauth2
|
||||
|
||||
go 1.18
|
||||
|
||||
require github.com/pkg/errors v0.9.1
|
|
@ -0,0 +1,2 @@
|
|||
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
|
||||
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
|
|
@ -0,0 +1,166 @@
|
|||
package oauth2
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"mime"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
type Config struct {
|
||||
ClientId string
|
||||
ClientSecret string
|
||||
AuthURL string
|
||||
TokenURL string
|
||||
Scopes []string
|
||||
}
|
||||
|
||||
type TokenIssuer interface {
|
||||
Issue(context.Context) (Token, error)
|
||||
}
|
||||
|
||||
type ClientCredentialsIssuer struct {
|
||||
cfg Config
|
||||
retriever *TokenRetriever
|
||||
}
|
||||
|
||||
func NewClientCredentialsIssuer(cfg Config, tr *TokenRetriever) *ClientCredentialsIssuer {
|
||||
isr := ClientCredentialsIssuer{
|
||||
cfg: cfg,
|
||||
retriever: tr,
|
||||
}
|
||||
return &isr
|
||||
}
|
||||
|
||||
func (isr *ClientCredentialsIssuer) Issue(ctx context.Context) (Token, error) {
|
||||
data := url.Values{
|
||||
"grant_type": {"client_credentials"},
|
||||
"client_id": {isr.cfg.ClientId},
|
||||
"client_secret": {isr.cfg.ClientSecret},
|
||||
}
|
||||
if len(isr.cfg.Scopes) > 0 {
|
||||
data.Set("scope", strings.Join(isr.cfg.Scopes, " "))
|
||||
}
|
||||
|
||||
return isr.retriever.Retrieve(ctx, data)
|
||||
}
|
||||
|
||||
type TokenRetriever struct {
|
||||
tokenURL string
|
||||
httpClient *http.Client
|
||||
timeNow timeNowFunc
|
||||
}
|
||||
|
||||
func NewTokenRetriever(tokenURL string, httpClient *http.Client, timeNow timeNowFunc) *TokenRetriever {
|
||||
tr := TokenRetriever{
|
||||
tokenURL: tokenURL,
|
||||
httpClient: httpClient,
|
||||
timeNow: timeNow,
|
||||
}
|
||||
return &tr
|
||||
}
|
||||
|
||||
func (tr *TokenRetriever) Retrieve(ctx context.Context, data url.Values) (Token, error) {
|
||||
req, err := http.NewRequestWithContext(
|
||||
ctx,
|
||||
http.MethodPost,
|
||||
tr.tokenURL,
|
||||
strings.NewReader(data.Encode()),
|
||||
)
|
||||
if err != nil {
|
||||
return Token{}, errors.WithStack(err)
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
|
||||
resp, err := tr.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return Token{}, errors.WithStack(err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return Token{}, errors.WithStack(err)
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return Token{}, errors.Errorf("Unexpected status code: %v, body: %s", resp.StatusCode, body)
|
||||
}
|
||||
|
||||
contentType, _, _ := mime.ParseMediaType(resp.Header.Get("Content-Type"))
|
||||
token, err := parseToken(body, contentType, tr.timeNow)
|
||||
if err != nil {
|
||||
return Token{}, err
|
||||
}
|
||||
|
||||
return token, nil
|
||||
}
|
||||
|
||||
type TokenRefresher struct {
|
||||
retriever *TokenRetriever
|
||||
}
|
||||
|
||||
func NewTokenRefresher(retriever *TokenRetriever) *TokenRefresher {
|
||||
tr := TokenRefresher{
|
||||
retriever: retriever,
|
||||
}
|
||||
return &tr
|
||||
}
|
||||
|
||||
func (tr *TokenRefresher) Refresh(ctx context.Context, refreshToken string) (Token, error) {
|
||||
if len(refreshToken) == 0 {
|
||||
return Token{}, errors.Errorf("refresh token is empty!")
|
||||
}
|
||||
|
||||
data := url.Values{
|
||||
"grant_type": {"refresh_token"},
|
||||
"refresh_token": {refreshToken},
|
||||
}
|
||||
return tr.retriever.Retrieve(ctx, data)
|
||||
}
|
||||
|
||||
type ReuseIssuer struct {
|
||||
new TokenIssuer
|
||||
refresher *TokenRefresher
|
||||
timeNow timeNowFunc
|
||||
|
||||
token Token
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
func NewReuseIssuer(new TokenIssuer, refresher *TokenRefresher, timeNow timeNowFunc) *ReuseIssuer {
|
||||
isr := ReuseIssuer{
|
||||
new: new,
|
||||
refresher: refresher,
|
||||
timeNow: timeNow,
|
||||
}
|
||||
return &isr
|
||||
}
|
||||
|
||||
func (isr *ReuseIssuer) Issue(ctx context.Context) (Token, error) {
|
||||
isr.mu.Lock()
|
||||
defer isr.mu.Unlock()
|
||||
|
||||
if isr.token.Valid(isr.timeNow()) {
|
||||
return isr.token, nil
|
||||
}
|
||||
|
||||
token, err := isr.refresher.Refresh(ctx, isr.token.RefreshToken)
|
||||
if err == nil {
|
||||
isr.token = token
|
||||
return isr.token, nil
|
||||
}
|
||||
|
||||
token, err = isr.new.Issue(ctx)
|
||||
if err != nil {
|
||||
return Token{}, err
|
||||
}
|
||||
|
||||
isr.token = token
|
||||
return isr.token, nil
|
||||
}
|
|
@ -0,0 +1,118 @@
|
|||
package oauth2
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/url"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
const (
|
||||
expiresThreshold = time.Second * 3
|
||||
)
|
||||
|
||||
type timeNowFunc func() time.Time
|
||||
|
||||
type rawToken struct {
|
||||
AccessToken string `json:"access_token"`
|
||||
TokenType string `json:"token_type"`
|
||||
ExpiresIn int64 `json:"expires_in"`
|
||||
RefreshToken string `json:"refresh_token"`
|
||||
Scope string `json:"scope"`
|
||||
|
||||
rawData map[string]any `json:"-"`
|
||||
}
|
||||
|
||||
func (rt rawToken) makeToken(timeNow timeNowFunc) Token {
|
||||
return Token{
|
||||
AccessToken: rt.AccessToken,
|
||||
TokenType: rt.TokenType,
|
||||
ExpiresIn: timeNow().Add(time.Second * time.Duration(rt.ExpiresIn)),
|
||||
RefreshToken: rt.RefreshToken,
|
||||
Scope: rt.Scope,
|
||||
rawData: rt.rawData,
|
||||
}
|
||||
}
|
||||
|
||||
type Token struct {
|
||||
AccessToken string
|
||||
TokenType string
|
||||
ExpiresIn time.Time
|
||||
RefreshToken string
|
||||
Scope string
|
||||
|
||||
rawData map[string]any
|
||||
}
|
||||
|
||||
func (t Token) Valid(now time.Time) bool {
|
||||
return len(t.AccessToken) > 0 && !t.expired(now)
|
||||
}
|
||||
|
||||
func (t Token) expired(now time.Time) bool {
|
||||
return t.ExpiresIn.Add(-expiresThreshold).Before(now)
|
||||
}
|
||||
|
||||
func (t Token) Type() string {
|
||||
if len(t.TokenType) == 0 {
|
||||
return "Bearer"
|
||||
}
|
||||
if strings.EqualFold(t.TokenType, "bearer") {
|
||||
return "Bearer"
|
||||
}
|
||||
return t.TokenType
|
||||
}
|
||||
|
||||
func (t Token) AuthorizationHeader() string {
|
||||
return fmt.Sprintf("%s %s", t.Type(), t.AccessToken)
|
||||
}
|
||||
|
||||
func (t Token) Extra(key string) any {
|
||||
v, ok := t.rawData[key]
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
return v
|
||||
}
|
||||
|
||||
func parseToken(data []byte, contentType string, timeNow timeNowFunc) (Token, error) {
|
||||
switch contentType {
|
||||
case "application/x-www-form-urlencoded", "text/plain":
|
||||
vals, err := url.ParseQuery(string(data))
|
||||
if err != nil {
|
||||
return Token{}, errors.WithStack(err)
|
||||
}
|
||||
|
||||
expiresIn, err := strconv.ParseInt(vals.Get("expires_in"), 10, 64)
|
||||
if err != nil {
|
||||
return Token{}, errors.WithStack(err)
|
||||
}
|
||||
|
||||
rawData := make(map[string]any, len(vals))
|
||||
for k := range vals {
|
||||
rawData[k] = vals.Get(k)
|
||||
}
|
||||
|
||||
rt := rawToken{
|
||||
AccessToken: vals.Get("access_token"),
|
||||
TokenType: vals.Get("token_type"),
|
||||
ExpiresIn: expiresIn,
|
||||
RefreshToken: vals.Get("refresh_token"),
|
||||
Scope: vals.Get("scope"),
|
||||
rawData: rawData,
|
||||
}
|
||||
return rt.makeToken(timeNow), nil
|
||||
|
||||
default:
|
||||
var rt rawToken
|
||||
if err := json.Unmarshal(data, &rt); err != nil {
|
||||
return Token{}, errors.WithStack(err)
|
||||
}
|
||||
json.Unmarshal(data, &rt.rawData)
|
||||
|
||||
return rt.makeToken(timeNow), nil
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue