Coverage for manila/tests/test_ssh_utils.py: 98%

161 statements  

« prev     ^ index     » next       coverage.py v7.11.0, created at 2026-02-18 22:19 +0000

1# Licensed under the Apache License, Version 2.0 (the "License"); you may 

2# not use this file except in compliance with the License. You may obtain 

3# a copy of the License at 

4# 

5# http://www.apache.org/licenses/LICENSE-2.0 

6# 

7# Unless required by applicable law or agreed to in writing, software 

8# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT 

9# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the 

10# License for the specific language governing permissions and limitations 

11# under the License. 

12 

13import threading 

14import time 

15from unittest import mock 

16 

17from oslo_utils import uuidutils 

18import paramiko 

19 

20from manila import exception 

21from manila import ssh_utils 

22from manila import test 

23 

24 

25class FakeSock(object): 

26 def settimeout(self, timeout): 

27 pass 

28 

29 

30class FakeTransport(object): 

31 

32 def __init__(self): 

33 self.active = True 

34 self.sock = FakeSock() 

35 

36 def set_keepalive(self, timeout): 

37 pass 

38 

39 def is_active(self): 

40 return self.active 

41 

42 

43class FakeSSHClient(object): 

44 

45 def __init__(self): 

46 self.id = uuidutils.generate_uuid() 

47 self.transport = FakeTransport() 

48 

49 def set_missing_host_key_policy(self, policy): 

50 pass 

51 

52 def connect(self, ip, port=22, username=None, password=None, 

53 key_filename=None, look_for_keys=None, timeout=10, 

54 banner_timeout=10): 

55 pass 

56 

57 def get_transport(self): 

58 return self.transport 

59 

60 def close(self): 

61 pass 

62 

63 def __call__(self, *args, **kwargs): 

64 pass 

65 

66 

67class SSHPoolTestCase(test.TestCase): 

68 """Unit test for SSH Connection Pool.""" 

69 

70 def test_single_ssh_connect(self): 

71 with mock.patch.object(paramiko, "SSHClient", 

72 mock.Mock(return_value=FakeSSHClient())): 

73 sshpool = ssh_utils.SSHPool("127.0.0.1", 22, 10, "test", 

74 password="test", min_size=1, 

75 max_size=1) 

76 with sshpool.item() as ssh: 

77 first_id = ssh.id 

78 

79 with sshpool.item() as ssh: 

80 second_id = ssh.id 

81 

82 self.assertEqual(first_id, second_id) 

83 paramiko.SSHClient.assert_called_once_with() 

84 

85 def test_create_ssh_with_password(self): 

86 fake_ssh_client = mock.Mock() 

87 fake_transport = mock.Mock() 

88 fake_ssh_client.get_transport.return_value = fake_transport 

89 ssh_pool = ssh_utils.SSHPool("127.0.0.1", 22, 10, "test", 

90 password="test") 

91 with mock.patch.object(paramiko, "SSHClient", 

92 return_value=fake_ssh_client): 

93 ssh_pool.create() 

94 

95 fake_ssh_client.connect.assert_called_once_with( 

96 "127.0.0.1", port=22, username="test", 

97 password="test", key_filename=None, look_for_keys=False, 

98 timeout=10, banner_timeout=10) 

99 fake_transport.set_keepalive.assert_called_once_with(10) 

100 

101 def test_create_ssh_with_key(self): 

102 path_to_private_key = "/fakepath/to/privatekey" 

103 fake_ssh_client = mock.Mock() 

104 fake_transport = mock.Mock() 

105 fake_ssh_client.get_transport.return_value = fake_transport 

106 ssh_pool = ssh_utils.SSHPool("127.0.0.1", 22, 10, 

107 "test", 

108 privatekey="/fakepath/to/privatekey") 

109 with mock.patch.object(paramiko, "SSHClient", 

110 return_value=fake_ssh_client): 

111 ssh_pool.create() 

112 fake_ssh_client.connect.assert_called_once_with( 

113 "127.0.0.1", port=22, username="test", password=None, 

114 key_filename=path_to_private_key, look_for_keys=False, 

115 timeout=10, banner_timeout=10) 

116 fake_transport.set_keepalive.assert_called_once_with(10) 

117 

118 def test_create_ssh_with_nothing(self): 

119 fake_ssh_client = mock.Mock() 

120 fake_transport = mock.Mock() 

121 fake_ssh_client.get_transport.return_value = fake_transport 

122 ssh_pool = ssh_utils.SSHPool("127.0.0.1", 22, 10, "test") 

123 with mock.patch.object(paramiko, "SSHClient", 

124 return_value=fake_ssh_client): 

125 ssh_pool.create() 

126 fake_ssh_client.connect.assert_called_once_with( 

127 "127.0.0.1", port=22, username="test", password=None, 

128 key_filename=None, look_for_keys=True, 

129 timeout=10, banner_timeout=10) 

130 fake_transport.set_keepalive.assert_called_once_with(10) 

131 

132 def test_create_ssh_error_connecting(self): 

133 attrs = {'connect.side_effect': paramiko.SSHException, } 

134 fake_ssh_client = mock.Mock(**attrs) 

135 ssh_pool = ssh_utils.SSHPool("127.0.0.1", 22, 10, "test") 

136 with mock.patch.object(paramiko, "SSHClient", 

137 return_value=fake_ssh_client): 

138 self.assertRaises(exception.SSHException, ssh_pool.create) 

139 fake_ssh_client.connect.assert_called_once_with( 

140 "127.0.0.1", port=22, username="test", password=None, 

141 key_filename=None, look_for_keys=True, 

142 timeout=10, banner_timeout=10) 

