Coverage for manila/tests/test_ssh_utils.py: 98%
161 statements
« prev ^ index » next coverage.py v7.11.0, created at 2026-02-18 22:19 +0000
« prev ^ index » next coverage.py v7.11.0, created at 2026-02-18 22:19 +0000
1# Licensed under the Apache License, Version 2.0 (the "License"); you may
2# not use this file except in compliance with the License. You may obtain
3# a copy of the License at
4#
5# http://www.apache.org/licenses/LICENSE-2.0
6#
7# Unless required by applicable law or agreed to in writing, software
8# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
9# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
10# License for the specific language governing permissions and limitations
11# under the License.
13import threading
14import time
15from unittest import mock
17from oslo_utils import uuidutils
18import paramiko
20from manila import exception
21from manila import ssh_utils
22from manila import test
25class FakeSock(object):
26 def settimeout(self, timeout):
27 pass
30class FakeTransport(object):
32 def __init__(self):
33 self.active = True
34 self.sock = FakeSock()
36 def set_keepalive(self, timeout):
37 pass
39 def is_active(self):
40 return self.active
43class FakeSSHClient(object):
45 def __init__(self):
46 self.id = uuidutils.generate_uuid()
47 self.transport = FakeTransport()
49 def set_missing_host_key_policy(self, policy):
50 pass
52 def connect(self, ip, port=22, username=None, password=None,
53 key_filename=None, look_for_keys=None, timeout=10,
54 banner_timeout=10):
55 pass
57 def get_transport(self):
58 return self.transport
60 def close(self):
61 pass
63 def __call__(self, *args, **kwargs):
64 pass
67class SSHPoolTestCase(test.TestCase):
68 """Unit test for SSH Connection Pool."""
70 def test_single_ssh_connect(self):
71 with mock.patch.object(paramiko, "SSHClient",
72 mock.Mock(return_value=FakeSSHClient())):
73 sshpool = ssh_utils.SSHPool("127.0.0.1", 22, 10, "test",
74 password="test", min_size=1,
75 max_size=1)
76 with sshpool.item() as ssh:
77 first_id = ssh.id
79 with sshpool.item() as ssh:
80 second_id = ssh.id
82 self.assertEqual(first_id, second_id)
83 paramiko.SSHClient.assert_called_once_with()
85 def test_create_ssh_with_password(self):
86 fake_ssh_client = mock.Mock()
87 fake_transport = mock.Mock()
88 fake_ssh_client.get_transport.return_value = fake_transport
89 ssh_pool = ssh_utils.SSHPool("127.0.0.1", 22, 10, "test",
90 password="test")
91 with mock.patch.object(paramiko, "SSHClient",
92 return_value=fake_ssh_client):
93 ssh_pool.create()
95 fake_ssh_client.connect.assert_called_once_with(
96 "127.0.0.1", port=22, username="test",
97 password="test", key_filename=None, look_for_keys=False,
98 timeout=10, banner_timeout=10)
99 fake_transport.set_keepalive.assert_called_once_with(10)
101 def test_create_ssh_with_key(self):
102 path_to_private_key = "/fakepath/to/privatekey"
103 fake_ssh_client = mock.Mock()
104 fake_transport = mock.Mock()
105 fake_ssh_client.get_transport.return_value = fake_transport
106 ssh_pool = ssh_utils.SSHPool("127.0.0.1", 22, 10,
107 "test",
108 privatekey="/fakepath/to/privatekey")
109 with mock.patch.object(paramiko, "SSHClient",
110 return_value=fake_ssh_client):
111 ssh_pool.create()
112 fake_ssh_client.connect.assert_called_once_with(
113 "127.0.0.1", port=22, username="test", password=None,
114 key_filename=path_to_private_key, look_for_keys=False,
115 timeout=10, banner_timeout=10)
116 fake_transport.set_keepalive.assert_called_once_with(10)
118 def test_create_ssh_with_nothing(self):
119 fake_ssh_client = mock.Mock()
120 fake_transport = mock.Mock()
121 fake_ssh_client.get_transport.return_value = fake_transport
122 ssh_pool = ssh_utils.SSHPool("127.0.0.1", 22, 10, "test")
123 with mock.patch.object(paramiko, "SSHClient",
124 return_value=fake_ssh_client):
125 ssh_pool.create()
126 fake_ssh_client.connect.assert_called_once_with(
127 "127.0.0.1", port=22, username="test", password=None,
128 key_filename=None, look_for_keys=True,
129 timeout=10, banner_timeout=10)
130 fake_transport.set_keepalive.assert_called_once_with(10)
132 def test_create_ssh_error_connecting(self):
133 attrs = {'connect.side_effect': paramiko.SSHException, }
134 fake_ssh_client = mock.Mock(**attrs)
135 ssh_pool = ssh_utils.SSHPool("127.0.0.1", 22, 10, "test")
136 with mock.patch.object(paramiko, "SSHClient",
137 return_value=fake_ssh_client):
138 self.assertRaises(exception.SSHException, ssh_pool.create)
139 fake_ssh_client.connect.assert_called_once_with(
140 "127.0.0.1", port=22, username="test", password=None,
141 key_filename=None, look_for_keys=True,
142 timeout=10, banner_timeout=10)
144 def test_closed_reopend_ssh_connections(self):
145 with mock.patch.object(paramiko, "SSHClient",
146 mock.Mock(return_value=FakeSSHClient())):
147 sshpool = ssh_utils.SSHPool("127.0.0.1", 22, 10,
148 "test", password="test",
149 min_size=1, max_size=2)
150 with sshpool.item() as ssh:
151 first_id = ssh.id
152 with sshpool.item() as ssh:
153 second_id = ssh.id
154 # Close the connection and test for a new connection
155 ssh.get_transport().active = False
156 self.assertEqual(first_id, second_id)
157 paramiko.SSHClient.assert_called_once_with()
159 # Expected new ssh pool
160 with mock.patch.object(paramiko, "SSHClient",
161 mock.Mock(return_value=FakeSSHClient())):
162 with sshpool.item() as ssh:
163 third_id = ssh.id
164 self.assertNotEqual(first_id, third_id)
165 paramiko.SSHClient.assert_called_once_with()
167 @mock.patch('builtins.open')
168 @mock.patch('paramiko.SSHClient')
169 @mock.patch('os.path.isfile', return_value=True)
170 def test_sshpool_remove(self, mock_isfile, mock_sshclient, mock_open):
171 ssh_to_remove = mock.Mock()
172 ssh_to_remove.get_transport.return_value.is_active.return_value = True
173 mock_sshclient.side_effect = [mock.Mock(), ssh_to_remove, mock.Mock()]
174 sshpool = ssh_utils.SSHPool("127.0.0.1", 22, 10,
175 "test", password="test",
176 min_size=3, max_size=3)
178 # Get connections to populate the pool
179 conn1 = sshpool.get()
180 conn2 = sshpool.get()
181 conn3 = sshpool.get()
183 # Put them back so they're in free_items
184 sshpool.put(conn1)
185 sshpool.put(conn2)
186 sshpool.put(conn3)
188 self.assertIn(ssh_to_remove, list(sshpool.free_items))
190 sshpool.remove(ssh_to_remove)
192 self.assertNotIn(ssh_to_remove, list(sshpool.free_items))
194 @mock.patch('builtins.open')
195 @mock.patch('paramiko.SSHClient')
196 @mock.patch('os.path.isfile', return_value=True)
197 def test_sshpool_remove_object_not_in_pool(self, mock_isfile,
198 mock_sshclient, mock_open):
199 # create an SSH Client that is not a part of sshpool.
200 ssh_to_remove = mock.Mock()
201 mock_conn1 = mock.Mock()
202 mock_conn2 = mock.Mock()
203 mock_conn1.get_transport.return_value.is_active.return_value = True
204 mock_conn2.get_transport.return_value.is_active.return_value = True
205 mock_sshclient.side_effect = [mock_conn1, mock_conn2]
207 sshpool = ssh_utils.SSHPool("127.0.0.1", 22, 10,
208 "test", password="test",
209 min_size=2, max_size=2)
211 # Get and put back connections to populate free_items
212 conn1 = sshpool.get()
213 conn2 = sshpool.get()
214 sshpool.put(conn1)
215 sshpool.put(conn2)
217 listBefore = list(sshpool.free_items)
219 self.assertNotIn(ssh_to_remove, listBefore)
221 sshpool.remove(ssh_to_remove)
223 self.assertEqual(listBefore, list(sshpool.free_items))
225 def test_sshpool_thread_safety(self):
226 """Test that the pool is thread-safe."""
227 with mock.patch.object(paramiko, "SSHClient",
228 mock.Mock(return_value=FakeSSHClient())):
229 sshpool = ssh_utils.SSHPool("127.0.0.1", 22, 10,
230 "test", password="test",
231 min_size=1, max_size=5)
233 connections_acquired = []
234 errors = []
236 def acquire_connection():
237 try:
238 with sshpool.item() as ssh:
239 connections_acquired.append(ssh.id)
240 time.sleep(0.1) # Simulate work
241 except Exception as e:
242 errors.append(str(e))
244 # Start multiple threads
245 threads = []
246 for _ in range(10):
247 thread = threading.Thread(target=acquire_connection)
248 threads.append(thread)
249 thread.start()
251 # Wait for all threads to complete
252 for thread in threads:
253 thread.join()
255 # Verify no errors
256 self.assertEqual([], errors)
257 self.assertEqual(10, len(connections_acquired))
258 self.assertLessEqual(sshpool.current_size, 5)
260 def test_sshpool_put_get_behavior(self):
261 with mock.patch.object(paramiko, "SSHClient",
262 mock.Mock(return_value=FakeSSHClient())):
263 sshpool = ssh_utils.SSHPool("127.0.0.1", 22, 10,
264 "test", password="test",
265 min_size=1, max_size=3)
267 conn1 = sshpool.get()
268 self.assertIsNotNone(conn1)
269 self.assertEqual(1, sshpool.current_size)
271 sshpool.put(conn1)
272 self.assertEqual(1, len(sshpool.free_items))
274 conn2 = sshpool.get()
275 self.assertEqual(conn1.id, conn2.id)