#include "MCTargetDesc/X86MCTargetDesc.h"
#include "Views/SummaryView.h"
#include "X86TestBase.h"
#include "llvm/ADT/SmallPtrSet.h"
#include "llvm/MC/MCInstBuilder.h"
#include "llvm/MCA/CustomBehaviour.h"
#include "llvm/MCA/IncrementalSourceMgr.h"
#include "llvm/MCA/InstrBuilder.h"
#include "llvm/MCA/Pipeline.h"
#include "llvm/Support/Format.h"
#include "llvm/Support/JSON.h"
#include "llvm/Support/raw_ostream.h"
#include <memory>
#include <unordered_map>

using namespace llvm;
using namespace mca;

TEST_F(X86TestBase, TestResumablePipeline) {
  mca::Context MCA(*MRI, *STI);

  mca::IncrementalSourceMgr ISM;
  // Empty CustomBehaviour.
  auto CB = std::make_unique<mca::CustomBehaviour>(*STI, ISM, *MCII);

  auto PO = getDefaultPipelineOptions();
  auto P = MCA.createDefaultPipeline(PO, ISM, *CB);
  ASSERT_TRUE(P);

  SmallVector<MCInst> MCIs;
  getSimpleInsts(MCIs, /*Repeats=*/100);

  // Add views.
  auto SV = std::make_unique<SummaryView>(STI->getSchedModel(), MCIs,
                                          PO.DispatchWidth);
  P->addEventListener(SV.get());

  auto IM = std::make_unique<mca::InstrumentManager>(*STI, *MCII);
  mca::InstrBuilder IB(*STI, *MCII, *MRI, MCIA.get(), *IM, /*CallLatency=*/100);

  const SmallVector<mca::Instrument *> Instruments;
  // Tile size = 7
  for (unsigned i = 0U, E = MCIs.size(); i < E;) {
    for (unsigned TE = i + 7; i < TE && i < E; ++i) {
      Expected<std::unique_ptr<mca::Instruction>> InstOrErr =
          IB.createInstruction(MCIs[i], Instruments);
      ASSERT_TRUE(bool(InstOrErr));
      ISM.addInst(std::move(InstOrErr.get()));
    }

    // Run the pipeline.
    Expected<unsigned> Cycles = P->run();
    if (!Cycles) {
      // Should be a stream pause error.
      ASSERT_TRUE(Cycles.errorIsA<mca::InstStreamPause>());
      llvm::consumeError(Cycles.takeError());
    }
  }

  ISM.endOfStream();
  // Has to terminate properly.
  Expected<unsigned> Cycles = P->run();
  ASSERT_TRUE(bool(Cycles));

  json::Value Result = SV->toJSON();
  auto *ResultObj = Result.getAsObject();
  ASSERT_TRUE(ResultObj);

  // Run the baseline.
  json::Object BaselineResult;
  auto E = runBaselineMCA(BaselineResult, MCIs);
  ASSERT_FALSE(bool(E)) << "Failed to run baseline";
  auto *BaselineObj = BaselineResult.getObject(SV->getNameAsString());
  ASSERT_TRUE(BaselineObj) << "Does not contain SummaryView result";

  // Compare the results.
  constexpr const char *Fields[] = {"Instructions", "TotalCycles", "TotaluOps",
                                    "BlockRThroughput"};
  for (const auto *F : Fields) {
    auto V = ResultObj->getInteger(F);
    auto BV = BaselineObj->getInteger(F);
    ASSERT_TRUE(V && BV);
    ASSERT_EQ(*BV, *V) << "Value of '" << F << "' does not match";
  }
}