143 

144 def test_closed_reopend_ssh_connections(self): 

145 with mock.patch.object(paramiko, "SSHClient", 

146 mock.Mock(return_value=FakeSSHClient())): 

147 sshpool = ssh_utils.SSHPool("127.0.0.1", 22, 10, 

148 "test", password="test", 

149 min_size=1, max_size=2) 

150 with sshpool.item() as ssh: 

151 first_id = ssh.id 

152 with sshpool.item() as ssh: 

153 second_id = ssh.id 

154 # Close the connection and test for a new connection 

155 ssh.get_transport().active = False 

156 self.assertEqual(first_id, second_id) 

157 paramiko.SSHClient.assert_called_once_with() 

158 

159 # Expected new ssh pool 

160 with mock.patch.object(paramiko, "SSHClient", 

161 mock.Mock(return_value=FakeSSHClient())): 

162 with sshpool.item() as ssh: 

163 third_id = ssh.id 

164 self.assertNotEqual(first_id, third_id) 

165 paramiko.SSHClient.assert_called_once_with() 

166 

167 @mock.patch('builtins.open') 

168 @mock.patch('paramiko.SSHClient') 

169 @mock.patch('os.path.isfile', return_value=True) 

170 def test_sshpool_remove(self, mock_isfile, mock_sshclient, mock_open): 

171 ssh_to_remove = mock.Mock() 

172 ssh_to_remove.get_transport.return_value.is_active.return_value = True 

173 mock_sshclient.side_effect = [mock.Mock(), ssh_to_remove, mock.Mock()] 

174 sshpool = ssh_utils.SSHPool("127.0.0.1", 22, 10, 

175 "test", password="test", 

176 min_size=3, max_size=3) 

177 

178 # Get connections to populate the pool 

179 conn1 = sshpool.get() 

180 conn2 = sshpool.get() 

181 conn3 = sshpool.get() 

182 

183 # Put them back so they're in free_items 

184 sshpool.put(conn1) 

185 sshpool.put(conn2) 

186 sshpool.put(conn3) 

187 

188 self.assertIn(ssh_to_remove, list(sshpool.free_items)) 

189 

190 sshpool.remove(ssh_to_remove) 

191 

192 self.assertNotIn(ssh_to_remove, list(sshpool.free_items)) 

193 

194 @mock.patch('builtins.open') 

195 @mock.patch('paramiko.SSHClient') 

196 @mock.patch('os.path.isfile', return_value=True) 

197 def test_sshpool_remove_object_not_in_pool(self, mock_isfile, 

198 mock_sshclient, mock_open): 

199 # create an SSH Client that is not a part of sshpool. 

200 ssh_to_remove = mock.Mock() 

201 mock_conn1 = mock.Mock() 

202 mock_conn2 = mock.Mock() 

203 mock_conn1.get_transport.return_value.is_active.return_value = True 

204 mock_conn2.get_transport.return_value.is_active.return_value = True 

205 mock_sshclient.side_effect = [mock_conn1, mock_conn2] 

206 

207 sshpool = ssh_utils.SSHPool("127.0.0.1", 22, 10, 

208 "test", password="test", 

209 min_size=2, max_size=2) 

210 

211 # Get and put back connections to populate free_items 

212 conn1 = sshpool.get() 

213 conn2 = sshpool.get() 

214 sshpool.put(conn1) 

215 sshpool.put(conn2) 

216 

217 listBefore = list(sshpool.free_items) 

218 

219 self.assertNotIn(ssh_to_remove, listBefore) 

220 

221 sshpool.remove(ssh_to_remove) 

222 

223 self.assertEqual(listBefore, list(sshpool.free_items)) 

224 

225 def test_sshpool_thread_safety(self): 

226 """Test that the pool is thread-safe.""" 

227 with mock.patch.object(paramiko, "SSHClient", 

228 mock.Mock(return_value=FakeSSHClient())): 

229 sshpool = ssh_utils.SSHPool("127.0.0.1", 22, 10, 

230 "test", password="test", 

231 min_size=1, max_size=5) 

232 

233 connections_acquired = [] 

234 errors = [] 

235 

236 def acquire_connection(): 

237 try: 

238 with sshpool.item() as ssh: 

239 connections_acquired.append(ssh.id) 

240 time.sleep(0.1) # Simulate work 

241 except Exception as e: 

242 errors.append(str(e)) 

243 

244 # Start multiple threads 

245 threads = [] 

246 for _ in range(10): 

247 thread = threading.Thread(target=acquire_connection) 

248 threads.append(thread) 

249 thread.start() 

250 

251 # Wait for all threads to complete 

252 for thread in threads: 

253 thread.join() 

254 

255 # Verify no errors 

256 self.assertEqual([], errors) 

257 self.assertEqual(10, len(connections_acquired)) 

258 self.assertLessEqual(sshpool.current_size, 5) 

259 

260 def test_sshpool_put_get_behavior(self): 

261 with mock.patch.object(paramiko, "SSHClient", 

262 mock.Mock(return_value=FakeSSHClient())): 

263 sshpool = ssh_utils.SSHPool("127.0.0.1", 22, 10, 

264 "test", password="test", 

265 min_size=1, max_size=3) 

266 

267 conn1 = sshpool.get() 

268 self.assertIsNotNone(conn1) 

269 self.assertEqual(1, sshpool.current_size) 

270 

271 sshpool.put(conn1) 

272 self.assertEqual(1, len(sshpool.free_items)) 

273 

274 conn2 = sshpool.get() 

275 self.assertEqual(conn1.id, conn2.id)