#!/usr/libexec/platform-python
#
# Copyright (C) 2022 Olaf Kirch <okir@suse.de>

import susetest

susetest.requireAddressResource("ipv4_address")
susetest.optionalAddressResource("ipv6_address")

def _command_should_succeed(node, st):
	if not st:
		node.logFailure(f"Command {st.command.commandline} failed: {st.message}")
		return False
	return True

def _command_should_fail(node, st):
	if not st:
		return True
	node.logFailure(f"Command {st.command.commandline} succeeded (but should have failed)")
	return False


class SSH:
	def __init__(self, driver):
		self.driver = driver
		self.client = driver.client
		self.server = driver.server

		self._client_user = None
		self._server_user = None

		self._keys = {}
		self._authorized = None
		self._haveKnownHosts = False

	@property
	def client_user(self):
		if self._client_user is None:
			user = self.client.requireUser("test-user")
			if not user.uid:
				node.logError("user %s does not seem to exist" % user.login)
				raise susetest.CriticalResourceMissingError("lacking test-user resource")

			self._client_user = user

			assert(user.login == self.server_user.login)

		return self._client_user

	@property
	def server_user(self):
		if self._server_user is None:
			user = self.server.requireUser("test-user")
			if not user.uid:
				node.logError("user %s does not seem to exist" % user.login)
				raise susetest.CriticalResourceMissingError("lacking test-user resource")

			self._server_user = user
		return self._server_user

	def getKey(self, keyfile):
		return self._keys.get(keyfile)

	def dropKey(self, keyfile):
		key = self.getKey(keyfile)
		if key is None:
			return

		self.runClient(f"rm -f ~/.ssh/{keyfile}*", quiet = True)
		del self._keys[keyfile]

		if self._authorized == key:
			self._authorized = None

	def removeKnownHosts(self):
		self.client.logInfo("Removing known_hosts file")
		ssh.runClient("rm -f ~/.ssh/known_hosts")

	def checkKnownHosts(self):
		st = self.runClient("test -f ~/.ssh/known_hosts")
		self._haveKnownHosts = bool(st)
		return self._haveKnownHosts

	@property
	def testuser(self):
		return self.client_user.login

	@property
	def testuser_home(self):
		return self.client_user.home

	@property
	def testuser_password(self):
		return self.client_user.password

	def runClient(self, command, **kwargs):
		st = self.client.run(command, user = self.client_user.login, **kwargs)
		if not st:
			self.client.logInfo(f"WARNING: {command} failed: {st.message}")
			return False
		return True

	def runServer(self, command, **kwargs):
		st = self.server.run(command, user = self.server_user.login, **kwargs)
		if not st:
			self.client.logInfo(f"WARNING: {command} failed: {st.message}")
			return False
		return True

	class Key:
		def __init__(self, name, passphrase = None):
			self.name = name
			self.passphrase = passphrase

	def keygen(self, node = None, passphrase = None, keyalgo = None, keyfile = None, *args):
		if node is None:
			node = self.client

		if keyfile is None:
			if keyalgo:
				keyfile = f"id_{keyalgo}"
			else:
				keyfile = "id_rsa"

		key = self.getKey(keyfile)
		if key is not None:
			if key.passphrase == passphrase:
				return keyfile

			self.dropKey(keyfile)
		else:
			# Still, better safe than sorry. There may be stuff left from experimentation...
			node.run(f"rm -f ~/.ssh/{keyfile}*", quiet = True, user = self.testuser)

		args = list(args)
		args += ["-N", passphrase or "''"]
		if keyalgo:
			args += ["-t", keyalgo]
		args += ["-f", f"~/.ssh/{keyfile}"]

		command = "ssh-keygen"
		command += " " + " ".join(args)

		st = node.run(command, user = self.testuser)
		if not st:
			node.logFailure(f"{command} failed: {st.message}")
			return None

		self._keys[keyfile] = self.Key(keyfile, passphrase)
		return keyfile

	def config_change_value(self, node, key, value):
		file = node.requireFile("sshd_config")
		if not file or not file.path:
			return False

		node.logInfo(f"{file.path}: Setting {key}={value}")

		editor = file.createEditor()
		editor.addOrReplaceEntry(name = key, value = value)
		editor.commit()

		node.logInfo("Reloading SSH service")
		sshd = node.requireService("ssh")
		sshd.reload()

		return True

	def allow_password_auth(self):
		self.server.logInfo("Allowing password based authentication")
		return self.config_change_value(self.server, "PasswordAuthentication", "yes")

	def deny_password_auth(self):
		self.server.logInfo("Denying password based authentication")
		return self.config_change_value(self.server, "PasswordAuthentication", "no")

	def authorize_key(self, keyfile):
		server = self.server
		client = self.client

		self.driver.logInfo(f"== Authorizing ssh key {keyfile} ==")

		# FIXME: downloading ~/.ssh/something hangs quietly rather than fail
		path = f"{self.testuser_home}/.ssh/{keyfile}.pub"
		data = client.recvbuffer(path, user = self.testuser, quiet = True)
		if not data:
			client.logFailure(f"Failed to download publickey ~/.ssh/{keyfile}.pub")
			return False

		server.run("mkdir -m 0755 -p ~/.ssh", user = self.testuser)

		path = f"{self.testuser_home}/.ssh/authorized_keys"
		if not server.sendbuffer(path, data, user = self.testuser, quiet = True):
			server.logFailure("Failed to upload publickey to ~/.ssh/authorized_keys")
			return False

		return True

	def build_ssh_command(self):
		command = ["ssh"]
		if not self._haveKnownHosts:
			command.append("-oStrictHostKeyChecking=no")
		command.append("server true")
		return " ".join(command)

	def try_keyauth(self, check_status = _command_should_succeed):
		client = self.client

		command = self.build_ssh_command()

		st = client.run(command, user = self.testuser)
		return check_status(client, st)

	def try_password(self, check_status = _command_should_succeed):
		client = self.client

		self.runServer("rm -f ~/.ssh/authorized_keys")

		chat_script = [
			["assword:", self.testuser_password],
		]

		command = self.build_ssh_command()

		st = client.runChatScript(command, chat_script, user = self.testuser, tty = True, timeout = 10, timeoutOkay = True)
		return check_status(client, st)

	def try_passphrase(self, passphrase, check_status = _command_should_succeed):
		client = self.client
		server = self.server

		chat_script = [
			["nter passphrase for key", passphrase],
		]

		command = self.build_ssh_command()

		st = client.runChatScript(command, chat_script, user = self.testuser, tty = True, timeout = 10)
		return check_status(client, st)

