//
// ========================================================================
// Copyright (c) 1995-2020 Mort Bay Consulting Pty Ltd and others.
//
// This program and the accompanying materials are made available under
// the terms of the Eclipse Public License 2.0 which is available at
// https://www.eclipse.org/legal/epl-2.0
//
// This Source Code may also be made available under the following
// Secondary Licenses when the conditions for such availability set
// forth in the Eclipse Public License, v. 2.0 are satisfied:
// the Apache License v2.0 which is available at
// https://www.apache.org/licenses/LICENSE-2.0
//
// SPDX-License-Identifier: EPL-2.0 OR Apache-2.0
// ========================================================================
//

package org.eclipse.jetty.websocket.core;

import java.net.Socket;
import java.util.concurrent.Exchanger;
import java.util.concurrent.TimeUnit;
import java.util.function.BiFunction;

import org.eclipse.jetty.logging.StacklessLogging;
import org.eclipse.jetty.util.Callback;
import org.eclipse.jetty.websocket.core.internal.Parser;
import org.eclipse.jetty.websocket.core.internal.WebSocketCoreSession;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.Test;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import static org.eclipse.jetty.util.Callback.NOOP;
import static org.hamcrest.MatcherAssert.assertThat;
import static org.hamcrest.Matchers.containsString;
import static org.hamcrest.Matchers.is;
import static org.hamcrest.Matchers.notNullValue;
import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.junit.jupiter.api.Assertions.assertTrue;

/**
 * Tests of a core server with a fake client
 */
public class WebSocketOpenTest extends WebSocketTester
{
    private static Logger LOG = LoggerFactory.getLogger(WebSocketOpenTest.class);

    private WebSocketServer server;
    private DemandingAsyncFrameHandler serverHandler;
    private Socket client;

    @AfterEach
    public void after() throws Exception
    {
        if (server != null)
            server.stop();
    }

    public void setup(BiFunction<CoreSession, Callback, Void> onOpen) throws Exception
    {
        serverHandler = new DemandingAsyncFrameHandler(onOpen);
        server = new WebSocketServer(serverHandler);
        server.start();
        client = newClient(server.getLocalPort());
    }

    @Test
    public void testSendFrameInOnOpen() throws Exception
    {
        setup((s, c) ->
        {
            assertThat(s.toString(), containsString("CONNECTED"));
            s.sendFrame(new Frame(OpCode.TEXT, "Hello"), NOOP, false);
            c.succeeded();
            s.demand(1);
            return null;
        });
        Parser.ParsedFrame frame = receiveFrame(client.getInputStream());
        assertThat(frame.getPayloadAsUTF8(), is("Hello"));

        client.getOutputStream().write(RawFrameBuilder.buildClose(new CloseStatus(CloseStatus.NORMAL), true));
        assertTrue(serverHandler.closeLatch.await(5, TimeUnit.SECONDS));
        assertThat(serverHandler.closeStatus.getCode(), is(CloseStatus.NORMAL));

        frame = receiveFrame(client.getInputStream());
        assertThat(frame.getOpCode(), is(OpCode.CLOSE));
        assertThat(new CloseStatus(frame).getCode(), is(CloseStatus.NORMAL));
    }

    @Test
    public void testFailureInOnOpen() throws Exception
    {
        try (StacklessLogging stackless = new StacklessLogging(WebSocketCoreSession.class))
        {
            setup((s, c) ->
            {
                assertThat(s.toString(), containsString("CONNECTED"));
                c.failed(new Exception("Test Exception in onOpen"));
                return null;
            });

            assertTrue(serverHandler.closeLatch.await(5, TimeUnit.SECONDS));
            assertThat(serverHandler.error, notNullValue());

            assertTrue(serverHandler.closeLatch.await(5, TimeUnit.SECONDS));
            assertThat(serverHandler.closeStatus.getCode(), is(CloseStatus.SERVER_ERROR));

            Parser.ParsedFrame frame = receiveFrame(client.getInputStream());
            assertThat(frame.getOpCode(), is(OpCode.CLOSE));
            assertThat(new CloseStatus(frame).getCode(), is(CloseStatus.SERVER_ERROR));
        }
    }

