#!/usr/bin/env python3

# sshdo - controls which commands may be executed via incoming ssh
#
# Copyright (C) 2018-2023 raf <raf@raf.org>
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 2 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see https://www.gnu.org/licenses/.
#
# 20230619 raf <raf@raf.org>

from __future__ import print_function
import os, sys, re, getopt, syslog, pwd, grp, glob, gzip, copy, operator, functools
py2 = sys.version_info[0] == 2
none = None

prog_name = 'sshdo'
prog_version = '1.1.1'
prog_date = '20230619'
prog_author = 'raf <raf@raf.org>'
prog_url = 'https://raf.org/sshdo'

default_config_fname = '/etc/sshdoers'
default_logfiles = '/var/log/auth.log*'
default_match = 'digits'
syslog_facility_map = dict(auth=syslog.LOG_AUTH, daemon=syslog.LOG_DAEMON, user=syslog.LOG_USER, local0=syslog.LOG_LOCAL0, local1=syslog.LOG_LOCAL1, local2=syslog.LOG_LOCAL2, local3=syslog.LOG_LOCAL3, local4=syslog.LOG_LOCAL4, local5=syslog.LOG_LOCAL5, local6=syslog.LOG_LOCAL6, local7=syslog.LOG_LOCAL7)

def main():
	'''Parse the command line: allow, disallow, train, check, learn, unlearn'''
	config_fname = os.environ.get('SSHDO_CONFIG', default_config_fname)
	do_check = do_learn = do_unlearn = do_accepting = 0
	try:
		opts, args = getopt.gnu_getopt(sys.argv[1:], 'hVC:clua', longopts=['help', 'version', 'config=', 'check', 'learn', 'unlearn', 'accepting'])
	except getopt.GetoptError as e:
		fatal(e)
	for opt, arg in opts:
		if opt in ('-h', '--help'):
			print('usage:')
			print(' sshdo [label]               # For use as a forced command')
			print(' sshdo [options]             # For admin use on the command line')
			print('')
			print('options:')
			print(' -h, --help                  - Output the usage message')
			print(' -V, --version               - Output the version message')
			print(' -C, --config configfile     - Override default config: %s' % config_fname)
			print(' -c, --check [configfile...] - Check syntax in configuration files')
			print(' -l, --learn [logfiles...]   - Output config to allow training logs')
			print(' -u, --unlearn [logfiles...] - Output config removing unused commands')
			print(' -a, --accepting             - For learn/unlearn, accept disallowed')
			print('')
			print('usage as a forced command in ~/.ssh/authorized_keys:')
			print('  command="/usr/bin/sshdo [label]" ssh-rsa AAAA...== user@example.net')
			print('')
			print('usage as a forced command in /etc/ssh/sshd_config:')
			print('  Match User user')
			print('  ForceCommand /usr/bin/sshdo [label]')
			print('')
			print('sshdo provides an easily configurable way of controlling which commands')
			print('may be executed via incoming ssh connections.')
			print('')
			print('See the manual entries sshdo(8) and sshdoers(5) for details.')
			print('')
			print('Name: %s' % prog_name)
			print('Version: %s' % prog_version)
			print('Date: %s' % prog_date)
			print('Author: %s' % prog_author)
			print('URL: %s' % prog_url)
			print('')
			print('Copyright (C) 2018-2023 %s' % prog_author)
			print('')
			print('This is free software released under the terms of the GPLv2+:')
			print('')
			print('  https://www.gnu.org/licenses')
			print('')
			print('There is no warranty; not even for merchantability or fitness')
			print('for a particular purpose.')
			print('')
			print('Report bugs to %s' % prog_author)
			sys.exit(0)
		elif opt in ('-V', '--version'):
			print('%s-%s' % (prog_name, prog_version))
			sys.exit(0)
		elif opt in ('-C', '--config'):
			config_fname = arg
		elif opt in ('-c', '--check'):
			do_check = 1
		elif opt in ('-l', '--learn'):
			do_learn = 1
		elif opt in ('-u', '--unlearn'):
			do_unlearn = 1
		elif opt in ('-a', '--accepting'):
			do_accepting = 1
	if do_check + do_learn + do_unlearn > 1:
		fatal('The %s options are mutually exclusive' % ' and '.join((['--check'] if do_check else []) + (['--learn'] if do_learn else []) + (['--unlearn'] if do_unlearn else [])))
	if do_accepting and do_learn + do_unlearn == 0:
		fatal('The --accepting option requires the --learn or --unlearn option')
	if do_check:
		sys.exit(check(config_fname, args))
	if do_learn:
		sys.exit(learn(config_fname, args, do_accepting))
	if do_unlearn:
		sys.exit(unlearn(config_fname, args, do_accepting))
	sshdo(config_fname, s('[\s:]', '_', ' '.join(args)) if args else none)

def fatal(msg):
	'''Emit an error message to stderr and exit with status 1'''
	print('error: %s' % msg, file=sys.stderr)
	sys.exit(1)

