/*******************************************************************************
 *
 * MIT License
 *
 * Copyright 2024-2025 AMD ROCm(TM) Software
 *
 * Permission is hereby granted, free of charge, to any person obtaining a copy
 * of this software and associated documentation files (the "Software"), to deal
 * in the Software without restriction, including without limitation the rights
 * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
 * copies of the Software, and to permit persons to whom the Software is
 * furnished to do so, subject to the following conditions:
 *
 * The above copyright notice and this permission notice shall be included in
 * all copies or substantial portions of the Software.
 *
 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
 * SOFTWARE.
 *
 *******************************************************************************/

#include <rocRoller/Expression.hpp>

namespace rocRoller
{
    namespace Expression
    {
        struct LowerBitfieldCombineExpressionVisitor
        {
            template <CUnary Expr>
            ExpressionPtr operator()(Expr const& expr) const
            {
                Expr cpy = expr;
                cpy.arg  = call(expr.arg);
                return std::make_shared<Expression>(cpy);
            }

            template <CBinary Expr>
            ExpressionPtr operator()(Expr const& expr) const
            {
                Expr cpy = expr;
                cpy.lhs  = call(expr.lhs);
                cpy.rhs  = call(expr.rhs);
                return std::make_shared<Expression>(cpy);
            }

            ExpressionPtr operator()(ScaledMatrixMultiply const& expr) const
            {
                ScaledMatrixMultiply cpy = expr;
                cpy.matA                 = call(expr.matA);
                cpy.matB                 = call(expr.matB);
                cpy.matC                 = call(expr.matC);
                cpy.scaleA               = call(expr.scaleA);
                cpy.scaleB               = call(expr.scaleB);
                return std::make_shared<Expression>(cpy);
            }

            template <CTernary Expr>
            ExpressionPtr operator()(Expr const& expr) const
            {
                Expr cpy = expr;
                cpy.lhs  = call(expr.lhs);
                cpy.r1hs = call(expr.r1hs);
                cpy.r2hs = call(expr.r2hs);
                return std::make_shared<Expression>(cpy);
            }

            template <CNary Expr>
            ExpressionPtr operator()(Expr const& expr) const
            {
                auto cpy = expr;
                std::ranges::for_each(cpy.operands, [this](auto& op) { op = call(op); });
                return std::make_shared<Expression>(std::move(cpy));
            }

            ExpressionPtr operator()(BitfieldCombine const& expr) const
            {
                auto lhs = expr.lhs;
                lhs      = call(expr.lhs);
                if(lhs)
                {
                    AssertFatal(resultVariableType(lhs).getElementSize() <= 4u,
                                "Currently BitfieldCombine only supports: src size <= 1 dword");
                    AssertFatal(resultVariableType(lhs).getElementSize() * 8u
                                    >= expr.srcOffset + expr.width,
                                "Bitfield exceeds the number of bits of source, source size "
                                "(bytes), offset, width = ",
                                ShowValue(resultVariableType(lhs).getElementSize()),
                                ShowValue(expr.srcOffset),
                                ShowValue(expr.width));
                }

                auto rhs = expr.rhs;
                rhs      = call(expr.rhs);
                if(rhs)
                {
                    AssertFatal(resultVariableType(rhs).getElementSize() == 4u,
                                "Currently BitfieldCombine only supports: dst size = 1 dword");
                    AssertFatal(resultVariableType(rhs).getElementSize() * 8u
                                    >= expr.dstOffset + expr.width,
                                "Bitfield exceeds the number of bits of destination, destination "
                                "size (bytes), offset, width = ",
                                ShowValue(resultVariableType(rhs).getElementSize()),
                                ShowValue(expr.dstOffset),
                                ShowValue(expr.width));
                }

                auto const srcIsZero = expr.srcIsZero && expr.srcIsZero.value();
                if(not srcIsZero)
                {
                    rocRoller::Raw32 srcMask((static_cast<uint32_t>(1ul << expr.width) - 1ul)
                                             << expr.srcOffset);
                    lhs = (literal(srcMask) & lhs); // Extract bits
                }

                auto const dstIsZero = expr.dstIsZero && expr.dstIsZero.value();
                if(not dstIsZero)
                {
                    rocRoller::Raw32 dstMask(
                        ~((static_cast<uint32_t>(1ul << expr.width) - 1ul) << expr.dstOffset));
                    rhs = (literal(dstMask) & rhs); // Clear bits
                }

                if(expr.dstOffset > expr.srcOffset)
                    lhs = lhs << literal(expr.dstOffset - expr.srcOffset);
                else if(expr.dstOffset < expr.srcOffset)
                    lhs = logicalShiftR(lhs, literal(expr.srcOffset - expr.dstOffset));

                ExpressionPtr ret = lhs | rhs;
                setComment(ret, expr.comment);
                return ret;
            }

            template <CValue Value>
            ExpressionPtr operator()(Value const& expr) const
            {
                return std::make_shared<Expression>(expr);
            }

            ExpressionPtr call(ExpressionPtr expr) const
            {
                if(!expr)
                    return expr;

                return std::visit(*this, *expr);
            }
        };

        /**
         * Replace a BitfieldCombine expression with:
         *
         *   srcMask =   ((1 << width) - 1) << srcOffset
         *   dstMask = ~(((1 << width) - 1) << dstOffset)
         *   dst = shift((srcMask & src), abs(srcOffset-dstOffset)) | (dstMask & dst)
         *
         *   Note: src=lhs, dst=rhs
         */
        ExpressionPtr lowerBitfieldCombine(ExpressionPtr expr)
        {
            auto visitor = LowerBitfieldCombineExpressionVisitor();
            return visitor.call(expr);
        }

    }
}