TEST_F(X86TestBase, TestInstructionRecycling) {
  mca::Context MCA(*MRI, *STI);

  std::unordered_map<const mca::InstrDesc *, SmallPtrSet<mca::Instruction *, 2>>
      RecycledInsts;
  auto GetRecycledInst = [&](const mca::InstrDesc &Desc) -> mca::Instruction * {
    auto It = RecycledInsts.find(&Desc);
    if (It != RecycledInsts.end()) {
      auto &Insts = It->second;
      if (Insts.size()) {
        mca::Instruction *I = *Insts.begin();
        Insts.erase(I);
        return I;
      }
    }
    return nullptr;
  };
  auto AddRecycledInst = [&](mca::Instruction *I) {
    const mca::InstrDesc &D = I->getDesc();
    RecycledInsts[&D].insert(I);
  };

  mca::IncrementalSourceMgr ISM;
  ISM.setOnInstFreedCallback(AddRecycledInst);

  // Empty CustomBehaviour.
  auto CB = std::make_unique<mca::CustomBehaviour>(*STI, ISM, *MCII);

  auto PO = getDefaultPipelineOptions();
  auto P = MCA.createDefaultPipeline(PO, ISM, *CB);
  ASSERT_TRUE(P);

  SmallVector<MCInst> MCIs;
  getSimpleInsts(MCIs, /*Repeats=*/100);

  // Add views.
  auto SV = std::make_unique<SummaryView>(STI->getSchedModel(), MCIs,
                                          PO.DispatchWidth);
  P->addEventListener(SV.get());

  // Default InstrumentManager
  auto IM = std::make_unique<mca::InstrumentManager>(*STI, *MCII);

  mca::InstrBuilder IB(*STI, *MCII, *MRI, MCIA.get(), *IM, /*CallLatency=*/100);
  IB.setInstRecycleCallback(GetRecycledInst);

  const SmallVector<mca::Instrument *> Instruments;
  // Tile size = 7
  for (unsigned i = 0U, E = MCIs.size(); i < E;) {
    for (unsigned TE = i + 7; i < TE && i < E; ++i) {
      Expected<std::unique_ptr<mca::Instruction>> InstOrErr =
          IB.createInstruction(MCIs[i], Instruments);

      if (!InstOrErr) {
        mca::Instruction *RecycledInst = nullptr;
        // Check if the returned instruction is a recycled
        // one.
        auto RemainingE = handleErrors(InstOrErr.takeError(),
                                       [&](const mca::RecycledInstErr &RC) {
                                         RecycledInst = RC.getInst();
                                       });
        ASSERT_FALSE(bool(RemainingE));
        ASSERT_TRUE(RecycledInst);
        ISM.addRecycledInst(RecycledInst);
      } else {
        ISM.addInst(std::move(InstOrErr.get()));
      }
    }

    // Run the pipeline.
    Expected<unsigned> Cycles = P->run();
    if (!Cycles) {
      // Should be a stream pause error.
      ASSERT_TRUE(Cycles.errorIsA<mca::InstStreamPause>());
      llvm::consumeError(Cycles.takeError());
    }
  }

  ISM.endOfStream();
  // Has to terminate properly.
  Expected<unsigned> Cycles = P->run();
  ASSERT_TRUE(bool(Cycles));

  json::Value Result = SV->toJSON();
  auto *ResultObj = Result.getAsObject();
  ASSERT_TRUE(ResultObj);

  // Run the baseline.
  json::Object BaselineResult;
  auto E = runBaselineMCA(BaselineResult, MCIs);
  ASSERT_FALSE(bool(E)) << "Failed to run baseline";
  auto *BaselineObj = BaselineResult.getObject(SV->getNameAsString());
  ASSERT_TRUE(BaselineObj) << "Does not contain SummaryView result";

  // Compare the results.
  constexpr const char *Fields[] = {"Instructions", "TotalCycles", "TotaluOps",
                                    "BlockRThroughput"};
  for (const auto *F : Fields) {
    auto V = ResultObj->getInteger(F);
    auto BV = BaselineObj->getInteger(F);
    ASSERT_TRUE(V && BV);
    ASSERT_EQ(*BV, *V) << "Value of '" << F << "' does not match";
  }
}

// Test that we do not depend upon the MCInst address for variant description
// construction. This test creates two instructions that will use variant
// description as they are both zeroing idioms, but write to different
// registers. If the key used to access the variant instruction description is
// the same between the descriptions (like the MCInst pointer), we will run into
// an assertion failure due to the different writes.
TEST_F(X86TestBase, TestVariantInstructionsSameAddress) {
  mca::Context MCA(*MRI, *STI);

  mca::IncrementalSourceMgr ISM;
  // Empty CustomBehaviour.
  auto CB = std::make_unique<mca::CustomBehaviour>(*STI, ISM, *MCII);

  auto PO = getDefaultPipelineOptions();
  auto P = MCA.createDefaultPipeline(PO, ISM, *CB);
  ASSERT_TRUE(P);

  auto IM = std::make_unique<mca::InstrumentManager>(*STI, *MCII);
  mca::InstrBuilder IB(*STI, *MCII, *MRI, MCIA.get(), *IM, 100);

  const SmallVector<mca::Instrument *> Instruments;

  MCInst InstructionToAdd;
  InstructionToAdd = MCInstBuilder(X86::XOR64rr)
                         .addReg(X86::RAX)
                         .addReg(X86::RAX)
                         .addReg(X86::RAX);
  Expected<std::unique_ptr<mca::Instruction>> Instruction1OrErr =
      IB.createInstruction(InstructionToAdd, Instruments);
  ASSERT_TRUE(static_cast<bool>(Instruction1OrErr));
  ISM.addInst(std::move(Instruction1OrErr.get()));

  InstructionToAdd = MCInstBuilder(X86::XORPSrr)
                         .addReg(X86::XMM0)
                         .addReg(X86::XMM0)
                         .addReg(X86::XMM0);
  Expected<std::unique_ptr<mca::Instruction>> Instruction2OrErr =
      IB.createInstruction(InstructionToAdd, Instruments);
  ASSERT_TRUE(static_cast<bool>(Instruction2OrErr));
  ISM.addInst(std::move(Instruction2OrErr.get()));

  ISM.endOfStream();
  Expected<unsigned> Cycles = P->run();
  ASSERT_TRUE(static_cast<bool>(Cycles));
}