def sshdo(config_fname, label):
	'''Allow and execute or disallow $SSH_ORIGINAL_COMMAND (logging everything)'''
	# Load the configuration
	config = load_config(config_fname)
	# Get the user (trust sshd) and the requested command
	ssh_user = os.environ.get('USER')
	ssh_cmd = s('^\s+', '', s('\s+$', '', os.environ.get('SSH_ORIGINAL_COMMAND', '<interactive>')))
	# Check the command and allow and execute it or disallow it
	auth = check_auth(config, ssh_user, label, ssh_cmd)
	if auth == 'allowed':
		logmsg(config, syslog.LOG_INFO, 'allowed', labelf(label), commandf(ssh_cmd))
		shell_exec(config, ssh_user, label, ssh_cmd)
	if auth.startswith('allowed-group-'):
		logmsg(config, syslog.LOG_INFO, 'allowed', labelf(label), commandf(ssh_cmd), groupf(auth[len('allowed-group-'):]))
		shell_exec(config, ssh_user, label, ssh_cmd)
	if auth == 'training':
		logmsg(config, syslog.LOG_ERR, 'training', labelf(label), commandf(ssh_cmd))
		shell_exec(config, ssh_user, label, ssh_cmd)
	if auth.startswith('training-group-'):
		logmsg(config, syslog.LOG_ERR, 'training', labelf(label), commandf(ssh_cmd), groupf(auth[len('training-group-'):]))
		shell_exec(config, ssh_user, label, ssh_cmd)
	if auth == 'disallowed':
		logmsg(config, syslog.LOG_ERR, 'disallowed', labelf(label), commandf(ssh_cmd))
		if 'banner:' in config:
			try:
				with open(config['banner:']) as f:
					sys.stderr.write(f.read())
			except IOError as e:
				logmsg(config, syslog.LOG_ERR, 'configerror', filenamef(config['banner:']), errorf(e))
		sys.exit(1)

def check(config_fname, args):
	'''Check the syntax of any config files given on the command line, or the default'''
	total_errors = 0
	for fname in args if args else [config_fname]:
		config, errors = load_config(fname, verbose=1)
		total_errors += errors
		if errors == 0:
			print('%s syntax OK' % fname)
	return min(total_errors, 255)

def sshdo_log_messages(config, filenames):
	'''Scan the log files, find and yield sshdo log messages'''
	# Get the configured or default log files if not supplied on the command line
	if not filenames:
		filenames = functools.reduce(operator.add, [sort(glob.glob(_)) if _ != '-' else _ for _ in split('\s+', config['logfiles:'] if 'logfiles:' in config else default_logfiles)])
	if not filenames:
		fatal('No log files found: %s' % (config['logfiles:'] if 'logfiles:' in config else default_logfiles))
	# Yield sshdo log messages
	for fname in filenames:
		try:
			l = 0
			for line in (sys.stdin if fname == '-' else gzip.open(fname, mode='rb' if py2 else 'rt') if fname.endswith('.gz') else open(fname)):
				l += 1
				# Identify sshdo log messages
				if not ' %s[' % prog_name in line and not ' %s:' % prog_name in line:
					continue
				# Extract information
				valre = r'(?:\\(?:\\|"|x[0-9a-fA-F]{2})|[^\\"])*'
				match = m(' type="(%s)" user="(%s)" (?:remoteip="%s" )?(?:label="(%s)" )?command="(%s)"(?: group="(%s)")?(?: config="(%s)")?' % (valre, valre, valre, valre, valre, valre, valre), line)
				if match is none:
					continue
				msgtype, user, label, cmd, grp, cfg = [decode_log(_) if _ is not none else _ for _ in p(match)]
				yield fname, l, line, msgtype, user if grp is none else '+' + grp, label, cmd, cfg
		except IOError as e:
			fatal('Failed to read: %s: %s' % (fname, e))

def learn(config_fname, filenames, accepting):
	'''Scan log files for sshdo training and disallowed messages and output new configuration
	directives to allow them. Disallowed messages result in commented out directives.'''
	# Load the current configuration
	config = load_config(config_fname)
	# Examine log files, building up new configuration directives
	trained = {} # trained[cmd][user/label] = '#' or ''
	for fname, l, line, msgtype, user, label, cmd, cfg in sshdo_log_messages(config, filenames):
		# Skip if related to a different configuration file
		if not (cfg == config_fname or config_fname == default_config_fname and cfg is none):
			continue
		# Skip if the command was allowed at the time
		if not (msgtype == 'training' or msgtype == 'disallowed'):
			continue
		# Skip if currently allowed (might have changed since being logged)
		auth = check_auth(config, user, label, cmd)
		if auth == 'allowed':
			continue
		# Learn this. It'll be commented out if disallowed and not accepting.
		# Interactive logins will also be commented out.
		if cmd not in trained:
			trained[cmd] = {}
		prefix = '# ' if msgtype == 'disallowed' and not accepting or cmd == '<interactive>' else ''
		userlabel = ('%s/%s' % (user, label)) if label else user
		if prefix == '# ' and userlabel in trained[cmd] and trained[cmd][userlabel] == '':
			prefix = '' # Keep existing ''
		trained[cmd][userlabel] = prefix
	# Output the configuration directives needed to allow these commands.
	# Note: Any -user config will override this.
	if trained:
		trained = coalesce_commands(config['match:'] if 'match:' in config else default_match, trained)
		for cmd in sort(trained.keys()):
			# Exclude users with labels if the same user is present without a label
			users = [u for u in trained[cmd] if trained[cmd][u] == '' and not ('/' in u and s('/.*$', '', u) in trained[cmd])]
			if users:
				print('%s: %s' % (' '.join(sort(users)), encode_conf(cmd)))
			users = [u for u in trained[cmd] if trained[cmd][u] == '# ' and not ('/' in u and s('/.*$', '', u) in trained[cmd])]
			if users:
				print('# %s: %s' % (' '.join(sort(users)), encode_conf(cmd)))

