#!/usr/bin/python


#author Justin Sherrill jsherril@redhat.com
#version: 2.0-beta
#
#



import getopt
import sys
import getpass
from xmlrpclib import Server
from xmlrpclib import Fault
import string
import os
from os import path
from optparse import OptionParser
from rhn.connections import idn_ascii_to_pune

DATA_DIR = "./data/"
DATA_DIR_RHN = "/usr/share/rhn/channel-data/"
GATHER_DIR = "/mnt/engarchive2/released/"

versions = ( "7", "6", "5", "4", "3", "2.1" );
releases = ( "AS", "ES", "WS", "Server", "Client", "Desktop", "Appliance",
		"ComputeNode", "WebServer", "Workstation");
arches = ( "i386",  "ia64",  "ppc",  "s390",  "s390x",  "x86_64", "ppc64");
updates = ("GOLD", "U1", "U2", "U3", "U4", "U5", "U6", "U7", "U8", "U9", "U10" );
subrepos = ("VT", "Cluster", "ClusterStorage", "Workstation", "ResilientStorage",
		"ScalableFileSystem", "LoadBalancer", "HighAvailability"); #RHEL 5 yum repos on the discs
extras = ("Extras", "Supplementary"); #things shipped on other cds


bundleSize = None
server = None
user = None
password = None
filename = None
ssl = True
parser = None
clone = False

client = None

def get_args():
   global parser
   parser = OptionParser(usage="""spacewalk-create-channel  -l USER  -s SERVER [-p PASSWORD]  -v VERSION -r RELEASE
                -u UPDATE -a ARCH -e EXTRA  [ -D DATAFILE]  [-c SRCCHAN] -d DESTCHAN [-N DESTNAME]  [-P PARENT]
                [-C | --clear] [-n | --nossl]""")
   parser.add_option("-n", "--nossl", action="store_false", dest="ssl", default=True, help="disables the use of SSL when connecting to the server")
   parser.add_option("-C", "--clear", action="store_true", dest="clear_channel", help="remove all channel packages from the channel before adding")
   parser.add_option("-s", "--server", action="store", dest="server", default="localhost", help="the server to connect to (i.e. hostname.domain.com) [default: %default]")
   parser.add_option("-l", "--user", action="store", dest="user", help="the user to connect as")
   parser.add_option("-p", "--password", action="store", dest="password", help="the password to use for the connection (if not specified, will be prompted for)")
   parser.add_option("-v", "--version", action="store", dest="version", help='the version to use (i.e.' + string.join(versions, ', ') +')')
   parser.add_option("-r", "--release", action="store", dest="release", help='the release to use (i.e. ' + string.join(releases, ', ') + ')')
   parser.add_option("-a", "--arch", action="store", dest="arch", help='the arch to use (i.e. ' + string.join(arches, ', ') + ')')
   parser.add_option("-u", "--update", action="store", dest="update_level", help='the update level to use (i.e. ' + string.join(updates, ', ')  + ')')
   parser.add_option("-e", "--extra", action="store", dest="extra", help='the child channel/repo to use (i.e. ' + string.join(extras, ', ') + ')')
   parser.add_option("-D", "--data", action="store", dest="data", help='a data file to use, may use this instead of version, release, update, arch,  and extra')
   parser.add_option("-c", "--sourceChannel", action="store", dest="source_channel", help=' the channel to pull packages from (will be auto detected if not provided)')
   parser.add_option("-d", "--destChannel", action="store", dest="dest_channel", help='the label of the destintation channel to push the packages to (will be created if not existing)')
   parser.add_option("-P", "--parentChannel", action="store", dest="parent_channel", default="", help='if specified, and --destChannel does not exist, --destChannel will be created with this parent')
   parser.add_option("-L", "--clone", action="store_true", dest="clone", default=False, help='if creating destChannel, clone it from the source channel, before adding packages (Satellite 5.1 or newer required)')
   parser.add_option("-K", "--skiplist", action="store", dest="skiplist", help="a filename with a list of package names to ignore when adding to the destination channel")
   parser.add_option("-g", "--gather", action="store_true", dest="gather", default=False , help='used to gather data files, generally not useful to most people')
   parser.add_option("-b", "--bundle", action="store", dest="bundle", type='int', default='50', help='if you are getting "502 proxy" errors, try using a smaller value (default 50)')
   parser.add_option("-N", "--name", action="store", dest="dest_name", help='if specified, set the descriptive name of the destination channel, else use channel label')

   return parser.parse_args()

