oauth2/token.go

119 lines
2.4 KiB
Go

package oauth2
import (
"encoding/json"
"fmt"
"net/url"
"strconv"
"strings"
"time"
"github.com/pkg/errors"
)
const (
expiresThreshold = time.Second * 3
)
type timeNowFunc func() time.Time
type rawToken struct {
AccessToken string `json:"access_token"`
TokenType string `json:"token_type"`
ExpiresIn int64 `json:"expires_in"`
RefreshToken string `json:"refresh_token"`
Scope string `json:"scope"`
rawData map[string]any `json:"-"`
}
func (rt rawToken) makeToken(timeNow timeNowFunc) Token {
return Token{
AccessToken: rt.AccessToken,
TokenType: rt.TokenType,
ExpiresIn: timeNow().Add(time.Second * time.Duration(rt.ExpiresIn)),
RefreshToken: rt.RefreshToken,
Scope: rt.Scope,
rawData: rt.rawData,
}
}
type Token struct {
AccessToken string
TokenType string
ExpiresIn time.Time
RefreshToken string
Scope string
rawData map[string]any
}
func (t Token) Valid(now time.Time) bool {
return len(t.AccessToken) > 0 && !t.expired(now)
}
func (t Token) expired(now time.Time) bool {
return t.ExpiresIn.Add(-expiresThreshold).Before(now)
}
func (t Token) Type() string {
if len(t.TokenType) == 0 {
return "Bearer"
}
if strings.EqualFold(t.TokenType, "bearer") {
return "Bearer"
}
return t.TokenType
}
func (t Token) AuthorizationHeader() string {
return fmt.Sprintf("%s %s", t.Type(), t.AccessToken)
}
func (t Token) Extra(key string) any {
v, ok := t.rawData[key]
if !ok {
return nil
}
return v
}
func parseToken(data []byte, contentType string, timeNow timeNowFunc) (Token, error) {
switch contentType {
case "application/x-www-form-urlencoded", "text/plain":
vals, err := url.ParseQuery(string(data))
if err != nil {
return Token{}, errors.WithStack(err)
}
expiresIn, err := strconv.ParseInt(vals.Get("expires_in"), 10, 64)
if err != nil {
return Token{}, errors.WithStack(err)
}
rawData := make(map[string]any, len(vals))
for k := range vals {
rawData[k] = vals.Get(k)
}
rt := rawToken{
AccessToken: vals.Get("access_token"),
TokenType: vals.Get("token_type"),
ExpiresIn: expiresIn,
RefreshToken: vals.Get("refresh_token"),
Scope: vals.Get("scope"),
rawData: rawData,
}
return rt.makeToken(timeNow), nil
default:
var rt rawToken
if err := json.Unmarshal(data, &rt); err != nil {
return Token{}, errors.WithStack(err)
}
json.Unmarshal(data, &rt.rawData)
return rt.makeToken(timeNow), nil
}
}