def unlearn(config_fname, filenames, accepting):
	'''Scan log files for sshdo messages. Compare them against the current configuration
	and output the existing directives that are still needed to allow them. Comment out the
	directives that are no longer needed.'''
	# Load the current configuration
	config = load_config(config_fname)
	# Examine log files, recording configuration usage
	used = { 'match:': config['match:'] if 'match:' in config else default_match } # used[user][label][cmd] = 1
	seen = {} # seen[line] = 1
	for fname, l, line, msgtype, user, label, cmd, cfg in sshdo_log_messages(config, filenames):
		# Skip if related to a different configuration file
		if not (cfg == config_fname or config_fname == default_config_fname and cfg is none):
			continue
		# Skip unless the command was executed or we are in accepting mode
		if not (msgtype == 'allowed' or msgtype == 'training' or msgtype == 'disallowed' and accepting):
			continue
		# Skip interactive logins
		if cmd == '<interactive>':
			continue
		# Record this usage
		if user not in used:
			used[user] = {}
		if label not in used[user]:
			used[user][label] = {}
		used[user][label][cmd] = 1
	# Compare current configuration against log messages (prepare for output)
	current = {} # current[cmd][user/label] = '# ' or ''
	for user in [u for u in config if not u.endswith(':')]:
		for label in config[user]:
			for cmd in config[user][label]:
				keep = user.startswith('-') or (user in used and label in used[user] and (cmd in used[user][label] or any([check_command(used, c, { cmd: '' }) for c in used[user][label].keys()])))
				if not keep and user in used and label is none:
					for l in used[user].keys():
						keep = cmd in used[user][l] or any([check_command(used, c, { cmd: '' }) for c in used[user][l].keys()])
						if keep:
							break
				if cmd not in current:
					current[cmd] = {}
				current[cmd][('%s/%s' % (user, label)) if label else user] = '' if keep else '# '
	current = coalesce_commands(config['match:'] if 'match:' in config else default_match, current, 1)
	# Output directives to allow used commands and commented out directives for unused commands
	for cmd in sort(current.keys()):
		# Exclude users with labels if the same user is present without a label
		users = [u for u in current[cmd] if current[cmd][u] == '' and not ('/' in u and s('/.*$', '', u) in current[cmd])]
		if users:
			print('%s: %s' % (' '.join(sort(users)), encode_conf(cmd)))
		users = [u for u in current[cmd] if current[cmd][u] == '# ' and not ('/' in u and s('/.*$', '', u) in current[cmd])]
		if users:
			print('# %s: %s' % (' '.join(sort(users)), encode_conf(cmd)))

