Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

8314063 : The socket is not closed in Connection::createSocket when the handshake failed for LDAP connection #15294

Closed
wants to merge 15 commits into from
Closed
325 changes: 153 additions & 172 deletions src/java.naming/share/classes/com/sun/jndi/ldap/Connection.java
Original file line number Diff line number Diff line change
@@ -25,98 +25,95 @@

package com.sun.jndi.ldap;

import javax.naming.CommunicationException;
import javax.naming.InterruptedNamingException;
import javax.naming.NamingException;
import javax.naming.ServiceUnavailableException;
import javax.naming.ldap.Control;
import javax.net.SocketFactory;
import javax.net.ssl.HandshakeCompletedEvent;
import javax.net.ssl.HandshakeCompletedListener;
import javax.net.ssl.SSLParameters;
import javax.net.ssl.SSLPeerUnverifiedException;
import javax.net.ssl.SSLSocket;
import javax.security.sasl.SaslException;
import java.io.BufferedInputStream;
import java.io.BufferedOutputStream;
import java.io.InterruptedIOException;
import java.io.IOException;
import java.io.OutputStream;
import java.io.InputStream;
import java.io.InterruptedIOException;
import java.io.OutputStream;
import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
import java.net.InetSocketAddress;
import java.net.Socket;
import javax.net.ssl.SSLSocket;

import javax.naming.CommunicationException;
import javax.naming.ServiceUnavailableException;
import javax.naming.NamingException;
import javax.naming.InterruptedNamingException;

import javax.naming.ldap.Control;

import java.lang.reflect.Method;
import java.lang.reflect.InvocationTargetException;
import java.security.AccessController;
import java.security.PrivilegedAction;
import java.security.cert.Certificate;
import java.security.cert.X509Certificate;
import java.util.Arrays;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutionException;
import javax.net.SocketFactory;
import javax.net.ssl.SSLParameters;
import javax.net.ssl.HandshakeCompletedEvent;
import javax.net.ssl.HandshakeCompletedListener;
import javax.net.ssl.SSLPeerUnverifiedException;
import javax.security.sasl.SaslException;

