# Copyright (c) 2017 The Verde Developers.
# Distributed under the terms of the BSD 3-Clause License.
# SPDX-License-Identifier: BSD-3-Clause
#
# This code is part of the Fatiando a Terra project (https://www.fatiando.org)
#
"""
Test the coordinate generation functions
"""
import warnings

import numpy as np
import numpy.testing as npt
import pytest

from ..coordinates import (
    check_region,
    grid_coordinates,
    line_coordinates,
    longitude_continuity,
    profile_coordinates,
    rolling_window,
    spacing_to_size,
)


def test_rolling_window_invalid_coordinate_shapes():
    "Shapes of input coordinates must all be the same"
    coordinates = [np.arange(10), np.arange(10).reshape((5, 2))]
    with pytest.raises(ValueError):
        rolling_window(coordinates, size=2, spacing=1)


def test_rolling_window_empty():
    "Make sure empty windows return an empty index"
    coords = grid_coordinates((-5, -1, 6, 10), spacing=1)
    # Use a larger region to make sure the first window is empty
    # Doing this will raise a warning for non-overlapping windows. Capture it
    # so it doesn't pollute the test log.
    with warnings.catch_warnings(record=True):
        windows = rolling_window(coords, size=0.001, spacing=1, region=(-7, 1, 4, 12))[
            1
        ]
    assert windows[0, 0][0].size == 0 and windows[0, 0][1].size == 0
    # Make sure we can still index with an empty array
    assert coords[0][windows[0, 0]].size == 0


def test_rolling_window_warnings():
    "Should warn users if the windows don't overlap"
    coords = grid_coordinates((-5, -1, 6, 10), spacing=1)
    # For exact same size there will be 1 point overlapping so should not warn
    with warnings.catch_warnings(record=True) as warn:
        rolling_window(coords, size=2, spacing=2)
        assert not any(issubclass(w.category, UserWarning) for w in warn)
    args = [dict(spacing=3), dict(spacing=(4, 1)), dict(shape=(1, 2))]
    for arg in args:
        with warnings.catch_warnings(record=True) as warn:
            rolling_window(coords, size=2, **arg)
            # Filter out the user warnings from some deprecation warnings that
            # might be thrown by other packages.
            userwarnings = [w for w in warn if issubclass(w.category, UserWarning)]
            assert len(userwarnings) == 1
            assert issubclass(userwarnings[-1].category, UserWarning)
            assert str(userwarnings[-1].message).split()[0] == "Rolling"


def test_rolling_window_no_shape_or_spacing():
    """
    Check if error is raise if no shape or spacing is passed
    """
    coords = grid_coordinates((-5, -1, 6, 10), spacing=1)
    err_msg = "Either a shape or a spacing must be provided."
    with pytest.raises(ValueError, match=err_msg):
        rolling_window(coords, size=2)


def test_rolling_window_oversized_window():
    """
    Check if error is raised if size larger than region is passed
    """
    oversize = 5
    regions = [
        (-5, -1, 6, 20),  # window larger than west-east
        (-20, -1, 6, 10),  # window larger than south-north
        (-5, -1, 6, 10),  # window larger than both dims
    ]
    for region in regions:
        coords = grid_coordinates(region, spacing=1)
        # The expected error message with regex
        # (the long expression intends to capture floats and ints)
        float_regex = r"[+-]?([0-9]*[.])?[0-9]+"
        err_msg = (
            r"Window size '{}' is larger ".format(float_regex)
            + r"than dimensions of the region "
            + r"'\({0}, {0}, {0}, {0}\)'.".format(float_regex)
        )
        with pytest.raises(ValueError, match=err_msg):
            rolling_window(coords, size=oversize, spacing=2)


