119 lines
2.4 KiB
Go
119 lines
2.4 KiB
Go
package oauth2
|
|
|
|
import (
|
|
"encoding/json"
|
|
"fmt"
|
|
"net/url"
|
|
"strconv"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/pkg/errors"
|
|
)
|
|
|
|
const (
|
|
expiresThreshold = time.Second * 3
|
|
)
|
|
|
|
type timeNowFunc func() time.Time
|
|
|
|
type rawToken struct {
|
|
AccessToken string `json:"access_token"`
|
|
TokenType string `json:"token_type"`
|
|
ExpiresIn int64 `json:"expires_in"`
|
|
RefreshToken string `json:"refresh_token"`
|
|
Scope string `json:"scope"`
|
|
|
|
rawData map[string]any `json:"-"`
|
|
}
|
|
|
|
func (rt rawToken) makeToken(timeNow timeNowFunc) Token {
|
|
return Token{
|
|
AccessToken: rt.AccessToken,
|
|
TokenType: rt.TokenType,
|
|
ExpiresIn: timeNow().Add(time.Second * time.Duration(rt.ExpiresIn)),
|
|
RefreshToken: rt.RefreshToken,
|
|
Scope: rt.Scope,
|
|
rawData: rt.rawData,
|
|
}
|
|
}
|
|
|
|
type Token struct {
|
|
AccessToken string
|
|
TokenType string
|
|
ExpiresIn time.Time
|
|
RefreshToken string
|
|
Scope string
|
|
|
|
rawData map[string]any
|
|
}
|
|
|
|
func (t Token) Valid(now time.Time) bool {
|
|
return len(t.AccessToken) > 0 && !t.expired(now)
|
|
}
|
|
|
|
func (t Token) expired(now time.Time) bool {
|
|
return t.ExpiresIn.Add(-expiresThreshold).Before(now)
|
|
}
|
|
|
|
func (t Token) Type() string {
|
|
if len(t.TokenType) == 0 {
|
|
return "Bearer"
|
|
}
|
|
if strings.EqualFold(t.TokenType, "bearer") {
|
|
return "Bearer"
|
|
}
|
|
return t.TokenType
|
|
}
|
|
|
|
func (t Token) AuthorizationHeader() string {
|
|
return fmt.Sprintf("%s %s", t.Type(), t.AccessToken)
|
|
}
|
|
|
|
func (t Token) Extra(key string) any {
|
|
v, ok := t.rawData[key]
|
|
if !ok {
|
|
return nil
|
|
}
|
|
return v
|
|
}
|
|
|
|
func parseToken(data []byte, contentType string, timeNow timeNowFunc) (Token, error) {
|
|
switch contentType {
|
|
case "application/x-www-form-urlencoded", "text/plain":
|
|
vals, err := url.ParseQuery(string(data))
|
|
if err != nil {
|
|
return Token{}, errors.WithStack(err)
|
|
}
|
|
|
|
expiresIn, err := strconv.ParseInt(vals.Get("expires_in"), 10, 64)
|
|
if err != nil {
|
|
return Token{}, errors.WithStack(err)
|
|
}
|
|
|
|
rawData := make(map[string]any, len(vals))
|
|
for k := range vals {
|
|
rawData[k] = vals.Get(k)
|
|
}
|
|
|
|
rt := rawToken{
|
|
AccessToken: vals.Get("access_token"),
|
|
TokenType: vals.Get("token_type"),
|
|
ExpiresIn: expiresIn,
|
|
RefreshToken: vals.Get("refresh_token"),
|
|
Scope: vals.Get("scope"),
|
|
rawData: rawData,
|
|
}
|
|
return rt.makeToken(timeNow), nil
|
|
|
|
default:
|
|
var rt rawToken
|
|
if err := json.Unmarshal(data, &rt); err != nil {
|
|
return Token{}, errors.WithStack(err)
|
|
}
|
|
json.Unmarshal(data, &rt.rawData)
|
|
|
|
return rt.makeToken(timeNow), nil
|
|
}
|
|
}
|