/**
* A thread that creates a connection to an LDAP server.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The formatting changes here is not related to the code changes under review in this PR. Could you please revert them?

* After the connection, the thread reads from the connection.
* A caller can invoke methods on the instance to read LDAP responses
* and to send LDAP requests.
* <p>
* There is a one-to-one correspondence between an LdapClient and
* a Connection. Access to Connection and its methods is only via
* LdapClient with two exceptions: SASL authentication and StartTLS.
* SASL needs to access Connection's socket IO streams (in order to do encryption
* of the security layer). StartTLS needs to do replace IO streams
* and close the IO streams on nonfatal close. The code for SASL
* authentication can be treated as being the same as from LdapClient
* because the SASL code is only ever called from LdapClient, from
* inside LdapClient's synchronized authenticate() method. StartTLS is called
* directly by the application but should only occur when the underlying
* connection is quiet.
* <p>
* In terms of synchronization, worry about data structures
* used by the Connection thread because that usage might contend
* with calls by the main threads (i.e., those that call LdapClient).
* Main threads need to worry about contention with each other.
* Fields that Connection thread uses:
* inStream - synced access and update; initialized in constructor;
* referenced outside class unsync'ed (by LdapSasl) only
* when connection is quiet
* traceFile, traceTagIn, traceTagOut - no sync; debugging only
* parent - no sync; initialized in constructor; no updates
* pendingRequests - sync
* pauseLock - per-instance lock;
* paused - sync via pauseLock (pauseReader())
* Members used by main threads (LdapClient):
* host, port - unsync; read-only access for StartTLS and debug messages
* setBound(), setV3() - no sync; called only by LdapClient.authenticate(),
* which is a sync method called only when connection is "quiet"
* getMsgId() - sync
* writeRequest(), removeRequest(),findRequest(), abandonOutstandingReqs() -
* access to shared pendingRequests is sync
* writeRequest(), abandonRequest(), ldapUnbind() - access to outStream sync
* cleanup() - sync
* readReply() - access to sock sync
* unpauseReader() - (indirectly via writeRequest) sync on pauseLock
* Members used by SASL auth (main thread):
* inStream, outStream - no sync; used to construct new stream; accessed
* only when conn is "quiet" and not shared
* replaceStreams() - sync method
* Members used by StartTLS:
* inStream, outStream - no sync; used to record the existing streams;
* accessed only when conn is "quiet" and not shared
* replaceStreams() - sync method
* <p>
* Handles anonymous, simple, and SASL bind for v3; anonymous and simple
* for v2.
* %%% made public for access by LdapSasl %%%
*
* @author Vincent Ryan
* @author Rosanna Lee
* @author Jagane Sundar
*/
* A thread that creates a connection to an LDAP server.
* After the connection, the thread reads from the connection.
* A caller can invoke methods on the instance to read LDAP responses
* and to send LDAP requests.
* <p>
* There is a one-to-one correspondence between an LdapClient and
* a Connection. Access to Connection and its methods is only via
* LdapClient with two exceptions: SASL authentication and StartTLS.
* SASL needs to access Connection's socket IO streams (in order to do encryption
* of the security layer). StartTLS needs to do replace IO streams
* and close the IO streams on nonfatal close. The code for SASL
* authentication can be treated as being the same as from LdapClient
* because the SASL code is only ever called from LdapClient, from
* inside LdapClient's synchronized authenticate() method. StartTLS is called
* directly by the application but should only occur when the underlying
* connection is quiet.
* <p>
* In terms of synchronization, worry about data structures
* used by the Connection thread because that usage might contend
* with calls by the main threads (i.e., those that call LdapClient).
* Main threads need to worry about contention with each other.
* Fields that Connection thread uses:
* inStream - synced access and update; initialized in constructor;
* referenced outside class unsync'ed (by LdapSasl) only
* when connection is quiet
* traceFile, traceTagIn, traceTagOut - no sync; debugging only
* parent - no sync; initialized in constructor; no updates
* pendingRequests - sync
* pauseLock - per-instance lock;
* paused - sync via pauseLock (pauseReader())
* Members used by main threads (LdapClient):
* host, port - unsync; read-only access for StartTLS and debug messages
* setBound(), setV3() - no sync; called only by LdapClient.authenticate(),
* which is a sync method called only when connection is "quiet"
* getMsgId() - sync
* writeRequest(), removeRequest(),findRequest(), abandonOutstandingReqs() -
* access to shared pendingRequests is sync
* writeRequest(), abandonRequest(), ldapUnbind() - access to outStream sync
* cleanup() - sync
* readReply() - access to sock sync
* unpauseReader() - (indirectly via writeRequest) sync on pauseLock
* Members used by SASL auth (main thread):
* inStream, outStream - no sync; used to construct new stream; accessed
* only when conn is "quiet" and not shared
* replaceStreams() - sync method
* Members used by StartTLS:
* inStream, outStream - no sync; used to record the existing streams;
* accessed only when conn is "quiet" and not shared
* replaceStreams() - sync method
* <p>
* Handles anonymous, simple, and SASL bind for v3; anonymous and simple
* for v2.
* %%% made public for access by LdapSasl %%%
*
* @author Vincent Ryan
* @author Rosanna Lee
* @author Jagane Sundar
*/
public final class Connection implements Runnable {

private static final boolean debug = false;
@@ -128,9 +125,9 @@ public final class Connection implements Runnable {
private boolean v3 = true; // Set in setV3()

public final String host; // used by LdapClient for generating exception messages
// used by StartTlsResponse when creating an SSL socket
// used by StartTlsResponse when creating an SSL socket
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same - reformatting of unrelated code

public final int port; // used by LdapClient for generating exception messages
// used by StartTlsResponse when creating an SSL socket
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same - reformatting of unrelated code

// used by StartTlsResponse when creating an SSL socket

private boolean bound = false; // Set in setBound()

@@ -189,6 +186,7 @@ private static boolean hostnameVerificationDisabledValue() {
}
return prop.isEmpty() ? true : Boolean.parseBoolean(prop);
}

// true means v3; false means v2
// Called in LdapClient.authenticate() (which is synchronized)
// when connection is "quiet" and not shared; no need to synchronize
@@ -211,7 +209,7 @@ void setBound() {
////////////////////////////////////////////////////////////////////////////

Connection(LdapClient parent, String host, int port, String socketFactory,
int connectTimeout, int readTimeout, OutputStream trace) throws NamingException {
int connectTimeout, int readTimeout, OutputStream trace) throws NamingException {

this.host = host;
this.port = port;
@@ -243,15 +241,15 @@ void setBound() {
// realException.printStackTrace();

CommunicationException ce =
new CommunicationException(host + ":" + port);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same - reformatting of unrelated code

new CommunicationException(host + ":" + port);
ce.setRootCause(realException);
throw ce;
} catch (Exception e) {
// We need to have a catch all here and
// ignore generic exceptions.
// Also catches all IO errors generated by socket creation.
CommunicationException ce =
new CommunicationException(host + ":" + port);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same - reformatting of unrelated code

new CommunicationException(host + ":" + port);
ce.setRootCause(e);
throw ce;
}
@@ -265,7 +263,7 @@ void setBound() {
* Create an InetSocketAddress using the specified hostname and port number.
*/
private InetSocketAddress createInetSocketAddress(String host, int port) {
return new InetSocketAddress(host, port);
return new InetSocketAddress(host, port);
}

/*
@@ -280,16 +278,13 @@ private InetSocketAddress createInetSocketAddress(String host, int port) {
private Socket createSocket(String host, int port, String socketFactory,
int connectTimeout) throws Exception {

Socket socket = null;
SocketFactory factory = getSocketFactory(socketFactory);
assert factory != null;
//create the socket with default socket factory or custom factory
Socket socket = createConnectionSocket(host, port, factory, connectTimeout);

// the handshake for SSL connection with server and reset timeout for the socket
try {
if (socketFactory != null) {
// create a connected socket with factory
socket = createConnectionSocket(host, port, socketFactory, connectTimeout);
} else {
// create a connected socket without factory
socket = createConnectionSocket(host, port, connectTimeout);
}
//the handshake for SSL connection with server and reset timeout for the socket
initialSSLHandshake(socket, connectTimeout);
} catch (Exception e) {
// 8314063 the socket is not closed after the failure of handshake
@@ -300,65 +295,51 @@ private Socket createSocket(String host, int port, String socketFactory,
return socket;
}

// create a connected socket without factory
private Socket createConnectionSocket(String host, int port, int connectTimeout) throws Exception {

Socket socket = null;
if (connectTimeout > 0) {

InetSocketAddress endpoint = createInetSocketAddress(host, port);
socket = new Socket();

socket.connect(endpoint, connectTimeout);
// get the socket factory, either default or custom
private SocketFactory getSocketFactory(String socketFactoryName) throws Exception {
if (socketFactoryName == null) {
if (debug) {
System.err.println("Connection: creating socket with " +
"a timeout");
System.err.println("Connection: using default SocketFactory");
}
}

// continue (but ignore connectTimeout)
if (socket == null) {
// connected socket
socket = new Socket(host, port);
return SocketFactory.getDefault();
} else {
if (debug) {
System.err.println("Connection: creating socket");
System.err.println("Connection: loading supplied SocketFactory: " + socketFactoryName);
}
@SuppressWarnings("unchecked")
Class<? extends SocketFactory> socketFactoryClass =
(Class<? extends SocketFactory>) Obj.helper.loadClass(socketFactoryName);
Method getDefault =
socketFactoryClass.getMethod("getDefault");
SocketFactory factory = (SocketFactory) getDefault.invoke(null, new Object[]{});
return factory;
}
return socket;
}

// create a connected socket with factory
private Socket createConnectionSocket(String host, int port, String socketFactory,
private Socket createConnectionSocket(String host, int port, SocketFactory factory,
int connectTimeout) throws Exception {
@SuppressWarnings("unchecked")
Class<? extends SocketFactory> socketFactoryClass =
(Class<? extends SocketFactory>) Obj.helper.loadClass(socketFactory);
Method getDefault =
socketFactoryClass.getMethod("getDefault", new Class<?>[]{});
SocketFactory factory = (SocketFactory) getDefault.invoke(null, new Object[]{});
Socket socket = null;

// create the socket
if (connectTimeout > 0) {
// create unconnected socket and then connect it if timeout
// is supplied
InetSocketAddress endpoint =
createInetSocketAddress(host, port);
// unconnected socket
socket = factory.createSocket();
// connected socket
// connect socket with a timeout
socket.connect(endpoint, connectTimeout);
if (debug) {
System.err.println("Connection: creating socket with " +
"a timeout using supplied socket factory");
"a connect timeout");
}
} else {
// continue (but ignore connectTimeout)
if (socket == null) {
// connected socket
socket = factory.createSocket(host, port);
if (debug) {
System.err.println("Connection: creating socket using " +
"supplied socket factory");
}
}
if (socket == null) {
// create connected socket
socket = factory.createSocket(host, port);
if (debug) {
System.err.println("Connection: creating connected socket with" +
" no connect timeout");
}
}
return socket;
@@ -368,7 +349,7 @@ private Socket createConnectionSocket(String host, int port, String socketFactor
// the SSL handshake following socket connection as part of the timeout.
// So explicitly set a socket read timeout, trigger the SSL handshake,
// then reset the timeout.
private void initialSSLHandshake(Socket socket , int connectTimeout) throws Exception {
private void initialSSLHandshake(Socket socket, int connectTimeout) throws Exception {

if (socket instanceof SSLSocket) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

instanceof pattern matching can be used here:

        if (socket instanceof SSLSocket sslSocket) {

SSLSocket sslSocket = (SSLSocket) socket;
@@ -402,15 +383,15 @@ LdapRequest writeRequest(BerEncoder ber, int msgId) throws IOException {
}

LdapRequest writeRequest(BerEncoder ber, int msgId,
boolean pauseAfterReceipt) throws IOException {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reformatting of unrelated code

boolean pauseAfterReceipt) throws IOException {
return writeRequest(ber, msgId, pauseAfterReceipt, -1);
}

LdapRequest writeRequest(BerEncoder ber, int msgId,
boolean pauseAfterReceipt, int replyQueueCapacity) throws IOException {
boolean pauseAfterReceipt, int replyQueueCapacity) throws IOException {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reformatting of unrelated code


LdapRequest req =
new LdapRequest(msgId, pauseAfterReceipt, replyQueueCapacity);
new LdapRequest(msgId, pauseAfterReceipt, replyQueueCapacity);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reformatting of unrelated code

addRequest(req);

if (traceFile != null) {
@@ -450,7 +431,7 @@ BerDecoder readReply(LdapRequest ldr) throws NamingException {
synchronized (this) {
if (sock == null) {
throw new ServiceUnavailableException(host + ":" + port +
"; socket closed");
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reformatting of unrelated code

"; socket closed");
}
}

@@ -462,7 +443,7 @@ BerDecoder readReply(LdapRequest ldr) throws NamingException {
rber = ldr.getReplyBer(readTimeout);
} catch (InterruptedException ex) {
throw new InterruptedNamingException(
"Interrupted during LDAP operation");
"Interrupted during LDAP operation");
} catch (IOException ioe) {
// Connection is timed out OR closed/cancelled
// getReplyBer throws IOException when the requests needs to be abandoned
@@ -553,17 +534,17 @@ void abandonRequest(LdapRequest ldr, Control[] reqCtls) {
//
try {
ber.beginSeq(Ber.ASN_SEQUENCE | Ber.ASN_CONSTRUCTOR);
ber.encodeInt(abandonMsgId);
ber.encodeInt(ldr.msgId, LdapClient.LDAP_REQ_ABANDON);
ber.encodeInt(abandonMsgId);
ber.encodeInt(ldr.msgId, LdapClient.LDAP_REQ_ABANDON);

if (v3) {
LdapClient.encodeControls(ber, reqCtls);
}
if (v3) {
LdapClient.encodeControls(ber, reqCtls);
}
ber.endSeq();

if (traceFile != null) {
Ber.dumpBER(traceFile, traceTagOut, ber.getBuf(), 0,
ber.getDataLen());
ber.getDataLen());
}

synchronized (this) {
@@ -606,19 +587,19 @@ private void ldapUnbind(Control[] reqCtls) {
try {

ber.beginSeq(Ber.ASN_SEQUENCE | Ber.ASN_CONSTRUCTOR);
ber.encodeInt(unbindMsgId);
// IMPLICIT TAGS
ber.encodeByte(LdapClient.LDAP_REQ_UNBIND);
ber.encodeByte(0);
ber.encodeInt(unbindMsgId);
// IMPLICIT TAGS
ber.encodeByte(LdapClient.LDAP_REQ_UNBIND);
ber.encodeByte(0);

if (v3) {
LdapClient.encodeControls(ber, reqCtls);
}
if (v3) {
LdapClient.encodeControls(ber, reqCtls);
}
ber.endSeq();

if (traceFile != null) {
Ber.dumpBER(traceFile, traceTagOut, ber.getBuf(),
0, ber.getDataLen());
0, ber.getDataLen());
}

synchronized (this) {
@@ -634,14 +615,14 @@ private void ldapUnbind(Control[] reqCtls) {
}

/**
* @param reqCtls Possibly null request controls that accompanies the
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reformatting of unrelated code

* abandon and unbind LDAP request.
* @param reqCtls Possibly null request controls that accompanies the
* abandon and unbind LDAP request.
* @param notifyParent true means to call parent LdapClient back, notifying
* it that the connection has been closed; false means not to notify
* parent. If LdapClient invokes cleanup(), notifyParent should be set to
* false because LdapClient already knows that it is closing
* the connection. If Connection invokes cleanup(), notifyParent should be
* set to true because LdapClient needs to know about the closure.
* it that the connection has been closed; false means not to notify
* parent. If LdapClient invokes cleanup(), notifyParent should be set to
* false because LdapClient already knows that it is closing
* the connection. If Connection invokes cleanup(), notifyParent should be
* set to true because LdapClient needs to know about the closure.
*/
void cleanup(Control[] reqCtls, boolean notifyParent) {
boolean nparent = false;
@@ -691,10 +672,10 @@ void cleanup(Control[] reqCtls, boolean notifyParent) {
LdapRequest ldr = pendingRequests;
while (ldr != null) {
ldr.close();
ldr = ldr.next;
}
ldr = ldr.next;
}
}
}
if (nparent) {
parent.processConnectionClosure();
}
@@ -848,23 +829,23 @@ private void unpauseReader() throws IOException {
if (paused) {
if (debug) {
System.err.println("Unpausing reader; read from: " +
inStream);
inStream);
}
paused = false;
pauseLock.notify();
}
}
}

/*
/*
* Pauses reader so that it stops reading from the input stream.
* Reader blocks on pauseLock instead of read().
* MUST be called from within synchronized (pauseLock) clause.
*/
private void pauseReader() throws IOException {
if (debug) {
System.err.println("Pausing reader; was reading from: " +
inStream);
inStream);
}
paused = true;
try {
@@ -945,8 +926,8 @@ public void run() {

// Read all length bytes
while (bytesread < seqlenlen) {
br = in.read(inbuf, offset+bytesread,
seqlenlen-bytesread);
br = in.read(inbuf, offset + bytesread,
seqlenlen - bytesread);
if (br < 0) {
eos = true;
break; // EOF
@@ -960,8 +941,8 @@ public void run() {

// Add contents of length bytes to determine length
seqlen = 0;
for( int i = 0; i < seqlenlen; i++) {
seqlen = (seqlen << 8) + (inbuf[offset+i] & 0xff);
for (int i = 0; i < seqlenlen; i++) {
seqlen = (seqlen << 8) + (inbuf[offset + i] & 0xff);
}
offset += bytesread;
}
@@ -1045,7 +1026,7 @@ public void run() {

if (debug) {
System.err.println("Connection: end-of-stream detected: "
+ in);
+ in);
}
} catch (IOException ex) {
if (debug) {
@@ -1061,8 +1042,7 @@ public void run() {
}

private static byte[] readFully(InputStream is, int length)
throws IOException
{
throws IOException {
byte[] buf = new byte[Math.min(length, 8192)];
int nread = 0;
while (nread < length) {
@@ -1106,7 +1086,7 @@ public synchronized void setHandshakeCompletedListener(SSLSocket sslSocket) {
}

public X509Certificate getTlsServerCertificate()
throws SaslException {
throws SaslException {
try {
if (isTlsConnection() && tlsHandshakeListener != null)
return tlsHandshakeListener.tlsHandshakeCompleted.get();
@@ -1122,6 +1102,7 @@ private class HandshakeListener implements HandshakeCompletedListener {

private final CompletableFuture<X509Certificate> tlsHandshakeCompleted =
new CompletableFuture<>();

@Override
public void handshakeCompleted(HandshakeCompletedEvent event) {
try {