credentials_test.go 6.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243
  1. /*
  2. *
  3. * Copyright 2016 gRPC authors.
  4. *
  5. * Licensed under the Apache License, Version 2.0 (the "License");
  6. * you may not use this file except in compliance with the License.
  7. * You may obtain a copy of the License at
  8. *
  9. * http://www.apache.org/licenses/LICENSE-2.0
  10. *
  11. * Unless required by applicable law or agreed to in writing, software
  12. * distributed under the License is distributed on an "AS IS" BASIS,
  13. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. * See the License for the specific language governing permissions and
  15. * limitations under the License.
  16. *
  17. */
  18. package credentials
  19. import (
  20. "context"
  21. "crypto/tls"
  22. "net"
  23. "reflect"
  24. "testing"
  25. "google.golang.org/grpc/testdata"
  26. )
  27. func TestTLSOverrideServerName(t *testing.T) {
  28. expectedServerName := "server.name"
  29. c := NewTLS(nil)
  30. c.OverrideServerName(expectedServerName)
  31. if c.Info().ServerName != expectedServerName {
  32. t.Fatalf("c.Info().ServerName = %v, want %v", c.Info().ServerName, expectedServerName)
  33. }
  34. }
  35. func TestTLSClone(t *testing.T) {
  36. expectedServerName := "server.name"
  37. c := NewTLS(nil)
  38. c.OverrideServerName(expectedServerName)
  39. cc := c.Clone()
  40. if cc.Info().ServerName != expectedServerName {
  41. t.Fatalf("cc.Info().ServerName = %v, want %v", cc.Info().ServerName, expectedServerName)
  42. }
  43. cc.OverrideServerName("")
  44. if c.Info().ServerName != expectedServerName {
  45. t.Fatalf("Change in clone should not affect the original, c.Info().ServerName = %v, want %v", c.Info().ServerName, expectedServerName)
  46. }
  47. }
  48. type serverHandshake func(net.Conn) (AuthInfo, error)
  49. func TestClientHandshakeReturnsAuthInfo(t *testing.T) {
  50. done := make(chan AuthInfo, 1)
  51. lis := launchServer(t, tlsServerHandshake, done)
  52. defer lis.Close()
  53. lisAddr := lis.Addr().String()
  54. clientAuthInfo := clientHandle(t, gRPCClientHandshake, lisAddr)
  55. // wait until server sends serverAuthInfo or fails.
  56. serverAuthInfo, ok := <-done
  57. if !ok {
  58. t.Fatalf("Error at server-side")
  59. }
  60. if !compare(clientAuthInfo, serverAuthInfo) {
  61. t.Fatalf("c.ClientHandshake(_, %v, _) = %v, want %v.", lisAddr, clientAuthInfo, serverAuthInfo)
  62. }
  63. }
  64. func TestServerHandshakeReturnsAuthInfo(t *testing.T) {
  65. done := make(chan AuthInfo, 1)
  66. lis := launchServer(t, gRPCServerHandshake, done)
  67. defer lis.Close()
  68. clientAuthInfo := clientHandle(t, tlsClientHandshake, lis.Addr().String())
  69. // wait until server sends serverAuthInfo or fails.
  70. serverAuthInfo, ok := <-done
  71. if !ok {
  72. t.Fatalf("Error at server-side")
  73. }
  74. if !compare(clientAuthInfo, serverAuthInfo) {
  75. t.Fatalf("ServerHandshake(_) = %v, want %v.", serverAuthInfo, clientAuthInfo)
  76. }
  77. }
  78. func TestServerAndClientHandshake(t *testing.T) {
  79. done := make(chan AuthInfo, 1)
  80. lis := launchServer(t, gRPCServerHandshake, done)
  81. defer lis.Close()
  82. clientAuthInfo := clientHandle(t, gRPCClientHandshake, lis.Addr().String())
  83. // wait until server sends serverAuthInfo or fails.
  84. serverAuthInfo, ok := <-done
  85. if !ok {
  86. t.Fatalf("Error at server-side")
  87. }
  88. if !compare(clientAuthInfo, serverAuthInfo) {
  89. t.Fatalf("AuthInfo returned by server: %v and client: %v aren't same", serverAuthInfo, clientAuthInfo)
  90. }
  91. }
  92. func compare(a1, a2 AuthInfo) bool {
  93. if a1.AuthType() != a2.AuthType() {
  94. return false
  95. }
  96. switch a1.AuthType() {
  97. case "tls":
  98. state1 := a1.(TLSInfo).State
  99. state2 := a2.(TLSInfo).State
  100. if state1.Version == state2.Version &&
  101. state1.HandshakeComplete == state2.HandshakeComplete &&
  102. state1.CipherSuite == state2.CipherSuite &&
  103. state1.NegotiatedProtocol == state2.NegotiatedProtocol {
  104. return true
  105. }
  106. return false
  107. default:
  108. return false
  109. }
  110. }
  111. func launchServer(t *testing.T, hs serverHandshake, done chan AuthInfo) net.Listener {
  112. lis, err := net.Listen("tcp", "localhost:0")
  113. if err != nil {
  114. t.Fatalf("Failed to listen: %v", err)
  115. }
  116. go serverHandle(t, hs, done, lis)
  117. return lis
  118. }
  119. // Is run in a separate goroutine.
  120. func serverHandle(t *testing.T, hs serverHandshake, done chan AuthInfo, lis net.Listener) {
  121. serverRawConn, err := lis.Accept()
  122. if err != nil {
  123. t.Errorf("Server failed to accept connection: %v", err)
  124. close(done)
  125. return
  126. }
  127. serverAuthInfo, err := hs(serverRawConn)
  128. if err != nil {
  129. t.Errorf("Server failed while handshake. Error: %v", err)
  130. serverRawConn.Close()
  131. close(done)
  132. return
  133. }
  134. done <- serverAuthInfo
  135. }
  136. func clientHandle(t *testing.T, hs func(net.Conn, string) (AuthInfo, error), lisAddr string) AuthInfo {
  137. conn, err := net.Dial("tcp", lisAddr)
  138. if err != nil {
  139. t.Fatalf("Client failed to connect to %s. Error: %v", lisAddr, err)
  140. }
  141. defer conn.Close()
  142. clientAuthInfo, err := hs(conn, lisAddr)
  143. if err != nil {
  144. t.Fatalf("Error on client while handshake. Error: %v", err)
  145. }
  146. return clientAuthInfo
  147. }
  148. // Server handshake implementation in gRPC.
  149. func gRPCServerHandshake(conn net.Conn) (AuthInfo, error) {
  150. serverTLS, err := NewServerTLSFromFile(testdata.Path("server1.pem"), testdata.Path("server1.key"))
  151. if err != nil {
  152. return nil, err
  153. }
  154. _, serverAuthInfo, err := serverTLS.ServerHandshake(conn)
  155. if err != nil {
  156. return nil, err
  157. }
  158. return serverAuthInfo, nil
  159. }
  160. // Client handshake implementation in gRPC.
  161. func gRPCClientHandshake(conn net.Conn, lisAddr string) (AuthInfo, error) {
  162. clientTLS := NewTLS(&tls.Config{InsecureSkipVerify: true})
  163. _, authInfo, err := clientTLS.ClientHandshake(context.Background(), lisAddr, conn)
  164. if err != nil {
  165. return nil, err
  166. }
  167. return authInfo, nil
  168. }
  169. func tlsServerHandshake(conn net.Conn) (AuthInfo, error) {
  170. cert, err := tls.LoadX509KeyPair(testdata.Path("server1.pem"), testdata.Path("server1.key"))
  171. if err != nil {
  172. return nil, err
  173. }
  174. serverTLSConfig := &tls.Config{Certificates: []tls.Certificate{cert}}
  175. serverConn := tls.Server(conn, serverTLSConfig)
  176. err = serverConn.Handshake()
  177. if err != nil {
  178. return nil, err
  179. }
  180. return TLSInfo{State: serverConn.ConnectionState()}, nil
  181. }
  182. func tlsClientHandshake(conn net.Conn, _ string) (AuthInfo, error) {
  183. clientTLSConfig := &tls.Config{InsecureSkipVerify: true}
  184. clientConn := tls.Client(conn, clientTLSConfig)
  185. if err := clientConn.Handshake(); err != nil {
  186. return nil, err
  187. }
  188. return TLSInfo{State: clientConn.ConnectionState()}, nil
  189. }
  190. func TestAppendH2ToNextProtos(t *testing.T) {
  191. tests := []struct {
  192. name string
  193. ps []string
  194. want []string
  195. }{
  196. {
  197. name: "empty",
  198. ps: nil,
  199. want: []string{"h2"},
  200. },
  201. {
  202. name: "only h2",
  203. ps: []string{"h2"},
  204. want: []string{"h2"},
  205. },
  206. {
  207. name: "with h2",
  208. ps: []string{"alpn", "h2"},
  209. want: []string{"alpn", "h2"},
  210. },
  211. {
  212. name: "no h2",
  213. ps: []string{"alpn"},
  214. want: []string{"alpn", "h2"},
  215. },
  216. }
  217. for _, tt := range tests {
  218. t.Run(tt.name, func(t *testing.T) {
  219. if got := appendH2ToNextProtos(tt.ps); !reflect.DeepEqual(got, tt.want) {
  220. t.Errorf("appendH2ToNextProtos() = %v, want %v", got, tt.want)
  221. }
  222. })
  223. }
  224. }