local_test.go 6.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225
  1. /*
  2. *
  3. * Copyright 2020 gRPC authors.
  4. *
  5. * Licensed under the Apache License, Version 2.0 (the "License");
  6. * you may not use this file except in compliance with the License.
  7. * You may obtain a copy of the License at
  8. *
  9. * http://www.apache.org/licenses/LICENSE-2.0
  10. *
  11. * Unless required by applicable law or agreed to in writing, software
  12. * distributed under the License is distributed on an "AS IS" BASIS,
  13. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. * See the License for the specific language governing permissions and
  15. * limitations under the License.
  16. *
  17. */
  18. package local
  19. import (
  20. "context"
  21. "fmt"
  22. "net"
  23. "runtime"
  24. "strings"
  25. "testing"
  26. "time"
  27. "google.golang.org/grpc/credentials"
  28. "google.golang.org/grpc/internal/grpctest"
  29. )
  30. const defaultTestTimeout = 10 * time.Second
  31. type s struct {
  32. grpctest.Tester
  33. }
  34. func Test(t *testing.T) {
  35. grpctest.RunSubTests(t, s{})
  36. }
  37. func (s) TestGetSecurityLevel(t *testing.T) {
  38. testCases := []struct {
  39. testNetwork string
  40. testAddr string
  41. want credentials.SecurityLevel
  42. }{
  43. {
  44. testNetwork: "tcp",
  45. testAddr: "127.0.0.1:10000",
  46. want: credentials.NoSecurity,
  47. },
  48. {
  49. testNetwork: "tcp",
  50. testAddr: "[::1]:10000",
  51. want: credentials.NoSecurity,
  52. },
  53. {
  54. testNetwork: "unix",
  55. testAddr: "/tmp/grpc_fullstack_test",
  56. want: credentials.PrivacyAndIntegrity,
  57. },
  58. {
  59. testNetwork: "tcp",
  60. testAddr: "192.168.0.1:10000",
  61. want: credentials.InvalidSecurityLevel,
  62. },
  63. }
  64. for _, tc := range testCases {
  65. got, _ := getSecurityLevel(tc.testNetwork, tc.testAddr)
  66. if got != tc.want {
  67. t.Fatalf("GetSeurityLevel(%s, %s) returned %s but want %s", tc.testNetwork, tc.testAddr, got.String(), tc.want.String())
  68. }
  69. }
  70. }
  71. type serverHandshake func(net.Conn) (credentials.AuthInfo, error)
  72. func getSecurityLevelFromAuthInfo(ai credentials.AuthInfo) credentials.SecurityLevel {
  73. if c, ok := ai.(interface {
  74. GetCommonAuthInfo() credentials.CommonAuthInfo
  75. }); ok {
  76. return c.GetCommonAuthInfo().SecurityLevel
  77. }
  78. return credentials.InvalidSecurityLevel
  79. }
  80. // Server local handshake implementation.
  81. func serverLocalHandshake(conn net.Conn) (credentials.AuthInfo, error) {
  82. cred := NewCredentials()
  83. _, authInfo, err := cred.ServerHandshake(conn)
  84. if err != nil {
  85. return nil, err
  86. }
  87. return authInfo, nil
  88. }
  89. // Client local handshake implementation.
  90. func clientLocalHandshake(conn net.Conn, lisAddr string) (credentials.AuthInfo, error) {
  91. cred := NewCredentials()
  92. ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
  93. defer cancel()
  94. _, authInfo, err := cred.ClientHandshake(ctx, lisAddr, conn)
  95. if err != nil {
  96. return nil, err
  97. }
  98. return authInfo, nil
  99. }
  100. // Client connects to a server with local credentials.
  101. func clientHandle(hs func(net.Conn, string) (credentials.AuthInfo, error), network, lisAddr string) (credentials.AuthInfo, error) {
  102. conn, _ := net.Dial(network, lisAddr)
  103. defer conn.Close()
  104. clientAuthInfo, err := hs(conn, lisAddr)
  105. if err != nil {
  106. return nil, fmt.Errorf("Error on client while handshake")
  107. }
  108. return clientAuthInfo, nil
  109. }
  110. type testServerHandleResult struct {
  111. authInfo credentials.AuthInfo
  112. err error
  113. }
  114. // Server accepts a client's connection with local credentials.
  115. func serverHandle(hs serverHandshake, done chan testServerHandleResult, lis net.Listener) {
  116. serverRawConn, err := lis.Accept()
  117. if err != nil {
  118. done <- testServerHandleResult{authInfo: nil, err: fmt.Errorf("Server failed to accept connection. Error: %v", err)}
  119. return
  120. }
  121. serverAuthInfo, err := hs(serverRawConn)
  122. if err != nil {
  123. serverRawConn.Close()
  124. done <- testServerHandleResult{authInfo: nil, err: fmt.Errorf("Server failed while handshake. Error: %v", err)}
  125. return
  126. }
  127. done <- testServerHandleResult{authInfo: serverAuthInfo, err: nil}
  128. }
  129. func serverAndClientHandshake(lis net.Listener) (credentials.SecurityLevel, error) {
  130. done := make(chan testServerHandleResult, 1)
  131. const timeout = 5 * time.Second
  132. timer := time.NewTimer(timeout)
  133. defer timer.Stop()
  134. go serverHandle(serverLocalHandshake, done, lis)
  135. defer lis.Close()
  136. clientAuthInfo, err := clientHandle(clientLocalHandshake, lis.Addr().Network(), lis.Addr().String())
  137. if err != nil {
  138. return credentials.InvalidSecurityLevel, fmt.Errorf("Error at client-side: %v", err)
  139. }
  140. select {
  141. case <-timer.C:
  142. return credentials.InvalidSecurityLevel, fmt.Errorf("Test didn't finish in time")
  143. case serverHandleResult := <-done:
  144. if serverHandleResult.err != nil {
  145. return credentials.InvalidSecurityLevel, fmt.Errorf("Error at server-side: %v", serverHandleResult.err)
  146. }
  147. clientSecLevel := getSecurityLevelFromAuthInfo(clientAuthInfo)
  148. serverSecLevel := getSecurityLevelFromAuthInfo(serverHandleResult.authInfo)
  149. if clientSecLevel == credentials.InvalidSecurityLevel {
  150. return credentials.InvalidSecurityLevel, fmt.Errorf("Error at client-side: client's AuthInfo does not implement GetCommonAuthInfo()")
  151. }
  152. if serverSecLevel == credentials.InvalidSecurityLevel {
  153. return credentials.InvalidSecurityLevel, fmt.Errorf("Error at server-side: server's AuthInfo does not implement GetCommonAuthInfo()")
  154. }
  155. if clientSecLevel != serverSecLevel {
  156. return credentials.InvalidSecurityLevel, fmt.Errorf("client's AuthInfo contains %s but server's AuthInfo contains %s", clientSecLevel.String(), serverSecLevel.String())
  157. }
  158. return clientSecLevel, nil
  159. }
  160. }
  161. func (s) TestServerAndClientHandshake(t *testing.T) {
  162. testCases := []struct {
  163. testNetwork string
  164. testAddr string
  165. want credentials.SecurityLevel
  166. }{
  167. {
  168. testNetwork: "tcp",
  169. testAddr: "127.0.0.1:0",
  170. want: credentials.NoSecurity,
  171. },
  172. {
  173. testNetwork: "tcp",
  174. testAddr: "[::1]:0",
  175. want: credentials.NoSecurity,
  176. },
  177. {
  178. testNetwork: "tcp",
  179. testAddr: "localhost:0",
  180. want: credentials.NoSecurity,
  181. },
  182. {
  183. testNetwork: "unix",
  184. testAddr: fmt.Sprintf("/tmp/grpc_fullstck_test%d", time.Now().UnixNano()),
  185. want: credentials.PrivacyAndIntegrity,
  186. },
  187. }
  188. for _, tc := range testCases {
  189. if runtime.GOOS == "windows" && tc.testNetwork == "unix" {
  190. t.Skip("skipping tests for unix connections on Windows")
  191. }
  192. t.Run("serverAndClientHandshakeResult", func(t *testing.T) {
  193. lis, err := net.Listen(tc.testNetwork, tc.testAddr)
  194. if err != nil {
  195. if strings.Contains(err.Error(), "bind: cannot assign requested address") ||
  196. strings.Contains(err.Error(), "socket: address family not supported by protocol") {
  197. t.Skipf("no support for address %v", tc.testAddr)
  198. }
  199. t.Fatalf("Failed to listen: %v", err)
  200. }
  201. got, err := serverAndClientHandshake(lis)
  202. if got != tc.want {
  203. t.Fatalf("serverAndClientHandshake(%s, %s) = %v, %v; want %v, nil", tc.testNetwork, tc.testAddr, got, err, tc.want)
  204. }
  205. })
  206. }
  207. }