deviceauth.go 5.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198
  1. package oauth2
  2. import (
  3. "context"
  4. "encoding/json"
  5. "errors"
  6. "fmt"
  7. "io"
  8. "net/http"
  9. "net/url"
  10. "strings"
  11. "time"
  12. "golang.org/x/oauth2/internal"
  13. )
  14. // https://datatracker.ietf.org/doc/html/rfc8628#section-3.5
  15. const (
  16. errAuthorizationPending = "authorization_pending"
  17. errSlowDown = "slow_down"
  18. errAccessDenied = "access_denied"
  19. errExpiredToken = "expired_token"
  20. )
  21. // DeviceAuthResponse describes a successful RFC 8628 Device Authorization Response
  22. // https://datatracker.ietf.org/doc/html/rfc8628#section-3.2
  23. type DeviceAuthResponse struct {
  24. // DeviceCode
  25. DeviceCode string `json:"device_code"`
  26. // UserCode is the code the user should enter at the verification uri
  27. UserCode string `json:"user_code"`
  28. // VerificationURI is where user should enter the user code
  29. VerificationURI string `json:"verification_uri"`
  30. // VerificationURIComplete (if populated) includes the user code in the verification URI. This is typically shown to the user in non-textual form, such as a QR code.
  31. VerificationURIComplete string `json:"verification_uri_complete,omitempty"`
  32. // Expiry is when the device code and user code expire
  33. Expiry time.Time `json:"expires_in,omitempty"`
  34. // Interval is the duration in seconds that Poll should wait between requests
  35. Interval int64 `json:"interval,omitempty"`
  36. }
  37. func (d DeviceAuthResponse) MarshalJSON() ([]byte, error) {
  38. type Alias DeviceAuthResponse
  39. var expiresIn int64
  40. if !d.Expiry.IsZero() {
  41. expiresIn = int64(time.Until(d.Expiry).Seconds())
  42. }
  43. return json.Marshal(&struct {
  44. ExpiresIn int64 `json:"expires_in,omitempty"`
  45. *Alias
  46. }{
  47. ExpiresIn: expiresIn,
  48. Alias: (*Alias)(&d),
  49. })
  50. }
  51. func (c *DeviceAuthResponse) UnmarshalJSON(data []byte) error {
  52. type Alias DeviceAuthResponse
  53. aux := &struct {
  54. ExpiresIn int64 `json:"expires_in"`
  55. // workaround misspelling of verification_uri
  56. VerificationURL string `json:"verification_url"`
  57. *Alias
  58. }{
  59. Alias: (*Alias)(c),
  60. }
  61. if err := json.Unmarshal(data, &aux); err != nil {
  62. return err
  63. }
  64. if aux.ExpiresIn != 0 {
  65. c.Expiry = time.Now().UTC().Add(time.Second * time.Duration(aux.ExpiresIn))
  66. }
  67. if c.VerificationURI == "" {
  68. c.VerificationURI = aux.VerificationURL
  69. }
  70. return nil
  71. }
  72. // DeviceAuth returns a device auth struct which contains a device code
  73. // and authorization information provided for users to enter on another device.
  74. func (c *Config) DeviceAuth(ctx context.Context, opts ...AuthCodeOption) (*DeviceAuthResponse, error) {
  75. // https://datatracker.ietf.org/doc/html/rfc8628#section-3.1
  76. v := url.Values{
  77. "client_id": {c.ClientID},
  78. }
  79. if len(c.Scopes) > 0 {
  80. v.Set("scope", strings.Join(c.Scopes, " "))
  81. }
  82. for _, opt := range opts {
  83. opt.setValue(v)
  84. }
  85. return retrieveDeviceAuth(ctx, c, v)
  86. }
  87. func retrieveDeviceAuth(ctx context.Context, c *Config, v url.Values) (*DeviceAuthResponse, error) {
  88. if c.Endpoint.DeviceAuthURL == "" {
  89. return nil, errors.New("endpoint missing DeviceAuthURL")
  90. }
  91. req, err := http.NewRequest("POST", c.Endpoint.DeviceAuthURL, strings.NewReader(v.Encode()))
  92. if err != nil {
  93. return nil, err
  94. }
  95. req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
  96. req.Header.Set("Accept", "application/json")
  97. t := time.Now()
  98. r, err := internal.ContextClient(ctx).Do(req)
  99. if err != nil {
  100. return nil, err
  101. }
  102. body, err := io.ReadAll(io.LimitReader(r.Body, 1<<20))
  103. if err != nil {
  104. return nil, fmt.Errorf("oauth2: cannot auth device: %v", err)
  105. }
  106. if code := r.StatusCode; code < 200 || code > 299 {
  107. return nil, &RetrieveError{
  108. Response: r,
  109. Body: body,
  110. }
  111. }
  112. da := &DeviceAuthResponse{}
  113. err = json.Unmarshal(body, &da)
  114. if err != nil {
  115. return nil, fmt.Errorf("unmarshal %s", err)
  116. }
  117. if !da.Expiry.IsZero() {
  118. // Make a small adjustment to account for time taken by the request
  119. da.Expiry = da.Expiry.Add(-time.Since(t))
  120. }
  121. return da, nil
  122. }
  123. // DeviceAccessToken polls the server to exchange a device code for a token.
  124. func (c *Config) DeviceAccessToken(ctx context.Context, da *DeviceAuthResponse, opts ...AuthCodeOption) (*Token, error) {
  125. if !da.Expiry.IsZero() {
  126. var cancel context.CancelFunc
  127. ctx, cancel = context.WithDeadline(ctx, da.Expiry)
  128. defer cancel()
  129. }
  130. // https://datatracker.ietf.org/doc/html/rfc8628#section-3.4
  131. v := url.Values{
  132. "client_id": {c.ClientID},
  133. "grant_type": {"urn:ietf:params:oauth:grant-type:device_code"},
  134. "device_code": {da.DeviceCode},
  135. }
  136. if len(c.Scopes) > 0 {
  137. v.Set("scope", strings.Join(c.Scopes, " "))
  138. }
  139. for _, opt := range opts {
  140. opt.setValue(v)
  141. }
  142. // "If no value is provided, clients MUST use 5 as the default."
  143. // https://datatracker.ietf.org/doc/html/rfc8628#section-3.2
  144. interval := da.Interval
  145. if interval == 0 {
  146. interval = 5
  147. }
  148. ticker := time.NewTicker(time.Duration(interval) * time.Second)
  149. defer ticker.Stop()
  150. for {
  151. select {
  152. case <-ctx.Done():
  153. return nil, ctx.Err()
  154. case <-ticker.C:
  155. tok, err := retrieveToken(ctx, c, v)
  156. if err == nil {
  157. return tok, nil
  158. }
  159. e, ok := err.(*RetrieveError)
  160. if !ok {
  161. return nil, err
  162. }
  163. switch e.ErrorCode {
  164. case errSlowDown:
  165. // https://datatracker.ietf.org/doc/html/rfc8628#section-3.5
  166. // "the interval MUST be increased by 5 seconds for this and all subsequent requests"
  167. interval += 5
  168. ticker.Reset(time.Duration(interval) * time.Second)
  169. case errAuthorizationPending:
  170. // Do nothing.
  171. case errAccessDenied, errExpiredToken:
  172. fallthrough
  173. default:
  174. return tok, err
  175. }
  176. }
  177. }
  178. }