Coverage for manila/utils.py: 88%
326 statements
« prev ^ index » next coverage.py v7.11.0, created at 2026-02-18 22:19 +0000
« prev ^ index » next coverage.py v7.11.0, created at 2026-02-18 22:19 +0000
1# Copyright 2010 United States Government as represented by the
2# Administrator of the National Aeronautics and Space Administration.
3# Copyright 2011 Justin Santa Barbara
4# All Rights Reserved.
5#
6# Licensed under the Apache License, Version 2.0 (the "License"); you may
7# not use this file except in compliance with the License. You may obtain
8# a copy of the License at
9#
10# http://www.apache.org/licenses/LICENSE-2.0
11#
12# Unless required by applicable law or agreed to in writing, software
13# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
14# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
15# License for the specific language governing permissions and limitations
16# under the License.
18"""Utilities and helper functions."""
20import contextlib
21import functools
22import inspect
23import pyclbr
24import re
25import shutil
26import sys
27import tempfile
28import tenacity
29import time
31import logging
32import netaddr
33from oslo_concurrency import lockutils
34from oslo_concurrency import processutils
35from oslo_config import cfg
36from oslo_log import log
37from oslo_utils import importutils
38from oslo_utils import netutils
39from oslo_utils import strutils
40from oslo_utils import timeutils
41from webob import exc
44from manila.common import constants
45from manila.db import api as db_api
46from manila import exception
47from manila.i18n import _
50CONF = cfg.CONF
51LOG = log.getLogger(__name__)
52if getattr(CONF, 'debug', False): 52 ↛ 53line 52 didn't jump to line 53 because the condition on line 52 was never true
53 logging.getLogger("paramiko").setLevel(logging.DEBUG)
55_ISO8601_TIME_FORMAT_SUBSECOND = '%Y-%m-%dT%H:%M:%S.%f'
56_ISO8601_TIME_FORMAT = '%Y-%m-%dT%H:%M:%S'
58synchronized = lockutils.synchronized_with_prefix('manila-')
61def isotime(at=None, subsecond=False):
62 """Stringify time in ISO 8601 format."""
64 # Python provides a similar instance method for datetime.datetime objects
65 # called isoformat(). The format of the strings generated by isoformat()
66 # have a couple of problems:
67 # 1) The strings generated by isotime are used in tokens and other public
68 # APIs that we can't change without a deprecation period. The strings
69 # generated by isoformat are not the same format, so we can't just
70 # change to it.
71 # 2) The strings generated by isoformat do not include the microseconds if
72 # the value happens to be 0. This will likely show up as random failures
73 # as parsers may be written to always expect microseconds, and it will
74 # parse correctly most of the time.
76 if not at: 76 ↛ 77line 76 didn't jump to line 77 because the condition on line 76 was never true
77 at = timeutils.utcnow()
78 st = at.strftime(_ISO8601_TIME_FORMAT
79 if not subsecond
80 else _ISO8601_TIME_FORMAT_SUBSECOND)
81 tz = at.tzinfo.tzname(None) if at.tzinfo else 'UTC'
82 # Need to handle either iso8601 or python UTC format
83 st += ('Z' if tz in ['UTC', 'UTC+00:00'] else tz)
84 return st
87def _get_root_helper():
88 return 'sudo manila-rootwrap %s' % CONF.rootwrap_config
91def execute(*cmd, **kwargs):
92 """Convenience wrapper around oslo's execute() function."""
93 kwargs.setdefault('root_helper', _get_root_helper())
94 if getattr(CONF, 'debug', False):
95 kwargs['loglevel'] = logging.DEBUG
96 return processutils.execute(*cmd, **kwargs)
99def check_ssh_injection(cmd_list):
100 ssh_injection_pattern = ['`', '$', '|', '||', ';', '&', '&&', '>', '>>',
101 '<']
103 # Check whether injection attacks exist
104 for arg in cmd_list:
105 arg = arg.strip()
107 # Check for matching quotes on the ends
108 is_quoted = re.match('^(?P<quote>[\'"])(?P<quoted>.*)(?P=quote)$', arg)
109 if is_quoted:
110 # Check for unescaped quotes within the quoted argument
111 quoted = is_quoted.group('quoted')
112 if quoted: 112 ↛ 124line 112 didn't jump to line 124 because the condition on line 112 was always true
113 if (re.match('[\'"]', quoted) or
114 re.search('[^\\\\][\'"]', quoted)):
115 raise exception.SSHInjectionThreat(command=cmd_list)
116 else:
117 # We only allow spaces within quoted arguments, and that
118 # is the only special character allowed within quotes
119 if len(arg.split()) > 1:
120 raise exception.SSHInjectionThreat(command=cmd_list)
122 # Second, check whether danger character in command. So the shell
123 # special operator must be a single argument.
124 for c in ssh_injection_pattern:
125 if c not in arg:
126 continue
128 result = arg.find(c)
129 if not result == -1: 129 ↛ 124line 129 didn't jump to line 124 because the condition on line 129 was always true
130 if result == 0 or not arg[result - 1] == '\\': 130 ↛ 124line 130 didn't jump to line 124 because the condition on line 130 was always true
131 raise exception.SSHInjectionThreat(command=cmd_list)
134class LazyPluggable(object):
135 """A pluggable backend loaded lazily based on some value."""
137 def __init__(self, pivot, **backends):
138 self.__backends = backends
139 self.__pivot = pivot
140 self.__backend = None
142 def __get_backend(self):
143 if not self.__backend:
144 backend_name = CONF[self.__pivot]
145 if backend_name not in self.__backends: 145 ↛ 146line 145 didn't jump to line 146 because the condition on line 145 was never true
146 raise exception.Error(_('Invalid backend: %s') % backend_name)
148 backend = self.__backends[backend_name]
149 if isinstance(backend, tuple): 149 ↛ 150line 149 didn't jump to line 150 because the condition on line 149 was never true
150 name = backend[0]
151 fromlist = backend[1]
152 else:
153 name = backend
154 fromlist = backend
156 self.__backend = __import__(name, None, None, fromlist)
157 LOG.debug('backend %s', self.__backend)
158 return self.__backend
160 def __getattr__(self, key):
161 backend = self.__get_backend()
162 return getattr(backend, key)
165def monkey_patch():
166 """Patch decorator.
168 If the Flags.monkey_patch set as True,
169 this function patches a decorator
170 for all functions in specified modules.
171 You can set decorators for each modules
172 using CONF.monkey_patch_modules.
173 The format is "Module path:Decorator function".
174 Example: 'manila.api.ec2.cloud:' \
175 manila.openstack.common.notifier.api.notify_decorator'
177 Parameters of the decorator is as follows.
178 (See manila.openstack.common.notifier.api.notify_decorator)
180 name - name of the function
181 function - object of the function
182 """
183 # If CONF.monkey_patch is not True, this function do nothing.
184 if not CONF.monkey_patch: 184 ↛ 185line 184 didn't jump to line 185 because the condition on line 184 was never true
185 return
186 # Get list of modules and decorators
187 for module_and_decorator in CONF.monkey_patch_modules:
188 module, decorator_name = module_and_decorator.split(':')
189 # import decorator function
190 decorator = importutils.import_class(decorator_name)
191 __import__(module)
192 # Retrieve module information using pyclbr
193 module_data = pyclbr.readmodule_ex(module)
194 for key in module_data.keys():
195 # set the decorator for the class methods
196 if isinstance(module_data[key], pyclbr.Class):
197 clz = importutils.import_class("%s.%s" % (module, key))
198 # NOTE(vponomaryov): we need to distinguish class methods types
199 # for py2 and py3, because the concept of 'unbound methods' has
200 # been removed from the python3.x
201 member_type = inspect.isfunction
202 for method, func in inspect.getmembers(clz, member_type):
203 setattr(
204 clz, method,
205 decorator("%s.%s.%s" % (module, key, method), func))
206 # set the decorator for the function
207 if isinstance(module_data[key], pyclbr.Function):
208 func = importutils.import_class("%s.%s" % (module, key))
209 setattr(sys.modules[module], key,
210 decorator("%s.%s" % (module, key), func))
213def file_open(*args, **kwargs):
214 """Open file
216 see built-in open() documentation for more details
218 Note: The reason this is kept in a separate module is to easily
219 be able to provide a stub module that doesn't alter system
220 state at all (for unit tests)
221 """
222 return open(*args, **kwargs)
225def service_is_up(service):
226 """Check whether a service is up based on last heartbeat."""
227 last_heartbeat = service['updated_at'] or service['created_at']
228 # Timestamps in DB are UTC.
229 tdelta = timeutils.utcnow() - last_heartbeat
230 elapsed = tdelta.total_seconds()
231 return abs(elapsed) <= CONF.service_down_time
234def validate_service_host(context, host):
235 service = db_api.service_get_by_host_and_topic(context, host,
236 'manila-share')
237 if not service_is_up(service):
238 raise exception.ServiceIsDown(service=service['host'])
240 return service
243@contextlib.contextmanager
244def tempdir(**kwargs):
245 tmpdir = tempfile.mkdtemp(**kwargs)
246 try:
247 yield tmpdir
248 finally:
249 try:
250 shutil.rmtree(tmpdir)
251 except OSError as e:
252 LOG.debug('Could not remove tmpdir: %s', e)
255def walk_class_hierarchy(clazz, encountered=None):
256 """Walk class hierarchy, yielding most derived classes first."""
257 if not encountered:
258 encountered = []
259 for subclass in clazz.__subclasses__():
260 if subclass not in encountered: 260 ↛ 259line 260 didn't jump to line 259 because the condition on line 260 was always true
261 encountered.append(subclass)
262 # drill down to leaves first
263 for subsubclass in walk_class_hierarchy(subclass, encountered):
264 yield subsubclass
265 yield subclass
268def cidr_to_network(cidr):
269 """Convert cidr to network."""
270 try:
271 network = netaddr.IPNetwork(cidr)
272 return network
273 except netaddr.AddrFormatError:
274 raise exception.InvalidInput(_("Invalid cidr supplied %s") % cidr)
277def cidr_to_netmask(cidr):
278 """Convert cidr to netmask."""
279 return str(cidr_to_network(cidr).netmask)
282def cidr_to_prefixlen(cidr):
283 """Convert cidr to prefix length."""
284 return cidr_to_network(cidr).prefixlen
287def is_valid_ip_address(ip_address, ip_version):
288 ip_version = ([int(ip_version)] if not isinstance(ip_version, list)
289 else ip_version)
291 if not set(ip_version).issubset(set([4, 6])):
292 raise exception.ManilaException(
293 _("Provided improper IP version '%s'.") % ip_version)
295 if not isinstance(ip_address, str):
296 return False
298 if 4 in ip_version:
299 if netutils.is_valid_ipv4(ip_address):
300 return True
301 if 6 in ip_version:
302 if netutils.is_valid_ipv6(ip_address):
303 return True
305 return False
308def get_bool_param(param_string, params, default=False):
309 param = params.get(param_string, default)
310 if not strutils.is_valid_boolstr(param):
311 msg = _("Value '%(param)s' for '%(param_string)s' is not "
312 "a boolean.") % {'param': param, 'param_string': param_string}
313 raise exception.InvalidParameterValue(err=msg)
315 return strutils.bool_from_string(param, strict=True)
318def is_all_tenants(search_opts):
319 """Checks to see if the all_tenants flag is in search_opts
321 :param dict search_opts: The search options for a request
322 :returns: boolean indicating if all_tenants are being requested or not
323 """
324 all_tenants = search_opts.get('all_tenants')
325 if all_tenants:
326 try:
327 all_tenants = strutils.bool_from_string(all_tenants, True)
328 except ValueError as err:
329 raise exception.InvalidInput(str(err))
330 else:
331 # The empty string is considered enabling all_tenants
332 all_tenants = 'all_tenants' in search_opts
333 return all_tenants
336class IsAMatcher(object):
337 def __init__(self, expected_value=None):
338 self.expected_value = expected_value
340 def __eq__(self, actual_value):
341 return isinstance(actual_value, self.expected_value)
344class ComparableMixin(object):
345 def _compare(self, other, method):
346 try:
347 return method(self._cmpkey(), other._cmpkey())
348 except (AttributeError, TypeError):
349 # _cmpkey not implemented, or return different type,
350 # so I can't compare with "other".
351 return NotImplemented
353 def __lt__(self, other):
354 return self._compare(other, lambda s, o: s < o)
356 def __le__(self, other):
357 return self._compare(other, lambda s, o: s <= o)
359 def __eq__(self, other):
360 return self._compare(other, lambda s, o: s == o)
362 def __ge__(self, other):
363 return self._compare(other, lambda s, o: s >= o)
365 def __gt__(self, other):
366 return self._compare(other, lambda s, o: s > o)
368 def __ne__(self, other):
369 return self._compare(other, lambda s, o: s != o)
372class retry_if_exit_code(tenacity.retry_if_exception):
373 """Retry on ProcessExecutionError specific exit codes."""
374 def __init__(self, codes):
375 self.codes = (codes,) if isinstance(codes, int) else codes
376 super(retry_if_exit_code, self).__init__(self._check_exit_code)
378 def _check_exit_code(self, exc):
379 return (exc and isinstance(exc, processutils.ProcessExecutionError) and
380 exc.exit_code in self.codes)
383def retry(retry_param=Exception,
384 interval=1,
385 retries=10,
386 backoff_rate=2,
387 backoff_sleep_max=None,
388 wait_random=False,
389 infinite=False,
390 retry=tenacity.retry_if_exception_type):
392 if retries < 1: 392 ↛ 393line 392 didn't jump to line 393 because the condition on line 392 was never true
393 raise ValueError('Retries must be greater than or '
394 'equal to 1 (received: %s). ' % retries)
396 if wait_random:
397 kwargs = {'multiplier': interval}
398 if backoff_sleep_max is not None: 398 ↛ 399line 398 didn't jump to line 399 because the condition on line 398 was never true
399 kwargs.update({'max': backoff_sleep_max})
400 wait = tenacity.wait_random_exponential(**kwargs)
401 else:
402 kwargs = {'multiplier': interval, 'min': 0, 'exp_base': backoff_rate}
403 if backoff_sleep_max is not None:
404 kwargs.update({'max': backoff_sleep_max})
405 wait = tenacity.wait_exponential(**kwargs)
407 if infinite:
408 stop = tenacity.stop.stop_never
409 else:
410 stop = tenacity.stop_after_attempt(retries)
412 def _decorator(f):
414 @functools.wraps(f)
415 def _wrapper(*args, **kwargs):
416 r = tenacity.Retrying(
417 sleep=tenacity.nap.sleep,
418 before_sleep=tenacity.before_sleep_log(LOG, logging.DEBUG),
419 after=tenacity.after_log(LOG, logging.DEBUG),
420 stop=stop,
421 reraise=True,
422 retry=retry(retry_param),
423 wait=wait)
424 return r(f, *args, **kwargs)
426 return _wrapper
428 return _decorator
431def get_bool_from_api_params(key, params, default=False, strict=True):
432 """Parse bool value from request params.
434 HTTPBadRequest will be directly raised either of the cases below:
435 1. invalid bool string was found by key(with strict on).
436 2. key not found while default value is invalid(with strict on).
437 """
438 param = params.get(key, default)
439 try:
440 param = strutils.bool_from_string(param,
441 strict=strict,
442 default=default)
443 except ValueError:
444 msg = _('Invalid value %(param)s for %(param_string)s. '
445 'Expecting a boolean.') % {'param': param,
446 'param_string': key}
447 raise exc.HTTPBadRequest(explanation=msg)
448 return param
451def check_params_exist(keys, params):
452 """Validates if keys exist in params.
454 :param keys: List of keys to check
455 :param params: Parameters received from REST API
456 """
457 if any(set(keys) - set(params)):
458 msg = _("Must specify all mandatory parameters: %s") % keys
459 raise exc.HTTPBadRequest(explanation=msg)
462def check_params_are_boolean(keys, params, default=False):
463 """Validates if keys in params are boolean.
465 :param keys: List of keys to check
466 :param params: Parameters received from REST API
467 :param default: default value when it does not exist
468 :return: a dictionary with keys and respective retrieved value
469 """
470 result = {}
471 for key in keys:
472 value = get_bool_from_api_params(key, params, default, strict=True)
473 result[key] = value
474 return result
477def require_driver_initialized(func):
478 @functools.wraps(func)
479 def wrapper(self, *args, **kwargs):
480 # we can't do anything if the driver didn't init
481 if not self.driver.initialized:
482 driver_name = self.driver.__class__.__name__
483 raise exception.DriverNotInitialized(driver=driver_name)
484 return func(self, *args, **kwargs)
485 return wrapper
488def convert_str(text):
489 """Convert to native string.
491 Convert bytes and Unicode strings to native strings:
493 * convert to Unicode on Python 3: decode bytes from UTF-8
494 """
495 if isinstance(text, bytes):
496 return text.decode('utf-8')
497 else:
498 return text
501def translate_string_size_to_float(string, multiplier='G'):
502 """Translates human-readable storage size to float value.
504 Supported values for 'multiplier' are following:
505 K - kilo | 1
506 M - mega | 1024
507 G - giga | 1024 * 1024
508 T - tera | 1024 * 1024 * 1024
509 P = peta | 1024 * 1024 * 1024 * 1024
511 returns:
512 - float if correct input data provided
513 - None if incorrect
514 """
515 if not isinstance(string, str):
516 return None
517 multipliers = ('K', 'M', 'G', 'T', 'P')
518 mapping = {
519 k: 1024.0 ** v
520 for k, v in zip(multipliers, range(len(multipliers)))
521 }
522 if multiplier not in multipliers: 522 ↛ 523line 522 didn't jump to line 523 because the condition on line 522 was never true
523 raise exception.ManilaException(
524 "'multiplier' arg should be one of following: "
525 "'%(multipliers)s'. But it is '%(multiplier)s'." % {
526 'multiplier': multiplier,
527 'multipliers': "', '".join(multipliers),
528 }
529 )
530 try:
531 value = float(string.replace(",", ".")) / 1024.0
532 value = value / mapping[multiplier]
533 return value
534 except (ValueError, TypeError):
535 matched = re.match(
536 r"^(\d*[.,]*\d*)([%s])$" % ''.join(multipliers), string)
537 if matched:
538 # The replace() is needed in case decimal separator is a comma
539 value = float(matched.groups()[0].replace(",", "."))
540 multiplier = mapping[matched.groups()[1]] / mapping[multiplier]
541 return value * multiplier
544def wait_for_access_update(context, db, share_instance,
545 migration_wait_access_rules_timeout):
546 starttime = time.time()
547 deadline = starttime + migration_wait_access_rules_timeout
548 tries = 0
550 while True:
551 instance = db.share_instance_get(context, share_instance['id'])
553 if instance['access_rules_status'] == constants.STATUS_ACTIVE:
554 break
556 tries += 1
557 now = time.time()
558 if (instance['access_rules_status'] ==
559 constants.SHARE_INSTANCE_RULES_ERROR):
560 msg = _("Failed to update access rules"
561 " on share instance %s") % share_instance['id']
562 raise exception.ShareMigrationFailed(reason=msg)
563 elif now > deadline:
564 msg = _("Timeout trying to update access rules"
565 " on share instance %(share_id)s. Timeout "
566 "was %(timeout)s seconds.") % {
567 'share_id': share_instance['id'],
568 'timeout': migration_wait_access_rules_timeout}
569 raise exception.ShareMigrationFailed(reason=msg)
570 else:
571 # 1.414 = square-root of 2
572 time.sleep(1.414 ** tries)
575class DoNothing(str):
576 """Class that literrally does nothing.
578 We inherit from str in case it's called with json.dumps.
579 """
581 def __call__(self, *args, **kwargs):
582 return self
584 def __getattr__(self, name):
585 return self
588DO_NOTHING = DoNothing()
591def notifications_enabled(conf):
592 """Check if oslo notifications are enabled."""
593 notifications_driver = set(conf.oslo_messaging_notifications.driver)
594 return notifications_driver and notifications_driver != {'noop'}
597def if_notifications_enabled(function):
598 """Calls decorated method only if notifications are enabled."""
599 @functools.wraps(function)
600 def wrapped(*args, **kwargs):
601 if notifications_enabled(CONF):
602 return function(*args, **kwargs)
603 return DO_NOTHING
604 return wrapped
607def write_remote_file(ssh, filename, contents, as_root=False):
608 tmp_filename = "%s.tmp" % filename
609 if as_root:
610 cmd = 'sudo tee "%s" > /dev/null' % tmp_filename
611 cmd2 = 'sudo mv -f "%s" "%s"' % (tmp_filename, filename)
612 else:
613 cmd = 'cat > "%s"' % tmp_filename
614 cmd2 = 'mv -f "%s" "%s"' % (tmp_filename, filename)
615 stdin, __, __ = ssh.exec_command(cmd)
616 stdin.write(contents)
617 stdin.close()
618 stdin.channel.shutdown_write()
619 ssh.exec_command(cmd2)
622def convert_time_duration_to_iso_format(time_duration):
623 """Covert time duration to ISO 8601 format"""
624 unit_mapping = {
625 'minutes': 'M',
626 'hours': 'H',
627 'days': 'D',
628 'months': 'M',
629 'years': 'Y',
630 }
631 pattern = re.compile(r'(\d+)\s*(minutes|hours|days|months|years)')
632 match = pattern.match(time_duration)
633 if not match:
634 raise exception.ManilaException(
635 f"Invalid time duration format: {time_duration}")
636 value, unit = match.groups()
637 if unit in ["minutes", "hours", "days"]:
638 iso_format = f"PT{value}{unit_mapping[unit]}"
639 else:
640 iso_format = f"P{value}{unit_mapping[unit]}"
641 return iso_format