def load_config(config_fname, verbose=0):
	'''Parse the configuration files and return a nested dict:
	config[user][label][command]. The user can be preceded by '-' in which
	case it disallows the command. The user can be preceded by '+' in which
	case it is a group. The label may be None in which case it applies to
	all labels and no label.'''
	config = {} # config[user][label][cmd] = 1
	config['config:'] = config_fname
	errors = 0
	users_seen = {}
	def errormsg(msg):
		print(msg, file=sys.stderr)
		return 1
	for fname in [config_fname] + sort(glob.glob(config_fname + '.d/*')):
		try:
			l = 0
			with open(fname, 'U' if py2 else 'r') as f:
				for line in f:
					l += 1
					# Clean the input
					line = s('^\s+', '', s('\s+$', '', s('^\s*#.*$', '', line)))
					if not line:
						continue
					# Parse an authorization directive: user user/label +group/label -user: cmd arg...
					match = m('^([+-]?[^\s:]+(?:\s+[+-]?[^\s:]+)*)\s*:\s*(.*)$', line)
					if match is not none:
						usernames, cmd = p(match, 1), decode_conf(p(match, 2))
						# Store the configuration data
						usernames = split('\s+', usernames)
						for username in usernames:
							# Is there a label?
							if '/' in username:
								username, label = split('/', username, maxsplit=1)
							else:
								label = none
							# Report possible problems when doing --check
							if verbose and not username in users_seen:
								users_seen[username] = 1
								if username[0] == '+':
									groupname = username[1:]
									try:
										_ = grp.getgrnam(groupname)
									except KeyError as e:
										errors += errormsg('warning: No such group: %s [%s line %s]' % (groupname, fname, l))
								else:
									actual_username = username[1:] if username[0] == '-' else username
									try:
										_ = pwd.getpwnam(actual_username)
									except KeyError as e:
										errors += errormsg('warning: No such user: %s [%s line %s]' % (actual_username, fname, l))
									# Check for clashes of user and -user:
									# (a) -user clashes with user and user/*
									# (b) -user/label clashes with user and user/label
									# (c) user clashes with -user and -user/*
									# (d) user/label clashes with -user and -user/label
									opposite = ('' if username[0] == '-' else '-') + actual_username
									# Check:
									# user against -user
									# user/label against -user/label
									# -user against user
									# -user/label against user/label
									if opposite in config and label in config[opposite] and cmd in config[opposite][label]:
										errors += errormsg('warning: Clashing allow/disallow: %s%s and %s%s: %s [%s line %s]' % (username, ('/' + label) if label else '', opposite, ('/' + label) if label else '', cmd, fname, l))
									# Check:
									# user/label against -user
									# -user/label against user
									if label is not none and opposite in config and none in config[opposite] and cmd in config[opposite][none]:
										errors += errormsg('warning: Clashing allow/disallow: %s/%s and %s: %s [%s line %s]' % (username, label, opposite, cmd, fname, l))
									# Check:
									# user against -user/*
									# -user against user/*
									if label is none and opposite in config:
										for other_label in sort(config[opposite].keys()):
											if other_label is not none and cmd in config[opposite][other_label]:
												errors += errormsg('warning: Clashing allow/disallow: %s and %s/%s: %s [%s line %s]' % (username, opposite, other_label, cmd, fname, l))
							# Build the configuration data
							if username not in config:
								config[username] = {}
							if label not in config[username]:
								config[username][label] = {}
							config[username][label][cmd] = 1
						continue
					# Parse a training directive: training user user/label +group/label -user
					match = m('^training((?:\s+[+-]?[^\s:]+)*)$', line)
					if match is not none:
						usernames = p(match, 1)
						usernames = split('\s+', s('^\s+', '', usernames)) if usernames else []
						if usernames or fname == config_fname:
							if 'training:' not in config:
								config['training:'] = {}
							for username in usernames:
								# Is there a label?
								if '/' in username:
									username, label = split('/', username, maxsplit=1)
								else:
									label = none
								# Report possible problems when doing --check
								if verbose and not username in users_seen:
									users_seen[username] = 1
									if username[0] == '+':
										groupname = username[1:]
										try:
											_ = grp.getgrnam(groupname)
										except KeyError as e:
											errors += errormsg('warning: No such group: %s [%s line %s]' % (groupname, fname, l))
									else:
										actual_username = username[1:] if username[0] == '-' else username
										try:
											_ = pwd.getpwnam(actual_username)
										except KeyError as e:
											errors += errormsg('warning: No such user: %s [%s line %s]' % (actual_username, fname, l))
										# Check for clashes of user and -user:
										# (a) -user clashes with user and user/*
										# (b) -user/label clashes with user and user/label
										# (c) user clashes with -user and -user/*
										# (d) user/label clashes with -user and -user/label
										opposite = ('' if username[0] == '-' else '-') + actual_username
										training = config['training:']
										# Check:
										# user against -user
										# user/label against -user/label
										# -user against user
										# -user/label against user/label
										if opposite in training and label in training[opposite]:
											errors += errormsg('warning: Clashing training mode: %s%s and %s%s [%s line %s]' % (username, ('/' + label) if label else '', opposite, ('/' + label) if label else '', fname, l))
										# Check:
										# user/label against -user
										# -user/label against user
										if label is not none and opposite in training and none in training[opposite]:
											errors += errormsg('warning: Clashing training mode: %s/%s and %s [%s line %s]' % (username, label, opposite, fname, l))
										# Check:
										# user against -user/*
										# -user against user/*
										if label is none and opposite in training:
											for other_label in sort(training[opposite].keys()):
												if other_label is not none:
													errors += errormsg('warning: Clashing training mode: %s and %s/%s [%s line %s]' % (username, opposite, other_label, fname, l))
								# Build the configuration data
								if username not in config['training:']:
									config['training:'][username] = {}
								if label not in config['training:'][username]:
									config['training:'][username][label] = {}
						elif usernames == [] and fname != config_fname and verbose:
							errors += errormsg('warning: Invalid config: %s (only allowed in %s) [%s line %s]' % (line, config_fname, fname, l))
						continue
					# Parse a syslog directive
					match = m('^syslog\s+(%s)$' % '|'.join(syslog_facility_map.keys()), line, 'i')
					if match is not none:
						if fname == config_fname:
							if verbose and 'syslog:' in config:
								errors += errormsg('warning: syslog specified more than once: %r and %r [%s line %s]' % (config['syslog:'], p(match, 1).lower(), fname, l))
							config['syslog:'] = p(match, 1).lower()
						elif verbose:
							errors += errormsg('warning: Invalid config: %s (only allowed in %s) [%s line %s]' % (line, config_fname, fname, l))
						continue
					# Parse a logfiles directive
					match = m('^logfiles\s+(.+)$', line, 'i')
					if match is not none:
						if fname == config_fname:
							config['logfiles:'] = (config['logfiles:'] + ' ' + p(match, 1)) if 'logfiles:' in config else p(match, 1)
							if verbose:
								for pattern in split('\s+', config['logfiles:']):
									if pattern != '-' and len(glob.glob(pattern)) == 0:
										errors += errormsg('warning: No such logfiles: %s [%s line %s]' % (pattern, fname, l))
						elif verbose:
							errors += errormsg('warning: Invalid config: %s (only allowed in %s) [%s line %s]' % (line, config_fname, fname, l))
						continue
					# Parse a match directive
					match = m('^match\s+(exact|digits|hexdigits)$', line, 'i')
					if match is not none:
						if fname == config_fname:
							if verbose and 'match:' in config:
								errors += errormsg('warning: match specified more than once: %r and %r [%s line %s]' % (config['match:'], p(match, 1).lower(), fname, l))
							config['match:'] = p(match, 1).lower()
						elif verbose:
							errors += errormsg('warning: Invalid config: %s (only allowed in %s) [%s line %s]' % (line, config_fname, fname, l))
						continue
					# Parse a banner directive
					match = m('^banner\s+(.*)$', line, 'i')
					if match is not none:
						if fname == config_fname:
							if verbose and 'banner:' in config:
								errors += errormsg('warning: banner specified more than once: %r and %r [%s line %s]' % (config['banner:'], p(match, 1), fname, l))
							config['banner:'] = p(match, 1)
							if verbose and not os.path.isfile(config['banner:']):
								errors += errormsg('warning: No such banner: %s [%s line %s]' % (config['banner:'], fname, l))
						elif verbose:
							errors += errormsg('warning: Invalid config: %s (only allowed in %s) [%s line %s]' % (line, config_fname, fname, l))
						continue
					# Report unrecognised lines
					if verbose:
						errors += errormsg('error: Invalid config: %s [%s line %s]' % (line, fname, l))
					else:
						logmsg(config, syslog.LOG_ERR, 'configerror', filenamef(fname), linenumberf(l), linef(line))
		except (IOError, UnicodeDecodeError) as e:
			if verbose:
				errors += errormsg('error: Failed to read: %s: %s' % (fname, e))
			else:
				logmsg(config, syslog.LOG_ERR, 'configerror', filenamef(fname), errorf(e))
	if verbose and 'logfiles:' not in config and len(glob.glob(default_logfiles)) == 0:
		errors += errormsg('warning: No default log files: %s' % default_logfiles)
	return (config, errors) if verbose else config

