package b4j.ssh;
import anywheresoftware.b4a.BA.ShortName;
import anywheresoftware.b4a.BA.Version;
import com.jcraft.jsch.*;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.nio.ByteBuffer;
import java.nio.CharBuffer;
import java.nio.charset.*;
import java.util.ArrayList;
import java.util.List;
import java.util.Objects;
import java.util.Properties;
import java.util.concurrent.LinkedBlockingDeque;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicLong;
/**
* Minimal deterministic SSH shell wrapper for B4J.
* Provides a simple, safe, predictable API for interactive SSH automation.
*/
@ShortName("B4JSSH")
@Version(3.2f)
public class B4JSSH {
// -------------------------------------------------------------------------
// Fields
// -------------------------------------------------------------------------
private JSch jsch;
private Session session;
private ChannelShell shell;
private InputStream in;
private OutputStream out;
private final LinkedBlockingDeque<Chunk> queue = new LinkedBlockingDeque<>();
private final byte[] readerBuf = new byte[4096];
private final CharsetDecoder decoder =
StandardCharsets.UTF_8.newDecoder()
.onMalformedInput(CodingErrorAction.REPORT)
.onUnmappableCharacter(CodingErrorAction.REPORT);
private final ByteBuffer decodeBytes = ByteBuffer.allocate(4096 + 3);
private final CharBuffer decodeChars = CharBuffer.allocate(8192);
private Thread readerThread;
private final AtomicLong generation = new AtomicLong(0);
private static final class Chunk {
final String data;
final boolean eof;
final long gen;
Chunk(String d, boolean eof, long gen) {
this.data = d;
this.eof = eof;
this.gen = gen;
}
static Chunk data(String s, long g) { return new Chunk(s, false, g); }
static Chunk eof(long g) { return new Chunk(null, true, g); }
}
private static final int INTER_CHUNK_IDLE_MS = 150;
// -------------------------------------------------------------------------
// Lifecycle
// -------------------------------------------------------------------------
public void Initialize() {
if (jsch != null)
throw new SshException(SshException.ErrorKind.INTERNAL,
"Initialize() called more than once.");
jsch = new JSch();
}
public void Connect(String host, int port, String user, String pass, int timeoutMs) {
Objects.requireNonNull(host, "host");
Objects.requireNonNull(user, "user");
Objects.requireNonNull(pass, "pass");
if (port <= 0 || port > 65535)
throw new SshException(SshException.ErrorKind.INTERNAL,
"Invalid SSH port: " + port);
if (timeoutMs <= 0)
throw new SshException(SshException.ErrorKind.INTERNAL,
"timeoutMs must be > 0. Value=" + timeoutMs);
if (jsch == null)
throw new SshException(SshException.ErrorKind.INTERNAL,
"Connect() called before Initialize().");
try {
session = jsch.getSession(user, host, port);
session.setPassword(pass);
Properties cfg = new Properties();
cfg.put("StrictHostKeyChecking", "no");
cfg.put("PreferredAuthentications", "password");
session.setConfig(cfg);
session.connect(timeoutMs);
} catch (JSchException ex) {
safeDisconnect();
throw new SshException(SshException.ErrorKind.CONNECT,
"SSH connection failed: " + ex.getMessage(), ex);
}
}
public void OpenShell() {
if (session == null || !session.isConnected())
throw new SshException(SshException.ErrorKind.INTERNAL,
"OpenShell() called when session is not connected.");
try {
Channel ch = session.openChannel("shell");
if (!(ch instanceof ChannelShell)) {
ch.disconnect();
throw new SshException(SshException.ErrorKind.PROTOCOL,
"Expected ChannelShell but got: " + ch.getClass().getName());
}
shell = (ChannelShell) ch;
shell.setPty(true);
in = shell.getInputStream();
out = shell.getOutputStream();
shell.connect();
} catch (JSchException | IOException ex) {
safeCloseShell();
throw new SshException(SshException.ErrorKind.PROTOCOL,
"Failed to open shell: " + ex.getMessage(), ex);
}
queue.clear();
long myGen = generation.incrementAndGet();
if (myGen == Long.MAX_VALUE)
generation.set(1);
startReader(myGen);
}
public void Disconnect() {
safeDisconnect();
}
// -------------------------------------------------------------------------
// I/O
// -------------------------------------------------------------------------
public void Write(String cmd) {
Objects.requireNonNull(cmd, "cmd");
ensureShell();
try {
out.write(cmd.getBytes(StandardCharsets.UTF_8));
out.write('\n');
out.flush();
} catch (IOException ex) {
throw new SshException(SshException.ErrorKind.INTERNAL,
"Write() I/O error: " + ex.getMessage(), ex);
}
}
public void WriteRaw(String data) {
Objects.requireNonNull(data, "data");
ensureShell();
try {
out.write(data.getBytes(StandardCharsets.UTF_8));
out.flush();
} catch (IOException ex) {
throw new SshException(SshException.ErrorKind.INTERNAL,
"WriteRaw() I/O error: " + ex.getMessage(), ex);
}
}
public String ReadWindow(int timeoutMs) {
ensureShell();
if (timeoutMs <= 0)
throw new SshException(SshException.ErrorKind.INTERNAL,
"timeoutMs must be > 0. Value=" + timeoutMs);
StringBuilder sb = new StringBuilder(512);
long myGen = generation.get();
long deadline = System.nanoTime() + TimeUnit.MILLISECONDS.toNanos(timeoutMs);
boolean gotAny = false;
while (true) {
long now = System.nanoTime();
if (now >= deadline)
return sb.toString();
long remMs = TimeUnit.NANOSECONDS.toMillis(deadline - now);
long pollMs = gotAny ? Math.min(INTER_CHUNK_IDLE_MS, remMs) : remMs;
if (pollMs < 1) pollMs = 1;
Chunk c;
try {
c = queue.poll(pollMs, TimeUnit.MILLISECONDS);
} catch (InterruptedException ie) {
Thread.currentThread().interrupt();
throw new SshException(SshException.ErrorKind.TIMEOUT,
"ReadWindow interrupted. Buffer: [" + sb + "]", ie);
}
if (c == null)
return sb.toString();
if (c.gen != myGen)
continue;
if (c.eof)
throw new SshException(SshException.ErrorKind.REMOTE_CLOSED,
"Remote closed during ReadWindow. Buffer: [" + sb + "]");
sb.append(c.data);
gotAny = true;
}
}
public String ReadUntil(String[] prompts, int timeoutMs) {
ensureShell();
if (prompts == null || prompts.length == 0)
throw new SshException(SshException.ErrorKind.INTERNAL,
"prompts must not be null or empty.");
if (timeoutMs <= 0)
throw new SshException(SshException.ErrorKind.INTERNAL,
"timeoutMs must be > 0. Value=" + timeoutMs);
List<String> list = new ArrayList<>();
int maxLen = 0;
for (String p : prompts) {
if (p == null || p.isEmpty())
throw new SshException(SshException.ErrorKind.INTERNAL,
"prompts must not contain null or empty strings.");
list.add(p);
if (p.length() > maxLen) maxLen = p.length();
}
StringBuilder sb = new StringBuilder(512);
long myGen = generation.get();
long deadline = System.nanoTime() + TimeUnit.MILLISECONDS.toNanos(timeoutMs);
while (true) {
long now = System.nanoTime();
if (now >= deadline)
throw new SshException(SshException.ErrorKind.TIMEOUT,
"ReadUntil timed out. Buffer: [" + sb + "]");
long remMs = TimeUnit.NANOSECONDS.toMillis(deadline - now);
if (remMs < 1) remMs = 1;
Chunk c;
try {
c = queue.poll(remMs, TimeUnit.MILLISECONDS);
} catch (InterruptedException ie) {
Thread.currentThread().interrupt();
throw new SshException(SshException.ErrorKind.TIMEOUT,
"ReadUntil interrupted. Buffer: [" + sb + "]", ie);
}
if (c == null)
throw new SshException(SshException.ErrorKind.TIMEOUT,
"ReadUntil timed out. Buffer: [" + sb + "]");
if (c.gen != myGen)
continue;
if (c.eof)
throw new SshException(SshException.ErrorKind.REMOTE_CLOSED,
"Remote closed during ReadUntil. Buffer: [" + sb + "]");
int prev = sb.length();
sb.append(c.data);
int searchFrom = Math.max(0, prev - (maxLen - 1));
for (String p : list) {
if (sb.indexOf(p, searchFrom) >= 0)
return sb.toString();
}
}
}
// -------------------------------------------------------------------------
// State
// -------------------------------------------------------------------------
public boolean IsConnected() {
return session != null && session.isConnected();
}
public boolean IsShellOpen() {
return shell != null && shell.isConnected();
}
// -------------------------------------------------------------------------
// Helpers
// -------------------------------------------------------------------------
private void ensureShell() {
if (shell == null || !shell.isConnected())
throw new SshException(SshException.ErrorKind.INTERNAL,
"Shell not open. Call OpenShell() first.");
}
private void safeCloseShell() {
try {
if (shell != null)
shell.disconnect();
} catch (Exception ignored) {
} finally {
shell = null;
}
try {
if (in != null) in.close();
} catch (Exception ignored) {}
in = null;
out = null;
if (readerThread != null)
readerThread.interrupt();
readerThread = null;
queue.clear();
}
private void safeDisconnect() {
safeCloseShell();
try {
if (session != null)
session.disconnect();
} catch (Exception ignored) {
} finally {
session = null;
}
}
private void startReader(long myGen) {
decoder.reset();
decodeBytes.clear();
decodeChars.clear();
readerThread = new Thread(() -> {
while (true) {
int len;
try {
len = in.read(readerBuf);
} catch (IOException e) {
queue.offer(Chunk.eof(myGen));
break;
}
if (len < 0) {
queue.offer(Chunk.eof(myGen));
break;
}
try {
String s = decodeUtf8(readerBuf, 0, len);
if (!s.isEmpty())
queue.offer(Chunk.data(s, myGen));
} catch (Exception ex) {
queue.offer(Chunk.eof(myGen));
break;
}
}
}, "B4JSSH-Reader");
readerThread.setDaemon(true);
readerThread.start();
}
private String decodeUtf8(byte[] buf, int off, int len) {
if (len > decodeBytes.remaining()) {
throw new SshException(SshException.ErrorKind.INTERNAL,
"UTF-8 decode overflow: len=" + len +
" remaining=" + decodeBytes.remaining());
}
decodeBytes.put(buf, off, len);
decodeBytes.flip();
StringBuilder sb = new StringBuilder(len);
while (true) {
CoderResult r = decoder.decode(decodeBytes, decodeChars, false);
if (r.isError())
throw new SshException(SshException.ErrorKind.INTERNAL,
"Invalid UTF-8 sequence received from remote.");
decodeChars.flip();
if (decodeChars.hasRemaining())
sb.append(decodeChars);
decodeChars.clear();
if (r.isUnderflow())
break;
}
decodeBytes.compact();
return sb.toString();
}
}