/**
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */
#include <gtest/gtest.h>
#include <pulsar/Client.h>
#include <pulsar/ConsumerCryptoFailureAction.h>
#include <pulsar/MessageBatch.h>

#include <optional>
#include <stdexcept>

#include "lib/CompressionCodec.h"
#include "lib/MessageCrypto.h"
#include "lib/SharedBuffer.h"

static std::string lookupUrl = "pulsar://localhost:6650";

using namespace pulsar;

static CryptoKeyReaderPtr getDefaultCryptoKeyReader() {
    return std::make_shared<DefaultCryptoKeyReader>(TEST_CONF_DIR "/public-key.client-rsa.pem",
                                                    TEST_CONF_DIR "/private-key.client-rsa.pem");
}

static std::vector<std::string> decryptValue(const char* data, size_t length,
                                             std::optional<const EncryptionContext*> context) {
    if (!context.has_value()) {
        return {std::string(data, length)};
    }
    if (!context.value()->isDecryptionFailed()) {
        return {std::string(data, length)};
    }

    MessageCrypto crypto{"test", false};
    SharedBuffer decryptedPayload;
    auto originalPayload = SharedBuffer::copy(data, length);
    if (!crypto.decrypt(*context.value(), originalPayload, getDefaultCryptoKeyReader(), decryptedPayload)) {
        throw std::runtime_error("Decryption failed");
    }

    SharedBuffer uncompressedPayload;
    if (!CompressionCodecProvider::getCodec(context.value()->compressionType())
             .decode(decryptedPayload, context.value()->uncompressedMessageSize(), uncompressedPayload)) {
        throw std::runtime_error("Decompression failed");
    }

    std::vector<std::string> values;
    if (auto batchSize = context.value()->batchSize(); batchSize > 0) {
        MessageBatch batch;
        for (auto&& msg : batch.parseFrom(uncompressedPayload, batchSize).messages()) {
            values.emplace_back(msg.getDataAsString());
        }
    } else {
        // non-batched message
        values.emplace_back(uncompressedPayload.data(), uncompressedPayload.readableBytes());
    }
    return values;
}

static void testDecryption(Client& client, const std::string& topic, bool withDecryption,
                           int numMessageReceived) {
    ProducerConfiguration producerConf;
    producerConf.setCompressionType(CompressionLZ4);
    producerConf.addEncryptionKey("client-rsa.pem");
    producerConf.setCryptoKeyReader(getDefaultCryptoKeyReader());

    Producer producer;
    ASSERT_EQ(ResultOk, client.createProducer(topic, producerConf, producer));

    std::vector<std::string> sentValues;
    auto send = [&producer, &sentValues](const std::string& value) {
        Message msg = MessageBuilder().setContent(value).build();
        producer.sendAsync(msg, nullptr);
        sentValues.emplace_back(value);
    };

    for (int i = 0; i < 5; i++) {
        send("msg-" + std::to_string(i));
    }
    producer.flush();
    send("last-msg");
    producer.flush();

    ASSERT_EQ(ResultOk, client.createProducer(topic, producer));
    send("unencrypted-msg");
    producer.flush();
    producer.close();

    ConsumerConfiguration consumerConf;
    consumerConf.setSubscriptionInitialPosition(InitialPositionEarliest);
    if (withDecryption) {
        consumerConf.setCryptoKeyReader(getDefaultCryptoKeyReader());
    } else {
        consumerConf.setCryptoFailureAction(ConsumerCryptoFailureAction::CONSUME);
    }
    Consumer consumer;
    ASSERT_EQ(ResultOk, client.subscribe(topic, "sub", consumerConf, consumer));

    std::vector<std::string> values;
    for (int i = 0; i < numMessageReceived; i++) {
        Message msg;
        ASSERT_EQ(ResultOk, consumer.receive(msg, 3000));
        if (i < numMessageReceived - 1) {
            ASSERT_TRUE(msg.getEncryptionContext().has_value());
        } else {
            ASSERT_FALSE(msg.getEncryptionContext().has_value());
        }
        for (auto&& value : decryptValue(static_cast<const char*>(msg.getData()), msg.getLength(),
                                         msg.getEncryptionContext())) {
            values.emplace_back(value);
        }
    }
    ASSERT_EQ(values, sentValues);
    consumer.close();
}

TEST(EncryptionTests, testDecryptionSuccess) {
    Client client{lookupUrl};
    std::string topic = "test-decryption-success-" + std::to_string(time(nullptr));
    testDecryption(client, topic, true, 7);
    client.close();
}

TEST(EncryptionTests, testDecryptionFailure) {
    Client client{lookupUrl};
    std::string topic = "test-decryption-failure-" + std::to_string(time(nullptr));
    // The 1st batch that has 5 messages cannot be decrypted, so they can be received only once
    testDecryption(client, topic, false, 3);
    client.close();
}