def check_auth(config, user, label, cmd):
	'''Return whether or not the given user/label is allowed to execute the given command'''
	# Check exclusions first
	if check_config(config, '-' + user, label, cmd):
		return 'disallowed'
	# Check user
	if check_config(config, user, label, cmd):
		return 'allowed'
	# Check groups
	for group in get_groups(user):
		if check_config(config, '+' + group, label, cmd):
			return 'allowed-group-' + group
	# Check training mode
	if 'training:' in config:
		training = config['training:']
		if training == {}:
			return 'training'
		if check_config(training, '-' + user, label):
			return 'disallowed'
		if check_config(training, user, label):
			return 'training'
		for group in get_groups(user):
			if check_config(training, '+' + group, label):
				return 'training-group-' + group
	# Disallow
	return 'disallowed'

def check_config(config, user, label, cmd=none):
	'''Return whether or not config contains user/label and cmd'''
	if user in config and none in config[user] and (cmd is none or cmd in config[user][none] or check_command(config, cmd, config[user][none])):
		return 1
	if label is not none and user in config and label in config[user] and (cmd is none or cmd in config[user][label] or check_command(config, cmd, config[user][label])):
		return 1
	return 0

def check_command(config, cmd, user_label_config):
	'''Compare a command against config that includes # patterns. cmd is the
	command. user_label_config is a dict with shell command patterns as keys.'''
	for command_pattern in sort(user_label_config.keys()):
		# Ignore commands without hashes because they've already been checked
		if '#' not in command_pattern:
			continue
		# Quote/backslash all meta-characters except "/" and "#"
		command_pattern = s('([^/\w#])', r'\\\1', command_pattern)
		# Replace # and ##+ with corresponding patterns (see hash_pattern)
		command_pattern = s('(#+)', lambda match: hash_pattern(config, p(match, 1)), command_pattern)
		# Does the command match this pattern?
		if m('^' + command_pattern + '$', cmd) is not none:
			return 1
	return 0

