package com.zibaldone.cats
package ch_01

import cats.syntax.applicative.*
import cats.syntax.flatMap.*
import cats.syntax.functor.*
import cats.{FlatMap, Monad}

import scala.annotation.tailrec
import scala.collection.mutable
import scala.util.Try

// flatMap is a mental model of chained transformations
extension [F[_]: FlatMap, A](container: F[A])

  // ex. return all combinations (A, B)
  def combine[B](otherContainer: F[B]): F[(A, B)] =
    for // a.k.a `product` from the Semigroupal typeclass
      a <- container
      b <- otherContainer
    yield (a, b)

trait `monad`[F[_]] extends ch_03.`applicative`[F] with ch_03.`flatMap`[F]:

  // ex. implement map
  final override def map[A, B](fa: F[A])(f: A => B): F[B] = flatMap(fa)(a => pure(f(a)))

  final override def product[A, B](fa: F[A], fb: F[B]): F[(A, B)] =
    flatMap(fa)(a => map(fb)(b => (a, b)))

// ex. service layer API
final case class Connection(host: String, port: String)

sealed trait HttpService[F[_]: Monad]:

  def connection(config: Map[String, String]): F[Connection]

  def request(connection: Connection, payload: String): F[Int]

  def response(config: Map[String, String], payload: String): F[Int] =
    for connection <- connection(config); response <- request(connection, payload) yield response

object OptionalHttpService extends HttpService[Option]:

  override def connection(config: Map[String, String]): Option[Connection] =
    Option.when(config.contains("host") && config.contains("port")) { Connection(config("host"), config("port")) }

  override def request(connection: Connection, payload: String): Option[Int] =
    Option.when(payload.length >= 20) { 42 }

object TryHttpService extends HttpService[Try]:

  override def connection(config: Map[String, String]): Try[Connection] =
    Try { Connection(config("host"), config("port")) }

  override def request(connection: Connection, payload: String): Try[Int] =
    Try { if payload.length < 20 then throw new IllegalArgumentException() else 42 }

// ex. monad for identity type
opaque type Id[A] = A

given Monad[Id] = new Monad[Id]:

  override def pure[A](x: A): Id[A] = x

  override def flatMap[A, B](fa: Id[A])(f: A => Id[B]): Id[B] = f(fa)

  @tailrec override def tailRecM[A, B](a: A)(f: A => Id[Either[A, B]]): Id[B] = f(a) match
    case Left(value)  => tailRecM(value)(f) // left == false
    case Right(value) => value              // right == true

// ex. monad for tree type
enum BinaryTree[+T]:

  case Leaf(value: T)
  case Branch(left: BinaryTree[T], right: BinaryTree[T])

  // for simplicity only hashes on the memory ref
  override def hashCode(): Int = System.identityHashCode(this)

given Monad[BinaryTree] = new Monad[BinaryTree]:

  import ch_01.BinaryTree.{Branch, Leaf}

  override def pure[A](x: A): BinaryTree[A] = Leaf(x)

  override def flatMap[A, B](fa: BinaryTree[A])(f: A => BinaryTree[B]): BinaryTree[B] = fa match
    case Leaf(value)         => f(value)
    case Branch(left, right) => Branch(flatMap(left)(f), flatMap(right)(f))

  override def tailRecM[A, B](a: A)(f: A => BinaryTree[Either[A, B]]): BinaryTree[B] =

    @tailrec def loop(
      notVisited: List[BinaryTree[Either[A, B]]],
      visited: mutable.Set[BinaryTree[Either[A, B]]],
      done: List[BinaryTree[B]]
    ): BinaryTree[B] = notVisited match
      case Nil          => done.head
      case head :: next => head match
          case Leaf(Right(value))                           => loop(next, visited, Leaf(value) :: done)
          case Leaf(Left(value))                            => loop(f(value) :: next, visited, done)
          case root @ Branch(left, right) if !visited(root) => loop(right :: left :: notVisited, visited + root, done)
          case root @ Branch(left, right)                   => loop(next, visited, Branch(done.head, done.tail.head) :: done.drop(2))

    loop(f(a) :: Nil, mutable.Set.empty, Nil)