/*
 * Decompiled with CFR 0.152.
 */
package org.apache.celeborn.common.network.sasl.registration;

import java.io.IOException;
import java.nio.ByteBuffer;
import java.util.List;
import org.apache.celeborn.common.network.client.RpcResponseCallback;
import org.apache.celeborn.common.network.client.TransportClient;
import org.apache.celeborn.common.network.protocol.RequestMessage;
import org.apache.celeborn.common.network.protocol.RpcFailure;
import org.apache.celeborn.common.network.protocol.RpcRequest;
import org.apache.celeborn.common.network.protocol.TransportMessage;
import org.apache.celeborn.common.network.sasl.CelebornSaslServer;
import org.apache.celeborn.common.network.sasl.SaslRpcHandler;
import org.apache.celeborn.common.network.sasl.SecretRegistry;
import org.apache.celeborn.common.network.server.BaseMessageHandler;
import org.apache.celeborn.common.network.util.TransportConf;
import org.apache.celeborn.common.protocol.MessageType;
import org.apache.celeborn.common.protocol.PbAuthType;
import org.apache.celeborn.common.protocol.PbAuthenticationInitiationRequest;
import org.apache.celeborn.common.protocol.PbAuthenticationInitiationResponse;
import org.apache.celeborn.common.protocol.PbRegisterApplicationRequest;
import org.apache.celeborn.common.protocol.PbRegisterApplicationResponse;
import org.apache.celeborn.common.protocol.PbSaslMechanism;
import org.apache.celeborn.common.protocol.PbSaslRequest;
import org.apache.celeborn.shaded.com.google.common.base.Throwables;
import org.apache.celeborn.shaded.com.google.common.collect.Lists;
import org.apache.celeborn.shaded.io.netty.channel.Channel;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class RegistrationRpcHandler
extends BaseMessageHandler {
    private static final Logger LOG = LoggerFactory.getLogger(RegistrationRpcHandler.class);
    private static final String VERSION = "1.0";
    private static final List<PbSaslMechanism> SASL_MECHANISMS = Lists.newArrayList(PbSaslMechanism.newBuilder().setMechanism("ANONYMOUS").addAuthTypes(PbAuthType.CLIENT_AUTH).build(), PbSaslMechanism.newBuilder().setMechanism("DIGEST-MD5").addAuthTypes(PbAuthType.CONNECTION_AUTH).build());
    private final TransportConf conf;
    private final Channel channel;
    private final BaseMessageHandler delegate;
    private RegistrationState registrationState = RegistrationState.NONE;
    private final SecretRegistry secretRegistry;
    private SaslRpcHandler saslHandler;
    private CelebornSaslServer saslServer = null;

    public RegistrationRpcHandler(TransportConf conf, Channel channel, BaseMessageHandler delegate, SecretRegistry secretRegistry) {
        this.conf = conf;
        this.channel = channel;
        this.secretRegistry = secretRegistry;
        this.delegate = delegate;
        this.saslHandler = new SaslRpcHandler(conf, channel, delegate, secretRegistry);
    }

    @Override
    public boolean checkRegistered() {
        return this.delegate.checkRegistered();
    }

    @Override
    public final void receive(TransportClient client, RequestMessage message, RpcResponseCallback callback) {
        if (this.registrationState == RegistrationState.REGISTERED || this.saslHandler.isAuthenticated()) {
            LOG.trace("Already authenticated. Delegating {}", (Object)client.getClientId());
            this.delegate.receive(client, message, callback);
        } else {
            RpcRequest rpcRequest = (RpcRequest)message;
            try {
                this.processRpcMessage(client, rpcRequest, callback);
            }
            catch (Exception e) {
                LOG.error("Error while invoking RpcHandler#receive() on RPC id " + rpcRequest.requestId, (Throwable)e);
                this.registrationState = RegistrationState.FAILED;
                client.getChannel().writeAndFlush(new RpcFailure(rpcRequest.requestId, Throwables.getStackTraceAsString(e)));
            }
        }
    }

    @Override
    public final void receive(TransportClient client, RequestMessage message) {
        if (this.registrationState == RegistrationState.REGISTERED || this.saslHandler.isAuthenticated()) {
            if (LOG.isTraceEnabled()) {
                LOG.trace("Already authenticated. Delegating {}", (Object)client.getClientId());
            }
        } else {
            throw new SecurityException("Unauthenticated call to receive().");
        }
        this.delegate.receive(client, message);
    }

    private void processRpcMessage(TransportClient client, RpcRequest message, RpcResponseCallback callback) throws IOException {
        TransportMessage pbMsg = TransportMessage.fromByteBuffer(message.body().nioByteBuffer());
        switch (pbMsg.getMessageTypeValue()) {
            case 72: {
                PbAuthenticationInitiationRequest authInitRequest = (PbAuthenticationInitiationRequest)pbMsg.getParsedPayload();
                this.checkRequestAllowed(RegistrationState.NONE);
                this.respondToAuthInitialization(callback);
                this.registrationState = RegistrationState.INIT;
                LOG.trace("Authentication initialization completed: rpcId {}", (Object)message.requestId);
                break;
            }
            case 71: {
                PbSaslRequest saslRequest = (PbSaslRequest)pbMsg.getParsedPayload();
                if (saslRequest.getAuthType().equals(PbAuthType.CLIENT_AUTH)) {
                    LOG.trace("Received Sasl Message for client authentication");
                    this.checkRequestAllowed(RegistrationState.INIT);
                    this.authenticateClient(saslRequest, callback);
                    if (!this.saslServer.isComplete()) break;
                    LOG.debug("SASL authentication successful for channel {}", (Object)client);
                    this.complete();
                    this.registrationState = RegistrationState.AUTHENTICATED;
                    LOG.trace("Client authenticated: rpcId {}", (Object)message.requestId);
                    break;
                }
                LOG.trace("Delegating to sasl handler: rpcId {}", (Object)message.requestId);
                this.saslHandler.receive(client, message, callback);
                break;
            }
            case 74: {
                PbRegisterApplicationRequest registerApplicationRequest = (PbRegisterApplicationRequest)pbMsg.getParsedPayload();
                this.checkRequestAllowed(RegistrationState.AUTHENTICATED);
                LOG.trace("Application registration started {}", (Object)registerApplicationRequest.getId());
                this.processRegisterApplicationRequest(registerApplicationRequest, callback);
                this.registrationState = RegistrationState.REGISTERED;
                client.setClientId(registerApplicationRequest.getId());
                LOG.info("Application registered: appId {} rpcId {}", (Object)registerApplicationRequest.getId(), (Object)message.requestId);
                break;
            }
            default: {
                throw new SecurityException("The app is not registered and the connection is not authenticated " + message.requestId);
            }
        }
    }

    private void checkRequestAllowed(RegistrationState expectedState) {
        if (this.registrationState != expectedState) {
            throw new IllegalStateException("Invalid registration state. Expected: " + (Object)((Object)expectedState) + ", Actual: " + (Object)((Object)this.registrationState));
        }
    }

    private void respondToAuthInitialization(RpcResponseCallback callback) {
        PbAuthenticationInitiationResponse response = PbAuthenticationInitiationResponse.newBuilder().setAuthEnabled(this.conf.authEnabled()).setVersion(VERSION).addAllSaslMechanisms(SASL_MECHANISMS).build();
        TransportMessage message = new TransportMessage(MessageType.AUTHENTICATION_INITIATION_RESPONSE, response.toByteArray());
        callback.onSuccess(message.toByteBuffer());
    }

    private void authenticateClient(PbSaslRequest saslMessage, RpcResponseCallback callback) {
        if (this.saslServer == null || !this.saslServer.isComplete()) {
            if (this.saslServer == null) {
                this.saslServer = new CelebornSaslServer("ANONYMOUS", null, null);
            }
        } else {
            throw new IllegalArgumentException("Unexpected message type " + saslMessage.toString());
        }
        byte[] response = this.saslServer.response(saslMessage.getPayload().toByteArray());
        callback.onSuccess(ByteBuffer.wrap(response));
    }

    private void processRegisterApplicationRequest(PbRegisterApplicationRequest registerApplicationRequest, RpcResponseCallback callback) {
        if (this.secretRegistry.isRegistered(registerApplicationRequest.getId())) {
            throw new IllegalStateException("Application is already registered " + registerApplicationRequest.getId());
        }
        this.secretRegistry.register(registerApplicationRequest.getId(), registerApplicationRequest.getSecret());
        PbRegisterApplicationResponse response = PbRegisterApplicationResponse.newBuilder().setStatus(true).build();
        TransportMessage message = new TransportMessage(MessageType.REGISTER_APPLICATION_RESPONSE, response.toByteArray());
        callback.onSuccess(message.toByteBuffer());
    }

    @Override
    public void channelInactive(TransportClient client) {
        this.delegate.channelInactive(client);
        this.cleanup();
    }

    @Override
    public void exceptionCaught(Throwable cause, TransportClient client) {
        this.delegate.exceptionCaught(cause, client);
    }

    private void complete() {
        this.cleanup();
    }

    private void cleanup() {
        if (null != this.saslServer) {
            try {
                this.saslServer.dispose();
            }
            catch (RuntimeException e) {
                LOG.error("Error while disposing SASL server", (Throwable)e);
            }
            finally {
                this.saslServer = null;
            }
        }
        this.saslHandler.cleanup();
    }

    private static enum RegistrationState {
        NONE,
        INIT,
        AUTHENTICATED,
        REGISTERED,
        FAILED;

    }
}