def hash_pattern(config, hashes):
	'''Given a string of hashes, return the corresponding regular expression fragment
	for command pattern matching. A single hash matches a hash or one or more digits.
	Multiple hashes match the same number of hashes and/or digits. For hexdigits matching,
	match hexadecimal digits. For exact matching, match literally.'''
	match_style = config['match:'] if 'match:' in config else default_match
	if match_style == 'digits':
		return '(?:#|[0-9]+)' if len(hashes) == 1 else '[#0-9]{%s}' % len(hashes)
	if match_style == 'hexdigits':
		return '(?:#|[0-9a-fA-F]+)' if len(hashes) == 1 else '[#0-9a-fA-F]{%s}' % len(hashes)
	return hashes # exact

def shell_exec(config, user, label, cmd):
	'''Replace the current process with the given shell command. If cmd is
	"<interactive>", the user's login shell is executed.'''
	try:
		pw_shell = pwd.getpwnam(user).pw_shell
	except KeyError as e: # Can't happen if invoked by sshd
		logmsg(config, syslog.LOG_ERR, 'execerror', labelf(label), commandf(cmd), errorf('Invalid user'))
		sys.exit(1)
	login_shell = pw_shell if pw_shell else '/bin/sh'
	sh = s('^.*/', '', login_shell)
	try:
		os.execv(login_shell, ['-' + sh] if cmd == '<interactive>' else [sh, '-c', cmd])
	except OSError as e:
		logmsg(config, syslog.LOG_ERR, 'execerror', labelf(label), commandf(cmd), errorf(e))
	sys.exit(1)

