Search Options

Results per page
Sort
Preferred Languages
Advance

Results 1 - 2 of 2 for getBatchDims (0.12 sec)

  1. tensorflow/compiler/mlir/tensorflow/ir/tf_ops_a_m.cc

    // GatherV2Op
    //===----------------------------------------------------------------------===//
    
    LogicalResult GatherV2Op::verify() {
      GatherV2Op op = *this;
      int64_t batch_dims = op.getBatchDims();
      if (auto ty = mlir::dyn_cast<RankedTensorType>(op.getIndices().getType())) {
        int64_t rank = ty.getRank();
        if (batch_dims > rank || batch_dims < -rank)
          return op.emitOpError()
    Registered: Sun Jun 16 05:45:23 UTC 2024
    - Last Modified: Thu Apr 25 16:01:03 UTC 2024
    - 146.7K bytes
    - Viewed (0)
  2. tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.cc

      assert(port[0] == 0);
    
      auto params = op.getParams();
      auto params_ty = mlir::dyn_cast<RankedTensorType>(params.getType());
      if (!params_ty || !params_ty.hasStaticShape() || params_ty.getRank() != 1 ||
          op.getBatchDims() != 0) {
        return {};
      }
    
      DenseIntElementsAttr axis;
      if (!matchPattern(op.getAxis(), m_Constant(&axis)) ||
          axis.getNumElements() != 1 ||
          !axis.getSplatValue<llvm::APInt>().isZero()) {
    Registered: Sun Jun 16 05:45:23 UTC 2024
    - Last Modified: Sat Jun 08 07:28:49 UTC 2024
    - 134.1K bytes
    - Viewed (0)
Back to top