sts_test.go 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764
  1. // +build go1.13
  2. /*
  3. *
  4. * Copyright 2020 gRPC authors.
  5. *
  6. * Licensed under the Apache License, Version 2.0 (the "License");
  7. * you may not use this file except in compliance with the License.
  8. * You may obtain a copy of the License at
  9. *
  10. * http://www.apache.org/licenses/LICENSE-2.0
  11. *
  12. * Unless required by applicable law or agreed to in writing, software
  13. * distributed under the License is distributed on an "AS IS" BASIS,
  14. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  15. * See the License for the specific language governing permissions and
  16. * limitations under the License.
  17. *
  18. */
  19. package sts
  20. import (
  21. "bytes"
  22. "context"
  23. "crypto/x509"
  24. "encoding/json"
  25. "errors"
  26. "fmt"
  27. "io/ioutil"
  28. "net/http"
  29. "net/http/httputil"
  30. "strings"
  31. "testing"
  32. "time"
  33. "github.com/google/go-cmp/cmp"
  34. "google.golang.org/grpc/credentials"
  35. "google.golang.org/grpc/internal"
  36. "google.golang.org/grpc/internal/grpctest"
  37. "google.golang.org/grpc/internal/testutils"
  38. )
  39. const (
  40. requestedTokenType = "urn:ietf:params:oauth:token-type:access-token"
  41. actorTokenPath = "/var/run/secrets/token.jwt"
  42. actorTokenType = "urn:ietf:params:oauth:token-type:refresh_token"
  43. actorTokenContents = "actorToken.jwt.contents"
  44. accessTokenContents = "access_token"
  45. subjectTokenPath = "/var/run/secrets/token.jwt"
  46. subjectTokenType = "urn:ietf:params:oauth:token-type:id_token"
  47. subjectTokenContents = "subjectToken.jwt.contents"
  48. serviceURI = "http://localhost"
  49. exampleResource = "https://backend.example.com/api"
  50. exampleAudience = "example-backend-service"
  51. testScope = "https://www.googleapis.com/auth/monitoring"
  52. )
  53. var (
  54. goodOptions = Options{
  55. TokenExchangeServiceURI: serviceURI,
  56. Audience: exampleAudience,
  57. RequestedTokenType: requestedTokenType,
  58. SubjectTokenPath: subjectTokenPath,
  59. SubjectTokenType: subjectTokenType,
  60. }
  61. goodRequestParams = &requestParameters{
  62. GrantType: tokenExchangeGrantType,
  63. Audience: exampleAudience,
  64. Scope: defaultCloudPlatformScope,
  65. RequestedTokenType: requestedTokenType,
  66. SubjectToken: subjectTokenContents,
  67. SubjectTokenType: subjectTokenType,
  68. }
  69. goodMetadata = map[string]string{
  70. "Authorization": fmt.Sprintf("Bearer %s", accessTokenContents),
  71. }
  72. )
  73. type s struct {
  74. grpctest.Tester
  75. }
  76. func Test(t *testing.T) {
  77. grpctest.RunSubTests(t, s{})
  78. }
  79. // A struct that implements AuthInfo interface and added to the context passed
  80. // to GetRequestMetadata from tests.
  81. type testAuthInfo struct {
  82. credentials.CommonAuthInfo
  83. }
  84. func (ta testAuthInfo) AuthType() string {
  85. return "testAuthInfo"
  86. }
  87. func createTestContext(ctx context.Context, s credentials.SecurityLevel) context.Context {
  88. auth := &testAuthInfo{CommonAuthInfo: credentials.CommonAuthInfo{SecurityLevel: s}}
  89. ri := credentials.RequestInfo{
  90. Method: "testInfo",
  91. AuthInfo: auth,
  92. }
  93. return internal.NewRequestInfoContext.(func(context.Context, credentials.RequestInfo) context.Context)(ctx, ri)
  94. }
  95. // errReader implements the io.Reader interface and returns an error from the
  96. // Read method.
  97. type errReader struct{}
  98. func (r errReader) Read(b []byte) (n int, err error) {
  99. return 0, errors.New("read error")
  100. }
  101. // We need a function to construct the response instead of simply declaring it
  102. // as a variable since the the response body will be consumed by the
  103. // credentials, and therefore we will need a new one everytime.
  104. func makeGoodResponse() *http.Response {
  105. respJSON, _ := json.Marshal(responseParameters{
  106. AccessToken: accessTokenContents,
  107. IssuedTokenType: "urn:ietf:params:oauth:token-type:access_token",
  108. TokenType: "Bearer",
  109. ExpiresIn: 3600,
  110. })
  111. respBody := ioutil.NopCloser(bytes.NewReader(respJSON))
  112. return &http.Response{
  113. Status: "200 OK",
  114. StatusCode: http.StatusOK,
  115. Body: respBody,
  116. }
  117. }
  118. // fakeHTTPDoer helps mock out the http.Client.Do calls made by the credentials
  119. // code under test. It makes the http.Request made by the credentials available
  120. // through a channel, and makes it possible to inject various responses.
  121. type fakeHTTPDoer struct {
  122. reqCh *testutils.Channel
  123. respCh *testutils.Channel
  124. err error
  125. }
  126. func (fc *fakeHTTPDoer) Do(req *http.Request) (*http.Response, error) {
  127. fc.reqCh.Send(req)
  128. val, err := fc.respCh.Receive()
  129. if err != nil {
  130. return nil, err
  131. }
  132. return val.(*http.Response), fc.err
  133. }
  134. // Overrides the http.Client with a fakeClient which sends a good response.
  135. func overrideHTTPClientGood() (*fakeHTTPDoer, func()) {
  136. fc := &fakeHTTPDoer{
  137. reqCh: testutils.NewChannel(),
  138. respCh: testutils.NewChannel(),
  139. }
  140. fc.respCh.Send(makeGoodResponse())
  141. origMakeHTTPDoer := makeHTTPDoer
  142. makeHTTPDoer = func(_ *x509.CertPool) httpDoer { return fc }
  143. return fc, func() { makeHTTPDoer = origMakeHTTPDoer }
  144. }
  145. // Overrides the http.Client with the provided fakeClient.
  146. func overrideHTTPClient(fc *fakeHTTPDoer) func() {
  147. origMakeHTTPDoer := makeHTTPDoer
  148. makeHTTPDoer = func(_ *x509.CertPool) httpDoer { return fc }
  149. return func() { makeHTTPDoer = origMakeHTTPDoer }
  150. }
  151. // Overrides the subject token read to return a const which we can compare in
  152. // our tests.
  153. func overrideSubjectTokenGood() func() {
  154. origReadSubjectTokenFrom := readSubjectTokenFrom
  155. readSubjectTokenFrom = func(path string) ([]byte, error) {
  156. return []byte(subjectTokenContents), nil
  157. }
  158. return func() { readSubjectTokenFrom = origReadSubjectTokenFrom }
  159. }
  160. // Overrides the subject token read to always return an error.
  161. func overrideSubjectTokenError() func() {
  162. origReadSubjectTokenFrom := readSubjectTokenFrom
  163. readSubjectTokenFrom = func(path string) ([]byte, error) {
  164. return nil, errors.New("error reading subject token")
  165. }
  166. return func() { readSubjectTokenFrom = origReadSubjectTokenFrom }
  167. }
  168. // Overrides the actor token read to return a const which we can compare in
  169. // our tests.
  170. func overrideActorTokenGood() func() {
  171. origReadActorTokenFrom := readActorTokenFrom
  172. readActorTokenFrom = func(path string) ([]byte, error) {
  173. return []byte(actorTokenContents), nil
  174. }
  175. return func() { readActorTokenFrom = origReadActorTokenFrom }
  176. }
  177. // Overrides the actor token read to always return an error.
  178. func overrideActorTokenError() func() {
  179. origReadActorTokenFrom := readActorTokenFrom
  180. readActorTokenFrom = func(path string) ([]byte, error) {
  181. return nil, errors.New("error reading actor token")
  182. }
  183. return func() { readActorTokenFrom = origReadActorTokenFrom }
  184. }
  185. // compareRequest compares the http.Request received in the test with the
  186. // expected requestParameters specified in wantReqParams.
  187. func compareRequest(gotRequest *http.Request, wantReqParams *requestParameters) error {
  188. jsonBody, err := json.Marshal(wantReqParams)
  189. if err != nil {
  190. return err
  191. }
  192. wantReq, err := http.NewRequest("POST", serviceURI, bytes.NewBuffer(jsonBody))
  193. if err != nil {
  194. return fmt.Errorf("failed to create http request: %v", err)
  195. }
  196. wantReq.Header.Set("Content-Type", "application/json")
  197. wantR, err := httputil.DumpRequestOut(wantReq, true)
  198. if err != nil {
  199. return err
  200. }
  201. gotR, err := httputil.DumpRequestOut(gotRequest, true)
  202. if err != nil {
  203. return err
  204. }
  205. if diff := cmp.Diff(string(wantR), string(gotR)); diff != "" {
  206. return fmt.Errorf("sts request diff (-want +got):\n%s", diff)
  207. }
  208. return nil
  209. }
  210. // receiveAndCompareRequest waits for a request to be sent out by the
  211. // credentials implementation using the fakeHTTPClient and compares it to an
  212. // expected goodRequest. This is expected to be called in a separate goroutine
  213. // by the tests. So, any errors encountered are pushed to an error channel
  214. // which is monitored by the test.
  215. func receiveAndCompareRequest(reqCh *testutils.Channel, errCh chan error) {
  216. val, err := reqCh.Receive()
  217. if err != nil {
  218. errCh <- err
  219. return
  220. }
  221. req := val.(*http.Request)
  222. if err := compareRequest(req, goodRequestParams); err != nil {
  223. errCh <- err
  224. return
  225. }
  226. errCh <- nil
  227. }
  228. // TestGetRequestMetadataSuccess verifies the successful case of sending an
  229. // token exchange request and processing the response.
  230. func (s) TestGetRequestMetadataSuccess(t *testing.T) {
  231. defer overrideSubjectTokenGood()()
  232. fc, cancel := overrideHTTPClientGood()
  233. defer cancel()
  234. creds, err := NewCredentials(goodOptions)
  235. if err != nil {
  236. t.Fatalf("NewCredentials(%v) = %v", goodOptions, err)
  237. }
  238. errCh := make(chan error, 1)
  239. go receiveAndCompareRequest(fc.reqCh, errCh)
  240. gotMetadata, err := creds.GetRequestMetadata(createTestContext(context.Background(), credentials.PrivacyAndIntegrity), "")
  241. if err != nil {
  242. t.Fatalf("creds.GetRequestMetadata() = %v", err)
  243. }
  244. if !cmp.Equal(gotMetadata, goodMetadata) {
  245. t.Fatalf("creds.GetRequestMetadata() = %v, want %v", gotMetadata, goodMetadata)
  246. }
  247. if err := <-errCh; err != nil {
  248. t.Fatal(err)
  249. }
  250. // Make another call to get request metadata and this should return contents
  251. // from the cache. This will fail if the credentials tries to send a fresh
  252. // request here since we have not configured our fakeClient to return any
  253. // response on retries.
  254. gotMetadata, err = creds.GetRequestMetadata(createTestContext(context.Background(), credentials.PrivacyAndIntegrity), "")
  255. if err != nil {
  256. t.Fatalf("creds.GetRequestMetadata() = %v", err)
  257. }
  258. if !cmp.Equal(gotMetadata, goodMetadata) {
  259. t.Fatalf("creds.GetRequestMetadata() = %v, want %v", gotMetadata, goodMetadata)
  260. }
  261. }
  262. // TestGetRequestMetadataBadSecurityLevel verifies the case where the
  263. // securityLevel specified in the context passed to GetRequestMetadata is not
  264. // sufficient.
  265. func (s) TestGetRequestMetadataBadSecurityLevel(t *testing.T) {
  266. defer overrideSubjectTokenGood()()
  267. creds, err := NewCredentials(goodOptions)
  268. if err != nil {
  269. t.Fatalf("NewCredentials(%v) = %v", goodOptions, err)
  270. }
  271. gotMetadata, err := creds.GetRequestMetadata(createTestContext(context.Background(), credentials.IntegrityOnly), "")
  272. if err == nil {
  273. t.Fatalf("creds.GetRequestMetadata() succeeded with metadata %v, expected to fail", gotMetadata)
  274. }
  275. }
  276. // TestGetRequestMetadataCacheExpiry verifies the case where the cached access
  277. // token has expired, and the credentials implementation will have to send a
  278. // fresh token exchange request.
  279. func (s) TestGetRequestMetadataCacheExpiry(t *testing.T) {
  280. const expiresInSecs = 1
  281. defer overrideSubjectTokenGood()()
  282. fc := &fakeHTTPDoer{
  283. reqCh: testutils.NewChannel(),
  284. respCh: testutils.NewChannel(),
  285. }
  286. defer overrideHTTPClient(fc)()
  287. creds, err := NewCredentials(goodOptions)
  288. if err != nil {
  289. t.Fatalf("NewCredentials(%v) = %v", goodOptions, err)
  290. }
  291. // The fakeClient is configured to return an access_token with a one second
  292. // expiry. So, in the second iteration, the credentials will find the cache
  293. // entry, but that would have expired, and therefore we expect it to send
  294. // out a fresh request.
  295. for i := 0; i < 2; i++ {
  296. errCh := make(chan error, 1)
  297. go receiveAndCompareRequest(fc.reqCh, errCh)
  298. respJSON, _ := json.Marshal(responseParameters{
  299. AccessToken: accessTokenContents,
  300. IssuedTokenType: "urn:ietf:params:oauth:token-type:access_token",
  301. TokenType: "Bearer",
  302. ExpiresIn: expiresInSecs,
  303. })
  304. respBody := ioutil.NopCloser(bytes.NewReader(respJSON))
  305. resp := &http.Response{
  306. Status: "200 OK",
  307. StatusCode: http.StatusOK,
  308. Body: respBody,
  309. }
  310. fc.respCh.Send(resp)
  311. gotMetadata, err := creds.GetRequestMetadata(createTestContext(context.Background(), credentials.PrivacyAndIntegrity), "")
  312. if err != nil {
  313. t.Fatalf("creds.GetRequestMetadata() = %v", err)
  314. }
  315. if !cmp.Equal(gotMetadata, goodMetadata) {
  316. t.Fatalf("creds.GetRequestMetadata() = %v, want %v", gotMetadata, goodMetadata)
  317. }
  318. if err := <-errCh; err != nil {
  319. t.Fatal(err)
  320. }
  321. time.Sleep(expiresInSecs * time.Second)
  322. }
  323. }
  324. // TestGetRequestMetadataBadResponses verifies the scenario where the token
  325. // exchange server returns bad responses.
  326. func (s) TestGetRequestMetadataBadResponses(t *testing.T) {
  327. tests := []struct {
  328. name string
  329. response *http.Response
  330. }{
  331. {
  332. name: "bad JSON",
  333. response: &http.Response{
  334. Status: "200 OK",
  335. StatusCode: http.StatusOK,
  336. Body: ioutil.NopCloser(strings.NewReader("not JSON")),
  337. },
  338. },
  339. {
  340. name: "no access token",
  341. response: &http.Response{
  342. Status: "200 OK",
  343. StatusCode: http.StatusOK,
  344. Body: ioutil.NopCloser(strings.NewReader("{}")),
  345. },
  346. },
  347. }
  348. for _, test := range tests {
  349. t.Run(test.name, func(t *testing.T) {
  350. defer overrideSubjectTokenGood()()
  351. fc := &fakeHTTPDoer{
  352. reqCh: testutils.NewChannel(),
  353. respCh: testutils.NewChannel(),
  354. }
  355. defer overrideHTTPClient(fc)()
  356. creds, err := NewCredentials(goodOptions)
  357. if err != nil {
  358. t.Fatalf("NewCredentials(%v) = %v", goodOptions, err)
  359. }
  360. errCh := make(chan error, 1)
  361. go receiveAndCompareRequest(fc.reqCh, errCh)
  362. fc.respCh.Send(test.response)
  363. if _, err := creds.GetRequestMetadata(createTestContext(context.Background(), credentials.PrivacyAndIntegrity), ""); err == nil {
  364. t.Fatal("creds.GetRequestMetadata() succeeded when expected to fail")
  365. }
  366. if err := <-errCh; err != nil {
  367. t.Fatal(err)
  368. }
  369. })
  370. }
  371. }
  372. // TestGetRequestMetadataBadSubjectTokenRead verifies the scenario where the
  373. // attempt to read the subjectToken fails.
  374. func (s) TestGetRequestMetadataBadSubjectTokenRead(t *testing.T) {
  375. defer overrideSubjectTokenError()()
  376. fc, cancel := overrideHTTPClientGood()
  377. defer cancel()
  378. creds, err := NewCredentials(goodOptions)
  379. if err != nil {
  380. t.Fatalf("NewCredentials(%v) = %v", goodOptions, err)
  381. }
  382. errCh := make(chan error, 1)
  383. go func() {
  384. if _, err := fc.reqCh.Receive(); err != testutils.ErrRecvTimeout {
  385. errCh <- err
  386. return
  387. }
  388. errCh <- nil
  389. }()
  390. if _, err := creds.GetRequestMetadata(createTestContext(context.Background(), credentials.PrivacyAndIntegrity), ""); err == nil {
  391. t.Fatal("creds.GetRequestMetadata() succeeded when expected to fail")
  392. }
  393. if err := <-errCh; err != nil {
  394. t.Fatal(err)
  395. }
  396. }
  397. func (s) TestNewCredentials(t *testing.T) {
  398. tests := []struct {
  399. name string
  400. opts Options
  401. errSystemRoots bool
  402. wantErr bool
  403. }{
  404. {
  405. name: "invalid options - empty subjectTokenPath",
  406. opts: Options{
  407. TokenExchangeServiceURI: serviceURI,
  408. },
  409. wantErr: true,
  410. },
  411. {
  412. name: "invalid system root certs",
  413. opts: goodOptions,
  414. errSystemRoots: true,
  415. wantErr: true,
  416. },
  417. {
  418. name: "good case",
  419. opts: goodOptions,
  420. },
  421. }
  422. for _, test := range tests {
  423. t.Run(test.name, func(t *testing.T) {
  424. if test.errSystemRoots {
  425. oldSystemRoots := loadSystemCertPool
  426. loadSystemCertPool = func() (*x509.CertPool, error) {
  427. return nil, errors.New("failed to load system cert pool")
  428. }
  429. defer func() {
  430. loadSystemCertPool = oldSystemRoots
  431. }()
  432. }
  433. creds, err := NewCredentials(test.opts)
  434. if (err != nil) != test.wantErr {
  435. t.Fatalf("NewCredentials(%v) = %v, want %v", test.opts, err, test.wantErr)
  436. }
  437. if err == nil {
  438. if !creds.RequireTransportSecurity() {
  439. t.Errorf("creds.RequireTransportSecurity() returned false")
  440. }
  441. }
  442. })
  443. }
  444. }
  445. func (s) TestValidateOptions(t *testing.T) {
  446. tests := []struct {
  447. name string
  448. opts Options
  449. wantErrPrefix string
  450. }{
  451. {
  452. name: "empty token exchange service URI",
  453. opts: Options{},
  454. wantErrPrefix: "empty token_exchange_service_uri in options",
  455. },
  456. {
  457. name: "invalid URI",
  458. opts: Options{
  459. TokenExchangeServiceURI: "\tI'm a bad URI\n",
  460. },
  461. wantErrPrefix: "invalid control character in URL",
  462. },
  463. {
  464. name: "unsupported scheme",
  465. opts: Options{
  466. TokenExchangeServiceURI: "unix:///path/to/socket",
  467. },
  468. wantErrPrefix: "scheme is not supported",
  469. },
  470. {
  471. name: "empty subjectTokenPath",
  472. opts: Options{
  473. TokenExchangeServiceURI: serviceURI,
  474. },
  475. wantErrPrefix: "required field SubjectTokenPath is not specified",
  476. },
  477. {
  478. name: "empty subjectTokenType",
  479. opts: Options{
  480. TokenExchangeServiceURI: serviceURI,
  481. SubjectTokenPath: subjectTokenPath,
  482. },
  483. wantErrPrefix: "required field SubjectTokenType is not specified",
  484. },
  485. {
  486. name: "good options",
  487. opts: goodOptions,
  488. },
  489. }
  490. for _, test := range tests {
  491. t.Run(test.name, func(t *testing.T) {
  492. err := validateOptions(test.opts)
  493. if (err != nil) != (test.wantErrPrefix != "") {
  494. t.Errorf("validateOptions(%v) = %v, want %v", test.opts, err, test.wantErrPrefix)
  495. }
  496. if err != nil && !strings.Contains(err.Error(), test.wantErrPrefix) {
  497. t.Errorf("validateOptions(%v) = %v, want %v", test.opts, err, test.wantErrPrefix)
  498. }
  499. })
  500. }
  501. }
  502. func (s) TestConstructRequest(t *testing.T) {
  503. tests := []struct {
  504. name string
  505. opts Options
  506. subjectTokenReadErr bool
  507. actorTokenReadErr bool
  508. wantReqParams *requestParameters
  509. wantErr bool
  510. }{
  511. {
  512. name: "subject token read failure",
  513. subjectTokenReadErr: true,
  514. opts: goodOptions,
  515. wantErr: true,
  516. },
  517. {
  518. name: "actor token read failure",
  519. actorTokenReadErr: true,
  520. opts: Options{
  521. TokenExchangeServiceURI: serviceURI,
  522. Audience: exampleAudience,
  523. RequestedTokenType: requestedTokenType,
  524. SubjectTokenPath: subjectTokenPath,
  525. SubjectTokenType: subjectTokenType,
  526. ActorTokenPath: actorTokenPath,
  527. ActorTokenType: actorTokenType,
  528. },
  529. wantErr: true,
  530. },
  531. {
  532. name: "default cloud platform scope",
  533. opts: goodOptions,
  534. wantReqParams: goodRequestParams,
  535. },
  536. {
  537. name: "all good",
  538. opts: Options{
  539. TokenExchangeServiceURI: serviceURI,
  540. Resource: exampleResource,
  541. Audience: exampleAudience,
  542. Scope: testScope,
  543. RequestedTokenType: requestedTokenType,
  544. SubjectTokenPath: subjectTokenPath,
  545. SubjectTokenType: subjectTokenType,
  546. ActorTokenPath: actorTokenPath,
  547. ActorTokenType: actorTokenType,
  548. },
  549. wantReqParams: &requestParameters{
  550. GrantType: tokenExchangeGrantType,
  551. Resource: exampleResource,
  552. Audience: exampleAudience,
  553. Scope: testScope,
  554. RequestedTokenType: requestedTokenType,
  555. SubjectToken: subjectTokenContents,
  556. SubjectTokenType: subjectTokenType,
  557. ActorToken: actorTokenContents,
  558. ActorTokenType: actorTokenType,
  559. },
  560. },
  561. }
  562. for _, test := range tests {
  563. t.Run(test.name, func(t *testing.T) {
  564. if test.subjectTokenReadErr {
  565. defer overrideSubjectTokenError()()
  566. } else {
  567. defer overrideSubjectTokenGood()()
  568. }
  569. if test.actorTokenReadErr {
  570. defer overrideActorTokenError()()
  571. } else {
  572. defer overrideActorTokenGood()()
  573. }
  574. gotRequest, err := constructRequest(context.Background(), test.opts)
  575. if (err != nil) != test.wantErr {
  576. t.Fatalf("constructRequest(%v) = %v, wantErr: %v", test.opts, err, test.wantErr)
  577. }
  578. if test.wantErr {
  579. return
  580. }
  581. if err := compareRequest(gotRequest, test.wantReqParams); err != nil {
  582. t.Fatal(err)
  583. }
  584. })
  585. }
  586. }
  587. func (s) TestSendRequest(t *testing.T) {
  588. defer overrideSubjectTokenGood()()
  589. req, err := constructRequest(context.Background(), goodOptions)
  590. if err != nil {
  591. t.Fatal(err)
  592. }
  593. tests := []struct {
  594. name string
  595. resp *http.Response
  596. respErr error
  597. wantErr bool
  598. }{
  599. {
  600. name: "client error",
  601. respErr: errors.New("http.Client.Do failed"),
  602. wantErr: true,
  603. },
  604. {
  605. name: "bad response body",
  606. resp: &http.Response{
  607. Status: "200 OK",
  608. StatusCode: http.StatusOK,
  609. Body: ioutil.NopCloser(errReader{}),
  610. },
  611. wantErr: true,
  612. },
  613. {
  614. name: "nonOK status code",
  615. resp: &http.Response{
  616. Status: "400 BadRequest",
  617. StatusCode: http.StatusBadRequest,
  618. Body: ioutil.NopCloser(strings.NewReader("")),
  619. },
  620. wantErr: true,
  621. },
  622. {
  623. name: "good case",
  624. resp: makeGoodResponse(),
  625. },
  626. }
  627. for _, test := range tests {
  628. t.Run(test.name, func(t *testing.T) {
  629. client := &fakeHTTPDoer{
  630. reqCh: testutils.NewChannel(),
  631. respCh: testutils.NewChannel(),
  632. err: test.respErr,
  633. }
  634. client.respCh.Send(test.resp)
  635. _, err := sendRequest(client, req)
  636. if (err != nil) != test.wantErr {
  637. t.Errorf("sendRequest(%v) = %v, wantErr: %v", req, err, test.wantErr)
  638. }
  639. })
  640. }
  641. }
  642. func (s) TestTokenInfoFromResponse(t *testing.T) {
  643. noAccessToken, _ := json.Marshal(responseParameters{
  644. IssuedTokenType: "urn:ietf:params:oauth:token-type:access_token",
  645. TokenType: "Bearer",
  646. ExpiresIn: 3600,
  647. })
  648. goodResponse, _ := json.Marshal(responseParameters{
  649. IssuedTokenType: requestedTokenType,
  650. AccessToken: accessTokenContents,
  651. TokenType: "Bearer",
  652. ExpiresIn: 3600,
  653. })
  654. tests := []struct {
  655. name string
  656. respBody []byte
  657. wantTokenInfo *tokenInfo
  658. wantErr bool
  659. }{
  660. {
  661. name: "bad JSON",
  662. respBody: []byte("not JSON"),
  663. wantErr: true,
  664. },
  665. {
  666. name: "empty response",
  667. respBody: []byte(""),
  668. wantErr: true,
  669. },
  670. {
  671. name: "non-empty response with no access token",
  672. respBody: noAccessToken,
  673. wantErr: true,
  674. },
  675. {
  676. name: "good response",
  677. respBody: goodResponse,
  678. wantTokenInfo: &tokenInfo{
  679. tokenType: "Bearer",
  680. token: accessTokenContents,
  681. },
  682. },
  683. }
  684. for _, test := range tests {
  685. t.Run(test.name, func(t *testing.T) {
  686. gotTokenInfo, err := tokenInfoFromResponse(test.respBody)
  687. if (err != nil) != test.wantErr {
  688. t.Fatalf("tokenInfoFromResponse(%+v) = %v, wantErr: %v", test.respBody, err, test.wantErr)
  689. }
  690. if test.wantErr {
  691. return
  692. }
  693. // Can't do a cmp.Equal on the whole struct since the expiryField
  694. // is populated based on time.Now().
  695. if gotTokenInfo.tokenType != test.wantTokenInfo.tokenType || gotTokenInfo.token != test.wantTokenInfo.token {
  696. t.Errorf("tokenInfoFromResponse(%+v) = %+v, want: %+v", test.respBody, gotTokenInfo, test.wantTokenInfo)
  697. }
  698. })
  699. }
  700. }