def test_spacing_to_size():
    "Check that correct size and stop are returned"
    start, stop = -10, 0

    size, new_stop = spacing_to_size(start, stop, spacing=2.5, adjust="spacing")
    npt.assert_allclose(size, 5)
    npt.assert_allclose(new_stop, stop)

    size, new_stop = spacing_to_size(start, stop, spacing=2, adjust="spacing")
    npt.assert_allclose(size, 6)
    npt.assert_allclose(new_stop, stop)

    size, new_stop = spacing_to_size(start, stop, spacing=2.6, adjust="spacing")
    npt.assert_allclose(size, 5)
    npt.assert_allclose(new_stop, stop)

    size, new_stop = spacing_to_size(start, stop, spacing=2.4, adjust="spacing")
    npt.assert_allclose(size, 5)
    npt.assert_allclose(new_stop, stop)

    size, new_stop = spacing_to_size(start, stop, spacing=2.6, adjust="region")
    npt.assert_allclose(size, 5)
    npt.assert_allclose(new_stop, 0.4)

    size, new_stop = spacing_to_size(start, stop, spacing=2.4, adjust="region")
    npt.assert_allclose(size, 5)
    npt.assert_allclose(new_stop, -0.4)


def test_line_coordinates_fails():
    "Check failures for invalid arguments"
    start, stop = 0, 1
    size = 10
    spacing = 0.1
    # Make sure it doesn't fail for these parameters
    line_coordinates(start, stop, size=size)
    line_coordinates(start, stop, spacing=spacing)
    with pytest.raises(ValueError):
        line_coordinates(start, stop)
    with pytest.raises(ValueError):
        line_coordinates(start, stop, size=size, spacing=spacing)


def test_line_coordinates_spacing_larger_than_twice_interval():
    "Check if pixel_register works when the spacing is greater than the limits"
    start, stop = 0, 1
    spacing = 3
    coordinates = line_coordinates(start, stop, spacing=spacing)
    npt.assert_allclose(coordinates, [0, 1])
    coordinates = line_coordinates(start, stop, spacing=spacing, pixel_register=True)
    npt.assert_allclose(coordinates, [0.5])
    coordinates = line_coordinates(start, stop, spacing=spacing, adjust="region")
    npt.assert_allclose(coordinates, [0, 3])
    coordinates = line_coordinates(
        start, stop, spacing=spacing, pixel_register=True, adjust="region"
    )
    npt.assert_allclose(coordinates, [1.5])


def test_grid_coordinates_fails():
    "Check failures for invalid arguments"
    region = (0, 1, 0, 10)
    shape = (10, 20)
    spacing = 0.5
    # Make sure it doesn't fail for these parameters
    grid_coordinates(region, shape)
    grid_coordinates(region, spacing=spacing)

    with pytest.raises(ValueError):
        grid_coordinates(region, shape=shape, spacing=spacing)
    with pytest.raises(ValueError):
        grid_coordinates(region, shape=None, spacing=None)
    with pytest.raises(ValueError):
        grid_coordinates(region, spacing=spacing, adjust="invalid adjust")
    with pytest.raises(ValueError):
        grid_coordinates(region, spacing=(1, 2, 3))


def test_check_region():
    "Make sure an exception is raised for bad regions"
    with pytest.raises(ValueError):
        check_region([])
    with pytest.raises(ValueError):
        check_region([1, 2, 3, 4, 5])
    with pytest.raises(ValueError):
        check_region([1, 2, 3])
    with pytest.raises(ValueError):
        check_region([1, 2, 3, 1])
    with pytest.raises(ValueError):
        check_region([2, 1, 3, 4])
    with pytest.raises(ValueError):
        check_region([-1, -2, -4, -3])
    with pytest.raises(ValueError):
        check_region([-2, -1, -2, -3])


def test_profile_coordinates_fails():
    "Should raise an exception for invalid input"
    with pytest.raises(ValueError):
        profile_coordinates((0, 1), (1, 2), size=0)
    with pytest.raises(ValueError):
        profile_coordinates((0, 1), (1, 2), size=-10)