def validate_options(options):
   error = False

   if not options.user:
      error = True
      print "--user must be specified"


   if not options.data:
      if not options.version:
         error = True
         print "--version must be specified"
      if not options.release:
         error = True
         print "--release must be specified"
      if not options.arch:
         error = True
         print "--arch must be specified"
      if not options.update_level:
         error = True
         print "--update must be specified"


   if  not options.dest_channel:
      error = True
      print "--destChannel must be specified"

   if error:
      sys.exit(1)

def main():


   (options, args) = get_args()


   if options.gather:
      gather(False)
      sys.exit(0)

   validate_options(options)

   global server, user, password, filename, ssl, bundleSize, clone
   server = idn_ascii_to_pune(options.server)
   user = options.user
   password  = options.password
   if password == None:
      password = getpass.getpass()
   #even though this is supposed to be an int already, in RHEL4's python it isn't
   bundleSize = int(options.bundle) 

   #did you provide a data file?  If not we need all the information
   if options.data == None:
      version = options.version
      release = options.release
      update = options.update_level
      if update == "0":
          update = "gold"
      arch = options.arch
      extra = options.extra
      filename = getFileName(version,update, release, arch, extra)
      fullpath = DATA_DIR_RHN + filename
   else:
      filename = os.path.basename(options.data)
      fullpath = options.data
      version, release, update, arch, extra = parseDataFilename(filename)   

   ssl = options.ssl
   clear = options.clear_channel

   if not os.path.exists(fullpath):
      print "Error: data file '" + filename + "' does not exist"
      sys.exit(1)

   parent = options.parent_channel
   clone = options.clone

   newChannelLabel = options.dest_channel
   newChannelLabel = newChannelLabel.lower()
   srcChanLabel = options.source_channel

   if not options.dest_name:
      newChannelName = options.dest_channel
   else:
      newChannelName = options.dest_name

   if (srcChanLabel == None):
      print "You have not specified a source channel, we will try to determine it from inputs"
      srcChanLabel = findSrcChan(version, release, arch, extra) 
   srcChanLabel = srcChanLabel.lower()

   print("Trying with source channel: " + srcChanLabel)

   #now lets login to the server
   proto = "http"
   if ssl:
      proto = proto + "s"
   client = Server(proto + "://" + server + "/rpc/api")
   if clone and client.api.getVersion() < 5.1:
      print "--clone cannot be used with a Satellite version older than 5.1"
   try:
      auth = client.auth.login(user, password)
   except Fault:
      excInfo = sys.exc_info()[1]
      print "\n %s " % excInfo.faultString
      sys.exit(1)

   skiplist = []
   if options.skiplist:
      skiplist = read_skip_list(options.skiplist)
  
   populate(client, auth, srcChanLabel, newChannelLabel, newChannelName, filename, parent, clear, skiplist=skiplist)


