package ch_03
import utils.*
import cats.effect.*
import cats.syntax.parallel.*
import scala.collection.immutable.Queue
import scala.concurrent.duration.*
import scala.util.Random
abstract class MutexIO:
def acquire: IO[Unit]
def release: IO[Unit]
object MutexIO:
private[MutexIO] type Signal = Deferred[IO, Unit]
private[MutexIO] final case class State(isLocked: Boolean, queue: Queue[Signal])
private[MutexIO] val unlocked: State = State(false, Queue.empty)
def apply(): IO[MutexIO] = IO.ref(unlocked).map { state =>
new MutexIO:
override def acquire: IO[Unit] = IO.uncancelable { poll =>
IO.deferred[Unit].flatMap { signal =>
val cleanup: IO[Unit] = state.modify {
case State(isLocked, queue) =>
val newQueue = queue.filterNot(_ eq signal)
val isRunning = newQueue.size == queue.size
State(isLocked, queue) -> (if isRunning then release else IO.unit)
}.flatten
state.modify {
case State(false, _) => State(true, Queue.empty) -> IO.unit
case State(true, queue) => State(true, queue.enqueue(signal)) -> poll(signal.get).onCancel(cleanup)
}.flatten
}
}
override def release: IO[Unit] = state.modify {
case State(false, _) => unlocked -> IO.unit
case State(true, queue) if queue.isEmpty => unlocked -> IO.unit
case State(true, queue) =>
val (signal, rest) = queue.dequeue
State(true, rest) -> signal.complete(()).void
}.flatten
}
def criticalTask: IO[Int] = IO.sleep(5.seconds) >> IO(Random.nextInt(100))
def lockingTask(id: Int, mutex: MutexIO): IO[Int] =
for
_ <- IO.pure(s"[task-$id] - acquiring lock").inspect
_ <- mutex.acquire
_ <- IO.pure(s"[task-$id] - critical section").inspect
res <- criticalTask
_ <- IO.pure(s"[task-$id] - releasing mutex").inspect
_ <- mutex.release
_ <- IO.pure(s"[task-$id] - lock removed").inspect
yield res
def lockingTasks: IO[Int] =
for
mutex <- MutexIO()
res <- (1 to 10).toList.parTraverse(lockingTask(_, mutex))
yield res.sum