initial commit

This commit is contained in:
Pavel Merzlyakov 2022-11-15 10:22:46 +03:00
commit 997188e310
4 changed files with 291 additions and 0 deletions

5
go.mod Normal file
View File

@ -0,0 +1,5 @@
module git.bit5.ru/backend/oauth2
go 1.18
require github.com/pkg/errors v0.9.1

2
go.sum Normal file
View File

@ -0,0 +1,2 @@
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=

166
oauth2.go Normal file
View File

@ -0,0 +1,166 @@
package oauth2
import (
"context"
"io"
"mime"
"net/http"
"net/url"
"strings"
"sync"
"github.com/pkg/errors"
)
type Config struct {
ClientId string
ClientSecret string
AuthURL string
TokenURL string
Scopes []string
}
type TokenIssuer interface {
Issue(context.Context) (Token, error)
}
type ClientCredentialsIssuer struct {
cfg Config
retriever *TokenRetriever
}
func NewClientCredentialsIssuer(cfg Config, tr *TokenRetriever) *ClientCredentialsIssuer {
isr := ClientCredentialsIssuer{
cfg: cfg,
retriever: tr,
}
return &isr
}
func (isr *ClientCredentialsIssuer) Issue(ctx context.Context) (Token, error) {
data := url.Values{
"grant_type": {"client_credentials"},
"client_id": {isr.cfg.ClientId},
"client_secret": {isr.cfg.ClientSecret},
}
if len(isr.cfg.Scopes) > 0 {
data.Set("scope", strings.Join(isr.cfg.Scopes, " "))
}
return isr.retriever.Retrieve(ctx, data)
}
type TokenRetriever struct {
tokenURL string
httpClient *http.Client
timeNow timeNowFunc
}
func NewTokenRetriever(tokenURL string, httpClient *http.Client, timeNow timeNowFunc) *TokenRetriever {
tr := TokenRetriever{
tokenURL: tokenURL,
httpClient: httpClient,
timeNow: timeNow,
}
return &tr
}
func (tr *TokenRetriever) Retrieve(ctx context.Context, data url.Values) (Token, error) {
req, err := http.NewRequestWithContext(
ctx,
http.MethodPost,
tr.tokenURL,
strings.NewReader(data.Encode()),
)
if err != nil {
return Token{}, errors.WithStack(err)
}
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
resp, err := tr.httpClient.Do(req)
if err != nil {
return Token{}, errors.WithStack(err)
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
return Token{}, errors.WithStack(err)
}
if resp.StatusCode != http.StatusOK {
return Token{}, errors.Errorf("Unexpected status code: %v, body: %s", resp.StatusCode, body)
}
contentType, _, _ := mime.ParseMediaType(resp.Header.Get("Content-Type"))
token, err := parseToken(body, contentType, tr.timeNow)
if err != nil {
return Token{}, err
}
return token, nil
}
type TokenRefresher struct {
retriever *TokenRetriever
}
func NewTokenRefresher(retriever *TokenRetriever) *TokenRefresher {
tr := TokenRefresher{
retriever: retriever,
}
return &tr
}
func (tr *TokenRefresher) Refresh(ctx context.Context, refreshToken string) (Token, error) {
if len(refreshToken) == 0 {
return Token{}, errors.Errorf("refresh token is empty!")
}
data := url.Values{
"grant_type": {"refresh_token"},
"refresh_token": {refreshToken},
}
return tr.retriever.Retrieve(ctx, data)
}
type ReuseIssuer struct {
new TokenIssuer
refresher *TokenRefresher
timeNow timeNowFunc
token Token
mu sync.Mutex
}
func NewReuseIssuer(new TokenIssuer, refresher *TokenRefresher, timeNow timeNowFunc) *ReuseIssuer {
isr := ReuseIssuer{
new: new,
refresher: refresher,
timeNow: timeNow,
}
return &isr
}
func (isr *ReuseIssuer) Issue(ctx context.Context) (Token, error) {
isr.mu.Lock()
defer isr.mu.Unlock()
if isr.token.Valid(isr.timeNow()) {
return isr.token, nil
}
token, err := isr.refresher.Refresh(ctx, isr.token.RefreshToken)
if err == nil {
isr.token = token
return isr.token, nil
}
token, err = isr.new.Issue(ctx)
if err != nil {
return Token{}, err
}
isr.token = token
return isr.token, nil
}

118
token.go Normal file
View File

@ -0,0 +1,118 @@
package oauth2
import (
"encoding/json"
"fmt"
"net/url"
"strconv"
"strings"
"time"
"github.com/pkg/errors"
)
const (
expiresThreshold = time.Second * 3
)
type timeNowFunc func() time.Time
type rawToken struct {
AccessToken string `json:"access_token"`
TokenType string `json:"token_type"`
ExpiresIn int64 `json:"expires_in"`
RefreshToken string `json:"refresh_token"`
Scope string `json:"scope"`
rawData map[string]any `json:"-"`
}
func (rt rawToken) makeToken(timeNow timeNowFunc) Token {
return Token{
AccessToken: rt.AccessToken,
TokenType: rt.TokenType,
ExpiresIn: timeNow().Add(time.Second * time.Duration(rt.ExpiresIn)),
RefreshToken: rt.RefreshToken,
Scope: rt.Scope,
rawData: rt.rawData,
}
}
type Token struct {
AccessToken string
TokenType string
ExpiresIn time.Time
RefreshToken string
Scope string
rawData map[string]any
}
func (t Token) Valid(now time.Time) bool {
return len(t.AccessToken) > 0 && !t.expired(now)
}
func (t Token) expired(now time.Time) bool {
return t.ExpiresIn.Add(-expiresThreshold).Before(now)
}
func (t Token) Type() string {
if len(t.TokenType) == 0 {
return "Bearer"
}
if strings.EqualFold(t.TokenType, "bearer") {
return "Bearer"
}
return t.TokenType
}
func (t Token) AuthorizationHeader() string {
return fmt.Sprintf("%s %s", t.Type(), t.AccessToken)
}
func (t Token) Extra(key string) any {
v, ok := t.rawData[key]
if !ok {
return nil
}
return v
}
func parseToken(data []byte, contentType string, timeNow timeNowFunc) (Token, error) {
switch contentType {
case "application/x-www-form-urlencoded", "text/plain":
vals, err := url.ParseQuery(string(data))
if err != nil {
return Token{}, errors.WithStack(err)
}
expiresIn, err := strconv.ParseInt(vals.Get("expires_in"), 10, 64)
if err != nil {
return Token{}, errors.WithStack(err)
}
rawData := make(map[string]any, len(vals))
for k := range vals {
rawData[k] = vals.Get(k)
}
rt := rawToken{
AccessToken: vals.Get("access_token"),
TokenType: vals.Get("token_type"),
ExpiresIn: expiresIn,
RefreshToken: vals.Get("refresh_token"),
Scope: vals.Get("scope"),
rawData: rawData,
}
return rt.makeToken(timeNow), nil
default:
var rt rawToken
if err := json.Unmarshal(data, &rt); err != nil {
return Token{}, errors.WithStack(err)
}
json.Unmarshal(data, &rt.rawData)
return rt.makeToken(timeNow), nil
}
}