package Allocator;

import java.util.HashMap;
import java.util.LinkedList;
import java.util.concurrent.locks.ReentrantReadWriteLock;
import java.util.concurrent.locks.ReadWriteLock;

public class MultithreadedBlockAllocator implements Allocator {

    //-----------------------Global Variables-----------------------
    //pending free request list: integer is threadId that allocated it, linkedlist is list of addresses
    private HashMap<Integer, LinkedList<Long>> PFRL;
    private ReadWriteLock globalReadWriteLockForPFRL;
    private HashMap<Integer, ReentrantReadWriteLock> readWriteLocksForPFRLs;
    //long is address, integer is thread that allocates
    private HashMap<Long, Integer> allocatedListByThread;
    private ReadWriteLock readWriteLockForAllocatedListByThread;

    //-----------------------Global Variables-----------------------


    //-----------------------Thread Local Variables-----------------------
    private ThreadLocal<HashMap<Integer, LinkedList<Long>>> freeList = new ThreadLocal<>();

    //Long is the address as key, integer is the size of the block
    //perhaps use a class instead of an array
    //perhaps also keep, keep how many times it was split(for realloc)
    private ThreadLocal<HashMap<Long, Integer>> allocatedList = new ThreadLocal<>();
    private ThreadLocal<Integer> timesFreeCalled = new ThreadLocal<>();
    //-----------------------Thread Local Variables-----------------------


    public MultithreadedBlockAllocator() {
        PFRL = new HashMap<>();
        globalReadWriteLockForPFRL = new ReentrantReadWriteLock();
        readWriteLocksForPFRLs = new HashMap<>();

        allocatedListByThread = new HashMap<>();
        readWriteLockForAllocatedListByThread = new ReentrantReadWriteLock();
        freeList = new ThreadLocal<HashMap<Integer, LinkedList<Long>>>() {
            @Override
            protected HashMap<Integer, LinkedList<Long>> initialValue() {
                HashMap<Integer, LinkedList<Long>> freeList = new HashMap<>();
                return freeList;
            }
        };
        allocatedList = new ThreadLocal<HashMap<Long, Integer>>() {
            @Override
            protected HashMap<Long, Integer> initialValue() {
                HashMap<Long, Integer> allocatedList = new HashMap<>();
                return allocatedList;
            }
        };
        timesFreeCalled = new ThreadLocal<Integer>() {
            @Override
            protected Integer initialValue() {
                Integer timesFreeCalled = 0;
                return timesFreeCalled;
            }
        };
    }

    @Override
    public Long allocate(int size) {
        int currentThread = (int) Thread.currentThread().threadId();
        Long address;
        int smallestFittingSize = smallestFittingSize(size);
        if (size > 4095) {
            if (!freeList.get().containsKey(smallestFittingSize)) {
                freeList.get().put(smallestFittingSize, new LinkedList<>());
            }
            if (freeList.get().get(smallestFittingSize).isEmpty()) {
                Long mmap = OperatingSystem.getInstance().mmap(size);
                addToFreeList(size, mmap);
                address = mmap;
            } else {
                address = freeList.get().get(smallestFittingSize).removeFirst();
            }
            allocatedList.get().put(address, smallestFittingSize);
        } else {
            if (!freeList.get().containsKey(size)) {
                freeList.get().put(size, new LinkedList<>());
            }
            if (freeList.get().get(size).isEmpty()) {
                Long mmap = OperatingSystem.getInstance().mmap(size);
                addToFreeListActualSize(size, mmap);
            }
            address = freeList.get().get(size).removeFirst();
            allocatedList.get().put(address, size);
        }

        readWriteLockForAllocatedListByThread.writeLock().lock();
        allocatedListByThread.put(address, currentThread);
        readWriteLockForAllocatedListByThread.writeLock().unlock();
        return address;
    }

    private int smallestFittingSize(int size) {
        int i = 1;
        while (i < size) {
            i = i << 1;
        }
        return i;
    }

    public void addToFreeListActualSize(int size, long startAddress) {
        //lower ceiling of 4096/size
        int numBlocks = (int) Math.floor(4096 / size);
        for (int j = 0; j < numBlocks; j++) {
            Long address = startAddress + j * size;
            freeList.get().get(size).add(address);
        }
    }

    private void addToFreeList(int size, long startAddress) {
        int smallestFittingSize = smallestFittingSize(size);
        int numBlocks = 4096 / smallestFittingSize;
        for (int j = 0; j < numBlocks; j++) {
            Long address = startAddress + j * smallestFittingSize;
            freeList.get().get(smallestFittingSize).add(address);
        }
    }

    @Override
    public void free(Long address) {
        //save free request to pending free request list
        //first get the thread that allocated it
        readWriteLockForAllocatedListByThread.readLock().lock();
        int allocatorThread = allocatedListByThread.get(address);
        readWriteLockForAllocatedListByThread.readLock().unlock();

        //create if they dont exist
        globalReadWriteLockForPFRL.writeLock().lock();
        if (PFRL.get(allocatorThread) == null) {
            PFRL.put(allocatorThread, new LinkedList<>());
        }
        if (readWriteLocksForPFRLs.get(allocatorThread) == null) {
            readWriteLocksForPFRLs.put(allocatorThread, new ReentrantReadWriteLock());
        }
        globalReadWriteLockForPFRL.writeLock().unlock();

        readWriteLocksForPFRLs.get(allocatorThread).writeLock().lock();
        PFRL.get(allocatorThread).add(address);
        readWriteLocksForPFRLs.get(allocatorThread).writeLock().unlock();

        //free all addresses in pending free request list for this thread after 20 calls
        timesFreeCalled.set(timesFreeCalled.get() + 1);
        int currentThread = (int) Thread.currentThread().threadId();
        if (timesFreeCalled.get() > 20) {
            timesFreeCalled.set(0);
            globalReadWriteLockForPFRL.readLock().lock();
            for (int i = 0; i < PFRL.get(currentThread).size(); i++) {
                actuallyFree(PFRL.get(currentThread).get(i));
            }
            globalReadWriteLockForPFRL.readLock().unlock();
            readWriteLocksForPFRLs.get(currentThread).writeLock().lock();
            PFRL.get(currentThread).clear();
            readWriteLocksForPFRLs.get(currentThread).writeLock().unlock();
        }
    }


    //no locks since all the variables are thread local
    private void actuallyFree(Long address) {
        if(allocatedList.get().containsKey(address)) {
            int size = allocatedList.get().remove(address);
            freeList.get().get(size).add(address);
        }
    }

    @Override
    public Long reAllocate(Long oldAddress, int newSize) {
        free(oldAddress);
        return allocate(newSize);
    }

    @Override
    public boolean isAccessible(Long address) {
        return allocatedList.get().containsKey(address);
    }


    /*
     * Same as above, except it allows to check a range of
     * addresses more efficiently.
     *
     * In addition, this method should verify that all addresses
     * in the range belong to the same block of memory.
     */
    @Override
    public boolean isAccessible(Long address, int size) {
        //synchronized (this) {
        if (allocatedList.get().containsKey(address)) {
            for (int i = 1; i < size; i++) {
                if (allocatedList.get().containsKey(address + i)) {
                    return false;
                }
            }
            return true;
        }
        return false;

        //}

    }


}