#!/usr/bin/env python
#
# Copyright (c) 2003-2007 Andrea Luzzardi <scox@sig11.org>
#
# This file is part of the pam_usb project. pam_usb is free software;
# you can redistribute it and/or modify it under the terms of the GNU General
# Public License version 2, as published by the Free Software Foundation.
#
# pam_usb is distributed in the hope that it will be useful, but WITHOUT ANY
# WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
# FOR A PARTICULAR PURPOSE.  See the GNU General Public License for more
# details.
#
# You should have received a copy of the GNU General Public License along with
# this program; if not, write to the Free Software Foundation, Inc., 51 Franklin
# Street, Fifth Floor, Boston, MA 02110-1301 USA.

import fcntl
import getopt
import gi
import os
import pwd
import re
import signal
import subprocess
import sys
import syslog
import threading

gi.require_version('UDisks', '2.0')

from gi.repository import GLib, UDisks

import xml.etree.ElementTree as et

class HotPlugDevice:
	def __init__(self, serial):
		self.__udi = None
		self.__serial = serial
		self.__callbacks = []
		self.__running = False

	def run(self):
		self.__scanDevices()
		self.__registerSignals()
		self.__running = True
		GLib.MainLoop().run()
		print('signals registered')

	def addCallback(self, callback):
		self.__callbacks.append(callback)

	def __scanDevices(self):
		for udi in udisksObjectManager.get_objects():
			if udi.get_block():
				device = udisks.get_drive_for_block(udi.get_block())
				if device:
					self.__deviceAdded(device)

	def __registerSignals(self):
		for signal, callback in (('object-added', self.__objectAdded),
				('object-removed', self.__objectRemoved)):
			udisksObjectManager.connect(signal, callback)

	def __objectAdded(self, _, udi):
		if udi.get_block():
			device = udisks.get_drive_for_block(udi.get_block())
			if device:
				self.__deviceAdded(device)

	def __objectRemoved(self, _, udi):
		if udi.get_block():
			device = udisks.get_drive_for_block(udi.get_block())
			if device:
				self.__deviceRemoved(device)

	def __deviceAdded(self, udi):
		if self.__udi is not None:
			return
		if udi.get_property('serial') != self.__serial:
			return
		self.__udi = udi
		if self.__running:
			[ cb('added') for cb in self.__callbacks ]

	def __deviceRemoved(self, udi):
		if self.__udi is None:
			return
		if self.__udi != udi:
			return
		self.__udi = None
		if self.__running:
			[ cb('removed') for cb in self.__callbacks ]

class Log:
	def __init__(self):
		syslog.openlog('pamusb-agent', syslog.LOG_PID | syslog.LOG_PERROR,
				syslog.LOG_AUTH)

	def info(self, message):
		self.__logMessage(syslog.LOG_NOTICE, message)

	def error(self, message):
		self.__logMessage(syslog.LOG_ERR, message)

	def __logMessage(self, priority, message):
		syslog.syslog(priority, message)

def usage():
	print('Usage: %s [--help] [--config=path] [--daemon] [--check=path]' % \
			os.path.basename(__file__))
	sys.exit(1)

def runAs(uid, gid):
	def set_id():
		os.setgid(gid)
		os.setuid(uid)
return set_id

import getopt

try:
	opts, args = getopt.getopt(sys.argv[1:], "hc:dc:",
			["help", "config=", "daemon", "check="])
except getopt.GetoptError:
	usage()

options = {'configFile' : '/etc/security/pam_usb.conf',
		'daemon' : False,
		'check' : '/usr/bin/pamusb-check'}

if len(args) != 0:
	usage()

for o, a in opts:
	if o in ('-h', '--help'):
		usage()
	if o in ('-c', '--config'):
		options['configFile'] = a
	if o in ('-d', '--daemon'):
		options['daemon'] = True
	if o in ('-c', '--check'):
		options['check'] = a


if not os.path.exists(options['check']):
	print('%s not found.' % options['check'])
	print("You might specify manually pamusb-check's location using --check.")
	usage()

logger = Log()

doc = et.parse(options['configFile'])
users = doc.findall('users/user')