def populate(client, auth, srcChannel, newChannel, newChannelName, filename, parent, clear, skiplist=[]):
     
   chanList = client.channel.listSoftwareChannels(auth)
  
   src_id = None
   src_label = None
   new_id = None
   arch_label = None
   new_label = None
   for chan in chanList:
      label = getChannelAttr(chan, 'label')
      if label == srcChannel:
         src_label = getChannelAttr(chan, 'label')
         arch_label = getChannelAttr(chan, 'arch')
      elif label == newChannel:
         new_label = getChannelAttr(chan, 'label')
    
   if src_label == None:
      print("Error: Source Channel '" + srcChannel + "' could not be found.")
      sys.exit(1)
   if new_label == None:
      if not clone:
         print("Creating channel, " + newChannel + ", with arch " + arch_label),
         client.channel.software.create(auth, newChannel, newChannelName, newChannel, getArchLabel(arch_label, client, auth), parent)
      if clone:
          print("Cloning channel " + srcChannel + " to " + newChannel),
          details_map = { 'name':newChannelName, 'label':newChannel, 'summary':newChannel }
          if parent:
             details_map['parent_label'] = parent
          client.channel.software.clone(auth, srcChannel, details_map, True)

      if parent:
         print("with parent " + parent)
   else:
      print("Reusing " + newChannel+ " as destination channel")
   existing_packs = client.channel.software.listAllPackages(auth, srcChannel)
   

   if clear:
      clearChannel(newChannel, existing_packs, client, auth)  
   
   #strip new lines 
   fileList = map(str.strip, open(DATA_DIR_RHN + filename).readlines())
   fileList.sort()
   pack_num = 0
   skip_num = 0
   ids_to_add = []

   print(str(len(fileList)) + " packages in source file to push.")  
   
   #for each filename, go through the src package list and find the package id
   for rpm in fileList:
      nvre_to_push = splitFilename(rpm)
      for src_pack in existing_packs:
         if (nvre_to_push == splitPackage(src_pack)):
            key = 'name'
            if not key in src_pack:
               key = 'package_name'
            if src_pack[key] in skiplist:
               skip_num = skip_num + 1
               continue
            pack_num = pack_num + 1
            ids_to_add.append(getPackageAttr(src_pack, 'id'))
            break
   print("Pushing " + str(pack_num) + " packages, please wait.")
   if len(skiplist) > 0:
      print("Skipping " + str(skip_num) + " packages based off of skip list")

   while len(ids_to_add) > 0:
      print '%d of %d' % (pack_num-len(ids_to_add), pack_num),
      sys.stdout.flush()
      client.channel.software.addPackages(auth, newChannel, ids_to_add[:bundleSize])
      del ids_to_add[:bundleSize]
      print '%s\r' % ' '*20,
      sys.stdout.flush()
      

   print("Successfully pushed " + str(pack_num) + " packages out of "+ str(len(fileList)))


def read_skip_list(filename):
   if not os.path.exists(filename):
      print("Specified skip list " + filename + " does not exist!")
      sys.exit(1)
   nameList = map(str.strip, open(filename).readlines())
   print nameList
   return nameList

def clearChannel(label, packages, client, auth):
      ids_to_remove = []
      print("Clearing channel packages")
      for pack in packages:
         ids_to_remove.append(getPackageAttr(pack, 'id'))
      num_to_remove = len(ids_to_remove)
      while len(ids_to_remove) > 0:
         print '%d of %d' % (num_to_remove-len(ids_to_remove), num_to_remove),
         sys.stdout.flush()
         client.channel.software.removePackages(auth, label, ids_to_remove[:bundleSize])
         del ids_to_remove[:bundleSize]
         print '%s\r' % ' '*20,
         sys.stdout.flush()
      print("Finished clearing channel")



def getArchLabel(archName, client, auth):
   arches = client.channel.software.listArches(auth)
   for arch in arches:
      if getArchAttr(arch, 'name') == archName:
         return getArchAttr(arch, 'label')


def splitPackage(map):
#   return [ getPackageAttr(map, 'name'), getPackageAttr(map, 'version'), getPackageAttr(map, 'release'), getPackageAttr(map, 'arch_label')]
    return '-'.join([getPackageAttr(map, 'name'), getPackageAttr(map, 'version'), getPackageAttr(map, 'release'), getPackageAttr(map, 'arch_label')])

