/*
 * Decompiled with CFR 0.152.
 */
package org.dbunit.database;

import java.sql.PreparedStatement;
import java.sql.ResultSet;
import java.sql.SQLException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Set;
import org.dbunit.database.IDatabaseConnection;
import org.dbunit.database.PrimaryKeyFilteredTableWrapper;
import org.dbunit.database.search.ForeignKeyRelationshipEdge;
import org.dbunit.dataset.DataSetException;
import org.dbunit.dataset.IDataSet;
import org.dbunit.dataset.ITable;
import org.dbunit.dataset.ITableIterator;
import org.dbunit.dataset.ITableMetaData;
import org.dbunit.dataset.filter.AbstractTableFilter;
import org.dbunit.util.SQLHelper;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class PrimaryKeyFilter
extends AbstractTableFilter {
    private final IDatabaseConnection connection;
    private final Map allowedPKsPerTable;
    private final Map allowedPKsInput;
    private final Map pksToScanPerTable;
    private final boolean reverseScan;
    protected final Logger logger = LoggerFactory.getLogger(this.getClass());
    private final Map pkColumnPerTable = new HashMap();
    private final Map fkEdgesPerTable = new HashMap();
    private final Map fkReverseEdgesPerTable = new HashMap();
    private final List tableNames = new ArrayList();
    static /* synthetic */ Class class$org$dbunit$database$PrimaryKeyFilter$FilterIterator;

    public PrimaryKeyFilter(IDatabaseConnection connection, Map allowedPKs, boolean reverseDependency) {
        this.connection = connection;
        this.allowedPKsPerTable = new HashMap();
        this.allowedPKsInput = allowedPKs;
        this.reverseScan = reverseDependency;
        this.pksToScanPerTable = new HashMap(allowedPKs.size());
        Iterator iterator = allowedPKs.entrySet().iterator();
        while (iterator.hasNext()) {
            Map.Entry entry = iterator.next();
            Object table = entry.getKey();
            Set inputSet = (Set)entry.getValue();
            HashSet newSet = new HashSet(inputSet);
            this.pksToScanPerTable.put(table, newSet);
        }
    }

    public void nodeAdded(Object node) {
        this.logger.debug("nodeAdded(node=" + node + ") - start");
        this.tableNames.add(node);
        if (this.logger.isDebugEnabled()) {
            this.logger.debug("nodeAdded: " + node);
        }
    }

    public void edgeAdded(ForeignKeyRelationshipEdge edge) {
        Object pkTo;
        String to;
        String from;
        Set<ForeignKeyRelationshipEdge> edges;
        if (this.logger.isDebugEnabled()) {
            this.logger.debug("edgeAdded: " + edge);
        }
        if ((edges = (HashSet<ForeignKeyRelationshipEdge>)this.fkEdgesPerTable.get(from = (String)edge.getFrom())) == null) {
            edges = new HashSet<ForeignKeyRelationshipEdge>();
            this.fkEdgesPerTable.put(from, edges);
        }
        if (!edges.contains(edge)) {
            edges.add(edge);
        }
        if ((edges = (Set)this.fkReverseEdgesPerTable.get(to = (String)edge.getTo())) == null) {
            edges = new HashSet();
            this.fkReverseEdgesPerTable.put(to, edges);
        }
        if (!edges.contains(edge)) {
            edges.add(edge);
        }
        if ((pkTo = this.pkColumnPerTable.get(to)) == null) {
            String pk = edge.getPKColumn();
            this.pkColumnPerTable.put(to, pk);
        }
    }

    public boolean isValidName(String tableName) throws DataSetException {
        this.logger.debug("isValidName(tableName=" + tableName + ") - start");
        return true;
    }

    public ITableIterator iterator(IDataSet dataSet, boolean reversed) throws DataSetException {
        if (this.logger.isDebugEnabled()) {
            this.logger.debug("Filter.iterator()");
        }
        try {
            this.searchPKs(dataSet);
        }
        catch (SQLException e) {
            this.logger.error("iterator()", (Throwable)e);
            throw new DataSetException(e);
        }
        return new FilterIterator(reversed ? dataSet.reverseIterator() : dataSet.iterator());
    }

    private void searchPKs(IDataSet dataSet) throws DataSetException, SQLException {
        this.logger.debug("searchPKs(dataSet=" + dataSet + ") - start");
        int counter = 0;
        while (!this.pksToScanPerTable.isEmpty()) {
            ++counter;
            if (this.logger.isDebugEnabled()) {
                this.logger.debug("RUN # " + counter);
            }
            for (int i = this.tableNames.size() - 1; i >= 0; --i) {
                String tableName = (String)this.tableNames.get(i);
                String pkColumn = dataSet.getTable(tableName).getTableMetaData().getPrimaryKeys()[0].getColumnName();
                Set tmpSet = (Set)this.pksToScanPerTable.get(tableName);
                if (tmpSet == null || tmpSet.isEmpty()) continue;
                HashSet pksToScan = new HashSet(tmpSet);
                if (this.logger.isDebugEnabled()) {
                    this.logger.debug("before search: " + tableName + "=>" + pksToScan);
                }
                this.scanPKs(tableName, pkColumn, pksToScan);
                this.scanReversePKs(tableName, pksToScan);
                this.allowPKs(tableName, pksToScan);
                this.removePKsToScan(tableName, pksToScan);
            }
            this.removeScannedTables();
        }
        if (this.logger.isDebugEnabled()) {
            this.logger.debug("Finished searchIds()");
        }
    }

    private void removeScannedTables() {
        this.logger.debug("removeScannedTables() - start");
        Iterator iterator = this.pksToScanPerTable.entrySet().iterator();
        ArrayList<String> tablesToRemove = new ArrayList<String>();
        while (iterator.hasNext()) {
            Map.Entry entry = iterator.next();
            String table = (String)entry.getKey();
            Set pksToScan = (Set)entry.getValue();
            boolean removeIt = pksToScan.isEmpty();
            if (!this.tableNames.contains(table)) {
                if (this.logger.isWarnEnabled()) {
                    this.logger.warn("Discarding ids " + pksToScan + " of table " + table + "as this table has not been passed as input");
                }
                removeIt = true;
            }
            if (!removeIt) continue;
            tablesToRemove.add(table);
        }
        iterator = tablesToRemove.iterator();
        while (iterator.hasNext()) {
            this.pksToScanPerTable.remove(iterator.next());
        }
    }

    private void allowPKs(String table, Set newAllowedPKs) {
        Set forcedAllowedPKs;
        this.logger.debug("allowPKs(table=" + table + ", newAllowedPKs=" + newAllowedPKs + ") - start");
        HashSet currentAllowedIds = (HashSet)this.allowedPKsPerTable.get(table);
        if (currentAllowedIds == null) {
            currentAllowedIds = new HashSet();
            this.allowedPKsPerTable.put(table, currentAllowedIds);
        }
        if ((forcedAllowedPKs = (Set)this.allowedPKsInput.get(table)) == null || forcedAllowedPKs.isEmpty()) {
            currentAllowedIds.addAll(newAllowedPKs);
        } else {
            Iterator iterator = newAllowedPKs.iterator();
            while (iterator.hasNext()) {
                Object id = iterator.next();
                if (forcedAllowedPKs.contains(id)) {
                    currentAllowedIds.add(id);
                    continue;
                }
                if (!this.logger.isDebugEnabled()) continue;
                this.logger.debug("Discarding id " + id + " of table " + table + " as it was not included in the input!");
            }
        }
    }

    private void scanPKs(String table, String pkColumn, Set allowedIds) throws SQLException {
        this.logger.debug("scanPKs(table=" + table + ", pkColumn=" + pkColumn + ", allowedIds=" + allowedIds + ") - start");
        Set fkEdges = (Set)this.fkEdgesPerTable.get(table);
        if (fkEdges == null || fkEdges.isEmpty()) {
            return;
        }
        ArrayList<Object> fkTables = new ArrayList<Object>(fkEdges.size());
        Iterator iterator = fkEdges.iterator();
        StringBuffer colsBuffer = new StringBuffer();
        while (iterator.hasNext()) {
            ForeignKeyRelationshipEdge edge = (ForeignKeyRelationshipEdge)iterator.next();
            fkTables.add(edge.getTo());
            colsBuffer.append(edge.getFKColumn());
            if (!iterator.hasNext()) continue;
            colsBuffer.append(", ");
        }
        String sql = "SELECT " + colsBuffer + " FROM " + table + " WHERE " + pkColumn + " = ? ";
        if (this.logger.isDebugEnabled()) {
            this.logger.debug("SQL: " + sql);
        }
        PreparedStatement pstmt = null;
        ResultSet rs = null;
        try {
            pstmt = this.connection.getConnection().prepareStatement(sql);
            iterator = allowedIds.iterator();
            while (iterator.hasNext()) {
                Object pk = iterator.next();
                if (this.logger.isDebugEnabled()) {
                    this.logger.debug("Executing sql for ? = " + pk);
                }
                pstmt.setObject(1, pk);
                rs = pstmt.executeQuery();
                while (rs.next()) {
                    for (int i = 0; i < fkTables.size(); ++i) {
                        String newTable = (String)fkTables.get(i);
                        Object fk = rs.getObject(i + 1);
                        if (fk != null) {
                            if (this.logger.isDebugEnabled()) {
                                this.logger.debug("New ID: " + newTable + "->" + fk);
                            }
                            this.addPKToScan(newTable, fk);
                            continue;
                        }
                        this.logger.warn("Found null FK for relationship  " + table + "=>" + newTable);
                    }
                }
            }
        }
        catch (SQLException e) {
            this.logger.error("scanPKs()", (Throwable)e);
            SQLHelper.close(rs, pstmt);
        }
    }

    private void scanReversePKs(String table, Set pksToScan) throws SQLException {
        this.logger.debug("scanReversePKs(table=" + table + ", pksToScan=" + pksToScan + ") - start");
        if (!this.reverseScan) {
            return;
        }
        Set fkReverseEdges = (Set)this.fkReverseEdgesPerTable.get(table);
        if (fkReverseEdges == null || fkReverseEdges.isEmpty()) {
            return;
        }
        Iterator iterator = fkReverseEdges.iterator();
        while (iterator.hasNext()) {
            ForeignKeyRelationshipEdge edge = (ForeignKeyRelationshipEdge)iterator.next();
            this.addReverseEdge(edge, pksToScan);
        }
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    private void addReverseEdge(ForeignKeyRelationshipEdge edge, Set idsToScan) throws SQLException {
        this.logger.debug("addReverseEdge(edge=" + edge + ", idsToScan=" + idsToScan + ") - start");
        String fkTable = (String)edge.getFrom();
        String fkColumn = edge.getFKColumn();
        String pkColumn = this.getPKColumn(fkTable);
        String sql = "SELECT " + pkColumn + " FROM " + fkTable + " WHERE " + fkColumn + " = ? ";
        PreparedStatement pstmt = null;
        try {
            if (this.logger.isDebugEnabled()) {
                this.logger.debug("Preparing SQL query '" + sql + "'");
            }
            pstmt = this.connection.getConnection().prepareStatement(sql);
        }
        catch (SQLException e) {
            this.logger.error("addReverseEdge()", (Throwable)e);
            SQLHelper.close(pstmt);
        }
        ResultSet rs = null;
        Iterator iterator = idsToScan.iterator();
        try {
            while (iterator.hasNext()) {
                Object pk = iterator.next();
                if (this.logger.isDebugEnabled()) {
                    this.logger.debug("executing query '" + sql + "' for ? = " + pk);
                }
                pstmt.setObject(1, pk);
                rs = pstmt.executeQuery();
                while (rs.next()) {
                    Object fk = rs.getObject(1);
                    this.addPKToScan(fkTable, fk);
                }
            }
        }
        finally {
            SQLHelper.close(rs, pstmt);
        }
    }

    private String getPKColumn(String table) throws SQLException {
        this.logger.debug("getPKColumn(table=" + table + ") - start");
        String pkColumn = (String)this.pkColumnPerTable.get(table);
        if (pkColumn == null) {
            pkColumn = SQLHelper.getPrimaryKeyColumn(this.connection.getConnection(), table);
            this.pkColumnPerTable.put(table, pkColumn);
        }
        return pkColumn;
    }

    private void removePKsToScan(String table, Set ids) {
        this.logger.debug("removePKsToScan(table=" + table + ", ids=" + ids + ") - start");
        Set pksToScan = (Set)this.pksToScanPerTable.get(table);
        if (pksToScan != null) {
            if (pksToScan == ids) {
                throw new RuntimeException("INTERNAL ERROR on removeIdsToScan() for table " + table);
            }
            pksToScan.removeAll(ids);
        }
    }

    private void addPKToScan(String table, Object pk) {
        this.logger.debug("addPKToScan(table=" + table + ", pk=" + pk + ") - start");
        Set scannedIds = (Set)this.allowedPKsPerTable.get(table);
        if (scannedIds != null && scannedIds.contains(pk)) {
            if (this.logger.isDebugEnabled()) {
                this.logger.debug("Discarding already scanned id=" + pk + " for table " + table);
            }
            return;
        }
        HashSet<Object> pksToScan = (HashSet<Object>)this.pksToScanPerTable.get(table);
        if (pksToScan == null) {
            pksToScan = new HashSet<Object>();
            this.pksToScanPerTable.put(table, pksToScan);
        }
        pksToScan.add(pk);
    }

    static /* synthetic */ Class class$(String x0) {
        try {
            return Class.forName(x0);
        }
        catch (ClassNotFoundException x1) {
            throw new NoClassDefFoundError(x1.getMessage());
        }
    }

    private class FilterIterator
    implements ITableIterator {
        private final Logger logger = LoggerFactory.getLogger((Class)(class$org$dbunit$database$PrimaryKeyFilter$FilterIterator == null ? (class$org$dbunit$database$PrimaryKeyFilter$FilterIterator = PrimaryKeyFilter.class$("org.dbunit.database.PrimaryKeyFilter$FilterIterator")) : class$org$dbunit$database$PrimaryKeyFilter$FilterIterator));
        private final ITableIterator _iterator;

        public FilterIterator(ITableIterator iterator) {
            this._iterator = iterator;
        }

        public boolean next() throws DataSetException {
            if (this.logger.isDebugEnabled()) {
                this.logger.debug("Iterator.next()");
            }
            while (this._iterator.next()) {
                if (!PrimaryKeyFilter.this.accept(this._iterator.getTableMetaData().getTableName())) continue;
                return true;
            }
            return false;
        }

        public ITableMetaData getTableMetaData() throws DataSetException {
            if (this.logger.isDebugEnabled()) {
                this.logger.debug("Iterator.getTableMetaData()");
            }
            return this._iterator.getTableMetaData();
        }

        public ITable getTable() throws DataSetException {
            if (this.logger.isDebugEnabled()) {
                this.logger.debug("Iterator.getTable()");
            }
            ITable table = this._iterator.getTable();
            String tableName = table.getTableMetaData().getTableName();
            Set allowedPKs = (Set)PrimaryKeyFilter.this.allowedPKsPerTable.get(tableName);
            if (allowedPKs != null) {
                return new PrimaryKeyFilteredTableWrapper(table, allowedPKs);
            }
            return table;
        }
    }
}