def coalesce_commands(method, commands, unlearning=0):
	'''commands = coalesce_commands(method, commands, unlearning=0)

	Return a coalesced commands data structure replacing the top-level keys,
	which are shell commands, with patterns that match similar commands. If
	unlearning, keys can be patterns.

	The method parameter is one of: "exact", "digits" or "hexdigits".

	The "exact" method: Each hash character represents a literal hash
	character only.

	The "digits" method: Single hash characters represent either a hash
	character or one or more decimal digits. Two or more hash characters
	represent the same number of hash characters or decimal digits (each
	character can either be a hash character or a digit).

	The "hexdigits" method: Single hash characters represent either a hash
	character or one or more hexadecimal digits. Two or more hash characters
	represent the same number of hash characters or hexadecimal digits (each
	character can either be a hash character or a digit).

	The commands parameter is a dict whose keys are shell commands/patterns
	and whose values are dicts whose keys are strings of the form "user" or
	"user/label" and whose values are either "# " or "".

	So, the commands data structure looks like:

		commands["cmd"]["user/label"] = "# " or ""

	It represents commands that were requested by certain users and whether
	("") or not ("# ") they were allowed at the time. If unlearning, it
	represents authorizations and whether or not they were encountered.

	This function examines the top-level command/pattern keys in the
	commands parameter, identifying subsets that are similar and can be
	replaced by a pattern that matches them. The data structure returned
	will have the patterns as the keys and the values will be the combined
	values of all keys in commands that match the pattern.

	Where similar commands keys differ by digit substrings of different
	lengths, the pattern has a single hash character to represent one or
	more digits.

	Where similar commands keys differ by digit substrings with the same
	number of digits, the pattern has the same number of hash characters to
	represent the fixed number of digits.

	The resulting data structure, whose keys are command patterns, has
	values that are dicts that combine the values of all matching commands.
	The values also include an additional key "commands:" whose value is a
	dict whose keys are the commands keys that match the pattern (the values
	are ""). This doesn't apply to the "exact" method.

	In cases where similar commands keys have values whose user-level keys
	are the same (e.g. "sam"), and the corresponding values differ (i.e. ""
	and "# "), the data structure returned will use the "# " value. When
	unlearning, it will use the "" value.

	Here's an illustrative example of the input and output of this function:

		coalesce_commands('digits', {
			'echo 1': { 'sam': '' },
			'echo 12': { 'drew/label': '# ', 'alex': '' },
			'echo 123': { 'jude': '# ' },

			'echo aaa 12': { 'sam': '' },
			'echo aaa 23': { 'drew': '', 'sam': '# ' },
			'echo aaa 34': { 'drew': '# ' },

			'echo bbb 1 11 1 x': { 'drew': '' },
			'echo bbb 1 11 11 x': { 'drew': '' },
			'echo bbb 1 11 111 x': { 'drew': '' },

			'echo ccc 1 11 1 y': { 'drew': '' },
			'echo ccc 2 22 22 y': { 'drew': '' },
			'echo ccc 3 33 333 y': { 'drew': '' },
		})
		==
		{
			'echo #': { 'sam': '', 'drew/label': '# ', 'alex': '', 'jude': '# ', 'commands:': { 'echo 1': '', 'echo 12': '', 'echo 123': '' } },
			'echo aaa ##': { 'sam': '# ', 'drew': '# ', 'commands:': { 'echo aaa 12': '', 'echo aaa 23': '', 'echo aaa 34': '' } },
			'echo bbb 1 11 # x': { 'drew': '', 'commands:': { 'echo bbb 1 11 1 x': '', 'echo bbb 1 11 11 x': '', 'echo bbb 1 11 111 x': '' } },
			'echo ccc # ## # y': { 'drew': '', 'commands:': { 'echo ccc 1 11 1 y': '', 'echo ccc 2 22 22 y': '', 'echo ccc 3 33 333 y': '' } },
		}
	'''

	if method == 'exact':
		return commands

	# For each command, determine if it's similar to any preceding command
	# that is already in the result. If it isn't, add it to the result. If
	# it is, combine it with that preceding command in the result,
	# maintaining the set of possible patterns for each digits segment and
	# choosing the best available pattern for each segment at the end.

	def digit_patterns(digits):
		'''Return the list of all patterns that represent the digits parameter: e.g.
		digit_patterns('1') == ['1', '#']
		digit_patterns('123') == ['123', '#', '###']'''
		return [digits, '#'] + (['#' * len(digits)] if len(digits) != 1 else [])

	def isstr(v):
		return isinstance(v, basestring) if py2 else type(v) in [bytes, str]

	def similar_cmd(asegments, bsegments):
		'''Return whether or not asegments is similar to bsegments'''
		if len(asegments) != len(bsegments):
			return 0
		for i in range(len(asegments)):
			a, b = asegments[i], bsegments[i]
			alit, blit = isstr(a), isstr(b)
			# Check that literals pair up and pattern lists pair up
			if alit and not blit or not alit and blit:
				return 0
			# Digit patterns are always compatible
			if not alit:
				continue
			# Check that literals are the same
			if a != b:
				return 0
		return 1

	def coalesce_segments(asegments, bsegments):
		'''Merge asegments into bsegments to represent asegments in addition
		to any other commands that it already represents. asegments and
		bsegments must be similar'''
		nsegments = []
		for i in range(len(asegments)):
			aseg, bseg = asegments[i], bsegments[i]
			if isstr(aseg):
				nsegments.append(aseg)
				continue
			# First literal digits maybe, then single hash always, then multiple hashes maybe
			alit = not aseg[0].startswith('#')
			blit = not bseg[0].startswith('#')
			alen = len(aseg[-1]) if aseg[-1] != '#' else 0
			blen = len(bseg[-1]) if bseg[-1] != '#' else 0
			# Keep literal digits if they match exactly
			nseg = [aseg[0], '#'] if alit and blit and aseg[0] == bseg[0] else ['#']
			# Keep multiple hashes if they match exactly
			if alen and blen and alen == blen:
				nseg.append(aseg[-1])
			nsegments.append(nseg)
		return nsegments

	def best_pattern(patterns):
		'''Given a list of potential patterns (e.g. ['123', '#', '###']),
		return the best'''
		return patterns[0] if not patterns[0].startswith('#') else patterns[-1]

	def segments_to_cmd(command_segments):
		'''Return a command pattern based on the given segments'''
		return ''.join([_ if isstr(_) else best_pattern(_) for _ in command_segments])

	def coalesce_usage(ausage, busage, cmd, command_segments):
		'''Merge ausage and busage. Where there is a clash between '# ' and
		'', '# ' wins unless unlearning'''
		nusage = copy.deepcopy(busage)
		for u in ausage.keys():
			if u not in nusage:
				nusage[u] = ausage[u]
				continue
			if ausage[u] != nusage[u]:
				nusage[u] = '' if unlearning else '# '
		return nusage

	# Construct result as a list rather than a dict because we need the
	# commands as lists of segments which aren't hashable
	build = []
	digit_class = '#0-9' if method == 'digits' else '#0-9a-fA-F'
	for cmd in sort(commands.keys()):
		# Split the command into segments: literal, digits, literal, ...
		# Odd-numbered segments are the digits.
		segments = split('([%s]+)' % digit_class, cmd)
		# Replace digits segments with lists of possible replacement patterns
		segments = [segments[i] if i % 2 == 0 else digit_patterns(segments[i]) for i in range(len(segments))]
		# Look for a similar command already in build, or add it
		found = 0
		build.sort()
		for i in range(len(build)):
			bsegments, busage = build[i]
			if similar_cmd(segments, bsegments):
				msegments = coalesce_segments(segments, bsegments)
				musage = coalesce_usage(commands[cmd], busage, cmd, msegments)
				musage['commands:'][cmd] = ''
				build[i] = [msegments, musage]
				found = 1
		if not found:
			usage = copy.deepcopy(commands[cmd])
			usage['commands:'] = { cmd: '' }
			build.append([segments, usage])

	# Convert build to a dict and return it
	coalesced = {}
	for csegments, cusage in build:
		coalesced[segments_to_cmd(csegments)] = cusage
	return coalesced

def get_groups(username):
	'''Return a list of the names of all of the given user's groups'''
	groups = []
	# Get the primary group
	try:
		pwdent = pwd.getpwnam(username)
	except KeyError as e:
		return [] # Can't happen if invoked by sshd
	gid = pwdent.pw_gid
	try:
		grpent = grp.getgrgid(gid)
		groups.append(grpent.gr_name)
	except KeyError as e:
		pass # Only if /etc/group is broken
	# Get the supplementary groups
	for grpent in grp.getgrall():
		if username in grpent.gr_mem:
			groups.append(grpent.gr_name)
	return groups