def splitFilename(filename):
   per = filename.split('.')
   together = '.'.join(per[:-2]) + '-' + per[-2]
#   return together.split('-')
   return together
   


def getAttr(map, base, attribute):
   label = map.get(attribute)
   if label == None:
      label = map.get(base + attribute)
   return label

def getPackageAttr(map, attribute):
   return getAttr(map, 'package_', attribute)

def getChannelAttr(map, attribute):
   return getAttr(map, 'channel_', attribute)

def getArchAttr(map, attribute):
   return getAttr(map, 'arch_', attribute)

#returns tuple (version, release, update, arch, extra)
def parseDataFilename(file):
   file = file.lower()
   split = string.split(file, '-')
   if len(split) == 4:
      return (split[0], split[2], split[1], split[3], None) 
   else:
      return (split[0], split[2], split[1], split[3], split[4])

def findSrcChan(version, release, arch, extra = None):
   low_release = release.lower() 

   if version == "2.1":
      if low_release == "as":
        return "redhat-advanced-server-2.1"
      if release == "es":
         return "redhat-ent-linux-i386-es-2.1"
      if release == "ws":
         return "redhat-ent-linux-i386-ws-2.1"
   if low_release == "computenode":
       low_release = "hpc-node"
   #if we don't have a subrepo/extras we're done
   if extra == None:
      return "rhel-%s-%s-%s" % (arch, low_release, version)
   else:  #else we do, so lets process that
      low_extra = extra.lower()
      extra_map = {"clusterstorage": "cluster-storage", "resilientstorage": "rs",
              "scalablefilesystem": "sfs", "highavailability": "ha", "loadbalancer": "lb"}
      if low_extra in extra_map:
         low_extra = extra_map[low_extra]
      if low_extra == "extras":
          return  "rhel-%s-%s-%s-%s" % (arch, low_release, version, low_extra)
      else:
          return "rhel-%s-%s-%s-%s" % (arch, low_release, low_extra, version )