ssh = None

# meh - this creates a setup functio for the test GROUP, not for the entire test suite
@susetest.setup
def setup(driver):
	'''Ensure we have all the resources this test suite requires'''
	global ssh

	ssh = SSH(driver)

@susetest.test
def ssh_keygen(driver):
	'''keygen: verify that we can generate SSH keys'''
	if not ssh.keygen():
		return

	ssh.runClient("ls ~/.ssh -l")
	driver.logInfo("keygen okay")

@susetest.test
def ssh_simple_keyauth(driver):
	'''keyauth: verify that the SSH key we generated can be used for authentication'''
	keyfile = ssh.keygen()
	if keyfile is None:
		return

	if not ssh.authorize_key(keyfile):
		return

	if not ssh.try_keyauth():
		return False

	driver.logInfo("OK, RSA key authentication seems to work")

@susetest.test
def ssh_keyauth_with_hostkey(driver):
	'''verify-hostkey: verify that the SSH client asks the user to verify the host key'''
	if not ssh.keygen(keyfile = "id_rsa"):
		return

	if not ssh.authorize_key("id_rsa"):
		return

	client = driver.client

	# Remove the known_hosts file
	ssh.removeKnownHosts()

	chat_script = [
		["continue connecting", None],
		["?", "yes"],
	]

	st = client.runChatScript("ssh server true", chat_script, user = ssh.testuser, timeout = 10)
	if not st:
		client.logFailure(f"ssh failed: {st.message}")
		return False

	if not ssh.checkKnownHosts():
		client.logFailure(f"BAD, cannot find known_hosts file")
		return False

	driver.logInfo("OK, ssh client asked for host key, and known_hosts file was created")
	return True

@susetest.test
def ssh_keyauth_with_passphrase(driver):
	'''keyauth-passphrase: verify that passphrase protected keys work'''
	myPassphrase = "funkyP@ssphrase"
	keyfile = ssh.keygen(keyalgo = "rsa", passphrase = myPassphrase)
	if not keyfile:
		return

	if not ssh.authorize_key(keyfile):
		return

	if not ssh.try_passphrase(myPassphrase):
		return

	driver.logInfo("OK, ssh client asked for pass phrase")
	return True

@susetest.test
def ssh_keyauth_with_agent(driver):
	'''keyauth-agent: verify that ssh-agent works'''
	client = driver.client

	myPassphrase = "funkyP@ssphrase"
	keyfile = ssh.keygen(keyalgo = "rsa", passphrase = myPassphrase)
	if not keyfile:
		return

	if not ssh.authorize_key(keyfile):
		return

	chat_script = [
		["assphrase for", myPassphrase],
	]

	st = client.runChatScript("eval `ssh-agent`; ssh-add; ssh -oStrictHostKeyChecking=no server true", 
			chat_script, user = ssh.testuser, tty = True, timeout = 10)
	if not _command_should_succeed(client, st):
		return

	driver.logInfo("OK, ssh-agent seems to work")
	return True

def __ssh_keyauth_with_algo(driver, keyalgo, check_status = _command_should_succeed):
	'''keyauth: verify that the SSH key we generated can be used for authentication'''
	keyfile = ssh.keygen(keyalgo = keyalgo)
	if not keyfile:
		return

	if not ssh.authorize_key(keyfile):
		return

	if not ssh.try_keyauth(check_status):
		return False

	keyalgo = keyalgo.upper()
	if check_status == _command_should_fail:
		driver.logInfo(f"OK, {keyalgo} key authentication fails as expected")
		return

	driver.logInfo(f"OK, {keyalgo} key authentication seems to work")

@susetest.test
def ssh_keyauth_dsa(driver):
	'''keyauth-dsa: verify that DSA key authentication works'''
	__ssh_keyauth_with_algo(driver, "dsa", check_status = _command_should_fail)

@susetest.test
def ssh_keyauth_ecdsa(driver):
	'''keyauth-ecdsa: verify that ECDSA key authentication works'''
	__ssh_keyauth_with_algo(driver, "ecdsa")

@susetest.test
def ssh_password_auth_allowed(driver):
	'''passwordauth-allowed: verify that password authentication works (when allowed)'''
	ssh.allow_password_auth()
	if not ssh.try_password():
		return

	driver.logInfo("OK, password authentication seems to work")

@susetest.test
def ssh_password_auth_denied(driver):
	'''passwordauth-denied: verify that password authentication fails (when denied)'''
	ssh.deny_password_auth()
	if not ssh.try_password(check_status = _command_should_fail):
		return

	driver.logInfo("OK, password authentication was denied")

@susetest.test
def scp_simple(driver):
	'''scp.simple: verify that we can scp files'''
	global ssh

	if ssh is None:
		ssh = SSH(driver)

	if not ssh.keygen(keyfile = "id_rsa"):
		return

	if not ssh.authorize_key("id_rsa"):
		return

	client = driver.client
	server = driver.server

	testdata = "too late to pick some daffodils".encode('utf-8')

	path = f"{ssh.client_user.home}/DATA"
	if not client.sendbuffer(path, testdata, user = ssh.testuser, quiet = True):
		client.logFailure("Unable to upload test file to client")
		return

	st = client.run("scp DATA server:", user = ssh.testuser)
	if _command_should_succeed(client, st):
		return

	path = f"{ssh.server_user.home}/DATA"
	data = server.recvbuffer(path, user = ssh.testuser, quiet = True)
	if not data:
		server.logFailure(f"Failed to download test data")
		return

	if data != testdata:
		driver.logFailure("Test data apparently got corrupted during transfer")
		driver.logInfo(f"   sent:     {testdata}")
		driver.logInfo(f"   received: {data}")
		return

	driver.logInfo("OK, scp seems to work")

# boilerplate tests
susetest.template('selinux-verify-subsystem', 'ssh', 'client')

if __name__ == '__main__':
	susetest.perform()
