alts_test.go 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413
  1. //go:build linux || windows
  2. // +build linux windows
  3. /*
  4. *
  5. * Copyright 2018 gRPC authors.
  6. *
  7. * Licensed under the Apache License, Version 2.0 (the "License");
  8. * you may not use this file except in compliance with the License.
  9. * You may obtain a copy of the License at
  10. *
  11. * http://www.apache.org/licenses/LICENSE-2.0
  12. *
  13. * Unless required by applicable law or agreed to in writing, software
  14. * distributed under the License is distributed on an "AS IS" BASIS,
  15. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  16. * See the License for the specific language governing permissions and
  17. * limitations under the License.
  18. *
  19. */
  20. package alts
  21. import (
  22. "context"
  23. "reflect"
  24. "sync"
  25. "testing"
  26. "time"
  27. "github.com/golang/protobuf/proto"
  28. "google.golang.org/grpc"
  29. "google.golang.org/grpc/codes"
  30. "google.golang.org/grpc/credentials/alts/internal/handshaker/service"
  31. altsgrpc "google.golang.org/grpc/credentials/alts/internal/proto/grpc_gcp"
  32. altspb "google.golang.org/grpc/credentials/alts/internal/proto/grpc_gcp"
  33. "google.golang.org/grpc/credentials/alts/internal/testutil"
  34. "google.golang.org/grpc/internal/grpctest"
  35. "google.golang.org/grpc/internal/testutils"
  36. testgrpc "google.golang.org/grpc/interop/grpc_testing"
  37. testpb "google.golang.org/grpc/interop/grpc_testing"
  38. "google.golang.org/grpc/status"
  39. )
  40. const (
  41. defaultTestLongTimeout = 10 * time.Second
  42. defaultTestShortTimeout = 10 * time.Millisecond
  43. )
  44. type s struct {
  45. grpctest.Tester
  46. }
  47. func Test(t *testing.T) {
  48. grpctest.RunSubTests(t, s{})
  49. }
  50. func (s) TestInfoServerName(t *testing.T) {
  51. // This is not testing any handshaker functionality, so it's fine to only
  52. // use NewServerCreds and not NewClientCreds.
  53. alts := NewServerCreds(DefaultServerOptions())
  54. if got, want := alts.Info().ServerName, ""; got != want {
  55. t.Fatalf("%v.Info().ServerName = %v, want %v", alts, got, want)
  56. }
  57. }
  58. func (s) TestOverrideServerName(t *testing.T) {
  59. wantServerName := "server.name"
  60. // This is not testing any handshaker functionality, so it's fine to only
  61. // use NewServerCreds and not NewClientCreds.
  62. c := NewServerCreds(DefaultServerOptions())
  63. c.OverrideServerName(wantServerName)
  64. if got, want := c.Info().ServerName, wantServerName; got != want {
  65. t.Fatalf("c.Info().ServerName = %v, want %v", got, want)
  66. }
  67. }
  68. func (s) TestCloneClient(t *testing.T) {
  69. wantServerName := "server.name"
  70. opt := DefaultClientOptions()
  71. opt.TargetServiceAccounts = []string{"not", "empty"}
  72. c := NewClientCreds(opt)
  73. c.OverrideServerName(wantServerName)
  74. cc := c.Clone()
  75. if got, want := cc.Info().ServerName, wantServerName; got != want {
  76. t.Fatalf("cc.Info().ServerName = %v, want %v", got, want)
  77. }
  78. cc.OverrideServerName("")
  79. if got, want := c.Info().ServerName, wantServerName; got != want {
  80. t.Fatalf("Change in clone should not affect the original, c.Info().ServerName = %v, want %v", got, want)
  81. }
  82. if got, want := cc.Info().ServerName, ""; got != want {
  83. t.Fatalf("cc.Info().ServerName = %v, want %v", got, want)
  84. }
  85. ct := c.(*altsTC)
  86. cct := cc.(*altsTC)
  87. if ct.side != cct.side {
  88. t.Errorf("cc.side = %q, want %q", cct.side, ct.side)
  89. }
  90. if ct.hsAddress != cct.hsAddress {
  91. t.Errorf("cc.hsAddress = %q, want %q", cct.hsAddress, ct.hsAddress)
  92. }
  93. if !reflect.DeepEqual(ct.accounts, cct.accounts) {
  94. t.Errorf("cc.accounts = %q, want %q", cct.accounts, ct.accounts)
  95. }
  96. }
  97. func (s) TestCloneServer(t *testing.T) {
  98. wantServerName := "server.name"
  99. c := NewServerCreds(DefaultServerOptions())
  100. c.OverrideServerName(wantServerName)
  101. cc := c.Clone()
  102. if got, want := cc.Info().ServerName, wantServerName; got != want {
  103. t.Fatalf("cc.Info().ServerName = %v, want %v", got, want)
  104. }
  105. cc.OverrideServerName("")
  106. if got, want := c.Info().ServerName, wantServerName; got != want {
  107. t.Fatalf("Change in clone should not affect the original, c.Info().ServerName = %v, want %v", got, want)
  108. }
  109. if got, want := cc.Info().ServerName, ""; got != want {
  110. t.Fatalf("cc.Info().ServerName = %v, want %v", got, want)
  111. }
  112. ct := c.(*altsTC)
  113. cct := cc.(*altsTC)
  114. if ct.side != cct.side {
  115. t.Errorf("cc.side = %q, want %q", cct.side, ct.side)
  116. }
  117. if ct.hsAddress != cct.hsAddress {
  118. t.Errorf("cc.hsAddress = %q, want %q", cct.hsAddress, ct.hsAddress)
  119. }
  120. if !reflect.DeepEqual(ct.accounts, cct.accounts) {
  121. t.Errorf("cc.accounts = %q, want %q", cct.accounts, ct.accounts)
  122. }
  123. }
  124. func (s) TestInfo(t *testing.T) {
  125. // This is not testing any handshaker functionality, so it's fine to only
  126. // use NewServerCreds and not NewClientCreds.
  127. c := NewServerCreds(DefaultServerOptions())
  128. info := c.Info()
  129. if got, want := info.ProtocolVersion, ""; got != want {
  130. t.Errorf("info.ProtocolVersion=%v, want %v", got, want)
  131. }
  132. if got, want := info.SecurityProtocol, "alts"; got != want {
  133. t.Errorf("info.SecurityProtocol=%v, want %v", got, want)
  134. }
  135. if got, want := info.SecurityVersion, "1.0"; got != want {
  136. t.Errorf("info.SecurityVersion=%v, want %v", got, want)
  137. }
  138. if got, want := info.ServerName, ""; got != want {
  139. t.Errorf("info.ServerName=%v, want %v", got, want)
  140. }
  141. }
  142. func (s) TestCompareRPCVersions(t *testing.T) {
  143. for _, tc := range []struct {
  144. v1 *altspb.RpcProtocolVersions_Version
  145. v2 *altspb.RpcProtocolVersions_Version
  146. output int
  147. }{
  148. {
  149. version(3, 2),
  150. version(2, 1),
  151. 1,
  152. },
  153. {
  154. version(3, 2),
  155. version(3, 1),
  156. 1,
  157. },
  158. {
  159. version(2, 1),
  160. version(3, 2),
  161. -1,
  162. },
  163. {
  164. version(3, 1),
  165. version(3, 2),
  166. -1,
  167. },
  168. {
  169. version(3, 2),
  170. version(3, 2),
  171. 0,
  172. },
  173. } {
  174. if got, want := compareRPCVersions(tc.v1, tc.v2), tc.output; got != want {
  175. t.Errorf("compareRPCVersions(%v, %v)=%v, want %v", tc.v1, tc.v2, got, want)
  176. }
  177. }
  178. }
  179. func (s) TestCheckRPCVersions(t *testing.T) {
  180. for _, tc := range []struct {
  181. desc string
  182. local *altspb.RpcProtocolVersions
  183. peer *altspb.RpcProtocolVersions
  184. output bool
  185. maxCommonVersion *altspb.RpcProtocolVersions_Version
  186. }{
  187. {
  188. "local.max > peer.max and local.min > peer.min",
  189. versions(2, 1, 3, 2),
  190. versions(1, 2, 2, 1),
  191. true,
  192. version(2, 1),
  193. },
  194. {
  195. "local.max > peer.max and local.min < peer.min",
  196. versions(1, 2, 3, 2),
  197. versions(2, 1, 2, 1),
  198. true,
  199. version(2, 1),
  200. },
  201. {
  202. "local.max > peer.max and local.min = peer.min",
  203. versions(2, 1, 3, 2),
  204. versions(2, 1, 2, 1),
  205. true,
  206. version(2, 1),
  207. },
  208. {
  209. "local.max < peer.max and local.min > peer.min",
  210. versions(2, 1, 2, 1),
  211. versions(1, 2, 3, 2),
  212. true,
  213. version(2, 1),
  214. },
  215. {
  216. "local.max = peer.max and local.min > peer.min",
  217. versions(2, 1, 2, 1),
  218. versions(1, 2, 2, 1),
  219. true,
  220. version(2, 1),
  221. },
  222. {
  223. "local.max < peer.max and local.min < peer.min",
  224. versions(1, 2, 2, 1),
  225. versions(2, 1, 3, 2),
  226. true,
  227. version(2, 1),
  228. },
  229. {
  230. "local.max < peer.max and local.min = peer.min",
  231. versions(1, 2, 2, 1),
  232. versions(1, 2, 3, 2),
  233. true,
  234. version(2, 1),
  235. },
  236. {
  237. "local.max = peer.max and local.min < peer.min",
  238. versions(1, 2, 2, 1),
  239. versions(2, 1, 2, 1),
  240. true,
  241. version(2, 1),
  242. },
  243. {
  244. "all equal",
  245. versions(2, 1, 2, 1),
  246. versions(2, 1, 2, 1),
  247. true,
  248. version(2, 1),
  249. },
  250. {
  251. "max is smaller than min",
  252. versions(2, 1, 1, 2),
  253. versions(2, 1, 1, 2),
  254. false,
  255. nil,
  256. },
  257. {
  258. "no overlap, local > peer",
  259. versions(4, 3, 6, 5),
  260. versions(1, 0, 2, 1),
  261. false,
  262. nil,
  263. },
  264. {
  265. "no overlap, local < peer",
  266. versions(1, 0, 2, 1),
  267. versions(4, 3, 6, 5),
  268. false,
  269. nil,
  270. },
  271. {
  272. "no overlap, max < min",
  273. versions(6, 5, 4, 3),
  274. versions(2, 1, 1, 0),
  275. false,
  276. nil,
  277. },
  278. } {
  279. output, maxCommonVersion := checkRPCVersions(tc.local, tc.peer)
  280. if got, want := output, tc.output; got != want {
  281. t.Errorf("%v: checkRPCVersions(%v, %v)=(%v, _), want (%v, _)", tc.desc, tc.local, tc.peer, got, want)
  282. }
  283. if got, want := maxCommonVersion, tc.maxCommonVersion; !proto.Equal(got, want) {
  284. t.Errorf("%v: checkRPCVersions(%v, %v)=(_, %v), want (_, %v)", tc.desc, tc.local, tc.peer, got, want)
  285. }
  286. }
  287. }
  288. // TestFullHandshake performs a full ALTS handshake between a test client and
  289. // server, where both client and server offload to a local, fake handshaker
  290. // service.
  291. func (s) TestFullHandshake(t *testing.T) {
  292. // The vmOnGCP global variable MUST be reset to true after the client
  293. // or server credentials have been created, but before the ALTS
  294. // handshake begins. If vmOnGCP is not reset and this test is run
  295. // anywhere except for a GCP VM, then the ALTS handshake will
  296. // immediately fail.
  297. once.Do(func() {})
  298. vmOnGCP = true
  299. // Start the fake handshaker service and the server.
  300. var wait sync.WaitGroup
  301. defer wait.Wait()
  302. stopHandshaker, handshakerAddress := startFakeHandshakerService(t, &wait)
  303. defer stopHandshaker()
  304. stopServer, serverAddress := startServer(t, handshakerAddress, &wait)
  305. defer stopServer()
  306. // Ping the server, authenticating with ALTS.
  307. clientCreds := NewClientCreds(&ClientOptions{HandshakerServiceAddress: handshakerAddress})
  308. conn, err := grpc.Dial(serverAddress, grpc.WithTransportCredentials(clientCreds))
  309. if err != nil {
  310. t.Fatalf("grpc.Dial(%v) failed: %v", serverAddress, err)
  311. }
  312. defer conn.Close()
  313. ctx, cancel := context.WithTimeout(context.Background(), defaultTestLongTimeout)
  314. defer cancel()
  315. c := testgrpc.NewTestServiceClient(conn)
  316. for ; ctx.Err() == nil; <-time.After(defaultTestShortTimeout) {
  317. _, err = c.UnaryCall(ctx, &testpb.SimpleRequest{})
  318. if err == nil {
  319. break
  320. }
  321. if code := status.Code(err); code == codes.Unavailable {
  322. // The server is not ready yet. Try again.
  323. continue
  324. }
  325. t.Fatalf("c.UnaryCall() failed: %v", err)
  326. }
  327. // Close open connections to the fake handshaker service.
  328. if err := service.CloseForTesting(); err != nil {
  329. t.Errorf("service.CloseForTesting() failed: %v", err)
  330. }
  331. }
  332. func version(major, minor uint32) *altspb.RpcProtocolVersions_Version {
  333. return &altspb.RpcProtocolVersions_Version{
  334. Major: major,
  335. Minor: minor,
  336. }
  337. }
  338. func versions(minMajor, minMinor, maxMajor, maxMinor uint32) *altspb.RpcProtocolVersions {
  339. return &altspb.RpcProtocolVersions{
  340. MinRpcVersion: version(minMajor, minMinor),
  341. MaxRpcVersion: version(maxMajor, maxMinor),
  342. }
  343. }
  344. func startFakeHandshakerService(t *testing.T, wait *sync.WaitGroup) (stop func(), address string) {
  345. listener, err := testutils.LocalTCPListener()
  346. if err != nil {
  347. t.Fatalf("LocalTCPListener() failed: %v", err)
  348. }
  349. s := grpc.NewServer()
  350. altsgrpc.RegisterHandshakerServiceServer(s, &testutil.FakeHandshaker{})
  351. wait.Add(1)
  352. go func() {
  353. defer wait.Done()
  354. if err := s.Serve(listener); err != nil {
  355. t.Errorf("failed to serve: %v", err)
  356. }
  357. }()
  358. return func() { s.Stop() }, listener.Addr().String()
  359. }
  360. func startServer(t *testing.T, handshakerServiceAddress string, wait *sync.WaitGroup) (stop func(), address string) {
  361. listener, err := testutils.LocalTCPListener()
  362. if err != nil {
  363. t.Fatalf("LocalTCPListener() failed: %v", err)
  364. }
  365. serverOpts := &ServerOptions{HandshakerServiceAddress: handshakerServiceAddress}
  366. creds := NewServerCreds(serverOpts)
  367. s := grpc.NewServer(grpc.Creds(creds))
  368. testgrpc.RegisterTestServiceServer(s, &testServer{})
  369. wait.Add(1)
  370. go func() {
  371. defer wait.Done()
  372. if err := s.Serve(listener); err != nil {
  373. t.Errorf("s.Serve(%v) failed: %v", listener, err)
  374. }
  375. }()
  376. return func() { s.Stop() }, listener.Addr().String()
  377. }
  378. type testServer struct {
  379. testgrpc.UnimplementedTestServiceServer
  380. }
  381. func (s *testServer) UnaryCall(ctx context.Context, in *testpb.SimpleRequest) (*testpb.SimpleResponse, error) {
  382. return &testpb.SimpleResponse{
  383. Payload: &testpb.Payload{},
  384. }, nil
  385. }