#this function has no real use outside of red hat
def gather(clear):
   if clear:
      os.system("rm -rf " + DATA_DIR)

   if not path.isdir(DATA_DIR):
      os.mkdir(DATA_DIR)

   for version in versions:
      for update in updates:
         for release in releases:
            for arch in arches:
               if float(version) <= 4:
                  repoDir = "RHEL-%s/%s/%s/%s/tree/" % (version, update, release, arch)  
                  fullDir = GATHER_DIR + repoDir
                  saveRpmList(fullDir +  "RedHat/RPMS/", DATA_DIR + getFileName(version, update,release, arch))
                  #save extras
                  suppl = extras[0]
                  repoDir = "RHEL-%s-%s/%s/%s/%s/tree/" % (version, suppl, update, release, arch)
                  fullDir = GATHER_DIR + repoDir 
                  saveRpmList(fullDir +  "RedHat/RPMS/", DATA_DIR + getFileName(version, update,release, arch, suppl))
               elif float(version) == 5:
                  repoDir = "RHEL-%s-%s/%s/%s/os/" % (version, release, update, arch) 
                  fullDir = GATHER_DIR + repoDir
                  saveRpmList(fullDir +  release + "/", DATA_DIR + getFileName(version, update,release, arch))
                  #now lets handle the other repos:
                  for subrepo in subrepos:
                     saveRpmList(fullDir +  subrepo + "/", DATA_DIR + getFileName(version, update,release, arch, subrepo))
                  #do supplementary
                  suppl = extras[1]
                  extraDir = "RHEL-%s-%s-%s/%s/%s/os/%s/" % (version, release, suppl, update, arch, suppl)
                  fullDir = GATHER_DIR + extraDir
                  saveRpmList(fullDir + "/", DATA_DIR + getFileName(version, update,release, arch, suppl))
               elif float(version) == 6:
                  if update == "GOLD":  #RHEL 6 uses X.Y notation instead of X UY
                    update_num = "0"
                  else:
                    update_num = update.replace("U","") #strip the U
                  # RHEL-6/6.0/Server/x86_64/os/
                  repoDir = "RHEL-%s/%s.%s/%s/%s/os/" % (version, version, update_num, release, arch)
                  fullDir = GATHER_DIR + repoDir
                  if not saveRpmList(fullDir +  release + "/listing", DATA_DIR + getFileName(version, update,release, arch)):
                      saveRpmList(fullDir +  release + "/Packages/", DATA_DIR + getFileName(version, update,release, arch))
                  #now lets handle the other repos:
                  for subrepo in subrepos:
                     saveRpmList(fullDir +  subrepo + "/listing", DATA_DIR + getFileName(version, update,release, arch, subrepo))
                  #do supplementary
                  suppl = extras[1]
                  fullDir = GATHER_DIR + "RHEL-%s-%s/%s.%s/%s/%s/os/Packages/" % (version, suppl, version, update_num, release, arch)
                  saveRpmList(fullDir + "/", DATA_DIR + getFileName(version, update,release, arch, suppl)) 

                  #do optional
                  fullDir = GATHER_DIR + "RHEL-%s/%s.%s/%s/optional/%s/os/" % (version, version, update_num, release, arch)
                  saveRpmList(fullDir + "/Packages/", DATA_DIR + getFileName(version, update,release, arch, 'optional'))

               elif float(version) == 7:
                  if update == "GOLD":
                    update_num = "0"
                  else:
                    update_num = update.replace("U","") #strip the U
                  # RHEL-7/7.0/Server/x86_64/os/
                  repoDir = "RHEL-%s/%s.%s/%s/%s/os/" % (version, version, update_num, release, arch)
                  fullDir = GATHER_DIR + repoDir

                  saveRpmList(fullDir + "/Packages/", DATA_DIR + getFileName(version, update,release, arch))

                  #now lets handle the other repos:
                  for subrepo in subrepos:
                     saveRpmList(fullDir +  '/addons/' + subrepo, DATA_DIR + getFileName(version, update,release, arch, subrepo))

                  #do optional
                  fullDir = GATHER_DIR + "RHEL-%s/%s.%s/%s-optional/%s/os/" % (version, version, update_num, release, arch)
                  saveRpmList(fullDir + "/Packages/", DATA_DIR + getFileName(version, update,release, arch, 'optional'))


def saveRpmList(source, filename):
   if path.exists(source) and not path.exists(filename):
      print source
      if path.isfile(source):
          os.system("cp %s %s" % (source, filename))
      else: # isdir
          os.system("ls " + source + " | grep rpm | sort > " + filename);
      return True
   else:
      return False

def getFileName(version, update, release, arch, additional=None):
   if additional == None:
    return "%s-%s-%s-%s" % (version, update.lower(), release.lower(), arch)
   else:
      return "%s-%s-%s-%s-%s" % (version, update.lower(), release.lower(), arch, additional.capitalize())


def help():
   usage()
   print """
Examples
-------

Create channel 'my-stable-channel' for RHEL 6 Server GOLD x86_64:
spacewalk-create-channel  -l admin -s myserver.example.com -v 6 -u gold -r server -a x86_64 -d my-stable-channel

or another way:
spacewalk-create-channel  -l admin -s myserver.example.com  -D /usr/share/rhn/channel-data/6-gold-server-x86_64 -d my-stable-channel

Upgrade prevously created channel 'my-stable-channel' to RHEL 6 Server u1 x86_64:
spacewalk-create-channel  -l admin -s myserver.example.com -v 6 -u u1 -r server -a x86_64 -d my-stable-channel

Add the Supplementary channel as a child channel of 'my-stable-channel':
spacewalk-create-channel  -l admin -s myserver.example.com -v 6 -u u1 -r server -a x86_64 -e supplementary -P my-stable-channel
"""

def error(msg):
   print msg
   sys.exit(-1)


if __name__ == "__main__":
    main()