def logmsg(config, priority, msgtype, *msg):
	'''Emit the given message via syslog to auth.priority by default'''
	facility = syslog_facility_map[config['syslog:']] if 'syslog:' in config else syslog.LOG_AUTH
	fields = [logfield('type', msgtype)]
	user = os.environ.get('USER')
	if user:
		fields.append(logfield('user', user))
	client = os.environ.get('SSH_CLIENT')
	if client:
		fields.append(logfield('remoteip', s(' .*$', '', client)))
	fields.extend(list(msg))
	if config['config:'] != default_config_fname:
		fields.append(logfield('config', config['config:']))
	syslog.openlog(prog_name, syslog.LOG_PID, facility)
	syslog.syslog(priority, ' '.join([_ for _ in fields if _ is not none]))
	syslog.closelog()

def logfield(name, value):
	'''Return name and value for use in log messages: e.g. name="value"'''
	return ('%s="%s"' % (name, encode_log(str(value)))) if value is not none else none

commandf = lambda value: logfield('command', value)
filenamef = lambda value: logfield('filename', value)
linenumberf = lambda value: logfield('linenumber', value)
linef = lambda value: logfield('line', value)
errorf = lambda value: logfield('error', value)
labelf = lambda value: logfield('label', value)
groupf = lambda value: logfield('group', value)

encode_log = lambda data: s('([\x00-\x1f"\\\\])', lambda _: '\\' + (p(_, 1) if p(_, 1) in '"\\' else 'x%02x' % ord(p(_, 1))), data)
decode_log = lambda data: s('\\\\(x[0-9a-fA-F]{2}|["\\\\])', lambda _: p(_, 1) if p(_, 1) in '"\\' else chr(int(p(_, 1)[1:], 16)), data)
encode_conf = lambda data: data if not m('[\x00-\x1f]', data) else '<binary> ' + s('([\x00-\x1f\\\\])', lambda _: '\\' + (p(_, 1) if p(_, 1) == '\\' else 'x%02x' % ord(p(_, 1))), data)
decode_conf = lambda data: data if not data.startswith('<binary>') else s('\\\\(x[0-9a-fA-F]{2}|[\\\\])', lambda _: p(_, 1) if p(_, 1) in '"\\' else chr(int(p(_, 1)[1:], 16)), s('^<binary>\s*', '', data))

_re_cache = {}
def r(pattern, opts=''):
	'''r(pattern, opts='') -> compiled regular expression object

	Like re.compile() but with compact options. opts is a string containing
	any of the characters 'ilmsxu' each corresponding to re.I, re.L etc. If
	the pattern is a unicode object, the option re.U is automatically included.

	usage: pattern = r('([a-e])\d', 'i')'''
	# Not a string? Must already be compiled.
	if not (isinstance(pattern, basestring) if py2 else type(pattern) in [bytes, str]):
		return pattern
	# Check the cache
	key = pattern + '\0\0\0' + opts
	if key in _re_cache:
		return _re_cache[key]
	# Parse opts
	flags, odict = 0, { 'i': re.I, 'l': re.L, 'm': re.M, 's': re.S, 'x': re.X, 'u': re.U }
	for o in opts:
		flags |= odict[o]
	if type(pattern) == (unicode if py2 else str):
		flags |= re.U
	# Compile, cache and return
	recomp = re.compile(pattern, flags)
	_re_cache[key] = recomp
	return recomp

def m(pattern, text, opts='', pos=0, endpos=none):
	'''m(pattern, text, opts='', pos=0, endpos=None) -> re.MatchObject

	Like re.search() but with more compact options. See r().

	usage: match = m(pattern, text)'''
	return r(pattern, opts).search(text, pos, endpos if endpos is not none else len(text))

def s(pattern, rep, text, opts='', count=0):
	'''s(pattern, rep, text, opts='', count=0) -> str

	Alias for re.sub().

	usage: text = s('\s+', ' ', text)'''
	return r(pattern, opts).sub(rep, text, count)

def p(match, index=none):
	'''p(match, index=None) -> a string or a list of strings

	When index is supplied, it's the same as match.group(index).
	Otherwise, it's the same as match.groups().

	usage: match = m('a(b)c', 'abc'); print(len(p(match)), p(match, 0))'''
	if match is none:
		return [] if index is none else none
	return match.group(index) if index is not none else match.groups()

def split(pattern, text, opts='', maxsplit=0):
	'''split(pattern, text, opts='', maxsplit=0) -> list

	Alias for re.split().

	usage: split(pattern, text)'''
	return r(pattern, opts).split(text, maxsplit)

def sort(seq, key=none):
	'''sort(sequence, key=None) -> sequence

	Sort sequence and then return it.

	usage: print(sort([4, 3, 2, 1]))'''
	if not py2:
		seq = list(seq)
	seq.sort(key=key)
	return seq

if __name__ == '__main__':
	main()

# vi:set ts=4 sw=4:
