/*
  Filename : FineGrainedHeap.cvl
  Authors  : YYY and XXX
  Created  : 2021-03-02
  Modified : 2025-01-16

  CIVL model of FineGrainedHeap class from "The Art of Multiprocessor
  Programming" 2nd ed, by Herlihy, Luchangco, Shavit, and Spear,
  Sec. 15.4.2 "A concurrent heap", and companion code
  ch15/priority/src/priority/FineGrainedHeap.java.
  
  This is an unbounded priority queue implementation using
  fine-grained lock synchronization over a heap represented as an
  array of HeapNodes.
  
  some Lab
  some division
  some institution
 */
#include "Lock.h"
#include "PQueue.h"
#include "tid.h"
#include "types.h"
#include <limits.h>
#include <stdbool.h>
#include <stdlib.h>
#include <string.h>

#define ROOT 1
#define NO_ONE $proc_null
#define CAPACITY 8

typedef enum _status {EMPTY, AVAILABLE, BUSY} Status;

// HeapNode class ...

typedef struct HeapNode {
  Status tag;
  int score;
  T item;
  $proc owner;
  Lock lock;
} * HeapNode;

static void HeapNode_init(HeapNode n, T myItem, int myPriority) {
  n->item = myItem;
  n->score = myPriority;
  n->tag = BUSY;
  n->owner = $self;
}

static HeapNode HeapNode_create() {
  HeapNode result = (HeapNode) malloc(sizeof(struct HeapNode));
  result->tag = EMPTY;
  result->lock = Lock_create();
  return result;
}

static HeapNode HeapNode_destroy(HeapNode n) {
  Lock_destroy(n->lock);
  free(n);
}

static void HeapNode_lock(HeapNode n) {
  Lock_acquire(n->lock);
}

static void HeapNode_unlock(HeapNode n) {
  Lock_release(n->lock);
}

static bool HeapNode_amOwner(HeapNode n) {
  switch (n->tag) {
    case EMPTY:
      return false;
    case AVAILABLE:
      return false;
    case BUSY:
      return (n->owner == $self);
  }
  return false; // not reached
}

// FineGrainedHeap ...

struct PQueue {
  Lock heapLock;
  int next;
  HeapNode* heap;
  int heap_length;
};

PQueue PQueue_create() {
  PQueue q = malloc(sizeof(struct PQueue));
  q->heapLock = Lock_create();
  q->next = ROOT;
  q->heap_length = CAPACITY;
  q->heap = malloc(sizeof(HeapNode) * (q->heap_length + 1));
  for (int i = 0; i < q->heap_length + 1; i++) {
    q->heap[i] = HeapNode_create();
  }
  return q;
}

void PQueue_destroy(PQueue q) {
  Lock_destroy(q->heapLock);
  for (int i = 0; i < q->heap_length + 1; i++) {
    HeapNode_destroy(q->heap[i]);
  }
  free(q->heap);
  free(q);
}

void PQueue_initialize_context(void) {}

void PQueue_finalize_context(void) {}

void PQueue_initialize(PQueue pq) {}

void PQueue_finalize(PQueue pq) {}

void PQueue_terminate(int tid) {}

bool PQueue_stuck(void) {
  return false;
}

static void swap(PQueue q, int i, int j);

void PQueue_add(PQueue q, T item, int priority) {
  Lock_acquire(q->heapLock);
  int child = q->next++;
  HeapNode_lock(q->heap[child]);
  Lock_release(q->heapLock);
  HeapNode_init(q->heap[child], item, priority);
  HeapNode_unlock(q->heap[child]);
  while (child > ROOT) {
    int parent = child / 2;
    HeapNode_lock(q->heap[parent]);
    HeapNode_lock(q->heap[child]);
    int oldChild = child;
    // try
    if (q->heap[parent]->tag == AVAILABLE &&
        HeapNode_amOwner(q->heap[child])) {
      if (q->heap[child]->score < q->heap[parent]->score) {
        swap(q, child, parent);
        child = parent;
      } else {
        q->heap[child]->tag = AVAILABLE;
        q->heap[child]->owner = NO_ONE;
        // finally
        HeapNode_unlock(q->heap[oldChild]);
        HeapNode_unlock(q->heap[parent]);
        return;
      }
    } else if (!HeapNode_amOwner(q->heap[child])) {
      child = parent;
    }
    // finally
    HeapNode_unlock(q->heap[oldChild]);
    HeapNode_unlock(q->heap[parent]);
  }
  if (child == ROOT) {
    HeapNode_lock(q->heap[ROOT]);
    if (HeapNode_amOwner(q->heap[ROOT])) {
      q->heap[ROOT]->tag = AVAILABLE;
      q->heap[ROOT]->owner = NO_ONE;
    }
    HeapNode_unlock(q->heap[ROOT]);
  }
}

