EmbeddedPgExtension.java
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package de.softwareforge.testing.postgres.junit5;
import static com.google.common.base.Preconditions.checkNotNull;
import static com.google.common.base.Preconditions.checkState;
import static de.softwareforge.testing.postgres.junit5.EmbeddedPgExtension.TestMode.TESTMODE_KEY;
import de.softwareforge.testing.postgres.embedded.DatabaseInfo;
import de.softwareforge.testing.postgres.embedded.DatabaseManager;
import de.softwareforge.testing.postgres.embedded.DatabaseManager.DatabaseManagerBuilder;
import de.softwareforge.testing.postgres.embedded.EmbeddedPostgres;
import java.sql.SQLException;
import java.util.UUID;
import javax.sql.DataSource;
import com.google.common.annotations.VisibleForTesting;
import org.junit.jupiter.api.extension.AfterAllCallback;
import org.junit.jupiter.api.extension.AfterEachCallback;
import org.junit.jupiter.api.extension.BeforeAllCallback;
import org.junit.jupiter.api.extension.BeforeEachCallback;
import org.junit.jupiter.api.extension.ExtensionContext;
import org.junit.jupiter.api.extension.ExtensionContext.Namespace;
import org.junit.jupiter.api.extension.ExtensionContext.Store;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
public final class EmbeddedPgExtension implements BeforeAllCallback, AfterAllCallback, BeforeEachCallback, AfterEachCallback {
private static final Logger LOG = LoggerFactory.getLogger(EmbeddedPgExtension.class);
// multiple instances must use different namespaces
private final Namespace PG_NAMESPACE = Namespace.create(UUID.randomUUID());
private final DatabaseManager.Builder<DatabaseManager> databaseManagerBuilder;
private volatile DatabaseManager databaseManager = null;
private EmbeddedPgExtension(DatabaseManager.Builder<DatabaseManager> databaseManagerBuilder) {
this.databaseManagerBuilder = databaseManagerBuilder;
}
public static EmbeddedPgExtensionBuilder multiDatabase() {
return new EmbeddedPgExtensionBuilder(true);
}
public static EmbeddedPgExtensionBuilder singleDatabase() {
return new EmbeddedPgExtensionBuilder(false);
}
/**
* Returns the data source for the current instance.
*/
public DataSource createDataSource() throws SQLException {
return createDatabaseInfo().asDataSource();
}
@VisibleForTesting
EmbeddedPostgres getEmbeddedPostgres() {
return databaseManager.getEmbeddedPostgres();
}
/**
* Returns a {@link DatabaseInfo} describing the database connection.
*/
public DatabaseInfo createDatabaseInfo() throws SQLException {
checkState(databaseManager != null, "no before method has been called!");
DatabaseInfo databaseInfo = databaseManager.getDatabaseInfo();
if (databaseInfo.exception().isEmpty()) {
LOG.info("Connection to {}", databaseInfo.asJdbcUrl());
}
return databaseInfo;
}
@Override
public void beforeAll(ExtensionContext extensionContext) throws Exception {
checkNotNull(extensionContext, "extensionContext is null");
Store pgStore = extensionContext.getStore(PG_NAMESPACE);
TestMode testMode = pgStore.getOrComputeIfAbsent(TESTMODE_KEY,
k -> new TestMode(extensionContext.getUniqueId(), databaseManagerBuilder.build()),
TestMode.class);
this.databaseManager = testMode.start(extensionContext.getUniqueId());
}
@Override
public void afterAll(ExtensionContext extensionContext) throws Exception {
checkNotNull(extensionContext, "extensionContext is null");
Store pgStore = extensionContext.getStore(PG_NAMESPACE);
TestMode testMode = pgStore.get(TESTMODE_KEY, TestMode.class);
if (testMode != null) {
this.databaseManager = testMode.stop(extensionContext.getUniqueId());
}
}
@Override
public void beforeEach(ExtensionContext extensionContext) throws Exception {
checkNotNull(extensionContext, "extensionContext is null");
Store pgStore = extensionContext.getStore(PG_NAMESPACE);
TestMode testMode = pgStore.getOrComputeIfAbsent(TESTMODE_KEY,
k -> new TestMode(extensionContext.getUniqueId(), databaseManagerBuilder.build()),
TestMode.class);
this.databaseManager = testMode.start(extensionContext.getUniqueId());
}
@Override
public void afterEach(ExtensionContext extensionContext) throws Exception {
checkNotNull(extensionContext, "extensionContext is null");
Store pgStore = extensionContext.getStore(PG_NAMESPACE);
TestMode testMode = pgStore.get(TESTMODE_KEY, TestMode.class);
if (testMode != null) {
this.databaseManager = testMode.stop(extensionContext.getUniqueId());
}
}
static class EmbeddedPgExtensionBuilder extends DatabaseManager.Builder<EmbeddedPgExtension> {
private EmbeddedPgExtensionBuilder(boolean multiMode) {
super(multiMode);
}
@Override
public EmbeddedPgExtension build() {
DatabaseManager.Builder<DatabaseManager> databaseManagerBuilder = new DatabaseManagerBuilder(multiMode)
.withPreparer(databasePreparer);
customizers.build().forEach(databaseManagerBuilder::withCustomizer);
return new EmbeddedPgExtension(databaseManagerBuilder);
}
}
static final class TestMode {
static final Object TESTMODE_KEY = new Object();
private final String id;
private final DatabaseManager databaseManager;
private TestMode(String id, DatabaseManager databaseManager) {
this.id = id;
this.databaseManager = databaseManager;
}
public DatabaseManager start(String id) throws Exception {
if (this.id.equals(id)) {
databaseManager.start();
}
return databaseManager;
}
public DatabaseManager stop(String id) throws Exception {
if (this.id.equals(id)) {
databaseManager.close();
return null;
}
return databaseManager;
}
}
}