Junarata-tehtävä

class IntvalSet:
    """Set where the included natural numbers
        are given as intervals [a, b]"""
    def __init__(self):
        self.intvals = []

    def has(self, x):
        for (a,b) in self.intvals:
            if a<=x and x<=b: return True
        return False

    def add(self, a, b):
        self.intvals.append((a, b))

    def getAll(self, lowerBdd=0):
        """generates all numbers in the set.
            If lowerBdd is given,
            generates only numbers >= lowerBdd"""
        for (a,b) in self.intvals:
            for x in range(max(a, lowerBdd), b+1):
                yield x

    def __str__(self):
        return ", ".join([str(x) for x in self.getAll()])
    


def getData(num):
    ret = []
    with open("junar"+str(num)+".in") as f:
        ret = f.read().split("\n")
    return [int(x) for x in ret[1:] if len(x)>0]


def getPoss(p):
    """Get list of positions of elements in the
        permutation p of [1..n]
        e.g. getPoss([2,5,3,1,4]) = [None, 3, 0, 2, 4, 1]
    """
    ret = [None]*(len(p)+1)
    for i,x in enumerate(p):
        ret[x] = i
    return ret

"""
Let's denote types of orderings of k-1, k and k+1
(or k and k+1; k-1 and k in cases k=1; k=n  like this:

123: 0
132: 1
213: 2
231: 3
312: 4
321: 5
_12: 6
_21: 7
12_: 8
21_: 9
"""
def getTyyppi(a, b, c):
    if a<b:
        if b<c: return 0
        if a<c: return 1
        return 3
    else: #b<a
        if b>c: return 5
        if a<c: return 2
        return 4


def getNumOfRounds(poss):
    """How many rounds does it take to gather the numbers.
        @param poss: index-positions of the permutation
    """
    curr = 1
    prevInd = poss[1]
    r = 1
    for curr in range(1, len(poss)):
        currInd = poss[curr]
        if currInd<prevInd: r += 1
        prevInd = currInd
    return r

"""
As the above algorithm shows, the number of rounds for a permutation is
1 + #{k | poss[k+1]<poss[k]}
A swap can decrease the rounds by at most 2.
This can be seen from the following:
Let's consider for each k in [1..n] the positions that into which it is
moved, decrease the rounds. This can only ever be at most 1.
But a swap moves two, so if the other one also is a decreasing move,
we get 2, except in the case when we swap some k and k+1,
since that only decreases the rounds by 1.

(Look at the types, and consider for each how it decreases rounds.)

Let's hope the answer will always be 'score 2' (decreases rounds by 2) as
the existence of such swap is highly probably when we have a large random
permutation.
"""

def countScore2Pairs(arr):
    n = len(arr)
    poss = getPoss(arr)
    toConsider = []
    t = None
    tempSet = None
    for i, x in enumerate(arr):
        if x==1: t = 6 if poss[x+1] > i else 7
        elif x==n: t = 8 if poss[x-1] < i else 9
        else: t = getTyyppi(poss[x-1], i, poss[x+1])

        tempSet = IntvalSet()
        
        if t==2:
            tempSet.add(poss[x-1], poss[x+1]-1)
        elif t==5:
            tempSet.add(0, poss[x+1])
            tempSet.add(poss[x-1], n-1)
        elif t==1:
            tempSet.add(poss[x-1]+1, poss[x+1])
        elif t==7:
            tempSet.add(0, poss[x+1])
        elif t==9:
            tempSet.add(poss[x-1], n-1)
        
        toConsider.append(tempSet)
    
    numPairs = 0
    for i, x in enumerate(arr):
        #if i%1000==0: print ("i="+str(i))
        for j in toConsider[i].getAll(i+1):
            if toConsider[j].has(i) and abs(x-arr[j])>1:
                numPairs += 1
    return numPairs


def junarata(arr):
    n = len(arr)
    poss = getPoss(arr)
    origRounds = getNumOfRounds(poss)
    numPairs = countScore2Pairs(arr)
    return "%d %d %d" %(n, origRounds-2, numPairs)


##---------------------------------------------------
## Random tests

import random
def randomJuna(n):
    a = range(1, n+1)
    random.shuffle(a)
    return [int(x) for x in junarata(a).split()]

from collections import Counter
def genJunaData(n, simuN):
    d = [randomJuna(n) for _ in range(simuN)]
    #print ([x[2] for x in d])
    c1 = Counter([x[1] for x in d])
    c2 = Counter([x[2] for x in d])
    c = c2
    return [ [k, c[k]] for k in sorted(c.keys()) ]
    
##-------------------------------------------------------


#Solve putkaposti
for i in range(1, 10):
    print(junarata(getData(i)))


""" Test execution time
import time
startT = time.time()
print (junarata(getData(9)))
print ("Time: "+str(time.time()-startT))
"""

Jätä kommentti