/* Filename : OptimisticList.cvl
   Author   : ZZZ, WWW, YYY, XXX
   Created  : 2018-05-01
   Modified : 2025-01-17

   CIVL model of class OptimisticList from "The Art of Multiprocessor
   Programming" 2nd ed, by Herlihy, Luchangco, Shavit, and Spear,
   Sec. 9.6, "Optimistic synchronization", and companion code
   ch9/src/lists/OptimisticList.java.
  
   This is a concurrent list using a lock per node and an optimisitic
   synchronization strategy.  We don't deal with destruction of nodes
   since the original code relies on the garbage collector and to do
   it manually is complicated.
  
   Note it is assumed that the hash function is injective.
*/
#include "Lock.h"
#include "Set.h"
#include "hash.cvh"
#include "tid.h"
#include "types.h"
#include <limits.h>
#include <stdbool.h>
#include <stdlib.h>

// Nodes...

typedef struct Node {
  T item;
  int key;
  struct Node * next;
  Lock lock;
} * Node;

static Node node_create(T item){
  Node this = malloc(sizeof(struct Node));
  this->lock = Lock_create();
  this->key = hash_code(item);
  this->item = item;
  this->next = NULL;
  return this; 
}

static void node_destroy(Node node) {
  Lock_destroy(node->lock);
  free(node);
}

static void node_destroy_rec(Node node) {
  if (node == NULL) return;
  node_destroy_rec(node->next);
  node_destroy(node);
}

// Sets...

struct Set {
  Node head;
};

Set Set_create() {
  Set this = malloc(sizeof(struct Set));
  this->head = node_create(INT_MIN);
  this->head->next = node_create(INT_MAX); 
  return this;
}

void Set_destroy(Set list) {
  node_destroy_rec(list->head);
  free(list);
}

void Set_initialize_context(void) {}

void Set_finalize_context(void) {}

void Set_initialize(Set set) {}

void Set_finalize(Set set) {}

void Set_terminate(int tid) {}

bool Set_stuck(void) {
  return false;
}

void Set_print(Set this) {
  Node curr = this->head->next;
  $print("{ ");
  while (curr->next != NULL) {
    $print(curr->item, " ");
    curr = curr->next;
  }
  $print("}");
}

static bool validate(Set this, Node pred, Node curr) {
  Node node = this->head;
  while (node->key <= pred->key) {
    if (node == pred) {
      return (pred->next == curr);
    }
    node = node->next;
  }
  return false;
}

bool Set_add(Set this, T item) {
  int key = hash_code(item);
  while (true) {
    Node pred = this->head;
    Node curr = pred->next;
    while (curr->key < key) {
      pred = curr;
      curr = curr->next;
    }
    Lock_acquire(pred->lock);
    // try...
    Lock_acquire(curr->lock);
    // try...
    if (validate(this, pred, curr)) {
      if (curr->key == key) {
	Lock_release(curr->lock); // finally
	Lock_release(pred->lock); // finally
	return false;
      } else {
	Node node = node_create(item);
        node->next = curr;
	pred->next = node;
	Lock_release(curr->lock); // finally
	Lock_release(pred->lock); // finally
	return true;
      }
    }
    Lock_release(curr->lock); // finally
    Lock_release(pred->lock); // finally
  }
}

bool Set_remove(Set this, T item) {
  int key = hash_code(item);
  while (true) {
    Node pred = this->head;
    Node curr = pred->next;
    while (curr->key < key) {
      pred = curr;
      curr = curr->next;
    }
    // try...
    Lock_acquire(pred->lock);
    // try...
    Lock_acquire(curr->lock);
    if (validate(this, pred, curr)) {
      if (curr->key == key) {
	pred->next = curr->next;
        Lock_release(curr->lock); // finally
	// node_destroy(curr); not a good time, other
        // threads may hold references to curr
	Lock_release(pred->lock); // finally
	return true;
      } else {
	Lock_release(curr->lock); // finally
	Lock_release(pred->lock); // finally
	return false;
      }
    }
    Lock_release(curr->lock); // finally
    Lock_release(pred->lock); // finally
  }
}

bool Set_contains(Set this, T item) {
  int key = hash_code(item);
  while (true) {
    Node pred = this->head;
    Node curr = pred->next;
    while (curr->key < key) {
      pred = curr;
      curr = curr->next;
    }
    Lock_acquire(pred->lock);
    // try...
    Lock_acquire(curr->lock);
    // try...
    if (validate(this, pred, curr)) {
      bool result = (curr->key == key);
      Lock_release(pred->lock); // finally
      Lock_release(curr->lock); // finally
      return result;
    }
    Lock_release(pred->lock); // finally
    Lock_release(curr->lock); // finally
  }
}

#ifdef _OPTIMISTIC_LIST_MAIN

static void startup(Set s, int nproc) {
  tid_init(nproc);
  Set_initialize_context();
  Set_initialize(s);
}

static void shutdown(Set s) {
  Set_finalize(s);
  Set_finalize_context();
  tid_finalize();
}

static void test1(Set list, int N) {
  startup(list, N);
  $parfor(int i: 0 .. N-1) {
    tid_register(i);
    Set_add(list, i);
    Set_terminate(i);
    tid_unregister();
  }
  Set_print(list);
  $print("\n");
  shutdown(list);

  startup(list, N);
  $parfor(int i: 0 .. N-1) {
    tid_register(i);
    $assert(Set_contains(list, i));
    Set_terminate(i);
    tid_unregister();
  }
  shutdown(list);
}

static void test2(Set list, int N) {
  startup(list, N);
  $parfor (int i: 0 .. N-1) {
    tid_register(i);
    int result = Set_add(list, i);
    $assert(result);
    result = Set_contains(list, i);
    $assert(result);
    result = Set_remove(list, i);
    $assert(result);
    result = Set_contains(list, i);
    $assert(!result);
    tid_unregister();
  }
  shutdown(list);
}

void main() {
  Set list = Set_create();
  int test = 1+$choose_int(2);
  $print("\n*** Starting test ", test, " ***\n");
  if (test == 1)
    test1(list, 3);
  else
    test2(list, 2);
  Set_destroy(list);
  $print("Test ", test, " complete.\n");
}
#endif