{-# LANGUAGE ApplicativeDo #-}
{-# LANGUAGE BlockArguments #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE ViewPatterns #-}

module U.Codebase.Sqlite.Sync22 where

import Control.Monad (when)
import Control.Monad.Except (ExceptT, MonadError (throwError))
import qualified Control.Monad.Except as Except
import Control.Monad.Extra (ifM)
import Control.Monad.RWS (MonadIO, MonadReader, lift)
import Control.Monad.Reader (ReaderT)
import qualified Control.Monad.Reader as Reader
import Control.Monad.Validate (ValidateT, runValidateT)
import qualified Control.Monad.Validate as Validate
import Data.Bifunctor (bimap)
import Data.Bitraversable (bitraverse)
import Data.ByteString (ByteString)
import Data.Bytes.Get (getWord8, runGetS)
import Data.Bytes.Put (putWord8, runPutS)
import Data.Foldable (for_, toList, traverse_)
import Data.Functor ((<&>))
import Data.List.Extra (nubOrd)
import Data.Maybe (catMaybes, fromMaybe)
import Data.Set (Set)
import qualified Data.Set as Set
import Debug.Trace (traceM, trace)
import qualified U.Codebase.Reference as Reference
import qualified U.Codebase.Sqlite.Branch.Format as BL
import U.Codebase.Sqlite.Connection (Connection)
import U.Codebase.Sqlite.DbId
import qualified U.Codebase.Sqlite.LocalIds as L
import qualified U.Codebase.Sqlite.ObjectType as OT
import qualified U.Codebase.Sqlite.Term.Format as TL
import qualified U.Codebase.Sqlite.Patch.Format as PL
import qualified U.Codebase.Sqlite.Queries as Q
import qualified U.Codebase.Sqlite.Reference as Sqlite
import qualified U.Codebase.Sqlite.Reference as Sqlite.Reference
import qualified U.Codebase.Sqlite.Referent as Sqlite.Referent
import qualified U.Codebase.Sqlite.Serialization as S
import U.Codebase.Sync (Sync (Sync), TrySyncResult)
import qualified U.Codebase.Sync as Sync
import qualified U.Codebase.WatchKind as WK
import U.Util.Cache (Cache)
import qualified U.Util.Cache as Cache

data Entity
  = O ObjectId
  | C CausalHashId
  | W WK.WatchKind Sqlite.Reference.IdH
  deriving (Eq, Ord, Show)

data DbTag = SrcDb | DestDb

data DecodeError
  = ErrTermComponent
  | ErrDeclComponent
  | ErrBranchFormat
  | ErrPatchFormat
  | ErrWatchResult
  deriving (Show)

type ErrString = String

data Error
  = DbIntegrity Q.Integrity
  | DecodeError DecodeError ByteString ErrString
  | -- | hashes corresponding to a single object in source codebase
    --  correspond to multiple objects in destination codebase
    HashObjectCorrespondence ObjectId [HashId] [HashId] [ObjectId]
  | SourceDbNotExist
  deriving (Show)

data Env = Env
  { srcDB :: Connection,
    destDB :: Connection,
    -- | there are three caches of this size
    idCacheSize :: Word
  }

debug :: Bool
debug = False

-- data Mappings
sync22 ::
  ( MonadIO m,
    MonadError Error m,
    MonadReader Env m
  ) =>
  m (Sync m Entity)
sync22 = do
  size <- Reader.reader idCacheSize
  tCache <- Cache.semispaceCache size
  hCache <- Cache.semispaceCache size
  oCache <- Cache.semispaceCache size
  cCache <- Cache.semispaceCache size
  pure $ Sync (trySync tCache hCache oCache cCache)

trySync ::
  forall m.
  (MonadIO m, MonadError Error m, MonadReader Env m) =>
  Cache TextId TextId ->
  Cache HashId HashId ->
  Cache ObjectId ObjectId ->
  Cache CausalHashId CausalHashId ->
  Entity ->
  m (TrySyncResult Entity)
trySync tCache hCache oCache cCache = \case
  -- for causals, we need to get the value_hash_id of the thingo
  -- - maybe enqueue their parents
  -- - enqueue the self_ and value_ hashes
  -- - enqueue the namespace object, if present
  C chId ->
    isSyncedCausal chId >>= \case
      Just {} -> pure Sync.PreviouslyDone
      Nothing -> do
        result <- runValidateT @(Set Entity) @m @() do
          bhId <- runSrc $ Q.loadCausalValueHashId chId
          mayBoId <- runSrc . Q.maybeObjectIdForAnyHashId $ unBranchHashId bhId
          traverse_ syncLocalObjectId mayBoId

          parents' :: [CausalHashId] <- findParents' chId
          bhId' <- lift $ syncBranchHashId bhId
          chId' <- lift $ syncCausalHashId chId
          runDest do
            Q.saveCausal chId' bhId'
            Q.saveCausalParents chId' parents'

        case result of
          Left deps -> pure . Sync.Missing $ toList deps
          Right () -> pure Sync.Done

  -- objects are the hairiest. obviously, if they
  -- exist, we're done; otherwise we do some fancy stuff
  O oId ->
    isSyncedObject oId >>= \case
      Just {} -> pure Sync.PreviouslyDone
      Nothing -> do
        (hId, objType, bytes) <- runSrc $ Q.loadObjectWithHashIdAndTypeById oId
        hId' <- syncHashLiteral hId
        result <- runValidateT @(Set Entity) @m @ObjectId case objType of
          OT.TermComponent -> do
            -- split up the localIds (parsed), term, and type blobs
            -- note: this whole business with `fmt` is pretty weird, and will need to be
            -- revisited when there are more formats.
            -- (or maybe i'll learn something by implementing sync for patches and namespaces,
            -- which have two formats already)
            (fmt, unzip -> (localIds, bytes)) <-
              lift case flip runGetS bytes do
                tag <- getWord8
                component <- S.decomposeComponent
                pure (tag, component) of
                Right x -> pure x
                Left s -> throwError $ DecodeError ErrTermComponent bytes s
            -- iterate through the local ids looking for missing deps;
            -- then either enqueue the missing deps, or proceed to move the object
            when debug $ traceM $ "LocalIds for Source " ++ show oId ++ ": " ++ show localIds
            localIds' <- traverse syncLocalIds localIds
            when debug $ traceM $ "LocalIds for Dest: " ++ show localIds'
            -- reassemble and save the reindexed term
            let bytes' =
                  runPutS $
                    putWord8 fmt >> S.recomposeComponent (zip localIds' bytes)
            oId' <- runDest $ Q.saveObject hId' objType bytes'
            -- copy reference-specific stuff
            lift $ for_ [0 .. length localIds - 1] \(fromIntegral -> idx) -> do
              -- sync watch results
              for_ [WK.TestWatch] \wk ->
                syncWatch wk (Reference.Id hId idx)
              -- sync dependencies index
              let ref = Reference.Id oId idx
                  ref' = Reference.Id oId' idx
              let fromJust' = fromMaybe (error "missing objects should've been caught by `foldLocalIds` above")
              runSrc (Q.getDependenciesForDependent ref)
                >>= traverse (fmap fromJust' . isSyncedObjectReference)
                >>= runDest . traverse_ (flip Q.addToDependentsIndex ref')
              -- sync type index
              runSrc (Q.getTypeReferencesForComponent oId)
                >>= traverse (syncTypeIndexRow oId')
                >>= traverse_ (runDest . uncurry Q.addToTypeIndex)
              -- sync type mentions index
              runSrc (Q.getTypeMentionsReferencesForComponent oId)
                >>= traverse (syncTypeIndexRow oId')
                >>= traverse_ (runDest . uncurry Q.addToTypeMentionsIndex)
            pure oId'
          OT.DeclComponent -> do
            -- split up the localIds (parsed), decl blobs
            (fmt, unzip -> (localIds, declBytes)) <-
              case flip runGetS bytes do
                tag <- getWord8
                component <- S.decomposeComponent
                pure (tag, component) of
                Right x -> pure x
                Left s -> throwError $ DecodeError ErrDeclComponent bytes s
            -- iterate through the local ids looking for missing deps;
            -- then either enqueue the missing deps, or proceed to move the object
            localIds' <- traverse syncLocalIds localIds
            -- reassemble and save the reindexed term
            let bytes' =
                  runPutS $
                    putWord8 fmt
                      >> S.recomposeComponent (zip localIds' declBytes)
            oId' <- runDest $ Q.saveObject hId' objType bytes'
            -- copy per-element-of-the-component stuff
            lift $ for_ [0 .. length localIds - 1] \(fromIntegral -> idx) -> do
              -- sync dependencies index
              let ref = Reference.Id oId idx
                  ref' = Reference.Id oId' idx
              let fromJust' = fromMaybe (error "missing objects should've been caught by `foldLocalIds` above")
              runSrc (Q.getDependenciesForDependent ref)
                >>= traverse (fmap fromJust' . isSyncedObjectReference)
                >>= runDest . traverse_ (flip Q.addToDependentsIndex ref')
              -- sync type index
              runSrc (Q.getTypeReferencesForComponent oId)
                >>= traverse (syncTypeIndexRow oId')
                >>= traverse_ (runDest . uncurry Q.addToTypeIndex)
              -- sync type mentions index
              runSrc (Q.getTypeMentionsReferencesForComponent oId)
                >>= traverse (syncTypeIndexRow oId')
                >>= traverse_ (runDest . uncurry Q.addToTypeMentionsIndex)
            pure oId'
          OT.Namespace -> case flip runGetS bytes S.decomposeBranchFormat of
            Right (BL.SyncFull ids body) -> do
              ids' <- syncBranchLocalIds ids
              let bytes' = runPutS $ S.recomposeBranchFormat (BL.SyncFull ids' body)
              oId' <- runDest $ Q.saveObject hId' objType bytes'
              pure oId'
            Right (BL.SyncDiff boId ids body) -> do
              boId' <- syncBranchObjectId boId
              ids' <- syncBranchLocalIds ids
              let bytes' = runPutS $ S.recomposeBranchFormat (BL.SyncDiff boId' ids' body)
              oId' <- runDest $ Q.saveObject hId' objType bytes'
              pure oId'
            Left s -> throwError $ DecodeError ErrBranchFormat bytes s
          OT.Patch -> case flip runGetS bytes S.decomposePatchFormat of
            Right (PL.SyncFull ids body) -> do
              ids' <- syncPatchLocalIds ids
              let bytes' = runPutS $ S.recomposePatchFormat (PL.SyncFull ids' body)
              oId' <- runDest $ Q.saveObject hId' objType bytes'
              pure oId'
            Right (PL.SyncDiff poId ids body) -> do
              poId' <- syncPatchObjectId poId
              ids' <- syncPatchLocalIds ids
              let bytes' = runPutS $ S.recomposePatchFormat (PL.SyncDiff poId' ids' body)
              oId' <- runDest $ Q.saveObject hId' objType bytes'
              pure oId'
            Left s -> throwError $ DecodeError ErrPatchFormat bytes s
        case result of
          Left deps -> pure . Sync.Missing $ toList deps
          Right oId' -> do
            syncSecondaryHashes oId oId'
            when debug $ traceM $ "Source " ++ show (hId, oId) ++ " becomes Dest " ++ show (hId', oId')
            Cache.insert oCache oId oId'
            pure Sync.Done
  W k r -> syncWatch k r
  where
    syncLocalObjectId :: ObjectId -> ValidateT (Set Entity) m ObjectId
    syncLocalObjectId oId =
      lift (isSyncedObject oId) >>= \case
        Just oId' -> pure oId'
        Nothing -> Validate.refute . Set.singleton $ O oId

    syncPatchObjectId :: PatchObjectId -> ValidateT (Set Entity) m PatchObjectId
    syncPatchObjectId = fmap PatchObjectId . syncLocalObjectId . unPatchObjectId

    syncBranchObjectId :: BranchObjectId -> ValidateT (Set Entity) m BranchObjectId
    syncBranchObjectId = fmap BranchObjectId . syncLocalObjectId . unBranchObjectId

    syncCausal :: CausalHashId -> ValidateT (Set Entity) m CausalHashId
    syncCausal chId =
      lift (isSyncedCausal chId) >>= \case
        Just chId' -> pure chId'
        Nothing -> Validate.refute . Set.singleton $ C chId

    syncLocalIds :: L.LocalIds -> ValidateT (Set Entity) m L.LocalIds
    syncLocalIds (L.LocalIds tIds oIds) = do
      oIds' <- traverse syncLocalObjectId oIds
      tIds' <- lift $ traverse syncTextLiteral tIds
      pure $ L.LocalIds tIds' oIds'

    syncPatchLocalIds :: PL.PatchLocalIds -> ValidateT (Set Entity) m PL.PatchLocalIds
    syncPatchLocalIds (PL.LocalIds tIds hIds oIds) = do
      oIds' <- traverse syncLocalObjectId oIds
      tIds' <- lift $ traverse syncTextLiteral tIds
      hIds' <- lift $ traverse syncHashLiteral hIds

      -- workaround for requiring components to compute component lengths for references.
      -- this line requires objects in the destination for any hashes referenced in the source,
      -- (making those objects dependencies of this patch).  See Sync21.filter{Term,Type}Edit
      traverse_ syncLocalObjectId =<< traverse (runSrc . Q.expectObjectIdForAnyHashId) hIds

      pure $ PL.LocalIds tIds' hIds' oIds'

    syncBranchLocalIds :: BL.BranchLocalIds -> ValidateT (Set Entity) m BL.BranchLocalIds
    syncBranchLocalIds (BL.LocalIds tIds oIds poIds chboIds) = do
      oIds' <- traverse syncLocalObjectId oIds
      poIds' <- traverse (fmap PatchObjectId . syncLocalObjectId . unPatchObjectId) poIds
      chboIds' <- traverse (bitraverse syncBranchObjectId syncCausal) chboIds
      tIds' <- lift $ traverse syncTextLiteral tIds
      pure $ BL.LocalIds tIds' oIds' poIds' chboIds'

    syncTypeIndexRow oId' = bitraverse syncHashReference (pure . rewriteTypeIndexReferent oId')

    rewriteTypeIndexReferent :: ObjectId -> Sqlite.Referent.Id -> Sqlite.Referent.Id
    rewriteTypeIndexReferent oId' = bimap (const oId') (const oId')

    syncTextLiteral :: TextId -> m TextId
    syncTextLiteral = Cache.apply tCache \tId -> do
      t <- runSrc $ Q.loadTextById tId
      tId' <- runDest $ Q.saveText t
      when debug $ traceM $ "Source " ++ show tId ++ " is Dest " ++ show tId' ++ " (" ++ show t ++ ")"
      pure tId'

    syncHashLiteral :: HashId -> m HashId
    syncHashLiteral = Cache.apply hCache \hId -> do
      b32hex <- runSrc $ Q.loadHashById hId
      hId' <- runDest $ Q.saveHash b32hex
      when debug $ traceM $ "Source " ++ show hId ++ " is Dest " ++ show hId' ++ " (" ++ show b32hex ++ ")"
      pure hId'

    isSyncedObjectReference :: Sqlite.Reference -> m (Maybe Sqlite.Reference)
    isSyncedObjectReference = \case
      Reference.ReferenceBuiltin t ->
        Just . Reference.ReferenceBuiltin <$> syncTextLiteral t
      Reference.ReferenceDerived id ->
        fmap Reference.ReferenceDerived <$> isSyncedObjectReferenceId id

    isSyncedObjectReferenceId :: Sqlite.Reference.Id -> m (Maybe Sqlite.Reference.Id)
    isSyncedObjectReferenceId (Reference.Id oId idx) =
      isSyncedObject oId <&> fmap (\oId' -> Reference.Id oId' idx)

    syncHashReference :: Sqlite.ReferenceH -> m Sqlite.ReferenceH
    syncHashReference = bitraverse syncTextLiteral syncHashLiteral

    syncCausalHashId :: CausalHashId -> m CausalHashId
    syncCausalHashId = fmap CausalHashId . syncHashLiteral . unCausalHashId

    syncBranchHashId :: BranchHashId -> m BranchHashId
    syncBranchHashId = fmap BranchHashId . syncHashLiteral . unBranchHashId

    findParents' :: CausalHashId -> ValidateT (Set Entity) m [CausalHashId]
    findParents' chId = do
      srcParents <- runSrc $ Q.loadCausalParents chId
      traverse syncCausal srcParents

    syncWatch :: WK.WatchKind -> Sqlite.Reference.IdH -> m (TrySyncResult Entity)
    syncWatch wk r | debug && trace ("Sync22.syncWatch " ++ show wk ++ " " ++ show r) False = undefined
    syncWatch wk r = do
      r' <- traverse syncHashLiteral r
      doneKinds <- runDest (Q.loadWatchKindsByReference r')
      if (notElem wk doneKinds) then do
        runSrc (Q.loadWatch wk r) >>= traverse \blob -> do
          TL.SyncWatchResult li body <-
            either (throwError . DecodeError ErrWatchResult blob) pure $ runGetS S.decomposeWatchFormat blob
          li' <- bitraverse syncTextLiteral syncHashLiteral li
          when debug $ traceM $ "LocalIds for Source watch result " ++ show r ++ ": " ++ show li
          when debug $ traceM $ "LocalIds for Dest watch result " ++ show r' ++ ": " ++ show li'
          let blob' = runPutS $ S.recomposeWatchFormat (TL.SyncWatchResult li' body)
          runDest (Q.saveWatch wk r' blob')
        pure Sync.Done
      else pure Sync.PreviouslyDone

    syncSecondaryHashes oId oId' =
      runSrc (Q.hashIdWithVersionForObject oId) >>= traverse_ (go oId')
      where
        go oId' (hId, hashVersion) = do
          hId' <- syncHashLiteral hId
          runDest $ Q.saveHashObject hId' oId' hashVersion

    isSyncedObject :: ObjectId -> m (Maybe ObjectId)
    isSyncedObject = Cache.applyDefined oCache \oId -> do
      hIds <- toList <$> runSrc (Q.hashIdsForObject oId)
      hIds' <- traverse syncHashLiteral hIds
      ( nubOrd . catMaybes
          <$> traverse (runDest . Q.maybeObjectIdForAnyHashId) hIds'
        )
        >>= \case
          [oId'] -> do
            when debug $ traceM $ "Source " ++ show oId ++ " is Dest " ++ show oId'
            pure $ Just oId'
          [] -> pure $ Nothing
          oIds' -> throwError (HashObjectCorrespondence oId hIds hIds' oIds')

    isSyncedCausal :: CausalHashId -> m (Maybe CausalHashId)
    isSyncedCausal = Cache.applyDefined cCache \chId -> do
      let hId = unCausalHashId chId
      hId' <- syncHashLiteral hId
      ifM
        (runDest $ Q.isCausalHash hId')
        (pure . Just $ CausalHashId hId')
        (pure Nothing)

runSrc,
  runDest ::
    (MonadError Error m, MonadReader Env m) =>
    ReaderT Connection (ExceptT Q.Integrity m) a ->
    m a
runSrc ma = Reader.reader srcDB >>= flip runDB ma
runDest ma = Reader.reader destDB >>= flip runDB ma

runDB ::
  MonadError Error m => Connection -> ReaderT Connection (ExceptT Q.Integrity m) a -> m a
runDB conn action =
  Except.runExceptT (Reader.runReaderT action conn) >>= \case
    Left e -> throwError (DbIntegrity e)
    Right a -> pure a