def userDeviceThread(user):

	userName = user.get('id')
	uid = pwd.getpwnam(userName)[2]
	gid = pwd.getpwnam(userName)[3]
	os.environ = None

	events = {
		'lock' : [],
		'unlock' : []
	}

	for hotplug in user.findall('agent'):
		henvs = {}
		hcmds = []

		for hcmd in hotplug.findall('cmd'):
			if hcmd.text is not None:
				hcmds.append(hcmd.text)
			else:
				logger.error('Ignoring empty command for user "%s".' % userName)

		for henv in hotplug.findall('env'):
			if henv.text is not None:
				henv_var = re.sub(r'^(.*?)=.*$', '\\1', henv.text)
				henv_arg = re.sub(r'^.*?=(.*)$', '\\1', henv.text)

				if henv_var != '' and henv_arg != '':
					henvs[henv_var] = henv_arg
				else:
					logger.error('Ignoring invalid command environment variable for user "%s".' % userName)
			else:
				logger.error('Ignoring empty environment variable for user "%s".' % userName)

		events[hotplug.get('event')].append(
			{
				'env': henvs,
				'cmd': hcmds
			}
		)

	deviceName = user.find('device').text.strip()

	devices = doc.findall("devices/device")
	deviceOK = False
	for device in devices:
		if device.get('id') == deviceName:
			deviceOK = True
			break

	if not deviceOK:
		logger.error('Device %s not found in configuration file.' % deviceName)
		return 1

	serial = device.find('serial').text.strip()

	def authChangeCallback(event):
		if event == 'removed':
			logger.info('Device "%s" has been removed, ' \
					'locking down user "%s"...' % (deviceName, userName))

			for l in events['lock']:

				if len(l['cmd']) != 0:
					for cmd in l['cmd']:

						logger.info('Running "%s"' % cmd)
						subprocess.run(cmd.split(), env=l['env'], preexec_fn=runAs(uid, gid))

			logger.info('Locked.')
			return

		logger.info('Device "%s" has been inserted. ' \
				'Performing verification...' % deviceName)
		cmdLine = "%s --debug --config=%s --service=pamusb-agent %s" % (
				options['check'], options['configFile'], userName)
		logger.info('Executing "%s"' % cmdLine)
		if not os.system(cmdLine):
			logger.info('Authentication succeeded. ' \
					'Unlocking user "%s"...' % userName)

			for l in events['unlock']:

				if len(l['cmd']) != 0:
					for cmd in l['cmd']:

						logger.info('Running "%s"' % cmd)
						subprocess.run(cmd.split(), env=l['env'], preexec_fn=runAs(uid, gid))

			logger.info('Unlocked.')
			return

		else:
			logger.info('Authentication failed for device %s. ' \
				'Keeping user "%s" locked down.' % (deviceName, userName))

	hpDev = HotPlugDevice(serial)
	hpDev.addCallback(authChangeCallback)

	logger.info('Watching device "%s" for user "%s"' % (deviceName, userName))
	hpDev.run()

udisks = UDisks.Client.new_sync()
udisksObjectManager = udisks.get_object_manager()

sysUsers= []
validUsers = []

def processCheck():

	global filelock
	filelock=open(os.path.realpath(__file__),'r')

	try:
		fcntl.flock(filelock,fcntl.LOCK_EX|fcntl.LOCK_NB)
	except:
		logger.error('Process is already running.')
		sys.exit(1)

	if os.getuid() != 0:
		logger.error('Process must be run as root.')
		sys.exit(1)

processCheck()

try:
	with open('/etc/passwd', 'r') as f:
		for line in f.readlines():
			sysUser = re.sub(r'^(.*?):.*', '\\1', line[:-1])
			sysUsers.append(sysUser)
		f.close()
except:
	logger.error('Couldn\'t read system user names from "/etc/passwd". Process can\'t continue.')
	sys.exit(1)

logger.info('pamusb-agent up and running.')

for userObj in users:
	userId = userObj.get('id')

	for sysUser_ in sysUsers:
		if (userId == sysUser_ and
		userObj not in validUsers):
			validUsers.append(userObj)

# logger.error('User %s not found in configuration file' % username)

for user in validUsers:
	threading.Thread(
		target=userDeviceThread,
		args=(user,)
	).start()

if options['daemon'] and os.fork():
	sys.exit(0)

def sig_handler(sig, frame):
    logger.info('Stopping agent.')
    sys.exit(0)

sys_signals = ['SIGINT', 'SIGTERM', 'SIGTSTP', 'SIGTTIN', 'SIGTTOU']

for i in sys_signals:
    signal.signal(getattr(signal, i), sig_handler)