T PQueue_removeMin(PQueue q) {
  Lock_acquire(q->heapLock);
  int bottom = --q->next;
  HeapNode_lock(q->heap[bottom]);
  HeapNode_lock(q->heap[ROOT]);
  Lock_release(q->heapLock);
  if (q->heap[ROOT]->tag == EMPTY) {
    HeapNode_unlock(q->heap[ROOT]);
    HeapNode_unlock(q->heap[bottom]);
    return -1;
  }
  T item = q->heap[ROOT]->item;
  q->heap[ROOT]->tag = EMPTY;
  swap(q, bottom, ROOT);
  q->heap[bottom]->owner = NO_ONE;
  HeapNode_unlock(q->heap[bottom]);
  if (q->heap[ROOT]->tag == EMPTY) {
    HeapNode_unlock(q->heap[ROOT]);
    return item;
  }
  int child = 0;
  int parent = ROOT;
  while (parent < q->heap_length / 2) {
    int left = parent * 2;
    int right = (parent * 2) + 1;
    HeapNode_lock(q->heap[left]);
    HeapNode_lock(q->heap[right]);
    if (q->heap[left]->tag == EMPTY) {
      HeapNode_unlock(q->heap[right]);
      HeapNode_unlock(q->heap[left]);
      break;
    } else if (q->heap[right]->tag == EMPTY ||
               q->heap[left]->score < q->heap[right]->score) {
      HeapNode_unlock(q->heap[right]);
      child = left;
    } else {
      HeapNode_unlock(q->heap[left]);
      child = right;
    }
    if (q->heap[child]->score < q->heap[parent]->score) {
      swap(q, parent, child);
      HeapNode_unlock(q->heap[parent]);
      parent = child;
    } else {
      HeapNode_unlock(q->heap[child]);
      break;
    }
  }
  HeapNode_unlock(q->heap[parent]);
  return item;
}

static void swap(PQueue q, int i, int j) {
  $proc _owner = q->heap[i]->owner;
  q->heap[i]->owner = q->heap[j]->owner;
  q->heap[j]->owner = _owner;
  T _item = q->heap[i]->item;
  q->heap[i]->item = q->heap[j]->item;
  q->heap[j]->item = _item;
  int _priority = q->heap[i]->score;
  q->heap[i]->score = q->heap[j]->score;
  q->heap[j]->score = _priority;
  Status _tag = q->heap[i]->tag;
  q->heap[i]->tag = q->heap[j]->tag;
  q->heap[j]->tag = _tag;
}

void PQueue_print(PQueue q) {
  $print("{ ");
  for (int i = 1; i < q->next; i++) {
    $print("(", q->heap[i]->item, ",", q->heap[i]->score, ") ");
  }
  $print("}");
}

#ifdef _FINE_GRAINED_HEAP_MAIN

static void startup(PQueue pq, int nproc) {
  tid_init(nproc);
  PQueue_initialize_context();
  PQueue_initialize(pq);
}

static void shutdown(PQueue pq) {
  PQueue_finalize(pq);
  PQueue_finalize_context();
  tid_finalize();
}

int main() {
  PQueue pq = PQueue_create();
  int N = 2;
  startup(pq, N);
  $parfor(int i : 0 .. N - 1) {
    tid_register(i);
    PQueue_add(pq, i, i);
    PQueue_terminate(i);
    tid_unregister();
  }
  $print("PQueue after adds: ");
  PQueue_print(pq);
  $print("\n");
  shutdown(pq);
  
  int result[N];
  startup(pq, N);
  $parfor(int i : 0 .. N - 1) {
    tid_register(i);
    result[i] = PQueue_removeMin(pq);
    PQueue_terminate(i);
    tid_unregister();
  }
  $print("PQueue after removes: ");
  PQueue_print(pq);
  $print("\n");
  shutdown(pq);

  // could use forall exists to do this instead of checksum
  int checkSum = 0;
  $for(int i : 0 .. N - 1)
    checkSum += result[i];
  $assert(checkSum == (N * (N - 1)) / 2);
  PQueue_destroy(pq);
}
#endif