/*******************************************************************************
 *
 * MIT License
 *
 * Copyright (C) 2022-2025 Advanced Micro Devices, Inc. All rights reserved.
 *
 * Permission is hereby granted, free of charge, to any person obtaining a copy
 * of this software and associated documentation files (the "Software"), to deal
 * in the Software without restriction, including without limitation the rights
 * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
 * copies of the Software, and to permit persons to whom the Software is
 * furnished to do so, subject to the following conditions:
 *
 * The above copyright notice and this permission notice shall be included in
 * all copies or substantial portions of the Software.
 *
 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
 * SOFTWARE.
 *
 *******************************************************************************/

#include <Tensile/AMDGPU.hpp>

namespace TensileLite
{
    // TODD- Currently, this map is for detecting CU-Fallback + runtime modification for XCC.
    //       XCC is introduced since 942, so this is only targeting at gfx942 for now.
    //       (i.e. Detect this on 942 only. For other devices, we don't have problems for CU-FB + mismatching XCC,
    //        so all the other devices will be seen as standardCU)
    //       THIS IS A TEMPORARILY WORKAROUND. WE MAY NEED TO COME UP WITH A BETTER WAY TO DETECT CU-FALLBACK
    static const std::map<AMDGPU::Processor, int>& StandardCUMap()
    {
        static const std::map<AMDGPU::Processor, int> StandardCU_XCC
            = {{AMDGPU::Processor::gfx942, 304}};
        return StandardCU_XCC;
    }

    TENSILE_API std::string AMDGPU::type() const
    {
        return Type();
    }

    TENSILE_API AMDGPU::AMDGPU() {}

    TENSILE_API AMDGPU::AMDGPU(AMDGPU::Processor p, int cus, std::string const& name)
        : processor(p)
        , computeUnitCount(cus)
        , deviceName(name)
        , skDynamicGrid(getSKDynamicGrid())
        , skDynamicWGM(getSKDynamicWGM())
        , fixedWGM(getFixedWGM())
        , fixedWGMXCC(getFixedWGMXCC())
        , skMaxCUs(getSKMaxCUs())
        , skGridMultiplier(getSKGridMultiplier())
        , skFixedGrid(getSKFixedGrid())
        , skFullTiles(getSKFullTiles())
    {
    }

    TENSILE_API AMDGPU::~AMDGPU() = default;

    TENSILE_API bool AMDGPU::isStandardCU() const
    {
        // return the result if we already tested it.
        if(isStandardCUs != -1)
            return (isStandardCUs == 1);

        // assume current device is a standard cu device.
        isStandardCUs = 1;

        auto mapIter = StandardCUMap().find(processor);
        // NB: For any other devices not included in the map, we always see them as true (they have no XCC predicate issue.)
        if(mapIter != StandardCUMap().end())
        {
            // check if current device is a standard cu devcie.
            if(computeUnitCount != mapIter->second)
                isStandardCUs = 0;
        }

        return (isStandardCUs == 1);
    }

    TENSILE_API bool AMDGPU::runsKernelTargeting(AMDGPU::Processor other) const
    {
        if(other > this->processor)
            return false;
        if(other == this->processor)
            return true;

        if(other == Processor::gfx803)
            return false;

        if(other == Processor::gfx900)
            return true;

        return false;
    }

    std::ostream& operator<<(std::ostream& stream, AMDGPU::Processor p)
    {
        stream << AMDGPU::toString(p);
        return stream;
    }

    TENSILE_API std::string AMDGPU::description() const
    {
        std::ostringstream rv;

        rv << deviceName << "(" << computeUnitCount << "-CU " << processor << ")";

        return rv.str();
    }

    TENSILE_API std::ostream& operator<<(std::ostream& stream, AMDGPU g)
    {
        return stream << g.description();
    }
} // namespace TensileLite
