Junarata-tehtävä 2

Tiedosto fenwickTree.py

class FenwickTree:

    def __init__(self, n):
        self.n = n+1
        self.arr = [0]*self.n

    def incr(self, ind, val):
        """Increment the value at index ind by val"""
        if ind<0: raise Exception("Index out of bounds")
        curr = ind+1
        while curr<self.n:
            self.arr[curr] += val
            curr += curr&(-curr)

    def getSum(self, upto):
        """Get the prefix sum up to index upto"""
        s = 0
        ind = upto+1
        while ind>0:
            s += self.arr[ind]
            ind -= ind&(-ind)
        return s

Tiedosto junarata6.py

from fenwickTree import FenwickTree

class IntvalSet:
    """Set where the included natural numbers
        are given as intervals [a, b]"""
    def __init__(self):
        self.intvals = []

    def has(self, x):
        for (a,b) in self.intvals:
            if a<=x and x<=b: return True
        return False

    def add(self, a, b):
        self.intvals.append((a, b))

    def getAll(self, lowerBdd=0):
        """generates all numbers in the set.
            If lowerBdd is given,
            generates only numbers >= lowerBdd"""
        for (a,b) in self.intvals:
            for x in range(max(a, lowerBdd), b+1):
                yield x

    def __str__(self):
        return "[]" if len(self.intvals)==0 else " U ".join(["[%d, %d]" %x for x in self.intvals])


def getData(num):
    ret = []
    with open("junar"+str(num)+".in") as f:
        ret = f.read().split("\n")
    return [int(x) for x in ret[1:] if len(x)>0]

def getPoss(p):
    """Get list of positions of elements in the
        permutation p of [1..n]
        e.g. getPoss([2,5,3,1,4]) = [None, 3, 0, 2, 4, 1]
    """
    ret = [None]*(len(p)+1)
    for i,x in enumerate(p):
        ret[x] = i
    return ret

"""
Let's denote types of orderings of k-1, k and k+1
(or k and k+1; k-1 and k in cases k=1; k=n  like this:

123: 0
132: 1
213: 2
231: 3
312: 4
321: 5
_12: 6
_21: 7
12_: 8
21_: 9
"""
def getTyyppi(a, b, c):
    if a<b:
        if b<c: return 0
        if a<c: return 1
        return 3
    else: #b<a
        if b>c: return 5
        if a<c: return 2
        return 4


def getNumOfRounds(poss):
    """How many rounds does it take to gather the numbers.
        @param poss: index-positions of the permutation
    """
    curr = 1
    prevInd = poss[1]
    r = 1
    for curr in range(1, len(poss)):
        currInd = poss[curr]
        if currInd<prevInd: r += 1
        prevInd = currInd
    return r


def countScore2Pairs(arr):
    n = len(arr)
    poss = getPoss(arr)
    toConsider = []
    t = None
    tempSet = None
    for i, x in enumerate(arr):
        if x==1: t = 6 if poss[x+1] > i else 7
        elif x==n: t = 8 if poss[x-1] < i else 9
        else: t = getTyyppi(poss[x-1], i, poss[x+1])

        tempSet = IntvalSet()
        
        if t==2:
            tempSet.add(poss[x-1], poss[x+1]-1)
        elif t==5:
            tempSet.add(0, poss[x+1])
            tempSet.add(poss[x-1], n-1)
        elif t==1:
            tempSet.add(poss[x-1]+1, poss[x+1])
        elif t==7:
            tempSet.add(0, poss[x+1])
        elif t==9:
            tempSet.add(poss[x-1], n-1)
        
        toConsider.append(tempSet)


    """
    mat = [[1 if toConsider[i].has(j) else 0 for j in range(n)]
           for i in range(n)]

    print ("To Consider Matrix:")
    #print (mat)
    for row in mat: print row
    """

    tree = FenwickTree(n)

    events = [[] for _ in range(n+1)] #grouped by time
    #print ("Events len is "+str(len(events)))
    for i, c in enumerate(toConsider):
        for (j1, j2) in c.intvals:
            #print ("Pushing events for interval [%d, %d]" %(j1, j2))
            events[j1].append([j1, i, 1])
            events[j2+1].append([j2+1, i, -1])
    
    numPairs = 0
    for k, evs in enumerate(events[:-1]):
        for (j, i, d) in evs:
            tree.incr(i, d)
        for (j1, j2) in toConsider[k].intvals:
            numPairs += tree.getSum(j2) - tree.getSum(j1-1)
        #subtract the adjacent ones
        x = arr[k]
        if (x>1 and toConsider[k].has(poss[x-1])
            and toConsider[poss[x-1]].has(k) ): numPairs -= 1
        if (x<n and toConsider[k].has(poss[x+1])
            and toConsider[poss[x+1]].has(k) ): numPairs -= 1
    
    return numPairs/2


def junarata(arr):
    n = len(arr)
    poss = getPoss(arr)
    origRounds = getNumOfRounds(poss)
    numPairs = countScore2Pairs(arr)
    return "%d %d %d" %(n, origRounds-2, numPairs)



import random

#a = range(1, 10); random.shuffle(a); print(junarata(a));

#Solve putkaposti
for i in range(1, 10):
    print(junarata(getData(i)))



#Test execution time
"""
import time
startT = time.time()
print (junarata(getData(9)))
print ("Time: "+str(time.time()-startT))
"""

Jätä kommentti