def test_longitude_continuity():
    "Test continuous boundary conditions in geographic coordinates."
    # Define longitude around the globe for [0, 360) and [-180, 180)
    longitude_360 = np.linspace(0, 350, 36)
    longitude_180 = np.hstack((longitude_360[:18], longitude_360[18:] - 360))
    latitude = np.linspace(-90, 90, 36)
    s, n = -90, 90
    # Check w, e in [0, 360)
    w, e = 10.5, 20.3
    for longitude in [longitude_360, longitude_180]:
        coordinates = [longitude, latitude]
        coordinates_new, region_new = longitude_continuity(coordinates, (w, e, s, n))
        w_new, e_new = region_new[:2]
        assert w_new == w
        assert e_new == e
        npt.assert_allclose(coordinates_new[0], longitude_360)
    # Check w, e in [-180, 180)
    w, e = -20, 20
    for longitude in [longitude_360, longitude_180]:
        coordinates = [longitude, latitude]
        coordinates_new, region_new = longitude_continuity(coordinates, (w, e, s, n))
        w_new, e_new = region_new[:2]
        assert w_new == -20
        assert e_new == 20
        npt.assert_allclose(coordinates_new[0], longitude_180)
    # Check region around the globe
    for w, e in [[0, 360], [-180, 180], [-20, 340]]:
        for longitude in [longitude_360, longitude_180]:
            coordinates = [longitude, latitude]
            coordinates_new, region_new = longitude_continuity(
                coordinates, (w, e, s, n)
            )
            w_new, e_new = region_new[:2]
            assert w_new == 0
            assert e_new == 360
            npt.assert_allclose(coordinates_new[0], longitude_360)
    # Check w == e
    w, e = 20, 20
    for longitude in [longitude_360, longitude_180]:
        coordinates = [longitude, latitude]
        coordinates_new, region_new = longitude_continuity(coordinates, (w, e, s, n))
        w_new, e_new = region_new[:2]
        assert w_new == 20
        assert e_new == 20
        npt.assert_allclose(coordinates_new[0], longitude_360)
    # Check angle greater than 180
    w, e = 0, 200
    for longitude in [longitude_360, longitude_180]:
        coordinates = [longitude, latitude]
        coordinates_new, region_new = longitude_continuity(coordinates, (w, e, s, n))
        w_new, e_new = region_new[:2]
        assert w_new == 0
        assert e_new == 200
        npt.assert_allclose(coordinates_new[0], longitude_360)
    w, e = -160, 160
    for longitude in [longitude_360, longitude_180]:
        coordinates = [longitude, latitude]
        coordinates_new, region_new = longitude_continuity(coordinates, (w, e, s, n))
        w_new, e_new = region_new[:2]
        assert w_new == -160
        assert e_new == 160
        npt.assert_allclose(coordinates_new[0], longitude_180)


def test_invalid_geographic_region():
    "Check if invalid region in longitude_continuity raises a ValueError"
    # Region with latitude over boundaries
    w, e = -10, 10
    for s, n in [[-200, 90], [-90, 200]]:
        with pytest.raises(ValueError):
            longitude_continuity(None, [w, e, s, n])
    # Region with longitude over boundaries
    s, n = -10, 10
    for w, e in [[-200, 0], [0, 380]]:
        with pytest.raises(ValueError):
            longitude_continuity(None, [w, e, s, n])
    # Region with longitudinal difference greater than 360
    w, e, s, n = -180, 200, -10, 10
    with pytest.raises(ValueError):
        longitude_continuity(None, [w, e, s, n])


def test_invalid_geographic_coordinates():
    "Check if invalid coordinates in longitude_continuity raises a ValueError"
    boundaries = [0, 360, -90, 90]
    spacing = 10
    region = [-20, 20, -20, 20]
    # Region with longitude point over boundaries
    longitude, latitude = grid_coordinates(boundaries, spacing=spacing)
    longitude[0] = -200
    with pytest.raises(ValueError):
        longitude_continuity([longitude, latitude], region)
    longitude[0] = 400
    with pytest.raises(ValueError):
        longitude_continuity([longitude, latitude], region)
    # Region with latitude point over boundaries
    longitude, latitude = grid_coordinates(boundaries, spacing=spacing)
    latitude[0] = -100
    with pytest.raises(ValueError):
        longitude_continuity([longitude, latitude], region)
    latitude[0] = 100
    with pytest.raises(ValueError):
        longitude_continuity([longitude, latitude], region)


def test_meshgrid_extra_coords_error():
    "Should raise an exception if meshgrid=False and extra_coords are used"
    with pytest.raises(ValueError):
        grid_coordinates(
            region=(0, 1, 0, 3), spacing=0.1, meshgrid=False, extra_coords=10
        )