    @Test
    public void testCloseInOnOpen() throws Exception
    {
        setup((s, c) ->
        {
            assertThat(s.toString(), containsString("CONNECTED"));
            s.close(CloseStatus.SHUTDOWN, "Test close in onOpen", c);
            return null;
        });

        Parser.ParsedFrame frame = receiveFrame(client.getInputStream());
        assertThat(frame.getOpCode(), is(OpCode.CLOSE));
        assertThat(new CloseStatus(frame).getCode(), is(CloseStatus.SHUTDOWN));

        client.getOutputStream().write(RawFrameBuilder.buildClose(new CloseStatus(CloseStatus.NORMAL), true));
        assertTrue(serverHandler.closeLatch.await(5, TimeUnit.SECONDS));
        assertThat(serverHandler.closeStatus.getCode(), is(CloseStatus.SHUTDOWN));
    }

    @Test
    public void testAsyncOnOpen() throws Exception
    {
        Exchanger<CoreSession> sx = new Exchanger<>();
        Exchanger<Callback> cx = new Exchanger<>();
        setup((s, c) ->
        {
            assertThat(s.toString(), containsString("CONNECTED"));
            try
            {
                sx.exchange(s);
                cx.exchange(c);
            }
            catch (InterruptedException e)
            {
                throw new RuntimeException(e);
            }
            return null;
        });

        CoreSession coreSession = sx.exchange(null);
        Callback onOpenCallback = cx.exchange(null);
        Thread.sleep(100);

        // Can send while onOpen is active
        coreSession.sendFrame(new Frame(OpCode.TEXT, "Hello"), NOOP, false);
        Parser.ParsedFrame frame = receiveFrame(client.getInputStream());
        assertThat(frame.getPayloadAsUTF8(), is("Hello"));

        // But cannot receive
        client.getOutputStream().write(RawFrameBuilder.buildClose(new CloseStatus(CloseStatus.NORMAL), true));
        assertFalse(serverHandler.closeLatch.await(1, TimeUnit.SECONDS));

        // Can't demand until open
        assertThrows(Throwable.class, () -> coreSession.demand(1));
        client.getOutputStream().write(RawFrameBuilder.buildClose(new CloseStatus(CloseStatus.NORMAL), true));
        assertFalse(serverHandler.closeLatch.await(1, TimeUnit.SECONDS));

        // Succeeded moves to OPEN state and still does not read CLOSE frame
        onOpenCallback.succeeded();
        assertThat(coreSession.toString(), containsString("OPEN"));

        // Demand start receiving frames
        coreSession.demand(1);
        client.getOutputStream().write(RawFrameBuilder.buildClose(new CloseStatus(CloseStatus.NORMAL), true));
        assertTrue(serverHandler.closeLatch.await(5, TimeUnit.SECONDS));

        // Closed handled normally
        assertTrue(serverHandler.closeLatch.await(5, TimeUnit.SECONDS));
        assertThat(serverHandler.closeStatus.getCode(), is(CloseStatus.NORMAL));
        frame = receiveFrame(client.getInputStream());
        assertThat(frame.getOpCode(), is(OpCode.CLOSE));
        assertThat(new CloseStatus(frame).getCode(), is(CloseStatus.NORMAL));
    }

    static class DemandingAsyncFrameHandler extends TestAsyncFrameHandler
    {
        private BiFunction<CoreSession, Callback, Void> onOpen;

        DemandingAsyncFrameHandler(BiFunction<CoreSession, Callback, Void> onOpen)
        {
            this.onOpen = onOpen;
        }

        @Override
        public void onOpen(CoreSession coreSession, Callback callback)
        {
            if (LOG.isDebugEnabled())
                LOG.debug("[{}] onOpen {}", name, coreSession);
            this.coreSession = coreSession;
            onOpen.apply(coreSession, callback);
            openLatch.countDown();
        }

        @Override
        public boolean isDemanding()
        {
            return true;
        }
    }
}
